Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TokenTextSplitter enhancement #1558

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,37 +29,49 @@
/**
* @author Raphael Yu
* @author Christian Tzolov
* @author Ricken Bazolo
*/
public class TokenTextSplitter extends TextSplitter {

private final EncodingRegistry registry = Encodings.newLazyEncodingRegistry();

private final Encoding encoding = registry.getEncoding(EncodingType.CL100K_BASE);

private final static int DEFAULT_CHUNK_SIZE = 800;

private final static int MIN_CHUNK_SIZE_CHARS = 350;

private final static int MIN_CHUNK_LENGTH_TO_EMBED = 5;

private final static int MAX_NUM_CHUNKS = 10000;

private final static boolean KEEP_SEPARATOR = true;

// The target size of each text chunk in tokens
private int defaultChunkSize = 800;
private final int chunkSize;

// The minimum size of each text chunk in characters
private int minChunkSizeChars = 350;
private final int minChunkSizeChars;

// Discard chunks shorter than this
private int minChunkLengthToEmbed = 5;
private final int minChunkLengthToEmbed;

// The maximum number of chunks to generate from a text
private int maxNumChunks = 10000;
private final int maxNumChunks;

private boolean keepSeparator = true;
private final boolean keepSeparator;

public TokenTextSplitter() {
this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, KEEP_SEPARATOR);
}

public TokenTextSplitter(boolean keepSeparator) {
this.keepSeparator = keepSeparator;
this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, keepSeparator);
}

public TokenTextSplitter(int defaultChunkSize, int minChunkSizeChars, int minChunkLengthToEmbed, int maxNumChunks,
public TokenTextSplitter(int chunkSize, int minChunkSizeChars, int minChunkLengthToEmbed, int maxNumChunks,
boolean keepSeparator) {
this.defaultChunkSize = defaultChunkSize;
this.chunkSize = chunkSize;
this.minChunkSizeChars = minChunkSizeChars;
this.minChunkLengthToEmbed = minChunkLengthToEmbed;
this.maxNumChunks = maxNumChunks;
Expand All @@ -68,7 +80,7 @@ public TokenTextSplitter(int defaultChunkSize, int minChunkSizeChars, int minChu

@Override
protected List<String> splitText(String text) {
return doSplit(text, this.defaultChunkSize);
return doSplit(text, this.chunkSize);
}

protected List<String> doSplit(String text, int chunkSize) {
Expand Down Expand Up @@ -133,4 +145,55 @@ private String decodeTokens(List<Integer> tokens) {
return this.encoding.decode(tokensIntArray);
}

public static Builder builder() {
return new Builder();
}

public static class Builder {

private int chunkSize;

private int minChunkSizeChars;

private int minChunkLengthToEmbed;

private int maxNumChunks;

private boolean keepSeparator;

private Builder() {
}

public Builder withChunkSize(int chunkSize) {
this.chunkSize = chunkSize;
return this;
}

public Builder withMinChunkSizeChars(int minChunkSizeChars) {
this.minChunkSizeChars = minChunkSizeChars;
return this;
}

public Builder withMinChunkLengthToEmbed(int minChunkLengthToEmbed) {
this.minChunkLengthToEmbed = minChunkLengthToEmbed;
return this;
}

public Builder withMaxNumChunks(int maxNumChunks) {
this.maxNumChunks = maxNumChunks;
return this;
}

public Builder withKeepSeparator(boolean keepSeparator) {
this.keepSeparator = keepSeparator;
return this;
}

public TokenTextSplitter build() {
return new TokenTextSplitter(this.chunkSize, this.minChunkSizeChars, this.minChunkLengthToEmbed,
this.maxNumChunks, this.keepSeparator);
}

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package org.springframework.ai.transformer.splitter;

import org.junit.jupiter.api.Test;
import org.springframework.ai.document.DefaultContentFormatter;
import org.springframework.ai.document.Document;

import java.util.List;
import java.util.Map;

import static org.assertj.core.api.Assertions.assertThat;

/**
* @author Ricken Bazolo
*/
public class TokenTextSplitterTest {

@Test
public void testTokenTextSplitterBuilderWithDefaultValues() {

var contentFormatter1 = DefaultContentFormatter.defaultConfig();
var contentFormatter2 = DefaultContentFormatter.defaultConfig();

assertThat(contentFormatter1).isNotSameAs(contentFormatter2);

var doc1 = new Document("In the end, writing arises when man realizes that memory is not enough.",
Map.of("key1", "value1", "key2", "value2"));
doc1.setContentFormatter(contentFormatter1);

var doc2 = new Document("The most oppressive thing about the labyrinth is that you are constantly "
+ "being forced to choose. It isn’t the lack of an exit, but the abundance of exits that is so disorienting.",
Map.of("key2", "value22", "key3", "value3"));
doc2.setContentFormatter(contentFormatter2);

var tokenTextSplitter = new TokenTextSplitter();

var chunks = tokenTextSplitter.apply(List.of(doc1, doc2));

assertThat(chunks.size()).isEqualTo(2);

// Doc 1
assertThat(chunks.get(0).getContent())
.isEqualTo("In the end, writing arises when man realizes that memory is not enough.");
// Doc 2
assertThat(chunks.get(1).getContent()).isEqualTo(
"The most oppressive thing about the labyrinth is that you are constantly being forced to choose. It isn’t the lack of an exit, but the abundance of exits that is so disorienting.");

assertThat(chunks.get(0).getMetadata()).containsKeys("key1", "key2").doesNotContainKeys("key3");
assertThat(chunks.get(1).getMetadata()).containsKeys("key2", "key3").doesNotContainKeys("key1");
}

@Test
public void testTokenTextSplitterBuilderWithAllFields() {

var contentFormatter1 = DefaultContentFormatter.defaultConfig();
var contentFormatter2 = DefaultContentFormatter.defaultConfig();

assertThat(contentFormatter1).isNotSameAs(contentFormatter2);

var doc1 = new Document("In the end, writing arises when man realizes that memory is not enough.",
Map.of("key1", "value1", "key2", "value2"));
doc1.setContentFormatter(contentFormatter1);

var doc2 = new Document("The most oppressive thing about the labyrinth is that you are constantly "
+ "being forced to choose. It isn’t the lack of an exit, but the abundance of exits that is so disorienting.",
Map.of("key2", "value22", "key3", "value3"));
doc2.setContentFormatter(contentFormatter2);

var tokenTextSplitter = TokenTextSplitter.builder()
.withChunkSize(10)
.withMinChunkSizeChars(5)
.withMinChunkLengthToEmbed(3)
.withMaxNumChunks(50)
.withKeepSeparator(true)
.build();

var chunks = tokenTextSplitter.apply(List.of(doc1, doc2));

assertThat(chunks.size()).isEqualTo(6);

// Doc 1
assertThat(chunks.get(0).getContent()).isEqualTo("In the end, writing arises when man realizes that");
assertThat(chunks.get(1).getContent()).isEqualTo("memory is not enough.");

// Doc 2
assertThat(chunks.get(2).getContent()).isEqualTo("The most oppressive thing about the labyrinth is that you");
assertThat(chunks.get(3).getContent()).isEqualTo("are constantly being forced to choose.");
assertThat(chunks.get(4).getContent()).isEqualTo("It isn’t the lack of an exit, but");
assertThat(chunks.get(5).getContent()).isEqualTo("the abundance of exits that is so disorienting");

// Verify that the same, merged metadata is copied to all chunks.
assertThat(chunks.get(0).getMetadata()).isEqualTo(chunks.get(1).getMetadata());
assertThat(chunks.get(2).getMetadata()).isEqualTo(chunks.get(3).getMetadata());

assertThat(chunks.get(0).getMetadata()).containsKeys("key1", "key2").doesNotContainKeys("key3");
assertThat(chunks.get(2).getMetadata()).containsKeys("key2", "key3").doesNotContainKeys("key1");
}

}