Skip to content

Commit

Permalink
finalize
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickvonplaten committed Sep 17, 2024
1 parent d311877 commit 0881bd1
Show file tree
Hide file tree
Showing 5 changed files with 175 additions and 9 deletions.
128 changes: 127 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "mistral_common"
version = "1.4.0"
version = "1.4.1"
description = ""
authors = ["bam4d <[email protected]>"]
readme = "README.md"
Expand Down Expand Up @@ -35,6 +35,7 @@ typing-extensions = "^4.11.0"
tiktoken = "^0.7.0"
pillow = "^10.3.0"
requests = "^2.0.0"
opencv-python-headless = "^4.10.0.84"

[tool.poetry.group.dev.dependencies]
types-jsonschema = "4.21.0.20240118"
Expand Down
2 changes: 1 addition & 1 deletion src/mistral_common/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.4.0"
__version__ = "1.4.1"
11 changes: 5 additions & 6 deletions src/mistral_common/tokens/tokenizers/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from io import BytesIO
from typing import Tuple, Union

import cv2
import numpy as np
from PIL import Image

Expand Down Expand Up @@ -54,22 +55,21 @@ def _convert_to_rgb(image: Image.Image) -> Image.Image:


def normalize(
image: Image.Image,
np_image: np.ndarray,
mean: Tuple[float, float, float],
std: Tuple[float, float, float],
) -> np.ndarray:
"""
Normalize a tensor image with mean and standard deviation.
Args:
image (Image.Image): Image to be normalized.
image (np.ndarray): Image to be normalized.
mean (tuple[float, float, float]): Mean for each channel.
std (tuple[float, float, float]): Standard deviation for each channel.
Returns:
np.ndarray: Normalized image with shape (C, H, W).
"""
np_image = np.array(image, dtype=np.float32)
np_image = np_image / 255.0

assert len(np_image.shape) == 3, f"{np_image.shape=}"
Expand All @@ -81,9 +81,8 @@ def normalize(


def transform_image(image: Image.Image, new_size: Tuple[int, int]) -> np.ndarray:
image = _convert_to_rgb(image)
image = image.resize(new_size, Image.Resampling.BICUBIC)
return normalize(image, DATASET_MEAN, DATASET_STD)
np_image = cv2.resize(np.array(_convert_to_rgb(image), dtype=np.float32), new_size, interpolation=cv2.INTER_CUBIC)
return normalize(np_image, DATASET_MEAN, DATASET_STD)


class ImageEncoder(MultiModalEncoder):
Expand Down
40 changes: 40 additions & 0 deletions tests/test_multimodal.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import base64
from io import BytesIO
from typing import Tuple

import numpy as np
import pytest
import requests
from mistral_common.protocol.instruct.messages import (
Expand Down Expand Up @@ -70,6 +72,44 @@ def test_image_encoder(mm_config: MultimodalConfig, special_token_ids: SpecialIm
assert len(tokens) == (w + 1) * h


@pytest.mark.parametrize("size", [(200, 311), (300, 212), (251, 1374), (1475, 477), (1344, 1544), (2133, 3422)])
def test_image_processing(
mm_config: MultimodalConfig, special_token_ids: SpecialImageIDs, size: Tuple[int, int]
) -> None:
mm_config.max_image_size = 1024
mm_encoder = ImageEncoder(mm_config, special_token_ids)

# all images with w,h >= 1024 should be resized to 1024
# else round to nearest multiple of 16
# all while keeping the aspect ratio
EXP_IMG_SIZES = {
(200, 311): (208, 320),
(300, 212): (304, 224),
(251, 1374): (192, 1024),
(1475, 477): (1024, 336),
(1344, 1544): (896, 1024),
(2133, 3422): (640, 1024),
}
# integration test to make sure the img processing stays 100% the same
EXP_IMG_SUM = {
(200, 311): 232038.65023772235,
(300, 212): 182668.98900347573,
(251, 1374): 726925.9371541862,
(1475, 477): 985935.4162606588,
(1344, 1544): 2982953.705365115,
(2133, 3422): 2304438.4010818982,
}

url = f"https://picsum.photos/id/237/{size[0]}/{size[1]}"

content = ImageURLChunk(image_url=url)

image = mm_encoder(content).image

assert image.transpose().shape[:2] == EXP_IMG_SIZES[size], image.transpose().shape[:2]
assert np.abs(image).sum() - EXP_IMG_SUM[size] < 1e-5, np.abs(image).sum()


def test_image_encoder_formats(mm_config: MultimodalConfig, special_token_ids: SpecialImageIDs) -> None:
mm_encoder = ImageEncoder(mm_config, special_token_ids)

Expand Down

0 comments on commit 0881bd1

Please sign in to comment.