-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add --no_balance
flag to not balance datasets
#287
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -194,8 +194,8 @@ def inlp( | |
the input dimension. | ||
y: Target tensor of shape (N,) for binary classification or (N, C) for | ||
multiclass classification, where C is the number of classes. | ||
max_iter: Maximum number of iterations to run. If `None`, run for the full | ||
dimension of the input. | ||
max_iter: Maximum number of iterations to run. If `None`, run until the data | ||
is linearly guarded (no linear classifier can extract information). | ||
tol: Tolerance for the loss function. The algorithm will stop when the loss | ||
is within `tol` of the entropy of the labels. | ||
|
||
|
@@ -212,12 +212,11 @@ def inlp( | |
p = y.float().mean() | ||
H = -p * torch.log(p) - (1 - p) * torch.log(1 - p) | ||
|
||
if max_iter is not None: | ||
d = min(d, max_iter) | ||
max_iter = max_iter or d | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's just some refactoring which has nothing to do with the balancing I guesS? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. right, I also added a max_iter flag and this was a necessary refactoring |
||
|
||
# Iterate until the loss is within epsilon of the entropy | ||
result = InlpResult() | ||
for _ in range(d): | ||
for _ in range(max_iter): | ||
clf = cls(d, device=x.device, dtype=x.dtype) | ||
loss = clf.fit(x, y) | ||
result.classifiers.append(clf) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,7 @@ | |
|
||
|
||
def train_supervised( | ||
data: dict[str, tuple], device: str, mode: str | ||
data: dict[str, tuple], device: str, mode: str, max_inlp_iter: int | None = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that's a new feature not related to the balancing either, right? |
||
) -> list[Classifier]: | ||
Xs, train_labels = [], [] | ||
|
||
|
@@ -26,7 +26,7 @@ def train_supervised( | |
lr_model.fit_cv(X, train_labels) | ||
return [lr_model] | ||
elif mode == "inlp": | ||
return Classifier.inlp(X, train_labels).classifiers | ||
return Classifier.inlp(X, train_labels, max_inlp_iter).classifiers | ||
elif mode == "single": | ||
lr_model = Classifier(X.shape[-1], device=device) | ||
lr_model.fit(X, train_labels) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just make it
balance: bool = True ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would also avoid having that:
balance=not cfg.no_balance
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because it would be unclear how to use the flag to disable balancing from the CLA.
--balance False
or something is weirder than--no_balance
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
--balance False does not seem weirder than --no_balance True to me.
But okay, it's fine for me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I think I agree with you now