Skip to content
This repository has been archived by the owner on Oct 30, 2024. It is now read-only.

Commit

Permalink
add: dataset metadata + routing retriever (#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
iwilltry42 authored Jul 3, 2024
1 parent 26373f2 commit f5559ed
Show file tree
Hide file tree
Showing 16 changed files with 318 additions and 8 deletions.
14 changes: 14 additions & 0 deletions examples/routing_retriever.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
flows:
foo:
default: true
retrieval:
retriever:
name: routing
options:
model:
openai:
apiKey: "${OPENAI_API_KEY}"
model: gpt-4o
apiType: OPEN_AI
apiBase: https://api.openai.com/v1
topK: 6
1 change: 1 addition & 0 deletions pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,5 @@ type Client interface {
Retrieve(ctx context.Context, datasetID string, query string, opts datastore.RetrieveOpts) ([]vectorstore.Document, error)
ExportDatasets(ctx context.Context, path string, datasets ...string) error
ImportDatasets(ctx context.Context, path string, datasets ...string) error
UpdateDataset(ctx context.Context, dataset index.Dataset, opts *datastore.UpdateDatasetOpts) (*index.Dataset, error)
}
5 changes: 5 additions & 0 deletions pkg/client/default.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,8 @@ func (c *DefaultClient) ImportDatasets(ctx context.Context, path string, dataset
// TODO: implement
panic("not implemented")
}

func (c *DefaultClient) UpdateDataset(ctx context.Context, dataset index.Dataset, opts *datastore.UpdateDatasetOpts) (*index.Dataset, error) {
// TODO: implement
panic("not implemented")
}
4 changes: 4 additions & 0 deletions pkg/client/standalone.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,7 @@ func (c *StandaloneClient) ExportDatasets(ctx context.Context, path string, data
func (c *StandaloneClient) ImportDatasets(ctx context.Context, path string, datasets ...string) error {
return c.Datastore.ImportDatasetsFromFile(ctx, path, datasets...)
}

func (c *StandaloneClient) UpdateDataset(ctx context.Context, dataset index.Dataset, opts *datastore.UpdateDatasetOpts) (*index.Dataset, error) {
return c.Datastore.UpdateDataset(ctx, dataset, opts)
}
75 changes: 75 additions & 0 deletions pkg/cmd/edit_dataset.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package cmd

import (
"encoding/json"
"fmt"
"github.com/gptscript-ai/knowledge/pkg/datastore"
"github.com/gptscript-ai/knowledge/pkg/index"
"github.com/spf13/cobra"
)

type ClientEditDataset struct {
Client
ResetMetadata bool `usage:"reset metadata to default (empty)"`
UpdateMetadata map[string]string `usage:"update metadata key-value pairs (existing metadata will be updated/preserved)"`
ReplaceMetadata map[string]string `usage:"replace metadata with key-value pairs (existing metadata will be removed)"`
}

func (s *ClientEditDataset) Customize(cmd *cobra.Command) {
cmd.Use = "edit-dataset <dataset-id>"
cmd.Short = "Edit an existing dataset"
cmd.Args = cobra.ExactArgs(1)
cmd.MarkFlagsMutuallyExclusive("reset-metadata", "update-metadata", "replace-metadata")
}

func (s *ClientEditDataset) Run(cmd *cobra.Command, args []string) error {
c, err := s.getClient()
if err != nil {
return err
}

datasetID := args[0]

// Get current dataset
dataset, err := c.GetDataset(cmd.Context(), datasetID)
if err != nil {
return fmt.Errorf("failed to get dataset: %w", err)
}

if dataset == nil {
fmt.Printf("dataset not found: %q\n", datasetID)
return fmt.Errorf("dataset not found: %s", datasetID)
}

updatedDataset := index.Dataset{
ID: dataset.ID,
}

// Update Metadata - since flags are mutually exclusive, this should be either an empty map, or one of the update/replace maps
metadata := map[string]any{}

for k, v := range s.UpdateMetadata {
metadata[k] = v
}

for k, v := range s.ReplaceMetadata {
metadata[k] = v
}

updatedDataset.Metadata = metadata

dataset, err = c.UpdateDataset(cmd.Context(), updatedDataset, &datastore.UpdateDatasetOpts{ReplaceMedata: s.ResetMetadata || len(s.ReplaceMetadata) > 0})
if err != nil {
return fmt.Errorf("failed to update dataset: %w", err)
}

dataset.Files = nil // Don't print files

jsonOutput, err := json.Marshal(dataset)
if err != nil {
return fmt.Errorf("failed to marshal dataset: %w", err)
}

fmt.Println("Updated dataset:\n", string(jsonOutput))
return nil
}
1 change: 1 addition & 0 deletions pkg/cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ func New() *cobra.Command {
new(ClientAskDir),
new(ClientExportDatasets),
new(ClientImportDatasets),
new(ClientEditDataset),
new(Version),
)
}
Expand Down
45 changes: 45 additions & 0 deletions pkg/datastore/dataset.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ import (
"gorm.io/gorm"
)

type UpdateDatasetOpts struct {
ReplaceMedata bool
}

func (s *Datastore) NewDataset(ctx context.Context, dataset index.Dataset) error {
// Set defaults
if dataset.EmbedDimension <= 0 {
Expand Down Expand Up @@ -75,3 +79,44 @@ func (s *Datastore) ListDatasets(ctx context.Context) ([]index.Dataset, error) {

return datasets, nil
}

func (s *Datastore) UpdateDataset(ctx context.Context, updatedDataset index.Dataset, opts *UpdateDatasetOpts) (*index.Dataset, error) {
if opts == nil {
opts = &UpdateDatasetOpts{}
}

var origDS *index.Dataset
var err error

if updatedDataset.ID == "" {
return origDS, fmt.Errorf("dataset ID is required")
}

origDS, err = s.GetDataset(ctx, updatedDataset.ID)
if err != nil {
return origDS, err
}
if origDS == nil {
return origDS, fmt.Errorf("dataset not found: %s", updatedDataset.ID)
}

// Update Metadata
if opts.ReplaceMedata {
origDS.ReplaceMetadata(updatedDataset.Metadata)
} else {
origDS.UpdateMetadata(updatedDataset.Metadata)
}

// Check if there is any other non-null field in the updatedDataset
if updatedDataset.EmbedDimension > 0 {
return origDS, fmt.Errorf("embedding dimension cannot be updated")
}

if updatedDataset.Files != nil {
return origDS, fmt.Errorf("files cannot be updated")
}

slog.Debug("Updating dataset", "id", updatedDataset.ID, "metadata", updatedDataset.Metadata)

return origDS, s.Index.UpdateDataset(ctx, *origDS)
}
6 changes: 5 additions & 1 deletion pkg/datastore/retrieve.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,9 @@ func (s *Datastore) Retrieve(ctx context.Context, datasetID string, query string
}
retrievalFlow.FillDefaults(topK)

return retrievalFlow.Run(ctx, s.Vectorstore, query, datasetID)
return retrievalFlow.Run(ctx, s, query, datasetID)
}

func (s *Datastore) SimilaritySearch(ctx context.Context, query string, numDocuments int, datasetID string) ([]vectorstore.Document, error) {
return s.Vectorstore.SimilaritySearch(ctx, query, numDocuments, datasetID)
}
7 changes: 5 additions & 2 deletions pkg/datastore/retrievers/retrievers.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@ package retrievers
import (
"context"
"fmt"
"github.com/gptscript-ai/knowledge/pkg/datastore/store"
"log/slog"

"github.com/gptscript-ai/knowledge/pkg/datastore/defaults"
vs "github.com/gptscript-ai/knowledge/pkg/vectorstore"
)

type Retriever interface {
Retrieve(ctx context.Context, store vs.VectorStore, query string, datasetID string) ([]vs.Document, error)
Retrieve(ctx context.Context, store store.Store, query string, datasetID string) ([]vs.Document, error)
}

func GetRetriever(name string) (Retriever, error) {
Expand All @@ -19,6 +20,8 @@ func GetRetriever(name string) (Retriever, error) {
return &BasicRetriever{TopK: defaults.TopK}, nil
case "subquery":
return &SubqueryRetriever{Limit: 3, TopK: 3}, nil
case "routing":
return &RoutingRetriever{TopK: defaults.TopK}, nil
default:
return nil, fmt.Errorf("unknown retriever %q", name)
}
Expand All @@ -32,7 +35,7 @@ type BasicRetriever struct {
TopK int
}

func (r *BasicRetriever) Retrieve(ctx context.Context, store vs.VectorStore, query string, datasetID string) ([]vs.Document, error) {
func (r *BasicRetriever) Retrieve(ctx context.Context, store store.Store, query string, datasetID string) ([]vs.Document, error) {
if r.TopK <= 0 {
slog.Debug("[BasicRetriever] TopK not set, using default", "default", defaults.TopK)
r.TopK = defaults.TopK
Expand Down
89 changes: 89 additions & 0 deletions pkg/datastore/retrievers/routing.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package retrievers

import (
"context"
"encoding/json"
"fmt"
"github.com/gptscript-ai/knowledge/pkg/datastore/defaults"
"github.com/gptscript-ai/knowledge/pkg/datastore/store"
"github.com/gptscript-ai/knowledge/pkg/llm"
vs "github.com/gptscript-ai/knowledge/pkg/vectorstore"
"log/slog"
)

type RoutingRetriever struct {
Model llm.LLMConfig
AvailableDatasets []string
TopK int
}

var routingPromptTpl = `The following query will be used for a vector similarity search.
Please route it to the appropriate dataset. Choose the one that fits best to the query based on the metadata.
Query: "{{.query}}"
Available datasets in a JSON map, where the key is the dataset ID and the value is a map of metadata fields:
{{ .datasets }}
Reply only in the following JSON format, without any styling or markdown syntax:
{"result": "<dataset-id>"}`

type routingResp struct {
Result string `json:"result"`
}

func (r *RoutingRetriever) Retrieve(ctx context.Context, store store.Store, query string, datasetID string) ([]vs.Document, error) {
log := slog.With("component", "RoutingRetriever")

log.Debug("Ignoring input datasetID in routing retriever, as it chooses on by itself", "query", query, "inputDataset", datasetID)

if r.TopK <= 0 {
log.Debug("TopK not set, using default", "default", defaults.TopK)
r.TopK = defaults.TopK
}

if len(r.AvailableDatasets) == 0 {
allDatasets, err := store.ListDatasets(ctx)
if err != nil {
return nil, err
}
for _, ds := range allDatasets {
r.AvailableDatasets = append(r.AvailableDatasets, ds.ID)
}
}
slog.Debug("Available datasets", "datasets", r.AvailableDatasets)

datasets := map[string]map[string]any{}
for _, dsID := range r.AvailableDatasets {
dataset, err := store.GetDataset(ctx, dsID)
if err != nil {
return nil, err
}
if dataset == nil {
return nil, fmt.Errorf("dataset not found: %q", dsID)
}
datasets[dataset.ID] = dataset.Metadata
}

datasetsJSON, err := json.Marshal(datasets)
if err != nil {
return nil, err
}

m, err := llm.NewFromConfig(r.Model)
if err != nil {
return nil, err
}

result, err := m.Prompt(context.Background(), routingPromptTpl, map[string]interface{}{"query": query, "datasets": string(datasetsJSON)})
if err != nil {
return nil, err
}
slog.Debug("Routing result", "result", result)
var resp routingResp
err = json.Unmarshal([]byte(result), &resp)
if err != nil {
return nil, err
}

slog.Debug("Routing query to dataset", "query", query, "dataset", resp.Result)

return store.SimilaritySearch(ctx, query, r.TopK, resp.Result)
}
3 changes: 2 additions & 1 deletion pkg/datastore/retrievers/subquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/gptscript-ai/knowledge/pkg/datastore/store"
"github.com/gptscript-ai/knowledge/pkg/llm"
vs "github.com/gptscript-ai/knowledge/pkg/vectorstore"
"log/slog"
Expand All @@ -30,7 +31,7 @@ type subqueryResp struct {
Results []string `json:"results"`
}

func (s SubqueryRetriever) Retrieve(ctx context.Context, store vs.VectorStore, query string, datasetID string) ([]vs.Document, error) {
func (s SubqueryRetriever) Retrieve(ctx context.Context, store store.Store, query string, datasetID string) ([]vs.Document, error) {
m, err := llm.NewFromConfig(s.Model)
if err != nil {
return nil, err
Expand Down
13 changes: 13 additions & 0 deletions pkg/datastore/store/store.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package store

import (
"context"
"github.com/gptscript-ai/knowledge/pkg/index"
vs "github.com/gptscript-ai/knowledge/pkg/vectorstore"
)

type Store interface {
ListDatasets(ctx context.Context) ([]index.Dataset, error)
GetDataset(ctx context.Context, datasetID string) (*index.Dataset, error)
SimilaritySearch(ctx context.Context, query string, numDocuments int, collection string) ([]vs.Document, error)
}
3 changes: 2 additions & 1 deletion pkg/flows/flows.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package flows
import (
"context"
"fmt"
"github.com/gptscript-ai/knowledge/pkg/datastore/store"
"io"
"log/slog"
"slices"
Expand Down Expand Up @@ -113,7 +114,7 @@ func (f *RetrievalFlow) FillDefaults(topK int) {
}
}

func (f *RetrievalFlow) Run(ctx context.Context, store vs.VectorStore, query string, datasetID string) ([]vs.Document, error) {
func (f *RetrievalFlow) Run(ctx context.Context, store store.Store, query string, datasetID string) ([]vs.Document, error) {
var err error
originalQuery := query

Expand Down
13 changes: 13 additions & 0 deletions pkg/index/datasets.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,16 @@ func (db *DB) ImportDatasetsFromFile(ctx context.Context, path string) error {

return nil
}

func (db *DB) UpdateDataset(ctx context.Context, dataset Dataset) error {
gdb := db.gormDB.WithContext(ctx)

slog.Debug("Updating dataset in DB", "id", dataset.ID, "metadata", dataset.Metadata)
err := gdb.Save(dataset).Error
if err != nil {
return err
}

gdb.Commit()
return nil
}
Loading

0 comments on commit f5559ed

Please sign in to comment.