This repository has been archived by the owner on Oct 30, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add: subquery retriever + redact sensitive values in logs (#39)
- Loading branch information
1 parent
0a0f13c
commit 26373f2
Showing
11 changed files
with
234 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
flows: | ||
foo: | ||
default: false | ||
retrieval: | ||
retriever: | ||
name: subquery | ||
options: | ||
limit: 3 | ||
topK: 5 | ||
model: | ||
openai: | ||
apiKey: "${OPENAI_API_KEY}" | ||
model: gpt-4o | ||
apiType: OPEN_AI | ||
apiBase: https://api.openai.com/v1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
package retrievers | ||
|
||
import ( | ||
"context" | ||
"encoding/json" | ||
"fmt" | ||
"github.com/gptscript-ai/knowledge/pkg/llm" | ||
vs "github.com/gptscript-ai/knowledge/pkg/vectorstore" | ||
"log/slog" | ||
"strings" | ||
) | ||
|
||
type SubqueryRetriever struct { | ||
Model llm.LLMConfig | ||
Limit int | ||
TopK int | ||
} | ||
|
||
var subqueryPrompt = `The following query will be used for a vector similarity search. | ||
If it is too complex or covering multiple topics or entities, please split it into multiple subqueries. | ||
I.e. a comparative query like "What are the differences between cats and dogs?" could be split into subqueries concerning cats and dogs separately. | ||
The resulting subqueries will then be used for separate vector similarity searches. | ||
Just changing the phrasing of the input question often won't change the semantic meaning, so those may not be good candidates. | ||
Limit the number of subqueries to a maximum of {{.limit}} (less is ok). | ||
Query: "{{.query}}" | ||
Reply with all subqueries in a json list like the following and don't reply with anything else (also don't use any markdown syntax). | ||
Response schema: {"results": ["<subquery-1>", "<subquery-2>"]}` | ||
|
||
type subqueryResp struct { | ||
Results []string `json:"results"` | ||
} | ||
|
||
func (s SubqueryRetriever) Retrieve(ctx context.Context, store vs.VectorStore, query string, datasetID string) ([]vs.Document, error) { | ||
m, err := llm.NewFromConfig(s.Model) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
if s.TopK <= 0 { | ||
s.TopK = 3 | ||
} | ||
|
||
if s.Limit < 1 { | ||
return nil, fmt.Errorf("limit must be at least 1") | ||
} | ||
|
||
if s.Limit == 0 { | ||
s.Limit = 3 | ||
} | ||
|
||
result, err := m.Prompt(context.Background(), subqueryPrompt, map[string]interface{}{"query": query, "limit": s.Limit}) | ||
if err != nil { | ||
return nil, err | ||
} | ||
var resp subqueryResp | ||
err = json.Unmarshal([]byte(result), &resp) | ||
if err != nil { | ||
slog.Debug("llm response", "response", result) | ||
return nil, fmt.Errorf("[retrievers/subquery] failed to unmarshal llm response: %w", err) | ||
} | ||
|
||
queries := resp.Results | ||
|
||
slog.Debug("SubqueryQueryRetriever generated subqueries", "queries", strings.Join(queries, " | ")) | ||
|
||
var resultDocs []vs.Document | ||
for _, q := range queries { | ||
docs, err := store.SimilaritySearch(ctx, q, s.TopK, datasetID) | ||
if err != nil { | ||
return nil, err | ||
} | ||
resultDocs = append(resultDocs, docs...) | ||
} | ||
|
||
return resultDocs, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
package config | ||
|
||
import ( | ||
"reflect" | ||
"slices" | ||
"strings" | ||
) | ||
|
||
var SensitiveFields = []string{ | ||
"password", | ||
"apikey", | ||
"token", | ||
"secret", | ||
"credentials", | ||
"auth", | ||
} | ||
|
||
func RedactSensitive(s any, fields ...string) any { | ||
toRedact := fields | ||
if len(toRedact) == 0 { | ||
toRedact = SensitiveFields | ||
} | ||
|
||
v := reflect.ValueOf(s) | ||
if v.Kind() == reflect.Ptr { | ||
v = v.Elem() | ||
} | ||
|
||
if v.Kind() != reflect.Struct { | ||
return s | ||
} | ||
|
||
redactedStruct := reflect.New(v.Type()).Elem() | ||
|
||
for i := 0; i < v.NumField(); i++ { | ||
field := v.Field(i) | ||
fieldType := v.Type().Field(i) | ||
fieldName := strings.ToLower(fieldType.Name) | ||
|
||
// Handle nested structs recursively | ||
if field.Kind() == reflect.Struct { | ||
redactedStruct.Field(i).Set(reflect.ValueOf(RedactSensitive(field.Interface()))) | ||
continue | ||
} | ||
|
||
if slices.Contains(toRedact, fieldName) { | ||
redactedStruct.Field(i).SetString("REDACTED") | ||
} else { | ||
redactedStruct.Field(i).Set(field) | ||
} | ||
} | ||
|
||
return redactedStruct.Interface() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
package config | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestRedactSensitiveWithNoCustomFields(t *testing.T) { | ||
original := struct { | ||
ApiKey string | ||
Username string | ||
}{"secret", "user"} | ||
|
||
redacted := RedactSensitive(original).(struct { | ||
ApiKey string | ||
Username string | ||
}) | ||
|
||
assert.Equal(t, "REDACTED", redacted.ApiKey) | ||
assert.Equal(t, "user", redacted.Username) | ||
} | ||
|
||
func TestRedactSensitiveWithCustomFields(t *testing.T) { | ||
original := struct { | ||
Foo string | ||
Spam string | ||
Bar string | ||
}{"some", "weird", "fields"} | ||
|
||
redacted := RedactSensitive(original, "spam").(struct { | ||
Foo string | ||
Spam string | ||
Bar string | ||
}) | ||
|
||
assert.Equal(t, "some", redacted.Foo) | ||
assert.Equal(t, "REDACTED", redacted.Spam) | ||
assert.Equal(t, "fields", redacted.Bar) | ||
} | ||
|
||
func TestRedactSensitiveWithUnsupportedType(t *testing.T) { | ||
original := "string" | ||
redacted := RedactSensitive(original) | ||
assert.Equal(t, original, redacted) | ||
} | ||
|
||
func TestRedactSensitiveWithPointerToStruct(t *testing.T) { | ||
original := &struct { | ||
Token string | ||
Info string | ||
}{"token", "info"} | ||
|
||
redacted := RedactSensitive(original).(struct { | ||
Token string | ||
Info string | ||
}) | ||
|
||
assert.Equal(t, "REDACTED", redacted.Token) | ||
assert.Equal(t, "info", redacted.Info) | ||
} | ||
|
||
func TestRedactSensitiveWithNestedStruct(t *testing.T) { | ||
original := struct { | ||
Credentials struct { | ||
Password string | ||
} | ||
Detail string | ||
}{struct{ Password string }{"pass"}, "detail"} | ||
|
||
redacted := RedactSensitive(original).(struct { | ||
Credentials struct { | ||
Password string | ||
} | ||
Detail string | ||
}) | ||
|
||
assert.Equal(t, "REDACTED", redacted.Credentials.Password) | ||
assert.Equal(t, "detail", redacted.Detail) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters