Skip to content

Commit

Permalink
Update Python typehints in Beam YAML (apache#33523)
Browse files Browse the repository at this point in the history
  • Loading branch information
jrmccluskey authored and Naireen committed Jan 17, 2025
1 parent bdce95d commit 28d65bf
Show file tree
Hide file tree
Showing 10 changed files with 56 additions and 70 deletions.
15 changes: 7 additions & 8 deletions sdks/python/apache_beam/yaml/json_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@
"""

import json
from collections.abc import Callable
from typing import Any
from typing import Callable
from typing import Dict
from typing import Optional

import jsonschema
Expand All @@ -49,7 +48,7 @@


def json_schema_to_beam_schema(
json_schema: Dict[str, Any]) -> schema_pb2.Schema:
json_schema: dict[str, Any]) -> schema_pb2.Schema:
"""Returns a Beam schema equivalent for the given Json schema."""
def maybe_nullable(beam_type, nullable):
if nullable:
Expand All @@ -75,7 +74,7 @@ def maybe_nullable(beam_type, nullable):
])


def json_type_to_beam_type(json_type: Dict[str, Any]) -> schema_pb2.FieldType:
def json_type_to_beam_type(json_type: dict[str, Any]) -> schema_pb2.FieldType:
"""Returns a Beam schema type for the given Json (schema) type."""
if not isinstance(json_type, dict) or 'type' not in json_type:
raise ValueError(f'Malformed type {json_type}.')
Expand Down Expand Up @@ -107,7 +106,7 @@ def json_type_to_beam_type(json_type: Dict[str, Any]) -> schema_pb2.FieldType:


def beam_schema_to_json_schema(
beam_schema: schema_pb2.Schema) -> Dict[str, Any]:
beam_schema: schema_pb2.Schema) -> dict[str, Any]:
return {
'type': 'object',
'properties': {
Expand All @@ -118,7 +117,7 @@ def beam_schema_to_json_schema(
}


def beam_type_to_json_type(beam_type: schema_pb2.FieldType) -> Dict[str, Any]:
def beam_type_to_json_type(beam_type: schema_pb2.FieldType) -> dict[str, Any]:
type_info = beam_type.WhichOneof("type_info")
if type_info == "atomic_type":
if beam_type.atomic_type in BEAM_ATOMIC_TYPES_TO_JSON:
Expand Down Expand Up @@ -198,7 +197,7 @@ def json_to_row(beam_type: schema_pb2.FieldType) -> Callable[[Any], Any]:

def json_parser(
beam_schema: schema_pb2.Schema,
json_schema: Optional[Dict[str,
json_schema: Optional[dict[str,
Any]] = None) -> Callable[[bytes], beam.Row]:
"""Returns a callable converting Json strings to Beam rows of the given type.
Expand Down Expand Up @@ -307,7 +306,7 @@ def _validate_compatible(weak_schema, strong_schema):


def row_validator(beam_schema: schema_pb2.Schema,
json_schema: Dict[str, Any]) -> Callable[[Any], Any]:
json_schema: dict[str, Any]) -> Callable[[Any], Any]:
"""Returns a callable that will fail on elements not respecting json_schema.
"""
if not json_schema:
Expand Down
4 changes: 2 additions & 2 deletions sdks/python/apache_beam/yaml/yaml_combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

"""This module defines the basic Combine operation."""

from collections.abc import Iterable
from collections.abc import Mapping
from typing import Any
from typing import Iterable
from typing import Mapping
from typing import Optional

import apache_beam as beam
Expand Down
3 changes: 1 addition & 2 deletions sdks/python/apache_beam/yaml/yaml_enrichment.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#

from typing import Any
from typing import Dict
from typing import Optional

import apache_beam as beam
Expand All @@ -43,7 +42,7 @@
def enrichment_transform(
pcoll,
enrichment_handler: str,
handler_config: Dict[str, Any],
handler_config: dict[str, Any],
timeout: Optional[float] = 30):
# pylint: disable=line-too-long

Expand Down
16 changes: 7 additions & 9 deletions sdks/python/apache_beam/yaml/yaml_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,11 @@

import io
import os
from collections.abc import Callable
from collections.abc import Iterable
from collections.abc import Mapping
from typing import Any
from typing import Callable
from typing import Iterable
from typing import List
from typing import Mapping
from typing import Optional
from typing import Tuple

import fastavro
import yaml
Expand Down Expand Up @@ -110,7 +108,7 @@ def read_from_bigquery(
row_restriction (str): Optional SQL text filtering statement, similar to a
WHERE clause in a query. Aggregates are not supported. Restricted to a
maximum length for 1 MB.
selected_fields (List[str]): Optional List of names of the fields in the
selected_fields (list[str]): Optional List of names of the fields in the
table that should be read. If empty, all fields will be read. If the
specified field is a nested field, all the sub-fields in the field will be
selected. The output field order is unrelated to the order of fields
Expand Down Expand Up @@ -211,7 +209,7 @@ def raise_exception(failed_row_with_error):

def _create_parser(
format,
schema: Any) -> Tuple[schema_pb2.Schema, Callable[[bytes], beam.Row]]:
schema: Any) -> tuple[schema_pb2.Schema, Callable[[bytes], beam.Row]]:

format = format.upper()

Expand Down Expand Up @@ -355,7 +353,7 @@ def read_from_pubsub(
elif not topic and not subscription:
raise TypeError('One of topic or subscription may be specified.')
payload_schema, parser = _create_parser(format, schema)
extra_fields: List[schema_pb2.Field] = []
extra_fields: list[schema_pb2.Field] = []
if not attributes and not attributes_map:
mapper = lambda msg: parser(msg)
else:
Expand Down Expand Up @@ -443,7 +441,7 @@ def write_to_pubsub(
"""
input_schema = schemas.schema_from_element_type(pcoll.element_type)

extra_fields: List[str] = []
extra_fields: list[str] = []
if isinstance(attributes, str):
attributes = [attributes]
if attributes:
Expand Down
8 changes: 3 additions & 5 deletions sdks/python/apache_beam/yaml/yaml_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

"""This module defines the Join operation."""
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Union

Expand Down Expand Up @@ -176,9 +174,9 @@ def _SqlJoinTransform(
pcolls,
sql_transform_constructor,
*,
equalities: Union[str, List[Dict[str, str]]],
type: Union[str, Dict[str, List]] = 'inner',
fields: Optional[Dict[str, Any]] = None):
equalities: Union[str, list[dict[str, str]]],
type: Union[str, dict[str, list]] = 'inner',
fields: Optional[dict[str, Any]] = None):
"""Joins two or more inputs using a specified condition.
For example::
Expand Down
22 changes: 10 additions & 12 deletions sdks/python/apache_beam/yaml/yaml_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,11 @@
import itertools
import re
from collections import abc
from collections.abc import Callable
from collections.abc import Collection
from collections.abc import Iterable
from collections.abc import Mapping
from typing import Any
from typing import Callable
from typing import Collection
from typing import Dict
from typing import Iterable
from typing import List
from typing import Mapping
from typing import Optional
from typing import TypeVar
from typing import Union
Expand Down Expand Up @@ -349,7 +347,7 @@ def _validator(beam_type: schema_pb2.FieldType) -> Callable[[Any], bool]:

def _as_callable_for_pcoll(
pcoll,
fn_spec: Union[str, Dict[str, str]],
fn_spec: Union[str, dict[str, str]],
msg: str,
language: Optional[str]):
if language == 'javascript':
Expand Down Expand Up @@ -495,7 +493,7 @@ class _Validate(beam.PTransform):
"""
def __init__(
self,
schema: Dict[str, Any],
schema: dict[str, Any],
error_handling: Optional[Mapping[str, Any]] = None):
self._schema = schema
self._exception_handling_args = exception_handling_args(error_handling)
Expand Down Expand Up @@ -615,7 +613,7 @@ def with_exception_handling(self, **kwargs):
@beam.ptransform.ptransform_fn
@maybe_with_exception_handling_transform_fn
def _PyJsFilter(
pcoll, keep: Union[str, Dict[str, str]], language: Optional[str] = None):
pcoll, keep: Union[str, dict[str, str]], language: Optional[str] = None):
"""Keeps only records that satisfy the given criteria.
See more complete documentation on
Expand Down Expand Up @@ -740,8 +738,8 @@ def extract_expr(name, v):
@beam.ptransform.ptransform_fn
def _Partition(
pcoll,
by: Union[str, Dict[str, str]],
outputs: List[str],
by: Union[str, dict[str, str]],
outputs: list[str],
unknown_output: Optional[str] = None,
error_handling: Optional[Mapping[str, Any]] = None,
language: Optional[str] = 'generic'):
Expand Down Expand Up @@ -820,7 +818,7 @@ def split(element):
@maybe_with_exception_handling_transform_fn
def _AssignTimestamps(
pcoll,
timestamp: Union[str, Dict[str, str]],
timestamp: Union[str, dict[str, str]],
language: Optional[str] = None):
"""Assigns a new timestamp each element of its input.
Expand Down
22 changes: 10 additions & 12 deletions sdks/python/apache_beam/yaml/yaml_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,8 @@
#

"""This module defines yaml wrappings for some ML transforms."""
from collections.abc import Callable
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional

import apache_beam as beam
Expand All @@ -41,13 +39,13 @@


class ModelHandlerProvider:
handler_types: Dict[str, Callable[..., "ModelHandlerProvider"]] = {}
handler_types: dict[str, Callable[..., "ModelHandlerProvider"]] = {}

def __init__(
self,
handler,
preprocess: Optional[Dict[str, str]] = None,
postprocess: Optional[Dict[str, str]] = None):
preprocess: Optional[dict[str, str]] = None,
postprocess: Optional[dict[str, str]] = None):
self._handler = handler
self._preprocess_fn = self.parse_processing_transform(
preprocess, 'preprocess') or self.default_preprocess_fn()
Expand Down Expand Up @@ -136,15 +134,15 @@ def __init__(
endpoint_id: str,
project: str,
location: str,
preprocess: Dict[str, str],
postprocess: Optional[Dict[str, str]] = None,
preprocess: dict[str, str],
postprocess: Optional[dict[str, str]] = None,
experiment: Optional[str] = None,
network: Optional[str] = None,
private: bool = False,
min_batch_size: Optional[int] = None,
max_batch_size: Optional[int] = None,
max_batch_duration_secs: Optional[int] = None,
env_vars: Optional[Dict[str, Any]] = None):
env_vars: Optional[dict[str, Any]] = None):
"""
ModelHandler for Vertex AI.
Expand Down Expand Up @@ -257,9 +255,9 @@ def get_user_schema_fields(user_type):
@beam.ptransform.ptransform_fn
def run_inference(
pcoll,
model_handler: Dict[str, Any],
model_handler: dict[str, Any],
inference_tag: Optional[str] = 'inference',
inference_args: Optional[Dict[str, Any]] = None) -> beam.PCollection[beam.Row]: # pylint: disable=line-too-long
inference_args: Optional[dict[str, Any]] = None) -> beam.PCollection[beam.Row]: # pylint: disable=line-too-long
"""
A transform that takes the input rows, containing examples (or features), for
use on an ML model. The transform then appends the inferences
Expand Down Expand Up @@ -481,7 +479,7 @@ def ml_transform(
pcoll,
write_artifact_location: Optional[str] = None,
read_artifact_location: Optional[str] = None,
transforms: Optional[List[Any]] = None):
transforms: Optional[list[Any]] = None):
if tft is None:
raise ValueError(
'tensorflow-transform must be installed to use this MLTransform')
Expand Down
22 changes: 10 additions & 12 deletions sdks/python/apache_beam/yaml/yaml_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,11 @@
import sys
import urllib.parse
import warnings
from collections.abc import Callable
from collections.abc import Iterable
from collections.abc import Iterator
from collections.abc import Mapping
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Mapping
from typing import Optional

import docstring_parser
Expand Down Expand Up @@ -150,7 +148,7 @@ def as_provider_list(name, lst):

class ExternalProvider(Provider):
"""A Provider implemented via the cross language transform service."""
_provider_types: Dict[str, Callable[..., Provider]] = {}
_provider_types: dict[str, Callable[..., Provider]] = {}

def __init__(self, urns, service):
self._urns = urns
Expand Down Expand Up @@ -689,7 +687,7 @@ def create(elements: Iterable[Any], reshuffle: Optional[bool] = True):
- {first: 0, second: {str: "foo", values: [1, 2, 3]}}
- {first: 1, second: {str: "bar", values: [4, 5, 6]}}
will result in a schema of the form (int, Row(string, List[int])).
will result in a schema of the form (int, Row(string, list[int])).
This can also be expressed as YAML::
Expand Down Expand Up @@ -1027,22 +1025,22 @@ def __init__(
self._base_python = base_python

@classmethod
def _key(cls, base_python: str, packages: List[str]) -> str:
def _key(cls, base_python: str, packages: list[str]) -> str:
return json.dumps({
'binary': base_python, 'packages': sorted(packages)
},
sort_keys=True)

@classmethod
def _path(cls, base_python: str, packages: List[str]) -> str:
def _path(cls, base_python: str, packages: list[str]) -> str:
return os.path.join(
cls.VENV_CACHE,
hashlib.sha256(cls._key(base_python,
packages).encode('utf-8')).hexdigest())

@classmethod
def _create_venv_from_scratch(
cls, base_python: str, packages: List[str]) -> str:
cls, base_python: str, packages: list[str]) -> str:
venv = cls._path(base_python, packages)
if not os.path.exists(venv):
try:
Expand All @@ -1061,7 +1059,7 @@ def _create_venv_from_scratch(

@classmethod
def _create_venv_from_clone(
cls, base_python: str, packages: List[str]) -> str:
cls, base_python: str, packages: list[str]) -> str:
venv = cls._path(base_python, packages)
if not os.path.exists(venv):
try:
Expand Down
Loading

0 comments on commit 28d65bf

Please sign in to comment.