Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

addded decode_predictions2() to utils.py #25

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions keras_vggface/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@



import json
import numpy as np
from keras import backend as K
from keras.utils.data_utils import get_file
Expand All @@ -27,6 +28,8 @@

VGGFACE_DIR = 'models/vggface'

VGGFACE2_CLASS_INDEX = None
VGGFACE2_CLASS_INDEX_PATH = 'https://tbd/vggface2_class_index.json'

def preprocess_input(x, data_format=None, version=1):
if data_format is None:
Expand Down Expand Up @@ -62,6 +65,42 @@ def preprocess_input(x, data_format=None, version=1):
return x


def decode_predictions2(preds, top=5):
"""Decodes the prediction of a vggface2 model.

# Arguments
preds: Numpy tensor encoding a batch of predictions.
top: integer, how many top-guesses to return.

# Returns
A list of lists of top class prediction tuples
`(class_name, class_description, score)`.
One list of tuples per sample in batch input.

# Raises
ValueError: in case of invalid shape of the `pred` array
(must be 2D).
"""
global VGGFACE2_CLASS_INDEX
if len(preds.shape) != 2 or preds.shape[1] != 8631:
raise ValueError('`decode_predictions` expects '
'a batch of predictions '
'(i.e. a 2D array of shape (samples, 8631)). '
'Found array with shape: ' + str(preds.shape))
if VGGFACE2_CLASS_INDEX is None:
fpath = get_file('vggface2_class_index.json',
VGGFACE2_CLASS_INDEX_PATH,
cache_subdir=VGGFACE_DIR)
VGGFACE2_CLASS_INDEX = json.load(open(fpath))
results = []
for pred in preds:
top_indices = pred.argsort()[-top:][::-1]
result = [tuple(VGGFACE2_CLASS_INDEX[str(i)]) + (pred[i],) for i in top_indices]
result.sort(key=lambda x: x[2], reverse=True)
results.append(result)
return results


def decode_predictions(preds, top=5):
LABELS = None
if len(preds.shape) == 2:
Expand Down