Skip to content

Commit

Permalink
[cdd/**/*.py] Increase type annotation coverage ; minor logic fixes …
Browse files Browse the repository at this point in the history
…that this annotation found
  • Loading branch information
SamuelMarks committed Dec 25, 2023
1 parent a21f899 commit c6bfc58
Show file tree
Hide file tree
Showing 17 changed files with 203 additions and 163 deletions.
2 changes: 1 addition & 1 deletion cdd/argparse_function/emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,4 +274,4 @@ def argparse_function(
)


__all__ = ["argparse_function"]
__all__ = ["argparse_function"] # type: list[str]
25 changes: 15 additions & 10 deletions cdd/argparse_function/parse.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""
Argparse function parser
"""

from _ast import Call
from ast import Assign, FunctionDef, Return, Tuple, get_docstring
from collections import OrderedDict
from functools import partial
from itertools import filterfalse
from operator import setitem
from typing import Optional
from typing import List, Optional, cast

from cdd.argparse_function.utils.emit_utils import _parse_return, parse_out_param
from cdd.shared.ast_utils import (
Expand Down Expand Up @@ -82,7 +82,9 @@ def argparse_ast(
require_default = False

# Parse all relevant nodes from function body
body = function_def.body if doc_string is None else function_def.body[1:]
body: FunctionDef.body = (
function_def.body if doc_string is None else function_def.body[1:]
)
for node in body:
if is_argparse_add_argument(node):
name, _param = parse_out_param(
Expand All @@ -95,7 +97,7 @@ def argparse_ast(
else partial(setitem, intermediate_repr["params"], name)
)(_param)
if not require_default and _param.get("default") is not None:
require_default = True
require_default: bool = True
elif isinstance(node, Assign) and is_argparse_description(node):
intermediate_repr["doc"] = get_value(node.value)
elif isinstance(node, Return) and isinstance(node.value, Tuple):
Expand All @@ -110,11 +112,14 @@ def argparse_ast(
)
)

inner_body = list(
filterfalse(
is_argparse_description,
filterfalse(is_argparse_add_argument, body),
)
inner_body: List[Call] = cast(
List[Call],
list(
filterfalse(
is_argparse_description,
filterfalse(is_argparse_add_argument, body),
)
),
)
if inner_body:
intermediate_repr["_internal"] = {
Expand All @@ -129,4 +134,4 @@ def argparse_ast(
return intermediate_repr


__all__ = ["argparse_ast"]
__all__ = ["argparse_ast"] # type: list[str]
38 changes: 21 additions & 17 deletions cdd/argparse_function/utils/emit_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import ast
from ast import Name, Return
from typing import Any
from typing import Any, Callable, Dict, Optional, Union, cast

from cdd.shared.ast_utils import NoneStr, get_value, set_value
from cdd.shared.defaults_utils import extract_default, set_default_doc
Expand Down Expand Up @@ -45,9 +45,9 @@ def _parse_return(e, intermediate_repr, function_def, emit_default_doc):

typ: str = intermediate_repr["returns"]["return_type"]["typ"]
if "[" in intermediate_repr["returns"]["return_type"]["typ"]:
typ = to_code(get_value(ast.parse(typ).body[0].value.slice).elts[1]).rstrip(
"\n"
)
typ: str = to_code(
get_value(ast.parse(typ).body[0].value.slice).elts[1]
).rstrip("\n")

return set_default_doc(
(
Expand Down Expand Up @@ -90,7 +90,7 @@ def parse_out_param(expr, require_default=False, emit_default_doc=True):
:rtype: ```tuple[str, dict]```
"""
# print("require_default:", require_default, ";")
required = get_value(
required: bool = get_value(
get_value(
next(
(
Expand All @@ -103,7 +103,7 @@ def parse_out_param(expr, require_default=False, emit_default_doc=True):
)
)

typ = next(
typ: str = next(
(
_handle_value(get_value(key_word))
for key_word in expr.value.keywords
Expand All @@ -112,15 +112,15 @@ def parse_out_param(expr, require_default=False, emit_default_doc=True):
"str",
)
name: str = get_value(expr.value.args[0])[len("--") :]
default = next(
default: Optional[Any] = next(
(
get_value(key_word.value)
for key_word in expr.value.keywords
if key_word.arg == "default"
),
None,
)
doc = (
doc: Optional[str] = (
lambda help_: help_
if help_ is None
else (
Expand Down Expand Up @@ -152,12 +152,16 @@ def parse_out_param(expr, require_default=False, emit_default_doc=True):
# if name.endswith("kwargs"):
# default = NoneStr
# else:
default = simple_types[typ] if typ in simple_types else NoneStr
default: Optional[
Dict[Optional[str], Union[int, float, complex, str, bool, None]]
] = (simple_types[typ] if typ in simple_types else NoneStr)

elif require_default: # or typ.startswith("Optional"):
default = NoneStr
default: Optional[
Dict[Optional[str], Union[int, float, complex, str, bool, None]]
] = NoneStr

action = next(
action: Optional[Any] = next(
(
get_value(key_word.value)
for key_word in expr.value.keywords
Expand All @@ -166,7 +170,7 @@ def parse_out_param(expr, require_default=False, emit_default_doc=True):
None,
)

typ = next(
typ: Optional[Any] = next(
(
_handle_keyword(keyword, typ)
for keyword in expr.value.keywords
Expand All @@ -175,10 +179,10 @@ def parse_out_param(expr, require_default=False, emit_default_doc=True):
typ,
)
if action == "append":
typ = "List[{typ}]".format(typ=typ)
typ: str = "List[{typ}]".format(typ=typ)

if not required and "Optional" not in typ:
typ = "Optional[{typ}]".format(typ=typ)
typ: str = "Optional[{typ}]".format(typ=typ)

return name, dict(
doc=doc, typ=typ, **({} if default is None else {"default": default})
Expand All @@ -198,9 +202,9 @@ def _handle_keyword(keyword, typ):
:return: string representation of type
:rtype: ```str```
"""
quote_f = identity
quote_f: Callable[[str], str] = cast(Callable[[str], str], identity)

type_ = "Union"
type_: str = "Union"
if typ == Any or typ in simple_types:
if typ in ("str", Any):

Expand All @@ -216,7 +220,7 @@ def quote_f(s):
"""
return "'{}'".format(s)

type_ = "Literal"
type_: str = "Literal"

return "{type}[{types}]".format(
type=type_,
Expand Down
32 changes: 17 additions & 15 deletions cdd/class_/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ def class_(
(("return_type", intermediate_repr["params"].pop("return_type")),)
)

body = class_def.body if doc_str is None else class_def.body[1:]
body: ClassDef.body = class_def.body if doc_str is None else class_def.body[1:]
for e in body:
if isinstance(e, AnnAssign):
typ = to_code(e.annotation).rstrip("\n")
typ: str = to_code(e.annotation).rstrip("\n")
# print(ast.dump(e, indent=4))
val = (
(
Expand All @@ -137,14 +137,16 @@ def class_(
)

# if 'str' in typ and val: val["default"] = val["default"].strip("'") # Unquote?
typ_default = {"typ": typ} if val is None else dict(typ=typ, **val)
typ_default = (
{"typ": typ} if val is None else dict(typ=typ, **val)
) # type: Union[bool, dict[str, Any]]

target_id = e.target.id.lstrip("*")
target_id: str = e.target.id.lstrip("*")

for key in "params", "returns":
if target_id in (intermediate_repr[key] or iter(())):
intermediate_repr[key][target_id].update(typ_default)
typ_default = False
typ_default: bool = False
break

if typ_default:
Expand Down Expand Up @@ -281,16 +283,16 @@ def _class_from_memory(
src: Optional[str] = get_source(class_def)
if src is None:
return ir
parsed_body: ClassDef = cast(ClassDef, ast.parse(src.lstrip()).body[0])
parsed_class: ClassDef = cast(ClassDef, ast.parse(src.lstrip()).body[0])
original_doc_str: Optional[str] = get_docstring(
parsed_body, clean=parse_original_whitespace
parsed_class, clean=parse_original_whitespace
)
parsed_body.body = (
parsed_body.body if original_doc_str is None else parsed_body.body[1:]
parsed_class.body = (
parsed_class.body if original_doc_str is None else parsed_class.body[1:]
)
if merge_inner_function is not None:
_merge_inner_function(
parsed_body,
parsed_class,
infer_type=infer_type,
intermediate_repr=ir,
merge_inner_function=merge_inner_function,
Expand All @@ -299,20 +301,20 @@ def _class_from_memory(
ir["_internal"] = {
"original_doc_str": original_doc_str
if parse_original_whitespace
else get_docstring(parsed_body, clean=False),
else get_docstring(parsed_class, clean=False),
"body": list(
filterfalse(
rpartial(isinstance, (AnnAssign, Assign)),
parsed_body.body,
parsed_class.body,
)
),
"from_name": class_name,
"from_name": cast(str, class_name),
"from_type": "cls",
}
if class_name is None:
class_name: Optional[str] = parsed_body.name
class_name: Optional[str] = parsed_class.name
body_ir: IntermediateRepr = class_(
class_def=parsed_body,
class_def=parsed_class,
class_name=class_name,
merge_inner_function=merge_inner_function,
)
Expand Down
10 changes: 5 additions & 5 deletions cdd/compound/doctrans.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Helper to traverse the AST of the input file, extract the docstring out, parse and format to intended style, and emit
"""

from _ast import Module
from ast import fix_missing_locations
from copy import deepcopy
from operator import attrgetter
Expand Down Expand Up @@ -30,11 +30,11 @@ def doctrans(filename, docstring_format, type_annotations, no_word_wrap):
:type no_word_wrap: ```Optional[Literal[True]]```
"""
with open(filename, "rt") as f:
original_source = f.read()
node = ast_parse(original_source, skip_docstring_remit=False)
original_module = deepcopy(node)
original_source: str = f.read()
node: Module = ast_parse(original_source, skip_docstring_remit=False)
original_module: Module = deepcopy(node)

node = fix_missing_locations(
node: Module = fix_missing_locations(
DocTrans(
docstring_format=docstring_format,
word_wrap=no_word_wrap is None,
Expand Down
43 changes: 23 additions & 20 deletions cdd/compound/exmod.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from itertools import chain, groupby
from operator import attrgetter, itemgetter
from os import makedirs, path
from typing import Optional
from typing import Optional, Tuple, cast

import cdd.class_.parse
import cdd.compound.exmod_utils
Expand Down Expand Up @@ -123,15 +123,15 @@ def exmod(
except AssertionError as e:
raise ModuleNotFoundError(e)

mod_path = (
mod_path: str = (
module_name
if module_name.startswith(module_root + ".")
else ".".join((module_root, module_name))
)
blacklist, whitelist = map(
frozenset, (blacklist or iter(()), whitelist or iter(()))
)
proceed = any(
proceed: bool = any(
(
sum(map(len, (blacklist, whitelist))) == 0,
mod_path not in blacklist and (mod_path in whitelist or not whitelist),
Expand All @@ -153,7 +153,7 @@ def exmod(

imports = _emit_files_from_module_and_return_imports(
module_name=module_name, module=module, module_root_dir=module_root_dir
)
) # type: Optional[list[ImportFrom]]
if not imports:
# Case: no obvious folder hierarchy, so parse the `__init__` file in root
with open(
Expand Down Expand Up @@ -214,26 +214,29 @@ def exmod(
),
)
)
)
) # type: list[ImportFrom]

assert imports, "Module contents are empty"
modules_names = tuple(
map(
lambda name_module: (
name_module[0],
tuple(map(itemgetter(1), name_module[1])),
),
groupby(
map(
lambda node_mod: (
node_mod[0],
node_mod[2].module,
modules_names: Tuple[str, ...] = cast(
Tuple[str, ...],
tuple(
map(
lambda name_module: (
name_module[0],
tuple(map(itemgetter(1), name_module[1])),
),
groupby(
map(
lambda node_mod: (
node_mod[0],
node_mod[2].module,
),
imports,
),
imports,
itemgetter(0),
),
itemgetter(0),
),
)
)
),
)

init_filepath: str = path.join(
Expand Down
Loading

0 comments on commit c6bfc58

Please sign in to comment.