Skip to content

Commit

Permalink
Merge pull request #1 from skryzh/format_fix
Browse files Browse the repository at this point in the history
[fix] issue262 fixing problem w/ removing SQL-Hints w/ comments
  • Loading branch information
skryzh authored Oct 18, 2024
2 parents a2c7e64 + bf68cb1 commit b0666d8
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 7 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,6 @@ MANIFEST
.cache/
*.egg-info/
htmlcov/
.pytest_cache
.pytest_cache
**/.vscode
**/.env
35 changes: 29 additions & 6 deletions sqlparse/filters/others.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ class StripCommentsFilter:

@staticmethod
def _process(tlist):
def get_next_comment():
def get_next_comment(idx=-1):
# TODO(andi) Comment types should be unified, see related issue38
return tlist.token_next_by(i=sql.Comment, t=T.Comment)
return tlist.token_next_by(i=sql.Comment, t=T.Comment, idx=idx)

def _get_insert_token(token):
"""Returns either a whitespace or the line breaks from token."""
Expand All @@ -31,15 +31,35 @@ def _get_insert_token(token):
else:
return sql.Token(T.Whitespace, ' ')

sql_hints = (T.Comment.Multiline.Hint, T.Comment.Single.Hint)
tidx, token = get_next_comment()
while token:
# skipping token remove if token is a SQL-Hint. issue262
is_sql_hint = False
if token.ttype in sql_hints:
is_sql_hint = True
elif isinstance(token, sql.Comment):
comment_tokens = token.tokens
if len(comment_tokens) > 0:
if comment_tokens[0].ttype in sql_hints:
is_sql_hint = True

if is_sql_hint:
# using current index as start index to search next token for
# preventing infinite loop in cases when token type is a
# "SQL-Hint"and has to be skipped
tidx, token = get_next_comment(idx=tidx)
continue

pidx, prev_ = tlist.token_prev(tidx, skip_ws=False)
nidx, next_ = tlist.token_next(tidx, skip_ws=False)
# Replace by whitespace if prev and next exist and if they're not
# whitespaces. This doesn't apply if prev or next is a parenthesis.
if (prev_ is None or next_ is None
or prev_.is_whitespace or prev_.match(T.Punctuation, '(')
or next_.is_whitespace or next_.match(T.Punctuation, ')')):
if (
prev_ is None or next_ is None
or prev_.is_whitespace or prev_.match(T.Punctuation, '(')
or next_.is_whitespace or next_.match(T.Punctuation, ')')
):
# Insert a whitespace to ensure the following SQL produces
# a valid SQL (see #425).
if prev_ is not None and not prev_.match(T.Punctuation, '('):
Expand All @@ -48,7 +68,10 @@ def _get_insert_token(token):
else:
tlist.tokens[tidx] = _get_insert_token(token)

tidx, token = get_next_comment()
# using current index as start index to search next token for
# preventing infinite loop in cases when token type is a
# "SQL-Hint"and has to be skipped
tidx, token = get_next_comment(idx=tidx)

def process(self, stmt):
[self.process(sgroup) for sgroup in stmt.get_sublists()]
Expand Down
20 changes: 20 additions & 0 deletions tests/test_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,19 @@ def test_strip_comments_single(self):
'from foo--comment\nf'
res = sqlparse.format(sql, strip_comments=True)
assert res == 'select a\nfrom foo\nf'
# SQL-Hints have to be preserved
sql = 'select --+full(u)'
res = sqlparse.format(sql, strip_comments=True)
assert res == sql
sql = '#+ hint\nselect * from foo'
res = sqlparse.format(sql, strip_comments=True)
assert res == sql
sql = 'select --+full(u)\n--comment simple'
res = sqlparse.format(sql, strip_comments=True)
assert res == 'select --+full(u)\n'
sql = '#+ hint\nselect * from foo\n# comment simple'
res = sqlparse.format(sql, strip_comments=True)
assert res == '#+ hint\nselect * from foo\n'

def test_strip_comments_invalid_option(self):
sql = 'select-- foo\nfrom -- bar\nwhere'
Expand All @@ -83,6 +96,13 @@ def test_strip_comments_multi(self):
sql = 'select (/* sql /* starts here */ select 2)'
res = sqlparse.format(sql, strip_comments=True, strip_whitespace=True)
assert res == 'select (select 2)'
# SQL-Hints have to be preserved
sql = 'SELECT /*+cluster(T)*/* FROM T_EEE T where A >:1'
res = sqlparse.format(sql, strip_comments=True)
assert res == sql
sql = 'insert /*+ DIRECT */ into sch.table_name as select * from foo'
res = sqlparse.format(sql, strip_comments=True)
assert res == sql

def test_strip_comments_preserves_linebreak(self):
sql = 'select * -- a comment\r\nfrom foo'
Expand Down

0 comments on commit b0666d8

Please sign in to comment.