Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for tooling improvement branch #1

Merged
merged 2 commits into from
Jan 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mbrl/algorithms/mbpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,12 @@ def evaluate(
num_episodes: int,
video_recorder: VideoRecorder,
) -> float:
avg_episode_reward = 0
avg_episode_reward = 0.0
for episode in range(num_episodes):
obs = env.reset()
video_recorder.init(enabled=(episode == 0))
done = False
episode_reward = 0
episode_reward = 0.0
while not done:
action = agent.act(obs)
obs, reward, done, _ = env.step(action)
Expand Down
2 changes: 1 addition & 1 deletion mbrl/env/cartpole_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class CartPoleEnv(gym.Env):
# This is a continuous version of gym's cartpole environment, with the only difference
# being valid actions are any numbers in the range [-1, 1], and the are applied as
# a multiplicative factor to the total force.
metadata = {"render.modes": ["human", "rgb_array"], "video.frames_per_second": 50}
metadata = {"render.modes": ["human", "rgb_array"], "video.frames_per_second": [50]}

def __init__(self):
self.gravity = 9.8
Expand Down
2 changes: 1 addition & 1 deletion mbrl/third_party/pytorch_sac_pranz24/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@
)

for i_episode in itertools.count(1):
episode_reward = 0
episode_reward = 0.0
episode_steps = 0
done = False
state = env.reset()
Expand Down
1 change: 0 additions & 1 deletion mbrl/util/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import mbrl.models
import mbrl.planning
import mbrl.types

from .replay_buffer import (
BootstrapIterator,
ReplayBuffer,
Expand Down
32 changes: 16 additions & 16 deletions mbrl/util/pybullet.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,11 @@ def load_state_from_file(p, filename: str) -> None:
@staticmethod
def _get_current_state_default(env: gym.wrappers.TimeLimit) -> Tuple:
"""Returns the internal state of a manipulation / pendulum environment."""
env = env.env
filename = PybulletEnvHandler.save_state_to_file(env._p)
new_env = env.env
filename = PybulletEnvHandler.save_state_to_file(new_env._p)
import pickle

pickle_bytes = pickle.dumps(env)
pickle_bytes = pickle.dumps(new_env)
return ((filename, pickle_bytes),)

@staticmethod
Expand All @@ -138,8 +138,8 @@ def _set_env_state_default(state: Tuple, env: gym.wrappers.TimeLimit) -> None:
((filename, pickle_bytes),) = state
new_env = pickle.loads(pickle_bytes)
env.env = new_env
env = env.env
PybulletEnvHandler.load_state_from_file(env._p, filename)
new_env = env.env
PybulletEnvHandler.load_state_from_file(new_env._p, filename)

@staticmethod
def _get_current_state_locomotion(env: gym.wrappers.TimeLimit) -> Tuple:
Expand All @@ -152,15 +152,15 @@ def _get_current_state_locomotion(env: gym.wrappers.TimeLimit) -> Tuple:
Args:
env (:class:`gym.wrappers.TimeLimit`): the environment.
"""
env = env.env
new_env = env.env
robot = env.robot
if not isinstance(robot, (RSWalkerBase, MJWalkerBase)):
raise RuntimeError("Invalid robot type. Expected a locomotor robot")

filename = PybulletEnvHandler.save_state_to_file(env._p)
ground_ids = env.ground_ids
potential = env.potential
reward = float(env.reward)
filename = PybulletEnvHandler.save_state_to_file(new_env._p)
ground_ids = new_env.ground_ids
potential = new_env.potential
reward = float(new_env.reward)
robot_keys: List[Tuple[str, Callable]] = [
("body_rpy", tuple),
("body_xyz", tuple),
Expand Down Expand Up @@ -231,12 +231,12 @@ def _set_env_state_locomotion(state: Tuple, env: gym.wrappers.TimeLimit):
robot_data,
) = state

env = env.env
env.ground_ids = ground_ids
env.potential = potential
env.reward = reward
PybulletEnvHandler.load_state_from_file(env._p, filename)
new_env = env.env
new_env.ground_ids = ground_ids
new_env.potential = potential
new_env.reward = reward
PybulletEnvHandler.load_state_from_file(new_env._p, filename)
for k, v in robot_data.items():
setattr(env.robot, k, v)
setattr(new_env.robot, k, v)
else:
raise RuntimeError("Only pybulletgym environments supported.")