-
Notifications
You must be signed in to change notification settings - Fork 0
/
deform_dataset.py
67 lines (55 loc) · 1.94 KB
/
deform_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import glob
import os
import pickle as pkl
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
class DeformDataset(Dataset):
def __init__(self, data_dir, mode):
assert os.path.isdir(
f"{data_dir}/{mode}"
), "Data folder does not exist."
self.pkl_files = glob.glob(f"{data_dir}/{mode}/*_offset.pkl")
self.rgb_files = [
fn.replace("_offset.pkl", "_rgb.png") for fn in self.pkl_files
]
self.normals_files = [
fn.replace("_offset.pkl", "_normals.png") for fn in self.pkl_files
]
self.img2tensor = transforms.Compose(
[
Image.open,
transforms.ToTensor(),
]
)
def __len__(self):
return len(self.pkl_files)
def __getitem__(self, idx):
with open(self.pkl_files[idx], "rb") as fp:
data = pkl.load(fp)
offsets = data["offsets"]
depths = data["depth"].permute([2, 0, 1]).repeat([3, 1, 1])
rgb = self.img2tensor(self.rgb_files[idx])
normals = self.img2tensor(self.normals_files[idx])
normals = normals * 2.0 - 1.0
return rgb, depths, normals, offsets
if __name__ == "__main__":
from torch.utils.data import DataLoader
if torch.cuda.is_available():
device = torch.device("cuda:0")
torch.cuda.set_device(device)
else:
device = torch.device("cpu")
batch_size = 8
data_dir = "/home/zyuwei/Projects/cloth_shape_estimation/data"
train_ds = DeformDataset(data_dir=data_dir, mode="train")
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
for idx, batch in enumerate(train_dl):
batch = [data.to(device) for data in batch]
rgb, depths, normals, offsets = batch
print(
rgb.shape, depths.shape, normals.shape, offsets.shape
)
if idx >= 10:
break