From e0f70315488c23ea64ab63ef7b8530ce01783024 Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Wed, 20 May 2020 18:19:02 +0100 Subject: [PATCH] Add support for streaming request payload (#237) - Add support for streaming data to requests - `AWSPayload` contains an internal enum that can be initialised with a stream case which takes a closure supply `ByteBuffers` - Add helper method `Body.asByteBuffer` - AWSSignerV4: Added `BodyData.unsignedPayload` - Use unsigned-payload when streaming request payloads - Added `AWSPayload.fileHandle` used to create payload from `NIOFileHandle` - `AWSPayload.stream` now has an optional size parameter as does `AWSPayload.fileHandle` - Use one eventloop for request streaming - Added `.empty` case to `AWSPayload` enum - Added "chunked" output for streamed requests without a size - Added flags to `AWSShapeWithPayload` indicate whether streaming is allowed Co-authored-by: Fabian Fett --- Sources/AWSSDKSwiftCore/AWSClient.swift | 25 ++- .../AWSSDKSwiftCore/AWSShapes/Payload.swift | 123 ++++++++++-- Sources/AWSSDKSwiftCore/Doc/AWSShape.swift | 29 ++- .../AWSSDKSwiftCore/HTTP/AWSHTTPClient.swift | 10 +- .../HTTP/AsyncHTTPClient.swift | 73 ++++++- .../HTTP/NIOTSHTTPClient.swift | 16 +- .../Message/AWSMiddleware.swift | 4 +- .../AWSSDKSwiftCore/Message/AWSRequest.swift | 36 +++- .../AWSSDKSwiftCore/Message/AWSResponse.swift | 14 +- Sources/AWSSDKSwiftCore/Message/Body.swift | 33 ++-- Sources/AWSSDKSwiftCore/MetaDataService.swift | 2 +- Sources/AWSSignerV4/signer.swift | 3 + .../AWSSDKSwiftCoreTests/AWSClientTests.swift | 185 +++++++++++++++++- .../HTTPClientTests.swift | 14 +- Tests/AWSSDKSwiftCoreTests/TestServer.swift | 1 + Tests/AWSSDKSwiftCoreTests/TestUtils.swift | 20 +- 16 files changed, 505 insertions(+), 83 deletions(-) diff --git a/Sources/AWSSDKSwiftCore/AWSClient.swift b/Sources/AWSSDKSwiftCore/AWSClient.swift index cd78c1192..53f28c9ac 100644 --- a/Sources/AWSSDKSwiftCore/AWSClient.swift +++ b/Sources/AWSSDKSwiftCore/AWSClient.swift @@ -33,6 +33,7 @@ public final class AWSClient { public enum ClientError: Swift.Error { case invalidURL(String) + case tooMuchData } enum InternalError: Swift.Error { @@ -385,6 +386,12 @@ extension AWSClient { ).applyMiddlewares(serviceConfig.middlewares + middlewares) } + internal func verifyStream(operation: String, payload: AWSPayload, input: AWSShapeWithPayload.Type) { + guard case .stream(let size,_) = payload.payload else { return } + precondition(input.options.contains(.allowStreaming), "\(operation) does not allow streaming of data") + precondition(size != nil || input.options.contains(.allowChunkedStreaming), "\(operation) does not allow chunked streaming of data. Please supply a data size.") + } + internal func createAWSRequest(operation operationName: String, path: String, httpMethod: String, input: Input) throws -> AWSRequest { var headers: [String: Any] = [:] var path = path @@ -438,11 +445,13 @@ extension AWSClient { switch serviceConfig.serviceProtocol { case .json, .restjson: - if let payload = (Input.self as? AWSShapeWithPayload.Type)?.payloadPath { + if let shapeWithPayload = Input.self as? AWSShapeWithPayload.Type { + let payload = shapeWithPayload.payloadPath if let payloadBody = mirror.getAttribute(forKey: payload) { switch payloadBody { case let awsPayload as AWSPayload: - body = .buffer(awsPayload.byteBuffer) + verifyStream(operation: operationName, payload: awsPayload, input: shapeWithPayload) + body = .raw(awsPayload) case let shape as AWSEncodableShape: body = .json(try shape.encodeAsJSON()) default: @@ -474,11 +483,13 @@ extension AWSClient { } case .restxml: - if let payload = (Input.self as? AWSShapeWithPayload.Type)?.payloadPath { + if let shapeWithPayload = Input.self as? AWSShapeWithPayload.Type { + let payload = shapeWithPayload.payloadPath if let payloadBody = mirror.getAttribute(forKey: payload) { switch payloadBody { case let awsPayload as AWSPayload: - body = .buffer(awsPayload.byteBuffer) + verifyStream(operation: operationName, payload: awsPayload, input: shapeWithPayload) + body = .raw(awsPayload) case let shape as AWSEncodableShape: var rootName: String? = nil // extract custom payload name @@ -633,9 +644,9 @@ extension AWSClient { } return try XMLDecoder().decode(Output.self, from: outputNode) - case .buffer(let byteBuffer): + case .raw(let payload): if let payloadKey = payloadKey { - outputDict[payloadKey] = AWSPayload.byteBuffer(byteBuffer) + outputDict[payloadKey] = payload } default: @@ -788,6 +799,8 @@ extension AWSClient.ClientError: CustomStringConvertible { The request url \(urlString) is invalid format. This error is internal. So please make a issue on https://github.com/swift-aws/aws-sdk-swift/issues to solve it. """ + case .tooMuchData: + return "You have supplied too much data for the Request." } } } diff --git a/Sources/AWSSDKSwiftCore/AWSShapes/Payload.swift b/Sources/AWSSDKSwiftCore/AWSShapes/Payload.swift index d855f1a13..6951b02c7 100644 --- a/Sources/AWSSDKSwiftCore/AWSShapes/Payload.swift +++ b/Sources/AWSSDKSwiftCore/AWSShapes/Payload.swift @@ -16,44 +16,127 @@ import struct Foundation.Data import NIO import NIOFoundationCompat -/// Object storing request/response payload +/// Holds a request or response payload. A request payload can be in the form of either a ByteBuffer or a stream function that will supply ByteBuffers to the HTTP client. +/// A response payload only comes in the form of a ByteBuffer public struct AWSPayload { - - /// Construct a payload from a ByteBuffer - public static func byteBuffer(_ byteBuffer: ByteBuffer) -> Self { - return Self(byteBuffer: byteBuffer) + + /// Internal enum + enum Payload { + case byteBuffer(ByteBuffer) + case stream(size: Int?, stream: (EventLoop)->EventLoopFuture) + case empty } - - /// Construct a payload from a Data - public static func data(_ data: Data) -> Self { - var byteBuffer = ByteBufferAllocator().buffer(capacity: data.count) + + internal let payload: Payload + + /// construct a payload from a ByteBuffer + public static func byteBuffer(_ buffer: ByteBuffer) -> Self { + return AWSPayload(payload: .byteBuffer(buffer)) + } + + /// construct a payload from a stream function. If you supply a size the stream function will be called repeated until you supply the number of bytes specified. If you + /// don't supply a size the stream function will be called repeatedly until you supply an empty `ByteBuffer` + public static func stream(size: Int? = nil, stream: @escaping (EventLoop)->EventLoopFuture) -> Self { + return AWSPayload(payload: .stream(size: size, stream: stream)) + } + + /// construct an empty payload + public static var empty: Self { + return AWSPayload(payload: .empty) + } + + /// Construct a payload from `Data` + public static func data(_ data: Data, byteBufferAllocator: ByteBufferAllocator = ByteBufferAllocator()) -> Self { + var byteBuffer = byteBufferAllocator.buffer(capacity: data.count) byteBuffer.writeBytes(data) - return Self(byteBuffer: byteBuffer) + return AWSPayload(payload: .byteBuffer(byteBuffer)) } - /// Construct a payload from a String - public static func string(_ string: String) -> Self { - var byteBuffer = ByteBufferAllocator().buffer(capacity: string.utf8.count) + /// Construct a payload from a `String` + public static func string(_ string: String, byteBufferAllocator: ByteBufferAllocator = ByteBufferAllocator()) -> Self { + var byteBuffer = byteBufferAllocator.buffer(capacity: string.utf8.count) byteBuffer.writeString(string) - return Self(byteBuffer: byteBuffer) + return AWSPayload(payload: .byteBuffer(byteBuffer)) + } + + /// Construct a stream payload from a `NIOFileHandle` + public static func fileHandle(_ fileHandle: NIOFileHandle, size: Int? = nil, fileIO: NonBlockingFileIO, byteBufferAllocator: ByteBufferAllocator = ByteBufferAllocator()) -> Self { + let blockSize = 64*1024 + var leftToRead = size + func stream(_ eventLoop: EventLoop) -> EventLoopFuture { + // calculate how much data is left to read, if a file size was indicated + var blockSize = blockSize + if let leftToRead2 = leftToRead { + blockSize = min(blockSize, leftToRead2) + leftToRead = leftToRead2 - blockSize + } + let futureByteBuffer = fileIO.read(fileHandle: fileHandle, byteCount: blockSize, allocator: byteBufferAllocator, eventLoop: eventLoop) + + if leftToRead != nil { + return futureByteBuffer.map { byteBuffer in + precondition(byteBuffer.readableBytes == blockSize, "File did not have enough data") + return byteBuffer + } + } + return futureByteBuffer + } + + return AWSPayload(payload: .stream(size: size, stream: stream)) + } + + /// Return the size of the payload. If the payload is a stream it is always possible to return a size + var size: Int? { + switch payload { + case .byteBuffer(let byteBuffer): + return byteBuffer.readableBytes + case .stream(let size,_): + return size + case .empty: + return 0 + } } /// return payload as Data public func asData() -> Data? { - return byteBuffer.getData(at: byteBuffer.readerIndex, length: byteBuffer.readableBytes, byteTransferStrategy: .noCopy) + switch payload { + case .byteBuffer(let byteBuffer): + return byteBuffer.getData(at: byteBuffer.readerIndex, length: byteBuffer.readableBytes, byteTransferStrategy: .noCopy) + default: + return nil + } } /// return payload as String public func asString() -> String? { - return byteBuffer.getString(at: byteBuffer.readerIndex, length: byteBuffer.readableBytes, encoding: .utf8) + switch payload { + case .byteBuffer(let byteBuffer): + return byteBuffer.getString(at: byteBuffer.readerIndex, length: byteBuffer.readableBytes) + default: + return nil + } } /// return payload as ByteBuffer - public func asBytebuffer() -> ByteBuffer { - return byteBuffer + public func asByteBuffer() -> ByteBuffer? { + switch payload { + case .byteBuffer(let byteBuffer): + return byteBuffer + default: + return nil + } + } + + /// does payload consist of zero bytes + public var isEmpty: Bool { + switch payload { + case .byteBuffer(let buffer): + return buffer.readableBytes == 0 + case .stream: + return false + case .empty: + return true + } } - - let byteBuffer: ByteBuffer } extension AWSPayload: Decodable { diff --git a/Sources/AWSSDKSwiftCore/Doc/AWSShape.swift b/Sources/AWSSDKSwiftCore/Doc/AWSShape.swift index 7d020324d..3992b44f9 100644 --- a/Sources/AWSSDKSwiftCore/Doc/AWSShape.swift +++ b/Sources/AWSSDKSwiftCore/Doc/AWSShape.swift @@ -106,10 +106,18 @@ public extension AWSEncodableShape { guard value.count <= max else { throw AWSClientError(.validationError, message: "Length of \(parent).\(name) (\(value.count)) is greater than the maximum allowed value \(max).") } } func validate(_ value: AWSPayload, name: String, parent: String, min: Int) throws { - guard value.byteBuffer.readableBytes >= min else { throw AWSClientError(.validationError, message: "Length of \(parent).\(name) (\(value.byteBuffer.readableBytes)) is less than minimum allowed value \(min).") } + if let size = value.size { + guard size >= min else { + throw AWSClientError(.validationError, message: "Length of \(parent).\(name) (\(size)) is less than minimum allowed value \(min).") + } + } } func validate(_ value: AWSPayload, name: String, parent: String, max: Int) throws { - guard value.byteBuffer.readableBytes <= max else { throw AWSClientError(.validationError, message: "Length of \(parent).\(name) (\(value.byteBuffer.readableBytes)) is greater than the maximum allowed value \(max).") } + if let size = value.size { + guard size <= max else { + throw AWSClientError(.validationError, message: "Length of \(parent).\(name) (\(size)) is greater than the maximum allowed value \(max).") + } + } } func validate(_ value: String, name: String, parent: String, pattern: String) throws { let regularExpression = try NSRegularExpression(pattern: pattern, options: []) @@ -158,8 +166,25 @@ public extension AWSEncodableShape { /// AWSShape that can be decoded public protocol AWSDecodableShape: AWSShape & Decodable {} +/// AWSShapeWithPayload options. +public struct PayloadOptions: OptionSet { + public var rawValue: Int + + public init(rawValue: Int) { + self.rawValue = rawValue + } + + public static let allowStreaming = PayloadOptions(rawValue: 1<<0) + public static let allowChunkedStreaming = PayloadOptions(rawValue: 1<<1) +} + /// Root AWSShape which include a payload public protocol AWSShapeWithPayload { /// The path to the object that is included in the request body static var payloadPath: String { get } + static var options: PayloadOptions { get } +} + +extension AWSShapeWithPayload { + public static var options: PayloadOptions { return [] } } diff --git a/Sources/AWSSDKSwiftCore/HTTP/AWSHTTPClient.swift b/Sources/AWSSDKSwiftCore/HTTP/AWSHTTPClient.swift index 05878a18c..ef318c312 100644 --- a/Sources/AWSSDKSwiftCore/HTTP/AWSHTTPClient.swift +++ b/Sources/AWSSDKSwiftCore/HTTP/AWSHTTPClient.swift @@ -21,9 +21,9 @@ public struct AWSHTTPRequest { public let url: URL public let method: HTTPMethod public let headers: HTTPHeaders - public let body: ByteBuffer? - - public init(url: URL, method: HTTPMethod, headers: HTTPHeaders = [:], body: ByteBuffer? = nil) { + public let body: AWSPayload + + public init(url: URL, method: HTTPMethod, headers: HTTPHeaders = [:], body: AWSPayload = .empty) { self.url = url self.method = method self.headers = headers @@ -42,10 +42,10 @@ public protocol AWSHTTPResponse { public protocol AWSHTTPClient { /// Execute HTTP request and return a future holding a HTTP Response func execute(request: AWSHTTPRequest, timeout: TimeAmount, on eventLoop: EventLoop?) -> EventLoopFuture - + /// This should be called before an HTTP Client can be de-initialised func syncShutdown() throws - + /// Event loop group used by client var eventLoopGroup: EventLoopGroup { get } } diff --git a/Sources/AWSSDKSwiftCore/HTTP/AsyncHTTPClient.swift b/Sources/AWSSDKSwiftCore/HTTP/AsyncHTTPClient.swift index 749a644e9..30a2f7371 100644 --- a/Sources/AWSSDKSwiftCore/HTTP/AsyncHTTPClient.swift +++ b/Sources/AWSSDKSwiftCore/HTTP/AsyncHTTPClient.swift @@ -18,19 +18,80 @@ import NIO /// comply with AWSHTTPClient protocol extension AsyncHTTPClient.HTTPClient: AWSHTTPClient { + + /// write stream to StreamWriter + private func writeToStreamWriter( + writer: HTTPClient.Body.StreamWriter, + size: Int?, + on eventLoop: EventLoop, + getData: @escaping (EventLoop)->EventLoopFuture) -> EventLoopFuture { + let promise = eventLoop.makePromise(of: Void.self) + + func _writeToStreamWriter(_ amountLeft: Int?) { + // get byte buffer from closure, write to StreamWriter, if there are still bytes to write then call + // _writeToStreamWriter again. + _ = getData(eventLoop) + .map { (byteBuffer)->() in + // if no amount was set and the byte buffer has no readable bytes then this is assumed to mean + // there will be no more data + if amountLeft == nil && byteBuffer.readableBytes == 0 { + promise.succeed(()) + return + } + // calculate amount left to write + let newAmountLeft = amountLeft.map { $0 - byteBuffer.readableBytes } + // write chunk. If amountLeft is nil assume we are writing chunked output + let writeFuture: EventLoopFuture = writer.write(.byteBuffer(byteBuffer)) + _ = writeFuture.flatMap { ()->EventLoopFuture in + if let newAmountLeft = newAmountLeft { + if newAmountLeft == 0 { + promise.succeed(()) + } else if newAmountLeft < 0 { + promise.fail(AWSClient.ClientError.tooMuchData) + } else { + _writeToStreamWriter(newAmountLeft) + } + } else { + _writeToStreamWriter(nil) + } + return promise.futureResult + }.cascadeFailure(to: promise) + }.cascadeFailure(to: promise) + } + _writeToStreamWriter(size) + return promise.futureResult + } + + /// Execute HTTP request + /// - Parameters: + /// - request: HTTP request + /// - timeout: If execution is idle for longer than timeout then throw error + /// - eventLoop: eventLoop to run request on + /// - Returns: EventLoopFuture that will be fulfilled with request response public func execute(request: AWSHTTPRequest, timeout: TimeAmount, on eventLoop: EventLoop?) -> EventLoopFuture { if let eventLoop = eventLoop { precondition(self.eventLoopGroup.makeIterator().contains { $0 === eventLoop }, "EventLoop provided to AWSClient must be part of the HTTPClient's EventLoopGroup.") - } + } + let eventLoop = eventLoop ?? eventLoopGroup.next() let requestBody: AsyncHTTPClient.HTTPClient.Body? - if let body = request.body { - requestBody = .byteBuffer(body) - } else { + var requestHeaders = request.headers + + switch request.body.payload { + case .byteBuffer(let byteBuffer): + requestBody = .byteBuffer(byteBuffer) + case .stream(let size, let getData): + // add "Transfer-Encoding" header if streaming with unknown size + if size == nil { + requestHeaders.add(name: "Transfer-Encoding", value: "chunked") + } + requestBody = .stream(length: size) { writer in + return self.writeToStreamWriter(writer: writer, size: size, on: eventLoop, getData: getData) + } + case .empty: requestBody = nil } do { - let eventLoop = eventLoop ?? eventLoopGroup.next() - let asyncRequest = try AsyncHTTPClient.HTTPClient.Request(url: request.url, method: request.method, headers: request.headers, body: requestBody) + let asyncRequest = try AsyncHTTPClient.HTTPClient.Request(url: request.url, method: request.method, headers: requestHeaders, body: requestBody) return execute(request: asyncRequest, eventLoop: .delegate(on: eventLoop), deadline: .now() + timeout).map { $0 } } catch { return eventLoopGroup.next().makeFailedFuture(error) diff --git a/Sources/AWSSDKSwiftCore/HTTP/NIOTSHTTPClient.swift b/Sources/AWSSDKSwiftCore/HTTP/NIOTSHTTPClient.swift index 97618c295..00fa47dce 100644 --- a/Sources/AWSSDKSwiftCore/HTTP/NIOTSHTTPClient.swift +++ b/Sources/AWSSDKSwiftCore/HTTP/NIOTSHTTPClient.swift @@ -157,9 +157,7 @@ public final class NIOTSHTTPClient { head.headers.replaceOrAdd(name: "Host", value: hostname) head.headers.replaceOrAdd(name: "User-Agent", value: "AWS SDK Swift Core") - if let body = request.body { - head.headers.replaceOrAdd(name: "Content-Length", value: body.readableBytes.description) - } + head.headers.replaceOrAdd(name: "Content-Length", value: request.body?.readableBytes.description ?? "0") head.headers.replaceOrAdd(name: "Connection", value: "Close") @@ -254,7 +252,17 @@ extension NIOTSHTTPClient: AWSHTTPClient { uri: request.url.absoluteString ) head.headers = request.headers - let request = Request(head: head, body: request.body) + + let requestBody: ByteBuffer? + switch request.body.payload { + case .byteBuffer(let byteBuffer): + requestBody = byteBuffer + case .stream: + preconditionFailure("Request streaming isnt supported") + case .empty: + requestBody = nil + } + let request = Request(head: head, body: requestBody) return connect(request, timeout: timeout, on: eventLoop).map { return $0 } } diff --git a/Sources/AWSSDKSwiftCore/Message/AWSMiddleware.swift b/Sources/AWSSDKSwiftCore/Message/AWSMiddleware.swift index 23dc93532..1759ad94a 100644 --- a/Sources/AWSSDKSwiftCore/Message/AWSMiddleware.swift +++ b/Sources/AWSSDKSwiftCore/Message/AWSMiddleware.swift @@ -51,8 +51,8 @@ public struct AWSLoggingMiddleware : AWSServiceMiddleware { case .json(let data): output += "\n " output += String(data: data, encoding: .utf8) ?? "Failed to convert JSON response to UTF8" - case .buffer(let byteBuffer): - output += "data (\(byteBuffer.readableBytes) bytes)" + case .raw(let payload): + output += "raw (\(payload.size?.description ?? "unknown") bytes)" case .text(let string): output += "\n \(string)" case .empty: diff --git a/Sources/AWSSDKSwiftCore/Message/AWSRequest.swift b/Sources/AWSSDKSwiftCore/Message/AWSRequest.swift index 8f90d9bf0..353bcaccf 100644 --- a/Sources/AWSSDKSwiftCore/Message/AWSRequest.swift +++ b/Sources/AWSSDKSwiftCore/Message/AWSRequest.swift @@ -67,7 +67,7 @@ public struct AWSRequest { case "GET","HEAD": break default: - if case .restjson = serviceProtocol, case .buffer(_) = body { + if case .restjson = serviceProtocol, case .raw(_) = body { headers["Content-Type"] = "binary/octet-stream" } else { headers["Content-Type"] = serviceProtocol.contentType @@ -76,7 +76,7 @@ public struct AWSRequest { } return HTTPHeaders(headers.map { ($0, $1) }) } - + func createHTTPRequest(signer: AWSSigner) -> AWSHTTPRequest { // if credentials are empty don't sign request if signer.credentials.isEmpty() { @@ -100,23 +100,41 @@ public struct AWSRequest { /// Create HTTP Client request from AWSRequest func toHTTPRequest() -> AWSHTTPRequest { - return AWSHTTPRequest.init(url: url, method: HTTPMethod(rawValue: httpMethod), headers: getHttpHeaders(), body: body.asByteBuffer()) + return AWSHTTPRequest.init(url: url, method: HTTPMethod(rawValue: httpMethod), headers: getHttpHeaders(), body: body.asPayload()) } /// Create HTTP Client request with signed URL from AWSRequest func toHTTPRequestWithSignedURL(signer: AWSSigner) -> AWSHTTPRequest { let method = HTTPMethod(rawValue: httpMethod) - let bodyData = body.asByteBuffer() - let signedURL = signer.signURL(url: url, method: method, body: bodyData != nil ? .byteBuffer(bodyData!) : nil, date: Date(), expires: 86400) - return AWSHTTPRequest.init(url: signedURL, method: method, headers: getHttpHeaders(), body: bodyData) + let payload = self.body.asPayload() + let bodyDataForSigning: AWSSigner.BodyData? + switch payload.payload { + case .byteBuffer(let buffer): + bodyDataForSigning = .byteBuffer(buffer) + case .stream: + bodyDataForSigning = .unsignedPayload + case .empty: + bodyDataForSigning = nil + } + let signedURL = signer.signURL(url: url, method: method, body: bodyDataForSigning, date: Date(), expires: 86400) + return AWSHTTPRequest.init(url: signedURL, method: method, headers: getHttpHeaders(), body: payload) } /// Create HTTP Client request with signed headers from AWSRequest func toHTTPRequestWithSignedHeader(signer: AWSSigner) -> AWSHTTPRequest { let method = HTTPMethod(rawValue: httpMethod) - let bodyData = body.asByteBuffer() - let signedHeaders = signer.signHeaders(url: url, method: method, headers: getHttpHeaders(), body: bodyData != nil ? .byteBuffer(bodyData!) : nil, date: Date()) - return AWSHTTPRequest.init(url: url, method: method, headers: signedHeaders, body: bodyData) + let payload = self.body.asPayload() + let bodyDataForSigning: AWSSigner.BodyData? + switch payload.payload { + case .byteBuffer(let buffer): + bodyDataForSigning = .byteBuffer(buffer) + case .stream: + bodyDataForSigning = .unsignedPayload + case .empty: + bodyDataForSigning = nil + } + let signedHeaders = signer.signHeaders(url: url, method: method, headers: getHttpHeaders(), body: bodyDataForSigning, date: Date()) + return AWSHTTPRequest.init(url: url, method: method, headers: signedHeaders, body: payload) } // return new request with middleware applied diff --git a/Sources/AWSSDKSwiftCore/Message/AWSResponse.swift b/Sources/AWSSDKSwiftCore/Message/AWSResponse.swift index 48e270435..27b3718f0 100644 --- a/Sources/AWSSDKSwiftCore/Message/AWSResponse.swift +++ b/Sources/AWSSDKSwiftCore/Message/AWSResponse.swift @@ -42,14 +42,18 @@ public struct AWSResponse { // body guard let body = response.body, - body.readableBytes > 0, - let data = body.getData(at: body.readerIndex, length: body.readableBytes, byteTransferStrategy: .noCopy) else { - self.body = .empty - return + body.readableBytes > 0 else { + self.body = .empty + return } if raw { - self.body = .buffer(body) + self.body = .raw(.byteBuffer(body)) + return + } + + guard let data = body.getData(at: body.readerIndex, length: body.readableBytes, byteTransferStrategy: .noCopy) else { + self.body = .empty return } diff --git a/Sources/AWSSDKSwiftCore/Message/Body.swift b/Sources/AWSSDKSwiftCore/Message/Body.swift index 7bcd74acc..eb06138d8 100644 --- a/Sources/AWSSDKSwiftCore/Message/Body.swift +++ b/Sources/AWSSDKSwiftCore/Message/Body.swift @@ -22,7 +22,7 @@ public enum Body { /// text case text(String) /// raw data - case buffer(ByteBuffer) + case raw(AWSPayload) /// json data case json(Data) /// xml @@ -38,8 +38,12 @@ extension Body { case .text(let text): return text - case .buffer(let byteBuffer): - return byteBuffer.getString(at: byteBuffer.readerIndex, length: byteBuffer.readableBytes, encoding: .utf8) + case .raw(let payload): + if let byteBuffer = payload.asByteBuffer() { + return byteBuffer.getString(at: byteBuffer.readerIndex, length: byteBuffer.readableBytes, encoding: .utf8) + } else { + return nil + } case .json(let data): return String(data: data, encoding: .utf8) @@ -53,24 +57,24 @@ extension Body { } } - /// return as bytebuffer - public func asByteBuffer() -> ByteBuffer? { + /// return as payload + public func asPayload() -> AWSPayload { switch self { case .text(let text): var buffer = ByteBufferAllocator().buffer(capacity: text.utf8.count) buffer.writeString(text) - return buffer + return .byteBuffer(buffer) - case .buffer(let byteBuffer): - return byteBuffer + case .raw(let payload): + return payload case .json(let data): if data.isEmpty { - return nil + return .empty } else { var buffer = ByteBufferAllocator().buffer(capacity: data.count) buffer.writeBytes(data) - return buffer + return .byteBuffer(buffer) } case .xml(let node): @@ -78,10 +82,15 @@ extension Body { let text = xmlDocument.xmlString var buffer = ByteBufferAllocator().buffer(capacity: text.utf8.count) buffer.writeString(text) - return buffer + return .byteBuffer(buffer) case .empty: - return nil + return .empty } } + + // return as ByteBuffer + public func asByteBuffer() -> ByteBuffer? { + return asPayload().asByteBuffer() + } } diff --git a/Sources/AWSSDKSwiftCore/MetaDataService.swift b/Sources/AWSSDKSwiftCore/MetaDataService.swift index 9fdcfac98..64ae5d234 100644 --- a/Sources/AWSSDKSwiftCore/MetaDataService.swift +++ b/Sources/AWSSDKSwiftCore/MetaDataService.swift @@ -63,7 +63,7 @@ extension MetaDataServiceProvider { /// make HTTP request func request(url: String, method: HTTPMethod = .GET, headers: [String:String] = [:], timeout: TimeInterval, httpClient: AWSHTTPClient, on eventLoop: EventLoop) -> EventLoopFuture { - let request = AWSHTTPRequest(url: URL(string: url)!, method: method, headers: HTTPHeaders(headers.map {($0.key, $0.value) }), body: nil) + let request = AWSHTTPRequest(url: URL(string: url)!, method: method, headers: HTTPHeaders(headers.map {($0.key, $0.value) })) let futureResponse = httpClient.execute(request: request, timeout: TimeAmount.seconds(2), on: eventLoop) return futureResponse } diff --git a/Sources/AWSSignerV4/signer.swift b/Sources/AWSSignerV4/signer.swift index f73a60977..c38aefc2f 100644 --- a/Sources/AWSSignerV4/signer.swift +++ b/Sources/AWSSignerV4/signer.swift @@ -48,6 +48,7 @@ public struct AWSSigner { case string(String) case data(Data) case byteBuffer(ByteBuffer) + case unsignedPayload } /// Generate signed headers, for a HTTP request @@ -202,6 +203,8 @@ public struct AWSSigner { hash = byteBufferView.withContiguousStorageIfAvailable { bytes in return SHA256.hash(data: bytes).hexDigest() } + case .unsignedPayload: + return "UNSIGNED-PAYLOAD" } if let hash = hash { return hash diff --git a/Tests/AWSSDKSwiftCoreTests/AWSClientTests.swift b/Tests/AWSSDKSwiftCoreTests/AWSClientTests.swift index 3a41b1c7d..dee8810d5 100644 --- a/Tests/AWSSDKSwiftCoreTests/AWSClientTests.swift +++ b/Tests/AWSSDKSwiftCoreTests/AWSClientTests.swift @@ -246,7 +246,7 @@ class AWSClientTests: XCTestCase { XCTAssertEqual(awsRequest.url.absoluteString, "https://s3.ca-central-1.amazonaws.com/Bucket?list-type=2") let nioRequest: AWSHTTPRequest = awsRequest.toHTTPRequest() XCTAssertEqual(nioRequest.method, HTTPMethod.GET) - XCTAssertEqual(nioRequest.body, nil) + XCTAssertTrue(nioRequest.body.isEmpty) } catch { XCTFail(error.localizedDescription) } @@ -303,7 +303,7 @@ class AWSClientTests: XCTestCase { let awsRequest = try client.createAWSRequest(operation: "test", path: "/", httpMethod: "GET") XCTAssertEqual(awsRequest.url.absoluteString, "https://service.aws.amazon.com/") } - + func testCreateAwsRequestWithKeywordInHeader() { struct KeywordRequest: AWSEncodableShape { static var _encoding: [AWSMemberEncoding] = [ @@ -315,7 +315,7 @@ class AWSClientTests: XCTestCase { let request = KeywordRequest(repeat: "Repeat") let awsRequest = try s3Client.createAWSRequest(operation: "Keyword", path: "/", httpMethod: "POST", input: request) XCTAssertEqual(awsRequest.httpHeaders["repeat"] as? String, "Repeat") - XCTAssertEqual(awsRequest.body.asByteBuffer(), nil) + XCTAssertTrue(awsRequest.body.asPayload().isEmpty) } catch { XCTFail(error.localizedDescription) } @@ -1012,6 +1012,185 @@ class AWSClientTests: XCTestCase { } } + func testRequestStreaming() { + struct Input : AWSEncodableShape & AWSShapeWithPayload { + static var payloadPath: String = "payload" + static var options: PayloadOptions = [.allowStreaming] + let payload: AWSPayload + private enum CodingKeys: CodingKey {} + } + + let awsServer = AWSTestServer(serviceProtocol: .json) + let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) + defer { + XCTAssertNoThrow(try awsServer.stop()) + XCTAssertNoThrow(try httpClient.syncShutdown()) + } + do { + let client = createAWSClient(accessKeyId: "", secretAccessKey: "", endpoint: awsServer.address, httpClientProvider: .shared(httpClient)) + + // supply buffer in 16k blocks + let bufferSize = 1024*1024 + let blockSize = 16*1024 + let data = createRandomBuffer(45,9182, size: bufferSize) + + var i = 0 + let payload = AWSPayload.stream(size: bufferSize) { eventLoop in + var buffer = ByteBufferAllocator().buffer(capacity: blockSize) + buffer.writeBytes(data[i..<(i+blockSize)]) + i = i + blockSize + return eventLoop.makeSucceededFuture(buffer) + } + let input = Input(payload: payload) + let response = client.send(operation: "test", path: "/", httpMethod: "POST", input: input) + + try awsServer.process { request in + let bytes = request.body.getBytes(at: 0, length: request.body.readableBytes) + XCTAssertEqual(bytes, data) + let response = AWSTestServer.Response(httpStatus: .ok, headers: [:], body: nil) + return AWSTestServer.Result(output: response, continueProcessing: false) + } + + try response.wait() + } catch { + XCTFail("Unexpected error: \(error)") + } + } + + func testRequestStreamingTooMuchData() { + struct Input : AWSEncodableShape & AWSShapeWithPayload { + static var payloadPath: String = "payload" + static var options: PayloadOptions = [.allowStreaming] + let payload: AWSPayload + private enum CodingKeys: CodingKey {} + } + + let awsServer = AWSTestServer(serviceProtocol: .json) + let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) + defer { + try? awsServer.stop() + XCTAssertNoThrow(try httpClient.syncShutdown()) + } + do { + let client = createAWSClient(accessKeyId: "", secretAccessKey: "", endpoint: awsServer.address, httpClientProvider: .shared(httpClient)) + + // set up stream of 8 bytes but supply more than that + let payload = AWSPayload.stream(size: 8) { eventLoop in + var buffer = ByteBufferAllocator().buffer(capacity: 0) + buffer.writeString("String longer than 8 bytes") + return eventLoop.makeSucceededFuture(buffer) + } + let input = Input(payload: payload) + let response = client.send(operation: "test", path: "/", httpMethod: "POST", input: input) + try response.wait() + } catch AWSClient.ClientError.tooMuchData { + } catch { + XCTFail("Unexpected error: \(error)") + } + } + + func testRequestStreamingFile() { + struct Input : AWSEncodableShape & AWSShapeWithPayload { + static var payloadPath: String = "payload" + static var options: PayloadOptions = [.allowStreaming] + let payload: AWSPayload + private enum CodingKeys: CodingKey {} + } + + let awsServer = AWSTestServer(serviceProtocol: .json) + let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) + defer { + try? awsServer.stop() + XCTAssertNoThrow(try httpClient.syncShutdown()) + } + do { + let client = createAWSClient(accessKeyId: "", secretAccessKey: "", endpoint: awsServer.address, httpClientProvider: .shared(httpClient)) + + let bufferSize = 208*1024 + let data = Data(createRandomBuffer(45,9182, size: bufferSize)) + let filename = "testRequestStreamingFile" + let fileURL = URL(fileURLWithPath: filename) + try data.write(to: fileURL) + defer { + XCTAssertNoThrow(try FileManager.default.removeItem(at: fileURL)) + } + + let threadPool = NIOThreadPool(numberOfThreads: 3) + threadPool.start() + let fileIO = NonBlockingFileIO(threadPool: threadPool) + let fileHandle = try fileIO.openFile(path: filename, mode: .read, eventLoop: httpClient.eventLoopGroup.next()).wait() + defer { + XCTAssertNoThrow(try fileHandle.close()) + XCTAssertNoThrow(try threadPool.syncShutdownGracefully()) + } + + let input = Input(payload: .fileHandle(fileHandle, size: bufferSize, fileIO: fileIO)) + let response = client.send(operation: "test", path: "/", httpMethod: "POST", input: input) + + try awsServer.process { request in + XCTAssertNil(request.headers["transfer-encoding"]) + XCTAssertEqual(request.headers["Content-Length"], bufferSize.description) + let requestData = request.body.getData(at: 0, length: request.body.readableBytes) + XCTAssertEqual(requestData, data) + let response = AWSTestServer.Response(httpStatus: .ok, headers: [:], body: nil) + return AWSTestServer.Result(output: response, continueProcessing: false) + } + + try response.wait() + } catch AWSClient.ClientError.tooMuchData { + } catch { + XCTFail("Unexpected error: \(error)") + } + } + + func testRequestChunkedStreaming() { + struct Input : AWSEncodableShape & AWSShapeWithPayload { + static var payloadPath: String = "payload" + static var options: PayloadOptions = [.allowStreaming, .allowChunkedStreaming] + let payload: AWSPayload + private enum CodingKeys: CodingKey {} + } + + let awsServer = AWSTestServer(serviceProtocol: .json) + let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) + defer { + XCTAssertNoThrow(try awsServer.stop()) + XCTAssertNoThrow(try httpClient.syncShutdown()) + } + do { + let client = createAWSClient(accessKeyId: "", secretAccessKey: "", endpoint: awsServer.address, httpClientProvider: .shared(httpClient)) + + // supply buffer in 16k blocks + let bufferSize = 145*1024 + let blockSize = 16*1024 + let data = createRandomBuffer(45,9182, size: bufferSize) + var byteBuffer = ByteBufferAllocator().buffer(capacity: bufferSize) + byteBuffer.writeBytes(data) + + let payload = AWSPayload.stream { eventLoop in + let size = min(blockSize, byteBuffer.readableBytes) + if size == 0 { + return eventLoop.makeSucceededFuture((byteBuffer)) + } else { + return eventLoop.makeSucceededFuture(byteBuffer.readSlice(length: size)!) + } + } + let input = Input(payload: payload) + let response = client.send(operation: "test", path: "/", httpMethod: "POST", input: input) + + try awsServer.process { request in + let bytes = request.body.getBytes(at: 0, length: request.body.readableBytes) + XCTAssertTrue(bytes == data) + let response = AWSTestServer.Response(httpStatus: .ok, headers: [:], body: nil) + return AWSTestServer.Result(output: response, continueProcessing: false) + } + + try response.wait() + } catch { + XCTFail("Unexpected error: \(error)") + } + } + func testProvideHTTPClient() { do { // By default AsyncHTTPClient will follow redirects. This test creates an HTTP client that doesn't follow redirects and diff --git a/Tests/AWSSDKSwiftCoreTests/HTTPClientTests.swift b/Tests/AWSSDKSwiftCoreTests/HTTPClientTests.swift index 9a62af088..12d7c9f13 100644 --- a/Tests/AWSSDKSwiftCoreTests/HTTPClientTests.swift +++ b/Tests/AWSSDKSwiftCoreTests/HTTPClientTests.swift @@ -50,7 +50,7 @@ class NIOTSHTTPClientTests: XCTestCase { func testInitWithInvalidURL() { do { - let request = AWSHTTPRequest(url: URL(string:"no_protocol.com")!, method: .GET, headers: HTTPHeaders(), body: nil) + let request = AWSHTTPRequest(url: URL(string:"no_protocol.com")!, method: .GET, headers: HTTPHeaders()) _ = try client.execute(request: request, timeout: .seconds(5), on: client.eventLoopGroup.next()).wait() XCTFail("Should throw malformedURL error") } catch { @@ -63,7 +63,7 @@ class NIOTSHTTPClientTests: XCTestCase { func testConnectGet() { do { - let request = AWSHTTPRequest(url: awsServer.addressURL, method: .GET, headers: HTTPHeaders(), body: nil) + let request = AWSHTTPRequest(url: awsServer.addressURL, method: .GET, headers: HTTPHeaders()) let future = client.execute(request: request, timeout: .seconds(5), on: client.eventLoopGroup.next()) try awsServer.httpBin() _ = try future.wait() @@ -74,7 +74,7 @@ class NIOTSHTTPClientTests: XCTestCase { func testConnectPost() { do { - let request = AWSHTTPRequest(url: awsServer.addressURL, method: .POST, headers: HTTPHeaders(), body: nil) + let request = AWSHTTPRequest(url: awsServer.addressURL, method: .POST, headers: HTTPHeaders()) let future = client.execute(request: request, timeout: .seconds(5), on: client.eventLoopGroup.next()) try awsServer.httpBin() _ = try future.wait() @@ -153,7 +153,7 @@ class HTTPClientTests { XCTAssertNoThrow(try awsServer.stop()) } let headers: HTTPHeaders = [:] - let request = AWSHTTPRequest(url: URL(string: "\(awsServer.address)/get?test=2")!, method: .GET, headers: headers, body: nil) + let request = AWSHTTPRequest(url: URL(string: "\(awsServer.address)/get?test=2")!, method: .GET, headers: headers, body: .empty) let responseFuture = execute(request) try awsServer.httpBin() @@ -170,7 +170,7 @@ class HTTPClientTests { func testHTTPS() { do { let headers: HTTPHeaders = [:] - let request = AWSHTTPRequest(url: URL(string:"https://httpbin.org/get")!, method: .GET, headers: headers, body: nil) + let request = AWSHTTPRequest(url: URL(string:"https://httpbin.org/get")!, method: .GET, headers: headers) let response = try execute(request).wait() XCTAssertEqual(response.url, "https://httpbin.org/get") @@ -188,7 +188,7 @@ class HTTPClientTests { let headers: HTTPHeaders = [ "Test-Header": "testValue" ] - let request = AWSHTTPRequest(url: awsServer.addressURL, method: .POST, headers: headers, body: nil) + let request = AWSHTTPRequest(url: awsServer.addressURL, method: .POST, headers: headers, body: .empty) let responseFuture = execute(request) try awsServer.httpBin() @@ -216,7 +216,7 @@ class HTTPClientTests { let text = "thisisatest" var body = ByteBufferAllocator().buffer(capacity: text.utf8.count) body.writeString(text) - let request = AWSHTTPRequest(url: awsServer.addressURL, method: .POST, headers: headers, body: body) + let request = AWSHTTPRequest(url: awsServer.addressURL, method: .POST, headers: headers, body: .byteBuffer(body)) let responseFuture = execute(request) try awsServer.httpBin() diff --git a/Tests/AWSSDKSwiftCoreTests/TestServer.swift b/Tests/AWSSDKSwiftCoreTests/TestServer.swift index 0a2db2dca..4ca8c86db 100644 --- a/Tests/AWSSDKSwiftCoreTests/TestServer.swift +++ b/Tests/AWSSDKSwiftCoreTests/TestServer.swift @@ -28,6 +28,7 @@ class AWSTestServer { case notEnd case emptyBody case noXMLBody + case corruptChunkedData } // what are we returning enum ServiceProtocol { diff --git a/Tests/AWSSDKSwiftCoreTests/TestUtils.swift b/Tests/AWSSDKSwiftCoreTests/TestUtils.swift index e01a56b44..db63454fb 100644 --- a/Tests/AWSSDKSwiftCoreTests/TestUtils.swift +++ b/Tests/AWSSDKSwiftCoreTests/TestUtils.swift @@ -23,7 +23,7 @@ import Foundation self.defaultValue = `default` self.variableName = variableName } - + public var wrappedValue: Value { get { guard let value = Environment[variableName] else { return defaultValue } @@ -71,3 +71,21 @@ func createAWSClient( httpClientProvider: httpClientProvider ) } + +// create a buffer of random values. Will always create the same given you supply the same z and w values +// Random number generator from https://www.codeproject.com/Articles/25172/Simple-Random-Number-Generation +func createRandomBuffer(_ w: UInt, _ z: UInt, size: Int) -> [UInt8] { + var z = z + var w = w + func getUInt8() -> UInt8 + { + z = 36969 * (z & 65535) + (z >> 16); + w = 18000 * (w & 65535) + (w >> 16); + return UInt8(((z << 16) + w) & 0xff); + } + var data = Array(repeating: 0, count: size) + for i in 0..