Skip to content

Commit

Permalink
Add coloured tokens multi visualisation (#41)
Browse files Browse the repository at this point in the history
Co-authored-by: Alan <41682961+alan-cooney@users.noreply.github.com>
  • Loading branch information
neelnanda-io and alan-cooney authored Jan 10, 2023
1 parent 5563526 commit f80525e
Show file tree
Hide file tree
Showing 8 changed files with 540 additions and 5 deletions.
85 changes: 85 additions & 0 deletions python/circuitsvis/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
import torch
from circuitsvis.utils.render import RenderedHTML, render

ArrayRank1 = Union[List[float], np.ndarray, torch.Tensor]
ArrayRank2 = Union[List[List[float]], np.ndarray, torch.Tensor]
ArrayRank3 = Union[List[List[List[float]]], np.ndarray, torch.Tensor]
IntArrayRank1 = Union[List[int], np.ndarray, torch.Tensor]


def colored_tokens(
tokens: List[str],
Expand Down Expand Up @@ -40,3 +45,83 @@ def colored_tokens(
"ColoredTokens",
**kwargs
)


def colored_tokens_multi(
tokens: List[str],
values: torch.Tensor,
labels: Optional[List[str]] = None,
) -> RenderedHTML:
"""Shows a sequence of tokens colored by their value.
Takes in a tensor of values of shape [S, K] (S tokens, K different types of
value).
The user can hover or click on a button for each of the K types to color the
token with those values.
The user can hover over a token to see a list of the K values for that
token.
Args:
tokens: List of string tokens, one for each token in the prompt. Length [S]
values: The tensor of values to color tokens by. Shape [S, K]
labels: The names of the values. Length [K].
Returns:
Html: Log prob visualization
"""
assert len(tokens) == values.size(0), \
f"Number of tokens ({len(tokens)}) must equal first dimension of values tensor, " + \
f"shape {values.shape}"
if labels:
assert len(labels) == values.size(1), \
f"Number of labels ({len(labels)}) must equal second dimension of values tensor, " + \
f"shape {values.shape}"

return render(
"ColoredTokensMulti",
tokens=tokens,
values=values,
labels=labels,
)


def visualize_model_performance(
tokens: torch.Tensor,
str_tokens: List[str],
logits: torch.Tensor,
):
"""Visualizes model performance on some text
Shows logits, log probs, and probabilities for predicting each token (from
the previous tokens), colors the tokens according to one of logits, log
probs and probabilities, according to user input.
Allows the user to enter custom bounds for the values (eg, saturate color of
probability at 0.01)
"""
if len(tokens.shape) == 2:
assert tokens.shape[0] == 1, \
f"tokens must be rank 1, or rank 2 with a dummy batch dimension. Shape: {tokens.shape}"
tokens = tokens[0]
if len(logits.shape) == 3:
assert logits.shape[0] == 1, \
f"logits must be rank 2, or rank 3 with a dummy batch dimension. Shape: {logits.shape}"
logits = logits[0]
assert len(str_tokens) == len(tokens), \
"Must have same number of tokens and str_tokens"
assert len(tokens) == logits.shape[0], \
"Must have the same number of tokens and logit vectors"

# We remove the final vector of logits, as it can't predict anything.
logits = logits[:-1]
log_probs = logits.log_softmax(dim=-1)
probs = logits.softmax(dim=-1)
values = torch.stack([
logits.gather(-1, tokens[1:, None])[:, 0],
log_probs.gather(-1, tokens[1:, None])[:, 0],
probs.gather(-1, tokens[1:, None])[:, 0],
], dim=1)
labels = ["logits", "log_probs", "probs"]
return colored_tokens_multi(str_tokens[1:], values, labels)
6 changes: 5 additions & 1 deletion react/.eslintrc.js
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ module.exports = {
],
rules: {
"import/prefer-default-export": "off",
"import/no-extraneous-dependencies": ["error", { "devDependencies": ["**/*.test.*", "**/*.stories.*"] }],
"import/no-extraneous-dependencies": [
"error",
{ devDependencies: ["**/*.test.*", "**/*.stories.*"] }
],
"no-plusplus": ["error", { allowForLoopAfterthoughts: true }],
"react/jsx-filename-extension": "off",
"react/require-default-props": "off", // Not needed with TS
"react/react-in-jsx-scope": "off" // Esbuild injects this for us
Expand Down
5 changes: 3 additions & 2 deletions react/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ export { AttentionHeads } from "./attention/AttentionHeads";
export { AttentionPattern } from "./attention/AttentionPattern";
export { AttentionPatterns } from "./attention/AttentionPatterns";
export { ColoredTokens } from "./tokens/ColoredTokens";
export { ColoredTokensMulti } from "./tokens/ColoredTokensMulti";
export { Hello } from "./examples/Hello";
export { render } from "./render-helper";
export { TextNeuronActivations } from "./activations/TextNeuronActivations";
export { TopkTokens } from "./topk/TopkTokens";
export { TopkSamples } from "./topk/TopkSamples";
export { TokenLogProbs } from "./logits/TokenLogProbs";
export { TopkSamples } from "./topk/TopkSamples";
export { TopkTokens } from "./topk/TopkTokens";
113 changes: 113 additions & 0 deletions react/src/tokens/ColoredTokensCustomTooltips.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import { AnyColor } from "colord";
import React from "react";
import { TokenCustomTooltip } from "./utils/TokenCustomTooltip";

/**
* Display tokens with a background representing how negative (close to
* `negativeColor`) or positive (close to `positiveColor`) the token is. Zero is
* always displayed as white.
*
* Hover over a token, to view its value.
*/
export function ColoredTokensCustomTooltips({
maxValue,
minValue,
negativeColor,
positiveColor,
tokens,
values,
tooltips
}: ColoredTokensCustomTooltipsProps) {
const tokenMin = minValue ?? Math.min(...values);
const tokenMax = maxValue ?? Math.max(...values);

return (
<div className="colored-tokens" style={{ paddingBottom: 30 }}>
{tokens.map((token, key) => (
<TokenCustomTooltip
key={key}
token={token}
value={values[key]}
min={tokenMin}
max={tokenMax}
negativeColor={negativeColor}
positiveColor={positiveColor}
tooltip={tooltips[key]}
/>
))}
</div>
);
}

export interface ColoredTokensCustomTooltipsProps {
/**
* Maximum value
*
* Used to determine how dark the token color is when positive (i.e. based on
* how close it is to the minimum value).
*
* @default Math.max(...values)
*/
maxValue?: number;

/**
* Minimum value
*
* Used to determine how dark the token color is when negative (i.e. based on
* how close it is to the minimum value).
*
* @default Math.min(...values)
*/
minValue?: number;

/**
* Negative color
*
* Color to use for negative values. This can be any valid CSS color string.
*
* Be mindful of color blindness if not using the default here.
*
* @default "red"
*
* @example rgb(255, 0, 0)
*
* @example #ff0000
*/
negativeColor?: AnyColor;

/**
* Positive color
*
* Color to use for positive values. This can be any valid CSS color string.
*
* Be mindful of color blindness if not using the default here.
*
* @default "blue"
*
* @example rgb(0, 0, 255)
*
* @example #0000ff
*/
positiveColor?: AnyColor;

/**
* List of tokens
*
* Must be the same length as the list of values.
*/
tokens: string[];

/**
* Values for each token
*
* Must be the same length as the list of tokens.
*/
values: number[];

/**
* Tooltips for each token
*
* Must be the same length as the list of tokens.
*/
tooltips: React.ReactNode[];
}
19 changes: 19 additions & 0 deletions react/src/tokens/ColoredTokensMulti.stories.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import { ComponentStory, ComponentMeta } from "@storybook/react";
import React from "react";
import { mockTokens, mockValues, mockLabels } from "./mocks/coloredTokensMulti";
import { ColoredTokensMulti } from "./ColoredTokensMulti";

export default {
component: ColoredTokensMulti
} as ComponentMeta<typeof ColoredTokensMulti>;

const Template: ComponentStory<typeof ColoredTokensMulti> = (args) => (
<ColoredTokensMulti {...args} />
);

export const SmallModelExample = Template.bind({});
SmallModelExample.args = {
tokens: mockTokens,
values: mockValues,
labels: mockLabels
};
Loading

0 comments on commit f80525e

Please sign in to comment.