diff --git a/mbpo/optimizers/policy_optimizers/sac/sac.py b/mbpo/optimizers/policy_optimizers/sac/sac.py index bef1a7e..c49bed9 100644 --- a/mbpo/optimizers/policy_optimizers/sac/sac.py +++ b/mbpo/optimizers/policy_optimizers/sac/sac.py @@ -91,6 +91,11 @@ def __init__(self, eval_environment: envs.Env | None = None, episode_length_eval: int | None = None, eval_key_fixed: bool = False, + non_equidistant_time: bool = False, + continuous_discounting: float = 0, + min_time_between_switches: float = 0, + max_time_between_switches: float = 0, + env_dt: float = 0, ): if min_replay_size >= num_timesteps: raise ValueError( @@ -201,7 +206,13 @@ def normalize_fn(batch: PyTree, _: PyTree) -> PyTree: # Setup optimization self.losses = SACLosses(sac_network=self.sac_networks_model.get_sac_networks(), reward_scaling=reward_scaling, - discounting=discounting, u_dim=self.u_dim, target_entropy=self.target_entropy) + discounting=discounting, u_dim=self.u_dim, target_entropy=self.target_entropy, + non_equidistant_time=non_equidistant_time, + continuous_discounting=continuous_discounting, + min_time_between_switches=min_time_between_switches, + max_time_between_switches=max_time_between_switches, + env_dt=env_dt, + ) self.alpha_update = gradient_update_fn( self.losses.alpha_loss, self.alpha_optimizer, pmap_axis_name=self._PMAP_AXIS_NAME)