diff --git a/apps/ai/clients/admin-console/src/pages/queries/[queryId]/index.tsx b/apps/ai/clients/admin-console/src/pages/queries/[queryId]/index.tsx
index d84e7625..ab35ae60 100644
--- a/apps/ai/clients/admin-console/src/pages/queries/[queryId]/index.tsx
+++ b/apps/ai/clients/admin-console/src/pages/queries/[queryId]/index.tsx
@@ -1,7 +1,8 @@
+import PageErrorMessage from '@/components/error/page-error-message'
import PageLayout from '@/components/layout/page-layout'
-import QueryError from '@/components/query/error'
import LoadingQuery from '@/components/query/loading'
import QueryWorkspace from '@/components/query/workspace'
+import { ContentBox } from '@/components/ui/content-box'
import useQueries from '@/hooks/api/query/useQueries'
import { useQuery } from '@/hooks/api/query/useQuery'
import useQueryExecution from '@/hooks/api/query/useQueryExecution'
@@ -73,9 +74,13 @@ const QueryPage: FC = () => {
pageContent =
} else if (error) {
pageContent = (
-
-
-
+
+
+
)
} else if (query)
pageContent = (
diff --git a/apps/ai/clients/admin-console/src/pages/queries/index.tsx b/apps/ai/clients/admin-console/src/pages/queries/index.tsx
index b68bebd6..8b0dddbe 100644
--- a/apps/ai/clients/admin-console/src/pages/queries/index.tsx
+++ b/apps/ai/clients/admin-console/src/pages/queries/index.tsx
@@ -1,8 +1,8 @@
import { DataTable } from '@/components/data-table'
import { LoadingTable } from '@/components/data-table/loading-table'
+import PageErrorMessage from '@/components/error/page-error-message'
import PageLayout from '@/components/layout/page-layout'
import { getColumns } from '@/components/queries/columns'
-import QueriesError from '@/components/queries/error'
import { ContentBox } from '@/components/ui/content-box'
import { useAppContext } from '@/contexts/app-context'
import useQueries from '@/hooks/api/query/useQueries'
@@ -51,8 +51,13 @@ const QueriesPage: FC = () => {
let pageContent: JSX.Element = <>>
- if (!isLoadingFirst && error) {
- pageContent =
+ if (error) {
+ pageContent = (
+
+ )
} else if (isLoadingFirst) {
pageContent =
} else
diff --git a/apps/ai/clients/slack/handlers/message/index.js b/apps/ai/clients/slack/handlers/message/index.js
index 53fe8e85..97ba132a 100644
--- a/apps/ai/clients/slack/handlers/message/index.js
+++ b/apps/ai/clients/slack/handlers/message/index.js
@@ -40,7 +40,7 @@ async function handleMessage(context, say) {
workspace_id: teamId,
channel_id: channel_id,
thread_ts: thread_ts,
- }
+ },
}
log('Fetching data from', endpointUrl)
log('Request payload:', payload)
@@ -54,15 +54,11 @@ async function handleMessage(context, say) {
})
if (!response.ok) {
try {
- const { prompt_id, display_id, error_message } =
- await response.json()
+ const { error_code, message, detail } = await response.json()
error(
- `API Response not ok: status code ${response.status}, ${response.statusText}, error message: ${error_message}, query id: ${prompt_id}`
+ `API Response not ok: status code ${response.status}, ${response.statusText}, error code: ${error_code}, error message: ${message}, detail: ${detail}`
)
- const responseMessage =
- prompt_id == undefined || display_id == undefined
- ? `:exclamation: Sorry, something went wrong when I was processing your request. Please try again later.`
- : `:warning: Sorry, something went wrong while generating response for query ${display_id}. We'll get back to you once it's been reviewed by the data-team admins.`
+ const responseMessage = `:warning: Sorry, something went wrong while generating response. Error message: \`${message}\``
await say({
blocks: [
{
diff --git a/apps/ai/server/app.py b/apps/ai/server/app.py
index 7bc432b2..412523bf 100644
--- a/apps/ai/server/app.py
+++ b/apps/ai/server/app.py
@@ -5,6 +5,9 @@
from fastapi.middleware.cors import CORSMiddleware
from config import settings
+from exceptions.exception_handlers import exception_handler
+from exceptions.exceptions import BaseError
+from middleware.error import UnknownErrorMiddleware
from modules.auth import controller as auth_controller
from modules.db_connection import controller as db_connection_controller
from modules.finetuning import controller as finetuning_controller
@@ -21,7 +24,6 @@
from modules.organization.invoice import controller as invoice_controller
from modules.table_description import controller as table_description_controller
from modules.user import controller as user_controller
-from utils.exception import GenerationEngineError, query_engine_exception_handler
tags_metadata = [
{"name": "Authentication", "description": "Login endpoints for authentication"},
@@ -46,6 +48,7 @@
app = FastAPI()
+app.add_middleware(UnknownErrorMiddleware)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
@@ -54,8 +57,7 @@
allow_headers=["*"],
)
-app.add_exception_handler(GenerationEngineError, query_engine_exception_handler)
-
+app.add_exception_handler(BaseError, exception_handler)
app.include_router(db_connection_controller.router, tags=["Database Connection"])
app.include_router(finetuning_controller.router, tags=["Finetuning"])
diff --git a/apps/ai/server/dataherald b/apps/ai/server/dataherald
index b7e99006..9f34ba70 160000
--- a/apps/ai/server/dataherald
+++ b/apps/ai/server/dataherald
@@ -1 +1 @@
-Subproject commit b7e99006d64dfe0a0ac3c52bc77a80bde2372b28
+Subproject commit 9f34ba7016f4f5381528170a09c9a6a561341c20
diff --git a/apps/ai/server/exceptions/error_codes.py b/apps/ai/server/exceptions/error_codes.py
new file mode 100644
index 00000000..ece39e15
--- /dev/null
+++ b/apps/ai/server/exceptions/error_codes.py
@@ -0,0 +1,44 @@
+from enum import Enum, EnumMeta
+
+from pydantic import BaseModel
+from starlette.status import (
+ HTTP_400_BAD_REQUEST,
+ HTTP_500_INTERNAL_SERVER_ERROR,
+)
+
+
+class ErrorCodeData(BaseModel):
+ status_code: int
+ message: str
+
+
+class ErrorCodeInterface(EnumMeta):
+ def __new__(cls, metacls, bases, classdict):
+ enum_class = super().__new__(cls, metacls, bases, classdict)
+ for name, member in enum_class.__members__.items():
+ if not isinstance(member.value, ErrorCodeData):
+ raise TypeError(
+ f"Enum value for {name} must be an instance of ErrorCodeData"
+ )
+ return enum_class
+
+
+class BaseErrorCode(Enum, metaclass=ErrorCodeInterface):
+ """ ""
+ This class serves as a base for all Error code enums
+ It will enforce that all enum values are instances of ErrorCodeData
+ """
+
+ pass
+
+
+class GeneralErrorCode(BaseErrorCode):
+ unknown_error = ErrorCodeData(
+ status_code=HTTP_500_INTERNAL_SERVER_ERROR, message="Unknown error"
+ )
+ unhandled_engine_error = ErrorCodeData(
+ status_code=HTTP_500_INTERNAL_SERVER_ERROR, message="Unhandled engine error"
+ )
+ reserved_metadata_key = ErrorCodeData(
+ status_code=HTTP_400_BAD_REQUEST, message="Metadata cannot contain reserved key"
+ )
diff --git a/apps/ai/server/exceptions/error_response.py b/apps/ai/server/exceptions/error_response.py
new file mode 100644
index 00000000..edf699cd
--- /dev/null
+++ b/apps/ai/server/exceptions/error_response.py
@@ -0,0 +1,8 @@
+from pydantic import BaseModel
+
+
+class ErrorResponse(BaseModel):
+ trace_id: str
+ error_code: str
+ message: str
+ detail: dict | None = None
diff --git a/apps/ai/server/exceptions/exception_handlers.py b/apps/ai/server/exceptions/exception_handlers.py
new file mode 100644
index 00000000..f14f98a3
--- /dev/null
+++ b/apps/ai/server/exceptions/exception_handlers.py
@@ -0,0 +1,64 @@
+from fastapi import Request
+from fastapi.logger import logger
+from fastapi.responses import JSONResponse
+from httpx import Response
+
+from exceptions.error_response import ErrorResponse
+from exceptions.exceptions import (
+ BaseError,
+ EngineError,
+ UnhandledEngineError,
+)
+from exceptions.utils import is_http_error
+
+
+async def exception_handler(request: Request, exc: BaseError): # noqa: ARG001
+
+ trace_id = exc.trace_id
+ error_code = exc.error_code
+ status_code = exc.status_code
+ message = exc.message
+ detail = {k: v for k, v in exc.detail.items() if v is not None}
+
+ logger.error(
+ "ERROR\nTrace ID: %s, error_code: %s, detail: %s",
+ trace_id,
+ error_code,
+ detail,
+ )
+ return JSONResponse(
+ status_code=status_code,
+ content=ErrorResponse(
+ trace_id=trace_id,
+ error_code=error_code,
+ message=message,
+ detail=detail,
+ ).dict(),
+ )
+
+
+def raise_engine_exception(response: Response, org_id: str):
+ if not is_http_error(response.status_code):
+ return
+
+ response_json: dict = response.json()
+
+ if "error_code" in response_json:
+ error_code = response_json["error_code"]
+ message = response_json.get(
+ "message", f"Unknown translation engine error_code: {error_code}"
+ )
+ detail = response_json.get("detail", {})
+ detail["organization_id"] = org_id
+
+ logger.error("Handled error from translation engine: %s", message)
+
+ raise EngineError(
+ error_code=error_code,
+ status_code=response.status_code,
+ message=message,
+ detail=detail,
+ )
+
+ logger.error("Unhandled error from translation engine: %s", response.text)
+ raise UnhandledEngineError()
diff --git a/apps/ai/server/exceptions/exceptions.py b/apps/ai/server/exceptions/exceptions.py
new file mode 100644
index 00000000..a63f8c4a
--- /dev/null
+++ b/apps/ai/server/exceptions/exceptions.py
@@ -0,0 +1,117 @@
+from abc import ABC
+
+from exceptions.error_codes import BaseErrorCode, GeneralErrorCode
+from exceptions.utils import generate_trace_id
+
+
+class BaseError(ABC, Exception):
+ ERROR_CODES: BaseErrorCode = None
+
+ @property
+ def trace_id(self) -> str:
+ return self._trace_id
+
+ @property
+ def error_code(self) -> str:
+ return self._error_code
+
+ @property
+ def status_code(self) -> str:
+ return self._status_code
+
+ @property
+ def message(self) -> str:
+ return self._message
+
+ @property
+ def detail(self) -> dict:
+ return self._detail
+
+ def __init__(
+ self,
+ error_code: str = None,
+ status_code: str = None,
+ message: str = None,
+ detail: dict = None,
+ ) -> None:
+
+ if type(self) is BaseError:
+ raise TypeError("BaseError class may not be instantiated directly")
+
+ if self.ERROR_CODES is None or not hasattr(self.ERROR_CODES, "__members__"):
+ raise TypeError(
+ f"ERROR_CODES in {self.__class__.__name__} must be defined and be an enum type"
+ )
+
+ def handled_error_code(error_code: str) -> bool:
+ return error_code in self.ERROR_CODES.__members__
+
+ self._trace_id = generate_trace_id()
+
+ if error_code is not None:
+ self._error_code = error_code
+
+ if handled_error_code(error_code):
+ self._status_code = self.ERROR_CODES[error_code].value.status_code
+ self._message = (
+ message
+ if message is not None
+ else self.ERROR_CODES[error_code].value.message
+ )
+ else:
+ self._status_code = status_code if status_code is not None else "500"
+ self._message = (
+ message
+ if message is not None
+ else f"Unknown error_code: {error_code}"
+ )
+ else:
+ self._status_code = status_code if status_code is not None else 500
+ self._message = message if message is not None else "Unknown error"
+
+ self._detail = detail if detail is not None else {}
+
+
+class GeneralError(BaseError):
+ """
+ Base class for general exceptions
+ """
+
+ ERROR_CODES: BaseErrorCode = GeneralErrorCode
+
+
+class EngineError(GeneralError):
+ def __init__(
+ self,
+ error_code: str,
+ status_code: int,
+ message: str,
+ detail: dict,
+ ) -> None:
+ super().__init__(
+ error_code=error_code,
+ status_code=status_code,
+ message=message,
+ detail=detail,
+ )
+
+
+class UnhandledEngineError(GeneralError):
+ def __init__(self) -> None:
+ super().__init__(
+ error_code=GeneralErrorCode.unhandled_engine_error.name,
+ )
+
+
+class ReservedMetadataKeyError(GeneralError):
+ def __init__(self) -> None:
+ super().__init__(
+ error_code=GeneralErrorCode.reserved_metadata_key.name,
+ )
+
+
+class UnknownError(GeneralError):
+ def __init__(self, error: str = None) -> None:
+ super().__init__(
+ error_code=GeneralErrorCode.unknown_error.name, detail={"error": error}
+ )
diff --git a/apps/ai/server/exceptions/utils.py b/apps/ai/server/exceptions/utils.py
new file mode 100644
index 00000000..50e2791f
--- /dev/null
+++ b/apps/ai/server/exceptions/utils.py
@@ -0,0 +1,11 @@
+import uuid
+
+from starlette.status import HTTP_400_BAD_REQUEST
+
+
+def is_http_error(status_code: int) -> bool:
+ return status_code >= HTTP_400_BAD_REQUEST
+
+
+def generate_trace_id():
+ return f"E-{str(uuid.uuid4())}" # Generate a unique trace ID for each error
diff --git a/apps/ai/server/middleware/error.py b/apps/ai/server/middleware/error.py
new file mode 100644
index 00000000..46f441ad
--- /dev/null
+++ b/apps/ai/server/middleware/error.py
@@ -0,0 +1,31 @@
+from fastapi import Request
+from fastapi.logger import logger
+from fastapi.responses import JSONResponse
+from starlette.middleware.base import BaseHTTPMiddleware
+
+from exceptions.error_codes import GeneralErrorCode
+from exceptions.error_response import ErrorResponse
+from exceptions.utils import generate_trace_id
+
+
+class UnknownErrorMiddleware(BaseHTTPMiddleware):
+ async def dispatch(self, request: Request, call_next):
+ try:
+ return await call_next(request)
+ except Exception:
+ trace_id = generate_trace_id()
+ logger.error(f"Unhandled ERROR\nTrace ID: {trace_id}", exc_info=True)
+
+ error_code = GeneralErrorCode.unknown_error.name
+ status_code = GeneralErrorCode.unknown_error.value.status_code
+ message = GeneralErrorCode.unknown_error.value.message
+
+ # raising an exception here causes problems with the exception handling
+ return JSONResponse(
+ status_code=status_code,
+ content=ErrorResponse(
+ trace_id=trace_id,
+ error_code=error_code,
+ message=message,
+ ).dict(),
+ )
diff --git a/apps/ai/server/modules/auth/models/exceptions.py b/apps/ai/server/modules/auth/models/exceptions.py
new file mode 100644
index 00000000..188ee577
--- /dev/null
+++ b/apps/ai/server/modules/auth/models/exceptions.py
@@ -0,0 +1,95 @@
+from starlette.status import (
+ HTTP_401_UNAUTHORIZED,
+ HTTP_403_FORBIDDEN,
+ HTTP_500_INTERNAL_SERVER_ERROR,
+)
+
+from exceptions.error_codes import BaseErrorCode, ErrorCodeData
+from exceptions.exceptions import BaseError
+
+
+class AuthErrorCode(BaseErrorCode):
+ unauthorized_user = ErrorCodeData(
+ status_code=HTTP_401_UNAUTHORIZED, message="Unauthorized user"
+ )
+ unauthorized_operation = ErrorCodeData(
+ status_code=HTTP_401_UNAUTHORIZED, message="Unauthorized operation"
+ )
+ unauthorized_data_access = ErrorCodeData(
+ status_code=HTTP_401_UNAUTHORIZED, message="Unauthorized data access"
+ )
+ bearer_token_expired = ErrorCodeData(
+ status_code=HTTP_401_UNAUTHORIZED, message="Bearer token expired"
+ )
+ invalid_bearer_token = ErrorCodeData(
+ status_code=HTTP_403_FORBIDDEN, message="Bearer token is invalid"
+ )
+ invalid_or_revoked_key = ErrorCodeData(
+ status_code=HTTP_403_FORBIDDEN, message="Invalid or revoked API key"
+ )
+ py_jwk_client_error = ErrorCodeData(
+ status_code=HTTP_500_INTERNAL_SERVER_ERROR, message="PyJWKClient error"
+ )
+ decode_error = ErrorCodeData(
+ status_code=HTTP_401_UNAUTHORIZED, message="Decode error"
+ )
+
+
+class AuthError(BaseError):
+ """
+ Base class for auth exceptions
+ """
+
+ ERROR_CODES: BaseErrorCode = AuthErrorCode
+
+
+class UnauthorizedUserError(AuthError):
+ def __init__(self, email: str) -> None:
+ super().__init__(
+ error_code=AuthErrorCode.unauthorized_user.name,
+ detail={"email": email},
+ )
+
+
+class UnauthorizedOperationError(AuthError):
+ def __init__(self, user_id: str) -> None:
+ super().__init__(
+ error_code=AuthErrorCode.unauthorized_operation.name,
+ detail={"user_id": user_id},
+ )
+
+
+class UnauthorizedDataAccessError(AuthError):
+ def __init__(self, user_id: str) -> None:
+ super().__init__(
+ error_code=AuthErrorCode.unauthorized_data_access.name,
+ detail={"user_id": user_id},
+ )
+
+
+class InvalidOrRevokedAPIKeyError(AuthError):
+ def __init__(self, key_id: str) -> None:
+ super().__init__(
+ error_code=AuthErrorCode.invalid_or_revoked_key.name,
+ detail={"key_id": key_id},
+ )
+
+
+class BearerTokenExpiredError(AuthError):
+ def __init__(self) -> None:
+ super().__init__(error_code=AuthErrorCode.bearer_token_expired.name)
+
+
+class InvalidBearerTokenError(AuthError):
+ def __init__(self) -> None:
+ super().__init__(error_code=AuthErrorCode.invalid_bearer_token.name)
+
+
+class PyJWKClientError(AuthError):
+ def __init__(self) -> None:
+ super().__init__(error_code=AuthErrorCode.py_jwk_client_error.name)
+
+
+class DecodeError(AuthError):
+ def __init__(self) -> None:
+ super().__init__(error_code=AuthErrorCode.decode_error.name)
diff --git a/apps/ai/server/modules/db_connection/models/exceptions.py b/apps/ai/server/modules/db_connection/models/exceptions.py
new file mode 100644
index 00000000..b4ee16f7
--- /dev/null
+++ b/apps/ai/server/modules/db_connection/models/exceptions.py
@@ -0,0 +1,41 @@
+from starlette.status import (
+ HTTP_400_BAD_REQUEST,
+ HTTP_404_NOT_FOUND,
+)
+
+from exceptions.error_codes import BaseErrorCode, ErrorCodeData
+from exceptions.exceptions import BaseError
+
+
+class DBConnectionErrorCode(BaseErrorCode):
+ db_connection_not_found = ErrorCodeData(
+ status_code=HTTP_404_NOT_FOUND, message="Database connection not found"
+ )
+ db_connection_alias_exists = ErrorCodeData(
+ status_code=HTTP_400_BAD_REQUEST,
+ message="Existing database connection already has alias",
+ )
+
+
+class DBConnectionError(BaseError):
+ """
+ Base class for database connection exceptions
+ """
+
+ ERROR_CODES: BaseErrorCode = DBConnectionErrorCode
+
+
+class DBConnectionNotFoundError(DBConnectionError):
+ def __init__(self, db_connection_id: str, org_id: str) -> None:
+ super().__init__(
+ error_code=DBConnectionErrorCode.db_connection_not_found.name,
+ detail={"db_connection_id": db_connection_id, "organization_id": org_id},
+ )
+
+
+class DBConnectionAliasExistsError(DBConnectionError):
+ def __init__(self, db_connection_id: str, org_id: str) -> None:
+ super().__init__(
+ error_code=DBConnectionErrorCode.db_connection_alias_exists.name,
+ detail={"db_connection_id": db_connection_id, "organization_id": org_id},
+ )
diff --git a/apps/ai/server/modules/db_connection/repository.py b/apps/ai/server/modules/db_connection/repository.py
index e4971eac..3568a7fb 100644
--- a/apps/ai/server/modules/db_connection/repository.py
+++ b/apps/ai/server/modules/db_connection/repository.py
@@ -29,3 +29,17 @@ def get_db_connection(self, db_connection_id: str, org_id: str) -> DBConnection:
if db_connection
else None
)
+
+ def get_db_connection_by_alias(self, alias: str, org_id: str) -> DBConnection:
+ db_connection = MongoDB.find_one(
+ DATABASE_CONNECTION_COL,
+ {
+ "alias": alias,
+ "metadata.dh_internal.organization_id": org_id,
+ },
+ )
+ return (
+ DBConnection(id=str(db_connection["_id"]), **db_connection)
+ if db_connection
+ else None
+ )
diff --git a/apps/ai/server/modules/db_connection/service.py b/apps/ai/server/modules/db_connection/service.py
index aeec88c7..7bf14338 100644
--- a/apps/ai/server/modules/db_connection/service.py
+++ b/apps/ai/server/modules/db_connection/service.py
@@ -1,19 +1,23 @@
import json
import httpx
-from fastapi import HTTPException, UploadFile, status
+from fastapi import UploadFile
from config import settings, ssh_settings
+from exceptions.exception_handlers import raise_engine_exception
from modules.db_connection.models.entities import (
DBConnection,
DBConnectionMetadata,
DHDBConnectionMetadata,
)
+from modules.db_connection.models.exceptions import (
+ DBConnectionAliasExistsError,
+ DBConnectionNotFoundError,
+)
from modules.db_connection.models.requests import DBConnectionRequest
from modules.db_connection.models.responses import DBConnectionResponse
from modules.db_connection.repository import DBConnectionRepository
from utils.analytics import Analytics, EventName, EventType
-from utils.exception import raise_for_status
from utils.misc import reserved_key_in_metadata
from utils.s3 import S3
@@ -52,6 +56,12 @@ async def add_db_connection(
file: UploadFile | None = None,
) -> DBConnectionResponse:
reserved_key_in_metadata(db_connection_request.metadata)
+ db_connection = self.repo.get_db_connection_by_alias(
+ db_connection_request.alias, org_id
+ )
+ if db_connection:
+ raise DBConnectionAliasExistsError(db_connection.id, org_id)
+
db_connection_internal_request = DBConnection(
**db_connection_request.dict(exclude_unset=True)
)
@@ -79,7 +89,7 @@ async def add_db_connection(
timeout=settings.default_engine_timeout,
)
- raise_for_status(response.status_code, response.text)
+ raise_engine_exception(response, org_id=org_id)
db_connection = DBConnectionResponse(**response.json())
self.analytics.track(
org_id,
@@ -130,7 +140,7 @@ async def update_db_connection(
json=db_connection_internal_request.dict(),
timeout=settings.default_engine_timeout,
)
- raise_for_status(response.status_code, response.text)
+ raise_engine_exception(response, org_id=org_id)
return DBConnectionResponse(**response.json())
def get_db_connection_in_org(
@@ -138,10 +148,7 @@ def get_db_connection_in_org(
) -> DBConnection:
db_connection = self.repo.get_db_connection(db_connection_id, org_id)
if not db_connection:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="Database connection not found",
- )
+ raise DBConnectionNotFoundError(db_connection_id, org_id)
return db_connection
def get_database_type(self, connection_uri: str) -> str:
diff --git a/apps/ai/server/modules/finetuning/models/exceptions.py b/apps/ai/server/modules/finetuning/models/exceptions.py
new file mode 100644
index 00000000..2677bbb5
--- /dev/null
+++ b/apps/ai/server/modules/finetuning/models/exceptions.py
@@ -0,0 +1,41 @@
+from starlette.status import (
+ HTTP_400_BAD_REQUEST,
+ HTTP_404_NOT_FOUND,
+)
+
+from exceptions.error_codes import BaseErrorCode, ErrorCodeData
+from exceptions.exceptions import BaseError
+
+
+class FinetuningErrorCode(BaseErrorCode):
+ finetuning_not_found = ErrorCodeData(
+ status_code=HTTP_404_NOT_FOUND, message="Finetuning not found"
+ )
+ finetuning_alias_exists = ErrorCodeData(
+ status_code=HTTP_400_BAD_REQUEST,
+ message="Existing finetuning already has alias",
+ )
+
+
+class FinetuningError(BaseError):
+ """
+ Base class for finetuning exceptions
+ """
+
+ ERROR_CODES: BaseErrorCode = FinetuningErrorCode
+
+
+class FinetuningNotFoundError(FinetuningError):
+ def __init__(self, finetuning_id: str, org_id: str) -> None:
+ super().__init__(
+ error_code=FinetuningErrorCode.finetuning_not_found.name,
+ detail={"finetuning_id": finetuning_id, "organization_id": org_id},
+ )
+
+
+class FinetuningAliasExistsError(FinetuningError):
+ def __init__(self, finetuning_id: str, org_id: str) -> None:
+ super().__init__(
+ error_code=FinetuningErrorCode.finetuning_alias_exists.name,
+ detail={"finetuning_id": finetuning_id, "organization_id": org_id},
+ )
diff --git a/apps/ai/server/modules/finetuning/repository.py b/apps/ai/server/modules/finetuning/repository.py
index a50860de..40a0dcc7 100644
--- a/apps/ai/server/modules/finetuning/repository.py
+++ b/apps/ai/server/modules/finetuning/repository.py
@@ -36,3 +36,17 @@ def get_finetuning_job(self, finetuning_id: str, org_id: str) -> Finetuning:
if finetuning_job
else None
)
+
+ def get_finetuning_job_by_alias(self, alias: str, org_id: str) -> Finetuning:
+ finetuning_job = MongoDB.find_one(
+ FINETUNING_COL,
+ {
+ "alias": alias,
+ "metadata.dh_internal.organization_id": org_id,
+ },
+ )
+ return (
+ Finetuning(id=str(finetuning_job["_id"]), **finetuning_job)
+ if finetuning_job
+ else None
+ )
diff --git a/apps/ai/server/modules/finetuning/service.py b/apps/ai/server/modules/finetuning/service.py
index 469ad9ed..232e9b0a 100644
--- a/apps/ai/server/modules/finetuning/service.py
+++ b/apps/ai/server/modules/finetuning/service.py
@@ -1,19 +1,22 @@
import httpx
-from fastapi import HTTPException, status
from config import settings
+from exceptions.exception_handlers import raise_engine_exception
from modules.db_connection.service import DBConnectionService
from modules.finetuning.models.entities import (
DHFinetuningMetadata,
Finetuning,
FinetuningMetadata,
)
+from modules.finetuning.models.exceptions import (
+ FinetuningAliasExistsError,
+ FinetuningNotFoundError,
+)
from modules.finetuning.models.requests import FinetuningRequest
from modules.finetuning.models.responses import AggrFinetuning
from modules.finetuning.repository import FinetuningRepository
from modules.golden_sql.service import GoldenSQLService
from utils.analytics import Analytics, EventName, EventType
-from utils.exception import raise_for_status
from utils.misc import reserved_key_in_metadata
@@ -43,7 +46,7 @@ async def get_finetuning_jobs(
params={"db_connection_id": db_connection.id},
timeout=settings.default_engine_timeout,
)
- raise_for_status(response.status_code, response.text)
+ raise_engine_exception(response, org_id=org_id)
finetuning_jobs += [
AggrFinetuning(
**finetuning_job,
@@ -62,7 +65,7 @@ async def get_finetuning_job(
settings.engine_url + f"/finetunings/{finetuning_id}",
timeout=settings.default_engine_timeout,
)
- raise_for_status(response.status_code, response.text)
+ raise_engine_exception(response, org_id=org_id)
finetuning_job = Finetuning(**response.json())
db_connection = self.db_connection_service.get_db_connection_in_org(
finetuning_job.db_connection_id, org_id
@@ -79,6 +82,12 @@ async def create_finetuning_job(
finetuning_request.db_connection_id, org_id
)
+ finetuning = self.repo.get_finetuning_job_by_alias(
+ finetuning_request.alias, org_id
+ )
+ if finetuning:
+ raise FinetuningAliasExistsError(finetuning.id, org_id)
+
finetuning_request.metadata = FinetuningMetadata(
**finetuning_request.metadata,
dh_internal=DHFinetuningMetadata(organization_id=org_id),
@@ -89,7 +98,7 @@ async def create_finetuning_job(
settings.engine_url + "/finetunings",
json=finetuning_request.dict(exclude_unset=True),
)
- raise_for_status(response.status_code, response.text)
+ raise_engine_exception(response, org_id=org_id)
aggr_finetuning = AggrFinetuning(
**response.json(), db_connection_alias=db_connection.alias
@@ -124,7 +133,7 @@ async def cancel_finetuning_job(
response = await client.post(
settings.engine_url + f"/finetunings/{finetuning_id}/cancel",
)
- raise_for_status(response.status_code, response.text)
+ raise_engine_exception(response, org_id=org_id)
return AggrFinetuning(
**response.json(), db_connection_alias=db_connection.alias
)
@@ -132,10 +141,7 @@ async def cancel_finetuning_job(
def get_finetuning_job_in_org(self, finetuning_id: str, org_id: str) -> Finetuning:
finetuning_job = self.repo.get_finetuning_job(finetuning_id, org_id)
if not finetuning_job:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="Finetuning not found",
- )
+ raise FinetuningNotFoundError(finetuning_id, org_id)
return finetuning_job
def is_gpt_4_model(self, model_name: str) -> bool:
diff --git a/apps/ai/server/modules/generation/aggr_service.py b/apps/ai/server/modules/generation/aggr_service.py
index ac77235d..c7e59151 100644
--- a/apps/ai/server/modules/generation/aggr_service.py
+++ b/apps/ai/server/modules/generation/aggr_service.py
@@ -1,10 +1,10 @@
from datetime import datetime
import httpx
-from fastapi import HTTPException, status
from fastapi.responses import StreamingResponse
from config import settings
+from exceptions.exception_handlers import raise_engine_exception
from modules.db_connection.service import DBConnectionService
from modules.generation.models.entities import (
DHNLGenerationMetadata,
@@ -21,6 +21,12 @@
SQLGenerationMetadata,
SQLGenerationStatus,
)
+from modules.generation.models.exceptions import (
+ GenerationVerifiedOrRejectedError,
+ InvalidSqlGenerationError,
+ PromptNotFoundError,
+ SqlGenerationNotFoundError,
+)
from modules.generation.models.requests import (
GenerationUpdateRequest,
NLGenerationRequest,
@@ -47,7 +53,6 @@
from modules.user.models.responses import UserResponse
from modules.user.service import UserService
from utils.analytics import Analytics, EventName, EventType
-from utils.exception import GenerationEngineError, raise_for_status
from utils.slack import SlackWebClient, remove_slack_mentions
CONFIDENCE_CAP = 0.95
@@ -79,9 +84,7 @@ async def get_generation(self, prompt_id: str, org_id: str) -> GenerationRespons
prompt, sql_generation, nl_generation
)
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND, detail="Prompt not found"
- )
+ raise PromptNotFoundError(prompt_id, org_id)
def get_generation_list(
self,
@@ -192,7 +195,7 @@ async def create_generation(
json=generation_request.dict(exclude_unset=True),
timeout=settings.default_engine_timeout,
)
- self._raise_for_generation_status(response, display_id=display_id)
+ raise_engine_exception(response, org_id=organization.id)
nl_generation = NLGeneration(**response.json())
sql_generation = self.repo.get_sql_generation(
@@ -278,7 +281,7 @@ async def create_prompt_sql_generation_result(
json=generation_request.dict(exclude_unset=True),
timeout=settings.default_engine_timeout,
)
- self._raise_for_generation_status(response)
+ raise_engine_exception(response, org_id=org_id)
sql_generation = SQLGeneration(**response.json())
self.repo.update_prompt_dh_metadata(
@@ -294,15 +297,13 @@ async def create_prompt_sql_generation_result(
prompt = self.repo.get_prompt(sql_generation.prompt_id, organization.id)
if sql_generation.status == SQLGenerationStatus.VALID:
- sql_result_response = await client.get(
+ response = await client.get(
settings.engine_url
+ f"/sql-generations/{sql_generation.id}/execute",
timeout=settings.default_engine_timeout,
)
- raise_for_status(
- sql_result_response.status_code, sql_result_response.text
- )
- sql_result = sql_result_response.json()
+ raise_engine_exception(response, org_id=org_id)
+ sql_result = response.json()
else:
sql_result = None
@@ -323,9 +324,7 @@ async def update_generation(
) -> GenerationResponse:
prompt = self.repo.get_prompt(prompt_id, org_id)
if not prompt:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND, detail="Prompt not found"
- )
+ raise PromptNotFoundError(prompt_id, org_id)
sql_generation = self.repo.get_latest_sql_generation(prompt_id, org_id)
nl_generation = (
@@ -407,9 +406,7 @@ async def create_sql_nl_generation(
) -> GenerationResponse:
prompt = self.repo.get_prompt(prompt_id, org_id)
if not prompt:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND, detail="Prompt not found"
- )
+ raise PromptNotFoundError(prompt_id, org_id)
generation_request = SQLNLGenerationRequest(
metadata=NLGenerationMetadata(
@@ -435,7 +432,7 @@ async def create_sql_nl_generation(
json=generation_request.dict(exclude_unset=True),
timeout=settings.default_engine_timeout,
)
- self._raise_for_generation_status(response, prompt=prompt)
+ raise_engine_exception(response, org_id=org_id)
nl_generation = NLGeneration(**response.json())
sql_generation = self.repo.get_sql_generation(
nl_generation.sql_generation_id, org_id
@@ -459,15 +456,13 @@ async def create_sql_nl_generation(
prompt = self.repo.get_prompt(prompt_id, org_id)
if sql_generation.status == SQLGenerationStatus.VALID:
- sql_result_response = await client.get(
+ response = await client.get(
settings.engine_url
+ f"/sql-generations/{sql_generation.id}/execute",
timeout=settings.default_engine_timeout,
)
- raise_for_status(
- sql_result_response.status_code, sql_result_response.text
- )
- sql_result = sql_result_response.json()
+ raise_engine_exception(response, org_id=org_id)
+ sql_result = response.json()
else:
sql_result = None
@@ -484,23 +479,18 @@ async def create_sql_generation_result(
self,
prompt_id: str,
sql_request: SQLRequest,
- org_id,
+ org_id: str,
user: UserResponse = None,
) -> GenerationResponse:
prompt = self.repo.get_prompt(prompt_id, org_id)
if not prompt:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND, detail="Prompt not found"
- )
+ raise PromptNotFoundError(prompt_id, org_id)
if prompt.metadata.dh_internal.generation_status in {
GenerationStatus.VERIFIED,
GenerationStatus.REJECTED,
}:
- raise_for_status(
- status_code=status.HTTP_400_BAD_REQUEST,
- message="generation has already been verified or rejected",
- )
+ raise GenerationVerifiedOrRejectedError(prompt_id, org_id)
generation_request = SQLGenerationRequest(
sql=sql_request.sql,
@@ -516,7 +506,7 @@ async def create_sql_generation_result(
json=generation_request.dict(exclude_unset=True),
timeout=settings.default_engine_timeout,
)
- self._raise_for_generation_status(response, prompt=prompt)
+ raise_engine_exception(response, org_id=org_id)
sql_generation = SQLGeneration(**response.json())
self.repo.update_prompt_dh_metadata(
@@ -537,15 +527,13 @@ async def create_sql_generation_result(
prompt = self.repo.get_prompt(prompt_id, org_id)
if sql_generation.status == SQLGenerationStatus.VALID:
- sql_result_response = await client.get(
+ response = await client.get(
settings.engine_url
+ f"/sql-generations/{sql_generation.id}/execute",
timeout=settings.default_engine_timeout,
)
- raise_for_status(
- sql_result_response.status_code, sql_result_response.text
- )
- sql_result = sql_result_response.json()
+ raise_engine_exception(response, org_id=org_id)
+ sql_result = response.json()
else:
sql_result = None
@@ -562,14 +550,10 @@ async def create_nl_generation(
) -> NLGenerationResponse:
prompt = self.repo.get_prompt(prompt_id, org_id)
if not prompt:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND, detail="Prompt not found"
- )
+ raise PromptNotFoundError(prompt_id, org_id)
sql_generation = self.repo.get_latest_sql_generation(prompt_id, org_id)
if not sql_generation:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND, detail="SQL Prompt not found"
- )
+ raise SqlGenerationNotFoundError(prompt_id, org_id)
generation_request = NLGenerationRequest(
metadata=NLGenerationMetadata(
@@ -584,7 +568,7 @@ async def create_nl_generation(
json=generation_request.dict(exclude_unset=True),
timeout=settings.default_engine_timeout,
)
- raise_for_status(response.status_code, response.text)
+ raise_engine_exception(response, org_id=org_id)
nl_generation = NLGeneration(**response.json())
self.repo.update_prompt_dh_metadata(
@@ -632,27 +616,19 @@ async def send_message(self, prompt_id: str, org_id: str):
async def export_csv_file(self, prompt_id: str, org_id: str) -> StreamingResponse:
if not self.repo.get_prompt(prompt_id, org_id):
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND, detail="Prompt not found"
- )
+ raise PromptNotFoundError(prompt_id, org_id)
sql_generation = self.repo.get_latest_sql_generation(prompt_id, org_id)
if not sql_generation:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="SQL Generation not found",
- )
+ raise SqlGenerationNotFoundError(prompt_id, org_id)
if sql_generation.status != SQLGenerationStatus.VALID:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="SQL Generation is not valid",
- )
+ raise InvalidSqlGenerationError(sql_generation.id, org_id)
async with httpx.AsyncClient() as client:
response = await client.get(
settings.engine_url + f"/sql-generations/{sql_generation.id}/csv-file",
timeout=settings.default_engine_timeout,
)
- raise_for_status(response.status_code, response.text)
+ raise_engine_exception(response, org_id=org_id)
return StreamingResponse(
content=response.iter_bytes(),
headers=response.headers,
@@ -721,28 +697,3 @@ def _get_mapped_generation_response(
sql_result=sql_result,
**prompt.metadata.dh_internal.dict(exclude_unset=True),
)
-
- def _raise_for_generation_status(
- self, response: httpx.Response, prompt: Prompt = None, display_id: str = None
- ):
- response_json = response.json()
- if response.status_code != status.HTTP_201_CREATED:
- if prompt or ("prompt_id" in response_json and response_json["prompt_id"]):
- prompt_id = prompt.id if prompt else response_json["prompt_id"]
- self.repo.update_prompt_dh_metadata(
- prompt_id,
- DHPromptMetadata(
- generation_status=GenerationStatus.ERROR,
- ),
- )
- raise GenerationEngineError(
- status_code=response.status_code,
- prompt_id=prompt_id,
- display_id=(
- display_id or prompt.metadata.dh_internal.display_id
- if prompt
- else None
- ),
- error_message=response_json["message"],
- )
- raise_for_status(response.status_code, response.text)
diff --git a/apps/ai/server/modules/generation/models/exceptions.py b/apps/ai/server/modules/generation/models/exceptions.py
new file mode 100644
index 00000000..25002e25
--- /dev/null
+++ b/apps/ai/server/modules/generation/models/exceptions.py
@@ -0,0 +1,74 @@
+from starlette.status import (
+ HTTP_400_BAD_REQUEST,
+ HTTP_404_NOT_FOUND,
+)
+
+from exceptions.error_codes import BaseErrorCode, ErrorCodeData
+from exceptions.exceptions import BaseError
+
+
+class GenerationErrorCode(BaseErrorCode):
+ prompt_not_found = ErrorCodeData(
+ status_code=HTTP_404_NOT_FOUND, message="Prompt not found"
+ )
+ sql_generation_not_found = ErrorCodeData(
+ status_code=HTTP_404_NOT_FOUND, message="SQL generation not found"
+ )
+ nl_generation_not_found = ErrorCodeData(
+ status_code=HTTP_404_NOT_FOUND, message="NL generation not found"
+ )
+ generation_verified_or_rejected = ErrorCodeData(
+ status_code=HTTP_400_BAD_REQUEST,
+ message="Cannot modify verified or rejected generation",
+ )
+ invalid_sql_generation = ErrorCodeData(
+ status_code=HTTP_400_BAD_REQUEST, message="Invalid SQL generation"
+ )
+
+
+class GenerationError(BaseError):
+ """
+ Base class for generation exceptions
+ """
+
+ ERROR_CODES: BaseErrorCode = GenerationErrorCode
+
+
+class PromptNotFoundError(GenerationError):
+ def __init__(self, prompt_id: str, org_id: str) -> None:
+ super().__init__(
+ error_code=GenerationErrorCode.prompt_not_found.name,
+ detail={"prompt_id": prompt_id, "organization_id": org_id},
+ )
+
+
+class SqlGenerationNotFoundError(GenerationError):
+ def __init__(self, sql_generation_id: str, org_id: str) -> None:
+ super().__init__(
+ error_code=GenerationErrorCode.sql_generation_not_found.name,
+ detail={"sql_generation_id": sql_generation_id, "organization_id": org_id},
+ )
+
+
+class NlGenerationNotFoundError(GenerationError):
+ def __init__(self, nl_generation_id: str, org_id: str) -> None:
+ super().__init__(
+ error_code=GenerationErrorCode.nl_generation_not_found.name,
+ detail={"nl_generation_id": nl_generation_id, "organization_id": org_id},
+ )
+
+
+class GenerationVerifiedOrRejectedError(GenerationError):
+ def __init__(self, nl_generation_id: str, org_id: str) -> None:
+ super().__init__(
+ error_code=GenerationErrorCode.generation_verified_or_rejected.name,
+ detail={"nl_generation_id": nl_generation_id, "organization_id": org_id},
+ )
+
+
+class InvalidSqlGenerationError(GenerationError):
+ def __init__(self, sql_generation_id: str, org_id: str) -> None:
+ super().__init__(
+ error_code=GenerationErrorCode.invalid_sql_generation.name,
+ detail={"sql_generation_id": sql_generation_id, "organization_id": org_id},
+ )
diff --git a/apps/ai/server/modules/generation/service.py b/apps/ai/server/modules/generation/service.py
index 6d93499d..509c909a 100644
--- a/apps/ai/server/modules/generation/service.py
+++ b/apps/ai/server/modules/generation/service.py
@@ -1,8 +1,8 @@
import httpx
-from fastapi import HTTPException, status
from fastapi.responses import StreamingResponse
from config import settings
+from exceptions.exception_handlers import raise_engine_exception
from modules.db_connection.service import DBConnectionService
from modules.generation.models.entities import (
DHNLGenerationMetadata,
@@ -18,6 +18,13 @@
SQLGenerationMetadata,
SQLGenerationStatus,
)
+from modules.generation.models.exceptions import (
+ GenerationVerifiedOrRejectedError,
+ InvalidSqlGenerationError,
+ NlGenerationNotFoundError,
+ PromptNotFoundError,
+ SqlGenerationNotFoundError,
+)
from modules.generation.models.requests import (
NLGenerationRequest,
PromptRequest,
@@ -33,7 +40,6 @@
)
from modules.generation.repository import GenerationRepository
from utils.analytics import Analytics, EventName, EventType
-from utils.exception import GenerationEngineError, raise_for_status
from utils.misc import reserved_key_in_metadata
@@ -128,7 +134,7 @@ async def create_prompt(
settings.engine_url + "/prompts",
json=create_request.dict(exclude_unset=True),
)
- raise_for_status(response.status_code, response.text)
+ raise_engine_exception(response, org_id=org_id)
return PromptResponse(**response.json())
async def create_prompt_sql_generation(
@@ -155,7 +161,7 @@ async def create_prompt_sql_generation(
json=create_request.dict(exclude_unset=True),
timeout=settings.default_engine_timeout,
)
- self._raise_for_generation_status(response, org_id)
+ raise_engine_exception(response, org_id=org_id)
sql_generation = SQLGenerationResponse(**response.json())
self._update_generation_status(
@@ -195,7 +201,7 @@ async def create_prompt_sql_nl_generation(
json=create_request.dict(exclude_unset=True),
timeout=settings.default_engine_timeout,
)
- self._raise_for_generation_status(response, org_id)
+ raise_engine_exception(response, org_id=org_id)
nl_generation = NLGenerationResponse(**response.json())
sql_generation = self.repo.get_sql_generation(
nl_generation.sql_generation_id, org_id
@@ -218,11 +224,7 @@ async def create_sql_generation(
GenerationStatus.REJECTED,
GenerationStatus.VERIFIED,
}:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Cannot create SQL generation for a prompt that has been verified or rejected",
- )
-
+ raise GenerationVerifiedOrRejectedError(prompt_id, org_id)
create_request.metadata = SQLGenerationMetadata(
**create_request.metadata,
dh_internal=DHSQLGenerationMetadata(organization_id=org_id),
@@ -237,7 +239,7 @@ async def create_sql_generation(
json=create_request.dict(exclude_unset=True),
timeout=settings.default_engine_timeout,
)
- self._raise_for_generation_status(response, org_id, prompt)
+ raise_engine_exception(response, org_id=org_id)
sql_generation = SQLGenerationResponse(**response.json())
self._update_generation_status(prompt_id, sql_generation.status)
@@ -256,10 +258,7 @@ async def create_sql_nl_generation(
GenerationStatus.REJECTED,
GenerationStatus.VERIFIED,
}:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Cannot create SQL generation for a prompt that has been verified or rejected",
- )
+ raise GenerationVerifiedOrRejectedError(prompt_id, org_id=org_id)
create_request.sql_generation.metadata = SQLGenerationMetadata(
**create_request.sql_generation.metadata,
dh_internal=DHSQLGenerationMetadata(organization_id=org_id),
@@ -281,7 +280,7 @@ async def create_sql_nl_generation(
json=create_request.dict(exclude_unset=True),
timeout=settings.default_engine_timeout,
)
- self._raise_for_generation_status(response, org_id, prompt)
+ raise_engine_exception(response, org_id=org_id)
nl_generation = NLGenerationResponse(**response.json())
sql_generation = self.repo.get_sql_generation(
nl_generation.sql_generation_id, org_id
@@ -311,7 +310,7 @@ async def create_nl_generation(
json=create_request.dict(exclude_unset=True),
timeout=settings.default_engine_timeout,
)
- raise_for_status(response.status_code, response.text)
+ raise_engine_exception(response, org_id=org_id)
return NLGenerationResponse(**response.json())
async def execute_sql_generation(
@@ -322,17 +321,14 @@ async def execute_sql_generation(
) -> list[dict]:
sql_generation = self.get_sql_generation_in_org(sql_generation_id, org_id)
if sql_generation.status != SQLGenerationStatus.VALID:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="SQL Generation is not valid",
- )
+ raise InvalidSqlGenerationError(sql_generation_id, org_id)
async with httpx.AsyncClient() as client:
response = await client.get(
settings.engine_url + f"/sql-generations/{sql_generation_id}/execute",
params={"max_rows": max_rows},
timeout=settings.default_engine_timeout,
)
- raise_for_status(response.status_code, response.text)
+ raise_engine_exception(response, org_id=org_id)
return response.json()
async def export_csv_file(
@@ -340,16 +336,13 @@ async def export_csv_file(
) -> StreamingResponse:
sql_generation = self.get_sql_generation_in_org(sql_generation_id, org_id)
if sql_generation.status != SQLGenerationStatus.VALID:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="SQL Generation is not valid",
- )
+ raise InvalidSqlGenerationError(sql_generation_id, org_id)
async with httpx.AsyncClient() as client:
response = await client.get(
settings.engine_url + f"/sql-generations/{sql_generation_id}/csv-file",
timeout=settings.default_engine_timeout,
)
- raise_for_status(response.status_code, response.text)
+ raise_engine_exception(response, org_id=org_id)
return StreamingResponse(
content=response.iter_bytes(),
headers=response.headers,
@@ -360,10 +353,7 @@ async def export_csv_file(
def get_prompt_in_org(self, prompt_id: str, org_id: str) -> Prompt:
prompt = self.repo.get_prompt(prompt_id, org_id)
if not prompt:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="Prompt not found",
- )
+ raise PromptNotFoundError(prompt_id, org_id)
return prompt
def get_sql_generation_in_org(
@@ -371,10 +361,7 @@ def get_sql_generation_in_org(
) -> SQLGeneration:
sql_generation = self.repo.get_sql_generation(sql_generation_id, org_id)
if not sql_generation:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="SQL Generation not found",
- )
+ raise SqlGenerationNotFoundError(sql_generation_id, org_id)
return sql_generation
def get_nl_generation_in_org(
@@ -382,32 +369,9 @@ def get_nl_generation_in_org(
) -> NLGeneration:
nl_generation = self.repo.get_nl_generation(nl_generation_id, org_id)
if not nl_generation:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="NL Generation not found",
- )
+ raise NlGenerationNotFoundError(nl_generation_id, org_id)
return nl_generation
- def _raise_for_generation_status(
- self, response: httpx.Response, org_id: str, prompt: Prompt = None
- ):
- response_json = response.json()
- if response.status_code != status.HTTP_201_CREATED:
- if "prompt_id" in response_json and response_json["prompt_id"]:
- prompt = self.get_prompt(response_json["prompt_id"], org_id)
- if prompt:
- self.repo.update_prompt_dh_metadata(
- prompt.id,
- DHPromptMetadata(generation_status=GenerationStatus.ERROR),
- )
- raise GenerationEngineError(
- status_code=response.status_code,
- prompt_id=prompt.id,
- display_id=prompt.metadata.dh_internal.display_id,
- error_message=response_json["message"],
- )
- raise_for_status(response.status_code, response.text)
-
def _update_generation_status(self, prompt_id: str, status: SQLGenerationStatus):
self.repo.update_prompt_dh_metadata(
prompt_id,
diff --git a/apps/ai/server/modules/golden_sql/models/exceptions.py b/apps/ai/server/modules/golden_sql/models/exceptions.py
new file mode 100644
index 00000000..4c578688
--- /dev/null
+++ b/apps/ai/server/modules/golden_sql/models/exceptions.py
@@ -0,0 +1,40 @@
+from starlette.status import (
+ HTTP_400_BAD_REQUEST,
+ HTTP_404_NOT_FOUND,
+)
+
+from exceptions.error_codes import BaseErrorCode, ErrorCodeData
+from exceptions.exceptions import BaseError
+
+
+class GoldenSQLErrorCode(BaseErrorCode):
+ golden_sql_not_found = ErrorCodeData(
+ status_code=HTTP_404_NOT_FOUND, message="Golden SQL not found"
+ )
+ cannot_delete_golden_sql = ErrorCodeData(
+ status_code=HTTP_400_BAD_REQUEST, message="Cannot delete golden SQL"
+ )
+
+
+class GoldenSQLError(BaseError):
+ """
+ Base class for golden SQL exceptions
+ """
+
+ ERROR_CODES: BaseErrorCode = GoldenSQLErrorCode
+
+
+class GoldenSqlNotFoundError(GoldenSQLError):
+ def __init__(self, golden_sql_id: str, org_id: str) -> None:
+ super().__init__(
+ error_code=GoldenSQLErrorCode.golden_sql_not_found.name,
+ detail={"golden_sql_id": golden_sql_id, "organization_id": org_id},
+ )
+
+
+class CannotDeleteGoldenSqlError(GoldenSQLError):
+ def __init__(self, golden_sql_id: str, org_id: str) -> None:
+ super().__init__(
+ error_code=GoldenSQLErrorCode.cannot_delete_golden_sql.name,
+ detail={"golden_sql_id": golden_sql_id, "organization_id": org_id},
+ )
diff --git a/apps/ai/server/modules/golden_sql/service.py b/apps/ai/server/modules/golden_sql/service.py
index 427eb007..00344d1f 100644
--- a/apps/ai/server/modules/golden_sql/service.py
+++ b/apps/ai/server/modules/golden_sql/service.py
@@ -1,9 +1,9 @@
from typing import List
import httpx
-from fastapi import HTTPException, status
from config import settings
+from exceptions.exception_handlers import raise_engine_exception
from modules.generation.models.entities import GenerationStatus
from modules.generation.service import DBConnectionService
from modules.golden_sql.models.entities import (
@@ -12,11 +12,14 @@
GoldenSQLMetadata,
GoldenSQLSource,
)
+from modules.golden_sql.models.exceptions import (
+ CannotDeleteGoldenSqlError,
+ GoldenSqlNotFoundError,
+)
from modules.golden_sql.models.requests import GoldenSQLRequest
from modules.golden_sql.models.responses import AggrGoldenSQL
from modules.golden_sql.repository import GoldenSQLRepository
from utils.analytics import Analytics, EventName, EventType
-from utils.exception import raise_for_status
from utils.misc import reserved_key_in_metadata
@@ -26,8 +29,8 @@ def __init__(self):
self.db_connection_service = DBConnectionService()
self.analytics = Analytics()
- def get_golden_sql(self, golden_id: str, org_id: str) -> AggrGoldenSQL:
- golden_sql = self.get_golden_sql_in_org(golden_id, org_id)
+ def get_golden_sql(self, golden_sql_id: str, org_id: str) -> AggrGoldenSQL:
+ golden_sql = self.get_golden_sql_in_org(golden_sql_id, org_id)
return AggrGoldenSQL(
**golden_sql.dict(),
db_connection_alias=self.db_connection_service.get_db_connection_in_org(
@@ -102,7 +105,7 @@ async def add_user_upload_golden_sql(
],
timeout=settings.default_engine_timeout,
)
- raise_for_status(response.status_code, response.text)
+ raise_engine_exception(response, org_id=org_id)
response_jsons = response.json()
golden_sqls = [
@@ -131,24 +134,24 @@ async def add_user_upload_golden_sql(
# we can avoid cyclic import if we avoid deleting verified golden sql
async def delete_golden_sql(
- self, golden_id: str, org_id: str, query_status: GenerationStatus = None
+ self, golden_sql_id: str, org_id: str, query_status: GenerationStatus = None
) -> dict:
- golden_sql = self.get_golden_sql_in_org(golden_id, org_id)
+ golden_sql = self.get_golden_sql_in_org(golden_sql_id, org_id)
async with httpx.AsyncClient() as client:
response = await client.delete(
- settings.engine_url + f"/golden-sqls/{golden_id}",
+ settings.engine_url + f"/golden-sqls/{golden_sql_id}",
timeout=settings.default_engine_timeout,
)
- raise_for_status(response.status_code, response.text)
+ raise_engine_exception(response, org_id=org_id)
if response.json()["status"]:
if query_status:
self.repo.update_generation_status(
golden_sql.metadata.dh_internal.prompt_id, query_status
)
- return {"id": golden_id}
+ return {"id": golden_sql_id}
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
+ raise CannotDeleteGoldenSqlError(golden_sql_id, org_id)
def get_verified_golden_sql(self, prompt_id: str) -> GoldenSQL:
return self.repo.get_verified_golden_sql(prompt_id)
@@ -182,7 +185,7 @@ async def add_verified_golden_sql(
json=[golden_sql_request.dict(exclude_unset=True)],
timeout=settings.default_engine_timeout,
)
- raise_for_status(response.status_code, response.text)
+ raise_engine_exception(response, org_id=org_id)
response_json = response.json()[0]
self.analytics.track(
@@ -193,11 +196,8 @@ async def add_verified_golden_sql(
return GoldenSQL(**response_json)
- def get_golden_sql_in_org(self, golden_id: str, org_id: str) -> GoldenSQL:
- golden_sql = self.repo.get_golden_sql(golden_id, org_id)
+ def get_golden_sql_in_org(self, golden_sql_id: str, org_id: str) -> GoldenSQL:
+ golden_sql = self.repo.get_golden_sql(golden_sql_id, org_id)
if not golden_sql:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="Golden sql not found",
- )
+ raise GoldenSqlNotFoundError(golden_sql_id, org_id)
return golden_sql
diff --git a/apps/ai/server/modules/instruction/models/exceptions.py b/apps/ai/server/modules/instruction/models/exceptions.py
new file mode 100644
index 00000000..51f5bce4
--- /dev/null
+++ b/apps/ai/server/modules/instruction/models/exceptions.py
@@ -0,0 +1,49 @@
+from starlette.status import (
+ HTTP_400_BAD_REQUEST,
+ HTTP_404_NOT_FOUND,
+)
+
+from exceptions.error_codes import BaseErrorCode, ErrorCodeData
+from exceptions.exceptions import BaseError
+
+
+class InstructionErrorCode(BaseErrorCode):
+ instruction_not_found = ErrorCodeData(
+ status_code=HTTP_404_NOT_FOUND, message="Instruction not found"
+ )
+ single_instruction_only = ErrorCodeData(
+ status_code=HTTP_400_BAD_REQUEST,
+ message="Only one instruction allowed per database connection",
+ )
+
+
+class InstructionError(BaseError):
+ """
+ Base class for instruction exceptions
+ """
+
+ ERROR_CODES: BaseErrorCode = InstructionErrorCode
+
+
+class InstructionNotFoundError(InstructionError):
+ def __init__(
+ self, org_id: str, instruction_id: str | None, db_connection_id: str | None
+ ) -> None:
+ if instruction_id:
+ detail = {"db_connection_id": db_connection_id, "organization_id": org_id}
+ elif db_connection_id:
+ detail = {"db_connection_id": db_connection_id, "organization_id": org_id}
+ else:
+ raise ValueError("instruction_id or db_connection_id must be provided")
+ super().__init__(
+ error_code=InstructionErrorCode.instruction_not_found.name,
+ detail=detail,
+ )
+
+
+class SingleInstructionOnlyError(InstructionError):
+ def __init__(self, db_connection_id: str, org_id: str) -> None:
+ super().__init__(
+ error_code=InstructionErrorCode.single_instruction_only.name,
+ detail={"db_connection_id": db_connection_id, "organization_id": org_id},
+ )
diff --git a/apps/ai/server/modules/instruction/service.py b/apps/ai/server/modules/instruction/service.py
index 882d6648..a58ff8a5 100644
--- a/apps/ai/server/modules/instruction/service.py
+++ b/apps/ai/server/modules/instruction/service.py
@@ -1,17 +1,20 @@
import httpx
-from fastapi import HTTPException, status
from config import settings
+from exceptions.exception_handlers import raise_engine_exception
from modules.db_connection.service import DBConnectionService
from modules.instruction.models.entities import (
DHInstructionMetadata,
Instruction,
InstructionMetadata,
)
+from modules.instruction.models.exceptions import (
+ InstructionNotFoundError,
+ SingleInstructionOnlyError,
+)
from modules.instruction.models.requests import InstructionRequest
from modules.instruction.models.responses import AggrInstruction
from modules.instruction.repository import InstructionRepository
-from utils.exception import raise_for_status
from utils.misc import reserved_key_in_metadata
@@ -57,10 +60,7 @@ def get_first_instruction(
)
instructions = self.repo.get_instructions(db_connection_id, org_id)
if len(instructions) == 0:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="Instruction not found",
- )
+ raise InstructionNotFoundError(org_id, db_connection_id=db_connection_id)
return AggrInstruction(
**instructions[0].dict(), db_connection_alias=db_connection.alias
)
@@ -71,9 +71,8 @@ async def add_instruction(
reserved_key_in_metadata(instruction_request.metadata)
if self.repo.get_instructions(instruction_request.db_connection_id, org_id):
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Instruction already exists for this db connection",
+ raise SingleInstructionOnlyError(
+ instruction_request.db_connection_id, org_id
)
db_connection = self.db_connection_service.get_db_connection_in_org(
@@ -89,7 +88,7 @@ async def add_instruction(
settings.engine_url + "/instructions",
json=instruction_request.dict(exclude_unset=True),
)
- raise_for_status(response.status_code, response.text)
+ raise_engine_exception(response, org_id=org_id)
return AggrInstruction(
**response.json(), db_connection_alias=db_connection.alias
)
@@ -120,7 +119,7 @@ async def update_instruction(
settings.engine_url + f"/instructions/{instruction_id}",
json=instruction_request.dict(exclude_unset=True),
)
- raise_for_status(response.status_code, response.text)
+ raise_engine_exception(response, org_id=org_id)
return AggrInstruction(
**response.json(), db_connection_alias=db_connection.alias
)
@@ -135,14 +134,11 @@ async def delete_instruction(self, instruction_id: str, org_id: str):
response = await client.delete(
settings.engine_url + f"/instructions/{instruction_id}",
)
- raise_for_status(response.status_code, response.text)
+ raise_engine_exception(response, org_id=org_id)
return {"id": instruction_id}
def get_instruction_in_org(self, instruction_id: str, org_id: str) -> Instruction:
instruction = self.repo.get_instruction(instruction_id, org_id)
if not instruction:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="Instruction not found",
- )
+ raise InstructionNotFoundError(org_id, instruction_id=instruction_id)
return instruction
diff --git a/apps/ai/server/modules/key/models/exceptions.py b/apps/ai/server/modules/key/models/exceptions.py
new file mode 100644
index 00000000..2e769fcd
--- /dev/null
+++ b/apps/ai/server/modules/key/models/exceptions.py
@@ -0,0 +1,62 @@
+from starlette.status import (
+ HTTP_400_BAD_REQUEST,
+ HTTP_401_UNAUTHORIZED,
+)
+
+from exceptions.error_codes import BaseErrorCode, ErrorCodeData
+from exceptions.exceptions import BaseError
+
+
+class KeyErrorCode(BaseErrorCode):
+ key_not_found = ErrorCodeData(
+ status_code=HTTP_401_UNAUTHORIZED, message="API key not found"
+ )
+ key_name_exists = ErrorCodeData(
+ status_code=HTTP_400_BAD_REQUEST, message="Existing key already has name"
+ )
+ cannot_revoke_key = ErrorCodeData(
+ status_code=HTTP_400_BAD_REQUEST, message="Cannot revoke api key"
+ )
+ cannot_create_key = ErrorCodeData(
+ status_code=HTTP_400_BAD_REQUEST, message="Cannot create api key"
+ )
+
+
+class KeyError(BaseError):
+ """
+ Base class for api key exceptions
+ """
+
+ ERROR_CODES: BaseErrorCode = KeyErrorCode
+
+
+class KeyNotFoundError(KeyError):
+ def __init__(self, key_id: str, org_id: str) -> None:
+ super().__init__(
+ error_code=KeyErrorCode.key_not_found.name,
+ detail={"key_id": key_id, "organization_id": org_id},
+ )
+
+
+class KeyNameExistsError(KeyError):
+ def __init__(self, key_id: str, org_id: str) -> None:
+ super().__init__(
+ error_code=KeyErrorCode.key_name_exists.name,
+ detail={"key_id": key_id, "organization_id": org_id},
+ )
+
+
+class CannotRevokeKeyError(KeyError):
+ def __init__(self, key_id: str, org_id: str) -> None:
+ super().__init__(
+ error_code=KeyErrorCode.cannot_revoke_key.name,
+ detail={"key_id": key_id, "organization_id": org_id},
+ )
+
+
+class CannotCreateKeyError(KeyError):
+ def __init__(self, org_id: str) -> None:
+ super().__init__(
+ error_code=KeyErrorCode.cannot_create_key.name,
+ detail={"organization_id": org_id},
+ )
diff --git a/apps/ai/server/modules/key/repository.py b/apps/ai/server/modules/key/repository.py
index 5d637087..4e7fcdb3 100644
--- a/apps/ai/server/modules/key/repository.py
+++ b/apps/ai/server/modules/key/repository.py
@@ -12,6 +12,10 @@ def get_key(self, key_id: str, org_id: str) -> APIKey:
)
return APIKey(id=str(key["_id"]), **key) if key else None
+ def get_key_by_name(self, name: str, org_id: str) -> APIKey:
+ key = MongoDB.find_one(KEY_COL, {"name": name, "organization_id": org_id})
+ return APIKey(id=str(key["_id"]), **key) if key else None
+
def get_keys(self, org_id: str) -> list[APIKey]:
return [
APIKey(id=str(key["_id"]), **key)
@@ -22,8 +26,8 @@ def get_key_by_hash(self, key_hash: str) -> APIKey:
key = MongoDB.find_one(KEY_COL, {"key_hash": key_hash})
return APIKey(id=str(key["_id"]), **key) if key else None
- def add_key(self, key: dict) -> str:
- return str(MongoDB.insert_one(KEY_COL, key))
+ def add_key(self, key: APIKey) -> str:
+ return str(MongoDB.insert_one(KEY_COL, key.dict(exclude={"id"})))
def delete_key(self, key_id: str, org_id: str) -> int:
return MongoDB.delete_one(
diff --git a/apps/ai/server/modules/key/service.py b/apps/ai/server/modules/key/service.py
index af252bce..d859aecf 100644
--- a/apps/ai/server/modules/key/service.py
+++ b/apps/ai/server/modules/key/service.py
@@ -1,11 +1,13 @@
import hashlib
import secrets
-from datetime import datetime
-
-from fastapi import HTTPException, status
from config import settings
from modules.key.models.entities import APIKey
+from modules.key.models.exceptions import (
+ CannotCreateKeyError,
+ CannotRevokeKeyError,
+ KeyNameExistsError,
+)
from modules.key.models.requests import KeyGenerationRequest
from modules.key.models.responses import KeyPreviewResponse, KeyResponse
from modules.key.repository import KeyRepository
@@ -24,16 +26,18 @@ def get_keys(self, org_id: str) -> list[KeyPreviewResponse]:
def add_key(
self, key_request: KeyGenerationRequest, org_id: str, api_key: str = None
) -> KeyResponse:
+ key = self.repo.get_key_by_name(key_request.name, org_id)
+ if key:
+ raise KeyNameExistsError(key.id, org_id)
if not api_key:
api_key = KEY_PREFIX + self.generate_new_key()
key = APIKey(
key_hash=self.hash_key(key=api_key),
organization_id=org_id,
- created_at=datetime.now(),
name=key_request.name,
key_preview=KEY_PREFIX + "························" + api_key[-3:],
)
- key_id = self.repo.add_key(key.dict(exclude_unset=True))
+ key_id = self.repo.add_key(key)
if key_id:
return KeyResponse(
@@ -45,10 +49,7 @@ def add_key(
api_key=api_key,
)
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Could not create key",
- )
+ raise CannotCreateKeyError(org_id)
def validate_key(self, api_key: str) -> APIKey:
return self.repo.get_key_by_hash(key_hash=self.hash_key(api_key))
@@ -71,6 +72,4 @@ def revoke_key(self, key_id: str, org_id: str):
if self.repo.delete_key(key_id, org_id) == 1:
return {"id": key_id}
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND, detail="Key not found"
- )
+ raise CannotRevokeKeyError(key_id, org_id)
diff --git a/apps/ai/server/modules/organization/invoice/exception/error_codes.py b/apps/ai/server/modules/organization/invoice/exception/error_codes.py
new file mode 100644
index 00000000..d3c8f4d9
--- /dev/null
+++ b/apps/ai/server/modules/organization/invoice/exception/error_codes.py
@@ -0,0 +1,49 @@
+from starlette.status import (
+ HTTP_400_BAD_REQUEST,
+ HTTP_402_PAYMENT_REQUIRED,
+)
+
+from exceptions.error_codes import BaseErrorCode, ErrorCodeData
+
+
+class InvoiceErrorCode(BaseErrorCode):
+ no_payment_method = ErrorCodeData(
+ status_code=HTTP_402_PAYMENT_REQUIRED,
+ message="No payment method on file",
+ )
+ last_payment_method = ErrorCodeData(
+ status_code=HTTP_400_BAD_REQUEST, message="Last payment method"
+ )
+ spending_limit_exceeded = ErrorCodeData(
+ status_code=HTTP_402_PAYMENT_REQUIRED, message="Spending limit exceeded"
+ )
+ hard_spending_limit_exceeded = ErrorCodeData(
+ status_code=HTTP_402_PAYMENT_REQUIRED,
+ message="Hard spending limit exceeded",
+ )
+ subscription_past_due = ErrorCodeData(
+ status_code=HTTP_402_PAYMENT_REQUIRED,
+ message="Stripe subscription past due",
+ )
+ subscription_canceled = ErrorCodeData(
+ status_code=HTTP_402_PAYMENT_REQUIRED,
+ message="Stripe subscription canceled",
+ )
+ unknown_subscription_status = ErrorCodeData(
+ status_code=HTTP_400_BAD_REQUEST,
+ message="Unknown stripe subscription status",
+ )
+ is_enterprise_plan = ErrorCodeData(
+ status_code=HTTP_400_BAD_REQUEST,
+ message="Cannot perform action for enterprise plan",
+ )
+ cannot_update_spending_limit = ErrorCodeData(
+ status_code=HTTP_400_BAD_REQUEST, message="Cannot update spending limit"
+ )
+ cannot_update_payment_method = ErrorCodeData(
+ status_code=HTTP_400_BAD_REQUEST, message="Cannot update payment method"
+ )
+ missing_invoice_details = ErrorCodeData(
+ status_code=HTTP_400_BAD_REQUEST,
+ message="Organization missing invoice details",
+ )
diff --git a/apps/ai/server/modules/organization/invoice/models/exceptions.py b/apps/ai/server/modules/organization/invoice/models/exceptions.py
new file mode 100644
index 00000000..3b7924fa
--- /dev/null
+++ b/apps/ai/server/modules/organization/invoice/models/exceptions.py
@@ -0,0 +1,146 @@
+from starlette.status import (
+ HTTP_400_BAD_REQUEST,
+ HTTP_402_PAYMENT_REQUIRED,
+)
+
+from exceptions.error_codes import BaseErrorCode, ErrorCodeData
+from exceptions.exceptions import BaseError
+
+
+class InvoiceErrorCode(BaseErrorCode):
+ no_payment_method = ErrorCodeData(
+ status_code=HTTP_402_PAYMENT_REQUIRED,
+ message="No payment method on file",
+ )
+ last_payment_method = ErrorCodeData(
+ status_code=HTTP_400_BAD_REQUEST, message="Last payment method"
+ )
+ spending_limit_exceeded = ErrorCodeData(
+ status_code=HTTP_402_PAYMENT_REQUIRED, message="Spending limit exceeded"
+ )
+ hard_spending_limit_exceeded = ErrorCodeData(
+ status_code=HTTP_402_PAYMENT_REQUIRED,
+ message="Hard spending limit exceeded",
+ )
+ subscription_past_due = ErrorCodeData(
+ status_code=HTTP_402_PAYMENT_REQUIRED,
+ message="Stripe subscription past due",
+ )
+ subscription_canceled = ErrorCodeData(
+ status_code=HTTP_402_PAYMENT_REQUIRED,
+ message="Stripe subscription canceled",
+ )
+ unknown_subscription_status = ErrorCodeData(
+ status_code=HTTP_400_BAD_REQUEST,
+ message="Unknown stripe subscription status",
+ )
+ is_enterprise_plan = ErrorCodeData(
+ status_code=HTTP_400_BAD_REQUEST,
+ message="Cannot perform action for enterprise plan",
+ )
+ cannot_update_spending_limit = ErrorCodeData(
+ status_code=HTTP_400_BAD_REQUEST, message="Cannot update spending limit"
+ )
+ cannot_update_payment_method = ErrorCodeData(
+ status_code=HTTP_400_BAD_REQUEST, message="Cannot update payment method"
+ )
+ missing_invoice_details = ErrorCodeData(
+ status_code=HTTP_400_BAD_REQUEST,
+ message="Organization missing invoice details",
+ )
+
+
+class InvoiceError(BaseError):
+ """
+ Base class for invoice exceptions
+ """
+
+ ERROR_CODES: BaseErrorCode = InvoiceErrorCode
+
+
+class NoPaymentMethodError(InvoiceError):
+ def __init__(self, organization_id: str) -> None:
+ super().__init__(
+ error_code=InvoiceErrorCode.no_payment_method.name,
+ detail={"organization_id": organization_id},
+ )
+
+
+class LastPaymentMethodError(InvoiceError):
+ def __init__(self, organization_id: str) -> None:
+ super().__init__(
+ error_code=InvoiceErrorCode.last_payment_method.name,
+ detail={"organization_id": organization_id},
+ )
+
+
+class SpendingLimitExceededError(InvoiceError):
+ def __init__(self, organization_id: str) -> None:
+ super().__init__(
+ error_code=InvoiceErrorCode.spending_limit_exceeded.name,
+ detail={"organization_id": organization_id},
+ )
+
+
+class HardSpendingLimitExceededError(InvoiceError):
+ def __init__(self, organization_id: str) -> None:
+ super().__init__(
+ error_code=InvoiceErrorCode.hard_spending_limit_exceeded.name,
+ detail={"organization_id": organization_id},
+ )
+
+
+class SubscriptionPastDueError(InvoiceError):
+ def __init__(self, organization_id: str) -> None:
+ super().__init__(
+ error_code=InvoiceErrorCode.subscription_past_due.name,
+ detail={"organization_id": organization_id},
+ )
+
+
+class SubscriptionCanceledError(InvoiceError):
+ def __init__(self, organization_id: str) -> None:
+ super().__init__(
+ error_code=InvoiceErrorCode.subscription_canceled.name,
+ detail={"organization_id": organization_id},
+ )
+
+
+class UnknownSubscriptionStatusError(InvoiceError):
+ def __init__(self, organization_id: str) -> None:
+ super().__init__(
+ error_code=InvoiceErrorCode.unknown_subscription_status.name,
+ detail={"organization_id": organization_id},
+ )
+
+
+class IsEnterprisePlanError(InvoiceError):
+ def __init__(self, organization_id: str) -> None:
+ super().__init__(
+ error_code=InvoiceErrorCode.is_enterprise_plan.name,
+ detail={"organization_id": organization_id},
+ )
+
+
+class CannotUpdateSpendingLimitError(InvoiceError):
+ def __init__(self, organization_id: str) -> None:
+ super().__init__(
+ error_code=InvoiceErrorCode.cannot_update_spending_limit.name,
+ detail={"organization_id": organization_id},
+ )
+
+
+class CannotUpdatePaymentMethodError(InvoiceError):
+ def __init__(self, organization_id: str) -> None:
+ super().__init__(
+ error_code=InvoiceErrorCode.cannot_update_payment_method.name,
+ detail={"organization_id": organization_id},
+ )
+
+
+class MissingInvoiceDetailsError(InvoiceError):
+ def __init__(self, organization_id: str) -> None:
+ super().__init__(
+ error_code=InvoiceErrorCode.missing_invoice_details.name,
+ detail={"organization_id": organization_id},
+ )
diff --git a/apps/ai/server/modules/organization/invoice/repository.py b/apps/ai/server/modules/organization/invoice/repository.py
index 583483e3..f3d8ed51 100644
--- a/apps/ai/server/modules/organization/invoice/repository.py
+++ b/apps/ai/server/modules/organization/invoice/repository.py
@@ -88,8 +88,8 @@ def get_positive_credits(self, org_id: str) -> list[Credit]:
)
]
- def create_usage(self, usage: dict) -> str:
- return str(MongoDB.insert_one(USAGE_COL, usage))
+ def create_usage(self, usage: Usage) -> str:
+ return str(MongoDB.insert_one(USAGE_COL, usage.dict(exclude={"id"})))
def update_spending_limit(self, org_id: str, spending_limit: int) -> int:
return MongoDB.update_one(
@@ -119,8 +119,8 @@ def update_billing_cyce_anchor(self, org_id: str, billing_cycle_anchor: int) ->
{"invoice_details.billing_cycle_anchor": billing_cycle_anchor},
)
- def create_credit(self, credit: dict) -> str:
- return str(MongoDB.insert_one(CREDIT_COL, credit))
+ def create_credit(self, credit: Credit) -> str:
+ return str(MongoDB.insert_one(CREDIT_COL, credit.dict(exclude={"id"})))
def update_available_credits(self, org_id: str, credit: int) -> int:
return MongoDB.update_one(
diff --git a/apps/ai/server/modules/organization/invoice/service.py b/apps/ai/server/modules/organization/invoice/service.py
index 57263cc3..ae94ad05 100644
--- a/apps/ai/server/modules/organization/invoice/service.py
+++ b/apps/ai/server/modules/organization/invoice/service.py
@@ -1,6 +1,5 @@
from datetime import datetime
-from fastapi import HTTPException, status
from stripe import PaymentMethod
from config import invoice_settings
@@ -13,6 +12,19 @@
UsageInvoice,
UsageType,
)
+from modules.organization.invoice.models.exceptions import (
+ CannotUpdatePaymentMethodError,
+ CannotUpdateSpendingLimitError,
+ HardSpendingLimitExceededError,
+ IsEnterprisePlanError,
+ LastPaymentMethodError,
+ MissingInvoiceDetailsError,
+ NoPaymentMethodError,
+ SpendingLimitExceededError,
+ SubscriptionCanceledError,
+ SubscriptionPastDueError,
+ UnknownSubscriptionStatusError,
+)
from modules.organization.invoice.models.requests import (
CreditRequest,
PaymentMethodRequest,
@@ -28,7 +40,6 @@
from modules.organization.repository import OrganizationRepository
from utils.analytics import Analytics, EventName, EventType
from utils.billing import Billing
-from utils.exception import ErrorCode
class InvoiceService:
@@ -70,10 +81,7 @@ def update_spending_limit(
hard_spending_limit=organization.invoice_details.hard_spending_limit,
)
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Unable to update spending limit",
- )
+ raise CannotUpdateSpendingLimitError(org_id)
def get_pending_invoice(self, org_id: str) -> InvoiceResponse:
@@ -194,10 +202,7 @@ def set_default_payment_method(
):
return {"success": True}
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Unable to set default payment method",
- )
+ raise CannotUpdatePaymentMethodError(org_id)
def detach_payment_method(
self, org_id: str, payment_method_id: str
@@ -210,10 +215,7 @@ def detach_payment_method(
organization.invoice_details.stripe_customer_id
)
if len(payment_methods) <= 1:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Cannot detach last payment method",
- )
+ raise LastPaymentMethodError(org_id)
# check if payment method exists for customer, avoids using stripe api
payment_method = None
@@ -233,9 +235,7 @@ def detach_payment_method(
break
return self._get_mapped_payment_method_response(payment_method, False)
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND, detail="Payment method not found"
- )
+ raise NoPaymentMethodError(org_id)
def record_usage(
self,
@@ -256,7 +256,7 @@ def record_usage(
description=description,
status=RecordStatus.UNRECORDED,
)
- usage_id = self.repo.create_usage(usage.dict(exclude={"id"}))
+ usage_id = self.repo.create_usage(usage)
print(f"New usage created: {usage_id}")
available_credits = organization.invoice_details.available_credits
self._apply_unrecorded_credits(
@@ -286,15 +286,24 @@ def check_usage(
# check if organization has payment method
organization = self.org_repo.get_organization(org_id)
if not organization.invoice_details:
- raise HTTPException(
- status_code=status.HTTP_402_PAYMENT_REQUIRED,
- detail="Organization does not have invoice details",
- )
+ raise MissingInvoiceDetailsError(org_id)
# skip check if enterprise
if organization.invoice_details.plan != PaymentPlan.ENTERPRISE:
- self._check_subscription_status(
+ if (
organization.invoice_details.stripe_subscription_status
- )
+ != StripeSubscriptionStatus.ACTIVE
+ ):
+ if (
+ organization.invoice_details.stripe_subscription_status
+ == StripeSubscriptionStatus.PAST_DUE
+ ):
+ raise SubscriptionPastDueError(org_id)
+ if (
+ organization.invoice_details.stripe_subscription_status
+ == StripeSubscriptionStatus.CANCELED
+ ):
+ raise SubscriptionCanceledError(org_id)
+ raise UnknownSubscriptionStatusError(org_id)
start_date, end_date = (
self.billing.get_current_subscription_period_with_anchor(
organization.invoice_details.billing_cycle_anchor
@@ -316,27 +325,23 @@ def check_usage(
)
> organization.invoice_details.available_credits
):
- raise HTTPException(
- status_code=status.HTTP_402_PAYMENT_REQUIRED,
- detail=ErrorCode.no_payment_method,
- )
+ raise NoPaymentMethodError(org_id)
# for usage based and credit only
- self._check_spending_limit_from_usage(
- usages,
- organization.invoice_details.spending_limit,
- organization.invoice_details.hard_spending_limit,
+ total_usage_cost = self._calculate_total_usage_cost(
+ self._get_invoice_from_usages(usages)
)
+ if total_usage_cost > organization.invoice_details.hard_spending_limit:
+ raise HardSpendingLimitExceededError(org_id)
+ if total_usage_cost > organization.invoice_details.spending_limit:
+ raise SpendingLimitExceededError(org_id)
def add_credits(
self, org_id: str, user_id: str, credit_request: CreditRequest
) -> CreditResponse:
organization = self.org_repo.get_organization(org_id)
if organization.invoice_details.plan == PaymentPlan.ENTERPRISE:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Cannot add credits to enterprise plan",
- )
+ raise IsEnterprisePlanError(org_id)
credit_id = self.repo.create_credit(
Credit(
@@ -344,7 +349,7 @@ def add_credits(
amount=credit_request.amount,
status=RecordStatus.RECORDED,
description=f"added by {user_id}: {credit_request.description}",
- ).dict(exclude={"id"})
+ )
)
print(f"New credit created: {credit_id}")
# apply credits to recorded usage
@@ -359,7 +364,7 @@ def add_credits(
amount=-credits_due,
status=RecordStatus.RECORDED,
description=f"negative credits for stripe pending invoice; used from new credit {credit_id}",
- ).dict(exclude={"id"})
+ )
)
self.billing.create_balance_transaction(
organization.invoice_details.stripe_customer_id,
@@ -381,23 +386,6 @@ def add_credits(
)
return self.repo.get_credit(credit_id)
- def _check_spending_limit_from_usage(
- self, usages: list[Usage], spending_limit: int, hard_spending_limit: int
- ):
- total_usage_cost = self._calculate_total_usage_cost(
- self._get_invoice_from_usages(usages)
- )
- if total_usage_cost > hard_spending_limit:
- raise HTTPException(
- status_code=status.HTTP_402_PAYMENT_REQUIRED,
- detail=ErrorCode.hard_spending_limit_exceeded,
- )
- if total_usage_cost > spending_limit:
- raise HTTPException(
- status_code=status.HTTP_402_PAYMENT_REQUIRED,
- detail=ErrorCode.spending_limit_exceeded,
- )
-
def _get_invoice_from_usages(self, usages: list[Usage]) -> UsageInvoice:
usage_invoice = {
UsageType.SQL_GENERATION: 0,
@@ -436,23 +424,6 @@ def _get_mapped_payment_method_response(
is_default=is_defualt,
)
- def _check_subscription_status(self, subscription_status: str):
- if subscription_status != StripeSubscriptionStatus.ACTIVE:
- if subscription_status == StripeSubscriptionStatus.PAST_DUE:
- raise HTTPException(
- status_code=status.HTTP_402_PAYMENT_REQUIRED,
- detail=ErrorCode.subscription_past_due,
- )
- if subscription_status == StripeSubscriptionStatus.CANCELED:
- raise HTTPException(
- status_code=status.HTTP_402_PAYMENT_REQUIRED,
- detail=ErrorCode.subscription_canceled,
- )
- raise HTTPException(
- status_code=status.HTTP_402_PAYMENT_REQUIRED,
- detail=ErrorCode.unknown_subscription_status,
- )
-
def _apply_unrecorded_credits(
self,
org_id: str,
@@ -469,7 +440,7 @@ def _apply_unrecorded_credits(
amount=-credits_due,
status=RecordStatus.UNRECORDED,
description=description,
- ).dict(exclude={"id"})
+ )
)
print(f"New negative credit created: {neg_credit_id}")
self.repo.update_available_credits(org_id, available_credits - credits_due)
diff --git a/apps/ai/server/modules/organization/models/exceptions.py b/apps/ai/server/modules/organization/models/exceptions.py
new file mode 100644
index 00000000..aaf3bb2e
--- /dev/null
+++ b/apps/ai/server/modules/organization/models/exceptions.py
@@ -0,0 +1,90 @@
+from starlette.status import (
+ HTTP_400_BAD_REQUEST,
+ HTTP_404_NOT_FOUND,
+)
+
+from exceptions.error_codes import BaseErrorCode, ErrorCodeData
+from exceptions.exceptions import BaseError
+
+
+class OrganizationErrorCode(BaseErrorCode):
+ organization_not_found = ErrorCodeData(
+ status_code=HTTP_404_NOT_FOUND, message="Organization not found"
+ )
+ slack_installation_not_found = ErrorCodeData(
+ status_code=HTTP_404_NOT_FOUND, message="Slack installation not found"
+ )
+ cannot_create_organization = ErrorCodeData(
+ status_code=HTTP_400_BAD_REQUEST, message="Cannot create organization"
+ )
+ cannot_update_organization = ErrorCodeData(
+ status_code=HTTP_400_BAD_REQUEST, message="Cannot update organization"
+ )
+ cannot_delete_organization = ErrorCodeData(
+ status_code=HTTP_400_BAD_REQUEST, message="Cannot delete organization"
+ )
+ invalid_llm_api_key = ErrorCodeData(
+ status_code=HTTP_400_BAD_REQUEST, message="Invalid LLM API key"
+ )
+
+
+class OrganizationError(BaseError):
+ """
+ Base class for organization exceptions
+ """
+
+ ERROR_CODES: BaseErrorCode = OrganizationErrorCode
+
+
+class OrganizationNotFoundError(OrganizationError):
+ def __init__(
+ self, slack_workspace_id: str | None, organization_id: str | None
+ ) -> None:
+ if slack_workspace_id:
+ detail = {"slack_workspace_id": slack_workspace_id}
+ elif organization_id:
+ detail = {"organization_id": organization_id}
+ else:
+ raise ValueError("workspace_id or organization_id must be provided")
+ super().__init__(
+ error_code=OrganizationErrorCode.organization_not_found.name,
+ detail=detail,
+ )
+
+
+class SlackInstallationNotFoundError(OrganizationError):
+ def __init__(self, slack_workspace_id: str) -> None:
+ super().__init__(
+ error_code=OrganizationErrorCode.slack_installation_not_found.name,
+ detail={"slack_workspace_id": slack_workspace_id},
+ )
+
+
+class CannotCreateOrganizationError(OrganizationError):
+ def __init__(self) -> None:
+ super().__init__(
+ error_code=OrganizationErrorCode.cannot_create_organization.name,
+ )
+
+
+class CannotUpdateOrganizationError(OrganizationError):
+ def __init__(self, organization_id: str) -> None:
+ super().__init__(
+ error_code=OrganizationErrorCode.cannot_update_organization.name,
+ detail={"organization_id": organization_id},
+ )
+
+
+class CannotDeleteOrganizationError(OrganizationError):
+ def __init__(self, organization_id: str) -> None:
+ super().__init__(
+ error_code=OrganizationErrorCode.cannot_delete_organization.name,
+ detail={"organization_id": organization_id},
+ )
+
+
+class InvalidLlmApiKeyError(OrganizationError):
+ def __init__(self) -> None:
+ super().__init__(
+ error_code=OrganizationErrorCode.invalid_llm_api_key.name,
+ )
diff --git a/apps/ai/server/modules/organization/repository.py b/apps/ai/server/modules/organization/repository.py
index 54c2bc04..b810200f 100644
--- a/apps/ai/server/modules/organization/repository.py
+++ b/apps/ai/server/modules/organization/repository.py
@@ -52,8 +52,7 @@ def update_organization(self, org_id: str, new_org_data: dict) -> int:
ORGANIZATION_COL, {"_id": ObjectId(org_id)}, new_org_data
)
- def add_organization(self, new_org_data: dict) -> str:
- # each organization should have unique name
- if MongoDB.find_one(ORGANIZATION_COL, {"name": new_org_data["name"]}):
- return None
- return str(MongoDB.insert_one(ORGANIZATION_COL, new_org_data))
+ def add_organization(self, new_org_data: Organization) -> str:
+ return str(
+ MongoDB.insert_one(ORGANIZATION_COL, new_org_data.dict(exclude={"id"}))
+ )
diff --git a/apps/ai/server/modules/organization/service.py b/apps/ai/server/modules/organization/service.py
index 54d53c8c..bd14ce5b 100644
--- a/apps/ai/server/modules/organization/service.py
+++ b/apps/ai/server/modules/organization/service.py
@@ -1,5 +1,4 @@
import openai
-from fastapi import HTTPException, status
from config import invoice_settings
from modules.organization.invoice.models.entities import (
@@ -14,6 +13,14 @@
SlackConfig,
SlackInstallation,
)
+from modules.organization.models.exceptions import (
+ CannotCreateOrganizationError,
+ CannotDeleteOrganizationError,
+ CannotUpdateOrganizationError,
+ InvalidLlmApiKeyError,
+ OrganizationNotFoundError,
+ SlackInstallationNotFoundError,
+)
from modules.organization.models.requests import OrganizationRequest
from modules.organization.models.responses import OrganizationResponse
from modules.organization.repository import OrganizationRepository
@@ -36,15 +43,14 @@ def get_organization(self, org_id: str) -> OrganizationResponse:
return self.repo.get_organization(org_id)
def get_organization_by_slack_workspace_id(
- self, workspace_id: str
+ self, slack_workspace_id: str
) -> OrganizationResponse:
- organization = self.repo.get_organization_by_slack_workspace_id(workspace_id)
+ organization = self.repo.get_organization_by_slack_workspace_id(
+ slack_workspace_id
+ )
if organization:
return organization
-
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND, detail="Organization not found"
- )
+ raise OrganizationNotFoundError(slack_workspace_id=slack_workspace_id)
def add_organization(
self, org_request: OrganizationRequest
@@ -69,7 +75,7 @@ def add_organization(
hard_spending_limit=invoice_settings.default_hard_spending_limit,
available_credits=invoice_settings.signup_credits,
)
- new_id = self.repo.add_organization(organization.dict(exclude_unset=True))
+ new_id = self.repo.add_organization(organization)
if new_id:
new_organization = self.repo.get_organization(new_id)
# create signup credit, mark as recorded
@@ -79,7 +85,7 @@ def add_organization(
amount=invoice_settings.signup_credits,
status=RecordStatus.RECORDED,
description="Signup credits",
- ).dict(exclude={"id"})
+ )
)
print(f"New credit created: {credit_id}")
self.analytics.track(
@@ -93,10 +99,7 @@ def add_organization(
)
return OrganizationResponse(**new_organization.dict())
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Organization exists or cannot add organization",
- )
+ raise CannotCreateOrganizationError()
def add_user_organization(self, user_id: str, user_email: str) -> str:
new_organization = self.add_organization(
@@ -120,19 +123,13 @@ def update_organization(
):
return self.repo.get_organization(org_id)
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Organization not found or cannot be updated",
- )
+ raise CannotUpdateOrganizationError(org_id)
def delete_organization(self, org_id: str) -> dict:
if self.repo.delete_organization(org_id) == 1:
return {"id": org_id}
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Organization not found or cannot be deleted",
- )
+ raise CannotDeleteOrganizationError(org_id)
def add_organization_by_slack_installation(
self, slack_installation_request: SlackInstallation
@@ -153,10 +150,7 @@ def add_organization_by_slack_installation(
updated_org = self.repo.get_organization(str(current_org.id))
return OrganizationResponse(**updated_org.dict())
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="An error ocurred while updating organization",
- )
+ raise CannotUpdateOrganizationError(current_org.id)
organization = Organization(
name=slack_installation_request.team.name,
@@ -179,24 +173,31 @@ def add_organization_by_slack_installation(
available_credits=invoice_settings.signup_credits,
)
- new_id = self.repo.add_organization(organization.dict(exclude={"id"}))
+ new_id = self.repo.add_organization(organization)
if new_id:
# create signup credit, mark as recorded
+ new_organization = self.repo.get_organization(new_id)
credit_id = self.invoice_repo.create_credit(
Credit(
organization_id=new_id,
amount=invoice_settings.signup_credits,
status=RecordStatus.RECORDED,
description="Signup credits",
- ).dict(exclude={"id"})
+ )
)
print(f"New credit created: {credit_id}")
- return self.repo.get_organization(new_id)
+ self.analytics.track(
+ new_organization.id,
+ EventName.organization_created,
+ EventType.organization_event(
+ id=new_organization.id,
+ name=new_organization.name,
+ owner=new_organization.owner,
+ ),
+ )
+ return new_organization
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Organization exists or cannot add organization",
- )
+ raise CannotCreateOrganizationError()
def get_slack_installation_by_slack_workspace_id(
self, slack_workspace_id: str
@@ -207,9 +208,7 @@ def get_slack_installation_by_slack_workspace_id(
if organization:
return organization.slack_config.slack_installation
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND, detail="slack installation not found"
- )
+ raise SlackInstallationNotFoundError(slack_workspace_id)
def get_organization_by_customer_id(self, customer_id: str) -> Organization:
return self.repo.get_organization_by_customer_id(customer_id)
@@ -223,7 +222,4 @@ def _validate_api_key(self, llm_api_key: str):
try:
openai.Model.list()
except openai.error.AuthenticationError as e:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Invalid LLM API key",
- ) from e
+ raise InvalidLlmApiKeyError() from e
diff --git a/apps/ai/server/modules/table_description/models/exceptions.py b/apps/ai/server/modules/table_description/models/exceptions.py
new file mode 100644
index 00000000..6afcc31a
--- /dev/null
+++ b/apps/ai/server/modules/table_description/models/exceptions.py
@@ -0,0 +1,31 @@
+from starlette.status import (
+ HTTP_404_NOT_FOUND,
+)
+
+from exceptions.error_codes import BaseErrorCode, ErrorCodeData
+from exceptions.exceptions import BaseError
+
+
+class TableDescriptionErrorCode(BaseErrorCode):
+ table_description_not_found = ErrorCodeData(
+ status_code=HTTP_404_NOT_FOUND, message="Table description not found"
+ )
+
+
+class TableDescriptionError(BaseError):
+ """
+ Base class for table description exceptions
+ """
+
+ ERROR_CODES: BaseErrorCode = TableDescriptionErrorCode
+
+
+class TableDescriptionNotFoundError(TableDescriptionError):
+ def __init__(self, table_description_id: str, org_id: str) -> None:
+ super().__init__(
+ error_code=TableDescriptionErrorCode.table_description_not_found.name,
+ detail={
+ "table_description_id": table_description_id,
+ "organization_id": org_id,
+ },
+ )
diff --git a/apps/ai/server/modules/table_description/service.py b/apps/ai/server/modules/table_description/service.py
index 42018f63..8821bee3 100644
--- a/apps/ai/server/modules/table_description/service.py
+++ b/apps/ai/server/modules/table_description/service.py
@@ -1,13 +1,14 @@
import httpx
-from fastapi import HTTPException, status
from config import settings
+from exceptions.exception_handlers import raise_engine_exception
from modules.db_connection.service import DBConnectionService
from modules.table_description.models.entities import (
DHTableDescriptionMetadata,
TableDescription,
TableDescriptionMetadata,
)
+from modules.table_description.models.exceptions import TableDescriptionNotFoundError
from modules.table_description.models.requests import (
ScanRequest,
TableDescriptionRequest,
@@ -18,7 +19,6 @@
DatabaseDescriptionResponse,
)
from modules.table_description.repository import TableDescriptionRepository
-from utils.exception import raise_for_status
from utils.misc import reserved_key_in_metadata
@@ -39,7 +39,7 @@ async def get_table_descriptions(
params={"db_connection_id": db_connection_id, "table_name": table_name},
timeout=settings.default_engine_timeout,
)
- raise_for_status(response.status_code, response.text)
+ raise_engine_exception(response, org_id=org_id)
table_descriptions = [
AggrTableDescription(
**table_description, db_connection_alias=db_connection.alias
@@ -68,7 +68,7 @@ async def get_table_description(
settings.engine_url + f"/table-descriptions/{table_description_id}",
timeout=settings.default_engine_timeout,
)
- raise_for_status(response.status_code, response.text)
+ raise_engine_exception(response, org_id=org_id)
table_description = AggrTableDescription(
**response.json(), db_connection_alias=db_connection.alias
)
@@ -94,7 +94,7 @@ async def refresh_table_description(
)
try:
- raise_for_status(response.status_code, response.text)
+ raise_engine_exception(response, org_id=org_id)
table_descriptions = [
AggrTableDescription(**table_description)
for table_description in response.json()
@@ -139,7 +139,7 @@ async def get_database_description_list(
params={"db_connection_id": db_connection.id},
timeout=settings.default_engine_timeout,
)
- raise_for_status(response.status_code, response.text)
+ raise_engine_exception(response, org_id=org_id)
table_descriptions = [
AggrTableDescription(**table_description)
for table_description in response.json()
@@ -187,7 +187,7 @@ async def sync_databases_schemas(
json=scan_request.dict(exclude_unset=True),
timeout=settings.default_engine_timeout,
)
- raise_for_status(response.status_code, response.text)
+ raise_engine_exception(response, org_id=org_id)
table_descriptions = [
AggrTableDescription(**table_description)
for table_description in response.json()
@@ -218,7 +218,7 @@ async def update_table_description(
settings.engine_url + f"/table-descriptions/{table_description_id}",
json=table_description_request.dict(exclude_unset=True),
)
- raise_for_status(response.status_code, response.text)
+ raise_engine_exception(response, org_id=org_id)
return AggrTableDescription(
**response.json(), db_connection_alias=db_connection.alias
)
@@ -230,8 +230,5 @@ def get_table_description_in_org(
table_description_id, org_id
)
if not table_description:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="Table Description not found",
- )
+ raise TableDescriptionNotFoundError(table_description_id, org_id)
return table_description
diff --git a/apps/ai/server/modules/user/controller.py b/apps/ai/server/modules/user/controller.py
index fe1c1861..33d1b9b2 100644
--- a/apps/ai/server/modules/user/controller.py
+++ b/apps/ai/server/modules/user/controller.py
@@ -35,7 +35,7 @@ async def get_user(
async def add_user(
new_user_request: UserRequest, token: str = Depends(token_auth_scheme)
) -> UserResponse:
- authorize.is_admin_user(VerifyToken(token.credentials).verify())
+ authorize.is_admin_user(authorize.user(VerifyToken(token.credentials).verify()))
return user_service.add_user(new_user_request)
diff --git a/apps/ai/server/modules/user/models/entities.py b/apps/ai/server/modules/user/models/entities.py
index 9ec46be2..d4681190 100644
--- a/apps/ai/server/modules/user/models/entities.py
+++ b/apps/ai/server/modules/user/models/entities.py
@@ -1,7 +1,7 @@
from datetime import datetime
from enum import Enum
-from pydantic import BaseModel, Field
+from pydantic import BaseModel
from utils.validation import ObjectIdString
@@ -24,4 +24,4 @@ class BaseUser(BaseModel):
class User(BaseUser):
id: ObjectIdString | None
role: Roles | None
- created_at: datetime = Field(default_factory=datetime.now)
+ created_at: datetime = datetime.now()
diff --git a/apps/ai/server/modules/user/models/exceptions.py b/apps/ai/server/modules/user/models/exceptions.py
new file mode 100644
index 00000000..b87354e5
--- /dev/null
+++ b/apps/ai/server/modules/user/models/exceptions.py
@@ -0,0 +1,83 @@
+from starlette.status import HTTP_400_BAD_REQUEST, HTTP_404_NOT_FOUND, HTTP_409_CONFLICT
+
+from exceptions.error_codes import BaseErrorCode, ErrorCodeData
+from exceptions.exceptions import BaseError
+
+
+class UserErrorCode(BaseErrorCode):
+ user_not_found = ErrorCodeData(
+ status_code=HTTP_404_NOT_FOUND, message="User not found"
+ )
+ user_exists_in_org = ErrorCodeData(
+ status_code=HTTP_409_CONFLICT,
+ message="User already exists in organization",
+ )
+ user_exists_in_other_org = ErrorCodeData(
+ status_code=HTTP_409_CONFLICT,
+ message="User already exists in other organization",
+ )
+ cannot_create_user = ErrorCodeData(
+ status_code=HTTP_400_BAD_REQUEST, message="Cannot create user"
+ )
+ cannot_update_user = ErrorCodeData(
+ status_code=HTTP_400_BAD_REQUEST, message="Cannot update user"
+ )
+ cannot_delete_user = ErrorCodeData(
+ status_code=HTTP_400_BAD_REQUEST, message="Cannot delete user"
+ )
+
+
+class UserError(BaseError):
+ """
+ Base class for user exceptions
+ """
+
+ ERROR_CODES: BaseErrorCode = UserErrorCode
+
+
+class UserNotFoundError(UserError):
+ def __init__(self, user_id: str, org_id: str) -> None:
+ super().__init__(
+ error_code=UserErrorCode.user_not_found.name,
+ detail={"user_id": user_id, "organization_id": org_id},
+ )
+
+
+class UserExistsInOrgError(UserError):
+ def __init__(self, user_id: str, org_id: str) -> None:
+ super().__init__(
+ error_code=UserErrorCode.user_exists_in_org.name,
+ detail={"user_id": user_id, "organization_id": org_id},
+ )
+
+
+class UserExistsInOtherOrgError(UserError):
+ def __init__(self, user_id: str, org_id: str) -> None:
+ super().__init__(
+ error_code=UserErrorCode.user_exists_in_other_org.name,
+ detail={"user_id": user_id, "organization_id": org_id},
+ )
+
+
+class CannotCreateUserError(UserError):
+ def __init__(self, org_id: str) -> None:
+ super().__init__(
+ error_code=UserErrorCode.cannot_create_user.name,
+ detail={"organization_id": org_id},
+ )
+
+
+class CannotUpdateUserError(UserError):
+ def __init__(self, user_id: str, org_id: str) -> None:
+ super().__init__(
+ error_code=UserErrorCode.cannot_update_user.name,
+ detail={"user_id": user_id, "organization_id": org_id},
+ )
+
+
+class CannotDeleteUserError(UserError):
+ def __init__(self, user_id: str, org_id: str) -> None:
+ super().__init__(
+ error_code=UserErrorCode.cannot_delete_user.name,
+ detail={"user_id": user_id, "organization_id": org_id},
+ )
diff --git a/apps/ai/server/modules/user/service.py b/apps/ai/server/modules/user/service.py
index 8cf54e9e..4b27be57 100644
--- a/apps/ai/server/modules/user/service.py
+++ b/apps/ai/server/modules/user/service.py
@@ -1,7 +1,13 @@
from bson import ObjectId
-from fastapi import HTTPException, status
from modules.user.models.entities import User
+from modules.user.models.exceptions import (
+ CannotCreateUserError,
+ CannotDeleteUserError,
+ CannotUpdateUserError,
+ UserExistsInOrgError,
+ UserExistsInOtherOrgError,
+)
from modules.user.models.requests import UserOrganizationRequest, UserRequest
from modules.user.models.responses import UserResponse
from modules.user.repository import UserRepository
@@ -34,10 +40,7 @@ def add_user(self, user_request: UserRequest) -> UserResponse:
added_user = self.repo.get_user({"_id": ObjectId(new_user_id)})
return UserResponse(**added_user.dict())
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="User exists or cannot add user",
- )
+ raise CannotCreateUserError(user_request.organization_id)
def invite_user_to_org(
self, user_request: UserRequest, org_id: str
@@ -45,24 +48,15 @@ def invite_user_to_org(
stored_user = self.repo.get_user_by_email(user_request.email)
if stored_user:
if stored_user.organization_id == org_id:
- error_code = "USER_ALREADY_EXISTS_IN_ORG"
- else:
- error_code = "USER_ALREADY_EXISTS_IN_OTHER_ORG"
-
- raise HTTPException(
- status_code=status.HTTP_409_CONFLICT,
- detail=error_code,
- )
+ raise UserExistsInOrgError(stored_user.id)
+ raise UserExistsInOtherOrgError(stored_user.id, stored_user.organization_id)
new_user_data = User(
**user_request.dict(exclude={"organization_id"}), organization_id=org_id
)
new_user_id = self.repo.add_user(new_user_data)
if not new_user_id:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="An error occurred while trying to create the user",
- )
+ raise CannotCreateUserError(org_id)
new_user = self.repo.get_user({"_id": ObjectId(new_user_id)})
@@ -90,10 +84,7 @@ def update_user(self, user_id: str, user_request: UserRequest) -> UserResponse:
new_user = self.repo.get_user({"_id": ObjectId(user_id)})
return UserResponse(**new_user.dict())
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="User not found or cannot be updated",
- )
+ raise CannotUpdateUserError(user_id)
def update_user_organization(
self, user_id: str, user_organization_request: UserOrganizationRequest
@@ -108,10 +99,7 @@ def update_user_organization(
new_user = self.repo.get_user({"_id": ObjectId(user_id)})
return UserResponse(**new_user.dict())
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="User not found or cannot be updated",
- )
+ raise CannotUpdateUserError(user_id)
def delete_user(self, user_id: str, org_id: str) -> dict:
if (
@@ -130,7 +118,4 @@ def delete_user(self, user_id: str, org_id: str) -> dict:
):
return {"id": user_id}
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="User not found or cannot be deleted",
- )
+ raise CannotDeleteUserError(user_id)
diff --git a/apps/ai/server/utils/auth.py b/apps/ai/server/utils/auth.py
index aadcf05a..11550e22 100644
--- a/apps/ai/server/utils/auth.py
+++ b/apps/ai/server/utils/auth.py
@@ -1,6 +1,6 @@
import jwt
from bson import ObjectId
-from fastapi import HTTPException, Security, status
+from fastapi import Security
from fastapi.security import APIKeyHeader
from config import (
@@ -8,9 +8,21 @@
auth_settings,
)
from database.mongo import MongoDB
+from exceptions.exceptions import UnknownError
+from modules.auth.models.exceptions import (
+ BearerTokenExpiredError,
+ DecodeError,
+ InvalidBearerTokenError,
+ InvalidOrRevokedAPIKeyError,
+ PyJWKClientError,
+ UnauthorizedDataAccessError,
+ UnauthorizedOperationError,
+ UnauthorizedUserError,
+)
from modules.key.service import KeyService
from modules.organization.service import OrganizationService
from modules.user.models.entities import Roles
+from modules.user.models.exceptions import UserNotFoundError
from modules.user.models.responses import UserResponse
from modules.user.service import UserService
@@ -38,13 +50,9 @@ def _fetch_signing_key(self):
try:
self.signing_key = self.jwks_client.get_signing_key_from_jwt(self.token).key
except jwt.exceptions.PyJWKClientError as error:
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(error)
- ) from error
+ raise PyJWKClientError() from error
except jwt.exceptions.DecodeError as error:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED, detail=str(error)
- ) from error
+ raise DecodeError() from error
def _decode_payload(self):
try:
@@ -56,21 +64,13 @@ def _decode_payload(self):
issuer=auth_settings.auth0_issuer,
)
except jwt.ExpiredSignatureError as error:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED, detail="Token expired"
- ) from error
+ raise BearerTokenExpiredError() from error
except (jwt.InvalidAudienceError, jwt.InvalidIssuerError) as error:
- raise HTTPException(
- status_code=status.HTTP_403_FORBIDDEN, detail="Token is invalid"
- ) from error
+ raise InvalidBearerTokenError() from error
except (jwt.DecodeError, jwt.InvalidTokenError) as error:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED, detail="Token is invalid"
- ) from error
- except Exception as e:
- raise HTTPException(
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
- ) from e
+ raise InvalidBearerTokenError() from error
+ except Exception as error:
+ raise UnknownError(str(error)) from error
class Authorize:
@@ -78,52 +78,29 @@ def user(self, payload: dict) -> UserResponse:
email = payload[auth_settings.auth0_issuer + "email"]
user = user_service.get_user_by_email(email)
if not user:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized User"
- )
+ raise UnauthorizedUserError(email=email)
return user
def user_in_organization(self, user_id: str, org_id: str):
- self._item_in_organization(USER_COL, user_id, org_id)
+ if not MongoDB.find_one(
+ USER_COL,
+ {"_id": ObjectId(user_id), "organization_id": org_id},
+ ):
+ raise UserNotFoundError(user_id, org_id)
def is_admin_user(self, user: UserResponse):
if user.role != Roles.admin:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED, detail="User not authorized"
- )
+ raise UnauthorizedOperationError(user_id=user.id)
- def is_self(self, id_a: str, id_b: str):
- if id_a != id_b:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="User not authorized to access other user data",
- )
+ def is_self(self, user_a_id: str, user_b_id: str):
+ # TODO - fix param names to clear up confusion
+ if user_a_id != user_b_id:
+ raise UnauthorizedDataAccessError(user_id=user_a_id)
- def is_not_self(self, id_a: str, id_b: str):
- if id_a == id_b:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="User not authorized to self modify user data",
- )
-
- def _item_in_organization(
- self,
- collection: str,
- id: str,
- org_id: str,
- key: str = "_id",
- is_metadata: bool = False,
- ):
- metadata_prefix = "metadata" if is_metadata else ""
- item = MongoDB.find_one(
- collection,
- {key: ObjectId(id), f"{metadata_prefix}organization_id": org_id},
- )
-
- if not item:
- raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND, detail="Item not found"
- )
+ def is_not_self(self, user_a_id: str, user_b_id: str):
+ # TODO - fix param names to clear up confusion
+ if user_a_id == user_b_id:
+ raise UnauthorizedOperationError(user_id=user_a_id)
api_key_header = APIKeyHeader(name="X-API-Key")
@@ -134,6 +111,4 @@ def get_api_key(api_key: str = Security(api_key_header)) -> str:
if validated_key:
return validated_key
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED, detail="API Key does not exist"
- )
+ raise InvalidOrRevokedAPIKeyError(key_id=api_key)
diff --git a/apps/ai/server/utils/exception.py b/apps/ai/server/utils/exception.py
deleted file mode 100644
index 637f6a90..00000000
--- a/apps/ai/server/utils/exception.py
+++ /dev/null
@@ -1,49 +0,0 @@
-import logging
-from enum import Enum
-
-from fastapi import HTTPException, Request, status
-from fastapi.responses import JSONResponse
-
-logger = logging.getLogger(__name__)
-
-
-class ErrorCode(str, Enum):
- no_payment_method = "no_payment_method"
- spending_limit_exceeded = "spending_limit_exceeded"
- hard_spending_limit_exceeded = "hard_spending_limit_exceeded"
- subscription_past_due = "subscription_past_due"
- subscription_canceled = "subscription_canceled"
- unknown_subscription_status = "unknown_subscription_status"
-
-
-class GenerationEngineError(Exception):
- def __init__(
- self, status_code: int, prompt_id: str, display_id: str, error_message: str
- ):
- self.status_code = status_code
- self.prompt_id = prompt_id
- self.display_id = display_id
- self.error_message = error_message
-
-
-async def query_engine_exception_handler(
- request: Request, exc: GenerationEngineError # noqa: ARG001
-):
- return JSONResponse(
- status_code=exc.status_code,
- content={
- "prompt_id": exc.prompt_id,
- "display_id": exc.display_id,
- "error_message": exc.error_message,
- },
- )
-
-
-def raise_for_status(status_code: int, message: str = None):
- if status_code < status.HTTP_400_BAD_REQUEST:
- return
-
- logger.error("Error from K2-Engine: %s", message)
- raise HTTPException(
- status_code=status_code, detail=f"Error from K2-Engine: {message}"
- )
diff --git a/apps/ai/server/utils/misc.py b/apps/ai/server/utils/misc.py
index 675ae755..9e06f67b 100644
--- a/apps/ai/server/utils/misc.py
+++ b/apps/ai/server/utils/misc.py
@@ -1,6 +1,5 @@
-from fastapi import HTTPException, status
-
from database.mongo import DESCENDING, MongoDB
+from exceptions.exceptions import ReservedMetadataKeyError
MAX_DISPLAY_ID = 99999
RESERVED_KEY = "dh_internal"
@@ -31,7 +30,4 @@ def get_next_display_id(collection, org_id: str, prefix: str) -> str:
def reserved_key_in_metadata(metadata: dict):
if RESERVED_KEY in metadata:
- raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail=f"Metadata cannot contain reserved key: {RESERVED_KEY}",
- )
+ raise ReservedMetadataKeyError()