From 85ab29915fa1818e0519d26cd6f440c815d4f59b Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 25 Oct 2024 01:19:32 +0800 Subject: [PATCH] feat(typing): better annotation support for `PyTree[T]` (#166) --- CHANGELOG.md | 4 ++-- docs/source/conf.py | 29 ++++++++++++++++++++++++++++- docs/source/integration.rst | 12 ++++++------ optree/integration/__init__.py | 4 ++++ optree/typing.py | 19 ++++++++++++++++--- pyproject.toml | 1 + tests/integration/test_imports.py | 2 ++ 7 files changed, 59 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 27fd0784..cdf8f24f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- Improve typing support for generic `PyTree[T]` and registry lookup / register functions by [@XuehaiPan](https://github.com/XuehaiPan) in [#160](https://github.com/metaopt/optree/pull/160). +- Improve typing support for generic `PyTree[T]` and registry lookup / register functions by [@XuehaiPan](https://github.com/XuehaiPan) in [#160](https://github.com/metaopt/optree/pull/160) and [#166](https://github.com/metaopt/optree/pull/166). ### Changed @@ -21,7 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- +- Improve typing support for `optree.dataclasses.dataclass` and `optree.dataclasses.field` by [@manulari](https://github.com/manulari) in [#165](https://github.com/metaopt/optree/pull/165). ### Removed diff --git a/docs/source/conf.py b/docs/source/conf.py index 2db9fd3e..e585d2cc 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -18,6 +18,7 @@ # https://www.sphinx-doc.org/en/master/usage/configuration.html # pylint: disable=all +# mypy: ignore-errors # -- Path setup -------------------------------------------------------------- @@ -83,7 +84,7 @@ def get_version() -> str: # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] -source_suffix = '.rst' +source_suffix = {'.rst': 'restructuredtext'} # The master toctree document. master_doc = 'index' @@ -105,6 +106,9 @@ def get_version() -> str: # The name of the Pygments (syntax highlighting) style to use. pygments_style = 'default' +# A list of warning codes to suppress arbitrary warning messages. +suppress_warnings = ['config.cache'] + # -- Options for autodoc ----------------------------------------------------- autodoc_default_options = { @@ -175,3 +179,26 @@ def get_version() -> str: # To make sphinx-copybutton skip all prompt characters generated by pygments copybutton_exclude = '.linenos, .gp' + +# -- Options for autodoc-typehints extension --------------------------------- +always_use_bars_union = True +typehints_use_signature = False +typehints_use_signature_return = False + + +def typehints_formatter(annotation, config=None): + from typing import Union + + if ( + isinstance(annotation, type(Union[int, str])) + and annotation.__origin__ is Union + and hasattr(annotation, '__pytree_args__') + ): + param, name = annotation.__pytree_args__ + if name is not None: + return f':py:class:`{name}`' + + from sphinx_autodoc_typehints import format_annotation + + return rf':py:class:`PyTree` \[{format_annotation(param,config=config)}]' + return None diff --git a/docs/source/integration.rst b/docs/source/integration.rst index eeca35e0..c60c682d 100644 --- a/docs/source/integration.rst +++ b/docs/source/integration.rst @@ -1,8 +1,8 @@ Integration with Third-Party Libraries ====================================== -Integration for JAX -------------------- +Integration for `JAX `_ +------------------------------------------------------ .. currentmodule:: optree.integration.jax @@ -14,8 +14,8 @@ Integration for JAX ------ -Integration for NumPy ---------------------- +Integration for `NumPy `_ +--------------------------------------------------------- .. currentmodule:: optree.integration.numpy @@ -27,8 +27,8 @@ Integration for NumPy ------ -Integration for PyTorch ------------------------ +Integration for `PyTorch `_ +--------------------------------------------------------------- .. currentmodule:: optree.integration.torch diff --git a/optree/integration/__init__.py b/optree/integration/__init__.py index 54d9105a..afc4753e 100644 --- a/optree/integration/__init__.py +++ b/optree/integration/__init__.py @@ -28,6 +28,10 @@ SUBMODULES: frozenset[str] = frozenset({'jax', 'numpy', 'torch'}) +def __dir__() -> list[str]: + return [*sorted(SUBMODULES), 'SUBMODULES'] + + def __getattr__(name: str) -> ModuleType: if name in SUBMODULES: import importlib # pylint: disable=import-outside-toplevel diff --git a/optree/typing.py b/optree/typing.py index 344f9de1..c753ae40 100644 --- a/optree/typing.py +++ b/optree/typing.py @@ -178,7 +178,7 @@ class PyTree(Generic[T]): # pragma: no cover typing.Union[torch.Tensor, typing.Tuple[ForwardRef('PyTree[torch.Tensor]'), ...], typing.List[ForwardRef('PyTree[torch.Tensor]')], - typing.Dict[collections.abc.Hashable, ForwardRef('PyTree[torch.Tensor]')], + typing.Dict[typing.Any, ForwardRef('PyTree[torch.Tensor]')], typing.Deque[ForwardRef('PyTree[torch.Tensor]')], optree.typing.CustomTreeNode[ForwardRef('PyTree[torch.Tensor]')]] """ @@ -232,11 +232,24 @@ def __class_getitem__( # noqa: C901 param, # type: ignore[valid-type] Tuple[recurse_ref, ...], # type: ignore[valid-type] # Tuple, NamedTuple, PyStructSequence List[recurse_ref], # type: ignore[valid-type] - Dict[Hashable, recurse_ref], # type: ignore[valid-type] # Dict, OrderedDict, DefaultDict + Dict[Any, recurse_ref], # type: ignore[valid-type] # Dict, OrderedDict, DefaultDict Deque[recurse_ref], # type: ignore[valid-type] CustomTreeNode[recurse_ref], # type: ignore[valid-type] ] pytree_alias.__pytree_args__ = item # type: ignore[attr-defined] + + # pylint: disable-next=no-member + original_copy_with = pytree_alias.copy_with # type: ignore[attr-defined] + original_num_params = len(pytree_alias.__args__) # type: ignore[attr-defined] + + def copy_with(params: tuple) -> TypeAlias: + if not isinstance(params, tuple) or len(params) != original_num_params: + return original_copy_with(params) + if params[0] is param: + return pytree_alias + return PyTree[params[0]] # type: ignore[misc,valid-type] + + object.__setattr__(pytree_alias, 'copy_with', copy_with) return pytree_alias def __new__(cls) -> NoReturn: # pylint: disable=arguments-differ @@ -302,7 +315,7 @@ class PyTreeTypeVar: # pragma: no cover typing.Union[torch.Tensor, typing.Tuple[ForwardRef('TensorTree'), ...], typing.List[ForwardRef('TensorTree')], - typing.Dict[collections.abc.Hashable, ForwardRef('TensorTree')], + typing.Dict[typing.Any, ForwardRef('TensorTree')], typing.Deque[ForwardRef('TensorTree')], optree.typing.CustomTreeNode[ForwardRef('TensorTree')]] """ diff --git a/pyproject.toml b/pyproject.toml index faab5485..e5e0b610 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -262,6 +262,7 @@ typing-modules = ["optree.typing"] "E402", # module-import-not-at-top-of-file ] "docs/source/conf.py" = [ + "ANN", # flake8-annotations "INP001", # flake8-no-pep420 ] diff --git a/tests/integration/test_imports.py b/tests/integration/test_imports.py index 439b7cf2..144f1224 100644 --- a/tests/integration/test_imports.py +++ b/tests/integration/test_imports.py @@ -21,6 +21,8 @@ def test_imports(): + assert dir(optree.integration) == ['SUBMODULES', 'jax', 'numpy', 'torch'] + with pytest.raises(AttributeError): optree.integration.abc # noqa: B018