From 08cfb35895d1b1d2aeeb89885348a6975fc5d82c Mon Sep 17 00:00:00 2001 From: SergeyRyabinin Date: Tue, 17 Oct 2023 21:53:05 +0000 Subject: [PATCH] Encode ARNs in RFC compliant mode --- .../include/aws/core/endpoint/AWSEndpoint.h | 5 +++ .../include/aws/core/http/URI.h | 13 +++++- .../include/aws/core/utils/StringUtils.h | 5 +++ src/aws-cpp-sdk-core/source/http/URI.cpp | 10 ++--- tests/aws-cpp-sdk-core-tests/http/URITest.cpp | 9 ++++ .../cpp/common/UriRequestQueryParams.vm | 45 +++++++------------ .../withoutrequest/OperationOutcome.vm | 21 ++------- 7 files changed, 55 insertions(+), 53 deletions(-) diff --git a/src/aws-cpp-sdk-core/include/aws/core/endpoint/AWSEndpoint.h b/src/aws-cpp-sdk-core/include/aws/core/endpoint/AWSEndpoint.h index 311ac71426f..77a4a6621d5 100644 --- a/src/aws-cpp-sdk-core/include/aws/core/endpoint/AWSEndpoint.h +++ b/src/aws-cpp-sdk-core/include/aws/core/endpoint/AWSEndpoint.h @@ -44,6 +44,11 @@ namespace Aws m_uri.AddPathSegments(std::forward(pathSegments)); } + inline void SetRfc3986Encoded(bool rfcEncoded) + { + m_uri.SetRfc3986Encoded(rfcEncoded); + } + using OptionalError = Crt::Optional>; OptionalError AddPrefixIfMissing(const Aws::String& prefix); diff --git a/src/aws-cpp-sdk-core/include/aws/core/http/URI.h b/src/aws-cpp-sdk-core/include/aws/core/http/URI.h index 724c24379f2..9dfbb115f4f 100644 --- a/src/aws-cpp-sdk-core/include/aws/core/http/URI.h +++ b/src/aws-cpp-sdk-core/include/aws/core/http/URI.h @@ -180,6 +180,16 @@ namespace Aws */ Aws::String GetURIString(bool includeQueryString = true) const; + /** + * Returns true if this URI is going to be encoded in Rfc3986 compliant mode + */ + inline bool IsRfc3986Encoded() const { return m_useRfcEncoding; } + + /** + * Sets Rfc3986 compliant encoding mode. False (i.e. use legacy encoding with some chars unescaped) is the default. + */ + inline void SetRfc3986Encoded(const bool value) { m_useRfcEncoding = value; } + /** * URLEncodes the path portions of path (doesn't encode the "/" portion) * Keeps the first and the last "/". @@ -189,7 +199,7 @@ namespace Aws /** * URLEncodes the path portion of the URI according to RFC3986 */ - static Aws::String URLEncodePathRFC3986(const Aws::String& path); + static Aws::String URLEncodePathRFC3986(const Aws::String& path, bool = false); private: void ParseURIParts(const Aws::String& uri); @@ -205,6 +215,7 @@ namespace Aws uint16_t m_port = HTTP_DEFAULT_PORT; Aws::Vector m_pathSegments; bool m_pathHasTrailingSlash = false; + bool m_useRfcEncoding = false; Aws::String m_queryString; }; diff --git a/src/aws-cpp-sdk-core/include/aws/core/utils/StringUtils.h b/src/aws-cpp-sdk-core/include/aws/core/utils/StringUtils.h index 0281a8fe06c..04e99e25187 100644 --- a/src/aws-cpp-sdk-core/include/aws/core/utils/StringUtils.h +++ b/src/aws-cpp-sdk-core/include/aws/core/utils/StringUtils.h @@ -49,6 +49,11 @@ namespace Aws */ static Aws::String URLEncode(const char* unsafe); + static inline Aws::String URLEncode(const Aws::String& unsafe) + { + return URLEncode(unsafe.c_str()); + } + /** * Http Clients tend to escape some characters but not all. Escaping all of them causes problems, because the client * will also try to escape them. diff --git a/src/aws-cpp-sdk-core/source/http/URI.cpp b/src/aws-cpp-sdk-core/source/http/URI.cpp index 0bc3c092458..75637c42f11 100644 --- a/src/aws-cpp-sdk-core/source/http/URI.cpp +++ b/src/aws-cpp-sdk-core/source/http/URI.cpp @@ -27,10 +27,10 @@ const char* SEPARATOR = "://"; bool s_compliantRfc3986Encoding = false; void SetCompliantRfc3986Encoding(bool compliant) { s_compliantRfc3986Encoding = compliant; } -Aws::String urlEncodeSegment(const Aws::String& segment) +Aws::String urlEncodeSegment(const Aws::String& segment, bool rfcEncoded = false) { // consolidates legacy escaping logic into one local method - if (s_compliantRfc3986Encoding) + if (rfcEncoded || s_compliantRfc3986Encoding) { return StringUtils::URLEncode(segment.c_str()); } @@ -141,7 +141,7 @@ void URI::SetScheme(Scheme value) } } -Aws::String URI::URLEncodePathRFC3986(const Aws::String& path) +Aws::String URI::URLEncodePathRFC3986(const Aws::String& path, bool rfcEncoded) { if (path.empty()) { @@ -155,7 +155,7 @@ Aws::String URI::URLEncodePathRFC3986(const Aws::String& path) // escape characters appearing in a URL path according to RFC 3986 for (const auto& segment : pathParts) { - ss << '/' << urlEncodeSegment(segment); + ss << '/' << urlEncodeSegment(segment, rfcEncoded); } // if the last character was also a slash, then add that back here. @@ -237,7 +237,7 @@ Aws::String URI::GetURLEncodedPathRFC3986() const // (mostly; there is some non-standards legacy support that can be disabled) for (const auto& segment : m_pathSegments) { - ss << '/' << urlEncodeSegment(segment); + ss << '/' << urlEncodeSegment(segment, m_useRfcEncoding); } if (m_pathSegments.empty() || m_pathHasTrailingSlash) diff --git a/tests/aws-cpp-sdk-core-tests/http/URITest.cpp b/tests/aws-cpp-sdk-core-tests/http/URITest.cpp index 24099840666..5857ff57622 100644 --- a/tests/aws-cpp-sdk-core-tests/http/URITest.cpp +++ b/tests/aws-cpp-sdk-core-tests/http/URITest.cpp @@ -286,6 +286,15 @@ TEST_F(URITest, TestParseWithColon) EXPECT_EQ(80, complexUri.GetPort()); EXPECT_STREQ("/awsnativesdkputobjectstestbucket20150702T200059Z/TestObject:1234/awsnativesdkputobjectstestbucket20150702T200059Z/TestObject:Key", complexUri.GetPath().c_str()); EXPECT_STREQ(strComplexUri, complexUri.GetURIString().c_str()); + + URI complexUriCompliant(strComplexUri); + complexUriCompliant.SetRfc3986Encoded(true); + EXPECT_STREQ("s3.us-east-1.amazonaws.com", complexUriCompliant.GetAuthority().c_str()); + EXPECT_EQ(80, complexUriCompliant.GetPort()); + EXPECT_STREQ("/awsnativesdkputobjectstestbucket20150702T200059Z/TestObject:1234/awsnativesdkputobjectstestbucket20150702T200059Z/TestObject:Key", + complexUri.GetPath().c_str()); + EXPECT_STREQ("http://s3.us-east-1.amazonaws.com/awsnativesdkputobjectstestbucket20150702T200059Z/TestObject%3A1234/awsnativesdkputobjectstestbucket20150702T200059Z/TestObject%3AKey", + complexUriCompliant.GetURIString().c_str()); } TEST_F(URITest, TestParseWithColonCompliant) diff --git a/tools/code-generation/generator/src/main/resources/com/amazonaws/util/awsclientgenerator/velocity/cpp/common/UriRequestQueryParams.vm b/tools/code-generation/generator/src/main/resources/com/amazonaws/util/awsclientgenerator/velocity/cpp/common/UriRequestQueryParams.vm index 70647c69fcc..ecbb3e358e7 100644 --- a/tools/code-generation/generator/src/main/resources/com/amazonaws/util/awsclientgenerator/velocity/cpp/common/UriRequestQueryParams.vm +++ b/tools/code-generation/generator/src/main/resources/com/amazonaws/util/awsclientgenerator/velocity/cpp/common/UriRequestQueryParams.vm @@ -1,4 +1,3 @@ -#if($serviceModel.endpointRules) #if(($serviceNamespace == "S3Crt" && $operation.s3CrtEnabled) || $operation.getRequest().getShape().hasEventStreamMembers()) #set($meterNeeded = true) #set($indent = "") @@ -37,7 +36,6 @@ ${indent} CoreErrors::ENDPOINT_RESOLUTION_FAILURE, "ENDPOINT_RESOLUTION ${indent} return; ${indent} } #end -#end #if($operation.http.requestUri.contains("?")) ${indent} Aws::StringStream ss; #end @@ -54,19 +52,17 @@ ${indent} Aws::StringStream ss; #set($queryStart = true) #set($pathAndQuery = $operation.http.splitUriPartIntoPathAndQuery($uriPartString)) #if(!$pathAndQuery.get(0).isEmpty()) -#if($serviceModel.endpointRules) -${indent} endpointResolutionOutcome.GetResult().AddPathSegments("${pathAndQuery.get(0)}"); -#else -${indent} uri.AddPathSegments("${pathAndQuery.get(0)}"); +#if($pathAndQuery.get(0).toLowerCase().contains("arn")) +${indent} endpointResolutionOutcome.GetResult().SetRfc3986Encoded(true); #end +${indent} endpointResolutionOutcome.GetResult().AddPathSegments("${pathAndQuery.get(0)}"); #end ${indent} ss.str("${pathAndQuery.get(1)}"); #elseif(!$uriPartString.equals("/")) -#if($serviceModel.endpointRules) -${indent} endpointResolutionOutcome.GetResult().AddPathSegments("$uriPartString"); -#else -${indent} uri.AddPathSegments("$uriPartString"); +#if($uriPartString.get(0).toLowerCase().contains("arn")) +${indent} endpointResolutionOutcome.GetResult().SetRfc3986Encoded(true); #end +${indent} endpointResolutionOutcome.GetResult().AddPathSegments("$uriPartString"); #end## ---------------------------- if (request uri contains query) end ------ #foreach($var in $uriVars)## for (parameter in request uri parameters) ------- #set($varIndex = $partIndex - 1) @@ -83,17 +79,12 @@ ${indent} uri.AddPathSegments("$uriPartString"); ${indent} ss << $parameter; #else #if($greedySyntax) -#if($serviceModel.endpointRules) -${indent} endpointResolutionOutcome.GetResult().AddPathSegments($parameter); -#else -${indent} uri.AddPathSegments($parameter); +#if($parameter.toLowerCase().contains("arn")) +${indent} endpointResolutionOutcome.GetResult().SetRfc3986Encoded(true); #end +${indent} endpointResolutionOutcome.GetResult().AddPathSegments($parameter); #else -#if($serviceModel.endpointRules) ${indent} endpointResolutionOutcome.GetResult().AddPathSegment($parameter); -#else -${indent} uri.AddPathSegment($parameter); -#end #end #end #if($uriParts.size() > $partIndex) @@ -102,19 +93,17 @@ ${indent} uri.AddPathSegment($parameter); #set($queryStart = true) #set($pathAndQuery = $operation.http.splitUriPartIntoPathAndQuery($uriPartString)) #if(!$pathAndQuery.get(0).isEmpty()) -#if($serviceModel.endpointRules) -${indent} endpointResolutionOutcome.GetResult().AddPathSegments("${pathAndQuery.get(0)}"); -#else -${indent} uri.AddPathSegments("${pathAndQuery.get(0)}"); +#if($pathAndQuery.get(0).toLowerCase().contains("arn")) +${indent} endpointResolutionOutcome.GetResult().SetRfc3986Encoded(true); #end +${indent} endpointResolutionOutcome.GetResult().AddPathSegments("${pathAndQuery.get(0)}"); #end ${indent} ss.str("${pathAndQuery.get(1)}"); #elseif(!$uriPartString.equals("/")) -#if($serviceModel.endpointRules) -${indent} endpointResolutionOutcome.GetResult().AddPathSegments("$uriPartString"); -#else -${indent} uri.AddPathSegments("$uriPartString"); +#if($uriPartString.toLowerCase().contains("arn")) +${indent} endpointResolutionOutcome.GetResult().SetRfc3986Encoded(true); #end +${indent} endpointResolutionOutcome.GetResult().AddPathSegments("$uriPartString"); #end #end #end## --------------------- if !skipFirst end --- @@ -123,9 +112,5 @@ ${indent} uri.AddPathSegments("$uriPartString"); #end## --------------------- if uriParts.size() > startIndex end --- #end## --------------------- for (parameter in request uri parameters) end --- #if($queryStart) -#if($serviceModel.endpointRules) ${indent} endpointResolutionOutcome.GetResult().SetQueryString(ss.str()); -#else -${indent} uri.SetQueryString(ss.str()); -#end #end \ No newline at end of file diff --git a/tools/code-generation/generator/src/main/resources/com/amazonaws/util/awsclientgenerator/velocity/cpp/json/serviceoperations/withoutrequest/OperationOutcome.vm b/tools/code-generation/generator/src/main/resources/com/amazonaws/util/awsclientgenerator/velocity/cpp/json/serviceoperations/withoutrequest/OperationOutcome.vm index 339342bfcef..b88a9119463 100644 --- a/tools/code-generation/generator/src/main/resources/com/amazonaws/util/awsclientgenerator/velocity/cpp/json/serviceoperations/withoutrequest/OperationOutcome.vm +++ b/tools/code-generation/generator/src/main/resources/com/amazonaws/util/awsclientgenerator/velocity/cpp/json/serviceoperations/withoutrequest/OperationOutcome.vm @@ -1,14 +1,6 @@ ${operation.name}Outcome ${className}::${operation.name}() const { AWS_OPERATION_GUARD(${operation.name}); -#if(!$serviceModel.endpointRules) - Aws::StringStream ss; -#if($metadata.hasEndpointTrait) - ss << m_baseUri << "${operation.http.requestUri}"; -#else - ss << m_uri << "${operation.http.requestUri}"; -#end -#end AWS_OPERATION_CHECK_PTR(m_telemetryProvider, ${operation.name}, CoreErrors, CoreErrors::NOT_INITIALIZED); auto tracer = m_telemetryProvider->getTracer(this->GetServiceClientName(), {}); auto meter = m_telemetryProvider->getMeter(this->GetServiceClientName(), {}); @@ -19,22 +11,17 @@ AWS_OPERATION_GUARD(${operation.name}); return TracingUtils::MakeCallWithTiming<${operation.name}Outcome>( [&]()-> ${operation.name}Outcome { #parse("com/amazonaws/util/awsclientgenerator/velocity/cpp/common/EndpointRulesNoRequestUriComputation.vm") -#if($serviceModel.endpointRules && $operation.http.requestUri && $operation.http.requestUri != "/") +#if($operation.http.requestUri && $operation.http.requestUri != "/") +#if($operation.http.requestUri.toLowerCase().contains("arn")) + endpointResolutionOutcome.GetResult().SetRfc3986Encoded(true); +#end endpointResolutionOutcome.GetResult().AddPathSegments("${operation.http.requestUri}"); #end -#if($serviceModel.endpointRules) #if($operation.result && $operation.result.shape.hasStreamMembers()) return ${operation.name}Outcome(MakeRequestWithUnparsedResponse(endpointResolutionOutcome.GetResult(), Aws::Http::HttpMethod::HTTP_${operation.http.method}, ${operation.signerName}, "${operation.name}")); #else return ${operation.name}Outcome(MakeRequest(endpointResolutionOutcome.GetResult(), Aws::Http::HttpMethod::HTTP_${operation.http.method}, ${operation.signerName}, "${operation.name}")); #end -#else##-#if($serviceModel.endpointRules) -#if($operation.result && $operation.result.shape.hasStreamMembers()) - return ${operation.name}Outcome(MakeRequestWithUnparsedResponse(ss.str(), Aws::Http::HttpMethod::HTTP_${operation.http.method}, ${operation.signerName}, "${operation.name}")); -#else - return ${operation.name}Outcome(MakeRequest(ss.str(), Aws::Http::HttpMethod::HTTP_${operation.http.method}, ${operation.signerName}, "${operation.name}")); -#end -#end##-#if($serviceModel.endpointRules) }, TracingUtils::SMITHY_CLIENT_DURATION_METRIC, *meter,