Skip to content

Commit

Permalink
Add tag to ExternalPipelineChannel so we can get artifacts by tags.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 630465771
  • Loading branch information
tfx-copybara committed May 22, 2024
1 parent a4d4cbe commit dc7c8dc
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 3 deletions.
61 changes: 59 additions & 2 deletions tfx/dsl/compiler/node_inputs_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Compiler submodule specialized for NodeInputs."""

from collections.abc import Iterable
from typing import Type, cast
from typing import Optional, Type, cast

from tfx import types
from tfx.dsl.compiler import compiler_context
Expand Down Expand Up @@ -137,12 +137,16 @@ def compile_op_node(op_node: resolver_op.OpNode):
def _compile_channel_pb_contexts(
context_types_and_names: Iterable[tuple[str, pipeline_pb2.Value]],
result: pipeline_pb2.InputSpec.Channel,
property_predicate: Optional[pipeline_pb2.PropertyPredicate] = None,
):
"""Adds contexts to the channel."""
for context_type, context_value in context_types_and_names:
ctx = result.context_queries.add()
ctx.type.name = context_type
ctx.name.CopyFrom(context_value)
if context_value:
ctx.name.CopyFrom(context_value)
if property_predicate:
ctx.property_predicate.CopyFrom(property_predicate)


def _compile_channel_pb(
Expand Down Expand Up @@ -249,6 +253,59 @@ def _compile_input_spec(
result_input_channel,
)

if channel.tags:
predicates = []
for tag in channel.tags:
predicates.append(
pipeline_pb2.PropertyPredicate(
value_comparator=pipeline_pb2.PropertyPredicate.ValueComparator(
property_name='__tag_' + tag + '__',
op=pipeline_pb2.PropertyPredicate.ValueComparator.Op.EQ,
target_value=pipeline_pb2.Value(
field_value=metadata_store_pb2.Value(bool_value=True)
),
is_custom_property=True,
)
)
)

if len(predicates) == 1:
_compile_channel_pb_contexts(
[(
constants.PIPELINE_RUN_CONTEXT_TYPE_NAME,
_get_tfx_value(''),
)],
result_input_channel,
predicates[0],
)
else:
binary_operator_predicate = pipeline_pb2.PropertyPredicate(
binary_logical_operator=pipeline_pb2.PropertyPredicate.BinaryLogicalOperator(
op=pipeline_pb2.PropertyPredicate.BinaryLogicalOperator.LogicalOp.AND,
lhs=predicates[0],
rhs=predicates[1],
)
)

for i in range(2, len(predicates)):
binary_operator_predicate = pipeline_pb2.PropertyPredicate(
binary_logical_operator=pipeline_pb2.PropertyPredicate.BinaryLogicalOperator(
op=pipeline_pb2.PropertyPredicate.BinaryLogicalOperator.LogicalOp.AND,
lhs=binary_operator_predicate,
rhs=predicates[i],
)
)
_compile_channel_pb_contexts(
[(
constants.PIPELINE_RUN_CONTEXT_TYPE_NAME,
_get_tfx_value(''),
)],
result_input_channel,
binary_operator_predicate,
)

print(result_input_channel)

if pipeline_ctx.pipeline.platform_config:
project_config = (
pipeline_ctx.pipeline.platform_config.project_platform_config
Expand Down
6 changes: 5 additions & 1 deletion tfx/types/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,7 @@ def __init__(
producer_component_id: str,
output_key: str,
pipeline_run_id: str = '',
tags: Sequence[str] = (),
):
"""Initialization of ExternalPipelineChannel.
Expand All @@ -733,13 +734,15 @@ def __init__(
output_key: The output key when producer component produces the artifacts
in this Channel.
pipeline_run_id: (Optional) Pipeline run id the artifacts belong to.
tags: (Optional) A list of tags the artifacts belong to.
"""
super().__init__(type=artifact_type)
self.owner = owner
self.pipeline_name = pipeline_name
self.producer_component_id = producer_component_id
self.output_key = output_key
self.pipeline_run_id = pipeline_run_id
self.tags = tags

def get_data_dependent_node_ids(self) -> Set[str]:
return set()
Expand All @@ -751,7 +754,8 @@ def __repr__(self) -> str:
f'pipeline_name={self.pipeline_name}, '
f'producer_component_id={self.producer_component_id}, '
f'output_key={self.output_key}, '
f'pipeline_run_id={self.pipeline_run_id})'
f'pipeline_run_id={self.pipeline_run_id}), '
f'tags={self.tags}'
)


Expand Down
5 changes: 5 additions & 0 deletions tfx/types/channel_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def external_pipeline_artifact_query(
producer_component_id: str,
output_key: str,
pipeline_run_id: str = '',
tags: Sequence[str] = (),
) -> channel.ExternalPipelineChannel:
"""Helper function to construct a query to get artifacts from an external pipeline.
Expand All @@ -160,6 +161,9 @@ def external_pipeline_artifact_query(
output_key: The output key when producer component produces the artifacts in
this Channel.
pipeline_run_id: (Optional) Pipeline run id the artifacts belong to.
tags: (Optional) A list of tags the artifacts belong to. It is an AND
relationship between tags. For example, if tags=['tag1', 'tag2'], then
only artifacts with both 'tag1' and 'tag2' will be returned.
Returns:
channel.ExternalPipelineChannel instance.
Expand All @@ -177,6 +181,7 @@ def external_pipeline_artifact_query(
producer_component_id=producer_component_id,
output_key=output_key,
pipeline_run_id=pipeline_run_id,
tags=tags,
)


Expand Down

0 comments on commit dc7c8dc

Please sign in to comment.