-
Notifications
You must be signed in to change notification settings - Fork 2
/
get_train_data.py
47 lines (30 loc) · 1.03 KB
/
get_train_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import get_cluster_centres
import numpy as np
def get_images_and_labels_1(img_dict):
img_labels = []
img_list = []
for key in img_dict.keys():
for des in img_dict[key]:
img_labels.append(key)
# extend label 50 times
# img_labels.extend(itertools.repeat(key, 50))
img_des = des
C = get_cluster_centres.get_cluster_centres(img_des, 50)
img_list.append(C)
img_labels_np = np.asarray(img_labels)
img_list_np = np.asarray(img_list)
(x, m, n) = img_list_np.shape
img_list_np_reshaped = img_list_np.reshape(x, (m, n))
return (img_labels_np, img_list_np_reshaped)
"""
Idea is to vshape the numpy array, each time we add a new image's
descriptor list.
"""
def get_images_and_labels(img_dict):
img_labels = []
img_list = []
for key in img_dict.keys():
for img in img_dict[key]:
img_labels.append(key)
img_list.append(img)
return (img_list, img_labels)