-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_model.py
140 lines (114 loc) · 5.21 KB
/
train_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import argparse
import time
import torch
from model import SingleViewto3D, ImplicitMLPDecoder
from r2n2_custom import R2N2
from pytorch3d.datasets.r2n2.utils import collate_batched_R2N2
import dataset_location
from pytorch3d.ops import sample_points_from_meshes
import losses
import matplotlib.pyplot as plt
import numpy as np
def get_args_parser():
parser = argparse.ArgumentParser('Singleto3D', add_help=False)
# Model parameters
parser.add_argument('--arch', default='resnet18', type=str)
parser.add_argument('--lr', default=4e-4, type=str)
parser.add_argument('--max_iter', default=10000, type=str)
parser.add_argument('--log_freq', default=1000, type=str)
parser.add_argument('--batch_size', default=8, type=str)
parser.add_argument('--num_workers', default=2, type=str)
parser.add_argument('--type', default='vox', choices=['vox', 'point', 'mesh'], type=str)
parser.add_argument('--n_points', default=10000, type=int)
parser.add_argument('--w_chamfer', default=1.0, type=float)
parser.add_argument('--w_smooth', default=0.1, type=float)
parser.add_argument('--save_freq', default=500, type=int)
parser.add_argument('--device', default='cuda', type=str)
parser.add_argument('--load_feat', action='store_true')
parser.add_argument('--load_checkpoint', action='store_true')
return parser
def preprocess(feed_dict,args):
images = feed_dict['images'].squeeze(1)
if args.type == "vox":
voxels = feed_dict['voxels'].float()
ground_truth_3d = voxels
elif args.type == "point":
mesh = feed_dict['mesh']
pointclouds_tgt = sample_points_from_meshes(mesh, args.n_points)
ground_truth_3d = pointclouds_tgt
elif args.type == "mesh":
ground_truth_3d = feed_dict["mesh"]
if args.load_feat:
feats = torch.stack(feed_dict['feats'])
return feats.to(args.device), ground_truth_3d.to(args.device)
else:
return images.to(args.device), ground_truth_3d.to(args.device)
def calculate_loss(predictions, ground_truth, args):
if args.type == 'vox':
loss = losses.voxel_loss(predictions,ground_truth)
elif args.type == 'point':
loss = losses.chamfer_loss(predictions, ground_truth, args.n_points)
elif args.type == 'mesh':
sample_trg = sample_points_from_meshes(ground_truth, args.n_points)
sample_pred = sample_points_from_meshes(predictions, args.n_points)
loss_reg = losses.chamfer_loss(sample_pred, sample_trg, args.n_points)
loss_smooth = losses.smoothness_loss(predictions)
loss = args.w_chamfer * loss_reg + args.w_smooth * loss_smooth
return loss
# def train_model(i, args):
def train_model(args):
r2n2_dataset = R2N2("train", dataset_location.SHAPENET_PATH, dataset_location.R2N2_PATH, dataset_location.SPLITS_PATH, return_voxels=True, return_feats=args.load_feat)
loader = torch.utils.data.DataLoader(
r2n2_dataset,
batch_size=args.batch_size,
num_workers=args.num_workers,
collate_fn=collate_batched_R2N2,
pin_memory=True,
drop_last=True)
train_loader = iter(loader)
model = SingleViewto3D(args)
# model = ImplicitMLPDecoder(args)
model.to(args.device)
model.train()
# ============ preparing optimizer ... ============
optimizer = torch.optim.Adam(model.parameters(), lr = args.lr) # to use with ViTs
start_iter = 0
start_time = time.time()
# # if i > 0:
# # args.load_checkpoint = 1 # uncomment this if you want to load a checkpoint
# args.load_checkpoint = 0 # and comment this
if args.load_checkpoint:
checkpoint = torch.load(f'checkpoint_{args.type}.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_iter = checkpoint['step']
print(f"Succesfully loaded iter {start_iter}")
print("Starting training !")
for step in range(start_iter, args.max_iter):
iter_start_time = time.time()
if step % len(train_loader) == 0: #restart after one epoch
train_loader = iter(loader)
read_start_time = time.time()
feed_dict = next(train_loader)
images_gt, ground_truth_3d = preprocess(feed_dict,args)
read_time = time.time() - read_start_time
prediction_3d = model(images_gt, args)
loss = calculate_loss(prediction_3d, ground_truth_3d, args)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_time = time.time() - start_time
iter_time = time.time() - iter_start_time
loss_vis = loss.cpu().item()
if (step % args.save_freq) == 0:
torch.save({
'step': step,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict()
}, f'checkpoint_{args.type}.pth')
print("[%4d/%4d]; time: %.0f (%.2f, %.2f); loss: %.3f" % (step, args.max_iter, total_time, read_time, iter_time, loss_vis))
print('Done!')
if __name__ == '__main__':
parser = argparse.ArgumentParser('Singleto3D', parents=[get_args_parser()])
args = parser.parse_args()
train_model(args)