Skip to content

Commit

Permalink
Add WiP gradio integration
Browse files Browse the repository at this point in the history
  • Loading branch information
RedTachyon committed Mar 20, 2024
1 parent a0b6646 commit d18fe7b
Show file tree
Hide file tree
Showing 5 changed files with 258 additions and 21 deletions.
66 changes: 66 additions & 0 deletions cogment_lab/humans/gradio_actor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import asyncio
import json
import logging
import multiprocessing as mp
from typing import Any

import cogment
import numpy as np

from cogment_lab.core import CogmentActor
from cogment_lab.generated import cog_settings


def obs_to_msg(obs: np.ndarray | dict[str, np.ndarray | dict]) -> dict[str, Any]:
if isinstance(obs, np.ndarray):
obs = obs.tolist()
elif isinstance(obs, dict):
obs = {k: obs_to_msg(v) for k, v in obs.items()}
elif isinstance(obs, np.integer):
obs = int(obs)
elif isinstance(obs, np.floating):
obs = float(obs)
return obs


def msg_to_action(data: str, action_map: list[str] | dict[str, int]) -> int:
if isinstance(action_map, list):
action_map = {action: i for i, action in enumerate(action_map)}
if data.startswith("{"):
action = json.loads(data)
elif data not in action_map:
action = action_map["no-op"]
else:
action = action_map[data]
logging.info(f"Processed action {action} from {data} with action_map {action_map}")
return action


class GradioActor(CogmentActor):
def __init__(self, send_queue: mp.Queue, recv_queue: mp.Queue):
super().__init__(send_queue, recv_queue)
self.send_queue = send_queue
self.recv_queue = recv_queue

async def act(self, observation: Any, rendered_frame: np.ndarray | None = None) -> int:
logging.info(f"Received observation {observation} and frame inside gradio actor")
obs_data = obs_to_msg(observation)
self.send_queue.put((obs_data, rendered_frame))
logging.info(f"Sent observation {obs_data} and frame inside gradio actor")
action = self.recv_queue.get()
logging.info(f"Received action {action} inside gradio actor")
return action


async def run_cogment_actor(port: int, send_queue: asyncio.Queue, recv_queue: asyncio.Queue, signal_queue: mp.Queue):
context = cogment.Context(cog_settings=cog_settings, user_id="cogment_lab")
gradio_actor = GradioActor(send_queue, recv_queue)

logging.info("Registering actor")
context.register_actor(impl=gradio_actor.impl, impl_name="gradio", actor_classes=["player"])

logging.info("Serving actor")
serve = context.serve_all_registered(cogment.ServedEndpoint(port=port))

signal_queue.put(True)
await serve
116 changes: 116 additions & 0 deletions cogment_lab/humans/gradio_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import asyncio
import logging
import multiprocessing as mp
import signal

import gradio as gr

from cogment_lab.humans.gradio_actor import run_cogment_actor
from cogment_lab.utils.runners import setup_logging


async def shutdown():
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
asyncio.get_event_loop().stop()


def signal_handler(sig, frame):
asyncio.create_task(shutdown())


def run_gradio_interface(send_queue: mp.Queue, recv_queue: mp.Queue):
# transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-base.en")

# def transcribe(audio):
# sr, y = audio
# y = y.astype(np.float32)
# y /= np.max(np.abs(y))
#
# return transcriber({"sampling_rate": sr, "raw": y})["text"]

def get_starting_action():
logging.info("Getting starting action inside gradio")
obs, frame = recv_queue.get()
logging.info(f"Received obs {obs} and frame inside gradio")
return obs, frame

def send_action(action: str):
logging.info(f"Sending action {action} inside gradio")
action = int(action) # TODO: Handle non-integer actions
send_queue.put(action)
logging.info(f"Sent action {action} inside gradio")
obs, frame = recv_queue.get()
logging.info(f"Received obs {obs} and frame inside gradio")
return obs, frame

with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=2):
with gr.Row():
image_output = gr.Image(label="Image Output")
with gr.Column(scale=1):
text_output = gr.Textbox(label="Text Output")
with gr.Row():
text_input = gr.Textbox(label="Text Input")
start_button = gr.Button("Start")

# start_button.click(fn=generate_random_image, outputs=image_output)
start_button.click(fn=get_starting_action, outputs=[text_output, image_output])
text_input.submit(fn=send_action, inputs=text_input, outputs=[text_output, image_output])

demo.launch()


async def gradio_actor_main(
cogment_port: int,
signal_queue: mp.Queue,
):
gradio_to_actor = mp.Queue()
actor_to_gradio = mp.Queue()

logging.info("Starting gradio interface")
process = mp.Process(target=run_gradio_interface, args=(gradio_to_actor, actor_to_gradio))
process.start()

logging.info("Starting cogment actor")
cogment_task = asyncio.create_task(
run_cogment_actor(
port=cogment_port,
send_queue=actor_to_gradio,
recv_queue=gradio_to_actor,
signal_queue=signal_queue,
)
)

logging.info("Waiting for cogment actor to finish")

await cogment_task

logging.error("Cogment actor finished, runner exiting")


def gradio_actor_runner(
cogment_port: int,
signal_queue: mp.Queue,
log_file: str | None = None,
):
if log_file:
setup_logging(log_file)

loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
for sig in [signal.SIGINT, signal.SIGTERM]:
loop.add_signal_handler(sig, lambda s=sig, frame=None: signal_handler(s, frame))

try:
loop.run_until_complete(
gradio_actor_main(
cogment_port=cogment_port,
signal_queue=signal_queue,
)
)
finally:
loop.close()
21 changes: 0 additions & 21 deletions cogment_lab/humans/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,27 +23,6 @@
from cogment_lab.utils.runners import setup_logging


# def human_actor_runner(
# app_port: int = 8000,
# cogment_port: int = 8999,
# log_file: str | None = None
# ):
# """Runs the human actor along with the FastAPI server"""
# if log_file:
# setup_logging(log_file)
#
# # Queues for communication between FastAPI and Cogment actor
# app_to_actor = asyncio.Queue()
# actor_to_app = asyncio.Queue()
#
# # Asyncio tasks for the FastAPI server and Cogment actor
# fastapi_task = start_fastapi(port=app_port, send_queue=app_to_actor, recv_queue=actor_to_app)
# cogment_task = asyncio.create_task(run_cogment_actor(port=cogment_port, send_queue=actor_to_app, recv_queue=app_to_actor))
#
# # Run the asyncio event loop
# asyncio.run(asyncio.gather(fastapi_task, cogment_task))


async def shutdown():
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
for task in tasks:
Expand Down
37 changes: 37 additions & 0 deletions cogment_lab/process_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from cogment_lab.core import BaseActor, BaseEnv
from cogment_lab.envs.runner import env_runner
from cogment_lab.generated import cog_settings, data_pb2
from cogment_lab.humans.gradio_runner import gradio_actor_runner
from cogment_lab.humans.runner import human_actor_runner
from cogment_lab.utils.trial_utils import (
TrialData,
Expand Down Expand Up @@ -340,6 +341,42 @@ def run_web_ui(

return self.is_ready(signal_queue)

def run_gradio_ui(
self,
app_port: int = 7860, # TODO: currently doesn't work
cogment_port: int = 8998,
log_file: str | None = None,
) -> Coroutine[None, None, bool]:
"""Runs the human actor in a separate process
Args:
app_port (int, optional): Port for web UI. Defaults to 8000.
cogment_port (int, optional): Port for Cogment connection. Defaults to 8999.
log_file (str | None, optional): Log file path. Defaults to None.
Returns:
bool: Whether the web UI startup succeeded
"""

signal_queue = Queue(1)

if self.log_dir is not None and log_file:
log_file = os.path.join(self.log_dir, log_file)

self._add_process(
target=gradio_actor_runner,
name="gradio",
args=(
cogment_port,
signal_queue,
log_file,
),
)
logging.info(f"Started gradio UI on port {app_port} with log file {log_file}")

self.actor_ports["gradio"] = cogment_port
return self.is_ready(signal_queue)

def stop_service(self, name: ImplName, timeout: float = 1.0):
"""Stops a process or a task.
Expand Down
39 changes: 39 additions & 0 deletions examples/gymnasium/gradio_as_actor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import asyncio
import datetime

from cogment_lab.envs.gymnasium import GymEnvironment
from cogment_lab.process_manager import Cogment


async def main():
logpath = f"logs/logs-{datetime.datetime.now().isoformat()}"

cog = Cogment(log_dir=logpath)

print(logpath)

# Launch an environment in a subprocess

cenv = GymEnvironment(env_id="FrozenLake-v1", render=True, make_kwargs={"is_slippery": False})

await cog.run_env(env=cenv, env_name="flake", port=9011, log_file="env.log")

# Launch gradio env

await cog.run_gradio_ui(log_file="gradio_ui.log")

trial_id = await cog.start_trial(
env_name="flake",
session_config={"render": True},
actor_impls={
"gym": "gradio",
},
)

data = await cog.get_trial_data(trial_id)

print(data)


if __name__ == "__main__":
asyncio.run(main())

0 comments on commit d18fe7b

Please sign in to comment.