-
Notifications
You must be signed in to change notification settings - Fork 56
/
functions.py
executable file
·109 lines (103 loc) · 3.94 KB
/
functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import matplotlib as mpl
if os.environ.get('DISPLAY','') == '':
print('no display found. Using non-interactive Agg backend')
mpl.use('Agg')
import matplotlib.pyplot as plt
import scipy.io as sio
import torch
import torch.nn as nn
import numpy as np
from src.torchviz import make_dot, make_dot_from_trace
class stats:
def __init__(self, path, start_epoch):
if start_epoch is not 0:
stats_ = sio.loadmat(os.path.join(path,'stats.mat'))
data = stats_['data']
content = data[0,0]
self.trainObj = content['trainObj'][:,:start_epoch].squeeze().tolist()
self.trainTop1 = content['trainTop1'][:,:start_epoch].squeeze().tolist()
self.trainTop5 = content['trainTop5'][:,:start_epoch].squeeze().tolist()
self.valObj = content['valObj'][:,:start_epoch].squeeze().tolist()
self.valTop1 = content['valTop1'][:,:start_epoch].squeeze().tolist()
self.valTop5 = content['valTop5'][:,:start_epoch].squeeze().tolist()
if start_epoch is 1:
self.trainObj = [self.trainObj]
self.trainTop1 = [self.trainTop1]
self.trainTop5 = [self.trainTop5]
self.valObj = [self.valObj]
self.valTop1 = [self.valTop1]
self.valTop5 = [self.valTop5]
else:
self.trainObj = []
self.trainTop1 = []
self.trainTop5 = []
self.valObj = []
self.valTop1 = []
self.valTop5 = []
def _update(self, trainObj, top1, top5, valObj, prec1, prec5):
self.trainObj.append(trainObj)
self.trainTop1.append(top1.cpu().numpy())
self.trainTop5.append(top5.cpu().numpy())
self.valObj.append(valObj)
self.valTop1.append(prec1.cpu().numpy())
self.valTop5.append(prec5.cpu().numpy())
def vizNet(model, path):
model.eval()
x = torch.randn(10,3,224,224)
y = model(x)
g = make_dot(y)
g.render(os.path.join(path,'graph'), view=False)
def plot_curve(stats, path, iserr):
trainObj = np.array(stats.trainObj)
valObj = np.array(stats.valObj)
if iserr:
trainTop1 = 100 - np.array(stats.trainTop1)
trainTop5 = 100 - np.array(stats.trainTop5)
valTop1 = 100 - np.array(stats.valTop1)
valTop5 = 100 - np.array(stats.valTop5)
titleName = 'error'
else:
trainTop1 = np.array(stats.trainTop1)
trainTop5 = np.array(stats.trainTop5)
valTop1 = np.array(stats.valTop1)
valTop5 = np.array(stats.valTop5)
titleName = 'accuracy'
epoch = len(trainObj)
figure = plt.figure()
obj = plt.subplot(1,3,1)
obj.plot(range(1,epoch+1),trainObj,'o-',label = 'train')
obj.plot(range(1,epoch+1),valObj,'o-',label = 'val')
plt.xlabel('epoch')
plt.title('objective')
handles, labels = obj.get_legend_handles_labels()
obj.legend(handles[::-1], labels[::-1])
top1 = plt.subplot(1,3,2)
top1.plot(range(1,epoch+1),trainTop1,'o-',label = 'train')
top1.plot(range(1,epoch+1),valTop1,'o-',label = 'val')
plt.title('top1'+titleName)
plt.xlabel('epoch')
handles, labels = top1.get_legend_handles_labels()
top1.legend(handles[::-1], labels[::-1])
top5 = plt.subplot(1,3,3)
top5.plot(range(1,epoch+1),trainTop5,'o-',label = 'train')
top5.plot(range(1,epoch+1),valTop5,'o-',label = 'val')
plt.title('top5'+titleName)
plt.xlabel('epoch')
handles, labels = top5.get_legend_handles_labels()
top5.legend(handles[::-1], labels[::-1])
filename = os.path.join(path, 'net-train.pdf')
figure.savefig(filename, bbox_inches='tight')
plt.close()
def decode_params(input_params):
params = input_params[0]
out_params = []
_start=0
_end=0
for i in range(len(params)):
if params[i] == ',':
out_params.append(float(params[_start:_end]))
_start=_end+1
_end+=1
out_params.append(float(params[_start:_end]))
return out_params