Skip to content

Commit

Permalink
ENH: efficiency of EM algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
VincentAuriau committed Dec 26, 2024
1 parent 1597ca3 commit 2eaf427
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 11 deletions.
3 changes: 1 addition & 2 deletions choice_learn/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ def fit(
self.callbacks.on_train_end(logs=temps_logs)
return losses_history

@tf.function
@tf.function(reduce_retracing=True)
def batch_predict(
self,
shared_features_by_choice,
Expand Down Expand Up @@ -731,7 +731,6 @@ def f(params_1d):
# calculate gradients and convert to 1D tf.Tensor
grads = tape.gradient(loss_value, self.trainable_weights)
grads = tf.dynamic_stitch(idx, grads)
# print out iteration & loss
f.iter.assign_add(1)

# store loss value so we can retrieve later
Expand Down
31 changes: 23 additions & 8 deletions choice_learn/models/latent_class_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,18 @@ def instantiate(self, **kwargs):
name="Latent-Logits",
)
self.latent_logits = init_logit
self.models = [self.model_class(**mp) for mp in self.model_parameters]
for model in self.models:
model.instantiate(**kwargs)

self.models = self.instantiate_latent_models(**kwargs)
self.instantiated = True

def instantiate_latent_models(self, **kwargs):
"""Instantiate latent models."""
models = [self.model_class(**mp) for mp in self.model_parameters]
for model in models:
model.instantiate(**kwargs)

return models

# @tf.function
def batch_predict(
self,
Expand Down Expand Up @@ -824,7 +830,7 @@ def _expectation(self, choice_dataset):
)

return tf.clip_by_value(
predicted_probas / np.sum(predicted_probas, axis=1, keepdims=True), 1e-10, 1
predicted_probas / np.sum(predicted_probas, axis=1, keepdims=True), 1e-6, 1
), loss

def _maximization(self, choice_dataset, verbose=0):
Expand All @@ -842,10 +848,17 @@ def _maximization(self, choice_dataset, verbose=0):
np.ndarray
latent probabilities resulting of maximization step
"""
self.models = [self.model_class(**mp) for mp in self.model_parameters]
# models = [self.model_class(**mp) for mp in self.model_parameters]
# for i in range(len(models)):
# for j, var in enumerate(self.models[i].trainable_weights):
# models[i]._trainable_weights[j] = var
# self.instantiate_latent_models(choice_dataset)

# M-step: MNL estimation
for q in range(self.n_latent_classes):
self.models[q].fit(choice_dataset, sample_weight=self.weights[:, q], verbose=verbose)
self.models[q].fit(
choice_dataset, sample_weight=self.weights[:, q].numpy(), verbose=verbose
)

# M-step: latent probability estimation
latent_probas = np.sum(self.weights, axis=0)
Expand Down Expand Up @@ -876,7 +889,9 @@ def _em_fit(self, choice_dataset, sample_weight=None, verbose=0):

# Initialization
init_sample_weight = np.random.rand(self.n_latent_classes, len(choice_dataset))
init_sample_weight = init_sample_weight / np.sum(init_sample_weight, axis=0, keepdims=True)
init_sample_weight = np.clip(
init_sample_weight / np.sum(init_sample_weight, axis=0, keepdims=True), 1e-6, 1
)
for i, model in enumerate(self.models):
# model.instantiate()
model.fit(choice_dataset, sample_weight=init_sample_weight[i], verbose=verbose)
Expand All @@ -888,7 +903,7 @@ def _em_fit(self, choice_dataset, sample_weight=None, verbose=0):
if np.sum(np.isnan(self.latent_logits)) > 0:
print("Nan in logits")
break
return hist_logits, hist_loss
return hist_loss, hist_logits

def predict_probas(self, choice_dataset, batch_size=-1):
"""Predicts the choice probabilities for each choice and each product of a ChoiceDataset.
Expand Down
14 changes: 13 additions & 1 deletion choice_learn/models/latent_class_mnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import tensorflow as tf

import choice_learn.tf_ops as tf_ops

from .conditional_logit import ConditionalLogit, MNLCoefficients
from .latent_class_base_model import BaseLatentClassModel
from .simple_mnl import SimpleMNL
Expand All @@ -23,6 +25,7 @@ def __init__(
intercept=None,
optimizer="Adam",
lr=0.001,
epochs_maximization=1000,
**kwargs,
):
"""Initialize model.
Expand Down Expand Up @@ -56,7 +59,7 @@ def __init__(
"batch_size": batch_size,
"lbfgs_tolerance": lbfgs_tolerance,
"lr": lr,
"epochs": 1000,
"epochs": epochs_maximization,
}

super().__init__(
Expand Down Expand Up @@ -88,6 +91,15 @@ def instantiate_latent_models(self, n_items, n_shared_features, n_items_features
model.indexes, model.weights = model.instantiate(
n_items, n_shared_features, n_items_features
)
model.exact_nll = tf_ops.CustomCategoricalCrossEntropy(
from_logits=False,
label_smoothing=0.0,
sparse=False,
axis=-1,
epsilon=1e-25,
name="exact_categorical_crossentropy",
reduction="sum_over_batch_size",
)
model.instantiated = True

def instantiate(self, n_items, n_shared_features, n_items_features):
Expand Down

0 comments on commit 2eaf427

Please sign in to comment.