diff --git a/tests/test_dask.py b/tests/test_dask.py index fa719dbd..ce69d503 100644 --- a/tests/test_dask.py +++ b/tests/test_dask.py @@ -1,3 +1,5 @@ +from contextlib import contextmanager + import dask import numpy as np import pytest @@ -12,11 +14,11 @@ def xp(): return array_namespace(da.empty(0)) -@pytest.fixture -def no_compute(): +@contextmanager +def assert_no_compute(): """ - Cause the test to raise if at any point anything calls compute() or persist(), - e.g. as it can be triggered implicitly by __bool__, __array__, etc. + Context manager that raises if at any point inside it anything calls compute() + or persist(), e.g. as it can be triggered implicitly by __bool__, __array__, etc. """ def get(dsk, *args, **kwargs): raise AssertionError("Called compute() or persist()") @@ -25,59 +27,68 @@ def get(dsk, *args, **kwargs): yield -def test_no_compute(no_compute): - """Test the no_compute fixture""" +def test_no_compute(): + """Test the no_compute context manager""" a = da.asarray(True) with pytest.raises(AssertionError, match="Called compute"): - bool(a) + with assert_no_compute(): + bool(a) + + # Exiting the context manager restores the original scheduler + assert bool(a) is True # Test no_compute for functions that use generic _aliases with xp=np -def test_unary_ops_no_compute(xp, no_compute): - a = xp.asarray([1.5, -1.5]) - xp.ceil(a) - xp.floor(a) - xp.trunc(a) - xp.sign(a) +def test_unary_ops_no_compute(xp): + with assert_no_compute(): + a = xp.asarray([1.5, -1.5]) + xp.ceil(a) + xp.floor(a) + xp.trunc(a) + xp.sign(a) -def test_matmul_tensordot_no_compute(xp, no_compute): +def test_matmul_tensordot_no_compute(xp): A = da.ones((4, 4), chunks=2) B = da.zeros((4, 4), chunks=2) - xp.matmul(A, B) - xp.tensordot(A, B) + with assert_no_compute(): + xp.matmul(A, B) + xp.tensordot(A, B) # Test no_compute for functions that are fully bespoke for dask -def test_asarray_no_compute(xp, no_compute): - a = xp.arange(10) - xp.asarray(a) - xp.asarray(a, dtype=np.int16) - xp.asarray(a, dtype=a.dtype) - xp.asarray(a, copy=True) - xp.asarray(a, copy=True, dtype=np.int16) - xp.asarray(a, copy=True, dtype=a.dtype) - xp.asarray(a, copy=False) - xp.asarray(a, copy=False, dtype=a.dtype) +def test_asarray_no_compute(xp): + with assert_no_compute(): + a = xp.arange(10) + xp.asarray(a) + xp.asarray(a, dtype=np.int16) + xp.asarray(a, dtype=a.dtype) + xp.asarray(a, copy=True) + xp.asarray(a, copy=True, dtype=np.int16) + xp.asarray(a, copy=True, dtype=a.dtype) + xp.asarray(a, copy=False) + xp.asarray(a, copy=False, dtype=a.dtype) @pytest.mark.parametrize("copy", [True, False]) -def test_astype_no_compute(xp, no_compute, copy): - a = xp.arange(10) - xp.astype(a, np.int16, copy=copy) - xp.astype(a, a.dtype, copy=copy) +def test_astype_no_compute(xp, copy): + with assert_no_compute(): + a = xp.arange(10) + xp.astype(a, np.int16, copy=copy) + xp.astype(a, a.dtype, copy=copy) -def test_clip_no_compute(xp, no_compute): - a = xp.arange(10) - xp.clip(a) - xp.clip(a, 1) - xp.clip(a, 1, 8) +def test_clip_no_compute(xp): + with assert_no_compute(): + a = xp.arange(10) + xp.clip(a) + xp.clip(a, 1) + xp.clip(a, 1, 8) -def test_generators_are_lazy(xp, no_compute): +def test_generators_are_lazy(xp): """ Test that generator functions are fully lazy, e.g. that da.ones(n) is not implemented as da.asarray(np.ones(n)) @@ -85,12 +96,13 @@ def test_generators_are_lazy(xp, no_compute): size = 100_000_000_000 # 800 GB chunks = size // 10 # 10x 80 GB chunks - xp.zeros(size, chunks=chunks) - xp.ones(size, chunks=chunks) - xp.empty(size, chunks=chunks) - xp.full(size, fill_value=123, chunks=chunks) - a = xp.arange(size, chunks=chunks) - xp.zeros_like(a) - xp.ones_like(a) - xp.empty_like(a) - xp.full_like(a, fill_value=123) + with assert_no_compute(): + xp.zeros(size, chunks=chunks) + xp.ones(size, chunks=chunks) + xp.empty(size, chunks=chunks) + xp.full(size, fill_value=123, chunks=chunks) + a = xp.arange(size, chunks=chunks) + xp.zeros_like(a) + xp.ones_like(a) + xp.empty_like(a) + xp.full_like(a, fill_value=123)