Skip to content

Commit

Permalink
-handle batch size issues
Browse files Browse the repository at this point in the history
  • Loading branch information
omar-mohamed committed Oct 25, 2021
1 parent 1457e96 commit 29e3668
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 28 deletions.
7 changes: 6 additions & 1 deletion generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ def __getitem__(self, idx):
batch_x = np.asarray([self.load_image(x_path) for x_path in batch_x_path])
batch_x = self.transform_batch_images(batch_x)
batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
if batch_x.shape[0] < self.batch_size:
remaining = self.batch_size - batch_x.shape[0]
batch_x = np.concatenate((batch_x,batch_x[0:remaining]), axis = 0)
batch_x_path = np.concatenate((batch_x_path,batch_x_path[0:remaining]), axis = 0)
batch_y = np.concatenate((batch_y,batch_y[0:remaining]), axis = 0)
return batch_x, batch_y, batch_x_path

def load_image(self, image_file):
Expand Down Expand Up @@ -83,7 +88,7 @@ def get_y_true(self):

def prepare_dataset(self):
df = self.dataset_df.sample(frac=1., random_state=self.random_state)
self.x_path, self.y = df["Image Index"].as_matrix(), self.tokenizer_wrapper.tokenize_sentences(df[self.class_names].as_matrix())
self.x_path, self.y = df["Image Index"].values, self.tokenizer_wrapper.tokenize_sentences(df[self.class_names].values)

def on_epoch_end(self):
if self.shuffle:
Expand Down
33 changes: 7 additions & 26 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,25 +1,9 @@
absl-py==0.7.1
astor==0.8.0
boto==2.49.0
boto3==1.9.214
botocore==1.12.214
certifi==2019.6.16
chardet==3.0.4
Click==7.0
cycler==0.10.0
decorator==4.4.0
docutils==0.15.2
efficientnet==1.0.0
gast==0.2.2
google-pasta==0.1.7
h5py==2.9.0
idna==2.8
imageio==2.5.0
imgaug==0.3.0
jmespath==0.9.4
joblib==0.14.1
Keras-Applications==1.0.8
Keras-Preprocessing==1.1.0
kiwisolver==1.1.0
lxml==4.4.1
Markdown==3.1.1
Expand All @@ -29,29 +13,26 @@ nltk==3.4.5
numpy==1.17.0
opencv-python==4.1.0.25
opencv-python-headless==4.1.2.30
opt-einsum==3.1.0
pandas==0.25.1
Pillow==6.1.0
protobuf==3.9.1
psutil==5.6.7
pyparsing==2.4.2
python-dateutil==2.8.0
python-docx==0.8.10
pytz==2019.2
PyWavelets==1.0.3
requests==2.22.0
s3transfer==0.2.1
scikit-image==0.15.0
scikit-learn==0.22.1
Shapely==1.6.4.post2
six==1.12.0
Shapely==1.7.1
smart-open==1.8.4
tensorflow==2.1.0
tensorflow==2.3.0
termcolor==1.1.0
Theano==1.0.4
tqdm==4.41.1
urllib3==1.25.3
Werkzeug==0.15.5
wrapt==1.11.2
xdg==4.0.1
boto3==1.10.50
botocore==1.13.50
pymc3==3.11.0
theano==1.0.4
theano-pymc==1.1.0
git+https://github.com/Maluuba/nlg-eval.git@master
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def train_step(tag_predictions, visual_features, target):
print('Batches that took long: {}'.format(times_to_get_batch))

ckpt_manager.save()
if epoch % FLAGS.epochs_to_evaluate == 0:
if epoch % FLAGS.epochs_to_evaluate == 0 and epoch > 0:
print("Evaluating on test set..")
train_enqueuer.stop()
current_scores = evaluate_enqueuer(test_enqueuer, test_steps, FLAGS, encoder, decoder, tokenizer_wrapper,
Expand Down

0 comments on commit 29e3668

Please sign in to comment.