-
Notifications
You must be signed in to change notification settings - Fork 0
/
collect_contexts.py
106 lines (85 loc) · 3.88 KB
/
collect_contexts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Example Usage:
python collect_contexts.py \
--corpus_file resources/data/Topical-Chat/KGD/train.json \
--outfile resources/data/Topical-Chat/KGD/contexts/train_questions.txt --extract q
python collect_contexts.py \
--corpus_file resources/data/Topical-Chat/KGD/train.json \
--outfile resources/data/Topical-Chat/KGD/contexts/train_exclamations.txt --extract e
"""
import argparse
from pathlib import Path
import re
import spacy
from datasets import load_dataset
from tqdm import tqdm
def set_args():
ap = argparse.ArgumentParser()
ap.add_argument('--corpus_file', type=str, default='resources/data/Topical-Chat/KGD/train.json', help='dataset name')
ap.add_argument('--outfile', type=str, required=True, help='output file name')
ap.add_argument('--max_contexts', type=int, default=None, help='max number of contexts')
ap.add_argument('--extract', type=str, default='q', choices=['q','e'], help='q for questions, e for exclamations')
return ap.parse_args()
def load_corpus(corpus_file):
if 'topical-chat' in corpus_file.lower():
extension = corpus_file.split(".")[-1]
dataset_dict = load_dataset(extension, data_files=corpus_file)
corpus_sents = dataset_dict['train']['target']
elif 'commonsense-dialogues' in corpus_file.lower():
dataset_dict = load_dataset('json', data_files=corpus_file)
corpus_sents = dataset_dict['train']['target']
# note: for commonsense, we also consider sentences from the context
for turns in dataset_dict['train']['turns']:
corpus_sents.extend(turn for turn in turns if turn != '')
elif 'daily-dialog' in corpus_file.lower():
dataset_dict = load_dataset('json', data_files=corpus_file)
corpus_sents = dataset_dict['train']['target']
print(f'Corpus sentences: {len(corpus_sents)}')
return corpus_sents
def clean(string):
"""remove speaker tags"""
return re.sub(r'<speaker\d>\s?', '', string).strip()
def extract_questions(corpus, nlp, sentence_length_threshold=6):
"""extract questions from corpus"""
questions = set()
for doc in tqdm(nlp.pipe(corpus, batch_size=100, n_process=8), total=len(corpus)):
for sent in doc.sents:
if len(sent) >= sentence_length_threshold and sent.text.strip().endswith('?'):
questions.add(clean(sent.text))
print(f'Found {len(questions)} questions in corpus')
return questions
def extract_exclamations(corpus, nlp, sentence_length_threshold=6):
"""extract exlamations from corpus"""
exlamations = set()
for doc in tqdm(nlp.pipe(corpus, batch_size=100, n_process=8), total=len(corpus)):
for sent in doc.sents:
if len(sent) >= sentence_length_threshold and sent.text.strip().endswith('!'):
exlamations.add(clean(sent.text))
print(f'Found {len(exlamations)} exlamations in corpus')
return exlamations
def write_to_outfile(iterable, outfile, total=None):
Path(outfile).parent.mkdir(parents=True, exist_ok=True)
with open(outfile, 'w', encoding='utf8') as f:
for i, item in enumerate(iterable):
if total is not None and i > total:
break
f.write(item + '\n')
print(f'Wrote {i+1} items to {outfile}')
return
def main(args):
# Load dataset
corpus = load_corpus(args.corpus_file)
# Load spacy model
nlp = spacy.load('en_core_web_sm')
nlp.add_pipe("sentencizer")
if args.extract == 'q': # Extract questions
questions = extract_questions(corpus, nlp)
write_to_outfile(questions, args.outfile, args.max_contexts)
elif args.extract == 'e': # Extract exclamations
exlamations = extract_exclamations(corpus, nlp)
write_to_outfile(exlamations, args.outfile, args.max_contexts)
if __name__ == '__main__':
args = set_args()
main(args)