Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python ExternalTransformProvider improvements #33359

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions sdks/python/apache_beam/transforms/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,14 +962,14 @@ def __init__(
self, path_to_jar, extra_args=None, classpath=None, append_args=None):
if extra_args and append_args:
raise ValueError('Only one of extra_args or append_args may be provided')
self._path_to_jar = path_to_jar
self.path_to_jar = path_to_jar
self._extra_args = extra_args
self._classpath = classpath or []
self._service_count = 0
self._append_args = append_args or []

def is_existing_service(self):
return subprocess_server.is_service_endpoint(self._path_to_jar)
return subprocess_server.is_service_endpoint(self.path_to_jar)

@staticmethod
def _expand_jars(jar):
Expand Down Expand Up @@ -997,7 +997,7 @@ def _expand_jars(jar):
def _default_args(self):
"""Default arguments to be used by `JavaJarExpansionService`."""

to_stage = ','.join([self._path_to_jar] + sum((
to_stage = ','.join([self.path_to_jar] + sum((
JavaJarExpansionService._expand_jars(jar)
for jar in self._classpath or []), []))
args = ['{{PORT}}', f'--filesToStage={to_stage}']
Expand All @@ -1009,24 +1009,24 @@ def _default_args(self):

def __enter__(self):
if self._service_count == 0:
self._path_to_jar = subprocess_server.JavaJarServer.local_jar(
self._path_to_jar)
self.path_to_jar = subprocess_server.JavaJarServer.local_jar(
self.path_to_jar)
if self._extra_args is None:
self._extra_args = self._default_args() + self._append_args
# Consider memoizing these servers (with some timeout).
logging.info(
'Starting a JAR-based expansion service from JAR %s ' + (
'and with classpath: %s' %
self._classpath if self._classpath else ''),
self._path_to_jar)
self.path_to_jar)
classpath_urls = [
subprocess_server.JavaJarServer.local_jar(path)
for jar in self._classpath
for path in JavaJarExpansionService._expand_jars(jar)
]
self._service_provider = subprocess_server.JavaJarServer(
ExpansionAndArtifactRetrievalStub,
self._path_to_jar,
self.path_to_jar,
self._extra_args,
classpath=classpath_urls)
self._service = self._service_provider.__enter__()
Expand Down
65 changes: 44 additions & 21 deletions sdks/python/apache_beam/transforms/external_transform_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from apache_beam.transforms import PTransform
from apache_beam.transforms.external import BeamJarExpansionService
from apache_beam.transforms.external import JavaJarExpansionService
from apache_beam.transforms.external import SchemaAwareExternalTransform
from apache_beam.transforms.external import SchemaTransformsConfig
from apache_beam.typehints.schemas import named_tuple_to_schema
Expand Down Expand Up @@ -133,37 +134,57 @@ class ExternalTransformProvider:
(see the `urn_pattern` parameter).

These classes are generated when :class:`ExternalTransformProvider` is
initialized. We need to give it one or more expansion service addresses that
are already up and running:
>>> provider = ExternalTransformProvider(["localhost:12345",
... "localhost:12121"])
We can also give it the gradle target of a standard Beam expansion service:
>>> provider = ExternalTransform(BeamJarExpansionService(
... "sdks:java:io:google-cloud-platform:expansion-service:shadowJar"))
Let's take a look at the output of :func:`get_available()` to know the
available transforms in the expansion service(s) we provided:
initialized. You can give it an expansion service address that is already
up and running:

>>> provider = ExternalTransformProvider("localhost:12345")

Or you can give it the path to an expansion service Jar file:

>>> provider = ExternalTransformProvider(JavaJarExpansionService(
"path/to/expansion-service.jar"))

Or you can give it the gradle target of a standard Beam expansion service:

>>> provider = ExternalTransformProvider(BeamJarExpansionService(
"sdks:java:io:google-cloud-platform:expansion-service:shadowJar"))

Note that you can provide a list of these services:

>>> provider = ExternalTransformProvider([
"localhost:12345",
JavaJarExpansionService("path/to/expansion-service.jar"),
BeamJarExpansionService(
"sdks:java:io:google-cloud-platform:expansion-service:shadowJar")])

The output of :func:`get_available()` provides a list of available transforms
in the provided expansion service(s):

>>> provider.get_available()
[('JdbcWrite', 'beam:schematransform:org.apache.beam:jdbc_write:v1'),
('BigtableRead', 'beam:schematransform:org.apache.beam:bigtable_read:v1'),
...]

Then retrieve a transform by :func:`get()`, :func:`get_urn()`, or by directly
accessing it as an attribute of :class:`ExternalTransformProvider`.
All of the following commands do the same thing:
You can retrieve a transform with :func:`get()`, :func:`get_urn()`, or by
directly accessing it as an attribute. The following lines all do the same
thing:

>>> provider.get('BigqueryStorageRead')
>>> provider.get_urn(
... 'beam:schematransform:org.apache.beam:bigquery_storage_read:v1')
'beam:schematransform:org.apache.beam:bigquery_storage_read:v1')
>>> provider.BigqueryStorageRead

You can inspect the transform's documentation to know more about it. This
returns some documentation only IF the underlying SchemaTransform
implementation provides any.
You can inspect the transform's documentation for more details. The following
returns the documentation provided by the underlying SchemaTransform. If no
such documentation is provided, this will be empty.

>>> import inspect
>>> inspect.getdoc(provider.BigqueryStorageRead)

Similarly, you can inspect the transform's signature to know more about its
parameters, including their names, types, and any documentation that the
underlying SchemaTransform may provide:

>>> inspect.signature(provider.BigqueryStorageRead)
(query: 'typing.Union[str, NoneType]: The SQL query to be executed to...',
row_restriction: 'typing.Union[str, NoneType]: Read only rows that match...',
Expand All @@ -178,8 +199,6 @@ class ExternalTransformProvider:
query=query,
row_restriction=restriction)
| 'Some processing' >> beam.Map(...))

Experimental; no backwards compatibility guarantees.
"""
def __init__(self, expansion_services, urn_pattern=STANDARD_URN_PATTERN):
f"""Initialize an ExternalTransformProvider
Expand All @@ -188,6 +207,7 @@ def __init__(self, expansion_services, urn_pattern=STANDARD_URN_PATTERN):
A list of expansion services to discover transforms from.
Supported forms:
* a string representing the expansion service address
* a :attr:`JavaJarExpansionService` pointing to the path of a Java Jar
* a :attr:`BeamJarExpansionService` pointing to a gradle target
:param urn_pattern:
The regular expression used to match valid transforms. In addition to
Expand All @@ -213,11 +233,14 @@ def _create_wrappers(self):
target = service
if isinstance(service, BeamJarExpansionService):
target = service.gradle_target
if isinstance(service, JavaJarExpansionService):
target = service.path_to_jar
try:
schematransform_configs = SchemaAwareExternalTransform.discover(service)
except Exception as e:
logging.exception(
"Encountered an error while discovering expansion service %s:\n%s",
"Encountered an error while discovering "
"expansion service at '%s':\n%s",
target,
e)
continue
Expand Down Expand Up @@ -249,7 +272,7 @@ def _create_wrappers(self):

if skipped_urns:
logging.info(
"Skipped URN(s) in %s that don't follow the pattern \"%s\": %s",
"Skipped URN(s) in '%s' that don't follow the pattern \"%s\": %s",
target,
self._urn_pattern,
skipped_urns)
Expand All @@ -262,7 +285,7 @@ def get_available(self) -> List[Tuple[str, str]]:
return list(self._name_to_urn.items())

def get_all(self) -> Dict[str, ExternalTransform]:
"""Get all ExternalTransform"""
"""Get all ExternalTransforms"""
return self._transforms

def get(self, name) -> ExternalTransform:
Expand Down
Loading