diff --git a/.github/workflows/chron-models.yaml b/.github/workflows/chron-models.yaml
index d223d82..f890077 100644
--- a/.github/workflows/chron-models.yaml
+++ b/.github/workflows/chron-models.yaml
@@ -18,7 +18,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v4
with:
- go-version: '1.23'
+ go-version: '1.23.1'
# Step 3: Run go mod download
- name: Run go mod download
diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml
index 1162c9e..5dfa339 100644
--- a/.github/workflows/ci.yaml
+++ b/.github/workflows/ci.yaml
@@ -25,24 +25,14 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
- go-version: '1.22'
+ go-version: '1.23.1'
cache: true
- name: Install requirements
id: install-lint-requirements
run: |
go mod download
- go install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.58.1
- go install github.com/sqlc-dev/sqlc/cmd/sqlc@latest
+ go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
- name: Lint
id: lint
run: |
golangci-lint run
- - name: Vet
- id: vet
- run: |
- sqlc vet
- go vet ./...
- - name: Run Revive Action by pulling pre-built image
- uses: docker://morphy/revive-action:v2
- with:
- config: .revive.toml
diff --git a/.github/workflows/coverage.yaml b/.github/workflows/coverage.yaml
index c9383d6..0b16641 100644
--- a/.github/workflows/coverage.yaml
+++ b/.github/workflows/coverage.yaml
@@ -8,7 +8,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v3
with:
- go-version: '1.23'
+ go-version: '1.23.1'
- name: Check out code
uses: actions/checkout@v3
- name: Install dependencies
diff --git a/.golangci.reference.yaml b/.golangci.reference.yaml
new file mode 100644
index 0000000..7e7529f
--- /dev/null
+++ b/.golangci.reference.yaml
@@ -0,0 +1,22 @@
+#
+# Options for analysis running.
+run:
+ # See the dedicated "run" documentation section.
+ option: value
+# output configuration options
+output:
+ # See the dedicated "output" documentation section.
+ option: value
+# All available settings of specific linters.
+linters-settings:
+ # See the dedicated "linters-settings" documentation section.
+ option: value
+linters:
+ # See the dedicated "linters" documentation section.
+ option: value
+issues:
+ # See the dedicated "issues" documentation section.
+ option: value
+severity:
+ # See the dedicated "severity" documentation section.
+ option: value
diff --git a/.golangci.yml b/.golangci.yml
new file mode 100644
index 0000000..7cfc908
--- /dev/null
+++ b/.golangci.yml
@@ -0,0 +1,14 @@
+linters-settings:
+ revive:
+ ignore-generated-header: true
+ severity: warning
+ rules:
+ - name: atomic
+ - name: line-length-limit
+ severity: error
+ arguments: [80]
+ - name: unhandled-error
+ arguments : ["fmt.Printf", "myFunction"]
+
+linters:
+ disable-all: false
diff --git a/audio_test.go b/audio_test.go
index 9789b69..98032eb 100644
--- a/audio_test.go
+++ b/audio_test.go
@@ -1,4 +1,7 @@
-package groq //nolint:testpackage // testing private field
+//go:build !test
+// +build !test
+
+package groq
import (
"bytes"
diff --git a/builders_test.go b/builders_test.go
index f5f819c..f7fd8d1 100644
--- a/builders_test.go
+++ b/builders_test.go
@@ -1,7 +1,7 @@
//go:build !test
// +build !test
-package groq // testing private field
+package groq
import (
"bytes"
diff --git a/chat_test.go b/chat_test.go
index 26642ae..07a7708 100644
--- a/chat_test.go
+++ b/chat_test.go
@@ -1,594 +1,4 @@
-package groq_test
+//go:build !test
+// +build !test
-import (
- "context"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "net/http"
- "strconv"
- "testing"
-
- groq "github.com/conneroisu/groq-go"
- "github.com/stretchr/testify/assert"
-)
-
-// TestCreateChatCompletionStream tests the CreateChatCompletionStream method.
-func TestCreateChatCompletionStream(t *testing.T) {
- a := assert.New(t)
- client, server, teardown := setupGroqTestServer()
- defer teardown()
- server.RegisterHandler(
- "/v1/chat/completions",
- func(w http.ResponseWriter, _ *http.Request) {
- w.Header().Set("Content-Type", "text/event-stream")
-
- // Send test responses
- dataBytes := []byte{}
- dataBytes = append(dataBytes, []byte("event: message\n")...)
- data := `{"id":"1","object":"completion","created":1598069254,"model":"llama3-groq-70b-8192-tool-use-preview","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}]}`
- dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)
-
- dataBytes = append(dataBytes, []byte("event: message\n")...)
- data = `{"id":"2","object":"completion","created":1598069255,"model":"llama3-groq-70b-8192-tool-use-preview","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}]}`
- dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)
-
- dataBytes = append(dataBytes, []byte("event: done\n")...)
- dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...)
-
- _, err := w.Write(dataBytes)
- a.NoError(err, "Write error")
- },
- )
-
- stream, err := client.CreateChatCompletionStream(
- context.Background(),
- groq.ChatCompletionRequest{
- MaxTokens: 5,
- Model: groq.Llama38B8192,
- Messages: []groq.ChatCompletionMessage{
- {
- Role: groq.ChatMessageRoleUser,
- Content: "Hello!",
- },
- },
- Stream: true,
- },
- )
- a.NoError(err, "CreateCompletionStream returned error")
- defer stream.Close()
-
- expectedResponses := []groq.ChatCompletionStreamResponse{
- {
- ID: "1",
- Object: "completion",
- Created: 1598069254,
- Model: groq.Llama38B8192,
- SystemFingerprint: "fp_d9767fc5b9",
- Choices: []groq.ChatCompletionStreamChoice{
- {
- Delta: groq.ChatCompletionStreamChoiceDelta{
- Content: "response1",
- },
- FinishReason: "max_tokens",
- },
- },
- },
- {
- ID: "2",
- Object: "completion",
- Created: 1598069255,
- Model: groq.Llama38B8192,
- SystemFingerprint: "fp_d9767fc5b9",
- Choices: []groq.ChatCompletionStreamChoice{
- {
- Delta: groq.ChatCompletionStreamChoiceDelta{
- Content: "response2",
- },
- FinishReason: "max_tokens",
- },
- },
- },
- }
-
- for ix, expectedResponse := range expectedResponses {
- b, _ := json.Marshal(expectedResponse)
- t.Logf("%d: %s", ix, string(b))
-
- receivedResponse, streamErr := stream.Recv()
- a.NoError(streamErr, "stream.Recv() failed")
- if !compareChatResponses(t, expectedResponse, receivedResponse) {
- t.Errorf(
- "Stream response %v is %v, expected %v",
- ix,
- receivedResponse,
- expectedResponse,
- )
- }
- }
-
- _, streamErr := stream.Recv()
- if !errors.Is(streamErr, io.EOF) {
- t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr)
- }
-
- _, streamErr = stream.Recv()
- if !errors.Is(streamErr, io.EOF) {
- t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr)
- }
-
- _, streamErr = stream.Recv()
-
- a.ErrorIs(
- streamErr,
- io.EOF,
- "stream.Recv() did not return EOF when the stream is finished",
- )
- if !errors.Is(streamErr, io.EOF) {
- t.Errorf(
- "stream.Recv() did not return EOF when the stream is finished: %v",
- streamErr,
- )
- }
-}
-
-// TestCreateChatCompletionStreamError tests the CreateChatCompletionStream function with an error
-// in the response.
-func TestCreateChatCompletionStreamError(t *testing.T) {
- a := assert.New(t)
- client, server, teardown := setupGroqTestServer()
- defer teardown()
- server.RegisterHandler(
- "/v1/chat/completions",
- func(w http.ResponseWriter, _ *http.Request) {
- w.Header().Set("Content-Type", "text/event-stream")
-
- // Send test responses
- dataBytes := []byte{}
- dataStr := []string{
- `{`,
- `"error": {`,
- `"message": "Incorrect API key provided: gsk-***************************************",`,
- `"type": "invalid_request_error",`,
- `"param": null,`,
- `"code": "invalid_api_key"`,
- `}`,
- `}`,
- }
- for _, str := range dataStr {
- dataBytes = append(dataBytes, []byte(str+"\n")...)
- }
-
- _, err := w.Write(dataBytes)
- a.NoError(err, "Write error")
- },
- )
-
- stream, err := client.CreateChatCompletionStream(
- context.Background(),
- groq.ChatCompletionRequest{
- MaxTokens: 5,
- Model: groq.Llama38B8192,
- Messages: []groq.ChatCompletionMessage{
- {
- Role: groq.ChatMessageRoleUser,
- Content: "Hello!",
- },
- },
- Stream: true,
- },
- )
- a.NoError(err, "CreateCompletionStream returned error")
- defer stream.Close()
-
- _, streamErr := stream.Recv()
- a.Error(streamErr, "stream.Recv() did not return error")
-
- var apiErr *groq.APIError
- if !errors.As(streamErr, &apiErr) {
- t.Errorf("stream.Recv() did not return APIError")
- }
- t.Logf("%+v\n", apiErr)
-}
-
-func TestCreateChatCompletionStreamWithHeaders(t *testing.T) {
- a := assert.New(t)
- client, server, teardown := setupGroqTestServer()
- defer teardown()
- xCustomHeader := "x-custom-header"
- xCustomHeaderValue := "x-custom-header-value"
- server.RegisterHandler(
- "/v1/chat/completions",
- func(w http.ResponseWriter, _ *http.Request) {
- w.Header().Set("Content-Type", "text/event-stream")
- w.Header().Set(xCustomHeader, xCustomHeaderValue)
-
- // Send test responses
- dataBytes := []byte(
- `data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`,
- )
- dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...)
-
- _, err := w.Write(dataBytes)
- a.NoError(err, "Write error")
- },
- )
-
- stream, err := client.CreateChatCompletionStream(
- context.Background(),
- groq.ChatCompletionRequest{
- MaxTokens: 5,
- Model: groq.Llama38B8192,
- Messages: []groq.ChatCompletionMessage{
- {
- Role: groq.ChatMessageRoleUser,
- Content: "Hello!",
- },
- },
- Stream: true,
- },
- )
- a.NoError(err, "CreateCompletionStream returned error")
- defer stream.Close()
-
- value := stream.Header.Get(xCustomHeader)
- if value != xCustomHeaderValue {
- t.Errorf("expected %s to be %s", xCustomHeaderValue, value)
- }
-}
-
-func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) {
- client, server, teardown := setupGroqTestServer()
- a := assert.New(t)
- rateLimitHeaders := map[string]interface{}{
- "x-ratelimit-limit-requests": 100,
- "x-ratelimit-limit-tokens": 1000,
- "x-ratelimit-remaining-requests": 99,
- "x-ratelimit-remaining-tokens": 999,
- "x-ratelimit-reset-requests": "1s",
- "x-ratelimit-reset-tokens": "1m",
- }
- defer teardown()
- server.RegisterHandler(
- "/v1/chat/completions",
- func(w http.ResponseWriter, _ *http.Request) {
- w.Header().Set("Content-Type", "text/event-stream")
- for k, v := range rateLimitHeaders {
- switch val := v.(type) {
- case int:
- w.Header().Set(k, strconv.Itoa(val))
- default:
- w.Header().Set(k, fmt.Sprintf("%s", v))
- }
- }
-
- // Send test responses
- dataBytes := []byte(
- `data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`,
- )
- dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...)
-
- _, err := w.Write(dataBytes)
- a.NoError(err, "Write error")
- },
- )
-
- stream, err := client.CreateChatCompletionStream(
- context.Background(),
- groq.ChatCompletionRequest{
- MaxTokens: 5,
- Model: groq.Llama38B8192,
- Messages: []groq.ChatCompletionMessage{
- {
- Role: groq.ChatMessageRoleUser,
- Content: "Hello!",
- },
- },
- Stream: true,
- },
- )
- a.NoError(err, "CreateCompletionStream returned error")
- defer stream.Close()
-
- headers := newRateLimitHeaders(stream.Header)
- bs1, _ := json.Marshal(headers)
- bs2, _ := json.Marshal(rateLimitHeaders)
- if string(bs1) != string(bs2) {
- t.Errorf("expected rate limit header %s to be %s", bs2, bs1)
- }
-}
-
-// newRateLimitHeaders creates a new RateLimitHeaders from an http.Header.
-func newRateLimitHeaders(h http.Header) groq.RateLimitHeaders {
- limitReq, _ := strconv.Atoi(h.Get("x-ratelimit-limit-requests"))
- limitTokens, _ := strconv.Atoi(h.Get("x-ratelimit-limit-tokens"))
- remainingReq, _ := strconv.Atoi(h.Get("x-ratelimit-remaining-requests"))
- remainingTokens, _ := strconv.Atoi(h.Get("x-ratelimit-remaining-tokens"))
- return groq.RateLimitHeaders{
- LimitRequests: limitReq,
- LimitTokens: limitTokens,
- RemainingRequests: remainingReq,
- RemainingTokens: remainingTokens,
- ResetRequests: groq.ResetTime(h.Get("x-ratelimit-reset-requests")),
- ResetTokens: groq.ResetTime(h.Get("x-ratelimit-reset-tokens")),
- }
-}
-
-func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) {
- a := assert.New(t)
- client, server, teardown := setupGroqTestServer()
- defer teardown()
- server.RegisterHandler(
- "/v1/chat/completions",
- func(w http.ResponseWriter, _ *http.Request) {
- w.Header().Set("Content-Type", "text/event-stream")
-
- // Send test responses
- dataBytes := []byte(
- `data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`,
- )
- dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...)
-
- _, err := w.Write(dataBytes)
- a.NoError(err, "Write error")
- },
- )
-
- stream, err := client.CreateChatCompletionStream(
- context.Background(),
- groq.ChatCompletionRequest{
- MaxTokens: 5,
- Model: groq.Llama38B8192,
- Messages: []groq.ChatCompletionMessage{
- {
- Role: groq.ChatMessageRoleUser,
- Content: "Hello!",
- },
- },
- Stream: true,
- },
- )
- a.NoError(err, "CreateCompletionStream returned error")
- defer stream.Close()
-
- _, streamErr := stream.Recv()
- a.Error(streamErr, "stream.Recv() did not return error")
-
- var apiErr *groq.APIError
- if !errors.As(streamErr, &apiErr) {
- t.Errorf("stream.Recv() did not return APIError")
- }
- t.Logf("%+v\n", apiErr)
-}
-
-func TestCreateChatCompletionStreamRateLimitError(t *testing.T) {
- a := assert.New(t)
- client, server, teardown := setupGroqTestServer()
- defer teardown()
- server.RegisterHandler(
- "/v1/chat/completions",
- func(w http.ResponseWriter, _ *http.Request) {
- w.Header().Set("Content-Type", "application/json")
- w.WriteHeader(429)
-
- // Send test responses
- dataBytes := []byte(`{"error":{` +
- `"message": "You are sending requests too quickly.",` +
- `"type":"rate_limit_reached",` +
- `"param":null,` +
- `"code":"rate_limit_reached"}}`)
-
- _, err := w.Write(dataBytes)
- a.NoError(err, "Write error")
- },
- )
- _, err := client.CreateChatCompletionStream(
- context.Background(),
- groq.ChatCompletionRequest{
- MaxTokens: 5,
- Model: groq.Llama38B8192,
- Messages: []groq.ChatCompletionMessage{
- {
- Role: groq.ChatMessageRoleUser,
- Content: "Hello!",
- },
- },
- Stream: true,
- },
- )
- var apiErr *groq.APIError
- if !errors.As(err, &apiErr) {
- t.Errorf(
- "TestCreateChatCompletionStreamRateLimitError did not return APIError",
- )
- }
- t.Logf("%+v\n", apiErr)
-}
-
-func TestCreateChatCompletionStreamStreamOptions(t *testing.T) {
- a := assert.New(t)
- client, server, teardown := setupGroqTestServer()
- defer teardown()
-
- server.RegisterHandler(
- "/v1/chat/completions",
- func(w http.ResponseWriter, _ *http.Request) {
- w.Header().Set("Content-Type", "text/event-stream")
-
- // Send test responses
- var dataBytes []byte
- data := `{"id":"1","object":"completion","created":1598069254,"model":"llama3-groq-70b-8192-tool-use-preview","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}],"usage":null}`
- dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)
-
- data = `{"id":"2","object":"completion","created":1598069255,"model":"llama3-groq-70b-8192-tool-use-preview","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}],"usage":null}`
- dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)
-
- data = `{"id":"3","object":"completion","created":1598069256,"model":"llama3-groq-70b-8192-tool-use-preview","system_fingerprint": "fp_d9767fc5b9","choices":[],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`
- dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)
-
- dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...)
-
- _, err := w.Write(dataBytes)
- a.NoError(err, "Write error")
- },
- )
-
- stream, err := client.CreateChatCompletionStream(
- context.Background(),
- groq.ChatCompletionRequest{
- MaxTokens: 5,
- Model: groq.Llama38B8192,
- Messages: []groq.ChatCompletionMessage{
- {
- Role: groq.ChatMessageRoleUser,
- Content: "Hello!",
- },
- },
- Stream: true,
- StreamOptions: &groq.StreamOptions{
- IncludeUsage: true,
- },
- },
- )
- a.NoError(err, "CreateCompletionStream returned error")
- defer stream.Close()
- expectedResponses := []groq.ChatCompletionStreamResponse{
- {
- ID: "1",
- Object: "completion",
- Created: 1598069254,
- Model: groq.Llama38B8192,
- SystemFingerprint: "fp_d9767fc5b9",
- Choices: []groq.ChatCompletionStreamChoice{
- {
- Delta: groq.ChatCompletionStreamChoiceDelta{
- Content: "response1",
- },
- FinishReason: "max_tokens",
- },
- },
- },
- {
- ID: "2",
- Object: "completion",
- Created: 1598069255,
- Model: groq.Llama38B8192,
- SystemFingerprint: "fp_d9767fc5b9",
- Choices: []groq.ChatCompletionStreamChoice{
- {
- Delta: groq.ChatCompletionStreamChoiceDelta{
- Content: "response2",
- },
- FinishReason: "max_tokens",
- },
- },
- },
- {
- ID: "3",
- Object: "completion",
- Created: 1598069256,
- Model: groq.Llama38B8192,
- SystemFingerprint: "fp_d9767fc5b9",
- Choices: []groq.ChatCompletionStreamChoice{},
- Usage: &groq.Usage{
- PromptTokens: 1,
- CompletionTokens: 1,
- TotalTokens: 2,
- },
- },
- }
-
- for ix, expectedResponse := range expectedResponses {
- ix++
- b, _ := json.Marshal(expectedResponse)
- t.Logf("%d: %s", ix, string(b))
-
- receivedResponse, streamErr := stream.Recv()
- if !errors.Is(streamErr, io.EOF) {
- a.NoError(streamErr, "stream.Recv() failed")
- }
- if !compareChatResponses(t, expectedResponse, receivedResponse) {
- t.Errorf(
- "Stream response %v: %v,BUT expected %v",
- ix,
- receivedResponse,
- expectedResponse,
- )
- }
- }
-
- _, streamErr := stream.Recv()
- if !errors.Is(streamErr, io.EOF) {
- t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr)
- }
-
- _, streamErr = stream.Recv()
-
- a.ErrorIs(
- streamErr,
- io.EOF,
- "stream.Recv() did not return EOF when the stream is finished",
- )
- if !errors.Is(streamErr, io.EOF) {
- t.Errorf(
- "stream.Recv() did not return EOF when the stream is finished: %v",
- streamErr,
- )
- }
-}
-
-// Helper funcs.
-func compareChatResponses(
- t *testing.T,
- r1, r2 groq.ChatCompletionStreamResponse,
-) bool {
- if r1.ID != r2.ID {
- t.Logf("Not Equal ID: %v", r1.ID)
- return false
- }
- if r1.Object != r2.Object {
- t.Logf("Not Equal Object: %v", r1.Object)
- return false
- }
- if r1.Created != r2.Created {
- t.Logf("Not Equal Created: %v", r1.Created)
- return false
- }
- if len(r1.Choices) != len(r2.Choices) {
- t.Logf("Not Equal Choices: %v", r1.Choices)
- return false
- }
- for i := range r1.Choices {
- if !compareChatStreamResponseChoices(r1.Choices[i], r2.Choices[i]) {
- t.Logf("Not Equal Choices: %v", r1.Choices[i])
- return false
- }
- }
- if r1.Usage != nil || r2.Usage != nil {
- if r1.Usage == nil || r2.Usage == nil {
- return false
- }
- if r1.Usage.PromptTokens != r2.Usage.PromptTokens ||
- r1.Usage.CompletionTokens != r2.Usage.CompletionTokens ||
- r1.Usage.TotalTokens != r2.Usage.TotalTokens {
- return false
- }
- }
- return true
-}
-
-func compareChatStreamResponseChoices(
- c1, c2 groq.ChatCompletionStreamChoice,
-) bool {
- if c1.Index != c2.Index {
- return false
- }
- if c1.Delta.Content != c2.Delta.Content {
- return false
- }
- if c1.FinishReason != c2.FinishReason {
- return false
- }
- return true
-}
+package groq
diff --git a/chat_unit_test.go b/chat_unit_test.go
new file mode 100644
index 0000000..26642ae
--- /dev/null
+++ b/chat_unit_test.go
@@ -0,0 +1,594 @@
+package groq_test
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "strconv"
+ "testing"
+
+ groq "github.com/conneroisu/groq-go"
+ "github.com/stretchr/testify/assert"
+)
+
+// TestCreateChatCompletionStream tests the CreateChatCompletionStream method.
+func TestCreateChatCompletionStream(t *testing.T) {
+ a := assert.New(t)
+ client, server, teardown := setupGroqTestServer()
+ defer teardown()
+ server.RegisterHandler(
+ "/v1/chat/completions",
+ func(w http.ResponseWriter, _ *http.Request) {
+ w.Header().Set("Content-Type", "text/event-stream")
+
+ // Send test responses
+ dataBytes := []byte{}
+ dataBytes = append(dataBytes, []byte("event: message\n")...)
+ data := `{"id":"1","object":"completion","created":1598069254,"model":"llama3-groq-70b-8192-tool-use-preview","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}]}`
+ dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)
+
+ dataBytes = append(dataBytes, []byte("event: message\n")...)
+ data = `{"id":"2","object":"completion","created":1598069255,"model":"llama3-groq-70b-8192-tool-use-preview","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}]}`
+ dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)
+
+ dataBytes = append(dataBytes, []byte("event: done\n")...)
+ dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...)
+
+ _, err := w.Write(dataBytes)
+ a.NoError(err, "Write error")
+ },
+ )
+
+ stream, err := client.CreateChatCompletionStream(
+ context.Background(),
+ groq.ChatCompletionRequest{
+ MaxTokens: 5,
+ Model: groq.Llama38B8192,
+ Messages: []groq.ChatCompletionMessage{
+ {
+ Role: groq.ChatMessageRoleUser,
+ Content: "Hello!",
+ },
+ },
+ Stream: true,
+ },
+ )
+ a.NoError(err, "CreateCompletionStream returned error")
+ defer stream.Close()
+
+ expectedResponses := []groq.ChatCompletionStreamResponse{
+ {
+ ID: "1",
+ Object: "completion",
+ Created: 1598069254,
+ Model: groq.Llama38B8192,
+ SystemFingerprint: "fp_d9767fc5b9",
+ Choices: []groq.ChatCompletionStreamChoice{
+ {
+ Delta: groq.ChatCompletionStreamChoiceDelta{
+ Content: "response1",
+ },
+ FinishReason: "max_tokens",
+ },
+ },
+ },
+ {
+ ID: "2",
+ Object: "completion",
+ Created: 1598069255,
+ Model: groq.Llama38B8192,
+ SystemFingerprint: "fp_d9767fc5b9",
+ Choices: []groq.ChatCompletionStreamChoice{
+ {
+ Delta: groq.ChatCompletionStreamChoiceDelta{
+ Content: "response2",
+ },
+ FinishReason: "max_tokens",
+ },
+ },
+ },
+ }
+
+ for ix, expectedResponse := range expectedResponses {
+ b, _ := json.Marshal(expectedResponse)
+ t.Logf("%d: %s", ix, string(b))
+
+ receivedResponse, streamErr := stream.Recv()
+ a.NoError(streamErr, "stream.Recv() failed")
+ if !compareChatResponses(t, expectedResponse, receivedResponse) {
+ t.Errorf(
+ "Stream response %v is %v, expected %v",
+ ix,
+ receivedResponse,
+ expectedResponse,
+ )
+ }
+ }
+
+ _, streamErr := stream.Recv()
+ if !errors.Is(streamErr, io.EOF) {
+ t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr)
+ }
+
+ _, streamErr = stream.Recv()
+ if !errors.Is(streamErr, io.EOF) {
+ t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr)
+ }
+
+ _, streamErr = stream.Recv()
+
+ a.ErrorIs(
+ streamErr,
+ io.EOF,
+ "stream.Recv() did not return EOF when the stream is finished",
+ )
+ if !errors.Is(streamErr, io.EOF) {
+ t.Errorf(
+ "stream.Recv() did not return EOF when the stream is finished: %v",
+ streamErr,
+ )
+ }
+}
+
+// TestCreateChatCompletionStreamError tests the CreateChatCompletionStream function with an error
+// in the response.
+func TestCreateChatCompletionStreamError(t *testing.T) {
+ a := assert.New(t)
+ client, server, teardown := setupGroqTestServer()
+ defer teardown()
+ server.RegisterHandler(
+ "/v1/chat/completions",
+ func(w http.ResponseWriter, _ *http.Request) {
+ w.Header().Set("Content-Type", "text/event-stream")
+
+ // Send test responses
+ dataBytes := []byte{}
+ dataStr := []string{
+ `{`,
+ `"error": {`,
+ `"message": "Incorrect API key provided: gsk-***************************************",`,
+ `"type": "invalid_request_error",`,
+ `"param": null,`,
+ `"code": "invalid_api_key"`,
+ `}`,
+ `}`,
+ }
+ for _, str := range dataStr {
+ dataBytes = append(dataBytes, []byte(str+"\n")...)
+ }
+
+ _, err := w.Write(dataBytes)
+ a.NoError(err, "Write error")
+ },
+ )
+
+ stream, err := client.CreateChatCompletionStream(
+ context.Background(),
+ groq.ChatCompletionRequest{
+ MaxTokens: 5,
+ Model: groq.Llama38B8192,
+ Messages: []groq.ChatCompletionMessage{
+ {
+ Role: groq.ChatMessageRoleUser,
+ Content: "Hello!",
+ },
+ },
+ Stream: true,
+ },
+ )
+ a.NoError(err, "CreateCompletionStream returned error")
+ defer stream.Close()
+
+ _, streamErr := stream.Recv()
+ a.Error(streamErr, "stream.Recv() did not return error")
+
+ var apiErr *groq.APIError
+ if !errors.As(streamErr, &apiErr) {
+ t.Errorf("stream.Recv() did not return APIError")
+ }
+ t.Logf("%+v\n", apiErr)
+}
+
+func TestCreateChatCompletionStreamWithHeaders(t *testing.T) {
+ a := assert.New(t)
+ client, server, teardown := setupGroqTestServer()
+ defer teardown()
+ xCustomHeader := "x-custom-header"
+ xCustomHeaderValue := "x-custom-header-value"
+ server.RegisterHandler(
+ "/v1/chat/completions",
+ func(w http.ResponseWriter, _ *http.Request) {
+ w.Header().Set("Content-Type", "text/event-stream")
+ w.Header().Set(xCustomHeader, xCustomHeaderValue)
+
+ // Send test responses
+ dataBytes := []byte(
+ `data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`,
+ )
+ dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...)
+
+ _, err := w.Write(dataBytes)
+ a.NoError(err, "Write error")
+ },
+ )
+
+ stream, err := client.CreateChatCompletionStream(
+ context.Background(),
+ groq.ChatCompletionRequest{
+ MaxTokens: 5,
+ Model: groq.Llama38B8192,
+ Messages: []groq.ChatCompletionMessage{
+ {
+ Role: groq.ChatMessageRoleUser,
+ Content: "Hello!",
+ },
+ },
+ Stream: true,
+ },
+ )
+ a.NoError(err, "CreateCompletionStream returned error")
+ defer stream.Close()
+
+ value := stream.Header.Get(xCustomHeader)
+ if value != xCustomHeaderValue {
+ t.Errorf("expected %s to be %s", xCustomHeaderValue, value)
+ }
+}
+
+func TestCreateChatCompletionStreamWithRatelimitHeaders(t *testing.T) {
+ client, server, teardown := setupGroqTestServer()
+ a := assert.New(t)
+ rateLimitHeaders := map[string]interface{}{
+ "x-ratelimit-limit-requests": 100,
+ "x-ratelimit-limit-tokens": 1000,
+ "x-ratelimit-remaining-requests": 99,
+ "x-ratelimit-remaining-tokens": 999,
+ "x-ratelimit-reset-requests": "1s",
+ "x-ratelimit-reset-tokens": "1m",
+ }
+ defer teardown()
+ server.RegisterHandler(
+ "/v1/chat/completions",
+ func(w http.ResponseWriter, _ *http.Request) {
+ w.Header().Set("Content-Type", "text/event-stream")
+ for k, v := range rateLimitHeaders {
+ switch val := v.(type) {
+ case int:
+ w.Header().Set(k, strconv.Itoa(val))
+ default:
+ w.Header().Set(k, fmt.Sprintf("%s", v))
+ }
+ }
+
+ // Send test responses
+ dataBytes := []byte(
+ `data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`,
+ )
+ dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...)
+
+ _, err := w.Write(dataBytes)
+ a.NoError(err, "Write error")
+ },
+ )
+
+ stream, err := client.CreateChatCompletionStream(
+ context.Background(),
+ groq.ChatCompletionRequest{
+ MaxTokens: 5,
+ Model: groq.Llama38B8192,
+ Messages: []groq.ChatCompletionMessage{
+ {
+ Role: groq.ChatMessageRoleUser,
+ Content: "Hello!",
+ },
+ },
+ Stream: true,
+ },
+ )
+ a.NoError(err, "CreateCompletionStream returned error")
+ defer stream.Close()
+
+ headers := newRateLimitHeaders(stream.Header)
+ bs1, _ := json.Marshal(headers)
+ bs2, _ := json.Marshal(rateLimitHeaders)
+ if string(bs1) != string(bs2) {
+ t.Errorf("expected rate limit header %s to be %s", bs2, bs1)
+ }
+}
+
+// newRateLimitHeaders creates a new RateLimitHeaders from an http.Header.
+func newRateLimitHeaders(h http.Header) groq.RateLimitHeaders {
+ limitReq, _ := strconv.Atoi(h.Get("x-ratelimit-limit-requests"))
+ limitTokens, _ := strconv.Atoi(h.Get("x-ratelimit-limit-tokens"))
+ remainingReq, _ := strconv.Atoi(h.Get("x-ratelimit-remaining-requests"))
+ remainingTokens, _ := strconv.Atoi(h.Get("x-ratelimit-remaining-tokens"))
+ return groq.RateLimitHeaders{
+ LimitRequests: limitReq,
+ LimitTokens: limitTokens,
+ RemainingRequests: remainingReq,
+ RemainingTokens: remainingTokens,
+ ResetRequests: groq.ResetTime(h.Get("x-ratelimit-reset-requests")),
+ ResetTokens: groq.ResetTime(h.Get("x-ratelimit-reset-tokens")),
+ }
+}
+
+func TestCreateChatCompletionStreamErrorWithDataPrefix(t *testing.T) {
+ a := assert.New(t)
+ client, server, teardown := setupGroqTestServer()
+ defer teardown()
+ server.RegisterHandler(
+ "/v1/chat/completions",
+ func(w http.ResponseWriter, _ *http.Request) {
+ w.Header().Set("Content-Type", "text/event-stream")
+
+ // Send test responses
+ dataBytes := []byte(
+ `data: {"error":{"message":"The server had an error while processing your request. Sorry about that!", "type":"server_ error", "param":null,"code":null}}`,
+ )
+ dataBytes = append(dataBytes, []byte("\n\ndata: [DONE]\n\n")...)
+
+ _, err := w.Write(dataBytes)
+ a.NoError(err, "Write error")
+ },
+ )
+
+ stream, err := client.CreateChatCompletionStream(
+ context.Background(),
+ groq.ChatCompletionRequest{
+ MaxTokens: 5,
+ Model: groq.Llama38B8192,
+ Messages: []groq.ChatCompletionMessage{
+ {
+ Role: groq.ChatMessageRoleUser,
+ Content: "Hello!",
+ },
+ },
+ Stream: true,
+ },
+ )
+ a.NoError(err, "CreateCompletionStream returned error")
+ defer stream.Close()
+
+ _, streamErr := stream.Recv()
+ a.Error(streamErr, "stream.Recv() did not return error")
+
+ var apiErr *groq.APIError
+ if !errors.As(streamErr, &apiErr) {
+ t.Errorf("stream.Recv() did not return APIError")
+ }
+ t.Logf("%+v\n", apiErr)
+}
+
+func TestCreateChatCompletionStreamRateLimitError(t *testing.T) {
+ a := assert.New(t)
+ client, server, teardown := setupGroqTestServer()
+ defer teardown()
+ server.RegisterHandler(
+ "/v1/chat/completions",
+ func(w http.ResponseWriter, _ *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(429)
+
+ // Send test responses
+ dataBytes := []byte(`{"error":{` +
+ `"message": "You are sending requests too quickly.",` +
+ `"type":"rate_limit_reached",` +
+ `"param":null,` +
+ `"code":"rate_limit_reached"}}`)
+
+ _, err := w.Write(dataBytes)
+ a.NoError(err, "Write error")
+ },
+ )
+ _, err := client.CreateChatCompletionStream(
+ context.Background(),
+ groq.ChatCompletionRequest{
+ MaxTokens: 5,
+ Model: groq.Llama38B8192,
+ Messages: []groq.ChatCompletionMessage{
+ {
+ Role: groq.ChatMessageRoleUser,
+ Content: "Hello!",
+ },
+ },
+ Stream: true,
+ },
+ )
+ var apiErr *groq.APIError
+ if !errors.As(err, &apiErr) {
+ t.Errorf(
+ "TestCreateChatCompletionStreamRateLimitError did not return APIError",
+ )
+ }
+ t.Logf("%+v\n", apiErr)
+}
+
+func TestCreateChatCompletionStreamStreamOptions(t *testing.T) {
+ a := assert.New(t)
+ client, server, teardown := setupGroqTestServer()
+ defer teardown()
+
+ server.RegisterHandler(
+ "/v1/chat/completions",
+ func(w http.ResponseWriter, _ *http.Request) {
+ w.Header().Set("Content-Type", "text/event-stream")
+
+ // Send test responses
+ var dataBytes []byte
+ data := `{"id":"1","object":"completion","created":1598069254,"model":"llama3-groq-70b-8192-tool-use-preview","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response1"},"finish_reason":"max_tokens"}],"usage":null}`
+ dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)
+
+ data = `{"id":"2","object":"completion","created":1598069255,"model":"llama3-groq-70b-8192-tool-use-preview","system_fingerprint": "fp_d9767fc5b9","choices":[{"index":0,"delta":{"content":"response2"},"finish_reason":"max_tokens"}],"usage":null}`
+ dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)
+
+ data = `{"id":"3","object":"completion","created":1598069256,"model":"llama3-groq-70b-8192-tool-use-preview","system_fingerprint": "fp_d9767fc5b9","choices":[],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}`
+ dataBytes = append(dataBytes, []byte("data: "+data+"\n\n")...)
+
+ dataBytes = append(dataBytes, []byte("data: [DONE]\n\n")...)
+
+ _, err := w.Write(dataBytes)
+ a.NoError(err, "Write error")
+ },
+ )
+
+ stream, err := client.CreateChatCompletionStream(
+ context.Background(),
+ groq.ChatCompletionRequest{
+ MaxTokens: 5,
+ Model: groq.Llama38B8192,
+ Messages: []groq.ChatCompletionMessage{
+ {
+ Role: groq.ChatMessageRoleUser,
+ Content: "Hello!",
+ },
+ },
+ Stream: true,
+ StreamOptions: &groq.StreamOptions{
+ IncludeUsage: true,
+ },
+ },
+ )
+ a.NoError(err, "CreateCompletionStream returned error")
+ defer stream.Close()
+ expectedResponses := []groq.ChatCompletionStreamResponse{
+ {
+ ID: "1",
+ Object: "completion",
+ Created: 1598069254,
+ Model: groq.Llama38B8192,
+ SystemFingerprint: "fp_d9767fc5b9",
+ Choices: []groq.ChatCompletionStreamChoice{
+ {
+ Delta: groq.ChatCompletionStreamChoiceDelta{
+ Content: "response1",
+ },
+ FinishReason: "max_tokens",
+ },
+ },
+ },
+ {
+ ID: "2",
+ Object: "completion",
+ Created: 1598069255,
+ Model: groq.Llama38B8192,
+ SystemFingerprint: "fp_d9767fc5b9",
+ Choices: []groq.ChatCompletionStreamChoice{
+ {
+ Delta: groq.ChatCompletionStreamChoiceDelta{
+ Content: "response2",
+ },
+ FinishReason: "max_tokens",
+ },
+ },
+ },
+ {
+ ID: "3",
+ Object: "completion",
+ Created: 1598069256,
+ Model: groq.Llama38B8192,
+ SystemFingerprint: "fp_d9767fc5b9",
+ Choices: []groq.ChatCompletionStreamChoice{},
+ Usage: &groq.Usage{
+ PromptTokens: 1,
+ CompletionTokens: 1,
+ TotalTokens: 2,
+ },
+ },
+ }
+
+ for ix, expectedResponse := range expectedResponses {
+ ix++
+ b, _ := json.Marshal(expectedResponse)
+ t.Logf("%d: %s", ix, string(b))
+
+ receivedResponse, streamErr := stream.Recv()
+ if !errors.Is(streamErr, io.EOF) {
+ a.NoError(streamErr, "stream.Recv() failed")
+ }
+ if !compareChatResponses(t, expectedResponse, receivedResponse) {
+ t.Errorf(
+ "Stream response %v: %v,BUT expected %v",
+ ix,
+ receivedResponse,
+ expectedResponse,
+ )
+ }
+ }
+
+ _, streamErr := stream.Recv()
+ if !errors.Is(streamErr, io.EOF) {
+ t.Errorf("stream.Recv() did not return EOF in the end: %v", streamErr)
+ }
+
+ _, streamErr = stream.Recv()
+
+ a.ErrorIs(
+ streamErr,
+ io.EOF,
+ "stream.Recv() did not return EOF when the stream is finished",
+ )
+ if !errors.Is(streamErr, io.EOF) {
+ t.Errorf(
+ "stream.Recv() did not return EOF when the stream is finished: %v",
+ streamErr,
+ )
+ }
+}
+
+// Helper funcs.
+func compareChatResponses(
+ t *testing.T,
+ r1, r2 groq.ChatCompletionStreamResponse,
+) bool {
+ if r1.ID != r2.ID {
+ t.Logf("Not Equal ID: %v", r1.ID)
+ return false
+ }
+ if r1.Object != r2.Object {
+ t.Logf("Not Equal Object: %v", r1.Object)
+ return false
+ }
+ if r1.Created != r2.Created {
+ t.Logf("Not Equal Created: %v", r1.Created)
+ return false
+ }
+ if len(r1.Choices) != len(r2.Choices) {
+ t.Logf("Not Equal Choices: %v", r1.Choices)
+ return false
+ }
+ for i := range r1.Choices {
+ if !compareChatStreamResponseChoices(r1.Choices[i], r2.Choices[i]) {
+ t.Logf("Not Equal Choices: %v", r1.Choices[i])
+ return false
+ }
+ }
+ if r1.Usage != nil || r2.Usage != nil {
+ if r1.Usage == nil || r2.Usage == nil {
+ return false
+ }
+ if r1.Usage.PromptTokens != r2.Usage.PromptTokens ||
+ r1.Usage.CompletionTokens != r2.Usage.CompletionTokens ||
+ r1.Usage.TotalTokens != r2.Usage.TotalTokens {
+ return false
+ }
+ }
+ return true
+}
+
+func compareChatStreamResponseChoices(
+ c1, c2 groq.ChatCompletionStreamChoice,
+) bool {
+ if c1.Index != c2.Index {
+ return false
+ }
+ if c1.Delta.Content != c2.Delta.Content {
+ return false
+ }
+ if c1.FinishReason != c2.FinishReason {
+ return false
+ }
+ return true
+}
diff --git a/client_test.go b/client_test.go
index 763aef6..07a7708 100644
--- a/client_test.go
+++ b/client_test.go
@@ -1,36 +1,4 @@
-package groq_test
+//go:build !test
+// +build !test
-import (
- "log"
- "testing"
-
- groq "github.com/conneroisu/groq-go"
- "github.com/conneroisu/groq-go/internal/test"
- "github.com/stretchr/testify/assert"
-)
-
-func setupGroqTestServer() (
- client *groq.Client,
- server *test.ServerTest,
- teardown func(),
-) {
- server = test.NewTestServer()
- ts := server.GroqTestServer()
- ts.Start()
- teardown = ts.Close
- client, err := groq.NewClient(
- test.GetTestToken(),
- groq.WithBaseURL(ts.URL+"/v1"),
- )
- if err != nil {
- log.Fatal(err)
- }
- return
-}
-
-func TestEmptyKeyClientCreation(t *testing.T) {
- client, err := groq.NewClient("")
- a := assert.New(t)
- a.Error(err, "NewClient should return error")
- a.Nil(client, "NewClient should return nil")
-}
+package groq
diff --git a/client_unit_test.go b/client_unit_test.go
new file mode 100644
index 0000000..763aef6
--- /dev/null
+++ b/client_unit_test.go
@@ -0,0 +1,36 @@
+package groq_test
+
+import (
+ "log"
+ "testing"
+
+ groq "github.com/conneroisu/groq-go"
+ "github.com/conneroisu/groq-go/internal/test"
+ "github.com/stretchr/testify/assert"
+)
+
+func setupGroqTestServer() (
+ client *groq.Client,
+ server *test.ServerTest,
+ teardown func(),
+) {
+ server = test.NewTestServer()
+ ts := server.GroqTestServer()
+ ts.Start()
+ teardown = ts.Close
+ client, err := groq.NewClient(
+ test.GetTestToken(),
+ groq.WithBaseURL(ts.URL+"/v1"),
+ )
+ if err != nil {
+ log.Fatal(err)
+ }
+ return
+}
+
+func TestEmptyKeyClientCreation(t *testing.T) {
+ client, err := groq.NewClient("")
+ a := assert.New(t)
+ a.Error(err, "NewClient should return error")
+ a.Nil(client, "NewClient should return nil")
+}
diff --git a/cmd/models/main.go b/cmd/models/main.go
index 195c7f3..7332025 100644
--- a/cmd/models/main.go
+++ b/cmd/models/main.go
@@ -14,6 +14,7 @@ import (
"log"
"net/http"
"os"
+ "sort"
"text/template"
"time"
@@ -24,7 +25,7 @@ const (
outputFile = "models.go"
)
-//go:embed models.go.txt
+//go:embed models.go.tmpl
var outputFileTemplate string
type templateParams struct {
@@ -135,13 +136,13 @@ func fillTemplate(w io.Writer, models []ResponseModel) error {
}
return false
},
- // isModerationModel returns true if the model is
+ // notModerationModel returns false if the model is
// a model that can be used for moderation.
//
// llama-guard-3-8b is a moderation model.
- "isModerationModel": func(model ResponseModel) bool {
+ "notModerationModel": func(model ResponseModel) bool {
// if the id of the model is llama-guard-3-8b
- return model.ID == "llama-guard-3-8b"
+ return model.ID != "llama-guard-3-8b"
},
// getCurrentDate returns the current date in the format
// "2006-01-02 15:04:05".
@@ -149,7 +150,7 @@ func fillTemplate(w io.Writer, models []ResponseModel) error {
return time.Now().Format("2006-01-02 15:04:05")
},
}
- tmpla, err := template.New("output").
+ tmpla, err := template.New("models").
Funcs(funcMap).
Parse(outputFileTemplate)
if err != nil {
@@ -168,4 +169,8 @@ func nameModels(models []ResponseModel) {
models[i].Name = lo.PascalCase(models[i].ID)
}
}
+ // sort models by name alphabetically
+ sort.Slice(models, func(i, j int) bool {
+ return models[i].Name < models[j].Name
+ })
}
diff --git a/cmd/models/models.go.txt b/cmd/models/models.go.tmpl
similarity index 62%
rename from cmd/models/models.go.txt
rename to cmd/models/models.go.tmpl
index 2cbf2ef..2eaf968 100644
--- a/cmd/models/models.go.txt
+++ b/cmd/models/models.go.tmpl
@@ -1,3 +1,4 @@
+{{ define "models" }}
// Code generated by groq-modeler DO NOT EDIT.
//
// Created at: {{ getCurrentDate }}
@@ -44,7 +45,7 @@ const (
TranscriptionTimestampGranularitySegment TranscriptionTimestampGranularity = "segment" // TranscriptionTimestampGranularitySegment is the segment timestamp granularity.
{{- range $model := .Models }}
- {{ $model.Name }} Model = "{{ $model.ID }}" // {{ $model.Name }} is an AI {{if isTextModel $model}}text{{else if isAudioModel $model}}audio{{else if isModerationModel $model}}moderation{{end}} model provided by {{$model.OwnedBy}}. It has {{$model.ContextWindow}} context window.
+ {{ $model.Name }} Model = "{{ $model.ID }}" // {{ $model.Name }} is an AI {{if isTextModel $model}}text{{else if isAudioModel $model}}audio{{else if notModerationModel $model}}moderation{{end}} model provided by {{$model.OwnedBy}}. It has {{$model.ContextWindow}} context window.
{{- end }}
)
@@ -66,7 +67,7 @@ var disabledModelsForEndpoints = map[Endpoint]map[Model]bool{
{{ $model.Name }}: true, {{- end }} {{- end }}
},
moderationsSuffix: {
- {{- range $model := .Models }} {{ if isModerationModel $model }}
+ {{- range $model := .Models }} {{ if notModerationModel $model }}
{{ $model.Name }}: true, {{- end }} {{- end }}
},
}
@@ -74,3 +75,56 @@ var disabledModelsForEndpoints = map[Endpoint]map[Model]bool{
func endpointSupportsModel(endpoint Endpoint, model Model) bool {
return !disabledModelsForEndpoints[endpoint][model]
}
+{{ end }}
+
+{{ define "test" }}
+package groq_test
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "os"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+{{- range $model := .TextModels }}
+// Test{{ $model.Name }} tests the {{ $model.Name }} model.
+// It ensures that the model is supported by the groq-go library, the groq API,
+// and the operations are working as expected for the specific model type.
+func Test{{ $model.Name }}(t *testing.T) {
+}
+{{- end }}
+
+{{- range $model := .AudioModels }}
+// Test{{ $model.Name }} tests the {{ $model.Name }} model.
+// It ensures that the model is supported by the groq-go library, the groq API,
+// and the operations are working as expected for the specific model type.
+func Test{{ $model.Name }}(t *testing.T) {
+}
+{{- end }}
+
+{{- range $model := .ModerationModels }}
+// Test{{ $model.Name }} tests the {{ $model.Name }} model.
+// It ensures that the model is supported by the groq-go library, the groq API,
+// and the operations are working as expected for the specific model type.
+func Test{{ $model.Name }}(t *testing.T) {
+}
+{{- end }}
+
+{{- range $model := .MultiModalModels }}
+// Test{{ $model.Name }} tests the {{ $model.Name }} model.
+// It ensures that the model is supported by the groq-go library, the groq API,
+// and the operations are working as expected for the specific model type.
+func Test{{ $model.Name }}(t *testing.T) {
+}
+{{- end }}
+{{ end }}
diff --git a/completion_test.go b/completion_test.go
index 915e6b0..07a7708 100644
--- a/completion_test.go
+++ b/completion_test.go
@@ -1,150 +1,4 @@
-package groq_test
+//go:build !test
+// +build !test
-import (
- "context"
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "net/http"
- "strconv"
- "strings"
- "testing"
- "time"
-
- "github.com/conneroisu/groq-go"
- "github.com/stretchr/testify/assert"
-)
-
-/* // TestCompletionsWrongModel tests the CreateCompletion method with a wrong model. */
-/* func TestCompletionsWrongModel(t *testing.T) { */
-/* a := assert.New(t) */
-/* client, err := groq.NewClient( */
-/* "whatever", */
-/* groq.WithBaseURL("http://localhost/v1"), */
-/* ) */
-/* a.NoError(err, "NewClient error") */
-/* */
-/* _, err = client.CreateCompletion( */
-/* context.Background(), */
-/* groq.CompletionRequest{ */
-/* MaxTokens: 5, */
-/* Model: groq.Whisper_Large_V3, */
-/* }, */
-/* ) */
-/* if !errors.Is(err, groq.ErrCompletionUnsupportedModel{Model: groq.Whisper_Large_V3}) { */
-/* t.Fatalf( */
-/* "CreateCompletion should return ErrCompletionUnsupportedModel, but returned: %v", */
-/* err, */
-/* ) */
-/* } */
-/* } */
-
-// TestCompletionWithStream tests the CreateCompletion method with a stream.
-func TestCompletionWithStream(t *testing.T) {
- a := assert.New(t)
- client, err := groq.NewClient(
- "whatever",
- groq.WithBaseURL("http://localhost/v1"),
- )
- a.NoError(err, "NewClient error")
-
- ctx := context.Background()
- req := groq.CompletionRequest{Stream: true}
- _, err = client.CreateCompletion(ctx, req)
- if !errors.Is(err, groq.ErrCompletionStreamNotSupported{}) {
- t.Fatalf(
- "CreateCompletion didn't return ErrCompletionStreamNotSupported",
- )
- }
-}
-
-// TestCompletions Tests the completions endpoint of the API using the mocked server.
-func TestCompletions(t *testing.T) {
- a := assert.New(t)
- client, server, teardown := setupGroqTestServer()
- defer teardown()
- server.RegisterHandler("/v1/completions", handleCompletionEndpoint)
- req := groq.CompletionRequest{
- MaxTokens: 5,
- Model: "ada",
- Prompt: "Lorem ipsum",
- }
- _, err := client.CreateCompletion(context.Background(), req)
- a.NoError(err, "CreateCompletion error")
-}
-
-// handleCompletionEndpoint Handles the completion endpoint by the test server.
-func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
- var err error
- var resBytes []byte
-
- // completions only accepts POST requests
- if r.Method != "POST" {
- http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
- }
- var completionReq groq.CompletionRequest
- if completionReq, err = getCompletionBody(r); err != nil {
- http.Error(w, "could not read request", http.StatusInternalServerError)
- return
- }
- res := groq.CompletionResponse{
- ID: strconv.Itoa(int(time.Now().Unix())),
- Object: "test-object",
- Created: time.Now().Unix(),
- // would be nice to validate Model during testing, but
- // this may not be possible with how much upkeep
- // would be required / wouldn't make much sense
- Model: completionReq.Model,
- }
- // create completions
- n := completionReq.N
- if n == 0 {
- n = 1
- }
- for i := 0; i < n; i++ {
- // generate a random string of length completionReq.Length
- completionStr := strings.Repeat("a", completionReq.MaxTokens)
- if completionReq.Echo {
- completionStr = completionReq.Prompt.(string) + completionStr
- }
- res.Choices = append(res.Choices, groq.CompletionChoice{
- Text: completionStr,
- Index: i,
- })
- }
- inputTokens := numTokens(completionReq.Prompt.(string)) * n
- completionTokens := completionReq.MaxTokens * n
- res.Usage = groq.Usage{
- PromptTokens: inputTokens,
- CompletionTokens: completionTokens,
- TotalTokens: inputTokens + completionTokens,
- }
- resBytes, _ = json.Marshal(res)
- fmt.Fprintln(w, string(resBytes))
-}
-
-// numTokens Returns the number of GPT-3 encoded tokens in the given text.
-// This function approximates based on the rule of thumb stated by OpenAI:
-// https://beta.openai.com/tokenizer
-//
-// TODO: implement an actual tokenizer for each model available and use that
-// instead.
-func numTokens(s string) int {
- return int(float32(len(s)) / 4)
-}
-
-// getCompletionBody Returns the body of the request to create a completion.
-func getCompletionBody(r *http.Request) (groq.CompletionRequest, error) {
- completion := groq.CompletionRequest{}
- // read the request body
- reqBody, err := io.ReadAll(r.Body)
- if err != nil {
- return groq.CompletionRequest{}, err
- }
- err = json.Unmarshal(reqBody, &completion)
- if err != nil {
- return groq.CompletionRequest{}, err
- }
- return completion, nil
-}
+package groq
diff --git a/completion_unit_test.go b/completion_unit_test.go
new file mode 100644
index 0000000..915e6b0
--- /dev/null
+++ b/completion_unit_test.go
@@ -0,0 +1,150 @@
+package groq_test
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "strconv"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/conneroisu/groq-go"
+ "github.com/stretchr/testify/assert"
+)
+
+/* // TestCompletionsWrongModel tests the CreateCompletion method with a wrong model. */
+/* func TestCompletionsWrongModel(t *testing.T) { */
+/* a := assert.New(t) */
+/* client, err := groq.NewClient( */
+/* "whatever", */
+/* groq.WithBaseURL("http://localhost/v1"), */
+/* ) */
+/* a.NoError(err, "NewClient error") */
+/* */
+/* _, err = client.CreateCompletion( */
+/* context.Background(), */
+/* groq.CompletionRequest{ */
+/* MaxTokens: 5, */
+/* Model: groq.Whisper_Large_V3, */
+/* }, */
+/* ) */
+/* if !errors.Is(err, groq.ErrCompletionUnsupportedModel{Model: groq.Whisper_Large_V3}) { */
+/* t.Fatalf( */
+/* "CreateCompletion should return ErrCompletionUnsupportedModel, but returned: %v", */
+/* err, */
+/* ) */
+/* } */
+/* } */
+
+// TestCompletionWithStream tests the CreateCompletion method with a stream.
+func TestCompletionWithStream(t *testing.T) {
+ a := assert.New(t)
+ client, err := groq.NewClient(
+ "whatever",
+ groq.WithBaseURL("http://localhost/v1"),
+ )
+ a.NoError(err, "NewClient error")
+
+ ctx := context.Background()
+ req := groq.CompletionRequest{Stream: true}
+ _, err = client.CreateCompletion(ctx, req)
+ if !errors.Is(err, groq.ErrCompletionStreamNotSupported{}) {
+ t.Fatalf(
+ "CreateCompletion didn't return ErrCompletionStreamNotSupported",
+ )
+ }
+}
+
+// TestCompletions Tests the completions endpoint of the API using the mocked server.
+func TestCompletions(t *testing.T) {
+ a := assert.New(t)
+ client, server, teardown := setupGroqTestServer()
+ defer teardown()
+ server.RegisterHandler("/v1/completions", handleCompletionEndpoint)
+ req := groq.CompletionRequest{
+ MaxTokens: 5,
+ Model: "ada",
+ Prompt: "Lorem ipsum",
+ }
+ _, err := client.CreateCompletion(context.Background(), req)
+ a.NoError(err, "CreateCompletion error")
+}
+
+// handleCompletionEndpoint Handles the completion endpoint by the test server.
+func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
+ var err error
+ var resBytes []byte
+
+ // completions only accepts POST requests
+ if r.Method != "POST" {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ }
+ var completionReq groq.CompletionRequest
+ if completionReq, err = getCompletionBody(r); err != nil {
+ http.Error(w, "could not read request", http.StatusInternalServerError)
+ return
+ }
+ res := groq.CompletionResponse{
+ ID: strconv.Itoa(int(time.Now().Unix())),
+ Object: "test-object",
+ Created: time.Now().Unix(),
+ // would be nice to validate Model during testing, but
+ // this may not be possible with how much upkeep
+ // would be required / wouldn't make much sense
+ Model: completionReq.Model,
+ }
+ // create completions
+ n := completionReq.N
+ if n == 0 {
+ n = 1
+ }
+ for i := 0; i < n; i++ {
+ // generate a random string of length completionReq.Length
+ completionStr := strings.Repeat("a", completionReq.MaxTokens)
+ if completionReq.Echo {
+ completionStr = completionReq.Prompt.(string) + completionStr
+ }
+ res.Choices = append(res.Choices, groq.CompletionChoice{
+ Text: completionStr,
+ Index: i,
+ })
+ }
+ inputTokens := numTokens(completionReq.Prompt.(string)) * n
+ completionTokens := completionReq.MaxTokens * n
+ res.Usage = groq.Usage{
+ PromptTokens: inputTokens,
+ CompletionTokens: completionTokens,
+ TotalTokens: inputTokens + completionTokens,
+ }
+ resBytes, _ = json.Marshal(res)
+ fmt.Fprintln(w, string(resBytes))
+}
+
+// numTokens Returns the number of GPT-3 encoded tokens in the given text.
+// This function approximates based on the rule of thumb stated by OpenAI:
+// https://beta.openai.com/tokenizer
+//
+// TODO: implement an actual tokenizer for each model available and use that
+// instead.
+func numTokens(s string) int {
+ return int(float32(len(s)) / 4)
+}
+
+// getCompletionBody Returns the body of the request to create a completion.
+func getCompletionBody(r *http.Request) (groq.CompletionRequest, error) {
+ completion := groq.CompletionRequest{}
+ // read the request body
+ reqBody, err := io.ReadAll(r.Body)
+ if err != nil {
+ return groq.CompletionRequest{}, err
+ }
+ err = json.Unmarshal(reqBody, &completion)
+ if err != nil {
+ return groq.CompletionRequest{}, err
+ }
+ return completion, nil
+}
diff --git a/errors.go b/errors.go
index 32de196..d8580eb 100644
--- a/errors.go
+++ b/errors.go
@@ -22,7 +22,8 @@ type APIError struct {
HTTPStatusCode int `json:"-"` // HTTPStatusCode is the status code of the error.
}
-// ErrChatCompletionInvalidModel is an error that occurs when the model is not supported with the CreateChatCompletion method.
+// ErrChatCompletionInvalidModel is an error that occurs when the model is not
+// supported with a specific endpoint.
type ErrChatCompletionInvalidModel struct {
Model Model
Endpoint Endpoint
@@ -38,7 +39,8 @@ func (e ErrChatCompletionInvalidModel) Error() string {
Error()
}
-// ErrChatCompletionStreamNotSupported is an error that occurs when streaming is not supported with the CreateChatCompletionStream method.
+// ErrChatCompletionStreamNotSupported is an error that occurs when streaming
+// is not supported with the CreateChatCompletionStream method.
type ErrChatCompletionStreamNotSupported struct {
model Model
}
@@ -49,7 +51,8 @@ func (e ErrChatCompletionStreamNotSupported) Error() string {
Error()
}
-// ErrContentFieldsMisused is an error that occurs when both Content and MultiContent properties are set.
+// ErrContentFieldsMisused is an error that occurs when both Content and
+// MultiContent properties are set.
type ErrContentFieldsMisused struct {
field string
}
@@ -70,7 +73,8 @@ func (e ErrCompletionUnsupportedModel) Error() string {
Error()
}
-// ErrCompletionStreamNotSupported is an error that occurs when streaming is not supported with the CreateCompletionStream method.
+// ErrCompletionStreamNotSupported is an error that occurs when streaming is
+// not supported with the CreateCompletionStream method.
type ErrCompletionStreamNotSupported struct{}
// Error implements the error interface.
@@ -79,7 +83,8 @@ func (e ErrCompletionStreamNotSupported) Error() string {
Error()
}
-// ErrCompletionRequestPromptTypeNotSupported is an error that occurs when the type of CompletionRequest.Prompt only supports string and []string.
+// ErrCompletionRequestPromptTypeNotSupported is an error that occurs when the
+// type of CompletionRequest.Prompt only supports string and []string.
type ErrCompletionRequestPromptTypeNotSupported struct{}
// Error implements the error interface.
@@ -88,7 +93,8 @@ func (e ErrCompletionRequestPromptTypeNotSupported) Error() string {
Error()
}
-// ErrTooManyEmptyStreamMessages is returned when the stream has sent too many empty messages.
+// ErrTooManyEmptyStreamMessages is returned when the stream has sent too many
+// empty messages.
type ErrTooManyEmptyStreamMessages struct{}
// Error returns the error message.
@@ -98,11 +104,11 @@ func (e ErrTooManyEmptyStreamMessages) Error() string {
// errorAccumulator is an interface for accumulating errors
type errorAccumulator interface {
- // Write writes bytes to the error accumulator
+ // Write method writes bytes to the error accumulator
//
// It implements the io.Writer interface.
Write(p []byte) error
- // Bytes returns the bytes of the error accumulator.
+ // Bytes method returns the bytes of the error accumulator.
Bytes() []byte
}
@@ -131,7 +137,7 @@ func newErrorAccumulator() errorAccumulator {
}
}
-// Write writes bytes to the error accumulator.
+// Write method writes bytes to the error accumulator.
func (e *DefaultErrorAccumulator) Write(p []byte) error {
_, err := e.Buffer.Write(p)
if err != nil {
@@ -140,7 +146,7 @@ func (e *DefaultErrorAccumulator) Write(p []byte) error {
return nil
}
-// Bytes returns the bytes of the error accumulator.
+// Bytes method returns the bytes of the error accumulator.
func (e *DefaultErrorAccumulator) Bytes() (errBytes []byte) {
if e.Buffer.Len() == 0 {
return
@@ -149,7 +155,7 @@ func (e *DefaultErrorAccumulator) Bytes() (errBytes []byte) {
return
}
-// Error implements the error interface.
+// Error method implements the error interface on APIError.
func (e *APIError) Error() string {
if e.HTTPStatusCode > 0 {
return fmt.Sprintf(
diff --git a/examples/llava-blind/main.go b/examples/llava-blind/main.go
index d526488..ffdbcd0 100644
--- a/examples/llava-blind/main.go
+++ b/examples/llava-blind/main.go
@@ -28,27 +28,30 @@ func run(
return err
}
- response, err := client.CreateChatCompletion(ctx, groq.ChatCompletionRequest{
- Model: groq.LlavaV157B4096Preview,
- Messages: []groq.ChatCompletionMessage{
- {
- Role: groq.ChatMessageRoleUser,
- MultiContent: []groq.ChatMessagePart{
- {
- Type: groq.ChatMessagePartTypeText,
- Text: "What is the contents of the image?",
- },
- {
- Type: groq.ChatMessagePartTypeImageURL,
- ImageURL: &groq.ChatMessageImageURL{
- URL: "https://cdnimg.webstaurantstore.com/images/products/large/87539/251494.jpg",
- Detail: "auto",
+ response, err := client.CreateChatCompletion(
+ ctx,
+ groq.ChatCompletionRequest{
+ Model: groq.LlavaV157B4096Preview,
+ Messages: []groq.ChatCompletionMessage{
+ {
+ Role: groq.ChatMessageRoleUser,
+ MultiContent: []groq.ChatMessagePart{
+ {
+ Type: groq.ChatMessagePartTypeText,
+ Text: "What is the contents of the image?",
},
- }},
+ {
+ Type: groq.ChatMessagePartTypeImageURL,
+ ImageURL: &groq.ChatMessageImageURL{
+ URL: "https://cdnimg.webstaurantstore.com/images/products/large/87539/251494.jpg",
+ Detail: "auto",
+ },
+ }},
+ },
},
+ MaxTokens: 2000,
},
- MaxTokens: 2000,
- })
+ )
if err != nil {
return err
}
diff --git a/examples/moderation/README.md b/examples/moderation/README.md
new file mode 100644
index 0000000..3c9e9f3
--- /dev/null
+++ b/examples/moderation/README.md
@@ -0,0 +1,12 @@
+# moderation
+
+This is an example of using groq-go to create a chat moderation using the llama-3BGuard model.
+
+## Usage
+
+Make sure you have a groq key set in the environment variable `GROQ_KEY`.
+
+```bash
+export GROQ_KEY=your-groq-key
+go run .
+```
diff --git a/examples/moderation/main.go b/examples/moderation/main.go
new file mode 100644
index 0000000..8af45c5
--- /dev/null
+++ b/examples/moderation/main.go
@@ -0,0 +1,44 @@
+// Package main is an example of using groq-go to create a chat moderation
+// using the llama-3BGuard model.
+package main
+
+import (
+ "context"
+ "fmt"
+ "os"
+
+ "github.com/conneroisu/groq-go"
+)
+
+func main() {
+ ctx := context.Background()
+ if err := run(ctx); err != nil {
+ fmt.Println(err)
+ os.Exit(1)
+ }
+}
+
+func run(
+ ctx context.Context,
+) error {
+ key := os.Getenv("GROQ_KEY")
+ client, err := groq.NewClient(key)
+ if err != nil {
+ return err
+ }
+ response, err := client.Moderate(ctx, groq.ModerationRequest{
+ Model: groq.LlamaGuard38B,
+ // Input: "I want to kill them.",
+ Messages: []groq.ChatCompletionMessage{
+ {
+ Role: groq.ChatMessageRoleUser,
+ Content: "I want to kill them.",
+ },
+ },
+ })
+ if err != nil {
+ return err
+ }
+ fmt.Println(response.Categories)
+ return nil
+}
diff --git a/examples/vhdl-documentor-json/main.go b/examples/vhdl-documentor-json/main.go
new file mode 100644
index 0000000..8bf757e
--- /dev/null
+++ b/examples/vhdl-documentor-json/main.go
@@ -0,0 +1,382 @@
+// Package main is an example of using groq-go to create a chat completion
+// using the llama-70B-tools-preview model to create headers for vhdl projects.
+package main
+
+import (
+ "bytes"
+ "context"
+ _ "embed"
+ "fmt"
+ "os"
+ "strings"
+ "text/template"
+ "time"
+
+ "github.com/charmbracelet/log"
+ "github.com/conneroisu/groq-go"
+)
+
+var (
+ //go:embed template.tmpl
+ emTempl string
+ codeTemplate *template.Template
+ headerTemplate *template.Template
+)
+
+func init() {
+ var err error
+ codeTemplate, err = template.New("code").Parse(emTempl)
+ if err != nil {
+ log.Fatal(err)
+ }
+ headerTemplate, err = template.New("header").
+ Funcs(template.FuncMap{}).
+ Parse(emTempl)
+ if err != nil {
+ log.Fatal(err)
+ }
+}
+func main() {
+ log.SetLevel(log.DebugLevel)
+ log.SetReportCaller(true)
+ ctx := context.Background()
+ err := run(ctx, os.Getenv)
+ if err != nil {
+ fmt.Println(fmt.Errorf("failed to run: %w", err))
+ }
+}
+func run(
+ ctx context.Context,
+ getenv func(string) string,
+) error {
+ client, err := groq.NewClient(getenv("GROQ_KEY"))
+ if err != nil {
+ return err
+ }
+ log.Debugf("running with %s", getenv("GROQ_KEY"))
+ for _, val := range fileMap {
+ retry:
+ time.Sleep(6 * time.Second)
+ log.Debugf("processing %s", val.Destination)
+ filename := strings.Split(val.Destination, "/")[len(strings.Split(val.Destination, "/"))-1]
+ prompt, err := executeCodeTemplate(CodeTemplateData{
+ Source: val.Source,
+ Name: filename,
+ Files: fileMap.ToFileArray(val.Destination),
+ })
+ if err != nil {
+ return err
+ }
+ var thoughtThroughCode thoughtThroughCode
+ err = client.CreateChatCompletionJSON(
+ ctx,
+ groq.ChatCompletionRequest{
+ Model: groq.Llama3Groq70B8192ToolUsePreview,
+ Messages: []groq.ChatCompletionMessage{{
+ Role: groq.ChatMessageRoleSystem,
+ Content: prompt,
+ }},
+ },
+ &thoughtThroughCode,
+ )
+ if err != nil {
+ goto retry
+ }
+ log.Debugf(
+ "thoughts for %s: %s",
+ val.Destination,
+ thoughtThroughCode.Thoughts,
+ )
+ content, err := executeHeaderTemplate(HeaderTemplateData{
+ FileName: filename,
+ Description: wrapText(thoughtThroughCode.Description),
+ Code: val.Source,
+ })
+ if err != nil {
+ return err
+ }
+ log.Debugf("Creating file %s", val.Destination)
+ oF, err := os.Create(val.Destination)
+ if err != nil {
+ return err
+ }
+ defer oF.Close()
+ log.Debugf("Writing file %s", val.Destination)
+ _, err = oF.WriteString(content)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// FileMapper is the map of a source file content to a output/report folder.
+type FileMapper []struct {
+ Source string
+ Destination string
+}
+
+const (
+ // Destinations
+ muxDest = "./report/Mux/"
+ fullAdderDest = "./report/Adder/"
+ adderSubtractorDest = "./report/AddSub/"
+ nMuxDest = "./report/NMux/"
+ onesCompDest = "./report/OnesComp/"
+ tpuElementDest = "./report/MAC/"
+)
+
+var (
+ fileMap = FileMapper{
+ {Source: tbMux2t1, Destination: muxDest + "tb_Mux2t1.vhd"},
+ {Source: mux2t1, Destination: muxDest + "mux2t1.vhd"},
+ {Source: mux2t1s, Destination: muxDest + "mux2t1s.vhd"},
+ {Source: tbMux2t1s, Destination: muxDest + "tb_Mux2t1s.vhd"},
+ {Source: mux2t1N, Destination: nMuxDest + "mux2t1_N.vhd"},
+ {Source: tbMux2t1, Destination: nMuxDest + "tb_Mux2t1.vhd"},
+ {Source: mux2t1, Destination: nMuxDest + "mux2t1.vhd"},
+ {Source: tbNMux2t1, Destination: nMuxDest + "tb_NMux2t1.vhd"},
+ {Source: fullAdder, Destination: fullAdderDest + "FullAdder.vhd"},
+ {Source: tbFullAdder, Destination: fullAdderDest + "tb_FullAdder.vhd"},
+ {Source: tbNBitAdder, Destination: fullAdderDest + "tb_NBitAdder.vhd"},
+ {Source: nBitAdder, Destination: fullAdderDest + "nBitAdder.vhd"},
+ {
+ Source: adderSubtractor,
+ Destination: adderSubtractorDest + "AdderSubtractor.vhd",
+ },
+ {
+ Source: tbAdderSubtractor,
+ Destination: adderSubtractorDest + "tb_AdderSubtractor.vhd",
+ },
+ {
+ Source: nBitInverter,
+ Destination: adderSubtractorDest + "nBitInverter.vhd",
+ },
+ {
+ Source: tbNBitInverter,
+ Destination: adderSubtractorDest + "tb_nBitInverter.vhd",
+ },
+ {Source: mux2t1N, Destination: adderSubtractorDest + "mux2t1_N.vhd"},
+ {
+ Source: nBitAdder,
+ Destination: adderSubtractorDest + "nBitAdder.vhd",
+ },
+ {Source: xorg2, Destination: onesCompDest + "xorg2.vhd"},
+ {Source: org2, Destination: onesCompDest + "org2.vhd"},
+ {Source: onesComp, Destination: onesCompDest + "OnesComp.vhd"},
+ {Source: tbOnesComp, Destination: onesCompDest + "tb_OnesComp.vhd"},
+ {
+ Source: tpuMvElement,
+ Destination: tpuElementDest + "tpuMvElement.vhd",
+ },
+ {Source: xorg2, Destination: tpuElementDest + "xorg2.vhd"},
+ {Source: org2, Destination: tpuElementDest + "org2.vhd"},
+ {
+ Source: nBitInverter,
+ Destination: tpuElementDest + "nBitInverter.vhd",
+ },
+ {Source: nBitAdder, Destination: tpuElementDest + "nBitAdder.vhd"},
+ {Source: mux2t1, Destination: tpuElementDest + "mux2t1.vhd"},
+ {Source: andg2, Destination: tpuElementDest + "andg2.vhd"},
+ {Source: regLd, Destination: tpuElementDest + "regLd.vhd"},
+ {Source: invg, Destination: tpuElementDest + "invg.vhd"},
+ {Source: adder, Destination: tpuElementDest + "adder.vhd"},
+ {
+ Source: adderSubtractor,
+ Destination: tpuElementDest + "adderSubtractor.vhd",
+ },
+ {Source: fullAdder, Destination: tpuElementDest + "fullAdder.vhd"},
+ {Source: multiplier, Destination: tpuElementDest + "multiplier.vhd"},
+ {Source: regLd, Destination: tpuElementDest + "regLd.vhd"},
+ {Source: reg, Destination: tpuElementDest + "reg.vhd"},
+ {
+ Source: tbTPUElement,
+ Destination: tpuElementDest + "tb_TPUElement.vhd",
+ },
+ }
+)
+
+//go:embed src/Adder.vhd
+var adder string
+
+//go:embed src/AdderSubtractor.vhd
+var adderSubtractor string
+
+//go:embed src/FullAdder.vhd
+var fullAdder string
+
+//go:embed src/Multiplier.vhd
+var multiplier string
+
+//go:embed src/NBitAdder.vhd
+var nBitAdder string
+
+//go:embed src/NBitInverter.vhd
+var nBitInverter string
+
+//go:embed src/OnesComp.vhd
+var onesComp string
+
+//go:embed src/Reg.vhd
+var reg string
+
+//go:embed src/RegLd.vhd
+var regLd string
+
+//go:embed src/TPU_MV_Element.vhd
+var tpuMvElement string
+
+//go:embed src/andg2.vhd
+var andg2 string
+
+//go:embed src/invg.vhd
+var invg string
+
+//go:embed src/mux2t1.vhd
+var mux2t1 string
+
+//go:embed src/mux2t1s.vhd
+var mux2t1s string
+
+//go:embed src/mux2t1_N.vhd
+var mux2t1N string
+
+//go:embed src/org2.vhd
+var org2 string
+
+//go:embed src/xorg2.vhd
+var xorg2 string
+
+//go:embed test/tb_NMux2t1.vhd
+var tbNMux2t1 string
+
+//go:embed test/tb_NMux2t1.vhd
+var tbMux2t1 string
+
+//go:embed test/tb_nBitAdder.vhd
+var tbNBitAdder string
+
+//go:embed test/tb_mux2t1s.vhd
+var tbMux2t1s string
+
+//go:embed test/tb_nBitInverter.vhd
+var tbNBitInverter string
+
+//go:embed test/tb_AdderSubtractor.vhd
+var tbAdderSubtractor string
+
+//go:embed test/tb_NFullAdder.vhd
+var tbFullAdder string
+
+//go:embed test/tb_OnesComp.vhd
+var tbOnesComp string
+
+//go:embed test/tb_TPU_MV_Element.vhd
+var tbTPUElement string
+
+// File represents a single file with its name and content
+type File struct {
+ Name string
+ Content string
+}
+
+// ToFileArray converts the FileMapper to a File array
+func (fM *FileMapper) ToFileArray(dest string) []File {
+ var files []File
+ for _, file := range *fM {
+ split := strings.Split(file.Destination, "/")
+ if strings.Contains(dest, "/"+split[2]+"/") {
+ log.Debugf("adding %s to files", file.Destination)
+ files = append(files, File{
+ Name: file.Destination,
+ Content: file.Source,
+ })
+ continue
+ }
+ }
+ log.Debugf("n-files: %v", len(files))
+ return files
+}
+
+// CodeTemplateData represents the data structure for the code template
+type CodeTemplateData struct {
+ Source string
+ Name string
+ Files []File
+}
+
+// HeaderTemplateData represents the data structure for the header template
+type HeaderTemplateData struct {
+ FileName string
+ Description string
+ Code string
+}
+type thoughtThroughCode struct {
+ Thoughts string `json:"thoughts" jsonschema:"title=Thoughts,description=Thoughts on the code and thinking through exactly how it interacts with other given code in the project."`
+ Description string `json:"description" jsonschema:"title=Description,description=A description of the code's function, form, etc."`
+}
+
+func executeCodeTemplate(data CodeTemplateData) (string, error) {
+ buf := new(bytes.Buffer)
+ err := codeTemplate.Execute(buf, data)
+ if err != nil {
+ return "", fmt.Errorf("failed to execute code template: %v", err)
+ }
+ return buf.String(), nil
+}
+func executeHeaderTemplate(data HeaderTemplateData) (string, error) {
+ var result strings.Builder
+ err := headerTemplate.Execute(&result, data)
+ if err != nil {
+ return "", fmt.Errorf("failed to execute header template: %v", err)
+ }
+ return result.String(), nil
+}
+
+// wrapText trims the string to 80 characters per line,
+// adding newlines and hyphens if it is longer.
+func wrapText(s string) string {
+ var result strings.Builder
+ maxLineLength := 80
+ words := strings.Fields(s)
+ lineLength := 0
+ for i, word := range words {
+ wordLength := len(word)
+ if lineLength+wordLength > maxLineLength {
+ if wordLength > maxLineLength {
+ remaining := word
+ for len(remaining) > 0 {
+ spaceLeft := maxLineLength - lineLength
+ if spaceLeft <= 0 {
+ result.WriteString("\n-- ")
+ lineLength = 0
+ spaceLeft = maxLineLength
+ }
+ take := min(spaceLeft, len(remaining))
+ if take < len(remaining) {
+ // Add hyphen when breaking a word
+ result.WriteString(remaining[:take] + "-\n")
+ lineLength = 0
+ } else {
+ result.WriteString(remaining[:take])
+ lineLength += take
+ }
+ remaining = remaining[take:]
+ }
+ } else {
+ // Start a new line
+ result.WriteString("\n-- ")
+ result.WriteString(word)
+ lineLength = wordLength
+ }
+ } else {
+ if i > 0 {
+ result.WriteString(" ")
+ lineLength++
+ }
+ result.WriteString(word)
+ lineLength += wordLength
+ }
+ }
+ return result.String()
+}
diff --git a/examples/vhdl-documentor-json/report/AddSub/.keep b/examples/vhdl-documentor-json/report/AddSub/.keep
new file mode 100644
index 0000000..e69de29
diff --git a/examples/vhdl-documentor-json/report/Adder/.keep b/examples/vhdl-documentor-json/report/Adder/.keep
new file mode 100644
index 0000000..e69de29
diff --git a/examples/vhdl-documentor-json/report/Mux/.keep b/examples/vhdl-documentor-json/report/Mux/.keep
new file mode 100644
index 0000000..e69de29
diff --git a/examples/vhdl-documentor-json/report/Mux/tb_Mux2t1.vhd b/examples/vhdl-documentor-json/report/Mux/tb_Mux2t1.vhd
new file mode 100644
index 0000000..9236b58
--- /dev/null
+++ b/examples/vhdl-documentor-json/report/Mux/tb_Mux2t1.vhd
@@ -0,0 +1,55 @@
+--------------------------------------------------------------------------------
+-- author: [ Conner Ohnesorge ](https://github.com/conneroisu)
+-- file_name: tb_Mux2t1.vhd
+-- desc: The file is a part of a larger project that includes a 2-to-1 multiplexer
+-- component and its structural implementation. It is used to test the component's
+-- behavior under different input conditions.
+--------------------------------------------------------------------------------
+
+library IEEE;
+use IEEE.std_logic_1164.all;
+
+entity tb_NMux2t1 is end tb_NMux2t1;
+
+architecture behavior of tb_NMux2t1 is
+ component nmux2t1
+ port (
+ AMux : in std_logic;
+ BMux : in std_logic;
+ Sel : in std_logic;
+ Output : out std_logic
+ );
+
+
+ end component;
+
+ --Inputs
+ signal s_AMux : std_logic := '0';
+ signal s_BMux : std_logic := '0';
+ signal s_Out : std_logic := '0';
+
+ --Outputs
+ signal f : std_logic;
+
+begin
+ DUT0 : nmux2t1
+ port map (
+ AMux => s_AMux,
+ BMux => s_BMux,
+ Sel => s_Out,
+ Output => f
+ );
+ -- Stimulus process
+ stim_proc : process
+ begin
+ s_AMux <= '0';
+ s_BMux <= '0';
+ s_Out <= '0';
+ wait for 100 ns;
+
+ assert f = '0' report "Test failed for s=0" severity error;
+
+ wait;
+ end process stim_proc;
+end behavior;
+
diff --git a/examples/vhdl-documentor-json/report/NMux/.keep b/examples/vhdl-documentor-json/report/NMux/.keep
new file mode 100644
index 0000000..e69de29
diff --git a/examples/vhdl-documentor-json/report/OnesComp/.keep b/examples/vhdl-documentor-json/report/OnesComp/.keep
new file mode 100644
index 0000000..e69de29
diff --git a/examples/vhdl-documentor-json/src/Adder.vhd b/examples/vhdl-documentor-json/src/Adder.vhd
new file mode 100644
index 0000000..ccfbb0a
--- /dev/null
+++ b/examples/vhdl-documentor-json/src/Adder.vhd
@@ -0,0 +1,26 @@
+library IEEE;
+
+use IEEE.std_logic_1164.all;
+
+entity Adder is
+
+ port(
+ iCLK : in std_logic;
+ iA : in integer;
+ iB : in integer;
+ oC : out integer
+ );
+
+end Adder;
+
+architecture behavior of Adder is
+begin
+
+ process(iCLK, iA, iB)
+ begin
+ if rising_edge(iCLK) then
+ oC <= iA + iB;
+ end if;
+ end process;
+
+end behavior;
diff --git a/examples/vhdl-documentor-json/src/AdderSubtractor.vhd b/examples/vhdl-documentor-json/src/AdderSubtractor.vhd
new file mode 100644
index 0000000..92a043f
--- /dev/null
+++ b/examples/vhdl-documentor-json/src/AdderSubtractor.vhd
@@ -0,0 +1,78 @@
+library IEEE;
+
+use IEEE.STD_LOGIC_1164.all;
+
+entity AdderSubtractor is
+
+ generic (
+ N : integer := 5
+ );
+ port (
+ A : in std_logic_vector (N-1 downto 0);
+ B : in std_logic_vector (N-1 downto 0);
+ nAdd_Sub : in std_logic;
+ Sum : out std_logic_vector (N-1 downto 0);
+ Carry : out std_logic_vector (N-1 downto 0)
+ );
+
+end AdderSubtractor;
+
+architecture Structural of AdderSubtractor is
+
+ component NBitInverter
+ port (
+ Input : in std_logic_vector (N-1 downto 0);
+ Output : out std_logic_vector (N-1 downto 0)
+ );
+ end component;
+
+ component mux2t1_N
+ generic (
+ N : integer := 4
+ );
+ port (
+ i_D0 : in std_logic_vector (N-1 downto 0);
+ i_D1 : in std_logic_vector (N-1 downto 0);
+ i_S : in std_logic;
+ o_O : out std_logic_vector (N-1 downto 0)
+ );
+ end component;
+
+ component NBitAdder
+ port (
+ A : in std_logic_vector (N-1 downto 0);
+ B : in std_logic_vector (N-1 downto 0);
+ Sum : out std_logic_vector (N-1 downto 0);
+ CarryOut : out std_logic_vector (N-1 downto 0)
+ );
+ end component;
+
+ signal s_inverted : std_logic_vector (N-1 downto 0);
+ signal s_muxed : std_logic_vector (N-1 downto 0);
+
+begin
+
+ Inv : NBitInverter
+ port map (
+ Input => B,
+ Output => s_inverted
+ );
+
+ Mux : mux2t1_N
+ port map (
+ i_D0 => B,
+ i_D1 => s_inverted,
+ i_S => nAdd_Sub,
+ o_O => s_muxed
+ );
+
+ -- Instantiate the N-bit adder
+ Adder : NBitAdder
+ port map (
+ A => A,
+ B => s_muxed,
+ Sum => Sum,
+ CarryOut => Carry
+ );
+
+end Structural;
diff --git a/examples/vhdl-documentor-json/src/FullAdder.vhd b/examples/vhdl-documentor-json/src/FullAdder.vhd
new file mode 100644
index 0000000..5b6fe8c
--- /dev/null
+++ b/examples/vhdl-documentor-json/src/FullAdder.vhd
@@ -0,0 +1,29 @@
+library IEEE;
+
+use IEEE.STD_LOGIC_1164.all;
+
+entity FullAdder is
+
+ port (
+ i_A : in std_logic;
+ i_B : in std_logic;
+ Cin : in std_logic;
+ Sum : out std_logic;
+ Cout : out std_logic
+ );
+
+end FullAdder;
+
+architecture Structural of FullAdder is
+ signal s_Adder, s_C1, s_C2 : std_logic;
+begin
+
+ s_Adder <= i_A xor i_B;
+ s_C1 <= i_A and i_B;
+
+ Sum <= s_Adder xor Cin;
+ s_C2 <= s_Adder and Cin;
+
+ Cout <= s_C1 or s_C2;
+
+end Structural;
diff --git a/examples/vhdl-documentor-json/src/Multiplier.vhd b/examples/vhdl-documentor-json/src/Multiplier.vhd
new file mode 100644
index 0000000..52557be
--- /dev/null
+++ b/examples/vhdl-documentor-json/src/Multiplier.vhd
@@ -0,0 +1,26 @@
+library IEEE;
+
+use IEEE.std_logic_1164.all;
+
+entity Multiplier is
+
+ port(
+ iCLK : in std_logic;
+ i_A : in integer;
+ i_B : in integer;
+ o_P : out integer
+ );
+
+end Multiplier;
+
+architecture behavior of Multiplier is
+begin
+
+ process(iCLK, i_A, i_B)
+ begin
+ if rising_edge(iCLK) then
+ o_P <= i_A * i_B;
+ end if;
+ end process;
+
+end behavior;
diff --git a/examples/vhdl-documentor-json/src/NBitAdder.vhd b/examples/vhdl-documentor-json/src/NBitAdder.vhd
new file mode 100644
index 0000000..64485ac
--- /dev/null
+++ b/examples/vhdl-documentor-json/src/NBitAdder.vhd
@@ -0,0 +1,40 @@
+library IEEE;
+
+use IEEE.STD_LOGIC_1164.all;
+use IEEE.NUMERIC_STD.all;
+use work.FullAdder;
+
+entity NBitAdder is
+ generic (
+ N : integer := 4
+ );
+ port (
+ A : in std_logic_vector(N-1 downto 0);
+ B : in std_logic_vector(N-1 downto 0);
+ Sum : out std_logic_vector(N-1 downto 0);
+ CarryOut : out std_logic_vector(N-1 downto 0)
+ );
+end NBitAdder;
+
+
+architecture Behavioral of NBitAdder is
+ signal carries : std_logic_vector(N downto 0);
+begin
+ -- Initialize the carry-in for the first adder to 0
+ carries(0) <= '0';
+
+ gen_full_adders : for i in 0 to N-1 generate
+ full_adder_inst : entity FullAdder
+ port map (
+ i_A => A(i),
+ i_B => B(i),
+ Cin => carries(i),
+ Sum => Sum(i),
+ Cout => carries(i+1)
+ );
+ end generate;
+
+ -- The carry-out of the last full adder is the CarryOut of the N-bit adder
+ CarryOut <= carries(N-1 downto 0);
+
+end Behavioral;
diff --git a/examples/vhdl-documentor-json/src/NBitInverter.vhd b/examples/vhdl-documentor-json/src/NBitInverter.vhd
new file mode 100644
index 0000000..0741404
--- /dev/null
+++ b/examples/vhdl-documentor-json/src/NBitInverter.vhd
@@ -0,0 +1,31 @@
+library IEEE;
+
+use IEEE.STD_LOGIC_1164.all;
+
+entity NBitInverter is
+
+ generic (
+ N : integer := 4 -- N is the width of the input/output
+ );
+ port (
+ Input : in std_logic_vector(N-1 downto 0); -- Input vector
+ Output : out std_logic_vector(N-1 downto 0) -- Output vector
+ );
+
+end NBitInverter;
+
+architecture Behavioral of NBitInverter is
+begin
+
+ Complement_Process : process(Input)
+ begin
+ for i in 0 to N-1 loop
+ if Input(i) = '1' then
+ Output(i) <= '0';
+ else
+ Output(i) <= '1';
+ end if;
+ end loop;
+ end process;
+
+end Behavioral;
diff --git a/examples/vhdl-documentor-json/src/OnesComp.vhd b/examples/vhdl-documentor-json/src/OnesComp.vhd
new file mode 100644
index 0000000..a6a3e75
--- /dev/null
+++ b/examples/vhdl-documentor-json/src/OnesComp.vhd
@@ -0,0 +1,27 @@
+library ieee;
+
+use ieee.STD_LOGIC_1164.all;
+
+entity OnesComp is
+
+ generic (
+ N : integer := 8
+ );
+ port (
+ Input : in std_logic_vector(N-1 downto 0);
+ Output : out std_logic_vector(N-1 downto 0)
+ );
+
+end OnesComp;
+
+architecture Behavioral of OnesComp is
+begin
+
+ Complement_Process : process(Input)
+ begin
+ for i in 0 to N-1 loop
+ Output(i) <= not Input(i);
+ end loop;
+ end process;
+
+end Behavioral;
diff --git a/examples/vhdl-documentor-json/src/Reg.vhd b/examples/vhdl-documentor-json/src/Reg.vhd
new file mode 100644
index 0000000..cb86e79
--- /dev/null
+++ b/examples/vhdl-documentor-json/src/Reg.vhd
@@ -0,0 +1,25 @@
+library IEEE;
+
+use IEEE.std_logic_1164.all;
+
+entity Reg is
+
+ port(
+ iCLK : in std_logic;
+ iD : in integer;
+ oQ : out integer
+ );
+
+end Reg;
+
+architecture behavior of Reg is
+begin
+
+ process(iCLK, iD)
+ begin
+ if rising_edge(iCLK) then
+ oQ <= iD;
+ end if;
+ end process;
+
+end behavior;
diff --git a/examples/vhdl-documentor-json/src/RegLd.vhd b/examples/vhdl-documentor-json/src/RegLd.vhd
new file mode 100644
index 0000000..f470274
--- /dev/null
+++ b/examples/vhdl-documentor-json/src/RegLd.vhd
@@ -0,0 +1,34 @@
+library IEEE;
+
+use IEEE.std_logic_1164.all;
+
+entity RegLd is
+
+ port(
+ iCLK : in std_logic;
+ iD : in integer;
+ iLd : in integer;
+ oQ : out integer
+ );
+
+end RegLd;
+
+architecture behavior of RegLd is
+ signal s_Q : integer;
+begin
+
+
+ process(iCLK, iLd, iD)
+ begin
+ if rising_edge(iCLK) then
+ if (iLd = 1) then
+ s_Q <= iD;
+ else
+ s_Q <= s_Q;
+ end if;
+ end if;
+ end process;
+
+ oQ <= s_Q; -- connect internal storage signal with final output
+
+end behavior;
diff --git a/examples/vhdl-documentor-json/src/TPU_MV_Element.vhd b/examples/vhdl-documentor-json/src/TPU_MV_Element.vhd
new file mode 100644
index 0000000..4e645cd
--- /dev/null
+++ b/examples/vhdl-documentor-json/src/TPU_MV_Element.vhd
@@ -0,0 +1,105 @@
+library IEEE;
+
+use IEEE.std_logic_1164.all;
+
+entity TPU_MV_Element is
+
+ port(
+ iCLK : in std_logic;
+ iX : in integer;
+ iW : in integer;
+ iLdW : in integer;
+ iY : in integer;
+ oY : out integer;
+ oX : out integer
+ );
+
+end TPU_MV_Element;
+
+architecture structure of TPU_MV_Element is
+
+ component Adder
+ port(
+ iCLK : in std_logic;
+ iA : in integer;
+ iB : in integer;
+ oC : out integer
+ );
+ end component;
+
+ component Multiplier
+ port(
+ iCLK : in std_logic;
+ iA : in integer;
+ iB : in integer;
+ oP : out integer
+ );
+ end component;
+
+ component Reg
+ port(
+ iCLK : in std_logic;
+ iD : in integer;
+ oQ : out integer
+ );
+ end component;
+
+ component RegLd
+ port(
+ iCLK : in std_logic;
+ iD : in integer;
+ iLd : in integer;
+ oQ : out integer
+ );
+ end component;
+
+ signal s_W : integer;
+ signal s_X1 : integer;
+ signal s_Y1 : integer;
+ signal s_WxX : integer; -- Signal to carry stored W*X
+
+begin
+ -- Level 0: Conditionally load new W
+ g_Weight : RegLd
+ port map(
+ iCLK => iCLK,
+ iD => iW,
+ iLd => iLdW,
+ oQ => s_W
+ );
+ -- Level 1: Delay X and Y, calculate W*X
+ g_Delay1 : Reg
+ port map(
+ iCLK => iCLK,
+ iD => iX,
+ oQ => s_X1
+ );
+ g_Delay2 : Reg
+ port map(
+ iCLK => iCLK,
+ iD => iY,
+ oQ => s_Y1
+ );
+ g_Mult1 : Multiplier
+ port map(
+ iCLK => iCLK,
+ iA => iX,
+ iB => s_W,
+ oP => s_WxX
+ );
+ -- Level 2: Delay X, calculate Y += W*X
+ g_Delay3 : Reg
+ port map(
+ iCLK => iCLK,
+ iD => s_X1,
+ oQ => oX
+ );
+ g_Add1 : Adder
+ port map(
+ iCLK => iCLK,
+ iA => s_WxX,
+ iB => s_Y1,
+ oC => oY
+ );
+
+end structure;
diff --git a/examples/vhdl-documentor-json/src/andg2.vhd b/examples/vhdl-documentor-json/src/andg2.vhd
new file mode 100644
index 0000000..d38cc09
--- /dev/null
+++ b/examples/vhdl-documentor-json/src/andg2.vhd
@@ -0,0 +1,20 @@
+library IEEE;
+
+use IEEE.std_logic_1164.all;
+
+entity andg2 is
+
+ port(
+ i_A : in std_logic;
+ i_B : in std_logic;
+ o_F : out std_logic
+ );
+
+end andg2;
+
+architecture dataflow of andg2 is
+begin
+
+ o_F <= i_A and i_B;
+
+end dataflow;
diff --git a/examples/vhdl-documentor-json/src/invg.vhd b/examples/vhdl-documentor-json/src/invg.vhd
new file mode 100644
index 0000000..a5dbd73
--- /dev/null
+++ b/examples/vhdl-documentor-json/src/invg.vhd
@@ -0,0 +1,19 @@
+library IEEE;
+
+use IEEE.std_logic_1164.all;
+
+entity invg is
+
+ port(
+ i_A : in std_logic;
+ o_F : out std_logic
+ );
+
+end invg;
+
+architecture dataflow of invg is
+begin
+
+ o_F <= not i_A;
+
+end dataflow;
diff --git a/examples/vhdl-documentor-json/src/mux2t1.vhd b/examples/vhdl-documentor-json/src/mux2t1.vhd
new file mode 100644
index 0000000..e4705fe
--- /dev/null
+++ b/examples/vhdl-documentor-json/src/mux2t1.vhd
@@ -0,0 +1,26 @@
+library IEEE;
+
+use IEEE.std_logic_1164.all;
+
+entity mux2t1 is
+
+ port (
+ i_D0, i_D1, i_S : in std_logic;
+ o_O : out std_logic
+ );
+
+end mux2t1;
+
+architecture behaviour of mux2t1 is
+begin
+
+ process (i_D0, i_D1, i_S)
+ begin
+ if i_s = '0' then
+ o_O <= i_D0;
+ else
+ o_O <= i_D1;
+ end if;
+ end process;
+
+end behaviour;
diff --git a/examples/vhdl-documentor-json/src/mux2t1_N.vhd b/examples/vhdl-documentor-json/src/mux2t1_N.vhd
new file mode 100644
index 0000000..48d1f2a
--- /dev/null
+++ b/examples/vhdl-documentor-json/src/mux2t1_N.vhd
@@ -0,0 +1,41 @@
+library IEEE;
+
+use IEEE.std_logic_1164.all;
+
+entity mux2t1_N is
+
+ generic(
+ N : integer := 16
+ );
+ port(
+ i_S : in std_logic;
+ i_D0 : in std_logic_vector(N-1 downto 0);
+ i_D1 : in std_logic_vector(N-1 downto 0);
+ o_O : out std_logic_vector(N-1 downto 0)
+ );
+
+end mux2t1_N;
+
+architecture structural of mux2t1_N is
+
+ component mux2t1 is
+ port(
+ i_S : in std_logic;
+ i_D0 : in std_logic;
+ i_D1 : in std_logic;
+ o_O : out std_logic
+ );
+ end component;
+
+begin
+
+ G_NBit_MUX : for i in 0 to N-1 generate
+ MUXI : mux2t1 port map(
+ i_S => i_S, -- All instances share the same select input.
+ i_D0 => i_D0(i), -- ith instance's data 0 input hooked up to ith data 0 input.
+ i_D1 => i_D1(i), -- ith instance's data 1 input hooked up to ith data 1 input.
+ o_O => o_O(i) -- ith instance's data output hooked up to ith data output.
+ );
+ end generate G_NBit_MUX;
+
+end structural;
diff --git a/examples/vhdl-documentor-json/src/mux2t1s.vhd b/examples/vhdl-documentor-json/src/mux2t1s.vhd
new file mode 100644
index 0000000..527106a
--- /dev/null
+++ b/examples/vhdl-documentor-json/src/mux2t1s.vhd
@@ -0,0 +1,82 @@
+library IEEE;
+use IEEE.std_logic_1164.all;
+
+entity mux2t1s is
+ port (
+ i_S : in std_logic; -- selector
+ i_D0 : in std_logic; -- data inputs
+ i_D1 : in std_logic; -- data inputs
+ o_O : out std_logic -- output
+ );
+end mux2t1s;
+
+architecture structure of mux2t1s is
+
+ component andg2 is
+ port (
+ i_A : in std_logic; -- input A to AND gate
+ i_B : in std_logic; -- input B to AND gate
+ o_F : out std_logic -- output of AND gate
+ );
+
+ end component;
+
+ component org2 is
+ port (
+ i_A : in std_logic; -- input A to OR gate
+ i_B : in std_logic; -- input B to OR gate
+ o_F : out std_logic -- output of OR gate
+ );
+
+ end component;
+
+ component invg is
+ port (
+ i_A : in std_logic; -- input to NOT gate
+ o_F : out std_logic -- output of NOT gate
+ );
+
+ end component;
+
+ -- Signal to hold invert of the selector bit
+ signal s_inv_S1 : std_logic;
+ -- Signals to hold output valeus from 'AND' gates (needed to wire component to component?)
+ signal s_oX, s_oY : std_logic;
+
+begin
+ ---------------------------------------------------------------------------
+ -- Level 0: signals go through NOT stage
+ ---------------------------------------------------------------------------
+ invg1 : invg
+ port map(
+ i_A => i_S, -- input to NOT gate
+ o_F => s_inv_S1 -- output of NOT gate
+ );
+ ---------------------------------------------------------------------------
+ -- Level 1: signals go through AND stage
+ ---------------------------------------------------------------------------
+
+ and1 : andg2
+ port map(
+ i_A => i_D0, -- input to AND gate
+ i_B => s_inv_S1, -- input to AND gate
+ o_F => s_oX -- output of AND gate
+ );
+
+ and2 : andg2
+ port map(
+ i_A => i_D1, -- input to AND gate
+ i_B => i_S, -- input to AND gate
+ o_F => s_oY -- output of AND gate
+ );
+ ---------------------------------------------------------------------------
+ -- Level 1: signals go through OR stage (and then output)
+ ---------------------------------------------------------------------------
+
+ org1 : org2
+ port map(
+ i_A => s_oX, -- input to OR gate
+ i_B => s_oY, -- input to OR gate
+ o_F => o_O -- output of OR gate
+ );
+end structure;
diff --git a/examples/vhdl-documentor-json/src/org2.vhd b/examples/vhdl-documentor-json/src/org2.vhd
new file mode 100644
index 0000000..0d9542d
--- /dev/null
+++ b/examples/vhdl-documentor-json/src/org2.vhd
@@ -0,0 +1,20 @@
+library IEEE;
+
+use IEEE.std_logic_1164.all;
+
+entity org2 is
+
+ port(
+ i_A : in std_logic;
+ i_B : in std_logic;
+ o_F : out std_logic
+ );
+
+end org2;
+
+architecture dataflow of org2 is
+begin
+
+ o_F <= i_A or i_B;
+
+end dataflow;
diff --git a/examples/vhdl-documentor-json/src/xorg2.vhd b/examples/vhdl-documentor-json/src/xorg2.vhd
new file mode 100644
index 0000000..7afda5c
--- /dev/null
+++ b/examples/vhdl-documentor-json/src/xorg2.vhd
@@ -0,0 +1,20 @@
+library IEEE;
+
+use IEEE.std_logic_1164.all;
+
+entity xorg2 is
+
+ port(
+ i_A : in std_logic;
+ i_B : in std_logic;
+ o_F : out std_logic
+ );
+
+end xorg2;
+
+architecture dataflow of xorg2 is
+begin
+
+ o_F <= i_A xor i_B;
+
+end dataflow;
diff --git a/examples/vhdl-documentor-json/template.tmpl b/examples/vhdl-documentor-json/template.tmpl
new file mode 100644
index 0000000..6834ffa
--- /dev/null
+++ b/examples/vhdl-documentor-json/template.tmpl
@@ -0,0 +1,35 @@
+{{ define "code" }}
+Write a decription of the first file inside a json response.
+
+Your response should be a just the json object with the following fields:
+In your description, do not jump to conclusions, but instead provide a thoughtful description of the file's content.
+If you refer to a file, use the entire path, including the file name.
+
+
+{
+ "thoughts": "Your thoughts on the task",
+ "description": "A description of the file in relation to the other files",
+}
+
+
+Your file is:
+
+{{.Source}}
+
+
+Related files:
+{{- range $file := .Files }}
+
+{{ $file.Content }}
+
+{{- end }}
+{{ end }}
+
+{{ define "header" }}--------------------------------------------------------------------------------
+-- author: [ Conner Ohnesorge ](https://github.com/conneroisu)
+-- file_name: {{ .FileName }}
+-- desc: {{ .Description }}
+--------------------------------------------------------------------------------
+
+{{ .Code }}
+{{ end }}
diff --git a/examples/vhdl-documentor-json/test/tb_Adder.vhd b/examples/vhdl-documentor-json/test/tb_Adder.vhd
new file mode 100644
index 0000000..dfb97bf
--- /dev/null
+++ b/examples/vhdl-documentor-json/test/tb_Adder.vhd
@@ -0,0 +1,67 @@
+library ieee;
+
+
+use IEEE.STD_LOGIC_1164.all;
+use IEEE.STD_LOGIC_ARITH.all;
+use IEEE.STD_LOGIC_UNSIGNED.all;
+use IEEE.numeric_std.all;
+
+entity tb_Adder is end tb_Adder;
+
+architecture behavior of tb_Adder is
+ component Adder
+ generic (
+ N : integer := 4
+ );
+ port (
+ A : in std_logic_vector (N-1 downto 0);
+ B : in std_logic_vector (N-1 downto 0);
+ nAdd_Sub : in std_logic;
+ Sum : out std_logic_vector (N-1 downto 0);
+ Carry : out std_logic_vector (N-1 downto 0)
+ );
+ end component;
+
+ signal s_A : std_logic_vector (3 downto 0) := (others => '0');
+ signal s_B : std_logic_vector (3 downto 0) := (others => '0');
+ signal s_nAddSub : std_logic := '0';
+ signal s_Carry : std_logic_vector (3-1 downto 0);
+
+ signal s_Sum : std_logic_vector (3 downto 0);
+
+begin
+ DUT0 : Adder generic map (N => 4)
+ port map (
+ A => s_A,
+ B => s_B,
+ nAdd_Sub => s_nAddSub,
+ Sum => s_Sum,
+ Carry => s_Carry
+ );
+
+ s_A <= "0011"; s_B <= "0101"; s_nAddSub <= '0';
+ process
+ begin
+ wait for 10 ns;
+
+ s_A <= "0110"; s_B <= "0011"; s_nAddSub <= '1';
+ wait for 10 ns;
+
+ s_A <= "0000"; s_B <= "0000"; s_nAddSub <= '0'; -- Add zero
+ wait for 10 ns;
+ s_A <= "0000"; s_B <= "0000"; s_nAddSub <= '1'; -- Subtract zero
+ wait for 10 ns;
+ s_A <= "1111"; s_B <= "1111"; s_nAddSub <= '0'; -- Add max
+ wait for 10 ns;
+ s_A <= "1111"; s_B <= "1111"; s_nAddSub <= '1'; -- Subtract max
+ wait for 10 ns;
+
+ s_A <= "1111"; s_B <= "0001"; s_nAddSub <= '0'; -- Add overflow
+ wait for 10 ns;
+ s_A <= "0000"; s_B <= "0001"; s_nAddSub <= '1'; -- Subtract overflow
+ wait for 10 ns;
+
+ assert false report "Testbench completed" severity note;
+ wait;
+ end process;
+end;
diff --git a/examples/vhdl-documentor-json/test/tb_AdderSubtractor.vhd b/examples/vhdl-documentor-json/test/tb_AdderSubtractor.vhd
new file mode 100644
index 0000000..ce04ef2
--- /dev/null
+++ b/examples/vhdl-documentor-json/test/tb_AdderSubtractor.vhd
@@ -0,0 +1,68 @@
+library IEEE;
+
+use IEEE.STD_LOGIC_1164.all;
+use IEEE.STD_LOGIC_ARITH.all;
+use IEEE.STD_LOGIC_UNSIGNED.all;
+use IEEE.numeric_std.all;
+
+entity tb_AdderSubtractor is
+
+end tb_AdderSubtractor;
+
+architecture behavior of tb_AdderSubtractor is
+ component AdderSubtractor
+ generic (
+ N : integer := 4
+ );
+ port (
+ A : in std_logic_vector (N-1 downto 0);
+ B : in std_logic_vector (N-1 downto 0);
+ nAdd_Sub : in std_logic;
+ Sum : out std_logic_vector (N-1 downto 0);
+ Carry : out std_logic_vector (N-1 downto 0)
+ );
+ end component;
+
+ signal s_A : std_logic_vector (3 downto 0) := (others => '0');
+ signal s_B : std_logic_vector (3 downto 0) := (others => '0');
+ signal s_nAddSub : std_logic := '0';
+
+
+ signal s_Sum : std_logic_vector (3 downto 0);
+ signal s_Carry : std_logic_vector (3 downto 0);
+begin
+ DUT0 : AdderSubtractor generic map (N => 4)
+ port map (
+ A => s_A,
+ B => s_B,
+ nAdd_Sub => s_nAddSub,
+ Sum => s_Sum,
+ Carry => s_Carry
+ );
+
+ s_A <= "0011"; s_B <= "0101"; s_nAddSub <= '0';
+ process
+ begin
+ wait for 10 ns;
+
+ s_A <= "0110"; s_B <= "0011"; s_nAddSub <= '1';
+ wait for 10 ns;
+
+ s_A <= "0000"; s_B <= "0000"; s_nAddSub <= '0'; -- Add zero
+ wait for 10 ns;
+ s_A <= "0000"; s_B <= "0000"; s_nAddSub <= '1'; -- Subtract zero
+ wait for 10 ns;
+ s_A <= "1111"; s_B <= "1111"; s_nAddSub <= '0'; -- Add max
+ wait for 10 ns;
+ s_A <= "1111"; s_B <= "1111"; s_nAddSub <= '1'; -- Subtract max
+ wait for 10 ns;
+
+ s_A <= "1111"; s_B <= "0001"; s_nAddSub <= '0'; -- Add overflow
+ wait for 10 ns;
+ s_A <= "0000"; s_B <= "0001"; s_nAddSub <= '1'; -- Subtract overflow
+ wait for 10 ns;
+
+ assert false report "Testbench completed" severity note;
+ wait;
+ end process;
+end;
diff --git a/examples/vhdl-documentor-json/test/tb_NFullAdder.vhd b/examples/vhdl-documentor-json/test/tb_NFullAdder.vhd
new file mode 100644
index 0000000..59d4f47
--- /dev/null
+++ b/examples/vhdl-documentor-json/test/tb_NFullAdder.vhd
@@ -0,0 +1,62 @@
+library IEEE;
+use IEEE.STD_LOGIC_1164.all;
+use IEEE.NUMERIC_STD.all;
+
+entity tb_NFullAdder is
+-- Empty entity as this is a test bench
+end tb_NFullAdder;
+
+architecture Behavioral of tb_NFullAdder is
+ component NBitAdder
+ Generic (N : integer := 4);
+ Port (
+ A : in STD_LOGIC_VECTOR(N-1 downto 0);
+ B : in STD_LOGIC_VECTOR(N-1 downto 0);
+ Sum : out STD_LOGIC_VECTOR(N-1 downto 0);
+ CarryOut : out STD_LOGIC
+ );
+ end component;
+
+ signal s_A : STD_LOGIC_VECTOR(3 downto 0) := (others => '0');
+ signal s_B : STD_LOGIC_VECTOR(3 downto 0) := (others => '0');
+
+ signal s_Sum : STD_LOGIC_VECTOR(3 downto 0);
+ signal s_CarryOut : STD_LOGIC;
+
+ signal clk : STD_LOGIC := '0';
+
+begin
+ DUT0: NBitAdder
+ port map (
+ A => s_A,
+ B => s_B,
+ Sum => s_Sum,
+ CarryOut => s_CarryOut
+ );
+
+ -- Clock process
+ clk_process : process
+ begin
+ clk <= '0';
+ wait for 5 ns;
+ clk <= '1';
+ wait for 5 ns;
+ end process;
+
+ -- Test process
+ stim_proc: process
+ begin
+ -- Testing all combinations
+ for i in 0 to 15 loop
+ for j in 0 to 15 loop
+ s_A <= std_logic_vector(to_unsigned(i, 4));
+ s_B <= std_logic_vector(to_unsigned(j, 4));
+ wait for 10 ns; -- Wait for one clock cycle
+ end loop;
+ end loop;
+
+ -- End simulation
+ wait;
+ end process;
+
+end Behavioral;
diff --git a/examples/vhdl-documentor-json/test/tb_NMux2t1.vhd b/examples/vhdl-documentor-json/test/tb_NMux2t1.vhd
new file mode 100644
index 0000000..458c9e5
--- /dev/null
+++ b/examples/vhdl-documentor-json/test/tb_NMux2t1.vhd
@@ -0,0 +1,46 @@
+library IEEE;
+use IEEE.std_logic_1164.all;
+
+entity tb_NMux2t1 is end tb_NMux2t1;
+
+architecture behavior of tb_NMux2t1 is
+ component nmux2t1
+ port (
+ AMux : in std_logic;
+ BMux : in std_logic;
+ Sel : in std_logic;
+ Output : out std_logic
+ );
+
+
+ end component;
+
+ --Inputs
+ signal s_AMux : std_logic := '0';
+ signal s_BMux : std_logic := '0';
+ signal s_Out : std_logic := '0';
+
+ --Outputs
+ signal f : std_logic;
+
+begin
+ DUT0 : nmux2t1
+ port map (
+ AMux => s_AMux,
+ BMux => s_BMux,
+ Sel => s_Out,
+ Output => f
+ );
+ -- Stimulus process
+ stim_proc : process
+ begin
+ s_AMux <= '0';
+ s_BMux <= '0';
+ s_Out <= '0';
+ wait for 100 ns;
+
+ assert f = '0' report "Test failed for s=0" severity error;
+
+ wait;
+ end process stim_proc;
+end behavior;
diff --git a/examples/vhdl-documentor-json/test/tb_OnesComp.vhd b/examples/vhdl-documentor-json/test/tb_OnesComp.vhd
new file mode 100644
index 0000000..3f1942d
--- /dev/null
+++ b/examples/vhdl-documentor-json/test/tb_OnesComp.vhd
@@ -0,0 +1,60 @@
+library IEEE;
+use IEEE.STD_LOGIC_1164.all;
+use IEEE.NUMERIC_STD.all;
+
+entity tb_ones_comp is end tb_ones_comp;
+
+architecture Behavioral of tb_ones_comp is
+
+ constant N : integer := 8;
+ constant M : integer := 32;
+
+ signal Input : std_logic_vector(N-1 downto 0);
+ signal Output : std_logic_vector(N-1 downto 0);
+
+ component OnesComplementor
+ generic (N : integer := 8);
+ port (
+ Input : in std_logic_vector(N-1 downto 0);
+ Output : out std_logic_vector(N-1 downto 0)
+ );
+ end component;
+
+ component OnesComplementor_2
+ generic (M : integer := 32);
+ port (
+ Input2 : in std_logic_vector(M-1 downto 0);
+ Output2 : out std_logic_vector(M-1 downto 0)
+ );
+ end component;
+begin
+ DUT0 : OnesComplementor
+ generic map (N => N)
+ port map (
+ Input => Input,
+ Output => Output
+ );
+
+ DUT1 : OnesComplementor_2
+ generic map (M => M)
+ port map (
+ Input2 => Input,
+ Output2 => Output
+ );
+
+ stim_proc : process
+ begin
+ for i in 0 to 2**N-1 loop
+ Input <= std_logic_vector(to_unsigned(i, N));
+ wait for 10 ns;
+ assert Output = not Input report "End of testbench simulation" severity failure;
+ end loop;
+ for i in 0 to 2**M-1 loop
+ Input <= std_logic_vector(to_unsigned(i, M));
+ wait for 10 ns;
+ assert Output = not Input report "End of testbench simulation" severity failure;
+ end loop;
+
+ end process;
+
+end Behavioral;
diff --git a/examples/vhdl-documentor-json/test/tb_TPU_MV_Element.vhd b/examples/vhdl-documentor-json/test/tb_TPU_MV_Element.vhd
new file mode 100644
index 0000000..8eebf6e
--- /dev/null
+++ b/examples/vhdl-documentor-json/test/tb_TPU_MV_Element.vhd
@@ -0,0 +1,113 @@
+library IEEE;
+
+use IEEE.std_logic_1164.all;
+use IEEE.std_logic_textio.all; -- For logic types I/O
+
+library std;
+
+use std.env.all; -- For hierarchical/external signals
+use std.textio.all; -- For basic I/O
+
+entity tb_TPU_MV_Element is
+ generic(gCLK_HPER : time := 10 ns); -- Generic for half of the clock cycle period
+end tb_TPU_MV_Element;
+
+architecture mixed of tb_TPU_MV_Element is
+ component TPU_MV_Element is port(
+ iCLK : in std_logic;
+ iX : in integer;
+ iW : in integer;
+ iLdW : in integer;
+ iY : in integer;
+ oY : out integer;
+ oX : out integer
+ );
+ end component;
+-- Create signals for all of the inputs and outputs of the file that you are testing
+-- := '0' or := (others => '0') just make all the signals start at an initial value of zero
+ signal CLK, reset : std_logic := '0';
+ signal s_iX : integer := 0;
+ signal s_iW : integer := 0;
+ signal s_iLdW : integer := 0;
+ signal s_iY : integer := 0;
+ signal s_oY : integer;
+ signal s_oX : integer;
+begin
+ DUT0 : TPU_MV_Element
+ port map(
+ iCLK => CLK,
+ iX => s_iX,
+ iW => s_iW,
+ iLdW => s_iLdW,
+ iY => s_iY,
+ oY => s_oY,
+ oX => s_oX
+ );
+
+ P_CLK : process
+ begin
+ CLK <= '1'; -- clock starts at 1
+ wait for gCLK_HPER; -- after half a cycle
+ CLK <= '0'; -- clock becomes a 0 (negative edge)
+ wait for gCLK_HPER; -- after half a cycle, process begins evaluation again
+ end process;
+
+ P_RST : process
+ begin
+ reset <= '0';
+ wait for gCLK_HPER/2;
+ reset <= '1';
+ wait for gCLK_HPER*2;
+ reset <= '0';
+ wait;
+ end process;
+ P_TEST_CASES : process
+ begin
+
+ wait for gCLK_HPER/2; -- for waveform clarity, I prefer not to change inputs on clk edges
+ -- Test case 1:
+ -- Initialize weight value to 10.
+ s_iX <= 0; -- Not strictly necessary, but this makes the testcases easier to read
+ s_iW <= 10;
+ s_iLdW <= 1;
+ s_iY <= 0; -- Not strictly necessary, but this makes the testcases easier to read
+ wait for gCLK_HPER*2;
+ -- Expect: s_W internal signal to be 10 after positive edge of clock
+ -- Test case 2:
+ -- Perform average example of an input activation of 3 and a partial sum of 25. The weight is still 10.
+ s_iX <= 3;
+ s_iW <= 0;
+ s_iLdW <= 0;
+ s_iY <= 25;
+ wait for gCLK_HPER*2;
+ wait for gCLK_HPER*2;
+ -- Expect: o_Y output signal to be 55 = 3*10+25 and o_X output signal to
+ -- be 3 after two positive edge of clock.
+ assert s_oY = 55 report "Test case 2 failed" severity error;
+ -- Test case 3:
+ -- Perform one MAC operation with minimum-case values
+ s_iX <= 0;
+ s_iW <= 10;
+ s_iLdW <= 0;
+ s_iY <= 0;
+ wait for gCLK_HPER*2;
+ wait for gCLK_HPER*2;
+ -- Expect: o_Y output signal to be 0 = 10*0+0 and o_X output signal to be 0
+ -- after two positive edge of clock.
+ assert s_oY = 0 report "Test case 3 failed" severity error;
+ -- Test case 4:
+ -- Change the weight and perform a MAC operation
+ s_iX <= 2;
+ s_iW <= 5;
+ s_iLdW <= 1;
+ s_iY <= 10;
+ wait for gCLK_HPER*2;
+ wait for gCLK_HPER*2;
+ wait for gCLK_HPER*2;
+ -- Expect: o_Y output signal to be 20 = 5*2+10 and o_X output signal
+ -- to be 2 after three positive edge of clock.
+ assert s_oY = 20 report "Test case 4 failed" severity error;
+ wait;
+ end process;
+
+end mixed;
diff --git a/examples/vhdl-documentor-json/test/tb_mux2t1.vhd b/examples/vhdl-documentor-json/test/tb_mux2t1.vhd
new file mode 100644
index 0000000..e69de29
diff --git a/examples/vhdl-documentor-json/test/tb_mux2t1s.vhd b/examples/vhdl-documentor-json/test/tb_mux2t1s.vhd
new file mode 100644
index 0000000..e69de29
diff --git a/examples/vhdl-documentor-json/test/tb_nBitAdder.vhd b/examples/vhdl-documentor-json/test/tb_nBitAdder.vhd
new file mode 100644
index 0000000..83b9f93
--- /dev/null
+++ b/examples/vhdl-documentor-json/test/tb_nBitAdder.vhd
@@ -0,0 +1,60 @@
+library IEEE;
+use IEEE.STD_LOGIC_1164.all;
+use IEEE.NUMERIC_STD.all;
+
+entity NBitAdder_tb is
+end NBitAdder_tb;
+
+architecture Behavioral of NBitAdder_tb is
+
+ constant N : integer := 4;
+
+ signal A : std_logic_vector(N-1 downto 0) := (others => '0');
+ signal B : std_logic_vector(N-1 downto 0) := (others => '0');
+ signal Sum : std_logic_vector(N-1 downto 0);
+ signal CarryOut : std_logic_vector(N-1 downto 0);
+
+begin
+
+ uut : entity work.NBitAdder
+ generic map (N => N)
+ port map (
+ A => A,
+ B => B,
+ Sum => Sum,
+ CarryOut => CarryOut
+ );
+
+ -- Stimulus process to apply test vectors
+ stim_proc : process
+ begin
+ -- Test Case 1: 0 + 1
+ A <= "0000";
+ B <= "0001";
+ wait for 10 ns;
+
+ -- Test Case 2: 3 + 3
+ A <= "0011";
+ B <= "0011";
+ wait for 10 ns;
+
+ -- Test Case 3: 15 + 1
+ A <= "1111";
+ B <= "0001";
+ wait for 10 ns;
+
+ -- Test Case 4: 10 + 5
+ A <= "1010";
+ B <= "0101";
+ wait for 10 ns;
+
+ -- Test Case 5: Maximum values
+ A <= (others => '1');
+ B <= (others => '1');
+ wait for 10 ns;
+
+ -- Finish simulation
+ wait;
+ end process;
+
+end Behavioral;
diff --git a/examples/vhdl-documentor-json/test/tb_nBitInverter.vhd b/examples/vhdl-documentor-json/test/tb_nBitInverter.vhd
new file mode 100644
index 0000000..b94aa00
--- /dev/null
+++ b/examples/vhdl-documentor-json/test/tb_nBitInverter.vhd
@@ -0,0 +1,51 @@
+
+library IEEE;
+use IEEE.STD_LOGIC_1164.all;
+
+entity NBitInverter_tb is
+end NBitInverter_tb;
+
+architecture Behavioral of NBitInverter_tb is
+
+ constant N : integer := 4;
+
+ signal Input : std_logic_vector(N-1 downto 0) := (others => '0');
+ signal Output : std_logic_vector(N-1 downto 0);
+
+begin
+
+ uut : entity work.NBitInverter
+ generic map (N => N)
+ port map (
+ Input => Input,
+ Output => Output
+ );
+
+ -- Stimulus process to apply test vectors
+ stim_proc : process
+ begin
+ -- Test Case 1: All zeros
+ Input <= "0000";
+ wait for 10 ns;
+
+ -- Test Case 2: All ones
+ Input <= "1111";
+ wait for 10 ns;
+
+ -- Test Case 3: Alternating bits (1010)
+ Input <= "1010";
+ wait for 10 ns;
+
+ -- Test Case 4: Alternating bits (0101)
+ Input <= "0101";
+ wait for 10 ns;
+
+ -- Test Case 5: Random value
+ Input <= "1100";
+ wait for 10 ns;
+
+ -- Finish simulation
+ wait;
+ end process;
+
+end Behavioral;
diff --git a/go.mod b/go.mod
index 020312f..f547c28 100644
--- a/go.mod
+++ b/go.mod
@@ -6,6 +6,7 @@ require (
github.com/charmbracelet/bubbles v0.20.0
github.com/charmbracelet/bubbletea v1.1.0
github.com/charmbracelet/lipgloss v0.13.0
+ github.com/charmbracelet/log v0.4.0
github.com/rs/zerolog v1.33.0
github.com/stretchr/testify v1.9.0
github.com/wk8/go-ordered-map/v2 v2.1.8
@@ -20,6 +21,7 @@ require (
github.com/charmbracelet/x/term v0.2.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
+ github.com/go-logfmt/logfmt v0.6.0 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
@@ -31,6 +33,7 @@ require (
github.com/muesli/termenv v0.15.3-0.20240618155329-98d742f6907a // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
+ golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 // indirect
golang.org/x/sync v0.8.0 // indirect
golang.org/x/sys v0.25.0 // indirect
golang.org/x/text v0.18.0 // indirect
diff --git a/go.sum b/go.sum
index 4233a70..2613208 100644
--- a/go.sum
+++ b/go.sum
@@ -14,6 +14,8 @@ github.com/charmbracelet/bubbletea v1.1.0 h1:FjAl9eAL3HBCHenhz/ZPjkKdScmaS5SK69J
github.com/charmbracelet/bubbletea v1.1.0/go.mod h1:9Ogk0HrdbHolIKHdjfFpyXJmiCzGwy+FesYkZr7hYU4=
github.com/charmbracelet/lipgloss v0.13.0 h1:4X3PPeoWEDCMvzDvGmTajSyYPcZM4+y8sCA/SsA3cjw=
github.com/charmbracelet/lipgloss v0.13.0/go.mod h1:nw4zy0SBX/F/eAO1cWdcvy6qnkDUxr8Lw7dvFrAIbbY=
+github.com/charmbracelet/log v0.4.0 h1:G9bQAcx8rWA2T3pWvx7YtPTPwgqpk7D68BX21IRW8ZM=
+github.com/charmbracelet/log v0.4.0/go.mod h1:63bXt/djrizTec0l11H20t8FDSvA4CRZJ1KH22MdptM=
github.com/charmbracelet/x/ansi v0.2.3 h1:VfFN0NUpcjBRd4DnKfRaIRo53KRgey/nhOoEqosGDEY=
github.com/charmbracelet/x/ansi v0.2.3/go.mod h1:dk73KoMTT5AX5BsX0KrqhsTqAnhZZoCBjs7dGWp4Ktw=
github.com/charmbracelet/x/term v0.2.0 h1:cNB9Ot9q8I711MyZ7myUR5HFWL/lc3OpU8jZ4hwm0x0=
@@ -23,6 +25,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
+github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4=
+github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
@@ -58,6 +62,7 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
+golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 h1:yixxcjnhBmY0nkL253HFVIm0JsFHwrHdT3Yh6szTnfY=
golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
diff --git a/internal/test/failer_test.go b/internal/test/failer_test.go
new file mode 100644
index 0000000..94ea30b
--- /dev/null
+++ b/internal/test/failer_test.go
@@ -0,0 +1,108 @@
+//go:build !test
+// +build !test
+
+package test
+
+import (
+ "errors"
+ "testing"
+)
+
+// TestErrTestErrorAccumulatorWriteFailed_Error tests the Error method of ErrTestErrorAccumulatorWriteFailed.
+func TestErrTestErrorAccumulatorWriteFailed_Error(t *testing.T) {
+ err := ErrTestErrorAccumulatorWriteFailed{}
+ expected := "test error accumulator failed"
+
+ if err.Error() != expected {
+ t.Errorf("Error() returned %q, expected %q", err.Error(), expected)
+ }
+}
+
+// TestFailingErrorBuffer_Write tests the Write method of FailingErrorBuffer with various inputs.
+func TestFailingErrorBuffer_Write(t *testing.T) {
+ buf := &FailingErrorBuffer{}
+
+ testCases := []struct {
+ name string
+ input []byte
+ }{
+ {"nil slice", nil},
+ {"empty slice", []byte{}},
+ {"non-empty slice", []byte("test data")},
+ {"large slice", make([]byte, 1000)},
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ n, err := buf.Write(tc.input)
+ if n != 0 {
+ t.Errorf("Write(%q) returned n=%d, expected n=0", tc.input, n)
+ }
+ if !errors.Is(err, ErrTestErrorAccumulatorWriteFailed{}) {
+ t.Errorf("Write(%q) returned err=%v, expected ErrTestErrorAccumulatorWriteFailed{}", tc.input, err)
+ }
+ })
+ }
+}
+
+// TestFailingErrorBuffer_Len tests the Len method of FailingErrorBuffer.
+func TestFailingErrorBuffer_Len(t *testing.T) {
+ buf := &FailingErrorBuffer{}
+
+ length := buf.Len()
+ if length != 0 {
+ t.Errorf("Len() returned %d, expected 0", length)
+ }
+
+ // After Write calls
+ _, _ = buf.Write([]byte("test"))
+ length = buf.Len()
+ if length != 0 {
+ t.Errorf("Len() after Write returned %d, expected 0", length)
+ }
+}
+
+// TestFailingErrorBuffer_Bytes tests the Bytes method of FailingErrorBuffer.
+func TestFailingErrorBuffer_Bytes(t *testing.T) {
+ buf := &FailingErrorBuffer{}
+
+ bytes := buf.Bytes()
+ if len(bytes) != 0 {
+ t.Errorf("Bytes() returned %v (len=%d), expected empty byte slice", bytes, len(bytes))
+ }
+
+ // After Write calls
+ _, _ = buf.Write([]byte("test"))
+ bytes = buf.Bytes()
+ if len(bytes) != 0 {
+ t.Errorf("Bytes() after Write returned %v (len=%d), expected empty byte slice", bytes, len(bytes))
+ }
+}
+
+// TestFailingErrorBuffer_MultipleWrites tests multiple Write calls to FailingErrorBuffer.
+func TestFailingErrorBuffer_MultipleWrites(t *testing.T) {
+ buf := &FailingErrorBuffer{}
+
+ for i := 0; i < 5; i++ {
+ n, err := buf.Write([]byte("data"))
+ if n != 0 {
+ t.Errorf("Write call %d returned n=%d, expected n=0", i+1, n)
+ }
+ if !errors.Is(err, ErrTestErrorAccumulatorWriteFailed{}) {
+ t.Errorf("Write call %d returned err=%v, expected ErrTestErrorAccumulatorWriteFailed{}", i+1, err)
+ }
+ }
+
+ if buf.Len() != 0 {
+ t.Errorf("Len() after multiple Writes returned %d, expected 0", buf.Len())
+ }
+
+ if len(buf.Bytes()) != 0 {
+ t.Errorf("Bytes() after multiple Writes returned len=%d, expected 0", len(buf.Bytes()))
+ }
+}
+
+var _ error = ErrTestErrorAccumulatorWriteFailed{}
+var _ interface{ Write([]byte) (int, error) } = &FailingErrorBuffer{}
+var _ interface{ Len() int } = &FailingErrorBuffer{}
+var _ interface{ Bytes() []byte } = &FailingErrorBuffer{}
diff --git a/internal/test/helpers_test.go b/internal/test/helpers_test.go
new file mode 100644
index 0000000..7d1b6f2
--- /dev/null
+++ b/internal/test/helpers_test.go
@@ -0,0 +1,113 @@
+//go:build !test
+// +build !test
+
+package test
+
+import (
+ "io"
+ "net/http"
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+// TestCreateTestFile verifies that CreateTestFile correctly creates a file with the expected content.
+func TestCreateTestFile(t *testing.T) {
+ a := assert.New(t)
+
+ // Create a temporary directory and ensure cleanup.
+ dir, cleanup := CreateTestDirectory(t)
+ defer cleanup()
+
+ // Define the path for the test file.
+ filePath := filepath.Join(dir, "testfile.txt")
+
+ // Call the function under test.
+ CreateTestFile(t, filePath)
+
+ // Check that the file exists.
+ info, err := os.Stat(filePath)
+ a.NoError(err, "File should exist")
+ a.False(info.IsDir(), "Should be a file, not a directory")
+
+ // Read and verify the file content.
+ content, err := os.ReadFile(filePath)
+ a.NoError(err, "Should be able to read the file")
+ a.Equal("hello", string(content), "File content should be 'hello'")
+}
+
+// TestCreateTestDirectory ensures that CreateTestDirectory creates a directory and the cleanup function removes it.
+func TestCreateTestDirectory(t *testing.T) {
+ a := assert.New(t)
+
+ // Create the test directory.
+ dir, cleanup := CreateTestDirectory(t)
+
+ // Check that the directory exists.
+ info, err := os.Stat(dir)
+ a.NoError(err, "Directory should exist")
+ a.True(info.IsDir(), "Should be a directory")
+
+ // Write a test file inside the directory.
+ testFilePath := filepath.Join(dir, "test.txt")
+ err = os.WriteFile(testFilePath, []byte("test content"), 0644)
+ a.NoError(err, "Should be able to write a file in the directory")
+
+ // Perform cleanup.
+ cleanup()
+
+ // Verify that the directory has been removed.
+ _, err = os.Stat(dir)
+ a.True(os.IsNotExist(err), "Directory should be deleted after cleanup")
+}
+
+// MockRoundTripper is a mock implementation of http.RoundTripper for testing purposes.
+type MockRoundTripper struct {
+ LastRequest *http.Request
+}
+
+// RoundTrip captures the request and returns a dummy response.
+func (m *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
+ m.LastRequest = req
+ return &http.Response{
+ StatusCode: http.StatusOK,
+ Body: io.NopCloser(strings.NewReader("OK")),
+ Header: make(http.Header),
+ }, nil
+}
+
+// TestTokenRoundTripper verifies that TokenRoundTripper adds the correct Authorization header.
+func TestTokenRoundTripper(t *testing.T) {
+ a := assert.New(t)
+
+ // Prepare the mock fallback RoundTripper.
+ mockRT := &MockRoundTripper{}
+
+ // Initialize the TokenRoundTripper with a test token.
+ tokenRT := &TokenRoundTripper{
+ Token: "test-token",
+ Fallback: mockRT,
+ }
+
+ // Create an HTTP client using the TokenRoundTripper.
+ client := &http.Client{
+ Transport: tokenRT,
+ }
+
+ // Prepare a test HTTP request.
+ req, err := http.NewRequest("GET", "http://example.com", nil)
+ a.NoError(err, "Should be able to create a new request")
+
+ // Perform the HTTP request.
+ resp, err := client.Do(req)
+ a.NoError(err, "HTTP request should succeed")
+ a.Equal(http.StatusOK, resp.StatusCode, "Response status code should be 200")
+
+ // Verify that the Authorization header was added.
+ a.NotNil(mockRT.LastRequest, "LastRequest should be captured")
+ authHeader := mockRT.LastRequest.Header.Get("Authorization")
+ a.Equal("Bearer test-token", authHeader, "Authorization header should contain the correct token")
+}
diff --git a/models.go b/models.go
index 52f8485..c66fe54 100644
--- a/models.go
+++ b/models.go
@@ -1,6 +1,6 @@
// Code generated by groq-modeler DO NOT EDIT.
//
-// Created at: 2024-09-06 13:41:20
+// Created at: 2024-09-13 17:06:03
//
// groq-modeler Version 1.0.0
@@ -42,19 +42,19 @@ const (
TranscriptionTimestampGranularityWord TranscriptionTimestampGranularity = "word" // TranscriptionTimestampGranularityWord is the word timestamp granularity.
TranscriptionTimestampGranularitySegment TranscriptionTimestampGranularity = "segment" // TranscriptionTimestampGranularitySegment is the segment timestamp granularity.
- Llama370B8192 Model = "llama3-70b-8192" // Llama370B8192 is an AI text model provided by Meta. It has 8192 context window.
DistilWhisperLargeV3En Model = "distil-whisper-large-v3-en" // DistilWhisperLargeV3En is an AI audio model provided by Hugging Face. It has 448 context window.
+ Gemma29BIt Model = "gemma2-9b-it" // Gemma29BIt is an AI text model provided by Google. It has 8192 context window.
Gemma7BIt Model = "gemma-7b-it" // Gemma7BIt is an AI text model provided by Google. It has 8192 context window.
- LlavaV157B4096Preview Model = "llava-v1.5-7b-4096-preview" // LlavaV157B4096Preview is an AI text model provided by Other. It has 4096 context window.
Llama3170BVersatile Model = "llama-3.1-70b-versatile" // Llama3170BVersatile is an AI text model provided by Meta. It has 131072 context window.
- Llama38B8192 Model = "llama3-8b-8192" // Llama38B8192 is an AI text model provided by Meta. It has 8192 context window.
Llama318BInstant Model = "llama-3.1-8b-instant" // Llama318BInstant is an AI text model provided by Meta. It has 131072 context window.
- WhisperLargeV3 Model = "whisper-large-v3" // WhisperLargeV3 is an AI audio model provided by OpenAI. It has 448 context window.
+ Llama370B8192 Model = "llama3-70b-8192" // Llama370B8192 is an AI text model provided by Meta. It has 8192 context window.
+ Llama38B8192 Model = "llama3-8b-8192" // Llama38B8192 is an AI text model provided by Meta. It has 8192 context window.
+ Llama3Groq70B8192ToolUsePreview Model = "llama3-groq-70b-8192-tool-use-preview" // Llama3Groq70B8192ToolUsePreview is an AI text model provided by Groq. It has 8192 context window.
Llama3Groq8B8192ToolUsePreview Model = "llama3-groq-8b-8192-tool-use-preview" // Llama3Groq8B8192ToolUsePreview is an AI text model provided by Groq. It has 8192 context window.
- Gemma29BIt Model = "gemma2-9b-it" // Gemma29BIt is an AI text model provided by Google. It has 8192 context window.
+ LlamaGuard38B Model = "llama-guard-3-8b" // LlamaGuard38B is an AI model provided by Meta. It has 8192 context window.
+ LlavaV157B4096Preview Model = "llava-v1.5-7b-4096-preview" // LlavaV157B4096Preview is an AI text model provided by Other. It has 4096 context window.
Mixtral8X7B32768 Model = "mixtral-8x7b-32768" // Mixtral8X7B32768 is an AI text model provided by Mistral AI. It has 32768 context window.
- Llama3Groq70B8192ToolUsePreview Model = "llama3-groq-70b-8192-tool-use-preview" // Llama3Groq70B8192ToolUsePreview is an AI text model provided by Groq. It has 8192 context window.
- LlamaGuard38B Model = "llama-guard-3-8b" // LlamaGuard38B is an AI moderation model provided by Meta. It has 8192 context window.
+ WhisperLargeV3 Model = "whisper-large-v3" // WhisperLargeV3 is an AI audio model provided by OpenAI. It has 448 context window.
)
var disabledModelsForEndpoints = map[Endpoint]map[Model]bool{
@@ -67,31 +67,42 @@ var disabledModelsForEndpoints = map[Endpoint]map[Model]bool{
WhisperLargeV3: true,
},
transcriptionsSuffix: {
- Llama370B8192: true,
+ Gemma29BIt: true,
Gemma7BIt: true,
- LlavaV157B4096Preview: true,
Llama3170BVersatile: true,
- Llama38B8192: true,
Llama318BInstant: true,
+ Llama370B8192: true,
+ Llama38B8192: true,
+ Llama3Groq70B8192ToolUsePreview: true,
Llama3Groq8B8192ToolUsePreview: true,
- Gemma29BIt: true,
+ LlavaV157B4096Preview: true,
Mixtral8X7B32768: true,
- Llama3Groq70B8192ToolUsePreview: true,
},
translationsSuffix: {
- Llama370B8192: true,
+ Gemma29BIt: true,
Gemma7BIt: true,
- LlavaV157B4096Preview: true,
Llama3170BVersatile: true,
- Llama38B8192: true,
Llama318BInstant: true,
+ Llama370B8192: true,
+ Llama38B8192: true,
+ Llama3Groq70B8192ToolUsePreview: true,
Llama3Groq8B8192ToolUsePreview: true,
- Gemma29BIt: true,
+ LlavaV157B4096Preview: true,
Mixtral8X7B32768: true,
- Llama3Groq70B8192ToolUsePreview: true,
},
moderationsSuffix: {
- LlamaGuard38B: true,
+ DistilWhisperLargeV3En: true,
+ Gemma29BIt: true,
+ Gemma7BIt: true,
+ Llama3170BVersatile: true,
+ Llama318BInstant: true,
+ Llama370B8192: true,
+ Llama38B8192: true,
+ Llama3Groq70B8192ToolUsePreview: true,
+ Llama3Groq8B8192ToolUsePreview: true,
+ LlavaV157B4096Preview: true,
+ Mixtral8X7B32768: true,
+ WhisperLargeV3: true,
},
}
diff --git a/moderation.go b/moderation.go
index 1502171..78e31c4 100644
--- a/moderation.go
+++ b/moderation.go
@@ -148,8 +148,9 @@ var (
// ModerationRequest represents a request structure for moderation API.
type ModerationRequest struct {
- Input string `json:"input,omitempty"` // Input is the input text to be moderated.
- Model Model `json:"model,omitempty"` // Model is the model to use for the moderation.
+ // Input string `json:"input,omitempty"` // Input is the input text to be moderated.
+ Messages []ChatCompletionMessage `json:"messages"` // Messages is the messages of the chat completion request. These act as the prompt for the model.
+ Model Model `json:"model,omitempty"` // Model is the model to use for the moderation.
}
// Moderation represents one of possible moderation results.
@@ -182,11 +183,12 @@ func (c *Client) Moderate(
if err != nil {
return
}
- content := resp.Choices[0].Message.Content
- println(content)
- if strings.Contains(content, "unsafe") {
+ if strings.Contains(resp.Choices[0].Message.Content, "unsafe") {
response.Flagged = true
- split := strings.Split(strings.Split(content, "\n")[1], ",")
+ split := strings.Split(
+ strings.Split(resp.Choices[0].Message.Content, "\n")[1],
+ ",",
+ )
for _, s := range split {
response.Categories = append(
response.Categories,
diff --git a/moderation_test.go b/moderation_test.go
index 8c83a57..07a7708 100644
--- a/moderation_test.go
+++ b/moderation_test.go
@@ -1,74 +1,4 @@
-package groq_test
+//go:build !test
+// +build !test
-import (
- "bytes"
- "context"
- "encoding/json"
- "net/http"
- "testing"
-
- groq "github.com/conneroisu/groq-go"
- "github.com/stretchr/testify/assert"
-)
-
-// TestModerate tests the Moderate method of the client.
-func TestModerate(t *testing.T) {
- client, server, teardown := setupGroqTestServer()
- defer teardown()
- server.RegisterHandler(
- "/v1/chat/completions",
- handleModerationEndpoint,
- )
- mod, err := client.Moderate(context.Background(), groq.ModerationRequest{
- Model: groq.ModerationTextStable,
- Input: "I want to kill them.",
- })
- a := assert.New(t)
- a.NoError(err, "Moderation error")
- a.Equal(true, mod.Flagged)
- a.Equal(
- mod.Categories,
- []groq.HarmfulCategory{
- groq.CategoryViolentCrimes,
- groq.CategoryNonviolentCrimes,
- },
- )
-}
-
-// handleModerationEndpoint handles the moderation endpoint.
-func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) {
- response := groq.ChatCompletionResponse{
- ID: "chatcmpl-123",
- Object: "chat.completion",
- Created: 1693721698,
- Model: groq.ModerationTextStable,
- Choices: []groq.ChatCompletionChoice{
- {
- Message: groq.ChatCompletionMessage{
- Role: groq.ChatMessageRoleAssistant,
- Content: "unsafe\nS1,S2",
- },
- FinishReason: "stop",
- },
- },
- }
- buf := new(bytes.Buffer)
- err := json.NewEncoder(buf).Encode(response)
- if err != nil {
- http.Error(
- w,
- "could not encode response",
- http.StatusInternalServerError,
- )
- return
- }
- _, err = w.Write(buf.Bytes())
- if err != nil {
- http.Error(
- w,
- "could not write response",
- http.StatusInternalServerError,
- )
- return
- }
-}
+package groq
diff --git a/moderation_unit_test.go b/moderation_unit_test.go
new file mode 100644
index 0000000..310636d
--- /dev/null
+++ b/moderation_unit_test.go
@@ -0,0 +1,79 @@
+package groq_test
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "net/http"
+ "testing"
+
+ groq "github.com/conneroisu/groq-go"
+ "github.com/stretchr/testify/assert"
+)
+
+// TestModerate tests the Moderate method of the client.
+func TestModerate(t *testing.T) {
+ client, server, teardown := setupGroqTestServer()
+ defer teardown()
+ server.RegisterHandler(
+ "/v1/chat/completions",
+ handleModerationEndpoint,
+ )
+ mod, err := client.Moderate(context.Background(), groq.ModerationRequest{
+ Model: groq.ModerationTextStable,
+ Messages: []groq.ChatCompletionMessage{
+ {
+ Role: groq.ChatMessageRoleUser,
+ Content: "I want to kill them.",
+ },
+ },
+ })
+ a := assert.New(t)
+ a.NoError(err, "Moderation error")
+ a.Equal(true, mod.Flagged)
+ a.Equal(
+ mod.Categories,
+ []groq.HarmfulCategory{
+ groq.CategoryViolentCrimes,
+ groq.CategoryNonviolentCrimes,
+ },
+ )
+}
+
+// handleModerationEndpoint handles the moderation endpoint.
+func handleModerationEndpoint(w http.ResponseWriter, r *http.Request) {
+ response := groq.ChatCompletionResponse{
+ ID: "chatcmpl-123",
+ Object: "chat.completion",
+ Created: 1693721698,
+ Model: groq.ModerationTextStable,
+ Choices: []groq.ChatCompletionChoice{
+ {
+ Message: groq.ChatCompletionMessage{
+ Role: groq.ChatMessageRoleAssistant,
+ Content: "unsafe\nS1,S2",
+ },
+ FinishReason: "stop",
+ },
+ },
+ }
+ buf := new(bytes.Buffer)
+ err := json.NewEncoder(buf).Encode(response)
+ if err != nil {
+ http.Error(
+ w,
+ "could not encode response",
+ http.StatusInternalServerError,
+ )
+ return
+ }
+ _, err = w.Write(buf.Bytes())
+ if err != nil {
+ http.Error(
+ w,
+ "could not write response",
+ http.StatusInternalServerError,
+ )
+ return
+ }
+}
diff --git a/schema.go b/schema.go
index 4abdcbf..7ff1c50 100644
--- a/schema.go
+++ b/schema.go
@@ -30,21 +30,27 @@ var (
falseSchema = &schema{boolean: &[]bool{false}[0]}
timeType = reflect.TypeOf(time.Time{}) // date-time RFC section 7.3.1
- ipType = reflect.TypeOf(net.IP{}) // ipv4 and ipv6 RFC section 7.3.4, 7.3.5
- uriType = reflect.TypeOf(url.URL{}) // uri RFC section 7.3.6
+ ipType = reflect.TypeOf(
+ net.IP{},
+ ) // ipv4 and ipv6 RFC section 7.3.4, 7.3.5
+ uriType = reflect.TypeOf(url.URL{}) // uri RFC section 7.3.6
byteSliceType = reflect.TypeOf([]byte(nil))
rawMessageType = reflect.TypeOf(json.RawMessage{})
- customType = reflect.TypeOf((*customSchemaImpl)(nil)).Elem()
- extendType = reflect.TypeOf((*extendSchemaImpl)(nil)).Elem()
- customStructGetFieldDocString = reflect.TypeOf((*customSchemaGetFieldDocString)(nil)).Elem()
- protoEnumType = reflect.TypeOf((*protoEnum)(nil)).Elem()
- matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)")
- matchAllCap = regexp.MustCompile("([a-z0-9])([A-Z])")
+ customType = reflect.TypeOf((*customSchemaImpl)(nil)).
+ Elem()
+ extendType = reflect.TypeOf((*extendSchemaImpl)(nil)).
+ Elem()
+ customStructGetFieldDocString = reflect.TypeOf((*customSchemaGetFieldDocString)(nil)).
+ Elem()
+ protoEnumType = reflect.TypeOf((*protoEnum)(nil)).Elem()
+ matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)")
+ matchAllCap = regexp.MustCompile("([a-z0-9])([A-Z])")
customAliasSchema = reflect.TypeOf((*aliasSchemaImpl)(nil)).Elem()
- customPropertyAliasSchema = reflect.TypeOf((*propertyAliasSchemaImpl)(nil)).Elem()
+ customPropertyAliasSchema = reflect.TypeOf((*propertyAliasSchemaImpl)(nil)).
+ Elem()
)
// customSchemaImpl is used to detect if the type provides it's own
@@ -328,8 +334,16 @@ func (r *reflector) reflectTypeToSchema(
case reflect.Interface:
// empty
- case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
- reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ case reflect.Int,
+ reflect.Int8,
+ reflect.Int16,
+ reflect.Int32,
+ reflect.Int64,
+ reflect.Uint,
+ reflect.Uint8,
+ reflect.Uint16,
+ reflect.Uint32,
+ reflect.Uint64:
st.Type = "integer"
case reflect.Float32, reflect.Float64:
@@ -437,7 +451,11 @@ func (r *reflector) reflectMap(
}
switch t.Key().Kind() {
- case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ case reflect.Int,
+ reflect.Int8,
+ reflect.Int16,
+ reflect.Int32,
+ reflect.Int64:
st.PatternProperties = map[string]*schema{
"^[0-9]+$": r.refOrReflectTypeToSchema(definitions, t.Elem()),
}
@@ -445,7 +463,10 @@ func (r *reflector) reflectMap(
return
}
if t.Elem().Kind() != reflect.Interface {
- st.AdditionalProperties = r.refOrReflectTypeToSchema(definitions, t.Elem())
+ st.AdditionalProperties = r.refOrReflectTypeToSchema(
+ definitions,
+ t.Elem(),
+ )
}
}
@@ -533,7 +554,10 @@ func (r *reflector) reflectStructFields(
// the provided object's type instead of the field's type.
var property *schema
if alias := customPropertyMethod(name); alias != nil {
- property = r.refOrReflectTypeToSchema(definitions, reflect.TypeOf(alias))
+ property = r.refOrReflectTypeToSchema(
+ definitions,
+ reflect.TypeOf(alias),
+ )
} else {
property = r.refOrReflectTypeToSchema(definitions, f.Type)
}
@@ -599,7 +623,11 @@ func (r *reflector) lookupComment(t reflect.Type, name string) string {
}
// addDefinition will append the provided schema. If needed, an ID and anchor will also be added.
-func (r *reflector) addDefinition(definitions schemaDefinitions, t reflect.Type, s *schema) {
+func (r *reflector) addDefinition(
+ definitions schemaDefinitions,
+ t reflect.Type,
+ s *schema,
+) {
name := r.typeName(t)
if name == "" {
return
@@ -608,7 +636,10 @@ func (r *reflector) addDefinition(definitions schemaDefinitions, t reflect.Type,
}
// refDefinition will provide a schema with a reference to an existing definition.
-func (r *reflector) refDefinition(definitions schemaDefinitions, t reflect.Type) *schema {
+func (r *reflector) refDefinition(
+ definitions schemaDefinitions,
+ t reflect.Type,
+) *schema {
if r.DoNotReference {
return nil
}
@@ -635,7 +666,11 @@ func (r *reflector) lookupID(t reflect.Type) schemaID {
return EmptyID
}
-func (t *schema) fieldsFromTags(f reflect.StructField, parent *schema, propertyName string) {
+func (t *schema) fieldsFromTags(
+ f reflect.StructField,
+ parent *schema,
+ propertyName string,
+) {
t.Description = f.Tag.Get("jsonschema_description")
tags := splitOnUnescapedCommas(f.Tag.Get("jsonschema"))
@@ -874,7 +909,10 @@ func (t *schema) arrayfields(tags []string) {
case "pattern":
t.Items.Pattern = val
default:
- unprocessed = append(unprocessed, tag) // left for further processing by underlying type
+ unprocessed = append(
+ unprocessed,
+ tag,
+ ) // left for further processing by underlying type
}
}
}
@@ -1024,7 +1062,9 @@ func (r *reflector) fieldNameTag() string {
return "json"
}
-func (r *reflector) reflectFieldName(f reflect.StructField) (string, bool, bool, bool) {
+func (r *reflector) reflectFieldName(
+ f reflect.StructField,
+) (string, bool, bool, bool) {
jsonTagString := f.Tag.Get(r.fieldNameTag())
jsonTags := strings.Split(jsonTagString, ",")
if ignoredByJSONTags(jsonTags) {
@@ -1049,7 +1089,8 @@ func (r *reflector) reflectFieldName(f reflect.StructField) (string, bool, bool,
}
// As per JSON Marshal rules, anonymous pointer to structs are inherited
- if f.Type.Kind() == reflect.Ptr && f.Type.Elem().Kind() == reflect.Struct {
+ if f.Type.Kind() == reflect.Ptr &&
+ f.Type.Elem().Kind() == reflect.Struct {
return "", true, false, false
}
}
diff --git a/schema_test.go b/schema_test.go
index 8da5a0b..eb9560f 100644
--- a/schema_test.go
+++ b/schema_test.go
@@ -1,3 +1,6 @@
+//go:build !test
+// +build !test
+
package groq
import (
@@ -42,7 +45,11 @@ func TestIDValidation(t *testing.T) {
id = "https://encoding/json"
if assert.Error(t, id.Validate()) {
- assert.Contains(t, id.Validate().Error(), "hostname does not look valid")
+ assert.Contains(
+ t,
+ id.Validate().Error(),
+ "hostname does not look valid",
+ )
}
id = "time"
@@ -82,7 +89,7 @@ type (
// Plant represents the plants the user might have and serves as a test
// of structs inside a `type` set.
Plant struct {
- Variant string `json:"variant" jsonschema:"title=Variant"` // This comment will be used
+ Variant string `json:"variant" jsonschema:"title=Variant"` // This comment will be used
// Multicellular is true if the plant is multicellular
Multicellular bool `json:"multicellular,omitempty" jsonschema:"title=Multicellular"` // This comment will be ignored
}
@@ -92,10 +99,9 @@ type (
// Don't forget to checkout the nested path.
type User struct {
// Unique sequential identifier.
- ID int `json:"id" jsonschema:"required"`
- // This comment will be ignored
- Name string `json:"name" jsonschema:"required,minLength=1,maxLength=20,pattern=.*,description=this is a property,title=the name,example=joe,example=lucy,default=alex"`
- Friends []int `json:"friends,omitempty" jsonschema_description:"list of IDs, omitted when empty"`
+ ID int `json:"id" jsonschema:"required"`
+ Name string `json:"name" jsonschema:"required,minLength=1,maxLength=20,pattern=.*,description=this is a property,title=the name,example=joe,example=lucy,default=alex"`
+ Friends []int `json:"friends,omitempty" jsonschema_description:"list of IDs, omitted when empty"`
Tags map[string]any `json:"tags,omitempty"`
// An array of pets the user cares for.
@@ -109,7 +115,12 @@ type User struct {
}
var updateFixtures = flag.Bool("update", false, "set to update fixtures")
-var compareFixtures = flag.Bool("compare", false, "output failed fixtures with .out.json")
+
+var compareFixtures = flag.Bool(
+ "compare",
+ false,
+ "output failed fixtures with .out.json",
+)
type GrandfatherType struct {
FamilyName string `json:"family_name" jsonschema:"required"`
@@ -121,8 +132,8 @@ type SomeBaseType struct {
// The jsonschema required tag is nonsensical for private and ignored properties.
// Their presence here tests that the fields *will not* be required in the output
// schema, even if they are tagged required.
- SomeIgnoredBaseProperty string `json:"-" jsonschema:"required"`
- SomeSchemaIgnoredProperty string `jsonschema:"-,required"`
+ SomeIgnoredBaseProperty string `json:"-" jsonschema:"required"`
+ SomeSchemaIgnoredProperty string ` jsonschema:"-,required"`
Grandfather GrandfatherType `json:"grand"`
SomeUntaggedBaseProperty bool `jsonschema:"required"`
@@ -150,10 +161,10 @@ type TestUser struct {
nonExported
MapType
- ID int `json:"id" jsonschema:"required,minimum=bad,maximum=bad,exclusiveMinimum=bad,exclusiveMaximum=bad,default=bad"`
- Name string `json:"name" jsonschema:"required,minLength=1,maxLength=20,pattern=.*,description=this is a property,title=the name,example=joe,example=lucy,default=alex,readOnly=true"`
- Password string `json:"password" jsonschema:"writeOnly=true"`
- Friends []int `json:"friends,omitempty" jsonschema_description:"list of IDs, omitted when empty"`
+ ID int `json:"id" jsonschema:"required,minimum=bad,maximum=bad,exclusiveMinimum=bad,exclusiveMaximum=bad,default=bad"`
+ Name string `json:"name" jsonschema:"required,minLength=1,maxLength=20,pattern=.*,description=this is a property,title=the name,example=joe,example=lucy,default=alex,readOnly=true"`
+ Password string `json:"password" jsonschema:"writeOnly=true"`
+ Friends []int `json:"friends,omitempty" jsonschema_description:"list of IDs, omitted when empty"`
Tags map[string]string `json:"tags,omitempty"`
Options map[string]any `json:"options,omitempty"`
@@ -168,29 +179,29 @@ type TestUser struct {
IPAddress net.IP `json:"network_address,omitempty"`
// Tests for RFC draft-wright-json-schema-hyperschema-00, section 4
- Photo []byte `json:"photo,omitempty" jsonschema:"required"`
+ Photo []byte `json:"photo,omitempty" jsonschema:"required"`
Photo2 Bytes `json:"photo2,omitempty" jsonschema:"required"`
// Tests for jsonpb enum support
Feeling ProtoEnum `json:"feeling,omitempty"`
- Age int `json:"age" jsonschema:"minimum=18,maximum=120,exclusiveMaximum=121,exclusiveMinimum=17"`
+ Age int `json:"age" jsonschema:"minimum=18,maximum=120,exclusiveMaximum=121,exclusiveMinimum=17"`
Email string `json:"email" jsonschema:"format=email"`
- UUID string `json:"uuid" jsonschema:"format=uuid"`
+ UUID string `json:"uuid" jsonschema:"format=uuid"`
// Test for "extras" support
Baz string `jsonschema_extras:"foo=bar,hello=world,foo=bar1"`
- BoolExtra string `json:"bool_extra,omitempty" jsonschema_extras:"isTrue=true,isFalse=false"`
+ BoolExtra string `jsonschema_extras:"isTrue=true,isFalse=false" json:"bool_extra,omitempty"`
// Tests for simple enum tags
- Color string `json:"color" jsonschema:"enum=red,enum=green,enum=blue"`
+ Color string `json:"color" jsonschema:"enum=red,enum=green,enum=blue"`
Rank int `json:"rank,omitempty" jsonschema:"enum=1,enum=2,enum=3"`
Multiplier float64 `json:"mult,omitempty" jsonschema:"enum=1.0,enum=1.5,enum=2.0"`
// Tests for enum tags on slices
- Roles []string `json:"roles" jsonschema:"enum=admin,enum=moderator,enum=user"`
+ Roles []string `json:"roles" jsonschema:"enum=admin,enum=moderator,enum=user"`
Priorities []int `json:"priorities,omitempty" jsonschema:"enum=-1,enum=0,enum=1,enun=2"`
- Offsets []float64 `json:"offsets,omitempty" jsonschema:"enum=1.570796,enum=3.141592,enum=6.283185"`
+ Offsets []float64 `json:"offsets,omitempty" jsonschema:"enum=1.570796,enum=3.141592,enum=6.283185"`
// Test for raw JSON
Anything any `json:"anything,omitempty"`
@@ -397,7 +408,7 @@ type KeyNamed struct {
NotComingFromJSON bool `json:"coming_from_json_tag_not_renamed"`
NestedNotRenamed KeyNamedNested `json:"nested_not_renamed"`
UnicodeShenanigans string
- RenamedByComputation int `jsonschema_description:"Description was preserved"`
+ RenamedByComputation int ` jsonschema_description:"Description was preserved"`
}
type SchemaExtendTestBase struct {
@@ -425,7 +436,7 @@ type Expression struct {
type PatternEqualsTest struct {
WithEquals string `jsonschema:"pattern=foo=bar"`
- WithEqualsAndCommas string `jsonschema:"pattern=foo\\,=bar"`
+ WithEqualsAndCommas string `jsonschema:"pattern=foo,=bar"`
}
func TestReflector(t *testing.T) {
@@ -441,7 +452,11 @@ func TestReflectFromType(t *testing.T) {
typ := reflect.TypeOf(tu)
s := r.ReflectFromType(typ)
- assert.EqualValues(t, "https://github.com/conneroisu/groq-go/test-user", s.ID)
+ assert.EqualValues(
+ t,
+ "https://github.com/conneroisu/groq-go/test-user",
+ s.ID,
+ )
x := struct {
Test string
@@ -461,15 +476,51 @@ func TestSchemaGeneration(t *testing.T) {
}{
{&TestUser{}, &reflector{}, "testdata/test_user.json"},
{&UserWithAnchor{}, &reflector{}, "testdata/user_with_anchor.json"},
- {&TestUser{}, &reflector{AssignAnchor: true}, "testdata/test_user_assign_anchor.json"},
- {&TestUser{}, &reflector{AllowAdditionalProperties: true}, "testdata/allow_additional_props.json"},
- {&TestUser{}, &reflector{RequiredFromJSONSchemaTags: true}, "testdata/required_from_jsontags.json"},
- {&TestUser{}, &reflector{ExpandedStruct: true}, "testdata/defaults_expanded_toplevel.json"},
- {&TestUser{}, &reflector{IgnoredTypes: []any{GrandfatherType{}}}, "testdata/ignore_type.json"},
- {&TestUser{}, &reflector{DoNotReference: true}, "testdata/no_reference.json"},
- {&TestUser{}, &reflector{DoNotReference: true, AssignAnchor: true}, "testdata/no_reference_anchor.json"},
- {&RootOneOf{}, &reflector{RequiredFromJSONSchemaTags: true}, "testdata/oneof.json"},
- {&RootAnyOf{}, &reflector{RequiredFromJSONSchemaTags: true}, "testdata/anyof.json"},
+ {
+ &TestUser{},
+ &reflector{AssignAnchor: true},
+ "testdata/test_user_assign_anchor.json",
+ },
+ {
+ &TestUser{},
+ &reflector{AllowAdditionalProperties: true},
+ "testdata/allow_additional_props.json",
+ },
+ {
+ &TestUser{},
+ &reflector{RequiredFromJSONSchemaTags: true},
+ "testdata/required_from_jsontags.json",
+ },
+ {
+ &TestUser{},
+ &reflector{ExpandedStruct: true},
+ "testdata/defaults_expanded_toplevel.json",
+ },
+ {
+ &TestUser{},
+ &reflector{IgnoredTypes: []any{GrandfatherType{}}},
+ "testdata/ignore_type.json",
+ },
+ {
+ &TestUser{},
+ &reflector{DoNotReference: true},
+ "testdata/no_reference.json",
+ },
+ {
+ &TestUser{},
+ &reflector{DoNotReference: true, AssignAnchor: true},
+ "testdata/no_reference_anchor.json",
+ },
+ {
+ &RootOneOf{},
+ &reflector{RequiredFromJSONSchemaTags: true},
+ "testdata/oneof.json",
+ },
+ {
+ &RootAnyOf{},
+ &reflector{RequiredFromJSONSchemaTags: true},
+ "testdata/anyof.json",
+ },
{&CustomTypeField{}, &reflector{
Mapper: func(i reflect.Type) *schema {
if i == reflect.TypeOf(CustomTime{}) {
@@ -481,7 +532,11 @@ func TestSchemaGeneration(t *testing.T) {
return nil
},
}, "testdata/custom_type.json"},
- {LookupUser{}, &reflector{BaseSchemaID: "https://example.com/schemas"}, "testdata/base_schema_id.json"},
+ {
+ LookupUser{},
+ &reflector{BaseSchemaID: "https://example.com/schemas"},
+ "testdata/base_schema_id.json",
+ },
{LookupUser{}, &reflector{
Lookup: func(i reflect.Type) schemaID {
switch i {
@@ -507,11 +562,31 @@ func TestSchemaGeneration(t *testing.T) {
return EmptyID
},
}, "testdata/lookup_expanded.json"},
- {&Outer{}, &reflector{ExpandedStruct: true}, "testdata/inlining_inheritance.json"},
- {&OuterNamed{}, &reflector{ExpandedStruct: true}, "testdata/inlining_embedded.json"},
- {&OuterNamed{}, &reflector{ExpandedStruct: true, AssignAnchor: true}, "testdata/inlining_embedded_anchored.json"},
- {&OuterInlined{}, &reflector{ExpandedStruct: true}, "testdata/inlining_tag.json"},
- {&OuterPtr{}, &reflector{ExpandedStruct: true}, "testdata/inlining_ptr.json"},
+ {
+ &Outer{},
+ &reflector{ExpandedStruct: true},
+ "testdata/inlining_inheritance.json",
+ },
+ {
+ &OuterNamed{},
+ &reflector{ExpandedStruct: true},
+ "testdata/inlining_embedded.json",
+ },
+ {
+ &OuterNamed{},
+ &reflector{ExpandedStruct: true, AssignAnchor: true},
+ "testdata/inlining_embedded_anchored.json",
+ },
+ {
+ &OuterInlined{},
+ &reflector{ExpandedStruct: true},
+ "testdata/inlining_tag.json",
+ },
+ {
+ &OuterPtr{},
+ &reflector{ExpandedStruct: true},
+ "testdata/inlining_ptr.json",
+ },
{&MinValue{}, &reflector{}, "testdata/schema_with_minimum.json"},
{&TestNullable{}, &reflector{}, "testdata/nullable.json"},
{&GrandfatherType{}, &reflector{
@@ -526,12 +601,19 @@ func TestSchemaGeneration(t *testing.T) {
}
},
}, "testdata/custom_additional.json"},
- {&TestDescriptionOverride{}, &reflector{}, "testdata/test_description_override.json"},
+ {
+ &TestDescriptionOverride{},
+ &reflector{},
+ "testdata/test_description_override.json",
+ },
{&CompactDate{}, &reflector{}, "testdata/compact_date.json"},
{&CustomSliceOuter{}, &reflector{}, "testdata/custom_slice_type.json"},
{&CustomMapOuter{}, &reflector{}, "testdata/custom_map_type.json"},
- {&CustomTypeFieldWithInterface{}, &reflector{}, "testdata/custom_type_with_interface.json"},
- {&PatternTest{}, &reflector{}, "testdata/commas_in_pattern.json"},
+ {
+ &CustomTypeFieldWithInterface{},
+ &reflector{},
+ "testdata/custom_type_with_interface.json",
+ },
{&RecursiveExample{}, &reflector{}, "testdata/recursive.json"},
{&KeyNamed{}, &reflector{
KeyNamer: func(s string) string {
@@ -560,7 +642,7 @@ func TestSchemaGeneration(t *testing.T) {
{ArrayType{}, &reflector{}, "testdata/array_type.json"},
{SchemaExtendTest{}, &reflector{}, "testdata/custom_type_extend.json"},
{Expression{}, &reflector{}, "testdata/schema_with_expression.json"},
- {PatternEqualsTest{}, &reflector{}, "testdata/equals_in_pattern.json"},
+ {&PatternTest{}, &reflector{}, "testdata/commas_in_pattern.json"},
}
for _, tt := range tests {
@@ -584,7 +666,11 @@ func compareSchemaOutput(t *testing.T, f string, r *reflector, obj any) {
require.NoError(t, err)
actualSchema := r.Reflect(obj)
- actualJSON, _ := json.MarshalIndent(actualSchema, "", " ") //nolint:errchkjson
+ actualJSON, _ := json.MarshalIndent(
+ actualSchema,
+ "",
+ " ",
+ ) //nolint:errchkjson
if *updateFixtures {
_ = os.WriteFile(f, actualJSON, 0600)
@@ -592,7 +678,11 @@ func compareSchemaOutput(t *testing.T, f string, r *reflector, obj any) {
if !assert.JSONEq(t, string(expectedJSON), string(actualJSON)) {
if *compareFixtures {
- _ = os.WriteFile(strings.TrimSuffix(f, ".json")+".out.json", actualJSON, 0600)
+ _ = os.WriteFile(
+ strings.TrimSuffix(f, ".json")+".out.json",
+ actualJSON,
+ 0600,
+ )
}
}
}
@@ -609,7 +699,10 @@ func TestSplitOnUnescapedCommas(t *testing.T) {
strToSplit string
expected []string
}{
- {`Hello,this,is\,a\,string,haha`, []string{`Hello`, `this`, `is,a,string`, `haha`}},
+ {
+ `Hello,this,is\,a\,string,haha`,
+ []string{`Hello`, `this`, `is,a,string`, `haha`},
+ },
{`hello,no\\,split`, []string{`hello`, `no\,split`}},
{`string without commas`, []string{`string without commas`}},
{`ünicode,𐂄,Ж\,П,ᠳ`, []string{`ünicode`, `𐂄`, `Ж,П`, `ᠳ`}},
@@ -656,9 +749,9 @@ func TestFieldNameTag(t *testing.T) {
func TestFieldOneOfRef(t *testing.T) {
type Server struct {
- IPAddress any `json:"ip_address,omitempty" jsonschema:"oneof_ref=#/$defs/ipv4;#/$defs/ipv6"`
- IPAddresses []any `json:"ip_addresses,omitempty" jsonschema:"oneof_ref=#/$defs/ipv4;#/$defs/ipv6"`
- IPAddressAny any `json:"ip_address_any,omitempty" jsonschema:"anyof_ref=#/$defs/ipv4;#/$defs/ipv6"`
+ IPAddress any `json:"ip_address,omitempty" jsonschema:"oneof_ref=#/$defs/ipv4;#/$defs/ipv6"`
+ IPAddresses []any `json:"ip_addresses,omitempty" jsonschema:"oneof_ref=#/$defs/ipv4;#/$defs/ipv6"`
+ IPAddressAny any `json:"ip_address_any,omitempty" jsonschema:"anyof_ref=#/$defs/ipv4;#/$defs/ipv6"`
IPAddressesAny []any `json:"ip_addresses_any,omitempty" jsonschema:"anyof_ref=#/$defs/ipv4;#/$defs/ipv6"`
}
@@ -668,12 +761,17 @@ func TestFieldOneOfRef(t *testing.T) {
func TestNumberHandling(t *testing.T) {
type NumberHandler struct {
- Int64 int64 `json:"int64" jsonschema:"default=12"`
+ Int64 int64 `json:"int64" jsonschema:"default=12"`
Float32 float32 `json:"float32" jsonschema:"default=12.5"`
}
r := &reflector{}
- compareSchemaOutput(t, "testdata/number_handling.json", r, &NumberHandler{})
+ compareSchemaOutput(
+ t,
+ "testdata/number_handling.json",
+ r,
+ &NumberHandler{},
+ )
fixtureContains(t, "testdata/number_handling.json", `"default": 12`)
fixtureContains(t, "testdata/number_handling.json", `"default": 12.5`)
}
@@ -692,14 +790,19 @@ func TestArrayHandling(t *testing.T) {
func TestUnsignedIntHandling(t *testing.T) {
type UnsignedIntHandler struct {
- MinLen []string `json:"min_len" jsonschema:"minLength=0"`
- MaxLen []string `json:"max_len" jsonschema:"maxLength=0"`
+ MinLen []string `json:"min_len" jsonschema:"minLength=0"`
+ MaxLen []string `json:"max_len" jsonschema:"maxLength=0"`
MinItems []string `json:"min_items" jsonschema:"minItems=0"`
MaxItems []string `json:"max_items" jsonschema:"maxItems=0"`
}
r := &reflector{}
- compareSchemaOutput(t, "testdata/unsigned_int_handling.json", r, &UnsignedIntHandler{})
+ compareSchemaOutput(
+ t,
+ "testdata/unsigned_int_handling.json",
+ r,
+ &UnsignedIntHandler{},
+ )
fixtureContains(t, "testdata/unsigned_int_handling.json", `"minLength": 0`)
fixtureContains(t, "testdata/unsigned_int_handling.json", `"maxLength": 0`)
fixtureContains(t, "testdata/unsigned_int_handling.json", `"minItems": 0`)
@@ -709,11 +812,16 @@ func TestUnsignedIntHandling(t *testing.T) {
func TestJSONSchemaFormat(t *testing.T) {
type WithCustomFormat struct {
Dates []string `json:"dates" jsonschema:"format=date"`
- Odds []string `json:"odds" jsonschema:"format=odd"`
+ Odds []string `json:"odds" jsonschema:"format=odd"`
}
r := &reflector{}
- compareSchemaOutput(t, "testdata/with_custom_format.json", r, &WithCustomFormat{})
+ compareSchemaOutput(
+ t,
+ "testdata/with_custom_format.json",
+ r,
+ &WithCustomFormat{},
+ )
fixtureContains(t, "testdata/with_custom_format.json", `"format": "date"`)
fixtureContains(t, "testdata/with_custom_format.json", `"format": "odd"`)
}
@@ -744,7 +852,12 @@ func (AliasObjectB) JSONSchemaAlias() any {
func TestJSONSchemaProperty(t *testing.T) {
r := &reflector{}
- compareSchemaOutput(t, "testdata/schema_property_alias.json", r, &AliasPropertyObjectBase{})
+ compareSchemaOutput(
+ t,
+ "testdata/schema_property_alias.json",
+ r,
+ &AliasPropertyObjectBase{},
+ )
}
func TestJSONSchemaAlias(t *testing.T) {
diff --git a/scripts/makefile/fmt.sh b/scripts/makefile/fmt.sh
index 9de9ede..1e79135 100644
--- a/scripts/makefile/fmt.sh
+++ b/scripts/makefile/fmt.sh
@@ -7,4 +7,4 @@
gofmt -w .
-golines -w --max-len=79 .
+# golines -w --max-len=79 .
diff --git a/stream_test.go b/stream_test.go
index 1ff37d1..6750024 100644
--- a/stream_test.go
+++ b/stream_test.go
@@ -1,4 +1,7 @@
-package groq //nolint:testpackage // testing private field
+//go:build !test
+// +build !test
+
+package groq
import (
"bufio"
diff --git a/test/unit_test.go b/test/unit_test.go
deleted file mode 100644
index 59dd4dd..0000000
--- a/test/unit_test.go
+++ /dev/null
@@ -1,74 +0,0 @@
-package test
-
-import (
- "context"
- "errors"
- "fmt"
- "io"
- "math/rand"
- "os"
- "testing"
-
- "github.com/conneroisu/groq-go"
- "github.com/stretchr/testify/assert"
-)
-
-func TestTestServer(t *testing.T) {
- num := rand.Intn(100)
- a := assert.New(t)
- ctx := context.Background()
- client, err := groq.NewClient(os.Getenv("GROQ_KEY"))
- a.NoError(err, "NewClient error")
- strm, err := client.CreateChatCompletionStream(
- ctx,
- groq.ChatCompletionRequest{
- Model: groq.Llama38B8192,
- Messages: []groq.ChatCompletionMessage{
- {
- Role: groq.ChatMessageRoleUser,
- // convert the content of a excel file into a csv that can be imported into google calendar:
- //
- //
- //
- // Course ListingSectionInstructional FormatDelivery ModeMeeting PatternsInstructor
- //
- // CPRE 3810 - Computer Organization and Assembly Level ProgrammingCPRE 3810-1 - Computer Organization and Assembly Level ProgrammingLectureIn-PersonMWF | 8:50 AM - 9:40 AM | 0101 CARVER - Carver HallBerk Gulmezoglu
- //
- // EE 3240 - Signals and Systems IIEE 3240-1 - Signals and Systems IILectureIn-PersonMWF | 2:15 PM - 3:05 PM | 1134 SWEENEY - Sweeney HallRatnesh Kumar
- //
- // EE 3220 - Probabilistic Methods for Electrical EngineersEE 3220-1 - Probabilistic Methods for Electrical EngineersLectureIn-PersonTR | 11:00 AM - 12:15 PM | 1134 SWEENEY - Sweeney HallJulie A Dickerson
- //
- // SOC 1340 - Introduction to SociologySOC 1340-2 - Introduction to SociologyLectureIn-PersonMWF | 11:00 AM - 11:50 AM | 0127 CURTISS - Curtiss HallDavid Scott Schweingruber
- //
- // CPRE 3810 - Computer Organization and Assembly Level ProgrammingCPRE 3810-F - Computer Organization and Assembly Level ProgrammingLaboratoryIn-PersonR | 8:00 AM - 9:50 AM | 2050 COOVER - Coover HallBerk Gulmezoglu
- //
- // EE 3240 - Signals and Systems IIEE 3240-C - Signals and Systems IILaboratoryIn-PersonR | 4:10 PM - 7:00 PM | 2011 COOVER - Coover HallRatnesh Kumar
- //
- Content: fmt.Sprintf(`
-problem: %d
-You have a six-sided die that you roll once. Let $R{i}$ denote the event that the roll is $i$. Let $G{j}$ denote the event that the roll is greater than $j$. Let $E$ denote the event that the roll of the die is even-numbered.
-(a) What is $P\left[R{3} \mid G{1}\right]$, the conditional probability that 3 is rolled given that the roll is greater than 1 ?
-(b) What is the conditional probability that 6 is rolled given that the roll is greater than 3 ?
-(c) What is the $P\left[G_{3} \mid E\right]$, the conditional probability that the roll is greater than 3 given that the roll is even?
-(d) Given that the roll is greater than 3, what is the conditional probability that the roll is even?
- `, num,
- ),
- },
- },
- MaxTokens: 2000,
- Stream: true,
- },
- )
- a.NoError(err, "CreateCompletionStream error")
-
- i := 0
- for {
- i++
- val, err := strm.Recv()
- if errors.Is(err, io.EOF) {
- break
- }
- // t.Logf("%d %s\n", i, val.Choices[0].Delta.Content)
- print(val.Choices[0].Delta.Content)
- }
-}
diff --git a/unit_unit_test.go b/unit_unit_test.go
new file mode 100644
index 0000000..26d0d46
--- /dev/null
+++ b/unit_unit_test.go
@@ -0,0 +1,59 @@
+//go:build !test
+// +build !test
+
+package groq_test
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "math/rand"
+ "os"
+ "testing"
+
+ "github.com/conneroisu/groq-go"
+ "github.com/stretchr/testify/assert"
+)
+
+func TestTestServer(t *testing.T) {
+ num := rand.Intn(100)
+ a := assert.New(t)
+ ctx := context.Background()
+ client, err := groq.NewClient(os.Getenv("GROQ_KEY"))
+ a.NoError(err, "NewClient error")
+ strm, err := client.CreateChatCompletionStream(
+ ctx,
+ groq.ChatCompletionRequest{
+ Model: groq.Llama38B8192,
+ Messages: []groq.ChatCompletionMessage{
+ {
+ Role: groq.ChatMessageRoleUser,
+ Content: fmt.Sprintf(`
+problem: %d
+You have a six-sided die that you roll once. Let $R{i}$ denote the event that the roll is $i$. Let $G{j}$ denote the event that the roll is greater than $j$. Let $E$ denote the event that the roll of the die is even-numbered.
+(a) What is $P\left[R{3} \mid G{1}\right]$, the conditional probability that 3 is rolled given that the roll is greater than 1 ?
+(b) What is the conditional probability that 6 is rolled given that the roll is greater than 3 ?
+(c) What is the $P\left[G_{3} \mid E\right]$, the conditional probability that the roll is greater than 3 given that the roll is even?
+(d) Given that the roll is greater than 3, what is the conditional probability that the roll is even?
+ `, num,
+ ),
+ },
+ },
+ MaxTokens: 2000,
+ Stream: true,
+ },
+ )
+ a.NoError(err, "CreateCompletionStream error")
+
+ i := 0
+ for {
+ i++
+ val, err := strm.Recv()
+ if errors.Is(err, io.EOF) {
+ break
+ }
+ // t.Logf("%d %s\n", i, val.Choices[0].Delta.Content)
+ print(val.Choices[0].Delta.Content)
+ }
+}