-
Notifications
You must be signed in to change notification settings - Fork 4
/
autotuning.py
157 lines (130 loc) · 5.22 KB
/
autotuning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import os
import numpy as np
import tvm
from tvm import autotvm
from tvm import relay
import tvm.relay.testing
from tvm.autotvm.tuner import XGBTuner
from tvm.contrib.util import tempdir
import tvm.contrib.graph_runtime as runtime
import sys
from ctypes import *
from tvm.contrib.download import download_testdata
from tvm.relay.testing.darknet import __darknetffi__
import tvm.relay.testing.yolo_detection
import tvm.relay.testing.darknet
#################################################################
# Parameters Setting
cfg_path = 'yolov3.cfg'
weights_path = 'yolov3.weights'
#target = 'llvm'
target = tvm.target.cuda()
log_file = 'yolov3_auto.log'
#################################################################
# Define Network
# --------------
def get_network(batch_size):
"""Get the symbol definition and random weight of a network"""
MODEL_NAME = 'yolov3'
REPO_URL = 'https://github.com/dmlc/web-data/blob/master/darknet/'
DARKNET_LIB = 'libdarknet2.0.so'
DARKNET_URL = REPO_URL + 'lib/' + DARKNET_LIB + '?raw=true'
lib_path = download_testdata(DARKNET_URL, DARKNET_LIB, module="darknet")
DARKNET_LIB = __darknetffi__.dlopen(lib_path)
net = DARKNET_LIB.load_network(cfg_path.encode('utf-8'), weights_path.encode('utf-8'), 0)
dtype = 'float32'
data = np.empty([1, net.c, net.h, net.w], dtype)
shape_dict = {'data': data.shape}
print("Converting darknet to relay functions...")
mod, params = relay.frontend.from_darknet(net, dtype=dtype, shape=data.shape)
input_shape = data.shape
return mod, params, input_shape
###########################################
# Set Tuning Options
# ------------------
dtype = 'float32'
tuning_option = {
'log_filename': log_file,
'tuner': 'xgb',
'n_trial': 2000,
'early_stopping': 600,
'measure_option': autotvm.measure_option(
builder=autotvm.LocalBuilder(timeout=10),
runner=autotvm.LocalRunner(number=20, repeat=3, timeout=4, min_repeat_ms=150)
),
}
def tune_tasks(tasks,
measure_option,
tuner='xgb',
n_trial=1000,
early_stopping=None,
log_filename='tuning.log',
use_transfer_learning=True,
try_winograd=True):
if try_winograd:
for i in range(len(tasks)):
try: # try winograd template
tsk = autotvm.task.create(tasks[i].name, tasks[i].args,
tasks[i].target, tasks[i].target_host, 'winograd')
input_channel = tsk.workload[1][1]
if input_channel >= 64:
tasks[i] = tsk
except Exception:
pass
# create tmp log file
tmp_log_file = log_filename + ".tmp"
if os.path.exists(tmp_log_file):
os.remove(tmp_log_file)
for i, tsk in enumerate(reversed(tasks)):
prefix = "[Task %2d/%2d] " %(i+1, len(tasks))
# create tuner
if tuner == 'xgb' or tuner == 'xgb-rank':
tuner_obj = XGBTuner(tsk, loss_type='rank')
else:
raise ValueError("Invalid tuner: " + tuner)
if use_transfer_learning:
if os.path.isfile(tmp_log_file):
tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file))
# do tuning
n_trial = min(n_trial, len(tsk.config_space))
tuner_obj.tune(n_trial=n_trial,
early_stopping=early_stopping,
measure_option=measure_option,
callbacks=[
autotvm.callback.progress_bar(n_trial, prefix=prefix),
autotvm.callback.log_to_file(tmp_log_file)])
# pick best records to a cache file
autotvm.record.pick_best(tmp_log_file, log_filename)
os.remove(tmp_log_file)
def tune_and_evaluate(tuning_opt):
# extract workloads from relay program
print("Extract tasks...")
mod, params, input_shape = get_network(batch_size=1)
tasks = autotvm.task.extract_from_program(mod["main"], target=target,
params=params, ops=(relay.op.nn.conv2d,))
# run tuning tasks
print("Tuning...")
tune_tasks(tasks, **tuning_opt)
# compile kernels with history best records
with autotvm.apply_history_best(log_file):
print("Compile...")
with relay.build_config(opt_level=3):
graph, lib, params = relay.build_module.build(
mod, target=target, params=params)
# export library
tmp = tempdir()
filename = "net.tar"
lib.export_library(tmp.relpath(filename))
# load parameters
ctx = tvm.context(str(target), 0)
module = runtime.create(graph, lib, ctx)
data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
module.set_input('data', data_tvm)
module.set_input(**params)
# evaluate
print("Evaluate inference time cost...")
ftimer = module.module.time_evaluator("run", ctx, number=1, repeat=600)
prof_res = np.array(ftimer().results) * 1000 # convert to millisecond
print("Mean inference time (std dev): %.2f ms (%.2f ms)" %
(np.mean(prof_res), np.std(prof_res)))
tune_and_evaluate(tuning_option)