Skip to content
This repository has been archived by the owner on Oct 30, 2024. It is now read-only.

Commit

Permalink
fix: do not min/max-normalize already normalized similarity scores to…
Browse files Browse the repository at this point in the history
… not distort them (#107)
  • Loading branch information
iwilltry42 authored Sep 6, 2024
1 parent b7ef94f commit 8d35da0
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 3 deletions.
3 changes: 3 additions & 0 deletions pkg/datastore/lib/scores/scores.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ func NormalizeDocScores(docs []vs.Document) []vs.Document {

// NormalizeScore normalizes a single score
func NormalizeScore(score float32, minScore float32, maxScore float32) float32 {
if maxScore == 0 {
return 0
}
if maxScore-minScore == 0 {
return 1 // Avoid division by zero - also, this happens for a single document, so we want a score of 1 here
}
Expand Down
4 changes: 4 additions & 0 deletions pkg/datastore/retrievers/bm25.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ func (r *BM25Retriever) Name() string {
return BM25RetrieverName
}

func (r *BM25Retriever) NormalizedScores() bool {
return false
}

func (r *BM25Retriever) DecodeConfig(cfg map[string]any) error {
return DefaultConfigDecoder(r, cfg)
}
Expand Down
18 changes: 15 additions & 3 deletions pkg/datastore/retrievers/merging.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ func (r *MergingRetriever) Name() string {
return MergingRetrieverName
}

func (r *MergingRetriever) NormalizedScores() bool {
return true
}

func (r *MergingRetriever) DecodeConfig(cfg map[string]any) error {
if err := mapstructure.Decode(cfg, &r); err != nil {
return fmt.Errorf("failed to decode merging retriever configuration: %w", err)
Expand Down Expand Up @@ -75,7 +79,11 @@ func (r *MergingRetriever) Retrieve(ctx context.Context, store store.Store, quer

slog.Debug("Retrieved documents from retriever", "retriever", retriever.Name, "numDocs", len(retrievedDocs))

min, max := scores.FindMinMaxScores(retrievedDocs)
normalized := r.retrievers[ri].NormalizedScores()
var minScore, maxScore float32
if !normalized {
minScore, maxScore = scores.FindMinMaxScores(retrievedDocs)
}

docLoop:
for _, retrievedDoc := range retrievedDocs {
Expand All @@ -85,7 +93,11 @@ func (r *MergingRetriever) Retrieve(ctx context.Context, store store.Store, quer
// Note that this was found by another retriever and note it's similarityScore
resultDocs[i].Metadata["retriever"] = fmt.Sprintf("%s,%s", resultDocs[i].Metadata["retriever"], retriever.Name)
resultDocs[i].Metadata["retrieverScore::"+retriever.Name] = retrievedDoc.SimilarityScore
normalizedScore := scores.NormalizeScore(retrievedDoc.SimilarityScore, min, max)
normalizedScore := retrievedDoc.SimilarityScore
if !normalized {
normalizedScore = scores.NormalizeScore(retrievedDoc.SimilarityScore, minScore, maxScore)
slog.Debug("Normalized score", "retriever", retriever.Name, "score", retrievedDoc.SimilarityScore, "minScore", minScore, "maxScore", maxScore, "normalizedScore", normalizedScore)
}
resultDocs[i].Metadata["retrieverScoreNormalized::"+retriever.Name] = normalizedScore
resultDocs[i].SimilarityScore += normalizedScore * z.Dereference(retriever.Weight)
continue docLoop
Expand All @@ -94,7 +106,7 @@ func (r *MergingRetriever) Retrieve(ctx context.Context, store store.Store, quer
// not in resultDocs yet, add it
retrievedDoc.Metadata["retriever"] = retriever.Name
retrievedDoc.Metadata["retrieverScore::"+retriever.Name] = retrievedDoc.SimilarityScore
normalizedScore := scores.NormalizeScore(retrievedDoc.SimilarityScore, min, max)
normalizedScore := scores.NormalizeScore(retrievedDoc.SimilarityScore, minScore, maxScore)
retrievedDoc.Metadata["retrieverScoreNormalized::"+retriever.Name] = normalizedScore
retrievedDoc.SimilarityScore = normalizedScore * z.Dereference(retriever.Weight)
resultDocs = append(resultDocs, retrievedDoc)
Expand Down
5 changes: 5 additions & 0 deletions pkg/datastore/retrievers/retrievers.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type Retriever interface {
Retrieve(ctx context.Context, store store.Store, query string, datasetIDs []string, where map[string]string, whereDocument []chromem.WhereDocument) ([]vs.Document, error)
Name() string
DecodeConfig(cfg map[string]any) error
NormalizedScores() bool // whether the retriever returns normalized scores
}

func GetRetriever(name string) (Retriever, error) {
Expand Down Expand Up @@ -68,6 +69,10 @@ func (r *BasicRetriever) Name() string {
return BasicRetrieverName
}

func (r *BasicRetriever) NormalizedScores() bool {
return true
}

func (r *BasicRetriever) DecodeConfig(cfg map[string]any) error {
return DefaultConfigDecoder(r, cfg)
}
Expand Down
4 changes: 4 additions & 0 deletions pkg/datastore/retrievers/routing.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ func (r *RoutingRetriever) Name() string {
return RoutingRetrieverName
}

func (r *RoutingRetriever) NormalizedScores() bool {
return true
}

func (r *RoutingRetriever) DecodeConfig(cfg map[string]any) error {
return DefaultConfigDecoder(r, cfg)
}
Expand Down
4 changes: 4 additions & 0 deletions pkg/datastore/retrievers/subquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ func (s *SubqueryRetriever) Name() string {
return SubqueryRetrieverName
}

func (s *SubqueryRetriever) NormalizedScores() bool {
return true
}

func (s *SubqueryRetriever) DecodeConfig(cfg map[string]any) error {
return DefaultConfigDecoder(s, cfg)
}
Expand Down

0 comments on commit 8d35da0

Please sign in to comment.