From 9ba7197d44d31918d779ef732a8e0eb800adc752 Mon Sep 17 00:00:00 2001 From: Richard Date: Fri, 29 Nov 2024 15:32:31 +0000 Subject: [PATCH] more int32 --- .../coordinate_array/jax_coordinate_array.py | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/autoarray/structures/triangles/coordinate_array/jax_coordinate_array.py b/autoarray/structures/triangles/coordinate_array/jax_coordinate_array.py index 89ac2447..6e66e555 100644 --- a/autoarray/structures/triangles/coordinate_array/jax_coordinate_array.py +++ b/autoarray/structures/triangles/coordinate_array/jax_coordinate_array.py @@ -134,10 +134,14 @@ def up_sample(self) -> "CoordinateArrayTriangles": n = coordinates.shape[0] - shift0 = np.zeros((n, 2)) - shift3 = np.tile(np.array([0, 1]), (n, 1)) - shift1 = np.stack([np.ones(n), np.where(flip_mask, 1.0, 0.0)], axis=1) - shift2 = np.stack([-np.ones(n), np.where(flip_mask, 1.0, 0.0)], axis=1) + shift0 = np.zeros((n, 2), dtype=np.int32) + shift3 = np.tile(np.array([0, 1], dtype=np.int32), (n, 1)) + shift1 = np.stack( + [np.ones(n, dtype=np.int32), np.where(flip_mask, 1, 0)], axis=1 + ) + shift2 = np.stack( + [-np.ones(n, dtype=np.int32), np.where(flip_mask, 1, 0)], axis=1 + ) shifts = np.stack([shift0, shift1, shift2, shift3], axis=1) coordinates_expanded = coordinates[:, None, :] @@ -164,14 +168,14 @@ def neighborhood(self) -> "CoordinateArrayTriangles": coordinates = self.coordinates flip_mask = self.flip_mask - shift0 = np.zeros((coordinates.shape[0], 2)) - shift1 = np.tile(np.array([1, 0]), (coordinates.shape[0], 1)) - shift2 = np.tile(np.array([-1, 0]), (coordinates.shape[0], 1)) + shift0 = np.zeros((coordinates.shape[0], 2), dtype=np.int32) + shift1 = np.tile(np.array([1, 0], dtype=np.int32), (coordinates.shape[0], 1)) + shift2 = np.tile(np.array([-1, 0], dtype=np.int32), (coordinates.shape[0], 1)) shift3 = np.where( flip_mask[:, None], - np.tile(np.array([0, 1]), (coordinates.shape[0], 1)), - np.tile(np.array([0, -1]), (coordinates.shape[0], 1)), + np.tile(np.array([0, 1], dtype=np.int32), (coordinates.shape[0], 1)), + np.tile(np.array([0, -1], dtype=np.int32), (coordinates.shape[0], 1)), ) shifts = np.stack([shift0, shift1, shift2, shift3], axis=1)