Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Assistant APIs #464

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions packages/grafana-llm-app/llmclient/go.mod
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing this completely seems a bit scary, my knowledge of Go dependencies isn't strong enough to know how it would handle a breaking change (e.g. v2) in sashabaranov/go-openai though 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we should consider this in a separate PR, unless it's really needed here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, what's the reason for this removal? I like specifying that we depend on >=1.15.3, <2.0.0. Are you running into import issues somewhere?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sd2k @csmarchbanks I got confused about why there were two go.mod files - I didn't quite realise the client was a separate build product! I'll revert, sorry for the confusion.

This file was deleted.

2 changes: 0 additions & 2 deletions packages/grafana-llm-app/llmclient/go.sum

This file was deleted.

99 changes: 99 additions & 0 deletions packages/grafana-llm-app/llmclient/llmclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ type ChatCompletionRequest struct {
Model Model `json:"model"`
}

// AssistantRequest is a request for creating an assistant using an abstract model.
type AssistantRequest struct {
openai.AssistantRequest
Model Model `json:"model"`
}

// OpenAI is an interface for talking to OpenAI via the Grafana LLM app.
// Requests made using this interface will be routed to the OpenAI backend
// configured in the Grafana LLM app's settings, with authentication handled
Expand All @@ -47,6 +53,36 @@ type OpenAI interface {
ChatCompletions(ctx context.Context, req ChatCompletionRequest) (openai.ChatCompletionResponse, error)
// ChatCompletionsStream makes a streaming request to the OpenAI Chat Completion API.
ChatCompletionsStream(ctx context.Context, req ChatCompletionRequest) (*openai.ChatCompletionStream, error)
// CreateAssistant creates an assistant using the given request.
CreateAssistant(ctx context.Context, req AssistantRequest) (openai.Assistant, error)
// RetrieveAssistant retrieves an assistant by ID.
RetrieveAssistant(ctx context.Context, assistantID string) (openai.Assistant, error)
// ListAssistants lists assistants.
ListAssistants(ctx context.Context, limit *int, order *string, after *string, before *string) (openai.AssistantsList, error)
// DeleteAssistant deletes an assistant by ID.
DeleteAssistant(ctx context.Context, assistantID string) (openai.AssistantDeleteResponse, error)
// CreateThread creates a new thread.
CreateThread(ctx context.Context, req openai.ThreadRequest) (openai.Thread, error)
// RetrieveThread retrieves a thread by ID.
RetrieveThread(ctx context.Context, threadID string) (openai.Thread, error)
// DeleteThread deletes a thread by ID.
DeleteThread(ctx context.Context, threadID string) (openai.ThreadDeleteResponse, error)
// CreateMessage creates a new message in a thread.
CreateMessage(ctx context.Context, threadID string, request openai.MessageRequest) (msg openai.Message, err error)
// ListMessages lists messages in a thread.
ListMessages(ctx context.Context, threadID string, limit *int, order *string, after *string, before *string, runID *string) (openai.MessagesList, error)
// RetrieveMessage retrieves a message in a thread.
RetrieveMessage(ctx context.Context, threadID string, messageID string) (msg openai.Message, err error)
// DeleteMessage deletes a message in a thread.
DeleteMessage(ctx context.Context, threadID string, messageID string) (msg openai.MessageDeletionStatus, err error)
// CreateRun creates a new run in a thread.
CreateRun(ctx context.Context, threadID string, request openai.RunRequest) (run openai.Run, err error)
// RetrieveRun retrieves a run in a thread.
RetrieveRun(ctx context.Context, threadID string, runID string) (run openai.Run, err error)
// CancelRun cancels a run in a thread.
CancelRun(ctx context.Context, threadID string, runID string) (run openai.Run, err error)
// SubmitToolOutputs submits tool outputs for a run in a thread.
SubmitToolOutputs(ctx context.Context, threadID string, runID string, request openai.SubmitToolOutputsRequest) (response openai.Run, err error)
Comment on lines +56 to +85
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A niggling part of my brain thinks this should be a separate OpenAIAssistant interface to avoid the OpenAI interface becoming too big (and make it easier to mock for users in tests), particularly if users are importing and using this in their code already. WDYT?

I guess people would need to type switch though so maybe it's not worth it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also think having a second interface could be nice. It would also allow us to check for features based on if the interface is implemented for a connection or not (in the case we add more first class implementations of this interface).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, will add this interface.

}

type openAI struct {
Expand Down Expand Up @@ -90,6 +126,7 @@ type openAIHealthDetails struct {
OK bool `json:"ok"`
Error string `json:"error,omitempty"`
Models map[Model]openAIModelHealth `json:"models"`
Assistant openAIModelHealth `json:"assistant"`
}

type vectorHealthDetails struct {
Expand Down Expand Up @@ -159,3 +196,65 @@ func (o *openAI) ChatCompletionsStream(ctx context.Context, req ChatCompletionRe
r.Model = string(req.Model)
return o.client.CreateChatCompletionStream(ctx, r)
}

func (o *openAI) CreateAssistant(ctx context.Context, req AssistantRequest) (openai.Assistant, error) {
r := req.AssistantRequest
r.Model = string(req.Model)
return o.client.CreateAssistant(ctx, r)
}

func (o *openAI) RetrieveAssistant(ctx context.Context, assistantID string) (openai.Assistant, error) {
return o.client.RetrieveAssistant(ctx, assistantID)
}

func (o *openAI) ListAssistants(ctx context.Context, limit *int, order *string, after *string, before *string) (openai.AssistantsList, error) {
return o.client.ListAssistants(ctx, limit, order, after, before)
}

func (o *openAI) DeleteAssistant(ctx context.Context, assistantID string) (openai.AssistantDeleteResponse, error) {
return o.client.DeleteAssistant(ctx, assistantID)
}

func (o *openAI) CreateThread(ctx context.Context, req openai.ThreadRequest) (openai.Thread, error) {
return o.client.CreateThread(ctx, req)
}

func (o *openAI) RetrieveThread(ctx context.Context, threadID string) (openai.Thread, error) {
return o.client.RetrieveThread(ctx, threadID)
}

func (o *openAI) DeleteThread(ctx context.Context, threadID string) (openai.ThreadDeleteResponse, error) {
return o.client.DeleteThread(ctx, threadID)
}

func (o *openAI) CreateMessage(ctx context.Context, threadID string, request openai.MessageRequest) (msg openai.Message, err error) {
return o.client.CreateMessage(ctx, threadID, request)
}

func (o *openAI) ListMessages(ctx context.Context, threadID string, limit *int, order *string, after *string, before *string, runID *string) (msg openai.MessagesList, err error) {
return o.client.ListMessage(ctx, threadID, limit, order, after, before, runID)
}

func (o *openAI) RetrieveMessage(ctx context.Context, threadID string, messageID string) (msg openai.Message, err error) {
return o.client.RetrieveMessage(ctx, threadID, messageID)
}

func (o *openAI) DeleteMessage(ctx context.Context, threadID string, messageID string) (msg openai.MessageDeletionStatus, err error) {
return o.client.DeleteMessage(ctx, threadID, messageID)
}

func (o *openAI) CreateRun(ctx context.Context, threadID string, request openai.RunRequest) (run openai.Run, err error) {
return o.client.CreateRun(ctx, threadID, request)
}

func (o *openAI) RetrieveRun(ctx context.Context, threadID string, runID string) (run openai.Run, err error) {
return o.client.RetrieveRun(ctx, threadID, runID)
}

func (o *openAI) CancelRun(ctx context.Context, threadID string, runID string) (run openai.Run, err error) {
return o.client.CancelRun(ctx, threadID, runID)
}

func (o *openAI) SubmitToolOutputs(ctx context.Context, threadID string, runID string, request openai.SubmitToolOutputsRequest) (response openai.Run, err error) {
return o.client.SubmitToolOutputs(ctx, threadID, runID, request)
}
250 changes: 250 additions & 0 deletions packages/grafana-llm-app/llmclient/llmclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,253 @@ func TestChatCompletionsStream(t *testing.T) {
t.Errorf("expected streamed content to be 'hello there', got '%s'", content)
}
}

func TestCreateAssistant(t *testing.T) {
ctx := context.Background()
key := "test"
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/plugins/grafana-llm-app/resources/openai/v1/assistants" {
w.WriteHeader(http.StatusNotFound)
w.Write([]byte("404 page not found"))
}
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusMethodNotAllowed)
}
if r.Header.Get("Authorization") != "Bearer "+key {
w.WriteHeader(http.StatusUnauthorized)
}
req := openai.AssistantRequest{}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
w.WriteHeader(http.StatusOK)
response := openai.Assistant{
ID: "test",
}
w.Header().Set("Content-Type", "application/json")
j, _ := json.Marshal(response)
w.Write(j)
})
server := httptest.NewServer(handler)
client := NewOpenAI(server.URL, key)
// Test case: Create assistant request succeeds
req := AssistantRequest{
AssistantRequest: openai.AssistantRequest{},
Model: ModelBase,
}
_, err := client.CreateAssistant(ctx, req)
if err != nil {
t.Errorf("Expected no error, but got: %v", err)
}
}

func TestCreateThread(t *testing.T) {
ctx := context.Background()
key := "test"
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/plugins/grafana-llm-app/resources/openai/v1/threads" {
w.WriteHeader(http.StatusNotFound)
w.Write([]byte("404 page not found"))
}
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusMethodNotAllowed)
}
if r.Header.Get("Authorization") != "Bearer "+key {
w.WriteHeader(http.StatusUnauthorized)
}
req := openai.ThreadRequest{}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
w.WriteHeader(http.StatusOK)
response := openai.Thread{
ID: "test",
}
w.Header().Set("Content-Type", "application/json")
j, _ := json.Marshal(response)
w.Write(j)
})
server := httptest.NewServer(handler)
client := NewOpenAI(server.URL, key)
// Test case: Create thread request succeeds
req := openai.ThreadRequest{}
_, err := client.CreateThread(ctx, req)
if err != nil {
t.Errorf("Expected no error, but got: %v", err)
}
}

func TestCreateMessage(t *testing.T) {
ctx := context.Background()
key := "test"
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/plugins/grafana-llm-app/resources/openai/v1/threads/test/messages" {
w.WriteHeader(http.StatusNotFound)
w.Write([]byte("404 page not found"))
}
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusMethodNotAllowed)
}
if r.Header.Get("Authorization") != "Bearer "+key {
w.WriteHeader(http.StatusUnauthorized)
}
req := openai.MessageRequest{}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
w.WriteHeader(http.StatusOK)
response := openai.Message{
ID: "test",
}
w.Header().Set("Content-Type", "application/json")
j, _ := json.Marshal(response)
w.Write(j)
})
server := httptest.NewServer(handler)
client := NewOpenAI(server.URL, key)
// Test case: Create message request succeeds
req := openai.MessageRequest{}
_, err := client.CreateMessage(ctx, "test", req)
if err != nil {
t.Errorf("Expected no error, but got: %v", err)
}
}

func TestCreateRun(t *testing.T) {
ctx := context.Background()
key := "test"
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/plugins/grafana-llm-app/resources/openai/v1/threads/test/runs" {
w.WriteHeader(http.StatusNotFound)
w.Write([]byte("404 page not found"))
}
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusMethodNotAllowed)
}
if r.Header.Get("Authorization") != "Bearer "+key {
w.WriteHeader(http.StatusUnauthorized)
}
req := openai.RunRequest{}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
w.WriteHeader(http.StatusOK)
response := openai.Run{
ID: "test",
}
w.Header().Set("Content-Type", "application/json")
j, _ := json.Marshal(response)
w.Write(j)
})
server := httptest.NewServer(handler)
client := NewOpenAI(server.URL, key)
// Test case: Create run request succeeds
req := openai.RunRequest{}
_, err := client.CreateRun(ctx, "test", req)
if err != nil {
t.Errorf("Expected no error, but got: %v", err)
}
}

func TestRetrieveRun(t *testing.T) {
ctx := context.Background()
key := "test"
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/plugins/grafana-llm-app/resources/openai/v1/threads/test/runs/test" {
w.WriteHeader(http.StatusNotFound)
w.Write([]byte("404 page not found"))
}
if r.Method != http.MethodGet {
w.WriteHeader(http.StatusMethodNotAllowed)
}
if r.Header.Get("Authorization") != "Bearer "+key {
w.WriteHeader(http.StatusUnauthorized)
}
w.WriteHeader(http.StatusOK)
response := openai.Run{
ID: "test",
}
w.Header().Set("Content-Type", "application/json")
j, _ := json.Marshal(response)
w.Write(j)
})
server := httptest.NewServer(handler)
client := NewOpenAI(server.URL, key)
// Test case: Retrieve run request succeeds
_, err := client.RetrieveRun(ctx, "test", "test")
if err != nil {
t.Errorf("Expected no error, but got: %v", err)
}
}

func TestSubmitToolOutputs(t *testing.T) {
ctx := context.Background()
key := "test"
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/plugins/grafana-llm-app/resources/openai/v1/threads/test/runs/test/submit_tool_outputs" {
w.WriteHeader(http.StatusNotFound)
w.Write([]byte("404 page not found"))
}
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusMethodNotAllowed)
}
if r.Header.Get("Authorization") != "Bearer "+key {
w.WriteHeader(http.StatusUnauthorized)
}
req := openai.SubmitToolOutputsRequest{}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
w.WriteHeader(http.StatusOK)
response := openai.Run{
ID: "test",
}
w.Header().Set("Content-Type", "application/json")
j, _ := json.Marshal(response)
w.Write(j)
})
server := httptest.NewServer(handler)
client := NewOpenAI(server.URL, key)
// Test case: Submit tool outputs request succeeds
req := openai.SubmitToolOutputsRequest{}
_, err := client.SubmitToolOutputs(ctx, "test", "test", req)
if err != nil {
t.Errorf("Expected no error, but got: %v", err)
}
}

func TestListMessage(t *testing.T) {
ctx := context.Background()
key := "test"
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/plugins/grafana-llm-app/resources/openai/v1/threads/test/messages" {
w.WriteHeader(http.StatusNotFound)
w.Write([]byte("404 page not found"))
}
if r.Method != http.MethodGet {
w.WriteHeader(http.StatusMethodNotAllowed)
}
if r.Header.Get("Authorization") != "Bearer "+key {
w.WriteHeader(http.StatusUnauthorized)
}
w.WriteHeader(http.StatusOK)
response := openai.Run{
ID: "test",
}
w.Header().Set("Content-Type", "application/json")
j, _ := json.Marshal(response)
w.Write(j)
})
server := httptest.NewServer(handler)
client := NewOpenAI(server.URL, key)
// Test case: Retrieve run request succeeds
_, err := client.ListMessages(ctx, "test", nil, nil, nil, nil, nil)
if err != nil {
t.Errorf("Expected no error, but got: %v", err)
}
}
4 changes: 4 additions & 0 deletions packages/grafana-llm-app/pkg/plugin/azure_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,7 @@ func (p *azure) getAzureMapping() (map[Model]string, error) {
}
return result, nil
}

func (p *azure) ListAssistants(ctx context.Context, limit *int, order *string, after *string, before *string) (openai.AssistantsList, error) {
return p.oc.ListAssistants(ctx, limit, order, after, before)
}
Loading
Loading