-
Notifications
You must be signed in to change notification settings - Fork 2
/
predictor.py
1872 lines (1658 loc) · 78.7 KB
/
predictor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""CNN for predicting activity of a guide sequence: classification and
regression.
"""
import argparse
from collections import defaultdict
import gzip
import os
import pickle
import fnn
import parse_data
import numpy as np
import scipy
import sklearn
import sklearn.metrics
import tensorflow as tf
__author__ = 'Hayden Metsky <hayden@mit.edu>'
def parse_args():
"""Parse arguments.
Returns:
argument namespace
"""
# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument('--load-model',
help=("Path from which to load parameters and model weights "
"for model found by hyperparameter search; if set, "
"any other arguments provided about the model "
"architecture or hyperparameters will be overridden and "
"this will skip training and only test the model"))
parser.add_argument('--load-model-as-tf-savedmodel',
help=("Path to directory containing a model in TensorFlow's "
"SavedModel architecture; this cannot be set along "
"with --load-model"))
parser.add_argument('--dataset',
choices=['cas13'],
default='cas13',
help=("Dataset to use."))
parser.add_argument('--cas13-subset',
choices=['exp', 'pos', 'neg', 'exp-and-pos'],
help=("Use a subset of the Cas13 data. See parse_data module "
"for descriptions of the subsets. To use all data, do not "
"set."))
parser.add_argument('--cas13-classify',
action='store_true',
help=("If set, only classify Cas13 activity into inactive/active"))
parser.add_argument('--cas13-regress-on-all',
action='store_true',
help=("If set, perform regression for Cas13 data on all data "
"(this can be reduced using --cas13-subset)"))
parser.add_argument('--cas13-regress-only-on-active',
action='store_true',
help=("If set, perform regression for Cas13 data only on the "
"active class"))
parser.add_argument('--cas13-normalize-crrna-activity',
action='store_true',
help=("If set, normalize the activity of each crRNA (guide) "
"across its targets to have mean 0 and stdev 1; this means "
"prediction is performed based on target differences (e.g., "
"mismatches) rather than inherent sequence of the crRNA"))
parser.add_argument('--cas13-use-difference-from-wildtype-activity',
action='store_true',
help=("If set, use the activity value of a guide g and target t "
"pair to be the difference between the measured activity of "
"g-t and the mean activity between g and all wildtype "
"(matching) targets of g; this means prediction is "
"performed based on targeted differences (e.g., mismatches) "
"rather than inherent sequence of the crRNA"))
parser.add_argument('--use-median-measurement',
action='store_true',
help=("If set, use the median measurment across replicates "
"(instead, resample)"))
parser.add_argument('--context-nt',
type=int,
default=10,
help=("nt of target sequence context to include alongside each "
"guide"))
parser.add_argument('--conv-filter-width',
type=int,
nargs='+',
help=("Width of the convolutional filter (nt) (or multiple widths "
"to perform parallel convolutions). If not set, do not "
"use convolutional layers (or the batch norm or pooling "
"that follow it)."))
parser.add_argument('--conv-num-filters',
type=int,
default=20,
help=("Number of convolutional filters (i.e., output channels) "
"in the first layer"))
parser.add_argument('--pool-window-width',
type=int,
default=2,
help=("Width of the pooling window"))
parser.add_argument('--fully-connected-dim',
type=int,
nargs='+',
default=[20],
help=("Dimension of each fully connected layer (i.e., of its "
"output space); specify multiple dimensions for multiple "
"fully connected layers"))
parser.add_argument('--pool-strategy',
choices=['max', 'avg', 'max-and-avg'],
default='max',
help=("For pooling, 'max' only does max pooling; 'avg' only does "
"average pooling; 'max-and-avg' does both and concatenates."))
parser.add_argument('--locally-connected-width',
type=int,
nargs='+',
help=("If set, width (kernel size) of the locally connected layer. "
"Use multiple widths to have parallel locally connected layers "
"that get concatenated. If not set, do not use have a locally "
"connected layer."))
parser.add_argument('--locally-connected-dim',
type=int,
default=1,
help=("Dimension of each locally connected layer (i.e., of its "
"output space)"))
parser.add_argument('--skip-batch-norm',
action='store_true',
help=("If set, skip batch normalization layer"))
parser.add_argument('--add-gc-content',
action='store_true',
help=("If set, add GC content of a guide explicitly into the "
"first fully connected layer of the predictor"))
parser.add_argument('--activation-fn',
choices=['relu', 'elu'],
default='relu',
help=("Activation function to use on hidden layers"))
parser.add_argument('--dropout-rate',
type=float,
default=0.25,
help=("Rate of dropout in between the 2 fully connected layers"))
parser.add_argument('--l2-factor',
type=float,
default=0,
help=("L2 regularization factor. This is applied to weights "
"(kernal_regularizer). Note that this does not regularize "
"bias of activity."))
parser.add_argument('--sample-weight-scaling-factor',
type=float,
default=0,
help=("Hyperparameter p where sample weight is (1 + p*["
"difference in activity from mean wildtype activity]); "
"p must be >= 0. Note that p=0 means that all samples are "
"weighted the same; higher p means that guide-target pairs "
"whose activity deviates from the wildtype from the guide "
"are treated as more important. This is only used for "
"regression."))
parser.add_argument('--batch-size',
type=int,
default=32,
help=("Batch size"))
parser.add_argument('--learning-rate',
type=float,
default=0.00001,
help=("Learning rate for Adam optimizer"))
parser.add_argument('--max-num-epochs',
type=int,
default=1000,
help=("Maximum number of training epochs (this employs early "
"stopping)"))
parser.add_argument('--test-split-frac',
type=float,
default=0.3,
help=("Fraction of the dataset to use for testing the final "
"model"))
parser.add_argument('--seed',
type=int,
default=1,
help=("Random seed"))
parser.add_argument('--serialize-model-with-tf-savedmodel',
help=("Serialize the model with TensorFlow's SavedModel format. "
"This should be a directory in which to serialize the "
"model; this saves the entire model (architecture, "
"weights, training configuration"))
parser.add_argument('--plot-roc-curve',
help=("If set, path to PDF at which to save plot of ROC curve"))
parser.add_argument('--plot-predictions',
help=("If set, path to PDF at which to save plot of predictions "
"vs. true values"))
parser.add_argument('--write-test-tsv',
help=("If set, path to .tsv.gz at which to write test results, "
"including sequences in the test set and predictions "
"(one row per test data point)"))
parser.add_argument('--write-test-confusion-matrices',
help=("If set, path to .tsv.gz at which to write confusion "
"matrices at different thresholds using the test data"))
parser.add_argument('--determine-classifier-threshold-for-precision',
type=float,
default=0.975,
help=("If set, determine thresholds (across folds) that "
"achieve this precision; does not use test data"))
parser.add_argument('--compute-confusion-matrices-across-thresholds-and-folds',
help=("If set, compute confusion matrices at different thresholds "
"(and across folds); does not use test data. The argument "
"should give a path at which this writes a TSV file with "
"the results"))
parser.add_argument('--filter-test-data-by-classification-score',
nargs=2,
help=("If set, only test on data that is classified as active. "
"This consists of 2 arguments: (1) path to TSV file "
"written by test functions with classification scores; "
"(2) score to use as threshold for classifying (>= "
"threshold is active). This is useful when evaluating "
"regression models trained on active data points; we "
"want to test only on data that has been classified as "
"active."))
args = parser.parse_args()
# Print the arguments provided
print(args)
return args
def set_seed(seed):
"""Set tensorflow and numpy seed.
Args:
seed: random seed
"""
tf.random.set_seed(seed)
np.random.seed(seed)
def read_data(args, split_frac=None, make_feats_for_baseline=None):
"""Read input/output data.
Args:
args: argument namespace
split_frac: if set, (train, validate, test) fractions (must sum
to 1); if None, use 0.3 for the test set, 0.7*(2/3) for the
train set, and 0.7*(1/3) for the validate set
use_validation: if True, have the validation set be 1/3 of what would
be the training set (and the training set be the other 2/3); if
False, do not have a validation set
make_feats_for_baseline: if set, make feature vector for baseline
models; see parse_data module for description of values
Returns:
data parser object from parse_data
"""
if make_feats_for_baseline is not None and args.dataset != 'cas13':
raise Exception("make_feats_for_baseline only works with Cas13 data")
# Read data
if args.dataset == 'cas13':
parser_class = parse_data.Cas13ActivityParser
subset = args.cas13_subset
if args.cas13_classify:
regression = False
else:
regression = True
if split_frac is None:
test_frac = 0.3
train_frac = (1.0 - test_frac) * (2.0/3.0)
validation_frac = (1.0 - test_frac) * (1.0/3.0)
else:
train_frac, validation_frac, test_frac = split_frac
data_parser = parser_class(
subset=subset,
context_nt=args.context_nt,
split=(train_frac, validation_frac, test_frac),
shuffle_seed=args.seed,
stratify_by_pos=True,
use_median_measurement=args.use_median_measurement)
if args.dataset == 'cas13':
classify_activity = args.cas13_classify
regress_on_all = args.cas13_regress_on_all
regress_only_on_active = args.cas13_regress_only_on_active
data_parser.set_activity_mode(
classify_activity, regress_on_all, regress_only_on_active)
if make_feats_for_baseline is not None:
data_parser.set_make_feats_for_baseline(make_feats_for_baseline)
if args.cas13_normalize_crrna_activity:
data_parser.set_normalize_crrna_activity()
if args.cas13_use_difference_from_wildtype_activity:
data_parser.set_use_difference_from_wildtype_activity()
data_parser.read()
x_train, y_train = data_parser.train_set()
x_validate, y_validate = data_parser.validate_set()
x_test, y_test = data_parser.test_set()
# Print the size of each data set
data_sizes = 'DATA SIZES - Train: {}, Validate: {}, Test: {}'
print(data_sizes.format(len(x_train), len(x_validate), len(x_test)))
if regression:
# Print the mean outputs and its variance
print('Mean train output: {}'.format(np.mean(y_train)))
print('Variance of train output: {}'.format(np.var(y_train)))
else:
# Print the fraction of the training data points that are in each class
classes = set(tuple(y) for y in y_train)
for c in classes:
num_c = sum(1 for y in y_train if tuple(y) == c)
frac_c = float(num_c) / len(y_train)
frac_c_msg = 'Fraction of train data in class {}: {}'
print(frac_c_msg.format(c, frac_c))
if len(x_validate) == 0:
print('No validation data')
else:
for c in classes:
num_c = sum(1 for y in y_validate if tuple(y) == c)
frac_c = float(num_c) / len(y_validate)
frac_c_msg = 'Fraction of validate data in class {}: {}'
print(frac_c_msg.format(c, frac_c))
for c in classes:
num_c = sum(1 for y in y_test if tuple(y) == c)
frac_c = float(num_c) / len(y_test)
frac_c_msg = 'Fraction of test data in class {}: {}'
print(frac_c_msg.format(c, frac_c))
if args.dataset == 'cas13' and args.cas13_classify:
print('Note that inactive=1 and active=0')
return data_parser
def make_dataset_and_batch(x, y, batch_size=32):
"""Make tensorflow dataset and batch.
Args:
x: input data
y: outputs (labels if classification)
batch_size: batch size
Returns:
batched tf.data.Dataset object
"""
return tf.data.Dataset.from_tensor_slices((x, y)).batch(batch_size)
def load_model(load_path, params, x_train, y_train):
"""Construct model and load weights according to hyperparameter search.
Args:
load_path: path containing model weights
params: dict of parameters
x_train, y_train: train data (only needed for data shape and class
weights)
Returns:..
fnn.CasCNNWithParallelFilters object
"""
# First construct the model
model = construct_model(params, x_train.shape,
regression=params['regression'],
y_train=y_train,
compile_for_keras=True)
# Note: Previoulsly, this would have to train the model on one
# data point (reason below); however, this is no longer needed with Keras
# See https://www.tensorflow.org/beta/guide/keras/saving_and_serializing
# for details on loading a serialized subclassed model
# To initialize variables used by the optimizers and any stateful metric
# variables, we need to train it on some data before calling `load_weights`;
# note that it appears this is necessary (otherwise, there are no variables
# in the model, and nothing gets loaded)
# Only train the models on one data point, and for 1 epoch
def copy_weights(model):
# Copy weights, so we can verify that they changed after loading
return [tf.Variable(w) for w in model.weights]
def weights_are_eq(weights1, weights2):
# Determine whether weights1 == weights2
for w1, w2 in zip(weights1, weights2):
# 'w1' and 'w2' are each collections of weights (e.g., the kernel
# for some layer); they are tf.Variable objects (effectively,
# tensors)
# Make a tensor containing element-wise boolean comparisons (it
# is a 1D tensor with True/False)
elwise_eq = tf.equal(w1, w2)
# Check if all elements in 'elwise_eq' are True (this will make a
# Tensor with one element, True or False)
all_are_eq_tensor = tf.reduce_all(elwise_eq)
# Convert the tensor 'all_are_eq_tensor' to a boolean
all_are_eq = all_are_eq_tensor.numpy()
if not all_are_eq:
return False
return True
def load_weights(model, fn):
# Load weights
# There are some concerns about whether weights are actually being
# loaded (e.g., https://github.com/tensorflow/tensorflow/issues/27937),
# so check that they have changed after calling `load_weights`
# Use expect_partial() to silence warnings because this will not
# load optimizer parameters, which are loaded in construct_model()
w_before = copy_weights(model)
w_before2 = copy_weights(model)
model.load_weights(os.path.join(load_path, fn)).expect_partial()
w_after = copy_weights(model)
w_after2 = copy_weights(model)
assert (weights_are_eq(w_before, w_before2) is True)
assert (weights_are_eq(w_before, w_after) is False)
assert (weights_are_eq(w_after, w_after2) is True)
load_weights(model, 'model.weights')
return model
def construct_model(params, shape, regression=False, compile_for_keras=True,
y_train=None, parallelize_over_gpus=False):
"""Construct model.
This uses the fnn module.
This can also compile the model for Keras, to use multiple GPUs if
available.
Args:
params: dict of hyperparameters
shape: shape of input data; only used for printing model summary
regression: if True, perform regression; if False, classification
compile_for_keras: if set, compile for keras
y_train: training data to use for computing class weights; only needed
if compile_for_keras is True and regression is False
parallelize_over_gpus: if True, parallelize over all available GPUs
Returns:
fnn.CasCNNWithParallelFilters object
"""
if not compile_for_keras:
# Just return a model
return fnn.construct_model(params, shape, regression=regression)
def make():
model = fnn.construct_model(params, shape, regression=regression)
# Define an optimizer, loss, metrics, etc.
if model.regression:
# When doing regression, sometimes the output would always be the
# same value regardless of input; decreasing the learning rate fixed this
optimizer = tf.keras.optimizers.Adam(lr=model.learning_rate)
loss = 'mse'
# Note that using other custom metrics like R^2, Pearson, etc. (as
# implemented above) seems to raise errors; they are really only
# needed during testing
metrics = ['mse', 'mae']
model.class_weight = None
else:
optimizer = tf.keras.optimizers.Adam(lr=model.learning_rate)
loss = 'binary_crossentropy' # using class_weight should weight
# Note that using other custom metrics like auROC (as implemented
# above) seems to raise errors; they are really only needed during
# testing
metrics = ['bce', 'accuracy']
assert y_train is not None
y_train_labels = [y_train[i][0] for i in range(len(y_train))]
class_weight = sklearn.utils.class_weight.compute_class_weight(
'balanced', sorted(np.unique(y_train_labels)), y_train_labels)
model.class_weight = {i: weight for i, weight in enumerate(class_weight)}
# Compile the model
model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
return model
if parallelize_over_gpus:
# Use a MirroredStrategy to take advantage of multiple GPUs, if there are
# multiple
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
model = make()
else:
model = make()
return model
def pred_from_nt(model, pairs):
"""Predict activity from nucleotide sequence.
Args:
model: model object with call() function
pairs: list of tuples (target with context, guide)
Returns:
output of model.call()
"""
FASTA_CODES = {'A': set(('A')),
'T': set(('T')),
'C': set(('C')),
'G': set(('G')),
'K': set(('G', 'T')),
'M': set(('A', 'C')),
'R': set(('A', 'G')),
'Y': set(('C', 'T')),
'S': set(('C', 'G')),
'W': set(('A', 'T')),
'B': set(('C', 'G', 'T')),
'V': set(('A', 'C', 'G')),
'H': set(('A', 'C', 'T')),
'D': set(('A', 'G', 'T')),
'N': set(('A', 'T', 'C', 'G'))}
onehot_idx = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
def onehot(b):
# One-hot encoding of base b
real_bases = FASTA_CODES[b]
v = [0, 0, 0, 0]
for b_real in real_bases:
assert b_real in onehot_idx.keys()
v[onehot_idx[b_real]] = 1.0 / len(real_bases)
return v
context_nt = model.context_nt
l = 2*context_nt + len(pairs[0][1])
x = np.empty((len(pairs), l, 8), dtype='f')
for i, (target_with_context, guide) in enumerate(pairs):
assert len(target_with_context) == 2*context_nt + len(guide)
# Determine one-hot encodings -- i.e., an input vector
input_vec = []
for pos in range(context_nt):
v_target = onehot(target_with_context[pos])
v_guide = [0, 0, 0, 0]
input_vec += [v_target + v_guide]
for pos in range(len(guide)):
v_target = onehot(target_with_context[context_nt + pos])
v_guide = onehot(guide[pos])
input_vec += [v_target + v_guide]
for pos in range(context_nt):
v_target = onehot(target_with_context[context_nt + len(guide) + pos])
v_guide = [0, 0, 0, 0]
input_vec += [v_target + v_guide]
input_vec = np.array(input_vec, dtype='f')
x[i] = input_vec
pred_activity = model.call(x, training=False)
pred_activity = [p[0] for p in pred_activity.numpy()]
return pred_activity
def load_model_for_cas13_regression_on_active(load_path):
"""Construct model and load parameters and weights.
This wraps load_model(), without the need to specify x_train, etc. for
initializing variables.
Args:
load_path: path containing model weights
Returns:..
fnn.CasCNNWithParallelFilters object
"""
# Load parameters
load_path_params = os.path.join(load_path,
'model.params.pkl')
with open(load_path_params, 'rb') as f:
saved_params = pickle.load(f)
params = {'dataset': 'cas13', 'cas13_subset': 'exp-and-pos',
'cas13_regress_only_on_active': True}
for k, v in saved_params.items():
params[k] = v
# Load data; we only need 1 data point, which is used to initialize
# variables
parser_class = parse_data.Cas13ActivityParser
subset = 'exp-and-pos'
regression = True
test_frac = 0.3
train_frac = (1.0 - test_frac) * (2.0/3.0)
validation_frac = (1.0 - test_frac) * (1.0/3.0)
context_nt = params['context_nt']
data_parser = parser_class(
subset=subset,
context_nt=context_nt,
split=(train_frac, validation_frac, test_frac),
shuffle_seed=1,
stratify_by_pos=True)
data_parser.set_activity_mode(False, False, True)
data_parser.read()
x_train, y_train = data_parser.train_set()
# Load the model
return load_model(load_path, params, x_train, y_train)
def determine_classifier_threshold_for_precision(params, x, y,
num_splits, data_parser, precision_threshold):
"""Find a threshold, via cross-valiation, to achieve a desired precision.
This focuses on precision because it is an important metric for
deploying assays.
It finds the smallest threshold that achieves a desired precision.
It does this across multiple splits of the training data.
Args:
params: model parameters (model should *not* be pre-trained)
x, y: data to perform cross-validation with
num_splits: number of folds to compute threshold
data_parser: object to parse data from parse_data
precision_threshold: desired threshold on precision
Returns:
list of thresholds, one per split
"""
# Construct a function that the test function will callback
best_thresholds = []
def find_threshold(y_true, y_pred):
# Compute threshold
pr_curve = sklearn.metrics.precision_recall_curve(y_true, y_pred)
precision, recall, thresholds = pr_curve
# Find the smallest threshold (highest i) where precision is
# >= precision_threshold
for i, prec in enumerate(precision):
if prec >= precision_threshold:
thres = float(thresholds[i])
break
best_thresholds.append(thres)
import predictor_hyperparam_search as phs
phs.cross_validate(params, x, y, num_splits, False,
callback=find_threshold, dp=data_parser)
return best_thresholds
def compute_confusion_matrix_at_thresholds_across_splits(params, x, y,
num_splits, data_parser, out_tsv=None):
"""Calculate confusion matrix, across different splits.
It calculates the matrix at each output predicted value from the
classifier (each is a threshold that could change the confusion
matrix).
Args:
params: model parameters (model should *not* be pre-trained)
x, y: data to perform cross-validation with
num_splits: number of folds on which to compute confusion matrices
data_parser: object to parse data from parse_data
out_tsv: if set, write the results to this TSV file
Returns:
list [(i, S_i)] where each S_i represents a split of the data and
S_i is a list [(t, d)] where t gives a threshold and d is
a dict giving the number of true positives, false positives,
true negatives, and false negatives (keys: 'tp', 'fp', 'tn', 'fn')
"""
# Construct a function that the test function will callback
curr_split = 0
results = []
def compute_confusion_matrices(y_true, y_pred):
nonlocal curr_split
nonlocal results
y_true = y_true.flatten()
y_pred = y_pred.flatten()
# Test thresholds of 0, 1, and all predicted values
thresholds_to_test = sorted(list(set([0] + list(y_pred) + [1])))
# Calculate a confusion matrix at every threshold
cms = []
for thres in thresholds_to_test:
y_pred_decided = [int(y >= thres) for y in y_pred]
cm = sklearn.metrics.confusion_matrix(y_true, y_pred_decided,
labels=[0, 1])
tn, fp, fn, tp = cm.ravel()
d = {'tn': tn, 'fp': fp, 'fn': fn, 'tp': tp}
cms += [(thres, d)]
results += [(curr_split, cms)]
curr_split += 1
import predictor_hyperparam_search as phs
phs.cross_validate(params, x, y, num_splits, False,
callback=compute_confusion_matrices, dp=data_parser)
if out_tsv is not None:
header = ['split', 'threshold', 'tn', 'fp', 'fn', 'tp']
with gzip.open(out_tsv, 'wt') as fw:
def write_row(row):
fw.write('\t'.join(str(x) for x in row) + '\n')
write_row(header)
for split, cms in results:
for thres, d in cms:
row = [split, thres, d['tn'], d['fp'], d['fn'], d['tp']]
write_row(row)
return results
def compute_confusion_matrix_at_thresholds(y_true, y_pred, out_tsv=None):
"""Calculate confusion matrix on test data.
It calculates the matrix at each output predicted value from the
classifier (each is a threshold that could change the confusion
matrix).
Args:
y_true/y_pred: arrays of true and predicted values for the test data
out_tsv: if set, write the results to this TSV file
Returns:
list [(t, d)] where t gives a threshold and d is a dict giving
statistics at that threshold
"""
y_true = y_true.flatten()
y_pred = y_pred.flatten()
# Test thresholds of 0, 1, and all predicted values
thresholds_to_test = sorted(list(set([0] + list(y_pred) + [1])))
# Calculate a confusion matrix at every threshold
cms = []
for thres in thresholds_to_test:
y_pred_decided = [int(y >= thres) for y in y_pred]
cm = sklearn.metrics.confusion_matrix(y_true, y_pred_decided,
labels=[0, 1])
tn, fp, fn, tp = cm.ravel()
sensitivity = tp / (tp + fn)
specificity = tn / (tn + fp)
youden = sensitivity + specificity - 1
precision = tp / (tp + fp)
npv = tn / (tn + fn)
accuracy = (tp + tn) / (tp + tn + fp + fn)
f1 = 2*tp / (2*tp + fp + fn)
d = {'tn': tn, 'fp': fp, 'fn': fn, 'tp': tp,
'sensitivity': sensitivity, 'specificity': specificity,
'youden': youden, 'precision': precision,
'npv': npv, 'accuracy': accuracy, 'f1': f1}
cms += [(thres, d)]
if out_tsv is not None:
header = ['threshold', 'tn', 'fp', 'fn', 'tp',
'sensitivity', 'specificity', 'youden', 'precision',
'npv', 'accuracy', 'f1']
with gzip.open(out_tsv, 'wt') as fw:
def write_row(row):
fw.write('\t'.join(str(x) for x in row) + '\n')
write_row(header)
for thres, d in cms:
row = [thres] + [d[h] for h in header[1:]]
write_row(row)
return cms
def filter_test_data_by_classification_score(x_test, y_test,
data_parser, classification_test_tsv, score_threshold):
"""Select test data points that are classified as positive.
This is useful if we wish to evaluate regression that was trained on active
data points. We would first classify the test data, and then only
evaluate regression using the data points that are classified
as active.
Args:
x_test, y_test: test data
data_parser: object to parse data from parse_data
classification_test_tsv: the output TSV file ('write_test_tsv')
written by the test functions
score_threshold: classification score (between 0 and 1); deem
all test data with scores >= SCORE_THRESHOLD to be
active/positive
Returns:
list of tuples (x_test, y_test), filtered to only
contain active data points
"""
# Read all rows in classification_test_tsv
header_idx = {}
rows = []
with gzip.open(classification_test_tsv, 'rt') as f:
for i, line in enumerate(f):
ls = line.rstrip().split('\t')
if i == 0:
# Parse header
for j in range(len(ls)):
header_idx[ls[j]] = j
else:
rows += [ls]
rows_new = []
for row in rows:
row_dict = {k: row[header_idx[k]] for k in header_idx.keys()}
rows_new += [row_dict]
rows = rows_new
# Convert x_test, y_test into the same encoding that is
# used in the test TSV; keep just the target and guide
# as strings, and crRNA position, which is enough to
# uniquely identify data points (up to technical replicates)
# In particular, map a tuple of that to indices of the test
# data
encoding_idx = defaultdict(list)
for i in range(len(x_test)):
m = data_parser.seq_features_from_encoding(x_test[i])
crrna_pos = data_parser.pos_for_input(x_test[i])
enc = (m['target'], m['guide'], crrna_pos)
encoding_idx[enc].append(i)
# Keep all test data that is classified as active
# Note that there are replicates, but all the replicates for a single
# guide-target pair will have the same classification score because the
# classification is deterministic and depends only on the guide-target
# sequence (however their measured activity may differ). So if a
# guide-target pair is classified as active, keep all of its replicates;
# and if it is classified as inactive, discard all replicates
x_test_filtered, y_test_filtered = [], []
added_enc = set()
for row in rows:
enc = (row['target'], row['guide'], int(row['crrna_pos']))
if enc in added_enc:
# Already added data points for this
continue
if float(row['predicted_activity']) >= score_threshold:
# Classify this as active, and add all data points (replicates)
# for this guide-target pair
for i in encoding_idx[enc]:
x_test_filtered.append(x_test[i])
y_test_filtered.append(y_test[i])
added_enc.add(enc)
x_test_filtered = np.array(x_test_filtered)
y_test_filtered = np.array(y_test_filtered)
return x_test_filtered, y_test_filtered
#####################################################################
#####################################################################
# Custom functions for training and testing
#####################################################################
#####################################################################
# For classification, use cross-entropy as the loss function
# This expects sigmoids (values in [0,1]) as the output; it will
# transform back to logits (not bounded betweem 0 and 1) before
# calling tf.nn.sigmoid_cross_entropy_with_logits
bce_per_sample = tf.keras.losses.BinaryCrossentropy()
# For regression, use mean squared error as the loss function
mse_per_sample = tf.keras.losses.MeanSquaredError()
# When outputting loss, take the mean across the samples from each batch
train_loss_metric = tf.keras.metrics.Mean(name='train_loss')
validate_loss_metric = tf.keras.metrics.Mean(name='validate_loss')
test_loss_metric = tf.keras.metrics.Mean(name='test_loss')
# Define metrics for regression
# tf.keras.metrics does not have Pearson correlation or Spearman's correlation,
# so we have to define these; note that it becomes much easier to use these
# outside of the tf.function functions rather than inside of them (like the
# other metrics are used)
# This also defines a metric for R^2 (below, R2Score)
# Note that R2Score does not necessarily equal r^2 here, where r is
# pearson_corr. The value R2Score is computed by definition of R^2 (1 minus
# (residual sum of squares)/(total sum of squares)) from the true vs. predicted
# values. This is why R2Score can be negative: it can do an even worse job with
# prediction than just predicting the mean. r is computed by simple least
# squares regression between y_true and y_pred, and finding the Pearson's r of
# this curve (since this is simple linear regression, r^2 should be
# nonnegative). The value R2 measures the goodness-of-fit of the specific
# linear correlation y_pred = y_true, whereas r measures the correlation from
# the regression (y_pred = m*y_true + b).
def pearson_corr(y_true, y_pred):
if len(y_true) < 2:
# Avoid exception
r = np.nan
else:
r, _ = scipy.stats.pearsonr(y_true, y_pred)
return r
def spearman_corr(y_true, y_pred):
if len(y_true) < 2:
# Avoid exception
rho = np.nan
else:
rho, _ = scipy.stats.spearmanr(y_true, y_pred)
return rho
class CustomMetric:
def __init__(self, name):
self.__name__ = name
self.y_true = []
self.y_pred = []
def __call__(self, y_true, y_pred):
# Save y_true and y_pred (tensors) into a list
self.y_true += [y_true]
self.y_pred += [y_pred]
def to_np_array(self):
# Concat tensors and convert to numpy arrays
y_true_np = tf.reshape(tf.concat(self.y_true, 0), [-1]).numpy()
y_pred_np = tf.reshape(tf.concat(self.y_pred, 0), [-1]).numpy()
return y_true_np, y_pred_np
def result(self):
raise NotImplementedError("result() must be implemented in a subclass")
def reset_states(self):
self.y_true = []
self.y_pred = []
class Correlation(CustomMetric):
def __init__(self, corrtype, name='correlation'):
assert corrtype in ('pearson_corr', 'spearman_corr')
if corrtype == 'pearson_corr':
self.corr_fn = pearson_corr
if corrtype == 'spearman_corr':
self.corr_fn = spearman_corr
super().__init__(name)
def result(self):
y_true_np, y_pred_np = super(Correlation, self).to_np_array()
return self.corr_fn(y_true_np, y_pred_np)
class R2Score(CustomMetric):
def __init__(self, name='r2_score'):
super().__init__(name)
def result(self):
y_true_np, y_pred_np = super(R2Score, self).to_np_array()
return sklearn.metrics.r2_score(y_true_np, y_pred_np)
train_mse_metric = tf.keras.metrics.MeanSquaredError(name='train_mse')
train_mse_weighted_metric = tf.keras.metrics.MeanSquaredError(name='train_mse_weighted')
train_mae_metric = tf.keras.metrics.MeanAbsoluteError(name='train_mae')
train_mape_metric = tf.keras.metrics.MeanAbsolutePercentageError(name='train_mape')
train_r2_score_metric = R2Score(name='train_r2_score')
train_pearson_corr_metric = Correlation('pearson_corr', name='train_pearson_corr')
train_spearman_corr_metric = Correlation('spearman_corr', name='train_spearman_corr')
validate_mse_metric = tf.keras.metrics.MeanSquaredError(name='validate_mse')
validate_mse_weighted_metric = tf.keras.metrics.MeanSquaredError(name='validate_mse_weighted')
validate_mae_metric = tf.keras.metrics.MeanAbsoluteError(name='validate_mae')
validate_mape_metric = tf.keras.metrics.MeanAbsolutePercentageError(name='validate_mape')
validate_r2_score_metric = R2Score(name='validate_r2_score')
validate_pearson_corr_metric = Correlation('pearson_corr', name='validate_pearson_corr')
validate_spearman_corr_metric = Correlation('spearman_corr', name='validate_spearman_corr')
test_mse_metric = tf.keras.metrics.MeanSquaredError(name='test_mse')
test_mse_weighted_metric = tf.keras.metrics.MeanSquaredError(name='test_mse_weighted')
test_mae_metric = tf.keras.metrics.MeanAbsoluteError(name='test_mae')
test_mape_metric = tf.keras.metrics.MeanAbsolutePercentageError(name='test_mape')
test_r2_score_metric = R2Score(name='test_r2_score')
test_pearson_corr_metric = Correlation('pearson_corr', name='test_pearson_corr')
test_spearman_corr_metric = Correlation('spearman_corr', name='test_spearman_corr')
# Define metrics for classification
# Report on the accuracy and AUC for each epoch (each metric is updated
# with data from each batch, and computed using data from all batches)
train_bce_metric = tf.keras.metrics.BinaryCrossentropy(name='train_bce')
train_bce_weighted_metric = tf.keras.metrics.BinaryCrossentropy(name='train_bce_weighted')
train_accuracy_metric = tf.keras.metrics.BinaryAccuracy(name='train_accuracy')
train_auc_roc_metric = tf.keras.metrics.AUC(
num_thresholds=200, curve='ROC', name='train_auc_roc')
train_auc_pr_metric = tf.keras.metrics.AUC(
num_thresholds=200, curve='PR', name='train_auc_pr')
validate_bce_metric = tf.keras.metrics.BinaryCrossentropy(name='validate_bce')
validate_bce_weighted_metric = tf.keras.metrics.BinaryCrossentropy(name='validate_bce_weighted')
validate_accuracy_metric = tf.keras.metrics.BinaryAccuracy(name='validate_accuracy')
validate_auc_roc_metric = tf.keras.metrics.AUC(
num_thresholds=200, curve='ROC', name='validate_auc_roc')
validate_auc_pr_metric = tf.keras.metrics.AUC(
num_thresholds=200, curve='PR', name='validate_auc_pr')
test_bce_metric = tf.keras.metrics.BinaryCrossentropy(name='test_bce')
test_bce_weighted_metric = tf.keras.metrics.BinaryCrossentropy(name='test_bce_weighted')
test_accuracy_metric = tf.keras.metrics.BinaryAccuracy(name='test_accuracy')
test_auc_roc_metric = tf.keras.metrics.AUC(
num_thresholds=200, curve='ROC', name='test_auc_roc')
test_auc_pr_metric = tf.keras.metrics.AUC(
num_thresholds=200, curve='PR', name='test_auc_pr')
# Store the model and optimizer as global (module-wide) variables
# If passing them directly to train_step(), validate_step(), and test_step(),
# TensorFlow complains about having to do tf.function retracing, which is
# expensive and due to passing Python objects instead of tensors
_model = None
_optimizer = None
# Train the model using GradientTape; this is called on each batch
def train_step(seqs, outputs, sample_weight=None):
if _model.regression:
loss_fn = mse_per_sample
else:
loss_fn = bce_per_sample
with tf.GradientTape() as tape:
# Compute predictions and loss
# Pass along `training=True` so that this can be given to
# the batchnorm and dropout layers; an alternative to passing
# it along would be to use `tf.keras.backend.set_learning_phase(1)`
# to set the training phase
predictions = _model(seqs, training=True)
prediction_loss = loss_fn(outputs, predictions,
sample_weight=sample_weight)
# Add the regularization losses
regularization_loss = tf.add_n(_model.losses)
loss = prediction_loss + regularization_loss
# Compute gradients and opitmize parameters
gradients = tape.gradient(loss, _model.trainable_variables)