forked from HanbaekLyu/RNN_NMF_chatbot
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataprocess.py
129 lines (106 loc) · 4.71 KB
/
dataprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import torch
from torch.jit import script, trace
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import csv
import random
import re
import os
import unicodedata
import codecs
from io import open
import itertools
import math
USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")
def printLines(file, n=10):
with open(file, 'rb') as datafile:
lines = datafile.readlines()
for line in lines[:n]:
print(line)
#printLines(r"C:\Users\wxwyl\Desktop\wyl code\cornell-movie\cornell movie-dialogs corpus\movie_lines.txt")
# Splits each line of the file into a dictionary of fields
def loadLines(fileName, fields):
lines = {}
with open(fileName, 'r', encoding='iso-8859-1') as f:
for line in f:
values = line.split(" +++$+++ ")
# Extract fields
lineObj = {}
for i, field in enumerate(fields):
lineObj[field] = values[i]
lines[lineObj['lineID']] = lineObj
return lines
# Groups fields of lines from `loadLines` into conversations based on *movie_conversations.txt*
def loadConversations(fileName, lines, fields):
conversations = []
with open(fileName, 'r', encoding='iso-8859-1') as f:
for line in f:
values = line.split(" +++$+++ ")
# Extract fields
convObj = {}
for i, field in enumerate(fields):
convObj[field] = values[i]
# Convert string to list (convObj["utteranceIDs"] == "['L598485', 'L598486', ...]")
utterance_id_pattern = re.compile('L[0-9]+')
lineIds = utterance_id_pattern.findall(convObj["utteranceIDs"])
# Reassemble lines
convObj["lines"] = []
for lineId in lineIds:
convObj["lines"].append(lines[lineId])
conversations.append(convObj)
return conversations
# Extracts pairs of sentences from conversations
def extractSentencePairs(conversations):
qa_pairs = []
for conversation in conversations:
# Iterate over all the lines of the conversation
for i in range(len(conversation["lines"]) - 1): # We ignore the last line (no answer for it)
inputLine = conversation["lines"][i]["text"].strip()
targetLine = conversation["lines"][i+1]["text"].strip()
# Filter wrong samples (if one of the lists is empty)
if inputLine and targetLine:
qa_pairs.append([inputLine, targetLine])
return qa_pairs
# Define path to new file
datafile = r"C:\Users\wxwyl\Desktop\wylcode\cornell-movie-seq2seqta-new\cornell movie-dialogs corpus\formatted_movie_lines.txt"
datafile_validation = r"C:\Users\wxwyl\Desktop\wylcode\cornell-movie-seq2seqta-new\cornell movie-dialogs corpus\formatted_movie_lines_validation.txt"
delimiter = '\t'
# Unescape the delimiter
delimiter = str(codecs.decode(delimiter, "unicode_escape"))
# Initialize lines dict, conversations list, and field ids
lines = {}
conversations = []
MOVIE_LINES_FIELDS = ["lineID", "characterID", "movieID", "character", "text"]
MOVIE_CONVERSATIONS_FIELDS = ["character1ID", "character2ID", "movieID", "utteranceIDs"]
if __name__ == '__main__':
# Load lines and process conversations
print("\nProcessing corpus...")
lines = loadLines(r"C:\Users\wxwyl\Desktop\wylcode\cornell-movie-seq2seqta-new\cornell movie-dialogs corpus\movie_lines.txt", MOVIE_LINES_FIELDS)
print("\nLoading conversations...")
conversations = loadConversations(r"C:\Users\wxwyl\Desktop\wylcode\cornell-movie-seq2seqta-new\cornell movie-dialogs corpus\movie_conversations.txt",
lines, MOVIE_CONVERSATIONS_FIELDS)
# Write new csv file
number = 1
print("\nWriting newly formatted file...")
with open(datafile_validation, 'w', encoding='utf-8') as outputfile:
writer = csv.writer(outputfile, delimiter=delimiter, lineterminator='\n')
for pair in extractSentencePairs(conversations):
if number <= 60000:
writer.writerow(pair)
number += 1
number = 1
with open(datafile, 'w', encoding='utf-8') as outputfile:
writer = csv.writer(outputfile, delimiter=delimiter, lineterminator='\n')
for pair in extractSentencePairs(conversations):
if number > 60000:
writer.writerow(pair)
number += 1
# Print a sample of lines
#print("\nSample lines from file:")
#printLines(datafile)