Skip to content

Commit

Permalink
Remove support for multiple feature dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
luciaquirke committed Dec 11, 2024
1 parent a78566e commit b3a0e55
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 25 deletions.
7 changes: 0 additions & 7 deletions sae/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from dataclasses import dataclass
from typing import AbstractSet

from simple_parsing import Serializable, list_field

Expand Down Expand Up @@ -67,12 +66,6 @@ class TrainConfig(Serializable):

save_every: int = 1000
"""Save SAEs every `save_every` steps."""

sample_dims: AbstractSet[int] = frozenset({0, 1})
"""Dimensions containing SAE inputs."""

feature_dims: AbstractSet[int] = frozenset({2})
"""Dimensions of SAE inputs."""

log_to_wandb: bool = True
run_name: str | None = None
Expand Down
23 changes: 8 additions & 15 deletions sae/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ def __init__(
found_hookpoints = natsorted(raw_hookpoints)

if not found_hookpoints:
print(f"No modules matched the pattern(s) {cfg.hookpoints}")
print("Available modules:")
for name, _ in model.named_modules():
print(name)
error_msg = f"""No modules matched the pattern(s) {cfg.hookpoints}
Available modules:
{chr(10).join(name for name, _ in model.named_modules())}"""
raise ValueError(error_msg)

cfg.hookpoints = found_hookpoints

Expand All @@ -63,8 +64,7 @@ def __init__(

device = model.device
dummy_inputs = dummy_inputs if dummy_inputs is not None else model.dummy_inputs
input_widths = resolve_widths(model, cfg.hookpoints, dummy_inputs,
dims=cfg.feature_dims)
input_widths = resolve_widths(model, cfg.hookpoints, dummy_inputs)
unique_widths = set(input_widths.values())

if cfg.distribute_modules and len(unique_widths) > 1:
Expand Down Expand Up @@ -215,11 +215,11 @@ def hook(module: nn.Module, inputs, outputs):
outputs = outputs[0]

name = module_to_name[module]
output_dict[name] = outputs
output_dict[name] = outputs.flatten(0, -2)

# Remember the inputs if we're training a transcoder
if self.cfg.transcode:
input_dict[name] = inputs
input_dict[name] = inputs.flatten(0, -2)

for batch in dl:
input_dict.clear()
Expand Down Expand Up @@ -248,13 +248,6 @@ def hook(module: nn.Module, inputs, outputs):
inputs = input_dict.get(name, outputs)
raw = self.saes[name] # 'raw' never has a DDP wrapper

outputs = outputs.permute(*self.cfg.sample_dims, *self.cfg.feature_dims)
outputs = outputs.reshape(-1, raw.d_in)

if self.cfg.transcode:
inputs = inputs.permute(*self.cfg.sample_dims, *self.cfg.feature_dims)
inputs = inputs.reshape(-1, raw.d_in)

# On the first iteration, initialize the decoder bias
if self.global_step == 0:
# NOTE: The all-cat here could conceivably cause an OOM in some
Expand Down
5 changes: 2 additions & 3 deletions sae/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
from typing import Any, Type, TypeVar, cast
from math import prod

import torch
from accelerate.utils import send_to_device
Expand Down Expand Up @@ -63,7 +62,7 @@ def get_layer_list(model: PreTrainedModel) -> tuple[str, nn.ModuleList]:
@torch.inference_mode()
def resolve_widths(
model: PreTrainedModel, module_names: list[str], dummy_inputs: dict[str, Tensor],
dims: set[int] = {-1},
dim: int = -1,
) -> dict[str, int]:
"""Find number of output dimensions for the specified modules."""
module_to_name = {
Expand All @@ -78,7 +77,7 @@ def hook(module, _, output):

name = module_to_name[module]

shapes[name] = prod(output.shape[d] for d in dims)
shapes[name] = output.shape[dim]

handles = [
mod.register_forward_hook(hook) for mod in module_to_name
Expand Down

0 comments on commit b3a0e55

Please sign in to comment.