Skip to content

Commit

Permalink
Merge pull request #497 from MAIF/hotfix/interactions_values
Browse files Browse the repository at this point in the history
fix interaction_values values and list_ind
  • Loading branch information
ThomasBouche authored Nov 3, 2023
2 parents 7d59855 + b347d69 commit 96950d1
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 17 deletions.
10 changes: 6 additions & 4 deletions shapash/explainer/smart_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1896,12 +1896,12 @@ def _select_indices_interactions_plot(self, selection, max_points):
list_ind = random.sample(self.explainer.x_init.index.tolist(), max_points)
addnote = "Length of random Subset : "
elif isinstance(selection, list):
if hasattr(self, 'interaction_selection'):
if set(self.interaction_selection).issubset(set(selection)):
list_ind = self.interaction_selection
elif len(selection) <= max_points:
if len(selection) <= max_points:
list_ind = selection
addnote = "Length of user-defined Subset : "
elif hasattr(self, 'interaction_selection'):
if set(selection).issubset(set(self.interaction_selection)):
list_ind = self.interaction_selection
else:
list_ind = random.sample(selection, max_points)
addnote = "Length of random Subset : "
Expand Down Expand Up @@ -1985,6 +1985,8 @@ def interactions_plot(self,
feature_values2 = self.explainer.x_init.loc[list_ind, col_name2].to_frame()

interaction_values = self.explainer.get_interaction_values(selection=list_ind)[:, col_id1, col_id2]
if col_id1 != col_id2:
interaction_values = interaction_values * 2

# selecting the best plot : Scatter, Violin?
if col_value_count1 > violin_maxf:
Expand Down
26 changes: 13 additions & 13 deletions tests/unit_tests/explainer/test_smart_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1798,7 +1798,7 @@ def test_interactions_plot_1(self):
color=self.x_init[col1])

assert np.array_equal(output.data[0].x, expected_output.data[0].x)
assert np.array_equal(output.data[1].y, expected_output.data[1].y)
assert np.array_equal(output.data[1].y, expected_output.data[1].y * 2)
assert output.data[0].showlegend is True
assert len(output.data) == 2

Expand All @@ -1819,8 +1819,8 @@ def test_interactions_plot_2(self):
smart_explainer.x_encoded['X2'] = smart_explainer.x_encoded['X2'].astype(float)

interaction_values = np.array([
[[0.1, -0.7], [-0.6, 0.3]],
[[0.2, -0.1], [-0.2, 0.1]]
[[0.1, -0.7], [-0.7, 0.3]],
[[0.2, -0.1], [-0.1, 0.1]]
])

smart_explainer.interaction_values = interaction_values
Expand All @@ -1829,7 +1829,7 @@ def test_interactions_plot_2(self):
output = smart_explainer.plot.interactions_plot(col1, col2, violin_maxf=0)

assert np.array_equal(output.data[0].x, ['PhD', 'Master'])
assert np.array_equal(output.data[0].y, [-0.7, -0.1])
assert np.array_equal(output.data[0].y, [-1.4, -0.2])
assert np.array_equal(output.data[0].marker.color, [34., 27.])
assert len(output.data) == 1

Expand All @@ -1852,8 +1852,8 @@ def test_interactions_plot_3(self):
smart_explainer.x_encoded['X2'] = smart_explainer.x_encoded['X2'].astype(float)

interaction_values = np.array([
[[0.1, -0.7], [-0.6, 0.3]],
[[0.2, -0.1], [-0.2, 0.1]]
[[0.1, -0.7], [-0.7, 0.3]],
[[0.2, -0.1], [-0.1, 0.1]]
])

smart_explainer.interaction_values = interaction_values
Expand All @@ -1862,7 +1862,7 @@ def test_interactions_plot_3(self):
output = smart_explainer.plot.interactions_plot(col2, col1, violin_maxf=0)

assert np.array_equal(output.data[0].x, [34.])
assert np.array_equal(output.data[0].y, [-0.6])
assert np.array_equal(output.data[0].y, [-1.4])
assert output.data[0].name == 'PhD'

assert np.array_equal(output.data[1].x, [27.])
Expand Down Expand Up @@ -1892,8 +1892,8 @@ def test_interactions_plot_4(self):
smart_explainer.x_encoded['X2'] = smart_explainer.x_encoded['X2'].astype(float)

interaction_values = np.array([
[[0.1, -0.7], [-0.6, 0.3]],
[[0.2, -0.1], [-0.2, 0.1]]
[[0.1, -0.7], [-0.7, 0.3]],
[[0.2, -0.1], [-0.1, 0.1]]
])

smart_explainer.interaction_values = interaction_values
Expand All @@ -1902,7 +1902,7 @@ def test_interactions_plot_4(self):
output = smart_explainer.plot.interactions_plot(col1, col2, violin_maxf=0)

assert np.array_equal(output.data[0].x, [520, 12800])
assert np.array_equal(output.data[0].y, [-0.7, -0.1])
assert np.array_equal(output.data[0].y, [-1.4, -0.2])
assert np.array_equal(output.data[0].marker.color, [34., 27.])

assert len(output.data) == 1
Expand All @@ -1926,8 +1926,8 @@ def test_interactions_plot_5(self):
smart_explainer.x_encoded['X2'] = smart_explainer.x_encoded['X2'].astype(float)

interaction_values = np.array([
[[0.1, -0.7], [-0.6, 0.3]],
[[0.2, -0.1], [-0.2, 0.1]]
[[0.1, -0.7], [-0.7, 0.3]],
[[0.2, -0.1], [-0.1, 0.1]]
])

smart_explainer.interaction_values = interaction_values
Expand All @@ -1942,7 +1942,7 @@ def test_interactions_plot_5(self):
assert output.data[2].type == 'scatter'

assert np.array_equal(output.data[2].x, ['PhD', 'Master'])
assert np.array_equal(output.data[2].y, [-0.7, -0.1])
assert np.array_equal(output.data[2].y, [-1.4, -0.2])
assert np.array_equal(output.data[2].marker.color, [34., 27.])

self.setUp()
Expand Down

0 comments on commit 96950d1

Please sign in to comment.