Skip to content

Commit

Permalink
Merge pull request #158 from ickc/feature/jax_wrapper-mask_2d_circula…
Browse files Browse the repository at this point in the history
…r_from

fix jit on mask_2d_circular_from
  • Loading branch information
Jammy2211 authored Jan 22, 2025
2 parents 311f7cd + ba31c1c commit 40be8b1
Showing 1 changed file with 7 additions and 19 deletions.
26 changes: 7 additions & 19 deletions autoarray/mask/mask_2d_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def total_pixels_2d_from(mask_2d: np.ndarray) -> int:
return total_regular_pixels


@numba_util.jit()
@numba_util.jit(static_argnums=0)
def mask_2d_circular_from(
shape_native: Tuple[int, int],
pixel_scales: ty.PixelScales,
Expand Down Expand Up @@ -114,24 +114,12 @@ def mask_2d_circular_from(
mask = mask_circular_from(
shape=(10, 10), pixel_scales=0.1, radius=0.5, centre=(0.0, 0.0))
"""

mask_2d = np.full(shape_native, True)

centres_scaled = mask_2d_centres_from(
shape_native=mask_2d.shape, pixel_scales=pixel_scales, centre=centre
)

for y in range(mask_2d.shape[0]):
for x in range(mask_2d.shape[1]):
y_scaled = (y - centres_scaled[0]) * pixel_scales[0]
x_scaled = (x - centres_scaled[1]) * pixel_scales[1]

r_scaled = np.sqrt(x_scaled**2 + y_scaled**2)

if r_scaled <= radius:
mask_2d[y, x] = False

return mask_2d
centres_scaled = mask_2d_centres_from(shape_native, pixel_scales, centre)
ys, xs = np.indices(shape_native)
return (radius * radius) < (
np.square((ys - centres_scaled[0]) * pixel_scales[0]) +
np.square((xs - centres_scaled[1]) * pixel_scales[1])
)


@numba_util.jit()
Expand Down

0 comments on commit 40be8b1

Please sign in to comment.