Skip to content

Commit

Permalink
Allow for more CORS configuration (#1594)
Browse files Browse the repository at this point in the history
Motivation:

We added some level of CORS configuration support in #1583. This change adds
further flexibility.

Modifications:

- Add an 'originBased' mode where the value of the origin header is
  returned in the response head.
- Add a custom fallback where the user can specify a callback which
  is passed the value of the origin header and returns the value to
  return in the 'access-control-allow-origin' response header (or nil,
  if the origin is not allowed).

Result:

More flexibility for CORS.
  • Loading branch information
glbrntt authored May 9, 2023
1 parent 2d5795d commit ef8ffb9
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 0 deletions.
58 changes: 58 additions & 0 deletions Sources/GRPC/Server.swift
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,9 @@ extension Server.Configuration.CORS {
public struct AllowedOrigins: Hashable, Sendable {
enum Wrapped: Hashable, Sendable {
case all
case originBased
case only([String])
case custom(AnyCustomCORSAllowedOrigin)
}

private(set) var wrapped: Wrapped
Expand All @@ -500,10 +502,23 @@ extension Server.Configuration.CORS {
/// Allow all origin values.
public static let all = Self(.all)

/// Allow all origin values; similar to `all` but returns the value of the origin header field
/// in the 'access-control-allow-origin' response header (rather than "*").
public static let originBased = Self(.originBased)

/// Allow only the given origin values.
public static func only(_ allowed: [String]) -> Self {
return Self(.only(allowed))
}

/// Provide a custom CORS origin check.
///
/// - Parameter checkOrigin: A closure which is called with the value of the 'origin' header
/// and returns the value to use in the 'access-control-allow-origin' response header,
/// or `nil` if the origin is not allowed.
public static func custom<C: GRPCCustomCORSAllowedOrigin>(_ custom: C) -> Self {
return Self(.custom(AnyCustomCORSAllowedOrigin(custom)))
}
}
}

Expand All @@ -530,3 +545,46 @@ extension Comparable {
return min(max(self, range.lowerBound), range.upperBound)
}
}

public protocol GRPCCustomCORSAllowedOrigin: Sendable, Hashable {
/// Returns the value to use for the 'access-control-allow-origin' response header for the given
/// value of the 'origin' request header.
///
/// - Parameter origin: The value of the 'origin' request header field.
/// - Returns: The value to use for the 'access-control-allow-origin' header field or `nil` if no
/// CORS related headers should be returned.
func check(origin: String) -> String?
}

extension Server.Configuration.CORS.AllowedOrigins {
struct AnyCustomCORSAllowedOrigin: GRPCCustomCORSAllowedOrigin {
private var checkOrigin: @Sendable (String) -> String?
private let hashInto: @Sendable (inout Hasher) -> Void
#if swift(>=5.7)
private let isEqualTo: @Sendable (any GRPCCustomCORSAllowedOrigin) -> Bool
#else
private let isEqualTo: @Sendable (Any) -> Bool
#endif

init<W: GRPCCustomCORSAllowedOrigin>(_ wrap: W) {
self.checkOrigin = { wrap.check(origin: $0) }
self.hashInto = { wrap.hash(into: &$0) }
self.isEqualTo = { wrap == ($0 as? W) }
}

func check(origin: String) -> String? {
return self.checkOrigin(origin)
}

func hash(into hasher: inout Hasher) {
self.hashInto(&hasher)
}

static func == (
lhs: Server.Configuration.CORS.AllowedOrigins.AnyCustomCORSAllowedOrigin,
rhs: Server.Configuration.CORS.AllowedOrigins.AnyCustomCORSAllowedOrigin
) -> Bool {
return lhs.isEqualTo(rhs)
}
}
}
4 changes: 4 additions & 0 deletions Sources/GRPC/WebCORSHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,12 @@ extension Server.Configuration.CORS.AllowedOrigins {
switch self.wrapped {
case .all:
return "*"
case .originBased:
return origin
case let .only(allowed):
return allowed.contains(origin) ? origin : nil
case let .custom(custom):
return custom.check(origin: origin)
}
}
}
44 changes: 44 additions & 0 deletions Tests/GRPCTests/WebCORSHandlerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,50 @@ internal final class WebCORSHandlerTests: XCTestCase {
try self.runPreflightRequestTest(spec: spec)
}

func testOptionsPreflightOriginBased() throws {
let spec = PreflightRequestSpec(
configuration: .init(
allowedOrigins: .originBased,
allowedHeaders: ["x-grpc-web"],
allowCredentialedRequests: false,
preflightCacheExpiration: 60
),
requestOrigin: "foo",
expectOrigin: "foo",
expectAllowedHeaders: ["x-grpc-web"],
expectAllowCredentials: false,
expectMaxAge: "60"
)
try self.runPreflightRequestTest(spec: spec)
}

func testOptionsPreflightCustom() throws {
struct Wrapper: GRPCCustomCORSAllowedOrigin {
func check(origin: String) -> String? {
if origin == "foo" {
return "bar"
} else {
return nil
}
}
}

let spec = PreflightRequestSpec(
configuration: .init(
allowedOrigins: .custom(Wrapper()),
allowedHeaders: ["x-grpc-web"],
allowCredentialedRequests: false,
preflightCacheExpiration: 60
),
requestOrigin: "foo",
expectOrigin: "bar",
expectAllowedHeaders: ["x-grpc-web"],
expectAllowCredentials: false,
expectMaxAge: "60"
)
try self.runPreflightRequestTest(spec: spec)
}

func testOptionsPreflightAllowSomeOrigins() throws {
let spec = PreflightRequestSpec(
configuration: .init(
Expand Down

0 comments on commit ef8ffb9

Please sign in to comment.