diff --git a/ci/requirements-py3.10-pandas1.5.3-pydantic1.10.11.txt b/ci/requirements-py3.10-pandas1.5.3-pydantic1.10.11.txt index 7dc9a5cf8..2aef3ab8b 100644 --- a/ci/requirements-py3.10-pandas1.5.3-pydantic1.10.11.txt +++ b/ci/requirements-py3.10-pandas1.5.3-pydantic1.10.11.txt @@ -80,7 +80,7 @@ mdurl==0.1.2 modin==0.22.3 more-itertools==10.4.0 msgpack==1.0.8 -multimethod==1.10 +multimethod==1.12 mypy==1.10.0 mypy-extensions==1.0.0 myst-nb==1.1.1 diff --git a/ci/requirements-py3.10-pandas1.5.3-pydantic2.3.0.txt b/ci/requirements-py3.10-pandas1.5.3-pydantic2.3.0.txt index 3889e7e98..8a4ea4f2f 100644 --- a/ci/requirements-py3.10-pandas1.5.3-pydantic2.3.0.txt +++ b/ci/requirements-py3.10-pandas1.5.3-pydantic2.3.0.txt @@ -81,7 +81,7 @@ mdurl==0.1.2 modin==0.22.3 more-itertools==10.4.0 msgpack==1.0.8 -multimethod==1.10 +multimethod==1.12 mypy==1.10.0 mypy-extensions==1.0.0 myst-nb==1.1.1 diff --git a/ci/requirements-py3.10-pandas2.2.2-pydantic1.10.11.txt b/ci/requirements-py3.10-pandas2.2.2-pydantic1.10.11.txt index 5084b71c4..52ca0e11b 100644 --- a/ci/requirements-py3.10-pandas2.2.2-pydantic1.10.11.txt +++ b/ci/requirements-py3.10-pandas2.2.2-pydantic1.10.11.txt @@ -81,7 +81,7 @@ mdurl==0.1.2 modin==0.31.0 more-itertools==10.4.0 msgpack==1.0.8 -multimethod==1.10 +multimethod==1.12 mypy==1.10.0 mypy-extensions==1.0.0 myst-nb==1.1.1 diff --git a/ci/requirements-py3.10-pandas2.2.2-pydantic2.3.0.txt b/ci/requirements-py3.10-pandas2.2.2-pydantic2.3.0.txt index b98e8a73b..1b9be9414 100644 --- a/ci/requirements-py3.10-pandas2.2.2-pydantic2.3.0.txt +++ b/ci/requirements-py3.10-pandas2.2.2-pydantic2.3.0.txt @@ -82,7 +82,7 @@ mdurl==0.1.2 modin==0.31.0 more-itertools==10.4.0 msgpack==1.0.8 -multimethod==1.10 +multimethod==1.12 mypy==1.10.0 mypy-extensions==1.0.0 myst-nb==1.1.1 diff --git a/ci/requirements-py3.11-pandas1.5.3-pydantic1.10.11.txt b/ci/requirements-py3.11-pandas1.5.3-pydantic1.10.11.txt index eea6e24cd..1d6d27a54 100644 --- a/ci/requirements-py3.11-pandas1.5.3-pydantic1.10.11.txt +++ b/ci/requirements-py3.11-pandas1.5.3-pydantic1.10.11.txt @@ -79,7 +79,7 @@ mdurl==0.1.2 modin==0.22.3 more-itertools==10.4.0 msgpack==1.0.8 -multimethod==1.10 +multimethod==1.12 mypy==1.10.0 mypy-extensions==1.0.0 myst-nb==1.1.1 diff --git a/ci/requirements-py3.11-pandas1.5.3-pydantic2.3.0.txt b/ci/requirements-py3.11-pandas1.5.3-pydantic2.3.0.txt index 7acd63154..9751608e9 100644 --- a/ci/requirements-py3.11-pandas1.5.3-pydantic2.3.0.txt +++ b/ci/requirements-py3.11-pandas1.5.3-pydantic2.3.0.txt @@ -80,7 +80,7 @@ mdurl==0.1.2 modin==0.22.3 more-itertools==10.4.0 msgpack==1.0.8 -multimethod==1.10 +multimethod==1.12 mypy==1.10.0 mypy-extensions==1.0.0 myst-nb==1.1.1 diff --git a/ci/requirements-py3.11-pandas2.2.2-pydantic1.10.11.txt b/ci/requirements-py3.11-pandas2.2.2-pydantic1.10.11.txt index 6e75fbaaf..4f5069fff 100644 --- a/ci/requirements-py3.11-pandas2.2.2-pydantic1.10.11.txt +++ b/ci/requirements-py3.11-pandas2.2.2-pydantic1.10.11.txt @@ -80,7 +80,7 @@ mdurl==0.1.2 modin==0.31.0 more-itertools==10.4.0 msgpack==1.0.8 -multimethod==1.10 +multimethod==1.12 mypy==1.10.0 mypy-extensions==1.0.0 myst-nb==1.1.1 diff --git a/ci/requirements-py3.11-pandas2.2.2-pydantic2.3.0.txt b/ci/requirements-py3.11-pandas2.2.2-pydantic2.3.0.txt index 2ec41e6d6..22f1cc434 100644 --- a/ci/requirements-py3.11-pandas2.2.2-pydantic2.3.0.txt +++ b/ci/requirements-py3.11-pandas2.2.2-pydantic2.3.0.txt @@ -81,7 +81,7 @@ mdurl==0.1.2 modin==0.31.0 more-itertools==10.4.0 msgpack==1.0.8 -multimethod==1.10 +multimethod==1.12 mypy==1.10.0 mypy-extensions==1.0.0 myst-nb==1.1.1 diff --git a/ci/requirements-py3.9-pandas1.5.3-pydantic1.10.11.txt b/ci/requirements-py3.9-pandas1.5.3-pydantic1.10.11.txt index 14605ecde..d6e94b179 100644 --- a/ci/requirements-py3.9-pandas1.5.3-pydantic1.10.11.txt +++ b/ci/requirements-py3.9-pandas1.5.3-pydantic1.10.11.txt @@ -80,7 +80,7 @@ mdurl==0.1.2 modin==0.22.3 more-itertools==10.4.0 msgpack==1.0.8 -multimethod==1.10 +multimethod==1.12 mypy==1.10.0 mypy-extensions==1.0.0 myst-nb==1.1.1 diff --git a/ci/requirements-py3.9-pandas1.5.3-pydantic2.3.0.txt b/ci/requirements-py3.9-pandas1.5.3-pydantic2.3.0.txt index 94a84daad..3e64e5d34 100644 --- a/ci/requirements-py3.9-pandas1.5.3-pydantic2.3.0.txt +++ b/ci/requirements-py3.9-pandas1.5.3-pydantic2.3.0.txt @@ -81,7 +81,7 @@ mdurl==0.1.2 modin==0.22.3 more-itertools==10.4.0 msgpack==1.0.8 -multimethod==1.10 +multimethod==1.12 mypy==1.10.0 mypy-extensions==1.0.0 myst-nb==1.1.1 diff --git a/ci/requirements-py3.9-pandas2.2.2-pydantic1.10.11.txt b/ci/requirements-py3.9-pandas2.2.2-pydantic1.10.11.txt index ab8ce642c..ffaf43d9b 100644 --- a/ci/requirements-py3.9-pandas2.2.2-pydantic1.10.11.txt +++ b/ci/requirements-py3.9-pandas2.2.2-pydantic1.10.11.txt @@ -81,7 +81,7 @@ mdurl==0.1.2 modin==0.31.0 more-itertools==10.4.0 msgpack==1.0.8 -multimethod==1.10 +multimethod==1.12 mypy==1.10.0 mypy-extensions==1.0.0 myst-nb==1.1.1 diff --git a/ci/requirements-py3.9-pandas2.2.2-pydantic2.3.0.txt b/ci/requirements-py3.9-pandas2.2.2-pydantic2.3.0.txt index bd84392a7..4167b5015 100644 --- a/ci/requirements-py3.9-pandas2.2.2-pydantic2.3.0.txt +++ b/ci/requirements-py3.9-pandas2.2.2-pydantic2.3.0.txt @@ -82,7 +82,7 @@ mdurl==0.1.2 modin==0.31.0 more-itertools==10.4.0 msgpack==1.0.8 -multimethod==1.10 +multimethod==1.12 mypy==1.10.0 mypy-extensions==1.0.0 myst-nb==1.1.1 diff --git a/dev/requirements-3.10.txt b/dev/requirements-3.10.txt index 817543d9d..fd147196a 100644 --- a/dev/requirements-3.10.txt +++ b/dev/requirements-3.10.txt @@ -82,7 +82,7 @@ mdurl==0.1.2 modin==0.31.0 more-itertools==10.4.0 msgpack==1.0.8 -multimethod==1.10 +multimethod==1.12 mypy==1.10.0 mypy-extensions==1.0.0 myst-nb==1.1.1 diff --git a/dev/requirements-3.11.txt b/dev/requirements-3.11.txt index 1c9bbb43f..47a55b9ab 100644 --- a/dev/requirements-3.11.txt +++ b/dev/requirements-3.11.txt @@ -81,7 +81,7 @@ mdurl==0.1.2 modin==0.31.0 more-itertools==10.4.0 msgpack==1.0.8 -multimethod==1.10 +multimethod==1.12 mypy==1.10.0 mypy-extensions==1.0.0 myst-nb==1.1.1 diff --git a/dev/requirements-3.9.txt b/dev/requirements-3.9.txt index 08ace7e87..c3be160e5 100644 --- a/dev/requirements-3.9.txt +++ b/dev/requirements-3.9.txt @@ -82,7 +82,7 @@ mdurl==0.1.2 modin==0.31.0 more-itertools==10.4.0 msgpack==1.0.8 -multimethod==1.10 +multimethod==1.12 mypy==1.10.0 mypy-extensions==1.0.0 myst-nb==1.1.1 diff --git a/environment.yml b/environment.yml index 611dd9d00..1abd2e3b5 100644 --- a/environment.yml +++ b/environment.yml @@ -16,10 +16,10 @@ dependencies: - pyyaml >=5.1 - typing_inspect >= 0.6.0 - typing_extensions >= 3.7.4.3 - - frictionless <= 4.40.8 # v5.* introduces breaking changes + - frictionless <= 4.40.8 # v5.* introduces breaking changes - pyarrow - pydantic - - multimethod <= 1.10.0 + - multimethod # mypy extra - pandas-stubs diff --git a/pandera/api/pandas/types.py b/pandera/api/pandas/types.py index 2006359a9..58fb7a902 100644 --- a/pandera/api/pandas/types.py +++ b/pandera/api/pandas/types.py @@ -32,35 +32,35 @@ def supported_types() -> SupportedTypes: """Get the types supported by pandera schemas.""" # pylint: disable=import-outside-toplevel - table_types = [pd.DataFrame] - field_types = [pd.Series] - index_types = [pd.Index] - multiindex_types = [pd.MultiIndex] + table_types: Tuple[type, ...] = (pd.DataFrame,) + field_types: Tuple[type, ...] = (pd.Series,) + index_types: Tuple[type, ...] = (pd.Index,) + multiindex_types: Tuple[type, ...] = (pd.MultiIndex,) try: import pyspark.pandas as ps - table_types.append(ps.DataFrame) - field_types.append(ps.Series) - index_types.append(ps.Index) - multiindex_types.append(ps.MultiIndex) + table_types += (ps.DataFrame,) + field_types += (ps.Series,) + index_types += (ps.Index,) + multiindex_types += (ps.MultiIndex,) except ImportError: pass try: # pragma: no cover import modin.pandas as mpd - table_types.append(mpd.DataFrame) - field_types.append(mpd.Series) - index_types.append(mpd.Index) - multiindex_types.append(mpd.MultiIndex) + table_types += (mpd.DataFrame,) + field_types += (mpd.Series,) + index_types += (mpd.Index,) + multiindex_types += (mpd.MultiIndex,) except ImportError: pass try: import dask.dataframe as dd - table_types.append(dd.DataFrame) - field_types.append(dd.Series) - index_types.append(dd.Index) + table_types += (dd.DataFrame,) + field_types += (dd.Series,) + index_types += (dd.Index,) except ImportError: pass @@ -72,6 +72,36 @@ def supported_types() -> SupportedTypes: ) +def supported_type_unions(attribute: str): + """Get the type unions for a given attribute.""" + if attribute == "table_types": + return Union[tuple(supported_types().table_types)] + if attribute == "field_types": + return Union[tuple(supported_types().field_types)] + if attribute == "index_types": + return Union[tuple(supported_types().index_types)] + if attribute == "multiindex_types": + return Union[tuple(supported_types().multiindex_types)] + if attribute == "table_or_field_types": + return Union[ + tuple( + ( + *supported_types().table_types, + *supported_types().field_types, + ) + ) + ] + raise ValueError(f"invalid attribute {attribute}") + + +Table = supported_type_unions("table_types") +Field = supported_type_unions("field_types") +Index = supported_type_unions("index_types") +Multiindex = supported_type_unions("multiindex_types") +TableOrField = supported_type_unions("table_or_field_types") +Bool = Union[bool, np.bool_] + + def is_table(obj): """Verifies whether an object is table-like. diff --git a/pandera/api/polars/types.py b/pandera/api/polars/types.py index f038bcf73..a23464f2c 100644 --- a/pandera/api/polars/types.py +++ b/pandera/api/polars/types.py @@ -27,3 +27,11 @@ class CheckResult(NamedTuple): type, pl.datatypes.classes.DataTypeClass, ] + + +def is_bool(x): + """Verifies whether an object is a boolean type.""" + return isinstance(x, (bool, pl.Boolean)) + + +Bool = Union[bool, pl.Boolean] diff --git a/pandera/backends/pandas/checks.py b/pandera/backends/pandas/checks.py index 3c9cc3d61..fc721c82d 100644 --- a/pandera/backends/pandas/checks.py +++ b/pandera/backends/pandas/checks.py @@ -4,15 +4,14 @@ from typing import Dict, List, Optional, Union, cast import pandas as pd -from multimethod import DispatchError, overload - +from multimethod import DispatchError, multidispatch from pandera.api.base.checks import CheckResult, GroupbyObject from pandera.api.checks import Check from pandera.api.pandas.types import ( - is_bool, - is_field, - is_table, - is_table_or_field, + Bool, + Field, + Table, + TableOrField, ) from pandera.backends.base import BaseCheckBackend @@ -78,18 +77,18 @@ def _format_groupby_input( return output # type: ignore[return-value] - @overload + @multidispatch def preprocess(self, check_obj, key) -> pd.Series: """Preprocesses a check object before applying the check function.""" # This handles the case of Series validation, which has no other context except # for the index to groupby on. Right now grouping by the index is not allowed. return check_obj - @overload # type: ignore [no-redef] - def preprocess( + @preprocess.register + def _( self, - check_obj: is_field, # type: ignore [valid-type] - key, + check_obj: Field, # type: ignore [valid-type] + _, ) -> Union[pd.Series, Dict[str, pd.Series]]: if self.check.groupby is None: return check_obj @@ -100,10 +99,10 @@ def preprocess( ), ) - @overload # type: ignore [no-redef] - def preprocess( + @preprocess.register + def _( self, - check_obj: is_table, # type: ignore [valid-type] + check_obj: Table, # type: ignore [valid-type] key, ) -> Union[pd.DataFrame, Dict[str, pd.DataFrame]]: if self.check.groupby is None: @@ -115,11 +114,11 @@ def preprocess( ), ) - @overload # type: ignore [no-redef] - def preprocess( + @preprocess.register + def _( self, - check_obj: is_table, # type: ignore [valid-type] - key: None, + check_obj: Table, # type: ignore [valid-type] + _: None, ) -> Union[pd.DataFrame, Dict[str, pd.DataFrame]]: if self.check.groupby is None: return check_obj @@ -130,39 +129,39 @@ def preprocess( ), ) - @overload + @multidispatch def apply(self, check_obj): """Apply the check function to a check object.""" raise NotImplementedError - @overload # type: ignore [no-redef] - def apply(self, check_obj: dict): + @apply.register + def _(self, check_obj: dict): return self.check_fn(check_obj) - @overload # type: ignore [no-redef] - def apply(self, check_obj: is_field): # type: ignore [valid-type] + @apply.register + def _(self, check_obj: Field): # type: ignore [valid-type] if self.check.element_wise: return check_obj.map(self.check_fn) return self.check_fn(check_obj) - @overload # type: ignore [no-redef] - def apply(self, check_obj: is_table): # type: ignore [valid-type] + @apply.register + def _(self, check_obj: Table): # type: ignore [valid-type] if self.check.element_wise: return check_obj.apply(self.check_fn, axis=1) return self.check_fn(check_obj) - @overload + @multidispatch def postprocess(self, check_obj, check_output): """Postprocesses the result of applying the check function.""" raise TypeError( f"output type of check_fn not recognized: {type(check_output)}" ) - @overload # type: ignore [no-redef] - def postprocess( + @postprocess.register + def _( self, check_obj, - check_output: is_bool, # type: ignore [valid-type] + check_output: Bool, # type: ignore [valid-type] ) -> CheckResult: """Postprocesses the result of applying the check function.""" return CheckResult( @@ -198,11 +197,11 @@ def _get_series_failure_cases( ) return failure_cases - @overload # type: ignore [no-redef] - def postprocess( + @postprocess.register + def _( self, - check_obj: is_field, # type: ignore [valid-type] - check_output: is_field, # type: ignore [valid-type] + check_obj: Field, # type: ignore [valid-type] + check_output: Field, # type: ignore [valid-type] ) -> CheckResult: """Postprocesses the result of applying the check function.""" if check_obj.index.equals(check_output.index) and self.check.ignore_na: @@ -214,11 +213,11 @@ def postprocess( self._get_series_failure_cases(check_obj, check_output), ) - @overload # type: ignore [no-redef] - def postprocess( + @postprocess.register + def _( self, - check_obj: is_table, # type: ignore [valid-type] - check_output: is_field, # type: ignore [valid-type] + check_obj: Table, # type: ignore [valid-type] + check_output: Field, # type: ignore [valid-type] ) -> CheckResult: """Postprocesses the result of applying the check function.""" if check_obj.index.equals(check_output.index) and self.check.ignore_na: @@ -230,11 +229,11 @@ def postprocess( self._get_series_failure_cases(check_obj, check_output), ) - @overload # type: ignore [no-redef] - def postprocess( + @postprocess.register + def _( self, - check_obj: is_table, # type: ignore [valid-type] - check_output: is_table, # type: ignore [valid-type] + check_obj: Table, # type: ignore [valid-type] + check_output: Table, # type: ignore [valid-type] ) -> CheckResult: """Postprocesses the result of applying the check function.""" assert check_obj.shape == check_output.shape @@ -244,19 +243,19 @@ def postprocess( # collect failure cases across all columns. Flse values in check_output # are nulls. select_failure_cases = check_obj[~check_output] - failure_cases = [] + _failure_cases = [] for col in select_failure_cases.columns: cases = select_failure_cases[col].rename("failure_case").dropna() if len(cases) == 0: continue - failure_cases.append( + _failure_cases.append( cases.to_frame() .assign(column=col) .rename_axis("index") .reset_index() ) - if failure_cases: - failure_cases = pd.concat(failure_cases, axis=0) + if _failure_cases: + failure_cases = pd.concat(_failure_cases, axis=0) # convert to a dataframe where each row is a failure case at # a particular index, and failure case values are dictionaries # indicating which column and value failed in that row. @@ -279,11 +278,11 @@ def postprocess( failure_cases, ) - @overload # type: ignore [no-redef] - def postprocess( + @postprocess.register + def _( self, - check_obj: is_table_or_field, # type: ignore [valid-type] - check_output: is_bool, # type: ignore [valid-type] + check_obj: TableOrField, # type: ignore [valid-type] + check_output: Bool, # type: ignore [valid-type] ) -> CheckResult: """Postprocesses the result of applying the check function.""" check_output = bool(check_output) @@ -294,11 +293,11 @@ def postprocess( None, ) - @overload # type: ignore [no-redef] - def postprocess( + @postprocess.register + def _( self, check_obj: dict, - check_output: is_field, # type: ignore [valid-type] + check_output: Field, # type: ignore [valid-type] ) -> CheckResult: """Postprocesses the result of applying the check function.""" return CheckResult( diff --git a/pandera/backends/pandas/hypotheses.py b/pandera/backends/pandas/hypotheses.py index a9bdc6f8d..2c7b9ebd5 100644 --- a/pandera/backends/pandas/hypotheses.py +++ b/pandera/backends/pandas/hypotheses.py @@ -4,11 +4,11 @@ from typing import Any, Callable, Dict, Union, cast import pandas as pd -from multimethod import overload +from multimethod import multidispatch from pandera import errors from pandera.api.hypotheses import Hypothesis -from pandera.api.pandas.types import is_field, is_table +from pandera.api.pandas.types import is_field, Table from pandera.backends.pandas.checks import PandasCheckBackend @@ -48,6 +48,8 @@ def equal(stat, pvalue, alpha=DEFAULT_ALPHA) -> bool: class PandasHypothesisBackend(PandasCheckBackend): """Hypothesis backend implementation for pandas.""" + check: Hypothesis + RELATIONSHIP_FUNCTIONS = { "greater_than": greater_than, "less_than": less_than, @@ -106,15 +108,15 @@ def is_one_sample_test(self): """Return True if hypothesis is a one-sample test.""" return len(self.check.samples) <= 1 - @overload # type: ignore [no-redef] + @multidispatch def preprocess(self, check_obj, key) -> Any: self.check.groups = self.check.samples return super().preprocess(check_obj, key) - @overload # type: ignore [no-redef] - def preprocess( + @preprocess.register + def _( self, - check_obj: is_table, # type: ignore [valid-type] + check_obj: Table, # type: ignore [valid-type] key, ) -> Union[pd.DataFrame, Dict[str, pd.DataFrame]]: if self.check.groupby is None: @@ -126,10 +128,10 @@ def preprocess( ), ) - @overload # type: ignore [no-redef] - def preprocess( + @preprocess.register + def _( self, - check_obj: is_table, # type: ignore [valid-type] + check_obj: Table, # type: ignore [valid-type] key: None, ) -> pd.Series: """Preprocesses a check object before applying the check function.""" diff --git a/pandera/backends/pandas/parsers.py b/pandera/backends/pandas/parsers.py index 5d5ac1368..a80517468 100644 --- a/pandera/backends/pandas/parsers.py +++ b/pandera/backends/pandas/parsers.py @@ -4,10 +4,10 @@ from typing import Dict, Optional, Union import pandas as pd -from multimethod import overload +from multimethod import multidispatch from pandera.api.base.parsers import ParserResult -from pandera.api.pandas.types import is_field, is_table +from pandera.api.pandas.types import Field, Table from pandera.api.parsers import Parser from pandera.backends.base import BaseParserBackend @@ -22,40 +22,40 @@ def __init__(self, parser: Parser): self.parser = parser self.parser_fn = partial(parser._parser_fn, **parser._parser_kwargs) - @overload + @multidispatch def preprocess( self, parse_obj, key # pylint:disable=unused-argument ) -> pd.Series: # pylint:disable=unused-argument """Preprocesses a parser object before applying the parse function.""" return parse_obj - @overload # type: ignore [no-redef] - def preprocess( + @preprocess.register + def _( self, - parse_obj: is_table, # type: ignore [valid-type] + parse_obj: Table, # type: ignore [valid-type] key, ) -> Union[pd.DataFrame, Dict[str, pd.DataFrame]]: return parse_obj[key] - @overload # type: ignore [no-redef] - def preprocess( - self, parse_obj: is_table, key: None # type: ignore [valid-type] # pylint:disable=unused-argument + @preprocess.register + def _( + self, parse_obj: Table, key: None # type: ignore [valid-type] # pylint:disable=unused-argument ) -> Union[pd.DataFrame, Dict[str, pd.DataFrame]]: return parse_obj - @overload + @multidispatch def apply(self, parse_obj): """Apply the parse function to a parser object.""" raise NotImplementedError - @overload # type: ignore [no-redef] - def apply(self, parse_obj: is_field): # type: ignore [valid-type] + @apply.register + def _(self, parse_obj: Field): # type: ignore [valid-type] if self.parser.element_wise: return parse_obj.map(self.parser_fn) return self.parser_fn(parse_obj) - @overload # type: ignore [no-redef] - def apply(self, parse_obj: is_table): # type: ignore [valid-type] + @apply.register + def _(self, parse_obj: Table): # type: ignore [valid-type] if self.parser.element_wise: return getattr(parse_obj, "map", parse_obj.applymap)( self.parser_fn diff --git a/pandera/backends/polars/checks.py b/pandera/backends/polars/checks.py index 26b599690..203a7e8df 100644 --- a/pandera/backends/polars/checks.py +++ b/pandera/backends/polars/checks.py @@ -4,12 +4,12 @@ from typing import Optional import polars as pl -from multimethod import overload +from multimethod import multidispatch from polars.lazyframe.group_by import LazyGroupBy from pandera.api.base.checks import CheckResult from pandera.api.checks import Check -from pandera.api.polars.types import PolarsData +from pandera.api.polars.types import PolarsData, Bool from pandera.api.polars.utils import ( get_lazyframe_schema, get_lazyframe_column_names, @@ -76,15 +76,15 @@ def apply(self, check_obj: PolarsData): return out - @overload + @multidispatch def postprocess(self, check_obj, check_output): """Postprocesses the result of applying the check function.""" raise TypeError( # pragma: no cover f"output type of check_fn not recognized: {type(check_output)}" ) - @overload # type: ignore [no-redef] - def postprocess( + @postprocess.register + def _( self, check_obj: PolarsData, check_output: pl.LazyFrame, @@ -105,11 +105,11 @@ def postprocess( failure_cases=failure_cases, ) - @overload # type: ignore [no-redef] - def postprocess( + @postprocess.register + def _( self, check_obj: PolarsData, - check_output: bool, + check_output: Bool, # type: ignore [valid-type] ) -> CheckResult: """Postprocesses the result of applying the check function.""" ldf_output = pl.LazyFrame({CHECK_OUTPUT_KEY: [check_output]}) diff --git a/reqs-test.txt b/reqs-test.txt index 8e51b5f14..cefc87bfe 100644 --- a/reqs-test.txt +++ b/reqs-test.txt @@ -272,7 +272,7 @@ msgpack==1.0.5 # via # distributed # ray -multimethod==1.9.1 +multimethod==1.12 # via -r requirements.in mypy==0.982 # via -r requirements.in diff --git a/requirements.in b/requirements.in index df53991bf..f4e2217e5 100644 --- a/requirements.in +++ b/requirements.in @@ -14,7 +14,7 @@ typing_extensions >= 3.7.4.3 frictionless <= 4.40.8 pyarrow pydantic -multimethod <= 1.10.0 +multimethod pandas-stubs pyspark[connect] >= 3.2.0 polars >= 0.20.0 diff --git a/setup.py b/setup.py index 749915fe6..d5bff6a07 100644 --- a/setup.py +++ b/setup.py @@ -46,7 +46,7 @@ packages=find_packages(include=["pandera*"]), package_data={"pandera": ["py.typed"]}, install_requires=[ - "multimethod <= 1.10.0", + "multimethod", "numpy >= 1.19.0", "packaging >= 20.0", "pandas >= 1.2.0",