diff --git a/sae/config.py b/sae/config.py index 12d91d2..31c6e1d 100644 --- a/sae/config.py +++ b/sae/config.py @@ -67,14 +67,15 @@ class TrainConfig(Serializable): save_every: int = 1000 """Save SAEs every `save_every` steps.""" - instance_dims: tuple = (0, 1) + sample_dims: set[int] = {0, 1} + """Dimensions containing SAE inputs.""" - feature_dims: tuple = (2,) + feature_dims: set[int] = {2} + """Dimensions of SAE inputs.""" log_to_wandb: bool = True run_name: str | None = None wandb_log_frequency: int = 1 - def __post_init__(self): assert not ( diff --git a/sae/trainer.py b/sae/trainer.py index f277f65..0026e04 100644 --- a/sae/trainer.py +++ b/sae/trainer.py @@ -248,11 +248,11 @@ def hook(module: nn.Module, inputs, outputs): inputs = input_dict.get(name, outputs) raw = self.saes[name] # 'raw' never has a DDP wrapper - outputs = outputs.permute(*self.cfg.instance_dims, *self.cfg.feature_dims) + outputs = outputs.permute(*self.cfg.sample_dims, *self.cfg.feature_dims) outputs = outputs.reshape(-1, raw.d_in) if self.cfg.transcode: - inputs = inputs.permute(*self.cfg.instance_dims, *self.cfg.feature_dims) + inputs = inputs.permute(*self.cfg.sample_dims, *self.cfg.feature_dims) inputs = inputs.reshape(-1, raw.d_in) # On the first iteration, initialize the decoder bias diff --git a/sae/utils.py b/sae/utils.py index cac2f0a..4d5e483 100644 --- a/sae/utils.py +++ b/sae/utils.py @@ -63,7 +63,7 @@ def get_layer_list(model: PreTrainedModel) -> tuple[str, nn.ModuleList]: @torch.inference_mode() def resolve_widths( model: PreTrainedModel, module_names: list[str], dummy_inputs: dict[str, Tensor], - dims: tuple[int] = (-1,), + dims: set[int] = {-1}, ) -> dict[str, int]: """Find number of output dimensions for the specified modules.""" module_to_name = {