Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

token source position tracking #794

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sqlparse/engine/statement_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def process(self, stream):
EOS_TTYPE = T.Whitespace, T.Comment.Single

# Run over all stream tokens
for ttype, value in stream:
for ttype, value, pos in stream:
# Yield token if we finished a statement and there's no whitespaces
# It will count newline token as a non whitespace. In this context
# whitespace ignores newlines.
Expand All @@ -99,7 +99,7 @@ def process(self, stream):
self.level += self._change_splitlevel(ttype, value)

# Append the token to the current statement
self.tokens.append(sql.Token(ttype, value))
self.tokens.append(sql.Token(ttype, value, pos))

# Check if we get the end of a statement
# Issue762: Allow GO (or "GO 2") as statement splitter.
Expand Down
14 changes: 7 additions & 7 deletions sqlparse/filters/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ def __init__(self, case=None):
self.convert = getattr(str, case)

def process(self, stream):
for ttype, value in stream:
for ttype, value, pos in stream:
if ttype in self.ttype:
value = self.convert(value)
yield ttype, value
yield ttype, value, pos


class KeywordCaseFilter(_CaseFilter):
Expand All @@ -30,10 +30,10 @@ class IdentifierCaseFilter(_CaseFilter):
ttype = T.Name, T.String.Symbol

def process(self, stream):
for ttype, value in stream:
for ttype, value, pos in stream:
if ttype in self.ttype and value.strip()[0] != '"':
value = self.convert(value)
yield ttype, value
yield ttype, value, pos


class TruncateStringFilter:
Expand All @@ -42,9 +42,9 @@ def __init__(self, width, char):
self.char = char

def process(self, stream):
for ttype, value in stream:
for ttype, value, pos in stream:
if ttype != T.Literal.String.Single:
yield ttype, value
yield ttype, value, pos
continue

if value[:2] == "''":
Expand All @@ -56,4 +56,4 @@ def process(self, stream):

if len(inner) > self.width:
value = ''.join((quote, inner[:self.width], self.char, quote))
yield ttype, value
yield ttype, value, pos
14 changes: 7 additions & 7 deletions sqlparse/lexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,14 @@ def is_keyword(self, value):

def get_tokens(self, text, encoding=None):
"""
Return an iterable of (tokentype, value) pairs generated from
Return an iterable of (tokentype, value, pos) tuples generated from
`text`. If `unfiltered` is set to `True`, the filtering mechanism
is bypassed even if filters are defined.

Also preprocess the text, i.e. expand tabs and strip it if
wanted and applies registered filters.

Split ``text`` into (tokentype, text) pairs.
Split ``text`` into (tokentype, text, pos) tuples.

``stack`` is the initial stack (default: ``['root']``)
"""
Expand Down Expand Up @@ -142,20 +142,20 @@ def get_tokens(self, text, encoding=None):
if not m:
continue
elif isinstance(action, tokens._TokenType):
yield action, m.group()
yield action, m.group(), pos
elif action is keywords.PROCESS_AS_KEYWORD:
yield self.is_keyword(m.group())
yield (*self.is_keyword(m.group()), pos)

consume(iterable, m.end() - pos - 1)
break
else:
yield tokens.Error, char
yield tokens.Error, char, pos


def tokenize(sql, encoding=None):
"""Tokenize sql.

Tokenize *sql* using the :class:`Lexer` and return a 2-tuple stream
of ``(token type, value)`` items.
Tokenize *sql* using the :class:`Lexer` and return a 3-tuple stream
of ``(token type, value, pos)`` items.
"""
return Lexer.get_default_instance().get_tokens(sql, encoding)
29 changes: 26 additions & 3 deletions sqlparse/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,16 @@ class Token:
the type of the token.
"""

__slots__ = ('value', 'ttype', 'parent', 'normalized', 'is_keyword',
'is_group', 'is_whitespace', 'is_newline')
__slots__ = ('value', 'ttype', 'parent', 'pos', 'length', 'normalized',
'is_keyword', 'is_group', 'is_whitespace', 'is_newline')

def __init__(self, ttype, value):
def __init__(self, ttype, value, pos=None):
value = str(value)
self.value = value
self.ttype = ttype
self.parent = None
self.pos = pos
self.length = len(value)
self.is_group = False
self.is_keyword = ttype in T.Keyword
self.is_whitespace = self.ttype in T.Whitespace
Expand Down Expand Up @@ -163,6 +165,27 @@ def __init__(self, tokens=None):
super().__init__(None, str(self))
self.is_group = True

@property
def pos(self):
if len(self.tokens) > 0:
return self.tokens[0].pos

@property
def length(self):
if len(self.tokens) > 0:
first, last = self.tokens[0], self.tokens[-1]
if first.pos is not None and last.pos is not None:
return last.length + (last.pos - first.pos)

return len(str(self))

# this is a bit of a hack to avoid problems with the super constructor
# trying to set these attributes, which we want to compute dynamically
@pos.setter
def pos(self, value): ...
@length.setter
def length(self, value): ...
Comment on lines +182 to +187
Copy link
Author

@dylanscott dylanscott Oct 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I initially had the Token constructor do a hasattr(self, 'pos') check before assigning, which I didn't love either. But that approach broke one test where deepcopy was called on a parsed statement, whereas this approach is robust to that case.

It seemed worth making these properties dynamic given that TokenList exposes methods for inserting tokens


def __str__(self):
return ''.join(token.value for token in self.flatten())

Expand Down
136 changes: 136 additions & 0 deletions tests/test_source_positions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from typing import List, Optional, Tuple, Type, Union

from sqlparse import parse
from sqlparse.sql import Identifier, IdentifierList, Statement, Token
from sqlparse.tokens import (
CTE,
DML,
Comparison,
Keyword,
Number,
_TokenType as TokenType,
)


def itertokens(token: Token):
yield token
if token.is_group:
for child in token.tokens:
yield from itertokens(child)


# allow matching by Token subclass or ttype
TokenClassOrType = Union[TokenType, Type[Token]]


def parsed_tokens_with_sources(
sql: str, types: Tuple[TokenClassOrType, ...]
) -> List[Tuple[TokenClassOrType, str, str]]:
# given a query, parses it, iterates over all the tokens it contains, and
# for each token that matches `types`, returns a tuple of the matched token
# type, the token's value, and the source of the token retrieved by slicing
# into the original query using the token's `pos` and `length` attributes

def matches(token: Token) -> Optional[TokenClassOrType]:
for class_or_type in types:
if isinstance(class_or_type, TokenType):
if token.ttype is class_or_type:
return class_or_type
elif isinstance(token, class_or_type):
return class_or_type

def get_source(token: Token) -> str:
return sql[token.pos : token.pos + token.length]

statements = parse(sql)
return [
(match, token.value, get_source(token))
for statement in statements
for token in itertokens(statement)
if (match := matches(token))
]


def test_simple_query():
assert parsed_tokens_with_sources(
"select * from foo;", (DML, Identifier, Keyword, Statement)
) == [
(Statement, "select * from foo;", "select * from foo;"),
(DML, "select", "select"),
(Keyword, "from", "from"),
(Identifier, "foo", "foo"),
]


def test_multiple_statements():
stmt1 = "select * from foo;"
stmt2 = "\nselect *\nfrom bar;"
assert parsed_tokens_with_sources(
stmt1 + stmt2, (DML, Identifier, Keyword, Statement)
) == [
(Statement, stmt1, stmt1),
(DML, "select", "select"),
(Keyword, "from", "from"),
(Identifier, "foo", "foo"),
(Statement, stmt2, stmt2),
(DML, "select", "select"),
(Keyword, "from", "from"),
(Identifier, "bar", "bar"),
]


def test_subselect():
stmt = """select a0, b0, c0, d0, e0 from
(select * from dual) q0 where 1=1 and 2=2"""
assert parsed_tokens_with_sources(
stmt,
(
DML,
Comparison,
Identifier,
IdentifierList,
Keyword,
Number.Integer,
Statement,
),
) == [
(Statement, stmt, stmt),
(DML, "select", "select"),
(IdentifierList, "a0, b0, c0, d0, e0", "a0, b0, c0, d0, e0"),
(Identifier, "a0", "a0"),
(Identifier, "b0", "b0"),
(Identifier, "c0", "c0"),
(Identifier, "d0", "d0"),
(Identifier, "e0", "e0"),
(Keyword, "from", "from"),
(Identifier, "(select * from dual) q0", "(select * from dual) q0"),
(DML, "select", "select"),
(Keyword, "from", "from"),
(Identifier, "dual", "dual"),
(Identifier, "q0", "q0"),
(Keyword, "where", "where"),
(Number.Integer, "1", "1"),
(Comparison, "=", "="),
(Number.Integer, "1", "1"),
(Keyword, "and", "and"),
(Number.Integer, "2", "2"),
(Comparison, "=", "="),
(Number.Integer, "2", "2"),
]


def test_cte():
stmt = """WITH foo AS (SELECT 1, 2, 3)
SELECT * FROM foo;"""
assert parsed_tokens_with_sources(
stmt, (CTE, DML, Identifier, Keyword, Statement)
) == [
(Statement, stmt, stmt),
(CTE, "WITH", "WITH"),
(Identifier, "foo AS (SELECT 1, 2, 3)", "foo AS (SELECT 1, 2, 3)"),
(Keyword, "AS", "AS"),
(DML, "SELECT", "SELECT"),
(DML, "SELECT", "SELECT"),
(Keyword, "FROM", "FROM"),
(Identifier, "foo", "foo"),
]
8 changes: 4 additions & 4 deletions tests/test_tokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@ def test_tokenize_simple():
assert isinstance(stream, types.GeneratorType)
tokens = list(stream)
assert len(tokens) == 8
assert len(tokens[0]) == 2
assert tokens[0] == (T.Keyword.DML, 'select')
assert tokens[-1] == (T.Punctuation, ';')
assert len(tokens[0]) == 3
assert tokens[0] == (T.Keyword.DML, 'select', 0)
assert tokens[-1] == (T.Punctuation, ';', 17)


def test_tokenize_backticks():
s = '`foo`.`bar`'
tokens = list(lexer.tokenize(s))
assert len(tokens) == 3
assert tokens[0] == (T.Name, '`foo`')
assert tokens[0] == (T.Name, '`foo`', 0)


@pytest.mark.parametrize('s', ['foo\nbar\n', 'foo\rbar\r',
Expand Down