From 9858d085433dc4b3474827280a44e21c22a83157 Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Mon, 13 Jan 2025 22:15:10 +0100 Subject: [PATCH] changed: changed seed argument to rng to conform to SPEC-7 --- python/rebop/__init__.py | 20 +++++++++++--------- python/rebop/rebop.pyi | 23 ++++++++++++++++++++++- src/lib.rs | 2 +- tests/test_rebop.py | 18 +++++++++--------- 4 files changed, 43 insertions(+), 20 deletions(-) diff --git a/python/rebop/__init__.py b/python/rebop/__init__.py index 0db0e33..aa69bf0 100644 --- a/python/rebop/__init__.py +++ b/python/rebop/__init__.py @@ -1,17 +1,18 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from collections.abc import Sequence +from typing import TypeAlias +import numpy as np import xarray as xr from .rebop import Gillespie, __version__ # type: ignore[attr-defined] -if TYPE_CHECKING: - from collections.abc import Sequence +SeedLike: TypeAlias = int | np.integer | Sequence[int] | np.random.SeedSequence +RNGLike: TypeAlias = np.random.Generator | np.random.BitGenerator -__all__ = ("Gillespie", "__version__") -og_run = Gillespie.run +__all__ = ("Gillespie", "__version__") def run_xarray( # noqa: PLR0913 too many parameters in function definition @@ -19,8 +20,8 @@ def run_xarray( # noqa: PLR0913 too many parameters in function definition init: dict[str, int], tmax: float, nb_steps: int, - seed: int | None = None, *, + rng: RNGLike | SeedLike | None = None, sparse: bool = False, var_names: Sequence[str] | None = None, ) -> xr.Dataset: @@ -29,12 +30,13 @@ def run_xarray( # noqa: PLR0913 too many parameters in function definition The initial configuration is specified in the dictionary `init`. Returns an xarray Dataset. """ - times, result = og_run( - self, + rng_ = np.random.default_rng(rng) + seed = rng_.integers(np.iinfo(np.uint64).max, dtype=np.uint64) + times, result = self._run( init, tmax, nb_steps, - seed, + seed=seed, sparse=sparse, var_names=var_names, ) diff --git a/python/rebop/rebop.pyi b/python/rebop/rebop.pyi index da33fa3..38bcb4b 100644 --- a/python/rebop/rebop.pyi +++ b/python/rebop/rebop.pyi @@ -1,7 +1,12 @@ from collections.abc import Sequence +from typing import TypeAlias +import numpy as np import xarray +SeedLike: TypeAlias = int | np.integer | Sequence[int] | np.random.SeedSequence +RNGLike: TypeAlias = np.random.Generator | np.random.BitGenerator + class Gillespie: """Reaction system composed of species and reactions.""" @@ -26,13 +31,29 @@ class Gillespie: def nb_species(self, /) -> int: """Number of species currently in the system.""" - def run( + def _run( self, init: dict[str, int], tmax: float, nb_steps: int, + *, seed: int | None = None, + sparse: bool = False, + var_names: Sequence[str] | None = None, + ) -> tuple[np.ndarray, dict[str, np.ndarray]]: + """Run the system until `tmax` with `nb_steps` steps. + + The initial configuration is specified in the dictionary `init`. + Returns numpy arrays. + """ + + def run( + self, + init: dict[str, int], + tmax: float, + nb_steps: int, *, + rng: RNGLike | SeedLike | None = None, sparse: bool = False, var_names: Sequence[str] | None = None, ) -> xarray.Dataset: diff --git a/src/lib.rs b/src/lib.rs index 8050315..fab2883 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -304,7 +304,7 @@ impl Gillespie { /// If `nb_steps` is `0`, then returns all reactions, ending with the first that happens at /// or after `tmax`. #[pyo3(signature = (init, tmax, nb_steps, seed=None, sparse=false, var_names=None))] - fn run( + fn _run( &self, init: HashMap, tmax: f64, diff --git a/tests/test_rebop.py b/tests/test_rebop.py index f6853df..ce056d3 100644 --- a/tests/test_rebop.py +++ b/tests/test_rebop.py @@ -15,7 +15,7 @@ def sir_model(transmission: float = 1e-4, recovery: float = 0.01) -> rebop.Gille @pytest.mark.parametrize("seed", [None, *range(10)]) def test_sir(seed: int | None) -> None: sir = sir_model() - ds = sir.run({"S": 999, "I": 1}, tmax=250, nb_steps=250, seed=seed) + ds = sir.run({"S": 999, "I": 1}, tmax=250, nb_steps=250, rng=seed) assert isinstance(ds, xr.Dataset) npt.assert_array_equal(ds.time, np.arange(251)) assert all(ds.S >= 0) @@ -29,18 +29,18 @@ def test_sir(seed: int | None) -> None: def test_fixed_seed() -> None: sir = sir_model() - ds = sir.run({"S": 999, "I": 1}, tmax=250, nb_steps=250, seed=42) + ds = sir.run({"S": 999, "I": 1}, tmax=250, nb_steps=250, rng=42) assert ds.S[-1] == 0 - assert ds.I[-1] == 166 - assert ds.R[-1] == 834 + assert ds.I[-1] == 182 + assert ds.R[-1] == 818 @pytest.mark.parametrize("seed", range(10)) def test_all_reactions(seed: int) -> None: tmax = 250 sir = sir_model() - ds = sir.run({"S": 999, "I": 1}, tmax=tmax, nb_steps=0, seed=seed) + ds = sir.run({"S": 999, "I": 1}, tmax=tmax, nb_steps=0, rng=seed) assert ds.time[0] == 0 assert ds.time[-1] > tmax assert all(ds.time.diff(dim="time") > 0) @@ -58,8 +58,8 @@ def test_dense_vs_sparse() -> None: tmax = 250 nb_steps = 250 seed = 42 - ds_dense = sir.run(init, tmax=tmax, nb_steps=nb_steps, seed=seed, sparse=False) - ds_sparse = sir.run(init, tmax=tmax, nb_steps=nb_steps, seed=seed, sparse=True) + ds_dense = sir.run(init, tmax=tmax, nb_steps=nb_steps, rng=seed, sparse=False) + ds_sparse = sir.run(init, tmax=tmax, nb_steps=nb_steps, rng=seed, sparse=True) assert (ds_dense == ds_sparse).all() @@ -74,9 +74,9 @@ def test_var_names(nb_steps: int) -> None: tmax = 250 seed = 0 - ds_all = sir.run(init, tmax=tmax, nb_steps=nb_steps, seed=seed, var_names=None) + ds_all = sir.run(init, tmax=tmax, nb_steps=nb_steps, rng=seed, var_names=None) ds_subset = sir.run( - init, tmax=tmax, nb_steps=nb_steps, seed=seed, var_names=subset_to_save + init, tmax=tmax, nb_steps=nb_steps, rng=seed, var_names=subset_to_save ) for s in subset_to_save: