Skip to content

Commit

Permalink
🐛 修复函数签名中varargkwarg缺失的问题
Browse files Browse the repository at this point in the history
  • Loading branch information
snowykami committed Aug 29, 2024
1 parent 1584069 commit bfc17ab
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 1 deletion.
1 change: 1 addition & 0 deletions litedoc/style/markdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

litedoc_hide = "@litedoc-hide"


def generate(parser: AstParser, lang: str, frontmatter: Optional[dict] = None, style: str = "google") -> str:
"""
Generate markdown style document from ast
Expand Down
20 changes: 19 additions & 1 deletion litedoc/syntax/astparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .node import *
from ..docstring.parser import parse


class AstParser:
def __init__(self, code: str, style: str = "google"):
"""
Expand Down Expand Up @@ -110,13 +111,23 @@ def parse(self):
)
for arg in sub_node.args.args
],
vararg=ArgNode(
name=sub_node.args.vararg.arg,
type=self.clear_quotes(
ast.unparse(sub_node.args.vararg.annotation).strip()) if sub_node.args.vararg.annotation else TypeHint.NO_TYPEHINT
) if sub_node.args.vararg else None,
kwonlyargs=[
ArgNode(
name=arg.arg,
type=self.clear_quotes(ast.unparse(arg.annotation).strip()) if arg.annotation else TypeHint.NO_TYPEHINT,
)
for arg in sub_node.args.kwonlyargs
],
kwarg=ArgNode(
name=sub_node.args.kwarg.arg,
type=self.clear_quotes(
ast.unparse(sub_node.args.kwarg.annotation).strip()) if sub_node.args.kwarg.annotation else TypeHint.NO_TYPEHINT
) if sub_node.args.kwarg else None,
kw_defaults=[
ConstantNode(
value=ast.unparse(default).strip() if default else TypeHint.NO_DEFAULT
Expand Down Expand Up @@ -155,7 +166,6 @@ def parse(self):
# 仅打印模块级别的函数
if not self._is_module_level_function(node):
continue

self.functions.append(FunctionNode(
name=node.name,
docs=parse(ast.get_docstring(node), parser=self.style) if ast.get_docstring(node) else None,
Expand All @@ -173,13 +183,21 @@ def parse(self):
)
for arg, default in zip(node.args.args, node.args.defaults)
],
vararg=ArgNode(
name=node.args.vararg.arg,
type=self.clear_quotes(ast.unparse(node.args.vararg.annotation).strip()) if node.args.vararg.annotation else TypeHint.NO_TYPEHINT
) if node.args.vararg else None,
kwonlyargs=[
ArgNode(
name=arg.arg,
type=self.clear_quotes(ast.unparse(arg.annotation).strip()) if arg.annotation else TypeHint.NO_TYPEHINT,
)
for arg in node.args.kwonlyargs
],
kwarg=ArgNode(
name=node.args.kwarg.arg,
type=self.clear_quotes(ast.unparse(node.args.kwarg.annotation).strip()) if node.args.kwarg.annotation else TypeHint.NO_TYPEHINT
) if node.args.kwarg else None,
kw_defaults=[
ConstantNode(
value=ast.unparse(default).strip() if default else TypeHint.NO_DEFAULT
Expand Down
14 changes: 14 additions & 0 deletions litedoc/syntax/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ class FunctionNode(BaseModel):

posonlyargs: list[ArgNode] = []
args: list[ArgNode] = []
vararg: Optional[ArgNode] = None
kwonlyargs: list[ArgNode] = []
kwarg: Optional[ArgNode] = None
kw_defaults: list[ConstantNode] = []
defaults: list[ConstantNode] = []

Expand Down Expand Up @@ -212,6 +214,12 @@ def markdown(self, lang: str, indent: int = 0) -> str:
args.append(arg_text)
arg_i += 1

if arg := self.vararg:
arg_text = f"*{arg.name}"
if arg.type != TypeHint.NO_TYPEHINT:
arg_text += f": {arg.type}"
args.append(arg_text)

if len(self.kwonlyargs) > 0:
# 加关键字参数分割符 *
args.append("*")
Expand All @@ -223,6 +231,12 @@ def markdown(self, lang: str, indent: int = 0) -> str:
arg_text += f" = {kw_default.value}"
args.append(arg_text)

if self.kwarg is not None:
arg_text = f"**{self.kwarg.name}"
if self.kwarg.type != TypeHint.NO_TYPEHINT:
arg_text += f": {self.kwarg.type}"
args.append(arg_text)

"""魔法方法"""
if self.name in self.magic_methods:
if len(args) == 2:
Expand Down
32 changes: 32 additions & 0 deletions tests/test_modules/mbcp/mp_math/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def partial_derivative_func(*args: Var) -> Var:
args_list_minus = list(args)
args_list_minus[var] -= epsilon
return (func(*args_list_plus) - func(*args_list_minus)) / (2 * epsilon)

return partial_derivative_func
elif isinstance(var, tuple):
def high_order_partial_derivative_func(*args: Var) -> Var:
Expand All @@ -80,6 +81,37 @@ def high_order_partial_derivative_func(*args: Var) -> Var:
for v in var:
result_func = get_partial_derivative_func(result_func, v, epsilon)
return result_func(*args)

return high_order_partial_derivative_func
else:
raise ValueError("Invalid var type")


def curry(func: MultiVarsFunc, *args: Var) -> OneVarFunc:
"""
对多参数函数进行柯里化。
> [!tip]
> 有关函数柯里化,可参考[函数式编程--柯理化(Currying)](https://zhuanlan.zhihu.com/p/355859667)
Args:
func: 函数
*args: 参数
Returns:
柯里化后的函数
"""

def curried_func(*args2: Var) -> Var:
"""@litedoc-hide"""
return func(*args, *args2)

return curried_func


def test_kwargs(*args, **kwargs):
"""
测试kwargs
Args:
*args:
**kwargs:
"""
print(args)
print(kwargs)

0 comments on commit bfc17ab

Please sign in to comment.