Skip to content

Commit

Permalink
fix: Improve sklearn automatic batch size (#125)
Browse files Browse the repository at this point in the history
* refactor: Improve safety of automatic batch size computation

* build: Upgrade version, update changelog
  • Loading branch information
lorenzomammana authored Jul 9, 2024
1 parent 8a2ab23 commit 1640df6
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 30 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
# Changelog
All notable changes to this project will be documented in this file.

### [2.1.13]

- Improve safe batch size computation for sklearn based classification tasks

### [2.1.12]

#### Fixed
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "quadra"
version = "2.1.12"
version = "2.1.13"
description = "Deep Learning experiment orchestration library"
authors = [
"Federico Belotti <[email protected]>",
Expand Down
2 changes: 1 addition & 1 deletion quadra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "2.1.12"
__version__ = "2.1.13"


def get_version():
Expand Down
16 changes: 6 additions & 10 deletions quadra/tasks/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from quadra.trainers.classification import SklearnClassificationTrainer
from quadra.utils import utils
from quadra.utils.classification import (
automatic_batch_size_computation,
get_results,
save_classification_result,
)
Expand Down Expand Up @@ -539,15 +538,6 @@ def prepare(self) -> None:
self.datamodule.prepare_data()
self.datamodule.setup(stage="fit")

if not self.automatic_batch_size.disable and self.device != "cpu":
self.datamodule.batch_size = automatic_batch_size_computation(
datamodule=self.datamodule,
backbone=self.backbone,
starting_batch_size=self.automatic_batch_size.starting_batch_size,
)

self.train_dataloader_list = list(self.datamodule.train_dataloader())
self.test_dataloader_list = list(self.datamodule.val_dataloader())
self.trainer = self.config.trainer

@property
Expand Down Expand Up @@ -601,6 +591,7 @@ def trainer(self, trainer_config: DictConfig) -> None:
self._trainer = trainer

@typing.no_type_check
@automatic_datamodule_batch_size(batch_size_attribute_name="batch_size")
def train(self) -> None:
"""Train the model."""
log.info("Starting training...!")
Expand All @@ -609,6 +600,9 @@ def train(self) -> None:

class_to_keep = None

self.train_dataloader_list = list(self.datamodule.train_dataloader())
self.test_dataloader_list = list(self.datamodule.val_dataloader())

if hasattr(self.datamodule, "class_to_keep_training") and self.datamodule.class_to_keep_training is not None:
class_to_keep = self.datamodule.class_to_keep_training

Expand Down Expand Up @@ -729,6 +723,7 @@ def extract_model_summary(

break

@automatic_datamodule_batch_size(batch_size_attribute_name="batch_size")
def train_full_data(self):
"""Train the model on train + validation."""
# Reinit classifier
Expand All @@ -743,6 +738,7 @@ def test(self) -> None:
# train module to handle cross validation

@typing.no_type_check
@automatic_datamodule_batch_size(batch_size_attribute_name="batch_size")
def test_full_data(self) -> None:
"""Test model trained on full dataset."""
self.config.datamodule.class_to_idx = self.datamodule.full_dataset.class_to_idx
Expand Down
59 changes: 41 additions & 18 deletions quadra/utils/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,28 @@ def decorator(func: Callable):
def wrapper(self, *args, **kwargs):
"""Wrapper function."""
is_func_finished = False
starting_batch_size = None
automatic_batch_size_completed = False

if hasattr(self, "automatic_batch_size_completed"):
automatic_batch_size_completed = self.automatic_batch_size_completed

if hasattr(self, "automatic_batch_size"):
if not hasattr(self.automatic_batch_size, "disable") or not hasattr(
self.automatic_batch_size, "starting_batch_size"
):
raise ValueError(
"The automatic_batch_size attribute should have the disable and starting_batch_size attributes"
)
starting_batch_size = (
self.automatic_batch_size.starting_batch_size if not self.automatic_batch_size.disable else None
)

if starting_batch_size is not None and not automatic_batch_size_completed:
# If we already tried to reduce the batch size, we will start from the last batch size
log.info("Performing automatic batch size scaling from %d", starting_batch_size)
setattr(self.datamodule, batch_size_attribute_name, starting_batch_size)

while not is_func_finished:
valid_exceptions = (RuntimeError,)

Expand All @@ -426,25 +448,26 @@ def wrapper(self, *args, **kwargs):

try:
func(self, *args, **kwargs)
is_func_finished = True
self.automatic_batch_size_completed = True
if torch.cuda.is_available():
torch.cuda.empty_cache()
except valid_exceptions as e:
if "out of memory" in str(e) or "Failed to allocate memory" in str(e):
current_batch_size = getattr(self.datamodule, batch_size_attribute_name)
setattr(self.datamodule, batch_size_attribute_name, current_batch_size // 2)
log.warning(
"The function %s went out of memory, trying to reduce the batch size to %d",
func.__name__,
self.datamodule.batch_size,
)

if self.datamodule.batch_size == 0:
raise RuntimeError(
f"Unable to run {func.__name__} with batch size 1, the program will exit"
) from e
continue

raise e

is_func_finished = True
current_batch_size = getattr(self.datamodule, batch_size_attribute_name)
setattr(self.datamodule, batch_size_attribute_name, current_batch_size // 2)
log.warning(
"The function %s went out of memory, trying to reduce the batch size to %d",
func.__name__,
self.datamodule.batch_size,
)

if self.datamodule.batch_size == 0:
raise RuntimeError(
f"Unable to run {func.__name__} with batch size 1, the program will exit"
) from e

if torch.cuda.is_available():
torch.cuda.empty_cache()

return wrapper

Expand Down

0 comments on commit 1640df6

Please sign in to comment.