This repository has been archived by the owner on Nov 2, 2023. It is now read-only.
forked from casperg92/MaSIF_colab
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main_training.py
156 lines (134 loc) · 5.24 KB
/
main_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# Standard imports:
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import random_split
from torch_geometric.data import DataLoader
from torch_geometric.transforms import Compose
from pathlib import Path
# Custom data loader and model:
from data import ProteinPairsSurfaces, PairData, CenterPairAtoms
from data import RandomRotationPairAtoms, NormalizeChemFeatures, iface_valid_filter
from model import dMaSIF
from data_iteration import iterate, iterate_surface_precompute
from helper import *
from Arguments import parser
# Parse the arguments, prepare the TensorBoard writer:
args = parser.parse_args()
writer = SummaryWriter("runs/{}".format(args.experiment_name))
model_path = "models/" + args.experiment_name
if not Path("models/").exists():
Path("models/").mkdir(exist_ok=False)
# Ensure reproducibility:
torch.backends.cudnn.deterministic = True
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
np.random.seed(args.seed)
# Create the model, with a warm restart if applicable:
net = dMaSIF(args)
net = net.to(args.device)
# We load the train and test datasets.
# Random transforms, to ensure that no network/baseline overfits on pose parameters:
transformations = (
Compose([NormalizeChemFeatures(), CenterPairAtoms(), RandomRotationPairAtoms()])
if args.random_rotation
else Compose([NormalizeChemFeatures()])
)
# PyTorch geometric expects an explicit list of "batched variables":
batch_vars = ["xyz_p1", "xyz_p2", "atom_coords_p1", "atom_coords_p2"]
# Load the train dataset:
train_dataset = ProteinPairsSurfaces(
"surface_data", ppi=args.search, train=True, transform=transformations
)
train_dataset = [data for data in train_dataset if iface_valid_filter(data)]
train_loader = DataLoader(
train_dataset, batch_size=1, follow_batch=batch_vars, shuffle=True
)
print("Preprocessing training dataset")
train_dataset = iterate_surface_precompute(train_loader, net, args)
# Train/Validation split:
train_nsamples = len(train_dataset)
val_nsamples = int(train_nsamples * args.validation_fraction)
train_nsamples = train_nsamples - val_nsamples
train_dataset, val_dataset = random_split(
train_dataset, [train_nsamples, val_nsamples]
)
# Load the test dataset:
test_dataset = ProteinPairsSurfaces(
"surface_data", ppi=args.search, train=False, transform=transformations
)
test_dataset = [data for data in test_dataset if iface_valid_filter(data)]
test_loader = DataLoader(
test_dataset, batch_size=1, follow_batch=batch_vars, shuffle=True
)
print("Preprocessing testing dataset")
test_dataset = iterate_surface_precompute(test_loader, net, args)
# PyTorch_geometric data loaders:
train_loader = DataLoader(
train_dataset, batch_size=1, follow_batch=batch_vars, shuffle=True
)
val_loader = DataLoader(val_dataset, batch_size=1, follow_batch=batch_vars)
test_loader = DataLoader(test_dataset, batch_size=1, follow_batch=batch_vars)
# Baseline optimizer:
optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, amsgrad=True)
best_loss = 1e10 # We save the "best model so far"
starting_epoch = 0
if args.restart_training != "":
checkpoint = torch.load("models/" + args.restart_training)
net.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
starting_epoch = checkpoint["epoch"]
best_loss = checkpoint["best_loss"]
# Training loop (~100 times) over the dataset:
for i in range(starting_epoch, args.n_epochs):
# Train first, Test second:
for dataset_type in ["Train", "Validation", "Test"]:
if dataset_type == "Train":
test = False
else:
test = True
suffix = dataset_type
if dataset_type == "Train":
dataloader = train_loader
elif dataset_type == "Validation":
dataloader = val_loader
elif dataset_type == "Test":
dataloader = test_loader
# Perform one pass through the data:
info = iterate(
net,
dataloader,
optimizer,
args,
test=test,
summary_writer=writer,
epoch_number=i,
)
# Write down the results using a TensorBoard writer:
for key, val in info.items():
if key in [
"Loss",
"ROC-AUC",
"Distance/Positives",
"Distance/Negatives",
"Matching ROC-AUC",
]:
writer.add_scalar(f"{key}/{suffix}", np.mean(val), i)
if "R_values/" in key:
val = np.array(val)
writer.add_scalar(f"{key}/{suffix}", np.mean(val[val > 0]), i)
if dataset_type == "Validation": # Store validation loss for saving the model
val_loss = np.mean(info["Loss"])
if True: # Additional saves
if val_loss < best_loss:
print("Validation loss {}, saving model".format(val_loss))
torch.save(
{
"epoch": i,
"model_state_dict": net.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"best_loss": best_loss,
},
model_path + "_epoch{}".format(i),
)
best_loss = val_loss