diff --git a/nifi-extension-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/ListenHTTP.java b/nifi-extension-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/ListenHTTP.java index 8c6f5dff78cc..7d8d9f7a0c70 100644 --- a/nifi-extension-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/ListenHTTP.java +++ b/nifi-extension-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/ListenHTTP.java @@ -16,6 +16,7 @@ */ package org.apache.nifi.processors.standard; +import jakarta.servlet.DispatcherType; import jakarta.servlet.Servlet; import jakarta.servlet.http.HttpServletResponse; import jakarta.ws.rs.Path; @@ -44,6 +45,7 @@ import org.apache.nifi.processor.Relationship; import org.apache.nifi.processor.exception.ProcessException; import org.apache.nifi.processor.util.StandardValidators; +import org.apache.nifi.processors.standard.filters.HttpMethodFilter; import org.apache.nifi.processors.standard.http.HttpProtocolStrategy; import org.apache.nifi.processors.standard.servlets.ContentAcknowledgmentServlet; import org.apache.nifi.processors.standard.servlets.HealthCheckServlet; @@ -71,6 +73,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.EnumSet; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -476,6 +479,8 @@ synchronized private void createHttpServerFromService(final ProcessContext conte } } + contextHandler.addFilter(HttpMethodFilter.class, "/*", EnumSet.allOf(DispatcherType.class)); + contextHandler.setAttribute(CONTEXT_ATTRIBUTE_PROCESSOR, this); contextHandler.setAttribute(CONTEXT_ATTRIBUTE_LOGGER, getLogger()); contextHandler.setAttribute(CONTEXT_ATTRIBUTE_SESSION_FACTORY_HOLDER, sessionFactoryReference); diff --git a/nifi-extension-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/filters/HttpMethodFilter.java b/nifi-extension-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/filters/HttpMethodFilter.java new file mode 100644 index 000000000000..d0888fb25494 --- /dev/null +++ b/nifi-extension-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/filters/HttpMethodFilter.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.nifi.processors.standard.filters; + +import jakarta.servlet.Filter; +import jakarta.servlet.FilterChain; +import jakarta.servlet.ServletException; +import jakarta.servlet.ServletRequest; +import jakarta.servlet.ServletResponse; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +import java.io.IOException; + +public class HttpMethodFilter implements Filter { + + @Override + public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException { + + HttpServletRequest request = (HttpServletRequest) servletRequest; + HttpServletResponse response = (HttpServletResponse) servletResponse; + + if (request.getMethod().equals("OPTIONS") || request.getMethod().equals("TRACE")) { + response.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED, "Method Not Allowed"); + } else { + filterChain.doFilter(servletRequest, servletResponse); + } + } +} diff --git a/nifi-extension-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/servlets/ListenHTTPServlet.java b/nifi-extension-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/servlets/ListenHTTPServlet.java index 429b3f174e16..31b011158fff 100644 --- a/nifi-extension-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/servlets/ListenHTTPServlet.java +++ b/nifi-extension-bundles/nifi-standard-bundle/nifi-standard-processors/src/main/java/org/apache/nifi/processors/standard/servlets/ListenHTTPServlet.java @@ -158,22 +158,6 @@ protected void doHead(final HttpServletRequest request, final HttpServletRespons } } - private void notAllowed(final HttpServletRequest request, final HttpServletResponse response) throws IOException { - response.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED, "Method Not Allowed"); - } - - @Override - protected void doTrace(final HttpServletRequest request, final HttpServletResponse response) throws ServletException, IOException { - notAllowed(request, response); - logger.debug("Denying TRACE request; method not allowed."); - } - - @Override - protected void doOptions(final HttpServletRequest request, final HttpServletResponse response) throws ServletException, IOException { - notAllowed(request, response); - logger.debug("Denying OPTIONS request; method not allowed."); - } - @Override protected void doPost(final HttpServletRequest request, final HttpServletResponse response) throws ServletException, IOException { diff --git a/nifi-extension-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/TestListenHTTP.java b/nifi-extension-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/TestListenHTTP.java index e27e57997f6e..c00788dbecc5 100644 --- a/nifi-extension-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/TestListenHTTP.java +++ b/nifi-extension-bundles/nifi-standard-bundle/nifi-standard-processors/src/test/java/org/apache/nifi/processors/standard/TestListenHTTP.java @@ -58,6 +58,7 @@ import javax.security.auth.x500.X500Principal; import jakarta.servlet.http.HttpServletResponse; + import java.io.ByteArrayOutputStream; import java.io.File; import java.io.IOException; @@ -345,14 +346,14 @@ public void testSecureServerSupportsCurrentTlsProtocolVersion() throws Exception public void testSecureServerTrustStoreConfiguredClientAuthenticationRequired() throws Exception { configureProcessorSslContextService(ListenHTTP.ClientAuthentication.REQUIRED); final int port = startSecureServer(); - assertThrows(IOException.class, () -> sendMessage(null, true, port, false, HTTP_POST)); + assertThrows(IOException.class, () -> sendMessage(null, true, port, HTTP_BASE_PATH, false, HTTP_POST)); } @Test public void testSecureServerTrustStoreNotConfiguredClientAuthenticationNotRequired() throws Exception { configureProcessorSslContextService(ListenHTTP.ClientAuthentication.AUTO); final int port = startSecureServer(); - final int responseCode = sendMessage(null, true, port, true, HTTP_POST); + final int responseCode = sendMessage(null, true, port, HTTP_BASE_PATH, true, HTTP_POST); assertEquals(HttpServletResponse.SC_NO_CONTENT, responseCode); } @@ -464,7 +465,7 @@ public void testPostContentEncodingGzipAccepted() throws IOException { final OkHttpClient okHttpClient = getOkHttpClient(false, false); final Request.Builder requestBuilder = new Request.Builder(); - final String url = buildUrl(false, port); + final String url = buildUrl(false, port, HTTP_BASE_PATH); requestBuilder.url(url); final String message = String.class.getSimpleName(); @@ -496,6 +497,14 @@ public void testOptionsNotAllowed() throws Exception { startWebServerAndSendMessages(messages, HttpServletResponse.SC_METHOD_NOT_ALLOWED, false, false, HTTP_OPTIONS); } + @Test + public void testOptionsNotAllowedOnNonBasePath() throws Exception { + final int port = startWebServer(); + final int statusCode = sendMessage("payload 1", false, port, "randomPath", false, HTTP_OPTIONS); + + assertEquals(HttpServletResponse.SC_METHOD_NOT_ALLOWED, statusCode, "HTTP Status Code not matched"); + } + @Test public void testTraceNotAllowed() throws Exception { final List messages = new ArrayList<>(); @@ -504,6 +513,14 @@ public void testTraceNotAllowed() throws Exception { startWebServerAndSendMessages(messages, HttpServletResponse.SC_METHOD_NOT_ALLOWED, false, false, HTTP_TRACE); } + @Test + public void testTraceNotAllowedOnNonBasePath() throws Exception { + final int port = startWebServer(); + final int statusCode = sendMessage("payload 1", false, port, "randomPath", false, HTTP_TRACE); + + assertEquals(HttpServletResponse.SC_METHOD_NOT_ALLOWED, statusCode, "HTTP Status Code not matched"); + } + private MockRecordParser setupRecordReaderTest() throws InitializationException { final MockRecordParser parser = new MockRecordParser(); final MockRecordWriter writer = new MockRecordWriter(); @@ -527,10 +544,10 @@ private int startSecureServer() { return startWebServer(); } - private int sendMessage(final String message, final boolean secure, final int port, boolean clientAuthRequired, final String httpMethod) throws IOException { + private int sendMessage(final String message, final boolean secure, final int port, final String basePath, boolean clientAuthRequired, final String httpMethod) throws IOException { final byte[] bytes = message == null ? new byte[]{} : message.getBytes(StandardCharsets.UTF_8); final RequestBody requestBody = RequestBody.create(bytes, APPLICATION_OCTET_STREAM); - final String url = buildUrl(secure, port); + final String url = buildUrl(secure, port, basePath); final Request.Builder requestBuilder = new Request.Builder(); final Request request = requestBuilder.method(httpMethod, requestBody) .url(url) @@ -557,8 +574,8 @@ private OkHttpClient getOkHttpClient(final boolean secure, final boolean clientA return builder.build(); } - private String buildUrl(final boolean secure, final int port) { - return String.format("%s://localhost:%s/%s", secure ? "https" : "http", port, HTTP_BASE_PATH); + private String buildUrl(final boolean secure, final int port, String basePath) { + return String.format("%s://localhost:%s/%s", secure ? "https" : "http", port, basePath); } private void testPOSTRequestsReceived(int returnCode, boolean secure, boolean twoWaySsl) throws Exception { @@ -623,7 +640,7 @@ private void startWebServerAndSendMessages(final List messages, final in final int port = startWebServer(); for (final String message : messages) { - final int statusCode = sendMessage(message, secure, port, clientAuthRequired, httpMethod); + final int statusCode = sendMessage(message, secure, port, HTTP_BASE_PATH, clientAuthRequired, httpMethod); assertEquals(expectedStatusCode, statusCode, "HTTP Status Code not matched"); } } @@ -669,7 +686,7 @@ public void testMultipartFormDataRequest() throws IOException { .build(); final Request request = new Request.Builder() - .url(buildUrl(isSecure, port)) + .url(buildUrl(isSecure, port, HTTP_BASE_PATH)) .post(multipartBody) .build(); @@ -741,7 +758,7 @@ public void testLargeHTTPRequestHeader() throws Exception { final int port = startWebServer(); OkHttpClient client = getOkHttpClient(false, false); - final String url = buildUrl(false, port); + final String url = buildUrl(false, port, HTTP_BASE_PATH); Request request = new Request.Builder() .url(url) .addHeader("Large-Header", largeHeaderValue)