From b32a8519209a15562ec141ddee222412e3335a45 Mon Sep 17 00:00:00 2001 From: Phil Date: Tue, 30 Jan 2024 13:48:40 -0500 Subject: [PATCH] parser: account for number of header columns in dialect detection When we parse CSVs, we consider it an error for any row to contain more values than there are header columns. But the dialect detection wasn't consistent with that behavior, and if it encountered such a row it would score it higher than it would a row containing fewer values than there are headers. The consequence of that is that we could end up scoring an incorrect quote character higher than a correct one if it produces more columns (which often the case when quoted values contain delimiters). This commit addresses that oversight by zeroing the score of any row that contains too many values. Thus it is treated the same as if the row couldn't be parsed at all. The result is that dialect detection produces a much more accurate guess of the correct quote character. --- .../format/character_separated/detection.rs | 86 +++++++++++++++++-- .../src/format/character_separated/mod.rs | 2 +- 2 files changed, 82 insertions(+), 6 deletions(-) diff --git a/crates/parser/src/format/character_separated/detection.rs b/crates/parser/src/format/character_separated/detection.rs index b216d98cde..d43cda1a50 100644 --- a/crates/parser/src/format/character_separated/detection.rs +++ b/crates/parser/src/format/character_separated/detection.rs @@ -71,9 +71,14 @@ pub struct Dialect { /// `line_separator` and `escape` are required. If `config_quote` or `config_delimiter` are Some, /// then the search space will be limited to only those values, and they will be returned in the /// detected dialect. +/// The `header_count` parameter must be the number of header columns that was passed in the config, +/// and must be > 0 if it is `Some`. More typically, it will be `None`, and will be inferred from the +/// first row. This is needed because we consider it an error for any row to have more values than there +/// are headers. /// A dialect is always detected and returned, even though it may not be a very good fit. This reflects /// the reality that even an incorrect dialect can usually at least result in a single column per line. pub fn detect_dialect( + header_count: Option, line_separator: LineEnding, escape: Escape, peeked: Bytes, @@ -86,7 +91,15 @@ pub fn detect_dialect( .iter() .copied() .map(|(quote, delimiter)| { - let score = compute_score(peeked.clone(), quote, delimiter, line_separator, escape); + let score = compute_score( + header_count, + peeked.clone(), + quote, + delimiter, + line_separator, + escape, + ); + tracing::debug!(?quote, ?delimiter, ?score, "computed score for dialect"); Dialect { quote, delimiter, @@ -106,9 +119,14 @@ pub fn detect_dialect( .pop() .expect("must have at least one candidate dialect"); // Log the top few candidates, as it's helpful to see the runner up when detection doesn't go as we expected - let runners_up = &dialects[0..(dialects.len().min(3))]; + let runners_up = &dialects[dialects.len().saturating_sub(3)..]; - tracing::debug!(?winning_dialect, ?runners_up, "detected CSV dialect"); + tracing::debug!( + ?winning_dialect, + ?runners_up, + total_checked_dialects = permutations.len(), + "detected CSV dialect" + ); winning_dialect } @@ -159,6 +177,7 @@ fn get_dialect_candidates( } fn compute_score( + mut header_count: Option, peeked: Bytes, quote: Quote, delimiter: Delimiter, @@ -188,7 +207,18 @@ fn compute_score( if !more { break; } - let score = record.len().saturating_sub(1); + let mut score = record.len().saturating_sub(1); + if let Some(n_headers) = header_count { + // It is an error for a CSV row to have more values than there are headers, so count this row + // as an error in that case. Note that it's permissible for a row to have _fewer_ values. + // This behavior matches that of `CsvOutput`, which returns an error if a row has too many values. + if record.len() > n_headers { + score = 0; + } + } else { + // Consider the first row to be headers, and note the number so we can properly score subsequent rows. + header_count = Some(record.len()); + } if score > 0 { row_count += 1; } @@ -263,6 +293,52 @@ mod test { } } + #[test] + fn account_for_header_count_when_scoring() { + let input = Bytes::from_static(b"a;b;c;d\n'a;b;c';d;'e;f;g';h"); + let sq_score = compute_score( + None, + input.clone(), + Quote::SingleQuote, + Delimiter::Semicolon, + LineEnding::CRLF, + Escape::None, + ); + let dq_score = compute_score( + None, + input.clone(), + Quote::DoubleQuote, + Delimiter::Semicolon, + LineEnding::CRLF, + Escape::None, + ); + + assert!(sq_score > dq_score); + assert_eq!(2, sq_score.row_count); + assert_eq!(1, dq_score.row_count); + + let sq_score = compute_score( + Some(9), + input.clone(), + Quote::SingleQuote, + Delimiter::Semicolon, + LineEnding::CRLF, + Escape::None, + ); + let dq_score = compute_score( + Some(9), + input.clone(), + Quote::DoubleQuote, + Delimiter::Semicolon, + LineEnding::CRLF, + Escape::None, + ); + // With 9 headers, the double quote dialect should result in 2 rows + assert_eq!(2, dq_score.row_count); + // But the single quote dialect should still have a higher score due to greater consistency + assert!(sq_score > dq_score); + } + #[test] fn dialect_detection() { #[derive(Debug, PartialEq, serde::Deserialize)] @@ -286,7 +362,7 @@ mod test { let expected: DetectionResult = serde_json::from_slice(&expect_json) .expect("failed to deserialize expected detection result"); - let dialect = detect_dialect(LineEnding::CRLF, Escape::None, csv, None, None); + let dialect = detect_dialect(None, LineEnding::CRLF, Escape::None, csv, None, None); let actual = DetectionResult { quote: dialect.quote, delimiter: dialect.delimiter, diff --git a/crates/parser/src/format/character_separated/mod.rs b/crates/parser/src/format/character_separated/mod.rs index ef73e1b6a8..0075fb38a5 100644 --- a/crates/parser/src/format/character_separated/mod.rs +++ b/crates/parser/src/format/character_separated/mod.rs @@ -73,13 +73,13 @@ impl Parser for CsvParser { Some(qd) => qd, None => { let dialect = detection::detect_dialect( + Some(self.config.headers.len()).filter(|n| *n > 0), line_ending, escape, peek, config_quote, config_delimiter, ); - tracing::debug!(?dialect, "detected CSV dialect"); (dialect.quote, dialect.delimiter) } };