This repository has been archived by the owner on Oct 30, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
17ea088
commit fdfc79e
Showing
4 changed files
with
342 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
187 changes: 187 additions & 0 deletions
187
pkg/datastore/textsplitter/markdown_rolling/markdown_rolling.go
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
package markdown_rolling | ||
|
||
import ( | ||
"fmt" | ||
"strings" | ||
|
||
"github.com/pkoukk/tiktoken-go" | ||
lcgosplitter "github.com/tmc/langchaingo/textsplitter" | ||
) | ||
|
||
// NewMarkdownTextSplitter creates a new Markdown text splitter. | ||
func NewMarkdownTextSplitter(opts ...Option) (*MarkdownTextSplitter, error) { | ||
options := DefaultOptions() | ||
|
||
for _, opt := range opts { | ||
opt(&options) | ||
} | ||
|
||
var tk *tiktoken.Tiktoken | ||
var err error | ||
if options.EncodingName != "" { | ||
tk, err = tiktoken.GetEncoding(options.EncodingName) | ||
} else { | ||
tk, err = tiktoken.EncodingForModel(options.ModelName) | ||
} | ||
if err != nil { | ||
return nil, fmt.Errorf("couldn't get encoding: %w", err) | ||
} | ||
|
||
tokenSplitter := lcgosplitter.TokenSplitter{ | ||
ChunkSize: options.ChunkSize, | ||
ChunkOverlap: options.ChunkOverlap, | ||
ModelName: options.ModelName, | ||
EncodingName: options.EncodingName, | ||
AllowedSpecial: []string{}, | ||
DisallowedSpecial: []string{"all"}, | ||
} | ||
|
||
return &MarkdownTextSplitter{ | ||
options, | ||
tk, | ||
tokenSplitter, | ||
}, nil | ||
} | ||
|
||
// MarkdownTextSplitter markdown header text splitter. | ||
type MarkdownTextSplitter struct { | ||
Options | ||
*tiktoken.Tiktoken | ||
tokenSplitter lcgosplitter.TokenSplitter | ||
} | ||
|
||
type block struct { | ||
headings []string | ||
lines []string | ||
text string | ||
tokenSize int | ||
} | ||
|
||
func (s *MarkdownTextSplitter) getTokenSize(text string) int { | ||
return len(s.Encode(text, []string{}, []string{"all"})) | ||
} | ||
|
||
func (s *MarkdownTextSplitter) finishBlock(blocks []block, currentBlock block, headingStack []string) ([]block, block, error) { | ||
|
||
for _, header := range headingStack { | ||
if header != "" { | ||
currentBlock.headings = append(currentBlock.headings, header) | ||
} | ||
} | ||
|
||
if len(currentBlock.lines) == 0 && s.IgnoreHeadingOnly { | ||
return blocks, block{}, nil | ||
} | ||
|
||
headingStr := strings.TrimSpace(strings.Join(currentBlock.headings, "\n")) | ||
contentStr := strings.TrimSpace(strings.Join(currentBlock.lines, "\n")) | ||
text := headingStr + "\n" + contentStr | ||
|
||
if len(text) == 0 { | ||
return blocks, block{}, nil | ||
} | ||
|
||
textTokenSize := s.getTokenSize(text) | ||
|
||
if textTokenSize <= s.ChunkSize { | ||
// append new block to free up some space | ||
return append(blocks, block{ | ||
text: text, | ||
tokenSize: textTokenSize, | ||
}), block{}, nil | ||
} | ||
|
||
// If the block is larger than the chunk size, split it | ||
headingTokenSize := s.getTokenSize(headingStr) | ||
|
||
// Split into chunks that leave room for the heading | ||
s.tokenSplitter.ChunkSize = s.ChunkSize - headingTokenSize | ||
|
||
splits, err := s.tokenSplitter.SplitText(contentStr) | ||
if err != nil { | ||
return blocks, block{}, err | ||
} | ||
|
||
for _, split := range splits { | ||
text = headingStr + "\n" + split | ||
blocks = append(blocks, block{ | ||
text: text, | ||
tokenSize: s.getTokenSize(text), | ||
}) | ||
} | ||
|
||
return blocks, block{}, nil | ||
|
||
} | ||
|
||
// SplitText splits text into chunks. | ||
func (s *MarkdownTextSplitter) SplitText(text string) ([]string, error) { | ||
|
||
var ( | ||
headingStack []string | ||
chunks []string | ||
currentChunk block | ||
currentHeadingLevel int = 1 | ||
currentBlock block | ||
|
||
blocks []block | ||
err error | ||
) | ||
|
||
// Parse markdown line-by-line and build heading-delimited blocks | ||
for _, line := range strings.Split(text, "\n") { | ||
|
||
// Handle header = start a new block | ||
if strings.HasPrefix(line, "#") { | ||
// Finish the previous Block | ||
blocks, currentBlock, err = s.finishBlock(blocks, currentBlock, headingStack) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
// Get the header level | ||
headingLevel := strings.Count(strings.Split(line, " ")[0], "#") - 1 | ||
|
||
headingStack = append(headingStack[:headingLevel], line) | ||
|
||
// Clear the header stack for lower level headers | ||
for j := headingLevel + 1; j < len(headingStack); j++ { | ||
headingStack[j] = "" | ||
} | ||
|
||
// Reset header stack indices between this level and the last seen level, backwards | ||
for j := headingLevel - 1; j > currentHeadingLevel; j-- { | ||
headingStack[j] = "" | ||
} | ||
|
||
currentHeadingLevel = headingLevel | ||
continue | ||
|
||
} | ||
|
||
// If the line is not a header, add it to the current block | ||
currentBlock.lines = append(currentBlock.lines, line) | ||
|
||
} | ||
|
||
// Finish the last block | ||
blocks, currentBlock, err = s.finishBlock(blocks, currentBlock, headingStack) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
// Combine blocks into chunks as close to the target token size as possible | ||
for _, b := range blocks { | ||
if currentChunk.tokenSize+b.tokenSize <= s.ChunkSize { | ||
// Doesn't exceed chunk size, so add to the current chunk | ||
currentChunk.text += "\n" + b.text | ||
currentChunk.tokenSize += b.tokenSize | ||
} else { | ||
// Exceeds chunk size, so start a new chunk | ||
chunks = append(chunks, currentChunk.text) | ||
currentChunk = b | ||
} | ||
} | ||
|
||
return chunks, nil | ||
} |
85 changes: 85 additions & 0 deletions
85
pkg/datastore/textsplitter/markdown_rolling/markdown_rolling_test.go
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
package markdown_rolling | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestSplitTextWithBasicMarkdown(t *testing.T) { | ||
splitter := NewMarkdownTextSplitter() | ||
chunks, err := splitter.SplitText("# Heading\n\nThis is a paragraph.") | ||
assert.NoError(t, err) | ||
assert.Equal(t, 1, len(chunks)) | ||
|
||
expected := []string{"# Heading\nThis is a paragraph."} | ||
|
||
assert.Equal(t, expected, chunks) | ||
} | ||
|
||
func TestSplitTextWithOptions(t *testing.T) { | ||
md := ` | ||
# Heading 1 | ||
some p under h1 | ||
## Heading 2 | ||
### Heading 3 | ||
- some | ||
- list | ||
- items | ||
**bold** | ||
# 2nd Heading 1 | ||
#### Heading 4 | ||
some p under h4 | ||
` | ||
|
||
testcases := []struct { | ||
name string | ||
splitter *MarkdownTextSplitter | ||
expected []string | ||
}{ | ||
{ | ||
name: "default", | ||
splitter: NewMarkdownTextSplitter(), | ||
expected: []string{ | ||
"# Heading 1\nsome p under h1", | ||
"# Heading 1\n## Heading 2", | ||
"# Heading 1\n## Heading 2\n### Heading 3\n- some\n- list\n- items\n\n**bold**", | ||
"# 2nd Heading 1", | ||
"# 2nd Heading 1\n#### Heading 4\nsome p under h4", | ||
}, | ||
}, | ||
{ | ||
name: "ignore_heading_only", | ||
splitter: NewMarkdownTextSplitter(WithIgnoreHeadingOnly(true)), | ||
Check failure on line 59 in pkg/datastore/textsplitter/markdown_rolling/markdown_rolling_test.go GitHub Actions / Full Test Suite
|
||
expected: []string{ | ||
"# Heading 1\nsome p under h1", | ||
"# Heading 1\n## Heading 2\n### Heading 3\n- some\n- list\n- items\n\n**bold**", | ||
"# 2nd Heading 1\n#### Heading 4\nsome p under h4", | ||
}, | ||
}, | ||
{ | ||
name: "split_h1_only", | ||
splitter: NewMarkdownTextSplitter(), | ||
expected: []string{ | ||
"# Heading 1\nsome p under h1\n\n## Heading 2\n### Heading 3\n\n- some\n- list\n- items\n\n**bold**", | ||
"# 2nd Heading 1\n#### Heading 4\n\nsome p under h4", | ||
}, | ||
}, | ||
} | ||
|
||
for _, tc := range testcases { | ||
t.Run(tc.name, func(t *testing.T) { | ||
chunks, err := tc.splitter.SplitText(md) | ||
assert.NoError(t, err) | ||
assert.Equal(t, len(tc.expected), len(chunks)) | ||
|
||
assert.Equal(t, tc.expected, chunks) | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
package markdown_rolling | ||
|
||
import ( | ||
"github.com/gptscript-ai/knowledge/pkg/datastore/defaults" | ||
lcgosplitter "github.com/tmc/langchaingo/textsplitter" | ||
) | ||
|
||
// Options is a struct that contains options for a text splitter. | ||
type Options struct { | ||
ChunkSize int | ||
ChunkOverlap int | ||
Separators []string | ||
KeepSeparator bool | ||
ModelName string | ||
EncodingName string | ||
SecondSplitter lcgosplitter.TextSplitter | ||
|
||
IgnoreHeadingOnly bool // Ignore chunks that only contain headings | ||
} | ||
|
||
// DefaultOptions returns the default options for all text splitter. | ||
func DefaultOptions() Options { | ||
return Options{ | ||
ChunkSize: defaults.TextSplitterChunkSize, | ||
ChunkOverlap: defaults.TextSplitterChunkOverlap, | ||
|
||
ModelName: defaults.TextSplitterTokenModel, | ||
EncodingName: defaults.TextSplitterTokenEncoding, | ||
|
||
IgnoreHeadingOnly: true, | ||
} | ||
} | ||
|
||
// Option is a function that can be used to set options for a text splitter. | ||
type Option func(*Options) | ||
|
||
// WithChunkSize sets the chunk size for a text splitter. | ||
func WithChunkSize(chunkSize int) Option { | ||
return func(o *Options) { | ||
o.ChunkSize = chunkSize | ||
} | ||
} | ||
|
||
// WithChunkOverlap sets the chunk overlap for a text splitter. | ||
func WithChunkOverlap(chunkOverlap int) Option { | ||
return func(o *Options) { | ||
o.ChunkOverlap = chunkOverlap | ||
} | ||
} | ||
|
||
// WithModelName sets the model name for a text splitter. | ||
func WithModelName(modelName string) Option { | ||
return func(o *Options) { | ||
o.ModelName = modelName | ||
} | ||
} | ||
|
||
// WithEncodingName sets the encoding name for a text splitter. | ||
func WithEncodingName(encodingName string) Option { | ||
return func(o *Options) { | ||
o.EncodingName = encodingName | ||
} | ||
} | ||
|
||
func WithIgnoreHeadingOnly(ignoreHeadingOnly bool) Option { | ||
return func(o *Options) { | ||
o.IgnoreHeadingOnly = ignoreHeadingOnly | ||
} | ||
} |