Skip to content

Commit

Permalink
[Go] Added support for flow auth and Firebase auth plugin. (#722)
Browse files Browse the repository at this point in the history
  • Loading branch information
apascal07 authored Aug 8, 2024
1 parent a67f111 commit abd71c1
Show file tree
Hide file tree
Showing 13 changed files with 763 additions and 65 deletions.
2 changes: 2 additions & 0 deletions docs-go/_guides.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ toc:
path: /docs/genkit-go/models
- title: Creating flows
path: /docs/genkit-go/flows
- title: Adding authentication to flows
path: /docs/genkit-go/auth
- title: Prompting models
path: /docs/genkit-go/prompts
- title: Managing prompts
Expand Down
57 changes: 57 additions & 0 deletions docs-go/auth.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Flow Authentication

Genkit supports flow-level authentication, allowing you to secure your flows and ensure that only authorized users can execute them. This is particularly useful when deploying flows as HTTP endpoints.

## Configuring Flow Authentication

To add authentication to a flow, you can use the `WithFlowAuth` option when defining the flow. This option takes an implementation of the `FlowAuth` interface, which provides methods for handling authentication and authorization.

Here's an example of how to define a flow with authentication:

```golang
{% includecode github_path="firebase/genkit/go/internal/doc-snippets/flows.go" region_tag="auth" adjust_indentation="auto" %}
```

In this example, we're using the Firebase auth plugin to handle authentication. The `policy` function defines the authorization logic, checking if the user ID in the auth context matches the input user ID.

## Using the Firebase Auth Plugin

The Firebase auth plugin provides an easy way to integrate Firebase Authentication with your Genkit flows. Here's how to use it:

1. Import the Firebase plugin:

```golang
import "github.com/firebase/genkit/go/plugins/firebase"
```

2. Create a Firebase auth provider:

```golang
{% includecode github_path="firebase/genkit/go/internal/doc-snippets/flows.go" region_tag="auth-create" adjust_indentation="auto" %}
```

The `NewAuth` function takes three arguments:

- `ctx`: The context for Firebase initialization.
- `policy`: A function that defines your authorization logic.
- `required`: A boolean indicating whether authentication is required for direct calls.

3. Use the auth provider when defining your flow:

```golang
{% includecode github_path="firebase/genkit/go/internal/doc-snippets/flows.go" region_tag="auth-define" adjust_indentation="auto" %}
```

## Handling Authentication in HTTP Requests

When your flow is deployed as an HTTP endpoint, the Firebase auth plugin will automatically handle authentication for incoming requests. It expects a Bearer token in the Authorization header of the HTTP request.

## Running Authenticated Flows Locally

When running authenticated flows locally or from within other flows, you can provide local authentication context using the `WithLocalAuth` option:

```golang
{% includecode github_path="firebase/genkit/go/internal/doc-snippets/flows.go" region_tag="auth-run" adjust_indentation="auto" %}
```

This allows you to test authenticated flows without needing to provide a valid Firebase token.
2 changes: 1 addition & 1 deletion docs-go/flows.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ then call `Init()`:
```

`Init` starts a `net/http` server that exposes your flows as HTTP
endpoints (for example, `http://localhost:3400/menuSuggestionFlow`).
endpoints (for example, `http://localhost:3400/menuSuggestionFlow`).

The second parameter is an optional `Options` that specifies the following:

Expand Down
148 changes: 131 additions & 17 deletions go/genkit/flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
"strconv"
"sync"
Expand Down Expand Up @@ -98,28 +99,85 @@ type Flow[In, Out, Stream any] struct {
tstate *tracing.State // set from the action when the flow is defined
inputSchema *jsonschema.Schema // Schema of the input to the flow
outputSchema *jsonschema.Schema // Schema of the output out of the flow
auth FlowAuth // Auth provider and policy checker for the flow.
// TODO: scheduler
// TODO: experimentalDurable
// TODO: authPolicy
// TODO: middleware
}

// runOptions configures a single flow run.
type runOptions struct {
authContext AuthContext // Auth context to pass to auth policy checker when calling a flow directly.
}

// flowOptions configures a flow.
type flowOptions struct {
auth FlowAuth // Auth provider and policy checker for the flow.
}

type noStream = func(context.Context, struct{}) error

// AuthContext is the type of the auth context passed to the auth policy checker.
type AuthContext map[string]any

// FlowAuth configures an auth context provider and an auth policy check for a flow.
type FlowAuth interface {
// ProvideAuthContext sets the auth context on the given context by parsing an auth header.
// The parsing logic is provided by the auth provider.
ProvideAuthContext(ctx context.Context, authHeader string) (context.Context, error)

// NewContext sets the auth context on the given context. This is used when
// the auth context is provided by the user, rather than by the auth provider.
NewContext(ctx context.Context, authContext AuthContext) context.Context

// FromContext retrieves the auth context from the given context.
FromContext(ctx context.Context) AuthContext

// CheckAuthPolicy checks the auth context against policy.
CheckAuthPolicy(ctx context.Context, input any) error
}

// streamingCallback is the type of streaming callbacks.
type streamingCallback[Stream any] func(context.Context, Stream) error

// FlowOption modifies the flow with the provided option.
type FlowOption func(opts *flowOptions)

// FlowRunOption modifies a flow run with the provided option.
type FlowRunOption func(opts *runOptions)

// WithFlowAuth sets an auth provider and policy checker for the flow.
func WithFlowAuth(auth FlowAuth) FlowOption {
return func(f *flowOptions) {
if f.auth != nil {
log.Panic("auth already set in flow")
}
f.auth = auth
}
}

// WithLocalAuth configures an option to run or stream a flow with a local auth value.
func WithLocalAuth(authContext AuthContext) FlowRunOption {
return func(opts *runOptions) {
if opts.authContext != nil {
log.Panic("authContext already set in runOptions")
}
opts.authContext = authContext
}
}

// DefineFlow creates a Flow that runs fn, and registers it as an action.
//
// fn takes an input of type In and returns an output of type Out.
func DefineFlow[In, Out any](
name string,
fn func(ctx context.Context, input In) (Out, error),
opts ...FlowOption,
) *Flow[In, Out, struct{}] {
return defineFlow(registry.Global, name, core.Func[In, Out, struct{}](
func(ctx context.Context, input In, cb func(ctx context.Context, _ struct{}) error) (Out, error) {
return fn(ctx, input)
}))
}), opts...)
}

// DefineStreamingFlow creates a streaming Flow that runs fn, and registers it as an action.
Expand All @@ -134,11 +192,12 @@ func DefineFlow[In, Out any](
func DefineStreamingFlow[In, Out, Stream any](
name string,
fn func(ctx context.Context, input In, callback func(context.Context, Stream) error) (Out, error),
opts ...FlowOption,
) *Flow[In, Out, Stream] {
return defineFlow(registry.Global, name, core.Func[In, Out, Stream](fn))
return defineFlow(registry.Global, name, core.Func[In, Out, Stream](fn), opts...)
}

func defineFlow[In, Out, Stream any](r *registry.Registry, name string, fn core.Func[In, Out, Stream]) *Flow[In, Out, Stream] {
func defineFlow[In, Out, Stream any](r *registry.Registry, name string, fn core.Func[In, Out, Stream], opts ...FlowOption) *Flow[In, Out, Stream] {
var i In
var o Out
f := &Flow[In, Out, Stream]{
Expand All @@ -148,12 +207,27 @@ func defineFlow[In, Out, Stream any](r *registry.Registry, name string, fn core.
outputSchema: base.InferJSONSchema(o),
// TODO: set stateStore?
}
flowOpts := &flowOptions{}
for _, opt := range opts {
opt(flowOpts)
}
f.auth = flowOpts.auth
metadata := map[string]any{
"inputSchema": f.inputSchema,
"outputSchema": f.outputSchema,
"requiresAuth": f.auth != nil,
}
afunc := func(ctx context.Context, inst *flowInstruction[In], cb func(context.Context, Stream) error) (*flowState[In, Out], error) {
tracing.SetCustomMetadataAttr(ctx, "flow:wrapperAction", "true")
// Only non-durable flows have an auth policy so can safely assume Start.Input.
if inst.Start != nil {
if f.auth != nil {
ctx = f.auth.NewContext(ctx, inst.Auth)
}
if err := f.checkAuthPolicy(ctx, any(inst.Start.Input)); err != nil {
return nil, err
}
}
return f.runInstruction(ctx, inst, streamingCallback[Stream](cb))
}
core.DefineActionInRegistry(r, "", f.name, atype.Flow, metadata, nil, afunc)
Expand All @@ -167,18 +241,19 @@ func defineFlow[In, Out, Stream any](r *registry.Registry, name string, fn core.
// A flowInstruction is an instruction to follow with a flow.
// It is the input for the flow's action.
// Exactly one field will be non-nil.
type flowInstruction[I any] struct {
Start *startInstruction[I] `json:"start,omitempty"`
type flowInstruction[In any] struct {
Start *startInstruction[In] `json:"start,omitempty"`
Resume *resumeInstruction `json:"resume,omitempty"`
Schedule *scheduleInstruction[I] `json:"schedule,omitempty"`
Schedule *scheduleInstruction[In] `json:"schedule,omitempty"`
RunScheduled *runScheduledInstruction `json:"runScheduled,omitempty"`
State *stateInstruction `json:"state,omitempty"`
Retry *retryInstruction `json:"retry,omitempty"`
Auth map[string]any `json:"auth,omitempty"`
}

// A startInstruction starts a flow.
type startInstruction[I any] struct {
Input I `json:"input,omitempty"`
type startInstruction[In any] struct {
Input In `json:"input,omitempty"`
Labels map[string]string `json:"labels,omitempty"`
}

Expand All @@ -189,9 +264,9 @@ type resumeInstruction struct {
}

// A scheduleInstruction schedules a flow to start at a later time.
type scheduleInstruction[I any] struct {
type scheduleInstruction[In any] struct {
DelaySecs float64 `json:"delay,omitempty"`
Input I `json:"input,omitempty"`
Input In `json:"input,omitempty"`
}

// A runScheduledInstruction starts a scheduled flow.
Expand Down Expand Up @@ -324,7 +399,7 @@ func (f *Flow[In, Out, Stream]) runInstruction(ctx context.Context, inst *flowIn
// Name returns the name that the flow was defined with.
func (f *Flow[In, Out, Stream]) Name() string { return f.name }

func (f *Flow[In, Out, Stream]) runJSON(ctx context.Context, input json.RawMessage, cb streamingCallback[json.RawMessage]) (json.RawMessage, error) {
func (f *Flow[In, Out, Stream]) runJSON(ctx context.Context, authHeader string, input json.RawMessage, cb streamingCallback[json.RawMessage]) (json.RawMessage, error) {
// Validate input before unmarshaling it because invalid or unknown fields will be discarded in the process.
if err := base.ValidateJSON(input, f.inputSchema); err != nil {
return nil, &base.HTTPError{Code: http.StatusBadRequest, Err: err}
Expand All @@ -333,6 +408,13 @@ func (f *Flow[In, Out, Stream]) runJSON(ctx context.Context, input json.RawMessa
if err := json.Unmarshal(input, &in); err != nil {
return nil, &base.HTTPError{Code: http.StatusBadRequest, Err: err}
}
newCtx, err := f.provideAuthContext(ctx, authHeader)
if err != nil {
return nil, &base.HTTPError{Code: http.StatusUnauthorized, Err: err}
}
if err := f.checkAuthPolicy(newCtx, in); err != nil {
return nil, &base.HTTPError{Code: http.StatusForbidden, Err: err}
}
// If there is a callback, wrap it to turn an S into a json.RawMessage.
var callback streamingCallback[Stream]
if cb != nil {
Expand Down Expand Up @@ -361,6 +443,28 @@ func (f *Flow[In, Out, Stream]) runJSON(ctx context.Context, input json.RawMessa
return json.Marshal(res.Response)
}

// provideAuthContext provides auth context for the given auth header if flow auth is configured.
func (f *Flow[In, Out, Stream]) provideAuthContext(ctx context.Context, authHeader string) (context.Context, error) {
if f.auth != nil {
newCtx, err := f.auth.ProvideAuthContext(ctx, authHeader)
if err != nil {
return nil, fmt.Errorf("unauthorized: %w", err)
}
return newCtx, nil
}
return ctx, nil
}

// checkAuthPolicy checks auth context against the policy if flow auth is configured.
func (f *Flow[In, Out, Stream]) checkAuthPolicy(ctx context.Context, input any) error {
if f.auth != nil {
if err := f.auth.CheckAuthPolicy(ctx, input); err != nil {
return fmt.Errorf("permission denied for resource: %w", err)
}
}
return nil
}

// start starts executing the flow with the given input.
func (f *Flow[In, Out, Stream]) start(ctx context.Context, input In, cb streamingCallback[Stream]) (_ *flowState[In, Out], err error) {
flowID, err := generateFlowID()
Expand Down Expand Up @@ -569,11 +673,21 @@ func Run[Out any](ctx context.Context, name string, f func() (Out, error)) (Out,

// Run runs the flow in the context of another flow. The flow must run to completion when started
// (that is, it must not have interrupts).
func (f *Flow[In, Out, Stream]) Run(ctx context.Context, input In) (Out, error) {
return f.run(ctx, input, nil)
func (f *Flow[In, Out, Stream]) Run(ctx context.Context, input In, opts ...FlowRunOption) (Out, error) {
return f.run(ctx, input, nil, opts...)
}

func (f *Flow[In, Out, Stream]) run(ctx context.Context, input In, cb func(context.Context, Stream) error) (Out, error) {
func (f *Flow[In, Out, Stream]) run(ctx context.Context, input In, cb func(context.Context, Stream) error, opts ...FlowRunOption) (Out, error) {
runOpts := &runOptions{}
for _, opt := range opts {
opt(runOpts)
}
if runOpts.authContext != nil && f.auth != nil {
ctx = f.auth.NewContext(ctx, runOpts.authContext)
}
if err := f.checkAuthPolicy(ctx, input); err != nil {
return base.Zero[Out](), err
}
state, err := f.start(ctx, input, cb)
if err != nil {
return base.Zero[Out](), err
Expand Down Expand Up @@ -602,7 +716,7 @@ type StreamFlowValue[Out, Stream any] struct {
// again.
//
// Otherwise the Stream field of the passed [StreamFlowValue] holds a streamed result.
func (f *Flow[In, Out, Stream]) Stream(ctx context.Context, input In) func(func(*StreamFlowValue[Out, Stream], error) bool) {
func (f *Flow[In, Out, Stream]) Stream(ctx context.Context, input In, opts ...FlowRunOption) func(func(*StreamFlowValue[Out, Stream], error) bool) {
return func(yield func(*StreamFlowValue[Out, Stream], error) bool) {
cb := func(ctx context.Context, s Stream) error {
if ctx.Err() != nil {
Expand All @@ -613,7 +727,7 @@ func (f *Flow[In, Out, Stream]) Stream(ctx context.Context, input In) func(func(
}
return nil
}
output, err := f.run(ctx, input, cb)
output, err := f.run(ctx, input, cb, opts...)
if err != nil {
yield(nil, err)
} else {
Expand Down
Loading

0 comments on commit abd71c1

Please sign in to comment.