-
Notifications
You must be signed in to change notification settings - Fork 5
/
chexnet_wrapper.py
23 lines (20 loc) · 1.13 KB
/
chexnet_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from utility import load_model
import numpy as np
from tensorflow.keras.models import Model
import tensorflow as tf
class ChexnetWrapper:
def __init__(self, model_path, model_name, pop_conv_layers):
model = load_model(model_path, model_name)
self.model = Model(inputs=model.input, outputs=[model.output, model.layers[-pop_conv_layers - 1].output])
self.model.summary()
def get_visual_features(self, images, threshold):
state = tf.keras.backend.learning_phase()
tf.keras.backend.set_learning_phase(0)
predictions, visual_features = self.model.predict(images)
predictions = np.reshape(predictions, [predictions.shape[0], -1])
visual_features = np.reshape(visual_features, [visual_features.shape[0], -1])
predictions = np.reshape(predictions, (predictions.shape[0], -1, predictions.shape[-1]))
visual_features = np.reshape(visual_features, (visual_features.shape[0], -1, visual_features.shape[-1]))
predictions = np.array(predictions >= threshold, dtype=np.float32)
tf.keras.backend.set_learning_phase(state)
return predictions, visual_features