diff --git a/python/google/protobuf/internal/message_factory_test.py b/python/google/protobuf/internal/message_factory_test.py index 5c9e4bf8ea6b5..e560725220dfa 100644 --- a/python/google/protobuf/internal/message_factory_test.py +++ b/python/google/protobuf/internal/message_factory_test.py @@ -294,6 +294,29 @@ def testExtensionValueInDifferentFile(self): self.assertEqual(234, m.Extensions[ext1].setting) self.assertEqual(345, m.Extensions[ext2].setting) + def testNestedDefinition(self): + f = descriptor_pb2.FileDescriptorProto( + name='google/protobuf/internal/meta_class.proto', + package='google.protobuf.python.internal') + msg_proto = f.message_type.add(name='Empty') + msg_proto.nested_type.add(name='Nested') + msg_proto.field.add(name='nested_field', + number=1, + label=descriptor.FieldDescriptor.LABEL_REPEATED, + type=descriptor.FieldDescriptor.TYPE_MESSAGE, + type_name='Nested') + + msg_proto.nested_type[0].nested_type.add(name='DoublyNested') + msg_proto.nested_type[0].field.add(name='doubly_nested_field', + number=2, + label=descriptor.FieldDescriptor.LABEL_REPEATED, + type=descriptor.FieldDescriptor.TYPE_MESSAGE, + type_name='DoublyNested') + + messages = message_factory.GetMessages([f]) + self.assertIn('google.protobuf.python.internal.Empty.Nested', messages) + self.assertIn('google.protobuf.python.internal.Empty.Nested.DoublyNested', messages) + def testDescriptorKeepConcreteClass(self): def loadFile(): f= descriptor_pb2.FileDescriptorProto( @@ -310,6 +333,8 @@ def loadFile(): messages = loadFile() for des, meta_class in messages.items(): + if des == "google.protobuf.python.internal.Empty.Nested": + continue message = meta_class() nested_des = message.DESCRIPTOR.nested_types_by_name['Nested'] nested_msg = nested_des._concrete_class() diff --git a/python/google/protobuf/message_factory.py b/python/google/protobuf/message_factory.py index 9b64ff05b846c..c8a98fc6e181a 100644 --- a/python/google/protobuf/message_factory.py +++ b/python/google/protobuf/message_factory.py @@ -56,6 +56,8 @@ def GetMessageClassesForFiles(files, pool): This will find and resolve dependencies, failing if the descriptor pool cannot satisfy them. + This will also recursively find any nested definitions. + Args: files: The file names to extract messages from. pool: The descriptor pool to find the files including the dependent files. @@ -69,6 +71,13 @@ def GetMessageClassesForFiles(files, pool): for desc in file_desc.message_types_by_name.values(): result[desc.full_name] = GetMessageClass(desc) + # Recursively load protos for nested definitions. + nested_descriptions = list(desc.nested_types_by_name.values()) + while nested_descriptions: + nested_desc = nested_descriptions.pop() + result[nested_desc.full_name] = GetMessageClass(nested_desc) + nested_descriptions.extend(nested_desc.nested_types_by_name.values()) + # While the extension FieldDescriptors are created by the descriptor pool, # the python classes created in the factory need them to be registered # explicitly, which is done below.