Skip to content

Commit

Permalink
Run more tests on array-api-strict and sparse
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Jan 11, 2025
1 parent e5dd419 commit ec28ffd
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 23 deletions.
10 changes: 3 additions & 7 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,13 +631,9 @@ def device(x: Array, /) -> Device:
return "cpu"
elif is_dask_array(x):
# Peek at the metadata of the jax array to determine type
try:
import numpy as np
if isinstance(x._meta, np.ndarray):
# Must be on CPU since backed by numpy
return "cpu"
except ImportError:
pass
if is_numpy_array(x._meta):
# Must be on CPU since backed by numpy
return "cpu"
return _DASK_DEVICE
elif is_jax_array(x):
# JAX has .device() as a method, but it is being deprecated so that it
Expand Down
5 changes: 5 additions & 0 deletions docs/supported-array-libraries.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,8 @@ The minimum supported Dask version is 2023.12.0.
## [Sparse](https://sparse.pydata.org/en/stable/)

Similar to JAX, `sparse` Array API support is contained directly in `sparse`.

(array-api-strict-support)=
## [array-api-strict](https://data-apis.org/array-api-strict/)

array-api-strict exists only to test support for the Array API, so it does not need any wrappers.
10 changes: 2 additions & 8 deletions tests/_helpers.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
from importlib import import_module
import sys

import pytest

wrapped_libraries = ["numpy", "cupy", "torch", "dask.array"]
all_libraries = wrapped_libraries + ["jax.numpy"]
all_libraries = wrapped_libraries + ["array_api_strict", "jax.numpy", "sparse"]

# `sparse` added array API support as of Python 3.10.
if sys.version_info >= (3, 10):
all_libraries.append('sparse')

def import_(library, wrapper=False):
if library == 'cupy':
Expand All @@ -20,9 +16,7 @@ def import_(library, wrapper=False):
jax_numpy = import_module("jax.numpy")
if not hasattr(jax_numpy, "__array_api_version__"):
library = 'jax.experimental.array_api'
elif library.startswith('sparse'):
library = 'sparse'
else:
elif library in wrapped_libraries:
library = 'array_api_compat.' + library

return import_module(library)
5 changes: 4 additions & 1 deletion tests/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@

@pytest.mark.parametrize("library", ["common"] + wrapped_libraries)
def test_all(library):
import_(library, wrapper=True)
if library == "common":
import array_api_compat.common # noqa: F401
else:
import_(library, wrapper=True)

for mod_name in sys.modules:
if not mod_name.startswith('array_api_compat.' + library):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_array_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@

@pytest.mark.parametrize("use_compat", [True, False, None])
@pytest.mark.parametrize("api_version", [None, "2021.12", "2022.12", "2023.12"])
@pytest.mark.parametrize("library", all_libraries + ['array_api_strict'])
@pytest.mark.parametrize("library", all_libraries)
def test_array_namespace(library, api_version, use_compat):
xp = import_(library)

array = xp.asarray([1.0, 2.0, 3.0])
if use_compat is True and library in {'array_api_strict', 'jax.numpy', 'sparse'}:
if use_compat and library not in wrapped_libraries:
pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat))
return
namespace = array_api_compat.array_namespace(array, api_version=api_version, use_compat=use_compat)
Expand Down
24 changes: 19 additions & 5 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
is_dask_array, is_jax_array, is_pydata_sparse_array,
is_numpy_namespace, is_cupy_namespace, is_torch_namespace,
is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace,
is_array_api_strict_namespace,
)

from array_api_compat import (
Expand Down Expand Up @@ -31,6 +32,7 @@
'dask.array': 'is_dask_namespace',
'jax.numpy': 'is_jax_namespace',
'sparse': 'is_pydata_sparse_namespace',
'array_api_strict': 'is_array_api_strict_namespace',
}


Expand Down Expand Up @@ -72,7 +74,12 @@ def test_xp_is_array_generics(library):
is_func = globals()[func]
if is_func(x0):
matches.append(library2)
assert matches in ([library], ["numpy"])

if library == "array_api_strict":
# There is no is_array_api_strict_array() function
assert matches == []
else:
assert matches in ([library], ["numpy"])


@pytest.mark.parametrize("library", all_libraries)
Expand Down Expand Up @@ -151,26 +158,33 @@ def test_to_device_host(library):
@pytest.mark.parametrize("target_library", is_array_functions.keys())
@pytest.mark.parametrize("source_library", is_array_functions.keys())
def test_asarray_cross_library(source_library, target_library, request):
if source_library == "dask.array" and target_library == "torch":
def _xfail(reason: str) -> None:
# Allow rest of test to execute instead of immediately xfailing
# xref https://github.com/pandas-dev/pandas/issues/38902
request.node.add_marker(pytest.mark.xfail(reason=reason))

if source_library == "dask.array" and target_library == "torch":
# TODO: remove xfail once
# https://github.com/dask/dask/issues/8260 is resolved
request.node.add_marker(pytest.mark.xfail(reason="Bug in dask raising error on conversion"))
if source_library == "cupy" and target_library != "cupy":
_xfail(reason="Bug in dask raising error on conversion")
elif source_library == "jax.numpy" and target_library == "torch":
_xfail(reason="casts int to float")
elif source_library == "cupy" and target_library != "cupy":
# cupy explicitly disallows implicit conversions to CPU
pytest.skip(reason="cupy does not support implicit conversion to CPU")
elif source_library == "sparse" and target_library != "sparse":
pytest.skip(reason="`sparse` does not allow implicit densification")

src_lib = import_(source_library, wrapper=True)
tgt_lib = import_(target_library, wrapper=True)
is_tgt_type = globals()[is_array_functions[target_library]]

a = src_lib.asarray([1, 2, 3])
a = src_lib.asarray([1, 2, 3], dtype=src_lib.int32)
b = tgt_lib.asarray(a)

assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}"
assert b.dtype == tgt_lib.int32


@pytest.mark.parametrize("library", wrapped_libraries)
def test_asarray_copy(library):
Expand Down

0 comments on commit ec28ffd

Please sign in to comment.