Skip to content

Commit

Permalink
TST: fix test_reslogit.py (#165)
Browse files Browse the repository at this point in the history
* TST: small change

* TST: change dataset size in test_reslogit.py

* TST: other small changes

* TST: fix optimizer error

* TST: fix typo
  • Loading branch information
julesdesir authored Oct 3, 2024
1 parent c011a58 commit b82ee45
Showing 1 changed file with 63 additions and 35 deletions.
98 changes: 63 additions & 35 deletions tests/integration_tests/models/test_reslogit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,20 @@
from choice_learn.models import ResLogit

dataset = load_swissmetro()
dataset = dataset[:10] # Reduce the dataset size for faster testing
dataset = dataset[:100] # Reduce the dataset size for faster testing
n_items = np.shape(dataset.items_features_by_choice)[2]
n_shared_features = np.shape(dataset.shared_features_by_choice)[2]
n_items_features = np.shape(dataset.items_features_by_choice)[3]

nb_epochs = 100
batch_size = -1


def test_reslogit_fit_with_sgd():
"""Tests that ResLogit can fit with SGD."""
global dataset

model = ResLogit(lr=1e-6, epochs=30, optimizer="SGD", batch_size=32)
model = ResLogit(lr=1e-3, epochs=nb_epochs, optimizer="SGD", batch_size=batch_size)
model.instantiate(n_items, n_shared_features, n_items_features)
eval_before = model.evaluate(dataset)
tf.config.run_functions_eagerly(True) # To help with the coverage calculation
Expand All @@ -32,7 +35,7 @@ def test_reslogit_fit_with_adam():
"""Tests that ResLogit can fit with Adam."""
global dataset

model = ResLogit(epochs=20, optimizer="Adam", batch_size=32)
model = ResLogit(lr=1e-3, epochs=nb_epochs, optimizer="Adam", batch_size=batch_size)
model.instantiate(n_items, n_shared_features, n_items_features)
eval_before = model.evaluate(dataset)
model.fit(dataset)
Expand All @@ -44,7 +47,7 @@ def test_reslogit_fit_with_adamax():
"""Tests that ResLogit can fit with Adamax."""
global dataset

model = ResLogit(epochs=20, optimizer="Adamax", batch_size=32)
model = ResLogit(lr=1e-3, epochs=nb_epochs, optimizer="Adamax", batch_size=batch_size)
model.instantiate(n_items, n_shared_features, n_items_features)
eval_before = model.evaluate(dataset)
model.fit(dataset)
Expand All @@ -59,7 +62,9 @@ def test_reslogit_fit_with_optimizer_not_implemented():
"""
global dataset

model = ResLogit(epochs=20, optimizer="xyz_not_implemented", batch_size=32)
model = ResLogit(
lr=1e-3, epochs=nb_epochs, optimizer="xyz_not_implemented", batch_size=batch_size
)
model.instantiate(n_items, n_shared_features, n_items_features)
eval_before = model.evaluate(dataset)
model.fit(dataset)
Expand All @@ -71,7 +76,9 @@ def test_reslogit_fit_with_none_intercept():
"""Tests that ResLogit can fit with intercept=None."""
global dataset

model = ResLogit(intercept=None, lr=1e-6, epochs=20, optimizer="SGD", batch_size=32)
model = ResLogit(
intercept=None, lr=1e-3, epochs=nb_epochs, optimizer="Adam", batch_size=batch_size
)

indexes, weights = model.instantiate(
n_items=n_items, n_shared_features=n_shared_features, n_items_features=n_items_features
Expand All @@ -89,7 +96,9 @@ def test_reslogit_fit_with_item_intercept():
"""Tests that ResLogit can fit with intercept="item"."""
global dataset

model = ResLogit(intercept="item", lr=1e-6, epochs=20, optimizer="SGD", batch_size=32)
model = ResLogit(
intercept="item", lr=1e-3, epochs=nb_epochs, optimizer="Adam", batch_size=batch_size
)

indexes, weights = model.instantiate(
n_items=n_items, n_shared_features=n_shared_features, n_items_features=n_items_features
Expand All @@ -106,7 +115,9 @@ def test_reslogit_fit_with_item_full_intercept():
"""Tests that ResLogit can fit with intercept="item-full"."""
global dataset

model = ResLogit(intercept="item-full", lr=1e-6, epochs=20, optimizer="SGD", batch_size=32)
model = ResLogit(
intercept="item-full", lr=1e-3, epochs=nb_epochs, optimizer="Adam", batch_size=batch_size
)

indexes, weights = model.instantiate(
n_items=n_items, n_shared_features=n_shared_features, n_items_features=n_items_features
Expand All @@ -124,7 +135,11 @@ def test_reslogit_fit_with_other_intercept():
global dataset

model = ResLogit(
intercept="xyz_other_intercept", lr=1e-6, epochs=20, optimizer="SGD", batch_size=32
intercept="xyz_other_intercept",
lr=1e-3,
epochs=nb_epochs,
optimizer="Adam",
batch_size=batch_size,
)

indexes, weights = model.instantiate(
Expand All @@ -144,7 +159,12 @@ def test_reslogit_fit_with_other_intercept():
# full_dataset = load_swissmetro() # Use the full dataset to compare the models

# reslogit = ResLogit(
# intercept="item", lr=1e-6, n_layers=0, epochs=100, optimizer="SGD", batch_size=32
# intercept="item",
# n_layers=0,
# lr=1e-3,
# epochs=nb_epochs,
# optimizer="Adam",
# batch_size=batch_size
# )
# reslogit_indexes, reslogit_initial_weights = reslogit.instantiate(
# n_items=n_items, n_shared_features=n_shared_features, n_items_features=n_items_features
Expand All @@ -153,7 +173,13 @@ def test_reslogit_fit_with_other_intercept():
# reslogit_final_weights = reslogit.trainable_weights
# reslogit_score = reslogit.evaluate(full_dataset)

# simple_mnl = SimpleMNL(intercept="item", lr=1e-6, epochs=100, optimizer="SGD", batch_size=32)
# simple_mnl = SimpleMNL(
# intercept="item",
# lr=1e-3,
# epochs=nb_epochs,
# optimizer="Adam",
# batch_size=batch_size
# )
# simple_mnl_indexes, simple_mnl_initial_weights = simple_mnl.instantiate(
# n_items=n_items, n_shared_features=n_shared_features, n_items_features=n_items_features
# )
Expand Down Expand Up @@ -183,8 +209,10 @@ def test_reslogit_different_n_layers():
"""Tests that ResLogit can fit with different n_layers."""
global dataset

for n_layers in [0, 1, 4, 16]:
model = ResLogit(n_layers=n_layers, lr=1e-6, epochs=20, optimizer="SGD", batch_size=32)
for n_layers in [0, 1, 4]:
model = ResLogit(
n_layers=n_layers, lr=1e-3, epochs=nb_epochs, optimizer="Adam", batch_size=batch_size
)
# The model can fit
model.instantiate(n_items, n_shared_features, n_items_features)
eval_before = model.evaluate(dataset)
Expand All @@ -205,17 +233,17 @@ def test_reslogit_different_layers_width():
"""Tests that ResLogit can fit with different custom widths for its residual layers."""
global dataset

list_n_layers = [0, 1, 4, 16]
list_res_layers_width = [[], [], [128, 256, n_items], [2, 4, 8, 16] * 3 + [32, 64, n_items]]
list_n_layers = [0, 1, 4]
list_res_layers_width = [[], [], [128, 256, n_items]]

for n_layers, res_layers_width in zip(list_n_layers, list_res_layers_width):
model = ResLogit(
n_layers=n_layers,
res_layers_width=res_layers_width,
lr=1e-4,
epochs=20,
lr=1e-3,
epochs=nb_epochs,
optimizer="Adam",
batch_size=-1,
batch_size=batch_size,
)
# The model can fit
model.instantiate(n_items, n_shared_features, n_items_features)
Expand Down Expand Up @@ -246,10 +274,10 @@ def test_reslogit_different_layers_width():
model = ResLogit(
n_layers=4,
res_layers_width=[2, 4, 8, n_items],
lr=1e-6,
epochs=20,
optimizer="SGD",
batch_size=32,
lr=1e-3,
epochs=nb_epochs,
optimizer="Adam",
batch_size=batch_size,
)
try:
model.fit(dataset)
Expand All @@ -261,10 +289,10 @@ def test_reslogit_different_layers_width():
model = ResLogit(
n_layers=4,
res_layers_width=[2, 4, 8, 16],
lr=1e-6,
epochs=20,
optimizer="SGD",
batch_size=32,
lr=1e-3,
epochs=nb_epochs,
optimizer="Adam",
batch_size=batch_size,
)
try:
model.fit(dataset)
Expand All @@ -284,10 +312,10 @@ def test_reslogit_different_activation():
model = ResLogit(
n_layers=2,
activation=activation_str,
lr=1e-6,
epochs=20,
optimizer="SGD",
batch_size=32,
lr=1e-3,
epochs=nb_epochs,
optimizer="Adam",
batch_size=batch_size,
)
# The model can fit
model.instantiate(n_items, n_shared_features, n_items_features)
Expand All @@ -300,10 +328,10 @@ def test_reslogit_different_activation():
model = ResLogit(
n_layers=2,
activation="xyz_not_implemented",
lr=1e-6,
epochs=20,
optimizer="SGD",
batch_size=32,
lr=1e-3,
epochs=nb_epochs,
optimizer="Adam",
batch_size=batch_size,
)
try:
model.fit(dataset)
Expand All @@ -320,7 +348,7 @@ def test_that_endpoints_run():
"""
global dataset

model = ResLogit(epochs=20)
model = ResLogit(epochs=nb_epochs)
model.fit(dataset)
model.evaluate(dataset)
model.predict_probas(dataset)
Expand Down

0 comments on commit b82ee45

Please sign in to comment.