Skip to content

Commit

Permalink
Refactoring, graceful exit
Browse files Browse the repository at this point in the history
  • Loading branch information
RedTachyon committed Mar 21, 2024
1 parent 047f696 commit bdc32a6
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 121 deletions.
76 changes: 75 additions & 1 deletion cogment_lab/humans/gradio_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
import json
import logging
import multiprocessing as mp
from typing import Any
import signal
from typing import Any, Callable

import cogment
import numpy as np

from cogment_lab.core import CogmentActor
from cogment_lab.generated import cog_settings
from cogment_lab.utils.runners import setup_logging


def obs_to_msg(obs: np.ndarray | dict[str, np.ndarray | dict]) -> dict[str, Any]:
Expand Down Expand Up @@ -64,3 +66,75 @@ async def run_cogment_actor(port: int, send_queue: asyncio.Queue, recv_queue: as

signal_queue.put(True)
await serve


async def shutdown():
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
asyncio.get_event_loop().stop()


def signal_handler(sig, frame):
asyncio.create_task(shutdown())


async def gradio_actor_main(
cogment_port: int,
gradio_app_fn: Callable[[mp.Queue, mp.Queue, str], None],
signal_queue: mp.Queue,
log_file: str | None = None,
):
gradio_to_actor = mp.Queue()
actor_to_gradio = mp.Queue()

logging.info("Starting gradio interface")
process = mp.Process(target=gradio_app_fn, args=(gradio_to_actor, actor_to_gradio, log_file))
process.start()

try:
logging.info("Starting cogment actor")
cogment_task = asyncio.create_task(
run_cogment_actor(
port=cogment_port,
send_queue=actor_to_gradio,
recv_queue=gradio_to_actor,
signal_queue=signal_queue,
)
)

logging.info("Waiting for cogment actor to finish")

await cogment_task
finally:
process.terminate()
process.join()


def gradio_actor_runner(
cogment_port: int,
gradio_app_fn: Callable[[mp.Queue, mp.Queue, str], None],
signal_queue: mp.Queue,
log_file: str | None = None,
):
if log_file:
setup_logging(log_file)

loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
for sig in [signal.SIGINT, signal.SIGTERM]:
loop.add_signal_handler(sig, lambda s=sig, frame=None: signal_handler(s, frame))

try:
loop.run_until_complete(
gradio_actor_main(
cogment_port=cogment_port,
gradio_app_fn=gradio_app_fn,
signal_queue=signal_queue,
log_file=log_file,
)
)
finally:
loop.run_until_complete(shutdown())
loop.close()
116 changes: 0 additions & 116 deletions cogment_lab/humans/gradio_runner.py

This file was deleted.

7 changes: 4 additions & 3 deletions cogment_lab/process_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from cogment_lab.core import BaseActor, BaseEnv
from cogment_lab.envs.runner import env_runner
from cogment_lab.generated import cog_settings, data_pb2
from cogment_lab.humans.gradio_runner import gradio_actor_runner
from cogment_lab.humans.gradio_actor import gradio_actor_runner
from cogment_lab.humans.runner import human_actor_runner
from cogment_lab.utils.trial_utils import (
TrialData,
Expand Down Expand Up @@ -343,7 +343,7 @@ def run_web_ui(

def run_gradio_ui(
self,
app_port: int = 7860, # TODO: currently doesn't work
gradio_app_fn: Callable[[mp.Queue, mp.Queue, str], None],
cogment_port: int = 8998,
log_file: str | None = None,
) -> Coroutine[None, None, bool]:
Expand All @@ -368,11 +368,12 @@ def run_gradio_ui(
name="gradio",
args=(
cogment_port,
gradio_app_fn,
signal_queue,
log_file,
),
)
logging.info(f"Started gradio UI on port {app_port} with log file {log_file}")
logging.info(f"Started gradio UI on a port with log file {log_file}")

self.actor_ports["gradio"] = cogment_port
return self.is_ready(signal_queue)
Expand Down
60 changes: 59 additions & 1 deletion examples/gymnasium/gradio_as_actor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,66 @@
import asyncio
import datetime
import logging
import multiprocessing as mp

import gradio as gr
import numpy as np
from transformers import pipeline

from cogment_lab.envs.gymnasium import GymEnvironment
from cogment_lab.process_manager import Cogment
from cogment_lab.utils.runners import setup_logging


def run_gradio_interface(send_queue: mp.Queue, recv_queue: mp.Queue, log_file: str | None = None):
if log_file is not None:
setup_logging(log_file)

transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-base.en")

def transcribe(audio):
if audio is None:
return None
sr, y = audio
y = y.astype(np.float32)
y /= np.max(np.abs(y))

return transcriber({"sampling_rate": sr, "raw": y})["text"]

def get_starting_action():
logging.info("Getting starting action inside gradio")
obs, frame = recv_queue.get()
logging.info(f"Received obs {obs} and frame inside gradio")
# start_button.visible = False
# text_input.visible = True
return obs, frame

def send_action(action: str):
logging.info(f"Sending action {action} inside gradio")
action = int(action) # TODO: Handle non-integer actions
send_queue.put(action)
logging.info(f"Sent action {action} inside gradio")
obs, frame = recv_queue.get()
logging.info(f"Received obs {obs} and frame inside gradio")
return obs, frame

with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=2):
with gr.Row():
image_output = gr.Image(label="Image Output")
with gr.Column(scale=1):
text_output = gr.Textbox(label="Text Output")
with gr.Row():
text_input = gr.Textbox(label="Text Input")
start_button = gr.Button("Start")
audio_input = gr.Audio(label="Audio Input")

start_button.click(fn=get_starting_action, outputs=[text_output, image_output])
text_input.submit(fn=send_action, inputs=text_input, outputs=[text_output, image_output])
audio_input.change(fn=transcribe, inputs=audio_input, outputs=text_input)

demo.launch()


async def main():
Expand All @@ -20,7 +78,7 @@ async def main():

# Launch gradio env

await cog.run_gradio_ui(log_file="gradio_ui.log")
await cog.run_gradio_ui(gradio_app_fn=run_gradio_interface, log_file="gradio_ui.log")

trial_id = await cog.start_trial(
env_name="flake",
Expand Down

0 comments on commit bdc32a6

Please sign in to comment.