Skip to content

Commit

Permalink
Fix NiklasRosenstein#47: Can serialize Union with Literal
Browse files Browse the repository at this point in the history
  • Loading branch information
rhaps0dy committed May 10, 2024
1 parent 572c79c commit d178c85
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 4 deletions.
19 changes: 15 additions & 4 deletions databind/src/databind/json/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,13 +763,19 @@ def _check_style_compatibility(self, ctx: Context, style: str, value: t.Any) ->
def convert(self, ctx: Context) -> t.Any:
datatype = ctx.datatype
union: t.Optional[Union]
literal_types: list[TypeHint] = []

if isinstance(datatype, UnionTypeHint):
if datatype.has_none_type():
raise NotImplementedError("unable to handle Union type with None in it")
if not all(isinstance(a, ClassTypeHint) for a in datatype):
raise NotImplementedError(f"members of plain Union must be concrete types: {datatype}")
members = {t.cast(ClassTypeHint, a).type.__name__: a for a in datatype}
if len(members) != len(datatype):

literal_types = [a for a in datatype if isinstance(a, LiteralTypeHint)]
non_literal_types = [a for a in datatype if not isinstance(a, LiteralTypeHint)]
if not all(isinstance(a, ClassTypeHint) for a in non_literal_types):
raise NotImplementedError(f"members of plain Union must be concrete or Literal types: {datatype}")

members = {t.cast(ClassTypeHint, a).type.__name__: a for a in non_literal_types}
if len(members) != len(non_literal_types):
raise NotImplementedError(f"members of plain Union cannot have overlapping type names: {datatype}")
union = Union(members, Union.BEST_MATCH)
elif isinstance(datatype, (AnnotatedTypeHint, ClassTypeHint)):
Expand All @@ -788,6 +794,11 @@ def convert(self, ctx: Context) -> t.Any:
return ctx.spawn(ctx.value, member_type, None).convert()
except ConversionError as exc:
errors.append((exc.origin, exc))
for literal_type in literal_types:
try:
return ctx.spawn(ctx.value, literal_type, None).convert()
except ConversionError as exc:
errors.append((exc.origin, exc))
raise ConversionError(
self,
ctx,
Expand Down
13 changes: 13 additions & 0 deletions databind/src/databind/json/tests/converters_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,3 +713,16 @@ def of(cls, v: str) -> "MyCls":
mapper = make_mapper([JsonConverterSupport()])
assert mapper.serialize(MyCls(), MyCls) == "MyCls"
assert mapper.deserialize("MyCls", MyCls) == MyCls()


def test_union_literal():
mapper = make_mapper([UnionConverter(), PlainDatatypeConverter()])

IntType = int | t.Literal["hi", "bye"]
StrType = str | t.Literal["hi", "bye"]

assert mapper.serialize("hi", IntType) == "hi"
assert mapper.serialize(2, IntType) == 2

assert mapper.serialize("bye", StrType) == "bye"
assert mapper.serialize("other", StrType) == "other"

0 comments on commit d178c85

Please sign in to comment.