diff --git a/examples/subquery_retriever.yaml b/examples/subquery_retriever.yaml new file mode 100644 index 00000000..46b3485b --- /dev/null +++ b/examples/subquery_retriever.yaml @@ -0,0 +1,15 @@ +flows: + foo: + default: false + retrieval: + retriever: + name: subquery + options: + limit: 3 + topK: 5 + model: + openai: + apiKey: "${OPENAI_API_KEY}" + model: gpt-4o + apiType: OPEN_AI + apiBase: https://api.openai.com/v1 \ No newline at end of file diff --git a/pkg/cmd/client.go b/pkg/cmd/client.go index c8234796..e8a3d132 100644 --- a/pkg/cmd/client.go +++ b/pkg/cmd/client.go @@ -91,7 +91,6 @@ func (s *Client) loadArchive() error { } func (s *Client) getClient() (client.Client, error) { - if err := s.loadArchive(); err != nil { return nil, err } diff --git a/pkg/cmd/export.go b/pkg/cmd/export.go index 69ee5765..545f656f 100644 --- a/pkg/cmd/export.go +++ b/pkg/cmd/export.go @@ -39,7 +39,6 @@ func (s *ClientExportDatasets) Run(cmd *cobra.Command, args []string) error { dsnames[i] = ds.ID } } else { - for _, datasetID := range dsnames { ds, err := c.GetDataset(cmd.Context(), datasetID) if err != nil { diff --git a/pkg/cmd/list_datasets.go b/pkg/cmd/list_datasets.go index 0c020784..5ba17bc6 100644 --- a/pkg/cmd/list_datasets.go +++ b/pkg/cmd/list_datasets.go @@ -18,7 +18,6 @@ func (s *ClientListDatasets) Customize(cmd *cobra.Command) { } func (s *ClientListDatasets) Run(cmd *cobra.Command, args []string) error { - c, err := s.getClient() if err != nil { return err diff --git a/pkg/cmd/retrieve.go b/pkg/cmd/retrieve.go index c73388bf..fc1e640e 100644 --- a/pkg/cmd/retrieve.go +++ b/pkg/cmd/retrieve.go @@ -30,7 +30,6 @@ func (s *ClientRetrieve) Customize(cmd *cobra.Command) { } func (s *ClientRetrieve) Run(cmd *cobra.Command, args []string) error { - c, err := s.getClient() if err != nil { return err diff --git a/pkg/datastore/retrievers/retrievers.go b/pkg/datastore/retrievers/retrievers.go index 7fac64db..54b9fde4 100644 --- a/pkg/datastore/retrievers/retrievers.go +++ b/pkg/datastore/retrievers/retrievers.go @@ -2,6 +2,7 @@ package retrievers import ( "context" + "fmt" "log/slog" "github.com/gptscript-ai/knowledge/pkg/datastore/defaults" @@ -16,8 +17,10 @@ func GetRetriever(name string) (Retriever, error) { switch name { case "basic", "default": return &BasicRetriever{TopK: defaults.TopK}, nil + case "subquery": + return &SubqueryRetriever{Limit: 3, TopK: 3}, nil default: - return nil, nil + return nil, fmt.Errorf("unknown retriever %q", name) } } diff --git a/pkg/datastore/retrievers/subquery.go b/pkg/datastore/retrievers/subquery.go new file mode 100644 index 00000000..e83b9612 --- /dev/null +++ b/pkg/datastore/retrievers/subquery.go @@ -0,0 +1,76 @@ +package retrievers + +import ( + "context" + "encoding/json" + "fmt" + "github.com/gptscript-ai/knowledge/pkg/llm" + vs "github.com/gptscript-ai/knowledge/pkg/vectorstore" + "log/slog" + "strings" +) + +type SubqueryRetriever struct { + Model llm.LLMConfig + Limit int + TopK int +} + +var subqueryPrompt = `The following query will be used for a vector similarity search. +If it is too complex or covering multiple topics or entities, please split it into multiple subqueries. +I.e. a comparative query like "What are the differences between cats and dogs?" could be split into subqueries concerning cats and dogs separately. +The resulting subqueries will then be used for separate vector similarity searches. +Just changing the phrasing of the input question often won't change the semantic meaning, so those may not be good candidates. +Limit the number of subqueries to a maximum of {{.limit}} (less is ok). +Query: "{{.query}}" +Reply with all subqueries in a json list like the following and don't reply with anything else (also don't use any markdown syntax). +Response schema: {"results": ["", ""]}` + +type subqueryResp struct { + Results []string `json:"results"` +} + +func (s SubqueryRetriever) Retrieve(ctx context.Context, store vs.VectorStore, query string, datasetID string) ([]vs.Document, error) { + m, err := llm.NewFromConfig(s.Model) + if err != nil { + return nil, err + } + + if s.TopK <= 0 { + s.TopK = 3 + } + + if s.Limit < 1 { + return nil, fmt.Errorf("limit must be at least 1") + } + + if s.Limit == 0 { + s.Limit = 3 + } + + result, err := m.Prompt(context.Background(), subqueryPrompt, map[string]interface{}{"query": query, "limit": s.Limit}) + if err != nil { + return nil, err + } + var resp subqueryResp + err = json.Unmarshal([]byte(result), &resp) + if err != nil { + slog.Debug("llm response", "response", result) + return nil, fmt.Errorf("[retrievers/subquery] failed to unmarshal llm response: %w", err) + } + + queries := resp.Results + + slog.Debug("SubqueryQueryRetriever generated subqueries", "queries", strings.Join(queries, " | ")) + + var resultDocs []vs.Document + for _, q := range queries { + docs, err := store.SimilaritySearch(ctx, q, s.TopK, datasetID) + if err != nil { + return nil, err + } + resultDocs = append(resultDocs, docs...) + } + + return resultDocs, nil +} diff --git a/pkg/flows/config/config.go b/pkg/flows/config/config.go index 6446749f..028aa99d 100644 --- a/pkg/flows/config/config.go +++ b/pkg/flows/config/config.go @@ -203,7 +203,7 @@ func (i *IngestionFlowConfig) AsIngestionFlow() (*flows.IngestionFlow, error) { if err := mapstructure.Decode(tf.Options, &transformer); err != nil { return nil, fmt.Errorf("failed to decode transformer configuration: %w", err) } - slog.Debug("Transformer custom configuration", "name", tf.Name, "config", transformer) + slog.Debug("Transformer custom configuration", "name", tf.Name, "config", RedactSensitive(transformer)) } flow.Transformations = append(flow.Transformations, transformer) } @@ -235,7 +235,7 @@ func (r *RetrievalFlowConfig) AsRetrievalFlow() (*flows.RetrievalFlow, error) { if err := mapstructure.Decode(qm.Options, &modifier); err != nil { return nil, fmt.Errorf("failed to decode query modifier configuration: %w", err) } - slog.Debug("Query Modifier custom configuration", "name", qm.Name, "config", modifier) + slog.Debug("Query Modifier custom configuration", "name", qm.Name, "config", RedactSensitive(modifier)) } flow.QueryModifiers = append(flow.QueryModifiers, modifier) } @@ -250,7 +250,7 @@ func (r *RetrievalFlowConfig) AsRetrievalFlow() (*flows.RetrievalFlow, error) { if err := mapstructure.Decode(r.Retriever.Options, &ret); err != nil { return nil, fmt.Errorf("failed to decode retriever configuration: %w", err) } - slog.Debug("Retriever custom configuration", "name", r.Retriever.Name, "config", ret) + slog.Debug("Retriever custom configuration", "name", r.Retriever.Name, "config", RedactSensitive(ret)) } flow.Retriever = ret } @@ -265,7 +265,7 @@ func (r *RetrievalFlowConfig) AsRetrievalFlow() (*flows.RetrievalFlow, error) { if err := mapstructure.Decode(pp.Options, &postprocessor); err != nil { return nil, fmt.Errorf("failed to decode postprocessor configuration: %w", err) } - slog.Debug("Postprocessor custom configuration", "name", pp.Name, "config", postprocessor) + slog.Debug("Postprocessor custom configuration", "name", pp.Name, "config", RedactSensitive(postprocessor)) } flow.Postprocessors = append(flow.Postprocessors, postprocessor) } diff --git a/pkg/flows/config/redact.go b/pkg/flows/config/redact.go new file mode 100644 index 00000000..be45687f --- /dev/null +++ b/pkg/flows/config/redact.go @@ -0,0 +1,54 @@ +package config + +import ( + "reflect" + "slices" + "strings" +) + +var SensitiveFields = []string{ + "password", + "apikey", + "token", + "secret", + "credentials", + "auth", +} + +func RedactSensitive(s any, fields ...string) any { + toRedact := fields + if len(toRedact) == 0 { + toRedact = SensitiveFields + } + + v := reflect.ValueOf(s) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + + if v.Kind() != reflect.Struct { + return s + } + + redactedStruct := reflect.New(v.Type()).Elem() + + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + fieldType := v.Type().Field(i) + fieldName := strings.ToLower(fieldType.Name) + + // Handle nested structs recursively + if field.Kind() == reflect.Struct { + redactedStruct.Field(i).Set(reflect.ValueOf(RedactSensitive(field.Interface()))) + continue + } + + if slices.Contains(toRedact, fieldName) { + redactedStruct.Field(i).SetString("REDACTED") + } else { + redactedStruct.Field(i).Set(field) + } + } + + return redactedStruct.Interface() +} diff --git a/pkg/flows/config/redact_test.go b/pkg/flows/config/redact_test.go new file mode 100644 index 00000000..0f5f5fd5 --- /dev/null +++ b/pkg/flows/config/redact_test.go @@ -0,0 +1,80 @@ +package config + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRedactSensitiveWithNoCustomFields(t *testing.T) { + original := struct { + ApiKey string + Username string + }{"secret", "user"} + + redacted := RedactSensitive(original).(struct { + ApiKey string + Username string + }) + + assert.Equal(t, "REDACTED", redacted.ApiKey) + assert.Equal(t, "user", redacted.Username) +} + +func TestRedactSensitiveWithCustomFields(t *testing.T) { + original := struct { + Foo string + Spam string + Bar string + }{"some", "weird", "fields"} + + redacted := RedactSensitive(original, "spam").(struct { + Foo string + Spam string + Bar string + }) + + assert.Equal(t, "some", redacted.Foo) + assert.Equal(t, "REDACTED", redacted.Spam) + assert.Equal(t, "fields", redacted.Bar) +} + +func TestRedactSensitiveWithUnsupportedType(t *testing.T) { + original := "string" + redacted := RedactSensitive(original) + assert.Equal(t, original, redacted) +} + +func TestRedactSensitiveWithPointerToStruct(t *testing.T) { + original := &struct { + Token string + Info string + }{"token", "info"} + + redacted := RedactSensitive(original).(struct { + Token string + Info string + }) + + assert.Equal(t, "REDACTED", redacted.Token) + assert.Equal(t, "info", redacted.Info) +} + +func TestRedactSensitiveWithNestedStruct(t *testing.T) { + original := struct { + Credentials struct { + Password string + } + Detail string + }{struct{ Password string }{"pass"}, "detail"} + + redacted := RedactSensitive(original).(struct { + Credentials struct { + Password string + } + Detail string + }) + + assert.Equal(t, "REDACTED", redacted.Credentials.Password) + assert.Equal(t, "detail", redacted.Detail) +} diff --git a/pkg/llm/llm.go b/pkg/llm/llm.go index 2940f08b..e97df6fa 100644 --- a/pkg/llm/llm.go +++ b/pkg/llm/llm.go @@ -44,7 +44,7 @@ func (llm *LLM) Prompt(ctx context.Context, promptTpl string, values map[string] if err != nil { return "", err } - slog.Debug("Prompting LLM with: %s", p) + slog.Debug("Prompting LLM", "prompt", p) res, err := golcmodel.GeneratePrompt(ctx, llm.model, p) if err != nil {