diff --git a/generator.py b/generator.py index 6b0e271..3913c64 100644 --- a/generator.py +++ b/generator.py @@ -51,6 +51,11 @@ def __getitem__(self, idx): batch_x = np.asarray([self.load_image(x_path) for x_path in batch_x_path]) batch_x = self.transform_batch_images(batch_x) batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size] + if batch_x.shape[0] < self.batch_size: + remaining = self.batch_size - batch_x.shape[0] + batch_x = np.concatenate((batch_x,batch_x[0:remaining]), axis = 0) + batch_x_path = np.concatenate((batch_x_path,batch_x_path[0:remaining]), axis = 0) + batch_y = np.concatenate((batch_y,batch_y[0:remaining]), axis = 0) return batch_x, batch_y, batch_x_path def load_image(self, image_file): @@ -83,7 +88,7 @@ def get_y_true(self): def prepare_dataset(self): df = self.dataset_df.sample(frac=1., random_state=self.random_state) - self.x_path, self.y = df["Image Index"].as_matrix(), self.tokenizer_wrapper.tokenize_sentences(df[self.class_names].as_matrix()) + self.x_path, self.y = df["Image Index"].values, self.tokenizer_wrapper.tokenize_sentences(df[self.class_names].values) def on_epoch_end(self): if self.shuffle: diff --git a/requirements.txt b/requirements.txt index a6831a2..ca9d702 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,25 +1,9 @@ -absl-py==0.7.1 -astor==0.8.0 -boto==2.49.0 -boto3==1.9.214 -botocore==1.12.214 -certifi==2019.6.16 -chardet==3.0.4 -Click==7.0 -cycler==0.10.0 -decorator==4.4.0 docutils==0.15.2 efficientnet==1.0.0 -gast==0.2.2 -google-pasta==0.1.7 -h5py==2.9.0 -idna==2.8 imageio==2.5.0 imgaug==0.3.0 jmespath==0.9.4 joblib==0.14.1 -Keras-Applications==1.0.8 -Keras-Preprocessing==1.1.0 kiwisolver==1.1.0 lxml==4.4.1 Markdown==3.1.1 @@ -29,29 +13,26 @@ nltk==3.4.5 numpy==1.17.0 opencv-python==4.1.0.25 opencv-python-headless==4.1.2.30 -opt-einsum==3.1.0 pandas==0.25.1 Pillow==6.1.0 -protobuf==3.9.1 psutil==5.6.7 pyparsing==2.4.2 python-dateutil==2.8.0 python-docx==0.8.10 -pytz==2019.2 PyWavelets==1.0.3 requests==2.22.0 s3transfer==0.2.1 scikit-image==0.15.0 scikit-learn==0.22.1 -Shapely==1.6.4.post2 -six==1.12.0 +Shapely==1.7.1 smart-open==1.8.4 -tensorflow==2.1.0 +tensorflow==2.3.0 termcolor==1.1.0 -Theano==1.0.4 tqdm==4.41.1 urllib3==1.25.3 -Werkzeug==0.15.5 -wrapt==1.11.2 -xdg==4.0.1 +boto3==1.10.50 +botocore==1.13.50 +pymc3==3.11.0 +theano==1.0.4 +theano-pymc==1.1.0 git+https://github.com/Maluuba/nlg-eval.git@master diff --git a/train.py b/train.py index 23af221..ebed249 100644 --- a/train.py +++ b/train.py @@ -145,7 +145,7 @@ def train_step(tag_predictions, visual_features, target): print('Batches that took long: {}'.format(times_to_get_batch)) ckpt_manager.save() - if epoch % FLAGS.epochs_to_evaluate == 0: + if epoch % FLAGS.epochs_to_evaluate == 0 and epoch > 0: print("Evaluating on test set..") train_enqueuer.stop() current_scores = evaluate_enqueuer(test_enqueuer, test_steps, FLAGS, encoder, decoder, tokenizer_wrapper,