Skip to content

Commit

Permalink
Resolve default value of Handler.typestr if method missing.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 710747415
  • Loading branch information
niketkumar authored and Orbax Authors committed Dec 30, 2024
1 parent c227788 commit 32f2c03
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 14 deletions.
9 changes: 9 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [0.11.0] - 2024-12-30

### Fixed
- Resolve default value of Handler.typestr if method missing.
### Added
- Add announcement for grain version compatibility. See
https://github.com/google/orbax/issues/1456.


## [0.10.3] - 2024-12-19

### Added
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,33 @@ def get(


def register_handler_type(handler_cls):
_GLOBAL_HANDLER_TYPE_REGISTRY.add(handler_cls.typestr(), handler_cls)
"""Registers a checkpoint handler type in the global registry.
The registry is keyed by the handler's typestr. If the handler does not
provide a typestr, the default typestr is resolved from the handler's
module and class name.
Args:
handler_cls: The checkpoint handler class to register.
Returns:
The registered checkpoint handler class.
"""
# TODO(adamcogdell): Change HandlerTypeRegistry.add(typestr, type) to
# HandlerTypeRegistry.add(handler_type) and move following logic into
# HandlerTypeRegistry.add(). It will help to drop unit tests based on the
# singleton HandlerTypeRegistry, which can be flaky.
try:
typestr = handler_cls.typestr()
except AttributeError:
typestr = f'{handler_cls.__module__}.{handler_cls.__qualname__}'
logging.warning(
'Handler class %s does not have a typestr method. '
'Using the default typestr value "%s" instead.',
handler_cls,
typestr,
)
_GLOBAL_HANDLER_TYPE_REGISTRY.add(typestr, handler_cls)
return handler_cls


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Tests for CheckpointerHandler type registry."""

import copy
from absl.testing import absltest
from absl.testing import parameterized
from etils import epath
Expand All @@ -26,6 +27,7 @@


class TestHandler(checkpoint_handler.CheckpointHandler):

def save(self, directory: epath.Path, *args, **kwargs):
pass

Expand All @@ -34,7 +36,9 @@ def restore(self, directory: epath.Path, *args, **kwargs):


class ParentHandler(checkpoint_handler.CheckpointHandler):

class TestHandler(checkpoint_handler.CheckpointHandler):

def save(self, directory: epath.Path, *args, **kwargs):
pass

Expand All @@ -43,6 +47,7 @@ def restore(self, directory: epath.Path, *args, **kwargs):


class StandardCheckpointHandler(checkpoint_handler.CheckpointHandler):

def save(self, directory: epath.Path, *args, **kwargs):
pass

Expand All @@ -57,11 +62,16 @@ class ChildStandardCheckpointHandler(


class TypestrOverrideHandler(checkpoint_handler.CheckpointHandler):

@classmethod
def typestr(cls) -> str:
return 'typestr_override'


class NoTypestrHandler:
pass


class HandlerTypeRegistryTest(parameterized.TestCase):

def test_register_and_get(self):
Expand All @@ -78,13 +88,11 @@ def test_register_and_get(self):
)
self.assertTrue(
'__main__.TestHandler' in registry._registry
or
'handler_type_registry_test.TestHandler' in registry._registry
or 'handler_type_registry_test.TestHandler' in registry._registry
)
self.assertTrue(
'__main__.ParentHandler.TestHandler' in registry._registry
or
'handler_type_registry_test.ParentHandler.TestHandler'
or 'handler_type_registry_test.ParentHandler.TestHandler'
in registry._registry
)

Expand All @@ -97,7 +105,7 @@ def test_register_different_modules(self):
)
registry.add(
standard_checkpoint_handler.StandardCheckpointHandler.typestr(),
standard_checkpoint_handler.StandardCheckpointHandler
standard_checkpoint_handler.StandardCheckpointHandler,
)
self.assertEqual(
registry.get(
Expand All @@ -107,14 +115,13 @@ def test_register_different_modules(self):
)
self.assertTrue(
'__main__.StandardCheckpointHandler' in registry._registry
or
'handler_type_registry_test.StandardCheckpointHandler'
or 'handler_type_registry_test.StandardCheckpointHandler'
in registry._registry
)
self.assertIn(
'orbax.checkpoint._src.handlers.standard_checkpoint_handler.'
'StandardCheckpointHandler',
registry._registry
registry._registry,
)

def test_register_duplicate_handler_type(self):
Expand All @@ -133,7 +140,7 @@ def test_register_duplicate_handler_type(self):
r'<class \'(?:__main__|handler_type_registry_test)\.TestHandler\'>. '
'Cannot add type '
r'<class \'(?:__main__|handler_type_registry_test)\.'
'ParentHandler.TestHandler\'>.',
"ParentHandler.TestHandler'>.",
):
registry.add(TestHandler.typestr(), ParentHandler.TestHandler)

Expand All @@ -151,11 +158,10 @@ def test_register_subclass_handler_type(self):
registry = HandlerTypeRegistry()
registry.add(
standard_checkpoint_handler.StandardCheckpointHandler.typestr(),
standard_checkpoint_handler.StandardCheckpointHandler
standard_checkpoint_handler.StandardCheckpointHandler,
)
registry.add(
ChildStandardCheckpointHandler.typestr(),
ChildStandardCheckpointHandler
ChildStandardCheckpointHandler.typestr(), ChildStandardCheckpointHandler
)
self.assertEqual(
registry.get(
Expand All @@ -176,5 +182,23 @@ def test_typestr_override(self):
TypestrOverrideHandler,
)

def test_no_typestr(self):
backup = copy.deepcopy(handler_type_registry._GLOBAL_HANDLER_TYPE_REGISTRY)
try:
# Clear the global registry to avoid side effects from other tests.
handler_type_registry._GLOBAL_HANDLER_TYPE_REGISTRY._registry.clear()

handler_type_registry.register_handler_type(NoTypestrHandler)
registry = handler_type_registry._GLOBAL_HANDLER_TYPE_REGISTRY._registry

expected_registry0 = {
'handler_type_registry_test.NoTypestrHandler': NoTypestrHandler
}
expected_registry1 = {'__main__.NoTypestrHandler': NoTypestrHandler}
self.assertIn(registry, [expected_registry0, expected_registry1])
finally:
handler_type_registry._GLOBAL_HANDLER_TYPE_REGISTRY = backup


if __name__ == '__main__':
absltest.main()
2 changes: 1 addition & 1 deletion checkpoint/orbax/checkpoint/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

# A new PyPI release will be pushed everytime `__version__` is increased.
# Also modify version and date in CHANGELOG.
__version__ = '0.10.3'
__version__ = '0.11.0'


# TODO: b/362813406 - Add latest change timestamp and commit number.
Expand Down
6 changes: 6 additions & 0 deletions docs/guides/checkpoint/orbax_checkpoint_announcements.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Announcements

## 2024-12-30
orbax-checkpoint version `0.10.3` and
[grain](https://pypi.org/project/grain/) version `0.2.2` are not compatible.
Either upgrade `grain>=0.2.3` or `orbax-checkpoint>=0.11.0`. Please see
https://github.com/google/orbax/issues/1456 for error details.

## 2024-10-25
A new option, `strict` has been added to `ArrayRestoreArgs` (and will be
present in the next version release). The option defaults to True. This
Expand Down

0 comments on commit 32f2c03

Please sign in to comment.