diff --git a/mbpo/optimizers/base_optimizer.py b/mbpo/optimizers/base_optimizer.py index 00b1a8c..02b511f 100644 --- a/mbpo/optimizers/base_optimizer.py +++ b/mbpo/optimizers/base_optimizer.py @@ -20,6 +20,10 @@ def __init__(self, system: System | None = None, key: jr.PRNGKey = jr.PRNGKey(0) def set_system(self, system: System): self.system = system + @property + def can_act_in_batches(self): + return True + @abstractmethod def act(self, obs: chex.Array, diff --git a/mbpo/optimizers/trajectory_optimizers/icem_optimizer.py b/mbpo/optimizers/trajectory_optimizers/icem_optimizer.py index 57bc0c4..36fa2f7 100755 --- a/mbpo/optimizers/trajectory_optimizers/icem_optimizer.py +++ b/mbpo/optimizers/trajectory_optimizers/icem_optimizer.py @@ -10,12 +10,11 @@ from jax import vmap, jit from jax.nn import relu -from jax.numpy import sqrt, newaxis -from jax.numpy.fft import irfft, rfftfreq +from brax.training.replay_buffers import ReplayBufferState from jaxtyping import Float, Array, Key, Scalar from mbpo.optimizers.base_optimizer import BaseOptimizer -from mbpo.systems.base_systems import System, SystemParams +from mbpo.systems.base_systems import System from mbpo.systems.dynamics.base_dynamics import DynamicsParams from mbpo.systems.rewards.base_rewards import RewardParams from mbpo.utils.optimizer_utils import rollout_actions @@ -119,7 +118,7 @@ def __init__(self, else: self.summarize_cost_samples = jnp.mean - def init(self, key: chex.Array) -> iCemOptimizerState: + def init(self, key: chex.Array, true_buffer_state: ReplayBufferState | None = None) -> iCemOptimizerState: assert self.system is not None, "iCem optimizer requires system to be defined." init_key, dummy_buffer_key, key = jax.random.split(key, 3) system_params = self.system.init_params(init_key) @@ -280,6 +279,10 @@ def __init__(self, def set_system(self, system: System): super().set_system(system) + @property + def can_act_in_batches(self): + return False + def init(self, key: chex.PRNGKey, true_buffer_state = None) -> iCemOptimizerState: