Skip to content

Commit

Permalink
add continuous-time learning to sac.py
Browse files Browse the repository at this point in the history
  • Loading branch information
lenarttreven committed May 3, 2024
1 parent 5bae48e commit f3827e7
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion mbpo/optimizers/policy_optimizers/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ def __init__(self,
eval_environment: envs.Env | None = None,
episode_length_eval: int | None = None,
eval_key_fixed: bool = False,
non_equidistant_time: bool = False,
continuous_discounting: float = 0,
min_time_between_switches: float = 0,
max_time_between_switches: float = 0,
env_dt: float = 0,
):
if min_replay_size >= num_timesteps:
raise ValueError(
Expand Down Expand Up @@ -201,7 +206,13 @@ def normalize_fn(batch: PyTree, _: PyTree) -> PyTree:

# Setup optimization
self.losses = SACLosses(sac_network=self.sac_networks_model.get_sac_networks(), reward_scaling=reward_scaling,
discounting=discounting, u_dim=self.u_dim, target_entropy=self.target_entropy)
discounting=discounting, u_dim=self.u_dim, target_entropy=self.target_entropy,
non_equidistant_time=non_equidistant_time,
continuous_discounting=continuous_discounting,
min_time_between_switches=min_time_between_switches,
max_time_between_switches=max_time_between_switches,
env_dt=env_dt,
)

self.alpha_update = gradient_update_fn(
self.losses.alpha_loss, self.alpha_optimizer, pmap_axis_name=self._PMAP_AXIS_NAME)
Expand Down

0 comments on commit f3827e7

Please sign in to comment.