Skip to content
This repository has been archived by the owner on Sep 1, 2024. It is now read-only.

Commit

Permalink
Merge pull request #2 from fairinternal/add_sac
Browse files Browse the repository at this point in the history
Add SAC agent to MBPO
  • Loading branch information
luisenp authored Aug 26, 2020
2 parents 4101320 + a29334a commit dc2b62d
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 96 deletions.
67 changes: 32 additions & 35 deletions conf/agent/sac.yaml
Original file line number Diff line number Diff line change
@@ -1,41 +1,38 @@
# @package _global_
agent:
name: agent
_target_: agent.agent.SACAgent
params:
obs_dim: ??? # to be specified later
action_dim: ??? # to be specified later
action_range: ??? # to be specified later
device: ${device}
critic_cfg: ${double_q_critic}
actor_cfg: ${diag_gaussian_actor}
discount: 0.99
init_temperature: 0.1
alpha_lr: 1e-4
alpha_betas: [0.9, 0.999]
actor_lr: 1e-4
actor_betas: [0.9, 0.999]
actor_update_frequency: 1
critic_lr: 1e-4
critic_betas: [0.9, 0.999]
critic_tau: 0.005
critic_target_update_frequency: 2
batch_size: 1024
learnable_temperature: true
_target_: pytorch_sac.agent.sac.SACAgent
obs_dim: ??? # to be specified later
action_dim: ??? # to be specified later
action_range: ??? # to be specified later
device: ${device}
critic_cfg: ${double_q_critic}
actor_cfg: ${diag_gaussian_actor}
discount: 0.99
init_temperature: 0.1
alpha_lr: 1e-4
alpha_betas: [0.9, 0.999]
actor_lr: 1e-4
actor_betas: [0.9, 0.999]
actor_update_frequency: 1
critic_lr: 1e-4
critic_betas: [0.9, 0.999]
critic_tau: 0.005
critic_target_update_frequency: 2
batch_size: 1024
learnable_temperature: true
target_entropy: -1

double_q_critic:
class: agent.critic.DoubleQCritic
params:
obs_dim: ${agent.params.obs_dim}
action_dim: ${agent.params.action_dim}
hidden_dim: 1024
hidden_depth: 2
_target_: pytorch_sac.agent.critic.DoubleQCritic
obs_dim: ${agent.obs_dim}
action_dim: ${agent.action_dim}
hidden_dim: 1024
hidden_depth: 2

diag_gaussian_actor:
class: agent.actor.DiagGaussianActor
params:
obs_dim: ${agent.params.obs_dim}
action_dim: ${agent.params.action_dim}
hidden_depth: 2
hidden_dim: 1024
log_std_bounds: [-5, 2]
_target_: pytorch_sac.agent.actor.DiagGaussianActor
obs_dim: ${agent.obs_dim}
action_dim: ${agent.action_dim}
hidden_depth: 2
hidden_dim: 1024
log_std_bounds: [-5, 2]
20 changes: 16 additions & 4 deletions conf/mbpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,27 @@ env: "hopper--stand"
env_dataset_size: 1000
validation_ratio: 0.1
dynamics_model_batch_size: 256
initial_exploration_steps: 100
initial_exploration_steps: 20
num_epochs: 100
freq_train_dyn_model: 100
patience: 50
rollouts_per_env_step: 40
rollouts_per_step: 40
rollout_horizon: 15 # TODO replace by thresholded linear
rollout_batch_size: 512
rollout_batch_size: 32
sac_buffer_capacity: ???
sac_samples_action: true
num_sac_updates_per_rollout: 100

seed: 0

device: "cuda:0"
device: "cuda:0"

log_frequency: 100
log_save_tb: false


experiment: test_exp

hydra:
run:
dir: ./exp/mbrl/${env}/${now:%Y.%m.%d}/${now:%H%M}_${experiment}
147 changes: 91 additions & 56 deletions mbrl/mbpo.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import Callable, Tuple

import gym
Expand All @@ -17,8 +18,9 @@ def collect_random_trajectories(
env_dataset_test: replay_buffer.IterableReplayBuffer,
steps_to_collect: int,
val_ratio: float,
rng: np.random.RandomState,
):
indices = np.random.permutation(steps_to_collect)
indices = rng.permutation(steps_to_collect)
n_train = int(steps_to_collect * (1 - val_ratio))
indices_train = set(indices[:n_train])

Expand All @@ -39,41 +41,32 @@ def collect_random_trajectories(
return


def rollout_model(
env: gym.Env,
model: models.Model,
def rollout_model_and_populate_sac_buffer(
model_env: models.ModelEnv,
env_dataset: replay_buffer.BootstrapReplayBuffer,
termination_fn: Callable[[np.ndarray, np.ndarray, np.ndarray], np.ndarray],
obs_shape: Tuple[int],
act_shape: Tuple[int],
sac_buffer_capacity: int,
num_rollouts: int,
agent: pytorch_sac.SACAgent,
sac_buffer: pytorch_sac.ReplayBuffer,
sac_samples_action: bool,
rollout_horizon: int,
batch_size: int,
device: torch.device,
) -> pytorch_sac.ReplayBuffer:
model_env = models.ModelEnv(env, model, termination_fn)
sac_buffer = pytorch_sac.ReplayBuffer(
obs_shape, act_shape, sac_buffer_capacity, device
)
for _ in range(num_rollouts):
initial_obs, action, *_ = env_dataset.sample(batch_size, ensemble=False)
obs = model_env.reset(initial_obs_batch=initial_obs)
for i in range(rollout_horizon):
pred_next_obs, pred_rewards, pred_dones, _ = model_env.step(action)
# TODO consider changing sac_buffer to vectorize this loop
for j in range(batch_size):
sac_buffer.add(
obs[j],
action[j],
pred_rewards[j],
pred_next_obs[j],
pred_dones[j],
pred_dones[j],
)
obs = pred_next_obs
):

return sac_buffer
initial_obs, action, *_ = env_dataset.sample(batch_size, ensemble=False)
obs = model_env.reset(initial_obs_batch=initial_obs)
for i in range(rollout_horizon):
action = agent.act(obs, sample=sac_samples_action, batched=True)
pred_next_obs, pred_rewards, pred_dones, _ = model_env.step(action)
# TODO change sac_buffer to vectorize this loop (the batch size will be really large)
for j in range(batch_size):
sac_buffer.add(
obs[j],
action[j],
pred_rewards[j],
pred_next_obs[j],
pred_dones[j],
pred_dones[j],
)
obs = pred_next_obs


def train(
Expand All @@ -82,13 +75,26 @@ def train(
device: torch.device,
cfg: omegaconf.DictConfig,
):
# ------------------- Initialization -------------------
obs_shape = env.observation_space.shape
act_shape = env.action_space.shape

# Agent
# agent = pytorch_sac.SACAgent()
cfg.agent.obs_dim = obs_shape[0]
cfg.agent.action_dim = act_shape[0]
cfg.agent.action_range = [
float(env.action_space.low.min()),
float(env.action_space.high.max()),
]
agent = hydra.utils.instantiate(cfg.agent)

work_dir = os.getcwd()
logger = pytorch_sac.Logger(
work_dir, save_tb=cfg.log_save_tb, log_frequency=cfg.log_frequency, agent="sac"
)

rng = np.random.RandomState(cfg.seed)

# Creating and populating environment dataset
# -------------- Create initial env. dataset --------------
env_dataset_train = replay_buffer.BootstrapReplayBuffer(
cfg.env_dataset_size,
cfg.dynamics_model_batch_size,
Expand All @@ -100,15 +106,18 @@ def train(
env_dataset_val = replay_buffer.IterableReplayBuffer(
val_buffer_capacity, cfg.dynamics_model_batch_size, obs_shape, act_shape
)
# TODO replace this with some exploration policy
collect_random_trajectories(
env,
env_dataset_train,
env_dataset_val,
cfg.initial_exploration_steps,
cfg.validation_ratio,
rng,
)

# Training loop
# ---------------------------------------------------------
# --------------------- Training Loop ---------------------
cfg.model.in_size = obs_shape[0] + act_shape[0]
cfg.model.out_size = obs_shape[0] + 1

Expand All @@ -117,26 +126,52 @@ def train(
sac_buffer_capacity = (
cfg.rollouts_per_step * cfg.rollout_horizon * cfg.rollout_batch_size
)

updates_made = 0
env_steps = 0
model_env = models.ModelEnv(env, ensemble, termination_fn)
for epoch in range(cfg.num_epochs):
if epoch % cfg.freq_train_dyn_model == 0:
train_loss, val_score = models.train_dyn_ensemble(
ensemble,
env_dataset_train,
device,
dataset_val=env_dataset_val,
patience=cfg.patience,
obs = env.reset()
done = False
while not done:
# --------------- Env. Step and adding to model dataset -----------------
action = agent.act(obs)
next_obs, reward, done, _ = env.step(action)
if rng.random() < cfg.validation_ratio:
env_dataset_val.add(obs, action, next_obs, reward, done)
else:
env_dataset_train.add(obs, action, next_obs, reward, done)
obs = next_obs

# --------------- Model Training -----------------
if env_steps % cfg.freq_train_dyn_model == 0:
train_loss, val_score = models.train_dyn_ensemble(
ensemble,
env_dataset_train,
device,
dataset_val=env_dataset_val,
patience=cfg.patience,
)

# --------------- Agent Training -----------------
sac_buffer = pytorch_sac.ReplayBuffer(
obs_shape, act_shape, sac_buffer_capacity, device
)
for _ in range(cfg.rollouts_per_step):
rollout_model_and_populate_sac_buffer(
model_env,
env_dataset_train,
agent,
sac_buffer,
cfg.sac_samples_action,
cfg.rollout_horizon,
cfg.rollout_batch_size,
)

for _ in range(cfg.num_sac_updates_per_rollout):
agent.update(sac_buffer, logger, updates_made)
updates_made += 1

logger.dump(updates_made, save=True)

sac_buffer = rollout_model(
env,
ensemble,
env_dataset_train,
termination_fn,
obs_shape,
act_shape,
sac_buffer_capacity,
cfg.rollouts_per_step,
cfg.rollout_horizon,
cfg.roullout_batch_size,
device,
)
env_steps += 1
2 changes: 1 addition & 1 deletion mbrl/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def step(self, actions: np.ndarray):
model_in = torch.from_numpy(
np.concatenate([self._current_obs, actions], axis=1)
).to(self.model.device)
model_out = self.model(model_in).cpu().numpy()[0]
model_out = self.model(model_in)[0].cpu().numpy()
next_observs = model_out[:, :-1]
rewards = model_out[:, -1]
dones = self.termination_fn(actions, next_observs)
Expand Down

0 comments on commit dc2b62d

Please sign in to comment.