-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
100 lines (90 loc) · 3.35 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import numpy as np
import cv2
import random
import torch
def padding(x,y):
h,w,c = x.shape
size = max(h,w)
paddingh = (size-h)//2
paddingw = (size-w)//2
temp_x = np.zeros((size,size,c))
temp_y = np.zeros((size,size))
temp_x[paddingh:h+paddingh,paddingw:w+paddingw,:] = x
temp_y[paddingh:h+paddingh,paddingw:w+paddingw] = y
return temp_x,temp_y
def random_crop(x,y):
h,w = y.shape
randh = np.random.randint(h/8)
randw = np.random.randint(w/8)
randf = np.random.randint(10)
offseth = 0 if randh == 0 else np.random.randint(randh)
offsetw = 0 if randw == 0 else np.random.randint(randw)
p0, p1, p2, p3 = offseth,h+offseth-randh, offsetw, w+offsetw-randw
if randf >= 5:
x = x[::, ::-1, ::]
y = y[::, ::-1]
return x[p0:p1,p2:p3],y[p0:p1,p2:p3]
def random_rotate(x,y):
angle = np.random.randint(-25,25)
h, w = y.shape
center = (w / 2, h / 2)
M = cv2.getRotationMatrix2D(center, angle, 1.0)
return cv2.warpAffine(x, M, (w, h)),cv2.warpAffine(y, M, (w, h))
def random_light(x):
contrast = np.random.rand(1)+0.5
light = np.random.randint(-20,20)
x = contrast*x + light
return np.clip(x,0,255)
def generateTxt(image_dir,mask_dir,save_name):
with open(save_name,'w')as fp:
print('There are %d files in %s'%(len(os.listdir(image_dir)),image_dir))
for name in os.listdir(image_dir):
image_path = os.path.join(image_dir,name)
mask_path = os.path.join(mask_dir,name.split('.')[0]+'.png')
fp.write(image_path + ' ' + mask_path + '\n')
fp.close()
return save_name
def getTrainGenerator(file_path, target_size, batch_size, israndom=False):
f = open(file_path, 'r')
trainlist = f.readlines()
f.close()
while True:
random.shuffle(trainlist)
batch_x = []
batch_y = []
for i,name in enumerate(trainlist):
p = name.strip('\r\n').split(' ')
img_path = p[0]
mask_path = p[1]
x = cv2.imread(img_path)
y = cv2.imread(mask_path)
x = np.array(x, dtype=np.float32)
y = np.array(y, dtype=np.float32)
if len(y.shape) == 3:
y = y[:,:,0]
y = y/y.max()
if israndom:
x,y = random_crop(x,y)
x,y = random_rotate(x,y)
x = random_light(x)
x = x[..., ::-1]
# Zero-center by mean pixel
x[..., 0] -= 103.939
x[..., 1] -= 116.779
x[..., 2] -= 123.68
x, y = padding(x, y)
x = cv2.resize(x, target_size, interpolation=cv2.INTER_LINEAR)
y = cv2.resize(y, target_size, interpolation=cv2.INTER_NEAREST)
x = np.array([x[:,:,0],x[:,:,1],x[:,:,2]])
y = np.array([y])
batch_x.append(x)
batch_y.append(y)
if len(batch_x) == batch_size or i == len(trainlist):
yield (torch.Tensor(np.array(batch_x, dtype=np.float32)), torch.Tensor(np.array(batch_y, dtype=np.float32)))
batch_x = []
batch_y = []
if __name__ == "__main__":
txtname = generateTxt('DUTS-TR/DUTS-TR-Image/','DUTS-TR/DUTS-TR-Mask/','train.txt')
generator = getTrainGenerator(txtname, target_size=(32,32), batch_size=1, israndom=False)
print(generator.__next__()[0].min())