Skip to content

Commit

Permalink
Add attention heads visualisation (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
alan-cooney authored Dec 11, 2022
1 parent 8065614 commit ab0dd58
Show file tree
Hide file tree
Showing 7 changed files with 391 additions and 52 deletions.
50 changes: 50 additions & 0 deletions python/circuitsvis/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,56 @@
from circuitsvis.utils.render import RenderedHTML, render


def attention_heads(
attention: Union[list, np.ndarray, torch.Tensor],
tokens: List[str],
max_value: Optional[float] = None,
min_value: Optional[float] = None,
negative_color: Optional[str] = None,
positive_color: Optional[str] = None,
) -> RenderedHTML:
"""Attention Heads
Attention patterns from destination to source tokens, for a group of heads.
Displays a small heatmap for each attention head. When one is selected, it
is then shown in full size.
Args:
attention: Attention head activations of the shape [dest_tokens x
src_tokens]
tokens: List of tokens (e.g. `["A", "person"]`). Must be the same length
as the list of values.
max_value: Maximum value. Used to determine how dark the token color is
when positive (i.e. based on how close it is to the maximum value).
min_value: Minimum value. Used to determine how dark the token color is
when negative (i.e. based on how close it is to the minimum value).
negative_color: Color for negative values. This can be any valid CSS
color string. Be mindful of color blindness if not using the default
here.
positive_color: Color for positive values. This can be any valid CSS
color string. Be mindful of color blindness if not using the default
here.
Returns:
Html: Attention pattern visualization
"""
kwargs = {
"tokens": tokens,
"attention": attention,
"minValue": min_value,
"maxValue": max_value,
"negativeColor": negative_color,
"positiveColor": positive_color,
}

return render(
"AttentionHeads",
**kwargs
)



def attention_patterns(
tokens: List[str],
attention: Union[list, np.ndarray, torch.Tensor],
Expand Down
26 changes: 26 additions & 0 deletions react/src/attention/AttentionHeads.stories.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import { ComponentStory, ComponentMeta } from "@storybook/react";
import React from "react";

import { AttentionHeads } from "./AttentionHeads";
import { mockAttention, mockTokens } from "./mocks/attention";

export default {
component: AttentionHeads,
argTypes: {
negativeColor: { control: "color" },
positiveColor: { control: "color" },
tokens: { control: { type: "object", raw: true } },
values: { control: { type: "object", raw: true } }
}
} as ComponentMeta<typeof AttentionHeads>;

const Template: ComponentStory<typeof AttentionHeads> = (args) => (
<AttentionHeads {...args} />
);

export const InductionHeadsLayer: ComponentStory<typeof AttentionHeads> =
Template.bind({});
InductionHeadsLayer.args = {
tokens: mockTokens,
attention: mockAttention
};
260 changes: 260 additions & 0 deletions react/src/attention/AttentionHeads.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
import React from "react";
import { Col, Container, Row } from "react-grid-system";
import { AttentionPattern } from "./AttentionPattern";
import { useHoverLock, UseHoverLockState } from "./components/useHoverLock";

/**
* Attention head color
*
* @param idx Head index
* @param numberOfHeads Number of heads
* @param alpha Opaqueness (0% = fully transparent, 100% = fully opaque)
*/
export function attentionHeadColor(
idx: number,
numberOfHeads: number,
alpha: string = "100%"
): string {
const hue = Math.round((idx / numberOfHeads) * 360);

return `hsla(${hue}, 70%, 50%, ${alpha})`;
}

/**
* Attention Heads Selector
*/
export function AttentionHeadsSelector({
attention,
attentionHeadNames,
focused,
maxValue,
minValue,
negativeColor,
onClick,
onMouseEnter,
onMouseLeave,
positiveColor,
tokens
}: AttentionHeadsProps & {
attentionHeadNames: string[];
} & UseHoverLockState) {
return (
<Row style={{ marginBottom: 15 }}>
{attention.map((headAttention, idx) => {
const isFocused = focused === idx;

return (
<Col lg={1} md={2} xs={3} style={{ margin: 0, padding: 0 }} key={idx}>
<div
style={{ padding: 3 }}
onClick={() => onClick(idx)}
onMouseEnter={() => onMouseEnter(idx)}
onMouseLeave={onMouseLeave}
>
<div
style={{
position: "relative",
borderStyle: "solid",
borderWidth: 1,
borderColor: attentionHeadColor(idx, attention.length),
boxShadow: isFocused
? `0px 0px 4px 3px ${attentionHeadColor(
idx,
attention.length,
"60%"
)}`
: undefined
}}
>
<h4
style={{
position: "absolute",
top: 0,
right: 0,
zIndex: 100,
margin: 0,
padding: 1,
background: attentionHeadColor(idx, attention.length),
color: "white"
}}
>
{attentionHeadNames[idx]}
</h4>

<AttentionPattern
attention={headAttention}
tokens={tokens}
showAxisLabels={false}
maxValue={maxValue}
minValue={minValue}
negativeColor={negativeColor}
positiveColor={positiveColor}
/>
</div>
</div>
</Col>
);
})}
</Row>
);
}

/**
* Attention patterns from destination to source tokens, for a group of heads.
*
* Displays a small heatmap for each attention head. When one is selected, it is
* then shown in full size.
*/
export function AttentionHeads({
attention,
attentionHeadNames,
maxValue,
minValue,
negativeColor,
positiveColor,
tokens
}: AttentionHeadsProps) {
// Attention head focussed state
const { focused, onClick, onMouseEnter, onMouseLeave } = useHoverLock(0);

const headNames =
attentionHeadNames || attention.map((_, idx) => `Head ${idx}`);

return (
<Container>
<h3 style={{ marginBottom: 15 }}>
Head Selector (hover to view, click to lock)
</h3>

<AttentionHeadsSelector
attention={attention}
attentionHeadNames={headNames}
focused={focused}
maxValue={maxValue}
minValue={minValue}
negativeColor={negativeColor}
onClick={onClick}
onMouseEnter={onMouseEnter}
onMouseLeave={onMouseLeave}
positiveColor={positiveColor}
tokens={tokens}
/>

<Row>
<Col xs={12}>
<h3 style={{ marginBottom: 10 }}>{headNames[focused]} Zoomed</h3>

<div
style={{
position: "relative",
maxWidth: `${tokens.length * 4}em`
}}
>
<h2
style={{
position: "absolute",
top: 0,
right: 0,
zIndex: 1000,
margin: 6,
padding: "5px 10px",
background: attentionHeadColor(focused, attention.length),
color: "white"
}}
>
{headNames[focused]}
</h2>
<AttentionPattern
attention={attention[focused]}
maxValue={maxValue}
minValue={minValue}
negativeColor={negativeColor}
positiveColor={positiveColor}
tokens={tokens}
/>
</div>
</Col>
</Row>

<Row></Row>
</Container>
);
}

export interface AttentionHeadsProps {
/**
* Attention heads activations
*
* Of the shape [ heads x dest_pos x src_pos ]
*/
attention: number[][][];

/**
* Names for each attention head
*
* Useful if e.g. you want to label the heads with the layer they are from.
*/
attentionHeadNames?: string[];

/**
* Maximum value
*
* Used to determine how dark the token color is when positive (i.e. based on
* how close it is to the maximum value).
*
* @default Math.max(...values)
*/
maxValue?: number;

/**
* Minimum value
*
* Used to determine how dark the token color is when negative (i.e. based on
* how close it is to the minimum value).
*
* @default Math.min(...values)
*/
minValue?: number;

/**
* Negative color
*
* Color to use for negative values. This can be any valid CSS color string.
*
* Be mindful of color blindness if not using the default here.
*
* @default red
*
* @example rgb(255, 0, 0)
*
* @example #ff0000
*/
negativeColor?: string;

/**
* Positive color
*
* Color to use for positive values. This can be any valid CSS color string.
*
* Be mindful of color blindness if not using the default here.
*
* @default blue
*
* @example rgb(0, 0, 255)
*
* @example #0000ff
*/
positiveColor?: string;

/**
* Show axis labels
*/
showAxisLabels?: boolean;

/**
* List of tokens
*
* Must be the same length as the list of values.
*/
tokens: string[];
}
Loading

0 comments on commit ab0dd58

Please sign in to comment.