diff --git a/apps/ai/clients/admin-console/src/components/api-keys/api-keys-list.tsx b/apps/ai/clients/admin-console/src/components/api-keys/api-keys-list.tsx index c2db69e5..63a6ecde 100644 --- a/apps/ai/clients/admin-console/src/components/api-keys/api-keys-list.tsx +++ b/apps/ai/clients/admin-console/src/components/api-keys/api-keys-list.tsx @@ -1,5 +1,4 @@ import { getApiKeysColumns } from '@/components/api-keys/columns' -import ApiKeysError from '@/components/api-keys/error' import GenerateApiKeyDialog from '@/components/api-keys/generate-api-key-dialog' import { DataTable } from '@/components/data-table' import { LoadingTable } from '@/components/data-table/loading-table' @@ -8,6 +7,7 @@ import useApiKeys from '@/hooks/api/api-keys/useApiKeys' import { useDeleteApiKey } from '@/hooks/api/api-keys/useDeleteApiKey' import { KeyRound, RefreshCcw } from 'lucide-react' import { useCallback, useMemo } from 'react' +import PageErrorMessage from '../error/page-error-message' const ApiKeysList = () => { const { isLoading, isValidating, error, apiKeys, mutate } = useApiKeys() @@ -42,7 +42,13 @@ const ApiKeysList = () => { ) } else if (error) { - pageContent = + pageContent = ( + + ) } else if (apiKeys?.length === 0) { pageContent = (
No API keys generated yet.
diff --git a/apps/ai/clients/admin-console/src/components/api-keys/delete-api-key-dialog.tsx b/apps/ai/clients/admin-console/src/components/api-keys/delete-api-key-dialog.tsx index 0d253e20..545b0212 100644 --- a/apps/ai/clients/admin-console/src/components/api-keys/delete-api-key-dialog.tsx +++ b/apps/ai/clients/admin-console/src/components/api-keys/delete-api-key-dialog.tsx @@ -9,6 +9,7 @@ import { import { Button } from '@/components/ui/button' import { ToastAction } from '@/components/ui/toast' import { toast } from '@/components/ui/use-toast' +import { ErrorResponse } from '@/models/api' import { Loader, Trash2 } from 'lucide-react' import { FC, useState } from 'react' @@ -31,10 +32,11 @@ const DeleteApiKeyDialog: FC = ({ deleteFnc }) => { setOpen(false) } catch (e) { console.error(e) + const { message: title, trace_id: description } = e as ErrorResponse toast({ variant: 'destructive', - title: 'Oops! Something went wrong', - description: 'There was a problem with revoking your API Key.', + title, + description, action: ( Try again diff --git a/apps/ai/clients/admin-console/src/components/api-keys/error.tsx b/apps/ai/clients/admin-console/src/components/api-keys/error.tsx deleted file mode 100644 index 77818d57..00000000 --- a/apps/ai/clients/admin-console/src/components/api-keys/error.tsx +++ /dev/null @@ -1,11 +0,0 @@ -import PageErrorMessage from '@/components/layout/page-error-message' -import { FC } from 'react' - -const ApiKeysError: FC = () => ( - -) - -export default ApiKeysError diff --git a/apps/ai/clients/admin-console/src/components/api-keys/generate-api-key-dialog.tsx b/apps/ai/clients/admin-console/src/components/api-keys/generate-api-key-dialog.tsx index e8a2de8a..47200fc7 100644 --- a/apps/ai/clients/admin-console/src/components/api-keys/generate-api-key-dialog.tsx +++ b/apps/ai/clients/admin-console/src/components/api-keys/generate-api-key-dialog.tsx @@ -19,6 +19,7 @@ import { Toaster } from '@/components/ui/toaster' import { toast } from '@/components/ui/use-toast' import { usePostApiKey } from '@/hooks/api/api-keys/usePostApiKey' import { copyToClipboard } from '@/lib/utils' +import { ErrorResponse } from '@/models/api' import { yupResolver } from '@hookform/resolvers/yup' import { Copy, Loader, Plus } from 'lucide-react' import { FC, useState } from 'react' @@ -89,12 +90,13 @@ const GenerateApiKeyDialog: FC = ({ onGeneratedKey }) => { title: 'Secret API key generated', description: `Your secret key was generated successfully.`, }) - } catch (error) { - console.error(`Error generating the API key: ${error}`) + } catch (e) { + console.error(e) + const { message: title, trace_id: description } = e as ErrorResponse toast({ variant: 'destructive', - title: 'Oops! Something went wrong', - description: 'There was a problem generating your secret key.', + title, + description, action: ( = ({ }) await onPaymentMethodAdded() reset() - } catch (error) { - console.error(error) + } catch (e) { + console.error(e) + const { message: title, trace_id: description } = e as ErrorResponse toast({ variant: 'destructive', - title: 'Oops! Something went wrong', - description: 'The payment method could not be added.', + title, + description, }) } } } catch (error) { toast({ variant: 'destructive', - title: 'Oops! Something went wrong', + title: 'An error occurred', description: 'The payment method could not be added due to service provider error.', }) diff --git a/apps/ai/clients/admin-console/src/components/databases/database-connection-form-dialog.tsx b/apps/ai/clients/admin-console/src/components/databases/database-connection-form-dialog.tsx index dc9332bc..ba1b58c5 100644 --- a/apps/ai/clients/admin-console/src/components/databases/database-connection-form-dialog.tsx +++ b/apps/ai/clients/admin-console/src/components/databases/database-connection-form-dialog.tsx @@ -20,7 +20,7 @@ import { Toaster } from '@/components/ui/toaster' import { toast } from '@/components/ui/use-toast' import usePostDatabaseConnection from '@/hooks/api/usePostDatabaseConnection' import { formatDriver } from '@/lib/domain/database' -import { DatabaseConnection, Databases } from '@/models/api' +import { DatabaseConnection, Databases, ErrorResponse } from '@/models/api' import { yupResolver } from '@hookform/resolvers/yup' import { AlertCircle, @@ -99,10 +99,12 @@ const DatabaseConnectionFormDialog: FC = ({ setDatabaseConnected(true) onConnected(undefined, false) } catch (e) { + console.error(e) + const { message: title, trace_id: description } = e as ErrorResponse toast({ variant: 'destructive', - title: 'Oops! Something went wrong', - description: 'There was a problem connecting your Database.', + title, + description, action: ( = ({ await onRefresh(optimisticDatabasesUpdate) } catch (e) { console.error(e) + const { message: title, trace_id: description } = e as ErrorResponse toast({ variant: 'destructive', - title: 'Oops! Something went wrong', - description: 'There was a problem refreshing your Databases.', + title, + description, action: ( = ({ } } catch (e) { console.error(e) - onUpdateDatabasesData(databases) + const { message: title, trace_id: description } = e as ErrorResponse toast({ variant: 'destructive', - title: 'Oops! Something went wrong', - description: - 'There was a problem scanning your Databases table schemas.', + title, + description, action: ( Try again ), }) + onUpdateDatabasesData(databases) } finally { setIsSynchronizing(false) } diff --git a/apps/ai/clients/admin-console/src/components/databases/error.tsx b/apps/ai/clients/admin-console/src/components/databases/error.tsx deleted file mode 100644 index 834658d3..00000000 --- a/apps/ai/clients/admin-console/src/components/databases/error.tsx +++ /dev/null @@ -1,11 +0,0 @@ -import PageErrorMessage from '@/components/layout/page-error-message' -import { FC } from 'react' - -const DatabasesError: FC = () => ( - -) - -export default DatabasesError diff --git a/apps/ai/clients/admin-console/src/components/error/error-details.tsx b/apps/ai/clients/admin-console/src/components/error/error-details.tsx new file mode 100644 index 00000000..15b31afa --- /dev/null +++ b/apps/ai/clients/admin-console/src/components/error/error-details.tsx @@ -0,0 +1,69 @@ +import { cn, copyToClipboard } from '@/lib/utils' +import { ErrorResponse } from '@/models/api' +import { Copy } from 'lucide-react' +import { FC, HTMLAttributes } from 'react' +import { Button } from '../ui/button' +import { toast } from '../ui/use-toast' + +type ErrorDetailsProps = HTMLAttributes & { + error: ErrorResponse + size?: 'default' | 'small' + displayTitle?: boolean +} + +const ErrorDetails: FC = ({ + error, + displayTitle = true, + size = 'default', + className, + ...props +}) => { + const handleCopyErrorTraceId = async () => { + try { + await copyToClipboard(error?.trace_id) + toast({ + variant: 'success', + title: 'Error Trace ID copied!', + }) + } catch (error) { + console.error('Could not copy text: ', error) + toast({ + variant: 'destructive', + title: 'Could not copy the Error Trace ID', + }) + } + } + + const isSmall = size === 'small' + + return ( +
+ {displayTitle && Error details} +
+ Trace ID: + {error.trace_id} + +
+
+ Description: + {error.message} +
+
+ ) +} + +export default ErrorDetails diff --git a/apps/ai/clients/admin-console/src/components/error/page-error-message.tsx b/apps/ai/clients/admin-console/src/components/error/page-error-message.tsx new file mode 100644 index 00000000..e136d53c --- /dev/null +++ b/apps/ai/clients/admin-console/src/components/error/page-error-message.tsx @@ -0,0 +1,27 @@ +import { Toaster } from '@/components/ui/toaster' +import { ErrorResponse } from '@/models/api' +import { AlertOctagon } from 'lucide-react' +import { FC, HTMLAttributes } from 'react' +import ErrorDetails from './error-details' + +export type PageErrorMessageProps = HTMLAttributes & { + message: string + error?: ErrorResponse +} + +const PageErrorMessage: FC = ({ + message, + error, + ...props +}) => ( +
+
+ + {message} +
+ {error && } + +
+) + +export default PageErrorMessage diff --git a/apps/ai/clients/admin-console/src/components/fine-tunnings/error.tsx b/apps/ai/clients/admin-console/src/components/fine-tunnings/error.tsx deleted file mode 100644 index 5034e78f..00000000 --- a/apps/ai/clients/admin-console/src/components/fine-tunnings/error.tsx +++ /dev/null @@ -1,11 +0,0 @@ -import PageErrorMessage from '@/components/layout/page-error-message' -import { FC } from 'react' - -const FineTunningsError: FC = () => ( - -) - -export default FineTunningsError diff --git a/apps/ai/clients/admin-console/src/components/golden-sql/error.tsx b/apps/ai/clients/admin-console/src/components/golden-sql/error.tsx deleted file mode 100644 index a678d33b..00000000 --- a/apps/ai/clients/admin-console/src/components/golden-sql/error.tsx +++ /dev/null @@ -1,11 +0,0 @@ -import PageErrorMessage from '@/components/layout/page-error-message' -import { FC } from 'react' - -const QueriesError: FC = () => ( - -) - -export default QueriesError diff --git a/apps/ai/clients/admin-console/src/components/hoc/WithApiFetcher.tsx b/apps/ai/clients/admin-console/src/components/hoc/WithApiFetcher.tsx new file mode 100644 index 00000000..41e6f86c --- /dev/null +++ b/apps/ai/clients/admin-console/src/components/hoc/WithApiFetcher.tsx @@ -0,0 +1,19 @@ +import useApiFetcher from '@/hooks/api/generics/useApiFetcher' +import { FC, ReactNode } from 'react' +import { SWRConfig } from 'swr' + +const WithApiFetcher: FC<{ children: ReactNode }> = ({ children }) => { + const { apiFetcher } = useApiFetcher() + return ( + + {children} + + ) +} + +export default WithApiFetcher diff --git a/apps/ai/clients/admin-console/src/components/layout/page-error-message.tsx b/apps/ai/clients/admin-console/src/components/layout/page-error-message.tsx deleted file mode 100644 index 7befff2f..00000000 --- a/apps/ai/clients/admin-console/src/components/layout/page-error-message.tsx +++ /dev/null @@ -1,15 +0,0 @@ -import { AlertOctagon } from 'lucide-react' -import { FC, HTMLAttributes } from 'react' - -export type PageErrorMessageProps = HTMLAttributes & { - message: string -} - -const PageErrorMessage: FC = ({ message, ...props }) => ( -
- - {message} -
-) - -export default PageErrorMessage diff --git a/apps/ai/clients/admin-console/src/components/organization/edit-organization-dialog.tsx b/apps/ai/clients/admin-console/src/components/organization/edit-organization-dialog.tsx index 6226ffe8..f25d013b 100644 --- a/apps/ai/clients/admin-console/src/components/organization/edit-organization-dialog.tsx +++ b/apps/ai/clients/admin-console/src/components/organization/edit-organization-dialog.tsx @@ -21,6 +21,7 @@ import { Toaster } from '@/components/ui/toaster' import { toast } from '@/components/ui/use-toast' import { useAppContext } from '@/contexts/app-context' import { usePutOrganization } from '@/hooks/api/organization/usePutOrganization' +import { ErrorResponse } from '@/models/api' import { yupResolver } from '@hookform/resolvers/yup' import { Edit, Loader } from 'lucide-react' import { FC, useEffect, useState } from 'react' @@ -70,11 +71,13 @@ const EditOrganizationDialog: FC = () => { title: 'Organization name updated', description: `The organization name has been updated.`, }) - } catch (error) { + } catch (e) { + console.error(e) + const { message: title, trace_id: description } = e as ErrorResponse toast({ variant: 'destructive', - title: 'Oops! Something went wrong', - description: 'The organization name could not be updated.', + title, + description, action: ( = ({ onOrganizationUpdate() setEditEnabled(false) form.reset() - } catch (error) { - console.error(error) - form.setError('llm_api_key', { - message: 'Invalid API key', - }) + } catch (e) { + console.error(e) + const { message: title, trace_id: description } = e as ErrorResponse toast({ variant: 'destructive', - title: 'Oops! Something went wrong', - description: - 'There was a problem updating the API key. Please try again.', + title, + description, action: ( updateLlmApiKey()}> Try again ), }) + form.setError('llm_api_key', { + message: 'Invalid API key', + }) } finally { setUpdating(false) } diff --git a/apps/ai/clients/admin-console/src/components/organization/payment-methods-list.tsx b/apps/ai/clients/admin-console/src/components/organization/payment-methods-list.tsx index ed1ab757..0d286d10 100644 --- a/apps/ai/clients/admin-console/src/components/organization/payment-methods-list.tsx +++ b/apps/ai/clients/admin-console/src/components/organization/payment-methods-list.tsx @@ -1,6 +1,6 @@ import AddPaymentMethodDialog from '@/components/billing/add-payment-method-dialog' import LoadingList from '@/components/layout/loading-list' -import PageErrorMessage from '@/components/layout/page-error-message' +import PageErrorMessage from '@/components/error/page-error-message' import { AlertDialog, AlertDialogCancel, @@ -16,7 +16,7 @@ import { toast } from '@/components/ui/use-toast' import { useAppContext } from '@/contexts/app-context' import { useDeletePaymentMethod } from '@/hooks/api/billing/useDeletePaymentMethod' import usePaymentMethods from '@/hooks/api/billing/usePaymentMethods' -import { PaymentMethod } from '@/models/api' +import { ErrorResponse, PaymentMethod } from '@/models/api' import { CreditCard, Loader, Trash2 } from 'lucide-react' import CreditCardLogo from '@/components/billing/credit-card-logo' @@ -49,12 +49,13 @@ const PaymentMethodsList = () => { title: 'Payment Method Removed', description: `The payment method ending in ${pm.last4} was removed from the Organization.`, }) - } catch (error) { - console.error(error) + } catch (e) { + console.error(e) + const { message: title, trace_id: description } = e as ErrorResponse toast({ variant: 'destructive', - title: 'Oops! Something went wrong', - description: 'There was a problem removing the payment method.', + title, + description, action: ( {

Payment Methods

{error ? ( - + ) : ( <>
diff --git a/apps/ai/clients/admin-console/src/components/queries/error.tsx b/apps/ai/clients/admin-console/src/components/queries/error.tsx deleted file mode 100644 index c24e2e1f..00000000 --- a/apps/ai/clients/admin-console/src/components/queries/error.tsx +++ /dev/null @@ -1,11 +0,0 @@ -import PageErrorMessage from '@/components/layout/page-error-message' -import { FC } from 'react' - -const QueriesError: FC = () => ( - -) - -export default QueriesError diff --git a/apps/ai/clients/admin-console/src/components/query/error.tsx b/apps/ai/clients/admin-console/src/components/query/error.tsx deleted file mode 100644 index ad239fc6..00000000 --- a/apps/ai/clients/admin-console/src/components/query/error.tsx +++ /dev/null @@ -1,11 +0,0 @@ -import PageErrorMessage from '@/components/layout/page-error-message' -import { FC } from 'react' - -const QueryError: FC = () => ( - -) - -export default QueryError diff --git a/apps/ai/clients/admin-console/src/components/query/message-section.tsx b/apps/ai/clients/admin-console/src/components/query/message-section.tsx index 35334f11..d82123b8 100644 --- a/apps/ai/clients/admin-console/src/components/query/message-section.tsx +++ b/apps/ai/clients/admin-console/src/components/query/message-section.tsx @@ -26,7 +26,7 @@ import { ToastAction } from '@/components/ui/toast' import { toast } from '@/components/ui/use-toast' import useQueryGenerateMessage from '@/hooks/api/query/useQueryGenerateMessage' import { cn } from '@/lib/utils' -import { Query } from '@/models/api' +import { ErrorResponse, Query } from '@/models/api' import { formatDistance } from 'date-fns' import { Bot, Edit, Info, Loader } from 'lucide-react' import Image from 'next/image' @@ -88,12 +88,13 @@ const MessageSection: FC = ({ title: 'Message updated', description: 'The query message was updated successfully.', }) - } catch (error) { - console.error(error) + } catch (e) { + console.error(e) + const { message: title, trace_id: description } = e as ErrorResponse toast({ variant: 'destructive', - title: 'Oops! Something went wrong', - description: 'There was a problem with updating the message.', + title, + description, action: ( = ({ description: 'The query message was sent to the Slack thread.', }) await onMessageSent() - } catch (error) { - console.error(error) + } catch (e) { + console.error(e) + const { message: title, trace_id: description } = e as ErrorResponse toast({ variant: 'destructive', - title: 'Oops! Something went wrong', - description: 'There was a problem with sending the message.', + title, + description, action: ( Try again diff --git a/apps/ai/clients/admin-console/src/components/query/workspace.tsx b/apps/ai/clients/admin-console/src/components/query/workspace.tsx index 4e92679b..a18ce51e 100644 --- a/apps/ai/clients/admin-console/src/components/query/workspace.tsx +++ b/apps/ai/clients/admin-console/src/components/query/workspace.tsx @@ -26,7 +26,7 @@ import { isVerified, } from '@/lib/domain/query' import { cn } from '@/lib/utils' -import { Query, QueryStatus } from '@/models/api' +import { ErrorResponse, Query, QueryStatus } from '@/models/api' import { EDomainQueryWorkspaceStatus, QueryWorkspaceStatus, @@ -149,12 +149,13 @@ const QueryWorkspace: FC = ({ description: 'The query was resubmitted to the platform for a new response.', }) - } catch (error) { - console.error(error) + } catch (e) { + console.error(e) + const { message: title, trace_id: description } = e as ErrorResponse toast({ variant: 'destructive', - title: 'Oops! Something went wrong', - description: 'There was a problem with resubmitting the query.', + title, + description, action: ( Try again @@ -179,10 +180,11 @@ const QueryWorkspace: FC = ({ }) } catch (e) { console.error(e) + const { message: title, trace_id: description } = e as ErrorResponse toast({ variant: 'destructive', - title: 'Oops! Something went wrong', - description: 'There was a problem with running the query.', + title, + description, action: ( Try again @@ -223,11 +225,11 @@ const QueryWorkspace: FC = ({ } } catch (e) { console.error(e) - setCurrentQueryStatus(getWorkspaceQueryStatus(status)) + const { message: title, trace_id: description } = e as ErrorResponse toast({ variant: 'destructive', - title: 'Oops! Something went wrong', - description: 'There was a problem with updating the query status.', + title, + description, action: ( = ({ ), }) + setCurrentQueryStatus(getWorkspaceQueryStatus(status)) } finally { setUpdatingQueryStatus(false) } diff --git a/apps/ai/clients/admin-console/src/components/usage/edit-spending-limit-dialog.tsx b/apps/ai/clients/admin-console/src/components/usage/edit-spending-limit-dialog.tsx index c13eaaab..5e6f502a 100644 --- a/apps/ai/clients/admin-console/src/components/usage/edit-spending-limit-dialog.tsx +++ b/apps/ai/clients/admin-console/src/components/usage/edit-spending-limit-dialog.tsx @@ -23,7 +23,7 @@ import { toast } from '@/components/ui/use-toast' import { useAppContext } from '@/contexts/app-context' import usePutSpendingLimits from '@/hooks/api/billing/usePutSpendingLimits' import { toCents, toDollars } from '@/lib/utils' -import { SpendingLimits } from '@/models/api' +import { ErrorResponse, SpendingLimits } from '@/models/api' import { yupResolver } from '@hookform/resolvers/yup' import { Loader } from 'lucide-react' import { FC, useState } from 'react' @@ -89,11 +89,13 @@ const EditSpendingLimitDialog: FC = ({ variant: 'success', title: 'Spending limit updated', }) - } catch (error) { + } catch (e) { + console.error(e) + const { message: title, trace_id: description } = e as ErrorResponse toast({ variant: 'destructive', - title: 'Oops! Something went wrong', - description: 'The spending limit could not be updated.', + title, + description, action: ( > = ({ className }) => { ) } else if (error) { pageContent = ( - + ) } else if (usage) { if (usage.current_period_start && usage.current_period_end) { diff --git a/apps/ai/clients/admin-console/src/components/user/invite-member-dialog.tsx b/apps/ai/clients/admin-console/src/components/user/invite-member-dialog.tsx index a01be0b9..07d75580 100644 --- a/apps/ai/clients/admin-console/src/components/user/invite-member-dialog.tsx +++ b/apps/ai/clients/admin-console/src/components/user/invite-member-dialog.tsx @@ -20,6 +20,8 @@ import { ToastAction } from '@/components/ui/toast' import { Toaster } from '@/components/ui/toaster' import { toast } from '@/components/ui/use-toast' import { usePostUserToOrganization } from '@/hooks/api/user/usePostUserToOrganization' +import { ErrorResponse } from '@/models/api' +import { EUserErrorCode } from '@/models/errorCodes' import { yupResolver } from '@hookform/resolvers/yup' import { AlertCircle, Loader, ShieldAlert } from 'lucide-react' import { FC, useState } from 'react' @@ -72,25 +74,22 @@ const InviteMemberDialog: FC = ({ }) onInviteMember() setError(undefined) - } catch (error) { - let errorDescription: string - let toastAction: JSX.Element | undefined - const errorResponse = error as Error - console.error(`Error adding user: ${errorResponse}`) - if (errorResponse.cause === 409) { - switch (errorResponse.message) { - case 'USER_ALREADY_EXISTS_IN_ORG': - errorDescription = `The user is already a member of your Organization.` - break - case 'USER_ALREADY_EXISTS_IN_OTHER_ORG': - errorDescription = `The user is already a member of another Organization.` - break - default: - errorDescription = `There was a problem inviting the new member to the Organization.` - } - } else { - errorDescription = `There was a problem inviting the new member to the Organization.` - toastAction = ( + } catch (e) { + console.error(e) + const { + message: title, + trace_id: description, + error_code, + } = e as ErrorResponse + let action: JSX.Element | undefined + if ( + ![ + EUserErrorCode.user_exists_in_org, + EUserErrorCode.user_exists_in_other_org, + ].includes(error_code as EUserErrorCode) + ) { + // should be able to retry if the error is not related to user already existing + action = ( form.handleSubmit(handleInviteMember)()} @@ -99,12 +98,12 @@ const InviteMemberDialog: FC = ({ ) } - setError(errorDescription) + setError(title) toast({ variant: 'destructive', - title: 'Oops! Something went wrong', - description: errorDescription, - action: toastAction, + title, + description, + action, }) } finally { setInvitingMember(false) diff --git a/apps/ai/clients/admin-console/src/components/user/user-list.tsx b/apps/ai/clients/admin-console/src/components/user/user-list.tsx index 01b3647c..c1742926 100644 --- a/apps/ai/clients/admin-console/src/components/user/user-list.tsx +++ b/apps/ai/clients/admin-console/src/components/user/user-list.tsx @@ -19,14 +19,15 @@ import UserPicture from '@/components/user/user-picture' import { useAppContext } from '@/contexts/app-context' import { useDeleteUser } from '@/hooks/api/user/useDeleteUser' import useUsers from '@/hooks/api/user/useUsers' -import { User } from '@/models/api' +import { ErrorResponse, User } from '@/models/api' import { Trash2, UserPlus2 } from 'lucide-react' import { useState } from 'react' +import PageErrorMessage from '../error/page-error-message' const UserList = () => { const [openInviteMemberDialog, setOpenInviteMemberDialog] = useState(false) const { user: currentUser } = useAppContext() - const { isLoading, users, mutate } = useUsers() + const { isLoading, error, users, mutate } = useUsers() const deleteUser = useDeleteUser() const handleMemberInvited = () => { @@ -45,13 +46,13 @@ const UserList = () => { } was removed from the Organization.`, }) mutate() - } catch (error) { - console.error(error) + } catch (e) { + console.error(e) + const { message: title, trace_id: description } = e as ErrorResponse toast({ variant: 'destructive', - title: 'Oops! Something went wrong', - description: - 'There was a problem removing the user from the Organization.', + title, + description, action: ( {
{isLoading ? ( + ) : error ? ( + ) : ( users && (
    diff --git a/apps/ai/clients/admin-console/src/contexts/subscription-context.tsx b/apps/ai/clients/admin-console/src/contexts/subscription-context.tsx index 835837fb..6afd977c 100644 --- a/apps/ai/clients/admin-console/src/contexts/subscription-context.tsx +++ b/apps/ai/clients/admin-console/src/contexts/subscription-context.tsx @@ -1,14 +1,6 @@ +import { ESubscriptionErrorCode } from '@/models/errorCodes' import React, { ReactNode, createContext, useContext, useState } from 'react' -export enum ESubscriptionErrorCode { - no_payment_method = 'no_payment_method', - spending_limit_exceeded = 'spending_limit_exceeded', - hard_spending_limit_exceeded = 'hard_spending_limit_exceeded', - subscription_past_due = 'subscription_past_due', - subscription_canceled = 'subscription_canceled', - unknown_subscription_status = 'unknown_subscription_status', -} - export type SubscriptionErrorCode = keyof typeof ESubscriptionErrorCode interface SubscriptionContextType { @@ -48,11 +40,3 @@ export const useSubscription = () => { } return context } - -export const isSubscriptionErrorCode = ( - errorCode: string, -): errorCode is ESubscriptionErrorCode => { - return Object.values(ESubscriptionErrorCode).includes( - errorCode as ESubscriptionErrorCode, - ) -} diff --git a/apps/ai/clients/admin-console/src/hooks/api/api-keys/useApiKeys.ts b/apps/ai/clients/admin-console/src/hooks/api/api-keys/useApiKeys.ts index 0bc492e9..9e13a966 100644 --- a/apps/ai/clients/admin-console/src/hooks/api/api-keys/useApiKeys.ts +++ b/apps/ai/clients/admin-console/src/hooks/api/api-keys/useApiKeys.ts @@ -1,13 +1,13 @@ import { API_URL } from '@/config' import { useAuth } from '@/contexts/auth-context' -import { ApiKeys } from '@/models/api' +import { ApiKeys, ErrorResponse } from '@/models/api' import useSWR, { KeyedMutator } from 'swr' interface ApiKeysResponse { apiKeys: ApiKeys | undefined isLoading: boolean isValidating: boolean - error: unknown + error: ErrorResponse | null mutate: KeyedMutator } diff --git a/apps/ai/clients/admin-console/src/hooks/api/billing/usePaymentMethods.ts b/apps/ai/clients/admin-console/src/hooks/api/billing/usePaymentMethods.ts index ed883040..8c8537bc 100644 --- a/apps/ai/clients/admin-console/src/hooks/api/billing/usePaymentMethods.ts +++ b/apps/ai/clients/admin-console/src/hooks/api/billing/usePaymentMethods.ts @@ -1,13 +1,13 @@ import { API_URL } from '@/config' import { useAppContext } from '@/contexts/app-context' import { useAuth } from '@/contexts/auth-context' -import { PaymentMethods } from '@/models/api' +import { ErrorResponse, PaymentMethods } from '@/models/api' import useSWR, { KeyedMutator } from 'swr' interface PaymentMethodsResponse { paymentMethods: PaymentMethods | undefined isLoading: boolean - error: unknown + error: ErrorResponse | null mutate: KeyedMutator } diff --git a/apps/ai/clients/admin-console/src/hooks/api/billing/useSpendingLimits.tsx b/apps/ai/clients/admin-console/src/hooks/api/billing/useSpendingLimits.tsx index 411d8eaa..3e44a688 100644 --- a/apps/ai/clients/admin-console/src/hooks/api/billing/useSpendingLimits.tsx +++ b/apps/ai/clients/admin-console/src/hooks/api/billing/useSpendingLimits.tsx @@ -1,14 +1,14 @@ import { API_URL } from '@/config' import { useAppContext } from '@/contexts/app-context' import { useAuth } from '@/contexts/auth-context' -import { SpendingLimits } from '@/models/api' +import { ErrorResponse, SpendingLimits } from '@/models/api' import useSWR, { KeyedMutator } from 'swr' interface SpendingLimitsResponse { limits: SpendingLimits | undefined isLoading: boolean isValidating: boolean - error: unknown + error: ErrorResponse | null mutate: KeyedMutator } diff --git a/apps/ai/clients/admin-console/src/hooks/api/billing/useUsage.tsx b/apps/ai/clients/admin-console/src/hooks/api/billing/useUsage.tsx index 378b7e97..d6bb8091 100644 --- a/apps/ai/clients/admin-console/src/hooks/api/billing/useUsage.tsx +++ b/apps/ai/clients/admin-console/src/hooks/api/billing/useUsage.tsx @@ -1,14 +1,14 @@ import { API_URL } from '@/config' import { useAppContext } from '@/contexts/app-context' import { useAuth } from '@/contexts/auth-context' -import { Usage } from '@/models/api' +import { ErrorResponse, Usage } from '@/models/api' import useSWR, { KeyedMutator } from 'swr' interface UsageResponse { usage: Usage | undefined isLoading: boolean isValidating: boolean - error: unknown + error: ErrorResponse | null mutate: KeyedMutator } diff --git a/apps/ai/clients/admin-console/src/hooks/api/fine-tuning/useFinetunings.ts b/apps/ai/clients/admin-console/src/hooks/api/fine-tuning/useFinetunings.ts index a8c26715..873e5978 100644 --- a/apps/ai/clients/admin-console/src/hooks/api/fine-tuning/useFinetunings.ts +++ b/apps/ai/clients/admin-console/src/hooks/api/fine-tuning/useFinetunings.ts @@ -1,13 +1,13 @@ import { API_URL } from '@/config' import { useAuth } from '@/contexts/auth-context' -import { FineTuningModels } from '@/models/api' +import { ErrorResponse, FineTuningModels } from '@/models/api' import useSWR, { KeyedMutator } from 'swr' interface FineTuningModelsResponse { models: FineTuningModels | undefined isLoading: boolean isValidating: boolean - error: unknown + error: ErrorResponse | null mutate: KeyedMutator } diff --git a/apps/ai/clients/admin-console/src/hooks/api/generics/useApiFetcher.ts b/apps/ai/clients/admin-console/src/hooks/api/generics/useApiFetcher.ts index cfbbd79e..4d52b688 100644 --- a/apps/ai/clients/admin-console/src/hooks/api/generics/useApiFetcher.ts +++ b/apps/ai/clients/admin-console/src/hooks/api/generics/useApiFetcher.ts @@ -1,10 +1,7 @@ import { API_URL } from '@/config' import { useAuth } from '@/contexts/auth-context' -import { - ESubscriptionErrorCode, - isSubscriptionErrorCode, - useSubscription, -} from '@/contexts/subscription-context' +import { useSubscription } from '@/contexts/subscription-context' +import { isErrorResponse, isSubscriptionErrorCode } from '@/lib/domain/error' import { useRouter } from 'next/navigation' import { useCallback, useState } from 'react' @@ -48,16 +45,14 @@ const useApiFetcher = () => { router.push('api/auth/logout') } } else { - const serverError: { detail: string } = await response.json() - const error = new Error(serverError.detail, { - cause: response.status, - }) - // for now the error codes are in the `detail` field. This will become a JSON in the future - if (isSubscriptionErrorCode(serverError.detail)) { - setSubscriptionStatus(serverError.detail as ESubscriptionErrorCode) - } + const errorResponse = await response.json() + if (isErrorResponse(errorResponse)) { + if (isSubscriptionErrorCode(errorResponse.error_code)) { + setSubscriptionStatus(errorResponse.error_code) + } - throw error + throw errorResponse + } } } return response.json() diff --git a/apps/ai/clients/admin-console/src/hooks/api/generics/usePagination.ts b/apps/ai/clients/admin-console/src/hooks/api/generics/usePagination.ts index 3e50edca..4def35c8 100644 --- a/apps/ai/clients/admin-console/src/hooks/api/generics/usePagination.ts +++ b/apps/ai/clients/admin-console/src/hooks/api/generics/usePagination.ts @@ -1,5 +1,6 @@ import { API_URL } from '@/config' import { useAuth } from '@/contexts/auth-context' +import { ErrorResponse } from '@/models/api' import { KeyedMutator } from 'swr' import useSWRInfinite from 'swr/infinite' @@ -12,7 +13,7 @@ export interface PageResponse { isLoadingFirst: boolean isLoadingMore: boolean isReachingEnd: boolean - error: unknown + error: ErrorResponse | null page: number setPage: ( page: number | ((_page: number) => number), diff --git a/apps/ai/clients/admin-console/src/hooks/api/organization/useOrganizations.ts b/apps/ai/clients/admin-console/src/hooks/api/organization/useOrganizations.ts index a7ab169f..23819be1 100644 --- a/apps/ai/clients/admin-console/src/hooks/api/organization/useOrganizations.ts +++ b/apps/ai/clients/admin-console/src/hooks/api/organization/useOrganizations.ts @@ -1,12 +1,12 @@ import { API_URL } from '@/config' import { useAuth } from '@/contexts/auth-context' -import { Organizations } from '@/models/api' +import { ErrorResponse, Organizations } from '@/models/api' import useSWR from 'swr' interface OrganizationsResponse { organizations: Organizations | undefined isLoading: boolean - error: unknown + error: ErrorResponse | null } const useOrganizations = (): OrganizationsResponse => { diff --git a/apps/ai/clients/admin-console/src/hooks/api/useDatabaseConnection.ts b/apps/ai/clients/admin-console/src/hooks/api/useDatabaseConnection.ts index cb19f51f..62dd781f 100644 --- a/apps/ai/clients/admin-console/src/hooks/api/useDatabaseConnection.ts +++ b/apps/ai/clients/admin-console/src/hooks/api/useDatabaseConnection.ts @@ -1,12 +1,12 @@ import { API_URL } from '@/config' import { useAuth } from '@/contexts/auth-context' -import { DatabaseConnection } from '@/models/api' +import { DatabaseConnection, ErrorResponse } from '@/models/api' import useSWR from 'swr' interface DatabaseConnectionResponse { databaseConnection: DatabaseConnection | undefined isLoading: boolean - error: unknown + error: ErrorResponse | null } const useDatabaseConnection = ( diff --git a/apps/ai/clients/admin-console/src/hooks/api/useDatabaseConnections.ts b/apps/ai/clients/admin-console/src/hooks/api/useDatabaseConnections.ts index 11a5e437..5e3c54cd 100644 --- a/apps/ai/clients/admin-console/src/hooks/api/useDatabaseConnections.ts +++ b/apps/ai/clients/admin-console/src/hooks/api/useDatabaseConnections.ts @@ -1,12 +1,12 @@ import { API_URL } from '@/config' import { useAuth } from '@/contexts/auth-context' -import { DatabaseConnections } from '@/models/api' +import { DatabaseConnections, ErrorResponse } from '@/models/api' import useSWR from 'swr' interface DatabaseConnectionResponse { dbConnections: DatabaseConnections | undefined isLoading: boolean - error: unknown + error: ErrorResponse | null } const useDatabaseConnections = (): DatabaseConnectionResponse => { diff --git a/apps/ai/clients/admin-console/src/hooks/api/useDatabases.ts b/apps/ai/clients/admin-console/src/hooks/api/useDatabases.ts index 6443dc8c..8c795fb8 100644 --- a/apps/ai/clients/admin-console/src/hooks/api/useDatabases.ts +++ b/apps/ai/clients/admin-console/src/hooks/api/useDatabases.ts @@ -1,14 +1,14 @@ import { API_URL } from '@/config' import { useAuth } from '@/contexts/auth-context' import useApiFetcher from '@/hooks/api/generics/useApiFetcher' -import { Databases } from '@/models/api' +import { Databases, ErrorResponse } from '@/models/api' import { useEffect, useState } from 'react' import useSWR from 'swr' interface DatabasesResponse { databases: Databases | undefined isLoading: boolean - error: unknown + error: ErrorResponse | null mutate: ( optimisticData?: Databases, refresh?: boolean, diff --git a/apps/ai/clients/admin-console/src/hooks/api/user/useUsers.ts b/apps/ai/clients/admin-console/src/hooks/api/user/useUsers.ts index 840d32f2..a3a085c4 100644 --- a/apps/ai/clients/admin-console/src/hooks/api/user/useUsers.ts +++ b/apps/ai/clients/admin-console/src/hooks/api/user/useUsers.ts @@ -1,12 +1,12 @@ import { API_URL } from '@/config' import { useAuth } from '@/contexts/auth-context' -import { Users } from '@/models/api' +import { ErrorResponse, Users } from '@/models/api' import useSWR, { KeyedMutator } from 'swr' interface UsersResponse { users: Users | undefined isLoading: boolean - error: unknown + error: ErrorResponse | null mutate: KeyedMutator } diff --git a/apps/ai/clients/admin-console/src/hooks/database/useColumnResource.ts b/apps/ai/clients/admin-console/src/hooks/database/useColumnResource.ts index 49dd21fd..a953969c 100644 --- a/apps/ai/clients/admin-console/src/hooks/database/useColumnResource.ts +++ b/apps/ai/clients/admin-console/src/hooks/database/useColumnResource.ts @@ -4,7 +4,11 @@ import { API_URL } from '@/config' import useApiFetcher from '@/hooks/api/generics/useApiFetcher' import { UseDatabaseResourceFromTree } from '@/hooks/database/useDatabaseResourceFromTree' import { isColumnResource } from '@/lib/domain/database' -import { ColumnDescription, TableDescription } from '@/models/api' +import { + ColumnDescription, + ErrorResponse, + TableDescription, +} from '@/models/api' import { ColumnResource, DatabaseResourceType } from '@/models/domain' import { useCallback } from 'react' import useSWR, { mutate } from 'swr' @@ -62,10 +66,12 @@ export const useColumnResource = ( title: `${columnDescription?.name} description updated`, }) } catch (e) { + console.error(e) + const { message: title, trace_id: description } = e as ErrorResponse toast({ variant: 'destructive', - title: 'Oops! Something went wrong', - description: `${columnDescription?.name} description could not be updated`, + title, + description, }) } }, diff --git a/apps/ai/clients/admin-console/src/hooks/database/useDatabaseResource.ts b/apps/ai/clients/admin-console/src/hooks/database/useDatabaseResource.ts index bfe634c4..289cbd60 100644 --- a/apps/ai/clients/admin-console/src/hooks/database/useDatabaseResource.ts +++ b/apps/ai/clients/admin-console/src/hooks/database/useDatabaseResource.ts @@ -4,7 +4,7 @@ import { API_URL } from '@/config' import useApiFetcher from '@/hooks/api/generics/useApiFetcher' import { UseDatabaseResourceFromTree } from '@/hooks/database/useDatabaseResourceFromTree' import { isDatabaseResource } from '@/lib/domain/database' -import { Instruction } from '@/models/api' +import { ErrorResponse, Instruction } from '@/models/api' import { DatabaseResource, DatabaseResourceType } from '@/models/domain' import { useCallback } from 'react' import useSWR, { mutate } from 'swr' @@ -26,7 +26,6 @@ export const useDatabaseResource = ( isLoading, error, } = useSWR(resourceUrl, apiFetcher, { - errorRetryCount: 0, revalidateOnFocus: false, revalidateIfStale: false, }) @@ -61,10 +60,12 @@ export const useDatabaseResource = ( ) toast({ variant: 'success', title: 'Database instructions updated' }) } catch (e) { + console.error(e) + const { message: title, trace_id: description } = e as ErrorResponse toast({ variant: 'destructive', - title: 'Oops! Something went wrong', - description: 'Database instructions could not be updated', + title, + description, }) } }, diff --git a/apps/ai/clients/admin-console/src/hooks/database/useDatabaseResourceFromTree.ts b/apps/ai/clients/admin-console/src/hooks/database/useDatabaseResourceFromTree.ts index 5d483488..7d3b5f52 100644 --- a/apps/ai/clients/admin-console/src/hooks/database/useDatabaseResourceFromTree.ts +++ b/apps/ai/clients/admin-console/src/hooks/database/useDatabaseResourceFromTree.ts @@ -7,6 +7,7 @@ import { isDatabaseResource, isTableResource, } from '@/lib/domain/database' +import { ErrorResponse } from '@/models/api' import { ColumnResource, DatabaseResource, @@ -17,7 +18,7 @@ import { export interface UseDatabaseResourceFromTree { resource: R | null isLoading: boolean - error: unknown + error: ErrorResponse | null updateResource: ((newText: string) => Promise) | null } diff --git a/apps/ai/clients/admin-console/src/hooks/database/useTableResource.ts b/apps/ai/clients/admin-console/src/hooks/database/useTableResource.ts index 23d7a41c..1f3e1b82 100644 --- a/apps/ai/clients/admin-console/src/hooks/database/useTableResource.ts +++ b/apps/ai/clients/admin-console/src/hooks/database/useTableResource.ts @@ -4,7 +4,7 @@ import { API_URL } from '@/config' import useApiFetcher from '@/hooks/api/generics/useApiFetcher' import { UseDatabaseResourceFromTree } from '@/hooks/database/useDatabaseResourceFromTree' import { isTableResource } from '@/lib/domain/database' -import { TableDescription } from '@/models/api' +import { ErrorResponse, TableDescription } from '@/models/api' import { DatabaseResourceType, TableResource } from '@/models/domain' import { useCallback } from 'react' import useSWR, { mutate } from 'swr' @@ -56,10 +56,12 @@ export const useTableResource = ( title: `${tableDescription?.table_name} description updated`, }) } catch (e) { + console.error(e) + const { message: title, trace_id: description } = e as ErrorResponse toast({ variant: 'destructive', - title: 'Oops! Something went wrong', - description: `${tableDescription?.table_name} description could not be updated`, + title, + description, }) } }, diff --git a/apps/ai/clients/admin-console/src/lib/domain/error.tsx b/apps/ai/clients/admin-console/src/lib/domain/error.tsx new file mode 100644 index 00000000..190bfd74 --- /dev/null +++ b/apps/ai/clients/admin-console/src/lib/domain/error.tsx @@ -0,0 +1,20 @@ +import { ErrorResponse } from '@/models/api' +import { ESubscriptionErrorCode } from '@/models/errorCodes' + +export const isErrorResponse = (error: unknown): error is ErrorResponse => { + return ( + typeof error === 'object' && + error !== null && + 'message' in error && + 'error_code' in error && + 'trace_id' in error + ) +} + +export const isSubscriptionErrorCode = ( + errorCode: string, +): errorCode is ESubscriptionErrorCode => { + return Object.values(ESubscriptionErrorCode).includes( + errorCode as ESubscriptionErrorCode, + ) +} diff --git a/apps/ai/clients/admin-console/src/models/api.ts b/apps/ai/clients/admin-console/src/models/api.ts index ebb2ece4..7b8d2a62 100644 --- a/apps/ai/clients/admin-console/src/models/api.ts +++ b/apps/ai/clients/admin-console/src/models/api.ts @@ -1,5 +1,12 @@ import { UserProfile } from '@auth0/nextjs-auth0/client' +export interface ErrorResponse { + trace_id: string + message: string + error_code: string + detail?: Record +} + export interface SlackTeam { id: string | null name: string | null diff --git a/apps/ai/clients/admin-console/src/models/errorCodes.ts b/apps/ai/clients/admin-console/src/models/errorCodes.ts new file mode 100644 index 00000000..c9a44dee --- /dev/null +++ b/apps/ai/clients/admin-console/src/models/errorCodes.ts @@ -0,0 +1,13 @@ +export enum ESubscriptionErrorCode { + no_payment_method = 'no_payment_method', + spending_limit_exceeded = 'spending_limit_exceeded', + hard_spending_limit_exceeded = 'hard_spending_limit_exceeded', + subscription_past_due = 'subscription_past_due', + subscription_canceled = 'subscription_canceled', + unknown_subscription_status = 'unknown_subscription_status', +} + +export enum EUserErrorCode { + user_exists_in_org = 'user_exists_in_org', + user_exists_in_other_org = 'user_exists_in_other_org', +} diff --git a/apps/ai/clients/admin-console/src/pages/_app.tsx b/apps/ai/clients/admin-console/src/pages/_app.tsx index f82bd6e5..fa407abf 100644 --- a/apps/ai/clients/admin-console/src/pages/_app.tsx +++ b/apps/ai/clients/admin-console/src/pages/_app.tsx @@ -1,16 +1,14 @@ import WithAnalytics from '@/components/hoc/WithAnalytics' +import WithApiFetcher from '@/components/hoc/WithApiFetcher' import WithMobileRedirect from '@/components/hoc/WithMobileRedirect' import WithSubscription from '@/components/hoc/WithSubscription' import { AppContextProvider } from '@/contexts/app-context' import { AuthProvider } from '@/contexts/auth-context' import { SubscriptionProvider } from '@/contexts/subscription-context' -import useApiFetcher from '@/hooks/api/generics/useApiFetcher' import { cn } from '@/lib/utils' import '@/styles/globals.css' import type { AppProps } from 'next/app' import { Nunito_Sans, Source_Code_Pro } from 'next/font/google' -import { FC, ReactNode } from 'react' -import { SWRConfig } from 'swr' export const sourceCode = Source_Code_Pro({ subsets: ['latin'], @@ -25,19 +23,6 @@ export const mainFont = Nunito_Sans({ display: 'swap', }) -const SWRConfigWithAuth: FC<{ children: ReactNode }> = ({ children }) => { - const { apiFetcher } = useApiFetcher() - return ( - - {children} - - ) -} - export default function App({ Component, pageProps }: AppProps) { return ( @@ -46,7 +31,7 @@ export default function App({ Component, pageProps }: AppProps) { - +
    -
    +
    diff --git a/apps/ai/clients/admin-console/src/pages/auth/error.tsx b/apps/ai/clients/admin-console/src/pages/auth/error.tsx index 5d8fb258..ef41219e 100644 --- a/apps/ai/clients/admin-console/src/pages/auth/error.tsx +++ b/apps/ai/clients/admin-console/src/pages/auth/error.tsx @@ -23,7 +23,7 @@ const getErrorCause = (errorDescription: string): string => { case ERROR_CODES.EMAIL_NOT_VERIFIED: return 'Verify your email address' default: - return "Oops! We couldn't verify your identity" + return "We couldn't verify your identity" } } diff --git a/apps/ai/clients/admin-console/src/pages/billing/index.tsx b/apps/ai/clients/admin-console/src/pages/billing/index.tsx index 0a4bcf72..fa7ed238 100644 --- a/apps/ai/clients/admin-console/src/pages/billing/index.tsx +++ b/apps/ai/clients/admin-console/src/pages/billing/index.tsx @@ -1,3 +1,4 @@ +import PageErrorMessage from '@/components/error/page-error-message' import PageLayout from '@/components/layout/page-layout' import PaymentMethodsList from '@/components/organization/payment-methods-list' import { ContentBox } from '@/components/ui/content-box' @@ -21,7 +22,7 @@ import { FC } from 'react' const BillingPage: FC = () => { const { organization } = useAppContext() - const { usage } = useUsage() + const { usage, isLoading, error } = useUsage() const router = useRouter() if (!organization) return <> @@ -62,17 +63,26 @@ const BillingPage: FC = () => {
- {usage ? ( - <> -
- ${toDollars(usage?.amount_due)} -
- - {billingCycle} - - - ) : ( + {isLoading ? ( + ) : error ? ( + + + + ) : ( + usage && ( + <> +
+ ${toDollars(usage?.amount_due)} +
+ + {billingCycle} + + + ) )}
diff --git a/apps/ai/clients/admin-console/src/pages/databases/index.tsx b/apps/ai/clients/admin-console/src/pages/databases/index.tsx index dbecfa85..b6800579 100644 --- a/apps/ai/clients/admin-console/src/pages/databases/index.tsx +++ b/apps/ai/clients/admin-console/src/pages/databases/index.tsx @@ -1,8 +1,8 @@ import DatabaseConnectionFormDialog from '@/components/databases/database-connection-form-dialog' import DatabaseDetails from '@/components/databases/database-details' -import DatabasesError from '@/components/databases/error' import FirstDatabaseConnection from '@/components/databases/first-database-connection' import LoadingDatabases from '@/components/databases/loading' +import PageErrorMessage from '@/components/error/page-error-message' import PageLayout from '@/components/layout/page-layout' import { ContentBox } from '@/components/ui/content-box' import { GlobalTreeSelectionProvider } from '@/components/ui/tree-view-global-context' @@ -54,9 +54,12 @@ const DatabasesPage: FC = () => { if (error) { pageContent = ( -
- -
+ + + ) } else if ( (isLoading || isLoadingAfterFirstConnection) && diff --git a/apps/ai/clients/admin-console/src/pages/fine-tuning/index.tsx b/apps/ai/clients/admin-console/src/pages/fine-tuning/index.tsx index 506c5022..4263888c 100644 --- a/apps/ai/clients/admin-console/src/pages/fine-tuning/index.tsx +++ b/apps/ai/clients/admin-console/src/pages/fine-tuning/index.tsx @@ -1,7 +1,7 @@ import { DataTable } from '@/components/data-table' import { LoadingTable } from '@/components/data-table/loading-table' +import PageErrorMessage from '@/components/error/page-error-message' import { finetunningsColumns } from '@/components/fine-tunnings/columns' -import FineTunningsError from '@/components/fine-tunnings/error' import PageLayout from '@/components/layout/page-layout' import { Button } from '@/components/ui/button' import { ContentBox } from '@/components/ui/content-box' @@ -22,7 +22,12 @@ const FineTuningPage: FC = () => { ) } else if (error) { - pageContent = + pageContent = ( + + ) } else if (models?.length === 0) { pageContent =
No models created yet.
} else { diff --git a/apps/ai/clients/admin-console/src/pages/golden-sql/index.tsx b/apps/ai/clients/admin-console/src/pages/golden-sql/index.tsx index d86e238a..2a7294c0 100644 --- a/apps/ai/clients/admin-console/src/pages/golden-sql/index.tsx +++ b/apps/ai/clients/admin-console/src/pages/golden-sql/index.tsx @@ -1,14 +1,15 @@ import { DataTable } from '@/components/data-table' import { LoadingTable } from '@/components/data-table/loading-table' +import PageErrorMessage from '@/components/error/page-error-message' import { getColumns } from '@/components/golden-sql/columns' import PageLayout from '@/components/layout/page-layout' -import QueriesError from '@/components/queries/error' import { ContentBox } from '@/components/ui/content-box' import { ToastAction } from '@/components/ui/toast' import { Toaster } from '@/components/ui/toaster' import { toast } from '@/components/ui/use-toast' import { useDeleteGoldenSql } from '@/hooks/api/useDeleteGoldenSql' import useGoldenSqlList from '@/hooks/api/useGoldenSqlList' +import { ErrorResponse } from '@/models/api' import { withPageAuthRequired } from '@auth0/nextjs-auth0/client' import Head from 'next/head' import { FC, useCallback, useMemo, useState } from 'react' @@ -38,11 +39,11 @@ const GoldenSQLPage: FC = () => { }) } catch (e) { console.error(e) + const { message: title, trace_id: description } = e as ErrorResponse toast({ variant: 'destructive', - title: 'Oops! Something went wrong', - description: - 'There was a problem with deleting your golden sql query.', + title, + description, action: ( handleDelete(id)}> Try again @@ -74,7 +75,12 @@ const GoldenSQLPage: FC = () => { let pageContent: JSX.Element = <> if (error) { - pageContent = + pageContent = ( + + ) } else if (isLoadingFirst) { pageContent = } else diff --git a/apps/ai/clients/admin-console/src/pages/organization/index.tsx b/apps/ai/clients/admin-console/src/pages/organization/index.tsx index 7c600d87..1a1b275d 100644 --- a/apps/ai/clients/admin-console/src/pages/organization/index.tsx +++ b/apps/ai/clients/admin-console/src/pages/organization/index.tsx @@ -2,7 +2,6 @@ import PageLayout from '@/components/layout/page-layout' import EditOrganizationDialog from '@/components/organization/edit-organization-dialog' import { Button } from '@/components/ui/button' import { ContentBox } from '@/components/ui/content-box' -import { Input } from '@/components/ui/input' import { toast } from '@/components/ui/use-toast' import UserList from '@/components/user/user-list' import { useAppContext } from '@/contexts/app-context' diff --git a/apps/ai/clients/admin-console/src/pages/playground/index.tsx b/apps/ai/clients/admin-console/src/pages/playground/index.tsx index 31e40222..4328739d 100644 --- a/apps/ai/clients/admin-console/src/pages/playground/index.tsx +++ b/apps/ai/clients/admin-console/src/pages/playground/index.tsx @@ -1,3 +1,4 @@ +import ErrorDetails from '@/components/error/error-details' import PageLayout from '@/components/layout/page-layout' import { SectionHeader, @@ -26,10 +27,11 @@ import { getDomainStatusColors, mapQuery, } from '@/lib/domain/query' -import { Query } from '@/models/api' +import { ErrorResponse, Query } from '@/models/api' import { withPageAuthRequired } from '@auth0/nextjs-auth0/client' import clsx from 'clsx' import { + AlertOctagon, CaseSensitive, Code2, DatabaseZap, @@ -54,8 +56,11 @@ const PlaygroundPage: FC = () => { const [currentQueryPrompt, setCurrentQueryPrompt] = useState('') const [queryError, setQueryError] = useState() - const { dbConnections, isLoading: loadingDatabases } = - useDatabaseConnections() + const { + dbConnections, + isLoading: loadingDatabases, + error: dbError, + } = useDatabaseConnections() // Database connections const [selectedDbConnectionId, setSelectedDbConnectionId] = @@ -143,24 +148,25 @@ const PlaygroundPage: FC = () => { toast({ title: 'Generation completed', }) - } catch (error) { - console.error(error) - if ((error as Error).name === 'AbortError') { + } catch (e) { + console.error(e) + if ((e as Error)?.name === 'AbortError') { toast({ title: 'Generation cancelled', }) } else { - setQueryError((error as Error).message) + const { message: title, trace_id: description } = e as ErrorResponse toast({ variant: 'destructive', - title: 'Oops! Something went wrong', - description: 'There was a problem with the SQL generation.', + title, + description, action: ( Try again ), }) + setQueryError(title) } } finally { setSubmittingQuery(false) @@ -181,7 +187,19 @@ const PlaygroundPage: FC = () => { let content =
- if (!dbConnectionOptions?.length) { + if (dbError) { + content = ( +
+ + There was a problem fetching your database connections. + +
+ ) + } else if (!dbConnectionOptions?.length) { content = (
diff --git a/apps/ai/clients/admin-console/src/pages/queries/[queryId]/index.tsx b/apps/ai/clients/admin-console/src/pages/queries/[queryId]/index.tsx index d84e7625..ab35ae60 100644 --- a/apps/ai/clients/admin-console/src/pages/queries/[queryId]/index.tsx +++ b/apps/ai/clients/admin-console/src/pages/queries/[queryId]/index.tsx @@ -1,7 +1,8 @@ +import PageErrorMessage from '@/components/error/page-error-message' import PageLayout from '@/components/layout/page-layout' -import QueryError from '@/components/query/error' import LoadingQuery from '@/components/query/loading' import QueryWorkspace from '@/components/query/workspace' +import { ContentBox } from '@/components/ui/content-box' import useQueries from '@/hooks/api/query/useQueries' import { useQuery } from '@/hooks/api/query/useQuery' import useQueryExecution from '@/hooks/api/query/useQueryExecution' @@ -73,9 +74,13 @@ const QueryPage: FC = () => { pageContent = } else if (error) { pageContent = ( -
- -
+ + + ) } else if (query) pageContent = ( diff --git a/apps/ai/clients/admin-console/src/pages/queries/index.tsx b/apps/ai/clients/admin-console/src/pages/queries/index.tsx index b68bebd6..8b0dddbe 100644 --- a/apps/ai/clients/admin-console/src/pages/queries/index.tsx +++ b/apps/ai/clients/admin-console/src/pages/queries/index.tsx @@ -1,8 +1,8 @@ import { DataTable } from '@/components/data-table' import { LoadingTable } from '@/components/data-table/loading-table' +import PageErrorMessage from '@/components/error/page-error-message' import PageLayout from '@/components/layout/page-layout' import { getColumns } from '@/components/queries/columns' -import QueriesError from '@/components/queries/error' import { ContentBox } from '@/components/ui/content-box' import { useAppContext } from '@/contexts/app-context' import useQueries from '@/hooks/api/query/useQueries' @@ -51,8 +51,13 @@ const QueriesPage: FC = () => { let pageContent: JSX.Element = <> - if (!isLoadingFirst && error) { - pageContent = + if (error) { + pageContent = ( + + ) } else if (isLoadingFirst) { pageContent = } else diff --git a/apps/ai/clients/slack/handlers/message/index.js b/apps/ai/clients/slack/handlers/message/index.js index 53fe8e85..97ba132a 100644 --- a/apps/ai/clients/slack/handlers/message/index.js +++ b/apps/ai/clients/slack/handlers/message/index.js @@ -40,7 +40,7 @@ async function handleMessage(context, say) { workspace_id: teamId, channel_id: channel_id, thread_ts: thread_ts, - } + }, } log('Fetching data from', endpointUrl) log('Request payload:', payload) @@ -54,15 +54,11 @@ async function handleMessage(context, say) { }) if (!response.ok) { try { - const { prompt_id, display_id, error_message } = - await response.json() + const { error_code, message, detail } = await response.json() error( - `API Response not ok: status code ${response.status}, ${response.statusText}, error message: ${error_message}, query id: ${prompt_id}` + `API Response not ok: status code ${response.status}, ${response.statusText}, error code: ${error_code}, error message: ${message}, detail: ${detail}` ) - const responseMessage = - prompt_id == undefined || display_id == undefined - ? `:exclamation: Sorry, something went wrong when I was processing your request. Please try again later.` - : `:warning: Sorry, something went wrong while generating response for query ${display_id}. We'll get back to you once it's been reviewed by the data-team admins.` + const responseMessage = `:warning: Sorry, something went wrong while generating response. Error message: \`${message}\`` await say({ blocks: [ { diff --git a/apps/ai/server/app.py b/apps/ai/server/app.py index 7bc432b2..412523bf 100644 --- a/apps/ai/server/app.py +++ b/apps/ai/server/app.py @@ -5,6 +5,9 @@ from fastapi.middleware.cors import CORSMiddleware from config import settings +from exceptions.exception_handlers import exception_handler +from exceptions.exceptions import BaseError +from middleware.error import UnknownErrorMiddleware from modules.auth import controller as auth_controller from modules.db_connection import controller as db_connection_controller from modules.finetuning import controller as finetuning_controller @@ -21,7 +24,6 @@ from modules.organization.invoice import controller as invoice_controller from modules.table_description import controller as table_description_controller from modules.user import controller as user_controller -from utils.exception import GenerationEngineError, query_engine_exception_handler tags_metadata = [ {"name": "Authentication", "description": "Login endpoints for authentication"}, @@ -46,6 +48,7 @@ app = FastAPI() +app.add_middleware(UnknownErrorMiddleware) app.add_middleware( CORSMiddleware, allow_origins=["*"], @@ -54,8 +57,7 @@ allow_headers=["*"], ) -app.add_exception_handler(GenerationEngineError, query_engine_exception_handler) - +app.add_exception_handler(BaseError, exception_handler) app.include_router(db_connection_controller.router, tags=["Database Connection"]) app.include_router(finetuning_controller.router, tags=["Finetuning"]) diff --git a/apps/ai/server/dataherald b/apps/ai/server/dataherald index b7e99006..9f34ba70 160000 --- a/apps/ai/server/dataherald +++ b/apps/ai/server/dataherald @@ -1 +1 @@ -Subproject commit b7e99006d64dfe0a0ac3c52bc77a80bde2372b28 +Subproject commit 9f34ba7016f4f5381528170a09c9a6a561341c20 diff --git a/apps/ai/server/exceptions/error_codes.py b/apps/ai/server/exceptions/error_codes.py new file mode 100644 index 00000000..ece39e15 --- /dev/null +++ b/apps/ai/server/exceptions/error_codes.py @@ -0,0 +1,44 @@ +from enum import Enum, EnumMeta + +from pydantic import BaseModel +from starlette.status import ( + HTTP_400_BAD_REQUEST, + HTTP_500_INTERNAL_SERVER_ERROR, +) + + +class ErrorCodeData(BaseModel): + status_code: int + message: str + + +class ErrorCodeInterface(EnumMeta): + def __new__(cls, metacls, bases, classdict): + enum_class = super().__new__(cls, metacls, bases, classdict) + for name, member in enum_class.__members__.items(): + if not isinstance(member.value, ErrorCodeData): + raise TypeError( + f"Enum value for {name} must be an instance of ErrorCodeData" + ) + return enum_class + + +class BaseErrorCode(Enum, metaclass=ErrorCodeInterface): + """ "" + This class serves as a base for all Error code enums + It will enforce that all enum values are instances of ErrorCodeData + """ + + pass + + +class GeneralErrorCode(BaseErrorCode): + unknown_error = ErrorCodeData( + status_code=HTTP_500_INTERNAL_SERVER_ERROR, message="Unknown error" + ) + unhandled_engine_error = ErrorCodeData( + status_code=HTTP_500_INTERNAL_SERVER_ERROR, message="Unhandled engine error" + ) + reserved_metadata_key = ErrorCodeData( + status_code=HTTP_400_BAD_REQUEST, message="Metadata cannot contain reserved key" + ) diff --git a/apps/ai/server/exceptions/error_response.py b/apps/ai/server/exceptions/error_response.py new file mode 100644 index 00000000..edf699cd --- /dev/null +++ b/apps/ai/server/exceptions/error_response.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel + + +class ErrorResponse(BaseModel): + trace_id: str + error_code: str + message: str + detail: dict | None = None diff --git a/apps/ai/server/exceptions/exception_handlers.py b/apps/ai/server/exceptions/exception_handlers.py new file mode 100644 index 00000000..f14f98a3 --- /dev/null +++ b/apps/ai/server/exceptions/exception_handlers.py @@ -0,0 +1,64 @@ +from fastapi import Request +from fastapi.logger import logger +from fastapi.responses import JSONResponse +from httpx import Response + +from exceptions.error_response import ErrorResponse +from exceptions.exceptions import ( + BaseError, + EngineError, + UnhandledEngineError, +) +from exceptions.utils import is_http_error + + +async def exception_handler(request: Request, exc: BaseError): # noqa: ARG001 + + trace_id = exc.trace_id + error_code = exc.error_code + status_code = exc.status_code + message = exc.message + detail = {k: v for k, v in exc.detail.items() if v is not None} + + logger.error( + "ERROR\nTrace ID: %s, error_code: %s, detail: %s", + trace_id, + error_code, + detail, + ) + return JSONResponse( + status_code=status_code, + content=ErrorResponse( + trace_id=trace_id, + error_code=error_code, + message=message, + detail=detail, + ).dict(), + ) + + +def raise_engine_exception(response: Response, org_id: str): + if not is_http_error(response.status_code): + return + + response_json: dict = response.json() + + if "error_code" in response_json: + error_code = response_json["error_code"] + message = response_json.get( + "message", f"Unknown translation engine error_code: {error_code}" + ) + detail = response_json.get("detail", {}) + detail["organization_id"] = org_id + + logger.error("Handled error from translation engine: %s", message) + + raise EngineError( + error_code=error_code, + status_code=response.status_code, + message=message, + detail=detail, + ) + + logger.error("Unhandled error from translation engine: %s", response.text) + raise UnhandledEngineError() diff --git a/apps/ai/server/exceptions/exceptions.py b/apps/ai/server/exceptions/exceptions.py new file mode 100644 index 00000000..a63f8c4a --- /dev/null +++ b/apps/ai/server/exceptions/exceptions.py @@ -0,0 +1,117 @@ +from abc import ABC + +from exceptions.error_codes import BaseErrorCode, GeneralErrorCode +from exceptions.utils import generate_trace_id + + +class BaseError(ABC, Exception): + ERROR_CODES: BaseErrorCode = None + + @property + def trace_id(self) -> str: + return self._trace_id + + @property + def error_code(self) -> str: + return self._error_code + + @property + def status_code(self) -> str: + return self._status_code + + @property + def message(self) -> str: + return self._message + + @property + def detail(self) -> dict: + return self._detail + + def __init__( + self, + error_code: str = None, + status_code: str = None, + message: str = None, + detail: dict = None, + ) -> None: + + if type(self) is BaseError: + raise TypeError("BaseError class may not be instantiated directly") + + if self.ERROR_CODES is None or not hasattr(self.ERROR_CODES, "__members__"): + raise TypeError( + f"ERROR_CODES in {self.__class__.__name__} must be defined and be an enum type" + ) + + def handled_error_code(error_code: str) -> bool: + return error_code in self.ERROR_CODES.__members__ + + self._trace_id = generate_trace_id() + + if error_code is not None: + self._error_code = error_code + + if handled_error_code(error_code): + self._status_code = self.ERROR_CODES[error_code].value.status_code + self._message = ( + message + if message is not None + else self.ERROR_CODES[error_code].value.message + ) + else: + self._status_code = status_code if status_code is not None else "500" + self._message = ( + message + if message is not None + else f"Unknown error_code: {error_code}" + ) + else: + self._status_code = status_code if status_code is not None else 500 + self._message = message if message is not None else "Unknown error" + + self._detail = detail if detail is not None else {} + + +class GeneralError(BaseError): + """ + Base class for general exceptions + """ + + ERROR_CODES: BaseErrorCode = GeneralErrorCode + + +class EngineError(GeneralError): + def __init__( + self, + error_code: str, + status_code: int, + message: str, + detail: dict, + ) -> None: + super().__init__( + error_code=error_code, + status_code=status_code, + message=message, + detail=detail, + ) + + +class UnhandledEngineError(GeneralError): + def __init__(self) -> None: + super().__init__( + error_code=GeneralErrorCode.unhandled_engine_error.name, + ) + + +class ReservedMetadataKeyError(GeneralError): + def __init__(self) -> None: + super().__init__( + error_code=GeneralErrorCode.reserved_metadata_key.name, + ) + + +class UnknownError(GeneralError): + def __init__(self, error: str = None) -> None: + super().__init__( + error_code=GeneralErrorCode.unknown_error.name, detail={"error": error} + ) diff --git a/apps/ai/server/exceptions/utils.py b/apps/ai/server/exceptions/utils.py new file mode 100644 index 00000000..50e2791f --- /dev/null +++ b/apps/ai/server/exceptions/utils.py @@ -0,0 +1,11 @@ +import uuid + +from starlette.status import HTTP_400_BAD_REQUEST + + +def is_http_error(status_code: int) -> bool: + return status_code >= HTTP_400_BAD_REQUEST + + +def generate_trace_id(): + return f"E-{str(uuid.uuid4())}" # Generate a unique trace ID for each error diff --git a/apps/ai/server/middleware/error.py b/apps/ai/server/middleware/error.py new file mode 100644 index 00000000..46f441ad --- /dev/null +++ b/apps/ai/server/middleware/error.py @@ -0,0 +1,31 @@ +from fastapi import Request +from fastapi.logger import logger +from fastapi.responses import JSONResponse +from starlette.middleware.base import BaseHTTPMiddleware + +from exceptions.error_codes import GeneralErrorCode +from exceptions.error_response import ErrorResponse +from exceptions.utils import generate_trace_id + + +class UnknownErrorMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + try: + return await call_next(request) + except Exception: + trace_id = generate_trace_id() + logger.error(f"Unhandled ERROR\nTrace ID: {trace_id}", exc_info=True) + + error_code = GeneralErrorCode.unknown_error.name + status_code = GeneralErrorCode.unknown_error.value.status_code + message = GeneralErrorCode.unknown_error.value.message + + # raising an exception here causes problems with the exception handling + return JSONResponse( + status_code=status_code, + content=ErrorResponse( + trace_id=trace_id, + error_code=error_code, + message=message, + ).dict(), + ) diff --git a/apps/ai/server/modules/auth/models/exceptions.py b/apps/ai/server/modules/auth/models/exceptions.py new file mode 100644 index 00000000..188ee577 --- /dev/null +++ b/apps/ai/server/modules/auth/models/exceptions.py @@ -0,0 +1,95 @@ +from starlette.status import ( + HTTP_401_UNAUTHORIZED, + HTTP_403_FORBIDDEN, + HTTP_500_INTERNAL_SERVER_ERROR, +) + +from exceptions.error_codes import BaseErrorCode, ErrorCodeData +from exceptions.exceptions import BaseError + + +class AuthErrorCode(BaseErrorCode): + unauthorized_user = ErrorCodeData( + status_code=HTTP_401_UNAUTHORIZED, message="Unauthorized user" + ) + unauthorized_operation = ErrorCodeData( + status_code=HTTP_401_UNAUTHORIZED, message="Unauthorized operation" + ) + unauthorized_data_access = ErrorCodeData( + status_code=HTTP_401_UNAUTHORIZED, message="Unauthorized data access" + ) + bearer_token_expired = ErrorCodeData( + status_code=HTTP_401_UNAUTHORIZED, message="Bearer token expired" + ) + invalid_bearer_token = ErrorCodeData( + status_code=HTTP_403_FORBIDDEN, message="Bearer token is invalid" + ) + invalid_or_revoked_key = ErrorCodeData( + status_code=HTTP_403_FORBIDDEN, message="Invalid or revoked API key" + ) + py_jwk_client_error = ErrorCodeData( + status_code=HTTP_500_INTERNAL_SERVER_ERROR, message="PyJWKClient error" + ) + decode_error = ErrorCodeData( + status_code=HTTP_401_UNAUTHORIZED, message="Decode error" + ) + + +class AuthError(BaseError): + """ + Base class for auth exceptions + """ + + ERROR_CODES: BaseErrorCode = AuthErrorCode + + +class UnauthorizedUserError(AuthError): + def __init__(self, email: str) -> None: + super().__init__( + error_code=AuthErrorCode.unauthorized_user.name, + detail={"email": email}, + ) + + +class UnauthorizedOperationError(AuthError): + def __init__(self, user_id: str) -> None: + super().__init__( + error_code=AuthErrorCode.unauthorized_operation.name, + detail={"user_id": user_id}, + ) + + +class UnauthorizedDataAccessError(AuthError): + def __init__(self, user_id: str) -> None: + super().__init__( + error_code=AuthErrorCode.unauthorized_data_access.name, + detail={"user_id": user_id}, + ) + + +class InvalidOrRevokedAPIKeyError(AuthError): + def __init__(self, key_id: str) -> None: + super().__init__( + error_code=AuthErrorCode.invalid_or_revoked_key.name, + detail={"key_id": key_id}, + ) + + +class BearerTokenExpiredError(AuthError): + def __init__(self) -> None: + super().__init__(error_code=AuthErrorCode.bearer_token_expired.name) + + +class InvalidBearerTokenError(AuthError): + def __init__(self) -> None: + super().__init__(error_code=AuthErrorCode.invalid_bearer_token.name) + + +class PyJWKClientError(AuthError): + def __init__(self) -> None: + super().__init__(error_code=AuthErrorCode.py_jwk_client_error.name) + + +class DecodeError(AuthError): + def __init__(self) -> None: + super().__init__(error_code=AuthErrorCode.decode_error.name) diff --git a/apps/ai/server/modules/db_connection/models/exceptions.py b/apps/ai/server/modules/db_connection/models/exceptions.py new file mode 100644 index 00000000..b4ee16f7 --- /dev/null +++ b/apps/ai/server/modules/db_connection/models/exceptions.py @@ -0,0 +1,41 @@ +from starlette.status import ( + HTTP_400_BAD_REQUEST, + HTTP_404_NOT_FOUND, +) + +from exceptions.error_codes import BaseErrorCode, ErrorCodeData +from exceptions.exceptions import BaseError + + +class DBConnectionErrorCode(BaseErrorCode): + db_connection_not_found = ErrorCodeData( + status_code=HTTP_404_NOT_FOUND, message="Database connection not found" + ) + db_connection_alias_exists = ErrorCodeData( + status_code=HTTP_400_BAD_REQUEST, + message="Existing database connection already has alias", + ) + + +class DBConnectionError(BaseError): + """ + Base class for database connection exceptions + """ + + ERROR_CODES: BaseErrorCode = DBConnectionErrorCode + + +class DBConnectionNotFoundError(DBConnectionError): + def __init__(self, db_connection_id: str, org_id: str) -> None: + super().__init__( + error_code=DBConnectionErrorCode.db_connection_not_found.name, + detail={"db_connection_id": db_connection_id, "organization_id": org_id}, + ) + + +class DBConnectionAliasExistsError(DBConnectionError): + def __init__(self, db_connection_id: str, org_id: str) -> None: + super().__init__( + error_code=DBConnectionErrorCode.db_connection_alias_exists.name, + detail={"db_connection_id": db_connection_id, "organization_id": org_id}, + ) diff --git a/apps/ai/server/modules/db_connection/repository.py b/apps/ai/server/modules/db_connection/repository.py index e4971eac..3568a7fb 100644 --- a/apps/ai/server/modules/db_connection/repository.py +++ b/apps/ai/server/modules/db_connection/repository.py @@ -29,3 +29,17 @@ def get_db_connection(self, db_connection_id: str, org_id: str) -> DBConnection: if db_connection else None ) + + def get_db_connection_by_alias(self, alias: str, org_id: str) -> DBConnection: + db_connection = MongoDB.find_one( + DATABASE_CONNECTION_COL, + { + "alias": alias, + "metadata.dh_internal.organization_id": org_id, + }, + ) + return ( + DBConnection(id=str(db_connection["_id"]), **db_connection) + if db_connection + else None + ) diff --git a/apps/ai/server/modules/db_connection/service.py b/apps/ai/server/modules/db_connection/service.py index aeec88c7..7bf14338 100644 --- a/apps/ai/server/modules/db_connection/service.py +++ b/apps/ai/server/modules/db_connection/service.py @@ -1,19 +1,23 @@ import json import httpx -from fastapi import HTTPException, UploadFile, status +from fastapi import UploadFile from config import settings, ssh_settings +from exceptions.exception_handlers import raise_engine_exception from modules.db_connection.models.entities import ( DBConnection, DBConnectionMetadata, DHDBConnectionMetadata, ) +from modules.db_connection.models.exceptions import ( + DBConnectionAliasExistsError, + DBConnectionNotFoundError, +) from modules.db_connection.models.requests import DBConnectionRequest from modules.db_connection.models.responses import DBConnectionResponse from modules.db_connection.repository import DBConnectionRepository from utils.analytics import Analytics, EventName, EventType -from utils.exception import raise_for_status from utils.misc import reserved_key_in_metadata from utils.s3 import S3 @@ -52,6 +56,12 @@ async def add_db_connection( file: UploadFile | None = None, ) -> DBConnectionResponse: reserved_key_in_metadata(db_connection_request.metadata) + db_connection = self.repo.get_db_connection_by_alias( + db_connection_request.alias, org_id + ) + if db_connection: + raise DBConnectionAliasExistsError(db_connection.id, org_id) + db_connection_internal_request = DBConnection( **db_connection_request.dict(exclude_unset=True) ) @@ -79,7 +89,7 @@ async def add_db_connection( timeout=settings.default_engine_timeout, ) - raise_for_status(response.status_code, response.text) + raise_engine_exception(response, org_id=org_id) db_connection = DBConnectionResponse(**response.json()) self.analytics.track( org_id, @@ -130,7 +140,7 @@ async def update_db_connection( json=db_connection_internal_request.dict(), timeout=settings.default_engine_timeout, ) - raise_for_status(response.status_code, response.text) + raise_engine_exception(response, org_id=org_id) return DBConnectionResponse(**response.json()) def get_db_connection_in_org( @@ -138,10 +148,7 @@ def get_db_connection_in_org( ) -> DBConnection: db_connection = self.repo.get_db_connection(db_connection_id, org_id) if not db_connection: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Database connection not found", - ) + raise DBConnectionNotFoundError(db_connection_id, org_id) return db_connection def get_database_type(self, connection_uri: str) -> str: diff --git a/apps/ai/server/modules/finetuning/models/exceptions.py b/apps/ai/server/modules/finetuning/models/exceptions.py new file mode 100644 index 00000000..2677bbb5 --- /dev/null +++ b/apps/ai/server/modules/finetuning/models/exceptions.py @@ -0,0 +1,41 @@ +from starlette.status import ( + HTTP_400_BAD_REQUEST, + HTTP_404_NOT_FOUND, +) + +from exceptions.error_codes import BaseErrorCode, ErrorCodeData +from exceptions.exceptions import BaseError + + +class FinetuningErrorCode(BaseErrorCode): + finetuning_not_found = ErrorCodeData( + status_code=HTTP_404_NOT_FOUND, message="Finetuning not found" + ) + finetuning_alias_exists = ErrorCodeData( + status_code=HTTP_400_BAD_REQUEST, + message="Existing finetuning already has alias", + ) + + +class FinetuningError(BaseError): + """ + Base class for finetuning exceptions + """ + + ERROR_CODES: BaseErrorCode = FinetuningErrorCode + + +class FinetuningNotFoundError(FinetuningError): + def __init__(self, finetuning_id: str, org_id: str) -> None: + super().__init__( + error_code=FinetuningErrorCode.finetuning_not_found.name, + detail={"finetuning_id": finetuning_id, "organization_id": org_id}, + ) + + +class FinetuningAliasExistsError(FinetuningError): + def __init__(self, finetuning_id: str, org_id: str) -> None: + super().__init__( + error_code=FinetuningErrorCode.finetuning_alias_exists.name, + detail={"finetuning_id": finetuning_id, "organization_id": org_id}, + ) diff --git a/apps/ai/server/modules/finetuning/repository.py b/apps/ai/server/modules/finetuning/repository.py index a50860de..40a0dcc7 100644 --- a/apps/ai/server/modules/finetuning/repository.py +++ b/apps/ai/server/modules/finetuning/repository.py @@ -36,3 +36,17 @@ def get_finetuning_job(self, finetuning_id: str, org_id: str) -> Finetuning: if finetuning_job else None ) + + def get_finetuning_job_by_alias(self, alias: str, org_id: str) -> Finetuning: + finetuning_job = MongoDB.find_one( + FINETUNING_COL, + { + "alias": alias, + "metadata.dh_internal.organization_id": org_id, + }, + ) + return ( + Finetuning(id=str(finetuning_job["_id"]), **finetuning_job) + if finetuning_job + else None + ) diff --git a/apps/ai/server/modules/finetuning/service.py b/apps/ai/server/modules/finetuning/service.py index 469ad9ed..232e9b0a 100644 --- a/apps/ai/server/modules/finetuning/service.py +++ b/apps/ai/server/modules/finetuning/service.py @@ -1,19 +1,22 @@ import httpx -from fastapi import HTTPException, status from config import settings +from exceptions.exception_handlers import raise_engine_exception from modules.db_connection.service import DBConnectionService from modules.finetuning.models.entities import ( DHFinetuningMetadata, Finetuning, FinetuningMetadata, ) +from modules.finetuning.models.exceptions import ( + FinetuningAliasExistsError, + FinetuningNotFoundError, +) from modules.finetuning.models.requests import FinetuningRequest from modules.finetuning.models.responses import AggrFinetuning from modules.finetuning.repository import FinetuningRepository from modules.golden_sql.service import GoldenSQLService from utils.analytics import Analytics, EventName, EventType -from utils.exception import raise_for_status from utils.misc import reserved_key_in_metadata @@ -43,7 +46,7 @@ async def get_finetuning_jobs( params={"db_connection_id": db_connection.id}, timeout=settings.default_engine_timeout, ) - raise_for_status(response.status_code, response.text) + raise_engine_exception(response, org_id=org_id) finetuning_jobs += [ AggrFinetuning( **finetuning_job, @@ -62,7 +65,7 @@ async def get_finetuning_job( settings.engine_url + f"/finetunings/{finetuning_id}", timeout=settings.default_engine_timeout, ) - raise_for_status(response.status_code, response.text) + raise_engine_exception(response, org_id=org_id) finetuning_job = Finetuning(**response.json()) db_connection = self.db_connection_service.get_db_connection_in_org( finetuning_job.db_connection_id, org_id @@ -79,6 +82,12 @@ async def create_finetuning_job( finetuning_request.db_connection_id, org_id ) + finetuning = self.repo.get_finetuning_job_by_alias( + finetuning_request.alias, org_id + ) + if finetuning: + raise FinetuningAliasExistsError(finetuning.id, org_id) + finetuning_request.metadata = FinetuningMetadata( **finetuning_request.metadata, dh_internal=DHFinetuningMetadata(organization_id=org_id), @@ -89,7 +98,7 @@ async def create_finetuning_job( settings.engine_url + "/finetunings", json=finetuning_request.dict(exclude_unset=True), ) - raise_for_status(response.status_code, response.text) + raise_engine_exception(response, org_id=org_id) aggr_finetuning = AggrFinetuning( **response.json(), db_connection_alias=db_connection.alias @@ -124,7 +133,7 @@ async def cancel_finetuning_job( response = await client.post( settings.engine_url + f"/finetunings/{finetuning_id}/cancel", ) - raise_for_status(response.status_code, response.text) + raise_engine_exception(response, org_id=org_id) return AggrFinetuning( **response.json(), db_connection_alias=db_connection.alias ) @@ -132,10 +141,7 @@ async def cancel_finetuning_job( def get_finetuning_job_in_org(self, finetuning_id: str, org_id: str) -> Finetuning: finetuning_job = self.repo.get_finetuning_job(finetuning_id, org_id) if not finetuning_job: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Finetuning not found", - ) + raise FinetuningNotFoundError(finetuning_id, org_id) return finetuning_job def is_gpt_4_model(self, model_name: str) -> bool: diff --git a/apps/ai/server/modules/generation/aggr_service.py b/apps/ai/server/modules/generation/aggr_service.py index ac77235d..c7e59151 100644 --- a/apps/ai/server/modules/generation/aggr_service.py +++ b/apps/ai/server/modules/generation/aggr_service.py @@ -1,10 +1,10 @@ from datetime import datetime import httpx -from fastapi import HTTPException, status from fastapi.responses import StreamingResponse from config import settings +from exceptions.exception_handlers import raise_engine_exception from modules.db_connection.service import DBConnectionService from modules.generation.models.entities import ( DHNLGenerationMetadata, @@ -21,6 +21,12 @@ SQLGenerationMetadata, SQLGenerationStatus, ) +from modules.generation.models.exceptions import ( + GenerationVerifiedOrRejectedError, + InvalidSqlGenerationError, + PromptNotFoundError, + SqlGenerationNotFoundError, +) from modules.generation.models.requests import ( GenerationUpdateRequest, NLGenerationRequest, @@ -47,7 +53,6 @@ from modules.user.models.responses import UserResponse from modules.user.service import UserService from utils.analytics import Analytics, EventName, EventType -from utils.exception import GenerationEngineError, raise_for_status from utils.slack import SlackWebClient, remove_slack_mentions CONFIDENCE_CAP = 0.95 @@ -79,9 +84,7 @@ async def get_generation(self, prompt_id: str, org_id: str) -> GenerationRespons prompt, sql_generation, nl_generation ) - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail="Prompt not found" - ) + raise PromptNotFoundError(prompt_id, org_id) def get_generation_list( self, @@ -192,7 +195,7 @@ async def create_generation( json=generation_request.dict(exclude_unset=True), timeout=settings.default_engine_timeout, ) - self._raise_for_generation_status(response, display_id=display_id) + raise_engine_exception(response, org_id=organization.id) nl_generation = NLGeneration(**response.json()) sql_generation = self.repo.get_sql_generation( @@ -278,7 +281,7 @@ async def create_prompt_sql_generation_result( json=generation_request.dict(exclude_unset=True), timeout=settings.default_engine_timeout, ) - self._raise_for_generation_status(response) + raise_engine_exception(response, org_id=org_id) sql_generation = SQLGeneration(**response.json()) self.repo.update_prompt_dh_metadata( @@ -294,15 +297,13 @@ async def create_prompt_sql_generation_result( prompt = self.repo.get_prompt(sql_generation.prompt_id, organization.id) if sql_generation.status == SQLGenerationStatus.VALID: - sql_result_response = await client.get( + response = await client.get( settings.engine_url + f"/sql-generations/{sql_generation.id}/execute", timeout=settings.default_engine_timeout, ) - raise_for_status( - sql_result_response.status_code, sql_result_response.text - ) - sql_result = sql_result_response.json() + raise_engine_exception(response, org_id=org_id) + sql_result = response.json() else: sql_result = None @@ -323,9 +324,7 @@ async def update_generation( ) -> GenerationResponse: prompt = self.repo.get_prompt(prompt_id, org_id) if not prompt: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail="Prompt not found" - ) + raise PromptNotFoundError(prompt_id, org_id) sql_generation = self.repo.get_latest_sql_generation(prompt_id, org_id) nl_generation = ( @@ -407,9 +406,7 @@ async def create_sql_nl_generation( ) -> GenerationResponse: prompt = self.repo.get_prompt(prompt_id, org_id) if not prompt: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail="Prompt not found" - ) + raise PromptNotFoundError(prompt_id, org_id) generation_request = SQLNLGenerationRequest( metadata=NLGenerationMetadata( @@ -435,7 +432,7 @@ async def create_sql_nl_generation( json=generation_request.dict(exclude_unset=True), timeout=settings.default_engine_timeout, ) - self._raise_for_generation_status(response, prompt=prompt) + raise_engine_exception(response, org_id=org_id) nl_generation = NLGeneration(**response.json()) sql_generation = self.repo.get_sql_generation( nl_generation.sql_generation_id, org_id @@ -459,15 +456,13 @@ async def create_sql_nl_generation( prompt = self.repo.get_prompt(prompt_id, org_id) if sql_generation.status == SQLGenerationStatus.VALID: - sql_result_response = await client.get( + response = await client.get( settings.engine_url + f"/sql-generations/{sql_generation.id}/execute", timeout=settings.default_engine_timeout, ) - raise_for_status( - sql_result_response.status_code, sql_result_response.text - ) - sql_result = sql_result_response.json() + raise_engine_exception(response, org_id=org_id) + sql_result = response.json() else: sql_result = None @@ -484,23 +479,18 @@ async def create_sql_generation_result( self, prompt_id: str, sql_request: SQLRequest, - org_id, + org_id: str, user: UserResponse = None, ) -> GenerationResponse: prompt = self.repo.get_prompt(prompt_id, org_id) if not prompt: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail="Prompt not found" - ) + raise PromptNotFoundError(prompt_id, org_id) if prompt.metadata.dh_internal.generation_status in { GenerationStatus.VERIFIED, GenerationStatus.REJECTED, }: - raise_for_status( - status_code=status.HTTP_400_BAD_REQUEST, - message="generation has already been verified or rejected", - ) + raise GenerationVerifiedOrRejectedError(prompt_id, org_id) generation_request = SQLGenerationRequest( sql=sql_request.sql, @@ -516,7 +506,7 @@ async def create_sql_generation_result( json=generation_request.dict(exclude_unset=True), timeout=settings.default_engine_timeout, ) - self._raise_for_generation_status(response, prompt=prompt) + raise_engine_exception(response, org_id=org_id) sql_generation = SQLGeneration(**response.json()) self.repo.update_prompt_dh_metadata( @@ -537,15 +527,13 @@ async def create_sql_generation_result( prompt = self.repo.get_prompt(prompt_id, org_id) if sql_generation.status == SQLGenerationStatus.VALID: - sql_result_response = await client.get( + response = await client.get( settings.engine_url + f"/sql-generations/{sql_generation.id}/execute", timeout=settings.default_engine_timeout, ) - raise_for_status( - sql_result_response.status_code, sql_result_response.text - ) - sql_result = sql_result_response.json() + raise_engine_exception(response, org_id=org_id) + sql_result = response.json() else: sql_result = None @@ -562,14 +550,10 @@ async def create_nl_generation( ) -> NLGenerationResponse: prompt = self.repo.get_prompt(prompt_id, org_id) if not prompt: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail="Prompt not found" - ) + raise PromptNotFoundError(prompt_id, org_id) sql_generation = self.repo.get_latest_sql_generation(prompt_id, org_id) if not sql_generation: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail="SQL Prompt not found" - ) + raise SqlGenerationNotFoundError(prompt_id, org_id) generation_request = NLGenerationRequest( metadata=NLGenerationMetadata( @@ -584,7 +568,7 @@ async def create_nl_generation( json=generation_request.dict(exclude_unset=True), timeout=settings.default_engine_timeout, ) - raise_for_status(response.status_code, response.text) + raise_engine_exception(response, org_id=org_id) nl_generation = NLGeneration(**response.json()) self.repo.update_prompt_dh_metadata( @@ -632,27 +616,19 @@ async def send_message(self, prompt_id: str, org_id: str): async def export_csv_file(self, prompt_id: str, org_id: str) -> StreamingResponse: if not self.repo.get_prompt(prompt_id, org_id): - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail="Prompt not found" - ) + raise PromptNotFoundError(prompt_id, org_id) sql_generation = self.repo.get_latest_sql_generation(prompt_id, org_id) if not sql_generation: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="SQL Generation not found", - ) + raise SqlGenerationNotFoundError(prompt_id, org_id) if sql_generation.status != SQLGenerationStatus.VALID: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="SQL Generation is not valid", - ) + raise InvalidSqlGenerationError(sql_generation.id, org_id) async with httpx.AsyncClient() as client: response = await client.get( settings.engine_url + f"/sql-generations/{sql_generation.id}/csv-file", timeout=settings.default_engine_timeout, ) - raise_for_status(response.status_code, response.text) + raise_engine_exception(response, org_id=org_id) return StreamingResponse( content=response.iter_bytes(), headers=response.headers, @@ -721,28 +697,3 @@ def _get_mapped_generation_response( sql_result=sql_result, **prompt.metadata.dh_internal.dict(exclude_unset=True), ) - - def _raise_for_generation_status( - self, response: httpx.Response, prompt: Prompt = None, display_id: str = None - ): - response_json = response.json() - if response.status_code != status.HTTP_201_CREATED: - if prompt or ("prompt_id" in response_json and response_json["prompt_id"]): - prompt_id = prompt.id if prompt else response_json["prompt_id"] - self.repo.update_prompt_dh_metadata( - prompt_id, - DHPromptMetadata( - generation_status=GenerationStatus.ERROR, - ), - ) - raise GenerationEngineError( - status_code=response.status_code, - prompt_id=prompt_id, - display_id=( - display_id or prompt.metadata.dh_internal.display_id - if prompt - else None - ), - error_message=response_json["message"], - ) - raise_for_status(response.status_code, response.text) diff --git a/apps/ai/server/modules/generation/models/exceptions.py b/apps/ai/server/modules/generation/models/exceptions.py new file mode 100644 index 00000000..25002e25 --- /dev/null +++ b/apps/ai/server/modules/generation/models/exceptions.py @@ -0,0 +1,74 @@ +from starlette.status import ( + HTTP_400_BAD_REQUEST, + HTTP_404_NOT_FOUND, +) + +from exceptions.error_codes import BaseErrorCode, ErrorCodeData +from exceptions.exceptions import BaseError + + +class GenerationErrorCode(BaseErrorCode): + prompt_not_found = ErrorCodeData( + status_code=HTTP_404_NOT_FOUND, message="Prompt not found" + ) + sql_generation_not_found = ErrorCodeData( + status_code=HTTP_404_NOT_FOUND, message="SQL generation not found" + ) + nl_generation_not_found = ErrorCodeData( + status_code=HTTP_404_NOT_FOUND, message="NL generation not found" + ) + generation_verified_or_rejected = ErrorCodeData( + status_code=HTTP_400_BAD_REQUEST, + message="Cannot modify verified or rejected generation", + ) + invalid_sql_generation = ErrorCodeData( + status_code=HTTP_400_BAD_REQUEST, message="Invalid SQL generation" + ) + + +class GenerationError(BaseError): + """ + Base class for generation exceptions + """ + + ERROR_CODES: BaseErrorCode = GenerationErrorCode + + +class PromptNotFoundError(GenerationError): + def __init__(self, prompt_id: str, org_id: str) -> None: + super().__init__( + error_code=GenerationErrorCode.prompt_not_found.name, + detail={"prompt_id": prompt_id, "organization_id": org_id}, + ) + + +class SqlGenerationNotFoundError(GenerationError): + def __init__(self, sql_generation_id: str, org_id: str) -> None: + super().__init__( + error_code=GenerationErrorCode.sql_generation_not_found.name, + detail={"sql_generation_id": sql_generation_id, "organization_id": org_id}, + ) + + +class NlGenerationNotFoundError(GenerationError): + def __init__(self, nl_generation_id: str, org_id: str) -> None: + super().__init__( + error_code=GenerationErrorCode.nl_generation_not_found.name, + detail={"nl_generation_id": nl_generation_id, "organization_id": org_id}, + ) + + +class GenerationVerifiedOrRejectedError(GenerationError): + def __init__(self, nl_generation_id: str, org_id: str) -> None: + super().__init__( + error_code=GenerationErrorCode.generation_verified_or_rejected.name, + detail={"nl_generation_id": nl_generation_id, "organization_id": org_id}, + ) + + +class InvalidSqlGenerationError(GenerationError): + def __init__(self, sql_generation_id: str, org_id: str) -> None: + super().__init__( + error_code=GenerationErrorCode.invalid_sql_generation.name, + detail={"sql_generation_id": sql_generation_id, "organization_id": org_id}, + ) diff --git a/apps/ai/server/modules/generation/service.py b/apps/ai/server/modules/generation/service.py index 6d93499d..509c909a 100644 --- a/apps/ai/server/modules/generation/service.py +++ b/apps/ai/server/modules/generation/service.py @@ -1,8 +1,8 @@ import httpx -from fastapi import HTTPException, status from fastapi.responses import StreamingResponse from config import settings +from exceptions.exception_handlers import raise_engine_exception from modules.db_connection.service import DBConnectionService from modules.generation.models.entities import ( DHNLGenerationMetadata, @@ -18,6 +18,13 @@ SQLGenerationMetadata, SQLGenerationStatus, ) +from modules.generation.models.exceptions import ( + GenerationVerifiedOrRejectedError, + InvalidSqlGenerationError, + NlGenerationNotFoundError, + PromptNotFoundError, + SqlGenerationNotFoundError, +) from modules.generation.models.requests import ( NLGenerationRequest, PromptRequest, @@ -33,7 +40,6 @@ ) from modules.generation.repository import GenerationRepository from utils.analytics import Analytics, EventName, EventType -from utils.exception import GenerationEngineError, raise_for_status from utils.misc import reserved_key_in_metadata @@ -128,7 +134,7 @@ async def create_prompt( settings.engine_url + "/prompts", json=create_request.dict(exclude_unset=True), ) - raise_for_status(response.status_code, response.text) + raise_engine_exception(response, org_id=org_id) return PromptResponse(**response.json()) async def create_prompt_sql_generation( @@ -155,7 +161,7 @@ async def create_prompt_sql_generation( json=create_request.dict(exclude_unset=True), timeout=settings.default_engine_timeout, ) - self._raise_for_generation_status(response, org_id) + raise_engine_exception(response, org_id=org_id) sql_generation = SQLGenerationResponse(**response.json()) self._update_generation_status( @@ -195,7 +201,7 @@ async def create_prompt_sql_nl_generation( json=create_request.dict(exclude_unset=True), timeout=settings.default_engine_timeout, ) - self._raise_for_generation_status(response, org_id) + raise_engine_exception(response, org_id=org_id) nl_generation = NLGenerationResponse(**response.json()) sql_generation = self.repo.get_sql_generation( nl_generation.sql_generation_id, org_id @@ -218,11 +224,7 @@ async def create_sql_generation( GenerationStatus.REJECTED, GenerationStatus.VERIFIED, }: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Cannot create SQL generation for a prompt that has been verified or rejected", - ) - + raise GenerationVerifiedOrRejectedError(prompt_id, org_id) create_request.metadata = SQLGenerationMetadata( **create_request.metadata, dh_internal=DHSQLGenerationMetadata(organization_id=org_id), @@ -237,7 +239,7 @@ async def create_sql_generation( json=create_request.dict(exclude_unset=True), timeout=settings.default_engine_timeout, ) - self._raise_for_generation_status(response, org_id, prompt) + raise_engine_exception(response, org_id=org_id) sql_generation = SQLGenerationResponse(**response.json()) self._update_generation_status(prompt_id, sql_generation.status) @@ -256,10 +258,7 @@ async def create_sql_nl_generation( GenerationStatus.REJECTED, GenerationStatus.VERIFIED, }: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Cannot create SQL generation for a prompt that has been verified or rejected", - ) + raise GenerationVerifiedOrRejectedError(prompt_id, org_id=org_id) create_request.sql_generation.metadata = SQLGenerationMetadata( **create_request.sql_generation.metadata, dh_internal=DHSQLGenerationMetadata(organization_id=org_id), @@ -281,7 +280,7 @@ async def create_sql_nl_generation( json=create_request.dict(exclude_unset=True), timeout=settings.default_engine_timeout, ) - self._raise_for_generation_status(response, org_id, prompt) + raise_engine_exception(response, org_id=org_id) nl_generation = NLGenerationResponse(**response.json()) sql_generation = self.repo.get_sql_generation( nl_generation.sql_generation_id, org_id @@ -311,7 +310,7 @@ async def create_nl_generation( json=create_request.dict(exclude_unset=True), timeout=settings.default_engine_timeout, ) - raise_for_status(response.status_code, response.text) + raise_engine_exception(response, org_id=org_id) return NLGenerationResponse(**response.json()) async def execute_sql_generation( @@ -322,17 +321,14 @@ async def execute_sql_generation( ) -> list[dict]: sql_generation = self.get_sql_generation_in_org(sql_generation_id, org_id) if sql_generation.status != SQLGenerationStatus.VALID: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="SQL Generation is not valid", - ) + raise InvalidSqlGenerationError(sql_generation_id, org_id) async with httpx.AsyncClient() as client: response = await client.get( settings.engine_url + f"/sql-generations/{sql_generation_id}/execute", params={"max_rows": max_rows}, timeout=settings.default_engine_timeout, ) - raise_for_status(response.status_code, response.text) + raise_engine_exception(response, org_id=org_id) return response.json() async def export_csv_file( @@ -340,16 +336,13 @@ async def export_csv_file( ) -> StreamingResponse: sql_generation = self.get_sql_generation_in_org(sql_generation_id, org_id) if sql_generation.status != SQLGenerationStatus.VALID: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="SQL Generation is not valid", - ) + raise InvalidSqlGenerationError(sql_generation_id, org_id) async with httpx.AsyncClient() as client: response = await client.get( settings.engine_url + f"/sql-generations/{sql_generation_id}/csv-file", timeout=settings.default_engine_timeout, ) - raise_for_status(response.status_code, response.text) + raise_engine_exception(response, org_id=org_id) return StreamingResponse( content=response.iter_bytes(), headers=response.headers, @@ -360,10 +353,7 @@ async def export_csv_file( def get_prompt_in_org(self, prompt_id: str, org_id: str) -> Prompt: prompt = self.repo.get_prompt(prompt_id, org_id) if not prompt: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Prompt not found", - ) + raise PromptNotFoundError(prompt_id, org_id) return prompt def get_sql_generation_in_org( @@ -371,10 +361,7 @@ def get_sql_generation_in_org( ) -> SQLGeneration: sql_generation = self.repo.get_sql_generation(sql_generation_id, org_id) if not sql_generation: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="SQL Generation not found", - ) + raise SqlGenerationNotFoundError(sql_generation_id, org_id) return sql_generation def get_nl_generation_in_org( @@ -382,32 +369,9 @@ def get_nl_generation_in_org( ) -> NLGeneration: nl_generation = self.repo.get_nl_generation(nl_generation_id, org_id) if not nl_generation: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="NL Generation not found", - ) + raise NlGenerationNotFoundError(nl_generation_id, org_id) return nl_generation - def _raise_for_generation_status( - self, response: httpx.Response, org_id: str, prompt: Prompt = None - ): - response_json = response.json() - if response.status_code != status.HTTP_201_CREATED: - if "prompt_id" in response_json and response_json["prompt_id"]: - prompt = self.get_prompt(response_json["prompt_id"], org_id) - if prompt: - self.repo.update_prompt_dh_metadata( - prompt.id, - DHPromptMetadata(generation_status=GenerationStatus.ERROR), - ) - raise GenerationEngineError( - status_code=response.status_code, - prompt_id=prompt.id, - display_id=prompt.metadata.dh_internal.display_id, - error_message=response_json["message"], - ) - raise_for_status(response.status_code, response.text) - def _update_generation_status(self, prompt_id: str, status: SQLGenerationStatus): self.repo.update_prompt_dh_metadata( prompt_id, diff --git a/apps/ai/server/modules/golden_sql/models/exceptions.py b/apps/ai/server/modules/golden_sql/models/exceptions.py new file mode 100644 index 00000000..4c578688 --- /dev/null +++ b/apps/ai/server/modules/golden_sql/models/exceptions.py @@ -0,0 +1,40 @@ +from starlette.status import ( + HTTP_400_BAD_REQUEST, + HTTP_404_NOT_FOUND, +) + +from exceptions.error_codes import BaseErrorCode, ErrorCodeData +from exceptions.exceptions import BaseError + + +class GoldenSQLErrorCode(BaseErrorCode): + golden_sql_not_found = ErrorCodeData( + status_code=HTTP_404_NOT_FOUND, message="Golden SQL not found" + ) + cannot_delete_golden_sql = ErrorCodeData( + status_code=HTTP_400_BAD_REQUEST, message="Cannot delete golden SQL" + ) + + +class GoldenSQLError(BaseError): + """ + Base class for golden SQL exceptions + """ + + ERROR_CODES: BaseErrorCode = GoldenSQLErrorCode + + +class GoldenSqlNotFoundError(GoldenSQLError): + def __init__(self, golden_sql_id: str, org_id: str) -> None: + super().__init__( + error_code=GoldenSQLErrorCode.golden_sql_not_found.name, + detail={"golden_sql_id": golden_sql_id, "organization_id": org_id}, + ) + + +class CannotDeleteGoldenSqlError(GoldenSQLError): + def __init__(self, golden_sql_id: str, org_id: str) -> None: + super().__init__( + error_code=GoldenSQLErrorCode.cannot_delete_golden_sql.name, + detail={"golden_sql_id": golden_sql_id, "organization_id": org_id}, + ) diff --git a/apps/ai/server/modules/golden_sql/service.py b/apps/ai/server/modules/golden_sql/service.py index 427eb007..00344d1f 100644 --- a/apps/ai/server/modules/golden_sql/service.py +++ b/apps/ai/server/modules/golden_sql/service.py @@ -1,9 +1,9 @@ from typing import List import httpx -from fastapi import HTTPException, status from config import settings +from exceptions.exception_handlers import raise_engine_exception from modules.generation.models.entities import GenerationStatus from modules.generation.service import DBConnectionService from modules.golden_sql.models.entities import ( @@ -12,11 +12,14 @@ GoldenSQLMetadata, GoldenSQLSource, ) +from modules.golden_sql.models.exceptions import ( + CannotDeleteGoldenSqlError, + GoldenSqlNotFoundError, +) from modules.golden_sql.models.requests import GoldenSQLRequest from modules.golden_sql.models.responses import AggrGoldenSQL from modules.golden_sql.repository import GoldenSQLRepository from utils.analytics import Analytics, EventName, EventType -from utils.exception import raise_for_status from utils.misc import reserved_key_in_metadata @@ -26,8 +29,8 @@ def __init__(self): self.db_connection_service = DBConnectionService() self.analytics = Analytics() - def get_golden_sql(self, golden_id: str, org_id: str) -> AggrGoldenSQL: - golden_sql = self.get_golden_sql_in_org(golden_id, org_id) + def get_golden_sql(self, golden_sql_id: str, org_id: str) -> AggrGoldenSQL: + golden_sql = self.get_golden_sql_in_org(golden_sql_id, org_id) return AggrGoldenSQL( **golden_sql.dict(), db_connection_alias=self.db_connection_service.get_db_connection_in_org( @@ -102,7 +105,7 @@ async def add_user_upload_golden_sql( ], timeout=settings.default_engine_timeout, ) - raise_for_status(response.status_code, response.text) + raise_engine_exception(response, org_id=org_id) response_jsons = response.json() golden_sqls = [ @@ -131,24 +134,24 @@ async def add_user_upload_golden_sql( # we can avoid cyclic import if we avoid deleting verified golden sql async def delete_golden_sql( - self, golden_id: str, org_id: str, query_status: GenerationStatus = None + self, golden_sql_id: str, org_id: str, query_status: GenerationStatus = None ) -> dict: - golden_sql = self.get_golden_sql_in_org(golden_id, org_id) + golden_sql = self.get_golden_sql_in_org(golden_sql_id, org_id) async with httpx.AsyncClient() as client: response = await client.delete( - settings.engine_url + f"/golden-sqls/{golden_id}", + settings.engine_url + f"/golden-sqls/{golden_sql_id}", timeout=settings.default_engine_timeout, ) - raise_for_status(response.status_code, response.text) + raise_engine_exception(response, org_id=org_id) if response.json()["status"]: if query_status: self.repo.update_generation_status( golden_sql.metadata.dh_internal.prompt_id, query_status ) - return {"id": golden_id} + return {"id": golden_sql_id} - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) + raise CannotDeleteGoldenSqlError(golden_sql_id, org_id) def get_verified_golden_sql(self, prompt_id: str) -> GoldenSQL: return self.repo.get_verified_golden_sql(prompt_id) @@ -182,7 +185,7 @@ async def add_verified_golden_sql( json=[golden_sql_request.dict(exclude_unset=True)], timeout=settings.default_engine_timeout, ) - raise_for_status(response.status_code, response.text) + raise_engine_exception(response, org_id=org_id) response_json = response.json()[0] self.analytics.track( @@ -193,11 +196,8 @@ async def add_verified_golden_sql( return GoldenSQL(**response_json) - def get_golden_sql_in_org(self, golden_id: str, org_id: str) -> GoldenSQL: - golden_sql = self.repo.get_golden_sql(golden_id, org_id) + def get_golden_sql_in_org(self, golden_sql_id: str, org_id: str) -> GoldenSQL: + golden_sql = self.repo.get_golden_sql(golden_sql_id, org_id) if not golden_sql: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Golden sql not found", - ) + raise GoldenSqlNotFoundError(golden_sql_id, org_id) return golden_sql diff --git a/apps/ai/server/modules/instruction/models/exceptions.py b/apps/ai/server/modules/instruction/models/exceptions.py new file mode 100644 index 00000000..51f5bce4 --- /dev/null +++ b/apps/ai/server/modules/instruction/models/exceptions.py @@ -0,0 +1,49 @@ +from starlette.status import ( + HTTP_400_BAD_REQUEST, + HTTP_404_NOT_FOUND, +) + +from exceptions.error_codes import BaseErrorCode, ErrorCodeData +from exceptions.exceptions import BaseError + + +class InstructionErrorCode(BaseErrorCode): + instruction_not_found = ErrorCodeData( + status_code=HTTP_404_NOT_FOUND, message="Instruction not found" + ) + single_instruction_only = ErrorCodeData( + status_code=HTTP_400_BAD_REQUEST, + message="Only one instruction allowed per database connection", + ) + + +class InstructionError(BaseError): + """ + Base class for instruction exceptions + """ + + ERROR_CODES: BaseErrorCode = InstructionErrorCode + + +class InstructionNotFoundError(InstructionError): + def __init__( + self, org_id: str, instruction_id: str | None, db_connection_id: str | None + ) -> None: + if instruction_id: + detail = {"db_connection_id": db_connection_id, "organization_id": org_id} + elif db_connection_id: + detail = {"db_connection_id": db_connection_id, "organization_id": org_id} + else: + raise ValueError("instruction_id or db_connection_id must be provided") + super().__init__( + error_code=InstructionErrorCode.instruction_not_found.name, + detail=detail, + ) + + +class SingleInstructionOnlyError(InstructionError): + def __init__(self, db_connection_id: str, org_id: str) -> None: + super().__init__( + error_code=InstructionErrorCode.single_instruction_only.name, + detail={"db_connection_id": db_connection_id, "organization_id": org_id}, + ) diff --git a/apps/ai/server/modules/instruction/service.py b/apps/ai/server/modules/instruction/service.py index 882d6648..a58ff8a5 100644 --- a/apps/ai/server/modules/instruction/service.py +++ b/apps/ai/server/modules/instruction/service.py @@ -1,17 +1,20 @@ import httpx -from fastapi import HTTPException, status from config import settings +from exceptions.exception_handlers import raise_engine_exception from modules.db_connection.service import DBConnectionService from modules.instruction.models.entities import ( DHInstructionMetadata, Instruction, InstructionMetadata, ) +from modules.instruction.models.exceptions import ( + InstructionNotFoundError, + SingleInstructionOnlyError, +) from modules.instruction.models.requests import InstructionRequest from modules.instruction.models.responses import AggrInstruction from modules.instruction.repository import InstructionRepository -from utils.exception import raise_for_status from utils.misc import reserved_key_in_metadata @@ -57,10 +60,7 @@ def get_first_instruction( ) instructions = self.repo.get_instructions(db_connection_id, org_id) if len(instructions) == 0: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Instruction not found", - ) + raise InstructionNotFoundError(org_id, db_connection_id=db_connection_id) return AggrInstruction( **instructions[0].dict(), db_connection_alias=db_connection.alias ) @@ -71,9 +71,8 @@ async def add_instruction( reserved_key_in_metadata(instruction_request.metadata) if self.repo.get_instructions(instruction_request.db_connection_id, org_id): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Instruction already exists for this db connection", + raise SingleInstructionOnlyError( + instruction_request.db_connection_id, org_id ) db_connection = self.db_connection_service.get_db_connection_in_org( @@ -89,7 +88,7 @@ async def add_instruction( settings.engine_url + "/instructions", json=instruction_request.dict(exclude_unset=True), ) - raise_for_status(response.status_code, response.text) + raise_engine_exception(response, org_id=org_id) return AggrInstruction( **response.json(), db_connection_alias=db_connection.alias ) @@ -120,7 +119,7 @@ async def update_instruction( settings.engine_url + f"/instructions/{instruction_id}", json=instruction_request.dict(exclude_unset=True), ) - raise_for_status(response.status_code, response.text) + raise_engine_exception(response, org_id=org_id) return AggrInstruction( **response.json(), db_connection_alias=db_connection.alias ) @@ -135,14 +134,11 @@ async def delete_instruction(self, instruction_id: str, org_id: str): response = await client.delete( settings.engine_url + f"/instructions/{instruction_id}", ) - raise_for_status(response.status_code, response.text) + raise_engine_exception(response, org_id=org_id) return {"id": instruction_id} def get_instruction_in_org(self, instruction_id: str, org_id: str) -> Instruction: instruction = self.repo.get_instruction(instruction_id, org_id) if not instruction: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Instruction not found", - ) + raise InstructionNotFoundError(org_id, instruction_id=instruction_id) return instruction diff --git a/apps/ai/server/modules/key/models/exceptions.py b/apps/ai/server/modules/key/models/exceptions.py new file mode 100644 index 00000000..2e769fcd --- /dev/null +++ b/apps/ai/server/modules/key/models/exceptions.py @@ -0,0 +1,62 @@ +from starlette.status import ( + HTTP_400_BAD_REQUEST, + HTTP_401_UNAUTHORIZED, +) + +from exceptions.error_codes import BaseErrorCode, ErrorCodeData +from exceptions.exceptions import BaseError + + +class KeyErrorCode(BaseErrorCode): + key_not_found = ErrorCodeData( + status_code=HTTP_401_UNAUTHORIZED, message="API key not found" + ) + key_name_exists = ErrorCodeData( + status_code=HTTP_400_BAD_REQUEST, message="Existing key already has name" + ) + cannot_revoke_key = ErrorCodeData( + status_code=HTTP_400_BAD_REQUEST, message="Cannot revoke api key" + ) + cannot_create_key = ErrorCodeData( + status_code=HTTP_400_BAD_REQUEST, message="Cannot create api key" + ) + + +class KeyError(BaseError): + """ + Base class for api key exceptions + """ + + ERROR_CODES: BaseErrorCode = KeyErrorCode + + +class KeyNotFoundError(KeyError): + def __init__(self, key_id: str, org_id: str) -> None: + super().__init__( + error_code=KeyErrorCode.key_not_found.name, + detail={"key_id": key_id, "organization_id": org_id}, + ) + + +class KeyNameExistsError(KeyError): + def __init__(self, key_id: str, org_id: str) -> None: + super().__init__( + error_code=KeyErrorCode.key_name_exists.name, + detail={"key_id": key_id, "organization_id": org_id}, + ) + + +class CannotRevokeKeyError(KeyError): + def __init__(self, key_id: str, org_id: str) -> None: + super().__init__( + error_code=KeyErrorCode.cannot_revoke_key.name, + detail={"key_id": key_id, "organization_id": org_id}, + ) + + +class CannotCreateKeyError(KeyError): + def __init__(self, org_id: str) -> None: + super().__init__( + error_code=KeyErrorCode.cannot_create_key.name, + detail={"organization_id": org_id}, + ) diff --git a/apps/ai/server/modules/key/repository.py b/apps/ai/server/modules/key/repository.py index 5d637087..4e7fcdb3 100644 --- a/apps/ai/server/modules/key/repository.py +++ b/apps/ai/server/modules/key/repository.py @@ -12,6 +12,10 @@ def get_key(self, key_id: str, org_id: str) -> APIKey: ) return APIKey(id=str(key["_id"]), **key) if key else None + def get_key_by_name(self, name: str, org_id: str) -> APIKey: + key = MongoDB.find_one(KEY_COL, {"name": name, "organization_id": org_id}) + return APIKey(id=str(key["_id"]), **key) if key else None + def get_keys(self, org_id: str) -> list[APIKey]: return [ APIKey(id=str(key["_id"]), **key) @@ -22,8 +26,8 @@ def get_key_by_hash(self, key_hash: str) -> APIKey: key = MongoDB.find_one(KEY_COL, {"key_hash": key_hash}) return APIKey(id=str(key["_id"]), **key) if key else None - def add_key(self, key: dict) -> str: - return str(MongoDB.insert_one(KEY_COL, key)) + def add_key(self, key: APIKey) -> str: + return str(MongoDB.insert_one(KEY_COL, key.dict(exclude={"id"}))) def delete_key(self, key_id: str, org_id: str) -> int: return MongoDB.delete_one( diff --git a/apps/ai/server/modules/key/service.py b/apps/ai/server/modules/key/service.py index af252bce..d859aecf 100644 --- a/apps/ai/server/modules/key/service.py +++ b/apps/ai/server/modules/key/service.py @@ -1,11 +1,13 @@ import hashlib import secrets -from datetime import datetime - -from fastapi import HTTPException, status from config import settings from modules.key.models.entities import APIKey +from modules.key.models.exceptions import ( + CannotCreateKeyError, + CannotRevokeKeyError, + KeyNameExistsError, +) from modules.key.models.requests import KeyGenerationRequest from modules.key.models.responses import KeyPreviewResponse, KeyResponse from modules.key.repository import KeyRepository @@ -24,16 +26,18 @@ def get_keys(self, org_id: str) -> list[KeyPreviewResponse]: def add_key( self, key_request: KeyGenerationRequest, org_id: str, api_key: str = None ) -> KeyResponse: + key = self.repo.get_key_by_name(key_request.name, org_id) + if key: + raise KeyNameExistsError(key.id, org_id) if not api_key: api_key = KEY_PREFIX + self.generate_new_key() key = APIKey( key_hash=self.hash_key(key=api_key), organization_id=org_id, - created_at=datetime.now(), name=key_request.name, key_preview=KEY_PREFIX + "························" + api_key[-3:], ) - key_id = self.repo.add_key(key.dict(exclude_unset=True)) + key_id = self.repo.add_key(key) if key_id: return KeyResponse( @@ -45,10 +49,7 @@ def add_key( api_key=api_key, ) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Could not create key", - ) + raise CannotCreateKeyError(org_id) def validate_key(self, api_key: str) -> APIKey: return self.repo.get_key_by_hash(key_hash=self.hash_key(api_key)) @@ -71,6 +72,4 @@ def revoke_key(self, key_id: str, org_id: str): if self.repo.delete_key(key_id, org_id) == 1: return {"id": key_id} - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail="Key not found" - ) + raise CannotRevokeKeyError(key_id, org_id) diff --git a/apps/ai/server/modules/organization/invoice/exception/error_codes.py b/apps/ai/server/modules/organization/invoice/exception/error_codes.py new file mode 100644 index 00000000..d3c8f4d9 --- /dev/null +++ b/apps/ai/server/modules/organization/invoice/exception/error_codes.py @@ -0,0 +1,49 @@ +from starlette.status import ( + HTTP_400_BAD_REQUEST, + HTTP_402_PAYMENT_REQUIRED, +) + +from exceptions.error_codes import BaseErrorCode, ErrorCodeData + + +class InvoiceErrorCode(BaseErrorCode): + no_payment_method = ErrorCodeData( + status_code=HTTP_402_PAYMENT_REQUIRED, + message="No payment method on file", + ) + last_payment_method = ErrorCodeData( + status_code=HTTP_400_BAD_REQUEST, message="Last payment method" + ) + spending_limit_exceeded = ErrorCodeData( + status_code=HTTP_402_PAYMENT_REQUIRED, message="Spending limit exceeded" + ) + hard_spending_limit_exceeded = ErrorCodeData( + status_code=HTTP_402_PAYMENT_REQUIRED, + message="Hard spending limit exceeded", + ) + subscription_past_due = ErrorCodeData( + status_code=HTTP_402_PAYMENT_REQUIRED, + message="Stripe subscription past due", + ) + subscription_canceled = ErrorCodeData( + status_code=HTTP_402_PAYMENT_REQUIRED, + message="Stripe subscription canceled", + ) + unknown_subscription_status = ErrorCodeData( + status_code=HTTP_400_BAD_REQUEST, + message="Unknown stripe subscription status", + ) + is_enterprise_plan = ErrorCodeData( + status_code=HTTP_400_BAD_REQUEST, + message="Cannot perform action for enterprise plan", + ) + cannot_update_spending_limit = ErrorCodeData( + status_code=HTTP_400_BAD_REQUEST, message="Cannot update spending limit" + ) + cannot_update_payment_method = ErrorCodeData( + status_code=HTTP_400_BAD_REQUEST, message="Cannot update payment method" + ) + missing_invoice_details = ErrorCodeData( + status_code=HTTP_400_BAD_REQUEST, + message="Organization missing invoice details", + ) diff --git a/apps/ai/server/modules/organization/invoice/models/exceptions.py b/apps/ai/server/modules/organization/invoice/models/exceptions.py new file mode 100644 index 00000000..3b7924fa --- /dev/null +++ b/apps/ai/server/modules/organization/invoice/models/exceptions.py @@ -0,0 +1,146 @@ +from starlette.status import ( + HTTP_400_BAD_REQUEST, + HTTP_402_PAYMENT_REQUIRED, +) + +from exceptions.error_codes import BaseErrorCode, ErrorCodeData +from exceptions.exceptions import BaseError + + +class InvoiceErrorCode(BaseErrorCode): + no_payment_method = ErrorCodeData( + status_code=HTTP_402_PAYMENT_REQUIRED, + message="No payment method on file", + ) + last_payment_method = ErrorCodeData( + status_code=HTTP_400_BAD_REQUEST, message="Last payment method" + ) + spending_limit_exceeded = ErrorCodeData( + status_code=HTTP_402_PAYMENT_REQUIRED, message="Spending limit exceeded" + ) + hard_spending_limit_exceeded = ErrorCodeData( + status_code=HTTP_402_PAYMENT_REQUIRED, + message="Hard spending limit exceeded", + ) + subscription_past_due = ErrorCodeData( + status_code=HTTP_402_PAYMENT_REQUIRED, + message="Stripe subscription past due", + ) + subscription_canceled = ErrorCodeData( + status_code=HTTP_402_PAYMENT_REQUIRED, + message="Stripe subscription canceled", + ) + unknown_subscription_status = ErrorCodeData( + status_code=HTTP_400_BAD_REQUEST, + message="Unknown stripe subscription status", + ) + is_enterprise_plan = ErrorCodeData( + status_code=HTTP_400_BAD_REQUEST, + message="Cannot perform action for enterprise plan", + ) + cannot_update_spending_limit = ErrorCodeData( + status_code=HTTP_400_BAD_REQUEST, message="Cannot update spending limit" + ) + cannot_update_payment_method = ErrorCodeData( + status_code=HTTP_400_BAD_REQUEST, message="Cannot update payment method" + ) + missing_invoice_details = ErrorCodeData( + status_code=HTTP_400_BAD_REQUEST, + message="Organization missing invoice details", + ) + + +class InvoiceError(BaseError): + """ + Base class for invoice exceptions + """ + + ERROR_CODES: BaseErrorCode = InvoiceErrorCode + + +class NoPaymentMethodError(InvoiceError): + def __init__(self, organization_id: str) -> None: + super().__init__( + error_code=InvoiceErrorCode.no_payment_method.name, + detail={"organization_id": organization_id}, + ) + + +class LastPaymentMethodError(InvoiceError): + def __init__(self, organization_id: str) -> None: + super().__init__( + error_code=InvoiceErrorCode.last_payment_method.name, + detail={"organization_id": organization_id}, + ) + + +class SpendingLimitExceededError(InvoiceError): + def __init__(self, organization_id: str) -> None: + super().__init__( + error_code=InvoiceErrorCode.spending_limit_exceeded.name, + detail={"organization_id": organization_id}, + ) + + +class HardSpendingLimitExceededError(InvoiceError): + def __init__(self, organization_id: str) -> None: + super().__init__( + error_code=InvoiceErrorCode.hard_spending_limit_exceeded.name, + detail={"organization_id": organization_id}, + ) + + +class SubscriptionPastDueError(InvoiceError): + def __init__(self, organization_id: str) -> None: + super().__init__( + error_code=InvoiceErrorCode.subscription_past_due.name, + detail={"organization_id": organization_id}, + ) + + +class SubscriptionCanceledError(InvoiceError): + def __init__(self, organization_id: str) -> None: + super().__init__( + error_code=InvoiceErrorCode.subscription_canceled.name, + detail={"organization_id": organization_id}, + ) + + +class UnknownSubscriptionStatusError(InvoiceError): + def __init__(self, organization_id: str) -> None: + super().__init__( + error_code=InvoiceErrorCode.unknown_subscription_status.name, + detail={"organization_id": organization_id}, + ) + + +class IsEnterprisePlanError(InvoiceError): + def __init__(self, organization_id: str) -> None: + super().__init__( + error_code=InvoiceErrorCode.is_enterprise_plan.name, + detail={"organization_id": organization_id}, + ) + + +class CannotUpdateSpendingLimitError(InvoiceError): + def __init__(self, organization_id: str) -> None: + super().__init__( + error_code=InvoiceErrorCode.cannot_update_spending_limit.name, + detail={"organization_id": organization_id}, + ) + + +class CannotUpdatePaymentMethodError(InvoiceError): + def __init__(self, organization_id: str) -> None: + super().__init__( + error_code=InvoiceErrorCode.cannot_update_payment_method.name, + detail={"organization_id": organization_id}, + ) + + +class MissingInvoiceDetailsError(InvoiceError): + def __init__(self, organization_id: str) -> None: + super().__init__( + error_code=InvoiceErrorCode.missing_invoice_details.name, + detail={"organization_id": organization_id}, + ) diff --git a/apps/ai/server/modules/organization/invoice/repository.py b/apps/ai/server/modules/organization/invoice/repository.py index 583483e3..f3d8ed51 100644 --- a/apps/ai/server/modules/organization/invoice/repository.py +++ b/apps/ai/server/modules/organization/invoice/repository.py @@ -88,8 +88,8 @@ def get_positive_credits(self, org_id: str) -> list[Credit]: ) ] - def create_usage(self, usage: dict) -> str: - return str(MongoDB.insert_one(USAGE_COL, usage)) + def create_usage(self, usage: Usage) -> str: + return str(MongoDB.insert_one(USAGE_COL, usage.dict(exclude={"id"}))) def update_spending_limit(self, org_id: str, spending_limit: int) -> int: return MongoDB.update_one( @@ -119,8 +119,8 @@ def update_billing_cyce_anchor(self, org_id: str, billing_cycle_anchor: int) -> {"invoice_details.billing_cycle_anchor": billing_cycle_anchor}, ) - def create_credit(self, credit: dict) -> str: - return str(MongoDB.insert_one(CREDIT_COL, credit)) + def create_credit(self, credit: Credit) -> str: + return str(MongoDB.insert_one(CREDIT_COL, credit.dict(exclude={"id"}))) def update_available_credits(self, org_id: str, credit: int) -> int: return MongoDB.update_one( diff --git a/apps/ai/server/modules/organization/invoice/service.py b/apps/ai/server/modules/organization/invoice/service.py index 57263cc3..ae94ad05 100644 --- a/apps/ai/server/modules/organization/invoice/service.py +++ b/apps/ai/server/modules/organization/invoice/service.py @@ -1,6 +1,5 @@ from datetime import datetime -from fastapi import HTTPException, status from stripe import PaymentMethod from config import invoice_settings @@ -13,6 +12,19 @@ UsageInvoice, UsageType, ) +from modules.organization.invoice.models.exceptions import ( + CannotUpdatePaymentMethodError, + CannotUpdateSpendingLimitError, + HardSpendingLimitExceededError, + IsEnterprisePlanError, + LastPaymentMethodError, + MissingInvoiceDetailsError, + NoPaymentMethodError, + SpendingLimitExceededError, + SubscriptionCanceledError, + SubscriptionPastDueError, + UnknownSubscriptionStatusError, +) from modules.organization.invoice.models.requests import ( CreditRequest, PaymentMethodRequest, @@ -28,7 +40,6 @@ from modules.organization.repository import OrganizationRepository from utils.analytics import Analytics, EventName, EventType from utils.billing import Billing -from utils.exception import ErrorCode class InvoiceService: @@ -70,10 +81,7 @@ def update_spending_limit( hard_spending_limit=organization.invoice_details.hard_spending_limit, ) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Unable to update spending limit", - ) + raise CannotUpdateSpendingLimitError(org_id) def get_pending_invoice(self, org_id: str) -> InvoiceResponse: @@ -194,10 +202,7 @@ def set_default_payment_method( ): return {"success": True} - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Unable to set default payment method", - ) + raise CannotUpdatePaymentMethodError(org_id) def detach_payment_method( self, org_id: str, payment_method_id: str @@ -210,10 +215,7 @@ def detach_payment_method( organization.invoice_details.stripe_customer_id ) if len(payment_methods) <= 1: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Cannot detach last payment method", - ) + raise LastPaymentMethodError(org_id) # check if payment method exists for customer, avoids using stripe api payment_method = None @@ -233,9 +235,7 @@ def detach_payment_method( break return self._get_mapped_payment_method_response(payment_method, False) - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail="Payment method not found" - ) + raise NoPaymentMethodError(org_id) def record_usage( self, @@ -256,7 +256,7 @@ def record_usage( description=description, status=RecordStatus.UNRECORDED, ) - usage_id = self.repo.create_usage(usage.dict(exclude={"id"})) + usage_id = self.repo.create_usage(usage) print(f"New usage created: {usage_id}") available_credits = organization.invoice_details.available_credits self._apply_unrecorded_credits( @@ -286,15 +286,24 @@ def check_usage( # check if organization has payment method organization = self.org_repo.get_organization(org_id) if not organization.invoice_details: - raise HTTPException( - status_code=status.HTTP_402_PAYMENT_REQUIRED, - detail="Organization does not have invoice details", - ) + raise MissingInvoiceDetailsError(org_id) # skip check if enterprise if organization.invoice_details.plan != PaymentPlan.ENTERPRISE: - self._check_subscription_status( + if ( organization.invoice_details.stripe_subscription_status - ) + != StripeSubscriptionStatus.ACTIVE + ): + if ( + organization.invoice_details.stripe_subscription_status + == StripeSubscriptionStatus.PAST_DUE + ): + raise SubscriptionPastDueError(org_id) + if ( + organization.invoice_details.stripe_subscription_status + == StripeSubscriptionStatus.CANCELED + ): + raise SubscriptionCanceledError(org_id) + raise UnknownSubscriptionStatusError(org_id) start_date, end_date = ( self.billing.get_current_subscription_period_with_anchor( organization.invoice_details.billing_cycle_anchor @@ -316,27 +325,23 @@ def check_usage( ) > organization.invoice_details.available_credits ): - raise HTTPException( - status_code=status.HTTP_402_PAYMENT_REQUIRED, - detail=ErrorCode.no_payment_method, - ) + raise NoPaymentMethodError(org_id) # for usage based and credit only - self._check_spending_limit_from_usage( - usages, - organization.invoice_details.spending_limit, - organization.invoice_details.hard_spending_limit, + total_usage_cost = self._calculate_total_usage_cost( + self._get_invoice_from_usages(usages) ) + if total_usage_cost > organization.invoice_details.hard_spending_limit: + raise HardSpendingLimitExceededError(org_id) + if total_usage_cost > organization.invoice_details.spending_limit: + raise SpendingLimitExceededError(org_id) def add_credits( self, org_id: str, user_id: str, credit_request: CreditRequest ) -> CreditResponse: organization = self.org_repo.get_organization(org_id) if organization.invoice_details.plan == PaymentPlan.ENTERPRISE: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Cannot add credits to enterprise plan", - ) + raise IsEnterprisePlanError(org_id) credit_id = self.repo.create_credit( Credit( @@ -344,7 +349,7 @@ def add_credits( amount=credit_request.amount, status=RecordStatus.RECORDED, description=f"added by {user_id}: {credit_request.description}", - ).dict(exclude={"id"}) + ) ) print(f"New credit created: {credit_id}") # apply credits to recorded usage @@ -359,7 +364,7 @@ def add_credits( amount=-credits_due, status=RecordStatus.RECORDED, description=f"negative credits for stripe pending invoice; used from new credit {credit_id}", - ).dict(exclude={"id"}) + ) ) self.billing.create_balance_transaction( organization.invoice_details.stripe_customer_id, @@ -381,23 +386,6 @@ def add_credits( ) return self.repo.get_credit(credit_id) - def _check_spending_limit_from_usage( - self, usages: list[Usage], spending_limit: int, hard_spending_limit: int - ): - total_usage_cost = self._calculate_total_usage_cost( - self._get_invoice_from_usages(usages) - ) - if total_usage_cost > hard_spending_limit: - raise HTTPException( - status_code=status.HTTP_402_PAYMENT_REQUIRED, - detail=ErrorCode.hard_spending_limit_exceeded, - ) - if total_usage_cost > spending_limit: - raise HTTPException( - status_code=status.HTTP_402_PAYMENT_REQUIRED, - detail=ErrorCode.spending_limit_exceeded, - ) - def _get_invoice_from_usages(self, usages: list[Usage]) -> UsageInvoice: usage_invoice = { UsageType.SQL_GENERATION: 0, @@ -436,23 +424,6 @@ def _get_mapped_payment_method_response( is_default=is_defualt, ) - def _check_subscription_status(self, subscription_status: str): - if subscription_status != StripeSubscriptionStatus.ACTIVE: - if subscription_status == StripeSubscriptionStatus.PAST_DUE: - raise HTTPException( - status_code=status.HTTP_402_PAYMENT_REQUIRED, - detail=ErrorCode.subscription_past_due, - ) - if subscription_status == StripeSubscriptionStatus.CANCELED: - raise HTTPException( - status_code=status.HTTP_402_PAYMENT_REQUIRED, - detail=ErrorCode.subscription_canceled, - ) - raise HTTPException( - status_code=status.HTTP_402_PAYMENT_REQUIRED, - detail=ErrorCode.unknown_subscription_status, - ) - def _apply_unrecorded_credits( self, org_id: str, @@ -469,7 +440,7 @@ def _apply_unrecorded_credits( amount=-credits_due, status=RecordStatus.UNRECORDED, description=description, - ).dict(exclude={"id"}) + ) ) print(f"New negative credit created: {neg_credit_id}") self.repo.update_available_credits(org_id, available_credits - credits_due) diff --git a/apps/ai/server/modules/organization/models/exceptions.py b/apps/ai/server/modules/organization/models/exceptions.py new file mode 100644 index 00000000..aaf3bb2e --- /dev/null +++ b/apps/ai/server/modules/organization/models/exceptions.py @@ -0,0 +1,90 @@ +from starlette.status import ( + HTTP_400_BAD_REQUEST, + HTTP_404_NOT_FOUND, +) + +from exceptions.error_codes import BaseErrorCode, ErrorCodeData +from exceptions.exceptions import BaseError + + +class OrganizationErrorCode(BaseErrorCode): + organization_not_found = ErrorCodeData( + status_code=HTTP_404_NOT_FOUND, message="Organization not found" + ) + slack_installation_not_found = ErrorCodeData( + status_code=HTTP_404_NOT_FOUND, message="Slack installation not found" + ) + cannot_create_organization = ErrorCodeData( + status_code=HTTP_400_BAD_REQUEST, message="Cannot create organization" + ) + cannot_update_organization = ErrorCodeData( + status_code=HTTP_400_BAD_REQUEST, message="Cannot update organization" + ) + cannot_delete_organization = ErrorCodeData( + status_code=HTTP_400_BAD_REQUEST, message="Cannot delete organization" + ) + invalid_llm_api_key = ErrorCodeData( + status_code=HTTP_400_BAD_REQUEST, message="Invalid LLM API key" + ) + + +class OrganizationError(BaseError): + """ + Base class for organization exceptions + """ + + ERROR_CODES: BaseErrorCode = OrganizationErrorCode + + +class OrganizationNotFoundError(OrganizationError): + def __init__( + self, slack_workspace_id: str | None, organization_id: str | None + ) -> None: + if slack_workspace_id: + detail = {"slack_workspace_id": slack_workspace_id} + elif organization_id: + detail = {"organization_id": organization_id} + else: + raise ValueError("workspace_id or organization_id must be provided") + super().__init__( + error_code=OrganizationErrorCode.organization_not_found.name, + detail=detail, + ) + + +class SlackInstallationNotFoundError(OrganizationError): + def __init__(self, slack_workspace_id: str) -> None: + super().__init__( + error_code=OrganizationErrorCode.slack_installation_not_found.name, + detail={"slack_workspace_id": slack_workspace_id}, + ) + + +class CannotCreateOrganizationError(OrganizationError): + def __init__(self) -> None: + super().__init__( + error_code=OrganizationErrorCode.cannot_create_organization.name, + ) + + +class CannotUpdateOrganizationError(OrganizationError): + def __init__(self, organization_id: str) -> None: + super().__init__( + error_code=OrganizationErrorCode.cannot_update_organization.name, + detail={"organization_id": organization_id}, + ) + + +class CannotDeleteOrganizationError(OrganizationError): + def __init__(self, organization_id: str) -> None: + super().__init__( + error_code=OrganizationErrorCode.cannot_delete_organization.name, + detail={"organization_id": organization_id}, + ) + + +class InvalidLlmApiKeyError(OrganizationError): + def __init__(self) -> None: + super().__init__( + error_code=OrganizationErrorCode.invalid_llm_api_key.name, + ) diff --git a/apps/ai/server/modules/organization/repository.py b/apps/ai/server/modules/organization/repository.py index 54c2bc04..b810200f 100644 --- a/apps/ai/server/modules/organization/repository.py +++ b/apps/ai/server/modules/organization/repository.py @@ -52,8 +52,7 @@ def update_organization(self, org_id: str, new_org_data: dict) -> int: ORGANIZATION_COL, {"_id": ObjectId(org_id)}, new_org_data ) - def add_organization(self, new_org_data: dict) -> str: - # each organization should have unique name - if MongoDB.find_one(ORGANIZATION_COL, {"name": new_org_data["name"]}): - return None - return str(MongoDB.insert_one(ORGANIZATION_COL, new_org_data)) + def add_organization(self, new_org_data: Organization) -> str: + return str( + MongoDB.insert_one(ORGANIZATION_COL, new_org_data.dict(exclude={"id"})) + ) diff --git a/apps/ai/server/modules/organization/service.py b/apps/ai/server/modules/organization/service.py index 54d53c8c..bd14ce5b 100644 --- a/apps/ai/server/modules/organization/service.py +++ b/apps/ai/server/modules/organization/service.py @@ -1,5 +1,4 @@ import openai -from fastapi import HTTPException, status from config import invoice_settings from modules.organization.invoice.models.entities import ( @@ -14,6 +13,14 @@ SlackConfig, SlackInstallation, ) +from modules.organization.models.exceptions import ( + CannotCreateOrganizationError, + CannotDeleteOrganizationError, + CannotUpdateOrganizationError, + InvalidLlmApiKeyError, + OrganizationNotFoundError, + SlackInstallationNotFoundError, +) from modules.organization.models.requests import OrganizationRequest from modules.organization.models.responses import OrganizationResponse from modules.organization.repository import OrganizationRepository @@ -36,15 +43,14 @@ def get_organization(self, org_id: str) -> OrganizationResponse: return self.repo.get_organization(org_id) def get_organization_by_slack_workspace_id( - self, workspace_id: str + self, slack_workspace_id: str ) -> OrganizationResponse: - organization = self.repo.get_organization_by_slack_workspace_id(workspace_id) + organization = self.repo.get_organization_by_slack_workspace_id( + slack_workspace_id + ) if organization: return organization - - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail="Organization not found" - ) + raise OrganizationNotFoundError(slack_workspace_id=slack_workspace_id) def add_organization( self, org_request: OrganizationRequest @@ -69,7 +75,7 @@ def add_organization( hard_spending_limit=invoice_settings.default_hard_spending_limit, available_credits=invoice_settings.signup_credits, ) - new_id = self.repo.add_organization(organization.dict(exclude_unset=True)) + new_id = self.repo.add_organization(organization) if new_id: new_organization = self.repo.get_organization(new_id) # create signup credit, mark as recorded @@ -79,7 +85,7 @@ def add_organization( amount=invoice_settings.signup_credits, status=RecordStatus.RECORDED, description="Signup credits", - ).dict(exclude={"id"}) + ) ) print(f"New credit created: {credit_id}") self.analytics.track( @@ -93,10 +99,7 @@ def add_organization( ) return OrganizationResponse(**new_organization.dict()) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Organization exists or cannot add organization", - ) + raise CannotCreateOrganizationError() def add_user_organization(self, user_id: str, user_email: str) -> str: new_organization = self.add_organization( @@ -120,19 +123,13 @@ def update_organization( ): return self.repo.get_organization(org_id) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Organization not found or cannot be updated", - ) + raise CannotUpdateOrganizationError(org_id) def delete_organization(self, org_id: str) -> dict: if self.repo.delete_organization(org_id) == 1: return {"id": org_id} - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Organization not found or cannot be deleted", - ) + raise CannotDeleteOrganizationError(org_id) def add_organization_by_slack_installation( self, slack_installation_request: SlackInstallation @@ -153,10 +150,7 @@ def add_organization_by_slack_installation( updated_org = self.repo.get_organization(str(current_org.id)) return OrganizationResponse(**updated_org.dict()) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="An error ocurred while updating organization", - ) + raise CannotUpdateOrganizationError(current_org.id) organization = Organization( name=slack_installation_request.team.name, @@ -179,24 +173,31 @@ def add_organization_by_slack_installation( available_credits=invoice_settings.signup_credits, ) - new_id = self.repo.add_organization(organization.dict(exclude={"id"})) + new_id = self.repo.add_organization(organization) if new_id: # create signup credit, mark as recorded + new_organization = self.repo.get_organization(new_id) credit_id = self.invoice_repo.create_credit( Credit( organization_id=new_id, amount=invoice_settings.signup_credits, status=RecordStatus.RECORDED, description="Signup credits", - ).dict(exclude={"id"}) + ) ) print(f"New credit created: {credit_id}") - return self.repo.get_organization(new_id) + self.analytics.track( + new_organization.id, + EventName.organization_created, + EventType.organization_event( + id=new_organization.id, + name=new_organization.name, + owner=new_organization.owner, + ), + ) + return new_organization - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Organization exists or cannot add organization", - ) + raise CannotCreateOrganizationError() def get_slack_installation_by_slack_workspace_id( self, slack_workspace_id: str @@ -207,9 +208,7 @@ def get_slack_installation_by_slack_workspace_id( if organization: return organization.slack_config.slack_installation - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail="slack installation not found" - ) + raise SlackInstallationNotFoundError(slack_workspace_id) def get_organization_by_customer_id(self, customer_id: str) -> Organization: return self.repo.get_organization_by_customer_id(customer_id) @@ -223,7 +222,4 @@ def _validate_api_key(self, llm_api_key: str): try: openai.Model.list() except openai.error.AuthenticationError as e: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid LLM API key", - ) from e + raise InvalidLlmApiKeyError() from e diff --git a/apps/ai/server/modules/table_description/models/exceptions.py b/apps/ai/server/modules/table_description/models/exceptions.py new file mode 100644 index 00000000..6afcc31a --- /dev/null +++ b/apps/ai/server/modules/table_description/models/exceptions.py @@ -0,0 +1,31 @@ +from starlette.status import ( + HTTP_404_NOT_FOUND, +) + +from exceptions.error_codes import BaseErrorCode, ErrorCodeData +from exceptions.exceptions import BaseError + + +class TableDescriptionErrorCode(BaseErrorCode): + table_description_not_found = ErrorCodeData( + status_code=HTTP_404_NOT_FOUND, message="Table description not found" + ) + + +class TableDescriptionError(BaseError): + """ + Base class for table description exceptions + """ + + ERROR_CODES: BaseErrorCode = TableDescriptionErrorCode + + +class TableDescriptionNotFoundError(TableDescriptionError): + def __init__(self, table_description_id: str, org_id: str) -> None: + super().__init__( + error_code=TableDescriptionErrorCode.table_description_not_found.name, + detail={ + "table_description_id": table_description_id, + "organization_id": org_id, + }, + ) diff --git a/apps/ai/server/modules/table_description/service.py b/apps/ai/server/modules/table_description/service.py index 42018f63..8821bee3 100644 --- a/apps/ai/server/modules/table_description/service.py +++ b/apps/ai/server/modules/table_description/service.py @@ -1,13 +1,14 @@ import httpx -from fastapi import HTTPException, status from config import settings +from exceptions.exception_handlers import raise_engine_exception from modules.db_connection.service import DBConnectionService from modules.table_description.models.entities import ( DHTableDescriptionMetadata, TableDescription, TableDescriptionMetadata, ) +from modules.table_description.models.exceptions import TableDescriptionNotFoundError from modules.table_description.models.requests import ( ScanRequest, TableDescriptionRequest, @@ -18,7 +19,6 @@ DatabaseDescriptionResponse, ) from modules.table_description.repository import TableDescriptionRepository -from utils.exception import raise_for_status from utils.misc import reserved_key_in_metadata @@ -39,7 +39,7 @@ async def get_table_descriptions( params={"db_connection_id": db_connection_id, "table_name": table_name}, timeout=settings.default_engine_timeout, ) - raise_for_status(response.status_code, response.text) + raise_engine_exception(response, org_id=org_id) table_descriptions = [ AggrTableDescription( **table_description, db_connection_alias=db_connection.alias @@ -68,7 +68,7 @@ async def get_table_description( settings.engine_url + f"/table-descriptions/{table_description_id}", timeout=settings.default_engine_timeout, ) - raise_for_status(response.status_code, response.text) + raise_engine_exception(response, org_id=org_id) table_description = AggrTableDescription( **response.json(), db_connection_alias=db_connection.alias ) @@ -94,7 +94,7 @@ async def refresh_table_description( ) try: - raise_for_status(response.status_code, response.text) + raise_engine_exception(response, org_id=org_id) table_descriptions = [ AggrTableDescription(**table_description) for table_description in response.json() @@ -139,7 +139,7 @@ async def get_database_description_list( params={"db_connection_id": db_connection.id}, timeout=settings.default_engine_timeout, ) - raise_for_status(response.status_code, response.text) + raise_engine_exception(response, org_id=org_id) table_descriptions = [ AggrTableDescription(**table_description) for table_description in response.json() @@ -187,7 +187,7 @@ async def sync_databases_schemas( json=scan_request.dict(exclude_unset=True), timeout=settings.default_engine_timeout, ) - raise_for_status(response.status_code, response.text) + raise_engine_exception(response, org_id=org_id) table_descriptions = [ AggrTableDescription(**table_description) for table_description in response.json() @@ -218,7 +218,7 @@ async def update_table_description( settings.engine_url + f"/table-descriptions/{table_description_id}", json=table_description_request.dict(exclude_unset=True), ) - raise_for_status(response.status_code, response.text) + raise_engine_exception(response, org_id=org_id) return AggrTableDescription( **response.json(), db_connection_alias=db_connection.alias ) @@ -230,8 +230,5 @@ def get_table_description_in_org( table_description_id, org_id ) if not table_description: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Table Description not found", - ) + raise TableDescriptionNotFoundError(table_description_id, org_id) return table_description diff --git a/apps/ai/server/modules/user/controller.py b/apps/ai/server/modules/user/controller.py index fe1c1861..33d1b9b2 100644 --- a/apps/ai/server/modules/user/controller.py +++ b/apps/ai/server/modules/user/controller.py @@ -35,7 +35,7 @@ async def get_user( async def add_user( new_user_request: UserRequest, token: str = Depends(token_auth_scheme) ) -> UserResponse: - authorize.is_admin_user(VerifyToken(token.credentials).verify()) + authorize.is_admin_user(authorize.user(VerifyToken(token.credentials).verify())) return user_service.add_user(new_user_request) diff --git a/apps/ai/server/modules/user/models/entities.py b/apps/ai/server/modules/user/models/entities.py index 9ec46be2..d4681190 100644 --- a/apps/ai/server/modules/user/models/entities.py +++ b/apps/ai/server/modules/user/models/entities.py @@ -1,7 +1,7 @@ from datetime import datetime from enum import Enum -from pydantic import BaseModel, Field +from pydantic import BaseModel from utils.validation import ObjectIdString @@ -24,4 +24,4 @@ class BaseUser(BaseModel): class User(BaseUser): id: ObjectIdString | None role: Roles | None - created_at: datetime = Field(default_factory=datetime.now) + created_at: datetime = datetime.now() diff --git a/apps/ai/server/modules/user/models/exceptions.py b/apps/ai/server/modules/user/models/exceptions.py new file mode 100644 index 00000000..b87354e5 --- /dev/null +++ b/apps/ai/server/modules/user/models/exceptions.py @@ -0,0 +1,83 @@ +from starlette.status import HTTP_400_BAD_REQUEST, HTTP_404_NOT_FOUND, HTTP_409_CONFLICT + +from exceptions.error_codes import BaseErrorCode, ErrorCodeData +from exceptions.exceptions import BaseError + + +class UserErrorCode(BaseErrorCode): + user_not_found = ErrorCodeData( + status_code=HTTP_404_NOT_FOUND, message="User not found" + ) + user_exists_in_org = ErrorCodeData( + status_code=HTTP_409_CONFLICT, + message="User already exists in organization", + ) + user_exists_in_other_org = ErrorCodeData( + status_code=HTTP_409_CONFLICT, + message="User already exists in other organization", + ) + cannot_create_user = ErrorCodeData( + status_code=HTTP_400_BAD_REQUEST, message="Cannot create user" + ) + cannot_update_user = ErrorCodeData( + status_code=HTTP_400_BAD_REQUEST, message="Cannot update user" + ) + cannot_delete_user = ErrorCodeData( + status_code=HTTP_400_BAD_REQUEST, message="Cannot delete user" + ) + + +class UserError(BaseError): + """ + Base class for user exceptions + """ + + ERROR_CODES: BaseErrorCode = UserErrorCode + + +class UserNotFoundError(UserError): + def __init__(self, user_id: str, org_id: str) -> None: + super().__init__( + error_code=UserErrorCode.user_not_found.name, + detail={"user_id": user_id, "organization_id": org_id}, + ) + + +class UserExistsInOrgError(UserError): + def __init__(self, user_id: str, org_id: str) -> None: + super().__init__( + error_code=UserErrorCode.user_exists_in_org.name, + detail={"user_id": user_id, "organization_id": org_id}, + ) + + +class UserExistsInOtherOrgError(UserError): + def __init__(self, user_id: str, org_id: str) -> None: + super().__init__( + error_code=UserErrorCode.user_exists_in_other_org.name, + detail={"user_id": user_id, "organization_id": org_id}, + ) + + +class CannotCreateUserError(UserError): + def __init__(self, org_id: str) -> None: + super().__init__( + error_code=UserErrorCode.cannot_create_user.name, + detail={"organization_id": org_id}, + ) + + +class CannotUpdateUserError(UserError): + def __init__(self, user_id: str, org_id: str) -> None: + super().__init__( + error_code=UserErrorCode.cannot_update_user.name, + detail={"user_id": user_id, "organization_id": org_id}, + ) + + +class CannotDeleteUserError(UserError): + def __init__(self, user_id: str, org_id: str) -> None: + super().__init__( + error_code=UserErrorCode.cannot_delete_user.name, + detail={"user_id": user_id, "organization_id": org_id}, + ) diff --git a/apps/ai/server/modules/user/service.py b/apps/ai/server/modules/user/service.py index 8cf54e9e..4b27be57 100644 --- a/apps/ai/server/modules/user/service.py +++ b/apps/ai/server/modules/user/service.py @@ -1,7 +1,13 @@ from bson import ObjectId -from fastapi import HTTPException, status from modules.user.models.entities import User +from modules.user.models.exceptions import ( + CannotCreateUserError, + CannotDeleteUserError, + CannotUpdateUserError, + UserExistsInOrgError, + UserExistsInOtherOrgError, +) from modules.user.models.requests import UserOrganizationRequest, UserRequest from modules.user.models.responses import UserResponse from modules.user.repository import UserRepository @@ -34,10 +40,7 @@ def add_user(self, user_request: UserRequest) -> UserResponse: added_user = self.repo.get_user({"_id": ObjectId(new_user_id)}) return UserResponse(**added_user.dict()) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="User exists or cannot add user", - ) + raise CannotCreateUserError(user_request.organization_id) def invite_user_to_org( self, user_request: UserRequest, org_id: str @@ -45,24 +48,15 @@ def invite_user_to_org( stored_user = self.repo.get_user_by_email(user_request.email) if stored_user: if stored_user.organization_id == org_id: - error_code = "USER_ALREADY_EXISTS_IN_ORG" - else: - error_code = "USER_ALREADY_EXISTS_IN_OTHER_ORG" - - raise HTTPException( - status_code=status.HTTP_409_CONFLICT, - detail=error_code, - ) + raise UserExistsInOrgError(stored_user.id) + raise UserExistsInOtherOrgError(stored_user.id, stored_user.organization_id) new_user_data = User( **user_request.dict(exclude={"organization_id"}), organization_id=org_id ) new_user_id = self.repo.add_user(new_user_data) if not new_user_id: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="An error occurred while trying to create the user", - ) + raise CannotCreateUserError(org_id) new_user = self.repo.get_user({"_id": ObjectId(new_user_id)}) @@ -90,10 +84,7 @@ def update_user(self, user_id: str, user_request: UserRequest) -> UserResponse: new_user = self.repo.get_user({"_id": ObjectId(user_id)}) return UserResponse(**new_user.dict()) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="User not found or cannot be updated", - ) + raise CannotUpdateUserError(user_id) def update_user_organization( self, user_id: str, user_organization_request: UserOrganizationRequest @@ -108,10 +99,7 @@ def update_user_organization( new_user = self.repo.get_user({"_id": ObjectId(user_id)}) return UserResponse(**new_user.dict()) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="User not found or cannot be updated", - ) + raise CannotUpdateUserError(user_id) def delete_user(self, user_id: str, org_id: str) -> dict: if ( @@ -130,7 +118,4 @@ def delete_user(self, user_id: str, org_id: str) -> dict: ): return {"id": user_id} - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="User not found or cannot be deleted", - ) + raise CannotDeleteUserError(user_id) diff --git a/apps/ai/server/utils/auth.py b/apps/ai/server/utils/auth.py index aadcf05a..11550e22 100644 --- a/apps/ai/server/utils/auth.py +++ b/apps/ai/server/utils/auth.py @@ -1,6 +1,6 @@ import jwt from bson import ObjectId -from fastapi import HTTPException, Security, status +from fastapi import Security from fastapi.security import APIKeyHeader from config import ( @@ -8,9 +8,21 @@ auth_settings, ) from database.mongo import MongoDB +from exceptions.exceptions import UnknownError +from modules.auth.models.exceptions import ( + BearerTokenExpiredError, + DecodeError, + InvalidBearerTokenError, + InvalidOrRevokedAPIKeyError, + PyJWKClientError, + UnauthorizedDataAccessError, + UnauthorizedOperationError, + UnauthorizedUserError, +) from modules.key.service import KeyService from modules.organization.service import OrganizationService from modules.user.models.entities import Roles +from modules.user.models.exceptions import UserNotFoundError from modules.user.models.responses import UserResponse from modules.user.service import UserService @@ -38,13 +50,9 @@ def _fetch_signing_key(self): try: self.signing_key = self.jwks_client.get_signing_key_from_jwt(self.token).key except jwt.exceptions.PyJWKClientError as error: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(error) - ) from error + raise PyJWKClientError() from error except jwt.exceptions.DecodeError as error: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail=str(error) - ) from error + raise DecodeError() from error def _decode_payload(self): try: @@ -56,21 +64,13 @@ def _decode_payload(self): issuer=auth_settings.auth0_issuer, ) except jwt.ExpiredSignatureError as error: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail="Token expired" - ) from error + raise BearerTokenExpiredError() from error except (jwt.InvalidAudienceError, jwt.InvalidIssuerError) as error: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail="Token is invalid" - ) from error + raise InvalidBearerTokenError() from error except (jwt.DecodeError, jwt.InvalidTokenError) as error: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail="Token is invalid" - ) from error - except Exception as e: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) - ) from e + raise InvalidBearerTokenError() from error + except Exception as error: + raise UnknownError(str(error)) from error class Authorize: @@ -78,52 +78,29 @@ def user(self, payload: dict) -> UserResponse: email = payload[auth_settings.auth0_issuer + "email"] user = user_service.get_user_by_email(email) if not user: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized User" - ) + raise UnauthorizedUserError(email=email) return user def user_in_organization(self, user_id: str, org_id: str): - self._item_in_organization(USER_COL, user_id, org_id) + if not MongoDB.find_one( + USER_COL, + {"_id": ObjectId(user_id), "organization_id": org_id}, + ): + raise UserNotFoundError(user_id, org_id) def is_admin_user(self, user: UserResponse): if user.role != Roles.admin: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail="User not authorized" - ) + raise UnauthorizedOperationError(user_id=user.id) - def is_self(self, id_a: str, id_b: str): - if id_a != id_b: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="User not authorized to access other user data", - ) + def is_self(self, user_a_id: str, user_b_id: str): + # TODO - fix param names to clear up confusion + if user_a_id != user_b_id: + raise UnauthorizedDataAccessError(user_id=user_a_id) - def is_not_self(self, id_a: str, id_b: str): - if id_a == id_b: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="User not authorized to self modify user data", - ) - - def _item_in_organization( - self, - collection: str, - id: str, - org_id: str, - key: str = "_id", - is_metadata: bool = False, - ): - metadata_prefix = "metadata" if is_metadata else "" - item = MongoDB.find_one( - collection, - {key: ObjectId(id), f"{metadata_prefix}organization_id": org_id}, - ) - - if not item: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail="Item not found" - ) + def is_not_self(self, user_a_id: str, user_b_id: str): + # TODO - fix param names to clear up confusion + if user_a_id == user_b_id: + raise UnauthorizedOperationError(user_id=user_a_id) api_key_header = APIKeyHeader(name="X-API-Key") @@ -134,6 +111,4 @@ def get_api_key(api_key: str = Security(api_key_header)) -> str: if validated_key: return validated_key - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail="API Key does not exist" - ) + raise InvalidOrRevokedAPIKeyError(key_id=api_key) diff --git a/apps/ai/server/utils/exception.py b/apps/ai/server/utils/exception.py deleted file mode 100644 index 637f6a90..00000000 --- a/apps/ai/server/utils/exception.py +++ /dev/null @@ -1,49 +0,0 @@ -import logging -from enum import Enum - -from fastapi import HTTPException, Request, status -from fastapi.responses import JSONResponse - -logger = logging.getLogger(__name__) - - -class ErrorCode(str, Enum): - no_payment_method = "no_payment_method" - spending_limit_exceeded = "spending_limit_exceeded" - hard_spending_limit_exceeded = "hard_spending_limit_exceeded" - subscription_past_due = "subscription_past_due" - subscription_canceled = "subscription_canceled" - unknown_subscription_status = "unknown_subscription_status" - - -class GenerationEngineError(Exception): - def __init__( - self, status_code: int, prompt_id: str, display_id: str, error_message: str - ): - self.status_code = status_code - self.prompt_id = prompt_id - self.display_id = display_id - self.error_message = error_message - - -async def query_engine_exception_handler( - request: Request, exc: GenerationEngineError # noqa: ARG001 -): - return JSONResponse( - status_code=exc.status_code, - content={ - "prompt_id": exc.prompt_id, - "display_id": exc.display_id, - "error_message": exc.error_message, - }, - ) - - -def raise_for_status(status_code: int, message: str = None): - if status_code < status.HTTP_400_BAD_REQUEST: - return - - logger.error("Error from K2-Engine: %s", message) - raise HTTPException( - status_code=status_code, detail=f"Error from K2-Engine: {message}" - ) diff --git a/apps/ai/server/utils/misc.py b/apps/ai/server/utils/misc.py index 675ae755..9e06f67b 100644 --- a/apps/ai/server/utils/misc.py +++ b/apps/ai/server/utils/misc.py @@ -1,6 +1,5 @@ -from fastapi import HTTPException, status - from database.mongo import DESCENDING, MongoDB +from exceptions.exceptions import ReservedMetadataKeyError MAX_DISPLAY_ID = 99999 RESERVED_KEY = "dh_internal" @@ -31,7 +30,4 @@ def get_next_display_id(collection, org_id: str, prefix: str) -> str: def reserved_key_in_metadata(metadata: dict): if RESERVED_KEY in metadata: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Metadata cannot contain reserved key: {RESERVED_KEY}", - ) + raise ReservedMetadataKeyError()