Skip to content

Commit

Permalink
improve error logging in orbax checkpoint loading
Browse files Browse the repository at this point in the history
  • Loading branch information
thorben-frank committed Aug 22, 2024
1 parent a2e394e commit ba31537
Showing 1 changed file with 22 additions and 15 deletions.
37 changes: 22 additions & 15 deletions mlff/io/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,29 @@

def load_params_from_ckpt_dir(ckpt_dir):
try:
loaded_mngr = checkpoint.CheckpointManager(
pathlib.Path(ckpt_dir).resolve(),
item_names=('state',),
item_handlers={'state': checkpoint.StandardCheckpointHandler()},
options=checkpoint.CheckpointManagerOptions(step_prefix="ckpt"),
)

mngr_state = loaded_mngr.restore(
loaded_mngr.latest_step()
)

state = mngr_state.get('state')

return state['valid_params']
except FileNotFoundError:
return load_state_from_ckpt_dir(ckpt_dir)['valid_params']
except ValueError:
try:
loaded_mngr = checkpoint.CheckpointManager(
pathlib.Path(ckpt_dir).resolve(),
item_names=('state',),
item_handlers={'state': checkpoint.StandardCheckpointHandler()},
options=checkpoint.CheckpointManagerOptions(step_prefix="ckpt"),
)

mngr_state = loaded_mngr.restore(
loaded_mngr.latest_step()
)

state = mngr_state.get('state')

return state['valid_params']
except ValueError:
raise RuntimeError(
f'Loading model parameters from checkpoint saved at {ckpt_dir} failed. '
'This error typically occurs if within the ckpt_XXX directory there is another folder. '
'Consider moving the folder somewhere else.'
)


def load_state_from_ckpt_dir(ckpt_dir: str):
Expand Down

0 comments on commit ba31537

Please sign in to comment.