diff --git a/rclpy/CMakeLists.txt b/rclpy/CMakeLists.txt index a9966c040..3872edf4c 100644 --- a/rclpy/CMakeLists.txt +++ b/rclpy/CMakeLists.txt @@ -159,6 +159,7 @@ if(BUILD_TESTING) test/test_action_client.py test/test_action_graph.py test/test_action_server.py + test/test_asyncio_interop.py test/test_callback_group.py test/test_client.py test/test_clock.py diff --git a/rclpy/rclpy/executors.py b/rclpy/rclpy/executors.py index 3c5bea476..b62852654 100644 --- a/rclpy/rclpy/executors.py +++ b/rclpy/rclpy/executors.py @@ -190,6 +190,27 @@ def create_task(self, callback: Union[Callable, Coroutine], *args, **kwargs) -> # Task inherits from Future return task + def call_soon(self, callback, *args) -> Task: + """ + Add a callback or coroutine to be executed during :meth:`spin`. + + Arguments to this function are passed to the callback. + + :param callback: A callback to be run in the executor. + :return: A Task which the executor will execute. + """ + if self._is_shutdown: + raise ShutdownException() + + if not isinstance(callback, Task): + callback = Task(callback, args, None, executor=self) + + with self._tasks_lock: + self._tasks.append((callback, None, None)) + self._guard.trigger() + + return callback + def shutdown(self, timeout_sec: float = None) -> bool: """ Stop executing callbacks and wait for their completion. @@ -432,12 +453,9 @@ async def handler(entity, gc, is_shutdown, work_tracker): gc.trigger() except InvalidHandle: pass - task = Task( + return Task( handler, (entity, self._guard, self._is_shutdown, self._work_tracker), executor=self) - with self._tasks_lock: - self._tasks.append((task, entity, node)) - return task def can_execute(self, entity: WaitableEntityType) -> bool: """ @@ -481,16 +499,19 @@ def _wait_for_ready_callbacks( # Yield tasks in-progress before waiting for new work tasks = None with self._tasks_lock: - tasks = list(self._tasks) - if tasks: - for task, entity, node in reversed(tasks): - if (not task.executing() and not task.done() and - (node is None or node in nodes_to_use)): - yielded_work = True - yield task, entity, node - with self._tasks_lock: - # Get rid of any tasks that are done - self._tasks = list(filter(lambda t_e_n: not t_e_n[0].done(), self._tasks)) + tasks = self._tasks + # Tasks that need to be executed again will add themselves back to the executor + self._tasks = [] + while tasks: + task_trio = tasks.pop() + task, entity, node = task_trio + if node is None or node in nodes_to_use: + yielded_work = True + yield task_trio + else: + # Asked not to execute these tasks, so don't do them yet + with self._tasks_lock: + self._tasks.append(task_trio) # Gather entities that can be waited on subscriptions: List[Subscription] = [] diff --git a/rclpy/rclpy/task.py b/rclpy/rclpy/task.py index 2ed1d6cd4..89253813c 100644 --- a/rclpy/rclpy/task.py +++ b/rclpy/rclpy/task.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import inspect import sys import threading @@ -52,11 +53,12 @@ def __del__(self): file=sys.stderr) def __await__(self): - # Yield if the task is not finished - while not self._done: - yield + if not self._done: + yield self return self.result() + __iter__ = __await__ + def cancel(self): """Request cancellation of the running task if it is not done already.""" with self._lock: @@ -142,7 +144,7 @@ def _schedule_or_invoke_done_callbacks(self): if executor is not None: # Have the executor take care of the callbacks for callback in callbacks: - executor.create_task(callback, self) + executor.call_soon(callback, self) else: # No executor, call right away for callback in callbacks: @@ -176,7 +178,7 @@ def add_done_callback(self, callback): if self._done: executor = self._executor() if executor is not None: - executor.create_task(callback, self) + executor.call_soon(callback, self) else: invoke = True else: @@ -199,6 +201,8 @@ class Task(Future): def __init__(self, handler, args=None, kwargs=None, executor=None): super().__init__(executor=executor) + if executor is None: + raise RuntimeError('Task requires an executor') # _handler is either a normal function or a coroutine self._handler = handler # Arguments passed into the function @@ -212,12 +216,8 @@ def __init__(self, handler, args=None, kwargs=None, executor=None): self._handler = handler(*args, **kwargs) self._args = None self._kwargs = None - # True while the task is being executed - self._executing = False - # Lock acquired to prevent task from executing in parallel with itself - self._task_lock = threading.Lock() - def __call__(self): + def __call__(self, future=None): """ Run or resume a task. @@ -225,49 +225,57 @@ def __call__(self): await it. If there are done callbacks it will schedule them with the executor. The return value of the handler is stored as the task result. + + This function must not be called in parallel with itself. + + :param future: do not use """ - if self._done or self._executing or not self._task_lock.acquire(blocking=False): + if self._done: return - try: - if self._done: - return - self._executing = True - - if inspect.iscoroutine(self._handler): - # Execute a coroutine - try: - self._handler.send(None) - except StopIteration as e: - # The coroutine finished; store the result - self._handler.close() - self.set_result(e.value) - self._complete_task() - except Exception as e: - self.set_exception(e) - self._complete_task() - else: - # Execute a normal function - try: - self.set_result(self._handler(*self._args, **self._kwargs)) - except Exception as e: - self.set_exception(e) + if inspect.iscoroutine(self._handler): + # Execute a coroutine + try: + result = self._handler.send(None) + if isinstance(result, Future): + # Wait for an rclpy future to complete + result.add_done_callback(self) + elif asyncio.isfuture(result): + # Get the event loop of this thread (raises RuntimeError if there isn't one) + event_loop = asyncio.get_running_loop() + # Make sure we're in the same thread as the future's event loop. + # TODO(sloretz) is asyncio.Future.get_loop() thread-safe? + if result.get_loop() is not event_loop: + raise RuntimeError('Cannot await asyncio future from a different thread') + # Resume this task when the asyncio future completes + result.add_done_callback(lambda _: self._executor().call_soon(self)) + elif result is None: + # Wait for one iteration if a bare yield is used + self._executor().call_soon(self) + else: + # What is this intermediate value? + # Could be a different async library's coroutine + # Could be a generator yielded a value + raise RuntimeError(f'Coroutine yielded unexpected value: {result}') + except StopIteration as e: + # Coroutine or generator returning a result + self._handler.close() + self.set_result(e.value) self._complete_task() - - self._executing = False - finally: - self._task_lock.release() + except Exception as e: + # Coroutine or generator raising an exception + self._handler.close() + self.set_exception(e) + self._complete_task() + else: + # Execute a normal function + try: + self.set_result(self._handler(*self._args, **self._kwargs)) + except Exception as e: + self.set_exception(e) + self._complete_task() def _complete_task(self): """Cleanup after task finished.""" self._handler = None self._args = None self._kwargs = None - - def executing(self): - """ - Check if the task is currently being executed. - - :return: True if the task is currently executing. - :rtype: bool - """ - return self._executing diff --git a/rclpy/test/test_asyncio_interop.py b/rclpy/test/test_asyncio_interop.py new file mode 100644 index 000000000..9951f0b15 --- /dev/null +++ b/rclpy/test/test_asyncio_interop.py @@ -0,0 +1,66 @@ +# Copyright 2022 Open Source Robotics Foundation, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import time + +import pytest + +import rclpy +from rclpy.executors import SingleThreadedExecutor + + +MAX_TEST_TIME = 5.0 +TIME_FUDGE_FACTOR = 0.2 + + +@pytest.fixture +def node_and_executor(): + rclpy.init() + node = rclpy.create_node('test_asyncio_interop') + executor = SingleThreadedExecutor() + executor.add_node(node) + yield node, executor + executor.shutdown() + node.destroy_node() + rclpy.shutdown() + + +def test_sleep_in_event_loop(node_and_executor): + node, executor = node_and_executor + + expected_sleep_time = 0.5 + sleep_time = None + + async def cb(): + nonlocal sleep_time + start = time.monotonic() + await asyncio.sleep(expected_sleep_time) + end = time.monotonic() + sleep_time = end - start + + guard = node.create_guard_condition(cb) + guard.trigger() + + async def spin(): + nonlocal sleep_time + start = time.monotonic() + while not sleep_time and MAX_TEST_TIME > time.monotonic() - start: + executor.spin_once(timeout_sec=0) + # Don't use 100% CPU + await asyncio.sleep(0.01) + + asyncio.run(spin()) + assert sleep_time >= expected_sleep_time + assert abs(expected_sleep_time - sleep_time) <= expected_sleep_time * TIME_FUDGE_FACTOR diff --git a/rclpy/test/test_executor.py b/rclpy/test/test_executor.py index c0c3cc810..aab025f96 100644 --- a/rclpy/test/test_executor.py +++ b/rclpy/test/test_executor.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio import threading import time import unittest @@ -157,25 +156,15 @@ def test_execute_coroutine_timer(self): executor.add_node(self.node) called1 = False - called2 = False async def coroutine(): nonlocal called1 - nonlocal called2 called1 = True - await asyncio.sleep(0) - called2 = True tmr = self.node.create_timer(0.1, coroutine) try: executor.spin_once(timeout_sec=1.23) self.assertTrue(called1) - self.assertFalse(called2) - - called1 = False - executor.spin_once(timeout_sec=0) - self.assertFalse(called1) - self.assertTrue(called2) finally: self.node.destroy_timer(tmr) @@ -185,26 +174,16 @@ def test_execute_coroutine_guard_condition(self): executor.add_node(self.node) called1 = False - called2 = False async def coroutine(): nonlocal called1 - nonlocal called2 called1 = True - await asyncio.sleep(0) - called2 = True gc = self.node.create_guard_condition(coroutine) try: gc.trigger() executor.spin_once(timeout_sec=0) self.assertTrue(called1) - self.assertFalse(called2) - - called1 = False - executor.spin_once(timeout_sec=1) - self.assertFalse(called1) - self.assertTrue(called2) finally: self.node.destroy_guard_condition(gc) diff --git a/rclpy/test/test_guard_condition.py b/rclpy/test/test_guard_condition.py index f7eeb0d84..53cae82e0 100644 --- a/rclpy/test/test_guard_condition.py +++ b/rclpy/test/test_guard_condition.py @@ -26,15 +26,19 @@ def setUpClass(cls): rclpy.init(context=cls.context) cls.node = rclpy.create_node( 'TestGuardCondition', namespace='/rclpy/test', context=cls.context) - cls.executor = SingleThreadedExecutor(context=cls.context) - cls.executor.add_node(cls.node) @classmethod def tearDownClass(cls): - cls.executor.shutdown() cls.node.destroy_node() rclpy.shutdown(context=cls.context) + def setUp(self): + self.executor = SingleThreadedExecutor(context=self.context) + self.executor.add_node(self.node) + + def tearDown(self): + self.executor.shutdown() + def test_trigger(self): called = False diff --git a/rclpy/test/test_task.py b/rclpy/test/test_task.py index f0e92ccf9..8167fe94a 100644 --- a/rclpy/test/test_task.py +++ b/rclpy/test/test_task.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio import unittest from rclpy.task import Future @@ -27,15 +26,24 @@ def __init__(self): def create_task(self, cb, *args): self.done_callbacks.append((cb, args)) + def call_soon(self, cb, *args): + self.done_callbacks.append((cb, args)) + class TestTask(unittest.TestCase): + def setUp(self): + self.executor = DummyExecutor() + + def tearDown(self): + self.executor = None + def test_task_normal_callable(self): def func(): return 'Sentinel Result' - t = Task(func) + t = Task(func, executor=self.executor) t() self.assertTrue(t.done()) self.assertEqual('Sentinel Result', t.result()) @@ -45,58 +53,53 @@ def test_task_lambda(self): def func(): return 'Sentinel Result' - t = Task(lambda: func()) + t = Task(lambda: func(), executor=self.executor) t() self.assertTrue(t.done()) self.assertEqual('Sentinel Result', t.result()) def test_coroutine(self): called1 = False - called2 = False - async def coro(): + # async based coroutine + async def coro(fut): nonlocal called1 - nonlocal called2 called1 = True - await asyncio.sleep(0) - called2 = True - return 'Sentinel Result' + result = await fut + return 'Sentinel ' + result - t = Task(coro) + fut = Future() + t = Task(coro, args=(fut,), executor=self.executor) t() self.assertTrue(called1) - self.assertFalse(called2) + self.assertFalse(t.done()) called1 = False + fut.set_result('Result') t() self.assertFalse(called1) - self.assertTrue(called2) self.assertTrue(t.done()) self.assertEqual('Sentinel Result', t.result()) def test_done_callback_scheduled(self): - executor = DummyExecutor() - - t = Task(lambda: None, executor=executor) + t = Task(lambda: None, executor=self.executor) t.add_done_callback('Sentinel Value') t() self.assertTrue(t.done()) - self.assertEqual(1, len(executor.done_callbacks)) - self.assertEqual('Sentinel Value', executor.done_callbacks[0][0]) - args = executor.done_callbacks[0][1] + self.assertEqual(1, len(self.executor.done_callbacks)) + self.assertEqual('Sentinel Value', self.executor.done_callbacks[0][0]) + args = self.executor.done_callbacks[0][1] self.assertEqual(1, len(args)) self.assertEqual(t, args[0]) def test_done_task_done_callback_scheduled(self): - executor = DummyExecutor() - - t = Task(lambda: None, executor=executor) + t = Task(lambda: None, executor=self.executor) t() self.assertTrue(t.done()) t.add_done_callback('Sentinel Value') - self.assertEqual(1, len(executor.done_callbacks)) - self.assertEqual('Sentinel Value', executor.done_callbacks[0][0]) - args = executor.done_callbacks[0][1] + self.assertEqual(1, len(self.executor.done_callbacks)) + self.assertEqual('Sentinel Value', self.executor.done_callbacks[0][0]) + args = self.executor.done_callbacks[0][1] self.assertEqual(1, len(args)) self.assertEqual(t, args[0]) @@ -107,7 +110,7 @@ def func(): nonlocal called called = True - t = Task(func) + t = Task(func, executor=self.executor) t() self.assertTrue(called) self.assertTrue(t.done()) @@ -117,12 +120,12 @@ def func(): self.assertTrue(t.done()) def test_cancelled(self): - t = Task(lambda: None) + t = Task(lambda: None, executor=self.executor) t.cancel() self.assertTrue(t.cancelled()) def test_done_task_cancelled(self): - t = Task(lambda: None) + t = Task(lambda: None, executor=self.executor) t() t.cancel() self.assertFalse(t.cancelled()) @@ -134,7 +137,7 @@ def func(): e.sentinel_value = 'Sentinel Exception' raise e - t = Task(func) + t = Task(func, executor=self.executor) t() self.assertTrue(t.done()) self.assertEqual('Sentinel Exception', t.exception().sentinel_value) @@ -148,7 +151,7 @@ async def coro(): e.sentinel_value = 'Sentinel Exception' raise e - t = Task(coro) + t = Task(coro, executor=self.executor) t() self.assertTrue(t.done()) self.assertEqual('Sentinel Exception', t.exception().sentinel_value) @@ -161,7 +164,7 @@ def test_task_normal_callable_args(self): def func(arg): return arg - t = Task(func, args=(arg_in,)) + t = Task(func, args=(arg_in,), executor=self.executor) t() self.assertEqual('Sentinel Arg', t.result()) @@ -171,7 +174,7 @@ def test_coroutine_args(self): async def coro(arg): return arg - t = Task(coro, args=(arg_in,)) + t = Task(coro, args=(arg_in,), executor=self.executor) t() self.assertEqual('Sentinel Arg', t.result()) @@ -181,7 +184,7 @@ def test_task_normal_callable_kwargs(self): def func(kwarg=None): return kwarg - t = Task(func, kwargs={'kwarg': arg_in}) + t = Task(func, kwargs={'kwarg': arg_in}, executor=self.executor) t() self.assertEqual('Sentinel Arg', t.result()) @@ -191,14 +194,10 @@ def test_coroutine_kwargs(self): async def coro(kwarg=None): return kwarg - t = Task(coro, kwargs={'kwarg': arg_in}) + t = Task(coro, kwargs={'kwarg': arg_in}, executor=self.executor) t() self.assertEqual('Sentinel Arg', t.result()) - def test_executing(self): - t = Task(lambda: None) - self.assertFalse(t.executing()) - class TestFuture(unittest.TestCase):