-
Notifications
You must be signed in to change notification settings - Fork 351
/
demo.py
105 lines (79 loc) · 3.28 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# -*- coding: utf-8 -*-
import numpy as np
import tensorflow as tf
from PIL import Image
import time
import yolo_v3
import yolo_v3_tiny
from utils import load_coco_names, draw_boxes, get_boxes_and_inputs, get_boxes_and_inputs_pb, non_max_suppression, \
load_graph, letter_box_image
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string(
'input_img', '', 'Input image')
tf.app.flags.DEFINE_string(
'output_img', '', 'Output image')
tf.app.flags.DEFINE_string(
'class_names', 'coco.names', 'File with class names')
tf.app.flags.DEFINE_string(
'weights_file', 'yolov3.weights', 'Binary file with detector weights')
tf.app.flags.DEFINE_string(
'data_format', 'NCHW', 'Data format: NCHW (gpu only) / NHWC')
tf.app.flags.DEFINE_string(
'ckpt_file', './saved_model/model.ckpt', 'Checkpoint file')
tf.app.flags.DEFINE_string(
'frozen_model', '', 'Frozen tensorflow protobuf model')
tf.app.flags.DEFINE_bool(
'tiny', False, 'Use tiny version of YOLOv3')
tf.app.flags.DEFINE_bool(
'spp', False, 'Use SPP version of YOLOv3')
tf.app.flags.DEFINE_integer(
'size', 416, 'Image size')
tf.app.flags.DEFINE_float(
'conf_threshold', 0.5, 'Confidence threshold')
tf.app.flags.DEFINE_float(
'iou_threshold', 0.4, 'IoU threshold')
tf.app.flags.DEFINE_float(
'gpu_memory_fraction', 1.0, 'Gpu memory fraction to use')
def main(argv=None):
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=FLAGS.gpu_memory_fraction)
config = tf.ConfigProto(
gpu_options=gpu_options,
log_device_placement=False,
)
img = Image.open(FLAGS.input_img)
img_resized = letter_box_image(img, FLAGS.size, FLAGS.size, 128)
img_resized = img_resized.astype(np.float32)
classes = load_coco_names(FLAGS.class_names)
if FLAGS.frozen_model:
t0 = time.time()
frozenGraph = load_graph(FLAGS.frozen_model)
print("Loaded graph in {:.2f}s".format(time.time()-t0))
boxes, inputs = get_boxes_and_inputs_pb(frozenGraph)
with tf.Session(graph=frozenGraph, config=config) as sess:
t0 = time.time()
detected_boxes = sess.run(
boxes, feed_dict={inputs: [img_resized]})
else:
if FLAGS.tiny:
model = yolo_v3_tiny.yolo_v3_tiny
elif FLAGS.spp:
model = yolo_v3.yolo_v3_spp
else:
model = yolo_v3.yolo_v3
boxes, inputs = get_boxes_and_inputs(model, len(classes), FLAGS.size, FLAGS.data_format)
saver = tf.train.Saver(var_list=tf.global_variables(scope='detector'))
with tf.Session(config=config) as sess:
t0 = time.time()
saver.restore(sess, FLAGS.ckpt_file)
print('Model restored in {:.2f}s'.format(time.time()-t0))
t0 = time.time()
detected_boxes = sess.run(
boxes, feed_dict={inputs: [img_resized]})
filtered_boxes = non_max_suppression(detected_boxes,
confidence_threshold=FLAGS.conf_threshold,
iou_threshold=FLAGS.iou_threshold)
print("Predictions found in {:.2f}s".format(time.time() - t0))
draw_boxes(filtered_boxes, img, classes, (FLAGS.size, FLAGS.size), True)
img.save(FLAGS.output_img)
if __name__ == '__main__':
tf.app.run()