From 552feffc040e1b7ff91d782cfa5b08a04da6de92 Mon Sep 17 00:00:00 2001 From: Adrian Stachlewski Date: Fri, 17 Jan 2025 21:28:59 +0100 Subject: [PATCH] Fixes #3334: Add support for custom parameter classes in mypy plugin --- luigi/mypy.py | 55 +++++++++++++++++++++++------------------------ test/mypy_test.py | 9 +++++++- 2 files changed, 35 insertions(+), 29 deletions(-) diff --git a/luigi/mypy.py b/luigi/mypy.py index 8de6d5a37e..363538bf21 100644 --- a/luigi/mypy.py +++ b/luigi/mypy.py @@ -6,7 +6,6 @@ from __future__ import annotations -import re import sys from typing import Callable, Dict, Final, Iterator, List, Literal, Optional @@ -61,10 +60,6 @@ METADATA_TAG: Final[str] = "task" -PARAMETER_FULLNAME_MATCHER: Final[re.Pattern] = re.compile( - r"^luigi(\.parameter)?\.\w*Parameter$" -) - if sys.version_info[:2] < (3, 8): # This plugin uses the walrus operator, which is only available in Python 3.8+ raise RuntimeError("This plugin requires Python 3.8+") @@ -84,12 +79,17 @@ def get_function_hook( self, fullname: str ) -> Callable[[FunctionContext], Type] | None: """Adjust the return type of the `Parameters` function.""" - if PARAMETER_FULLNAME_MATCHER.match(fullname): + if self.check_parameter(fullname): return self._task_parameter_field_callback return None + def check_parameter(self, fullname): + sym = self.lookup_fully_qualified(fullname) + if sym and isinstance(sym.node, TypeInfo): + return any(base.fullname == "luigi.parameter.Parameter" for base in sym.node.mro) + def _task_class_maker_callback(self, ctx: ClassDefContext) -> None: - transformer = TaskTransformer(ctx.cls, ctx.reason, ctx.api) + transformer = TaskTransformer(ctx.cls, ctx.reason, ctx.api, self) transformer.transform() def _task_parameter_field_callback(self, ctx: FunctionContext) -> Type: @@ -210,10 +210,12 @@ def __init__( cls: ClassDef, reason: Expression | Statement, api: SemanticAnalyzerPluginInterface, + task_plugin: TaskPlugin, ) -> None: self._cls = cls self._reason = reason self._api = api + self._task_plugin = task_plugin def transform(self) -> bool: """Apply all the necessary transformations to the underlying gokart.Task""" @@ -311,7 +313,7 @@ def collect_attributes(self) -> Optional[List[TaskAttribute]]: # Second, collect attributes belonging to the current class. current_attr_names: set[str] = set() for stmt in self._get_assignment_statements_from_block(cls.defs): - if not is_parameter_call(stmt.rvalue): + if not self.is_parameter_call(stmt.rvalue): continue # a: int, b: str = 1, 'foo' is not supported syntax so we @@ -435,29 +437,26 @@ def _infer_task_attr_init_type( return default + def is_parameter_call(self, expr: Expression) -> bool: + """Checks if the expression is a call to luigi.Parameter()""" + if not isinstance(expr, CallExpr): + return False -def is_parameter_call(expr: Expression) -> bool: - """Checks if the expression is a call to luigi.Parameter()""" - if not isinstance(expr, CallExpr): - return False - - callee = expr.callee - if isinstance(callee, MemberExpr): - type_info = callee.node - if type_info is None and isinstance(callee.expr, NameExpr): - return ( - PARAMETER_FULLNAME_MATCHER.match(f"{callee.expr.name}.{callee.name}") - is not None - ) - elif isinstance(callee, NameExpr): - type_info = callee.node - else: - return False + callee = expr.callee + fullname = None + if isinstance(callee, MemberExpr): + type_info = callee.node + if type_info is None and isinstance(callee.expr, NameExpr): + fullname = f"{callee.expr.name}.{callee.name}" + elif isinstance(callee, NameExpr): + type_info = callee.node + else: + return False - if isinstance(type_info, TypeInfo): - return PARAMETER_FULLNAME_MATCHER.match(type_info.fullname) is not None + if isinstance(type_info, TypeInfo): + fullname = type_info.fullname - return False + return fullname is not None and self._task_plugin.check_parameter(fullname) def plugin(version: str) -> type[Plugin]: diff --git a/test/mypy_test.py b/test/mypy_test.py index 9460a0e9c8..d7347979b8 100644 --- a/test/mypy_test.py +++ b/test/mypy_test.py @@ -12,14 +12,21 @@ def test_plugin_no_issue(self): test_code = """ import luigi +from uuid import UUID + + +class UUIDParameter(luigi.Parameter): + def parse(self, s): + return UUID(s) class MyTask(luigi.Task): foo: int = luigi.IntParameter() bar: str = luigi.Parameter() + uniq: UUID = UUIDParameter() baz: str = luigi.Parameter(default="baz") -MyTask(foo=1, bar='bar') +MyTask(foo=1, bar='bar', uniq=UUID("9b0591d7-a167-4978-bc6d-41f7d84a288c")) """ with tempfile.NamedTemporaryFile(suffix=".py") as test_file: