From 4f8bdbf47ff820863512603afd8aebb18efac5cc Mon Sep 17 00:00:00 2001 From: Niels Bantilan Date: Mon, 30 Sep 2024 16:04:54 -0400 Subject: [PATCH 1/2] Reduce import overhead to improve runtime (#1821) * consolidate how to get supported datatypes Signed-off-by: cosmicBboy * fix get backend logic Signed-off-by: cosmicBboy * import modin and pyspark in builtin_checks conditionally Signed-off-by: cosmicBboy * fix unit test pa.typing.pyspark Signed-off-by: cosmicBboy --------- Signed-off-by: cosmicBboy --- pandera/__init__.py | 5 +- pandera/api/extensions.py | 2 + pandera/api/pandas/types.py | 191 +++++++++++++----- pandera/backends/pandas/builtin_checks.py | 16 +- pandera/backends/pandas/hypotheses.py | 8 - pandera/backends/pandas/register.py | 115 +---------- tests/core/test_extension_modules.py | 11 +- tests/hypotheses/test_hypotheses.py | 12 +- .../pyspark/test_schemas_on_pyspark_pandas.py | 25 ++- 9 files changed, 184 insertions(+), 201 deletions(-) diff --git a/pandera/__init__.py b/pandera/__init__.py index 24536eaf1..8780f52d5 100644 --- a/pandera/__init__.py +++ b/pandera/__init__.py @@ -1,6 +1,7 @@ # pylint: disable=wrong-import-position """A flexible and expressive pandas validation library.""" +import os import platform from pandera._patch_numpy2 import _patch_numpy2 @@ -12,7 +13,6 @@ import pandera.backends.base.builtin_hypotheses import pandera.backends.pandas from pandera import errors, external_config -from pandera.accessors import pandas_accessor from pandera.api import extensions from pandera.api.checks import Check from pandera.api.dataframe.model_components import ( @@ -82,9 +82,6 @@ from pandera.dtypes import Complex256, Float128 -external_config._set_pyspark_environment_variables() - - __all__ = [ # dtypes "Bool", diff --git a/pandera/api/extensions.py b/pandera/api/extensions.py index 42322856d..58de01eb7 100644 --- a/pandera/api/extensions.py +++ b/pandera/api/extensions.py @@ -27,6 +27,7 @@ def register_builtin_check( fn=None, strategy: Optional[Callable] = None, _check_cls: Type = Check, + aliases: Optional[List[str]] = None, **outer_kwargs, ): """Register a check method to the Check namespace. @@ -41,6 +42,7 @@ def register_builtin_check( register_builtin_check, strategy=strategy, _check_cls=_check_cls, + aliases=aliases, **outer_kwargs, ) diff --git a/pandera/api/pandas/types.py b/pandera/api/pandas/types.py index 2006359a9..2b4fd5d6c 100644 --- a/pandera/api/pandas/types.py +++ b/pandera/api/pandas/types.py @@ -1,12 +1,15 @@ +# pylint: disable=unused-import """Utility functions for pandas validation.""" from functools import lru_cache -from typing import NamedTuple, Tuple, Type, Union +from typing import Any, NamedTuple, Type, TypeVar, Union, Optional import numpy as np import pandas as pd from pandera.dtypes import DataType +from pandera.errors import BackendNotFoundError + PandasDtypeInputTypes = Union[ str, @@ -17,68 +20,146 @@ np.dtype, ] -SupportedTypes = NamedTuple( - "SupportedTypes", - ( - ("table_types", Tuple[type, ...]), - ("field_types", Tuple[type, ...]), - ("index_types", Tuple[type, ...]), - ("multiindex_types", Tuple[type, ...]), - ), +PANDAS_LIKE_CLS_NAMES = frozenset( + [ + "DataFrame", + "Series", + "Index", + "MultiIndex", + "GeoDataFrame", + "GeoSeries", + ] ) -@lru_cache(maxsize=None) -def supported_types() -> SupportedTypes: - """Get the types supported by pandera schemas.""" - # pylint: disable=import-outside-toplevel - table_types = [pd.DataFrame] - field_types = [pd.Series] - index_types = [pd.Index] - multiindex_types = [pd.MultiIndex] +class BackendTypes(NamedTuple): - try: - import pyspark.pandas as ps + # list of datatypes available + dataframe_datatypes: tuple + series_datatypes: tuple + index_datatypes: tuple + multiindex_datatypes: tuple + check_backend_types: tuple - table_types.append(ps.DataFrame) - field_types.append(ps.Series) - index_types.append(ps.Index) - multiindex_types.append(ps.MultiIndex) - except ImportError: - pass - try: # pragma: no cover - import modin.pandas as mpd - table_types.append(mpd.DataFrame) - field_types.append(mpd.Series) - index_types.append(mpd.Index) - multiindex_types.append(mpd.MultiIndex) - except ImportError: - pass - try: +@lru_cache +def get_backend_types(check_cls_fqn: str): + + dataframe_datatypes = [] + series_datatypes = [] + index_datatypes = [] + multiindex_datatypes = [] + + mod_name, *mod_path, cls_name = check_cls_fqn.split(".") + if mod_name != "pandera": + if cls_name not in PANDAS_LIKE_CLS_NAMES: + raise BackendNotFoundError( + f"cls_name {cls_name} not in {PANDAS_LIKE_CLS_NAMES}" + ) + + if mod_name == "pandera": + # assume mod_path e.g. ["typing", "pandas"] + assert mod_path[0] == "typing" + *_, mod_name = mod_path + else: + mod_name = mod_name.split(".")[-1] + + def register_pandas_backend(): + from pandera.accessors import pandas_accessor + + dataframe_datatypes.append(pd.DataFrame) + series_datatypes.append(pd.Series) + index_datatypes.append(pd.Index) + multiindex_datatypes.append(pd.MultiIndex) + + def register_dask_backend(): import dask.dataframe as dd + from pandera.accessors import dask_accessor + + dataframe_datatypes.append(dd.DataFrame) + series_datatypes.append(dd.Series) + index_datatypes.append(dd.Index) + + def register_modin_backend(): + import modin.pandas as mpd + from pandera.accessors import modin_accessor + + dataframe_datatypes.append(mpd.DataFrame) + series_datatypes.append(mpd.Series) + index_datatypes.append(mpd.Index) + multiindex_datatypes.append(mpd.MultiIndex) - table_types.append(dd.DataFrame) - field_types.append(dd.Series) - index_types.append(dd.Index) - except ImportError: - pass - - return SupportedTypes( - tuple(table_types), - tuple(field_types), - tuple(index_types), - tuple(multiindex_types), + def register_pyspark_backend(): + import pyspark.pandas as ps + from pandera.accessors import pyspark_accessor + + dataframe_datatypes.append(ps.DataFrame) + series_datatypes.append(ps.Series) + index_datatypes.append(ps.Index) + multiindex_datatypes.append(ps.MultiIndex) + + def register_geopandas_backend(): + import geopandas as gpd + + register_pandas_backend() + dataframe_datatypes.append(gpd.GeoDataFrame) + series_datatypes.append(gpd.GeoSeries) + + register_fn = { + "pandas": register_pandas_backend, + "dask_expr": register_dask_backend, + "modin": register_modin_backend, + "pyspark": register_pyspark_backend, + "geopandas": register_geopandas_backend, + "pandera": lambda: None, + }[mod_name] + + register_fn() + + check_backend_types = [ + *dataframe_datatypes, + *series_datatypes, + *index_datatypes, + ] + + return BackendTypes( + dataframe_datatypes=tuple(dataframe_datatypes), + series_datatypes=tuple(series_datatypes), + index_datatypes=tuple(index_datatypes), + multiindex_datatypes=tuple(multiindex_datatypes), + check_backend_types=tuple(check_backend_types), ) +T = TypeVar("T") + + +def _get_fullname(_cls: Type) -> str: + return f"{_cls.__module__}.{_cls.__name__}" + + +def get_backend_types_from_mro(_cls: Type) -> Optional[BackendTypes]: + try: + return get_backend_types(_get_fullname(_cls)) + except BackendNotFoundError: + for base_cls in _cls.__bases__: + try: + return get_backend_types(_get_fullname(base_cls)) + except BackendNotFoundError: + pass + return None + + def is_table(obj): """Verifies whether an object is table-like. Where a table is a 2-dimensional data matrix of rows and columns, which can be indexed in multiple different ways. """ - return isinstance(obj, supported_types().table_types) + backend_types = get_backend_types_from_mro(type(obj)) + return backend_types is not None and isinstance( + obj, backend_types.dataframe_datatypes + ) def is_field(obj): @@ -87,17 +168,26 @@ def is_field(obj): Where a field is a columnar representation of data in a table-like data structure. """ - return isinstance(obj, supported_types().field_types) + backend_types = get_backend_types_from_mro(type(obj)) + return backend_types is not None and isinstance( + obj, backend_types.series_datatypes + ) def is_index(obj): """Verifies whether an object is a table index.""" - return isinstance(obj, supported_types().index_types) + backend_types = get_backend_types_from_mro(type(obj)) + return backend_types is not None and isinstance( + obj, backend_types.index_datatypes + ) def is_multiindex(obj): """Verifies whether an object is a multi-level table index.""" - return isinstance(obj, supported_types().multiindex_types) + backend_types = get_backend_types_from_mro(type(obj)) + return backend_types is not None and isinstance( + obj, backend_types.multiindex_datatypes + ) def is_table_or_field(obj): @@ -105,9 +195,6 @@ def is_table_or_field(obj): return is_table(obj) or is_field(obj) -is_supported_check_obj = is_table_or_field - - def is_bool(x): """Verifies whether an object is a boolean type.""" return isinstance(x, (bool, np.bool_)) diff --git a/pandera/backends/pandas/builtin_checks.py b/pandera/backends/pandas/builtin_checks.py index 3a8bf603a..16ff117cc 100644 --- a/pandera/backends/pandas/builtin_checks.py +++ b/pandera/backends/pandas/builtin_checks.py @@ -1,5 +1,6 @@ """Pandas implementation of built-in checks""" +import sys import operator import re from typing import Any, Iterable, Optional, TypeVar, Union, cast @@ -10,18 +11,21 @@ from pandera.api.extensions import register_builtin_check -from pandera.typing.modin import MODIN_INSTALLED -from pandera.typing.pyspark import PYSPARK_INSTALLED +MODIN_IMPORTED = "modin" in sys.modules +PYSPARK_IMPORTED = "pyspark" in sys.modules -if MODIN_INSTALLED and not PYSPARK_INSTALLED: # pragma: no cover + +# TODO: create a separate module for each framework: dask, modin, pyspark +# so checks are registered for the correct framework. +if MODIN_IMPORTED and not PYSPARK_IMPORTED: # pragma: no cover import modin.pandas as mpd PandasData = Union[pd.Series, pd.DataFrame, mpd.Series, mpd.DataFrame] -elif not MODIN_INSTALLED and PYSPARK_INSTALLED: # pragma: no cover +elif not MODIN_IMPORTED and PYSPARK_IMPORTED: # pragma: no cover import pyspark.pandas as ppd PandasData = Union[pd.Series, pd.DataFrame, ppd.Series, ppd.DataFrame] # type: ignore[misc] -elif MODIN_INSTALLED and PYSPARK_INSTALLED: # pragma: no cover +elif MODIN_IMPORTED and PYSPARK_IMPORTED: # pragma: no cover import modin.pandas as mpd import pyspark.pandas as ppd @@ -40,6 +44,8 @@ T = TypeVar("T") +# TODO: remove aliases, it's not needed anymore since aliases are defined in the +# Check class. @register_builtin_check( aliases=["eq"], strategy=st.eq_strategy, diff --git a/pandera/backends/pandas/hypotheses.py b/pandera/backends/pandas/hypotheses.py index a9bdc6f8d..1bbd72ad4 100644 --- a/pandera/backends/pandas/hypotheses.py +++ b/pandera/backends/pandas/hypotheses.py @@ -12,14 +12,6 @@ from pandera.backends.pandas.checks import PandasCheckBackend -try: - from scipy import stats # pylint: disable=unused-import -except ImportError: # pragma: no cover - HAS_SCIPY = False -else: - HAS_SCIPY = True - - DEFAULT_ALPHA = 0.01 diff --git a/pandera/backends/pandas/register.py b/pandera/backends/pandas/register.py index 4fee3b128..d14150829 100644 --- a/pandera/backends/pandas/register.py +++ b/pandera/backends/pandas/register.py @@ -1,8 +1,7 @@ -# pylint: disable=unused-import """Register pandas backends.""" from functools import lru_cache -from typing import NamedTuple, Optional +from typing import Optional from pandera.backends.pandas.array import SeriesSchemaBackend from pandera.backends.pandas.checks import PandasCheckBackend @@ -14,117 +13,6 @@ from pandera.backends.pandas.container import DataFrameSchemaBackend from pandera.backends.pandas.hypotheses import PandasHypothesisBackend from pandera.backends.pandas.parsers import PandasParserBackend -from pandera.errors import BackendNotFoundError - - -class BackendTypes(NamedTuple): - - # list of datatypes available - dataframe_datatypes: list - series_datatypes: list - index_datatypes: list - multiindex_datatypes: list - check_backend_types: list - - -PANDAS_LIKE_CLS_NAMES = frozenset( - [ - "DataFrame", - "Series", - "Index", - "MultiIndex", - "GeoDataFrame", - "GeoSeries", - ] -) - - -@lru_cache -def get_backend_types(check_cls_fqn: str): - - dataframe_datatypes = [] - series_datatypes = [] - index_datatypes = [] - multiindex_datatypes = [] - - mod_name, *mod_path, cls_name = check_cls_fqn.split(".") - if mod_name != "pandera": - if cls_name not in PANDAS_LIKE_CLS_NAMES: - raise BackendNotFoundError( - f"cls_name {cls_name} not in {PANDAS_LIKE_CLS_NAMES}" - ) - - if mod_name == "pandera": - # assume mod_path e.g. ["typing", "pandas"] - assert mod_path[0] == "typing" - *_, mod_name = mod_path - - def register_pandas_backend(): - import pandas as pd - from pandera.accessors import pandas_accessor - - dataframe_datatypes.append(pd.DataFrame) - series_datatypes.append(pd.Series) - index_datatypes.append(pd.Index) - multiindex_datatypes.append(pd.MultiIndex) - - def register_dask_backend(): - import dask.dataframe as dd - from pandera.accessors import dask_accessor - - dataframe_datatypes.append(dd.DataFrame) - series_datatypes.append(dd.Series) - index_datatypes.append(dd.Index) - - def register_modin_backend(): - import modin.pandas as mpd - from pandera.accessors import modin_accessor - - dataframe_datatypes.append(mpd.DataFrame) - series_datatypes.append(mpd.Series) - index_datatypes.append(mpd.Index) - multiindex_datatypes.append(mpd.MultiIndex) - - def register_pyspark_backend(): - import pyspark.pandas as ps - from pandera.accessors import pyspark_accessor - - dataframe_datatypes.append(ps.DataFrame) - series_datatypes.append(ps.Series) - index_datatypes.append(ps.Index) - multiindex_datatypes.append(ps.MultiIndex) - - def register_geopandas_backend(): - import geopandas as gpd - - register_pandas_backend() - dataframe_datatypes.append(gpd.GeoDataFrame) - series_datatypes.append(gpd.GeoSeries) - - register_fn = { - "pandas": register_pandas_backend, - "dask_expr": register_dask_backend, - "modin": register_modin_backend, - "pyspark": register_pyspark_backend, - "geopandas": register_geopandas_backend, - "pandera": lambda: None, - }[mod_name] - - register_fn() - - check_backend_types = [ - *dataframe_datatypes, - *series_datatypes, - *index_datatypes, - ] - - return BackendTypes( - dataframe_datatypes=dataframe_datatypes, - series_datatypes=series_datatypes, - index_datatypes=index_datatypes, - multiindex_datatypes=multiindex_datatypes, - check_backend_types=check_backend_types, - ) @lru_cache @@ -152,6 +40,7 @@ def register_pandas_backends( from pandera.api.pandas.components import Column, Index, MultiIndex from pandera.api.pandas.container import DataFrameSchema from pandera.api.parsers import Parser + from pandera.api.pandas.types import get_backend_types assert check_cls_fqn is not None, ( "pandas backend registration requires passing in the fully qualified " diff --git a/tests/core/test_extension_modules.py b/tests/core/test_extension_modules.py index f124a09fd..e54a0d841 100644 --- a/tests/core/test_extension_modules.py +++ b/tests/core/test_extension_modules.py @@ -4,13 +4,20 @@ import pytest from pandera.api.hypotheses import Hypothesis -from pandera.backends.pandas.hypotheses import HAS_SCIPY + + +try: + from scipy import stats # pylint: disable=unused-import +except ImportError: # pragma: no cover + SCIPY_INSTALLED = False +else: + SCIPY_INSTALLED = True def test_hypotheses_module_import() -> None: """Test that Hypothesis built-in methods raise import error.""" data = pd.Series([1, 2, 3]) - if not HAS_SCIPY: + if not SCIPY_INSTALLED: for fn, check_args in [ ( lambda: Hypothesis.two_sample_ttest("sample1", "sample2"), diff --git a/tests/hypotheses/test_hypotheses.py b/tests/hypotheses/test_hypotheses.py index 80b5b9bc0..d992cd882 100644 --- a/tests/hypotheses/test_hypotheses.py +++ b/tests/hypotheses/test_hypotheses.py @@ -12,15 +12,19 @@ String, errors, ) -from pandera.backends.pandas.hypotheses import HAS_SCIPY -if HAS_SCIPY: - from scipy import stats # pylint: disable=import-error + +try: + from scipy import stats # pylint: disable=unused-import +except ImportError: # pragma: no cover + SCIPY_INSTALLED = False +else: + SCIPY_INSTALLED = True # skip all tests in module if "hypotheses" depends aren't installed pytestmark = pytest.mark.skipif( - not HAS_SCIPY, reason='needs "hypotheses" module dependencies' + not SCIPY_INSTALLED, reason='needs "hypotheses" module dependencies' ) diff --git a/tests/pyspark/test_schemas_on_pyspark_pandas.py b/tests/pyspark/test_schemas_on_pyspark_pandas.py index b1c9542ee..2d8c8b9fd 100644 --- a/tests/pyspark/test_schemas_on_pyspark_pandas.py +++ b/tests/pyspark/test_schemas_on_pyspark_pandas.py @@ -12,6 +12,7 @@ from packaging import version import pandera as pa +from pandera.typing import pyspark as pyspark_typing from pandera import dtypes, extensions, system from pandera.engines import numpy_engine, pandas_engine, geopandas_engine from pandera.typing import DataFrame, Index, Series @@ -554,11 +555,9 @@ def test_schema_model(): # pylint: disable=too-few-public-methods class Schema(pa.DataFrameModel): - int_field: pa.typing.pyspark.Series[int] = pa.Field(gt=0) - float_field: pa.typing.pyspark.Series[float] = pa.Field(lt=0) - str_field: pa.typing.pyspark.Series[str] = pa.Field( - isin=["a", "b", "c"] - ) + int_field: pyspark_typing.Series[int] = pa.Field(gt=0) + float_field: pyspark_typing.Series[float] = pa.Field(lt=0) + str_field: pyspark_typing.Series[str] = pa.Field(isin=["a", "b", "c"]) valid_df = ps.DataFrame( { @@ -621,10 +620,10 @@ def test_check_decorators(): # pylint: disable=too-few-public-methods class InSchema(pa.DataFrameModel): - a: pa.typing.pyspark.Series[int] + a: pyspark_typing.Series[int] class OutSchema(InSchema): - b: pa.typing.pyspark.Series[int] + b: pyspark_typing.Series[int] @pa.check_input(in_schema) @pa.check_output(out_schema) @@ -648,16 +647,16 @@ def function_check_io_invalid(df: ps.DataFrame) -> ps.DataFrame: @pa.check_types def function_check_types( - df: pa.typing.pyspark.DataFrame[InSchema], - ) -> pa.typing.pyspark.DataFrame[OutSchema]: + df: pyspark_typing.DataFrame[InSchema], + ) -> pyspark_typing.DataFrame[OutSchema]: df["b"] = df["a"] + 1 - return typing.cast(pa.typing.pyspark.DataFrame[OutSchema], df) + return typing.cast(pyspark_typing.DataFrame[OutSchema], df) @pa.check_types def function_check_types_invalid( - df: pa.typing.pyspark.DataFrame[InSchema], - ) -> pa.typing.pyspark.DataFrame[OutSchema]: - return typing.cast(pa.typing.pyspark.DataFrame[OutSchema], df) + df: pyspark_typing.DataFrame[InSchema], + ) -> pyspark_typing.DataFrame[OutSchema]: + return typing.cast(pyspark_typing.DataFrame[OutSchema], df) valid_df = ps.DataFrame({"a": [1, 2, 3]}) invalid_df = ps.DataFrame({"b": [1, 2, 3]}) From ea4538d2f71795bba09e602d568d673798c92b35 Mon Sep 17 00:00:00 2001 From: gab23r <106454081+gab23r@users.noreply.github.com> Date: Fri, 1 Nov 2024 14:55:45 +0100 Subject: [PATCH 2/2] accept expr in default value (#1820) * accept expr in default value Signed-off-by: gabriel * add test --------- Signed-off-by: gabriel Co-authored-by: gabriel --- pandera/backends/polars/components.py | 5 ++++- tests/polars/test_polars_components.py | 20 ++++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/pandera/backends/polars/components.py b/pandera/backends/polars/components.py index 5c40a106f..751ffa53b 100644 --- a/pandera/backends/polars/components.py +++ b/pandera/backends/polars/components.py @@ -387,7 +387,10 @@ def set_default(self, check_obj: pl.LazyFrame, schema) -> pl.LazyFrame: if hasattr(schema, "default") and schema.default is None: return check_obj - default_value = pl.lit(schema.default, dtype=schema.dtype.type) + if isinstance(schema.default, pl.Expr): + default_value = schema.default + else: + default_value = pl.lit(schema.default, dtype=schema.dtype.type) expr = pl.col(schema.selector) if is_float_dtype(check_obj, schema.selector): expr = expr.fill_nan(default_value) diff --git a/tests/polars/test_polars_components.py b/tests/polars/test_polars_components.py index c1a325a9a..1f5b60847 100644 --- a/tests/polars/test_polars_components.py +++ b/tests/polars/test_polars_components.py @@ -256,4 +256,24 @@ def test_set_default(data, dtype, default): assert validated_data.select(pl.col("column").eq(default).any()).item() +def test_expr_as_default(): + schema = pa.DataFrameSchema( + columns={ + "a": pa.Column(int), + "b": pa.Column(float, default=1), + "c": pa.Column(str, default=pl.lit("foo")), + "d": pa.Column(int, nullable=True, default=pl.col("a")), + }, + add_missing_columns=True, + coerce=True, + ) + df = pl.LazyFrame({"a": [1, 2, 3]}) + assert schema.validate(df).collect().to_dict(as_series=False) == { + "a": [1, 2, 3], + "b": [1.0, 1.0, 1.0], + "c": ["foo", "foo", "foo"], + "d": [1, 2, 3], + } + + def test_column_schema_on_lazyframe_coerce(): ...