diff --git a/cdd/__init__.py b/cdd/__init__.py index 6f0e3b7a..8a540277 100644 --- a/cdd/__init__.py +++ b/cdd/__init__.py @@ -9,7 +9,7 @@ from logging import getLogger as get_logger __author__ = "Samuel Marks" # type: str -__version__ = "0.0.99rc44" # type: str +__version__ = "0.0.99rc45" # type: str __description__ = ( "Open API to/fro routes, models, and tests. " "Convert between docstrings, classes, methods, argparse, pydantic, and SQLalchemy." diff --git a/cdd/shared/ast_utils.py b/cdd/shared/ast_utils.py index 745e18d4..39999517 100644 --- a/cdd/shared/ast_utils.py +++ b/cdd/shared/ast_utils.py @@ -46,7 +46,8 @@ from json import dumps from operator import attrgetter, contains, inv, itemgetter, neg, not_, pos from os import path -from typing import FrozenSet, Generator, Optional +from typing import Callable, FrozenSet, Generator, MutableSet, Optional +from typing import Tuple as TTuple from typing import __all__ as typing__all__ import cdd.shared.source_transformer @@ -2215,7 +2216,14 @@ def get_types(node): return iter((node.value.id, node.slice.id)) elif isinstance(node.slice, Tuple): return chain.from_iterable( - ((node.value.id,), map(get_value, map(get_value, node.slice.elts))) + ( + (node.value.id,), + ( + iter(()) + if node.value.id == "Literal" + else map(get_value, map(get_value, node.slice.elts)) + ), + ) ) @@ -2228,16 +2236,16 @@ def infer_imports(module, modules_to_all=DEFAULT_MODULES_TO_ALL): - sqlalchemy - pydantic - :param module: Module, ClassDef, FunctionDef, AsyncFunctionDef, Assign - :type module: ```Union[ClassDef, FunctionDef, AsyncFunctionDef, Assign]``` + :param module: Module, ClassDef, FunctionDef, AsyncFunctionDef, Assign, AnnAssign + :type module: ```Union[Module, ClassDef, FunctionDef, AsyncFunctionDef, Assign, AnnAssign]``` :param modules_to_all: Tuple of module_name to __all__ of module; (str) to FrozenSet[str] :type modules_to_all: ```tuple[tuple[str, frozenset], ...]``` :return: List of imports - :rtype: ```Optional[Tuple[Union[Import, ImportFrom]]]``` + :rtype: ```Optional[Tuple[Union[Import, ImportFrom], ...]]``` """ - if isinstance(module, (ClassDef, FunctionDef, AsyncFunctionDef, Assign)): + if isinstance(module, (ClassDef, FunctionDef, AsyncFunctionDef, Assign, AnnAssign)): module: Module = Module(body=[module], type_ignores=[], stmt=None) assert isinstance(module, Module), "Expected `Module` got `{type_name}`".format( type_name=type(module).__name__ @@ -2252,7 +2260,13 @@ def node_to_importable_name(node): :rtype: ```Optional[str]``` """ if getattr(node, "type_comment", None) is not None: - return node.type_comment + return ( + node.type_comment + if node.type_comment in simple_types + else get_value( + get_value(get_value(ast.parse(node.type_comment).body[0])) + ) + ) elif getattr(node, "annotation", None) is not None: node = node # type: Union[AnnAssign, arg] return node.annotation # cast(node, Union[AnnAssign, arg]) @@ -2261,7 +2275,9 @@ def node_to_importable_name(node): else: return None - _symbol_to_import = partial(symbol_to_import, modules_to_all=modules_to_all) + _symbol_to_import: Callable[[str], Optional[TTuple[str, str]]] = partial( + symbol_to_import, modules_to_all=modules_to_all + ) # Lots of room for optimisation here; but its probably NP-hard: imports = tuple( @@ -2352,8 +2368,10 @@ def deduplicate_sorted_imports(module): :return: Module but with duplicate import entries in first import block removed :rtype: ```Module``` """ - assert isinstance(module, Module) - fst_import_idx = next( + assert isinstance(module, Module), "Expected `Module` got `{}`".format( + type(module).__name__ + ) + fst_import_idx: Optional[int] = next( map( itemgetter(0), filter( @@ -2365,7 +2383,7 @@ def deduplicate_sorted_imports(module): ) if fst_import_idx is None: return module - lst_import_idx = next( + lst_import_idx: Optional[int] = next( iter( deque( map( @@ -2380,7 +2398,7 @@ def deduplicate_sorted_imports(module): ), None, ) - name_seen = set() + name_seen: MutableSet[str] = set() module.body = ( module.body[:fst_import_idx] diff --git a/cdd/tests/test_shared/test_ast_utils.py b/cdd/tests/test_shared/test_ast_utils.py index 0a7bd60d..fb515788 100644 --- a/cdd/tests/test_shared/test_ast_utils.py +++ b/cdd/tests/test_shared/test_ast_utils.py @@ -31,6 +31,7 @@ arguments, keyword, ) +from collections import deque from copy import deepcopy from itertools import repeat from os import extsep, path @@ -88,6 +89,7 @@ function_adder_ast, function_adder_str, ) +from cdd.tests.mocks.pydantic import pydantic_class_cls_def from cdd.tests.mocks.sqlalchemy import config_decl_base_ast from cdd.tests.utils_for_tests import inspectable_compile, run_ast_test, unittest_main @@ -432,7 +434,7 @@ def test_infer_imports_with_sqlalchemy(self) -> None: """ imports = infer_imports( config_decl_base_ast - ) # type: Optional[Tuple[Union[Import, ImportFrom]]] + ) # type: Optional[TTuple[Union[Import, ImportFrom], ...]] self.assertIsNotNone(imports) self.assertEqual(len(imports), 1) run_ast_test( @@ -455,6 +457,57 @@ def test_infer_imports_with_sqlalchemy(self) -> None: ), ) + def test_infer_imports_with_simple_node_variants(self) -> None: + """ + Test that `infer_imports` with some simple variants + """ + + def inner_test(imports): + """ + Run the actual test + + :param imports: The imports to compare against + :type imports: ```TList[ImportFrom]``` + """ + self.assertIsNotNone(imports) + self.assertEqual(len(imports), 1) + run_ast_test( + self, + imports[0], + ImportFrom( + module="typing" if PY_GTE_3_8 else "typing_extensions", + names=[ + alias( + "Literal", + None, + identifier=None, + identifier_name=None, + ) + ], + level=0, + ), + ) + + deque( + map( + inner_test, + map( + infer_imports, + ( + pydantic_class_cls_def, + Assign( + targets=[Name("a", Load(), lineno=None, col_offset=None)], + value=set_value("cat"), + type_comment="Literal['cat']", + expr=None, + lineno=None, + ), + ), + ), + ), + maxlen=0, + ) + def test_node_to_dict(self) -> None: """ Tests `node_to_dict` @@ -642,6 +695,7 @@ def test_get_value(self) -> None: ) self.assertIsNone(get_value(Name(None, None))) self.assertEqual(get_value(get_value(ast.parse("-5").body[0])), -5) + self.assertEqual(get_value(Num(n=-5, constant_value=None, string=None)), -5) def test_set_value(self) -> None: """Tests that `set_value` returns the right type for the right Python version""" @@ -749,21 +803,76 @@ def test_find_ast_type_fails(self) -> None: def test_get_types(self) -> None: """Test that `get_types` functions correctly""" + self.assertTupleEqual(tuple(get_types(None)), tuple()) + self.assertTupleEqual(tuple(get_types("str")), ("str",)) + self.assertTupleEqual( + tuple( + get_types( + Subscript( + value=Name( + id="Optional", ctx=Load(), lineno=None, col_offset=None + ), + slice=Name(id="Any", ctx=Load(), lineno=None, col_offset=None), + ctx=Load(), + expr_context_ctx=None, + expr_slice=None, + expr_value=None, + lineno=None, + col_offset=None, + ) + ) + ), + ("Optional", "Any"), + ) self.assertTupleEqual( - tuple(get_types("str")), - ("str",), + tuple( + get_types( + Subscript( + value=Name( + id="Literal", ctx=Load(), lineno=None, col_offset=None + ), + slice=Tuple( + elts=list(map(set_value, ("foo", "bar"))), + ctx=Load(), + expr=None, + lineno=None, + col_offset=None, + ), + ctx=Load(), + expr_context_ctx=None, + expr_slice=None, + expr_value=None, + lineno=None, + col_offset=None, + ) + ) + ), + ("Literal",), ) self.assertTupleEqual( tuple( get_types( Subscript( - value=Name(id="Optional", ctx=Load()), - slice=Name(id="Any", ctx=Load()), + value=Name( + id="Tuple", ctx=Load(), lineno=None, col_offset=None + ), + slice=Tuple( + elts=list(map(set_value, ("int", "float"))), + ctx=Load(), + expr=None, + lineno=None, + col_offset=None, + ), ctx=Load(), + expr_context_ctx=None, + expr_slice=None, + expr_value=None, + lineno=None, + col_offset=None, ) ) ), - ("Optional", "Any"), + ("Tuple", "int", "float"), ) def test_to_named_class_def(self) -> None: