Skip to content
This repository has been archived by the owner on Oct 30, 2024. It is now read-only.

fix: do not min/max-normalize already normalized similarity scores to not distort them #107

Merged
merged 1 commit into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pkg/datastore/lib/scores/scores.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ func NormalizeDocScores(docs []vs.Document) []vs.Document {

// NormalizeScore normalizes a single score
func NormalizeScore(score float32, minScore float32, maxScore float32) float32 {
if maxScore == 0 {
return 0
}
if maxScore-minScore == 0 {
return 1 // Avoid division by zero - also, this happens for a single document, so we want a score of 1 here
}
Expand Down
4 changes: 4 additions & 0 deletions pkg/datastore/retrievers/bm25.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ func (r *BM25Retriever) Name() string {
return BM25RetrieverName
}

func (r *BM25Retriever) NormalizedScores() bool {
return false
}

func (r *BM25Retriever) DecodeConfig(cfg map[string]any) error {
return DefaultConfigDecoder(r, cfg)
}
Expand Down
18 changes: 15 additions & 3 deletions pkg/datastore/retrievers/merging.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ func (r *MergingRetriever) Name() string {
return MergingRetrieverName
}

func (r *MergingRetriever) NormalizedScores() bool {
return true
}

func (r *MergingRetriever) DecodeConfig(cfg map[string]any) error {
if err := mapstructure.Decode(cfg, &r); err != nil {
return fmt.Errorf("failed to decode merging retriever configuration: %w", err)
Expand Down Expand Up @@ -75,7 +79,11 @@ func (r *MergingRetriever) Retrieve(ctx context.Context, store store.Store, quer

slog.Debug("Retrieved documents from retriever", "retriever", retriever.Name, "numDocs", len(retrievedDocs))

min, max := scores.FindMinMaxScores(retrievedDocs)
normalized := r.retrievers[ri].NormalizedScores()
var minScore, maxScore float32
if !normalized {
minScore, maxScore = scores.FindMinMaxScores(retrievedDocs)
}

docLoop:
for _, retrievedDoc := range retrievedDocs {
Expand All @@ -85,7 +93,11 @@ func (r *MergingRetriever) Retrieve(ctx context.Context, store store.Store, quer
// Note that this was found by another retriever and note it's similarityScore
resultDocs[i].Metadata["retriever"] = fmt.Sprintf("%s,%s", resultDocs[i].Metadata["retriever"], retriever.Name)
resultDocs[i].Metadata["retrieverScore::"+retriever.Name] = retrievedDoc.SimilarityScore
normalizedScore := scores.NormalizeScore(retrievedDoc.SimilarityScore, min, max)
normalizedScore := retrievedDoc.SimilarityScore
if !normalized {
normalizedScore = scores.NormalizeScore(retrievedDoc.SimilarityScore, minScore, maxScore)
slog.Debug("Normalized score", "retriever", retriever.Name, "score", retrievedDoc.SimilarityScore, "minScore", minScore, "maxScore", maxScore, "normalizedScore", normalizedScore)
}
resultDocs[i].Metadata["retrieverScoreNormalized::"+retriever.Name] = normalizedScore
resultDocs[i].SimilarityScore += normalizedScore * z.Dereference(retriever.Weight)
continue docLoop
Expand All @@ -94,7 +106,7 @@ func (r *MergingRetriever) Retrieve(ctx context.Context, store store.Store, quer
// not in resultDocs yet, add it
retrievedDoc.Metadata["retriever"] = retriever.Name
retrievedDoc.Metadata["retrieverScore::"+retriever.Name] = retrievedDoc.SimilarityScore
normalizedScore := scores.NormalizeScore(retrievedDoc.SimilarityScore, min, max)
normalizedScore := scores.NormalizeScore(retrievedDoc.SimilarityScore, minScore, maxScore)
retrievedDoc.Metadata["retrieverScoreNormalized::"+retriever.Name] = normalizedScore
retrievedDoc.SimilarityScore = normalizedScore * z.Dereference(retriever.Weight)
resultDocs = append(resultDocs, retrievedDoc)
Expand Down
5 changes: 5 additions & 0 deletions pkg/datastore/retrievers/retrievers.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type Retriever interface {
Retrieve(ctx context.Context, store store.Store, query string, datasetIDs []string, where map[string]string, whereDocument []chromem.WhereDocument) ([]vs.Document, error)
Name() string
DecodeConfig(cfg map[string]any) error
NormalizedScores() bool // whether the retriever returns normalized scores
}

func GetRetriever(name string) (Retriever, error) {
Expand Down Expand Up @@ -68,6 +69,10 @@ func (r *BasicRetriever) Name() string {
return BasicRetrieverName
}

func (r *BasicRetriever) NormalizedScores() bool {
return true
}

func (r *BasicRetriever) DecodeConfig(cfg map[string]any) error {
return DefaultConfigDecoder(r, cfg)
}
Expand Down
4 changes: 4 additions & 0 deletions pkg/datastore/retrievers/routing.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ func (r *RoutingRetriever) Name() string {
return RoutingRetrieverName
}

func (r *RoutingRetriever) NormalizedScores() bool {
return true
}

func (r *RoutingRetriever) DecodeConfig(cfg map[string]any) error {
return DefaultConfigDecoder(r, cfg)
}
Expand Down
4 changes: 4 additions & 0 deletions pkg/datastore/retrievers/subquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ func (s *SubqueryRetriever) Name() string {
return SubqueryRetrieverName
}

func (s *SubqueryRetriever) NormalizedScores() bool {
return true
}

func (s *SubqueryRetriever) DecodeConfig(cfg map[string]any) error {
return DefaultConfigDecoder(s, cfg)
}
Expand Down