diff --git a/metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py b/metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py index 47e62418c2961..e1dd2b749757c 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/aws/aws_common.py @@ -1,16 +1,19 @@ from functools import reduce -from typing import List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Union import boto3 from boto3.session import Session -from mypy_boto3_glue import GlueClient -from mypy_boto3_s3 import S3Client -from mypy_boto3_sagemaker import SageMakerClient from datahub.configuration import ConfigModel from datahub.configuration.common import AllowDenyPattern from datahub.emitter.mce_builder import DEFAULT_ENV +if TYPE_CHECKING: + + from mypy_boto3_glue import GlueClient + from mypy_boto3_s3 import S3Client + from mypy_boto3_sagemaker import SageMakerClient + def assume_role( role_arn: str, aws_region: str, credentials: Optional[dict] = None @@ -88,13 +91,13 @@ def get_session(self) -> Session: else: return Session(region_name=self.aws_region) - def get_s3_client(self) -> S3Client: + def get_s3_client(self) -> "S3Client": return self.get_session().client("s3") - def get_glue_client(self) -> GlueClient: + def get_glue_client(self) -> "GlueClient": return self.get_session().client("glue") - def get_sagemaker_client(self) -> SageMakerClient: + def get_sagemaker_client(self) -> "SageMakerClient": return self.get_session().client("sagemaker") diff --git a/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/feature_groups.py b/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/feature_groups.py index 09c10e93a95fd..381ab4ef88af8 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/feature_groups.py +++ b/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/feature_groups.py @@ -1,12 +1,5 @@ from dataclasses import dataclass -from typing import Iterable, List - -from mypy_boto3_sagemaker import SageMakerClient -from mypy_boto3_sagemaker.type_defs import ( - DescribeFeatureGroupResponseTypeDef, - FeatureDefinitionTypeDef, - FeatureGroupSummaryTypeDef, -) +from typing import TYPE_CHECKING, Iterable, List import datahub.emitter.mce_builder as builder from datahub.ingestion.api.workunit import MetadataWorkUnit @@ -27,14 +20,23 @@ MLPrimaryKeyPropertiesClass, ) +if TYPE_CHECKING: + + from mypy_boto3_sagemaker import SageMakerClient + from mypy_boto3_sagemaker.type_defs import ( + DescribeFeatureGroupResponseTypeDef, + FeatureDefinitionTypeDef, + FeatureGroupSummaryTypeDef, + ) + @dataclass class FeatureGroupProcessor: - sagemaker_client: SageMakerClient + sagemaker_client: "SageMakerClient" env: str report: SagemakerSourceReport - def get_all_feature_groups(self) -> List[FeatureGroupSummaryTypeDef]: + def get_all_feature_groups(self) -> List["FeatureGroupSummaryTypeDef"]: """ List all feature groups in SageMaker. """ @@ -50,7 +52,7 @@ def get_all_feature_groups(self) -> List[FeatureGroupSummaryTypeDef]: def get_feature_group_details( self, feature_group_name: str - ) -> DescribeFeatureGroupResponseTypeDef: + ) -> "DescribeFeatureGroupResponseTypeDef": """ Get details of a feature group (including list of component features). """ @@ -74,7 +76,7 @@ def get_feature_group_details( return feature_group def get_feature_group_wu( - self, feature_group_details: DescribeFeatureGroupResponseTypeDef + self, feature_group_details: "DescribeFeatureGroupResponseTypeDef" ) -> MetadataWorkUnit: """ Generate an MLFeatureTable workunit for a SageMaker feature group. @@ -146,8 +148,8 @@ def get_feature_type(self, aws_type: str, feature_name: str) -> str: def get_feature_wu( self, - feature_group_details: DescribeFeatureGroupResponseTypeDef, - feature: FeatureDefinitionTypeDef, + feature_group_details: "DescribeFeatureGroupResponseTypeDef", + feature: "FeatureDefinitionTypeDef", ) -> MetadataWorkUnit: """ Generate an MLFeature workunit for a SageMaker feature. diff --git a/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/jobs.py b/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/jobs.py index 99057de90a868..a8f6e346c1b65 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/jobs.py +++ b/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/jobs.py @@ -2,6 +2,7 @@ from dataclasses import dataclass, field from enum import Enum from typing import ( + TYPE_CHECKING, Any, DefaultDict, Dict, @@ -16,8 +17,6 @@ Union, ) -from mypy_boto3_sagemaker import SageMakerClient - from datahub.emitter import mce_builder from datahub.ingestion.api.workunit import MetadataWorkUnit from datahub.ingestion.source.aws.aws_common import make_s3_urn @@ -47,6 +46,9 @@ JobStatusClass, ) +if TYPE_CHECKING: + from mypy_boto3_sagemaker import SageMakerClient + JobInfo = TypeVar( "JobInfo", AutoMlJobInfo, @@ -151,7 +153,7 @@ class JobProcessor: """ # boto3 SageMaker client - sagemaker_client: SageMakerClient + sagemaker_client: "SageMakerClient" env: str report: SagemakerSourceReport # config filter for specific job types to ingest (see metadata-ingestion README) diff --git a/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/lineage.py b/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/lineage.py index e51d5aab6b1e4..771aca5fad114 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/lineage.py +++ b/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/lineage.py @@ -1,19 +1,20 @@ from collections import defaultdict from dataclasses import dataclass, field -from typing import Any, DefaultDict, Dict, List, Set - -from mypy_boto3_sagemaker import SageMakerClient -from mypy_boto3_sagemaker.type_defs import ( - ActionSummaryTypeDef, - ArtifactSummaryTypeDef, - AssociationSummaryTypeDef, - ContextSummaryTypeDef, -) +from typing import TYPE_CHECKING, Any, DefaultDict, Dict, List, Set from datahub.ingestion.source.aws.sagemaker_processors.common import ( SagemakerSourceReport, ) +if TYPE_CHECKING: + from mypy_boto3_sagemaker import SageMakerClient + from mypy_boto3_sagemaker.type_defs import ( + ActionSummaryTypeDef, + ArtifactSummaryTypeDef, + AssociationSummaryTypeDef, + ContextSummaryTypeDef, + ) + @dataclass class LineageInfo: @@ -42,13 +43,13 @@ class LineageInfo: @dataclass class LineageProcessor: - sagemaker_client: SageMakerClient + sagemaker_client: "SageMakerClient" env: str report: SagemakerSourceReport nodes: Dict[str, Dict[str, Any]] = field(default_factory=dict) lineage_info: LineageInfo = field(default_factory=LineageInfo) - def get_all_actions(self) -> List[ActionSummaryTypeDef]: + def get_all_actions(self) -> List["ActionSummaryTypeDef"]: """ List all actions in SageMaker. """ @@ -62,7 +63,7 @@ def get_all_actions(self) -> List[ActionSummaryTypeDef]: return actions - def get_all_artifacts(self) -> List[ArtifactSummaryTypeDef]: + def get_all_artifacts(self) -> List["ArtifactSummaryTypeDef"]: """ List all artifacts in SageMaker. """ @@ -76,7 +77,7 @@ def get_all_artifacts(self) -> List[ArtifactSummaryTypeDef]: return artifacts - def get_all_contexts(self) -> List[ContextSummaryTypeDef]: + def get_all_contexts(self) -> List["ContextSummaryTypeDef"]: """ List all contexts in SageMaker. """ @@ -90,7 +91,7 @@ def get_all_contexts(self) -> List[ContextSummaryTypeDef]: return contexts - def get_incoming_edges(self, node_arn: str) -> List[AssociationSummaryTypeDef]: + def get_incoming_edges(self, node_arn: str) -> List["AssociationSummaryTypeDef"]: """ Get all incoming edges for a node in the lineage graph. """ @@ -104,7 +105,7 @@ def get_incoming_edges(self, node_arn: str) -> List[AssociationSummaryTypeDef]: return edges - def get_outgoing_edges(self, node_arn: str) -> List[AssociationSummaryTypeDef]: + def get_outgoing_edges(self, node_arn: str) -> List["AssociationSummaryTypeDef"]: """ Get all outgoing edges for a node in the lineage graph. """ diff --git a/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/models.py b/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/models.py index 17ce027b86391..7e2862a02c717 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/models.py +++ b/metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker_processors/models.py @@ -1,16 +1,15 @@ from collections import defaultdict from dataclasses import dataclass, field from datetime import datetime -from typing import DefaultDict, Dict, Iterable, List, Optional, Set, Tuple - -from mypy_boto3_sagemaker import SageMakerClient -from mypy_boto3_sagemaker.type_defs import ( - DescribeEndpointOutputTypeDef, - DescribeModelOutputTypeDef, - DescribeModelPackageGroupOutputTypeDef, - EndpointSummaryTypeDef, - ModelPackageGroupSummaryTypeDef, - ModelSummaryTypeDef, +from typing import ( + TYPE_CHECKING, + DefaultDict, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, ) import datahub.emitter.mce_builder as builder @@ -43,6 +42,17 @@ OwnershipTypeClass, ) +if TYPE_CHECKING: + from mypy_boto3_sagemaker import SageMakerClient + from mypy_boto3_sagemaker.type_defs import ( + DescribeEndpointOutputTypeDef, + DescribeModelOutputTypeDef, + DescribeModelPackageGroupOutputTypeDef, + EndpointSummaryTypeDef, + ModelPackageGroupSummaryTypeDef, + ModelSummaryTypeDef, + ) + ENDPOINT_STATUS_MAP: Dict[str, str] = { "OutOfService": DeploymentStatusClass.OUT_OF_SERVICE, "Creating": DeploymentStatusClass.CREATING, @@ -58,7 +68,7 @@ @dataclass class ModelProcessor: - sagemaker_client: SageMakerClient + sagemaker_client: "SageMakerClient" env: str report: SagemakerSourceReport lineage: LineageInfo @@ -81,7 +91,7 @@ class ModelProcessor: group_arn_to_name: Dict[str, str] = field(default_factory=dict) - def get_all_models(self) -> List[ModelSummaryTypeDef]: + def get_all_models(self) -> List["ModelSummaryTypeDef"]: """ List all models in SageMaker. """ @@ -95,7 +105,7 @@ def get_all_models(self) -> List[ModelSummaryTypeDef]: return models - def get_model_details(self, model_name: str) -> DescribeModelOutputTypeDef: + def get_model_details(self, model_name: str) -> "DescribeModelOutputTypeDef": """ Get details of a model. """ @@ -103,7 +113,7 @@ def get_model_details(self, model_name: str) -> DescribeModelOutputTypeDef: # see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_model return self.sagemaker_client.describe_model(ModelName=model_name) - def get_all_groups(self) -> List[ModelPackageGroupSummaryTypeDef]: + def get_all_groups(self) -> List["ModelPackageGroupSummaryTypeDef"]: """ List all model groups in SageMaker. """ @@ -118,7 +128,7 @@ def get_all_groups(self) -> List[ModelPackageGroupSummaryTypeDef]: def get_group_details( self, group_name: str - ) -> DescribeModelPackageGroupOutputTypeDef: + ) -> "DescribeModelPackageGroupOutputTypeDef": """ Get details of a model group. """ @@ -128,7 +138,7 @@ def get_group_details( ModelPackageGroupName=group_name ) - def get_all_endpoints(self) -> List[EndpointSummaryTypeDef]: + def get_all_endpoints(self) -> List["EndpointSummaryTypeDef"]: endpoints = [] @@ -140,7 +150,9 @@ def get_all_endpoints(self) -> List[EndpointSummaryTypeDef]: return endpoints - def get_endpoint_details(self, endpoint_name: str) -> DescribeEndpointOutputTypeDef: + def get_endpoint_details( + self, endpoint_name: str + ) -> "DescribeEndpointOutputTypeDef": # see https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_endpoint return self.sagemaker_client.describe_endpoint(EndpointName=endpoint_name) @@ -162,7 +174,7 @@ def get_endpoint_status( return endpoint_status def get_endpoint_wu( - self, endpoint_details: DescribeEndpointOutputTypeDef + self, endpoint_details: "DescribeEndpointOutputTypeDef" ) -> MetadataWorkUnit: """a Get a workunit for an endpoint. @@ -206,7 +218,7 @@ def get_endpoint_wu( def get_model_endpoints( self, - model_details: DescribeModelOutputTypeDef, + model_details: "DescribeModelOutputTypeDef", endpoint_arn_to_name: Dict[str, str], model_image: Optional[str], model_uri: Optional[str], @@ -235,7 +247,7 @@ def get_model_endpoints( return model_endpoints_sorted def get_group_wu( - self, group_details: DescribeModelPackageGroupOutputTypeDef + self, group_details: "DescribeModelPackageGroupOutputTypeDef" ) -> MetadataWorkUnit: """ Get a workunit for a model group. @@ -285,7 +297,7 @@ def get_group_wu( return MetadataWorkUnit(id=group_name, mce=mce) def match_model_jobs( - self, model_details: DescribeModelOutputTypeDef + self, model_details: "DescribeModelOutputTypeDef" ) -> Tuple[Set[str], Set[str], List[MLHyperParamClass], List[MLMetricClass]]: model_training_jobs: Set[str] = set() @@ -380,7 +392,7 @@ def strip_quotes(string: str) -> str: def get_model_wu( self, - model_details: DescribeModelOutputTypeDef, + model_details: "DescribeModelOutputTypeDef", endpoint_arn_to_name: Dict[str, str], ) -> MetadataWorkUnit: """