Skip to content

Commit

Permalink
address feedback
Browse files Browse the repository at this point in the history
Co-Authored-By: Guido Imperiale <6213168+crusaderky@users.noreply.github.com>
  • Loading branch information
lithomas1 and crusaderky committed Jan 15, 2025
1 parent fef8607 commit 5aa1333
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 18 deletions.
22 changes: 17 additions & 5 deletions array_api_compat/dask/array/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,25 @@
isdtype = get_xp(np)(_aliases.isdtype)
unstack = get_xp(da)(_aliases.unstack)

def astype(x: Array, dtype: Dtype, /, *, copy: bool = True) -> Array:
def astype(
x: Array,
dtype: Dtype,
/,
*,
copy: bool = True,
device: Device | None = None
) -> Array:
if device is not None:
raise NotImplementedError(
"The device keyword is not implemented yet for "
"array-api-compat wrapped dask"
)
if not copy and dtype == x.dtype:
return x
# dask astype doesn't respect copy=True so copy
# manually via numpy
x = np.array(x, dtype=dtype, copy=copy)
return da.from_array(x)
# dask astype doesn't respect copy=True,
# so call copy manually afterwards
x = x.astype(dtype)
return x.copy() if copy else x

# Common aliases

Expand Down
13 changes: 0 additions & 13 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,16 +272,3 @@ def test_asarray_copy(library):
assert all(b[0] == 1.0)
else:
assert all(b[0] == 0.0)

@pytest.mark.parametrize("library", wrapped_libraries)
def test_astype_copy(library):
# array-api-tests currently doesn't check copy=True
# makes a copy when dtypes are the same
# so we check that here
xp = import_(library, wrapper=True)
a = xp.asarray([1])
b = xp.astype(a, a.dtype, copy=True)

a[0] = 10

assert b[0] == 1

0 comments on commit 5aa1333

Please sign in to comment.