diff --git a/web/api/tests/tests.py b/web/api/tests/tests.py index 14f39d7..4120637 100644 --- a/web/api/tests/tests.py +++ b/web/api/tests/tests.py @@ -6,15 +6,8 @@ from accounts.models import CustomUser from api.middleware import add_slash from api.models import Collection, Document, Page, PageEmbedding -from api.views import ( - Bearer, - QueryFilter, - QueryIn, - filter_collections, - filter_documents, - filter_query, - router, -) +from api.views import (Bearer, QueryFilter, QueryIn, filter_collections, + filter_documents, filter_query, router) from django.core.exceptions import ValidationError as DjangoValidationError from django.core.files.uploadedfile import SimpleUploadedFile from django.test import override_settings @@ -1684,6 +1677,72 @@ async def test_create_embedding(async_client, user): assert response.json()["data"] != [] +async def test_create_embedding_invalid_input(async_client, user): + task = "image" + input_data = ["/Users/user/Desktop/image.png"] + response = await async_client.post( + "/embeddings/", + json={"task": task, "input_data": input_data}, + headers={"Authorization": f"Bearer {user.token}"}, + ) + + assert response.status_code == 422 + assert response.json() == { + "detail": [ + { + "type": "value_error", + "loc": ["body", "payload"], + "msg": "Value error, Each input must be a valid base64 string or a URL. Please use our Python SDK if you want to provide a file path.", + "ctx": { + "error": "Each input must be a valid base64 string or a URL. Please use our Python SDK if you want to provide a file path." + }, + } + ] + } + + +async def test_create_embedding_valid_url_service_down(async_client, user): + task = "image" + input_data = ["https://tourism.gov.in/sites/default/files/2019-04/dummy-pdf_2.pdf"] + EMBEDDINGS_POST_PATH = "api.models.aiohttp.ClientSession.post" + # Create a mock response object with status 500 + mock_response = AsyncMock() + mock_response.status = 500 + mock_response.json.return_value = AsyncMock(return_value={"error": "Service Down"}) + # Mock the context manager __aenter__ to return the mock_response + mock_response.__aenter__.return_value = mock_response + # Patch the aiohttp.ClientSession.post method to return the mock_response + with patch(EMBEDDINGS_POST_PATH, return_value=mock_response): + response = await async_client.post( + "/embeddings/", + json={"task": task, "input_data": input_data}, + headers={"Authorization": f"Bearer {user.token}"}, + ) + assert response.status_code == 503 + + +async def test_create_embedding_valid_base64_service_down(async_client, user): + task = "image" + input_data = [ + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=" + ] + EMBEDDINGS_POST_PATH = "api.models.aiohttp.ClientSession.post" + # Create a mock response object with status 500 + mock_response = AsyncMock() + mock_response.status = 500 + mock_response.json.return_value = AsyncMock(return_value={"error": "Service Down"}) + # Mock the context manager __aenter__ to return the mock_response + mock_response.__aenter__.return_value = mock_response + # Patch the aiohttp.ClientSession.post method to return the mock_response + with patch(EMBEDDINGS_POST_PATH, return_value=mock_response): + response = await async_client.post( + "/embeddings/", + json={"task": task, "input_data": input_data}, + headers={"Authorization": f"Bearer {user.token}"}, + ) + assert response.status_code == 503 + + async def test_create_embedding_service_down(async_client, user): task = "query" input_data = ["What is 1 + 1"] diff --git a/web/api/views.py b/web/api/views.py index 6e03191..43acfc4 100644 --- a/web/api/views.py +++ b/web/api/views.py @@ -1,8 +1,10 @@ import asyncio import base64 import logging +import re from enum import Enum from typing import Dict, List, Optional, Tuple, Union +from urllib.parse import urlparse import aiohttp from accounts.models import CustomUser @@ -18,7 +20,8 @@ from ninja.security import HttpBearer from pgvector.utils import HalfVector from pydantic import Field, model_validator -from svix.api import ApplicationIn, EndpointIn, EndpointUpdate, MessageIn, SvixAsync +from svix.api import (ApplicationIn, EndpointIn, EndpointUpdate, MessageIn, + SvixAsync) from typing_extensions import Self from .models import Collection, Document, MaxSim, Page @@ -1266,6 +1269,24 @@ class EmbeddingsIn(Schema): input_data: List[str] task: TaskEnum + @model_validator(mode="after") + def validate_input_data(self) -> Self: + if self.task == TaskEnum.image: + for value in self.input_data: + # Validate base64 + base64_pattern = r"^[A-Za-z0-9+/]+={0,2}$" + is_base64 = re.match(base64_pattern, value) and len(value) % 4 == 0 + + # Validate URL + parsed = urlparse(value) + is_url = all([parsed.scheme, parsed.netloc]) + + if not (is_base64 or is_url): + raise ValueError( + "Each input must be a valid base64 string or a URL. Please use our Python SDK if you want to provide a file path." + ) + return self + class EmbeddingsOut(Schema): _object: str