This repository has been archived by the owner on Oct 30, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add: documentloader and textsplitter config
- Loading branch information
1 parent
6a49f15
commit 0458aef
Showing
12 changed files
with
246 additions
and
69 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,11 @@ | ||
package defaults | ||
|
||
const EmbeddingDimension int = 1536 | ||
const TopK int = 5 | ||
const ( | ||
EmbeddingDimension int = 1536 | ||
TopK int = 5 | ||
|
||
TextSplitterTokenModel = "gpt-4" | ||
TextSplitterChunkSize = 1024 | ||
TextSplitterChunkOverlap = 256 | ||
TextSplitterTokenEncoding = "cl100k_base" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
package documentloader | ||
|
||
import ( | ||
"context" | ||
"strings" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestGetDocumentLoaderConfig_ValidLoader(t *testing.T) { | ||
cfg, err := GetDocumentLoaderConfig("pdf") | ||
assert.NoError(t, err) | ||
assert.IsTypef(t, PDFOptions{}, cfg, "cfg is not of type PDFOptions") | ||
} | ||
|
||
func TestGetDocumentLoaderConfig_InvalidLoader(t *testing.T) { | ||
_, err := GetDocumentLoaderConfig("invalid") | ||
assert.Error(t, err) | ||
} | ||
|
||
func TestGetDocumentLoaderFunc_ValidLoaderWithoutConfig(t *testing.T) { | ||
_, err := GetDocumentLoaderFunc("plaintext", nil) | ||
assert.NoError(t, err) | ||
} | ||
|
||
func TestGetDocumentLoaderFunc_ValidLoaderWithInvalidConfig(t *testing.T) { | ||
_, err := GetDocumentLoaderFunc("pdf", "invalid") | ||
assert.Error(t, err) | ||
} | ||
|
||
func TestGetDocumentLoaderFunc_ValidLoaderWithValidConfig(t *testing.T) { | ||
_, err := GetDocumentLoaderFunc("pdf", PDFOptions{}) | ||
assert.NoError(t, err) | ||
} | ||
|
||
func TestGetDocumentLoaderFunc_InvalidLoader(t *testing.T) { | ||
_, err := GetDocumentLoaderFunc("invalid", nil) | ||
assert.Error(t, err) | ||
} | ||
|
||
func TestGetDocumentLoaderFunc_LoadPlainText(t *testing.T) { | ||
loaderFunc, _ := GetDocumentLoaderFunc("plaintext", nil) | ||
docs, err := loaderFunc(context.Background(), strings.NewReader("test")) | ||
assert.NoError(t, err) | ||
assert.Len(t, docs, 1) | ||
assert.Equal(t, "test", docs[0].Content) | ||
} | ||
|
||
func TestGetDocumentLoaderFunc_LoadPDF(t *testing.T) { | ||
loaderFunc, _ := GetDocumentLoaderFunc("pdf", PDFOptions{}) | ||
_, err := loaderFunc(context.Background(), strings.NewReader("test")) | ||
assert.Error(t, err) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
package textsplitter | ||
|
||
import ( | ||
"fmt" | ||
"github.com/gptscript-ai/knowledge/pkg/datastore/defaults" | ||
vs "github.com/gptscript-ai/knowledge/pkg/vectorstore" | ||
lcgosplitter "github.com/tmc/langchaingo/textsplitter" | ||
) | ||
|
||
type SplitterFunc func([]vs.Document) ([]vs.Document, error) | ||
|
||
type TextSplitterOpts struct { | ||
ChunkSize int `usage:"Textsplitter Chunk Size" default:"1024" env:"KNOW_TEXTSPLITTER_CHUNK_SIZE" name:"textsplitter-chunk-size"` | ||
ChunkOverlap int `usage:"Textsplitter Chunk Overlap" default:"256" env:"KNOW_TEXTSPLITTER_CHUNK_OVERLAP" name:"textsplitter-chunk-overlap"` | ||
ModelName string `usage:"Textsplitter Model Name" default:"gpt-4" env:"KNOW_TEXTSPLITTER_MODEL_NAME" name:"textsplitter-model-name"` | ||
EncodingName string `usage:"Textsplitter Encoding Name" default:"cl100k_base" env:"KNOW_TEXTSPLITTER_ENCODING_NAME" name:"textsplitter-encoding-name"` | ||
} | ||
|
||
// NewTextSplitterOpts returns the default options for a text splitter. | ||
func NewTextSplitterOpts() TextSplitterOpts { | ||
return TextSplitterOpts{ | ||
ChunkSize: defaults.TextSplitterChunkSize, | ||
ChunkOverlap: defaults.TextSplitterChunkOverlap, | ||
ModelName: defaults.TextSplitterTokenModel, | ||
EncodingName: defaults.TextSplitterTokenEncoding, | ||
} | ||
} | ||
|
||
// NewLcgoTextSplitter returns a new langchain-go text splitter. | ||
func NewLcgoTextSplitter(opts TextSplitterOpts) lcgosplitter.TokenSplitter { | ||
return lcgosplitter.NewTokenSplitter( | ||
lcgosplitter.WithChunkSize(opts.ChunkSize), | ||
lcgosplitter.WithChunkOverlap(opts.ChunkOverlap), | ||
lcgosplitter.WithModelName(opts.ModelName), | ||
lcgosplitter.WithEncodingName(opts.EncodingName), | ||
) | ||
} | ||
|
||
func NewLcgoMarkdownSplitter(opts TextSplitterOpts) *lcgosplitter.MarkdownTextSplitter { | ||
return lcgosplitter.NewMarkdownTextSplitter( | ||
lcgosplitter.WithChunkSize(opts.ChunkSize), | ||
lcgosplitter.WithChunkOverlap(opts.ChunkOverlap), | ||
lcgosplitter.WithModelName(opts.ModelName), | ||
lcgosplitter.WithEncodingName(opts.EncodingName), | ||
lcgosplitter.WithHeadingHierarchy(true), | ||
) | ||
} | ||
|
||
func GetTextSplitterConfig(name string) (any, error) { | ||
// TODO: expose splitter-specific config, not only our top-level options | ||
switch name { | ||
case "text", "markdown": | ||
return TextSplitterOpts{}, nil | ||
default: | ||
return nil, fmt.Errorf("unknown text splitter %q", name) | ||
} | ||
} | ||
|
||
func GetTextSplitterFunc(name string, config any) (SplitterFunc, error) { | ||
switch name { | ||
case "text": | ||
if config == nil { | ||
config = NewTextSplitterOpts() | ||
} | ||
config, ok := config.(TextSplitterOpts) | ||
if !ok { | ||
return nil, fmt.Errorf("invalid text splitter configuration") | ||
} | ||
return FromLangchain(NewLcgoTextSplitter(config)).SplitDocuments, nil | ||
case "markdown": | ||
if config == nil { | ||
config = NewTextSplitterOpts() | ||
} | ||
config, ok := config.(TextSplitterOpts) | ||
if !ok { | ||
return nil, fmt.Errorf("invalid markdown text splitter configuration") | ||
} | ||
return FromLangchain(NewLcgoMarkdownSplitter(config)).SplitDocuments, nil | ||
default: | ||
return nil, fmt.Errorf("unknown text splitter %q", name) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
package textsplitter | ||
|
||
import ( | ||
"github.com/stretchr/testify/assert" | ||
"testing" | ||
) | ||
|
||
func TestGetTextSplitterConfigWithValidName(t *testing.T) { | ||
_, err := GetTextSplitterConfig("text") | ||
assert.NoError(t, err) | ||
} | ||
|
||
func TestGetTextSplitterConfigWithInvalidName(t *testing.T) { | ||
_, err := GetTextSplitterConfig("invalid") | ||
assert.Error(t, err) | ||
} | ||
|
||
func TestGetTextSplitterFuncWithValidNameAndNilConfig(t *testing.T) { | ||
_, err := GetTextSplitterFunc("text", nil) | ||
assert.NoError(t, err) | ||
} | ||
|
||
func TestGetTextSplitterFuncWithValidNameAndInvalidConfig(t *testing.T) { | ||
_, err := GetTextSplitterFunc("text", "invalid") | ||
assert.Error(t, err) | ||
} | ||
|
||
func TestGetTextSplitterFuncWithValidNameAndValidConfig(t *testing.T) { | ||
_, err := GetTextSplitterFunc("text", NewTextSplitterOpts()) | ||
assert.NoError(t, err) | ||
} | ||
|
||
func TestGetTextSplitterFuncWithInvalidName(t *testing.T) { | ||
_, err := GetTextSplitterFunc("invalid", nil) | ||
assert.Error(t, err) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.