diff --git a/cogment_lab/generated/data_pb2.py b/cogment_lab/generated/data_pb2.py index 97419f5..c32e6c7 100644 --- a/cogment_lab/generated/data_pb2.py +++ b/cogment_lab/generated/data_pb2.py @@ -1,18 +1,3 @@ -# Copyright 2024 AI Redefined Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: data.proto """Generated protocol buffer code.""" @@ -31,34 +16,33 @@ import cogment_lab.generated.spaces_pb2 as spaces__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\ndata.proto\x12\x0b\x63ogment_lab\x1a\rndarray.proto\x1a\x0cspaces.proto"\xd7\x01\n\x10\x45nvironmentSpecs\x12\x16\n\x0eimplementation\x18\x01 \x01(\t\x12\x12\n\nturn_based\x18\x02 \x01(\x08\x12\x13\n\x0bnum_players\x18\x03 \x01(\x05\x12\x34\n\x11observation_space\x18\x04 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\x12/\n\x0c\x61\x63tion_space\x18\x05 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\x12\x1b\n\x13web_components_file\x18\x06 \x01(\t"s\n\nAgentSpecs\x12\x34\n\x11observation_space\x18\x01 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\x12/\n\x0c\x61\x63tion_space\x18\x02 \x01(\x0b\x32\x19.cogment_lab.spaces.Space"Y\n\x05Value\x12\x16\n\x0cstring_value\x18\x01 \x01(\tH\x00\x12\x13\n\tint_value\x18\x02 \x01(\x05H\x00\x12\x15\n\x0b\x66loat_value\x18\x03 \x01(\x02H\x00\x42\x0c\n\nvalue_type"\xf1\x01\n\x11\x45nvironmentConfig\x12\x0e\n\x06run_id\x18\x01 \x01(\t\x12\x0e\n\x06render\x18\x02 \x01(\x08\x12\x14\n\x0crender_width\x18\x03 \x01(\x05\x12\x0c\n\x04seed\x18\x04 \x01(\r\x12\x0f\n\x07\x66latten\x18\x05 \x01(\x08\x12\x41\n\nreset_args\x18\x06 \x03(\x0b\x32-.cogment_lab.EnvironmentConfig.ResetArgsEntry\x1a\x44\n\x0eResetArgsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.cogment_lab.Value:\x02\x38\x01"/\n\nHFHubModel\x12\x0f\n\x07repo_id\x18\x01 \x01(\t\x12\x10\n\x08\x66ilename\x18\x02 \x01(\t"\xa4\x01\n\x0b\x41gentConfig\x12\x0e\n\x06run_id\x18\x01 \x01(\t\x12,\n\x0b\x61gent_specs\x18\x02 \x01(\x0b\x32\x17.cogment_lab.AgentSpecs\x12\x0c\n\x04seed\x18\x03 \x01(\r\x12\x10\n\x08model_id\x18\x04 \x01(\t\x12\x17\n\x0fmodel_iteration\x18\x05 \x01(\x05\x12\x1e\n\x16model_update_frequency\x18\x06 \x01(\x05"\r\n\x0bTrialConfig"\x88\x01\n\x0bObservation\x12*\n\x05value\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\x12\x0e\n\x06\x61\x63tive\x18\x02 \x01(\x08\x12\r\n\x05\x61live\x18\x03 \x01(\x08\x12\x1b\n\x0erendered_frame\x18\x04 \x01(\x0cH\x00\x88\x01\x01\x42\x11\n\x0f_rendered_frame":\n\x0cPlayerAction\x12*\n\x05value\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Arrayb\x06proto3' -) +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\ndata.proto\x12\x0b\x63ogment_lab\x1a\rndarray.proto\x1a\x0cspaces.proto\"\xd7\x01\n\x10\x45nvironmentSpecs\x12\x16\n\x0eimplementation\x18\x01 \x01(\t\x12\x12\n\nturn_based\x18\x02 \x01(\x08\x12\x13\n\x0bnum_players\x18\x03 \x01(\x05\x12\x34\n\x11observation_space\x18\x04 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\x12/\n\x0c\x61\x63tion_space\x18\x05 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\x12\x1b\n\x13web_components_file\x18\x06 \x01(\t\"s\n\nAgentSpecs\x12\x34\n\x11observation_space\x18\x01 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\x12/\n\x0c\x61\x63tion_space\x18\x02 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\"Y\n\x05Value\x12\x16\n\x0cstring_value\x18\x01 \x01(\tH\x00\x12\x13\n\tint_value\x18\x02 \x01(\x05H\x00\x12\x15\n\x0b\x66loat_value\x18\x03 \x01(\x02H\x00\x42\x0c\n\nvalue_type\"\xf1\x01\n\x11\x45nvironmentConfig\x12\x0e\n\x06run_id\x18\x01 \x01(\t\x12\x0e\n\x06render\x18\x02 \x01(\x08\x12\x14\n\x0crender_width\x18\x03 \x01(\x05\x12\x0c\n\x04seed\x18\x04 \x01(\r\x12\x0f\n\x07\x66latten\x18\x05 \x01(\x08\x12\x41\n\nreset_args\x18\x06 \x03(\x0b\x32-.cogment_lab.EnvironmentConfig.ResetArgsEntry\x1a\x44\n\x0eResetArgsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12!\n\x05value\x18\x02 \x01(\x0b\x32\x12.cogment_lab.Value:\x02\x38\x01\"/\n\nHFHubModel\x12\x0f\n\x07repo_id\x18\x01 \x01(\t\x12\x10\n\x08\x66ilename\x18\x02 \x01(\t\"\xa4\x01\n\x0b\x41gentConfig\x12\x0e\n\x06run_id\x18\x01 \x01(\t\x12,\n\x0b\x61gent_specs\x18\x02 \x01(\x0b\x32\x17.cogment_lab.AgentSpecs\x12\x0c\n\x04seed\x18\x03 \x01(\r\x12\x10\n\x08model_id\x18\x04 \x01(\t\x12\x17\n\x0fmodel_iteration\x18\x05 \x01(\x05\x12\x1e\n\x16model_update_frequency\x18\x06 \x01(\x05\"\r\n\x0bTrialConfig\"\x88\x01\n\x0bObservation\x12*\n\x05value\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\x12\x0e\n\x06\x61\x63tive\x18\x02 \x01(\x08\x12\r\n\x05\x61live\x18\x03 \x01(\x08\x12\x1b\n\x0erendered_frame\x18\x04 \x01(\x0cH\x00\x88\x01\x01\x42\x11\n\x0f_rendered_frame\":\n\x0cPlayerAction\x12*\n\x05value\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Arrayb\x06proto3') _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "data_pb2", globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'data_pb2', globals()) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _ENVIRONMENTCONFIG_RESETARGSENTRY._options = None - _ENVIRONMENTCONFIG_RESETARGSENTRY._serialized_options = b"8\001" - _ENVIRONMENTSPECS._serialized_start = 57 - _ENVIRONMENTSPECS._serialized_end = 272 - _AGENTSPECS._serialized_start = 274 - _AGENTSPECS._serialized_end = 389 - _VALUE._serialized_start = 391 - _VALUE._serialized_end = 480 - _ENVIRONMENTCONFIG._serialized_start = 483 - _ENVIRONMENTCONFIG._serialized_end = 724 - _ENVIRONMENTCONFIG_RESETARGSENTRY._serialized_start = 656 - _ENVIRONMENTCONFIG_RESETARGSENTRY._serialized_end = 724 - _HFHUBMODEL._serialized_start = 726 - _HFHUBMODEL._serialized_end = 773 - _AGENTCONFIG._serialized_start = 776 - _AGENTCONFIG._serialized_end = 940 - _TRIALCONFIG._serialized_start = 942 - _TRIALCONFIG._serialized_end = 955 - _OBSERVATION._serialized_start = 958 - _OBSERVATION._serialized_end = 1094 - _PLAYERACTION._serialized_start = 1096 - _PLAYERACTION._serialized_end = 1154 + + DESCRIPTOR._options = None + _ENVIRONMENTCONFIG_RESETARGSENTRY._options = None + _ENVIRONMENTCONFIG_RESETARGSENTRY._serialized_options = b'8\001' + _ENVIRONMENTSPECS._serialized_start=57 + _ENVIRONMENTSPECS._serialized_end=272 + _AGENTSPECS._serialized_start=274 + _AGENTSPECS._serialized_end=389 + _VALUE._serialized_start=391 + _VALUE._serialized_end=480 + _ENVIRONMENTCONFIG._serialized_start=483 + _ENVIRONMENTCONFIG._serialized_end=724 + _ENVIRONMENTCONFIG_RESETARGSENTRY._serialized_start=656 + _ENVIRONMENTCONFIG_RESETARGSENTRY._serialized_end=724 + _HFHUBMODEL._serialized_start=726 + _HFHUBMODEL._serialized_end=773 + _AGENTCONFIG._serialized_start=776 + _AGENTCONFIG._serialized_end=940 + _TRIALCONFIG._serialized_start=942 + _TRIALCONFIG._serialized_end=955 + _OBSERVATION._serialized_start=958 + _OBSERVATION._serialized_end=1094 + _PLAYERACTION._serialized_start=1096 + _PLAYERACTION._serialized_end=1154 # @@protoc_insertion_point(module_scope) diff --git a/cogment_lab/generated/ndarray_pb2.py b/cogment_lab/generated/ndarray_pb2.py index 56513b7..84cba8f 100644 --- a/cogment_lab/generated/ndarray_pb2.py +++ b/cogment_lab/generated/ndarray_pb2.py @@ -1,18 +1,3 @@ -# Copyright 2024 AI Redefined Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: ndarray.proto """Generated protocol buffer code.""" @@ -27,16 +12,17 @@ _sym_db = _symbol_database.Default() -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\rndarray.proto\x12\x14\x63ogment_lab.nd_array"\xb8\x01\n\x05\x41rray\x12*\n\x05\x64type\x18\x01 \x01(\x0e\x32\x1b.cogment_lab.nd_array.DType\x12\r\n\x05shape\x18\x02 \x03(\r\x12\x10\n\x08raw_data\x18\x03 \x01(\x0c\x12\x10\n\x08npy_data\x18\x04 \x01(\x0c\x12\x13\n\x0b\x64ouble_data\x18\x05 \x03(\x01\x12\x12\n\nint32_data\x18\x06 \x03(\x11\x12\x12\n\nint64_data\x18\x07 \x03(\x12\x12\x13\n\x0buint32_data\x18\x08 \x03(\r*\x83\x01\n\x05\x44Type\x12\x11\n\rDTYPE_UNKNOWN\x10\x00\x12\x11\n\rDTYPE_FLOAT32\x10\x01\x12\x11\n\rDTYPE_FLOAT64\x10\x02\x12\x0e\n\nDTYPE_INT8\x10\x03\x12\x0f\n\x0b\x44TYPE_INT32\x10\x04\x12\x0f\n\x0b\x44TYPE_INT64\x10\x05\x12\x0f\n\x0b\x44TYPE_UINT8\x10\x06\x62\x06proto3' -) + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rndarray.proto\x12\x14\x63ogment_lab.nd_array\"\xcd\x01\n\x05\x41rray\x12*\n\x05\x64type\x18\x01 \x01(\x0e\x32\x1b.cogment_lab.nd_array.DType\x12\r\n\x05shape\x18\x02 \x03(\r\x12\x10\n\x08raw_data\x18\x03 \x01(\x0c\x12\x10\n\x08npy_data\x18\x04 \x01(\x0c\x12\x13\n\x0b\x64ouble_data\x18\x05 \x03(\x01\x12\x12\n\nint32_data\x18\x06 \x03(\x11\x12\x12\n\nint64_data\x18\x07 \x03(\x12\x12\x13\n\x0buint32_data\x18\x08 \x03(\r\x12\x13\n\x0bstring_data\x18\t \x03(\t*\x95\x01\n\x05\x44Type\x12\x11\n\rDTYPE_UNKNOWN\x10\x00\x12\x11\n\rDTYPE_FLOAT32\x10\x01\x12\x11\n\rDTYPE_FLOAT64\x10\x02\x12\x0e\n\nDTYPE_INT8\x10\x03\x12\x0f\n\x0b\x44TYPE_INT32\x10\x04\x12\x0f\n\x0b\x44TYPE_INT64\x10\x05\x12\x0f\n\x0b\x44TYPE_UINT8\x10\x06\x12\x10\n\x0c\x44TYPE_STRING\x10\x07\x62\x06proto3') _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "ndarray_pb2", globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'ndarray_pb2', globals()) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _DTYPE._serialized_start = 227 - _DTYPE._serialized_end = 358 - _ARRAY._serialized_start = 40 - _ARRAY._serialized_end = 224 + + DESCRIPTOR._options = None + _DTYPE._serialized_start=248 + _DTYPE._serialized_end=397 + _ARRAY._serialized_start=40 + _ARRAY._serialized_end=245 # @@protoc_insertion_point(module_scope) diff --git a/cogment_lab/generated/spaces_pb2.py b/cogment_lab/generated/spaces_pb2.py index 3088fa8..335ebf3 100644 --- a/cogment_lab/generated/spaces_pb2.py +++ b/cogment_lab/generated/spaces_pb2.py @@ -1,18 +1,3 @@ -# Copyright 2024 AI Redefined Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: spaces.proto """Generated protocol buffer code.""" @@ -30,26 +15,27 @@ import cogment_lab.generated.ndarray_pb2 as ndarray__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x0cspaces.proto\x12\x12\x63ogment_lab.spaces\x1a\rndarray.proto"$\n\x08\x44iscrete\x12\t\n\x01n\x18\x01 \x01(\x05\x12\r\n\x05start\x18\x02 \x01(\x05"Z\n\x03\x42ox\x12(\n\x03low\x18\x02 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\x12)\n\x04high\x18\x03 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array"5\n\x0bMultiBinary\x12&\n\x01n\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array":\n\rMultiDiscrete\x12)\n\x04nvec\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array"|\n\x04\x44ict\x12\x31\n\x06spaces\x18\x01 \x03(\x0b\x32!.cogment_lab.spaces.Dict.SubSpace\x1a\x41\n\x08SubSpace\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05space\x18\x02 \x01(\x0b\x32\x19.cogment_lab.spaces.Space"\x89\x02\n\x05Space\x12\x30\n\x08\x64iscrete\x18\x01 \x01(\x0b\x32\x1c.cogment_lab.spaces.DiscreteH\x00\x12&\n\x03\x62ox\x18\x02 \x01(\x0b\x32\x17.cogment_lab.spaces.BoxH\x00\x12(\n\x04\x64ict\x18\x03 \x01(\x0b\x32\x18.cogment_lab.spaces.DictH\x00\x12\x37\n\x0cmulti_binary\x18\x04 \x01(\x0b\x32\x1f.cogment_lab.spaces.MultiBinaryH\x00\x12;\n\x0emulti_discrete\x18\x05 \x01(\x0b\x32!.cogment_lab.spaces.MultiDiscreteH\x00\x42\x06\n\x04kindb\x06proto3' -) +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0cspaces.proto\x12\x12\x63ogment_lab.spaces\x1a\rndarray.proto\"$\n\x08\x44iscrete\x12\t\n\x01n\x18\x01 \x01(\x05\x12\r\n\x05start\x18\x02 \x01(\x05\"Z\n\x03\x42ox\x12(\n\x03low\x18\x02 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\x12)\n\x04high\x18\x03 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\"5\n\x0bMultiBinary\x12&\n\x01n\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\":\n\rMultiDiscrete\x12)\n\x04nvec\x18\x01 \x01(\x0b\x32\x1b.cogment_lab.nd_array.Array\"|\n\x04\x44ict\x12\x31\n\x06spaces\x18\x01 \x03(\x0b\x32!.cogment_lab.spaces.Dict.SubSpace\x1a\x41\n\x08SubSpace\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05space\x18\x02 \x01(\x0b\x32\x19.cogment_lab.spaces.Space\"?\n\x04Text\x12\x12\n\nmax_length\x18\x01 \x01(\x05\x12\x12\n\nmin_length\x18\x02 \x01(\x05\x12\x0f\n\x07\x63harset\x18\x03 \x01(\t\"\xb3\x02\n\x05Space\x12\x30\n\x08\x64iscrete\x18\x01 \x01(\x0b\x32\x1c.cogment_lab.spaces.DiscreteH\x00\x12&\n\x03\x62ox\x18\x02 \x01(\x0b\x32\x17.cogment_lab.spaces.BoxH\x00\x12(\n\x04\x64ict\x18\x03 \x01(\x0b\x32\x18.cogment_lab.spaces.DictH\x00\x12\x37\n\x0cmulti_binary\x18\x04 \x01(\x0b\x32\x1f.cogment_lab.spaces.MultiBinaryH\x00\x12;\n\x0emulti_discrete\x18\x05 \x01(\x0b\x32!.cogment_lab.spaces.MultiDiscreteH\x00\x12(\n\x04text\x18\x06 \x01(\x0b\x32\x18.cogment_lab.spaces.TextH\x00\x42\x06\n\x04kindb\x06proto3') _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "spaces_pb2", globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'spaces_pb2', globals()) if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _DISCRETE._serialized_start = 51 - _DISCRETE._serialized_end = 87 - _BOX._serialized_start = 89 - _BOX._serialized_end = 179 - _MULTIBINARY._serialized_start = 181 - _MULTIBINARY._serialized_end = 234 - _MULTIDISCRETE._serialized_start = 236 - _MULTIDISCRETE._serialized_end = 294 - _DICT._serialized_start = 296 - _DICT._serialized_end = 420 - _DICT_SUBSPACE._serialized_start = 355 - _DICT_SUBSPACE._serialized_end = 420 - _SPACE._serialized_start = 423 - _SPACE._serialized_end = 688 + + DESCRIPTOR._options = None + _DISCRETE._serialized_start=51 + _DISCRETE._serialized_end=87 + _BOX._serialized_start=89 + _BOX._serialized_end=179 + _MULTIBINARY._serialized_start=181 + _MULTIBINARY._serialized_end=234 + _MULTIDISCRETE._serialized_start=236 + _MULTIDISCRETE._serialized_end=294 + _DICT._serialized_start=296 + _DICT._serialized_end=420 + _DICT_SUBSPACE._serialized_start=355 + _DICT_SUBSPACE._serialized_end=420 + _TEXT._serialized_start=422 + _TEXT._serialized_end=485 + _SPACE._serialized_start=488 + _SPACE._serialized_end=795 # @@protoc_insertion_point(module_scope) diff --git a/cogment_lab/protos/ndarray.proto b/cogment_lab/protos/ndarray.proto index 24dc3a9..7fefc4f 100644 --- a/cogment_lab/protos/ndarray.proto +++ b/cogment_lab/protos/ndarray.proto @@ -24,6 +24,7 @@ enum DType { DTYPE_INT32 = 4; DTYPE_INT64 = 5; DTYPE_UINT8 = 6; + DTYPE_STRING = 7; } message Array { @@ -35,4 +36,5 @@ message Array { repeated sint32 int32_data = 6; repeated sint64 int64_data = 7; repeated uint32 uint32_data = 8; + repeated string string_data = 9; } diff --git a/cogment_lab/protos/spaces.proto b/cogment_lab/protos/spaces.proto index 04d4686..23898fb 100644 --- a/cogment_lab/protos/spaces.proto +++ b/cogment_lab/protos/spaces.proto @@ -44,6 +44,12 @@ message Dict { repeated SubSpace spaces = 1; } +message Text { + int32 max_length = 1; + int32 min_length = 2; + string charset = 3; +} + message Space { oneof kind { Discrete discrete = 1; @@ -51,5 +57,6 @@ message Space { Dict dict = 3; MultiBinary multi_binary = 4; MultiDiscrete multi_discrete = 5; + Text text = 6; } } diff --git a/cogment_lab/specs/ndarray_serialization.py b/cogment_lab/specs/ndarray_serialization.py index a6f9fb9..068ba1c 100644 --- a/cogment_lab/specs/ndarray_serialization.py +++ b/cogment_lab/specs/ndarray_serialization.py @@ -24,6 +24,7 @@ from cogment_lab.generated.ndarray_pb2 import DTYPE_INT8 # type: ignore from cogment_lab.generated.ndarray_pb2 import DTYPE_INT32 # type: ignore from cogment_lab.generated.ndarray_pb2 import DTYPE_INT64 # type: ignore +from cogment_lab.generated.ndarray_pb2 import DTYPE_STRING # type: ignore from cogment_lab.generated.ndarray_pb2 import DTYPE_UINT8 # type: ignore from cogment_lab.generated.ndarray_pb2 import DTYPE_UNKNOWN # type: ignore from cogment_lab.generated.ndarray_pb2 import Array # type: ignore @@ -45,6 +46,7 @@ DTYPE_INT32: np.dtype("int32"), DTYPE_INT64: np.dtype("int64"), DTYPE_UINT8: np.dtype("uint8"), + DTYPE_STRING: np.dtype("str"), } DOUBLE_DTYPES = frozenset(["float32", "float64"]) @@ -64,7 +66,7 @@ def serialize_ndarray( serialization_format: SerializationFormat = SerializationFormat.RAW, ) -> Array: str_dtype = str(nd_array.dtype) - pb_dtype = PB_DTYPE_FROM_DTYPE.get(str_dtype, DTYPE_UNKNOWN) + pb_dtype = DTYPE_STRING if "U" in str_dtype else PB_DTYPE_FROM_DTYPE.get(str_dtype, DTYPE_UNKNOWN) # SerializationFormat.RAW if serialization_format is SerializationFormat.RAW: @@ -109,6 +111,12 @@ def serialize_ndarray( dtype=pb_dtype, int64_data=nd_array.ravel(order="C").tolist(), ) + if "U" in str_dtype: + return Array( + shape=nd_array.shape, + dtype=pb_dtype, + string_data=nd_array.ravel(order="C").tolist(), + ) raise RuntimeError( f"[{str_dtype}] is not a supported numpy dtype for serialization format [{serialization_format}]" @@ -136,6 +144,8 @@ def deserialize_ndarray(pb_array: Array) -> np.ndarray | None: return np.array(pb_array.int32_data, dtype=dtype).reshape(shape, order="C") if str_dtype in INT64_DTYPES: return np.array(pb_array.int64_data, dtype=dtype).reshape(shape, order="C") + if "U" in str_dtype: + return np.array(pb_array.string_data).reshape(shape, order="C") return None # raise RuntimeError( diff --git a/cogment_lab/specs/spaces_serialization.py b/cogment_lab/specs/spaces_serialization.py index 500ff39..170e9a1 100644 --- a/cogment_lab/specs/spaces_serialization.py +++ b/cogment_lab/specs/spaces_serialization.py @@ -21,6 +21,7 @@ from cogment_lab.generated.spaces_pb2 import MultiBinary # type: ignore from cogment_lab.generated.spaces_pb2 import MultiDiscrete # type: ignore from cogment_lab.generated.spaces_pb2 import Space # type: ignore +from cogment_lab.generated.spaces_pb2 import Text # type: ignore from .ndarray_serialization import ( SerializationFormat, @@ -64,6 +65,10 @@ def serialize_gym_space(space: gym.Space, serialization_format=SerializationForm for key, gym_sub_space in space.spaces.items(): spaces.append(Dict.SubSpace(key=key, space=serialize_gym_space(gym_sub_space))) return Space(dict=Dict(spaces=spaces)) + + if isinstance(space, gym.spaces.Text): + return Space(text=Text(max_length=space.max_length, min_length=space.min_length, charset=space.characters)) + raise RuntimeError(f"[{type(space)}] is not a supported space type") @@ -98,5 +103,12 @@ def deserialize_space(pb_space: Space) -> gym.Space: spaces.append((sub_space.key, deserialize_space(sub_space.space))) return gym.spaces.Dict(spaces=spaces) + if space_kind == "text": + text_space_pb = pb_space.text + return gym.spaces.Text( + max_length=text_space_pb.max_length, + min_length=text_space_pb.min_length, + charset=text_space_pb.charset, + ) raise RuntimeError(f"[{space_kind}] is not a supported space kind") diff --git a/cogment_lab/utils/coltra_utils.py b/cogment_lab/utils/coltra_utils.py index 6a8337a..7ccdcaf 100644 --- a/cogment_lab/utils/coltra_utils.py +++ b/cogment_lab/utils/coltra_utils.py @@ -14,7 +14,7 @@ import numpy as np import torch -from coltra import Agent +from coltra import Agent, DAgent from coltra.buffers import Action, Observation, OnPolicyRecord from cogment_lab.utils.trial_utils import TrialData @@ -36,6 +36,8 @@ def convert_trial_data_to_coltra(trial_data: TrialData, agent: Agent) -> OnPolic done = trial_data.done # state = None # Assuming 'state' is not provided in TrialData + is_discrete = isinstance(agent, DAgent) + # last_value = agent.act(Observation(vector=trial_data.last_observation), get_value=True)[2]["value"] # value = agent.act(Observation(vector=trial_data.observations), get_value=True)[2]["value"] @@ -49,10 +51,12 @@ def convert_trial_data_to_coltra(trial_data: TrialData, agent: Agent) -> OnPolic if obs is None or action is None or reward is None or done is None: raise ValueError("Missing required fields in TrialData for conversion") + action_args = {"discrete": action} if is_discrete else {"continuous": action} + # Create an OnPolicyRecord instance with the mapped fields on_policy_record = OnPolicyRecord( obs=Observation(vector=obs).tensor(), # type: ignore - action=Action(discrete=action).tensor(), # type: ignore + action=Action(**action_args).tensor(), # type: ignore reward=torch.tensor(reward.astype(np.float32)), value=torch.tensor(value.astype(np.float32)), done=torch.tensor(done.astype(np.float32)), diff --git a/cogment_lab/utils/import_class.py b/cogment_lab/utils/import_class.py index 4cb4ddf..20e0a94 100644 --- a/cogment_lab/utils/import_class.py +++ b/cogment_lab/utils/import_class.py @@ -16,14 +16,20 @@ def import_object(class_name: str): - """Imports an object from a module based on a string + """ + Imports an object from a module based on a string. + If the argument is a string without any dots, it's taken from the current namespace. Args: class_name (str): The full path to the object e.g. "package.module.Class" + or the name of the object in the current namespace. Returns: object: The imported object """ - module_path, class_name = class_name.rsplit(".", 1) - module = import_module(module_path) - return getattr(module, class_name) + if "." not in class_name: + return globals()[class_name] + else: + module_path, class_name = class_name.rsplit(".", 1) + module = import_module(module_path) + return getattr(module, class_name) diff --git a/cogment_lab/utils/trial_utils.py b/cogment_lab/utils/trial_utils.py index 149a87a..f42cdb9 100644 --- a/cogment_lab/utils/trial_utils.py +++ b/cogment_lab/utils/trial_utils.py @@ -106,6 +106,8 @@ def initialize_buffer(space: gym.Space | None, length: int) -> np.ndarray | dict return {key: np.empty((length,) + space[key].shape, dtype=space[key].dtype) for key in space.spaces.keys()} # type: ignore elif isinstance(space, gym.spaces.Tuple): return {i: np.empty((length,) + space[i].shape, dtype=space[i].dtype) for i in range(len(space.spaces))} # type: ignore + elif isinstance(space, gym.spaces.Text): + return np.empty((length,), dtype=" tuple[str, dict]: + text = "".join(self.np_random.choice(self.observation_space.character_list, size=10)) + self.text = text + return text, {} + + def step(self, action: str): + reward = float(action == self.text) + truncated = False + terminated = action == self.text + return self.text, reward, terminated, truncated, {} + + +class NoisyEchoActor(CogmentActor): + def __init__(self, noise_prob: float = 0.1): + super().__init__(noise_prob=noise_prob) + self.noise_prob = noise_prob + self.rng = np.random.default_rng(0) + + async def act(self, observation: Observation, rendered_frame: np.ndarray | None = None) -> Action: + if self.rng.random() < self.noise_prob: + return "nope" + return observation + + +gym.register("Echo-v0", entry_point="tests.shared:EchoEnv", max_episode_steps=10) diff --git a/tests/test_text_spaces.py b/tests/test_text_spaces.py new file mode 100644 index 0000000..3330fe2 --- /dev/null +++ b/tests/test_text_spaces.py @@ -0,0 +1,102 @@ +# Copyright 2024 AI Redefined Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gymnasium as gym +import numpy as np +import pytest + +import tests.shared +from cogment_lab.envs.gymnasium import GymEnvironment +from cogment_lab.process_manager import Cogment +from cogment_lab.specs.ndarray_serialization import ( + SerializationFormat, + deserialize_ndarray, + serialize_ndarray, +) + + +@pytest.mark.parametrize( + "data", + [ + "foo", + "bar", + "hello world", + ["foo", "bar"], + ["hello", "world"], + ["one", "two", "three"], + [["11", "12", "13"], ["21", "22", "23"], ["31", "32", "33"]], + ], +) +def test_serialize_ndarray(data): + arr = np.array(data) + serialized = serialize_ndarray(arr, SerializationFormat.STRUCTURED) + deserialized = deserialize_ndarray(serialized) + assert np.array_equal(arr, deserialized) + + +def test_text_env(): + env = gym.make("Echo-v0") + obs, info = env.reset(seed=0) + assert isinstance(obs, str) + + obs, reward, terminated, truncated, info = env.step("not the same") + + assert isinstance(obs, str) + assert reward == 0.0 + assert terminated is False + assert truncated is False + + obs, reward, terminated, truncated, info = env.step(obs) + + assert isinstance(obs, str) + assert reward == 1.0 + assert terminated is True + assert truncated is False + + +@pytest.mark.asyncio +async def test_echo(): + """Test the echo environment.""" + + cog = Cogment(log_dir="logs") + + cenv = GymEnvironment(env_id="tests.shared:Echo-v0", render=False) + + await cog.run_env(env=cenv, env_name="echo_env", port=9011, log_file="echo_env.log") + + actor = tests.shared.NoisyEchoActor(noise_prob=0.8) + + await cog.run_actor( + actor=actor, + actor_name="echo_actor", + port=9021, + log_file="actor-echo.log", + ) + + trial_id = await cog.start_trial( + env_name="echo_env", + session_config={"render": False}, + actor_impls={ + "gym": "echo_actor", + }, + ) + + data = await cog.get_trial_data(trial_id=trial_id) + + assert isinstance(data, dict) + assert isinstance(data["gym"].observations, np.ndarray) + assert isinstance(data["gym"].actions, np.ndarray) + assert isinstance(data["gym"].rewards, np.ndarray) + + await cog.cleanup()