Skip to content

Commit

Permalink
Merge pull request #24 from conneroisu/devie
Browse files Browse the repository at this point in the history
devie
  • Loading branch information
conneroisu authored Sep 6, 2024
2 parents ac1505b + 4bac49b commit 7c1dde0
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 75 deletions.
31 changes: 13 additions & 18 deletions audio.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,31 +71,34 @@ func (r *audioTextResponse) SetHeader(header http.Header) {
r.header = header
}

// ToAudioResponse converts the audio text response to an audio response.
func (r *audioTextResponse) ToAudioResponse() AudioResponse {
// toAudioResponse converts the audio text response to an audio response.
func (r *audioTextResponse) toAudioResponse() AudioResponse {
return AudioResponse{
Text: r.Text,
Header: r.header,
}
}

// CreateTranscription — API call to create a transcription. Returns transcribed text.
// CreateTranscription calls the transcriptions endpoint with the given request.
//
// Returns transcribed text in the response_format specified in the request.
func (c *Client) CreateTranscription(
ctx context.Context,
request AudioRequest,
) (response AudioResponse, err error) {
return c.callAudioAPI(ctx, request, transcriptionsSuffix)
}

// CreateTranslation — API call to translate audio into English.
// CreateTranslation calls the translations endpoint with the given request.
//
// Returns the translated text in the response_format specified in the request.
func (c *Client) CreateTranslation(
ctx context.Context,
request AudioRequest,
) (response AudioResponse, err error) {
return c.callAudioAPI(ctx, request, translationsSuffix)
}

// callAudioAPI — API call to an audio endpoint.
func (c *Client) callAudioAPI(
ctx context.Context,
request AudioRequest,
Expand All @@ -118,54 +121,49 @@ func (c *Client) callAudioAPI(
return AudioResponse{}, err
}

if request.HasJSONResponse() {
if request.hasJSONResponse() {
err = c.sendRequest(req, &response)
} else {
var textResponse audioTextResponse
err = c.sendRequest(req, &textResponse)
response = textResponse.ToAudioResponse()
response = textResponse.toAudioResponse()
}
if err != nil {
return AudioResponse{}, err
}
return
}

// HasJSONResponse returns true if the response format is JSON.
func (r AudioRequest) HasJSONResponse() bool {
func (r AudioRequest) hasJSONResponse() bool {
return r.Format == "" || r.Format == AudioResponseFormatJSON ||
r.Format == AudioResponseFormatVerboseJSON
}

// audioMultipartForm creates a form with audio file contents and the name of the model to use for
// audio processing.
func audioMultipartForm(request AudioRequest, b FormBuilder) error {
func audioMultipartForm(request AudioRequest, b formBuilder) error {
err := createFileField(request, b)
if err != nil {
return err
}

err = b.WriteField("model", string(request.Model))
if err != nil {
return fmt.Errorf("writing model name: %w", err)
}

// Create a form field for the prompt (if provided)
if request.Prompt != "" {
err = b.WriteField("prompt", request.Prompt)
if err != nil {
return fmt.Errorf("writing prompt: %w", err)
}
}

// Create a form field for the format (if provided)
if request.Format != "" {
err = b.WriteField("response_format", string(request.Format))
if err != nil {
return fmt.Errorf("writing format: %w", err)
}
}

// Create a form field for the temperature (if provided)
if request.Temperature != 0 {
err = b.WriteField(
Expand All @@ -176,15 +174,13 @@ func audioMultipartForm(request AudioRequest, b FormBuilder) error {
return fmt.Errorf("writing temperature: %w", err)
}
}

// Create a form field for the language (if provided)
if request.Language != "" {
err = b.WriteField("language", request.Language)
if err != nil {
return fmt.Errorf("writing language: %w", err)
}
}

if len(request.TimestampGranularities) > 0 {
for _, tg := range request.TimestampGranularities {
err = b.WriteField("timestamp_granularities[]", string(tg))
Expand All @@ -193,14 +189,13 @@ func audioMultipartForm(request AudioRequest, b FormBuilder) error {
}
}
}
// Close the multipart writer
return b.Close()
}

// createFileField creates the "file" form field from either an existing file or by using the reader.
func createFileField(
request AudioRequest,
b FormBuilder,
b formBuilder,
) (err error) {
if request.Reader != nil {
err := b.CreateFormFileReader("file", request.Reader, request.FilePath)
Expand Down
65 changes: 19 additions & 46 deletions builders.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,46 +12,43 @@ import (
"path"
)

// FormBuilder is an interface for building a form.
type FormBuilder interface {
// formBuilder is an interface for building a form.
type formBuilder interface {
CreateFormFile(fieldname string, file *os.File) error
CreateFormFileReader(fieldname string, r io.Reader, filename string) error
WriteField(fieldname, value string) error
Close() error
FormDataContentType() string
}

// DefaultFormBuilder is a default implementation of FormBuilder.
type DefaultFormBuilder struct {
// defaultFormBuilder is a default implementation of FormBuilder.
type defaultFormBuilder struct {
writer *multipart.Writer
}

// NewFormBuilder creates a new DefaultFormBuilder.
func NewFormBuilder(body io.Writer) *DefaultFormBuilder {
return &DefaultFormBuilder{
// newFormBuilder creates a new DefaultFormBuilder.
func newFormBuilder(body io.Writer) *defaultFormBuilder {
return &defaultFormBuilder{
writer: multipart.NewWriter(body),
}
}

// CreateFormFile creates a form file.
func (fb *DefaultFormBuilder) CreateFormFile(
func (fb *defaultFormBuilder) CreateFormFile(
fieldname string,
file *os.File,
) error {
return fb.createFormFile(fieldname, file, file.Name())
}

// CreateFormFileReader creates a form file from a reader.
func (fb *DefaultFormBuilder) CreateFormFileReader(
func (fb *defaultFormBuilder) CreateFormFileReader(
fieldname string,
r io.Reader,
filename string,
) error {
return fb.createFormFile(fieldname, r, path.Base(filename))
}

// createFormFile creates a form file.
func (fb *DefaultFormBuilder) createFormFile(
func (fb *defaultFormBuilder) createFormFile(
fieldname string,
r io.Reader,
filename string,
Expand All @@ -71,23 +68,19 @@ func (fb *DefaultFormBuilder) createFormFile(
return nil
}

// WriteField writes a field to the form.
func (fb *DefaultFormBuilder) WriteField(fieldname, value string) error {
func (fb *defaultFormBuilder) WriteField(fieldname, value string) error {
return fb.writer.WriteField(fieldname, value)
}

// Close closes the form.
func (fb *DefaultFormBuilder) Close() error {
func (fb *defaultFormBuilder) Close() error {
return fb.writer.Close()
}

// FormDataContentType returns the content type of the form.
func (fb *DefaultFormBuilder) FormDataContentType() string {
func (fb *defaultFormBuilder) FormDataContentType() string {
return fb.writer.FormDataContentType()
}

// RequestBuilder is an interface that defines the Build method.
type RequestBuilder interface {
type requestBuilder interface {
Build(
ctx context.Context,
method, url string,
Expand All @@ -96,33 +89,13 @@ type RequestBuilder interface {
) (*http.Request, error)
}

// HTTPRequestBuilder is a struct that implements the RequestBuilder interface.
type HTTPRequestBuilder struct {
marshaller Marshaller
}

// Marshaller is an interface that defines the Marshal method.
type Marshaller interface {
Marshal(v any) ([]byte, error)
}

// JSONMarshaller is a struct that implements the Marshaller interface.
type JSONMarshaller struct{}
type httpRequestBuilder struct{}

// Marshal marshals the given value to JSON.
func (j *JSONMarshaller) Marshal(v any) ([]byte, error) {
return json.Marshal(v)
}

// NewRequestBuilder returns a new HTTPRequestBuilder.
func NewRequestBuilder() *HTTPRequestBuilder {
return &HTTPRequestBuilder{
marshaller: &JSONMarshaller{},
}
func newRequestBuilder() *httpRequestBuilder {
return &httpRequestBuilder{}
}

// Build builds a new request.
func (b *HTTPRequestBuilder) Build(
func (b *httpRequestBuilder) Build(
ctx context.Context,
method string,
url string,
Expand All @@ -135,7 +108,7 @@ func (b *HTTPRequestBuilder) Build(
bodyReader = v
} else {
var reqBytes []byte
reqBytes, err = b.marshaller.Marshal(body)
reqBytes, err = json.Marshal(body)
if err != nil {
return
}
Expand Down
11 changes: 6 additions & 5 deletions builders_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package groq // testing private field
import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
"net/http"
Expand Down Expand Up @@ -79,7 +80,7 @@ func TestFormBuilderWithFailingWriter(t *testing.T) {
defer file.Close()
defer os.Remove(file.Name())

Check warning on line 82 in builders_test.go

View check run for this annotation

Codeac.io / Codeac Code Quality

CodeDuplication

This block of 10 lines is too similar to builders_test.go:93
builder := NewFormBuilder(&failingWriter{})
builder := newFormBuilder(&failingWriter{})
err = builder.CreateFormFile("file", file)
a.ErrorIs(
err,
Expand All @@ -102,7 +103,7 @@ func TestFormBuilderWithClosedFile(t *testing.T) {
defer os.Remove(file.Name())

Check warning on line 103 in builders_test.go

View check run for this annotation

Codeac.io / Codeac Code Quality

CodeDuplication

This block of 10 lines is too similar to builders_test.go:72

body := &bytes.Buffer{}
builder := NewFormBuilder(body)
builder := newFormBuilder(body)
err = builder.CreateFormFile("file", file)
a.Error(err, "formbuilder should return error if file is closed")
a.ErrorIs(
Expand All @@ -115,13 +116,13 @@ func TestFormBuilderWithClosedFile(t *testing.T) {
// TestRequestBuilderReturnsRequest tests the request builder returns a
// request.
func TestRequestBuilderReturnsRequest(t *testing.T) {
b := NewRequestBuilder()
b := newRequestBuilder()
var (
ctx = context.Background()
method = http.MethodPost
url = "/foo"
request = map[string]string{"foo": "bar"}
reqBytes, _ = b.marshaller.Marshal(request)
reqBytes, _ = json.Marshal(request)
want, _ = http.NewRequestWithContext(
ctx,
method,
Expand All @@ -146,7 +147,7 @@ func TestRequestBuilderReturnsRequestWhenRequestOfArgsIsNil(t *testing.T) {
url = "/foo"
want, _ = http.NewRequestWithContext(ctx, method, url, nil)
)
b := NewRequestBuilder()
b := newRequestBuilder()
got, _ := b.Build(ctx, method, url, nil, nil)
if !reflect.DeepEqual(got, want) {
t.Errorf("Build() got = %v, want %v", got, want)
Expand Down
12 changes: 6 additions & 6 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ type Client struct {
baseURL string // Base URL for the client.
client *http.Client // Client is the HTTP client to use
EmptyMessagesLimit uint // EmptyMessagesLimit is the limit for the empty messages.
requestBuilder RequestBuilder
requestFormBuilder FormBuilder
createFormBuilder func(body io.Writer) FormBuilder
requestBuilder requestBuilder
requestFormBuilder formBuilder
createFormBuilder func(body io.Writer) formBuilder
logger zerolog.Logger // Logger is the logger for the client.
}

Expand All @@ -57,10 +57,10 @@ func NewClient(groqAPIKey string, opts ...Opts) (*Client, error) {
Logger(),
baseURL: groqAPIURLv1,
EmptyMessagesLimit: 10,
createFormBuilder: func(body io.Writer) FormBuilder {
return NewFormBuilder(body)
createFormBuilder: func(body io.Writer) formBuilder {
return newFormBuilder(body)
},
requestBuilder: NewRequestBuilder(),
requestBuilder: newRequestBuilder(),
}
for _, opt := range opts {
opt(c)
Expand Down

0 comments on commit 7c1dde0

Please sign in to comment.