From 09b1f46d9aa50c0fb6241018a4bef482d5c6db11 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Thu, 4 Aug 2022 19:34:16 -0500 Subject: [PATCH] revert allow_objects behavior for unsupported annotations Revert a regression in how unsupported type annotations for structured config fields are handled when the allow_objects flag is True. See PR https://github.com/omry/omegaconf/pull/993 for details. --- omegaconf/omegaconf.py | 19 +++++- .../structured_conf/test_structured_config.py | 64 +++++++++++++++++++ 2 files changed, 82 insertions(+), 1 deletion(-) diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index 4a20e9788..9db564f9f 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -1050,7 +1050,24 @@ def _node_wrap( node = PathNode(value=value, key=key, parent=parent, is_optional=is_optional) else: if parent is not None and parent._get_flag("allow_objects") is True: - node = AnyNode(value=value, key=key, parent=parent) + if type(value) in (list, tuple): + node = ListConfig( + content=value, + key=key, + parent=parent, + ref_type=ref_type, + is_optional=is_optional, + ) + elif is_primitive_dict(value): + node = DictConfig( + content=value, + key=key, + parent=parent, + ref_type=ref_type, + is_optional=is_optional, + ) + else: + node = AnyNode(value=value, key=key, parent=parent) else: raise ValidationError(f"Unexpected type annotation: {type_str(ref_type)}") return node diff --git a/tests/structured_conf/test_structured_config.py b/tests/structured_conf/test_structured_config.py index 3ff5bd867..7d5b17406 100644 --- a/tests/structured_conf/test_structured_config.py +++ b/tests/structured_conf/test_structured_config.py @@ -5,6 +5,7 @@ import sys from importlib import import_module from pathlib import Path +from types import LambdaType from typing import Any, Callable, Dict, List, Optional, Tuple, Union from _pytest.python_api import RaisesContext @@ -992,6 +993,69 @@ def test_has_bad_annotation2(self, module: Any) -> None: ): OmegaConf.structured(module.HasBadAnnotation2) + @mark.parametrize( + "input_, expected, expected_type, expected_ref_type, expected_object_type", + [ + param( + lambda module: module.HasBadAnnotation1, + {"data": "???"}, + AnyNode, + Any, + None, + ), + param( + lambda module: module.HasBadAnnotation1(123), + {"data": 123}, + AnyNode, + Any, + None, + ), + param( + lambda module: module.HasBadAnnotation1([1, 2, 3]), + {"data": [1, 2, 3]}, + ListConfig, + object, + list, + ), + param( + lambda module: module.HasBadAnnotation1({1: 2}), + {"data": {1: 2}}, + DictConfig, + object, + dict, + ), + param( + lambda module: module.HasBadAnnotation1(module.UserWithDefaultName), + {"data": {"name": "bob", "age": "???"}}, + DictConfig, + object, + lambda module: module.UserWithDefaultName, + ), + ], + ) + def test_bad_annotation_allow_objects( + self, + module: Any, + input_: Any, + expected: Any, + expected_type: Any, + expected_ref_type: Any, + expected_object_type: Any, + ) -> None: + """ + Test how unsupported annotation types are handled when `allow_objects` is True + """ + input_ = input_(module) + if isinstance(expected_object_type, LambdaType): + expected_object_type = expected_object_type(module) + + cfg = OmegaConf.structured(input_, flags={"allow_objects": True}) + + assert cfg == expected + assert isinstance(cfg._get_node("data"), expected_type) + assert cfg._get_node("data")._metadata.ref_type is expected_ref_type + assert cfg._get_node("data")._metadata.object_type is expected_object_type + def validate_frozen_impl(conf: DictConfig) -> None: with raises(ReadonlyConfigError):