Skip to content

Commit

Permalink
Merge pull request #100 from grafana/vector-search-add-filters
Browse files Browse the repository at this point in the history
Add support for filtered vector search
  • Loading branch information
yoziru authored Oct 17, 2023
2 parents 27588c9 + 8b6736e commit bbca65b
Show file tree
Hide file tree
Showing 7 changed files with 110 additions and 14 deletions.
2 changes: 1 addition & 1 deletion pkg/plugin/health.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func (a *App) testVectorService(ctx context.Context) error {
if a.vectorService == nil {
return fmt.Errorf("vector service not configured")
}
_, err := a.vectorService.Search(ctx, vectorCollections[0], "test", 1)
_, err := a.vectorService.Search(ctx, vectorCollections[0], "test", 1, nil)
return err
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/plugin/health_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func (m *mockHealthCheckClient) Do(req *http.Request) (*http.Response, error) {

type mockVectorService struct{}

func (m *mockVectorService) Search(ctx context.Context, collection string, query string, topK uint64) ([]store.SearchResult, error) {
func (m *mockVectorService) Search(ctx context.Context, collection string, query string, topK uint64, filter map[string]interface{}) ([]store.SearchResult, error) {
return []store.SearchResult{{Payload: map[string]any{"a": "b"}, Score: 1.0}}, nil
}

Expand Down
9 changes: 5 additions & 4 deletions pkg/plugin/resources.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,10 @@ func newAzureOpenAIProxy(settings Settings) http.Handler {
}

type vectorSearchRequest struct {
Query string `json:"query"`
Collection string `json:"collection"`
TopK uint64 `json:"topK"`
Query string `json:"query"`
Collection string `json:"collection"`
TopK uint64 `json:"topK"`
Filter map[string]interface{} `json:"filter"`
}

type vectorSearchResponse struct {
Expand All @@ -194,7 +195,7 @@ func (app *App) handleVectorSearch(w http.ResponseWriter, req *http.Request) {
if body.TopK == 0 {
body.TopK = 10
}
results, err := app.vectorService.Search(req.Context(), body.Collection, body.Query, body.TopK)
results, err := app.vectorService.Search(req.Context(), body.Collection, body.Query, body.TopK, body.Filter)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
Expand Down
6 changes: 3 additions & 3 deletions pkg/plugin/vector/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
)

type Service interface {
Search(ctx context.Context, collection string, query string, topK uint64) ([]store.SearchResult, error)
Search(ctx context.Context, collection string, query string, topK uint64, filter map[string]interface{}) ([]store.SearchResult, error)
Cancel()
}

Expand Down Expand Up @@ -57,7 +57,7 @@ func NewService(s VectorSettings, secrets map[string]string) (Service, error) {
}, nil
}

func (v *vectorService) Search(ctx context.Context, collection string, query string, topK uint64) ([]store.SearchResult, error) {
func (v *vectorService) Search(ctx context.Context, collection string, query string, topK uint64, filter map[string]interface{}) ([]store.SearchResult, error) {
if query == "" {
return nil, fmt.Errorf("query cannot be empty")
}
Expand All @@ -78,7 +78,7 @@ func (v *vectorService) Search(ctx context.Context, collection string, query str

log.DefaultLogger.Info("Searching", "collection", collection, "query", query)
// Search the vector store for similar vectors.
results, err := v.store.Search(ctx, collection, e, topK)
results, err := v.store.Search(ctx, collection, e, topK, filter)
if err != nil {
return nil, fmt.Errorf("vector store search: %w", err)
}
Expand Down
94 changes: 93 additions & 1 deletion pkg/plugin/vector/store/qdrant.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package store
import (
"context"
"crypto/tls"
"fmt"

"github.com/grafana/grafana-plugin-sdk-go/backend/log"
qdrant "github.com/qdrant/go-client/qdrant"
Expand Down Expand Up @@ -82,14 +83,105 @@ func (q *qdrantStore) CollectionExists(ctx context.Context, collection string) (
return true, nil
}

func (q *qdrantStore) Search(ctx context.Context, collection string, vector []float32, topK uint64) ([]SearchResult, error) {
func (q *qdrantStore) mapFilters(ctx context.Context, filter map[string]interface{}) (*qdrant.Filter, error) {
qdrantFilterMap := &qdrant.Filter{}

if filter == nil {
return qdrantFilterMap, nil
}

for k, v := range filter {
switch v := v.(type) {
case map[string]interface{}:
for op, val := range v {
match, err := createQdrantMatch(val)
if err != nil {
return nil, err
}

condition := &qdrant.Condition{
ConditionOneOf: &qdrant.Condition_Field{
Field: &qdrant.FieldCondition{
Key: k,
Match: match,
},
},
}

switch op {
case "$eq":
qdrantFilterMap.Must = append(qdrantFilterMap.Must, condition)
case "$ne":
qdrantFilterMap.MustNot = append(qdrantFilterMap.MustNot, condition)
default:
return nil, fmt.Errorf("unsupported operator: %s", op)
}
}
case []interface{}:
switch k {
case "$or":
for _, u := range v {
filterMap, err := q.mapFilters(ctx, u.(map[string]interface{}))
if err != nil {
return nil, err
}
qdrantFilterMap.Should = append(qdrantFilterMap.Should, &qdrant.Condition{
ConditionOneOf: &qdrant.Condition_Filter{
Filter: filterMap,
},
})
}
case "$and":
for _, u := range v {
filterMap, err := q.mapFilters(ctx, u.(map[string]interface{}))
if err != nil {
return nil, err
}
qdrantFilterMap.Must = append(qdrantFilterMap.Must, &qdrant.Condition{
ConditionOneOf: &qdrant.Condition_Filter{
Filter: filterMap,
},
})
}
default:
return nil, fmt.Errorf("unsupported operator: %s", k)
}
default:
return nil, fmt.Errorf("unsupported filter struct: %T", v)
}
}

return qdrantFilterMap, nil
}

func createQdrantMatch(val interface{}) (*qdrant.Match, error) {
match := &qdrant.Match{}
switch val := val.(type) {
case string:
match.MatchValue = &qdrant.Match_Keyword{
Keyword: val,
}
default:
return nil, fmt.Errorf("unsupported filter type: %T", val)
}
return match, nil
}

func (q *qdrantStore) Search(ctx context.Context, collection string, vector []float32, topK uint64, filter map[string]interface{}) ([]SearchResult, error) {
if q.md != nil {
ctx = metadata.NewOutgoingContext(ctx, *q.md)
}

qdrantFilter, err := q.mapFilters(ctx, filter)
if err != nil {
return nil, err
}

result, err := q.pointsClient.Search(ctx, &qdrant.SearchPoints{
CollectionName: collection,
Vector: vector,
Limit: topK,
Filter: qdrantFilter,
// Include all payloads in the search result
WithVectors: &qdrant.WithVectorsSelector{SelectorOptions: &qdrant.WithVectorsSelector_Enable{Enable: false}},
WithPayload: &qdrant.WithPayloadSelector{SelectorOptions: &qdrant.WithPayloadSelector_Enable{Enable: true}},
Expand Down
2 changes: 1 addition & 1 deletion pkg/plugin/vector/store/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type SearchResult struct {

type ReadVectorStore interface {
CollectionExists(ctx context.Context, collection string) (bool, error)
Search(ctx context.Context, collection string, vector []float32, topK uint64) ([]SearchResult, error)
Search(ctx context.Context, collection string, vector []float32, topK uint64, filter map[string]interface{}) ([]SearchResult, error)
}

type WriteVectorStore interface {
Expand Down
9 changes: 6 additions & 3 deletions pkg/plugin/vector/store/vectorapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,17 @@ func (g *grafanaVectorAPI) CollectionExists(ctx context.Context, collection stri
return true, nil
}

func (g *grafanaVectorAPI) Search(ctx context.Context, collection string, vector []float32, topK uint64) ([]SearchResult, error) {
func (g *grafanaVectorAPI) Search(ctx context.Context, collection string, vector []float32, topK uint64, filter map[string]interface{}) ([]SearchResult, error) {
type queryPointsRequest struct {
Query []float32 `json:"query"`
TopK uint64 `json:"top_k"`
// optional filter json field
Filter map[string]interface{} `json:"filter"`
}
reqBody := queryPointsRequest{
Query: vector,
TopK: topK,
Query: vector,
TopK: topK,
Filter: filter,
}
reqJSON, err := json.Marshal(reqBody)
if err != nil {
Expand Down

0 comments on commit bbca65b

Please sign in to comment.