-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathscript_train.py
75 lines (62 loc) · 2.02 KB
/
script_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import cnn
import datetime
from datetime import timedelta
import fire
import kymatio
import os
import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers
import time
import torch
import sys
def run(loss_type, sav_dir, job_id):
# Print header
start_time = int(time.time())
print(str(datetime.datetime.now()) + " Start.")
print(__doc__ + "\n")
print("Loss type: " + loss_type)
print("Save directory: " + sav_dir)
print("Job ID: " + str(job_id))
print("\n")
# Print version numbers
for module in [kymatio, torch, pl]:
print("{} version: {:s}".format(module.__name__, module.__version__))
print("\n")
sys.stdout.flush()
n_densities = 32
n_slopes = 32
n_folds = 8
batch_size = 64
dataset = cnn.ChirpTextureDataModule(
n_densities=n_densities,
n_slopes=n_slopes,
n_folds=n_folds,
batch_size=batch_size)
dataset.setup()
samples_per_epoch = 768
steps_per_epoch = samples_per_epoch / dataset.batch_size
steps_per_epoch = 1
models_dir = os.path.join(sav_dir, "models_{}".format(loss_type))
logs_dir = os.path.join(sav_dir, "logs_{}".format(loss_type))
model = cnn.EffNet(loss_type, models_dir, steps_per_epoch)
tb_logger = pl_loggers.TensorBoardLogger(save_dir=logs_dir)
trainer = pl.Trainer(
max_epochs=-1,
limit_train_batches=steps_per_epoch,
callbacks=[],
logger=tb_logger,
max_time=timedelta(hours=12)
)
trainer.fit(model, dataset)
# Print elapsed time.
print(str(datetime.datetime.now()) + " Success.")
elapsed_time = time.time() - int(start_time)
elapsed_hours = int(elapsed_time / (60 * 60))
elapsed_minutes = int((elapsed_time % (60 * 60)) / 60)
elapsed_seconds = elapsed_time % 60.0
elapsed_str = "{:>02}:{:>02}:{:>05.2f}".format(
elapsed_hours, elapsed_minutes, elapsed_seconds
)
print("Total elapsed time: " + elapsed_str + ".")
if __name__ == '__main__':
fire.Fire(run)