diff --git a/.github/workflows/chron-models.yaml b/.github/workflows/chron-models.yaml index d223d82..f890077 100644 --- a/.github/workflows/chron-models.yaml +++ b/.github/workflows/chron-models.yaml @@ -18,7 +18,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v4 with: - go-version: '1.23' + go-version: '1.23.1' # Step 3: Run go mod download - name: Run go mod download diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 1162c9e..5dfa339 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -25,24 +25,14 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-go@v5 with: - go-version: '1.22' + go-version: '1.23.1' cache: true - name: Install requirements id: install-lint-requirements run: | go mod download - go install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.58.1 - go install github.com/sqlc-dev/sqlc/cmd/sqlc@latest + go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest - name: Lint id: lint run: | golangci-lint run - - name: Vet - id: vet - run: | - sqlc vet - go vet ./... - - name: Run Revive Action by pulling pre-built image - uses: docker://morphy/revive-action:v2 - with: - config: .revive.toml diff --git a/.github/workflows/coverage.yaml b/.github/workflows/coverage.yaml index c9383d6..0b16641 100644 --- a/.github/workflows/coverage.yaml +++ b/.github/workflows/coverage.yaml @@ -8,7 +8,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v3 with: - go-version: '1.23' + go-version: '1.23.1' - name: Check out code uses: actions/checkout@v3 - name: Install dependencies diff --git a/.golangci.reference.yaml b/.golangci.reference.yaml new file mode 100644 index 0000000..7e7529f --- /dev/null +++ b/.golangci.reference.yaml @@ -0,0 +1,22 @@ +# +# Options for analysis running. +run: + # See the dedicated "run" documentation section. + option: value +# output configuration options +output: + # See the dedicated "output" documentation section. + option: value +# All available settings of specific linters. +linters-settings: + # See the dedicated "linters-settings" documentation section. + option: value +linters: + # See the dedicated "linters" documentation section. + option: value +issues: + # See the dedicated "issues" documentation section. + option: value +severity: + # See the dedicated "severity" documentation section. + option: value diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..7cfc908 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,14 @@ +linters-settings: + revive: + ignore-generated-header: true + severity: warning + rules: + - name: atomic + - name: line-length-limit + severity: error + arguments: [80] + - name: unhandled-error + arguments : ["fmt.Printf", "myFunction"] + +linters: + disable-all: false diff --git a/audio_test.go b/audio_test.go index 9789b69..98032eb 100644 --- a/audio_test.go +++ b/audio_test.go @@ -1,4 +1,7 @@ -package groq //nolint:testpackage // testing private field +//go:build !test +// +build !test + +package groq import ( "bytes" diff --git a/builders_test.go b/builders_test.go index f5f819c..f7fd8d1 100644 --- a/builders_test.go +++ b/builders_test.go @@ -1,7 +1,7 @@ //go:build !test // +build !test -package groq // testing private field +package groq import ( "bytes" diff --git a/chat_test.go b/chat_test.go index 26642ae..07a7708 100644 --- a/chat_test.go +++ b/chat_test.go @@ -1,594 +1,4 @@ -package groq_test +//go:build !test +// +build !test -import ( - "context" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "strconv" - "testing" - - groq "github.com/conneroisu/groq-go" - "github.com/stretchr/testify/assert" -) - -// TestCreateChatCompletionStream tests the CreateChatCompletionStream method. -func TestCreateChatCompletionStream(t *testing.T) { - a := assert.New(t) - client, server, teardown := setupGroqTestServer() - defer teardown() - server.RegisterHandler( - "/v1/chat/completions", - func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "text/event-stream") - - // Send test responses - dataBytes := []byte{} - dataBytes = append(dataBytes, []byte("event: message\n")...) - data := `{"id":"1","object":"completion","created":1598069254,"model":"llama3-groq-70b-8192-tool-use-preview","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}]}` - dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) - - dataBytes = append(dataBytes, []byte("event: message\n")...) - data = `{"id":"2","object":"completion","created":1598069255,"model":"llama3-groq-70b-8192-tool-use-preview","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}]}` - dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) - - dataBytes = append(dataBytes, []byte("event: done\n")...) - dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) - - _, err := w.Write(dataBytes) - a.NoError(err, "Write error") - }, - ) - - stream, err := client.CreateChatCompletionStream( - context.Background(), - groq.ChatCompletionRequest{ - MaxTokens: 5, - Model: groq.Llama38B8192, - Messages: []groq.ChatCompletionMessage{ - { - Role: groq.ChatMessageRoleUser, - Content: "Hello!", - }, - }, - Stream: true, - }, - ) - a.NoError(err, "CreateCompletionStream returned error") - defer stream.Close() - - expectedResponses := []groq.ChatCompletionStreamResponse{ - { - ID: "1", - Object: "completion", - Created: 1598069254, - Model: groq.Llama38B8192, - SystemFingerprint: "fp_d9767fc5b9", - Choices: []groq.ChatCompletionStreamChoice{ - { - Delta: groq.ChatCompletionStreamChoiceDelta{ - Content: "response1", - }, - FinishReason: "max_tokens", - }, - }, - }, - { - ID: "2", - Object: "completion", - Created: 1598069255, - Model: groq.Llama38B8192, - SystemFingerprint: "fp_d9767fc5b9", - Choices: []groq.ChatCompletionStreamChoice{ - { - Delta: groq.ChatCompletionStreamChoiceDelta{ - Content: "response2", - }, - FinishReason: "max_tokens", - }, - }, - }, - } - - for ix, expectedResponse := range expectedResponses { - b, _ := json.Marshal(expectedResponse) - t.Logf("%d: %s", ix, string(b)) - - receivedResponse, streamErr := stream.Recv() - a.NoError(streamErr, "stream.Recv() failed") - if !compareChatResponses(t, expectedResponse, receivedResponse) { - t.Errorf( - "Stream response %v is %v, expected %v", - ix, - receivedResponse, - expectedResponse, - ) - } - } - - _, streamErr := stream.Recv() - if !errors.Is(streamErr, io.EOF) { - t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) - } - - _, streamErr = stream.Recv() - if !errors.Is(streamErr, io.EOF) { - t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) - } - - _, streamErr = stream.Recv() - - a.ErrorIs( - streamErr, - io.EOF, - "stream.Recv() did not return EOF when the stream is finished", - ) - if !errors.Is(streamErr, io.EOF) { - t.Errorf( - "stream.Recv() did not return EOF when the stream is finished: %v", - streamErr, - ) - } -} - -// TestCreateChatCompletionStreamError tests the CreateChatCompletionStream function with an error -// in the response. -func TestCreateChatCompletionStreamError(t *testing.T) { - a := assert.New(t) - client, server, teardown := setupGroqTestServer() - defer teardown() - server.RegisterHandler( - "/v1/chat/completions", - func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "text/event-stream") - - // Send test responses - dataBytes := []byte{} - dataStr := []string{ - `{`, - `"error": {`, - `"message": "Incorrect API key provided: gsk-***************************************",`, - `"type": "invalid_request_error",`, - `"param": null,`, - `"code": "invalid_api_key"`, - `}`, - `}`, - } - for _, str := range dataStr { - dataBytes = append(dataBytes, []byte(str+"\n")...) - } - - _, err := w.Write(dataBytes) - a.NoError(err, "Write error") - }, - ) - - stream, err := client.CreateChatCompletionStream( - context.Background(), - groq.ChatCompletionRequest{ - MaxTokens: 5, - Model: groq.Llama38B8192, - Messages: []groq.ChatCompletionMessage{ - { - Role: groq.ChatMessageRoleUser, - Content: "Hello!", - }, - }, - Stream: true, - }, - ) - a.NoError(err, "CreateCompletionStream returned error") - defer stream.Close() - - _, streamErr := stream.Recv() - a.Error(streamErr, "stream.Recv() did not return error") - - var apiErr *groq.APIError - if !errors.As(streamErr, &apiErr) { - t.Errorf("stream.Recv() did not return APIError") - } - t.Logf("%+v\n", apiErr) -} - -func TestCreateChatCompletionStreamWithHeaders(t *testing.T) { - a := assert.New(t) - client, server, teardown := setupGroqTestServer() - defer teardown() - xCustomHeader := "x-custom-header" - xCustomHeaderValue := "x-custom-header-value" - server.RegisterHandler( - "/v1/chat/completions", - func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set(xCustomHeader, xCustomHeaderValue) - - // Send test responses - dataBytes := []byte( - `data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`, - ) - dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...) - - _, err := w.Write(dataBytes) - a.NoError(err, "Write error") - }, - ) - - stream, err := client.CreateChatCompletionStream( - context.Background(), - groq.ChatCompletionRequest{ - MaxTokens: 5, - Model: groq.Llama38B8192, - Messages: []groq.ChatCompletionMessage{ - { - Role: groq.ChatMessageRoleUser, - Content: "Hello!", - }, - }, - Stream: true, - }, - ) - a.NoError(err, "CreateCompletionStream returned error") - defer stream.Close() - - value := stream.Header.Get(xCustomHeader) - if value != xCustomHeaderValue { - t.Errorf("expected %s to be %s", xCustomHeaderValue, value) - } -} - -func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) { - client, server, teardown := setupGroqTestServer() - a := assert.New(t) - rateLimitHeaders := map[string]interface{}{ - "x-ratelimit-limit-requests": 100, - "x-ratelimit-limit-tokens": 1000, - "x-ratelimit-remaining-requests": 99, - "x-ratelimit-remaining-tokens": 999, - "x-ratelimit-reset-requests": "1s", - "x-ratelimit-reset-tokens": "1m", - } - defer teardown() - server.RegisterHandler( - "/v1/chat/completions", - func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "text/event-stream") - for k, v := range rateLimitHeaders { - switch val := v.(type) { - case int: - w.Header().Set(k, strconv.Itoa(val)) - default: - w.Header().Set(k, fmt.Sprintf("%s", v)) - } - } - - // Send test responses - dataBytes := []byte( - `data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`, - ) - dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...) - - _, err := w.Write(dataBytes) - a.NoError(err, "Write error") - }, - ) - - stream, err := client.CreateChatCompletionStream( - context.Background(), - groq.ChatCompletionRequest{ - MaxTokens: 5, - Model: groq.Llama38B8192, - Messages: []groq.ChatCompletionMessage{ - { - Role: groq.ChatMessageRoleUser, - Content: "Hello!", - }, - }, - Stream: true, - }, - ) - a.NoError(err, "CreateCompletionStream returned error") - defer stream.Close() - - headers := newRateLimitHeaders(stream.Header) - bs1, _ := json.Marshal(headers) - bs2, _ := json.Marshal(rateLimitHeaders) - if string(bs1) != string(bs2) { - t.Errorf("expected rate limit header %s to be %s", bs2, bs1) - } -} - -// newRateLimitHeaders creates a new RateLimitHeaders from an http.Header. -func newRateLimitHeaders(h http.Header) groq.RateLimitHeaders { - limitReq, _ := strconv.Atoi(h.Get("x-ratelimit-limit-requests")) - limitTokens, _ := strconv.Atoi(h.Get("x-ratelimit-limit-tokens")) - remainingReq, _ := strconv.Atoi(h.Get("x-ratelimit-remaining-requests")) - remainingTokens, _ := strconv.Atoi(h.Get("x-ratelimit-remaining-tokens")) - return groq.RateLimitHeaders{ - LimitRequests: limitReq, - LimitTokens: limitTokens, - RemainingRequests: remainingReq, - RemainingTokens: remainingTokens, - ResetRequests: groq.ResetTime(h.Get("x-ratelimit-reset-requests")), - ResetTokens: groq.ResetTime(h.Get("x-ratelimit-reset-tokens")), - } -} - -func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) { - a := assert.New(t) - client, server, teardown := setupGroqTestServer() - defer teardown() - server.RegisterHandler( - "/v1/chat/completions", - func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "text/event-stream") - - // Send test responses - dataBytes := []byte( - `data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`, - ) - dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...) - - _, err := w.Write(dataBytes) - a.NoError(err, "Write error") - }, - ) - - stream, err := client.CreateChatCompletionStream( - context.Background(), - groq.ChatCompletionRequest{ - MaxTokens: 5, - Model: groq.Llama38B8192, - Messages: []groq.ChatCompletionMessage{ - { - Role: groq.ChatMessageRoleUser, - Content: "Hello!", - }, - }, - Stream: true, - }, - ) - a.NoError(err, "CreateCompletionStream returned error") - defer stream.Close() - - _, streamErr := stream.Recv() - a.Error(streamErr, "stream.Recv() did not return error") - - var apiErr *groq.APIError - if !errors.As(streamErr, &apiErr) { - t.Errorf("stream.Recv() did not return APIError") - } - t.Logf("%+v\n", apiErr) -} - -func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { - a := assert.New(t) - client, server, teardown := setupGroqTestServer() - defer teardown() - server.RegisterHandler( - "/v1/chat/completions", - func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(429) - - // Send test responses - dataBytes := []byte(`{"error":{` + - `"message": "You are sending requests too quickly.",` + - `"type":"rate_limit_reached",` + - `"param":null,` + - `"code":"rate_limit_reached"}}`) - - _, err := w.Write(dataBytes) - a.NoError(err, "Write error") - }, - ) - _, err := client.CreateChatCompletionStream( - context.Background(), - groq.ChatCompletionRequest{ - MaxTokens: 5, - Model: groq.Llama38B8192, - Messages: []groq.ChatCompletionMessage{ - { - Role: groq.ChatMessageRoleUser, - Content: "Hello!", - }, - }, - Stream: true, - }, - ) - var apiErr *groq.APIError - if !errors.As(err, &apiErr) { - t.Errorf( - "TestCreateChatCompletionStreamRateLimitError did not return APIError", - ) - } - t.Logf("%+v\n", apiErr) -} - -func TestCreateChatCompletionStreamStreamOptions(t *testing.T) { - a := assert.New(t) - client, server, teardown := setupGroqTestServer() - defer teardown() - - server.RegisterHandler( - "/v1/chat/completions", - func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "text/event-stream") - - // Send test responses - var dataBytes []byte - data := `{"id":"1","object":"completion","created":1598069254,"model":"llama3-groq-70b-8192-tool-use-preview","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}],"usage":null}` - dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) - - data = `{"id":"2","object":"completion","created":1598069255,"model":"llama3-groq-70b-8192-tool-use-preview","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}],"usage":null}` - dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) - - data = `{"id":"3","object":"completion","created":1598069256,"model":"llama3-groq-70b-8192-tool-use-preview","system_fingerprint": "fp_d9767fc5b9","choices":[],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}` - dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) - - dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) - - _, err := w.Write(dataBytes) - a.NoError(err, "Write error") - }, - ) - - stream, err := client.CreateChatCompletionStream( - context.Background(), - groq.ChatCompletionRequest{ - MaxTokens: 5, - Model: groq.Llama38B8192, - Messages: []groq.ChatCompletionMessage{ - { - Role: groq.ChatMessageRoleUser, - Content: "Hello!", - }, - }, - Stream: true, - StreamOptions: &groq.StreamOptions{ - IncludeUsage: true, - }, - }, - ) - a.NoError(err, "CreateCompletionStream returned error") - defer stream.Close() - expectedResponses := []groq.ChatCompletionStreamResponse{ - { - ID: "1", - Object: "completion", - Created: 1598069254, - Model: groq.Llama38B8192, - SystemFingerprint: "fp_d9767fc5b9", - Choices: []groq.ChatCompletionStreamChoice{ - { - Delta: groq.ChatCompletionStreamChoiceDelta{ - Content: "response1", - }, - FinishReason: "max_tokens", - }, - }, - }, - { - ID: "2", - Object: "completion", - Created: 1598069255, - Model: groq.Llama38B8192, - SystemFingerprint: "fp_d9767fc5b9", - Choices: []groq.ChatCompletionStreamChoice{ - { - Delta: groq.ChatCompletionStreamChoiceDelta{ - Content: "response2", - }, - FinishReason: "max_tokens", - }, - }, - }, - { - ID: "3", - Object: "completion", - Created: 1598069256, - Model: groq.Llama38B8192, - SystemFingerprint: "fp_d9767fc5b9", - Choices: []groq.ChatCompletionStreamChoice{}, - Usage: &groq.Usage{ - PromptTokens: 1, - CompletionTokens: 1, - TotalTokens: 2, - }, - }, - } - - for ix, expectedResponse := range expectedResponses { - ix++ - b, _ := json.Marshal(expectedResponse) - t.Logf("%d: %s", ix, string(b)) - - receivedResponse, streamErr := stream.Recv() - if !errors.Is(streamErr, io.EOF) { - a.NoError(streamErr, "stream.Recv() failed") - } - if !compareChatResponses(t, expectedResponse, receivedResponse) { - t.Errorf( - "Stream response %v: %v,BUT expected %v", - ix, - receivedResponse, - expectedResponse, - ) - } - } - - _, streamErr := stream.Recv() - if !errors.Is(streamErr, io.EOF) { - t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) - } - - _, streamErr = stream.Recv() - - a.ErrorIs( - streamErr, - io.EOF, - "stream.Recv() did not return EOF when the stream is finished", - ) - if !errors.Is(streamErr, io.EOF) { - t.Errorf( - "stream.Recv() did not return EOF when the stream is finished: %v", - streamErr, - ) - } -} - -// Helper funcs. -func compareChatResponses( - t *testing.T, - r1, r2 groq.ChatCompletionStreamResponse, -) bool { - if r1.ID != r2.ID { - t.Logf("Not Equal ID: %v", r1.ID) - return false - } - if r1.Object != r2.Object { - t.Logf("Not Equal Object: %v", r1.Object) - return false - } - if r1.Created != r2.Created { - t.Logf("Not Equal Created: %v", r1.Created) - return false - } - if len(r1.Choices) != len(r2.Choices) { - t.Logf("Not Equal Choices: %v", r1.Choices) - return false - } - for i := range r1.Choices { - if !compareChatStreamResponseChoices(r1.Choices[i], r2.Choices[i]) { - t.Logf("Not Equal Choices: %v", r1.Choices[i]) - return false - } - } - if r1.Usage != nil || r2.Usage != nil { - if r1.Usage == nil || r2.Usage == nil { - return false - } - if r1.Usage.PromptTokens != r2.Usage.PromptTokens || - r1.Usage.CompletionTokens != r2.Usage.CompletionTokens || - r1.Usage.TotalTokens != r2.Usage.TotalTokens { - return false - } - } - return true -} - -func compareChatStreamResponseChoices( - c1, c2 groq.ChatCompletionStreamChoice, -) bool { - if c1.Index != c2.Index { - return false - } - if c1.Delta.Content != c2.Delta.Content { - return false - } - if c1.FinishReason != c2.FinishReason { - return false - } - return true -} +package groq diff --git a/chat_unit_test.go b/chat_unit_test.go new file mode 100644 index 0000000..26642ae --- /dev/null +++ b/chat_unit_test.go @@ -0,0 +1,594 @@ +package groq_test + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strconv" + "testing" + + groq "github.com/conneroisu/groq-go" + "github.com/stretchr/testify/assert" +) + +// TestCreateChatCompletionStream tests the CreateChatCompletionStream method. +func TestCreateChatCompletionStream(t *testing.T) { + a := assert.New(t) + client, server, teardown := setupGroqTestServer() + defer teardown() + server.RegisterHandler( + "/v1/chat/completions", + func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + // Send test responses + dataBytes := []byte{} + dataBytes = append(dataBytes, []byte("event: message\n")...) + data := `{"id":"1","object":"completion","created":1598069254,"model":"llama3-groq-70b-8192-tool-use-preview","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}]}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + dataBytes = append(dataBytes, []byte("event: message\n")...) + data = `{"id":"2","object":"completion","created":1598069255,"model":"llama3-groq-70b-8192-tool-use-preview","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}]}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + dataBytes = append(dataBytes, []byte("event: done\n")...) + dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + a.NoError(err, "Write error") + }, + ) + + stream, err := client.CreateChatCompletionStream( + context.Background(), + groq.ChatCompletionRequest{ + MaxTokens: 5, + Model: groq.Llama38B8192, + Messages: []groq.ChatCompletionMessage{ + { + Role: groq.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }, + ) + a.NoError(err, "CreateCompletionStream returned error") + defer stream.Close() + + expectedResponses := []groq.ChatCompletionStreamResponse{ + { + ID: "1", + Object: "completion", + Created: 1598069254, + Model: groq.Llama38B8192, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []groq.ChatCompletionStreamChoice{ + { + Delta: groq.ChatCompletionStreamChoiceDelta{ + Content: "response1", + }, + FinishReason: "max_tokens", + }, + }, + }, + { + ID: "2", + Object: "completion", + Created: 1598069255, + Model: groq.Llama38B8192, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []groq.ChatCompletionStreamChoice{ + { + Delta: groq.ChatCompletionStreamChoiceDelta{ + Content: "response2", + }, + FinishReason: "max_tokens", + }, + }, + }, + } + + for ix, expectedResponse := range expectedResponses { + b, _ := json.Marshal(expectedResponse) + t.Logf("%d: %s", ix, string(b)) + + receivedResponse, streamErr := stream.Recv() + a.NoError(streamErr, "stream.Recv() failed") + if !compareChatResponses(t, expectedResponse, receivedResponse) { + t.Errorf( + "Stream response %v is %v, expected %v", + ix, + receivedResponse, + expectedResponse, + ) + } + } + + _, streamErr := stream.Recv() + if !errors.Is(streamErr, io.EOF) { + t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) + } + + _, streamErr = stream.Recv() + if !errors.Is(streamErr, io.EOF) { + t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) + } + + _, streamErr = stream.Recv() + + a.ErrorIs( + streamErr, + io.EOF, + "stream.Recv() did not return EOF when the stream is finished", + ) + if !errors.Is(streamErr, io.EOF) { + t.Errorf( + "stream.Recv() did not return EOF when the stream is finished: %v", + streamErr, + ) + } +} + +// TestCreateChatCompletionStreamError tests the CreateChatCompletionStream function with an error +// in the response. +func TestCreateChatCompletionStreamError(t *testing.T) { + a := assert.New(t) + client, server, teardown := setupGroqTestServer() + defer teardown() + server.RegisterHandler( + "/v1/chat/completions", + func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + // Send test responses + dataBytes := []byte{} + dataStr := []string{ + `{`, + `"error": {`, + `"message": "Incorrect API key provided: gsk-***************************************",`, + `"type": "invalid_request_error",`, + `"param": null,`, + `"code": "invalid_api_key"`, + `}`, + `}`, + } + for _, str := range dataStr { + dataBytes = append(dataBytes, []byte(str+"\n")...) + } + + _, err := w.Write(dataBytes) + a.NoError(err, "Write error") + }, + ) + + stream, err := client.CreateChatCompletionStream( + context.Background(), + groq.ChatCompletionRequest{ + MaxTokens: 5, + Model: groq.Llama38B8192, + Messages: []groq.ChatCompletionMessage{ + { + Role: groq.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }, + ) + a.NoError(err, "CreateCompletionStream returned error") + defer stream.Close() + + _, streamErr := stream.Recv() + a.Error(streamErr, "stream.Recv() did not return error") + + var apiErr *groq.APIError + if !errors.As(streamErr, &apiErr) { + t.Errorf("stream.Recv() did not return APIError") + } + t.Logf("%+v\n", apiErr) +} + +func TestCreateChatCompletionStreamWithHeaders(t *testing.T) { + a := assert.New(t) + client, server, teardown := setupGroqTestServer() + defer teardown() + xCustomHeader := "x-custom-header" + xCustomHeaderValue := "x-custom-header-value" + server.RegisterHandler( + "/v1/chat/completions", + func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set(xCustomHeader, xCustomHeaderValue) + + // Send test responses + dataBytes := []byte( + `data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`, + ) + dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + a.NoError(err, "Write error") + }, + ) + + stream, err := client.CreateChatCompletionStream( + context.Background(), + groq.ChatCompletionRequest{ + MaxTokens: 5, + Model: groq.Llama38B8192, + Messages: []groq.ChatCompletionMessage{ + { + Role: groq.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }, + ) + a.NoError(err, "CreateCompletionStream returned error") + defer stream.Close() + + value := stream.Header.Get(xCustomHeader) + if value != xCustomHeaderValue { + t.Errorf("expected %s to be %s", xCustomHeaderValue, value) + } +} + +func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) { + client, server, teardown := setupGroqTestServer() + a := assert.New(t) + rateLimitHeaders := map[string]interface{}{ + "x-ratelimit-limit-requests": 100, + "x-ratelimit-limit-tokens": 1000, + "x-ratelimit-remaining-requests": 99, + "x-ratelimit-remaining-tokens": 999, + "x-ratelimit-reset-requests": "1s", + "x-ratelimit-reset-tokens": "1m", + } + defer teardown() + server.RegisterHandler( + "/v1/chat/completions", + func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + for k, v := range rateLimitHeaders { + switch val := v.(type) { + case int: + w.Header().Set(k, strconv.Itoa(val)) + default: + w.Header().Set(k, fmt.Sprintf("%s", v)) + } + } + + // Send test responses + dataBytes := []byte( + `data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`, + ) + dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + a.NoError(err, "Write error") + }, + ) + + stream, err := client.CreateChatCompletionStream( + context.Background(), + groq.ChatCompletionRequest{ + MaxTokens: 5, + Model: groq.Llama38B8192, + Messages: []groq.ChatCompletionMessage{ + { + Role: groq.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }, + ) + a.NoError(err, "CreateCompletionStream returned error") + defer stream.Close() + + headers := newRateLimitHeaders(stream.Header) + bs1, _ := json.Marshal(headers) + bs2, _ := json.Marshal(rateLimitHeaders) + if string(bs1) != string(bs2) { + t.Errorf("expected rate limit header %s to be %s", bs2, bs1) + } +} + +// newRateLimitHeaders creates a new RateLimitHeaders from an http.Header. +func newRateLimitHeaders(h http.Header) groq.RateLimitHeaders { + limitReq, _ := strconv.Atoi(h.Get("x-ratelimit-limit-requests")) + limitTokens, _ := strconv.Atoi(h.Get("x-ratelimit-limit-tokens")) + remainingReq, _ := strconv.Atoi(h.Get("x-ratelimit-remaining-requests")) + remainingTokens, _ := strconv.Atoi(h.Get("x-ratelimit-remaining-tokens")) + return groq.RateLimitHeaders{ + LimitRequests: limitReq, + LimitTokens: limitTokens, + RemainingRequests: remainingReq, + RemainingTokens: remainingTokens, + ResetRequests: groq.ResetTime(h.Get("x-ratelimit-reset-requests")), + ResetTokens: groq.ResetTime(h.Get("x-ratelimit-reset-tokens")), + } +} + +func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) { + a := assert.New(t) + client, server, teardown := setupGroqTestServer() + defer teardown() + server.RegisterHandler( + "/v1/chat/completions", + func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + // Send test responses + dataBytes := []byte( + `data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`, + ) + dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + a.NoError(err, "Write error") + }, + ) + + stream, err := client.CreateChatCompletionStream( + context.Background(), + groq.ChatCompletionRequest{ + MaxTokens: 5, + Model: groq.Llama38B8192, + Messages: []groq.ChatCompletionMessage{ + { + Role: groq.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }, + ) + a.NoError(err, "CreateCompletionStream returned error") + defer stream.Close() + + _, streamErr := stream.Recv() + a.Error(streamErr, "stream.Recv() did not return error") + + var apiErr *groq.APIError + if !errors.As(streamErr, &apiErr) { + t.Errorf("stream.Recv() did not return APIError") + } + t.Logf("%+v\n", apiErr) +} + +func TestCreateChatCompletionStreamRateLimitError(t *testing.T) { + a := assert.New(t) + client, server, teardown := setupGroqTestServer() + defer teardown() + server.RegisterHandler( + "/v1/chat/completions", + func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(429) + + // Send test responses + dataBytes := []byte(`{"error":{` + + `"message": "You are sending requests too quickly.",` + + `"type":"rate_limit_reached",` + + `"param":null,` + + `"code":"rate_limit_reached"}}`) + + _, err := w.Write(dataBytes) + a.NoError(err, "Write error") + }, + ) + _, err := client.CreateChatCompletionStream( + context.Background(), + groq.ChatCompletionRequest{ + MaxTokens: 5, + Model: groq.Llama38B8192, + Messages: []groq.ChatCompletionMessage{ + { + Role: groq.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + }, + ) + var apiErr *groq.APIError + if !errors.As(err, &apiErr) { + t.Errorf( + "TestCreateChatCompletionStreamRateLimitError did not return APIError", + ) + } + t.Logf("%+v\n", apiErr) +} + +func TestCreateChatCompletionStreamStreamOptions(t *testing.T) { + a := assert.New(t) + client, server, teardown := setupGroqTestServer() + defer teardown() + + server.RegisterHandler( + "/v1/chat/completions", + func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + // Send test responses + var dataBytes []byte + data := `{"id":"1","object":"completion","created":1598069254,"model":"llama3-groq-70b-8192-tool-use-preview","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}],"usage":null}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + data = `{"id":"2","object":"completion","created":1598069255,"model":"llama3-groq-70b-8192-tool-use-preview","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}],"usage":null}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + data = `{"id":"3","object":"completion","created":1598069256,"model":"llama3-groq-70b-8192-tool-use-preview","system_fingerprint": "fp_d9767fc5b9","choices":[],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}` + dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...) + + dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...) + + _, err := w.Write(dataBytes) + a.NoError(err, "Write error") + }, + ) + + stream, err := client.CreateChatCompletionStream( + context.Background(), + groq.ChatCompletionRequest{ + MaxTokens: 5, + Model: groq.Llama38B8192, + Messages: []groq.ChatCompletionMessage{ + { + Role: groq.ChatMessageRoleUser, + Content: "Hello!", + }, + }, + Stream: true, + StreamOptions: &groq.StreamOptions{ + IncludeUsage: true, + }, + }, + ) + a.NoError(err, "CreateCompletionStream returned error") + defer stream.Close() + expectedResponses := []groq.ChatCompletionStreamResponse{ + { + ID: "1", + Object: "completion", + Created: 1598069254, + Model: groq.Llama38B8192, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []groq.ChatCompletionStreamChoice{ + { + Delta: groq.ChatCompletionStreamChoiceDelta{ + Content: "response1", + }, + FinishReason: "max_tokens", + }, + }, + }, + { + ID: "2", + Object: "completion", + Created: 1598069255, + Model: groq.Llama38B8192, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []groq.ChatCompletionStreamChoice{ + { + Delta: groq.ChatCompletionStreamChoiceDelta{ + Content: "response2", + }, + FinishReason: "max_tokens", + }, + }, + }, + { + ID: "3", + Object: "completion", + Created: 1598069256, + Model: groq.Llama38B8192, + SystemFingerprint: "fp_d9767fc5b9", + Choices: []groq.ChatCompletionStreamChoice{}, + Usage: &groq.Usage{ + PromptTokens: 1, + CompletionTokens: 1, + TotalTokens: 2, + }, + }, + } + + for ix, expectedResponse := range expectedResponses { + ix++ + b, _ := json.Marshal(expectedResponse) + t.Logf("%d: %s", ix, string(b)) + + receivedResponse, streamErr := stream.Recv() + if !errors.Is(streamErr, io.EOF) { + a.NoError(streamErr, "stream.Recv() failed") + } + if !compareChatResponses(t, expectedResponse, receivedResponse) { + t.Errorf( + "Stream response %v: %v,BUT expected %v", + ix, + receivedResponse, + expectedResponse, + ) + } + } + + _, streamErr := stream.Recv() + if !errors.Is(streamErr, io.EOF) { + t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr) + } + + _, streamErr = stream.Recv() + + a.ErrorIs( + streamErr, + io.EOF, + "stream.Recv() did not return EOF when the stream is finished", + ) + if !errors.Is(streamErr, io.EOF) { + t.Errorf( + "stream.Recv() did not return EOF when the stream is finished: %v", + streamErr, + ) + } +} + +// Helper funcs. +func compareChatResponses( + t *testing.T, + r1, r2 groq.ChatCompletionStreamResponse, +) bool { + if r1.ID != r2.ID { + t.Logf("Not Equal ID: %v", r1.ID) + return false + } + if r1.Object != r2.Object { + t.Logf("Not Equal Object: %v", r1.Object) + return false + } + if r1.Created != r2.Created { + t.Logf("Not Equal Created: %v", r1.Created) + return false + } + if len(r1.Choices) != len(r2.Choices) { + t.Logf("Not Equal Choices: %v", r1.Choices) + return false + } + for i := range r1.Choices { + if !compareChatStreamResponseChoices(r1.Choices[i], r2.Choices[i]) { + t.Logf("Not Equal Choices: %v", r1.Choices[i]) + return false + } + } + if r1.Usage != nil || r2.Usage != nil { + if r1.Usage == nil || r2.Usage == nil { + return false + } + if r1.Usage.PromptTokens != r2.Usage.PromptTokens || + r1.Usage.CompletionTokens != r2.Usage.CompletionTokens || + r1.Usage.TotalTokens != r2.Usage.TotalTokens { + return false + } + } + return true +} + +func compareChatStreamResponseChoices( + c1, c2 groq.ChatCompletionStreamChoice, +) bool { + if c1.Index != c2.Index { + return false + } + if c1.Delta.Content != c2.Delta.Content { + return false + } + if c1.FinishReason != c2.FinishReason { + return false + } + return true +} diff --git a/client_test.go b/client_test.go index 763aef6..07a7708 100644 --- a/client_test.go +++ b/client_test.go @@ -1,36 +1,4 @@ -package groq_test +//go:build !test +// +build !test -import ( - "log" - "testing" - - groq "github.com/conneroisu/groq-go" - "github.com/conneroisu/groq-go/internal/test" - "github.com/stretchr/testify/assert" -) - -func setupGroqTestServer() ( - client *groq.Client, - server *test.ServerTest, - teardown func(), -) { - server = test.NewTestServer() - ts := server.GroqTestServer() - ts.Start() - teardown = ts.Close - client, err := groq.NewClient( - test.GetTestToken(), - groq.WithBaseURL(ts.URL+"/v1"), - ) - if err != nil { - log.Fatal(err) - } - return -} - -func TestEmptyKeyClientCreation(t *testing.T) { - client, err := groq.NewClient("") - a := assert.New(t) - a.Error(err, "NewClient should return error") - a.Nil(client, "NewClient should return nil") -} +package groq diff --git a/client_unit_test.go b/client_unit_test.go new file mode 100644 index 0000000..763aef6 --- /dev/null +++ b/client_unit_test.go @@ -0,0 +1,36 @@ +package groq_test + +import ( + "log" + "testing" + + groq "github.com/conneroisu/groq-go" + "github.com/conneroisu/groq-go/internal/test" + "github.com/stretchr/testify/assert" +) + +func setupGroqTestServer() ( + client *groq.Client, + server *test.ServerTest, + teardown func(), +) { + server = test.NewTestServer() + ts := server.GroqTestServer() + ts.Start() + teardown = ts.Close + client, err := groq.NewClient( + test.GetTestToken(), + groq.WithBaseURL(ts.URL+"/v1"), + ) + if err != nil { + log.Fatal(err) + } + return +} + +func TestEmptyKeyClientCreation(t *testing.T) { + client, err := groq.NewClient("") + a := assert.New(t) + a.Error(err, "NewClient should return error") + a.Nil(client, "NewClient should return nil") +} diff --git a/cmd/models/main.go b/cmd/models/main.go index 195c7f3..7332025 100644 --- a/cmd/models/main.go +++ b/cmd/models/main.go @@ -14,6 +14,7 @@ import ( "log" "net/http" "os" + "sort" "text/template" "time" @@ -24,7 +25,7 @@ const ( outputFile = "models.go" ) -//go:embed models.go.txt +//go:embed models.go.tmpl var outputFileTemplate string type templateParams struct { @@ -135,13 +136,13 @@ func fillTemplate(w io.Writer, models []ResponseModel) error { } return false }, - // isModerationModel returns true if the model is + // notModerationModel returns false if the model is // a model that can be used for moderation. // // llama-guard-3-8b is a moderation model. - "isModerationModel": func(model ResponseModel) bool { + "notModerationModel": func(model ResponseModel) bool { // if the id of the model is llama-guard-3-8b - return model.ID == "llama-guard-3-8b" + return model.ID != "llama-guard-3-8b" }, // getCurrentDate returns the current date in the format // "2006-01-02 15:04:05". @@ -149,7 +150,7 @@ func fillTemplate(w io.Writer, models []ResponseModel) error { return time.Now().Format("2006-01-02 15:04:05") }, } - tmpla, err := template.New("output"). + tmpla, err := template.New("models"). Funcs(funcMap). Parse(outputFileTemplate) if err != nil { @@ -168,4 +169,8 @@ func nameModels(models []ResponseModel) { models[i].Name = lo.PascalCase(models[i].ID) } } + // sort models by name alphabetically + sort.Slice(models, func(i, j int) bool { + return models[i].Name < models[j].Name + }) } diff --git a/cmd/models/models.go.txt b/cmd/models/models.go.tmpl similarity index 62% rename from cmd/models/models.go.txt rename to cmd/models/models.go.tmpl index 2cbf2ef..2eaf968 100644 --- a/cmd/models/models.go.txt +++ b/cmd/models/models.go.tmpl @@ -1,3 +1,4 @@ +{{ define "models" }} // Code generated by groq-modeler DO NOT EDIT. // // Created at: {{ getCurrentDate }} @@ -44,7 +45,7 @@ const ( TranscriptionTimestampGranularitySegment TranscriptionTimestampGranularity = "segment" // TranscriptionTimestampGranularitySegment is the segment timestamp granularity. {{- range $model := .Models }} - {{ $model.Name }} Model = "{{ $model.ID }}" // {{ $model.Name }} is an AI {{if isTextModel $model}}text{{else if isAudioModel $model}}audio{{else if isModerationModel $model}}moderation{{end}} model provided by {{$model.OwnedBy}}. It has {{$model.ContextWindow}} context window. + {{ $model.Name }} Model = "{{ $model.ID }}" // {{ $model.Name }} is an AI {{if isTextModel $model}}text{{else if isAudioModel $model}}audio{{else if notModerationModel $model}}moderation{{end}} model provided by {{$model.OwnedBy}}. It has {{$model.ContextWindow}} context window. {{- end }} ) @@ -66,7 +67,7 @@ var disabledModelsForEndpoints = map[Endpoint]map[Model]bool{ {{ $model.Name }}: true, {{- end }} {{- end }} }, moderationsSuffix: { - {{- range $model := .Models }} {{ if isModerationModel $model }} + {{- range $model := .Models }} {{ if notModerationModel $model }} {{ $model.Name }}: true, {{- end }} {{- end }} }, } @@ -74,3 +75,56 @@ var disabledModelsForEndpoints = map[Endpoint]map[Model]bool{ func endpointSupportsModel(endpoint Endpoint, model Model) bool { return !disabledModelsForEndpoints[endpoint][model] } +{{ end }} + +{{ define "test" }} +package groq_test + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +{{- range $model := .TextModels }} +// Test{{ $model.Name }} tests the {{ $model.Name }} model. +// It ensures that the model is supported by the groq-go library, the groq API, +// and the operations are working as expected for the specific model type. +func Test{{ $model.Name }}(t *testing.T) { +} +{{- end }} + +{{- range $model := .AudioModels }} +// Test{{ $model.Name }} tests the {{ $model.Name }} model. +// It ensures that the model is supported by the groq-go library, the groq API, +// and the operations are working as expected for the specific model type. +func Test{{ $model.Name }}(t *testing.T) { +} +{{- end }} + +{{- range $model := .ModerationModels }} +// Test{{ $model.Name }} tests the {{ $model.Name }} model. +// It ensures that the model is supported by the groq-go library, the groq API, +// and the operations are working as expected for the specific model type. +func Test{{ $model.Name }}(t *testing.T) { +} +{{- end }} + +{{- range $model := .MultiModalModels }} +// Test{{ $model.Name }} tests the {{ $model.Name }} model. +// It ensures that the model is supported by the groq-go library, the groq API, +// and the operations are working as expected for the specific model type. +func Test{{ $model.Name }}(t *testing.T) { +} +{{- end }} +{{ end }} diff --git a/completion_test.go b/completion_test.go index 915e6b0..07a7708 100644 --- a/completion_test.go +++ b/completion_test.go @@ -1,150 +1,4 @@ -package groq_test +//go:build !test +// +build !test -import ( - "context" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "strconv" - "strings" - "testing" - "time" - - "github.com/conneroisu/groq-go" - "github.com/stretchr/testify/assert" -) - -/* // TestCompletionsWrongModel tests the CreateCompletion method with a wrong model. */ -/* func TestCompletionsWrongModel(t *testing.T) { */ -/* a := assert.New(t) */ -/* client, err := groq.NewClient( */ -/* "whatever", */ -/* groq.WithBaseURL("http://localhost/v1"), */ -/* ) */ -/* a.NoError(err, "NewClient error") */ -/* */ -/* _, err = client.CreateCompletion( */ -/* context.Background(), */ -/* groq.CompletionRequest{ */ -/* MaxTokens: 5, */ -/* Model: groq.Whisper_Large_V3, */ -/* }, */ -/* ) */ -/* if !errors.Is(err, groq.ErrCompletionUnsupportedModel{Model: groq.Whisper_Large_V3}) { */ -/* t.Fatalf( */ -/* "CreateCompletion should return ErrCompletionUnsupportedModel, but returned: %v", */ -/* err, */ -/* ) */ -/* } */ -/* } */ - -// TestCompletionWithStream tests the CreateCompletion method with a stream. -func TestCompletionWithStream(t *testing.T) { - a := assert.New(t) - client, err := groq.NewClient( - "whatever", - groq.WithBaseURL("http://localhost/v1"), - ) - a.NoError(err, "NewClient error") - - ctx := context.Background() - req := groq.CompletionRequest{Stream: true} - _, err = client.CreateCompletion(ctx, req) - if !errors.Is(err, groq.ErrCompletionStreamNotSupported{}) { - t.Fatalf( - "CreateCompletion didn't return ErrCompletionStreamNotSupported", - ) - } -} - -// TestCompletions Tests the completions endpoint of the API using the mocked server. -func TestCompletions(t *testing.T) { - a := assert.New(t) - client, server, teardown := setupGroqTestServer() - defer teardown() - server.RegisterHandler("/v1/completions", handleCompletionEndpoint) - req := groq.CompletionRequest{ - MaxTokens: 5, - Model: "ada", - Prompt: "Lorem ipsum", - } - _, err := client.CreateCompletion(context.Background(), req) - a.NoError(err, "CreateCompletion error") -} - -// handleCompletionEndpoint Handles the completion endpoint by the test server. -func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { - var err error - var resBytes []byte - - // completions only accepts POST requests - if r.Method != "POST" { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - } - var completionReq groq.CompletionRequest - if completionReq, err = getCompletionBody(r); err != nil { - http.Error(w, "could not read request", http.StatusInternalServerError) - return - } - res := groq.CompletionResponse{ - ID: strconv.Itoa(int(time.Now().Unix())), - Object: "test-object", - Created: time.Now().Unix(), - // would be nice to validate Model during testing, but - // this may not be possible with how much upkeep - // would be required / wouldn't make much sense - Model: completionReq.Model, - } - // create completions - n := completionReq.N - if n == 0 { - n = 1 - } - for i := 0; i < n; i++ { - // generate a random string of length completionReq.Length - completionStr := strings.Repeat("a", completionReq.MaxTokens) - if completionReq.Echo { - completionStr = completionReq.Prompt.(string) + completionStr - } - res.Choices = append(res.Choices, groq.CompletionChoice{ - Text: completionStr, - Index: i, - }) - } - inputTokens := numTokens(completionReq.Prompt.(string)) * n - completionTokens := completionReq.MaxTokens * n - res.Usage = groq.Usage{ - PromptTokens: inputTokens, - CompletionTokens: completionTokens, - TotalTokens: inputTokens + completionTokens, - } - resBytes, _ = json.Marshal(res) - fmt.Fprintln(w, string(resBytes)) -} - -// numTokens Returns the number of GPT-3 encoded tokens in the given text. -// This function approximates based on the rule of thumb stated by OpenAI: -// https://beta.openai.com/tokenizer -// -// TODO: implement an actual tokenizer for each model available and use that -// instead. -func numTokens(s string) int { - return int(float32(len(s)) / 4) -} - -// getCompletionBody Returns the body of the request to create a completion. -func getCompletionBody(r *http.Request) (groq.CompletionRequest, error) { - completion := groq.CompletionRequest{} - // read the request body - reqBody, err := io.ReadAll(r.Body) - if err != nil { - return groq.CompletionRequest{}, err - } - err = json.Unmarshal(reqBody, &completion) - if err != nil { - return groq.CompletionRequest{}, err - } - return completion, nil -} +package groq diff --git a/completion_unit_test.go b/completion_unit_test.go new file mode 100644 index 0000000..915e6b0 --- /dev/null +++ b/completion_unit_test.go @@ -0,0 +1,150 @@ +package groq_test + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "testing" + "time" + + "github.com/conneroisu/groq-go" + "github.com/stretchr/testify/assert" +) + +/* // TestCompletionsWrongModel tests the CreateCompletion method with a wrong model. */ +/* func TestCompletionsWrongModel(t *testing.T) { */ +/* a := assert.New(t) */ +/* client, err := groq.NewClient( */ +/* "whatever", */ +/* groq.WithBaseURL("http://localhost/v1"), */ +/* ) */ +/* a.NoError(err, "NewClient error") */ +/* */ +/* _, err = client.CreateCompletion( */ +/* context.Background(), */ +/* groq.CompletionRequest{ */ +/* MaxTokens: 5, */ +/* Model: groq.Whisper_Large_V3, */ +/* }, */ +/* ) */ +/* if !errors.Is(err, groq.ErrCompletionUnsupportedModel{Model: groq.Whisper_Large_V3}) { */ +/* t.Fatalf( */ +/* "CreateCompletion should return ErrCompletionUnsupportedModel, but returned: %v", */ +/* err, */ +/* ) */ +/* } */ +/* } */ + +// TestCompletionWithStream tests the CreateCompletion method with a stream. +func TestCompletionWithStream(t *testing.T) { + a := assert.New(t) + client, err := groq.NewClient( + "whatever", + groq.WithBaseURL("http://localhost/v1"), + ) + a.NoError(err, "NewClient error") + + ctx := context.Background() + req := groq.CompletionRequest{Stream: true} + _, err = client.CreateCompletion(ctx, req) + if !errors.Is(err, groq.ErrCompletionStreamNotSupported{}) { + t.Fatalf( + "CreateCompletion didn't return ErrCompletionStreamNotSupported", + ) + } +} + +// TestCompletions Tests the completions endpoint of the API using the mocked server. +func TestCompletions(t *testing.T) { + a := assert.New(t) + client, server, teardown := setupGroqTestServer() + defer teardown() + server.RegisterHandler("/v1/completions", handleCompletionEndpoint) + req := groq.CompletionRequest{ + MaxTokens: 5, + Model: "ada", + Prompt: "Lorem ipsum", + } + _, err := client.CreateCompletion(context.Background(), req) + a.NoError(err, "CreateCompletion error") +} + +// handleCompletionEndpoint Handles the completion endpoint by the test server. +func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { + var err error + var resBytes []byte + + // completions only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + var completionReq groq.CompletionRequest + if completionReq, err = getCompletionBody(r); err != nil { + http.Error(w, "could not read request", http.StatusInternalServerError) + return + } + res := groq.CompletionResponse{ + ID: strconv.Itoa(int(time.Now().Unix())), + Object: "test-object", + Created: time.Now().Unix(), + // would be nice to validate Model during testing, but + // this may not be possible with how much upkeep + // would be required / wouldn't make much sense + Model: completionReq.Model, + } + // create completions + n := completionReq.N + if n == 0 { + n = 1 + } + for i := 0; i < n; i++ { + // generate a random string of length completionReq.Length + completionStr := strings.Repeat("a", completionReq.MaxTokens) + if completionReq.Echo { + completionStr = completionReq.Prompt.(string) + completionStr + } + res.Choices = append(res.Choices, groq.CompletionChoice{ + Text: completionStr, + Index: i, + }) + } + inputTokens := numTokens(completionReq.Prompt.(string)) * n + completionTokens := completionReq.MaxTokens * n + res.Usage = groq.Usage{ + PromptTokens: inputTokens, + CompletionTokens: completionTokens, + TotalTokens: inputTokens + completionTokens, + } + resBytes, _ = json.Marshal(res) + fmt.Fprintln(w, string(resBytes)) +} + +// numTokens Returns the number of GPT-3 encoded tokens in the given text. +// This function approximates based on the rule of thumb stated by OpenAI: +// https://beta.openai.com/tokenizer +// +// TODO: implement an actual tokenizer for each model available and use that +// instead. +func numTokens(s string) int { + return int(float32(len(s)) / 4) +} + +// getCompletionBody Returns the body of the request to create a completion. +func getCompletionBody(r *http.Request) (groq.CompletionRequest, error) { + completion := groq.CompletionRequest{} + // read the request body + reqBody, err := io.ReadAll(r.Body) + if err != nil { + return groq.CompletionRequest{}, err + } + err = json.Unmarshal(reqBody, &completion) + if err != nil { + return groq.CompletionRequest{}, err + } + return completion, nil +} diff --git a/errors.go b/errors.go index 32de196..d8580eb 100644 --- a/errors.go +++ b/errors.go @@ -22,7 +22,8 @@ type APIError struct { HTTPStatusCode int `json:"-"` // HTTPStatusCode is the status code of the error. } -// ErrChatCompletionInvalidModel is an error that occurs when the model is not supported with the CreateChatCompletion method. +// ErrChatCompletionInvalidModel is an error that occurs when the model is not +// supported with a specific endpoint. type ErrChatCompletionInvalidModel struct { Model Model Endpoint Endpoint @@ -38,7 +39,8 @@ func (e ErrChatCompletionInvalidModel) Error() string { Error() } -// ErrChatCompletionStreamNotSupported is an error that occurs when streaming is not supported with the CreateChatCompletionStream method. +// ErrChatCompletionStreamNotSupported is an error that occurs when streaming +// is not supported with the CreateChatCompletionStream method. type ErrChatCompletionStreamNotSupported struct { model Model } @@ -49,7 +51,8 @@ func (e ErrChatCompletionStreamNotSupported) Error() string { Error() } -// ErrContentFieldsMisused is an error that occurs when both Content and MultiContent properties are set. +// ErrContentFieldsMisused is an error that occurs when both Content and +// MultiContent properties are set. type ErrContentFieldsMisused struct { field string } @@ -70,7 +73,8 @@ func (e ErrCompletionUnsupportedModel) Error() string { Error() } -// ErrCompletionStreamNotSupported is an error that occurs when streaming is not supported with the CreateCompletionStream method. +// ErrCompletionStreamNotSupported is an error that occurs when streaming is +// not supported with the CreateCompletionStream method. type ErrCompletionStreamNotSupported struct{} // Error implements the error interface. @@ -79,7 +83,8 @@ func (e ErrCompletionStreamNotSupported) Error() string { Error() } -// ErrCompletionRequestPromptTypeNotSupported is an error that occurs when the type of CompletionRequest.Prompt only supports string and []string. +// ErrCompletionRequestPromptTypeNotSupported is an error that occurs when the +// type of CompletionRequest.Prompt only supports string and []string. type ErrCompletionRequestPromptTypeNotSupported struct{} // Error implements the error interface. @@ -88,7 +93,8 @@ func (e ErrCompletionRequestPromptTypeNotSupported) Error() string { Error() } -// ErrTooManyEmptyStreamMessages is returned when the stream has sent too many empty messages. +// ErrTooManyEmptyStreamMessages is returned when the stream has sent too many +// empty messages. type ErrTooManyEmptyStreamMessages struct{} // Error returns the error message. @@ -98,11 +104,11 @@ func (e ErrTooManyEmptyStreamMessages) Error() string { // errorAccumulator is an interface for accumulating errors type errorAccumulator interface { - // Write writes bytes to the error accumulator + // Write method writes bytes to the error accumulator // // It implements the io.Writer interface. Write(p []byte) error - // Bytes returns the bytes of the error accumulator. + // Bytes method returns the bytes of the error accumulator. Bytes() []byte } @@ -131,7 +137,7 @@ func newErrorAccumulator() errorAccumulator { } } -// Write writes bytes to the error accumulator. +// Write method writes bytes to the error accumulator. func (e *DefaultErrorAccumulator) Write(p []byte) error { _, err := e.Buffer.Write(p) if err != nil { @@ -140,7 +146,7 @@ func (e *DefaultErrorAccumulator) Write(p []byte) error { return nil } -// Bytes returns the bytes of the error accumulator. +// Bytes method returns the bytes of the error accumulator. func (e *DefaultErrorAccumulator) Bytes() (errBytes []byte) { if e.Buffer.Len() == 0 { return @@ -149,7 +155,7 @@ func (e *DefaultErrorAccumulator) Bytes() (errBytes []byte) { return } -// Error implements the error interface. +// Error method implements the error interface on APIError. func (e *APIError) Error() string { if e.HTTPStatusCode > 0 { return fmt.Sprintf( diff --git a/examples/llava-blind/main.go b/examples/llava-blind/main.go index d526488..ffdbcd0 100644 --- a/examples/llava-blind/main.go +++ b/examples/llava-blind/main.go @@ -28,27 +28,30 @@ func run( return err } - response, err := client.CreateChatCompletion(ctx, groq.ChatCompletionRequest{ - Model: groq.LlavaV157B4096Preview, - Messages: []groq.ChatCompletionMessage{ - { - Role: groq.ChatMessageRoleUser, - MultiContent: []groq.ChatMessagePart{ - { - Type: groq.ChatMessagePartTypeText, - Text: "What is the contents of the image?", - }, - { - Type: groq.ChatMessagePartTypeImageURL, - ImageURL: &groq.ChatMessageImageURL{ - URL: "https://cdnimg.webstaurantstore.com/images/products/large/87539/251494.jpg", - Detail: "auto", + response, err := client.CreateChatCompletion( + ctx, + groq.ChatCompletionRequest{ + Model: groq.LlavaV157B4096Preview, + Messages: []groq.ChatCompletionMessage{ + { + Role: groq.ChatMessageRoleUser, + MultiContent: []groq.ChatMessagePart{ + { + Type: groq.ChatMessagePartTypeText, + Text: "What is the contents of the image?", }, - }}, + { + Type: groq.ChatMessagePartTypeImageURL, + ImageURL: &groq.ChatMessageImageURL{ + URL: "https://cdnimg.webstaurantstore.com/images/products/large/87539/251494.jpg", + Detail: "auto", + }, + }}, + }, }, + MaxTokens: 2000, }, - MaxTokens: 2000, - }) + ) if err != nil { return err } diff --git a/examples/moderation/README.md b/examples/moderation/README.md new file mode 100644 index 0000000..3c9e9f3 --- /dev/null +++ b/examples/moderation/README.md @@ -0,0 +1,12 @@ +# moderation + +This is an example of using groq-go to create a chat moderation using the llama-3BGuard model. + +## Usage + +Make sure you have a groq key set in the environment variable `GROQ_KEY`. + +```bash +export GROQ_KEY=your-groq-key +go run . +``` diff --git a/examples/moderation/main.go b/examples/moderation/main.go new file mode 100644 index 0000000..8af45c5 --- /dev/null +++ b/examples/moderation/main.go @@ -0,0 +1,44 @@ +// Package main is an example of using groq-go to create a chat moderation +// using the llama-3BGuard model. +package main + +import ( + "context" + "fmt" + "os" + + "github.com/conneroisu/groq-go" +) + +func main() { + ctx := context.Background() + if err := run(ctx); err != nil { + fmt.Println(err) + os.Exit(1) + } +} + +func run( + ctx context.Context, +) error { + key := os.Getenv("GROQ_KEY") + client, err := groq.NewClient(key) + if err != nil { + return err + } + response, err := client.Moderate(ctx, groq.ModerationRequest{ + Model: groq.LlamaGuard38B, + // Input: "I want to kill them.", + Messages: []groq.ChatCompletionMessage{ + { + Role: groq.ChatMessageRoleUser, + Content: "I want to kill them.", + }, + }, + }) + if err != nil { + return err + } + fmt.Println(response.Categories) + return nil +} diff --git a/examples/vhdl-documentor-json/main.go b/examples/vhdl-documentor-json/main.go new file mode 100644 index 0000000..8bf757e --- /dev/null +++ b/examples/vhdl-documentor-json/main.go @@ -0,0 +1,382 @@ +// Package main is an example of using groq-go to create a chat completion +// using the llama-70B-tools-preview model to create headers for vhdl projects. +package main + +import ( + "bytes" + "context" + _ "embed" + "fmt" + "os" + "strings" + "text/template" + "time" + + "github.com/charmbracelet/log" + "github.com/conneroisu/groq-go" +) + +var ( + //go:embed template.tmpl + emTempl string + codeTemplate *template.Template + headerTemplate *template.Template +) + +func init() { + var err error + codeTemplate, err = template.New("code").Parse(emTempl) + if err != nil { + log.Fatal(err) + } + headerTemplate, err = template.New("header"). + Funcs(template.FuncMap{}). + Parse(emTempl) + if err != nil { + log.Fatal(err) + } +} +func main() { + log.SetLevel(log.DebugLevel) + log.SetReportCaller(true) + ctx := context.Background() + err := run(ctx, os.Getenv) + if err != nil { + fmt.Println(fmt.Errorf("failed to run: %w", err)) + } +} +func run( + ctx context.Context, + getenv func(string) string, +) error { + client, err := groq.NewClient(getenv("GROQ_KEY")) + if err != nil { + return err + } + log.Debugf("running with %s", getenv("GROQ_KEY")) + for _, val := range fileMap { + retry: + time.Sleep(6 * time.Second) + log.Debugf("processing %s", val.Destination) + filename := strings.Split(val.Destination, "/")[len(strings.Split(val.Destination, "/"))-1] + prompt, err := executeCodeTemplate(CodeTemplateData{ + Source: val.Source, + Name: filename, + Files: fileMap.ToFileArray(val.Destination), + }) + if err != nil { + return err + } + var thoughtThroughCode thoughtThroughCode + err = client.CreateChatCompletionJSON( + ctx, + groq.ChatCompletionRequest{ + Model: groq.Llama3Groq70B8192ToolUsePreview, + Messages: []groq.ChatCompletionMessage{{ + Role: groq.ChatMessageRoleSystem, + Content: prompt, + }}, + }, + &thoughtThroughCode, + ) + if err != nil { + goto retry + } + log.Debugf( + "thoughts for %s: %s", + val.Destination, + thoughtThroughCode.Thoughts, + ) + content, err := executeHeaderTemplate(HeaderTemplateData{ + FileName: filename, + Description: wrapText(thoughtThroughCode.Description), + Code: val.Source, + }) + if err != nil { + return err + } + log.Debugf("Creating file %s", val.Destination) + oF, err := os.Create(val.Destination) + if err != nil { + return err + } + defer oF.Close() + log.Debugf("Writing file %s", val.Destination) + _, err = oF.WriteString(content) + if err != nil { + return err + } + } + return nil +} + +// FileMapper is the map of a source file content to a output/report folder. +type FileMapper []struct { + Source string + Destination string +} + +const ( + // Destinations + muxDest = "./report/Mux/" + fullAdderDest = "./report/Adder/" + adderSubtractorDest = "./report/AddSub/" + nMuxDest = "./report/NMux/" + onesCompDest = "./report/OnesComp/" + tpuElementDest = "./report/MAC/" +) + +var ( + fileMap = FileMapper{ + {Source: tbMux2t1, Destination: muxDest + "tb_Mux2t1.vhd"}, + {Source: mux2t1, Destination: muxDest + "mux2t1.vhd"}, + {Source: mux2t1s, Destination: muxDest + "mux2t1s.vhd"}, + {Source: tbMux2t1s, Destination: muxDest + "tb_Mux2t1s.vhd"}, + {Source: mux2t1N, Destination: nMuxDest + "mux2t1_N.vhd"}, + {Source: tbMux2t1, Destination: nMuxDest + "tb_Mux2t1.vhd"}, + {Source: mux2t1, Destination: nMuxDest + "mux2t1.vhd"}, + {Source: tbNMux2t1, Destination: nMuxDest + "tb_NMux2t1.vhd"}, + {Source: fullAdder, Destination: fullAdderDest + "FullAdder.vhd"}, + {Source: tbFullAdder, Destination: fullAdderDest + "tb_FullAdder.vhd"}, + {Source: tbNBitAdder, Destination: fullAdderDest + "tb_NBitAdder.vhd"}, + {Source: nBitAdder, Destination: fullAdderDest + "nBitAdder.vhd"}, + { + Source: adderSubtractor, + Destination: adderSubtractorDest + "AdderSubtractor.vhd", + }, + { + Source: tbAdderSubtractor, + Destination: adderSubtractorDest + "tb_AdderSubtractor.vhd", + }, + { + Source: nBitInverter, + Destination: adderSubtractorDest + "nBitInverter.vhd", + }, + { + Source: tbNBitInverter, + Destination: adderSubtractorDest + "tb_nBitInverter.vhd", + }, + {Source: mux2t1N, Destination: adderSubtractorDest + "mux2t1_N.vhd"}, + { + Source: nBitAdder, + Destination: adderSubtractorDest + "nBitAdder.vhd", + }, + {Source: xorg2, Destination: onesCompDest + "xorg2.vhd"}, + {Source: org2, Destination: onesCompDest + "org2.vhd"}, + {Source: onesComp, Destination: onesCompDest + "OnesComp.vhd"}, + {Source: tbOnesComp, Destination: onesCompDest + "tb_OnesComp.vhd"}, + { + Source: tpuMvElement, + Destination: tpuElementDest + "tpuMvElement.vhd", + }, + {Source: xorg2, Destination: tpuElementDest + "xorg2.vhd"}, + {Source: org2, Destination: tpuElementDest + "org2.vhd"}, + { + Source: nBitInverter, + Destination: tpuElementDest + "nBitInverter.vhd", + }, + {Source: nBitAdder, Destination: tpuElementDest + "nBitAdder.vhd"}, + {Source: mux2t1, Destination: tpuElementDest + "mux2t1.vhd"}, + {Source: andg2, Destination: tpuElementDest + "andg2.vhd"}, + {Source: regLd, Destination: tpuElementDest + "regLd.vhd"}, + {Source: invg, Destination: tpuElementDest + "invg.vhd"}, + {Source: adder, Destination: tpuElementDest + "adder.vhd"}, + { + Source: adderSubtractor, + Destination: tpuElementDest + "adderSubtractor.vhd", + }, + {Source: fullAdder, Destination: tpuElementDest + "fullAdder.vhd"}, + {Source: multiplier, Destination: tpuElementDest + "multiplier.vhd"}, + {Source: regLd, Destination: tpuElementDest + "regLd.vhd"}, + {Source: reg, Destination: tpuElementDest + "reg.vhd"}, + { + Source: tbTPUElement, + Destination: tpuElementDest + "tb_TPUElement.vhd", + }, + } +) + +//go:embed src/Adder.vhd +var adder string + +//go:embed src/AdderSubtractor.vhd +var adderSubtractor string + +//go:embed src/FullAdder.vhd +var fullAdder string + +//go:embed src/Multiplier.vhd +var multiplier string + +//go:embed src/NBitAdder.vhd +var nBitAdder string + +//go:embed src/NBitInverter.vhd +var nBitInverter string + +//go:embed src/OnesComp.vhd +var onesComp string + +//go:embed src/Reg.vhd +var reg string + +//go:embed src/RegLd.vhd +var regLd string + +//go:embed src/TPU_MV_Element.vhd +var tpuMvElement string + +//go:embed src/andg2.vhd +var andg2 string + +//go:embed src/invg.vhd +var invg string + +//go:embed src/mux2t1.vhd +var mux2t1 string + +//go:embed src/mux2t1s.vhd +var mux2t1s string + +//go:embed src/mux2t1_N.vhd +var mux2t1N string + +//go:embed src/org2.vhd +var org2 string + +//go:embed src/xorg2.vhd +var xorg2 string + +//go:embed test/tb_NMux2t1.vhd +var tbNMux2t1 string + +//go:embed test/tb_NMux2t1.vhd +var tbMux2t1 string + +//go:embed test/tb_nBitAdder.vhd +var tbNBitAdder string + +//go:embed test/tb_mux2t1s.vhd +var tbMux2t1s string + +//go:embed test/tb_nBitInverter.vhd +var tbNBitInverter string + +//go:embed test/tb_AdderSubtractor.vhd +var tbAdderSubtractor string + +//go:embed test/tb_NFullAdder.vhd +var tbFullAdder string + +//go:embed test/tb_OnesComp.vhd +var tbOnesComp string + +//go:embed test/tb_TPU_MV_Element.vhd +var tbTPUElement string + +// File represents a single file with its name and content +type File struct { + Name string + Content string +} + +// ToFileArray converts the FileMapper to a File array +func (fM *FileMapper) ToFileArray(dest string) []File { + var files []File + for _, file := range *fM { + split := strings.Split(file.Destination, "/") + if strings.Contains(dest, "/"+split[2]+"/") { + log.Debugf("adding %s to files", file.Destination) + files = append(files, File{ + Name: file.Destination, + Content: file.Source, + }) + continue + } + } + log.Debugf("n-files: %v", len(files)) + return files +} + +// CodeTemplateData represents the data structure for the code template +type CodeTemplateData struct { + Source string + Name string + Files []File +} + +// HeaderTemplateData represents the data structure for the header template +type HeaderTemplateData struct { + FileName string + Description string + Code string +} +type thoughtThroughCode struct { + Thoughts string `json:"thoughts" jsonschema:"title=Thoughts,description=Thoughts on the code and thinking through exactly how it interacts with other given code in the project."` + Description string `json:"description" jsonschema:"title=Description,description=A description of the code's function, form, etc."` +} + +func executeCodeTemplate(data CodeTemplateData) (string, error) { + buf := new(bytes.Buffer) + err := codeTemplate.Execute(buf, data) + if err != nil { + return "", fmt.Errorf("failed to execute code template: %v", err) + } + return buf.String(), nil +} +func executeHeaderTemplate(data HeaderTemplateData) (string, error) { + var result strings.Builder + err := headerTemplate.Execute(&result, data) + if err != nil { + return "", fmt.Errorf("failed to execute header template: %v", err) + } + return result.String(), nil +} + +// wrapText trims the string to 80 characters per line, +// adding newlines and hyphens if it is longer. +func wrapText(s string) string { + var result strings.Builder + maxLineLength := 80 + words := strings.Fields(s) + lineLength := 0 + for i, word := range words { + wordLength := len(word) + if lineLength+wordLength > maxLineLength { + if wordLength > maxLineLength { + remaining := word + for len(remaining) > 0 { + spaceLeft := maxLineLength - lineLength + if spaceLeft <= 0 { + result.WriteString("\n-- ") + lineLength = 0 + spaceLeft = maxLineLength + } + take := min(spaceLeft, len(remaining)) + if take < len(remaining) { + // Add hyphen when breaking a word + result.WriteString(remaining[:take] + "-\n") + lineLength = 0 + } else { + result.WriteString(remaining[:take]) + lineLength += take + } + remaining = remaining[take:] + } + } else { + // Start a new line + result.WriteString("\n-- ") + result.WriteString(word) + lineLength = wordLength + } + } else { + if i > 0 { + result.WriteString(" ") + lineLength++ + } + result.WriteString(word) + lineLength += wordLength + } + } + return result.String() +} diff --git a/examples/vhdl-documentor-json/report/AddSub/.keep b/examples/vhdl-documentor-json/report/AddSub/.keep new file mode 100644 index 0000000..e69de29 diff --git a/examples/vhdl-documentor-json/report/Adder/.keep b/examples/vhdl-documentor-json/report/Adder/.keep new file mode 100644 index 0000000..e69de29 diff --git a/examples/vhdl-documentor-json/report/Mux/.keep b/examples/vhdl-documentor-json/report/Mux/.keep new file mode 100644 index 0000000..e69de29 diff --git a/examples/vhdl-documentor-json/report/Mux/tb_Mux2t1.vhd b/examples/vhdl-documentor-json/report/Mux/tb_Mux2t1.vhd new file mode 100644 index 0000000..9236b58 --- /dev/null +++ b/examples/vhdl-documentor-json/report/Mux/tb_Mux2t1.vhd @@ -0,0 +1,55 @@ +-------------------------------------------------------------------------------- +-- author: [ Conner Ohnesorge ](https://github.com/conneroisu) +-- file_name: tb_Mux2t1.vhd +-- desc: The file is a part of a larger project that includes a 2-to-1 multiplexer +-- component and its structural implementation. It is used to test the component's +-- behavior under different input conditions. +-------------------------------------------------------------------------------- + +library IEEE; +use IEEE.std_logic_1164.all; + +entity tb_NMux2t1 is end tb_NMux2t1; + +architecture behavior of tb_NMux2t1 is + component nmux2t1 + port ( + AMux : in std_logic; + BMux : in std_logic; + Sel : in std_logic; + Output : out std_logic + ); + + + end component; + + --Inputs + signal s_AMux : std_logic := '0'; + signal s_BMux : std_logic := '0'; + signal s_Out : std_logic := '0'; + + --Outputs + signal f : std_logic; + +begin + DUT0 : nmux2t1 + port map ( + AMux => s_AMux, + BMux => s_BMux, + Sel => s_Out, + Output => f + ); + -- Stimulus process + stim_proc : process + begin + s_AMux <= '0'; + s_BMux <= '0'; + s_Out <= '0'; + wait for 100 ns; + + assert f = '0' report "Test failed for s=0" severity error; + + wait; + end process stim_proc; +end behavior; + diff --git a/examples/vhdl-documentor-json/report/NMux/.keep b/examples/vhdl-documentor-json/report/NMux/.keep new file mode 100644 index 0000000..e69de29 diff --git a/examples/vhdl-documentor-json/report/OnesComp/.keep b/examples/vhdl-documentor-json/report/OnesComp/.keep new file mode 100644 index 0000000..e69de29 diff --git a/examples/vhdl-documentor-json/src/Adder.vhd b/examples/vhdl-documentor-json/src/Adder.vhd new file mode 100644 index 0000000..ccfbb0a --- /dev/null +++ b/examples/vhdl-documentor-json/src/Adder.vhd @@ -0,0 +1,26 @@ +library IEEE; + +use IEEE.std_logic_1164.all; + +entity Adder is + + port( + iCLK : in std_logic; + iA : in integer; + iB : in integer; + oC : out integer + ); + +end Adder; + +architecture behavior of Adder is +begin + + process(iCLK, iA, iB) + begin + if rising_edge(iCLK) then + oC <= iA + iB; + end if; + end process; + +end behavior; diff --git a/examples/vhdl-documentor-json/src/AdderSubtractor.vhd b/examples/vhdl-documentor-json/src/AdderSubtractor.vhd new file mode 100644 index 0000000..92a043f --- /dev/null +++ b/examples/vhdl-documentor-json/src/AdderSubtractor.vhd @@ -0,0 +1,78 @@ +library IEEE; + +use IEEE.STD_LOGIC_1164.all; + +entity AdderSubtractor is + + generic ( + N : integer := 5 + ); + port ( + A : in std_logic_vector (N-1 downto 0); + B : in std_logic_vector (N-1 downto 0); + nAdd_Sub : in std_logic; + Sum : out std_logic_vector (N-1 downto 0); + Carry : out std_logic_vector (N-1 downto 0) + ); + +end AdderSubtractor; + +architecture Structural of AdderSubtractor is + + component NBitInverter + port ( + Input : in std_logic_vector (N-1 downto 0); + Output : out std_logic_vector (N-1 downto 0) + ); + end component; + + component mux2t1_N + generic ( + N : integer := 4 + ); + port ( + i_D0 : in std_logic_vector (N-1 downto 0); + i_D1 : in std_logic_vector (N-1 downto 0); + i_S : in std_logic; + o_O : out std_logic_vector (N-1 downto 0) + ); + end component; + + component NBitAdder + port ( + A : in std_logic_vector (N-1 downto 0); + B : in std_logic_vector (N-1 downto 0); + Sum : out std_logic_vector (N-1 downto 0); + CarryOut : out std_logic_vector (N-1 downto 0) + ); + end component; + + signal s_inverted : std_logic_vector (N-1 downto 0); + signal s_muxed : std_logic_vector (N-1 downto 0); + +begin + + Inv : NBitInverter + port map ( + Input => B, + Output => s_inverted + ); + + Mux : mux2t1_N + port map ( + i_D0 => B, + i_D1 => s_inverted, + i_S => nAdd_Sub, + o_O => s_muxed + ); + + -- Instantiate the N-bit adder + Adder : NBitAdder + port map ( + A => A, + B => s_muxed, + Sum => Sum, + CarryOut => Carry + ); + +end Structural; diff --git a/examples/vhdl-documentor-json/src/FullAdder.vhd b/examples/vhdl-documentor-json/src/FullAdder.vhd new file mode 100644 index 0000000..5b6fe8c --- /dev/null +++ b/examples/vhdl-documentor-json/src/FullAdder.vhd @@ -0,0 +1,29 @@ +library IEEE; + +use IEEE.STD_LOGIC_1164.all; + +entity FullAdder is + + port ( + i_A : in std_logic; + i_B : in std_logic; + Cin : in std_logic; + Sum : out std_logic; + Cout : out std_logic + ); + +end FullAdder; + +architecture Structural of FullAdder is + signal s_Adder, s_C1, s_C2 : std_logic; +begin + + s_Adder <= i_A xor i_B; + s_C1 <= i_A and i_B; + + Sum <= s_Adder xor Cin; + s_C2 <= s_Adder and Cin; + + Cout <= s_C1 or s_C2; + +end Structural; diff --git a/examples/vhdl-documentor-json/src/Multiplier.vhd b/examples/vhdl-documentor-json/src/Multiplier.vhd new file mode 100644 index 0000000..52557be --- /dev/null +++ b/examples/vhdl-documentor-json/src/Multiplier.vhd @@ -0,0 +1,26 @@ +library IEEE; + +use IEEE.std_logic_1164.all; + +entity Multiplier is + + port( + iCLK : in std_logic; + i_A : in integer; + i_B : in integer; + o_P : out integer + ); + +end Multiplier; + +architecture behavior of Multiplier is +begin + + process(iCLK, i_A, i_B) + begin + if rising_edge(iCLK) then + o_P <= i_A * i_B; + end if; + end process; + +end behavior; diff --git a/examples/vhdl-documentor-json/src/NBitAdder.vhd b/examples/vhdl-documentor-json/src/NBitAdder.vhd new file mode 100644 index 0000000..64485ac --- /dev/null +++ b/examples/vhdl-documentor-json/src/NBitAdder.vhd @@ -0,0 +1,40 @@ +library IEEE; + +use IEEE.STD_LOGIC_1164.all; +use IEEE.NUMERIC_STD.all; +use work.FullAdder; + +entity NBitAdder is + generic ( + N : integer := 4 + ); + port ( + A : in std_logic_vector(N-1 downto 0); + B : in std_logic_vector(N-1 downto 0); + Sum : out std_logic_vector(N-1 downto 0); + CarryOut : out std_logic_vector(N-1 downto 0) + ); +end NBitAdder; + + +architecture Behavioral of NBitAdder is + signal carries : std_logic_vector(N downto 0); +begin + -- Initialize the carry-in for the first adder to 0 + carries(0) <= '0'; + + gen_full_adders : for i in 0 to N-1 generate + full_adder_inst : entity FullAdder + port map ( + i_A => A(i), + i_B => B(i), + Cin => carries(i), + Sum => Sum(i), + Cout => carries(i+1) + ); + end generate; + + -- The carry-out of the last full adder is the CarryOut of the N-bit adder + CarryOut <= carries(N-1 downto 0); + +end Behavioral; diff --git a/examples/vhdl-documentor-json/src/NBitInverter.vhd b/examples/vhdl-documentor-json/src/NBitInverter.vhd new file mode 100644 index 0000000..0741404 --- /dev/null +++ b/examples/vhdl-documentor-json/src/NBitInverter.vhd @@ -0,0 +1,31 @@ +library IEEE; + +use IEEE.STD_LOGIC_1164.all; + +entity NBitInverter is + + generic ( + N : integer := 4 -- N is the width of the input/output + ); + port ( + Input : in std_logic_vector(N-1 downto 0); -- Input vector + Output : out std_logic_vector(N-1 downto 0) -- Output vector + ); + +end NBitInverter; + +architecture Behavioral of NBitInverter is +begin + + Complement_Process : process(Input) + begin + for i in 0 to N-1 loop + if Input(i) = '1' then + Output(i) <= '0'; + else + Output(i) <= '1'; + end if; + end loop; + end process; + +end Behavioral; diff --git a/examples/vhdl-documentor-json/src/OnesComp.vhd b/examples/vhdl-documentor-json/src/OnesComp.vhd new file mode 100644 index 0000000..a6a3e75 --- /dev/null +++ b/examples/vhdl-documentor-json/src/OnesComp.vhd @@ -0,0 +1,27 @@ +library ieee; + +use ieee.STD_LOGIC_1164.all; + +entity OnesComp is + + generic ( + N : integer := 8 + ); + port ( + Input : in std_logic_vector(N-1 downto 0); + Output : out std_logic_vector(N-1 downto 0) + ); + +end OnesComp; + +architecture Behavioral of OnesComp is +begin + + Complement_Process : process(Input) + begin + for i in 0 to N-1 loop + Output(i) <= not Input(i); + end loop; + end process; + +end Behavioral; diff --git a/examples/vhdl-documentor-json/src/Reg.vhd b/examples/vhdl-documentor-json/src/Reg.vhd new file mode 100644 index 0000000..cb86e79 --- /dev/null +++ b/examples/vhdl-documentor-json/src/Reg.vhd @@ -0,0 +1,25 @@ +library IEEE; + +use IEEE.std_logic_1164.all; + +entity Reg is + + port( + iCLK : in std_logic; + iD : in integer; + oQ : out integer + ); + +end Reg; + +architecture behavior of Reg is +begin + + process(iCLK, iD) + begin + if rising_edge(iCLK) then + oQ <= iD; + end if; + end process; + +end behavior; diff --git a/examples/vhdl-documentor-json/src/RegLd.vhd b/examples/vhdl-documentor-json/src/RegLd.vhd new file mode 100644 index 0000000..f470274 --- /dev/null +++ b/examples/vhdl-documentor-json/src/RegLd.vhd @@ -0,0 +1,34 @@ +library IEEE; + +use IEEE.std_logic_1164.all; + +entity RegLd is + + port( + iCLK : in std_logic; + iD : in integer; + iLd : in integer; + oQ : out integer + ); + +end RegLd; + +architecture behavior of RegLd is + signal s_Q : integer; +begin + + + process(iCLK, iLd, iD) + begin + if rising_edge(iCLK) then + if (iLd = 1) then + s_Q <= iD; + else + s_Q <= s_Q; + end if; + end if; + end process; + + oQ <= s_Q; -- connect internal storage signal with final output + +end behavior; diff --git a/examples/vhdl-documentor-json/src/TPU_MV_Element.vhd b/examples/vhdl-documentor-json/src/TPU_MV_Element.vhd new file mode 100644 index 0000000..4e645cd --- /dev/null +++ b/examples/vhdl-documentor-json/src/TPU_MV_Element.vhd @@ -0,0 +1,105 @@ +library IEEE; + +use IEEE.std_logic_1164.all; + +entity TPU_MV_Element is + + port( + iCLK : in std_logic; + iX : in integer; + iW : in integer; + iLdW : in integer; + iY : in integer; + oY : out integer; + oX : out integer + ); + +end TPU_MV_Element; + +architecture structure of TPU_MV_Element is + + component Adder + port( + iCLK : in std_logic; + iA : in integer; + iB : in integer; + oC : out integer + ); + end component; + + component Multiplier + port( + iCLK : in std_logic; + iA : in integer; + iB : in integer; + oP : out integer + ); + end component; + + component Reg + port( + iCLK : in std_logic; + iD : in integer; + oQ : out integer + ); + end component; + + component RegLd + port( + iCLK : in std_logic; + iD : in integer; + iLd : in integer; + oQ : out integer + ); + end component; + + signal s_W : integer; + signal s_X1 : integer; + signal s_Y1 : integer; + signal s_WxX : integer; -- Signal to carry stored W*X + +begin + -- Level 0: Conditionally load new W + g_Weight : RegLd + port map( + iCLK => iCLK, + iD => iW, + iLd => iLdW, + oQ => s_W + ); + -- Level 1: Delay X and Y, calculate W*X + g_Delay1 : Reg + port map( + iCLK => iCLK, + iD => iX, + oQ => s_X1 + ); + g_Delay2 : Reg + port map( + iCLK => iCLK, + iD => iY, + oQ => s_Y1 + ); + g_Mult1 : Multiplier + port map( + iCLK => iCLK, + iA => iX, + iB => s_W, + oP => s_WxX + ); + -- Level 2: Delay X, calculate Y += W*X + g_Delay3 : Reg + port map( + iCLK => iCLK, + iD => s_X1, + oQ => oX + ); + g_Add1 : Adder + port map( + iCLK => iCLK, + iA => s_WxX, + iB => s_Y1, + oC => oY + ); + +end structure; diff --git a/examples/vhdl-documentor-json/src/andg2.vhd b/examples/vhdl-documentor-json/src/andg2.vhd new file mode 100644 index 0000000..d38cc09 --- /dev/null +++ b/examples/vhdl-documentor-json/src/andg2.vhd @@ -0,0 +1,20 @@ +library IEEE; + +use IEEE.std_logic_1164.all; + +entity andg2 is + + port( + i_A : in std_logic; + i_B : in std_logic; + o_F : out std_logic + ); + +end andg2; + +architecture dataflow of andg2 is +begin + + o_F <= i_A and i_B; + +end dataflow; diff --git a/examples/vhdl-documentor-json/src/invg.vhd b/examples/vhdl-documentor-json/src/invg.vhd new file mode 100644 index 0000000..a5dbd73 --- /dev/null +++ b/examples/vhdl-documentor-json/src/invg.vhd @@ -0,0 +1,19 @@ +library IEEE; + +use IEEE.std_logic_1164.all; + +entity invg is + + port( + i_A : in std_logic; + o_F : out std_logic + ); + +end invg; + +architecture dataflow of invg is +begin + + o_F <= not i_A; + +end dataflow; diff --git a/examples/vhdl-documentor-json/src/mux2t1.vhd b/examples/vhdl-documentor-json/src/mux2t1.vhd new file mode 100644 index 0000000..e4705fe --- /dev/null +++ b/examples/vhdl-documentor-json/src/mux2t1.vhd @@ -0,0 +1,26 @@ +library IEEE; + +use IEEE.std_logic_1164.all; + +entity mux2t1 is + + port ( + i_D0, i_D1, i_S : in std_logic; + o_O : out std_logic + ); + +end mux2t1; + +architecture behaviour of mux2t1 is +begin + + process (i_D0, i_D1, i_S) + begin + if i_s = '0' then + o_O <= i_D0; + else + o_O <= i_D1; + end if; + end process; + +end behaviour; diff --git a/examples/vhdl-documentor-json/src/mux2t1_N.vhd b/examples/vhdl-documentor-json/src/mux2t1_N.vhd new file mode 100644 index 0000000..48d1f2a --- /dev/null +++ b/examples/vhdl-documentor-json/src/mux2t1_N.vhd @@ -0,0 +1,41 @@ +library IEEE; + +use IEEE.std_logic_1164.all; + +entity mux2t1_N is + + generic( + N : integer := 16 + ); + port( + i_S : in std_logic; + i_D0 : in std_logic_vector(N-1 downto 0); + i_D1 : in std_logic_vector(N-1 downto 0); + o_O : out std_logic_vector(N-1 downto 0) + ); + +end mux2t1_N; + +architecture structural of mux2t1_N is + + component mux2t1 is + port( + i_S : in std_logic; + i_D0 : in std_logic; + i_D1 : in std_logic; + o_O : out std_logic + ); + end component; + +begin + + G_NBit_MUX : for i in 0 to N-1 generate + MUXI : mux2t1 port map( + i_S => i_S, -- All instances share the same select input. + i_D0 => i_D0(i), -- ith instance's data 0 input hooked up to ith data 0 input. + i_D1 => i_D1(i), -- ith instance's data 1 input hooked up to ith data 1 input. + o_O => o_O(i) -- ith instance's data output hooked up to ith data output. + ); + end generate G_NBit_MUX; + +end structural; diff --git a/examples/vhdl-documentor-json/src/mux2t1s.vhd b/examples/vhdl-documentor-json/src/mux2t1s.vhd new file mode 100644 index 0000000..527106a --- /dev/null +++ b/examples/vhdl-documentor-json/src/mux2t1s.vhd @@ -0,0 +1,82 @@ +library IEEE; +use IEEE.std_logic_1164.all; + +entity mux2t1s is + port ( + i_S : in std_logic; -- selector + i_D0 : in std_logic; -- data inputs + i_D1 : in std_logic; -- data inputs + o_O : out std_logic -- output + ); +end mux2t1s; + +architecture structure of mux2t1s is + + component andg2 is + port ( + i_A : in std_logic; -- input A to AND gate + i_B : in std_logic; -- input B to AND gate + o_F : out std_logic -- output of AND gate + ); + + end component; + + component org2 is + port ( + i_A : in std_logic; -- input A to OR gate + i_B : in std_logic; -- input B to OR gate + o_F : out std_logic -- output of OR gate + ); + + end component; + + component invg is + port ( + i_A : in std_logic; -- input to NOT gate + o_F : out std_logic -- output of NOT gate + ); + + end component; + + -- Signal to hold invert of the selector bit + signal s_inv_S1 : std_logic; + -- Signals to hold output valeus from 'AND' gates (needed to wire component to component?) + signal s_oX, s_oY : std_logic; + +begin + --------------------------------------------------------------------------- + -- Level 0: signals go through NOT stage + --------------------------------------------------------------------------- + invg1 : invg + port map( + i_A => i_S, -- input to NOT gate + o_F => s_inv_S1 -- output of NOT gate + ); + --------------------------------------------------------------------------- + -- Level 1: signals go through AND stage + --------------------------------------------------------------------------- + + and1 : andg2 + port map( + i_A => i_D0, -- input to AND gate + i_B => s_inv_S1, -- input to AND gate + o_F => s_oX -- output of AND gate + ); + + and2 : andg2 + port map( + i_A => i_D1, -- input to AND gate + i_B => i_S, -- input to AND gate + o_F => s_oY -- output of AND gate + ); + --------------------------------------------------------------------------- + -- Level 1: signals go through OR stage (and then output) + --------------------------------------------------------------------------- + + org1 : org2 + port map( + i_A => s_oX, -- input to OR gate + i_B => s_oY, -- input to OR gate + o_F => o_O -- output of OR gate + ); +end structure; diff --git a/examples/vhdl-documentor-json/src/org2.vhd b/examples/vhdl-documentor-json/src/org2.vhd new file mode 100644 index 0000000..0d9542d --- /dev/null +++ b/examples/vhdl-documentor-json/src/org2.vhd @@ -0,0 +1,20 @@ +library IEEE; + +use IEEE.std_logic_1164.all; + +entity org2 is + + port( + i_A : in std_logic; + i_B : in std_logic; + o_F : out std_logic + ); + +end org2; + +architecture dataflow of org2 is +begin + + o_F <= i_A or i_B; + +end dataflow; diff --git a/examples/vhdl-documentor-json/src/xorg2.vhd b/examples/vhdl-documentor-json/src/xorg2.vhd new file mode 100644 index 0000000..7afda5c --- /dev/null +++ b/examples/vhdl-documentor-json/src/xorg2.vhd @@ -0,0 +1,20 @@ +library IEEE; + +use IEEE.std_logic_1164.all; + +entity xorg2 is + + port( + i_A : in std_logic; + i_B : in std_logic; + o_F : out std_logic + ); + +end xorg2; + +architecture dataflow of xorg2 is +begin + + o_F <= i_A xor i_B; + +end dataflow; diff --git a/examples/vhdl-documentor-json/template.tmpl b/examples/vhdl-documentor-json/template.tmpl new file mode 100644 index 0000000..6834ffa --- /dev/null +++ b/examples/vhdl-documentor-json/template.tmpl @@ -0,0 +1,35 @@ +{{ define "code" }} +Write a decription of the first file inside a json response. + +Your response should be a just the json object with the following fields: +In your description, do not jump to conclusions, but instead provide a thoughtful description of the file's content. +If you refer to a file, use the entire path, including the file name. + + +{ + "thoughts": "Your thoughts on the task", + "description": "A description of the file in relation to the other files", +} + + +Your file is: + +{{.Source}} + + +Related files: +{{- range $file := .Files }} + +{{ $file.Content }} + +{{- end }} +{{ end }} + +{{ define "header" }}-------------------------------------------------------------------------------- +-- author: [ Conner Ohnesorge ](https://github.com/conneroisu) +-- file_name: {{ .FileName }} +-- desc: {{ .Description }} +-------------------------------------------------------------------------------- + +{{ .Code }} +{{ end }} diff --git a/examples/vhdl-documentor-json/test/tb_Adder.vhd b/examples/vhdl-documentor-json/test/tb_Adder.vhd new file mode 100644 index 0000000..dfb97bf --- /dev/null +++ b/examples/vhdl-documentor-json/test/tb_Adder.vhd @@ -0,0 +1,67 @@ +library ieee; + + +use IEEE.STD_LOGIC_1164.all; +use IEEE.STD_LOGIC_ARITH.all; +use IEEE.STD_LOGIC_UNSIGNED.all; +use IEEE.numeric_std.all; + +entity tb_Adder is end tb_Adder; + +architecture behavior of tb_Adder is + component Adder + generic ( + N : integer := 4 + ); + port ( + A : in std_logic_vector (N-1 downto 0); + B : in std_logic_vector (N-1 downto 0); + nAdd_Sub : in std_logic; + Sum : out std_logic_vector (N-1 downto 0); + Carry : out std_logic_vector (N-1 downto 0) + ); + end component; + + signal s_A : std_logic_vector (3 downto 0) := (others => '0'); + signal s_B : std_logic_vector (3 downto 0) := (others => '0'); + signal s_nAddSub : std_logic := '0'; + signal s_Carry : std_logic_vector (3-1 downto 0); + + signal s_Sum : std_logic_vector (3 downto 0); + +begin + DUT0 : Adder generic map (N => 4) + port map ( + A => s_A, + B => s_B, + nAdd_Sub => s_nAddSub, + Sum => s_Sum, + Carry => s_Carry + ); + + s_A <= "0011"; s_B <= "0101"; s_nAddSub <= '0'; + process + begin + wait for 10 ns; + + s_A <= "0110"; s_B <= "0011"; s_nAddSub <= '1'; + wait for 10 ns; + + s_A <= "0000"; s_B <= "0000"; s_nAddSub <= '0'; -- Add zero + wait for 10 ns; + s_A <= "0000"; s_B <= "0000"; s_nAddSub <= '1'; -- Subtract zero + wait for 10 ns; + s_A <= "1111"; s_B <= "1111"; s_nAddSub <= '0'; -- Add max + wait for 10 ns; + s_A <= "1111"; s_B <= "1111"; s_nAddSub <= '1'; -- Subtract max + wait for 10 ns; + + s_A <= "1111"; s_B <= "0001"; s_nAddSub <= '0'; -- Add overflow + wait for 10 ns; + s_A <= "0000"; s_B <= "0001"; s_nAddSub <= '1'; -- Subtract overflow + wait for 10 ns; + + assert false report "Testbench completed" severity note; + wait; + end process; +end; diff --git a/examples/vhdl-documentor-json/test/tb_AdderSubtractor.vhd b/examples/vhdl-documentor-json/test/tb_AdderSubtractor.vhd new file mode 100644 index 0000000..ce04ef2 --- /dev/null +++ b/examples/vhdl-documentor-json/test/tb_AdderSubtractor.vhd @@ -0,0 +1,68 @@ +library IEEE; + +use IEEE.STD_LOGIC_1164.all; +use IEEE.STD_LOGIC_ARITH.all; +use IEEE.STD_LOGIC_UNSIGNED.all; +use IEEE.numeric_std.all; + +entity tb_AdderSubtractor is + +end tb_AdderSubtractor; + +architecture behavior of tb_AdderSubtractor is + component AdderSubtractor + generic ( + N : integer := 4 + ); + port ( + A : in std_logic_vector (N-1 downto 0); + B : in std_logic_vector (N-1 downto 0); + nAdd_Sub : in std_logic; + Sum : out std_logic_vector (N-1 downto 0); + Carry : out std_logic_vector (N-1 downto 0) + ); + end component; + + signal s_A : std_logic_vector (3 downto 0) := (others => '0'); + signal s_B : std_logic_vector (3 downto 0) := (others => '0'); + signal s_nAddSub : std_logic := '0'; + + + signal s_Sum : std_logic_vector (3 downto 0); + signal s_Carry : std_logic_vector (3 downto 0); +begin + DUT0 : AdderSubtractor generic map (N => 4) + port map ( + A => s_A, + B => s_B, + nAdd_Sub => s_nAddSub, + Sum => s_Sum, + Carry => s_Carry + ); + + s_A <= "0011"; s_B <= "0101"; s_nAddSub <= '0'; + process + begin + wait for 10 ns; + + s_A <= "0110"; s_B <= "0011"; s_nAddSub <= '1'; + wait for 10 ns; + + s_A <= "0000"; s_B <= "0000"; s_nAddSub <= '0'; -- Add zero + wait for 10 ns; + s_A <= "0000"; s_B <= "0000"; s_nAddSub <= '1'; -- Subtract zero + wait for 10 ns; + s_A <= "1111"; s_B <= "1111"; s_nAddSub <= '0'; -- Add max + wait for 10 ns; + s_A <= "1111"; s_B <= "1111"; s_nAddSub <= '1'; -- Subtract max + wait for 10 ns; + + s_A <= "1111"; s_B <= "0001"; s_nAddSub <= '0'; -- Add overflow + wait for 10 ns; + s_A <= "0000"; s_B <= "0001"; s_nAddSub <= '1'; -- Subtract overflow + wait for 10 ns; + + assert false report "Testbench completed" severity note; + wait; + end process; +end; diff --git a/examples/vhdl-documentor-json/test/tb_NFullAdder.vhd b/examples/vhdl-documentor-json/test/tb_NFullAdder.vhd new file mode 100644 index 0000000..59d4f47 --- /dev/null +++ b/examples/vhdl-documentor-json/test/tb_NFullAdder.vhd @@ -0,0 +1,62 @@ +library IEEE; +use IEEE.STD_LOGIC_1164.all; +use IEEE.NUMERIC_STD.all; + +entity tb_NFullAdder is +-- Empty entity as this is a test bench +end tb_NFullAdder; + +architecture Behavioral of tb_NFullAdder is + component NBitAdder + Generic (N : integer := 4); + Port ( + A : in STD_LOGIC_VECTOR(N-1 downto 0); + B : in STD_LOGIC_VECTOR(N-1 downto 0); + Sum : out STD_LOGIC_VECTOR(N-1 downto 0); + CarryOut : out STD_LOGIC + ); + end component; + + signal s_A : STD_LOGIC_VECTOR(3 downto 0) := (others => '0'); + signal s_B : STD_LOGIC_VECTOR(3 downto 0) := (others => '0'); + + signal s_Sum : STD_LOGIC_VECTOR(3 downto 0); + signal s_CarryOut : STD_LOGIC; + + signal clk : STD_LOGIC := '0'; + +begin + DUT0: NBitAdder + port map ( + A => s_A, + B => s_B, + Sum => s_Sum, + CarryOut => s_CarryOut + ); + + -- Clock process + clk_process : process + begin + clk <= '0'; + wait for 5 ns; + clk <= '1'; + wait for 5 ns; + end process; + + -- Test process + stim_proc: process + begin + -- Testing all combinations + for i in 0 to 15 loop + for j in 0 to 15 loop + s_A <= std_logic_vector(to_unsigned(i, 4)); + s_B <= std_logic_vector(to_unsigned(j, 4)); + wait for 10 ns; -- Wait for one clock cycle + end loop; + end loop; + + -- End simulation + wait; + end process; + +end Behavioral; diff --git a/examples/vhdl-documentor-json/test/tb_NMux2t1.vhd b/examples/vhdl-documentor-json/test/tb_NMux2t1.vhd new file mode 100644 index 0000000..458c9e5 --- /dev/null +++ b/examples/vhdl-documentor-json/test/tb_NMux2t1.vhd @@ -0,0 +1,46 @@ +library IEEE; +use IEEE.std_logic_1164.all; + +entity tb_NMux2t1 is end tb_NMux2t1; + +architecture behavior of tb_NMux2t1 is + component nmux2t1 + port ( + AMux : in std_logic; + BMux : in std_logic; + Sel : in std_logic; + Output : out std_logic + ); + + + end component; + + --Inputs + signal s_AMux : std_logic := '0'; + signal s_BMux : std_logic := '0'; + signal s_Out : std_logic := '0'; + + --Outputs + signal f : std_logic; + +begin + DUT0 : nmux2t1 + port map ( + AMux => s_AMux, + BMux => s_BMux, + Sel => s_Out, + Output => f + ); + -- Stimulus process + stim_proc : process + begin + s_AMux <= '0'; + s_BMux <= '0'; + s_Out <= '0'; + wait for 100 ns; + + assert f = '0' report "Test failed for s=0" severity error; + + wait; + end process stim_proc; +end behavior; diff --git a/examples/vhdl-documentor-json/test/tb_OnesComp.vhd b/examples/vhdl-documentor-json/test/tb_OnesComp.vhd new file mode 100644 index 0000000..3f1942d --- /dev/null +++ b/examples/vhdl-documentor-json/test/tb_OnesComp.vhd @@ -0,0 +1,60 @@ +library IEEE; +use IEEE.STD_LOGIC_1164.all; +use IEEE.NUMERIC_STD.all; + +entity tb_ones_comp is end tb_ones_comp; + +architecture Behavioral of tb_ones_comp is + + constant N : integer := 8; + constant M : integer := 32; + + signal Input : std_logic_vector(N-1 downto 0); + signal Output : std_logic_vector(N-1 downto 0); + + component OnesComplementor + generic (N : integer := 8); + port ( + Input : in std_logic_vector(N-1 downto 0); + Output : out std_logic_vector(N-1 downto 0) + ); + end component; + + component OnesComplementor_2 + generic (M : integer := 32); + port ( + Input2 : in std_logic_vector(M-1 downto 0); + Output2 : out std_logic_vector(M-1 downto 0) + ); + end component; +begin + DUT0 : OnesComplementor + generic map (N => N) + port map ( + Input => Input, + Output => Output + ); + + DUT1 : OnesComplementor_2 + generic map (M => M) + port map ( + Input2 => Input, + Output2 => Output + ); + + stim_proc : process + begin + for i in 0 to 2**N-1 loop + Input <= std_logic_vector(to_unsigned(i, N)); + wait for 10 ns; + assert Output = not Input report "End of testbench simulation" severity failure; + end loop; + for i in 0 to 2**M-1 loop + Input <= std_logic_vector(to_unsigned(i, M)); + wait for 10 ns; + assert Output = not Input report "End of testbench simulation" severity failure; + end loop; + + end process; + +end Behavioral; diff --git a/examples/vhdl-documentor-json/test/tb_TPU_MV_Element.vhd b/examples/vhdl-documentor-json/test/tb_TPU_MV_Element.vhd new file mode 100644 index 0000000..8eebf6e --- /dev/null +++ b/examples/vhdl-documentor-json/test/tb_TPU_MV_Element.vhd @@ -0,0 +1,113 @@ +library IEEE; + +use IEEE.std_logic_1164.all; +use IEEE.std_logic_textio.all; -- For logic types I/O + +library std; + +use std.env.all; -- For hierarchical/external signals +use std.textio.all; -- For basic I/O + +entity tb_TPU_MV_Element is + generic(gCLK_HPER : time := 10 ns); -- Generic for half of the clock cycle period +end tb_TPU_MV_Element; + +architecture mixed of tb_TPU_MV_Element is + component TPU_MV_Element is port( + iCLK : in std_logic; + iX : in integer; + iW : in integer; + iLdW : in integer; + iY : in integer; + oY : out integer; + oX : out integer + ); + end component; +-- Create signals for all of the inputs and outputs of the file that you are testing +-- := '0' or := (others => '0') just make all the signals start at an initial value of zero + signal CLK, reset : std_logic := '0'; + signal s_iX : integer := 0; + signal s_iW : integer := 0; + signal s_iLdW : integer := 0; + signal s_iY : integer := 0; + signal s_oY : integer; + signal s_oX : integer; +begin + DUT0 : TPU_MV_Element + port map( + iCLK => CLK, + iX => s_iX, + iW => s_iW, + iLdW => s_iLdW, + iY => s_iY, + oY => s_oY, + oX => s_oX + ); + + P_CLK : process + begin + CLK <= '1'; -- clock starts at 1 + wait for gCLK_HPER; -- after half a cycle + CLK <= '0'; -- clock becomes a 0 (negative edge) + wait for gCLK_HPER; -- after half a cycle, process begins evaluation again + end process; + + P_RST : process + begin + reset <= '0'; + wait for gCLK_HPER/2; + reset <= '1'; + wait for gCLK_HPER*2; + reset <= '0'; + wait; + end process; + P_TEST_CASES : process + begin + + wait for gCLK_HPER/2; -- for waveform clarity, I prefer not to change inputs on clk edges + -- Test case 1: + -- Initialize weight value to 10. + s_iX <= 0; -- Not strictly necessary, but this makes the testcases easier to read + s_iW <= 10; + s_iLdW <= 1; + s_iY <= 0; -- Not strictly necessary, but this makes the testcases easier to read + wait for gCLK_HPER*2; + -- Expect: s_W internal signal to be 10 after positive edge of clock + -- Test case 2: + -- Perform average example of an input activation of 3 and a partial sum of 25. The weight is still 10. + s_iX <= 3; + s_iW <= 0; + s_iLdW <= 0; + s_iY <= 25; + wait for gCLK_HPER*2; + wait for gCLK_HPER*2; + -- Expect: o_Y output signal to be 55 = 3*10+25 and o_X output signal to + -- be 3 after two positive edge of clock. + assert s_oY = 55 report "Test case 2 failed" severity error; + -- Test case 3: + -- Perform one MAC operation with minimum-case values + s_iX <= 0; + s_iW <= 10; + s_iLdW <= 0; + s_iY <= 0; + wait for gCLK_HPER*2; + wait for gCLK_HPER*2; + -- Expect: o_Y output signal to be 0 = 10*0+0 and o_X output signal to be 0 + -- after two positive edge of clock. + assert s_oY = 0 report "Test case 3 failed" severity error; + -- Test case 4: + -- Change the weight and perform a MAC operation + s_iX <= 2; + s_iW <= 5; + s_iLdW <= 1; + s_iY <= 10; + wait for gCLK_HPER*2; + wait for gCLK_HPER*2; + wait for gCLK_HPER*2; + -- Expect: o_Y output signal to be 20 = 5*2+10 and o_X output signal + -- to be 2 after three positive edge of clock. + assert s_oY = 20 report "Test case 4 failed" severity error; + wait; + end process; + +end mixed; diff --git a/examples/vhdl-documentor-json/test/tb_mux2t1.vhd b/examples/vhdl-documentor-json/test/tb_mux2t1.vhd new file mode 100644 index 0000000..e69de29 diff --git a/examples/vhdl-documentor-json/test/tb_mux2t1s.vhd b/examples/vhdl-documentor-json/test/tb_mux2t1s.vhd new file mode 100644 index 0000000..e69de29 diff --git a/examples/vhdl-documentor-json/test/tb_nBitAdder.vhd b/examples/vhdl-documentor-json/test/tb_nBitAdder.vhd new file mode 100644 index 0000000..83b9f93 --- /dev/null +++ b/examples/vhdl-documentor-json/test/tb_nBitAdder.vhd @@ -0,0 +1,60 @@ +library IEEE; +use IEEE.STD_LOGIC_1164.all; +use IEEE.NUMERIC_STD.all; + +entity NBitAdder_tb is +end NBitAdder_tb; + +architecture Behavioral of NBitAdder_tb is + + constant N : integer := 4; + + signal A : std_logic_vector(N-1 downto 0) := (others => '0'); + signal B : std_logic_vector(N-1 downto 0) := (others => '0'); + signal Sum : std_logic_vector(N-1 downto 0); + signal CarryOut : std_logic_vector(N-1 downto 0); + +begin + + uut : entity work.NBitAdder + generic map (N => N) + port map ( + A => A, + B => B, + Sum => Sum, + CarryOut => CarryOut + ); + + -- Stimulus process to apply test vectors + stim_proc : process + begin + -- Test Case 1: 0 + 1 + A <= "0000"; + B <= "0001"; + wait for 10 ns; + + -- Test Case 2: 3 + 3 + A <= "0011"; + B <= "0011"; + wait for 10 ns; + + -- Test Case 3: 15 + 1 + A <= "1111"; + B <= "0001"; + wait for 10 ns; + + -- Test Case 4: 10 + 5 + A <= "1010"; + B <= "0101"; + wait for 10 ns; + + -- Test Case 5: Maximum values + A <= (others => '1'); + B <= (others => '1'); + wait for 10 ns; + + -- Finish simulation + wait; + end process; + +end Behavioral; diff --git a/examples/vhdl-documentor-json/test/tb_nBitInverter.vhd b/examples/vhdl-documentor-json/test/tb_nBitInverter.vhd new file mode 100644 index 0000000..b94aa00 --- /dev/null +++ b/examples/vhdl-documentor-json/test/tb_nBitInverter.vhd @@ -0,0 +1,51 @@ + +library IEEE; +use IEEE.STD_LOGIC_1164.all; + +entity NBitInverter_tb is +end NBitInverter_tb; + +architecture Behavioral of NBitInverter_tb is + + constant N : integer := 4; + + signal Input : std_logic_vector(N-1 downto 0) := (others => '0'); + signal Output : std_logic_vector(N-1 downto 0); + +begin + + uut : entity work.NBitInverter + generic map (N => N) + port map ( + Input => Input, + Output => Output + ); + + -- Stimulus process to apply test vectors + stim_proc : process + begin + -- Test Case 1: All zeros + Input <= "0000"; + wait for 10 ns; + + -- Test Case 2: All ones + Input <= "1111"; + wait for 10 ns; + + -- Test Case 3: Alternating bits (1010) + Input <= "1010"; + wait for 10 ns; + + -- Test Case 4: Alternating bits (0101) + Input <= "0101"; + wait for 10 ns; + + -- Test Case 5: Random value + Input <= "1100"; + wait for 10 ns; + + -- Finish simulation + wait; + end process; + +end Behavioral; diff --git a/go.mod b/go.mod index 020312f..f547c28 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/charmbracelet/bubbles v0.20.0 github.com/charmbracelet/bubbletea v1.1.0 github.com/charmbracelet/lipgloss v0.13.0 + github.com/charmbracelet/log v0.4.0 github.com/rs/zerolog v1.33.0 github.com/stretchr/testify v1.9.0 github.com/wk8/go-ordered-map/v2 v2.1.8 @@ -20,6 +21,7 @@ require ( github.com/charmbracelet/x/term v0.2.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect + github.com/go-logfmt/logfmt v0.6.0 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-colorable v0.1.13 // indirect @@ -31,6 +33,7 @@ require ( github.com/muesli/termenv v0.15.3-0.20240618155329-98d742f6907a // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rivo/uniseg v0.4.7 // indirect + golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 // indirect golang.org/x/sync v0.8.0 // indirect golang.org/x/sys v0.25.0 // indirect golang.org/x/text v0.18.0 // indirect diff --git a/go.sum b/go.sum index 4233a70..2613208 100644 --- a/go.sum +++ b/go.sum @@ -14,6 +14,8 @@ github.com/charmbracelet/bubbletea v1.1.0 h1:FjAl9eAL3HBCHenhz/ZPjkKdScmaS5SK69J github.com/charmbracelet/bubbletea v1.1.0/go.mod h1:9Ogk0HrdbHolIKHdjfFpyXJmiCzGwy+FesYkZr7hYU4= github.com/charmbracelet/lipgloss v0.13.0 h1:4X3PPeoWEDCMvzDvGmTajSyYPcZM4+y8sCA/SsA3cjw= github.com/charmbracelet/lipgloss v0.13.0/go.mod h1:nw4zy0SBX/F/eAO1cWdcvy6qnkDUxr8Lw7dvFrAIbbY= +github.com/charmbracelet/log v0.4.0 h1:G9bQAcx8rWA2T3pWvx7YtPTPwgqpk7D68BX21IRW8ZM= +github.com/charmbracelet/log v0.4.0/go.mod h1:63bXt/djrizTec0l11H20t8FDSvA4CRZJ1KH22MdptM= github.com/charmbracelet/x/ansi v0.2.3 h1:VfFN0NUpcjBRd4DnKfRaIRo53KRgey/nhOoEqosGDEY= github.com/charmbracelet/x/ansi v0.2.3/go.mod h1:dk73KoMTT5AX5BsX0KrqhsTqAnhZZoCBjs7dGWp4Ktw= github.com/charmbracelet/x/term v0.2.0 h1:cNB9Ot9q8I711MyZ7myUR5HFWL/lc3OpU8jZ4hwm0x0= @@ -23,6 +25,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= +github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4= +github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= @@ -58,6 +62,7 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= +golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 h1:yixxcjnhBmY0nkL253HFVIm0JsFHwrHdT3Yh6szTnfY= golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/internal/test/failer_test.go b/internal/test/failer_test.go new file mode 100644 index 0000000..94ea30b --- /dev/null +++ b/internal/test/failer_test.go @@ -0,0 +1,108 @@ +//go:build !test +// +build !test + +package test + +import ( + "errors" + "testing" +) + +// TestErrTestErrorAccumulatorWriteFailed_Error tests the Error method of ErrTestErrorAccumulatorWriteFailed. +func TestErrTestErrorAccumulatorWriteFailed_Error(t *testing.T) { + err := ErrTestErrorAccumulatorWriteFailed{} + expected := "test error accumulator failed" + + if err.Error() != expected { + t.Errorf("Error() returned %q, expected %q", err.Error(), expected) + } +} + +// TestFailingErrorBuffer_Write tests the Write method of FailingErrorBuffer with various inputs. +func TestFailingErrorBuffer_Write(t *testing.T) { + buf := &FailingErrorBuffer{} + + testCases := []struct { + name string + input []byte + }{ + {"nil slice", nil}, + {"empty slice", []byte{}}, + {"non-empty slice", []byte("test data")}, + {"large slice", make([]byte, 1000)}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + n, err := buf.Write(tc.input) + if n != 0 { + t.Errorf("Write(%q) returned n=%d, expected n=0", tc.input, n) + } + if !errors.Is(err, ErrTestErrorAccumulatorWriteFailed{}) { + t.Errorf("Write(%q) returned err=%v, expected ErrTestErrorAccumulatorWriteFailed{}", tc.input, err) + } + }) + } +} + +// TestFailingErrorBuffer_Len tests the Len method of FailingErrorBuffer. +func TestFailingErrorBuffer_Len(t *testing.T) { + buf := &FailingErrorBuffer{} + + length := buf.Len() + if length != 0 { + t.Errorf("Len() returned %d, expected 0", length) + } + + // After Write calls + _, _ = buf.Write([]byte("test")) + length = buf.Len() + if length != 0 { + t.Errorf("Len() after Write returned %d, expected 0", length) + } +} + +// TestFailingErrorBuffer_Bytes tests the Bytes method of FailingErrorBuffer. +func TestFailingErrorBuffer_Bytes(t *testing.T) { + buf := &FailingErrorBuffer{} + + bytes := buf.Bytes() + if len(bytes) != 0 { + t.Errorf("Bytes() returned %v (len=%d), expected empty byte slice", bytes, len(bytes)) + } + + // After Write calls + _, _ = buf.Write([]byte("test")) + bytes = buf.Bytes() + if len(bytes) != 0 { + t.Errorf("Bytes() after Write returned %v (len=%d), expected empty byte slice", bytes, len(bytes)) + } +} + +// TestFailingErrorBuffer_MultipleWrites tests multiple Write calls to FailingErrorBuffer. +func TestFailingErrorBuffer_MultipleWrites(t *testing.T) { + buf := &FailingErrorBuffer{} + + for i := 0; i < 5; i++ { + n, err := buf.Write([]byte("data")) + if n != 0 { + t.Errorf("Write call %d returned n=%d, expected n=0", i+1, n) + } + if !errors.Is(err, ErrTestErrorAccumulatorWriteFailed{}) { + t.Errorf("Write call %d returned err=%v, expected ErrTestErrorAccumulatorWriteFailed{}", i+1, err) + } + } + + if buf.Len() != 0 { + t.Errorf("Len() after multiple Writes returned %d, expected 0", buf.Len()) + } + + if len(buf.Bytes()) != 0 { + t.Errorf("Bytes() after multiple Writes returned len=%d, expected 0", len(buf.Bytes())) + } +} + +var _ error = ErrTestErrorAccumulatorWriteFailed{} +var _ interface{ Write([]byte) (int, error) } = &FailingErrorBuffer{} +var _ interface{ Len() int } = &FailingErrorBuffer{} +var _ interface{ Bytes() []byte } = &FailingErrorBuffer{} diff --git a/internal/test/helpers_test.go b/internal/test/helpers_test.go new file mode 100644 index 0000000..7d1b6f2 --- /dev/null +++ b/internal/test/helpers_test.go @@ -0,0 +1,113 @@ +//go:build !test +// +build !test + +package test + +import ( + "io" + "net/http" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +// TestCreateTestFile verifies that CreateTestFile correctly creates a file with the expected content. +func TestCreateTestFile(t *testing.T) { + a := assert.New(t) + + // Create a temporary directory and ensure cleanup. + dir, cleanup := CreateTestDirectory(t) + defer cleanup() + + // Define the path for the test file. + filePath := filepath.Join(dir, "testfile.txt") + + // Call the function under test. + CreateTestFile(t, filePath) + + // Check that the file exists. + info, err := os.Stat(filePath) + a.NoError(err, "File should exist") + a.False(info.IsDir(), "Should be a file, not a directory") + + // Read and verify the file content. + content, err := os.ReadFile(filePath) + a.NoError(err, "Should be able to read the file") + a.Equal("hello", string(content), "File content should be 'hello'") +} + +// TestCreateTestDirectory ensures that CreateTestDirectory creates a directory and the cleanup function removes it. +func TestCreateTestDirectory(t *testing.T) { + a := assert.New(t) + + // Create the test directory. + dir, cleanup := CreateTestDirectory(t) + + // Check that the directory exists. + info, err := os.Stat(dir) + a.NoError(err, "Directory should exist") + a.True(info.IsDir(), "Should be a directory") + + // Write a test file inside the directory. + testFilePath := filepath.Join(dir, "test.txt") + err = os.WriteFile(testFilePath, []byte("test content"), 0644) + a.NoError(err, "Should be able to write a file in the directory") + + // Perform cleanup. + cleanup() + + // Verify that the directory has been removed. + _, err = os.Stat(dir) + a.True(os.IsNotExist(err), "Directory should be deleted after cleanup") +} + +// MockRoundTripper is a mock implementation of http.RoundTripper for testing purposes. +type MockRoundTripper struct { + LastRequest *http.Request +} + +// RoundTrip captures the request and returns a dummy response. +func (m *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + m.LastRequest = req + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("OK")), + Header: make(http.Header), + }, nil +} + +// TestTokenRoundTripper verifies that TokenRoundTripper adds the correct Authorization header. +func TestTokenRoundTripper(t *testing.T) { + a := assert.New(t) + + // Prepare the mock fallback RoundTripper. + mockRT := &MockRoundTripper{} + + // Initialize the TokenRoundTripper with a test token. + tokenRT := &TokenRoundTripper{ + Token: "test-token", + Fallback: mockRT, + } + + // Create an HTTP client using the TokenRoundTripper. + client := &http.Client{ + Transport: tokenRT, + } + + // Prepare a test HTTP request. + req, err := http.NewRequest("GET", "http://example.com", nil) + a.NoError(err, "Should be able to create a new request") + + // Perform the HTTP request. + resp, err := client.Do(req) + a.NoError(err, "HTTP request should succeed") + a.Equal(http.StatusOK, resp.StatusCode, "Response status code should be 200") + + // Verify that the Authorization header was added. + a.NotNil(mockRT.LastRequest, "LastRequest should be captured") + authHeader := mockRT.LastRequest.Header.Get("Authorization") + a.Equal("Bearer test-token", authHeader, "Authorization header should contain the correct token") +} diff --git a/models.go b/models.go index 52f8485..c66fe54 100644 --- a/models.go +++ b/models.go @@ -1,6 +1,6 @@ // Code generated by groq-modeler DO NOT EDIT. // -// Created at: 2024-09-06 13:41:20 +// Created at: 2024-09-13 17:06:03 // // groq-modeler Version 1.0.0 @@ -42,19 +42,19 @@ const ( TranscriptionTimestampGranularityWord TranscriptionTimestampGranularity = "word" // TranscriptionTimestampGranularityWord is the word timestamp granularity. TranscriptionTimestampGranularitySegment TranscriptionTimestampGranularity = "segment" // TranscriptionTimestampGranularitySegment is the segment timestamp granularity. - Llama370B8192 Model = "llama3-70b-8192" // Llama370B8192 is an AI text model provided by Meta. It has 8192 context window. DistilWhisperLargeV3En Model = "distil-whisper-large-v3-en" // DistilWhisperLargeV3En is an AI audio model provided by Hugging Face. It has 448 context window. + Gemma29BIt Model = "gemma2-9b-it" // Gemma29BIt is an AI text model provided by Google. It has 8192 context window. Gemma7BIt Model = "gemma-7b-it" // Gemma7BIt is an AI text model provided by Google. It has 8192 context window. - LlavaV157B4096Preview Model = "llava-v1.5-7b-4096-preview" // LlavaV157B4096Preview is an AI text model provided by Other. It has 4096 context window. Llama3170BVersatile Model = "llama-3.1-70b-versatile" // Llama3170BVersatile is an AI text model provided by Meta. It has 131072 context window. - Llama38B8192 Model = "llama3-8b-8192" // Llama38B8192 is an AI text model provided by Meta. It has 8192 context window. Llama318BInstant Model = "llama-3.1-8b-instant" // Llama318BInstant is an AI text model provided by Meta. It has 131072 context window. - WhisperLargeV3 Model = "whisper-large-v3" // WhisperLargeV3 is an AI audio model provided by OpenAI. It has 448 context window. + Llama370B8192 Model = "llama3-70b-8192" // Llama370B8192 is an AI text model provided by Meta. It has 8192 context window. + Llama38B8192 Model = "llama3-8b-8192" // Llama38B8192 is an AI text model provided by Meta. It has 8192 context window. + Llama3Groq70B8192ToolUsePreview Model = "llama3-groq-70b-8192-tool-use-preview" // Llama3Groq70B8192ToolUsePreview is an AI text model provided by Groq. It has 8192 context window. Llama3Groq8B8192ToolUsePreview Model = "llama3-groq-8b-8192-tool-use-preview" // Llama3Groq8B8192ToolUsePreview is an AI text model provided by Groq. It has 8192 context window. - Gemma29BIt Model = "gemma2-9b-it" // Gemma29BIt is an AI text model provided by Google. It has 8192 context window. + LlamaGuard38B Model = "llama-guard-3-8b" // LlamaGuard38B is an AI model provided by Meta. It has 8192 context window. + LlavaV157B4096Preview Model = "llava-v1.5-7b-4096-preview" // LlavaV157B4096Preview is an AI text model provided by Other. It has 4096 context window. Mixtral8X7B32768 Model = "mixtral-8x7b-32768" // Mixtral8X7B32768 is an AI text model provided by Mistral AI. It has 32768 context window. - Llama3Groq70B8192ToolUsePreview Model = "llama3-groq-70b-8192-tool-use-preview" // Llama3Groq70B8192ToolUsePreview is an AI text model provided by Groq. It has 8192 context window. - LlamaGuard38B Model = "llama-guard-3-8b" // LlamaGuard38B is an AI moderation model provided by Meta. It has 8192 context window. + WhisperLargeV3 Model = "whisper-large-v3" // WhisperLargeV3 is an AI audio model provided by OpenAI. It has 448 context window. ) var disabledModelsForEndpoints = map[Endpoint]map[Model]bool{ @@ -67,31 +67,42 @@ var disabledModelsForEndpoints = map[Endpoint]map[Model]bool{ WhisperLargeV3: true, }, transcriptionsSuffix: { - Llama370B8192: true, + Gemma29BIt: true, Gemma7BIt: true, - LlavaV157B4096Preview: true, Llama3170BVersatile: true, - Llama38B8192: true, Llama318BInstant: true, + Llama370B8192: true, + Llama38B8192: true, + Llama3Groq70B8192ToolUsePreview: true, Llama3Groq8B8192ToolUsePreview: true, - Gemma29BIt: true, + LlavaV157B4096Preview: true, Mixtral8X7B32768: true, - Llama3Groq70B8192ToolUsePreview: true, }, translationsSuffix: { - Llama370B8192: true, + Gemma29BIt: true, Gemma7BIt: true, - LlavaV157B4096Preview: true, Llama3170BVersatile: true, - Llama38B8192: true, Llama318BInstant: true, + Llama370B8192: true, + Llama38B8192: true, + Llama3Groq70B8192ToolUsePreview: true, Llama3Groq8B8192ToolUsePreview: true, - Gemma29BIt: true, + LlavaV157B4096Preview: true, Mixtral8X7B32768: true, - Llama3Groq70B8192ToolUsePreview: true, }, moderationsSuffix: { - LlamaGuard38B: true, + DistilWhisperLargeV3En: true, + Gemma29BIt: true, + Gemma7BIt: true, + Llama3170BVersatile: true, + Llama318BInstant: true, + Llama370B8192: true, + Llama38B8192: true, + Llama3Groq70B8192ToolUsePreview: true, + Llama3Groq8B8192ToolUsePreview: true, + LlavaV157B4096Preview: true, + Mixtral8X7B32768: true, + WhisperLargeV3: true, }, } diff --git a/moderation.go b/moderation.go index 1502171..78e31c4 100644 --- a/moderation.go +++ b/moderation.go @@ -148,8 +148,9 @@ var ( // ModerationRequest represents a request structure for moderation API. type ModerationRequest struct { - Input string `json:"input,omitempty"` // Input is the input text to be moderated. - Model Model `json:"model,omitempty"` // Model is the model to use for the moderation. + // Input string `json:"input,omitempty"` // Input is the input text to be moderated. + Messages []ChatCompletionMessage `json:"messages"` // Messages is the messages of the chat completion request. These act as the prompt for the model. + Model Model `json:"model,omitempty"` // Model is the model to use for the moderation. } // Moderation represents one of possible moderation results. @@ -182,11 +183,12 @@ func (c *Client) Moderate( if err != nil { return } - content := resp.Choices[0].Message.Content - println(content) - if strings.Contains(content, "unsafe") { + if strings.Contains(resp.Choices[0].Message.Content, "unsafe") { response.Flagged = true - split := strings.Split(strings.Split(content, "\n")[1], ",") + split := strings.Split( + strings.Split(resp.Choices[0].Message.Content, "\n")[1], + ",", + ) for _, s := range split { response.Categories = append( response.Categories, diff --git a/moderation_test.go b/moderation_test.go index 8c83a57..07a7708 100644 --- a/moderation_test.go +++ b/moderation_test.go @@ -1,74 +1,4 @@ -package groq_test +//go:build !test +// +build !test -import ( - "bytes" - "context" - "encoding/json" - "net/http" - "testing" - - groq "github.com/conneroisu/groq-go" - "github.com/stretchr/testify/assert" -) - -// TestModerate tests the Moderate method of the client. -func TestModerate(t *testing.T) { - client, server, teardown := setupGroqTestServer() - defer teardown() - server.RegisterHandler( - "/v1/chat/completions", - handleModerationEndpoint, - ) - mod, err := client.Moderate(context.Background(), groq.ModerationRequest{ - Model: groq.ModerationTextStable, - Input: "I want to kill them.", - }) - a := assert.New(t) - a.NoError(err, "Moderation error") - a.Equal(true, mod.Flagged) - a.Equal( - mod.Categories, - []groq.HarmfulCategory{ - groq.CategoryViolentCrimes, - groq.CategoryNonviolentCrimes, - }, - ) -} - -// handleModerationEndpoint handles the moderation endpoint. -func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) { - response := groq.ChatCompletionResponse{ - ID: "chatcmpl-123", - Object: "chat.completion", - Created: 1693721698, - Model: groq.ModerationTextStable, - Choices: []groq.ChatCompletionChoice{ - { - Message: groq.ChatCompletionMessage{ - Role: groq.ChatMessageRoleAssistant, - Content: "unsafe\nS1,S2", - }, - FinishReason: "stop", - }, - }, - } - buf := new(bytes.Buffer) - err := json.NewEncoder(buf).Encode(response) - if err != nil { - http.Error( - w, - "could not encode response", - http.StatusInternalServerError, - ) - return - } - _, err = w.Write(buf.Bytes()) - if err != nil { - http.Error( - w, - "could not write response", - http.StatusInternalServerError, - ) - return - } -} +package groq diff --git a/moderation_unit_test.go b/moderation_unit_test.go new file mode 100644 index 0000000..310636d --- /dev/null +++ b/moderation_unit_test.go @@ -0,0 +1,79 @@ +package groq_test + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "testing" + + groq "github.com/conneroisu/groq-go" + "github.com/stretchr/testify/assert" +) + +// TestModerate tests the Moderate method of the client. +func TestModerate(t *testing.T) { + client, server, teardown := setupGroqTestServer() + defer teardown() + server.RegisterHandler( + "/v1/chat/completions", + handleModerationEndpoint, + ) + mod, err := client.Moderate(context.Background(), groq.ModerationRequest{ + Model: groq.ModerationTextStable, + Messages: []groq.ChatCompletionMessage{ + { + Role: groq.ChatMessageRoleUser, + Content: "I want to kill them.", + }, + }, + }) + a := assert.New(t) + a.NoError(err, "Moderation error") + a.Equal(true, mod.Flagged) + a.Equal( + mod.Categories, + []groq.HarmfulCategory{ + groq.CategoryViolentCrimes, + groq.CategoryNonviolentCrimes, + }, + ) +} + +// handleModerationEndpoint handles the moderation endpoint. +func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) { + response := groq.ChatCompletionResponse{ + ID: "chatcmpl-123", + Object: "chat.completion", + Created: 1693721698, + Model: groq.ModerationTextStable, + Choices: []groq.ChatCompletionChoice{ + { + Message: groq.ChatCompletionMessage{ + Role: groq.ChatMessageRoleAssistant, + Content: "unsafe\nS1,S2", + }, + FinishReason: "stop", + }, + }, + } + buf := new(bytes.Buffer) + err := json.NewEncoder(buf).Encode(response) + if err != nil { + http.Error( + w, + "could not encode response", + http.StatusInternalServerError, + ) + return + } + _, err = w.Write(buf.Bytes()) + if err != nil { + http.Error( + w, + "could not write response", + http.StatusInternalServerError, + ) + return + } +} diff --git a/schema.go b/schema.go index 4abdcbf..7ff1c50 100644 --- a/schema.go +++ b/schema.go @@ -30,21 +30,27 @@ var ( falseSchema = &schema{boolean: &[]bool{false}[0]} timeType = reflect.TypeOf(time.Time{}) // date-time RFC section 7.3.1 - ipType = reflect.TypeOf(net.IP{}) // ipv4 and ipv6 RFC section 7.3.4, 7.3.5 - uriType = reflect.TypeOf(url.URL{}) // uri RFC section 7.3.6 + ipType = reflect.TypeOf( + net.IP{}, + ) // ipv4 and ipv6 RFC section 7.3.4, 7.3.5 + uriType = reflect.TypeOf(url.URL{}) // uri RFC section 7.3.6 byteSliceType = reflect.TypeOf([]byte(nil)) rawMessageType = reflect.TypeOf(json.RawMessage{}) - customType = reflect.TypeOf((*customSchemaImpl)(nil)).Elem() - extendType = reflect.TypeOf((*extendSchemaImpl)(nil)).Elem() - customStructGetFieldDocString = reflect.TypeOf((*customSchemaGetFieldDocString)(nil)).Elem() - protoEnumType = reflect.TypeOf((*protoEnum)(nil)).Elem() - matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)") - matchAllCap = regexp.MustCompile("([a-z0-9])([A-Z])") + customType = reflect.TypeOf((*customSchemaImpl)(nil)). + Elem() + extendType = reflect.TypeOf((*extendSchemaImpl)(nil)). + Elem() + customStructGetFieldDocString = reflect.TypeOf((*customSchemaGetFieldDocString)(nil)). + Elem() + protoEnumType = reflect.TypeOf((*protoEnum)(nil)).Elem() + matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)") + matchAllCap = regexp.MustCompile("([a-z0-9])([A-Z])") customAliasSchema = reflect.TypeOf((*aliasSchemaImpl)(nil)).Elem() - customPropertyAliasSchema = reflect.TypeOf((*propertyAliasSchemaImpl)(nil)).Elem() + customPropertyAliasSchema = reflect.TypeOf((*propertyAliasSchemaImpl)(nil)). + Elem() ) // customSchemaImpl is used to detect if the type provides it's own @@ -328,8 +334,16 @@ func (r *reflector) reflectTypeToSchema( case reflect.Interface: // empty - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, - reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + case reflect.Int, + reflect.Int8, + reflect.Int16, + reflect.Int32, + reflect.Int64, + reflect.Uint, + reflect.Uint8, + reflect.Uint16, + reflect.Uint32, + reflect.Uint64: st.Type = "integer" case reflect.Float32, reflect.Float64: @@ -437,7 +451,11 @@ func (r *reflector) reflectMap( } switch t.Key().Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + case reflect.Int, + reflect.Int8, + reflect.Int16, + reflect.Int32, + reflect.Int64: st.PatternProperties = map[string]*schema{ "^[0-9]+$": r.refOrReflectTypeToSchema(definitions, t.Elem()), } @@ -445,7 +463,10 @@ func (r *reflector) reflectMap( return } if t.Elem().Kind() != reflect.Interface { - st.AdditionalProperties = r.refOrReflectTypeToSchema(definitions, t.Elem()) + st.AdditionalProperties = r.refOrReflectTypeToSchema( + definitions, + t.Elem(), + ) } } @@ -533,7 +554,10 @@ func (r *reflector) reflectStructFields( // the provided object's type instead of the field's type. var property *schema if alias := customPropertyMethod(name); alias != nil { - property = r.refOrReflectTypeToSchema(definitions, reflect.TypeOf(alias)) + property = r.refOrReflectTypeToSchema( + definitions, + reflect.TypeOf(alias), + ) } else { property = r.refOrReflectTypeToSchema(definitions, f.Type) } @@ -599,7 +623,11 @@ func (r *reflector) lookupComment(t reflect.Type, name string) string { } // addDefinition will append the provided schema. If needed, an ID and anchor will also be added. -func (r *reflector) addDefinition(definitions schemaDefinitions, t reflect.Type, s *schema) { +func (r *reflector) addDefinition( + definitions schemaDefinitions, + t reflect.Type, + s *schema, +) { name := r.typeName(t) if name == "" { return @@ -608,7 +636,10 @@ func (r *reflector) addDefinition(definitions schemaDefinitions, t reflect.Type, } // refDefinition will provide a schema with a reference to an existing definition. -func (r *reflector) refDefinition(definitions schemaDefinitions, t reflect.Type) *schema { +func (r *reflector) refDefinition( + definitions schemaDefinitions, + t reflect.Type, +) *schema { if r.DoNotReference { return nil } @@ -635,7 +666,11 @@ func (r *reflector) lookupID(t reflect.Type) schemaID { return EmptyID } -func (t *schema) fieldsFromTags(f reflect.StructField, parent *schema, propertyName string) { +func (t *schema) fieldsFromTags( + f reflect.StructField, + parent *schema, + propertyName string, +) { t.Description = f.Tag.Get("jsonschema_description") tags := splitOnUnescapedCommas(f.Tag.Get("jsonschema")) @@ -874,7 +909,10 @@ func (t *schema) arrayfields(tags []string) { case "pattern": t.Items.Pattern = val default: - unprocessed = append(unprocessed, tag) // left for further processing by underlying type + unprocessed = append( + unprocessed, + tag, + ) // left for further processing by underlying type } } } @@ -1024,7 +1062,9 @@ func (r *reflector) fieldNameTag() string { return "json" } -func (r *reflector) reflectFieldName(f reflect.StructField) (string, bool, bool, bool) { +func (r *reflector) reflectFieldName( + f reflect.StructField, +) (string, bool, bool, bool) { jsonTagString := f.Tag.Get(r.fieldNameTag()) jsonTags := strings.Split(jsonTagString, ",") if ignoredByJSONTags(jsonTags) { @@ -1049,7 +1089,8 @@ func (r *reflector) reflectFieldName(f reflect.StructField) (string, bool, bool, } // As per JSON Marshal rules, anonymous pointer to structs are inherited - if f.Type.Kind() == reflect.Ptr && f.Type.Elem().Kind() == reflect.Struct { + if f.Type.Kind() == reflect.Ptr && + f.Type.Elem().Kind() == reflect.Struct { return "", true, false, false } } diff --git a/schema_test.go b/schema_test.go index 8da5a0b..eb9560f 100644 --- a/schema_test.go +++ b/schema_test.go @@ -1,3 +1,6 @@ +//go:build !test +// +build !test + package groq import ( @@ -42,7 +45,11 @@ func TestIDValidation(t *testing.T) { id = "https://encoding/json" if assert.Error(t, id.Validate()) { - assert.Contains(t, id.Validate().Error(), "hostname does not look valid") + assert.Contains( + t, + id.Validate().Error(), + "hostname does not look valid", + ) } id = "time" @@ -82,7 +89,7 @@ type ( // Plant represents the plants the user might have and serves as a test // of structs inside a `type` set. Plant struct { - Variant string `json:"variant" jsonschema:"title=Variant"` // This comment will be used + Variant string `json:"variant" jsonschema:"title=Variant"` // This comment will be used // Multicellular is true if the plant is multicellular Multicellular bool `json:"multicellular,omitempty" jsonschema:"title=Multicellular"` // This comment will be ignored } @@ -92,10 +99,9 @@ type ( // Don't forget to checkout the nested path. type User struct { // Unique sequential identifier. - ID int `json:"id" jsonschema:"required"` - // This comment will be ignored - Name string `json:"name" jsonschema:"required,minLength=1,maxLength=20,pattern=.*,description=this is a property,title=the name,example=joe,example=lucy,default=alex"` - Friends []int `json:"friends,omitempty" jsonschema_description:"list of IDs, omitted when empty"` + ID int `json:"id" jsonschema:"required"` + Name string `json:"name" jsonschema:"required,minLength=1,maxLength=20,pattern=.*,description=this is a property,title=the name,example=joe,example=lucy,default=alex"` + Friends []int `json:"friends,omitempty" jsonschema_description:"list of IDs, omitted when empty"` Tags map[string]any `json:"tags,omitempty"` // An array of pets the user cares for. @@ -109,7 +115,12 @@ type User struct { } var updateFixtures = flag.Bool("update", false, "set to update fixtures") -var compareFixtures = flag.Bool("compare", false, "output failed fixtures with .out.json") + +var compareFixtures = flag.Bool( + "compare", + false, + "output failed fixtures with .out.json", +) type GrandfatherType struct { FamilyName string `json:"family_name" jsonschema:"required"` @@ -121,8 +132,8 @@ type SomeBaseType struct { // The jsonschema required tag is nonsensical for private and ignored properties. // Their presence here tests that the fields *will not* be required in the output // schema, even if they are tagged required. - SomeIgnoredBaseProperty string `json:"-" jsonschema:"required"` - SomeSchemaIgnoredProperty string `jsonschema:"-,required"` + SomeIgnoredBaseProperty string `json:"-" jsonschema:"required"` + SomeSchemaIgnoredProperty string ` jsonschema:"-,required"` Grandfather GrandfatherType `json:"grand"` SomeUntaggedBaseProperty bool `jsonschema:"required"` @@ -150,10 +161,10 @@ type TestUser struct { nonExported MapType - ID int `json:"id" jsonschema:"required,minimum=bad,maximum=bad,exclusiveMinimum=bad,exclusiveMaximum=bad,default=bad"` - Name string `json:"name" jsonschema:"required,minLength=1,maxLength=20,pattern=.*,description=this is a property,title=the name,example=joe,example=lucy,default=alex,readOnly=true"` - Password string `json:"password" jsonschema:"writeOnly=true"` - Friends []int `json:"friends,omitempty" jsonschema_description:"list of IDs, omitted when empty"` + ID int `json:"id" jsonschema:"required,minimum=bad,maximum=bad,exclusiveMinimum=bad,exclusiveMaximum=bad,default=bad"` + Name string `json:"name" jsonschema:"required,minLength=1,maxLength=20,pattern=.*,description=this is a property,title=the name,example=joe,example=lucy,default=alex,readOnly=true"` + Password string `json:"password" jsonschema:"writeOnly=true"` + Friends []int `json:"friends,omitempty" jsonschema_description:"list of IDs, omitted when empty"` Tags map[string]string `json:"tags,omitempty"` Options map[string]any `json:"options,omitempty"` @@ -168,29 +179,29 @@ type TestUser struct { IPAddress net.IP `json:"network_address,omitempty"` // Tests for RFC draft-wright-json-schema-hyperschema-00, section 4 - Photo []byte `json:"photo,omitempty" jsonschema:"required"` + Photo []byte `json:"photo,omitempty" jsonschema:"required"` Photo2 Bytes `json:"photo2,omitempty" jsonschema:"required"` // Tests for jsonpb enum support Feeling ProtoEnum `json:"feeling,omitempty"` - Age int `json:"age" jsonschema:"minimum=18,maximum=120,exclusiveMaximum=121,exclusiveMinimum=17"` + Age int `json:"age" jsonschema:"minimum=18,maximum=120,exclusiveMaximum=121,exclusiveMinimum=17"` Email string `json:"email" jsonschema:"format=email"` - UUID string `json:"uuid" jsonschema:"format=uuid"` + UUID string `json:"uuid" jsonschema:"format=uuid"` // Test for "extras" support Baz string `jsonschema_extras:"foo=bar,hello=world,foo=bar1"` - BoolExtra string `json:"bool_extra,omitempty" jsonschema_extras:"isTrue=true,isFalse=false"` + BoolExtra string `jsonschema_extras:"isTrue=true,isFalse=false" json:"bool_extra,omitempty"` // Tests for simple enum tags - Color string `json:"color" jsonschema:"enum=red,enum=green,enum=blue"` + Color string `json:"color" jsonschema:"enum=red,enum=green,enum=blue"` Rank int `json:"rank,omitempty" jsonschema:"enum=1,enum=2,enum=3"` Multiplier float64 `json:"mult,omitempty" jsonschema:"enum=1.0,enum=1.5,enum=2.0"` // Tests for enum tags on slices - Roles []string `json:"roles" jsonschema:"enum=admin,enum=moderator,enum=user"` + Roles []string `json:"roles" jsonschema:"enum=admin,enum=moderator,enum=user"` Priorities []int `json:"priorities,omitempty" jsonschema:"enum=-1,enum=0,enum=1,enun=2"` - Offsets []float64 `json:"offsets,omitempty" jsonschema:"enum=1.570796,enum=3.141592,enum=6.283185"` + Offsets []float64 `json:"offsets,omitempty" jsonschema:"enum=1.570796,enum=3.141592,enum=6.283185"` // Test for raw JSON Anything any `json:"anything,omitempty"` @@ -397,7 +408,7 @@ type KeyNamed struct { NotComingFromJSON bool `json:"coming_from_json_tag_not_renamed"` NestedNotRenamed KeyNamedNested `json:"nested_not_renamed"` UnicodeShenanigans string - RenamedByComputation int `jsonschema_description:"Description was preserved"` + RenamedByComputation int ` jsonschema_description:"Description was preserved"` } type SchemaExtendTestBase struct { @@ -425,7 +436,7 @@ type Expression struct { type PatternEqualsTest struct { WithEquals string `jsonschema:"pattern=foo=bar"` - WithEqualsAndCommas string `jsonschema:"pattern=foo\\,=bar"` + WithEqualsAndCommas string `jsonschema:"pattern=foo,=bar"` } func TestReflector(t *testing.T) { @@ -441,7 +452,11 @@ func TestReflectFromType(t *testing.T) { typ := reflect.TypeOf(tu) s := r.ReflectFromType(typ) - assert.EqualValues(t, "https://github.com/conneroisu/groq-go/test-user", s.ID) + assert.EqualValues( + t, + "https://github.com/conneroisu/groq-go/test-user", + s.ID, + ) x := struct { Test string @@ -461,15 +476,51 @@ func TestSchemaGeneration(t *testing.T) { }{ {&TestUser{}, &reflector{}, "testdata/test_user.json"}, {&UserWithAnchor{}, &reflector{}, "testdata/user_with_anchor.json"}, - {&TestUser{}, &reflector{AssignAnchor: true}, "testdata/test_user_assign_anchor.json"}, - {&TestUser{}, &reflector{AllowAdditionalProperties: true}, "testdata/allow_additional_props.json"}, - {&TestUser{}, &reflector{RequiredFromJSONSchemaTags: true}, "testdata/required_from_jsontags.json"}, - {&TestUser{}, &reflector{ExpandedStruct: true}, "testdata/defaults_expanded_toplevel.json"}, - {&TestUser{}, &reflector{IgnoredTypes: []any{GrandfatherType{}}}, "testdata/ignore_type.json"}, - {&TestUser{}, &reflector{DoNotReference: true}, "testdata/no_reference.json"}, - {&TestUser{}, &reflector{DoNotReference: true, AssignAnchor: true}, "testdata/no_reference_anchor.json"}, - {&RootOneOf{}, &reflector{RequiredFromJSONSchemaTags: true}, "testdata/oneof.json"}, - {&RootAnyOf{}, &reflector{RequiredFromJSONSchemaTags: true}, "testdata/anyof.json"}, + { + &TestUser{}, + &reflector{AssignAnchor: true}, + "testdata/test_user_assign_anchor.json", + }, + { + &TestUser{}, + &reflector{AllowAdditionalProperties: true}, + "testdata/allow_additional_props.json", + }, + { + &TestUser{}, + &reflector{RequiredFromJSONSchemaTags: true}, + "testdata/required_from_jsontags.json", + }, + { + &TestUser{}, + &reflector{ExpandedStruct: true}, + "testdata/defaults_expanded_toplevel.json", + }, + { + &TestUser{}, + &reflector{IgnoredTypes: []any{GrandfatherType{}}}, + "testdata/ignore_type.json", + }, + { + &TestUser{}, + &reflector{DoNotReference: true}, + "testdata/no_reference.json", + }, + { + &TestUser{}, + &reflector{DoNotReference: true, AssignAnchor: true}, + "testdata/no_reference_anchor.json", + }, + { + &RootOneOf{}, + &reflector{RequiredFromJSONSchemaTags: true}, + "testdata/oneof.json", + }, + { + &RootAnyOf{}, + &reflector{RequiredFromJSONSchemaTags: true}, + "testdata/anyof.json", + }, {&CustomTypeField{}, &reflector{ Mapper: func(i reflect.Type) *schema { if i == reflect.TypeOf(CustomTime{}) { @@ -481,7 +532,11 @@ func TestSchemaGeneration(t *testing.T) { return nil }, }, "testdata/custom_type.json"}, - {LookupUser{}, &reflector{BaseSchemaID: "https://example.com/schemas"}, "testdata/base_schema_id.json"}, + { + LookupUser{}, + &reflector{BaseSchemaID: "https://example.com/schemas"}, + "testdata/base_schema_id.json", + }, {LookupUser{}, &reflector{ Lookup: func(i reflect.Type) schemaID { switch i { @@ -507,11 +562,31 @@ func TestSchemaGeneration(t *testing.T) { return EmptyID }, }, "testdata/lookup_expanded.json"}, - {&Outer{}, &reflector{ExpandedStruct: true}, "testdata/inlining_inheritance.json"}, - {&OuterNamed{}, &reflector{ExpandedStruct: true}, "testdata/inlining_embedded.json"}, - {&OuterNamed{}, &reflector{ExpandedStruct: true, AssignAnchor: true}, "testdata/inlining_embedded_anchored.json"}, - {&OuterInlined{}, &reflector{ExpandedStruct: true}, "testdata/inlining_tag.json"}, - {&OuterPtr{}, &reflector{ExpandedStruct: true}, "testdata/inlining_ptr.json"}, + { + &Outer{}, + &reflector{ExpandedStruct: true}, + "testdata/inlining_inheritance.json", + }, + { + &OuterNamed{}, + &reflector{ExpandedStruct: true}, + "testdata/inlining_embedded.json", + }, + { + &OuterNamed{}, + &reflector{ExpandedStruct: true, AssignAnchor: true}, + "testdata/inlining_embedded_anchored.json", + }, + { + &OuterInlined{}, + &reflector{ExpandedStruct: true}, + "testdata/inlining_tag.json", + }, + { + &OuterPtr{}, + &reflector{ExpandedStruct: true}, + "testdata/inlining_ptr.json", + }, {&MinValue{}, &reflector{}, "testdata/schema_with_minimum.json"}, {&TestNullable{}, &reflector{}, "testdata/nullable.json"}, {&GrandfatherType{}, &reflector{ @@ -526,12 +601,19 @@ func TestSchemaGeneration(t *testing.T) { } }, }, "testdata/custom_additional.json"}, - {&TestDescriptionOverride{}, &reflector{}, "testdata/test_description_override.json"}, + { + &TestDescriptionOverride{}, + &reflector{}, + "testdata/test_description_override.json", + }, {&CompactDate{}, &reflector{}, "testdata/compact_date.json"}, {&CustomSliceOuter{}, &reflector{}, "testdata/custom_slice_type.json"}, {&CustomMapOuter{}, &reflector{}, "testdata/custom_map_type.json"}, - {&CustomTypeFieldWithInterface{}, &reflector{}, "testdata/custom_type_with_interface.json"}, - {&PatternTest{}, &reflector{}, "testdata/commas_in_pattern.json"}, + { + &CustomTypeFieldWithInterface{}, + &reflector{}, + "testdata/custom_type_with_interface.json", + }, {&RecursiveExample{}, &reflector{}, "testdata/recursive.json"}, {&KeyNamed{}, &reflector{ KeyNamer: func(s string) string { @@ -560,7 +642,7 @@ func TestSchemaGeneration(t *testing.T) { {ArrayType{}, &reflector{}, "testdata/array_type.json"}, {SchemaExtendTest{}, &reflector{}, "testdata/custom_type_extend.json"}, {Expression{}, &reflector{}, "testdata/schema_with_expression.json"}, - {PatternEqualsTest{}, &reflector{}, "testdata/equals_in_pattern.json"}, + {&PatternTest{}, &reflector{}, "testdata/commas_in_pattern.json"}, } for _, tt := range tests { @@ -584,7 +666,11 @@ func compareSchemaOutput(t *testing.T, f string, r *reflector, obj any) { require.NoError(t, err) actualSchema := r.Reflect(obj) - actualJSON, _ := json.MarshalIndent(actualSchema, "", " ") //nolint:errchkjson + actualJSON, _ := json.MarshalIndent( + actualSchema, + "", + " ", + ) //nolint:errchkjson if *updateFixtures { _ = os.WriteFile(f, actualJSON, 0600) @@ -592,7 +678,11 @@ func compareSchemaOutput(t *testing.T, f string, r *reflector, obj any) { if !assert.JSONEq(t, string(expectedJSON), string(actualJSON)) { if *compareFixtures { - _ = os.WriteFile(strings.TrimSuffix(f, ".json")+".out.json", actualJSON, 0600) + _ = os.WriteFile( + strings.TrimSuffix(f, ".json")+".out.json", + actualJSON, + 0600, + ) } } } @@ -609,7 +699,10 @@ func TestSplitOnUnescapedCommas(t *testing.T) { strToSplit string expected []string }{ - {`Hello,this,is\,a\,string,haha`, []string{`Hello`, `this`, `is,a,string`, `haha`}}, + { + `Hello,this,is\,a\,string,haha`, + []string{`Hello`, `this`, `is,a,string`, `haha`}, + }, {`hello,no\\,split`, []string{`hello`, `no\,split`}}, {`string without commas`, []string{`string without commas`}}, {`ünicode,𐂄,Ж\,П,ᠳ`, []string{`ünicode`, `𐂄`, `Ж,П`, `ᠳ`}}, @@ -656,9 +749,9 @@ func TestFieldNameTag(t *testing.T) { func TestFieldOneOfRef(t *testing.T) { type Server struct { - IPAddress any `json:"ip_address,omitempty" jsonschema:"oneof_ref=#/$defs/ipv4;#/$defs/ipv6"` - IPAddresses []any `json:"ip_addresses,omitempty" jsonschema:"oneof_ref=#/$defs/ipv4;#/$defs/ipv6"` - IPAddressAny any `json:"ip_address_any,omitempty" jsonschema:"anyof_ref=#/$defs/ipv4;#/$defs/ipv6"` + IPAddress any `json:"ip_address,omitempty" jsonschema:"oneof_ref=#/$defs/ipv4;#/$defs/ipv6"` + IPAddresses []any `json:"ip_addresses,omitempty" jsonschema:"oneof_ref=#/$defs/ipv4;#/$defs/ipv6"` + IPAddressAny any `json:"ip_address_any,omitempty" jsonschema:"anyof_ref=#/$defs/ipv4;#/$defs/ipv6"` IPAddressesAny []any `json:"ip_addresses_any,omitempty" jsonschema:"anyof_ref=#/$defs/ipv4;#/$defs/ipv6"` } @@ -668,12 +761,17 @@ func TestFieldOneOfRef(t *testing.T) { func TestNumberHandling(t *testing.T) { type NumberHandler struct { - Int64 int64 `json:"int64" jsonschema:"default=12"` + Int64 int64 `json:"int64" jsonschema:"default=12"` Float32 float32 `json:"float32" jsonschema:"default=12.5"` } r := &reflector{} - compareSchemaOutput(t, "testdata/number_handling.json", r, &NumberHandler{}) + compareSchemaOutput( + t, + "testdata/number_handling.json", + r, + &NumberHandler{}, + ) fixtureContains(t, "testdata/number_handling.json", `"default": 12`) fixtureContains(t, "testdata/number_handling.json", `"default": 12.5`) } @@ -692,14 +790,19 @@ func TestArrayHandling(t *testing.T) { func TestUnsignedIntHandling(t *testing.T) { type UnsignedIntHandler struct { - MinLen []string `json:"min_len" jsonschema:"minLength=0"` - MaxLen []string `json:"max_len" jsonschema:"maxLength=0"` + MinLen []string `json:"min_len" jsonschema:"minLength=0"` + MaxLen []string `json:"max_len" jsonschema:"maxLength=0"` MinItems []string `json:"min_items" jsonschema:"minItems=0"` MaxItems []string `json:"max_items" jsonschema:"maxItems=0"` } r := &reflector{} - compareSchemaOutput(t, "testdata/unsigned_int_handling.json", r, &UnsignedIntHandler{}) + compareSchemaOutput( + t, + "testdata/unsigned_int_handling.json", + r, + &UnsignedIntHandler{}, + ) fixtureContains(t, "testdata/unsigned_int_handling.json", `"minLength": 0`) fixtureContains(t, "testdata/unsigned_int_handling.json", `"maxLength": 0`) fixtureContains(t, "testdata/unsigned_int_handling.json", `"minItems": 0`) @@ -709,11 +812,16 @@ func TestUnsignedIntHandling(t *testing.T) { func TestJSONSchemaFormat(t *testing.T) { type WithCustomFormat struct { Dates []string `json:"dates" jsonschema:"format=date"` - Odds []string `json:"odds" jsonschema:"format=odd"` + Odds []string `json:"odds" jsonschema:"format=odd"` } r := &reflector{} - compareSchemaOutput(t, "testdata/with_custom_format.json", r, &WithCustomFormat{}) + compareSchemaOutput( + t, + "testdata/with_custom_format.json", + r, + &WithCustomFormat{}, + ) fixtureContains(t, "testdata/with_custom_format.json", `"format": "date"`) fixtureContains(t, "testdata/with_custom_format.json", `"format": "odd"`) } @@ -744,7 +852,12 @@ func (AliasObjectB) JSONSchemaAlias() any { func TestJSONSchemaProperty(t *testing.T) { r := &reflector{} - compareSchemaOutput(t, "testdata/schema_property_alias.json", r, &AliasPropertyObjectBase{}) + compareSchemaOutput( + t, + "testdata/schema_property_alias.json", + r, + &AliasPropertyObjectBase{}, + ) } func TestJSONSchemaAlias(t *testing.T) { diff --git a/scripts/makefile/fmt.sh b/scripts/makefile/fmt.sh index 9de9ede..1e79135 100644 --- a/scripts/makefile/fmt.sh +++ b/scripts/makefile/fmt.sh @@ -7,4 +7,4 @@ gofmt -w . -golines -w --max-len=79 . +# golines -w --max-len=79 . diff --git a/stream_test.go b/stream_test.go index 1ff37d1..6750024 100644 --- a/stream_test.go +++ b/stream_test.go @@ -1,4 +1,7 @@ -package groq //nolint:testpackage // testing private field +//go:build !test +// +build !test + +package groq import ( "bufio" diff --git a/test/unit_test.go b/test/unit_test.go deleted file mode 100644 index 59dd4dd..0000000 --- a/test/unit_test.go +++ /dev/null @@ -1,74 +0,0 @@ -package test - -import ( - "context" - "errors" - "fmt" - "io" - "math/rand" - "os" - "testing" - - "github.com/conneroisu/groq-go" - "github.com/stretchr/testify/assert" -) - -func TestTestServer(t *testing.T) { - num := rand.Intn(100) - a := assert.New(t) - ctx := context.Background() - client, err := groq.NewClient(os.Getenv("GROQ_KEY")) - a.NoError(err, "NewClient error") - strm, err := client.CreateChatCompletionStream( - ctx, - groq.ChatCompletionRequest{ - Model: groq.Llama38B8192, - Messages: []groq.ChatCompletionMessage{ - { - Role: groq.ChatMessageRoleUser, - // convert the content of a excel file into a csv that can be imported into google calendar: - // - // - // - // Course ListingSectionInstructional FormatDelivery ModeMeeting PatternsInstructor - // - // CPRE 3810 - Computer Organization and Assembly Level ProgrammingCPRE 3810-1 - Computer Organization and Assembly Level ProgrammingLectureIn-PersonMWF | 8:50 AM - 9:40 AM | 0101 CARVER - Carver HallBerk Gulmezoglu - // - // EE 3240 - Signals and Systems IIEE 3240-1 - Signals and Systems IILectureIn-PersonMWF | 2:15 PM - 3:05 PM | 1134 SWEENEY - Sweeney HallRatnesh Kumar - // - // EE 3220 - Probabilistic Methods for Electrical EngineersEE 3220-1 - Probabilistic Methods for Electrical EngineersLectureIn-PersonTR | 11:00 AM - 12:15 PM | 1134 SWEENEY - Sweeney HallJulie A Dickerson - // - // SOC 1340 - Introduction to SociologySOC 1340-2 - Introduction to SociologyLectureIn-PersonMWF | 11:00 AM - 11:50 AM | 0127 CURTISS - Curtiss HallDavid Scott Schweingruber - // - // CPRE 3810 - Computer Organization and Assembly Level ProgrammingCPRE 3810-F - Computer Organization and Assembly Level ProgrammingLaboratoryIn-PersonR | 8:00 AM - 9:50 AM | 2050 COOVER - Coover HallBerk Gulmezoglu - // - // EE 3240 - Signals and Systems IIEE 3240-C - Signals and Systems IILaboratoryIn-PersonR | 4:10 PM - 7:00 PM | 2011 COOVER - Coover HallRatnesh Kumar - // - Content: fmt.Sprintf(` -problem: %d -You have a six-sided die that you roll once. Let $R{i}$ denote the event that the roll is $i$. Let $G{j}$ denote the event that the roll is greater than $j$. Let $E$ denote the event that the roll of the die is even-numbered. -(a) What is $P\left[R{3} \mid G{1}\right]$, the conditional probability that 3 is rolled given that the roll is greater than 1 ? -(b) What is the conditional probability that 6 is rolled given that the roll is greater than 3 ? -(c) What is the $P\left[G_{3} \mid E\right]$, the conditional probability that the roll is greater than 3 given that the roll is even? -(d) Given that the roll is greater than 3, what is the conditional probability that the roll is even? - `, num, - ), - }, - }, - MaxTokens: 2000, - Stream: true, - }, - ) - a.NoError(err, "CreateCompletionStream error") - - i := 0 - for { - i++ - val, err := strm.Recv() - if errors.Is(err, io.EOF) { - break - } - // t.Logf("%d %s\n", i, val.Choices[0].Delta.Content) - print(val.Choices[0].Delta.Content) - } -} diff --git a/unit_unit_test.go b/unit_unit_test.go new file mode 100644 index 0000000..26d0d46 --- /dev/null +++ b/unit_unit_test.go @@ -0,0 +1,59 @@ +//go:build !test +// +build !test + +package groq_test + +import ( + "context" + "errors" + "fmt" + "io" + "math/rand" + "os" + "testing" + + "github.com/conneroisu/groq-go" + "github.com/stretchr/testify/assert" +) + +func TestTestServer(t *testing.T) { + num := rand.Intn(100) + a := assert.New(t) + ctx := context.Background() + client, err := groq.NewClient(os.Getenv("GROQ_KEY")) + a.NoError(err, "NewClient error") + strm, err := client.CreateChatCompletionStream( + ctx, + groq.ChatCompletionRequest{ + Model: groq.Llama38B8192, + Messages: []groq.ChatCompletionMessage{ + { + Role: groq.ChatMessageRoleUser, + Content: fmt.Sprintf(` +problem: %d +You have a six-sided die that you roll once. Let $R{i}$ denote the event that the roll is $i$. Let $G{j}$ denote the event that the roll is greater than $j$. Let $E$ denote the event that the roll of the die is even-numbered. +(a) What is $P\left[R{3} \mid G{1}\right]$, the conditional probability that 3 is rolled given that the roll is greater than 1 ? +(b) What is the conditional probability that 6 is rolled given that the roll is greater than 3 ? +(c) What is the $P\left[G_{3} \mid E\right]$, the conditional probability that the roll is greater than 3 given that the roll is even? +(d) Given that the roll is greater than 3, what is the conditional probability that the roll is even? + `, num, + ), + }, + }, + MaxTokens: 2000, + Stream: true, + }, + ) + a.NoError(err, "CreateCompletionStream error") + + i := 0 + for { + i++ + val, err := strm.Recv() + if errors.Is(err, io.EOF) { + break + } + // t.Logf("%d %s\n", i, val.Choices[0].Delta.Content) + print(val.Choices[0].Delta.Content) + } +}