From 4ea0c4004913c6b68273eb99d5b29db110b5bbaf Mon Sep 17 00:00:00 2001 From: Alex Mallen Date: Tue, 22 Aug 2023 21:09:37 +0000 Subject: [PATCH 1/3] flag --- elk/extraction/extraction.py | 4 ++++ elk/extraction/prompt_loading.py | 5 +++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 26e8a7f1..d4b8ad6c 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -65,6 +65,9 @@ class Extract(Serializable): binarize: bool = False """Whether to binarize the dataset labels for multi-class datasets.""" + no_balance: bool = False + """Whether to disable balancing the dataset by label.""" + int8: bool = False """Whether to perform inference in mixed int8 precision with `bitsandbytes`.""" @@ -189,6 +192,7 @@ def extract_hiddens( num_shots=cfg.num_shots, split_type=split_type, template_path=cfg.template_path, + balance=not cfg.no_balance, rank=rank, world_size=world_size, seed=cfg.seed, diff --git a/elk/extraction/prompt_loading.py b/elk/extraction/prompt_loading.py index cb42d233..defd42da 100644 --- a/elk/extraction/prompt_loading.py +++ b/elk/extraction/prompt_loading.py @@ -21,6 +21,7 @@ def load_prompts( seed: int = 42, split_type: Literal["train", "val"] = "train", template_path: str | None = None, + balance: bool = True, rank: int = 0, world_size: int = 1, ) -> Iterator[dict]: @@ -89,14 +90,14 @@ def load_prompts( else: fewshot_iter = None - if label_column in ds.features: + if label_column in ds.features and balance: ds = BalancedSampler( ds.to_iterable_dataset(), set(label_choices), label_col=label_column, ) else: - if rank == 0: + if rank == 0 and balance: print("No label column found, not balancing") ds = ds.to_iterable_dataset() From ec63dfdd5cc4b4e9b00b33bb5fb95919f261e6ad Mon Sep 17 00:00:00 2001 From: Alex Mallen Date: Tue, 22 Aug 2023 21:11:08 +0000 Subject: [PATCH 2/3] add inlp max iterations --- elk/training/classifier.py | 9 ++++----- elk/training/supervised.py | 4 ++-- elk/training/train.py | 4 ++++ 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/elk/training/classifier.py b/elk/training/classifier.py index 148da939..2394cef4 100644 --- a/elk/training/classifier.py +++ b/elk/training/classifier.py @@ -194,8 +194,8 @@ def inlp( the input dimension. y: Target tensor of shape (N,) for binary classification or (N, C) for multiclass classification, where C is the number of classes. - max_iter: Maximum number of iterations to run. If `None`, run for the full - dimension of the input. + max_iter: Maximum number of iterations to run. If `None`, run until the data + is linearly guarded (no linear classifier can extract information). tol: Tolerance for the loss function. The algorithm will stop when the loss is within `tol` of the entropy of the labels. @@ -212,12 +212,11 @@ def inlp( p = y.float().mean() H = -p * torch.log(p) - (1 - p) * torch.log(1 - p) - if max_iter is not None: - d = min(d, max_iter) + max_iter = max_iter or d # Iterate until the loss is within epsilon of the entropy result = InlpResult() - for _ in range(d): + for _ in range(max_iter): clf = cls(d, device=x.device, dtype=x.dtype) loss = clf.fit(x, y) result.classifiers.append(clf) diff --git a/elk/training/supervised.py b/elk/training/supervised.py index d2eef5f7..84f4e988 100644 --- a/elk/training/supervised.py +++ b/elk/training/supervised.py @@ -6,7 +6,7 @@ def train_supervised( - data: dict[str, tuple], device: str, mode: str + data: dict[str, tuple], device: str, mode: str, max_inlp_iter: int | None = None ) -> list[Classifier]: Xs, train_labels = [], [] @@ -26,7 +26,7 @@ def train_supervised( lr_model.fit_cv(X, train_labels) return [lr_model] elif mode == "inlp": - return Classifier.inlp(X, train_labels).classifiers + return Classifier.inlp(X, train_labels, max_inlp_iter).classifiers elif mode == "single": lr_model = Classifier(X.shape[-1], device=device) lr_model.fit(X, train_labels) diff --git a/elk/training/train.py b/elk/training/train.py index 938fcf79..7e2abeb9 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -34,6 +34,9 @@ class Elicit(Run): cross-validation. Defaults to "single", which means to train a single classifier on the training data. "cv" means to use cross-validation.""" + max_inlp_iter: int | None = None + """Maximum number of iterations for Iterative Nullspace Projection (INLP).""" + def create_models_dir(self, out_dir: Path): lr_dir = None lr_dir = out_dir / "lr_models" @@ -123,6 +126,7 @@ def apply_to_layer( train_dict, device=device, mode=self.supervised, + max_inlp_iter=self.max_inlp_iter, ) with open(lr_dir / f"layer_{layer}.pt", "wb") as file: torch.save(lr_models, file) From 1ed44eff7a9b6a8f59cf640a53a99a46f1465762 Mon Sep 17 00:00:00 2001 From: Alex Mallen Date: Tue, 22 Aug 2023 21:11:36 +0000 Subject: [PATCH 3/3] add check for 0 val examples --- elk/debug_logging.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/elk/debug_logging.py b/elk/debug_logging.py index 59bea62f..d93df9b1 100644 --- a/elk/debug_logging.py +++ b/elk/debug_logging.py @@ -31,6 +31,9 @@ def save_debug_log(datasets: list[DatasetDictWithName], out_dir: Path) -> None: else: train_split, val_split = select_train_val_splits(ds) + if len(ds[val_split]) == 0: + logging.warning(f"Val split '{val_split}' is empty!") + continue text_questions = ds[val_split][0]["text_questions"] template_ids = ds[val_split][0]["variant_ids"] label = ds[val_split][0]["label"]