Skip to content
This repository has been archived by the owner on Oct 30, 2024. It is now read-only.

Commit

Permalink
add: rolling markdown splitter
Browse files Browse the repository at this point in the history
  • Loading branch information
iwilltry42 committed Oct 29, 2024
1 parent 17ea088 commit fdfc79e
Show file tree
Hide file tree
Showing 4 changed files with 342 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pkg/datastore/defaults/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ const (
TopK int = 10

TextSplitterTokenModel = "gpt-4"
TextSplitterChunkSize = 1024
TextSplitterChunkSize = 2048
TextSplitterChunkOverlap = 256
TextSplitterTokenEncoding = "cl100k_base"
)
187 changes: 187 additions & 0 deletions pkg/datastore/textsplitter/markdown_rolling/markdown_rolling.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
package markdown_rolling

import (
"fmt"
"strings"

"github.com/pkoukk/tiktoken-go"
lcgosplitter "github.com/tmc/langchaingo/textsplitter"
)

// NewMarkdownTextSplitter creates a new Markdown text splitter.
func NewMarkdownTextSplitter(opts ...Option) (*MarkdownTextSplitter, error) {
options := DefaultOptions()

for _, opt := range opts {
opt(&options)
}

var tk *tiktoken.Tiktoken
var err error
if options.EncodingName != "" {
tk, err = tiktoken.GetEncoding(options.EncodingName)
} else {
tk, err = tiktoken.EncodingForModel(options.ModelName)
}
if err != nil {
return nil, fmt.Errorf("couldn't get encoding: %w", err)
}

tokenSplitter := lcgosplitter.TokenSplitter{
ChunkSize: options.ChunkSize,
ChunkOverlap: options.ChunkOverlap,
ModelName: options.ModelName,
EncodingName: options.EncodingName,
AllowedSpecial: []string{},
DisallowedSpecial: []string{"all"},
}

return &MarkdownTextSplitter{
options,
tk,
tokenSplitter,
}, nil
}

// MarkdownTextSplitter markdown header text splitter.
type MarkdownTextSplitter struct {
Options
*tiktoken.Tiktoken
tokenSplitter lcgosplitter.TokenSplitter
}

type block struct {
headings []string
lines []string
text string
tokenSize int
}

func (s *MarkdownTextSplitter) getTokenSize(text string) int {
return len(s.Encode(text, []string{}, []string{"all"}))
}

func (s *MarkdownTextSplitter) finishBlock(blocks []block, currentBlock block, headingStack []string) ([]block, block, error) {

for _, header := range headingStack {
if header != "" {
currentBlock.headings = append(currentBlock.headings, header)
}
}

if len(currentBlock.lines) == 0 && s.IgnoreHeadingOnly {
return blocks, block{}, nil
}

headingStr := strings.TrimSpace(strings.Join(currentBlock.headings, "\n"))
contentStr := strings.TrimSpace(strings.Join(currentBlock.lines, "\n"))
text := headingStr + "\n" + contentStr

if len(text) == 0 {
return blocks, block{}, nil
}

textTokenSize := s.getTokenSize(text)

if textTokenSize <= s.ChunkSize {
// append new block to free up some space
return append(blocks, block{
text: text,
tokenSize: textTokenSize,
}), block{}, nil
}

// If the block is larger than the chunk size, split it
headingTokenSize := s.getTokenSize(headingStr)

// Split into chunks that leave room for the heading
s.tokenSplitter.ChunkSize = s.ChunkSize - headingTokenSize

splits, err := s.tokenSplitter.SplitText(contentStr)
if err != nil {
return blocks, block{}, err
}

for _, split := range splits {
text = headingStr + "\n" + split
blocks = append(blocks, block{
text: text,
tokenSize: s.getTokenSize(text),
})
}

return blocks, block{}, nil

}

// SplitText splits text into chunks.
func (s *MarkdownTextSplitter) SplitText(text string) ([]string, error) {

var (
headingStack []string
chunks []string
currentChunk block
currentHeadingLevel int = 1
currentBlock block

blocks []block
err error
)

// Parse markdown line-by-line and build heading-delimited blocks
for _, line := range strings.Split(text, "\n") {

// Handle header = start a new block
if strings.HasPrefix(line, "#") {
// Finish the previous Block
blocks, currentBlock, err = s.finishBlock(blocks, currentBlock, headingStack)
if err != nil {
return nil, err
}

// Get the header level
headingLevel := strings.Count(strings.Split(line, " ")[0], "#") - 1

headingStack = append(headingStack[:headingLevel], line)

// Clear the header stack for lower level headers
for j := headingLevel + 1; j < len(headingStack); j++ {
headingStack[j] = ""
}

// Reset header stack indices between this level and the last seen level, backwards
for j := headingLevel - 1; j > currentHeadingLevel; j-- {
headingStack[j] = ""
}

currentHeadingLevel = headingLevel
continue

}

// If the line is not a header, add it to the current block
currentBlock.lines = append(currentBlock.lines, line)

}

// Finish the last block
blocks, currentBlock, err = s.finishBlock(blocks, currentBlock, headingStack)
if err != nil {
return nil, err
}

// Combine blocks into chunks as close to the target token size as possible
for _, b := range blocks {
if currentChunk.tokenSize+b.tokenSize <= s.ChunkSize {
// Doesn't exceed chunk size, so add to the current chunk
currentChunk.text += "\n" + b.text
currentChunk.tokenSize += b.tokenSize
} else {
// Exceeds chunk size, so start a new chunk
chunks = append(chunks, currentChunk.text)
currentChunk = b
}
}

return chunks, nil
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package markdown_rolling

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestSplitTextWithBasicMarkdown(t *testing.T) {
splitter := NewMarkdownTextSplitter()

Check failure on line 10 in pkg/datastore/textsplitter/markdown_rolling/markdown_rolling_test.go

View workflow job for this annotation

GitHub Actions / Full Test Suite

assignment mismatch: 1 variable but NewMarkdownTextSplitter returns 2 values
chunks, err := splitter.SplitText("# Heading\n\nThis is a paragraph.")
assert.NoError(t, err)
assert.Equal(t, 1, len(chunks))

expected := []string{"# Heading\nThis is a paragraph."}

assert.Equal(t, expected, chunks)
}

func TestSplitTextWithOptions(t *testing.T) {
md := `
# Heading 1
some p under h1
## Heading 2
### Heading 3
- some
- list
- items
**bold**
# 2nd Heading 1
#### Heading 4
some p under h4
`

testcases := []struct {
name string
splitter *MarkdownTextSplitter
expected []string
}{
{
name: "default",
splitter: NewMarkdownTextSplitter(),

Check failure on line 48 in pkg/datastore/textsplitter/markdown_rolling/markdown_rolling_test.go

View workflow job for this annotation

GitHub Actions / Full Test Suite

multiple-value NewMarkdownTextSplitter() (value of type (*MarkdownTextSplitter, error)) in single-value context
expected: []string{
"# Heading 1\nsome p under h1",
"# Heading 1\n## Heading 2",
"# Heading 1\n## Heading 2\n### Heading 3\n- some\n- list\n- items\n\n**bold**",
"# 2nd Heading 1",
"# 2nd Heading 1\n#### Heading 4\nsome p under h4",
},
},
{
name: "ignore_heading_only",
splitter: NewMarkdownTextSplitter(WithIgnoreHeadingOnly(true)),

Check failure on line 59 in pkg/datastore/textsplitter/markdown_rolling/markdown_rolling_test.go

View workflow job for this annotation

GitHub Actions / Full Test Suite

multiple-value NewMarkdownTextSplitter(WithIgnoreHeadingOnly(true)) (value of type (*MarkdownTextSplitter, error)) in single-value context
expected: []string{
"# Heading 1\nsome p under h1",
"# Heading 1\n## Heading 2\n### Heading 3\n- some\n- list\n- items\n\n**bold**",
"# 2nd Heading 1\n#### Heading 4\nsome p under h4",
},
},
{
name: "split_h1_only",
splitter: NewMarkdownTextSplitter(),

Check failure on line 68 in pkg/datastore/textsplitter/markdown_rolling/markdown_rolling_test.go

View workflow job for this annotation

GitHub Actions / Full Test Suite

multiple-value NewMarkdownTextSplitter() (value of type (*MarkdownTextSplitter, error)) in single-value context
expected: []string{
"# Heading 1\nsome p under h1\n\n## Heading 2\n### Heading 3\n\n- some\n- list\n- items\n\n**bold**",
"# 2nd Heading 1\n#### Heading 4\n\nsome p under h4",
},
},
}

for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
chunks, err := tc.splitter.SplitText(md)
assert.NoError(t, err)
assert.Equal(t, len(tc.expected), len(chunks))

assert.Equal(t, tc.expected, chunks)
})
}
}
69 changes: 69 additions & 0 deletions pkg/datastore/textsplitter/markdown_rolling/options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package markdown_rolling

import (
"github.com/gptscript-ai/knowledge/pkg/datastore/defaults"
lcgosplitter "github.com/tmc/langchaingo/textsplitter"
)

// Options is a struct that contains options for a text splitter.
type Options struct {
ChunkSize int
ChunkOverlap int
Separators []string
KeepSeparator bool
ModelName string
EncodingName string
SecondSplitter lcgosplitter.TextSplitter

IgnoreHeadingOnly bool // Ignore chunks that only contain headings
}

// DefaultOptions returns the default options for all text splitter.
func DefaultOptions() Options {
return Options{
ChunkSize: defaults.TextSplitterChunkSize,
ChunkOverlap: defaults.TextSplitterChunkOverlap,

ModelName: defaults.TextSplitterTokenModel,
EncodingName: defaults.TextSplitterTokenEncoding,

IgnoreHeadingOnly: true,
}
}

// Option is a function that can be used to set options for a text splitter.
type Option func(*Options)

// WithChunkSize sets the chunk size for a text splitter.
func WithChunkSize(chunkSize int) Option {
return func(o *Options) {
o.ChunkSize = chunkSize
}
}

// WithChunkOverlap sets the chunk overlap for a text splitter.
func WithChunkOverlap(chunkOverlap int) Option {
return func(o *Options) {
o.ChunkOverlap = chunkOverlap
}
}

// WithModelName sets the model name for a text splitter.
func WithModelName(modelName string) Option {
return func(o *Options) {
o.ModelName = modelName
}
}

// WithEncodingName sets the encoding name for a text splitter.
func WithEncodingName(encodingName string) Option {
return func(o *Options) {
o.EncodingName = encodingName
}
}

func WithIgnoreHeadingOnly(ignoreHeadingOnly bool) Option {
return func(o *Options) {
o.IgnoreHeadingOnly = ignoreHeadingOnly
}
}

0 comments on commit fdfc79e

Please sign in to comment.