Skip to content

Commit

Permalink
ADD: basic unit tools tests
Browse files Browse the repository at this point in the history
  • Loading branch information
VincentAuriau committed Dec 28, 2024
1 parent f104688 commit 7434773
Showing 1 changed file with 25 additions and 16 deletions.
41 changes: 25 additions & 16 deletions tests/unit_tests/tools/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,47 @@

import numpy as np

from choice_learn.toolbox.assortment_optimizer import MNLAssortmentOptimizer, LatentClassAssortmentOptimizer, LatentClassPricingOptimizer
from choice_learn.toolbox.assortment_optimizer import (
LatentClassAssortmentOptimizer,
LatentClassPricingOptimizer,
MNLAssortmentOptimizer,
)

solvers = ["or-tools"]


def test_mnl_assort_instantiate():
"""Test instantiation with both solvers."""
for solv in solvers:
MNLAssortmentOptimizer(
solver=solv,
utilities=np.array([1., 2., 3.]),
itemwise_values=np.array([0.5, 0.5, 0.5]),
assortment_size=2)
solver=solv,
utilities=np.array([1.0, 2.0, 3.0]),
itemwise_values=np.array([0.5, 0.5, 0.5]),
assortment_size=2,
)


def test_lc_assort_instantiate():
"""Test instantiation with both solvers."""
for solv in solvers:
LatentClassAssortmentOptimizer(
solver=solv,
class_weights=np.array([.2, .8]),
class_utilities=np.array([[1., 2., 3.], [3., 2., 1.]]),
itemwise_values=np.array([0.5, 0.5, 0.5]),
assortment_size=2)
solver=solv,
class_weights=np.array([0.2, 0.8]),
class_utilities=np.array([[1.0, 2.0, 3.0], [3.0, 2.0, 1.0]]),
itemwise_values=np.array([0.5, 0.5, 0.5]),
assortment_size=2,
)


def test_lc_pricing_instantiate():
"""Test instantiation with both solvers."""
for solv in solvers:
LatentClassPricingOptimizer(
solver=solv,
class_weights=np.array([.2, .8]),
class_utilities=np.array([[[1., 1.1], [2., 2.1], [3., 3.1]],
[[3., 3.1], [2., 2.1], [1., 1.1]]]),
itemwise_values=np.array([[0.5, 1.2], [0.5, 1.2], [0.5, 1.2]]),
assortment_size=2)
solver=solv,
class_weights=np.array([0.2, 0.8]),
class_utilities=np.array(
[[[1.0, 1.1], [2.0, 2.1], [3.0, 3.1]], [[3.0, 3.1], [2.0, 2.1], [1.0, 1.1]]]
),
itemwise_values=np.array([[0.5, 1.2], [0.5, 1.2], [0.5, 1.2]]),
assortment_size=2,
)

0 comments on commit 7434773

Please sign in to comment.