diff --git a/sae/config.py b/sae/config.py index acf0ab6..eedb268 100644 --- a/sae/config.py +++ b/sae/config.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import AbstractSet from simple_parsing import Serializable, list_field @@ -68,12 +67,6 @@ class TrainConfig(Serializable): save_every: int = 1000 """Save SAEs every `save_every` steps.""" - sample_dims: AbstractSet[int] = frozenset({0, 1}) - """Dimensions containing SAE inputs.""" - - feature_dims: AbstractSet[int] = frozenset({2}) - """Dimensions of SAE inputs.""" - log_to_wandb: bool = True run_name: str | None = None wandb_log_frequency: int = 1 diff --git a/sae/trainer.py b/sae/trainer.py index 0026e04..758dcb8 100644 --- a/sae/trainer.py +++ b/sae/trainer.py @@ -36,10 +36,11 @@ def __init__( found_hookpoints = natsorted(raw_hookpoints) if not found_hookpoints: - print(f"No modules matched the pattern(s) {cfg.hookpoints}") - print("Available modules:") - for name, _ in model.named_modules(): - print(name) + error_msg = f"""No modules matched the pattern(s) {cfg.hookpoints} + +Available modules: +{chr(10).join(name for name, _ in model.named_modules())}""" + raise ValueError(error_msg) cfg.hookpoints = found_hookpoints @@ -63,8 +64,7 @@ def __init__( device = model.device dummy_inputs = dummy_inputs if dummy_inputs is not None else model.dummy_inputs - input_widths = resolve_widths(model, cfg.hookpoints, dummy_inputs, - dims=cfg.feature_dims) + input_widths = resolve_widths(model, cfg.hookpoints, dummy_inputs) unique_widths = set(input_widths.values()) if cfg.distribute_modules and len(unique_widths) > 1: @@ -215,11 +215,11 @@ def hook(module: nn.Module, inputs, outputs): outputs = outputs[0] name = module_to_name[module] - output_dict[name] = outputs + output_dict[name] = outputs.flatten(0, -2) # Remember the inputs if we're training a transcoder if self.cfg.transcode: - input_dict[name] = inputs + input_dict[name] = inputs.flatten(0, -2) for batch in dl: input_dict.clear() @@ -248,13 +248,6 @@ def hook(module: nn.Module, inputs, outputs): inputs = input_dict.get(name, outputs) raw = self.saes[name] # 'raw' never has a DDP wrapper - outputs = outputs.permute(*self.cfg.sample_dims, *self.cfg.feature_dims) - outputs = outputs.reshape(-1, raw.d_in) - - if self.cfg.transcode: - inputs = inputs.permute(*self.cfg.sample_dims, *self.cfg.feature_dims) - inputs = inputs.reshape(-1, raw.d_in) - # On the first iteration, initialize the decoder bias if self.global_step == 0: # NOTE: The all-cat here could conceivably cause an OOM in some diff --git a/sae/utils.py b/sae/utils.py index 4d5e483..5f39d20 100644 --- a/sae/utils.py +++ b/sae/utils.py @@ -1,6 +1,5 @@ import os from typing import Any, Type, TypeVar, cast -from math import prod import torch from accelerate.utils import send_to_device @@ -63,7 +62,7 @@ def get_layer_list(model: PreTrainedModel) -> tuple[str, nn.ModuleList]: @torch.inference_mode() def resolve_widths( model: PreTrainedModel, module_names: list[str], dummy_inputs: dict[str, Tensor], - dims: set[int] = {-1}, + dim: int = -1, ) -> dict[str, int]: """Find number of output dimensions for the specified modules.""" module_to_name = { @@ -77,8 +76,7 @@ def hook(module, _, output): output, *_ = output name = module_to_name[module] - - shapes[name] = prod(output.shape[d] for d in dims) + shapes[name] = output.shape[dim] handles = [ mod.register_forward_hook(hook) for mod in module_to_name