Skip to content

Commit

Permalink
✨ add needs_import_cache_size configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisjsewell committed Sep 12, 2024
1 parent 6776e98 commit 1f10d62
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 23 deletions.
11 changes: 11 additions & 0 deletions docs/configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1645,7 +1645,18 @@ keys:
The related CSS class definition must be done by the user, e.g. by :ref:`own_css`.
(*optional*) (*default*: ``external_link``)
.. _needs_import_cache_size:
needs_import_cache_size
~~~~~~~~~~~~~~~~~~~~~~~
.. versionadded:: 3.1.0
Sets the maximum number of needs cached by the :ref:`needimport` directive,
which is used to avoid multiple reads of the same file.
Note, setting this value too high may lead to high memory usage during the sphinx build.
Default: :need_config_default:`import_cache_size`
.. _needs_needextend_strict:
Expand Down
5 changes: 5 additions & 0 deletions docs/directives/needimport.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ The directive also supports URL as argument to download ``needs.json`` from remo
.. needimport:: https://my_company.com/docs/remote-needs.json
.. seealso::

:ref:`needs_import_cache_size`,
to control the cache size for imported needs.

Options
-------

Expand Down
4 changes: 4 additions & 0 deletions sphinx_needs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,10 @@ def __setattr__(self, name: str, value: Any) -> None:
default_factory=list, metadata={"rebuild": "html", "types": (list,)}
)
"""List of external sources to load needs from."""
import_cache_size: int = field(
default=100, metadata={"rebuild": "html", "types": (int,)}
)
"""Maximum number of imported needs to cache."""
builder_filter: str = field(
default="is_external==False", metadata={"rebuild": "html", "types": (str,)}
)
Expand Down
99 changes: 76 additions & 23 deletions sphinx_needs/directives/needimport.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import json
import os
import re
from typing import Sequence
import threading
from copy import deepcopy
from typing import Any, OrderedDict, Sequence
from urllib.parse import urlparse

import requests
Expand Down Expand Up @@ -52,7 +54,8 @@ class NeedimportDirective(SphinxDirective):

@measure_time("needimport")
def run(self) -> Sequence[nodes.Node]:
# needs_list = {}
needs_config = NeedsSphinxConfig(self.config)

version = self.options.get("version")
filter_string = self.options.get("filter")
id_prefix = self.options.get("id_prefix", "")
Expand Down Expand Up @@ -111,21 +114,32 @@ def run(self) -> Sequence[nodes.Node]:
raise ReferenceError(
f"Could not load needs import file {correct_need_import_path}"
)
mtime = os.path.getmtime(correct_need_import_path)

try:
with open(correct_need_import_path) as needs_file:
needs_import_list = json.load(needs_file)
except json.JSONDecodeError as e:
# TODO: Add exception handling
raise SphinxNeedsFileException(correct_need_import_path) from e

errors = check_needs_data(needs_import_list)
if errors.schema:
logger.info(
f"Schema validation errors detected in file {correct_need_import_path}:"
)
for error in errors.schema:
logger.info(f' {error.message} -> {".".join(error.path)}')
if (
needs_import_list := _FileCache.get(correct_need_import_path, mtime)
) is None:
try:
with open(correct_need_import_path) as needs_file:
needs_import_list = json.load(needs_file)
except json.JSONDecodeError as e:
# TODO: Add exception handling
raise SphinxNeedsFileException(correct_need_import_path) from e

errors = check_needs_data(needs_import_list)
if errors.schema:
logger.info(
f"Schema validation errors detected in file {correct_need_import_path}:"
)
for error in errors.schema:
logger.info(f' {error.message} -> {".".join(error.path)}')
else:
_FileCache.set(
correct_need_import_path,
mtime,
needs_import_list,
needs_config.import_cache_size,
)

if version is None:
try:
Expand All @@ -141,17 +155,17 @@ def run(self) -> Sequence[nodes.Node]:
f"Version {version} not found in needs import file {correct_need_import_path}"
)

needs_config = NeedsSphinxConfig(self.config)
data = needs_import_list["versions"][version]

# TODO this is not exactly NeedsInfoType, because the export removes/adds some keys
needs_list: dict[str, NeedsInfoType] = data["needs"]

if ids := self.options.get("ids"):
id_list = [i.strip() for i in ids.split(",") if i.strip()]
data["needs"] = {
needs_list = {
key: data["needs"][key] for key in id_list if key in data["needs"]
}

# TODO this is not exactly NeedsInfoType, because the export removes/adds some keys
needs_list: dict[str, NeedsInfoType] = data["needs"]
if schema := data.get("needs_schema"):
# Set defaults from schema
defaults = {
Expand All @@ -160,7 +174,8 @@ def run(self) -> Sequence[nodes.Node]:
if "default" in value
}
needs_list = {
key: {**defaults, **value} for key, value in needs_list.items()
key: {**defaults, **value} # type: ignore[typeddict-item]
for key, value in needs_list.items()
}

# Filter imported needs
Expand All @@ -169,7 +184,8 @@ def run(self) -> Sequence[nodes.Node]:
if filter_string is None:
needs_list_filtered[key] = need
else:
filter_context = need.copy()
# we deepcopy here, to ensure that the original data is not modified
filter_context = deepcopy(need)

# Support both ways of addressing the description, as "description" is used in json file, but
# "content" is the sphinx internal name for this kind of information
Expand All @@ -185,7 +201,9 @@ def run(self) -> Sequence[nodes.Node]:
location=(self.env.docname, self.lineno),
)

needs_list = needs_list_filtered
# note we need to deepcopy here, as we are going to modify the data,
# but we want to ensure data referenced from the cache is not modified
needs_list = deepcopy(needs_list_filtered)

# If we need to set an id prefix, we also need to manipulate all used ids in the imported data.
extra_links = needs_config.extra_links
Expand Down Expand Up @@ -283,6 +301,41 @@ def docname(self) -> str:
return self.env.docname


class _ImportCache:
"""A simple cache for imported needs,
mapping a (path, mtime) to a dictionary of needs.
that is thread safe,
and has a maximum size when adding new items.
"""

def __init__(self) -> None:
self._cache: OrderedDict[tuple[str, float], dict[str, Any]] = OrderedDict()
self._need_count = 0
self._lock = threading.Lock()

def set(
self, path: str, mtime: float, value: dict[str, Any], max_size: int
) -> None:
with self._lock:
self._cache[(path, mtime)] = value
self._need_count += len(value)
max_size = max(max_size, 0)
while self._need_count > max_size:
_, value = self._cache.popitem(last=False)
self._need_count -= len(value)

Check warning on line 325 in sphinx_needs/directives/needimport.py

View check run for this annotation

Codecov / codecov/patch

sphinx_needs/directives/needimport.py#L324-L325

Added lines #L324 - L325 were not covered by tests

def get(self, path: str, mtime: float) -> dict[str, Any] | None:
with self._lock:
return self._cache.get((path, mtime), None)

def __repr__(self) -> str:
with self._lock:
return f"{self.__class__.__name__}({list(self._cache)})"

Check warning on line 333 in sphinx_needs/directives/needimport.py

View check run for this annotation

Codecov / codecov/patch

sphinx_needs/directives/needimport.py#L332-L333

Added lines #L332 - L333 were not covered by tests


_FileCache = _ImportCache()


class VersionNotFound(BaseException):
pass

Expand Down

0 comments on commit 1f10d62

Please sign in to comment.