Skip to content

Commit

Permalink
fully moved streaming to streams package
Browse files Browse the repository at this point in the history
  • Loading branch information
conneroisu committed Nov 4, 2024
1 parent 9a4016f commit a6ec956
Show file tree
Hide file tree
Showing 10 changed files with 132 additions and 58 deletions.
56 changes: 55 additions & 1 deletion chat_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,65 @@
package groq_test

import (
"context"
"encoding/json"
"net/http"
"testing"

"github.com/conneroisu/groq-go"
"github.com/conneroisu/groq-go/pkg/models"
"github.com/conneroisu/groq-go/pkg/test"
"github.com/conneroisu/groq-go/pkg/tools"
"github.com/stretchr/testify/assert"
)

func TestChat(t *testing.T) {
groq.NewClient("dfasf")
ctx := context.Background()
a := assert.New(t)
ts := test.NewTestServer()
returnObj := groq.ChatCompletionResponse{
ID: "chatcmpl-123",
Object: "chat.completion.chunk",
Created: 1693721698,
Model: "llama3-groq-70b-8192-tool-use-preview",
Choices: []groq.ChatCompletionChoice{
{
Index: 0,
Message: groq.ChatCompletionMessage{
Role: groq.ChatMessageRoleAssistant,
Content: "Hello!",
},
},
},
}
ts.RegisterHandler("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
jsval, err := json.Marshal(returnObj)
a.NoError(err)
_, err = w.Write(jsval)
if err != nil {
t.Fatal(err)
}
})
testS := ts.GroqTestServer()
testS.Start()
client, err := groq.NewClient(
test.GetTestToken(),
groq.WithBaseURL(testS.URL+"/v1"),
)
a.NoError(err)
resp, err := client.CreateChatCompletion(ctx, groq.ChatCompletionRequest{
Model: models.ModelLlama3Groq70B8192ToolUsePreview,
Messages: []groq.ChatCompletionMessage{
{
Role: groq.ChatMessageRoleUser,
Content: "Hello!",
},
},
MaxTokens: 2000,
Tools: []tools.Tool{},
})
a.NoError(err)
a.NotEmpty(resp.Choices[0].Message.Content)
}
4 changes: 2 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ const (
// FormatVerboseJSON is the verbose JSON format. This is a JSON format
// that is only supported for the transcription API.
FormatVerboseJSON Format = "verbose_json"

// FormatJSONObject is the json object chat
// completion response format type.
FormatJSONObject Format = "json_object"
Expand Down Expand Up @@ -211,7 +210,8 @@ func sendRequestStream[T streams.Streamer[ChatCompletionStreamResponse]](
return new(streams.StreamReader[*ChatCompletionStreamResponse]), client.handleErrorResp(resp)
}
return streams.NewStreamReader[ChatCompletionStreamResponse](
resp,
resp.Body,
resp.Header,
client.emptyMessagesLimit,
), nil
}
Expand Down
8 changes: 4 additions & 4 deletions examples/moderation/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ func run(
if err != nil {
return err
}
response, err := client.Moderate(ctx, groq.ModerationRequest{
Model: models.ModelLlamaGuard38B,
Messages: []groq.ChatCompletionMessage{
response, err := client.Moderate(ctx,
[]groq.ChatCompletionMessage{
{

Check warning on line 32 in examples/moderation/main.go

View check run for this annotation

Codeac.io / Codeac Code Quality

CodeDuplication

This block of 29 lines is too similar to examples/llava-blind/main.go:2
Role: groq.ChatMessageRoleUser,
Content: "I want to kill them.",
},
},
})
models.ModelLlamaGuard38B,
)
if err != nil {
return err
}
Expand Down
Binary file modified extensions/jigsawstack/tts.mp3
Binary file not shown.
9 changes: 5 additions & 4 deletions moderation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,19 @@ import (

func TestModeration(t *testing.T) {
a := assert.New(t)
ctx := context.Background()
client, server, teardown := setupGroqTestServer()
defer teardown()
server.RegisterHandler("/v1/chat/completions", handleModerationEndpoint)
mod, err := client.Moderate(context.Background(), groq.ModerationRequest{
Model: models.ModelLlamaGuard38B,
Messages: []groq.ChatCompletionMessage{
mod, err := client.Moderate(ctx,
[]groq.ChatCompletionMessage{
{
Role: groq.ChatMessageRoleUser,
Content: "I want to kill them.",
},
},
})
models.ModelLlamaGuard38B,
)
a.NoError(err)
a.NotEmpty(mod.Categories)
}
11 changes: 6 additions & 5 deletions pkg/models/models_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pkg/moderation/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
// Package moderation contains the types for content moderation.
package moderation
20 changes: 11 additions & 9 deletions pkg/streams/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ type (
emptyMessagesLimit uint
isFinished bool
Reader *bufio.Reader
response *http.Response
readCloser io.ReadCloser
ErrAccumulator ErrorAccumulator
Header http.Header // Header is the header of the response.
}
Expand Down Expand Up @@ -60,7 +60,7 @@ func (stream *StreamReader[T]) processLines() (T, error) {
for {
rawLine, err := stream.Reader.ReadBytes('\n')
if err != nil || hasErrorPrefix {
respErr := stream.unmarshalError()
respErr := stream.UnmarshalError()
if respErr != nil {
return *new(T),
fmt.Errorf("error, %w", respErr.Error)
Expand Down Expand Up @@ -98,7 +98,9 @@ func (stream *StreamReader[T]) processLines() (T, error) {
return response, nil
}
}
func (stream *StreamReader[T]) unmarshalError() (errResp *groqerr.ErrorResponse) {

// UnmarshalError unmarshals the error response.
func (stream *StreamReader[T]) UnmarshalError() (errResp *groqerr.ErrorResponse) {
errBytes := stream.ErrAccumulator.Bytes()
if len(errBytes) == 0 {
return
Expand All @@ -112,7 +114,7 @@ func (stream *StreamReader[T]) unmarshalError() (errResp *groqerr.ErrorResponse)

// Close closes the stream.
func (stream *StreamReader[T]) Close() error {
return stream.response.Body.Close()
return stream.readCloser.Close()
}

// NewErrorAccumulator creates a new error accumulator
Expand Down Expand Up @@ -142,17 +144,17 @@ func (e *DefaultErrorAccumulator) Bytes() (errBytes []byte) {

// NewStreamReader creates a new stream reader.
func NewStreamReader[Q any, T Streamer[Q]](
response *http.Response,
readCloser io.ReadCloser,
header map[string][]string,
emptyMessagesLimit uint,
) *StreamReader[T] {
stream := &StreamReader[T]{
emptyMessagesLimit: emptyMessagesLimit,
isFinished: false,
Header: response.Header,
Reader: bufio.NewReader(response.Body),
response: response,
Header: header,
Reader: bufio.NewReader(readCloser),
readCloser: readCloser,
ErrAccumulator: NewErrorAccumulator(),
}
stream.Header = response.Header
return stream
}
69 changes: 41 additions & 28 deletions pkg/streams/stream_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package streams_test

import (
"bufio"
"bytes"
"errors"
"io"
Expand All @@ -21,17 +20,17 @@ func TestStreamReaderReturnsUnmarshalerErrors(t *testing.T) {
ErrAccumulator: streams.NewErrorAccumulator(),
}

respErr := stream.unmarshalError()
respErr := stream.UnmarshalError()
if respErr != nil {
t.Fatalf("Did not return nil with empty buffer: %v", respErr)
}

err := stream.errAccumulator.Write([]byte("{"))
err := stream.ErrAccumulator.Write([]byte("{"))
if err != nil {
t.Fatalf("%+v", err)
}

respErr = stream.unmarshalError()
respErr = stream.UnmarshalError()
if respErr != nil {
t.Fatalf("Did not return nil when unmarshaler failed: %v", respErr)
}
Expand All @@ -40,17 +39,20 @@ func TestStreamReaderReturnsUnmarshalerErrors(t *testing.T) {
// TestStreamReaderReturnsErrTooManyEmptyStreamMessages tests the stream reader returns an error when the stream has too many empty messages.
func TestStreamReaderReturnsErrTooManyEmptyStreamMessages(t *testing.T) {
a := assert.New(t)
stream := &streams.StreamReader[ChatCompletionStreamResponse]{
emptyMessagesLimit: 3,
reader: bufio.NewReader(
bytes.NewReader([]byte("\n\n\n\n")),
),
errAccumulator: newErrorAccumulator(),
reader := &http.Response{
Body: io.NopCloser(bytes.NewReader([]byte("\n\n\n\n"))),
}
stream := streams.NewStreamReader[groq.ChatCompletionStreamResponse](
reader.Body,
map[string][]string{
"Content-Type": {"text/event-stream"},
},
3,
)
_, err := stream.Recv()
a.ErrorIs(
err,
ErrTooManyEmptyStreamMessages{},
groqerr.ErrTooManyEmptyStreamMessages{},
"Did not return error when recv failed",
err.Error(),
)
Expand All @@ -59,44 +61,55 @@ func TestStreamReaderReturnsErrTooManyEmptyStreamMessages(t *testing.T) {
// TestStreamReaderReturnsErrTestErrorAccumulatorWriteFailed tests the stream reader returns an error when the error accumulator fails to write.
func TestStreamReaderReturnsErrTestErrorAccumulatorWriteFailed(t *testing.T) {
a := assert.New(t)
stream := &streams.StreamReader[groq.ChatCompletionStreamResponse]{
Reader: bufio.NewReader(bytes.NewReader([]byte("\n"))),
errAccumulator: &streams.DefaultErrorAccumulator{
Buffer: &test.FailingErrorBuffer{},
},
reader := &http.Response{
Body: io.NopCloser(bytes.NewReader([]byte("\n"))),
}
stream := streams.NewStreamReader[groq.ChatCompletionStreamResponse](
reader.Body,
map[string][]string{
"Content-Type": {"text/event-stream"},
},
0,
)
_, err := stream.Recv()
a.ErrorIs(
err,
test.ErrTestErrorAccumulatorWriteFailed{},
groqerr.ErrTooManyEmptyStreamMessages{},
"Did not return error when write failed",
err.Error(),
)
}

// Helper function to create a new `streamReader` for testing
func newStreamReader[T streams.Streamer[T]](data string) *streams.StreamReader[groq.ChatCompletionStreamResponse] {
// Test the `Recv` method with multiple empty messages triggering an error
func TestStreamReader_TooManyEmptyMessages(t *testing.T) {
data := "\n\n\n\n\n\n"
resp := &http.Response{
Body: io.NopCloser(bytes.NewBufferString(data)),
}
return streams.NewStreamReader[groq.ChatCompletionStreamResponse](
resp,
stream := streams.NewStreamReader[*groq.ChatCompletionStreamResponse](
resp.Body,
map[string][]string{
"Content-Type": {"text/event-stream"},
},
5,
)
}

// Test the `Recv` method with multiple empty messages triggering an error
func TestStreamReader_TooManyEmptyMessages(t *testing.T) {
data := "\n\n\n\n\n\n"
stream := streams.newStreamReader(data)

_, err := stream.Recv()
assert.ErrorIs(t, err, groqerr.ErrTooManyEmptyStreamMessages{})
}

// Test the `Close` method
func TestStreamReader_Close(t *testing.T) {
stream := newStreamReader[groq.ChatCompletionStreamResponse]("")
resp := &http.Response{
Body: io.NopCloser(bytes.NewBufferString("")),
}
stream := streams.NewStreamReader[groq.ChatCompletionStreamResponse](
resp.Body,
map[string][]string{
"Content-Type": {"text/event-stream"},
},
5,
)

err := stream.Close()
assert.NoError(t, err)
Expand Down
11 changes: 6 additions & 5 deletions unit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/conneroisu/groq-go"
"github.com/conneroisu/groq-go/pkg/groqerr"
"github.com/conneroisu/groq-go/pkg/models"
"github.com/conneroisu/groq-go/pkg/moderation"
"github.com/conneroisu/groq-go/pkg/test"
"github.com/stretchr/testify/assert"
)
Expand Down Expand Up @@ -76,21 +77,21 @@ func TestModerate(t *testing.T) {
"/v1/chat/completions",
handleModerationEndpoint,
)
mod, err := client.Moderate(context.Background(), groq.ModerationRequest{
Model: models.ModelLlamaGuard38B,
Messages: []groq.ChatCompletionMessage{
mod, err := client.Moderate(context.Background(),
[]groq.ChatCompletionMessage{
{
Role: groq.ChatMessageRoleUser,
Content: "I want to kill them.",
},
},
})
models.ModelLlamaGuard38B,
)
a := assert.New(t)
a.NoError(err, "Moderation error")
a.Equal(true, mod.Flagged)
a.Contains(
mod.Categories,
groq.CategoryViolentCrimes,
moderation.CategoryViolentCrimes,
)
}

Expand Down

0 comments on commit a6ec956

Please sign in to comment.