Skip to content

Commit

Permalink
try to test with jax
Browse files Browse the repository at this point in the history
  • Loading branch information
rhayes777 committed Nov 8, 2024
1 parent 2684069 commit 9a29132
Showing 1 changed file with 51 additions and 29 deletions.
80 changes: 51 additions & 29 deletions test_autoarray/structures/triangles/test_coordinate_jax.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
from autoarray.numpy_wrapper import jit
import pytest
import numpy as np

from autoarray.structures.triangles.jax_coordinate_array import CoordinateArrayTriangles
from autoarray.structures.triangles.shape import Point

try:
from jax import numpy as np
import jax

jax.config.update("jax_log_compiles", True)
from autoarray.structures.triangles.jax_coordinate_array import (
CoordinateArrayTriangles,
)
except ImportError:
import numpy as np
from autoarray.structures.triangles.coordinate_array import CoordinateArrayTriangles


@pytest.fixture
def one_triangle():
Expand Down Expand Up @@ -32,41 +42,53 @@ def test_full_routine(one_triangle, plot):
def test_neighborhood(one_triangle):
assert np.all(
np.array(jit(one_triangle.neighborhood)().triangles)
== [
[
[-0.5, -0.4330126941204071],
[-1.0, 0.4330126941204071],
[0.0, 0.4330126941204071],
],
== np.array(
[
[0.0, -1.299038052558899],
[-0.5, -0.4330126941204071],
[0.5, -0.4330126941204071],
],
[
[0.0, 0.4330126941204071],
[0.5, -0.4330126941204071],
[-0.5, -0.4330126941204071],
],
[
[0.5, -0.4330126941204071],
[0.0, 0.4330126941204071],
[1.0, 0.4330126941204071],
],
]
[
[-0.5, -0.4330126941204071],
[-1.0, 0.4330126941204071],
[0.0, 0.4330126941204071],
],
[
[0.0, -1.299038052558899],
[-0.5, -0.4330126941204071],
[0.5, -0.4330126941204071],
],
[
[0.0, 0.4330126941204071],
[0.5, -0.4330126941204071],
[-0.5, -0.4330126941204071],
],
[
[0.5, -0.4330126941204071],
[0.0, 0.4330126941204071],
[1.0, 0.4330126941204071],
],
]
)
)


def test_up_sample(one_triangle):
up_sampled = jit(one_triangle.up_sample)()
assert np.all(
np.array(up_sampled.triangles)
== [
== np.array(
[
[[0.0, -0.4330126941204071], [-0.25, 0.0], [0.25, 0.0]],
[[0.25, 0.0], [0.5, -0.4330126941204071], [0.0, -0.4330126941204071]],
[[-0.25, 0.0], [0.0, -0.4330126941204071], [-0.5, -0.4330126941204071]],
[[0.0, 0.4330126941204071], [0.25, 0.0], [-0.25, 0.0]],
[
[[0.0, -0.4330126941204071], [-0.25, 0.0], [0.25, 0.0]],
[
[0.25, 0.0],
[0.5, -0.4330126941204071],
[0.0, -0.4330126941204071],
],
[
[-0.25, 0.0],
[0.0, -0.4330126941204071],
[-0.5, -0.4330126941204071],
],
[[0.0, 0.4330126941204071], [0.25, 0.0], [-0.25, 0.0]],
]
]
]
)
)

0 comments on commit 9a29132

Please sign in to comment.