Skip to content

Commit

Permalink
add tests for polars decorators (#1615)
Browse files Browse the repository at this point in the history
Signed-off-by: cosmicBboy <[email protected]>
  • Loading branch information
cosmicBboy authored May 6, 2024
1 parent b11cc4d commit 612d25c
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 8 deletions.
16 changes: 8 additions & 8 deletions pandera/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,17 @@

from pandera import errors
from pandera.api.base.error_handler import ErrorHandler
from pandera.api.pandas.array import SeriesSchema
from pandera.api.pandas.container import DataFrameSchema
from pandera.api.pandas.model import DataFrameModel
from pandera.api.dataframe.components import ComponentSchema
from pandera.api.dataframe.container import DataFrameSchema
from pandera.api.dataframe.model import DataFrameModel
from pandera.inspection_utils import (
is_classmethod_from_meta,
is_decorated_classmethod,
)
from pandera.typing import AnnotationInfo
from pandera.validation_depth import validation_type

Schemas = Union[DataFrameSchema, SeriesSchema]
Schemas = Union[DataFrameSchema, ComponentSchema]
InputGetter = Union[str, int]
OutputGetter = Union[str, int, Callable]
F = TypeVar("F", bound=Callable)
Expand Down Expand Up @@ -84,7 +84,7 @@ def _get_fn_argnames(fn: Callable) -> List[str]:
def _handle_schema_error(
decorator_name,
fn: Callable,
schema: Union[DataFrameSchema, SeriesSchema],
schema: Union[DataFrameSchema, ComponentSchema],
data_obj: Any,
schema_error: errors.SchemaError,
) -> NoReturn:
Expand All @@ -110,7 +110,7 @@ def _handle_schema_error(
def _parse_schema_error(
decorator_name,
fn: Callable,
schema: Union[DataFrameSchema, SeriesSchema],
schema: Union[DataFrameSchema, ComponentSchema],
data_obj: Any,
schema_error: errors.SchemaError,
reason_code: errors.SchemaErrorReason,
Expand Down Expand Up @@ -355,7 +355,7 @@ def check_output(
# pylint: disable=too-many-boolean-expressions
if callable(obj_getter) and (
schema.coerce
or (schema.index is not None and schema.index.coerce)
or (schema.index is not None and schema.index.coerce) # type: ignore[union-attr]
or (
isinstance(schema, DataFrameSchema)
and any(col.coerce for col in schema.columns.values())
Expand Down Expand Up @@ -490,7 +490,7 @@ def _wrapper(
out_schemas = out
if isinstance(out, list):
out_schemas = out
elif isinstance(out, (DataFrameSchema, SeriesSchema)):
elif isinstance(out, (DataFrameSchema, ComponentSchema)):
out_schemas = [(None, out)] # type: ignore
elif isinstance(out, tuple):
out_schemas = [out]
Expand Down
7 changes: 7 additions & 0 deletions pandera/typing/polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ class LazyFrame(DataFrameBase, pl.LazyFrame, Generic[T]):
*new in 0.19.0*
"""

class DataFrame(DataFrameBase, pl.DataFrame, Generic[T]):
"""
Pandera generic for pl.LazyFrame, only used for type annotation.
*new in 0.19.0*
"""

# pylint: disable=too-few-public-methods
class Series(SeriesBase, pl.Series, Generic[T]):
"""
Expand Down
96 changes: 96 additions & 0 deletions tests/polars/test_polars_decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""Unit tests for using schemas with polars and function decorators."""

import polars as pl
import pytest

import pandera.polars as pa
import pandera.typing.polars as pa_typing


@pytest.fixture
def data() -> pl.DataFrame:
return pl.DataFrame({"a": [1, 2, 3]})


@pytest.fixture
def invalid_data(data) -> pl.DataFrame:
return data.rename({"a": "b"})


def test_polars_dataframe_check_io(data, invalid_data):
# pylint: disable=unused-argument

schema = pa.DataFrameSchema({"a": pa.Column(int)})

@pa.check_input(schema)
def fn_check_input(x):
...

@pa.check_output(schema)
def fn_check_output(x):
return x

@pa.check_io(x=schema, out=schema)
def fn_check_io(x):
return x

@pa.check_io(x=schema, out=schema)
def fn_check_io_invalid(x):
return x.rename({"a": "b"})

# valid data should pass
fn_check_input(data)
fn_check_output(data)
fn_check_io(data)

# invalid data or invalid function should not pass
with pytest.raises(pa.errors.SchemaError):
fn_check_input(invalid_data)

with pytest.raises(pa.errors.SchemaError):
fn_check_output(invalid_data)

with pytest.raises(pa.errors.SchemaError):
fn_check_io_invalid(data)


def test_polars_dataframe_check_types(data, invalid_data):
# pylint: disable=unused-argument

class Model(pa.DataFrameModel):
a: int

@pa.check_types
def fn_check_input(x: pa_typing.DataFrame[Model]):
...

@pa.check_types
def fn_check_output(x) -> pa_typing.DataFrame[Model]:
return x

@pa.check_types
def fn_check_io(
x: pa_typing.DataFrame[Model],
) -> pa_typing.DataFrame[Model]:
return x

@pa.check_types
def fn_check_io_invalid(
x: pa_typing.DataFrame[Model],
) -> pa_typing.DataFrame[Model]:
return x.rename({"a": "b"})

# valid data should pass
fn_check_input(data)
fn_check_output(data)
fn_check_io(data)

# invalid data or invalid function should not pass
with pytest.raises(pa.errors.SchemaError):
fn_check_input(invalid_data)

with pytest.raises(pa.errors.SchemaError):
fn_check_output(invalid_data)

with pytest.raises(pa.errors.SchemaError):
fn_check_io_invalid(data)

0 comments on commit 612d25c

Please sign in to comment.