Skip to content

Commit

Permalink
fix JAX imports
Browse files Browse the repository at this point in the history
  • Loading branch information
Jammy2211 committed Dec 18, 2024
1 parent e7a66fe commit d03ab4c
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 15 deletions.
22 changes: 11 additions & 11 deletions autoarray/mask/abstract_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,17 @@ def pixel_scale(self) -> float:
For a mask with dimensions two or above check that are pixel scales are the same, and if so return this
single value as a float.
"""
for pixel_scale in self.pixel_scales:
if abs(pixel_scale - self.pixel_scales[0]) > 1.0e-8:
logger.warning(
f"""
The Mask has different pixel scales in each dimensions, which are {self.pixel_scales}.
This is not expected, and will lead to unexpected behaviour in the grid and mask classes.
The code will continue to run, but you should check that the pixel scales are as you expect and
that associated data structures (e.g. grids) are behaving as you expect.
"""
)
# for pixel_scale in self.pixel_scales:
# if abs(pixel_scale - self.pixel_scales[0]) > 1.0e-8:
# logger.warning(
# f"""
# The Mask has different pixel scales in each dimensions, which are {self.pixel_scales}.
#
# This is not expected, and will lead to unexpected behaviour in the grid and mask classes.
# The code will continue to run, but you should check that the pixel scales are as you expect and
# that associated data structures (e.g. grids) are behaving as you expect.
# """
# )

return self.pixel_scales[0]

Expand Down
5 changes: 4 additions & 1 deletion autoarray/structures/triangles/array/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
from .array import ArrayTriangles
from .jax_array import ArrayTriangles as JAXArrayTriangles
try:
from .jax_array import ArrayTriangles as JAXArrayTriangles
except ImportError:
pass
9 changes: 6 additions & 3 deletions autoarray/structures/triangles/coordinate_array/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from .coordinate_array import CoordinateArrayTriangles
from .jax_coordinate_array import (
CoordinateArrayTriangles as JAXCoordinateArrayTriangles,
)
try:
from .jax_coordinate_array import (
CoordinateArrayTriangles as JAXCoordinateArrayTriangles,
)
except ImportError:
pass

0 comments on commit d03ab4c

Please sign in to comment.