Skip to content

Commit

Permalink
Add support for streaming request payload (#237)
Browse files Browse the repository at this point in the history
- Add support for streaming data to requests
- `AWSPayload` contains an internal enum that can be initialised with a stream case which takes a closure supply `ByteBuffers`
- Add helper method `Body.asByteBuffer`
- AWSSignerV4: Added `BodyData.unsignedPayload`
- Use unsigned-payload when streaming request payloads
- Added `AWSPayload.fileHandle` used to create payload from `NIOFileHandle`
- `AWSPayload.stream` now has an optional size parameter as does `AWSPayload.fileHandle`
- Use one eventloop for request streaming
- Added `.empty` case to `AWSPayload` enum
- Added "chunked" output for streamed requests without a size
- Added flags to `AWSShapeWithPayload` indicate whether streaming is allowed

Co-authored-by: Fabian Fett <[email protected]>
  • Loading branch information
adam-fowler and fabianfett authored May 20, 2020
1 parent 86301ca commit e0f7031
Show file tree
Hide file tree
Showing 16 changed files with 505 additions and 83 deletions.
25 changes: 19 additions & 6 deletions Sources/AWSSDKSwiftCore/AWSClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public final class AWSClient {

public enum ClientError: Swift.Error {
case invalidURL(String)
case tooMuchData
}

enum InternalError: Swift.Error {
Expand Down Expand Up @@ -385,6 +386,12 @@ extension AWSClient {
).applyMiddlewares(serviceConfig.middlewares + middlewares)
}

internal func verifyStream(operation: String, payload: AWSPayload, input: AWSShapeWithPayload.Type) {
guard case .stream(let size,_) = payload.payload else { return }
precondition(input.options.contains(.allowStreaming), "\(operation) does not allow streaming of data")
precondition(size != nil || input.options.contains(.allowChunkedStreaming), "\(operation) does not allow chunked streaming of data. Please supply a data size.")
}

internal func createAWSRequest<Input: AWSEncodableShape>(operation operationName: String, path: String, httpMethod: String, input: Input) throws -> AWSRequest {
var headers: [String: Any] = [:]
var path = path
Expand Down Expand Up @@ -438,11 +445,13 @@ extension AWSClient {

switch serviceConfig.serviceProtocol {
case .json, .restjson:
if let payload = (Input.self as? AWSShapeWithPayload.Type)?.payloadPath {
if let shapeWithPayload = Input.self as? AWSShapeWithPayload.Type {
let payload = shapeWithPayload.payloadPath
if let payloadBody = mirror.getAttribute(forKey: payload) {
switch payloadBody {
case let awsPayload as AWSPayload:
body = .buffer(awsPayload.byteBuffer)
verifyStream(operation: operationName, payload: awsPayload, input: shapeWithPayload)
body = .raw(awsPayload)
case let shape as AWSEncodableShape:
body = .json(try shape.encodeAsJSON())
default:
Expand Down Expand Up @@ -474,11 +483,13 @@ extension AWSClient {
}

case .restxml:
if let payload = (Input.self as? AWSShapeWithPayload.Type)?.payloadPath {
if let shapeWithPayload = Input.self as? AWSShapeWithPayload.Type {
let payload = shapeWithPayload.payloadPath
if let payloadBody = mirror.getAttribute(forKey: payload) {
switch payloadBody {
case let awsPayload as AWSPayload:
body = .buffer(awsPayload.byteBuffer)
verifyStream(operation: operationName, payload: awsPayload, input: shapeWithPayload)
body = .raw(awsPayload)
case let shape as AWSEncodableShape:
var rootName: String? = nil
// extract custom payload name
Expand Down Expand Up @@ -633,9 +644,9 @@ extension AWSClient {
}
return try XMLDecoder().decode(Output.self, from: outputNode)

case .buffer(let byteBuffer):
case .raw(let payload):
if let payloadKey = payloadKey {
outputDict[payloadKey] = AWSPayload.byteBuffer(byteBuffer)
outputDict[payloadKey] = payload
}

default:
Expand Down Expand Up @@ -788,6 +799,8 @@ extension AWSClient.ClientError: CustomStringConvertible {
The request url \(urlString) is invalid format.
This error is internal. So please make a issue on https://github.com/swift-aws/aws-sdk-swift/issues to solve it.
"""
case .tooMuchData:
return "You have supplied too much data for the Request."
}
}
}
Expand Down
123 changes: 103 additions & 20 deletions Sources/AWSSDKSwiftCore/AWSShapes/Payload.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,44 +16,127 @@ import struct Foundation.Data
import NIO
import NIOFoundationCompat

/// Object storing request/response payload
/// Holds a request or response payload. A request payload can be in the form of either a ByteBuffer or a stream function that will supply ByteBuffers to the HTTP client.
/// A response payload only comes in the form of a ByteBuffer
public struct AWSPayload {

/// Construct a payload from a ByteBuffer
public static func byteBuffer(_ byteBuffer: ByteBuffer) -> Self {
return Self(byteBuffer: byteBuffer)

/// Internal enum
enum Payload {
case byteBuffer(ByteBuffer)
case stream(size: Int?, stream: (EventLoop)->EventLoopFuture<ByteBuffer>)
case empty
}

/// Construct a payload from a Data
public static func data(_ data: Data) -> Self {
var byteBuffer = ByteBufferAllocator().buffer(capacity: data.count)

internal let payload: Payload

/// construct a payload from a ByteBuffer
public static func byteBuffer(_ buffer: ByteBuffer) -> Self {
return AWSPayload(payload: .byteBuffer(buffer))
}

/// construct a payload from a stream function. If you supply a size the stream function will be called repeated until you supply the number of bytes specified. If you
/// don't supply a size the stream function will be called repeatedly until you supply an empty `ByteBuffer`
public static func stream(size: Int? = nil, stream: @escaping (EventLoop)->EventLoopFuture<ByteBuffer>) -> Self {
return AWSPayload(payload: .stream(size: size, stream: stream))
}

/// construct an empty payload
public static var empty: Self {
return AWSPayload(payload: .empty)
}

/// Construct a payload from `Data`
public static func data(_ data: Data, byteBufferAllocator: ByteBufferAllocator = ByteBufferAllocator()) -> Self {
var byteBuffer = byteBufferAllocator.buffer(capacity: data.count)
byteBuffer.writeBytes(data)
return Self(byteBuffer: byteBuffer)
return AWSPayload(payload: .byteBuffer(byteBuffer))
}

/// Construct a payload from a String
public static func string(_ string: String) -> Self {
var byteBuffer = ByteBufferAllocator().buffer(capacity: string.utf8.count)
/// Construct a payload from a `String`
public static func string(_ string: String, byteBufferAllocator: ByteBufferAllocator = ByteBufferAllocator()) -> Self {
var byteBuffer = byteBufferAllocator.buffer(capacity: string.utf8.count)
byteBuffer.writeString(string)
return Self(byteBuffer: byteBuffer)
return AWSPayload(payload: .byteBuffer(byteBuffer))
}

/// Construct a stream payload from a `NIOFileHandle`
public static func fileHandle(_ fileHandle: NIOFileHandle, size: Int? = nil, fileIO: NonBlockingFileIO, byteBufferAllocator: ByteBufferAllocator = ByteBufferAllocator()) -> Self {
let blockSize = 64*1024
var leftToRead = size
func stream(_ eventLoop: EventLoop) -> EventLoopFuture<ByteBuffer> {
// calculate how much data is left to read, if a file size was indicated
var blockSize = blockSize
if let leftToRead2 = leftToRead {
blockSize = min(blockSize, leftToRead2)
leftToRead = leftToRead2 - blockSize
}
let futureByteBuffer = fileIO.read(fileHandle: fileHandle, byteCount: blockSize, allocator: byteBufferAllocator, eventLoop: eventLoop)

if leftToRead != nil {
return futureByteBuffer.map { byteBuffer in
precondition(byteBuffer.readableBytes == blockSize, "File did not have enough data")
return byteBuffer
}
}
return futureByteBuffer
}

return AWSPayload(payload: .stream(size: size, stream: stream))
}

/// Return the size of the payload. If the payload is a stream it is always possible to return a size
var size: Int? {
switch payload {
case .byteBuffer(let byteBuffer):
return byteBuffer.readableBytes
case .stream(let size,_):
return size
case .empty:
return 0
}
}

/// return payload as Data
public func asData() -> Data? {
return byteBuffer.getData(at: byteBuffer.readerIndex, length: byteBuffer.readableBytes, byteTransferStrategy: .noCopy)
switch payload {
case .byteBuffer(let byteBuffer):
return byteBuffer.getData(at: byteBuffer.readerIndex, length: byteBuffer.readableBytes, byteTransferStrategy: .noCopy)
default:
return nil
}
}

/// return payload as String
public func asString() -> String? {
return byteBuffer.getString(at: byteBuffer.readerIndex, length: byteBuffer.readableBytes, encoding: .utf8)
switch payload {
case .byteBuffer(let byteBuffer):
return byteBuffer.getString(at: byteBuffer.readerIndex, length: byteBuffer.readableBytes)
default:
return nil
}
}

/// return payload as ByteBuffer
public func asBytebuffer() -> ByteBuffer {
return byteBuffer
public func asByteBuffer() -> ByteBuffer? {
switch payload {
case .byteBuffer(let byteBuffer):
return byteBuffer
default:
return nil
}
}

/// does payload consist of zero bytes
public var isEmpty: Bool {
switch payload {
case .byteBuffer(let buffer):
return buffer.readableBytes == 0
case .stream:
return false
case .empty:
return true
}
}

let byteBuffer: ByteBuffer
}

extension AWSPayload: Decodable {
Expand Down
29 changes: 27 additions & 2 deletions Sources/AWSSDKSwiftCore/Doc/AWSShape.swift
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,18 @@ public extension AWSEncodableShape {
guard value.count <= max else { throw AWSClientError(.validationError, message: "Length of \(parent).\(name) (\(value.count)) is greater than the maximum allowed value \(max).") }
}
func validate(_ value: AWSPayload, name: String, parent: String, min: Int) throws {
guard value.byteBuffer.readableBytes >= min else { throw AWSClientError(.validationError, message: "Length of \(parent).\(name) (\(value.byteBuffer.readableBytes)) is less than minimum allowed value \(min).") }
if let size = value.size {
guard size >= min else {
throw AWSClientError(.validationError, message: "Length of \(parent).\(name) (\(size)) is less than minimum allowed value \(min).")
}
}
}
func validate(_ value: AWSPayload, name: String, parent: String, max: Int) throws {
guard value.byteBuffer.readableBytes <= max else { throw AWSClientError(.validationError, message: "Length of \(parent).\(name) (\(value.byteBuffer.readableBytes)) is greater than the maximum allowed value \(max).") }
if let size = value.size {
guard size <= max else {
throw AWSClientError(.validationError, message: "Length of \(parent).\(name) (\(size)) is greater than the maximum allowed value \(max).")
}
}
}
func validate(_ value: String, name: String, parent: String, pattern: String) throws {
let regularExpression = try NSRegularExpression(pattern: pattern, options: [])
Expand Down Expand Up @@ -158,8 +166,25 @@ public extension AWSEncodableShape {
/// AWSShape that can be decoded
public protocol AWSDecodableShape: AWSShape & Decodable {}

/// AWSShapeWithPayload options.
public struct PayloadOptions: OptionSet {
public var rawValue: Int

public init(rawValue: Int) {
self.rawValue = rawValue
}

public static let allowStreaming = PayloadOptions(rawValue: 1<<0)
public static let allowChunkedStreaming = PayloadOptions(rawValue: 1<<1)
}

/// Root AWSShape which include a payload
public protocol AWSShapeWithPayload {
/// The path to the object that is included in the request body
static var payloadPath: String { get }
static var options: PayloadOptions { get }
}

extension AWSShapeWithPayload {
public static var options: PayloadOptions { return [] }
}
10 changes: 5 additions & 5 deletions Sources/AWSSDKSwiftCore/HTTP/AWSHTTPClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ public struct AWSHTTPRequest {
public let url: URL
public let method: HTTPMethod
public let headers: HTTPHeaders
public let body: ByteBuffer?
public init(url: URL, method: HTTPMethod, headers: HTTPHeaders = [:], body: ByteBuffer? = nil) {
public let body: AWSPayload

public init(url: URL, method: HTTPMethod, headers: HTTPHeaders = [:], body: AWSPayload = .empty) {
self.url = url
self.method = method
self.headers = headers
Expand All @@ -42,10 +42,10 @@ public protocol AWSHTTPResponse {
public protocol AWSHTTPClient {
/// Execute HTTP request and return a future holding a HTTP Response
func execute(request: AWSHTTPRequest, timeout: TimeAmount, on eventLoop: EventLoop?) -> EventLoopFuture<AWSHTTPResponse>

/// This should be called before an HTTP Client can be de-initialised
func syncShutdown() throws

/// Event loop group used by client
var eventLoopGroup: EventLoopGroup { get }
}
73 changes: 67 additions & 6 deletions Sources/AWSSDKSwiftCore/HTTP/AsyncHTTPClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,80 @@ import NIO

/// comply with AWSHTTPClient protocol
extension AsyncHTTPClient.HTTPClient: AWSHTTPClient {

/// write stream to StreamWriter
private func writeToStreamWriter(
writer: HTTPClient.Body.StreamWriter,
size: Int?,
on eventLoop: EventLoop,
getData: @escaping (EventLoop)->EventLoopFuture<ByteBuffer>) -> EventLoopFuture<Void> {
let promise = eventLoop.makePromise(of: Void.self)

func _writeToStreamWriter(_ amountLeft: Int?) {
// get byte buffer from closure, write to StreamWriter, if there are still bytes to write then call
// _writeToStreamWriter again.
_ = getData(eventLoop)
.map { (byteBuffer)->() in
// if no amount was set and the byte buffer has no readable bytes then this is assumed to mean
// there will be no more data
if amountLeft == nil && byteBuffer.readableBytes == 0 {
promise.succeed(())
return
}
// calculate amount left to write
let newAmountLeft = amountLeft.map { $0 - byteBuffer.readableBytes }
// write chunk. If amountLeft is nil assume we are writing chunked output
let writeFuture: EventLoopFuture<Void> = writer.write(.byteBuffer(byteBuffer))
_ = writeFuture.flatMap { ()->EventLoopFuture<Void> in
if let newAmountLeft = newAmountLeft {
if newAmountLeft == 0 {
promise.succeed(())
} else if newAmountLeft < 0 {
promise.fail(AWSClient.ClientError.tooMuchData)
} else {
_writeToStreamWriter(newAmountLeft)
}
} else {
_writeToStreamWriter(nil)
}
return promise.futureResult
}.cascadeFailure(to: promise)
}.cascadeFailure(to: promise)
}
_writeToStreamWriter(size)
return promise.futureResult
}

/// Execute HTTP request
/// - Parameters:
/// - request: HTTP request
/// - timeout: If execution is idle for longer than timeout then throw error
/// - eventLoop: eventLoop to run request on
/// - Returns: EventLoopFuture that will be fulfilled with request response
public func execute(request: AWSHTTPRequest, timeout: TimeAmount, on eventLoop: EventLoop?) -> EventLoopFuture<AWSHTTPResponse> {
if let eventLoop = eventLoop {
precondition(self.eventLoopGroup.makeIterator().contains { $0 === eventLoop }, "EventLoop provided to AWSClient must be part of the HTTPClient's EventLoopGroup.")
}
}
let eventLoop = eventLoop ?? eventLoopGroup.next()
let requestBody: AsyncHTTPClient.HTTPClient.Body?
if let body = request.body {
requestBody = .byteBuffer(body)
} else {
var requestHeaders = request.headers

switch request.body.payload {
case .byteBuffer(let byteBuffer):
requestBody = .byteBuffer(byteBuffer)
case .stream(let size, let getData):
// add "Transfer-Encoding" header if streaming with unknown size
if size == nil {
requestHeaders.add(name: "Transfer-Encoding", value: "chunked")
}
requestBody = .stream(length: size) { writer in
return self.writeToStreamWriter(writer: writer, size: size, on: eventLoop, getData: getData)
}
case .empty:
requestBody = nil
}
do {
let eventLoop = eventLoop ?? eventLoopGroup.next()
let asyncRequest = try AsyncHTTPClient.HTTPClient.Request(url: request.url, method: request.method, headers: request.headers, body: requestBody)
let asyncRequest = try AsyncHTTPClient.HTTPClient.Request(url: request.url, method: request.method, headers: requestHeaders, body: requestBody)
return execute(request: asyncRequest, eventLoop: .delegate(on: eventLoop), deadline: .now() + timeout).map { $0 }
} catch {
return eventLoopGroup.next().makeFailedFuture(error)
Expand Down
Loading

0 comments on commit e0f7031

Please sign in to comment.