diff --git a/sanic_testing/testing.py b/sanic_testing/testing.py index c78d40e..fe3d6cd 100644 --- a/sanic_testing/testing.py +++ b/sanic_testing/testing.py @@ -345,6 +345,11 @@ def __init__( self.gather_request = True self.last_request = None + self._server_is_running = False + + @property + def server_is_running(self): + return self._server_is_running def _collect_request(self, request): if self.gather_request: @@ -360,13 +365,7 @@ def _start_test_mode(cls, sanic, *args, **kwargs): def _end_test_mode(cls, sanic, *args, **kwargs): Sanic.test_mode = False - async def request( # type: ignore - self, method, url, gather_request=True, *args, **kwargs - ) -> typing.Tuple[ - typing.Optional[Request], typing.Optional[TestingResponse] - ]: - self.sanic_app.router.reset() - self.sanic_app.signal_router.reset() + async def run(self): await self.sanic_app._startup() # type: ignore await self.sanic_app._server_event("init", "before") await self.sanic_app._server_event("init", "after") @@ -376,6 +375,25 @@ async def request( # type: ignore self._collect_request ) + self._server_is_running = True + + async def stop(self): + await self.sanic_app._server_event("shutdown", "before") + await self.sanic_app._server_event("shutdown", "after") + + self._server_is_running = False + + async def request( # type: ignore + self, method, url, gather_request=True, *args, **kwargs + ) -> typing.Tuple[ + typing.Optional[Request], typing.Optional[TestingResponse] + ]: + stop_after_request = False + + if not self._server_is_running: + await self.run() + stop_after_request = True + if not url.startswith( ("http:", "https:", "ftp:", "ftps://", "//", "ws:", "wss:") ): @@ -391,8 +409,8 @@ async def request( # type: ignore self.gather_request = gather_request response = await super().request(method, url, *args, **kwargs) - await self.sanic_app._server_event("shutdown", "before") - await self.sanic_app._server_event("shutdown", "after") + if stop_after_request: + await self.stop() response.__class__ = TestingResponse @@ -474,3 +492,9 @@ def __setstate__(self, d): # Need to create a new CookieJar when unpickling, # because it was killed on Pickle self._cookies = httpx.Cookies() + + async def __aenter__(self): + await self.run() + + async def __aexit__(self, *args): + await self.stop() diff --git a/tests/test_asgi_client.py b/tests/test_asgi_client.py index 4f2f8ce..b9c6170 100644 --- a/tests/test_asgi_client.py +++ b/tests/test_asgi_client.py @@ -17,6 +17,49 @@ async def test_basic_asgi_client(app, method): assert response.content_type == "text/plain; charset=utf-8" +@pytest.mark.asyncio +@pytest.mark.parametrize( + "method", ["get", "post", "patch", "put", "delete", "options"] +) +async def test_asgi_client_as_async_context_manager(app, method): + assert app.asgi_client.server_is_running is False + + async with app.asgi_client: + assert app.asgi_client.server_is_running is True + + request, response = await getattr(app.asgi_client, method)("/") + + assert isinstance(request, Request) + assert response.body == b"foo" + assert response.status == 200 + assert response.content_type == "text/plain; charset=utf-8" + + assert app.asgi_client.server_is_running is False + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "method", ["get", "post", "patch", "put", "delete", "options"] +) +async def test_asgi_client_manual_run_and_stop(app, method): + assert app.asgi_client.server_is_running is False + + await app.asgi_client.run() + + assert app.asgi_client.server_is_running is True + + request, response = await getattr(app.asgi_client, method)("/") + + assert isinstance(request, Request) + assert response.body == b"foo" + assert response.status == 200 + assert response.content_type == "text/plain; charset=utf-8" + + await app.asgi_client.stop() + + assert app.asgi_client.server_is_running is False + + @pytest.mark.asyncio async def test_websocket_route(app): ev = asyncio.Event()