diff --git a/test_autolens/point/model/test_andrew_implementation.py b/test_autolens/point/model/test_andrew_implementation.py index f09bd6a77..a84c6c45b 100644 --- a/test_autolens/point/model/test_andrew_implementation.py +++ b/test_autolens/point/model/test_andrew_implementation.py @@ -1,3 +1,10 @@ +try: + import jax + + JAX_INSTALLED = True +except ImportError: + JAX_INSTALLED = False + import numpy as np import pytest @@ -14,10 +21,8 @@ def noise_map(): return np.array([1.0, 1.0]) -def test_andrew_implementation( - data, - noise_map, -): +@pytest.fixture +def fit(data, noise_map): model_positions = np.array( [ (-1.0749, -1.1), @@ -25,12 +30,14 @@ def test_andrew_implementation( ] ) - fit = Fit( + return Fit( data=data, noise_map=noise_map, model_positions=model_positions, ) + +def test_andrew_implementation(fit): assert np.allclose( fit.all_permutations_log_likelihoods(), [ @@ -41,6 +48,11 @@ def test_andrew_implementation( assert fit.log_likelihood() == -4.40375330990644 +@pytest.mark.skipif(not JAX_INSTALLED, reason="JAX is not installed") +def test_jax(fit): + assert jax.jit(fit.log_likelihood)() == -4.40375330990644 + + def test_nan_model_positions( data, noise_map,