Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Gymnasium Tuple Space #20

Merged
merged 5 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions cogment_lab/generated/cog_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@


_player_class = _cog.actor.ActorClass(
name="player",
config_type=data_pb.AgentConfig,
action_space=data_pb.PlayerAction,
observation_space=data_pb.Observation,
)
name="player",
config_type=data_pb.AgentConfig,
action_space=data_pb.PlayerAction,
observation_space=data_pb.Observation,
)


actor_classes = _cog.actor.ActorClassList(_player_class)
Expand Down
88 changes: 85 additions & 3 deletions cogment_lab/generated/data_pb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder


# @@protoc_insertion_point(imports)
Expand All @@ -32,8 +33,89 @@

DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\ndata.proto\x12\x0b\x63ogment_lab\x1a\rndarray.proto\x1a\x0cspaces.proto\"\xd7\x01\n\x10\x45nvironmentSpecs\x12\x16\n\x0eimplementation\x18\x01 \x01(\t\x12\x12\n\nturn_based\x18\x02 \x01(\x08\x12\x13\n\x0bnum_players\x18\x03 \x01(\x05\x12\x34\n\x11observation_space\x18\x04 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\x12/\n\x0c\x61\x63tion_space\x18\x05 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\x12\x1b\n\x13web_components_file\x18\x06 \x01(\t\"s\n\nAgentSpecs\x12\x34\n\x11observation_space\x18\x01 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\x12/\n\x0c\x61\x63tion_space\x18\x02 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\"Y\n\x05Value\x12\x16\n\x0cstring_value\x18\x01 \x01(\tH\x00\x12\x13\n\tint_value\x18\x02 \x01(\x05H\x00\x12\x15\n\x0b\x66loat_value\x18\x03 \x01(\x02H\x00\x42\x0c\n\nvalue_type\"\xf1\x01\n\x11\x45nvironmentConfig\x12\x0e\n\x06run_id\x18\x01 \x01(\t\x12\x0e\n\x06render\x18\x02 \x01(\x08\x12\x14\n\x0crender_width\x18\x03 \x01(\x05\x12\x0c\n\x04seed\x18\x04 \x01(\r\x12\x0f\n\x07\x66latten\x18\x05 \x01(\x08\x12\x41\n\nreset_args\x18\x06 \x03(\x0b\x32-.cogment_lab.EnvironmentConfig.ResetArgsEntry\x1a\x44\n\x0eResetArgsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.cogment_lab.Value:\x02\x38\x01\"/\n\nHFHubModel\x12\x0f\n\x07repo_id\x18\x01 \x01(\t\x12\x10\n\x08\x66ilename\x18\x02 \x01(\t\"\xa4\x01\n\x0b\x41gentConfig\x12\x0e\n\x06run_id\x18\x01 \x01(\t\x12,\n\x0b\x61gent_specs\x18\x02 \x01(\x0b\x32\x17.cogment_lab.AgentSpecs\x12\x0c\n\x04seed\x18\x03 \x01(\r\x12\x10\n\x08model_id\x18\x04 \x01(\t\x12\x17\n\x0fmodel_iteration\x18\x05 \x01(\x05\x12\x1e\n\x16model_update_frequency\x18\x06 \x01(\x05\"\r\n\x0bTrialConfig\"\x88\x01\n\x0bObservation\x12*\n\x05value\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\x12\x0e\n\x06\x61\x63tive\x18\x02 \x01(\x08\x12\r\n\x05\x61live\x18\x03 \x01(\x08\x12\x1b\n\x0erendered_frame\x18\x04 \x01(\x0cH\x00\x88\x01\x01\x42\x11\n\x0f_rendered_frame\":\n\x0cPlayerAction\x12*\n\x05value\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Arrayb\x06proto3')

_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'data_pb2', globals())


_ENVIRONMENTSPECS = DESCRIPTOR.message_types_by_name['EnvironmentSpecs']
_AGENTSPECS = DESCRIPTOR.message_types_by_name['AgentSpecs']
_VALUE = DESCRIPTOR.message_types_by_name['Value']
_ENVIRONMENTCONFIG = DESCRIPTOR.message_types_by_name['EnvironmentConfig']
_ENVIRONMENTCONFIG_RESETARGSENTRY = _ENVIRONMENTCONFIG.nested_types_by_name['ResetArgsEntry']
_HFHUBMODEL = DESCRIPTOR.message_types_by_name['HFHubModel']
_AGENTCONFIG = DESCRIPTOR.message_types_by_name['AgentConfig']
_TRIALCONFIG = DESCRIPTOR.message_types_by_name['TrialConfig']
_OBSERVATION = DESCRIPTOR.message_types_by_name['Observation']
_PLAYERACTION = DESCRIPTOR.message_types_by_name['PlayerAction']
EnvironmentSpecs = _reflection.GeneratedProtocolMessageType('EnvironmentSpecs', (_message.Message,), {
'DESCRIPTOR' : _ENVIRONMENTSPECS,
'__module__' : 'data_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.EnvironmentSpecs)
})
_sym_db.RegisterMessage(EnvironmentSpecs)

AgentSpecs = _reflection.GeneratedProtocolMessageType('AgentSpecs', (_message.Message,), {
'DESCRIPTOR' : _AGENTSPECS,
'__module__' : 'data_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.AgentSpecs)
})
_sym_db.RegisterMessage(AgentSpecs)

Value = _reflection.GeneratedProtocolMessageType('Value', (_message.Message,), {
'DESCRIPTOR' : _VALUE,
'__module__' : 'data_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.Value)
})
_sym_db.RegisterMessage(Value)

EnvironmentConfig = _reflection.GeneratedProtocolMessageType('EnvironmentConfig', (_message.Message,), {

'ResetArgsEntry' : _reflection.GeneratedProtocolMessageType('ResetArgsEntry', (_message.Message,), {
'DESCRIPTOR' : _ENVIRONMENTCONFIG_RESETARGSENTRY,
'__module__' : 'data_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.EnvironmentConfig.ResetArgsEntry)
})
,
'DESCRIPTOR' : _ENVIRONMENTCONFIG,
'__module__' : 'data_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.EnvironmentConfig)
})
_sym_db.RegisterMessage(EnvironmentConfig)
_sym_db.RegisterMessage(EnvironmentConfig.ResetArgsEntry)

HFHubModel = _reflection.GeneratedProtocolMessageType('HFHubModel', (_message.Message,), {
'DESCRIPTOR' : _HFHUBMODEL,
'__module__' : 'data_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.HFHubModel)
})
_sym_db.RegisterMessage(HFHubModel)

AgentConfig = _reflection.GeneratedProtocolMessageType('AgentConfig', (_message.Message,), {
'DESCRIPTOR' : _AGENTCONFIG,
'__module__' : 'data_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.AgentConfig)
})
_sym_db.RegisterMessage(AgentConfig)

TrialConfig = _reflection.GeneratedProtocolMessageType('TrialConfig', (_message.Message,), {
'DESCRIPTOR' : _TRIALCONFIG,
'__module__' : 'data_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.TrialConfig)
})
_sym_db.RegisterMessage(TrialConfig)

Observation = _reflection.GeneratedProtocolMessageType('Observation', (_message.Message,), {
'DESCRIPTOR' : _OBSERVATION,
'__module__' : 'data_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.Observation)
})
_sym_db.RegisterMessage(Observation)

PlayerAction = _reflection.GeneratedProtocolMessageType('PlayerAction', (_message.Message,), {
'DESCRIPTOR' : _PLAYERACTION,
'__module__' : 'data_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.PlayerAction)
})
_sym_db.RegisterMessage(PlayerAction)

if _descriptor._USE_C_DESCRIPTORS == False:

DESCRIPTOR._options = None
Expand Down
26 changes: 23 additions & 3 deletions cogment_lab/generated/ndarray_pb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
from google.protobuf.internal import enum_type_wrapper


# @@protoc_insertion_point(imports)
Expand All @@ -30,8 +32,26 @@

DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rndarray.proto\x12\x14\x63ogment_lab.nd_array\"\xcd\x01\n\x05\x41rray\x12*\n\x05\x64type\x18\x01 \x01(\x0e\x32\x1b.cogment_lab.nd_array.DType\x12\r\n\x05shape\x18\x02 \x03(\r\x12\x10\n\x08raw_data\x18\x03 \x01(\x0c\x12\x10\n\x08npy_data\x18\x04 \x01(\x0c\x12\x13\n\x0b\x64ouble_data\x18\x05 \x03(\x01\x12\x12\n\nint32_data\x18\x06 \x03(\x11\x12\x12\n\nint64_data\x18\x07 \x03(\x12\x12\x13\n\x0buint32_data\x18\x08 \x03(\r\x12\x13\n\x0bstring_data\x18\t \x03(\t*\x95\x01\n\x05\x44Type\x12\x11\n\rDTYPE_UNKNOWN\x10\x00\x12\x11\n\rDTYPE_FLOAT32\x10\x01\x12\x11\n\rDTYPE_FLOAT64\x10\x02\x12\x0e\n\nDTYPE_INT8\x10\x03\x12\x0f\n\x0b\x44TYPE_INT32\x10\x04\x12\x0f\n\x0b\x44TYPE_INT64\x10\x05\x12\x0f\n\x0b\x44TYPE_UINT8\x10\x06\x12\x10\n\x0c\x44TYPE_STRING\x10\x07\x62\x06proto3')

_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'ndarray_pb2', globals())
_DTYPE = DESCRIPTOR.enum_types_by_name['DType']
DType = enum_type_wrapper.EnumTypeWrapper(_DTYPE)
DTYPE_UNKNOWN = 0
DTYPE_FLOAT32 = 1
DTYPE_FLOAT64 = 2
DTYPE_INT8 = 3
DTYPE_INT32 = 4
DTYPE_INT64 = 5
DTYPE_UINT8 = 6
DTYPE_STRING = 7


_ARRAY = DESCRIPTOR.message_types_by_name['Array']
Array = _reflection.GeneratedProtocolMessageType('Array', (_message.Message,), {
'DESCRIPTOR' : _ARRAY,
'__module__' : 'ndarray_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.nd_array.Array)
})
_sym_db.RegisterMessage(Array)

if _descriptor._USE_C_DESCRIPTORS == False:

DESCRIPTOR._options = None
Expand Down
103 changes: 95 additions & 8 deletions cogment_lab/generated/spaces_pb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder


# @@protoc_insertion_point(imports)
Expand All @@ -29,10 +30,92 @@
import cogment_lab.generated.ndarray_pb2 as ndarray__pb2


DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0cspaces.proto\x12\x12\x63ogment_lab.spaces\x1a\rndarray.proto\"$\n\x08\x44iscrete\x12\t\n\x01n\x18\x01 \x01(\x05\x12\r\n\x05start\x18\x02 \x01(\x05\"Z\n\x03\x42ox\x12(\n\x03low\x18\x02 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\x12)\n\x04high\x18\x03 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\"5\n\x0bMultiBinary\x12&\n\x01n\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\":\n\rMultiDiscrete\x12)\n\x04nvec\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\"|\n\x04\x44ict\x12\x31\n\x06spaces\x18\x01 \x03(\x0b\x32!.cogment_lab.spaces.Dict.SubSpace\x1a\x41\n\x08SubSpace\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05space\x18\x02 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\"?\n\x04Text\x12\x12\n\nmax_length\x18\x01 \x01(\x05\x12\x12\n\nmin_length\x18\x02 \x01(\x05\x12\x0f\n\x07\x63harset\x18\x03 \x01(\t\"\xb3\x02\n\x05Space\x12\x30\n\x08\x64iscrete\x18\x01 \x01(\x0b\x32\x1c.cogment_lab.spaces.DiscreteH\x00\x12&\n\x03\x62ox\x18\x02 \x01(\x0b\x32\x17.cogment_lab.spaces.BoxH\x00\x12(\n\x04\x64ict\x18\x03 \x01(\x0b\x32\x18.cogment_lab.spaces.DictH\x00\x12\x37\n\x0cmulti_binary\x18\x04 \x01(\x0b\x32\x1f.cogment_lab.spaces.MultiBinaryH\x00\x12;\n\x0emulti_discrete\x18\x05 \x01(\x0b\x32!.cogment_lab.spaces.MultiDiscreteH\x00\x12(\n\x04text\x18\x06 \x01(\x0b\x32\x18.cogment_lab.spaces.TextH\x00\x42\x06\n\x04kindb\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0cspaces.proto\x12\x12\x63ogment_lab.spaces\x1a\rndarray.proto\"$\n\x08\x44iscrete\x12\t\n\x01n\x18\x01 \x01(\x05\x12\r\n\x05start\x18\x02 \x01(\x05\"Z\n\x03\x42ox\x12(\n\x03low\x18\x02 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\x12)\n\x04high\x18\x03 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\"5\n\x0bMultiBinary\x12&\n\x01n\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\":\n\rMultiDiscrete\x12)\n\x04nvec\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\"|\n\x04\x44ict\x12\x31\n\x06spaces\x18\x01 \x03(\x0b\x32!.cogment_lab.spaces.Dict.SubSpace\x1a\x41\n\x08SubSpace\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05space\x18\x02 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\"q\n\x05Tuple\x12\x32\n\x06spaces\x18\x02 \x03(\x0b\x32\".cogment_lab.spaces.Tuple.SubSpace\x1a\x34\n\x08SubSpace\x12(\n\x05space\x18\x01 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\"?\n\x04Text\x12\x12\n\nmax_length\x18\x01 \x01(\x05\x12\x12\n\nmin_length\x18\x02 \x01(\x05\x12\x0f\n\x07\x63harset\x18\x03 \x01(\t\"\xdf\x02\n\x05Space\x12\x30\n\x08\x64iscrete\x18\x01 \x01(\x0b\x32\x1c.cogment_lab.spaces.DiscreteH\x00\x12&\n\x03\x62ox\x18\x02 \x01(\x0b\x32\x17.cogment_lab.spaces.BoxH\x00\x12(\n\x04\x64ict\x18\x03 \x01(\x0b\x32\x18.cogment_lab.spaces.DictH\x00\x12\x37\n\x0cmulti_binary\x18\x04 \x01(\x0b\x32\x1f.cogment_lab.spaces.MultiBinaryH\x00\x12;\n\x0emulti_discrete\x18\x05 \x01(\x0b\x32!.cogment_lab.spaces.MultiDiscreteH\x00\x12*\n\x05tuple\x18\x06 \x01(\x0b\x32\x19.cogment_lab.spaces.TupleH\x00\x12(\n\x04text\x18\x07 \x01(\x0b\x32\x18.cogment_lab.spaces.TextH\x00\x42\x06\n\x04kindb\x06proto3')



_DISCRETE = DESCRIPTOR.message_types_by_name['Discrete']
_BOX = DESCRIPTOR.message_types_by_name['Box']
_MULTIBINARY = DESCRIPTOR.message_types_by_name['MultiBinary']
_MULTIDISCRETE = DESCRIPTOR.message_types_by_name['MultiDiscrete']
_DICT = DESCRIPTOR.message_types_by_name['Dict']
_DICT_SUBSPACE = _DICT.nested_types_by_name['SubSpace']
_TUPLE = DESCRIPTOR.message_types_by_name['Tuple']
_TUPLE_SUBSPACE = _TUPLE.nested_types_by_name['SubSpace']
_TEXT = DESCRIPTOR.message_types_by_name['Text']
_SPACE = DESCRIPTOR.message_types_by_name['Space']
Discrete = _reflection.GeneratedProtocolMessageType('Discrete', (_message.Message,), {
'DESCRIPTOR' : _DISCRETE,
'__module__' : 'spaces_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.spaces.Discrete)
})
_sym_db.RegisterMessage(Discrete)

Box = _reflection.GeneratedProtocolMessageType('Box', (_message.Message,), {
'DESCRIPTOR' : _BOX,
'__module__' : 'spaces_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.spaces.Box)
})
_sym_db.RegisterMessage(Box)

MultiBinary = _reflection.GeneratedProtocolMessageType('MultiBinary', (_message.Message,), {
'DESCRIPTOR' : _MULTIBINARY,
'__module__' : 'spaces_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.spaces.MultiBinary)
})
_sym_db.RegisterMessage(MultiBinary)

MultiDiscrete = _reflection.GeneratedProtocolMessageType('MultiDiscrete', (_message.Message,), {
'DESCRIPTOR' : _MULTIDISCRETE,
'__module__' : 'spaces_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.spaces.MultiDiscrete)
})
_sym_db.RegisterMessage(MultiDiscrete)

Dict = _reflection.GeneratedProtocolMessageType('Dict', (_message.Message,), {

'SubSpace' : _reflection.GeneratedProtocolMessageType('SubSpace', (_message.Message,), {
'DESCRIPTOR' : _DICT_SUBSPACE,
'__module__' : 'spaces_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.spaces.Dict.SubSpace)
})
,
'DESCRIPTOR' : _DICT,
'__module__' : 'spaces_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.spaces.Dict)
})
_sym_db.RegisterMessage(Dict)
_sym_db.RegisterMessage(Dict.SubSpace)

Tuple = _reflection.GeneratedProtocolMessageType('Tuple', (_message.Message,), {

'SubSpace' : _reflection.GeneratedProtocolMessageType('SubSpace', (_message.Message,), {
'DESCRIPTOR' : _TUPLE_SUBSPACE,
'__module__' : 'spaces_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.spaces.Tuple.SubSpace)
})
,
'DESCRIPTOR' : _TUPLE,
'__module__' : 'spaces_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.spaces.Tuple)
})
_sym_db.RegisterMessage(Tuple)
_sym_db.RegisterMessage(Tuple.SubSpace)

Text = _reflection.GeneratedProtocolMessageType('Text', (_message.Message,), {
'DESCRIPTOR' : _TEXT,
'__module__' : 'spaces_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.spaces.Text)
})
_sym_db.RegisterMessage(Text)

Space = _reflection.GeneratedProtocolMessageType('Space', (_message.Message,), {
'DESCRIPTOR' : _SPACE,
'__module__' : 'spaces_pb2'
# @@protoc_insertion_point(class_scope:cogment_lab.spaces.Space)
})
_sym_db.RegisterMessage(Space)

_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'spaces_pb2', globals())
if _descriptor._USE_C_DESCRIPTORS == False:

DESCRIPTOR._options = None
Expand All @@ -48,8 +131,12 @@
_DICT._serialized_end=420
_DICT_SUBSPACE._serialized_start=355
_DICT_SUBSPACE._serialized_end=420
_TEXT._serialized_start=422
_TEXT._serialized_end=485
_SPACE._serialized_start=488
_SPACE._serialized_end=795
_TUPLE._serialized_start=422
_TUPLE._serialized_end=535
_TUPLE_SUBSPACE._serialized_start=483
_TUPLE_SUBSPACE._serialized_end=535
_TEXT._serialized_start=537
_TEXT._serialized_end=600
_SPACE._serialized_start=603
_SPACE._serialized_end=954
# @@protoc_insertion_point(module_scope)
10 changes: 9 additions & 1 deletion cogment_lab/protos/spaces.proto
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ message Dict {
repeated SubSpace spaces = 1;
}

message Tuple {
message SubSpace {
Space space = 1;
}
repeated SubSpace spaces = 2;
}

message Text {
int32 max_length = 1;
int32 min_length = 2;
Expand All @@ -57,6 +64,7 @@ message Space {
Dict dict = 3;
MultiBinary multi_binary = 4;
MultiDiscrete multi_discrete = 5;
Text text = 6;
Tuple tuple = 6;
Text text = 7;
}
}
12 changes: 11 additions & 1 deletion cogment_lab/specs/spaces_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from cogment_lab.generated.spaces_pb2 import MultiDiscrete # type: ignore
from cogment_lab.generated.spaces_pb2 import Space # type: ignore
from cogment_lab.generated.spaces_pb2 import Text # type: ignore
from cogment_lab.generated.spaces_pb2 import Tuple # type: ignore

from .ndarray_serialization import (
SerializationFormat,
Expand Down Expand Up @@ -66,6 +67,12 @@ def serialize_gym_space(space: gym.Space, serialization_format=SerializationForm
spaces.append(Dict.SubSpace(key=key, space=serialize_gym_space(gym_sub_space)))
return Space(dict=Dict(spaces=spaces))

if isinstance(space, gym.spaces.Tuple):
spaces = []
for gym_sub_space in space.spaces:
spaces.append(Tuple.SubSpace(space=serialize_gym_space(gym_sub_space, serialization_format)))
return Space(tuple=Tuple(spaces=spaces))

if isinstance(space, gym.spaces.Text):
return Space(text=Text(max_length=space.max_length, min_length=space.min_length, charset=space.characters))

Expand Down Expand Up @@ -101,8 +108,11 @@ def deserialize_space(pb_space: Space) -> gym.Space:
spaces = []
for sub_space in dict_space_pb.spaces:
spaces.append((sub_space.key, deserialize_space(sub_space.space)))

return gym.spaces.Dict(spaces=spaces)
if space_kind == "tuple":
tuple_space_pb = pb_space.tuple
spaces = [deserialize_space(sub_space) for sub_space in tuple_space_pb.spaces]
return gym.spaces.Tuple(spaces=spaces)
if space_kind == "text":
text_space_pb = pb_space.text
return gym.spaces.Text(
Expand Down
Loading
Loading