Skip to content

Commit

Permalink
avoid using return cast
Browse files Browse the repository at this point in the history
return cast is equivalent to `type: ignore` on the result of the expression,
but more expensive and can cause errors e.g. preventing access to traits during __del__ in process teardown
  • Loading branch information
minrk committed Nov 28, 2024
1 parent 9522ad6 commit 772dcd9
Showing 1 changed file with 59 additions and 53 deletions.
112 changes: 59 additions & 53 deletions traitlets/traitlets.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ class TraitError(Exception):
# -----------------------------------------------------------------------------


def isidentifier(s: t.Any) -> bool:
return t.cast(bool, s.isidentifier())
def isidentifier(s: str) -> bool:
return s.isidentifier()


def _safe_literal_eval(s: str) -> t.Any:
Expand Down Expand Up @@ -293,13 +293,21 @@ class link:

updating = False

def __init__(self, source: t.Any, target: t.Any, transform: t.Any = None) -> None:
def __init__(
self, source: t.Any, target: t.Any, transform: t.Iterable[FuncT] | None = None
) -> None:
_validate_link(source, target)
self.source, self.target = source, target
self._transform, self._transform_inv = transform if transform else (lambda x: x,) * 2

if transform:
self._transform, self._transform_inv = transform # type:ignore[method-assign]
self.link()

def _transform(self, x: T) -> T:
"""default transform: no-op"""
return x

_transform_inv = _transform

def link(self) -> None:
try:
setattr(
Expand Down Expand Up @@ -597,12 +605,12 @@ def default(self, obj: t.Any = None) -> G | None:
in the same way that dynamic defaults defined by ``@default`` are.
"""
if self.default_value is not Undefined:
return t.cast(G, self.default_value)
return self.default_value # type:ignore[no-any-return]
elif hasattr(self, "make_dynamic_default"):
return t.cast(G, self.make_dynamic_default())
return self.make_dynamic_default() # type:ignore[no-any-return]
else:
# Undefined will raise in TraitType.get
return t.cast(G, self.default_value)
return self.default_value # type:ignore[no-any-return]

def get_default_value(self) -> G | None:
"""DEPRECATED: Retrieve the static default value for this trait.
Expand All @@ -613,7 +621,7 @@ def get_default_value(self) -> G | None:
DeprecationWarning,
stacklevel=2,
)
return t.cast(G, self.default_value)
return self.default_value # type:ignore[no-any-return]

def init_default_value(self, obj: t.Any) -> G | None:
"""DEPRECATED: Set the static default value for the trait type."""
Expand Down Expand Up @@ -658,12 +666,12 @@ def get(self, obj: HasTraits, cls: type[t.Any] | None = None) -> G | None:
type="default",
)
)
return t.cast(G, value)
return value # type:ignore[no-any-return]
except Exception as e:
# This should never be reached.
raise TraitError("Unexpected error in TraitType: default value not set properly") from e
else:
return t.cast(G, value)
return value # type:ignore[no-any-return]

@t.overload
def __get__(self, obj: None, cls: type[t.Any]) -> Self:
Expand All @@ -684,7 +692,7 @@ def __get__(self, obj: HasTraits | None, cls: type[t.Any]) -> Self | G:
if obj is None:
return self
else:
return t.cast(G, self.get(obj, cls)) # the G should encode the Optional
return self.get(obj, cls) # type:ignore[return-value]

def set(self, obj: HasTraits, value: S) -> None:
new_value = self._validate(obj, value)
Expand Down Expand Up @@ -722,7 +730,7 @@ def _validate(self, obj: t.Any, value: t.Any) -> G | None:
value = self.validate(obj, value)
if obj._cross_validation_lock is False:
value = self._cross_validate(obj, value)
return t.cast(G, value)
return value # type:ignore[no-any-return]

def _cross_validate(self, obj: t.Any, value: t.Any) -> G | None:
if self.name in obj._trait_validators:
Expand All @@ -738,7 +746,7 @@ def _cross_validate(self, obj: t.Any, value: t.Any) -> G | None:
"use @validate decorator instead.",
)
value = cross_validate(value, self)
return t.cast(G, value)
return value # type:ignore[no-any-return]

def __or__(self, other: TraitType[t.Any, t.Any]) -> Union:
if isinstance(other, Union):
Expand Down Expand Up @@ -1142,7 +1150,7 @@ def compatible_observer(
)
return func(self, change)

return t.cast(FuncT, compatible_observer)
return compatible_observer # type:ignore[return-value]


def validate(*names: Sentinel | str) -> ValidateHandler:
Expand Down Expand Up @@ -1894,7 +1902,7 @@ def trait_defaults(self, *names: str, **metadata: t.Any) -> dict[str, t.Any] | S
raise TraitError(f"'{n}' is not a trait of '{type(self).__name__}' instances")

if len(names) == 1 and len(metadata) == 0:
return t.cast(Sentinel, self._get_trait_default_generator(names[0])(self))
return self._get_trait_default_generator(names[0])(self) # type:ignore[no-any-return]

trait_names = self.trait_names(**metadata)
trait_names.extend(names)
Expand Down Expand Up @@ -2144,7 +2152,7 @@ def validate(self, obj: t.Any, value: t.Any) -> G:
) from e
try:
if issubclass(value, self.klass): # type:ignore[arg-type]
return t.cast(G, value)
return value # type:ignore[no-any-return]
except Exception:
pass

Expand Down Expand Up @@ -2306,7 +2314,7 @@ def validate(self, obj: t.Any, value: t.Any) -> T | None:
if self.allow_none and value is None:
return value
if isinstance(value, self.klass): # type:ignore[arg-type]
return t.cast(T, value)
return value # type:ignore[no-any-return]
else:
self.error(obj, value)

Expand Down Expand Up @@ -2338,7 +2346,7 @@ def default_value_repr(self) -> str:
return repr(self.make_dynamic_default())

def from_string(self, s: str) -> T | None:
return t.cast(T, _safe_literal_eval(s))
return _safe_literal_eval(s) # type:ignore[no-any-return]


class ForwardDeclaredMixin:
Expand Down Expand Up @@ -2635,12 +2643,12 @@ def __init__(
def validate(self, obj: t.Any, value: t.Any) -> G:
if not isinstance(value, int):
self.error(obj, value)
return t.cast(G, _validate_bounds(self, obj, value))
return _validate_bounds(self, obj, value) # type:ignore[no-any-return]

def from_string(self, s: str) -> G:
if self.allow_none and s == "None":
return t.cast(G, None)
return t.cast(G, int(s))
return None # type:ignore[return-value]
return int(s) # type:ignore[return-value]

def subclass_init(self, cls: type[t.Any]) -> None:
pass # fully opt out of instance_init
Expand Down Expand Up @@ -2691,7 +2699,7 @@ def validate(self, obj: t.Any, value: t.Any) -> G:
value = int(value)
except Exception:
self.error(obj, value)
return t.cast(G, _validate_bounds(self, obj, value))
return _validate_bounds(self, obj, value) # type:ignore[no-any-return]


Long, CLong = Int, CInt
Expand Down Expand Up @@ -2753,12 +2761,12 @@ def validate(self, obj: t.Any, value: t.Any) -> G:
value = float(value)
if not isinstance(value, float):
self.error(obj, value)
return t.cast(G, _validate_bounds(self, obj, value))
return _validate_bounds(self, obj, value) # type:ignore[no-any-return]

def from_string(self, s: str) -> G:
if self.allow_none and s == "None":
return t.cast(G, None)
return t.cast(G, float(s))
return None # type:ignore[return-value]
return float(s) # type:ignore[return-value]

def subclass_init(self, cls: type[t.Any]) -> None:
pass # fully opt out of instance_init
Expand Down Expand Up @@ -2809,7 +2817,7 @@ def validate(self, obj: t.Any, value: t.Any) -> G:
value = float(value)
except Exception:
self.error(obj, value)
return t.cast(G, _validate_bounds(self, obj, value))
return _validate_bounds(self, obj, value) # type:ignore[no-any-return]


class Complex(TraitType[complex, t.Union[complex, float, int]]):
Expand Down Expand Up @@ -2935,18 +2943,18 @@ def __init__(

def validate(self, obj: t.Any, value: t.Any) -> G:
if isinstance(value, str):
return t.cast(G, value)
return value # type:ignore[return-value]
if isinstance(value, bytes):
try:
return t.cast(G, value.decode("ascii", "strict"))
return value.decode("ascii", "strict") # type:ignore[return-value]
except UnicodeDecodeError as e:
msg = "Could not decode {!r} for unicode trait '{}' of {} instance."
raise TraitError(msg.format(value, self.name, class_of(obj))) from e
self.error(obj, value)

def from_string(self, s: str) -> G:
if self.allow_none and s == "None":
return t.cast(G, None)
return None # type:ignore[return-value]
s = os.path.expanduser(s)
if len(s) >= 2:
# handle deprecated "1"
Expand All @@ -2960,7 +2968,7 @@ def from_string(self, s: str) -> G:
DeprecationWarning,
stacklevel=2,
)
return t.cast(G, s)
return s # type:ignore[return-value]

def subclass_init(self, cls: type[t.Any]) -> None:
pass # fully opt out of instance_init
Expand Down Expand Up @@ -3008,7 +3016,7 @@ def __init__(

def validate(self, obj: t.Any, value: t.Any) -> G:
try:
return t.cast(G, str(value))
return str(value) # type:ignore[return-value]
except Exception:
self.error(obj, value)

Expand Down Expand Up @@ -3091,22 +3099,22 @@ def __init__(

def validate(self, obj: t.Any, value: t.Any) -> G:
if isinstance(value, bool):
return t.cast(G, value)
return value # type:ignore[return-value]
elif isinstance(value, int):
if value == 1:
return t.cast(G, True)
return True # type:ignore[return-value]
elif value == 0:
return t.cast(G, False)
return False # type:ignore[return-value]
self.error(obj, value)

def from_string(self, s: str) -> G:
if self.allow_none and s == "None":
return t.cast(G, None)
return None # type:ignore[return-value]
s = s.lower()
if s in {"true", "1"}:
return t.cast(G, True)
return True # type:ignore[return-value]
elif s in {"false", "0"}:
return t.cast(G, False)
return False # type:ignore[return-value]
else:
raise ValueError("%r is not 1, 0, true, or false")

Expand Down Expand Up @@ -3163,7 +3171,7 @@ def __init__(

def validate(self, obj: t.Any, value: t.Any) -> G:
try:
return t.cast(G, bool(value))
return bool(value) # type:ignore[return-value]
except Exception:
self.error(obj, value)

Expand Down Expand Up @@ -3220,7 +3228,7 @@ def __init__(

def validate(self, obj: t.Any, value: t.Any) -> G:
if self.values and value in self.values:
return t.cast(G, value)
return value # type:ignore[no-any-return]
self.error(obj, value)

def _choices_str(self, as_rst: bool = False) -> str:
Expand All @@ -3247,7 +3255,7 @@ def from_string(self, s: str) -> G:
try:
return self.validate(None, s)
except TraitError:
return t.cast(G, _safe_literal_eval(s))
return _safe_literal_eval(s) # type:ignore[no-any-return]

def subclass_init(self, cls: type[t.Any]) -> None:
pass # fully opt out of instance_init
Expand Down Expand Up @@ -3275,7 +3283,7 @@ def validate(self, obj: t.Any, value: t.Any) -> G:
for v in self.values or []:
assert isinstance(v, str)
if v.lower() == value.lower():
return t.cast(G, v)
return v # type:ignore[return-value]
self.error(obj, value)

def _info(self, as_rst: bool = False) -> str:
Expand Down Expand Up @@ -3479,14 +3487,12 @@ def validate(self, obj: t.Any, value: t.Any) -> T | None:
if value is None:
return value

value = self.validate_elements(obj, value)

return t.cast(T, value)
return self.validate_elements(obj, value)

def validate_elements(self, obj: t.Any, value: t.Any) -> T | None:
validated = []
if self._trait is None or isinstance(self._trait, Any):
return t.cast(T, value)
return value # type:ignore[no-any-return]
for v in value:
try:
v = self._trait._validate(obj, v)
Expand Down Expand Up @@ -3553,7 +3559,7 @@ def from_string_list(self, s_list: list[str]) -> T | None:
else:
# backward-compat: allow item_from_string to ignore index arg
def item_from_string(s: str, index: int | None = None) -> T | str:
return t.cast(T, self.item_from_string(s))
return self.item_from_string(s)

return self.klass( # type:ignore[call-arg]
[item_from_string(s, index=idx) for idx, s in enumerate(s_list)]
Expand All @@ -3565,7 +3571,7 @@ def item_from_string(self, s: str, index: int | None = None) -> T | str:
Evaluated when parsing CLI configuration from a string
"""
if self._trait:
return t.cast(T, self._trait.from_string(s))
return self._trait.from_string(s) # type:ignore[no-any-return]
else:
return s

Expand Down Expand Up @@ -4051,7 +4057,7 @@ def from_string(self, s: str) -> dict[K, V] | None:
if not isinstance(s, str):
raise TypeError(f"from_string expects a string, got {s!r} of type {type(s)}")
try:
return t.cast("dict[K, V]", self.from_string_list([s]))
return self.from_string_list([s]) # type:ignore[no-any-return]
except Exception:
test = _safe_literal_eval(s)
if isinstance(test, dict):
Expand Down Expand Up @@ -4109,7 +4115,7 @@ def item_from_string(self, s: str) -> dict[K, V]:
value_trait = (self._per_key_traits or {}).get(key, self._value_trait)
if value_trait:
value = value_trait.from_string(value)
return t.cast("dict[K, V]", {key: value})
return {key: value} # type:ignore[dict-item]


class TCPAddress(TraitType[G, S]):
Expand Down Expand Up @@ -4165,17 +4171,17 @@ def validate(self, obj: t.Any, value: t.Any) -> G:
if isinstance(value[0], str) and isinstance(value[1], int):
port = value[1]
if port >= 0 and port <= 65535:
return t.cast(G, value)
return value # type:ignore[return-value]
self.error(obj, value)

def from_string(self, s: str) -> G:
if self.allow_none and s == "None":
return t.cast(G, None)
return None # type:ignore[return-value]
if ":" not in s:
raise ValueError("Require `ip:port`, got %r" % s)
ip, port_str = s.split(":", 1)
port = int(port_str)
return t.cast(G, (ip, port))
return (ip, port) # type:ignore[return-value]


class CRegExp(TraitType["re.Pattern[t.Any]", t.Union["re.Pattern[t.Any]", str]]):
Expand Down

0 comments on commit 772dcd9

Please sign in to comment.