Skip to content

Commit

Permalink
Recompile protos, modify code to support text spaces
Browse files Browse the repository at this point in the history
  • Loading branch information
RedTachyon committed Mar 19, 2024
1 parent 35c1931 commit 0b66609
Show file tree
Hide file tree
Showing 12 changed files with 247 additions and 107 deletions.
68 changes: 26 additions & 42 deletions cogment_lab/generated/data_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

34 changes: 10 additions & 24 deletions cogment_lab/generated/ndarray_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

54 changes: 20 additions & 34 deletions cogment_lab/generated/spaces_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions cogment_lab/protos/ndarray.proto
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ enum DType {
DTYPE_INT32 = 4;
DTYPE_INT64 = 5;
DTYPE_UINT8 = 6;
DTYPE_STRING = 7;
}

message Array {
Expand All @@ -35,4 +36,5 @@ message Array {
repeated sint32 int32_data = 6;
repeated sint64 int64_data = 7;
repeated uint32 uint32_data = 8;
repeated string string_data = 9;
}
7 changes: 7 additions & 0 deletions cogment_lab/protos/spaces.proto
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,19 @@ message Dict {
repeated SubSpace spaces = 1;
}

message Text {
int32 max_length = 1;
int32 min_length = 2;
string charset = 3;
}

message Space {
oneof kind {
Discrete discrete = 1;
Box box = 2;
Dict dict = 3;
MultiBinary multi_binary = 4;
MultiDiscrete multi_discrete = 5;
Text text = 6;
}
}
12 changes: 11 additions & 1 deletion cogment_lab/specs/ndarray_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from cogment_lab.generated.ndarray_pb2 import DTYPE_INT8 # type: ignore
from cogment_lab.generated.ndarray_pb2 import DTYPE_INT32 # type: ignore
from cogment_lab.generated.ndarray_pb2 import DTYPE_INT64 # type: ignore
from cogment_lab.generated.ndarray_pb2 import DTYPE_STRING # type: ignore
from cogment_lab.generated.ndarray_pb2 import DTYPE_UINT8 # type: ignore
from cogment_lab.generated.ndarray_pb2 import DTYPE_UNKNOWN # type: ignore
from cogment_lab.generated.ndarray_pb2 import Array # type: ignore
Expand All @@ -45,6 +46,7 @@
DTYPE_INT32: np.dtype("int32"),
DTYPE_INT64: np.dtype("int64"),
DTYPE_UINT8: np.dtype("uint8"),
DTYPE_STRING: np.dtype("str"),
}

DOUBLE_DTYPES = frozenset(["float32", "float64"])
Expand All @@ -64,7 +66,7 @@ def serialize_ndarray(
serialization_format: SerializationFormat = SerializationFormat.RAW,
) -> Array:
str_dtype = str(nd_array.dtype)
pb_dtype = PB_DTYPE_FROM_DTYPE.get(str_dtype, DTYPE_UNKNOWN)
pb_dtype = DTYPE_STRING if "U" in str_dtype else PB_DTYPE_FROM_DTYPE.get(str_dtype, DTYPE_UNKNOWN)

# SerializationFormat.RAW
if serialization_format is SerializationFormat.RAW:
Expand Down Expand Up @@ -109,6 +111,12 @@ def serialize_ndarray(
dtype=pb_dtype,
int64_data=nd_array.ravel(order="C").tolist(),
)
if "U" in str_dtype:
return Array(
shape=nd_array.shape,
dtype=pb_dtype,
string_data=nd_array.ravel(order="C").tolist(),
)

raise RuntimeError(
f"[{str_dtype}] is not a supported numpy dtype for serialization format [{serialization_format}]"
Expand Down Expand Up @@ -136,6 +144,8 @@ def deserialize_ndarray(pb_array: Array) -> np.ndarray | None:
return np.array(pb_array.int32_data, dtype=dtype).reshape(shape, order="C")
if str_dtype in INT64_DTYPES:
return np.array(pb_array.int64_data, dtype=dtype).reshape(shape, order="C")
if "U" in str_dtype:
return np.array(pb_array.string_data).reshape(shape, order="C")

return None
# raise RuntimeError(
Expand Down
12 changes: 12 additions & 0 deletions cogment_lab/specs/spaces_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from cogment_lab.generated.spaces_pb2 import MultiBinary # type: ignore
from cogment_lab.generated.spaces_pb2 import MultiDiscrete # type: ignore
from cogment_lab.generated.spaces_pb2 import Space # type: ignore
from cogment_lab.generated.spaces_pb2 import Text # type: ignore

from .ndarray_serialization import (
SerializationFormat,
Expand Down Expand Up @@ -64,6 +65,10 @@ def serialize_gym_space(space: gym.Space, serialization_format=SerializationForm
for key, gym_sub_space in space.spaces.items():
spaces.append(Dict.SubSpace(key=key, space=serialize_gym_space(gym_sub_space)))
return Space(dict=Dict(spaces=spaces))

if isinstance(space, gym.spaces.Text):
return Space(text=Text(max_length=space.max_length, min_length=space.min_length, charset=space.characters))

raise RuntimeError(f"[{type(space)}] is not a supported space type")


Expand Down Expand Up @@ -98,5 +103,12 @@ def deserialize_space(pb_space: Space) -> gym.Space:
spaces.append((sub_space.key, deserialize_space(sub_space.space)))

return gym.spaces.Dict(spaces=spaces)
if space_kind == "text":
text_space_pb = pb_space.text
return gym.spaces.Text(
max_length=text_space_pb.max_length,
min_length=text_space_pb.min_length,
charset=text_space_pb.charset,
)

raise RuntimeError(f"[{space_kind}] is not a supported space kind")
Loading

0 comments on commit 0b66609

Please sign in to comment.