From 085d6abe338c3dec7f828ad1b2f1caeb869f5099 Mon Sep 17 00:00:00 2001 From: Mahdi Dibaiee Date: Mon, 24 Jul 2023 15:23:09 +0100 Subject: [PATCH] derivation_preview: add new proxy for derivation preview service --- derivation_preview.go | 90 +++++++++++++++++++++++++++++++++++++++++++ main.go | 3 ++ 2 files changed, 93 insertions(+) create mode 100644 derivation_preview.go diff --git a/derivation_preview.go b/derivation_preview.go new file mode 100644 index 0000000..bdb6d8d --- /dev/null +++ b/derivation_preview.go @@ -0,0 +1,90 @@ +package main + +import ( + context "context" + "fmt" + "io" + "net/http" + "encoding/json" + "bytes" + + "github.com/estuary/data-plane-gateway/auth" + "github.com/urfave/negroni" +) + +func NewDerivationPreviewServer(ctx context.Context) http.Handler { + previewHandler := negroni.Classic() + previewHandler.Use(negroni.HandlerFunc(cors)) + previewHandler.UseHandler(derivationPreviewHandler) + + return previewHandler +} + +type PreviewRequest struct { + DraftId string `json:"draft_id"` + Collection string `json:"collection"` + NumDocuments int `json:"num_documents"` +} + +// Will be used with both http and https +// Inspired partially by https://gist.github.com/yowu/f7dc34bd4736a65ff28d +var derivationPreviewHandler = http.HandlerFunc(func(writer http.ResponseWriter, proxy_req *http.Request) { + // Do auth + // Pull JWT from authz header + // See auth.go:authorized() + // decodeJWT(that bearer token) -> AuthorizedClaims + claims, err := auth.AuthenticateHttpReq(proxy_req, []byte(*jwtVerificationKey)) + if err != nil { + http.Error(writer, err.Error(), http.StatusUnauthorized) + return + } + + var req PreviewRequest + + if reqBytes, err := io.ReadAll(proxy_req.Body); err != nil { + http.Error(writer, err.Error(), http.StatusBadRequest) + return + } else if err := json.Unmarshal(reqBytes, &req); err != nil { + http.Error(writer, err.Error(), http.StatusBadRequest) + return + } + + authorization_error := auth.EnforcePrefix(claims, req.Collection) + + // enforcePrefix(claims, collection_name) + // collection_name comes from actual preview request + if authorization_error != nil { + http.Error(writer, authorization_error.Error(), http.StatusForbidden) + return + } + + // Call preview + reqBytes, err := json.Marshal(req) + var reqReader = bytes.NewReader(reqBytes) + + httpRequest, err := http.NewRequest("POST", fmt.Sprintf("http://%s/preview", *previewAddr), reqReader) + if err != nil { + http.Error(writer, fmt.Errorf("creating request to be sent to derivation preview: %w", err).Error(), http.StatusInternalServerError) + return + } + + httpRequest.Header.Add("content-type", "application/json") + httpRequest.Header.Add("authorization", proxy_req.Header.Get("authorization")) + + var httpClient = http.Client{} + preview_response, preview_error := httpClient.Do(httpRequest) + + if preview_error != nil { + // An error is returned if there were too many redirects or if there was an HTTP protocol error. + // A non-2xx response doesn't cause an error. + http.Error(writer, preview_error.Error(), http.StatusInternalServerError) + return + } + + defer preview_response.Body.Close() + // Return result + + copyHeader(writer.Header(), preview_response.Header) + writer.WriteHeader(preview_response.StatusCode) + io.Copy(writer, preview_response.Body) +}) diff --git a/main.go b/main.go index 39b8e7f..1902799 100644 --- a/main.go +++ b/main.go @@ -30,6 +30,7 @@ var ( brokerAddr = flag.String("broker-address", "localhost:8080", "Target broker address") consumerAddr = flag.String("consumer-address", "localhost:9000", "Target consumer address") inferenceAddr = flag.String("inference-address", "localhost:9090", "Target schema inference service address") + previewAddr = flag.String("preview-address", "localhost:8098", "Target derivation preview service address") corsOrigin = flag.String("cors-origin", "*", "CORS Origin") controlPlaneAuthUrl = flag.String("control-plane-auth-url", "", "base url to use for redirecting unauthorized requests") jwtVerificationKey = flag.String("verification-key", "supersecret", "Key used to verify JWTs signed by the Flow Control Plane") @@ -112,11 +113,13 @@ func main() { restHandler := NewRestServer(ctx, fmt.Sprintf("localhost:%s", *tlsPort)) schemaInferenceHandler := NewSchemaInferenceServer(ctx) + derivationPreviewHandler := NewDerivationPreviewServer(ctx) // These routes will be exposed to the public internet and used for handling both http and https requests. publicMux := http.NewServeMux() publicMux.Handle("/healthz", healthHandler) publicMux.Handle("/infer_schema", schemaInferenceHandler) + publicMux.Handle("/derivation_preview", derivationPreviewHandler) publicMux.Handle("/", restHandler) if *controlPlaneAuthUrl == "" {