Skip to content

Commit

Permalink
Add attention head labels
Browse files Browse the repository at this point in the history
  • Loading branch information
alan-cooney authored Dec 12, 2022
1 parent b1753ed commit 5740a9a
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions python/circuitsvis/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
def attention_heads(
attention: Union[list, np.ndarray, torch.Tensor],
tokens: List[str],
attention_head_names: Optional[List[str]] = None,
max_value: Optional[float] = None,
min_value: Optional[float] = None,
negative_color: Optional[str] = None,
Expand Down Expand Up @@ -41,12 +42,13 @@ def attention_heads(
Html: Attention pattern visualization
"""
kwargs = {
"tokens": tokens,
"attention": attention,
"minValue": min_value,
"attentionHeadNames": attention_head_names,
"maxValue": max_value,
"minValue": min_value,
"negativeColor": negative_color,
"positiveColor": positive_color,
"tokens": tokens,
}

return render(
Expand Down

0 comments on commit 5740a9a

Please sign in to comment.