-
Notifications
You must be signed in to change notification settings - Fork 7
/
train.py
457 lines (418 loc) · 24 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
import os
import argparse
import time
import logging
import datetime
from copy import deepcopy
import cv2
import numpy as np
from tqdm import tqdm
import random
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.distributed as dist
import torch.utils.data
from torch.backends import cudnn
from torch.utils.tensorboard import SummaryWriter
import torchvision.utils as vutils
from fastai.vision import *
from Dino.utils.utils import Config, Logger, MyConcatDataset
from Dino.utils.util import Averager
from Dino.dataset.datasetsupervised_kmeans import ImageDatasetSelfSupervisedKmeans
from Dino.dataset.dataset import collate_fn_filter_none
from Dino.modules.vision_transformer import DINOHead
from Dino.modules import vision_transformer as vits
from Dino.modules.segmentor import SegHead
from Dino.model.dino_vision import ABIDINOModel
from Dino.modules import utils
from Dino.loss.Dino_loss import DINOLoss
from torchvision import models as torchvision_models
import warnings
warnings.filterwarnings("ignore")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torchvision_archs = sorted(name for name in torchvision_models.__dict__
if name.islower() and not name.startswith("__")
and callable(torchvision_models.__dict__[name]))
def train(config):
"""parameter configuration"""
utils.init_distributed_mode(config)
utils.fix_random_seeds(config.seed)
cudnn.benchmark = True
"""dataset preparation"""
logging.info('Construct dataset.')
train_dataloader = _get_databaunch(config)
config.iter_num = len(train_dataloader)
logging.info(f"each epoch iteration: {config.iter_num}")
""" model configuration """
# ============ building student and teacher networks ... ============
# we changed the name DeiT-S for ViT-S to avoid confusions
config.arch = config.arch.replace("deit", "vit")
# if the network is a Vision Transformer (i.e. vit_tiny, vit_small, vit_base)
if config.arch in vits.__dict__.keys():
student = vits.__dict__[config.arch](
patch_size=config.patch_size,
drop_path_rate=config.drop_path_rate, # stochastic depth
)
teacher = vits.__dict__[config.arch](patch_size=config.patch_size)
embed_dim = student.embed_dim
# if the network is a XCiT
elif config.arch in torch.hub.list("facebookresearch/xcit:main"):
student = torch.hub.load('facebookresearch/xcit:main', config.arch,
pretrained=False, drop_path_rate=config.drop_path_rate)
teacher = torch.hub.load('facebookresearch/xcit:main', config.arch, pretrained=False)
embed_dim = student.embed_dim
# otherwise, we check if the architecture is in torchvision models
elif config.arch in torchvision_models.__dict__.keys():
student = torchvision_models.__dict__[config.arch]()
teacher = torchvision_models.__dict__[config.arch]()
embed_dim = student.fc.weight.shape[1]
else:
print(f"Unknow architecture: {config.arch}")
# multi-crop wrapper handles forward with inputs of different resolutions
student = ABIDINOModel(
student,
SegHead(in_channels=config.model_seg_channel, mla_channels=128, mlahead_channels=64, num_classes=2),
DINOHead(embed_dim, config.out_dim, use_bn=config.use_bn_in_head, norm_last_layer=config.norm_last_layer, ))
teacher = ABIDINOModel(
teacher,
None,
DINOHead(embed_dim, config.out_dim, config.use_bn_in_head), )
# move networks to gpu
student, teacher = student.cuda(), teacher.cuda()
# synchronize batch norms (if any)
if utils.has_batchnorms(student):
student = nn.SyncBatchNorm.convert_sync_batchnorm(student)
teacher = nn.SyncBatchNorm.convert_sync_batchnorm(teacher)
# we need DDP wrapper to have synchro batch norms working...
teacher = nn.parallel.DistributedDataParallel(teacher, device_ids=[config.gpu])
teacher_without_ddp = teacher.module
else:
# teacher_without_ddp and teacher are the same thing
teacher_without_ddp = teacher
student = nn.parallel.DistributedDataParallel(student, device_ids=[config.gpu], find_unused_parameters=True)
# teacher and student start with the same weights
# teacher_without_ddp.load_state_dict(student.module.state_dict())
teacher_without_ddp.backbone.load_state_dict(student.module.backbone.state_dict())
teacher_without_ddp.head.load_state_dict(student.module.head.state_dict())
# there is no backpropagation through the teacher, so no need for gradients
for p in teacher.parameters():
p.requires_grad = False
print(f"Student and Teacher are built: they are both {config.arch} network.")
""" setup loss """
config.epochs = int(config.training_epochs * len(train_dataloader) * (
config.batch_size_per_gpu * utils.get_world_size()) / config.imgnet_based) + 1
print(f'training epochs is {config.epochs}')
# ============ preparing loss ... ============
dino_loss = DINOLoss(
config.out_dim,
config.crops_number,
config.warmup_teacher_temp,
config.teacher_temp,
config.warmup_teacher_temp_epochs,
config.epochs,
).cuda()
# ============ preparing optimizer ... ============
params_groups = utils.get_params_groups(student)
if config.optimizer == "adamw":
optimizer = torch.optim.AdamW(params_groups) # to use with ViTs
elif config.optimizer == "sgd":
optimizer = torch.optim.SGD(params_groups, lr=0, momentum=0.9) # lr is set by scheduler
elif config.optimizer == "lars":
optimizer = utils.LARS(params_groups) # to use with convnet and large batches
# for mixed precision training
fp16_scaler = None
if config.use_fp16:
fp16_scaler = torch.cuda.amp.GradScaler()
# ============ init schedulers ... ============
lr_schedule = utils.cosine_iter_scheduler(
config.lr * (config.batch_size_per_gpu * utils.get_world_size()) / 256., # linear scaling rule
config.min_lr,
config.training_epochs * len(train_dataloader),
warmup_iters=int(
(config.warmup_epoch * config.imgnet_based) / (config.batch_size_per_gpu * utils.get_world_size())),
)
wd_schedule = utils.cosine_iter_scheduler(
config.weight_decay,
config.weight_decay_end,
config.training_epochs * len(train_dataloader),
)
# momentum parameter is increased to 1. during training with a cosine schedule
momentum_schedule = utils.cosine_iter_scheduler(config.momentum_teacher, 1,
config.training_epochs * len(train_dataloader))
print(f"Loss, optimizer and schedulers ready.")
# ============ optionally resume training ... ============
to_restore = {"epoch": 0, 'iteration': 0}
utils.restart_from_checkpoint(
os.path.join(config.output_dir, config.global_name, "checkpoint.pth"),
run_variables=to_restore,
student=student,
teacher=teacher,
optimizer=optimizer,
fp16_scaler=fp16_scaler,
dino_loss=dino_loss,
)
iteration = int(to_restore["iteration"])
epoch = to_restore["epoch"]
print(f'continue to train:{iteration}:{epoch}')
start_time = time.time()
min_loss = 100.
best_eval_accuracy = 0.
best_eval_char_accuracy = 0.
global global_epoch
global_epoch = 0
print("Starting DINO training !")
for train_epoch in range(config.training_epochs):
train_dataloader.sampler.set_epoch(train_epoch)
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Epoch: [{}/{}]'.format(train_epoch, config.training_epochs)
for (image_tensors, masks, metrics) in metric_logger.log_every(train_dataloader, 10, header):
epoch = int((iteration + 1) * (config.batch_size_per_gpu * utils.get_world_size()) / config.imgnet_based)
### examine epoch updating state
if epoch != global_epoch:
global_epoch = deepcopy(epoch)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
train_stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
save_dict = {
'student': student.state_dict(),
'teacher': teacher.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch,
'iteration': iteration,
# 'args': config,
'dino_loss': dino_loss.state_dict(),
}
if fp16_scaler is not None:
save_dict['fp16_scaler'] = fp16_scaler.state_dict()
utils.save_on_master(save_dict, os.path.join(config.output_dir, config.global_name, 'checkpoint.pth'))
if config.saveckp_freq and epoch % config.saveckp_freq == 0:
utils.save_on_master(save_dict, os.path.join(config.output_dir, config.global_name,
f'checkpoint{epoch:04}.pth'))
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
'epoch': epoch}
if utils.is_main_process():
with (Path(config.output_dir) / f"{config.global_name}/log.txt").open("a") as f:
f.write(json.dumps(log_stats) + "\n")
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Epoch: [{}/{}]'.format(train_epoch, config.training_epochs)
image_tensors = image_tensors.cuda(non_blocking=True)
masks = masks.cuda(non_blocking=True)
for i, param_group in enumerate(optimizer.param_groups):
param_group["lr"] = lr_schedule[iteration]
if i == 0: # only the first group is regularized
param_group["weight_decay"] = wd_schedule[iteration]
# teacher and student forward passes + compute dino loss
with torch.cuda.amp.autocast(fp16_scaler is not None):
metrics = metrics.float()
student_output = student(image_tensors, metrics, masks, epoch, clusters=None)
teacher_output = teacher(image_tensors, metrics, None, None, clusters=student_output['zero'], index=student_output['index']) # only the 2 global views pass through the teacher
affine_grid = F.affine_grid(metrics[:, :2, :], size=(masks.shape[0], 1, masks.shape[1], masks.shape[2]))
masks_image = F.grid_sample(masks.unsqueeze(1), affine_grid.to(masks.device))
masks_image = (masks_image > 0.1).float().squeeze()
student_output['gt'] = [masks, masks_image]
loss = dino_loss(student_output, teacher_output, epoch)
if not math.isfinite(loss.item()):
print("Loss is {}, stopping training".format(loss.item()), force=True)
sys.exit(1)
# student update
optimizer.zero_grad()
param_norms = None
if fp16_scaler is None:
loss.backward()
if config.clip_grad:
param_norms = utils.clip_gradients(student, config.clip_grad)
utils.cancel_gradients_last_layer(epoch, student,
config.freeze_last_layer)
optimizer.step()
else:
fp16_scaler.scale(loss).backward()
if config.clip_grad:
fp16_scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
param_norms = utils.clip_gradients(student, config.clip_grad)
utils.cancel_gradients_last_layer(epoch, student,
config.freeze_last_layer)
fp16_scaler.step(optimizer)
fp16_scaler.update()
# EMA update for the teacher
with torch.no_grad():
m = momentum_schedule[iteration] # momentum parameter
# for param_q, param_k in zip(student.module.parameters(), teacher_without_ddp.parameters()):
# param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)
for param_q, param_k in zip(student.module.backbone.parameters(),
teacher_without_ddp.backbone.parameters()):
param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)
for param_q, param_k in zip(student.module.head.parameters(), teacher_without_ddp.head.parameters()):
param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)
# logging
torch.cuda.synchronize()
metric_logger.update(loss=loss.item())
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
metric_logger.update(wd=optimizer.param_groups[0]["weight_decay"])
if iteration % config.training_show_iters == 0:
i = random.randint(0, config.batch_size_per_gpu - 1)
last_losses = dino_loss.last_losses
for name, loss in last_losses.items():
scalar_value = loss.data.cpu().numpy()
tag = 'metric/' + name
config.writer.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)
lr = optimizer.param_groups[0]["lr"]
config.writer.add_scalar(tag='metric/' + 'lr', scalar_value=lr, global_step=iteration)
wd = optimizer.param_groups[0]["weight_decay"]
config.writer.add_scalar(tag='metric/' + 'wd', scalar_value=wd, global_step=iteration)
if iteration > config.training_epochs * len(train_dataloader):
break
iteration += 1
if iteration > config.training_epochs * len(train_dataloader):
break
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
def _parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, required=True,
help='path to config file')
# Model parameters
parser.add_argument('--arch', default='vit_small', type=str,
# choices=['vit_tiny', 'vit_small', 'vit_base', 'xcit', 'deit_tiny', 'deit_small'] \
# + torchvision_archs + torch.hub.list("facebookresearch/xcit:main"),
help="""Name of architecture to train. For quick experiments with ViTs,
we recommend using vit_tiny or vit_small.""")
parser.add_argument('--patch_size', default=4, type=int, help="""Size in pixels
of input square patches - default 16 (for 16x16 patches). Using smaller
values leads to better performance but requires more memory. Applies only
for ViTs (vit_tiny, vit_small and vit_base). If <16, we recommend disabling
mixed precision training (--use_fp16 false) to avoid unstabilities.""")
parser.add_argument('--out_dim', default=65536, type=int, help="""Dimensionality of
the DINO head output. For complex and large datasets large values (like 65k) work well.""")
parser.add_argument('--norm_last_layer', default=True, type=utils.bool_flag,
help="""Whether or not to weight normalize the last layer of the DINO head.
Not normalizing leads to better performance but can make the training unstable.
In our experiments, we typically set this paramater to False with vit_small and True with vit_base.""")
parser.add_argument('--momentum_teacher', default=0.996, type=float, help="""Base EMA
parameter for teacher update. The value is increased to 1 during training with cosine schedule.
We recommend setting a higher value with small batches: for example use 0.9995 with batch size of 256.""")
parser.add_argument('--use_bn_in_head', default=False, type=utils.bool_flag,
help="Whether to use batch normalizations in projection head (Default: False)")
# Temperature teacher parameters
parser.add_argument('--warmup_teacher_temp', default=0.04, type=float,
help="""Initial value for the teacher temperature: 0.04 works well in most cases.
Try decreasing it if the training loss does not decrease.""")
parser.add_argument('--teacher_temp', default=0.04, type=float, help="""Final value (after linear warmup)
of the teacher temperature. For most experiments, anything above 0.07 is unstable. We recommend
starting with the default value of 0.04 and increase this slightly if needed.""")
parser.add_argument('--warmup_teacher_temp_epochs', default=0, type=int,
help='Number of warmup epochs for the teacher temperature (Default: 30).')
# Training/Optimization parameters
parser.add_argument('--use_fp16', type=utils.bool_flag, default=True, help="""Whether or not
to use half precision for training. Improves training time and memory requirements,
but can provoke instability and slight decay of performance. We recommend disabling
mixed precision if the loss is unstable, if reducing the patch size or if training with bigger ViTs.""")
parser.add_argument('--weight_decay', type=float, default=0.04, help="""Initial value of the
weight decay. With ViT, a smaller value at the beginning of training works well.""")
parser.add_argument('--weight_decay_end', type=float, default=0.4, help="""Final value of the
weight decay. We use a cosine schedule for WD and using a larger decay by
the end of training improves performance for ViTs.""")
parser.add_argument('--clip_grad', type=float, default=3.0, help="""Maximal parameter
gradient norm if using gradient clipping. Clipping with norm .3 ~ 1.0 can
help optimization for larger ViT architectures. 0 for disabling.""")
parser.add_argument('--batch_size_per_gpu', default=64, type=int,
help='Per-GPU batch-size : number of distinct images loaded on one GPU.')
parser.add_argument('--epochs', default=100, type=int, help='Number of epochs of training.')
parser.add_argument('--freeze_last_layer', default=1, type=int, help="""Number of epochs
during which we keep the output layer fixed. Typically doing so during
the first epoch helps training. Try increasing this value if the loss does not decrease.""")
parser.add_argument("--lr", default=0.0005, type=float, help="""Learning rate at the end of
linear warmup (highest LR used during training). The learning rate is linearly scaled
with the batch size, and specified here for a reference batch size of 256.""")
parser.add_argument("--warmup_epochs", default=10, type=int,
help="Number of epochs for the linear learning-rate warm up.")
parser.add_argument('--min_lr', type=float, default=1e-6, help="""Target LR at the
end of optimization. We use a cosine LR schedule with linear warmup.""")
parser.add_argument('--optimizer', default='adamw', type=str,
choices=['adamw', 'sgd', 'lars'],
help="""Type of optimizer. We recommend using adamw with ViTs.""")
parser.add_argument('--drop_path_rate', type=float, default=0.1, help="stochastic depth rate")
# Multi-crop parameters
parser.add_argument('--global_crops_scale', type=float, nargs='+', default=(0.4, 1.),
help="""Scale range of the cropped image before resizing, relatively to the origin image.
Used for large global view cropping. When disabling multi-crop (--local_crops_number 0), we
recommand using a wider range of scale ("--global_crops_scale 0.14 1." for example)""")
parser.add_argument('--local_crops_number', type=int, default=8, help="""Number of small
local views to generate. Set this parameter to 0 to disable multi-crop training.
When disabling multi-crop we recommend to use "--global_crops_scale 0.14 1." """)
parser.add_argument('--local_crops_scale', type=float, nargs='+', default=(0.05, 0.4),
help="""Scale range of the cropped image before resizing, relatively to the origin image.
Used for small local view cropping of multi-crop.""")
# Misc
parser.add_argument('--data_path', default='/path/to/imagenet/train/', type=str,
help='Please specify path to the ImageNet training data.')
parser.add_argument('--output_dir', default=".", type=str, help='Path to save logs and checkpoints.')
parser.add_argument('--saveckp_freq', default=20, type=int, help='Save checkpoint every x epochs.')
parser.add_argument('--seed', default=0, type=int, help='Random seed.')
parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.')
parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up
distributed training; see https://pytorch.org/docs/stable/distributed.html""")
parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
args = parser.parse_args()
config = Config(args.config)
return config
def _get_databaunch(config):
def _get_dataset(ds_type, paths, is_training, config, **kwargs):
kwargs.update({
'img_h': config.dataset_image_height,
'img_w': config.dataset_image_width,
'max_length': config.dataset_max_length,
'case_sensitive': config.dataset_case_sensitive,
'charset_path': config.dataset_charset_path,
'data_aug': config.dataset_data_aug,
'deteriorate_ratio': config.dataset_deteriorate_ratio,
'multiscales': config.dataset_multiscales,
'data_portion': config.dataset_portion,
'filter_single_punctuation': config.dataset_filter_single_punctuation,
'mask': config.dataset_mask,
'mask_path': config.dataset_mask_path,
})
datasets = []
for p in paths:
subfolders = [f.path for f in os.scandir(p) if f.is_dir()]
if subfolders: # Concat all subfolders
datasets.append(_get_dataset(ds_type, subfolders, is_training, config, **kwargs))
else:
datasets.append(ds_type(path=p, is_training=is_training, **kwargs))
if len(datasets) > 1:
return MyConcatDataset(datasets)
else:
return datasets[0]
bunch_kwargs = {}
ds_kwargs = {}
bunch_kwargs['collate_fn'] = collate_fn_filter_none
dataset_class = ImageDatasetSelfSupervisedKmeans
if config.dataset_augmentation_severity is not None:
ds_kwargs['augmentation_severity'] = config.dataset_augmentation_severity
ds_kwargs['supervised_flag'] = ifnone(config.model_contrastive_supervised_flag, False)
train_ds = _get_dataset(dataset_class, config.dataset_train_roots, True, config, **ds_kwargs)
sampler = torch.utils.data.DistributedSampler(train_ds, shuffle=True)
train_dataloader = torch.utils.data.DataLoader(
train_ds,
sampler=sampler,
batch_size=config.batch_size_per_gpu,
num_workers=config.dataset_num_workers,
collate_fn=collate_fn_filter_none,
pin_memory=config.dataset_pin_memory,
drop_last=True,
)
return train_dataloader
if __name__ == "__main__":
config = _parse_arguments()
Logger.init(config.global_workdir, config.global_name, config.global_phase)
Logger.enable_file()
# _set_random_seed(config.global_seed)
logging.info(config)
os.makedirs(f"./saved_models/{config.global_name}", exist_ok=True)
os.makedirs(f"./tensorboard", exist_ok=True)
config.writer = SummaryWriter(log_dir=f"./tensorboard/{config.global_name}")
train(config)