-
Notifications
You must be signed in to change notification settings - Fork 1
/
eval.py
114 lines (93 loc) · 2.82 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import sys
import numpy as np
import argparse
# import SentEval
sys.path.insert(0, './SentEval')
import spacy
import torch
from pytorch_lightning import Trainer
from train import NLINet
from data import SNLIData
spacy_en = spacy.load('en_core_web_sm')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def senteval(checkpoint_path, params_senteval):
import senteval
model = NLINet.load_from_checkpoint(checkpoint_path).to(device)
model.eval()
data = SNLIData(batch_size=128)
def prepare(params, samples):
params.vocab = data.get_vocab()
params.max_len = np.max([len(x) for x in samples])
params.wvec_dim = 300
return
def batcher(params, batch):
samples = []
for sent in batch:
sent_idxs = []
for token in sent:
sent_idxs.append(params.vocab[token])
# padding
for _ in range(len(sent_idxs) + 1, params.max_len + 1):
sent_idxs.append(params.vocab["<pad>"])
samples.append(sent_idxs)
embed = torch.LongTensor(samples).to(device)
return model.model.encode(embed).detach().cpu().numpy()
se = senteval.engine.SE(params_senteval, batcher, prepare)
transfer_tasks = [
'MR',
'CR',
'SUBJ',
'MPQA',
'SST2',
'TREC',
'MRPC',
'SICKEntailment',
]
results = se.eval(transfer_tasks)
print(results)
return results
def process_senteval_result(result_dict):
devaccs = []
ndevs = []
for task, res in result_dict.items():
try:
devaccs.append(res['devacc'])
ndevs.append(res['ndev'])
except:
pass
print('Macro accuracy:', np.mean(devaccs))
macros = [(ndevs[i] / sum(ndevs)) * devaccs[i] for i in range(len(devaccs))]
print('Micro accuracy:', sum(macros))
def snli(checkpoint_path):
data = SNLIData(batch_size=128)
_, _, test_loader = data.get_iters()
model = NLINet.load_from_checkpoint(checkpoint_path).to(device)
model.eval()
trainer = Trainer(weights_summary=None)
test_result = trainer.test(model, test_dataloaders=test_loader, verbose=True)
print(test_result)
return test_result
if __name__ == "__main__":
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--checkpoint_path',
type=str,
help='Path to model checkpoint')
args = parser.parse_args()
params_senteval = {
'task_path': './SentEval/data/',
'usepytorch': False,
'kfold': 5,
'classifier': {
'nhid': 0,
'optim': 'rmsprop',
'batch_size': 128,
'tenacity': 3,
'epoch_size': 2
}
}
print("######### Evaluating SNLI #########")
snli(args.checkpoint_path)
print("######### Evaluating SENTEVAL #########")
result_dict = senteval(args.checkpoint_path, params_senteval)
process_senteval_result(result_dict)