Skip to content
This repository has been archived by the owner on Oct 30, 2024. It is now read-only.

Commit

Permalink
add: documentloader and textsplitter config
Browse files Browse the repository at this point in the history
  • Loading branch information
iwilltry42 committed May 27, 2024
1 parent 6a49f15 commit 38556c6
Show file tree
Hide file tree
Showing 16 changed files with 257 additions and 78 deletions.
3 changes: 2 additions & 1 deletion pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package client
import (
"context"
"github.com/gptscript-ai/knowledge/pkg/datastore"
"github.com/gptscript-ai/knowledge/pkg/datastore/textsplitter"
"github.com/gptscript-ai/knowledge/pkg/index"
"github.com/gptscript-ai/knowledge/pkg/server/types"
"github.com/gptscript-ai/knowledge/pkg/vectorstore"
Expand All @@ -12,7 +13,7 @@ type IngestPathsOpts struct {
IgnoreExtensions []string
Concurrency int
Recursive bool
TextSplitterOpts *datastore.TextSplitterOpts
TextSplitterOpts *textsplitter.TextSplitterOpts
}

type RetrieveOpts struct {
Expand Down
4 changes: 2 additions & 2 deletions pkg/cmd/ingest.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package cmd
import (
"fmt"
"github.com/gptscript-ai/knowledge/pkg/client"
"github.com/gptscript-ai/knowledge/pkg/datastore"
"github.com/gptscript-ai/knowledge/pkg/datastore/textsplitter"
"github.com/spf13/cobra"
"strings"
)
Expand All @@ -12,7 +12,7 @@ type ClientIngest struct {
Client
Dataset string `usage:"Target Dataset ID" short:"d" default:"default" env:"KNOW_TARGET_DATASET"`
ClientIngestOpts
datastore.TextSplitterOpts
textsplitter.TextSplitterOpts
}

type ClientIngestOpts struct {
Expand Down
11 changes: 9 additions & 2 deletions pkg/datastore/defaults/defaults.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
package defaults

const EmbeddingDimension int = 1536
const TopK int = 5
const (
EmbeddingDimension int = 1536
TopK int = 5

TextSplitterTokenModel = "gpt-4"
TextSplitterChunkSize = 1024
TextSplitterChunkOverlap = 256
TextSplitterTokenEncoding = "cl100k_base"
)
54 changes: 54 additions & 0 deletions pkg/datastore/documentloader/documentloaders_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package documentloader

import (
"context"
"strings"
"testing"

"github.com/stretchr/testify/assert"
)

func TestGetDocumentLoaderConfig_ValidLoader(t *testing.T) {
cfg, err := GetDocumentLoaderConfig("pdf")
assert.NoError(t, err)
assert.IsTypef(t, PDFOptions{}, cfg, "cfg is not of type PDFOptions")
}

func TestGetDocumentLoaderConfig_InvalidLoader(t *testing.T) {
_, err := GetDocumentLoaderConfig("invalid")
assert.Error(t, err)
}

func TestGetDocumentLoaderFunc_ValidLoaderWithoutConfig(t *testing.T) {
_, err := GetDocumentLoaderFunc("plaintext", nil)
assert.NoError(t, err)
}

func TestGetDocumentLoaderFunc_ValidLoaderWithInvalidConfig(t *testing.T) {
_, err := GetDocumentLoaderFunc("pdf", "invalid")
assert.Error(t, err)
}

func TestGetDocumentLoaderFunc_ValidLoaderWithValidConfig(t *testing.T) {
_, err := GetDocumentLoaderFunc("pdf", PDFOptions{})
assert.NoError(t, err)
}

func TestGetDocumentLoaderFunc_InvalidLoader(t *testing.T) {
_, err := GetDocumentLoaderFunc("invalid", nil)
assert.Error(t, err)
}

func TestGetDocumentLoaderFunc_LoadPlainText(t *testing.T) {
loaderFunc, _ := GetDocumentLoaderFunc("plaintext", nil)
docs, err := loaderFunc(context.Background(), strings.NewReader("test"))
assert.NoError(t, err)
assert.Len(t, docs, 1)
assert.Equal(t, "test", docs[0].Content)
}

func TestGetDocumentLoaderFunc_LoadPDF(t *testing.T) {
loaderFunc, _ := GetDocumentLoaderFunc("pdf", PDFOptions{})
_, err := loaderFunc(context.Background(), strings.NewReader("test"))
assert.Error(t, err)
}
17 changes: 6 additions & 11 deletions pkg/datastore/ingest.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,7 @@ import (
"strings"
)

const (
defaultTokenModel = "gpt-4"
defaultChunkSize = 1024
defaultChunkOverlap = 256
defaultTokenEncoding = "cl100k_base"
)
const ()

var firstclassFileExtensions = map[string]struct{}{
".pdf": {},
Expand All @@ -51,7 +46,7 @@ type IngestOpts struct {
FileMetadata *index.FileMetadata
IsDuplicateFuncName string
IsDuplicateFunc IsDuplicateFunc
TextSplitterOpts *TextSplitterOpts
TextSplitterOpts *textsplitter.TextSplitterOpts
}

// Ingest loads a document from a reader and adds it to the dataset.
Expand Down Expand Up @@ -270,12 +265,12 @@ func DefaultDocLoaderFunc(filetype string) func(ctx context.Context, reader io.R
}
}

func DefaultTextSplitter(filetype string, textSplitterOpts *TextSplitterOpts) types.TextSplitter {
func DefaultTextSplitter(filetype string, textSplitterOpts *textsplitter.TextSplitterOpts) types.TextSplitter {
if textSplitterOpts == nil {
textSplitterOpts = z.Pointer(NewTextSplitterOpts())
textSplitterOpts = z.Pointer(textsplitter.NewTextSplitterOpts())
}
genericTextSplitter := textsplitter.FromLangchain(NewLcgoTextSplitter(*textSplitterOpts))
markdownTextSplitter := textsplitter.FromLangchain(NewLcgoMarkdownSplitter(*textSplitterOpts))
genericTextSplitter := textsplitter.FromLangchain(textsplitter.NewLcgoTextSplitter(*textSplitterOpts))
markdownTextSplitter := textsplitter.FromLangchain(textsplitter.NewLcgoMarkdownSplitter(*textSplitterOpts))

switch filetype {
case ".md", "text/markdown":
Expand Down
3 changes: 2 additions & 1 deletion pkg/datastore/ingest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package datastore

import (
"context"
"github.com/gptscript-ai/knowledge/pkg/datastore/textsplitter"
"github.com/gptscript-ai/knowledge/pkg/datastore/transformers"
"github.com/gptscript-ai/knowledge/pkg/flows"
"github.com/stretchr/testify/require"
Expand All @@ -13,7 +14,7 @@ import (

func TestExtractPDF(t *testing.T) {
ctx := context.Background()
textSplitterOpts := NewTextSplitterOpts()
textSplitterOpts := textsplitter.NewTextSplitterOpts()
err := filepath.WalkDir("testdata/pdf", func(path string, d fs.DirEntry, err error) error {
if err != nil {
t.Fatalf("filepath.WalkDir() error = %v", err)
Expand Down
38 changes: 0 additions & 38 deletions pkg/datastore/textsplitter.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,47 +2,9 @@ package datastore

import (
vs "github.com/gptscript-ai/knowledge/pkg/vectorstore"
lcgosplitter "github.com/tmc/langchaingo/textsplitter"
"strings"
)

type TextSplitterOpts struct {
ChunkSize int `usage:"Textsplitter Chunk Size" default:"1024" env:"KNOW_TEXTSPLITTER_CHUNK_SIZE" name:"textsplitter-chunk-size"`
ChunkOverlap int `usage:"Textsplitter Chunk Overlap" default:"256" env:"KNOW_TEXTSPLITTER_CHUNK_OVERLAP" name:"textsplitter-chunk-overlap"`
ModelName string `usage:"Textsplitter Model Name" default:"gpt-4" env:"KNOW_TEXTSPLITTER_MODEL_NAME" name:"textsplitter-model-name"`
EncodingName string `usage:"Textsplitter Encoding Name" default:"cl100k_base" env:"KNOW_TEXTSPLITTER_ENCODING_NAME" name:"textsplitter-encoding-name"`
}

// NewTextSplitterOpts returns the default options for a text splitter.
func NewTextSplitterOpts() TextSplitterOpts {
return TextSplitterOpts{
ChunkSize: defaultChunkSize,
ChunkOverlap: defaultChunkOverlap,
ModelName: defaultTokenModel,
EncodingName: defaultTokenEncoding,
}
}

// NewLcgoTextSplitter returns a new langchain-go text splitter.
func NewLcgoTextSplitter(opts TextSplitterOpts) lcgosplitter.TokenSplitter {
return lcgosplitter.NewTokenSplitter(
lcgosplitter.WithChunkSize(opts.ChunkSize),
lcgosplitter.WithChunkOverlap(opts.ChunkOverlap),
lcgosplitter.WithModelName(opts.ModelName),
lcgosplitter.WithEncodingName(opts.EncodingName),
)
}

func NewLcgoMarkdownSplitter(opts TextSplitterOpts) *lcgosplitter.MarkdownTextSplitter {
return lcgosplitter.NewMarkdownTextSplitter(
lcgosplitter.WithChunkSize(opts.ChunkSize),
lcgosplitter.WithChunkOverlap(opts.ChunkOverlap),
lcgosplitter.WithModelName(opts.ModelName),
lcgosplitter.WithEncodingName(opts.EncodingName),
lcgosplitter.WithHeadingHierarchy(true),
)
}

// FilterMarkdownDocsNoContent filters out Markdown documents with no content or only headings
//
// TODO: this may be moved into the MarkdownTextSplitter as well
Expand Down
82 changes: 82 additions & 0 deletions pkg/datastore/textsplitter/textsplitter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package textsplitter

import (
"fmt"
"github.com/gptscript-ai/knowledge/pkg/datastore/defaults"
vs "github.com/gptscript-ai/knowledge/pkg/vectorstore"
lcgosplitter "github.com/tmc/langchaingo/textsplitter"
)

type SplitterFunc func([]vs.Document) ([]vs.Document, error)

type TextSplitterOpts struct {
ChunkSize int `usage:"Textsplitter Chunk Size" default:"1024" env:"KNOW_TEXTSPLITTER_CHUNK_SIZE" name:"textsplitter-chunk-size"`
ChunkOverlap int `usage:"Textsplitter Chunk Overlap" default:"256" env:"KNOW_TEXTSPLITTER_CHUNK_OVERLAP" name:"textsplitter-chunk-overlap"`
ModelName string `usage:"Textsplitter Model Name" default:"gpt-4" env:"KNOW_TEXTSPLITTER_MODEL_NAME" name:"textsplitter-model-name"`
EncodingName string `usage:"Textsplitter Encoding Name" default:"cl100k_base" env:"KNOW_TEXTSPLITTER_ENCODING_NAME" name:"textsplitter-encoding-name"`
}

// NewTextSplitterOpts returns the default options for a text splitter.
func NewTextSplitterOpts() TextSplitterOpts {
return TextSplitterOpts{
ChunkSize: defaults.TextSplitterChunkSize,
ChunkOverlap: defaults.TextSplitterChunkOverlap,
ModelName: defaults.TextSplitterTokenModel,
EncodingName: defaults.TextSplitterTokenEncoding,
}
}

// NewLcgoTextSplitter returns a new langchain-go text splitter.
func NewLcgoTextSplitter(opts TextSplitterOpts) lcgosplitter.TokenSplitter {
return lcgosplitter.NewTokenSplitter(
lcgosplitter.WithChunkSize(opts.ChunkSize),
lcgosplitter.WithChunkOverlap(opts.ChunkOverlap),
lcgosplitter.WithModelName(opts.ModelName),
lcgosplitter.WithEncodingName(opts.EncodingName),
)
}

func NewLcgoMarkdownSplitter(opts TextSplitterOpts) *lcgosplitter.MarkdownTextSplitter {
return lcgosplitter.NewMarkdownTextSplitter(
lcgosplitter.WithChunkSize(opts.ChunkSize),
lcgosplitter.WithChunkOverlap(opts.ChunkOverlap),
lcgosplitter.WithModelName(opts.ModelName),
lcgosplitter.WithEncodingName(opts.EncodingName),
lcgosplitter.WithHeadingHierarchy(true),
)
}

func GetTextSplitterConfig(name string) (any, error) {
// TODO: expose splitter-specific config, not only our top-level options
switch name {
case "text", "markdown":
return TextSplitterOpts{}, nil
default:
return nil, fmt.Errorf("unknown text splitter %q", name)
}
}

func GetTextSplitterFunc(name string, config any) (SplitterFunc, error) {
switch name {
case "text":
if config == nil {
config = NewTextSplitterOpts()
}
config, ok := config.(TextSplitterOpts)
if !ok {
return nil, fmt.Errorf("invalid text splitter configuration")
}
return FromLangchain(NewLcgoTextSplitter(config)).SplitDocuments, nil
case "markdown":
if config == nil {
config = NewTextSplitterOpts()
}
config, ok := config.(TextSplitterOpts)
if !ok {
return nil, fmt.Errorf("invalid markdown text splitter configuration")
}
return FromLangchain(NewLcgoMarkdownSplitter(config)).SplitDocuments, nil
default:
return nil, fmt.Errorf("unknown text splitter %q", name)
}
}
36 changes: 36 additions & 0 deletions pkg/datastore/textsplitter/textsplitter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package textsplitter

import (
"github.com/stretchr/testify/assert"
"testing"
)

func TestGetTextSplitterConfigWithValidName(t *testing.T) {
_, err := GetTextSplitterConfig("text")
assert.NoError(t, err)
}

func TestGetTextSplitterConfigWithInvalidName(t *testing.T) {
_, err := GetTextSplitterConfig("invalid")
assert.Error(t, err)
}

func TestGetTextSplitterFuncWithValidNameAndNilConfig(t *testing.T) {
_, err := GetTextSplitterFunc("text", nil)
assert.NoError(t, err)
}

func TestGetTextSplitterFuncWithValidNameAndInvalidConfig(t *testing.T) {
_, err := GetTextSplitterFunc("text", "invalid")
assert.Error(t, err)
}

func TestGetTextSplitterFuncWithValidNameAndValidConfig(t *testing.T) {
_, err := GetTextSplitterFunc("text", NewTextSplitterOpts())
assert.NoError(t, err)
}

func TestGetTextSplitterFuncWithInvalidName(t *testing.T) {
_, err := GetTextSplitterFunc("invalid", nil)
assert.Error(t, err)
}
5 changes: 0 additions & 5 deletions pkg/datastore/types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package types
import (
"context"
vs "github.com/gptscript-ai/knowledge/pkg/vectorstore"
"io"
)

type DocumentTransformerFunc func(context.Context, []vs.Document) ([]vs.Document, error)
Expand All @@ -12,10 +11,6 @@ type DocumentTransformer interface {
Transform(context.Context, []vs.Document) ([]vs.Document, error)
}

type DocumentLoaderFunc func(context.Context, io.Reader) ([]vs.Document, error)

type TextSplitterFunc func([]vs.Document) ([]vs.Document, error)

type DocumentLoader interface {
Load(ctx context.Context) ([]vs.Document, error)
LoadAndSplit(ctx context.Context, splitter TextSplitter) ([]vs.Document, error)
Expand Down
29 changes: 29 additions & 0 deletions pkg/flows/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package config
import (
"encoding/json"
"github.com/gptscript-ai/knowledge/pkg/datastore/documentloader"
"github.com/gptscript-ai/knowledge/pkg/datastore/textsplitter"
"github.com/gptscript-ai/knowledge/pkg/flows"
"os"
"sigs.k8s.io/yaml"
Expand Down Expand Up @@ -79,6 +80,34 @@ func (i *IngestionFlowConfig) AsIngestionFlow() (*flows.IngestionFlow, error) {
return nil, err
}
}
loaderFunc, err := documentloader.GetDocumentLoaderFunc(name, cfg)
if err != nil {
return nil, err
}
flow.Load = loaderFunc
}

if i.TextSplitter.Name != "" {
name := strings.ToLower(strings.Trim(i.TextSplitter.Name, " "))
cfg, err := textsplitter.GetTextSplitterConfig(name)
if err != nil {
return nil, err
}
if len(i.TextSplitter.Options) > 0 {
jsondata, err := json.Marshal(i.TextSplitter.Options)
if err != nil {
return nil, err
}
err = json.Unmarshal(jsondata, &cfg)
if err != nil {
return nil, err
}
}
splitterFunc, err := textsplitter.GetTextSplitterFunc(name, cfg)
if err != nil {
return nil, err
}
flow.Split = splitterFunc
}

return flow, nil
Expand Down
Loading

0 comments on commit 38556c6

Please sign in to comment.