From 55b1f131dd646b3949c7838f47129c1b7a5ce42e Mon Sep 17 00:00:00 2001 From: Tom White Date: Fri, 5 Jan 2018 14:39:16 +1300 Subject: [PATCH] addded decode_predictions2() to utils.py Added a version of decode_predictions specific to vggface2 that returns results in an identical format as the similarily named function in keras.applications.imagenet_utils Sample response: [('n001138', 'Bong_Joon-ho', 0.98794019), ('n009064', 'Yang_Yuanqing', 0.0016884505), ('n007811', 'Satoru_Iwata', 0.0016516199), ('n000161', 'Akio_Toyoda', 0.00053648232), ('n007343', 'Richard_Stallman', 0.00044841404)] Note: depends on vggface2_class_index.json file. --- keras_vggface/utils.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/keras_vggface/utils.py b/keras_vggface/utils.py index cf64f98..af16c17 100644 --- a/keras_vggface/utils.py +++ b/keras_vggface/utils.py @@ -7,6 +7,7 @@ +import json import numpy as np from keras import backend as K from keras.utils.data_utils import get_file @@ -27,6 +28,8 @@ VGGFACE_DIR = 'models/vggface' +VGGFACE2_CLASS_INDEX = None +VGGFACE2_CLASS_INDEX_PATH = 'https://tbd/vggface2_class_index.json' def preprocess_input(x, data_format=None, version=1): if data_format is None: @@ -62,6 +65,42 @@ def preprocess_input(x, data_format=None, version=1): return x +def decode_predictions2(preds, top=5): + """Decodes the prediction of a vggface2 model. + + # Arguments + preds: Numpy tensor encoding a batch of predictions. + top: integer, how many top-guesses to return. + + # Returns + A list of lists of top class prediction tuples + `(class_name, class_description, score)`. + One list of tuples per sample in batch input. + + # Raises + ValueError: in case of invalid shape of the `pred` array + (must be 2D). + """ + global VGGFACE2_CLASS_INDEX + if len(preds.shape) != 2 or preds.shape[1] != 8631: + raise ValueError('`decode_predictions` expects ' + 'a batch of predictions ' + '(i.e. a 2D array of shape (samples, 8631)). ' + 'Found array with shape: ' + str(preds.shape)) + if VGGFACE2_CLASS_INDEX is None: + fpath = get_file('vggface2_class_index.json', + VGGFACE2_CLASS_INDEX_PATH, + cache_subdir=VGGFACE_DIR) + VGGFACE2_CLASS_INDEX = json.load(open(fpath)) + results = [] + for pred in preds: + top_indices = pred.argsort()[-top:][::-1] + result = [tuple(VGGFACE2_CLASS_INDEX[str(i)]) + (pred[i],) for i in top_indices] + result.sort(key=lambda x: x[2], reverse=True) + results.append(result) + return results + + def decode_predictions(preds, top=5): LABELS = None if len(preds.shape) == 2: