Skip to content

Commit

Permalink
parser: account for number of header columns in dialect detection
Browse files Browse the repository at this point in the history
When we parse CSVs, we consider it an error for any row to contain more values
than there are header columns. But the dialect detection wasn't consistent with
that behavior, and if it encountered such a row it would score it higher than
it would a row containing fewer values than there are headers. The consequence
of that is that we could end up scoring an incorrect quote character higher
than a correct one if it produces more columns (which often the case when
quoted values contain delimiters). This commit addresses that oversight by
zeroing the score of any row that contains too many values. Thus it is treated
the same as if the row couldn't be parsed at all. The result is that dialect
detection produces a much more accurate guess of the correct quote character.
  • Loading branch information
psFried committed Jan 30, 2024
1 parent 8c53c97 commit b32a851
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 6 deletions.
86 changes: 81 additions & 5 deletions crates/parser/src/format/character_separated/detection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,14 @@ pub struct Dialect {
/// `line_separator` and `escape` are required. If `config_quote` or `config_delimiter` are Some,
/// then the search space will be limited to only those values, and they will be returned in the
/// detected dialect.
/// The `header_count` parameter must be the number of header columns that was passed in the config,
/// and must be > 0 if it is `Some`. More typically, it will be `None`, and will be inferred from the
/// first row. This is needed because we consider it an error for any row to have more values than there
/// are headers.
/// A dialect is always detected and returned, even though it may not be a very good fit. This reflects
/// the reality that even an incorrect dialect can usually at least result in a single column per line.
pub fn detect_dialect(
header_count: Option<usize>,
line_separator: LineEnding,
escape: Escape,
peeked: Bytes,
Expand All @@ -86,7 +91,15 @@ pub fn detect_dialect(
.iter()
.copied()
.map(|(quote, delimiter)| {
let score = compute_score(peeked.clone(), quote, delimiter, line_separator, escape);
let score = compute_score(
header_count,
peeked.clone(),
quote,
delimiter,
line_separator,
escape,
);
tracing::debug!(?quote, ?delimiter, ?score, "computed score for dialect");
Dialect {
quote,
delimiter,
Expand All @@ -106,9 +119,14 @@ pub fn detect_dialect(
.pop()
.expect("must have at least one candidate dialect");
// Log the top few candidates, as it's helpful to see the runner up when detection doesn't go as we expected
let runners_up = &dialects[0..(dialects.len().min(3))];
let runners_up = &dialects[dialects.len().saturating_sub(3)..];

tracing::debug!(?winning_dialect, ?runners_up, "detected CSV dialect");
tracing::debug!(
?winning_dialect,
?runners_up,
total_checked_dialects = permutations.len(),
"detected CSV dialect"
);
winning_dialect
}

Expand Down Expand Up @@ -159,6 +177,7 @@ fn get_dialect_candidates(
}

fn compute_score(
mut header_count: Option<usize>,
peeked: Bytes,
quote: Quote,
delimiter: Delimiter,
Expand Down Expand Up @@ -188,7 +207,18 @@ fn compute_score(
if !more {
break;
}
let score = record.len().saturating_sub(1);
let mut score = record.len().saturating_sub(1);
if let Some(n_headers) = header_count {
// It is an error for a CSV row to have more values than there are headers, so count this row
// as an error in that case. Note that it's permissible for a row to have _fewer_ values.
// This behavior matches that of `CsvOutput`, which returns an error if a row has too many values.
if record.len() > n_headers {
score = 0;
}
} else {
// Consider the first row to be headers, and note the number so we can properly score subsequent rows.
header_count = Some(record.len());
}
if score > 0 {
row_count += 1;
}
Expand Down Expand Up @@ -263,6 +293,52 @@ mod test {
}
}

#[test]
fn account_for_header_count_when_scoring() {
let input = Bytes::from_static(b"a;b;c;d\n'a;b;c';d;'e;f;g';h");
let sq_score = compute_score(
None,
input.clone(),
Quote::SingleQuote,
Delimiter::Semicolon,
LineEnding::CRLF,
Escape::None,
);
let dq_score = compute_score(
None,
input.clone(),
Quote::DoubleQuote,
Delimiter::Semicolon,
LineEnding::CRLF,
Escape::None,
);

assert!(sq_score > dq_score);
assert_eq!(2, sq_score.row_count);
assert_eq!(1, dq_score.row_count);

let sq_score = compute_score(
Some(9),
input.clone(),
Quote::SingleQuote,
Delimiter::Semicolon,
LineEnding::CRLF,
Escape::None,
);
let dq_score = compute_score(
Some(9),
input.clone(),
Quote::DoubleQuote,
Delimiter::Semicolon,
LineEnding::CRLF,
Escape::None,
);
// With 9 headers, the double quote dialect should result in 2 rows
assert_eq!(2, dq_score.row_count);
// But the single quote dialect should still have a higher score due to greater consistency
assert!(sq_score > dq_score);
}

#[test]
fn dialect_detection() {
#[derive(Debug, PartialEq, serde::Deserialize)]
Expand All @@ -286,7 +362,7 @@ mod test {
let expected: DetectionResult = serde_json::from_slice(&expect_json)
.expect("failed to deserialize expected detection result");

let dialect = detect_dialect(LineEnding::CRLF, Escape::None, csv, None, None);
let dialect = detect_dialect(None, LineEnding::CRLF, Escape::None, csv, None, None);
let actual = DetectionResult {
quote: dialect.quote,
delimiter: dialect.delimiter,
Expand Down
2 changes: 1 addition & 1 deletion crates/parser/src/format/character_separated/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,13 @@ impl Parser for CsvParser {
Some(qd) => qd,
None => {
let dialect = detection::detect_dialect(
Some(self.config.headers.len()).filter(|n| *n > 0),
line_ending,
escape,
peek,
config_quote,
config_delimiter,
);
tracing::debug!(?dialect, "detected CSV dialect");
(dialect.quote, dialect.delimiter)
}
};
Expand Down

0 comments on commit b32a851

Please sign in to comment.