From bdc32a610baa1230939c48f4b88a29f6e4849a3c Mon Sep 17 00:00:00 2001 From: ariel Date: Thu, 21 Mar 2024 11:03:15 +0100 Subject: [PATCH] Refactoring, graceful exit --- cogment_lab/humans/gradio_actor.py | 76 ++++++++++++++++- cogment_lab/humans/gradio_runner.py | 116 -------------------------- cogment_lab/process_manager.py | 7 +- examples/gymnasium/gradio_as_actor.py | 60 ++++++++++++- 4 files changed, 138 insertions(+), 121 deletions(-) delete mode 100644 cogment_lab/humans/gradio_runner.py diff --git a/cogment_lab/humans/gradio_actor.py b/cogment_lab/humans/gradio_actor.py index 29815c2..bab256e 100644 --- a/cogment_lab/humans/gradio_actor.py +++ b/cogment_lab/humans/gradio_actor.py @@ -2,13 +2,15 @@ import json import logging import multiprocessing as mp -from typing import Any +import signal +from typing import Any, Callable import cogment import numpy as np from cogment_lab.core import CogmentActor from cogment_lab.generated import cog_settings +from cogment_lab.utils.runners import setup_logging def obs_to_msg(obs: np.ndarray | dict[str, np.ndarray | dict]) -> dict[str, Any]: @@ -64,3 +66,75 @@ async def run_cogment_actor(port: int, send_queue: asyncio.Queue, recv_queue: as signal_queue.put(True) await serve + + +async def shutdown(): + tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] + for task in tasks: + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + asyncio.get_event_loop().stop() + + +def signal_handler(sig, frame): + asyncio.create_task(shutdown()) + + +async def gradio_actor_main( + cogment_port: int, + gradio_app_fn: Callable[[mp.Queue, mp.Queue, str], None], + signal_queue: mp.Queue, + log_file: str | None = None, +): + gradio_to_actor = mp.Queue() + actor_to_gradio = mp.Queue() + + logging.info("Starting gradio interface") + process = mp.Process(target=gradio_app_fn, args=(gradio_to_actor, actor_to_gradio, log_file)) + process.start() + + try: + logging.info("Starting cogment actor") + cogment_task = asyncio.create_task( + run_cogment_actor( + port=cogment_port, + send_queue=actor_to_gradio, + recv_queue=gradio_to_actor, + signal_queue=signal_queue, + ) + ) + + logging.info("Waiting for cogment actor to finish") + + await cogment_task + finally: + process.terminate() + process.join() + + +def gradio_actor_runner( + cogment_port: int, + gradio_app_fn: Callable[[mp.Queue, mp.Queue, str], None], + signal_queue: mp.Queue, + log_file: str | None = None, +): + if log_file: + setup_logging(log_file) + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + for sig in [signal.SIGINT, signal.SIGTERM]: + loop.add_signal_handler(sig, lambda s=sig, frame=None: signal_handler(s, frame)) + + try: + loop.run_until_complete( + gradio_actor_main( + cogment_port=cogment_port, + gradio_app_fn=gradio_app_fn, + signal_queue=signal_queue, + log_file=log_file, + ) + ) + finally: + loop.run_until_complete(shutdown()) + loop.close() diff --git a/cogment_lab/humans/gradio_runner.py b/cogment_lab/humans/gradio_runner.py deleted file mode 100644 index 378b32b..0000000 --- a/cogment_lab/humans/gradio_runner.py +++ /dev/null @@ -1,116 +0,0 @@ -import asyncio -import logging -import multiprocessing as mp -import signal - -import gradio as gr - -from cogment_lab.humans.gradio_actor import run_cogment_actor -from cogment_lab.utils.runners import setup_logging - - -async def shutdown(): - tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] - for task in tasks: - task.cancel() - await asyncio.gather(*tasks, return_exceptions=True) - asyncio.get_event_loop().stop() - - -def signal_handler(sig, frame): - asyncio.create_task(shutdown()) - - -def run_gradio_interface(send_queue: mp.Queue, recv_queue: mp.Queue): - # transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-base.en") - - # def transcribe(audio): - # sr, y = audio - # y = y.astype(np.float32) - # y /= np.max(np.abs(y)) - # - # return transcriber({"sampling_rate": sr, "raw": y})["text"] - - def get_starting_action(): - logging.info("Getting starting action inside gradio") - obs, frame = recv_queue.get() - logging.info(f"Received obs {obs} and frame inside gradio") - return obs, frame - - def send_action(action: str): - logging.info(f"Sending action {action} inside gradio") - action = int(action) # TODO: Handle non-integer actions - send_queue.put(action) - logging.info(f"Sent action {action} inside gradio") - obs, frame = recv_queue.get() - logging.info(f"Received obs {obs} and frame inside gradio") - return obs, frame - - with gr.Blocks() as demo: - with gr.Row(): - with gr.Column(scale=2): - with gr.Row(): - image_output = gr.Image(label="Image Output") - with gr.Column(scale=1): - text_output = gr.Textbox(label="Text Output") - with gr.Row(): - text_input = gr.Textbox(label="Text Input") - start_button = gr.Button("Start") - - # start_button.click(fn=generate_random_image, outputs=image_output) - start_button.click(fn=get_starting_action, outputs=[text_output, image_output]) - text_input.submit(fn=send_action, inputs=text_input, outputs=[text_output, image_output]) - - demo.launch() - - -async def gradio_actor_main( - cogment_port: int, - signal_queue: mp.Queue, -): - gradio_to_actor = mp.Queue() - actor_to_gradio = mp.Queue() - - logging.info("Starting gradio interface") - process = mp.Process(target=run_gradio_interface, args=(gradio_to_actor, actor_to_gradio)) - process.start() - - logging.info("Starting cogment actor") - cogment_task = asyncio.create_task( - run_cogment_actor( - port=cogment_port, - send_queue=actor_to_gradio, - recv_queue=gradio_to_actor, - signal_queue=signal_queue, - ) - ) - - logging.info("Waiting for cogment actor to finish") - - await cogment_task - - logging.error("Cogment actor finished, runner exiting") - - -def gradio_actor_runner( - cogment_port: int, - signal_queue: mp.Queue, - log_file: str | None = None, -): - if log_file: - setup_logging(log_file) - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - for sig in [signal.SIGINT, signal.SIGTERM]: - loop.add_signal_handler(sig, lambda s=sig, frame=None: signal_handler(s, frame)) - - try: - loop.run_until_complete( - gradio_actor_main( - cogment_port=cogment_port, - signal_queue=signal_queue, - ) - ) - finally: - loop.close() diff --git a/cogment_lab/process_manager.py b/cogment_lab/process_manager.py index 196dc5d..e27deca 100644 --- a/cogment_lab/process_manager.py +++ b/cogment_lab/process_manager.py @@ -33,7 +33,7 @@ from cogment_lab.core import BaseActor, BaseEnv from cogment_lab.envs.runner import env_runner from cogment_lab.generated import cog_settings, data_pb2 -from cogment_lab.humans.gradio_runner import gradio_actor_runner +from cogment_lab.humans.gradio_actor import gradio_actor_runner from cogment_lab.humans.runner import human_actor_runner from cogment_lab.utils.trial_utils import ( TrialData, @@ -343,7 +343,7 @@ def run_web_ui( def run_gradio_ui( self, - app_port: int = 7860, # TODO: currently doesn't work + gradio_app_fn: Callable[[mp.Queue, mp.Queue, str], None], cogment_port: int = 8998, log_file: str | None = None, ) -> Coroutine[None, None, bool]: @@ -368,11 +368,12 @@ def run_gradio_ui( name="gradio", args=( cogment_port, + gradio_app_fn, signal_queue, log_file, ), ) - logging.info(f"Started gradio UI on port {app_port} with log file {log_file}") + logging.info(f"Started gradio UI on a port with log file {log_file}") self.actor_ports["gradio"] = cogment_port return self.is_ready(signal_queue) diff --git a/examples/gymnasium/gradio_as_actor.py b/examples/gymnasium/gradio_as_actor.py index cf319b9..c348764 100644 --- a/examples/gymnasium/gradio_as_actor.py +++ b/examples/gymnasium/gradio_as_actor.py @@ -1,8 +1,66 @@ import asyncio import datetime +import logging +import multiprocessing as mp + +import gradio as gr +import numpy as np +from transformers import pipeline from cogment_lab.envs.gymnasium import GymEnvironment from cogment_lab.process_manager import Cogment +from cogment_lab.utils.runners import setup_logging + + +def run_gradio_interface(send_queue: mp.Queue, recv_queue: mp.Queue, log_file: str | None = None): + if log_file is not None: + setup_logging(log_file) + + transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-base.en") + + def transcribe(audio): + if audio is None: + return None + sr, y = audio + y = y.astype(np.float32) + y /= np.max(np.abs(y)) + + return transcriber({"sampling_rate": sr, "raw": y})["text"] + + def get_starting_action(): + logging.info("Getting starting action inside gradio") + obs, frame = recv_queue.get() + logging.info(f"Received obs {obs} and frame inside gradio") + # start_button.visible = False + # text_input.visible = True + return obs, frame + + def send_action(action: str): + logging.info(f"Sending action {action} inside gradio") + action = int(action) # TODO: Handle non-integer actions + send_queue.put(action) + logging.info(f"Sent action {action} inside gradio") + obs, frame = recv_queue.get() + logging.info(f"Received obs {obs} and frame inside gradio") + return obs, frame + + with gr.Blocks() as demo: + with gr.Row(): + with gr.Column(scale=2): + with gr.Row(): + image_output = gr.Image(label="Image Output") + with gr.Column(scale=1): + text_output = gr.Textbox(label="Text Output") + with gr.Row(): + text_input = gr.Textbox(label="Text Input") + start_button = gr.Button("Start") + audio_input = gr.Audio(label="Audio Input") + + start_button.click(fn=get_starting_action, outputs=[text_output, image_output]) + text_input.submit(fn=send_action, inputs=text_input, outputs=[text_output, image_output]) + audio_input.change(fn=transcribe, inputs=audio_input, outputs=text_input) + + demo.launch() async def main(): @@ -20,7 +78,7 @@ async def main(): # Launch gradio env - await cog.run_gradio_ui(log_file="gradio_ui.log") + await cog.run_gradio_ui(gradio_app_fn=run_gradio_interface, log_file="gradio_ui.log") trial_id = await cog.start_trial( env_name="flake",