-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
39 lines (30 loc) · 1.22 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
#pylint: disable=E1101
import numpy as np
import torch
import torch.nn as nn
import torchvision.models as models
class Net(nn.Module):
num_classes = 1
def __init__(self):
super(Net, self).__init__()
vgg16 = models.vgg16(pretrained=True)
self.encoder = vgg16.features
for i,param in enumerate(self.encoder.parameters()):
param.requires_grad = i >= 16
self.a_convT2d = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1)
self.b_convT2d = nn.ConvTranspose2d(in_channels=768, out_channels=128, kernel_size=4, stride=4, padding=0)
self.convT2d3 = nn.ConvTranspose2d(in_channels=384, out_channels=1, kernel_size=4, stride=4, padding=0)
def forward(self, x):
skipConnections = {}
for i, layer in enumerate(self.encoder):
x = layer(x)
if i in [23, 15]:
skipConnections[i] = x
x = self.a_convT2d(x)
x = torch.cat((x,skipConnections[23]), 1)
x = self.b_convT2d(x)
x = torch.cat((x, skipConnections[15]), 1)
x = self.convT2d3(x)
x = nn.Sigmoid()(x)
x = x.view(x.size()[0], -1, Net.num_classes)
return x