Skip to content

Commit

Permalink
Add assert statement to assume batch and feature dimensions are well …
Browse files Browse the repository at this point in the history
…ordered; remove permutes
  • Loading branch information
luciaquirke committed Dec 4, 2024
1 parent a78566e commit c700c96
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 12 deletions.
8 changes: 2 additions & 6 deletions sae/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from dataclasses import dataclass
from typing import AbstractSet

from simple_parsing import Serializable, list_field

Expand Down Expand Up @@ -68,11 +67,8 @@ class TrainConfig(Serializable):
save_every: int = 1000
"""Save SAEs every `save_every` steps."""

sample_dims: AbstractSet[int] = frozenset({0, 1})
"""Dimensions containing SAE inputs."""

feature_dims: AbstractSet[int] = frozenset({2})
"""Dimensions of SAE inputs."""
feature_dims: list[int] = list_field()
"""Dimensions of SAE inputs features."""

log_to_wandb: bool = True
run_name: str | None = None
Expand Down
6 changes: 1 addition & 5 deletions sae/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,12 +248,8 @@ def hook(module: nn.Module, inputs, outputs):
inputs = input_dict.get(name, outputs)
raw = self.saes[name] # 'raw' never has a DDP wrapper

outputs = outputs.permute(*self.cfg.sample_dims, *self.cfg.feature_dims)
outputs = outputs.reshape(-1, raw.d_in)

if self.cfg.transcode:
inputs = inputs.permute(*self.cfg.sample_dims, *self.cfg.feature_dims)
inputs = inputs.reshape(-1, raw.d_in)
inputs = inputs.reshape(-1, raw.d_in)

# On the first iteration, initialize the decoder bias
if self.global_step == 0:
Expand Down
6 changes: 5 additions & 1 deletion sae/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def get_layer_list(model: PreTrainedModel) -> tuple[str, nn.ModuleList]:
@torch.inference_mode()
def resolve_widths(
model: PreTrainedModel, module_names: list[str], dummy_inputs: dict[str, Tensor],
dims: set[int] = {-1},
dims: list[int] = [-1],
) -> dict[str, int]:
"""Find number of output dimensions for the specified modules."""
module_to_name = {
Expand All @@ -78,6 +78,10 @@ def hook(module, _, output):

name = module_to_name[module]

pos_dims = [d if d >= 0 else output.ndim + d for d in dims]
assert all(i in pos_dims for i in range(min(pos_dims), max(pos_dims) + 1)) and max(pos_dims) == output.ndim - 1, \
f"Feature dimensions {dims} must be contiguous and include the final dimension"

shapes[name] = prod(output.shape[d] for d in dims)

handles = [
Expand Down

0 comments on commit c700c96

Please sign in to comment.