Skip to content

Commit

Permalink
AIP-81 Add Insert Multiple Pools API (#44121)
Browse files Browse the repository at this point in the history
* Add bulk post pools, refactor post pool

* Add 409 case for TestPostPool

* Add test for bulk post pools

* Remove unused status code, rename post_body to body

* Refactor duplicate pool insert handling

- handle exception from db level instead of application level

* Add global database exception handler for fastapi

* Remove manual handle for unique constraint exc

* Refactor test_pools

* Fix bound for TypeVar, type for comment
  • Loading branch information
jason810496 authored Nov 22, 2024
1 parent fc52d7d commit f33166a
Show file tree
Hide file tree
Showing 12 changed files with 439 additions and 7 deletions.
9 changes: 8 additions & 1 deletion airflow/api_fastapi/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@
from fastapi import FastAPI
from starlette.routing import Mount

from airflow.api_fastapi.core_api.app import init_config, init_dag_bag, init_plugins, init_views
from airflow.api_fastapi.core_api.app import (
init_config,
init_dag_bag,
init_error_handlers,
init_plugins,
init_views,
)
from airflow.api_fastapi.execution_api.app import create_task_execution_api_app
from airflow.auth.managers.base_auth_manager import BaseAuthManager
from airflow.configuration import conf
Expand Down Expand Up @@ -61,6 +67,7 @@ def create_app(apps: str = "all") -> FastAPI:
init_dag_bag(app)
init_views(app)
init_plugins(app)
init_error_handlers(app)
init_auth_manager()

if "execution" in apps_list or "all" in apps_list:
Expand Down
64 changes: 64 additions & 0 deletions airflow/api_fastapi/common/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Generic, TypeVar

from fastapi import HTTPException, Request, status
from sqlalchemy.exc import IntegrityError

T = TypeVar("T", bound=Exception)


class BaseErrorHandler(Generic[T], ABC):
"""Base class for error handlers."""

def __init__(self, exception_cls: T) -> None:
self.exception_cls = exception_cls

@abstractmethod
def exception_handler(self, request: Request, exc: T):
"""exception_handler method."""
raise NotImplementedError


class _UniqueConstraintErrorHandler(BaseErrorHandler[IntegrityError]):
"""Exception raised when trying to insert a duplicate value in a unique column."""

def __init__(self):
super().__init__(IntegrityError)
self.unique_constraint_error_messages = [
"UNIQUE constraint failed", # SQLite
"Duplicate entry", # MySQL
"violates unique constraint", # PostgreSQL
]

def exception_handler(self, request: Request, exc: IntegrityError):
"""Handle IntegrityError exception."""
exc_orig_str = str(exc.orig)
if any(error_msg in exc_orig_str for error_msg in self.unique_constraint_error_messages):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Unique constraint violation",
)


DatabaseErrorHandlers = [
_UniqueConstraintErrorHandler(),
]
8 changes: 8 additions & 0 deletions airflow/api_fastapi/core_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,11 @@ def init_config(app: FastAPI) -> None:
app.add_middleware(GZipMiddleware, minimum_size=1024, compresslevel=5)

app.state.secret_key = conf.get("webserver", "secret_key")


def init_error_handlers(app: FastAPI) -> None:
from airflow.api_fastapi.common.exceptions import DatabaseErrorHandlers

# register database error handlers
for handler in DatabaseErrorHandlers:
app.add_exception_handler(handler.exception_cls, handler.exception_handler)
8 changes: 7 additions & 1 deletion airflow/api_fastapi/core_api/datamodels/pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ class PoolPatchBody(BaseModel):
class PoolPostBody(BasePool):
"""Pool serializer for post bodies."""

pool: str = Field(alias="name")
pool: str = Field(alias="name", max_length=256)
description: str | None = None
include_deferred: bool = False


class PoolPostBulkBody(BaseModel):
"""Pools serializer for post bodies."""

pools: list[PoolPostBody]
63 changes: 63 additions & 0 deletions airflow/api_fastapi/core_api/openapi/v1-generated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3278,6 +3278,56 @@ paths:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Forbidden
'409':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Conflict
'422':
description: Validation Error
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
/public/pools/bulk:
post:
tags:
- Pool
summary: Post Pools
description: Create multiple pools.
operationId: post_pools
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/PoolPostBulkBody'
required: true
responses:
'201':
description: Successful Response
content:
application/json:
schema:
$ref: '#/components/schemas/PoolCollectionResponse'
'401':
description: Unauthorized
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
'403':
description: Forbidden
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
'409':
description: Conflict
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
'422':
description: Validation Error
content:
Expand Down Expand Up @@ -6544,6 +6594,7 @@ components:
properties:
name:
type: string
maxLength: 256
title: Name
slots:
type: integer
Expand All @@ -6563,6 +6614,18 @@ components:
- slots
title: PoolPostBody
description: Pool serializer for post bodies.
PoolPostBulkBody:
properties:
pools:
items:
$ref: '#/components/schemas/PoolPostBody'
type: array
title: Pools
type: object
required:
- pools
title: PoolPostBulkBody
description: Pools serializer for post bodies.
PoolResponse:
properties:
name:
Expand Down
31 changes: 28 additions & 3 deletions airflow/api_fastapi/core_api/routes/public/pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
PoolCollectionResponse,
PoolPatchBody,
PoolPostBody,
PoolPostBulkBody,
PoolResponse,
)
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
Expand Down Expand Up @@ -160,14 +161,38 @@ def patch_pool(
@pools_router.post(
"",
status_code=status.HTTP_201_CREATED,
responses=create_openapi_http_exception_doc(
[status.HTTP_409_CONFLICT]
), # handled by global exception handler
)
def post_pool(
post_body: PoolPostBody,
body: PoolPostBody,
session: Annotated[Session, Depends(get_session)],
) -> PoolResponse:
"""Create a Pool."""
pool = Pool(**post_body.model_dump())

pool = Pool(**body.model_dump())
session.add(pool)

return PoolResponse.model_validate(pool, from_attributes=True)


@pools_router.post(
"/bulk",
status_code=status.HTTP_201_CREATED,
responses=create_openapi_http_exception_doc(
[
status.HTTP_409_CONFLICT, # handled by global exception handler
]
),
)
def post_pools(
body: PoolPostBulkBody,
session: Annotated[Session, Depends(get_session)],
) -> PoolCollectionResponse:
"""Create multiple pools."""
pools = [Pool(**body.model_dump()) for body in body.pools]
session.add_all(pools)
return PoolCollectionResponse(
pools=[PoolResponse.model_validate(pool, from_attributes=True) for pool in pools],
total_entries=len(pools),
)
3 changes: 3 additions & 0 deletions airflow/ui/openapi-gen/queries/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1330,6 +1330,9 @@ export type DagRunServiceClearDagRunMutationResult = Awaited<
export type PoolServicePostPoolMutationResult = Awaited<
ReturnType<typeof PoolService.postPool>
>;
export type PoolServicePostPoolsMutationResult = Awaited<
ReturnType<typeof PoolService.postPools>
>;
export type TaskInstanceServiceGetTaskInstancesBatchMutationResult = Awaited<
ReturnType<typeof TaskInstanceService.getTaskInstancesBatch>
>;
Expand Down
38 changes: 38 additions & 0 deletions airflow/ui/openapi-gen/queries/queries.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import {
DagWarningType,
PoolPatchBody,
PoolPostBody,
PoolPostBulkBody,
TaskInstancesBatchBody,
VariableBody,
} from "../requests/types.gen";
Expand Down Expand Up @@ -2363,6 +2364,43 @@ export const usePoolServicePostPool = <
PoolService.postPool({ requestBody }) as unknown as Promise<TData>,
...options,
});
/**
* Post Pools
* Create multiple pools.
* @param data The data for the request.
* @param data.requestBody
* @returns PoolCollectionResponse Successful Response
* @throws ApiError
*/
export const usePoolServicePostPools = <
TData = Common.PoolServicePostPoolsMutationResult,
TError = unknown,
TContext = unknown,
>(
options?: Omit<
UseMutationOptions<
TData,
TError,
{
requestBody: PoolPostBulkBody;
},
TContext
>,
"mutationFn"
>,
) =>
useMutation<
TData,
TError,
{
requestBody: PoolPostBulkBody;
},
TContext
>({
mutationFn: ({ requestBody }) =>
PoolService.postPools({ requestBody }) as unknown as Promise<TData>,
...options,
});
/**
* Get Task Instances Batch
* Get list of task instances.
Expand Down
17 changes: 17 additions & 0 deletions airflow/ui/openapi-gen/requests/schemas.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2848,6 +2848,7 @@ export const $PoolPostBody = {
properties: {
name: {
type: "string",
maxLength: 256,
title: "Name",
},
slots: {
Expand Down Expand Up @@ -2877,6 +2878,22 @@ export const $PoolPostBody = {
description: "Pool serializer for post bodies.",
} as const;

export const $PoolPostBulkBody = {
properties: {
pools: {
items: {
$ref: "#/components/schemas/PoolPostBody",
},
type: "array",
title: "Pools",
},
},
type: "object",
required: ["pools"],
title: "PoolPostBulkBody",
description: "Pools serializer for post bodies.",
} as const;

export const $PoolResponse = {
properties: {
name: {
Expand Down
28 changes: 28 additions & 0 deletions airflow/ui/openapi-gen/requests/services.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ import type {
GetPoolsResponse,
PostPoolData,
PostPoolResponse,
PostPoolsData,
PostPoolsResponse,
GetProvidersData,
GetProvidersResponse,
GetTaskInstanceData,
Expand Down Expand Up @@ -1790,6 +1792,32 @@ export class PoolService {
errors: {
401: "Unauthorized",
403: "Forbidden",
409: "Conflict",
422: "Validation Error",
},
});
}

/**
* Post Pools
* Create multiple pools.
* @param data The data for the request.
* @param data.requestBody
* @returns PoolCollectionResponse Successful Response
* @throws ApiError
*/
public static postPools(
data: PostPoolsData,
): CancelablePromise<PostPoolsResponse> {
return __request(OpenAPI, {
method: "POST",
url: "/public/pools/bulk",
body: data.requestBody,
mediaType: "application/json",
errors: {
401: "Unauthorized",
403: "Forbidden",
409: "Conflict",
422: "Validation Error",
},
});
Expand Down
Loading

0 comments on commit f33166a

Please sign in to comment.