From a5109917559ec2d7e6e458206490eda434b52052 Mon Sep 17 00:00:00 2001 From: VincentAuriau Date: Mon, 15 Jan 2024 15:53:54 +0100 Subject: [PATCH] ADD: tests with new signature --- notebooks/features_storage_example.ipynb | 1 - tests/unit_tests/data/test_choice_dataset.py | 226 +++++++++--------- tests/unit_tests/data/test_store.py | 237 ++++++++++++------- 3 files changed, 263 insertions(+), 201 deletions(-) diff --git a/notebooks/features_storage_example.ipynb b/notebooks/features_storage_example.ipynb index 43200b9b..5e8df68a 100644 --- a/notebooks/features_storage_example.ipynb +++ b/notebooks/features_storage_example.ipynb @@ -36,7 +36,6 @@ "outputs": [], "source": [ "features = {\"customerA\": [1, 2, 3], \"customerB\": [4, 5, 6], \"customerC\": [7, 8, 9]}\n", - "\n", "storage = FeaturesStorage(values=features, values_names=[\"age\", \"income\", \"children_nb\"], name=\"customers\")" ] }, diff --git a/tests/unit_tests/data/test_choice_dataset.py b/tests/unit_tests/data/test_choice_dataset.py index 0ef46728..b288cde3 100644 --- a/tests/unit_tests/data/test_choice_dataset.py +++ b/tests/unit_tests/data/test_choice_dataset.py @@ -4,7 +4,7 @@ from choice_learn.data.choice_dataset import ChoiceDataset -items_features = [ +fixed_items_features = [ [1, 2], # item 1 [size, weight] [2, 4], # item 2 [size, weight] [1.5, 1.5], # item 3 [size, weight] @@ -17,19 +17,19 @@ # Customer 2 bought item 3 at session 2 choices = [0, 2, 1] -sessions_items_availabilities = [ +contexts_items_availabilities = [ [1, 1, 1], # All items available at session 1 [1, 1, 1], # All items available at session 2 [0, 1, 1], # Item 1 not available at session 3 ] -sessions_features = [ +contexts_features = [ [100, 20], # session 1, customer 1 [budget, age] [200, 40], # session 2, customer 2 [budget, age] [80, 20], # session 3, customer 1 [budget, age] ] -sessions_items_features = [ +contexts_items_features = [ [ [100, 0], # Session 1, Item 1 [price, promotion] [140, 0], # Session 1, Item 2 [price, promotion] @@ -53,19 +53,10 @@ def test_instantiate_len(): """Test the __init__ method.""" choices = [0, 2, 1] dataset = ChoiceDataset( - items_features=items_features, - sessions_features=sessions_features, - sessions_items_features=sessions_items_features, - sessions_items_availabilities=sessions_items_availabilities, - choices=choices, - ) - assert len(dataset) == 3 - choices = [[0], [1, 2], [2, 1, 1, 1]] - dataset = ChoiceDataset( - items_features=items_features, - sessions_features=sessions_features, - sessions_items_features=sessions_items_features, - sessions_items_availabilities=sessions_items_availabilities, + fixed_items_features=fixed_items_features, + contexts_features=contexts_features, + contexts_items_features=contexts_items_features, + contexts_items_availabilities=contexts_items_availabilities, choices=choices, ) assert len(dataset) == 3 @@ -76,10 +67,10 @@ def test_fail_instantiate(): choices = [0, 1] with pytest.raises(ValueError): ChoiceDataset( - items_features=items_features, - sessions_features=sessions_features, - sessions_items_features=sessions_items_features, - sessions_items_availabilities=sessions_items_availabilities, + fixed_items_features=fixed_items_features, + contexts_features=contexts_features, + contexts_items_features=contexts_items_features, + contexts_items_availabilities=contexts_items_availabilities, choices=choices, ) @@ -93,10 +84,10 @@ def test_fail_instantiate_2(): ] with pytest.raises(ValueError): ChoiceDataset( - items_features=items_features, - sessions_features=sessions_features, - sessions_items_features=sessions_items_features, - sessions_items_availabilities=sessions_items_availabilities, + fixed_items_features=items_features, + contexts_features=contexts_features, + contexts_items_features=contexts_items_features, + contexts_items_availabilities=contexts_items_availabilities, choices=choices, ) @@ -109,10 +100,10 @@ def test_fail_instantiate_3(): ] with pytest.raises(ValueError): ChoiceDataset( - items_features=items_features, - sessions_features=sessions_features, - sessions_items_features=sessions_items_features, - sessions_items_availabilities=sessions_items_availabilities, + fixed_items_features=items_features, + contexts_features=contexts_features, + contexts_items_features=contexts_items_features, + contexts_items_availabilities=contexts_items_availabilities, choices=choices, ) @@ -125,10 +116,10 @@ def test_fail_instantiate_10(): ] with pytest.raises(ValueError): ChoiceDataset( - items_features=items_features, - sessions_features=sessions_features, - sessions_items_features=sessions_items_features, - sessions_items_availabilities=sessions_items_availabilities, + fixed_items_features=fixed_items_features, + contexts_features=contexts_features, + contexts_items_features=contexts_items_features, + contexts_items_availabilities=sessions_items_availabilities, choices=choices, ) @@ -142,10 +133,10 @@ def test_fail_instantiate_4(): ] with pytest.raises(ValueError): ChoiceDataset( - items_features=items_features, - sessions_features=sessions_features, - sessions_items_features=sessions_items_features, - sessions_items_availabilities=sessions_items_availabilities, + fixed_items_features=fixed_items_features, + contexts_features=contexts_features, + contexts_items_features=contexts_items_features, + contexts_items_availabilities=sessions_items_availabilities, choices=choices, ) @@ -158,10 +149,10 @@ def test_fail_instantiate_5(): ] with pytest.raises(ValueError): ChoiceDataset( - items_features=items_features, - sessions_features=sessions_features, - sessions_items_features=sessions_items_features, - sessions_items_availabilities=sessions_items_availabilities, + fixed_items_features=fixed_items_features, + contexts_features=sessions_features, + contexts_items_features=contexts_items_features, + contexts_items_availabilities=contexts_items_availabilities, choices=choices, ) @@ -175,10 +166,10 @@ def test_fail_instantiate_6(): ] with pytest.raises(ValueError): ChoiceDataset( - items_features=items_features, - sessions_features=sessions_features, - sessions_items_features=sessions_items_features, - sessions_items_availabilities=sessions_items_availabilities, + fixed_items_features=fixed_items_features, + contexts_features=sessions_features, + contexts_items_features=contexts_items_features, + contexts_items_availabilities=contexts_items_availabilities, choices=choices, ) @@ -199,10 +190,10 @@ def test_fail_instantiate_7(): ] with pytest.raises(ValueError): ChoiceDataset( - items_features=items_features, - sessions_features=sessions_features, - sessions_items_features=sessions_items_features, - sessions_items_availabilities=sessions_items_availabilities, + fixed_items_features=fixed_items_features, + contexts_features=contexts_features, + contexts_items_features=sessions_items_features, + contexts_items_availabilities=contexts_items_availabilities, choices=choices, ) @@ -226,10 +217,10 @@ def test_fail_instantiate_8(): ] with pytest.raises(ValueError): ChoiceDataset( - items_features=items_features, - sessions_features=sessions_features, - sessions_items_features=sessions_items_features, - sessions_items_availabilities=sessions_items_availabilities, + fixed_items_features=fixed_items_features, + contexts_features=contexts_features, + contexts_items_features=sessions_items_features, + contexts_items_availabilities=contexts_items_availabilities, choices=choices, ) @@ -239,10 +230,10 @@ def test_fail_instantiate_9(): choices = [0, 4, 2] # choices higher than nb of items with pytest.raises(ValueError): ChoiceDataset( - items_features=items_features, - sessions_features=sessions_features, - sessions_items_features=sessions_items_features, - sessions_items_availabilities=sessions_items_availabilities, + fixed_items_features=fixed_items_features, + contexts_features=contexts_features, + contexts_items_features=contexts_items_features, + contexts_items_availabilities=contexts_items_availabilities, choices=choices, ) @@ -250,15 +241,14 @@ def test_fail_instantiate_9(): def test_shape(): """Tests get shape methods.""" dataset = ChoiceDataset( - items_features=items_features, - sessions_features=sessions_features, - sessions_items_features=sessions_items_features, - sessions_items_availabilities=sessions_items_availabilities, + fixed_items_features=fixed_items_features, + contexts_features=contexts_features, + contexts_items_features=contexts_items_features, + contexts_items_availabilities=contexts_items_availabilities, choices=choices, ) assert dataset.get_num_items() == 3 - assert dataset.get_num_sessions() == 3 assert dataset.get_num_choices() == 3 @@ -269,7 +259,7 @@ def test_from_df(): "item_id": [0, 1, 2, 0, 1, 2, 1, 2], "items_feat_1": [1, 2, 1.5, 1, 2, 1.5, 2, 1.5], "items_feat_2": [2, 4, 1.5, 2, 4, 1.5, 4, 1.5], - "session_id": [0, 0, 0, 1, 1, 1, 2, 2], + "context_id": [0, 0, 0, 1, 1, 1, 2, 2], "session_feat_1": [100, 100, 100, 200, 200, 200, 80, 80], "session_feat_2": [20, 20, 20, 40, 40, 40, 20, 20], "session_item_feat_1": [100, 140, 200, 100, 120, 200, 120, 180], @@ -279,27 +269,29 @@ def test_from_df(): ) cd_test = ChoiceDataset.from_single_df( features_df, - items_features_columns=["items_feat_1", "items_feat_2"], - sessions_features_columns=["session_feat_1", "session_feat_2"], - sessions_items_features_columns=["session_item_feat_1", "session_item_feat_2"], - choice_mode="item_id", + fixed_items_features_columns=["items_feat_1", "items_feat_2"], + contexts_features_columns=["session_feat_1", "session_feat_2"], + contexts_items_features_columns=["session_item_feat_1", "session_item_feat_2"], + choice_mode="items_id", ) ground_truth_cd = ChoiceDataset( - items_features=items_features, - sessions_features=sessions_features, - sessions_items_features=sessions_items_features, - sessions_items_availabilities=sessions_items_availabilities, + fixed_items_features=fixed_items_features, + contexts_features=contexts_features, + contexts_items_features=contexts_items_features, + contexts_items_availabilities=contexts_items_availabilities, choices=choices, ) - assert (cd_test.items_features[0] == ground_truth_cd.items_features[0]).all() - assert (cd_test.sessions_features[0] == ground_truth_cd.sessions_features[0]).all() + assert (cd_test.fixed_items_features[0] == ground_truth_cd.fixed_items_features[0]).all() + assert (cd_test.contexts_features[0] == ground_truth_cd.contexts_features[0]).all() assert ( - cd_test.sessions_items_features[0].astype("float32") - == ground_truth_cd.sessions_items_features[0].astype("float32") + cd_test.contexts_items_features[0].astype("float32") + == ground_truth_cd.contexts_items_features[0].astype("float32") ).all() assert ( - cd_test.sessions_items_availabilities == ground_truth_cd.sessions_items_availabilities + cd_test.contexts_items_availabilities == ground_truth_cd.contexts_items_availabilities ).all() + print(cd_test.choices) + print(cd_test.fixed_items_features) assert (cd_test.choices == ground_truth_cd.choices).all() features_df = pd.DataFrame( @@ -307,7 +299,7 @@ def test_from_df(): "item_id": [0, 1, 2, 0, 1, 2, 1, 2], "items_feat_1": [1, 2, 1.5, 1, 2, 1.5, 2, 1.5], "items_feat_2": [2, 4, 1.5, 2, 4, 1.5, 4, 1.5], - "session_id": [0, 0, 0, 1, 1, 1, 2, 2], + "context_id": [0, 0, 0, 1, 1, 1, 2, 2], "session_feat_1": [100, 100, 100, 200, 200, 200, 80, 80], "session_feat_2": [20, 20, 20, 40, 40, 40, 20, 20], "session_item_feat_1": [100, 140, 200, 100, 120, 200, 120, 180], @@ -317,26 +309,26 @@ def test_from_df(): ) cd_test = ChoiceDataset.from_single_df( features_df, - items_features_columns=["items_feat_1", "items_feat_2"], - sessions_features_columns=["session_feat_1", "session_feat_2"], - sessions_items_features_columns=["session_item_feat_1", "session_item_feat_2"], + fixed_items_features_columns=["items_feat_1", "items_feat_2"], + contexts_features_columns=["session_feat_1", "session_feat_2"], + contexts_items_features_columns=["session_item_feat_1", "session_item_feat_2"], choice_mode="one_zero", ) ground_truth_cd = ChoiceDataset( - items_features=items_features, - sessions_features=sessions_features, - sessions_items_features=sessions_items_features, - sessions_items_availabilities=sessions_items_availabilities, + fixed_items_features=fixed_items_features, + contexts_features=contexts_features, + contexts_items_features=contexts_items_features, + contexts_items_availabilities=contexts_items_availabilities, choices=choices, ) - assert (cd_test.items_features[0] == ground_truth_cd.items_features[0]).all() - assert (cd_test.sessions_features[0] == ground_truth_cd.sessions_features[0]).all() + assert (cd_test.fixed_items_features[0] == ground_truth_cd.fixed_items_features[0]).all() + assert (cd_test.contexts_features[0] == ground_truth_cd.contexts_features[0]).all() assert ( - cd_test.sessions_items_features[0].astype("float32") - == ground_truth_cd.sessions_items_features[0].astype("float32") + cd_test.contexts_items_features[0].astype("float32") + == ground_truth_cd.contexts_items_features[0].astype("float32") ).all() assert ( - cd_test.sessions_items_availabilities == ground_truth_cd.sessions_items_availabilities + cd_test.contexts_items_availabilities == ground_truth_cd.contexts_items_availabilities ).all() assert (cd_test.choices == ground_truth_cd.choices).all() @@ -344,10 +336,10 @@ def test_from_df(): def test_summary(): """Tests summary method.""" dataset = ChoiceDataset( - items_features=items_features, - sessions_features=sessions_features, - sessions_items_features=sessions_items_features, - sessions_items_availabilities=sessions_items_availabilities, + fixed_items_features=fixed_items_features, + contexts_features=contexts_features, + contexts_items_features=contexts_items_features, + contexts_items_availabilities=contexts_items_availabilities, choices=choices, ) dataset.summary() @@ -357,21 +349,21 @@ def test_summary(): def test_getitem(): """Tests getitem method.""" dataset = ChoiceDataset( - items_features=items_features, - sessions_features=sessions_features, - sessions_items_features=sessions_items_features, - sessions_items_availabilities=sessions_items_availabilities, + fixed_items_features=fixed_items_features, + contexts_features=contexts_features, + contexts_items_features=contexts_items_features, + contexts_items_availabilities=contexts_items_availabilities, choices=choices, ) sub_dataset = dataset[[0, 1]] - assert (sub_dataset.items_features[0] == dataset.items_features[0]).all() - assert (sub_dataset.sessions_features[0] == dataset.sessions_features[0][[0, 1]]).all() + assert (sub_dataset.fixed_items_features[0] == dataset.fixed_items_features[0]).all() + assert (sub_dataset.contexts_features[0] == dataset.contexts_features[0][[0, 1]]).all() assert ( - sub_dataset.sessions_items_features[0] == dataset.sessions_items_features[0][[0, 1]] + sub_dataset.contexts_items_features[0] == dataset.contexts_items_features[0][[0, 1]] ).all() assert ( - sub_dataset.sessions_items_availabilities == dataset.sessions_items_availabilities[[0, 1]] + sub_dataset.contexts_items_availabilities == dataset.contexts_items_availabilities[[0, 1]] ).all() assert (sub_dataset.choices == dataset.choices[[0, 1]]).all() assert (sub_dataset.choices == [0, 2]).all() @@ -380,17 +372,17 @@ def test_getitem(): def test_batch(): """Tests the batch method.""" dataset = ChoiceDataset( - items_features=items_features, - sessions_features=sessions_features, - sessions_items_features=sessions_items_features, - sessions_items_availabilities=sessions_items_availabilities, + fixed_items_features=fixed_items_features, + contexts_features=contexts_features, + contexts_items_features=contexts_items_features, + contexts_items_availabilities=contexts_items_availabilities, choices=choices, ) batch = dataset.batch[[0, 1]] - assert (batch[0] == items_features).all() - assert (batch[1] == sessions_features[:2]).all() - assert (batch[2] == sessions_items_features[:2]).all() - assert (batch[3] == sessions_items_availabilities[:2]).all() + assert (batch[0] == fixed_items_features).all() + assert (batch[1] == contexts_features[:2]).all() + assert (batch[2] == contexts_items_features[:2]).all() + assert (batch[3] == contexts_items_availabilities[:2]).all() assert (batch[4] == choices[:2]).all() sliced_batch = dataset.batch[:2] @@ -411,10 +403,10 @@ def test_batch(): def test_iter_batch(): """Tests the iter_batch method.""" dataset = ChoiceDataset( - items_features=items_features, - sessions_features=sessions_features, - sessions_items_features=sessions_items_features, - sessions_items_availabilities=sessions_items_availabilities, + fixed_items_features=fixed_items_features, + contexts_features=contexts_features, + contexts_items_features=contexts_items_features, + contexts_items_availabilities=contexts_items_availabilities, choices=choices, ) for batch_nb, batch in enumerate(dataset.iter_batch(batch_size=2)): @@ -429,10 +421,10 @@ def test_iter_batch(): def test_filter(): """Tests the filter method.""" dataset = ChoiceDataset( - items_features=items_features, - sessions_features=sessions_features, - sessions_items_features=sessions_items_features, - sessions_items_availabilities=sessions_items_availabilities, + fixed_items_features=fixed_items_features, + contexts_features=contexts_features, + contexts_items_features=contexts_items_features, + contexts_items_availabilities=contexts_items_availabilities, choices=choices, ) filtered_dataset = dataset.filter([True, False, True]) diff --git a/tests/unit_tests/data/test_store.py b/tests/unit_tests/data/test_store.py index d1fd3df8..9fdd1767 100644 --- a/tests/unit_tests/data/test_store.py +++ b/tests/unit_tests/data/test_store.py @@ -1,133 +1,204 @@ """Test the store module.""" -from choice_learn.data.store import FeaturesStore, OneHotStore, Store +import numpy as np +import pandas as pd + +from choice_learn.data.storage import FeaturesStorage, OneHotStorage def test_len_store(): - """Test the __len__ method of Store.""" - store = Store(values=[1, 2, 3, 4], sequence=[0, 1, 2, 3, 0, 1, 2, 3]) - assert len(store) == 8 + """Test the __len__ method of Storage.""" + features = {"customerA": [1, 2], "customerB": [4, 5], "customerC": [7, 8]} + storage = FeaturesStorage( + values=features, values_names=["age", "income", "children_nb"], name="customers" + ) + assert len(storage) == 3 + assert storage.shape == (3, 2) def test_get_store_element(): """Test the _get_store_element method of Store.""" - store = Store(values=[1, 2, 3, 4], sequence=[0, 1, 2, 3, 0, 1, 2, 3]) - assert store._get_store_element(0) == 1 - assert store._get_store_element([0, 1, 2]) == [1, 2, 3] + features = {"customerA": [1, 2], "customerB": [4, 5], "customerC": [7, 8]} + storage = FeaturesStorage( + values=features, values_names=["age", "income", "children_nb"], name="customers" + ) + assert (storage.get_element_from_index(0) == np.array([1, 2])).all() + assert (storage.get_element_from_index([0, 1, 2]) == np.array([[1, 2], [4, 5], [7, 8]])).all() def test_store_batch(): """Test the batch method of Store.""" - store = Store(values=[1, 2, 3, 4], sequence=[0, 1, 2, 3, 0, 1, 2, 3]) - assert store.batch[1] == 2 - assert store.batch[2:4] == [3, 4] - assert store.batch[[2, 3, 6, 7]] == [3, 4, 3, 4] + features = {"customerA": [1, 2], "customerB": [4, 5], "customerC": [7, 8]} + storage = FeaturesStorage( + values=features, values_names=["age", "income", "children_nb"], name="customers" + ) + assert (storage.batch["customerA"] == np.array([1, 2])).all() + assert ( + storage.batch[["customerA", "customerC", "customerA", "customerC"]] + == np.array([[1, 2], [7, 8], [1, 2], [7, 8]]) + ).all() def test_featuresstore_instantiation(): """Test the instantiation of FeaturesStore.""" - store = FeaturesStore( - values=[[10, 10], [4, 4], [2, 2], [8, 8]], - sequence=[0, 1, 2, 3, 0, 1, 2, 3], - indexes=[0, 1, 2, 3], + features = {"customerA": [1, 2], "customerB": [4, 5], "customerC": [7, 8]} + storage = FeaturesStorage( + values=features, values_names=["age", "income", "children_nb"], name="customers" ) - assert store.shape == (8, 2) - assert [store.sequence[i] == [0, 1, 2, 3, 0, 1, 2, 3][i] for i in range(8)] - assert store.store == {0: [10, 10], 1: [4, 4], 2: [2, 2], 3: [8, 8]} + + for k, v in storage.storage.items(): + assert ( + v + == { + "customerA": np.array([1, 2]), + "customerB": np.array([4, 5]), + "customerC": np.array([7, 8]), + }[k] + ).all() def test_featuresstore_instantiation_indexless(): """Test the instantiation of FeaturesStore.""" - store = FeaturesStore( - values=[[10, 10], [4, 4], [2, 2], [8, 8]], sequence=[0, 1, 2, 3, 0, 1, 2, 3] + features = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] + ids = ["customerA", "customerB", "customerC"] + + storage = FeaturesStorage( + ids=ids, values=features, values_names=["age", "income", "children_nb"], name="customers" ) - assert store.shape == (8, 2) - assert [store.sequence[i] == [0, 1, 2, 3, 0, 1, 2, 3][i] for i in range(8)] - assert store.store == {0: [10, 10], 1: [4, 4], 2: [2, 2], 3: [8, 8]} + assert storage.shape == (3, 3) + for k, v in storage.storage.items(): + assert ( + v + == { + "customerA": np.array([1, 2, 3]), + "customerB": np.array([4, 5, 6]), + "customerC": np.array([7, 8, 9]), + }[k] + ).all() def test_featuresstore_instantiation_from_list(): """Test the instantiation of FeaturesStore.""" - store = FeaturesStore.from_list( - values_list=[[10, 10], [4, 4], [2, 2], [8, 8]], sequence=[0, 1, 2, 3, 0, 1, 2, 3] + features = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] + + storage = FeaturesStorage( + values=features, values_names=["age", "income", "children_nb"], name="customers" ) - assert store.shape == (8, 2) - assert [store.sequence[i] == [0, 1, 2, 3, 0, 1, 2, 3][i] for i in range(8)] - assert store.store == {0: [10, 10], 1: [4, 4], 2: [2, 2], 3: [8, 8]} + storage.batch[[0, 2, 0, 2]] + assert storage.shape == (3, 3) + for k, v in storage.storage.items(): + assert ( + v == {0: np.array([1, 2, 3]), 1: np.array([4, 5, 6]), 2: np.array([7, 8, 9])}[k] + ).all() def test_featuresstore_instantiation_fromdict(): """Test the instantiation of FeaturesStore.""" - store = FeaturesStore.from_dict( - values_dict={0: [10, 10], 1: [4, 4], 2: [2, 2], 3: [8, 8]}, - sequence=[0, 1, 2, 3, 0, 1, 2, 3], - ) - assert store.shape == (8, 2) - assert [store.sequence[i] == [0, 1, 2, 3, 0, 1, 2, 3][i] for i in range(8)] - assert store.store == {0: [10, 10], 1: [4, 4], 2: [2, 2], 3: [8, 8]} + features = { + "age": [1, 4, 7], + "income": [2, 5, 8], + "children_nb": [3, 6, 9], + "id": ["customerA", "customerB", "customerC"], + } + features = pd.DataFrame(features) + storage = FeaturesStorage(values=features, name="customers") + assert storage.shape == (3, 3) + for k, v in storage.storage.items(): + assert ( + v + == { + "customerA": np.array([1, 2, 3]), + "customerB": np.array([4, 5, 6]), + "customerC": np.array([7, 8, 9]), + }[k] + ).all() + + +def test_featuresstore_instantiation_fromdf(): + """Test the instantiation of FeaturesStore.""" + features = {"age": [1, 4, 7], "income": [2, 5, 8], "children_nb": [3, 6, 9]} + features = pd.DataFrame(features, index=["customerA", "customerB", "customerC"]) + storage = FeaturesStorage(values=features, name="customers") + assert storage.shape == (3, 3) + for k, v in storage.storage.items(): + assert ( + v + == { + "customerA": np.array([1, 2, 3]), + "customerB": np.array([4, 5, 6]), + "customerC": np.array([7, 8, 9]), + }[k] + ).all() def test_featuresstore_getitem(): """Test the __getitem__ method of FeaturesStore.""" - store = FeaturesStore.from_dict( - values_dict={0: [10, 10], 1: [4, 4], 2: [2, 2], 3: [8, 8]}, - sequence=[0, 1, 2, 3, 0, 1, 2, 3], + features = {"customerA": [1, 2], "customerB": [4, 5], "customerC": [7, 8]} + storage = FeaturesStorage( + values=features, values_names=["age", "income", "children_nb"], name="customers" ) - sub_store = store[0:3] - assert sub_store.shape == (3, 2) - assert [sub_store.sequence[i] == [0, 1, 2][i] for i in range(3)] - assert sub_store.store == {0: [10, 10], 1: [4, 4], 2: [2, 2]} + sub_storage = storage[["customerA", "customerC"]] + assert sub_storage.shape == (2, 2) + for k, v in {"customerA": np.array([1, 2]), "customerC": np.array([7, 8])}.items(): + print(v, sub_storage.storage[k]) + assert (v == sub_storage.storage[k]).all() def test_onehotstore_instantiation(): """Test the instantiation of OneHotStore.""" - indexes = [0, 1, 2, 4] - values = [0, 1, 2, 3] - sequence = [0, 1, 2, 4, 0, 1, 2, 4] - store = OneHotStore(indexes=indexes, values=values, sequence=sequence) - assert store.shape == (8, 4) - assert [store.sequence[i] == [0, 1, 2, 4, 0, 1, 2, 4][i] for i in range(8)] - assert store.store == {0: 0, 1: 1, 2: 2, 4: 3} + ids = [0, 1, 2, 3, 4] + values = [4, 3, 2, 1, 0] + storage = OneHotStorage(ids=ids, values=values, name="OneHotTest") + assert storage.shape == (5, 5) + assert storage.storage == {0: 4, 1: 3, 2: 2, 3: 1, 4: 0} def test_onehotstore_instantiation_from_sequence(): """Test the instantiation; from_sequence of OneHotStore.""" - sequence = [0, 1, 2, 3, 0, 1, 2, 3] - store = OneHotStore.from_sequence(sequence=sequence) - assert store.shape == (8, 4) - assert [store.sequence[i] == [0, 1, 2, 3, 0, 1, 2, 3][i] for i in range(8)] - assert store.store == {0: 0, 1: 1, 2: 2, 3: 3} + values = [4, 3, 2, 1, 0] + storage = OneHotStorage(values=values, name="OneHotTest") + assert ( + storage.batch[[0, 2, 4]] == np.array([[0, 0, 0, 0, 1], [0, 0, 1, 0, 0], [1, 0, 0, 0, 0]]) + ).all() + assert storage.storage == {4: 0, 3: 1, 2: 2, 1: 3, 0: 4} -def test_onehotstore_getitem(): - """Test the getitem of OneHotStore.""" - indexes = [0, 1, 2, 4] - values = [0, 1, 2, 3] - sequence = [0, 1, 2, 4, 0, 1, 2, 4] - store = OneHotStore(indexes=indexes, values=values, sequence=sequence) - sub_store = store[0:3] - assert sub_store.shape == (3, 3) - assert [ - sub_store.sequence[i] == [0, 1, 2, 3, 0, 1, 2, 3][i] for i in range(len(sub_store.sequence)) - ] - assert sub_store.store == { - 0: 0, - 1: 1, - 2: 2, - } +def test_onehotstore_instantiation_from_ids(): + """Test the instantiation; from_sequence of OneHotStore.""" + ids = [0, 1, 2, 3, 4] + storage = OneHotStorage(ids=ids, name="OneHotTest") + assert ( + storage.batch[[0, 2, 4]] == np.array([[1, 0, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 0, 1]]) + ).all() + assert storage.storage == {0: 0, 1: 1, 2: 2, 3: 3, 4: 4} -def test_onehotstore_batch(): - """Test the getitem of OneHotStore.""" - indexes = [0, 1, 2, 4] - values = [0, 1, 2, 3] - sequence = [0, 1, 2, 4, 0, 1, 2, 4] - store = OneHotStore(indexes=indexes, values=values, sequence=sequence) - - batch = store.batch[0] - assert (batch == [1, 0, 0, 0]).all() +def test_onehotstore_instantiation_from_dict(): + """Test the instantiation; from_sequence of OneHotStore.""" + ids = [0, 1, 2, 3, 4] + values = [4, 3, 2, 1, 0] + values_dict = {k: v for k, v in zip(ids, values)} + storage = OneHotStorage(values=values_dict, name="OneHotTest") + assert ( + storage.batch[[0, 2, 4]] == np.array([[0, 0, 0, 0, 1], [0, 0, 1, 0, 0], [1, 0, 0, 0, 0]]) + ).all() + assert storage.storage == {4: 0, 3: 1, 2: 2, 1: 3, 0: 4} - batch = store.batch[0:4] - assert (batch == [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]).all() - batch = store.batch[[3, 6, 7]] - assert (batch == [[0, 0, 0, 1], [0, 0, 1, 0], [0, 0, 0, 1]]).all() +def test_onehotstore_getitem(): + """Test the getitem of OneHotStore.""" + ids = [0, 1, 2, 3, 4] + values = [4, 3, 2, 1, 0] + storage = OneHotStorage(ids=ids, values=values, name="OneHotTest") + assert ( + storage.batch[[0, 2, 4]] == np.array([[0, 0, 0, 0, 1], [0, 0, 1, 0, 0], [1, 0, 0, 0, 0]]) + ).all() + assert storage.get_element_from_index(0) == 4 + + +def test_fail_instantiation(): + """Testing failed instantiation.""" + try: + _ = OneHotStorage(name="OneHotTest") + assert False + except ValueError: + assert True