Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add custom tsquery from websearch function and related tests #838

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions store/neurostore/resources/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from ..core import cache
from ..database import db
from .utils import get_current_user
from .utils import get_current_user, validate_search_query, search_to_tsquery
from ..models import (
StudysetStudy,
AnnotationAnalysis,
Expand Down Expand Up @@ -613,7 +613,10 @@ def search(self):
if s is not None and s.isdigit():
q = q.filter_by(pmid=s)
elif s is not None and self._fulltext_fields:
tsquery = sa.func.websearch_to_tsquery("english", s)
valid = validate_search_query(s)
if not valid:
abort(400, description=valid)
tsquery = search_to_tsquery(s)
q = q.filter(m._ts_vector.op("@@")(tsquery))

# Alternatively (or in addition), search on individual fields.
Expand Down
179 changes: 179 additions & 0 deletions store/neurostore/resources/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,182 @@ class ClassView(cls):
ClassView.__name__ = cls.__name__

return ClassView


def validate_search_query(query: str) -> bool:
"""
Validate a search query string.

Args:
query (str): The query string to validate.

Returns:
bool: True if the query is valid, False otherwise.
"""
# Check for valid parentheses
if not validate_parentheses(query):
return 'Unmatched parentheses'

# Check for valid query end
if not validate_query_end(query):
return 'Query cannot end with an operator'

return True


def validate_parentheses(query: str) -> bool:
"""
Validate the parentheses in a query string.

Args:
query (str): The query string to validate.

Returns:
bool: True if parentheses are valid, False otherwise.
"""
stack = []
for char in query:
if char == '(':
stack.append(char)
elif char == ')':
if not stack:
return False # Unmatched closing parenthesis
stack.pop()
return not stack # Ensure all opening parentheses are closed


def validate_query_end(query: str) -> bool:
""" Query should not end with an operator """
operators = ('AND', 'OR', 'NOT')

if query.strip().split(' ')[-1] in operators:
return False
return True


def count_chars(target, query: str) -> int:
""" Count the number of chars in a query string.
Excluding those in quoted phrases."""
count = 0
in_quotes = False
for char in query:
if char == '"':
in_quotes = not in_quotes
if char == target and not in_quotes:
count += 1
return count


def pubmed_to_tsquery(query: str) -> str:
"""
Convert a PubMed-like search query to PostgreSQL tsquery format,
grouping both single-quoted and double-quoted text with the <-> operator
for proximity search.

Additionally, automatically adds & between non-explicitly connected terms
and handles NOT terms.

Args:
query (str): The search query.

Returns:
str: The PostgreSQL tsquery equivalent.
"""

query = query.upper() # Ensure uniformity

# Step 1: Split into tokens (preserving quoted phrases)
# Regex pattern: match quoted phrases or non-space sequences
tokens = re.findall( r'"[^"]*"|\'[^\']*\'|\S+', query)

# Step 2: Combine tokens in parantheses into single tokens
def combine_parentheses(tokens: list) -> list:
"""
Combine tokens within parentheses into a single token.

Args:
tokens (list): List of tokens to process.

Returns:
list: Processed list with tokens inside parentheses combined.
"""
combined_tokens = []
buffer = []
paren_count = 0
for token in tokens:
# If buffer is not empty, we are inside parentheses
if len(buffer) > 0:
buffer.append(token)

# Adjust the count of parentheses
paren_count += count_chars('(', token) - count_chars(')', token)

if paren_count < 1:
# Combine all tokens in parentheses
combined_tokens.append(' '.join(buffer))
buffer = [] # Clear the buffer
paren_count = 0

else:
n_paren = count_chars('(', token) - count_chars(')', token)
# If not in parentheses, but token contains opening parentheses
# Start capturing tokens inside parentheses
if token[0] == '(' and n_paren > 0:
paren_count += n_paren
buffer.append(token) # Start capturing tokens in parens
print(buffer)
else:
combined_tokens.append(token)

# If the list ends without a closing parenthesis (invalid input)
# append buffer contents (fallback)
if buffer:
combined_tokens.append(' '.join(buffer))

return combined_tokens

tokens = combine_parentheses(tokens)
print(tokens)
for i, token in enumerate(tokens):
if token[0] == "(" and token[-1] == ")":
# RECURSIVE: Process the contents of the parentheses
token_res = pubmed_to_tsquery(token[1:-1])
token = '(' + token_res + ')'
tokens[i] = token

# Step 4: Handle both single-quoted and double-quoted phrases,
# grouping them with <-> (proximity operator)
elif token[0] in ('"', "'"):
# Split quoted text into individual words and join with <-> for
# proximity search
words = re.findall(r'\w+', token)
tokens[i] = '<->'.join(words)

# Step 3: Replace logical operators AND, OR, NOT
else:
if token == 'AND':
tokens[i] = '&'
elif token == 'OR':
tokens[i] = '|'
elif token == 'NOT':
tokens[i] = '&!'

processed_tokens = []
last_token = None
for token in tokens:
# Step 5: Add & between consecutive terms that aren't already
# connected by an operator
stripped_token = token.strip()

if stripped_token == '':
continue # Ignore empty tokens from splitting

if last_token and last_token not in ('&', '|', '!', '&!'):
if stripped_token not in ('&', '|', '!', '&!'):
# Insert an implicit AND (&) between two non-operator tokens
processed_tokens.append('&')

processed_tokens.append(stripped_token)
last_token = stripped_token

return ' '.join(processed_tokens)
54 changes: 54 additions & 0 deletions store/neurostore/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import pytest

from ..utils import search_to_tsquery, validate_search_query


invalid_queries = [
('("autism" OR "ASD" OR "autistic") AND (("decision*" OR "choice*" ', 'Unmatched parentheses'),
('"autism" OR "ASD" OR "autistic" OR ', 'Query cannot end with an operator'),
('(("Autism Spectrum Disorder" OR "autism spectrum disorder") OR ("Autism" OR "autism") OR ("ASD")) AND (("decision*" OR "Dec', 'Unmatched parentheses')
]

valid_queries = [
('"Mild Cognitive Impairment" or "Early Cognitive Decline" or "Pre-Dementia" or "Mild Neurocognitive Disorder"',
'MILD<->COGNITIVE<->IMPAIRMENT | EARLY<->COGNITIVE<->DECLINE | PRE<->DEMENTIA | MILD<->NEUROCOGNITIVE<->DISORDER'),
('("autism" OR "ASD" OR "autistic") AND ("decision" OR "choice")',
'(AUTISM | ASD | AUTISTIC) & (DECISION | CHOICE)'),
('stroop and depression or back and depression or go',
'STROOP & DEPRESSION | BACK & DEPRESSION | GO'),
('("autism" OR "ASD" OR "autistic") AND (("decision" OR "decision-making" OR "choice" OR "selection" OR "option" OR "value") OR ("feedback" OR "feedback-related" OR "reward" OR "error" OR "outcome" OR "punishment" OR "reinforcement"))',
'(AUTISM | ASD | AUTISTIC) & ((DECISION | DECISION<->MAKING | CHOICE | SELECTION | OPTION | VALUE) | (FEEDBACK | FEEDBACK<->RELATED | REWARD | ERROR | OUTCOME | PUNISHMENT | REINFORCEMENT))'),
('"dyslexia" or "Reading Disorder" or "Language-Based Learning Disability" or "Phonological Processing Disorder" or "Word Blindness"',
'DYSLEXIA | READING<->DISORDER | LANGUAGE<->BASED<->LEARNING<->DISABILITY | PHONOLOGICAL<->PROCESSING<->DISORDER | WORD<->BLINDNESS'),
('emotion and pain -physical -touch',
'EMOTION & PAIN & -PHYSICAL & -TOUCH'),
('("Schizophrenia"[Mesh] OR schizophrenia )',
'(SCHIZOPHRENIA & [MESH] | SCHIZOPHRENIA)')
('Bipolar Disorder',
'BIPOLAR & DISORDER'),
('"quchi" or "LI11"',
'QUCHI | LI11'),
('"rubber hand illusion"',
'RUBBER<->HAND<->ILLUSION'),
]

error_queries = [
"[Major Depressive Disorder (MDD)] or [Clinical Depression] or [Unipolar Depression]"
]

validate_queries = invalid_queries + [(q, True) for q, _ in valid_queries]


@pytest.mark.parametrize("query, expected", valid_queries)
def test_search_to_tsquery(query, expected):
assert search_to_tsquery(query) == expected


@pytest.mark.parametrize("query, expected", invalid_queries)
def test_validate_search_query(query, expected):
assert validate_search_query(query) == expected

@pytest.mark.parametrize("query", error_queries)
def test_search_to_tsquery_error(query):
with pytest.raises(ValueError):
search_to_tsquery(query)
Loading