-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a0b6646
commit d18fe7b
Showing
5 changed files
with
258 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import asyncio | ||
import json | ||
import logging | ||
import multiprocessing as mp | ||
from typing import Any | ||
|
||
import cogment | ||
import numpy as np | ||
|
||
from cogment_lab.core import CogmentActor | ||
from cogment_lab.generated import cog_settings | ||
|
||
|
||
def obs_to_msg(obs: np.ndarray | dict[str, np.ndarray | dict]) -> dict[str, Any]: | ||
if isinstance(obs, np.ndarray): | ||
obs = obs.tolist() | ||
elif isinstance(obs, dict): | ||
obs = {k: obs_to_msg(v) for k, v in obs.items()} | ||
elif isinstance(obs, np.integer): | ||
obs = int(obs) | ||
elif isinstance(obs, np.floating): | ||
obs = float(obs) | ||
return obs | ||
|
||
|
||
def msg_to_action(data: str, action_map: list[str] | dict[str, int]) -> int: | ||
if isinstance(action_map, list): | ||
action_map = {action: i for i, action in enumerate(action_map)} | ||
if data.startswith("{"): | ||
action = json.loads(data) | ||
elif data not in action_map: | ||
action = action_map["no-op"] | ||
else: | ||
action = action_map[data] | ||
logging.info(f"Processed action {action} from {data} with action_map {action_map}") | ||
return action | ||
|
||
|
||
class GradioActor(CogmentActor): | ||
def __init__(self, send_queue: mp.Queue, recv_queue: mp.Queue): | ||
super().__init__(send_queue, recv_queue) | ||
self.send_queue = send_queue | ||
self.recv_queue = recv_queue | ||
|
||
async def act(self, observation: Any, rendered_frame: np.ndarray | None = None) -> int: | ||
logging.info(f"Received observation {observation} and frame inside gradio actor") | ||
obs_data = obs_to_msg(observation) | ||
self.send_queue.put((obs_data, rendered_frame)) | ||
logging.info(f"Sent observation {obs_data} and frame inside gradio actor") | ||
action = self.recv_queue.get() | ||
logging.info(f"Received action {action} inside gradio actor") | ||
return action | ||
|
||
|
||
async def run_cogment_actor(port: int, send_queue: asyncio.Queue, recv_queue: asyncio.Queue, signal_queue: mp.Queue): | ||
context = cogment.Context(cog_settings=cog_settings, user_id="cogment_lab") | ||
gradio_actor = GradioActor(send_queue, recv_queue) | ||
|
||
logging.info("Registering actor") | ||
context.register_actor(impl=gradio_actor.impl, impl_name="gradio", actor_classes=["player"]) | ||
|
||
logging.info("Serving actor") | ||
serve = context.serve_all_registered(cogment.ServedEndpoint(port=port)) | ||
|
||
signal_queue.put(True) | ||
await serve |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import asyncio | ||
import logging | ||
import multiprocessing as mp | ||
import signal | ||
|
||
import gradio as gr | ||
|
||
from cogment_lab.humans.gradio_actor import run_cogment_actor | ||
from cogment_lab.utils.runners import setup_logging | ||
|
||
|
||
async def shutdown(): | ||
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] | ||
for task in tasks: | ||
task.cancel() | ||
await asyncio.gather(*tasks, return_exceptions=True) | ||
asyncio.get_event_loop().stop() | ||
|
||
|
||
def signal_handler(sig, frame): | ||
asyncio.create_task(shutdown()) | ||
|
||
|
||
def run_gradio_interface(send_queue: mp.Queue, recv_queue: mp.Queue): | ||
# transcriber = pipeline("automatic-speech-recognition", model="openai/whisper-base.en") | ||
|
||
# def transcribe(audio): | ||
# sr, y = audio | ||
# y = y.astype(np.float32) | ||
# y /= np.max(np.abs(y)) | ||
# | ||
# return transcriber({"sampling_rate": sr, "raw": y})["text"] | ||
|
||
def get_starting_action(): | ||
logging.info("Getting starting action inside gradio") | ||
obs, frame = recv_queue.get() | ||
logging.info(f"Received obs {obs} and frame inside gradio") | ||
return obs, frame | ||
|
||
def send_action(action: str): | ||
logging.info(f"Sending action {action} inside gradio") | ||
action = int(action) # TODO: Handle non-integer actions | ||
send_queue.put(action) | ||
logging.info(f"Sent action {action} inside gradio") | ||
obs, frame = recv_queue.get() | ||
logging.info(f"Received obs {obs} and frame inside gradio") | ||
return obs, frame | ||
|
||
with gr.Blocks() as demo: | ||
with gr.Row(): | ||
with gr.Column(scale=2): | ||
with gr.Row(): | ||
image_output = gr.Image(label="Image Output") | ||
with gr.Column(scale=1): | ||
text_output = gr.Textbox(label="Text Output") | ||
with gr.Row(): | ||
text_input = gr.Textbox(label="Text Input") | ||
start_button = gr.Button("Start") | ||
|
||
# start_button.click(fn=generate_random_image, outputs=image_output) | ||
start_button.click(fn=get_starting_action, outputs=[text_output, image_output]) | ||
text_input.submit(fn=send_action, inputs=text_input, outputs=[text_output, image_output]) | ||
|
||
demo.launch() | ||
|
||
|
||
async def gradio_actor_main( | ||
cogment_port: int, | ||
signal_queue: mp.Queue, | ||
): | ||
gradio_to_actor = mp.Queue() | ||
actor_to_gradio = mp.Queue() | ||
|
||
logging.info("Starting gradio interface") | ||
process = mp.Process(target=run_gradio_interface, args=(gradio_to_actor, actor_to_gradio)) | ||
process.start() | ||
|
||
logging.info("Starting cogment actor") | ||
cogment_task = asyncio.create_task( | ||
run_cogment_actor( | ||
port=cogment_port, | ||
send_queue=actor_to_gradio, | ||
recv_queue=gradio_to_actor, | ||
signal_queue=signal_queue, | ||
) | ||
) | ||
|
||
logging.info("Waiting for cogment actor to finish") | ||
|
||
await cogment_task | ||
|
||
logging.error("Cogment actor finished, runner exiting") | ||
|
||
|
||
def gradio_actor_runner( | ||
cogment_port: int, | ||
signal_queue: mp.Queue, | ||
log_file: str | None = None, | ||
): | ||
if log_file: | ||
setup_logging(log_file) | ||
|
||
loop = asyncio.new_event_loop() | ||
asyncio.set_event_loop(loop) | ||
for sig in [signal.SIGINT, signal.SIGTERM]: | ||
loop.add_signal_handler(sig, lambda s=sig, frame=None: signal_handler(s, frame)) | ||
|
||
try: | ||
loop.run_until_complete( | ||
gradio_actor_main( | ||
cogment_port=cogment_port, | ||
signal_queue=signal_queue, | ||
) | ||
) | ||
finally: | ||
loop.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import asyncio | ||
import datetime | ||
|
||
from cogment_lab.envs.gymnasium import GymEnvironment | ||
from cogment_lab.process_manager import Cogment | ||
|
||
|
||
async def main(): | ||
logpath = f"logs/logs-{datetime.datetime.now().isoformat()}" | ||
|
||
cog = Cogment(log_dir=logpath) | ||
|
||
print(logpath) | ||
|
||
# Launch an environment in a subprocess | ||
|
||
cenv = GymEnvironment(env_id="FrozenLake-v1", render=True, make_kwargs={"is_slippery": False}) | ||
|
||
await cog.run_env(env=cenv, env_name="flake", port=9011, log_file="env.log") | ||
|
||
# Launch gradio env | ||
|
||
await cog.run_gradio_ui(log_file="gradio_ui.log") | ||
|
||
trial_id = await cog.start_trial( | ||
env_name="flake", | ||
session_config={"render": True}, | ||
actor_impls={ | ||
"gym": "gradio", | ||
}, | ||
) | ||
|
||
data = await cog.get_trial_data(trial_id) | ||
|
||
print(data) | ||
|
||
|
||
if __name__ == "__main__": | ||
asyncio.run(main()) |