Skip to content

Commit

Permalink
Ibis check backend (unionai-oss#1831)
Browse files Browse the repository at this point in the history
* [wip] add minimal ibis check backend implementation

Signed-off-by: cosmicBboy <[email protected]>

* support scalar, column, and table check output types

Signed-off-by: cosmicBboy <[email protected]>

* support scalar, column, and table check output types

Signed-off-by: cosmicBboy <[email protected]>

* Ibis check backend suggestions (unionai-oss#1855)

* Apply suggestions from code review

Signed-off-by: Deepyaman Datta <[email protected]>

* Fix row-order-dependent order by adding table cols

Signed-off-by: Deepyaman Datta <[email protected]>

---------

Signed-off-by: Deepyaman Datta <[email protected]>

* fix lint

Signed-off-by: cosmicBboy <[email protected]>

* fix unit tests

Signed-off-by: cosmicBboy <[email protected]>

---------

Signed-off-by: cosmicBboy <[email protected]>
Signed-off-by: Deepyaman Datta <[email protected]>
Co-authored-by: Deepyaman Datta <[email protected]>
Signed-off-by: Deepyaman Datta <[email protected]>
  • Loading branch information
cosmicBboy and deepyaman committed Dec 24, 2024
1 parent 9886e87 commit e960c9a
Show file tree
Hide file tree
Showing 9 changed files with 461 additions and 6 deletions.
9 changes: 8 additions & 1 deletion pandera/api/ibis/types.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
"""Ibis types."""

from typing import NamedTuple, Union
from typing import NamedTuple, Optional, Union

import ibis.expr.datatypes as dt
import ibis.expr.types as ir


class IbisData(NamedTuple):
table: ir.Table
key: Optional[str] = None


class CheckResult(NamedTuple):
"""Check result for user-defined checks."""

Expand All @@ -15,6 +20,8 @@ class CheckResult(NamedTuple):
failure_cases: ir.Table


IbisCheckObjects = Union[ir.Table, ir.Column]

IbisDtypeInputTypes = Union[
str,
type,
Expand Down
160 changes: 160 additions & 0 deletions pandera/backends/ibis/checks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
"""Check backend for Ibis."""

from functools import partial
from typing import Optional


import ibis
import ibis.expr.types as ir
from ibis import _, selectors as s
from ibis.expr.types.groupby import GroupedTable
from multimethod import overload

from pandera.api.base.checks import CheckResult
from pandera.api.checks import Check
from pandera.api.ibis.types import IbisData
from pandera.backends.base import BaseCheckBackend
from pandera.backends.ibis.utils import select_column
from pandera.constants import CHECK_OUTPUT_KEY

CHECK_OUTPUT_SUFFIX = f"__{CHECK_OUTPUT_KEY}__"


class IbisCheckBackend(BaseCheckBackend):
"""Check backend for Ibis."""

def __init__(self, check: Check):
"""Initializes a check backend object."""
super().__init__(check)
assert check._check_fn is not None, "Check._check_fn must be set."
self.check = check
self.check_fn = partial(check._check_fn, **check._check_kwargs)

def groupby(self, check_obj) -> GroupedTable:
"""Implements groupby behavior for check object."""
raise NotImplementedError

def query(self, check_obj: ir.Table):
"""Implements querying behavior to produce subset of check object."""
raise NotImplementedError

def aggregate(self, check_obj: ir.Table):
"""Implements aggregation behavior for check object."""
raise NotImplementedError

def preprocess(self, check_obj: ir.Table, key: Optional[str]):
"""Preprocesses a check object before applying the check function."""
# This handles the case of Series validation, which has no other context except
# for the index to groupby on. Right now grouping by the index is not allowed.
return check_obj

def apply(self, check_obj: IbisData):
"""Apply the check function to a check object."""
if self.check.element_wise:
selector = (
select_column(check_obj.key)
if check_obj.key is not None
else s.all()
)
out = check_obj.table.mutate(
s.across(
selector, self.check_fn, f"{{col}}{CHECK_OUTPUT_SUFFIX}"
)
).select(selector | s.endswith(CHECK_OUTPUT_SUFFIX))
else:
out = self.check_fn(check_obj)
if isinstance(out, dict):
out = check_obj.table.mutate(
**{f"{k}{CHECK_OUTPUT_SUFFIX}": v for k, v in out.items()}
)

if isinstance(out, ir.Table):
# for checks that return a boolean dataframe, make sure all columns
# are boolean and reduce to a single boolean column.
acc = ibis.literal(True)
for col in out.columns:
if col.endswith(CHECK_OUTPUT_SUFFIX):
assert out[col].type().is_boolean(), (
f"column '{col[: -len(CHECK_OUTPUT_SUFFIX)]}' "
"is not boolean. If check function returns a "
"dataframe, it must contain only boolean columns."
)
acc = acc & out[col]
return out.mutate({CHECK_OUTPUT_KEY: acc})
elif out.type().is_boolean():
return out
else:
raise TypeError( # pragma: no cover
f"output type of check_fn not recognized: {type(out)}"
)

@overload
def postprocess(self, check_obj, check_output):
"""Postprocesses the result of applying the check function."""
raise TypeError( # pragma: no cover
f"output type of check_fn not recognized: {type(check_output)}"
)

@overload # type: ignore [no-redef]
def postprocess(
self,
check_obj: IbisData,
check_output: ir.BooleanScalar,
) -> CheckResult:
"""Postprocesses the result of applying the check function."""
return CheckResult(
check_output=check_output,
check_passed=check_output,
checked_object=check_obj,
failure_cases=None,
)

@overload # type: ignore [no-redef]
def postprocess(
self,
check_obj: IbisData,
check_output: ir.BooleanColumn,
) -> CheckResult:
"""Postprocesses the result of applying the check function."""
check_output = check_output.name(CHECK_OUTPUT_KEY)
failure_cases = check_obj.table.filter(~check_output)
if check_obj.key is not None:
failure_cases = failure_cases.select(check_obj.key)
return CheckResult(
check_output=check_output,
check_passed=check_output.all(),
checked_object=check_obj,
failure_cases=failure_cases,
)

@overload # type: ignore [no-redef]
def postprocess(
self,
check_obj: IbisData,
check_output: ir.Table,
) -> CheckResult:
"""Postprocesses the result of applying the check function."""
passed = check_output[CHECK_OUTPUT_KEY].all()
failure_cases = check_output.filter(~_[CHECK_OUTPUT_KEY]).drop(
s.endswith(f"__{CHECK_OUTPUT_KEY}__")
| select_column(CHECK_OUTPUT_KEY)
)

if check_obj.key is not None:
failure_cases = failure_cases.select(check_obj.key)
return CheckResult(
check_output=check_output.select(CHECK_OUTPUT_KEY),
check_passed=passed,
checked_object=check_obj,
failure_cases=failure_cases,
)

def __call__(
self,
check_obj: ir.Table,
key: Optional[str] = None,
) -> CheckResult:
check_obj = self.preprocess(check_obj, key)
ibis_data = IbisData(check_obj, key)
check_output = self.apply(ibis_data)
return self.postprocess(ibis_data, check_output)
37 changes: 35 additions & 2 deletions pandera/backends/ibis/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,12 @@ def validate(

# run the checks
core_checks = [
(self.check_dtype, (sample, schema)),
self.check_dtype,
self.run_checks,
]

for check, args in core_checks:
args = (sample, schema)
for check in core_checks:
results = check(*args)
if isinstance(results, CoreCheckResult):
results = [results]
Expand Down Expand Up @@ -114,3 +116,34 @@ def check_dtype(
message=msg,
failure_cases=failure_cases,
)

@validate_scope(scope=ValidationScope.DATA)
def run_checks(self, check_obj, schema) -> List[CoreCheckResult]:
check_results: List[CoreCheckResult] = []
for check_index, check in enumerate(schema.checks):
try:
check_results.append(
self.run_check(
check_obj,
schema,
check,
check_index,
schema.selector,
)
)
except Exception as err: # pylint: disable=broad-except
# catch other exceptions that may occur when executing the Check
err_msg = f'"{err.args[0]}"' if len(err.args) > 0 else ""
msg = f"{err.__class__.__name__}({err_msg})"
check_results.append(
CoreCheckResult(
passed=False,
check=check,
check_index=check_index,
reason_code=SchemaErrorReason.CHECK_ERROR,
message=msg,
failure_cases=msg,
original_exc=err,
)
)
return check_results
2 changes: 2 additions & 0 deletions pandera/backends/ibis/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def register_ibis_backends():
from pandera.api.ibis.container import DataFrameSchema
from pandera.backends.ibis.components import ColumnBackend
from pandera.backends.ibis.container import DataFrameSchemaBackend
from pandera.backends.ibis.checks import IbisCheckBackend

DataFrameSchema.register_backend(ir.Table, DataFrameSchemaBackend)
Column.register_backend(ir.Table, ColumnBackend)
Check.register_backend(ir.Table, IbisCheckBackend)
10 changes: 10 additions & 0 deletions pandera/backends/ibis/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""Utility functions for the Ibis backend."""

from ibis import selectors as s


def select_column(*names):
"""Select a column from a table."""
if hasattr(s, "cols"):
return s.cols(*names)
return s.c(*names)
2 changes: 1 addition & 1 deletion pandera/backends/pandas/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


class PandasCheckBackend(BaseCheckBackend):
"""Check backend ofr pandas."""
"""Check backend for pandas."""

def __init__(self, check: Check):
"""Initializes a check backend object."""
Expand Down
4 changes: 2 additions & 2 deletions pandera/backends/polars/checks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Check backend for pandas."""
"""Check backend for polars."""

from functools import partial
from typing import Optional
Expand All @@ -18,7 +18,7 @@


class PolarsCheckBackend(BaseCheckBackend):
"""Check backend ofr pandas."""
"""Check backend for polars."""

def __init__(self, check: Check):
"""Initializes a check backend object."""
Expand Down
1 change: 1 addition & 0 deletions pandera/ibis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
from pandera.api.ibis.components import Column
from pandera.api.ibis.container import DataFrameSchema
from pandera.api.ibis.model import DataFrameModel
from pandera.api.ibis.types import IbisData
Loading

0 comments on commit e960c9a

Please sign in to comment.