-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathserini_run_to_json.py
85 lines (62 loc) · 2.45 KB
/
serini_run_to_json.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import csv
import click
import json
import pandas as pd
import tqdm
from ranx import Run, Qrels, evaluate
@click.command()
@click.option(
"--data_folder",
type=str,
required=True,
)
def main(data_folder):
qrel_df = pd.read_csv(f'{data_folder}/qrels/test.tsv', sep='\t')
qrels = {}
for index, row in qrel_df.iterrows():
q_id = str(row['query-id'])
if not q_id in qrels:
qrels[q_id] = {}
qrels[q_id][str(row['corpus-id'])] = row['score']
bm25_serini = pd.read_csv(f'{data_folder}/run.txt', sep=' ', header=None, quoting=csv.QUOTE_NONE)
run = {}
for _id, row in tqdm.tqdm(bm25_serini[[0, 2, 4]].iterrows()):
q_id = str(row[0])
if not q_id in run:
run[q_id] = {}
run[q_id][str(row[2])] = row[4]
with open(f'{data_folder}/qrels.json', 'w') as f:
json.dump(qrels, f, indent=2)
with open(f'{data_folder}/bm25_run.json', 'w') as f:
json.dump(run, f, indent=2)
ranx_qrels = Qrels(qrels)
ranx_run = Run(run)
ranx_run.name = 'BM25'
ranx_run.save(f'{data_folder}/bm25_run.lz4')
print(evaluate(ranx_qrels, ranx_run, ['map@100', 'mrr@10', 'recall@100', 'precision@5', 'ndcg@10'], make_comparable=True))
"""
if data_folder != 'nq' and data_folder != 'climate-fever':
qrel_df = pd.read_csv(f'{data_folder}/qrels/dev.tsv', sep='\t')
qrels = {}
for index, row in qrel_df.iterrows():
q_id = str(row['query-id'])
if not q_id in qrels:
qrels[q_id] = {}
qrels[q_id][str(row['corpus-id'])] = row['score']
bm25_serini = pd.read_csv(f'{data_folder}/dev_run.txt', sep=' ', header=None, quoting=csv.QUOTE_NONE)
run = {}
for _id, row in tqdm.tqdm(bm25_serini[[0, 2, 4]].iterrows()):
q_id = str(row[0])
if not q_id in run:
run[q_id] = {}
run[q_id][str(row[2])] = row[4]
with open(f'{data_folder}/dev_qrels.json', 'w') as f:
json.dump(qrels, f, indent=2)
with open(f'{data_folder}/dev_bm25_run.json', 'w') as f:
json.dump(run, f, indent=2)
ranx_qrels = Qrels(qrels)
ranx_run = Run(run)
print(evaluate(ranx_qrels, ranx_run, ['map@100', 'mrr@10', 'recall@100', 'precision@5', 'ndcg@10']))
"""
if __name__ == '__main__':
main()