Skip to content

Commit

Permalink
fix plot position
Browse files Browse the repository at this point in the history
  • Loading branch information
Aggrathon committed Nov 25, 2022
1 parent a1585b8 commit 9d2f4b5
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions slisemap/slisemap.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,13 +364,16 @@ def local_model(self, value: Callable[[torch.Tensor, torch.Tensor], torch.Tensor
@property
def local_loss(
self,
) -> Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]:
) -> Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor]:
# Local model loss function. Takes in Ytilde[n, n, o], Y[n, o], and B[n, q], and returns L[n, n]
return self._local_loss

@local_loss.setter
def local_loss(
self, value: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
self,
value: Callable[
[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor
],
):
if self._local_loss != value:
_assert(callable(value), "local_loss must be callable", Slisemap.local_loss)
Expand Down Expand Up @@ -1381,8 +1384,11 @@ def plot_position(
loc="lower center" if inside else "upper right",
bbox_to_anchor=(1 - w, h * 0.35, w * 0.9, h * 0.6) if inside else None,
)
marker = Line2D(
[], [], linestyle="None", color="#fd8431", marker="X", markersize=5
)
g.add_legend(
{"": Line2D([], [], None, "None", "#fd8431", "X", 5)},
{"": marker},
"Selected",
loc="upper center" if inside else "lower right",
bbox_to_anchor=(1 - w, h * 0.05, w * 0.9, h * 0.3) if inside else None,
Expand Down

0 comments on commit 9d2f4b5

Please sign in to comment.