Skip to content

Commit

Permalink
fix: correct Nats JS request
Browse files Browse the repository at this point in the history
  • Loading branch information
Lancetnik committed Oct 18, 2024
1 parent 10472df commit 248fd02
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 44 deletions.
6 changes: 3 additions & 3 deletions faststream/nats/publisher/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from faststream._internal.subscriber.utils import resolve_custom_func
from faststream.message import encode_message
from faststream.nats.parser import NatsParser
from faststream.nats.response import NatsPublishCommand

if TYPE_CHECKING:
from nats.aio.client import Client
Expand All @@ -20,6 +19,7 @@
AsyncCallable,
CustomCallable,
)
from faststream.nats.response import NatsPublishCommand


class NatsFastProducer(ProducerProto):
Expand All @@ -44,7 +44,7 @@ def __init__(
@override
async def publish( # type: ignore[override]
self,
message: NatsPublishCommand,
message: "NatsPublishCommand",
) -> None:
payload, content_type = encode_message(message.body)

Expand Down Expand Up @@ -126,7 +126,7 @@ async def request( # type: ignore[override]
self,
message: "NatsPublishCommand",
) -> "Msg":
payload, content_type = encode_message(message)
payload, content_type = encode_message(message.body)

reply_to = self._connection._nc.new_inbox()
future: asyncio.Future[Msg] = asyncio.Future()
Expand Down
52 changes: 19 additions & 33 deletions faststream/nats/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
if TYPE_CHECKING:
from faststream._internal.basic_types import SendableMessage
from faststream.nats.publisher.specified import SpecificationPublisher
from faststream.nats.response import NatsPublishCommand
from faststream.nats.subscriber.usecase import LogicSubscriber

__all__ = ("TestNatsBroker",)
Expand Down Expand Up @@ -74,28 +75,20 @@ def __init__(self, broker: NatsBroker) -> None:

@override
async def publish( # type: ignore[override]
self,
message: "SendableMessage",
subject: str,
reply_to: str = "",
headers: Optional[dict[str, str]] = None,
correlation_id: Optional[str] = None,
# NatsJSFastProducer compatibility
timeout: Optional[float] = None,
stream: Optional[str] = None,
self, cmd: "NatsPublishCommand"
) -> None:
incoming = build_message(
message=message,
subject=subject,
headers=headers,
correlation_id=correlation_id,
reply_to=reply_to,
message=cmd.body,
subject=cmd.destination,
headers=cmd.headers,
correlation_id=cmd.correlation_id,
reply_to=cmd.reply_to,
)

for handler in _find_handler(
self.broker._subscribers,
subject,
stream,
cmd.destination,
cmd.stream,
):
msg: Union[list[PatchedMessage], PatchedMessage]

Expand All @@ -104,31 +97,24 @@ async def publish( # type: ignore[override]
else:
msg = incoming

await self._execute_handler(msg, subject, handler)
await self._execute_handler(msg, cmd.destination, handler)

@override
async def request( # type: ignore[override]
self,
message: "SendableMessage",
subject: str,
*,
correlation_id: Optional[str] = None,
headers: Optional[dict[str, str]] = None,
timeout: float = 0.5,
# NatsJSFastProducer compatibility
stream: Optional[str] = None,
cmd: "NatsPublishCommand",
) -> "PatchedMessage":
incoming = build_message(
message=message,
subject=subject,
headers=headers,
correlation_id=correlation_id,
message=cmd.body,
subject=cmd.destination,
headers=cmd.headers,
correlation_id=cmd.correlation_id,
)

for handler in _find_handler(
self.broker._subscribers,
subject,
stream,
cmd.destination,
cmd.stream,
):
msg: Union[list[PatchedMessage], PatchedMessage]

Expand All @@ -137,8 +123,8 @@ async def request( # type: ignore[override]
else:
msg = incoming

with anyio.fail_after(timeout):
return await self._execute_handler(msg, subject, handler)
with anyio.fail_after(cmd.timeout):
return await self._execute_handler(msg, cmd.destination, handler)

raise SubscriberNotFound

Expand Down
5 changes: 4 additions & 1 deletion faststream/response/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from .response import Response
from .publish_type import PublishType
from .response import PublishCommand, Response
from .utils import ensure_response

__all__ = (
"PublishCommand",
"PublishType",
"Response",
"ensure_response",
)
15 changes: 8 additions & 7 deletions tests/brokers/base/middlewares.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from faststream._internal.basic_types import DecodedMessage
from faststream.exceptions import SkipMessage
from faststream.middlewares import BaseMiddleware, ExceptionMiddleware
from faststream.response import PublishCommand

from .basic import BaseTestcaseConfig

Expand Down Expand Up @@ -331,8 +332,9 @@ async def test_patch_publish(
event: asyncio.Event,
) -> None:
class Mid(BaseMiddleware):
async def on_publish(self, msg: str, *args, **kwargs) -> str:
return msg * 2
async def on_publish(self, msg: PublishCommand) -> PublishCommand:
msg.body *= 2
return msg

broker = self.get_broker(middlewares=(Mid,))

Expand Down Expand Up @@ -370,11 +372,10 @@ async def test_global_publisher_middleware(
mock: Mock,
) -> None:
class Mid(BaseMiddleware):
async def on_publish(self, msg: str, *args, **kwargs) -> str:
data = msg * 2
assert args or kwargs
mock.enter(data)
return data
async def on_publish(self, msg: PublishCommand) -> PublishCommand:
msg.body *= 2
mock.enter(msg.body)
return msg

async def after_publish(self, *args, **kwargs) -> None:
mock.end()
Expand Down

0 comments on commit 248fd02

Please sign in to comment.