diff --git a/pkg/datastore/defaults/defaults.go b/pkg/datastore/defaults/defaults.go index b8e6e7aa..81face0f 100644 --- a/pkg/datastore/defaults/defaults.go +++ b/pkg/datastore/defaults/defaults.go @@ -1,4 +1,11 @@ package defaults -const EmbeddingDimension int = 1536 -const TopK int = 5 +const ( + EmbeddingDimension int = 1536 + TopK int = 5 + + TextSplitterTokenModel = "gpt-4" + TextSplitterChunkSize = 1024 + TextSplitterChunkOverlap = 256 + TextSplitterTokenEncoding = "cl100k_base" +) diff --git a/pkg/datastore/documentloader/documentloaders_test.go b/pkg/datastore/documentloader/documentloaders_test.go new file mode 100644 index 00000000..eb96de02 --- /dev/null +++ b/pkg/datastore/documentloader/documentloaders_test.go @@ -0,0 +1,54 @@ +package documentloader + +import ( + "context" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGetDocumentLoaderConfig_ValidLoader(t *testing.T) { + cfg, err := GetDocumentLoaderConfig("pdf") + assert.NoError(t, err) + assert.IsTypef(t, PDFOptions{}, cfg, "cfg is not of type PDFOptions") +} + +func TestGetDocumentLoaderConfig_InvalidLoader(t *testing.T) { + _, err := GetDocumentLoaderConfig("invalid") + assert.Error(t, err) +} + +func TestGetDocumentLoaderFunc_ValidLoaderWithoutConfig(t *testing.T) { + _, err := GetDocumentLoaderFunc("plaintext", nil) + assert.NoError(t, err) +} + +func TestGetDocumentLoaderFunc_ValidLoaderWithInvalidConfig(t *testing.T) { + _, err := GetDocumentLoaderFunc("pdf", "invalid") + assert.Error(t, err) +} + +func TestGetDocumentLoaderFunc_ValidLoaderWithValidConfig(t *testing.T) { + _, err := GetDocumentLoaderFunc("pdf", PDFOptions{}) + assert.NoError(t, err) +} + +func TestGetDocumentLoaderFunc_InvalidLoader(t *testing.T) { + _, err := GetDocumentLoaderFunc("invalid", nil) + assert.Error(t, err) +} + +func TestGetDocumentLoaderFunc_LoadPlainText(t *testing.T) { + loaderFunc, _ := GetDocumentLoaderFunc("plaintext", nil) + docs, err := loaderFunc(context.Background(), strings.NewReader("test")) + assert.NoError(t, err) + assert.Len(t, docs, 1) + assert.Equal(t, "test", docs[0].Content) +} + +func TestGetDocumentLoaderFunc_LoadPDF(t *testing.T) { + loaderFunc, _ := GetDocumentLoaderFunc("pdf", PDFOptions{}) + _, err := loaderFunc(context.Background(), strings.NewReader("test")) + assert.Error(t, err) +} diff --git a/pkg/datastore/ingest.go b/pkg/datastore/ingest.go index da9782b9..018b6528 100644 --- a/pkg/datastore/ingest.go +++ b/pkg/datastore/ingest.go @@ -26,12 +26,7 @@ import ( "strings" ) -const ( - defaultTokenModel = "gpt-4" - defaultChunkSize = 1024 - defaultChunkOverlap = 256 - defaultTokenEncoding = "cl100k_base" -) +const () var firstclassFileExtensions = map[string]struct{}{ ".pdf": {}, @@ -51,7 +46,7 @@ type IngestOpts struct { FileMetadata *index.FileMetadata IsDuplicateFuncName string IsDuplicateFunc IsDuplicateFunc - TextSplitterOpts *TextSplitterOpts + TextSplitterOpts *textsplitter.TextSplitterOpts } // Ingest loads a document from a reader and adds it to the dataset. @@ -270,12 +265,12 @@ func DefaultDocLoaderFunc(filetype string) func(ctx context.Context, reader io.R } } -func DefaultTextSplitter(filetype string, textSplitterOpts *TextSplitterOpts) types.TextSplitter { +func DefaultTextSplitter(filetype string, textSplitterOpts *textsplitter.TextSplitterOpts) types.TextSplitter { if textSplitterOpts == nil { - textSplitterOpts = z.Pointer(NewTextSplitterOpts()) + textSplitterOpts = z.Pointer(textsplitter.NewTextSplitterOpts()) } - genericTextSplitter := textsplitter.FromLangchain(NewLcgoTextSplitter(*textSplitterOpts)) - markdownTextSplitter := textsplitter.FromLangchain(NewLcgoMarkdownSplitter(*textSplitterOpts)) + genericTextSplitter := textsplitter.FromLangchain(textsplitter.NewLcgoTextSplitter(*textSplitterOpts)) + markdownTextSplitter := textsplitter.FromLangchain(textsplitter.NewLcgoMarkdownSplitter(*textSplitterOpts)) switch filetype { case ".md", "text/markdown": diff --git a/pkg/datastore/textsplitter.go b/pkg/datastore/textsplitter.go index 075a1c19..15e76f05 100644 --- a/pkg/datastore/textsplitter.go +++ b/pkg/datastore/textsplitter.go @@ -2,47 +2,9 @@ package datastore import ( vs "github.com/gptscript-ai/knowledge/pkg/vectorstore" - lcgosplitter "github.com/tmc/langchaingo/textsplitter" "strings" ) -type TextSplitterOpts struct { - ChunkSize int `usage:"Textsplitter Chunk Size" default:"1024" env:"KNOW_TEXTSPLITTER_CHUNK_SIZE" name:"textsplitter-chunk-size"` - ChunkOverlap int `usage:"Textsplitter Chunk Overlap" default:"256" env:"KNOW_TEXTSPLITTER_CHUNK_OVERLAP" name:"textsplitter-chunk-overlap"` - ModelName string `usage:"Textsplitter Model Name" default:"gpt-4" env:"KNOW_TEXTSPLITTER_MODEL_NAME" name:"textsplitter-model-name"` - EncodingName string `usage:"Textsplitter Encoding Name" default:"cl100k_base" env:"KNOW_TEXTSPLITTER_ENCODING_NAME" name:"textsplitter-encoding-name"` -} - -// NewTextSplitterOpts returns the default options for a text splitter. -func NewTextSplitterOpts() TextSplitterOpts { - return TextSplitterOpts{ - ChunkSize: defaultChunkSize, - ChunkOverlap: defaultChunkOverlap, - ModelName: defaultTokenModel, - EncodingName: defaultTokenEncoding, - } -} - -// NewLcgoTextSplitter returns a new langchain-go text splitter. -func NewLcgoTextSplitter(opts TextSplitterOpts) lcgosplitter.TokenSplitter { - return lcgosplitter.NewTokenSplitter( - lcgosplitter.WithChunkSize(opts.ChunkSize), - lcgosplitter.WithChunkOverlap(opts.ChunkOverlap), - lcgosplitter.WithModelName(opts.ModelName), - lcgosplitter.WithEncodingName(opts.EncodingName), - ) -} - -func NewLcgoMarkdownSplitter(opts TextSplitterOpts) *lcgosplitter.MarkdownTextSplitter { - return lcgosplitter.NewMarkdownTextSplitter( - lcgosplitter.WithChunkSize(opts.ChunkSize), - lcgosplitter.WithChunkOverlap(opts.ChunkOverlap), - lcgosplitter.WithModelName(opts.ModelName), - lcgosplitter.WithEncodingName(opts.EncodingName), - lcgosplitter.WithHeadingHierarchy(true), - ) -} - // FilterMarkdownDocsNoContent filters out Markdown documents with no content or only headings // // TODO: this may be moved into the MarkdownTextSplitter as well diff --git a/pkg/datastore/textsplitter/textsplitter.go b/pkg/datastore/textsplitter/textsplitter.go new file mode 100644 index 00000000..468f8764 --- /dev/null +++ b/pkg/datastore/textsplitter/textsplitter.go @@ -0,0 +1,82 @@ +package textsplitter + +import ( + "fmt" + "github.com/gptscript-ai/knowledge/pkg/datastore/defaults" + vs "github.com/gptscript-ai/knowledge/pkg/vectorstore" + lcgosplitter "github.com/tmc/langchaingo/textsplitter" +) + +type SplitterFunc func([]vs.Document) ([]vs.Document, error) + +type TextSplitterOpts struct { + ChunkSize int `usage:"Textsplitter Chunk Size" default:"1024" env:"KNOW_TEXTSPLITTER_CHUNK_SIZE" name:"textsplitter-chunk-size"` + ChunkOverlap int `usage:"Textsplitter Chunk Overlap" default:"256" env:"KNOW_TEXTSPLITTER_CHUNK_OVERLAP" name:"textsplitter-chunk-overlap"` + ModelName string `usage:"Textsplitter Model Name" default:"gpt-4" env:"KNOW_TEXTSPLITTER_MODEL_NAME" name:"textsplitter-model-name"` + EncodingName string `usage:"Textsplitter Encoding Name" default:"cl100k_base" env:"KNOW_TEXTSPLITTER_ENCODING_NAME" name:"textsplitter-encoding-name"` +} + +// NewTextSplitterOpts returns the default options for a text splitter. +func NewTextSplitterOpts() TextSplitterOpts { + return TextSplitterOpts{ + ChunkSize: defaults.TextSplitterChunkSize, + ChunkOverlap: defaults.TextSplitterChunkOverlap, + ModelName: defaults.TextSplitterTokenModel, + EncodingName: defaults.TextSplitterTokenEncoding, + } +} + +// NewLcgoTextSplitter returns a new langchain-go text splitter. +func NewLcgoTextSplitter(opts TextSplitterOpts) lcgosplitter.TokenSplitter { + return lcgosplitter.NewTokenSplitter( + lcgosplitter.WithChunkSize(opts.ChunkSize), + lcgosplitter.WithChunkOverlap(opts.ChunkOverlap), + lcgosplitter.WithModelName(opts.ModelName), + lcgosplitter.WithEncodingName(opts.EncodingName), + ) +} + +func NewLcgoMarkdownSplitter(opts TextSplitterOpts) *lcgosplitter.MarkdownTextSplitter { + return lcgosplitter.NewMarkdownTextSplitter( + lcgosplitter.WithChunkSize(opts.ChunkSize), + lcgosplitter.WithChunkOverlap(opts.ChunkOverlap), + lcgosplitter.WithModelName(opts.ModelName), + lcgosplitter.WithEncodingName(opts.EncodingName), + lcgosplitter.WithHeadingHierarchy(true), + ) +} + +func GetTextSplitterConfig(name string) (any, error) { + // TODO: expose splitter-specific config, not only our top-level options + switch name { + case "text", "markdown": + return TextSplitterOpts{}, nil + default: + return nil, fmt.Errorf("unknown text splitter %q", name) + } +} + +func GetTextSplitterFunc(name string, config any) (SplitterFunc, error) { + switch name { + case "text": + if config == nil { + config = NewTextSplitterOpts() + } + config, ok := config.(TextSplitterOpts) + if !ok { + return nil, fmt.Errorf("invalid text splitter configuration") + } + return FromLangchain(NewLcgoTextSplitter(config)).SplitDocuments, nil + case "markdown": + if config == nil { + config = NewTextSplitterOpts() + } + config, ok := config.(TextSplitterOpts) + if !ok { + return nil, fmt.Errorf("invalid markdown text splitter configuration") + } + return FromLangchain(NewLcgoMarkdownSplitter(config)).SplitDocuments, nil + default: + return nil, fmt.Errorf("unknown text splitter %q", name) + } +} diff --git a/pkg/datastore/textsplitter/textsplitter_test.go b/pkg/datastore/textsplitter/textsplitter_test.go new file mode 100644 index 00000000..c1764142 --- /dev/null +++ b/pkg/datastore/textsplitter/textsplitter_test.go @@ -0,0 +1,36 @@ +package textsplitter + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestGetTextSplitterConfigWithValidName(t *testing.T) { + _, err := GetTextSplitterConfig("text") + assert.NoError(t, err) +} + +func TestGetTextSplitterConfigWithInvalidName(t *testing.T) { + _, err := GetTextSplitterConfig("invalid") + assert.Error(t, err) +} + +func TestGetTextSplitterFuncWithValidNameAndNilConfig(t *testing.T) { + _, err := GetTextSplitterFunc("text", nil) + assert.NoError(t, err) +} + +func TestGetTextSplitterFuncWithValidNameAndInvalidConfig(t *testing.T) { + _, err := GetTextSplitterFunc("text", "invalid") + assert.Error(t, err) +} + +func TestGetTextSplitterFuncWithValidNameAndValidConfig(t *testing.T) { + _, err := GetTextSplitterFunc("text", NewTextSplitterOpts()) + assert.NoError(t, err) +} + +func TestGetTextSplitterFuncWithInvalidName(t *testing.T) { + _, err := GetTextSplitterFunc("invalid", nil) + assert.Error(t, err) +} diff --git a/pkg/datastore/types/types.go b/pkg/datastore/types/types.go index ddbecdc0..1b5bd218 100644 --- a/pkg/datastore/types/types.go +++ b/pkg/datastore/types/types.go @@ -3,7 +3,6 @@ package types import ( "context" vs "github.com/gptscript-ai/knowledge/pkg/vectorstore" - "io" ) type DocumentTransformerFunc func(context.Context, []vs.Document) ([]vs.Document, error) @@ -12,10 +11,6 @@ type DocumentTransformer interface { Transform(context.Context, []vs.Document) ([]vs.Document, error) } -type DocumentLoaderFunc func(context.Context, io.Reader) ([]vs.Document, error) - -type TextSplitterFunc func([]vs.Document) ([]vs.Document, error) - type DocumentLoader interface { Load(ctx context.Context) ([]vs.Document, error) LoadAndSplit(ctx context.Context, splitter TextSplitter) ([]vs.Document, error) diff --git a/pkg/flows/config/config.go b/pkg/flows/config/config.go index fb010f62..68ed4030 100644 --- a/pkg/flows/config/config.go +++ b/pkg/flows/config/config.go @@ -3,6 +3,7 @@ package config import ( "encoding/json" "github.com/gptscript-ai/knowledge/pkg/datastore/documentloader" + "github.com/gptscript-ai/knowledge/pkg/datastore/textsplitter" "github.com/gptscript-ai/knowledge/pkg/flows" "os" "sigs.k8s.io/yaml" @@ -79,6 +80,34 @@ func (i *IngestionFlowConfig) AsIngestionFlow() (*flows.IngestionFlow, error) { return nil, err } } + loaderFunc, err := documentloader.GetDocumentLoaderFunc(name, cfg) + if err != nil { + return nil, err + } + flow.Load = loaderFunc + } + + if i.TextSplitter.Name != "" { + name := strings.ToLower(strings.Trim(i.TextSplitter.Name, " ")) + cfg, err := textsplitter.GetTextSplitterConfig(name) + if err != nil { + return nil, err + } + if len(i.TextSplitter.Options) > 0 { + jsondata, err := json.Marshal(i.TextSplitter.Options) + if err != nil { + return nil, err + } + err = json.Unmarshal(jsondata, &cfg) + if err != nil { + return nil, err + } + } + splitterFunc, err := textsplitter.GetTextSplitterFunc(name, cfg) + if err != nil { + return nil, err + } + flow.Split = splitterFunc } return flow, nil diff --git a/pkg/flows/config/config_test.go b/pkg/flows/config/config_test.go index 6aedff53..d63ae2d6 100644 --- a/pkg/flows/config/config_test.go +++ b/pkg/flows/config/config_test.go @@ -2,13 +2,14 @@ package config import ( "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "testing" ) func TestLoadConfigFromValidJSONFile(t *testing.T) { cfg, err := FromFile("testdata/valid.json") assert.NoError(t, err) - assert.NotNil(t, cfg) + require.NotNil(t, cfg) assert.NotEmpty(t, cfg.Flows) assert.Equal(t, 2, len(cfg.Flows)) assert.Equal(t, 1, len(cfg.Flows["flow1"].Ingestion)) @@ -19,7 +20,7 @@ func TestLoadConfigFromValidJSONFile(t *testing.T) { func TestLoadConfigFromValidYAMLFile(t *testing.T) { cfg, err := FromFile("testdata/valid.yaml") assert.NoError(t, err) - assert.NotNil(t, cfg) + require.NotNil(t, cfg) assert.NotEmpty(t, cfg.Flows) } diff --git a/pkg/flows/config/testdata/valid.json b/pkg/flows/config/testdata/valid.json index f0f4f8bb..b7f6360f 100644 --- a/pkg/flows/config/testdata/valid.json +++ b/pkg/flows/config/testdata/valid.json @@ -4,8 +4,12 @@ "ingestion": [ { "filetypes": [".txt", ".md"], - "documentLoader": "textLoader", - "textSplitter": "simpleSplitter", + "documentLoader": { + "name": "textLoader" + }, + "textSplitter": { + "name": "simpleSplitter" + }, "transformers": ["transformer1", "transformer2"] } ], @@ -15,8 +19,12 @@ "ingestion": [ { "filetypes": [".json"], - "documentLoader": "jsonLoader", - "textSplitter": "jsonSplitter", + "documentLoader": { + "name": "jsonLoader" + }, + "textSplitter": { + "name": "jsonSplitter" + }, "transformers": ["transformer3"] } ], diff --git a/pkg/flows/config/testdata/valid.yaml b/pkg/flows/config/testdata/valid.yaml index 477aeb5a..f587f5ad 100644 --- a/pkg/flows/config/testdata/valid.yaml +++ b/pkg/flows/config/testdata/valid.yaml @@ -2,14 +2,20 @@ flows: flow1: ingestion: - filetypes: [".txt", ".md"] - documentLoader: "textLoader" - textSplitter: "simpleSplitter" - transformers: ["transformer1", "transformer2"] + documentLoader: + name: "textLoader" + textSplitter: + name: "simpleSplitter" + transformers: + - "transformer1" + - "transformer2" retrieval: {} flow2: ingestion: - filetypes: [".json"] - documentLoader: "jsonLoader" - textSplitter: "jsonSplitter" + documentLoader: + name: "jsonLoader" + textSplitter: + name: "jsonSplitter" transformers: ["transformer3"] retrieval: {} \ No newline at end of file diff --git a/pkg/flows/flows.go b/pkg/flows/flows.go index 143496ad..cf8409c8 100644 --- a/pkg/flows/flows.go +++ b/pkg/flows/flows.go @@ -2,13 +2,15 @@ package flows import ( "context" + "github.com/gptscript-ai/knowledge/pkg/datastore/documentloader" + "github.com/gptscript-ai/knowledge/pkg/datastore/textsplitter" dstypes "github.com/gptscript-ai/knowledge/pkg/datastore/types" vs "github.com/gptscript-ai/knowledge/pkg/vectorstore" ) type IngestionFlow struct { - Load dstypes.DocumentLoaderFunc - Split dstypes.TextSplitterFunc + Load documentloader.LoaderFunc + Split textsplitter.SplitterFunc Transformations []dstypes.DocumentTransformer }