Skip to content

Commit

Permalink
fix: Changing update function wrapper to not wrap provided parameters.
Browse files Browse the repository at this point in the history
Refs: #12
  • Loading branch information
Thomas Zilio committed Nov 19, 2024
1 parent 7faa803 commit a0f329c
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 54 deletions.
29 changes: 15 additions & 14 deletions zcollection/collection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,24 +556,21 @@ def update(
lock=True))

local_func: WrappedPartitionCallable = _wrap_update_func(
*args,
delayed=delayed,
func=func,
fs=self.fs,
immutable=self._immutable,
selected_variables=selected_variables
) if depth == 0 else _wrap_update_func_with_overlap(
delayed=delayed,
depth=depth,
dim=self.partition_properties.dim,
func=func,
fs=self.fs,
immutable=self._immutable,
selected_partitions=selected_partitions,
selected_variables=selected_variables,
**kwargs) if depth == 0 else _wrap_update_func_with_overlap(
*args,
delayed=delayed,
depth=depth,
dim=self.partition_properties.dim,
func=func,
fs=self.fs,
immutable=self._immutable,
selected_partitions=selected_partitions,
selected_variables=selected_variables,
trim=trim,
**kwargs)
trim=trim)

client: dask.distributed.Client = dask_utils.get_client()

Expand All @@ -582,7 +579,11 @@ def update(
or dask_utils.dask_workers(client, cores_only=True))
storage.execute_transaction(
client, self.synchronizer,
client.map(local_func, tuple(batches), key=func.__name__))
client.map(local_func,
tuple(batches),
key=func.__name__,
func_args=args,
func_kwargs=kwargs))
tuple(map(self.fs.invalidate_cache, selected_partitions))

def drop_variable(
Expand Down
3 changes: 2 additions & 1 deletion zcollection/collection/callable_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

#: Function type to load and call a callback function of type
#: :class:`PartitionCallable`.
WrappedPartitionCallable = Callable[[Sequence[str]], None]
WrappedPartitionCallable = Callable[[Sequence[str], list[Any], dict[str, Any]],
None]


#: pylint: disable=too-few-public-methods
Expand Down
21 changes: 8 additions & 13 deletions zcollection/collection/detail.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,13 +252,11 @@ def calculate_slice(


def _wrap_update_func(
*args,
delayed: bool,
func: UpdateCallable,
fs: fsspec.AbstractFileSystem,
immutable: str | None,
selected_variables: Iterable[str] | None,
**kwargs,
) -> WrappedPartitionCallable:
"""Wrap an update function taking a partition's dataset as input and
returning variable's values as a numpy array.
Expand All @@ -271,21 +269,21 @@ def _wrap_update_func(
selected_variables: Name of the variables to load from the dataset.
If None, all variables are loaded.
trim: Whether to trim the overlap.
*args: Positional arguments to pass to the function.
**kwargs: Keyword arguments to pass to the function.
Returns:
The wrapped function that takes a set of dataset partitions and the
variable name as input and returns the variable's values as a numpy
array.
"""

def wrap_function(partitions: Iterable[str]) -> None:
def wrap_function(partitions: Iterable[str], func_args: list[Any],
func_kwargs: dict[str, Any]) -> None:
# Applying function for each partition's data
for partition in partitions:
zds: dataset.Dataset = _load_dataset(delayed, fs, immutable,
partition, selected_variables)
dictionary: dict[str, ArrayLike] = func(zds, *args, **kwargs)
dictionary: dict[str, ArrayLike] = func(zds, *func_args,
**func_kwargs)
tuple(
update_zarr_array( # type: ignore[func-returns-value]
dirname=join_path(partition, varname),
Expand All @@ -297,7 +295,6 @@ def wrap_function(partitions: Iterable[str]) -> None:


def _wrap_update_func_with_overlap(
*args,
delayed: bool,
depth: int,
dim: str,
Expand All @@ -307,7 +304,6 @@ def _wrap_update_func_with_overlap(
selected_partitions: Sequence[str],
selected_variables: Iterable[str] | None,
trim: bool,
**kwargs,
) -> WrappedPartitionCallable:
"""Wrap an update function taking a partition's dataset as input and
returning variable's values as a numpy array.
Expand All @@ -323,8 +319,6 @@ def _wrap_update_func_with_overlap(
selected_variables: Name of the variables to load from the dataset.
If None, all variables are loaded.
trim: Whether to trim the overlap.
*args: Positional arguments to pass to the function.
**kwargs: Keyword arguments to pass to the function.
Returns:
The wrapped function that takes a set of dataset partitions and the
Expand All @@ -334,7 +328,8 @@ def _wrap_update_func_with_overlap(
if depth < 0:
raise ValueError('Depth must be non-negative.')

def wrap_function(partitions: Sequence[str]) -> None:
def wrap_function(partitions: Sequence[str], func_args: list[Any],
func_kwargs: dict[str, Any]) -> None:
# Applying function for each partition's data
for partition in partitions:

Expand All @@ -353,15 +348,15 @@ def wrap_function(partitions: Sequence[str]) -> None:
selected_variables=selected_variables)
# pylint: enable=duplicate-code

_update_with_overlap(*args,
_update_with_overlap(*func_args,
func=func,
zds=zds,
indices=indices,
dim=dim,
fs=fs,
path=partition,
trim=trim,
**kwargs)
**func_kwargs)

return wrap_function

Expand Down
13 changes: 7 additions & 6 deletions zcollection/merging/tests/test_merging.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,13 @@ def test_update_fs(
"""Test the _update_fs function."""
generator = data.create_test_dataset(delayed=False)
zds = next(generator)
zds_sc = dask_client.scatter(zds)

partition_folder = local_fs.root.joinpath('variable=1')

zattrs = str(partition_folder.joinpath('.zattrs'))
future = dask_client.submit(_update_fs, str(partition_folder),
dask_client.scatter(zds), local_fs.fs)
future = dask_client.submit(_update_fs, str(partition_folder), zds_sc,
local_fs.fs)
dask_client.gather(future)
assert local_fs.exists(zattrs)

Expand All @@ -60,7 +61,7 @@ def test_update_fs(
try:
future = dask_client.submit(_update_fs,
str(partition_folder),
dask_client.scatter(zds),
zds_sc,
local_fs.fs,
synchronizer=ThrowError())
dask_client.gather(future)
Expand All @@ -83,13 +84,13 @@ def test_perform(
zds = next(generator)

path = str(local_fs.root.joinpath('variable=1'))
zds_sc = dask_client.scatter(zds)

future = dask_client.submit(_update_fs, path, dask_client.scatter(zds),
local_fs.fs)
future = dask_client.submit(_update_fs, path, zds_sc, local_fs.fs)
dask_client.gather(future)

future = dask_client.submit(perform,
dask_client.scatter(zds),
zds_sc,
path,
'time',
local_fs.fs,
Expand Down
8 changes: 3 additions & 5 deletions zcollection/view/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,8 +504,6 @@ def update(
wrap_function = _wrap_update_func(
func,
self.fs,
*args,
**kwargs,
)
else:
if selected_variables is not None and len(
Expand All @@ -521,8 +519,6 @@ def update(
self.fs,
self.view_ref,
trim,
*args,
**kwargs,
)

batchs: Iterator[Sequence[Any]] = dask_utils.split_sequence(
Expand All @@ -532,7 +528,9 @@ def update(
wrap_function,
tuple(batchs),
key=func.__name__,
base_dir=self.base_dir)
base_dir=self.base_dir,
func_args=args,
func_kwargs=kwargs)
storage.execute_transaction(client, self.synchronizer, awaitables)

# pylint: disable=duplicate-code
Expand Down
27 changes: 12 additions & 15 deletions zcollection/view/detail.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""
from __future__ import annotations

from typing import Any
import base64
from collections.abc import Callable, Iterable, Iterator, Sequence
import dataclasses
Expand Down Expand Up @@ -37,8 +38,9 @@
from ..type_hints import ArrayLike, NDArray

#: Type of the function used to update a view.
ViewUpdateCallable = \
Callable[[Iterable[tuple[dataset.Dataset, str]], str], None]
ViewUpdateCallable = Callable[
[Iterable[tuple[dataset.Dataset, str]], str, list[Any], dict[str,
Any]], None]

#: Name of the file that contains the checksum of the view.
CHECKSUM_FILE = '.checksum'
Expand Down Expand Up @@ -350,28 +352,26 @@ def calculate_slice(
def _wrap_update_func(
func: collection.UpdateCallable,
fs: fsspec.AbstractFileSystem,
*args,
**kwargs,
) -> ViewUpdateCallable:
"""Wrap an update function taking a list of partition's dataset and
partition's path as input and returning None.
Args:
func: The update function.
fs: The file system used to access the variables in the view.
*args: The arguments of the update function.
**kwargs: The keyword arguments of the update function.
Returns:
The wrapped function.
"""

def wrap_function(parameters: Iterable[tuple[dataset.Dataset, str]],
base_dir: str) -> None:
base_dir: str, func_args: list[Any],
func_kwargs: dict[str, Any]) -> None:
"""Wrap the function to be applied to the dataset."""
for zds, partition in parameters:
# Applying function on partition's data
dictionary: dict[str, ArrayLike] = func(zds, *args, **kwargs)
dictionary: dict[str, ArrayLike] = func(zds, *func_args,
**func_kwargs)
tuple(
update_zarr_array( # type: ignore[func-returns-value]
dirname=join_path(base_dir, partition, varname),
Expand All @@ -389,8 +389,6 @@ def _wrap_update_func_overlap(
fs: fsspec.AbstractFileSystem,
view_ref: collection.Collection,
trim: bool,
*args,
**kwargs,
) -> ViewUpdateCallable:
"""Wrap an update function taking a list of partition's dataset and
partition's path as input and returning None.
Expand All @@ -402,8 +400,6 @@ def _wrap_update_func_overlap(
fs: The file system used to access the variables in the view.
view_ref: The view reference.
trim: If True, trim the dataset to the overlap.
*args: The arguments of the update function.
**kwargs: The keyword arguments of the update function.
Returns:
The wrapped function.
Expand All @@ -414,7 +410,8 @@ def _wrap_update_func_overlap(
raise ValueError('The depth must be positive')

def wrap_function(parameters: Iterable[tuple[dataset.Dataset, str]],
base_dir: str) -> None:
base_dir: str, func_args: list[Any],
func_kwargs: dict[str, Any]) -> None:
"""Wrap the function to be applied to the dataset."""
zds: dataset.Dataset
indices: slice
Expand All @@ -425,15 +422,15 @@ def wrap_function(parameters: Iterable[tuple[dataset.Dataset, str]],
# pylint: disable=duplicate-code
# False positive with the function _wrap_update_func_with_overlap
# defined in the module zcollection.collection.detail
_update_with_overlap(*args,
_update_with_overlap(*func_args,
func=func,
zds=zds,
indices=indices,
dim=dim,
fs=fs,
path=join_path(base_dir, partition),
trim=trim,
**kwargs)
**func_kwargs)
# pylint: enable=duplicate-code

return wrap_function
Expand Down

0 comments on commit a0f329c

Please sign in to comment.