From e6419042fd707bb8286641ac9d89878ebc5fee69 Mon Sep 17 00:00:00 2001 From: CKrawczyk Date: Fri, 25 Oct 2024 13:14:02 +0100 Subject: [PATCH] Wrap class as PyTree Needed to make the `autolens.Tracer` example work. --- autoarray/operators/over_sampling/uniform.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/autoarray/operators/over_sampling/uniform.py b/autoarray/operators/over_sampling/uniform.py index cd075178..5c8652b0 100644 --- a/autoarray/operators/over_sampling/uniform.py +++ b/autoarray/operators/over_sampling/uniform.py @@ -14,7 +14,10 @@ from autoarray import exc from autoarray.operators.over_sampling import over_sample_util +from autofit.jax_wrapper import register_pytree_node_class + +@register_pytree_node_class class OverSamplingUniform(AbstractOverSampling): def __init__(self, sub_size: Union[int, Array2D]): """ @@ -319,6 +322,15 @@ def over_sampler_from(self, mask: Mask2D) -> "OverSamplerUniform": mask=mask, sub_size=self.sub_size, ) + + def tree_flatten(self): + children = (self.sub_size,) + aux_data = None + return (children, aux_data) + + @classmethod + def tree_unflatten(cls, aux_data, children): + return cls(*children) class OverSamplerUniform(AbstractOverSampler):