diff --git a/.test-infra/jenkins/dependency_check/dependency_check_report_generator.py b/.test-infra/jenkins/dependency_check/dependency_check_report_generator.py index 833800db1985..19e602cf4be0 100644 --- a/.test-infra/jenkins/dependency_check/dependency_check_report_generator.py +++ b/.test-infra/jenkins/dependency_check/dependency_check_report_generator.py @@ -231,7 +231,7 @@ def find_release_time_from_python_compatibility_checking_service(dep_name, versi def request_session_with_retries(): """ - Create a http session with retries + Create an http session with retries """ session = requests.Session() retries = Retry(total=3) diff --git a/build.gradle b/build.gradle index 6456a5ebb46b..1b633f814af0 100644 --- a/build.gradle +++ b/build.gradle @@ -294,3 +294,41 @@ release { pushToRemote = '' } } + +// Reports linkage errors across multiple Apache Beam artifact ids. +// +// To use (from the root of project): +// ./gradlew -Ppublishing -PjavaLinkageArtifactIds=artifactId1,artifactId2,... :checkJavaLinkage +// +// For example: +// ./gradlew -Ppublishing -PjavaLinkageArtifactIds=beam-sdks-java-core,beam-sdks-java-io-jdbc :checkJavaLinkage +// +// Note that this task publishes artifacts into your local Maven repository. +if (project.hasProperty('javaLinkageArtifactIds')) { + if (!project.hasProperty('publishing')) { + throw new GradleException('You can only check linkage of Java artifacts if you specify -Ppublishing on the command line as well.') + } + + configurations { linkageCheckerJava } + dependencies { + linkageCheckerJava "com.google.cloud.tools:dependencies:1.0.1" + } + + // We need to evaluate all the projects first so that we can find depend on all the + // publishMavenJavaPublicationToMavenLocal tasks below. + for (p in rootProject.subprojects) { + if (!p.path.equals(project.path)) { + evaluationDependsOn(p.path) + } + } + + project.task('checkJavaLinkage', type: JavaExec) { + dependsOn project.getTasksByName('publishMavenJavaPublicationToMavenLocal', true /* recursively */) + classpath = project.configurations.linkageCheckerJava + main = 'com.google.cloud.tools.opensource.classpath.LinkageCheckerMain' + args '-a', project.javaLinkageArtifactIds.split(',').collect({"${project.ext.mavenGroupId}:${it}:${project.version}"}).join(',') + doLast { + println "NOTE: This task published artifacts into your local Maven repository. You may want to remove them manually." + } + } +} diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy index 74f60c54e715..ebfe4b22ef64 100644 --- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy +++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy @@ -421,6 +421,7 @@ class BeamModulePlugin implements Plugin { aws_java_sdk2_dynamodb : "software.amazon.awssdk:dynamodb:$aws_java_sdk2_version", aws_java_sdk2_sdk_core : "software.amazon.awssdk:sdk-core:$aws_java_sdk2_version", aws_java_sdk2_sns : "software.amazon.awssdk:sns:$aws_java_sdk2_version", + aws_java_sdk2_sqs : "software.amazon.awssdk:sqs:$aws_java_sdk2_version", bigdataoss_gcsio : "com.google.cloud.bigdataoss:gcsio:$google_cloud_bigdataoss_version", bigdataoss_util : "com.google.cloud.bigdataoss:util:$google_cloud_bigdataoss_version", cassandra_driver_core : "com.datastax.cassandra:cassandra-driver-core:$cassandra_driver_version", @@ -1924,7 +1925,6 @@ class BeamModulePlugin implements Plugin { "--input=/etc/profile", "--output=/tmp/py-wordcount-direct", "--runner=${runner}", - "--experiments=worker_threads=100", "--parallelism=2", "--shutdown_sources_on_final_watermark", "--sdk_worker_parallelism=1", diff --git a/model/fn-execution/src/main/proto/beam_fn_api.proto b/model/fn-execution/src/main/proto/beam_fn_api.proto index 1c9c13b59ca2..c868babf14f6 100644 --- a/model/fn-execution/src/main/proto/beam_fn_api.proto +++ b/model/fn-execution/src/main/proto/beam_fn_api.proto @@ -88,7 +88,7 @@ service BeamFnControl { // matching instruction id. // Stable message InstructionRequest { - // (Required) An unique identifier provided by the runner which represents + // (Required) A unique identifier provided by the runner which represents // this requests execution. The InstructionResponse MUST have the matching id. string instruction_id = 1; @@ -564,7 +564,7 @@ service BeamFnData { */ message StateRequest { - // (Required) An unique identifier provided by the SDK which represents this + // (Required) A unique identifier provided by the SDK which represents this // requests execution. The StateResponse corresponding with this request // will have the matching id. string id = 1; diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/CounterCell.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/CounterCell.java index 9ca5cdb28b31..7b6177e9c1d6 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/CounterCell.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/CounterCell.java @@ -51,6 +51,12 @@ public CounterCell(MetricName name) { this.name = name; } + @Override + public void reset() { + dirty.afterModification(); + value.set(0L); + } + /** * Increment the counter by the given amount. * diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/DistributionCell.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/DistributionCell.java index ca85de2393a8..31430ecb88fb 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/DistributionCell.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/DistributionCell.java @@ -52,6 +52,12 @@ public DistributionCell(MetricName name) { this.name = name; } + @Override + public void reset() { + dirty.afterModification(); + value.set(DistributionData.EMPTY); + } + /** Increment the distribution by the given amount. */ @Override public void update(long n) { diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/ExecutionStateSampler.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/ExecutionStateSampler.java index 7bc8f68b4a3b..5582483c3915 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/ExecutionStateSampler.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/ExecutionStateSampler.java @@ -69,6 +69,11 @@ public static ExecutionStateSampler newForTest(MillisProvider clock) { @Nullable private Future executionSamplerFuture = null; + /** Reset the state sampler. */ + public void reset() { + lastSampleTimeMillis = 0; + } + /** * Called to start the ExecutionStateSampler. Until the returned {@link Closeable} is closed, the * state sampler will periodically sample the current state of all the threads it has been asked diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/ExecutionStateTracker.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/ExecutionStateTracker.java index e14d59742d5c..58e3f7d1060e 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/ExecutionStateTracker.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/ExecutionStateTracker.java @@ -137,6 +137,17 @@ public ExecutionStateTracker(ExecutionStateSampler sampler) { this.sampler = sampler; } + /** Reset the execution status. */ + public void reset() { + trackedThread = null; + currentState = null; + numTransitions = 0; + millisSinceLastTransition = 0; + transitionsAtLastSample = 0; + nextLullReportMs = LULL_REPORT_MS; + CURRENT_TRACKERS.entrySet().removeIf(entry -> entry.getValue() == this); + } + @VisibleForTesting public static ExecutionStateTracker newForTest() { return new ExecutionStateTracker(ExecutionStateSampler.newForTest()); @@ -261,6 +272,16 @@ public long getMillisSinceLastTransition() { return millisSinceLastTransition; } + /** Return the number of transitions since the last sample. */ + public long getTransitionsAtLastSample() { + return transitionsAtLastSample; + } + + /** Return the time of the next lull report. */ + public long getNextLullReportMs() { + return nextLullReportMs; + } + protected void takeSample(long millisSinceLastSample) { // These variables are read by Sampler thread, and written by Execution and Progress Reporting // threads. diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/GaugeCell.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/GaugeCell.java index f0d9d726469a..eb70469a495a 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/GaugeCell.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/GaugeCell.java @@ -50,6 +50,12 @@ public GaugeCell(MetricName name) { this.name = name; } + @Override + public void reset() { + dirty.afterModification(); + gaugeValue.set(GaugeData.empty()); + } + /** Set the gauge to the given value. */ @Override public void set(long value) { diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MetricCell.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MetricCell.java index e0673bc02722..0700bac7fe63 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MetricCell.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MetricCell.java @@ -36,4 +36,7 @@ public interface MetricCell extends Serializable { /** Return the cumulative value of this metric. */ DataT getCumulative(); + + /** Reset this metric. */ + void reset(); } diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MetricsContainerImpl.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MetricsContainerImpl.java index ac471caf63ef..fd7ee72a4072 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MetricsContainerImpl.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MetricsContainerImpl.java @@ -79,6 +79,19 @@ public MetricsContainerImpl(@Nullable String stepName) { this.stepName = stepName; } + /** Reset the metrics. */ + public void reset() { + reset(counters); + reset(distributions); + reset(gauges); + } + + private void reset(MetricsMap> cells) { + for (MetricCell cell : cells.values()) { + cell.reset(); + } + } + /** * Return a {@code CounterCell} named {@code metricName}. If it doesn't exist, create a {@code * Metric} with the specified name. diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MetricsContainerStepMap.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MetricsContainerStepMap.java index 5a8e89c898e4..db18f787bb85 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MetricsContainerStepMap.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/MetricsContainerStepMap.java @@ -84,6 +84,14 @@ public void update(String step, MetricsContainerImpl container) { getContainer(step).update(container); } + /** Reset the metric containers. */ + public void reset() { + for (MetricsContainerImpl metricsContainer : metricsContainers.values()) { + metricsContainer.reset(); + } + unboundContainer.reset(); + } + @Override public boolean equals(Object object) { if (object instanceof MetricsContainerStepMap) { diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/SimpleExecutionState.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/SimpleExecutionState.java index f0b6c46963db..07eeeba25f29 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/SimpleExecutionState.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/SimpleExecutionState.java @@ -69,6 +69,11 @@ public SimpleExecutionState( } } + /** Reset the totalMillis spent in the state. */ + public void reset() { + this.totalMillis = 0; + } + public String getUrn() { return this.urn; } diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/SimpleStateRegistry.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/SimpleStateRegistry.java index 19f279170565..17a3fcf8176c 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/SimpleStateRegistry.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/metrics/SimpleStateRegistry.java @@ -32,6 +32,13 @@ public void register(SimpleExecutionState state) { this.executionStates.add(state); } + /** Reset the registered SimpleExecutionStates. */ + public void reset() { + for (SimpleExecutionState state : executionStates) { + state.reset(); + } + } + /** @return Execution Time MonitoringInfos based on the tracked start or finish function. */ public List getExecutionTimeMonitoringInfos() { List monitoringInfos = new ArrayList(); diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/CounterCellTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/CounterCellTest.java index fe4d9863e443..1c9a34e651fc 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/CounterCellTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/CounterCellTest.java @@ -80,4 +80,15 @@ public void testNotEquals() { Assert.assertNotEquals(counterCell, differentName); Assert.assertNotEquals(counterCell.hashCode(), differentName.hashCode()); } + + @Test + public void testReset() { + CounterCell counterCell = new CounterCell(MetricName.named("namespace", "name")); + counterCell.inc(1); + assertThat(counterCell.getCumulative(), equalTo(1L)); + + counterCell.reset(); + assertThat(counterCell.getCumulative(), equalTo(0L)); + assertThat(counterCell.getDirty(), equalTo(new DirtyState())); + } } diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/DistributionCellTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/DistributionCellTest.java index 8bfa614ef26f..4e0b6b8ca268 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/DistributionCellTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/DistributionCellTest.java @@ -81,4 +81,15 @@ public void testNotEquals() { Assert.assertNotEquals(distributionCell, differentName); Assert.assertNotEquals(distributionCell.hashCode(), differentName.hashCode()); } + + @Test + public void testReset() { + DistributionCell distributionCell = new DistributionCell(MetricName.named("namespace", "name")); + distributionCell.update(2); + assertThat(distributionCell.getCumulative(), equalTo(DistributionData.create(2, 1, 2, 2))); + + distributionCell.reset(); + assertThat(distributionCell.getCumulative(), equalTo(DistributionData.EMPTY)); + assertThat(distributionCell.getDirty(), equalTo(new DirtyState())); + } } diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/ExecutionStateSamplerTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/ExecutionStateSamplerTest.java index 073cfd1ccfab..6ce515f07c92 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/ExecutionStateSamplerTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/ExecutionStateSamplerTest.java @@ -123,6 +123,13 @@ public void testLullDetectionOccurs() throws Exception { assertThat(step1act1.lullReported, equalTo(true)); } + @Test + public void testReset() throws Exception { + sampler.lastSampleTimeMillis = 100L; + sampler.reset(); + assertThat(sampler.lastSampleTimeMillis, equalTo(0L)); + } + private ExecutionStateTracker createTracker() { return new ExecutionStateTracker(sampler); } diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/ExecutionStateTrackerTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/ExecutionStateTrackerTest.java new file mode 100644 index 000000000000..5bc4f04a0ade --- /dev/null +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/ExecutionStateTrackerTest.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.core.metrics; + +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertThat; +import static org.mockito.Mockito.mock; + +import java.io.Closeable; +import java.util.concurrent.TimeUnit; +import org.apache.beam.runners.core.metrics.ExecutionStateTracker.ExecutionState; +import org.joda.time.DateTimeUtils.MillisProvider; +import org.junit.Before; +import org.junit.Test; + +/** Tests for {@link ExecutionStateTracker}. */ +public class ExecutionStateTrackerTest { + + private MillisProvider clock; + private ExecutionStateSampler sampler; + + @Before + public void setUp() { + clock = mock(MillisProvider.class); + sampler = ExecutionStateSampler.newForTest(clock); + } + + private static class TestExecutionState extends ExecutionState { + + private long totalMillis = 0; + + public TestExecutionState(String stateName) { + super(stateName); + } + + @Override + public void takeSample(long millisSinceLastSample) { + totalMillis += millisSinceLastSample; + } + + @Override + public void reportLull(Thread trackedThread, long millis) {} + } + + private final TestExecutionState testExecutionState = new TestExecutionState("activity"); + + @Test + public void testReset() throws Exception { + ExecutionStateTracker tracker = createTracker(); + try (Closeable c1 = tracker.activate(new Thread())) { + try (Closeable c2 = tracker.enterState(testExecutionState)) { + sampler.doSampling(400); + assertThat(testExecutionState.totalMillis, equalTo(400L)); + } + } + + tracker.reset(); + assertThat(tracker.getTrackedThread(), equalTo(null)); + assertThat(tracker.getCurrentState(), equalTo(null)); + assertThat(tracker.getNumTransitions(), equalTo(0L)); + assertThat(tracker.getMillisSinceLastTransition(), equalTo(0L)); + assertThat(tracker.getTransitionsAtLastSample(), equalTo(0L)); + assertThat(tracker.getNextLullReportMs(), equalTo(TimeUnit.MINUTES.toMillis(5))); + } + + private ExecutionStateTracker createTracker() { + return new ExecutionStateTracker(sampler); + } +} diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/GaugeCellTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/GaugeCellTest.java index 19d711c17188..174042dddaff 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/GaugeCellTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/GaugeCellTest.java @@ -74,4 +74,15 @@ public void testNotEquals() { Assert.assertNotEquals(gaugeCell, differentName); Assert.assertNotEquals(gaugeCell.hashCode(), differentName.hashCode()); } + + @Test + public void testReset() { + GaugeCell gaugeCell = new GaugeCell(MetricName.named("namespace", "name")); + gaugeCell.set(2); + assertThat(gaugeCell.getCumulative().value(), equalTo(GaugeData.create(2).value())); + + gaugeCell.reset(); + assertThat(gaugeCell.getCumulative(), equalTo(GaugeData.empty())); + assertThat(gaugeCell.getDirty(), equalTo(new DirtyState())); + } } diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/MetricsContainerStepMapTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/MetricsContainerStepMapTest.java index 9be564ac9bf3..dcebb31852d5 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/MetricsContainerStepMapTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/MetricsContainerStepMapTest.java @@ -328,6 +328,50 @@ public void testNotEquals() { metricsContainerStepMap.hashCode(), differentUnboundedContainer.hashCode()); } + @Test + public void testReset() { + MetricsContainerStepMap attemptedMetrics = new MetricsContainerStepMap(); + attemptedMetrics.update(STEP1, metricsContainer); + attemptedMetrics.update(STEP2, metricsContainer); + attemptedMetrics.update(STEP2, metricsContainer); + + MetricResults metricResults = asAttemptedOnlyMetricResults(attemptedMetrics); + MetricQueryResults allres = metricResults.allMetrics(); + assertCounter(COUNTER_NAME, allres, STEP1, VALUE, false); + assertDistribution( + DISTRIBUTION_NAME, + allres, + STEP1, + DistributionResult.create(VALUE * 3, 2, VALUE, VALUE * 2), + false); + assertGauge(GAUGE_NAME, allres, STEP1, GaugeResult.create(VALUE, Instant.now()), false); + + assertCounter(COUNTER_NAME, allres, STEP2, VALUE * 2, false); + assertDistribution( + DISTRIBUTION_NAME, + allres, + STEP2, + DistributionResult.create(VALUE * 6, 4, VALUE, VALUE * 2), + false); + assertGauge(GAUGE_NAME, allres, STEP2, GaugeResult.create(VALUE, Instant.now()), false); + + attemptedMetrics.reset(); + metricResults = asAttemptedOnlyMetricResults(attemptedMetrics); + allres = metricResults.allMetrics(); + + // Check that the metrics container for STEP1 is reset + assertCounter(COUNTER_NAME, allres, STEP1, 0L, false); + assertDistribution( + DISTRIBUTION_NAME, allres, STEP1, DistributionResult.IDENTITY_ELEMENT, false); + assertGauge(GAUGE_NAME, allres, STEP1, GaugeResult.empty(), false); + + // Check that the metrics container for STEP2 is reset + assertCounter(COUNTER_NAME, allres, STEP2, 0L, false); + assertDistribution( + DISTRIBUTION_NAME, allres, STEP2, DistributionResult.IDENTITY_ELEMENT, false); + assertGauge(GAUGE_NAME, allres, STEP2, GaugeResult.empty(), false); + } + private void assertIterableSize(Iterable iterable, int size) { assertThat(iterable, IsIterableWithSize.iterableWithSize(size)); } diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/SimpleExecutionStateTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/SimpleExecutionStateTest.java index 2faf7365ca86..f774930b958a 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/SimpleExecutionStateTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/SimpleExecutionStateTest.java @@ -54,6 +54,9 @@ public void testTakeSampleIncrementsTotal() { assertEquals(10, testObject.getTotalMillis()); testObject.takeSample(5); assertEquals(15, testObject.getTotalMillis()); + + testObject.reset(); + assertEquals(0, testObject.getTotalMillis()); } @Test diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/SimpleStateRegistryTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/SimpleStateRegistryTest.java index 205fde0def6f..2042e3f434e9 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/SimpleStateRegistryTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/metrics/SimpleStateRegistryTest.java @@ -18,6 +18,9 @@ package org.apache.beam.runners.core.metrics; import static org.junit.Assert.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import java.util.ArrayList; import java.util.HashMap; @@ -81,4 +84,18 @@ public void testExecutionTimeUrnsBuildMonitoringInfos() throws Exception { assertThat(testOutput, Matchers.hasItem(matcher)); } } + + @Test + public void testResetRegistry() { + SimpleExecutionState state1 = mock(SimpleExecutionState.class); + SimpleExecutionState state2 = mock(SimpleExecutionState.class); + + SimpleStateRegistry testObject = new SimpleStateRegistry(); + testObject.register(state1); + testObject.register(state2); + + testObject.reset(); + verify(state1, times(1)).reset(); + verify(state2, times(1)).reset(); + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DeltaCounterCell.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DeltaCounterCell.java index 44153740665b..09f42c02add1 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DeltaCounterCell.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DeltaCounterCell.java @@ -37,6 +37,11 @@ public DeltaCounterCell(MetricName name) { this.name = name; } + @Override + public void reset() { + value.set(0L); + } + @Override public void inc(long n) { value.addAndGet(n); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DeltaDistributionCell.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DeltaDistributionCell.java index f930aa96e501..6e75d4b9eb37 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DeltaDistributionCell.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DeltaDistributionCell.java @@ -52,6 +52,11 @@ void update(DistributionData data) { } while (!value.compareAndSet(original, original.combine(data))); } + @Override + public void reset() { + value.set(DistributionData.EMPTY); + } + @Override public void update(long sum, long count, long min, long max) { update(DistributionData.create(sum, count, min, max)); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/counters/CounterUpdateAggregators.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/counters/CounterUpdateAggregators.java index 32f99e73a259..d1a1df83c3de 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/counters/CounterUpdateAggregators.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/counters/CounterUpdateAggregators.java @@ -18,11 +18,11 @@ package org.apache.beam.runners.dataflow.worker.counters; import com.google.api.services.dataflow.model.CounterUpdate; -import com.google.common.collect.ImmutableMap; import java.util.Collections; import java.util.List; import java.util.Map; import org.apache.beam.runners.dataflow.worker.MetricsToCounterUpdateConverter.Kind; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; public class CounterUpdateAggregators { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/fn/control/TimerReceiver.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/fn/control/TimerReceiver.java index 7f31c4875c56..286ea8a7161e 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/fn/control/TimerReceiver.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/fn/control/TimerReceiver.java @@ -17,7 +17,6 @@ */ package org.apache.beam.runners.dataflow.worker.fn.control; -import com.google.common.annotations.VisibleForTesting; import java.io.IOException; import java.util.Collections; import java.util.HashMap; @@ -41,6 +40,7 @@ import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/fn/data/BeamFnDataGrpcService.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/fn/data/BeamFnDataGrpcService.java index 7c69bf1e008b..dcde1043849c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/fn/data/BeamFnDataGrpcService.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/fn/data/BeamFnDataGrpcService.java @@ -17,8 +17,6 @@ */ package org.apache.beam.runners.dataflow.worker.fn.data; -import java.util.Collections; -import java.util.List; import java.util.concurrent.CancellationException; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; @@ -41,10 +39,8 @@ import org.apache.beam.sdk.fn.data.InboundDataClient; import org.apache.beam.sdk.fn.data.LogicalEndpoint; import org.apache.beam.sdk.fn.stream.OutboundObserverFactory; -import org.apache.beam.sdk.options.ExperimentalOptions; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.vendor.grpc.v1p21p0.io.grpc.stub.StreamObserver; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Optional; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -63,22 +59,21 @@ public class BeamFnDataGrpcService extends BeamFnDataGrpc.BeamFnDataImplBase implements BeamFnService { private static final Logger LOG = LoggerFactory.getLogger(BeamFnDataGrpcService.class); - private static final String BEAM_FN_API_DATA_BUFFER_LIMIT = "beam_fn_api_data_buffer_limit="; private final Endpoints.ApiServiceDescriptor apiServiceDescriptor; private final ConcurrentMap> connectedClients; + private final PipelineOptions options; private final Function, StreamObserver> streamObserverFactory; private final HeaderAccessor headerAccessor; - private final Optional outboundBufferLimit; public BeamFnDataGrpcService( PipelineOptions options, Endpoints.ApiServiceDescriptor descriptor, Function, StreamObserver> streamObserverFactory, HeaderAccessor headerAccessor) { - this.outboundBufferLimit = getOutboundBufferLimit(options); + this.options = options; this.streamObserverFactory = streamObserverFactory; this.headerAccessor = headerAccessor; this.connectedClients = new ConcurrentHashMap<>(); @@ -86,17 +81,6 @@ public BeamFnDataGrpcService( LOG.info("Launched Beam Fn Data service {}", this.apiServiceDescriptor); } - private static final Optional getOutboundBufferLimit(PipelineOptions options) { - List experiments = options.as(ExperimentalOptions.class).getExperiments(); - for (String experiment : experiments == null ? Collections.emptyList() : experiments) { - if (experiment.startsWith(BEAM_FN_API_DATA_BUFFER_LIMIT)) { - return Optional.of( - Integer.parseInt(experiment.substring(BEAM_FN_API_DATA_BUFFER_LIMIT.length()))); - } - } - return Optional.absent(); - } - @Override public Endpoints.ApiServiceDescriptor getApiServiceDescriptor() { return apiServiceDescriptor; @@ -218,16 +202,11 @@ public InboundDataClient receive( public CloseableFnDataReceiver send(LogicalEndpoint outputLocation, Coder coder) { LOG.debug("Creating output consumer for {}", outputLocation); try { - if (outboundBufferLimit.isPresent()) { - return BeamFnDataBufferingOutboundObserver.forLocationWithBufferLimit( - outboundBufferLimit.get(), - outputLocation, - coder, - getClientFuture(clientId).get().getOutboundObserver()); - } else { - return BeamFnDataBufferingOutboundObserver.forLocation( - outputLocation, coder, getClientFuture(clientId).get().getOutboundObserver()); - } + return BeamFnDataBufferingOutboundObserver.forLocation( + options, + outputLocation, + coder, + getClientFuture(clientId).get().getOutboundObserver()); } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new RuntimeException(e); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/fn/control/TimerReceiverTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/fn/control/TimerReceiverTest.java index 6ea43801b0ea..2067a3a94d2c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/fn/control/TimerReceiverTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/fn/control/TimerReceiverTest.java @@ -101,7 +101,10 @@ public void setUp() throws Exception { InProcessServerFactory serverFactory = InProcessServerFactory.create(); dataServer = GrpcFnServer.allocatePortAndCreateFor( - GrpcDataService.create(serverExecutor, OutboundObserverFactory.serverDirect()), + GrpcDataService.create( + PipelineOptionsFactory.create(), + serverExecutor, + OutboundObserverFactory.serverDirect()), serverFactory); loggingServer = GrpcFnServer.allocatePortAndCreateFor( diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/artifact/AbstractArtifactRetrievalService.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/artifact/AbstractArtifactRetrievalService.java index 93ae6577353a..72af9e81c7b0 100644 --- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/artifact/AbstractArtifactRetrievalService.java +++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/artifact/AbstractArtifactRetrievalService.java @@ -95,15 +95,19 @@ public void getManifest( LOG.info("GetManifest for {}", token); try { - ArtifactApi.ProxyManifest proxyManifest = getManifestProxy(token); + final ArtifactApi.Manifest manifest; + if (AbstractArtifactStagingService.NO_ARTIFACTS_STAGED_TOKEN.equals(token)) { + manifest = ArtifactApi.Manifest.newBuilder().build(); + } else { + ArtifactApi.ProxyManifest proxyManifest = getManifestProxy(token); + LOG.info( + "GetManifest for {} -> {} artifacts", + token, + proxyManifest.getManifest().getArtifactCount()); + manifest = proxyManifest.getManifest(); + } ArtifactApi.GetManifestResponse response = - ArtifactApi.GetManifestResponse.newBuilder() - .setManifest(proxyManifest.getManifest()) - .build(); - LOG.info( - "GetManifest for {} -> {} artifacts", - token, - proxyManifest.getManifest().getArtifactCount()); + ArtifactApi.GetManifestResponse.newBuilder().setManifest(manifest).build(); responseObserver.onNext(response); responseObserver.onCompleted(); } catch (Exception e) { diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/artifact/AbstractArtifactStagingService.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/artifact/AbstractArtifactStagingService.java index 25f09a3e234b..86e79a5b9eab 100644 --- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/artifact/AbstractArtifactStagingService.java +++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/artifact/AbstractArtifactStagingService.java @@ -50,6 +50,8 @@ public abstract class AbstractArtifactStagingService extends ArtifactStagingServiceImplBase implements FnService { + public static final String NO_ARTIFACTS_STAGED_TOKEN = "__no_artifacts_staged__"; + private static final Logger LOG = LoggerFactory.getLogger(AbstractArtifactStagingService.class); private static final Charset CHARSET = StandardCharsets.UTF_8; @@ -77,25 +79,29 @@ public StreamObserver putArtifact( public void commitManifest( CommitManifestRequest request, StreamObserver responseObserver) { try { - String stagingSessionToken = request.getStagingSessionToken(); - ProxyManifest.Builder proxyManifestBuilder = - ProxyManifest.newBuilder().setManifest(request.getManifest()); - for (ArtifactMetadata artifactMetadata : request.getManifest().getArtifactList()) { - proxyManifestBuilder.addLocation( - Location.newBuilder() - .setName(artifactMetadata.getName()) - .setUri(getArtifactUri(stagingSessionToken, encodedFileName(artifactMetadata))) - .build()); - } - try (WritableByteChannel manifestWritableByteChannel = openManifest(stagingSessionToken)) { - manifestWritableByteChannel.write( - CHARSET.encode(JsonFormat.printer().print(proxyManifestBuilder.build()))); + final String retrievalToken; + if (request.getManifest().getArtifactCount() > 0) { + String stagingSessionToken = request.getStagingSessionToken(); + ProxyManifest.Builder proxyManifestBuilder = + ProxyManifest.newBuilder().setManifest(request.getManifest()); + for (ArtifactMetadata artifactMetadata : request.getManifest().getArtifactList()) { + proxyManifestBuilder.addLocation( + Location.newBuilder() + .setName(artifactMetadata.getName()) + .setUri(getArtifactUri(stagingSessionToken, encodedFileName(artifactMetadata))) + .build()); + } + try (WritableByteChannel manifestWritableByteChannel = openManifest(stagingSessionToken)) { + manifestWritableByteChannel.write( + CHARSET.encode(JsonFormat.printer().print(proxyManifestBuilder.build()))); + } + retrievalToken = getRetrievalToken(stagingSessionToken); + // TODO: Validate integrity of staged files. + } else { + retrievalToken = NO_ARTIFACTS_STAGED_TOKEN; } - // TODO: Validate integrity of staged files. responseObserver.onNext( - CommitManifestResponse.newBuilder() - .setRetrievalToken(getRetrievalToken(stagingSessionToken)) - .build()); + CommitManifestResponse.newBuilder().setRetrievalToken(retrievalToken).build()); responseObserver.onCompleted(); } catch (Exception e) { // TODO: Cleanup all the artifacts. diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/DefaultJobBundleFactory.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/DefaultJobBundleFactory.java index eaf1f3b4ef9e..a3f6f01fb929 100644 --- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/DefaultJobBundleFactory.java +++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/DefaultJobBundleFactory.java @@ -437,7 +437,8 @@ private ServerInfo createServerInfo(JobInfo jobInfo, ServerFactory serverFactory StaticGrpcProvisionService.create(jobInfo.toProvisionInfo()), serverFactory); GrpcFnServer dataServer = GrpcFnServer.allocatePortAndCreateFor( - GrpcDataService.create(executor, OutboundObserverFactory.serverDirect()), + GrpcDataService.create( + portableOptions, executor, OutboundObserverFactory.serverDirect()), serverFactory); GrpcFnServer stateServer = GrpcFnServer.allocatePortAndCreateFor(GrpcStateService.create(), serverFactory); diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/data/GrpcDataService.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/data/GrpcDataService.java index 69d378f3a892..5d58c5c8c368 100644 --- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/data/GrpcDataService.java +++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/data/GrpcDataService.java @@ -35,6 +35,7 @@ import org.apache.beam.sdk.fn.data.InboundDataClient; import org.apache.beam.sdk.fn.data.LogicalEndpoint; import org.apache.beam.sdk.fn.stream.OutboundObserverFactory; +import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.vendor.grpc.v1p21p0.io.grpc.stub.StreamObserver; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.SettableFuture; import org.slf4j.Logger; @@ -53,8 +54,10 @@ public class GrpcDataService extends BeamFnDataGrpc.BeamFnDataImplBase private static final Logger LOG = LoggerFactory.getLogger(GrpcDataService.class); public static GrpcDataService create( - ExecutorService executor, OutboundObserverFactory outboundObserverFactory) { - return new GrpcDataService(executor, outboundObserverFactory); + PipelineOptions options, + ExecutorService executor, + OutboundObserverFactory outboundObserverFactory) { + return new GrpcDataService(options, executor, outboundObserverFactory); } private final SettableFuture connectedClient; @@ -67,13 +70,17 @@ public static GrpcDataService create( */ private final Queue additionalMultiplexers; + private final PipelineOptions options; private final ExecutorService executor; private final OutboundObserverFactory outboundObserverFactory; private GrpcDataService( - ExecutorService executor, OutboundObserverFactory outboundObserverFactory) { + PipelineOptions options, + ExecutorService executor, + OutboundObserverFactory outboundObserverFactory) { this.connectedClient = SettableFuture.create(); this.additionalMultiplexers = new LinkedBlockingQueue<>(); + this.options = options; this.executor = executor; this.outboundObserverFactory = outboundObserverFactory; } @@ -83,6 +90,7 @@ private GrpcDataService( public GrpcDataService() { this.connectedClient = null; this.additionalMultiplexers = null; + this.options = null; this.executor = null; this.outboundObserverFactory = null; } @@ -168,7 +176,10 @@ public CloseableFnDataReceiver send(LogicalEndpoint outputLocation, Coder outputLocation.getTransformId()); try { return BeamFnDataBufferingOutboundObserver.forLocation( - outputLocation, coder, connectedClient.get(3, TimeUnit.MINUTES).getOutboundObserver()); + options, + outputLocation, + coder, + connectedClient.get(3, TimeUnit.MINUTES).getOutboundObserver()); } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new RuntimeException(e); diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/EmbeddedSdkHarness.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/EmbeddedSdkHarness.java index b7dd034365e2..a095b93a50fd 100644 --- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/EmbeddedSdkHarness.java +++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/EmbeddedSdkHarness.java @@ -79,7 +79,8 @@ protected void before() throws Exception { GrpcLoggingService.forWriter(Slf4jLogWriter.getDefault()), serverFactory); dataServer = GrpcFnServer.allocatePortAndCreateFor( - GrpcDataService.create(executor, OutboundObserverFactory.serverDirect()), + GrpcDataService.create( + PipelineOptionsFactory.create(), executor, OutboundObserverFactory.serverDirect()), serverFactory); controlServer = GrpcFnServer.allocatePortAndCreateFor(clientPoolService, serverFactory); diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/artifact/BeamFileSystemArtifactServicesTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/artifact/BeamFileSystemArtifactServicesTest.java index 40807cac8a17..95855308b3ad 100644 --- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/artifact/BeamFileSystemArtifactServicesTest.java +++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/artifact/BeamFileSystemArtifactServicesTest.java @@ -205,6 +205,24 @@ void checkCleanup(String stagingSessionToken, String stagingSession) throws Exce Files.exists(Paths.get(stagingDir.toAbsolutePath().toString(), stagingSession))); } + @Test + public void noArtifactsTest() throws Exception { + String stagingSession = "123"; + String stagingSessionToken = + BeamFileSystemArtifactStagingService.generateStagingSessionToken( + stagingSession, stagingDir.toUri().getPath()); + String stagingToken = commitManifest(stagingSessionToken, Collections.emptyList()); + Assert.assertEquals(AbstractArtifactStagingService.NO_ARTIFACTS_STAGED_TOKEN, stagingToken); + Assert.assertFalse( + Files.exists(Paths.get(stagingDir.toAbsolutePath().toString(), stagingSession))); + + GetManifestResponse retrievedManifest = + retrievalBlockingStub.getManifest( + GetManifestRequest.newBuilder().setRetrievalToken(stagingToken).build()); + Assert.assertEquals( + "Manifest with 0 artifacts", 0, retrievedManifest.getManifest().getArtifactCount()); + } + @Test public void putArtifactsSingleSmallFileTest() throws Exception { String fileName = "file1"; diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java index 6582a3df2dcd..dc732037ea83 100644 --- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java +++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java @@ -165,7 +165,10 @@ public void setup() throws Exception { InProcessServerFactory serverFactory = InProcessServerFactory.create(); dataServer = GrpcFnServer.allocatePortAndCreateFor( - GrpcDataService.create(serverExecutor, OutboundObserverFactory.serverDirect()), + GrpcDataService.create( + PipelineOptionsFactory.create(), + serverExecutor, + OutboundObserverFactory.serverDirect()), serverFactory); loggingServer = GrpcFnServer.allocatePortAndCreateFor( diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/SingleEnvironmentInstanceJobBundleFactoryTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/SingleEnvironmentInstanceJobBundleFactoryTest.java index 93a90bb12296..872720140c75 100644 --- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/SingleEnvironmentInstanceJobBundleFactoryTest.java +++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/SingleEnvironmentInstanceJobBundleFactoryTest.java @@ -48,6 +48,7 @@ import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.fn.IdGenerators; import org.apache.beam.sdk.fn.stream.OutboundObserverFactory; +import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; import org.junit.After; @@ -79,7 +80,8 @@ public void setup() throws Exception { InProcessServerFactory serverFactory = InProcessServerFactory.create(); dataServer = GrpcFnServer.allocatePortAndCreateFor( - GrpcDataService.create(executor, OutboundObserverFactory.serverDirect()), + GrpcDataService.create( + PipelineOptionsFactory.create(), executor, OutboundObserverFactory.serverDirect()), serverFactory); stateServer = GrpcFnServer.allocatePortAndCreateFor(GrpcStateService.create(), serverFactory); diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/data/GrpcDataServiceTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/data/GrpcDataServiceTest.java index a4458c0e5eb0..be08b5883df5 100644 --- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/data/GrpcDataServiceTest.java +++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/data/GrpcDataServiceTest.java @@ -45,6 +45,7 @@ import org.apache.beam.sdk.fn.data.LogicalEndpoint; import org.apache.beam.sdk.fn.stream.OutboundObserverFactory; import org.apache.beam.sdk.fn.test.TestStreams; +import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.vendor.grpc.v1p21p0.com.google.protobuf.ByteString; import org.apache.beam.vendor.grpc.v1p21p0.io.grpc.ManagedChannel; @@ -68,7 +69,9 @@ public void testMessageReceivedBySingleClientWhenThereAreMultipleClients() throw final CountDownLatch waitForInboundElements = new CountDownLatch(1); GrpcDataService service = GrpcDataService.create( - Executors.newCachedThreadPool(), OutboundObserverFactory.serverDirect()); + PipelineOptionsFactory.create(), + Executors.newCachedThreadPool(), + OutboundObserverFactory.serverDirect()); try (GrpcFnServer server = GrpcFnServer.allocatePortAndCreateFor(service, InProcessServerFactory.create())) { Collection> clientFutures = new ArrayList<>(); @@ -116,7 +119,9 @@ public void testMultipleClientsSendMessagesAreDirectedToProperConsumers() throws final CountDownLatch waitForInboundElements = new CountDownLatch(1); GrpcDataService service = GrpcDataService.create( - Executors.newCachedThreadPool(), OutboundObserverFactory.serverDirect()); + PipelineOptionsFactory.create(), + Executors.newCachedThreadPool(), + OutboundObserverFactory.serverDirect()); try (GrpcFnServer server = GrpcFnServer.allocatePortAndCreateFor(service, InProcessServerFactory.create())) { Collection> clientFutures = new ArrayList<>(); diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/SamzaExecutionContext.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/SamzaExecutionContext.java index 6518da36200c..dd18c9f02aca 100644 --- a/runners/samza/src/main/java/org/apache/beam/runners/samza/SamzaExecutionContext.java +++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/SamzaExecutionContext.java @@ -105,7 +105,8 @@ public void start() { fnDataServer = GrpcFnServer.allocatePortAndCreateFor( - GrpcDataService.create(dataExecutor, OutboundObserverFactory.serverDirect()), + GrpcDataService.create( + options, dataExecutor, OutboundObserverFactory.serverDirect()), ServerFactory.createDefault()); LOG.info("Started data server on port {}", fnDataServer.getServer().getPort()); diff --git a/runners/spark/build.gradle b/runners/spark/build.gradle index 2a1062de8a0c..7c11b87c7d39 100644 --- a/runners/spark/build.gradle +++ b/runners/spark/build.gradle @@ -187,7 +187,6 @@ task validatesStructuredStreamingRunnerBatch(type: Test) { maxParallelForks 4 useJUnit { includeCategories 'org.apache.beam.sdk.testing.ValidatesRunner' - excludeCategories 'org.apache.beam.sdk.testing.UsesCustomWindowMerging' // Unbounded excludeCategories 'org.apache.beam.sdk.testing.UsesUnboundedPCollections' excludeCategories 'org.apache.beam.sdk.testing.UsesTestStream' diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkCommonPipelineOptions.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkCommonPipelineOptions.java index a839725f2ce3..03ec68b450c7 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkCommonPipelineOptions.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkCommonPipelineOptions.java @@ -61,6 +61,12 @@ public interface SparkCommonPipelineOptions void setFilesToStage(List value); + @Description("Enable/disable sending aggregator values to Spark's metric sinks") + @Default.Boolean(true) + Boolean getEnableSparkMetricSinks(); + + void setEnableSparkMetricSinks(Boolean enableSparkMetricSinks); + /** * Returns the default checkpoint directory of /tmp/${job.name}. For testing purposes only. * Production applications should use a reliable filesystem such as HDFS/S3/GS. diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java index 0b25aab11d26..100227c6c1f7 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java @@ -84,12 +84,6 @@ public interface SparkPipelineOptions extends SparkCommonPipelineOptions { @Experimental void setBundleSize(Long value); - @Description("Enable/disable sending aggregator values to Spark's metric sinks") - @Default.Boolean(true) - Boolean getEnableSparkMetricSinks(); - - void setEnableSparkMetricSinks(Boolean enableSparkMetricSinks); - @Description( "If the spark runner will be initialized with a provided Spark Context. " + "The Spark Context should be provided with SparkContextOptions.") diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingPipelineOptions.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingPipelineOptions.java index 5573f78e7f0b..bc585d8a31e1 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingPipelineOptions.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingPipelineOptions.java @@ -19,7 +19,6 @@ import org.apache.beam.runners.spark.SparkCommonPipelineOptions; import org.apache.beam.sdk.options.Default; -import org.apache.beam.sdk.options.Description; import org.apache.beam.sdk.options.PipelineOptions; /** @@ -27,11 +26,6 @@ * master address, and other user-related knobs. */ public interface SparkStructuredStreamingPipelineOptions extends SparkCommonPipelineOptions { - @Description("Enable/disable sending aggregator values to Spark's metric sinks") - @Default.Boolean(true) - Boolean getEnableSparkMetricSinks(); - - void setEnableSparkMetricSinks(Boolean enableSparkMetricSinks); /** Set to true to run the job in test mode. */ @Default.Boolean(false) diff --git a/sdks/go/pkg/beam/model/fnexecution_v1/beam_fn_api.pb.go b/sdks/go/pkg/beam/model/fnexecution_v1/beam_fn_api.pb.go index 236b518f12f3..8b7cac0850fd 100644 --- a/sdks/go/pkg/beam/model/fnexecution_v1/beam_fn_api.pb.go +++ b/sdks/go/pkg/beam/model/fnexecution_v1/beam_fn_api.pb.go @@ -137,7 +137,7 @@ func (m *RemoteGrpcPort) GetCoderId() string { // matching instruction id. // Stable type InstructionRequest struct { - // (Required) An unique identifier provided by the runner which represents + // (Required) A unique identifier provided by the runner which represents // this requests execution. The InstructionResponse MUST have the matching id. InstructionId string `protobuf:"bytes,1,opt,name=instruction_id,json=instructionId,proto3" json:"instruction_id,omitempty"` // (Required) A request that the SDK Harness needs to interpret. @@ -2556,7 +2556,7 @@ func (m *Elements_Data) GetData() []byte { } type StateRequest struct { - // (Required) An unique identifier provided by the SDK which represents this + // (Required) A unique identifier provided by the SDK which represents this // requests execution. The StateResponse corresponding with this request // will have the matching id. Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/JsonMatcher.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/JsonMatcher.java new file mode 100644 index 000000000000..1f66dad32535 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/testing/JsonMatcher.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.testing; + +import static org.hamcrest.Matchers.is; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.IOException; +import java.util.Map; +import org.hamcrest.Description; +import org.hamcrest.Matcher; +import org.hamcrest.TypeSafeMatcher; + +/** + * Matcher to compare a string or byte[] representing a JSON Object, independent of field order. + * + *
+ *   assertThat("{\"name\": \"person\", \"height\": 80}",
+ *              jsonStringLike("{\"height\": 80, \"name\": \"person\"}"));
+ * 
+ */ +public abstract class JsonMatcher extends TypeSafeMatcher { + private Matcher> mapMatcher; + private static final ObjectMapper MAPPER = new ObjectMapper(); + private Map actualMap; + + public JsonMatcher(Map expectedMap) { + this.mapMatcher = is(expectedMap); + } + + protected abstract Map parse(T json) throws IOException; + + public static Matcher jsonBytesLike(String json) throws IOException { + Map fields = + MAPPER.readValue(json, new TypeReference>() {}); + return jsonBytesLike(fields); + } + + public static Matcher jsonBytesLike(Map fields) throws IOException { + return new JsonMatcher(fields) { + @Override + protected Map parse(byte[] json) throws IOException { + return MAPPER.readValue(json, new TypeReference>() {}); + } + }; + } + + public static Matcher jsonStringLike(String json) throws IOException { + Map fields = + MAPPER.readValue(json, new TypeReference>() {}); + return jsonStringLike(fields); + } + + public static Matcher jsonStringLike(Map fields) throws IOException { + return new JsonMatcher(fields) { + @Override + protected Map parse(String json) throws IOException { + return MAPPER.readValue(json, new TypeReference>() {}); + } + }; + } + + @Override + protected boolean matchesSafely(T actual) { + try { + actualMap = parse(actual); + } catch (IOException e) { + return false; + } + return mapMatcher.matches(actualMap); + } + + @Override + public void describeTo(Description description) { + mapMatcher.describeTo(description); + } + + @Override + protected void describeMismatchSafely(T item, Description mismatchDescription) { + mapMatcher.describeMismatch(actualMap, mismatchDescription); + } +} diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/RowJsonTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/RowJsonTest.java index d9ab4106225c..85b9aae605ed 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/RowJsonTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/RowJsonTest.java @@ -17,10 +17,12 @@ */ package org.apache.beam.sdk.util; +import static org.apache.beam.sdk.testing.JsonMatcher.jsonStringLike; +import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasProperty; import static org.hamcrest.Matchers.stringContainsInOrder; -import static org.junit.Assert.assertEquals; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.module.SimpleModule; @@ -239,18 +241,22 @@ private static Object[] makeNullsTestCase() { public void testDeserialize() throws IOException { Row parsedRow = newObjectMapperFor(schema).readValue(serializedString, Row.class); - assertEquals(row, parsedRow); + assertThat(row, equalTo(parsedRow)); + } + + @Test + public void testSerialize() throws IOException { + String str = newObjectMapperFor(schema).writeValueAsString(row); + + assertThat(str, jsonStringLike(serializedString)); } - // This serves to validate RowJsonSerializer. We don't have tests to check that the output - // string matches exactly what we expect, just that the string we produced can be deserialized - // again into an equal row. @Test public void testRoundTrip() throws IOException { ObjectMapper objectMapper = newObjectMapperFor(schema); Row parsedRow = objectMapper.readValue(objectMapper.writeValueAsString(row), Row.class); - assertEquals(row, parsedRow); + assertThat(row, equalTo(parsedRow)); } } @@ -416,7 +422,7 @@ private void testSupportedConversion( Row parsedRow = jsonParser.readValue(jsonObjectWith(fieldName, jsonFieldValue), Row.class); - assertEquals(expectedRow, parsedRow); + assertThat(expectedRow, equalTo(parsedRow)); } @Test diff --git a/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/extensions/gcp/util/RetryHttpRequestInitializer.java b/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/extensions/gcp/util/RetryHttpRequestInitializer.java index 8759f8e34404..5089e5088844 100644 --- a/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/extensions/gcp/util/RetryHttpRequestInitializer.java +++ b/sdks/java/extensions/google-cloud-platform-core/src/main/java/org/apache/beam/sdk/extensions/gcp/util/RetryHttpRequestInitializer.java @@ -42,7 +42,7 @@ /** * Implements a request initializer that adds retry handlers to all HttpRequests. * - *

Also can take a HttpResponseInterceptor to be applied to the responses. + *

Also can take an HttpResponseInterceptor to be applied to the responses. */ public class RetryHttpRequestInitializer implements HttpRequestInitializer { diff --git a/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/util/UploadIdResponseInterceptorTest.java b/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/util/UploadIdResponseInterceptorTest.java index d0bab1045f25..630a08491480 100644 --- a/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/util/UploadIdResponseInterceptorTest.java +++ b/sdks/java/extensions/google-cloud-platform-core/src/test/java/org/apache/beam/sdk/extensions/gcp/util/UploadIdResponseInterceptorTest.java @@ -39,7 +39,7 @@ public class UploadIdResponseInterceptorTest { @Rule public ExpectedLogs expectedLogs = ExpectedLogs.none(UploadIdResponseInterceptor.class); /** - * Builds a HttpResponse with the given string response. + * Builds an HttpResponse with the given string response. * * @param header header value to provide or null if none. * @param uploadId upload id to provide in the url upload id param or null if none. diff --git a/sdks/java/extensions/sql/build.gradle b/sdks/java/extensions/sql/build.gradle index 44bd77691e75..f0371c240a20 100644 --- a/sdks/java/extensions/sql/build.gradle +++ b/sdks/java/extensions/sql/build.gradle @@ -158,7 +158,6 @@ task integrationTest(type: Test) { include '**/*IT.class' exclude '**/KafkaCSVTableIT.java' - exclude '**/MongoDbReadWriteIT.java' maxParallelForks 4 classpath = project(":sdks:java:extensions:sql") .sourceSets diff --git a/sdks/java/extensions/sql/jdbc/build.gradle b/sdks/java/extensions/sql/jdbc/build.gradle index acddedfdeef5..02af0e6e6df2 100644 --- a/sdks/java/extensions/sql/jdbc/build.gradle +++ b/sdks/java/extensions/sql/jdbc/build.gradle @@ -36,7 +36,6 @@ dependencies { compile "jline:jline:2.14.6" compile "sqlline:sqlline:1.4.0" compile library.java.slf4j_jdk14 - compile library.java.guava testCompile project(path: ":sdks:java:io:google-cloud-platform", configuration: "testRuntime") testCompile library.java.junit testCompile library.java.hamcrest_core diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubIOJsonTable.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubIOJsonTable.java index 551e5a029e48..9e639e6e7566 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubIOJsonTable.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubIOJsonTable.java @@ -20,14 +20,13 @@ import static org.apache.beam.sdk.extensions.sql.meta.provider.pubsub.PubsubMessageToRow.DLQ_TAG; import static org.apache.beam.sdk.extensions.sql.meta.provider.pubsub.PubsubMessageToRow.MAIN_TAG; -import com.google.auto.value.AutoValue; import java.io.Serializable; -import javax.annotation.Nullable; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.extensions.sql.impl.BeamTableStatistics; import org.apache.beam.sdk.extensions.sql.meta.BaseBeamTable; import org.apache.beam.sdk.extensions.sql.meta.BeamSqlTable; +import org.apache.beam.sdk.extensions.sql.meta.provider.pubsub.PubsubJsonTableProvider.PubsubIOTableConfiguration; import org.apache.beam.sdk.io.gcp.pubsub.PubsubIO; import org.apache.beam.sdk.io.gcp.pubsub.PubsubMessage; import org.apache.beam.sdk.options.PipelineOptions; @@ -65,7 +64,7 @@ * } * * - *

Then SQL statements to declare and query such topic will look like this: + *

Then SQL statements to declare and query such a topic will look like this: * *

  *  CREATE TABLE topic_table (
@@ -85,49 +84,54 @@
  * the value of that attribute. If it is not specified, then message publish time will be used as
  * event timestamp. 'attributes' map contains Pubsub message attributes map unchanged and can be
  * referenced in the queries as well.
+ *
+ * 

Alternatively, one can use a flattened schema to model the pubsub messages (meaning {@link + * PubsubIOTableConfiguration#getUseFlatSchema()} is set). + * + *

In this configuration, only {@code event_timestamp} is required to be specified in the table + * schema. All other fields are assumed to be part of the message payload. SQL statements to declare + * and query the same topic as above will look like this: + * + *

+ *  CREATE TABLE topic_table (
+ *        event_timestamp TIMESTAMP,
+ *        name VARCHAR,
+ *        age INTEGER
+ *     )
+ *     TYPE 'pubsub'
+ *     LOCATION projects/<GCP project id>/topics/<topic name>
+ *     TBLPROPERTIES '{ \"timestampAttributeKey\" : <timestamp attribute> }';
+ *
+ *  SELECT event_timestamp, name FROM topic_table;
+ * 
+ * + *

If 'timestampAttributeKey' is specified in TBLPROPERTIES then 'event_timestamp' will be set to + * the value of that attribute. If it is not specified, then message publish time will be used as + * event timestamp. + * + *

In order to write to the same table you can use an INSERT statement like this: + * + *

+ *   INSERT INTO topic_table VALUES (TIMESTAMP '2019-11-13 10:14:14', 'Brian', 30)
+ * 
+ * + *

Note that when writing, the value for {@code event_timestamp} is ignored by default, since the + * Pubsub-managed publish time will be used to populate {@code event_timestamp} on read. In order to + * ensure the {@code event_timestamp} you specified is used, you should specify + * 'timestampAttributeKey' in TBLPROPERTIES. */ -@AutoValue @Internal @Experimental -abstract class PubsubIOJsonTable extends BaseBeamTable implements Serializable { - - /** - * Optional attribute key of the Pubsub message from which to extract the event timestamp. - * - *

This attribute has to conform to the same requirements as in {@link - * PubsubIO.Read.Builder#withTimestampAttribute}. - * - *

Short version: it has to be either millis since epoch or string in RFC 3339 format. - * - *

If the attribute is specified then event timestamps will be extracted from the specified - * attribute. If it is not specified then message publish timestamp will be used. - */ - @Nullable - abstract String getTimestampAttribute(); - - /** - * Optional topic path which will be used as a dead letter queue. - * - *

Messages that cannot be processed will be sent to this topic. If it is not specified then - * exception will be thrown for errors during processing causing the pipeline to crash. - */ - @Nullable - abstract String getDeadLetterQueue(); - - private boolean useDlq() { - return getDeadLetterQueue() != null; - } +class PubsubIOJsonTable extends BaseBeamTable implements Serializable { - /** - * Pubsub topic name. - * - *

Topic is the only way to specify the Pubsub source. Explicitly specifying the subscription - * is not supported at the moment. Subscriptions are automatically created (but not deleted). - */ - abstract String getTopic(); + protected final PubsubIOTableConfiguration config; + + private PubsubIOJsonTable(PubsubIOTableConfiguration config) { + this.config = config; + } - static Builder builder() { - return new AutoValue_PubsubIOJsonTable.Builder(); + static PubsubIOJsonTable withConfiguration(PubsubIOTableConfiguration config) { + return new PubsubIOJsonTable(config); } @Override @@ -135,14 +139,10 @@ public PCollection.IsBounded isBounded() { return PCollection.IsBounded.UNBOUNDED; } - /** - * Table schema, describes Pubsub message schema. - * - *

Includes fields 'event_timestamp', 'attributes, and 'payload'. See {@link - * PubsubMessageToRow}. - */ @Override - public abstract Schema getSchema(); + public Schema getSchema() { + return config.getSchema(); + } @Override public PCollection buildIOReader(PBegin begin) { @@ -152,7 +152,7 @@ public PCollection buildIOReader(PBegin begin) { .apply("parseMessageToRow", createParserParDo()); rowsWithDlq.get(MAIN_TAG).setRowSchema(getSchema()); - if (useDlq()) { + if (config.useDlq()) { rowsWithDlq.get(DLQ_TAG).apply(writeMessagesToDlq()); } @@ -163,47 +163,52 @@ private ParDo.MultiOutput createParserParDo() { return ParDo.of( PubsubMessageToRow.builder() .messageSchema(getSchema()) - .useDlq(getDeadLetterQueue() != null) + .useDlq(config.useDlq()) + .useFlatSchema(config.getUseFlatSchema()) .build()) - .withOutputTags(MAIN_TAG, useDlq() ? TupleTagList.of(DLQ_TAG) : TupleTagList.empty()); + .withOutputTags( + MAIN_TAG, config.useDlq() ? TupleTagList.of(DLQ_TAG) : TupleTagList.empty()); } private PubsubIO.Read readMessagesWithAttributes() { - PubsubIO.Read read = PubsubIO.readMessagesWithAttributes().fromTopic(getTopic()); + PubsubIO.Read read = + PubsubIO.readMessagesWithAttributes().fromTopic(config.getTopic()); - return (getTimestampAttribute() == null) - ? read - : read.withTimestampAttribute(getTimestampAttribute()); + return config.useTimestampAttribute() + ? read.withTimestampAttribute(config.getTimestampAttribute()) + : read; } private PubsubIO.Write writeMessagesToDlq() { - PubsubIO.Write write = PubsubIO.writeMessages().to(getDeadLetterQueue()); + PubsubIO.Write write = PubsubIO.writeMessages().to(config.getDeadLetterQueue()); - return (getTimestampAttribute() == null) - ? write - : write.withTimestampAttribute(getTimestampAttribute()); + return config.useTimestampAttribute() + ? write.withTimestampAttribute(config.getTimestampAttribute()) + : write; } @Override public POutput buildIOWriter(PCollection input) { - throw new UnsupportedOperationException("Writing to a Pubsub topic is not supported"); + if (!config.getUseFlatSchema()) { + throw new UnsupportedOperationException( + "Writing to a Pubsub topic is only supported for flattened schemas"); + } + + return input + .apply(RowToPubsubMessage.fromTableConfig(config)) + .apply(createPubsubMessageWrite()); + } + + private PubsubIO.Write createPubsubMessageWrite() { + PubsubIO.Write write = PubsubIO.writeMessages().to(config.getTopic()); + if (config.useTimestampAttribute()) { + write = write.withTimestampAttribute(config.getTimestampAttribute()); + } + return write; } @Override public BeamTableStatistics getTableStatistics(PipelineOptions options) { return BeamTableStatistics.UNBOUNDED_UNKNOWN; } - - @AutoValue.Builder - abstract static class Builder { - abstract Builder setSchema(Schema schema); - - abstract Builder setTimestampAttribute(String timestampAttribute); - - abstract Builder setDeadLetterQueue(String deadLetterQueue); - - abstract Builder setTopic(String topic); - - abstract PubsubIOJsonTable build(); - } } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubJsonTableProvider.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubJsonTableProvider.java index dc49771bf35f..2d32f319eb7f 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubJsonTableProvider.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubJsonTableProvider.java @@ -26,6 +26,9 @@ import com.alibaba.fastjson.JSONObject; import com.google.auto.service.AutoService; +import com.google.auto.value.AutoValue; +import java.io.Serializable; +import javax.annotation.Nullable; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.extensions.sql.meta.BeamSqlTable; @@ -51,40 +54,42 @@ public String getTableType() { @Override public BeamSqlTable buildBeamSqlTable(Table tableDefintion) { - validatePubsubMessageSchema(tableDefintion); - JSONObject tableProperties = tableDefintion.getProperties(); String timestampAttributeKey = tableProperties.getString("timestampAttributeKey"); String deadLetterQueue = tableProperties.getString("deadLetterQueue"); validateDlq(deadLetterQueue); - return PubsubIOJsonTable.builder() - .setSchema(tableDefintion.getSchema()) - .setTimestampAttribute(timestampAttributeKey) - .setDeadLetterQueue(deadLetterQueue) - .setTopic(tableDefintion.getLocation()) - .build(); - } + Schema schema = tableDefintion.getSchema(); + validateEventTimestamp(schema); - private void validatePubsubMessageSchema(Table tableDefinition) { - Schema schema = tableDefinition.getSchema(); + PubsubIOTableConfiguration config = + PubsubIOTableConfiguration.builder() + .setSchema(schema) + .setTimestampAttribute(timestampAttributeKey) + .setDeadLetterQueue(deadLetterQueue) + .setTopic(tableDefintion.getLocation()) + .setUseFlatSchema(!definesAttributeAndPayload(schema)) + .build(); - if (schema.getFieldCount() != 3 - || !fieldPresent(schema, TIMESTAMP_FIELD, TIMESTAMP) - || !fieldPresent( - schema, ATTRIBUTES_FIELD, Schema.FieldType.map(VARCHAR.withNullable(false), VARCHAR)) - || !(schema.hasField(PAYLOAD_FIELD) - && ROW.equals(schema.getField(PAYLOAD_FIELD).getType().getTypeName()))) { + return PubsubIOJsonTable.withConfiguration(config); + } + private void validateEventTimestamp(Schema schema) { + if (!fieldPresent(schema, TIMESTAMP_FIELD, TIMESTAMP)) { throw new IllegalArgumentException( - "Unsupported schema specified for Pubsub source in CREATE TABLE. " - + "CREATE TABLE for Pubsub topic should define exactly the following fields: " - + "'event_timestamp' field of type 'TIMESTAMP', 'attributes' field of type " - + "MAP, and 'payload' field of type 'ROW<...>' which matches the " - + "payload JSON format."); + "Unsupported schema specified for Pubsub source in CREATE TABLE." + + "CREATE TABLE for Pubsub topic must include at least 'event_timestamp' field of " + + "type 'TIMESTAMP'"); } } + private boolean definesAttributeAndPayload(Schema schema) { + return fieldPresent( + schema, ATTRIBUTES_FIELD, Schema.FieldType.map(VARCHAR.withNullable(false), VARCHAR)) + && (schema.hasField(PAYLOAD_FIELD) + && ROW.equals(schema.getField(PAYLOAD_FIELD).getType().getTypeName())); + } + private boolean fieldPresent(Schema schema, String field, Schema.FieldType expectedType) { return schema.hasField(field) && expectedType.equivalent( @@ -96,4 +101,77 @@ private void validateDlq(String deadLetterQueue) { throw new IllegalArgumentException("Dead letter queue topic name is not specified"); } } + + @AutoValue + public abstract static class PubsubIOTableConfiguration implements Serializable { + public boolean useDlq() { + return getDeadLetterQueue() != null; + } + + public boolean useTimestampAttribute() { + return getTimestampAttribute() != null; + } + + /** Determines whether or not the messages should be represented with a flattened schema. */ + abstract boolean getUseFlatSchema(); + + /** + * Optional attribute key of the Pubsub message from which to extract the event timestamp. + * + *

This attribute has to conform to the same requirements as in {@link + * PubsubIO.Read.Builder#withTimestampAttribute}. + * + *

Short version: it has to be either millis since epoch or string in RFC 3339 format. + * + *

If the attribute is specified then event timestamps will be extracted from the specified + * attribute. If it is not specified then message publish timestamp will be used. + */ + @Nullable + abstract String getTimestampAttribute(); + + /** + * Optional topic path which will be used as a dead letter queue. + * + *

Messages that cannot be processed will be sent to this topic. If it is not specified then + * exception will be thrown for errors during processing causing the pipeline to crash. + */ + @Nullable + abstract String getDeadLetterQueue(); + + /** + * Pubsub topic name. + * + *

Topic is the only way to specify the Pubsub source. Explicitly specifying the subscription + * is not supported at the moment. Subscriptions are automatically created (but not deleted). + */ + abstract String getTopic(); + + /** + * Table schema, describes Pubsub message schema. + * + *

If {@link #getUseFlatSchema()} is not set, schema must contain exactly fields + * 'event_timestamp', 'attributes, and 'payload'. Else, it must contain just 'event_timestamp'. + * See {@linkA PubsubMessageToRow} for details. + */ + public abstract Schema getSchema(); + + static Builder builder() { + return new AutoValue_PubsubJsonTableProvider_PubsubIOTableConfiguration.Builder(); + } + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setUseFlatSchema(boolean useFlatSchema); + + abstract Builder setSchema(Schema schema); + + abstract Builder setTimestampAttribute(String timestampAttribute); + + abstract Builder setDeadLetterQueue(String deadLetterQueue); + + abstract Builder setTopic(String topic); + + abstract PubsubIOTableConfiguration build(); + } + } } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubMessageToRow.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubMessageToRow.java index 3d9a71217566..64d4bc389789 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubMessageToRow.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubMessageToRow.java @@ -24,6 +24,8 @@ import com.google.auto.value.AutoValue; import java.nio.charset.StandardCharsets; import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; import javax.annotation.Nullable; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.annotations.Internal; @@ -54,10 +56,12 @@ public abstract class PubsubMessageToRow extends DoFn { /** * Schema of the Pubsub message. * - *

Required to have exactly 3 top level fields at the moment: + *

Required to have at least 'event_timestamp' field of type {@link Schema.FieldType#DATETIME}. + * + *

If {@code useFlatSchema()} is set every other field is assumed to be part of the payload. + * Otherwise, the schema must contain exactly: * *

    - *
  • 'event_timestamp' of type {@link Schema.FieldType#DATETIME} *
  • 'attributes' of type {@link TypeName#MAP MAP<VARCHAR,VARCHAR>} *
  • 'payload' of type {@link TypeName#ROW ROW<...>} *
@@ -68,8 +72,18 @@ public abstract class PubsubMessageToRow extends DoFn { public abstract boolean useDlq(); + public abstract boolean useFlatSchema(); + private Schema payloadSchema() { - return messageSchema().getField(PAYLOAD_FIELD).getType().getRowSchema(); + if (!useFlatSchema()) { + return messageSchema().getField(PAYLOAD_FIELD).getType().getRowSchema(); + } else { + // The payload contains every field in the schema except event_timestamp + return new Schema( + messageSchema().getFields().stream() + .filter(f -> !f.getName().equals(TIMESTAMP_FIELD)) + .collect(Collectors.toList())); + } } public static Builder builder() { @@ -95,28 +109,40 @@ public void processElement(ProcessContext context) { * payload, and attributes. */ private List getFieldValues(ProcessContext context) { + Row payload = parsePayloadJsonRow(context.element()); return messageSchema().getFields().stream() - .map(field -> getValueForField(field, context.timestamp(), context.element())) + .map( + field -> + getValueForField( + field, context.timestamp(), context.element().getAttributeMap(), payload)) .collect(toList()); } private Object getValueForField( - Schema.Field field, Instant timestamp, PubsubMessage pubsubMessage) { - - switch (field.getName()) { - case TIMESTAMP_FIELD: + Schema.Field field, Instant timestamp, Map attributeMap, Row payload) { + // TODO(BEAM-8801): do this check once at construction time, rather than for every element. + if (useFlatSchema()) { + if (field.getName().equals(TIMESTAMP_FIELD)) { return timestamp; - case ATTRIBUTES_FIELD: - return pubsubMessage.getAttributeMap(); - case PAYLOAD_FIELD: - return parsePayloadJsonRow(pubsubMessage); - default: - throw new IllegalArgumentException( - "Unexpected field '" - + field.getName() - + "' in top level schema" - + " for Pubsub message. Top level schema should only contain " - + "'timestamp', 'attributes', and 'payload' fields"); + } else { + return payload.getValue(field.getName()); + } + } else { + switch (field.getName()) { + case TIMESTAMP_FIELD: + return timestamp; + case ATTRIBUTES_FIELD: + return attributeMap; + case PAYLOAD_FIELD: + return payload; + default: + throw new IllegalArgumentException( + "Unexpected field '" + + field.getName() + + "' in top level schema" + + " for Pubsub message. Top level schema should only contain " + + "'timestamp', 'attributes', and 'payload' fields"); + } } } @@ -136,6 +162,8 @@ abstract static class Builder { public abstract Builder useDlq(boolean useDlq); + public abstract Builder useFlatSchema(boolean useFlatSchema); + public abstract PubsubMessageToRow build(); } } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/RowToPubsubMessage.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/RowToPubsubMessage.java new file mode 100644 index 000000000000..dce5b1fb206f --- /dev/null +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/RowToPubsubMessage.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.sql.meta.provider.pubsub; + +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument; + +import java.nio.charset.StandardCharsets; +import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.io.gcp.pubsub.PubsubMessage; +import org.apache.beam.sdk.schemas.transforms.DropFields; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ToJson; +import org.apache.beam.sdk.transforms.WithTimestamps; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; + +/** + * A {@link PTransform} to convert {@link Row} to {@link PubsubMessage} with JSON payload. + * + *

Currently only supports writing a flat schema into a JSON payload. This means that all Row + * field values are written to the {@link PubsubMessage} JSON payload, except for {@code + * event_timestamp}, which is either ignored or written to the message attributes, depending on + * whether {@link PubsubJsonTableProvider.PubsubIOTableConfiguration#getTimestampAttribute()} is + * set. + */ +@Experimental +public class RowToPubsubMessage extends PTransform, PCollection> { + private final PubsubJsonTableProvider.PubsubIOTableConfiguration config; + + private RowToPubsubMessage(PubsubJsonTableProvider.PubsubIOTableConfiguration config) { + checkArgument( + config.getUseFlatSchema(), "RowToPubsubMessage is only supported for flattened schemas."); + + this.config = config; + } + + public static RowToPubsubMessage fromTableConfig( + PubsubJsonTableProvider.PubsubIOTableConfiguration config) { + return new RowToPubsubMessage(config); + } + + @Override + public PCollection expand(PCollection input) { + PCollection withTimestamp = + (config.useTimestampAttribute()) + ? input.apply( + WithTimestamps.of((row) -> row.getDateTime("event_timestamp").toInstant())) + : input; + + return withTimestamp + .apply(DropFields.fields("event_timestamp")) + .apply(ToJson.of()) + .apply( + MapElements.into(TypeDescriptor.of(PubsubMessage.class)) + .via( + (String json) -> + new PubsubMessage( + json.getBytes(StandardCharsets.ISO_8859_1), ImmutableMap.of()))); + } +} diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/PubsubToBigqueryIT.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/PubsubToBigqueryIT.java index dc73b20bbd74..24eba2e804cd 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/PubsubToBigqueryIT.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/PubsubToBigqueryIT.java @@ -111,6 +111,57 @@ public void testSimpleInsert() throws Exception { .pollFor(Duration.standardMinutes(5)); } + @Test + public void testSimpleInsertFlat() throws Exception { + BeamSqlEnv sqlEnv = + BeamSqlEnv.inMemory(new PubsubJsonTableProvider(), new BigQueryTableProvider()); + + String createTableString = + "CREATE EXTERNAL TABLE pubsub_topic (\n" + + "event_timestamp TIMESTAMP, \n" + + "id INTEGER, \n" + + "name VARCHAR \n" + + ") \n" + + "TYPE 'pubsub' \n" + + "LOCATION '" + + pubsub.topicPath() + + "' \n" + + "TBLPROPERTIES '{ \"timestampAttributeKey\" : \"ts\" }'"; + sqlEnv.executeDdl(createTableString); + + String createTableStatement = + "CREATE EXTERNAL TABLE bq_table( \n" + + " id BIGINT, \n" + + " name VARCHAR \n " + + ") \n" + + "TYPE 'bigquery' \n" + + "LOCATION '" + + bigQuery.tableSpec() + + "'"; + sqlEnv.executeDdl(createTableStatement); + + String insertStatement = + "INSERT INTO bq_table \n" + "SELECT \n" + " id, \n" + " name \n" + "FROM pubsub_topic"; + + BeamSqlRelUtils.toPCollection(pipeline, sqlEnv.parseQuery(insertStatement)); + + pipeline.run(); + + List messages = + ImmutableList.of( + message(ts(1), 3, "foo"), message(ts(2), 5, "bar"), message(ts(3), 7, "baz")); + pubsub.publish(messages); + + bigQuery + .assertThatAllRows(SOURCE_SCHEMA) + .eventually( + containsInAnyOrder( + row(SOURCE_SCHEMA, 3L, "foo"), + row(SOURCE_SCHEMA, 5L, "bar"), + row(SOURCE_SCHEMA, 7L, "baz"))) + .pollFor(Duration.standardMinutes(5)); + } + private Row row(Schema schema, Object... values) { return Row.withSchema(schema).addValues(values).build(); } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbReadWriteIT.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbReadWriteIT.java index 82cafb931e57..aa11690b3aad 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbReadWriteIT.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/mongodb/MongoDbReadWriteIT.java @@ -28,12 +28,21 @@ import static org.junit.Assert.assertEquals; import com.mongodb.MongoClient; +import de.flapdoodle.embed.mongo.MongodExecutable; +import de.flapdoodle.embed.mongo.MongodProcess; +import de.flapdoodle.embed.mongo.MongodStarter; +import de.flapdoodle.embed.mongo.config.IMongodConfig; +import de.flapdoodle.embed.mongo.config.MongoCmdOptionsBuilder; +import de.flapdoodle.embed.mongo.config.MongodConfigBuilder; +import de.flapdoodle.embed.mongo.config.Net; +import de.flapdoodle.embed.mongo.config.Storage; +import de.flapdoodle.embed.mongo.distribution.Version; +import de.flapdoodle.embed.process.runtime.Network; import java.util.Arrays; import org.apache.beam.sdk.extensions.sql.impl.BeamSqlEnv; import org.apache.beam.sdk.extensions.sql.impl.rel.BeamRelNode; import org.apache.beam.sdk.extensions.sql.impl.rel.BeamSqlRelUtils; -import org.apache.beam.sdk.io.mongodb.MongoDBIOIT.MongoDBPipelineOptions; -import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.io.common.NetworkTestHelper; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.testing.PAssert; @@ -42,35 +51,22 @@ import org.apache.beam.sdk.values.Row; import org.junit.AfterClass; import org.junit.BeforeClass; -import org.junit.Ignore; +import org.junit.ClassRule; import org.junit.Rule; import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * A test of {@link org.apache.beam.sdk.extensions.sql.meta.provider.mongodb.MongoDbTable} on an * independent Mongo instance. - * - *

This test requires a running instance of MongoDB. Pass in connection information using - * PipelineOptions: - * - *

- *  ./gradlew integrationTest -p sdks/java/extensions/sql/integrationTest -DintegrationTestPipelineOptions='[
- *  "--mongoDBHostName=1.2.3.4",
- *  "--mongoDBPort=27017",
- *  "--mongoDBDatabaseName=mypass",
- *  "--numberOfRecords=1000" ]'
- *  --tests org.apache.beam.sdk.extensions.sql.meta.provider.mongodb.MongoDbReadWriteIT
- *  -DintegrationTestRunner=direct
- * 
- * - * A database, specified in the pipeline options, will be created implicitly if it does not exist - * already. And dropped upon completing tests. - * - *

Please see 'build_rules.gradle' file for instructions regarding running this test using Beam - * performance testing framework. */ -@Ignore("https://issues.apache.org/jira/browse/BEAM-8586") +@RunWith(JUnit4.class) public class MongoDbReadWriteIT { + private static final Logger LOG = LoggerFactory.getLogger(MongoDbReadWriteIT.class); private static final Schema SOURCE_SCHEMA = Schema.builder() .addNullableField("_id", STRING) @@ -84,38 +80,57 @@ public class MongoDbReadWriteIT { .addNullableField("c_varchar", STRING) .addNullableField("c_arr", FieldType.array(STRING)) .build(); + private static final String hostname = "localhost"; + private static final String database = "beam"; private static final String collection = "collection"; - private static MongoDBPipelineOptions options; + private static int port; + + @ClassRule public static final TemporaryFolder MONGODB_LOCATION = new TemporaryFolder(); + + private static final MongodStarter mongodStarter = MongodStarter.getDefaultInstance(); + private static MongodExecutable mongodExecutable; + private static MongodProcess mongodProcess; + private static MongoClient client; @Rule public final TestPipeline writePipeline = TestPipeline.create(); @Rule public final TestPipeline readPipeline = TestPipeline.create(); @BeforeClass public static void setUp() throws Exception { - PipelineOptionsFactory.register(MongoDBPipelineOptions.class); - options = TestPipeline.testingPipelineOptions().as(MongoDBPipelineOptions.class); + port = NetworkTestHelper.getAvailableLocalPort(); + LOG.info("Starting MongoDB embedded instance on {}", port); + IMongodConfig mongodConfig = + new MongodConfigBuilder() + .version(Version.Main.PRODUCTION) + .configServer(false) + .replication(new Storage(MONGODB_LOCATION.getRoot().getPath(), null, 0)) + .net(new Net(hostname, port, Network.localhostIsIPv6())) + .cmdOptions( + new MongoCmdOptionsBuilder() + .syncDelay(10) + .useNoPrealloc(true) + .useSmallFiles(true) + .useNoJournal(true) + .verbose(false) + .build()) + .build(); + mongodExecutable = mongodStarter.prepare(mongodConfig); + mongodProcess = mongodExecutable.start(); + client = new MongoClient(hostname, port); } @AfterClass public static void tearDown() throws Exception { - dropDatabase(); - } - - private static void dropDatabase() throws Exception { - new MongoClient(options.getMongoDBHostName()) - .getDatabase(options.getMongoDBDatabaseName()) - .drop(); + client.dropDatabase(database); + client.close(); + mongodProcess.stop(); + mongodExecutable.stop(); } @Test public void testWriteAndRead() { final String mongoSqlUrl = - String.format( - "mongodb://%s:%d/%s/%s", - options.getMongoDBHostName(), - options.getMongoDBPort(), - options.getMongoDBDatabaseName(), - collection); + String.format("mongodb://%s:%d/%s/%s", hostname, port, database, collection); Row testRow = row( diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubJsonIT.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubJsonIT.java index 430338faa613..0b5a4906f4b5 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubJsonIT.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubJsonIT.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.extensions.sql.meta.provider.pubsub; import static java.nio.charset.StandardCharsets.UTF_8; +import static org.apache.beam.sdk.testing.JsonMatcher.jsonBytesLike; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.equalTo; @@ -26,6 +27,7 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.IOException; import java.io.Serializable; import java.nio.charset.StandardCharsets; import java.sql.ResultSet; @@ -95,9 +97,11 @@ public class PubsubJsonIT implements Serializable { private static volatile Boolean checked = false; @Rule public transient TestPubsub eventsTopic = TestPubsub.create(); + @Rule public transient TestPubsub filteredEventsTopic = TestPubsub.create(); @Rule public transient TestPubsub dlqTopic = TestPubsub.create(); @Rule public transient TestPubsubSignal resultSignal = TestPubsubSignal.create(); @Rule public transient TestPipeline pipeline = TestPipeline.create(); + @Rule public transient TestPipeline filterPipeline = TestPipeline.create(); /** * HACK: we need an objectmapper to turn pipelineoptions back into a map. We need to use @@ -108,7 +112,7 @@ public class PubsubJsonIT implements Serializable { .registerModules(ObjectMapper.findModules(ReflectHelpers.findClassLoader())); @Test - public void testSelectsPayloadContent() throws Exception { + public void testSQLSelectsPayloadContent() throws Exception { String createTableString = "CREATE EXTERNAL TABLE message (\n" + "event_timestamp TIMESTAMP, \n" @@ -342,6 +346,231 @@ public void testWritesJsonRowsToPubsub() throws Exception { .waitForUpTo(Duration.standardSeconds(20)); } + @Test + public void testSQLSelectsPayloadContentFlat() throws Exception { + String createTableString = + "CREATE EXTERNAL TABLE message (\n" + + "event_timestamp TIMESTAMP, \n" + + "id INTEGER, \n" + + "name VARCHAR \n" + + ") \n" + + "TYPE 'pubsub' \n" + + "LOCATION '" + + eventsTopic.topicPath() + + "' \n" + + "TBLPROPERTIES '{ \"timestampAttributeKey\" : \"ts\" }'"; + + String queryString = "SELECT message.id, message.name from message"; + + // Prepare messages to send later + List messages = + ImmutableList.of( + message(ts(1), 3, "foo"), message(ts(2), 5, "bar"), message(ts(3), 7, "baz")); + + // Initialize SQL environment and create the pubsub table + BeamSqlEnv sqlEnv = BeamSqlEnv.inMemory(new PubsubJsonTableProvider()); + sqlEnv.executeDdl(createTableString); + + // Apply the PTransform to query the pubsub topic + PCollection queryOutput = query(sqlEnv, pipeline, queryString); + + // Observe the query results and send success signal after seeing the expected messages + queryOutput.apply( + "waitForSuccess", + resultSignal.signalSuccessWhen( + SchemaCoder.of(PAYLOAD_SCHEMA), + observedRows -> + observedRows.equals( + ImmutableSet.of( + row(PAYLOAD_SCHEMA, 3, "foo"), + row(PAYLOAD_SCHEMA, 5, "bar"), + row(PAYLOAD_SCHEMA, 7, "baz"))))); + + // Send the start signal to make sure the signaling topic is initialized + Supplier start = resultSignal.waitForStart(Duration.standardMinutes(5)); + pipeline.begin().apply(resultSignal.signalStart()); + + // Start the pipeline + pipeline.run(); + + // Wait until got the start response from the signalling topic + start.get(); + + // Start publishing the messages when main pipeline is started and signaling topic is ready + eventsTopic.publish(messages); + + // Poll the signaling topic for success message + resultSignal.waitForSuccess(Duration.standardSeconds(60)); + } + + @Test + public void testSQLInsertJsonRowsToPubsubFlat() throws Exception { + String createTableString = + "CREATE EXTERNAL TABLE message (\n" + + "event_timestamp TIMESTAMP, \n" + + "name VARCHAR, \n" + + "height INTEGER, \n" + + "knowsJavascript BOOLEAN \n" + + ") \n" + + "TYPE 'pubsub' \n" + + "LOCATION '" + + eventsTopic.topicPath() + + "' \n" + + "TBLPROPERTIES " + + " '{ " + + " \"deadLetterQueue\" : \"" + + dlqTopic.topicPath() + + "\"" + + " }'"; + + // Initialize SQL environment and create the pubsub table + BeamSqlEnv sqlEnv = BeamSqlEnv.inMemory(new PubsubJsonTableProvider()); + sqlEnv.executeDdl(createTableString); + + // TODO(BEAM-8741): Ideally we could write this query without specifying a column list, because + // it shouldn't be possible to write to event_timestamp when it's mapped to publish time. + String queryString = + "INSERT INTO message (name, height, knowsJavascript) \n" + + "VALUES \n" + + "('person1', 80, TRUE), \n" + + "('person2', 70, FALSE)"; + + // Apply the PTransform to insert the rows + PCollection queryOutput = query(sqlEnv, pipeline, queryString); + + pipeline.run().waitUntilFinish(Duration.standardMinutes(5)); + + eventsTopic + .assertThatTopicEventuallyReceives( + jsonMessageLike("{\"name\":\"person1\", \"height\": 80, \"knowsJavascript\": true}"), + jsonMessageLike("{\"name\":\"person2\", \"height\": 70, \"knowsJavascript\": false}")) + .waitForUpTo(Duration.standardSeconds(20)); + } + + @Test + public void testSQLInsertJsonRowsToPubsubWithTimestampAttributeFlat() throws Exception { + String createTableString = + "CREATE EXTERNAL TABLE message (\n" + + " event_timestamp TIMESTAMP, \n" + + " name VARCHAR, \n" + + " height INTEGER, \n" + + " knowsJavascript BOOLEAN \n" + + ") \n" + + "TYPE 'pubsub' \n" + + "LOCATION '" + + eventsTopic.topicPath() + + "' \n" + + "TBLPROPERTIES " + + " '{ " + + " \"deadLetterQueue\" : \"" + + dlqTopic.topicPath() + + "\"," + + " \"timestampAttributeKey\" : \"ts\"" + + " }'"; + + // Initialize SQL environment and create the pubsub table + BeamSqlEnv sqlEnv = BeamSqlEnv.inMemory(new PubsubJsonTableProvider()); + sqlEnv.executeDdl(createTableString); + + String queryString = + "INSERT INTO message " + + "VALUES " + + "(TIMESTAMP '1970-01-01 00:00:00.001', 'person1', 80, TRUE), " + + "(TIMESTAMP '1970-01-01 00:00:00.002', 'person2', 70, FALSE)"; + PCollection queryOutput = query(sqlEnv, pipeline, queryString); + + pipeline.run().waitUntilFinish(Duration.standardMinutes(5)); + + eventsTopic + .assertThatTopicEventuallyReceives( + jsonMessageLike( + ts(1), "{\"name\":\"person1\", \"height\": 80, \"knowsJavascript\": true}"), + jsonMessageLike( + ts(2), "{\"name\":\"person2\", \"height\": 70, \"knowsJavascript\": false}")) + .waitForUpTo(Duration.standardSeconds(20)); + } + + @Test + public void testSQLReadAndWriteWithSameFlatTableDefinition() throws Exception { + // This test verifies that the same pubsub table definition can be used for both reading and + // writing + // pipeline: Use SQL to insert data into `people` + // filterPipeline: Use SQL to read from `people`, filter the rows, and write to + // `javascript_people` + + String createTableString = + "CREATE EXTERNAL TABLE people (\n" + + "event_timestamp TIMESTAMP, \n" + + "name VARCHAR, \n" + + "height INTEGER, \n" + + "knowsJavascript BOOLEAN \n" + + ") \n" + + "TYPE 'pubsub' \n" + + "LOCATION '" + + eventsTopic.topicPath() + + "' \n"; + + String createFilteredTableString = + "CREATE EXTERNAL TABLE javascript_people (\n" + + "event_timestamp TIMESTAMP, \n" + + "name VARCHAR, \n" + + "height INTEGER \n" + + ") \n" + + "TYPE 'pubsub' \n" + + "LOCATION '" + + filteredEventsTopic.topicPath() + + "' \n"; + + // Initialize SQL environment and create the pubsub table + BeamSqlEnv sqlEnv = BeamSqlEnv.inMemory(new PubsubJsonTableProvider()); + sqlEnv.executeDdl(createTableString); + sqlEnv.executeDdl(createFilteredTableString); + + // TODO(BEAM-8741): Ideally we could write these queries without specifying a column list, + // because + // it shouldn't be possible to write to event_timestamp when it's mapped to publish time. + String filterQueryString = + "INSERT INTO javascript_people (name, height) (\n" + + " SELECT \n" + + " name, \n" + + " height \n" + + " FROM people \n" + + " WHERE knowsJavascript \n" + + ")"; + + String injectQueryString = + "INSERT INTO people (name, height, knowsJavascript) VALUES \n" + + "('person1', 80, TRUE), \n" + + "('person2', 70, FALSE), \n" + + "('person3', 60, TRUE), \n" + + "('person4', 50, FALSE), \n" + + "('person5', 40, TRUE)"; + + // Apply the PTransform to do the filtering + query(sqlEnv, filterPipeline, filterQueryString); + + // Apply the PTransform to inject the input data + query(sqlEnv, pipeline, injectQueryString); + + // Send the start signal to make sure the signaling topic is initialized + Supplier start = resultSignal.waitForStart(Duration.standardMinutes(5)); + filterPipeline.begin().apply("signal filter pipeline started", resultSignal.signalStart()); + + // Start the filter pipeline and wait until it has started. + filterPipeline.run(); + start.get(); + + // .. then run the injector pipeline + pipeline.run().waitUntilFinish(Duration.standardMinutes(5)); + + filteredEventsTopic + .assertThatTopicEventuallyReceives( + jsonMessageLike("{\"name\":\"person1\", \"height\": 80}"), + jsonMessageLike("{\"name\":\"person3\", \"height\": 60}"), + jsonMessageLike("{\"name\":\"person5\", \"height\": 40}")) + .waitForUpTo(Duration.standardMinutes(5)); + } + private static String toArg(Object o) { try { String jsonRepr = MAPPER.writeValueAsString(o); @@ -415,6 +644,17 @@ private Matcher messageLike(String jsonPayload) { return hasProperty("payload", equalTo(jsonPayload.getBytes(StandardCharsets.US_ASCII))); } + private Matcher jsonMessageLike(Instant timestamp, String jsonPayload) + throws IOException { + return allOf( + hasProperty("payload", jsonBytesLike(jsonPayload)), + hasProperty("attributeMap", hasEntry("ts", String.valueOf(timestamp.getMillis())))); + } + + private Matcher jsonMessageLike(String jsonPayload) throws IOException { + return hasProperty("payload", jsonBytesLike(jsonPayload)); + } + private String jsonString(int id, String name) { return "{ \"id\" : " + id + ", \"name\" : \"" + name + "\" }"; } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubJsonTableProviderTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubJsonTableProviderTest.java index e7f5fc77276d..ccfb613979d3 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubJsonTableProviderTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubJsonTableProviderTest.java @@ -75,53 +75,30 @@ public void testThrowsIfTimestampFieldNotProvided() { } @Test - public void testThrowsIfAttributesFieldNotProvided() { + public void testCreatesTableWithJustTimestamp() { PubsubJsonTableProvider provider = new PubsubJsonTableProvider(); - Schema messageSchema = - Schema.builder() - .addDateTimeField("event_timestamp") - .addRowField("payload", Schema.builder().build()) - .build(); + Schema messageSchema = Schema.builder().addDateTimeField("event_timestamp").build(); Table tableDefinition = tableDefinition().schema(messageSchema).build(); - thrown.expectMessage("Unsupported"); - thrown.expectMessage("'attributes'"); - provider.buildBeamSqlTable(tableDefinition); - } - - @Test - public void testThrowsIfPayloadFieldNotProvided() { - PubsubJsonTableProvider provider = new PubsubJsonTableProvider(); - Schema messageSchema = - Schema.builder() - .addDateTimeField("event_timestamp") - .addMapField("attributes", VARCHAR, VARCHAR) - .build(); - - Table tableDefinition = tableDefinition().schema(messageSchema).build(); + BeamSqlTable pubsubTable = provider.buildBeamSqlTable(tableDefinition); - thrown.expectMessage("Unsupported"); - thrown.expectMessage("'payload'"); - provider.buildBeamSqlTable(tableDefinition); + assertNotNull(pubsubTable); + assertEquals(messageSchema, pubsubTable.getSchema()); } @Test - public void testThrowsIfExtraFieldsExist() { + public void testCreatesFlatTable() { PubsubJsonTableProvider provider = new PubsubJsonTableProvider(); Schema messageSchema = - Schema.builder() - .addDateTimeField("event_timestamp") - .addMapField("attributes", VARCHAR, VARCHAR) - .addStringField("someField") - .addRowField("payload", Schema.builder().build()) - .build(); + Schema.builder().addDateTimeField("event_timestamp").addStringField("someField").build(); Table tableDefinition = tableDefinition().schema(messageSchema).build(); - thrown.expectMessage("Unsupported"); - thrown.expectMessage("'event_timestamp'"); - provider.buildBeamSqlTable(tableDefinition); + BeamSqlTable pubsubTable = provider.buildBeamSqlTable(tableDefinition); + + assertNotNull(pubsubTable); + assertEquals(messageSchema, pubsubTable.getSchema()); } private static Table.Builder tableDefinition() { diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubMessageToRowTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubMessageToRowTest.java index 370c2146aac7..cc6ae5fbfa1c 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubMessageToRowTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/pubsub/PubsubMessageToRowTest.java @@ -85,6 +85,7 @@ public void testConvertsMessages() { PubsubMessageToRow.builder() .messageSchema(messageSchema) .useDlq(false) + .useFlatSchema(false) .build())); PAssert.that(rows) @@ -134,6 +135,7 @@ public void testSendsInvalidToDLQ() { PubsubMessageToRow.builder() .messageSchema(messageSchema) .useDlq(true) + .useFlatSchema(false) .build()) .withOutputTags(MAIN_TAG, TupleTagList.of(DLQ_TAG))); @@ -166,6 +168,113 @@ public void testSendsInvalidToDLQ() { pipeline.run(); } + @Test + public void testConvertsMessagesToFlatRow() { + Schema messageSchema = + Schema.builder() + .addDateTimeField("event_timestamp") + .addNullableField("id", FieldType.INT32) + .addNullableField("name", FieldType.STRING) + .build(); + + PCollection rows = + pipeline + .apply( + "create", + Create.timestamped( + message(1, map("attr", "val"), "{ \"id\" : 3, \"name\" : \"foo\" }"), + message(2, map("bttr", "vbl"), "{ \"name\" : \"baz\", \"id\" : 5 }"), + message(3, map("cttr", "vcl"), "{ \"id\" : 7, \"name\" : \"bar\" }"), + message(4, map("dttr", "vdl"), "{ \"name\" : \"qaz\", \"id\" : 8 }"), + message(4, map("dttr", "vdl"), "{ \"name\" : null, \"id\" : null }"))) + .apply( + "convert", + ParDo.of( + PubsubMessageToRow.builder() + .messageSchema(messageSchema) + .useDlq(false) + .useFlatSchema(true) + .build())); + + PAssert.that(rows) + .containsInAnyOrder( + Row.withSchema(messageSchema) + .addValues(ts(1), /* map("attr", "val"), */ 3, "foo") + .build(), + Row.withSchema(messageSchema) + .addValues(ts(2), /* map("bttr", "vbl"), */ 5, "baz") + .build(), + Row.withSchema(messageSchema) + .addValues(ts(3), /* map("cttr", "vcl"), */ 7, "bar") + .build(), + Row.withSchema(messageSchema) + .addValues(ts(4), /* map("dttr", "vdl"), */ 8, "qaz") + .build(), + Row.withSchema(messageSchema) + .addValues(ts(4), /* map("dttr", "vdl"), */ null, null) + .build()); + + pipeline.run(); + } + + @Test + public void testSendsFlatRowInvalidToDLQ() { + Schema messageSchema = + Schema.builder() + .addDateTimeField("event_timestamp") + .addInt32Field("id") + .addStringField("name") + .build(); + + PCollectionTuple outputs = + pipeline + .apply( + "create", + Create.timestamped( + message(1, map("attr1", "val1"), "{ \"invalid1\" : \"sdfsd\" }"), + message(2, map("attr2", "val2"), "{ \"invalid2"), + message(3, map("attr", "val"), "{ \"id\" : 3, \"name\" : \"foo\" }"), + message(4, map("bttr", "vbl"), "{ \"name\" : \"baz\", \"id\" : 5 }"))) + .apply( + "convert", + ParDo.of( + PubsubMessageToRow.builder() + .messageSchema(messageSchema) + .useDlq(true) + .useFlatSchema(true) + .build()) + .withOutputTags( + PubsubMessageToRow.MAIN_TAG, TupleTagList.of(PubsubMessageToRow.DLQ_TAG))); + + PCollection rows = outputs.get(PubsubMessageToRow.MAIN_TAG); + PCollection dlqMessages = outputs.get(PubsubMessageToRow.DLQ_TAG); + + PAssert.that(dlqMessages) + .satisfies( + messages -> { + assertEquals(2, size(messages)); + assertEquals( + ImmutableSet.of(map("attr1", "val1"), map("attr2", "val2")), + convertToSet(messages, m -> m.getAttributeMap())); + + assertEquals( + ImmutableSet.of("{ \"invalid1\" : \"sdfsd\" }", "{ \"invalid2"), + convertToSet(messages, m -> new String(m.getPayload(), UTF_8))); + return null; + }); + + PAssert.that(rows) + .containsInAnyOrder( + Row.withSchema(messageSchema) + .addValues(ts(3), /* map("attr", "val"), */ 3, "foo") + .build(), + Row.withSchema(messageSchema) + .addValues(ts(4), /* map("bttr", "vbl"), */ 5, "baz") + .build()); + + pipeline.run(); + } + private Row row(Schema schema, Object... objects) { return Row.withSchema(schema).addValues(objects).build(); } diff --git a/sdks/java/extensions/sql/zetasql/build.gradle b/sdks/java/extensions/sql/zetasql/build.gradle index 59ee83180a51..560b45435c1a 100644 --- a/sdks/java/extensions/sql/zetasql/build.gradle +++ b/sdks/java/extensions/sql/zetasql/build.gradle @@ -31,6 +31,7 @@ dependencies { compile project(":sdks:java:core") compile project(":sdks:java:extensions:sql") compile library.java.vendored_calcite_1_20_0 + compile library.java.guava compile library.java.grpc_all compile library.java.protobuf_java compile library.java.protobuf_java_util diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SqlAnalyzer.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SqlAnalyzer.java index 3a4a4e246114..5c65f4ba6666 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SqlAnalyzer.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/SqlAnalyzer.java @@ -86,6 +86,7 @@ static Builder withQueryParams(Map params) { ResolvedStatement analyze(String sql) { AnalyzerOptions options = initAnalyzerOptions(builder.queryParams); List> tables = Analyzer.extractTableNamesFromStatement(sql); + SimpleCatalog catalog = createPopulatedCatalog(builder.topLevelSchema.getName(), options, tables); @@ -177,8 +178,7 @@ private void addTableToLeafCatalog( SimpleCatalog leafCatalog = createNestedCatalogs(topLevelCatalog, tablePath); org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.schema.Table calciteTable = - TableResolution.resolveCalciteTable( - builder.calciteContext, builder.topLevelSchema, tablePath); + TableResolution.resolveCalciteTable(builder.topLevelSchema, tablePath); if (calciteTable == null) { throw new RuntimeException( @@ -190,8 +190,7 @@ private void addTableToLeafCatalog( RelDataType rowType = calciteTable.getRowType(builder.typeFactory); - SimpleTableWithPath tableWithPath = - SimpleTableWithPath.of(builder.topLevelSchema.getName(), tablePath); + SimpleTableWithPath tableWithPath = SimpleTableWithPath.of(tablePath); trait.addResolvedTable(tableWithPath); addFieldsToTable(tableWithPath, rowType); @@ -236,7 +235,6 @@ static class Builder { private Map queryParams; private QueryTrait queryTrait; - private Context calciteContext; private SchemaPlus topLevelSchema; private JavaTypeFactory typeFactory; @@ -262,7 +260,6 @@ Builder withTopLevelSchema(SchemaPlus schema) { /** Calcite parsing context, can have name resolution and other configuration. */ Builder withCalciteContext(Context context) { - this.calciteContext = context; return this; } diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/TableResolution.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/TableResolution.java index a982d93b212d..d87eaf96d93e 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/TableResolution.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/TableResolution.java @@ -19,7 +19,12 @@ import com.google.zetasql.SimpleTable; import java.util.List; -import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.Context; +import java.util.stream.Collectors; +import org.apache.beam.sdk.extensions.sql.impl.BeamCalciteSchema; +import org.apache.beam.sdk.extensions.sql.impl.TableName; +import org.apache.beam.sdk.extensions.sql.meta.CustomTableResolver; +import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.jdbc.CalciteSchema; +import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.schema.Schema; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.schema.SchemaPlus; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.schema.Table; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; @@ -28,34 +33,54 @@ public class TableResolution { /** - * Returns Calcite Table by consulting the schema. + * Resolves {@code tablePath} according to the given {@code schemaPlus}. * - *

The way the schema is queried is defined by the name resolution strategey implemented by a - * TableResolver and stored as a TableResolutionContext in the context. - * - *

If no custom table resolution logic is provided, default one is used, which is: drill down - * the getSubschema() path until the second-to-last path element. We expect the path to be a table - * path, so the last element should be a valid table id, we don't expect anything else there. - * - *

This resembles a default Calcite planner strategy. One difference is that Calcite doesn't - * assume the last element is a table and will continue to call getSubschema(), making it - * impossible for a table provider to understand the context. + *

{@code tablePath} represents a structured table name where the last component is the name of + * the table and all the preceding components are sub-schemas / namespaces within {@code + * schemaPlus}. */ - public static Table resolveCalciteTable( - Context context, SchemaPlus schemaPlus, List tablePath) { - TableResolutionContext tableResolutionContext = context.unwrap(TableResolutionContext.class); - TableResolver tableResolver = getTableResolver(tableResolutionContext, schemaPlus.getName()); - return tableResolver.resolveCalciteTable(schemaPlus, tablePath); + public static Table resolveCalciteTable(SchemaPlus schemaPlus, List tablePath) { + Schema subSchema = schemaPlus; + + // subSchema.getSubschema() for all except last + for (int i = 0; i < tablePath.size() - 1; i++) { + subSchema = subSchema.getSubSchema(tablePath.get(i)); + if (subSchema == null) { + throw new IllegalStateException( + String.format( + "While resolving table path %s, no sub-schema found for component %s (\"%s\")", + tablePath, i, tablePath.get(i))); + } + } + + // for the final one call getTable() + return subSchema.getTable(Iterables.getLast(tablePath)); } - static TableResolver getTableResolver( - TableResolutionContext tableResolutionContext, String schemaName) { - if (tableResolutionContext == null - || !tableResolutionContext.hasCustomResolutionFor(schemaName)) { - return TableResolver.DEFAULT_ASSUME_LEAF_IS_TABLE; + /** + * Registers tables that will be resolved during query analysis, so table providers can eagerly + * pre-load metadata. + */ + // TODO(https://issues.apache.org/jira/browse/BEAM-8817): share this logic between dialects + public static void registerTables(SchemaPlus schemaPlus, List> tables) { + Schema defaultSchema = CalciteSchema.from(schemaPlus).schema; + if (defaultSchema instanceof BeamCalciteSchema + && ((BeamCalciteSchema) defaultSchema).getTableProvider() instanceof CustomTableResolver) { + ((CustomTableResolver) ((BeamCalciteSchema) defaultSchema).getTableProvider()) + .registerKnownTableNames( + tables.stream().map(TableName::create).collect(Collectors.toList())); } - return tableResolutionContext.getTableResolver(schemaName); + for (String subSchemaName : schemaPlus.getSubSchemaNames()) { + Schema subSchema = CalciteSchema.from(schemaPlus.getSubSchema(subSchemaName)).schema; + + if (subSchema instanceof BeamCalciteSchema + && ((BeamCalciteSchema) subSchema).getTableProvider() instanceof CustomTableResolver) { + ((CustomTableResolver) ((BeamCalciteSchema) subSchema).getTableProvider()) + .registerKnownTableNames( + tables.stream().map(TableName::create).collect(Collectors.toList())); + } + } } /** @@ -66,13 +91,11 @@ static class SimpleTableWithPath { SimpleTable table; List path; - String topLevelSchema; - static SimpleTableWithPath of(String topLevelSchema, List path) { + static SimpleTableWithPath of(List path) { SimpleTableWithPath tableWithPath = new SimpleTableWithPath(); tableWithPath.table = new SimpleTable(Iterables.getLast(path)); tableWithPath.path = path; - tableWithPath.topLevelSchema = topLevelSchema; return tableWithPath; } @@ -83,9 +106,5 @@ SimpleTable getTable() { List getPath() { return path; } - - String getTopLevelSchema() { - return topLevelSchema; - } } } diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/TableResolutionContext.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/TableResolutionContext.java deleted file mode 100644 index 3aed2c1fd323..000000000000 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/TableResolutionContext.java +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.extensions.sql.zetasql; - -import java.util.Map; -import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.Context; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; -import org.codehaus.commons.nullanalysis.Nullable; - -/** - * Calcite parser context to pass the configuration down to planner and rules so that we can - * configure custom name resolution. - */ -class TableResolutionContext implements Context { - - /** Table resolvers, associating top-level schema to a custom name resolution logic. */ - private final Map resolvers; - - /** Assigns a custom table resolver to the given schema. */ - static TableResolutionContext of(String topLevelSchema, TableResolver resolver) { - return new TableResolutionContext(ImmutableMap.of(topLevelSchema, resolver)); - } - - /** - * Uses the resolution logic that joins the table path into a single compound identifier and then - * queries the schema once, instead of drilling down into subschemas. - */ - static TableResolutionContext joinCompoundIds(String topLevelSchema) { - return of(topLevelSchema, TableResolver.JOIN_INTO_COMPOUND_ID); - } - - TableResolutionContext with(String topLevelSchema, TableResolver resolver) { - return new TableResolutionContext( - ImmutableMap.builder() - .putAll(this.resolvers) - .put(topLevelSchema, resolver) - .build()); - } - - boolean hasCustomResolutionFor(String schemaName) { - return resolvers.containsKey(schemaName); - } - - @Nullable - TableResolver getTableResolver(String schemaName) { - return resolvers.get(schemaName); - } - - private TableResolutionContext(Map resolvers) { - this.resolvers = resolvers; - } - - @Override - @SuppressWarnings("unchecked") - public T unwrap(Class c) { - return c.isAssignableFrom(TableResolutionContext.class) ? (T) this : null; - } -} diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/TableResolverImpl.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/TableResolverImpl.java deleted file mode 100644 index ad2f7ce57125..000000000000 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/TableResolverImpl.java +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.extensions.sql.zetasql; - -import java.util.List; -import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.schema.Schema; -import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.schema.Table; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; - -/** A couple of implementations of TableResolver. */ -class TableResolverImpl { - - /** - * Uses the logic similar to Calcite's EmptyScope.resolve_(...) except assumes the last element in - * the table path is a table name (which is assumed by ZetaSQL API getTableNames()). - * - *

This is the default. - * - *

I.e. drills down into schema.getSubschema() until the second last element of the table path, - * then calls schema.getTable(). - */ - static Table assumeLeafIsTable(Schema schema, List tablePath) { - Schema subSchema = schema; - - // subSchema.getSubschema() for all except last - for (int i = 0; i < tablePath.size() - 1; i++) { - subSchema = subSchema.getSubSchema(tablePath.get(i)); - } - - // for the final one call getTable() - return subSchema.getTable(Iterables.getLast(tablePath)); - } - - /** - * Joins the table name parts into a single ZetaSQL-compatible compound identifier, then calls - * schema.getTable(). - * - *

This is the input expected, for example, by Data Catalog. - * - *

Escapes slashes, backticks, quotes, for details see {@link - * ZetaSqlIdUtils#escapeAndJoin(List)}. - */ - static Table joinIntoCompoundId(Schema schema, List tablePath) { - return schema.getTable(ZetaSqlIdUtils.escapeAndJoin(tablePath)); - } -} diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLPlannerImpl.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLPlannerImpl.java index 43b1d71ed12e..ae77cbd92e29 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLPlannerImpl.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLPlannerImpl.java @@ -17,11 +17,13 @@ */ package org.apache.beam.sdk.extensions.sql.zetasql; +import com.google.zetasql.Analyzer; import com.google.zetasql.LanguageOptions; import com.google.zetasql.Value; import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedQueryStmt; import com.google.zetasql.resolvedast.ResolvedNodes.ResolvedStatement; import java.io.Reader; +import java.util.List; import java.util.Map; import java.util.logging.Logger; import org.apache.beam.sdk.extensions.sql.zetasql.translation.ConversionContext; @@ -134,6 +136,11 @@ public RelRoot rel(String sql, Map params) { QueryTrait trait = new QueryTrait(); + // Set up table providers that need to be pre-registered + // TODO(https://issues.apache.org/jira/browse/BEAM-8817): share this logic between dialects + List> tables = Analyzer.extractTableNamesFromStatement(sql); + TableResolution.registerTables(this.defaultSchemaPlus, tables); + ResolvedStatement statement = SqlAnalyzer.withQueryParams(params) .withQueryTrait(trait) diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLQueryPlanner.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLQueryPlanner.java index 6c755f127883..6ec56aeeb8fb 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLQueryPlanner.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSQLQueryPlanner.java @@ -31,7 +31,6 @@ import org.apache.beam.sdk.extensions.sql.impl.rel.BeamRelNode; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.config.CalciteConnectionConfig; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.jdbc.CalciteSchema; -import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.Contexts; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.ConventionTraitDef; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.RelTraitDef; import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.RelTraitSet; @@ -150,16 +149,10 @@ private FrameworkConfig defaultConfig(JdbcConnection connection, RuleSet[] ruleS final SqlOperatorTable opTab0 = connection.config().fun(SqlOperatorTable.class, SqlStdOperatorTable.instance()); - Object[] contexts = - org.apache.beam.vendor.calcite.v1_20_0.com.google.common.collect.ImmutableList.of( - connection.config(), TableResolutionContext.joinCompoundIds("datacatalog")) - .toArray(); - return Frameworks.newConfigBuilder() .parserConfig(parserConfig.build()) .defaultSchema(defaultSchema) .traitDefs(traitDefs) - .context(Contexts.of(contexts)) .ruleSets(ruleSets) .costFactory(BeamCostModel.FACTORY) .typeSystem(connection.getTypeFactory().getTypeSystem()) diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/TableScanConverter.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/TableScanConverter.java index 2f0c1e617f90..4ebf6343de07 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/TableScanConverter.java +++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/TableScanConverter.java @@ -61,8 +61,7 @@ public RelNode convert(ResolvedTableScan zetaNode, List inputs) { SchemaPlus defaultSchemaPlus = getConfig().getDefaultSchema(); // TODO: reject incorrect top-level schema - Table calciteTable = - TableResolution.resolveCalciteTable(getConfig().getContext(), defaultSchemaPlus, tablePath); + Table calciteTable = TableResolution.resolveCalciteTable(defaultSchemaPlus, tablePath); // we already resolved the table before passing the query to Analyzer, so it should be there checkNotNull( diff --git a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/JoinCompoundIdentifiersTest.java b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/JoinCompoundIdentifiersTest.java deleted file mode 100644 index 20f2b048fcf6..000000000000 --- a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/JoinCompoundIdentifiersTest.java +++ /dev/null @@ -1,343 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.extensions.sql.zetasql; - -import static org.apache.beam.sdk.extensions.sql.zetasql.TestInput.BASIC_TABLE_ONE; -import static org.apache.beam.sdk.extensions.sql.zetasql.TestInput.BASIC_TABLE_TWO; -import static org.apache.beam.sdk.extensions.sql.zetasql.TestInput.TABLE_WITH_STRUCT; - -import java.util.List; -import org.apache.beam.sdk.Pipeline; -import org.apache.beam.sdk.extensions.sql.impl.JdbcConnection; -import org.apache.beam.sdk.extensions.sql.impl.JdbcDriver; -import org.apache.beam.sdk.extensions.sql.impl.planner.BeamCostModel; -import org.apache.beam.sdk.extensions.sql.impl.planner.BeamRuleSets; -import org.apache.beam.sdk.extensions.sql.impl.rel.BeamRelNode; -import org.apache.beam.sdk.extensions.sql.impl.rel.BeamSqlRelUtils; -import org.apache.beam.sdk.extensions.sql.meta.BeamSqlTable; -import org.apache.beam.sdk.extensions.sql.meta.provider.ReadOnlyTableProvider; -import org.apache.beam.sdk.extensions.sql.meta.provider.TableProvider; -import org.apache.beam.sdk.options.PipelineOptionsFactory; -import org.apache.beam.sdk.schemas.Schema; -import org.apache.beam.sdk.testing.PAssert; -import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.Row; -import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.Contexts; -import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.ConventionTraitDef; -import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.RelTraitDef; -import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.schema.SchemaPlus; -import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.tools.FrameworkConfig; -import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.tools.Frameworks; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; -import org.joda.time.Duration; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExpectedException; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** Tests for identifiers. */ -@RunWith(JUnit4.class) -public class JoinCompoundIdentifiersTest { - - private static final Long TWO_MINUTES = 2L; - private static final String DEFAULT_SCHEMA = "beam"; - private static final String FULL_ON_ID = - "a.`b-\\`c`.d.`httz://d.e-f.g:233333/blah\\?yes=1&true=false`"; - private static final String TABLE_WITH_STRUCTS_ID = "a.`table:com`.`..::with-struct::..`"; - - private static final TableProvider TEST_TABLES = - new ReadOnlyTableProvider( - "test_table_provider", - ImmutableMap.builder() - .put("KeyValue", BASIC_TABLE_ONE) - .put("a.b", BASIC_TABLE_ONE) - .put("c.d.e", BASIC_TABLE_ONE) - .put("c.d.f", BASIC_TABLE_TWO) - .put("c.g.e", BASIC_TABLE_TWO) - .put("weird.`\\n\\t\\r\\f`", BASIC_TABLE_ONE) - .put("a.`b-\\`c`.d", BASIC_TABLE_TWO) - .put(FULL_ON_ID, BASIC_TABLE_TWO) - .put(TABLE_WITH_STRUCTS_ID, TABLE_WITH_STRUCT) - .build()); - - @Rule public transient TestPipeline pipeline = TestPipeline.create(); - @Rule public ExpectedException thrown = ExpectedException.none(); - - @Test - public void testComplexTableName() throws Exception { - FrameworkConfig cfg = initializeCalcite(); - - PCollection result = applySqlTransform(pipeline, cfg, "SELECT Key FROM a.b"); - - PAssert.that(result).containsInAnyOrder(singleValue(14L), singleValue(15L)); - pipeline.run().waitUntilFinish(Duration.standardMinutes(TWO_MINUTES)); - } - - @Test - public void testComplexTableName3Levels() throws Exception { - FrameworkConfig cfg = initializeCalcite(); - - PCollection result = applySqlTransform(pipeline, cfg, "SELECT Key FROM c.d.e"); - - PAssert.that(result).containsInAnyOrder(singleValue(14L), singleValue(15L)); - pipeline.run().waitUntilFinish(Duration.standardMinutes(TWO_MINUTES)); - } - - @Test - public void testOnePartWithBackticks() throws Exception { - FrameworkConfig cfg = initializeCalcite(); - - PCollection result = applySqlTransform(pipeline, cfg, "SELECT RowKey FROM a.`b-\\`c`.d"); - - PAssert.that(result).containsInAnyOrder(singleValue(16L), singleValue(15L)); - pipeline.run().waitUntilFinish(Duration.standardMinutes(TWO_MINUTES)); - } - - @Test - public void testNewLinesAndOtherWhitespace() throws Exception { - FrameworkConfig cfg = initializeCalcite(); - - PCollection result = - applySqlTransform(pipeline, cfg, "SELECT Key FROM weird.`\\n\\t\\r\\f`"); - - PAssert.that(result).containsInAnyOrder(singleValue(14L), singleValue(15L)); - pipeline.run().waitUntilFinish(Duration.standardMinutes(TWO_MINUTES)); - } - - @Test - public void testFullOnWithBackticks() throws Exception { - FrameworkConfig cfg = initializeCalcite(); - - PCollection result = applySqlTransform(pipeline, cfg, "SELECT RowKey FROM " + FULL_ON_ID); - - PAssert.that(result).containsInAnyOrder(singleValue(16L), singleValue(15L)); - pipeline.run().waitUntilFinish(Duration.standardMinutes(TWO_MINUTES)); - } - - @Test - public void testJoinWithFullOnWithBackticks() throws Exception { - FrameworkConfig cfg = initializeCalcite(); - - PCollection result = - applySqlTransform( - pipeline, - cfg, - "SELECT t1.RowKey FROM " - + FULL_ON_ID - + " AS t1 \n" - + " INNER JOIN a.`b-\\`c`.d t2 on t1.RowKey = t2.RowKey"); - - PAssert.that(result).containsInAnyOrder(singleValue(16L), singleValue(15L)); - pipeline.run().waitUntilFinish(Duration.standardMinutes(TWO_MINUTES)); - } - - @Test - public void testQualifiedFieldAccessWithAliasedComplexTableName() throws Exception { - FrameworkConfig cfg = initializeCalcite(); - - PCollection result = applySqlTransform(pipeline, cfg, "SELECT t.Key FROM a.b AS t"); - - PAssert.that(result).containsInAnyOrder(singleValue(14L), singleValue(15L)); - pipeline.run().waitUntilFinish(Duration.standardMinutes(TWO_MINUTES)); - } - - @Test - public void testQualifiedFieldAccessWithAliasedComplexTableName3Levels() throws Exception { - FrameworkConfig cfg = initializeCalcite(); - - PCollection result = applySqlTransform(pipeline, cfg, "SELECT t.Key FROM c.d.e AS t"); - - PAssert.that(result).containsInAnyOrder(singleValue(14L), singleValue(15L)); - pipeline.run().waitUntilFinish(Duration.standardMinutes(TWO_MINUTES)); - } - - @Test - public void testQualifiedFieldAccessWithUnaliasedComplexTableName() throws Exception { - FrameworkConfig cfg = initializeCalcite(); - - PCollection result = applySqlTransform(pipeline, cfg, "SELECT b.Key FROM a.b"); - - PAssert.that(result).containsInAnyOrder(singleValue(14L), singleValue(15L)); - pipeline.run().waitUntilFinish(Duration.standardMinutes(TWO_MINUTES)); - } - - @Test - public void testQualifiedFieldAccessWithUnaliasedComplexTableName3Levels() throws Exception { - FrameworkConfig cfg = initializeCalcite(); - - PCollection result = applySqlTransform(pipeline, cfg, "SELECT e.Key FROM c.d.e"); - - PAssert.that(result).containsInAnyOrder(singleValue(14L), singleValue(15L)); - pipeline.run().waitUntilFinish(Duration.standardMinutes(TWO_MINUTES)); - } - - @Test - public void testQualifiedFieldAccessWithUnaliasedComplexTableName3Levels2() throws Exception { - FrameworkConfig cfg = initializeCalcite(); - - PCollection result = applySqlTransform(pipeline, cfg, "SELECT e.Key FROM c.d.e"); - - PAssert.that(result).containsInAnyOrder(singleValue(14L), singleValue(15L)); - pipeline.run().waitUntilFinish(Duration.standardMinutes(TWO_MINUTES)); - } - - @Test - public void testQualifiedFieldAccessWithJoinOfAliasedComplexTableNames() throws Exception { - FrameworkConfig cfg = initializeCalcite(); - - PCollection result = - applySqlTransform( - pipeline, - cfg, - "SELECT t1.Key FROM a.b AS t1 INNER JOIN c.d.e AS t2 ON t1.Key = t2.Key"); - - PAssert.that(result).containsInAnyOrder(singleValue(14L), singleValue(15L)); - pipeline.run().waitUntilFinish(Duration.standardMinutes(TWO_MINUTES)); - } - - @Test - public void testJoinTwoTablesWithLastPartIdDifferent() throws Exception { - FrameworkConfig cfg = initializeCalcite(); - - PCollection result = - applySqlTransform( - pipeline, - cfg, - "SELECT t1.Key FROM c.d.e AS t1 INNER JOIN c.d.f AS t2 ON t1.Key = t2.RowKey"); - - PAssert.that(result).containsInAnyOrder(singleValue(15L)); - pipeline.run().waitUntilFinish(Duration.standardMinutes(TWO_MINUTES)); - } - - @Test - public void testJoinTwoTablesWithMiddlePartIdDifferent() throws Exception { - FrameworkConfig cfg = initializeCalcite(); - - PCollection result = - applySqlTransform( - pipeline, - cfg, - "SELECT t1.Key FROM c.d.e AS t1 INNER JOIN c.g.e AS t2 ON t1.Key = t2.RowKey"); - - PAssert.that(result).containsInAnyOrder(singleValue(15L)); - pipeline.run().waitUntilFinish(Duration.standardMinutes(TWO_MINUTES)); - } - - @Test - public void testQualifiedFieldAccessWithJoinOfUnaliasedComplexTableNames() throws Exception { - FrameworkConfig cfg = initializeCalcite(); - - PCollection result = - applySqlTransform(pipeline, cfg, "SELECT b.Key FROM a.b INNER JOIN c.d.e ON b.Key = e.Key"); - - PAssert.that(result).containsInAnyOrder(singleValue(14L), singleValue(15L)); - pipeline.run().waitUntilFinish(Duration.standardMinutes(TWO_MINUTES)); - } - - @Test - public void testStructFieldAccess() throws Exception { - FrameworkConfig cfg = initializeCalcite(); - - PCollection result = - applySqlTransform( - pipeline, - cfg, - "SELECT struct_col.struct_col_str FROM a.`table:com`.`..::with-struct::..`"); - - PAssert.that(result).containsInAnyOrder(singleValue("row_one"), singleValue("row_two")); - pipeline.run().waitUntilFinish(Duration.standardMinutes(TWO_MINUTES)); - } - - @Test - public void testQualifiedStructFieldAccess() throws Exception { - FrameworkConfig cfg = initializeCalcite(); - - PCollection result = - applySqlTransform( - pipeline, - cfg, - "SELECT `..::with-struct::..`.struct_col.struct_col_str \n" - + " FROM a.`table:com`.`..::with-struct::..`"); - - PAssert.that(result).containsInAnyOrder(singleValue("row_one"), singleValue("row_two")); - pipeline.run().waitUntilFinish(Duration.standardMinutes(TWO_MINUTES)); - } - - @Test - public void testAliasedStructFieldAccess() throws Exception { - FrameworkConfig cfg = initializeCalcite(); - - PCollection result = - applySqlTransform( - pipeline, - cfg, - "SELECT t.struct_col.struct_col_str FROM " + TABLE_WITH_STRUCTS_ID + " t"); - - PAssert.that(result).containsInAnyOrder(singleValue("row_one"), singleValue("row_two")); - pipeline.run().waitUntilFinish(Duration.standardMinutes(TWO_MINUTES)); - } - - @SuppressWarnings("unxchecked") - private static FrameworkConfig initializeCalcite() { - JdbcConnection jdbcConnection = - JdbcDriver.connect(TEST_TABLES, PipelineOptionsFactory.create()); - SchemaPlus defaultSchemaPlus = jdbcConnection.getCurrentSchemaPlus(); - List traitDefs = ImmutableList.of(ConventionTraitDef.INSTANCE); - - Object[] contexts = - ImmutableList.of( - Contexts.of(jdbcConnection.config()), - TableResolutionContext.joinCompoundIds(DEFAULT_SCHEMA)) - .toArray(); - - return Frameworks.newConfigBuilder() - .defaultSchema(defaultSchemaPlus) - .traitDefs(traitDefs) - .context(Contexts.of(contexts)) - .ruleSets(BeamRuleSets.getRuleSets()) - .costFactory(BeamCostModel.FACTORY) - .typeSystem(jdbcConnection.getTypeFactory().getTypeSystem()) - .build(); - } - - private PCollection applySqlTransform( - Pipeline pipeline, FrameworkConfig config, String query) throws Exception { - - BeamRelNode beamRelNode = new ZetaSQLQueryPlanner(config).parseQuery(query); - return BeamSqlRelUtils.toPCollection(pipeline, beamRelNode); - } - - private Row singleValue(long value) { - return Row.withSchema(singleLongField()).addValue(value).build(); - } - - private Row singleValue(String value) { - return Row.withSchema(singleStringField()).addValue(value).build(); - } - - private Schema singleLongField() { - return Schema.builder().addInt64Field("field1").build(); - } - - private Schema singleStringField() { - return Schema.builder().addStringField("field1").build(); - } -} diff --git a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/TableResolutionTest.java b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/TableResolutionTest.java new file mode 100644 index 000000000000..ca20bda34367 --- /dev/null +++ b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/TableResolutionTest.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.sql.zetasql; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.mockito.Mockito.when; + +import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.schema.SchemaPlus; +import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.schema.Table; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; +import org.hamcrest.Matchers; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +/** Unit tests for {@link TableResolution}. */ +@RunWith(JUnit4.class) +public class TableResolutionTest { + + // A simple in-memory SchemaPlus would be fine + @Mock SchemaPlus mockSchemaPlus; + @Mock SchemaPlus innerSchemaPlus; + + // A table whose identity is not important + @Mock Table mockTable; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + } + + /** Unit test for resolving a table with no hierarchy. */ + @Test + public void testResolveFlat() { + String tableName = "fake_table"; + when(mockSchemaPlus.getTable(tableName)).thenReturn(mockTable); + Table table = TableResolution.resolveCalciteTable(mockSchemaPlus, ImmutableList.of(tableName)); + assertThat(table, Matchers.is(mockTable)); + } + + /** Unit test for resolving a table with no hierarchy but dots in its actual name. */ + @Test + public void testResolveWithDots() { + String tableName = "fake.table"; + when(mockSchemaPlus.getTable(tableName)).thenReturn(mockTable); + Table table = TableResolution.resolveCalciteTable(mockSchemaPlus, ImmutableList.of(tableName)); + assertThat(table, Matchers.is(mockTable)); + } + + /** Unit test for failing to resolve a table with no subschemas. */ + @Test + public void testMissingFlat() { + String tableName = "fake_table"; + when(mockSchemaPlus.getTable(tableName)).thenReturn(null); + Table table = TableResolution.resolveCalciteTable(mockSchemaPlus, ImmutableList.of(tableName)); + assertThat(table, Matchers.nullValue()); + } + + /** Unit test for resolving a table with some hierarchy. */ + @Test + public void testResolveNested() { + String subSchema = "fake_schema"; + String tableName = "fake_table"; + when(mockSchemaPlus.getSubSchema(subSchema)).thenReturn(innerSchemaPlus); + when(innerSchemaPlus.getTable(tableName)).thenReturn(mockTable); + Table table = + TableResolution.resolveCalciteTable(mockSchemaPlus, ImmutableList.of(subSchema, tableName)); + assertThat(table, Matchers.is(mockTable)); + } + + /** Unit test for resolving a table with dots in the subschema names and the table name. */ + @Test + public void testResolveNestedWithDots() { + String subSchema = "fake.schema"; + String tableName = "fake.table"; + when(mockSchemaPlus.getSubSchema(subSchema)).thenReturn(innerSchemaPlus); + when(innerSchemaPlus.getTable(tableName)).thenReturn(mockTable); + Table table = + TableResolution.resolveCalciteTable(mockSchemaPlus, ImmutableList.of(subSchema, tableName)); + assertThat(table, Matchers.is(mockTable)); + } + + /** Unit test for resolving a table with some hierarchy that is missing. */ + @Test + public void testMissingSubschema() { + String subSchema = "fake_schema"; + String tableName = "fake_table"; + when(mockSchemaPlus.getSubSchema(subSchema)).thenReturn(null); + + Assert.assertThrows( + IllegalStateException.class, + () -> { + TableResolution.resolveCalciteTable( + mockSchemaPlus, ImmutableList.of(subSchema, tableName)); + }); + } + + /** Unit test for resolving a table with some hierarchy and the table is missing. */ + @Test + public void testMissingTableInSubschema() { + String subSchema = "fake_schema"; + String tableName = "fake_table"; + when(mockSchemaPlus.getSubSchema(subSchema)).thenReturn(innerSchemaPlus); + when(innerSchemaPlus.getTable(tableName)).thenReturn(null); + Table table = + TableResolution.resolveCalciteTable(mockSchemaPlus, ImmutableList.of(subSchema, tableName)); + assertThat(table, Matchers.nullValue()); + } +} diff --git a/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataBufferingOutboundObserver.java b/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataBufferingOutboundObserver.java index 02460bfaa0bb..72ab5d683f1e 100644 --- a/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataBufferingOutboundObserver.java +++ b/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataBufferingOutboundObserver.java @@ -17,14 +17,14 @@ */ package org.apache.beam.sdk.fn.data; -import java.io.IOException; +import java.util.Collections; +import java.util.List; import org.apache.beam.model.fnexecution.v1.BeamFnApi; import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.vendor.grpc.v1p21p0.com.google.protobuf.ByteString; +import org.apache.beam.sdk.options.ExperimentalOptions; +import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.vendor.grpc.v1p21p0.io.grpc.stub.StreamObserver; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** * A buffering outbound {@link FnDataReceiver} for the Beam Fn Data API. @@ -32,117 +32,59 @@ *

Encodes individually consumed elements with the provided {@link Coder} producing a single * {@link BeamFnApi.Elements} message when the buffer threshold is surpassed. * - *

The default buffer threshold can be overridden by specifying the experiment {@code - * beam_fn_api_data_buffer_limit=} + *

The default size-based buffer threshold can be overridden by specifying the experiment {@code + * beam_fn_api_data_buffer_size_limit=} * - *

TODO: Handle outputting large elements (> 2GiBs). Note that this also applies to the input - * side as well. - * - *

TODO: Handle outputting elements that are zero bytes by outputting a single byte as a marker, - * detect on the input side that no bytes were read and force reading a single byte. + *

The default time-based buffer threshold can be overridden by specifying the experiment {@code + * beam_fn_api_data_buffer_time_limit=} */ -public class BeamFnDataBufferingOutboundObserver implements CloseableFnDataReceiver { - // TODO: Consider moving this constant out of this class - public static final String BEAM_FN_API_DATA_BUFFER_LIMIT = "beam_fn_api_data_buffer_limit="; - @VisibleForTesting static final int DEFAULT_BUFFER_LIMIT_BYTES = 1_000_000; - private static final Logger LOG = - LoggerFactory.getLogger(BeamFnDataBufferingOutboundObserver.class); - - public static BeamFnDataBufferingOutboundObserver forLocation( - LogicalEndpoint endpoint, - Coder coder, - StreamObserver outboundObserver) { - return forLocationWithBufferLimit( - DEFAULT_BUFFER_LIMIT_BYTES, endpoint, coder, outboundObserver); - } +public interface BeamFnDataBufferingOutboundObserver extends CloseableFnDataReceiver { + // TODO: Consider moving this constant out of this interface + /** @deprecated Use BEAM_FN_API_DATA_BUFFER_SIZE_LIMIT instead. */ + @Deprecated String BEAM_FN_API_DATA_BUFFER_LIMIT = "beam_fn_api_data_buffer_limit="; - public static BeamFnDataBufferingOutboundObserver forLocationWithBufferLimit( - int bufferLimit, - LogicalEndpoint endpoint, - Coder coder, - StreamObserver outboundObserver) { - return new BeamFnDataBufferingOutboundObserver<>( - bufferLimit, endpoint, coder, outboundObserver); - } + String BEAM_FN_API_DATA_BUFFER_SIZE_LIMIT = "beam_fn_api_data_buffer_size_limit="; + @VisibleForTesting int DEFAULT_BUFFER_LIMIT_BYTES = 1_000_000; - private long byteCounter; - private long counter; - private boolean closed; - private final int bufferLimit; - private final Coder coder; - private final LogicalEndpoint outputLocation; - private final StreamObserver outboundObserver; - private final ByteString.Output bufferedElements; + String BEAM_FN_API_DATA_BUFFER_TIME_LIMIT = "beam_fn_api_data_buffer_time_limit="; + long DEFAULT_BUFFER_LIMIT_TIME_MS = -1L; - private BeamFnDataBufferingOutboundObserver( - int bufferLimit, - LogicalEndpoint outputLocation, + static BeamFnDataSizeBasedBufferingOutboundObserver forLocation( + PipelineOptions options, + LogicalEndpoint endpoint, Coder coder, StreamObserver outboundObserver) { - this.bufferLimit = bufferLimit; - this.outputLocation = outputLocation; - this.coder = coder; - this.outboundObserver = outboundObserver; - this.bufferedElements = ByteString.newOutput(); - this.closed = false; - } - - @Override - public void close() throws Exception { - if (closed) { - throw new IllegalStateException("Already closed."); - } - closed = true; - BeamFnApi.Elements.Builder elements = convertBufferForTransmission(); - // This will add an empty data block representing the end of stream. - elements - .addDataBuilder() - .setInstructionId(outputLocation.getInstructionId()) - .setTransformId(outputLocation.getTransformId()); - - LOG.debug( - "Closing stream for instruction {} and " - + "transform {} having transmitted {} values {} bytes", - outputLocation.getInstructionId(), - outputLocation.getTransformId(), - counter, - byteCounter); - outboundObserver.onNext(elements.build()); - } - - @Override - public void flush() throws IOException { - if (bufferedElements.size() > 0) { - outboundObserver.onNext(convertBufferForTransmission().build()); + int sizeLimit = getSizeLimit(options); + long timeLimit = getTimeLimit(options); + if (timeLimit > 0) { + return new BeamFnDataTimeBasedBufferingOutboundObserver<>( + sizeLimit, timeLimit, endpoint, coder, outboundObserver); + } else { + return new BeamFnDataSizeBasedBufferingOutboundObserver<>( + sizeLimit, endpoint, coder, outboundObserver); } } - @Override - public void accept(T t) throws IOException { - if (closed) { - throw new IllegalStateException("Already closed."); - } - coder.encode(t, bufferedElements); - counter += 1; - if (bufferedElements.size() >= bufferLimit) { - flush(); + static int getSizeLimit(PipelineOptions options) { + List experiments = options.as(ExperimentalOptions.class).getExperiments(); + for (String experiment : experiments == null ? Collections.emptyList() : experiments) { + if (experiment.startsWith(BEAM_FN_API_DATA_BUFFER_SIZE_LIMIT)) { + return Integer.parseInt(experiment.substring(BEAM_FN_API_DATA_BUFFER_SIZE_LIMIT.length())); + } + if (experiment.startsWith(BEAM_FN_API_DATA_BUFFER_LIMIT)) { + return Integer.parseInt(experiment.substring(BEAM_FN_API_DATA_BUFFER_LIMIT.length())); + } } + return DEFAULT_BUFFER_LIMIT_BYTES; } - private BeamFnApi.Elements.Builder convertBufferForTransmission() { - BeamFnApi.Elements.Builder elements = BeamFnApi.Elements.newBuilder(); - if (bufferedElements.size() == 0) { - return elements; + static long getTimeLimit(PipelineOptions options) { + List experiments = options.as(ExperimentalOptions.class).getExperiments(); + for (String experiment : experiments == null ? Collections.emptyList() : experiments) { + if (experiment.startsWith(BEAM_FN_API_DATA_BUFFER_TIME_LIMIT)) { + return Long.parseLong(experiment.substring(BEAM_FN_API_DATA_BUFFER_TIME_LIMIT.length())); + } } - - elements - .addDataBuilder() - .setInstructionId(outputLocation.getInstructionId()) - .setTransformId(outputLocation.getTransformId()) - .setData(bufferedElements.toByteString()); - - byteCounter += bufferedElements.size(); - bufferedElements.reset(); - return elements; + return DEFAULT_BUFFER_LIMIT_TIME_MS; } } diff --git a/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataSizeBasedBufferingOutboundObserver.java b/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataSizeBasedBufferingOutboundObserver.java new file mode 100644 index 000000000000..c0215aeaf2bc --- /dev/null +++ b/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataSizeBasedBufferingOutboundObserver.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.fn.data; + +import java.io.IOException; +import org.apache.beam.model.fnexecution.v1.BeamFnApi; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.vendor.grpc.v1p21p0.com.google.protobuf.ByteString; +import org.apache.beam.vendor.grpc.v1p21p0.io.grpc.stub.StreamObserver; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A size-based buffering outbound {@link FnDataReceiver} for the Beam Fn Data API. + * + *

TODO: Handle outputting large elements (> 2GiBs). Note that this also applies to the input + * side as well. + * + *

TODO: Handle outputting elements that are zero bytes by outputting a single byte as a marker, + * detect on the input side that no bytes were read and force reading a single byte. + */ +public class BeamFnDataSizeBasedBufferingOutboundObserver + implements BeamFnDataBufferingOutboundObserver { + private static final Logger LOG = + LoggerFactory.getLogger(BeamFnDataSizeBasedBufferingOutboundObserver.class); + + private long byteCounter; + private long counter; + private boolean closed; + private final int sizeLimit; + private final Coder coder; + private final LogicalEndpoint outputLocation; + private final StreamObserver outboundObserver; + private final ByteString.Output bufferedElements; + + BeamFnDataSizeBasedBufferingOutboundObserver( + int sizeLimit, + LogicalEndpoint outputLocation, + Coder coder, + StreamObserver outboundObserver) { + this.sizeLimit = sizeLimit; + this.outputLocation = outputLocation; + this.coder = coder; + this.outboundObserver = outboundObserver; + this.bufferedElements = ByteString.newOutput(); + this.closed = false; + } + + @Override + public void close() throws Exception { + if (closed) { + throw new IllegalStateException("Already closed."); + } + closed = true; + BeamFnApi.Elements.Builder elements = convertBufferForTransmission(); + // This will add an empty data block representing the end of stream. + elements + .addDataBuilder() + .setInstructionId(outputLocation.getInstructionId()) + .setTransformId(outputLocation.getTransformId()); + + LOG.debug( + "Closing stream for instruction {} and " + + "transform {} having transmitted {} values {} bytes", + outputLocation.getInstructionId(), + outputLocation.getTransformId(), + counter, + byteCounter); + outboundObserver.onNext(elements.build()); + } + + @Override + public void flush() throws IOException { + if (bufferedElements.size() > 0) { + outboundObserver.onNext(convertBufferForTransmission().build()); + } + } + + @Override + public void accept(T t) throws IOException { + if (closed) { + throw new IllegalStateException("Already closed."); + } + coder.encode(t, bufferedElements); + counter += 1; + if (bufferedElements.size() >= sizeLimit) { + flush(); + } + } + + private BeamFnApi.Elements.Builder convertBufferForTransmission() { + BeamFnApi.Elements.Builder elements = BeamFnApi.Elements.newBuilder(); + if (bufferedElements.size() == 0) { + return elements; + } + + elements + .addDataBuilder() + .setInstructionId(outputLocation.getInstructionId()) + .setTransformId(outputLocation.getTransformId()) + .setData(bufferedElements.toByteString()); + + byteCounter += bufferedElements.size(); + bufferedElements.reset(); + return elements; + } +} diff --git a/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataTimeBasedBufferingOutboundObserver.java b/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataTimeBasedBufferingOutboundObserver.java new file mode 100644 index 000000000000..3595fbd55971 --- /dev/null +++ b/sdks/java/fn-execution/src/main/java/org/apache/beam/sdk/fn/data/BeamFnDataTimeBasedBufferingOutboundObserver.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.fn.data; + +import java.io.IOException; +import java.util.concurrent.CancellationException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import org.apache.beam.model.fnexecution.v1.BeamFnApi; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.vendor.grpc.v1p21p0.io.grpc.stub.StreamObserver; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; + +/** + * A buffering outbound {@link FnDataReceiver} with both size-based buffer and time-based buffer + * enabled for the Beam Fn Data API. + */ +public class BeamFnDataTimeBasedBufferingOutboundObserver + extends BeamFnDataSizeBasedBufferingOutboundObserver { + + @VisibleForTesting final ScheduledFuture flushFuture; + + BeamFnDataTimeBasedBufferingOutboundObserver( + int sizeLimit, + long timeLimit, + LogicalEndpoint outputLocation, + Coder coder, + StreamObserver outboundObserver) { + super(sizeLimit, outputLocation, coder, outboundObserver); + this.flushFuture = + Executors.newSingleThreadScheduledExecutor( + new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat("DataBufferOutboundFlusher-thread") + .build()) + .scheduleAtFixedRate(this::periodicFlush, timeLimit, timeLimit, TimeUnit.MILLISECONDS); + } + + @Override + public void close() throws Exception { + checkFlushThreadException(); + flushFuture.cancel(false); + try { + flushFuture.get(); + } catch (ExecutionException ee) { + unwrapExecutionException(ee); + } catch (CancellationException ce) { + // expected + } + super.close(); + } + + @Override + public synchronized void flush() throws IOException { + super.flush(); + } + + @Override + public void accept(T t) throws IOException { + checkFlushThreadException(); + super.accept(t); + } + + private void periodicFlush() { + try { + flush(); + } catch (Throwable t) { + throw new RuntimeException(t); + } + } + + /** Check if the flush thread failed with an exception. */ + private void checkFlushThreadException() throws IOException { + if (flushFuture.isDone()) { + try { + flushFuture.get(); + throw new IOException("Periodic flushing thread finished unexpectedly."); + } catch (ExecutionException ee) { + unwrapExecutionException(ee); + } catch (CancellationException ce) { + throw new IOException(ce); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + throw new IOException(ie); + } + } + } + + private void unwrapExecutionException(ExecutionException ee) throws IOException { + // the cause is always RuntimeException + RuntimeException re = (RuntimeException) ee.getCause(); + if (re.getCause() instanceof IOException) { + throw (IOException) re.getCause(); + } else { + throw new IOException(re.getCause()); + } + } +} diff --git a/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/data/BeamFnDataBufferingOutboundObserverTest.java b/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/data/BeamFnDataSizeBasedBufferingOutboundObserverTest.java similarity index 80% rename from sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/data/BeamFnDataBufferingOutboundObserverTest.java rename to sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/data/BeamFnDataSizeBasedBufferingOutboundObserverTest.java index cb1752cde47f..0e53b26c3411 100644 --- a/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/data/BeamFnDataBufferingOutboundObserverTest.java +++ b/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/data/BeamFnDataSizeBasedBufferingOutboundObserverTest.java @@ -20,11 +20,13 @@ import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow; import static org.hamcrest.Matchers.empty; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThat; import static org.junit.Assert.fail; import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Consumer; @@ -34,6 +36,9 @@ import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.LengthPrefixCoder; import org.apache.beam.sdk.fn.test.TestStreams; +import org.apache.beam.sdk.options.ExperimentalOptions; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.vendor.grpc.v1p21p0.com.google.protobuf.ByteString; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; @@ -41,9 +46,9 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -/** Tests for {@link BeamFnDataBufferingOutboundObserver}. */ +/** Tests for {@link BeamFnDataSizeBasedBufferingOutboundObserver}. */ @RunWith(JUnit4.class) -public class BeamFnDataBufferingOutboundObserverTest { +public class BeamFnDataSizeBasedBufferingOutboundObserverTest { private static final LogicalEndpoint OUTPUT_LOCATION = LogicalEndpoint.of("777L", "555L"); private static final Coder> CODER = LengthPrefixCoder.of(WindowedValue.getValueOnlyCoder(ByteArrayCoder.of())); @@ -54,37 +59,43 @@ public void testWithDefaultBuffer() throws Exception { final AtomicBoolean onCompletedWasCalled = new AtomicBoolean(); CloseableFnDataReceiver> consumer = BeamFnDataBufferingOutboundObserver.forLocation( + PipelineOptionsFactory.create(), OUTPUT_LOCATION, CODER, TestStreams.withOnNext(addToValuesConsumer(values)) .withOnCompleted(setBooleanToTrue(onCompletedWasCalled)) .build()); + // Test that the time-based flush is disabled by default. + assertFalse(consumer instanceof BeamFnDataTimeBasedBufferingOutboundObserver); + // Test that nothing is emitted till the default buffer size is surpassed. consumer.accept( valueInGlobalWindow( - new byte[BeamFnDataBufferingOutboundObserver.DEFAULT_BUFFER_LIMIT_BYTES - 50])); + new byte + [BeamFnDataSizeBasedBufferingOutboundObserver.DEFAULT_BUFFER_LIMIT_BYTES - 50])); assertThat(values, empty()); // Test that when we cross the buffer, we emit. consumer.accept(valueInGlobalWindow(new byte[50])); assertEquals( messageWithData( - new byte[BeamFnDataBufferingOutboundObserver.DEFAULT_BUFFER_LIMIT_BYTES - 50], + new byte[BeamFnDataSizeBasedBufferingOutboundObserver.DEFAULT_BUFFER_LIMIT_BYTES - 50], new byte[50]), Iterables.get(values, 0)); // Test that nothing is emitted till the default buffer size is surpassed after a reset consumer.accept( valueInGlobalWindow( - new byte[BeamFnDataBufferingOutboundObserver.DEFAULT_BUFFER_LIMIT_BYTES - 50])); + new byte + [BeamFnDataSizeBasedBufferingOutboundObserver.DEFAULT_BUFFER_LIMIT_BYTES - 50])); assertEquals(1, values.size()); // Test that when we cross the buffer, we emit. consumer.accept(valueInGlobalWindow(new byte[50])); assertEquals( messageWithData( - new byte[BeamFnDataBufferingOutboundObserver.DEFAULT_BUFFER_LIMIT_BYTES - 50], + new byte[BeamFnDataSizeBasedBufferingOutboundObserver.DEFAULT_BUFFER_LIMIT_BYTES - 50], new byte[50]), Iterables.get(values, 1)); @@ -96,7 +107,8 @@ public void testWithDefaultBuffer() throws Exception { try { consumer.accept( valueInGlobalWindow( - new byte[BeamFnDataBufferingOutboundObserver.DEFAULT_BUFFER_LIMIT_BYTES - 50])); + new byte + [BeamFnDataSizeBasedBufferingOutboundObserver.DEFAULT_BUFFER_LIMIT_BYTES - 50])); fail("Writing after close should be prohibited."); } catch (IllegalStateException exn) { // expected @@ -115,9 +127,13 @@ public void testWithDefaultBuffer() throws Exception { public void testConfiguredBufferLimit() throws Exception { Collection values = new ArrayList<>(); AtomicBoolean onCompletedWasCalled = new AtomicBoolean(); + PipelineOptions options = PipelineOptionsFactory.create(); + options + .as(ExperimentalOptions.class) + .setExperiments(Arrays.asList("beam_fn_api_data_buffer_size_limit=100")); CloseableFnDataReceiver> consumer = - BeamFnDataBufferingOutboundObserver.forLocationWithBufferLimit( - 100, + BeamFnDataBufferingOutboundObserver.forLocation( + options, OUTPUT_LOCATION, CODER, TestStreams.withOnNext(addToValuesConsumer(values)) @@ -146,7 +162,7 @@ public void testConfiguredBufferLimit() throws Exception { Iterables.get(values, 1)); } - private static BeamFnApi.Elements messageWithData(byte[]... datum) throws IOException { + static BeamFnApi.Elements messageWithData(byte[]... datum) throws IOException { ByteString.Output output = ByteString.newOutput(); for (byte[] data : datum) { CODER.encode(valueInGlobalWindow(data), output); diff --git a/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/data/BeamFnDataTimeBasedBufferingOutboundObserverTest.java b/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/data/BeamFnDataTimeBasedBufferingOutboundObserverTest.java new file mode 100644 index 000000000000..f4effa846147 --- /dev/null +++ b/sdks/java/fn-execution/src/test/java/org/apache/beam/sdk/fn/data/BeamFnDataTimeBasedBufferingOutboundObserverTest.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.fn.data; + +import static org.apache.beam.sdk.fn.data.BeamFnDataSizeBasedBufferingOutboundObserverTest.messageWithData; +import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.concurrent.CountDownLatch; +import java.util.function.Consumer; +import org.apache.beam.model.fnexecution.v1.BeamFnApi.Elements; +import org.apache.beam.sdk.coders.ByteArrayCoder; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.LengthPrefixCoder; +import org.apache.beam.sdk.fn.test.TestStreams; +import org.apache.beam.sdk.options.ExperimentalOptions; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link BeamFnDataTimeBasedBufferingOutboundObserver}. */ +@RunWith(JUnit4.class) +public class BeamFnDataTimeBasedBufferingOutboundObserverTest { + private static final LogicalEndpoint OUTPUT_LOCATION = LogicalEndpoint.of("777L", "555L"); + private static final Coder> CODER = + LengthPrefixCoder.of(WindowedValue.getValueOnlyCoder(ByteArrayCoder.of())); + + @Test + public void testConfiguredTimeLimit() throws Exception { + Collection values = new ArrayList<>(); + PipelineOptions options = PipelineOptionsFactory.create(); + options + .as(ExperimentalOptions.class) + .setExperiments(Arrays.asList("beam_fn_api_data_buffer_time_limit=1")); + final CountDownLatch waitForFlush = new CountDownLatch(1); + CloseableFnDataReceiver> consumer = + BeamFnDataBufferingOutboundObserver.forLocation( + options, + OUTPUT_LOCATION, + CODER, + TestStreams.withOnNext( + (Consumer) + e -> { + values.add(e); + waitForFlush.countDown(); + }) + .build()); + + // Test that it emits when time passed the time limit + consumer.accept(valueInGlobalWindow(new byte[1])); + waitForFlush.await(); // wait the flush thread to flush the buffer + assertEquals(messageWithData(new byte[1]), Iterables.get(values, 0)); + } + + @Test + public void testConfiguredTimeLimitExceptionPropagation() throws Exception { + PipelineOptions options = PipelineOptionsFactory.create(); + options + .as(ExperimentalOptions.class) + .setExperiments(Arrays.asList("beam_fn_api_data_buffer_time_limit=1")); + BeamFnDataTimeBasedBufferingOutboundObserver> consumer = + (BeamFnDataTimeBasedBufferingOutboundObserver>) + BeamFnDataBufferingOutboundObserver.forLocation( + options, + OUTPUT_LOCATION, + CODER, + TestStreams.withOnNext( + (Consumer) + e -> { + throw new RuntimeException(""); + }) + .build()); + + // Test that it emits when time passed the time limit + consumer.accept(valueInGlobalWindow(new byte[1])); + // wait the flush thread to flush the buffer + while (!consumer.flushFuture.isDone()) { + Thread.sleep(1); + } + try { + // Test that the exception caught in the flush thread is propagate to + // the main thread when processing the next element + consumer.accept(valueInGlobalWindow(new byte[1])); + fail(); + } catch (Exception e) { + // expected + } + + consumer = + (BeamFnDataTimeBasedBufferingOutboundObserver>) + BeamFnDataBufferingOutboundObserver.forLocation( + options, + OUTPUT_LOCATION, + CODER, + TestStreams.withOnNext( + (Consumer) + e -> { + throw new RuntimeException(""); + }) + .build()); + consumer.accept(valueInGlobalWindow(new byte[1])); + // wait the flush thread to flush the buffer + while (!consumer.flushFuture.isDone()) { + Thread.sleep(1); + } + try { + // Test that the exception caught in the flush thread is propagate to + // the main thread when closing + consumer.close(); + fail(); + } catch (Exception e) { + // expected + } + } +} diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java index 48dbcf1629a2..e93b2ba0c7c0 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataReadRunner.java @@ -22,6 +22,7 @@ import com.google.auto.service.AutoService; import java.io.IOException; import java.util.Map; +import java.util.function.Consumer; import java.util.function.Supplier; import org.apache.beam.fn.harness.control.BundleSplitListener; import org.apache.beam.fn.harness.data.BeamFnDataClient; @@ -42,6 +43,7 @@ import org.apache.beam.sdk.fn.data.InboundDataClient; import org.apache.beam.sdk.fn.data.LogicalEndpoint; import org.apache.beam.sdk.fn.data.RemoteGrpcPortRead; +import org.apache.beam.sdk.function.ThrowingRunnable; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; @@ -86,6 +88,7 @@ public BeamFnDataReadRunner createRunnerForPTransform( PCollectionConsumerRegistry pCollectionConsumerRegistry, PTransformFunctionRegistry startFunctionRegistry, PTransformFunctionRegistry finishFunctionRegistry, + Consumer tearDownFunctions, BundleSplitListener splitListener) throws IOException { diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataWriteRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataWriteRunner.java index 9d2e35229193..dbe5d98947db 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataWriteRunner.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BeamFnDataWriteRunner.java @@ -22,6 +22,7 @@ import com.google.auto.service.AutoService; import java.io.IOException; import java.util.Map; +import java.util.function.Consumer; import java.util.function.Supplier; import org.apache.beam.fn.harness.control.BundleSplitListener; import org.apache.beam.fn.harness.data.BeamFnDataClient; @@ -42,6 +43,7 @@ import org.apache.beam.sdk.fn.data.FnDataReceiver; import org.apache.beam.sdk.fn.data.LogicalEndpoint; import org.apache.beam.sdk.fn.data.RemoteGrpcPortWrite; +import org.apache.beam.sdk.function.ThrowingRunnable; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; @@ -86,6 +88,7 @@ public BeamFnDataWriteRunner createRunnerForPTransform( PCollectionConsumerRegistry pCollectionConsumerRegistry, PTransformFunctionRegistry startFunctionRegistry, PTransformFunctionRegistry finishFunctionRegistry, + Consumer tearDownFunctions, BundleSplitListener splitListener) throws IOException { RunnerApi.Coder coderSpec; diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BoundedSourceRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BoundedSourceRunner.java index 6342e1d9d838..a632aa28769d 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BoundedSourceRunner.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/BoundedSourceRunner.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.util.Collection; import java.util.Map; +import java.util.function.Consumer; import java.util.function.Supplier; import org.apache.beam.fn.harness.control.BundleSplitListener; import org.apache.beam.fn.harness.control.ProcessBundleHandler; @@ -36,6 +37,7 @@ import org.apache.beam.runners.core.construction.PTransformTranslation; import org.apache.beam.runners.core.construction.ReadTranslation; import org.apache.beam.sdk.fn.data.FnDataReceiver; +import org.apache.beam.sdk.function.ThrowingRunnable; import org.apache.beam.sdk.io.BoundedSource; import org.apache.beam.sdk.io.Source.Reader; import org.apache.beam.sdk.options.PipelineOptions; @@ -80,6 +82,7 @@ public BoundedSourceRunner createRunnerForPTransform( PCollectionConsumerRegistry pCollectionConsumerRegistry, PTransformFunctionRegistry startFunctionRegistry, PTransformFunctionRegistry finishFunctionRegistry, + Consumer tearDownFunctions, BundleSplitListener splitListener) { ImmutableList.Builder>> consumers = ImmutableList.builder(); for (String pCollectionId : pTransform.getOutputsMap().values()) { diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/CombineRunners.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/CombineRunners.java index 71a807caba08..fbdb95da21a4 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/CombineRunners.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/CombineRunners.java @@ -20,6 +20,7 @@ import com.google.auto.service.AutoService; import java.io.IOException; import java.util.Map; +import java.util.function.Consumer; import java.util.function.Supplier; import org.apache.beam.fn.harness.control.BundleSplitListener; import org.apache.beam.fn.harness.data.BeamFnDataClient; @@ -37,6 +38,7 @@ import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.fn.data.FnDataReceiver; import org.apache.beam.sdk.function.ThrowingFunction; +import org.apache.beam.sdk.function.ThrowingRunnable; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.transforms.Combine.CombineFn; import org.apache.beam.sdk.util.SerializableUtils; @@ -125,6 +127,7 @@ public PrecombineRunner createRunnerForPTransform( PCollectionConsumerRegistry pCollectionConsumerRegistry, PTransformFunctionRegistry startFunctionRegistry, PTransformFunctionRegistry finishFunctionRegistry, + Consumer tearDownFunctions, BundleSplitListener splitListener) throws IOException { // Get objects needed to create the runner. diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/DoFnPTransformRunnerFactory.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/DoFnPTransformRunnerFactory.java index d548fc5e7c7c..433572b705d5 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/DoFnPTransformRunnerFactory.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/DoFnPTransformRunnerFactory.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.util.Map; +import java.util.function.Consumer; import java.util.function.Supplier; import org.apache.beam.fn.harness.control.BundleSplitListener; import org.apache.beam.fn.harness.data.BeamFnDataClient; @@ -40,6 +41,7 @@ import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.fn.data.FnDataReceiver; +import org.apache.beam.sdk.function.ThrowingRunnable; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.schemas.SchemaCoder; import org.apache.beam.sdk.state.TimeDomain; @@ -76,6 +78,8 @@ void processTimer( String timerId, TimeDomain timeDomain, WindowedValue> input); void finishBundle() throws Exception; + + void tearDown() throws Exception; } @Override @@ -92,6 +96,7 @@ public final RunnerT createRunnerForPTransform( PCollectionConsumerRegistry pCollectionConsumerRegistry, PTransformFunctionRegistry startFunctionRegistry, PTransformFunctionRegistry finishFunctionRegistry, + Consumer tearDownFunctions, BundleSplitListener splitListener) { Context context = new Context<>( @@ -139,6 +144,7 @@ public final RunnerT createRunnerForPTransform( } finishFunctionRegistry.register(pTransformId, runner::finishBundle); + tearDownFunctions.accept(runner::tearDown); return runner; } diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FlattenRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FlattenRunner.java index 020a8e82f356..702751c6bbc6 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FlattenRunner.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FlattenRunner.java @@ -22,6 +22,7 @@ import com.google.auto.service.AutoService; import java.io.IOException; import java.util.Map; +import java.util.function.Consumer; import java.util.function.Supplier; import org.apache.beam.fn.harness.control.BundleSplitListener; import org.apache.beam.fn.harness.data.BeamFnDataClient; @@ -33,6 +34,7 @@ import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection; import org.apache.beam.runners.core.construction.PTransformTranslation; import org.apache.beam.sdk.fn.data.FnDataReceiver; +import org.apache.beam.sdk.function.ThrowingRunnable; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; @@ -65,6 +67,7 @@ public FlattenRunner createRunnerForPTransform( PCollectionConsumerRegistry pCollectionConsumerRegistry, PTransformFunctionRegistry startFunctionRegistry, PTransformFunctionRegistry finishFunctionRegistry, + Consumer tearDownFunctions, BundleSplitListener splitListener) throws IOException { diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java index 61483360081d..a44629f6ccea 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java @@ -225,6 +225,11 @@ public void finishBundle() { this.stateAccessor = null; } + @Override + public void tearDown() { + doFnInvoker.invokeTeardown(); + } + /** Outputs the given element to the specified set of consumers wrapping any exceptions. */ private void outputTo( Collection>> consumers, WindowedValue output) { diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java index 252d902d7379..1aa5ba58ad2d 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnHarness.java @@ -206,6 +206,7 @@ public static void main( LOG.info("Entering instruction processing loop"); control.processInstructionRequests(executorService); + processBundleHandler.shutdown(); } finally { System.out.println("Shutting SDK harness down."); executorService.shutdown(); diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/MapFnRunners.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/MapFnRunners.java index 4345d29a5b97..0cd5dcae0854 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/MapFnRunners.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/MapFnRunners.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.util.Map; +import java.util.function.Consumer; import java.util.function.Supplier; import org.apache.beam.fn.harness.control.BundleSplitListener; import org.apache.beam.fn.harness.data.BeamFnDataClient; @@ -32,6 +33,7 @@ import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform; import org.apache.beam.sdk.fn.data.FnDataReceiver; import org.apache.beam.sdk.function.ThrowingFunction; +import org.apache.beam.sdk.function.ThrowingRunnable; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; @@ -107,6 +109,7 @@ public Mapper createRunnerForPTransform( PCollectionConsumerRegistry pCollectionConsumerRegistry, PTransformFunctionRegistry startFunctionRegistry, PTransformFunctionRegistry finishFunctionRegistry, + Consumer tearDownFunctions, BundleSplitListener splitListener) throws IOException { diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PTransformRunnerFactory.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PTransformRunnerFactory.java index ecd0f64844eb..389331f87c21 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PTransformRunnerFactory.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/PTransformRunnerFactory.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.util.Map; +import java.util.function.Consumer; import java.util.function.Supplier; import org.apache.beam.fn.harness.control.BundleSplitListener; import org.apache.beam.fn.harness.data.BeamFnDataClient; @@ -29,6 +30,7 @@ import org.apache.beam.model.pipeline.v1.RunnerApi.Coder; import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection; import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform; +import org.apache.beam.sdk.function.ThrowingRunnable; import org.apache.beam.sdk.options.PipelineOptions; /** A factory able to instantiate an appropriate handler for a given PTransform. */ @@ -55,6 +57,7 @@ public interface PTransformRunnerFactory { * registered within this multimap. * @param startFunctionRegistry A class to register a start bundle handler with. * @param finishFunctionRegistry A class to register a finish bundle handler with. + * @param addTearDownFunction A consumer to register a tear down handler with. * @param splitListener A listener to be invoked when the PTransform splits itself. */ T createRunnerForPTransform( @@ -70,6 +73,7 @@ T createRunnerForPTransform( PCollectionConsumerRegistry pCollectionConsumerRegistry, PTransformFunctionRegistry startFunctionRegistry, PTransformFunctionRegistry finishFunctionRegistry, + Consumer addTearDownFunction, BundleSplitListener splitListener) throws IOException; diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/SplittableProcessElementsRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/SplittableProcessElementsRunner.java index ec2875ee3a3f..7670a9a49d3a 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/SplittableProcessElementsRunner.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/SplittableProcessElementsRunner.java @@ -265,6 +265,11 @@ public void finishBundle() { doFnInvoker.invokeFinishBundle(finishBundleContext); } + @Override + public void tearDown() { + doFnInvoker.invokeTeardown(); + } + /** Outputs the given element to the specified set of consumers wrapping any exceptions. */ private void outputTo( Collection>> consumers, WindowedValue output) { diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java index 8b1c5fdedeef..11222859c495 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java @@ -17,15 +17,19 @@ */ package org.apache.beam.fn.harness.control; +import com.google.auto.value.AutoValue; import java.io.Closeable; import java.io.IOException; +import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.ServiceLoader; import java.util.Set; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.Phaser; +import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; import org.apache.beam.fn.harness.PTransformRunnerFactory; @@ -41,7 +45,6 @@ import org.apache.beam.model.fnexecution.v1.BeamFnApi.DelayedBundleApplication; import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleDescriptor; import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleRequest; -import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleResponse; import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest; import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest.Builder; import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateResponse; @@ -66,6 +69,7 @@ import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.HashMultimap; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Multimap; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.SetMultimap; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Sets; @@ -87,7 +91,7 @@ public class ProcessBundleHandler { public static final String JAVA_SOURCE_URN = "beam:source:java:0.1"; private static final Logger LOG = LoggerFactory.getLogger(ProcessBundleHandler.class); - private static final Map REGISTERED_RUNNER_FACTORIES; + @VisibleForTesting static final Map REGISTERED_RUNNER_FACTORIES; static { Set pipelineRunnerRegistrars = @@ -109,6 +113,7 @@ public class ProcessBundleHandler { private final BeamFnStateGrpcClientCache beamFnStateGrpcClientCache; private final Map urnToPTransformRunnerFactoryMap; private final PTransformRunnerFactory defaultPTransformRunnerFactory; + @VisibleForTesting final BundleProcessorCache bundleProcessorCache; public ProcessBundleHandler( PipelineOptions options, @@ -120,7 +125,8 @@ public ProcessBundleHandler( fnApiRegistry, beamFnDataClient, beamFnStateGrpcClientCache, - REGISTERED_RUNNER_FACTORIES); + REGISTERED_RUNNER_FACTORIES, + new BundleProcessorCache()); } @VisibleForTesting @@ -129,7 +135,8 @@ public ProcessBundleHandler( Function fnApiRegistry, BeamFnDataClient beamFnDataClient, BeamFnStateGrpcClientCache beamFnStateGrpcClientCache, - Map urnToPTransformRunnerFactoryMap) { + Map urnToPTransformRunnerFactoryMap, + BundleProcessorCache bundleProcessorCache) { this.options = options; this.fnApiRegistry = fnApiRegistry; this.beamFnDataClient = beamFnDataClient; @@ -137,6 +144,7 @@ public ProcessBundleHandler( this.urnToPTransformRunnerFactoryMap = urnToPTransformRunnerFactoryMap; this.defaultPTransformRunnerFactory = new UnknownPTransformRunnerFactory(urnToPTransformRunnerFactoryMap.keySet()); + this.bundleProcessorCache = bundleProcessorCache; } private void createRunnerAndConsumersForPTransformRecursively( @@ -151,6 +159,7 @@ private void createRunnerAndConsumersForPTransformRecursively( Set processedPTransformIds, PTransformFunctionRegistry startFunctionRegistry, PTransformFunctionRegistry finishFunctionRegistry, + Consumer addTearDownFunction, BundleSplitListener splitListener) throws IOException { @@ -172,6 +181,7 @@ private void createRunnerAndConsumersForPTransformRecursively( processedPTransformIds, startFunctionRegistry, finishFunctionRegistry, + addTearDownFunction, splitListener); } } @@ -204,6 +214,7 @@ private void createRunnerAndConsumersForPTransformRecursively( pCollectionConsumerRegistry, startFunctionRegistry, finishFunctionRegistry, + addTearDownFunction, splitListener); processedPTransformIds.add(pTransformId); } @@ -215,12 +226,86 @@ private void createRunnerAndConsumersForPTransformRecursively( */ public BeamFnApi.InstructionResponse.Builder processBundle(BeamFnApi.InstructionRequest request) throws Exception { + BeamFnApi.ProcessBundleResponse.Builder response = BeamFnApi.ProcessBundleResponse.newBuilder(); + + BundleProcessor bundleProcessor = + bundleProcessorCache.get( + request.getProcessBundle().getProcessBundleDescriptorId(), + () -> { + try { + return createBundleProcessor( + request.getProcessBundle().getProcessBundleDescriptorId(), + request.getProcessBundle()); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + bundleProcessor.setInstructionId(request.getInstructionId()); + PTransformFunctionRegistry startFunctionRegistry = bundleProcessor.getStartFunctionRegistry(); + PTransformFunctionRegistry finishFunctionRegistry = bundleProcessor.getFinishFunctionRegistry(); + Multimap allResiduals = bundleProcessor.getAllResiduals(); + PCollectionConsumerRegistry pCollectionConsumerRegistry = + bundleProcessor.getpCollectionConsumerRegistry(); + MetricsContainerStepMap metricsContainerRegistry = + bundleProcessor.getMetricsContainerRegistry(); + ExecutionStateTracker stateTracker = bundleProcessor.getStateTracker(); + QueueingBeamFnDataClient queueingClient = bundleProcessor.getQueueingClient(); + + try (HandleStateCallsForBundle beamFnStateClient = bundleProcessor.getBeamFnStateClient()) { + try (Closeable closeTracker = stateTracker.activate()) { + // Already in reverse topological order so we don't need to do anything. + for (ThrowingRunnable startFunction : startFunctionRegistry.getFunctions()) { + LOG.debug("Starting function {}", startFunction); + startFunction.run(); + } + + queueingClient.drainAndBlock(); + + // Need to reverse this since we want to call finish in topological order. + for (ThrowingRunnable finishFunction : + Lists.reverse(finishFunctionRegistry.getFunctions())) { + LOG.debug("Finishing function {}", finishFunction); + finishFunction.run(); + } + if (!allResiduals.isEmpty()) { + response.addAllResidualRoots(allResiduals.values()); + } + } + // Get start bundle Execution Time Metrics. + for (MonitoringInfo mi : startFunctionRegistry.getExecutionTimeMonitoringInfos()) { + response.addMonitoringInfos(mi); + } + // Get process bundle Execution Time Metrics. + for (MonitoringInfo mi : pCollectionConsumerRegistry.getExecutionTimeMonitoringInfos()) { + response.addMonitoringInfos(mi); + } + + // Get finish bundle Execution Time Metrics. + for (MonitoringInfo mi : finishFunctionRegistry.getExecutionTimeMonitoringInfos()) { + response.addMonitoringInfos(mi); + } + // Extract all other MonitoringInfos other than the execution time monitoring infos. + for (MonitoringInfo mi : metricsContainerRegistry.getMonitoringInfos()) { + response.addMonitoringInfos(mi); + } + bundleProcessorCache.release( + request.getProcessBundle().getProcessBundleDescriptorId(), bundleProcessor); + } + return BeamFnApi.InstructionResponse.newBuilder().setProcessBundle(response); + } + + /** Shutdown the bundles, running the tearDown() functions. */ + public void shutdown() throws Exception { + bundleProcessorCache.shutdown(); + } + + private BundleProcessor createBundleProcessor( + String bundleId, BeamFnApi.ProcessBundleRequest processBundleRequest) throws IOException { // Note: We must create one instance of the QueueingBeamFnDataClient as it is designed to // handle the life of a bundle. It will insert elements onto a queue and drain them off so all // process() calls will execute on this thread when queueingClient.drainAndBlock() is called. QueueingBeamFnDataClient queueingClient = new QueueingBeamFnDataClient(this.beamFnDataClient); - String bundleId = request.getProcessBundle().getProcessBundleDescriptorId(); BeamFnApi.ProcessBundleDescriptor bundleDescriptor = (BeamFnApi.ProcessBundleDescriptor) fnApiRegistry.apply(bundleId); @@ -238,6 +323,7 @@ public BeamFnApi.InstructionResponse.Builder processBundle(BeamFnApi.Instruction PTransformFunctionRegistry finishFunctionRegistry = new PTransformFunctionRegistry( metricsContainerRegistry, stateTracker, ExecutionStateTracker.FINISH_STATE_NAME); + List tearDownFunctions = new ArrayList<>(); // Build a multimap of PCollection ids to PTransform ids which consume said PCollections for (Map.Entry entry : @@ -247,99 +333,191 @@ public BeamFnApi.InstructionResponse.Builder processBundle(BeamFnApi.Instruction } } - ProcessBundleResponse.Builder response = ProcessBundleResponse.newBuilder(); + Multimap allResiduals = ArrayListMultimap.create(); // Instantiate a State API call handler depending on whether a State Api service descriptor // was specified. - try (HandleStateCallsForBundle beamFnStateClient = + HandleStateCallsForBundle beamFnStateClient = bundleDescriptor.hasStateApiServiceDescriptor() ? new BlockTillStateCallsFinish( beamFnStateGrpcClientCache.forApiServiceDescriptor( bundleDescriptor.getStateApiServiceDescriptor())) - : new FailAllStateCallsForBundle(request.getProcessBundle())) { - Multimap allPrimaries = ArrayListMultimap.create(); - Multimap allResiduals = ArrayListMultimap.create(); - BundleSplitListener splitListener = - (List primaries, List residuals) -> { - // Reset primaries and accumulate residuals. - Multimap newPrimaries = ArrayListMultimap.create(); - for (BundleApplication primary : primaries) { - newPrimaries.put(primary.getTransformId(), primary); - } - allPrimaries.clear(); - allPrimaries.putAll(newPrimaries); - - for (DelayedBundleApplication residual : residuals) { - allResiduals.put(residual.getApplication().getTransformId(), residual); - } - }; - - // Create a BeamFnStateClient - for (Map.Entry entry : - bundleDescriptor.getTransformsMap().entrySet()) { - - // Skip anything which isn't a root - // TODO: Remove source as a root and have it be triggered by the Runner. - if (!DATA_INPUT_URN.equals(entry.getValue().getSpec().getUrn()) - && !JAVA_SOURCE_URN.equals(entry.getValue().getSpec().getUrn()) - && !PTransformTranslation.READ_TRANSFORM_URN.equals( - entry.getValue().getSpec().getUrn())) { - continue; - } - - createRunnerAndConsumersForPTransformRecursively( - beamFnStateClient, - queueingClient, - entry.getKey(), - entry.getValue(), - request::getInstructionId, - bundleDescriptor, - pCollectionIdsToConsumingPTransforms, - pCollectionConsumerRegistry, - processedPTransformIds, + : new FailAllStateCallsForBundle(processBundleRequest); + Multimap allPrimaries = ArrayListMultimap.create(); + BundleSplitListener splitListener = + (List primaries, List residuals) -> { + // Reset primaries and accumulate residuals. + Multimap newPrimaries = ArrayListMultimap.create(); + for (BundleApplication primary : primaries) { + newPrimaries.put(primary.getTransformId(), primary); + } + allPrimaries.clear(); + allPrimaries.putAll(newPrimaries); + + for (DelayedBundleApplication residual : residuals) { + allResiduals.put(residual.getApplication().getTransformId(), residual); + } + }; + + BundleProcessor bundleProcessor = + BundleProcessor.create( startFunctionRegistry, finishFunctionRegistry, - splitListener); + tearDownFunctions, + allResiduals, + pCollectionConsumerRegistry, + metricsContainerRegistry, + stateTracker, + beamFnStateClient, + queueingClient); + + // Create a BeamFnStateClient + for (Map.Entry entry : + bundleDescriptor.getTransformsMap().entrySet()) { + + // Skip anything which isn't a root + // TODO: Remove source as a root and have it be triggered by the Runner. + if (!DATA_INPUT_URN.equals(entry.getValue().getSpec().getUrn()) + && !JAVA_SOURCE_URN.equals(entry.getValue().getSpec().getUrn()) + && !PTransformTranslation.READ_TRANSFORM_URN.equals( + entry.getValue().getSpec().getUrn())) { + continue; } - try (Closeable closeTracker = stateTracker.activate()) { - // Already in reverse topological order so we don't need to do anything. - for (ThrowingRunnable startFunction : startFunctionRegistry.getFunctions()) { - LOG.debug("Starting function {}", startFunction); - startFunction.run(); - } + createRunnerAndConsumersForPTransformRecursively( + beamFnStateClient, + queueingClient, + entry.getKey(), + entry.getValue(), + bundleProcessor::getInstructionId, + bundleDescriptor, + pCollectionIdsToConsumingPTransforms, + pCollectionConsumerRegistry, + processedPTransformIds, + startFunctionRegistry, + finishFunctionRegistry, + tearDownFunctions::add, + splitListener); + } + return bundleProcessor; + } - queueingClient.drainAndBlock(); + /** A cache for {@link BundleProcessor}s. */ + public static class BundleProcessorCache { - // Need to reverse this since we want to call finish in topological order. - for (ThrowingRunnable finishFunction : - Lists.reverse(finishFunctionRegistry.getFunctions())) { - LOG.debug("Finishing function {}", finishFunction); - finishFunction.run(); - } - if (!allResiduals.isEmpty()) { - response.addAllResidualRoots(allResiduals.values()); - } - } - // Get start bundle Execution Time Metrics. - for (MonitoringInfo mi : startFunctionRegistry.getExecutionTimeMonitoringInfos()) { - response.addMonitoringInfos(mi); - } - // Get process bundle Execution Time Metrics. - for (MonitoringInfo mi : pCollectionConsumerRegistry.getExecutionTimeMonitoringInfos()) { - response.addMonitoringInfos(mi); - } + private final Map> cachedBundleProcessors; - // Get finish bundle Execution Time Metrics. - for (MonitoringInfo mi : finishFunctionRegistry.getExecutionTimeMonitoringInfos()) { - response.addMonitoringInfos(mi); + BundleProcessorCache() { + this.cachedBundleProcessors = Maps.newConcurrentMap(); + } + + Map> getCachedBundleProcessors() { + return cachedBundleProcessors; + } + + /** + * Get a {@link BundleProcessor} from the cache if it's available. Otherwise, create one using + * the specified bundleProcessorSupplier. + */ + BundleProcessor get( + String bundleDescriptorId, Supplier bundleProcessorSupplier) { + ConcurrentLinkedQueue bundleProcessors = + cachedBundleProcessors.computeIfAbsent( + bundleDescriptorId, descriptorId -> new ConcurrentLinkedQueue<>()); + BundleProcessor bundleProcessor = bundleProcessors.poll(); + if (bundleProcessor != null) { + return bundleProcessor; } - // Extract all other MonitoringInfos other than the execution time monitoring infos. - for (MonitoringInfo mi : metricsContainerRegistry.getMonitoringInfos()) { - response.addMonitoringInfos(mi); + + return bundleProcessorSupplier.get(); + } + + /** + * Add a {@link BundleProcessor} to cache. The {@link BundleProcessor} will be reset before + * being added to the cache. + */ + void release(String bundleDescriptorId, BundleProcessor bundleProcessor) { + bundleProcessor.reset(); + cachedBundleProcessors.get(bundleDescriptorId).add(bundleProcessor); + } + + /** Shutdown all the cached {@link BundleProcessor}s, running the tearDown() functions. */ + void shutdown() throws Exception { + for (ConcurrentLinkedQueue bundleProcessors : + cachedBundleProcessors.values()) { + for (BundleProcessor bundleProcessor : bundleProcessors) { + for (ThrowingRunnable tearDownFunction : bundleProcessor.getTearDownFunctions()) { + LOG.debug("Tearing down function {}", tearDownFunction); + tearDownFunction.run(); + } + } } + cachedBundleProcessors.clear(); + } + } + + /** A container for the reusable information used to process a bundle. */ + @AutoValue + public abstract static class BundleProcessor { + public static BundleProcessor create( + PTransformFunctionRegistry startFunctionRegistry, + PTransformFunctionRegistry finishFunctionRegistry, + List tearDownFunctions, + Multimap allResiduals, + PCollectionConsumerRegistry pCollectionConsumerRegistry, + MetricsContainerStepMap metricsContainerRegistry, + ExecutionStateTracker stateTracker, + HandleStateCallsForBundle beamFnStateClient, + QueueingBeamFnDataClient queueingClient) { + return new AutoValue_ProcessBundleHandler_BundleProcessor( + startFunctionRegistry, + finishFunctionRegistry, + tearDownFunctions, + allResiduals, + pCollectionConsumerRegistry, + metricsContainerRegistry, + stateTracker, + beamFnStateClient, + queueingClient); + } + + private String instructionId; + + abstract PTransformFunctionRegistry getStartFunctionRegistry(); + + abstract PTransformFunctionRegistry getFinishFunctionRegistry(); + + abstract List getTearDownFunctions(); + + abstract Multimap getAllResiduals(); + + abstract PCollectionConsumerRegistry getpCollectionConsumerRegistry(); + + abstract MetricsContainerStepMap getMetricsContainerRegistry(); + + abstract ExecutionStateTracker getStateTracker(); + + abstract HandleStateCallsForBundle getBeamFnStateClient(); + + abstract QueueingBeamFnDataClient getQueueingClient(); + + String getInstructionId() { + return this.instructionId; + } + + void setInstructionId(String instructionId) { + this.instructionId = instructionId; + } + + void reset() { + getStartFunctionRegistry().reset(); + getFinishFunctionRegistry().reset(); + getAllResiduals().clear(); + getpCollectionConsumerRegistry().reset(); + getMetricsContainerRegistry().reset(); + getStateTracker().reset(); + ExecutionStateSampler.instance().reset(); } - return BeamFnApi.InstructionResponse.newBuilder().setProcessBundle(response); } /** @@ -408,8 +586,7 @@ public void handle(Builder requestBuilder, CompletableFuture resp } } - private abstract static class HandleStateCallsForBundle - implements AutoCloseable, BeamFnStateClient {} + abstract static class HandleStateCallsForBundle implements AutoCloseable, BeamFnStateClient {} private static class UnknownPTransformRunnerFactory implements PTransformRunnerFactory { private final Set knownUrns; @@ -432,6 +609,7 @@ public Object createRunnerForPTransform( PCollectionConsumerRegistry pCollectionConsumerRegistry, PTransformFunctionRegistry startFunctionRegistry, PTransformFunctionRegistry finishFunctionRegistry, + Consumer tearDownFunctions, BundleSplitListener splitListener) { String message = String.format( diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClient.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClient.java index 375c9af9fda8..61c4580595f3 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClient.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClient.java @@ -17,9 +17,6 @@ */ package org.apache.beam.fn.harness.data; -import java.util.Collections; -import java.util.List; -import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.function.Function; @@ -35,7 +32,6 @@ import org.apache.beam.sdk.fn.data.InboundDataClient; import org.apache.beam.sdk.fn.data.LogicalEndpoint; import org.apache.beam.sdk.fn.stream.OutboundObserverFactory; -import org.apache.beam.sdk.options.ExperimentalOptions; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.vendor.grpc.v1p21p0.io.grpc.ManagedChannel; import org.slf4j.Logger; @@ -48,7 +44,6 @@ */ public class BeamFnDataGrpcClient implements BeamFnDataClient { private static final Logger LOG = LoggerFactory.getLogger(BeamFnDataGrpcClient.class); - private static final String BEAM_FN_API_DATA_BUFFER_LIMIT = "beam_fn_api_data_buffer_limit="; private final ConcurrentMap cache; private final Function channelFactory; @@ -112,26 +107,8 @@ public CloseableFnDataReceiver send( "Creating output consumer for instruction {} and transform {}", outputLocation.getInstructionId(), outputLocation.getTransformId()); - Optional bufferLimit = getBufferLimit(options); - if (bufferLimit.isPresent()) { - return BeamFnDataBufferingOutboundObserver.forLocationWithBufferLimit( - bufferLimit.get(), outputLocation, coder, client.getOutboundObserver()); - } else { - return BeamFnDataBufferingOutboundObserver.forLocation( - outputLocation, coder, client.getOutboundObserver()); - } - } - - /** Returns the {@code beam_fn_api_data_buffer_limit=} experiment value if set. */ - private static Optional getBufferLimit(PipelineOptions options) { - List experiments = options.as(ExperimentalOptions.class).getExperiments(); - for (String experiment : experiments == null ? Collections.emptyList() : experiments) { - if (experiment.startsWith(BEAM_FN_API_DATA_BUFFER_LIMIT)) { - return Optional.of( - Integer.parseInt(experiment.substring(BEAM_FN_API_DATA_BUFFER_LIMIT.length()))); - } - } - return Optional.empty(); + return BeamFnDataBufferingOutboundObserver.forLocation( + options, outputLocation, coder, client.getOutboundObserver()); } private BeamFnDataGrpcMultiplexer getClientFor( diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java index 0bbdd6787ad7..80d270fbb5b4 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java @@ -109,6 +109,11 @@ public void register( pCollectionIdsToConsumers.put(pCollectionId, (FnDataReceiver) wrapAndEnableMetricContainer); } + /** Reset the execution states of the registered functions. */ + public void reset() { + executionStates.reset(); + } + /** @return the list of pcollection ids. */ public Set keySet() { return pCollectionIdsToConsumers.keySet(); diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PTransformFunctionRegistry.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PTransformFunctionRegistry.java index 26f1ade4c420..10d421cabd2d 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PTransformFunctionRegistry.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PTransformFunctionRegistry.java @@ -112,6 +112,11 @@ public void register(String pTransformId, ThrowingRunnable runnable) { runnables.add(wrapped); } + /** Reset the execution states of the registered functions. */ + public void reset() { + executionStates.reset(); + } + /** @return Execution Time MonitoringInfos based on the tracked start or finish function. */ public List getExecutionTimeMonitoringInfos() { return executionStates.getExecutionTimeMonitoringInfos(); diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/AssignWindowsRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/AssignWindowsRunnerTest.java index fc5e797e217c..e6e5297a07e1 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/AssignWindowsRunnerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/AssignWindowsRunnerTest.java @@ -205,6 +205,7 @@ public Coder windowCoder() { pCollectionConsumerRegistry, null /* startFunctionRegistry */, null, /* finishFunctionRegistry */ + null, /* tearDownRegistry */ null /* splitListener */); WindowedValue value = diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java index 18078107331e..e8a814c006ac 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataReadRunnerTest.java @@ -20,6 +20,7 @@ import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.empty; import static org.junit.Assert.assertThat; import static org.junit.Assert.fail; import static org.mockito.Matchers.any; @@ -59,6 +60,7 @@ import org.apache.beam.sdk.fn.data.RemoteGrpcPortRead; import org.apache.beam.sdk.fn.test.TestExecutors; import org.apache.beam.sdk.fn.test.TestExecutors.TestExecutorService; +import org.apache.beam.sdk.function.ThrowingRunnable; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.util.WindowedValue; @@ -144,6 +146,7 @@ public void testCreatingAndProcessingBeamFnDataReadRunner() throws Exception { PTransformFunctionRegistry finishFunctionRegistry = new PTransformFunctionRegistry( mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish"); + List teardownFunctions = new ArrayList<>(); RunnerApi.PTransform pTransform = RemoteGrpcPortRead.readFromPort(PORT_SPEC, localOutputId).toPTransform(); @@ -164,8 +167,11 @@ public void testCreatingAndProcessingBeamFnDataReadRunner() throws Exception { consumers, startFunctionRegistry, finishFunctionRegistry, + teardownFunctions::add, null /* splitListener */); + assertThat(teardownFunctions, empty()); + verifyZeroInteractions(mockBeamFnDataClient); InboundDataClient completionFuture = CompletableFutureInboundDataClient.create(); diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java index 0c04aab93436..5a77acf87557 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BeamFnDataWriteRunnerTest.java @@ -20,6 +20,7 @@ import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.empty; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; @@ -54,6 +55,7 @@ import org.apache.beam.sdk.fn.data.CloseableFnDataReceiver; import org.apache.beam.sdk.fn.data.LogicalEndpoint; import org.apache.beam.sdk.fn.data.RemoteGrpcPortWrite; +import org.apache.beam.sdk.function.ThrowingRunnable; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.util.WindowedValue; @@ -125,6 +127,7 @@ public void testCreatingAndProcessingBeamFnDataWriteRunner() throws Exception { PTransformFunctionRegistry finishFunctionRegistry = new PTransformFunctionRegistry( mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish"); + List teardownFunctions = new ArrayList<>(); String localInputId = "inputPC"; RunnerApi.PTransform pTransform = @@ -145,8 +148,11 @@ public void testCreatingAndProcessingBeamFnDataWriteRunner() throws Exception { consumers, startFunctionRegistry, finishFunctionRegistry, + teardownFunctions::add, null /* splitListener */); + assertThat(teardownFunctions, empty()); + verifyZeroInteractions(mockBeamFnDataClient); List> outputValues = new ArrayList<>(); diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BoundedSourceRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BoundedSourceRunnerTest.java index e5252d6dbd6a..bc31e9c1d771 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BoundedSourceRunnerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/BoundedSourceRunnerTest.java @@ -38,6 +38,7 @@ import org.apache.beam.runners.core.metrics.ExecutionStateTracker; import org.apache.beam.runners.core.metrics.MetricsContainerStepMap; import org.apache.beam.sdk.fn.data.FnDataReceiver; +import org.apache.beam.sdk.function.ThrowingRunnable; import org.apache.beam.sdk.io.BoundedSource; import org.apache.beam.sdk.io.CountingSource; import org.apache.beam.sdk.options.PipelineOptionsFactory; @@ -141,6 +142,7 @@ public void testCreatingAndProcessingSourceFromFactory() throws Exception { PTransformFunctionRegistry finishFunctionRegistry = new PTransformFunctionRegistry( mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish"); + List teardownFunctions = new ArrayList<>(); RunnerApi.FunctionSpec functionSpec = RunnerApi.FunctionSpec.newBuilder() @@ -170,6 +172,7 @@ public void testCreatingAndProcessingSourceFromFactory() throws Exception { consumers, startFunctionRegistry, finishFunctionRegistry, + teardownFunctions::add, null /* splitListener */); // This is testing a deprecated way of running sources and should be removed @@ -188,6 +191,7 @@ public void testCreatingAndProcessingSourceFromFactory() throws Exception { assertThat(outputValues, contains(valueInGlobalWindow(0L), valueInGlobalWindow(1L))); assertThat(finishFunctionRegistry.getFunctions(), Matchers.empty()); + assertThat(teardownFunctions, Matchers.empty()); } @Test diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/CombineRunnersTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/CombineRunnersTest.java index f7adb6c62846..7f5be398a148 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/CombineRunnersTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/CombineRunnersTest.java @@ -156,6 +156,7 @@ public void testPrecombine() throws Exception { consumers, startFunctionRegistry, finishFunctionRegistry, + null, null); Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run(); @@ -230,6 +231,7 @@ public void testMergeAccumulators() throws Exception { consumers, startFunctionRegistry, finishFunctionRegistry, + null, null); assertThat(startFunctionRegistry.getFunctions(), empty()); @@ -292,6 +294,7 @@ public void testExtractOutputs() throws Exception { consumers, startFunctionRegistry, finishFunctionRegistry, + null, null); assertThat(startFunctionRegistry.getFunctions(), empty()); @@ -354,6 +357,7 @@ public void testCombineGroupedValues() throws Exception { consumers, startFunctionRegistry, finishFunctionRegistry, + null, null); assertThat(startFunctionRegistry.getFunctions(), empty()); diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FlattenRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FlattenRunnerTest.java index cb43c449c55b..dc78ea2b22c8 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FlattenRunnerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FlattenRunnerTest.java @@ -91,6 +91,7 @@ public void testCreatingAndProcessingDoFlatten() throws Exception { consumers, null /* startFunctionRegistry */, null, /* finishFunctionRegistry */ + null, /* tearDownRegistry */ null /* splitListener */); mainOutputValues.clear(); @@ -158,6 +159,7 @@ public void testFlattenWithDuplicateInputCollectionProducesMultipleOutputs() thr consumers, null /* startFunctionRegistry */, null, /* finishFunctionRegistry */ + null, /* tearDownRegistry */ null /* splitListener */); mainOutputValues.clear(); diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java index 7a14b38012a4..45332836d583 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java @@ -53,6 +53,7 @@ import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.fn.data.FnDataReceiver; +import org.apache.beam.sdk.function.ThrowingRunnable; import org.apache.beam.sdk.metrics.Counter; import org.apache.beam.sdk.metrics.MetricKey; import org.apache.beam.sdk.metrics.MetricName; @@ -211,6 +212,7 @@ public void testUsingUserState() throws Exception { PTransformFunctionRegistry finishFunctionRegistry = new PTransformFunctionRegistry( mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish"); + List teardownFunctions = new ArrayList<>(); new FnApiDoFnRunner.Factory<>() .createRunnerForPTransform( @@ -226,6 +228,7 @@ public void testUsingUserState() throws Exception { consumers, startFunctionRegistry, finishFunctionRegistry, + teardownFunctions::add, null /* splitListener */); Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run(); @@ -261,6 +264,9 @@ public void testUsingUserState() throws Exception { Iterables.getOnlyElement(finishFunctionRegistry.getFunctions()).run(); assertThat(mainOutputValues, empty()); + Iterables.getOnlyElement(teardownFunctions).run(); + assertThat(mainOutputValues, empty()); + assertEquals( ImmutableMap.builder() .put(bagUserStateKey("value", "X"), encode("X2")) @@ -382,6 +388,7 @@ public void testBasicWithSideInputsAndOutputs() throws Exception { PTransformFunctionRegistry finishFunctionRegistry = new PTransformFunctionRegistry( mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish"); + List teardownFunctions = new ArrayList<>(); new FnApiDoFnRunner.Factory<>() .createRunnerForPTransform( @@ -397,6 +404,7 @@ public void testBasicWithSideInputsAndOutputs() throws Exception { consumers, startFunctionRegistry, finishFunctionRegistry, + teardownFunctions::add, null /* splitListener */); Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run(); @@ -433,6 +441,9 @@ public void testBasicWithSideInputsAndOutputs() throws Exception { Iterables.getOnlyElement(finishFunctionRegistry.getFunctions()).run(); assertThat(mainOutputValues, empty()); + Iterables.getOnlyElement(teardownFunctions).run(); + assertThat(mainOutputValues, empty()); + // Assert that state data did not change assertEquals(stateData, fakeClient.getData()); mainOutputValues.clear(); @@ -516,6 +527,7 @@ public void testSideInputIsAccessibleForDownstreamCallers() throws Exception { PTransformFunctionRegistry finishFunctionRegistry = new PTransformFunctionRegistry( mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish"); + List teardownFunctions = new ArrayList<>(); new FnApiDoFnRunner.Factory<>() .createRunnerForPTransform( @@ -531,6 +543,7 @@ public void testSideInputIsAccessibleForDownstreamCallers() throws Exception { consumers, startFunctionRegistry, finishFunctionRegistry, + teardownFunctions::add, null /* splitListener */); Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run(); @@ -551,6 +564,13 @@ public void testSideInputIsAccessibleForDownstreamCallers() throws Exception { assertThat( mainOutputValues.get(1).getValue(), contains("iterableValue1B", "iterableValue2B", "iterableValue3B")); + mainOutputValues.clear(); + + Iterables.getOnlyElement(finishFunctionRegistry.getFunctions()).run(); + assertThat(mainOutputValues, empty()); + + Iterables.getOnlyElement(teardownFunctions).run(); + assertThat(mainOutputValues, empty()); // Assert that state data did not change assertEquals(stateData, fakeClient.getData()); @@ -622,6 +642,7 @@ public void testUsingMetrics() throws Exception { PTransformFunctionRegistry finishFunctionRegistry = new PTransformFunctionRegistry( mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish"); + List teardownFunctions = new ArrayList<>(); new FnApiDoFnRunner.Factory<>() .createRunnerForPTransform( @@ -637,6 +658,7 @@ public void testUsingMetrics() throws Exception { consumers, startFunctionRegistry, finishFunctionRegistry, + teardownFunctions::add, null /* splitListener */); Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run(); @@ -650,6 +672,13 @@ public void testUsingMetrics() throws Exception { consumers.getMultiplexingConsumer(inputPCollectionId); mainInput.accept(valueInWindow("X", windowA)); mainInput.accept(valueInWindow("Y", windowB)); + mainOutputValues.clear(); + + Iterables.getOnlyElement(finishFunctionRegistry.getFunctions()).run(); + assertThat(mainOutputValues, empty()); + + Iterables.getOnlyElement(teardownFunctions).run(); + assertThat(mainOutputValues, empty()); MetricsContainer mc = MetricsEnvironment.getCurrentContainer(); @@ -810,6 +839,7 @@ public void testTimers() throws Exception { PTransformFunctionRegistry finishFunctionRegistry = new PTransformFunctionRegistry( mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish"); + List teardownFunctions = new ArrayList<>(); new FnApiDoFnRunner.Factory<>() .createRunnerForPTransform( @@ -837,6 +867,7 @@ public void testTimers() throws Exception { consumers, startFunctionRegistry, finishFunctionRegistry, + teardownFunctions::add, null /* splitListener */); Iterables.getOnlyElement(startFunctionRegistry.getFunctions()).run(); @@ -914,6 +945,9 @@ public void testTimers() throws Exception { Iterables.getOnlyElement(finishFunctionRegistry.getFunctions()).run(); assertThat(mainOutputValues, empty()); + Iterables.getOnlyElement(teardownFunctions).run(); + assertThat(mainOutputValues, empty()); + assertEquals( ImmutableMap.builder() .put(bagUserStateKey("bag", "X"), encode("X0", "X1", "X2", "processing")) diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/MapFnRunnersTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/MapFnRunnersTest.java index c408ee5d015e..956e295dd31e 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/MapFnRunnersTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/MapFnRunnersTest.java @@ -36,6 +36,7 @@ import org.apache.beam.runners.core.metrics.ExecutionStateTracker; import org.apache.beam.runners.core.metrics.MetricsContainerStepMap; import org.apache.beam.sdk.function.ThrowingFunction; +import org.apache.beam.sdk.function.ThrowingRunnable; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.transforms.windowing.IntervalWindow; @@ -74,6 +75,7 @@ public void testValueOnlyMapping() throws Exception { PTransformFunctionRegistry finishFunctionRegistry = new PTransformFunctionRegistry( metricsContainerRegistry, mock(ExecutionStateTracker.class), "finish"); + List teardownFunctions = new ArrayList<>(); ValueMapFnFactory factory = (ptId, pt) -> String::toUpperCase; MapFnRunners.forValueMapFnFactory(factory) @@ -90,10 +92,12 @@ public void testValueOnlyMapping() throws Exception { consumers, startFunctionRegistry, finishFunctionRegistry, + teardownFunctions::add, null /* splitListener */); assertThat(startFunctionRegistry.getFunctions(), empty()); assertThat(finishFunctionRegistry.getFunctions(), empty()); + assertThat(teardownFunctions, empty()); assertThat(consumers.keySet(), containsInAnyOrder("inputPC", "outputPC")); @@ -117,6 +121,7 @@ public void testFullWindowedValueMapping() throws Exception { PTransformFunctionRegistry finishFunctionRegistry = new PTransformFunctionRegistry( metricsContainerRegistry, mock(ExecutionStateTracker.class), "finish"); + List teardownFunctions = new ArrayList<>(); MapFnRunners.forWindowedValueMapFnFactory(this::createMapFunctionForPTransform) .createRunnerForPTransform( @@ -132,10 +137,12 @@ public void testFullWindowedValueMapping() throws Exception { consumers, startFunctionRegistry, finishFunctionRegistry, + teardownFunctions::add, null /* splitListener */); assertThat(startFunctionRegistry.getFunctions(), empty()); assertThat(finishFunctionRegistry.getFunctions(), empty()); + assertThat(teardownFunctions, empty()); assertThat(consumers.keySet(), containsInAnyOrder("inputPC", "outputPC")); @@ -158,6 +165,7 @@ public void testFullWindowedValueMappingWithCompressedWindow() throws Exception PTransformFunctionRegistry finishFunctionRegistry = new PTransformFunctionRegistry( mock(MetricsContainerStepMap.class), mock(ExecutionStateTracker.class), "finish"); + List teardownFunctions = new ArrayList<>(); MapFnRunners.forWindowedValueMapFnFactory(this::createMapFunctionForPTransform) .createRunnerForPTransform( @@ -173,10 +181,12 @@ public void testFullWindowedValueMappingWithCompressedWindow() throws Exception consumers, startFunctionRegistry, finishFunctionRegistry, + teardownFunctions::add, null /* splitListener */); assertThat(startFunctionRegistry.getFunctions(), empty()); assertThat(finishFunctionRegistry.getFunctions(), empty()); + assertThat(teardownFunctions, empty()); assertThat(consumers.keySet(), containsInAnyOrder("inputPC", "outputPC")); diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java index 54c1d1e3a17f..1a460e6b0019 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java @@ -17,25 +17,35 @@ */ package org.apache.beam.fn.harness.control; +import static org.apache.beam.fn.harness.control.ProcessBundleHandler.REGISTERED_RUNNER_FACTORIES; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.equalTo; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.any; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import java.io.IOException; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; import java.util.function.Supplier; import org.apache.beam.fn.harness.PTransformRunnerFactory; +import org.apache.beam.fn.harness.control.ProcessBundleHandler.BundleProcessor; +import org.apache.beam.fn.harness.control.ProcessBundleHandler.BundleProcessorCache; import org.apache.beam.fn.harness.data.BeamFnDataClient; import org.apache.beam.fn.harness.data.PCollectionConsumerRegistry; import org.apache.beam.fn.harness.data.PTransformFunctionRegistry; +import org.apache.beam.fn.harness.data.QueueingBeamFnDataClient; import org.apache.beam.fn.harness.state.BeamFnStateClient; import org.apache.beam.fn.harness.state.BeamFnStateGrpcClientCache; import org.apache.beam.model.fnexecution.v1.BeamFnApi; @@ -47,12 +57,29 @@ import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection; import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform; import org.apache.beam.model.pipeline.v1.RunnerApi.WindowingStrategy; +import org.apache.beam.runners.core.construction.CoderTranslation; +import org.apache.beam.runners.core.construction.ModelCoders; +import org.apache.beam.runners.core.construction.PTransformTranslation; +import org.apache.beam.runners.core.construction.ParDoTranslation; +import org.apache.beam.runners.core.metrics.ExecutionStateTracker; +import org.apache.beam.runners.core.metrics.MetricsContainerStepMap; +import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.function.ThrowingConsumer; +import org.apache.beam.sdk.function.ThrowingRunnable; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFnSchemaInformation; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.util.DoFnWithExecutionInformation; +import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.vendor.grpc.v1p21p0.com.google.protobuf.ByteString; import org.apache.beam.vendor.grpc.v1p21p0.com.google.protobuf.Message; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Multimap; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.Uninterruptibles; import org.junit.Before; import org.junit.Rule; @@ -82,6 +109,124 @@ public void setUp() { MockitoAnnotations.initMocks(this); } + private static class TestDoFn extends DoFn { + private static final TupleTag mainOutput = new TupleTag<>("mainOutput"); + + static List orderOfOperations = new ArrayList<>(); + + private enum State { + NOT_SET_UP, + SET_UP, + START_BUNDLE, + FINISH_BUNDLE, + TEAR_DOWN + } + + private TestDoFn.State state = TestDoFn.State.NOT_SET_UP; + + @Setup + public void setUp() { + checkState(TestDoFn.State.NOT_SET_UP.equals(state), "Unexpected state: %s", state); + state = TestDoFn.State.SET_UP; + orderOfOperations.add("setUp"); + } + + @Teardown + public void tearDown() { + checkState(!TestDoFn.State.TEAR_DOWN.equals(state), "Unexpected state: %s", state); + state = TestDoFn.State.TEAR_DOWN; + orderOfOperations.add("tearDown"); + } + + @StartBundle + public void startBundle() { + state = TestDoFn.State.START_BUNDLE; + orderOfOperations.add("startBundle"); + } + + @ProcessElement + public void processElement(ProcessContext context, BoundedWindow window) { + checkState(TestDoFn.State.START_BUNDLE.equals(state), "Unexpected state: %s", state); + } + + @FinishBundle + public void finishBundle(FinishBundleContext context) { + checkState(TestDoFn.State.START_BUNDLE.equals(state), "Unexpected state: %s", state); + state = TestDoFn.State.FINISH_BUNDLE; + orderOfOperations.add("finishBundle"); + } + } + + private static class TestBundleProcessor extends BundleProcessor { + static int resetCnt = 0; + + private BundleProcessor wrappedBundleProcessor; + + TestBundleProcessor(BundleProcessor wrappedBundleProcessor) { + this.wrappedBundleProcessor = wrappedBundleProcessor; + } + + @Override + PTransformFunctionRegistry getStartFunctionRegistry() { + return wrappedBundleProcessor.getStartFunctionRegistry(); + } + + @Override + PTransformFunctionRegistry getFinishFunctionRegistry() { + return wrappedBundleProcessor.getFinishFunctionRegistry(); + } + + @Override + List getTearDownFunctions() { + return wrappedBundleProcessor.getTearDownFunctions(); + } + + @Override + Multimap getAllResiduals() { + return wrappedBundleProcessor.getAllResiduals(); + } + + @Override + PCollectionConsumerRegistry getpCollectionConsumerRegistry() { + return wrappedBundleProcessor.getpCollectionConsumerRegistry(); + } + + @Override + MetricsContainerStepMap getMetricsContainerRegistry() { + return wrappedBundleProcessor.getMetricsContainerRegistry(); + } + + @Override + ExecutionStateTracker getStateTracker() { + return wrappedBundleProcessor.getStateTracker(); + } + + @Override + ProcessBundleHandler.HandleStateCallsForBundle getBeamFnStateClient() { + return wrappedBundleProcessor.getBeamFnStateClient(); + } + + @Override + QueueingBeamFnDataClient getQueueingClient() { + return wrappedBundleProcessor.getQueueingClient(); + } + + @Override + void reset() { + resetCnt++; + wrappedBundleProcessor.reset(); + } + } + + private static class TestBundleProcessorCache extends BundleProcessorCache { + + @Override + BundleProcessor get( + String bundleDescriptorId, Supplier bundleProcessorSupplier) { + return new TestBundleProcessor(super.get(bundleDescriptorId, bundleProcessorSupplier)); + } + } + @Test public void testOrderOfStartAndFinishCalls() throws Exception { BeamFnApi.ProcessBundleDescriptor processBundleDescriptor = @@ -118,14 +263,21 @@ public void testOrderOfStartAndFinishCalls() throws Exception { pCollectionConsumerRegistry, startFunctionRegistry, finishFunctionRegistry, + addTearDownFunction, splitListener) -> { - assertThat(processBundleInstructionId.get(), equalTo("999L")); - transformsProcessed.add(pTransform); startFunctionRegistry.register( - pTransformId, () -> orderOfOperations.add("Start" + pTransformId)); + pTransformId, + () -> { + assertThat(processBundleInstructionId.get(), equalTo("999L")); + orderOfOperations.add("Start" + pTransformId); + }); finishFunctionRegistry.register( - pTransformId, () -> orderOfOperations.add("Finish" + pTransformId)); + pTransformId, + () -> { + assertThat(processBundleInstructionId.get(), equalTo("999L")); + orderOfOperations.add("Finish" + pTransformId); + }); return null; }; @@ -137,7 +289,8 @@ public void testOrderOfStartAndFinishCalls() throws Exception { null /* beamFnStateClient */, ImmutableMap.of( DATA_INPUT_URN, startFinishRecorder, - DATA_OUTPUT_URN, startFinishRecorder)); + DATA_OUTPUT_URN, startFinishRecorder), + new BundleProcessorCache()); handler.processBundle( BeamFnApi.InstructionRequest.newBuilder() @@ -156,6 +309,217 @@ public void testOrderOfStartAndFinishCalls() throws Exception { assertThat(orderOfOperations, contains("Start3L", "Start2L", "Finish2L", "Finish3L")); } + @Test + public void testOrderOfSetupTeardownCalls() throws Exception { + DoFnWithExecutionInformation doFnWithExecutionInformation = + DoFnWithExecutionInformation.of( + new TestDoFn(), + TestDoFn.mainOutput, + Collections.emptyMap(), + DoFnSchemaInformation.create()); + RunnerApi.FunctionSpec functionSpec = + RunnerApi.FunctionSpec.newBuilder() + .setUrn(ParDoTranslation.CUSTOM_JAVA_DO_FN_URN) + .setPayload( + ByteString.copyFrom( + SerializableUtils.serializeToByteArray(doFnWithExecutionInformation))) + .build(); + RunnerApi.ParDoPayload parDoPayload = + RunnerApi.ParDoPayload.newBuilder() + .setDoFn(RunnerApi.SdkFunctionSpec.newBuilder().setSpec(functionSpec)) + .build(); + BeamFnApi.ProcessBundleDescriptor processBundleDescriptor = + BeamFnApi.ProcessBundleDescriptor.newBuilder() + .putTransforms( + "2L", + PTransform.newBuilder() + .setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(DATA_INPUT_URN).build()) + .putOutputs("2L-output", "2L-output-pc") + .build()) + .putTransforms( + "3L", + PTransform.newBuilder() + .setSpec( + RunnerApi.FunctionSpec.newBuilder() + .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN) + .setPayload(parDoPayload.toByteString())) + .putInputs("3L-input", "2L-output-pc") + .build()) + .putPcollections( + "2L-output-pc", + PCollection.newBuilder() + .setWindowingStrategyId("window-strategy") + .setCoderId("2L-output-coder") + .build()) + .putWindowingStrategies( + "window-strategy", + WindowingStrategy.newBuilder() + .setWindowCoderId("window-strategy-coder") + .setWindowFn( + RunnerApi.SdkFunctionSpec.newBuilder() + .setSpec( + RunnerApi.FunctionSpec.newBuilder() + .setUrn("beam:windowfn:global_windows:v0.1")) + .build()) + .setOutputTime(RunnerApi.OutputTime.Enum.END_OF_WINDOW) + .setAccumulationMode(RunnerApi.AccumulationMode.Enum.ACCUMULATING) + .setTrigger( + RunnerApi.Trigger.newBuilder() + .setAlways(RunnerApi.Trigger.Always.getDefaultInstance())) + .setClosingBehavior(RunnerApi.ClosingBehavior.Enum.EMIT_ALWAYS) + .setOnTimeBehavior(RunnerApi.OnTimeBehavior.Enum.FIRE_ALWAYS) + .build()) + .putCoders("2L-output-coder", CoderTranslation.toProto(StringUtf8Coder.of()).getCoder()) + .putCoders( + "window-strategy-coder", + Coder.newBuilder() + .setSpec( + RunnerApi.FunctionSpec.newBuilder() + .setUrn(ModelCoders.GLOBAL_WINDOW_CODER_URN) + .build()) + .build()) + .build(); + Map fnApiRegistry = ImmutableMap.of("1L", processBundleDescriptor); + + Map urnToPTransformRunnerFactoryMap = + Maps.newHashMap(REGISTERED_RUNNER_FACTORIES); + urnToPTransformRunnerFactoryMap.put( + DATA_INPUT_URN, + (pipelineOptions, + beamFnDataClient, + beamFnStateClient, + pTransformId, + pTransform, + processBundleInstructionId, + pCollections, + coders, + windowingStrategies, + pCollectionConsumerRegistry, + startFunctionRegistry, + finishFunctionRegistry, + addTearDownFunction, + splitListener) -> null); + + ProcessBundleHandler handler = + new ProcessBundleHandler( + PipelineOptionsFactory.create(), + fnApiRegistry::get, + beamFnDataClient, + null /* beamFnStateClient */, + urnToPTransformRunnerFactoryMap, + new BundleProcessorCache()); + + handler.processBundle( + BeamFnApi.InstructionRequest.newBuilder() + .setInstructionId("998L") + .setProcessBundle( + BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorId("1L")) + .build()); + + handler.processBundle( + BeamFnApi.InstructionRequest.newBuilder() + .setInstructionId("999L") + .setProcessBundle( + BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorId("1L")) + .build()); + + handler.shutdown(); + + // setup and teardown should occur only once when processing multiple bundles for the same + // descriptor + assertThat( + TestDoFn.orderOfOperations, + contains( + "setUp", "startBundle", "finishBundle", "startBundle", "finishBundle", "tearDown")); + } + + @Test + public void testBundleProcessorIsResetWhenAddedBackToCache() throws Exception { + BeamFnApi.ProcessBundleDescriptor processBundleDescriptor = + BeamFnApi.ProcessBundleDescriptor.newBuilder() + .putTransforms( + "2L", + RunnerApi.PTransform.newBuilder() + .setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(DATA_INPUT_URN).build()) + .build()) + .build(); + Map fnApiRegistry = ImmutableMap.of("1L", processBundleDescriptor); + + ProcessBundleHandler handler = + new ProcessBundleHandler( + PipelineOptionsFactory.create(), + fnApiRegistry::get, + beamFnDataClient, + null /* beamFnStateGrpcClientCache */, + ImmutableMap.of( + DATA_INPUT_URN, + (pipelineOptions, + beamFnDataClient, + beamFnStateClient, + pTransformId, + pTransform, + processBundleInstructionId, + pCollections, + coders, + windowingStrategies, + pCollectionConsumerRegistry, + startFunctionRegistry, + finishFunctionRegistry, + addTearDownFunction, + splitListener) -> null), + new TestBundleProcessorCache()); + + assertThat(TestBundleProcessor.resetCnt, equalTo(0)); + + handler.processBundle( + BeamFnApi.InstructionRequest.newBuilder() + .setInstructionId("998L") + .setProcessBundle( + BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorId("1L")) + .build()); + + // Check that BundleProcessor is reset when added back to the cache + assertThat(TestBundleProcessor.resetCnt, equalTo(1)); + + // BundleProcessor is added back to the BundleProcessorCache + assertThat(handler.bundleProcessorCache.getCachedBundleProcessors().size(), equalTo(1)); + assertThat( + handler.bundleProcessorCache.getCachedBundleProcessors().get("1L").size(), equalTo(1)); + } + + @Test + public void testBundleProcessorReset() { + PTransformFunctionRegistry startFunctionRegistry = mock(PTransformFunctionRegistry.class); + PTransformFunctionRegistry finishFunctionRegistry = mock(PTransformFunctionRegistry.class); + Multimap allResiduals = mock(Multimap.class); + PCollectionConsumerRegistry pCollectionConsumerRegistry = + mock(PCollectionConsumerRegistry.class); + MetricsContainerStepMap metricsContainerRegistry = mock(MetricsContainerStepMap.class); + ExecutionStateTracker stateTracker = mock(ExecutionStateTracker.class); + ProcessBundleHandler.HandleStateCallsForBundle beamFnStateClient = + mock(ProcessBundleHandler.HandleStateCallsForBundle.class); + QueueingBeamFnDataClient queueingClient = mock(QueueingBeamFnDataClient.class); + BundleProcessor bundleProcessor = + BundleProcessor.create( + startFunctionRegistry, + finishFunctionRegistry, + new ArrayList<>(), + allResiduals, + pCollectionConsumerRegistry, + metricsContainerRegistry, + stateTracker, + beamFnStateClient, + queueingClient); + + bundleProcessor.reset(); + verify(startFunctionRegistry, times(1)).reset(); + verify(finishFunctionRegistry, times(1)).reset(); + verify(allResiduals, times(1)).clear(); + verify(pCollectionConsumerRegistry, times(1)).reset(); + verify(metricsContainerRegistry, times(1)).reset(); + verify(stateTracker, times(1)).reset(); + } + @Test public void testCreatingPTransformExceptionsArePropagated() throws Exception { BeamFnApi.ProcessBundleDescriptor processBundleDescriptor = @@ -188,11 +552,13 @@ public void testCreatingPTransformExceptionsArePropagated() throws Exception { pCollectionConsumerRegistry, startFunctionRegistry, finishFunctionRegistry, + addTearDownFunction, splitListener) -> { thrown.expect(IllegalStateException.class); thrown.expectMessage("TestException"); throw new IllegalStateException("TestException"); - })); + }), + new BundleProcessorCache()); handler.processBundle( BeamFnApi.InstructionRequest.newBuilder() .setProcessBundle( @@ -233,18 +599,25 @@ public void testPTransformStartExceptionsArePropagated() throws Exception { pCollectionConsumerRegistry, startFunctionRegistry, finishFunctionRegistry, + addTearDownFunction, splitListener) -> { thrown.expect(IllegalStateException.class); thrown.expectMessage("TestException"); startFunctionRegistry.register( pTransformId, ProcessBundleHandlerTest::throwException); return null; - })); + }), + new BundleProcessorCache()); handler.processBundle( BeamFnApi.InstructionRequest.newBuilder() .setProcessBundle( BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorId("1L")) .build()); + + // BundleProcessor is not re-added back to the BundleProcessorCache in case of an exception + // during bundle processing + assertThat( + handler.bundleProcessorCache.getCachedBundleProcessors(), equalTo(Collections.EMPTY_MAP)); } @Test @@ -280,18 +653,25 @@ public void testPTransformFinishExceptionsArePropagated() throws Exception { pCollectionConsumerRegistry, startFunctionRegistry, finishFunctionRegistry, + addTearDownFunction, splitListener) -> { thrown.expect(IllegalStateException.class); thrown.expectMessage("TestException"); finishFunctionRegistry.register( pTransformId, ProcessBundleHandlerTest::throwException); return null; - })); + }), + new BundleProcessorCache()); handler.processBundle( BeamFnApi.InstructionRequest.newBuilder() .setProcessBundle( BeamFnApi.ProcessBundleRequest.newBuilder().setProcessBundleDescriptorId("1L")) .build()); + + // BundleProcessor is not re-added back to the BundleProcessorCache in case of an exception + // during bundle processing + assertThat( + handler.bundleProcessorCache.getCachedBundleProcessors(), equalTo(Collections.EMPTY_MAP)); } @Test @@ -365,6 +745,7 @@ public Object createRunnerForPTransform( PCollectionConsumerRegistry pCollectionConsumerRegistry, PTransformFunctionRegistry startFunctionRegistry, PTransformFunctionRegistry finishFunctionRegistry, + Consumer addTearDownFunction, BundleSplitListener splitListener) throws IOException { startFunctionRegistry.register( @@ -378,7 +759,8 @@ private void doStateCalls(BeamFnStateClient beamFnStateClient) { beamFnStateClient.handle( StateRequest.newBuilder().setInstructionId("FAIL"), unsuccessfulResponse); } - })); + }), + new BundleProcessorCache()); handler.processBundle( BeamFnApi.InstructionRequest.newBuilder() .setProcessBundle( @@ -424,6 +806,7 @@ public Object createRunnerForPTransform( PCollectionConsumerRegistry pCollectionConsumerRegistry, PTransformFunctionRegistry startFunctionRegistry, PTransformFunctionRegistry finishFunctionRegistry, + Consumer addTearDownFunction, BundleSplitListener splitListener) throws IOException { startFunctionRegistry.register( @@ -438,7 +821,8 @@ private void doStateCalls(BeamFnStateClient beamFnStateClient) { StateRequest.newBuilder().setInstructionId("SUCCESS"), new CompletableFuture<>()); } - })); + }), + new BundleProcessorCache()); handler.processBundle( BeamFnApi.InstructionRequest.newBuilder() .setProcessBundle( diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClientTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClientTest.java index 6ebd9615b8ca..672d41befaa4 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClientTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/BeamFnDataGrpcClientTest.java @@ -292,7 +292,7 @@ public StreamObserver data( BeamFnDataGrpcClient clientFactory = new BeamFnDataGrpcClient( PipelineOptionsFactory.fromArgs( - new String[] {"--experiments=beam_fn_api_data_buffer_limit=20"}) + new String[] {"--experiments=beam_fn_api_data_buffer_size_limit=20"}) .create(), (Endpoints.ApiServiceDescriptor descriptor) -> channel, OutboundObserverFactory.trivial()); diff --git a/sdks/java/io/amazon-web-services2/build.gradle b/sdks/java/io/amazon-web-services2/build.gradle index a52c827a393c..37465a284215 100644 --- a/sdks/java/io/amazon-web-services2/build.gradle +++ b/sdks/java/io/amazon-web-services2/build.gradle @@ -33,13 +33,16 @@ dependencies { compile library.java.aws_java_sdk2_dynamodb compile library.java.aws_java_sdk2_sdk_core compile library.java.aws_java_sdk2_sns + compile library.java.aws_java_sdk2_sqs compile library.java.jackson_core compile library.java.jackson_annotations compile library.java.jackson_databind compile library.java.slf4j_api + testCompile project(path: ":sdks:java:core", configuration: "shadowTest") testCompile library.java.hamcrest_core testCompile library.java.junit + testCompile 'org.elasticmq:elasticmq-rest-sqs_2.12:0.14.1' testCompile 'org.testcontainers:testcontainers:1.11.3' testRuntimeOnly library.java.slf4j_jdk14 testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow") diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/BasicSqsClientProvider.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/BasicSqsClientProvider.java new file mode 100644 index 000000000000..de5de2dd75a8 --- /dev/null +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/BasicSqsClientProvider.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.aws2.sqs; + +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument; + +import java.net.URI; +import javax.annotation.Nullable; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.sqs.SqsClient; +import software.amazon.awssdk.services.sqs.SqsClientBuilder; + +/** Basic implementation of {@link SqsClientProvider} used by default in {@link SqsIO}. */ +class BasicSqsClientProvider implements SqsClientProvider { + private final AwsCredentialsProvider awsCredentialsProvider; + private final String region; + @Nullable private final URI serviceEndpoint; + + BasicSqsClientProvider( + AwsCredentialsProvider awsCredentialsProvider, String region, @Nullable URI serviceEndpoint) { + checkArgument(awsCredentialsProvider != null, "awsCredentialsProvider can not be null"); + checkArgument(region != null, "region can not be null"); + this.awsCredentialsProvider = awsCredentialsProvider; + this.region = region; + this.serviceEndpoint = serviceEndpoint; + } + + @Override + public SqsClient getSqsClient() { + SqsClientBuilder builder = + SqsClient.builder().credentialsProvider(awsCredentialsProvider).region(Region.of(region)); + + if (serviceEndpoint != null) { + builder.endpointOverride(serviceEndpoint); + } + + return builder.build(); + } +} diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/MessageCoder.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/MessageCoder.java new file mode 100644 index 000000000000..4cf5da34a0fd --- /dev/null +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/MessageCoder.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.aws2.sqs; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Serializable; +import org.apache.beam.sdk.coders.AtomicCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import software.amazon.awssdk.services.sqs.model.Message; + +/** Custom Coder for handling SendMessageRequest for using in Write. */ +public class MessageCoder extends AtomicCoder implements Serializable { + private static final MessageCoder INSTANCE = new MessageCoder(); + + private MessageCoder() {} + + static MessageCoder of() { + return INSTANCE; + } + + @Override + public void encode(Message value, OutputStream outStream) throws IOException { + StringUtf8Coder.of().encode(value.messageId(), outStream); + StringUtf8Coder.of().encode(value.body(), outStream); + } + + @Override + public Message decode(InputStream inStream) throws IOException { + final String messageId = StringUtf8Coder.of().decode(inStream); + final String body = StringUtf8Coder.of().decode(inStream); + return Message.builder().messageId(messageId).body(body).build(); + } +} diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/MessageCoderRegistrar.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/MessageCoderRegistrar.java new file mode 100644 index 000000000000..0b72338dc141 --- /dev/null +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/MessageCoderRegistrar.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.aws2.sqs; + +import com.google.auto.service.AutoService; +import java.util.List; +import org.apache.beam.sdk.coders.CoderProvider; +import org.apache.beam.sdk.coders.CoderProviderRegistrar; +import org.apache.beam.sdk.coders.CoderProviders; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; +import software.amazon.awssdk.services.sqs.model.Message; + +/** A {@link CoderProviderRegistrar} for standard types used with {@link SqsIO}. */ +@AutoService(CoderProviderRegistrar.class) +public class MessageCoderRegistrar implements CoderProviderRegistrar { + @Override + public List getCoderProviders() { + return ImmutableList.of( + CoderProviders.forCoder(TypeDescriptor.of(Message.class), MessageCoder.of())); + } +} diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SendMessageRequestCoder.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SendMessageRequestCoder.java new file mode 100644 index 000000000000..e8c0283317bc --- /dev/null +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SendMessageRequestCoder.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.aws2.sqs; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.Serializable; +import org.apache.beam.sdk.coders.AtomicCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import software.amazon.awssdk.services.sqs.model.SendMessageRequest; + +/** Custom Coder for handling SendMessageRequest for using in Write. */ +public class SendMessageRequestCoder extends AtomicCoder + implements Serializable { + private static final SendMessageRequestCoder INSTANCE = new SendMessageRequestCoder(); + + private SendMessageRequestCoder() {} + + static SendMessageRequestCoder of() { + return INSTANCE; + } + + @Override + public void encode(SendMessageRequest value, OutputStream outStream) throws IOException { + StringUtf8Coder.of().encode(value.queueUrl(), outStream); + StringUtf8Coder.of().encode(value.messageBody(), outStream); + } + + @Override + public SendMessageRequest decode(InputStream inStream) throws IOException { + final String queueUrl = StringUtf8Coder.of().decode(inStream); + final String message = StringUtf8Coder.of().decode(inStream); + return SendMessageRequest.builder().queueUrl(queueUrl).messageBody(message).build(); + } +} diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SendMessageRequestCoderRegistrar.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SendMessageRequestCoderRegistrar.java new file mode 100644 index 000000000000..814f1f34ad87 --- /dev/null +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SendMessageRequestCoderRegistrar.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.aws2.sqs; + +import com.google.auto.service.AutoService; +import java.util.List; +import org.apache.beam.sdk.coders.CoderProvider; +import org.apache.beam.sdk.coders.CoderProviderRegistrar; +import org.apache.beam.sdk.coders.CoderProviders; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; +import software.amazon.awssdk.services.sqs.model.SendMessageRequest; + +/** A {@link CoderProviderRegistrar} for standard types used with {@link SqsIO}. */ +@AutoService(CoderProviderRegistrar.class) +public class SendMessageRequestCoderRegistrar implements CoderProviderRegistrar { + @Override + public List getCoderProviders() { + return ImmutableList.of( + CoderProviders.forCoder( + TypeDescriptor.of(SendMessageRequest.class), SendMessageRequestCoder.of())); + } +} diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsCheckpointMark.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsCheckpointMark.java new file mode 100644 index 000000000000..a0ec56b7b528 --- /dev/null +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsCheckpointMark.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.aws2.sqs; + +import java.io.Serializable; +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import org.apache.beam.sdk.io.UnboundedSource; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Objects; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; +import software.amazon.awssdk.services.sqs.model.Message; + +class SqsCheckpointMark implements UnboundedSource.CheckpointMark, Serializable { + + private final List messagesToDelete; + private final transient Optional reader; + + SqsCheckpointMark(SqsUnboundedReader reader, Collection messagesToDelete) { + this.reader = Optional.of(reader); + this.messagesToDelete = ImmutableList.copyOf(messagesToDelete); + } + + @Override + public void finalizeCheckpoint() { + reader.ifPresent(r -> r.delete(messagesToDelete)); + } + + List getMessagesToDelete() { + return messagesToDelete; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + SqsCheckpointMark that = (SqsCheckpointMark) o; + return Objects.equal(messagesToDelete, that.messagesToDelete); + } + + @Override + public int hashCode() { + return Objects.hashCode(messagesToDelete); + } +} diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsClientProvider.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsClientProvider.java new file mode 100644 index 000000000000..5d7d1525f4ef --- /dev/null +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsClientProvider.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.aws2.sqs; + +import java.io.Serializable; +import software.amazon.awssdk.services.sqs.SqsClient; + +/** + * Provides instances of Sqs clients. + * + *

Please note, that any instance of {@link SqsClientProvider} must be {@link Serializable} to + * ensure it can be sent to worker machines. + */ +public interface SqsClientProvider extends Serializable { + SqsClient getSqsClient(); +} diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsIO.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsIO.java new file mode 100644 index 000000000000..7d92f892a9f3 --- /dev/null +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsIO.java @@ -0,0 +1,272 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.aws2.sqs; + +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument; + +import com.google.auto.value.AutoValue; +import java.net.URI; +import javax.annotation.Nullable; +import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.io.aws2.options.AwsOptions; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.values.PBegin; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PDone; +import org.joda.time.Duration; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.services.sqs.SqsClient; +import software.amazon.awssdk.services.sqs.model.SendMessageRequest; + +/** + * An unbounded source for Amazon Simple Queue Service (SQS). + * + *

Reading from an SQS queue

+ * + *

The {@link SqsIO} {@link Read} returns an unbounded {@link PCollection} of {@link + * software.amazon.awssdk.services.sqs.model.Message} containing the received messages. Note: This + * source does not currently advance the watermark when no new messages are received. + * + *

To configure an SQS source, you have to provide the queueUrl to connect to. The following + * example illustrates how to configure the source: + * + *

{@code
+ * pipeline.apply(SqsIO.read().withQueueUrl(queueUrl))
+ * }
+ * + *

Writing to an SQS queue

+ * + *

The following example illustrates how to use the sink: + * + *

{@code
+ * pipeline
+ *   .apply(...) // returns PCollection
+ *   .apply(SqsIO.write())
+ * }
+ * + *

Additional Configuration

+ * + *

Additional configuration can be provided via {@link AwsCredentialsProvider} in code. For + * example, if you wanted to provide a secret access key via code: + * + *

{@code
+ * AwsCredentialsProvider provider = StaticCredentialsProvider.create(
+ *    AwsBasicCredentials.create(ACCESS_KEY_ID, SECRET_ACCESS_KEY));
+ * pipeline
+ *   .apply(...) // returns PCollection
+ *   .apply(SqsIO.write().withSqsClientProvider(provider))
+ * }
+ * + *

For more information on the available options see {@link AwsOptions}. + */ +@Experimental(Experimental.Kind.SOURCE_SINK) +public class SqsIO { + + public static Read read() { + return new AutoValue_SqsIO_Read.Builder().setMaxNumRecords(Long.MAX_VALUE).build(); + } + + public static Write write() { + return new AutoValue_SqsIO_Write.Builder().build(); + } + + private SqsIO() {} + + /** + * A {@link PTransform} to read/receive messages from SQS. See {@link SqsIO} for more information + * on usage and configuration. + */ + @AutoValue + public abstract static class Read extends PTransform> { + + @Nullable + abstract String queueUrl(); + + abstract long maxNumRecords(); + + @Nullable + abstract Duration maxReadTime(); + + @Nullable + abstract SqsClientProvider sqsClientProvider(); + + abstract Builder builder(); + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setQueueUrl(String queueUrl); + + abstract Builder setMaxNumRecords(long maxNumRecords); + + abstract Builder setMaxReadTime(Duration maxReadTime); + + abstract Builder setSqsClientProvider(SqsClientProvider sqsClientProvider); + + abstract Read build(); + } + + /** + * Define the max number of records received by the {@link Read}. When the max number of records + * is lower than {@code Long.MAX_VALUE}, the {@link Read} will provide a bounded {@link + * PCollection}. + */ + public Read withMaxNumRecords(long maxNumRecords) { + return builder().setMaxNumRecords(maxNumRecords).build(); + } + + /** + * Define the max read time (duration) while the {@link Read} will receive messages. When this + * max read time is not null, the {@link Read} will provide a bounded {@link PCollection}. + */ + public Read withMaxReadTime(Duration maxReadTime) { + return builder().setMaxReadTime(maxReadTime).build(); + } + + /** Define the queueUrl used by the {@link Read} to receive messages from SQS. */ + public Read withQueueUrl(String queueUrl) { + checkArgument(queueUrl != null, "queueUrl can not be null"); + checkArgument(!queueUrl.isEmpty(), "queueUrl can not be empty"); + return builder().setQueueUrl(queueUrl).build(); + } + + /** + * Allows to specify custom {@link SqsClientProvider}. {@link SqsClientProvider} creates new + * {@link SqsClient} which is later used for writing to a SqS queue. + */ + public Read withSqsClientProvider(SqsClientProvider awsClientsProvider) { + return builder().setSqsClientProvider(awsClientsProvider).build(); + } + + /** + * Specify {@link software.amazon.awssdk.auth.credentials.AwsCredentialsProvider} and region to + * be used to read from SQS. If you need more sophisticated credential protocol, then you should + * look at {@link Read#withSqsClientProvider(SqsClientProvider)}. + */ + public Read withSqsClientProvider(AwsCredentialsProvider credentialsProvider, String region) { + return withSqsClientProvider(credentialsProvider, region, null); + } + + /** + * Specify {@link AwsCredentialsProvider} and region to be used to write to SQS. If you need + * more sophisticated credential protocol, then you should look at {@link + * Read#withSqsClientProvider(SqsClientProvider)}. + * + *

The {@code serviceEndpoint} sets an alternative service host. This is useful to execute + * the tests with Kinesis service emulator. + */ + public Read withSqsClientProvider( + AwsCredentialsProvider credentialsProvider, String region, URI serviceEndpoint) { + return withSqsClientProvider( + new BasicSqsClientProvider(credentialsProvider, region, serviceEndpoint)); + } + + @Override + public PCollection expand(PBegin input) { + + org.apache.beam.sdk.io.Read.Unbounded unbounded = + org.apache.beam.sdk.io.Read.from(new SqsUnboundedSource(this)); + + PTransform> transform = unbounded; + + if (maxNumRecords() < Long.MAX_VALUE || maxReadTime() != null) { + transform = unbounded.withMaxReadTime(maxReadTime()).withMaxNumRecords(maxNumRecords()); + } + + return input.getPipeline().apply(transform); + } + } + // TODO: Add write batch api to improve performance + /** + * A {@link PTransform} to send messages to SQS. See {@link SqsIO} for more information on usage + * and configuration. + */ + @AutoValue + public abstract static class Write extends PTransform, PDone> { + + @Nullable + abstract SqsClientProvider getSqsClientProvider(); + + abstract Builder builder(); + + @AutoValue.Builder + abstract static class Builder { + + abstract Builder setSqsClientProvider(SqsClientProvider sqsClientProvider); + + abstract Write build(); + } + + /** + * Allows to specify custom {@link SqsClientProvider}. {@link SqsClientProvider} creates new + * {@link SqsClient} which is later used for writing to a SqS queue. + */ + public Write withSqsClientProvider(SqsClientProvider awsClientsProvider) { + return builder().setSqsClientProvider(awsClientsProvider).build(); + } + + /** + * Specify {@link software.amazon.awssdk.auth.credentials.AwsCredentialsProvider} and region to + * be used to write to SQS. If you need more sophisticated credential protocol, then you should + * look at {@link Write#withSqsClientProvider(SqsClientProvider)}. + */ + public Write withSqsClientProvider(AwsCredentialsProvider credentialsProvider, String region) { + return withSqsClientProvider(credentialsProvider, region, null); + } + + /** + * Specify {@link AwsCredentialsProvider} and region to be used to write to SQS. If you need + * more sophisticated credential protocol, then you should look at {@link + * Write#withSqsClientProvider(SqsClientProvider)}. + * + *

The {@code serviceEndpoint} sets an alternative service host. This is useful to execute + * the tests with Kinesis service emulator. + */ + public Write withSqsClientProvider( + AwsCredentialsProvider credentialsProvider, String region, URI serviceEndpoint) { + return withSqsClientProvider( + new BasicSqsClientProvider(credentialsProvider, region, serviceEndpoint)); + } + + @Override + public PDone expand(PCollection input) { + input.apply(ParDo.of(new SqsWriteFn(this))); + return PDone.in(input.getPipeline()); + } + } + + private static class SqsWriteFn extends DoFn { + private final Write spec; + private transient SqsClient sqs; + + SqsWriteFn(Write write) { + this.spec = write; + } + + @Setup + public void setup() { + sqs = spec.getSqsClientProvider().getSqsClient(); + } + + @ProcessElement + public void processElement(ProcessContext processContext) throws Exception { + sqs.sendMessage(processContext.element()); + } + } +} diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsMessage.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsMessage.java new file mode 100644 index 000000000000..7d28dc546f48 --- /dev/null +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsMessage.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.aws2.sqs; + +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument; + +import com.google.auto.value.AutoValue; +import java.io.Serializable; +import javax.annotation.Nullable; + +@AutoValue +public abstract class SqsMessage implements Serializable { + + @Nullable + abstract String getBody(); + + @Nullable + abstract String getMessageId(); + + @Nullable + abstract String getTimeStamp(); + + abstract Builder toBuilder(); + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setBody(String body); + + abstract Builder setMessageId(String messageId); + + abstract Builder setTimeStamp(String timeStamp); + + abstract SqsMessage build(); + } + + static SqsMessage create(String body, String messageId, String timeStamp) { + checkArgument(body != null, "body can not be null"); + checkArgument(messageId != null, "messageId can not be null"); + checkArgument(timeStamp != null, "timeStamp can not be null"); + + return new AutoValue_SqsMessage.Builder() + .setBody(body) + .setMessageId(messageId) + .setTimeStamp(timeStamp) + .build(); + } +} diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsUnboundedReader.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsUnboundedReader.java new file mode 100644 index 000000000000..72e8d7a55637 --- /dev/null +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsUnboundedReader.java @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.aws2.sqs; + +import java.io.Serializable; +import java.nio.charset.StandardCharsets; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.Queue; +import org.apache.beam.sdk.io.UnboundedSource; +import org.apache.beam.sdk.io.UnboundedSource.CheckpointMark; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.joda.time.Instant; +import software.amazon.awssdk.services.sqs.model.DeleteMessageRequest; +import software.amazon.awssdk.services.sqs.model.Message; +import software.amazon.awssdk.services.sqs.model.MessageSystemAttributeName; +import software.amazon.awssdk.services.sqs.model.ReceiveMessageRequest; +import software.amazon.awssdk.services.sqs.model.ReceiveMessageResponse; + +class SqsUnboundedReader extends UnboundedSource.UnboundedReader + implements Serializable { + + public static final int MAX_NUMBER_OF_MESSAGES = 10; + private final SqsUnboundedSource source; + private SqsMessage current; + private final Queue messagesNotYetRead; + private List messagesToDelete; + private Instant oldestPendingTimestamp = BoundedWindow.TIMESTAMP_MIN_VALUE; + + public SqsUnboundedReader(SqsUnboundedSource source, SqsCheckpointMark sqsCheckpointMark) { + this.source = source; + this.current = null; + + this.messagesNotYetRead = new ArrayDeque<>(); + this.messagesToDelete = new ArrayList<>(); + + if (sqsCheckpointMark != null) { + this.messagesToDelete.addAll(sqsCheckpointMark.getMessagesToDelete()); + } + } + + @Override + public Instant getWatermark() { + return oldestPendingTimestamp; + } + + @Override + public SqsMessage getCurrent() throws NoSuchElementException { + if (current == null) { + throw new NoSuchElementException(); + } + return current; + } + + @Override + public Instant getCurrentTimestamp() throws NoSuchElementException { + if (current == null) { + throw new NoSuchElementException(); + } + + return getTimestamp(current.getTimeStamp()); + } + + @Override + public byte[] getCurrentRecordId() throws NoSuchElementException { + if (current == null) { + throw new NoSuchElementException(); + } + return current.getMessageId().getBytes(StandardCharsets.UTF_8); + } + + @Override + public CheckpointMark getCheckpointMark() { + return new SqsCheckpointMark(this, messagesToDelete); + } + + @Override + public SqsUnboundedSource getCurrentSource() { + return source; + } + + @Override + public boolean start() { + return advance(); + } + + @Override + public boolean advance() { + if (messagesNotYetRead.isEmpty()) { + pull(); + } + + Message orgMsg = messagesNotYetRead.poll(); + if (orgMsg != null) { + String timeStamp = + orgMsg.attributes().get(MessageSystemAttributeName.APPROXIMATE_FIRST_RECEIVE_TIMESTAMP); + current = SqsMessage.create(orgMsg.body(), orgMsg.messageId(), timeStamp); + } else { + return false; + } + + messagesToDelete.add(orgMsg); + + Instant currentMessageTimestamp = getCurrentTimestamp(); + if (getCurrentTimestamp().isBefore(oldestPendingTimestamp)) { + oldestPendingTimestamp = currentMessageTimestamp; + } + + return true; + } + + @Override + public void close() {} + + void delete(final Collection messages) { + for (Message message : messages) { + if (messagesToDelete.contains(message)) { + DeleteMessageRequest deleteMessageRequest = + DeleteMessageRequest.builder() + .queueUrl(source.getRead().queueUrl()) + .receiptHandle(message.receiptHandle()) + .build(); + + source.getSqs().deleteMessage(deleteMessageRequest); + Instant currentMessageTimestamp = + getTimestamp( + message + .attributes() + .get(MessageSystemAttributeName.APPROXIMATE_FIRST_RECEIVE_TIMESTAMP)); + if (currentMessageTimestamp.isAfter(oldestPendingTimestamp)) { + oldestPendingTimestamp = currentMessageTimestamp; + } + } + } + } + + private void pull() { + final ReceiveMessageRequest receiveMessageRequest = + ReceiveMessageRequest.builder() + .maxNumberOfMessages(MAX_NUMBER_OF_MESSAGES) + .attributeNamesWithStrings( + MessageSystemAttributeName.APPROXIMATE_FIRST_RECEIVE_TIMESTAMP.toString()) + .queueUrl(source.getRead().queueUrl()) + .build(); + + final ReceiveMessageResponse receiveMessageResponse = + source.getSqs().receiveMessage(receiveMessageRequest); + + final List messages = receiveMessageResponse.messages(); + + if (messages == null || messages.isEmpty()) { + return; + } + + messagesNotYetRead.addAll(messages); + } + + private Instant getTimestamp(String timeStamp) { + return new Instant(Long.parseLong(timeStamp)); + } +} diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsUnboundedSource.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsUnboundedSource.java new file mode 100644 index 000000000000..b3c10db5cfd8 --- /dev/null +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsUnboundedSource.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.aws2.sqs; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; +import javax.annotation.Nullable; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.SerializableCoder; +import org.apache.beam.sdk.io.UnboundedSource; +import org.apache.beam.sdk.io.aws2.sqs.SqsIO.Read; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Supplier; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Suppliers; +import software.amazon.awssdk.services.sqs.SqsClient; + +class SqsUnboundedSource extends UnboundedSource { + private final Read read; + private final Supplier sqs; + + public SqsUnboundedSource(Read read) { + this.read = read; + sqs = + Suppliers.memoize( + (Supplier & Serializable) () -> read.sqsClientProvider().getSqsClient()); + } + + @Override + public List split(int desiredNumSplits, PipelineOptions options) { + List sources = new ArrayList<>(); + for (int i = 0; i < Math.max(1, desiredNumSplits); ++i) { + sources.add(new SqsUnboundedSource(read)); + } + return sources; + } + + @Override + public UnboundedReader createReader( + PipelineOptions options, @Nullable SqsCheckpointMark checkpointMark) { + return new SqsUnboundedReader(this, checkpointMark); + } + + @Override + public Coder getCheckpointMarkCoder() { + return SerializableCoder.of(SqsCheckpointMark.class); + } + + @Override + public Coder getOutputCoder() { + return SerializableCoder.of(SqsMessage.class); + } + + public Read getRead() { + return read; + } + + public SqsClient getSqs() { + return sqs.get(); + } + + @Override + public boolean requiresDeduping() { + return true; + } +} diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/package-info.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/package-info.java new file mode 100644 index 000000000000..601844d70c3b --- /dev/null +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/package-info.java @@ -0,0 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** Defines IO connectors for Amazon Web Services SQS. */ +package org.apache.beam.sdk.io.aws2.sqs; diff --git a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/EmbeddedSqsServer.java b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/EmbeddedSqsServer.java new file mode 100644 index 000000000000..2b826931297f --- /dev/null +++ b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/EmbeddedSqsServer.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.aws2.sqs; + +import java.net.URI; +import org.elasticmq.rest.sqs.SQSRestServer; +import org.elasticmq.rest.sqs.SQSRestServerBuilder; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.sqs.SqsClient; +import software.amazon.awssdk.services.sqs.model.CreateQueueRequest; +import software.amazon.awssdk.services.sqs.model.CreateQueueResponse; + +class EmbeddedSqsServer { + private static SQSRestServer sqsRestServer; + private static SqsClient client; + private static String queueUrl; + private static int port = 9234; + private static String endPoint = String.format("http://localhost:%d", port); + private static String queueName = "test"; + + static void start() { + sqsRestServer = SQSRestServerBuilder.withPort(port).start(); + + client = + SqsClient.builder() + .credentialsProvider( + StaticCredentialsProvider.create(AwsBasicCredentials.create("x", "x"))) + .endpointOverride(URI.create(endPoint)) + .region(Region.US_WEST_2) + .build(); + + CreateQueueRequest createQueueRequest = + CreateQueueRequest.builder().queueName(queueName).build(); + final CreateQueueResponse queue = client.createQueue(createQueueRequest); + queueUrl = queue.queueUrl(); + } + + static SqsClient getClient() { + return client; + } + + static String getQueueUrl() { + return queueUrl; + } + + static void stop() { + sqsRestServer.stopAndWait(); + } +} diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/TableResolver.java b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/SqsClientProviderMock.java similarity index 54% rename from sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/TableResolver.java rename to sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/SqsClientProviderMock.java index b7e516e7d3d6..bd341526bd3c 100644 --- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/TableResolver.java +++ b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/SqsClientProviderMock.java @@ -15,22 +15,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.beam.sdk.extensions.sql.zetasql; +package org.apache.beam.sdk.io.aws2.sqs; -import java.util.List; -import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.schema.Schema; -import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.schema.Table; +import software.amazon.awssdk.services.sqs.SqsClient; -/** An interface to implement a custom resolution strategy. */ -interface TableResolver { +/** Mocking AwsClientProvider. */ +public class SqsClientProviderMock implements SqsClientProvider { - TableResolver DEFAULT_ASSUME_LEAF_IS_TABLE = TableResolverImpl::assumeLeafIsTable; - TableResolver JOIN_INTO_COMPOUND_ID = TableResolverImpl::joinIntoCompoundId; + private static SqsClientProviderMock instance = new SqsClientProviderMock(); + private static SqsClient sqsClient; - /** - * Returns a resolved table given a table path. - * - *

Returns null if table is not found. - */ - Table resolveCalciteTable(Schema calciteSchema, List tablePath); + private SqsClientProviderMock() {} + + public static SqsClientProviderMock of(SqsClient sqs) { + sqsClient = sqs; + return instance; + } + + @Override + public SqsClient getSqsClient() { + return sqsClient; + } } diff --git a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/SqsIOTest.java b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/SqsIOTest.java new file mode 100644 index 000000000000..a3416dc28103 --- /dev/null +++ b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/SqsIOTest.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.aws2.sqs; + +import static org.junit.Assert.assertEquals; + +import java.util.ArrayList; +import java.util.List; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Count; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.values.PCollection; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import software.amazon.awssdk.services.sqs.SqsClient; +import software.amazon.awssdk.services.sqs.model.Message; +import software.amazon.awssdk.services.sqs.model.ReceiveMessageRequest; +import software.amazon.awssdk.services.sqs.model.ReceiveMessageResponse; +import software.amazon.awssdk.services.sqs.model.SendMessageRequest; + +/** Tests on {@link SqsIO}. */ +@RunWith(JUnit4.class) +public class SqsIOTest { + + @Rule public TestPipeline pipeline = TestPipeline.create(); + + @Test + public void testRead() { + final SqsClient client = EmbeddedSqsServer.getClient(); + final String queueUrl = EmbeddedSqsServer.getQueueUrl(); + + final PCollection output = + pipeline.apply( + SqsIO.read() + .withSqsClientProvider(SqsClientProviderMock.of(client)) + .withQueueUrl(queueUrl) + .withMaxNumRecords(100)); + + PAssert.thatSingleton(output.apply(Count.globally())).isEqualTo(100L); + + for (int i = 0; i < 100; i++) { + SendMessageRequest sendMessageRequest = + SendMessageRequest.builder().queueUrl(queueUrl).messageBody("This is a test").build(); + client.sendMessage(sendMessageRequest); + } + pipeline.run(); + } + + @Test + public void testWrite() { + final SqsClient client = EmbeddedSqsServer.getClient(); + final String queueUrl = EmbeddedSqsServer.getQueueUrl(); + + List messages = new ArrayList<>(); + for (int i = 0; i < 100; i++) { + final SendMessageRequest request = + SendMessageRequest.builder() + .queueUrl(queueUrl) + .messageBody("This is a test " + i) + .build(); + messages.add(request); + } + + pipeline + .apply(Create.of(messages)) + .apply(SqsIO.write().withSqsClientProvider(SqsClientProviderMock.of(client))); + pipeline.run().waitUntilFinish(); + + List received = new ArrayList<>(); + while (received.size() < 100) { + ReceiveMessageRequest receiveMessageRequest = + ReceiveMessageRequest.builder().queueUrl(queueUrl).build(); + final ReceiveMessageResponse receiveMessageResponse = + client.receiveMessage(receiveMessageRequest); + + if (receiveMessageResponse != null) { + for (Message message : receiveMessageResponse.messages()) { + received.add(message.body()); + } + } + } + + assertEquals(100, received.size()); + for (int i = 0; i < 100; i++) { + received.contains("This is a test " + i); + } + } + + @BeforeClass + public static void before() { + EmbeddedSqsServer.start(); + } + + @AfterClass + public static void after() { + EmbeddedSqsServer.stop(); + } +} diff --git a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraIO.java b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraIO.java index 8368982dc407..2588f5cbd142 100644 --- a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraIO.java +++ b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraIO.java @@ -409,7 +409,7 @@ public List> split( spec, desiredBundleSizeBytes, getEstimatedSizeBytes(pipelineOptions), cluster); } else { LOG.warn( - "Only Murmur3Partitioner is supported for splitting, using an unique source for " + "Only Murmur3Partitioner is supported for splitting, using a unique source for " + "the read"); return Collections.singletonList( new CassandraIO.CassandraSource<>(spec, Collections.singletonList(buildQuery(spec)))); diff --git a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-2/build.gradle b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-2/build.gradle index 482fc7dd835d..0c4411d81920 100644 --- a/sdks/java/io/elasticsearch-tests/elasticsearch-tests-2/build.gradle +++ b/sdks/java/io/elasticsearch-tests/elasticsearch-tests-2/build.gradle @@ -45,7 +45,6 @@ dependencies { testCompile library.java.commons_io_1x testCompile library.java.junit testCompile "org.elasticsearch.client:elasticsearch-rest-client:5.6.3" - testCompile library.java.vendored_guava_26_0_jre testCompile "org.elasticsearch:elasticsearch:$elastic_search_version" testRuntimeOnly library.java.slf4j_jdk14 testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow") diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchLoads.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchLoads.java index 23c81c5337ab..f097e47cb88c 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchLoads.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BatchLoads.java @@ -22,7 +22,9 @@ import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState; import com.google.api.services.bigquery.model.TableRow; +import java.util.Collections; import java.util.List; +import java.util.Set; import java.util.concurrent.ThreadLocalRandom; import javax.annotation.Nullable; import org.apache.beam.sdk.Pipeline; @@ -36,6 +38,7 @@ import org.apache.beam.sdk.coders.VoidCoder; import org.apache.beam.sdk.extensions.gcp.util.gcsfs.GcsPath; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.SchemaUpdateOption; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.WriteDisposition; import org.apache.beam.sdk.io.gcp.bigquery.WriteBundlesToFiles.Result; import org.apache.beam.sdk.options.PipelineOptions; @@ -115,6 +118,7 @@ class BatchLoads private BigQueryServices bigQueryServices; private final WriteDisposition writeDisposition; private final CreateDisposition createDisposition; + private Set schemaUpdateOptions; private final boolean ignoreUnknownValues; // Indicates that we are writing to a constant single table. If this is the case, we will create // the table, even if there is no data in it. @@ -166,6 +170,11 @@ class BatchLoads this.elementCoder = elementCoder; this.kmsKey = kmsKey; this.rowWriterFactory = rowWriterFactory; + schemaUpdateOptions = Collections.emptySet(); + } + + void setSchemaUpdateOptions(Set schemaUpdateOptions) { + this.schemaUpdateOptions = schemaUpdateOptions; } void setTestServices(BigQueryServices bigQueryServices) { @@ -587,7 +596,8 @@ private PCollection> writeTempTables( maxRetryJobs, ignoreUnknownValues, kmsKey, - rowWriterFactory.getSourceFormat())); + rowWriterFactory.getSourceFormat(), + schemaUpdateOptions)); } // In the case where the files fit into a single load job, there's no need to write temporary @@ -621,7 +631,8 @@ void writeSinglePartition( maxRetryJobs, ignoreUnknownValues, kmsKey, - rowWriterFactory.getSourceFormat())); + rowWriterFactory.getSourceFormat(), + schemaUpdateOptions)); } private WriteResult writeResult(Pipeline p) { diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java index 2059467fa2d0..3bd9d8c81813 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java @@ -43,8 +43,10 @@ import com.google.cloud.bigquery.storage.v1beta1.Storage.ReadSession; import com.google.cloud.bigquery.storage.v1beta1.Storage.Stream; import java.io.IOException; +import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.regex.Pattern; import java.util.stream.Collectors; import javax.annotation.Nullable; @@ -1635,6 +1637,7 @@ public static Write write() { .setBigQueryServices(new BigQueryServicesImpl()) .setCreateDisposition(Write.CreateDisposition.CREATE_IF_NEEDED) .setWriteDisposition(Write.WriteDisposition.WRITE_EMPTY) + .setSchemaUpdateOptions(Collections.emptySet()) .setNumFileShards(0) .setMethod(Write.Method.DEFAULT) .setExtendedErrorInfo(false) @@ -1729,6 +1732,8 @@ public enum Method { abstract CreateDisposition getCreateDisposition(); abstract WriteDisposition getWriteDisposition(); + + abstract Set getSchemaUpdateOptions(); /** Table description. Default is empty. */ @Nullable abstract String getTableDescription(); @@ -1807,6 +1812,8 @@ abstract Builder setAvroSchemaFactory( abstract Builder setWriteDisposition(WriteDisposition writeDisposition); + abstract Builder setSchemaUpdateOptions(Set schemaUpdateOptions); + abstract Builder setTableDescription(String tableDescription); abstract Builder setValidate(boolean validate); @@ -1910,6 +1917,25 @@ public enum WriteDisposition { WRITE_EMPTY } + /** + * An enumeration type for the BigQuery schema update options strings. + * + *

Note from the BigQuery API doc -- Schema update options are supported in two cases: when + * writeDisposition is WRITE_APPEND; when writeDisposition is WRITE_TRUNCATE and the destination + * table is a partition of a table, specified by partition decorators. + * + * @see + * configuration.query.schemaUpdateOptions in the BigQuery Jobs API + */ + public enum SchemaUpdateOption { + /** Allow adding a nullable field to the schema. */ + ALLOW_FIELD_ADDITION, + + /** Allow relaxing a required field in the original schema to nullable. */ + ALLOW_FIELD_RELAXATION + } + /** * Writes to the given table, specified in the format described in {@link * BigQueryHelpers#parseTableSpec}. @@ -2098,6 +2124,12 @@ public Write withWriteDisposition(WriteDisposition writeDisposition) { return toBuilder().setWriteDisposition(writeDisposition).build(); } + /** Allows the schema of the destination table to be updated as a side effect of the write. */ + public Write withSchemaUpdateOptions(Set schemaUpdateOptions) { + checkArgument(schemaUpdateOptions != null, "schemaUpdateOptions can not be null"); + return toBuilder().setSchemaUpdateOptions(schemaUpdateOptions).build(); + } + /** Specifies the table description. */ public Write withTableDescription(String tableDescription) { checkArgument(tableDescription != null, "tableDescription can not be null"); @@ -2589,6 +2621,9 @@ private WriteResult continueExpandTyped( rowWriterFactory, getKmsKey()); batchLoads.setTestServices(getBigQueryServices()); + if (getSchemaUpdateOptions() != null) { + batchLoads.setSchemaUpdateOptions(getSchemaUpdateOptions()); + } if (getMaxFilesPerBundle() != null) { batchLoads.setMaxNumWritersPerBundle(getMaxFilesPerBundle()); } @@ -2634,6 +2669,9 @@ public void populateDisplayData(DisplayData.Builder builder) { .add( DisplayData.item("writeDisposition", getWriteDisposition().toString()) .withLabel("Table WriteDisposition")) + .add( + DisplayData.item("schemaUpdateOptions", getSchemaUpdateOptions().toString()) + .withLabel("Table SchemaUpdateOptions")) .addIfNotDefault( DisplayData.item("validation", getValidate()).withLabel("Validation Enabled"), true) .addIfNotNull( diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteTables.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteTables.java index 10f368fd5cad..632373c034cc 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteTables.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/WriteTables.java @@ -29,6 +29,8 @@ import java.io.IOException; import java.util.List; import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; import javax.annotation.Nullable; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; @@ -38,6 +40,7 @@ import org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers.PendingJob; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers.PendingJobManager; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.SchemaUpdateOption; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.WriteDisposition; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.DatasetService; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.JobService; @@ -90,6 +93,7 @@ class WriteTables private final PCollectionView loadJobIdPrefixView; private final WriteDisposition firstPaneWriteDisposition; private final CreateDisposition firstPaneCreateDisposition; + private final Set schemaUpdateOptions; private final DynamicDestinations dynamicDestinations; private final List> sideInputs; private final TupleTag> mainOutputTag; @@ -219,7 +223,8 @@ public void processElement(ProcessContext c, BoundedWindow window) throws Except tableSchema, partitionFiles, writeDisposition, - createDisposition); + createDisposition, + schemaUpdateOptions); pendingJobs.add( new PendingJobData(window, retryJob, partitionFiles, tableDestination, tableReference)); } @@ -288,7 +293,9 @@ public WriteTables( int maxRetryJobs, boolean ignoreUnknownValues, String kmsKey, - String sourceFormat) { + String sourceFormat, + Set schemaUpdateOptions) { + this.tempTable = tempTable; this.bqServices = bqServices; this.loadJobIdPrefixView = loadJobIdPrefixView; @@ -303,6 +310,7 @@ public WriteTables( this.ignoreUnknownValues = ignoreUnknownValues; this.kmsKey = kmsKey; this.sourceFormat = sourceFormat; + this.schemaUpdateOptions = schemaUpdateOptions; } @Override @@ -346,7 +354,8 @@ private PendingJob startLoad( @Nullable TableSchema schema, List gcsUris, WriteDisposition writeDisposition, - CreateDisposition createDisposition) { + CreateDisposition createDisposition, + Set schemaUpdateOptions) { JobConfigurationLoad loadConfig = new JobConfigurationLoad() .setDestinationTable(ref) @@ -356,6 +365,11 @@ private PendingJob startLoad( .setCreateDisposition(createDisposition.name()) .setSourceFormat(sourceFormat) .setIgnoreUnknownValues(ignoreUnknownValues); + if (schemaUpdateOptions != null) { + List options = + schemaUpdateOptions.stream().map(Enum::name).collect(Collectors.toList()); + loadConfig.setSchemaUpdateOptions(options); + } if (timePartitioning != null) { loadConfig.setTimePartitioning(timePartitioning); // only set clustering if timePartitioning is set diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/TestPubsub.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/TestPubsub.java index 9b18333e59b1..1e75d4377abf 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/TestPubsub.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/TestPubsub.java @@ -202,9 +202,11 @@ public List pull() throws IOException { public List pull(int maxBatchSize) throws IOException { List messages = pubsub.pull(0, subscriptionPath, maxBatchSize, true); - pubsub.acknowledge( - subscriptionPath, - messages.stream().map(msg -> msg.ackId).collect(ImmutableList.toImmutableList())); + if (!messages.isEmpty()) { + pubsub.acknowledge( + subscriptionPath, + messages.stream().map(msg -> msg.ackId).collect(ImmutableList.toImmutableList())); + } return messages.stream() .map(msg -> new PubsubMessage(msg.elementBytes, msg.attributes, msg.recordId)) @@ -225,7 +227,7 @@ public List waitForNMessages(int n, Duration timeoutDuration) receivedMessages.addAll(pull(n - receivedMessages.size())); while (receivedMessages.size() < n - && Seconds.secondsBetween(new DateTime(), startTime).getSeconds() < timeoutSeconds) { + && Seconds.secondsBetween(startTime, new DateTime()).getSeconds() < timeoutSeconds) { Thread.sleep(1000); receivedMessages.addAll(pull(n - receivedMessages.size())); } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/FakeJobService.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/FakeJobService.java index f0b4cd7f6a90..0c9ecd7688b6 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/FakeJobService.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/FakeJobService.java @@ -51,10 +51,12 @@ import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Paths; +import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Collectors; import org.apache.avro.Schema; import org.apache.avro.file.DataFileReader; import org.apache.avro.file.DataFileWriter; @@ -264,6 +266,12 @@ public JobStatistics dryRunQuery(String projectId, JobConfigurationQuery query, throw new UnsupportedOperationException(); } + public Collection getAllJobs() { + synchronized (allJobs) { + return allJobs.values().stream().map(j -> j.job).collect(Collectors.toList()); + } + } + @Override public Job getJob(JobReference jobRef) { try { diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java index da6c5e7d2d3b..99e3fbaec130 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOWriteTest.java @@ -36,6 +36,8 @@ import com.google.api.services.bigquery.model.Clustering; import com.google.api.services.bigquery.model.ErrorProto; +import com.google.api.services.bigquery.model.Job; +import com.google.api.services.bigquery.model.JobConfigurationLoad; import com.google.api.services.bigquery.model.Table; import com.google.api.services.bigquery.model.TableDataInsertAllResponse; import com.google.api.services.bigquery.model.TableFieldSchema; @@ -54,11 +56,14 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Collections; +import java.util.EnumSet; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.ThreadLocalRandom; import java.util.regex.Matcher; import java.util.regex.Pattern; +import java.util.stream.Collectors; import org.apache.avro.generic.GenericData; import org.apache.avro.generic.GenericRecord; import org.apache.beam.sdk.coders.AtomicCoder; @@ -68,6 +73,7 @@ import org.apache.beam.sdk.io.GenerateSequence; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.Method; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.SchemaUpdateOption; import org.apache.beam.sdk.io.gcp.testing.FakeBigQueryServices; import org.apache.beam.sdk.io.gcp.testing.FakeDatasetService; import org.apache.beam.sdk.io.gcp.testing.FakeJobService; @@ -1179,6 +1185,8 @@ public void testBuildWriteDisplayData() { .withSchema(schema) .withCreateDisposition(BigQueryIO.Write.CreateDisposition.CREATE_IF_NEEDED) .withWriteDisposition(BigQueryIO.Write.WriteDisposition.WRITE_APPEND) + .withSchemaUpdateOptions( + EnumSet.of(BigQueryIO.Write.SchemaUpdateOption.ALLOW_FIELD_ADDITION)) .withTableDescription(tblDescription) .withoutValidation(); @@ -1194,6 +1202,11 @@ public void testBuildWriteDisplayData() { displayData, hasDisplayItem( "writeDisposition", BigQueryIO.Write.WriteDisposition.WRITE_APPEND.toString())); + assertThat( + displayData, + hasDisplayItem( + "schemaUpdateOptions", + EnumSet.of(BigQueryIO.Write.SchemaUpdateOption.ALLOW_FIELD_ADDITION).toString())); assertThat(displayData, hasDisplayItem("tableDescription", tblDescription)); assertThat(displayData, hasDisplayItem("validation", false)); } @@ -1571,7 +1584,8 @@ public void testWriteTables() throws Exception { 4, false, null, - "NEWLINE_DELIMITED_JSON"); + "NEWLINE_DELIMITED_JSON", + Collections.emptySet()); PCollection> writeTablesOutput = writeTablesInput.apply(writeTables); @@ -1854,4 +1868,59 @@ public void testWrongErrorConfigs() { + "uses extended errors information. Use getFailedInsertsWithErr instead")); } } + + void schemaUpdateOptionsTest( + BigQueryIO.Write.Method insertMethod, Set schemaUpdateOptions) + throws Exception { + TableRow row = new TableRow().set("date", "2019-01-01").set("number", "1"); + + TableSchema schema = + new TableSchema() + .setFields( + ImmutableList.of( + new TableFieldSchema() + .setName("date") + .setType("DATE") + .setName("number") + .setType("INTEGER"))); + + Write writeTransform = + BigQueryIO.writeTableRows() + .to("project-id:dataset-id.table-id") + .withTestServices(fakeBqServices) + .withMethod(insertMethod) + .withSchema(schema) + .withCreateDisposition(BigQueryIO.Write.CreateDisposition.CREATE_IF_NEEDED) + .withWriteDisposition(BigQueryIO.Write.WriteDisposition.WRITE_APPEND) + .withSchemaUpdateOptions(schemaUpdateOptions); + + p.apply(Create.of(row)).apply(writeTransform); + p.run(); + + List expectedOptions = + schemaUpdateOptions.stream().map(Enum::name).collect(Collectors.toList()); + + for (Job job : fakeJobService.getAllJobs()) { + JobConfigurationLoad configuration = job.getConfiguration().getLoad(); + assertEquals(expectedOptions, configuration.getSchemaUpdateOptions()); + } + } + + @Test + public void testWriteFileSchemaUpdateOptionAllowFieldAddition() throws Exception { + Set options = EnumSet.of(SchemaUpdateOption.ALLOW_FIELD_ADDITION); + schemaUpdateOptionsTest(BigQueryIO.Write.Method.FILE_LOADS, options); + } + + @Test + public void testWriteFileSchemaUpdateOptionAllowFieldRelaxation() throws Exception { + Set options = EnumSet.of(SchemaUpdateOption.ALLOW_FIELD_RELAXATION); + schemaUpdateOptionsTest(BigQueryIO.Write.Method.FILE_LOADS, options); + } + + @Test + public void testWriteFileSchemaUpdateOptionAll() throws Exception { + Set options = EnumSet.allOf(SchemaUpdateOption.class); + schemaUpdateOptionsTest(BigQueryIO.Write.Method.FILE_LOADS, options); + } } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQuerySchemaUpdateOptionsIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQuerySchemaUpdateOptionsIT.java new file mode 100644 index 000000000000..8d030ca98b46 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQuerySchemaUpdateOptionsIT.java @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery; + +import static org.junit.Assert.assertEquals; + +import com.google.api.services.bigquery.model.QueryResponse; +import com.google.api.services.bigquery.model.Table; +import com.google.api.services.bigquery.model.TableFieldSchema; +import com.google.api.services.bigquery.model.TableReference; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import java.security.SecureRandom; +import java.util.Arrays; +import java.util.EnumSet; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.SchemaUpdateOption; +import org.apache.beam.sdk.io.gcp.testing.BigqueryClient; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.testing.TestPipelineOptions; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Integration test for BigqueryIO with DataflowRunner and DirectRunner. */ +@RunWith(JUnit4.class) +public class BigQuerySchemaUpdateOptionsIT { + private static final Logger LOG = LoggerFactory.getLogger(BigQuerySchemaUpdateOptionsIT.class); + private static String project; + + private static final BigqueryClient BQ_CLIENT = + new BigqueryClient("BigQuerySchemaUpdateOptionsIT"); + + private static final String BIG_QUERY_DATASET_ID = + "bq_query_schema_update_options_" + + System.currentTimeMillis() + + "_" + + (new SecureRandom().nextInt(32)); + + private static final String TEST_TABLE_NAME_BASE = "test_table_"; + + private static final TableSchema BASE_TABLE_SCHEMA = + new TableSchema() + .setFields( + ImmutableList.of( + new TableFieldSchema().setName("optional_field").setType("STRING"), + new TableFieldSchema() + .setName("required_field") + .setType("STRING") + .setMode("REQUIRED"))); + + @BeforeClass + public static void setupTestEnvironment() throws Exception { + project = TestPipeline.testingPipelineOptions().as(GcpOptions.class).getProject(); + BQ_CLIENT.createNewDataset(project, BIG_QUERY_DATASET_ID); + } + + @AfterClass + public static void cleanup() { + LOG.info("Start to clean up tables and datasets."); + BQ_CLIENT.deleteDataset(project, BIG_QUERY_DATASET_ID); + } + + public interface Options extends TestPipelineOptions, BigQueryOptions {} + + /** + * Make a new table for use in a test. + * + * @return The name of the table + * @throws Exception if anything goes awry + */ + public String makeTestTable() throws Exception { + String tableName = + TEST_TABLE_NAME_BASE + System.currentTimeMillis() + "_" + (new SecureRandom().nextInt(32)); + + BQ_CLIENT.createNewTable( + project, + BIG_QUERY_DATASET_ID, + new Table() + .setSchema(BASE_TABLE_SCHEMA) + .setTableReference( + new TableReference() + .setTableId(tableName) + .setDatasetId(BIG_QUERY_DATASET_ID) + .setProjectId(project))); + + return tableName; + } + + /** + * Runs a write test against a BigQuery table to check that SchemaUpdateOption sets are taking + * effect. + * + *

Attempt write a row via BigQueryIO.writeTables with the given params, then run the given + * query, and finaly check the results of the query. + * + * @param schemaUpdateOptions The SchemaUpdateOption set to use + * @param tableName The table to write to + * @param schema The schema to use for the table + * @param rowToInsert The row to insert + * @param testQuery A testing SQL query to run after writing the row + * @param expectedResult The expected result of the query as a nested list of column values (one + * list per result row) + */ + private void runWriteTest( + Set schemaUpdateOptions, + String tableName, + TableSchema schema, + TableRow rowToInsert, + String testQuery, + List> expectedResult) + throws Exception { + Options options = TestPipeline.testingPipelineOptions().as(Options.class); + options.setTempLocation(options.getTempRoot() + "/bq_it_temp"); + + Pipeline p = Pipeline.create(options); + Create.Values input = Create.of(rowToInsert); + + Write writer = + BigQueryIO.writeTableRows() + .to(String.format("%s:%s.%s", options.getProject(), BIG_QUERY_DATASET_ID, tableName)) + .withSchema(schema) + .withCreateDisposition(BigQueryIO.Write.CreateDisposition.CREATE_IF_NEEDED) + .withWriteDisposition(BigQueryIO.Write.WriteDisposition.WRITE_APPEND) + .withSchemaUpdateOptions(schemaUpdateOptions); + + p.apply(input).apply(writer); + p.run().waitUntilFinish(); + + QueryResponse response = BQ_CLIENT.queryWithRetries(testQuery, project); + + List> result = + response.getRows().stream() + .map( + row -> + row.getF().stream() + .map(cell -> cell.getV().toString()) + .collect(Collectors.toList())) + .collect(Collectors.toList()); + + assertEquals(expectedResult, result); + } + + @Test + public void testAllowFieldAddition() throws Exception { + String tableName = makeTestTable(); + + Set schemaUpdateOptions = + EnumSet.of(BigQueryIO.Write.SchemaUpdateOption.ALLOW_FIELD_ADDITION); + + TableSchema newSchema = + new TableSchema() + .setFields( + ImmutableList.of( + new TableFieldSchema().setName("new_field").setType("STRING"), + new TableFieldSchema().setName("optional_field").setType("STRING"), + new TableFieldSchema() + .setName("required_field") + .setType("STRING") + .setMode("REQUIRED"))); + + String[] values = {"meow", "bark"}; + TableRow rowToInsert = + new TableRow().set("new_field", values[0]).set("required_field", values[1]); + + String testQuery = + String.format( + "SELECT new_field, required_field FROM [%s.%s];", BIG_QUERY_DATASET_ID, tableName); + + List> expectedResult = Arrays.asList(Arrays.asList(values)); + runWriteTest(schemaUpdateOptions, tableName, newSchema, rowToInsert, testQuery, expectedResult); + } + + @Test + public void testAllowFieldRelaxation() throws Exception { + String tableName = makeTestTable(); + + Set schemaUpdateOptions = + EnumSet.of(BigQueryIO.Write.SchemaUpdateOption.ALLOW_FIELD_RELAXATION); + + TableSchema newSchema = + new TableSchema() + .setFields( + ImmutableList.of( + new TableFieldSchema().setName("optional_field").setType("STRING"))); + + String value = "hellooo"; + TableRow rowToInsert = new TableRow().set("optional_field", value); + + String testQuery = + String.format("SELECT optional_field FROM [%s.%s];", BIG_QUERY_DATASET_ID, tableName); + + List> expectedResult = Arrays.asList(Arrays.asList(value)); + runWriteTest(schemaUpdateOptions, tableName, newSchema, rowToInsert, testQuery, expectedResult); + } +} diff --git a/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java index 8aadf9803cb4..9da6dabd0a08 100644 --- a/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java +++ b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java @@ -419,7 +419,7 @@ public List> split(int desiredNumSplits, PipelineOptions o throws Exception { List> sources = new ArrayList<>(); if (spec.getTopic() != null) { - // in the case of a topic, we create a single source, so an unique subscriber, to avoid + // in the case of a topic, we create a single source, so a unique subscriber, to avoid // element duplication sources.add(new UnboundedJmsSource(spec)); } else { diff --git a/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOTest.java b/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOTest.java index fb5534b0ed68..01cd428a2bda 100644 --- a/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOTest.java +++ b/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOTest.java @@ -264,7 +264,7 @@ public void testSplitForTopic() throws Exception { int desiredNumSplits = 5; JmsIO.UnboundedJmsSource initialSource = new JmsIO.UnboundedJmsSource(read); List splits = initialSource.split(desiredNumSplits, pipelineOptions); - // in the case of a topic, we can have only an unique subscriber on the topic per pipeline + // in the case of a topic, we can have only a unique subscriber on the topic per pipeline // else it means we can have duplicate messages (all subscribers on the topic receive every // message). // So, whatever the desizedNumSplits is, the actual number of splits should be 1. diff --git a/sdks/java/io/kinesis/build.gradle b/sdks/java/io/kinesis/build.gradle index 6cdaf3bef3ce..2144d10c36b5 100644 --- a/sdks/java/io/kinesis/build.gradle +++ b/sdks/java/io/kinesis/build.gradle @@ -34,6 +34,7 @@ dependencies { compile library.java.slf4j_api compile library.java.joda_time compile library.java.jackson_dataformat_cbor + compile library.java.guava compile library.java.aws_java_sdk_cloudwatch compile library.java.aws_java_sdk_core compile library.java.aws_java_sdk_kinesis diff --git a/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/MongoDbIO.java b/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/MongoDbIO.java index 9fd06d34f2e3..184589062d00 100644 --- a/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/MongoDbIO.java +++ b/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/MongoDbIO.java @@ -489,7 +489,7 @@ public List> split( } if (splitKeys.size() < 1) { - LOG.debug("Split keys is low, using an unique source"); + LOG.debug("Split keys is low, using a unique source"); return Collections.singletonList(this); } diff --git a/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java b/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java index 94638942fd9d..24c89832c5c3 100644 --- a/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java +++ b/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java @@ -152,7 +152,7 @@ abstract static class Builder { } /** - * Describe a connection configuration to the MQTT broker. This method creates an unique random + * Describe a connection configuration to the MQTT broker. This method creates a unique random * MQTT client ID. * * @param serverUri The MQTT broker URI. @@ -173,7 +173,7 @@ public static ConnectionConfiguration create(String serverUri, String topic) { * * @param serverUri The MQTT broker URI. * @param topic The MQTT getTopic pattern. - * @param clientId A client ID prefix, used to construct an unique client ID. + * @param clientId A client ID prefix, used to construct a unique client ID. * @return A connection configuration to the MQTT broker. * @deprecated This constructor will be removed in a future version of Beam, please use * #create(String, String)} and {@link #withClientId(String)} instead. @@ -196,7 +196,7 @@ public ConnectionConfiguration withTopic(String topic) { return builder().setTopic(topic).build(); } - /** Set up the client ID prefix, which is used to construct an unique client ID. */ + /** Set up the client ID prefix, which is used to construct a unique client ID. */ public ConnectionConfiguration withClientId(String clientId) { checkArgument(clientId != null, "clientId can not be null"); return builder().setClientId(clientId).build(); diff --git a/sdks/python/apache_beam/io/gcp/bigquery.py b/sdks/python/apache_beam/io/gcp/bigquery.py index 0280c6108837..61c93ff0189c 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery.py +++ b/sdks/python/apache_beam/io/gcp/bigquery.py @@ -728,7 +728,7 @@ def __init__( self._max_buffered_rows = (max_buffered_rows or BigQueryWriteFn.DEFAULT_MAX_BUFFERED_ROWS) self._retry_strategy = ( - retry_strategy or bigquery_tools.RetryStrategy.RETRY_ON_TRANSIENT_ERROR) + retry_strategy or bigquery_tools.RetryStrategy.RETRY_ALWAYS) self.additional_bq_parameters = additional_bq_parameters or {} @@ -868,7 +868,9 @@ def _flush_batch(self, destination): insert_ids=insert_ids, skip_invalid_rows=True) - _LOGGER.debug("Passed: %s. Errors are %s", passed, errors) + if not passed: + _LOGGER.info("There were errors inserting to BigQuery: %s", + errors) failed_rows = [rows[entry.index] for entry in errors] should_retry = any( bigquery_tools.RetryStrategy.should_retry( @@ -1066,6 +1068,10 @@ def __init__(self, FILE_LOADS on Batch pipelines. insert_retry_strategy: The strategy to use when retrying streaming inserts into BigQuery. Options are shown in bigquery_tools.RetryStrategy attrs. + Default is to retry always. This means that whenever there are rows + that fail to be inserted to BigQuery, they will be retried indefinitely. + Other retry strategy settings will produce a deadletter PCollection + as output. additional_bq_parameters (callable): A function that returns a dictionary with additional parameters to pass to BQ when creating / loading data into a table. These can be 'timePartitioning', 'clustering', etc. They diff --git a/sdks/python/apache_beam/io/gcp/gcsio_test.py b/sdks/python/apache_beam/io/gcp/gcsio_test.py index 3076f56e5a07..5dcfbd8d355a 100644 --- a/sdks/python/apache_beam/io/gcp/gcsio_test.py +++ b/sdks/python/apache_beam/io/gcp/gcsio_test.py @@ -108,7 +108,7 @@ def get_last_generation(self, bucket, obj): def Get(self, get_request, download=None): # pylint: disable=invalid-name f = self.get_file(get_request.bucket, get_request.object) if f is None: - # Failing with a HTTP 404 if file does not exist. + # Failing with an HTTP 404 if file does not exist. raise HttpError({'status': 404}, None, None) if download is None: return f.get_metadata() diff --git a/sdks/python/apache_beam/options/pipeline_options.py b/sdks/python/apache_beam/options/pipeline_options.py index 4de4b519adfa..8ab71374d76b 100644 --- a/sdks/python/apache_beam/options/pipeline_options.py +++ b/sdks/python/apache_beam/options/pipeline_options.py @@ -919,6 +919,48 @@ def _add_argparse_args(cls, parser): 'the pipeline.')) +class JobServerOptions(PipelineOptions): + """Options for starting a Beam job server. Roughly corresponds to + JobServerDriver.ServerConfiguration in Java. + """ + @classmethod + def _add_argparse_args(cls, parser): + parser.add_argument('--artifacts_dir', default=None, + help='The location to store staged artifact files. ' + 'Any Beam-supported file system is allowed. ' + 'If unset, the local temp dir will be used.') + parser.add_argument('--job_port', default=0, + help='Port to use for the job service. 0 to use a ' + 'dynamic port.') + parser.add_argument('--artifact_port', default=0, + help='Port to use for artifact staging. 0 to use a ' + 'dynamic port.') + parser.add_argument('--expansion_port', default=0, + help='Port to use for artifact staging. 0 to use a ' + 'dynamic port.') + + +class FlinkRunnerOptions(PipelineOptions): + + PUBLISHED_FLINK_VERSIONS = ['1.7', '1.8', '1.9'] + + @classmethod + def _add_argparse_args(cls, parser): + parser.add_argument('--flink_master', + default='[auto]', + help='Flink master address (http://host:port)' + ' Use "[local]" to start a local cluster' + ' for the execution. Use "[auto]" if you' + ' plan to either execute locally or let the' + ' Flink job server infer the cluster address.') + parser.add_argument('--flink_version', + default=cls.PUBLISHED_FLINK_VERSIONS[-1], + choices=cls.PUBLISHED_FLINK_VERSIONS, + help='Flink version to use.') + parser.add_argument('--flink_job_server_jar', + help='Path or URL to a flink jobserver jar.') + + class TestOptions(PipelineOptions): @classmethod diff --git a/sdks/python/apache_beam/runners/interactive/display/pipeline_graph.py b/sdks/python/apache_beam/runners/interactive/display/pipeline_graph.py index aabf95908dd9..4d9a46e1bf78 100644 --- a/sdks/python/apache_beam/runners/interactive/display/pipeline_graph.py +++ b/sdks/python/apache_beam/runners/interactive/display/pipeline_graph.py @@ -25,21 +25,29 @@ from __future__ import print_function import collections +import logging import threading import pydot import apache_beam as beam from apache_beam.portability.api import beam_runner_api_pb2 +from apache_beam.runners.interactive import interactive_environment as ie +from apache_beam.runners.interactive import pipeline_instrument as inst +from apache_beam.runners.interactive.display import pipeline_graph_renderer + +# pylint does not understand context +# pylint:disable=dangerous-default-value class PipelineGraph(object): - """Creates a DOT representation of the pipeline. Thread-safe.""" + """Creates a DOT representing the pipeline. Thread-safe. Runner agnostic.""" def __init__(self, pipeline, - default_vertex_attrs=None, - default_edge_attrs=None): + default_vertex_attrs={'shape': 'box'}, + default_edge_attrs=None, + render_option=None): """Constructor of PipelineGraph. Examples: @@ -55,9 +63,17 @@ def __init__(self, pipeline: (Pipeline proto) or (Pipeline) pipeline to be rendered. default_vertex_attrs: (Dict[str, str]) a dict of default vertex attributes default_edge_attrs: (Dict[str, str]) a dict of default edge attributes + render_option: (str) this parameter decides how the pipeline graph is + rendered. See display.pipeline_graph_renderer for available options. """ self._lock = threading.Lock() self._graph = None + self._pipeline_instrument = None + if isinstance(pipeline, beam.Pipeline): + self._pipeline_instrument = inst.PipelineInstrument(pipeline) + # The pre-process links user pipeline to runner pipeline through analysis + # but without mutating runner pipeline. + self._pipeline_instrument.preprocess() if isinstance(pipeline, beam_runner_api_pb2.Pipeline): self._pipeline_proto = pipeline @@ -79,8 +95,7 @@ def __init__(self, for pcoll_id in transform_proto.outputs.values(): self._producers[pcoll_id] = transform_id - # Set the default vertex color to blue. - default_vertex_attrs = default_vertex_attrs or {} + default_vertex_attrs = default_vertex_attrs or {'shape': 'box'} if 'color' not in default_vertex_attrs: default_vertex_attrs['color'] = 'blue' if 'fontcolor' not in default_vertex_attrs: @@ -92,9 +107,22 @@ def __init__(self, default_vertex_attrs, default_edge_attrs) + self._renderer = pipeline_graph_renderer.get_renderer(render_option) + def get_dot(self): return self._get_graph().to_string() + def display_graph(self): + rendered_graph = self._renderer.render_pipeline_graph(self) + if ie.current_env().is_in_notebook: + try: + from IPython.core import display + display.display(display.HTML(rendered_graph)) + except ImportError: # Unlikely to happen when is_in_notebook. + logging.warning('Failed to import IPython display module when current ' + 'environment is in a notebook. Cannot display the ' + 'pipeline graph.') + def _top_level_transforms(self): """Yields all top level PTransforms (subtransforms of the root PTransform). @@ -107,6 +135,20 @@ def _top_level_transforms(self): top_level_transform_proto = transforms[top_level_transform_id] yield top_level_transform_id, top_level_transform_proto + def _decorate(self, value): + """Decorates label-ish values used for rendering in dot language. + + Escapes special characters in the given str value for dot language. All + PTransform unique names are escaped implicitly in this module when building + dot representation. Otherwise, special characters will break the graph + rendered or cause runtime errors. + """ + # Replace py str literal `\\` which is `\` in dot with py str literal + # `\\\\` which is `\\` in dot so that dot `\\` can be rendered as `\`. Then + # replace `"` with `\\"` so that the dot generated will be `\"` and be + # rendered as `"`. + return '"{}"'.format(value.replace('\\', '\\\\').replace('"', '\\"')) + def _generate_graph_dicts(self): """From pipeline_proto and other info, generate the graph. @@ -126,24 +168,40 @@ def _generate_graph_dicts(self): self._edge_to_vertex_pairs = collections.defaultdict(list) for _, transform in self._top_level_transforms(): - vertex_dict[transform.unique_name] = {} + vertex_dict[self._decorate(transform.unique_name)] = {} for pcoll_id in transform.outputs.values(): - # For PCollections without consuming PTransforms, we add an invisible - # PTransform node as the consumer. + pcoll_node = None + if self._pipeline_instrument: + pcoll_node = self._pipeline_instrument.cacheable_var_by_pcoll_id( + pcoll_id) + # If no PipelineInstrument is available or the PCollection is not + # watched. + if not pcoll_node: + pcoll_node = 'pcoll%s' % (hash(pcoll_id) % 10000) + vertex_dict[pcoll_node] = { + 'shape': 'circle', + 'label': '', # The pcoll node has no name. + } + # There is PipelineInstrument and the PCollection is watched with an + # assigned variable. + else: + vertex_dict[pcoll_node] = {'shape': 'circle'} if pcoll_id not in self._consumers: - invisible_leaf = 'leaf%s' % (hash(pcoll_id) % 10000) - vertex_dict[invisible_leaf] = {'style': 'invis'} self._edge_to_vertex_pairs[pcoll_id].append( - (transform.unique_name, invisible_leaf)) - edge_dict[(transform.unique_name, invisible_leaf)] = {} + (self._decorate(transform.unique_name), pcoll_node)) + edge_dict[(self._decorate(transform.unique_name), + pcoll_node)] = {} else: for consumer in self._consumers[pcoll_id]: - producer_name = transform.unique_name - consumer_name = transforms[consumer].unique_name + producer_name = self._decorate(transform.unique_name) + consumer_name = self._decorate(transforms[consumer].unique_name) + self._edge_to_vertex_pairs[pcoll_id].append( + (producer_name, pcoll_node)) + edge_dict[(producer_name, pcoll_node)] = {} self._edge_to_vertex_pairs[pcoll_id].append( - (producer_name, consumer_name)) - edge_dict[(producer_name, consumer_name)] = {} + (pcoll_node, consumer_name)) + edge_dict[(pcoll_node, consumer_name)] = {} return vertex_dict, edge_dict diff --git a/sdks/python/apache_beam/runners/interactive/display/pipeline_graph_test.py b/sdks/python/apache_beam/runners/interactive/display/pipeline_graph_test.py new file mode 100644 index 000000000000..12cc3efa4ce3 --- /dev/null +++ b/sdks/python/apache_beam/runners/interactive/display/pipeline_graph_test.py @@ -0,0 +1,88 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Tests for apache_beam.runners.interactive.display.pipeline_graph.""" +from __future__ import absolute_import + +import unittest + +import apache_beam as beam +from apache_beam.runners.interactive import interactive_beam as ib +from apache_beam.runners.interactive import interactive_environment as ie +from apache_beam.runners.interactive import interactive_runner as ir +from apache_beam.runners.interactive.display import pipeline_graph + +# pylint: disable=range-builtin-not-iterating,unused-variable,possibly-unused-variable +# Reason: +# Disable pylint for pipelines built for testing. Not all PCollections are +# used but they need to be assigned to variables so that we can test how +# interactive beam applies the magic around user-defined variables. + + +# The tests need graphviz to work. +@unittest.skipIf(not ie.current_env().is_interactive_ready, + '[interactive] dependency is not installed.') +class PipelineGraphTest(unittest.TestCase): + + def setUp(self): + ie.new_env() + + def test_decoration(self): + p = beam.Pipeline(ir.InteractiveRunner()) + # We are examining if literal `"` and trailing literal `\` are decorated + # correctly. + pcoll = p | '"Cell 1": "Create\\"' >> beam.Create(range(1000)) + ib.watch(locals()) + + self.assertEqual( + ('digraph G {\n' + 'node [color=blue, fontcolor=blue, shape=box];\n' + # The py string literal from `\\\\\\"` is `\\\"` in dot and will be + # rendered as `\"` because they are enclosed by `"`. + '"\\"Cell 1\\": \\"Create\\\\\\"";\n' + 'pcoll [shape=circle];\n' + '"\\"Cell 1\\": \\"Create\\\\\\"" -> pcoll;\n' + '}\n'), + pipeline_graph.PipelineGraph(p).get_dot()) + + def test_get_dot(self): + p = beam.Pipeline(ir.InteractiveRunner()) + init_pcoll = p | 'Init' >> beam.Create(range(10)) + squares = init_pcoll | 'Square' >> beam.Map(lambda x: x * x) + cubes = init_pcoll | 'Cube' >> beam.Map(lambda x: x ** 3) + ib.watch(locals()) + + self.assertEqual( + ('digraph G {\n' + 'node [color=blue, fontcolor=blue, shape=box];\n' + '"Init";\n' + 'init_pcoll [shape=circle];\n' + '"Square";\n' + 'squares [shape=circle];\n' + '"Cube";\n' + 'cubes [shape=circle];\n' + '"Init" -> init_pcoll;\n' + 'init_pcoll -> "Square";\n' + 'init_pcoll -> "Cube";\n' + '"Square" -> squares;\n' + '"Cube" -> cubes;\n' + '}\n'), + pipeline_graph.PipelineGraph(p).get_dot()) + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/runners/interactive/examples/Interactive Beam Example.ipynb b/sdks/python/apache_beam/runners/interactive/examples/Interactive Beam Example.ipynb index 28448c2d23ce..a4486eec2145 100644 --- a/sdks/python/apache_beam/runners/interactive/examples/Interactive Beam Example.ipynb +++ b/sdks/python/apache_beam/runners/interactive/examples/Interactive Beam Example.ipynb @@ -47,66 +47,70 @@ "\n", "\n", - "\n", "\n", - "\n", - "\n", + "\n", + "\n", "G\n", - "\n", - "\n", - "\n", + "\n", "\n", - "Create\n", - "\n", - "Create\n", + "\n", + "Create\n", + "\n", + "Create\n", "\n", - "\n", - "Square\n", - "\n", - "Square\n", - "\n", - "\n", - "Create->Square\n", - "\n", - "\n", - "\n", - "{8, ...}\n", - "\n", + "\n", + "\n", + "diverge6871\n", + "\n", "\n", + "\n", + "\n", + "Create->diverge6871\n", + "\n", + "init_pcoll\n", "\n", - "\n", - "Cube\n", - "\n", - "Cube\n", - "\n", - "\n", - "Create->Cube\n", - "\n", - "\n", - "\n", - "{8, ...}\n", - "\n", + "\n", + "\n", + "Square\n", + "\n", + "Square\n", "\n", + "\n", + "\n", + "diverge6871->Square\n", + "\n", + "\n", "\n", - "\n", - "Square->leaf2469\n", - "\n", - "\n", - "\n", - "{36, ...}\n", - "\n", + "\n", + "\n", + "Cube\n", + "\n", + "Cube\n", "\n", + "\n", + "\n", + "diverge6871->Cube\n", + "\n", + "\n", "\n", - "\n", - "Cube->leaf2468\n", - "\n", - "\n", - "\n", - "{27, ...}\n", - "\n", + "\n", + "\n", + "\n", + "Square->leaf6244\n", + "\n", + "\n", + "squares\n", "\n", + "\n", + "\n", + "\n", + "Cube->leaf6690\n", + "\n", + "\n", + "cubes\n", "\n", "\n", "\n" @@ -117,52 +121,6 @@ }, "metadata": {}, "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "Running..." - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "Using 0 cached PCollections\n", - "Executing 8 of 3 transforms." - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "Cube produced {27, 729, 64, 343, 512, ...}" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "Square produced {36, 1, 9, 25, 81, ...}" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "Create produced {8, 6, 1, 9, 7, ...}" - ] - }, - "metadata": {}, - "output_type": "display_data" } ], "source": [ @@ -179,9 +137,23 @@ "execution_count": 3, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: matplotlib in /Users/ningk/workspace/p3_ib_venv/lib/python3.7/site-packages (3.1.1)\n", + "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /Users/ningk/workspace/p3_ib_venv/lib/python3.7/site-packages (from matplotlib) (2.4.2)\n", + "Requirement already satisfied: cycler>=0.10 in /Users/ningk/workspace/p3_ib_venv/lib/python3.7/site-packages (from matplotlib) (0.10.0)\n", + "Requirement already satisfied: python-dateutil>=2.1 in /Users/ningk/workspace/p3_ib_venv/lib/python3.7/site-packages (from matplotlib) (2.8.0)\n", + "Requirement already satisfied: numpy>=1.11 in /Users/ningk/workspace/p3_ib_venv/lib/python3.7/site-packages (from matplotlib) (1.17.3)\n", + "Requirement already satisfied: kiwisolver>=1.0.1 in /Users/ningk/workspace/p3_ib_venv/lib/python3.7/site-packages (from matplotlib) (1.1.0)\n", + "Requirement already satisfied: six in /Users/ningk/workspace/p3_ib_venv/lib/python3.7/site-packages (from cycler>=0.10->matplotlib) (1.12.0)\n", + "Requirement already satisfied: setuptools in /Users/ningk/workspace/p3_ib_venv/lib/python3.7/site-packages (from kiwisolver>=1.0.1->matplotlib) (41.6.0)\n" + ] + }, { "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD8CAYAAAB5Pm/hAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAGm5JREFUeJzt3X+QVPWZ7/H3MwM42+gVgVlLGZjGElEhjsCoeDUryvUGwQqkWI1Wo8SiMjdGl3g1rmQndW9uFVMlVYkmVHZJdUU3cOnrYlgtSLRcFbTib3bAn4DGURgYRBlRR2VCBOa5f5wzMIPAdDPdfZrTn1fV1Dnn6dPdz3QNn/7y7dPnmLsjIiLxVRF1AyIiUlgKehGRmFPQi4jEnIJeRCTmFPQiIjGnoBcRiTkFvYhIzCnoRURiTkEvIhJzA6JuAGD48OGeTCajbkNE5ISyfv36j929uq/9SiLok8kkzc3NUbchInJCMbPWbPbT1I2ISMwp6EVEYk5BLyIScyUxR38k+/bto62tjb1790bdSsmrqqqipqaGgQMHRt2KiJSgkg36trY2TjnlFJLJJGYWdTsly93ZvXs3bW1tjB49Oup2RKQElezUzd69exk2bJhCvg9mxrBhw/Q/HxE5qpINekAhnyW9TiInhkwGkkmoqAiWmUxxnrdkp25EROIkk4GGBujsDLZbW4NtgFSqsM9d0iN6EZG4aGw8FPLdOjuDeqEp6EvEgQMHom5BRApo27bc6vkUn6AvwOTXnj17mDFjBnV1dYwfP54VK1bwxBNPcO655zJx4kTmz5/PtddeC8DPfvYzfv7znx+87/jx49m6dSsAs2bNYtKkSYwbN450On1wn5NPPpm77rqLuro6XnrpJdavX88VV1zBpEmT+Na3vsXOnTsBWLx4Meeffz4XXHABN9xwQ79/LxEpvlGjcqvnUzzm6As0+fXEE09w5pln8thjjwHQ0dHB+PHjWbt2LWeffTbf/e53s3qcBx98kKFDh/KXv/yFiy66iNmzZzNs2DD27NnDJZdcwi9+8Qv27dvHFVdcwapVq6iurmbFihU0Njby4IMPcu+997JlyxZOOukkPvvss+P+fUQkOk1NvWMKIJEI6oUWjxF9gSa/vvGNb/DUU09xzz338Nxzz7FlyxZGjx7NmDFjMDPmzJmT1eMsXryYuro6Jk+ezPbt23n33XcBqKysZPbs2QC88847vPXWW1x99dVceOGFLFy4kLa2NgAuuOACUqkUy5cvZ8CAeLw3i5SbVArSaaitBbNgmU4X/oNYyGJEb2ZjgRU9SmcB/wtYFtaTwFbgenf/1IJj/X4FTAc6ge+5+4b8tn2YAk1+nXPOOWzYsIHHH3+cn/70p0ydOvWo+w4YMICurq6D293HtT/77LM8/fTTvPTSSyQSCaZMmXLwtqqqKiorK4Hgi0/jxo3jpZde+tpjP/bYY/zpT3/iD3/4A01NTbz55psKfJETUCpVnGA/XJ8jend/x90vdPcLgUkE4f0osABY4+5jgDXhNsA1wJjwpwFYUojGeynQ5NcHH3xAIpFgzpw53H333bz44ots3bqV9957D4CHHnro4L7JZJING4L3sw0bNrBlyxYgmO457bTTSCQSvP3227z88stHfK6xY8fS3t5+MOj37dvHxo0b6erqYvv27Vx55ZUsWrSIjo4Ovvzyy379XiJSXnIdFk4F3nP3VjObCUwJ60uBZ4F7gJnAMnd34GUzG2JmZ7j7zjz1/HUFmvx68803ufvuu6moqGDgwIEsWbKEjz/+mBkzZpBIJPjmN7/JF198AcDs2bNZtmwZ48aN45JLLuGcc84BYNq0afzmN7/hvPPOY+zYsUyePPmIzzVo0CBWrlzJ/Pnz6ejoYP/+/dxxxx2cc845zJkzh46ODtyd+fPnM2TIkH79XiJSZtw96x/gQeD2cP2zHnXr3gb+CFze47Y1QP0RHqsBaAaaR40a5YfbtGnT12rHtHy5e22tu1mwXL48t/sfh2eeecZnzJhR8OfJRs6vl4ic8IBmzyK7sx7Rm9kg4NvAT47wZuFm5jm+waSBNEB9fX1O9z2iqCa/RERKXC5TN9cAG9z9o3D7o+4pGTM7A9gV1ncAI3vcryasxc6UKVOYMmVK1G2IiBxTLodX3gg81GN7NTA3XJ8LrOpRv9kCk4EOL+T8vIiIHFNWI3ozGwxcDfyPHuV7gYfNbB7QClwf1h8nOLSyheAInVvy1q2IiOQsq6B39z3AsMNquwmOwjl8Xwduy0t3IiLSb/H4ZqyIiByVgj5Pnn322YMnOBMRKSUKehGRmItN0BfqEl3Lli3jggsuoK6ujptuuonvfe97rFy58uDtJ5988sH1zz//nBkzZjB27Fh+8IMfHDz3zZNPPsmll17KxIkTue666w6ewmDBggUHTz/84x//OD8Ni4gcJhZnxirUJbo2btzIwoULefHFFxk+fDiffPIJd95551H3X7duHZs2baK2tpZp06bxyCOPMGXKFBYuXMjTTz/N4MGDWbRoEffddx+33XYbjz76KG+//TZmptMPi0jBxCLoj3WW4v4E/dq1a7nuuusYPnw4AEOHDj3m/hdffDFnnXUWADfeeCPPP/88VVVVbNq0icsuuwyAr776iksvvZRTTz2Vqqoq5s2bx7XXXqv5fREpmFgEfTEv0dXzdMRdXV189dVXB28LztBMr2135+qrr+51pstu69atY82aNaxcuZJf//rXrF27Nv8Ni0jZi8UcfaEu0XXVVVfx+9//nt27dwPwySefkEwmWb9+PQCrV69m3759B/dft24dW7ZsoaurixUrVnD55ZczefJkXnjhBVpaWoDg8oR//vOf+fLLL+no6GD69Oncf//9vP766/1rVkTkKGIxoi/UJbrGjRtHY2MjV1xxBZWVlUyYMIFFixYxc+ZM6urqmDZtGoMHDz64/0UXXcTtt99OS0sLV155Jd/5zneoqKjgd7/7HTfeeCN//etfAVi4cCGnnHIKM2fOZO/evbg79913X/+aFRE5Cgu+yBqt+vp6b25u7lXbvHkz5513XtaPkckEc/LbtgUj+aam8jqZZa6vl4ic+MxsvbvX97VfLEb0oLMUi4gcTSzm6EVE5OhKOuhLYVrpRKDXSUSOpWSDvqqqit27dyvE+uDu7N69m6qqqqhbEZESVbJz9DU1NbS1tdHe3h51KyWvqqqKmpqaqNsQkRJVskE/cOBARo8eHXUbIiInvJKduhERkfxQ0IuIxFxWQW9mQ8xspZm9bWabzexSMxtqZk+Z2bvh8rRwXzOzxWbWYmZvmNnEwv4KIiJyLNmO6H8FPOHu5wJ1wGZgAbDG3ccAa8JtgGuAMeFPA7Akrx2LiEhO+gx6MzsV+DvgAQB3/8rdPwNmAkvD3ZYCs8L1mcAyD7wMDDGzM/LeuYiIZCWbEf1ooB34VzN71cx+a2aDgdPdfWe4z4fA6eH6CGB7j/u3hbVezKzBzJrNrFmHUIqIFE42QT8AmAgscfcJwB4OTdMA4MG3mnL6ZpO7p9293t3rq6urc7mriIjkIJugbwPa3P2VcHslQfB/1D0lEy53hbfvAEb2uH9NWBMRkQj0GfTu/iGw3czGhqWpwCZgNTA3rM0FVoXrq4Gbw6NvJgMdPaZ4RESkyLL9Zuw/ABkzGwS8D9xC8CbxsJnNA1qB68N9HwemAy1AZ7iviIhEJKugd/fXgCOd3H7qEfZ14LZ+9iUiInmib8aKiMScgl5EJOYU9CIiMaegFxGJOQW9iEjMKehFRGJOQS8iEnMKehGRmFPQi4jEnIJeRCTmFPQiIjGnoBcRiTkFvYhIzCnoRURiTkEvIhJzCnoRKQuZDCSTUFERLDOZqDsqnmyvMCUicsLKZKChATo7g+3W1mAbIJWKrq9i0YheRGKvsfFQyHfr7Azq5SCroDezrWb2ppm9ZmbNYW2omT1lZu+Gy9PCupnZYjNrMbM3zGxiIX8BEZG+bNuWWz1uchnRX+nuF7p797VjFwBr3H0MsCbcBrgGGBP+NABL8tWsiMjxGDUqt3rc9GfqZiawNFxfCszqUV/mgZeBIWZ2Rj+eR0SkX5qaIJHoXUskgno5yDboHXjSzNabWfgRBqe7+85w/UPg9HB9BLC9x33bwlovZtZgZs1m1tze3n4crYuIZCeVgnQaamvBLFim0+XxQSxkf9TN5e6+w8z+FnjKzN7ueaO7u5l5Lk/s7mkgDVBfX5/TfUVEcpVKlU+wHy6rEb277wiXu4BHgYuBj7qnZMLlrnD3HcDIHnevCWsiIhKBPoPezAab2Snd68B/B94CVgNzw93mAqvC9dXAzeHRN5OBjh5TPCIiUmTZTN2cDjxqZt37/z93f8LM/hN42MzmAa3A9eH+jwPTgRagE7gl712LiEjW+gx6d38fqDtCfTcw9Qh1B27LS3ciItJv+masiEjMKehFRGJOQS8iEnMKehGRmFPQi4jEnIJeRCTmFPQiIjGnoBcRiTkFvYhIzCnoRURiTkEvIhJzCnoRkZhT0IuIxJyCXkQk5hT0IiIxp6AXEYk5Bb2ISMwp6EVEYi7roDezSjN71cz+GG6PNrNXzKzFzFaY2aCwflK43RLenixM6yIiko1cRvQ/Ajb32F4E3O/uZwOfAvPC+jzg07B+f7ifiIhEJKugN7MaYAbw23DbgKuAleEuS4FZ4frMcJvw9qnh/iIiEoFsR/S/BP4R6Aq3hwGfufv+cLsNGBGujwC2A4S3d4T792JmDWbWbGbN7e3tx9m+iIj0pc+gN7NrgV3uvj6fT+zuaXevd/f66urqfD60iIj0MCCLfS4Dvm1m04Eq4L8AvwKGmNmAcNReA+wI998BjATazGwAcCqwO++di4hIVvoc0bv7T9y9xt2TwA3AWndPAc8Afx/uNhdYFa6vDrcJb1/r7p7XrkVEJGv9OY7+HuBOM2shmIN/IKw/AAwL63cCC/rXooiI9Ec2UzcHufuzwLPh+vvAxUfYZy9wXR56ExGRPNA3Y0VEYk5BLyIScwp6EZGYU9CLiMScgl5EJOYU9CIiMaegFxGJOQW9iBRUJgPJJFRUBMtMJuqOyk9OX5gSEclFJgMNDdDZGWy3tgbbAKlUdH2VG43oRaRgGhsPhXy3zs6gLsWjoBeRgtm2Lbe6FIaCXkQKZtSo3OpSGAp6ESmYpiZIJHrXEomgLsWjoBeRgkmlIJ2G2lowC5bptD6ILTYddSMiBZVKKdijphG9iEjMKehFRGJOQS8iEnN9Br2ZVZnZOjN73cw2mtn/CeujzewVM2sxsxVmNiisnxRut4S3Jwv7K4iIyLFkM6L/K3CVu9cBFwLTzGwysAi4393PBj4F5oX7zwM+Dev3h/uJiEhE+gx6D3wZbg4Mfxy4ClgZ1pcCs8L1meE24e1Tzczy1rGIiOQkqzl6M6s0s9eAXcBTwHvAZ+6+P9ylDRgRro8AtgOEt3cAw/LZtIiIZC+roHf3A+5+IVADXAyc298nNrMGM2s2s+b29vb+PpyIiBxFTkfduPtnwDPApcAQM+v+wlUNsCNc3wGMBAhvPxXYfYTHSrt7vbvXV1dXH2f7IiLSl2yOuqk2syHh+t8AVwObCQL/78Pd5gKrwvXV4Tbh7Wvd3fPZtIiIZC+bUyCcASw1s0qCN4aH3f2PZrYJ+DczWwi8CjwQ7v8A8H/NrAX4BLihAH2LiEiW+gx6d38DmHCE+vsE8/WH1/cC1+WlOxER6Td9M1ZEJOYU9CIiMaegFxGJOQW9iEjMKehFRGJOQS8iEnMKehGRmFPQi4jEnIJeRCTmFPQiIjGnoBcRiTkFvYhIzCnoRURiTkEvIhJzCnqRGMtkIJmEiopgmclE3ZFEIZsLj4jICSiTgYYG6OwMtltbg22AVCq6vqT4NKIXianGxkMh362zM6hLeVHQi8TUtm251SW+srk4+Egze8bMNpnZRjP7UVgfamZPmdm74fK0sG5mttjMWszsDTObWOhfQkS+btSo3OoSX9mM6PcDd7n7+cBk4DYzOx9YAKxx9zHAmnAb4BpgTPjTACzJe9ci0qemJkgketcSiaAu5aXPoHf3ne6+IVz/AtgMjABmAkvD3ZYCs8L1mcAyD7wMDDGzM/LeuYgcUyoF6TTU1oJZsEyn9UFsOcrpqBszSwITgFeA0919Z3jTh8Dp4foIYHuPu7WFtZ2ISFGlUgp2yeHDWDM7Gfh34A53/7znbe7ugOfyxGbWYGbNZtbc3t6ey11FRCQHWQW9mQ0kCPmMuz8Slj/qnpIJl7vC+g5gZI+714S1Xtw97e717l5fXV19vP2LiEgfsjnqxoAHgM3ufl+Pm1YDc8P1ucCqHvWbw6NvJgMdPaZ4RESkyLKZo78MuAl408xeC2v/BNwLPGxm84BW4PrwtseB6UAL0AnckteORUQkJ30Gvbs/D9hRbp56hP0duK2ffYmISJ7om7EiIjGnoBcRiTkFvYhIzCnoRURiTkEvIhJzCnoRkZhT0IuIxJyCXkQk5hT0IiIxp6AXEYk5Bb2ISMwp6EVEYk5BL1IgmQwkk1BRESwzmag7knKV06UERSQ7mQw0NEBnZ7Dd2hpsgy7tJ8WnEb1IATQ2Hgr5bp2dQV2k2BT0IgWwbVtudZFCUtCLFMCoUbnVRQpJQS9SAE1NkEj0riUSQV2k2LK5OPiDZrbLzN7qURtqZk+Z2bvh8rSwbma22MxazOwNM5tYyOZFSlUqBek01NaCWbBMp/VBrEQjmxH974Bph9UWAGvcfQywJtwGuAYYE/40AEvy06bIiSeVgq1boasrWCrkJSp9Br27/wn45LDyTGBpuL4UmNWjvswDLwNDzOyMfDUrIiK5O945+tPdfWe4/iFwerg+AtjeY7+2sCYiIhHp94ex7u6A53o/M2sws2Yza25vb+9vGyIichTHG/QfdU/JhMtdYX0HMLLHfjVh7WvcPe3u9e5eX11dfZxtiIhIX4436FcDc8P1ucCqHvWbw6NvJgMdPaZ4RETKW0QnQOrzXDdm9hAwBRhuZm3A/wbuBR42s3lAK3B9uPvjwHSgBegEbilAzyIiJ54IT4BkwRR7tOrr6725uTnqNkRECieZDML9cLW1wfG3x8HM1rt7fV/76ZuxIiLFEOEJkBT0Eks6F7x8TdR/FBGeAElBL7HTPRXa2gruh6ZCFfZlrBT+KCI8AZLm6CV2CjAVKie6UvmjyGSCixJs2xaM5Jua+vVBbLZz9Ap6iZ2KimDQdjiz4LwzUoZi+kehD2OlbOlc8CUm6rlxKPs/CgW9xI7OBV9CSmFuHMr+j0JBL7Gjc8GXkFK5eG6Z/1Fojl5ECiemc+OlQnP0IhL9/HiZz42XCgW9SFyVwvx4mc+NlwoFveRd1INICZXC/HiZz42XCgW95FUpDCJLRtTveBGeW6UXXTw3cgp6yatSGESWhFJ4x9P8uIQU9JJXpTKIjHw0XQrveJofl5CCXvKqJAaRpTCaLoV3PM2PS0hBHzNRD2SbmiAxaH+vWmLQ/uIOIkthNF0S73hoflwABX2slMJANkWGtH+fWrZidFHLVtL+fVKU2Wha0yZSStw97z/ANOAdgmvHLuhr/0mTJvmJbvmtz3lt5XY3Dnht5XZffutzRe+httY9iPjeP7W1ZdZEKfTg7r58efCcZsFy+fLiPr/EHtDs2WRyNjvl8gNUAu8BZwGDgNeB8491n/4GfdQhu/zW5zzBl70yJcGXRe/D6DpivhldRWzCjhyyZsXrYfly90Si9/MnEgpaiZ1sg74QUzcXAy3u/r67fwX8GzCzAM8DQOaHz9OwZAKtB2pwKmg9UEPDkglkfvh8oZ7yaxrTSToZ3KvWyWAa08mi9QAwqnJHTvXCNFECc9P6EFKkl0IE/Qhge4/ttrBWEKUQstsOnJlTvVCaDtxDgj29agn20HTgniI2USJz0/oQUuSgyD6MNbMGM2s2s+b29vbjfpxSCNlRlR/kVC+UVO0LpDnsg1C+T6r2hSI2odG0SKkpRNDvAEb22K4Ja724e9rd6929vrq6+rifrBRCtqlh65FH0g1bi9ZD0EgTqcQqtjKaLirZymhSiVUaTYuUuUIE/X8CY8xstJkNAm4AVhfgeYDSCNnUv1xO+tZXqa1sC0bSlW2kb32V1L9cXrQegkY0mhaRryvIhUfMbDrwS4IjcB5092MOKft74ZHMD5+nMZ1k24EzGVX5AU0NW4sfsiIiRZbthUd0hSkRkROUrjAlIiKAgl5EJPYU9CIiMaegFxGJOQW9iEjMlcRRN2bWDrTm4aGGAx/n4XHiQK9FQK9DQK/DIXF6LWrdvc9vnJZE0OeLmTVnc6hROdBrEdDrENDrcEg5vhaauhERiTkFvYhIzMUt6NNRN1BC9FoE9DoE9DocUnavRazm6EVE5OviNqIXEZHDxCbozWyamb1jZi1mtiDqfqJgZiPN7Bkz22RmG83sR1H3FCUzqzSzV83sj1H3EiUzG2JmK83sbTPbbGaXRt1TFMzsf4b/Lt4ys4fMrCrqnoolFkFvZpXAPwPXAOcDN5rZ+dF2FYn9wF3ufj4wGbitTF+Hbj8CNkfdRAn4FfCEu58L1FGGr4mZjQDmA/XuPp7gFOo3RNtV8cQi6CnyBclLlbvvdPcN4foXBP+gC3a93lJmZjXADOC3UfcSJTM7Ffg74AEAd//K3T+LtqvIDAD+xswGAAmguNf6jFBcgr6oFyQ/EZhZEpgAvBJtJ5H5JfCPQFfUjURsNNAO/Gs4jfVbMxscdVPF5u47gJ8D24CdQIe7PxltV8UTl6CXHszsZODfgTvc/fOo+yk2M7sW2OXu66PupQQMACYCS9x9ArAHKLvPsMzsNIL/5Y8GzgQGm9mcaLsqnrgEfVYXJC8HZjaQIOQz7v5I1P1E5DLg22a2lWAa7yozWx5tS5FpA9rcvft/disJgr/c/Ddgi7u3u/s+4BHgv0bcU9HEJeiLekHyUmVmRjAXu9nd74u6n6i4+0/cvcbdkwR/C2vdvWxGbz25+4fAdjMbG5amApsibCkq24DJZpYI/51MpYw+lB4QdQP54O77zex24D84dEHyjRG3FYXLgJuAN83stbD2T+7+eIQ9SfT+AciEg6D3gVsi7qfo3P0VM1sJbCA4Ou1VyugbsvpmrIhIzMVl6kZERI5CQS8iEnMKehGRmFPQi4jEnIJeRCTmFPQiIjGnoBcRiTkFvYhIzP1/UMvmQIt8rWMAAAAASUVORK5CYII=\n", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAafElEQVR4nO3df3DU9b3v8ec7AcxZ9IpAjqMEEhwRFWoEouLVHlGutwhOsePR6ixKHaa5tXqoV+uRnnTu7Z0hMzLTasv0HDo7lVO47LVYjg60Oh4VdOpvGvAnoDUKgSBKRI2VlPIj7/vH9xvYYCCbZHe/y3dfj5nMfr/v73d339kJr/3w2e9+v+buiIhIvJRF3YCIiOSewl1EJIYU7iIiMaRwFxGJIYW7iEgMDYq6AYCRI0d6TU1N1G2IiJxQNmzY8Im7V/a0rSjCvaamhqampqjbEBE5oZhZy7G2aVpGRCSGFO4iIjGkcBcRiaGimHPvyYEDB2htbWXfvn1Rt1L0KioqqKqqYvDgwVG3IiJFomjDvbW1lVNOOYWamhrMLOp2ipa7s2fPHlpbWxk7dmzU7YhIkSjaaZl9+/YxYsQIBXsvzIwRI0bofzgi0k3RhjugYM+SXieRE0M6DTU1UFYW3KbT+Xuuop2WERGJk3Qa6uuhoyNYb2kJ1gGSydw/X1GP3EVE4qKh4Uiwd+noCOr5oHAvEocOHYq6BRHJo+3b+1YfqPiEex4ms/bu3cusWbOora1l4sSJrFy5kieffJJzzz2XyZMnM3/+fK699loAfvKTn/DTn/708H0nTpzItm3bALjuuuuYMmUKEyZMIJVKHd7n5JNP5p577qG2tpaXX36ZDRs2cMUVVzBlyhS+8Y1vsGvXLgAWL17M+eefzwUXXMBNN9004N9LRApvzJi+1QcqHnPueZrMevLJJznzzDN5/PHHAWhvb2fixImsW7eOs88+m29/+9tZPc7SpUsZPnw4f/3rX7nooou4/vrrGTFiBHv37uWSSy7hZz/7GQcOHOCKK65g9erVVFZWsnLlShoaGli6dCn3338/W7du5aSTTuLzzz/v9+8jItFpbOweUwCJRFDPh3iM3PM0mfW1r32Np59+mvvuu4/nn3+erVu3MnbsWMaNG4eZMWfOnKweZ/HixdTW1jJ16lR27NjBe++9B0B5eTnXX389AO+++y5vv/02V199NRdeeCELFy6ktbUVgAsuuIBkMsmKFSsYNCge78cipSaZhFQKqqvBLLhNpfLzYSpkMXI3s/HAyozSWcD/ApaH9RpgG3Cju39mwXF5vwBmAh3Ad9x9Y27bPkqeJrPOOeccNm7cyBNPPMGPf/xjpk+ffsx9Bw0aRGdn5+H1ruPOn3vuOZ555hlefvllEokE06ZNO7ytoqKC8vJyIPgy0oQJE3j55Ze/8tiPP/44f/zjH/n9739PY2Mjb731lkJe5ASUTOYvzI/W68jd3d919wvd/UJgCkFgPwYsANa6+zhgbbgOcA0wLvypB5bko/Fu8jSZ9eGHH5JIJJgzZw733nsvL730Etu2beP9998H4OGHHz68b01NDRs3Bu9hGzduZOvWrUAwlXPaaaeRSCR45513eOWVV3p8rvHjx9PW1nY43A8cOMCmTZvo7Oxkx44dXHnllSxatIj29na+/PLLAf1eIhJ/fR3+TQfed/cWM5sNTAvry4DngPuA2cByd3fgFTMbZmZnuPuuHPX8VXmazHrrrbe49957KSsrY/DgwSxZsoRPPvmEWbNmkUgk+PrXv85f/vIXAK6//nqWL1/OhAkTuOSSSzjnnHMAmDFjBr/61a8477zzGD9+PFOnTu3xuYYMGcKqVauYP38+7e3tHDx4kLvuuotzzjmHOXPm0N7ejrszf/58hg0bNqDfS0RKgLtn/QMsBe4Mlz/PqFvXOvAH4PKMbWuBuh4eqx5oAprGjBnjR9u8efNXase1YoV7dbW7WXC7YkXf7t8Pzz77rM+aNSvvz5ONPr9eInLCA5r8GHmd9cjdzIYA3wR+1MMbhJuZ9/FNJQWkAOrq6vp03x4VcjJLRKTI9WVa5hpgo7t/HK5/3DXdYmZnALvD+k5gdMb9qsJa7EybNo1p06ZF3YaIyFf05VDIm4GHM9bXAHPD5bnA6oz6rRaYCrR7PufbRUTkK7IauZvZUOBq4H9klO8HHjGzeUALcGNYf4LgMMhmgiNrbstZtyIikpWswt3d9wIjjqrtITh65uh9HbgjJ92JiEi/xOMbqiIi0o3CPUeee+65wycRExGJmsJdRCSGYhPu+bp81fLly7nggguora3llltu4Tvf+Q6rVq06vP3kk08+vPzFF18wa9Ysxo8fz/e+973D55p56qmnuPTSS5k8eTI33HDD4dMHLFiw4PCpfH/4wx/mpmEREWJyyt98Xb5q06ZNLFy4kJdeeomRI0fy6aefcvfddx9z//Xr17N582aqq6uZMWMGjz76KNOmTWPhwoU888wzDB06lEWLFvHAAw9wxx138Nhjj/HOO+9gZjqVr4jkVCzC/Xhn/B1IuK9bt44bbriBkSNHAjB8+PDj7n/xxRdz1llnAXDzzTfzwgsvUFFRwebNm7nssssA2L9/P5deeimnnnoqFRUVzJs3j2uvvVbz9SKSU7EI90Jevirz1L6dnZ3s37//8LbgbMd0W3d3rr766m5nkOyyfv161q5dy6pVq/jlL3/JunXrct+wiJSkWMy55+vyVVdddRW/+93v2LNnDwCffvopNTU1bNiwAYA1a9Zw4MCBw/uvX7+erVu30tnZycqVK7n88suZOnUqL774Is3NzUBw6b4///nPfPnll7S3tzNz5kwefPBB3njjjYE1KyKSIRYj93xdvmrChAk0NDRwxRVXUF5ezqRJk1i0aBGzZ8+mtraWGTNmMHTo0MP7X3TRRdx55500Nzdz5ZVX8q1vfYuysjJ+85vfcPPNN/O3v/0NgIULF3LKKacwe/Zs9u3bh7vzwAMPDKxZEZEMFnyhNFp1dXXe1NTUrbZlyxbOO++8rB8jnQ7m2LdvD0bsjY2ldZLIvr5eInLiM7MN7l7X07ZYjNxBZ/wVEckUizl3ERHprqjDvRimjE4Eep1E5GhFG+4VFRXs2bNHwdULd2fPnj1UVFRE3YqIFJGinXOvqqqitbWVtra2qFspehUVFVRVVUXdhogUkaIN98GDBzN27Nio2xAROSEV7bSMiIj0n8JdRCSGsgp3MxtmZqvM7B0z22Jml5rZcDN72szeC29PC/c1M1tsZs1m9qaZTc7vryAiIkfLduT+C+BJdz8XqAW2AAuAte4+DlgbrgNcA4wLf+qBJTntWEREetVruJvZqcA/AA8BuPt+d/8cmA0sC3dbBlwXLs8GlnvgFWCYmZ2R885FROSYshm5jwXagH83s9fM7NdmNhQ43d13hft8BJweLo8CdmTcvzWsdWNm9WbWZGZNOtxRRCS3sgn3QcBkYIm7TwL2cmQKBgAPvmnUp28buXvK3evcva6ysrIvdxURkV5kE+6tQKu7vxquryII+4+7plvC293h9p3A6Iz7V4U1EREpkF7D3d0/AnaY2fiwNB3YDKwB5oa1ucDqcHkNcGt41MxUoD1j+kZERAog22+o/hOQNrMhwAfAbQRvDI+Y2TygBbgx3PcJYCbQDHSE+4qISAFlFe7u/jrQ0wnhp/ewrwN3DLAvEREZAH1DVUQkhhTuIiIxpHAXEYkhhbuISAwp3EVEYkjhLiISQwp3EZEYUriLiMSQwl1EJIYU7iIiMaRwFxGJIYW7iEgMKdxFRGJI4S4iEkMKdxGRGFK4i0hJSKehpgbKyoLbdDrqjvIr2ysxiYicsNJpqK+Hjo5gvaUlWAdIJqPrK580cheR2GtoOBLsXTo6gnpcZRXuZrbNzN4ys9fNrCmsDTezp83svfD2tLBuZrbYzJrN7E0zm5zPX0BEpDfbt/etHgd9Gblf6e4XunvXtVQXAGvdfRywNlwHuAYYF/7UA0ty1ayISH+MGdO3ehwMZFpmNrAsXF4GXJdRX+6BV4BhZnbGAJ5HRGRAGhshkeheSySCelxlG+4OPGVmG8ws/BiC0919V7j8EXB6uDwK2JFx39aw1o2Z1ZtZk5k1tbW19aN1EZHsJJOQSkF1NZgFt6lUfD9MheyPlrnc3Xea2d8DT5vZO5kb3d3NzPvyxO6eAlIAdXV1fbqviEhfJZPxDvOjZTVyd/ed4e1u4DHgYuDjrumW8HZ3uPtOYHTG3avCmoiIFEiv4W5mQ83slK5l4L8DbwNrgLnhbnOB1eHyGuDW8KiZqUB7xvSNiIgUQDbTMqcDj5lZ1/7/z92fNLM/AY+Y2TygBbgx3P8JYCbQDHQAt+W8axEROa5ew93dPwBqe6jvAab3UHfgjpx0JyIi/aJvqIqIxJDCXUQkhhTuIiIxpHAXEYkhhbuISAwp3EVEYkjhLiISQwp3EZEYUriLiMSQwl1EJIYU7iIiMaRwFxGJIYW7iEgMKdxFRGJI4S4iEkMKdxGRGFK4i4jEkMJdRCSGsg53Mys3s9fM7A/h+lgze9XMms1spZkNCesnhevN4faa/LQuIiLH0peR+w+ALRnri4AH3f1s4DNgXlifB3wW1h8M9xMRkQLKKtzNrAqYBfw6XDfgKmBVuMsy4LpweXa4Trh9eri/iIgUSLYj958D/wx0husjgM/d/WC43gqMCpdHATsAwu3t4f7dmFm9mTWZWVNbW1s/2xcRkZ70Gu5mdi2w29035PKJ3T3l7nXuXldZWZnLhxYRKXmDstjnMuCbZjYTqAD+C/ALYJiZDQpH51XAznD/ncBooNXMBgGnAnty3rmIiBxTryN3d/+Ru1e5ew1wE7DO3ZPAs8A/hrvNBVaHy2vCdcLt69zdc9q1iIgc10COc78PuNvMmgnm1B8K6w8BI8L63cCCgbUoIiJ9lc20zGHu/hzwXLj8AXBxD/vsA27IQW8iItJP+oaqiEgMKdxFRGJI4S4iEkMKdxGRGFK4i4jEkMJdRCSGFO4iIjGkcBeRvEqnoaYGysqC23Q66o5KQ5++xCQi0hfpNNTXQ0dHsN7SEqwDJJPR9VUKNHIXkbxpaDgS7F06OoK65JfCXUTyZvv2vtUldxTuIpI3Y8b0rS65o3AXkbxpbIREonstkQjqkl8KdxHJm2QSUimorgaz4DaV0oephaCjZUQkr5JJhXkUNHIXEYkhhbuISAwp3EVEYqjXcDezCjNbb2ZvmNkmM/s/YX2smb1qZs1mttLMhoT1k8L15nB7TX5/BREROVo2I/e/AVe5ey1wITDDzKYCi4AH3f1s4DNgXrj/POCzsP5guJ+IiBRQr+HugS/D1cHhjwNXAavC+jLgunB5drhOuH26mVnOOhYRkV5lNeduZuVm9jqwG3gaeB/43N0Phru0AqPC5VHADoBwezswIpdNi4jI8WUV7u5+yN0vBKqAi4FzB/rEZlZvZk1m1tTW1jbQhxMRkQx9OlrG3T8HngUuBYaZWdeXoKqAneHyTmA0QLj9VGBPD4+Vcvc6d6+rrKzsZ/siItKTbI6WqTSzYeHy3wFXA1sIQv4fw93mAqvD5TXhOuH2de7uuWxaRESOL5vTD5wBLDOzcoI3g0fc/Q9mthn4rZktBF4DHgr3fwj4v2bWDHwK3JSHvkVE5Dh6DXd3fxOY1EP9A4L596Pr+4AbctKdiIj0i76hKiISQwp3EZEYUriLiMSQwl1EJIYU7iIiMaRwFxGJIYW7iEgMKdxFRGJI4S4iEkMKdxGRGFK4i4jEkMJdRCSGFO4iIjGkcBcRiSGFu0iMpdNQUwNlZcFtOh11R1Io2VysQ0ROQOk01NdDR0ew3tISrAMkk9H1JYWhkbtITDU0HAn2Lh0dQV3iT+EuElPbt/etLvGSzQWyR5vZs2a22cw2mdkPwvpwM3vazN4Lb08L62Zmi82s2czeNLPJ+f4lROSrxozpW13iJZuR+0HgHnc/H5gK3GFm5wMLgLXuPg5YG64DXAOMC3/qgSU571pEetXYCIlE91oiEdQl/noNd3ff5e4bw+W/AFuAUcBsYFm42zLgunB5NrDcA68Aw8zsjJx3LiLHlUxCKgXV1WAW3KZS+jC1VPTpaBkzqwEmAa8Cp7v7rnDTR8Dp4fIoYEfG3VrD2i5EpKCSSYV5qcr6A1UzOxn4D+Aud/8ic5u7O+B9eWIzqzezJjNramtr68tdRUSkF1mFu5kNJgj2tLs/GpY/7ppuCW93h/WdwOiMu1eFtW7cPeXude5eV1lZ2d/+RUSkB9kcLWPAQ8AWd38gY9MaYG64PBdYnVG/NTxqZirQnjF9IyIiBZDNnPtlwC3AW2b2elj7F+B+4BEzmwe0ADeG254AZgLNQAdwW047FhGRXvUa7u7+AmDH2Dy9h/0duGOAfYmIyADoG6oiIjGkcBcRiSGFu4hIDCncRURiSOEuIhJDCncRkRhSuIuIxJDCXUQkhhTuIiIxpHAXEYkhhbuISAwp3EVEYkjhLpIn6TTU1EBZWXCbTkfdkZSSPl1mT0Syk05DfT10dATrLS3BOuiyd1IYGrmL5EFDw5Fg79LREdRFCkHhLpIH27f3rS6Sawp3kTwYM6ZvdZFcU7iL5EFjIyQS3WuJRFAXKYRsLpC91Mx2m9nbGbXhZva0mb0X3p4W1s3MFptZs5m9aWaT89m8SLFKJiGVgupqMAtuUyl9mCqFk83I/TfAjKNqC4C17j4OWBuuA1wDjAt/6oEluWlT5MSTTMK2bdDZGdwq2KWQeg13d/8j8OlR5dnAsnB5GXBdRn25B14BhpnZGblqVkREstPfOffT3X1XuPwRcHq4PArYkbFfa1gTEZECGvAHqu7ugPf1fmZWb2ZNZtbU1tY20DZERCRDf8P9467plvB2d1jfCYzO2K8qrH2Fu6fcvc7d6yorK/vZhoiI9KS/4b4GmBsuzwVWZ9RvDY+amQq0Z0zfiIiUtgKecKjXc8uY2cPANGCkmbUC/xu4H3jEzOYBLcCN4e5PADOBZqADuC0PPYuInHgKfMIhC6bMo1VXV+dNTU1RtyEikj81NUGgH626OjhWth/MbIO71/W0Td9QFREphAKfcEjhLrGkc6nLV0T9R1HgEw4p3CV2uqY2W1rA/cjUpgK+hBXDH0WBTzikOXeJnTxMbcqJrlj+KNLp4KT+27cHI/bGxgF9mHq8OXeFu8ROWVkwODuaWXCeFylBMf2j0AeqUlJ0LvUiE/VcN5TkH4XCXWJH51IvIsUw1w0l+UehcJfY0bnUi0ixXEy2BP8oNOcuIvkT07nuYqE5d5FSFfV8dwnOdRcLhbtIXBXDfHcJznUXC4W75FzUg0UJFcN8dwnOdRcLhbvkVDEMFotG1O9yBT6XyTHpYrKRULhLThXDYLEoFMO7nOa7S5rCXXKqWAaLkY+ai+FdTvPdJU3hLjlVFIPFYhg1F8O7nOa7S5rCPWaiHrA2NkJiyMFutcSQg4UdLBbDqLko3uXQfHcJU7jHSDEMWJOkSfl3qWYbRifVbCPl3yVJiY2aNSUiUXP3nP8AM4B3Ca6luqC3/adMmeInuhW3P+/V5TvcOOTV5Tt8xe3PF7yH6mr3INa7/1RXl1gTxdCDu/uKFcFzmgW3K1YU9vkl9oAmP1YOH2tDf3+AcuB94CxgCPAGcP7x7jPQcI86WFfc/rwn+LJbjiT4suB9GJ09ZprRWcAmrOdgNStcDytWuCcS3Z8/kVC4SuwcL9zzMS1zMdDs7h+4+37gt8DsPDwPAOnvv0D9kkm0HKrCKaPlUBX1SyaR/v4L+XrKr2hI1dDB0G61DobSkKopWA8AY8p39qmenyaKYK5ZHySK5CXcRwE7MtZbw1peFEOwbj90Zp/q+dJ46D4S7O1WS7CXxkP3FbCJIplr1geJUuIi+0DVzOrNrMnMmtra2vr9OMUQrGPKP+xTPV+S1S+S4qgPM/kuyeoXC9iERs0ixSAf4b4TGJ2xXhXWunH3lLvXuXtdZWVlv5+sGIK1sX5bzyPm+m0F6yFopJFkYjXbGEsn5WxjLMnEao2aRUpQPsL9T8A4MxtrZkOAm4A1eXgeoDiCNflvl5O6/TWqy1uDEXN5K6nbXyP5b5cXrIegEY2aRSSQl4t1mNlM4OcER84sdffjDh0HerGO9PdfoCFVw/ZDZzKm/EMa67cVPlhFRArseBfr0JWYREROULoSk4hIiVG4i4jEkMJdRCSGFO4iIjGkcBcRiaGiOFrGzNqAlhw81Ejgkxw8ThzotQjodQjodTgiTq9Ftbv3+C3Qogj3XDGzpmMdFlRq9FoE9DoE9DocUSqvhaZlRERiSOEuIhJDcQv3VNQNFBG9FgG9DgG9DkeUxGsRqzl3EREJxG3kLiIiKNxFRGIpNuFuZjPM7F0zazazBVH3EwUzG21mz5rZZjPbZGY/iLqnKJlZuZm9ZmZ/iLqXKJnZMDNbZWbvmNkWM7s06p6iYGb/M/x38baZPWxmFVH3lE+xCHczKwf+FbgGOB+42czOj7arSBwE7nH384GpwB0l+jp0+QGwJeomisAvgCfd/VyglhJ8TcxsFDAfqHP3iQTXmrgp2q7yKxbhDlwMNLv7B+6+H/gtMDvingrO3Xe5+8Zw+S8E/4jzdnHyYmZmVcAs4NdR9xIlMzsV+AfgIQB33+/un0fbVWQGAX9nZoOABFDYixwXWFzCfRSwI2O9lRINtS5mVgNMAl6NtpPI/Bz4Z6Az6kYiNhZoA/49nKL6tZkNjbqpQnP3ncBPge3ALqDd3Z+Ktqv8iku4SwYzOxn4D+Aud/8i6n4KzcyuBXa7+4aoeykCg4DJwBJ3nwTsBUruMykzO43gf/NjgTOBoWY2J9qu8isu4b4TGJ2xXhXWSo6ZDSYI9rS7Pxp1PxG5DPimmW0jmKK7ysxWRNtSZFqBVnfv+h/cKoKwLzX/Ddjq7m3ufgB4FPivEfeUV3EJ9z8B48xsrJkNIfigZE3EPRWcmRnB3OoWd38g6n6i4u4/cvcqd68h+FtY5+6xHqUdi7t/BOwws/FhaTqwOcKWorIdmGpmifDfyXRi/sHyoKgbyAV3P2hmdwL/SfAp+FJ33xRxW1G4DLgFeMvMXg9r/+LuT0TYk0Tvn4B0OPD5ALgt4n4Kzt1fNbNVwEaCo8peI+anIdDpB0REYigu0zIiIpJB4S4iEkMKdxGRGFK4i4jEkMJdRCSGFO4iIjGkcBcRiaH/DwETy1MYnSGJAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] @@ -239,6 +211,16 @@ "metadata": { "scrolled": true }, + "outputs": [], + "source": [ + "average_square = squares | 'Average Square' >> beam.CombineGlobally(AverageFn())\n", + "average_cube = cubes | 'Average Cube' >> beam.CombineGlobally(AverageFn())" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, "outputs": [ { "data": { @@ -246,94 +228,96 @@ "\n", "\n", - "\n", "\n", - "\n", - "\n", + "\n", + "\n", "G\n", - "\n", - "\n", - "\n", - "Square\n", - "\n", - "Square\n", - "\n", - "\n", - "Average Square\n", - "\n", - "Average Square\n", - "\n", - "\n", - "Square->Average Square\n", - "\n", - "\n", - "\n", - "{36, ...}\n", - "\n", + "\n", + "\n", + "\n", + "Create\n", + "\n", + "Create\n", "\n", + "\n", + "\n", + "diverge6871\n", + "\n", "\n", - "\n", - "\n", - "Create\n", - "\n", - "Create\n", + "\n", + "\n", + "Create->diverge6871\n", + "\n", + "init_pcoll\n", "\n", - "\n", - "Create->Square\n", - "\n", - "\n", - "\n", - "{8, ...}\n", - "\n", + "\n", + "\n", + "Square\n", + "\n", + "Square\n", "\n", + "\n", + "\n", + "diverge6871->Square\n", + "\n", + "\n", "\n", "\n", - "Cube\n", - "\n", - "Cube\n", - "\n", - "\n", - "Create->Cube\n", - "\n", - "\n", - "\n", - "{8, ...}\n", - "\n", + "\n", + "Cube\n", + "\n", + "Cube\n", "\n", + "\n", + "\n", + "diverge6871->Cube\n", + "\n", + "\n", "\n", - "\n", - "Average Square->leaf7582\n", - "\n", - "\n", - "\n", - "{28.5}\n", - "\n", + "\n", + "\n", + "Average Square\n", + "\n", + "Average Square\n", "\n", + "\n", + "\n", + "Square->Average Square\n", + "\n", + "\n", + "squares\n", "\n", "\n", - "Average Cube\n", - "\n", - "Average Cube\n", + "\n", + "Average Cube\n", + "\n", + "Average Cube\n", "\n", "\n", - "Cube->Average Cube\n", - "\n", - "\n", - "\n", - "{27, ...}\n", - "\n", + "\n", + "Cube->Average Cube\n", + "\n", + "\n", + "cubes\n", "\n", + "\n", + "\n", + "\n", + "Average Square->leaf1333\n", + "\n", + "\n", + "average_square\n", "\n", - "\n", - "Average Cube->leaf7574\n", - "\n", - "\n", - "\n", - "{202.5}\n", - "\n", - "\n", + "\n", + "\n", + "\n", + "Average Cube->leaf6554\n", + "\n", + "\n", + "average_cube\n", "\n", "\n", "\n" @@ -344,89 +328,32 @@ }, "metadata": {}, "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "Running..." - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "Using 2 cached PCollections\n", - "Executing 8 of 5 transforms." - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "Average Cube produced {202.5}" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "Cube produced {27, 729, 64, 343, 512, ...}" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "Square produced {36, 1, 9, 25, 81, ...}" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "Average Square produced {28.5}" - ] - }, - "metadata": {}, - "output_type": "display_data" } ], "source": [ - "average_square = squares | 'Average Square' >> beam.CombineGlobally(AverageFn())\n", - "average_cube = cubes | 'Average Cube' >> beam.CombineGlobally(AverageFn())\n", - "result = p.run(False)" + "result = p.run()" ] } ], "metadata": { "kernelspec": { - "display_name": "Python (beam_venv)", + "display_name": "Python3 (ib_venv)", "language": "python", - "name": "beam_venv_kernel" + "name": "p3_ib_venv" }, "language_info": { "codemirror_mode": { "name": "ipython", - "version": 2 + "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.13" + "pygments_lexer": "ipython3", + "version": "3.7.4" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/sdks/python/apache_beam/runners/interactive/interactive_runner.py b/sdks/python/apache_beam/runners/interactive/interactive_runner.py index db101ce6131b..b0222c3977f8 100644 --- a/sdks/python/apache_beam/runners/interactive/interactive_runner.py +++ b/sdks/python/apache_beam/runners/interactive/interactive_runner.py @@ -30,9 +30,9 @@ from apache_beam import runners from apache_beam.runners.direct import direct_runner from apache_beam.runners.interactive import cache_manager as cache -from apache_beam.runners.interactive import pipeline_analyzer -from apache_beam.runners.interactive.display import display_manager -from apache_beam.runners.interactive.display import pipeline_graph_renderer +from apache_beam.runners.interactive import interactive_environment as ie +from apache_beam.runners.interactive import pipeline_instrument as inst +from apache_beam.runners.interactive.display import pipeline_graph # size of PCollection samples cached. SAMPLE_SIZE = 8 @@ -68,8 +68,12 @@ def __init__(self, """ self._underlying_runner = (underlying_runner or direct_runner.DirectRunner()) - self._cache_manager = cache.FileBasedCacheManager(cache_dir, cache_format) - self._renderer = pipeline_graph_renderer.get_renderer(render_option) + if not ie.current_env().cache_manager(): + ie.current_env().set_cache_manager( + cache.FileBasedCacheManager(cache_dir, + cache_format)) + self._cache_manager = ie.current_env().cache_manager() + self._render_option = render_option self._in_session = False self._skip_display = skip_display @@ -84,7 +88,7 @@ def set_render_option(self, render_option): render_option: (str) this parameter decides how the pipeline graph is rendered. See display.pipeline_graph_renderer for available options. """ - self._renderer = pipeline_graph_renderer.get_renderer(render_option) + self._render_option = render_option def start_session(self): """Start the session that keeps back-end managers and workers alive. @@ -120,123 +124,53 @@ def apply(self, transform, pvalueish, options): return self._underlying_runner.apply(transform, pvalueish, options) def run_pipeline(self, pipeline, options): - if not hasattr(self, '_desired_cache_labels'): - self._desired_cache_labels = set() - - # Invoke a round trip through the runner API. This makes sure the Pipeline - # proto is stable. - pipeline = beam.pipeline.Pipeline.from_runner_api( - pipeline.to_runner_api(use_fake_coders=True), - pipeline.runner, - options) - - # Snapshot the pipeline in a portable proto before mutating it. - pipeline_proto, original_context = pipeline.to_runner_api( - return_context=True, use_fake_coders=True) - pcolls_to_pcoll_id = self._pcolls_to_pcoll_id(pipeline, original_context) - - analyzer = pipeline_analyzer.PipelineAnalyzer(self._cache_manager, - pipeline_proto, - self._underlying_runner, - options, - self._desired_cache_labels) - # Should be only accessed for debugging purpose. - self._analyzer = analyzer + pipeline_instrument = inst.pin(pipeline, options) pipeline_to_execute = beam.pipeline.Pipeline.from_runner_api( - analyzer.pipeline_proto_to_execute(), + pipeline_instrument.instrumented_pipeline_proto(), self._underlying_runner, options) if not self._skip_display: - display = display_manager.DisplayManager( - pipeline_proto=pipeline_proto, - pipeline_analyzer=analyzer, - cache_manager=self._cache_manager, - pipeline_graph_renderer=self._renderer) - display.start_periodic_update() + a_pipeline_graph = pipeline_graph.PipelineGraph( + pipeline_instrument.original_pipeline, + render_option=self._render_option) + a_pipeline_graph.display_graph() result = pipeline_to_execute.run() result.wait_until_finish() - if not self._skip_display: - display.stop_periodic_update() + return PipelineResult(result, pipeline_instrument) - return PipelineResult(result, self, self._analyzer.pipeline_info(), - self._cache_manager, pcolls_to_pcoll_id) - def _pcolls_to_pcoll_id(self, pipeline, original_context): - """Returns a dict mapping PCollections string to PCollection IDs. +class PipelineResult(beam.runners.runner.PipelineResult): + """Provides access to information about a pipeline.""" - Using a PipelineVisitor to iterate over every node in the pipeline, - records the mapping from PCollections to PCollections IDs. This mapping - will be used to query cached PCollections. + def __init__(self, underlying_result, pipeline_instrument): + """Constructor of PipelineResult. Args: - pipeline: (pipeline.Pipeline) - original_context: (pipeline_context.PipelineContext) - - Returns: - (dict from str to str) a dict mapping str(pcoll) to pcoll_id. + underlying_result: (PipelineResult) the result returned by the underlying + runner running the pipeline. + pipeline_instrument: (PipelineInstrument) pipeline instrument describing + the pipeline being executed with interactivity applied and related + metadata including where the interactivity-backing cache lies. """ - pcolls_to_pcoll_id = {} - - from apache_beam.pipeline import PipelineVisitor # pylint: disable=import-error - - class PCollVisitor(PipelineVisitor): # pylint: disable=used-before-assignment - """"A visitor that records input and output values to be replaced. - - Input and output values that should be updated are recorded in maps - input_replacements and output_replacements respectively. - - We cannot update input and output values while visiting since that - results in validation errors. - """ - - def enter_composite_transform(self, transform_node): - self.visit_transform(transform_node) - - def visit_transform(self, transform_node): - for pcoll in transform_node.outputs.values(): - pcolls_to_pcoll_id[str(pcoll)] = original_context.pcollections.get_id( - pcoll) - - pipeline.visit(PCollVisitor()) - return pcolls_to_pcoll_id - - -class PipelineResult(beam.runners.runner.PipelineResult): - """Provides access to information about a pipeline.""" - - def __init__(self, underlying_result, runner, pipeline_info, cache_manager, - pcolls_to_pcoll_id): super(PipelineResult, self).__init__(underlying_result.state) - self._runner = runner - self._pipeline_info = pipeline_info - self._cache_manager = cache_manager - self._pcolls_to_pcoll_id = pcolls_to_pcoll_id - - def _cache_label(self, pcoll): - pcoll_id = self._pcolls_to_pcoll_id[str(pcoll)] - return self._pipeline_info.cache_label(pcoll_id) + self._underlying_result = underlying_result + self._pipeline_instrument = pipeline_instrument def wait_until_finish(self): # PipelineResult is not constructed until pipeline execution is finished. return def get(self, pcoll): - cache_label = self._cache_label(pcoll) - if self._cache_manager.exists('full', cache_label): - pcoll_list, _ = self._cache_manager.read('full', cache_label) + key = self._pipeline_instrument.cache_key(pcoll) + if ie.current_env().cache_manager().exists('full', key): + pcoll_list, _ = ie.current_env().cache_manager().read('full', key) return pcoll_list else: - self._runner._desired_cache_labels.add(cache_label) # pylint: disable=protected-access raise ValueError('PCollection not available, please run the pipeline.') - def sample(self, pcoll): - cache_label = self._cache_label(pcoll) - if self._cache_manager.exists('sample', cache_label): - return self._cache_manager.read('sample', cache_label) - else: - self._runner._desired_cache_labels.add(cache_label) # pylint: disable=protected-access - raise ValueError('PCollection not available, please run the pipeline.') + def cancel(self): + self._underlying_result.cancel() diff --git a/sdks/python/apache_beam/runners/interactive/interactive_runner_test.py b/sdks/python/apache_beam/runners/interactive/interactive_runner_test.py index 9958d218570b..36ebce8e724f 100644 --- a/sdks/python/apache_beam/runners/interactive/interactive_runner_test.py +++ b/sdks/python/apache_beam/runners/interactive/interactive_runner_test.py @@ -28,6 +28,7 @@ import apache_beam as beam from apache_beam.runners.direct import direct_runner +from apache_beam.runners.interactive import interactive_beam as ib from apache_beam.runners.interactive import interactive_runner @@ -79,6 +80,9 @@ def process(self, element): | 'group' >> beam.GroupByKey() | 'count' >> beam.Map(lambda wordones: (wordones[0], sum(wordones[1])))) + # Watch the local scope for Interactive Beam so that counts will be cached. + ib.watch(locals()) + result = p.run() result.wait_until_finish() diff --git a/sdks/python/apache_beam/runners/interactive/pipeline_analyzer.py b/sdks/python/apache_beam/runners/interactive/pipeline_analyzer.py index 860bbe263fa8..ab56921a83d4 100644 --- a/sdks/python/apache_beam/runners/interactive/pipeline_analyzer.py +++ b/sdks/python/apache_beam/runners/interactive/pipeline_analyzer.py @@ -79,7 +79,7 @@ def _analyze_pipeline(self): 1. Start from target PCollections and recursively insert the producing PTransforms of those PCollections, where the producing PTransforms are either ReadCache or PTransforms in the original pipeline. - 2. Append WriteCache PTransforsm in the pipeline. + 2. Append WriteCache PTransforms in the pipeline. After running this function, the following variables will be set: self._pipeline_proto_to_execute @@ -343,7 +343,7 @@ def _top_level_producer(self, pcoll): pcoll: (PCollection) Returns: - (PTransform) top level producing PTransform of pcoll. + (AppliedPTransform) top level producing AppliedPTransform of pcoll. """ top_level_transform = pcoll.producer while top_level_transform.parent.parent: @@ -354,10 +354,10 @@ def _include_subtransforms(self, transform): """Depth-first yield the PTransform itself and its sub transforms. Args: - transform: (PTransform) + transform: (AppliedPTransform) Yields: - The input PTransform itself and all its sub transforms. + The input AppliedPTransform itself and all its sub transforms. """ yield transform for subtransform in transform.parts[::-1]: diff --git a/sdks/python/apache_beam/runners/interactive/pipeline_instrument.py b/sdks/python/apache_beam/runners/interactive/pipeline_instrument.py index 63664ed29872..503ff26dd17f 100644 --- a/sdks/python/apache_beam/runners/interactive/pipeline_instrument.py +++ b/sdks/python/apache_beam/runners/interactive/pipeline_instrument.py @@ -80,7 +80,10 @@ def __init__(self, pipeline, options=None): # A mapping from PCollection id to python id() value in user defined # pipeline instance. (self._pcoll_version_map, - self._cacheables) = cacheables(self.pcolls_to_pcoll_id) + self._cacheables, + # A dict from pcoll_id to variable name of the referenced PCollection. + # (Dict[str, str]) + self._cacheable_var_by_pcoll_id) = cacheables(self.pcolls_to_pcoll_id) # A dict from cache key to PCollection that is read from cache. # If exists, caller should reuse the PCollection read. If not, caller @@ -109,10 +112,10 @@ def cacheables(self): """Finds cacheable PCollections from the pipeline. The function only treats the result as cacheables since there is no - guarantee whether the cache desired PCollection has been cached or - not. A PCollection desires caching when it's bound to a user defined - variable in source code. Otherwise, the PCollection is not reusale - nor introspectable which nullifying the need of cache. + guarantee whether PCollections that need to be cached have been cached or + not. A PCollection needs to be cached when it's bound to a user defined + variable in the source code. Otherwise, the PCollection is not reusable + nor introspectable which nullifies the need of cache. """ return self._cacheables @@ -142,7 +145,6 @@ def instrument(self): Modifies: self._pipeline """ - self._preprocess() cacheable_inputs = set() class InstrumentVisitor(PipelineVisitor): @@ -169,7 +171,7 @@ def visit_transform(self, transform_node): self._write_cache(cacheable['pcoll']) # TODO(BEAM-7760): prune sub graphs that doesn't need to be executed. - def _preprocess(self): + def preprocess(self): """Pre-processes the pipeline. Since the pipeline instance in the class might not be the same instance @@ -314,29 +316,48 @@ def cache_key(self, pcoll): cacheable['producer_version'])) return '' + def cacheable_var_by_pcoll_id(self, pcoll_id): + """Retrieves the variable name of a PCollection. + + In source code, PCollection variables are defined in the user pipeline. When + it's converted to the runner api representation, each PCollection referenced + in the user pipeline is assigned a unique-within-pipeline pcoll_id. Given + such pcoll_id, retrieves the str variable name defined in user pipeline for + that referenced PCollection. If the PCollection is not watched, return None. + """ + return self._cacheable_var_by_pcoll_id.get(pcoll_id, None) + def pin(pipeline, options=None): - """Creates PipelineInstrument for a pipeline and its options with cache.""" + """Creates PipelineInstrument for a pipeline and its options with cache. + + This is the shorthand for doing 3 steps: 1) compute once for metadata of the + given runner pipeline and everything watched from user pipelines; 2) associate + info between the runner pipeline and its corresponding user pipeline, + eliminate data from other user pipelines if there are any; 3) mutate the + runner pipeline to apply interactivity. + """ pi = PipelineInstrument(pipeline, options) + pi.preprocess() pi.instrument() # Instruments the pipeline only once. return pi def cacheables(pcolls_to_pcoll_id): - """Finds cache desired PCollections from the instrumented pipeline. - - The function only treats the result as cacheables since whether the cache - desired PCollection has been cached depends on whether the pipeline has been - executed in current interactive environment. A PCollection desires caching - when it's bound to a user defined variable in source code. Otherwise, the - PCollection is not reusable nor introspectable which nullifies the need of - cache. There might be multiple pipelines defined and watched, this will - return for PCollections from the ones with pcolls_to_pcoll_id analyzed. The - check is not strict because pcoll_id is not unique across multiple pipelines. - Additional check needs to be done during instrument. + """Finds PCollections that need to be cached for analyzed PCollections. + + The function only treats the result as cacheables since there is no guarantee + whether PCollections that need to be cached have been cached or not. A + PCollection needs to be cached when it's bound to a user defined variable in + the source code. Otherwise, the PCollection is not reusable nor introspectable + which nullifies the need of cache. There might be multiple pipelines defined + and watched, this will only return for PCollections with pcolls_to_pcoll_id + analyzed. The check is not strict because pcoll_id is not unique across + multiple pipelines. Additional check needs to be done during instrument. """ pcoll_version_map = {} cacheables = {} + cacheable_var_by_pcoll_id = {} for watching in ie.current_env().watching(): for key, val in watching: # TODO(BEAM-8288): cleanup the attribute check when py2 is not supported. @@ -353,7 +374,8 @@ def cacheables(pcolls_to_pcoll_id): cacheable['producer_version'] = str(id(val.producer)) cacheables[cacheable_key(val, pcolls_to_pcoll_id)] = cacheable pcoll_version_map[cacheable['pcoll_id']] = cacheable['version'] - return pcoll_version_map, cacheables + cacheable_var_by_pcoll_id[cacheable['pcoll_id']] = key + return pcoll_version_map, cacheables, cacheable_var_by_pcoll_id def cacheable_key(pcoll, pcolls_to_pcoll_id, pcoll_version_map=None): diff --git a/sdks/python/apache_beam/runners/portability/flink_runner.py b/sdks/python/apache_beam/runners/portability/flink_runner.py index 0743fedbcdf7..085aa8b53aed 100644 --- a/sdks/python/apache_beam/runners/portability/flink_runner.py +++ b/sdks/python/apache_beam/runners/portability/flink_runner.py @@ -29,7 +29,6 @@ from apache_beam.runners.portability import job_server from apache_beam.runners.portability import portable_runner -PUBLISHED_FLINK_VERSIONS = ['1.7', '1.8', '1.9'] MAGIC_HOST_NAMES = ['[local]', '[auto]'] _LOGGER = logging.getLogger(__name__) @@ -38,24 +37,27 @@ class FlinkRunner(portable_runner.PortableRunner): def run_pipeline(self, pipeline, options): portable_options = options.view_as(pipeline_options.PortableOptions) - if (options.view_as(FlinkRunnerOptions).flink_master in MAGIC_HOST_NAMES + flink_options = options.view_as(pipeline_options.FlinkRunnerOptions) + if (flink_options.flink_master in MAGIC_HOST_NAMES and not portable_options.environment_type and not portable_options.output_executable_path): portable_options.environment_type = 'LOOPBACK' return super(FlinkRunner, self).run_pipeline(pipeline, options) def default_job_server(self, options): + flink_options = options.view_as(pipeline_options.FlinkRunnerOptions) flink_master = self.add_http_scheme( - options.view_as(FlinkRunnerOptions).flink_master) - options.view_as(FlinkRunnerOptions).flink_master = flink_master + flink_options.flink_master) + flink_options.flink_master = flink_master if flink_master in MAGIC_HOST_NAMES or sys.version_info < (3, 6): return job_server.StopOnExitJobServer(FlinkJarJobServer(options)) else: # This has to be changed [auto], otherwise we will attempt to submit a # the pipeline remotely on the Flink JobMaster which will _fail_. # DO NOT CHANGE the following line, unless you have tested this. - options.view_as(FlinkRunnerOptions).flink_master = '[auto]' - return flink_uber_jar_job_server.FlinkUberJarJobServer(flink_master) + flink_options.flink_master = '[auto]' + return flink_uber_jar_job_server.FlinkUberJarJobServer( + flink_master, options) @staticmethod def add_http_scheme(flink_master): @@ -69,33 +71,13 @@ def add_http_scheme(flink_master): return flink_master -class FlinkRunnerOptions(pipeline_options.PipelineOptions): - @classmethod - def _add_argparse_args(cls, parser): - parser.add_argument('--flink_master', - default='[auto]', - help='Flink master address (http://host:port)' - ' Use "[local]" to start a local cluster' - ' for the execution. Use "[auto]" if you' - ' plan to either execute locally or let the' - ' Flink job server infer the cluster address.') - parser.add_argument('--flink_version', - default=PUBLISHED_FLINK_VERSIONS[-1], - choices=PUBLISHED_FLINK_VERSIONS, - help='Flink version to use.') - parser.add_argument('--flink_job_server_jar', - help='Path or URL to a flink jobserver jar.') - parser.add_argument('--artifacts_dir', default=None) - - class FlinkJarJobServer(job_server.JavaJarJobServer): def __init__(self, options): - super(FlinkJarJobServer, self).__init__() - options = options.view_as(FlinkRunnerOptions) + super(FlinkJarJobServer, self).__init__(options) + options = options.view_as(pipeline_options.FlinkRunnerOptions) self._jar = options.flink_job_server_jar self._master_url = options.flink_master self._flink_version = options.flink_version - self._artifacts_dir = options.artifacts_dir def path_to_jar(self): if self._jar: @@ -104,12 +86,12 @@ def path_to_jar(self): return self.path_to_beam_jar( 'runners:flink:%s:job-server:shadowJar' % self._flink_version) - def java_arguments(self, job_port, artifacts_dir): + def java_arguments( + self, job_port, artifact_port, expansion_port, artifacts_dir): return [ '--flink-master', self._master_url, - '--artifacts-dir', (self._artifacts_dir - if self._artifacts_dir else artifacts_dir), + '--artifacts-dir', artifacts_dir, '--job-port', job_port, - '--artifact-port', 0, - '--expansion-port', 0 + '--artifact-port', artifact_port, + '--expansion-port', expansion_port ] diff --git a/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server.py b/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server.py index 8b9aee461d0c..c67edc46277a 100644 --- a/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server.py +++ b/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server.py @@ -33,6 +33,7 @@ import requests from google.protobuf import json_format +from apache_beam.options import pipeline_options from apache_beam.portability.api import beam_artifact_api_pb2_grpc from apache_beam.portability.api import beam_job_api_pb2 from apache_beam.portability.api import endpoints_pb2 @@ -50,10 +51,13 @@ class FlinkUberJarJobServer(abstract_job_service.AbstractJobServiceServicer): the pipeline artifacts. """ - def __init__(self, master_url, executable_jar=None): + def __init__(self, master_url, options): super(FlinkUberJarJobServer, self).__init__() self._master_url = master_url - self._executable_jar = executable_jar + self._executable_jar = (options.view_as(pipeline_options.FlinkRunnerOptions) + .flink_job_server_jar) + self._artifact_port = (options.view_as(pipeline_options.JobServerOptions) + .artifact_port) self._temp_dir = tempfile.mkdtemp(prefix='apache-beam-flink') def start(self): @@ -63,9 +67,10 @@ def stop(self): pass def executable_jar(self): - return self._executable_jar or job_server.JavaJarJobServer.local_jar( - job_server.JavaJarJobServer.path_to_beam_jar( - 'runners:flink:%s:job-server:shadowJar' % self.flink_version())) + url = (self._executable_jar or + job_server.JavaJarJobServer.path_to_beam_jar( + 'runners:flink:%s:job-server:shadowJar' % self.flink_version())) + return job_server.JavaJarJobServer.local_jar(url) def flink_version(self): full_version = requests.get( @@ -80,7 +85,8 @@ def create_beam_job(self, job_id, job_name, pipeline, options): job_id, job_name, pipeline, - options) + options, + artifact_port=self._artifact_port) class FlinkBeamJob(abstract_job_service.AbstractBeamJob): @@ -101,11 +107,13 @@ class FlinkBeamJob(abstract_job_service.AbstractBeamJob): [PIPELINE_FOLDER, PIPELINE_NAME, 'artifact-manifest.json']) def __init__( - self, master_url, executable_jar, job_id, job_name, pipeline, options): + self, master_url, executable_jar, job_id, job_name, pipeline, options, + artifact_port=0): super(FlinkBeamJob, self).__init__(job_id, job_name, pipeline, options) self._master_url = master_url self._executable_jar = executable_jar self._jar_uploaded = False + self._artifact_port = artifact_port def prepare(self): # Copy the executable jar, injecting the pipeline and options as resources. @@ -122,13 +130,14 @@ def prepare(self): with z.open(self.PIPELINE_MANIFEST, 'w') as fout: fout.write(json.dumps( {'defaultJobName': self.PIPELINE_NAME}).encode('utf-8')) - self._start_artifact_service(self._jar) + self._start_artifact_service(self._jar, self._artifact_port) - def _start_artifact_service(self, jar): + def _start_artifact_service(self, jar, requested_port): self._artifact_staging_service = artifact_service.ZipFileArtifactService( jar) self._artifact_staging_server = grpc.server(futures.ThreadPoolExecutor()) - port = self._artifact_staging_server.add_insecure_port('[::]:0') + port = self._artifact_staging_server.add_insecure_port( + '[::]:%s' % requested_port) beam_artifact_api_pb2_grpc.add_ArtifactStagingServiceServicer_to_server( self._artifact_staging_service, self._artifact_staging_server) self._artifact_staging_endpoint = endpoints_pb2.ApiServiceDescriptor( diff --git a/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server_test.py b/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server_test.py index 96867c66a662..dba328e87be7 100644 --- a/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server_test.py +++ b/sdks/python/apache_beam/runners/portability/flink_uber_jar_job_server_test.py @@ -28,6 +28,7 @@ import grpc import requests_mock +from apache_beam.options import pipeline_options from apache_beam.portability.api import beam_artifact_api_pb2 from apache_beam.portability.api import beam_artifact_api_pb2_grpc from apache_beam.portability.api import beam_job_api_pb2 @@ -51,7 +52,7 @@ class FlinkUberJarJobServerTest(unittest.TestCase): def test_flink_version(self, http_mock): http_mock.get('http://flink/v1/config', json={'flink-version': '3.1.4.1'}) job_server = flink_uber_jar_job_server.FlinkUberJarJobServer( - 'http://flink', None) + 'http://flink', pipeline_options.FlinkRunnerOptions()) self.assertEqual(job_server.flink_version(), "3.1") @requests_mock.mock() @@ -62,8 +63,10 @@ def test_end_to_end(self, http_mock): with zip.open('FakeClass.class', 'w') as fout: fout.write(b'[original_contents]') + options = pipeline_options.FlinkRunnerOptions() + options.flink_job_server_jar = fake_jar job_server = flink_uber_jar_job_server.FlinkUberJarJobServer( - 'http://flink', fake_jar) + 'http://flink', options) # Prepare the job. prepare_response = job_server.Prepare( diff --git a/sdks/python/apache_beam/runners/portability/job_server.py b/sdks/python/apache_beam/runners/portability/job_server.py index 7cf8d4321937..0d9b82bf833d 100644 --- a/sdks/python/apache_beam/runners/portability/job_server.py +++ b/sdks/python/apache_beam/runners/portability/job_server.py @@ -28,6 +28,7 @@ import grpc +from apache_beam.options import pipeline_options from apache_beam.portability.api import beam_job_api_pb2_grpc from apache_beam.runners.portability import local_job_service from apache_beam.utils import subprocess_server @@ -124,7 +125,16 @@ class JavaJarJobServer(SubprocessJobServer): MAVEN_REPOSITORY = 'https://repo.maven.apache.org/maven2/org/apache/beam' JAR_CACHE = os.path.expanduser("~/.apache_beam/cache") - def java_arguments(self, job_port, artifacts_dir): + def __init__(self, options): + super(JavaJarJobServer, self).__init__() + options = options.view_as(pipeline_options.JobServerOptions) + self._job_port = options.job_port + self._artifact_port = options.artifact_port + self._expansion_port = options.expansion_port + self._artifacts_dir = options.artifacts_dir + + def java_arguments( + self, job_port, artifact_port, expansion_port, artifacts_dir): raise NotImplementedError(type(self)) def path_to_jar(self): @@ -140,11 +150,15 @@ def local_jar(url): def subprocess_cmd_and_endpoint(self): jar_path = self.local_jar(self.path_to_jar()) - artifacts_dir = self.local_temp_dir(prefix='artifacts') - job_port, = subprocess_server.pick_port(None) + artifacts_dir = (self._artifacts_dir if self._artifacts_dir + else self.local_temp_dir(prefix='artifacts')) + job_port, = subprocess_server.pick_port(self._job_port) return ( ['java', '-jar', jar_path] + list( - self.java_arguments(job_port, artifacts_dir)), + self.java_arguments(job_port, + self._artifact_port, + self._expansion_port, + artifacts_dir)), 'localhost:%s' % job_port) diff --git a/sdks/python/apache_beam/runners/portability/spark_runner.py b/sdks/python/apache_beam/runners/portability/spark_runner.py index ca033103cc2e..8c3939e7c4c4 100644 --- a/sdks/python/apache_beam/runners/portability/spark_runner.py +++ b/sdks/python/apache_beam/runners/portability/spark_runner.py @@ -56,16 +56,14 @@ def _add_argparse_args(cls, parser): 'the execution.') parser.add_argument('--spark_job_server_jar', help='Path or URL to a Beam Spark jobserver jar.') - parser.add_argument('--artifacts_dir', default=None) class SparkJarJobServer(job_server.JavaJarJobServer): def __init__(self, options): - super(SparkJarJobServer, self).__init__() + super(SparkJarJobServer, self).__init__(options) options = options.view_as(SparkRunnerOptions) self._jar = options.spark_job_server_jar self._master_url = options.spark_master_url - self._artifacts_dir = options.artifacts_dir def path_to_jar(self): if self._jar: @@ -73,12 +71,12 @@ def path_to_jar(self): else: return self.path_to_beam_jar('runners:spark:job-server:shadowJar') - def java_arguments(self, job_port, artifacts_dir): + def java_arguments( + self, job_port, artifact_port, expansion_port, artifacts_dir): return [ '--spark-master-url', self._master_url, - '--artifacts-dir', (self._artifacts_dir - if self._artifacts_dir else artifacts_dir), + '--artifacts-dir', artifacts_dir, '--job-port', job_port, - '--artifact-port', 0, - '--expansion-port', 0 + '--artifact-port', artifact_port, + '--expansion-port', expansion_port ] diff --git a/sdks/python/apache_beam/testing/data/trigger_transcripts.yaml b/sdks/python/apache_beam/testing/data/trigger_transcripts.yaml index cac0c7404a2a..fdda05c57073 100644 --- a/sdks/python/apache_beam/testing/data/trigger_transcripts.yaml +++ b/sdks/python/apache_beam/testing/data/trigger_transcripts.yaml @@ -54,6 +54,24 @@ transcript: - {window: [10, 19], values: [10, 11], timestamp: 10} - {window: [20, 29], values: [25], timestamp: 25, late: false} +--- +name: timestamp_combiner_earliest_separate_bundles +window_fn: FixedWindows(10) +trigger_fn: Default +timestamp_combiner: OUTPUT_AT_EARLIEST +transcript: + - input: [1] + - input: [2] + - input: [3] + - input: [10] + - input: [11] + - input: [25] + - watermark: 100 + - expect: + - {window: [0, 9], values: [1, 2, 3], timestamp: 1, final: false} + - {window: [10, 19], values: [10, 11], timestamp: 10} + - {window: [20, 29], values: [25], timestamp: 25, late: false} + --- name: timestamp_combiner_latest window_fn: FixedWindows(10) diff --git a/sdks/python/apache_beam/transforms/timeutil.py b/sdks/python/apache_beam/transforms/timeutil.py index a5f729c75623..88305cb2767c 100644 --- a/sdks/python/apache_beam/transforms/timeutil.py +++ b/sdks/python/apache_beam/transforms/timeutil.py @@ -64,11 +64,11 @@ class TimestampCombinerImpl(with_metaclass(ABCMeta, object)): @abstractmethod def assign_output_time(self, window, input_timestamp): - pass + raise NotImplementedError @abstractmethod def combine(self, output_timestamp, other_output_timestamp): - pass + raise NotImplementedError def combine_all(self, merging_timestamps): """Apply combine to list of timestamps.""" @@ -76,7 +76,7 @@ def combine_all(self, merging_timestamps): for output_time in merging_timestamps: if combined_output_time is None: combined_output_time = output_time - else: + elif output_time is not None: combined_output_time = self.combine( combined_output_time, output_time) return combined_output_time @@ -89,9 +89,6 @@ def merge(self, unused_result_window, merging_timestamps): class DependsOnlyOnWindow(with_metaclass(ABCMeta, TimestampCombinerImpl)): """TimestampCombinerImpl that only depends on the window.""" - def combine(self, output_timestamp, other_output_timestamp): - return output_timestamp - def merge(self, result_window, unused_merging_timestamps): # Since we know that the result only depends on the window, we can ignore # the given timestamps. @@ -137,3 +134,6 @@ class OutputAtEndOfWindowImpl(DependsOnlyOnWindow): def assign_output_time(self, window, unused_input_timestamp): return window.max_timestamp() + + def combine(self, output_timestamp, other_output_timestamp): + return max(output_timestamp, other_output_timestamp) diff --git a/sdks/python/apache_beam/transforms/trigger.py b/sdks/python/apache_beam/transforms/trigger.py index e5bc20d2b88b..6f59f2158910 100644 --- a/sdks/python/apache_beam/transforms/trigger.py +++ b/sdks/python/apache_beam/transforms/trigger.py @@ -1147,7 +1147,6 @@ def merge(_, to_be_merged, merge_result): # pylint: disable=no-self-argument for unused_value, timestamp in elements) if element_output_time >= output_watermark)) if output_time is not None: - state.clear_state(window, self.WATERMARK_HOLD) state.add_state(window, self.WATERMARK_HOLD, output_time) context = state.at(window, self.clock) diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py index 6ac05d09abdf..74829e544c24 100644 --- a/sdks/python/apache_beam/transforms/util_test.py +++ b/sdks/python/apache_beam/transforms/util_test.py @@ -32,6 +32,7 @@ # patches unittest.TestCase to be python3 compatible import future.tests.base # pylint: disable=unused-import +from nose.plugins.attrib import attr import apache_beam as beam from apache_beam import WindowInto @@ -54,6 +55,8 @@ from apache_beam.transforms.window import SlidingWindows from apache_beam.transforms.window import TimestampedValue from apache_beam.utils import timestamp +from apache_beam.utils.timestamp import MAX_TIMESTAMP +from apache_beam.utils.timestamp import MIN_TIMESTAMP from apache_beam.utils.windowed_value import WindowedValue @@ -330,7 +333,6 @@ def process(self, element): with self.assertRaisesRegex(ValueError, r'window.*None.*add_timestamps2'): pipeline.run() - class ReshuffleTest(unittest.TestCase): def test_reshuffle_contents_unchanged(self): @@ -477,6 +479,59 @@ def test_reshuffle_streaming_global_window(self): label='after reshuffle') pipeline.run() + @attr('ValidatesRunner') + def test_reshuffle_preserves_timestamps(self): + with TestPipeline() as pipeline: + + # Create a PCollection and assign each element with a different timestamp. + before_reshuffle = (pipeline + | beam.Create([ + {'name': 'foo', 'timestamp': MIN_TIMESTAMP}, + {'name': 'foo', 'timestamp': 0}, + {'name': 'bar', 'timestamp': 33}, + {'name': 'bar', 'timestamp': MAX_TIMESTAMP}, + ]) + | beam.Map( + lambda element: beam.window.TimestampedValue( + element, element['timestamp']))) + + # Reshuffle the PCollection above and assign the timestamp of an element + # to that element again. + after_reshuffle = before_reshuffle | beam.Reshuffle() + + # Given an element, emits a string which contains the timestamp and the + # name field of the element. + def format_with_timestamp(element, timestamp=beam.DoFn.TimestampParam): + t = str(timestamp) + if timestamp == MIN_TIMESTAMP: + t = 'MIN_TIMESTAMP' + elif timestamp == MAX_TIMESTAMP: + t = 'MAX_TIMESTAMP' + return '{} - {}'.format(t, element['name']) + + # Combine each element in before_reshuffle with its timestamp. + formatted_before_reshuffle = (before_reshuffle + | "Get before_reshuffle timestamp" >> + beam.Map(format_with_timestamp)) + + # Combine each element in after_reshuffle with its timestamp. + formatted_after_reshuffle = (after_reshuffle + | "Get after_reshuffle timestamp" >> + beam.Map(format_with_timestamp)) + + expected_data = ['MIN_TIMESTAMP - foo', + 'Timestamp(0) - foo', + 'Timestamp(33) - bar', + 'MAX_TIMESTAMP - bar'] + + # Can't compare formatted_before_reshuffle and formatted_after_reshuffle + # directly, because they are deferred PCollections while equal_to only + # takes a concrete argument. + assert_that(formatted_before_reshuffle, equal_to(expected_data), + label="formatted_before_reshuffle") + assert_that(formatted_after_reshuffle, equal_to(expected_data), + label="formatted_after_reshuffle") + class WithKeysTest(unittest.TestCase): diff --git a/sdks/python/test-suites/portable/py2/build.gradle b/sdks/python/test-suites/portable/py2/build.gradle index 5d967e40c65b..d0b4a64dc2a2 100644 --- a/sdks/python/test-suites/portable/py2/build.gradle +++ b/sdks/python/test-suites/portable/py2/build.gradle @@ -105,7 +105,6 @@ task crossLanguagePythonJavaFlink { def testServiceExpansionJar = project(":sdks:java:testing:expansion-service:").buildTestExpansionServiceJar.archivePath def options = [ "--runner=PortableRunner", - "--experiments=worker_threads=100", "--parallelism=2", "--shutdown_sources_on_final_watermark", "--environment_cache_millis=10000", @@ -132,7 +131,6 @@ task crossLanguagePortableWordCount { "--input=/etc/profile", "--output=/tmp/py-wordcount-portable", "--runner=PortableRunner", - "--experiments=worker_threads=100", "--parallelism=2", "--shutdown_sources_on_final_watermark", "--environment_cache_millis=10000", diff --git a/website/src/_includes/section-menu/community.html b/website/src/_includes/section-menu/community.html index 5e2c11aff16b..39d1ceb0c556 100644 --- a/website/src/_includes/section-menu/community.html +++ b/website/src/_includes/section-menu/community.html @@ -15,6 +15,7 @@

  • Contact Us
  • Policies
  • YouTube channel
  • +
  • Twitter Handle
  • In Person
  • Promotion diff --git a/website/src/community/twitter-handle.md b/website/src/community/twitter-handle.md new file mode 100644 index 000000000000..6ee646ce9ce2 --- /dev/null +++ b/website/src/community/twitter-handle.md @@ -0,0 +1,41 @@ +--- +layout: section +title: 'Beam Twitter handle' +section_menu: section-menu/community.html +permalink: /community/twitter-handle/ +--- + +# Beam Twitter handle +## What is it and what are the goals? +Apache Beam community thrives to be inclusive to everyone. As part of this effort, we want to enable any community member to share anything interesting and fun related to Beam via its [official Twitter handle](https://twitter.com/ApacheBeam). + +The Twitter feed found here is owned by the Apache Beam PMC, and the process of proposing new tweets is documented below. + + + +## Process to propose new tweets for the Apache Beam Twitter + +- Compose a tweet with the news you want to share. It can be anything that is relevant to Apache Beam. For example, it can be a tweet to welcome new committers, announce new Beam features, share and recognize contributors publicly, promote events and meetups, share trending articles around batch and stream processing big data, etc. +- Go to [s.apache.org/beam-tweets](https://s.apache.org/beam-tweets) and request “Edit” access. +- After you are able to edit the document, fill out the necessary fields: + - Date - the date when you are filling out the document + - Author - your name + - Topic - what is the tweet about + - Content - the actual text of the tweet or link to the tweet that need to be retweeted + - Links/media - anything you want to add to the tweet, eg. photos, videos, references + - Deadline - by when the tweet needs to be out, or preferred date when it needs to be out +- After you type in your tweets, PMC members that are subscribed to the document will get notified and review the content. +- After approval, a PMC member with access to the Twitter handle will publish the tweet. +- Follow [Apache Beam handle](https://twitter.com/ApacheBeam) to see if your content is published! diff --git a/website/src/documentation/runners/dataflow.md b/website/src/documentation/runners/dataflow.md index 1cd28baff4ed..51fe6259f287 100644 --- a/website/src/documentation/runners/dataflow.md +++ b/website/src/documentation/runners/dataflow.md @@ -195,7 +195,7 @@ java -jar target/beam-examples-bundled-1.0.0.jar \ sdk_location - Override the default location from where the Beam SDK is downloaded. This value can be an URL, a Cloud Storage path, or a local path to an SDK tarball. Workflow submissions will download or copy the SDK tarball from this location. If set to the string default, a standard SDK location is used. If empty, no SDK is copied. + Override the default location from where the Beam SDK is downloaded. This value can be a URL, a Cloud Storage path, or a local path to an SDK tarball. Workflow submissions will download or copy the SDK tarball from this location. If set to the string default, a standard SDK location is used. If empty, no SDK is copied. default