Skip to content

Commit

Permalink
changed: changed seed argument to rng to conform to SPEC-7
Browse files Browse the repository at this point in the history
  • Loading branch information
Armavica committed Jan 13, 2025
1 parent f3e4ce6 commit 9858d08
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 20 deletions.
20 changes: 11 additions & 9 deletions python/rebop/__init__.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from collections.abc import Sequence
from typing import TypeAlias

import numpy as np
import xarray as xr

from .rebop import Gillespie, __version__ # type: ignore[attr-defined]

if TYPE_CHECKING:
from collections.abc import Sequence
SeedLike: TypeAlias = int | np.integer | Sequence[int] | np.random.SeedSequence
RNGLike: TypeAlias = np.random.Generator | np.random.BitGenerator

__all__ = ("Gillespie", "__version__")

og_run = Gillespie.run
__all__ = ("Gillespie", "__version__")


def run_xarray( # noqa: PLR0913 too many parameters in function definition
self: Gillespie,
init: dict[str, int],
tmax: float,
nb_steps: int,
seed: int | None = None,
*,
rng: RNGLike | SeedLike | None = None,
sparse: bool = False,
var_names: Sequence[str] | None = None,
) -> xr.Dataset:
Expand All @@ -29,12 +30,13 @@ def run_xarray( # noqa: PLR0913 too many parameters in function definition
The initial configuration is specified in the dictionary `init`.
Returns an xarray Dataset.
"""
times, result = og_run(
self,
rng_ = np.random.default_rng(rng)
seed = rng_.integers(np.iinfo(np.uint64).max, dtype=np.uint64)
times, result = self._run(
init,
tmax,
nb_steps,
seed,
seed=seed,
sparse=sparse,
var_names=var_names,
)
Expand Down
23 changes: 22 additions & 1 deletion python/rebop/rebop.pyi
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from collections.abc import Sequence
from typing import TypeAlias

import numpy as np
import xarray

SeedLike: TypeAlias = int | np.integer | Sequence[int] | np.random.SeedSequence
RNGLike: TypeAlias = np.random.Generator | np.random.BitGenerator

class Gillespie:
"""Reaction system composed of species and reactions."""

Expand All @@ -26,13 +31,29 @@ class Gillespie:
def nb_species(self, /) -> int:
"""Number of species currently in the system."""

def run(
def _run(
self,
init: dict[str, int],
tmax: float,
nb_steps: int,
*,
seed: int | None = None,
sparse: bool = False,
var_names: Sequence[str] | None = None,
) -> tuple[np.ndarray, dict[str, np.ndarray]]:
"""Run the system until `tmax` with `nb_steps` steps.
The initial configuration is specified in the dictionary `init`.
Returns numpy arrays.
"""

def run(
self,
init: dict[str, int],
tmax: float,
nb_steps: int,
*,
rng: RNGLike | SeedLike | None = None,
sparse: bool = False,
var_names: Sequence[str] | None = None,
) -> xarray.Dataset:
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ impl Gillespie {
/// If `nb_steps` is `0`, then returns all reactions, ending with the first that happens at
/// or after `tmax`.
#[pyo3(signature = (init, tmax, nb_steps, seed=None, sparse=false, var_names=None))]
fn run(
fn _run(
&self,
init: HashMap<String, usize>,
tmax: f64,
Expand Down
18 changes: 9 additions & 9 deletions tests/test_rebop.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def sir_model(transmission: float = 1e-4, recovery: float = 0.01) -> rebop.Gille
@pytest.mark.parametrize("seed", [None, *range(10)])
def test_sir(seed: int | None) -> None:
sir = sir_model()
ds = sir.run({"S": 999, "I": 1}, tmax=250, nb_steps=250, seed=seed)
ds = sir.run({"S": 999, "I": 1}, tmax=250, nb_steps=250, rng=seed)
assert isinstance(ds, xr.Dataset)
npt.assert_array_equal(ds.time, np.arange(251))
assert all(ds.S >= 0)
Expand All @@ -29,18 +29,18 @@ def test_sir(seed: int | None) -> None:

def test_fixed_seed() -> None:
sir = sir_model()
ds = sir.run({"S": 999, "I": 1}, tmax=250, nb_steps=250, seed=42)
ds = sir.run({"S": 999, "I": 1}, tmax=250, nb_steps=250, rng=42)

assert ds.S[-1] == 0
assert ds.I[-1] == 166
assert ds.R[-1] == 834
assert ds.I[-1] == 182
assert ds.R[-1] == 818


@pytest.mark.parametrize("seed", range(10))
def test_all_reactions(seed: int) -> None:
tmax = 250
sir = sir_model()
ds = sir.run({"S": 999, "I": 1}, tmax=tmax, nb_steps=0, seed=seed)
ds = sir.run({"S": 999, "I": 1}, tmax=tmax, nb_steps=0, rng=seed)
assert ds.time[0] == 0
assert ds.time[-1] > tmax
assert all(ds.time.diff(dim="time") > 0)
Expand All @@ -58,8 +58,8 @@ def test_dense_vs_sparse() -> None:
tmax = 250
nb_steps = 250
seed = 42
ds_dense = sir.run(init, tmax=tmax, nb_steps=nb_steps, seed=seed, sparse=False)
ds_sparse = sir.run(init, tmax=tmax, nb_steps=nb_steps, seed=seed, sparse=True)
ds_dense = sir.run(init, tmax=tmax, nb_steps=nb_steps, rng=seed, sparse=False)
ds_sparse = sir.run(init, tmax=tmax, nb_steps=nb_steps, rng=seed, sparse=True)
assert (ds_dense == ds_sparse).all()


Expand All @@ -74,9 +74,9 @@ def test_var_names(nb_steps: int) -> None:
tmax = 250
seed = 0

ds_all = sir.run(init, tmax=tmax, nb_steps=nb_steps, seed=seed, var_names=None)
ds_all = sir.run(init, tmax=tmax, nb_steps=nb_steps, rng=seed, var_names=None)
ds_subset = sir.run(
init, tmax=tmax, nb_steps=nb_steps, seed=seed, var_names=subset_to_save
init, tmax=tmax, nb_steps=nb_steps, rng=seed, var_names=subset_to_save
)

for s in subset_to_save:
Expand Down

0 comments on commit 9858d08

Please sign in to comment.