Skip to content

Commit

Permalink
- Write 3 short guides in the documentation.
Browse files Browse the repository at this point in the history
- Improve circuit visualization functions.
- Improve convenience functions, like circuit_prune_scores in patchable_model now can take edge names as strings.
  • Loading branch information
UFO-101 committed May 11, 2024
1 parent f2442ec commit 5777057
Show file tree
Hide file tree
Showing 22 changed files with 767 additions and 119 deletions.
1 change: 0 additions & 1 deletion auto_circuit/experiment_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,6 @@ def ioi_circuit_single_template_logit_diff_percent(
patch_type=patch_type,
ablation_type=ablation_type,
render_graph=False,
render_all_edges=False,
)
(
logit_diff_percent_mean,
Expand Down
6 changes: 3 additions & 3 deletions auto_circuit/prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def run_circuits(
ablation_type: AblationType = AblationType.RESAMPLE,
reverse_clean_corrupt: bool = False,
render_graph: bool = False,
render_all_edges: bool = False,
render_score_threshold: bool = False,
render_file_path: Optional[str] = None,
) -> CircuitOutputs:
"""Run the model, pruning edges based on the given `prune_scores`. Runs the model
Expand All @@ -46,7 +46,7 @@ def run_circuits(
ablation_type: The type of ablation to use.
reverse_clean_corrupt: Reverse clean and corrupt (for input and patches).
render_graph: Whether to render the graph using `draw_seq_graph`.
render_all_edges: Whether to render all edges, if `render_graph` is `True`.
render_score_threshold: Edge score threshold, if `render_graph` is `True`.
render_file_path: Path to save the rendered graph, if `render_graph` is `True`.
Returns:
Expand Down Expand Up @@ -103,7 +103,7 @@ def run_circuits(
if render_graph:
draw_seq_graph(
model=model,
show_all_edges=render_all_edges,
score_threshold=render_score_threshold,
show_all_seq_pos=False,
seq_labels=dataloader.seq_labels,
file_path=render_file_path,
Expand Down
3 changes: 2 additions & 1 deletion auto_circuit/prune_algos/ACDC.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch as t
from ordered_set import OrderedSet
from plotly import graph_objects as go
from torch.nn.functional import log_softmax, mse_loss

from auto_circuit.data import PromptDataLoader
Expand Down Expand Up @@ -35,7 +36,7 @@ def acdc_prune_scores(
test_mode: bool = False,
run_circuits_ref: Optional[Callable[..., CircuitOutputs]] = None,
show_graphs: bool = False,
draw_seq_graph_ref: Optional[Callable[..., None]] = None,
draw_seq_graph_ref: Optional[Callable[..., go.Figure]] = None,
) -> PruneScores:
"""
Run the ACDC algorithm from the paper "Towards Automated Circuit Discovery for
Expand Down
1 change: 1 addition & 0 deletions auto_circuit/utils/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ def patch_mode(
if curr_src_outs is None:
curr_src_outs = t.zeros_like(patch_src_outs)

# TODO: Raise an error if one of the edge names doesn't exist.
if edges is not None:
set_all_masks(model, val=0.0)
for edge in model.edges:
Expand Down
38 changes: 30 additions & 8 deletions auto_circuit/utils/patchable_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Dict, List, Optional, Set, Tuple
from collections import defaultdict
from typing import Any, Collection, Dict, List, Optional, Set, Tuple

import torch as t
from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCache
Expand Down Expand Up @@ -60,6 +61,7 @@ class PatchableModel(t.nn.Module):
srcs: Set[SrcNode]
dests: Set[DestNode]
edge_dict: Dict[int | None, List[Edge]] # Key is token position or None for all
edge_name_dict: Dict[int | None, Dict[str, Edge]]
edges: Set[Edge]
n_edges: int
seq_dim: int
Expand Down Expand Up @@ -101,6 +103,9 @@ def __init__(
self.edge_dict = edge_dict
self.edges = edges
self.n_edges = len(edges)
self.edge_name_dict = defaultdict(dict)
for edge in edges:
self.edge_name_dict[edge.seq_idx][edge.name] = edge
self.seq_dim = seq_dim
self.seq_len = seq_len
self.wrappers = wrappers
Expand Down Expand Up @@ -179,25 +184,42 @@ def new_prune_scores(self, init_val: float = 0.0) -> PruneScores:
prune_scores[mod_name] = t.full_like(mask.data, init_val)
return prune_scores

def circuit_prune_scores(self, edges: Set[Edge], bool: bool = False) -> PruneScores:
def circuit_prune_scores(
self,
edges: Optional[Collection[Edge | str]] = None,
edge_dict: Optional[Dict[Edge, float] | Dict[str, float]] = None,
bool: bool = False,
) -> PruneScores:
"""
Convert a set of edges to a corresponding
[`PruneScores`][auto_circuit.types.PruneScores] object.
Args:
edges: The set of edges to convert to prune scores.
edges: The set of edges or edge names to convert to prune scores.
bool: Whether to return the prune scores as boolean type tensors.
Returns:
The prune scores corresponding to the set of edges.
"""
prune_scores = self.new_prune_scores()
for edge in edges:
prune_scores[edge.dest.module_name][edge.patch_idx] = 1.0
ps = self.new_prune_scores()
assert not (edges is None and edge_dict is None), "Must specify edges"

# TODO: Raise an error if one of the edge names doesn't exist.
if edges is not None:
for edge in self.edges:
if edge in edges or edge.name in edges:
ps[edge.dest.module_name][edge.patch_idx] = 1.0
else:
assert edge_dict is not None
for e in self.edges:
if e in edge_dict.keys():
ps[e.dest.module_name][e.patch_idx] = edge_dict[e] # type: ignore
if e.name in edge_dict.keys():
ps[e.dest.module_name][e.patch_idx] = edge_dict[e] # type: ignore
if bool:
return dict([(mod, mask.bool()) for (mod, mask) in prune_scores.items()])
return dict([(mod, mask.bool()) for (mod, mask) in ps.items()])
else:
return prune_scores
return ps

def current_patch_masks_as_prune_scores(self) -> PruneScores:
"""
Expand Down
Loading

0 comments on commit 5777057

Please sign in to comment.