-
Notifications
You must be signed in to change notification settings - Fork 203
/
test.py
111 lines (86 loc) · 4.13 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import util
import argparse
from model import *
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
parser = argparse.ArgumentParser()
parser.add_argument('--device',type=str,default='cuda:3',help='')
parser.add_argument('--data',type=str,default='data/METR-LA',help='data path')
parser.add_argument('--adjdata',type=str,default='data/sensor_graph/adj_mx.pkl',help='adj data path')
parser.add_argument('--adjtype',type=str,default='doubletransition',help='adj type')
parser.add_argument('--gcn_bool',action='store_true',help='whether to add graph convolution layer')
parser.add_argument('--aptonly',action='store_true',help='whether only adaptive adj')
parser.add_argument('--addaptadj',action='store_true',help='whether add adaptive adj')
parser.add_argument('--randomadj',action='store_true',help='whether random initialize adaptive adj')
parser.add_argument('--seq_length',type=int,default=12,help='')
parser.add_argument('--nhid',type=int,default=32,help='')
parser.add_argument('--in_dim',type=int,default=2,help='inputs dimension')
parser.add_argument('--num_nodes',type=int,default=207,help='number of nodes')
parser.add_argument('--batch_size',type=int,default=64,help='batch size')
parser.add_argument('--learning_rate',type=float,default=0.001,help='learning rate')
parser.add_argument('--dropout',type=float,default=0.3,help='dropout rate')
parser.add_argument('--weight_decay',type=float,default=0.0001,help='weight decay rate')
parser.add_argument('--checkpoint',type=str,help='')
parser.add_argument('--plotheatmap',type=str,default='True',help='')
args = parser.parse_args()
def main():
device = torch.device(args.device)
_, _, adj_mx = util.load_adj(args.adjdata,args.adjtype)
supports = [torch.tensor(i).to(device) for i in adj_mx]
if args.randomadj:
adjinit = None
else:
adjinit = supports[0]
if args.aptonly:
supports = None
model = gwnet(device, args.num_nodes, args.dropout, supports=supports, gcn_bool=args.gcn_bool, addaptadj=args.addaptadj, aptinit=adjinit)
model.to(device)
model.load_state_dict(torch.load(args.checkpoint))
model.eval()
print('model load successfully')
dataloader = util.load_dataset(args.data, args.batch_size, args.batch_size, args.batch_size)
scaler = dataloader['scaler']
outputs = []
realy = torch.Tensor(dataloader['y_test']).to(device)
realy = realy.transpose(1,3)[:,0,:,:]
for iter, (x, y) in enumerate(dataloader['test_loader'].get_iterator()):
testx = torch.Tensor(x).to(device)
testx = testx.transpose(1,3)
with torch.no_grad():
preds = model(testx).transpose(1,3)
outputs.append(preds.squeeze())
yhat = torch.cat(outputs,dim=0)
yhat = yhat[:realy.size(0),...]
amae = []
amape = []
armse = []
for i in range(12):
pred = scaler.inverse_transform(yhat[:,:,i])
real = realy[:,:,i]
metrics = util.metric(pred,real)
log = 'Evaluate best model on test data for horizon {:d}, Test MAE: {:.4f}, Test MAPE: {:.4f}, Test RMSE: {:.4f}'
print(log.format(i+1, metrics[0], metrics[1], metrics[2]))
amae.append(metrics[0])
amape.append(metrics[1])
armse.append(metrics[2])
log = 'On average over 12 horizons, Test MAE: {:.4f}, Test MAPE: {:.4f}, Test RMSE: {:.4f}'
print(log.format(np.mean(amae),np.mean(amape),np.mean(armse)))
if args.plotheatmap == "True":
adp = F.softmax(F.relu(torch.mm(model.nodevec1, model.nodevec2)), dim=1)
device = torch.device('cpu')
adp.to(device)
adp = adp.cpu().detach().numpy()
adp = adp*(1/np.max(adp))
df = pd.DataFrame(adp)
sns.heatmap(df, cmap="RdYlBu")
plt.savefig("./emb"+ '.pdf')
y12 = realy[:,99,11].cpu().detach().numpy()
yhat12 = scaler.inverse_transform(yhat[:,99,11]).cpu().detach().numpy()
y3 = realy[:,99,2].cpu().detach().numpy()
yhat3 = scaler.inverse_transform(yhat[:,99,2]).cpu().detach().numpy()
df2 = pd.DataFrame({'real12':y12,'pred12':yhat12, 'real3': y3, 'pred3':yhat3})
df2.to_csv('./wave.csv',index=False)
if __name__ == "__main__":
main()