Skip to content
This repository has been archived by the owner on Sep 1, 2024. It is now read-only.

Misc updates to requirements and build files #161

Closed
wants to merge 16 commits into from
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ repos:
hooks:
- id: black
files: 'mbrl'
language_version: python3.7

- repo: https://gitlab.com/pycqa/flake8
rev: 3.9.2
Expand Down
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,11 @@ installation and are specific to models of type
We are planning to extend this in the future; if you have useful suggestions
don't hesitate to raise an issue or submit a pull request!

## Advanced Examples
MBRL-Lib can be used for many different research projects in the subject area.
Below are some community-contributed examples:
* [Trajectory-based Dynamics Model](https://arxiv.org/abs/2012.09156) Training [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/natolambert/mbrl-lib-dev/blob/main/notebooks/traj_based_model.ipynb)

## Documentation
Please check out our **[documentation](https://facebookresearch.github.io/mbrl-lib/)**
and don't hesitate to raise issues or contribute if anything is unclear!
Expand Down
6 changes: 3 additions & 3 deletions mbrl/algorithms/mbpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def rollout_model_and_populate_sac_buffer(
rollout_horizon: int,
batch_size: int,
):

batch = replay_buffer.sample(batch_size)
initial_obs, *_ = cast(mbrl.types.TransitionBatch, batch).astuple()
model_state = model_env.reset(
Expand Down Expand Up @@ -68,12 +67,12 @@ def evaluate(
num_episodes: int,
video_recorder: VideoRecorder,
) -> float:
avg_episode_reward = 0
avg_episode_reward = 0.0
for episode in range(num_episodes):
obs = env.reset()
video_recorder.init(enabled=(episode == 0))
done = False
episode_reward = 0
episode_reward = 0.0
while not done:
action = agent.act(obs)
obs, reward, done, _ = env.step(action)
Expand Down Expand Up @@ -198,6 +197,7 @@ def train(
obs, done = None, False
for steps_epoch in range(cfg.overrides.epoch_length):
if steps_epoch == 0 or done:
steps_epoch = 0
obs, done = env.reset(), False
# --- Doing env step and adding to model dataset ---
next_obs, reward, done, _ = mbrl.util.common.step_env_and_add_to_buffer(
Expand Down
2 changes: 0 additions & 2 deletions mbrl/diagnostics/control_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def evaluate_all_action_sequences(
pool: mp.Pool, # type: ignore
current_state: Tuple,
) -> torch.Tensor:

res_objs = [
pool.apply_async(evaluate_sequence_fn, (sequence, current_state)) # type: ignore
for sequence in action_sequences
Expand Down Expand Up @@ -148,7 +147,6 @@ def get_random_trajectory(horizon):
with mp.Pool(
processes=args.num_processes, initializer=init, initargs=[args.env, args.seed]
) as pool__:

total_reward__ = 0
frames = []
max_population_size = optimizer_cfg.population_size
Expand Down
2 changes: 1 addition & 1 deletion mbrl/env/cartpole_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class CartPoleEnv(gym.Env):
# This is a continuous version of gym's cartpole environment, with the only difference
# being valid actions are any numbers in the range [-1, 1], and the are applied as
# a multiplicative factor to the total force.
metadata = {"render.modes": ["human", "rgb_array"], "video.frames_per_second": 50}
metadata = {"render.modes": ["human", "rgb_array"], "video.frames_per_second": [50]}

def __init__(self):
self.gravity = 9.8
Expand Down
1 change: 1 addition & 0 deletions mbrl/planning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .core import Agent, RandomAgent, complete_agent_cfg, load_agent
from .linear_feedback import PIDAgent
from .trajectory_opt import (
CEMOptimizer,
ICEMOptimizer,
Expand Down
122 changes: 122 additions & 0 deletions mbrl/planning/linear_feedback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional

import numpy as np

from .core import Agent


class PIDAgent(Agent):
"""
Agent that reacts via an internal set of proportional–integral–derivative controllers.

A broad history of the PID controller can be found here:
https://en.wikipedia.org/wiki/PID_controller.

Args:
k_p (np.ndarry): proportional control coeff (Nx1)
k_i (np.ndarry): integral control coeff (Nx1)
k_d (np.ndarry): derivative control coeff (Nx1)
target (np.ndarry): setpoint (Nx1)
state_mapping (np.ndarry): indices of the state vector to apply the PID control to.
E.g. for a system with states [angle, angle_vel, position, position_vel], state_mapping
of [1, 3] and dim of 2 will apply the PID to angle_vel and position_vel variables.
batch_dim (int): number of samples to compute actions for simultaneously
"""

def __init__(
self,
k_p: np.ndarray,
k_i: np.ndarray,
k_d: np.ndarray,
target: np.ndarray,
state_mapping: Optional[np.ndarray] = None,
batch_dim: Optional[int] = 1,
):
super().__init__()
assert len(k_p) == len(k_i) == len(k_d) == len(target)
self.n_dof = len(k_p)

# State mapping defaults to first N states
if state_mapping is not None:
assert len(state_mapping) == len(target)
self.state_mapping = state_mapping
else:
self.state_mapping = np.arange(0, self.n_dof)

self.batch_dim = batch_dim

self._prev_error = np.zeros((self.n_dof, self.batch_dim))
self._cum_error = np.zeros((self.n_dof, self.batch_dim))

self.k_p = np.repeat(k_p[:, np.newaxis], self.batch_dim, axis=1)
self.k_i = np.repeat(k_i[:, np.newaxis], self.batch_dim, axis=1)
self.k_d = np.repeat(k_d[:, np.newaxis], self.batch_dim, axis=1)
self.target = np.repeat(target[:, np.newaxis], self.batch_dim, axis=1)

def act(self, obs: np.ndarray, **_kwargs) -> np.ndarray:
"""Issues an action given an observation.

This method optimizes a given observation or batch of observations for a
one-step action choice.


Args:
obs (np.ndarray): the observation for which the action is needed either N x 1 or N x B,
where N is the state dim and B is the batch size.

Returns:
(np.ndarray): the action outputted from the PID, either shape n_dof x 1 or n_dof x B.
"""
if obs.ndim == 1:
obs = np.expand_dims(obs, -1)
if len(obs) > self.n_dof:
pos = obs[self.state_mapping]
else:
pos = obs

error = self.target - pos
self._cum_error += error

P_value = np.multiply(self.k_p, error)
I_value = np.multiply(self.k_i, self._cum_error)
D_value = np.multiply(self.k_d, (error - self._prev_error))
self._prev_error = error
action = P_value + I_value + D_value
return action

def reset(self):
"""
Reset internal errors.
"""
self._prev_error = np.zeros((self.n_dof, self.batch_dim))
self._cum_error = np.zeros((self.n_dof, self.batch_dim))

def get_errors(self):
return self._prev_error, self._cum_error

def _get_P(self):
return self.k_p

def _get_I(self):
return self.k_i

def _get_D(self):
return self.k_d

def _get_targets(self):
return self.target

def get_parameters(self):
"""
Returns the parameters of the PID agent concatenated.

Returns:
(np.ndarray): the parameters.
"""
return np.stack(
(self._get_P(), self._get_I(), self._get_D(), self._get_targets())
).flatten()
2 changes: 1 addition & 1 deletion mbrl/third_party/pytorch_sac_pranz24/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@
)

for i_episode in itertools.count(1):
episode_reward = 0
episode_reward = 0.0
episode_steps = 0
done = False
state = env.reset()
Expand Down
1 change: 1 addition & 0 deletions mbrl/third_party/pytorch_sac_pranz24/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
LOG_SIG_MIN = -20
epsilon = 1e-6


# Initialize Policy weights
def weights_init_(m):
if isinstance(m, nn.Linear):
Expand Down
1 change: 0 additions & 1 deletion mbrl/util/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import mbrl.models
import mbrl.planning
import mbrl.types

from .replay_buffer import (
BootstrapIterator,
ReplayBuffer,
Expand Down
2 changes: 1 addition & 1 deletion mbrl/util/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def dump(self, step: int, prefix: str, save: bool = True, color: str = "yellow")
if len(self._meters) == 0:
return
if save:
data = dict([(key, meter.value()) for key, meter in self._meters.items()])
data = {key: meter.value() for key, meter in self._meters.items()}
data["step"] = step
self._dump_to_csv(data)
self._dump_to_console(data, prefix, color)
Expand Down
32 changes: 16 additions & 16 deletions mbrl/util/pybullet.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,11 @@ def load_state_from_file(p, filename: str) -> None:
@staticmethod
def _get_current_state_default(env: gym.wrappers.TimeLimit) -> Tuple:
"""Returns the internal state of a manipulation / pendulum environment."""
env = env.env
filename = PybulletEnvHandler.save_state_to_file(env._p)
new_env = env.env
filename = PybulletEnvHandler.save_state_to_file(new_env._p)
import pickle

pickle_bytes = pickle.dumps(env)
pickle_bytes = pickle.dumps(new_env)
return ((filename, pickle_bytes),)

@staticmethod
Expand All @@ -138,8 +138,8 @@ def _set_env_state_default(state: Tuple, env: gym.wrappers.TimeLimit) -> None:
((filename, pickle_bytes),) = state
new_env = pickle.loads(pickle_bytes)
env.env = new_env
env = env.env
PybulletEnvHandler.load_state_from_file(env._p, filename)
new_env = env.env
PybulletEnvHandler.load_state_from_file(new_env._p, filename)

@staticmethod
def _get_current_state_locomotion(env: gym.wrappers.TimeLimit) -> Tuple:
Expand All @@ -152,15 +152,15 @@ def _get_current_state_locomotion(env: gym.wrappers.TimeLimit) -> Tuple:
Args:
env (:class:`gym.wrappers.TimeLimit`): the environment.
"""
env = env.env
new_env = env.env
robot = env.robot
if not isinstance(robot, (RSWalkerBase, MJWalkerBase)):
raise RuntimeError("Invalid robot type. Expected a locomotor robot")

filename = PybulletEnvHandler.save_state_to_file(env._p)
ground_ids = env.ground_ids
potential = env.potential
reward = float(env.reward)
filename = PybulletEnvHandler.save_state_to_file(new_env._p)
ground_ids = new_env.ground_ids
potential = new_env.potential
reward = float(new_env.reward)
robot_keys: List[Tuple[str, Callable]] = [
("body_rpy", tuple),
("body_xyz", tuple),
Expand Down Expand Up @@ -231,12 +231,12 @@ def _set_env_state_locomotion(state: Tuple, env: gym.wrappers.TimeLimit):
robot_data,
) = state

env = env.env
env.ground_ids = ground_ids
env.potential = potential
env.reward = reward
PybulletEnvHandler.load_state_from_file(env._p, filename)
new_env = env.env
new_env.ground_ids = ground_ids
new_env.potential = potential
new_env.reward = reward
PybulletEnvHandler.load_state_from_file(new_env._p, filename)
for k, v in robot_data.items():
setattr(env.robot, k, v)
setattr(new_env.robot, k, v)
else:
raise RuntimeError("Only pybulletgym environments supported.")
4 changes: 2 additions & 2 deletions mbrl/util/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def _get_indices_valid_starts(
# iterator. It's a good price to pay for now, since it simplifies things
# enormously and it's less error prone
valid_starts = []
for (start, end) in trajectory_indices:
for start, end in trajectory_indices:
if end - start < sequence_length:
continue
valid_starts.extend(list(range(start, end - sequence_length + 1)))
Expand Down Expand Up @@ -368,7 +368,7 @@ def _get_indices_valid_starts(
# iterator. It's a good price to pay for now, since it simplifies things
# enormously and it's less error prone
valid_starts = []
for (start, end) in trajectory_indices:
for start, end in trajectory_indices:
if end - start < sequence_length:
continue
valid_starts.extend(list(range(start, end - sequence_length + 1)))
Expand Down
File renamed without changes.
1 change: 1 addition & 0 deletions requirements/dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ sphinx-rtd-theme>=0.5.0
flake8>=3.8.4
mypy>=0.902
black>=21.4b2
importlib_metadata<5
pytest>=6.0.1
types-pyyaml>=0.1.6
types-termcolor>=0.1.0
Expand Down
2 changes: 1 addition & 1 deletion requirements/main.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ tensorboard>=2.4.0
imageio>=2.9.0
numpy>=1.19.1
matplotlib>=3.3.1
gym==0.17.2
gym>=0.20.0,<0.25.0
jupyter>=1.0.0
pytest>=6.0.1
sk-video>=1.1.10
Expand Down
Loading