Skip to content

Commit

Permalink
Add Stream method for running a model and streaming its output (#29)
Browse files Browse the repository at this point in the history
* Rename replicate_test.go to client_test.go

* Add Stream method

* Support streaming official language models

* Fix streaming implementation

* Consolidate error into errChan

* Handle closed lineChan
  • Loading branch information
mattt authored Dec 3, 2023
1 parent 1e7fbeb commit bb054c1
Show file tree
Hide file tree
Showing 2 changed files with 267 additions and 0 deletions.
88 changes: 88 additions & 0 deletions replicate_test.go → client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package replicate_test
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -1225,3 +1226,90 @@ func TestAutomaticallyRetryPostRequests(t *testing.T) {

assert.ErrorContains(t, err, "Internal server error")
}

func TestStream(t *testing.T) {
tokens := []string{"Alpha", "Bravo", "Charlie", "Delta", "Echo"}

mockServer := httptest.NewUnstartedServer(nil)
mockServer.Config.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == "POST" && r.URL.Path == "/predictions" {
body, err := io.ReadAll(r.Body)
if err != nil {
t.Fatal(err)
}
defer r.Body.Close()

var requestBody map[string]interface{}
err = json.Unmarshal(body, &requestBody)
if err != nil {
t.Fatal(err)
}

assert.Equal(t, "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", requestBody["version"])
assert.Equal(t, map[string]interface{}{"text": "Alice"}, requestBody["input"])
assert.Equal(t, true, requestBody["stream"])

response := replicate.Prediction{
ID: "ufawqhfynnddngldkgtslldrkq",
Model: "replicate/hello-world",
Version: "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
Status: "starting",
Input: map[string]interface{}{"text": "Alice"},
CreatedAt: "2022-04-26T22:13:06.224088Z",
URLs: map[string]string{
"stream": fmt.Sprintf("%s/predictions/ufawqhfynnddngldkgtslldrkq/stream", mockServer.URL),
},
}
responseBytes, err := json.Marshal(response)
if err != nil {
t.Fatal(err)
}

w.WriteHeader(http.StatusCreated)
w.Write(responseBytes)
} else if r.Method == "GET" && r.URL.Path == "/predictions/ufawqhfynnddngldkgtslldrkq/stream" {
flusher, _ := w.(http.Flusher)
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")

for _, token := range tokens {
fmt.Fprintf(w, "data: %s\n\n", token)
flusher.Flush()
time.Sleep(time.Millisecond * 10)
}
} else {
t.Fatalf("Unexpected request: %s %s", r.Method, r.URL.Path)
}
})

mockServer.Start()
defer mockServer.Close()

client, err := replicate.NewClient(
replicate.WithToken("test-token"),
replicate.WithBaseURL(mockServer.URL),
)
require.NoError(t, err)

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

input := replicate.PredictionInput{"text": "Alice"}
version := "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa"

sseChan, errChan := client.Stream(ctx, fmt.Sprintf("replicate/hello-world:%s", version), input, nil)

for _, token := range tokens {
select {
case <-ctx.Done():
t.Fatal("context canceled")
case <-time.After(10 * time.Second):
t.Fatal("timeout")
case event := <-sseChan:
assert.Equal(t, token, event.Data)
case err := <-errChan:
assert.NoError(t, err)
}
}
}
179 changes: 179 additions & 0 deletions stream.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
package replicate

import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"unicode/utf8"

"golang.org/x/sync/errgroup"
)

var (
ErrInvalidUTF8Data = errors.New("invalid UTF-8 data")
)

// SSEEvent represents a Server-Sent Event.
type SSEEvent struct {
Type string
ID string
Data string
}

func (e *SSEEvent) decode(b []byte) error {
data := [][]byte{}
for _, line := range bytes.Split(b, []byte("\n")) {
// Parse field and value from line
parts := bytes.SplitN(line, []byte{':'}, 2)
field := string(parts[0])
var value []byte
if len(parts) == 2 {
value = parts[1]
// Trim leading space if present
value, _ = bytes.CutPrefix(value, []byte(" "))
}

switch field {
case "id":
e.ID = string(value)
case "event":
e.Type = string(value)
case "data":
data = append(data, value)
default:
// ignore
}
}

if !utf8.Valid(bytes.Join(data, []byte("\n"))) {
return ErrInvalidUTF8Data
}

e.Data = string(bytes.Join(data, []byte("\n")))

return nil
}

// Stream runs a model with the given input and returns a streams its output.
func (r *Client) Stream(ctx context.Context, identifier string, input PredictionInput, webhook *Webhook) (<-chan SSEEvent, <-chan error) {
sseChan := make(chan SSEEvent, 64)
errChan := make(chan error, 64)

done := make(chan struct{})

g, ctx := errgroup.WithContext(ctx)
g.Go(func() error {
id, err := ParseIdentifier(identifier)
if err != nil {
return err
}

var prediction *Prediction
if id.Version == nil {
prediction, err = r.CreatePredictionWithModel(ctx, id.Owner, id.Name, input, webhook, true)
} else {
prediction, err = r.CreatePrediction(ctx, *id.Version, input, webhook, true)
}

if err != nil {
return err
}

url := prediction.URLs["stream"]
if url == "" {
return errors.New("streaming not supported")
}

req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Cache-Control", "no-cache")
req.Header.Set("Connection", "keep-alive")

resp, err := r.c.Do(req)
if err != nil {
resp.Body.Close()
return fmt.Errorf("failed to send request: %w", err)
}

if resp.StatusCode != http.StatusOK {
return fmt.Errorf("received invalid status code %d", resp.StatusCode)
}

r := bufio.NewReader(resp.Body)
var buf bytes.Buffer
lineChan := make(chan []byte)

g.Go(func() error {
defer close(lineChan)

for {
select {
case <-done:
return nil
default:
line, err := r.ReadBytes('\n')
if err != nil {
defer resp.Body.Close()
if err == io.EOF {
return nil
}
return err
}
lineChan <- line
}
}
})

for {
select {
case <-ctx.Done():
return ctx.Err()
case b, ok := <-lineChan:
if !ok {
return nil
}

buf.Write(b)

if bytes.Equal(b, []byte("\n")) {
b := buf.Bytes()
buf.Reset()

event := SSEEvent{Type: "message"}
if err := event.decode(b); err != nil {
errChan <- err
}

switch event.Type {
case "error":
errChan <- unmarshalAPIError([]byte(event.Data))
case "done":
close(done)
return nil
default:
sseChan <- event
}
}
}
}
})

go func() {
defer close(sseChan)
defer close(errChan)

err := g.Wait()
if err != nil {
errChan <- err
}
}()

return sseChan, errChan
}

0 comments on commit bb054c1

Please sign in to comment.