diff --git a/dreamerv3/configs.yaml b/configs.yaml similarity index 100% rename from dreamerv3/configs.yaml rename to configs.yaml diff --git a/dreamerv3/jaxagent.py b/dreamerv3/jaxagent.py index f2ad2695..09e0878c 100644 --- a/dreamerv3/jaxagent.py +++ b/dreamerv3/jaxagent.py @@ -41,8 +41,8 @@ def __init__(self, agent_cls, obs_space, act_space, config): self.keys = [k for k in self.spaces if ( not k.startswith('_') and not k.startswith('log_') and k != 'reset')] - available = jax.devices(self.jaxcfg.platform) - embodied.print(f'JAX devices ({jax.local_device_count()}):', available) + available = jax.devices('cpu') + embodied.print(f'JAX devices (CPU):', available) if self.jaxcfg.assert_num_devices > 0: assert len(available) == self.jaxcfg.assert_num_devices, ( available, len(available), self.jaxcfg.assert_num_devices) @@ -80,7 +80,7 @@ def __init__(self, agent_cls, obs_space, act_space, config): self.should_sync = embodied.when.Every(self.jaxcfg.sync_every) self.policy_params = jax.device_put( {k: self.params[k].copy() for k in self.policy_keys}, - self.policy_mirrored) + device=jax.devices('cpu')[0]) self._lower_train() self._lower_report() @@ -187,7 +187,7 @@ def train(self, data, carry): if self.should_sync(self.updates) and not self.pending_sync: self.pending_sync = jax.device_put( - {k: allo[k] for k in self.policy_keys}, self.policy_mirrored) + {k: allo[k] for k in self.policy_keys}, device=jax.devices('cpu')[0]) else: jax.tree.map(lambda x: x.delete(), allo) @@ -249,10 +249,10 @@ def load(self, state): chex.assert_trees_all_equal_shapes(self.params, state) jax.tree.map(lambda x: x.delete(), self.params) jax.tree.map(lambda x: x.delete(), self.policy_params) - self.params = jax.device_put(state, self.train_mirrored) + self.params = jax.device_put(state, device=jax.devices('cpu')[0]) self.policy_params = jax.device_put( {k: self.params[k].copy() for k in self.policy_keys}, - self.policy_mirrored) + device=jax.devices('cpu')[0]) def _setup(self): try: @@ -283,7 +283,7 @@ def _setup(self): xla_flags.append('--xla_dump_hlo_as_long_text') if xla_flags: os.environ['XLA_FLAGS'] = ' '.join(xla_flags) - jax.config.update('jax_platform_name', self.jaxcfg.platform) + jax.config.update('jax_platform_name', 'cpu') jax.config.update('jax_disable_jit', not self.jaxcfg.jit) if self.jaxcfg.transfer_guard: jax.config.update('jax_transfer_guard', 'disallow') @@ -379,19 +379,17 @@ def _take_outs(self, outs): def _init_params(self, obs_space, act_space): B, T = self.config.batch_size, self.config.batch_length - seed = jax.device_put(np.array([self.config.seed, 0], np.uint32)) - data = jax.device_put(self._dummy_batch(self.spaces, (B, T))) - params = nj.init(self.agent.init_train, static_argnums=[1])( - {}, B, seed=seed) - _, carry = jax.jit(nj.pure(self.agent.init_train), static_argnums=[1])( - params, B, seed=seed) + seed = jax.device_put(np.array([self.config.seed, 0], np.uint32), device=jax.devices('cpu')[0]) # Ensure seed uses CPU + data = jax.device_put(self._dummy_batch(self.spaces, (B, T)), device=jax.devices('cpu')[0]) # Ensure data uses CPU + params = nj.init(self.agent.init_train, static_argnums=[1])({}, B, seed=seed) + _, carry = jax.jit(nj.pure(self.agent.init_train), static_argnums=[1])(params, B, seed=seed) params = nj.init(self.agent.train)(params, data, carry, seed=seed) - return jax.device_put(params, self.train_mirrored) + return jax.device_put(params, device=jax.devices('cpu')[0]) # Ensure params uses CPU def _next_seeds(self, sharding): shape = [2 * x for x in sharding.mesh.devices.shape] seeds = self.rng.integers(0, np.iinfo(np.uint32).max, shape, np.uint32) - return jax.device_put(seeds, sharding) + return jax.device_put(seeds, device=jax.devices('cpu')[0]) def _filter_data(self, data): return {k: v for k, v in data.items() if k in self.keys} diff --git a/example.py b/example.py index 11a4b0dc..d2e183c0 100644 --- a/example.py +++ b/example.py @@ -11,7 +11,7 @@ def main(): config = embodied.Config(dreamerv3.Agent.configs['defaults']) config = config.update({ - **dreamerv3.Agent.configs['size100m'], + **dreamerv3.Agent.configs['size12m'], #changed from 100m to 12m , select simple available model 'logdir': f'~/logdir/{embodied.timestamp()}-example', 'run.train_ratio': 32, })