Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #3334: Add support for custom parameter classes in mypy plugin #3335

Merged
merged 1 commit into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 27 additions & 28 deletions luigi/mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from __future__ import annotations

import re
import sys
from typing import Callable, Dict, Final, Iterator, List, Literal, Optional

Expand Down Expand Up @@ -61,10 +60,6 @@

METADATA_TAG: Final[str] = "task"

PARAMETER_FULLNAME_MATCHER: Final[re.Pattern] = re.compile(
r"^luigi(\.parameter)?\.\w*Parameter$"
)

if sys.version_info[:2] < (3, 8):
# This plugin uses the walrus operator, which is only available in Python 3.8+
raise RuntimeError("This plugin requires Python 3.8+")
Expand All @@ -84,12 +79,17 @@
self, fullname: str
) -> Callable[[FunctionContext], Type] | None:
"""Adjust the return type of the `Parameters` function."""
if PARAMETER_FULLNAME_MATCHER.match(fullname):
if self.check_parameter(fullname):
return self._task_parameter_field_callback
return None

def check_parameter(self, fullname):
sym = self.lookup_fully_qualified(fullname)
if sym and isinstance(sym.node, TypeInfo):
return any(base.fullname == "luigi.parameter.Parameter" for base in sym.node.mro)

def _task_class_maker_callback(self, ctx: ClassDefContext) -> None:
transformer = TaskTransformer(ctx.cls, ctx.reason, ctx.api)
transformer = TaskTransformer(ctx.cls, ctx.reason, ctx.api, self)
transformer.transform()

def _task_parameter_field_callback(self, ctx: FunctionContext) -> Type:
Expand Down Expand Up @@ -210,10 +210,12 @@
cls: ClassDef,
reason: Expression | Statement,
api: SemanticAnalyzerPluginInterface,
task_plugin: TaskPlugin,
) -> None:
self._cls = cls
self._reason = reason
self._api = api
self._task_plugin = task_plugin

def transform(self) -> bool:
"""Apply all the necessary transformations to the underlying gokart.Task"""
Expand Down Expand Up @@ -311,7 +313,7 @@
# Second, collect attributes belonging to the current class.
current_attr_names: set[str] = set()
for stmt in self._get_assignment_statements_from_block(cls.defs):
if not is_parameter_call(stmt.rvalue):
if not self.is_parameter_call(stmt.rvalue):
continue

# a: int, b: str = 1, 'foo' is not supported syntax so we
Expand Down Expand Up @@ -435,29 +437,26 @@

return default

def is_parameter_call(self, expr: Expression) -> bool:
"""Checks if the expression is a call to luigi.Parameter()"""
if not isinstance(expr, CallExpr):
return False

def is_parameter_call(expr: Expression) -> bool:
"""Checks if the expression is a call to luigi.Parameter()"""
if not isinstance(expr, CallExpr):
return False

callee = expr.callee
if isinstance(callee, MemberExpr):
type_info = callee.node
if type_info is None and isinstance(callee.expr, NameExpr):
return (
PARAMETER_FULLNAME_MATCHER.match(f"{callee.expr.name}.{callee.name}")
is not None
)
elif isinstance(callee, NameExpr):
type_info = callee.node
else:
return False
callee = expr.callee
fullname = None
if isinstance(callee, MemberExpr):
type_info = callee.node
if type_info is None and isinstance(callee.expr, NameExpr):
fullname = f"{callee.expr.name}.{callee.name}"

Check warning on line 450 in luigi/mypy.py

View check run for this annotation

Codecov / codecov/patch

luigi/mypy.py#L450

Added line #L450 was not covered by tests
elif isinstance(callee, NameExpr):
type_info = callee.node
else:
return False

Check warning on line 454 in luigi/mypy.py

View check run for this annotation

Codecov / codecov/patch

luigi/mypy.py#L454

Added line #L454 was not covered by tests

if isinstance(type_info, TypeInfo):
return PARAMETER_FULLNAME_MATCHER.match(type_info.fullname) is not None
if isinstance(type_info, TypeInfo):
fullname = type_info.fullname

return False
return fullname is not None and self._task_plugin.check_parameter(fullname)


def plugin(version: str) -> type[Plugin]:
Expand Down
9 changes: 8 additions & 1 deletion test/mypy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,21 @@ def test_plugin_no_issue(self):

test_code = """
import luigi
from uuid import UUID


class UUIDParameter(luigi.Parameter):
def parse(self, s):
return UUID(s)


class MyTask(luigi.Task):
foo: int = luigi.IntParameter()
bar: str = luigi.Parameter()
uniq: UUID = UUIDParameter()
baz: str = luigi.Parameter(default="baz")

MyTask(foo=1, bar='bar')
MyTask(foo=1, bar='bar', uniq=UUID("9b0591d7-a167-4978-bc6d-41f7d84a288c"))
"""

with tempfile.NamedTemporaryFile(suffix=".py") as test_file:
Expand Down
Loading