From 5808864f0d7d81bb865701f04ada4df0bf66141c Mon Sep 17 00:00:00 2001 From: Thorsten Klein Date: Tue, 18 Jun 2024 14:35:19 +0200 Subject: [PATCH] change: no error on empty collection, so that the LLM session can 'recover' from the tool call --- pkg/cmd/retrieve.go | 7 +++++++ pkg/vectorstore/chromem/chromem.go | 12 +++++++++--- pkg/vectorstore/errors.go | 9 --------- pkg/vectorstore/errors/errors.go | 10 ++++++++++ 4 files changed, 26 insertions(+), 12 deletions(-) delete mode 100644 pkg/vectorstore/errors.go create mode 100644 pkg/vectorstore/errors/errors.go diff --git a/pkg/cmd/retrieve.go b/pkg/cmd/retrieve.go index d5eef119..3d5b69da 100644 --- a/pkg/cmd/retrieve.go +++ b/pkg/cmd/retrieve.go @@ -2,7 +2,9 @@ package cmd import ( "encoding/json" + "errors" "fmt" + vserr "github.com/gptscript-ai/knowledge/pkg/vectorstore/errors" "log/slog" "github.com/gptscript-ai/knowledge/pkg/datastore" @@ -73,6 +75,11 @@ func (s *ClientRetrieve) Run(cmd *cobra.Command, args []string) error { sources, err := c.Retrieve(cmd.Context(), datasetID, query, retrieveOpts) if err != nil { + // An empty collection is not a hard error - the LLM session can "recover" from it + if errors.Is(err, vserr.ErrCollectionEmpty) { + fmt.Printf("Dataset %q does not contain any documents\n", datasetID) + return nil + } return err } diff --git a/pkg/vectorstore/chromem/chromem.go b/pkg/vectorstore/chromem/chromem.go index cc5e8030..9742a629 100644 --- a/pkg/vectorstore/chromem/chromem.go +++ b/pkg/vectorstore/chromem/chromem.go @@ -2,6 +2,8 @@ package chromem import ( "context" + "fmt" + "github.com/gptscript-ai/knowledge/pkg/vectorstore/errors" "log/slog" "maps" "runtime" @@ -55,7 +57,7 @@ func (s *Store) AddDocuments(ctx context.Context, docs []vs.Document, collection col := s.db.GetCollection(collection, s.embeddingFunc) if col == nil { - return nil, vs.ErrCollectionNotFound{Collection: collection} + return nil, fmt.Errorf("%w: %q", errors.ErrCollectionNotFound, collection) } err := col.AddDocuments(ctx, chromemDocs, runtime.NumCPU()/2) @@ -102,7 +104,11 @@ func convertStringMapToAnyMap(m map[string]string) map[string]any { func (s *Store) SimilaritySearch(ctx context.Context, query string, numDocuments int, collection string) ([]vs.Document, error) { col := s.db.GetCollection(collection, s.embeddingFunc) if col == nil { - return nil, vs.ErrCollectionNotFound{Collection: collection} + return nil, fmt.Errorf("%w: %q", errors.ErrCollectionNotFound, collection) + } + + if col.Count() == 0 { + return nil, fmt.Errorf("%w: %q", errors.ErrCollectionEmpty, collection) } if numDocuments > col.Count() { @@ -139,7 +145,7 @@ func (s *Store) RemoveCollection(_ context.Context, collection string) error { func (s *Store) RemoveDocument(ctx context.Context, documentID string, collection string) error { col := s.db.GetCollection(collection, s.embeddingFunc) if col == nil { - return vs.ErrCollectionNotFound{Collection: collection} + return fmt.Errorf("%w: %q", errors.ErrCollectionNotFound, collection) } return col.Delete(ctx, nil, nil, documentID) } diff --git a/pkg/vectorstore/errors.go b/pkg/vectorstore/errors.go deleted file mode 100644 index 3e091a10..00000000 --- a/pkg/vectorstore/errors.go +++ /dev/null @@ -1,9 +0,0 @@ -package vectorstore - -type ErrCollectionNotFound struct { - Collection string -} - -func (e ErrCollectionNotFound) Error() string { - return "collection not found: " + e.Collection -} diff --git a/pkg/vectorstore/errors/errors.go b/pkg/vectorstore/errors/errors.go new file mode 100644 index 00000000..ae34eb72 --- /dev/null +++ b/pkg/vectorstore/errors/errors.go @@ -0,0 +1,10 @@ +package errors + +import ( + "errors" +) + +var ( + ErrCollectionNotFound = errors.New("collection not found") + ErrCollectionEmpty = errors.New("collection is empty") +)