Skip to content

Commit

Permalink
Defer backend registration to validation time (#1818)
Browse files Browse the repository at this point in the history
* revert changes made in #1803 due to import loading time performance hit

Signed-off-by: cosmicBboy <[email protected]>

* update env

Signed-off-by: cosmicBboy <[email protected]>

* wip profiling schema def startup time

Signed-off-by: cosmicBboy <[email protected]>

* implement validation-time backend registration

Signed-off-by: cosmicBboy <[email protected]>

* fix unit tests

Signed-off-by: cosmicBboy <[email protected]>

* make linter happy

Signed-off-by: cosmicBboy <[email protected]>

* register backends in strategy method

Signed-off-by: cosmicBboy <[email protected]>

* import dask accessor

Signed-off-by: cosmicBboy <[email protected]>

* register default backends in array strategy

Signed-off-by: cosmicBboy <[email protected]>

---------

Signed-off-by: cosmicBboy <[email protected]>
  • Loading branch information
cosmicBboy authored Sep 24, 2024
1 parent 4935940 commit 3a2ff78
Show file tree
Hide file tree
Showing 17 changed files with 224 additions and 133 deletions.
3 changes: 2 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,5 @@ disable=
too-many-locals,
redefined-outer-name,
logging-fstring-interpolation,
multiple-statements
multiple-statements,
cyclic-import
5 changes: 4 additions & 1 deletion pandera/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pandera.backends.base.builtin_checks
import pandera.backends.base.builtin_hypotheses
import pandera.backends.pandas
from pandera import errors
from pandera import errors, external_config
from pandera.accessors import pandas_accessor
from pandera.api import extensions
from pandera.api.checks import Check
Expand Down Expand Up @@ -82,6 +82,9 @@
from pandera.dtypes import Complex256, Float128


external_config._set_pyspark_environment_variables()


__all__ = [
# dtypes
"Bool",
Expand Down
5 changes: 5 additions & 0 deletions pandera/api/base/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,11 @@ def get_builtin_check_fn(cls, name: str):
"""Gets a built-in check function"""
return cls.CHECK_FUNCTION_REGISTRY[name]

@classmethod
def is_builtin_check(cls, name: str) -> bool:
"""Gets a built-in check function"""
return name in cls.CHECK_FUNCTION_REGISTRY

@classmethod
def from_builtin_check_name(
cls,
Expand Down
16 changes: 10 additions & 6 deletions pandera/api/base/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def __init__(
self.description = description
self.metadata = metadata
self.drop_invalid_rows = drop_invalid_rows
self._register_default_backends()

def validate(
self,
Expand Down Expand Up @@ -107,6 +106,7 @@ def get_backend(
check_type: Optional[Type] = None,
) -> BaseSchemaBackend:
"""Get the backend associated with the type of ``check_obj`` ."""

if check_obj is not None:
check_obj_cls = type(check_obj)
elif check_type is not None:
Expand All @@ -115,6 +115,8 @@ def get_backend(
raise ValueError(
"Must pass in one of `check_obj` or `check_type`."
)

cls.register_default_backends(check_obj_cls)
classes = inspect.getmro(check_obj_cls)
for _class in classes:
try:
Expand All @@ -126,17 +128,19 @@ def get_backend(
f"Looked up the following base classes: {classes}"
)

def _register_default_backends(self):
@staticmethod
def register_default_backends(check_obj_cls: Type):
"""Register default backends.
This method is invoked in the `__init__` method for subclasses that
implement the API for a specific dataframe object, and should be
overridden in those subclasses.
This method is invoked in the `get_backend` method so that the
appropriate validation backend is loaded at validation time instead of
schema-definition time.
This method needs to be implemented by the schema subclass.
"""

def __setstate__(self, state):
self.__dict__ = state
self._register_default_backends()


def inferred_schema_guard(method):
Expand Down
6 changes: 6 additions & 0 deletions pandera/api/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,12 @@ def __call__(
``failure_cases``: subset of the check_object that failed.
"""
if self.name is not None and self.is_builtin_check(self.name):
# since we use multimethod.multidispatch to dispatch built-in check
# functions, we need to reload the function here in case additional
# type signatures have been registered for a specific built-in
# check.
self._check_fn = self.get_builtin_check_fn(self.name)
backend = self.get_backend(check_obj)(self)
return backend(check_obj, column)

Expand Down
59 changes: 0 additions & 59 deletions pandera/api/dataframe/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
)

from pandera import errors
from pandera.import_utils import strategy_import_error
from pandera.api.base.schema import BaseSchema, inferred_schema_guard
from pandera.api.base.types import CheckList, ParserList, StrictType
from pandera.api.checks import Check
Expand Down Expand Up @@ -297,18 +296,6 @@ def get_dtypes(self, check_obj: TDataObject) -> Dict[str, DataType]:
def coerce_dtype(self, check_obj: TDataObject) -> TDataObject:
return self.get_backend(check_obj).coerce_dtype(check_obj, schema=self)

def validate(
self,
check_obj: TDataObject,
head: Optional[int] = None,
tail: Optional[int] = None,
sample: Optional[int] = None,
random_state: Optional[int] = None,
lazy: bool = False,
inplace: bool = False,
) -> TDataObject:
raise NotImplementedError

def __call__(
self,
dataframe: TDataObject,
Expand Down Expand Up @@ -1284,52 +1271,6 @@ def to_json(

return pandera.io.to_json(self, target, **kwargs)

###########################
# Schema Strategy Methods #
###########################

@strategy_import_error
def strategy(
self, *, size: Optional[int] = None, n_regex_columns: int = 1
):
"""Create a ``hypothesis`` strategy for generating a DataFrame.
:param size: number of elements to generate
:param n_regex_columns: number of regex columns to generate.
:returns: a strategy that generates pandas DataFrame objects.
"""
from pandera import strategies as st

return st.dataframe_strategy(
self.dtype,
columns=self.columns,
checks=self.checks,
unique=self.unique,
index=self.index,
size=size,
n_regex_columns=n_regex_columns,
)

def example(
self, size: Optional[int] = None, n_regex_columns: int = 1
) -> TDataObject:
"""Generate an example of a particular size.
:param size: number of elements in the generated DataFrame.
:returns: pandas DataFrame object.
"""
# pylint: disable=import-outside-toplevel,cyclic-import,import-error
import hypothesis

with warnings.catch_warnings():
warnings.simplefilter(
"ignore",
category=hypothesis.errors.NonInteractiveExampleWarning,
)
return self.strategy(
size=size, n_regex_columns=n_regex_columns
).example()


def _validate_columns(
column_dict: dict[Any, Any], # type: ignore [name-defined]
Expand Down
28 changes: 22 additions & 6 deletions pandera/api/pandas/array.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Core pandas array specification."""

import warnings
from typing import Any, Optional, cast
from typing import Any, Optional, Type, cast

import pandas as pd

Expand All @@ -13,6 +13,7 @@
from pandera.config import get_config_context
from pandera.dtypes import DataType, UniqueSettings
from pandera.engines import pandas_engine
from pandera.errors import BackendNotFoundError


class ArraySchema(ComponentSchema[TDataObject]):
Expand All @@ -31,11 +32,6 @@ def _validate_attributes(self):
"DataFrameSchema dtype."
)

def _register_default_backends(self):
from pandera.backends.pandas.register import register_pandas_backends

register_pandas_backends()

@property
def dtype(self) -> DataType:
"""Get the pandas dtype"""
Expand All @@ -46,6 +42,21 @@ def dtype(self, value: Optional[PandasDtypeInputTypes]) -> None:
"""Set the pandas dtype"""
self._dtype = pandas_engine.Engine.dtype(value) if value else None

@staticmethod
def register_default_backends(check_obj_cls: Type):
from pandera.backends.pandas.register import register_pandas_backends

_cls = check_obj_cls
try:
register_pandas_backends(f"{_cls.__module__}.{_cls.__name__}")
except BackendNotFoundError:
for base_cls in _cls.__bases__:
base_cls_name = f"{base_cls.__module__}.{base_cls.__name__}"
try:
register_pandas_backends(base_cls_name)
except BackendNotFoundError:
pass

###########################
# Schema Strategy Methods #
###########################
Expand All @@ -59,6 +70,8 @@ def strategy(self, *, size=None):
"""
from pandera import strategies as st

self.register_default_backends(pd.DataFrame)

return st.series_strategy(
self.dtype,
checks=self.checks,
Expand Down Expand Up @@ -226,6 +239,9 @@ def validate( # type: ignore [override]
if hasattr(check_obj, "dask"):
# special case for dask series
if inplace:
# pylint: disable=unused-import
from pandera.accessors import dask_accessor

check_obj = check_obj.pandera.add_schema(self)
else:
check_obj = check_obj.copy()
Expand Down
76 changes: 70 additions & 6 deletions pandera/api/pandas/container.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Core pandas dataframe container specification."""

import warnings
from typing import Optional
from typing import Optional, Type

import pandas as pd

Expand All @@ -10,17 +10,14 @@
from pandera.config import get_config_context
from pandera.dtypes import DataType
from pandera.engines import pandas_engine
from pandera.errors import BackendNotFoundError
from pandera.import_utils import strategy_import_error


# pylint: disable=too-many-public-methods,too-many-locals
class DataFrameSchema(_DataFrameSchema[pd.DataFrame]):
"""A light-weight pandas DataFrame validator."""

def _register_default_backends(self):
from pandera.backends.pandas.register import register_pandas_backends

register_pandas_backends()

@property
def dtype(
self,
Expand Down Expand Up @@ -106,6 +103,9 @@ def validate(

if hasattr(check_obj, "dask"):
# special case for dask dataframes
# pylint: disable=unused-import
from pandera.accessors import dask_accessor

if inplace:
check_obj = check_obj.pandera.add_schema(self)
else:
Expand Down Expand Up @@ -143,6 +143,7 @@ def _validate(
lazy: bool = False,
inplace: bool = False,
) -> pd.DataFrame:

if self._is_inferred:
warnings.warn(
f"This {type(self)} is an inferred schema that hasn't been "
Expand All @@ -162,3 +163,66 @@ def _validate(
lazy=lazy,
inplace=inplace,
)

@staticmethod
def register_default_backends(check_obj_cls: Type):
from pandera.backends.pandas.register import register_pandas_backends

_cls = check_obj_cls
try:
register_pandas_backends(f"{_cls.__module__}.{_cls.__name__}")
except BackendNotFoundError:
for base_cls in _cls.__bases__:
base_cls_name = f"{base_cls.__module__}.{base_cls.__name__}"
try:
register_pandas_backends(base_cls_name)
except BackendNotFoundError:
pass

###########################
# Schema Strategy Methods #
###########################

@strategy_import_error
def strategy(
self, *, size: Optional[int] = None, n_regex_columns: int = 1
):
"""Create a ``hypothesis`` strategy for generating a DataFrame.
:param size: number of elements to generate
:param n_regex_columns: number of regex columns to generate.
:returns: a strategy that generates pandas DataFrame objects.
"""
from pandera import strategies as st

self.register_default_backends(pd.DataFrame)

return st.dataframe_strategy(
self.dtype,
columns=self.columns,
checks=self.checks,
unique=self.unique,
index=self.index,
size=size,
n_regex_columns=n_regex_columns,
)

def example(
self, size: Optional[int] = None, n_regex_columns: int = 1
) -> pd.DataFrame:
"""Generate an example of a particular size.
:param size: number of elements in the generated DataFrame.
:returns: pandas DataFrame object.
"""
# pylint: disable=import-outside-toplevel,cyclic-import,import-error
import hypothesis

with warnings.catch_warnings():
warnings.simplefilter(
"ignore",
category=hypothesis.errors.NonInteractiveExampleWarning,
)
return self.strategy(
size=size, n_regex_columns=n_regex_columns
).example()
7 changes: 5 additions & 2 deletions pandera/api/polars/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import logging
from typing import Any, Optional
from typing import Any, Optional, Type

import polars as pl

Expand Down Expand Up @@ -105,7 +105,10 @@ def __init__(

self.set_regex()

def _register_default_backends(self):
@staticmethod
def register_default_backends(
check_obj_cls: Type,
): # pylint: disable=unused-argument
register_polars_backends()

def validate(
Expand Down
Loading

0 comments on commit 3a2ff78

Please sign in to comment.