From 9d2f4b5d0581d27ffec90b3c50bd77a3ee80c826 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Anton=20Bj=C3=B6rklund?= Date: Fri, 25 Nov 2022 13:39:12 +0200 Subject: [PATCH] fix plot position --- slisemap/slisemap.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/slisemap/slisemap.py b/slisemap/slisemap.py index 8d796c1..da57119 100644 --- a/slisemap/slisemap.py +++ b/slisemap/slisemap.py @@ -364,13 +364,16 @@ def local_model(self, value: Callable[[torch.Tensor, torch.Tensor], torch.Tensor @property def local_loss( self, - ) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]: + ) -> Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor]: # Local model loss function. Takes in Ytilde[n, n, o], Y[n, o], and B[n, q], and returns L[n, n] return self._local_loss @local_loss.setter def local_loss( - self, value: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor] + self, + value: Callable[ + [torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor + ], ): if self._local_loss != value: _assert(callable(value), "local_loss must be callable", Slisemap.local_loss) @@ -1381,8 +1384,11 @@ def plot_position( loc="lower center" if inside else "upper right", bbox_to_anchor=(1 - w, h * 0.35, w * 0.9, h * 0.6) if inside else None, ) + marker = Line2D( + [], [], linestyle="None", color="#fd8431", marker="X", markersize=5 + ) g.add_legend( - {"": Line2D([], [], None, "None", "#fd8431", "X", 5)}, + {"": marker}, "Selected", loc="upper center" if inside else "lower right", bbox_to_anchor=(1 - w, h * 0.05, w * 0.9, h * 0.3) if inside else None,