-
Notifications
You must be signed in to change notification settings - Fork 8
/
utils.py
116 lines (94 loc) · 3.88 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
import torch.nn.functional as F
import horovod.torch as hvd
from torch.optim.lr_scheduler import _LRScheduler
def accuracy(output, target):
# get the index of the max log-probability
pred = output.max(1, keepdim=True)[1]
return pred.eq(target.view_as(pred)).cpu().float().mean()
def save_checkpoint(model, optimizer, checkpoint_format, epoch):
if hvd.rank() == 0:
filepath = checkpoint_format.format(epoch=epoch + 1)
state = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
}
torch.save(state, filepath)
class LabelSmoothLoss(torch.nn.Module):
def __init__(self, smoothing=0.0):
super(LabelSmoothLoss, self).__init__()
self.smoothing = smoothing
def forward(self, input, target):
log_prob = F.log_softmax(input, dim=-1)
weight = input.new_ones(input.size()) * \
self.smoothing / (input.size(-1) - 1.)
weight.scatter_(-1, target.unsqueeze(-1), (1. - self.smoothing))
loss = (-weight * log_prob).sum(dim=-1).mean()
return loss
def metric_average(val_tensor):
avg_tensor = hvd.allreduce(val_tensor)
return avg_tensor.item()
# Horovod: average metrics from distributed training.
class Metric(object):
def __init__(self, name):
self.name = name
self.sum = torch.tensor(0.)
self.n = torch.tensor(0.)
def update(self, val, n=1):
self.sum += float(hvd.allreduce(val.detach().cpu(), name=self.name))
self.n += n
@property
def avg(self):
return self.sum / self.n
def create_lr_schedule(workers, warmup_epochs, decay_schedule, alpha=0.1):
def lr_schedule(epoch):
lr_adj = 1.
if epoch < warmup_epochs:
lr_adj = 1. / workers * (epoch * (workers - 1) / warmup_epochs + 1)
else:
decay_schedule.sort(reverse=True)
for e in decay_schedule:
if epoch >= e:
lr_adj *= alpha
return lr_adj
return lr_schedule
class PolynomialDecay(_LRScheduler):
def __init__(self, optimizer, decay_steps, end_lr=0.0001, power=1.0, last_epoch=-1):
self.decay_steps = decay_steps
self.end_lr = end_lr
self.power = power
super().__init__(optimizer, last_epoch)
def get_lr(self):
return self._get_closed_form_lr()
def _get_closed_form_lr(self):
return [
(base_lr - self.end_lr) * ((1 - min(self.last_epoch, self.decay_steps) /
self.decay_steps) ** self.power) + self.end_lr
for base_lr in self.base_lrs
]
class WarmupScheduler(_LRScheduler):
def __init__(self, optimizer, warmup_epochs, after_scheduler, last_epoch=-1):
self.warmup_epochs = warmup_epochs
self.after_scheduler = after_scheduler
self.finished = False
super().__init__(optimizer, last_epoch)
def get_lr(self):
if self.last_epoch >= self.warmup_epochs:
if not self.finished:
self.after_scheduler.base_lrs = self.base_lrs
self.finished = True
return self.after_scheduler.get_lr()
return [self.last_epoch / self.warmup_epochs * lr for lr in self.base_lrs]
def step(self, epoch=None):
if self.finished:
if epoch is None:
self.after_scheduler.step(None)
else:
self.after_scheduler.step(epoch - self.warmup_epochs)
else:
return super().step(epoch)
class PolynomialWarmup(WarmupScheduler):
def __init__(self, optimizer, decay_steps, warmup_steps=0, end_lr=0.0001, power=1.0, last_epoch=-1):
base_scheduler = PolynomialDecay(
optimizer, decay_steps - warmup_steps, end_lr=end_lr, power=power, last_epoch=last_epoch)
super().__init__(optimizer, warmup_steps, base_scheduler, last_epoch=last_epoch)