Skip to content

Commit

Permalink
first iteration of all and any
Browse files Browse the repository at this point in the history
  • Loading branch information
jkanche committed May 24, 2024
1 parent 268d8dc commit ddd6957
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 2 deletions.
83 changes: 81 additions & 2 deletions src/delayedarray/DelayedArray.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Sequence, Tuple, Union, Optional, List, Callable
import numpy
from numpy import array, dtype, integer, issubdtype, ndarray, prod, array2string
from numpy import dtype, ndarray, array2string
from collections import namedtuple

from .SparseNdarray import SparseNdarray
Expand All @@ -24,7 +24,7 @@

from ._subset import _getitem_subset_preserves_dimensions, _getitem_subset_discards_dimensions, _repr_subset
from ._isometric import translate_ufunc_to_op_simple, translate_ufunc_to_op_with_args
from ._statistics import array_mean, array_var, array_sum, _create_offset_multipliers
from ._statistics import array_mean, array_var, array_sum, _create_offset_multipliers, array_any, array_all

__author__ = "ltla"
__copyright__ = "ltla"
Expand Down Expand Up @@ -264,6 +264,12 @@ def __array_function__(self, func, types, args, kwargs) -> "DelayedArray":
if func == numpy.var:
return self.var(**kwargs)

if func == numpy.any:
return self.any(**kwargs)

if func == numpy.all:
return self.all(**kwargs)

if func == numpy.shape:
return self.shape

Expand Down Expand Up @@ -901,6 +907,79 @@ def var(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype: Optiona
masked=is_masked(self),
)

def any(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype: Optional[numpy.dtype] = None, buffer_size: int = 1e8) -> numpy.ndarray:
"""Test whether any array element along a given axis evaluates to True.
Compute this test across the ``DelayedArray``, possibly over a
given axis or set of axes. If the seed has a ``any()`` method, that
method is called directly with the supplied arguments.
Args:
axis:
A single integer specifying the axis over which to calculate
the mean. Alternatively, a tuple (multiple axes) or None (no
axes), see :py:func:`~numpy.mean` for details.
dtype:
NumPy type for the output array. If None, this is automatically
chosen based on the type of the ``DelayedArray``, see
:py:func:`~numpy.mean` for details.
buffer_size:
Buffer size in bytes to use for block processing. Larger values
generally improve speed at the cost of memory.
Returns:
A NumPy array containing the boolean values. If ``axis = None``, this will
be a NumPy scalar instead.
"""
if hasattr(self._seed, "any"):
return self._seed.any(axis=axis, dtype=dtype)
else:
return array_any(
self,
axis=axis,
dtype=dtype,
reduce_over_x=lambda x, axes, op : _reduce(x, axes, op, buffer_size),
masked=is_masked(self),
)

def all(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype: Optional[numpy.dtype] = None, buffer_size: int = 1e8) -> numpy.ndarray:
"""Test whether all array elements along a given axis evaluate to True.
Compute this test across the ``DelayedArray``, possibly over a
given axis or set of axes. If the seed has a ``any()`` method, that
method is called directly with the supplied arguments.
Args:
axis:
A single integer specifying the axis over which to calculate
the mean. Alternatively, a tuple (multiple axes) or None (no
axes), see :py:func:`~numpy.mean` for details.
dtype:
NumPy type for the output array. If None, this is automatically
chosen based on the type of the ``DelayedArray``, see
:py:func:`~numpy.mean` for details.
buffer_size:
Buffer size in bytes to use for block processing. Larger values
generally improve speed at the cost of memory.
Returns:
A NumPy array containing the boolean values. If ``axis = None``, this will
be a NumPy scalar instead.
"""
if hasattr(self._seed, "all"):
return self._seed.all(axis=axis, dtype=dtype)
else:
return array_all(
self,
axis=axis,
dtype=dtype,
reduce_over_x=lambda x, axes, op : _reduce(x, axes, op, buffer_size),
masked=is_masked(self),
)

@extract_dense_array.register
def extract_dense_array_DelayedArray(x: DelayedArray, subset: Tuple[Sequence[int], ...]) -> numpy.ndarray:
Expand Down
58 changes: 58 additions & 0 deletions src/delayedarray/_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,61 @@ def op(offset, value):
return sumsq[0]
else:
return sumsq


def array_any(x, axis: Optional[Union[int, Tuple[int, ...]]], dtype: Optional[numpy.dtype], reduce_over_x: Callable, masked: bool) -> numpy.ndarray:
axes = _find_useful_axes(len(x.shape), axis)
if dtype is None:
dtype = _choose_output_type(x.dtype, preserve_integer = True)
output = _allocate_output_array(x.shape, axes, dtype)
buffer = output.ravel(order="F")

if masked:
masked = numpy.zeros(output.shape, dtype=numpy.uint, order="F")
mask_buffer = masked.ravel(order="F")
def op(offset, value):
if value is not numpy.ma.masked:
buffer[offset] = numpy.any(value)
else:
mask_buffer[offset] = True
reduce_over_x(x, axes, op)
size = _expected_sample_size(x.shape, axes)
output = numpy.ma.MaskedArray(output, mask=(masked == size))
else:
def op(offset, value):
buffer[offset] = numpy.any(value)
reduce_over_x(x, axes, op)

if len(axes) == 0:
return output[0]
else:
return output


def array_all(x, axis: Optional[Union[int, Tuple[int, ...]]], dtype: Optional[numpy.dtype], reduce_over_x: Callable, masked: bool) -> numpy.ndarray:
axes = _find_useful_axes(len(x.shape), axis)
if dtype is None:
dtype = _choose_output_type(x.dtype, preserve_integer = True)
output = _allocate_output_array(x.shape, axes, dtype)
buffer = output.ravel(order="F")

if masked:
masked = numpy.zeros(output.shape, dtype=numpy.uint, order="F")
mask_buffer = masked.ravel(order="F")
def op(offset, value):
if value is not numpy.ma.masked:
buffer[offset] = numpy.all(value)
else:
mask_buffer[offset] = True
reduce_over_x(x, axes, op)
size = _expected_sample_size(x.shape, axes)
output = numpy.ma.MaskedArray(output, mask=(masked == size))
else:
def op(offset, value):
buffer[offset] = numpy.all(value)
reduce_over_x(x, axes, op)

if len(axes) == 0:
return output[0]
else:
return output

0 comments on commit ddd6957

Please sign in to comment.