From c6bfc5862bcaf26d75b6551c6505c5138f2a0442 Mon Sep 17 00:00:00 2001 From: Samuel Marks <807580+SamuelMarks@users.noreply.github.com> Date: Mon, 25 Dec 2023 18:34:28 -0500 Subject: [PATCH] [cdd/**/*.py] Increase type annotation coverage ; minor logic fixes that this annotation found --- cdd/argparse_function/emit.py | 2 +- cdd/argparse_function/parse.py | 25 ++++++---- cdd/argparse_function/utils/emit_utils.py | 38 ++++++++------- cdd/class_/parse.py | 32 +++++++------ cdd/compound/doctrans.py | 10 ++-- cdd/compound/exmod.py | 43 +++++++++-------- cdd/compound/exmod_utils.py | 43 ++++++++++------- cdd/compound/openapi/gen_openapi.py | 3 +- cdd/compound/openapi/gen_routes.py | 35 +++++++++----- cdd/compound/openapi/parse.py | 6 +-- .../openapi/utils/emit_openapi_utils.py | 6 +-- cdd/compound/openapi/utils/emit_utils.py | 45 +++++++++-------- cdd/compound/openapi/utils/parse_utils.py | 2 +- cdd/shared/pure_utils.py | 48 ++++++++++++------- cdd/shared/types.py | 19 ++------ cdd/sqlalchemy/utils/shared_utils.py | 3 +- cdd/tests/test_cli/test_cli.py | 6 +-- 17 files changed, 203 insertions(+), 163 deletions(-) diff --git a/cdd/argparse_function/emit.py b/cdd/argparse_function/emit.py index 03485d98..71d22f45 100644 --- a/cdd/argparse_function/emit.py +++ b/cdd/argparse_function/emit.py @@ -274,4 +274,4 @@ def argparse_function( ) -__all__ = ["argparse_function"] +__all__ = ["argparse_function"] # type: list[str] diff --git a/cdd/argparse_function/parse.py b/cdd/argparse_function/parse.py index 50997ac9..f823f569 100644 --- a/cdd/argparse_function/parse.py +++ b/cdd/argparse_function/parse.py @@ -1,13 +1,13 @@ """ Argparse function parser """ - +from _ast import Call from ast import Assign, FunctionDef, Return, Tuple, get_docstring from collections import OrderedDict from functools import partial from itertools import filterfalse from operator import setitem -from typing import Optional +from typing import List, Optional, cast from cdd.argparse_function.utils.emit_utils import _parse_return, parse_out_param from cdd.shared.ast_utils import ( @@ -82,7 +82,9 @@ def argparse_ast( require_default = False # Parse all relevant nodes from function body - body = function_def.body if doc_string is None else function_def.body[1:] + body: FunctionDef.body = ( + function_def.body if doc_string is None else function_def.body[1:] + ) for node in body: if is_argparse_add_argument(node): name, _param = parse_out_param( @@ -95,7 +97,7 @@ def argparse_ast( else partial(setitem, intermediate_repr["params"], name) )(_param) if not require_default and _param.get("default") is not None: - require_default = True + require_default: bool = True elif isinstance(node, Assign) and is_argparse_description(node): intermediate_repr["doc"] = get_value(node.value) elif isinstance(node, Return) and isinstance(node.value, Tuple): @@ -110,11 +112,14 @@ def argparse_ast( ) ) - inner_body = list( - filterfalse( - is_argparse_description, - filterfalse(is_argparse_add_argument, body), - ) + inner_body: List[Call] = cast( + List[Call], + list( + filterfalse( + is_argparse_description, + filterfalse(is_argparse_add_argument, body), + ) + ), ) if inner_body: intermediate_repr["_internal"] = { @@ -129,4 +134,4 @@ def argparse_ast( return intermediate_repr -__all__ = ["argparse_ast"] +__all__ = ["argparse_ast"] # type: list[str] diff --git a/cdd/argparse_function/utils/emit_utils.py b/cdd/argparse_function/utils/emit_utils.py index fb62c606..3d85f99d 100644 --- a/cdd/argparse_function/utils/emit_utils.py +++ b/cdd/argparse_function/utils/emit_utils.py @@ -4,7 +4,7 @@ import ast from ast import Name, Return -from typing import Any +from typing import Any, Callable, Dict, Optional, Union, cast from cdd.shared.ast_utils import NoneStr, get_value, set_value from cdd.shared.defaults_utils import extract_default, set_default_doc @@ -45,9 +45,9 @@ def _parse_return(e, intermediate_repr, function_def, emit_default_doc): typ: str = intermediate_repr["returns"]["return_type"]["typ"] if "[" in intermediate_repr["returns"]["return_type"]["typ"]: - typ = to_code(get_value(ast.parse(typ).body[0].value.slice).elts[1]).rstrip( - "\n" - ) + typ: str = to_code( + get_value(ast.parse(typ).body[0].value.slice).elts[1] + ).rstrip("\n") return set_default_doc( ( @@ -90,7 +90,7 @@ def parse_out_param(expr, require_default=False, emit_default_doc=True): :rtype: ```tuple[str, dict]``` """ # print("require_default:", require_default, ";") - required = get_value( + required: bool = get_value( get_value( next( ( @@ -103,7 +103,7 @@ def parse_out_param(expr, require_default=False, emit_default_doc=True): ) ) - typ = next( + typ: str = next( ( _handle_value(get_value(key_word)) for key_word in expr.value.keywords @@ -112,7 +112,7 @@ def parse_out_param(expr, require_default=False, emit_default_doc=True): "str", ) name: str = get_value(expr.value.args[0])[len("--") :] - default = next( + default: Optional[Any] = next( ( get_value(key_word.value) for key_word in expr.value.keywords @@ -120,7 +120,7 @@ def parse_out_param(expr, require_default=False, emit_default_doc=True): ), None, ) - doc = ( + doc: Optional[str] = ( lambda help_: help_ if help_ is None else ( @@ -152,12 +152,16 @@ def parse_out_param(expr, require_default=False, emit_default_doc=True): # if name.endswith("kwargs"): # default = NoneStr # else: - default = simple_types[typ] if typ in simple_types else NoneStr + default: Optional[ + Dict[Optional[str], Union[int, float, complex, str, bool, None]] + ] = (simple_types[typ] if typ in simple_types else NoneStr) elif require_default: # or typ.startswith("Optional"): - default = NoneStr + default: Optional[ + Dict[Optional[str], Union[int, float, complex, str, bool, None]] + ] = NoneStr - action = next( + action: Optional[Any] = next( ( get_value(key_word.value) for key_word in expr.value.keywords @@ -166,7 +170,7 @@ def parse_out_param(expr, require_default=False, emit_default_doc=True): None, ) - typ = next( + typ: Optional[Any] = next( ( _handle_keyword(keyword, typ) for keyword in expr.value.keywords @@ -175,10 +179,10 @@ def parse_out_param(expr, require_default=False, emit_default_doc=True): typ, ) if action == "append": - typ = "List[{typ}]".format(typ=typ) + typ: str = "List[{typ}]".format(typ=typ) if not required and "Optional" not in typ: - typ = "Optional[{typ}]".format(typ=typ) + typ: str = "Optional[{typ}]".format(typ=typ) return name, dict( doc=doc, typ=typ, **({} if default is None else {"default": default}) @@ -198,9 +202,9 @@ def _handle_keyword(keyword, typ): :return: string representation of type :rtype: ```str``` """ - quote_f = identity + quote_f: Callable[[str], str] = cast(Callable[[str], str], identity) - type_ = "Union" + type_: str = "Union" if typ == Any or typ in simple_types: if typ in ("str", Any): @@ -216,7 +220,7 @@ def quote_f(s): """ return "'{}'".format(s) - type_ = "Literal" + type_: str = "Literal" return "{type}[{types}]".format( type=type_, diff --git a/cdd/class_/parse.py b/cdd/class_/parse.py index b1e79ce3..e23e835c 100644 --- a/cdd/class_/parse.py +++ b/cdd/class_/parse.py @@ -111,10 +111,10 @@ def class_( (("return_type", intermediate_repr["params"].pop("return_type")),) ) - body = class_def.body if doc_str is None else class_def.body[1:] + body: ClassDef.body = class_def.body if doc_str is None else class_def.body[1:] for e in body: if isinstance(e, AnnAssign): - typ = to_code(e.annotation).rstrip("\n") + typ: str = to_code(e.annotation).rstrip("\n") # print(ast.dump(e, indent=4)) val = ( ( @@ -137,14 +137,16 @@ def class_( ) # if 'str' in typ and val: val["default"] = val["default"].strip("'") # Unquote? - typ_default = {"typ": typ} if val is None else dict(typ=typ, **val) + typ_default = ( + {"typ": typ} if val is None else dict(typ=typ, **val) + ) # type: Union[bool, dict[str, Any]] - target_id = e.target.id.lstrip("*") + target_id: str = e.target.id.lstrip("*") for key in "params", "returns": if target_id in (intermediate_repr[key] or iter(())): intermediate_repr[key][target_id].update(typ_default) - typ_default = False + typ_default: bool = False break if typ_default: @@ -281,16 +283,16 @@ def _class_from_memory( src: Optional[str] = get_source(class_def) if src is None: return ir - parsed_body: ClassDef = cast(ClassDef, ast.parse(src.lstrip()).body[0]) + parsed_class: ClassDef = cast(ClassDef, ast.parse(src.lstrip()).body[0]) original_doc_str: Optional[str] = get_docstring( - parsed_body, clean=parse_original_whitespace + parsed_class, clean=parse_original_whitespace ) - parsed_body.body = ( - parsed_body.body if original_doc_str is None else parsed_body.body[1:] + parsed_class.body = ( + parsed_class.body if original_doc_str is None else parsed_class.body[1:] ) if merge_inner_function is not None: _merge_inner_function( - parsed_body, + parsed_class, infer_type=infer_type, intermediate_repr=ir, merge_inner_function=merge_inner_function, @@ -299,20 +301,20 @@ def _class_from_memory( ir["_internal"] = { "original_doc_str": original_doc_str if parse_original_whitespace - else get_docstring(parsed_body, clean=False), + else get_docstring(parsed_class, clean=False), "body": list( filterfalse( rpartial(isinstance, (AnnAssign, Assign)), - parsed_body.body, + parsed_class.body, ) ), - "from_name": class_name, + "from_name": cast(str, class_name), "from_type": "cls", } if class_name is None: - class_name: Optional[str] = parsed_body.name + class_name: Optional[str] = parsed_class.name body_ir: IntermediateRepr = class_( - class_def=parsed_body, + class_def=parsed_class, class_name=class_name, merge_inner_function=merge_inner_function, ) diff --git a/cdd/compound/doctrans.py b/cdd/compound/doctrans.py index 7a229dac..3f5a0cbc 100644 --- a/cdd/compound/doctrans.py +++ b/cdd/compound/doctrans.py @@ -1,7 +1,7 @@ """ Helper to traverse the AST of the input file, extract the docstring out, parse and format to intended style, and emit """ - +from _ast import Module from ast import fix_missing_locations from copy import deepcopy from operator import attrgetter @@ -30,11 +30,11 @@ def doctrans(filename, docstring_format, type_annotations, no_word_wrap): :type no_word_wrap: ```Optional[Literal[True]]``` """ with open(filename, "rt") as f: - original_source = f.read() - node = ast_parse(original_source, skip_docstring_remit=False) - original_module = deepcopy(node) + original_source: str = f.read() + node: Module = ast_parse(original_source, skip_docstring_remit=False) + original_module: Module = deepcopy(node) - node = fix_missing_locations( + node: Module = fix_missing_locations( DocTrans( docstring_format=docstring_format, word_wrap=no_word_wrap is None, diff --git a/cdd/compound/exmod.py b/cdd/compound/exmod.py index a18dc39f..3f3659c7 100644 --- a/cdd/compound/exmod.py +++ b/cdd/compound/exmod.py @@ -8,7 +8,7 @@ from itertools import chain, groupby from operator import attrgetter, itemgetter from os import makedirs, path -from typing import Optional +from typing import Optional, Tuple, cast import cdd.class_.parse import cdd.compound.exmod_utils @@ -123,7 +123,7 @@ def exmod( except AssertionError as e: raise ModuleNotFoundError(e) - mod_path = ( + mod_path: str = ( module_name if module_name.startswith(module_root + ".") else ".".join((module_root, module_name)) @@ -131,7 +131,7 @@ def exmod( blacklist, whitelist = map( frozenset, (blacklist or iter(()), whitelist or iter(())) ) - proceed = any( + proceed: bool = any( ( sum(map(len, (blacklist, whitelist))) == 0, mod_path not in blacklist and (mod_path in whitelist or not whitelist), @@ -153,7 +153,7 @@ def exmod( imports = _emit_files_from_module_and_return_imports( module_name=module_name, module=module, module_root_dir=module_root_dir - ) + ) # type: Optional[list[ImportFrom]] if not imports: # Case: no obvious folder hierarchy, so parse the `__init__` file in root with open( @@ -214,26 +214,29 @@ def exmod( ), ) ) - ) + ) # type: list[ImportFrom] assert imports, "Module contents are empty" - modules_names = tuple( - map( - lambda name_module: ( - name_module[0], - tuple(map(itemgetter(1), name_module[1])), - ), - groupby( - map( - lambda node_mod: ( - node_mod[0], - node_mod[2].module, + modules_names: Tuple[str, ...] = cast( + Tuple[str, ...], + tuple( + map( + lambda name_module: ( + name_module[0], + tuple(map(itemgetter(1), name_module[1])), + ), + groupby( + map( + lambda node_mod: ( + node_mod[0], + node_mod[2].module, + ), + imports, ), - imports, + itemgetter(0), ), - itemgetter(0), - ), - ) + ) + ), ) init_filepath: str = path.join( diff --git a/cdd/compound/exmod_utils.py b/cdd/compound/exmod_utils.py index c9f3384e..af44638d 100644 --- a/cdd/compound/exmod_utils.py +++ b/cdd/compound/exmod_utils.py @@ -1,6 +1,8 @@ """ Exmod utils """ import ast +import sys +from _ast import AST from ast import Assign, Expr, ImportFrom, List, Load, Module, Name, Store, alias from collections import OrderedDict, defaultdict, deque from functools import partial @@ -8,8 +10,7 @@ from itertools import chain from operator import attrgetter, eq from os import environ, extsep, makedirs, path -from sys import stdout -from typing import Optional +from typing import Any, Dict, Optional, TextIO, cast import cdd.argparse_function.emit import cdd.class_ @@ -34,7 +35,7 @@ from cdd.shared.source_transformer import ast_parse from cdd.tests.mocks import imports_header_ast -EXMOD_OUT_STREAM = environ.get("EXMOD_OUT_STREAM", stdout) +EXMOD_OUT_STREAM: TextIO = getattr(sys, environ.get("EXMOD_OUT_STREAM", "stdout")) def get_module_contents(obj, module_root_dir, current_module=None, _result={}): @@ -83,8 +84,8 @@ def get_module_contents(obj, module_root_dir, current_module=None, _result={}): ), ), iter(()), - ) - mod_to_symbol = defaultdict(list) + ) # type: Union[list[str], Iterator] + mod_to_symbol: defaultdict[Any, list] = defaultdict(list) deque( ( mod_to_symbol[import_from.module].append(name.name) @@ -98,7 +99,7 @@ def get_module_contents(obj, module_root_dir, current_module=None, _result={}): ), maxlen=0, ) - res = { + res: Dict[str, AST] = { "{module_name}.{submodule_name}.{node_name}".format( module_name=module_name, submodule_name=submodule_name, @@ -162,11 +163,11 @@ def _process_module_contents(_result, current_module, module_root_dir, name, sym :param symbol: Symbol—second value—from `dir(module)` :type symbol: ```type``` """ - fq = "{current_module}.{name}".format(current_module=current_module, name=name) + fq: str = "{current_module}.{name}".format(current_module=current_module, name=name) try: - symbol_location = getfile(symbol) + symbol_location: Optional[str] = getfile(symbol) except TypeError: - symbol_location = None + symbol_location: Optional[str] = None if symbol_location is not None and symbol_location.startswith(module_root_dir): if isinstance(symbol, type): _result[fq] = symbol @@ -231,15 +232,17 @@ def emit_file_on_hierarchy( original_relative_filename_path, ir = name_orig_ir[1], name_orig_ir[2] assert original_relative_filename_path - relative_filename_path = original_relative_filename_path - module_name_as_path = module_name.replace(".", path.sep) - new_module_name_as_path = new_module_name.replace(".", path.sep) + relative_filename_path: str = original_relative_filename_path + module_name_as_path: str = module_name.replace(".", path.sep) + new_module_name_as_path: str = new_module_name.replace(".", path.sep) if relative_filename_path.startswith(module_name_as_path + path.sep): - relative_filename_path = relative_filename_path[len(new_module_name_as_path) :] + relative_filename_path: str = relative_filename_path[ + len(new_module_name_as_path) : + ] if not name and ir.get("name") is not None: name: Optional[str] = ir.get("name") - output_dir_is_module = output_directory.replace(path.sep, ".").endswith( + output_dir_is_module: bool = output_directory.replace(path.sep, ".").endswith( new_module_name ) mod_path: str = path.join( @@ -290,15 +293,16 @@ def emit_file_on_hierarchy( ), ) ) - isfile_emit_filename = symbol_in_file = path.isfile(emit_filename) + symbol_in_file: bool = path.isfile(emit_filename) + isfile_emit_filename: bool = symbol_in_file existent_mod: Optional[Module] = None if isfile_emit_filename: with open(emit_filename, "rt") as f: - emit_filename_contents = f.read() + emit_filename_contents: str = f.read() existent_mod: Module = ast.parse( emit_filename_contents ) # Also, useful as this catches syntax errors - symbol_in_file = any( + symbol_in_file: bool = any( filter( partial(eq, name), map( @@ -481,7 +485,10 @@ def _emit_symbol( type_ignores=[], ) if isfile_emit_filename: - gen_node = cdd.shared.ast_utils.merge_modules(existent_mod, gen_node) + if existent_mod is not None: + gen_node: Module = cdd.shared.ast_utils.merge_modules( + cast(Module, existent_mod), gen_node + ) cdd.shared.ast_utils.merge_assignment_lists(gen_node, "__all__") if dry_run: print( diff --git a/cdd/compound/openapi/gen_openapi.py b/cdd/compound/openapi/gen_openapi.py index b7d82be4..10a919cb 100644 --- a/cdd/compound/openapi/gen_openapi.py +++ b/cdd/compound/openapi/gen_openapi.py @@ -18,6 +18,7 @@ from cdd.shared.ast_utils import get_value from cdd.shared.parse.utils.parser_utils import infer from cdd.shared.pure_utils import rpartial, update_d +from cdd.sqlalchemy.utils.shared_utils import OpenAPI_requestBodies from cdd.tests.mocks.json_schema import server_error_schema @@ -37,7 +38,7 @@ def openapi_bulk(app_name, model_paths, routes_paths): :return: OpenAPI dictionary :rtype: ```dict``` """ - request_bodies = {} + request_bodies: OpenAPI_requestBodies = {} def parse_model(filename): """ diff --git a/cdd/compound/openapi/gen_routes.py b/cdd/compound/openapi/gen_routes.py index 83e74455..008451fb 100644 --- a/cdd/compound/openapi/gen_routes.py +++ b/cdd/compound/openapi/gen_routes.py @@ -4,9 +4,11 @@ import ast from ast import Attribute, Call, ClassDef, FunctionDef, Module, Name +from collections.abc import dict_keys from itertools import chain from operator import attrgetter, itemgetter from os import path +from typing import Any, Callable, Dict, List, Optional, Union import cdd.routes.emit.bottle import cdd.sqlalchemy.parse @@ -50,18 +52,20 @@ def gen_routes(app, model_path, model_name, crud, route): :return: Iterator of functions representing relevant CRUD operations :rtype: ```Iterator[FunctionDef]``` """ - model_path = filename_from_mod_or_filename(model_path) + model_path: str = filename_from_mod_or_filename(model_path) assert path.isfile(model_path) with open(model_path, "rt") as f: mod: Module = ast.parse(f.read()) - sqlalchemy_node = next( + sqlalchemy_node: Optional[ClassDef] = next( filter( lambda node: isinstance(node, ClassDef) - and node.name == model_name - or isinstance(node, Name) - and node.id == model_name, + and ( + node.name == model_name + or isinstance(node.name, Name) + and node.name.id == model_name + ), ast.walk(mod), ), None, @@ -69,7 +73,7 @@ def gen_routes(app, model_path, model_name, crud, route): sqlalchemy_ir: IntermediateRepr = cdd.sqlalchemy.parse.sqlalchemy( Module(body=[sqlalchemy_node], stmt=None, type_ignores=[]) ) - primary_key = next( + primary_key: str = next( map( itemgetter(0), filter( @@ -79,13 +83,18 @@ def gen_routes(app, model_path, model_name, crud, route): ), next(iter(sqlalchemy_ir["params"].keys())), ) - _route_config = {"app": app, "name": model_name, "route": route, "variant": -1} - routes = [] + _route_config: dict[str, Union[str, int]] = { + "app": app, + "name": model_name, + "route": route, + "variant": -1, + } + routes: List[str] = [] if "C" in crud: routes.append(cdd.routes.emit.bottle.create(**_route_config)) _route_config["primary_key"] = primary_key - funcs = { + funcs: dict[str, Optional[Callable[[str, str, str, Any, int], str]]] = { "R": cdd.routes.emit.bottle.read, "U": None, "D": cdd.routes.emit.bottle.destroy, @@ -116,7 +125,7 @@ def upsert_routes(app, routes, routes_path, route, primary_key): :param routes_path: The path/module-resolution whence the routes are / will be :type routes_path: ```str``` """ - routes_path = filename_from_mod_or_filename(routes_path) + routes_path: str = filename_from_mod_or_filename(routes_path) if not path.isfile(routes_path): with open(routes_path, "wt") as f: @@ -172,8 +181,8 @@ def get_names(functions): ) ) - routes_required = get_names(routes) - routes_existing = get_names( + routes_required: Dict[str, FunctionDef] = get_names(routes) + routes_existing: Dict[str, FunctionDef] = get_names( filter( lambda node: any( filter( @@ -197,7 +206,7 @@ def get_names(functions): filter(rpartial(isinstance, FunctionDef), ast.walk(mod)), ) ) - missing_routes = ( + missing_routes: dict_keys[str, str] = ( routes_required.keys() & routes_existing.keys() ^ routes_required.keys() ) diff --git a/cdd/compound/openapi/parse.py b/cdd/compound/openapi/parse.py index 685088a4..195d03d4 100644 --- a/cdd/compound/openapi/parse.py +++ b/cdd/compound/openapi/parse.py @@ -3,7 +3,7 @@ """ from json import loads -from typing import Optional +from typing import List, Optional from yaml import safe_load @@ -25,7 +25,7 @@ def openapi(openapi_str, routes_dict, summary): :return: OpenAPI dictionary """ - entities = extract_entities(openapi_str) + entities: List[str] = extract_entities(openapi_str) non_error_entity: Optional[str] = None @@ -35,7 +35,7 @@ def openapi(openapi_str, routes_dict, summary): "{{'$ref': '#/components/schemas/{entity}'}}".format(entity=entity), ) if entity != "ServerError": - non_error_entity = entity + non_error_entity: str = entity openapi_d: dict = (loads if openapi_str.startswith("{") else safe_load)(openapi_str) if non_error_entity is not None: openapi_d["summary"] = "{located} `{entity}` object.".format( diff --git a/cdd/compound/openapi/utils/emit_openapi_utils.py b/cdd/compound/openapi/utils/emit_openapi_utils.py index 639cd97e..9863a2c6 100644 --- a/cdd/compound/openapi/utils/emit_openapi_utils.py +++ b/cdd/compound/openapi/utils/emit_openapi_utils.py @@ -51,7 +51,7 @@ def components_paths_from_name_model_route_id_crud( Literal['D', 'C', 'R'], Literal['D', 'C', 'U'], Literal['D', 'R', 'C'], Literal['D', 'R', 'U'], Literal['D', 'U', 'C'], Literal['D', 'U', 'R']]``` """ - _request_body = False + _request_body: bool = False if "C" in crud: paths[route] = { "post": { @@ -84,9 +84,9 @@ def components_paths_from_name_model_route_id_crud( }, } } - _request_body = True + _request_body: bool = True if not frozenset(crud) - frozenset("CRUD"): - _route = "{route}/{{{id}}}".format(route=route, id=_id) + _route: str = "{route}/{{{id}}}".format(route=route, id=_id) paths[_route] = { "parameters": [ { diff --git a/cdd/compound/openapi/utils/emit_utils.py b/cdd/compound/openapi/utils/emit_utils.py index 547f7c15..237209fe 100644 --- a/cdd/compound/openapi/utils/emit_utils.py +++ b/cdd/compound/openapi/utils/emit_utils.py @@ -33,7 +33,7 @@ from operator import attrgetter, eq, methodcaller from os import path from platform import system -from typing import Optional +from typing import Any, Dict, List, Optional import cdd.sqlalchemy.utils.shared_utils from cdd.shared.ast_utils import ( @@ -51,6 +51,7 @@ tab, ) from cdd.shared.source_transformer import to_code +from cdd.shared.types import ParamVal from cdd.sqlalchemy.utils.parse_utils import ( column_type2typ, get_pk_and_type, @@ -88,25 +89,27 @@ def param_to_sqlalchemy_column_call(name_param, include_name): if include_name: args.append(set_value(name)) - x_typ_sql = _param.get("x_typ", {}).get("sql", {}) + x_typ_sql = _param.get("x_typ", {}).get("sql", {}) # type: dict if "typ" in _param: - nullable = cdd.sqlalchemy.utils.shared_utils.update_args_infer_typ_sqlalchemy( - _param, args, name, nullable, x_typ_sql + nullable: bool = ( + cdd.sqlalchemy.utils.shared_utils.update_args_infer_typ_sqlalchemy( + _param, args, name, nullable, x_typ_sql + ) ) default = x_typ_sql.get("default", _param.get("default", ast)) - has_default = default is not ast - pk = _param.get("doc", "").startswith("[PK]") - fk = _param.get("doc", "").startswith("[FK") + has_default: bool = default is not ast + pk: bool = _param.get("doc", "").startswith("[PK]") + fk: bool = _param.get("doc", "").startswith("[FK") if pk: _param["doc"] = _param["doc"][4:].lstrip() keywords.append( ast.keyword(arg="primary_key", value=set_value(True), identifier=None), ) elif fk: - end = _param["doc"].find("]") + 1 - fk_val = _param["doc"][len("[FK(") : end - len(")]")] + end: int = _param["doc"].find("]") + 1 + fk_val: str = _param["doc"][len("[FK(") : end - len(")]")] _param["doc"] = _param["doc"][end:].lstrip() args.append( Call( @@ -116,9 +119,9 @@ def param_to_sqlalchemy_column_call(name_param, include_name): ) ) elif has_default and default not in none_types: - nullable = False + nullable: bool = False - rstripped_dot_doc = _param.get("doc", "").rstrip(".") + rstripped_dot_doc: str = _param.get("doc", "").rstrip(".") doc_added_at: Optional[int] = None if rstripped_dot_doc: doc_added_at: int = len(keywords) @@ -195,7 +198,7 @@ def generate_repr_method(params, cls_name, docstring_format): :return: `__repr__` method :rtype: ```FunctionDef``` """ - keys = tuple(params.keys()) + keys = tuple(params.keys()) # type: tuple[str, ...] return FunctionDef( name="__repr__", args=arguments( @@ -280,7 +283,7 @@ def generate_create_from_attr_staticmethod(params, cls_name, docstring_format): :return: `__repr__` method :rtype: ```FunctionDef``` """ - keys = tuple(params.keys()) + keys = tuple(params.keys()) # type: tuple[str, ...] return FunctionDef( name="create_from_attr", args=arguments( @@ -399,7 +402,7 @@ def ensure_has_primary_key(intermediate_repr, force_pk_id=False): }) :rtype: ```dict``` """ - params = ( + params: OrderedDict[str, ParamVal] = ( intermediate_repr if isinstance(intermediate_repr, OrderedDict) else intermediate_repr["params"] @@ -413,7 +416,7 @@ def ensure_has_primary_key(intermediate_repr, force_pk_id=False): ), ) ): - candidate_pks = [] + candidate_pks: List[str] = [] deque( map( candidate_pks.append, @@ -607,7 +610,7 @@ def handle_sqlalchemy_cls(symbol_to_module, sqlalchemy_class_def): ) return sqlalchemy_class_def - symbol2module = dict( + symbol2module: Dict[str, Any] = dict( chain.from_iterable( map( lambda import_from: map( @@ -704,14 +707,14 @@ def rewrite_fk_from_import(column_name, foreign_key_call): "rt", ) as f: mod: Module = ast.parse(f.read()) - matching_class = next( + matching_class: ClassDef = next( filter( lambda node: isinstance(node, ClassDef) and node.name == column_name.id, mod.body, ) ) - pk_typ = get_pk_and_type(matching_class) + pk_typ = get_pk_and_type(matching_class) # type: tuple[str, str] assert pk_typ is not None pk, typ = pk_typ del pk_typ @@ -754,7 +757,7 @@ def sqlalchemy_class_to_table(class_def, parse_original_whitespace): ) # Hybrid SQLalchemy class/table handler - table_dunder = next( + table_dunder: Optional[Call] = next( filter( lambda assign: any( filter( @@ -771,7 +774,7 @@ def sqlalchemy_class_to_table(class_def, parse_original_whitespace): # Parse into the same format that `sqlalchemy_table` can read, then return with a call to it - name = get_value( + name: str = get_value( next( filter( lambda assign: any( @@ -898,7 +901,7 @@ def sqlalchemy_table_to_class(table_expr_ass): ) -typ2column_type = {v: k for k, v in column_type2typ.items()} +typ2column_type: Dict[str, str] = {v: k for k, v in column_type2typ.items()} typ2column_type.update( { "bool": "Boolean", diff --git a/cdd/compound/openapi/utils/parse_utils.py b/cdd/compound/openapi/utils/parse_utils.py index 65daac39..196fcfea 100644 --- a/cdd/compound/openapi/utils/parse_utils.py +++ b/cdd/compound/openapi/utils/parse_utils.py @@ -19,7 +19,7 @@ def add_then_clear_stack(): """ Join entity, if non-empty add to entities. Clear stack. """ - entity = "".join(stack) + entity: str = "".join(stack) if entity: entities.append(entity) stack.clear() diff --git a/cdd/shared/pure_utils.py b/cdd/shared/pure_utils.py index 94a9856f..13141630 100644 --- a/cdd/shared/pure_utils.py +++ b/cdd/shared/pure_utils.py @@ -32,7 +32,7 @@ "bool": False, None: None, } -type_to_name = { +type_to_name: Dict[str, str] = { "Int": "int", "int": "int", "Float": "float", @@ -45,8 +45,8 @@ "None": "None", } -line_length = environ.get("DOCTRANS_LINE_LENGTH", 100) -fill = partial(_fill, width=line_length) +line_length: int = int(environ.get("DOCTRANS_LINE_LENGTH", 100)) +fill: Callable[[str], str] = partial(_fill, width=line_length) def read_file_to_str(filename, mode="rt"): @@ -297,7 +297,7 @@ def deindent(s, level=None, sep=tab): :rtype: ```AnyStr``` """ if level is None: - process_line = str.lstrip + process_line: Callable[[str], str] = str.lstrip else: sep *= level @@ -475,7 +475,12 @@ def quote(s, mark='"'): :return: Quoted string or input (if input is not str) :rtype: ```Union[str, float, complex, int, None]``` """ - very_simple_types = type(None), int, float, complex + very_simple_types = ( + type(None), + int, + float, + complex, + ) # type: tuple[Type[None], Type[int], Type[float], Type[complex]] s: str = ( s if isinstance(s, (str, *very_simple_types)) @@ -583,7 +588,7 @@ def update_d(d, arg=None, **kwargs): :rtype: ```dict``` """ if arg: - d.update(arg) + d.update(typing.cast(dict, arg)) if kwargs: d.update(kwargs) return d @@ -616,7 +621,7 @@ def diff(input_obj, op): :type input_obj: ```Any``` :param op: The operation to run - :type op: ```Callable[[Any], Any]``` + :type op: ```Callable[[Any], Sized]``` :return: length of difference, response of operated input :rtype: ```tuple[int, Any]``` @@ -624,13 +629,13 @@ def diff(input_obj, op): input_len: int = len( input_obj ) # Separate line and binding, as `op` could mutate the `input` - result = op(input_obj) + result: typing.Sized = op(input_obj) return input_len - len(result), result -strip_diff = partial(diff, op=str.strip) -lstrip_diff = partial(diff, op=str.lstrip) -rstrip_diff = partial(diff, op=str.rstrip) +strip_diff: Callable[[str], tuple[int, typing.Any]] = partial(diff, op=str.strip) +lstrip_diff: Callable[[str], tuple[int, typing.Any]] = partial(diff, op=str.lstrip) +rstrip_diff: Callable[[str], tuple[int, typing.Any]] = partial(diff, op=str.rstrip) def balanced_parentheses(s): @@ -643,8 +648,9 @@ def balanced_parentheses(s): :return: Whether the parens are balanced :rtype: ```bool``` """ - open_parens, closed_parens = "([{", ")]}" - counter = {paren: 0 for paren in open_parens + closed_parens} + open_parens: str = "([{" + closed_parens: str = ")]}" + counter: Dict[str, int] = {paren: 0 for paren in open_parens + closed_parens} quote_mark: Optional[typing.Literal["'", '"']] = None for idx, ch in enumerate(s): if ( @@ -652,10 +658,12 @@ def balanced_parentheses(s): and ch == quote_mark and (idx == 0 or s[idx - 1] != "\\") ): - quote_mark = None + quote_mark: Optional[typing.Literal["'", '"']] = None elif quote_mark is None: if ch in frozenset(("'", '"')): - quote_mark = typing.cast(typing.Literal["'", '"'], ch) + quote_mark: Optional[typing.Literal["'", '"']] = typing.cast( + typing.Literal["'", '"'], ch + ) elif ch in counter: counter[ch] += 1 return all( @@ -702,7 +710,7 @@ def location_within(container, iterable, cmp=eq): :rtype: ```tuple[int, int, Optional[Any]]``` """ if not hasattr(container, "__len__"): - container = tuple(container) + container: Tuple[typing.Any] = tuple(container) container_len: int = len(container) for elem in iterable: @@ -1052,6 +1060,10 @@ def paren_wrap_code(code): ) +class FilenameProtocol(typing.Protocol): + origin: str + + def filename_from_mod_or_filename(mod_or_filename): """ Resolve filename from module name or filename @@ -1062,7 +1074,9 @@ def filename_from_mod_or_filename(mod_or_filename): :return: Filename :rtype: ```str``` """ - filename = type("", tuple(), {"origin": mod_or_filename}) + filename: FilenameProtocol = typing.cast( + FilenameProtocol, type("", tuple(), {"origin": mod_or_filename}) + ) return ( filename if path.sep in mod_or_filename or path.isfile(mod_or_filename) diff --git a/cdd/shared/types.py b/cdd/shared/types.py index fa564ef1..d683b943 100644 --- a/cdd/shared/types.py +++ b/cdd/shared/types.py @@ -1,8 +1,7 @@ """ Shared types """ - -from ast import AnnAssign, Assign +from _ast import AST from cdd.shared.pure_utils import PY_GTE_3_8, PY_GTE_3_9, PY_GTE_3_11 @@ -11,22 +10,14 @@ from collections import OrderedDict else: from typing import OrderedDict - from typing import Any, List, Optional, TypedDict, Union + from typing import Any, List, Optional, TypedDict if PY_GTE_3_11: from typing import Required else: from typing_extensions import Required else: - from typing_extensions import ( - Any, - List, - Optional, - OrderedDict, - Required, - TypedDict, - Union, - ) + from typing_extensions import Any, List, Optional, OrderedDict, Required, TypedDict # class Parse(Protocol): @@ -43,8 +34,8 @@ "Internal", { "original_doc_str": Optional[str], - "body": List[Union[AnnAssign, Assign]], - "from_name": str, + "body": List[AST], + "from_name": Optional[str], "from_type": str, }, total=False, diff --git a/cdd/sqlalchemy/utils/shared_utils.py b/cdd/sqlalchemy/utils/shared_utils.py index d6db78a4..c1406ae8 100644 --- a/cdd/sqlalchemy/utils/shared_utils.py +++ b/cdd/sqlalchemy/utils/shared_utils.py @@ -173,8 +173,9 @@ def update_args_infer_typ_sqlalchemy(_param, args, name, nullable, x_typ_sql): # TODO: Finish writing these types OpenAPI_info = TypedDict("OpenAPI_info", {"title": str, "version": str}) +OpenAPI_requestBodies = dict OpenAPI_components = TypedDict( - "OpenAPI_components", {"requestBodies": dict, "schemas": dict} + "OpenAPI_components", {"requestBodies": OpenAPI_requestBodies, "schemas": dict} ) JSON_ref = TypedDict("JSON_ref", {"$ref": str, "required": bool}) OpenAPI_paths = dict diff --git a/cdd/tests/test_cli/test_cli.py b/cdd/tests/test_cli/test_cli.py index 181572b0..544828f1 100644 --- a/cdd/tests/test_cli/test_cli.py +++ b/cdd/tests/test_cli/test_cli.py @@ -19,7 +19,7 @@ class TestCli(TestCase): def test_build_parser(self) -> None: """Test that `_build_parser` produces a parser object""" - parser = _build_parser() + parser: ArgumentParser = _build_parser() self.assertIsInstance(parser, ArgumentParser) self.assertEqual(parser.description, __description__) @@ -36,9 +36,9 @@ def test_version(self) -> None: def test_name_main(self) -> None: """Test the `if __name__ == '__main___'` block""" - argparse_mock = MagicMock() + argparse_mock: MagicMock = MagicMock() - loader = SourceFileLoader( + loader: SourceFileLoader = SourceFileLoader( "__main__", os.path.join( os.path.dirname(os.path.dirname(os.path.dirname(__file__))),