From 401e55be6ddc6f959a0bee65f589a370ed33e143 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Tue, 27 Feb 2024 12:00:50 -0800 Subject: [PATCH] Refine type hint for ReparamMessenger --- pyro/poutine/handlers.py | 18 +++++++++--------- pyro/poutine/reparam_messenger.py | 2 +- pyro/poutine/trace_messenger.py | 2 +- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/pyro/poutine/handlers.py b/pyro/poutine/handlers.py index ff58166b72..278f6a60f2 100644 --- a/pyro/poutine/handlers.py +++ b/pyro/poutine/handlers.py @@ -79,7 +79,7 @@ from pyro.poutine.lift_messenger import LiftMessenger from pyro.poutine.markov_messenger import MarkovMessenger from pyro.poutine.mask_messenger import MaskMessenger -from pyro.poutine.reparam_messenger import ReparamMessenger +from pyro.poutine.reparam_messenger import ReparamHandler, ReparamMessenger from pyro.poutine.replay_messenger import ReplayMessenger from pyro.poutine.runtime import NonlocalExit from pyro.poutine.scale_messenger import ScaleMessenger @@ -152,7 +152,7 @@ def block( @overload def block( - fn: Callable[_P, _T] = ..., + fn: Callable[_P, _T], hide_fn: Optional[Callable[["Message"], Optional[bool]]] = None, expose_fn: Optional[Callable[["Message"], Optional[bool]]] = None, hide_all: bool = True, @@ -186,7 +186,7 @@ def broadcast( @overload def broadcast( - fn: Callable[_P, _T] = ..., + fn: Callable[_P, _T], ) -> Callable[_P, _T]: ... @@ -206,7 +206,7 @@ def collapse( @overload def collapse( - fn: Callable[_P, _T] = ..., + fn: Callable[_P, _T], *args: Any, **kwargs: Any, ) -> Callable[_P, _T]: ... @@ -269,7 +269,7 @@ def enum( @overload def enum( - fn: Callable[_P, _T] = ..., + fn: Callable[_P, _T], first_available_dim: Optional[int] = None, ) -> Callable[_P, _T]: ... @@ -371,14 +371,14 @@ def reparam( def reparam( fn: Callable[_P, _T], config: Union[Dict[str, "Reparam"], Callable[["Message"], Optional["Reparam"]]], -) -> Callable[_P, _T]: ... +) -> ReparamHandler[_P, _T]: ... @_make_handler(ReparamMessenger) def reparam( # type: ignore[empty-body] fn: Callable[_P, _T], config: Union[Dict[str, "Reparam"], Callable[["Message"], Optional["Reparam"]]], -) -> Union[ReparamMessenger, Callable[_P, _T]]: ... +) -> Union[ReparamMessenger, ReparamHandler[_P, _T]]: ... @overload @@ -391,7 +391,7 @@ def replay( @overload def replay( - fn: Callable[_P, _T] = ..., + fn: Callable[_P, _T], trace: Optional["Trace"] = None, params: Optional[Dict[str, "torch.Tensor"]] = None, ) -> Callable[_P, _T]: ... @@ -467,7 +467,7 @@ def substitute( # type: ignore[empty-body] @overload def trace( - fn: None = None, + fn: None = ..., graph_type: Optional[Literal["flat", "dense"]] = None, param_only: Optional[bool] = None, ) -> TraceMessenger: ... diff --git a/pyro/poutine/reparam_messenger.py b/pyro/poutine/reparam_messenger.py index 10405e0330..751b254aa6 100644 --- a/pyro/poutine/reparam_messenger.py +++ b/pyro/poutine/reparam_messenger.py @@ -67,7 +67,7 @@ def __init__( self.config = config self._args_kwargs = None - def __call__(self, fn: Callable[_P, _T]) -> Callable[_P, _T]: + def __call__(self, fn: Callable[_P, _T]) -> "ReparamHandler[_P, _T]": return ReparamHandler(self, fn) def _pyro_sample(self, msg: "Message") -> None: diff --git a/pyro/poutine/trace_messenger.py b/pyro/poutine/trace_messenger.py index 157294137b..4c1b3068bf 100644 --- a/pyro/poutine/trace_messenger.py +++ b/pyro/poutine/trace_messenger.py @@ -110,7 +110,7 @@ def __exit__(self, *args, **kwargs) -> None: identify_dense_edges(self.trace) return super().__exit__(*args, **kwargs) - def __call__(self, fn: Callable[_P, _T]) -> Callable[_P, _T]: + def __call__(self, fn: Callable[_P, _T]) -> "TraceHandler[_P, _T]": """ TODO docs """