-
Notifications
You must be signed in to change notification settings - Fork 1
/
sort.py
79 lines (60 loc) · 2.58 KB
/
sort.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from __future__ import print_function
import numpy as np
from kalman_tracker import KalmanBoxTracker
from data_association import associate_detections_to_trackers
class Sort:
def __init__(self, max_age=20, min_hits=5):
"""
Sets key parameters for SORT
"""
self.max_age = max_age
self.min_hits = min_hits
self.trackers = []
self.frame_count = 0
def update(self, dets, img=None):
"""
Params:
dets - a numpy array of detections in the format [[x,y,w,h,score],[x,y,w,h,score],...]
Requires: this method must be called once for each frame even with empty detections.
Returns the a similar array, where the last column is the object ID.
NOTE: The number of objects returned may differ from the number of detections provided.
"""
self.frame_count += 1
# Get predicted locations from existing trackers.
trks = np.zeros((len(self.trackers), 5))
to_del = []
ret = []
for t, trk in enumerate(trks):
pos = self.trackers[t].predict(img) #for kal!
trk[:] = [pos[0], pos[1], pos[2], pos[3], 0]
if (np.any(np.isnan(pos))):
to_del.append(t)
trks = np.ma.compress_rows(np.ma.masked_invalid(trks))
for t in reversed(to_del):
self.trackers.pop(t)
if len(dets) != 0:
matched, unmatched_dets, unmatched_trks = associate_detections_to_trackers(dets,trks)
# Update matched trackers with assigned detections
for t, trk in enumerate(self.trackers):
if t not in unmatched_trks:
d = matched[np.where(matched[:,1] == t)[0], 0]
trk.update(dets[d,:][0], img)
# Create and initialise new trackers for unmatched detections
for i in unmatched_dets:
trk = KalmanBoxTracker(dets[i,:])
self.trackers.append(trk)
i = len(self.trackers)
for trk in reversed(self.trackers):
if len(dets) == 0:
trk.update([], img)
d = trk.get_state()
# print(d)
if ((trk.time_since_update <= self.max_age) and (trk.hit_streak >= self.min_hits or self.frame_count <= self.min_hits)):
ret.append(np.concatenate((d,[trk.id+1])).reshape(1,-1))
i -= 1
# Remove dead tracklet
if trk.time_since_update > self.max_age:
self.trackers.pop(i)
if len(ret) > 0:
return np.concatenate(ret)
return np.empty((0, 5))