Skip to content

Commit

Permalink
Add ListModels method (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattt authored Oct 4, 2023
1 parent d0842c9 commit 9c83236
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 0 deletions.
10 changes: 10 additions & 0 deletions model.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,16 @@ func (m *ModelVersion) UnmarshalJSON(data []byte) error {
return json.Unmarshal(data, alias)
}

// ListModels lists public models.
func (r *Client) ListModels(ctx context.Context) (*Page[Model], error) {
response := &Page[Model]{}
err := r.request(ctx, "GET", "/models", nil, response)
if err != nil {
return nil, fmt.Errorf("failed to list models: %w", err)
}
return response, nil
}

// GetModel retrieves information about a model.
func (r *Client) GetModel(ctx context.Context, modelOwner string, modelName string) (*Model, error) {
model := &Model{}
Expand Down
45 changes: 45 additions & 0 deletions replicate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,51 @@ func TestGetCollection(t *testing.T) {
assert.Empty(t, *collection.Models)
}

func TestListModels(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/models", r.URL.Path)
assert.Equal(t, http.MethodGet, r.Method)

response := replicate.Page[replicate.Model]{
Results: []replicate.Model{
{
Owner: "stability-ai",
Name: "sdxl",
Description: "A text-to-image generative AI model that creates beautiful 1024x1024 images",
},
{
Owner: "meta",
Name: "codellama-13b",
Description: "A 13 billion parameter Llama tuned for code completion",
},
},
}

w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
body, _ := json.Marshal(response)
w.Write(body)
}))
defer mockServer.Close()

client, err := replicate.NewClient(
replicate.WithToken("test-token"),
replicate.WithBaseURL(mockServer.URL),
)
require.NoError(t, err)

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

modelsPage, err := client.ListModels(ctx)
assert.NoError(t, err)
assert.Equal(t, 2, len(modelsPage.Results))
assert.Equal(t, "stability-ai", modelsPage.Results[0].Owner)
assert.Equal(t, "sdxl", modelsPage.Results[0].Name)
assert.Equal(t, "meta", modelsPage.Results[1].Owner)
assert.Equal(t, "codellama-13b", modelsPage.Results[1].Name)
}

func TestGetModel(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/models/replicate/hello-world", r.URL.Path)
Expand Down

0 comments on commit 9c83236

Please sign in to comment.