diff --git a/store/neurostore/resources/base.py b/store/neurostore/resources/base.py index dc18f441..4f7bba7d 100644 --- a/store/neurostore/resources/base.py +++ b/store/neurostore/resources/base.py @@ -21,7 +21,7 @@ from ..core import cache from ..database import db -from .utils import get_current_user +from .utils import get_current_user, validate_search_query, search_to_tsquery from ..models import ( StudysetStudy, AnnotationAnalysis, @@ -613,7 +613,10 @@ def search(self): if s is not None and s.isdigit(): q = q.filter_by(pmid=s) elif s is not None and self._fulltext_fields: - tsquery = sa.func.websearch_to_tsquery("english", s) + valid = validate_search_query(s) + if not valid: + abort(400, description=valid) + tsquery = search_to_tsquery(s) q = q.filter(m._ts_vector.op("@@")(tsquery)) # Alternatively (or in addition), search on individual fields. diff --git a/store/neurostore/resources/utils.py b/store/neurostore/resources/utils.py index 9fa0fc68..962906ee 100644 --- a/store/neurostore/resources/utils.py +++ b/store/neurostore/resources/utils.py @@ -44,3 +44,182 @@ class ClassView(cls): ClassView.__name__ = cls.__name__ return ClassView + + +def validate_search_query(query: str) -> bool: + """ + Validate a search query string. + + Args: + query (str): The query string to validate. + + Returns: + bool: True if the query is valid, False otherwise. + """ + # Check for valid parentheses + if not validate_parentheses(query): + return 'Unmatched parentheses' + + # Check for valid query end + if not validate_query_end(query): + return 'Query cannot end with an operator' + + return True + + +def validate_parentheses(query: str) -> bool: + """ + Validate the parentheses in a query string. + + Args: + query (str): The query string to validate. + + Returns: + bool: True if parentheses are valid, False otherwise. + """ + stack = [] + for char in query: + if char == '(': + stack.append(char) + elif char == ')': + if not stack: + return False # Unmatched closing parenthesis + stack.pop() + return not stack # Ensure all opening parentheses are closed + + +def validate_query_end(query: str) -> bool: + """ Query should not end with an operator """ + operators = ('AND', 'OR', 'NOT') + + if query.strip().split(' ')[-1] in operators: + return False + return True + + +def count_chars(target, query: str) -> int: + """ Count the number of chars in a query string. + Excluding those in quoted phrases.""" + count = 0 + in_quotes = False + for char in query: + if char == '"': + in_quotes = not in_quotes + if char == target and not in_quotes: + count += 1 + return count + + +def pubmed_to_tsquery(query: str) -> str: + """ + Convert a PubMed-like search query to PostgreSQL tsquery format, + grouping both single-quoted and double-quoted text with the <-> operator + for proximity search. + + Additionally, automatically adds & between non-explicitly connected terms + and handles NOT terms. + + Args: + query (str): The search query. + + Returns: + str: The PostgreSQL tsquery equivalent. + """ + + query = query.upper() # Ensure uniformity + + # Step 1: Split into tokens (preserving quoted phrases) + # Regex pattern: match quoted phrases or non-space sequences + tokens = re.findall( r'"[^"]*"|\'[^\']*\'|\S+', query) + + # Step 2: Combine tokens in parantheses into single tokens + def combine_parentheses(tokens: list) -> list: + """ + Combine tokens within parentheses into a single token. + + Args: + tokens (list): List of tokens to process. + + Returns: + list: Processed list with tokens inside parentheses combined. + """ + combined_tokens = [] + buffer = [] + paren_count = 0 + for token in tokens: + # If buffer is not empty, we are inside parentheses + if len(buffer) > 0: + buffer.append(token) + + # Adjust the count of parentheses + paren_count += count_chars('(', token) - count_chars(')', token) + + if paren_count < 1: + # Combine all tokens in parentheses + combined_tokens.append(' '.join(buffer)) + buffer = [] # Clear the buffer + paren_count = 0 + + else: + n_paren = count_chars('(', token) - count_chars(')', token) + # If not in parentheses, but token contains opening parentheses + # Start capturing tokens inside parentheses + if token[0] == '(' and n_paren > 0: + paren_count += n_paren + buffer.append(token) # Start capturing tokens in parens + print(buffer) + else: + combined_tokens.append(token) + + # If the list ends without a closing parenthesis (invalid input) + # append buffer contents (fallback) + if buffer: + combined_tokens.append(' '.join(buffer)) + + return combined_tokens + + tokens = combine_parentheses(tokens) + print(tokens) + for i, token in enumerate(tokens): + if token[0] == "(" and token[-1] == ")": + # RECURSIVE: Process the contents of the parentheses + token_res = pubmed_to_tsquery(token[1:-1]) + token = '(' + token_res + ')' + tokens[i] = token + + # Step 4: Handle both single-quoted and double-quoted phrases, + # grouping them with <-> (proximity operator) + elif token[0] in ('"', "'"): + # Split quoted text into individual words and join with <-> for + # proximity search + words = re.findall(r'\w+', token) + tokens[i] = '<->'.join(words) + + # Step 3: Replace logical operators AND, OR, NOT + else: + if token == 'AND': + tokens[i] = '&' + elif token == 'OR': + tokens[i] = '|' + elif token == 'NOT': + tokens[i] = '&!' + + processed_tokens = [] + last_token = None + for token in tokens: + # Step 5: Add & between consecutive terms that aren't already + # connected by an operator + stripped_token = token.strip() + + if stripped_token == '': + continue # Ignore empty tokens from splitting + + if last_token and last_token not in ('&', '|', '!', '&!'): + if stripped_token not in ('&', '|', '!', '&!'): + # Insert an implicit AND (&) between two non-operator tokens + processed_tokens.append('&') + + processed_tokens.append(stripped_token) + last_token = stripped_token + + return ' '.join(processed_tokens) diff --git a/store/neurostore/tests/test_utils.py b/store/neurostore/tests/test_utils.py new file mode 100644 index 00000000..3e9cf35c --- /dev/null +++ b/store/neurostore/tests/test_utils.py @@ -0,0 +1,54 @@ +import pytest + +from ..utils import search_to_tsquery, validate_search_query + + +invalid_queries = [ + ('("autism" OR "ASD" OR "autistic") AND (("decision*" OR "choice*" ', 'Unmatched parentheses'), + ('"autism" OR "ASD" OR "autistic" OR ', 'Query cannot end with an operator'), + ('(("Autism Spectrum Disorder" OR "autism spectrum disorder") OR ("Autism" OR "autism") OR ("ASD")) AND (("decision*" OR "Dec', 'Unmatched parentheses') +] + +valid_queries = [ + ('"Mild Cognitive Impairment" or "Early Cognitive Decline" or "Pre-Dementia" or "Mild Neurocognitive Disorder"', + 'MILD<->COGNITIVE<->IMPAIRMENT | EARLY<->COGNITIVE<->DECLINE | PRE<->DEMENTIA | MILD<->NEUROCOGNITIVE<->DISORDER'), + ('("autism" OR "ASD" OR "autistic") AND ("decision" OR "choice")', + '(AUTISM | ASD | AUTISTIC) & (DECISION | CHOICE)'), + ('stroop and depression or back and depression or go', + 'STROOP & DEPRESSION | BACK & DEPRESSION | GO'), + ('("autism" OR "ASD" OR "autistic") AND (("decision" OR "decision-making" OR "choice" OR "selection" OR "option" OR "value") OR ("feedback" OR "feedback-related" OR "reward" OR "error" OR "outcome" OR "punishment" OR "reinforcement"))', + '(AUTISM | ASD | AUTISTIC) & ((DECISION | DECISION<->MAKING | CHOICE | SELECTION | OPTION | VALUE) | (FEEDBACK | FEEDBACK<->RELATED | REWARD | ERROR | OUTCOME | PUNISHMENT | REINFORCEMENT))'), + ('"dyslexia" or "Reading Disorder" or "Language-Based Learning Disability" or "Phonological Processing Disorder" or "Word Blindness"', + 'DYSLEXIA | READING<->DISORDER | LANGUAGE<->BASED<->LEARNING<->DISABILITY | PHONOLOGICAL<->PROCESSING<->DISORDER | WORD<->BLINDNESS'), + ('emotion and pain -physical -touch', + 'EMOTION & PAIN & -PHYSICAL & -TOUCH'), + ('("Schizophrenia"[Mesh] OR schizophrenia )', + '(SCHIZOPHRENIA & [MESH] | SCHIZOPHRENIA)') + ('Bipolar Disorder', + 'BIPOLAR & DISORDER'), + ('"quchi" or "LI11"', + 'QUCHI | LI11'), + ('"rubber hand illusion"', + 'RUBBER<->HAND<->ILLUSION'), +] + +error_queries = [ + "[Major Depressive Disorder (MDD)] or [Clinical Depression] or [Unipolar Depression]" +] + +validate_queries = invalid_queries + [(q, True) for q, _ in valid_queries] + + +@pytest.mark.parametrize("query, expected", valid_queries) +def test_search_to_tsquery(query, expected): + assert search_to_tsquery(query) == expected + + +@pytest.mark.parametrize("query, expected", invalid_queries) +def test_validate_search_query(query, expected): + assert validate_search_query(query) == expected + +@pytest.mark.parametrize("query", error_queries) +def test_search_to_tsquery_error(query): + with pytest.raises(ValueError): + search_to_tsquery(query)