Skip to content

Commit

Permalink
reshaped iCEM optimizer for consistency with SAC/PPO
Browse files Browse the repository at this point in the history
  • Loading branch information
Bakeey committed Nov 6, 2024
1 parent 1944ad1 commit 3c99bce
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 235 deletions.
3 changes: 1 addition & 2 deletions mbpo/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from mbpo.optimizers.policy_optimizers.sac.sac import SAC
from mbpo.optimizers.trajectory_optimizers.icem_optimizer import iCemTO, iCemOptimizerState, iCemParams
from mbpo.optimizers.trajectory_optimizers.icem_optimizer import iCemTO, iCemOptimizerState, iCemParams, iCEMOptimizer
from mbpo.optimizers.base_optimizer import BaseOptimizer
from mbpo.optimizers.policy_optimizers.brax_optimizers import PPOOptimizer, SACOptimizer, BraxOptimizer, \
BraxState, BraxOutput
from mbpo.optimizers.policy_optimizers.bptt_optimizer import BPTTOptimizer, BPTTState

187 changes: 57 additions & 130 deletions mbpo/optimizers/trajectory_optimizers/icem_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,151 +1,28 @@
"""Generate colored noise. Taken from: https://github.com/felixpatzelt/colorednoise/blob/master/colorednoise.py"""
from abc import abstractmethod
from functools import partial
from typing import NamedTuple, Generic
from typing import Tuple, NamedTuple, Generic

import chex
import jax
import jax.numpy as jnp
import jax.random as jr
import numpy as np
from jax import vmap

from jax import vmap, jit
from jax.nn import relu
from jax.numpy import sqrt, newaxis
from jax.numpy.fft import irfft, rfftfreq
from jaxtyping import Float, Array, Key, Scalar

from mbpo.optimizers.base_optimizer import BaseOptimizer
from mbpo.systems.base_systems import System, SystemParams
from mbpo.systems.dynamics.base_dynamics import DynamicsParams
from mbpo.systems.rewards.base_rewards import RewardParams
from mbpo.utils.optimizer_utils import rollout_actions
from mbpo.utils.general_utils import powerlaw_psd_gaussian
from mbpo.utils.type_aliases import OptimizerState


@partial(jax.jit, static_argnums=(0, 1, 3))
def powerlaw_psd_gaussian(exponent: float, size: int, rng: jax.random.PRNGKey, fmin: float = 0) -> jax.Array:
"""Gaussian (1/f)**beta noise.
Based on the algorithm in:
Timmer, J. and Koenig, M.:
On generating power law noise.
Astron. Astrophys. 300, 707-710 (1995)
Normalised to unit variance
Parameters:
-----------
exponent : float
The power-spectrum of the generated noise is proportional to
S(f) = (1 / f)**beta
flicker / pink noise: exponent beta = 1
brown noise: exponent beta = 2
Furthermore, the autocorrelation decays proportional to lag**-gamma
with gamma = 1 - beta for 0 < beta < 1.
There may be finite-size issues for beta close to one.
shape : int or iterable
The output has the given shape, and the desired power spectrum in
the last coordinate. That is, the last dimension is taken as time,
and all other components are independent.
fmin : float, optional
Low-frequency cutoff.
Default: 0 corresponds to original paper.
The power-spectrum below fmin is flat. fmin is defined relative
to a unit sampling rate (see numpy's rfftfreq). For convenience,
the passed value is mapped to max(fmin, 1/samples) internally
since 1/samples is the lowest possible finite frequency in the
sample. The largest possible value is fmin = 0.5, the Nyquist
frequency. The output for this value is white noise.
random_state : int, numpy.integer, numpy.random.Generator, numpy.random.RandomState,
optional
Optionally sets the state of NumPy's underlying random number generator.
Integer-compatible values or None are passed to np.random.default_rng.
np.random.RandomState or np.random.Generator are used directly.
Default: None.
Returns
-------
out : array
The samples.
Examples:
---------
# generate 1/f noise == pink noise == flicker noise
"""

# Make sure size is a list so we can iterate it and assign to it.
try:
size = list(size)
except TypeError:
size = [size]

# The number of samples in each time series
samples = size[-1]

# Calculate Frequencies (we asume a sample rate of one)
# Use fft functions for real output (-> hermitian spectrum)
f = rfftfreq(samples)

# Validate / normalise fmin
if 0 <= fmin <= 0.5:
fmin = max(fmin, 1. / samples) # Low frequency cutoff
else:
raise ValueError("fmin must be chosen between 0 and 0.5.")

# Build scaling factors for all frequencies
s_scale = f
ix = jnp.sum(s_scale < fmin) # Index of the cutoff

def cutoff(x, idx):
x_idx = jax.lax.dynamic_slice(x, start_indices=(idx,), slice_sizes=(1,))
y = jnp.ones_like(x) * x_idx
indexes = jnp.arange(0, x.shape[0], step=1)
first_idx = indexes < idx
z = (1 - first_idx) * x + first_idx * y
return z

def no_cutoff(x, idx):
return x

s_scale = jax.lax.cond(
jnp.logical_and(ix < len(s_scale), ix),
cutoff,
no_cutoff,
s_scale,
ix
)
s_scale = s_scale ** (-exponent / 2.)

# Calculate theoretical output standard deviation from scaling
w = s_scale[1:].copy()
w = w.at[-1].set(w[-1] * (1 + (samples % 2)) / 2.) # correct f = +-0.5
sigma = 2 * sqrt(jnp.sum(w ** 2)) / samples

# Adjust size to generate one Fourier component per frequency
size[-1] = len(f)

# Add empty dimension(s) to broadcast s_scale along last
# dimension of generated random power + phase (below)
dims_to_add = len(size) - 1
s_scale = s_scale[(newaxis,) * dims_to_add + (Ellipsis,)]

# prepare random number generator
key_sr, key_si, rng = jax.random.split(rng, 3)
sr = jax.random.normal(key=key_sr, shape=s_scale.shape) * s_scale
si = jax.random.normal(key=key_si, shape=s_scale.shape) * s_scale

# If the signal length is even, frequencies +/- 0.5 are equal
# so the coefficient must be real.
if not (samples % 2):
si = si.at[..., -1].set(0)
sr = sr.at[..., -1].set(sr[..., -1] * sqrt(2)) # Fix magnitude

# Regardless of signal length, the DC component must be real
si = si.at[..., 0].set(0)
sr = sr.at[..., 0].set(sr[..., 0] * sqrt(2)) # Fix magnitude

# Combine power + corrected phase to Fourier components
s = sr + 1J * si

# Transform to real time series & scale to unit variance
y = irfft(s, n=samples, axis=-1) / sigma
return y


class iCemParams(NamedTuple):
"""
num_particles: int = 10
Expand All @@ -159,7 +36,6 @@ class iCemParams(NamedTuple):
u_min: float | chex.Array = minimal value for action
u_max: float | chex.Array = maximal value for action
warm_start: bool = If we shift the action sequence for one and repeat the last action at initialization
"""
num_particles: int = 10
num_samples: int = 500
Expand Down Expand Up @@ -374,6 +250,57 @@ def colored_sample_fn(rng):
def act(self, obs: chex.Array, opt_state: iCemOptimizerState, evaluate: bool = True):
new_opt_state = self.optimize(initial_state=obs, opt_state=opt_state)
return new_opt_state.action, new_opt_state


class iCEMOptimizer(BaseOptimizer):
"""
iCEM Wrapper to ensure consistency with SAC and PPO optimizers
"""
def __init__(self,
horizon: int,
opt_params: iCemParams = iCemParams(),
system: System | None = None,
key: jr.PRNGKey = jr.PRNGKey(0),
**agent_kwargs):
super().__init__(system, key)
self.horizon = horizon
self.key = key
self.opt_params = opt_params
self.agent_class = iCemTO
self.agent_kwargs = agent_kwargs
if system is not None:
self.set_system(system)

def set_system(self, system: System):
super().set_system(system)

def init(self,
key: chex.PRNGKey,
true_buffer_state = None) -> iCemOptimizerState:
assert self.system is not None, "iCEM optimizer requires system to be defined."
self.agent = self.agent_class(horizon=self.horizon,
action_dim=self.system.u_dim,
key = self.key,
opt_params=self.opt_params,
**self.agent_kwargs)
self.agent.set_system(self.system)

# true_buffer_state is used to ensure compatibility with SAC and PPO
if true_buffer_state is None:
dummy_buffer_key, key = jr.split(key, 2)
true_buffer_state = self.dummy_true_buffer_state(dummy_buffer_key)
agent_state = self.agent.init(key)
agent_state.true_buffer_state = true_buffer_state
return agent_state

@partial(jit, static_argnums=(0, 3))
def act(self,
obs: chex.Array,
opt_state: iCemOptimizerState,
evaluate: bool = True) -> Tuple[chex.Array, iCemOptimizerState]:
assert self.system is not None, "iCEM optimizer requires system to be defined."
action, opt_state = self.agent.act(obs.reshape(-1,), opt_state, evaluate)
return action.reshape(1, -1), opt_state


if __name__ == "__main__":
Expand Down
103 changes: 0 additions & 103 deletions mbpo/optimizers/trajectory_optimizers/icem_wrapper.py

This file was deleted.

Loading

0 comments on commit 3c99bce

Please sign in to comment.