Skip to content
This repository has been archived by the owner on Sep 1, 2024. It is now read-only.

Add SAC agent to MBPO #2

Merged
merged 2 commits into from
Aug 26, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 31 additions & 35 deletions conf/agent/sac.yaml
Original file line number Diff line number Diff line change
@@ -1,41 +1,37 @@
# @package _global_
agent:
name: agent
_target_: agent.agent.SACAgent
params:
obs_dim: ??? # to be specified later
action_dim: ??? # to be specified later
action_range: ??? # to be specified later
device: ${device}
critic_cfg: ${double_q_critic}
actor_cfg: ${diag_gaussian_actor}
discount: 0.99
init_temperature: 0.1
alpha_lr: 1e-4
alpha_betas: [0.9, 0.999]
actor_lr: 1e-4
actor_betas: [0.9, 0.999]
actor_update_frequency: 1
critic_lr: 1e-4
critic_betas: [0.9, 0.999]
critic_tau: 0.005
critic_target_update_frequency: 2
batch_size: 1024
learnable_temperature: true
_target_: pytorch_sac.agent.sac.SACAgent
obs_dim: ??? # to be specified later
action_dim: ??? # to be specified later
action_range: ??? # to be specified later
device: ${device}
critic_cfg: ${double_q_critic}
actor_cfg: ${diag_gaussian_actor}
discount: 0.99
init_temperature: 0.1
alpha_lr: 1e-4
alpha_betas: [0.9, 0.999]
actor_lr: 1e-4
actor_betas: [0.9, 0.999]
actor_update_frequency: 1
critic_lr: 1e-4
critic_betas: [0.9, 0.999]
critic_tau: 0.005
critic_target_update_frequency: 2
batch_size: 1024
learnable_temperature: true

double_q_critic:
class: agent.critic.DoubleQCritic
params:
obs_dim: ${agent.params.obs_dim}
action_dim: ${agent.params.action_dim}
hidden_dim: 1024
hidden_depth: 2
_target_: pytorch_sac.agent.critic.DoubleQCritic
obs_dim: ${agent.obs_dim}
action_dim: ${agent.action_dim}
hidden_dim: 1024
hidden_depth: 2

diag_gaussian_actor:
class: agent.actor.DiagGaussianActor
params:
obs_dim: ${agent.params.obs_dim}
action_dim: ${agent.params.action_dim}
hidden_depth: 2
hidden_dim: 1024
log_std_bounds: [-5, 2]
_target_: pytorch_sac.agent.actor.DiagGaussianActor
obs_dim: ${agent.obs_dim}
action_dim: ${agent.action_dim}
hidden_depth: 2
hidden_dim: 1024
log_std_bounds: [-5, 2]
20 changes: 16 additions & 4 deletions conf/mbpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,27 @@ env: "hopper--stand"
env_dataset_size: 1000
validation_ratio: 0.1
dynamics_model_batch_size: 256
initial_exploration_steps: 100
initial_exploration_steps: 20
num_epochs: 100
freq_train_dyn_model: 100
patience: 50
rollouts_per_env_step: 40
rollouts_per_step: 40
rollout_horizon: 15 # TODO replace by thresholded linear
rollout_batch_size: 512
rollout_batch_size: 32
sac_buffer_capacity: ???
sac_samples_action: true
num_sac_updates_per_rollout: 100

seed: 0

device: "cuda:0"
device: "cuda:0"

log_frequency: 100
log_save_tb: false


experiment: test_exp

hydra:
run:
dir: ./exp/mbrl/${env}/${now:%Y.%m.%d}/${now:%H%M}_${experiment}
147 changes: 91 additions & 56 deletions mbrl/mbpo.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import Callable, Tuple

import gym
Expand All @@ -17,8 +18,9 @@ def collect_random_trajectories(
env_dataset_test: replay_buffer.IterableReplayBuffer,
steps_to_collect: int,
val_ratio: float,
rng: np.random.RandomState,
):
indices = np.random.permutation(steps_to_collect)
indices = rng.permutation(steps_to_collect)
n_train = int(steps_to_collect * (1 - val_ratio))
indices_train = set(indices[:n_train])

Expand All @@ -39,41 +41,32 @@ def collect_random_trajectories(
return


def rollout_model(
env: gym.Env,
model: models.Model,
def rollout_model_and_populate_sac_buffer(
model_env: models.ModelEnv,
env_dataset: replay_buffer.BootstrapReplayBuffer,
termination_fn: Callable[[np.ndarray, np.ndarray, np.ndarray], np.ndarray],
obs_shape: Tuple[int],
act_shape: Tuple[int],
sac_buffer_capacity: int,
num_rollouts: int,
agent: pytorch_sac.SACAgent,
sac_buffer: pytorch_sac.ReplayBuffer,
sac_samples_action: bool,
rollout_horizon: int,
batch_size: int,
device: torch.device,
) -> pytorch_sac.ReplayBuffer:
model_env = models.ModelEnv(env, model, termination_fn)
sac_buffer = pytorch_sac.ReplayBuffer(
obs_shape, act_shape, sac_buffer_capacity, device
)
for _ in range(num_rollouts):
initial_obs, action, *_ = env_dataset.sample(batch_size, ensemble=False)
obs = model_env.reset(initial_obs_batch=initial_obs)
for i in range(rollout_horizon):
pred_next_obs, pred_rewards, pred_dones, _ = model_env.step(action)
# TODO consider changing sac_buffer to vectorize this loop
for j in range(batch_size):
sac_buffer.add(
obs[j],
action[j],
pred_rewards[j],
pred_next_obs[j],
pred_dones[j],
pred_dones[j],
)
obs = pred_next_obs
):

return sac_buffer
initial_obs, action, *_ = env_dataset.sample(batch_size, ensemble=False)
obs = model_env.reset(initial_obs_batch=initial_obs)
for i in range(rollout_horizon):
action = agent.act(obs, sample=sac_samples_action, batched=True)
pred_next_obs, pred_rewards, pred_dones, _ = model_env.step(action)
# TODO change sac_buffer to vectorize this loop (the batch size will be really large)
for j in range(batch_size):
sac_buffer.add(
obs[j],
action[j],
pred_rewards[j],
pred_next_obs[j],
pred_dones[j],
pred_dones[j],
)
obs = pred_next_obs


def train(
Expand All @@ -82,13 +75,26 @@ def train(
device: torch.device,
cfg: omegaconf.DictConfig,
):
# ------------------- Initialization -------------------
obs_shape = env.observation_space.shape
act_shape = env.action_space.shape

# Agent
# agent = pytorch_sac.SACAgent()
cfg.agent.obs_dim = obs_shape[0]
cfg.agent.action_dim = act_shape[0]
cfg.agent.action_range = [
float(env.action_space.low.min()),
float(env.action_space.high.max()),
]
agent = hydra.utils.instantiate(cfg.agent)

work_dir = os.getcwd()
logger = pytorch_sac.Logger(
work_dir, save_tb=cfg.log_save_tb, log_frequency=cfg.log_frequency, agent="sac"
)

rng = np.random.RandomState(cfg.seed)

# Creating and populating environment dataset
# -------------- Create initial env. dataset --------------
env_dataset_train = replay_buffer.BootstrapReplayBuffer(
cfg.env_dataset_size,
cfg.dynamics_model_batch_size,
Expand All @@ -100,15 +106,18 @@ def train(
env_dataset_val = replay_buffer.IterableReplayBuffer(
val_buffer_capacity, cfg.dynamics_model_batch_size, obs_shape, act_shape
)
# TODO replace this with some exploration policy
collect_random_trajectories(
env,
env_dataset_train,
env_dataset_val,
cfg.initial_exploration_steps,
cfg.validation_ratio,
rng,
)

# Training loop
# ---------------------------------------------------------
# --------------------- Training Loop ---------------------
cfg.model.in_size = obs_shape[0] + act_shape[0]
cfg.model.out_size = obs_shape[0] + 1

Expand All @@ -117,26 +126,52 @@ def train(
sac_buffer_capacity = (
cfg.rollouts_per_step * cfg.rollout_horizon * cfg.rollout_batch_size
)

updates_made = 0
env_steps = 0
model_env = models.ModelEnv(env, ensemble, termination_fn)
for epoch in range(cfg.num_epochs):
if epoch % cfg.freq_train_dyn_model == 0:
train_loss, val_score = models.train_dyn_ensemble(
ensemble,
env_dataset_train,
device,
dataset_val=env_dataset_val,
patience=cfg.patience,
obs = env.reset()
done = False
while not done:
# --------------- Env. Step and adding to model dataset -----------------
action = agent.act(obs)
next_obs, reward, done, _ = env.step(action)
if rng.random() < cfg.validation_ratio:
env_dataset_val.add(obs, action, next_obs, reward, done)
else:
env_dataset_train.add(obs, action, next_obs, reward, done)
obs = next_obs

# --------------- Model Training -----------------
if env_steps % cfg.freq_train_dyn_model == 0:
train_loss, val_score = models.train_dyn_ensemble(
ensemble,
env_dataset_train,
device,
dataset_val=env_dataset_val,
patience=cfg.patience,
)

# --------------- Agent Training -----------------
sac_buffer = pytorch_sac.ReplayBuffer(
obs_shape, act_shape, sac_buffer_capacity, device
)
for _ in range(cfg.rollouts_per_step):
rollout_model_and_populate_sac_buffer(
model_env,
env_dataset_train,
agent,
sac_buffer,
cfg.sac_samples_action,
cfg.rollout_horizon,
cfg.rollout_batch_size,
)

for _ in range(cfg.num_sac_updates_per_rollout):
agent.update(sac_buffer, logger, updates_made)
updates_made += 1

logger.dump(updates_made, save=True)

sac_buffer = rollout_model(
env,
ensemble,
env_dataset_train,
termination_fn,
obs_shape,
act_shape,
sac_buffer_capacity,
cfg.rollouts_per_step,
cfg.rollout_horizon,
cfg.roullout_batch_size,
device,
)
env_steps += 1
2 changes: 1 addition & 1 deletion mbrl/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def step(self, actions: np.ndarray):
model_in = torch.from_numpy(
np.concatenate([self._current_obs, actions], axis=1)
).to(self.model.device)
model_out = self.model(model_in).cpu().numpy()[0]
model_out = self.model(model_in)[0].cpu().numpy()
next_observs = model_out[:, :-1]
rewards = model_out[:, -1]
dones = self.termination_fn(actions, next_observs)
Expand Down