Skip to content

Commit

Permalink
feat: auto limit string if truncate is set (#428)
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene authored Oct 17, 2024
1 parent 750898d commit cb1e594
Showing 1 changed file with 26 additions and 4 deletions.
30 changes: 26 additions & 4 deletions core/src/tokenization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ fn prepare_pre_prompt(

#[allow(clippy::too_many_arguments)]
fn tokenize_input(
inputs: EncodingInput,
mut inputs: EncodingInput,
add_special_tokens: bool,
max_input_length: usize,
truncate_params: Option<TruncationParams>,
Expand All @@ -288,9 +288,12 @@ fn tokenize_input(
let input_chars = inputs.count_chars();
let limit = max_input_length * MAX_CHAR_MULTIPLIER;
if input_chars > limit {
return Err(TextEmbeddingsError::Validation(format!(
"`inputs` must have less than {limit} characters. Given: {input_chars}"
)));
if truncate_params.is_none() {
return Err(TextEmbeddingsError::Validation(format!(
"`inputs` must have less than {limit} characters. Given: {input_chars}"
)));
}
inputs.apply_limit(limit);
}

let encoding = match inputs {
Expand Down Expand Up @@ -426,6 +429,25 @@ impl EncodingInput {
EncodingInput::Ids(v) => v.len(),
}
}

fn apply_limit(&mut self, limit: usize) {
let truncate_string = |s: &mut String, limit: usize| {
if s.is_char_boundary(limit) {
s.truncate(limit)
}
};

match self {
EncodingInput::Single(s) => {
truncate_string(s, limit);
}
EncodingInput::Dual(s1, s2) => {
truncate_string(s1, limit / 2);
truncate_string(s2, limit / 2);
}
EncodingInput::Ids(_) => {}
}
}
}

impl From<String> for EncodingInput {
Expand Down

0 comments on commit cb1e594

Please sign in to comment.