forked from HawkAaron/RNN-Transducer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval.py
80 lines (71 loc) · 2.94 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import argparse
import logging
import math
import os
import time
import editdistance
import kaldi_io
import mxnet as mx
import numpy as np
from model import Transducer, RNNModel
from DataLoader import TokenAcc, rephone
parser = argparse.ArgumentParser(description='MXNet Autograd RNN/LSTM Acoustic Model on TIMIT.')
parser.add_argument('model', help='trained model filename')
parser.add_argument('--beam', type=int, default=0)
parser.add_argument('--ctc', default=False, action='store_true', help='decode CTC acoustic model')
parser.add_argument('--bi', default=False, action='store_true', help='bidirectional LSTM')
parser.add_argument('--dataset', default='test')
parser.add_argument('--out', type=str, default='', help='decoded result output dir')
parser.add_argument('--cuda', action='store_true', help='use gpu')
args = parser.parse_args()
logdir = args.out if args.out else os.path.dirname(args.model) + '/decode.log'
logging.basicConfig(format='%(asctime)s: %(message)s', datefmt="%H:%M:%S", filename=logdir, level=logging.INFO)
context = mx.gpu(0) if args.cuda else mx.cpu(0)
# Load model
Model = RNNModel if args.ctc else Transducer
model = Model(62, 250, 3, bidirectional=args.bi)
model.collect_params().load(args.model, context)
# data set
feat = 'ark:copy-feats scp:data/{}/feats.scp ark:- | apply-cmvn --utt2spk=ark:data/{}/utt2spk scp:data/{}/cmvn.scp ark:- ark:- |\
add-deltas --delta-order=2 ark:- ark:- | nnet-forward data/final.feature_transform ark:- ark:- |'.format(args.dataset, args.dataset, args.dataset)
with open('data/'+args.dataset+'/text', 'r') as f:
label = {}
for line in f:
line = line.split()
label[line[0]] = line[1:]
with open('conf/phones.60-48-39.map', 'r') as f:
pmap = {rephone[0]:rephone[0]}
for line in f:
line = line.split()
if len(line) < 3: pmap[line[0]] = rephone[0]
else: pmap[line[0]] = line[2]
print(pmap)
def distance(y, t, blank=rephone[0]):
def remap(y, blank):
prev = blank
seq = []
for i in y:
if i != blank and i != prev: seq.append(i)
prev = i
return seq
y = remap(y, blank)
t = remap(t, blank)
return y, t, editdistance.eval(y, t)
def decode():
# TODO seperate decode and score
logging.info('Decoding Transduction model:')
err = cnt = 0
for i, (k, v) in enumerate(kaldi_io.read_mat_ark(feat)):
xs = mx.nd.array(v[None, ...]).as_in_context(context)
if args.beam > 0:
y, nll = model.beam_search(xs, args.beam)
else:
y, nll = model.greedy_decode(xs)
y = [pmap[rephone[i]] for i in y]
t = [pmap[i] for i in label[k]]
y, t, e = distance(y, t)
err += e; cnt += len(t)
logging.info('[{}]: {}'.format(k, ' '.join(t)))
logging.info('[{}]: {}\nlog-likelihood: {:.2f}\n'.format(k, ' '.join(y), nll))
logging.info('{} set Transducer PER {:.2f}%\n'.format(args.dataset.capitalize(), 100*err/cnt))
decode()