-
Notifications
You must be signed in to change notification settings - Fork 14
/
train.py
141 lines (128 loc) · 5.03 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import logging
import os
import logging
import hydra
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader
import torch.nn as nn
from src.utils.weight import load_checkpoint
from src.dataloader.lm_utils import get_list_id_obj_from_split_name
import pytorch_lightning as pl
from src.utils.dataloader import concat_dataloader
pl.seed_everything(2022)
# set level logging
logging.basicConfig(level=logging.INFO)
@hydra.main(version_base=None, config_path="configs", config_name="train")
def train(cfg: DictConfig):
OmegaConf.set_struct(cfg, False)
hydra_cfg = hydra.core.hydra_config.HydraConfig.get()
output_path = hydra_cfg["runtime"]["output_dir"]
os.makedirs(cfg.callback.checkpoint.dirpath, exist_ok=True)
logging.info(
f"Training script. The outputs of hydra will be stored in: {output_path}"
)
logging.info(f"Checkpoints will be stored in: {cfg.callback.checkpoint.dirpath}")
# Delayed imports to get faster parsing
from hydra.utils import instantiate
logging.info("Initializing logger, callbacks and trainer")
os.environ["WANDB_API_KEY"] = cfg.user.wandb_api_key
if cfg.machine.dryrun:
os.environ["WANDB_MODE"] = "offline"
logging.info(f"Wandb logger initialized at {cfg.save_dir}")
if cfg.machine.name == "slurm":
num_gpus = int(os.environ["SLURM_GPUS_ON_NODE"])
num_nodes = int(os.environ["SLURM_NNODES"])
cfg.machine.trainer.devices = num_gpus
cfg.machine.trainer.num_nodes = num_nodes
logging.info(f"Slurm config: {num_gpus} gpus, {num_nodes} nodes")
trainer = instantiate(cfg.machine.trainer)
logging.info(f"Trainer initialized")
model = instantiate(cfg.model)
logging.info(f"Model '{cfg.model.modelname}' loaded")
if cfg.model.pretrained_weight is not None:
load_checkpoint(
model.backbone,
cfg.model.pretrained_weight,
prefix="",
checkpoint_key="model",
)
val_dataloaders = {}
for data_name in cfg.train_datasets:
if data_name == "hope":
continue
config_dataloader = cfg.data[data_name].dataloader
splits = [
split
for split in os.listdir(config_dataloader.root_dir)
if os.path.isdir(os.path.join(config_dataloader.root_dir, split))
]
splits = [
split
for split in splits
if split.startswith("train") or split.startswith("val")
]
assert len(splits) == 1, f"Found {splits} train splits for {data_name}"
split = splits[0]
config_dataloader.reset_metaData = True
config_dataloader.split = split
config_dataloader.isTesting = True
val_dataloader = DataLoader(
instantiate(config_dataloader),
batch_size=cfg.machine.batch_size,
num_workers=cfg.machine.num_workers,
shuffle=False, # for visualize different samples
)
val_dataloaders[data_name] = val_dataloader
logging.info(
f"Loading validation dataloader with {data_name}, size {len(val_dataloader)} done!"
)
val_dataloaders = concat_dataloader(val_dataloaders)
train_dataloaders = {}
for data_name in cfg.train_datasets:
config_dataloader = cfg.data[data_name].dataloader
splits = [
split
for split in os.listdir(config_dataloader.root_dir)
if os.path.isdir(os.path.join(config_dataloader.root_dir, split))
]
splits = [
split
for split in splits
if split.startswith("train") or split.startswith("val")
]
assert len(splits) == 1, f"Found {splits} train splits for {data_name}"
split = splits[0]
config_dataloader.split = split
config_dataloader.reset_metaData = False
config_dataloader.isTesting = False
config_dataloader.use_augmentation = cfg.use_augmentation
config_dataloader.use_random_rotation = cfg.use_random_rotation
config_dataloader.use_random_scale_translation = (
cfg.use_random_scale_translation
)
config_dataloader.use_additional_negative_samples_for_training = (
cfg.use_additional_negative_samples_for_training
)
train_dataloader = DataLoader(
instantiate(config_dataloader),
batch_size=cfg.machine.batch_size,
num_workers=cfg.machine.num_workers,
shuffle=True,
)
logging.info(
f"Loading train dataloader with {data_name}, size {len(train_dataloader)} done!"
)
logging.info("---" * 100)
train_dataloaders[data_name] = train_dataloader
train_dataloaders = concat_dataloader(train_dataloaders)
logging.info(
f"Fitting the model: train_size={len(train_dataloaders)}, val_size={len(val_dataloaders)}"
)
trainer.fit(
model,
train_dataloaders=train_dataloaders,
val_dataloaders=val_dataloaders,
)
logging.info(f"Fitting done")
if __name__ == "__main__":
train()