Skip to content

Commit

Permalink
Encode ARNs in RFC compliant mode
Browse files Browse the repository at this point in the history
  • Loading branch information
SergeyRyabinin committed Oct 17, 2023
1 parent 06c16a2 commit 08cfb35
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 53 deletions.
5 changes: 5 additions & 0 deletions src/aws-cpp-sdk-core/include/aws/core/endpoint/AWSEndpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ namespace Aws
m_uri.AddPathSegments(std::forward<T>(pathSegments));
}

inline void SetRfc3986Encoded(bool rfcEncoded)
{
m_uri.SetRfc3986Encoded(rfcEncoded);
}

using OptionalError = Crt::Optional<Aws::Client::AWSError<Aws::Client::CoreErrors>>;
OptionalError AddPrefixIfMissing(const Aws::String& prefix);

Expand Down
13 changes: 12 additions & 1 deletion src/aws-cpp-sdk-core/include/aws/core/http/URI.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,16 @@ namespace Aws
*/
Aws::String GetURIString(bool includeQueryString = true) const;

/**
* Returns true if this URI is going to be encoded in Rfc3986 compliant mode
*/
inline bool IsRfc3986Encoded() const { return m_useRfcEncoding; }

/**
* Sets Rfc3986 compliant encoding mode. False (i.e. use legacy encoding with some chars unescaped) is the default.
*/
inline void SetRfc3986Encoded(const bool value) { m_useRfcEncoding = value; }

/**
* URLEncodes the path portions of path (doesn't encode the "/" portion)
* Keeps the first and the last "/".
Expand All @@ -189,7 +199,7 @@ namespace Aws
/**
* URLEncodes the path portion of the URI according to RFC3986
*/
static Aws::String URLEncodePathRFC3986(const Aws::String& path);
static Aws::String URLEncodePathRFC3986(const Aws::String& path, bool = false);

private:
void ParseURIParts(const Aws::String& uri);
Expand All @@ -205,6 +215,7 @@ namespace Aws
uint16_t m_port = HTTP_DEFAULT_PORT;
Aws::Vector<Aws::String> m_pathSegments;
bool m_pathHasTrailingSlash = false;
bool m_useRfcEncoding = false;
Aws::String m_queryString;
};

Expand Down
5 changes: 5 additions & 0 deletions src/aws-cpp-sdk-core/include/aws/core/utils/StringUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ namespace Aws
*/
static Aws::String URLEncode(const char* unsafe);

static inline Aws::String URLEncode(const Aws::String& unsafe)
{
return URLEncode(unsafe.c_str());
}

/**
* Http Clients tend to escape some characters but not all. Escaping all of them causes problems, because the client
* will also try to escape them.
Expand Down
10 changes: 5 additions & 5 deletions src/aws-cpp-sdk-core/source/http/URI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ const char* SEPARATOR = "://";
bool s_compliantRfc3986Encoding = false;
void SetCompliantRfc3986Encoding(bool compliant) { s_compliantRfc3986Encoding = compliant; }

Aws::String urlEncodeSegment(const Aws::String& segment)
Aws::String urlEncodeSegment(const Aws::String& segment, bool rfcEncoded = false)
{
// consolidates legacy escaping logic into one local method
if (s_compliantRfc3986Encoding)
if (rfcEncoded || s_compliantRfc3986Encoding)
{
return StringUtils::URLEncode(segment.c_str());
}
Expand Down Expand Up @@ -141,7 +141,7 @@ void URI::SetScheme(Scheme value)
}
}

Aws::String URI::URLEncodePathRFC3986(const Aws::String& path)
Aws::String URI::URLEncodePathRFC3986(const Aws::String& path, bool rfcEncoded)
{
if (path.empty())
{
Expand All @@ -155,7 +155,7 @@ Aws::String URI::URLEncodePathRFC3986(const Aws::String& path)
// escape characters appearing in a URL path according to RFC 3986
for (const auto& segment : pathParts)
{
ss << '/' << urlEncodeSegment(segment);
ss << '/' << urlEncodeSegment(segment, rfcEncoded);
}

// if the last character was also a slash, then add that back here.
Expand Down Expand Up @@ -237,7 +237,7 @@ Aws::String URI::GetURLEncodedPathRFC3986() const
// (mostly; there is some non-standards legacy support that can be disabled)
for (const auto& segment : m_pathSegments)
{
ss << '/' << urlEncodeSegment(segment);
ss << '/' << urlEncodeSegment(segment, m_useRfcEncoding);
}

if (m_pathSegments.empty() || m_pathHasTrailingSlash)
Expand Down
9 changes: 9 additions & 0 deletions tests/aws-cpp-sdk-core-tests/http/URITest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,15 @@ TEST_F(URITest, TestParseWithColon)
EXPECT_EQ(80, complexUri.GetPort());
EXPECT_STREQ("/awsnativesdkputobjectstestbucket20150702T200059Z/TestObject:1234/awsnativesdkputobjectstestbucket20150702T200059Z/TestObject:Key", complexUri.GetPath().c_str());
EXPECT_STREQ(strComplexUri, complexUri.GetURIString().c_str());

URI complexUriCompliant(strComplexUri);
complexUriCompliant.SetRfc3986Encoded(true);
EXPECT_STREQ("s3.us-east-1.amazonaws.com", complexUriCompliant.GetAuthority().c_str());
EXPECT_EQ(80, complexUriCompliant.GetPort());
EXPECT_STREQ("/awsnativesdkputobjectstestbucket20150702T200059Z/TestObject:1234/awsnativesdkputobjectstestbucket20150702T200059Z/TestObject:Key",
complexUri.GetPath().c_str());
EXPECT_STREQ("http://s3.us-east-1.amazonaws.com/awsnativesdkputobjectstestbucket20150702T200059Z/TestObject%3A1234/awsnativesdkputobjectstestbucket20150702T200059Z/TestObject%3AKey",
complexUriCompliant.GetURIString().c_str());
}

TEST_F(URITest, TestParseWithColonCompliant)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#if($serviceModel.endpointRules)
#if(($serviceNamespace == "S3Crt" && $operation.s3CrtEnabled) || $operation.getRequest().getShape().hasEventStreamMembers())
#set($meterNeeded = true)
#set($indent = "")
Expand Down Expand Up @@ -37,7 +36,6 @@ ${indent} CoreErrors::ENDPOINT_RESOLUTION_FAILURE, "ENDPOINT_RESOLUTION
${indent} return;
${indent} }
#end
#end
#if($operation.http.requestUri.contains("?"))
${indent} Aws::StringStream ss;
#end
Expand All @@ -54,19 +52,17 @@ ${indent} Aws::StringStream ss;
#set($queryStart = true)
#set($pathAndQuery = $operation.http.splitUriPartIntoPathAndQuery($uriPartString))
#if(!$pathAndQuery.get(0).isEmpty())
#if($serviceModel.endpointRules)
${indent} endpointResolutionOutcome.GetResult().AddPathSegments("${pathAndQuery.get(0)}");
#else
${indent} uri.AddPathSegments("${pathAndQuery.get(0)}");
#if($pathAndQuery.get(0).toLowerCase().contains("arn"))
${indent} endpointResolutionOutcome.GetResult().SetRfc3986Encoded(true);
#end
${indent} endpointResolutionOutcome.GetResult().AddPathSegments("${pathAndQuery.get(0)}");
#end
${indent} ss.str("${pathAndQuery.get(1)}");
#elseif(!$uriPartString.equals("/"))
#if($serviceModel.endpointRules)
${indent} endpointResolutionOutcome.GetResult().AddPathSegments("$uriPartString");
#else
${indent} uri.AddPathSegments("$uriPartString");
#if($uriPartString.get(0).toLowerCase().contains("arn"))
${indent} endpointResolutionOutcome.GetResult().SetRfc3986Encoded(true);
#end
${indent} endpointResolutionOutcome.GetResult().AddPathSegments("$uriPartString");
#end## ---------------------------- if (request uri contains query) end ------
#foreach($var in $uriVars)## for (parameter in request uri parameters) -------
#set($varIndex = $partIndex - 1)
Expand All @@ -83,17 +79,12 @@ ${indent} uri.AddPathSegments("$uriPartString");
${indent} ss << $parameter;
#else
#if($greedySyntax)
#if($serviceModel.endpointRules)
${indent} endpointResolutionOutcome.GetResult().AddPathSegments($parameter);
#else
${indent} uri.AddPathSegments($parameter);
#if($parameter.toLowerCase().contains("arn"))
${indent} endpointResolutionOutcome.GetResult().SetRfc3986Encoded(true);
#end
${indent} endpointResolutionOutcome.GetResult().AddPathSegments($parameter);
#else
#if($serviceModel.endpointRules)
${indent} endpointResolutionOutcome.GetResult().AddPathSegment($parameter);
#else
${indent} uri.AddPathSegment($parameter);
#end
#end
#end
#if($uriParts.size() > $partIndex)
Expand All @@ -102,19 +93,17 @@ ${indent} uri.AddPathSegment($parameter);
#set($queryStart = true)
#set($pathAndQuery = $operation.http.splitUriPartIntoPathAndQuery($uriPartString))
#if(!$pathAndQuery.get(0).isEmpty())
#if($serviceModel.endpointRules)
${indent} endpointResolutionOutcome.GetResult().AddPathSegments("${pathAndQuery.get(0)}");
#else
${indent} uri.AddPathSegments("${pathAndQuery.get(0)}");
#if($pathAndQuery.get(0).toLowerCase().contains("arn"))
${indent} endpointResolutionOutcome.GetResult().SetRfc3986Encoded(true);
#end
${indent} endpointResolutionOutcome.GetResult().AddPathSegments("${pathAndQuery.get(0)}");
#end
${indent} ss.str("${pathAndQuery.get(1)}");
#elseif(!$uriPartString.equals("/"))
#if($serviceModel.endpointRules)
${indent} endpointResolutionOutcome.GetResult().AddPathSegments("$uriPartString");
#else
${indent} uri.AddPathSegments("$uriPartString");
#if($uriPartString.toLowerCase().contains("arn"))
${indent} endpointResolutionOutcome.GetResult().SetRfc3986Encoded(true);
#end
${indent} endpointResolutionOutcome.GetResult().AddPathSegments("$uriPartString");
#end
#end
#end## --------------------- if !skipFirst end ---
Expand All @@ -123,9 +112,5 @@ ${indent} uri.AddPathSegments("$uriPartString");
#end## --------------------- if uriParts.size() > startIndex end ---
#end## --------------------- for (parameter in request uri parameters) end ---
#if($queryStart)
#if($serviceModel.endpointRules)
${indent} endpointResolutionOutcome.GetResult().SetQueryString(ss.str());
#else
${indent} uri.SetQueryString(ss.str());
#end
#end
Original file line number Diff line number Diff line change
@@ -1,14 +1,6 @@
${operation.name}Outcome ${className}::${operation.name}() const
{
AWS_OPERATION_GUARD(${operation.name});
#if(!$serviceModel.endpointRules)
Aws::StringStream ss;
#if($metadata.hasEndpointTrait)
ss << m_baseUri << "${operation.http.requestUri}";
#else
ss << m_uri << "${operation.http.requestUri}";
#end
#end
AWS_OPERATION_CHECK_PTR(m_telemetryProvider, ${operation.name}, CoreErrors, CoreErrors::NOT_INITIALIZED);
auto tracer = m_telemetryProvider->getTracer(this->GetServiceClientName(), {});
auto meter = m_telemetryProvider->getMeter(this->GetServiceClientName(), {});
Expand All @@ -19,22 +11,17 @@ AWS_OPERATION_GUARD(${operation.name});
return TracingUtils::MakeCallWithTiming<${operation.name}Outcome>(
[&]()-> ${operation.name}Outcome {
#parse("com/amazonaws/util/awsclientgenerator/velocity/cpp/common/EndpointRulesNoRequestUriComputation.vm")
#if($serviceModel.endpointRules && $operation.http.requestUri && $operation.http.requestUri != "/")
#if($operation.http.requestUri && $operation.http.requestUri != "/")
#if($operation.http.requestUri.toLowerCase().contains("arn"))
endpointResolutionOutcome.GetResult().SetRfc3986Encoded(true);
#end
endpointResolutionOutcome.GetResult().AddPathSegments("${operation.http.requestUri}");
#end
#if($serviceModel.endpointRules)
#if($operation.result && $operation.result.shape.hasStreamMembers())
return ${operation.name}Outcome(MakeRequestWithUnparsedResponse(endpointResolutionOutcome.GetResult(), Aws::Http::HttpMethod::HTTP_${operation.http.method}, ${operation.signerName}, "${operation.name}"));
#else
return ${operation.name}Outcome(MakeRequest(endpointResolutionOutcome.GetResult(), Aws::Http::HttpMethod::HTTP_${operation.http.method}, ${operation.signerName}, "${operation.name}"));
#end
#else##-#if($serviceModel.endpointRules)
#if($operation.result && $operation.result.shape.hasStreamMembers())
return ${operation.name}Outcome(MakeRequestWithUnparsedResponse(ss.str(), Aws::Http::HttpMethod::HTTP_${operation.http.method}, ${operation.signerName}, "${operation.name}"));
#else
return ${operation.name}Outcome(MakeRequest(ss.str(), Aws::Http::HttpMethod::HTTP_${operation.http.method}, ${operation.signerName}, "${operation.name}"));
#end
#end##-#if($serviceModel.endpointRules)
},
TracingUtils::SMITHY_CLIENT_DURATION_METRIC,
*meter,
Expand Down

0 comments on commit 08cfb35

Please sign in to comment.