Skip to content

Commit

Permalink
bugfix: check_input decorator handles functions with kwargs
Browse files Browse the repository at this point in the history
Signed-off-by: Niels Bantilan <[email protected]>
  • Loading branch information
cosmicBboy committed Dec 26, 2024
1 parent 2f269a4 commit a97f020
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 32 deletions.
48 changes: 18 additions & 30 deletions pandera/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,41 +258,29 @@ def _wrapper(*args, **kwargs):
pos_args[obj_getter], *validate_args
)
args = list(pos_args.values())
elif obj_getter is None and kwargs:
# get the first key in the same order specified in the
# function argument.
args_names = _get_fn_argnames(wrapped)

try:
kwargs[args_names[0]] = schema.validate(
kwargs[args_names[0]], *validate_args
)
except errors.SchemaError as e:
_handle_schema_error(
"check_input",
wrapped,
schema,
kwargs[args_names[0]],
e,
)
elif obj_getter is None and args:
elif obj_getter is None:
try:
_fn = (
wrapped
if not hasattr(wrapped, "__wrapped__")
else wrapped.__wrapped__
)
_fn = _unwrap_fn(wrapped)
obj_arg_name, *_ = _get_fn_argnames(wrapped)
arg_spec_args = inspect.getfullargspec(_fn).args
if arg_spec_args[0] in ("self", "cls"):
arg_idx = 0 if len(args) == 1 else 1

arg_idx = arg_spec_args.index(obj_arg_name)

if obj_arg_name in kwargs:
obj = kwargs[obj_arg_name]
kwargs[obj_arg_name] = schema.validate(
obj, *validate_args
)
elif obj_arg_name in pos_args:
obj = args[arg_idx]
args[arg_idx] = schema.validate(obj, *validate_args)
else:
arg_idx = 0
args[arg_idx] = schema.validate(
args[arg_idx], *validate_args
)
raise ValueError(
f"argument {obj_arg_name} not found in args or kwargs"
)
except errors.SchemaError as e:
_handle_schema_error(
"check_input", wrapped, schema, args[0], e
"check_input", wrapped, schema, obj, e
)
else:
raise TypeError(
Expand Down
23 changes: 21 additions & 2 deletions tests/core/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,25 @@ def _assert_expectation(result_df):
)


class DfModel(DataFrameModel):
col: int


# pylint: disable=unused-argument
@check_input(DfModel.to_schema())
def fn_with_check_input(data: DataFrame[DfModel], *, kwarg: bool = False):
return data


def test_check_input_on_fn_with_kwarg():
"""
That that a check_input correctly validates a function where the first arg
is the dataframe and the function has other kwargs.
"""
df = pd.DataFrame({"col": [1]})
fn_with_check_input(df, kwarg=True)


def test_check_io() -> None:
# pylint: disable=too-many-locals
"""Test that check_io correctly validates/invalidates data."""
Expand Down Expand Up @@ -777,13 +796,13 @@ def test_check_types_with_literal_type(arg_examples):
"""Test that using typing module types works with check_types"""

for example in arg_examples:
arg_type = Literal[example]
arg_type = Literal[example] # type: ignore

@check_types
def transform_with_literal(
df: DataFrame[InSchema],
# pylint: disable=unused-argument,cell-var-from-loop
arg: arg_type,
arg: arg_type, # type: ignore
) -> DataFrame[OutSchema]:
return df.assign(b=100) # type: ignore

Expand Down

0 comments on commit a97f020

Please sign in to comment.