diff --git a/choice_learn/datasets/base.py b/choice_learn/datasets/base.py index b4d864ac..d6da3e1e 100644 --- a/choice_learn/datasets/base.py +++ b/choice_learn/datasets/base.py @@ -645,7 +645,7 @@ def load_modecanada( if row.choice == 1: named_choice[n_row - 1] = row.alt - canada_df["choice"] = named_choice + canada_df["named_choice"] = named_choice if as_frame: if split_features: @@ -667,6 +667,9 @@ def load_modecanada( items_features_by_choice, choices, ) + if choice_format == "items_id": + canada_df["choice"] = canada_df["named_choice"] + canada_df = canada_df.drop("named_choice", axis=1) return canada_df if split_features: @@ -687,7 +690,7 @@ def load_modecanada( cf.append(context_df.loc[context_df.alt == item][items_features].to_numpy()[0]) cav.append(1) else: - cf.append([0.0, 0.0, 0.0, 0.0]) + cf.append([0.0 for _ in range(len(items_features))]) cav.append(0) cif.append(cf) ci_av.append(cav) diff --git a/tests/data/test_data.csv b/tests/data/test_data.csv new file mode 100644 index 00000000..ff80f335 --- /dev/null +++ b/tests/data/test_data.csv @@ -0,0 +1,21 @@ +,srch_id,prop_id,site_id +0,1,893,12 +1,1,10404,12 +2,1,21315,12 +3,1,27348,12 +4,1,29604,12 +5,1,30184,12 +6,1,44147,12 +7,1,50984,12 +8,1,53341,12 +9,1,56880,12 +10,1,59267,12 +11,1,59526,12 +12,1,68914,12 +13,1,74474,12 +14,1,81437,12 +15,1,85728,12 +16,1,88096,12 +17,1,88127,12 +18,1,88218,12 +19,1,89073,12 diff --git a/tests/unit_tests/datasets/test_expedia.py b/tests/unit_tests/datasets/test_expedia.py new file mode 100644 index 00000000..85187c30 --- /dev/null +++ b/tests/unit_tests/datasets/test_expedia.py @@ -0,0 +1,11 @@ +"""Unit testing for Expedia loader.""" + +import pytest + +from choice_learn.datasets import load_expedia + + +def test_raise_filenotfound(): + """Test that error raised if no file exist.""" + with pytest.raises(FileNotFoundError): + load_expedia() diff --git a/tests/unit_tests/test_os_datasets.py b/tests/unit_tests/test_os_datasets.py index 20372e32..87ac5b72 100644 --- a/tests/unit_tests/test_os_datasets.py +++ b/tests/unit_tests/test_os_datasets.py @@ -15,6 +15,7 @@ load_tafeng, load_train, ) +from choice_learn.datasets.base import load_csv, load_gzip, slice_from_names def test_swissmetro_loader(): @@ -29,15 +30,73 @@ def test_swissmetro_loader(): assert isinstance(swissmetro, ChoiceDataset) +def test_swissmetro_long_format(): + """Test loading the Swissmetro dataset in long format.""" + swissmetro = load_swissmetro(as_frame=True, preprocessing="long_format") + assert isinstance(swissmetro, pd.DataFrame) + assert swissmetro.shape == (30474, 7) + + +def test_swissmetro_tastenet(): + """Test TasteNet preprocessing of dataset.""" + _ = load_swissmetro(preprocessing="tastenet") + + +def test_swissmetro_tutorial(): + """Test tutorial preprocessing of dataset.""" + _ = load_swissmetro(preprocessing="tutorial") + + +def test_biogeme_nested_tutorial(): + """Test biogeme_nested preprocessing of dataset.""" + _ = load_swissmetro(preprocessing="biogeme_nested") + + +def test_rumnet_tutorial(): + """Test rumnet preprocessing of dataset.""" + _ = load_swissmetro(preprocessing="rumnet") + + def test_modecanada_loader(): """Test loading the Canada dataset.""" - canada = load_modecanada(as_frame=True) + canada = load_modecanada(as_frame=True, choice_format="items_id") assert isinstance(canada, pd.DataFrame) assert canada.shape == (15520, 11) canada = load_modecanada() assert isinstance(canada, ChoiceDataset) + ca, na, da = load_modecanada( + as_frame=True, + add_items_one_hot=True, + add_is_public=True, + choice_format="items_id", + split_features=True, + ) + assert ca.shape == (4324, 4) + assert na.shape == (15520, 11) + assert da.shape == (4324, 2) + + +def test_modecanada_features_split(): + """Test that features are split well.""" + ( + o, + ca, + na, + da, + ) = load_modecanada(add_items_one_hot=True, add_is_public=True, split_features=True) + assert o.shape == (4324, 3) + assert ca.shape == (4324, 4, 9) + assert na.shape == (4324, 4) + assert da.shape == (4324,) + + +def test_modecanada_loader_2(): + """Test loading the Canada dataset w/ preprocessing.""" + canada = load_modecanada(preprocessing="tutorial", add_items_one_hot=True) + assert isinstance(canada, ChoiceDataset) + def test_electricity_loader(): """Test loading the Electricity dataset.""" @@ -324,3 +383,23 @@ def test_londonpassenger_loader(): "distance", ] assert londonpassenger.shared_features_by_choice_names[0] == expected_shared_features_names + + +def test_description(): + """Test getting description.""" + _ = load_swissmetro(return_desc=True) + _ = load_modecanada(return_desc=True) + _ = load_heating(return_desc=True) + _ = load_electricity(return_desc=True) + _ = load_train(return_desc=True) + _ = load_car_preferences(return_desc=True) + _ = load_hc(return_desc=True) + _ = load_londonpassenger(return_desc=True) + _ = load_tafeng(return_desc=True) + + +def test_load_csv(): + """Test csv file loader.""" + _ = load_csv(data_file_name="test_data.csv", data_module="tests/data") + names, data = load_gzip("swissmetro.csv.gz", data_module="choice_learn/datasets/data") + _ = slice_from_names(data, names[:4], names)