Skip to content

Commit

Permalink
ENH: changes in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
VincentAuriau committed Nov 30, 2024
1 parent 0b65b46 commit e94df73
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 12 deletions.
4 changes: 2 additions & 2 deletions notebooks/introduction/3_model_clogit.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1361,7 +1361,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "tf_env",
"language": "python",
"name": "python3"
},
Expand All @@ -1375,7 +1375,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
"version": "3.11.4"
}
},
"nbformat": 4,
Expand Down
9 changes: 4 additions & 5 deletions tests/integration_tests/models/test_latent_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_latent_simple_mnl():
"""Test the simple latent class model fit() method."""
tf.config.run_functions_eagerly(True)
lc_model = LatentClassSimpleMNL(
n_latent_classes=3, fit_method="mle", optimizer="lbfgs", epochs=1000, lbfgs_tolerance=1e-20
n_latent_classes=3, fit_method="mle", optimizer="lbfgs", epochs=1000, lbfgs_tolerance=1e-8
)
_, _ = lc_model.fit(elec_dataset)

Expand All @@ -25,7 +25,7 @@ def test_latent_clogit():
"""Test the conditional logit latent class model fit() method."""
tf.config.run_functions_eagerly(True)
lc_model = LatentClassConditionalLogit(
n_latent_classes=3, fit_method="mle", optimizer="lbfgs", epochs=1000, lbfgs_tolerance=1e-12
n_latent_classes=3, fit_method="mle", optimizer="lbfgs", epochs=100, lbfgs_tolerance=1e-8
)
lc_model.add_shared_coefficient(
coefficient_name="pf", feature_name="pf", items_indexes=[0, 1, 2, 3]
Expand Down Expand Up @@ -60,7 +60,7 @@ def test_manual_lc():
fit_method="mle",
epochs=1000,
optimizer="lbfgs",
lbfgs_tolerance=1e-12,
lbfgs_tolerance=1e-8,
)

manual_lc.instantiate(n_items=4, n_shared_features=0, n_items_features=6)
Expand All @@ -79,8 +79,7 @@ def test_manual_lc_gd():
epochs=1000,
optimizer="Adam",
)
nll_before = manual_lc.evaluate(elec_dataset)
manual_lc.instantiate(n_items=4, n_shared_features=0, n_items_features=6)
nll_before = manual_lc.evaluate(elec_dataset)
_ = manual_lc.fit(elec_dataset)
manual_lc.compute_report(elec_dataset)
assert manual_lc.evaluate(elec_dataset) < nll_before
10 changes: 5 additions & 5 deletions tests/integration_tests/models/test_simple_mnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
dataset = load_swissmetro()


def test_simple_mnl_lbfgs_fit_with_lbfgs():
def test_simple_mnl_fit_with_lbfgs():
"""Tests that SimpleMNL can fit with LBFGS."""
tf.config.run_functions_eagerly(True)
global dataset

model = SimpleMNL(epochs=20)
model.fit(dataset)
model.evaluate(dataset)
_ = model.fit(dataset, get_report=True)
_ = model.evaluate(dataset)
assert model.evaluate(dataset) < 1.0


Expand All @@ -39,8 +39,8 @@ def test_that_endpoints_run():
global dataset

model = SimpleMNL(epochs=20)
model.fit(dataset)
model.compute_report(dataset)
_ = model.fit(dataset)
_ = model.compute_report(dataset)
model.evaluate(dataset)
model.predict_probas(dataset)
assert True

0 comments on commit e94df73

Please sign in to comment.