From cbdc53a5655787177ba1a666cc2d5f664f8f732d Mon Sep 17 00:00:00 2001 From: Jim Steinebrey Date: Mon, 19 Aug 2024 12:50:33 -0400 Subject: [PATCH] NIFI-13543 Add HttpRecordSink NIFI-13543 Changed default batch size to 0 and changed to throw IOException from sendHttpRequest() Signed-off-by: Matt Burgess This closes #9185 --- .../nifi-record-sink-service/pom.xml | 40 +++ .../nifi/record/sink/HttpRecordSink.java | 281 ++++++++++++++++ ...g.apache.nifi.controller.ControllerService | 1 + .../nifi/record/sink/TestHttpRecordSink.java | 304 ++++++++++++++++++ 4 files changed, 626 insertions(+) create mode 100644 nifi-extension-bundles/nifi-standard-services/nifi-record-sink-service-bundle/nifi-record-sink-service/src/main/java/org/apache/nifi/record/sink/HttpRecordSink.java create mode 100644 nifi-extension-bundles/nifi-standard-services/nifi-record-sink-service-bundle/nifi-record-sink-service/src/test/java/org/apache/nifi/record/sink/TestHttpRecordSink.java diff --git a/nifi-extension-bundles/nifi-standard-services/nifi-record-sink-service-bundle/nifi-record-sink-service/pom.xml b/nifi-extension-bundles/nifi-standard-services/nifi-record-sink-service-bundle/nifi-record-sink-service/pom.xml index a62c6ef9dafa..34f99cc5ddf5 100644 --- a/nifi-extension-bundles/nifi-standard-services/nifi-record-sink-service-bundle/nifi-record-sink-service/pom.xml +++ b/nifi-extension-bundles/nifi-standard-services/nifi-record-sink-service-bundle/nifi-record-sink-service/pom.xml @@ -44,6 +44,24 @@ nifi-event-transport 2.0.0-SNAPSHOT + + org.apache.nifi + nifi-web-client-provider-api + + + org.apache.nifi + nifi-ssl-context-service-api + + + org.apache.nifi + nifi-proxy-configuration-api + test + + + com.squareup.okhttp3 + mockwebserver + test + org.apache.nifi nifi-mock-record-utils @@ -59,5 +77,27 @@ 2.0.1 test + + org.apache.nifi + nifi-web-client-api + compile + + + org.apache.nifi + nifi-oauth2-provider-api + compile + + + org.apache.nifi + nifi-web-client-provider-service + 2.0.0-SNAPSHOT + test + + + org.apache.nifi + nifi-record-serialization-services + 2.0.0-SNAPSHOT + test + diff --git a/nifi-extension-bundles/nifi-standard-services/nifi-record-sink-service-bundle/nifi-record-sink-service/src/main/java/org/apache/nifi/record/sink/HttpRecordSink.java b/nifi-extension-bundles/nifi-standard-services/nifi-record-sink-service-bundle/nifi-record-sink-service/src/main/java/org/apache/nifi/record/sink/HttpRecordSink.java new file mode 100644 index 000000000000..9c64655aecf1 --- /dev/null +++ b/nifi-extension-bundles/nifi-standard-services/nifi-record-sink-service-bundle/nifi-record-sink-service/src/main/java/org/apache/nifi/record/sink/HttpRecordSink.java @@ -0,0 +1,281 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.nifi.record.sink; + +import org.apache.commons.io.IOUtils; +import org.apache.commons.lang3.StringUtils; +import org.apache.nifi.annotation.documentation.CapabilityDescription; +import org.apache.nifi.annotation.documentation.Tags; +import org.apache.nifi.annotation.lifecycle.OnEnabled; +import org.apache.nifi.components.PropertyDescriptor; +import org.apache.nifi.components.ValidationResult; +import org.apache.nifi.controller.AbstractControllerService; +import org.apache.nifi.controller.ConfigurationContext; +import org.apache.nifi.expression.ExpressionLanguageScope; +import org.apache.nifi.oauth2.OAuth2AccessTokenProvider; +import org.apache.nifi.processor.util.StandardValidators; +import org.apache.nifi.schema.access.SchemaNotFoundException; +import org.apache.nifi.serialization.RecordSetWriter; +import org.apache.nifi.serialization.RecordSetWriterFactory; +import org.apache.nifi.serialization.WriteResult; +import org.apache.nifi.serialization.record.Record; +import org.apache.nifi.serialization.record.RecordSet; +import org.apache.nifi.web.client.api.HttpRequestBodySpec; +import org.apache.nifi.web.client.api.HttpResponseEntity; +import org.apache.nifi.web.client.api.HttpUriBuilder; +import org.apache.nifi.web.client.provider.api.WebClientServiceProvider; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalLong; +import java.util.stream.Collectors; + +@Tags({"http", "post", "record", "sink"}) +@CapabilityDescription("Format and send Records to a configured uri using HTTP post. The Record Writer formats the records which are sent as the body of the HTTP post request. " + + "JsonRecordSetWriter is often used with this processor because many HTTP posts require a JSON body.") +public class HttpRecordSink extends AbstractControllerService implements RecordSinkService { + protected static final String HEADER_AUTHORIZATION = "Authorization"; + protected static final String HEADER_CONTENT_TYPE = "Content-Type"; + + public static final PropertyDescriptor API_URL = new PropertyDescriptor.Builder() + .name("API URL") + .description("The URL which receives the HTTP requests.") + .addValidator(StandardValidators.NON_BLANK_VALIDATOR) + .addValidator(StandardValidators.URL_VALIDATOR) + .expressionLanguageSupported(ExpressionLanguageScope.ENVIRONMENT) + .required(true) + .build(); + + static final PropertyDescriptor MAX_BATCH_SIZE = new PropertyDescriptor.Builder() + .name("Maximum Batch Size") + .description("Specifies the maximum number of records to send in the body of each HTTP request. Zero means the batch size is not limited, " + + "and all records are sent together in a single HTTP request.") + .defaultValue("0") + .addValidator(StandardValidators.NON_NEGATIVE_INTEGER_VALIDATOR) + .expressionLanguageSupported(ExpressionLanguageScope.ENVIRONMENT) + .required(true) + .build(); + + public static final PropertyDescriptor WEB_SERVICE_CLIENT_PROVIDER = new PropertyDescriptor.Builder() + .name("Web Service Client Provider") + .description("Controller service to provide the HTTP client for sending the HTTP requests.") + .required(true) + .identifiesControllerService(WebClientServiceProvider.class) + .build(); + + public static final PropertyDescriptor OAUTH2_ACCESS_TOKEN_PROVIDER = new PropertyDescriptor.Builder() + .name("OAuth2 Access Token Provider") + .description("OAuth2 service that provides the access tokens for the HTTP requests.") + .identifiesControllerService(OAuth2AccessTokenProvider.class) + .required(false) + .build(); + + private String apiUrl; + private int maxBatchSize; + private volatile RecordSetWriterFactory writerFactory; + private WebClientServiceProvider webClientServiceProvider; + private volatile Optional oauth2AccessTokenProviderOptional; + Map dynamicHttpHeaders; + + public static final List PROPERTIES = Collections.unmodifiableList(Arrays.asList( + API_URL, + MAX_BATCH_SIZE, + RECORD_WRITER_FACTORY, + WEB_SERVICE_CLIENT_PROVIDER, + OAUTH2_ACCESS_TOKEN_PROVIDER + )); + + @Override + public List getSupportedPropertyDescriptors() { + return PROPERTIES; + } + + /** + * Returns a PropertyDescriptor for the given name. This is for the user to be able to define their own properties + * which will sent as HTTP headers on the HTTP request + * + * @param propertyDescriptorName used to lookup if any property descriptors exist for that name + * @return a PropertyDescriptor object corresponding to the specified dynamic property name + */ + @Override + protected PropertyDescriptor getSupportedDynamicPropertyDescriptor(final String propertyDescriptorName) { + if (hasProhibitedName(propertyDescriptorName, HEADER_CONTENT_TYPE)) { + // Content-Type is case-sensitive for overriding our default Content-Type header, so prevent any other combination of upper/lower case letters + return getInvalidDynamicPropertyDescriptor(propertyDescriptorName, "is not allowed. Only exact case of Content-Type is allowed."); + } + + if (hasProhibitedName(propertyDescriptorName, HEADER_AUTHORIZATION)) { + // Authorization is case-sensitive for overriding our default Authorization header, so prevent any other combination of upper/lower case letters + return getInvalidDynamicPropertyDescriptor(propertyDescriptorName, "is not allowed. Only exact case of Authorization is allowed."); + } + + return new PropertyDescriptor.Builder() + .name(propertyDescriptorName) + .required(false) + .addValidator(StandardValidators.NON_EMPTY_VALIDATOR) + .expressionLanguageSupported(ExpressionLanguageScope.ENVIRONMENT) + .dynamic(true) + .build(); + } + + private static boolean hasProhibitedName(String userInput, String correctName) { + // Do not allow : in any case be + // cause it is not the correct name + // 'correctName' header is case-sensitive for overriding our default 'correctName' header, so prevent any other combination of upper/lower case letters + return (correctName + ":").equalsIgnoreCase(userInput) + || (correctName.equalsIgnoreCase(userInput) && !correctName.equals(userInput)); + } + + private static PropertyDescriptor getInvalidDynamicPropertyDescriptor(String propertyDescriptorName, String explanation) { + return new PropertyDescriptor.Builder() + .name(propertyDescriptorName) + .addValidator((subject, input, context) -> new ValidationResult.Builder() + .explanation(explanation) + .valid(false) + .subject(subject) + .build()) + .dynamic(true) + .build(); + } + + @OnEnabled + public void onEnabled(final ConfigurationContext context) { + apiUrl = context.getProperty(API_URL).evaluateAttributeExpressions().getValue(); + maxBatchSize = context.getProperty(MAX_BATCH_SIZE).evaluateAttributeExpressions().asInteger(); + writerFactory = context.getProperty(RECORD_WRITER_FACTORY).asControllerService(RecordSetWriterFactory.class); + webClientServiceProvider = context + .getProperty(WEB_SERVICE_CLIENT_PROVIDER).asControllerService(WebClientServiceProvider.class); + + if (context.getProperty(OAUTH2_ACCESS_TOKEN_PROVIDER).isSet()) { + OAuth2AccessTokenProvider oauth2AccessTokenProvider = context + .getProperty(OAUTH2_ACCESS_TOKEN_PROVIDER).asControllerService(OAuth2AccessTokenProvider.class); + oauth2AccessTokenProvider.getAccessDetails(); + oauth2AccessTokenProviderOptional = Optional.of(oauth2AccessTokenProvider); + } else { + oauth2AccessTokenProviderOptional = Optional.empty(); + } + + // Dynamic properties are sent as http headers on the post request. + dynamicHttpHeaders = context.getProperties().keySet().stream() + .filter(PropertyDescriptor::isDynamic) + .collect(Collectors.toMap( + PropertyDescriptor::getName, + p -> context.getProperty(p).evaluateAttributeExpressions().getValue())); + } + + @Override + public WriteResult sendData(RecordSet recordSet, Map attributes, boolean sendZeroResults) throws IOException { + WriteResult writeResult; + try (final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + final RecordSetWriter writer = writerFactory.createWriter(getLogger(), recordSet.getSchema(), baos, attributes)) { + writeResult = sendRecords(recordSet, writer, baos, maxBatchSize); + } catch (SchemaNotFoundException e) { + final String errorMessage = String.format("RecordSetWriter could not be created because the schema was not found. The schema name for the RecordSet to write is %s", + recordSet.getSchema().getSchemaName()); + throw new IOException(errorMessage, e); + } + + return writeResult; + } + + private WriteResult sendRecords(final RecordSet recordSet, final RecordSetWriter writer, final ByteArrayOutputStream baos, int maxBatchSize) throws IOException { + WriteResult writeResult = WriteResult.EMPTY; + Record r = recordSet.next(); + if (r != null) { + int batchCount = 0; + do { + if (maxBatchSize != 1 && batchCount == 0) { + // If maxBatchSize is one, then do NOT write record set begin or end markers because + // each single record is sent alone without being in an array. + writer.beginRecordSet(); + } + + writeResult = writer.write(r); + batchCount++; + + r = recordSet.next(); + + // If this is last record, then send current group of records. + // OR if we have processed maxBatchSize records, then send current group of records. + // Unless batchCount is 0, which means to send all records together in one batch at the end. + if (r == null || (maxBatchSize > 0 && batchCount >= maxBatchSize)) { + if (maxBatchSize != 1) { + writeResult = writer.finishRecordSet(); + } + writer.flush(); + sendHttpRequest(baos.toByteArray(), writer.getMimeType()); + baos.reset(); + batchCount = 0; + } + } while (r != null); + } + return writeResult; + } + + public void sendHttpRequest(final byte[] body, String mimeType) throws IOException { + final URI apiUri = URI.create(apiUrl); + final HttpUriBuilder uriBuilder = webClientServiceProvider.getHttpUriBuilder() + .scheme(apiUri.getScheme()) + .host(apiUri.getHost()) + .encodedPath(apiUri.getPath()); + if (apiUri.getPort() != -1) { + uriBuilder.port(apiUri.getPort()); + } + final URI uri = uriBuilder.build(); + + HttpRequestBodySpec requestBodySpec = webClientServiceProvider.getWebClientService() + .post() + .uri(uri); + + dynamicHttpHeaders.forEach(requestBodySpec::header); + + if (StringUtils.isNotBlank(mimeType) && !dynamicHttpHeaders.containsKey(HEADER_CONTENT_TYPE)) { + requestBodySpec.header(HEADER_CONTENT_TYPE, mimeType); + } + + if (!dynamicHttpHeaders.containsKey(HEADER_AUTHORIZATION)) { + oauth2AccessTokenProviderOptional.ifPresent(oauth2AccessTokenProvider -> + requestBodySpec.header(HEADER_AUTHORIZATION, "Bearer " + oauth2AccessTokenProvider.getAccessDetails().getAccessToken())); + } + + final InputStream requestBodyInputStream = new ByteArrayInputStream(body); + + try (final HttpResponseEntity response = requestBodySpec + .body(requestBodyInputStream, OptionalLong.of(requestBodyInputStream.available())) + .retrieve()) { + final int statusCode = response.statusCode(); + if (!(statusCode >= 200 && statusCode < 300)) { + throw new IOException(String.format("HTTP request failed with status code: %s for url: %s and returned response body: %s", + statusCode, uri.toString(), response.body() == null ? "none" : IOUtils.toString(response.body(), StandardCharsets.UTF_8))); + } + } catch (final IOException ioe) { + throw ioe; + } catch (final Exception e) { + throw new IOException(String.format("HttpRecordSink HTTP request transmission failed for url: %s", uri.toString()), e); + } + } +} diff --git a/nifi-extension-bundles/nifi-standard-services/nifi-record-sink-service-bundle/nifi-record-sink-service/src/main/resources/META-INF/services/org.apache.nifi.controller.ControllerService b/nifi-extension-bundles/nifi-standard-services/nifi-record-sink-service-bundle/nifi-record-sink-service/src/main/resources/META-INF/services/org.apache.nifi.controller.ControllerService index 4ca4073bf2f5..011851f6bf17 100644 --- a/nifi-extension-bundles/nifi-standard-services/nifi-record-sink-service-bundle/nifi-record-sink-service/src/main/resources/META-INF/services/org.apache.nifi.controller.ControllerService +++ b/nifi-extension-bundles/nifi-standard-services/nifi-record-sink-service-bundle/nifi-record-sink-service/src/main/resources/META-INF/services/org.apache.nifi.controller.ControllerService @@ -16,3 +16,4 @@ org.apache.nifi.record.sink.lookup.RecordSinkServiceLookup org.apache.nifi.record.sink.LoggingRecordSink org.apache.nifi.record.sink.EmailRecordSink org.apache.nifi.record.sink.event.UDPEventRecordSink +org.apache.nifi.record.sink.HttpRecordSink diff --git a/nifi-extension-bundles/nifi-standard-services/nifi-record-sink-service-bundle/nifi-record-sink-service/src/test/java/org/apache/nifi/record/sink/TestHttpRecordSink.java b/nifi-extension-bundles/nifi-standard-services/nifi-record-sink-service-bundle/nifi-record-sink-service/src/test/java/org/apache/nifi/record/sink/TestHttpRecordSink.java new file mode 100644 index 000000000000..8c5ada3f3bd9 --- /dev/null +++ b/nifi-extension-bundles/nifi-standard-services/nifi-record-sink-service-bundle/nifi-record-sink-service/src/test/java/org/apache/nifi/record/sink/TestHttpRecordSink.java @@ -0,0 +1,304 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.nifi.record.sink; + +import com.fasterxml.jackson.databind.ObjectMapper; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; +import org.apache.nifi.json.JsonRecordSetWriter; +import org.apache.nifi.oauth2.OAuth2AccessTokenProvider; +import org.apache.nifi.reporting.InitializationException; +import org.apache.nifi.serialization.RecordSetWriterFactory; +import org.apache.nifi.serialization.SimpleRecordSchema; +import org.apache.nifi.serialization.WriteResult; +import org.apache.nifi.serialization.record.MapRecord; +import org.apache.nifi.serialization.record.Record; +import org.apache.nifi.serialization.record.RecordField; +import org.apache.nifi.serialization.record.RecordFieldType; +import org.apache.nifi.serialization.record.RecordSchema; +import org.apache.nifi.serialization.record.RecordSet; +import org.apache.nifi.util.NoOpProcessor; +import org.apache.nifi.util.TestRunner; +import org.apache.nifi.util.TestRunners; +import org.apache.nifi.web.client.provider.api.WebClientServiceProvider; +import org.apache.nifi.web.client.provider.service.StandardWebClientServiceProvider; +import org.eclipse.jetty.http.HttpHeader; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Answers; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class TestHttpRecordSink { + public static final String ID = "id"; + public static final String NAME = "name"; + public static final String ACTIVE = "active"; + + private TestRunner testRunner; + private MockWebServer mockWebServer; + private HttpRecordSink httpRecordSink; + private RecordSetWriterFactory writerFactory; + final private String OAUTH_ACCESS_TOKEN = "access_token"; + + private static RecordSchema schema; + private static Record[] records; + private ObjectMapper mapper; + + @BeforeAll + public static void setupOnce() { + final List fields = new ArrayList<>(); + fields.add(new RecordField(ID, RecordFieldType.INT.getDataType())); + fields.add(new RecordField(NAME, RecordFieldType.STRING.getDataType())); + fields.add(new RecordField(ACTIVE, RecordFieldType.BOOLEAN.getDataType())); + + schema = new SimpleRecordSchema(fields); + + final Record record0 = createRecord(schema, 0); + final Record record1 = createRecord(schema, 1); + final Record record2 = createRecord(schema, 2); + final Record record3 = createRecord(schema, 3); + final Record record4 = createRecord(schema, 4); + records = new Record[] {record0, record1, record2, record3, record4 }; + } + + private static Record createRecord(final RecordSchema schema, final int index) { + final Map valueMap = new HashMap<>(); + valueMap.put(ID, index); + valueMap.put(NAME, "Name_äöü_こんにちは世界_" + index); + valueMap.put(ACTIVE, index % 2 == 0); + return new MapRecord(schema, valueMap); + } + + private static RecordSet createRecordSetWithSize(final int size) { + return RecordSet.of(schema, Arrays.copyOf(records, size)); + } + + @BeforeEach + public void setupEachTest() throws InitializationException, IOException { + mapper = new ObjectMapper(); + + mockWebServer = new MockWebServer(); + mockWebServer.start(); + String url = mockWebServer.url("/api/test").toString(); + + testRunner = TestRunners.newTestRunner(NoOpProcessor.class); + + final WebClientServiceProvider webClientServiceProvider = new StandardWebClientServiceProvider(); + testRunner.addControllerService("webClientServiceProvider", webClientServiceProvider); + testRunner.enableControllerService(webClientServiceProvider); + + httpRecordSink = new HttpRecordSink(); + + testRunner.addControllerService("httpRecordSink", httpRecordSink); + testRunner.setProperty(httpRecordSink, HttpRecordSink.API_URL, url); + testRunner.setProperty(httpRecordSink, HttpRecordSink.WEB_SERVICE_CLIENT_PROVIDER, "webClientServiceProvider"); + + writerFactory = new JsonRecordSetWriter(); + testRunner.addControllerService("writer", writerFactory); + testRunner.setProperty(httpRecordSink, HttpRecordSink.RECORD_WRITER_FACTORY, "writer"); + + setupOAuth2TokenProvider(); + } + + private void setupOAuth2TokenProvider() throws InitializationException { + String oauth2AccessTokenProviderId = "oauth2AccessTokenProviderId"; + + OAuth2AccessTokenProvider oauth2AccessTokenProvider = mock(OAuth2AccessTokenProvider.class, Answers.RETURNS_DEEP_STUBS); + when(oauth2AccessTokenProvider.getIdentifier()).thenReturn(oauth2AccessTokenProviderId); + when(oauth2AccessTokenProvider.getAccessDetails().getAccessToken()).thenReturn(OAUTH_ACCESS_TOKEN); + + testRunner.addControllerService(oauth2AccessTokenProviderId, oauth2AccessTokenProvider); + testRunner.enableControllerService(oauth2AccessTokenProvider); + + testRunner.setProperty(httpRecordSink, HttpRecordSink.OAUTH2_ACCESS_TOKEN_PROVIDER, oauth2AccessTokenProviderId); + } + + @AfterEach + public void cleanUpEachTest() throws IOException { + mockWebServer.shutdown(); + } + + @Test + public void testInvalidIfApiUrlEmpty() { + testRunner.setProperty(httpRecordSink, HttpRecordSink.API_URL, ""); + + testRunner.enableControllerService(writerFactory); + testRunner.assertNotValid(httpRecordSink); + } + + @Test + public void testInvalidIfWebClientServiceDoesNotExist() { + testRunner.setProperty(httpRecordSink, HttpRecordSink.WEB_SERVICE_CLIENT_PROVIDER, "nonexistent"); + + testRunner.enableControllerService(writerFactory); + testRunner.assertNotValid(httpRecordSink); + } + + @Test + public void testValidContentTypeHeader() throws Exception { + testRunner.setProperty(httpRecordSink, "Content-Type", "my_content_type"); + testRunner.setProperty(httpRecordSink, "RandomHeader", "random_value"); + + testRunner.enableControllerService(writerFactory); + testRunner.assertValid(httpRecordSink); + + testRunner.disableControllerService(writerFactory); + testSendData(5, 2, "my_content_type", null); + } + + @Test + public void testInvalidContentTypeHeader() { + testRunner.setProperty(httpRecordSink, "content-type", "anything"); + + testRunner.enableControllerService(writerFactory); + testRunner.assertNotValid(httpRecordSink); + } + + @Test + public void testValidAuthorizationDynamicHeader() throws Exception { + testRunner.setProperty(httpRecordSink, "Authorization", "Bearer my_authorization"); + + testRunner.enableControllerService(writerFactory); + testRunner.assertValid(httpRecordSink); + + testRunner.disableControllerService(writerFactory); + testSendData(3, 1, null, "my_authorization"); + } + + @Test + public void testInvalidAuthorizationDynamicHeader() { + testRunner.setProperty(httpRecordSink, "authorization", "anything"); + + testRunner.enableControllerService(writerFactory); + testRunner.assertNotValid(httpRecordSink); + } + + @Test + public void testSendDataBatchSize0() throws Exception { + testSendData(5, 0); + } + + @Test + public void testSendDataBatchSize1() throws Exception { + testSendData(4, 1); + } + + @Test + public void testSendDataBatchSize2() throws Exception { + testSendData(2, 2); + } + + @Test + public void testSendDataBatchSize3() throws Exception { + testSendData(2, 3); + } + + @Test + public void testSendDataBatchSize4() throws Exception { + testSendData(5, 4); + } + + @Test + public void testSendDataBatchSize5() throws Exception { + testSendData(2, 5); + } + + public void testSendData(int recordCount, int maxBatchSize) throws Exception { + testSendData(recordCount, maxBatchSize, null, null); + } + + public void testSendData(int recordCount, int maxBatchSize, + String expectedContentType, String expectedAuthorization) throws Exception { + RecordSet recordSetIn = createRecordSetWithSize(recordCount); + int expectedRequestCount = maxBatchSize == 0 + ? 1 + : recordCount / maxBatchSize + ((recordCount % maxBatchSize == 0) ? 0 : 1); + testRunner.setProperty(httpRecordSink, HttpRecordSink.MAX_BATCH_SIZE, String.valueOf(maxBatchSize)); + testRunner.enableControllerService(writerFactory); + testRunner.assertValid(httpRecordSink); + testRunner.enableControllerService(httpRecordSink); + + for (int i = 0; i < expectedRequestCount; i++) { + mockWebServer.enqueue(new MockResponse()); + } + + final WriteResult writeResult = httpRecordSink.sendData(recordSetIn, Collections.emptyMap(), false); + + assertNotNull(writeResult); + assertEquals(recordCount, writeResult.getRecordCount()); + assertEquals(Collections.EMPTY_MAP, writeResult.getAttributes()); + + assertEquals(expectedRequestCount, mockWebServer.getRequestCount()); + + for (int i = 0; i < expectedRequestCount; i++) { + RecordedRequest recordedRequest = mockWebServer.takeRequest(); + String requestBody = recordedRequest.getBody().readString(StandardCharsets.UTF_8); + Person[] people = + (maxBatchSize == 1) + ? new Person[] { + // For maxBatchSize 1, person is not in a Json array + mapper.readValue(requestBody, Person.class) + } + : mapper.readValue(requestBody, Person[].class); // Otherwise the body is a json array + + for (int personIndex = 0; personIndex < people.length; personIndex++) { + final int compareIndex = i * maxBatchSize + personIndex; + assertTrue(people[personIndex].equals(records[compareIndex]), "Mismatch - Expected: " + records[compareIndex].toMap().toString() + + " Actual: {" + people[personIndex].toString() + "} order of fields can be ignored."); + } + String actualContentTypeHeader = recordedRequest.getHeader(HttpHeader.CONTENT_TYPE.toString()); + assertEquals(expectedContentType != null ? expectedContentType : "application/json", actualContentTypeHeader); + + String actualAuthorizationHeader = recordedRequest.getHeader(HttpHeader.AUTHORIZATION.toString()); + assertEquals("Bearer " + (expectedAuthorization != null ? expectedAuthorization : OAUTH_ACCESS_TOKEN), + actualAuthorizationHeader); + } + } + + static public class Person { + public int id; + public String name; + public boolean active; + + public boolean equals(Record record) { + return id == record.getAsInt(ID) + && name.equals(record.getAsString(NAME)) + && active == record.getAsBoolean(ACTIVE); + } + + public String toString() { + return ID + "=" + id + ", " + NAME + "=" + name + ", " + ACTIVE + "=" + active; + } + } +}