From c0baf914fd6c8e3194ccd5d4975e7293a5ec9f9d Mon Sep 17 00:00:00 2001 From: Thomas BOUCHE Date: Fri, 22 Sep 2023 17:14:14 +0200 Subject: [PATCH] fix interaction_values values and list_ind --- shapash/explainer/smart_plotter.py | 10 ++++--- .../explainer/test_smart_plotter.py | 26 +++++++++---------- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/shapash/explainer/smart_plotter.py b/shapash/explainer/smart_plotter.py index be6d0bd0..0eaacb74 100644 --- a/shapash/explainer/smart_plotter.py +++ b/shapash/explainer/smart_plotter.py @@ -1896,12 +1896,12 @@ def _select_indices_interactions_plot(self, selection, max_points): list_ind = random.sample(self.explainer.x_init.index.tolist(), max_points) addnote = "Length of random Subset : " elif isinstance(selection, list): - if hasattr(self, 'interaction_selection'): - if set(self.interaction_selection).issubset(set(selection)): - list_ind = self.interaction_selection - elif len(selection) <= max_points: + if len(selection) <= max_points: list_ind = selection addnote = "Length of user-defined Subset : " + elif hasattr(self, 'interaction_selection'): + if set(selection).issubset(set(self.interaction_selection)): + list_ind = self.interaction_selection else: list_ind = random.sample(selection, max_points) addnote = "Length of random Subset : " @@ -1985,6 +1985,8 @@ def interactions_plot(self, feature_values2 = self.explainer.x_init.loc[list_ind, col_name2].to_frame() interaction_values = self.explainer.get_interaction_values(selection=list_ind)[:, col_id1, col_id2] + if col_id1 != col_id2: + interaction_values = interaction_values * 2 # selecting the best plot : Scatter, Violin? if col_value_count1 > violin_maxf: diff --git a/tests/unit_tests/explainer/test_smart_plotter.py b/tests/unit_tests/explainer/test_smart_plotter.py index 27b04cea..f0f8f50e 100644 --- a/tests/unit_tests/explainer/test_smart_plotter.py +++ b/tests/unit_tests/explainer/test_smart_plotter.py @@ -1798,7 +1798,7 @@ def test_interactions_plot_1(self): color=self.x_init[col1]) assert np.array_equal(output.data[0].x, expected_output.data[0].x) - assert np.array_equal(output.data[1].y, expected_output.data[1].y) + assert np.array_equal(output.data[1].y, expected_output.data[1].y * 2) assert output.data[0].showlegend is True assert len(output.data) == 2 @@ -1819,8 +1819,8 @@ def test_interactions_plot_2(self): smart_explainer.x_encoded['X2'] = smart_explainer.x_encoded['X2'].astype(float) interaction_values = np.array([ - [[0.1, -0.7], [-0.6, 0.3]], - [[0.2, -0.1], [-0.2, 0.1]] + [[0.1, -0.7], [-0.7, 0.3]], + [[0.2, -0.1], [-0.1, 0.1]] ]) smart_explainer.interaction_values = interaction_values @@ -1829,7 +1829,7 @@ def test_interactions_plot_2(self): output = smart_explainer.plot.interactions_plot(col1, col2, violin_maxf=0) assert np.array_equal(output.data[0].x, ['PhD', 'Master']) - assert np.array_equal(output.data[0].y, [-0.7, -0.1]) + assert np.array_equal(output.data[0].y, [-1.4, -0.2]) assert np.array_equal(output.data[0].marker.color, [34., 27.]) assert len(output.data) == 1 @@ -1852,8 +1852,8 @@ def test_interactions_plot_3(self): smart_explainer.x_encoded['X2'] = smart_explainer.x_encoded['X2'].astype(float) interaction_values = np.array([ - [[0.1, -0.7], [-0.6, 0.3]], - [[0.2, -0.1], [-0.2, 0.1]] + [[0.1, -0.7], [-0.7, 0.3]], + [[0.2, -0.1], [-0.1, 0.1]] ]) smart_explainer.interaction_values = interaction_values @@ -1862,7 +1862,7 @@ def test_interactions_plot_3(self): output = smart_explainer.plot.interactions_plot(col2, col1, violin_maxf=0) assert np.array_equal(output.data[0].x, [34.]) - assert np.array_equal(output.data[0].y, [-0.6]) + assert np.array_equal(output.data[0].y, [-1.4]) assert output.data[0].name == 'PhD' assert np.array_equal(output.data[1].x, [27.]) @@ -1892,8 +1892,8 @@ def test_interactions_plot_4(self): smart_explainer.x_encoded['X2'] = smart_explainer.x_encoded['X2'].astype(float) interaction_values = np.array([ - [[0.1, -0.7], [-0.6, 0.3]], - [[0.2, -0.1], [-0.2, 0.1]] + [[0.1, -0.7], [-0.7, 0.3]], + [[0.2, -0.1], [-0.1, 0.1]] ]) smart_explainer.interaction_values = interaction_values @@ -1902,7 +1902,7 @@ def test_interactions_plot_4(self): output = smart_explainer.plot.interactions_plot(col1, col2, violin_maxf=0) assert np.array_equal(output.data[0].x, [520, 12800]) - assert np.array_equal(output.data[0].y, [-0.7, -0.1]) + assert np.array_equal(output.data[0].y, [-1.4, -0.2]) assert np.array_equal(output.data[0].marker.color, [34., 27.]) assert len(output.data) == 1 @@ -1926,8 +1926,8 @@ def test_interactions_plot_5(self): smart_explainer.x_encoded['X2'] = smart_explainer.x_encoded['X2'].astype(float) interaction_values = np.array([ - [[0.1, -0.7], [-0.6, 0.3]], - [[0.2, -0.1], [-0.2, 0.1]] + [[0.1, -0.7], [-0.7, 0.3]], + [[0.2, -0.1], [-0.1, 0.1]] ]) smart_explainer.interaction_values = interaction_values @@ -1942,7 +1942,7 @@ def test_interactions_plot_5(self): assert output.data[2].type == 'scatter' assert np.array_equal(output.data[2].x, ['PhD', 'Master']) - assert np.array_equal(output.data[2].y, [-0.7, -0.1]) + assert np.array_equal(output.data[2].y, [-1.4, -0.2]) assert np.array_equal(output.data[2].marker.color, [34., 27.]) self.setUp()