Skip to content

Commit

Permalink
Memory leaks (#352)
Browse files Browse the repository at this point in the history
Potential solutions to
#349
  • Loading branch information
mzouink authored Dec 10, 2024
2 parents e9697ad + 2467cd8 commit 0b3e2d0
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,13 @@ def array(self, mode="r") -> Array:
assert num_channels is None, "Input labels cannot have a channel dimension"

def group_array(data):
out = da.zeros((len(self.groupings), *array.physical_shape), dtype=np.uint8)
for i, (_, group_ids) in enumerate(self.groupings):
if len(group_ids) == 0:
out[i] = data != self.background
else:
out[i] = da.isin(data, group_ids)
groups = [
da.isin(data, group_ids)
if len(group_ids) > 0
else data != self.background
for _, group_ids in self.groupings
]
out = da.stack(groups, axis=0)
return out

data = group_array(array.data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from typing import List, Dict, Optional
from funlib.persistence import Array
import numpy as np
import dask.array as da


Expand Down Expand Up @@ -45,18 +44,15 @@ class ConcatArrayConfig(ArrayConfig):
def array(self, mode: str = "r") -> Array:
arrays = [config.array(mode) for _, config in self.source_array_configs.items()]

out_data = da.stack([array.data for array in arrays], axis=0)
out_array = Array(
da.zeros(len(arrays), *arrays[0].physical_shape, dtype=arrays[0].dtype),
out_data,
offset=arrays[0].offset,
voxel_size=arrays[0].voxel_size,
axis_names=["c^"] + arrays[0].axis_names,
units=arrays[0].units,
)

def set_channels(data):
for i, array in enumerate(arrays):
data[i] = array.data[:]
return data

out_array.lazy_op(set_channels)
# callable lazy op so funlib.persistence doesn't try to recoginize this data as writable
out_array.lazy_op(lambda data: data)
return out_array
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ def array(self, mode: str = "r") -> Array:
assert num_channels_from_array(array) is not None

out_array = Array(
da.zeros(*array.physical_shape, dtype=array.dtype),
da.zeros(array.physical_shape, dtype=array.dtype),
offset=array.offset,
voxel_size=array.voxel_size,
axis_names=array.axis_names[1:],
units=array.units,
)

out_array.data = da.maximum(array.data, axis=0)
out_array.data = da.max(array.data, axis=0)

# mark data as non-writable
out_array.lazy_op(lambda data: data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,26 @@
from funlib.geometry import Coordinate
from funlib.persistence import Array

from xarray_multiscale.multiscale import downscale_dask
from xarray_multiscale import windowed_mean
import numpy as np
import dask.array as da

from typing import Sequence


def adjust_shape(array: da.Array, scale_factors: Sequence[int]) -> da.Array:
"""
Crop array to a shape that is a multiple of the scale factors.
This allows for clean downsampling.
"""
misalignment = np.any(np.mod(array.shape, scale_factors))
if misalignment:
new_shape = np.subtract(array.shape, np.mod(array.shape, scale_factors))
slices = tuple(slice(0, s) for s in new_shape)
array = array[slices]
return array


@attr.s
class ResampledArrayConfig(ArrayConfig):
Expand Down Expand Up @@ -37,7 +57,27 @@ class ResampledArrayConfig(ArrayConfig):
metadata={"help_text": "The order of the interpolation!"}
)

def preprocess(self, array: Array) -> Array:
"""
Preprocess an array by resampling it to the desired voxel size.
"""
if self.downsample is not None:
downsample = Coordinate(self.downsample)
return Array(
data=downscale_dask(
adjust_shape(array.data, downsample),
windowed_mean,
scale_factors=downsample,
),
offset=array.offset,
voxel_size=array.voxel_size * downsample,
axis_names=array.axis_names,
units=array.units,
)
elif self.upsample is not None:
raise NotImplementedError("Upsampling not yet implemented")

def array(self, mode: str = "r") -> Array:
# This is non trivial. We want to upsample or downsample the source
# array lazily. Not entirely sure how to do this with dask arrays.
raise NotImplementedError()
source_array = self.source_array_config.array(mode)

return self.preprocess(source_array)
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from dacapo.experiments.arraytypes.probabilities import ProbabilityArray
from .predictor import Predictor
from dacapo.experiments import Model
from dacapo.experiments.arraytypes import DistanceArray
from dacapo.tmp import np_to_funlib_array
from dacapo.utils.balance_weights import balance_weights

Expand Down Expand Up @@ -394,6 +393,7 @@ def __find_boundaries(self, labels):
# bound.: 00000001000100000001000 2n - 1

logger.debug(f"computing boundaries for {labels.shape}")
labels = labels.astype(np.uint8)

dims = len(labels.shape)
in_shape = labels.shape
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ dependencies = [
"upath",
"boto3",
"matplotlib",
"xarray-multiscale",
]

# extras
Expand Down Expand Up @@ -201,6 +202,7 @@ module = [
"napari.*",
"empanada.*",
"IPython.*",
"xarray_multiscale.*"
]
ignore_missing_imports = true

Expand Down
1 change: 0 additions & 1 deletion tests/components/test_gp_arraysource.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,5 @@ def test_gp_dacapo_array_source(array_config):
batch = source_node.request_batch(request)
data = batch[key].data
if data.dtype == bool:
raise ValueError("Data should not be bools")
data = data.astype(np.uint8)
assert (data - array[array.roi]).sum() == 0
32 changes: 32 additions & 0 deletions tests/components/test_preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from dacapo.experiments.datasplits.datasets.arrays.resampled_array_config import (
ResampledArrayConfig,
)

import numpy as np
from funlib.persistence import Array
from funlib.geometry import Coordinate


def test_resample():
# test downsampling arrays with shape 10 and 11 by a factor of 2 to test croping works
for top in [11, 12]:
arr = Array(np.array(np.arange(1, top)), offset=(0,), voxel_size=(3,))
resample_config = ResampledArrayConfig(
"test_resample", None, upsample=None, downsample=(2,), interp_order=1
)
resampled = resample_config.preprocess(arr)
assert resampled.voxel_size == Coordinate((6,))
assert resampled.shape == (5,)
assert np.allclose(resampled[:], np.array([1.5, 3.5, 5.5, 7.5, 9.5]))

# test 2D array
arr = Array(
np.array(np.arange(1, 11).reshape(5, 2).T), offset=(0, 0), voxel_size=(3, 3)
)
resample_config = ResampledArrayConfig(
"test_resample", None, upsample=None, downsample=(2, 1), interp_order=1
)
resampled = resample_config.preprocess(arr)
assert resampled.voxel_size == Coordinate(6, 3)
assert resampled.shape == (1, 5)
assert np.allclose(resampled[:], np.array([[1.5, 3.5, 5.5, 7.5, 9.5]]))

0 comments on commit 0b3e2d0

Please sign in to comment.