Skip to content

Commit

Permalink
DNM ENH: More lazy functions
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Jan 17, 2025
1 parent ea68ed8 commit c77c4f1
Show file tree
Hide file tree
Showing 7 changed files with 361 additions and 46 deletions.
3 changes: 3 additions & 0 deletions docs/api-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
expand_dims
kron
lazy_apply
lazy_raise
lazy_wait_on
lazy_warn
nunique
pad
setdiff1d
Expand Down
86 changes: 43 additions & 43 deletions pixi.lock

Large diffs are not rendered by default.

10 changes: 8 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ classifiers = [
"Typing :: Typed",
]
dynamic = ["version"]
dependencies = ["array-api-compat>=1.10.0,<2"]
# DNM https://github.com/data-apis/array-api-compat/pull/237
dependencies = ["array-api-compat @ git+https://github.com/crusaderky/array-api-compat.git@0c1b7ddf9186e4e719afe41675e0a111fa482257"]
# dependencies = ["array-api-compat>=1.10.0,<2"]

[project.optional-dependencies]
tests = [
Expand Down Expand Up @@ -54,6 +56,9 @@ Changelog = "https://github.com/data-apis/array-api-extra/releases"
[tool.hatch]
version.path = "src/array_api_extra/__init__.py"

# DNM
[tool.hatch.metadata]
allow-direct-references = true # Enable git dependencies

# Pixi

Expand All @@ -63,7 +68,8 @@ platforms = ["linux-64", "osx-arm64", "win-64"]

[tool.pixi.dependencies]
python = ">=3.10,<3.14"
array-api-compat = ">=1.10.0,<2"
# DNM
# array-api-compat = ">=1.10.0,<2"

[tool.pixi.pypi-dependencies]
array-api-extra = { path = ".", editable = true }
Expand Down
302 changes: 301 additions & 1 deletion src/array_api_extra/_lib/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,18 @@
from __future__ import annotations

import math
import warnings
from collections.abc import Callable, Sequence
from functools import wraps
from types import ModuleType
from typing import TYPE_CHECKING, Any, cast, overload

from ._utils._compat import array_namespace, is_dask_namespace, is_jax_namespace
from ._utils._compat import (
array_namespace,
is_dask_namespace,
is_jax_namespace,
is_lazy_array,
)
from ._utils._typing import Array, DType

if TYPE_CHECKING:
Expand Down Expand Up @@ -319,3 +325,297 @@ def wrapper( # type: ignore[no-any-decorated,no-any-explicit]
return (xp.asarray(out),)

return wrapper


def lazy_raise( # numpydoc ignore=SA04
x: Array,
cond: bool | Array,
exc: Exception,
*,
xp: ModuleType | None = None,
) -> Array:
"""
Raise if an eager check fails on a lazy array.
Consider this snippet::
>>> def f(x, xp):
... if xp.any(x < 0):
... raise ValueError("Some points are negative")
... return x + 1
The above code fails to compile when x is a JAX array and the function is wrapped
by `jax.jit`; it is also extremely slow on Dask. Other lazy backends, e.g. ndonnx,
are also expected to misbehave.
`xp.any(x < 0)` is a 0-dimensional array with `dtype=bool`; the `if` statement calls
`bool()` on the Array to convert it to a Python bool.
On eager backends such as NumPy, this is not a problem. On Dask, `bool()` implicitly
triggers a computation of the whole graph so far; what's worse is that the
intermediate results are discarded to optimize memory usage, so when later on user
explicitly calls `compute()` on their final output, `x` is recalculated from
scratch. On JAX, `bool()` raises if its called code is wrapped by `jax.jit` for the
same reason.
You should rewrite the above code as follows::
>>> def f(x, xp):
... x = lazy_raise(x, xp.any(x < 0), ValueError("Some points are negative"))
... return x + 1
When `xp` is eager, this is equivalent to the original code; if the error condition
resolves to True, the function raises immediately and the next line `return x + 1`
is never executed.
When `xp` is lazy, the function always returns a lazy array. When eventually the
user actually computes it, e.g. in Dask by calling `compute()` and in JAX by having
their outermost function decorated with `@jax.jit` return, only then the error
condition is evaluated. If True, the exception is raised and propagated as normal,
and the following nodes of the graph are never executed (so if the health check was
in place to prevent not only incorrect results but e.g. a segmentation fault, it's
still going to achieve its purpose).
Parameters
----------
x : Array
Any one Array, potentially lazy, that is used later on to produce the value
returned by your function.
cond : bool | Array
Must be either a plain Python bool or a 0-dimensional Array with boolean dtype.
If True, raise the exception. If False, return x.
exc : Exception
The exception instance to be raised.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.
Returns
-------
Array
`x`. If both `x` and `cond` are lazy array, the graph underlying `x` is altered
to raise `exc` if `cond` is True.
Raises
------
type(x)
If `cond` evaluates to True.
Warnings
--------
This function raises when x is eager, and quietly skips the check
when x is lazy::
>>> def f(x, xp):
... lazy_raise(x, xp.any(x < 0), ValueError("Some points are negative"))
... return x + 1
And so does this one, as lazy_raise replaces `x` but it does so too late to
contribute to the return value::
>>> def f(x, xp):
... y = x + 1
... x = lazy_raise(x, xp.any(x < 0), ValueError("Some points are negative"))
... return y
See Also
--------
lazy_apply
lazy_warn
lazy_wait_on
dask.graph_manipulation.wait_on
Notes
-----
This function will raise if the :doc:`jax:transfer_guard` is active and `cond` is
a JAX array on a non-CPU device.
"""

def _lazy_raise(x: Array, cond: Array) -> Array: # numpydoc ignore=PR01,RT01
"""Eager helper of `lazy_raise` running inside the lazy graph."""
if cond:
raise exc
return x

return _lazy_wait_on_impl(x, cond, _lazy_raise, xp=xp)


# Signature of warnings.warn copied from python/typeshed
@overload
def lazy_warn( # type: ignore[no-any-explicit,no-any-decorated] # numpydoc ignore=GL08
x: Array,
cond: bool | Array,
message: str,
category: type[Warning] | None = None,
stacklevel: int = 1,
source: Any | None = None,
*,
xp: ModuleType | None = None,
) -> None: ...
@overload
def lazy_warn( # type: ignore[no-any-explicit,no-any-decorated] # numpydoc ignore=GL08
x: Array,
cond: bool | Array,
message: Warning,
category: Any = None,
stacklevel: int = 1,
source: Any | None = None,
*,
xp: ModuleType | None = None,
) -> None: ...


def lazy_warn( # type: ignore[no-any-explicit] # numpydoc ignore=SA04,PR04
x: Array,
cond: bool | Array,
message: str | Warning,
category: Any = None,
stacklevel: int = 1,
source: Any | None = None,
*,
xp: ModuleType | None = None,
) -> Array:
"""
Call `warnings.warn` if an eager check fails on a lazy array.
This functions works in the same way as `lazy_raise`; refer to it
for the detailed explanation.
You should replace::
>>> def f(x, xp):
... if xp.any(x < 0):
... warnings.warn("Some points are negative", UserWarning, stacklevel=2)
... return x + 1
with::
>>> def f(x, xp):
... x = lazy_raise(x, xp.any(x < 0),
... "Some points are negative", UserWarning, stacklevel=2)
... return x + 1
Parameters
----------
x : Array
Any one Array, potentially lazy, that is used later on to produce the value
returned by your function.
cond : bool | Array
Must be either a plain Python bool or a 0-dimensional Array with boolean dtype.
If True, raise the exception. If False, return x.
message, category, stacklevel, source :
Parameters to `warnings.warn`. `stacklevel` is automatically increased to
compensate for the extra wrapper function.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.
Returns
-------
Array
`x`. If both `x` and `cond` are lazy array, the graph underlying `x` is altered
to issue the warning if `cond` is True.
See Also
--------
warnings.warn
lazy_apply
lazy_raise
lazy_wait_on
dask.graph_manipulation.wait_on
Notes
-----
On Dask, the warning is typically going to appear on the log of the
worker executing the function instead of on the client.
"""

def _lazy_warn(x: Array, cond: Array) -> Array: # numpydoc ignore=PR01,RT01
"""Eager helper of `lazy_raise` running inside the lazy graph."""
if cond:
warnings.warn(message, category, stacklevel=stacklevel + 2, source=source)
return x

return _lazy_wait_on_impl(x, cond, _lazy_warn, xp=xp)


def lazy_wait_on(
x: Array, wait_on: object, *, xp: ModuleType | None = None
) -> Array: # numpydoc ignore=SA04
"""
Pause materialization of `x` until `wait_on` has been materialized.
This is typically used to collect multiple calls to `lazy_raise` and/or
`lazy_warn` from validation functions that would otherwise return None.
If `wait_on` is not a lazy array, just return `x`.
Read `lazy_raise` for detailed explanation.
Parameters
----------
x : Array
Any one Array, potentially lazy, that is used later on to produce the value
returned by your function.
wait_on : object
Any object. If it's a lazy array, block the materialization of `x` until
`wait_on` has been fully materialized.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.
Returns
-------
Array
`x`. If both `x` and `wait_on` are lazy arrays, the graph
underlying `x` is altered to wait until `wait_on` has been materialized.
If `wait_on` raises, the exception is propagated to `x`.
See Also
--------
lazy_apply
lazy_raise
lazy_warn
dask.graph_manipulation.wait_on
Examples
--------
::
def validate(x, xp):
# Future that evaluates the checks. Contents are inconsequential.
# Avoid zero-sized arrays, as they may be elided by the graph optimizer.
future = xp.empty(1)
future = lazy_raise(future, xp.any(x < 10), ValueError("Less than 10"))
future = lazy_warn(future, xp.any(x > 20), UserWarning, "More than 20"))
return future
def f(x, xp):
x = lazy_wait_on(x, validate(x, xp), xp=xp)
return x + 1
"""

def _lazy_wait_on(x: Array, _: Array) -> Array: # numpydoc ignore=PR01,RT01
"""Eager helper of `lazy_wait_on` running inside the lazy graph."""
return x

return _lazy_wait_on_impl(x, wait_on, _lazy_wait_on, xp=xp)


def _lazy_wait_on_impl( # numpydoc ignore=PR01,RT01
x: Array,
wait_on: object,
eager_func: Callable[[Array, Array], Array],
xp: ModuleType | None,
) -> Array:
"""Implementation of lazy_raise, lazy_warn, and lazy_wait_on."""
if not is_lazy_array(wait_on):
return eager_func(x, wait_on)

if cast(Array, wait_on).shape != ():
msg = "cond/wait_on must be 0-dimensional"
raise ValueError(msg)

if xp is None:
xp = array_namespace(x, wait_on)

if is_dask_namespace(xp):
# lazy_apply would rechunk x
return xp.map_blocks(eager_func, x, wait_on, dtype=x.dtype, meta=x._meta) # pylint: disable=protected-access

return lazy_apply(eager_func, x, wait_on, shape=x.shape, dtype=x.dtype, xp=xp)
3 changes: 3 additions & 0 deletions src/array_api_extra/_lib/_utils/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
is_dask_namespace,
is_jax_array,
is_jax_namespace,
is_lazy_array,
is_numpy_namespace,
is_pydata_sparse_namespace,
is_torch_namespace,
Expand All @@ -26,6 +27,7 @@
is_dask_namespace,
is_jax_array,
is_jax_namespace,
is_lazy_array,
is_numpy_namespace,
is_pydata_sparse_namespace,
is_torch_namespace,
Expand All @@ -41,6 +43,7 @@
"is_dask_namespace",
"is_jax_array",
"is_jax_namespace",
"is_lazy_array",
"is_numpy_namespace",
"is_pydata_sparse_namespace",
"is_torch_namespace",
Expand Down
1 change: 1 addition & 0 deletions src/array_api_extra/_lib/_utils/_compat.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,6 @@ def is_numpy_namespace(xp: ModuleType, /) -> bool: ...
def is_pydata_sparse_namespace(xp: ModuleType, /) -> bool: ...
def is_torch_namespace(xp: ModuleType, /) -> bool: ...
def is_jax_array(x: object, /) -> bool: ...
def is_lazy_array(x: object, /) -> bool: ...
def is_writeable_array(x: object, /) -> bool: ...
def size(x: Array, /) -> int | None: ...
Loading

0 comments on commit c77c4f1

Please sign in to comment.