From 64f511d4f839a0751483424a5b30053af850bdb8 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 18 Nov 2024 01:50:30 +0000 Subject: [PATCH 1/9] chore(deps): bump codecov/codecov-action from 4 to 5 Bumps [codecov/codecov-action](https://github.com/codecov/codecov-action) from 4 to 5. - [Release notes](https://github.com/codecov/codecov-action/releases) - [Changelog](https://github.com/codecov/codecov-action/blob/main/CHANGELOG.md) - [Commits](https://github.com/codecov/codecov-action/compare/v4...v5) --- updated-dependencies: - dependency-name: codecov/codecov-action dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7451bef5..414f5a39 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -73,7 +73,7 @@ jobs: - name: Upload coverage reports to Codecov with GitHub Action if: ${{ matrix.python-version == '3.10' && matrix.os == 'ubuntu-latest'}} - uses: codecov/codecov-action@v4 + uses: codecov/codecov-action@v5 env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} with: From cb771f277057b7d1930a2fb991f5083cd2a101be Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 18 Nov 2024 17:54:53 +0000 Subject: [PATCH 2/9] chore: update pre-commit hooks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/astral-sh/ruff-pre-commit: v0.7.2 → v0.7.4](https://github.com/astral-sh/ruff-pre-commit/compare/v0.7.2...v0.7.4) --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 04875535..9422d7e8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -51,7 +51,7 @@ repos: args: [--prose-wrap=always] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: "v0.7.2" + rev: "v0.7.4" hooks: - id: ruff args: ["--fix", "--show-fixes"] From 48cf6c117f82868df5b5e5716d88240306845e42 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 21 Nov 2024 10:33:29 -0500 Subject: [PATCH 3/9] chore(deps): bump pypa/gh-action-pypi-publish from 1.11.0 to 1.12.2 (#284) Bumps [pypa/gh-action-pypi-publish](https://github.com/pypa/gh-action-pypi-publish) from 1.11.0 to 1.12.2. - [Release notes](https://github.com/pypa/gh-action-pypi-publish/releases) - [Commits](https://github.com/pypa/gh-action-pypi-publish/compare/v1.11.0...v1.12.2) --- updated-dependencies: - dependency-name: pypa/gh-action-pypi-publish dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Connor Stone, PhD --- .github/workflows/cd.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml index 6237f0eb..1777a364 100644 --- a/.github/workflows/cd.yml +++ b/.github/workflows/cd.yml @@ -57,7 +57,7 @@ jobs: ls -ltrh ls -ltrh dist - name: Publish to Test PyPI - uses: pypa/gh-action-pypi-publish@v1.11.0 + uses: pypa/gh-action-pypi-publish@v1.12.2 with: repository-url: https://test.pypi.org/legacy/ verbose: true @@ -95,5 +95,5 @@ jobs: name: artifact path: dist - - uses: pypa/gh-action-pypi-publish@v1.11.0 + - uses: pypa/gh-action-pypi-publish@v1.12.2 if: startsWith(github.ref, 'refs/tags') From e3633045a06be97f13209cb14d28e2a2376a6f73 Mon Sep 17 00:00:00 2001 From: "Connor Stone, PhD" Date: Thu, 21 Nov 2024 07:42:06 -0800 Subject: [PATCH 4/9] refactor: Change Parametrized system over to caskade (#275) * first pass convert to caskade * clear pydantic, some tests pass * all tests pass * params by args * fix requirements * add param metadata * get tutorials to work * adding batched plane * get tutorials to run * fix qso fit example * less hacky batched lens model * with math expressions * add chunk_size option in batched plane * shift order of parameters in lenssource * minor changes from comments. better batch plane, mass sheet kappa * jacobian of deflection angle now far more memory efficient at small compute cost * make tests pass for NFW TNFW * lenssource now handles larger grids * add padding mode for pixelated light source --- .../source/examples/Example_ImageFit_LM.ipynb | 8 +- docs/source/examples/Example_QSOLensFit.ipynb | 21 +- .../tutorials/InterfaceIntroduction_oop.ipynb | 14 +- .../InterfaceIntroduction_yaml.ipynb | 10 +- docs/source/tutorials/Introduction.ipynb | 83 +-- .../source/tutorials/InvertLensEquation.ipynb | 2 +- docs/source/tutorials/LensZoo.ipynb | 59 +- docs/source/tutorials/Microlensing.ipynb | 38 +- docs/source/tutorials/MultiplaneDemo.ipynb | 8 +- docs/source/tutorials/Parameters.ipynb | 242 ------- docs/source/tutorials/VisualizeCaustics.ipynb | 15 +- requirements.txt | 1 + src/caustics/__init__.py | 12 +- src/caustics/cosmology/FlatLambdaCDM.py | 64 +- src/caustics/cosmology/base.py | 76 +-- src/caustics/func.py | 4 + src/caustics/io.py | 126 ---- src/caustics/lenses/__init__.py | 2 + src/caustics/lenses/base.py | 311 ++++----- src/caustics/lenses/batchedplane.py | 211 ++++++ src/caustics/lenses/enclosed_mass.py | 57 +- src/caustics/lenses/epl.py | 71 +- src/caustics/lenses/external_shear.py | 49 +- src/caustics/lenses/func/__init__.py | 10 +- src/caustics/lenses/func/mass_sheet.py | 28 +- src/caustics/lenses/func/point.py | 75 +++ src/caustics/lenses/mass_sheet.py | 60 +- src/caustics/lenses/multiplane.py | 92 +-- src/caustics/lenses/multipole.py | 62 +- src/caustics/lenses/nfw.py | 105 ++- src/caustics/lenses/pixelated_convergence.py | 47 +- src/caustics/lenses/pixelated_potential.py | 47 +- src/caustics/lenses/point.py | 100 ++- src/caustics/lenses/pseudo_jaffe.py | 119 ++-- src/caustics/lenses/sie.py | 63 +- src/caustics/lenses/singleplane.py | 28 +- src/caustics/lenses/sis.py | 45 +- src/caustics/lenses/tnfw.py | 167 ++--- src/caustics/light/base.py | 20 +- src/caustics/light/light_stack.py | 13 +- src/caustics/light/pixelated.py | 30 +- src/caustics/light/pixelated_time.py | 20 +- src/caustics/light/sersic.py | 38 +- src/caustics/light/star_source.py | 28 +- src/caustics/models/__init__.py | 0 src/caustics/models/api.py | 37 -- src/caustics/models/base_models.py | 90 --- src/caustics/models/registry.py | 127 ---- src/caustics/models/utils.py | 285 -------- src/caustics/namespace_dict.py | 193 ------ src/caustics/packed.py | 9 - src/caustics/parameter.py | 102 --- src/caustics/parametrized.py | 610 ------------------ src/caustics/sims/__init__.py | 4 +- src/caustics/sims/lens_source.py | 57 +- src/caustics/sims/microlens.py | 23 +- src/caustics/sims/simulator.py | 131 ++-- src/caustics/sims/state_dict.py | 310 --------- src/caustics/tests.py | 18 +- src/caustics/utils.py | 141 +++- tests/conftest.py | 25 +- tests/models/test_mod_api.py | 236 ------- tests/models/test_mod_registry.py | 97 --- tests/models/test_mod_utils.py | 118 ---- tests/sims/conftest.py | 40 -- tests/sims/test_simulator.py | 92 --- tests/sims/test_state_dict.py | 189 ------ tests/test_batchedplane.py | 47 ++ tests/test_batching.py | 79 --- tests/test_epl.py | 10 +- tests/test_external_shear.py | 17 +- tests/test_io.py | 63 -- tests/test_lens_potential.py | 4 +- tests/test_masssheet.py | 11 +- tests/test_multiplane.py | 83 +-- tests/test_multipole.py | 14 +- tests/test_namespace_dict.py | 96 --- tests/test_nfw.py | 24 +- tests/test_parameter.py | 41 -- tests/test_parametrized.py | 283 -------- tests/test_point.py | 14 +- tests/test_pseudo_jaffe.py | 11 +- tests/test_sersic.py | 11 +- tests/test_sie.py | 10 +- tests/test_simulator_runs.py | 42 +- tests/test_sis.py | 18 +- tests/test_tnfw.py | 15 +- tests/utils/__init__.py | 233 ++----- tests/utils/models.py | 177 ----- 89 files changed, 1648 insertions(+), 5270 deletions(-) delete mode 100644 docs/source/tutorials/Parameters.ipynb delete mode 100644 src/caustics/io.py create mode 100644 src/caustics/lenses/batchedplane.py delete mode 100644 src/caustics/models/__init__.py delete mode 100644 src/caustics/models/api.py delete mode 100644 src/caustics/models/base_models.py delete mode 100644 src/caustics/models/registry.py delete mode 100644 src/caustics/models/utils.py delete mode 100644 src/caustics/namespace_dict.py delete mode 100644 src/caustics/packed.py delete mode 100644 src/caustics/parameter.py delete mode 100644 src/caustics/parametrized.py delete mode 100644 src/caustics/sims/state_dict.py delete mode 100644 tests/models/test_mod_api.py delete mode 100644 tests/models/test_mod_registry.py delete mode 100644 tests/models/test_mod_utils.py delete mode 100644 tests/sims/conftest.py delete mode 100644 tests/sims/test_simulator.py delete mode 100644 tests/sims/test_state_dict.py create mode 100644 tests/test_batchedplane.py delete mode 100644 tests/test_batching.py delete mode 100644 tests/test_io.py delete mode 100644 tests/test_namespace_dict.py delete mode 100644 tests/test_parameter.py delete mode 100644 tests/test_parametrized.py delete mode 100644 tests/utils/models.py diff --git a/docs/source/examples/Example_ImageFit_LM.ipynb b/docs/source/examples/Example_ImageFit_LM.ipynb index d06aa40d..717b5029 100644 --- a/docs/source/examples/Example_ImageFit_LM.ipynb +++ b/docs/source/examples/Example_ImageFit_LM.ipynb @@ -461,6 +461,11 @@ } ], "metadata": { + "kernelspec": { + "display_name": "PY39", + "language": "python", + "name": "python3" + }, "language_info": { "codemirror_mode": { "name": "ipython", @@ -470,7 +475,8 @@ "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3" + "pygments_lexer": "ipython3", + "version": "3.9.5" } }, "nbformat": 4, diff --git a/docs/source/examples/Example_QSOLensFit.ipynb b/docs/source/examples/Example_QSOLensFit.ipynb index c9febffa..15783f25 100644 --- a/docs/source/examples/Example_QSOLensFit.ipynb +++ b/docs/source/examples/Example_QSOLensFit.ipynb @@ -75,10 +75,15 @@ "sp_x = torch.tensor(0.2)\n", "sp_y = torch.tensor(0.2)\n", "\n", - "# true parameters x0 y0 q phi b\n", + "# true parameters x0 y0 q phi b\n", "params = torch.tensor([0.0, 0.0, 0.4, np.pi / 5, 1.0])\n", "# Points in image plane\n", - "x, y = lens.forward_raytrace(sp_x, sp_y, z_s, params)" + "x, y = lens.forward_raytrace(sp_x, sp_y, z_s, params)\n", + "# get magnifications\n", + "mu = lens.magnification(x, y, z_s, params)\n", + "# remove heavily demagnified points\n", + "x = x[mu > 1e-2]\n", + "y = y[mu > 1e-2]" ] }, { @@ -233,7 +238,7 @@ "detA = torch.linalg.det(A)\n", "\n", "CS = ax.contour(\n", - " thx, thy, detA, levels=[0.0], colors=\"green\", linestyles=\"dashed\", zorder=1\n", + " thx, thy, detA, levels=[0.0], colors=\"orange\", linestyles=\"dashed\", zorder=1\n", ")\n", "# Get the path from the matplotlib contour plot of the critical line\n", "paths = CS.allsegs[0]\n", @@ -246,7 +251,7 @@ " y1, y2 = lens.raytrace(x1, x2, z_s, params)\n", "\n", " # Plot the caustic\n", - " ax.plot(y1, y2, color=\"orange\", linestyle=\"--\", label=\"Fit\", zorder=1)\n", + " ax.plot(y1, y2, color=\"green\", linestyle=\"--\", label=\"Fit\", zorder=1)\n", "plt.legend()\n", "plt.show()" ] @@ -261,6 +266,11 @@ } ], "metadata": { + "kernelspec": { + "display_name": "PY39", + "language": "python", + "name": "python3" + }, "language_info": { "codemirror_mode": { "name": "ipython", @@ -270,7 +280,8 @@ "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3" + "pygments_lexer": "ipython3", + "version": "3.9.5" } }, "nbformat": 4, diff --git a/docs/source/tutorials/InterfaceIntroduction_oop.ipynb b/docs/source/tutorials/InterfaceIntroduction_oop.ipynb index dbb251a4..b2f4e639 100644 --- a/docs/source/tutorials/InterfaceIntroduction_oop.ipynb +++ b/docs/source/tutorials/InterfaceIntroduction_oop.ipynb @@ -141,7 +141,7 @@ "source": [ "# Print out the order of model parameters\n", "# Note that parameters with values are \"static\" so they don't need to be provided by you\n", - "sim.state_dict()" + "print(sim)" ] }, { @@ -240,7 +240,7 @@ "outputs": [], "source": [ "fig, axarr = plt.subplots(3, 7, figsize=(20, 9))\n", - "labels = tuple(sim.state_dict().keys())[3:]\n", + "labels = tuple(p.name for p in sim.dynamic_params)\n", "for i, ax in enumerate(axarr.flatten()):\n", " ax.imshow(J[..., i], origin=\"lower\")\n", " ax.set_title(labels[i])\n", @@ -265,7 +265,7 @@ "outputs": [], "source": [ "# Substitute sim with sim for the yaml method\n", - "sim.graph(True, True)" + "sim.graphviz()" ] }, { @@ -301,6 +301,11 @@ } ], "metadata": { + "kernelspec": { + "display_name": "PY39", + "language": "python", + "name": "python3" + }, "language_info": { "codemirror_mode": { "name": "ipython", @@ -310,7 +315,8 @@ "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3" + "pygments_lexer": "ipython3", + "version": "3.9.5" } }, "nbformat": 4, diff --git a/docs/source/tutorials/InterfaceIntroduction_yaml.ipynb b/docs/source/tutorials/InterfaceIntroduction_yaml.ipynb index d262ea12..ca90620f 100644 --- a/docs/source/tutorials/InterfaceIntroduction_yaml.ipynb +++ b/docs/source/tutorials/InterfaceIntroduction_yaml.ipynb @@ -76,7 +76,7 @@ "source": [ "# Print out the order of model parameters\n", "# Note that parameters with values are \"static\" so they don't need to be provided by you\n", - "sim.state_dict()" + "print(sim)" ] }, { @@ -177,7 +177,7 @@ "outputs": [], "source": [ "fig, axarr = plt.subplots(3, 7, figsize=(20, 9))\n", - "labels = tuple(sim.state_dict().keys())[3:]\n", + "labels = tuple(p.name for p in sim.dynamic_params)\n", "for i, ax in enumerate(axarr.flatten()):\n", " ax.imshow(J[..., i], origin=\"lower\")\n", " ax.set_title(labels[i])\n", @@ -202,7 +202,7 @@ "outputs": [], "source": [ "# The simulator is internally represented as a directed acyclic graph of operations\n", - "sim.graph(True, True)" + "sim.graphviz()" ] }, { @@ -384,7 +384,7 @@ ], "metadata": { "kernelspec": { - "display_name": "caustic", + "display_name": "PY39", "language": "python", "name": "python3" }, @@ -398,7 +398,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.18" + "version": "3.9.5" } }, "nbformat": 4, diff --git a/docs/source/tutorials/Introduction.ipynb b/docs/source/tutorials/Introduction.ipynb index 00fda241..570d85dd 100644 --- a/docs/source/tutorials/Introduction.ipynb +++ b/docs/source/tutorials/Introduction.ipynb @@ -170,7 +170,7 @@ "metadata": {}, "outputs": [], "source": [ - "simulator.graph(True, True)" + "simulator.graphviz()" ] }, { @@ -200,7 +200,7 @@ "source": [ "# Making a parameter static\n", "simulator.z_s = 1.0\n", - "simulator.graph(False, True) # z_s turns grey" + "simulator.graphviz() # z_s turns grey" ] }, { @@ -211,9 +211,7 @@ "source": [ "# Making a parameter dynamic\n", "simulator.z_s = None\n", - "simulator.graph(\n", - " False, True\n", - ") # z_s turns white, which makes it disappear when we don't show the dynamic parameters (first option False)" + "simulator.graphviz() # z_s turns white, which makes it disappear when we don't show the dynamic parameters (first option False)" ] }, { @@ -315,7 +313,7 @@ "metadata": {}, "source": [ "### Flattened Tensor\n", - "To make sure the order of the parameter is correct, print the simulator. Order of dynamic parameters is shown in the `x_order` field" + "To make sure the order of the parameter is correct, print the simulator. Order of dynamic parameters is read top to bottom." ] }, { @@ -324,7 +322,7 @@ "metadata": {}, "outputs": [], "source": [ - "simulator" + "print(simulator)" ] }, { @@ -367,7 +365,7 @@ "simulator.lens.x0 = None\n", "simulator.lens.cosmology.h0 = None\n", "\n", - "simulator" + "print(simulator)" ] }, { @@ -377,9 +375,9 @@ "outputs": [], "source": [ "B = 5\n", + "cosmo_params = torch.rand(B, 1) # h0\n", "lens_params = torch.randn(B, 2) # x0 and b\n", "source_params = torch.rand(B, 1) # Ie\n", - "cosmo_params = torch.rand(B, 1) # h0\n", "\n", "x = [lens_params, cosmo_params, source_params]\n", "ys = vmap(simulator)(x)" @@ -552,7 +550,7 @@ ")\n", "simulator = LensSource(lens, source, pixelscale=pixelscale, pixels_x=pixels, z_s=z_s)\n", "\n", - "simulator.graph(True, True)" + "simulator.graphviz()" ] }, { @@ -731,19 +729,21 @@ "source": [ "## Creating your own Simulator\n", "\n", - "Here, we only introduce the general design principles to create a Simulator. Worked examples can be found in [this notebook](./Simulators.ipynb). \n", + "Here, we only introduce the general design principles to create a simulator. More comprehensive explanations can be found in the [caskade docs](https://caskade.readthedocs.io/). \n", "\n", "### A Simulator is very much like a neural network in Pytorch\n", - "A simulator inherits from the super class `Simulator`, similar to how a neural network inherits from the `nn.Module` class in `Pytorch`\n", + "A simulator inherits from the super class `caskade.Module`, similar to how a neural network inherits from the `nn.Module` class in `Pytorch`\n", "\n", "```python\n", - "from caustics import Simulator\n", + "from caustics import Module, Param, forward\n", "\n", - "class MySim(Simulator):\n", + "class MySim(Module):\n", " def __init__(self):\n", " super().__init__()\n", + " self.p = Param(\"p\")\n", "\n", - " def forward(self, x):\n", + " @forward\n", + " def myfunction(self, x, p):\n", " ...\n", "```\n", "\n", @@ -766,23 +766,25 @@ "\n", "### How to feed parameters to the different modules\n", "\n", - "This is probably the easiest part of building a Simulator. Just feed `x` at the end of each method. And you are done. \n", + "This is probably the easiest part of building a Simulator, you only provide the values when calling the top level simulator.\n", " \n", - " Here is a minimal example that shows how to feed the parameters `x` to different modules in the `forward` method\n", + " Here is a minimal example that shows how to feed the parameters the `forward` method\n", " ```python\n", - " def forward(self, x):\n", - " alpha_x, alpha_y = self.lens.reduced_deflection_angle(self.theta_x, self.theta_y, self.z_s, x)\n", - " beta_x = self.theta_x - alpha_x # lens equation\n", - " ...\n", - " lensed_image = self.source.brightness(beta_x, beta_y, x)\n", + " @forward\n", + " def raytrace(self, x, y):\n", + " alpha_x, alpha_y = self.lens.reduced_deflection_angle(x, y, self.z_s)\n", + " beta_x = x - alpha_x # lens equation\n", + " beta_y = y - alpha_y\n", + " return self.source.brightness(beta_x, beta_y)\n", + "sim.raytrace(xgrid, ygrid, params)\n", " ``` \n", "\n", - "You might worry that `x` can have a relatively complex structure (flattened tensor, semantic lict, low-level dictionary). \n", + "You might worry that `params` can have a relatively complex structure (flattened tensor, semantic list, low-level dictionary). \n", "`caustics` handles this complexity for you. \n", - "You only need to make sure that `x` contains all the **dynamic** parameters required by your custom simulator. \n", - "This design works for every `caustics` module and each of their methods, meaning that `x` is always the last argument in a `caustics` method call signature. \n", + "You only need to make sure that `params` contains all the **dynamic** parameters required by your custom simulator. \n", + "This design works for every `caustics` module and each of their methods, meaning that `params` is always the last argument in a `caustics` method call signature. \n", " \n", - "The only details that you need to handle explicitly in your own simulator are stuff like the camera pixel position (`theta_x` and `theta_y`), and source redshifts (`z_s`). Those are often constructed in the `__init__` method because they can be assumed fixed. Thus, the example above assumed that they can be retrieved from the `self` registry. A Simulator is often an abstraction of an instrument with many fixed variables to describe it, or aimed at a specific observation. \n", + "The only details that you need to handle explicitly in your own simulator are stuff like the camera pixel position (`xgrid` and `ygrid`), and source redshifts (`z_s`). Those are often constructed in the `__init__` method because they can be assumed fixed. Thus, the example above assumed that they can be retrieved from the `self` registry. A Simulator is often an abstraction of an instrument with many fixed variables to describe it, or aimed at a specific observation. \n", "\n", "Of course, you could have more complex workflows for which this assumption is not true. For example, you might want to infer the PSF parameters of your instrument and need to feed this to the simulator as a dynamic parameter. \n", "The next section has what you need to customize completely your simulator\n", @@ -798,36 +800,43 @@ "metadata": {}, "outputs": [], "source": [ - "from caustics import Simulator\n", + "from caustics import Module, forward, Param\n", "\n", "\n", - "class MySim(Simulator):\n", + "class MySim(Module):\n", " def __init__(self):\n", " super().__init__() # Don't forget to use super!!\n", " # shape has to be a tuple, e.g. shape=(1,). This can be any shape you need.\n", - " self.add_param(\n", + " self.my_dynamic_arg = Param(\n", " \"my_dynamic_arg\", value=None, shape=(1,)\n", " ) # register a dynamic parameter in the DAG\n", - " self.add_param(\n", - " \"my_static_arg\", value=1.0, shape=(1,)\n", + " self.my_static_arg = Param(\n", + " \"my_static_arg\", value=1.0, shape=()\n", " ) # register a static parameter in the DAG\n", "\n", - " def forward(self, x):\n", - " my_arg1, my_arg2 = self.unpack(x) # retrieve your arguments\n", + " @forward\n", + " def forward(self, x, my_dynamic_arg, my_static_arg):\n", "\n", " # My very complex workflow\n", " ...\n", - " return my_arg1 + my_arg2\n", + " return my_dynamic_arg * x + my_static_arg\n", "\n", "\n", "sim = MySim()\n", - "sim.graph(True, True)" + "sim.graphviz()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { "kernelspec": { - "display_name": "caustic", + "display_name": "PY39", "language": "python", "name": "python3" }, @@ -841,7 +850,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.18" + "version": "3.9.5" } }, "nbformat": 4, diff --git a/docs/source/tutorials/InvertLensEquation.ipynb b/docs/source/tutorials/InvertLensEquation.ipynb index edd6dca9..df764dc4 100644 --- a/docs/source/tutorials/InvertLensEquation.ipynb +++ b/docs/source/tutorials/InvertLensEquation.ipynb @@ -261,7 +261,7 @@ " # Plot the caustic\n", " ax.plot(y1, y2, color=\"r\", zorder=1)\n", "ax.imshow(\n", - " sim_wide({}),\n", + " sim_wide(),\n", " origin=\"lower\",\n", " extent=(\n", " -sim_wide.pixelscale * sim_wide.pixels_x / 2,\n", diff --git a/docs/source/tutorials/LensZoo.ipynb b/docs/source/tutorials/LensZoo.ipynb index 51d9cdc5..6afdc588 100644 --- a/docs/source/tutorials/LensZoo.ipynb +++ b/docs/source/tutorials/LensZoo.ipynb @@ -86,6 +86,7 @@ " x0=0.0,\n", " y0=0.0,\n", " th_ein=1.0,\n", + " z_l=z_l,\n", ")\n", "sim = caustics.LensSource(\n", " lens=lens,\n", @@ -112,7 +113,7 @@ "# axarr[0].imshow(np.log10(convergence.numpy()), origin = \"lower\")\n", "axarr[0].axis(\"off\")\n", "axarr[0].set_title(\"Point Convergence not defined\")\n", - "axarr[1].imshow(np.log10(sim([z_l]).numpy()), origin=\"lower\")\n", + "axarr[1].imshow(np.log10(sim().numpy()), origin=\"lower\")\n", "axarr[1].axis(\"off\")\n", "axarr[1].set_title(\"Lensed Sersic\")\n", "plt.show()" @@ -140,6 +141,7 @@ " x0=0.0,\n", " y0=0.0,\n", " th_ein=1.0,\n", + " z_l=z_l,\n", ")\n", "sim = caustics.LensSource(\n", " lens=lens,\n", @@ -164,12 +166,12 @@ "source": [ "fig, axarr = plt.subplots(1, 2, figsize=(14, 7))\n", "convergence = avg_pool2d(\n", - " lens.convergence(thx, thy, z_s, z_l=z_l).squeeze()[None, None], upsample_factor\n", + " lens.convergence(thx, thy, z_s).squeeze()[None, None], upsample_factor\n", ").squeeze()\n", "axarr[0].imshow(np.log10(convergence.numpy()), origin=\"lower\")\n", "axarr[0].axis(\"off\")\n", "axarr[0].set_title(\"SIS Convergence\")\n", - "axarr[1].imshow(np.log10(sim([z_l]).numpy()), origin=\"lower\")\n", + "axarr[1].imshow(np.log10(sim().numpy()), origin=\"lower\")\n", "axarr[1].axis(\"off\")\n", "axarr[1].set_title(\"Lensed Sersic\")\n", "plt.show()" @@ -199,6 +201,7 @@ " q=0.6,\n", " phi=np.pi / 2,\n", " b=1.0,\n", + " z_l=z_l,\n", ")\n", "sim = caustics.LensSource(\n", " lens=lens,\n", @@ -223,12 +226,12 @@ "source": [ "fig, axarr = plt.subplots(1, 2, figsize=(14, 7))\n", "convergence = avg_pool2d(\n", - " lens.convergence(thx, thy, z_s, z_l=z_l).squeeze()[None, None], upsample_factor\n", + " lens.convergence(thx, thy, z_s).squeeze()[None, None], upsample_factor\n", ").squeeze()\n", "axarr[0].imshow(np.log10(convergence.numpy()), origin=\"lower\")\n", "axarr[0].axis(\"off\")\n", "axarr[0].set_title(\"SIE Convergence\")\n", - "axarr[1].imshow(np.log10(sim([z_l]).numpy()), origin=\"lower\")\n", + "axarr[1].imshow(np.log10(sim().numpy()), origin=\"lower\")\n", "axarr[1].axis(\"off\")\n", "axarr[1].set_title(\"Lensed Sersic\")\n", "plt.show()" @@ -259,6 +262,7 @@ " phi=np.pi / 2,\n", " b=1.0,\n", " t=0.5,\n", + " z_l=z_l,\n", ")\n", "sim = caustics.LensSource(\n", " lens=lens,\n", @@ -283,12 +287,12 @@ "source": [ "fig, axarr = plt.subplots(1, 2, figsize=(14, 7))\n", "convergence = avg_pool2d(\n", - " lens.convergence(thx, thy, z_s, z_l=z_l).squeeze()[None, None], upsample_factor\n", + " lens.convergence(thx, thy, z_s).squeeze()[None, None], upsample_factor\n", ").squeeze()\n", "axarr[0].imshow(np.log10(convergence.numpy()), origin=\"lower\")\n", "axarr[0].axis(\"off\")\n", "axarr[0].set_title(\"EPL Convergence\")\n", - "axarr[1].imshow(np.log10(sim([z_l]).numpy()), origin=\"lower\")\n", + "axarr[1].imshow(np.log10(sim().numpy()), origin=\"lower\")\n", "axarr[1].axis(\"off\")\n", "axarr[1].set_title(\"Lensed Sersic\")\n", "plt.show()" @@ -319,6 +323,7 @@ " y0=0.0,\n", " m=1e13,\n", " c=20.0,\n", + " z_l=z_l,\n", ")\n", "sim = caustics.LensSource(\n", " lens=lens,\n", @@ -343,12 +348,12 @@ "source": [ "fig, axarr = plt.subplots(1, 2, figsize=(14, 7))\n", "convergence = avg_pool2d(\n", - " lens.convergence(thx, thy, z_s, z_l=z_l).squeeze()[None, None], upsample_factor\n", + " lens.convergence(thx, thy, z_s).squeeze()[None, None], upsample_factor\n", ").squeeze()\n", "axarr[0].imshow(np.log10(convergence.numpy()), origin=\"lower\")\n", "axarr[0].axis(\"off\")\n", "axarr[0].set_title(\"NFW Convergence\")\n", - "axarr[1].imshow(np.log10(sim([z_l]).numpy()), origin=\"lower\")\n", + "axarr[1].imshow(np.log10(sim().numpy()), origin=\"lower\")\n", "axarr[1].axis(\"off\")\n", "axarr[1].set_title(\"Lensed Sersic\")\n", "plt.show()" @@ -382,6 +387,7 @@ " mass=1e12,\n", " scale_radius=1.0,\n", " tau=3.0,\n", + " z_l=z_l,\n", ")\n", "sim = caustics.LensSource(\n", " lens=lens,\n", @@ -406,12 +412,12 @@ "source": [ "fig, axarr = plt.subplots(1, 2, figsize=(14, 7))\n", "convergence = avg_pool2d(\n", - " lens.convergence(thx, thy, z_s, z_l=z_l).squeeze()[None, None], upsample_factor\n", + " lens.convergence(thx, thy, z_s).squeeze()[None, None], upsample_factor\n", ").squeeze()\n", "axarr[0].imshow(np.log10(convergence.numpy()), origin=\"lower\")\n", "axarr[0].axis(\"off\")\n", "axarr[0].set_title(\"Truncated NFW Convergence\")\n", - "axarr[1].imshow(np.log10(sim([z_l]).numpy()), origin=\"lower\")\n", + "axarr[1].imshow(np.log10(sim().numpy()), origin=\"lower\")\n", "axarr[1].axis(\"off\")\n", "axarr[1].set_title(\"Lensed Sersic\")\n", "plt.show()" @@ -445,6 +451,7 @@ " mass=1e13,\n", " core_radius=5e-1,\n", " scale_radius=15.0,\n", + " z_l=z_l,\n", ")\n", "sim = caustics.LensSource(\n", " lens=lens,\n", @@ -469,12 +476,12 @@ "source": [ "fig, axarr = plt.subplots(1, 2, figsize=(14, 7))\n", "convergence = avg_pool2d(\n", - " lens.convergence(thx, thy, z_s, z_l=z_l).squeeze()[None, None], upsample_factor\n", + " lens.convergence(thx, thy, z_s).squeeze()[None, None], upsample_factor\n", ").squeeze()\n", "axarr[0].imshow(np.log10(convergence.numpy()), origin=\"lower\")\n", "axarr[0].axis(\"off\")\n", "axarr[0].set_title(\"Pseudo Jaffe Convergence\")\n", - "axarr[1].imshow(np.log10(sim([z_l]).numpy()), origin=\"lower\")\n", + "axarr[1].imshow(np.log10(sim().numpy()), origin=\"lower\")\n", "axarr[1].axis(\"off\")\n", "axarr[1].set_title(\"Lensed Sersic\")\n", "plt.show()" @@ -503,6 +510,7 @@ " y0=0.0,\n", " gamma_1=1.0,\n", " gamma_2=-1.0,\n", + " z_l=z_l,\n", ")\n", "sim = caustics.LensSource(\n", " lens=lens,\n", @@ -530,7 +538,7 @@ "# axarr[0].imshow(np.log10(convergence.numpy()), origin = \"lower\")\n", "axarr[0].axis(\"off\")\n", "axarr[0].set_title(\"External Shear Convergence not defined\")\n", - "axarr[1].imshow(np.log10(sim([z_l]).numpy()), origin=\"lower\")\n", + "axarr[1].imshow(np.log10(sim().numpy()), origin=\"lower\")\n", "axarr[1].axis(\"off\")\n", "axarr[1].set_title(\"Lensed Sersic\")\n", "plt.show()" @@ -557,7 +565,8 @@ " cosmology=cosmology,\n", " x0=0.0,\n", " y0=0.0,\n", - " surface_density=1.5,\n", + " sd=1.5,\n", + " z_l=z_l,\n", ")\n", "sim = caustics.LensSource(\n", " lens=lens,\n", @@ -582,19 +591,32 @@ "source": [ "fig, axarr = plt.subplots(1, 2, figsize=(14, 7))\n", "convergence = avg_pool2d(\n", - " lens.convergence(thx, thy, z_s, z_l=z_l).squeeze()[None, None], upsample_factor\n", + " lens.convergence(thx, thy, z_s).squeeze()[None, None], upsample_factor\n", ").squeeze()\n", "axarr[0].imshow(np.log10(convergence.numpy()), origin=\"lower\")\n", "axarr[0].axis(\"off\")\n", "axarr[0].set_title(\"Mass Sheet Convergence\")\n", - "axarr[1].imshow(np.log10(sim([z_l]).numpy()), origin=\"lower\")\n", + "axarr[1].imshow(np.log10(sim().numpy()), origin=\"lower\")\n", "axarr[1].axis(\"off\")\n", "axarr[1].set_title(\"Lensed Sersic\")\n", "plt.show()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4e229ae1", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { + "kernelspec": { + "display_name": "PY39", + "language": "python", + "name": "python3" + }, "language_info": { "codemirror_mode": { "name": "ipython", @@ -604,7 +626,8 @@ "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3" + "pygments_lexer": "ipython3", + "version": "3.9.5" } }, "nbformat": 4, diff --git a/docs/source/tutorials/Microlensing.ipynb b/docs/source/tutorials/Microlensing.ipynb index 61dae44e..e8b1ad2c 100644 --- a/docs/source/tutorials/Microlensing.ipynb +++ b/docs/source/tutorials/Microlensing.ipynb @@ -14,6 +14,7 @@ "import matplotlib.animation as animation\n", "import torch\n", "import caustics\n", + "from caustics import Module, forward\n", "from IPython.display import HTML" ] }, @@ -86,26 +87,24 @@ "metadata": {}, "outputs": [], "source": [ - "class Microlens(caustics.Simulator):\n", + "class Microlens(Module):\n", "\n", - " theta_x = theta_x\n", - " theta_y = theta_y\n", - " z_s = 0.0\n", - "\n", - " def __init__(self, lens, src, z_s=None, name: str = \"sim\"):\n", + " def __init__(self, lens, src, z_s=0.0, name: str = \"sim\"):\n", " super().__init__(name)\n", " self.lens = lens\n", " self.src = src\n", + " self.z_s = z_s\n", + " self.theta_x = theta_x\n", + " self.theta_y = theta_y\n", "\n", - " def forward(self, params):\n", + " @forward\n", + " def __call__(self):\n", " # Compute the observed positions of the source\n", - " beta_x, beta_y = self.lens.raytrace(\n", - " self.theta_x, self.theta_y, self.z_s, params\n", - " )\n", + " beta_x, beta_y = self.lens.raytrace(self.theta_x, self.theta_y, self.z_s)\n", " # Compute the brightness of the source at the observed positions (the image)\n", - " brightness = self.src.brightness(beta_x, beta_y, params)\n", + " brightness = self.src.brightness(beta_x, beta_y)\n", " # Compute the baseline (unlensed) brightness of the source\n", - " baseline_brightness = self.src.brightness(theta_x, theta_y, params)\n", + " baseline_brightness = self.src.brightness(theta_x, theta_y)\n", " # Return the lensed image [n_pix x n_pix], and magnification\n", " return brightness, brightness.mean() / baseline_brightness.mean()" ] @@ -153,7 +152,7 @@ "sim.src.Ie = 5.0\n", "sim.src.gamma = gamma\n", "\n", - "sim.graph(True, True)" + "sim.graphviz()" ] }, { @@ -169,7 +168,7 @@ "metadata": {}, "outputs": [], "source": [ - "sim.x_order" + "print(sim)" ] }, { @@ -257,11 +256,18 @@ "# Or display the animation inline\n", "HTML(ani.to_jshtml())" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "PY39", "language": "python", "name": "python3" }, @@ -275,7 +281,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.14" + "version": "3.9.5" } }, "nbformat": 4, diff --git a/docs/source/tutorials/MultiplaneDemo.ipynb b/docs/source/tutorials/MultiplaneDemo.ipynb index 4f3d4296..0d488512 100644 --- a/docs/source/tutorials/MultiplaneDemo.ipynb +++ b/docs/source/tutorials/MultiplaneDemo.ipynb @@ -281,6 +281,11 @@ } ], "metadata": { + "kernelspec": { + "display_name": "PY39", + "language": "python", + "name": "python3" + }, "language_info": { "codemirror_mode": { "name": "ipython", @@ -290,7 +295,8 @@ "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3" + "pygments_lexer": "ipython3", + "version": "3.9.5" } }, "nbformat": 4, diff --git a/docs/source/tutorials/Parameters.ipynb b/docs/source/tutorials/Parameters.ipynb deleted file mode 100644 index 6859c9f7..00000000 --- a/docs/source/tutorials/Parameters.ipynb +++ /dev/null @@ -1,242 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "0", - "metadata": {}, - "source": [ - "# Parameters\n", - "\n", - "This notebook will walk you through the various ways to interact with parameters in caustics. For each lens and light model there are certain parameters which are given special priority as these parameters are the ones that would be sampled in a simulator. This allows for taking the machinery of caustics and converting it into a function which can use all the power of pytorch, or other sampling/optimization frameworks. \n", - "\n", - "Throughout the tutorial, keep in mind that parameters are stored in a directed acyclic graph (DAG). This gives a unique way to access each parameter" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1", - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2\n", - "\n", - "import torch\n", - "import matplotlib.pyplot as plt\n", - "\n", - "import caustics" - ] - }, - { - "cell_type": "markdown", - "id": "2", - "metadata": {}, - "source": [ - "## Setting static/dynamic parameters\n", - "\n", - "Let's see how we can set static and dynamic parameters. In caustics, a dynamic parameter is one which will be involved in sampling and must be provided on evaluation of a function. A static parameter has a fixed value and so \"disappears\" from the graph so that you don't need to worry about it anymore." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3", - "metadata": {}, - "outputs": [], - "source": [ - "# Flat cosmology with all dynamic parameters\n", - "cosmo = caustics.FlatLambdaCDM(name=\"cosmo\", h0=None, Om0=None)\n", - "\n", - "# SIE lens with q and b as static parameters\n", - "lens = caustics.SIE(cosmology=cosmo, q=0.4, b=1.0)\n", - "\n", - "# Sersic with all dynamic parameters except the sersic index, effective radius, and effective brightness\n", - "source = caustics.Sersic(name=\"source\", n=2.0, Re=1.0, Ie=1.0)\n", - "\n", - "# Sersic with all dynamic parameters except the x position, position angle, and effective radius\n", - "lens_light = caustics.Sersic(name=\"lenslight\", x0=0.0, phi=1.3, Re=1.0)\n", - "\n", - "# A simulator which captures all these parameters into a single DAG\n", - "sim = caustics.LensSource(\n", - " lens=lens, source=source, lens_light=lens_light, pixelscale=0.05, pixels_x=100\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "4", - "metadata": {}, - "source": [ - "We can have the simulator print a graph of the DAG from it's perspective. Note that the white boxes are dynamic parameters while the grey boxes are static parameters" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5", - "metadata": {}, - "outputs": [], - "source": [ - "sim.graph(True, True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6", - "metadata": {}, - "outputs": [], - "source": [ - "# Accessing a parameter and giving it a value will turn it into a static parameter\n", - "sim.SIE.phi = 0.4\n", - "sim.graph(True, True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7", - "metadata": {}, - "outputs": [], - "source": [ - "# Accessing a parameter and setting it to None will turn it into a dynamic parameter\n", - "sim.lenslight.x0 = None\n", - "sim.graph(True, True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8", - "metadata": {}, - "outputs": [], - "source": [ - "# This also gives us the order of parameters for a vector that can be an input to the sim function\n", - "x_tens = torch.tensor(\n", - " [\n", - " 1.5, # z_s\n", - " 0.5, # sie z_l\n", - " 0.1, # sie x0\n", - " -0.1, # sie y0\n", - " 0.7, # sie cosmo h0\n", - " 0.31, # sie cosmo Om0\n", - " 0.0, # source x0\n", - " 0.0, # source y0\n", - " 0.7, # source q\n", - " 1.4, # source phi\n", - " 0.1, # lenslight x0\n", - " -0.1, # lenslight y0\n", - " 0.6, # lenslight q\n", - " 3.0, # lenslight n\n", - " 1.0, # lenslight Ie\n", - " ]\n", - ")\n", - "res_tens = sim(x_tens)\n", - "\n", - "# alternatively we can construct a dictionary\n", - "x_dict = {\n", - " \"sim\": {\n", - " \"z_s\": torch.tensor(1.5),\n", - " },\n", - " \"SIE\": {\n", - " \"z_l\": torch.tensor(0.5),\n", - " \"x0\": torch.tensor(0.1),\n", - " \"y0\": torch.tensor(-0.1),\n", - " },\n", - " \"cosmo\": {\n", - " \"h0\": torch.tensor(0.7),\n", - " \"Om0\": torch.tensor(0.31),\n", - " },\n", - " \"source\": {\n", - " \"x0\": torch.tensor(0.0),\n", - " \"y0\": torch.tensor(0.0),\n", - " \"q\": torch.tensor(0.7),\n", - " \"phi\": torch.tensor(1.4),\n", - " },\n", - " \"lenslight\": {\n", - " \"x0\": torch.tensor(0.1),\n", - " \"y0\": torch.tensor(-0.1),\n", - " \"q\": torch.tensor(0.6),\n", - " \"n\": torch.tensor(3.0),\n", - " \"Ie\": torch.tensor(1.0),\n", - " },\n", - "}\n", - "res_dict = sim(x_dict)\n", - "\n", - "fig, axarr = plt.subplots(1, 2, figsize=(16, 8))\n", - "axarr[0].imshow(res_tens, origin=\"lower\")\n", - "axarr[0].set_title(\"Simulator from tensor\")\n", - "axarr[1].imshow(res_dict, origin=\"lower\")\n", - "axarr[1].set_title(\"Simulator from dictionary\")\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "id": "9", - "metadata": {}, - "source": [ - "## Manual Inputs\n", - "\n", - "We have now seen the standard `pack` method of passing dynamic parameters to a caustics function/simulator. This is very powerful at scale, but can be tedious to enter by hand while prototyping and doing tests. Now lets see a more manual way to pass parameters to a function. For this lets try getting the exact position of each of the 4 images of the background source." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "10", - "metadata": {}, - "outputs": [], - "source": [ - "# First find the position of each of the images\n", - "x, y = lens.forward_raytrace(\n", - " torch.tensor(0.0), # First three arguments are regular function arguments\n", - " torch.tensor(0.0),\n", - " torch.tensor(1.5),\n", - " z_l=torch.tensor(0.5), # Next three are kwargs which give the SIE parameters\n", - " x0=torch.tensor(0.1),\n", - " y0=torch.tensor(-0.1),\n", - " cosmo_h0=torch.tensor(\n", - " 0.7\n", - " ), # Next two are parameters needed for \"cosmo\" and so they are named as such\n", - " cosmo_Om0=torch.tensor(0.31),\n", - " fov=0.05 * 100, # Next two are kwargs for the \"forward_raytrace\" method\n", - " n_init=100,\n", - ")\n", - "\n", - "fig, ax = plt.subplots(figsize=(8, 8))\n", - "ax.imshow(\n", - " res_tens,\n", - " extent=(-0.05 * 100 / 2, 0.05 * 100 / 2, -0.05 * 100 / 2, 0.05 * 100 / 2),\n", - " origin=\"lower\",\n", - ")\n", - "ax.scatter(x.detach().cpu().numpy(), y.detach().cpu().numpy(), color=\"r\")\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "11", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/docs/source/tutorials/VisualizeCaustics.ipynb b/docs/source/tutorials/VisualizeCaustics.ipynb index d487a2fe..70fdb037 100644 --- a/docs/source/tutorials/VisualizeCaustics.ipynb +++ b/docs/source/tutorials/VisualizeCaustics.ipynb @@ -61,8 +61,7 @@ " np.pi / 5, # sie phi\n", " 1.0, # sie b\n", " ]\n", - ")\n", - "packparams = sie.pack(x)" + ")" ] }, { @@ -83,7 +82,7 @@ "outputs": [], "source": [ "# Conveniently caustic has a function to compute the jacobian of the lens equation\n", - "A = sie.jacobian_lens_equation(thx, thy, z_s, packparams)\n", + "A = sie.jacobian_lens_equation(thx, thy, z_s, x)\n", "# Note that if this is too slow you can set `method = \"finitediff\"` to run a faster version. You will also need to provide `pixelscale` then\n", "\n", "# Here we compute A's determinant at every point\n", @@ -128,7 +127,7 @@ " x1 = torch.tensor(list(float(vs[0]) for vs in path))\n", " x2 = torch.tensor(list(float(vs[1]) for vs in path))\n", " # raytrace the points to the source plane\n", - " y1, y2 = sie.raytrace(x1, x2, z_s, packparams)\n", + " y1, y2 = sie.raytrace(x1, x2, z_s, x)\n", "\n", " # Plot the caustic\n", " plt.plot(y1, y2)\n", @@ -145,6 +144,11 @@ } ], "metadata": { + "kernelspec": { + "display_name": "PY39", + "language": "python", + "name": "python3" + }, "language_info": { "codemirror_mode": { "name": "ipython", @@ -154,7 +158,8 @@ "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3" + "pygments_lexer": "ipython3", + "version": "3.9.5" } }, "nbformat": 4, diff --git a/requirements.txt b/requirements.txt index 91113d22..ed257fc8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ astropy>=5.2.1,<6.0.0 +caskade graphviz==0.20.1 h5py>=3.8.0 mpmath>=1.3.0,<1.4.0 diff --git a/src/caustics/__init__.py b/src/caustics/__init__.py index ffeb85d5..2d2b3466 100644 --- a/src/caustics/__init__.py +++ b/src/caustics/__init__.py @@ -1,5 +1,7 @@ from ._version import version as VERSION # noqa +from caskade import forward, Module, Param, ValidContext + from .cosmology import ( Cosmology, FlatLambdaCDM, @@ -21,6 +23,7 @@ SIE, SIS, SinglePlane, + BatchedPlane, MassSheet, TNFW, Multipole, @@ -35,15 +38,18 @@ StarSource, ) from . import utils -from .sims import LensSource, Microlens, Simulator +from .sims import LensSource, Microlens, build_simulator from .tests import test -from .models.api import build_simulator from . import func __version__ = VERSION __author__ = "Ciela Institute" __all__ = [ + "Module", + "Param", + "ValidContext", + "forward", "Cosmology", "FlatLambdaCDM", "h0_default", @@ -62,6 +68,7 @@ "SIE", "SIS", "SinglePlane", + "BatchedPlane", "MassSheet", "TNFW", "Multipole", @@ -75,7 +82,6 @@ "utils", "LensSource", "Microlens", - "Simulator", "test", "build_simulator", "func", diff --git a/src/caustics/cosmology/FlatLambdaCDM.py b/src/caustics/cosmology/FlatLambdaCDM.py index 9aaaa671..81d044c4 100644 --- a/src/caustics/cosmology/FlatLambdaCDM.py +++ b/src/caustics/cosmology/FlatLambdaCDM.py @@ -3,13 +3,11 @@ import torch from torch import Tensor - +from caskade import forward, Param from astropy.cosmology import default_cosmology from scipy.special import hyp2f1 from ..utils import interp1d -from ..parametrized import unpack -from ..packed import Packed from ..constants import c_Mpc_s, km_to_Mpc from .base import Cosmology, NameType @@ -66,9 +64,14 @@ def __init__( """ super().__init__(name) - self.add_param("h0", h0) - self.add_param("critical_density_0", critical_density_0) - self.add_param("Om0", Om0) + self.h0 = Param("h0", h0, units="unitless", valid=(0, None)) + self.critical_density_0 = Param( + "critical_density_0", + critical_density_0, + units="Msun/Mpc^3", + valid=(0, None), + ) + self.Om0 = Param("Om0", Om0, units="unitless", valid=(0, 1)) self._comoving_distance_helper_x_grid = _comoving_distance_helper_x_grid.to( dtype=torch.float32 @@ -90,7 +93,7 @@ def to( return self - def hubble_distance(self, h0): + def hubble_distance(self, h0: Annotated[Tensor, "Param"]): """ Calculate the Hubble distance. @@ -106,16 +109,12 @@ def hubble_distance(self, h0): """ return c_Mpc_s / (100 * km_to_Mpc) / h0 - @unpack + @forward def critical_density( self, - z: Tensor, - *args, - params: Optional["Packed"] = None, - h0: Optional[Tensor] = None, - critical_density_0: Optional[Tensor] = None, - Om0: Optional[Tensor] = None, - **kwargs, + z: Annotated[Tensor, "Param"], + critical_density_0: Annotated[Tensor, "Param"], + Om0: Annotated[Tensor, "Param"], ) -> torch.Tensor: """ Calculate the critical density at redshift z. @@ -135,10 +134,8 @@ def critical_density( Ode0 = 1 - Om0 return critical_density_0 * (Om0 * (1 + z) ** 3 + Ode0) # fmt: skip - @unpack - def _comoving_distance_helper( - self, x: Tensor, *args, params: Optional["Packed"] = None, **kwargs - ) -> Tensor: + @forward + def _comoving_distance_helper(self, x: Tensor) -> Tensor: """ Helper method for computing comoving distances. @@ -158,16 +155,12 @@ def _comoving_distance_helper( torch.atleast_1d(x), ).reshape(x.shape) - @unpack + @forward def comoving_distance( self, z: Tensor, - *args, - params: Optional["Packed"] = None, - h0: Optional[Tensor] = None, - critical_density_0: Optional[Tensor] = None, - Om0: Optional[Tensor] = None, - **kwargs, + h0: Annotated[Tensor, "Param"], + Om0: Annotated[Tensor, "Param"], ) -> Tensor: """ Calculate the comoving distance to redshift z. @@ -187,19 +180,10 @@ def comoving_distance( Ode0 = 1 - Om0 ratio = (Om0 / Ode0) ** (1 / 3) DH = self.hubble_distance(h0) - DC1z = self._comoving_distance_helper((1 + z) * ratio, params) - DC = self._comoving_distance_helper(ratio, params) + DC1z = self._comoving_distance_helper((1 + z) * ratio) + DC = self._comoving_distance_helper(ratio) return DH * (DC1z - DC) / (Om0 ** (1 / 3) * Ode0 ** (1 / 6)) # fmt: skip - @unpack - def transverse_comoving_distance( - self, - z: Tensor, - *args, - params: Optional["Packed"] = None, - h0: Optional[Tensor] = None, - critical_density_0: Optional[Tensor] = None, - Om0: Optional[Tensor] = None, - **kwargs, - ) -> Tensor: - return self.comoving_distance(z, params, **kwargs) + @forward + def transverse_comoving_distance(self, z: Tensor) -> Tensor: + return self.comoving_distance(z) diff --git a/src/caustics/cosmology/base.py b/src/caustics/cosmology/base.py index 152814ad..73ea66e2 100644 --- a/src/caustics/cosmology/base.py +++ b/src/caustics/cosmology/base.py @@ -4,15 +4,14 @@ from typing import Optional, Annotated from torch import Tensor +from caskade import Module, forward from ..constants import G_over_c2 -from ..parametrized import Parametrized, unpack -from ..packed import Packed NameType = Annotated[Optional[str], "Name of the cosmology"] -class Cosmology(Parametrized): +class Cosmology(Module): """ Abstract base class for cosmological models. @@ -45,7 +44,8 @@ def __init__(self, name: NameType = None): super().__init__(name) @abstractmethod - def critical_density(self, z: Tensor, params: Optional["Packed"] = None) -> Tensor: + @forward + def critical_density(self, z: Tensor) -> Tensor: """ Compute the critical density at redshift z. @@ -70,10 +70,8 @@ def critical_density(self, z: Tensor, params: Optional["Packed"] = None) -> Tens ... @abstractmethod - @unpack - def comoving_distance( - self, z: Tensor, *args, params: Optional["Packed"] = None, **kwargs - ) -> Tensor: + @forward + def comoving_distance(self, z: Tensor, *args, **kwargs) -> Tensor: """ Compute the comoving distance to redshift z. @@ -98,10 +96,8 @@ def comoving_distance( ... @abstractmethod - @unpack - def transverse_comoving_distance( - self, z: Tensor, *args, params: Optional["Packed"] = None, **kwargs - ) -> Tensor: + @forward + def transverse_comoving_distance(self, z: Tensor, *args, **kwargs) -> Tensor: """ Compute the transverse comoving distance to redshift z (Mpc). @@ -125,10 +121,8 @@ def transverse_comoving_distance( """ ... - @unpack - def comoving_distance_z1z2( - self, z1: Tensor, z2: Tensor, *args, params: Optional["Packed"] = None, **kwargs - ) -> Tensor: + @forward + def comoving_distance_z1z2(self, z1: Tensor, z2: Tensor) -> Tensor: """ Compute the comoving distance between two redshifts. @@ -155,12 +149,10 @@ def comoving_distance_z1z2( *Unit: Mpc* """ - return self.comoving_distance(z2, params) - self.comoving_distance(z1, params) + return self.comoving_distance(z2) - self.comoving_distance(z1) - @unpack - def transverse_comoving_distance_z1z2( - self, z1: Tensor, z2: Tensor, *args, params: Optional["Packed"] = None, **kwargs - ) -> Tensor: + @forward + def transverse_comoving_distance_z1z2(self, z1: Tensor, z2: Tensor) -> Tensor: """ Compute the transverse comoving distance between two redshifts (Mpc). @@ -188,13 +180,11 @@ def transverse_comoving_distance_z1z2( """ return self.transverse_comoving_distance( - z2, params - ) - self.transverse_comoving_distance(z1, params) + z2 + ) - self.transverse_comoving_distance(z1) - @unpack - def angular_diameter_distance( - self, z: Tensor, *args, params: Optional["Packed"] = None, **kwargs - ) -> Tensor: + @forward + def angular_diameter_distance(self, z: Tensor) -> Tensor: """ Compute the angular diameter distance to redshift z. @@ -216,12 +206,10 @@ def angular_diameter_distance( *Unit: Mpc* """ - return self.comoving_distance(z, params, **kwargs) / (1 + z) + return self.comoving_distance(z) / (1 + z) - @unpack - def angular_diameter_distance_z1z2( - self, z1: Tensor, z2: Tensor, *args, params: Optional["Packed"] = None, **kwargs - ) -> Tensor: + @forward + def angular_diameter_distance_z1z2(self, z1: Tensor, z2: Tensor) -> Tensor: """ Compute the angular diameter distance between two redshifts. @@ -248,16 +236,13 @@ def angular_diameter_distance_z1z2( *Unit: Mpc* """ - return self.comoving_distance_z1z2(z1, z2, params, **kwargs) / (1 + z2) + return self.comoving_distance_z1z2(z1, z2) / (1 + z2) - @unpack + @forward def time_delay_distance( self, z_l: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - **kwargs, ) -> Tensor: """ Compute the time delay distance between lens and source planes. @@ -285,19 +270,16 @@ def time_delay_distance( *Unit: Mpc* """ - d_l = self.angular_diameter_distance(z_l, params) - d_s = self.angular_diameter_distance(z_s, params) - d_ls = self.angular_diameter_distance_z1z2(z_l, z_s, params) + d_l = self.angular_diameter_distance(z_l) + d_s = self.angular_diameter_distance(z_s) + d_ls = self.angular_diameter_distance_z1z2(z_l, z_s) return (1 + z_l) * d_l * d_s / d_ls - @unpack + @forward def critical_surface_density( self, z_l: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - **kwargs, ) -> Tensor: """ Compute the critical surface density between lens and source planes. @@ -325,7 +307,7 @@ def critical_surface_density( *Unit: Msun/Mpc^2* """ - d_l = self.angular_diameter_distance(z_l, params) - d_s = self.angular_diameter_distance(z_s, params) - d_ls = self.angular_diameter_distance_z1z2(z_l, z_s, params) + d_l = self.angular_diameter_distance(z_l) + d_s = self.angular_diameter_distance(z_s) + d_ls = self.angular_diameter_distance_z1z2(z_l, z_s) return d_s / (4 * pi * G_over_c2 * d_l * d_ls) # fmt: skip diff --git a/src/caustics/func.py b/src/caustics/func.py index 7e16934d..efc5d314 100644 --- a/src/caustics/func.py +++ b/src/caustics/func.py @@ -9,6 +9,8 @@ reduced_deflection_angle_point, potential_point, convergence_point, + mass_to_rein_point, + rein_to_mass_point, reduced_deflection_angle_mass_sheet, potential_mass_sheet, convergence_mass_sheet, @@ -64,6 +66,8 @@ "reduced_deflection_angle_point", "potential_point", "convergence_point", + "mass_to_rein_point", + "rein_to_mass_point", "reduced_deflection_angle_mass_sheet", "potential_mass_sheet", "convergence_mass_sheet", diff --git a/src/caustics/io.py b/src/caustics/io.py deleted file mode 100644 index ddb517c1..00000000 --- a/src/caustics/io.py +++ /dev/null @@ -1,126 +0,0 @@ -from pathlib import Path -import json -import struct - -DEFAULT_ENCODING = "utf-8" -SAFETENSORS_METADATA = "__metadata__" - - -def _normalize_path(path: "str | Path") -> Path: - # Convert string path to Path object - if isinstance(path, str): - path = Path(path) - - # Get absolute path - return path.absolute() - - -def to_file( - path: "str | Path", data: "str | bytes", encoding: str = DEFAULT_ENCODING -) -> str: - """ - Save data string or bytes to specified file path - - Parameters - ---------- - path : str or Path - The path to save the data to - data : str | bytes - The data string or bytes to save to file - encoding : str, optional - The string encoding to use, by default "utf-8" - - Returns - ------- - str - The path string where the data is saved - """ - # TODO: Update to allow for remote paths saving - - # Convert string data to bytes - if isinstance(data, str): - data = data.encode(encoding) - - # Normalize path to pathlib.Path object - path = _normalize_path(path) - - with open(path, "wb") as f: - f.write(data) - - return str(path.absolute()) - - -def from_file(path: "str | Path") -> bytes: - """ - Load data from specified file path - - Parameters - ---------- - path : str or Path - The path to load the data from - - Returns - ------- - bytes - The data bytes loaded from the file - """ - # TODO: Update to allow for remote paths loading - - # Normalize path to pathlib.Path object - path = _normalize_path(path) - - return path.read_bytes() - - -def _get_safetensors_header(path: "str | Path") -> dict: - """ - Read specified file header to a dictionary - - Parameters - ---------- - path : str or Path - The path to get header from - - Returns - ------- - dict - The header dictionary - """ - # TODO: Update to allow for remote paths loading of header - - # Normalize path to pathlib.Path object - path = _normalize_path(path) - - # Doing this avoids reading the whole safetensors - # file in case that it's large - with open(path, "rb") as f: - # Get the size of the header by reading first 8 bytes - (length_of_header,) = struct.unpack(" dict: - """ - Get the metadata from the specified file path - - Parameters - ---------- - path : str or Path - The path to get the metadata from - - Returns - ------- - dict - The metadata dictionary - """ - header = _get_safetensors_header(path) - - # Only return the metadata - # if it's not even there, just return blank dict - return header.get(SAFETENSORS_METADATA, {}) diff --git a/src/caustics/lenses/__init__.py b/src/caustics/lenses/__init__.py index 462d53c8..d939150d 100644 --- a/src/caustics/lenses/__init__.py +++ b/src/caustics/lenses/__init__.py @@ -9,6 +9,7 @@ from .sie import SIE from .sis import SIS from .singleplane import SinglePlane +from .batchedplane import BatchedPlane from .mass_sheet import MassSheet from .tnfw import TNFW from .multiplane import Multiplane @@ -30,6 +31,7 @@ "SIE", "SIS", "SinglePlane", + "BatchedPlane", "MassSheet", "TNFW", "Multipole", diff --git a/src/caustics/lenses/base.py b/src/caustics/lenses/base.py index 3eb9ebe4..c1420481 100644 --- a/src/caustics/lenses/base.py +++ b/src/caustics/lenses/base.py @@ -6,11 +6,10 @@ import torch from torch import Tensor +from caskade import Module, Param, forward from ..cosmology import Cosmology -from ..parametrized import Parametrized, unpack from .utils import magnification -from ..packed import Packed from . import func __all__ = ("ThinLens", "ThickLens") @@ -24,7 +23,7 @@ LensesType = Annotated[List["ThinLens"], "A list of ThinLens objects"] -class Lens(Parametrized): +class Lens(Module): """ Base class for all lenses """ @@ -45,14 +44,12 @@ def __init__(self, cosmology: CosmologyType, name: NameType = None): super().__init__(name) self.cosmology = cosmology - @unpack + @forward def jacobian_lens_equation( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, method="autograd", pixelscale=None, **kwargs, @@ -65,7 +62,7 @@ def jacobian_lens_equation( """ if method == "autograd": - return self._jacobian_lens_equation_autograd(x, y, z_s, params, **kwargs) + return self._jacobian_lens_equation_autograd(x, y, z_s, **kwargs) elif method == "finitediff": if pixelscale is None: raise ValueError( @@ -74,45 +71,37 @@ def jacobian_lens_equation( "Please include the pixelscale argument" ) return self._jacobian_lens_equation_finitediff( - x, y, z_s, pixelscale, params, **kwargs + x, y, z_s, pixelscale, **kwargs ) else: raise ValueError("method should be one of: autograd, finitediff") - @unpack + @forward def shear( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, method="autograd", pixelscale: Optional[Tensor] = None, - **kwargs, ): """ General shear calculation for a lens model using the jacobian of the lens equation. Individual lenses may implement more efficient methods. """ - A = self.jacobian_lens_equation( - x, y, z_s, params=params, method=method, pixelscale=pixelscale - ) + A = self.jacobian_lens_equation(x, y, z_s, method=method, pixelscale=pixelscale) I = torch.eye(2, device=A.device, dtype=A.dtype).reshape( # noqa E741 *[1] * len(A.shape[:-2]), 2, 2 ) negPsi = 0.5 * (A[..., 0, 0] + A[..., 1, 1]).unsqueeze(-1).unsqueeze(-1) * I - A return 0.5 * (negPsi[..., 0, 0] - negPsi[..., 1, 1]), negPsi[..., 0, 1] - @unpack + @forward def magnification( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - **kwargs, ) -> Tensor: """ Compute the gravitational magnification at the given coordinates. @@ -145,22 +134,19 @@ def magnification( *Unit: unitless* """ - return magnification(partial(self.raytrace, params=params), x, y, z_s) + return magnification(self.raytrace, x, y, z_s) - @unpack + @forward def forward_raytrace( self, bx: Tensor, by: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, epsilon: float = 1e-3, x0: Optional[Tensor] = None, y0: Optional[Tensor] = None, fov: float = 5.0, divisions: int = 100, - **kwargs, ) -> tuple[Tensor, Tensor]: """ Perform a forward ray-tracing operation which maps from the source plane @@ -213,7 +199,7 @@ def forward_raytrace( *Unit: arcsec* """ - raytrace = partial(self.raytrace, params=params, z_s=z_s) + raytrace = partial(self.raytrace, z_s=z_s) if x0 is None: x0 = torch.zeros((), device=bx.device, dtype=bx.dtype) if y0 is None: @@ -238,14 +224,12 @@ class ThickLens(Lens): the cosmological parameters of the model. """ - @unpack + @forward def reduced_deflection_angle( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, **kwargs, ) -> tuple[Tensor, Tensor]: """ @@ -278,21 +262,19 @@ def reduced_deflection_angle( "ThickLens objects do not have a reduced deflection angle " "since they have no unique lens redshift. " "The distance D_{ls} is undefined in the equation " - "$\alpha_{reduced} = \frac{D_{ls}}{D_s}\alpha_{physical}$." + "$\\alpha_{reduced} = \\frac{D_{ls}}{D_s}\\alpha_{physical}$." "See `effective_reduced_deflection_angle`. " "Now using effective_reduced_deflection_angle, " "please switch functions to remove this warning" ) - return self.effective_reduced_deflection_angle(x, y, z_s, params, **kwargs) + return self.effective_reduced_deflection_angle(x, y, z_s, **kwargs) - @unpack + @forward def effective_reduced_deflection_angle( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, **kwargs, ) -> tuple[Tensor, Tensor]: """ThickLens objects do not have a reduced deflection angle since the @@ -324,17 +306,16 @@ def effective_reduced_deflection_angle( Dynamic parameter container for the lens model. Defaults to None. """ - bx, by = self.raytrace(x, y, z_s, params, **kwargs) + bx, by = self.raytrace(x, y, z_s, **kwargs) return x - bx, y - by - @unpack + @forward def physical_deflection_angle( self, x: Tensor, y: Tensor, z_s: Tensor, *args, - params: Optional["Packed"] = None, **kwargs, ) -> tuple[Tensor, Tensor]: """Physical deflection angles are computed with respect to a lensing @@ -381,14 +362,13 @@ def physical_deflection_angle( ) @abstractmethod - @unpack + @forward def raytrace( self, x: Tensor, y: Tensor, z_s: Tensor, *args, - params: Optional["Packed"] = None, **kwargs, ) -> tuple[Tensor, Tensor]: """Performs ray tracing by computing the angular position on the @@ -431,14 +411,13 @@ def raytrace( ... @abstractmethod - @unpack + @forward def surface_density( self, x: Tensor, y: Tensor, z_s: Tensor, *args, - params: Optional["Packed"] = None, **kwargs, ) -> Tensor: """ @@ -476,14 +455,13 @@ def surface_density( ... @abstractmethod - @unpack + @forward def time_delay( self, x: Tensor, y: Tensor, z_s: Tensor, *args, - params: Optional["Packed"] = None, **kwargs, ) -> Tensor: """ @@ -519,23 +497,20 @@ def time_delay( """ ... - @unpack + @forward def _jacobian_effective_deflection_angle_finitediff( self, x: Tensor, y: Tensor, z_s: Tensor, pixelscale: Tensor, - *args, - params: Optional["Packed"] = None, - **kwargs, ) -> tuple[tuple[Tensor, Tensor], tuple[Tensor, Tensor]]: """ Return the jacobian of the effective reduced deflection angle vector field. This equates to a (2,2) matrix at each (x,y) point. """ # Compute deflection angles - ax, ay = self.effective_reduced_deflection_angle(x, y, z_s, params) + ax, ay = self.effective_reduced_deflection_angle(x, y, z_s) # Build Jacobian J = torch.zeros((*ax.shape, 2, 2), device=ax.device, dtype=ax.dtype) @@ -543,51 +518,59 @@ def _jacobian_effective_deflection_angle_finitediff( J[..., 1, 1], J[..., 1, 0] = torch.gradient(ay, spacing=pixelscale) return J - @unpack + @forward def _jacobian_effective_deflection_angle_autograd( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - **kwargs, + chunk_size: int = 10000, ) -> tuple[tuple[Tensor, Tensor], tuple[Tensor, Tensor]]: """ Return the jacobian of the effective reduced deflection angle vector field. This equates to a (2,2) matrix at each (x,y) point. """ - # Ensure the x,y coordinates track gradients - x = x.detach().requires_grad_() - y = y.detach().requires_grad_() - - # Compute deflection angles - ax, ay = self.effective_reduced_deflection_angle(x, y, z_s, params) # Build Jacobian - J = torch.zeros((*ax.shape, 2, 2), device=ax.device, dtype=ax.dtype) - (J[..., 0, 0],) = torch.autograd.grad( - ax, x, grad_outputs=torch.ones_like(ax), create_graph=True + J = torch.zeros((*x.shape, 2, 2), device=x.device, dtype=x.dtype) + + # Compute deflection angle gradients + dax_dx = torch.func.grad( + lambda *a: self.effective_reduced_deflection_angle(*a)[0], argnums=0 ) - (J[..., 0, 1],) = torch.autograd.grad( - ax, y, grad_outputs=torch.ones_like(ax), create_graph=True + J[..., 0, 0] = torch.vmap(dax_dx, in_dims=(0, 0, None), chunk_size=chunk_size)( + x.flatten(), y.flatten(), z_s + ).reshape(x.shape) + + dax_dy = torch.func.grad( + lambda *a: self.effective_reduced_deflection_angle(*a)[0], argnums=1 ) - (J[..., 1, 0],) = torch.autograd.grad( - ay, x, grad_outputs=torch.ones_like(ay), create_graph=True + J[..., 0, 1] = torch.vmap(dax_dy, in_dims=(0, 0, None), chunk_size=chunk_size)( + x.flatten(), y.flatten(), z_s + ).reshape(x.shape) + + day_dx = torch.func.grad( + lambda *a: self.effective_reduced_deflection_angle(*a)[1], argnums=0 ) - (J[..., 1, 1],) = torch.autograd.grad( - ay, y, grad_outputs=torch.ones_like(ay), create_graph=True + J[..., 1, 0] = torch.vmap(day_dx, in_dims=(0, 0, None), chunk_size=chunk_size)( + x.flatten(), y.flatten(), z_s + ).reshape(x.shape) + + day_dy = torch.func.grad( + lambda *a: self.effective_reduced_deflection_angle(*a)[1], argnums=1 ) + J[..., 1, 1] = torch.vmap(day_dy, in_dims=(0, 0, None), chunk_size=chunk_size)( + x.flatten(), y.flatten(), z_s + ).reshape(x.shape) + return J.detach() - @unpack + @forward def jacobian_effective_deflection_angle( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, method="autograd", pixelscale=None, **kwargs, @@ -600,7 +583,9 @@ def jacobian_effective_deflection_angle( """ if method == "autograd": - return self._jacobian_effective_deflection_angle_autograd(x, y, z_s, params) + return self._jacobian_effective_deflection_angle_autograd( + x, y, z_s, **kwargs + ) elif method == "finitediff": if pixelscale is None: raise ValueError( @@ -609,20 +594,18 @@ def jacobian_effective_deflection_angle( "Please include the pixelscale argument" ) return self._jacobian_effective_deflection_angle_finitediff( - x, y, z_s, pixelscale, params + x, y, z_s, pixelscale, **kwargs ) else: raise ValueError("method should be one of: autograd, finitediff") - @unpack + @forward def _jacobian_lens_equation_finitediff( self, x: Tensor, y: Tensor, z_s: Tensor, pixelscale: Tensor, - *args, - params: Optional["Packed"] = None, **kwargs, ) -> tuple[tuple[Tensor, Tensor], tuple[Tensor, Tensor]]: """ @@ -631,18 +614,16 @@ def _jacobian_lens_equation_finitediff( """ # Build Jacobian J = self._jacobian_effective_deflection_angle_finitediff( - x, y, z_s, pixelscale, params, **kwargs + x, y, z_s, pixelscale, **kwargs ) return torch.eye(2).to(J.device) - J - @unpack + @forward def _jacobian_lens_equation_autograd( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, **kwargs, ) -> tuple[tuple[Tensor, Tensor], tuple[Tensor, Tensor]]: """ @@ -650,19 +631,15 @@ def _jacobian_lens_equation_autograd( This equates to a (2,2) matrix at each (x,y) point. """ # Build Jacobian - J = self._jacobian_effective_deflection_angle_autograd( - x, y, z_s, params, **kwargs - ) + J = self._jacobian_effective_deflection_angle_autograd(x, y, z_s, **kwargs) return torch.eye(2).to(J.device) - J.detach() - @unpack + @forward def effective_convergence_div( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, **kwargs, ) -> Tensor: """ @@ -675,17 +652,15 @@ def effective_convergence_div( See: https://arxiv.org/pdf/2006.07383.pdf see also the `effective_convergence_curl` method. """ - J = self.jacobian_effective_deflection_angle(x, y, z_s, params, **kwargs) + J = self.jacobian_effective_deflection_angle(x, y, z_s, **kwargs) return 0.5 * (J[..., 0, 0] + J[..., 1, 1]) - @unpack + @forward def effective_convergence_curl( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, **kwargs, ) -> Tensor: """ @@ -697,7 +672,7 @@ def effective_convergence_curl( See: https://arxiv.org/pdf/2006.07383.pdf """ - J = self.jacobian_effective_deflection_angle(x, y, z_s, params, **kwargs) + J = self.jacobian_effective_deflection_angle(x, y, z_s, **kwargs) return 0.5 * (J[..., 1, 0] - J[..., 0, 1]) @@ -732,18 +707,15 @@ def __init__( name: NameType = None, ): super().__init__(cosmology=cosmology, name=name) - self.add_param("z_l", z_l) + self.z_l = Param("z_l", z_l, units="unitless", valid=(0, None)) - @unpack + @forward def reduced_deflection_angle( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - z_l: Optional[Tensor] = None, - **kwargs, + z_l: Annotated[Tensor, "Param"], ) -> tuple[Tensor, Tensor]: """ Computes the reduced deflection angle of the lens at given coordinates [arcsec]. @@ -781,25 +753,22 @@ def reduced_deflection_angle( *Unit: arcsec* """ - d_s = self.cosmology.angular_diameter_distance(z_s, params) - d_ls = self.cosmology.angular_diameter_distance_z1z2(z_l, z_s, params) + d_s = self.cosmology.angular_diameter_distance(z_s) + d_ls = self.cosmology.angular_diameter_distance_z1z2(z_l, z_s) deflection_angle_x, deflection_angle_y = self.physical_deflection_angle( - x, y, z_s, params + x, y, z_s ) return func.reduced_from_physical_deflection_angle( deflection_angle_x, deflection_angle_y, d_s, d_ls ) - @unpack + @forward def physical_deflection_angle( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - z_l: Optional[Tensor] = None, - **kwargs, + z_l: Annotated[Tensor, "Param"], ) -> tuple[Tensor, Tensor]: """ Computes the physical deflection angle immediately after passing through this lens's plane. @@ -837,24 +806,23 @@ def physical_deflection_angle( *Unit: arcsec* """ - d_s = self.cosmology.angular_diameter_distance(z_s, params) - d_ls = self.cosmology.angular_diameter_distance_z1z2(z_l, z_s, params) + d_s = self.cosmology.angular_diameter_distance(z_s) + d_ls = self.cosmology.angular_diameter_distance_z1z2(z_l, z_s) deflection_angle_x, deflection_angle_y = self.reduced_deflection_angle( - x, y, z_s, params + x, y, z_s ) return func.physical_from_reduced_deflection_angle( deflection_angle_x, deflection_angle_y, d_s, d_ls ) @abstractmethod - @unpack + @forward def convergence( self, x: Tensor, y: Tensor, z_s: Tensor, *args, - params: Optional["Packed"] = None, **kwargs, ) -> Tensor: """ @@ -891,14 +859,13 @@ def convergence( ... @abstractmethod - @unpack + @forward def potential( self, x: Tensor, y: Tensor, z_s: Tensor, *args, - params: Optional["Packed"] = None, **kwargs, ) -> Tensor: """ @@ -934,16 +901,13 @@ def potential( """ ... - @unpack + @forward def surface_density( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - z_l: Optional[Tensor] = None, - **kwargs, + z_l: Annotated[Tensor, "Param"], ) -> Tensor: """ Computes the surface mass density of the lens at given coordinates. @@ -976,19 +940,15 @@ def surface_density( *Unit: Msun/Mpc^2* """ - critical_surface_density = self.cosmology.critical_surface_density( - z_l, z_s, params - ) - return self.convergence(x, y, z_s, params) * critical_surface_density # fmt: skip + critical_surface_density = self.cosmology.critical_surface_density(z_l, z_s) + return self.convergence(x, y, z_s) * critical_surface_density # fmt: skip - @unpack + @forward def raytrace( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, **kwargs, ) -> tuple[Tensor, Tensor]: """ @@ -1028,31 +988,28 @@ def raytrace( *Unit: arcsec* """ - ax, ay = self.reduced_deflection_angle(x, y, z_s, params, **kwargs) + ax, ay = self.reduced_deflection_angle(x, y, z_s, **kwargs) return x - ax, y - ay - def _arcsec2_to_days(self, z_l, z_s, params): + def _arcsec2_to_days(self, z_l, z_s): """ This method is used by :func:`caustics.lenses.ThinLens.time_delay` to convert arcsec^2 to days in the context of gravitational time delays. """ - d_l = self.cosmology.angular_diameter_distance(z_l, params) - d_s = self.cosmology.angular_diameter_distance(z_s, params) - d_ls = self.cosmology.angular_diameter_distance_z1z2(z_l, z_s, params) + d_l = self.cosmology.angular_diameter_distance(z_l) + d_s = self.cosmology.angular_diameter_distance(z_s) + d_ls = self.cosmology.angular_diameter_distance_z1z2(z_l, z_s) return func.time_delay_arcsec2_to_days(d_l, d_s, d_ls, z_l) - @unpack + @forward def time_delay( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - z_l: Optional[Tensor] = None, + z_l: Annotated[Tensor, "Param"], shapiro_time_delay: bool = True, geometric_time_delay: bool = True, - **kwargs, ) -> Tensor: """ Computes the gravitational time delay for light passing through the lens at given coordinates. @@ -1062,7 +1019,7 @@ def time_delay( .. math:: - \\Delta t = \\frac{1 + z_l}{c} \\frac{D_s}{D_l D_{ls}} \\left[ \\frac{1}{2}|\\vec{\\alpha}(\\vec{\\theta})|^2 - \psi(\\vec{\\theta}) \\right] + \\Delta t = \\frac{1 + z_l}{c} \\frac{D_s}{D_l D_{ls}} \\left[ \\frac{1}{2}|\\vec{\\alpha}(\\vec{\\theta})|^2 - \\psi(\\vec{\\theta}) \\right] where :math:`\\vec{\\alpha}(\\vec{\\theta})` is the deflection angle, :math:`\\psi(\\vec{\\theta})` is the lensing potential, @@ -1116,34 +1073,31 @@ def time_delay( TD = torch.zeros_like(x) if shapiro_time_delay: - potential = self.potential(x, y, z_s, params) + potential = self.potential(x, y, z_s) TD = TD - potential if geometric_time_delay: - ax, ay = self.physical_deflection_angle(x, y, z_s, params) + ax, ay = self.physical_deflection_angle(x, y, z_s) fp = 0.5 * (ax**2 + ay**2) TD = TD + fp - factor = self._arcsec2_to_days(z_l, z_s, params) + factor = self._arcsec2_to_days(z_l, z_s) return factor * TD - @unpack + @forward def _jacobian_deflection_angle_finitediff( self, x: Tensor, y: Tensor, z_s: Tensor, pixelscale: Tensor, - *args, - params: Optional[Packed] = None, - **kwargs, ) -> tuple[tuple[Tensor, Tensor], tuple[Tensor, Tensor]]: """ Return the jacobian of the deflection angle vector. This equates to a (2,2) matrix at each (x,y) point. """ # Compute deflection angles - ax, ay = self.reduced_deflection_angle(x, y, z_s, params) + ax, ay = self.reduced_deflection_angle(x, y, z_s) # Build Jacobian J = torch.zeros((*ax.shape, 2, 2), device=ax.device, dtype=ax.dtype) @@ -1151,54 +1105,61 @@ def _jacobian_deflection_angle_finitediff( J[..., 1, 1], J[..., 1, 0] = torch.gradient(ay, spacing=pixelscale) return J - @unpack + @forward def _jacobian_deflection_angle_autograd( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - **kwargs, + chunk_size: int = 10000, ) -> tuple[tuple[Tensor, Tensor], tuple[Tensor, Tensor]]: """ Return the jacobian of the deflection angle vector. This equates to a (2,2) matrix at each (x,y) point. """ - # Ensure the x,y coordinates track gradients - x = x.detach().requires_grad_() - y = y.detach().requires_grad_() - - # Compute deflection angles - ax, ay = self.reduced_deflection_angle(x, y, z_s, params) - # Build Jacobian - J = torch.zeros((*ax.shape, 2, 2), device=ax.device, dtype=ax.dtype) - (J[..., 0, 0],) = torch.autograd.grad( - ax, x, grad_outputs=torch.ones_like(ax), create_graph=True + J = torch.zeros((*x.shape, 2, 2), device=x.device, dtype=x.dtype) + + # Compute deflection angle gradients + dax_dx = torch.func.grad( + lambda *a: self.reduced_deflection_angle(*a)[0], argnums=0 ) - (J[..., 0, 1],) = torch.autograd.grad( - ax, y, grad_outputs=torch.ones_like(ax), create_graph=True + J[..., 0, 0] = torch.vmap(dax_dx, in_dims=(0, 0, None), chunk_size=chunk_size)( + x.flatten(), y.flatten(), z_s + ).reshape(x.shape) + + dax_dy = torch.func.grad( + lambda *a: self.reduced_deflection_angle(*a)[0], argnums=1 ) - (J[..., 1, 0],) = torch.autograd.grad( - ay, x, grad_outputs=torch.ones_like(ay), create_graph=True + J[..., 0, 1] = torch.vmap(dax_dy, in_dims=(0, 0, None), chunk_size=chunk_size)( + x.flatten(), y.flatten(), z_s + ).reshape(x.shape) + + day_dx = torch.func.grad( + lambda *a: self.reduced_deflection_angle(*a)[1], argnums=0 ) - (J[..., 1, 1],) = torch.autograd.grad( - ay, y, grad_outputs=torch.ones_like(ay), create_graph=True + J[..., 1, 0] = torch.vmap(day_dx, in_dims=(0, 0, None), chunk_size=chunk_size)( + x.flatten(), y.flatten(), z_s + ).reshape(x.shape) + + day_dy = torch.func.grad( + lambda *a: self.reduced_deflection_angle(*a)[1], argnums=1 ) + J[..., 1, 1] = torch.vmap(day_dy, in_dims=(0, 0, None), chunk_size=chunk_size)( + x.flatten(), y.flatten(), z_s + ).reshape(x.shape) + return J.detach() - @unpack + @forward def jacobian_deflection_angle( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, method="autograd", pixelscale=None, - **kwargs, + chunk_size: int = 10000, ) -> tuple[tuple[Tensor, Tensor], tuple[Tensor, Tensor]]: """ Return the jacobian of the deflection angle vector. @@ -1208,28 +1169,24 @@ def jacobian_deflection_angle( """ if method == "autograd": - return self._jacobian_deflection_angle_autograd(x, y, z_s, params) + return self._jacobian_deflection_angle_autograd(x, y, z_s, chunk_size) elif method == "finitediff": if pixelscale is None: raise ValueError( "Finite differences lensing jacobian requires regular grid " "and known pixelscale. Please include the pixelscale argument" ) - return self._jacobian_deflection_angle_finitediff( - x, y, z_s, pixelscale, params - ) + return self._jacobian_deflection_angle_finitediff(x, y, z_s, pixelscale) else: raise ValueError("method should be one of: autograd, finitediff") - @unpack + @forward def _jacobian_lens_equation_finitediff( self, x: Tensor, y: Tensor, z_s: Tensor, pixelscale: Tensor, - *args, - params: Optional["Packed"] = None, **kwargs, ) -> tuple[tuple[Tensor, Tensor], tuple[Tensor, Tensor]]: """ @@ -1237,19 +1194,15 @@ def _jacobian_lens_equation_finitediff( This equates to a (2,2) matrix at each (x,y) point. """ # Build Jacobian - J = self._jacobian_deflection_angle_finitediff( - x, y, z_s, pixelscale, params, **kwargs - ) + J = self._jacobian_deflection_angle_finitediff(x, y, z_s, pixelscale, **kwargs) return torch.eye(2).to(J.device) - J - @unpack + @forward def _jacobian_lens_equation_autograd( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, **kwargs, ) -> tuple[tuple[Tensor, Tensor], tuple[Tensor, Tensor]]: """ @@ -1257,5 +1210,5 @@ def _jacobian_lens_equation_autograd( This equates to a (2,2) matrix at each (x,y) point. """ # Build Jacobian - J = self._jacobian_deflection_angle_autograd(x, y, z_s, params, **kwargs) + J = self._jacobian_deflection_angle_autograd(x, y, z_s, **kwargs) return torch.eye(2).to(J.device) - J.detach() diff --git a/src/caustics/lenses/batchedplane.py b/src/caustics/lenses/batchedplane.py new file mode 100644 index 00000000..5c4fe314 --- /dev/null +++ b/src/caustics/lenses/batchedplane.py @@ -0,0 +1,211 @@ +from typing import Optional + +import torch +from torch import Tensor +from caskade import forward + +from .base import ThinLens, CosmologyType, NameType, ZLType +from ..utils import vmap_reduce + +__all__ = ("BatchedPlane",) + + +class BatchedPlane(ThinLens): + """ + A class for combining multiple thin lenses into a single lensing plane. It + is assumed that the lens parameters will have a batch dimension, internally + this class will vmap over the batch dimension and return the combined + lensing quantity. This class can only handle a single lens type, if you want + to combine different lens types, use the `SinglePlane` class. + + Attributes + ---------- + name: str + The name of the single plane lens. + + cosmology: Cosmology + An instance of the Cosmology class. + + lens: ThinLens + A ThinLens object that will be vmapped over into a single lensing plane. + + """ + + def __init__( + self, + cosmology: CosmologyType, + lens: ThinLens, + name: NameType = None, + z_l: ZLType = None, + chunk_size: Optional[int] = None, + ): + """ + Initialize the SinglePlane lens model. + """ + super().__init__(cosmology, z_l=z_l, name=name) + self.lens = lens + self.chunk_size = chunk_size + + @forward + def reduced_deflection_angle( + self, + x: Tensor, + y: Tensor, + z_s: Tensor, + ) -> tuple[Tensor, Tensor]: + """ + Calculate the total deflection angle by summing + the deflection angles of all individual lenses. + + Parameters + ---------- + x: Tensor + The x-coordinate of the lens. + + *Unit: arcsec* + + y: Tensor + The y-coordinate of the lens. + + *Unit: arcsec* + + z_s: Tensor + The source redshift. + + *Unit: unitless* + + params: Packed, optional + Dynamic parameter container. + + Returns + ------- + x_component: Tensor + The x-component of the deflection angle. + + *Unit: arcsec* + + y_component: Tensor + The y-component of the deflection angle. + + *Unit: arcsec* + + """ + + # Collect the dynamic parameters to vmap over + params = dict((p.name, p.value) for p in self.lens.local_dynamic_params) + batchdims = dict( + (p.name, -(len(p.shape) + 1)) for p in self.lens.local_dynamic_params + ) + batchdims["x"] = None + batchdims["y"] = None + batchdims["z_s"] = None + vr_deflection_angle = vmap_reduce( + lambda p: self.lens.reduced_deflection_angle(**p), + reduce_func=lambda x: (x[0].sum(dim=0), x[1].sum(dim=0)), + chunk_size=self.chunk_size, + in_dims=batchdims, + out_dims=(0, 0), + ) + return vr_deflection_angle(x=x, y=y, z_s=z_s, **params) + + @forward + def convergence( + self, + x: Tensor, + y: Tensor, + z_s: Tensor, + ) -> Tensor: + """ + Calculate the total projected mass density by + summing the mass densities of all individual lenses. + + Parameters + ---------- + x: Tensor + The x-coordinate of the lens. + + *Unit: arcsec* + + y: Tensor + The y-coordinate of the lens. + + *Unit: arcsec* + + z_s: Tensor + The source redshift. + + *Unit: unitless* + + params: Packed, optional + Dynamic parameter container. + + Returns + ------- + Tensor + The total projected mass density. + + *Unit: unitless* + + """ + # Collect the dynamic parameters to vmap over + params = dict((p.name, p.value) for p in self.lens.local_dynamic_params) + batchdims = dict( + (p.name, -(len(p.shape) + 1)) for p in self.lens.local_dynamic_params + ) + convergence = torch.vmap( + lambda p: self.lens.convergence(x, y, z_s, **p), + in_dims=(batchdims,), + chunk_size=self.chunk_size, + )(params) + return convergence.sum(dim=0) + + @forward + def potential( + self, + x: Tensor, + y: Tensor, + z_s: Tensor, + ) -> Tensor: + """ + Compute the total lensing potential by summing + the lensing potentials of all individual lenses. + + Parameters + ----------- + x: Tensor + The x-coordinate of the lens. + + *Unit: arcsec* + + y: Tensor + The y-coordinate of the lens. + + *Unit: arcsec* + + z_s: Tensor + The source redshift. + + *Unit: unitless* + + params: Packed, optional + Dynamic parameter container. + + Returns + ------- + Tensor + The total lensing potential. + + *Unit: arcsec^2* + + """ + # Collect the dynamic parameters to vmap over + params = dict((p.name, p.value) for p in self.lens.local_dynamic_params) + batchdims = dict( + (p.name, -(len(p.shape) + 1)) for p in self.lens.local_dynamic_params + ) + potential = torch.vmap( + lambda p: self.lens.potential(x, y, z_s, **p), + in_dims=(batchdims,), + chunk_size=self.chunk_size, + )(params) + return potential.sum(dim=0) diff --git a/src/caustics/lenses/enclosed_mass.py b/src/caustics/lenses/enclosed_mass.py index f37c37f7..247a3b8b 100644 --- a/src/caustics/lenses/enclosed_mass.py +++ b/src/caustics/lenses/enclosed_mass.py @@ -1,11 +1,10 @@ # mypy: disable-error-code="operator,union-attr,dict-item" from typing import Optional, Union, Annotated, Callable -from torch import Tensor +from torch import Tensor, pi +from caskade import forward, Param from .base import ThinLens, CosmologyType, NameType, ZLType -from ..parametrized import unpack -from ..packed import Packed from .func import physical_deflection_angle_enclosed_mass, convergence_enclosed_mass __all__ = ("EnclosedMass",) @@ -109,29 +108,25 @@ def __init__( super().__init__(cosmology, z_l, name=name, **kwargs) self.enclosed_mass = enclosed_mass - self.add_param("x0", x0) - self.add_param("y0", y0) - self.add_param("q", q) - self.add_param("phi", phi) - self.add_param("p", p) + self.x0 = Param("x0", x0, units="arcsec") + self.y0 = Param("y0", y0, units="arcsec") + self.q = Param("q", q, units="unitless", valid=(0, 1)) + self.phi = Param("phi", phi, units="radians", valid=(0, pi), cyclic=True) + self.p = Param("p", p, units="user-defined") self.s = s - @unpack + @forward def physical_deflection_angle( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional[Packed] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - q: Optional[Tensor] = None, - phi: Optional[Tensor] = None, - p: Optional[Tensor] = None, - **kwargs, + x0: Annotated[Tensor, "Param"], + y0: Annotated[Tensor, "Param"], + q: Annotated[Tensor, "Param"], + phi: Annotated[Tensor, "Param"], + p: Annotated[Tensor, "Param"], ) -> tuple[Tensor, Tensor]: """ Calculate the physical deflection angle of the lens at a given position. @@ -184,39 +179,31 @@ def physical_deflection_angle( x0, y0, q, phi, lambda r: self.enclosed_mass(r, p), x, y, self.s ) - @unpack + @forward def potential( self, x: Tensor, y: Tensor, z_s: Tensor, *args, - params: Optional[Packed] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - p: Optional[Tensor] = None, **kwargs, ) -> Tensor: raise NotImplementedError( "Potential is not implemented for enclosed mass profiles." ) - @unpack + @forward def convergence( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional[Packed] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - q: Optional[Tensor] = None, - phi: Optional[Tensor] = None, - p: Optional[Tensor] = None, - **kwargs, + z_l: Annotated[Tensor, "Param"], + x0: Annotated[Tensor, "Param"], + y0: Annotated[Tensor, "Param"], + q: Annotated[Tensor, "Param"], + phi: Annotated[Tensor, "Param"], + p: Annotated[Tensor, "Param"], ) -> Tensor: """ Calculate the dimensionless convergence of the lens at a given position. @@ -266,7 +253,7 @@ def convergence( *Unit: unitless* """ - csd = self.cosmology.critical_surface_density(z_l, z_s, params) + csd = self.cosmology.critical_surface_density(z_l, z_s) return convergence_enclosed_mass( x0, y0, diff --git a/src/caustics/lenses/epl.py b/src/caustics/lenses/epl.py index f1f7d694..485e0321 100644 --- a/src/caustics/lenses/epl.py +++ b/src/caustics/lenses/epl.py @@ -2,11 +2,10 @@ from typing import Optional, Union, Annotated import torch -from torch import Tensor +from torch import Tensor, pi +from caskade import forward, Param from .base import ThinLens, CosmologyType, NameType, ZLType -from ..parametrized import unpack -from ..packed import Packed from . import func __all__ = ("EPL",) @@ -180,32 +179,28 @@ def __init__( """ super().__init__(cosmology, z_l, name=name) - self.add_param("x0", x0) - self.add_param("y0", y0) - self.add_param("q", q) - self.add_param("phi", phi) - self.add_param("b", b) - self.add_param("t", t) + self.x0 = Param("x0", x0, units="arcsec") + self.y0 = Param("y0", y0, units="arcsec") + self.q = Param("q", q, units="unitless", valid=(0, 1)) + self.phi = Param("phi", phi, units="radians", valid=(0, pi), cyclic=True) + self.b = Param("b", b, units="arcsec", valid=(0, None)) + self.t = Param("t", t, units="unitless", valid=(0, 2)) self.s = s self.n_iter = n_iter - @unpack + @forward def reduced_deflection_angle( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - q: Optional[Tensor] = None, - phi: Optional[Tensor] = None, - b: Optional[Tensor] = None, - t: Optional[Tensor] = None, - **kwargs, + x0: Annotated[Tensor, "Param"], + y0: Annotated[Tensor, "Param"], + q: Annotated[Tensor, "Param"], + phi: Annotated[Tensor, "Param"], + b: Annotated[Tensor, "Param"], + t: Annotated[Tensor, "Param"], ) -> tuple[Tensor, Tensor]: """ Compute the reduced deflection angles of the lens. @@ -291,22 +286,18 @@ def _r_omega(self, z, t, q): return part_sum - @unpack + @forward def potential( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - q: Optional[Tensor] = None, - phi: Optional[Tensor] = None, - b: Optional[Tensor] = None, - t: Optional[Tensor] = None, - **kwargs, + x0: Annotated[Tensor, "Param"], + y0: Annotated[Tensor, "Param"], + q: Annotated[Tensor, "Param"], + phi: Annotated[Tensor, "Param"], + b: Annotated[Tensor, "Param"], + t: Annotated[Tensor, "Param"], ): """ Compute the lensing potential of the lens. @@ -341,22 +332,18 @@ def potential( """ return func.potential_epl(x0, y0, q, phi, b, t, x, y, self.n_iter) - @unpack + @forward def convergence( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - q: Optional[Tensor] = None, - phi: Optional[Tensor] = None, - b: Optional[Tensor] = None, - t: Optional[Tensor] = None, - **kwargs, + x0: Annotated[Tensor, "Param"], + y0: Annotated[Tensor, "Param"], + q: Annotated[Tensor, "Param"], + phi: Annotated[Tensor, "Param"], + b: Annotated[Tensor, "Param"], + t: Annotated[Tensor, "Param"], ) -> Tensor: """ Compute the convergence of the lens, which describes the local density of the lens. diff --git a/src/caustics/lenses/external_shear.py b/src/caustics/lenses/external_shear.py index 31cdfb3d..545c6ebb 100644 --- a/src/caustics/lenses/external_shear.py +++ b/src/caustics/lenses/external_shear.py @@ -3,10 +3,9 @@ from torch import Tensor import torch +from caskade import forward, Param from .base import ThinLens, CosmologyType, NameType, ZLType -from ..parametrized import unpack -from ..packed import Packed from . import func __all__ = ("ExternalShear",) @@ -79,26 +78,22 @@ def __init__( ): super().__init__(cosmology, z_l, name=name) - self.add_param("x0", x0) - self.add_param("y0", y0) - self.add_param("gamma_1", gamma_1) - self.add_param("gamma_2", gamma_2) + self.x0 = Param("x0", x0, units="arcsec") + self.y0 = Param("y0", y0, units="arcsec") + self.gamma_1 = Param("gamma_1", gamma_1, units="unitless") + self.gamma_2 = Param("gamma_2", gamma_2, units="unitless") self.s = s - @unpack + @forward def reduced_deflection_angle( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - gamma_1: Optional[Tensor] = None, - gamma_2: Optional[Tensor] = None, - **kwargs, + x0: Annotated[Tensor, "Param"], + y0: Annotated[Tensor, "Param"], + gamma_1: Annotated[Tensor, "Param"], + gamma_2: Annotated[Tensor, "Param"], ) -> tuple[Tensor, Tensor]: """ Calculates the reduced deflection angle. @@ -140,20 +135,16 @@ def reduced_deflection_angle( x0, y0, gamma_1, gamma_2, x, y ) - @unpack + @forward def potential( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - gamma_1: Optional[Tensor] = None, - gamma_2: Optional[Tensor] = None, - **kwargs, + x0: Annotated[Tensor, "Param"], + y0: Annotated[Tensor, "Param"], + gamma_1: Annotated[Tensor, "Param"], + gamma_2: Annotated[Tensor, "Param"], ) -> Tensor: """ Calculates the lensing potential. @@ -188,20 +179,12 @@ def potential( """ return func.potential_external_shear(x0, y0, gamma_1, gamma_2, x, y) - @unpack + @forward def convergence( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - gamma_1: Optional[Tensor] = None, - gamma_2: Optional[Tensor] = None, - **kwargs, ) -> Tensor: """ The convergence is undefined for an external shear. diff --git a/src/caustics/lenses/func/__init__.py b/src/caustics/lenses/func/__init__.py index a264a779..81d7b80c 100644 --- a/src/caustics/lenses/func/__init__.py +++ b/src/caustics/lenses/func/__init__.py @@ -12,7 +12,13 @@ time_delay_arcsec2_to_days, ) from .sie import reduced_deflection_angle_sie, potential_sie, convergence_sie -from .point import reduced_deflection_angle_point, potential_point, convergence_point +from .point import ( + reduced_deflection_angle_point, + potential_point, + convergence_point, + mass_to_rein_point, + rein_to_mass_point, +) from .mass_sheet import ( reduced_deflection_angle_mass_sheet, potential_mass_sheet, @@ -91,6 +97,8 @@ "reduced_deflection_angle_point", "potential_point", "convergence_point", + "mass_to_rein_point", + "rein_to_mass_point", "reduced_deflection_angle_mass_sheet", "potential_mass_sheet", "convergence_mass_sheet", diff --git a/src/caustics/lenses/func/mass_sheet.py b/src/caustics/lenses/func/mass_sheet.py index 0921d9cc..65af704e 100644 --- a/src/caustics/lenses/func/mass_sheet.py +++ b/src/caustics/lenses/func/mass_sheet.py @@ -3,9 +3,9 @@ from ...utils import translate_rotate -def reduced_deflection_angle_mass_sheet(x0, y0, surface_density, x, y): +def reduced_deflection_angle_mass_sheet(x0, y0, kappa, x, y): """ - Compute the reduced deflection angles. Here we use the Meneeghetti lecture + Compute the reduced deflection angles. Here we use the Meneghetti lecture notes equation 3.84. Parameters @@ -20,8 +20,8 @@ def reduced_deflection_angle_mass_sheet(x0, y0, surface_density, x, y): *Unit: arcsec* - surface_density: Optional[Union[Tensor, float]] - Surface density normalized by the critical surface density. + kappa: Optional[Union[Tensor, float]] + Convergence. Surface density normalized by the critical surface density. *Unit: unitless* @@ -50,12 +50,12 @@ def reduced_deflection_angle_mass_sheet(x0, y0, surface_density, x, y): """ x, y = translate_rotate(x, y, x0, y0) # Meneghetti eq 3.84 - ax = x * surface_density - ay = y * surface_density + ax = x * kappa + ay = y * kappa return ax, ay -def potential_mass_sheet(x0, y0, surface_density, x, y): +def potential_mass_sheet(x0, y0, kappa, x, y): """ Compute the lensing potential. Here we use the Meneghetti lecture notes equation 3.81. @@ -72,8 +72,8 @@ def potential_mass_sheet(x0, y0, surface_density, x, y): *Unit: arcsec* - surface_density: Optional[Union[Tensor, float]] - Surface density normalized by the critical surface density. + kappa: Optional[Union[Tensor, float]] + Convergence. Surface density normalized by the critical surface density. *Unit: unitless* @@ -97,18 +97,18 @@ def potential_mass_sheet(x0, y0, surface_density, x, y): """ x, y = translate_rotate(x, y, x0, y0) # Meneghetti eq 3.81 - return (surface_density / 2) * (x**2 + y**2) + return (kappa / 2) * (x**2 + y**2) -def convergence_mass_sheet(surface_density, x): +def convergence_mass_sheet(kappa, x): """ Compute the lensing convergence. In the case of a mass sheet, this is just the convergence value mapped to the input shape. Parameters ---------- - surface_density: Optional[Union[Tensor, float]] - Surface density normalized by the critical surface density. + kappa: Optional[Union[Tensor, float]] + Convergence. Surface density normalized by the critical surface density. *Unit: unitless* @@ -126,4 +126,4 @@ def convergence_mass_sheet(surface_density, x): """ # By definition - return surface_density * torch.ones_like(x) + return kappa * torch.ones_like(x) diff --git a/src/caustics/lenses/func/point.py b/src/caustics/lenses/func/point.py index 2a3aa031..f71e3de7 100644 --- a/src/caustics/lenses/func/point.py +++ b/src/caustics/lenses/func/point.py @@ -1,6 +1,7 @@ import torch from ...utils import translate_rotate +from ...constants import G_over_c2, rad_to_arcsec def reduced_deflection_angle_point(x0, y0, th_ein, x, y, s=0.0): @@ -147,3 +148,77 @@ def convergence_point(x0, y0, x, y): """ x, y = translate_rotate(x, y, x0, y0) return torch.where((x == 0) & (y == 0), torch.inf, 0.0) + + +def mass_to_rein_point(M, d_ls, d_l, d_s): + """ + Compute the Einstein radius of a point mass. See Meneghetti lecture notes equation 1.39 + + Parameters + ---------- + M: Tensor + Mass of the lens. + + *Unit: solar masses* + + d_ls: Tensor + Distance between the lens and the source. + + *Unit: Mpc* + + d_l: Tensor + Distance between the observer and the lens. + + *Unit: Mpc* + + d_s: Tensor + Distance between the observer and the source. + + *Unit: Mpc* + + Returns + ------- + Tensor + The Einstein radius. + + *Unit: arcsec* + + """ + return rad_to_arcsec * (4 * G_over_c2 * M * d_ls / (d_l * d_s)).sqrt() + + +def rein_to_mass_point(r, d_ls, d_l, d_s): + """ + Compute the Einstein radius of a point mass. See Meneghetti lecture notes equation 1.39 + + Parameters + ---------- + r: Tensor + Einstein radius of the lens. + + *Unit: arcsec* + + d_ls: Tensor + Distance between the lens and the source. + + *Unit: Mpc* + + d_l: Tensor + Distance between the observer and the lens. + + *Unit: Mpc* + + d_s: Tensor + Distance between the observer and the source. + + *Unit: Mpc* + + Returns + ------- + Tensor + The mass of the lens + + *Unit: solar masses* + + """ + return (r / rad_to_arcsec) ** 2 * d_l * d_s / (4 * G_over_c2 * d_ls) diff --git a/src/caustics/lenses/mass_sheet.py b/src/caustics/lenses/mass_sheet.py index 2c0f6331..2637d7b1 100644 --- a/src/caustics/lenses/mass_sheet.py +++ b/src/caustics/lenses/mass_sheet.py @@ -2,10 +2,9 @@ from typing import Optional, Union, Annotated from torch import Tensor +from caskade import forward, Param from .base import ThinLens, CosmologyType, NameType, ZLType -from ..parametrized import unpack -from ..packed import Packed from . import func __all__ = ("MassSheet",) @@ -38,8 +37,8 @@ class MassSheet(ThinLens): *Unit: arcsec* - surface_density: Optional[Union[Tensor, float]] - Surface density normalized by the critical surface density. + kappa: Optional[Union[Tensor, float]] + Convergence. Surface density normalized by the critical surface density. *Unit: unitless* """ @@ -47,7 +46,7 @@ class MassSheet(ThinLens): _null_params = { "x0": 0.0, "y0": 0.0, - "surface_density": 0.1, + "kappa": 0.1, } def __init__( @@ -64,30 +63,26 @@ def __init__( "y-coordinate of the shear center in the lens plane", True, ] = None, - surface_density: Annotated[ + kappa: Annotated[ Optional[Union[Tensor, float]], "Surface density", True ] = None, name: NameType = None, ): super().__init__(cosmology, z_l, name=name) - self.add_param("x0", x0) - self.add_param("y0", y0) - self.add_param("surface_density", surface_density) + self.x0 = Param("x0", x0, units="arcsec") + self.y0 = Param("y0", y0, units="arcsec") + self.kappa = Param("kappa", kappa, units="unitless") - @unpack + @forward def reduced_deflection_angle( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - surface_density: Optional[Tensor] = None, - **kwargs, + x0: Annotated[Tensor, "Param"], + y0: Annotated[Tensor, "Param"], + kappa: Annotated[Tensor, "Param"], ) -> tuple[Tensor, Tensor]: """ Calculates the reduced deflection angle. @@ -109,9 +104,6 @@ def reduced_deflection_angle( *Unit: unitless* - params: (Packed, optional) - Dynamic parameter container. - Returns ------- x_component: Tensor @@ -125,38 +117,28 @@ def reduced_deflection_angle( *Unit: arcsec* """ - return func.reduced_deflection_angle_mass_sheet(x0, y0, surface_density, x, y) + return func.reduced_deflection_angle_mass_sheet(x0, y0, kappa, x, y) - @unpack + @forward def potential( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - surface_density: Optional[Tensor] = None, - **kwargs, + x0: Annotated[Tensor, "Param"], + y0: Annotated[Tensor, "Param"], + kappa: Annotated[Tensor, "Param"], ) -> Tensor: # Meneghetti eq 3.81 - return func.potential_mass_sheet(x0, y0, surface_density, x, y) + return func.potential_mass_sheet(x0, y0, kappa, x, y) - @unpack + @forward def convergence( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - surface_density: Optional[Tensor] = None, - **kwargs, + kappa: Annotated[Tensor, "Param"], ) -> Tensor: # Essentially by definition - return func.convergence_mass_sheet(surface_density, x) + return func.convergence_mass_sheet(kappa, x) diff --git a/src/caustics/lenses/multiplane.py b/src/caustics/lenses/multiplane.py index 3e74c1fa..0db647ec 100644 --- a/src/caustics/lenses/multiplane.py +++ b/src/caustics/lenses/multiplane.py @@ -1,13 +1,11 @@ from operator import itemgetter -from typing import Optional import torch from torch import Tensor +from caskade import forward from ..constants import arcsec_to_rad, rad_to_arcsec, c_Mpc_s, days_to_seconds from .base import ThickLens, NameType, CosmologyType, LensesType -from ..parametrized import unpack -from ..packed import Packed __all__ = ("Multiplane",) @@ -37,12 +35,10 @@ def __init__( super().__init__(cosmology, name=name) self.lenses = lenses for lens in lenses: - self.add_parametrized(lens) + self.link(lens.name, lens) - @unpack - def get_z_ls( - self, *args, params: Optional["Packed"] = None, **kwargs - ) -> list[Tensor]: + @forward + def get_z_ls(self) -> list[Tensor]: """ Get the redshifts of each lens in the multiplane. @@ -60,28 +56,25 @@ def get_z_ls( """ # Relies on z_l being the first element to be unpacked, which should always # be the case for a ThinLens - return [lens.unpack(params)[0] for lens in self.lenses] + return [lens.z_l.value for lens in self.lenses] - @unpack + @forward def _raytrace_helper( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, shapiro_time_delay: bool = True, geometric_time_delay: bool = True, ray_coords: bool = True, - **kwargs, ): # Collect lens redshifts and ensure proper order - z_ls = self.get_z_ls(params) + z_ls = self.get_z_ls() lens_planes = [i for i, _ in sorted(enumerate(z_ls), key=itemgetter(1))] - D_s = self.cosmology.transverse_comoving_distance(z_s, params) + D_s = self.cosmology.transverse_comoving_distance(z_s) # Compute physical position on first lens plane - D = self.cosmology.transverse_comoving_distance(z_ls[lens_planes[0]], params) + D = self.cosmology.transverse_comoving_distance(z_ls[lens_planes[0]]) X, Y = x * arcsec_to_rad * D, y * arcsec_to_rad * D # fmt: skip # Initial angles are observation angles @@ -94,19 +87,14 @@ def _raytrace_helper( for i in lens_planes: z_next = z_ls[i + 1] if i != lens_planes[-1] else z_s # Compute deflection angle at current ray positions - D_l = self.cosmology.transverse_comoving_distance(z_ls[i], params) - D = self.cosmology.transverse_comoving_distance_z1z2( - z_ls[i], z_next, params - ) - D_is = self.cosmology.transverse_comoving_distance_z1z2( - z_ls[i], z_s, params - ) - D_next = self.cosmology.transverse_comoving_distance(z_next, params) + D_l = self.cosmology.transverse_comoving_distance(z_ls[i]) + D = self.cosmology.transverse_comoving_distance_z1z2(z_ls[i], z_next) + D_is = self.cosmology.transverse_comoving_distance_z1z2(z_ls[i], z_s) + D_next = self.cosmology.transverse_comoving_distance(z_next) alpha_x, alpha_y = self.lenses[i].physical_deflection_angle( X * rad_to_arcsec / D_l, Y * rad_to_arcsec / D_l, z_s, - params, ) # Update angle of rays after passing through lens (sum in eq 18) @@ -118,7 +106,7 @@ def _raytrace_helper( if shapiro_time_delay: beta_ij = D * D_s / (D_next * D_is) potential = self.lenses[i].potential( - X * rad_to_arcsec / D_l, Y * rad_to_arcsec / D_l, z_s, params + X * rad_to_arcsec / D_l, Y * rad_to_arcsec / D_l, z_s ) TD += (-tau_ij * beta_ij * arcsec_to_rad**2) * potential if geometric_time_delay: @@ -129,7 +117,7 @@ def _raytrace_helper( Y = Y + D * theta_y * arcsec_to_rad # Convert from physical position to angular position on the source plane - D_end = self.cosmology.transverse_comoving_distance(z_s, params) + D_end = self.cosmology.transverse_comoving_distance(z_s) if ray_coords and not (shapiro_time_delay or geometric_time_delay): return X * rad_to_arcsec / D_end, Y * rad_to_arcsec / D_end elif ray_coords and (shapiro_time_delay or geometric_time_delay): @@ -140,15 +128,12 @@ def _raytrace_helper( "No return value specified. Must choose one or more of: ray_coords, shapiro_time_delay, or geometric_time_delay to be True." ) - @unpack + @forward def raytrace( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - **kwargs, ) -> tuple[Tensor, Tensor]: """Calculate the angular source positions corresponding to the observer positions x,y. See Margarita et al. 2013 for the @@ -159,20 +144,20 @@ def raytrace( .. math:: - \vec{x}^{i+1} = \vec{x}^i + D_{i+1,i}\left[\vec{\theta} - \sum_{j=1}^{i}\bf{\alpha}^j(\vec{x}^j)\right] + \\vec{x}^{i+1} = \\vec{x}^i + D_{i+1,i}\\left[\\vec{\\theta} - \\sum_{j=1}^{i}\\bf{\\alpha}^j(\\vec{x}^j)\\right] - As an initialization we set the physical positions at the first lensing plane to be :math:`\vec{\theta}D_{1,0}` which is just propagation through regular space to the first plane. - Note that :math:`\vec{\alpha}` is a physical deflection angle. The equation above converts straightforwardly into a recursion formula: + As an initialization we set the physical positions at the first lensing plane to be :math:`\\vec{\\theta}D_{1,0}` which is just propagation through regular space to the first plane. + Note that :math:`\\vec{\\alpha}` is a physical deflection angle. The equation above converts straightforwardly into a recursion formula: .. math:: - \vec{x}^{i+1} = \vec{x}^i + D_{i+1,i}\vec{\theta}^{i} - \vec{\theta}^{i+1} = \vec{\theta}^{i} - \alpha^i(\vec{x}^{i+1}) + \\vec{x}^{i+1} = \\vec{x}^i + D_{i+1,i}\\vec{\theta}^{i} + \\vec{\\theta}^{i+1} = \\vec{\\theta}^{i} - \\alpha^i(\\vec{x}^{i+1}) - Here we set as initialization :math:`\vec{\theta}^0 = theta` the observation angular coordinates and :math:`\vec{x}^0 = 0` the initial physical coordinates (i.e. the observation rays come from a point at the observer). - The indexing of :math:`\vec{x}^i` and :math:`\vec{\theta}^i` indicates the properties at the plane :math:`i`, + Here we set as initialization :math:`\\vec{\theta}^0 = theta` the observation angular coordinates and :math:`\\vec{x}^0 = 0` the initial physical coordinates (i.e. the observation rays come from a point at the observer). + The indexing of :math:`\\vec{x}^i` and :math:`\\vec{\\theta}^i` indicates the properties at the plane :math:`i`, and 0 means the observer, 1 is the first lensing plane (infinitesimally after the plane since the deflection has been applied), - and so on. Note that in the actual implementation we start at :math:`\vec{x}^1` and :math:`\vec{\theta}^0` + and so on. Note that in the actual implementation we start at :math:`\\vec{x}^1` and :math:`\\vec{\\theta}^0` and begin at the second step in the recursion formula. Parameters @@ -216,34 +201,28 @@ def raytrace( x, y, z_s, - params, shapiro_time_delay=False, geometric_time_delay=False, ray_coords=True, - **kwargs, ) - @unpack + @forward def effective_reduced_deflection_angle( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - **kwargs, ) -> tuple[Tensor, Tensor]: - bx, by = self.raytrace(x, y, z_s, params) + bx, by = self.raytrace(x, y, z_s) return x - bx, y - by - @unpack + @forward def surface_density( self, x: Tensor, y: Tensor, z_s: Tensor, *args, - params: Optional["Packed"] = None, **kwargs, ) -> Tensor: """ @@ -284,29 +263,26 @@ def surface_density( # TODO: rescale mass densities of each lens and sum raise NotImplementedError() - @unpack + @forward def time_delay( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, shapiro_time_delay: bool = True, geometric_time_delay: bool = True, - **kwargs, ) -> Tensor: """ Compute the time delay of light caused by the lensing. This is based on equation 6.22 in Petters et al. 2001. For the time delay of a light path from the observer to the source, the following equation is used:: - \Delta t = \sum_{i=1}^{N-1} \tau_{i,i+1} \left[ \frac{1}{2} \left( \vec{\alpha}^i \right)^2 - \beta_{i,i+1} \psi^i \right] \\ - \tau_{i,j} = (1 + z_i) \frac{D_i D_{j}}{D_{i,j} c} \\ - \beta_{i,j} = \frac{D_{i,j} D_s}{D_{j} D_{i,s}} \\ + \\Delta t = \\sum_{i=1}^{N-1} \\tau_{i,i+1} \\left[ \\frac{1}{2} \\left( \\vec{\\alpha}^i \\right)^2 - \\beta_{i,i+1} \\psi^i \\right] \\\\ + \\tau_{i,j} = (1 + z_i) \\frac{D_i D_{j}}{D_{i,j} c} \\\\ + \\beta_{i,j} = \\frac{D_{i,j} D_s}{D_{j} D_{i,s}} \\\\ - where :math:`\vec{\alpha}^i` is the deflection angle at the i-th lens plane, - :math:`\psi^i` is the lensing potential at the i-th lens plane, + where :math:`\\vec{\\alpha}^i` is the deflection angle at the i-th lens plane, + :math:`\\psi^i` is the lensing potential at the i-th lens plane, :math:`D_i` is the comoving distance to the i-th lens plane, :math:`D_{i,j}` is the comoving distance between the i-th and j-th lens plane, :math:`D_s` is the comoving distance to the source, @@ -357,9 +333,7 @@ def time_delay( x, y, z_s, - params, shapiro_time_delay=shapiro_time_delay, geometric_time_delay=geometric_time_delay, ray_coords=False, - **kwargs, ) diff --git a/src/caustics/lenses/multipole.py b/src/caustics/lenses/multipole.py index 48f866a8..d6eb97e8 100644 --- a/src/caustics/lenses/multipole.py +++ b/src/caustics/lenses/multipole.py @@ -2,11 +2,10 @@ from typing import Optional, Union, Annotated import torch -from torch import Tensor +from torch import Tensor, pi +from caskade import forward, Param from .base import ThinLens, CosmologyType, NameType, ZLType -from ..parametrized import unpack -from ..packed import Packed from . import func __all__ = ("Multipole",) @@ -60,12 +59,19 @@ def __init__( ): super().__init__(cosmology, z_l, name=name) - self.add_param("x0", x0) - self.add_param("y0", y0) + self.x0 = Param("x0", x0, units="arcsec") + self.y0 = Param("y0", y0, units="arcsec") self.m = torch.as_tensor(m, dtype=torch.int32) assert torch.all(self.m >= 2).item(), "Multipole order must be >= 2" - self.add_param("a_m", a_m, self.m.shape) - self.add_param("phi_m", phi_m, self.m.shape) + self.a_m = Param("a_m", a_m, self.m.shape, units="unitless") + self.phi_m = Param( + "phi_m", + phi_m, + self.m.shape, + units="radians", + valid=(0, 2 * pi), + cyclic=True, + ) def to(self, device: torch.device = None, dtype: torch.dtype = None): """ @@ -89,20 +95,16 @@ def to(self, device: torch.device = None, dtype: torch.dtype = None): self.m = self.m.to(device, torch.int32) return self - @unpack + @forward def reduced_deflection_angle( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - a_m: Optional[Tensor] = None, - phi_m: Optional[Tensor] = None, - **kwargs, + x0: Annotated[Tensor, "Param"], + y0: Annotated[Tensor, "Param"], + a_m: Annotated[Tensor, "Param"], + phi_m: Annotated[Tensor, "Param"], ) -> tuple[Tensor, Tensor]: """ Calculate the deflection angle of the multipole. @@ -144,20 +146,16 @@ def reduced_deflection_angle( """ return func.reduced_deflection_angle_multipole(x0, y0, self.m, a_m, phi_m, x, y) - @unpack + @forward def potential( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - a_m: Optional[Tensor] = None, - phi_m: Optional[Tensor] = None, - **kwargs, + x0: Annotated[Tensor, "Param"], + y0: Annotated[Tensor, "Param"], + a_m: Annotated[Tensor, "Param"], + phi_m: Annotated[Tensor, "Param"], ) -> Tensor: """ Compute the lensing potential of the multiplane. @@ -194,20 +192,16 @@ def potential( """ return func.potential_multipole(x0, y0, self.m, a_m, phi_m, x, y) - @unpack + @forward def convergence( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - a_m: Optional[Tensor] = None, - phi_m: Optional[Tensor] = None, - **kwargs, + x0: Annotated[Tensor, "Param"], + y0: Annotated[Tensor, "Param"], + a_m: Annotated[Tensor, "Param"], + phi_m: Annotated[Tensor, "Param"], ) -> Tensor: """ Calculate the projected mass density of the multipole. diff --git a/src/caustics/lenses/nfw.py b/src/caustics/lenses/nfw.py index 8c03740b..3b6119a4 100644 --- a/src/caustics/lenses/nfw.py +++ b/src/caustics/lenses/nfw.py @@ -2,10 +2,9 @@ from typing import Optional, Union, Annotated, Literal from torch import Tensor +from caskade import forward, Param from .base import ThinLens, NameType, CosmologyType, ZLType -from ..parametrized import unpack -from ..packed import Packed from . import func DELTA = 200.0 @@ -156,10 +155,10 @@ def __init__( """ super().__init__(cosmology, z_l, name=name) - self.add_param("x0", x0) - self.add_param("y0", y0) - self.add_param("m", m) - self.add_param("c", c) + self.x0 = Param("x0", x0, units="arcsec") + self.y0 = Param("y0", y0, units="arcsec") + self.m = Param("m", m, units="Msun") + self.c = Param("c", c, units="unitless") self.s = s if use_case == "batchable": self._f = func._f_batchable_nfw @@ -172,17 +171,12 @@ def __init__( else: raise ValueError("use case should be one of: batchable, differentiable") - @unpack + @forward def get_scale_radius( self, - *args, - params: Optional[Packed] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - m: Optional[Tensor] = None, - c: Optional[Tensor] = None, - **kwargs, + z_l: Annotated[Tensor, "Param"], + m: Annotated[Tensor, "Param"], + c: Annotated[Tensor, "Param"], ) -> Tensor: """ Calculate the scale radius of the lens. @@ -217,20 +211,14 @@ def get_scale_radius( *Unit: Mpc* """ - critical_density = self.cosmology.critical_density(z_l, params) + critical_density = self.cosmology.critical_density(z_l) return func.scale_radius_nfw(critical_density, m, c, DELTA) - @unpack + @forward def get_scale_density( self, - *args, - params: Optional["Packed"] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - m: Optional[Tensor] = None, - c: Optional[Tensor] = None, - **kwargs, + z_l: Annotated[Tensor, "Param"], + c: Annotated[Tensor, "Param"], ) -> Tensor: """ Calculate the scale density of the lens. @@ -258,23 +246,20 @@ def get_scale_density( *Unit: Msun/Mpc^3* """ - critical_density = self.cosmology.critical_density(z_l, params) + critical_density = self.cosmology.critical_density(z_l) return func.scale_density_nfw(critical_density, c, DELTA) - @unpack + @forward def physical_deflection_angle( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - m: Optional[Tensor] = None, - c: Optional[Tensor] = None, - **kwargs, + z_l: Annotated[Tensor, "Param"], + x0: Annotated[Tensor, "Param"], + y0: Annotated[Tensor, "Param"], + m: Annotated[Tensor, "Param"], + c: Annotated[Tensor, "Param"], ) -> tuple[Tensor, Tensor]: """ Compute the physical deflection angle. @@ -312,26 +297,23 @@ def physical_deflection_angle( *Unit: arcsec* """ - d_l = self.cosmology.angular_diameter_distance(z_l, params) - critical_density = self.cosmology.critical_density(z_l, params) + d_l = self.cosmology.angular_diameter_distance(z_l) + critical_density = self.cosmology.critical_density(z_l) return func.physical_deflection_angle_nfw( x0, y0, m, c, critical_density, d_l, x, y, _h=self._h, DELTA=DELTA, s=self.s ) - @unpack + @forward def convergence( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - m: Optional[Tensor] = None, - c: Optional[Tensor] = None, - **kwargs, + z_l: Annotated[Tensor, "Param"], + x0: Annotated[Tensor, "Param"], + y0: Annotated[Tensor, "Param"], + m: Annotated[Tensor, "Param"], + c: Annotated[Tensor, "Param"], ) -> Tensor: """ Compute the convergence (dimensionless surface mass density). @@ -364,11 +346,9 @@ def convergence( *Unit: unitless* """ - critical_surface_density = self.cosmology.critical_surface_density( - z_l, z_s, params - ) - critical_density = self.cosmology.critical_density(z_l, params) - d_l = self.cosmology.angular_diameter_distance(z_l, params) + critical_surface_density = self.cosmology.critical_surface_density(z_l, z_s) + critical_density = self.cosmology.critical_density(z_l) + d_l = self.cosmology.angular_diameter_distance(z_l) return func.convergence_nfw( critical_surface_density, critical_density, @@ -384,20 +364,17 @@ def convergence( s=self.s, ) - @unpack + @forward def potential( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - m: Optional[Tensor] = None, - c: Optional[Tensor] = None, - **kwargs, + z_l: Annotated[Tensor, "Param"], + x0: Annotated[Tensor, "Param"], + y0: Annotated[Tensor, "Param"], + m: Annotated[Tensor, "Param"], + c: Annotated[Tensor, "Param"], ) -> Tensor: """ Compute the lensing potential. @@ -430,11 +407,9 @@ def potential( *Unit: arcsec^2* """ - critical_surface_density = self.cosmology.critical_surface_density( - z_l, z_s, params - ) - critical_density = self.cosmology.critical_density(z_l, params) - d_l = self.cosmology.angular_diameter_distance(z_l, params) + critical_surface_density = self.cosmology.critical_surface_density(z_l, z_s) + critical_density = self.cosmology.critical_density(z_l) + d_l = self.cosmology.angular_diameter_distance(z_l) return func.potential_nfw( critical_surface_density, critical_density, diff --git a/src/caustics/lenses/pixelated_convergence.py b/src/caustics/lenses/pixelated_convergence.py index c11529a8..bf725c02 100644 --- a/src/caustics/lenses/pixelated_convergence.py +++ b/src/caustics/lenses/pixelated_convergence.py @@ -4,11 +4,10 @@ import torch from torch import Tensor import numpy as np +from caskade import forward, Param from ..utils import interp2d from .base import ThinLens, CosmologyType, NameType, ZLType -from ..parametrized import unpack -from ..packed import Packed from . import func __all__ = ("PixelatedConvergence",) @@ -141,9 +140,11 @@ def __init__( elif shape is not None and len(shape) != 2: raise ValueError(f"shape must specify a 2D tensor. Received shape={shape}") - self.add_param("x0", x0) - self.add_param("y0", y0) - self.add_param("convergence_map", convergence_map, shape) + self.x0 = Param("x0", x0, units="arcsec") + self.y0 = Param("y0", y0, units="arcsec") + self.convergence_map = Param( + "convergence_map", convergence_map, shape, units="unitless" + ) if convergence_map is not None: self.n_pix = convergence_map.shape[0] @@ -245,19 +246,15 @@ def convolution_mode(self, convolution_mode: str): self._convolution_mode = convolution_mode - @unpack + @forward def reduced_deflection_angle( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - convergence_map: Optional[Tensor] = None, - **kwargs, + x0: Annotated[Tensor, "Param"], + y0: Annotated[Tensor, "Param"], + convergence_map: Annotated[Tensor, "Param"], ) -> tuple[Tensor, Tensor]: """ Compute the deflection angles at the specified positions using the given convergence map. @@ -310,19 +307,15 @@ def reduced_deflection_angle( self.convolution_mode, ) - @unpack + @forward def potential( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - convergence_map: Optional[Tensor] = None, - **kwargs, + x0: Annotated[Tensor, "Param"], + y0: Annotated[Tensor, "Param"], + convergence_map: Annotated[Tensor, "Param"], ) -> Tensor: """ Compute the lensing potential at the specified positions using the given convergence map. @@ -369,19 +362,15 @@ def potential( self.convolution_mode, ) - @unpack + @forward def convergence( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - convergence_map: Optional[Tensor] = None, - **kwargs, + x0: Annotated[Tensor, "Param"], + y0: Annotated[Tensor, "Param"], + convergence_map: Annotated[Tensor, "Param"], ) -> Tensor: """ Compute the convergence at the specified positions. This method is not implemented. diff --git a/src/caustics/lenses/pixelated_potential.py b/src/caustics/lenses/pixelated_potential.py index 816176c7..ab786b85 100644 --- a/src/caustics/lenses/pixelated_potential.py +++ b/src/caustics/lenses/pixelated_potential.py @@ -4,11 +4,10 @@ import torch from torch import Tensor import numpy as np +from caskade import forward, Param from ..utils import interp_bicubic from .base import ThinLens, CosmologyType, NameType, ZLType -from ..parametrized import unpack -from ..packed import Packed __all__ = ("PixelatedPotential",) @@ -101,9 +100,11 @@ def __init__( elif shape is not None and len(shape) != 2: raise ValueError(f"shape must specify a 2D tensor. Received shape={shape}") - self.add_param("x0", x0) - self.add_param("y0", y0) - self.add_param("potential_map", potential_map, shape) + self.x0 = Param("x0", x0, units="arcsec") + self.y0 = Param("y0", y0, units="arcsec") + self.potential_map = Param( + "potential_map", potential_map, shape, units="unitless" + ) self.pixelscale = pixelscale if potential_map is not None: @@ -114,19 +115,15 @@ def __init__( raise ValueError("Either potential_map or shape must be provided") self.fov = self.n_pix * self.pixelscale - @unpack + @forward def reduced_deflection_angle( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - potential_map: Optional[Tensor] = None, - **kwargs, + x0: Annotated[Tensor, "Param"], + y0: Annotated[Tensor, "Param"], + potential_map: Annotated[Tensor, "Param"], ) -> tuple[Tensor, Tensor]: """ Compute the deflection angles at the specified positions using the given convergence map. @@ -177,19 +174,15 @@ def reduced_deflection_angle( ) ) - @unpack + @forward def potential( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - potential_map: Optional[Tensor] = None, - **kwargs, + x0: Annotated[Tensor, "Param"], + y0: Annotated[Tensor, "Param"], + potential_map: Annotated[Tensor, "Param"], ) -> Tensor: """ Compute the lensing potential at the specified positions using the given convergence map. @@ -231,19 +224,15 @@ def potential( get_ddY=False, )[0].reshape(x.shape) - @unpack + @forward def convergence( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - potential_map: Optional[Tensor] = None, - **kwargs, + x0: Annotated[Tensor, "Param"], + y0: Annotated[Tensor, "Param"], + potential_map: Annotated[Tensor, "Param"], ) -> Tensor: """ Compute the convergence at the specified positions. This method is not implemented. diff --git a/src/caustics/lenses/point.py b/src/caustics/lenses/point.py index 699c37f3..13f29b73 100644 --- a/src/caustics/lenses/point.py +++ b/src/caustics/lenses/point.py @@ -2,10 +2,9 @@ from typing import Optional, Union, Annotated from torch import Tensor +from caskade import forward, Param from .base import ThinLens, CosmologyType, NameType, ZLType -from ..parametrized import unpack -from ..packed import Packed from . import func __all__ = ("Point",) @@ -117,24 +116,76 @@ def __init__( """ super().__init__(cosmology, z_l, name=name) - self.add_param("x0", x0) - self.add_param("y0", y0) - self.add_param("th_ein", th_ein) + self.x0 = Param("x0", x0, units="arcsec") + self.y0 = Param("y0", y0, units="arcsec") + self.th_ein = Param("th_ein", th_ein, units="arcsec", valid=(0, None)) self.s = s - @unpack + @forward + def mass_to_rein( + self, mass: Tensor, z_s: Tensor, z_l: Annotated[Tensor, "Param"] + ) -> Tensor: + """ + Convert mass to the Einstein radius. + + Parameters + ---------- + mass: Tensor + The mass of the lens + + *Unit: solar mass* + + Returns + ------- + Tensor + The Einstein radius. + + *Unit: arcsec* + + """ + + Dls = self.cosmology.angular_diameter_distance_z1z2(z_l, z_s) + Dl = self.cosmology.angular_diameter_distance(z_l) + Ds = self.cosmology.angular_diameter_distance(z_s) + return func.mass_to_rein_point(mass, Dls, Dl, Ds) + + @forward + def rein_to_mass( + self, r: Tensor, z_s: Tensor, z_l: Annotated[Tensor, "Param"] + ) -> Tensor: + """ + Convert Einstein radius to mass. + + Parameters + ---------- + r: Tensor + The Einstein radius. + + *Unit: arcsec* + + Returns + ------- + Tensor + The mass of the lens + + *Unit: solar mass* + + """ + + Dls = self.cosmology.angular_diameter_distance_z1z2(z_l, z_s) + Dl = self.cosmology.angular_diameter_distance(z_l) + Ds = self.cosmology.angular_diameter_distance(z_s) + return func.rein_to_mass_point(r, Dls, Dl, Ds) + + @forward def reduced_deflection_angle( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - th_ein: Optional[Tensor] = None, - **kwargs, + x0: Annotated[Tensor, "Param"], + y0: Annotated[Tensor, "Param"], + th_ein: Annotated[Tensor, "Param"], ) -> tuple[Tensor, Tensor]: """ Compute the deflection angles. @@ -174,19 +225,15 @@ def reduced_deflection_angle( """ return func.reduced_deflection_angle_point(x0, y0, th_ein, x, y, self.s) - @unpack + @forward def potential( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - th_ein: Optional[Tensor] = None, - **kwargs, + x0: Annotated[Tensor, "Param"], + y0: Annotated[Tensor, "Param"], + th_ein: Annotated[Tensor, "Param"], ) -> Tensor: """ Compute the lensing potential. @@ -221,19 +268,14 @@ def potential( """ return func.potential_point(x0, y0, th_ein, x, y, self.s) - @unpack + @forward def convergence( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - th_ein: Optional[Tensor] = None, - **kwargs, + x0: Annotated[Tensor, "Param"], + y0: Annotated[Tensor, "Param"], ) -> Tensor: """ Compute the convergence (dimensionless surface mass density). diff --git a/src/caustics/lenses/pseudo_jaffe.py b/src/caustics/lenses/pseudo_jaffe.py index 64f40e83..2869bcfa 100644 --- a/src/caustics/lenses/pseudo_jaffe.py +++ b/src/caustics/lenses/pseudo_jaffe.py @@ -4,11 +4,10 @@ import torch from torch import Tensor +from caskade import forward, Param from ..constants import arcsec_to_rad from .base import ThinLens, CosmologyType, NameType, ZLType -from ..parametrized import unpack -from ..packed import Packed from . import func __all__ = ("PseudoJaffe",) @@ -151,45 +150,38 @@ def __init__( """ super().__init__(cosmology, z_l, name=name) - self.add_param("x0", x0) - self.add_param("y0", y0) - self.add_param("mass", mass) - self.add_param("core_radius", core_radius) - self.add_param("scale_radius", scale_radius) + self.x0 = Param("x0", x0, units="arcsec") + self.y0 = Param("y0", y0, units="arcsec") + self.mass = Param("mass", mass, units="Msun", valid=(0, None)) + self.core_radius = Param( + "core_radius", core_radius, units="arcsec", valid=(0, None) + ) + self.scale_radius = Param( + "scale_radius", scale_radius, units="arcsec", valid=(0, None) + ) self.s = s - @unpack + @forward def get_convergence_0( self, z_s, - *args, - params: Optional["Packed"] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - mass: Optional[Tensor] = None, - core_radius: Optional[Tensor] = None, - scale_radius: Optional[Tensor] = None, - **kwargs, + z_l: Annotated[Tensor, "Param"], + mass: Annotated[Tensor, "Param"], + core_radius: Annotated[Tensor, "Param"], + scale_radius: Annotated[Tensor, "Param"], ): - d_l = self.cosmology.angular_diameter_distance(z_l, params) - sigma_crit = self.cosmology.critical_surface_density(z_l, z_s, params) + d_l = self.cosmology.angular_diameter_distance(z_l) + sigma_crit = self.cosmology.critical_surface_density(z_l, z_s) return mass / (2 * torch.pi * sigma_crit * core_radius * scale_radius * (d_l * arcsec_to_rad) ** 2) # fmt: skip - @unpack + @forward def mass_enclosed_2d( self, theta, z_s, - *args, - params: Optional["Packed"] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - mass: Optional[Tensor] = None, - core_radius: Optional[Tensor] = None, - scale_radius: Optional[Tensor] = None, - **kwargs, + mass: Annotated[Tensor, "Param"], + core_radius: Annotated[Tensor, "Param"], + scale_radius: Annotated[Tensor, "Param"], ): """ Calculate the mass enclosed within a two-dimensional radius. Using equation A10 from `Eliasdottir et al 2007 `_. @@ -273,21 +265,18 @@ def central_convergence( """ return pi * rho_0 * core_radius * scale_radius / ((core_radius + scale_radius) * critical_surface_density) # fmt: skip - @unpack + @forward def reduced_deflection_angle( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - mass: Optional[Tensor] = None, - core_radius: Optional[Tensor] = None, - scale_radius: Optional[Tensor] = None, - **kwargs, + z_l: Annotated[Tensor, "Param"], + x0: Annotated[Tensor, "Param"], + y0: Annotated[Tensor, "Param"], + mass: Annotated[Tensor, "Param"], + core_radius: Annotated[Tensor, "Param"], + scale_radius: Annotated[Tensor, "Param"], ) -> tuple[Tensor, Tensor]: """Calculate the deflection angle. @@ -324,27 +313,22 @@ def reduced_deflection_angle( *Unit: arcsec* """ - d_l = self.cosmology.angular_diameter_distance(z_l, params) - critical_surface_density = self.cosmology.critical_surface_density( - z_l, z_s, params - ) + d_l = self.cosmology.angular_diameter_distance(z_l) + critical_surface_density = self.cosmology.critical_surface_density(z_l, z_s) return func.reduced_deflection_angle_pseudo_jaffe(x0, y0, mass, core_radius, scale_radius, x, y, d_l, critical_surface_density) # fmt: skip - @unpack + @forward def potential( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - mass: Optional[Tensor] = None, - core_radius: Optional[Tensor] = None, - scale_radius: Optional[Tensor] = None, - **kwargs, + z_l: Annotated[Tensor, "Param"], + x0: Annotated[Tensor, "Param"], + y0: Annotated[Tensor, "Param"], + mass: Annotated[Tensor, "Param"], + core_radius: Annotated[Tensor, "Param"], + scale_radius: Annotated[Tensor, "Param"], ) -> Tensor: """ Compute the lensing potential. This calculation is based on equation A18 from `Eliasdottir et al 2007 `_. @@ -378,27 +362,24 @@ def potential( """ - d_l = self.cosmology.angular_diameter_distance(z_l, params) # Mpc - d_s = self.cosmology.angular_diameter_distance(z_s, params) # Mpc - d_ls = self.cosmology.angular_diameter_distance_z1z2(z_l, z_s, params) # Mpc + d_l = self.cosmology.angular_diameter_distance(z_l) # Mpc + d_s = self.cosmology.angular_diameter_distance(z_s) # Mpc + d_ls = self.cosmology.angular_diameter_distance_z1z2(z_l, z_s) # Mpc return func.potential_pseudo_jaffe(x0, y0, mass, core_radius, scale_radius, x, y, d_l, d_s, d_ls) # fmt: skip - @unpack + @forward def convergence( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - mass: Optional[Tensor] = None, - core_radius: Optional[Tensor] = None, - scale_radius: Optional[Tensor] = None, - **kwargs, + z_l: Annotated[Tensor, "Param"], + x0: Annotated[Tensor, "Param"], + y0: Annotated[Tensor, "Param"], + mass: Annotated[Tensor, "Param"], + core_radius: Annotated[Tensor, "Param"], + scale_radius: Annotated[Tensor, "Param"], ) -> Tensor: """ Calculate the projected mass density, based on equation A6. @@ -431,10 +412,8 @@ def convergence( *Unit: unitless* """ - d_l = self.cosmology.angular_diameter_distance(z_l, params) - critical_surface_density = self.cosmology.critical_surface_density( - z_l, z_s, params - ) + d_l = self.cosmology.angular_diameter_distance(z_l) + critical_surface_density = self.cosmology.critical_surface_density(z_l, z_s) return func.convergence_pseudo_jaffe( x0, y0, mass, core_radius, scale_radius, x, y, d_l, critical_surface_density ) diff --git a/src/caustics/lenses/sie.py b/src/caustics/lenses/sie.py index 7a3e8f3d..1968d777 100644 --- a/src/caustics/lenses/sie.py +++ b/src/caustics/lenses/sie.py @@ -1,11 +1,10 @@ # mypy: disable-error-code="operator,union-attr,dict-item" from typing import Optional, Union, Annotated -from torch import Tensor +from torch import Tensor, pi +from caskade import forward, Param from .base import ThinLens, CosmologyType, NameType, ZLType -from ..parametrized import unpack -from ..packed import Packed from . import func __all__ = ("SIE",) @@ -98,11 +97,11 @@ def __init__( """ super().__init__(cosmology, z_l, name=name) - self.add_param("x0", x0) - self.add_param("y0", y0) - self.add_param("q", q) - self.add_param("phi", phi) - self.add_param("b", b) + self.x0 = Param("x0", x0, units="arcsec") + self.y0 = Param("y0", y0, units="arcsec") + self.q = Param("q", q, units="unitless", valid=(0, 1)) + self.phi = Param("phi", phi, units="radians", valid=(0, pi), cyclic=True) + self.b = Param("b", b, units="arcsec", valid=(0, None)) self.s = s def _get_potential(self, x, y, q): @@ -136,21 +135,17 @@ def _get_potential(self, x, y, q): """ return (q**2 * (x**2 + self.s**2) + y**2).sqrt() # fmt: skip - @unpack + @forward def reduced_deflection_angle( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - q: Optional[Tensor] = None, - phi: Optional[Tensor] = None, - b: Optional[Tensor] = None, - **kwargs, + x0: Annotated[Tensor, "Param"], + y0: Annotated[Tensor, "Param"], + q: Annotated[Tensor, "Param"], + phi: Annotated[Tensor, "Param"], + b: Annotated[Tensor, "Param"], ) -> tuple[Tensor, Tensor]: """ Calculate the physical deflection angle. @@ -190,21 +185,17 @@ def reduced_deflection_angle( """ return func.reduced_deflection_angle_sie(x0, y0, q, phi, b, x, y, self.s) - @unpack + @forward def potential( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - x0: Optional[Tensor] = None, - z_l: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - q: Optional[Tensor] = None, - phi: Optional[Tensor] = None, - b: Optional[Tensor] = None, - **kwargs, + x0: Annotated[Tensor, "Param"], + y0: Annotated[Tensor, "Param"], + q: Annotated[Tensor, "Param"], + phi: Annotated[Tensor, "Param"], + b: Annotated[Tensor, "Param"], ) -> Tensor: """ Compute the lensing potential. @@ -239,21 +230,17 @@ def potential( """ return func.potential_sie(x0, y0, q, phi, b, x, y, self.s) - @unpack + @forward def convergence( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - q: Optional[Tensor] = None, - phi: Optional[Tensor] = None, - b: Optional[Tensor] = None, - **kwargs, + x0: Annotated[Tensor, "Param"], + y0: Annotated[Tensor, "Param"], + q: Annotated[Tensor, "Param"], + phi: Annotated[Tensor, "Param"], + b: Annotated[Tensor, "Param"], ) -> Tensor: """ Calculate the projected mass density. diff --git a/src/caustics/lenses/singleplane.py b/src/caustics/lenses/singleplane.py index 2f6e2ea4..5b2cb465 100644 --- a/src/caustics/lenses/singleplane.py +++ b/src/caustics/lenses/singleplane.py @@ -1,11 +1,8 @@ -from typing import Optional - import torch from torch import Tensor +from caskade import forward from .base import ThinLens, CosmologyType, NameType, LensesType, ZLType -from ..parametrized import unpack -from ..packed import Packed __all__ = ("SinglePlane",) @@ -41,18 +38,15 @@ def __init__( super().__init__(cosmology, z_l=z_l, name=name) self.lenses = lenses for lens in lenses: - self.add_parametrized(lens) + self.link(lens.name, lens) # TODO: assert all z_l are the same? - @unpack + @forward def reduced_deflection_angle( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - **kwargs, ) -> tuple[Tensor, Tensor]: """ Calculate the total deflection angle by summing @@ -94,20 +88,17 @@ def reduced_deflection_angle( ax = torch.zeros_like(x) ay = torch.zeros_like(x) for lens in self.lenses: - ax_cur, ay_cur = lens.reduced_deflection_angle(x, y, z_s, params) + ax_cur, ay_cur = lens.reduced_deflection_angle(x, y, z_s) ax = ax + ax_cur ay = ay + ay_cur return ax, ay - @unpack + @forward def convergence( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - **kwargs, ) -> Tensor: """ Calculate the total projected mass density by @@ -143,19 +134,16 @@ def convergence( """ convergence = torch.zeros_like(x) for lens in self.lenses: - convergence_cur = lens.convergence(x, y, z_s, params) + convergence_cur = lens.convergence(x, y, z_s) convergence = convergence + convergence_cur return convergence - @unpack + @forward def potential( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - **kwargs, ) -> Tensor: """ Compute the total lensing potential by summing @@ -191,6 +179,6 @@ def potential( """ potential = torch.zeros_like(x) for lens in self.lenses: - potential_cur = lens.potential(x, y, z_s, params) + potential_cur = lens.potential(x, y, z_s) potential = potential + potential_cur return potential diff --git a/src/caustics/lenses/sis.py b/src/caustics/lenses/sis.py index 13aa9a57..7d25beec 100644 --- a/src/caustics/lenses/sis.py +++ b/src/caustics/lenses/sis.py @@ -2,10 +2,9 @@ from typing import Optional, Union, Annotated from torch import Tensor +from caskade import forward, Param from .base import ThinLens, CosmologyType, NameType, ZLType -from ..parametrized import unpack -from ..packed import Packed from . import func __all__ = ("SIS",) @@ -76,24 +75,20 @@ def __init__( """ super().__init__(cosmology, z_l, name=name) - self.add_param("x0", x0) - self.add_param("y0", y0) - self.add_param("th_ein", th_ein) + self.x0 = Param("x0", x0, units="arcsec") + self.y0 = Param("y0", y0, units="arcsec") + self.th_ein = Param("th_ein", th_ein, units="arcsec", valid=(0, None)) self.s = s - @unpack + @forward def reduced_deflection_angle( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - th_ein: Optional[Tensor] = None, - **kwargs, + x0: Annotated[Tensor, "Param"], + y0: Annotated[Tensor, "Param"], + th_ein: Annotated[Tensor, "Param"], ) -> tuple[Tensor, Tensor]: """ Calculate the deflection angle of the SIS lens. @@ -133,19 +128,15 @@ def reduced_deflection_angle( """ return func.reduced_deflection_angle_sis(x0, y0, th_ein, x, y, self.s) - @unpack + @forward def potential( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - th_ein: Optional[Tensor] = None, - **kwargs, + x0: Annotated[Tensor, "Param"], + y0: Annotated[Tensor, "Param"], + th_ein: Annotated[Tensor, "Param"], ) -> Tensor: """ Compute the lensing potential of the SIS lens. @@ -180,19 +171,15 @@ def potential( """ return func.potential_sis(x0, y0, th_ein, x, y, self.s) - @unpack + @forward def convergence( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional["Packed"] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - th_ein: Optional[Tensor] = None, - **kwargs, + x0: Annotated[Tensor, "Param"], + y0: Annotated[Tensor, "Param"], + th_ein: Annotated[Tensor, "Param"], ) -> Tensor: """ Calculate the projected mass density of the SIS lens. diff --git a/src/caustics/lenses/tnfw.py b/src/caustics/lenses/tnfw.py index 4cc85841..69bc3a76 100644 --- a/src/caustics/lenses/tnfw.py +++ b/src/caustics/lenses/tnfw.py @@ -2,10 +2,9 @@ from typing import Optional, Union, Literal, Annotated from torch import Tensor +from caskade import forward, Param from .base import ThinLens, CosmologyType, NameType, ZLType -from ..parametrized import unpack -from ..packed import Packed from . import func DELTA = 200.0 @@ -154,29 +153,25 @@ def __init__( """ super().__init__(cosmology, z_l, name=name) - self.add_param("x0", x0) - self.add_param("y0", y0) - self.add_param("mass", mass) - self.add_param("scale_radius", scale_radius) - self.add_param("tau", tau) + self.x0 = Param("x0", x0, units="arcsec") + self.y0 = Param("y0", y0, units="arcsec") + self.mass = Param("mass", mass, units="Msun", valid=(0, None)) + self.scale_radius = Param( + "scale_radius", scale_radius, units="arcsec", valid=(0, None) + ) + self.tau = Param("tau", tau, units="unitless", valid=(0, None)) self.s = s self.interpret_m_total_mass = interpret_m_total_mass self._F_mode = use_case if use_case not in ["batchable", "differentiable"]: raise ValueError("use case should be one of: batchable, differentiable") - @unpack + @forward def get_concentration( self, - *args, - params: Optional[Packed] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - mass: Optional[Tensor] = None, - scale_radius: Optional[Tensor] = None, - tau: Optional[Tensor] = None, - **kwargs, + z_l: Annotated[Tensor, "Param"], + mass: Annotated[Tensor, "Param"], + scale_radius: Annotated[Tensor, "Param"], ) -> Tensor: """ Compute the concentration parameter "c" for a TNFW profile. @@ -224,22 +219,15 @@ def get_concentration( *Unit: unitless* """ - critical_density = self.cosmology.critical_density(z_l, params) - d_l = self.cosmology.angular_diameter_distance(z_l, params) + critical_density = self.cosmology.critical_density(z_l) + d_l = self.cosmology.angular_diameter_distance(z_l) return func.concentration_tnfw(mass, scale_radius, critical_density, d_l, DELTA) - @unpack + @forward def get_truncation_radius( self, - *args, - params: Optional[Packed] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - mass: Optional[Tensor] = None, - scale_radius: Optional[Tensor] = None, - tau: Optional[Tensor] = None, - **kwargs, + scale_radius: Annotated[Tensor, "Param"], + tau: Annotated[Tensor, "Param"], ) -> Tensor: """ Calculate the truncation radius of the TNFW lens. @@ -289,18 +277,13 @@ def get_truncation_radius( """ return tau * scale_radius - @unpack + @forward def M0( self, - *args, - params: Optional[Packed] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - mass: Optional[Tensor] = None, - scale_radius: Optional[Tensor] = None, - tau: Optional[Tensor] = None, - **kwargs, + z_l: Annotated[Tensor, "Param"], + mass: Annotated[Tensor, "Param"], + scale_radius: Annotated[Tensor, "Param"], + tau: Annotated[Tensor, "Param"], ) -> Tensor: """ Calculate the reference mass. @@ -353,25 +336,19 @@ def M0( if self.interpret_m_total_mass: return func.M0_totmass_tnfw(mass, tau) else: - d_l = self.cosmology.angular_diameter_distance(z_l, params) - critical_density = self.cosmology.critical_density(z_l, params) + d_l = self.cosmology.angular_diameter_distance(z_l) + critical_density = self.cosmology.critical_density(z_l) c = func.concentration_tnfw( mass, scale_radius, critical_density, d_l, DELTA ) return func.M0_scalemass_tnfw(scale_radius, c, critical_density, d_l, DELTA) - @unpack + @forward def get_scale_density( self, - *args, - params: Optional[Packed] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - mass: Optional[Tensor] = None, - scale_radius: Optional[Tensor] = None, - tau: Optional[Tensor] = None, - **kwargs, + z_l: Annotated[Tensor, "Param"], + mass: Annotated[Tensor, "Param"], + scale_radius: Annotated[Tensor, "Param"], ) -> Tensor: """ Calculate the scale density of the lens. @@ -419,26 +396,23 @@ def get_scale_density( *Unit: Msun/Mpc^3* """ - d_l = self.cosmology.angular_diameter_distance(z_l, params) - critical_density = self.cosmology.critical_density(z_l, params) + d_l = self.cosmology.angular_diameter_distance(z_l) + critical_density = self.cosmology.critical_density(z_l) c = func.concentration_tnfw(mass, scale_radius, critical_density, d_l, DELTA) return func.scale_density_tnfw(c, critical_density, DELTA) - @unpack + @forward def convergence( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional[Packed] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - mass: Optional[Tensor] = None, - scale_radius: Optional[Tensor] = None, - tau: Optional[Tensor] = None, - **kwargs, + z_l: Annotated[Tensor, "Param"], + x0: Annotated[Tensor, "Param"], + y0: Annotated[Tensor, "Param"], + mass: Annotated[Tensor, "Param"], + scale_radius: Annotated[Tensor, "Param"], + tau: Annotated[Tensor, "Param"], ) -> Tensor: """ TNFW convergence as given in Baltz et al. 2009. @@ -488,9 +462,9 @@ def convergence( """ - d_l = self.cosmology.angular_diameter_distance(z_l, params) - critical_density = self.cosmology.critical_surface_density(z_l, z_s, params) - M0 = self.M0(params) + d_l = self.cosmology.angular_diameter_distance(z_l) + critical_density = self.cosmology.critical_surface_density(z_l, z_s) + M0 = self.M0(z_l=z_l, mass=mass, scale_radius=scale_radius, tau=tau) return func.convergence_tnfw( x0, y0, @@ -505,20 +479,13 @@ def convergence( self.s, ) - @unpack + @forward def mass_enclosed_2d( self, r: Tensor, z_s: Tensor, - *args, - params: Optional[Packed] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - mass: Optional[Tensor] = None, - scale_radius: Optional[Tensor] = None, - tau: Optional[Tensor] = None, - **kwargs, + scale_radius: Annotated[Tensor, "Param"], + tau: Annotated[Tensor, "Param"], ) -> Tensor: """ Total projected mass (Msun) within a radius r (arcsec). @@ -567,24 +534,21 @@ def mass_enclosed_2d( """ - M0 = self.M0(params) + M0 = self.M0() return func.mass_enclosed_2d_tnfw(r, scale_radius, tau, M0, self._F_mode) - @unpack + @forward def physical_deflection_angle( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional[Packed] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - mass: Optional[Tensor] = None, - scale_radius: Optional[Tensor] = None, - tau: Optional[Tensor] = None, - **kwargs, + z_l: Annotated[Tensor, "Param"], + x0: Annotated[Tensor, "Param"], + y0: Annotated[Tensor, "Param"], + mass: Annotated[Tensor, "Param"], + scale_radius: Annotated[Tensor, "Param"], + tau: Annotated[Tensor, "Param"], ) -> tuple[Tensor, Tensor]: """Compute the physical deflection angle (arcsec) for this lens at the requested position. Note that the NFW/TNFW profile is more @@ -639,27 +603,24 @@ def physical_deflection_angle( *Unit: arcsec* """ - d_l = self.cosmology.angular_diameter_distance(z_l, params) - M0 = self.M0(params) + d_l = self.cosmology.angular_diameter_distance(z_l) + M0 = self.M0(z_l=z_l, mass=mass, scale_radius=scale_radius, tau=tau) return func.physical_deflection_angle_tnfw( x0, y0, scale_radius, tau, x, y, M0, d_l, self._F_mode, self.s ) - @unpack + @forward def potential( self, x: Tensor, y: Tensor, z_s: Tensor, - *args, - params: Optional[Packed] = None, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - mass: Optional[Tensor] = None, - scale_radius: Optional[Tensor] = None, - tau: Optional[Tensor] = None, - **kwargs, + z_l: Annotated[Tensor, "Param"], + x0: Annotated[Tensor, "Param"], + y0: Annotated[Tensor, "Param"], + mass: Annotated[Tensor, "Param"], + scale_radius: Annotated[Tensor, "Param"], + tau: Annotated[Tensor, "Param"], ) -> Tensor: """ Compute the lensing potential. @@ -712,11 +673,11 @@ def potential( """ - d_l = self.cosmology.angular_diameter_distance(z_l, params) - d_s = self.cosmology.angular_diameter_distance(z_s, params) - d_ls = self.cosmology.angular_diameter_distance_z1z2(z_l, z_s, params) + d_l = self.cosmology.angular_diameter_distance(z_l) + d_s = self.cosmology.angular_diameter_distance(z_s) + d_ls = self.cosmology.angular_diameter_distance_z1z2(z_l, z_s) - M0 = self.M0(params) + M0 = self.M0(z_l=z_l, mass=mass, scale_radius=scale_radius, tau=tau) return func.potential_tnfw( x0, y0, scale_radius, tau, x, y, M0, d_l, d_s, d_ls, self._F_mode, self.s ) diff --git a/src/caustics/light/base.py b/src/caustics/light/base.py index 3d636ca2..fb9490e3 100644 --- a/src/caustics/light/base.py +++ b/src/caustics/light/base.py @@ -2,22 +2,20 @@ from typing import Optional, Annotated from torch import Tensor - -from ..parametrized import Parametrized, unpack -from ..packed import Packed +from caskade import Module, forward __all__ = ("Source",) NameType = Annotated[Optional[str], "Name of the source"] -class Source(Parametrized): +class Source(Module): """ This is an abstract base class used to represent a source in a strong gravitational lensing system. It provides the basic structure and required methods that any derived source class should implement. - The Source class inherits from the Parametrized class, + The Source class inherits from the Module class, implying that it contains parameters that can be optimized or manipulated. @@ -28,10 +26,8 @@ class Source(Parametrized): """ @abstractmethod - @unpack - def brightness( - self, x: Tensor, y: Tensor, *args, params: Optional["Packed"] = None, **kwargs - ) -> Tensor: + @forward + def brightness(self, x: Tensor, y: Tensor, *args, **kwargs) -> Tensor: """ Abstract method that calculates the brightness of the source at the given coordinates. This method is expected to be implemented in any class that derives from Source. @@ -52,12 +48,6 @@ def brightness( *Unit: arcsec* - params: Packed, optional - Dynamic parameter container that might be required - to calculate the brightness. - The exact contents will depend on the specific - implementation in derived classes. - Returns ------- Tensor diff --git a/src/caustics/light/light_stack.py b/src/caustics/light/light_stack.py index 340e560c..45a5acf8 100644 --- a/src/caustics/light/light_stack.py +++ b/src/caustics/light/light_stack.py @@ -1,11 +1,10 @@ # mypy: disable-error-code="operator,union-attr" -from typing import Optional, Annotated, List +from typing import Annotated, List import torch +from caskade import forward from .base import Source, NameType -from ..parametrized import unpack -from ..packed import Packed __all__ = ("LightStack",) @@ -44,15 +43,13 @@ def __init__( super().__init__(name=name) self.light_models = light_models for model in light_models: - self.add_parametrized(model) + self.link(model.name, model) - @unpack + @forward def brightness( self, x, y, - *args, - params: Optional["Packed"] = None, **kwargs, ): """ @@ -88,5 +85,5 @@ def brightness( brightness = torch.zeros_like(x) for light_model in self.light_models: - brightness += light_model.brightness(x, y, params=params, **kwargs) + brightness += light_model.brightness(x, y, **kwargs) return brightness diff --git a/src/caustics/light/pixelated.py b/src/caustics/light/pixelated.py index dfcb3805..f4fa629a 100644 --- a/src/caustics/light/pixelated.py +++ b/src/caustics/light/pixelated.py @@ -2,11 +2,10 @@ from typing import Optional, Union, Annotated from torch import Tensor +from caskade import forward, Param from ..utils import interp2d from .base import Source, NameType -from ..parametrized import unpack -from ..packed import Packed __all__ = ("Pixelated",) @@ -116,23 +115,23 @@ def __init__( f"shape must be specify 2D or 3D tensors. Received shape={shape}" ) super().__init__(name=name) - self.add_param("x0", x0) - self.add_param("y0", y0) - self.add_param("image", image, shape) - self.add_param("pixelscale", pixelscale) - - @unpack + self.x0 = Param("x0", x0, units="arcsec") + self.y0 = Param("y0", y0, units="arcsec") + self.image = Param("image", image, shape, units="flux") + self.pixelscale = Param( + "pixelscale", pixelscale, units="arcsec/pixel", valid=(0, None) + ) + + @forward def brightness( self, x, y, - *args, - params: Optional["Packed"] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - image: Optional[Tensor] = None, - pixelscale: Optional[Tensor] = None, - **kwargs, + x0: Annotated[Tensor, "Param"], + y0: Annotated[Tensor, "Param"], + image: Annotated[Tensor, "Param"], + pixelscale: Annotated[Tensor, "Param"], + padding_mode: str = "zeros", ): """ Implements the `brightness` method for `Pixelated`. @@ -173,4 +172,5 @@ def brightness( image, (x - x0).view(-1) / fov_x * 2, (y - y0).view(-1) / fov_y * 2, # make coordinates bounds at half the fov + padding_mode=padding_mode, ).reshape(x.shape) diff --git a/src/caustics/light/pixelated_time.py b/src/caustics/light/pixelated_time.py index a7332fec..295303c8 100644 --- a/src/caustics/light/pixelated_time.py +++ b/src/caustics/light/pixelated_time.py @@ -2,11 +2,10 @@ from typing import Optional, Union, Annotated from torch import Tensor +from caskade import forward, Param from ..utils import interp3d from .base import Source, NameType -from ..parametrized import unpack -from ..packed import Packed __all__ = ("PixelatedTime",) @@ -127,24 +126,21 @@ def __init__( f"shape must be specify 3D or 4D tensors. Received shape={shape}" ) super().__init__(name=name) - self.add_param("x0", x0) - self.add_param("y0", y0) - self.add_param("cube", cube, shape) + self.x0 = Param("x0", x0, units="arcsec") + self.y0 = Param("y0", y0, units="arcsec") + self.cube = Param("cube", cube, shape, units="flux") self.pixelscale = pixelscale self.t_end = t_end - @unpack + @forward def brightness( self, x, y, t, - *args, - params: Optional["Packed"] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - cube: Optional[Tensor] = None, - **kwargs, + x0: Annotated[Tensor, "Param"], + y0: Annotated[Tensor, "Param"], + cube: Annotated[Tensor, "Param"], ): """ Implements the `brightness` method for `Pixelated`. diff --git a/src/caustics/light/sersic.py b/src/caustics/light/sersic.py index 4423c96d..536358c3 100644 --- a/src/caustics/light/sersic.py +++ b/src/caustics/light/sersic.py @@ -1,11 +1,10 @@ # mypy: disable-error-code="operator,union-attr" from typing import Optional, Union, Annotated -from torch import Tensor +from torch import Tensor, pi +from caskade import forward, Param from .base import Source, NameType -from ..parametrized import unpack -from ..packed import Packed from . import func __all__ = ("Sersic",) @@ -165,32 +164,29 @@ def __init__( """ super().__init__(name=name) - self.add_param("x0", x0) - self.add_param("y0", y0) - self.add_param("q", q) - self.add_param("phi", phi) - self.add_param("n", n) - self.add_param("Re", Re) - self.add_param("Ie", Ie) + self.x0 = Param("x0", x0, units="arcsec") + self.y0 = Param("y0", y0, units="arcsec") + self.q = Param("q", q, units="unitless", valid=(0, 1)) + self.phi = Param("phi", phi, units="radians", valid=(0, pi), cyclic=True) + self.n = Param("n", n, units="unitless", valid=(0.36, 10)) + self.Re = Param("Re", Re, units="arcsec", valid=(0, None)) + self.Ie = Param("Ie", Ie, units="flux", valid=(0, None)) self.s = s self.lenstronomy_k_mode = use_lenstronomy_k - @unpack + @forward def brightness( self, x, y, - *args, - params: Optional["Packed"] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - q: Optional[Tensor] = None, - phi: Optional[Tensor] = None, - n: Optional[Tensor] = None, - Re: Optional[Tensor] = None, - Ie: Optional[Tensor] = None, - **kwargs, + x0: Annotated[Tensor, "Param"], + y0: Annotated[Tensor, "Param"], + q: Annotated[Tensor, "Param"], + phi: Annotated[Tensor, "Param"], + n: Annotated[Tensor, "Param"], + Re: Annotated[Tensor, "Param"], + Ie: Annotated[Tensor, "Param"], ): """ Implements the `brightness` method for `Sersic`. The brightness at a given point is diff --git a/src/caustics/light/star_source.py b/src/caustics/light/star_source.py index 7dd2f3c1..a7c95658 100644 --- a/src/caustics/light/star_source.py +++ b/src/caustics/light/star_source.py @@ -2,10 +2,9 @@ from typing import Optional, Union, Annotated from torch import Tensor +from caskade import forward, Param from .base import Source, NameType -from ..parametrized import unpack -from ..packed import Packed from . import func __all__ = ("StarSource",) @@ -111,25 +110,22 @@ def __init__( """ super().__init__(name=name) - self.add_param("x0", x0) - self.add_param("y0", y0) - self.add_param("theta_s", theta_s) - self.add_param("Ie", Ie) - self.add_param("gamma", gamma) + self.x0 = Param("x0", x0, units="arcsec") + self.y0 = Param("y0", y0, units="arcsec") + self.theta_s = Param("theta_s", theta_s, units="arcsec", valid=(0, None)) + self.Ie = Param("Ie", Ie, units="flux", valid=(0, None)) + self.gamma = Param("gamma", gamma, units="unitless", valid=(0, 1)) - @unpack + @forward def brightness( self, x, y, - *args, - params: Optional["Packed"] = None, - x0: Optional[Tensor] = None, - y0: Optional[Tensor] = None, - theta_s: Optional[Tensor] = None, - Ie: Optional[Tensor] = None, - gamma: Optional[Tensor] = None, - **kwargs, + x0: Annotated[Tensor, "Param"], + y0: Annotated[Tensor, "Param"], + theta_s: Annotated[Tensor, "Param"], + Ie: Annotated[Tensor, "Param"], + gamma: Annotated[Tensor, "Param"], ): """ Implements the `brightness` method for `star`. This method calculates the diff --git a/src/caustics/models/__init__.py b/src/caustics/models/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/caustics/models/api.py b/src/caustics/models/api.py deleted file mode 100644 index 9f5a03ee..00000000 --- a/src/caustics/models/api.py +++ /dev/null @@ -1,37 +0,0 @@ -# mypy: disable-error-code="import-untyped" -import yaml -from pathlib import Path -from typing import Union - -from ..sims.simulator import Simulator -from ..io import from_file - - -def build_simulator(config_path: Union[str, Path]) -> Simulator: - """ - Build a simulator from the configuration - """ - # Imports using Pydantic are placed here to make Pydantic a weak dependency - from caustics.models.utils import setup_simulator_models, create_model, Field - from caustics.models.base_models import StateConfig - - simulators = setup_simulator_models() - Config = create_model( - "Config", __base__=StateConfig, simulator=(simulators, Field(...)) - ) - - # Load the yaml config - yaml_bytes = from_file(config_path) - config_dict = yaml.safe_load(yaml_bytes) - # Create config model - config = Config(**config_dict) - - # Get the simulator - sim = config.simulator.model_obj() - - # Load state if available - simulator_state = config.state - if simulator_state is not None: - sim.load_state_dict(simulator_state.load.path) - - return sim diff --git a/src/caustics/models/base_models.py b/src/caustics/models/base_models.py deleted file mode 100644 index 3a973b1f..00000000 --- a/src/caustics/models/base_models.py +++ /dev/null @@ -1,90 +0,0 @@ -from typing import Optional, Any, Dict -from pydantic import BaseModel, Field, ConfigDict -from ..parametrized import Parametrized - - -class Parameters(BaseModel): - model_config = ConfigDict(arbitrary_types_allowed=True) - - -class InitKwargs(Parameters): - model_config = ConfigDict(arbitrary_types_allowed=True) - - -class Base(BaseModel): - name: str = Field(..., description="Name of the object") - kind: str = Field(..., description="Kind of the object") - params: Optional[Parameters] = Field(None, description="Parameters of the object") - init_kwargs: Optional[InitKwargs] = Field( - None, description="Initiation keyword arguments for object creation" - ) - - # internal - _cls: Any - - def __init__(self, **data): - super().__init__(**data) - - def _get_init_kwargs_dump(self, init_kwargs: InitKwargs) -> Dict[str, Any]: - """ - Get the model dump of the class parameters, - if the field is a model then get the model object. - - Parameters - ---------- - init_kwargs : ClassParams - The class parameters to dump - - Returns - ------- - dict - The model dump of the class parameters - """ - model_dict = {} - for f in init_kwargs.model_fields_set: - model = getattr(init_kwargs, f) - if isinstance(model, Base): - model_dict[f] = model.model_obj() - elif isinstance(model, list): - model_dict[f] = [m.model_obj() for m in model] - else: - model_dict[f] = getattr(init_kwargs, f) - return model_dict - - @classmethod - def _set_class(cls, parametrized_cls: Parametrized) -> type["Base"]: - """ - Set the class of the object. - - Parameters - ---------- - cls : Parametrized - The class to set. - """ - cls._cls = parametrized_cls - return cls - - def model_obj(self) -> Any: - if not self._cls: - raise ValueError( - "The class is not set. Please set the class before calling this method." - ) - init_kwargs = ( - self._get_init_kwargs_dump(self.init_kwargs) if self.init_kwargs else {} - ) # Capture None case - params = self.params.model_dump() if self.params else {} # Capture None case - return self._cls(name=self.name, **init_kwargs, **params) - - -class FileInput(BaseModel): - path: str = Field(..., description="The path to the file") - - -class StateDict(BaseModel): - load: FileInput - - -class StateConfig(BaseModel): - state: Optional[StateDict] = Field( - None, description="State safetensor for the simulator" - ) diff --git a/src/caustics/models/registry.py b/src/caustics/models/registry.py deleted file mode 100644 index 118d0a95..00000000 --- a/src/caustics/models/registry.py +++ /dev/null @@ -1,127 +0,0 @@ -from functools import lru_cache -from collections import ChainMap -from typing import MutableMapping, Iterator, Optional - -from caustics.parametrized import Parametrized -from caustics.utils import _import_func_or_class - - -class _KindRegistry(MutableMapping[str, "Parametrized | str"]): - cosmology = { - "FlatLambdaCDM": "caustics.cosmology.FlatLambdaCDM.FlatLambdaCDM", - } - single_lenses = { - "EPL": "caustics.lenses.epl.EPL", - "ExternalShear": "caustics.lenses.external_shear.ExternalShear", - "PixelatedConvergence": "caustics.lenses.pixelated_convergence.PixelatedConvergence", - "NFW": "caustics.lenses.nfw.NFW", - "Point": "caustics.lenses.point.Point", - "PseudoJaffe": "caustics.lenses.pseudo_jaffe.PseudoJaffe", - "SIE": "caustics.lenses.sie.SIE", - "SIS": "caustics.lenses.sis.SIS", - "TNFW": "caustics.lenses.tnfw.TNFW", - "MassSheet": "caustics.lenses.mass_sheet.MassSheet", - "SinglePlane": "caustics.lenses.singleplane.SinglePlane", - "Multipole": "caustics.lenses.multipole.Multipole", - } - multi_lenses = { - "Multiplane": "caustics.lenses.multiplane.Multiplane", - } - light = { - "Pixelated": "caustics.light.pixelated.Pixelated", - "Sersic": "caustics.light.sersic.Sersic", - } - simulators = {"LensSource": "caustics.sims.lens_source.LensSource"} - - known_kinds = { - **cosmology, - **single_lenses, - **multi_lenses, - **light, - **simulators, - } - - def __init__(self) -> None: - self._m: ChainMap[str, "Parametrized | str"] = ChainMap({}, self.known_kinds) # type: ignore - - def __getitem__(self, item: str) -> Parametrized: - kind_mod: "str | Parametrized | None" = self._m.get(item, None) - if kind_mod is None: - raise KeyError(f"{item} not in registry") - if isinstance(kind_mod, str): - cls = _import_func_or_class(kind_mod) - else: - cls = kind_mod # type: ignore - return cls # type: ignore - - def __setitem__(self, item: str, value: "Parametrized | str") -> None: - if not ( - (isinstance(value, type) and issubclass(value, Parametrized)) - or isinstance(value, str) - ): - raise ValueError( - f"expected Parametrized subclass, got: {type(value).__name__!r}" - ) - self._m[item] = value - - def __delitem__(self, __v: str) -> None: - raise NotImplementedError("removal is unsupported") - - def __len__(self) -> int: - return len(set(self._m)) - - def __iter__(self) -> Iterator[str]: - return iter(set(self._m)) - - -_registry = _KindRegistry() - - -def available_kinds() -> list[str]: - """ - Return a list of classes that are available in the registry. - """ - return list(_registry) - - -def register_kind( - name: str, - cls: "Parametrized | str", - *, - clobber: bool = False, -) -> None: - """register a UPath implementation with a protocol - - Parameters - ---------- - name : str - Protocol name to associate with the class - cls : Parametrized or str - The caustics parametrized subclass or a str representing the - full path to the class like package.module.class. - clobber: - Whether to overwrite a protocol with the same name; if False, - will raise instead. - """ - if not clobber and name in _registry: - raise ValueError(f"{name!r} is already in registry and clobber is False!") - _registry[name] = cls - - -@lru_cache -def get_kind( - name: str, -) -> Optional[Parametrized]: - """Get a class from the registry by name. - - Parameters - ---------- - kind : str - The name of the kind to get. - - Returns - ------- - cls : Parametrized - The class associated with the given name. - """ - return _registry[name] diff --git a/src/caustics/models/utils.py b/src/caustics/models/utils.py deleted file mode 100644 index 6626e691..00000000 --- a/src/caustics/models/utils.py +++ /dev/null @@ -1,285 +0,0 @@ -# mypy: disable-error-code="union-attr, valid-type, has-type, assignment, arg-type, dict-item, return-value, misc" -import typing -import inspect -import torch -from typing import List, Literal, Dict, Annotated, Union, Any, Tuple - -try: - from pydantic import Field, create_model, field_validator, ValidationInfo -except ImportError: - raise ImportError( - "The `pydantic` package is required to use this feature. " - "You can install it using `pip install pydantic==2.7`. This package requires rust. Make sure you have the permissions to install the dependencies.\n " - "Otherwise, the maintainer can install the package for you, you can then use `pip install --no-index pydantic`" - ) - -from ..parametrized import Parametrized -from .base_models import Base, Parameters, InitKwargs -from .registry import get_kind, _registry -from ..parametrized import ClassParam -from ..utils import _import_func_or_class, _eval_expression - -PARAMS = "params" -INIT_KWARGS = "init_kwargs" - - -def _get_kwargs_field_definitions( - parametrized_class: Parametrized, dependant_models: Dict[str, Any] = {} -) -> Dict[str, Dict[str, Any]]: - """ - Get the field definitions for the parameters and init_kwargs of a Parametrized class - - Parameters - ---------- - parametrized_class : Parametrized - The Parametrized class to get the field definitions for. - dependant_models : Dict[str, Any], optional - The dependent models to use, by default {} - See: https://docs.pydantic.dev/latest/concepts/unions/#nested-discriminated-unions - - Returns - ------- - dict - The resulting field definitions dictionary - """ - cls_signature = inspect.signature(parametrized_class) - kwargs_field_definitions: Dict[str, Dict[str, Any]] = {PARAMS: {}, INIT_KWARGS: {}} - for k, v in cls_signature.parameters.items(): - if k != "name": - anno = v.annotation - dtype = anno.__origin__ - cls_param = ClassParam(*anno.__metadata__) - if cls_param.isParam: - kwargs_field_definitions[PARAMS][k] = ( - dtype, - Field(default=v.default, description=cls_param.description), - ) - # Below is to handle cases for init kwargs - elif k in dependant_models: - dependant_model = dependant_models[k] - if isinstance(dependant_model, list): - # For the multi lens case - # dependent model is wrapped in a list - dependant_model = dependant_model[0] - kwargs_field_definitions[INIT_KWARGS][k] = ( - List[dependant_model], - Field([], description=cls_param.description), - ) - else: - kwargs_field_definitions[INIT_KWARGS][k] = ( - dependant_model, - Field(..., description=cls_param.description), - ) - elif v.default == inspect._empty: - kwargs_field_definitions[INIT_KWARGS][k] = ( - dtype, - Field(..., description=cls_param.description), - ) - else: - kwargs_field_definitions[INIT_KWARGS][k] = ( - dtype, - Field(v.default, description=cls_param.description), - ) - return kwargs_field_definitions - - -def create_pydantic_model( - cls: "Parametrized | str", dependant_models: Dict[str, type] = {} -) -> Base: - """ - Create a pydantic model from a Parametrized class. - - Parameters - ---------- - cls : Parametrized | str - The Parametrized class to create the model from. - dependant_models : Dict[str, type], optional - The dependent models to use, by default {} - See: https://docs.pydantic.dev/latest/concepts/unions/#nested-discriminated-unions - - Returns - ------- - Base - The pydantic model of the Parametrized class. - """ - if isinstance(cls, str): - parametrized_class = get_kind(cls) # type: ignore - - # Get the field definitions for parameters and init_kwargs - kwargs_field_definitions = _get_kwargs_field_definitions( - parametrized_class, dependant_models - ) - - # Create the model field definitions - field_definitions = { - "kind": (Literal[parametrized_class.__name__], Field(parametrized_class.__name__)), # type: ignore - } - - if kwargs_field_definitions[PARAMS]: - - def _param_field_tensor_check(cls, v): - """Checks the ``params`` fields input - and converts to tensor if necessary""" - if not isinstance(v, torch.Tensor): - if isinstance(v, str): - v = _eval_expression(v) - v = torch.as_tensor(v) - return v - - # Setup the pydantic models for the parameters and init_kwargs - ParamsModel = create_model( - f"{parametrized_class.__name__}_Params", - __base__=Parameters, - __validators__={ - # Convert to tensor before passing to the model for additional validation - "field_tensor_check": field_validator( - "*", mode="before", check_fields=True - )(_param_field_tensor_check) - }, - **kwargs_field_definitions[PARAMS], - ) - field_definitions["params"] = ( - ParamsModel, - Field(ParamsModel(), description="Parameters of the object"), - ) - - if kwargs_field_definitions[INIT_KWARGS]: - - def _init_kwargs_field_check(cls, v, info: ValidationInfo): - """Checks the ``init_kwargs`` fields input""" - field_name = info.field_name - field = cls.model_fields[field_name] - anno_args = typing.get_args(field.annotation) - if len(anno_args) == 2 and anno_args[1] is type(None): - # This means that the anno is optional - expected_type = next( - filter(lambda x: x is not None, typing.get_args(field.annotation)) - ) - if not isinstance(v, expected_type): - if isinstance(v, dict): - if all(k in ["func", "kwargs"] for k in v.keys()): - # Special case for the init_kwargs - # this is to allow for creating tensor with some - # caustics utils function, such as - # `caustics.utils.gaussian` - func = _import_func_or_class(v["func"]) - v = func(**v["kwargs"]) # type: ignore - else: - raise ValueError( - f"Dictionary with keys 'func' and 'kwargs' expected, got: {v.keys()}" - ) - elif expected_type == torch.Tensor: - # Try to cast to tensor if expected type is tensor - v = torch.as_tensor(v) - else: - # Try to cast to the expected type - v = expected_type(v) - return v - - InitKwargsModel = create_model( - f"{parametrized_class.__name__}_Init_Kwargs", - __base__=InitKwargs, - **kwargs_field_definitions[INIT_KWARGS], - __validators__={ - "field_check": field_validator("*", mode="before", check_fields=True)( - _init_kwargs_field_check - ) - }, - ) - field_definitions["init_kwargs"] = ( - InitKwargsModel, - Field({}, description="Initiation keyword arguments of the object"), - ) - - # Create the model - model = create_model( - parametrized_class.__name__, __base__=Base, **field_definitions - ) - # Set the imported parametrized class to the model - # this will be accessible as `model._cls` - model = model._set_class(parametrized_class) - return model - - -def setup_pydantic_models() -> Tuple[type[Annotated], type[Annotated]]: - """ - Setup the pydantic models for the light sources and lenses. - - Returns - ------- - light_sources : type[Annotated] - The annotated union of the light source pydantic models - lenses : type[Annotated] - The annotated union of the lens pydantic models - """ - # Cosmology - cosmology_models = [create_pydantic_model(cosmo) for cosmo in _registry.cosmology] - cosmology = Annotated[Union[tuple(cosmology_models)], Field(discriminator="kind")] - # Light - light_models = [create_pydantic_model(light) for light in _registry.light] - light_sources = Annotated[Union[tuple(light_models)], Field(discriminator="kind")] - # Single Lens - lens_dependant_models = {"cosmology": cosmology} - single_lens_models = [ - create_pydantic_model(lens, dependant_models=lens_dependant_models) - for lens in _registry.single_lenses - if lens != "SinglePlane" # make exception for single plane - ] - single_lenses = Annotated[ - Union[tuple(single_lens_models)], Field(discriminator="kind") - ] - # Single plane - # this is a special case since single plane - # is a multi lens system - # but this is an option for multi lens - single_plane_model = create_pydantic_model( - "SinglePlane", - dependant_models={"lenses": [single_lenses], **lens_dependant_models}, - ) - single_lenses_and_plane = Annotated[ - Union[tuple([single_plane_model, *single_lens_models])], - Field(discriminator="kind"), - ] - # Multi Lens - multi_lens_models = [ - create_pydantic_model( - lens, - dependant_models={ - "lenses": [single_lenses_and_plane], - **lens_dependant_models, - }, - ) - for lens in _registry.multi_lenses - ] - lenses = Annotated[ - Union[tuple([single_plane_model, *single_lens_models, *multi_lens_models])], - Field(discriminator="kind"), - ] - return light_sources, lenses - - -def setup_simulator_models() -> type[Annotated]: - """ - Setup the pydantic models for the simulators - - Returns - ------- - type[Annotated] - The annotated union of the simulator pydantic models - """ - light_sources, lenses = setup_pydantic_models() - # Hard code the dependants for now - # there's currently only one simulator - # in the system. - dependents = { - "LensSource": { - "source": light_sources, - "lens_light": light_sources, - "lens": lenses, - } - } - simulators_models = [ - create_pydantic_model(sim, dependant_models=dependents.get(sim)) - for sim in _registry.simulators - ] - return Annotated[Union[tuple(simulators_models)], Field(discriminator="kind")] diff --git a/src/caustics/namespace_dict.py b/src/caustics/namespace_dict.py deleted file mode 100644 index e96026bb..00000000 --- a/src/caustics/namespace_dict.py +++ /dev/null @@ -1,193 +0,0 @@ -from collections import OrderedDict -import pprint - - -class NamespaceDict(OrderedDict): - """ - Add support for attributes on top of an OrderedDict - """ - - def __getattr__(self, key): - if key in self: - return self[key] - else: - raise AttributeError(f"'NamespaceDict' object has no attribute '{key}'") - - def __setattr__(self, key, value): - self[key] = value - - def __delattr__(self, key): - if key in self: - del self[key] - else: - raise AttributeError(f"'NamespaceDict' object has no attribute '{key}'") - - def __repr__(self): - return pprint.pformat(dict(self)) - - def __str__(self): - return pprint.pformat(dict(self)) - - -class _NestedNamespaceDict(NamespaceDict): - """ - Abstract method for NestedNamespaceDict and its Proxy - """ - - def flatten(self) -> NamespaceDict: - """ - Flatten the nested dictionary into a NamespaceDict - - Returns - ------- - NamespaceDict - Flattened dictionary as a NamespaceDict - """ - flattened_dict = NamespaceDict() - - def _flatten_dict(dictionary, parent_key=""): - for key, value in dictionary.items(): - new_key = f"{parent_key}.{key}" if parent_key else key - if isinstance(value, dict): - _flatten_dict(value, new_key) - else: - flattened_dict[new_key] = value - - _flatten_dict(self) - return flattened_dict - - def collapse(self) -> NamespaceDict: - """ - Flatten the nested dictionary and collapse keys into the first level - of the NamespaceDict - - Returns - ------- - NamespaceDict - Flattened dictionary as a NamespaceDict - """ - flattened_dict = NamespaceDict() - - def _flatten_dict(dictionary): - for key, value in dictionary.items(): - if isinstance(value, dict): - _flatten_dict(value) - else: - flattened_dict[key] = value - - _flatten_dict(self) - return flattened_dict - - -class _NestedNamespaceProxy(_NestedNamespaceDict): - """ - Proxy for NestedNamespaceDict in order to allow recursion in - the class attributes - """ - - def __init__(self, parent, key_path): - # Add new private keys to give us a ladder back to root node - self._parent = parent - self._key_path = key_path - super().__init__(parent[key_path]) - - def __setattr__(self, key, value): - if key.startswith("_"): - # We are in a child node, we need to recurse up - super().__setattr__(key, value) - else: - # We are at the root node, call the __setitem__ to set record value - self._parent.__setitem__(f"{self._key_path}.{key}", value) - - # Hide the private keys from common usage - def keys(self): - return [key for key in super().keys() if not key.startswith("_")] - - def items(self): - for key, value in super().items(): - if not key.startswith("_"): - yield (key, value) - - def values(self): - return [v for k, v in super().items() if not k.startswith("_")] - - def __len__(self): - # make sure hidden keys don't count in the length of the object - return len(self.keys()) - - -class NestedNamespaceDict(_NestedNamespaceDict): - """ - Example usage - ```python - nested_namespace = NestedNamespaceDict() - nested_namespace.foo = 'Hello' - nested_namespace.bar = {'baz': 'World'} - nested_namespace.bar.qux = 42 - # works also in the following way - nested_namespace["bar.qux"] = 42 - - print(nested_namespace) - # Output: - # {'foo': 'Hello', 'bar': {'baz': 'World', 'qux': 42 }} - - #============================== - # Flattened key access - #============================== - print(nested_dict['foo']) # Output: Hello - print(nested_dict['bar.baz']) # Output: World - print(nested_dict['bar.qux']) # Output: 42 - - #============================== - # Nested namespace access - #============================== - print(nested_dict.bar.qux) # Output: 42 - - #============================== - # Flatten and collapse method - #============================== - print(nested_dict.flatten()) - # Output: - # {'foo': 'Hello', 'bar.baz': 'World', 'bar.qux': 42} - - print(nested_dict.collapse() - # Output: - # {'foo': 'Hello', 'baz': 'World', 'qux': 42} - - """ - - def __getattr__(self, key): - if key in self: - value = super().__getitem__(key) - if isinstance(value, dict): - return _NestedNamespaceProxy(self, key) - else: - return value - else: - raise AttributeError( - f"'NestedNamespaceDict' object has no attribute '{key}'" - ) - - def __getitem__(self, key): - if "." in key: - root, childs = key.split(".", 1) - if root not in self: - raise KeyError(f"'NestedNamespaceDict' object has no key '{key}'") - return self[root].__getitem__(childs) - else: - return super().__getitem__(key) - - def __setitem__(self, key, value): - if isinstance(value, dict) and not isinstance(value, NestedNamespaceDict): - value = NestedNamespaceDict(value) - if "." in key: - root, childs = key.split(".", 1) - if root not in self: - self[root] = NestedNamespaceDict() - elif not isinstance(self[root], dict): - raise ValueError( - "Can't assign a NestedNamespaceDict to a non-dict entry" - ) - self[root].__setitem__(childs, value) - else: - super().__setitem__(key, value) diff --git a/src/caustics/packed.py b/src/caustics/packed.py deleted file mode 100644 index eaad6ff2..00000000 --- a/src/caustics/packed.py +++ /dev/null @@ -1,9 +0,0 @@ -from collections import OrderedDict - - -class Packed(OrderedDict): - """ - Dummy wrapper for `x` so other functions can check its type. - """ - - ... diff --git a/src/caustics/parameter.py b/src/caustics/parameter.py deleted file mode 100644 index f86c874b..00000000 --- a/src/caustics/parameter.py +++ /dev/null @@ -1,102 +0,0 @@ -# mypy: disable-error-code="union-attr" -from typing import Optional, Union - -import torch -from torch import Tensor - -__all__ = ("Parameter",) - - -class Parameter: - """ - Represents a static or dynamic parameter used - for strong gravitational lensing simulations - in the caustics codebase. - - A static parameter has a fixed value, - while a dynamic parameter must be passed - in each time it's required. - - Attributes - ---------- - value: (Optional[Tensor]) - The value of the parameter. - shape: (tuple[int, ...]) - The shape of the parameter. - """ - - def __init__( - self, - value: Optional[Union[Tensor, float]] = None, - shape: Optional[tuple[int, ...]] = (), - ): - # Must assign one of value or shape - if value is None: - if shape is None: - raise ValueError("If value is None, a shape must be provided") - if not isinstance(shape, tuple): - raise TypeError("The shape of a parameter must be a tuple") - self._shape = shape - else: - value = torch.as_tensor(value) - self._shape = value.shape - self._value = value - self._dtype = None if value is None else value.dtype - - @property - def static(self) -> bool: - return not self.dynamic - - @property - def dynamic(self) -> bool: - return self._value is None - - @property - def value(self) -> Optional[Tensor]: - return self._value - - @value.setter - def value(self, value: Union[None, Tensor, float]): - if value is not None: - value = torch.as_tensor(value) - if value.shape != self.shape: - raise ValueError( - "Cannot set Parameter value with a different shape. " - f"Received {value.shape}, expected {self.shape}" - ) - self._value = value - self._dtype = None if value is None else value.dtype - - @property - def dtype(self): - return self._dtype - - @property - def shape(self) -> tuple[int, ...]: - return self._shape - - def set_static(self): - self.value = None - - def to( - self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None - ): - """ - Moves and/or casts the values of the parameter. - - Parameters - ---------- - device: (Optional[torch.device], optional) - The device to move the values to. Defaults to None. - dtype: (Optional[torch.dtype], optional) - The desired data type. Defaults to None. - """ - if self.static: - self.value = self._value.to(device=device, dtype=dtype) - return self - - def __repr__(self) -> str: - if self.static: - return f"Param(value={self.value}, dtype={str(self.dtype)})" - else: - return f"Param(shape={self.shape})" diff --git a/src/caustics/parametrized.py b/src/caustics/parametrized.py deleted file mode 100644 index 205bf954..00000000 --- a/src/caustics/parametrized.py +++ /dev/null @@ -1,610 +0,0 @@ -# mypy: disable-error-code="var-annotated,index,type-arg" -from collections import OrderedDict -from math import prod -from typing import Optional, Union, List -from dataclasses import dataclass - -import functools -import itertools as it -import inspect -import textwrap - -import torch -import re -import keyword -from torch import Tensor -import graphviz - -from .packed import Packed -from .namespace_dict import NamespaceDict, NestedNamespaceDict -from .parameter import Parameter - -__all__ = ("Parametrized", "unpack") - - -@dataclass -class ClassParam: - description: str - isParam: bool = False - unit: Optional[str] = None - - -def check_valid_name(name): - if keyword.iskeyword(name) or not bool(re.match("^[a-zA-Z_][a-zA-Z0-9_]*$", name)): - raise NameError( - f"The string {name} contains illegal characters (like space or '-'). " - "Please use snake case or another valid python variable naming style." - ) - - -class Parametrized: - """ - Represents a class with Param and Parametrized attributes, - typically used to construct parts of a simulator - that have parameters which need to be tracked during MCMC sampling. - - This class can contain Params, Parametrized, - tensor buffers or normal attributes as its attributes. - It provides functionalities to manage these attributes, - ensuring that an attribute of one type isn't rebound - to be of a different type. - - TODO - - Attributes can be Params, Parametrized, tensor buffers or just normal attributes. - - Need to make sure an attribute of one of those types isn't rebound to be of a different type. - - Attributes - ---------- - name: str - The name of the Parametrized object. Default to class name. - parents: NestedNamespaceDict - Nested dictionary of parent Parametrized objects (higher level, more abstract modules). - params: OrderedDict[str, Parameter] - Dictionary of parameters. - childs: NestedNamespaceDict - Nested dictionary of childs Parametrized objects (lower level, more specialized modules). - dynamic_size: int - Size of dynamic parameters. - n_dynamic: int - Number of dynamic parameters. - n_static: int - Number of static parameters. - """ - - def __init__(self, name: Optional[str] = None): - if name is None: - name = self._default_name() - check_valid_name(name) - if not isinstance(name, str): - raise ValueError(f"name must be a string (received {name})") - self._name = name - self._parents: OrderedDict[str, Parametrized] = NamespaceDict() - self._params: OrderedDict[str, Parameter] = NamespaceDict() - self._childs: OrderedDict[str, Parametrized] = NamespaceDict() - self._module_key_map = {} - - def _default_name(self): - return re.search("([A-Z])\w+", str(self.__class__)).group() - - def __getattribute__(self, key): - try: - return super().__getattribute__(key) - except AttributeError as e: - # Check if key refers to a parametrized module name (different from its attribute key) - _map = super().__getattribute__( - "_module_key_map" - ) # use super to avoid recursion error - if key in _map.keys(): - return super().__getattribute__(_map[key]) - else: - raise e - - def __setattr__(self, key, value): - try: - if key in self._params.keys(): - # Redefine parameter value instead of making a new attribute - self._params[key].value = value - elif isinstance(value, Parameter): - # Create new parameter and attach it as an attribute - self.add_param(key, value.value, value.shape) - elif isinstance(value, Parametrized): - # Update map from attribute key to module name - # for __getattribute__ method - self._module_key_map[value.name] = key - self.add_parametrized(value, set_attr=False) - # set attr only to user defined key, - # not module name (self.{module.name} is still accessible, - # see __getattribute__ method) - super().__setattr__(key, value) - else: - super().__setattr__(key, value) - except AttributeError: # _params or another attribute in here do not exist yet - super().__setattr__(key, value) - - @property - def name(self) -> str: - return self._name - - @name.setter - def name(self, new_name: str): - check_valid_name(new_name) - old_name = self.name - for parent in self._parents.values(): - del parent._childs[old_name] - parent._childs[new_name] = self - for child in self._childs.values(): - del child._parents[old_name] - child._parents[new_name] = self - self._name = new_name - - def to( - self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None - ): - """ - Moves static Params for this component and its childs - to the specified device and casts them to the specified data type. - """ - for name, p in self._params.items(): - self._params[name] = p.to(device, dtype) - for child in self._childs.values(): - child.to(device, dtype) - return self - - @staticmethod - def _generate_unique_name(name, module_names): - i = 1 - while f"{name}_{i}" in module_names: - i += 1 - return f"{name}_{i}" - - def add_parametrized(self, p: "Parametrized", set_attr=True): - """ - Add a child to this module, and create edges for the DAG - """ - # If self.name is already in the module parents, we need to update self.name - if self.name in p._parents.keys(): - new_name = self._generate_unique_name(self.name, p._parents.keys()) - self.name = new_name # from name.setter, this updates the DAG edges as well - p._parents[self.name] = self - # If the module name is already in self._childs, we need to update module name - if p.name is self._childs.keys(): - new_name = self._generate_unique_name(p.name, self._childs.keys()) - p.name = new_name - self._childs[p.name] = p - if set_attr: - super().__setattr__(p.name, p) - - def add_param( - self, - name: str, - value: Optional[Union[Tensor, float]] = None, - shape: Optional[tuple[int, ...]] = (), - ): - """ - Stores a parameter in the _params dictionary and records its size. - - Parameters - ---------- - name: str - The name of the parameter. - value: (Optional[Tensor], optional) - The value of the parameter. Defaults to None. - shape: (Optional[tuple[int, ...]], optional) - The shape of the parameter. Defaults to an empty tuple. - """ - self._params[name] = Parameter(value, shape) - # __setattr__ inside add_param to catch all uses of this method - super().__setattr__(name, self._params[name]) - - @property - def n_dynamic(self) -> int: - return len(self.module_params.dynamic) - - @property - def n_static(self) -> int: - return len(self.module_params.static) - - @property - def dynamic_size(self) -> int: - return sum(prod(dyn.shape) for dyn in self.module_params.dynamic.values()) - - @property - def x_keys(self) -> OrderedDict[str, List[str]]: - return OrderedDict( - [ - (module.name, list(module.module_params.dynamic.keys())) - for module in self.dynamic_modules.values() - ] - ) - - @property - def x_order(self) -> List[str]: - merged_keys = [ - [".".join([key, v]) for v in values] for key, values in self.x_keys.items() - ] - return list(it.chain.from_iterable(merged_keys)) - - def pack( - self, - x: Union[ - list[Tensor], - dict[str, Union[list[Tensor], Tensor, dict[str, Tensor]]], - Tensor, - ] = Packed(), - ) -> Packed: - """ - Converts a list or tensor into a dict that can subsequently be unpacked - into arguments to this component and its childs. - Also, add a batch dimension to each Tensor - without such a dimension. - - Parameters - ---------- - x : list of tensor, dict of tensor, or tensor - The input to be packed. - - Returns - ------- - Packed - The packed input, and whether or not the input was batched. - - Raises - ------ - ValueError - If the input is not a list, dictionary, or tensor. - ValueError - If the input is a dictionary and some keys are missing. - ValueError - If the number of dynamic arguments does not match the expected number. - ValueError - If the input is a tensor and the shape does not match the expected shape. - """ - if isinstance(x, (dict, Packed)): - missing_names = [ - name for name in self.params.dynamic.keys() if name not in x - ] - if len(missing_names) > 0: - raise ValueError(f"missing x keys for {missing_names}") - - # TODO: check structure! - return Packed(x) - elif isinstance(x, (list, tuple)): - n_passed = len(x) - n_dynamic_params = len(self.params.dynamic.flatten()) - n_dynamic_modules = len(self.dynamic_modules) - x_repacked = {} - if n_passed == n_dynamic_params: - cur_offset = 0 - for name, module in self.dynamic_modules.items(): - x_repacked[name] = x[cur_offset : cur_offset + module.n_dynamic] - cur_offset += module.n_dynamic - elif n_passed == n_dynamic_modules: - for i, name in enumerate(self.dynamic_modules.keys()): - x_repacked[name] = x[i] - else: - raise ValueError( - f"{n_passed} dynamic args were passed, but {n_dynamic_params} parameters or " - f"{n_dynamic_modules} Tensor (1 per dynamic module) are required" - ) - return Packed(x_repacked) - - elif isinstance(x, Tensor): - n_passed = x.shape[-1] - n_expected = sum( - [module.dynamic_size for module in self.dynamic_modules.values()] - ) - if n_passed != n_expected: - # TODO: give component and arg names - raise ValueError( - f"{n_passed} flattened dynamic args were passed, but {n_expected} " - f"are required" - ) - - cur_offset = 0 - x_repacked = {} - for name, module in self.dynamic_modules.items(): - x_repacked[name] = x[..., cur_offset : cur_offset + module.dynamic_size] - cur_offset += module.dynamic_size - return Packed(x_repacked) - - else: - raise ValueError("Data structure not supported") - - def unpack( - self, x: Optional[dict[str, Union[list[Tensor], dict[str, Tensor], Tensor]]] - ) -> list[Tensor]: - """ - Unpacks a dict of kwargs, list of args or flattened vector of args to retrieve - this object's static and dynamic parameters. - - Parameters - ---------- - x: (Optional[dict[str, Union[list[Tensor], dict[str, Tensor], Tensor]]]) - The packed object to be unpacked. - - Returns - ------- - list[Tensor] - Unpacked static and dynamic parameters of the object. Note that - parameters will have an added batch dimension from the pack method. - - Raises - ------ - ValueError - If the input is not a dict, list, tuple or tensor. - ValueError - If the argument type is invalid. It must be a dict containing key {self.name} - and value containing args as list or flattened tensor, or kwargs. - """ - # Check if module has dynamic parameters - if self.module_params.dynamic: - dynamic_x = x[self.name] - else: # all parameters are static and module is not present in x - dynamic_x = [] - if isinstance(x, dict): - if self.name in x.keys() and x.get(self.name, {}): - print( - f"Module {self.name} is static, " - f"the parameters {' '.join(x[self.name].keys())} " - "passed dynamically will be ignored." - ) - unpacked_x = [] - offset = 0 - for name, param in self._params.items(): - if param.dynamic: - if isinstance(dynamic_x, dict): - param_value = dynamic_x[name] - elif isinstance(dynamic_x, (list, tuple)): - param_value = dynamic_x[offset] - offset += 1 - elif isinstance(dynamic_x, Tensor): - size = prod(param.shape) - param_value = dynamic_x[..., offset : offset + size].reshape( - param.shape - ) - offset += size - else: - raise ValueError( - f"Invalid data type found when unpacking parameters for {self.name}." - "Expected argument of unpack to be a list/tuple/dict of Tensor, " - "or simply a flattened tensor" - f"but found {type(dynamic_x)}." - ) - else: # param is static - param_value = param.value - if not isinstance(param_value, Tensor): - raise ValueError( - f"Invalid data type found when unpacking parameters for {self.name}." - f"Argument of unpack must contain Tensor, but found {type(param_value)}" - ) - unpacked_x.append(param_value) - return unpacked_x - - @property - def module_params(self) -> NestedNamespaceDict: - static = NestedNamespaceDict() - dynamic = NestedNamespaceDict() - for name, param in self._params.items(): - if param.static: - static[name] = param - else: - dynamic[name] = param - return NestedNamespaceDict([("static", static), ("dynamic", dynamic)]) - - @property - def params(self) -> NestedNamespaceDict: - # todo make this an ordinary dict and reorder at the end. - static = NestedNamespaceDict() - dynamic = NestedNamespaceDict() - - def _get_params(module): - mp = module.module_params - if mp.static: - static[module.name] = mp.static - if mp.dynamic: - dynamic[module.name] = mp.dynamic - for child in module._childs.values(): - _get_params(child) - - _get_params(self) - # TODO reorder - return NestedNamespaceDict([("static", static), ("dynamic", dynamic)]) - - @property - def dynamic_modules(self) -> NamespaceDict[str, "Parametrized"]: - # Only catch modules with dynamic parameters - modules = ( - NamespaceDict() - ) # todo make this an ordinary dict and reorder at the end. - - def _get_childs(module): - # Start from root, and move down the DAG - if module.module_params.dynamic: - modules[module.name] = module - if module._childs != {}: - for child in module._childs.values(): - _get_childs(child) - - _get_childs(self) - # TODO reorder - return modules - - @property - def static(self): - return list(self.module_params.static.keys()) - - @property - def dynamic(self): - return list(self.module_params.dynamic.keys()) - - def __repr__(self) -> str: - # TODO: change - return str(self) - - def __str__(self) -> str: - static_str = ", ".join(self.static) - dynamic_str = ", ".join(self.dynamic) - desc_dynamic_str = textwrap.shorten( - ", ".join(self.x_order), width=70, placeholder="..." - ) - - return ( - f"{self.__class__.__name__}(\n" - f" name='{self.name}',\n" - f" static=[{static_str}],\n" - f" dynamic=[{dynamic_str}],\n" - f" x_order=[{desc_dynamic_str}]\n" - f")" - ) - - def graph( - self, show_dynamic_params: bool = False, show_static_params: bool = False - ) -> "graphviz.Digraph": # type: ignore - """ - Returns a graph representation of the object and its parameters. - - Parameters - ---------- - show_dynamic_params: (bool, optional) - If true, the dynamic parameters are shown in the graph. Defaults to False. - show_static_params: (bool, optional) - If true, the static parameters are shown in the graph. Defaults to False. - - Returns - ------- - graphviz.Digraph - The graph representation of the object. - """ - - components = {} - params = [] - - def add_params(p: Parametrized, dot): - static = p.module_params.static.keys() - dynamic = p.module_params.dynamic.keys() - - dot.attr("node", style="solid", color="black", shape="box") - for n in dynamic: - pname = f"{p.name}/{n}" - if pname in params: - continue - params.append(pname) - if show_dynamic_params: - dot.node(pname, n) - dot.edge(p.name, pname) - - dot.attr("node", style="filled", color="lightgrey", shape="box") - for n in static: - pname = f"{p.name}/{n}" - if pname in params: - continue - params.append(pname) - if show_static_params: - dot.node(pname, n) - dot.edge(p.name, pname) - - def add_component(p: Parametrized, dot): - if p.name in components: - return - dot.attr("node", style="solid", color="black", shape="ellipse") - dot.node(p.name, f"{p.__class__.__name__}('{p.name}')") - components[p.name] = p - add_params(p, dot) - for child in p._childs.values(): - add_component(child, dot) - dot.edge(p.name, child.name) - - dot = graphviz.Digraph(strict=True) - add_component(self, dot) - - return dot - - -def count_args_before_varargs(function): - """ - Counts the number of arguments before the *args argument in a function. - """ - signature = inspect.signature(function) - count = 0 - - for param in signature.parameters.values(): - if param.kind == inspect.Parameter.VAR_POSITIONAL: - break - if param.name == "self": - continue - count += 1 - - return count - - -def unpack(method): - """ - Decorator that unpacks the "params" argument of a method. - There are a number of ways to interact with this method. - Let's consider a hypothetical lens with a function ``func`` that takes a position ``x`` and ``y`` and two parameters ``a`` and ``b`` and a key word argument ``c``. - The following are all valid ways to call ``func``:: python - - lens.func(x, y, a=a, b=b, c=c) # a and b are Tensors - lens.func(x, y, [a, b], c=c) # a and b are Tensors, or even [a, b] is a tensor - lens.func(x, y, params=[a, b], c=c) # a and b are Tensors, or even [a, b] is a tensor - - # If the ``a`` parameter has been set at a static value like this: - lens.a = a # a is a Tensor - # then the following is also valid: - lens.func(x, y, b=b, c=c) # b is a Tensor - lens.func(x, y, [b], c=c) # a and b are Tensors - lens.func(x, y, params=[b], c=c) # a and b are Tensors - - # If lens.func calls another method from a different parametrized object (say cosmo) and that method takes a dynamic parameter ``d``, then the following is also valid: - lens.func(x, y, a=a, b=b, cosmo_d=d, c=c) # a, b and d are Tensors - lens.func(x, y, [a, b, d], c=c) # a, b and d are Tensors - lens.func(x, y, params=[a, b, d], c=c) # a, b and d are Tensors - - In all cases any valid way to construct the ``Packed`` object also works (i.e. by tensor, dict, or tuple). - This gives a great deal of flexibility in how the parameters are passed to the method. - However, the following is not a valid way to pass the parameters:: python - - lens.func(x, y, a, b, c=c) - - This is because the ``a`` and ``b`` parameters are not named and so cannot be recognized by the unpack method. - - - """ - - nargs = count_args_before_varargs(method) - - @functools.wraps(method) - def wrapped(self, *args, **kwargs): - # Extract "params" regardless of how it is/they are passed - # --------------------------------------------------------- - if len(args) > nargs: - # Params is given as last argument - x = self.pack(args[-1]) - args = args[:-1] - elif "params" in kwargs: - # Params is given as a keyword argument - x = self.pack(kwargs.pop("params")) - elif self.params.dynamic: - # Params are given individually and are collected into a packed object - all_keys = self.params.dynamic - keys = list(all_keys.pop(self.name).keys()) - if all_keys: - keys += list(k.replace(".", "_") for k in all_keys.flatten().keys()) - try: - x = self.pack([kwargs.pop(name) for name in keys]) - except KeyError as e: - raise KeyError( - f"Missing parameter {e} in method '{method.__name__}' of module '{self.name}'" - ) - else: - # No dynamic parameters, so no packing is needed - x = Packed() - - # Fill kwargs with module params - # --------------------------------------------------------- - for name, value in zip(self._params.keys(), self.unpack(x)): - kwargs[name] = value - - return method(self, *args, params=x, **kwargs) - - return wrapped diff --git a/src/caustics/sims/__init__.py b/src/caustics/sims/__init__.py index ca28e5f1..7e13bef2 100644 --- a/src/caustics/sims/__init__.py +++ b/src/caustics/sims/__init__.py @@ -1,5 +1,5 @@ from .lens_source import LensSource from .microlens import Microlens -from .simulator import Simulator +from .simulator import build_simulator -__all__ = ("LensSource", "Microlens", "Simulator") +__all__ = ("LensSource", "Microlens", "build_simulator") diff --git a/src/caustics/sims/lens_source.py b/src/caustics/sims/lens_source.py index 8235ee2a..9c05370f 100644 --- a/src/caustics/sims/lens_source.py +++ b/src/caustics/sims/lens_source.py @@ -5,8 +5,9 @@ from typing import Optional, Annotated, Literal, Union import torch from torch import Tensor +from caskade import Module, forward, Param -from .simulator import Simulator, NameType +from .simulator import NameType from ..utils import ( meshgrid, gaussian_quadrature_grid, @@ -19,7 +20,7 @@ __all__ = ("LensSource",) -class LensSource(Simulator): +class LensSource(Module): """Lens image of a source. Straightforward simulator to sample a lensed image of a source object. @@ -149,11 +150,6 @@ def __init__( ): super().__init__(name) - # Lensing models - self.lens = lens - self.source = source - self.lens_light = lens_light - # Configure PSF self._psf_mode = psf_mode if psf is not None: @@ -161,12 +157,17 @@ def __init__( self._psf_shape = psf.shape if psf is not None else psf_shape # Build parameters - self.add_param("z_s", z_s) - self.add_param("psf", psf, self.psf_shape) - self.add_param("x0", x0) - self.add_param("y0", y0) + self.z_s = Param("z_s", z_s, units="unitless", valid=(0, None)) + self.psf = Param("psf", psf, self.psf_shape, units="unitless") + self.x0 = Param("x0", x0, units="arcsec") + self.y0 = Param("y0", y0, units="arcsec") self._pixelscale = pixelscale + # Lensing models + self.lens = lens + self.source = source + self.lens_light = lens_light + # Image grid self._pixels_x = pixels_x self._pixels_y = pixels_x if pixels_y is None else pixels_y @@ -306,13 +307,18 @@ def _unpad_fft(self, x): ..., : self._s[0], : self._s[1] ] - def forward( + @forward + def __call__( self, - params, - source_light=True, - lens_light=True, - lens_source=True, - psf_convolve=True, + z_s: Annotated[Tensor, "Param"], + psf: Annotated[Tensor, "Param"], + x0: Annotated[Tensor, "Param"], + y0: Annotated[Tensor, "Param"], + source_light: bool = True, + lens_light: bool = True, + lens_source: bool = True, + psf_convolve: bool = True, + chunk_size: int = 10000, ): """ forward function @@ -330,7 +336,6 @@ def forward( psf_convolve: boolean when true the image will be convolved with the psf """ - z_s, psf, x0, y0 = self.unpack(params) # Automatically turn off light for missing objects if self.source is None: @@ -346,12 +351,18 @@ def forward( if source_light: if lens_source: # Source is lensed by the lens mass distribution - bx, by = self.lens.raytrace(*grid, z_s, params) - mu_fine = self.source.brightness(bx, by, params) + bx, by = torch.vmap( + self.lens.raytrace, in_dims=(0, 0, None), chunk_size=chunk_size + )(grid[0].flatten(), grid[1].flatten(), z_s) + mu_fine = torch.vmap(self.source.brightness, chunk_size=chunk_size)( + bx, by + ).reshape(grid[0].shape) mu = gaussian_quadrature_integrator(mu_fine, self._weights) else: # Source is imaged without lensing - mu_fine = self.source.brightness(*grid, params) + mu_fine = torch.vmap(self.source.brightness, chunk_size=chunk_size)( + grid[0].flatten(), grid[1].flatten() + ).reshape(grid[0].shape) mu = gaussian_quadrature_integrator(mu_fine, self._weights) else: # Source is not added to the scene @@ -359,7 +370,9 @@ def forward( # Sample the lens light if lens_light and self.lens_light is not None: - mu_fine = self.lens_light.brightness(*grid, params) + mu_fine = torch.vmap(self.lens_light.brightness, chunk_size=chunk_size)( + grid[0].flatten(), grid[1].flatten() + ).reshape(grid[0].shape) mu += gaussian_quadrature_integrator(mu_fine, self._weights) # Convolve the PSF diff --git a/src/caustics/sims/microlens.py b/src/caustics/sims/microlens.py index 1cfd4190..7d328e3a 100644 --- a/src/caustics/sims/microlens.py +++ b/src/caustics/sims/microlens.py @@ -1,8 +1,9 @@ from typing import Optional, Annotated, Union, Literal import torch from torch import Tensor +from caskade import Module, forward, Param -from .simulator import Simulator, NameType +from .simulator import NameType from ..lenses.base import Lens from ..light.base import Source @@ -10,7 +11,7 @@ __all__ = ("Microlens",) -class Microlens(Simulator): +class Microlens(Module): """Computes the total flux from a microlens system within an fov. Straightforward simulator to compute the total flux a lensed image of a @@ -57,15 +58,16 @@ def __init__( ): super().__init__(name) + self.z_s = Param("z_s", z_s, units="unitless") + self.lens = lens self.source = source - self.add_param("z_s", z_s) - - def forward( + @forward + def __call__( self, - params, fov: Tensor, + z_s: Annotated[Tensor, "Param"], method: Literal["mcmc", "grid"] = "mcmc", N_mcmc: int = 10000, N_grid: int = 100, @@ -94,14 +96,13 @@ def forward( Error estimate on the total flux """ - (z_s,) = self.unpack(params) if method == "mcmc": # Sample the source using MCMC sample_x = torch.rand(N_mcmc) * (fov[1] - fov[0]) + fov[0] sample_y = torch.rand(N_mcmc) * (fov[3] - fov[2]) + fov[2] - bx, by = self.lens.raytrace(sample_x, sample_y, z_s, params) - mu = self.source.brightness(bx, by, params) + bx, by = self.lens.raytrace(sample_x, sample_y, z_s) + mu = self.source.brightness(bx, by) A = (fov[1] - fov[0]) * (fov[3] - fov[2]) return mu.mean() * A, mu.std() * A / N_mcmc**0.5 elif method == "grid": @@ -109,8 +110,8 @@ def forward( x = torch.linspace(fov[0], fov[1], N_grid) y = torch.linspace(fov[2], fov[3], N_grid) sample_x, sample_y = torch.meshgrid(x, y, indexing="ij") - bx, by = self.lens.raytrace(sample_x, sample_y, z_s, params) - mu = self.source.brightness(bx, by, params) + bx, by = self.lens.raytrace(sample_x, sample_y, z_s) + mu = self.source.brightness(bx, by) A = (fov[1] - fov[0]) * (fov[3] - fov[2]) return mu.mean() * A, mu.std() * A / N_grid else: diff --git a/src/caustics/sims/simulator.py b/src/caustics/sims/simulator.py index 0dfcc258..5ebd1101 100644 --- a/src/caustics/sims/simulator.py +++ b/src/caustics/sims/simulator.py @@ -1,99 +1,44 @@ -from typing import Dict, Annotated, Optional -from torch import Tensor +# mypy: disable-error-code="import-untyped,var-annotated" +from typing import Annotated, Optional, Union, TextIO +from inspect import signature -from ..parametrized import Parametrized -from .state_dict import StateDict -from ..namespace_dict import NestedNamespaceDict +from caskade import Module +import yaml -__all__ = ("Simulator",) +import caustics -NameType = Annotated[Optional[str], "Name of the simulator"] - - -class Simulator(Parametrized): - """A caustics simulator using Parametrized framework. - - Defines a simulator class which is a callable function that - operates on the Parametrized framework. Users define the `forward` - method which takes as its first argument an object which can be - packed, all other args and kwargs are simply passed to the forward - method. - - See `Parametrized` for details on how to add/access parameters. - - """ - - def __call__(self, *args, **kwargs): - if len(args) > 0: - packed_args = self.pack(args[0]) - rest_args = args[1:] - else: - packed_args = self.pack() - rest_args = tuple() - - return self.forward(packed_args, *rest_args, **kwargs) - - @staticmethod - def __set_module_params(module: Parametrized, params: Dict[str, Tensor]): - for k, v in params.items(): - setattr(module, k, v) - - def state_dict(self) -> StateDict: - return StateDict.from_params(self.params) +__all__ = ("NameType", "build_simulator") - def load_state_dict(self, file_path: str) -> "Simulator": - """ - Loads and then sets the state of the simulator from a file - - Parameters - ---------- - file_path : str | Path - The file path to a safetensors file - to load the state from - - Returns - ------- - Simulator - The simulator with the loaded state - """ - loaded_state_dict = StateDict.load(file_path) - self.set_state_dict(loaded_state_dict) - return self - - def set_state_dict(self, state_dict: StateDict) -> "Simulator": - """ - Sets the state of the simulator from a state dict - - Parameters - ---------- - state_dict : StateDict - The state dict to load from - - Returns - ------- - Simulator - The simulator with the loaded state - """ - # TODO: Do some checks for the state dict metadata - - # Convert to nested namespace dict - param_dicts = NestedNamespaceDict(state_dict) - - # Grab params for the current module - self_params = param_dicts.pop(self.name) - - def _set_params(module): - # Start from root, and move down the DAG - if module.name in param_dicts: - module_params = param_dicts[module.name] - self.__set_module_params(module, module_params) - if module._childs != {}: - for child in module._childs.values(): - _set_params(child) +NameType = Annotated[Optional[str], "Name of the simulator"] - # Set the parameters of the current module - self.__set_module_params(self, self_params) - # Set the parameters of the children modules - _set_params(self) - return self +def build_simulator(config: Union[str, TextIO]) -> Module: + + if isinstance(config, str): + with open(config, "r") as f: + config_dict = yaml.safe_load(f) + else: + config_dict = yaml.safe_load(config) + + modules = {} + for name, obj in config_dict.items(): + kwargs = obj.get("init_kwargs", {}) + for kwarg in kwargs: + for subname, subobj in config_dict.items(): + if subname == name: # only look at previous objects + break + if subobj == kwargs[kwarg] and isinstance(kwargs[kwarg], dict): + # fill already constructed object + kwargs[kwarg] = modules[subname] + + # Get the caustics object, using a "." path if given + base = caustics + for part in obj["kind"].split("."): + base = getattr(base, part) + if "name" in signature(base).parameters: # type: ignore[arg-type] + kwargs["name"] = name + # Instantiate the caustics object + modules[name] = base(**kwargs) # type: ignore[operator] + + # return the last object + return modules[tuple(modules.keys())[-1]] diff --git a/src/caustics/sims/state_dict.py b/src/caustics/sims/state_dict.py deleted file mode 100644 index e8ca97c6..00000000 --- a/src/caustics/sims/state_dict.py +++ /dev/null @@ -1,310 +0,0 @@ -from datetime import datetime as dt -from collections import OrderedDict -from typing import Any, Dict, Optional -from pathlib import Path - -from torch import Tensor -import torch -from .._version import __version__ -from ..namespace_dict import NamespaceDict, NestedNamespaceDict -from .. import io - -from safetensors.torch import save, load_file - -IMMUTABLE_ERR = TypeError("'StateDict' cannot be modified after creation.") -PARAM_KEYS = ["dynamic", "static"] - - -def _sanitize(tensors_dict: Dict[str, Optional[Tensor]]) -> Dict[str, Tensor]: - """ - Sanitize the input dictionary of tensors by - replacing Nones with tensors of size 0. - - Parameters - ---------- - tensors_dict : dict - A dictionary of tensors, including None. - - Returns - ------- - dict - A dictionary of tensors, with empty tensors - replaced by tensors of size 0. - """ - return { - k: v if isinstance(v, Tensor) else torch.ones(0) - for k, v in tensors_dict.items() - } - - -def _merge_and_flatten(params: "NamespaceDict | NestedNamespaceDict") -> NamespaceDict: - """ - Extract the parameters from a nested dictionary - of parameters and merge them into a single - dictionary of parameters. - - Parameters - ---------- - params : NamespaceDict | NestedNamespaceDict - The nested dictionary of parameters - that includes both "static" and "dynamic". - - Returns - ------- - NamespaceDict - The merged dictionary of parameters. - - Raises - ------ - TypeError - If the input ``params`` is not a - ``NamespaceDict`` or ``NestedNamespaceDict``. - ValueError - If the input ``params`` is a ``NestedNamespaceDict`` - but does not have the keys ``"static"`` and ``"dynamic"``. - """ - if not isinstance(params, (NamespaceDict, NestedNamespaceDict)): - raise TypeError("params must be a NamespaceDict or NestedNamespaceDict") - - if isinstance(params, NestedNamespaceDict): - # In this case, params is the full parameters - # with both "static" and "dynamic" keys - if sorted(params.keys()) != PARAM_KEYS: - raise ValueError(f"params must have keys {PARAM_KEYS}") - - # Extract the "static" and "dynamic" parameters - param_dicts = list(params.values()) - - # Merge the "static" and "dynamic" dictionaries - # to a single merged dictionary - final_dict = NestedNamespaceDict() - for pdict in param_dicts: - for k, v in pdict.items(): - if k not in final_dict: - final_dict[k] = v - else: - final_dict[k] = {**final_dict[k], **v} - - # Flatten the dictionary to a single level - params = final_dict.flatten() - return params - - -def _get_param_values(flat_params: "NamespaceDict") -> Dict[str, Optional[Tensor]]: - """ - Get the values of the parameters from a - flattened dictionary of parameters. - - Parameters - ---------- - flat_params : NamespaceDict - A flattened dictionary of parameters. - - Returns - ------- - Dict[str, Optional[Tensor]] - A dictionary of parameter values, - these values can be a tensor or None. - """ - return {k: v.value for k, v in flat_params.items()} - - -def _extract_tensors_dict( - params: "NamespaceDict | NestedNamespaceDict", -) -> Dict[str, Optional[Tensor]]: - """ - Extract the tensors from a nested dictionary - of parameters and merge them into a single - dictionary of parameters. Then return a - dictionary of tensors by getting the parameter - tensor values. - - Parameters - ---------- - params : NestedNamespaceDict - The nested dictionary of parameters - that includes both "static" and "dynamic" - export_params : bool, optional - Whether to return the merged parameters as well, - not just the dictionary of tensors, - by default False. - - Returns - ------- - dict - A dictionary of tensors - """ - all_params = _merge_and_flatten(params) - return _get_param_values(all_params) - - -class ImmutableODict(OrderedDict): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._created = True - - def __delitem__(self, _) -> None: - raise IMMUTABLE_ERR - - def __setitem__(self, key: str, value: Any) -> None: - if hasattr(self, "_created"): - raise IMMUTABLE_ERR - super().__setitem__(key, value) - - def __setattr__(self, name, value) -> None: - if hasattr(self, "_created"): - raise IMMUTABLE_ERR - return super().__setattr__(name, value) - - -class StateDict(ImmutableODict): - """A dictionary object that is immutable after creation. - This is used to store the parameters of a simulator at a given - point in time. - - Methods - ------- - to_params() - Convert the state dict to a dictionary of parameters. - """ - - __slots__ = ("_metadata", "_created", "_created_time") - - def __init__(self, metadata=None, *args, **kwargs): - # Get created time - self._created_time = dt.utcnow() - # Create metadata - _meta = { - "software_version": __version__, - "created_time": self._created_time.isoformat(), - } - if metadata: - _meta.update(metadata) - - # Set metadata - self._metadata = ImmutableODict(_meta) - - # Now create the object, this will set _created - # to True, and prevent any further modification - super().__init__(*args, **kwargs) - - def __delitem__(self, _) -> None: - raise IMMUTABLE_ERR - - def __setitem__(self, key: str, value: Any) -> None: - if hasattr(self, "_created"): - raise IMMUTABLE_ERR - super().__setitem__(key, value) - - @classmethod - def from_params(cls, params: "NestedNamespaceDict | NamespaceDict"): - """Class method to create a StateDict - from a dictionary of parameters - - Parameters - ---------- - params : NamespaceDict - A dictionary of parameters, - can either be the full parameters - that are "static" and "dynamic", - or "static" only. - - Returns - ------- - StateDict - A state dictionary object - """ - tensors_dict = _extract_tensors_dict(params) - return cls(**tensors_dict) - - def to_params(self) -> NestedNamespaceDict: - """ - Convert the state dict to - a nested dictionary of parameters. - - Returns - ------- - NestedNamespaceDict - A nested dictionary of parameters. - """ - from ..parameter import Parameter - - params = NamespaceDict() - for k, v in self.items(): - if v.nelement() == 0: - # Set to None if the tensor is empty - v = None - params[k] = Parameter(v) - return NestedNamespaceDict(params) - - def save(self, file_path: "str | Path | None" = None) -> str: - """ - Saves the state dictionary to an optional - ``file_path`` as safetensors format. - If ``file_path`` is not given, - this will default to a file in - the current working directory. - - *Note: The path specified must - have a '.st' extension.* - - Parameters - ---------- - file_path : str, optional - The file path to save the - state dictionary to, by default None - - Returns - ------- - str - The final path of the saved file - """ - input_path: Path - - if not file_path: - input_path = Path.cwd() / self.__st_file - elif isinstance(file_path, str): - input_path = Path(file_path) - else: - input_path = file_path - - ext = ".st" - if input_path.suffix != ext: - raise ValueError(f"File must have '{ext}' extension") - - return io.to_file(input_path, self._to_safetensors()) - - @classmethod - def load(cls, file_path: str) -> "StateDict": - """ - Loads the state dictionary from a - specified ``file_path``. - - Parameters - ---------- - file_path : str - The file path to load the - state dictionary from. - - Returns - ------- - StateDict - The loaded state dictionary - """ - # TODO: Need to rethink this for remote paths - - # Load just the metadata - metadata = io.get_safetensors_metadata(file_path) - - # Load the full data to cpu first - st_dict = load_file(file_path) - st_dict = {k: v if v.nelement() > 0 else None for k, v in st_dict.items()} - return cls(metadata=metadata, **st_dict) - - @property - def __st_file(self) -> str: - file_format = "%Y%m%dT%H%M%S_caustics.st" - return self._created_time.strftime(file_format) - - def _to_safetensors(self) -> bytes: - return save(_sanitize(self), metadata=self._metadata) diff --git a/src/caustics/tests.py b/src/caustics/tests.py index 15a90457..d5c29698 100644 --- a/src/caustics/tests.py +++ b/src/caustics/tests.py @@ -58,7 +58,6 @@ def _test_simulator_runs(device=DEVICE): assert torch.all( torch.isfinite( sim( - {}, source_light=True, lens_light=True, lens_source=True, @@ -69,7 +68,6 @@ def _test_simulator_runs(device=DEVICE): assert torch.all( torch.isfinite( sim( - {}, source_light=True, lens_light=True, lens_source=False, @@ -80,7 +78,6 @@ def _test_simulator_runs(device=DEVICE): assert torch.all( torch.isfinite( sim( - {}, source_light=True, lens_light=False, lens_source=True, @@ -91,7 +88,6 @@ def _test_simulator_runs(device=DEVICE): assert torch.all( torch.isfinite( sim( - {}, source_light=False, lens_light=True, lens_source=True, @@ -115,9 +111,9 @@ def _test_jacobian_autograd_vs_finitediff(device=DEVICE): lens = lens.to(device=device) # Evaluate Jacobian - J_autograd = lens.jacobian_lens_equation(thx, thy, z_s, lens.pack(x)) + J_autograd = lens.jacobian_lens_equation(thx, thy, z_s, x) J_finitediff = lens.jacobian_lens_equation( - thx, thy, z_s, lens.pack(x), method="finitediff", pixelscale=torch.tensor(0.01) + thx, thy, z_s, x, method="finitediff", pixelscale=torch.tensor(0.01) ) assert ( @@ -154,7 +150,7 @@ def _test_multiplane_jacobian(device=DEVICE): # Parameters z_s = torch.tensor(1.2, device=device) x = torch.tensor(xs, device=device).flatten() - A = lens.jacobian_lens_equation(thx, thy, z_s, lens.pack(x)) + A = lens.jacobian_lens_equation(thx, thy, z_s, x) assert A.shape == (10, 10, 2, 2) @@ -188,9 +184,9 @@ def _test_multiplane_jacobian_autograd_vs_finitediff(device=DEVICE): x = torch.tensor(xs, device=device).flatten() # Evaluate Jacobian - J_autograd = lens.jacobian_lens_equation(thx, thy, z_s, lens.pack(x)) + J_autograd = lens.jacobian_lens_equation(thx, thy, z_s, x) J_finitediff = lens.jacobian_lens_equation( - thx, thy, z_s, lens.pack(x), method="finitediff", pixelscale=torch.tensor(0.01) + thx, thy, z_s, x, method="finitediff", pixelscale=torch.tensor(0.01) ) assert ( @@ -227,9 +223,9 @@ def _test_multiplane_effective_convergence(device=DEVICE): # Parameters z_s = torch.tensor(1.2, device=device) x = torch.tensor(xs, device=device).flatten() - C = lens.effective_convergence_div(thx, thy, z_s, lens.pack(x)) + C = lens.effective_convergence_div(thx, thy, z_s, x) assert C.shape == (10, 10) - curl = lens.effective_convergence_curl(thx, thy, z_s, lens.pack(x)) + curl = lens.effective_convergence_curl(thx, thy, z_s, x) assert curl.shape == (10, 10) diff --git a/src/caustics/utils.py b/src/caustics/utils.py index f6399979..c0efde48 100644 --- a/src/caustics/utils.py +++ b/src/caustics/utils.py @@ -1,6 +1,6 @@ # mypy: disable-error-code="misc", disable-error-code="attr-defined" -from math import pi -from typing import Callable, Optional, Tuple, Union, Any, Literal +from math import pi, ceil +from typing import Callable, Optional, Tuple, Dict, Union, Any, Literal from importlib import import_module from functools import partial, lru_cache @@ -470,8 +470,8 @@ def interp2d( padding_mode: str = "zeros", ) -> Tensor: """ - Interpolates a 2D image at specified coordinates. - Similar to `torch.nn.functional.grid_sample` with `align_corners=False`. + Interpolates a 2D image at specified coordinates. Similar to + `torch.nn.functional.grid_sample` with `align_corners=False`. Parameters ---------- @@ -482,10 +482,15 @@ def interp2d( y: Tensor A 0D or 1D tensor of y coordinates at which to interpolate. method: (str, optional) - Interpolation method. Either 'nearest' or 'linear'. Defaults to 'linear'. + Interpolation method. Either 'nearest' or 'linear'. Defaults to + 'linear'. padding_mode: (str, optional) Defines the padding mode when out-of-bound indices are encountered. - Either 'zeros' or 'extrapolate'. Defaults to 'zeros'. + Either 'zeros', 'clamp', or 'extrapolate'. Defaults to 'zeros' which + fills padded coordinates with zeros. The 'clamp' mode clamps the + coordinates to the image boundaries (essentially taking the border + values out to infinity). The 'extrapolate' mode extrapolates the outer + linear interpolation beyond the last pixel boundary. Raises ------ @@ -503,7 +508,8 @@ def interp2d( Returns ------- Tensor - Tensor with the same shape as `x` and `y` containing the interpolated values. + Tensor with the same shape as `x` and `y` containing the interpolated + values. """ if im.ndim != 2: raise ValueError(f"im must be 2D (received {im.ndim}D tensor)") @@ -514,7 +520,12 @@ def interp2d( if padding_mode not in ["extrapolate", "zeros"]: raise ValueError(f"{padding_mode} is not a valid padding mode") - idxs_out_of_bounds = (y < -1) | (y > 1) | (x < -1) | (x > 1) + if padding_mode == "clamp": + x = x.clamp(-1, 1) + y = y.clamp(-1, 1) + else: + idxs_out_of_bounds = (y < -1) | (y > 1) | (x < -1) | (x > 1) + # Convert coordinates to pixel indices h, w = im.shape x = 0.5 * ((x + 1) * w - 1) @@ -960,6 +971,120 @@ def vmap_n( return vmapd_func +def _chunk_input(x, k, in_dims, chunk_size): + if isinstance(in_dims, tuple): + if chunk_size is None: + n_chunks = 1 + else: + i = 0 + while in_dims[i] is None: + i += 1 + B = x[i].shape[in_dims[i]] + n_chunks = ceil(B / chunk_size) + + # Break data into chunks + chunks = [[] for _ in range(n_chunks)] + for subx, in_dim in zip(x, in_dims): + if in_dim is None: + subchunking = [subx] * n_chunks + else: + subchunking = subx.chunk(n_chunks, dim=in_dim) + for j, subchunk in enumerate(subchunking): + chunks[j].append(subchunk) + else: # isinstance(in_dims, dict) + if chunk_size is None: + n_chunks = 1 + else: + for key, value in in_dims.items(): + if value is not None: + B = k[key].shape[value] + n_chunks = ceil(B / chunk_size) + break + + # Break data into chunks + chunks = [{} for _ in range(n_chunks)] + for key, value in in_dims.items(): + if value is None: + subchunking = [k[key]] * n_chunks + else: + subchunking = k[key].chunk(n_chunks, dim=value) + for j, subchunk in enumerate(subchunking): + chunks[j][key] = subchunk + return chunks + + +def vmap_reduce( + func: Callable, + reduce_func: Callable = lambda x: x.sum(dim=0), + chunk_size: Optional[int] = None, + in_dims: Union[Tuple[int, ...], Dict[str, int]] = (0,), + out_dims: Union[int, Tuple[int, ...]] = 0, + **kwargs, +) -> Tensor: + """ + Applies `torch.vmap` to `func` and then reduces the output using + `reduce_func` along the appropriate dimensions. This saves on memory + management if the dimension being reduced can cause the intermediate tensor + (before reduction) to be large. + + Note + ---- + The chunking and reduction is only "one level deep". If the output of `func` + is still large even after chunking, this function will not completely solve + the problem. Essentially if the batch dimension divided by chunk_size is + still larger than chunk_size, then you will still have a large intermediate + tensor. + + Parameters + ---------- + func: Callable + The function to transform. + reduce_func: Callable + The function to reduce the output of `func`. + in_dims: Tuple[int,...] + The dimensions to vectorize over in the input. + out_dims: Tuple[int,...] + The dimension to stack the output over. + chunk_size: (Optional[int]) + The size of the chunks to process. If None, the entire input is + processed at once. + kwargs: Dict + Additional keyword arguments to pass to `torch.vmap`. + + Returns + ------- + Tensor + The reduced output. + """ + if isinstance(in_dims, tuple): + vfunc = torch.vmap(func, in_dims, **kwargs) + else: # isinstance(in_dims, dict) + vfunc = torch.vmap(func, (in_dims,), **kwargs) + + def wrapped(*x, **k): + # Determine chunks + chunks = _chunk_input(x, k, in_dims, chunk_size) + + # Process and reduce the chunks + if isinstance(in_dims, tuple): + out = tuple(reduce_func(vfunc(*chunk)) for chunk in chunks) + else: # isinstance(in_dims, dict) + out = tuple(reduce_func(vfunc(chunk)) for chunk in chunks) + + # Stack the output + if isinstance(out_dims, int): + out = torch.stack(out, dim=out_dims) + else: + out = tuple( + torch.stack([o[i] for o in out], dim=d) for i, d in enumerate(out_dims) + ) + + # Reduce the output + return reduce_func(out) + + return wrapped + + def cluster_means(xs: Tensor, k: int): """ Computes cluster means using the k-means++ initialization algorithm. diff --git a/tests/conftest.py b/tests/conftest.py index 5b1267fe..b148a9d1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,37 +2,14 @@ import os import torch import pytest -import typing # Add the helpers directory to the path so we can import the helpers sys.path.append(os.path.join(os.path.dirname(__file__), "utils")) -from caustics.models.utils import setup_pydantic_models +# from caustics.models.utils import setup_pydantic_models CUDA_AVAILABLE = torch.cuda.is_available() -LIGHT_ANNOTATED, LENSES_ANNOTATED = setup_pydantic_models() - - -def _get_models(annotated): - typehint = typing.get_args(annotated)[0] - pydantic_models = typing.get_args(typehint) - if isinstance(pydantic_models, tuple): - pydantic_models = {m.__name__: m for m in pydantic_models} - else: - pydantic_models = {pydantic_models.__name__: pydantic_models} - return pydantic_models - - -@pytest.fixture -def light_models(): - return _get_models(LIGHT_ANNOTATED) - - -@pytest.fixture -def lens_models(): - return _get_models(LENSES_ANNOTATED) - @pytest.fixture(params=["yaml", "no_yaml"]) def sim_source(request): diff --git a/tests/models/test_mod_api.py b/tests/models/test_mod_api.py deleted file mode 100644 index a8cb6a01..00000000 --- a/tests/models/test_mod_api.py +++ /dev/null @@ -1,236 +0,0 @@ -from tempfile import NamedTemporaryFile -import os -import yaml - -import pytest -import torch - -try: - from pydantic import create_model -except ImportError: - raise ImportError( - "The `pydantic` package is required to use this feature. " - "You can install it using `pip install pydantic==2.7`. This package requires rust. Make sure you have the permissions to install the dependencies.\n " - "Otherwise, the maintainer can install the package for you, you can then use `pip install --no-index pydantic`" - ) - -import caustics -from caustics.models.utils import setup_simulator_models -from caustics.models.base_models import StateConfig, Field -from utils.models import setup_complex_multiplane_yaml -import textwrap - - -@pytest.fixture -def ConfigModel(): - simulators = setup_simulator_models() - return create_model( - "Config", __base__=StateConfig, simulator=(simulators, Field(...)) - ) - - -@pytest.fixture -def x_input(): - return torch.tensor([ - # z_s z_l x0 y0 q phi b x0 y0 q phi n Re - 1.5, 0.5, -0.2, 0.0, 0.4, 1.5708, 1.7, 0.0, 0.0, 0.5, -0.985, 1.3, 1.0, - # Ie x0 y0 q phi n Re Ie - 5.0, -0.2, 0.0, 0.8, 0.0, 1., 1.0, 10.0 - ]) # fmt: skip - - -@pytest.fixture -def sim_yaml(): - return textwrap.dedent( - """\ - cosmology: &cosmo - name: cosmo - kind: FlatLambdaCDM - - lens: &lens - name: lens - kind: SIE - init_kwargs: - cosmology: *cosmo - - src: &src - name: source - kind: Sersic - - lnslt: &lnslt - name: lenslight - kind: Sersic - - simulator: - name: minisim - kind: LensSource - init_kwargs: - # Single lense - lens: *lens - source: *src - lens_light: *lnslt - pixelscale: 0.05 - pixels_x: 100 - """ - ) - - -def _write_temp_yaml(yaml_str: str): - # Create temp file - f = NamedTemporaryFile("w", delete=False) - f.write(yaml_str) - f.flush() - f.close() - - return f.name - - -@pytest.fixture -def sim_yaml_file(sim_yaml): - temp_file = _write_temp_yaml(sim_yaml) - - yield temp_file - - if os.path.exists(temp_file): - os.unlink(temp_file) - - -@pytest.fixture -def simple_config_dict(sim_yaml): - return yaml.safe_load(sim_yaml) - - -@pytest.fixture -def sim_obj(): - cosmology = caustics.FlatLambdaCDM() - sie = caustics.SIE(cosmology=cosmology, name="lens") - src = caustics.Sersic(name="source") - lnslt = caustics.Sersic(name="lenslight") - return caustics.LensSource( - lens=sie, source=src, lens_light=lnslt, pixelscale=0.05, pixels_x=100 - ) - - -def test_build_simulator(sim_yaml_file, sim_obj, x_input): - sim = caustics.build_simulator(sim_yaml_file) - - result = sim(x_input) - expected_result = sim_obj(x_input) - assert sim.graph(True, True) - assert isinstance(result, torch.Tensor) - assert torch.allclose(result, expected_result) - - -def test_complex_build_simulator(): - yaml_str = setup_complex_multiplane_yaml() - x = torch.tensor( - [ - # z_s x0 y0 q phi n Re - 1.5, - 0.0, - 0.0, - 0.5, - -0.985, - 1.3, - 1.0, - # Ie x0 y0 q phi n Re Ie - 5.0, - -0.2, - 0.0, - 0.8, - 0.0, - 1.0, - 1.0, - 10.0, - ] - ) - # Create temp file - temp_file = _write_temp_yaml(yaml_str) - - # Open the temp file and build the simulator - sim = caustics.build_simulator(temp_file) - image = sim(x) - assert isinstance(image, torch.Tensor) - - # Remove the temp file - if os.path.exists(temp_file): - os.unlink(temp_file) - - -def test_build_simulator_w_state(sim_yaml_file, sim_obj, x_input): - sim = caustics.build_simulator(sim_yaml_file) - params = dict(zip(sim.x_order, x_input)) - - # Set the parameters from x input - # using set attribute to the module objects - # this makes the the params to be static - for k, v in params.items(): - n, p = k.split(".") - if n == sim.name: - setattr(sim, p, v) - continue - key = sim._module_key_map[n] - mod = getattr(sim, key) - setattr(mod, p, v) - - state_dict = sim.state_dict() - - # Save the state - state_path = None - with NamedTemporaryFile("wb", suffix=".st", delete=False) as f: - state_path = f.name - state_dict.save(state_path) - - # Add the path to state to the sim yaml - with open(sim_yaml_file, "a") as f: - f.write( - textwrap.dedent( - f""" - state: - load: - path: {state_path} - """ - ) - ) - - # Load the state - # First remove the original sim - del sim - newsim = caustics.build_simulator(sim_yaml_file) - result = newsim() - expected_result = sim_obj(x_input) - assert newsim.graph(True, True) - assert isinstance(result, torch.Tensor) - assert torch.allclose(result, expected_result) - - -@pytest.mark.parametrize( - "psf", - [ - { - "func": "caustics.utils.gaussian", - "kwargs": { - "pixelscale": 0.05, - "nx": 11, - "ny": 12, - "sigma": 0.2, - "upsample": 2, - }, - }, - # {"function": "caustics.utils.gaussian", "sigma": 0.2}, - [[2.0], [2.0]], - ], -) -@pytest.mark.parametrize("pixels_y", ["50", 50.3]) # will get casted to int -def test_init_kwargs_validate(ConfigModel, simple_config_dict, psf, pixels_y): - # Add psf - test_config_dict = {**simple_config_dict} - test_config_dict["simulator"]["init_kwargs"]["psf"] = psf - test_config_dict["simulator"]["init_kwargs"]["pixels_y"] = pixels_y - if isinstance(psf, dict) and "func" not in psf: - with pytest.raises(ValueError): - ConfigModel(**test_config_dict) - else: - # Test that the init_kwargs are validated - config = ConfigModel(**test_config_dict) - assert config.simulator.model_obj() diff --git a/tests/models/test_mod_registry.py b/tests/models/test_mod_registry.py deleted file mode 100644 index 2e10c0a0..00000000 --- a/tests/models/test_mod_registry.py +++ /dev/null @@ -1,97 +0,0 @@ -import pytest - -import caustics -from caustics.models.registry import ( - _KindRegistry, - available_kinds, - register_kind, - get_kind, - _registry, -) -from caustics.parameter import Parameter -from caustics.parametrized import Parametrized - - -class TestKindRegistry: - expected_attrs = [ - "cosmology", - "single_lenses", - "multi_lenses", - "light", - "simulators", - "known_kinds", - "_m", - ] - - def test_constructor(self): - registry = _KindRegistry() - - for attr in self.expected_attrs: - assert hasattr(registry, attr) - - @pytest.mark.parametrize("kind", ["NonExistingClass", "SIE", caustics.Sersic]) - def test_getitem(self, kind, mocker): - registry = _KindRegistry() - - if kind == "NonExistingClass": - with pytest.raises(KeyError): - registry[kind] - elif isinstance(kind, str): - cls = registry[kind] - assert cls == getattr(caustics, kind) - else: - test_key = "TestSersic" - registry.known_kinds[test_key] = kind - cls = registry[test_key] - assert cls == kind - - @pytest.mark.parametrize("kind", [Parameter, caustics.Sersic, "caustics.SIE"]) - def test_setitem(self, kind): - registry = _KindRegistry() - key = "TestSersic" - if isinstance(kind, str): - registry[key] = kind - assert key in registry._m - elif issubclass(kind, Parametrized): - registry[key] = kind - assert registry[key] == kind - else: - with pytest.raises(ValueError): - registry[key] = kind - - def test_delitem(self): - registry = _KindRegistry() - with pytest.raises(NotImplementedError): - del registry["Sersic"] - - def test_len(self): - registry = _KindRegistry() - assert len(registry) == len(set(registry._m)) - - def test_iter(self): - registry = _KindRegistry() - assert set(registry) == set(registry._m) - - -def test_available_kinds(): - assert available_kinds() == list(_registry) - - -def test_register_kind(): - key = "TestSersic2" - value = caustics.Sersic - register_kind(key, value) - assert key in _registry._m - assert _registry[key] == value - - with pytest.raises(ValueError): - register_kind("SIE", "caustics.SIE") - - -def test_get_kind(): - kind = "Sersic" - cls = get_kind(kind) - assert cls == caustics.Sersic - kind = "NonExistingClass" - with pytest.raises(KeyError): - cls = get_kind(kind) diff --git a/tests/models/test_mod_utils.py b/tests/models/test_mod_utils.py deleted file mode 100644 index edd88a7d..00000000 --- a/tests/models/test_mod_utils.py +++ /dev/null @@ -1,118 +0,0 @@ -import pytest -import inspect -import typing -from typing import Annotated, Dict -from caustics.models.registry import _registry, get_kind -from caustics.models.utils import ( - create_pydantic_model, - setup_pydantic_models, - setup_simulator_models, - _get_kwargs_field_definitions, - PARAMS, - INIT_KWARGS, -) -from caustics.models.base_models import Base -from caustics.parametrized import ClassParam - - -@pytest.fixture(params=_registry.known_kinds) -def kind(request): - return request.param - - -@pytest.fixture -def parametrized_class(kind): - return get_kind(kind) - - -def test_create_pydantic_model(kind): - model = create_pydantic_model(kind) - kind_cls = get_kind(kind) - expected_fields = {"kind", "name", "params", "init_kwargs"} - - assert model.__base__ == Base - assert model.__name__ == kind - assert model._cls == kind_cls - assert set(model.model_fields.keys()) == expected_fields - - -def test__get_kwargs_field_definitions(parametrized_class): - kwargs_fd = _get_kwargs_field_definitions(parametrized_class) - - cls_signature = inspect.signature(parametrized_class) - class_metadata = { - k: { - "dtype": v.annotation.__origin__, - "default": v.default, - "class_param": ClassParam(*v.annotation.__metadata__), - } - for k, v in cls_signature.parameters.items() - } - - for k, v in class_metadata.items(): - if k != "name": - if v["class_param"].isParam: - assert k in kwargs_fd[PARAMS] - assert isinstance(kwargs_fd[PARAMS][k], tuple) - assert kwargs_fd[PARAMS][k][0] == v["dtype"] - field_info = kwargs_fd[PARAMS][k][1] - else: - assert k in kwargs_fd[INIT_KWARGS] - assert isinstance(kwargs_fd[INIT_KWARGS][k], tuple) - assert kwargs_fd[INIT_KWARGS][k][0] == v["dtype"] - field_info = kwargs_fd[INIT_KWARGS][k][1] - - if v["default"] == inspect._empty: - # Skip empty defaults - continue - assert field_info.default == v["default"] - - -def _check_nested_discriminated_union( - input_anno: type[Annotated], class_paths: Dict[str, str] -): - # Check to see if the model selection is Annotated type - assert typing.get_origin(input_anno) == Annotated - # Check to see if the discriminator is "kind" - assert input_anno.__metadata__[0].discriminator == "kind" - - if typing.get_origin(input_anno.__origin__) == typing.Union: - models = input_anno.__origin__.__args__ - else: - # For single models - models = [input_anno.__origin__] - - # Check to see if the models are in the registry - assert len(models) == len(class_paths) - # Go through each model and check that it's pointing to the right class - for model in models: - assert model.__name__ in class_paths - assert model._cls == get_kind(model.__name__) - - -def test_setup_pydantic_models(): - # light, lenses - pydantic_models_annotated = setup_pydantic_models() - - registry_dict = { - "light": _registry.light, - "lenses": { - **_registry.single_lenses, - **_registry.multi_lenses, - }, - } - - pm_anno_dict = { - k: v for (k, v) in zip(list(registry_dict.keys()), pydantic_models_annotated) - } - - for key, pydantic_model_anno in pm_anno_dict.items(): - class_paths = registry_dict[key] - _check_nested_discriminated_union(pydantic_model_anno, class_paths) - - -def test_setup_simulator_models(): - simulators = setup_simulator_models() - - class_paths = _registry.simulators - _check_nested_discriminated_union(simulators, class_paths) diff --git a/tests/sims/conftest.py b/tests/sims/conftest.py deleted file mode 100644 index 7f45cce7..00000000 --- a/tests/sims/conftest.py +++ /dev/null @@ -1,40 +0,0 @@ -import pytest - -from caustics.sims.simulator import Simulator -from caustics.lenses import EPL -from caustics.light import Sersic -from caustics.cosmology import FlatLambdaCDM - - -@pytest.fixture -def test_epl_values(): - return { - "z_l": 0.5, - "phi": 0.0, - "b": 1.0, - "t": 1.0, - } - - -@pytest.fixture -def test_sersic_values(): - return { - "q": 0.9, - "phi": 0.3, - "n": 1.0, - } - - -@pytest.fixture -def simple_common_sim(test_epl_values, test_sersic_values): - class Sim(Simulator): - def __init__(self): - super().__init__() - self.cosmo = FlatLambdaCDM(h0=None) - self.epl = EPL(self.cosmo, **test_epl_values) - self.sersic = Sersic(**test_sersic_values) - self.add_param("z_s", 1.0) - - sim = Sim() - yield sim - del sim diff --git a/tests/sims/test_simulator.py b/tests/sims/test_simulator.py deleted file mode 100644 index 611f54b6..00000000 --- a/tests/sims/test_simulator.py +++ /dev/null @@ -1,92 +0,0 @@ -import pytest -from pathlib import Path -import sys - -import torch - -from caustics.sims.state_dict import ( - StateDict, - _extract_tensors_dict, -) - - -@pytest.fixture -def state_dict(simple_common_sim): - return simple_common_sim.state_dict() - - -@pytest.fixture -def expected_tensors(simple_common_sim): - tensors_dict = _extract_tensors_dict(simple_common_sim.params) - return tensors_dict - - -class TestSimulator: - def test_state_dict(self, state_dict, expected_tensors): - # Check state_dict type and default keys - assert isinstance(state_dict, StateDict) - - # Trying to modify state_dict should raise TypeError - with pytest.raises(TypeError): - state_dict["params"] = -1 - - # Check _metadata keys - assert "software_version" in state_dict._metadata - assert "created_time" in state_dict._metadata - - # Check params - assert dict(state_dict) == expected_tensors - - def test_set_module_params(self, simple_common_sim): - params = {"param1": torch.as_tensor(1), "param2": torch.as_tensor(2)} - # Call the __set_module_params method - simple_common_sim._Simulator__set_module_params(simple_common_sim, params) - - # Check if the module attributes have been set correctly - assert simple_common_sim.param1 == params["param1"] - assert simple_common_sim.param2 == params["param2"] - - @pytest.mark.skipif( - sys.platform.startswith("win"), - reason="Built-in open has different behavior on Windows", - ) - def test_load_state_dict(self, simple_common_sim): - simple_common_sim.epl.x0 = 0.0 - simple_common_sim.sersic.x0 = 1.0 - - fpath = simple_common_sim.state_dict().save() - loaded_state_dict = StateDict.load(fpath) - - # Change a value in the simulator - simple_common_sim.z_s = 3.0 - simple_common_sim.epl.x0 = None - simple_common_sim.sersic.x0 = None - - # Ensure that the simulator has been changed - assert ( - loaded_state_dict[f"{simple_common_sim.name}.z_s"] - != simple_common_sim.z_s.value - ) - assert ( - loaded_state_dict[f"{simple_common_sim.epl.name}.x0"] - != simple_common_sim.epl.x0.value - ) - assert ( - loaded_state_dict[f"{simple_common_sim.sersic.name}.x0"] - != simple_common_sim.sersic.x0.value - ) - - # Load the state dict form file - simple_common_sim.load_state_dict(fpath) - - # Once loaded now the values should be the same - assert ( - loaded_state_dict[f"{simple_common_sim.name}.z_s"] - == simple_common_sim.z_s.value - ) - - assert simple_common_sim.epl.x0.value == torch.as_tensor(0.0) - assert simple_common_sim.sersic.x0.value == torch.as_tensor(1.0) - - # Cleanup after - Path(fpath).unlink() diff --git a/tests/sims/test_state_dict.py b/tests/sims/test_state_dict.py deleted file mode 100644 index a309c270..00000000 --- a/tests/sims/test_state_dict.py +++ /dev/null @@ -1,189 +0,0 @@ -from pathlib import Path -from tempfile import TemporaryDirectory -import sys - -import pytest -import torch -from collections import OrderedDict -from safetensors.torch import save, load -from datetime import datetime as dt -from caustics.parameter import Parameter -from caustics.namespace_dict import NamespaceDict, NestedNamespaceDict -from caustics.sims.state_dict import ( - ImmutableODict, - StateDict, - IMMUTABLE_ERR, - _sanitize, - _merge_and_flatten, - _get_param_values, -) -from caustics import __version__ - - -class TestImmutableODict: - def test_constructor(self): - odict = ImmutableODict(a=1, b=2, c=3) - assert isinstance(odict, OrderedDict) - assert odict == {"a": 1, "b": 2, "c": 3} - assert hasattr(odict, "_created") - assert odict._created is True - - def test_setitem(self): - odict = ImmutableODict() - with pytest.raises(type(IMMUTABLE_ERR), match=str(IMMUTABLE_ERR)): - odict["key"] = "value" - - def test_delitem(self): - odict = ImmutableODict(key="value") - with pytest.raises(type(IMMUTABLE_ERR), match=str(IMMUTABLE_ERR)): - del odict["key"] - - def test_setattr(self): - odict = ImmutableODict() - with pytest.raises(type(IMMUTABLE_ERR), match=str(IMMUTABLE_ERR)): - odict.meta = {"key": "value"} - - -class TestStateDict: - simple_tensors = {"var1": torch.as_tensor(1.0), "var2": torch.as_tensor(2.0)} - - @pytest.fixture(scope="class") - def simple_state_dict(self): - return StateDict(**self.simple_tensors) - - def test_constructor(self): - time_format = "%Y-%m-%dT%H:%M:%S" - time_str_now = dt.utcnow().strftime(time_format) - state_dict = StateDict(**self.simple_tensors) - - # Get the created time and format to nearest seconds - sd_ct_dt = dt.fromisoformat(state_dict._metadata["created_time"]) - sd_ct_str = sd_ct_dt.strftime(time_format) - - # Check the default metadata and content - assert hasattr(state_dict, "_metadata") - assert state_dict._created is True - assert state_dict._metadata["software_version"] == __version__ - assert sd_ct_str == time_str_now - assert dict(state_dict) == self.simple_tensors - - def test_constructor_with_metadata(self): - time_format = "%Y-%m-%dT%H:%M:%S" - time_str_now = dt.utcnow().strftime(time_format) - metadata = {"created_time": time_str_now, "software_version": "0.0.1"} - state_dict = StateDict(metadata=metadata, **self.simple_tensors) - - assert isinstance(state_dict._metadata, ImmutableODict) - assert dict(state_dict._metadata) == dict(metadata) - - def test_setitem(self, simple_state_dict): - with pytest.raises(type(IMMUTABLE_ERR), match=str(IMMUTABLE_ERR)): - simple_state_dict["var1"] = torch.as_tensor(3.0) - - def test_delitem(self, simple_state_dict): - with pytest.raises(type(IMMUTABLE_ERR), match=str(IMMUTABLE_ERR)): - del simple_state_dict["var1"] - - def test_from_params(self, simple_common_sim): - params: NestedNamespaceDict = simple_common_sim.params - all_params = _merge_and_flatten(params) - tensors_dict = _get_param_values(all_params) - - expected_state_dict = StateDict(**tensors_dict) - - # Full parameters - state_dict = StateDict.from_params(params) - assert state_dict == expected_state_dict - - # Static only - state_dict = StateDict.from_params(all_params) - assert state_dict == expected_state_dict - - # Check for TypeError when passing a NamespaceDict or NestedNamespaceDict - with pytest.raises(TypeError): - StateDict.from_params({"a": 1, "b": 2}) - - # Check for TypeError when passing a NestedNamespaceDict - # without the "static" and "dynamic" keys - with pytest.raises(ValueError): - StateDict.from_params(NestedNamespaceDict({"a": 1, "b": 2})) - - def test_to_params(self): - params_with_none = {"var3": torch.ones(0), **self.simple_tensors} - state_dict = StateDict(**params_with_none) - params = StateDict(**params_with_none).to_params() - assert isinstance(params, NamespaceDict) - - for k, v in params.items(): - tensor_value = state_dict[k] - if tensor_value.nelement() > 0: - assert isinstance(v, Parameter) - assert v.value == tensor_value - - def test__to_safetensors(self): - state_dict = StateDict(**self.simple_tensors) - # Save to safetensors - tensors_bytes = state_dict._to_safetensors() - expected_bytes = save(_sanitize(state_dict), metadata=state_dict._metadata) - - # Reload to back to tensors dict - # this is done because the information - # might be stored in different arrangements - # within the safetensors bytes - loaded_tensors = load(tensors_bytes) - loaded_expected_tensors = load(expected_bytes) - assert loaded_tensors == loaded_expected_tensors - - def test_st_file_string(self, simple_state_dict): - file_format = "%Y%m%dT%H%M%S_caustics.st" - expected_file = simple_state_dict._created_time.strftime(file_format) - - assert simple_state_dict._StateDict__st_file == expected_file - - @pytest.mark.skipif( - sys.platform.startswith("win"), - reason="Built-in open has different behavior on Windows", - ) - def test_save(self, simple_state_dict): - # Check for default save path - expected_fpath = Path.cwd() / simple_state_dict._StateDict__st_file - default_fpath = simple_state_dict.save() - - assert Path(default_fpath).exists() - assert default_fpath == str(expected_fpath.absolute()) - - # Cleanup after - Path(default_fpath).unlink() - - # Check for specified save path - with TemporaryDirectory() as tempdir: - tempdir = Path(tempdir) - # Correct extension and path in a tempdir - fpath = tempdir / "test.st" - saved_path = simple_state_dict.save(str(fpath.absolute())) - - assert Path(saved_path).exists() - assert saved_path == str(fpath.absolute()) - - # Test save Path - fpath1 = tempdir / "test1.st" - saved_path = simple_state_dict.save(fpath1) - assert Path(saved_path).exists() - assert saved_path == str(fpath1.absolute()) - - # Wrong extension - wrong_fpath = tempdir / "test.txt" - with pytest.raises(ValueError): - saved_path = simple_state_dict.save(str(wrong_fpath.absolute())) - - @pytest.mark.skipif( - sys.platform.startswith("win"), - reason="Built-in open has different behavior on Windows", - ) - def test_load(self, simple_state_dict): - fpath = simple_state_dict.save() - loaded_state_dict = StateDict.load(fpath) - assert loaded_state_dict == simple_state_dict - - # Cleanup after - Path(fpath).unlink() diff --git a/tests/test_batchedplane.py b/tests/test_batchedplane.py new file mode 100644 index 00000000..a7ced6a0 --- /dev/null +++ b/tests/test_batchedplane.py @@ -0,0 +1,47 @@ +from math import pi + +import torch + +from caustics.cosmology import FlatLambdaCDM +from caustics.lenses import SIE, BatchedPlane +from caustics.utils import meshgrid + + +def test_batchedplane(): + + z_s = torch.tensor(1.2) + z_l = torch.tensor(0.5) + cosmology = FlatLambdaCDM(name="cosmo") + internallens = SIE(name="sie", cosmology=cosmology, z_l=z_l) + + lens = BatchedPlane(name="lens", lens=internallens, cosmology=cosmology, z_l=z_l) + x = torch.tensor([0.912, -0.442, 0.5, pi / 3, 1.0]).reshape(1, 5).repeat(10, 1) + + n_pix = 10 + res = 0.05 + upsample_factor = 2 + thx, thy = meshgrid( + res / upsample_factor, + upsample_factor * n_pix, + upsample_factor * n_pix, + dtype=torch.float32, + ) + + ax, ay = lens.reduced_deflection_angle(thx, thy, z_s, x) + + in_ax, in_ay = internallens.reduced_deflection_angle(thx, thy, z_s, x[0]) + + assert torch.allclose(ax, 10 * in_ax) + assert torch.allclose(ay, 10 * in_ay) + + kappa = lens.convergence(thx, thy, z_s, x) + + in_kappa = internallens.convergence(thx, thy, z_s, x[0]) + + assert torch.allclose(kappa, 10 * in_kappa) + + potential = lens.potential(thx, thy, z_s, x) + + in_potential = internallens.potential(thx, thy, z_s, x[0]) + + assert torch.allclose(potential, 10 * in_potential) diff --git a/tests/test_batching.py b/tests/test_batching.py deleted file mode 100644 index b151e938..00000000 --- a/tests/test_batching.py +++ /dev/null @@ -1,79 +0,0 @@ -import torch -from torch import vmap -from utils import setup_image_simulator, setup_simulator - - -def test_vmapped_simulator(device): - sim, (sim_params, cosmo_params, lens_params, source_params) = setup_simulator( - batched_params=True, - device=device, - ) - n_pix = sim.n_pix - print(sim.params) - - # test list input - x = sim_params + cosmo_params + lens_params + source_params - print(x[0].shape) - assert vmap(sim)(x).shape == torch.Size([2, n_pix, n_pix]) - - # test tensor input - x_tensor = torch.stack(x, dim=1) - print(x_tensor.shape) - assert vmap(sim)(x_tensor).shape == torch.Size([2, n_pix, n_pix]) - - # Test dictionary input: Only module with dynamic parameters are required - x_dict = { - "simulator": sim_params, - "cosmo": cosmo_params, - "source": source_params, - "lens": lens_params, - } - print(x_dict) - assert vmap(sim)(x_dict).shape == torch.Size([2, n_pix, n_pix]) - - # Test semantic list (one tensor per module) - x_semantic = [sim_params, cosmo_params, lens_params, source_params] - assert vmap(sim)(x_semantic).shape == torch.Size([2, n_pix, n_pix]) - - -def test_vmapped_simulator_with_pixelated_modules(device): - sim, (cosmo_params, lens_params, kappa, source) = setup_image_simulator( - batched_params=True, device=device - ) - n_pix = sim.n_pix - print(sim.params) - - # test list input - x = cosmo_params + lens_params + kappa + source - print(x[2].shape) - assert vmap(sim)(x).shape == torch.Size([2, n_pix, n_pix]) - - # test tensor input: Does not work well with images since it would require - # unflattening the images in caustics - # x_tensor = torch.concat([_x.view(2, -1) for _x in x], dim=1) - # print(x_tensor.shape) - # assert vmap(sim)(x_tensor).shape == torch.Size([2, n_pix, n_pix]) - - # Test dictionary input: Only module with dynamic parameters are required - x_dict = { - "cosmo": cosmo_params, - "source": source, - "lens": lens_params, - "kappa": kappa, - } - print(x_dict) - assert vmap(sim)(x_dict).shape == torch.Size([2, n_pix, n_pix]) - - # Test passing tensor in source and kappa instead of list - x_dict = { - "cosmo": cosmo_params, - "source": source[0], - "lens": lens_params, - "kappa": kappa[0], - } - print(x_dict) - assert vmap(sim)(x_dict).shape == torch.Size([2, n_pix, n_pix]) - - # Test semantic list (one tensor per module) - x_semantic = [cosmo_params, lens_params, kappa, source] - assert vmap(sim)(x_semantic).shape == torch.Size([2, n_pix, n_pix]) diff --git a/tests/test_epl.py b/tests/test_epl.py index 1c854e13..58e43759 100644 --- a/tests/test_epl.py +++ b/tests/test_epl.py @@ -1,5 +1,5 @@ from math import pi -import yaml +from io import StringIO import lenstronomy.Util.param_util as param_util import torch @@ -8,6 +8,7 @@ from caustics.cosmology import FlatLambdaCDM from caustics.lenses import EPL +from caustics.sims import build_simulator import numpy as np import pytest @@ -17,7 +18,7 @@ @pytest.mark.parametrize("phi", [pi / 3, -pi / 4]) @pytest.mark.parametrize("b", [0.1, 1.0]) @pytest.mark.parametrize("t", [0.1, 1.0, 1.9]) -def test_lenstronomy_epl(sim_source, device, lens_models, q, phi, b, t): +def test_lenstronomy_epl(sim_source, device, q, phi, b, t): if sim_source == "yaml": yaml_str = """\ cosmology: &cosmology @@ -29,9 +30,8 @@ def test_lenstronomy_epl(sim_source, device, lens_models, q, phi, b, t): init_kwargs: cosmology: *cosmology """ - yaml_dict = yaml.safe_load(yaml_str.encode("utf-8")) - mod = lens_models.get("EPL") - lens = mod(**yaml_dict["lens"]).model_obj() + with StringIO(yaml_str) as f: + lens = build_simulator(f) else: # Models cosmology = FlatLambdaCDM(name="cosmo") diff --git a/tests/test_external_shear.py b/tests/test_external_shear.py index 8af65ee4..5fb3bc8c 100644 --- a/tests/test_external_shear.py +++ b/tests/test_external_shear.py @@ -1,13 +1,15 @@ +from io import StringIO + import torch -import yaml from lenstronomy.LensModel.lens_model import LensModel -from utils import lens_test_helper +from utils import lens_test_helper from caustics.cosmology import FlatLambdaCDM from caustics.lenses import ExternalShear +from caustics.sims import build_simulator -def test(sim_source, device, lens_models): +def test(sim_source, device): atol = 1e-5 rtol = 1e-5 @@ -22,9 +24,8 @@ def test(sim_source, device, lens_models): init_kwargs: cosmology: *cosmology """ - yaml_dict = yaml.safe_load(yaml_str.encode("utf-8")) - mod = lens_models.get("ExternalShear") - lens = mod(**yaml_dict["lens"]).model_obj() + with StringIO(yaml_str) as f: + lens = build_simulator(f) else: # Models cosmology = FlatLambdaCDM(name="cosmo") @@ -49,7 +50,3 @@ def test(sim_source, device, lens_models): lens_test_helper( lens, lens_ls, z_s, x, kwargs_ls, rtol, atol, test_kappa=False, device=device ) - - -if __name__ == "__main__": - test(None) diff --git a/tests/test_io.py b/tests/test_io.py deleted file mode 100644 index d3116b7c..00000000 --- a/tests/test_io.py +++ /dev/null @@ -1,63 +0,0 @@ -from pathlib import Path -import tempfile -import struct -import json -import torch -from safetensors.torch import save -from caustics.io import ( - _get_safetensors_header, - _normalize_path, - to_file, - from_file, - get_safetensors_metadata, -) - - -def test_normalize_path(): - path_obj = Path().joinpath("path", "to", "file.txt") - # Test with a string path - path_str = str(path_obj) - normalized_path = _normalize_path(path_str) - assert normalized_path == path_obj.absolute() - assert str(normalized_path) == str(path_obj.absolute()) - - # Test with a Path object - normalized_path = _normalize_path(path_obj) - assert normalized_path == path_obj.absolute() - - -def test_to_and_from_file(): - with tempfile.TemporaryDirectory() as tmpdir: - fpath = Path(tmpdir) / "test.txt" - data = "test data" - - # Test to file - ffile = to_file(fpath, data) - - assert Path(ffile).exists() - assert ffile == str(fpath.absolute()) - assert Path(ffile).read_text() == data - - # Test from file - assert from_file(fpath) == data.encode("utf-8") - - -def test_get_safetensors_metadata(): - with tempfile.TemporaryDirectory() as tmpdir: - fpath = Path(tmpdir) / "test.st" - meta_dict = {"meta": "data"} - tensors_bytes = save({"test1": torch.as_tensor(1.0)}, metadata=meta_dict) - fpath.write_bytes(tensors_bytes) - - # Manually get header - first_bytes_length = 8 - (length_of_header,) = struct.unpack(" parameters and make it more intuitive -def test_to_method(): - sim, (sim_params, cosmo_params, lens_params, source_params) = setup_simulator( - batched_params=True - ) - - # Check that static params have correct type - module = Sersic(x0=0.5) - assert module.x0.dtype == torch.float32 - - module = Sersic(x0=torch.tensor(0.5)) - assert module.x0.dtype == torch.float32 - - module = Sersic(x0=np.array(0.5)) - assert ( - module.x0.dtype == torch.float64 - ) # Decided against default type, so gets numpy type here - - # Check that all parameters are converted to correct type - sim.to(dtype=torch.float16) - assert sim.z_s.dtype is None # dynamic parameter - assert sim.lens.cosmo.Om0.dtype == torch.float16 - assert sim.cosmo.Om0.dtype == torch.float16 - - -def test_parameter_redefinition(): - sim, _ = setup_simulator() - - # Make sure the __getattribute__ method works as intended - assert isinstance(sim.z_s, Parameter) - # Now test __setattr__ method, we need to catch the redefinition here and keep z_s a parameter - sim.z_s = 42 - # make sure z_s is still a parameter - assert sim.z_s.value == torch.tensor(42).float() - assert sim.z_s.static is True - sim.z_s = None - assert sim.z_s.value is None - assert sim.z_s.dynamic is True diff --git a/tests/test_point.py b/tests/test_point.py index 022e3fcc..02d9466f 100644 --- a/tests/test_point.py +++ b/tests/test_point.py @@ -1,16 +1,18 @@ +from io import StringIO + import torch -import yaml from lenstronomy.LensModel.lens_model import LensModel from utils import lens_test_helper from caustics.cosmology import FlatLambdaCDM from caustics.lenses import Point +from caustics.sims import build_simulator import pytest @pytest.mark.parametrize("th_ein", [0.1, 1.0, 2.0]) -def test_point_lens(sim_source, device, lens_models, th_ein): +def test_point_lens(sim_source, device, th_ein): atol = 1e-5 rtol = 1e-5 z_l = torch.tensor(0.9) @@ -23,14 +25,12 @@ def test_point_lens(sim_source, device, lens_models, th_ein): lens: &lens name: point kind: Point - params: - z_l: {float(z_l)} init_kwargs: + z_l: {float(z_l)} cosmology: *cosmology """ - yaml_dict = yaml.safe_load(yaml_str.encode("utf-8")) - mod = lens_models.get("Point") - lens = mod(**yaml_dict["lens"]).model_obj() + with StringIO(yaml_str) as f: + lens = build_simulator(f) else: # Models cosmology = FlatLambdaCDM(name="cosmo") diff --git a/tests/test_pseudo_jaffe.py b/tests/test_pseudo_jaffe.py index 6e45b84e..bf758603 100644 --- a/tests/test_pseudo_jaffe.py +++ b/tests/test_pseudo_jaffe.py @@ -1,17 +1,19 @@ +from io import StringIO + import torch -import yaml from lenstronomy.LensModel.lens_model import LensModel from utils import lens_test_helper from caustics.cosmology import FlatLambdaCDM from caustics.lenses import PseudoJaffe +from caustics.sims import build_simulator import pytest @pytest.mark.parametrize("mass", [1e8, 1e10, 1e12]) @pytest.mark.parametrize("Rc,Rs", [[1.0, 10.0], [1e-2, 1.0], [0.5, 1.0]]) -def test_pseudo_jaffe(sim_source, device, lens_models, mass, Rc, Rs): +def test_pseudo_jaffe(sim_source, device, mass, Rc, Rs): atol = 1e-5 rtol = 1e-5 @@ -26,9 +28,8 @@ def test_pseudo_jaffe(sim_source, device, lens_models, mass, Rc, Rs): init_kwargs: cosmology: *cosmology """ - yaml_dict = yaml.safe_load(yaml_str.encode("utf-8")) - mod = lens_models.get("PseudoJaffe") - lens = mod(**yaml_dict["lens"]).model_obj() + with StringIO(yaml_str) as f: + lens = build_simulator(f) cosmology = lens.cosmology else: # Models diff --git a/tests/test_sersic.py b/tests/test_sersic.py index 0c8089e9..5ab66a7f 100644 --- a/tests/test_sersic.py +++ b/tests/test_sersic.py @@ -1,12 +1,14 @@ +from io import StringIO + import lenstronomy.Util.param_util as param_util import numpy as np -import yaml import torch from lenstronomy.Data.pixel_grid import PixelGrid from lenstronomy.LightModel.light_model import LightModel from caustics.light import Sersic from caustics.utils import meshgrid +from caustics.sims import build_simulator import pytest @@ -14,7 +16,7 @@ @pytest.mark.parametrize("q", [0.2, 0.7]) @pytest.mark.parametrize("n", [1.0, 2.0, 3.0]) @pytest.mark.parametrize("th_e", [1.0, 10.0]) -def test_sersic(sim_source, device, light_models, q, n, th_e): +def test_sersic(sim_source, device, q, n, th_e): # Caustics setup res = 0.05 nx = 200 @@ -29,9 +31,8 @@ def test_sersic(sim_source, device, light_models, q, n, th_e): init_kwargs: use_lenstronomy_k: true """ - yaml_dict = yaml.safe_load(yaml_str.encode("utf-8")) - mod = light_models.get("Sersic") - sersic = mod(**yaml_dict["light"]).model_obj() + with StringIO(yaml_str) as f: + sersic = build_simulator(f) else: sersic = Sersic(name="sersic", use_lenstronomy_k=True) sersic.to(device=device) diff --git a/tests/test_sie.py b/tests/test_sie.py index 8b379936..16f7a879 100644 --- a/tests/test_sie.py +++ b/tests/test_sie.py @@ -1,5 +1,5 @@ from math import pi -import yaml +from io import StringIO import lenstronomy.Util.param_util as param_util import torch @@ -9,6 +9,7 @@ from caustics.cosmology import FlatLambdaCDM from caustics.lenses import SIE from caustics.utils import meshgrid +from caustics.sims import build_simulator import pytest @@ -16,7 +17,7 @@ @pytest.mark.parametrize("q", [0.5, 0.7, 0.9]) @pytest.mark.parametrize("phi", [pi / 3, -pi / 4, pi / 6]) @pytest.mark.parametrize("th_ein", [0.1, 1.0, 2.5]) -def test_sie(sim_source, device, lens_models, q, phi, th_ein): +def test_sie(sim_source, device, q, phi, th_ein): atol = 1e-5 rtol = 1e-3 @@ -31,9 +32,8 @@ def test_sie(sim_source, device, lens_models, q, phi, th_ein): init_kwargs: cosmology: *cosmology """ - yaml_dict = yaml.safe_load(yaml_str.encode("utf-8")) - mod = lens_models.get("SIE") - lens = mod(**yaml_dict["lens"]).model_obj() + with StringIO(yaml_str) as f: + lens = build_simulator(f) else: # Models cosmology = FlatLambdaCDM(name="cosmo") diff --git a/tests/test_simulator_runs.py b/tests/test_simulator_runs.py index faf470e3..4a5baa2d 100644 --- a/tests/test_simulator_runs.py +++ b/tests/test_simulator_runs.py @@ -1,3 +1,4 @@ +from io import StringIO from math import pi import torch @@ -9,10 +10,8 @@ from caustics.utils import gaussian from caustics import build_simulator -from utils import mock_from_file - -def test_simulator_runs(sim_source, device, mocker): +def test_simulator_runs(sim_source, device): if sim_source == "yaml": yaml_str = """\ cosmology: &cosmology @@ -22,24 +21,23 @@ def test_simulator_runs(sim_source, device, mocker): lensmass: &lensmass name: lens kind: SIE - params: + init_kwargs: z_l: 1.0 x0: 0.0 y0: 0.01 q: 0.5 - phi: pi / 3.0 + phi: 1.05 b: 1.0 - init_kwargs: cosmology: *cosmology source: &source name: source kind: Sersic - params: + init_kwargs: x0: 0.01 y0: -0.03 q: 0.6 - phi: -pi / 4 + phi: -0.785 n: 1.5 Re: 0.5 Ie: 1.0 @@ -47,18 +45,18 @@ def test_simulator_runs(sim_source, device, mocker): lenslight: &lenslight name: lenslight kind: Sersic - params: + init_kwargs: x0: 0.0 y0: 0.01 q: 0.7 - phi: pi / 4 + phi: 0.785 n: 3.0 Re: 0.7 Ie: 1.0 psf: &psf - func: caustics.utils.gaussian - kwargs: + kind: utils.gaussian + init_kwargs: pixelscale: 0.05 nx: 11 ny: 11 @@ -68,23 +66,19 @@ def test_simulator_runs(sim_source, device, mocker): simulator: name: simulator kind: LensSource - params: - z_s: 2.0 init_kwargs: # Single lens + z_s: 2.0 lens: *lensmass source: *source lens_light: *lenslight pixelscale: 0.05 - pixels_x: 50 - psf: *psf{quad_level} + pixels_x: 50{quad_level} """ - mock_from_file( - mocker, yaml_str.format(quad_level="") - ) # fixme, yaml should be able to accept None - sim = build_simulator("/path/to/sim.yaml") # Path doesn't actually exists - mock_from_file(mocker, yaml_str.format(quad_level="\n quad_level: 3")) - sim_q3 = build_simulator("/path/to/sim.yaml") # Path doesn't actually exists + with StringIO(yaml_str.format(quad_level="")) as f: + sim = build_simulator(f) + with StringIO(yaml_str.format(quad_level="\n quad_level: 3")) as f: + sim_q3 = build_simulator(f) else: # Model cosmology = FlatLambdaCDM(name="cosmo") @@ -147,7 +141,6 @@ def test_simulator_runs(sim_source, device, mocker): assert torch.all( torch.isfinite( sim( - {}, source_light=True, lens_light=True, lens_source=True, @@ -158,7 +151,6 @@ def test_simulator_runs(sim_source, device, mocker): assert torch.all( torch.isfinite( sim( - {}, source_light=True, lens_light=True, lens_source=False, @@ -169,7 +161,6 @@ def test_simulator_runs(sim_source, device, mocker): assert torch.all( torch.isfinite( sim( - {}, source_light=True, lens_light=False, lens_source=True, @@ -180,7 +171,6 @@ def test_simulator_runs(sim_source, device, mocker): assert torch.all( torch.isfinite( sim( - {}, source_light=False, lens_light=True, lens_source=True, diff --git a/tests/test_sis.py b/tests/test_sis.py index cf170620..2b184713 100644 --- a/tests/test_sis.py +++ b/tests/test_sis.py @@ -1,16 +1,18 @@ +from io import StringIO + import torch from lenstronomy.LensModel.lens_model import LensModel from utils import lens_test_helper -import yaml from caustics.cosmology import FlatLambdaCDM from caustics.lenses import SIS +from caustics.sims import build_simulator import pytest @pytest.mark.parametrize("th_ein", [0.1, 1.0, 2.0]) -def test(sim_source, device, lens_models, th_ein): +def test(sim_source, device, th_ein): atol = 1e-5 rtol = 1e-5 z_l = torch.tensor(0.5) @@ -23,14 +25,12 @@ def test(sim_source, device, lens_models, th_ein): lens: &lens name: sis kind: SIS - params: - z_l: {float(z_l)} init_kwargs: + z_l: {float(z_l)} cosmology: *cosmology """ - yaml_dict = yaml.safe_load(yaml_str.encode("utf-8")) - mod = lens_models.get("SIS") - lens = mod(**yaml_dict["lens"]).model_obj() + with StringIO(yaml_str) as f: + lens = build_simulator(f) else: # Models cosmology = FlatLambdaCDM(name="cosmo") @@ -46,7 +46,3 @@ def test(sim_source, device, lens_models, th_ein): ] lens_test_helper(lens, lens_ls, z_s, x, kwargs_ls, rtol, atol, device=device) - - -if __name__ == "__main__": - test(None) diff --git a/tests/test_tnfw.py b/tests/test_tnfw.py index 2751c3d6..c9aeffbf 100644 --- a/tests/test_tnfw.py +++ b/tests/test_tnfw.py @@ -1,7 +1,5 @@ -# from math import pi +from io import StringIO -# import lenstronomy.Util.param_util as param_util -import yaml import torch from astropy.cosmology import FlatLambdaCDM as FlatLambdaCDM_AP from astropy.cosmology import default_cosmology @@ -13,6 +11,7 @@ from caustics.cosmology import FlatLambdaCDM as CausticFlatLambdaCDM from caustics.lenses import TNFW +from caustics.sims import build_simulator import pytest @@ -26,7 +25,7 @@ ) # Note with m=1e14 the test fails, due to the Rs_angle becoming too large (pytorch is unstable) @pytest.mark.parametrize("c", [1.0, 8.0, 40.0]) @pytest.mark.parametrize("t", [2.0, 5.0, 20.0]) -def test(sim_source, device, lens_models, m, c, t): +def test(sim_source, device, m, c, t): atol = 1e-5 rtol = 3e-2 z_l = torch.tensor(0.1) @@ -39,15 +38,13 @@ def test(sim_source, device, lens_models, m, c, t): lens: &lens name: tnfw kind: TNFW - params: - z_l: {float(z_l)} init_kwargs: + z_l: {float(z_l)} cosmology: *cosmology interpret_m_total_mass: false """ - yaml_dict = yaml.safe_load(yaml_str.encode("utf-8")) - mod = lens_models.get("TNFW") - lens = mod(**yaml_dict["lens"]).model_obj() + with StringIO(yaml_str) as f: + lens = build_simulator(f) else: # Models cosmology = CausticFlatLambdaCDM(name="cosmo") diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py index b9da902f..e87171ff 100644 --- a/tests/utils/__init__.py +++ b/tests/utils/__init__.py @@ -4,180 +4,68 @@ from typing import Any, Dict, List, Union -import torch import numpy as np from astropy.cosmology import FlatLambdaCDM as FlatLambdaCDM_AP from lenstronomy.Data.pixel_grid import PixelGrid from lenstronomy.LensModel.lens_model import LensModel -from caustics.lenses import ThinLens, EPL, NFW, ThickLens, PixelatedConvergence -from caustics.light import Sersic, Pixelated -from caustics.utils import meshgrid -from caustics.sims import Simulator -from caustics.cosmology import FlatLambdaCDM -from .models import mock_from_file +import caustics + +import pytest +import textwrap __all__ = ("mock_from_file",) -def setup_simulator( - cosmo_static=False, use_nfw=True, simulator_static=False, batched_params=False, device=None -): - n_pix = 20 - - class Sim(Simulator): - def __init__(self, name="simulator"): - super().__init__(name) - if simulator_static: - self.add_param("z_s", 1.0) - else: - self.add_param("z_s", None) - z_l = 0.5 - self.cosmo = FlatLambdaCDM(h0=0.7 if cosmo_static else None, name="cosmo") - if use_nfw: - self.lens = NFW( - self.cosmo, z_l=z_l, name="lens" - ) # NFW wactually depend on cosmology, so a better test for Parametrized - else: - self.lens = EPL(self.cosmo, z_l=z_l, name="lens") - self.sersic = Sersic(name="source") - self.thx, self.thy = meshgrid(0.04, n_pix, device=device) - self.n_pix = n_pix - self.to(device=device) - - def forward(self, params): - (z_s,) = self.unpack(params) - alphax, alphay = self.lens.reduced_deflection_angle( - x=self.thx, y=self.thy, z_s=z_s, params=params - ) - bx = self.thx - alphax - by = self.thy - alphay - return self.sersic.brightness(bx, by, params) - - # default simulator params - z_s = torch.tensor([1.0, 1.5]) - sim_params = [z_s] - # default cosmo params - h0 = torch.tensor([0.68, 0.75]) - cosmo_params = [h0] - # default lens params - if use_nfw: - x0 = torch.tensor([0.0, 0.1]) - y0 = torch.tensor([0.0, 0.1]) - m = torch.tensor([1e12, 1e13]) - c = torch.tensor([10, 5]) - lens_params = [x0, y0, m, c] - else: - x0 = torch.tensor([0, 0.1]) - y0 = torch.tensor([0, 0.1]) - q = torch.tensor([0.9, 0.8]) - phi = torch.tensor([-0.56, 0.8]) - b = torch.tensor([1.5, 1.2]) - t = torch.tensor([1.2, 1.0]) - lens_params = [x0, y0, q, phi, b, t] - # default source params - x0s = torch.tensor([0, 0.1]) - y0s = torch.tensor([0, 0.1]) - qs = torch.tensor([0.9, 0.8]) - phis = torch.tensor([-0.56, 0.8]) - n = torch.tensor([1.0, 4.0]) - Re = torch.tensor([0.2, 0.5]) - Ie = torch.tensor([1.2, 10.0]) - source_params = [x0s, y0s, qs, phis, n, Re, Ie] - - if not batched_params: - sim_params = [_x[0] for _x in sim_params] - cosmo_params = [_x[0] for _x in cosmo_params] - lens_params = [_x[0] for _x in lens_params] - source_params = [_x[0] for _x in source_params] - - sim = Sim() - # Set device when not None - if device is not None: - sim = sim.to(device=device) - sim_params = [_p.to(device=device) for _p in sim_params] - cosmo_params = [_p.to(device=device) for _p in cosmo_params] - lens_params = [_p.to(device=device) for _p in lens_params] - source_params = [_p.to(device=device) for _p in source_params] - - return sim, (sim_params, cosmo_params, lens_params, source_params) - - -def setup_image_simulator(cosmo_static=False, batched_params=False, device=None): - n_pix = 20 - - class Sim(Simulator): - def __init__(self, name="test"): - super().__init__(name) - pixel_scale = 0.04 - z_l = 0.5 - self.z_s = torch.tensor(1.0) - self.cosmo = FlatLambdaCDM(h0=0.7 if cosmo_static else None, name="cosmo") - self.epl = EPL(self.cosmo, z_l=z_l, name="lens") - self.kappa = PixelatedConvergence( - pixel_scale, - self.cosmo, - z_l=z_l, - shape=(n_pix, n_pix), - name="kappa", - ) - self.source = Pixelated( - x0=0.0, - y0=0.0, - pixelscale=pixel_scale / 2, - shape=(n_pix, n_pix), - name="source", - ) - self.thx, self.thy = meshgrid(pixel_scale, n_pix, device=device) - self.n_pix = n_pix - self.to(device=device) - - def forward(self, params): - alphax, alphay = self.epl.reduced_deflection_angle( - x=self.thx, y=self.thy, z_s=self.z_s, params=params - ) - alphax_h, alphay_h = self.kappa.reduced_deflection_angle( - x=self.thx, y=self.thy, z_s=self.z_s, params=params - ) - bx = self.thx - alphax - alphax_h - by = self.thy - alphay - alphay_h - return self.source.brightness(bx, by, params) - - # default cosmo params - h0 = torch.tensor([0.68, 0.75]) - # default lens params - x0 = torch.tensor([0, 0.1]) - y0 = torch.tensor([0, 0.1]) - q = torch.tensor([0.9, 0.8]) - phi = torch.tensor([-0.56, 0.8]) - b = torch.tensor([1.5, 1.2]) - t = torch.tensor([1.2, 1.0]) - # default kappa params - kappa = torch.randn([2, n_pix, n_pix]) - source = torch.randn([2, n_pix, n_pix]) - - cosmo_params = [h0] - lens_params = [x0, y0, q, phi, b, t] - if not batched_params: - cosmo_params = [_x[0] for _x in cosmo_params] - lens_params = [_x[0] for _x in lens_params] - kappa = kappa[0] - source = source[0] - - sim = Sim() - # Set device when not None - if device is not None: - sim = sim.to(device=device) - cosmo_params = [_p.to(device=device) for _p in cosmo_params] - lens_params = [_p.to(device=device) for _p in lens_params] - kappa = kappa.to(device=device) - source = source.to(device=device) - - return sim, (cosmo_params, lens_params, [kappa], [source]) +@pytest.fixture +def sim_yaml(): + return textwrap.dedent( + """\ + cosmology: &cosmo + name: cosmo + kind: FlatLambdaCDM + + lens: &lens + name: lens + kind: SIE + init_kwargs: + cosmology: *cosmo + + src: &src + name: source + kind: Sersic + + lnslt: &lnslt + name: lenslight + kind: Sersic + + simulator: + name: minisim + kind: LensSource + init_kwargs: + # Single lense + lens: *lens + source: *src + lens_light: *lnslt + pixelscale: 0.05 + pixels_x: 100 + """ + ) + + +@pytest.fixture +def sim_obj(): + cosmology = caustics.FlatLambdaCDM() + sie = caustics.SIE(cosmology=cosmology, name="lens") + src = caustics.Sersic(name="source") + lnslt = caustics.Sersic(name="lenslight") + return caustics.LensSource( + lens=sie, source=src, lens_light=lnslt, pixelscale=0.05, pixels_x=100 + ) def get_default_cosmologies(device=None): - cosmology = FlatLambdaCDM("cosmo") + cosmology = caustics.FlatLambdaCDM("cosmo") cosmology_ap = FlatLambdaCDM_AP(100 * cosmology.h0.value, cosmology.Om0.value, Tcmb0=0) if device is not None: @@ -187,7 +75,7 @@ def get_default_cosmologies(device=None): def setup_grids(res=0.05, n_pix=100, device=None): # Caustics setup - thx, thy = meshgrid(res, n_pix, device=device) + thx, thy = caustics.utils.meshgrid(res, n_pix, device=device) if device is not None: thx = thx.to(device=device) thy = thy.to(device=device) @@ -236,20 +124,27 @@ def kappa_test_helper(lens, lens_ls, z_s, x, kwargs_ls, atol, rtol, device=None) def shear_test_helper( lens, lens_ls, z_s, x, kwargs_ls, atol, rtol, just_egregious=False, device=None ): - thx, thy, thx_ls, thy_ls = setup_grids(device=device) - gamma1, gamma2 = lens.shear(thx, thy, z_s, x) + thx, thy, thx_ls, thy_ls = setup_grids(device=device, n_pix=1000 if just_egregious else 100) + gamma1, gamma2 = lens.shear( + thx, + thy, + z_s, + x, + method="finitediff" if just_egregious else "autograd", + pixelscale=thx[0][1] - thx[0][0], + ) gamma1_ls, gamma2_ls = lens_ls.gamma(thx_ls, thy_ls, kwargs_ls) - if just_egregious: + if just_egregious: # only for NFW and TNFW, this needs more attention print(np.sum(np.abs(np.log10(np.abs(1 - gamma1.cpu().numpy() / gamma1_ls))) < 1)) - assert np.sum(np.abs(np.log10(np.abs(1 - gamma1.cpu().numpy() / gamma1_ls))) < 1) < 1000 - assert np.sum(np.abs(np.log10(np.abs(1 - gamma2.cpu().numpy() / gamma2_ls))) < 1) < 1000 + assert np.sum(np.abs(np.log10(np.abs(1 - gamma1.cpu().numpy() / gamma1_ls))) < 1) < 100000 + assert np.sum(np.abs(np.log10(np.abs(1 - gamma2.cpu().numpy() / gamma2_ls))) < 1) < 100000 else: assert np.allclose(gamma1.cpu().numpy(), gamma1_ls, rtol, atol) assert np.allclose(gamma2.cpu().numpy(), gamma2_ls, rtol, atol) def lens_test_helper( - lens: Union[ThinLens, ThickLens], + lens: Union[caustics.ThinLens, caustics.ThickLens], lens_ls: LensModel, z_s, x, diff --git a/tests/utils/models.py b/tests/utils/models.py deleted file mode 100644 index 14f90590..00000000 --- a/tests/utils/models.py +++ /dev/null @@ -1,177 +0,0 @@ -import yaml -import torch -import numpy as np - -import caustics - - -def mock_from_file(mocker, yaml_str): - # Mock the from_file function - # this way, we don't need to use a real file - mocker.patch("caustics.models.api.from_file", return_value=yaml_str.encode("utf-8")) - - -def obj_to_yaml(obj_dict: dict): - yaml_string = yaml.safe_dump(obj_dict, sort_keys=False) - string_list = yaml_string.split("\n") - id_str = string_list[0] + f" &{string_list[0]}".strip(":") - string_list[0] = id_str - return "\n".join(string_list).replace("'", "") - - -def setup_complex_multiplane_yaml(): - # initialization stuff for lenses - cosmology = caustics.FlatLambdaCDM(name="cosmo") - cosmo = { - cosmology.name: { - "name": cosmology.name, - "kind": cosmology.__class__.__name__, - } - } - cosmology.to(dtype=torch.float32) - n_pix = 100 - res = 0.05 - upsample_factor = 2 - fov = res * n_pix - thx, thy = caustics.utils.meshgrid( - res / upsample_factor, - upsample_factor * n_pix, - dtype=torch.float32, - ) - z_s = torch.tensor(1.5, dtype=torch.float32) - all_lenses = [] - all_single_planes = [] - - N_planes = 10 - N_lenses = 2 # per plane - - z_plane = np.linspace(0.1, 1.0, N_planes) - planes = [] - - for p, z_p in enumerate(z_plane): - lenses = [] - lens_keys = [] - - if p == N_planes // 2: - lens = caustics.NFW( - cosmology=cosmology, - z_l=z_p, - x0=torch.tensor(0.0), - y0=torch.tensor(0.0), - m=torch.tensor(10**11), - c=torch.tensor(10.0), - s=torch.tensor(0.001), - ) - lenses.append(lens) - all_lenses.append( - { - lens.name: { - "name": lens.name, - "kind": lens.__class__.__name__, - "params": { - k: float(v.value) - for k, v in lens.module_params.static.items() - }, - "init_kwargs": {"cosmology": f"*{cosmology.name}"}, - } - } - ) - lens_keys.append(f"*{lens.name}") - else: - for _ in range(N_lenses): - lens = caustics.NFW( - cosmology=cosmology, - z_l=z_p, - x0=torch.tensor(np.random.uniform(-fov / 2.0, fov / 2.0)), - y0=torch.tensor(np.random.uniform(-fov / 2.0, fov / 2.0)), - m=torch.tensor(10 ** np.random.uniform(8, 9)), - c=torch.tensor(np.random.uniform(4, 40)), - s=torch.tensor(0.001), - ) - lenses.append(lens) - all_lenses.append( - { - lens.name: { - "name": lens.name, - "kind": lens.__class__.__name__, - "params": { - k: float(v.value) - for k, v in lens.module_params.static.items() - }, - "init_kwargs": {"cosmology": f"*{cosmology.name}"}, - } - } - ) - lens_keys.append(f"*{lens.name}") - - single_plane = caustics.lenses.SinglePlane( - z_l=z_p, cosmology=cosmology, lenses=lenses, name=f"plane_{p}" - ) - planes.append(single_plane) - all_single_planes.append( - { - single_plane.name: { - "name": single_plane.name, - "kind": single_plane.__class__.__name__, - "params": { - k: float(v.value) - for k, v in single_plane.module_params.static.items() - }, - "init_kwargs": { - "lenses": lens_keys, - "cosmology": f"*{cosmology.name}", - }, - } - } - ) - - lens = caustics.lenses.Multiplane( - name="multiplane", cosmology=cosmology, lenses=planes - ) - multi_dict = { - lens.name: { - "name": lens.name, - "kind": lens.__class__.__name__, - "init_kwargs": { - "lenses": [f"*{p.name}" for p in planes], - "cosmology": f"*{cosmology.name}", - }, - } - } - lenses_yaml = ( - [obj_to_yaml(cosmo)] - + [obj_to_yaml(lens) for lens in all_lenses] - + [obj_to_yaml(plane) for plane in all_single_planes] - + [obj_to_yaml(multi_dict)] - ) - - source_yaml = obj_to_yaml({ - "source": { - "name": "source", - "kind": "Sersic", - } - }) - - lenslight_yaml = obj_to_yaml({ - "lnslight": { - "name": "lnslight", - "kind": "Sersic", - } - }) - - sim_yaml = obj_to_yaml({ - "simulator": { - "name": "sim", - "kind": "LensSource", - "init_kwargs": { - "lens": f"*{lens.name}", - "source": "*source", - "lens_light": "*lnslight", - "pixelscale": 0.05, - "pixels_x": 100, - } - } - }) - - all_yaml_list = lenses_yaml + [source_yaml, lenslight_yaml, sim_yaml] - return "\n".join(all_yaml_list) From 50b0c840a64daf3c83b70d2db351dbc71ac21d0b Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Thu, 21 Nov 2024 13:42:25 -0500 Subject: [PATCH 5/9] docs: Minor corrections --- .../source/examples/Example_ImageFit_LM.ipynb | 2 +- docs/source/frequently_asked_questions.rst | 33 ++++++++++++++----- docs/source/tutorials/LensZoo.ipynb | 2 +- src/caustics/utils.py | 1 + 4 files changed, 28 insertions(+), 10 deletions(-) diff --git a/docs/source/examples/Example_ImageFit_LM.ipynb b/docs/source/examples/Example_ImageFit_LM.ipynb index 717b5029..377f981d 100644 --- a/docs/source/examples/Example_ImageFit_LM.ipynb +++ b/docs/source/examples/Example_ImageFit_LM.ipynb @@ -234,7 +234,7 @@ "batch_inits = batch_inits.to(dtype=torch.float32)\n", "res = caustics.utils.batch_lm(\n", " batch_inits,\n", - " obs_system.reshape(-1).repeat(10, 1),\n", + " obs_system.reshape(-1).repeat(10, 1).to(dtype=torch.float32),\n", " lambda x: sim(x).reshape(-1),\n", " C=variance.reshape(-1).repeat(10, 1),\n", ")\n", diff --git a/docs/source/frequently_asked_questions.rst b/docs/source/frequently_asked_questions.rst index cf69bcca..870ff8a9 100644 --- a/docs/source/frequently_asked_questions.rst +++ b/docs/source/frequently_asked_questions.rst @@ -1,10 +1,27 @@ FAQs - Frequently asked questions ================================= -| **Q:** How do I know what order to put the parameter values in the pytorch tensor which gets passed to the simulator? -| **A:** If you are using a simulator, then you can get the parameters using ``your_simulator.state_dict()``. The parameters whose values are set dynamically will say "None", while the static parameters will have their values shown. The order of the dynamical parameters corresponds to the order you should use in your parameter value tensor. -| -| **Q:** Why can I put the lens redshift at higher values than the source redshift or to negative values for some parametric models? -| **A:** We can calculate everything for those profiles with reduced deflection angles where the redshifts do not actually play into the calculation. If you use a profile defined by the lens mass, like a NFW lens, or a Multiplane lens then it does matter that the redshifts make sense and you will very likely get errors for those. Similarly, if you call the ``lens.physical_deflection_angle`` you will encounter errors. -| -| **Q:** I do (multiplane-)lensing with pixelated convergence using the pixelated kappa map of a parametric profile. The lensing effect differs from directly using the parametric lens. Why is the lensing effect different? -| **A:** Since you do pixelated convergence your mass is binned in pixels in a finite field of view (FOV) so you are missing some mass. At the limit of infinite resolution and infinite FOV the pixelated profile gives you the parametric profile. If the difference is above your error tolerance then you have to increase the resolution and/or FOV of your pixelated convergence map. Especially for SIE or EPL profiles (which go to infinity density in the center, and have infinite mass outside any FOV) you will miss infinite mass when pixelating. +| **Q:** How do I know what order to put the parameter values in the pytorch + tensor which gets passed to the simulator? +| **A:** If you are using any ``Module``` object (so a simulator), then you can +get the parameters using ``print(simulator)``. The order of the dynamical +parameters (top to bottom) corresponds to the order you should use in your +parameter value tensor. Note that you can ignore the static parameters. | +| **Q:** Why can I put the lens redshift at higher values than the source + redshift or to negative values for some parametric models? +| **A:** We can calculate everything for those profiles with reduced deflection +angles where the redshifts do not actually play into the calculation. If you use +a profile defined by the lens mass, like a NFW lens, or a Multiplane lens then +it does matter that the redshifts make sense and you will very likely get errors +for those. Similarly, if you call the ``lens.physical_deflection_angle`` you +will encounter errors. | +| **Q:** I do (multiplane-)lensing with pixelated convergence using the + pixelated kappa map of a parametric profile. The lensing effect differs from + directly using the parametric lens. Why is the lensing effect different? +| **A:** Since you do pixelated convergence your mass is binned in pixels in a + finite field of view (FOV) so you are missing some mass. At the limit of + infinite resolution and infinite FOV the pixelated profile gives you the + parametric profile. If the difference is above your error tolerance then you + have to increase the resolution and/or FOV of your pixelated convergence map. + Especially for SIE or EPL profiles (which go to infinity density in the + center, and have infinite mass outside any FOV) you will miss infinite mass + when pixelating. | diff --git a/docs/source/tutorials/LensZoo.ipynb b/docs/source/tutorials/LensZoo.ipynb index 6afdc588..2a6ecff1 100644 --- a/docs/source/tutorials/LensZoo.ipynb +++ b/docs/source/tutorials/LensZoo.ipynb @@ -565,7 +565,7 @@ " cosmology=cosmology,\n", " x0=0.0,\n", " y0=0.0,\n", - " sd=1.5,\n", + " kappa=1.5,\n", " z_l=z_l,\n", ")\n", "sim = caustics.LensSource(\n", diff --git a/src/caustics/utils.py b/src/caustics/utils.py index c0efde48..3f1ea111 100644 --- a/src/caustics/utils.py +++ b/src/caustics/utils.py @@ -1201,6 +1201,7 @@ def batch_lm( Cinv = 1 / C else: Cinv = torch.linalg.inv(C) + Cinv = Cinv.to(dtype=X.dtype) v_lm_step = torch.vmap( partial( From 6b286516b5ff407dcefb79c8bc5311d2e3987e15 Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Thu, 21 Nov 2024 14:07:19 -0500 Subject: [PATCH 6/9] docs: fix LM example --- docs/source/examples/Example_ImageFit_LM.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/examples/Example_ImageFit_LM.ipynb b/docs/source/examples/Example_ImageFit_LM.ipynb index 377f981d..3a497672 100644 --- a/docs/source/examples/Example_ImageFit_LM.ipynb +++ b/docs/source/examples/Example_ImageFit_LM.ipynb @@ -423,7 +423,7 @@ "source": [ "J = J.reshape(-1, len(best_fit))\n", "# Compute Hessian\n", - "H = J.T @ (J / variance.reshape(-1, 1))\n", + "H = J.T @ (J / variance.reshape(-1, 1).to(dtype=torch.float32))\n", "# Compute covariance matrix\n", "C = torch.linalg.inv(H)\n", "plt.imshow(np.log10(np.abs(C.detach().cpu().numpy())))\n", From 28d6a659e7a4b5cc0471c350e5c4edc207d3b36d Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Thu, 21 Nov 2024 14:08:29 -0500 Subject: [PATCH 7/9] docs: fix faq --- docs/source/frequently_asked_questions.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/frequently_asked_questions.rst b/docs/source/frequently_asked_questions.rst index 870ff8a9..2cc82761 100644 --- a/docs/source/frequently_asked_questions.rst +++ b/docs/source/frequently_asked_questions.rst @@ -5,7 +5,7 @@ FAQs - Frequently asked questions | **A:** If you are using any ``Module``` object (so a simulator), then you can get the parameters using ``print(simulator)``. The order of the dynamical parameters (top to bottom) corresponds to the order you should use in your -parameter value tensor. Note that you can ignore the static parameters. | +parameter value tensor. Note that you can ignore the static parameters. | **Q:** Why can I put the lens redshift at higher values than the source redshift or to negative values for some parametric models? | **A:** We can calculate everything for those profiles with reduced deflection @@ -13,7 +13,7 @@ angles where the redshifts do not actually play into the calculation. If you use a profile defined by the lens mass, like a NFW lens, or a Multiplane lens then it does matter that the redshifts make sense and you will very likely get errors for those. Similarly, if you call the ``lens.physical_deflection_angle`` you -will encounter errors. | +will encounter errors. | **Q:** I do (multiplane-)lensing with pixelated convergence using the pixelated kappa map of a parametric profile. The lensing effect differs from directly using the parametric lens. Why is the lensing effect different? @@ -24,4 +24,4 @@ will encounter errors. | have to increase the resolution and/or FOV of your pixelated convergence map. Especially for SIE or EPL profiles (which go to infinity density in the center, and have infinite mass outside any FOV) you will miss infinite mass - when pixelating. | + when pixelating. From 34b75e8d28abc38879b7612916a51fabd3ab5525 Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Thu, 21 Nov 2024 14:31:13 -0500 Subject: [PATCH 8/9] docs: fix faq again --- docs/source/frequently_asked_questions.rst | 39 ++++++++++++---------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/docs/source/frequently_asked_questions.rst b/docs/source/frequently_asked_questions.rst index 2cc82761..7c247205 100644 --- a/docs/source/frequently_asked_questions.rst +++ b/docs/source/frequently_asked_questions.rst @@ -1,27 +1,32 @@ FAQs - Frequently asked questions ================================= -| **Q:** How do I know what order to put the parameter values in the pytorch - tensor which gets passed to the simulator? -| **A:** If you are using any ``Module``` object (so a simulator), then you can + +**Q:** How do I know what order to put the parameter values in the pytorch tensor which gets passed to the simulator? +----------- + +**A:** If you are using any ``Module``` object (so a simulator), then you can get the parameters using ``print(simulator)``. The order of the dynamical parameters (top to bottom) corresponds to the order you should use in your parameter value tensor. Note that you can ignore the static parameters. -| **Q:** Why can I put the lens redshift at higher values than the source - redshift or to negative values for some parametric models? -| **A:** We can calculate everything for those profiles with reduced deflection + +**Q:** Why can I put the lens redshift at higher values than the source redshift or to negative values for some parametric models? +----------- + +**A:** We can calculate everything for those profiles with reduced deflection angles where the redshifts do not actually play into the calculation. If you use a profile defined by the lens mass, like a NFW lens, or a Multiplane lens then it does matter that the redshifts make sense and you will very likely get errors for those. Similarly, if you call the ``lens.physical_deflection_angle`` you will encounter errors. -| **Q:** I do (multiplane-)lensing with pixelated convergence using the - pixelated kappa map of a parametric profile. The lensing effect differs from - directly using the parametric lens. Why is the lensing effect different? -| **A:** Since you do pixelated convergence your mass is binned in pixels in a - finite field of view (FOV) so you are missing some mass. At the limit of - infinite resolution and infinite FOV the pixelated profile gives you the - parametric profile. If the difference is above your error tolerance then you - have to increase the resolution and/or FOV of your pixelated convergence map. - Especially for SIE or EPL profiles (which go to infinity density in the - center, and have infinite mass outside any FOV) you will miss infinite mass - when pixelating. + +**Q:** I do (multiplane-)lensing with pixelated convergence using the pixelated kappa map of a parametric profile. The lensing effect differs from directly using the parametric lens. Why is the lensing effect different? +----------- + +**A:** Since you do pixelated convergence your mass is binned in pixels in a +finite field of view (FOV) so you are missing some mass. At the limit of +infinite resolution and infinite FOV the pixelated profile gives you the +parametric profile. If the difference is above your error tolerance then you +have to increase the resolution and/or FOV of your pixelated convergence map. +Especially for SIE or EPL profiles (which go to infinity density in the +center, and have infinite mass outside any FOV) you will miss infinite mass +when pixelating. From 7e5782a37cc61391e12f9e4b8e57f2521d97d51c Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Thu, 21 Nov 2024 16:34:18 -0500 Subject: [PATCH 9/9] fix citation.cff title --- CITATION.cff | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CITATION.cff b/CITATION.cff index 7212cf8a..ea2e174b 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -37,7 +37,7 @@ authors: - family-names: "Hezaveh" given-names: "Yashar" orcid: "https://orcid.org/0000-0002-8669-5733" -title: "caustics" +title: "Caustics: A Python Package for Accelerated Strong Gravitational Lensing Simulations" doi: 10.5281/zenodo.10806382 abstract: "The lensing pipeline of the future: GPU-accelerated, automatically-differentiable, highly modular. Currently under heavy development." repository-code: "https://github.com/Ciela-Institute/caustics"