From d178c85f2ca09f4bf6710b630b5be621a717a566 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Garriga-Alonso?= Date: Thu, 9 May 2024 22:54:11 -0400 Subject: [PATCH] Fix #47: Can serialize Union with Literal --- databind/src/databind/json/converters.py | 19 +++++++++++++++---- .../databind/json/tests/converters_test.py | 13 +++++++++++++ 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/databind/src/databind/json/converters.py b/databind/src/databind/json/converters.py index e04df80..18b8488 100644 --- a/databind/src/databind/json/converters.py +++ b/databind/src/databind/json/converters.py @@ -763,13 +763,19 @@ def _check_style_compatibility(self, ctx: Context, style: str, value: t.Any) -> def convert(self, ctx: Context) -> t.Any: datatype = ctx.datatype union: t.Optional[Union] + literal_types: list[TypeHint] = [] + if isinstance(datatype, UnionTypeHint): if datatype.has_none_type(): raise NotImplementedError("unable to handle Union type with None in it") - if not all(isinstance(a, ClassTypeHint) for a in datatype): - raise NotImplementedError(f"members of plain Union must be concrete types: {datatype}") - members = {t.cast(ClassTypeHint, a).type.__name__: a for a in datatype} - if len(members) != len(datatype): + + literal_types = [a for a in datatype if isinstance(a, LiteralTypeHint)] + non_literal_types = [a for a in datatype if not isinstance(a, LiteralTypeHint)] + if not all(isinstance(a, ClassTypeHint) for a in non_literal_types): + raise NotImplementedError(f"members of plain Union must be concrete or Literal types: {datatype}") + + members = {t.cast(ClassTypeHint, a).type.__name__: a for a in non_literal_types} + if len(members) != len(non_literal_types): raise NotImplementedError(f"members of plain Union cannot have overlapping type names: {datatype}") union = Union(members, Union.BEST_MATCH) elif isinstance(datatype, (AnnotatedTypeHint, ClassTypeHint)): @@ -788,6 +794,11 @@ def convert(self, ctx: Context) -> t.Any: return ctx.spawn(ctx.value, member_type, None).convert() except ConversionError as exc: errors.append((exc.origin, exc)) + for literal_type in literal_types: + try: + return ctx.spawn(ctx.value, literal_type, None).convert() + except ConversionError as exc: + errors.append((exc.origin, exc)) raise ConversionError( self, ctx, diff --git a/databind/src/databind/json/tests/converters_test.py b/databind/src/databind/json/tests/converters_test.py index 0de421c..9903605 100644 --- a/databind/src/databind/json/tests/converters_test.py +++ b/databind/src/databind/json/tests/converters_test.py @@ -713,3 +713,16 @@ def of(cls, v: str) -> "MyCls": mapper = make_mapper([JsonConverterSupport()]) assert mapper.serialize(MyCls(), MyCls) == "MyCls" assert mapper.deserialize("MyCls", MyCls) == MyCls() + + +def test_union_literal(): + mapper = make_mapper([UnionConverter(), PlainDatatypeConverter()]) + + IntType = int | t.Literal["hi", "bye"] + StrType = str | t.Literal["hi", "bye"] + + assert mapper.serialize("hi", IntType) == "hi" + assert mapper.serialize(2, IntType) == 2 + + assert mapper.serialize("bye", StrType) == "bye" + assert mapper.serialize("other", StrType) == "other"