Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async steps #609

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/pytest_bdd/asyncio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from pytest_bdd.steps import async_given, async_then, async_when

__all__ = ["async_given", "async_when", "async_then"]
80 changes: 78 additions & 2 deletions src/pytest_bdd/scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
"""
from __future__ import annotations

import asyncio
import contextlib
import functools
import inspect
import logging
import os
import re
Expand All @@ -34,7 +37,6 @@

from .parser import Feature, Scenario, ScenarioTemplate, Step


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -156,7 +158,14 @@ def _execute_step_function(

request.config.hook.pytest_bdd_before_step_call(**kw)
# Execute the step as if it was a pytest fixture, so that we can allow "yield" statements in it
return_value = call_fixture_func(fixturefunc=context.step_func, request=request, kwargs=kwargs)
step_func = context.step_func
if context.is_async:
if inspect.isasyncgenfunction(context.step_func):
step_func = _wrap_asyncgen(request, context.step_func)
elif inspect.iscoroutinefunction(context.step_func):
step_func = _wrap_coroutine(context.step_func)

return_value = call_fixture_func(fixturefunc=step_func, request=request, kwargs=kwargs)
except Exception as exception:
request.config.hook.pytest_bdd_step_error(exception=exception, **kw)
raise
Expand All @@ -167,6 +176,73 @@ def _execute_step_function(
request.config.hook.pytest_bdd_after_step(**kw)


def _wrap_asyncgen(request: FixtureRequest, func: Callable) -> Callable:
"""Wrapper for an async_generator function.

This will wrap the function in a synchronized method to return the first
yielded value from the generator. A finalizer will be added to the fixture
to ensure that no other values are yielded and that the loop is closed.

:param request: The fixture request.
:param func: The function to wrap.

:returns: The wrapped function.
"""

@functools.wraps(func)
def _wrapper(*args, **kwargs):
try:
loop, created = asyncio.get_running_loop(), False
except RuntimeError:
loop, created = asyncio.get_event_loop_policy().new_event_loop(), True

async_obj = func(*args, **kwargs)

def _finalizer() -> None:
"""Ensure no more values are yielded and close the loop."""
try:
loop.run_until_complete(async_obj.__anext__())
except StopAsyncIteration:
pass
else:
raise ValueError("Async generator must only yield once.")

if created:
loop.close()

value = loop.run_until_complete(async_obj.__anext__())
request.addfinalizer(_finalizer)

return value

return _wrapper


def _wrap_coroutine(func: Callable) -> Callable:
"""Wrapper for a coroutine function.

:param func: The function to wrap.

:returns: The wrapped function.
"""

@functools.wraps(func)
def _wrapper(*args, **kwargs):
try:
loop, created = asyncio.get_running_loop(), False
except RuntimeError:
loop, created = asyncio.get_event_loop_policy().new_event_loop(), True

try:
async_obj = func(*args, **kwargs)
return loop.run_until_complete(async_obj)
finally:
if created:
loop.close()

return _wrapper


def _execute_scenario(feature: Feature, scenario: Scenario, request: FixtureRequest) -> None:
"""Execute the scenario.

Expand Down
80 changes: 77 additions & 3 deletions src/pytest_bdd/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class StepFunctionContext:
parser: StepParser
converters: dict[str, Callable[..., Any]] = field(default_factory=dict)
target_fixture: str | None = None
is_async: bool = False


def get_step_fixture_name(step: Step) -> str:
Expand All @@ -78,6 +79,7 @@ def given(
converters: dict[str, Callable] | None = None,
target_fixture: str | None = None,
stacklevel: int = 1,
is_async: bool = False,
) -> Callable:
"""Given step decorator.

Expand All @@ -86,17 +88,62 @@ def given(
{<param_name>: <converter function>}.
:param target_fixture: Target fixture name to replace by steps definition function.
:param stacklevel: Stack level to find the caller frame. This is used when injecting the step definition fixture.
:param is_async: True if the step is asynchronous. (Default: False)

:return: Decorator function for the step.
"""
return step(name, GIVEN, converters=converters, target_fixture=target_fixture, stacklevel=stacklevel)
return step(
name, GIVEN, converters=converters, target_fixture=target_fixture, stacklevel=stacklevel, is_async=is_async
)


def async_given(
name: str | StepParser,
converters: dict[str, Callable] | None = None,
target_fixture: str | None = None,
stacklevel: int = 1,
) -> Callable:
"""Async Given step decorator.

:param name: Step name or a parser object.
:param converters: Optional `dict` of the argument or parameter converters in form
{<param_name>: <converter function>}.
:param target_fixture: Target fixture name to replace by steps definition function.
:param stacklevel: Stack level to find the caller frame. This is used when injecting the step definition fixture.

:return: Decorator function for the step.
"""
return given(name, converters=converters, target_fixture=target_fixture, stacklevel=stacklevel, is_async=True)


def when(
name: str | StepParser,
converters: dict[str, Callable] | None = None,
target_fixture: str | None = None,
stacklevel: int = 1,
is_async: bool = False,
) -> Callable:
"""When step decorator.

:param name: Step name or a parser object.
:param converters: Optional `dict` of the argument or parameter converters in form
{<param_name>: <converter function>}.
:param target_fixture: Target fixture name to replace by steps definition function.
:param stacklevel: Stack level to find the caller frame. This is used when injecting the step definition fixture.
:param is_async: True if the step is asynchronous. (Default: False)

:return: Decorator function for the step.
"""
return step(
name, WHEN, converters=converters, target_fixture=target_fixture, stacklevel=stacklevel, is_async=is_async
)


def async_when(
name: str | StepParser,
converters: dict[str, Callable] | None = None,
target_fixture: str | None = None,
stacklevel: int = 1,
) -> Callable:
"""When step decorator.

Expand All @@ -108,14 +155,15 @@ def when(

:return: Decorator function for the step.
"""
return step(name, WHEN, converters=converters, target_fixture=target_fixture, stacklevel=stacklevel)
return when(name, converters=converters, target_fixture=target_fixture, stacklevel=stacklevel, is_async=True)


def then(
name: str | StepParser,
converters: dict[str, Callable] | None = None,
target_fixture: str | None = None,
stacklevel: int = 1,
is_async: bool = False,
) -> Callable:
"""Then step decorator.

Expand All @@ -124,10 +172,32 @@ def then(
{<param_name>: <converter function>}.
:param target_fixture: Target fixture name to replace by steps definition function.
:param stacklevel: Stack level to find the caller frame. This is used when injecting the step definition fixture.
:param is_async: True if the step is asynchronous. (Default: False)

:return: Decorator function for the step.
"""
return step(name, THEN, converters=converters, target_fixture=target_fixture, stacklevel=stacklevel)
return step(
name, THEN, converters=converters, target_fixture=target_fixture, stacklevel=stacklevel, is_async=is_async
)


def async_then(
name: str | StepParser,
converters: dict[str, Callable] | None = None,
target_fixture: str | None = None,
stacklevel: int = 1,
) -> Callable:
"""Then step decorator.

:param name: Step name or a parser object.
:param converters: Optional `dict` of the argument or parameter converters in form
{<param_name>: <converter function>}.
:param target_fixture: Target fixture name to replace by steps definition function.
:param stacklevel: Stack level to find the caller frame. This is used when injecting the step definition fixture.

:return: Decorator function for the step.
"""
return step(name, THEN, converters=converters, target_fixture=target_fixture, stacklevel=stacklevel, is_async=True)


def step(
Expand All @@ -136,6 +206,7 @@ def step(
converters: dict[str, Callable] | None = None,
target_fixture: str | None = None,
stacklevel: int = 1,
is_async: bool = False,
) -> Callable[[TCallable], TCallable]:
"""Generic step decorator.

Expand All @@ -144,6 +215,7 @@ def step(
:param converters: Optional step arguments converters mapping.
:param target_fixture: Optional fixture name to replace by step definition.
:param stacklevel: Stack level to find the caller frame. This is used when injecting the step definition fixture.
:param is_async: True if the step is asynchronous. (Default: False)

:return: Decorator function for the step.

Expand All @@ -165,6 +237,7 @@ def decorator(func: TCallable) -> TCallable:
parser=parser,
converters=converters,
target_fixture=target_fixture,
is_async=is_async,
)

def step_function_marker() -> StepFunctionContext:
Expand All @@ -177,6 +250,7 @@ def step_function_marker() -> StepFunctionContext:
f"{StepNamePrefix.step_def.value}_{type_ or '*'}_{parser.name}", seen=caller_locals.keys()
)
caller_locals[fixture_step_name] = pytest.fixture(name=fixture_step_name)(step_function_marker)

return func

return decorator
Expand Down
Loading