Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

parser: account for number of header columns in dialect detection #1359

Merged
merged 1 commit into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 81 additions & 5 deletions crates/parser/src/format/character_separated/detection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,14 @@ pub struct Dialect {
/// `line_separator` and `escape` are required. If `config_quote` or `config_delimiter` are Some,
/// then the search space will be limited to only those values, and they will be returned in the
/// detected dialect.
/// The `header_count` parameter must be the number of header columns that was passed in the config,
/// and must be > 0 if it is `Some`. More typically, it will be `None`, and will be inferred from the
/// first row. This is needed because we consider it an error for any row to have more values than there
/// are headers.
/// A dialect is always detected and returned, even though it may not be a very good fit. This reflects
/// the reality that even an incorrect dialect can usually at least result in a single column per line.
pub fn detect_dialect(
header_count: Option<usize>,
line_separator: LineEnding,
escape: Escape,
peeked: Bytes,
Expand All @@ -86,7 +91,15 @@ pub fn detect_dialect(
.iter()
.copied()
.map(|(quote, delimiter)| {
let score = compute_score(peeked.clone(), quote, delimiter, line_separator, escape);
let score = compute_score(
header_count,
peeked.clone(),
quote,
delimiter,
line_separator,
escape,
);
tracing::debug!(?quote, ?delimiter, ?score, "computed score for dialect");
Dialect {
quote,
delimiter,
Expand All @@ -106,9 +119,14 @@ pub fn detect_dialect(
.pop()
.expect("must have at least one candidate dialect");
// Log the top few candidates, as it's helpful to see the runner up when detection doesn't go as we expected
let runners_up = &dialects[0..(dialects.len().min(3))];
let runners_up = &dialects[dialects.len().saturating_sub(3)..];
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous expression here was incorrect. So this now correctly prints the top 3 runners up, with the last value being the 2nd place finisher.


tracing::debug!(?winning_dialect, ?runners_up, "detected CSV dialect");
tracing::debug!(
?winning_dialect,
?runners_up,
total_checked_dialects = permutations.len(),
"detected CSV dialect"
);
winning_dialect
}

Expand Down Expand Up @@ -159,6 +177,7 @@ fn get_dialect_candidates(
}

fn compute_score(
mut header_count: Option<usize>,
peeked: Bytes,
quote: Quote,
delimiter: Delimiter,
Expand Down Expand Up @@ -188,7 +207,18 @@ fn compute_score(
if !more {
break;
}
let score = record.len().saturating_sub(1);
let mut score = record.len().saturating_sub(1);
if let Some(n_headers) = header_count {
// It is an error for a CSV row to have more values than there are headers, so count this row
// as an error in that case. Note that it's permissible for a row to have _fewer_ values.
// This behavior matches that of `CsvOutput`, which returns an error if a row has too many values.
if record.len() > n_headers {
score = 0;
}
} else {
// Consider the first row to be headers, and note the number so we can properly score subsequent rows.
header_count = Some(record.len());
}
if score > 0 {
row_count += 1;
}
Expand Down Expand Up @@ -263,6 +293,52 @@ mod test {
}
}

#[test]
fn account_for_header_count_when_scoring() {
let input = Bytes::from_static(b"a;b;c;d\n'a;b;c';d;'e;f;g';h");
let sq_score = compute_score(
None,
input.clone(),
Quote::SingleQuote,
Delimiter::Semicolon,
LineEnding::CRLF,
Escape::None,
);
let dq_score = compute_score(
None,
input.clone(),
Quote::DoubleQuote,
Delimiter::Semicolon,
LineEnding::CRLF,
Escape::None,
);

assert!(sq_score > dq_score);
assert_eq!(2, sq_score.row_count);
assert_eq!(1, dq_score.row_count);

let sq_score = compute_score(
Some(9),
input.clone(),
Quote::SingleQuote,
Delimiter::Semicolon,
LineEnding::CRLF,
Escape::None,
);
let dq_score = compute_score(
Some(9),
input.clone(),
Quote::DoubleQuote,
Delimiter::Semicolon,
LineEnding::CRLF,
Escape::None,
);
// With 9 headers, the double quote dialect should result in 2 rows
assert_eq!(2, dq_score.row_count);
// But the single quote dialect should still have a higher score due to greater consistency
assert!(sq_score > dq_score);
}

#[test]
fn dialect_detection() {
#[derive(Debug, PartialEq, serde::Deserialize)]
Expand All @@ -286,7 +362,7 @@ mod test {
let expected: DetectionResult = serde_json::from_slice(&expect_json)
.expect("failed to deserialize expected detection result");

let dialect = detect_dialect(LineEnding::CRLF, Escape::None, csv, None, None);
let dialect = detect_dialect(None, LineEnding::CRLF, Escape::None, csv, None, None);
let actual = DetectionResult {
quote: dialect.quote,
delimiter: dialect.delimiter,
Expand Down
2 changes: 1 addition & 1 deletion crates/parser/src/format/character_separated/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,13 @@ impl Parser for CsvParser {
Some(qd) => qd,
None => {
let dialect = detection::detect_dialect(
Some(self.config.headers.len()).filter(|n| *n > 0),
line_ending,
escape,
peek,
config_quote,
config_delimiter,
);
tracing::debug!(?dialect, "detected CSV dialect");
(dialect.quote, dialect.delimiter)
}
};
Expand Down
Loading