diff --git a/pkg/plugin/health.go b/pkg/plugin/health.go index 7ec2f716..64733bd9 100644 --- a/pkg/plugin/health.go +++ b/pkg/plugin/health.go @@ -122,7 +122,7 @@ func (a *App) testVectorService(ctx context.Context) error { if a.vectorService == nil { return fmt.Errorf("vector service not configured") } - _, err := a.vectorService.Search(ctx, vectorCollections[0], "test", 1) + _, err := a.vectorService.Search(ctx, vectorCollections[0], "test", 1, nil) return err } diff --git a/pkg/plugin/health_test.go b/pkg/plugin/health_test.go index 3cc4b54b..e24b4c9f 100644 --- a/pkg/plugin/health_test.go +++ b/pkg/plugin/health_test.go @@ -23,7 +23,7 @@ func (m *mockHealthCheckClient) Do(req *http.Request) (*http.Response, error) { type mockVectorService struct{} -func (m *mockVectorService) Search(ctx context.Context, collection string, query string, topK uint64) ([]store.SearchResult, error) { +func (m *mockVectorService) Search(ctx context.Context, collection string, query string, topK uint64, filter map[string]interface{}) ([]store.SearchResult, error) { return []store.SearchResult{{Payload: map[string]any{"a": "b"}, Score: 1.0}}, nil } diff --git a/pkg/plugin/resources.go b/pkg/plugin/resources.go index 1d795d89..7915f999 100644 --- a/pkg/plugin/resources.go +++ b/pkg/plugin/resources.go @@ -168,9 +168,10 @@ func newAzureOpenAIProxy(settings Settings) http.Handler { } type vectorSearchRequest struct { - Query string `json:"query"` - Collection string `json:"collection"` - TopK uint64 `json:"topK"` + Query string `json:"query"` + Collection string `json:"collection"` + TopK uint64 `json:"topK"` + Filter map[string]interface{} `json:"filter"` } type vectorSearchResponse struct { @@ -194,7 +195,7 @@ func (app *App) handleVectorSearch(w http.ResponseWriter, req *http.Request) { if body.TopK == 0 { body.TopK = 10 } - results, err := app.vectorService.Search(req.Context(), body.Collection, body.Query, body.TopK) + results, err := app.vectorService.Search(req.Context(), body.Collection, body.Query, body.TopK, body.Filter) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return diff --git a/pkg/plugin/vector/service.go b/pkg/plugin/vector/service.go index b1a6b4bb..6120c7e8 100644 --- a/pkg/plugin/vector/service.go +++ b/pkg/plugin/vector/service.go @@ -12,7 +12,7 @@ import ( ) type Service interface { - Search(ctx context.Context, collection string, query string, topK uint64) ([]store.SearchResult, error) + Search(ctx context.Context, collection string, query string, topK uint64, filter map[string]interface{}) ([]store.SearchResult, error) Cancel() } @@ -57,7 +57,7 @@ func NewService(s VectorSettings, secrets map[string]string) (Service, error) { }, nil } -func (v *vectorService) Search(ctx context.Context, collection string, query string, topK uint64) ([]store.SearchResult, error) { +func (v *vectorService) Search(ctx context.Context, collection string, query string, topK uint64, filter map[string]interface{}) ([]store.SearchResult, error) { if query == "" { return nil, fmt.Errorf("query cannot be empty") } @@ -78,7 +78,7 @@ func (v *vectorService) Search(ctx context.Context, collection string, query str log.DefaultLogger.Info("Searching", "collection", collection, "query", query) // Search the vector store for similar vectors. - results, err := v.store.Search(ctx, collection, e, topK) + results, err := v.store.Search(ctx, collection, e, topK, filter) if err != nil { return nil, fmt.Errorf("vector store search: %w", err) } diff --git a/pkg/plugin/vector/store/qdrant.go b/pkg/plugin/vector/store/qdrant.go index c7ba1b1b..573bba0e 100644 --- a/pkg/plugin/vector/store/qdrant.go +++ b/pkg/plugin/vector/store/qdrant.go @@ -3,6 +3,7 @@ package store import ( "context" "crypto/tls" + "fmt" "github.com/grafana/grafana-plugin-sdk-go/backend/log" qdrant "github.com/qdrant/go-client/qdrant" @@ -82,14 +83,105 @@ func (q *qdrantStore) CollectionExists(ctx context.Context, collection string) ( return true, nil } -func (q *qdrantStore) Search(ctx context.Context, collection string, vector []float32, topK uint64) ([]SearchResult, error) { +func (q *qdrantStore) mapFilters(ctx context.Context, filter map[string]interface{}) (*qdrant.Filter, error) { + qdrantFilterMap := &qdrant.Filter{} + + if filter == nil { + return qdrantFilterMap, nil + } + + for k, v := range filter { + switch v := v.(type) { + case map[string]interface{}: + for op, val := range v { + match, err := createQdrantMatch(val) + if err != nil { + return nil, err + } + + condition := &qdrant.Condition{ + ConditionOneOf: &qdrant.Condition_Field{ + Field: &qdrant.FieldCondition{ + Key: k, + Match: match, + }, + }, + } + + switch op { + case "$eq": + qdrantFilterMap.Must = append(qdrantFilterMap.Must, condition) + case "$ne": + qdrantFilterMap.MustNot = append(qdrantFilterMap.MustNot, condition) + default: + return nil, fmt.Errorf("unsupported operator: %s", op) + } + } + case []interface{}: + switch k { + case "$or": + for _, u := range v { + filterMap, err := q.mapFilters(ctx, u.(map[string]interface{})) + if err != nil { + return nil, err + } + qdrantFilterMap.Should = append(qdrantFilterMap.Should, &qdrant.Condition{ + ConditionOneOf: &qdrant.Condition_Filter{ + Filter: filterMap, + }, + }) + } + case "$and": + for _, u := range v { + filterMap, err := q.mapFilters(ctx, u.(map[string]interface{})) + if err != nil { + return nil, err + } + qdrantFilterMap.Must = append(qdrantFilterMap.Must, &qdrant.Condition{ + ConditionOneOf: &qdrant.Condition_Filter{ + Filter: filterMap, + }, + }) + } + default: + return nil, fmt.Errorf("unsupported operator: %s", k) + } + default: + return nil, fmt.Errorf("unsupported filter struct: %T", v) + } + } + + return qdrantFilterMap, nil +} + +func createQdrantMatch(val interface{}) (*qdrant.Match, error) { + match := &qdrant.Match{} + switch val := val.(type) { + case string: + match.MatchValue = &qdrant.Match_Keyword{ + Keyword: val, + } + default: + return nil, fmt.Errorf("unsupported filter type: %T", val) + } + return match, nil +} + +func (q *qdrantStore) Search(ctx context.Context, collection string, vector []float32, topK uint64, filter map[string]interface{}) ([]SearchResult, error) { if q.md != nil { ctx = metadata.NewOutgoingContext(ctx, *q.md) } + + qdrantFilter, err := q.mapFilters(ctx, filter) + if err != nil { + return nil, err + } + result, err := q.pointsClient.Search(ctx, &qdrant.SearchPoints{ CollectionName: collection, Vector: vector, Limit: topK, + Filter: qdrantFilter, // Include all payloads in the search result WithVectors: &qdrant.WithVectorsSelector{SelectorOptions: &qdrant.WithVectorsSelector_Enable{Enable: false}}, WithPayload: &qdrant.WithPayloadSelector{SelectorOptions: &qdrant.WithPayloadSelector_Enable{Enable: true}}, diff --git a/pkg/plugin/vector/store/store.go b/pkg/plugin/vector/store/store.go index 91bcee38..e7f31a6c 100644 --- a/pkg/plugin/vector/store/store.go +++ b/pkg/plugin/vector/store/store.go @@ -20,7 +20,7 @@ type SearchResult struct { type ReadVectorStore interface { CollectionExists(ctx context.Context, collection string) (bool, error) - Search(ctx context.Context, collection string, vector []float32, topK uint64) ([]SearchResult, error) + Search(ctx context.Context, collection string, vector []float32, topK uint64, filter map[string]interface{}) ([]SearchResult, error) } type WriteVectorStore interface { diff --git a/pkg/plugin/vector/store/vectorapi.go b/pkg/plugin/vector/store/vectorapi.go index 230d4e71..84cdddc7 100644 --- a/pkg/plugin/vector/store/vectorapi.go +++ b/pkg/plugin/vector/store/vectorapi.go @@ -31,14 +31,17 @@ func (g *grafanaVectorAPI) CollectionExists(ctx context.Context, collection stri return true, nil } -func (g *grafanaVectorAPI) Search(ctx context.Context, collection string, vector []float32, topK uint64) ([]SearchResult, error) { +func (g *grafanaVectorAPI) Search(ctx context.Context, collection string, vector []float32, topK uint64, filter map[string]interface{}) ([]SearchResult, error) { type queryPointsRequest struct { Query []float32 `json:"query"` TopK uint64 `json:"top_k"` + // optional filter json field + Filter map[string]interface{} `json:"filter"` } reqBody := queryPointsRequest{ - Query: vector, - TopK: topK, + Query: vector, + TopK: topK, + Filter: filter, } reqJSON, err := json.Marshal(reqBody) if err != nil {