From 1646eb9bb97423571479dc2542a1400e8c543421 Mon Sep 17 00:00:00 2001 From: wduguay-air Date: Wed, 3 Apr 2024 11:19:53 -0400 Subject: [PATCH] tuple support --- cogment_lab/generated/cog_settings.py | 29 ++---- cogment_lab/generated/data_pb2.py | 105 +++++++++++++++---- cogment_lab/generated/ndarray_pb2.py | 43 ++++---- cogment_lab/generated/spaces_pb2.py | 120 +++++++++++++++++----- cogment_lab/protos/spaces.proto | 11 +- cogment_lab/specs/spaces_serialization.py | 12 ++- tests/test_spaces_serialization.py | 90 ++++++++++++++++ 7 files changed, 324 insertions(+), 86 deletions(-) create mode 100644 tests/test_spaces_serialization.py diff --git a/cogment_lab/generated/cog_settings.py b/cogment_lab/generated/cog_settings.py index 6f3a19c..f2d7e56 100644 --- a/cogment_lab/generated/cog_settings.py +++ b/cogment_lab/generated/cog_settings.py @@ -1,32 +1,17 @@ -# Copyright 2024 AI Redefined Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from types import SimpleNamespace import cogment as _cog +from types import SimpleNamespace -import cogment_lab.generated.data_pb2 as data_pb import cogment_lab.generated.ndarray_pb2 as ndarray_pb import cogment_lab.generated.spaces_pb2 as spaces_pb - +import cogment_lab.generated.data_pb2 as data_pb _player_class = _cog.actor.ActorClass( - name="player", - config_type=data_pb.AgentConfig, - action_space=data_pb.PlayerAction, - observation_space=data_pb.Observation, -) + name="player", + config_type=data_pb.AgentConfig, + action_space=data_pb.PlayerAction, + observation_space=data_pb.Observation, + ) actor_classes = _cog.actor.ActorClassList(_player_class) diff --git a/cogment_lab/generated/data_pb2.py b/cogment_lab/generated/data_pb2.py index 9a11fbc..df60944 100644 --- a/cogment_lab/generated/data_pb2.py +++ b/cogment_lab/generated/data_pb2.py @@ -1,26 +1,12 @@ -# Copyright 2024 AI Redefined Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - +# -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: data.proto """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database -from google.protobuf.internal import builder as _builder - - # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -32,8 +18,89 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\ndata.proto\x12\x0b\x63ogment_lab\x1a\rndarray.proto\x1a\x0cspaces.proto\"\xd7\x01\n\x10\x45nvironmentSpecs\x12\x16\n\x0eimplementation\x18\x01 \x01(\t\x12\x12\n\nturn_based\x18\x02 \x01(\x08\x12\x13\n\x0bnum_players\x18\x03 \x01(\x05\x12\x34\n\x11observation_space\x18\x04 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\x12/\n\x0c\x61\x63tion_space\x18\x05 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\x12\x1b\n\x13web_components_file\x18\x06 \x01(\t\"s\n\nAgentSpecs\x12\x34\n\x11observation_space\x18\x01 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\x12/\n\x0c\x61\x63tion_space\x18\x02 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\"Y\n\x05Value\x12\x16\n\x0cstring_value\x18\x01 \x01(\tH\x00\x12\x13\n\tint_value\x18\x02 \x01(\x05H\x00\x12\x15\n\x0b\x66loat_value\x18\x03 \x01(\x02H\x00\x42\x0c\n\nvalue_type\"\xf1\x01\n\x11\x45nvironmentConfig\x12\x0e\n\x06run_id\x18\x01 \x01(\t\x12\x0e\n\x06render\x18\x02 \x01(\x08\x12\x14\n\x0crender_width\x18\x03 \x01(\x05\x12\x0c\n\x04seed\x18\x04 \x01(\r\x12\x0f\n\x07\x66latten\x18\x05 \x01(\x08\x12\x41\n\nreset_args\x18\x06 \x03(\x0b\x32-.cogment_lab.EnvironmentConfig.ResetArgsEntry\x1a\x44\n\x0eResetArgsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.cogment_lab.Value:\x02\x38\x01\"/\n\nHFHubModel\x12\x0f\n\x07repo_id\x18\x01 \x01(\t\x12\x10\n\x08\x66ilename\x18\x02 \x01(\t\"\xa4\x01\n\x0b\x41gentConfig\x12\x0e\n\x06run_id\x18\x01 \x01(\t\x12,\n\x0b\x61gent_specs\x18\x02 \x01(\x0b\x32\x17.cogment_lab.AgentSpecs\x12\x0c\n\x04seed\x18\x03 \x01(\r\x12\x10\n\x08model_id\x18\x04 \x01(\t\x12\x17\n\x0fmodel_iteration\x18\x05 \x01(\x05\x12\x1e\n\x16model_update_frequency\x18\x06 \x01(\x05\"\r\n\x0bTrialConfig\"\x88\x01\n\x0bObservation\x12*\n\x05value\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\x12\x0e\n\x06\x61\x63tive\x18\x02 \x01(\x08\x12\r\n\x05\x61live\x18\x03 \x01(\x08\x12\x1b\n\x0erendered_frame\x18\x04 \x01(\x0cH\x00\x88\x01\x01\x42\x11\n\x0f_rendered_frame\":\n\x0cPlayerAction\x12*\n\x05value\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Arrayb\x06proto3') -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'data_pb2', globals()) + + +_ENVIRONMENTSPECS = DESCRIPTOR.message_types_by_name['EnvironmentSpecs'] +_AGENTSPECS = DESCRIPTOR.message_types_by_name['AgentSpecs'] +_VALUE = DESCRIPTOR.message_types_by_name['Value'] +_ENVIRONMENTCONFIG = DESCRIPTOR.message_types_by_name['EnvironmentConfig'] +_ENVIRONMENTCONFIG_RESETARGSENTRY = _ENVIRONMENTCONFIG.nested_types_by_name['ResetArgsEntry'] +_HFHUBMODEL = DESCRIPTOR.message_types_by_name['HFHubModel'] +_AGENTCONFIG = DESCRIPTOR.message_types_by_name['AgentConfig'] +_TRIALCONFIG = DESCRIPTOR.message_types_by_name['TrialConfig'] +_OBSERVATION = DESCRIPTOR.message_types_by_name['Observation'] +_PLAYERACTION = DESCRIPTOR.message_types_by_name['PlayerAction'] +EnvironmentSpecs = _reflection.GeneratedProtocolMessageType('EnvironmentSpecs', (_message.Message,), { + 'DESCRIPTOR' : _ENVIRONMENTSPECS, + '__module__' : 'data_pb2' + # @@protoc_insertion_point(class_scope:cogment_lab.EnvironmentSpecs) + }) +_sym_db.RegisterMessage(EnvironmentSpecs) + +AgentSpecs = _reflection.GeneratedProtocolMessageType('AgentSpecs', (_message.Message,), { + 'DESCRIPTOR' : _AGENTSPECS, + '__module__' : 'data_pb2' + # @@protoc_insertion_point(class_scope:cogment_lab.AgentSpecs) + }) +_sym_db.RegisterMessage(AgentSpecs) + +Value = _reflection.GeneratedProtocolMessageType('Value', (_message.Message,), { + 'DESCRIPTOR' : _VALUE, + '__module__' : 'data_pb2' + # @@protoc_insertion_point(class_scope:cogment_lab.Value) + }) +_sym_db.RegisterMessage(Value) + +EnvironmentConfig = _reflection.GeneratedProtocolMessageType('EnvironmentConfig', (_message.Message,), { + + 'ResetArgsEntry' : _reflection.GeneratedProtocolMessageType('ResetArgsEntry', (_message.Message,), { + 'DESCRIPTOR' : _ENVIRONMENTCONFIG_RESETARGSENTRY, + '__module__' : 'data_pb2' + # @@protoc_insertion_point(class_scope:cogment_lab.EnvironmentConfig.ResetArgsEntry) + }) + , + 'DESCRIPTOR' : _ENVIRONMENTCONFIG, + '__module__' : 'data_pb2' + # @@protoc_insertion_point(class_scope:cogment_lab.EnvironmentConfig) + }) +_sym_db.RegisterMessage(EnvironmentConfig) +_sym_db.RegisterMessage(EnvironmentConfig.ResetArgsEntry) + +HFHubModel = _reflection.GeneratedProtocolMessageType('HFHubModel', (_message.Message,), { + 'DESCRIPTOR' : _HFHUBMODEL, + '__module__' : 'data_pb2' + # @@protoc_insertion_point(class_scope:cogment_lab.HFHubModel) + }) +_sym_db.RegisterMessage(HFHubModel) + +AgentConfig = _reflection.GeneratedProtocolMessageType('AgentConfig', (_message.Message,), { + 'DESCRIPTOR' : _AGENTCONFIG, + '__module__' : 'data_pb2' + # @@protoc_insertion_point(class_scope:cogment_lab.AgentConfig) + }) +_sym_db.RegisterMessage(AgentConfig) + +TrialConfig = _reflection.GeneratedProtocolMessageType('TrialConfig', (_message.Message,), { + 'DESCRIPTOR' : _TRIALCONFIG, + '__module__' : 'data_pb2' + # @@protoc_insertion_point(class_scope:cogment_lab.TrialConfig) + }) +_sym_db.RegisterMessage(TrialConfig) + +Observation = _reflection.GeneratedProtocolMessageType('Observation', (_message.Message,), { + 'DESCRIPTOR' : _OBSERVATION, + '__module__' : 'data_pb2' + # @@protoc_insertion_point(class_scope:cogment_lab.Observation) + }) +_sym_db.RegisterMessage(Observation) + +PlayerAction = _reflection.GeneratedProtocolMessageType('PlayerAction', (_message.Message,), { + 'DESCRIPTOR' : _PLAYERACTION, + '__module__' : 'data_pb2' + # @@protoc_insertion_point(class_scope:cogment_lab.PlayerAction) + }) +_sym_db.RegisterMessage(PlayerAction) + if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None diff --git a/cogment_lab/generated/ndarray_pb2.py b/cogment_lab/generated/ndarray_pb2.py index 22b3d5c..f85547c 100644 --- a/cogment_lab/generated/ndarray_pb2.py +++ b/cogment_lab/generated/ndarray_pb2.py @@ -1,26 +1,13 @@ -# Copyright 2024 AI Redefined Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - +# -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: ndarray.proto """Generated protocol buffer code.""" +from google.protobuf.internal import enum_type_wrapper from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database -from google.protobuf.internal import builder as _builder - - # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -30,8 +17,26 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rndarray.proto\x12\x14\x63ogment_lab.nd_array\"\xcd\x01\n\x05\x41rray\x12*\n\x05\x64type\x18\x01 \x01(\x0e\x32\x1b.cogment_lab.nd_array.DType\x12\r\n\x05shape\x18\x02 \x03(\r\x12\x10\n\x08raw_data\x18\x03 \x01(\x0c\x12\x10\n\x08npy_data\x18\x04 \x01(\x0c\x12\x13\n\x0b\x64ouble_data\x18\x05 \x03(\x01\x12\x12\n\nint32_data\x18\x06 \x03(\x11\x12\x12\n\nint64_data\x18\x07 \x03(\x12\x12\x13\n\x0buint32_data\x18\x08 \x03(\r\x12\x13\n\x0bstring_data\x18\t \x03(\t*\x95\x01\n\x05\x44Type\x12\x11\n\rDTYPE_UNKNOWN\x10\x00\x12\x11\n\rDTYPE_FLOAT32\x10\x01\x12\x11\n\rDTYPE_FLOAT64\x10\x02\x12\x0e\n\nDTYPE_INT8\x10\x03\x12\x0f\n\x0b\x44TYPE_INT32\x10\x04\x12\x0f\n\x0b\x44TYPE_INT64\x10\x05\x12\x0f\n\x0b\x44TYPE_UINT8\x10\x06\x12\x10\n\x0c\x44TYPE_STRING\x10\x07\x62\x06proto3') -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'ndarray_pb2', globals()) +_DTYPE = DESCRIPTOR.enum_types_by_name['DType'] +DType = enum_type_wrapper.EnumTypeWrapper(_DTYPE) +DTYPE_UNKNOWN = 0 +DTYPE_FLOAT32 = 1 +DTYPE_FLOAT64 = 2 +DTYPE_INT8 = 3 +DTYPE_INT32 = 4 +DTYPE_INT64 = 5 +DTYPE_UINT8 = 6 +DTYPE_STRING = 7 + + +_ARRAY = DESCRIPTOR.message_types_by_name['Array'] +Array = _reflection.GeneratedProtocolMessageType('Array', (_message.Message,), { + 'DESCRIPTOR' : _ARRAY, + '__module__' : 'ndarray_pb2' + # @@protoc_insertion_point(class_scope:cogment_lab.nd_array.Array) + }) +_sym_db.RegisterMessage(Array) + if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None diff --git a/cogment_lab/generated/spaces_pb2.py b/cogment_lab/generated/spaces_pb2.py index 1b29c05..8735b41 100644 --- a/cogment_lab/generated/spaces_pb2.py +++ b/cogment_lab/generated/spaces_pb2.py @@ -1,26 +1,12 @@ -# Copyright 2024 AI Redefined Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - +# -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: spaces.proto """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection from google.protobuf import symbol_database as _symbol_database -from google.protobuf.internal import builder as _builder - - # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -29,10 +15,92 @@ import cogment_lab.generated.ndarray_pb2 as ndarray__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0cspaces.proto\x12\x12\x63ogment_lab.spaces\x1a\rndarray.proto\"$\n\x08\x44iscrete\x12\t\n\x01n\x18\x01 \x01(\x05\x12\r\n\x05start\x18\x02 \x01(\x05\"Z\n\x03\x42ox\x12(\n\x03low\x18\x02 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\x12)\n\x04high\x18\x03 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\"5\n\x0bMultiBinary\x12&\n\x01n\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\":\n\rMultiDiscrete\x12)\n\x04nvec\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\"|\n\x04\x44ict\x12\x31\n\x06spaces\x18\x01 \x03(\x0b\x32!.cogment_lab.spaces.Dict.SubSpace\x1a\x41\n\x08SubSpace\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05space\x18\x02 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\"?\n\x04Text\x12\x12\n\nmax_length\x18\x01 \x01(\x05\x12\x12\n\nmin_length\x18\x02 \x01(\x05\x12\x0f\n\x07\x63harset\x18\x03 \x01(\t\"\xb3\x02\n\x05Space\x12\x30\n\x08\x64iscrete\x18\x01 \x01(\x0b\x32\x1c.cogment_lab.spaces.DiscreteH\x00\x12&\n\x03\x62ox\x18\x02 \x01(\x0b\x32\x17.cogment_lab.spaces.BoxH\x00\x12(\n\x04\x64ict\x18\x03 \x01(\x0b\x32\x18.cogment_lab.spaces.DictH\x00\x12\x37\n\x0cmulti_binary\x18\x04 \x01(\x0b\x32\x1f.cogment_lab.spaces.MultiBinaryH\x00\x12;\n\x0emulti_discrete\x18\x05 \x01(\x0b\x32!.cogment_lab.spaces.MultiDiscreteH\x00\x12(\n\x04text\x18\x06 \x01(\x0b\x32\x18.cogment_lab.spaces.TextH\x00\x42\x06\n\x04kindb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0cspaces.proto\x12\x12\x63ogment_lab.spaces\x1a\rndarray.proto\"$\n\x08\x44iscrete\x12\t\n\x01n\x18\x01 \x01(\x05\x12\r\n\x05start\x18\x02 \x01(\x05\"Z\n\x03\x42ox\x12(\n\x03low\x18\x02 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\x12)\n\x04high\x18\x03 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\"5\n\x0bMultiBinary\x12&\n\x01n\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\":\n\rMultiDiscrete\x12)\n\x04nvec\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\"|\n\x04\x44ict\x12\x31\n\x06spaces\x18\x01 \x03(\x0b\x32!.cogment_lab.spaces.Dict.SubSpace\x1a\x41\n\x08SubSpace\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05space\x18\x02 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\"q\n\x05Tuple\x12\x32\n\x06spaces\x18\x02 \x03(\x0b\x32\".cogment_lab.spaces.Tuple.SubSpace\x1a\x34\n\x08SubSpace\x12(\n\x05space\x18\x01 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\"?\n\x04Text\x12\x12\n\nmax_length\x18\x01 \x01(\x05\x12\x12\n\nmin_length\x18\x02 \x01(\x05\x12\x0f\n\x07\x63harset\x18\x03 \x01(\t\"\xdf\x02\n\x05Space\x12\x30\n\x08\x64iscrete\x18\x01 \x01(\x0b\x32\x1c.cogment_lab.spaces.DiscreteH\x00\x12&\n\x03\x62ox\x18\x02 \x01(\x0b\x32\x17.cogment_lab.spaces.BoxH\x00\x12(\n\x04\x64ict\x18\x03 \x01(\x0b\x32\x18.cogment_lab.spaces.DictH\x00\x12\x37\n\x0cmulti_binary\x18\x04 \x01(\x0b\x32\x1f.cogment_lab.spaces.MultiBinaryH\x00\x12;\n\x0emulti_discrete\x18\x05 \x01(\x0b\x32!.cogment_lab.spaces.MultiDiscreteH\x00\x12*\n\x05tuple\x18\x06 \x01(\x0b\x32\x19.cogment_lab.spaces.TupleH\x00\x12(\n\x04text\x18\x07 \x01(\x0b\x32\x18.cogment_lab.spaces.TextH\x00\x42\x06\n\x04kindb\x06proto3') + + + +_DISCRETE = DESCRIPTOR.message_types_by_name['Discrete'] +_BOX = DESCRIPTOR.message_types_by_name['Box'] +_MULTIBINARY = DESCRIPTOR.message_types_by_name['MultiBinary'] +_MULTIDISCRETE = DESCRIPTOR.message_types_by_name['MultiDiscrete'] +_DICT = DESCRIPTOR.message_types_by_name['Dict'] +_DICT_SUBSPACE = _DICT.nested_types_by_name['SubSpace'] +_TUPLE = DESCRIPTOR.message_types_by_name['Tuple'] +_TUPLE_SUBSPACE = _TUPLE.nested_types_by_name['SubSpace'] +_TEXT = DESCRIPTOR.message_types_by_name['Text'] +_SPACE = DESCRIPTOR.message_types_by_name['Space'] +Discrete = _reflection.GeneratedProtocolMessageType('Discrete', (_message.Message,), { + 'DESCRIPTOR' : _DISCRETE, + '__module__' : 'spaces_pb2' + # @@protoc_insertion_point(class_scope:cogment_lab.spaces.Discrete) + }) +_sym_db.RegisterMessage(Discrete) + +Box = _reflection.GeneratedProtocolMessageType('Box', (_message.Message,), { + 'DESCRIPTOR' : _BOX, + '__module__' : 'spaces_pb2' + # @@protoc_insertion_point(class_scope:cogment_lab.spaces.Box) + }) +_sym_db.RegisterMessage(Box) + +MultiBinary = _reflection.GeneratedProtocolMessageType('MultiBinary', (_message.Message,), { + 'DESCRIPTOR' : _MULTIBINARY, + '__module__' : 'spaces_pb2' + # @@protoc_insertion_point(class_scope:cogment_lab.spaces.MultiBinary) + }) +_sym_db.RegisterMessage(MultiBinary) + +MultiDiscrete = _reflection.GeneratedProtocolMessageType('MultiDiscrete', (_message.Message,), { + 'DESCRIPTOR' : _MULTIDISCRETE, + '__module__' : 'spaces_pb2' + # @@protoc_insertion_point(class_scope:cogment_lab.spaces.MultiDiscrete) + }) +_sym_db.RegisterMessage(MultiDiscrete) + +Dict = _reflection.GeneratedProtocolMessageType('Dict', (_message.Message,), { + + 'SubSpace' : _reflection.GeneratedProtocolMessageType('SubSpace', (_message.Message,), { + 'DESCRIPTOR' : _DICT_SUBSPACE, + '__module__' : 'spaces_pb2' + # @@protoc_insertion_point(class_scope:cogment_lab.spaces.Dict.SubSpace) + }) + , + 'DESCRIPTOR' : _DICT, + '__module__' : 'spaces_pb2' + # @@protoc_insertion_point(class_scope:cogment_lab.spaces.Dict) + }) +_sym_db.RegisterMessage(Dict) +_sym_db.RegisterMessage(Dict.SubSpace) + +Tuple = _reflection.GeneratedProtocolMessageType('Tuple', (_message.Message,), { + + 'SubSpace' : _reflection.GeneratedProtocolMessageType('SubSpace', (_message.Message,), { + 'DESCRIPTOR' : _TUPLE_SUBSPACE, + '__module__' : 'spaces_pb2' + # @@protoc_insertion_point(class_scope:cogment_lab.spaces.Tuple.SubSpace) + }) + , + 'DESCRIPTOR' : _TUPLE, + '__module__' : 'spaces_pb2' + # @@protoc_insertion_point(class_scope:cogment_lab.spaces.Tuple) + }) +_sym_db.RegisterMessage(Tuple) +_sym_db.RegisterMessage(Tuple.SubSpace) + +Text = _reflection.GeneratedProtocolMessageType('Text', (_message.Message,), { + 'DESCRIPTOR' : _TEXT, + '__module__' : 'spaces_pb2' + # @@protoc_insertion_point(class_scope:cogment_lab.spaces.Text) + }) +_sym_db.RegisterMessage(Text) + +Space = _reflection.GeneratedProtocolMessageType('Space', (_message.Message,), { + 'DESCRIPTOR' : _SPACE, + '__module__' : 'spaces_pb2' + # @@protoc_insertion_point(class_scope:cogment_lab.spaces.Space) + }) +_sym_db.RegisterMessage(Space) -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'spaces_pb2', globals()) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None @@ -48,8 +116,12 @@ _DICT._serialized_end=420 _DICT_SUBSPACE._serialized_start=355 _DICT_SUBSPACE._serialized_end=420 - _TEXT._serialized_start=422 - _TEXT._serialized_end=485 - _SPACE._serialized_start=488 - _SPACE._serialized_end=795 + _TUPLE._serialized_start=422 + _TUPLE._serialized_end=535 + _TUPLE_SUBSPACE._serialized_start=483 + _TUPLE_SUBSPACE._serialized_end=535 + _TEXT._serialized_start=537 + _TEXT._serialized_end=600 + _SPACE._serialized_start=603 + _SPACE._serialized_end=954 # @@protoc_insertion_point(module_scope) diff --git a/cogment_lab/protos/spaces.proto b/cogment_lab/protos/spaces.proto index 23898fb..f9dbef3 100644 --- a/cogment_lab/protos/spaces.proto +++ b/cogment_lab/protos/spaces.proto @@ -44,6 +44,13 @@ message Dict { repeated SubSpace spaces = 1; } +message Tuple { + message SubSpace { + Space space = 1; + } + repeated SubSpace spaces = 2; +} + message Text { int32 max_length = 1; int32 min_length = 2; @@ -57,6 +64,8 @@ message Space { Dict dict = 3; MultiBinary multi_binary = 4; MultiDiscrete multi_discrete = 5; - Text text = 6; + Tuple tuple = 6; + Text text = 7; } } + diff --git a/cogment_lab/specs/spaces_serialization.py b/cogment_lab/specs/spaces_serialization.py index 170e9a1..d2547a1 100644 --- a/cogment_lab/specs/spaces_serialization.py +++ b/cogment_lab/specs/spaces_serialization.py @@ -22,6 +22,7 @@ from cogment_lab.generated.spaces_pb2 import MultiDiscrete # type: ignore from cogment_lab.generated.spaces_pb2 import Space # type: ignore from cogment_lab.generated.spaces_pb2 import Text # type: ignore +from cogment_lab.generated.spaces_pb2 import Tuple # type: ignore from .ndarray_serialization import ( SerializationFormat, @@ -66,6 +67,12 @@ def serialize_gym_space(space: gym.Space, serialization_format=SerializationForm spaces.append(Dict.SubSpace(key=key, space=serialize_gym_space(gym_sub_space))) return Space(dict=Dict(spaces=spaces)) + if isinstance(space, gym.spaces.Tuple): + spaces = [] + for gym_sub_space in space.spaces: + spaces.append(Tuple.SubSpace(space=serialize_gym_space(gym_sub_space, serialization_format))) + return Space(tuple=Tuple(spaces=spaces)) + if isinstance(space, gym.spaces.Text): return Space(text=Text(max_length=space.max_length, min_length=space.min_length, charset=space.characters)) @@ -101,8 +108,11 @@ def deserialize_space(pb_space: Space) -> gym.Space: spaces = [] for sub_space in dict_space_pb.spaces: spaces.append((sub_space.key, deserialize_space(sub_space.space))) - return gym.spaces.Dict(spaces=spaces) + if space_kind == "tuple": + tuple_space_pb = pb_space.tuple + spaces = [deserialize_space(sub_space) for sub_space in tuple_space_pb.spaces] + return gym.spaces.Tuple(spaces) if space_kind == "text": text_space_pb = pb_space.text return gym.spaces.Text( diff --git a/tests/test_spaces_serialization.py b/tests/test_spaces_serialization.py new file mode 100644 index 0000000..10a2333 --- /dev/null +++ b/tests/test_spaces_serialization.py @@ -0,0 +1,90 @@ +# Copyright 2023 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gymnasium +import numpy as np +import pytest +from gymnasium.spaces import Box, Dict, Discrete, MultiBinary, MultiDiscrete, Tuple + + +from cogment_lab.specs.ndarray_serialization import SerializationFormat +from cogment_lab.specs.spaces_serialization import deserialize_space, serialize_gym_space + +# pylint: disable=no-member + + +def test_serialize_custom_observation_space(): + """Test serialization of gym spaces of type: + Dict, Discrete, Box, MultiDiscrete, MultiBinary. + """ + gym_space = Dict( + { + "ext_controller": MultiDiscrete([5, 2, 2]), + "inner_state": Dict( + { + "charge": Discrete(100), + "system_checks": MultiBinary(10), + "system_checks_seq": MultiBinary([2, 5, 10]), + "system_checks_array": MultiBinary(np.array([2, 5, 10])), + "job_status": Dict( + { + "task": Discrete(5), + "progress": Box(low=0, high=100, shape=()), + } + ), + } + ), + "tuple_state": Tuple([Discrete(i) for i in range(1, 4)]), + } + ) + + pb_space = serialize_gym_space(gym_space, serialization_format=SerializationFormat.STRUCTURED) + + assert len(pb_space.dict.spaces) == 3 + assert pb_space.dict.spaces[0].key == "ext_controller" + assert pb_space.dict.spaces[0].space.multi_discrete.nvec.shape == [3] + + assert pb_space.dict.spaces[1].key == "inner_state" + assert len(pb_space.dict.spaces[1].space.dict.spaces) == 5 + + assert pb_space.dict.spaces[1].space.dict.spaces[0].key == "charge" + assert pb_space.dict.spaces[1].space.dict.spaces[0].space.discrete.n == 100 + + assert pb_space.dict.spaces[1].space.dict.spaces[1].key == "job_status" + assert len(pb_space.dict.spaces[1].space.dict.spaces[1].space.dict.spaces) == 2 + + assert pb_space.dict.spaces[1].space.dict.spaces[1].space.dict.spaces[0].key == "progress" + assert pb_space.dict.spaces[1].space.dict.spaces[1].space.dict.spaces[0].space.box.low.double_data[ + 0 + ] == pytest.approx(0.0) + assert pb_space.dict.spaces[1].space.dict.spaces[1].space.dict.spaces[0].space.box.high.double_data[ + 0 + ] == pytest.approx(100.0) + + assert pb_space.dict.spaces[1].space.dict.spaces[1].space.dict.spaces[1].key == "task" + assert pb_space.dict.spaces[1].space.dict.spaces[1].space.dict.spaces[1].space.discrete.n == 5 + + assert pb_space.dict.spaces[1].space.dict.spaces[2].key == "system_checks" + assert pb_space.dict.spaces[1].space.dict.spaces[2].space.multi_binary.n.shape == [1] + + assert pb_space.dict.spaces[1].space.dict.spaces[3].key == "system_checks_array" + assert pb_space.dict.spaces[1].space.dict.spaces[3].space.multi_binary.n.shape == [3] + + assert pb_space.dict.spaces[1].space.dict.spaces[4].key == "system_checks_seq" + assert pb_space.dict.spaces[1].space.dict.spaces[4].space.multi_binary.n.shape == [3] + + assert pb_space.dict.spaces[2].key == "tuple_state" + assert pb_space.dict.spaces[2].space.tuple.spaces[0].space.discrete.n == 1 + assert pb_space.dict.spaces[2].space.tuple.spaces[1].space.discrete.n == 2 + assert pb_space.dict.spaces[2].space.tuple.spaces[2].space.discrete.n == 3