Skip to content

Commit

Permalink
Rename instance dims -> sample dims
Browse files Browse the repository at this point in the history
  • Loading branch information
luciaquirke committed Dec 3, 2024
1 parent d250458 commit dab69e7
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
7 changes: 4 additions & 3 deletions sae/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,15 @@ class TrainConfig(Serializable):
save_every: int = 1000
"""Save SAEs every `save_every` steps."""

instance_dims: tuple = (0, 1)
sample_dims: set[int] = {0, 1}
"""Dimensions containing SAE inputs."""

feature_dims: tuple = (2,)
feature_dims: set[int] = {2}
"""Dimensions of SAE inputs."""

log_to_wandb: bool = True
run_name: str | None = None
wandb_log_frequency: int = 1


def __post_init__(self):
assert not (
Expand Down
4 changes: 2 additions & 2 deletions sae/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,11 @@ def hook(module: nn.Module, inputs, outputs):
inputs = input_dict.get(name, outputs)
raw = self.saes[name] # 'raw' never has a DDP wrapper

outputs = outputs.permute(*self.cfg.instance_dims, *self.cfg.feature_dims)
outputs = outputs.permute(*self.cfg.sample_dims, *self.cfg.feature_dims)
outputs = outputs.reshape(-1, raw.d_in)

if self.cfg.transcode:
inputs = inputs.permute(*self.cfg.instance_dims, *self.cfg.feature_dims)
inputs = inputs.permute(*self.cfg.sample_dims, *self.cfg.feature_dims)
inputs = inputs.reshape(-1, raw.d_in)

# On the first iteration, initialize the decoder bias
Expand Down
2 changes: 1 addition & 1 deletion sae/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def get_layer_list(model: PreTrainedModel) -> tuple[str, nn.ModuleList]:
@torch.inference_mode()
def resolve_widths(
model: PreTrainedModel, module_names: list[str], dummy_inputs: dict[str, Tensor],
dims: tuple[int] = (-1,),
dims: set[int] = {-1},
) -> dict[str, int]:
"""Find number of output dimensions for the specified modules."""
module_to_name = {
Expand Down

0 comments on commit dab69e7

Please sign in to comment.