-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
180 lines (151 loc) · 5.61 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
"""
Main entry point. Uses Hydra to load config files and override defaults with command line args
"""
import logging
import os
import random
import time
import warnings
from pathlib import Path
import hydra
import numpy as np
import torch
import wandb
from lightning.pytorch import seed_everything
from lightning.pytorch.loggers.wandb import WandbLogger
from omegaconf import DictConfig, OmegaConf
from sdofm import utils # import days_hours_mins_secs_str
from sdofm.utils import flatten_dict
# import torch_xla.debug.profiler as xp
wandb_logger = None
# loads the config file
@hydra.main(config_path="../experiments", config_name="default")
def main(cfg: DictConfig) -> None:
match cfg.log_level:
case "DEBUG":
logging.basicConfig(level=logging.DEBUG)
case _:
logging.basicConfig(level=logging.INFO)
# set seed
torch.manual_seed(cfg.experiment.seed)
np.random.seed(cfg.experiment.seed)
random.seed(cfg.experiment.seed)
seed_everything(cfg.experiment.seed)
# set device using config disable_cuda option and torch.cuda.is_available()
profiler = None
match cfg.experiment.profiler:
case "XLAProfiler":
from lightning.pytorch.profilers import XLAProfiler
profiler = XLAProfiler(port=9014)
case "PyTorchProfiler":
from lightning.pytorch.profilers import PyTorchProfiler
profiler = PyTorchProfiler(
on_trace_ready=torch.profiler.tensorboard_trace_handler("./log/sdofm"),
record_shapes=True,
profile_memory=True,
with_stack=True,
)
case "Profiler":
from lightning.pytorch.profilers import Profiler
profiler = Profiler()
case None:
profiler = None
case _:
raise NotImplementedError(
f"Profiler {cfg.experiment.profiler} is not implemented."
)
# set precision of torch tensors
match cfg.experiment.precision:
case 64:
torch.set_default_tensor_type(torch.DoubleTensor)
case 32:
torch.set_default_tensor_type(torch.FloatTensor)
case _:
warnings.warn(
f"Setting precision {cfg.experiment.precision} will pass through to the trainer but not other operations."
)
# raise NotImplementedError(
# f"Precision {cfg.experiment.precision} not implemented"
# )
# run experiment
print(f"\nRunning with config:")
print(OmegaConf.to_yaml(cfg, resolve=False, sort_keys=False))
print("\n")
print(f"Using device: {cfg.experiment.accelerator}")
# set up wandb logging
if cfg.experiment.wandb.enable:
wandb.login()
output_dir = Path(cfg.experiment.wandb.output_directory)
output_dir.mkdir(exist_ok=True, parents=True)
print(
f"Created directory for storing results: {cfg.experiment.wandb.output_directory}"
)
cache_dir = Path(f"{cfg.experiment.wandb.output_directory}/.cache")
cache_dir.mkdir(exist_ok=True, parents=True)
os.environ["WANDB_CACHE_DIR"] = (
f"{cfg.experiment.wandb.output_directory}/.cache"
)
os.environ["WANDB_MODE"] = (
"offline" if not cfg.experiment.wandb.enable else "online"
)
resume = "never"
run_id = None
if cfg.experiment.resuming:
resume = "allow"
run_id = cfg.experiment.checkpoint.split(":")[0].split("-")[-1]
print("Will attempt to resume W&B run", run_id)
logger = WandbLogger(
# WandbLogger params
name=cfg.experiment.name,
project=cfg.experiment.project,
dir=cfg.experiment.wandb.output_directory,
log_model=cfg.experiment.wandb.log_model,
# kwargs for wandb.init
tags=cfg.experiment.wandb.tags,
notes=cfg.experiment.wandb.notes,
group=cfg.experiment.wandb.group,
save_code=True,
job_type=cfg.experiment.wandb.job_type,
config=flatten_dict(cfg),
resume=resume,
id=run_id,
)
else:
logger = None
# set up checkpointing to gcp bucket
if cfg.experiment.gcp_storage:
from google.cloud import storage
client = storage.Client() # project='myproject'
bucket = client.get_bucket(cfg.experiment.gcp_storage.bucket)
if not bucket.exists():
raise ValueError(
"Not authenticated or cannot find provided Google Storage bucket, is your machine authenticated?"
)
match cfg.experiment.task:
case "pretrain":
from scripts.pretrain import Pretrainer
pretrainer = Pretrainer(cfg, logger=logger, profiler=profiler)
pretrainer.run()
case "finetune":
from scripts.finetune import Finetuner
finetuner = Finetuner(cfg, logger=logger, profiler=profiler)
finetuner.run()
case "ablation":
from scripts.ablation import Ablation
ablation = Ablation(cfg, logger=logger, profiler=profiler)
ablation.run()
case _:
raise NotImplementedError(
f"Experiment {cfg.experiment.task} not implemented"
)
if __name__ == "__main__":
# server = xp.start_server(9012)
time_start = time.time()
# errors
os.environ["HYDRA_FULL_ERROR"] = "1" # Produce a complete stack trace
main()
print(
"\nTotal duration: {}".format(
utils.days_hours_mins_secs_str(time.time() - time_start)
)
)