Skip to content

Commit

Permalink
Support text-based spaces to support LLM stuff (#19)
Browse files Browse the repository at this point in the history
Text spaces, gradio interfaces
  • Loading branch information
RedTachyon authored Mar 25, 2024
1 parent 35c1931 commit 1302a37
Show file tree
Hide file tree
Showing 21 changed files with 589 additions and 619 deletions.
13 changes: 12 additions & 1 deletion cogment_lab/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,10 @@ async def on_message(self, messages: list):
"""Handle received messages."""
pass

async def on_ending(self, observation, rendered_frame):
"""Handle trial ending."""
pass

async def end(self):
"""Clean up when done."""
pass
Expand All @@ -250,10 +254,17 @@ async def impl(self, actor_session: ActorSession):
async for event in actor_session.all_events():
event: RecvEvent
self.current_event = event
if event.type != cogment.EventType.ACTIVE:
if event.type not in (cogment.EventType.ACTIVE, cogment.EventType.ENDING):
logging.info(f"Skipping event of type {event.type}")
continue

if event.type == cogment.EventType.ENDING:
observation = self.session_helper.get_observation(event)
await self.on_ending(observation.value, observation.rendered_frame)
continue

# type = ACTIVE

if event.observation:
observation = self.session_helper.get_observation(event)
if observation is None:
Expand Down
54 changes: 26 additions & 28 deletions cogment_lab/generated/data_pb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: data.proto
"""Generated protocol buffer code."""
Expand All @@ -31,34 +30,33 @@
import cogment_lab.generated.spaces_pb2 as spaces__pb2


DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\ndata.proto\x12\x0b\x63ogment_lab\x1a\rndarray.proto\x1a\x0cspaces.proto"\xd7\x01\n\x10\x45nvironmentSpecs\x12\x16\n\x0eimplementation\x18\x01 \x01(\t\x12\x12\n\nturn_based\x18\x02 \x01(\x08\x12\x13\n\x0bnum_players\x18\x03 \x01(\x05\x12\x34\n\x11observation_space\x18\x04 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\x12/\n\x0c\x61\x63tion_space\x18\x05 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\x12\x1b\n\x13web_components_file\x18\x06 \x01(\t"s\n\nAgentSpecs\x12\x34\n\x11observation_space\x18\x01 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\x12/\n\x0c\x61\x63tion_space\x18\x02 \x01(\x0b\x32\x19.cogment_lab.spaces.Space"Y\n\x05Value\x12\x16\n\x0cstring_value\x18\x01 \x01(\tH\x00\x12\x13\n\tint_value\x18\x02 \x01(\x05H\x00\x12\x15\n\x0b\x66loat_value\x18\x03 \x01(\x02H\x00\x42\x0c\n\nvalue_type"\xf1\x01\n\x11\x45nvironmentConfig\x12\x0e\n\x06run_id\x18\x01 \x01(\t\x12\x0e\n\x06render\x18\x02 \x01(\x08\x12\x14\n\x0crender_width\x18\x03 \x01(\x05\x12\x0c\n\x04seed\x18\x04 \x01(\r\x12\x0f\n\x07\x66latten\x18\x05 \x01(\x08\x12\x41\n\nreset_args\x18\x06 \x03(\x0b\x32-.cogment_lab.EnvironmentConfig.ResetArgsEntry\x1a\x44\n\x0eResetArgsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.cogment_lab.Value:\x02\x38\x01"/\n\nHFHubModel\x12\x0f\n\x07repo_id\x18\x01 \x01(\t\x12\x10\n\x08\x66ilename\x18\x02 \x01(\t"\xa4\x01\n\x0b\x41gentConfig\x12\x0e\n\x06run_id\x18\x01 \x01(\t\x12,\n\x0b\x61gent_specs\x18\x02 \x01(\x0b\x32\x17.cogment_lab.AgentSpecs\x12\x0c\n\x04seed\x18\x03 \x01(\r\x12\x10\n\x08model_id\x18\x04 \x01(\t\x12\x17\n\x0fmodel_iteration\x18\x05 \x01(\x05\x12\x1e\n\x16model_update_frequency\x18\x06 \x01(\x05"\r\n\x0bTrialConfig"\x88\x01\n\x0bObservation\x12*\n\x05value\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\x12\x0e\n\x06\x61\x63tive\x18\x02 \x01(\x08\x12\r\n\x05\x61live\x18\x03 \x01(\x08\x12\x1b\n\x0erendered_frame\x18\x04 \x01(\x0cH\x00\x88\x01\x01\x42\x11\n\x0f_rendered_frame":\n\x0cPlayerAction\x12*\n\x05value\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Arrayb\x06proto3'
)
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\ndata.proto\x12\x0b\x63ogment_lab\x1a\rndarray.proto\x1a\x0cspaces.proto\"\xd7\x01\n\x10\x45nvironmentSpecs\x12\x16\n\x0eimplementation\x18\x01 \x01(\t\x12\x12\n\nturn_based\x18\x02 \x01(\x08\x12\x13\n\x0bnum_players\x18\x03 \x01(\x05\x12\x34\n\x11observation_space\x18\x04 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\x12/\n\x0c\x61\x63tion_space\x18\x05 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\x12\x1b\n\x13web_components_file\x18\x06 \x01(\t\"s\n\nAgentSpecs\x12\x34\n\x11observation_space\x18\x01 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\x12/\n\x0c\x61\x63tion_space\x18\x02 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\"Y\n\x05Value\x12\x16\n\x0cstring_value\x18\x01 \x01(\tH\x00\x12\x13\n\tint_value\x18\x02 \x01(\x05H\x00\x12\x15\n\x0b\x66loat_value\x18\x03 \x01(\x02H\x00\x42\x0c\n\nvalue_type\"\xf1\x01\n\x11\x45nvironmentConfig\x12\x0e\n\x06run_id\x18\x01 \x01(\t\x12\x0e\n\x06render\x18\x02 \x01(\x08\x12\x14\n\x0crender_width\x18\x03 \x01(\x05\x12\x0c\n\x04seed\x18\x04 \x01(\r\x12\x0f\n\x07\x66latten\x18\x05 \x01(\x08\x12\x41\n\nreset_args\x18\x06 \x03(\x0b\x32-.cogment_lab.EnvironmentConfig.ResetArgsEntry\x1a\x44\n\x0eResetArgsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.cogment_lab.Value:\x02\x38\x01\"/\n\nHFHubModel\x12\x0f\n\x07repo_id\x18\x01 \x01(\t\x12\x10\n\x08\x66ilename\x18\x02 \x01(\t\"\xa4\x01\n\x0b\x41gentConfig\x12\x0e\n\x06run_id\x18\x01 \x01(\t\x12,\n\x0b\x61gent_specs\x18\x02 \x01(\x0b\x32\x17.cogment_lab.AgentSpecs\x12\x0c\n\x04seed\x18\x03 \x01(\r\x12\x10\n\x08model_id\x18\x04 \x01(\t\x12\x17\n\x0fmodel_iteration\x18\x05 \x01(\x05\x12\x1e\n\x16model_update_frequency\x18\x06 \x01(\x05\"\r\n\x0bTrialConfig\"\x88\x01\n\x0bObservation\x12*\n\x05value\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\x12\x0e\n\x06\x61\x63tive\x18\x02 \x01(\x08\x12\r\n\x05\x61live\x18\x03 \x01(\x08\x12\x1b\n\x0erendered_frame\x18\x04 \x01(\x0cH\x00\x88\x01\x01\x42\x11\n\x0f_rendered_frame\":\n\x0cPlayerAction\x12*\n\x05value\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Arrayb\x06proto3')

_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "data_pb2", globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'data_pb2', globals())
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_ENVIRONMENTCONFIG_RESETARGSENTRY._options = None
_ENVIRONMENTCONFIG_RESETARGSENTRY._serialized_options = b"8\001"
_ENVIRONMENTSPECS._serialized_start = 57
_ENVIRONMENTSPECS._serialized_end = 272
_AGENTSPECS._serialized_start = 274
_AGENTSPECS._serialized_end = 389
_VALUE._serialized_start = 391
_VALUE._serialized_end = 480
_ENVIRONMENTCONFIG._serialized_start = 483
_ENVIRONMENTCONFIG._serialized_end = 724
_ENVIRONMENTCONFIG_RESETARGSENTRY._serialized_start = 656
_ENVIRONMENTCONFIG_RESETARGSENTRY._serialized_end = 724
_HFHUBMODEL._serialized_start = 726
_HFHUBMODEL._serialized_end = 773
_AGENTCONFIG._serialized_start = 776
_AGENTCONFIG._serialized_end = 940
_TRIALCONFIG._serialized_start = 942
_TRIALCONFIG._serialized_end = 955
_OBSERVATION._serialized_start = 958
_OBSERVATION._serialized_end = 1094
_PLAYERACTION._serialized_start = 1096
_PLAYERACTION._serialized_end = 1154

DESCRIPTOR._options = None
_ENVIRONMENTCONFIG_RESETARGSENTRY._options = None
_ENVIRONMENTCONFIG_RESETARGSENTRY._serialized_options = b'8\001'
_ENVIRONMENTSPECS._serialized_start=57
_ENVIRONMENTSPECS._serialized_end=272
_AGENTSPECS._serialized_start=274
_AGENTSPECS._serialized_end=389
_VALUE._serialized_start=391
_VALUE._serialized_end=480
_ENVIRONMENTCONFIG._serialized_start=483
_ENVIRONMENTCONFIG._serialized_end=724
_ENVIRONMENTCONFIG_RESETARGSENTRY._serialized_start=656
_ENVIRONMENTCONFIG_RESETARGSENTRY._serialized_end=724
_HFHUBMODEL._serialized_start=726
_HFHUBMODEL._serialized_end=773
_AGENTCONFIG._serialized_start=776
_AGENTCONFIG._serialized_end=940
_TRIALCONFIG._serialized_start=942
_TRIALCONFIG._serialized_end=955
_OBSERVATION._serialized_start=958
_OBSERVATION._serialized_end=1094
_PLAYERACTION._serialized_start=1096
_PLAYERACTION._serialized_end=1154
# @@protoc_insertion_point(module_scope)
20 changes: 10 additions & 10 deletions cogment_lab/generated/ndarray_pb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: ndarray.proto
"""Generated protocol buffer code."""
Expand All @@ -27,16 +26,17 @@
_sym_db = _symbol_database.Default()


DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\rndarray.proto\x12\x14\x63ogment_lab.nd_array"\xb8\x01\n\x05\x41rray\x12*\n\x05\x64type\x18\x01 \x01(\x0e\x32\x1b.cogment_lab.nd_array.DType\x12\r\n\x05shape\x18\x02 \x03(\r\x12\x10\n\x08raw_data\x18\x03 \x01(\x0c\x12\x10\n\x08npy_data\x18\x04 \x01(\x0c\x12\x13\n\x0b\x64ouble_data\x18\x05 \x03(\x01\x12\x12\n\nint32_data\x18\x06 \x03(\x11\x12\x12\n\nint64_data\x18\x07 \x03(\x12\x12\x13\n\x0buint32_data\x18\x08 \x03(\r*\x83\x01\n\x05\x44Type\x12\x11\n\rDTYPE_UNKNOWN\x10\x00\x12\x11\n\rDTYPE_FLOAT32\x10\x01\x12\x11\n\rDTYPE_FLOAT64\x10\x02\x12\x0e\n\nDTYPE_INT8\x10\x03\x12\x0f\n\x0b\x44TYPE_INT32\x10\x04\x12\x0f\n\x0b\x44TYPE_INT64\x10\x05\x12\x0f\n\x0b\x44TYPE_UINT8\x10\x06\x62\x06proto3'
)


DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rndarray.proto\x12\x14\x63ogment_lab.nd_array\"\xcd\x01\n\x05\x41rray\x12*\n\x05\x64type\x18\x01 \x01(\x0e\x32\x1b.cogment_lab.nd_array.DType\x12\r\n\x05shape\x18\x02 \x03(\r\x12\x10\n\x08raw_data\x18\x03 \x01(\x0c\x12\x10\n\x08npy_data\x18\x04 \x01(\x0c\x12\x13\n\x0b\x64ouble_data\x18\x05 \x03(\x01\x12\x12\n\nint32_data\x18\x06 \x03(\x11\x12\x12\n\nint64_data\x18\x07 \x03(\x12\x12\x13\n\x0buint32_data\x18\x08 \x03(\r\x12\x13\n\x0bstring_data\x18\t \x03(\t*\x95\x01\n\x05\x44Type\x12\x11\n\rDTYPE_UNKNOWN\x10\x00\x12\x11\n\rDTYPE_FLOAT32\x10\x01\x12\x11\n\rDTYPE_FLOAT64\x10\x02\x12\x0e\n\nDTYPE_INT8\x10\x03\x12\x0f\n\x0b\x44TYPE_INT32\x10\x04\x12\x0f\n\x0b\x44TYPE_INT64\x10\x05\x12\x0f\n\x0b\x44TYPE_UINT8\x10\x06\x12\x10\n\x0c\x44TYPE_STRING\x10\x07\x62\x06proto3')

_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "ndarray_pb2", globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'ndarray_pb2', globals())
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_DTYPE._serialized_start = 227
_DTYPE._serialized_end = 358
_ARRAY._serialized_start = 40
_ARRAY._serialized_end = 224

DESCRIPTOR._options = None
_DTYPE._serialized_start=248
_DTYPE._serialized_end=397
_ARRAY._serialized_start=40
_ARRAY._serialized_end=245
# @@protoc_insertion_point(module_scope)
40 changes: 20 additions & 20 deletions cogment_lab/generated/spaces_pb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: spaces.proto
"""Generated protocol buffer code."""
Expand All @@ -30,26 +29,27 @@
import cogment_lab.generated.ndarray_pb2 as ndarray__pb2


DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x0cspaces.proto\x12\x12\x63ogment_lab.spaces\x1a\rndarray.proto"$\n\x08\x44iscrete\x12\t\n\x01n\x18\x01 \x01(\x05\x12\r\n\x05start\x18\x02 \x01(\x05"Z\n\x03\x42ox\x12(\n\x03low\x18\x02 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\x12)\n\x04high\x18\x03 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array"5\n\x0bMultiBinary\x12&\n\x01n\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array":\n\rMultiDiscrete\x12)\n\x04nvec\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array"|\n\x04\x44ict\x12\x31\n\x06spaces\x18\x01 \x03(\x0b\x32!.cogment_lab.spaces.Dict.SubSpace\x1a\x41\n\x08SubSpace\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05space\x18\x02 \x01(\x0b\x32\x19.cogment_lab.spaces.Space"\x89\x02\n\x05Space\x12\x30\n\x08\x64iscrete\x18\x01 \x01(\x0b\x32\x1c.cogment_lab.spaces.DiscreteH\x00\x12&\n\x03\x62ox\x18\x02 \x01(\x0b\x32\x17.cogment_lab.spaces.BoxH\x00\x12(\n\x04\x64ict\x18\x03 \x01(\x0b\x32\x18.cogment_lab.spaces.DictH\x00\x12\x37\n\x0cmulti_binary\x18\x04 \x01(\x0b\x32\x1f.cogment_lab.spaces.MultiBinaryH\x00\x12;\n\x0emulti_discrete\x18\x05 \x01(\x0b\x32!.cogment_lab.spaces.MultiDiscreteH\x00\x42\x06\n\x04kindb\x06proto3'
)
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0cspaces.proto\x12\x12\x63ogment_lab.spaces\x1a\rndarray.proto\"$\n\x08\x44iscrete\x12\t\n\x01n\x18\x01 \x01(\x05\x12\r\n\x05start\x18\x02 \x01(\x05\"Z\n\x03\x42ox\x12(\n\x03low\x18\x02 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\x12)\n\x04high\x18\x03 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\"5\n\x0bMultiBinary\x12&\n\x01n\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\":\n\rMultiDiscrete\x12)\n\x04nvec\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\"|\n\x04\x44ict\x12\x31\n\x06spaces\x18\x01 \x03(\x0b\x32!.cogment_lab.spaces.Dict.SubSpace\x1a\x41\n\x08SubSpace\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05space\x18\x02 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\"?\n\x04Text\x12\x12\n\nmax_length\x18\x01 \x01(\x05\x12\x12\n\nmin_length\x18\x02 \x01(\x05\x12\x0f\n\x07\x63harset\x18\x03 \x01(\t\"\xb3\x02\n\x05Space\x12\x30\n\x08\x64iscrete\x18\x01 \x01(\x0b\x32\x1c.cogment_lab.spaces.DiscreteH\x00\x12&\n\x03\x62ox\x18\x02 \x01(\x0b\x32\x17.cogment_lab.spaces.BoxH\x00\x12(\n\x04\x64ict\x18\x03 \x01(\x0b\x32\x18.cogment_lab.spaces.DictH\x00\x12\x37\n\x0cmulti_binary\x18\x04 \x01(\x0b\x32\x1f.cogment_lab.spaces.MultiBinaryH\x00\x12;\n\x0emulti_discrete\x18\x05 \x01(\x0b\x32!.cogment_lab.spaces.MultiDiscreteH\x00\x12(\n\x04text\x18\x06 \x01(\x0b\x32\x18.cogment_lab.spaces.TextH\x00\x42\x06\n\x04kindb\x06proto3')

_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "spaces_pb2", globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'spaces_pb2', globals())
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_DISCRETE._serialized_start = 51
_DISCRETE._serialized_end = 87
_BOX._serialized_start = 89
_BOX._serialized_end = 179
_MULTIBINARY._serialized_start = 181
_MULTIBINARY._serialized_end = 234
_MULTIDISCRETE._serialized_start = 236
_MULTIDISCRETE._serialized_end = 294
_DICT._serialized_start = 296
_DICT._serialized_end = 420
_DICT_SUBSPACE._serialized_start = 355
_DICT_SUBSPACE._serialized_end = 420
_SPACE._serialized_start = 423
_SPACE._serialized_end = 688

DESCRIPTOR._options = None
_DISCRETE._serialized_start=51
_DISCRETE._serialized_end=87
_BOX._serialized_start=89
_BOX._serialized_end=179
_MULTIBINARY._serialized_start=181
_MULTIBINARY._serialized_end=234
_MULTIDISCRETE._serialized_start=236
_MULTIDISCRETE._serialized_end=294
_DICT._serialized_start=296
_DICT._serialized_end=420
_DICT_SUBSPACE._serialized_start=355
_DICT_SUBSPACE._serialized_end=420
_TEXT._serialized_start=422
_TEXT._serialized_end=485
_SPACE._serialized_start=488
_SPACE._serialized_end=795
# @@protoc_insertion_point(module_scope)
160 changes: 160 additions & 0 deletions cogment_lab/humans/gradio_actor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright 2024 AI Redefined Inc. <[email protected]>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import asyncio
import json
import logging
import multiprocessing as mp
import signal
from typing import Any, Callable

import cogment
import numpy as np

from cogment_lab.core import CogmentActor
from cogment_lab.generated import cog_settings
from cogment_lab.utils.runners import setup_logging


def obs_to_msg(obs: np.ndarray | dict[str, np.ndarray | dict]) -> dict[str, Any]:
if isinstance(obs, np.ndarray):
obs = obs.tolist()
elif isinstance(obs, dict):
obs = {k: obs_to_msg(v) for k, v in obs.items()}
elif isinstance(obs, np.integer):
obs = int(obs)
elif isinstance(obs, np.floating):
obs = float(obs)
return obs


def msg_to_action(data: str, action_map: list[str] | dict[str, int]) -> int:
if isinstance(action_map, list):
action_map = {action: i for i, action in enumerate(action_map)}
if data.startswith("{"):
action = json.loads(data)
elif data not in action_map:
action = action_map["no-op"]
else:
action = action_map[data]
logging.info(f"Processed action {action} from {data} with action_map {action_map}")
return action


class GradioActor(CogmentActor):
def __init__(self, send_queue: mp.Queue, recv_queue: mp.Queue):
super().__init__(send_queue, recv_queue)
self.send_queue = send_queue
self.recv_queue = recv_queue

async def act(self, observation: Any, rendered_frame: np.ndarray | None = None) -> int:
# logging.info(f"Received observation {observation} and frame inside gradio actor")
obs_data = obs_to_msg(observation)
self.send_queue.put((obs_data, rendered_frame))
# logging.info(f"Sent observation {obs_data} and frame inside gradio actor")
action = self.recv_queue.get()
# logging.info(f"Received action {action} inside gradio actor")
return action

async def on_ending(self, observation, rendered_frame):
obs_data = obs_to_msg(observation)
self.send_queue.put((obs_data, rendered_frame))


async def run_cogment_actor(port: int, send_queue: asyncio.Queue, recv_queue: asyncio.Queue, signal_queue: mp.Queue):
context = cogment.Context(cog_settings=cog_settings, user_id="cogment_lab")
gradio_actor = GradioActor(send_queue, recv_queue)

logging.info("Registering actor")
context.register_actor(impl=gradio_actor.impl, impl_name="gradio", actor_classes=["player"])

logging.info("Serving actor")
serve = context.serve_all_registered(cogment.ServedEndpoint(port=port))

signal_queue.put(True)
await serve


async def shutdown():
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
asyncio.get_event_loop().stop()


def signal_handler(sig, frame):
asyncio.create_task(shutdown())


async def gradio_actor_main(
cogment_port: int,
gradio_app_fn: Callable[[mp.Queue, mp.Queue, str], None],
signal_queue: mp.Queue,
log_file: str | None = None,
):
gradio_to_actor = mp.Queue()
actor_to_gradio = mp.Queue()

logging.info("Starting gradio interface")
process = mp.Process(target=gradio_app_fn, args=(gradio_to_actor, actor_to_gradio, log_file))
process.start()

try:
logging.info("Starting cogment actor")
cogment_task = asyncio.create_task(
run_cogment_actor(
port=cogment_port,
send_queue=actor_to_gradio,
recv_queue=gradio_to_actor,
signal_queue=signal_queue,
)
)

logging.info("Waiting for cogment actor to finish")

await cogment_task
finally:
process.terminate()
process.join()


def gradio_actor_runner(
cogment_port: int,
gradio_app_fn: Callable[[mp.Queue, mp.Queue, str], None],
signal_queue: mp.Queue,
log_file: str | None = None,
):
if log_file:
setup_logging(log_file)

loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
for sig in [signal.SIGINT, signal.SIGTERM]:
loop.add_signal_handler(sig, lambda s=sig, frame=None: signal_handler(s, frame))

try:
loop.run_until_complete(
gradio_actor_main(
cogment_port=cogment_port,
gradio_app_fn=gradio_app_fn,
signal_queue=signal_queue,
log_file=log_file,
)
)
finally:
loop.run_until_complete(shutdown())
loop.close()
Loading

0 comments on commit 1302a37

Please sign in to comment.