Skip to content

Commit

Permalink
preserve sql formatting through a parse + display roundtrip (partial …
Browse files Browse the repository at this point in the history
…implementation)

this implements (a tiny portion of) apache#1634

pros: really useful when passing formatted queries to a real database, in order for database error message locations to match the original user's source locations

cons: if we want to do it well, we need to track source locations better, and this adds a complexity to the Display imlementations
  • Loading branch information
lovasoa committed Jan 2, 2025
1 parent 94ea206 commit 31d057a
Show file tree
Hide file tree
Showing 10 changed files with 303 additions and 36 deletions.
145 changes: 140 additions & 5 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,117 @@ where
DisplaySeparated { slice, sep: ", " }
}

pub struct DisplaySeparatedWithNewlines<'a, T>
where
T: fmt::Display + Spanned,
{
slice: &'a [T],
sep: &'static str,
last_span: Span,
}

impl<T> fmt::Display for DisplaySeparatedWithNewlines<'_, T>
where
T: fmt::Display + Spanned,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
// Initialize the last span to track where we left off in our previous display logic.
// We suppose we are at the start of a line, so we take the first item's starting position
let mut last_span = self.last_span;
if let Some(first) = self.slice.first() {
let first_span = first.span();
write_span_gap_lines(f, &mut last_span, first_span)?;
}
let mut delim = "";
for t in self.slice {
write!(f, "{delim}")?;
last_span.end.column += u64::try_from(delim.len()).unwrap_or(1);

let current_span = t.span();
write_span_gap(f, last_span, current_span)?;
write!(f, "{t}")?;
last_span = current_span;
delim = self.sep;
}
Ok(())
}
}

/// Write newlines and spaces between two spans
pub fn write_span_gap(
f: &mut fmt::Formatter,
mut last_span: Span,
current_span: Span,
) -> fmt::Result {
// write all the newlines between the last item and the current item
while last_span.end.line < current_span.start.line {
writeln!(f)?;
last_span.end.line += 1;
last_span.end.column = 1;
}
// write spaces between the last item and the current item
while last_span.end.column < current_span.start.column {
write!(f, " ")?;
last_span.end.column += 1;
}
Ok(())
}

/// Write newlines between two spans. If the two spans are on the same line, write a single space
pub fn write_span_gap_lines(
f: &mut fmt::Formatter,
last_span: &mut Span,
current_span: Span,
) -> fmt::Result {
let mut needs_space = true;
while last_span.end.line < current_span.start.line {
writeln!(f)?;
last_span.end.line += 1;
last_span.end.column = 1;
needs_space = false;
}
if needs_space {
write!(f, " ")?;
last_span.end.column += 1;
}
Ok(())
}

pub fn display_separated_with_newlines<'a, T>(
slice: &'a [T],
sep: &'static str,
last_span: Span,
) -> DisplaySeparatedWithNewlines<'a, T>
where
T: fmt::Display + Spanned,
{
DisplaySeparatedWithNewlines {
slice,
sep,
last_span,
}
}

pub fn display_comma_separated_with_newlines<T>(
slice: &[T],
last_span: Span,
) -> DisplaySeparatedWithNewlines<'_, T>
where
T: fmt::Display + Spanned,
{
// if we don't have span info, just add a space between the items
let sep = if slice.iter().all(|s| s.span() == Span::empty()) {
", "
} else {
","
};
DisplaySeparatedWithNewlines {
slice,
sep,
last_span,
}
}

/// An identifier, decomposed into its value or character data and the quote style.
#[derive(Debug, Clone, PartialOrd, Ord)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
Expand Down Expand Up @@ -3763,21 +3874,44 @@ impl fmt::Display for Statement {
if let Some(or) = or {
write!(f, "{or} ")?;
}
let mut last_span = table.span();
write!(f, "{table}")?;
if let Some(UpdateTableFromKind::BeforeSet(from)) = from {
write!(f, " FROM {from}")?;
let from_span = from.span();
write_span_gap_lines(f, &mut last_span, from_span)?;
last_span = from_span;
write!(f, "FROM {from}")?;
}
if !assignments.is_empty() {
write!(f, " SET {}", display_comma_separated(assignments))?;
let assign_span = assignments.first().unwrap().span();
write_span_gap_lines(f, &mut last_span, assign_span)?;
last_span.end.column += 3;
write!(
f,
"SET{}",
display_comma_separated_with_newlines(assignments, last_span)
)?;
last_span = assignments.last().unwrap().span();
}
if let Some(UpdateTableFromKind::AfterSet(from)) = from {
write!(f, " FROM {from}")?;
write_span_gap_lines(f, &mut last_span, from.span())?;
last_span = from.span();
write!(f, "FROM {from}")?;
}
if let Some(selection) = selection {
write!(f, " WHERE {selection}")?;
write_span_gap_lines(f, &mut last_span, selection.span())?;
last_span = selection.span();
write!(f, "WHERE {selection}")?;
}
if let Some(returning) = returning {
write!(f, " RETURNING {}", display_comma_separated(returning))?;
let returning_span = returning.first().unwrap().span();
write_span_gap_lines(f, &mut last_span, returning_span)?;
last_span.end = returning_span.start;
write!(
f,
"RETURNING{}",
display_comma_separated_with_newlines(returning, last_span)
)?;
}
Ok(())
}
Expand Down Expand Up @@ -5420,6 +5554,7 @@ impl fmt::Display for GrantObjects {
pub struct Assignment {
pub target: AssignmentTarget,
pub value: Expr,
pub span: Span,
}

impl fmt::Display for Assignment {
Expand Down
4 changes: 1 addition & 3 deletions src/ast/spans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1229,9 +1229,7 @@ impl Spanned for DoUpdate {

impl Spanned for Assignment {
fn span(&self) -> Span {
let Assignment { target, value } = self;

target.span().union(&value.span())
self.span
}
}

Expand Down
9 changes: 8 additions & 1 deletion src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12052,10 +12052,17 @@ impl<'a> Parser<'a> {

/// Parse a `var = expr` assignment, used in an UPDATE statement
pub fn parse_assignment(&mut self) -> Result<Assignment, ParserError> {
let start = self.peek_token().span.start;
let target = self.parse_assignment_target()?;
self.expect_token(&Token::Eq)?;
let value = self.parse_expr()?;
Ok(Assignment { target, value })
self.prev_token();
let end = self.next_token().span.end;
Ok(Assignment {
target,
value,
span: Span::new(start, end),
})
}

/// Parse the left-hand side of an assignment, used in an UPDATE statement
Expand Down
38 changes: 37 additions & 1 deletion src/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,26 @@ impl TestedDialects {
// Parser::parse_sql(&**self.dialects.first().unwrap(), sql)
}

/// Parses a single SQL string into multiple statements, ensuring
/// the result is the same for all tested dialects.
pub fn parse_sql_statements_with_locations(
&self,
sql: &str,
) -> Result<Vec<Statement>, ParserError> {
self.one_of_identical_results(|dialect| {
let mut tokenizer = Tokenizer::new(dialect, sql);
if let Some(options) = &self.options {
tokenizer = tokenizer.with_unescape(options.unescape);
}
let tokens = tokenizer.tokenize_with_location()?;
self.new_parser(dialect)
.with_tokens_with_locations(tokens)
.parse_statements()
})
// To fail the `ensure_multiple_dialects_are_tested` test:
// Parser::parse_sql(&**self.dialects.first().unwrap(), sql)
}

/// Ensures that `sql` parses as a single [Statement] for all tested
/// dialects.
///
Expand All @@ -152,7 +172,7 @@ impl TestedDialects {
/// 2. re-serializing the result of parsing `sql` produces the same
/// `canonical` sql string
pub fn one_statement_parses_to(&self, sql: &str, canonical: &str) -> Statement {
let mut statements = self.parse_sql_statements(sql).expect(sql);
let mut statements = self.parse_sql_statements_with_locations(sql).expect(sql);
assert_eq!(statements.len(), 1);

if !canonical.is_empty() && sql != canonical {
Expand All @@ -167,6 +187,17 @@ impl TestedDialects {
only_statement
}

/// Identical to `one_statement_parses_to`, but sets all locations to empty.
pub fn one_statement_parses_to_no_span(&self, sql: &str, canonical: &str) -> Statement {
let mut statements = self.parse_sql_statements(sql).expect(sql);
assert_eq!(statements.len(), 1);
let only_statement = statements.pop().unwrap();
if !canonical.is_empty() {
assert_eq!(canonical, only_statement.to_string())
}
only_statement
}

/// Ensures that `sql` parses as an [`Expr`], and that
/// re-serializing the parse result produces canonical
pub fn expr_parses_to(&self, sql: &str, canonical: &str) -> Expr {
Expand All @@ -184,6 +215,11 @@ impl TestedDialects {
self.one_statement_parses_to(sql, sql)
}

/// Identical to `verified_stmt`, but sets all locations to empty.
pub fn verified_stmt_no_span(&self, sql: &str) -> Statement {
self.one_statement_parses_to_no_span(sql, sql)
}

/// Ensures that `sql` parses as a single [Query], and that
/// re-serializing the parse result produces the same `sql`
/// string (is not modified after a serialization round-trip).
Expand Down
4 changes: 3 additions & 1 deletion tests/sqlparser_bigquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1624,16 +1624,18 @@ fn parse_merge() {
let update_action = MergeAction::Update {
assignments: vec![
Assignment {
span: Span::empty(),
target: AssignmentTarget::ColumnName(ObjectName(vec![Ident::new("a")])),
value: Expr::Value(number("1")),
},
Assignment {
span: Span::empty(),
target: AssignmentTarget::ColumnName(ObjectName(vec![Ident::new("b")])),
value: Expr::Value(number("2")),
},
],
};
match bigquery_and_generic().verified_stmt(sql) {
match bigquery_and_generic().verified_stmt_no_span(sql) {
Statement::Merge {
into,
table,
Expand Down
24 changes: 20 additions & 4 deletions tests/sqlparser_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,25 +297,30 @@ fn parse_update() {
match verified_stmt(sql) {
Statement::Update {
table,
assignments,
mut assignments,
selection,
..
} => {
assert_eq!(table.to_string(), "t".to_string());
// remove the span from the assignments before comparison
assignments.iter_mut().for_each(|a| a.span = Span::empty());
assert_eq!(
assignments,
vec![
Assignment {
target: AssignmentTarget::ColumnName(ObjectName(vec!["a".into()])),
value: Expr::Value(number("1")),
span: Span::empty(),
},
Assignment {
target: AssignmentTarget::ColumnName(ObjectName(vec!["b".into()])),
value: Expr::Value(number("2")),
span: Span::empty(),
},
Assignment {
target: AssignmentTarget::ColumnName(ObjectName(vec!["c".into()])),
value: Expr::Value(number("3")),
span: Span::empty(),
},
]
);
Expand Down Expand Up @@ -354,7 +359,7 @@ fn parse_update_set_from() {
Box::new(MsSqlDialect {}),
Box::new(SQLiteDialect {}),
]);
let stmt = dialects.verified_stmt(sql);
let stmt = dialects.verified_stmt_no_span(sql);
assert_eq!(
stmt,
Statement::Update {
Expand All @@ -363,6 +368,7 @@ fn parse_update_set_from() {
joins: vec![],
},
assignments: vec![Assignment {
span: Span::empty(),
target: AssignmentTarget::ColumnName(ObjectName(vec![Ident::new("name")])),
value: Expr::CompoundIdentifier(vec![Ident::new("t2"), Ident::new("name")])
}],
Expand Down Expand Up @@ -439,7 +445,7 @@ fn parse_update_set_from() {
#[test]
fn parse_update_with_table_alias() {
let sql = "UPDATE users AS u SET u.username = 'new_user' WHERE u.username = 'old_user'";
match verified_stmt(sql) {
match verified_stmt_no_span(sql) {
Statement::Update {
table,
assignments,
Expand Down Expand Up @@ -470,6 +476,7 @@ fn parse_update_with_table_alias() {
);
assert_eq!(
vec![Assignment {
span: Span::empty(),
target: AssignmentTarget::ColumnName(ObjectName(vec![
Ident::new("u"),
Ident::new("username")
Expand Down Expand Up @@ -8529,7 +8536,10 @@ fn test_revoke() {
fn parse_merge() {
let sql = "MERGE INTO s.bar AS dest USING (SELECT * FROM s.foo) AS stg ON dest.D = stg.D AND dest.E = stg.E WHEN NOT MATCHED THEN INSERT (A, B, C) VALUES (stg.A, stg.B, stg.C) WHEN MATCHED AND dest.A = 'a' THEN UPDATE SET dest.F = stg.F, dest.G = stg.G WHEN MATCHED THEN DELETE";
let sql_no_into = "MERGE s.bar AS dest USING (SELECT * FROM s.foo) AS stg ON dest.D = stg.D AND dest.E = stg.E WHEN NOT MATCHED THEN INSERT (A, B, C) VALUES (stg.A, stg.B, stg.C) WHEN MATCHED AND dest.A = 'a' THEN UPDATE SET dest.F = stg.F, dest.G = stg.G WHEN MATCHED THEN DELETE";
match (verified_stmt(sql), verified_stmt(sql_no_into)) {
match (
verified_stmt_no_span(sql),
verified_stmt_no_span(sql_no_into),
) {
(
Statement::Merge {
into,
Expand Down Expand Up @@ -8698,6 +8708,7 @@ fn parse_merge() {
action: MergeAction::Update {
assignments: vec![
Assignment {
span: Span::empty(),
target: AssignmentTarget::ColumnName(ObjectName(vec![
Ident::new("dest"),
Ident::new("F")
Expand All @@ -8708,6 +8719,7 @@ fn parse_merge() {
]),
},
Assignment {
span: Span::empty(),
target: AssignmentTarget::ColumnName(ObjectName(vec![
Ident::new("dest"),
Ident::new("G")
Expand Down Expand Up @@ -8992,6 +9004,10 @@ fn verified_stmt(query: &str) -> Statement {
all_dialects().verified_stmt(query)
}

fn verified_stmt_no_span(query: &str) -> Statement {
all_dialects().verified_stmt_no_span(query)
}

fn verified_query(query: &str) -> Query {
all_dialects().verified_query(query)
}
Expand Down
Loading

0 comments on commit 31d057a

Please sign in to comment.