Skip to content

Commit

Permalink
Merge pull request #3748 from sungchul2/enable-data-config-override
Browse files Browse the repository at this point in the history
Enable to override data configurations
  • Loading branch information
sungchul2 authored Jul 23, 2024
2 parents a15e79a + 3f8598b commit 632f5c6
Show file tree
Hide file tree
Showing 72 changed files with 1,077 additions and 916 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ All notable changes to this project will be documented in this file.

### Enhancements

- Enable to override data configurations
(<https://github.com/openvinotoolkit/training_extensions/pull/3748>)

### Bug fixes

## \[v2.1.0\]
Expand Down
96 changes: 88 additions & 8 deletions src/otx/cli/utils/jsonargparse.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Functions related to jsonargparse."""

# Copyright (C) 2023 Intel Corporation
# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations
Expand Down Expand Up @@ -136,22 +136,48 @@ def apply_config(self: ActionConfigFile, parser: ArgumentParser, cfg: Namespace,
cfg.__dict__.update(cfg_merged.__dict__)
overrides = cfg.__dict__.pop("overrides", None)
if overrides is not None:
# This is a feature to handle the callbacks & logger override for user-convinience
list_override(configs=cfg, key="callbacks", overrides=overrides.pop("callbacks", []))
list_override(configs=cfg, key="logger", overrides=overrides.pop("logger", []))
apply_override(cfg, overrides)
cfg.update(overrides)
if cfg.get(dest) is None:
cfg[dest] = []
cfg[dest].append(cfg_path)


def list_override(configs: Namespace, key: str, overrides: list) -> None:
def namespace_override(
configs: Namespace,
key: str,
overrides: Namespace,
convert_dict_to_namespace: bool = True,
) -> None:
"""Overrides the nested namespace type in the given configs with the provided overrides.
Args:
configs (Namespace): The configuration object containing the key.
key (str): key of the configs want to override.
overrides (Namespace): The configuration object to override the existing ones.
convert_dict_to_namespace (bool): Whether to convert the dictionary to Namespace. Defaults to True.
"""
for sub_key, sub_value in overrides.items():
if isinstance(sub_value, list) and all(isinstance(sv, dict) for sv in sub_value):
# only enable list of dictionary items
list_override(
configs=configs[key],
key=sub_key,
overrides=sub_value,
convert_dict_to_namespace=convert_dict_to_namespace,
)
else:
configs[key].update(sub_value, sub_key)


def list_override(configs: Namespace, key: str, overrides: list, convert_dict_to_namespace: bool = True) -> None:
"""Overrides the nested list type in the given configs with the provided override_list.
Args:
configs (Namespace): The configuration object containing the key.
key (str): key of the configs want to override.
overrides (list): The list of dictionary item to override the existing ones.
convert_dict_to_namespace (bool): Whether to convert the dictionary to Namespace. Defaults to True.
Example:
>>> configs = [
Expand Down Expand Up @@ -179,6 +205,28 @@ def list_override(configs: Namespace, key: str, overrides: list) -> None:
... ),
... ...
... ]
>>> append_callbacks = [
... {
... 'class_path': 'new_callbacks',
... },
... ]
>>> list_override(configs=configs, key="callbacks", overrides=append_callbacks)
>>> configs = [
... ...
... Namespace(class_path='new_callbacks'),
... ]
>>> append_callbacks_as_dict = [
... {
... 'class_path': 'new_callbacks1',
... },
... ]
>>> list_override(
... configs=configs, key="callbacks", overrides=append_callbacks_as_dict, convert_dict_to_namespace=False
... )
>>> configs = [
... ...
... {'class_path': 'new_callbacks1'},
... ]
"""
if key not in configs or configs[key] is None:
return
Expand All @@ -192,7 +240,40 @@ def list_override(configs: Namespace, key: str, overrides: list) -> None:
if item is not None:
Namespace(item).update(target)
else:
configs[key].append(dict_to_namespace(target))
converted_target = dict_to_namespace(target) if convert_dict_to_namespace else target
configs[key].append(converted_target)


def apply_override(cfg: Namespace, overrides: Namespace) -> None:
"""Overrides the provided overrides in the given configs.
Args:
configs (Namespace): The configuration object containing the key.
overrides (Namespace): The configuration object to override the existing ones.
"""
# replace the config with the overrides for keys in reset list
reset = overrides.pop("reset", [])
if isinstance(reset, str):
reset = [reset]
for key in reset:
if key in overrides:
# callbacks, logger -> update to namespace
# rest -> use dict as is
cfg[key] = (
[dict_to_namespace(o) for o in overrides.pop(key)]
if key in ("callbacks", "logger")
else overrides.pop(key)
)

# This is a feature to handle the callbacks, logger, and data override for user-convinience
list_override(configs=cfg, key="callbacks", overrides=overrides.pop("callbacks", []))
list_override(configs=cfg, key="logger", overrides=overrides.pop("logger", []))
namespace_override(
configs=cfg,
key="data",
overrides=overrides.pop("data", Namespace()),
convert_dict_to_namespace=False,
)


# [FIXME] harimkang: have to see if there's a better way to do it. (For now, Added 2 lines to existing function)
Expand Down Expand Up @@ -235,8 +316,7 @@ def get_defaults_with_overrides(self: ArgumentParser, skip_check: bool = False)
cfg_file = self._load_config_parser_mode(default_config_file.get_content(), key=key)
cfg = self.merge_config(cfg_file, cfg)
overrides = cfg.__dict__.pop("overrides", {})
list_override(configs=cfg, key="callbacks", overrides=overrides.pop("callbacks", []))
list_override(configs=cfg, key="logger", overrides=overrides.pop("logger", []))
apply_override(cfg, overrides)
if overrides is not None:
cfg.update(overrides)
try:
Expand Down
47 changes: 44 additions & 3 deletions src/otx/recipe/_base_/data/detection.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,24 @@ train_subset:
num_workers: 2
to_tv_image: false
transforms:
- class_path: torchvision.transforms.v2.ToImage
- class_path: otx.core.data.transform_libs.torchvision.MinIoURandomCrop
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
scale:
- 800
- 992
transform_bbox: true
- class_path: otx.core.data.transform_libs.torchvision.RandomFlip
init_args:
prob: 0.5
is_numpy_to_tvtensor: true
- class_path: torchvision.transforms.v2.ToDtype
init_args:
dtype: ${as_torch_dtype:torch.float32}
- class_path: torchvision.transforms.v2.Normalize
init_args:
mean: [0.0, 0.0, 0.0]
std: [255.0, 255.0, 255.0]
sampler:
class_path: torch.utils.data.RandomSampler

Expand All @@ -23,7 +40,19 @@ val_subset:
num_workers: 2
to_tv_image: false
transforms:
- class_path: torchvision.transforms.v2.ToImage
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
scale:
- 800
- 992
is_numpy_to_tvtensor: true
- class_path: torchvision.transforms.v2.ToDtype
init_args:
dtype: ${as_torch_dtype:torch.float32}
- class_path: torchvision.transforms.v2.Normalize
init_args:
mean: [0.0, 0.0, 0.0]
std: [255.0, 255.0, 255.0]
sampler:
class_path: torch.utils.data.RandomSampler

Expand All @@ -34,6 +63,18 @@ test_subset:
num_workers: 2
to_tv_image: false
transforms:
- class_path: torchvision.transforms.v2.ToImage
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
scale:
- 800
- 992
is_numpy_to_tvtensor: true
- class_path: torchvision.transforms.v2.ToDtype
init_args:
dtype: ${as_torch_dtype:torch.float32}
- class_path: torchvision.transforms.v2.Normalize
init_args:
mean: [0.0, 0.0, 0.0]
std: [255.0, 255.0, 255.0]
sampler:
class_path: torch.utils.data.RandomSampler
72 changes: 63 additions & 9 deletions src/otx/recipe/_base_/data/instance_segmentation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,86 @@ unannotated_items_ratio: 0.0
train_subset:
subset_name: train
transform_lib_type: TORCHVISION
to_tv_image: true
transforms:
- class_path: torchvision.transforms.v2.ToImage
batch_size: 1
num_workers: 2
to_tv_image: true
transforms:
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
keep_ratio: true
transform_bbox: true
transform_mask: true
scale:
- 1024
- 1024
- class_path: otx.core.data.transform_libs.torchvision.Pad
init_args:
pad_to_square: true
transform_mask: true
- class_path: otx.core.data.transform_libs.torchvision.RandomFlip
init_args:
prob: 0.5
is_numpy_to_tvtensor: true
- class_path: torchvision.transforms.v2.ToDtype
init_args:
dtype: ${as_torch_dtype:torch.float32}
- class_path: torchvision.transforms.v2.Normalize
init_args:
mean: [123.675, 116.28, 103.53]
std: [58.395, 57.12, 57.375]
sampler:
class_path: torch.utils.data.RandomSampler

val_subset:
subset_name: val
transform_lib_type: TORCHVISION
to_tv_image: true
transforms:
- class_path: torchvision.transforms.v2.ToImage
batch_size: 1
num_workers: 2
to_tv_image: true
transforms:
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
keep_ratio: true
scale:
- 1024
- 1024
- class_path: otx.core.data.transform_libs.torchvision.Pad
init_args:
pad_to_square: true
is_numpy_to_tvtensor: true
- class_path: torchvision.transforms.v2.ToDtype
init_args:
dtype: ${as_torch_dtype:torch.float32}
- class_path: torchvision.transforms.v2.Normalize
init_args:
mean: [123.675, 116.28, 103.53]
std: [58.395, 57.12, 57.375]
sampler:
class_path: torch.utils.data.RandomSampler

test_subset:
subset_name: test
transform_lib_type: TORCHVISION
to_tv_image: true
transforms:
- class_path: torchvision.transforms.v2.ToImage
batch_size: 1
num_workers: 2
to_tv_image: true
transforms:
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
keep_ratio: true
scale:
- 1024
- 1024
- class_path: otx.core.data.transform_libs.torchvision.Pad
init_args:
pad_to_square: true
is_numpy_to_tvtensor: true
- class_path: torchvision.transforms.v2.ToDtype
init_args:
dtype: ${as_torch_dtype:torch.float32}
- class_path: torchvision.transforms.v2.Normalize
init_args:
mean: [123.675, 116.28, 103.53]
std: [58.395, 57.12, 57.375]
sampler:
class_path: torch.utils.data.RandomSampler
60 changes: 57 additions & 3 deletions src/otx/recipe/_base_/data/rotated_detection.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,29 @@ train_subset:
transform_lib_type: TORCHVISION
to_tv_image: false
transforms:
- class_path: torchvision.transforms.v2.ToImage
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
keep_ratio: true
transform_bbox: true
transform_mask: true
scale:
- 1024
- 1024
- class_path: otx.core.data.transform_libs.torchvision.Pad
init_args:
size_divisor: 32
transform_mask: true
- class_path: otx.core.data.transform_libs.torchvision.RandomFlip
init_args:
prob: 0.5
is_numpy_to_tvtensor: true
- class_path: torchvision.transforms.v2.ToDtype
init_args:
dtype: ${as_torch_dtype:torch.float32}
- class_path: torchvision.transforms.v2.Normalize
init_args:
mean: [123.675, 116.28, 103.53]
std: [58.395, 57.12, 57.375]
batch_size: 1
num_workers: 2
sampler:
Expand All @@ -22,7 +44,23 @@ val_subset:
transform_lib_type: TORCHVISION
to_tv_image: false
transforms:
- class_path: torchvision.transforms.v2.ToImage
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
keep_ratio: true
scale:
- 1024
- 1024
- class_path: otx.core.data.transform_libs.torchvision.Pad
init_args:
size_divisor: 32
is_numpy_to_tvtensor: true
- class_path: torchvision.transforms.v2.ToDtype
init_args:
dtype: ${as_torch_dtype:torch.float32}
- class_path: torchvision.transforms.v2.Normalize
init_args:
mean: [123.675, 116.28, 103.53]
std: [58.395, 57.12, 57.375]
batch_size: 1
num_workers: 2
sampler:
Expand All @@ -33,7 +71,23 @@ test_subset:
transform_lib_type: TORCHVISION
to_tv_image: false
transforms:
- class_path: torchvision.transforms.v2.ToImage
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
keep_ratio: true
scale:
- 1024
- 1024
- class_path: otx.core.data.transform_libs.torchvision.Pad
init_args:
size_divisor: 32
is_numpy_to_tvtensor: true
- class_path: torchvision.transforms.v2.ToDtype
init_args:
dtype: ${as_torch_dtype:torch.float32}
- class_path: torchvision.transforms.v2.Normalize
init_args:
mean: [123.675, 116.28, 103.53]
std: [58.395, 57.12, 57.375]
batch_size: 1
num_workers: 2
sampler:
Expand Down
Loading

0 comments on commit 632f5c6

Please sign in to comment.