diff --git a/go.mod b/go.mod index 66381ac0..7af2fb93 100644 --- a/go.mod +++ b/go.mod @@ -44,7 +44,13 @@ require ( go.opentelemetry.io/otel/metric v1.31.0 // indirect go.opentelemetry.io/otel/sdk v1.31.0 // indirect go.opentelemetry.io/otel/trace v1.31.0 // indirect + golang.org/x/mod v0.17.0 // indirect golang.org/x/sys v0.26.0 // indirect + golang.org/x/text v0.19.0 // indirect + golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect gotest.tools/v3 v3.5.1 // indirect ) + +replace github.com/livepeer/ai-worker => ../home/user/test2/ai-worker diff --git a/go.sum b/go.sum index d5a9b577..14e2ccf4 100644 --- a/go.sum +++ b/go.sum @@ -124,6 +124,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= +golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -151,6 +153,8 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -166,6 +170,8 @@ google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojt gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/runner/app/main.py b/runner/app/main.py index 52668990..cbd80e97 100644 --- a/runner/app/main.py +++ b/runner/app/main.py @@ -61,6 +61,9 @@ def load_pipeline(pipeline: str, model_id: str) -> any: from app.pipelines.image_to_text import ImageToTextPipeline return ImageToTextPipeline(model_id) + case "text-to-audio": + from app.pipelines.text_to_audio import TextToAudioPipeline + return TextToAudioPipeline(model_id) case _: raise EnvironmentError( f"{pipeline} is not a valid pipeline for model {model_id}" @@ -101,6 +104,9 @@ def load_route(pipeline: str) -> any: case "image-to-text": from app.routes import image_to_text return image_to_text.router + case "text-to-audio": + from app.routes import text_to_audio + return text_to_audio.router case _: raise EnvironmentError(f"{pipeline} is not a valid pipeline") diff --git a/runner/app/pipelines/text_to_audio.py b/runner/app/pipelines/text_to_audio.py new file mode 100644 index 00000000..c57ca8e3 --- /dev/null +++ b/runner/app/pipelines/text_to_audio.py @@ -0,0 +1,123 @@ +import logging +import os +from typing import Any, Dict, Optional, Tuple + +import torch +from app.pipelines.base import Pipeline +from app.pipelines.utils import get_model_dir, get_torch_device +from app.utils.errors import InferenceError +from diffusers import StableAudioPipeline +from huggingface_hub import file_download +import numpy as np +import soundfile as sf +import io + +logger = logging.getLogger(__name__) + +class TextToAudioPipeline(Pipeline): + def __init__(self, model_id: str): + """Initialize the text to audio pipeline. + + Args: + model_id: The model ID to use for audio generation. + """ + self.model_id = model_id + kwargs = {"cache_dir": get_model_dir()} + + torch_device = get_torch_device() + folder_name = file_download.repo_folder_name( + repo_id=model_id, repo_type="model" + ) + folder_path = os.path.join(get_model_dir(), folder_name) + + # Load fp16 variant if available + has_fp16_variant = any( + ".fp16.safetensors" in fname + for _, _, files in os.walk(folder_path) + for fname in files + ) + if torch_device != "cpu" and has_fp16_variant: + logger.info("TextToAudioPipeline loading fp16 variant for %s", model_id) + kwargs["torch_dtype"] = torch.float16 + kwargs["variant"] = "fp16" + + if os.environ.get("BFLOAT16"): + logger.info("TextToAudioPipeline using bfloat16 precision for %s", model_id) + kwargs["torch_dtype"] = torch.bfloat16 + + # Initialize the pipeline + self.pipeline = StableAudioPipeline.from_pretrained( + model_id, + **kwargs + ).to(torch_device) + + # Enable optimization if configured + sfast_enabled = os.getenv("SFAST", "").strip().lower() == "true" + if sfast_enabled: + logger.info( + "TextToAudioPipeline will be dynamically compiled with stable-fast for %s", + model_id, + ) + from app.pipelines.optim.sfast import compile_model + self.pipeline = compile_model(self.pipeline) + + def __call__( + self, + prompt: str, + duration: float = 5.0, + num_inference_steps: int = 10, + guidance_scale: float = 3.0, + negative_prompt: str = None, + seed: Optional[int] = None, + **kwargs + ) -> Tuple[bytes, str]: + """Generate audio from text. + + Args: + prompt: The text prompt for audio generation. + duration: Duration of the generated audio in seconds. + num_inference_steps: Number of denoising steps. + guidance_scale: Scale for classifier-free guidance. + negative_prompt: Optional text prompt to guide what to exclude. + seed: Optional seed for reproducible generation. + + Returns: + Tuple containing the audio data as bytes and the file format. + """ + try: + # Set seed if provided + if seed is not None: + torch.manual_seed(seed) + + # Set default steps if invalid + if num_inference_steps is None or num_inference_steps < 1: + num_inference_steps = 10 + + # Validate duration + if duration < 1.0 or duration > 30.0: + raise ValueError("Duration must be between 1 and 30 seconds") + + # Generate audio + audio = self.pipeline( + prompt, + negative_prompt=negative_prompt, + audio_length_in_s=duration, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + **kwargs + ).audio[0] + + # Convert to bytes using soundfile + buffer = io.BytesIO() + sf.write(buffer, audio, samplerate=44100, format='WAV') + buffer.seek(0) + + return buffer.read(), 'wav' + + except Exception as e: + if isinstance(e, torch.cuda.OutOfMemoryError): + torch.cuda.empty_cache() + raise InferenceError(original_exception=e) + + def __str__(self) -> str: + return f"TextToAudioPipeline model_id={self.model_id}" \ No newline at end of file diff --git a/runner/app/routes/text_to_audio.py b/runner/app/routes/text_to_audio.py new file mode 100644 index 00000000..74ea8cd0 --- /dev/null +++ b/runner/app/routes/text_to_audio.py @@ -0,0 +1,149 @@ +import logging +import os +from typing import Annotated, Dict, Tuple, Union + +import torch +from app.dependencies import get_pipeline +from app.pipelines.base import Pipeline +from app.routes.utils import ( + HTTPError, + http_error, + handle_pipeline_exception, +) +from fastapi import APIRouter, Depends, Form, Response, status +from fastapi.responses import JSONResponse +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from pydantic import BaseModel + +router = APIRouter() +logger = logging.getLogger(__name__) + +# Pipeline specific error handling configuration +PIPELINE_ERROR_CONFIG: Dict[str, Tuple[Union[str, None], int]] = { + # Specific error types + "OutOfMemoryError": ( + "Out of memory error. Try reducing audio duration.", + status.HTTP_500_INTERNAL_SERVER_ERROR, + ), + "ValueError": ( + None, # Use the error message from the exception + status.HTTP_400_BAD_REQUEST, + ), +} + +RESPONSES = { + status.HTTP_200_OK: { + "content": { + "audio/wav": { + "schema": { + "type": "string", + "format": "binary" + } + }, + }, + "description": "Successfully generated audio" + }, + status.HTTP_400_BAD_REQUEST: {"model": HTTPError}, + status.HTTP_401_UNAUTHORIZED: {"model": HTTPError}, + status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError}, +} + +@router.post( + "/text-to-audio", + responses=RESPONSES, + description="Generate audio from text prompts using Stable Audio.", + operation_id="genTextToAudio", + summary="Text To Audio", + tags=["generate"], + openapi_extra={"x-speakeasy-name-override": "textToAudio"}, +) +@router.post( + "/text-to-audio/", + responses=RESPONSES, + include_in_schema=False, +) +async def text_to_audio( + prompt: Annotated[ + str, + Form(description="Text prompt for audio generation."), + ], + model_id: Annotated[ + str, + Form(description="Hugging Face model ID used for audio generation."), + ] = "", + duration: Annotated[ + float, + Form(description="Duration of generated audio in seconds (between 1 and 30 seconds)."), + ] = 5.0, + num_inference_steps: Annotated[ + int, + Form(description="Number of denoising steps. More steps usually lead to higher quality audio but slower inference."), + ] = 10, + guidance_scale: Annotated[ + float, + Form(description="Scale for classifier-free guidance. Higher values result in audio that better matches the prompt but may be lower quality."), + ] = 3.0, + negative_prompt: Annotated[ + str, + Form(description="Text prompt to guide what to exclude from audio generation."), + ] = None, + pipeline: Pipeline = Depends(get_pipeline), + token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)), +): + # Validate auth token if configured + auth_token = os.environ.get("AUTH_TOKEN") + if auth_token: + if not token or token.credentials != auth_token: + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + headers={"WWW-Authenticate": "Bearer"}, + content=http_error("Invalid bearer token."), + ) + + # Validate model ID matches pipeline + if model_id != "" and model_id != pipeline.model_id: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=http_error( + f"pipeline configured with {pipeline.model_id} but called with " + f"{model_id}." + ), + ) + + # Validate duration + if duration < 1.0 or duration > 30.0: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=http_error( + "Duration must be between 1 and 30 seconds." + ), + ) + + try: + # Generate audio + audio_data, audio_format = pipeline( + prompt=prompt, + duration=duration, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + negative_prompt=negative_prompt, + ) + + # Return audio file + return Response( + content=audio_data, + media_type=f"audio/{audio_format}", + headers={ + "Content-Disposition": f"attachment; filename=generated_audio.{audio_format}" + } + ) + + except Exception as e: + if isinstance(e, torch.cuda.OutOfMemoryError): + torch.cuda.empty_cache() + logger.error(f"TextToAudio pipeline error: {e}") + return handle_pipeline_exception( + e, + default_error_message="Text-to-audio pipeline error.", + custom_error_config=PIPELINE_ERROR_CONFIG, + ) \ No newline at end of file diff --git a/runner/dl_checkpoints.sh b/runner/dl_checkpoints.sh index 4a03c134..5d751caf 100755 --- a/runner/dl_checkpoints.sh +++ b/runner/dl_checkpoints.sh @@ -37,6 +37,7 @@ function download_beta_models() { # Download custom pipeline models. huggingface-cli download facebook/sam2-hiera-large --include "*.pt" "*.yaml" --cache-dir models + printf "\nDownloading token-gated models...\n" @@ -81,7 +82,8 @@ function download_restricted_models() { huggingface-cli download black-forest-labs/FLUX.1-dev --include "*.safetensors" "*.json" "*.txt" "*.model" --exclude ".onnx" ".onnx_data" --cache-dir models ${TOKEN_FLAG:+"$TOKEN_FLAG"} # Download LLM models (Warning: large model size) huggingface-cli download meta-llama/Meta-Llama-3.1-8B-Instruct --include "*.json" "*.bin" "*.safetensors" "*.txt" --cache-dir models - + #Download text-to-audio model + huggingface-cli download stabilityai/stable-audio-open-1.0 --include "*.safetensors" "*.bin" "*.json" --cache-dir models } # Enable HF transfer acceleration. diff --git a/runner/gateway.openapi.yaml b/runner/gateway.openapi.yaml index 019d65d2..726eb69d 100644 --- a/runner/gateway.openapi.yaml +++ b/runner/gateway.openapi.yaml @@ -411,6 +411,56 @@ paths: security: - HTTPBearer: [] x-speakeasy-name-override: imageToText + /text-to-audio: + post: + tags: + - generate + summary: Text To Audio + description: Generate audio from text prompts using Stable Audio. + operationId: genTextToAudio + requestBody: + content: + application/x-www-form-urlencoded: + schema: + $ref: '#/components/schemas/Body_genTextToAudio' + required: true + responses: + '200': + description: Successfully generated audio + content: + application/json: + schema: {} + audio/wav: + schema: + type: string + format: binary + '400': + description: Bad Request + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPError' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPError' + '500': + description: Internal Server Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPError' + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' + security: + - HTTPBearer: [] + x-speakeasy-name-override: textToAudio components: schemas: APIError: @@ -694,6 +744,43 @@ components: - image - model_id title: Body_genSegmentAnything2 + Body_genTextToAudio: + properties: + prompt: + type: string + title: Prompt + description: Text prompt for audio generation. + model_id: + type: string + title: Model Id + description: Hugging Face model ID used for audio generation. + default: '' + duration: + type: number + title: Duration + description: Duration of generated audio in seconds (between 1 and 30 seconds). + default: 5.0 + num_inference_steps: + type: integer + title: Num Inference Steps + description: Number of denoising steps. More steps usually lead to higher + quality audio but slower inference. + default: 10 + guidance_scale: + type: number + title: Guidance Scale + description: Scale for classifier-free guidance. Higher values result in + audio that better matches the prompt but may be lower quality. + default: 3.0 + negative_prompt: + type: string + title: Negative Prompt + description: Text prompt to guide what to exclude from audio generation. + type: object + required: + - prompt + - model_id + title: Body_genTextToAudio Body_genUpscale: properties: prompt: diff --git a/runner/gen_openapi.py b/runner/gen_openapi.py index f6d0e5dc..cdd762d3 100644 --- a/runner/gen_openapi.py +++ b/runner/gen_openapi.py @@ -16,6 +16,7 @@ upscale, llm, image_to_text, + text_to_audio, ) from fastapi.openapi.utils import get_openapi @@ -104,6 +105,7 @@ def write_openapi(fname: str, entrypoint: str = "runner"): app.include_router(segment_anything_2.router) app.include_router(llm.router) app.include_router(image_to_text.router) + app.include_router(text_to_audio.router) logger.info(f"Generating OpenAPI schema for '{entrypoint}' entrypoint...") openapi = get_openapi( diff --git a/runner/go.mod b/runner/go.mod new file mode 100644 index 00000000..c0154ff0 --- /dev/null +++ b/runner/go.mod @@ -0,0 +1,3 @@ +module github.com/livepeer/ai-worker + +go 1.23.2 diff --git a/runner/openapi.yaml b/runner/openapi.yaml index d6a0d35a..e64d27c2 100644 --- a/runner/openapi.yaml +++ b/runner/openapi.yaml @@ -422,6 +422,56 @@ paths: security: - HTTPBearer: [] x-speakeasy-name-override: imageToText + /text-to-audio: + post: + tags: + - generate + summary: Text To Audio + description: Generate audio from text prompts using Stable Audio. + operationId: genTextToAudio + requestBody: + content: + application/x-www-form-urlencoded: + schema: + $ref: '#/components/schemas/Body_genTextToAudio' + required: true + responses: + '200': + description: Successfully generated audio + content: + application/json: + schema: {} + audio/wav: + schema: + type: string + format: binary + '400': + description: Bad Request + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPError' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPError' + '500': + description: Internal Server Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPError' + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' + security: + - HTTPBearer: [] + x-speakeasy-name-override: textToAudio components: schemas: APIError: @@ -699,6 +749,42 @@ components: required: - image title: Body_genSegmentAnything2 + Body_genTextToAudio: + properties: + prompt: + type: string + title: Prompt + description: Text prompt for audio generation. + model_id: + type: string + title: Model Id + description: Hugging Face model ID used for audio generation. + default: '' + duration: + type: number + title: Duration + description: Duration of generated audio in seconds (between 1 and 30 seconds). + default: 5.0 + num_inference_steps: + type: integer + title: Num Inference Steps + description: Number of denoising steps. More steps usually lead to higher + quality audio but slower inference. + default: 10 + guidance_scale: + type: number + title: Guidance Scale + description: Scale for classifier-free guidance. Higher values result in + audio that better matches the prompt but may be lower quality. + default: 3.0 + negative_prompt: + type: string + title: Negative Prompt + description: Text prompt to guide what to exclude from audio generation. + type: object + required: + - prompt + title: Body_genTextToAudio Body_genUpscale: properties: prompt: diff --git a/runner/requirements.txt b/runner/requirements.txt index 87f72e43..56086fa6 100644 --- a/runner/requirements.txt +++ b/runner/requirements.txt @@ -19,3 +19,4 @@ sentencepiece== 0.2.0 protobuf==5.27.2 bitsandbytes==0.43.3 psutil==6.0.0 +soundfile \ No newline at end of file diff --git a/worker/docker.go b/worker/docker.go index b94dce66..965c56a8 100644 --- a/worker/docker.go +++ b/worker/docker.go @@ -38,6 +38,7 @@ var containerHostPorts = map[string]string{ "llm": "8500", "segment-anything-2": "8600", "image-to-text": "8700", + "text-to-audio": "8008", } // Mapping for per pipeline container images. diff --git a/worker/multipart.go b/worker/multipart.go index c66970db..057ef05e 100644 --- a/worker/multipart.go +++ b/worker/multipart.go @@ -406,3 +406,46 @@ func NewImageToTextMultipartWriter(w io.Writer, req GenImageToTextMultipartReque return mw, nil } +func NewTextToAudioFormdataWriter(w io.Writer, req GenTextToAudioFormdataRequestBody) (*multipart.Writer, error) { + mw := multipart.NewWriter(w) + + if err := mw.WriteField("prompt", req.Prompt); err != nil { + return nil, fmt.Errorf("failed to write prompt field: %w", err) + } + + if req.ModelId != nil { + if err := mw.WriteField("model_id", *req.ModelId); err != nil { + return nil, fmt.Errorf("failed to write model_id field: %w", err) + } + } + + if req.Duration != nil { + if err := mw.WriteField("duration", fmt.Sprintf("%f", *req.Duration)); err != nil { + return nil, fmt.Errorf("failed to write duration field: %w", err) + } + } + + if req.NumInferenceSteps != nil { + if err := mw.WriteField("num_inference_steps", strconv.Itoa(*req.NumInferenceSteps)); err != nil { + return nil, fmt.Errorf("failed to write num_inference_steps field: %w", err) + } + } + + if req.GuidanceScale != nil { + if err := mw.WriteField("guidance_scale", fmt.Sprintf("%f", *req.GuidanceScale)); err != nil { + return nil, fmt.Errorf("failed to write guidance_scale field: %w", err) + } + } + + if req.NegativePrompt != nil { + if err := mw.WriteField("negative_prompt", *req.NegativePrompt); err != nil { + return nil, fmt.Errorf("failed to write negative_prompt field: %w", err) + } + } + + if err := mw.Close(); err != nil { + return nil, fmt.Errorf("failed to close multipart writer: %w", err) + } + + return mw, nil +} diff --git a/worker/runner.gen.go b/worker/runner.gen.go index 999da920..3c244375 100644 --- a/worker/runner.gen.go +++ b/worker/runner.gen.go @@ -169,6 +169,27 @@ type BodyGenSegmentAnything2 struct { ReturnLogits *bool `json:"return_logits,omitempty"` } +// BodyGenTextToAudio defines model for Body_genTextToAudio. +type BodyGenTextToAudio struct { + // Duration Duration of generated audio in seconds (between 1 and 30 seconds). + Duration *float32 `json:"duration,omitempty"` + + // GuidanceScale Scale for classifier-free guidance. Higher values result in audio that better matches the prompt but may be lower quality. + GuidanceScale *float32 `json:"guidance_scale,omitempty"` + + // ModelId Hugging Face model ID used for audio generation. + ModelId *string `json:"model_id,omitempty"` + + // NegativePrompt Text prompt to guide what to exclude from audio generation. + NegativePrompt *string `json:"negative_prompt,omitempty"` + + // NumInferenceSteps Number of denoising steps. More steps usually lead to higher quality audio but slower inference. + NumInferenceSteps *int `json:"num_inference_steps,omitempty"` + + // Prompt Text prompt for audio generation. + Prompt string `json:"prompt"` +} + // BodyGenUpscale defines model for Body_genUpscale. type BodyGenUpscale struct { // Image Uploaded image to modify with the pipeline. @@ -344,6 +365,9 @@ type GenLLMFormdataRequestBody = BodyGenLLM // GenSegmentAnything2MultipartRequestBody defines body for GenSegmentAnything2 for multipart/form-data ContentType. type GenSegmentAnything2MultipartRequestBody = BodyGenSegmentAnything2 +// GenTextToAudioFormdataRequestBody defines body for GenTextToAudio for application/x-www-form-urlencoded ContentType. +type GenTextToAudioFormdataRequestBody = BodyGenTextToAudio + // GenTextToImageJSONRequestBody defines body for GenTextToImage for application/json ContentType. type GenTextToImageJSONRequestBody = TextToImageParams @@ -508,6 +532,11 @@ type ClientInterface interface { // GenSegmentAnything2WithBody request with any body GenSegmentAnything2WithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) + // GenTextToAudioWithBody request with any body + GenTextToAudioWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) + + GenTextToAudioWithFormdataBody(ctx context.Context, body GenTextToAudioFormdataRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) + // GenTextToImageWithBody request with any body GenTextToImageWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) @@ -613,6 +642,30 @@ func (c *Client) GenSegmentAnything2WithBody(ctx context.Context, contentType st return c.Client.Do(req) } +func (c *Client) GenTextToAudioWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewGenTextToAudioRequestWithBody(c.Server, contentType, body) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) GenTextToAudioWithFormdataBody(ctx context.Context, body GenTextToAudioFormdataRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewGenTextToAudioRequestWithFormdataBody(c.Server, body) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + func (c *Client) GenTextToImageWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) { req, err := NewGenTextToImageRequestWithBody(c.Server, contentType, body) if err != nil { @@ -861,6 +914,46 @@ func NewGenSegmentAnything2RequestWithBody(server string, contentType string, bo return req, nil } +// NewGenTextToAudioRequestWithFormdataBody calls the generic GenTextToAudio builder with application/x-www-form-urlencoded body +func NewGenTextToAudioRequestWithFormdataBody(server string, body GenTextToAudioFormdataRequestBody) (*http.Request, error) { + var bodyReader io.Reader + bodyStr, err := runtime.MarshalForm(body, nil) + if err != nil { + return nil, err + } + bodyReader = strings.NewReader(bodyStr.Encode()) + return NewGenTextToAudioRequestWithBody(server, "application/x-www-form-urlencoded", bodyReader) +} + +// NewGenTextToAudioRequestWithBody generates requests for GenTextToAudio with any type of body +func NewGenTextToAudioRequestWithBody(server string, contentType string, body io.Reader) (*http.Request, error) { + var err error + + serverURL, err := url.Parse(server) + if err != nil { + return nil, err + } + + operationPath := fmt.Sprintf("/text-to-audio") + if operationPath[0] == '/' { + operationPath = "." + operationPath + } + + queryURL, err := serverURL.Parse(operationPath) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("POST", queryURL.String(), body) + if err != nil { + return nil, err + } + + req.Header.Add("Content-Type", contentType) + + return req, nil +} + // NewGenTextToImageRequest calls the generic GenTextToImage builder with application/json body func NewGenTextToImageRequest(server string, body GenTextToImageJSONRequestBody) (*http.Request, error) { var bodyReader io.Reader @@ -996,6 +1089,11 @@ type ClientWithResponsesInterface interface { // GenSegmentAnything2WithBodyWithResponse request with any body GenSegmentAnything2WithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*GenSegmentAnything2Response, error) + // GenTextToAudioWithBodyWithResponse request with any body + GenTextToAudioWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*GenTextToAudioResponse, error) + + GenTextToAudioWithFormdataBodyWithResponse(ctx context.Context, body GenTextToAudioFormdataRequestBody, reqEditors ...RequestEditorFn) (*GenTextToAudioResponse, error) + // GenTextToImageWithBodyWithResponse request with any body GenTextToImageWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*GenTextToImageResponse, error) @@ -1186,6 +1284,32 @@ func (r GenSegmentAnything2Response) StatusCode() int { return 0 } +type GenTextToAudioResponse struct { + Body []byte + HTTPResponse *http.Response + JSON200 *interface{} + JSON400 *HTTPError + JSON401 *HTTPError + JSON422 *HTTPValidationError + JSON500 *HTTPError +} + +// Status returns HTTPResponse.Status +func (r GenTextToAudioResponse) Status() string { + if r.HTTPResponse != nil { + return r.HTTPResponse.Status + } + return http.StatusText(0) +} + +// StatusCode returns HTTPResponse.StatusCode +func (r GenTextToAudioResponse) StatusCode() int { + if r.HTTPResponse != nil { + return r.HTTPResponse.StatusCode + } + return 0 +} + type GenTextToImageResponse struct { Body []byte HTTPResponse *http.Response @@ -1309,6 +1433,23 @@ func (c *ClientWithResponses) GenSegmentAnything2WithBodyWithResponse(ctx contex return ParseGenSegmentAnything2Response(rsp) } +// GenTextToAudioWithBodyWithResponse request with arbitrary body returning *GenTextToAudioResponse +func (c *ClientWithResponses) GenTextToAudioWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*GenTextToAudioResponse, error) { + rsp, err := c.GenTextToAudioWithBody(ctx, contentType, body, reqEditors...) + if err != nil { + return nil, err + } + return ParseGenTextToAudioResponse(rsp) +} + +func (c *ClientWithResponses) GenTextToAudioWithFormdataBodyWithResponse(ctx context.Context, body GenTextToAudioFormdataRequestBody, reqEditors ...RequestEditorFn) (*GenTextToAudioResponse, error) { + rsp, err := c.GenTextToAudioWithFormdataBody(ctx, body, reqEditors...) + if err != nil { + return nil, err + } + return ParseGenTextToAudioResponse(rsp) +} + // GenTextToImageWithBodyWithResponse request with arbitrary body returning *GenTextToImageResponse func (c *ClientWithResponses) GenTextToImageWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*GenTextToImageResponse, error) { rsp, err := c.GenTextToImageWithBody(ctx, contentType, body, reqEditors...) @@ -1706,6 +1847,63 @@ func ParseGenSegmentAnything2Response(rsp *http.Response) (*GenSegmentAnything2R return response, nil } +// ParseGenTextToAudioResponse parses an HTTP response from a GenTextToAudioWithResponse call +func ParseGenTextToAudioResponse(rsp *http.Response) (*GenTextToAudioResponse, error) { + bodyBytes, err := io.ReadAll(rsp.Body) + defer func() { _ = rsp.Body.Close() }() + if err != nil { + return nil, err + } + + response := &GenTextToAudioResponse{ + Body: bodyBytes, + HTTPResponse: rsp, + } + + switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: + var dest interface{} + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON200 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 400: + var dest HTTPError + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON400 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 401: + var dest HTTPError + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON401 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 422: + var dest HTTPValidationError + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON422 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 500: + var dest HTTPError + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON500 = &dest + + case rsp.StatusCode == 200: + // Content-type (audio/wav) unsupported + + } + + return response, nil +} + // ParseGenTextToImageResponse parses an HTTP response from a GenTextToImageWithResponse call func ParseGenTextToImageResponse(rsp *http.Response) (*GenTextToImageResponse, error) { bodyBytes, err := io.ReadAll(rsp.Body) @@ -1837,6 +2035,9 @@ type ServerInterface interface { // Segment Anything 2 // (POST /segment-anything-2) GenSegmentAnything2(w http.ResponseWriter, r *http.Request) + // Text To Audio + // (POST /text-to-audio) + GenTextToAudio(w http.ResponseWriter, r *http.Request) // Text To Image // (POST /text-to-image) GenTextToImage(w http.ResponseWriter, r *http.Request) @@ -1891,6 +2092,12 @@ func (_ Unimplemented) GenSegmentAnything2(w http.ResponseWriter, r *http.Reques w.WriteHeader(http.StatusNotImplemented) } +// Text To Audio +// (POST /text-to-audio) +func (_ Unimplemented) GenTextToAudio(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotImplemented) +} + // Text To Image // (POST /text-to-image) func (_ Unimplemented) GenTextToImage(w http.ResponseWriter, r *http.Request) { @@ -2029,6 +2236,23 @@ func (siw *ServerInterfaceWrapper) GenSegmentAnything2(w http.ResponseWriter, r handler.ServeHTTP(w, r.WithContext(ctx)) } +// GenTextToAudio operation middleware +func (siw *ServerInterfaceWrapper) GenTextToAudio(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + ctx = context.WithValue(ctx, HTTPBearerScopes, []string{}) + + handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + siw.Handler.GenTextToAudio(w, r) + })) + + for _, middleware := range siw.HandlerMiddlewares { + handler = middleware(handler) + } + + handler.ServeHTTP(w, r.WithContext(ctx)) +} + // GenTextToImage operation middleware func (siw *ServerInterfaceWrapper) GenTextToImage(w http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -2197,6 +2421,9 @@ func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handl r.Group(func(r chi.Router) { r.Post(options.BaseURL+"/segment-anything-2", wrapper.GenSegmentAnything2) }) + r.Group(func(r chi.Router) { + r.Post(options.BaseURL+"/text-to-audio", wrapper.GenTextToAudio) + }) r.Group(func(r chi.Router) { r.Post(options.BaseURL+"/text-to-image", wrapper.GenTextToImage) }) @@ -2210,67 +2437,72 @@ func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handl // Base64 encoded, gzipped, json marshaled Swagger object var swaggerSpec = []string{ - "H4sIAAAAAAAC/+xceW/cOLL/KoTeA+wAfdkznjwY2D+czBHj2ZnA7mxmkBgNtlSt5lgitSTl7t48f/cH", - "FimJOvrKOp7dTP+VtnjUxfpVsUjmcxCKNBMcuFbB+edAhXNIKf68eHf5k5RCmt8RqFCyTDPBg3PTQsA0", - "EQkqE1wBSUUEySDoBZkUGUjNAOdIVdwePp6DG56CUjQGM04znUBwHlyr2Py1yswfSkvG4+DxsRdI+EfO", - "JETB+Uec9a4aUjJajhPTPyDUwWMveCWi1SQGfpFHTIzFGJbaMFTnkprGNp/vs0TQCCKC7WTGEiBakCkQ", - "LSk3PacQGd5nQqZUB+fBlHEqV540SLYtTy9AfU1YZKnOaJ6Y8UGvwcKbPI4Zj8nPNHQ6Jpc/klxBRGZC", - "lnxg95oWbdeoi7QEnUs+0SwFpWmaqToPWubQ4uMGx5BqjCU/r6mCaFjqAbnNs0xIDRF5oEkO6pwcKeAa", - "eAhHPXK0EDI66hEhCSWWKTIVIgHKyfGRIX5k2o5mNFFw9GJAfrScEaaIaz6u5nsxKHqSFChXhAuPyYGj", - "5trM7/6UGuV5fTytOSnHlWa2rUS7cry12LXeNizLy5TGMBb4T3tdxjmLKA9hokKaQM1MLwdnTRv9xEOR", - "SxqDcitFCxIDB0k1EJZiQ5gIBcmKJIzfGzUIa0NYapJJkWaaHM9ZPAfpbEdSuiISojx0U5B/5DRhevXC", - "19svjk9yi3yW8vI8nYI08rJCwDUeZufWwnDOZiuyYHqOrGUsg4Rx2OxmVn8dax3nnWzQ40lbjz9CLAGZ", - "WcxZaNko9FhwyhTJcjVHFS6ojBT2YpxpRhPbZ9Dkj2xXUyIkVVsg4YJciZsLcnwlFv0byu/JRUQzTU3r", - "C2d4yiPCtCKhkBagI+NlC2DxXKPjWiGcUAY6yE9LmmYJnJPP5FOQUA1c90PBFVPG0VbDJEz7hru+ipbJ", - "p+CcnAxGPfIp4CDZH2qYsSUkfSp1v2g9ffQVcIWCfTUcbMmzIxRyiKlmDzCxi38LE+PKTY7VC3SvnEVA", - "FnOqzV+wDJM8AjKTIu1Q8WXMhTQraEbqC5J8ykej70Jy4rP91rFG3lnWurjP04n160kGskuGk6YIb3Gp", - "ETErAMHHiAykE6/GSJ6SS9v5HcgWO4xriO3qRX74DCSgaBoaoeVkNFrPTwRcMGVsjAMH5FpIsL9JrnKa", - "GNQCipjlIMpBUSHKNNdEJWIBkpRcmGmiPEHPna5MvAEe63lLvqI/uUWuu6Tz1bvLqti0JtfbVNEZ6NUk", - "nEN4X1OeCX1N7b0DaTDRBFIcRnAYLkWlWYq4P2til4GFPIlMGiNmM+DKLDIhyZzKdJYnPpu3dtbXyEzJ", - "rIvWyC1A1NbILTi3lJRHIiUW39aownTu1Hdhq5oWRoP/WQPXYmZTERskmOCEZlnCqiAnobCxtczxyLSc", - "1ALZbUGzhc2NuJ8VBrSBrSMBqEX27RlAd2K6c9gsRX+yyPmECWppkl1h+V9C4/Uk13ldw7bbTLpjTvd3", - "FoFom3TWAMUfeh27o5mkKSgEZAWh4BEu71oe8mCm96X7eQ1uzTHs12ieveykansSxgmGc7UD0Td28i66", - "O6/dMv5QOz/Gzz911Vo29k8nUmF6T6Z5eA+6ycXJ6csmG+8LgsbEzHw0TBmV01TkXBsD2DnL7ZafUKDN", - "bCg0TQ5mzc/UxE43csGSxIA949jUMuG17fYKma4J5od2wRRMaB5P1sDy6LSVp5Yi4GBCo6gC45rANl0m", - "b2obD7fpkKAgnSaYNq8daxNeHkqgqpC7FuKRgYs8JusBfnv6cnr2H5y9HPKKQhMLFjVW78no9PsuPMSe", - "e8HhB5y7TXXPCGNDx4YQc3V13Y4sc6a0kKs69H2889Ha9eiCLrqcaHEPvLnmf/CQgi7J2PbpUuxa7N0v", - "5O+QI2sJNK2RwRpQPY+jaffSWikN6aQsTHbweYtdSGclshdoSDNj/1xCAwNfVlOMvU475pIdy8GYecMq", - "uIU4Ba4v+ErPGY9P20tiKpYd1VuSIIyQ7wmVkq5IzB6AE6oIJVOxLApBDm3Rqj3jBb/9/tvvxMZkf82/", - "Esu1lZc28csi6ivL/JfGearuJ4xnue6UTyz6EpRIcgxtpjPBzg2h9CpjIWIzbtkpySQ8MJEr8yNiIY5m", - "2qFLr8qtER1Plm+WH8jxm799+Nvp2Q8ITLcX17X9xLWhfIls/tvVPtI8MViu7ici16UiN0SFS7PDyqFX", - "adDmFtLVhudmG2YmtMVhmk5ZnBtlWtXbZaV6RMw0cPNnlIdY/QWtQbqRek65iTuMxwl4ZqhJVXBOfrWc", - "d/k5N4sqYf+ESSiEjNR+4mWCcU1wJONUgyrTqHLeamNJeQzk46h3cueWCI52dAksMwi17T4F20GCMh/N", - "J2u+iKUmYgqu6nmLo0VeWxm6BPWJtZ3h7fLUebmYOamcIRq+sJiDBAI0dOwTZgxHjn/r/f6iioG17RR2", - "a3LmQToyltApJB2MXeH3Mq+tsVZwc0IYj1iI+qemK8RS5DxyvU3WN6p1mdLw3u/SZteS3XAskoiY6T1W", - "ix2mSM77xgPUXCQmz8XlaecijCttcj8xMywixmF7x9HDlaXetvOuGUQrJmyIH++zsh7+hWWHJ67WPw0g", - "5las6Murwls2Ai/P/kJlzJ20eahnbtt37F0/LJyzw3/fjMfv1pzMm6Ydj+Yj0JQlePydJL/OgvOPn4P/", - "ljALzoP/GlaXAobuRsCwPGV/vGuXYM1UEDnKjHvlt6bkjqwncSXOGln/ThMW4XSl1OtEYRpS/LRJkuZ8", - "jxUvVpKKEQydKIPPbXOCLr6BJnr+ulj2dX6VpjpvnPP9+r+1OjR26KpUVpW3ikAHfcTYG7cE2uvkprY4", - "1uaRHWFBdd/naDqlGb2TMa4hYtQ3gT1r6jJBKwAqfxnVJV6nElu+3UsxeDq+SS/a1e43aQWvRgT+BnG5", - "vRKNEzdFbEjQIejV1bUvYJ1Z6bVUyUdzMm+7ixv+iYmu/hBbByDv1S5oJ735vek8yXyWOyQy+yi1l9Hs", - "2GKXucZwfq7XNJ2kix7JuZfuV5sRRY7t0Bdl/oq7l/qJdz2Tq+9dt3pRaz5UQWeEDYVc55WojyMTIfmM", - "RZgZ2O7INyb7dZK1SGYn3noZyzGmiu5Oq3cN3jfaF0GgYyefmobCmKHgmjJb9uXeOd9UmJ19XX1mXNvg", - "XM0WbTIf5qCLIroluKCKzBIaxxARqsjb258/1HItM83u+YOxhGmxKap/4lFS3Klymcuke/L3N1dux1SJ", - "EFJuUiIahqCUvaZWEHgvk61WzbGPsqyg2nx7ork67Pj02BrOc77dW3Aa23XnwIPd/cDz2pJqBp7eV8b3", - "XiHjXX30Jn8x7e5s+R2V1Ar7rd4de8oDzNbNrA0HmIfLWIfLWN/uZayzv/RdLHILGUU9Yxk8w1qzLYti", - "Mevo/47M0lDlVebpqiqWHioff9qJawu/dzxxbZ+xtUNoR5zdWnlIRFgrO1C+cqWU5nr43GLx7tGH5BDJ", - "dGQf7qSyyr3wyUTnVg0/VF2RZzI2X7dlIkYOS8r19DS1Q7UDD6z3Svy6rtg0LkrhHahteVdxY8j0raV+", - "exYfmilfcanKMrGlGOFY9XVWU0iHxmz22bHjwQZc+AbLEIxo9Vxhn+IDTuA8CGfdnp+adkdpzZxFc2vi", - "Qt+e8sp3FNv0p/2OjZqHVVRLgwhZYS6ZXt0aY1plvBmP370CKkGWb5kQ5+yncpK51lnwaOYw+8gOK7jL", - "o9YnDQrLnJOLy/J8Q/nZFHuADECa9puccyT0AFLZuUaD0WBkNCsy4DRjwXnw3eBkMDKGpHqObA/xGUlf", - "i35hzUyoLquWb228J0n2IM9tP0Tm/OkyMrl18x2KUToo/UpEeE3EbKiBIyEbBqnUQxOH+hHVtHoOts2N", - "uh69PNaNbIIefrAugWKfjkYNLjytD/9QRuZdWajtmJB2I5LluBGe5QmpuvWC75+QhaqO3UH/FY3IjdW+", - "pXvyPHTfc5rruZDsnxAh4ZPvnoewE5b8xLXJC8dCkCsqY6v1k7Pnkr7K4BDrbTQ0LJyePikLrTOFNjNV", - "F1KeO5w91/q75Bokpwm5BfkAsuDAg1HMWnwA/Xj3eNcLVJ6mVK6Kt4xkLEgRO2isDHoXwdig9rKvMqD3", - "QNWqz2kKffEAUrIIsb+GDr1gOMcjDCxWAMpeRy97whF8RdDwz1B2xYxHXyWORZQGU1OD4eXpeTeIX2RZ", - "siqO0Gt31RHJqdmImKzGS3ZbqN54XPCVYb1G7ZlxvX6qcwD29cB+ALR9Ac3eRRwLUl5I2RPRWN0xfBDY", - "IZHDDb/Fge15XP3tyfM4/J+Rx3UdcR68/t88nTtAzxdDzxfmUqzmoT7wPJTPzjqR55eux1Z7JR3F44Tn", - "wSBL7ZlBqF69OcDPIen4Cp5fPvL5MtcvHKMXDJMk3cHhsRyY46kOJQnlcW4YKQ80Wu5uH5+s93Jfxcv+", - "YrHoo7fnMgEeisgeJ+zn84bkM7u6f/Ho4OgHR386R3ePt/b0buPL6NTu5lqfupv8/dP1Pu4u/bt7Uvhu", - "g/INkbzjkcBXjuYtis/s5vUbaAdHPzj60zl64X3F4ianX+D3qu0gvWBoYvYOJcVfGje1MKf3LmapThTw", - "TsB3DvT7H8rUz9gP1cOD238jbo93jf6F4qH23A+dPfee43W6uXsSVMZ2Ml0V/+sF3onWilSvnjtdvnpU", - "9JXjfUHo4O8Hf/9G/N17kLenp+e+MyhkQCG5xovo4iLJ60TkEXkt0jTnTK/IL1TDgq4C9wIAr6+o8+Ew", - "kkDTfmxbB4kbPgjNcLxxtmb+W40nueumLSdS2G9IMzacgqbDUt7Hu8f/DwAA//+2vdykOFgAAA==", + "H4sIAAAAAAAC/+xce28bOZL/KkTfAXYAPZ3x5GBg/3AyjxhnzwS2splBYghUd6nFcTfZS7ItaXP+7gcW", + "2S32Qy+v7dnN6K/xqEnWi/WrYhWZr0Eo0kxw4FoFZ18DFc4gpfjn+YeLH6UU0vwdgQolyzQTPDgzXwiY", + "T0SCygRXQFIRQdILOkEmRQZSM8A1UhU3p49m4KanoBSNwczTTCcQnAVXKjb/t8zM/ygtGY+Dh4dOIOEf", + "OZMQBWefcdXb1ZSS0XKemPwBoQ4eOsFbES3HMfDzPGJiJEaw0IahKpfUfGzy+TFLBI0gIvidTFkCRAsy", + "AaIl5WbkBCLD+1TIlOrgLJgwTuXSkwbJNuXpBKivMYss1SnNEzM/6NRYeJ/HMeMx+YmGTsfk4geSK4jI", + "VMiSDxxe0aIdGrWRlqBzyceapaA0TTNV5UHLHBp8XOMcsppjyc8qqiAaFrpHbvIsE1JDRO5pkoM6I0cK", + "uAYewlGHHM2FjI46REhCiWWKTIRIgHJyfGSIH5lvR1OaKDh61SM/WM4IU8R9Pl6t96pXjCQpUK4IFx6T", + "PUfNfTN/dyfUKM8b42nNSTlaaWbbTrQ7x9uLbfttw7a8SGkMI4H/ae7LOGcR5SGMVUgTqJjpTe+0bqMf", + "eShySWNQbqdoQWLgIKkGwlL8ECZCQbIkCeN3Rg3C2hAWmmRSpJkmxzMWz0A625GULomEKA/dEuQfOU2Y", + "Xr7y9faz45PcIJ+lvDxPJyCNvKwQcI2H2bW1MJyz6ZLMmZ4haxnLIGEcNruZ1V/LXsd1xxv0OGzq8QeI", + "JSAz8xkLLRuFHgtOmSJZrmaowjmVkcJRjDPNaGLH9Or8ke1qSoSkagsknJNLcX1Oji/FvHtN+R05j2im", + "qfn6yhme8ogwrUgopAXoyHjZHFg80+i4VggnlIEO8uOCplkCZ+Qr+RIkVAPX3VBwxZRxtGU/CdOu4a6r", + "okXyJTgjw96gQ74EHCT7Q/UztoCkS6XuFl9PHnwFXKJgz4aDDXl2hEIOMdXsHsZ2829hYrRyk2P1Ct0r", + "ZxGQ+Yxq83+wCJM8AjKVIm1R8UXMhTQ7aEqqG5J8yQeD1yEZ+mz/4lgjHyxrbdzn6dj69TgD2SbDsC7C", + "L7jViJgWgOBjRAbSiVdhJE/JhR38AWSDHcY1xHb3Ij98ChJQNA210DIcDNbzEwEXTBkb48QeuRIS7N8k", + "VzlNDGoBRcxyEOWgqBBlkmuiEjEHSUouzDJRnqDnTpYm3gCP9awhXzGe3CDXbdL56t1lV2zak+ttqugU", + "9HIcziC8qyjPhL669j6ANJhoAilOIzgNt6LSLEXcn9axy8BCnkQmjRHTKXBlNpmQZEZlOs0Tn80bu+o7", + "ZKZk1kVr5BYgamrkBpxbSsojkRKLb2tUYQa36ruwVUULg97/rIFrMbWpiA0STHBCsyxhqyAnobCxtczx", + "wHwZVgLZTUGzgc21uJ8VBrSBrSUBqET27RlAe2K6c9gsRX+yyPmECWppkl1h+V9C4/Uk13ldzbbbTLpj", + "Tvd3FoFomnRaA8XvOy2no6mkKSgEZAWh4BFu70oecm+W96X7aQ1uzTDsV2ievmmlakcSxgmGc7UD0fd2", + "8Ta6O+/dMv5Quz7Gzz9111o29k8nUmFGjyd5eAe6zsXw5E2djY8FQWNiZn40TBmV01TkXBsD2DXL45af", + "UKDNbCg0nxzMmj9TEzvdzDlLEgP2jOOnhgmv7LC3yHRFMD+0C6ZgTPN4vAaWByeNPLUUAScTGkUrMK4I", + "bNNl8r5y8HCHDgkK0kmCafPauTbh5aEEqgq5KyEeGTjPY7Ie4LenLyen/8HZyyGvKDQxZ1Ft9w4HJ9+1", + "4SGO3AsOP+HaTap7RhgbOjaEmMvLq2ZkmTGlhVxWoe/zrY/WbkQbdNHFWIs74PU9/72HFHRBRnZMm2LX", + "Yu9+IX+HHFlLoGmFDNaAqnkcTdu31lJpSMdlYbKFzxscQlorkZ1AQ5oZ++cSahj4ZrXEyBu0Yy7Zsh2M", + "mTfsghuIU+D6nC/1jPH4pLklJmLRUr0lCcII+Y5QKemSxOweOKGKUDIRi6IQ5NAWrdoxXvDb77/9TmxM", + "9vf8W7FYW3lpEr8oor6yzD82zlN1N2Y8y3WrfGLelaBEkmNoM4MJDq4JpZcZCxGb8chOSSbhnolcmT8i", + "FuJsph26dFa5NaLjcPF+8Ykcv//bp7+dnH6PwHRzflU5T1wZyhfI5r9d7SPNE4Pl6m4scl0qckNUuDAn", + "rBw6Kw3a3EK62vDMHMPMgrY4TNMJi3OjTKt6u61Uh4ipBm7+N8pDrP6C1iDdTD2j3MQdxuMEPDNUpCo4", + "J79aztv8nJtNlbB/wjgUQkZqP/EywbgmOJNxqkGVaVS57upgSXkM5POgM7x1WwRnO7oEFhmE2g6fgB0g", + "QZkfzU/WfBFLTcQUXFXzFkeLvLMytAnqE2s6wy+LE+flYuqkcoao+cJ8BhII0NCxT5gxHDn+rfP7q1UM", + "rByncFidMw/SkbGETiBpYewSfy/z2gprBTdDwnjEQtQ/NUMhliLnkRttsr5BZciEhnf+kCa7luyGtkgi", + "Yqb32C12miI57xoPUDORmDwXt6ddizCutMn9xNSwiBiH31taD5eWetPOu2YQjZiwIX6Yg+xInBfNr2ro", + "iHILJ9WDY6P44kYZ2VZ5ke2XMe4Or4ocT0DPATgZotFeD4ovFbQsFmtLzjeU8F/XmcLyuj1XJVQpNmUg", + "u1ODTsUqzcOGwhYTd6zrGdUrXNLhDGyB34VGk6undGnc2WbsLpffryfyNPBvGX6i0vea+sqWUvcmFnYt", + "Y28sGz/Lucty3Xrsep7q8FZz7VifWp8v+g69we8/ZqUTPbLc+MRduqfxhNyKFT2+G7RlI745/Qu1L3bS", + "5qGPsa3esHffoHDOFv99Pxp9WHMjx3za8UpOBJqyBK+9JMmv0+Ds89fgvyVMg7Pgv/qry0B9dxOoX96u", + "ebhttl7MUhA5yox7Zfe65I6sJ/FKnDWy/p0mLMLlSqnXicI0pPjTJknq6z14qYddqWQEU2aUwee2vkAb", + "30ATPXtXbPsqv0pTndf6+7/+b6X/hAPaIsCq4r4i0EIfMfbabYHmPrmubI6158eWsKDa73HVndLM3skY", + "VxAx6pvA9pjbTNBIfJW/jaoSr1OJbdvspRi8FbNJL9r17DZpBa9EBX5haLE9wuPCdRFrErQIenl55QtY", + "ZVZ6X1aHjvpiXpkLC31jE139Kbb+Rz6qXdBOeut7y3mS+Sy3SHRF1Z3ay2h2blFdWmM4/4xXN52k8w7J", + "uXfMXxUhFDm2U1+V51asWlRvulRPcNWa1VYvaqyHKmiNsKGQ67wS9XFkIiSfsggzAzsc+cZDfpVkJZLZ", + "hbdewnSMqWK40+ptjfeN9kUQaKngpeZDYcxQcE2Zbfdwr78/EbmuVeRxXtPgXE3nTTKfZqCL5pklOKeK", + "TBMax+YQq8gvNz99quRaZpnd8wdjCfPFpqh+p7OkuFPHIpdJ++Ifry9dpWQlQki5SYloGIJS9npqQeCj", + "TLZaNccxyrKCavPtieZqsePTY2s4y/l2b8Fl7NCdAw8O9wPPO0uqHng6z4zvnULG2+rsTf5iD3gYDj5Q", + "Sa2w3+qd0ae8uNC4kbnh4sLhEubhEua3ewnz9C99B5PcQEZRz9j+yhIo2yFYzDr6vyOzNVT5hGGyXDVJ", + "DpWPP+2mRQO/d7xp0ayVNkNoS5zdWnlIRFgpO1C+dKWU+n742mDx9sGH5LDW8iizD3dDYZV74VOp1qMa", + "/rAaijyTkfl1WyZi5LCk3EhPUztUO/Ciyl6JX9vVutoFSbz7uC3vKm4KmrGV1G/P4kM95SsuU1omthQj", + "HKu+zioKadGYzT5bTjz4ATe+wTIEI7p6prRP8QEXcB6Eq27PT813R2nNmsXnxsKFvj3lle+ntulP+wNr", + "NQ+rqIYGEbLCXDK9vDHGtMp4Pxp9eAtUgizfMCLO2Z/KRWZaZ8GDWcOcI1us4C6NW580KCxzTs4vyv6G", + "8rMpdg8ZgDTfr3POkdA9SGXXGvQGvYHRrMiA04wFZ8Hr3rA3MIakeoZs97Eb1NWiW1gzE6rNquUbO+8p", + "om3gu+OHyJw/XUQmt66/PzNKB6Xfigivh5kDNXAkZMMglbpv4lA3opqunoFuc6O2x24PVSOboIc/WJdA", + "sU8GgxoXntb7fyjbdN6NhcqJCWnXIlmOB+FpnpDVsE7w3ROysKpjt9B/SyNybbVv6Q5fhu5HTnM9E5L9", + "EyIkPHz9MoSdsORHrk1eOBKCXFIZW60PT19K+lUGh1hvo6Fh4eTkSVlo9BSazKyGkLLvcPpS+++Ca5Cc", + "JuQG5D3IggMPRjFr8QH08+3DbSdQeZpSuSzeMJORIEXsoLEy6F0EY4Pai67KgN4BVcsupyl0xT1IySLE", + "/go6dIL+DFsYWKwAlL2KXrbDETwjaPg9lF0x48FXiWMRpcHU1GB42T1vB/HzLEuWRQu98kYFkZyag4jJ", + "arxkt4HqtUdFzwzrFWovjOvVrs4B2NcD+wHQ9gU0ewd5JEh5IWVPRGNVx/BBYIdEDg/8Fge253HVN2cv", + "4/B/Rh7X1uI8eP2/eTp3gJ5HQ88jcylW8VAfeO7L56atyPNz2yPLvZKO4lHSy2CQpfbCIFSt3hzg55B0", + "PIPnl4/7Huf6hWN0gn6SpDs4PJYDc+zqUJJQHueGkbKh0XB3++hsvZf7Kl505/N5F709lwnwUES2nbCf", + "zxuSL+zq/sWjg6MfHP3pHN092tzTu40vo1O7m2td6l7wdE/W+7h77OPuSeF7Lco3RPKWx0HPHM0bFF/Y", + "zas30A6OfnD0p3P0wvuKzU1OHuH3qukgnaBvYrZJ6st/mHJbUm+7Qial9+5lKRf2bzSdJECwjtuKC9V3", + "Qy8a+n3ST48NZkHUTX9O76vMtTxTqnXKN2BFsqy/djyAxgE0dgENvKk0EqR8pLcnXuiKv3hQsaX78HPt", + "UmcDKzYAw/Z2w7/Wv61exzk0Gg7O/o05+2P7DNpzP3T23Hu52+rm7vVgeQwgk2XxD2Ph8wmNb93dP4zS", + "6vKr94fPfDQoCB38/eDv34i/e2939/T03HcGhQwoJFf7R1OKO2fvEpFH5J1I05wzvSQ/Uw1zarJYfCyE", + "N93UWb8fSaBpN7Zfe4mb3gvNdLycumb9G40ninXLlgspHNenGetPQNN+Ke/D7cP/BwAA//8OvhFFW2AA", + "AA==", } // GetSwagger returns the content of the embedded swagger specification file diff --git a/worker/worker.go b/worker/worker.go index f54f6b49..ae425018 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -688,3 +688,60 @@ func (w *Worker) handleStreamingResponse(ctx context.Context, c *RunnerContainer return outputChan, nil } +func (w *Worker) TextToAudio(ctx context.Context, req GenTextToAudioFormdataRequestBody) (*http.Response, error) { + c, err := w.borrowContainer(ctx, "text-to-audio", *req.ModelId) + if err != nil { + return nil, err + } + defer w.returnContainer(c) + + var buf bytes.Buffer + mw, err := NewTextToAudioFormdataWriter(&buf, req) + if err != nil { + return nil, err + } + + resp, err := c.Client.GenTextToAudioWithBodyWithResponse(ctx, mw.FormDataContentType(), &buf) + if err != nil { + return nil, err + } + + if resp.JSON400 != nil { + val, err := json.Marshal(resp.JSON400) + if err != nil { + return nil, err + } + slog.Error("text-to-audio container returned 400", slog.String("err", string(val))) + return nil, errors.New("text-to-audio container returned 400: " + resp.JSON400.Detail.Msg) + } + + if resp.JSON401 != nil { + val, err := json.Marshal(resp.JSON401) + if err != nil { + return nil, err + } + slog.Error("text-to-audio container returned 401", slog.String("err", string(val))) + return nil, errors.New("text-to-audio container returned 401: " + resp.JSON401.Detail.Msg) + } + + if resp.JSON422 != nil { + val, err := json.Marshal(resp.JSON422) + if err != nil { + return nil, err + } + slog.Error("text-to-audio container returned 422", slog.String("err", string(val))) + return nil, errors.New("text-to-audio container returned 422: " + string(val)) + } + + if resp.JSON500 != nil { + val, err := json.Marshal(resp.JSON500) + if err != nil { + return nil, err + } + slog.Error("text-to-audio container returned 500", slog.String("err", string(val))) + return nil, errors.New("text-to-audio container returned 500: " + resp.JSON500.Detail.Msg) + } + + // Return the raw HTTP response since we're returning audio data + return resp.HTTPResponse, nil +}