-
Notifications
You must be signed in to change notification settings - Fork 446
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Engine Refactor Proposal #3752
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from jsonargparse import ActionConfigFile, ArgumentParser, namespace_to_dict | ||
|
||
from otx.cli.utils.jsonargparse import get_short_docstring | ||
from otx.engine.engine_poc import Engine | ||
|
||
|
||
class CLI: | ||
"""CLI. | ||
|
||
Limited CLI to show how the api does not change externally while retaining the ability to expose models from the | ||
adapters. | ||
""" | ||
|
||
def __init__(self): | ||
self.parser = ArgumentParser() | ||
self.parser.add_argument( | ||
"--config", | ||
action=ActionConfigFile, | ||
help="Configuration file in JSON format.", | ||
) | ||
self.add_subcommands() | ||
self.run() | ||
|
||
def subcommands(self): | ||
return ["train", "test"] | ||
|
||
def add_subcommands(self): | ||
parser_subcommand = self.parser.add_subcommands() | ||
for subcommand in self.subcommands(): | ||
subparser = ArgumentParser() | ||
subparser.add_method_arguments(Engine, subcommand) | ||
fn = getattr(Engine, subcommand) | ||
description = get_short_docstring(fn) | ||
parser_subcommand.add_subcommand(subcommand, subparser, help=description) | ||
|
||
def run(self): | ||
args = self.parser.parse_args() | ||
args_dict = namespace_to_dict(args) | ||
engine = Engine() | ||
# do something here | ||
|
||
|
||
if __name__ == "__main__": | ||
CLI() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from anomalib.data import AnomalibDataModule | ||
from anomalib.engine import Engine as AnomalibEngine | ||
from anomalib.models import AnomalyModule | ||
|
||
from otx.core.data.module import OTXDataModule | ||
from otx.engine.base import METRICS, Adapter | ||
|
||
|
||
def wrap_to_anomalib_datamodule(datamodule: OTXDataModule) -> AnomalibDataModule: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we really need this? |
||
"""Mock function to wrap OTXDataModule to AnomalibDataModule.""" | ||
return AnomalibDataModule( | ||
train=datamodule.train, | ||
val=datamodule.val, | ||
test=datamodule.test, | ||
batch_size=datamodule.batch_size, | ||
num_workers=datamodule.num_workers, | ||
pin_memory=datamodule.pin_memory, | ||
shuffle=datamodule.shuffle, | ||
) | ||
|
||
|
||
class AnomalibAdapter(Adapter): | ||
def __init__(self): | ||
self._engine = AnomalibEngine() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It feels like the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, the naming is super tricky here. The problem is that internally, Anomalib Engine also has a trainer. So currently, it is |
||
|
||
def train( | ||
self, | ||
model: AnomalyModule, | ||
datamodule: OTXDataModule | AnomalibDataModule, | ||
max_epochs: int = 1, | ||
**kwargs, | ||
) -> METRICS: | ||
if not isinstance(datamodule, AnomalibDataModule): | ||
datamodule = wrap_to_anomalib_datamodule(datamodule) | ||
self._engine = AnomalibEngine(max_epochs=max_epochs, **kwargs) | ||
return self._engine.train(model=model, datamodule=datamodule) | ||
|
||
def test( | ||
self, | ||
model: AnomalyModule, | ||
datamodule: OTXDataModule | AnomalibDataModule, | ||
max_epochs: int = 1, | ||
**kwargs, | ||
) -> METRICS: | ||
if not isinstance(datamodule, AnomalibDataModule): | ||
datamodule = wrap_to_anomalib_datamodule(datamodule) | ||
self._engine = AnomalibEngine(max_epochs=max_epochs, **kwargs) | ||
return self._engine.test(model=model, datamodule=datamodule) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Any | ||
|
||
METRICS = dict[str, float] | ||
ANNOTATIONS = Any | ||
|
||
|
||
class Adapter(ABC): | ||
@abstractmethod | ||
def train(self, **kwargs) -> METRICS: | ||
pass | ||
|
||
@abstractmethod | ||
def test(self, **kwargs) -> METRICS: | ||
pass | ||
|
||
# @abstractmethod | ||
# def predict(self, **kwargs) -> ANNOTATIONS: | ||
# pass | ||
|
||
# @abstractmethod | ||
# def export(self, **kwargs) -> Path: | ||
# pass | ||
|
||
# @abstractmethod | ||
# def optimize(self, **kwargs) -> Path: | ||
# pass | ||
|
||
# @abstractmethod | ||
# def explain(self, **kwargs) -> list[Tensor]: | ||
# pass | ||
|
||
# @abstractmethod | ||
# @classmethod | ||
# def from_config(cls, **kwargs) -> "Backend": | ||
# pass | ||
|
||
# @abstractmethod | ||
# @classmethod | ||
# def from_model_name(cls, **kwargs) -> "Backend": | ||
# pass |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import logging | ||
|
||
from anomalib.data import AnomalibDataModule | ||
from anomalib.models import AnomalyModule | ||
from ultralytics.data import ClassificationDataset as UltralyticsDataset | ||
from ultralytics.engine.model import Model | ||
|
||
from otx.core.data.module import OTXDataModule | ||
from otx.core.model.base import OTXModel | ||
from otx.core.utils.cache import TrainerArgumentsCache | ||
from otx.engine.base import METRICS, Adapter | ||
|
||
from .anomalib import AnomalibAdapter | ||
from .lightning import LightningAdapter | ||
from .ultralytics import UltralyticsAdapter | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class Engine: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe that the current structure is an implementation that doesn't take auto-configuration into account at all. Is auto-configuration considered in the design? |
||
"""Automatically selects the engine based on the model passed to the engine.""" | ||
|
||
def __init__( | ||
self, | ||
**kwargs, | ||
) -> None: | ||
self._adapter: Adapter | None = None | ||
self._cache = TrainerArgumentsCache(**kwargs) | ||
|
||
@property | ||
def adapter(self) -> Adapter: | ||
return self._adapter | ||
|
||
@adapter.setter | ||
def adapter(self, adapter: adapter) -> None: | ||
self._adapter = adapter | ||
|
||
def get_adapter(self, model: OTXModel | AnomalyModule | Model) -> adapter: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's weird to separate adapters into models. I think it would be better to be able to distinguish between adapters and models and have them be received as Engine's def __init__(
self,
...
backends = Literal["base", "anomalib", "..."] = "base",
...
) -> None: |
||
if isinstance(model, AnomalyModule) and ( | ||
self.adapter is not isinstance(self.adapter, AnomalibAdapter) or self.adapter is None | ||
): | ||
self.adapter = AnomalibAdapter(**self._cache.args) | ||
elif isinstance(model, OTXModel) and (self.adapter is None or not isinstance(self.adapter, LightningAdapter)): | ||
self.adapter = LightningAdapter(**self._cache.args) | ||
elif isinstance(model, Model) and (self.adapter is None or not isinstance(self.adapter, UltralyticsAdapter)): | ||
self.adapter = UltralyticsAdapter(**self._cache.args) | ||
|
||
return self.adapter | ||
|
||
def train( | ||
self, | ||
model: OTXModel | AnomalyModule | Model, | ||
datamodule: OTXDataModule | AnomalibDataModule | UltralyticsDataset, | ||
**kwargs, | ||
) -> METRICS: | ||
Comment on lines
+50
to
+55
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently, the arguments for each function in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. def train(
self,
max_epochs: int = 10,
seed: int | None = None,
deterministic: bool | Literal["warn"] = False,
precision: _PRECISION_INPUT | None = "32",
callbacks: list[Callback] | Callback | None = None,
logger: Logger | Iterable[Logger] | bool | None = None,
resume: bool = False,
metric: MetricCallable | None = None,
run_hpo: bool = False,
hpo_config: HpoConfig = HpoConfig(), # noqa: B008 https://github.com/omni-us/jsonargparse/issues/423
checkpoint: PathLike | None = None,
adaptive_bs: Literal["None", "Safe", "Full"] = "None",
**kwargs,
) -> dict[str, Any]: I like the current way of doing things better, where we clearly define what we need for each command as an argument to each function, as shown above. Also, I don't see any reason to change the model through the function. Current OTX: Engine is created (Model is determined) -> E2E flow is used as this model. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I prefer explicit arguments as well. I didn't include all of them here as I wanted to focus on the adapter. I think we should have explicit arguments also so that these will be serialised by the CLI when creating the YAML file. Maybe this will get polished after a few iterations of discussions. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I see, and the reason I commented was that I thought it would be a good idea to avoid breaking existing Engine Usage as much as possible. |
||
"""Train the model.""" | ||
adapter: Adapter = self.get_adapter(model) | ||
adapter.train(model=model, datamodule=datamodule, **kwargs) | ||
|
||
def test( | ||
self, | ||
model: OTXModel | AnomalyModule | Model, | ||
datamodule: OTXDataModule | AnomalibDataModule | UltralyticsDataset, | ||
) -> METRICS: | ||
"""Test the model.""" | ||
adapter = self.get_adapter(model) | ||
return adapter.test(model=model, datamodule=datamodule) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
import logging | ||
from collections.abc import Iterator | ||
from contextlib import contextmanager | ||
from pathlib import Path | ||
|
||
from lightning.pytorch import Trainer | ||
|
||
from otx.core.data.module import OTXDataModule | ||
from otx.core.metrics import MetricCallable | ||
from otx.core.model.base import OTXModel | ||
from otx.core.types.task import OTXTaskType | ||
from otx.core.utils.cache import TrainerArgumentsCache | ||
|
||
from .base import Adapter | ||
|
||
|
||
@contextmanager | ||
def override_metric_callable(model: OTXModel, new_metric_callable: MetricCallable | None) -> Iterator[OTXModel]: | ||
"""Override `OTXModel.metric_callable` to change the evaluation metric. | ||
|
||
Args: | ||
model: Model to override its metric callable | ||
new_metric_callable: If not None, override the model's one with this. Otherwise, do not override. | ||
""" | ||
if new_metric_callable is None: | ||
yield model | ||
return | ||
|
||
orig_metric_callable = model.metric_callable | ||
try: | ||
model.metric_callable = new_metric_callable | ||
yield model | ||
finally: | ||
model.metric_callable = orig_metric_callable | ||
|
||
|
||
class LightningAdapter(Adapter): | ||
"""OTX Engine. | ||
|
||
This is a temporary name and we can change it later. It is basically a subset of what is currently present in the | ||
original OTX Engine class (engine.py) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
datamodule: OTXDataModule | None = None, | ||
model: OTXModel | str | None = None, | ||
task: OTXTaskType | None = None, | ||
**kwargs, | ||
): | ||
self._cache = TrainerArgumentsCache(**kwargs) | ||
self.task = task | ||
self._trainer: Trainer | None = None | ||
self._datamodule: OTXDataModule = datamodule | ||
self._model: OTXModel = model | ||
|
||
def train( | ||
self, | ||
model: OTXModel | None = None, | ||
datamodule: OTXDataModule | None = None, | ||
max_epochs: int = 10, | ||
deterministic: bool = True, | ||
val_check_interval: int | float | None = 1, | ||
metric: MetricCallable | None = None, | ||
) -> dict[str, float]: | ||
if model is not None: | ||
self.model = model | ||
if datamodule is not None: | ||
self.datamodule = datamodule | ||
self._build_trainer( | ||
logger=None, | ||
callbacks=None, | ||
max_epochs=max_epochs, | ||
deterministic=deterministic, | ||
val_check_interval=val_check_interval, | ||
) | ||
|
||
# NOTE: Model's label info should be converted datamodule's label info before ckpt loading | ||
# This is due to smart weight loading check label name as well as number of classes. | ||
if self.model.label_info != self.datamodule.label_info: | ||
msg = ( | ||
"Model label_info is not equal to the Datamodule label_info. " | ||
f"It will be overriden: {self.model.label_info} => {self.datamodule.label_info}" | ||
) | ||
logging.warning(msg) | ||
self.model.label_info = self.datamodule.label_info | ||
|
||
with override_metric_callable(model=self.model, new_metric_callable=metric) as model: | ||
self.trainer.fit( | ||
model=model, | ||
datamodule=self.datamodule, | ||
) | ||
self.checkpoint = self.trainer.checkpoint_callback.best_model_path | ||
|
||
if not isinstance(self.checkpoint, (Path, str)): | ||
msg = "self.checkpoint should be Path or str at this time." | ||
raise TypeError(msg) | ||
|
||
best_checkpoint_symlink = Path(self.work_dir) / "best_checkpoint.ckpt" | ||
if best_checkpoint_symlink.is_symlink(): | ||
best_checkpoint_symlink.unlink() | ||
best_checkpoint_symlink.symlink_to(self.checkpoint) | ||
|
||
return self.trainer.callback_metrics | ||
|
||
def test(self, **kwargs) -> dict[str, float]: | ||
pass | ||
|
||
@property | ||
def trainer(self) -> Trainer: | ||
"""Returns the trainer object associated with the engine. | ||
|
||
To get this property, you should execute `Engine.train()` function first. | ||
|
||
Returns: | ||
Trainer: The trainer object. | ||
""" | ||
if self._trainer is None: | ||
msg = "Please run train() first" | ||
raise RuntimeError(msg) | ||
return self._trainer | ||
|
||
def _build_trainer(self, **kwargs) -> None: | ||
"""Instantiate the trainer based on the model parameters.""" | ||
if self._cache.requires_update(**kwargs) or self._trainer is None: | ||
self._cache.update(**kwargs) | ||
|
||
kwargs = self._cache.args | ||
self._trainer = Trainer(**kwargs) | ||
self._cache.is_trainer_args_identical = True | ||
self._trainer.task = self.task | ||
self.work_dir = self._trainer.default_root_dir | ||
|
||
@property | ||
def model(self) -> OTXModel: | ||
return self._model | ||
|
||
@model.setter | ||
def model(self, model: OTXModel) -> None: | ||
self._model = model | ||
|
||
@property | ||
def datamodule(self) -> OTXDataModule: | ||
return self._datamodule | ||
|
||
@datamodule.setter | ||
def datamodule(self, datamodule: OTXDataModule) -> None: | ||
self._datamodule = datamodule |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .adapter import UltralyticsAdapter | ||
|
||
__all__ = ["UltralyticsAdapter"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my opinion, the main focus of this PR is refactoring for Engine-related interfaces and scalability.
Currently, our CLI already has the functionality of this code, so if this PoC is going to merge into develop, we need to look at compatibility with our existing CLI code.
I think the following discussion should be made to the existing CLI code with this changes.
DataModule
other thanOTXDataModule
. -> I think this is possible with some tweaks toadd_argument
.Dataset
from external frameworks other thanDataModule
in the CLI.Anyway, for the CLI part, I think it's better to wait until Engine's design is finalized and see if there are any issues.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, the main intention of this PR is to propose the change to the Engine class. I included CLI for completeness, and also to ensure that the new design does not break it.