OpenShot Library | libopenshot  0.5.0
sort.cpp
Go to the documentation of this file.
1 // © OpenShot Studios, LLC
2 //
3 // SPDX-License-Identifier: LGPL-3.0-or-later
4 
5 #include "sort.hpp"
6 
7 using namespace std;
8 
9 // Constructor
10 SortTracker::SortTracker(int max_age, int min_hits, int max_missed, double min_iou, double nms_iou_thresh, double min_conf)
11 {
12  _min_hits = min_hits;
13  _max_age = max_age;
14  _max_missed = max_missed;
15  _min_iou = min_iou;
16  _nms_iou_thresh = nms_iou_thresh;
17  _min_conf = min_conf;
18  _next_id = 0;
19  alive_tracker = true;
20 }
21 
22 // Computes IOU between two bounding boxes
23 double SortTracker::GetIOU(cv::Rect_<float> bb_test, cv::Rect_<float> bb_gt)
24 {
25  float in = (bb_test & bb_gt).area();
26  float un = bb_test.area() + bb_gt.area() - in;
27 
28  if (un < DBL_EPSILON)
29  return 0;
30 
31  return (double)(in / un);
32 }
33 
34 // Computes centroid distance between two bounding boxes
36  cv::Rect_<float> bb_test,
37  cv::Rect_<float> bb_gt)
38 {
39  float bb_test_centroid_x = (bb_test.x + bb_test.width / 2);
40  float bb_test_centroid_y = (bb_test.y + bb_test.height / 2);
41 
42  float bb_gt_centroid_x = (bb_gt.x + bb_gt.width / 2);
43  float bb_gt_centroid_y = (bb_gt.y + bb_gt.height / 2);
44 
45  double distance = (double)sqrt(pow(bb_gt_centroid_x - bb_test_centroid_x, 2) + pow(bb_gt_centroid_y - bb_test_centroid_y, 2));
46 
47  return distance;
48 }
49 
50 // Function to apply NMS on detections
51 void apply_nms(vector<TrackingBox>& detections, double nms_iou_thresh) {
52  if (detections.empty()) return;
53 
54  // Sort detections by confidence descending
55  std::sort(detections.begin(), detections.end(), [](const TrackingBox& a, const TrackingBox& b) {
56  return a.confidence > b.confidence;
57  });
58 
59  vector<bool> suppressed(detections.size(), false);
60 
61  for (size_t i = 0; i < detections.size(); ++i) {
62  if (suppressed[i]) continue;
63 
64  for (size_t j = i + 1; j < detections.size(); ++j) {
65  if (suppressed[j]) continue;
66 
67  if (detections[i].classId == detections[j].classId &&
68  SortTracker::GetIOU(detections[i].box, detections[j].box) > nms_iou_thresh) {
69  suppressed[j] = true;
70  }
71  }
72  }
73 
74  // Remove suppressed detections
75  vector<TrackingBox> filtered;
76  for (size_t i = 0; i < detections.size(); ++i) {
77  if (!suppressed[i]) {
78  filtered.push_back(detections[i]);
79  }
80  }
81  detections = filtered;
82 }
83 
84 void SortTracker::update(vector<cv::Rect> detections_cv, int frame_count, double image_diagonal, std::vector<float> confidences, std::vector<int> classIds)
85 {
86  vector<TrackingBox> detections;
87  if (trackers.size() == 0) // the first frame met
88  {
89  alive_tracker = false;
90  // initialize kalman trackers using first detections.
91  for (unsigned int i = 0; i < detections_cv.size(); i++)
92  {
93  if (confidences[i] < _min_conf) continue; // filter low conf
94 
95  TrackingBox tb;
96 
97  tb.box = cv::Rect_<float>(detections_cv[i]);
98  tb.classId = classIds[i];
99  tb.confidence = confidences[i];
100  detections.push_back(tb);
101 
102  KalmanTracker trk = KalmanTracker(detections.back().box, detections.back().confidence, detections.back().classId, _next_id++);
103  trackers.push_back(trk);
104  }
105  return;
106  }
107  else
108  {
109  for (unsigned int i = 0; i < detections_cv.size(); i++)
110  {
111  if (confidences[i] < _min_conf) continue; // filter low conf
112 
113  TrackingBox tb;
114  tb.box = cv::Rect_<float>(detections_cv[i]);
115  tb.classId = classIds[i];
116  tb.confidence = confidences[i];
117  detections.push_back(tb);
118  }
119 
120  // Apply NMS to remove duplicates
121  apply_nms(detections, _nms_iou_thresh);
122 
123  for (auto it = frameTrackingResult.begin(); it != frameTrackingResult.end(); it++)
124  {
125  int frame_age = frame_count - it->frame;
126  if (frame_age >= _max_age || frame_age < 0)
127  {
128  dead_trackers_id.push_back(it->id);
129  }
130  }
131  }
132 
134  // 3.1. get predicted locations from existing trackers.
135  predictedBoxes.clear();
136  for (unsigned int i = 0; i < trackers.size();)
137  {
138  cv::Rect_<float> pBox = trackers[i].predict();
139  if (pBox.x >= 0 && pBox.y >= 0)
140  {
141  predictedBoxes.push_back(pBox);
142  i++;
143  continue;
144  }
145  trackers.erase(trackers.begin() + i);
146  }
147 
148  trkNum = predictedBoxes.size();
149  detNum = detections.size();
150 
151  cost_matrix.clear();
152  cost_matrix.resize(trkNum, vector<double>(detNum, 0));
153 
154  for (unsigned int i = 0; i < trkNum; i++) // compute cost matrix using 1 - IOU with gating
155  {
156  for (unsigned int j = 0; j < detNum; j++)
157  {
158  double iou = GetIOU(predictedBoxes[i], detections[j].box);
159  double dist = GetCentroidsDistance(predictedBoxes[i], detections[j].box) / image_diagonal;
160  if (trackers[i].classId != detections[j].classId || dist > max_centroid_dist_norm)
161  {
162  cost_matrix[i][j] = 1e9; // large cost for gating
163  }
164  else
165  {
166  cost_matrix[i][j] = 1 - iou + (1 - detections[j].confidence) * 0.1; // slight penalty for low conf
167  }
168  }
169  }
170 
171  HungarianAlgorithm HungAlgo;
172  assignment.clear();
173  HungAlgo.Solve(cost_matrix, assignment);
174  // find matches, unmatched_detections and unmatched_predictions
175  unmatchedTrajectories.clear();
176  unmatchedDetections.clear();
177  allItems.clear();
178  matchedItems.clear();
179 
180  if (detNum > trkNum) // there are unmatched detections
181  {
182  for (unsigned int n = 0; n < detNum; n++)
183  allItems.insert(n);
184 
185  for (unsigned int i = 0; i < trkNum; ++i)
186  matchedItems.insert(assignment[i]);
187 
188  set_difference(allItems.begin(), allItems.end(),
189  matchedItems.begin(), matchedItems.end(),
190  insert_iterator<set<int>>(unmatchedDetections, unmatchedDetections.begin()));
191  }
192  else if (detNum < trkNum) // there are unmatched trajectory/predictions
193  {
194  for (unsigned int i = 0; i < trkNum; ++i)
195  if (assignment[i] == -1) // unassigned label will be set as -1 in the assignment algorithm
196  unmatchedTrajectories.insert(i);
197  }
198  else
199  ;
200 
201  // filter out matched with low IOU
202  matchedPairs.clear();
203  for (unsigned int i = 0; i < trkNum; ++i)
204  {
205  if (assignment[i] == -1) // pass over invalid values
206  continue;
207  if (cost_matrix[i][assignment[i]] > 1 - _min_iou)
208  {
209  unmatchedTrajectories.insert(i);
210  unmatchedDetections.insert(assignment[i]);
211  }
212  else
213  matchedPairs.push_back(cv::Point(i, assignment[i]));
214  }
215 
216  for (unsigned int i = 0; i < matchedPairs.size(); i++)
217  {
218  int trkIdx = matchedPairs[i].x;
219  int detIdx = matchedPairs[i].y;
220  trackers[trkIdx].update(detections[detIdx].box);
221  trackers[trkIdx].classId = detections[detIdx].classId;
222  trackers[trkIdx].confidence = detections[detIdx].confidence;
223  }
224 
225  // create and initialise new trackers for unmatched detections
226  for (auto umd : unmatchedDetections)
227  {
228  KalmanTracker tracker = KalmanTracker(detections[umd].box, detections[umd].confidence, detections[umd].classId, _next_id++);
229  trackers.push_back(tracker);
230  }
231 
232  for (auto it2 = dead_trackers_id.begin(); it2 != dead_trackers_id.end(); it2++)
233  {
234  for (unsigned int i = 0; i < trackers.size();)
235  {
236  if (trackers[i].m_id == (*it2))
237  {
238  trackers.erase(trackers.begin() + i);
239  continue;
240  }
241  i++;
242  }
243  }
244 
245  // get trackers' output
246  frameTrackingResult.clear();
247  for (unsigned int i = 0; i < trackers.size();)
248  {
249  if ((trackers[i].m_hits >= _min_hits && trackers[i].m_time_since_update <= _max_missed) ||
250  frame_count <= _min_hits)
251  {
252  alive_tracker = true;
253  TrackingBox res;
254  res.box = trackers[i].get_state();
255  res.id = trackers[i].m_id;
256  res.frame = frame_count;
257  res.classId = trackers[i].classId;
258  res.confidence = trackers[i].confidence;
259  frameTrackingResult.push_back(res);
260  }
261 
262  // remove dead tracklet
263  if (trackers[i].m_time_since_update >= _max_age)
264  {
265  trackers.erase(trackers.begin() + i);
266  continue;
267  }
268  i++;
269  }
270 }
TrackingBox::confidence
float confidence
Definition: sort.hpp:25
HungarianAlgorithm
Definition: Hungarian.h:22
TrackingBox
Definition: sort.hpp:22
TrackingBox::frame
int frame
Definition: sort.hpp:24
SortTracker::GetCentroidsDistance
double GetCentroidsDistance(cv::Rect_< float > bb_test, cv::Rect_< float > bb_gt)
Definition: sort.cpp:35
KalmanTracker
This class represents the internel state of individual tracked objects observed as bounding box.
Definition: KalmanTracker.h:18
TrackingBox::classId
int classId
Definition: sort.hpp:26
SortTracker::SortTracker
SortTracker(int max_age=50, int min_hits=5, int max_missed=7, double min_iou=0.1, double nms_iou_thresh=0.5, double min_conf=0.3)
Definition: sort.cpp:10
SortTracker::GetIOU
static double GetIOU(cv::Rect_< float > bb_test, cv::Rect_< float > bb_gt)
Definition: sort.cpp:23
TrackingBox::box
cv::Rect_< float > box
Definition: sort.hpp:28
TrackingBox::id
int id
Definition: sort.hpp:27
HungarianAlgorithm::Solve
double Solve(std::vector< std::vector< double >> &DistMatrix, std::vector< int > &Assignment)
Definition: Hungarian.cpp:26
sort.hpp
SortTracker::update
void update(std::vector< cv::Rect > detection, int frame_count, double image_diagonal, std::vector< float > confidences, std::vector< int > classIds)
Definition: sort.cpp:84
apply_nms
void apply_nms(vector< TrackingBox > &detections, double nms_iou_thresh)
Definition: sort.cpp:51