Skip to content

Commit

Permalink
Merge pull request #52 from conneroisu/dev
Browse files Browse the repository at this point in the history
dev
  • Loading branch information
conneroisu authored Sep 20, 2024
2 parents 22b4260 + 5a76a5d commit 89e9f1f
Show file tree
Hide file tree
Showing 10 changed files with 241 additions and 153 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ go get github.com/conneroisu/groq-go

For introductory examples, see the [examples](https://github.com/conneroisu/groq-go/tree/main/examples) directory.

[audio](/examples/audio-lex-fridman)
[json](/examples/json-chat)
[moderation](/examples/moderation)
[vision](/examples/llava-blind)
[terminal-chat](/examples/terminal-chat)
[documentor-float](/examples/vhdl-documentor-json)

External Repositories using groq-go:
- [Automatic Git Commit Message Generator](https://github.com/conneroisu/gita)

Expand Down
2 changes: 1 addition & 1 deletion audio.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func (r *AudioResponse) SetHeader(header http.Header) {
// response format is text.
type audioTextResponse struct {
Text string `json:"text"` // Text is the text of the response.
header http.Header // Header is the header of the response.
header http.Header `json:"-"` // Header is the response header.
}

// SetHeader sets the header of the audio text response.
Expand Down
2 changes: 1 addition & 1 deletion builders.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ import (

// formBuilder is an interface for building a form.
type formBuilder interface {
io.Closer
CreateFormFile(fieldname string, file *os.File) error
CreateFormFileReader(fieldname string, r io.Reader, filename string) error
WriteField(fieldname, value string) error
Close() error
FormDataContentType() string
}

Expand Down
39 changes: 12 additions & 27 deletions chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,33 +59,22 @@ type ChatMessageImageURL struct {
Detail ImageURLDetail `json:"detail,omitempty"` // Detail is the detail of the image url.
}

// ChatMessagePart represents the chat message part of a chat completion message.
// ChatMessagePart represents the chat message part of a chat completion
// message.
type ChatMessagePart struct {
Type ChatMessagePartType `json:"type,omitempty"`
Text string `json:"text,omitempty"`
ImageURL *ChatMessageImageURL `json:"image_url,omitempty"`
Text string `json:"text,omitempty"` // Text is the text of the chat message part.
Type ChatMessagePartType `json:"type,omitempty"` // Type is the type of the chat message part.
ImageURL *ChatMessageImageURL `json:"image_url,omitempty"` // ImageURL is the image url of the chat message part.
}

// ChatCompletionMessage represents the chat completion message.
type ChatCompletionMessage struct {
Role Role `json:"role"` // Role is the role of the chat completion message.
Content string `json:"content"` // Content is the content of the chat completion message.
MultiContent []ChatMessagePart // MultiContent is the multi content of the chat completion message.

// This property isn't in the official documentation, but it's in
// the documentation for the official library for python:
// - https://github.com/openai/openai-python/blob/main/chatml.md
// - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
Name string `json:"name,omitempty"`

// FunctionCall setting for Role=assistant prompts this may be set to the function call generated by the model.
FunctionCall *FunctionCall `json:"function_call,omitempty"`

// ToolCalls setting for Role=assistant prompts this may be set to the tool calls generated by the model, such as function calls.
ToolCalls []ToolCall `json:"tool_calls,omitempty"`

// ToolCallID is setting for Role=tool prompts this should be set to the ID given in the assistant's prior request to call a tool.
ToolCallID string `json:"tool_call_id,omitempty"`
Role Role `json:"role"` // Role is the role of the chat completion message.
Content string `json:"content"` // Content is the content of the chat completion message.
MultiContent []ChatMessagePart `json:"-"` // MultiContent is the multi content of the chat completion message.
FunctionCall *FunctionCall `json:"function_call,omitempty"` // FunctionCall setting for Role=assistant prompts this may be set to the function call generated by the model.
ToolCalls []ToolCall `json:"tool_calls,omitempty"` // ToolCalls setting for Role=assistant prompts this may be set to the tool calls generated by the model, such as function calls.
ToolCallID string `json:"tool_call_id,omitempty"` // ToolCallID is setting for Role=tool prompts this should be set to the ID given in the assistant's prior request to call a tool.
}

// MarshalJSON method implements the json.Marshaler interface.
Expand All @@ -98,7 +87,6 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) {
Role Role `json:"role"`
Content string `json:"-"`
MultiContent []ChatMessagePart `json:"content,omitempty"`
Name string `json:"name,omitempty"`
FunctionCall *FunctionCall `json:"function_call,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
Expand All @@ -109,7 +97,6 @@ func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) {
Role Role `json:"role"`
Content string `json:"content"`
MultiContent []ChatMessagePart `json:"-"`
Name string `json:"name,omitempty"`
FunctionCall *FunctionCall `json:"function_call,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
Expand All @@ -123,7 +110,6 @@ func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) (err error) {
Role Role `json:"role"`
Content string `json:"content"`
MultiContent []ChatMessagePart
Name string `json:"name,omitempty"`
FunctionCall *FunctionCall `json:"function_call,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
Expand All @@ -137,7 +123,6 @@ func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) (err error) {
Role Role `json:"role"`
Content string
MultiContent []ChatMessagePart `json:"content"`
Name string `json:"name,omitempty"`
FunctionCall *FunctionCall `json:"function_call,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
Expand Down Expand Up @@ -191,7 +176,7 @@ type ChatCompletionResponseFormatJSONSchema struct {
Description string `json:"description,omitempty"`
// description of the chat completion response format json schema.
// Schema is the schema of the chat completion response format json schema.
Schema schema `json:"schema"`
Schema Schema `json:"schema"`
// Strict determines whether to enforce the schema upon the generated
// content.
Strict bool `json:"strict"`
Expand Down
6 changes: 3 additions & 3 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ type Client struct {
orgID string // OrgID is the organization ID for the client.
baseURL string // Base URL for the client.
client *http.Client // Client is the HTTP client to use
EmptyMessagesLimit uint // EmptyMessagesLimit is the limit for the empty messages.
emptyMessagesLimit uint // EmptyMessagesLimit is the limit for the empty messages.
requestBuilder requestBuilder
requestFormBuilder formBuilder
createFormBuilder func(body io.Writer) formBuilder
Expand All @@ -59,7 +59,7 @@ func NewClient(groqAPIKey string, opts ...Opts) (*Client, error) {
Timestamp().
Logger(),
baseURL: groqAPIURLv1,
EmptyMessagesLimit: 10,
emptyMessagesLimit: 10,
createFormBuilder: func(body io.Writer) formBuilder {
return newFormBuilder(body)
},
Expand Down Expand Up @@ -217,7 +217,7 @@ func sendRequestStream[T streamer](
return new(streamReader[T]), client.handleErrorResp(resp)
}
return &streamReader[T]{
emptyMessagesLimit: client.EmptyMessagesLimit,
emptyMessagesLimit: client.emptyMessagesLimit,
reader: bufio.NewReader(resp.Body),
response: resp,
errAccumulator: newErrorAccumulator(),
Expand Down
3 changes: 2 additions & 1 deletion errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ type errorResponse struct {
Error *APIError `json:"error,omitempty"`
}

// Error implements the error interface.
// newErrorAccumulator creates a new error accumulator
func newErrorAccumulator() errorAccumulator {
return &DefaultErrorAccumulator{
Expand Down Expand Up @@ -204,6 +203,7 @@ func (e *APIError) UnmarshalJSON(data []byte) (err error) {
return json.Unmarshal(rawMap["code"], &e.Code)
}

// Error implements the error interface.
func (e *requestError) Error() string {
return fmt.Sprintf(
"error, status code: %d, message: %s",
Expand All @@ -212,6 +212,7 @@ func (e *requestError) Error() string {
)
}

// Unwrap unwraps the error.
func (e *requestError) Unwrap() error {
return e.Err
}
34 changes: 28 additions & 6 deletions internal/test/failer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ func TestFailingErrorBuffer_Write(t *testing.T) {
t.Errorf("Write(%q) returned n=%d, expected n=0", tc.input, n)
}
if !errors.Is(err, ErrTestErrorAccumulatorWriteFailed{}) {
t.Errorf("Write(%q) returned err=%v, expected ErrTestErrorAccumulatorWriteFailed{}", tc.input, err)
t.Errorf(
"Write(%q) returned err=%v, expected ErrTestErrorAccumulatorWriteFailed{}",
tc.input,
err,
)
}
})
}
Expand Down Expand Up @@ -68,14 +72,22 @@ func TestFailingErrorBuffer_Bytes(t *testing.T) {

bytes := buf.Bytes()
if len(bytes) != 0 {
t.Errorf("Bytes() returned %v (len=%d), expected empty byte slice", bytes, len(bytes))
t.Errorf(
"Bytes() returned %v (len=%d), expected empty byte slice",
bytes,
len(bytes),
)
}

// After Write calls
_, _ = buf.Write([]byte("test"))
bytes = buf.Bytes()
if len(bytes) != 0 {
t.Errorf("Bytes() after Write returned %v (len=%d), expected empty byte slice", bytes, len(bytes))
t.Errorf(
"Bytes() after Write returned %v (len=%d), expected empty byte slice",
bytes,
len(bytes),
)
}
}

Expand All @@ -89,16 +101,26 @@ func TestFailingErrorBuffer_MultipleWrites(t *testing.T) {
t.Errorf("Write call %d returned n=%d, expected n=0", i+1, n)
}
if !errors.Is(err, ErrTestErrorAccumulatorWriteFailed{}) {
t.Errorf("Write call %d returned err=%v, expected ErrTestErrorAccumulatorWriteFailed{}", i+1, err)
t.Errorf(
"Write call %d returned err=%v, expected ErrTestErrorAccumulatorWriteFailed{}",
i+1,
err,
)
}
}

if buf.Len() != 0 {
t.Errorf("Len() after multiple Writes returned %d, expected 0", buf.Len())
t.Errorf(
"Len() after multiple Writes returned %d, expected 0",
buf.Len(),
)
}

if len(buf.Bytes()) != 0 {
t.Errorf("Bytes() after multiple Writes returned len=%d, expected 0", len(buf.Bytes()))
t.Errorf(
"Bytes() after multiple Writes returned len=%d, expected 0",
len(buf.Bytes()),
)
}
}

Expand Down
16 changes: 13 additions & 3 deletions internal/test/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ type MockRoundTripper struct {
}

// RoundTrip captures the request and returns a dummy response.
func (m *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
func (m *MockRoundTripper) RoundTrip(
req *http.Request,
) (*http.Response, error) {
m.LastRequest = req
return &http.Response{
StatusCode: http.StatusOK,
Expand Down Expand Up @@ -104,10 +106,18 @@ func TestTokenRoundTripper(t *testing.T) {
// Perform the HTTP request.
resp, err := client.Do(req)
a.NoError(err, "HTTP request should succeed")
a.Equal(http.StatusOK, resp.StatusCode, "Response status code should be 200")
a.Equal(
http.StatusOK,
resp.StatusCode,
"Response status code should be 200",
)

// Verify that the Authorization header was added.
a.NotNil(mockRT.LastRequest, "LastRequest should be captured")
authHeader := mockRT.LastRequest.Header.Get("Authorization")
a.Equal("Bearer test-token", authHeader, "Authorization header should contain the correct token")
a.Equal(
"Bearer test-token",
authHeader,
"Authorization header should contain the correct token",
)
}
Loading

0 comments on commit 89e9f1f

Please sign in to comment.