Skip to content
This repository has been archived by the owner on Oct 30, 2024. It is now read-only.

Commit

Permalink
add: subquery retriever + redact sensitive values in logs (#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
iwilltry42 authored Jul 3, 2024
1 parent 0a0f13c commit 26373f2
Show file tree
Hide file tree
Showing 11 changed files with 234 additions and 10 deletions.
15 changes: 15 additions & 0 deletions examples/subquery_retriever.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
flows:
foo:
default: false
retrieval:
retriever:
name: subquery
options:
limit: 3
topK: 5
model:
openai:
apiKey: "${OPENAI_API_KEY}"
model: gpt-4o
apiType: OPEN_AI
apiBase: https://api.openai.com/v1
1 change: 0 additions & 1 deletion pkg/cmd/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ func (s *Client) loadArchive() error {
}

func (s *Client) getClient() (client.Client, error) {

if err := s.loadArchive(); err != nil {
return nil, err
}
Expand Down
1 change: 0 additions & 1 deletion pkg/cmd/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ func (s *ClientExportDatasets) Run(cmd *cobra.Command, args []string) error {
dsnames[i] = ds.ID
}
} else {

for _, datasetID := range dsnames {
ds, err := c.GetDataset(cmd.Context(), datasetID)
if err != nil {
Expand Down
1 change: 0 additions & 1 deletion pkg/cmd/list_datasets.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ func (s *ClientListDatasets) Customize(cmd *cobra.Command) {
}

func (s *ClientListDatasets) Run(cmd *cobra.Command, args []string) error {

c, err := s.getClient()
if err != nil {
return err
Expand Down
1 change: 0 additions & 1 deletion pkg/cmd/retrieve.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ func (s *ClientRetrieve) Customize(cmd *cobra.Command) {
}

func (s *ClientRetrieve) Run(cmd *cobra.Command, args []string) error {

c, err := s.getClient()
if err != nil {
return err
Expand Down
5 changes: 4 additions & 1 deletion pkg/datastore/retrievers/retrievers.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package retrievers

import (
"context"
"fmt"
"log/slog"

"github.com/gptscript-ai/knowledge/pkg/datastore/defaults"
Expand All @@ -16,8 +17,10 @@ func GetRetriever(name string) (Retriever, error) {
switch name {
case "basic", "default":
return &BasicRetriever{TopK: defaults.TopK}, nil
case "subquery":
return &SubqueryRetriever{Limit: 3, TopK: 3}, nil
default:
return nil, nil
return nil, fmt.Errorf("unknown retriever %q", name)
}
}

Expand Down
76 changes: 76 additions & 0 deletions pkg/datastore/retrievers/subquery.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package retrievers

import (
"context"
"encoding/json"
"fmt"
"github.com/gptscript-ai/knowledge/pkg/llm"
vs "github.com/gptscript-ai/knowledge/pkg/vectorstore"
"log/slog"
"strings"
)

type SubqueryRetriever struct {
Model llm.LLMConfig
Limit int
TopK int
}

var subqueryPrompt = `The following query will be used for a vector similarity search.
If it is too complex or covering multiple topics or entities, please split it into multiple subqueries.
I.e. a comparative query like "What are the differences between cats and dogs?" could be split into subqueries concerning cats and dogs separately.
The resulting subqueries will then be used for separate vector similarity searches.
Just changing the phrasing of the input question often won't change the semantic meaning, so those may not be good candidates.
Limit the number of subqueries to a maximum of {{.limit}} (less is ok).
Query: "{{.query}}"
Reply with all subqueries in a json list like the following and don't reply with anything else (also don't use any markdown syntax).
Response schema: {"results": ["<subquery-1>", "<subquery-2>"]}`

type subqueryResp struct {
Results []string `json:"results"`
}

func (s SubqueryRetriever) Retrieve(ctx context.Context, store vs.VectorStore, query string, datasetID string) ([]vs.Document, error) {
m, err := llm.NewFromConfig(s.Model)
if err != nil {
return nil, err
}

if s.TopK <= 0 {
s.TopK = 3
}

if s.Limit < 1 {
return nil, fmt.Errorf("limit must be at least 1")
}

if s.Limit == 0 {
s.Limit = 3
}

result, err := m.Prompt(context.Background(), subqueryPrompt, map[string]interface{}{"query": query, "limit": s.Limit})
if err != nil {
return nil, err
}
var resp subqueryResp
err = json.Unmarshal([]byte(result), &resp)
if err != nil {
slog.Debug("llm response", "response", result)
return nil, fmt.Errorf("[retrievers/subquery] failed to unmarshal llm response: %w", err)
}

queries := resp.Results

slog.Debug("SubqueryQueryRetriever generated subqueries", "queries", strings.Join(queries, " | "))

var resultDocs []vs.Document
for _, q := range queries {
docs, err := store.SimilaritySearch(ctx, q, s.TopK, datasetID)
if err != nil {
return nil, err
}
resultDocs = append(resultDocs, docs...)
}

return resultDocs, nil
}
8 changes: 4 additions & 4 deletions pkg/flows/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ func (i *IngestionFlowConfig) AsIngestionFlow() (*flows.IngestionFlow, error) {
if err := mapstructure.Decode(tf.Options, &transformer); err != nil {
return nil, fmt.Errorf("failed to decode transformer configuration: %w", err)
}
slog.Debug("Transformer custom configuration", "name", tf.Name, "config", transformer)
slog.Debug("Transformer custom configuration", "name", tf.Name, "config", RedactSensitive(transformer))
}
flow.Transformations = append(flow.Transformations, transformer)
}
Expand Down Expand Up @@ -235,7 +235,7 @@ func (r *RetrievalFlowConfig) AsRetrievalFlow() (*flows.RetrievalFlow, error) {
if err := mapstructure.Decode(qm.Options, &modifier); err != nil {
return nil, fmt.Errorf("failed to decode query modifier configuration: %w", err)
}
slog.Debug("Query Modifier custom configuration", "name", qm.Name, "config", modifier)
slog.Debug("Query Modifier custom configuration", "name", qm.Name, "config", RedactSensitive(modifier))
}
flow.QueryModifiers = append(flow.QueryModifiers, modifier)
}
Expand All @@ -250,7 +250,7 @@ func (r *RetrievalFlowConfig) AsRetrievalFlow() (*flows.RetrievalFlow, error) {
if err := mapstructure.Decode(r.Retriever.Options, &ret); err != nil {
return nil, fmt.Errorf("failed to decode retriever configuration: %w", err)
}
slog.Debug("Retriever custom configuration", "name", r.Retriever.Name, "config", ret)
slog.Debug("Retriever custom configuration", "name", r.Retriever.Name, "config", RedactSensitive(ret))
}
flow.Retriever = ret
}
Expand All @@ -265,7 +265,7 @@ func (r *RetrievalFlowConfig) AsRetrievalFlow() (*flows.RetrievalFlow, error) {
if err := mapstructure.Decode(pp.Options, &postprocessor); err != nil {
return nil, fmt.Errorf("failed to decode postprocessor configuration: %w", err)
}
slog.Debug("Postprocessor custom configuration", "name", pp.Name, "config", postprocessor)
slog.Debug("Postprocessor custom configuration", "name", pp.Name, "config", RedactSensitive(postprocessor))
}
flow.Postprocessors = append(flow.Postprocessors, postprocessor)
}
Expand Down
54 changes: 54 additions & 0 deletions pkg/flows/config/redact.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package config

import (
"reflect"
"slices"
"strings"
)

var SensitiveFields = []string{
"password",
"apikey",
"token",
"secret",
"credentials",
"auth",
}

func RedactSensitive(s any, fields ...string) any {
toRedact := fields
if len(toRedact) == 0 {
toRedact = SensitiveFields
}

v := reflect.ValueOf(s)
if v.Kind() == reflect.Ptr {
v = v.Elem()
}

if v.Kind() != reflect.Struct {
return s
}

redactedStruct := reflect.New(v.Type()).Elem()

for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
fieldType := v.Type().Field(i)
fieldName := strings.ToLower(fieldType.Name)

// Handle nested structs recursively
if field.Kind() == reflect.Struct {
redactedStruct.Field(i).Set(reflect.ValueOf(RedactSensitive(field.Interface())))
continue
}

if slices.Contains(toRedact, fieldName) {
redactedStruct.Field(i).SetString("REDACTED")
} else {
redactedStruct.Field(i).Set(field)
}
}

return redactedStruct.Interface()
}
80 changes: 80 additions & 0 deletions pkg/flows/config/redact_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package config

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestRedactSensitiveWithNoCustomFields(t *testing.T) {
original := struct {
ApiKey string
Username string
}{"secret", "user"}

redacted := RedactSensitive(original).(struct {
ApiKey string
Username string
})

assert.Equal(t, "REDACTED", redacted.ApiKey)
assert.Equal(t, "user", redacted.Username)
}

func TestRedactSensitiveWithCustomFields(t *testing.T) {
original := struct {
Foo string
Spam string
Bar string
}{"some", "weird", "fields"}

redacted := RedactSensitive(original, "spam").(struct {
Foo string
Spam string
Bar string
})

assert.Equal(t, "some", redacted.Foo)
assert.Equal(t, "REDACTED", redacted.Spam)
assert.Equal(t, "fields", redacted.Bar)
}

func TestRedactSensitiveWithUnsupportedType(t *testing.T) {
original := "string"
redacted := RedactSensitive(original)
assert.Equal(t, original, redacted)
}

func TestRedactSensitiveWithPointerToStruct(t *testing.T) {
original := &struct {
Token string
Info string
}{"token", "info"}

redacted := RedactSensitive(original).(struct {
Token string
Info string
})

assert.Equal(t, "REDACTED", redacted.Token)
assert.Equal(t, "info", redacted.Info)
}

func TestRedactSensitiveWithNestedStruct(t *testing.T) {
original := struct {
Credentials struct {
Password string
}
Detail string
}{struct{ Password string }{"pass"}, "detail"}

redacted := RedactSensitive(original).(struct {
Credentials struct {
Password string
}
Detail string
})

assert.Equal(t, "REDACTED", redacted.Credentials.Password)
assert.Equal(t, "detail", redacted.Detail)
}
2 changes: 1 addition & 1 deletion pkg/llm/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func (llm *LLM) Prompt(ctx context.Context, promptTpl string, values map[string]
if err != nil {
return "", err
}
slog.Debug("Prompting LLM with: %s", p)
slog.Debug("Prompting LLM", "prompt", p)

res, err := golcmodel.GeneratePrompt(ctx, llm.model, p)
if err != nil {
Expand Down

0 comments on commit 26373f2

Please sign in to comment.