From 65e2dd775566c36d735d22eb94e4d464a5608be3 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Wed, 4 Sep 2024 16:26:47 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A7=20Introduce=20`NeedsPartsView`=20t?= =?UTF-8?q?ype?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Make it clearer where we are using an list of the "expanded" needs+parts. As discussed in #1264, there are a number of different representations of the needs, and so this makes it clearer which one a variable is. --- sphinx_needs/data.py | 20 ++++++++++- sphinx_needs/directives/needpie.py | 29 ++++++---------- sphinx_needs/filter_common.py | 55 +++++++++++++++--------------- sphinx_needs/utils.py | 22 ++++++++---- 4 files changed, 73 insertions(+), 53 deletions(-) diff --git a/sphinx_needs/data.py b/sphinx_needs/data.py index 460bba008..d36e99b2e 100644 --- a/sphinx_needs/data.py +++ b/sphinx_needs/data.py @@ -4,7 +4,17 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, Final, Literal, Mapping, NewType, TypedDict +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Final, + Literal, + Mapping, + NewType, + Sequence, + TypedDict, +) from sphinx.util.logging import getLogger @@ -685,6 +695,14 @@ class NeedsUmlType(NeedsBaseDataType): (e.g. back links have been computed etc) """ +NeedsPartsView = NewType("NeedsPartsView", Sequence[NeedsInfoType]) +"""A read-only view of a sequence of needs and parts, +after resolution (e.g. back links have been computed etc) + +The parts are created by creating a copy of the need for each item in ``parts``, +and then overwriting the fields with the values from the part. +""" + class SphinxNeedsData: """Centralised access to sphinx-needs data, stored within the Sphinx environment.""" diff --git a/sphinx_needs/directives/needpie.py b/sphinx_needs/directives/needpie.py index 1e4faafe7..639b3117b 100644 --- a/sphinx_needs/directives/needpie.py +++ b/sphinx_needs/directives/needpie.py @@ -1,7 +1,7 @@ from __future__ import annotations import hashlib -from typing import Any, Iterable, Sequence +from typing import Iterable, Sequence from docutils import nodes from docutils.parsers.rst import directives @@ -171,7 +171,7 @@ def process_needpie( elif current_needpie["filter_func"] and not content: # check and get filter_func try: - filter_func_sig = check_and_get_external_filter_func( + ff_result = check_and_get_external_filter_func( current_needpie.get("filter_func") ) except NeedsInvalidFilter as e: @@ -185,30 +185,23 @@ def process_needpie( continue # execute filter_func code - if filter_func_sig: - # Provides only a copy of needs to avoid data manipulations. - context: dict[str, Any] = { - "needs": need_list, - "results": [], - } - args = filter_func_sig.args.split(",") if filter_func_sig.args else [] - for index, arg in enumerate(args): - # All rgs are strings, but we must transform them to requested type, e.g. 1 -> int, "1" -> str - context[f"arg{index + 1}"] = arg - - filter_func_sig.func(**context) - - sizes = context["results"] + if ff_result: + args = ff_result.args.split(",") if ff_result.args else [] + args_context = {f"arg{index+1}": arg for index, arg in enumerate(args)} + + sizes = [] + ff_result.func(needs=need_list, results=sizes, **args_context) + # check items in sizes if not isinstance(sizes, list): logger.error( - f"The returned values from the given filter_func {filter_func_sig.sig!r} is not valid." + f"The returned values from the given filter_func {ff_result.sig!r} is not valid." " It must be a list." ) for item in sizes: if not isinstance(item, int) and not isinstance(item, float): logger.error( - f"The returned values from the given filter_func {filter_func_sig.sig!r} is not valid. " + f"The returned values from the given filter_func {ff_result.sig!r} is not valid. " "It must be a list with items of type int/float." ) diff --git a/sphinx_needs/filter_common.py b/sphinx_needs/filter_common.py index 897bc434b..c8c0b0f26 100644 --- a/sphinx_needs/filter_common.py +++ b/sphinx_needs/filter_common.py @@ -20,6 +20,7 @@ from sphinx_needs.data import ( NeedsFilteredBaseType, NeedsInfoType, + NeedsPartsView, NeedsView, SphinxNeedsData, ) @@ -134,9 +135,7 @@ def process_filters( # Check if external filter code is defined try: - filter_func_sig = check_and_get_external_filter_func( - filter_data.get("filter_func") - ) + ff_result = check_and_get_external_filter_func(filter_data.get("filter_func")) except NeedsInvalidFilter as e: log_warning( log, @@ -151,7 +150,7 @@ def process_filters( if not filter_code and filter_data["filter_code"]: filter_code = "\n".join(filter_data["filter_code"]) - if (not filter_code or filter_code.isspace()) and not filter_func_sig: + if (not filter_code or filter_code.isspace()) and not ff_result: if bool(filter_data["status"] or filter_data["tags"] or filter_data["types"]): found_needs_by_options: list[NeedsInfoType] = [] for need_info in all_needs_incl_parts: @@ -202,35 +201,36 @@ def process_filters( location=location, ) else: - # Provides only a copy of needs to avoid data manipulations. - context: dict[str, Any] = { - "needs": all_needs_incl_parts, - "results": [], - } + # The filter results may be dirty, as it may continue manipulated needs. + found_dirty_needs: list[NeedsInfoType] = [] if filter_code: # code from content + # TODO better context type + context: dict[str, list[NeedsInfoType]] = { + "needs": all_needs_incl_parts, # type: ignore[dict-item] + "results": [], + } exec(filter_code, context) - elif filter_func_sig: # code from external file + found_dirty_needs = context["results"] + elif ff_result: # code from external file args = [] - if filter_func_sig.args: - args = filter_func_sig.args.split(",") - for index, arg in enumerate(args): - # All args are strings, but we must transform them to requested type, e.g. 1 -> int, "1" -> str - context[f"arg{index+1}"] = arg + if ff_result.args: + args = ff_result.args.split(",") + args_context = {f"arg{index+1}": arg for index, arg in enumerate(args)} # Decorate function to allow time measurments filter_func = measure_time_func( - filter_func_sig.func, category="filter_func", source="user" + ff_result.func, category="filter_func", source="user" + ) + filter_func( + needs=all_needs_incl_parts, results=found_dirty_needs, **args_context ) - filter_func(**context) else: log_warning( log, "Something went wrong running filter", None, location=location ) return [] - # The filter results may be dirty, as it may continue manipulated needs. - found_dirty_needs: list[NeedsInfoType] = context["results"] found_needs = [] # Check if config allow unsafe filters @@ -277,15 +277,17 @@ def process_filters( return found_needs -def expand_needs_view(needs_view: NeedsView) -> list[NeedsInfoType]: - """Turns a needs view into a list of needs, expanding all need["parts"] to be items of the list.""" - all_needs_incl_parts: list[NeedsInfoType] = [] +def expand_needs_view(needs_view: NeedsView) -> NeedsPartsView: + """Turns a needs view into a sequence of needs, + expanding all ``need["parts"]`` to be items of the list. + """ + all_needs_incl_parts = [] for need in needs_view.values(): all_needs_incl_parts.append(need) for need_part in iter_need_parts(need): all_needs_incl_parts.append(need_part) - return all_needs_incl_parts + return NeedsPartsView(all_needs_incl_parts) T = TypeVar("T") @@ -295,19 +297,16 @@ def intersection_of_need_results(list_a: list[T], list_b: list[T]) -> list[T]: return [a for a in list_a if a in list_b] -V = TypeVar("V", bound=NeedsInfoType) - - @measure_time("filtering") def filter_needs( - needs: Iterable[V], + needs: Iterable[NeedsInfoType], config: NeedsSphinxConfig, filter_string: None | str = "", current_need: NeedsInfoType | None = None, *, location: tuple[str, int | None] | nodes.Node | None = None, append_warning: str = "", -) -> list[V]: +) -> list[NeedsInfoType]: """ Filters given needs based on a given filter string. Returns all needs, which pass the given filter. diff --git a/sphinx_needs/utils.py b/sphinx_needs/utils.py index 8f09db231..adb1dd564 100644 --- a/sphinx_needs/utils.py +++ b/sphinx_needs/utils.py @@ -7,7 +7,7 @@ import re from dataclasses import dataclass from functools import lru_cache, reduce, wraps -from typing import TYPE_CHECKING, Any, Callable, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Protocol, TypeVar from urllib.parse import urlparse from docutils import nodes @@ -16,7 +16,7 @@ from sphinx_needs.api.exceptions import NeedsInvalidFilter from sphinx_needs.config import LinkOptionsType, NeedsSphinxConfig -from sphinx_needs.data import NeedsInfoType, NeedsView, SphinxNeedsData +from sphinx_needs.data import NeedsInfoType, NeedsPartsView, NeedsView, SphinxNeedsData from sphinx_needs.defaults import NEEDS_PROFILING from sphinx_needs.logging import get_logger, log_warning @@ -310,19 +310,29 @@ def check_and_calc_base_url_rel_path(external_url: str, fromdocname: str) -> str return ref_uri +class FilterFunc(Protocol): + def __call__( + self, + *, + needs: NeedsPartsView, + results: list[Any], + **kwargs: str, + ) -> None: ... + + @dataclass -class FilterFunc: +class FilterFuncResult: """Dataclass for filter function.""" sig: str - func: Callable[..., Any] + func: FilterFunc args: str @lru_cache(maxsize=32) def check_and_get_external_filter_func( filter_func_ref: str | None, -) -> FilterFunc | None: +) -> FilterFuncResult | None: """Check and import filter function from external python file.""" if not filter_func_ref: return None @@ -348,7 +358,7 @@ def check_and_get_external_filter_func( except Exception: raise NeedsInvalidFilter(f"module does not have function: {filter_function}") - return FilterFunc(filter_func_ref, filter_func, filter_args) + return FilterFuncResult(filter_func_ref, filter_func, filter_args) def jinja_parse(context: dict[str, Any], jinja_string: str) -> str: