From 3be046adf54bddd9fe5b3cf2086617fa6a6e00d9 Mon Sep 17 00:00:00 2001 From: Thorsten Klein Date: Wed, 29 May 2024 14:27:42 +0200 Subject: [PATCH] add: postprocessors in retrieval flow --- .../postprocessors/postprocessors.go | 26 +++++++++++++++++++ pkg/flows/config/config.go | 17 ++++++++++++ pkg/flows/flows.go | 10 +++++-- 3 files changed, 51 insertions(+), 2 deletions(-) diff --git a/pkg/datastore/postprocessors/postprocessors.go b/pkg/datastore/postprocessors/postprocessors.go index 7cc9f6e2..b6eb608d 100644 --- a/pkg/datastore/postprocessors/postprocessors.go +++ b/pkg/datastore/postprocessors/postprocessors.go @@ -1 +1,27 @@ +// Package postprocessors is basically the same as package transformers, but used at a different stage of the RAG pipeline package postprocessors + +import ( + "fmt" + "github.com/gptscript-ai/knowledge/pkg/datastore/transformers" + "github.com/gptscript-ai/knowledge/pkg/datastore/types" +) + +// Postprocessor may be a "normal" +type Postprocessor types.DocumentTransformer + +var PostprocessorMap = map[string]Postprocessor{} + +func GetPostprocessor(name string) (Postprocessor, error) { + var postprocessor Postprocessor + var ok bool + postprocessor, ok = PostprocessorMap[name] + if !ok { + var err error + postprocessor, err = transformers.GetTransformer(name) + if err != nil { + return nil, fmt.Errorf("unknown postprocessor %q", name) + } + } + return postprocessor, nil +} diff --git a/pkg/flows/config/config.go b/pkg/flows/config/config.go index d38e33dc..b39d569f 100644 --- a/pkg/flows/config/config.go +++ b/pkg/flows/config/config.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "github.com/gptscript-ai/knowledge/pkg/datastore/documentloader" + "github.com/gptscript-ai/knowledge/pkg/datastore/postprocessors" "github.com/gptscript-ai/knowledge/pkg/datastore/querymodifiers" "github.com/gptscript-ai/knowledge/pkg/datastore/retrievers" "github.com/gptscript-ai/knowledge/pkg/datastore/textsplitter" @@ -250,5 +251,21 @@ func (r *RetrievalFlowConfig) AsRetrievalFlow() (*flows.RetrievalFlow, error) { flow.Retriever = ret } + if len(r.Postprocessors) > 0 { + for _, pp := range r.Postprocessors { + postprocessor, err := postprocessors.GetPostprocessor(pp.Name) + if err != nil { + return nil, err + } + if len(pp.Options) > 0 { + if err := mapstructure.Decode(pp.Options, &postprocessor); err != nil { + return nil, fmt.Errorf("failed to decode postprocessor configuration: %w", err) + } + slog.Debug("Postprocessor custom configuration", "name", pp.Name, "config", postprocessor) + } + flow.Postprocessors = append(flow.Postprocessors, postprocessor) + } + } + return flow, nil } diff --git a/pkg/flows/flows.go b/pkg/flows/flows.go index 2e73023b..c69a4ac6 100644 --- a/pkg/flows/flows.go +++ b/pkg/flows/flows.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/gptscript-ai/knowledge/pkg/datastore/defaults" "github.com/gptscript-ai/knowledge/pkg/datastore/documentloader" + "github.com/gptscript-ai/knowledge/pkg/datastore/postprocessors" "github.com/gptscript-ai/knowledge/pkg/datastore/querymodifiers" "github.com/gptscript-ai/knowledge/pkg/datastore/retrievers" "github.com/gptscript-ai/knowledge/pkg/datastore/textsplitter" @@ -98,7 +99,7 @@ func (f *IngestionFlow) Run(ctx context.Context, reader io.Reader) ([]vs.Documen type RetrievalFlow struct { QueryModifiers []querymodifiers.QueryModifier Retriever retrievers.Retriever - // TODO: Postprocessors + Postprocessors []postprocessors.Postprocessor } func (f *RetrievalFlow) FillDefaults() { @@ -121,7 +122,12 @@ func (f *RetrievalFlow) Run(ctx context.Context, store vs.VectorStore, query str return nil, err } - // TODO: add postprocessors + for _, pp := range f.Postprocessors { + docs, err = pp.Transform(ctx, docs) + if err != nil { + return nil, err + } + } slog.Debug("Retrieved documents", "num_documents", len(docs), "query", query, "dataset", datasetID) return docs, nil