-
Notifications
You must be signed in to change notification settings - Fork 1
/
input_data.py
31 lines (25 loc) · 980 Bytes
/
input_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import tensorflow as tf
from os import listdir
import numpy as np
file_path = "data/voc_pascal/"
batch_size = 1
def parse_image(filename):
image_string = tf.read_file(file_path + filename)
image_decoded = tf.image.decode_jpeg(image_string, channels=3)
image_resized = tf.image.resize_images(image_decoded, [64, 64])
img_standard = tf.image.per_image_standardization(image_resized)
return img_standard
def get_filenames():
filenames = listdir(file_path)
return filenames
def input_data():
filenames = get_filenames()
print(len(filenames))
train_dataset = tf.data.Dataset.from_tensor_slices(filenames)
train_dataset = train_dataset.shuffle(100).repeat()
train_dataset = train_dataset.map(parse_image, num_parallel_calls=4).batch(batch_size)
return train_dataset.make_one_shot_iterator()
if __name__ == '__main__':
iterator = input_data()
images = iterator.get_next()
print(images)