Skip to content

Commit

Permalink
Drop the requirement for cv2, test the new functions
Browse files Browse the repository at this point in the history
  • Loading branch information
RedTachyon committed Jan 22, 2024
1 parent 483d7fe commit a895d9e
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 30 deletions.
48 changes: 19 additions & 29 deletions cogment_lab/specs/encode_rendered_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,58 +12,48 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import cv2
from __future__ import annotations

import io

import numpy as np
from PIL import Image


MAX_RENDERED_WIDTH = 2048


def encode_rendered_frame(rendered_frame: np.ndarray, max_size: int = MAX_RENDERED_WIDTH) -> bytes:
def encode_rendered_frame(rendered_frame: np.ndarray, max_size: int = MAX_RENDERED_WIDTH, format: str = "PNG") -> bytes:
if max_size <= 0:
max_size = MAX_RENDERED_WIDTH
# gRPC max message size hack
height, width = rendered_frame.shape[:2]

image = Image.fromarray(rendered_frame.astype("uint8"), "RGB")

width, height = image.size
if max(height, width) > max_size:
if height > width:
new_height = max_size
new_width = int(new_height / height * width)
else:
new_width = max_size
new_height = int(height / width * new_width)
rendered_frame = cv2.resize(rendered_frame, (new_width, new_height), interpolation=cv2.INTER_AREA)
image = image.resize((new_width, new_height), Image.ANTIALIAS)

# note rgb -> bgr for cv2
result, encoded_frame = cv2.imencode(".jpg", rendered_frame[:, :, ::-1])
assert result
with io.BytesIO() as output:
image.save(output, format=format)
encoded_frame = output.getvalue()

return encoded_frame.tobytes()
return encoded_frame


def decode_rendered_frame(encoded_frame: bytes) -> np.ndarray:
"""
Decode the rendered frame from bytes to a NumPy array.
assert len(encoded_frame) > 0, "Encoded frame is empty"

Args:
encoded_frame (bytes): The encoded frame as a byte array.
with io.BytesIO(encoded_frame) as input:
image = Image.open(input)
decoded_frame = np.array(image)

Returns:
np.ndarray: The decoded rendered frame as a NumPy array.
"""
if encoded_frame is None or len(encoded_frame) == 0:
return None

# Convert the byte array back to a NumPy array
encoded_frame_np = np.frombuffer(encoded_frame, dtype=np.uint8)

# Decode the image from the byte array
decoded_frame = cv2.imdecode(encoded_frame_np, cv2.IMREAD_COLOR)

# Check if the decoding was successful
if decoded_frame is None:
raise ValueError("Failed to decode the rendered frame.")

# Convert from BGR to RGB
decoded_frame = decoded_frame[:, :, ::-1]

return decoded_frame
2 changes: 1 addition & 1 deletion cogment_lab/specs/observation_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def value(self):
return self._value

def _deserialize_rendered_frame(self):
if not self._pb_observation.rendered_frame != b"":
if self._pb_observation.rendered_frame == b"" or self._pb_observation.rendered_frame is None:
return None
return decode_rendered_frame(self._pb_observation.rendered_frame)

Expand Down
58 changes: 58 additions & 0 deletions tests/test_frame_encode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import numpy as np
import PIL
import pytest

from cogment_lab.specs.encode_rendered_frame import (
decode_rendered_frame,
encode_rendered_frame,
)


def create_test_image(width, height):
rng = np.random.default_rng(0)
return rng.integers(0, 255, (height, width, 3), dtype=np.uint8)


def test_encode_valid_input():
test_image = create_test_image(100, 100)
encoded = encode_rendered_frame(test_image)
assert isinstance(encoded, bytes)


def test_encode_invalid_size():
test_image = create_test_image(100, 100)
encoded = encode_rendered_frame(test_image, -1)
decoded = decode_rendered_frame(encoded)
assert max(decoded.shape[:2]) == 100


def test_encode_with_resizing():
test_image = create_test_image(2000, 2000)
max_size = 500
encoded = encode_rendered_frame(test_image, max_size)
decoded = decode_rendered_frame(encoded)

assert max(decoded.shape[:2]) == max_size

original_aspect_ratio = test_image.shape[1] / test_image.shape[0]
decoded_aspect_ratio = decoded.shape[1] / decoded.shape[0]
np.testing.assert_almost_equal(original_aspect_ratio, decoded_aspect_ratio)


def test_decode_valid_input():
test_image = create_test_image(100, 100)
encoded = encode_rendered_frame(test_image)
decoded = decode_rendered_frame(encoded)
assert decoded.shape == test_image.shape


def test_decode_failure():
with pytest.raises(PIL.UnidentifiedImageError):
decode_rendered_frame(b"invalid data")


def test_roundtrip():
test_image = create_test_image(100, 100)
encoded = encode_rendered_frame(test_image)
decoded = decode_rendered_frame(encoded)
np.testing.assert_array_almost_equal(test_image, decoded)

0 comments on commit a895d9e

Please sign in to comment.