From 5740a9a07b63eea5845c5525f8868ad77ebca34b Mon Sep 17 00:00:00 2001 From: Alan <41682961+alan-cooney@users.noreply.github.com> Date: Mon, 12 Dec 2022 09:14:29 +0000 Subject: [PATCH] Add attention head labels --- python/circuitsvis/attention.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/circuitsvis/attention.py b/python/circuitsvis/attention.py index 1b47dcb..7245af7 100644 --- a/python/circuitsvis/attention.py +++ b/python/circuitsvis/attention.py @@ -9,6 +9,7 @@ def attention_heads( attention: Union[list, np.ndarray, torch.Tensor], tokens: List[str], + attention_head_names: Optional[List[str]] = None, max_value: Optional[float] = None, min_value: Optional[float] = None, negative_color: Optional[str] = None, @@ -41,12 +42,13 @@ def attention_heads( Html: Attention pattern visualization """ kwargs = { - "tokens": tokens, "attention": attention, - "minValue": min_value, + "attentionHeadNames": attention_head_names, "maxValue": max_value, + "minValue": min_value, "negativeColor": negative_color, "positiveColor": positive_color, + "tokens": tokens, } return render(