forked from erickrf/assin
-
Notifications
You must be signed in to change notification settings - Fork 3
/
baseline-majority.py
44 lines (32 loc) · 1.44 KB
/
baseline-majority.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# -*- coding: utf-8 -*-
"""
Script implementing the majority baseline for the ASSIN shared task.
For the similarity task, it computes the training data average similarity
and outputs that value for all test pairs. For entailment, it outputs
tags test pairs with the majority class in the training data.
It produces an XML file as the output, which can be evaluated with the
`assin-eval.py` script.
"""
import argparse
from xml.etree.cElementTree import ElementTree as ET
import numpy as np
from collections import Counter
if __name__ == '__main__':
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('train', help='XML file with training data')
parser.add_argument('test', help='XML file with test data')
parser.add_argument('output', help='Output tagged XML file')
args = parser.parse_args()
tree = ET()
root_train = tree.parse(args.train)
similarities_train = np.array([float(pair.get('similarity'))
for pair in root_train])
similarity_avg = similarities_train.mean()
entailments_train = [pair.get('entailment') for pair in root_train]
entailment_counter = Counter(entailments_train)
majority_entailment, _ = entailment_counter.most_common(1)[0]
root_test = tree.parse(args.test)
for pair in root_test:
pair.set('similarity', str(similarity_avg))
pair.set('entailment', majority_entailment)
tree.write(args.output, 'utf-8')