Skip to content

Commit

Permalink
NIFI-14057 Return HTTP 405 for TRACE and OPTIONS on all paths for Lis…
Browse files Browse the repository at this point in the history
…tenHTTP (#9563)

- Added Filter to handle OPTIONS and TRACE methods for returning HTTP 405

Signed-off-by: David Handermann <[email protected]>
  • Loading branch information
mark-bathori authored Dec 19, 2024
1 parent 116f2f7 commit 3dabc84
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.nifi.processors.standard;

import jakarta.servlet.DispatcherType;
import jakarta.servlet.Servlet;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.ws.rs.Path;
Expand Down Expand Up @@ -44,6 +45,7 @@
import org.apache.nifi.processor.Relationship;
import org.apache.nifi.processor.exception.ProcessException;
import org.apache.nifi.processor.util.StandardValidators;
import org.apache.nifi.processors.standard.filters.HttpMethodFilter;
import org.apache.nifi.processors.standard.http.HttpProtocolStrategy;
import org.apache.nifi.processors.standard.servlets.ContentAcknowledgmentServlet;
import org.apache.nifi.processors.standard.servlets.HealthCheckServlet;
Expand Down Expand Up @@ -71,6 +73,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.EnumSet;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -476,6 +479,8 @@ synchronized private void createHttpServerFromService(final ProcessContext conte
}
}

contextHandler.addFilter(HttpMethodFilter.class, "/*", EnumSet.allOf(DispatcherType.class));

contextHandler.setAttribute(CONTEXT_ATTRIBUTE_PROCESSOR, this);
contextHandler.setAttribute(CONTEXT_ATTRIBUTE_LOGGER, getLogger());
contextHandler.setAttribute(CONTEXT_ATTRIBUTE_SESSION_FACTORY_HOLDER, sessionFactoryReference);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.processors.standard.filters;

import jakarta.servlet.Filter;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletResponse;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;

import java.io.IOException;

public class HttpMethodFilter implements Filter {

@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {

HttpServletRequest request = (HttpServletRequest) servletRequest;
HttpServletResponse response = (HttpServletResponse) servletResponse;

if (request.getMethod().equals("OPTIONS") || request.getMethod().equals("TRACE")) {
response.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED, "Method Not Allowed");
} else {
filterChain.doFilter(servletRequest, servletResponse);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -158,22 +158,6 @@ protected void doHead(final HttpServletRequest request, final HttpServletRespons
}
}

private void notAllowed(final HttpServletRequest request, final HttpServletResponse response) throws IOException {
response.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED, "Method Not Allowed");
}

@Override
protected void doTrace(final HttpServletRequest request, final HttpServletResponse response) throws ServletException, IOException {
notAllowed(request, response);
logger.debug("Denying TRACE request; method not allowed.");
}

@Override
protected void doOptions(final HttpServletRequest request, final HttpServletResponse response) throws ServletException, IOException {
notAllowed(request, response);
logger.debug("Denying OPTIONS request; method not allowed.");
}

@Override
protected void doPost(final HttpServletRequest request, final HttpServletResponse response) throws ServletException, IOException {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import javax.security.auth.x500.X500Principal;

import jakarta.servlet.http.HttpServletResponse;

import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
Expand Down Expand Up @@ -345,14 +346,14 @@ public void testSecureServerSupportsCurrentTlsProtocolVersion() throws Exception
public void testSecureServerTrustStoreConfiguredClientAuthenticationRequired() throws Exception {
configureProcessorSslContextService(ListenHTTP.ClientAuthentication.REQUIRED);
final int port = startSecureServer();
assertThrows(IOException.class, () -> sendMessage(null, true, port, false, HTTP_POST));
assertThrows(IOException.class, () -> sendMessage(null, true, port, HTTP_BASE_PATH, false, HTTP_POST));
}

@Test
public void testSecureServerTrustStoreNotConfiguredClientAuthenticationNotRequired() throws Exception {
configureProcessorSslContextService(ListenHTTP.ClientAuthentication.AUTO);
final int port = startSecureServer();
final int responseCode = sendMessage(null, true, port, true, HTTP_POST);
final int responseCode = sendMessage(null, true, port, HTTP_BASE_PATH, true, HTTP_POST);
assertEquals(HttpServletResponse.SC_NO_CONTENT, responseCode);
}

Expand Down Expand Up @@ -464,7 +465,7 @@ public void testPostContentEncodingGzipAccepted() throws IOException {

final OkHttpClient okHttpClient = getOkHttpClient(false, false);
final Request.Builder requestBuilder = new Request.Builder();
final String url = buildUrl(false, port);
final String url = buildUrl(false, port, HTTP_BASE_PATH);
requestBuilder.url(url);

final String message = String.class.getSimpleName();
Expand Down Expand Up @@ -496,6 +497,14 @@ public void testOptionsNotAllowed() throws Exception {
startWebServerAndSendMessages(messages, HttpServletResponse.SC_METHOD_NOT_ALLOWED, false, false, HTTP_OPTIONS);
}

@Test
public void testOptionsNotAllowedOnNonBasePath() throws Exception {
final int port = startWebServer();
final int statusCode = sendMessage("payload 1", false, port, "randomPath", false, HTTP_OPTIONS);

assertEquals(HttpServletResponse.SC_METHOD_NOT_ALLOWED, statusCode, "HTTP Status Code not matched");
}

@Test
public void testTraceNotAllowed() throws Exception {
final List<String> messages = new ArrayList<>();
Expand All @@ -504,6 +513,14 @@ public void testTraceNotAllowed() throws Exception {
startWebServerAndSendMessages(messages, HttpServletResponse.SC_METHOD_NOT_ALLOWED, false, false, HTTP_TRACE);
}

@Test
public void testTraceNotAllowedOnNonBasePath() throws Exception {
final int port = startWebServer();
final int statusCode = sendMessage("payload 1", false, port, "randomPath", false, HTTP_TRACE);

assertEquals(HttpServletResponse.SC_METHOD_NOT_ALLOWED, statusCode, "HTTP Status Code not matched");
}

private MockRecordParser setupRecordReaderTest() throws InitializationException {
final MockRecordParser parser = new MockRecordParser();
final MockRecordWriter writer = new MockRecordWriter();
Expand All @@ -527,10 +544,10 @@ private int startSecureServer() {
return startWebServer();
}

private int sendMessage(final String message, final boolean secure, final int port, boolean clientAuthRequired, final String httpMethod) throws IOException {
private int sendMessage(final String message, final boolean secure, final int port, final String basePath, boolean clientAuthRequired, final String httpMethod) throws IOException {
final byte[] bytes = message == null ? new byte[]{} : message.getBytes(StandardCharsets.UTF_8);
final RequestBody requestBody = RequestBody.create(bytes, APPLICATION_OCTET_STREAM);
final String url = buildUrl(secure, port);
final String url = buildUrl(secure, port, basePath);
final Request.Builder requestBuilder = new Request.Builder();
final Request request = requestBuilder.method(httpMethod, requestBody)
.url(url)
Expand All @@ -557,8 +574,8 @@ private OkHttpClient getOkHttpClient(final boolean secure, final boolean clientA
return builder.build();
}

private String buildUrl(final boolean secure, final int port) {
return String.format("%s://localhost:%s/%s", secure ? "https" : "http", port, HTTP_BASE_PATH);
private String buildUrl(final boolean secure, final int port, String basePath) {
return String.format("%s://localhost:%s/%s", secure ? "https" : "http", port, basePath);
}

private void testPOSTRequestsReceived(int returnCode, boolean secure, boolean twoWaySsl) throws Exception {
Expand Down Expand Up @@ -623,7 +640,7 @@ private void startWebServerAndSendMessages(final List<String> messages, final in
final int port = startWebServer();

for (final String message : messages) {
final int statusCode = sendMessage(message, secure, port, clientAuthRequired, httpMethod);
final int statusCode = sendMessage(message, secure, port, HTTP_BASE_PATH, clientAuthRequired, httpMethod);
assertEquals(expectedStatusCode, statusCode, "HTTP Status Code not matched");
}
}
Expand Down Expand Up @@ -669,7 +686,7 @@ public void testMultipartFormDataRequest() throws IOException {
.build();

final Request request = new Request.Builder()
.url(buildUrl(isSecure, port))
.url(buildUrl(isSecure, port, HTTP_BASE_PATH))
.post(multipartBody)
.build();

Expand Down Expand Up @@ -741,7 +758,7 @@ public void testLargeHTTPRequestHeader() throws Exception {

final int port = startWebServer();
OkHttpClient client = getOkHttpClient(false, false);
final String url = buildUrl(false, port);
final String url = buildUrl(false, port, HTTP_BASE_PATH);
Request request = new Request.Builder()
.url(url)
.addHeader("Large-Header", largeHeaderValue)
Expand Down

0 comments on commit 3dabc84

Please sign in to comment.