diff --git a/sphinx_needs/directives/needpie.py b/sphinx_needs/directives/needpie.py index 5a2bda1d4..366376c57 100644 --- a/sphinx_needs/directives/needpie.py +++ b/sphinx_needs/directives/needpie.py @@ -7,6 +7,7 @@ from docutils.parsers.rst import directives from sphinx.application import Sphinx +from sphinx_needs.api.exceptions import NeedsInvalidFilter from sphinx_needs.config import NeedsSphinxConfig from sphinx_needs.data import NeedsPieType, SphinxNeedsData from sphinx_needs.debug import measure_time @@ -169,41 +170,49 @@ def process_needpie( ) sizes.append(result) elif current_needpie["filter_func"] and not content: + # check and get filter_func try: - # check and get filter_func - filter_func, filter_args = check_and_get_external_filter_func( + filter_func_sig = check_and_get_external_filter_func( current_needpie.get("filter_func") ) - # execute filter_func code + except NeedsInvalidFilter as e: + log_warning( + logger, + str(e), + "filter_func", + location=node, + ) + remove_node_from_tree(node) + continue + + # execute filter_func code + if filter_func_sig: # Provides only a copy of needs to avoid data manipulations. context: dict[str, Any] = { "needs": need_list, "results": [], } - args = [] - if filter_args: - args = filter_args.split(",") + args = filter_func_sig.args.split(",") if filter_func_sig.args else [] for index, arg in enumerate(args): # All rgs are strings, but we must transform them to requested type, e.g. 1 -> int, "1" -> str context[f"arg{index + 1}"] = arg - if filter_func: - filter_func(**context) + filter_func_sig.func(**context) + sizes = context["results"] # check items in sizes if not isinstance(sizes, list): logger.error( - f"The returned values from the given filter_func {filter_func.__name__} is not valid." + f"The returned values from the given filter_func {filter_func_sig.sig!r} is not valid." " It must be a list." ) for item in sizes: if not isinstance(item, int) and not isinstance(item, float): logger.error( - f"The returned values from the given filter_func {filter_func.__name__} is not valid. " + f"The returned values from the given filter_func {filter_func_sig.sig!r} is not valid. " "It must be a list with items of type int/float." ) - except Exception as e: - raise e + elif current_needpie["filter_func"] and content: logger.error( "filter_func and content can't be used at the same time for needpie." diff --git a/sphinx_needs/filter_common.py b/sphinx_needs/filter_common.py index ff8977ed0..8d6e3e39d 100644 --- a/sphinx_needs/filter_common.py +++ b/sphinx_needs/filter_common.py @@ -147,16 +147,25 @@ def process_filters( all_needs_incl_parts = prepare_need_list(checked_all_needs) # Check if external filter code is defined - filter_func, filter_args = check_and_get_external_filter_func( - filter_data.get("filter_func") - ) + try: + filter_func_sig = check_and_get_external_filter_func( + filter_data.get("filter_func") + ) + except NeedsInvalidFilter as e: + log_warning( + log, + str(e), + "filter_func", + location=location, + ) + return [] filter_code = None # Get filter_code from if not filter_code and filter_data["filter_code"]: filter_code = "\n".join(filter_data["filter_code"]) - if (not filter_code or filter_code.isspace()) and not filter_func: + if (not filter_code or filter_code.isspace()) and not filter_func_sig: if bool(filter_data["status"] or filter_data["tags"] or filter_data["types"]): for need_info in all_needs_incl_parts: status_filter_passed = False @@ -214,17 +223,17 @@ def process_filters( if filter_code: # code from content exec(filter_code, context) - elif filter_func: # code from external file + elif filter_func_sig: # code from external file args = [] - if filter_args: - args = filter_args.split(",") + if filter_func_sig.args: + args = filter_func_sig.args.split(",") for index, arg in enumerate(args): # All args are strings, but we must transform them to requested type, e.g. 1 -> int, "1" -> str context[f"arg{index+1}"] = arg # Decorate function to allow time measurments filter_func = measure_time_func( - filter_func, category="filter_func", source="user" + filter_func_sig.func, category="filter_func", source="user" ) filter_func(**context) else: diff --git a/sphinx_needs/utils.py b/sphinx_needs/utils.py index 7472efcc8..b065f802b 100644 --- a/sphinx_needs/utils.py +++ b/sphinx_needs/utils.py @@ -5,6 +5,7 @@ import operator import os import re +from dataclasses import dataclass from functools import lru_cache, reduce, wraps from typing import TYPE_CHECKING, Any, Callable, TypeVar from urllib.parse import urlparse @@ -13,6 +14,7 @@ from jinja2 import Environment, Template from sphinx.application import BuildEnvironment, Sphinx +from sphinx_needs.api.exceptions import NeedsInvalidFilter from sphinx_needs.config import LinkOptionsType, NeedsSphinxConfig from sphinx_needs.data import NeedsInfoType, SphinxNeedsData from sphinx_needs.defaults import NEEDS_PROFILING @@ -308,43 +310,45 @@ def check_and_calc_base_url_rel_path(external_url: str, fromdocname: str) -> str return ref_uri -def check_and_get_external_filter_func(filter_func_ref: str | None) -> tuple[Any, str]: +@dataclass +class FilterFunc: + """Dataclass for filter function.""" + + sig: str + func: Callable[..., Any] + args: str + + +@lru_cache(maxsize=32) +def check_and_get_external_filter_func( + filter_func_ref: str | None, +) -> FilterFunc | None: """Check and import filter function from external python file.""" - # Check if external filter code is defined - filter_func = None - filter_args = "" + if not filter_func_ref: + return None - if filter_func_ref: - try: - filter_module, filter_function = filter_func_ref.rsplit(".") - except ValueError: - log_warning( - logger, - f'Filter function not valid "{filter_func_ref}". Example: my_module:my_func', - None, - None, - ) - return filter_func, filter_args + try: + filter_module, filter_function = filter_func_ref.rsplit(".") + except ValueError: + raise NeedsInvalidFilter("does not contain a dot") - result = re.search(r"^(\w+)(?:\((.*)\))*$", filter_function) - if not result: - return filter_func, filter_args - filter_function = result.group(1) - filter_args = result.group(2) or "" + result = re.search(r"^(\w+)(?:\((.*)\))*$", filter_function) + if not result: + raise NeedsInvalidFilter(f"malformed function signature: {filter_function!r}") + filter_function = result.group(1) + filter_args = result.group(2) or "" - try: - final_module = importlib.import_module(filter_module) - filter_func = getattr(final_module, filter_function) - except Exception: - log_warning( - logger, - f"Could not import filter function: {filter_func_ref}", - None, - None, - ) - return filter_func, filter_args + try: + final_module = importlib.import_module(filter_module) + except Exception: + raise NeedsInvalidFilter(f"cannot import module: {filter_module}") + + try: + filter_func = getattr(final_module, filter_function) + except Exception: + raise NeedsInvalidFilter(f"module does not have function: {filter_function}") - return filter_func, filter_args + return FilterFunc(filter_func_ref, filter_func, filter_args) def jinja_parse(context: dict[str, Any], jinja_string: str) -> str: diff --git a/tests/doc_test/doc_needs_filter_data/filter_code.rst b/tests/doc_test/doc_needs_filter_data/filter_code.rst index fd565180f..c3ba58272 100644 --- a/tests/doc_test/doc_needs_filter_data/filter_code.rst +++ b/tests/doc_test/doc_needs_filter_data/filter_code.rst @@ -21,3 +21,13 @@ Filter code test cases .. needpie:: Filter code func pie :labels: project_x, project_y :filter-func: filter_code_func.my_pie_filter_code + + +.. needtable:: Malformed filter func table + :style: table + :filter-func: filter_code_func.own_filter_code( + + +.. needpie:: Malformed filter func pie + :labels: project_x, project_y + :filter-func: filter_code_func.my_pie_filter_code( diff --git a/tests/doc_test/doc_needs_filter_data/index.rst b/tests/doc_test/doc_needs_filter_data/index.rst index 74ed4bf22..aba1731ce 100644 --- a/tests/doc_test/doc_needs_filter_data/index.rst +++ b/tests/doc_test/doc_needs_filter_data/index.rst @@ -61,3 +61,9 @@ Needflow example .. needflow:: My needflow :filter: variant == current_variant + + +.. toctree:: + + filter_code + filter_code_args diff --git a/tests/test_needs_filter_data.py b/tests/test_needs_filter_data.py index 67978b448..53e45f3dd 100644 --- a/tests/test_needs_filter_data.py +++ b/tests/test_needs_filter_data.py @@ -1,7 +1,9 @@ +import os from pathlib import Path import pytest from docutils import __version__ as doc_ver +from sphinx.util.console import strip_colors @pytest.mark.parametrize( @@ -12,6 +14,19 @@ def test_doc_needs_filter_data_html(test_app): app = test_app app.build() + + warnings = strip_colors( + app._warning.getvalue().replace(str(app.srcdir) + os.sep, "srcdir/") + ).splitlines() + print(warnings) + assert warnings == [ + "srcdir/filter_code.rst:26: WARNING: malformed function signature: 'own_filter_code(' [needs.filter_func]", + "srcdir/filter_code.rst:31: WARNING: malformed function signature: 'my_pie_filter_code(' [needs.filter_func]", + "WARNING: variant_not_equal_current_variant: failed", + "\t\tfailed needs: 1 (extern_filter_story_002)", + "\t\tused filter: variant != current_variant [needs.warnings]", + ] + index_html = Path(app.outdir, "index.html").read_text() # Check need_count works @@ -52,15 +67,6 @@ def test_doc_needs_filter_data_html(test_app): in index_html ) - # check needs_warnings works - warning = app._warning - warnings = warning.getvalue() - - # check warnings contents - assert "WARNING: variant_not_equal_current_variant: failed" in warnings - assert "failed needs: 1 (extern_filter_story_002)" in warnings - assert "used filter: variant != current_variant" in warnings - @pytest.mark.parametrize( "test_app",