diff --git a/src/awkward/_meta/numpymeta.py b/src/awkward/_meta/numpymeta.py index 8ffa1597e5..830519401e 100644 --- a/src/awkward/_meta/numpymeta.py +++ b/src/awkward/_meta/numpymeta.py @@ -2,6 +2,8 @@ from __future__ import annotations +from functools import cached_property + from awkward._meta.meta import Meta from awkward._nplikes.shape import ShapeItem from awkward._typing import JSONSerializable @@ -31,13 +33,13 @@ def purelist_depth(self) -> int: def is_identity_like(self) -> bool: return False - @property - def minmax_depth(self) -> tuple[int, int]: + @cached_property + def minmax_depth(self) -> tuple[int, int]: # type: ignore[override] depth = len(self.inner_shape) + 1 return (depth, depth) - @property - def branch_depth(self) -> tuple[bool, int]: + @cached_property + def branch_depth(self) -> tuple[bool, int]: # type: ignore[override] return (False, len(self.inner_shape) + 1) @property diff --git a/src/awkward/_meta/recordmeta.py b/src/awkward/_meta/recordmeta.py index 8e43215cf5..966cfa2ca2 100644 --- a/src/awkward/_meta/recordmeta.py +++ b/src/awkward/_meta/recordmeta.py @@ -2,6 +2,8 @@ from __future__ import annotations +from functools import cached_property + from awkward._meta.meta import Meta from awkward._regularize import is_integer from awkward._typing import Generic, JSONSerializable, TypeVar @@ -39,19 +41,21 @@ def purelist_depth(self) -> int: def is_identity_like(self) -> bool: return False - @property - def minmax_depth(self) -> tuple[int, int]: + @cached_property + def minmax_depth(self) -> tuple[int, int]: # type: ignore[override] if len(self._contents) == 0: return (1, 1) - mins, maxs = [], [] - for content in self._contents: - mindepth, maxdepth = content.minmax_depth - mins.append(mindepth) - maxs.append(maxdepth) - return (min(mins), max(maxs)) - - @property - def branch_depth(self) -> tuple[bool, int]: + mindepth, maxdepth = self._contents[0].minmax_depth + for content in self._contents[1:]: + mindepth_, maxdepth_ = content.minmax_depth + if mindepth_ < mindepth: + mindepth = mindepth_ + if maxdepth_ > maxdepth: + maxdepth = maxdepth_ + return (mindepth, maxdepth) + + @cached_property + def branch_depth(self) -> tuple[bool, int]: # type: ignore[override] if len(self._contents) == 0: return False, 1 @@ -80,8 +84,8 @@ def is_leaf(self) -> bool: # type: ignore[override] def contents(self) -> list[T]: return self._contents - @property - def fields(self) -> list[str]: + @cached_property + def fields(self) -> list[str]: # type: ignore[override] if self._fields is None: return [str(i) for i in range(len(self._contents))] else: diff --git a/src/awkward/_meta/unionmeta.py b/src/awkward/_meta/unionmeta.py index eefce689b2..60061aee41 100644 --- a/src/awkward/_meta/unionmeta.py +++ b/src/awkward/_meta/unionmeta.py @@ -3,6 +3,7 @@ from __future__ import annotations from collections import Counter +from functools import cached_property from awkward._meta.meta import Meta from awkward._typing import Generic, JSONSerializable, TypeVar @@ -31,15 +32,15 @@ def purelist_parameters(self, *keys: str) -> JSONSerializable: return None - @property - def purelist_isregular(self) -> bool: + @cached_property + def purelist_isregular(self) -> bool: # type: ignore[override] for content in self._contents: if not content.purelist_isregular: return False return True - @property - def purelist_depth(self) -> int: + @cached_property + def purelist_depth(self) -> int: # type: ignore[override] out = None for content in self._contents: if out is None: @@ -53,19 +54,21 @@ def purelist_depth(self) -> int: def is_identity_like(self) -> bool: return False - @property - def minmax_depth(self) -> tuple[int, int]: + @cached_property + def minmax_depth(self) -> tuple[int, int]: # type: ignore[override] if len(self._contents) == 0: return (0, 0) - mins, maxs = [], [] - for content in self._contents: - mindepth, maxdepth = content.minmax_depth - mins.append(mindepth) - maxs.append(maxdepth) - return (min(mins), max(maxs)) - - @property - def branch_depth(self) -> tuple[bool, int]: + mindepth, maxdepth = self._contents[0].minmax_depth + for content in self._contents[1:]: + mindepth_, maxdepth_ = content.minmax_depth + if mindepth_ < mindepth: + mindepth = mindepth_ + if maxdepth_ > maxdepth: + maxdepth = maxdepth_ + return (mindepth, maxdepth) + + @cached_property + def branch_depth(self) -> tuple[bool, int]: # type: ignore[override] if len(self._contents) == 0: return False, 1 @@ -83,8 +86,8 @@ def branch_depth(self) -> tuple[bool, int]: assert min_depth is not None return any_branch, min_depth - @property - def fields(self) -> list[str]: + @cached_property + def fields(self) -> list[str]: # type: ignore[override] field_counts = Counter([f for c in self._contents for f in c.fields]) return [f for f, n in field_counts.items() if n == len(self._contents)] @@ -102,6 +105,6 @@ def dimension_optiontype(self) -> bool: def content(self, index: int) -> T: return self._contents[index] - @property + @cached_property def contents(self) -> list[T]: return self._contents diff --git a/src/awkward/_nplikes/shape.py b/src/awkward/_nplikes/shape.py index 9020c88157..bbf148a8a8 100644 --- a/src/awkward/_nplikes/shape.py +++ b/src/awkward/_nplikes/shape.py @@ -63,6 +63,9 @@ def __str__(self) -> str: def __repr__(self): return self._instance_name + def __hash__(self): + return hash(self._instance_name) + def __eq__(self, other) -> bool: if other is self: return True diff --git a/src/awkward/_nplikes/typetracer.py b/src/awkward/_nplikes/typetracer.py index c4ad914c85..3aadf57ecd 100644 --- a/src/awkward/_nplikes/typetracer.py +++ b/src/awkward/_nplikes/typetracer.py @@ -3,6 +3,7 @@ from __future__ import annotations from collections.abc import Collection, Iterator, Sequence, Set +from functools import lru_cache from numbers import Number from typing import Callable @@ -55,7 +56,9 @@ def is_unknown_scalar(array: Any) -> TypeGuard[TypeTracerArray]: def is_unknown_integer(array: Any) -> TypeGuard[TypeTracerArray]: - return is_unknown_scalar(array) and np.issubdtype(array.dtype, np.integer) + return cast( + bool, is_unknown_scalar(array) and np.issubdtype(array.dtype, np.integer) + ) def is_unknown_array(array: Any) -> TypeGuard[TypeTracerArray]: @@ -1147,38 +1150,7 @@ def derive_slice_for_length( return start, stop, step, self.index_as_shape_item(slice_length) def broadcast_shapes(self, *shapes: tuple[ShapeItem, ...]) -> tuple[ShapeItem, ...]: - ndim = max((len(s) for s in shapes), default=0) - result: list[ShapeItem] = [1] * ndim - - for shape in shapes: - # Right broadcasting - missing_dim = ndim - len(shape) - if missing_dim > 0: - head: tuple[int, ...] = (1,) * missing_dim - shape = head + shape - - # Fail if we absolutely know the shapes aren't compatible - for i, item in enumerate(shape): - # Item is unknown, take it - if is_unknown_length(item): - result[i] = item - # Existing item is unknown, keep it - elif is_unknown_length(result[i]): - continue - # Items match, continue - elif result[i] == item: - continue - # Item is broadcastable, take existing - elif item == 1: - continue - # Existing is broadcastable, take it - elif result[i] == 1: - result[i] = item - else: - raise ValueError( - "known component of shape does not match broadcast result" - ) - return tuple(result) + return _broadcast_shapes(*shapes) def broadcast_arrays(self, *arrays: TypeTracerArray) -> list[TypeTracerArray]: for x in arrays: @@ -1706,6 +1678,42 @@ def __dlpack__(self, stream=None): raise NotImplementedError +@lru_cache +def _broadcast_shapes(*shapes): + ndim = max((len(s) for s in shapes), default=0) + result: list[ShapeItem] = [1] * ndim + + for shape in shapes: + # Right broadcasting + missing_dim = ndim - len(shape) + if missing_dim > 0: + head: tuple[int, ...] = (1,) * missing_dim + shape = head + shape + + # Fail if we absolutely know the shapes aren't compatible + for i, item in enumerate(shape): + # Item is unknown, take it + if is_unknown_length(item): + result[i] = item + # Existing item is unknown, keep it + elif is_unknown_length(result[i]): + continue + # Items match, continue + elif result[i] == item: + continue + # Item is broadcastable, take existing + elif item == 1: + continue + # Existing is broadcastable, take it + elif result[i] == 1: + result[i] = item + else: + raise ValueError( + "known component of shape does not match broadcast result" + ) + return tuple(result) + + def _attach_report( layout: Content, form: Form, diff --git a/src/awkward/forms/form.py b/src/awkward/forms/form.py index ef7019bf68..b77c9db9df 100644 --- a/src/awkward/forms/form.py +++ b/src/awkward/forms/form.py @@ -8,6 +8,7 @@ from collections import defaultdict from collections.abc import Callable, Iterable, Mapping from fnmatch import fnmatchcase +from functools import lru_cache from glob import escape as escape_glob import awkward as ak @@ -202,6 +203,7 @@ def from_dict(input: Mapping) -> Form: ) +@lru_cache def from_json(input: str) -> Form: return from_dict(json.loads(input)) diff --git a/src/awkward/types/numpytype.py b/src/awkward/types/numpytype.py index b366d628d9..4a35c5f256 100644 --- a/src/awkward/types/numpytype.py +++ b/src/awkward/types/numpytype.py @@ -5,6 +5,7 @@ import json import re from collections.abc import Mapping +from functools import lru_cache from awkward._behavior import find_array_typestr from awkward._nplikes.numpy_like import NumpyMetadata @@ -25,6 +26,7 @@ def is_primitive(primitive): return primitive in _primitive_to_dtype_dict +@lru_cache def primitive_to_dtype(primitive): if _primitive_to_dtype_datetime.match(primitive) is not None: return np.dtype(primitive) @@ -42,6 +44,7 @@ def primitive_to_dtype(primitive): return out +@lru_cache def dtype_to_primitive(dtype): if dtype.kind.upper() == "M" and dtype == dtype.newbyteorder("="): return str(dtype)