Skip to content

Commit

Permalink
Merge pull request #2 from lasgroup/main
Browse files Browse the repository at this point in the history
update dev/kiten from main
  • Loading branch information
Bakeey authored Nov 7, 2024
2 parents 92283fb + 58329d7 commit f18419b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
4 changes: 4 additions & 0 deletions mbpo/optimizers/base_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ def __init__(self, system: System | None = None, key: jr.PRNGKey = jr.PRNGKey(0)
def set_system(self, system: System):
self.system = system

@property
def can_act_in_batches(self):
return True

@abstractmethod
def act(self,
obs: chex.Array,
Expand Down
11 changes: 7 additions & 4 deletions mbpo/optimizers/trajectory_optimizers/icem_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@

from jax import vmap, jit
from jax.nn import relu
from jax.numpy import sqrt, newaxis
from jax.numpy.fft import irfft, rfftfreq
from brax.training.replay_buffers import ReplayBufferState
from jaxtyping import Float, Array, Key, Scalar

from mbpo.optimizers.base_optimizer import BaseOptimizer
from mbpo.systems.base_systems import System, SystemParams
from mbpo.systems.base_systems import System
from mbpo.systems.dynamics.base_dynamics import DynamicsParams
from mbpo.systems.rewards.base_rewards import RewardParams
from mbpo.utils.optimizer_utils import rollout_actions
Expand Down Expand Up @@ -119,7 +118,7 @@ def __init__(self,
else:
self.summarize_cost_samples = jnp.mean

def init(self, key: chex.Array) -> iCemOptimizerState:
def init(self, key: chex.Array, true_buffer_state: ReplayBufferState | None = None) -> iCemOptimizerState:
assert self.system is not None, "iCem optimizer requires system to be defined."
init_key, dummy_buffer_key, key = jax.random.split(key, 3)
system_params = self.system.init_params(init_key)
Expand Down Expand Up @@ -280,6 +279,10 @@ def __init__(self,
def set_system(self, system: System):
super().set_system(system)

@property
def can_act_in_batches(self):
return False

def init(self,
key: chex.PRNGKey,
true_buffer_state = None) -> iCemOptimizerState:
Expand Down

0 comments on commit f18419b

Please sign in to comment.