-
Notifications
You must be signed in to change notification settings - Fork 28
/
train.py
78 lines (54 loc) · 2.32 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import subprocess
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from utils.dataset_utils import PromptTrainDataset
from net.model import PromptIR
from utils.schedulers import LinearWarmupCosineAnnealingLR
import numpy as np
import wandb
from options import options as opt
import lightning.pytorch as pl
from lightning.pytorch.loggers import WandbLogger,TensorBoardLogger
from lightning.pytorch.callbacks import ModelCheckpoint
class PromptIRModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.net = PromptIR(decoder=True)
self.loss_fn = nn.L1Loss()
def forward(self,x):
return self.net(x)
def training_step(self, batch, batch_idx):
# training_step defines the train loop.
# it is independent of forward
([clean_name, de_id], degrad_patch, clean_patch) = batch
restored = self.net(degrad_patch)
loss = self.loss_fn(restored,clean_patch)
# Logging to TensorBoard (if installed) by default
self.log("train_loss", loss)
return loss
def lr_scheduler_step(self,scheduler,metric):
scheduler.step(self.current_epoch)
lr = scheduler.get_lr()
def configure_optimizers(self):
optimizer = optim.AdamW(self.parameters(), lr=2e-4)
scheduler = LinearWarmupCosineAnnealingLR(optimizer=optimizer,warmup_epochs=15,max_epochs=150)
return [optimizer],[scheduler]
def main():
print("Options")
print(opt)
if opt.wblogger is not None:
logger = WandbLogger(project=opt.wblogger,name="PromptIR-Train")
else:
logger = TensorBoardLogger(save_dir = "logs/")
trainset = PromptTrainDataset(opt)
checkpoint_callback = ModelCheckpoint(dirpath = opt.ckpt_dir,every_n_epochs = 1,save_top_k=-1)
trainloader = DataLoader(trainset, batch_size=opt.batch_size, pin_memory=True, shuffle=True,
drop_last=True, num_workers=opt.num_workers)
model = PromptIRModel()
trainer = pl.Trainer( max_epochs=opt.epochs,accelerator="gpu",devices=opt.num_gpus,strategy="ddp_find_unused_parameters_true",logger=logger,callbacks=[checkpoint_callback])
trainer.fit(model=model, train_dataloaders=trainloader)
if __name__ == '__main__':
main()