Skip to content

Commit

Permalink
TST: faster reslogit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
julesdesir committed Oct 1, 2024
1 parent 23e084e commit 9469f9d
Showing 1 changed file with 89 additions and 70 deletions.
159 changes: 89 additions & 70 deletions tests/integration_tests/models/test_reslogit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
import numpy as np

from choice_learn.datasets import load_swissmetro
from choice_learn.models import ResLogit, SimpleMNL

# from choice_learn.models import ResLogit, SimpleMNL
from choice_learn.models import ResLogit

dataset = load_swissmetro()
dataset = dataset[:10] # Reduce the dataset size for faster testing
n_items = np.shape(dataset.items_features_by_choice)[2]
n_shared_features = np.shape(dataset.shared_features_by_choice)[2]
n_items_features = np.shape(dataset.items_features_by_choice)[3]
Expand All @@ -16,29 +19,35 @@ def test_reslogit_fit_with_sgd():
global dataset

model = ResLogit(lr=1e-6, epochs=30, optimizer="SGD", batch_size=32)
model.instantiate(n_items, n_shared_features, n_items_features)
eval_before = model.evaluate(dataset)
model.fit(dataset)
model.evaluate(dataset)
assert model.evaluate(dataset) < 1.0
eval_after = model.evaluate(dataset)
assert eval_after <= eval_before


def test_reslogit_fit_with_adam():
"""Tests that ResLogit can fit with Adam."""
global dataset

model = ResLogit(epochs=20, optimizer="Adam", batch_size=32)
model.instantiate(n_items, n_shared_features, n_items_features)
eval_before = model.evaluate(dataset)
model.fit(dataset)
model.evaluate(dataset)
assert model.evaluate(dataset) < 1.0
eval_after = model.evaluate(dataset)
assert eval_after <= eval_before


def test_reslogit_fit_with_adamax():
"""Tests that ResLogit can fit with Adamax."""
global dataset

model = ResLogit(epochs=20, optimizer="Adamax", batch_size=32)
model.instantiate(n_items, n_shared_features, n_items_features)
eval_before = model.evaluate(dataset)
model.fit(dataset)
model.evaluate(dataset)
assert model.evaluate(dataset) < 1.0
eval_after = model.evaluate(dataset)
assert eval_after <= eval_before


def test_reslogit_fit_with_optimizer_not_implemented():
Expand All @@ -49,9 +58,11 @@ def test_reslogit_fit_with_optimizer_not_implemented():
global dataset

model = ResLogit(epochs=20, optimizer="xyz_not_implemented", batch_size=32)
model.instantiate(n_items, n_shared_features, n_items_features)
eval_before = model.evaluate(dataset)
model.fit(dataset)
model.evaluate(dataset)
assert model.evaluate(dataset) < 1.0
eval_after = model.evaluate(dataset)
assert eval_after <= eval_before


def test_reslogit_fit_with_none_intercept():
Expand All @@ -65,9 +76,11 @@ def test_reslogit_fit_with_none_intercept():
)
assert "intercept" not in indexes

model.instantiate(n_items, n_shared_features, n_items_features)
eval_before = model.evaluate(dataset)
model.fit(dataset)
model.evaluate(dataset)
assert model.evaluate(dataset) < 1.0
eval_after = model.evaluate(dataset)
assert eval_after <= eval_before


def test_reslogit_fit_with_item_intercept():
Expand All @@ -81,9 +94,10 @@ def test_reslogit_fit_with_item_intercept():
)
assert "intercept" in indexes

eval_before = model.evaluate(dataset)
model.fit(dataset)
model.evaluate(dataset)
assert model.evaluate(dataset) < 1.0
eval_after = model.evaluate(dataset)
assert eval_after <= eval_before


def test_reslogit_fit_with_item_full_intercept():
Expand All @@ -97,9 +111,10 @@ def test_reslogit_fit_with_item_full_intercept():
)
assert "intercept" in indexes

eval_before = model.evaluate(dataset)
model.fit(dataset)
model.evaluate(dataset)
assert model.evaluate(dataset) < 1.0
eval_after = model.evaluate(dataset)
assert eval_after <= eval_before


def test_reslogit_fit_with_other_intercept():
Expand All @@ -115,49 +130,51 @@ def test_reslogit_fit_with_other_intercept():
)
assert "intercept" in indexes

model.instantiate(n_items, n_shared_features, n_items_features)
eval_before = model.evaluate(dataset)
model.fit(dataset)
model.evaluate(dataset)
assert model.evaluate(dataset) < 1.0


def test_reslogit_comparison_with_simple_mnl():
"""Tests that ResLogit can fit better than SimpleMNL."""
global dataset

reslogit = ResLogit(
intercept="item", lr=1e-6, n_layers=0, epochs=100, optimizer="SGD", batch_size=32
)
reslogit_indexes, reslogit_initial_weights = reslogit.instantiate(
n_items=n_items, n_shared_features=n_shared_features, n_items_features=n_items_features
)
reslogit.fit(dataset)
reslogit_final_weights = reslogit.trainable_weights
reslogit_score = reslogit.evaluate(dataset)

simple_mnl = SimpleMNL(intercept="item", lr=1e-6, epochs=100, optimizer="SGD", batch_size=32)
simple_mnl_indexes, simple_mnl_initial_weights = simple_mnl.instantiate(
n_items=n_items, n_shared_features=n_shared_features, n_items_features=n_items_features
)
simple_mnl.fit(dataset)
simple_mnl_final_weights = simple_mnl.trainable_weights
simple_mnl_score = simple_mnl.evaluate(dataset)

assert reslogit_indexes == simple_mnl_indexes
for i in range(len(reslogit_initial_weights)):
assert np.allclose(
simple_mnl_initial_weights[i].numpy(),
reslogit_initial_weights[i].numpy(),
rtol=0,
atol=0.01,
)
assert np.abs(simple_mnl_score - reslogit_score) < 0.05
for i in range(len(reslogit_final_weights)):
assert np.allclose(
simple_mnl_final_weights[i].numpy(),
reslogit_final_weights[i].numpy(),
rtol=0,
atol=0.01,
)
eval_after = model.evaluate(dataset)
assert eval_after <= eval_before


# def test_reslogit_comparison_with_simple_mnl():
# """Tests that ResLogit can fit better than SimpleMNL."""
# full_dataset = load_swissmetro() # Use the full dataset to compare the models

# reslogit = ResLogit(
# intercept="item", lr=1e-6, n_layers=0, epochs=100, optimizer="SGD", batch_size=32
# )
# reslogit_indexes, reslogit_initial_weights = reslogit.instantiate(
# n_items=n_items, n_shared_features=n_shared_features, n_items_features=n_items_features
# )
# reslogit.fit(full_dataset)
# reslogit_final_weights = reslogit.trainable_weights
# reslogit_score = reslogit.evaluate(full_dataset)

# simple_mnl = SimpleMNL(intercept="item", lr=1e-6, epochs=100, optimizer="SGD", batch_size=32)
# simple_mnl_indexes, simple_mnl_initial_weights = simple_mnl.instantiate(
# n_items=n_items, n_shared_features=n_shared_features, n_items_features=n_items_features
# )
# simple_mnl.fit(full_dataset)
# simple_mnl_final_weights = simple_mnl.trainable_weights
# simple_mnl_score = simple_mnl.evaluate(full_dataset)

# assert reslogit_indexes == simple_mnl_indexes
# for i in range(len(reslogit_initial_weights)):
# assert np.allclose(
# simple_mnl_initial_weights[i].numpy(),
# reslogit_initial_weights[i].numpy(),
# rtol=0,
# atol=0.01,
# )
# assert np.abs(simple_mnl_score - reslogit_score) < 0.05
# for i in range(len(reslogit_final_weights)):
# assert np.allclose(
# simple_mnl_final_weights[i].numpy(),
# reslogit_final_weights[i].numpy(),
# rtol=0,
# atol=0.01,
# )


def test_reslogit_different_n_layers():
Expand All @@ -166,11 +183,12 @@ def test_reslogit_different_n_layers():

for n_layers in [0, 1, 4, 16]:
model = ResLogit(n_layers=n_layers, lr=1e-6, epochs=20, optimizer="SGD", batch_size=32)
model.fit(dataset)
model.evaluate(dataset)

# The model can fit
assert model.evaluate(dataset) < 1.0
model.instantiate(n_items, n_shared_features, n_items_features)
eval_before = model.evaluate(dataset)
model.fit(dataset)
eval_after = model.evaluate(dataset)
assert eval_after <= eval_before

# The global shape of the residual weights corresponds to the number of layers
assert len(model.resnet_model.trainable_variables) == n_layers
Expand All @@ -197,12 +215,12 @@ def test_reslogit_different_layers_width():
optimizer="SGD",
batch_size=32,
)
model.fit(dataset)
model.evaluate(dataset)

# The model can fit
# (We don't check the exact value because the model is not optimized for this test)
assert model.evaluate(dataset) < 1e3
model.instantiate(n_items, n_shared_features, n_items_features)
eval_before = model.evaluate(dataset)
model.fit(dataset)
eval_after = model.evaluate(dataset)
assert eval_after <= eval_before

# The global shape of the residual weights corresponds to the number of layers
assert len(model.resnet_model.trainable_variables) == n_layers
Expand Down Expand Up @@ -268,11 +286,12 @@ def test_reslogit_different_activation():
optimizer="SGD",
batch_size=32,
)
model.fit(dataset)
model.evaluate(dataset)

# The model can fit
assert model.evaluate(dataset) < 1
model.instantiate(n_items, n_shared_features, n_items_features)
eval_before = model.evaluate(dataset)
model.fit(dataset)
eval_after = model.evaluate(dataset)
assert eval_after <= eval_before

# Check if the ValueError is raised when the activation is not implemented
model = ResLogit(
Expand Down

0 comments on commit 9469f9d

Please sign in to comment.