forked from xinshuoweng/AB3DMOT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
visualization.py
85 lines (73 loc) · 3.35 KB
/
visualization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os, numpy as np, sys, cv2
from PIL import Image
from utils import is_path_exists, mkdir_if_missing, load_list_from_folder, fileparts, random_colors
from kitti_utils import read_label, compute_box_3d, draw_projected_box3d, Calibration
max_color = 30
colors = random_colors(max_color) # Generate random colors
type_whitelist = ['Car', 'Pedestrian', 'Cyclist']
score_threshold = -10000
width = 1242
height = 374
seq_list = ['0000', '0003']
def vis(result_sha, data_root, result_root):
def show_image_with_boxes(img, objects_res, object_gt, calib, save_path, height_threshold=0):
img2 = np.copy(img)
for obj in objects_res:
box3d_pts_2d, _ = compute_box_3d(obj, calib.P)
color_tmp = tuple([int(tmp * 255) for tmp in colors[obj.id % max_color]])
img2 = draw_projected_box3d(img2, box3d_pts_2d, color=color_tmp)
text = 'ID: %d' % obj.id
if box3d_pts_2d is not None:
img2 = cv2.putText(img2, text, (int(box3d_pts_2d[4, 0]), int(box3d_pts_2d[4, 1]) - 8), cv2.FONT_HERSHEY_TRIPLEX, 0.5, color=color_tmp)
img = Image.fromarray(img2)
img = img.resize((width, height))
img.save(save_path)
for seq in seq_list:
image_dir = os.path.join(data_root, 'image_02/%s' % seq)
calib_file = os.path.join(data_root, 'calib/%s.txt' % seq)
result_dir = os.path.join(result_root, '%s/trk_withid/%s' % (result_sha, seq))
save_3d_bbox_dir = os.path.join(result_dir, '../../trk_image_vis/%s' % seq); mkdir_if_missing(save_3d_bbox_dir)
# load the list
images_list, num_images = load_list_from_folder(image_dir)
print('number of images to visualize is %d' % num_images)
start_count = 0
for count in range(start_count, num_images):
image_tmp = images_list[count]
if not is_path_exists(image_tmp):
count += 1
continue
image_index = int(fileparts(image_tmp)[1])
image_tmp = np.array(Image.open(image_tmp))
img_height, img_width, img_channel = image_tmp.shape
result_tmp = os.path.join(result_dir, '%06d.txt'%image_index) # load the result
if not is_path_exists(result_tmp): object_res = []
else: object_res = read_label(result_tmp)
print('processing index: %d, %d/%d, results from %s' % (image_index, count+1, num_images, result_tmp))
calib_tmp = Calibration(calib_file) # load the calibration
object_res_filtered = []
for object_tmp in object_res:
if object_tmp.type not in type_whitelist: continue
if hasattr(object_tmp, 'score'):
if object_tmp.score < score_threshold: continue
center = object_tmp.t
object_res_filtered.append(object_tmp)
num_instances = len(object_res_filtered)
save_image_with_3dbbox_gt_path = os.path.join(save_3d_bbox_dir, '%06d.jpg' % (image_index))
show_image_with_boxes(image_tmp, object_res_filtered, [], calib_tmp, save_path=save_image_with_3dbbox_gt_path)
print('number of objects to plot is %d' % (num_instances))
count += 1
if __name__ == "__main__":
if len(sys.argv)!=2:
print("Usage: python visualization.py result_sha(e.g., 3d_det_test)")
sys.exit(1)
result_root = './results'
result_sha = sys.argv[1]
if ('train' in result_sha) or ('val' in result_sha):
print("No image data is provided for %s, please download the KITTI dataset" % split)
sys.exit(1)
elif 'test' in result_sha:
data_root = './data/KITTI/test'
else:
print("No image data is provided for %s, please download the KITTI dataset" % split)
sys.exit(1)
vis(result_sha, data_root, result_root)