From 612d25c6541d426355195e0e3eb613dbc1a6c67d Mon Sep 17 00:00:00 2001 From: Niels Bantilan Date: Mon, 6 May 2024 10:02:16 -0400 Subject: [PATCH] add tests for polars decorators (#1615) Signed-off-by: cosmicBboy --- pandera/decorators.py | 16 ++--- pandera/typing/polars.py | 7 ++ tests/polars/test_polars_decorators.py | 96 ++++++++++++++++++++++++++ 3 files changed, 111 insertions(+), 8 deletions(-) create mode 100644 tests/polars/test_polars_decorators.py diff --git a/pandera/decorators.py b/pandera/decorators.py index 1f7068bbc..fd044694d 100644 --- a/pandera/decorators.py +++ b/pandera/decorators.py @@ -27,9 +27,9 @@ from pandera import errors from pandera.api.base.error_handler import ErrorHandler -from pandera.api.pandas.array import SeriesSchema -from pandera.api.pandas.container import DataFrameSchema -from pandera.api.pandas.model import DataFrameModel +from pandera.api.dataframe.components import ComponentSchema +from pandera.api.dataframe.container import DataFrameSchema +from pandera.api.dataframe.model import DataFrameModel from pandera.inspection_utils import ( is_classmethod_from_meta, is_decorated_classmethod, @@ -37,7 +37,7 @@ from pandera.typing import AnnotationInfo from pandera.validation_depth import validation_type -Schemas = Union[DataFrameSchema, SeriesSchema] +Schemas = Union[DataFrameSchema, ComponentSchema] InputGetter = Union[str, int] OutputGetter = Union[str, int, Callable] F = TypeVar("F", bound=Callable) @@ -84,7 +84,7 @@ def _get_fn_argnames(fn: Callable) -> List[str]: def _handle_schema_error( decorator_name, fn: Callable, - schema: Union[DataFrameSchema, SeriesSchema], + schema: Union[DataFrameSchema, ComponentSchema], data_obj: Any, schema_error: errors.SchemaError, ) -> NoReturn: @@ -110,7 +110,7 @@ def _handle_schema_error( def _parse_schema_error( decorator_name, fn: Callable, - schema: Union[DataFrameSchema, SeriesSchema], + schema: Union[DataFrameSchema, ComponentSchema], data_obj: Any, schema_error: errors.SchemaError, reason_code: errors.SchemaErrorReason, @@ -355,7 +355,7 @@ def check_output( # pylint: disable=too-many-boolean-expressions if callable(obj_getter) and ( schema.coerce - or (schema.index is not None and schema.index.coerce) + or (schema.index is not None and schema.index.coerce) # type: ignore[union-attr] or ( isinstance(schema, DataFrameSchema) and any(col.coerce for col in schema.columns.values()) @@ -490,7 +490,7 @@ def _wrapper( out_schemas = out if isinstance(out, list): out_schemas = out - elif isinstance(out, (DataFrameSchema, SeriesSchema)): + elif isinstance(out, (DataFrameSchema, ComponentSchema)): out_schemas = [(None, out)] # type: ignore elif isinstance(out, tuple): out_schemas = [out] diff --git a/pandera/typing/polars.py b/pandera/typing/polars.py index 061a797f1..3493be714 100644 --- a/pandera/typing/polars.py +++ b/pandera/typing/polars.py @@ -35,6 +35,13 @@ class LazyFrame(DataFrameBase, pl.LazyFrame, Generic[T]): *new in 0.19.0* """ + class DataFrame(DataFrameBase, pl.DataFrame, Generic[T]): + """ + Pandera generic for pl.LazyFrame, only used for type annotation. + + *new in 0.19.0* + """ + # pylint: disable=too-few-public-methods class Series(SeriesBase, pl.Series, Generic[T]): """ diff --git a/tests/polars/test_polars_decorators.py b/tests/polars/test_polars_decorators.py new file mode 100644 index 000000000..5c5bfc6f7 --- /dev/null +++ b/tests/polars/test_polars_decorators.py @@ -0,0 +1,96 @@ +"""Unit tests for using schemas with polars and function decorators.""" + +import polars as pl +import pytest + +import pandera.polars as pa +import pandera.typing.polars as pa_typing + + +@pytest.fixture +def data() -> pl.DataFrame: + return pl.DataFrame({"a": [1, 2, 3]}) + + +@pytest.fixture +def invalid_data(data) -> pl.DataFrame: + return data.rename({"a": "b"}) + + +def test_polars_dataframe_check_io(data, invalid_data): + # pylint: disable=unused-argument + + schema = pa.DataFrameSchema({"a": pa.Column(int)}) + + @pa.check_input(schema) + def fn_check_input(x): + ... + + @pa.check_output(schema) + def fn_check_output(x): + return x + + @pa.check_io(x=schema, out=schema) + def fn_check_io(x): + return x + + @pa.check_io(x=schema, out=schema) + def fn_check_io_invalid(x): + return x.rename({"a": "b"}) + + # valid data should pass + fn_check_input(data) + fn_check_output(data) + fn_check_io(data) + + # invalid data or invalid function should not pass + with pytest.raises(pa.errors.SchemaError): + fn_check_input(invalid_data) + + with pytest.raises(pa.errors.SchemaError): + fn_check_output(invalid_data) + + with pytest.raises(pa.errors.SchemaError): + fn_check_io_invalid(data) + + +def test_polars_dataframe_check_types(data, invalid_data): + # pylint: disable=unused-argument + + class Model(pa.DataFrameModel): + a: int + + @pa.check_types + def fn_check_input(x: pa_typing.DataFrame[Model]): + ... + + @pa.check_types + def fn_check_output(x) -> pa_typing.DataFrame[Model]: + return x + + @pa.check_types + def fn_check_io( + x: pa_typing.DataFrame[Model], + ) -> pa_typing.DataFrame[Model]: + return x + + @pa.check_types + def fn_check_io_invalid( + x: pa_typing.DataFrame[Model], + ) -> pa_typing.DataFrame[Model]: + return x.rename({"a": "b"}) + + # valid data should pass + fn_check_input(data) + fn_check_output(data) + fn_check_io(data) + + # invalid data or invalid function should not pass + with pytest.raises(pa.errors.SchemaError): + fn_check_input(invalid_data) + + with pytest.raises(pa.errors.SchemaError): + fn_check_output(invalid_data) + + with pytest.raises(pa.errors.SchemaError): + fn_check_io_invalid(data)