Skip to content

Commit

Permalink
🐛 Handle malformed filter-func option value (#1254)
Browse files Browse the repository at this point in the history
This commit provides a warning for all malformed `filter-func` values, and also returns the result as an empty list of needs,
so we do not wasted time processing them.
  • Loading branch information
chrisjsewell authored Aug 30, 2024
1 parent 3665b04 commit e668a05
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 61 deletions.
33 changes: 21 additions & 12 deletions sphinx_needs/directives/needpie.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from docutils.parsers.rst import directives
from sphinx.application import Sphinx

from sphinx_needs.api.exceptions import NeedsInvalidFilter
from sphinx_needs.config import NeedsSphinxConfig
from sphinx_needs.data import NeedsPieType, SphinxNeedsData
from sphinx_needs.debug import measure_time
Expand Down Expand Up @@ -169,41 +170,49 @@ def process_needpie(
)
sizes.append(result)
elif current_needpie["filter_func"] and not content:
# check and get filter_func
try:
# check and get filter_func
filter_func, filter_args = check_and_get_external_filter_func(
filter_func_sig = check_and_get_external_filter_func(
current_needpie.get("filter_func")
)
# execute filter_func code
except NeedsInvalidFilter as e:
log_warning(
logger,
str(e),
"filter_func",
location=node,
)
remove_node_from_tree(node)
continue

# execute filter_func code
if filter_func_sig:
# Provides only a copy of needs to avoid data manipulations.
context: dict[str, Any] = {
"needs": need_list,
"results": [],
}
args = []
if filter_args:
args = filter_args.split(",")
args = filter_func_sig.args.split(",") if filter_func_sig.args else []
for index, arg in enumerate(args):
# All rgs are strings, but we must transform them to requested type, e.g. 1 -> int, "1" -> str
context[f"arg{index + 1}"] = arg

if filter_func:
filter_func(**context)
filter_func_sig.func(**context)

sizes = context["results"]
# check items in sizes
if not isinstance(sizes, list):
logger.error(
f"The returned values from the given filter_func {filter_func.__name__} is not valid."
f"The returned values from the given filter_func {filter_func_sig.sig!r} is not valid."
" It must be a list."
)
for item in sizes:
if not isinstance(item, int) and not isinstance(item, float):
logger.error(
f"The returned values from the given filter_func {filter_func.__name__} is not valid. "
f"The returned values from the given filter_func {filter_func_sig.sig!r} is not valid. "
"It must be a list with items of type int/float."
)
except Exception as e:
raise e

elif current_needpie["filter_func"] and content:
logger.error(
"filter_func and content can't be used at the same time for needpie."
Expand Down
25 changes: 17 additions & 8 deletions sphinx_needs/filter_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,16 +147,25 @@ def process_filters(
all_needs_incl_parts = prepare_need_list(checked_all_needs)

# Check if external filter code is defined
filter_func, filter_args = check_and_get_external_filter_func(
filter_data.get("filter_func")
)
try:
filter_func_sig = check_and_get_external_filter_func(
filter_data.get("filter_func")
)
except NeedsInvalidFilter as e:
log_warning(
log,
str(e),
"filter_func",
location=location,
)
return []

filter_code = None
# Get filter_code from
if not filter_code and filter_data["filter_code"]:
filter_code = "\n".join(filter_data["filter_code"])

if (not filter_code or filter_code.isspace()) and not filter_func:
if (not filter_code or filter_code.isspace()) and not filter_func_sig:
if bool(filter_data["status"] or filter_data["tags"] or filter_data["types"]):
for need_info in all_needs_incl_parts:
status_filter_passed = False
Expand Down Expand Up @@ -214,17 +223,17 @@ def process_filters(

if filter_code: # code from content
exec(filter_code, context)
elif filter_func: # code from external file
elif filter_func_sig: # code from external file
args = []
if filter_args:
args = filter_args.split(",")
if filter_func_sig.args:
args = filter_func_sig.args.split(",")
for index, arg in enumerate(args):
# All args are strings, but we must transform them to requested type, e.g. 1 -> int, "1" -> str
context[f"arg{index+1}"] = arg

# Decorate function to allow time measurments
filter_func = measure_time_func(
filter_func, category="filter_func", source="user"
filter_func_sig.func, category="filter_func", source="user"
)
filter_func(**context)
else:
Expand Down
68 changes: 36 additions & 32 deletions sphinx_needs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import operator
import os
import re
from dataclasses import dataclass
from functools import lru_cache, reduce, wraps
from typing import TYPE_CHECKING, Any, Callable, TypeVar
from urllib.parse import urlparse
Expand All @@ -13,6 +14,7 @@
from jinja2 import Environment, Template
from sphinx.application import BuildEnvironment, Sphinx

from sphinx_needs.api.exceptions import NeedsInvalidFilter
from sphinx_needs.config import LinkOptionsType, NeedsSphinxConfig
from sphinx_needs.data import NeedsInfoType, SphinxNeedsData
from sphinx_needs.defaults import NEEDS_PROFILING
Expand Down Expand Up @@ -308,43 +310,45 @@ def check_and_calc_base_url_rel_path(external_url: str, fromdocname: str) -> str
return ref_uri


def check_and_get_external_filter_func(filter_func_ref: str | None) -> tuple[Any, str]:
@dataclass
class FilterFunc:
"""Dataclass for filter function."""

sig: str
func: Callable[..., Any]
args: str


@lru_cache(maxsize=32)
def check_and_get_external_filter_func(
filter_func_ref: str | None,
) -> FilterFunc | None:
"""Check and import filter function from external python file."""
# Check if external filter code is defined
filter_func = None
filter_args = ""
if not filter_func_ref:
return None

if filter_func_ref:
try:
filter_module, filter_function = filter_func_ref.rsplit(".")
except ValueError:
log_warning(
logger,
f'Filter function not valid "{filter_func_ref}". Example: my_module:my_func',
None,
None,
)
return filter_func, filter_args
try:
filter_module, filter_function = filter_func_ref.rsplit(".")
except ValueError:
raise NeedsInvalidFilter("does not contain a dot")

result = re.search(r"^(\w+)(?:\((.*)\))*$", filter_function)
if not result:
return filter_func, filter_args
filter_function = result.group(1)
filter_args = result.group(2) or ""
result = re.search(r"^(\w+)(?:\((.*)\))*$", filter_function)
if not result:
raise NeedsInvalidFilter(f"malformed function signature: {filter_function!r}")
filter_function = result.group(1)
filter_args = result.group(2) or ""

try:
final_module = importlib.import_module(filter_module)
filter_func = getattr(final_module, filter_function)
except Exception:
log_warning(
logger,
f"Could not import filter function: {filter_func_ref}",
None,
None,
)
return filter_func, filter_args
try:
final_module = importlib.import_module(filter_module)
except Exception:
raise NeedsInvalidFilter(f"cannot import module: {filter_module}")

try:
filter_func = getattr(final_module, filter_function)
except Exception:
raise NeedsInvalidFilter(f"module does not have function: {filter_function}")

return filter_func, filter_args
return FilterFunc(filter_func_ref, filter_func, filter_args)


def jinja_parse(context: dict[str, Any], jinja_string: str) -> str:
Expand Down
10 changes: 10 additions & 0 deletions tests/doc_test/doc_needs_filter_data/filter_code.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,13 @@ Filter code test cases
.. needpie:: Filter code func pie
:labels: project_x, project_y
:filter-func: filter_code_func.my_pie_filter_code


.. needtable:: Malformed filter func table
:style: table
:filter-func: filter_code_func.own_filter_code(


.. needpie:: Malformed filter func pie
:labels: project_x, project_y
:filter-func: filter_code_func.my_pie_filter_code(
6 changes: 6 additions & 0 deletions tests/doc_test/doc_needs_filter_data/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,9 @@ Needflow example

.. needflow:: My needflow
:filter: variant == current_variant


.. toctree::

filter_code
filter_code_args
24 changes: 15 additions & 9 deletions tests/test_needs_filter_data.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
from pathlib import Path

import pytest
from docutils import __version__ as doc_ver
from sphinx.util.console import strip_colors


@pytest.mark.parametrize(
Expand All @@ -12,6 +14,19 @@
def test_doc_needs_filter_data_html(test_app):
app = test_app
app.build()

warnings = strip_colors(
app._warning.getvalue().replace(str(app.srcdir) + os.sep, "srcdir/")
).splitlines()
print(warnings)
assert warnings == [
"srcdir/filter_code.rst:26: WARNING: malformed function signature: 'own_filter_code(' [needs.filter_func]",
"srcdir/filter_code.rst:31: WARNING: malformed function signature: 'my_pie_filter_code(' [needs.filter_func]",
"WARNING: variant_not_equal_current_variant: failed",
"\t\tfailed needs: 1 (extern_filter_story_002)",
"\t\tused filter: variant != current_variant [needs.warnings]",
]

index_html = Path(app.outdir, "index.html").read_text()

# Check need_count works
Expand Down Expand Up @@ -52,15 +67,6 @@ def test_doc_needs_filter_data_html(test_app):
in index_html
)

# check needs_warnings works
warning = app._warning
warnings = warning.getvalue()

# check warnings contents
assert "WARNING: variant_not_equal_current_variant: failed" in warnings
assert "failed needs: 1 (extern_filter_story_002)" in warnings
assert "used filter: variant != current_variant" in warnings


@pytest.mark.parametrize(
"test_app",
Expand Down

0 comments on commit e668a05

Please sign in to comment.