-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
92 lines (78 loc) · 2.46 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import sys
import random
import datetime
import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from config import *
from train import train
from metrics import Estimator
from data import generate_dataset
from modules import generate_model
from utils import print_config, select_out_features
def main():
# create folder
save_path = BASIC_CONFIG['save_path']
if os.path.exists(save_path):
overwirte = input('Save path {} exists.\nDo you want to overwrite it? (y/n)\n'.format(save_path))
if overwirte != 'y':
sys.exit(0)
else:
os.makedirs(save_path)
# create logger
record_path = BASIC_CONFIG['record_path']
record_path = os.path.join(record_path, 'log-' + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
logger = SummaryWriter(record_path)
# print configuration
print_config({
'BASIC CONFIG': BASIC_CONFIG,
'DATA CONFIG': DATA_CONFIG,
'TRAIN CONFIG': TRAIN_CONFIG
})
# reproducibility
seed = BASIC_CONFIG['random_seed']
set_random_seed(seed)
# build model
net_name = BASIC_CONFIG['network']
backbone = NET_CONFIG[net_name]
device = BASIC_CONFIG['device']
criterion = TRAIN_CONFIG['criterion']
num_classes = BASIC_CONFIG['num_classes']
out_features = select_out_features(num_classes, criterion)
model, discriminator = generate_model(
device,
BASIC_CONFIG['pretrained'],
BASIC_CONFIG['checkpoint'],
NET_CONFIG['args']
)
# create dataset
train_dataset, val_dataset = generate_dataset(
DATA_CONFIG,
BASIC_CONFIG['data_path'],
BASIC_CONFIG['data_index'],
TRAIN_CONFIG['batch_size'],
TRAIN_CONFIG['num_workers']
)
# create estimator and then train
estimator = Estimator(criterion, num_classes, device)
train(
model=model,
discriminator=discriminator,
train_config=TRAIN_CONFIG,
data_config=DATA_CONFIG,
train_dataset=train_dataset,
val_dataset=val_dataset,
save_path=save_path,
estimator=estimator,
device=device,
logger=logger
)
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
if __name__ == '__main__':
main()