From 242eaeab887a2ddb20b076d0e30a53d52178da4d Mon Sep 17 00:00:00 2001 From: gabriel Date: Mon, 2 Sep 2024 11:39:36 +0200 Subject: [PATCH 1/4] Support Enum Signed-off-by: gabriel --- pandera/engines/polars_engine.py | 42 ++++++++++++++++++++++++++++-- tests/polars/test_polars_dtypes.py | 6 +++++ 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/pandera/engines/polars_engine.py b/pandera/engines/polars_engine.py index d1e21b6c1..4cead2cae 100644 --- a/pandera/engines/polars_engine.py +++ b/pandera/engines/polars_engine.py @@ -680,11 +680,49 @@ def __init__( # pylint:disable=super-init-not-called @classmethod def from_parametrized_dtype(cls, polars_dtype: pl.Categorical): - """Convert a :class:`polars.Decimal` to - a Pandera :class:`pandera.engines.polars_engine.Decimal`.""" + """Convert a :class:`polars.Categorical` to + a Pandera :class:`pandera.engines.polars_engine.Categorical`.""" return cls(ordering=polars_dtype.ordering) +@Engine.register_dtype(equivalents=[pl.Enum]) +@immutable(init=True) +class Enum(DataType): + """Polars enum data type.""" + + type = pl.Enum + + categories: pl.Series + + def __init__( # pylint:disable=super-init-not-called + self, + categories: pl.Series | Iterable[str] | None = None, + ) -> None: + object.__setattr__(self, "categories", categories) + object.__setattr__(self, "type", pl.Enum(categories=categories)) + + @classmethod + def from_parametrized_dtype(cls, polars_dtype: pl.Enum): + """Convert a :class:`polars.Enum` to + a Pandera :class:`pandera.engines.polars_engine.Enum`.""" + return cls(categories=polars_dtype.categories) + + def check( + self, + pandera_dtype: dtypes.DataType, + data_container: Optional[PolarsDataContainer] = None, + ) -> Union[bool, Iterable[bool]]: + try: + pandera_dtype = Engine.dtype(pandera_dtype) + except TypeError: + return False + + return ( + self.type == pandera_dtype.type + and (self.type.categories == pandera_dtype.categories).all() + ) + + @Engine.register_dtype( equivalents=["category", dtypes.Category, dtypes.Category()] ) diff --git a/tests/polars/test_polars_dtypes.py b/tests/polars/test_polars_dtypes.py index d61c772fd..34cafe0be 100644 --- a/tests/polars/test_polars_dtypes.py +++ b/tests/polars/test_polars_dtypes.py @@ -409,6 +409,12 @@ def test_polars_struct_nested_type(inner_dtype_cls): pl.List(pl.Object()), pl.LazyFrame({"0": [[1.0, 2.0, 3.0]]}), ], + # Enum + [ + pl.Enum(categories=["yes", "no"]), + pl.Enum(categories=["yes", "no", "?"]), + pl.LazyFrame({"0": ["yes", "yes", "no"]}), + ], # Struct [ pl.Struct({"a": pl.Utf8(), "b": pl.Int64(), "c": pl.Float64()}), From dc4d7f0b96f8a96c2fc1360ec1e8a69f62592d2e Mon Sep 17 00:00:00 2001 From: Niels Bantilan Date: Sat, 21 Sep 2024 09:04:21 -0400 Subject: [PATCH 2/4] Update polars_engine.py --- pandera/engines/polars_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandera/engines/polars_engine.py b/pandera/engines/polars_engine.py index 4cead2cae..570785c0d 100644 --- a/pandera/engines/polars_engine.py +++ b/pandera/engines/polars_engine.py @@ -719,7 +719,7 @@ def check( return ( self.type == pandera_dtype.type - and (self.type.categories == pandera_dtype.categories).all() + and (self.type.categories == getattr(pandera_dtype, "categories", None)).all() ) From 6e8a657f553b32e8041f31495ab22a6128513914 Mon Sep 17 00:00:00 2001 From: Niels Bantilan Date: Sat, 21 Sep 2024 09:31:39 -0400 Subject: [PATCH 3/4] Update polars_engine.py --- pandera/engines/polars_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandera/engines/polars_engine.py b/pandera/engines/polars_engine.py index 570785c0d..5f2391ded 100644 --- a/pandera/engines/polars_engine.py +++ b/pandera/engines/polars_engine.py @@ -719,7 +719,7 @@ def check( return ( self.type == pandera_dtype.type - and (self.type.categories == getattr(pandera_dtype, "categories", None)).all() + and (self.type.categories == pandera_dtype.categories).all() # type: ignore ) From c3a3715653fa77302d03d8601e6e38fe09ce1016 Mon Sep 17 00:00:00 2001 From: Niels Bantilan Date: Sat, 21 Sep 2024 11:48:02 -0400 Subject: [PATCH 4/4] Update polars_engine.py --- pandera/engines/polars_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandera/engines/polars_engine.py b/pandera/engines/polars_engine.py index 5f2391ded..9ba452c0e 100644 --- a/pandera/engines/polars_engine.py +++ b/pandera/engines/polars_engine.py @@ -696,7 +696,7 @@ class Enum(DataType): def __init__( # pylint:disable=super-init-not-called self, - categories: pl.Series | Iterable[str] | None = None, + categories: Union[pl.Series, Iterable[str], None] = None, ) -> None: object.__setattr__(self, "categories", categories) object.__setattr__(self, "type", pl.Enum(categories=categories))