diff --git a/sqlparse/engine/statement_splitter.py b/sqlparse/engine/statement_splitter.py index 6c69d303..0af67baa 100644 --- a/sqlparse/engine/statement_splitter.py +++ b/sqlparse/engine/statement_splitter.py @@ -84,7 +84,7 @@ def process(self, stream): EOS_TTYPE = T.Whitespace, T.Comment.Single # Run over all stream tokens - for ttype, value in stream: + for ttype, value, pos in stream: # Yield token if we finished a statement and there's no whitespaces # It will count newline token as a non whitespace. In this context # whitespace ignores newlines. @@ -99,7 +99,7 @@ def process(self, stream): self.level += self._change_splitlevel(ttype, value) # Append the token to the current statement - self.tokens.append(sql.Token(ttype, value)) + self.tokens.append(sql.Token(ttype, value, pos)) # Check if we get the end of a statement # Issue762: Allow GO (or "GO 2") as statement splitter. diff --git a/sqlparse/filters/tokens.py b/sqlparse/filters/tokens.py index cc00a844..863ce2c4 100644 --- a/sqlparse/filters/tokens.py +++ b/sqlparse/filters/tokens.py @@ -16,10 +16,10 @@ def __init__(self, case=None): self.convert = getattr(str, case) def process(self, stream): - for ttype, value in stream: + for ttype, value, pos in stream: if ttype in self.ttype: value = self.convert(value) - yield ttype, value + yield ttype, value, pos class KeywordCaseFilter(_CaseFilter): @@ -30,10 +30,10 @@ class IdentifierCaseFilter(_CaseFilter): ttype = T.Name, T.String.Symbol def process(self, stream): - for ttype, value in stream: + for ttype, value, pos in stream: if ttype in self.ttype and value.strip()[0] != '"': value = self.convert(value) - yield ttype, value + yield ttype, value, pos class TruncateStringFilter: @@ -42,9 +42,9 @@ def __init__(self, width, char): self.char = char def process(self, stream): - for ttype, value in stream: + for ttype, value, pos in stream: if ttype != T.Literal.String.Single: - yield ttype, value + yield ttype, value, pos continue if value[:2] == "''": @@ -56,4 +56,4 @@ def process(self, stream): if len(inner) > self.width: value = ''.join((quote, inner[:self.width], self.char, quote)) - yield ttype, value + yield ttype, value, pos diff --git a/sqlparse/lexer.py b/sqlparse/lexer.py index 8f88d171..7280e45d 100644 --- a/sqlparse/lexer.py +++ b/sqlparse/lexer.py @@ -106,14 +106,14 @@ def is_keyword(self, value): def get_tokens(self, text, encoding=None): """ - Return an iterable of (tokentype, value) pairs generated from + Return an iterable of (tokentype, value, pos) tuples generated from `text`. If `unfiltered` is set to `True`, the filtering mechanism is bypassed even if filters are defined. Also preprocess the text, i.e. expand tabs and strip it if wanted and applies registered filters. - Split ``text`` into (tokentype, text) pairs. + Split ``text`` into (tokentype, text, pos) tuples. ``stack`` is the initial stack (default: ``['root']``) """ @@ -142,20 +142,20 @@ def get_tokens(self, text, encoding=None): if not m: continue elif isinstance(action, tokens._TokenType): - yield action, m.group() + yield action, m.group(), pos elif action is keywords.PROCESS_AS_KEYWORD: - yield self.is_keyword(m.group()) + yield (*self.is_keyword(m.group()), pos) consume(iterable, m.end() - pos - 1) break else: - yield tokens.Error, char + yield tokens.Error, char, pos def tokenize(sql, encoding=None): """Tokenize sql. - Tokenize *sql* using the :class:`Lexer` and return a 2-tuple stream - of ``(token type, value)`` items. + Tokenize *sql* using the :class:`Lexer` and return a 3-tuple stream + of ``(token type, value, pos)`` items. """ return Lexer.get_default_instance().get_tokens(sql, encoding) diff --git a/sqlparse/sql.py b/sqlparse/sql.py index 10373751..0fb3ef2c 100644 --- a/sqlparse/sql.py +++ b/sqlparse/sql.py @@ -45,14 +45,16 @@ class Token: the type of the token. """ - __slots__ = ('value', 'ttype', 'parent', 'normalized', 'is_keyword', - 'is_group', 'is_whitespace', 'is_newline') + __slots__ = ('value', 'ttype', 'parent', 'pos', 'length', 'normalized', + 'is_keyword', 'is_group', 'is_whitespace', 'is_newline') - def __init__(self, ttype, value): + def __init__(self, ttype, value, pos=None): value = str(value) self.value = value self.ttype = ttype self.parent = None + self.pos = pos + self.length = len(value) self.is_group = False self.is_keyword = ttype in T.Keyword self.is_whitespace = self.ttype in T.Whitespace @@ -163,6 +165,27 @@ def __init__(self, tokens=None): super().__init__(None, str(self)) self.is_group = True + @property + def pos(self): + if len(self.tokens) > 0: + return self.tokens[0].pos + + @property + def length(self): + if len(self.tokens) > 0: + first, last = self.tokens[0], self.tokens[-1] + if first.pos is not None and last.pos is not None: + return last.length + (last.pos - first.pos) + + return len(str(self)) + + # this is a bit of a hack to avoid problems with the super constructor + # trying to set these attributes, which we want to compute dynamically + @pos.setter + def pos(self, value): ... + @length.setter + def length(self, value): ... + def __str__(self): return ''.join(token.value for token in self.flatten()) diff --git a/tests/test_source_positions.py b/tests/test_source_positions.py new file mode 100644 index 00000000..b84a0fab --- /dev/null +++ b/tests/test_source_positions.py @@ -0,0 +1,136 @@ +from typing import List, Optional, Tuple, Type, Union + +from sqlparse import parse +from sqlparse.sql import Identifier, IdentifierList, Statement, Token +from sqlparse.tokens import ( + CTE, + DML, + Comparison, + Keyword, + Number, + _TokenType as TokenType, +) + + +def itertokens(token: Token): + yield token + if token.is_group: + for child in token.tokens: + yield from itertokens(child) + + +# allow matching by Token subclass or ttype +TokenClassOrType = Union[TokenType, Type[Token]] + + +def parsed_tokens_with_sources( + sql: str, types: Tuple[TokenClassOrType, ...] +) -> List[Tuple[TokenClassOrType, str, str]]: + # given a query, parses it, iterates over all the tokens it contains, and + # for each token that matches `types`, returns a tuple of the matched token + # type, the token's value, and the source of the token retrieved by slicing + # into the original query using the token's `pos` and `length` attributes + + def matches(token: Token) -> Optional[TokenClassOrType]: + for class_or_type in types: + if isinstance(class_or_type, TokenType): + if token.ttype is class_or_type: + return class_or_type + elif isinstance(token, class_or_type): + return class_or_type + + def get_source(token: Token) -> str: + return sql[token.pos : token.pos + token.length] + + statements = parse(sql) + return [ + (match, token.value, get_source(token)) + for statement in statements + for token in itertokens(statement) + if (match := matches(token)) + ] + + +def test_simple_query(): + assert parsed_tokens_with_sources( + "select * from foo;", (DML, Identifier, Keyword, Statement) + ) == [ + (Statement, "select * from foo;", "select * from foo;"), + (DML, "select", "select"), + (Keyword, "from", "from"), + (Identifier, "foo", "foo"), + ] + + +def test_multiple_statements(): + stmt1 = "select * from foo;" + stmt2 = "\nselect *\nfrom bar;" + assert parsed_tokens_with_sources( + stmt1 + stmt2, (DML, Identifier, Keyword, Statement) + ) == [ + (Statement, stmt1, stmt1), + (DML, "select", "select"), + (Keyword, "from", "from"), + (Identifier, "foo", "foo"), + (Statement, stmt2, stmt2), + (DML, "select", "select"), + (Keyword, "from", "from"), + (Identifier, "bar", "bar"), + ] + + +def test_subselect(): + stmt = """select a0, b0, c0, d0, e0 from + (select * from dual) q0 where 1=1 and 2=2""" + assert parsed_tokens_with_sources( + stmt, + ( + DML, + Comparison, + Identifier, + IdentifierList, + Keyword, + Number.Integer, + Statement, + ), + ) == [ + (Statement, stmt, stmt), + (DML, "select", "select"), + (IdentifierList, "a0, b0, c0, d0, e0", "a0, b0, c0, d0, e0"), + (Identifier, "a0", "a0"), + (Identifier, "b0", "b0"), + (Identifier, "c0", "c0"), + (Identifier, "d0", "d0"), + (Identifier, "e0", "e0"), + (Keyword, "from", "from"), + (Identifier, "(select * from dual) q0", "(select * from dual) q0"), + (DML, "select", "select"), + (Keyword, "from", "from"), + (Identifier, "dual", "dual"), + (Identifier, "q0", "q0"), + (Keyword, "where", "where"), + (Number.Integer, "1", "1"), + (Comparison, "=", "="), + (Number.Integer, "1", "1"), + (Keyword, "and", "and"), + (Number.Integer, "2", "2"), + (Comparison, "=", "="), + (Number.Integer, "2", "2"), + ] + + +def test_cte(): + stmt = """WITH foo AS (SELECT 1, 2, 3) + SELECT * FROM foo;""" + assert parsed_tokens_with_sources( + stmt, (CTE, DML, Identifier, Keyword, Statement) + ) == [ + (Statement, stmt, stmt), + (CTE, "WITH", "WITH"), + (Identifier, "foo AS (SELECT 1, 2, 3)", "foo AS (SELECT 1, 2, 3)"), + (Keyword, "AS", "AS"), + (DML, "SELECT", "SELECT"), + (DML, "SELECT", "SELECT"), + (Keyword, "FROM", "FROM"), + (Identifier, "foo", "foo"), + ] diff --git a/tests/test_tokenize.py b/tests/test_tokenize.py index 8ec12d83..9ba797c8 100644 --- a/tests/test_tokenize.py +++ b/tests/test_tokenize.py @@ -14,16 +14,16 @@ def test_tokenize_simple(): assert isinstance(stream, types.GeneratorType) tokens = list(stream) assert len(tokens) == 8 - assert len(tokens[0]) == 2 - assert tokens[0] == (T.Keyword.DML, 'select') - assert tokens[-1] == (T.Punctuation, ';') + assert len(tokens[0]) == 3 + assert tokens[0] == (T.Keyword.DML, 'select', 0) + assert tokens[-1] == (T.Punctuation, ';', 17) def test_tokenize_backticks(): s = '`foo`.`bar`' tokens = list(lexer.tokenize(s)) assert len(tokens) == 3 - assert tokens[0] == (T.Name, '`foo`') + assert tokens[0] == (T.Name, '`foo`', 0) @pytest.mark.parametrize('s', ['foo\nbar\n', 'foo\rbar\r',