Skip to content

Commit

Permalink
Add custom tsquery from websearch function and related tests
Browse files Browse the repository at this point in the history
  • Loading branch information
adelavega committed Oct 29, 2024
1 parent e72d9c9 commit 243c150
Show file tree
Hide file tree
Showing 3 changed files with 238 additions and 2 deletions.
7 changes: 5 additions & 2 deletions store/neurostore/resources/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from ..core import cache
from ..database import db
from .utils import get_current_user
from .utils import get_current_user, validate_search_query, search_to_tsquery
from ..models import (
StudysetStudy,
AnnotationAnalysis,
Expand Down Expand Up @@ -613,7 +613,10 @@ def search(self):
if s is not None and s.isdigit():
q = q.filter_by(pmid=s)
elif s is not None and self._fulltext_fields:
tsquery = sa.func.websearch_to_tsquery("english", s)
valid = validate_search_query(s)
if not valid:
abort(400, description=valid)
tsquery = search_to_tsquery(s)
q = q.filter(m._ts_vector.op("@@")(tsquery))

# Alternatively (or in addition), search on individual fields.
Expand Down
179 changes: 179 additions & 0 deletions store/neurostore/resources/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,182 @@ class ClassView(cls):
ClassView.__name__ = cls.__name__

return ClassView


def validate_search_query(query: str) -> bool:
"""
Validate a search query string.
Args:
query (str): The query string to validate.
Returns:
bool: True if the query is valid, False otherwise.
"""
# Check for valid parentheses
if not validate_parentheses(query):
return 'Unmatched parentheses'

# Check for valid query end
if not validate_query_end(query):
return 'Query cannot end with an operator'

return True


def validate_parentheses(query: str) -> bool:
"""
Validate the parentheses in a query string.
Args:
query (str): The query string to validate.
Returns:
bool: True if parentheses are valid, False otherwise.
"""
stack = []
for char in query:
if char == '(':
stack.append(char)
elif char == ')':
if not stack:
return False # Unmatched closing parenthesis
stack.pop()
return not stack # Ensure all opening parentheses are closed


def validate_query_end(query: str) -> bool:
""" Query should not end with an operator """
operators = ('AND', 'OR', 'NOT')

if query.strip().split(' ')[-1] in operators:
return False
return True


def count_chars(target, query: str) -> int:
""" Count the number of chars in a query string.
Excluding those in quoted phrases."""
count = 0
in_quotes = False
for char in query:
if char == '"':
in_quotes = not in_quotes
if char == target and not in_quotes:
count += 1
return count


def pubmed_to_tsquery(query: str) -> str:
"""
Convert a PubMed-like search query to PostgreSQL tsquery format,
grouping both single-quoted and double-quoted text with the <-> operator
for proximity search.
Additionally, automatically adds & between non-explicitly connected terms
and handles NOT terms.
Args:
query (str): The search query.
Returns:
str: The PostgreSQL tsquery equivalent.
"""

query = query.upper() # Ensure uniformity

# Step 1: Split into tokens (preserving quoted phrases)
# Regex pattern: match quoted phrases or non-space sequences
tokens = re.findall( r'"[^"]*"|\'[^\']*\'|\S+', query)

# Step 2: Combine tokens in parantheses into single tokens
def combine_parentheses(tokens: list) -> list:
"""
Combine tokens within parentheses into a single token.
Args:
tokens (list): List of tokens to process.
Returns:
list: Processed list with tokens inside parentheses combined.
"""
combined_tokens = []
buffer = []
paren_count = 0
for token in tokens:
# If buffer is not empty, we are inside parentheses
if len(buffer) > 0:
buffer.append(token)

# Adjust the count of parentheses
paren_count += count_chars('(', token) - count_chars(')', token)

if paren_count < 1:
# Combine all tokens in parentheses
combined_tokens.append(' '.join(buffer))
buffer = [] # Clear the buffer
paren_count = 0

else:
n_paren = count_chars('(', token) - count_chars(')', token)
# If not in parentheses, but token contains opening parentheses
# Start capturing tokens inside parentheses
if token[0] == '(' and n_paren > 0:
paren_count += n_paren
buffer.append(token) # Start capturing tokens in parens
print(buffer)
else:
combined_tokens.append(token)

# If the list ends without a closing parenthesis (invalid input)
# append buffer contents (fallback)
if buffer:
combined_tokens.append(' '.join(buffer))

return combined_tokens

tokens = combine_parentheses(tokens)
print(tokens)
for i, token in enumerate(tokens):
if token[0] == "(" and token[-1] == ")":
# RECURSIVE: Process the contents of the parentheses
token_res = pubmed_to_tsquery(token[1:-1])
token = '(' + token_res + ')'
tokens[i] = token

# Step 4: Handle both single-quoted and double-quoted phrases,
# grouping them with <-> (proximity operator)
elif token[0] in ('"', "'"):
# Split quoted text into individual words and join with <-> for
# proximity search
words = re.findall(r'\w+', token)
tokens[i] = '<->'.join(words)

# Step 3: Replace logical operators AND, OR, NOT
else:
if token == 'AND':
tokens[i] = '&'
elif token == 'OR':
tokens[i] = '|'
elif token == 'NOT':
tokens[i] = '&!'

processed_tokens = []
last_token = None
for token in tokens:
# Step 5: Add & between consecutive terms that aren't already
# connected by an operator
stripped_token = token.strip()

if stripped_token == '':
continue # Ignore empty tokens from splitting

if last_token and last_token not in ('&', '|', '!', '&!'):
if stripped_token not in ('&', '|', '!', '&!'):
# Insert an implicit AND (&) between two non-operator tokens
processed_tokens.append('&')

processed_tokens.append(stripped_token)
last_token = stripped_token

return ' '.join(processed_tokens)
54 changes: 54 additions & 0 deletions store/neurostore/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import pytest

from ..utils import search_to_tsquery, validate_search_query


invalid_queries = [
('("autism" OR "ASD" OR "autistic") AND (("decision*" OR "choice*" ', 'Unmatched parentheses'),
('"autism" OR "ASD" OR "autistic" OR ', 'Query cannot end with an operator'),
('(("Autism Spectrum Disorder" OR "autism spectrum disorder") OR ("Autism" OR "autism") OR ("ASD")) AND (("decision*" OR "Dec', 'Unmatched parentheses')
]

valid_queries = [
('"Mild Cognitive Impairment" or "Early Cognitive Decline" or "Pre-Dementia" or "Mild Neurocognitive Disorder"',
'MILD<->COGNITIVE<->IMPAIRMENT | EARLY<->COGNITIVE<->DECLINE | PRE<->DEMENTIA | MILD<->NEUROCOGNITIVE<->DISORDER'),
('("autism" OR "ASD" OR "autistic") AND ("decision" OR "choice")',
'(AUTISM | ASD | AUTISTIC) & (DECISION | CHOICE)'),
('stroop and depression or back and depression or go',
'STROOP & DEPRESSION | BACK & DEPRESSION | GO'),
('("autism" OR "ASD" OR "autistic") AND (("decision" OR "decision-making" OR "choice" OR "selection" OR "option" OR "value") OR ("feedback" OR "feedback-related" OR "reward" OR "error" OR "outcome" OR "punishment" OR "reinforcement"))',
'(AUTISM | ASD | AUTISTIC) & ((DECISION | DECISION<->MAKING | CHOICE | SELECTION | OPTION | VALUE) | (FEEDBACK | FEEDBACK<->RELATED | REWARD | ERROR | OUTCOME | PUNISHMENT | REINFORCEMENT))'),
('"dyslexia" or "Reading Disorder" or "Language-Based Learning Disability" or "Phonological Processing Disorder" or "Word Blindness"',
'DYSLEXIA | READING<->DISORDER | LANGUAGE<->BASED<->LEARNING<->DISABILITY | PHONOLOGICAL<->PROCESSING<->DISORDER | WORD<->BLINDNESS'),
('emotion and pain -physical -touch',
'EMOTION & PAIN & -PHYSICAL & -TOUCH'),
('("Schizophrenia"[Mesh] OR schizophrenia )',
'(SCHIZOPHRENIA & [MESH] | SCHIZOPHRENIA)')
('Bipolar Disorder',
'BIPOLAR & DISORDER'),
('"quchi" or "LI11"',
'QUCHI | LI11'),
('"rubber hand illusion"',
'RUBBER<->HAND<->ILLUSION'),
]

error_queries = [
"[Major Depressive Disorder (MDD)] or [Clinical Depression] or [Unipolar Depression]"
]

validate_queries = invalid_queries + [(q, True) for q, _ in valid_queries]


@pytest.mark.parametrize("query, expected", valid_queries)
def test_search_to_tsquery(query, expected):
assert search_to_tsquery(query) == expected


@pytest.mark.parametrize("query, expected", invalid_queries)
def test_validate_search_query(query, expected):
assert validate_search_query(query) == expected

@pytest.mark.parametrize("query", error_queries)
def test_search_to_tsquery_error(query):
with pytest.raises(ValueError):
search_to_tsquery(query)

0 comments on commit 243c150

Please sign in to comment.