From 4d5869b281141513690e8ec4f0d556d3b91de50f Mon Sep 17 00:00:00 2001 From: Andy Arditi Date: Tue, 22 Aug 2023 04:50:54 -0400 Subject: [PATCH] Add bidirectional attention support to attention pattern (#75) --- python/circuitsvis/attention.py | 10 ++++++++++ react/src/attention/AttentionHeads.tsx | 16 ++++++++++++++++ react/src/attention/AttentionPattern.tsx | 21 ++++++++++++++++++--- 3 files changed, 44 insertions(+), 3 deletions(-) diff --git a/python/circuitsvis/attention.py b/python/circuitsvis/attention.py index 7245af7..3da78d4 100644 --- a/python/circuitsvis/attention.py +++ b/python/circuitsvis/attention.py @@ -14,6 +14,7 @@ def attention_heads( min_value: Optional[float] = None, negative_color: Optional[str] = None, positive_color: Optional[str] = None, + mask_upper_tri: Optional[bool] = None, ) -> RenderedHTML: """Attention Heads @@ -37,6 +38,9 @@ def attention_heads( positive_color: Color for positive values. This can be any valid CSS color string. Be mindful of color blindness if not using the default here. + mask_upper_tri: Whether or not to mask the upper triangular portion of + the attention patterns. Should be true for causal attention, false for + bidirectional attention. Returns: Html: Attention pattern visualization @@ -49,6 +53,7 @@ def attention_heads( "negativeColor": negative_color, "positiveColor": positive_color, "tokens": tokens, + "maskUpperTri": mask_upper_tri, } return render( @@ -90,6 +95,7 @@ def attention_pattern( negative_color: Optional[str] = None, show_axis_labels: Optional[bool] = None, positive_color: Optional[str] = None, + mask_upper_tri: Optional[bool] = None, ) -> RenderedHTML: """Attention Pattern @@ -112,6 +118,9 @@ def attention_pattern( positive_color: Color for positive values. This can be any valid CSS color string. Be mindful of color blindness if not using the default here. + mask_upper_tri: Whether or not to mask the upper triangular portion of + the attention patterns. Should be true for causal attention, false for + bidirectional attention. Returns: Html: Attention pattern visualization @@ -124,6 +133,7 @@ def attention_pattern( "negativeColor": negative_color, "positiveColor": positive_color, "showAxisLabels": show_axis_labels, + "maskUpperTri": mask_upper_tri, } return render( diff --git a/react/src/attention/AttentionHeads.tsx b/react/src/attention/AttentionHeads.tsx index 6c601d0..6725104 100644 --- a/react/src/attention/AttentionHeads.tsx +++ b/react/src/attention/AttentionHeads.tsx @@ -34,6 +34,7 @@ export function AttentionHeadsSelector({ onMouseEnter, onMouseLeave, positiveColor, + maskUpperTri, tokens }: AttentionHeadsProps & { attentionHeadNames: string[]; @@ -89,6 +90,7 @@ export function AttentionHeadsSelector({ minValue={minValue} negativeColor={negativeColor} positiveColor={positiveColor} + maskUpperTri={maskUpperTri} /> @@ -112,6 +114,7 @@ export function AttentionHeads({ minValue, negativeColor, positiveColor, + maskUpperTri = true, tokens }: AttentionHeadsProps) { // Attention head focussed state @@ -137,6 +140,7 @@ export function AttentionHeads({ onMouseEnter={onMouseEnter} onMouseLeave={onMouseLeave} positiveColor={positiveColor} + maskUpperTri={maskUpperTri} tokens={tokens} /> @@ -165,6 +169,7 @@ export function AttentionHeads({ negativeColor={negativeColor} positiveColor={positiveColor} zoomed={true} + maskUpperTri={maskUpperTri} tokens={tokens} /> @@ -241,6 +246,17 @@ export interface AttentionHeadsProps { */ positiveColor?: string; + /** + * Mask upper triangular + * + * Whether or not to mask the upper triangular portion of the attention patterns. + * + * Should be true for causal attention, false for bidirectional attention. + * + * @default true + */ + maskUpperTri?: boolean; + /** * Show axis labels */ diff --git a/react/src/attention/AttentionPattern.tsx b/react/src/attention/AttentionPattern.tsx index 0bdcfd7..c6708ac 100644 --- a/react/src/attention/AttentionPattern.tsx +++ b/react/src/attention/AttentionPattern.tsx @@ -62,6 +62,7 @@ export function AttentionPattern({ upperTriColor = DefaultUpperTriColor, showAxisLabels = true, zoomed = false, + maskUpperTri = true, tokens }: AttentionPatternProps) { // Tokens must be unique (for the categories), so we add an index prefix @@ -96,7 +97,7 @@ export function AttentionPattern({ // Set the background color for each block, based on the attention value backgroundColor(context: ScriptableContext<"matrix">) { const block = context.dataset.data[context.dataIndex] as any as Block; - if (block.srcIdx > block.destIdx) { + if (maskUpperTri && block.srcIdx > block.destIdx) { // Color the upper triangular part separately return colord(upperTriColor).toRgbString(); } @@ -130,7 +131,10 @@ export function AttentionPattern({ title: () => "", // Hide the title label({ raw }: TooltipItem<"matrix">) { const block = raw as Block; - if (block.destIdx < block.srcIdx) return "N/A"; // Just show N/A for the upper triangular part + if (maskUpperTri && block.destIdx < block.srcIdx) { + // Just show N/A for the upper triangular part + return "N/A"; + } return [ `(${block.destIdx}, ${block.srcIdx})`, `Src: ${block.srcToken}`, @@ -259,11 +263,22 @@ export interface AttentionPatternProps { */ positiveColor?: string; + /** + * Mask upper triangular + * + * Whether or not to mask the upper triangular portion of the attention patterns. + * + * Should be true for causal attention, false for bidirectional attention. + * + * @default true + */ + maskUpperTri?: boolean; + /** * Upper triangular color * * Color to use for the upper triangular part of the attention pattern to make visualization slightly nicer. - * The upper triangular part is irrelevant because of the causal mask. + * Only applied if maskUpperTri is set to true. * * @default rgb(200, 200, 200) *