diff --git a/examples/routing_retriever.yaml b/examples/routing_retriever.yaml new file mode 100644 index 0000000..65f6882 --- /dev/null +++ b/examples/routing_retriever.yaml @@ -0,0 +1,14 @@ +flows: + foo: + default: true + retrieval: + retriever: + name: routing + options: + model: + openai: + apiKey: "${OPENAI_API_KEY}" + model: gpt-4o + apiType: OPEN_AI + apiBase: https://api.openai.com/v1 + topK: 6 \ No newline at end of file diff --git a/pkg/client/client.go b/pkg/client/client.go index 5df6e7e..902214c 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -30,4 +30,5 @@ type Client interface { Retrieve(ctx context.Context, datasetID string, query string, opts datastore.RetrieveOpts) ([]vectorstore.Document, error) ExportDatasets(ctx context.Context, path string, datasets ...string) error ImportDatasets(ctx context.Context, path string, datasets ...string) error + UpdateDataset(ctx context.Context, dataset index.Dataset, opts *datastore.UpdateDatasetOpts) (*index.Dataset, error) } diff --git a/pkg/client/default.go b/pkg/client/default.go index 846baf4..279d43b 100644 --- a/pkg/client/default.go +++ b/pkg/client/default.go @@ -228,3 +228,8 @@ func (c *DefaultClient) ImportDatasets(ctx context.Context, path string, dataset // TODO: implement panic("not implemented") } + +func (c *DefaultClient) UpdateDataset(ctx context.Context, dataset index.Dataset, opts *datastore.UpdateDatasetOpts) (*index.Dataset, error) { + // TODO: implement + panic("not implemented") +} diff --git a/pkg/client/standalone.go b/pkg/client/standalone.go index 3661431..642bdfe 100644 --- a/pkg/client/standalone.go +++ b/pkg/client/standalone.go @@ -132,3 +132,7 @@ func (c *StandaloneClient) ExportDatasets(ctx context.Context, path string, data func (c *StandaloneClient) ImportDatasets(ctx context.Context, path string, datasets ...string) error { return c.Datastore.ImportDatasetsFromFile(ctx, path, datasets...) } + +func (c *StandaloneClient) UpdateDataset(ctx context.Context, dataset index.Dataset, opts *datastore.UpdateDatasetOpts) (*index.Dataset, error) { + return c.Datastore.UpdateDataset(ctx, dataset, opts) +} diff --git a/pkg/cmd/edit_dataset.go b/pkg/cmd/edit_dataset.go new file mode 100644 index 0000000..0ddc42b --- /dev/null +++ b/pkg/cmd/edit_dataset.go @@ -0,0 +1,75 @@ +package cmd + +import ( + "encoding/json" + "fmt" + "github.com/gptscript-ai/knowledge/pkg/datastore" + "github.com/gptscript-ai/knowledge/pkg/index" + "github.com/spf13/cobra" +) + +type ClientEditDataset struct { + Client + ResetMetadata bool `usage:"reset metadata to default (empty)"` + UpdateMetadata map[string]string `usage:"update metadata key-value pairs (existing metadata will be updated/preserved)"` + ReplaceMetadata map[string]string `usage:"replace metadata with key-value pairs (existing metadata will be removed)"` +} + +func (s *ClientEditDataset) Customize(cmd *cobra.Command) { + cmd.Use = "edit-dataset " + cmd.Short = "Edit an existing dataset" + cmd.Args = cobra.ExactArgs(1) + cmd.MarkFlagsMutuallyExclusive("reset-metadata", "update-metadata", "replace-metadata") +} + +func (s *ClientEditDataset) Run(cmd *cobra.Command, args []string) error { + c, err := s.getClient() + if err != nil { + return err + } + + datasetID := args[0] + + // Get current dataset + dataset, err := c.GetDataset(cmd.Context(), datasetID) + if err != nil { + return fmt.Errorf("failed to get dataset: %w", err) + } + + if dataset == nil { + fmt.Printf("dataset not found: %q\n", datasetID) + return fmt.Errorf("dataset not found: %s", datasetID) + } + + updatedDataset := index.Dataset{ + ID: dataset.ID, + } + + // Update Metadata - since flags are mutually exclusive, this should be either an empty map, or one of the update/replace maps + metadata := map[string]any{} + + for k, v := range s.UpdateMetadata { + metadata[k] = v + } + + for k, v := range s.ReplaceMetadata { + metadata[k] = v + } + + updatedDataset.Metadata = metadata + + dataset, err = c.UpdateDataset(cmd.Context(), updatedDataset, &datastore.UpdateDatasetOpts{ReplaceMedata: s.ResetMetadata || len(s.ReplaceMetadata) > 0}) + if err != nil { + return fmt.Errorf("failed to update dataset: %w", err) + } + + dataset.Files = nil // Don't print files + + jsonOutput, err := json.Marshal(dataset) + if err != nil { + return fmt.Errorf("failed to marshal dataset: %w", err) + } + + fmt.Println("Updated dataset:\n", string(jsonOutput)) + return nil +} diff --git a/pkg/cmd/main.go b/pkg/cmd/main.go index 7a4bf2e..2ed54ab 100644 --- a/pkg/cmd/main.go +++ b/pkg/cmd/main.go @@ -30,6 +30,7 @@ func New() *cobra.Command { new(ClientAskDir), new(ClientExportDatasets), new(ClientImportDatasets), + new(ClientEditDataset), new(Version), ) } diff --git a/pkg/datastore/dataset.go b/pkg/datastore/dataset.go index 0a09406..e3c336c 100644 --- a/pkg/datastore/dataset.go +++ b/pkg/datastore/dataset.go @@ -11,6 +11,10 @@ import ( "gorm.io/gorm" ) +type UpdateDatasetOpts struct { + ReplaceMedata bool +} + func (s *Datastore) NewDataset(ctx context.Context, dataset index.Dataset) error { // Set defaults if dataset.EmbedDimension <= 0 { @@ -75,3 +79,44 @@ func (s *Datastore) ListDatasets(ctx context.Context) ([]index.Dataset, error) { return datasets, nil } + +func (s *Datastore) UpdateDataset(ctx context.Context, updatedDataset index.Dataset, opts *UpdateDatasetOpts) (*index.Dataset, error) { + if opts == nil { + opts = &UpdateDatasetOpts{} + } + + var origDS *index.Dataset + var err error + + if updatedDataset.ID == "" { + return origDS, fmt.Errorf("dataset ID is required") + } + + origDS, err = s.GetDataset(ctx, updatedDataset.ID) + if err != nil { + return origDS, err + } + if origDS == nil { + return origDS, fmt.Errorf("dataset not found: %s", updatedDataset.ID) + } + + // Update Metadata + if opts.ReplaceMedata { + origDS.ReplaceMetadata(updatedDataset.Metadata) + } else { + origDS.UpdateMetadata(updatedDataset.Metadata) + } + + // Check if there is any other non-null field in the updatedDataset + if updatedDataset.EmbedDimension > 0 { + return origDS, fmt.Errorf("embedding dimension cannot be updated") + } + + if updatedDataset.Files != nil { + return origDS, fmt.Errorf("files cannot be updated") + } + + slog.Debug("Updating dataset", "id", updatedDataset.ID, "metadata", updatedDataset.Metadata) + + return origDS, s.Index.UpdateDataset(ctx, *origDS) +} diff --git a/pkg/datastore/retrieve.go b/pkg/datastore/retrieve.go index 812c757..e984dbc 100644 --- a/pkg/datastore/retrieve.go +++ b/pkg/datastore/retrieve.go @@ -27,5 +27,9 @@ func (s *Datastore) Retrieve(ctx context.Context, datasetID string, query string } retrievalFlow.FillDefaults(topK) - return retrievalFlow.Run(ctx, s.Vectorstore, query, datasetID) + return retrievalFlow.Run(ctx, s, query, datasetID) +} + +func (s *Datastore) SimilaritySearch(ctx context.Context, query string, numDocuments int, datasetID string) ([]vectorstore.Document, error) { + return s.Vectorstore.SimilaritySearch(ctx, query, numDocuments, datasetID) } diff --git a/pkg/datastore/retrievers/retrievers.go b/pkg/datastore/retrievers/retrievers.go index 54b9fde..2135fd9 100644 --- a/pkg/datastore/retrievers/retrievers.go +++ b/pkg/datastore/retrievers/retrievers.go @@ -3,6 +3,7 @@ package retrievers import ( "context" "fmt" + "github.com/gptscript-ai/knowledge/pkg/datastore/store" "log/slog" "github.com/gptscript-ai/knowledge/pkg/datastore/defaults" @@ -10,7 +11,7 @@ import ( ) type Retriever interface { - Retrieve(ctx context.Context, store vs.VectorStore, query string, datasetID string) ([]vs.Document, error) + Retrieve(ctx context.Context, store store.Store, query string, datasetID string) ([]vs.Document, error) } func GetRetriever(name string) (Retriever, error) { @@ -19,6 +20,8 @@ func GetRetriever(name string) (Retriever, error) { return &BasicRetriever{TopK: defaults.TopK}, nil case "subquery": return &SubqueryRetriever{Limit: 3, TopK: 3}, nil + case "routing": + return &RoutingRetriever{TopK: defaults.TopK}, nil default: return nil, fmt.Errorf("unknown retriever %q", name) } @@ -32,7 +35,7 @@ type BasicRetriever struct { TopK int } -func (r *BasicRetriever) Retrieve(ctx context.Context, store vs.VectorStore, query string, datasetID string) ([]vs.Document, error) { +func (r *BasicRetriever) Retrieve(ctx context.Context, store store.Store, query string, datasetID string) ([]vs.Document, error) { if r.TopK <= 0 { slog.Debug("[BasicRetriever] TopK not set, using default", "default", defaults.TopK) r.TopK = defaults.TopK diff --git a/pkg/datastore/retrievers/routing.go b/pkg/datastore/retrievers/routing.go new file mode 100644 index 0000000..bfbbfc2 --- /dev/null +++ b/pkg/datastore/retrievers/routing.go @@ -0,0 +1,89 @@ +package retrievers + +import ( + "context" + "encoding/json" + "fmt" + "github.com/gptscript-ai/knowledge/pkg/datastore/defaults" + "github.com/gptscript-ai/knowledge/pkg/datastore/store" + "github.com/gptscript-ai/knowledge/pkg/llm" + vs "github.com/gptscript-ai/knowledge/pkg/vectorstore" + "log/slog" +) + +type RoutingRetriever struct { + Model llm.LLMConfig + AvailableDatasets []string + TopK int +} + +var routingPromptTpl = `The following query will be used for a vector similarity search. +Please route it to the appropriate dataset. Choose the one that fits best to the query based on the metadata. +Query: "{{.query}}" +Available datasets in a JSON map, where the key is the dataset ID and the value is a map of metadata fields: +{{ .datasets }} +Reply only in the following JSON format, without any styling or markdown syntax: +{"result": ""}` + +type routingResp struct { + Result string `json:"result"` +} + +func (r *RoutingRetriever) Retrieve(ctx context.Context, store store.Store, query string, datasetID string) ([]vs.Document, error) { + log := slog.With("component", "RoutingRetriever") + + log.Debug("Ignoring input datasetID in routing retriever, as it chooses on by itself", "query", query, "inputDataset", datasetID) + + if r.TopK <= 0 { + log.Debug("TopK not set, using default", "default", defaults.TopK) + r.TopK = defaults.TopK + } + + if len(r.AvailableDatasets) == 0 { + allDatasets, err := store.ListDatasets(ctx) + if err != nil { + return nil, err + } + for _, ds := range allDatasets { + r.AvailableDatasets = append(r.AvailableDatasets, ds.ID) + } + } + slog.Debug("Available datasets", "datasets", r.AvailableDatasets) + + datasets := map[string]map[string]any{} + for _, dsID := range r.AvailableDatasets { + dataset, err := store.GetDataset(ctx, dsID) + if err != nil { + return nil, err + } + if dataset == nil { + return nil, fmt.Errorf("dataset not found: %q", dsID) + } + datasets[dataset.ID] = dataset.Metadata + } + + datasetsJSON, err := json.Marshal(datasets) + if err != nil { + return nil, err + } + + m, err := llm.NewFromConfig(r.Model) + if err != nil { + return nil, err + } + + result, err := m.Prompt(context.Background(), routingPromptTpl, map[string]interface{}{"query": query, "datasets": string(datasetsJSON)}) + if err != nil { + return nil, err + } + slog.Debug("Routing result", "result", result) + var resp routingResp + err = json.Unmarshal([]byte(result), &resp) + if err != nil { + return nil, err + } + + slog.Debug("Routing query to dataset", "query", query, "dataset", resp.Result) + + return store.SimilaritySearch(ctx, query, r.TopK, resp.Result) +} diff --git a/pkg/datastore/retrievers/subquery.go b/pkg/datastore/retrievers/subquery.go index e83b961..9da7ddf 100644 --- a/pkg/datastore/retrievers/subquery.go +++ b/pkg/datastore/retrievers/subquery.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "github.com/gptscript-ai/knowledge/pkg/datastore/store" "github.com/gptscript-ai/knowledge/pkg/llm" vs "github.com/gptscript-ai/knowledge/pkg/vectorstore" "log/slog" @@ -30,7 +31,7 @@ type subqueryResp struct { Results []string `json:"results"` } -func (s SubqueryRetriever) Retrieve(ctx context.Context, store vs.VectorStore, query string, datasetID string) ([]vs.Document, error) { +func (s SubqueryRetriever) Retrieve(ctx context.Context, store store.Store, query string, datasetID string) ([]vs.Document, error) { m, err := llm.NewFromConfig(s.Model) if err != nil { return nil, err diff --git a/pkg/datastore/store/store.go b/pkg/datastore/store/store.go new file mode 100644 index 0000000..71ef58c --- /dev/null +++ b/pkg/datastore/store/store.go @@ -0,0 +1,13 @@ +package store + +import ( + "context" + "github.com/gptscript-ai/knowledge/pkg/index" + vs "github.com/gptscript-ai/knowledge/pkg/vectorstore" +) + +type Store interface { + ListDatasets(ctx context.Context) ([]index.Dataset, error) + GetDataset(ctx context.Context, datasetID string) (*index.Dataset, error) + SimilaritySearch(ctx context.Context, query string, numDocuments int, collection string) ([]vs.Document, error) +} diff --git a/pkg/flows/flows.go b/pkg/flows/flows.go index e06fcdd..b8d07b9 100644 --- a/pkg/flows/flows.go +++ b/pkg/flows/flows.go @@ -3,6 +3,7 @@ package flows import ( "context" "fmt" + "github.com/gptscript-ai/knowledge/pkg/datastore/store" "io" "log/slog" "slices" @@ -113,7 +114,7 @@ func (f *RetrievalFlow) FillDefaults(topK int) { } } -func (f *RetrievalFlow) Run(ctx context.Context, store vs.VectorStore, query string, datasetID string) ([]vs.Document, error) { +func (f *RetrievalFlow) Run(ctx context.Context, store store.Store, query string, datasetID string) ([]vs.Document, error) { var err error originalQuery := query diff --git a/pkg/index/datasets.go b/pkg/index/datasets.go index cf2b16d..a4b58e6 100644 --- a/pkg/index/datasets.go +++ b/pkg/index/datasets.go @@ -87,3 +87,16 @@ func (db *DB) ImportDatasetsFromFile(ctx context.Context, path string) error { return nil } + +func (db *DB) UpdateDataset(ctx context.Context, dataset Dataset) error { + gdb := db.gormDB.WithContext(ctx) + + slog.Debug("Updating dataset in DB", "id", dataset.ID, "metadata", dataset.Metadata) + err := gdb.Save(dataset).Error + if err != nil { + return err + } + + gdb.Commit() + return nil +} diff --git a/pkg/index/metadata.go b/pkg/index/metadata.go new file mode 100644 index 0000000..84e130c --- /dev/null +++ b/pkg/index/metadata.go @@ -0,0 +1,40 @@ +package index + +import ( + "fmt" + "slices" +) + +// SetMetadataField sets a metadata field in the dataset. If the metadata does not exist, it will be created. +// If the metadata field already exists, it will be overwritten. +func (d *Dataset) SetMetadataField(key string, value interface{}) { + if d.Metadata == nil { + d.Metadata = make(map[string]interface{}) + } + d.Metadata[key] = value + d.cleanMetadata() +} + +// ReplaceMetadata replaces the metadata of the dataset with the given metadata. +func (d *Dataset) ReplaceMetadata(metadata map[string]interface{}) { + d.Metadata = metadata + d.cleanMetadata() +} + +// UpdateMetadata updates the metadata of the dataset with the given metadata. +// If a metadata field already exists, it will be overwritten. If a metadata field does not exist, it will be created. +// Existing metadata fields that are not present in the given metadata will remain unchanged. +func (d *Dataset) UpdateMetadata(metadata map[string]interface{}) { + for k, v := range metadata { + d.SetMetadataField(k, v) + } + d.cleanMetadata() +} + +func (d *Dataset) cleanMetadata() { + for k, v := range d.Metadata { + if v == nil || slices.Contains([]string{"", "-", "null", "nil"}, fmt.Sprintf("%v", v)) { + delete(d.Metadata, k) + } + } +} diff --git a/pkg/index/models.go b/pkg/index/models.go index 4e21d00..297cde0 100644 --- a/pkg/index/models.go +++ b/pkg/index/models.go @@ -7,9 +7,10 @@ import ( // Dataset refers to a VectorDB data space. // @Description Dataset refers to a VectorDB data space. type Dataset struct { - ID string `gorm:"primaryKey" json:"id"` - EmbedDimension int `json:"embed_dim,omitempty"` - Files []File `gorm:"foreignKey:Dataset;references:ID;constraint:OnDelete:CASCADE;"` + ID string `gorm:"primaryKey" json:"id"` + EmbedDimension int `json:"embed_dim,omitempty"` + Files []File `gorm:"foreignKey:Dataset;references:ID;constraint:OnDelete:CASCADE;"` + Metadata map[string]any `json:"metadata,omitempty" gorm:"serializer:json"` } type File struct {