Skip to content

Commit

Permalink
Merge pull request #448 from mit-ll-responsible-ai/improve-zen
Browse files Browse the repository at this point in the history
Drop support for hydra-core < 1.2.0. Improve `zen` to return dataclasses and stdlib containers over omegaconf objects
  • Loading branch information
rsokl authored Apr 14, 2023
2 parents 78dd8e2 + 3f73b6c commit cdf6fb2
Show file tree
Hide file tree
Showing 19 changed files with 181 additions and 456 deletions.
17 changes: 16 additions & 1 deletion docs/source/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,21 @@ The following parts of the documentation underwent significant revisions:
- `The landing page <https://github.com/mit-ll-responsible-ai/hydra-zen>`_ now has a "hydra-zen at at glance" subsection.
- The docs for `~hydra_zen.ZenStore` were revamped.

.. _v0.11.0:

---------------------
0.11.0rc - 2023-04-09
---------------------
This release drops support for hydra-core 1.1 and for omegaconf 2.1; this enables hydra-zen to remove a lot of complex compatibility logic and to improve the behavior
of :func:`~hydra_zen.zen`.


Compatibility-Breaking Changes
------------------------------
- The auto-instantiation behavior of :class:`~hydra_zen.wrapper.Zen` and :func:`~hydra_zen.zen` have been updated so that nested dataclasses (nested within lists, dicts, and other dataclasses) will no longer be returned as omegaconf configs (see :pull:`448`).
- hydra-core 1.2.0 and omegaconf 2.2.1 are now the minimum supported versions.


.. _v0.10.0:

-------------------
Expand Down Expand Up @@ -228,7 +243,7 @@ New Features
------------
- hydra-zen now supports Python 3.11
- Adds the :func:`~hydra_zen.zen` decorator (see :pull:`310`)
- Adds the :func:`~hydra_zen.wrapper.Zen` decorator-class (see :pull:`310`)
- Adds the :class:`~hydra_zen.wrapper.Zen` decorator-class (see :pull:`310`)
- Adds the :class:`~hydra_zen.ZenStore` class (see :pull:`331`)
- Adds `hyda_zen.store`, which is a pre-initialized instance of :class:`~hydra_zen.ZenStore` (see :pull:`331`)
- The option `hydra_convert='object'` is now supported by all of hydra-zen's config-creation functions. So that an instantiated structured config can be converted to an instance of its backing dataclass. This feature was added by `Hydra 1.3.0 <https://github.com/facebookresearch/hydra/issues/1719>`_.
Expand Down
9 changes: 6 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ dynamic = ["version"]
description = "Configurable, reproducible, and scalable workflows in Python, via Hydra"
readme = "README.md"
requires-python = ">=3.7"
dependencies = ["hydra-core >= 1.1.0", "typing-extensions >= 4.1.0"]
dependencies = ["hydra-core >= 1.2.0",
"omegaconf >= 2.2.1",
"typing-extensions >= 4.1.0",
]
license = { text = "MIT" }
keywords = [
"machine learning",
Expand Down Expand Up @@ -147,8 +150,8 @@ commands = pytest tests/ {posargs: -n auto --maxprocesses=4}
[testenv:min-deps]
description = Runs test suite against minimum supported versions of dependencies.
deps = hydra-core==1.1.0
omegaconf==2.1.1
deps = hydra-core==1.2.0
omegaconf==2.2.1
typing-extensions==4.1.0
{[testenv]deps}
basepython = python3.7
Expand Down
82 changes: 32 additions & 50 deletions src/hydra_zen/_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from enum import Enum
from functools import partial
from pathlib import Path, PosixPath, WindowsPath
from typing import NamedTuple, Set
from typing import FrozenSet, NamedTuple

import hydra
import omegaconf
Expand Down Expand Up @@ -36,58 +36,40 @@ def _get_version(ver_str: str) -> Version:
OMEGACONF_VERSION: Final = _get_version(omegaconf.__version__)
HYDRA_VERSION: Final = _get_version(hydra.__version__)

SUPPORTS_VERSION_BASE = HYDRA_VERSION >= (1, 2, 0)

# OmegaConf issue 830 describes a bug associated with structured configs
# composed via inheritance, where the child's attribute is a default-factory
# and the parent's corresponding attribute is not.
# We provide downstream workarounds until an upstream fix is released.
#
# Uncomment dynamic setting once OmegaConf merges fix:
# https://github.com/omry/omegaconf/pull/832
PATCH_OMEGACONF_830: Final = OMEGACONF_VERSION < Version(2, 2, 1)

# Hydra's instantiate API now supports partial-instantiation, indicated
# by a `_partial_ = True` attribute.
# https://github.com/facebookresearch/hydra/pull/1905
HYDRA_SUPPORTS_PARTIAL: Final = Version(1, 1, 1) < HYDRA_VERSION

HYDRA_SUPPORTS_NESTED_CONTAINER_TYPES: Final = OMEGACONF_VERSION >= Version(2, 2, 0)
HYDRA_SUPPORTS_BYTES: Final = OMEGACONF_VERSION >= Version(2, 2, 0)
HYDRA_SUPPORTS_Path: Final = OMEGACONF_VERSION >= Version(2, 2, 1)

# Indicates primitive types permitted in type-hints of structured configs
HYDRA_SUPPORTED_PRIMITIVE_TYPES: Final = {int, float, bool, str, Enum}
HYDRA_SUPPORTED_PRIMITIVE_TYPES: Final = frozenset(
{int, float, bool, str, Enum, bytes, Path}
)
# Indicates types of primitive values permitted in configs
HYDRA_SUPPORTED_PRIMITIVES = {int, float, bool, str, list, tuple, dict, NoneType}
ZEN_SUPPORTED_PRIMITIVES: Set[type] = {
set,
frozenset,
complex,
partial,
bytearray,
deque,
Counter,
range,
}

HYDRA_SUPPORTS_LIST_INSTANTIATION = HYDRA_VERSION >= Version(1, 1, 2)


if HYDRA_SUPPORTS_BYTES: # pragma: no cover
HYDRA_SUPPORTED_PRIMITIVES.add(bytes)
HYDRA_SUPPORTED_PRIMITIVE_TYPES.add(bytes)
else: # pragma: no cover
ZEN_SUPPORTED_PRIMITIVES.add(bytes)

_path_types = {Path, PosixPath, WindowsPath}

if HYDRA_SUPPORTS_Path: # pragma: no cover
HYDRA_SUPPORTED_PRIMITIVES.update(_path_types)
HYDRA_SUPPORTED_PRIMITIVE_TYPES.add(Path)
else: # pragma: no cover
ZEN_SUPPORTED_PRIMITIVES.update(_path_types)
HYDRA_SUPPORTED_PRIMITIVES = frozenset(
{
int,
float,
bool,
str,
list,
tuple,
dict,
NoneType,
bytes,
Path,
PosixPath,
WindowsPath,
}
)
ZEN_SUPPORTED_PRIMITIVES: FrozenSet[type] = frozenset(
{
set,
frozenset,
complex,
partial,
bytearray,
deque,
Counter,
range,
}
)

del _path_types

HYDRA_SUPPORTS_OBJECT_CONVERT = HYDRA_VERSION >= Version(1, 3, 0)
7 changes: 1 addition & 6 deletions src/hydra_zen/_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from omegaconf import DictConfig, ListConfig, OmegaConf
from typing_extensions import Literal, TypeAlias

from hydra_zen._compatibility import SUPPORTS_VERSION_BASE
from hydra_zen._hydra_overloads import instantiate
from hydra_zen.typing._implementations import DataClass_, InstOrType

Expand Down Expand Up @@ -430,11 +429,7 @@ def launch(
with initialize(
config_path=None,
job_name=job_name,
**(
{}
if (not SUPPORTS_VERSION_BASE or version_base is _NotSet)
else {"version_base": version_base}
),
**({} if version_base is _NotSet else {"version_base": version_base}),
):
# taken from hydra.compose with support for MULTIRUN
gh = GlobalHydra.instance()
Expand Down
4 changes: 1 addition & 3 deletions src/hydra_zen/structured_configs/_globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from typing_extensions import Final

from hydra_zen._compatibility import HYDRA_SUPPORTS_PARTIAL
from hydra_zen.funcs import get_obj, zen_processing
from hydra_zen.structured_configs import _utils

Expand All @@ -22,10 +21,9 @@
RECURSIVE_FIELD_NAME,
CONVERT_FIELD_NAME,
POS_ARG_FIELD_NAME,
PARTIAL_FIELD_NAME,
]

if HYDRA_SUPPORTS_PARTIAL: # pragma: no cover
_names.append(PARTIAL_FIELD_NAME)

HYDRA_FIELD_NAMES: FrozenSet[str] = frozenset(_names)

Expand Down
63 changes: 7 additions & 56 deletions src/hydra_zen/structured_configs/_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@

from hydra_zen._compatibility import (
HYDRA_SUPPORTED_PRIMITIVES,
HYDRA_SUPPORTS_PARTIAL,
PATCH_OMEGACONF_830,
ZEN_SUPPORTED_PRIMITIVES,
)
from hydra_zen.errors import (
Expand Down Expand Up @@ -441,22 +439,6 @@ def wrapper(decorated_obj: Any) -> Any:
)
decorated_obj = dataclass(**dc_options)(decorated_obj) # type: ignore

if PATCH_OMEGACONF_830 and 2 < len(decorated_obj.__mro__): # pragma: no cover
parents = decorated_obj.__mro__[1:-1]
# this class inherits from a parent
for field_ in fields(decorated_obj):
if field_.default_factory is not MISSING and any(
hasattr(p, field_.name) for p in parents
):
raise HydraZenValidationError(
"This config will not instantiate properly.\nThis is due to a "
"known bug in omegaconf: The config specifies a "
f"default-factory for field {field_.name}, and inherits from a "
"parent that specifies the same field with a non-factory value "
"-- the parent's value will take precedence.\nTo circumvent "
f"this upgrade to omegaconf 2.2.1 or higher."
)

if populate_full_signature:
# we need to ensure that the fields specified via the class definition
# take precedence over the fields that will be auto-populated by builds
Expand Down Expand Up @@ -797,7 +779,6 @@ def sanitized_field(
*,
error_prefix: str = "",
field_name: str = "",
_mutable_default_permitted: bool = True,
convert_dataclass: bool,
) -> Field[Any]:
value = sanitized_default_value(
Expand All @@ -815,16 +796,10 @@ def sanitized_field(
) or (
is_dataclass(value) and not isinstance(value, type) and value.__hash__ is None
):
if _mutable_default_permitted:
return cast(
Field[Any],
mutable_value(value, zen_convert={"dataclass": convert_dataclass}),
)
else: # pragma: no cover
value = builds(
type(value), value, zen_convert={"dataclass": convert_dataclass}
)

return cast(
Field[Any],
mutable_value(value, zen_convert={"dataclass": convert_dataclass}),
)
return _utils.field(default=value, init=init)


Expand Down Expand Up @@ -1594,7 +1569,7 @@ def builds(target, populate_full_signature=False, **kw):

for base in builds_bases:
_set_this_iteration = False
if HYDRA_SUPPORTS_PARTIAL and base_hydra_partial is None:
if base_hydra_partial is None:
base_hydra_partial = safe_getattr(base, PARTIAL_FIELD_NAME, None)
if parent_partial is None:
parent_partial = base_hydra_partial
Expand All @@ -1621,7 +1596,6 @@ def builds(target, populate_full_signature=False, **kw):
bool(zen_meta)
or bool(validated_wrappers)
or any(uses_zen_processing(b) for b in builds_bases)
or (bool(requires_partial_field) and not HYDRA_SUPPORTS_PARTIAL)
)

if base_zen_partial:
Expand Down Expand Up @@ -1665,7 +1639,7 @@ def builds(target, populate_full_signature=False, **kw):
_utils.field(default=bool(zen_partial), init=False),
),
)
if HYDRA_SUPPORTS_PARTIAL and base_hydra_partial:
if base_hydra_partial:
# Must explicitly set _partial_=False to prevent inheritance
target_field.append(
(
Expand Down Expand Up @@ -2148,27 +2122,6 @@ def builds(target, populate_full_signature=False, **kw):
value,
error_prefix=BUILDS_ERROR_PREFIX,
field_name=item[0],
_mutable_default_permitted=_utils.mutable_default_permitted(
builds_bases, name
),
convert_dataclass=zen_convert_settings["dataclass"],
)
elif (
PATCH_OMEGACONF_830
and builds_bases
and value.default_factory is not MISSING
): # pragma: no cover
# Addresses omegaconf #830 https://github.com/omry/omegaconf/issues/830
#
# Value was passed as a field-with-default-factory, we'll
# access the default from the factory and will reconstruct the field
_field = sanitized_field(
value.default_factory(),
error_prefix=BUILDS_ERROR_PREFIX,
field_name=item[0],
_mutable_default_permitted=_utils.mutable_default_permitted(
builds_bases, name
),
convert_dataclass=zen_convert_settings["dataclass"],
)
else:
Expand Down Expand Up @@ -2213,9 +2166,7 @@ def builds(target, populate_full_signature=False, **kw):

# _partial_=True should never be relied on when zen-processing is being used.
assert not (
HYDRA_SUPPORTS_PARTIAL
and requires_zen_processing
and safe_getattr(out, PARTIAL_FIELD_NAME, False)
requires_zen_processing and safe_getattr(out, PARTIAL_FIELD_NAME, False)
)

return cast(Union[Type[Builds[Importable]], Type[BuildsWithSig[Type[R], P]]], out)
Expand Down
Loading

0 comments on commit cdf6fb2

Please sign in to comment.