Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for filtered vector search #100

Merged
merged 13 commits into from
Oct 17, 2023
2 changes: 1 addition & 1 deletion pkg/plugin/health.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func (a *App) testVectorService(ctx context.Context) error {
if a.vectorService == nil {
return fmt.Errorf("vector service not configured")
}
_, err := a.vectorService.Search(ctx, vectorCollections[0], "test", 1)
_, err := a.vectorService.Search(ctx, vectorCollections[0], "test", 1, nil)
return err
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/plugin/health_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func (m *mockHealthCheckClient) Do(req *http.Request) (*http.Response, error) {

type mockVectorService struct{}

func (m *mockVectorService) Search(ctx context.Context, collection string, query string, topK uint64) ([]store.SearchResult, error) {
func (m *mockVectorService) Search(ctx context.Context, collection string, query string, topK uint64, filter map[string]interface{}) ([]store.SearchResult, error) {
return []store.SearchResult{{Payload: map[string]any{"a": "b"}, Score: 1.0}}, nil
}

Expand Down
3 changes: 2 additions & 1 deletion pkg/plugin/resources.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ type vectorSearchRequest struct {
Query string `json:"query"`
Collection string `json:"collection"`
TopK uint64 `json:"topK"`
Filter map[string]interface{}
sd2k marked this conversation as resolved.
Show resolved Hide resolved
}

type vectorSearchResponse struct {
Expand All @@ -194,7 +195,7 @@ func (app *App) handleVectorSearch(w http.ResponseWriter, req *http.Request) {
if body.TopK == 0 {
body.TopK = 10
}
results, err := app.vectorService.Search(req.Context(), body.Collection, body.Query, body.TopK)
results, err := app.vectorService.Search(req.Context(), body.Collection, body.Query, body.TopK, body.Filter)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
Expand Down
6 changes: 3 additions & 3 deletions pkg/plugin/vector/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
)

type Service interface {
Search(ctx context.Context, collection string, query string, topK uint64) ([]store.SearchResult, error)
Search(ctx context.Context, collection string, query string, topK uint64, filter map[string]interface{}) ([]store.SearchResult, error)
Cancel()
}

Expand Down Expand Up @@ -57,7 +57,7 @@ func NewService(s VectorSettings, secrets map[string]string) (Service, error) {
}, nil
}

func (v *vectorService) Search(ctx context.Context, collection string, query string, topK uint64) ([]store.SearchResult, error) {
func (v *vectorService) Search(ctx context.Context, collection string, query string, topK uint64, filter map[string]interface{}) ([]store.SearchResult, error) {
if query == "" {
return nil, fmt.Errorf("query cannot be empty")
}
Expand All @@ -78,7 +78,7 @@ func (v *vectorService) Search(ctx context.Context, collection string, query str

log.DefaultLogger.Info("Searching", "collection", collection, "query", query)
// Search the vector store for similar vectors.
results, err := v.store.Search(ctx, collection, e, topK)
results, err := v.store.Search(ctx, collection, e, topK, filter)
if err != nil {
return nil, fmt.Errorf("vector store search: %w", err)
}
Expand Down
78 changes: 77 additions & 1 deletion pkg/plugin/vector/store/qdrant.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package store
import (
"context"
"crypto/tls"
"fmt"

"github.com/grafana/grafana-plugin-sdk-go/backend/log"
qdrant "github.com/qdrant/go-client/qdrant"
Expand Down Expand Up @@ -82,14 +83,89 @@ func (q *qdrantStore) CollectionExists(ctx context.Context, collection string) (
return true, nil
}

func (q *qdrantStore) Search(ctx context.Context, collection string, vector []float32, topK uint64) ([]SearchResult, error) {
func (q *qdrantStore) mapFilters(ctx context.Context, filter map[string]interface{}) (*qdrant.Filter, error) {
qdrantFilterMap := &qdrant.Filter{}

if filter == nil {
return qdrantFilterMap, nil
}

for k, v := range filter {
switch v := v.(type) {
case map[string]interface{}:
for op, val := range v {
match, err := createQdrantMatch(val)
if err != nil {
return nil, err
}

condition := &qdrant.Condition{
ConditionOneOf: &qdrant.Condition_Field{
Field: &qdrant.FieldCondition{
Key: k,
Match: match,
},
},
}

switch op {
case "$eq":
qdrantFilterMap.Must = append(qdrantFilterMap.Must, condition)
case "$ne":
qdrantFilterMap.MustNot = append(qdrantFilterMap.MustNot, condition)
}
sd2k marked this conversation as resolved.
Show resolved Hide resolved
}
case []interface{}:
switch k {
case "$or":
for _, u := range v {
filterMap, err := q.mapFilters(ctx, u.(map[string]interface{}))
if err != nil {
return nil, err
}
qdrantFilterMap.Should = append(qdrantFilterMap.Should, &qdrant.Condition{
ConditionOneOf: &qdrant.Condition_Filter{
Filter: filterMap,
},
})
}
}
sd2k marked this conversation as resolved.
Show resolved Hide resolved
default:
return nil, fmt.Errorf("unsupported filter struct: %T", v)
}
}

return qdrantFilterMap, nil
}

func createQdrantMatch(val interface{}) (*qdrant.Match, error) {
match := &qdrant.Match{}
switch val := val.(type) {
case string:
match.MatchValue = &qdrant.Match_Keyword{
Keyword: val,
}
default:
return nil, fmt.Errorf("unsupported filter type: %T", val)
}
return match, nil
}

func (q *qdrantStore) Search(ctx context.Context, collection string, vector []float32, topK uint64, filter map[string]interface{}) ([]SearchResult, error) {
if q.md != nil {
ctx = metadata.NewOutgoingContext(ctx, *q.md)
}

qdrantFilter, err := q.mapFilters(ctx, filter)
if err != nil {
return nil, err
}

result, err := q.pointsClient.Search(ctx, &qdrant.SearchPoints{
CollectionName: collection,
Vector: vector,
Limit: topK,
Filter: qdrantFilter,
// Include all payloads in the search result
WithVectors: &qdrant.WithVectorsSelector{SelectorOptions: &qdrant.WithVectorsSelector_Enable{Enable: false}},
WithPayload: &qdrant.WithPayloadSelector{SelectorOptions: &qdrant.WithPayloadSelector_Enable{Enable: true}},
Expand Down
2 changes: 1 addition & 1 deletion pkg/plugin/vector/store/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type SearchResult struct {

type ReadVectorStore interface {
CollectionExists(ctx context.Context, collection string) (bool, error)
Search(ctx context.Context, collection string, vector []float32, topK uint64) ([]SearchResult, error)
Search(ctx context.Context, collection string, vector []float32, topK uint64, filter map[string]interface{}) ([]SearchResult, error)
}

type WriteVectorStore interface {
Expand Down
9 changes: 6 additions & 3 deletions pkg/plugin/vector/store/vectorapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,17 @@ func (g *grafanaVectorAPI) CollectionExists(ctx context.Context, collection stri
return true, nil
}

func (g *grafanaVectorAPI) Search(ctx context.Context, collection string, vector []float32, topK uint64) ([]SearchResult, error) {
func (g *grafanaVectorAPI) Search(ctx context.Context, collection string, vector []float32, topK uint64, filter map[string]interface{}) ([]SearchResult, error) {
type queryPointsRequest struct {
Query []float32 `json:"query"`
TopK uint64 `json:"top_k"`
// optional filter json field
Filter map[string]interface{} `json:"filter"`
}
reqBody := queryPointsRequest{
Query: vector,
TopK: topK,
Query: vector,
TopK: topK,
Filter: filter,
}
reqJSON, err := json.Marshal(reqBody)
if err != nil {
Expand Down