diff --git a/pkg/datastore/lib/scores/scores.go b/pkg/datastore/lib/scores/scores.go index bb6c473..03acc9f 100644 --- a/pkg/datastore/lib/scores/scores.go +++ b/pkg/datastore/lib/scores/scores.go @@ -40,6 +40,9 @@ func NormalizeDocScores(docs []vs.Document) []vs.Document { // NormalizeScore normalizes a single score func NormalizeScore(score float32, minScore float32, maxScore float32) float32 { + if maxScore == 0 { + return 0 + } if maxScore-minScore == 0 { return 1 // Avoid division by zero - also, this happens for a single document, so we want a score of 1 here } diff --git a/pkg/datastore/retrievers/bm25.go b/pkg/datastore/retrievers/bm25.go index 2ffbc95..151f030 100644 --- a/pkg/datastore/retrievers/bm25.go +++ b/pkg/datastore/retrievers/bm25.go @@ -29,6 +29,10 @@ func (r *BM25Retriever) Name() string { return BM25RetrieverName } +func (r *BM25Retriever) NormalizedScores() bool { + return false +} + func (r *BM25Retriever) DecodeConfig(cfg map[string]any) error { return DefaultConfigDecoder(r, cfg) } diff --git a/pkg/datastore/retrievers/merging.go b/pkg/datastore/retrievers/merging.go index da14f94..bca6880 100644 --- a/pkg/datastore/retrievers/merging.go +++ b/pkg/datastore/retrievers/merging.go @@ -32,6 +32,10 @@ func (r *MergingRetriever) Name() string { return MergingRetrieverName } +func (r *MergingRetriever) NormalizedScores() bool { + return true +} + func (r *MergingRetriever) DecodeConfig(cfg map[string]any) error { if err := mapstructure.Decode(cfg, &r); err != nil { return fmt.Errorf("failed to decode merging retriever configuration: %w", err) @@ -75,7 +79,11 @@ func (r *MergingRetriever) Retrieve(ctx context.Context, store store.Store, quer slog.Debug("Retrieved documents from retriever", "retriever", retriever.Name, "numDocs", len(retrievedDocs)) - min, max := scores.FindMinMaxScores(retrievedDocs) + normalized := r.retrievers[ri].NormalizedScores() + var minScore, maxScore float32 + if !normalized { + minScore, maxScore = scores.FindMinMaxScores(retrievedDocs) + } docLoop: for _, retrievedDoc := range retrievedDocs { @@ -85,7 +93,11 @@ func (r *MergingRetriever) Retrieve(ctx context.Context, store store.Store, quer // Note that this was found by another retriever and note it's similarityScore resultDocs[i].Metadata["retriever"] = fmt.Sprintf("%s,%s", resultDocs[i].Metadata["retriever"], retriever.Name) resultDocs[i].Metadata["retrieverScore::"+retriever.Name] = retrievedDoc.SimilarityScore - normalizedScore := scores.NormalizeScore(retrievedDoc.SimilarityScore, min, max) + normalizedScore := retrievedDoc.SimilarityScore + if !normalized { + normalizedScore = scores.NormalizeScore(retrievedDoc.SimilarityScore, minScore, maxScore) + slog.Debug("Normalized score", "retriever", retriever.Name, "score", retrievedDoc.SimilarityScore, "minScore", minScore, "maxScore", maxScore, "normalizedScore", normalizedScore) + } resultDocs[i].Metadata["retrieverScoreNormalized::"+retriever.Name] = normalizedScore resultDocs[i].SimilarityScore += normalizedScore * z.Dereference(retriever.Weight) continue docLoop @@ -94,7 +106,7 @@ func (r *MergingRetriever) Retrieve(ctx context.Context, store store.Store, quer // not in resultDocs yet, add it retrievedDoc.Metadata["retriever"] = retriever.Name retrievedDoc.Metadata["retrieverScore::"+retriever.Name] = retrievedDoc.SimilarityScore - normalizedScore := scores.NormalizeScore(retrievedDoc.SimilarityScore, min, max) + normalizedScore := scores.NormalizeScore(retrievedDoc.SimilarityScore, minScore, maxScore) retrievedDoc.Metadata["retrieverScoreNormalized::"+retriever.Name] = normalizedScore retrievedDoc.SimilarityScore = normalizedScore * z.Dereference(retriever.Weight) resultDocs = append(resultDocs, retrievedDoc) diff --git a/pkg/datastore/retrievers/retrievers.go b/pkg/datastore/retrievers/retrievers.go index d2819bf..00d7c65 100644 --- a/pkg/datastore/retrievers/retrievers.go +++ b/pkg/datastore/retrievers/retrievers.go @@ -21,6 +21,7 @@ type Retriever interface { Retrieve(ctx context.Context, store store.Store, query string, datasetIDs []string, where map[string]string, whereDocument []chromem.WhereDocument) ([]vs.Document, error) Name() string DecodeConfig(cfg map[string]any) error + NormalizedScores() bool // whether the retriever returns normalized scores } func GetRetriever(name string) (Retriever, error) { @@ -68,6 +69,10 @@ func (r *BasicRetriever) Name() string { return BasicRetrieverName } +func (r *BasicRetriever) NormalizedScores() bool { + return true +} + func (r *BasicRetriever) DecodeConfig(cfg map[string]any) error { return DefaultConfigDecoder(r, cfg) } diff --git a/pkg/datastore/retrievers/routing.go b/pkg/datastore/retrievers/routing.go index d0c3265..6419635 100644 --- a/pkg/datastore/retrievers/routing.go +++ b/pkg/datastore/retrievers/routing.go @@ -25,6 +25,10 @@ func (r *RoutingRetriever) Name() string { return RoutingRetrieverName } +func (r *RoutingRetriever) NormalizedScores() bool { + return true +} + func (r *RoutingRetriever) DecodeConfig(cfg map[string]any) error { return DefaultConfigDecoder(r, cfg) } diff --git a/pkg/datastore/retrievers/subquery.go b/pkg/datastore/retrievers/subquery.go index ab194e1..6b5a7db 100644 --- a/pkg/datastore/retrievers/subquery.go +++ b/pkg/datastore/retrievers/subquery.go @@ -28,6 +28,10 @@ func (s *SubqueryRetriever) Name() string { return SubqueryRetrieverName } +func (s *SubqueryRetriever) NormalizedScores() bool { + return true +} + func (s *SubqueryRetriever) DecodeConfig(cfg map[string]any) error { return DefaultConfigDecoder(s, cfg) }