Skip to content

Commit

Permalink
Fix AttentionHead resizing bug (#63)
Browse files Browse the repository at this point in the history
  • Loading branch information
UFO-101 authored Aug 1, 2023
1 parent a4f6a72 commit 8605b16
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 20 deletions.
13 changes: 2 additions & 11 deletions react/src/attention/AttentionHeads.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -143,17 +143,7 @@ export function AttentionHeads({
<Row>
<Col xs={12}>
<h3 style={{ marginBottom: 10 }}>{headNames[focused]} Zoomed</h3>

<div
style={{
position: "relative",
// Set the maximum width such that a head with just a few tokens
// doesn't have crazy large boxes per token. Note this is the
// width of the full chart (including axis labels) so it also
// needs a sensible lowest maximum.
maxWidth: `${Math.max(Math.round(tokens.length * 2.4), 20)}em`
}}
>
<div>
<h2
style={{
position: "absolute",
Expand All @@ -174,6 +164,7 @@ export function AttentionHeads({
minValue={minValue}
negativeColor={negativeColor}
positiveColor={positiveColor}
zoomed={true}
tokens={tokens}
/>
</div>
Expand Down
46 changes: 37 additions & 9 deletions react/src/attention/AttentionPattern.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ export function AttentionPattern({
positiveColor,
upperTriColor = DefaultUpperTriColor,
showAxisLabels = true,
zoomed = false,
tokens
}: AttentionPatternProps) {
// Tokens must be unique (for the categories), so we add an index prefix
Expand Down Expand Up @@ -164,15 +165,37 @@ export function AttentionPattern({

return (
<Col>
<Row style={{ aspectRatio: showAxisLabels ? undefined : "1/1" }}>
<Chart
type="matrix"
options={options}
data={data}
width={1000}
height={1000}
updateMode="none"
/>
<Row>
<div
style={{
// Chart.js charts resizing is weird.
// Responsive chart elements (which all are by default) require the
// parent element to have position: 'relative' and no sibling elements.
// There were previously issues that only occured at particular display
// sizes and zoom levels. See:
// https://github.com/alan-cooney/CircuitsVis/pull/63
// https://www.chartjs.org/docs/latest/configuration/responsive.html#important-note
// https://stackoverflow.com/a/48770978/7086623
position: "relative",
// Set the maximum width of zoomed heads such that a head with just a
// few tokens doesn't have crazy large boxes per token and the chart
// doesn't overflow the screen. Other heads fill their width.
maxWidth: zoomed
? `min(100%, ${Math.round(tokens.length * 8)}em)`
: "initial",
width: zoomed ? "initial" : "100%",
aspectRatio: "1/1"
}}
>
<Chart
type="matrix"
options={options}
data={data}
width={1000}
height={1000}
updateMode="none"
/>
</div>
</Row>
</Col>
);
Expand Down Expand Up @@ -255,6 +278,11 @@ export interface AttentionPatternProps {
*/
showAxisLabels?: boolean;

/**
* Is this a zoomed in view?
*/
zoomed?: boolean;

/**
* List of tokens
*
Expand Down

0 comments on commit 8605b16

Please sign in to comment.