Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/text to audio #241

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,13 @@ require (
go.opentelemetry.io/otel/metric v1.31.0 // indirect
go.opentelemetry.io/otel/sdk v1.31.0 // indirect
go.opentelemetry.io/otel/trace v1.31.0 // indirect
golang.org/x/mod v0.17.0 // indirect
golang.org/x/sys v0.26.0 // indirect
golang.org/x/text v0.19.0 // indirect
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
gotest.tools/v3 v3.5.1 // indirect
)

replace github.com/livepeer/ai-worker => ../home/user/test2/ai-worker
6 changes: 6 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA=
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
Expand Down Expand Up @@ -151,6 +153,8 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
Expand All @@ -166,6 +170,8 @@ google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojt
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
Expand Down
6 changes: 6 additions & 0 deletions runner/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def load_pipeline(pipeline: str, model_id: str) -> any:
from app.pipelines.image_to_text import ImageToTextPipeline

return ImageToTextPipeline(model_id)
case "text-to-audio":
from app.pipelines.text_to_audio import TextToAudioPipeline
return TextToAudioPipeline(model_id)
case _:
raise EnvironmentError(
f"{pipeline} is not a valid pipeline for model {model_id}"
Expand Down Expand Up @@ -101,6 +104,9 @@ def load_route(pipeline: str) -> any:
case "image-to-text":
from app.routes import image_to_text
return image_to_text.router
case "text-to-audio":
from app.routes import text_to_audio
return text_to_audio.router
case _:
raise EnvironmentError(f"{pipeline} is not a valid pipeline")

Expand Down
123 changes: 123 additions & 0 deletions runner/app/pipelines/text_to_audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import logging
import os
from typing import Any, Dict, Optional, Tuple

import torch
from app.pipelines.base import Pipeline
from app.pipelines.utils import get_model_dir, get_torch_device
from app.utils.errors import InferenceError
from diffusers import StableAudioPipeline
from huggingface_hub import file_download
import numpy as np
import soundfile as sf
import io

logger = logging.getLogger(__name__)

class TextToAudioPipeline(Pipeline):
def __init__(self, model_id: str):
"""Initialize the text to audio pipeline.

Args:
model_id: The model ID to use for audio generation.
"""
self.model_id = model_id
kwargs = {"cache_dir": get_model_dir()}

torch_device = get_torch_device()
folder_name = file_download.repo_folder_name(
repo_id=model_id, repo_type="model"
)
folder_path = os.path.join(get_model_dir(), folder_name)

# Load fp16 variant if available
has_fp16_variant = any(
".fp16.safetensors" in fname
for _, _, files in os.walk(folder_path)
for fname in files
)
if torch_device != "cpu" and has_fp16_variant:
logger.info("TextToAudioPipeline loading fp16 variant for %s", model_id)
kwargs["torch_dtype"] = torch.float16
kwargs["variant"] = "fp16"

if os.environ.get("BFLOAT16"):
logger.info("TextToAudioPipeline using bfloat16 precision for %s", model_id)
kwargs["torch_dtype"] = torch.bfloat16

# Initialize the pipeline
self.pipeline = StableAudioPipeline.from_pretrained(
model_id,
**kwargs
).to(torch_device)

# Enable optimization if configured
sfast_enabled = os.getenv("SFAST", "").strip().lower() == "true"
if sfast_enabled:
logger.info(
"TextToAudioPipeline will be dynamically compiled with stable-fast for %s",
model_id,
)
from app.pipelines.optim.sfast import compile_model
self.pipeline = compile_model(self.pipeline)

def __call__(
self,
prompt: str,
duration: float = 5.0,
num_inference_steps: int = 10,
guidance_scale: float = 3.0,
negative_prompt: str = None,
seed: Optional[int] = None,
**kwargs
) -> Tuple[bytes, str]:
"""Generate audio from text.

Args:
prompt: The text prompt for audio generation.
duration: Duration of the generated audio in seconds.
num_inference_steps: Number of denoising steps.
guidance_scale: Scale for classifier-free guidance.
negative_prompt: Optional text prompt to guide what to exclude.
seed: Optional seed for reproducible generation.

Returns:
Tuple containing the audio data as bytes and the file format.
"""
try:
# Set seed if provided
if seed is not None:
torch.manual_seed(seed)

# Set default steps if invalid
if num_inference_steps is None or num_inference_steps < 1:
num_inference_steps = 10

# Validate duration
if duration < 1.0 or duration > 30.0:
raise ValueError("Duration must be between 1 and 30 seconds")

# Generate audio
audio = self.pipeline(
prompt,
negative_prompt=negative_prompt,
audio_length_in_s=duration,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
**kwargs
).audio[0]

# Convert to bytes using soundfile
buffer = io.BytesIO()
sf.write(buffer, audio, samplerate=44100, format='WAV')
buffer.seek(0)

return buffer.read(), 'wav'

except Exception as e:
if isinstance(e, torch.cuda.OutOfMemoryError):
torch.cuda.empty_cache()
raise InferenceError(original_exception=e)

def __str__(self) -> str:
return f"TextToAudioPipeline model_id={self.model_id}"
149 changes: 149 additions & 0 deletions runner/app/routes/text_to_audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import logging
import os
from typing import Annotated, Dict, Tuple, Union

import torch
from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.routes.utils import (
HTTPError,
http_error,
handle_pipeline_exception,
)
from fastapi import APIRouter, Depends, Form, Response, status
from fastapi.responses import JSONResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from pydantic import BaseModel

router = APIRouter()
logger = logging.getLogger(__name__)

# Pipeline specific error handling configuration
PIPELINE_ERROR_CONFIG: Dict[str, Tuple[Union[str, None], int]] = {
# Specific error types
"OutOfMemoryError": (
"Out of memory error. Try reducing audio duration.",
status.HTTP_500_INTERNAL_SERVER_ERROR,
),
"ValueError": (
None, # Use the error message from the exception
status.HTTP_400_BAD_REQUEST,
),
}

RESPONSES = {
status.HTTP_200_OK: {
"content": {
"audio/wav": {
"schema": {
"type": "string",
"format": "binary"
}
},
},
"description": "Successfully generated audio"
},
status.HTTP_400_BAD_REQUEST: {"model": HTTPError},
status.HTTP_401_UNAUTHORIZED: {"model": HTTPError},
status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError},
}

@router.post(
"/text-to-audio",
responses=RESPONSES,
description="Generate audio from text prompts using Stable Audio.",
operation_id="genTextToAudio",
summary="Text To Audio",
tags=["generate"],
openapi_extra={"x-speakeasy-name-override": "textToAudio"},
)
@router.post(
"/text-to-audio/",
responses=RESPONSES,
include_in_schema=False,
)
async def text_to_audio(
prompt: Annotated[
str,
Form(description="Text prompt for audio generation."),
],
model_id: Annotated[
str,
Form(description="Hugging Face model ID used for audio generation."),
] = "",
duration: Annotated[
float,
Form(description="Duration of generated audio in seconds (between 1 and 30 seconds)."),
] = 5.0,
num_inference_steps: Annotated[
int,
Form(description="Number of denoising steps. More steps usually lead to higher quality audio but slower inference."),
] = 10,
guidance_scale: Annotated[
float,
Form(description="Scale for classifier-free guidance. Higher values result in audio that better matches the prompt but may be lower quality."),
] = 3.0,
negative_prompt: Annotated[
str,
Form(description="Text prompt to guide what to exclude from audio generation."),
] = None,
pipeline: Pipeline = Depends(get_pipeline),
token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)),
):
# Validate auth token if configured
auth_token = os.environ.get("AUTH_TOKEN")
if auth_token:
if not token or token.credentials != auth_token:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
headers={"WWW-Authenticate": "Bearer"},
content=http_error("Invalid bearer token."),
)

# Validate model ID matches pipeline
if model_id != "" and model_id != pipeline.model_id:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=http_error(
f"pipeline configured with {pipeline.model_id} but called with "
f"{model_id}."
),
)

# Validate duration
if duration < 1.0 or duration > 30.0:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=http_error(
"Duration must be between 1 and 30 seconds."
),
)

try:
# Generate audio
audio_data, audio_format = pipeline(
prompt=prompt,
duration=duration,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt,
)

# Return audio file
return Response(
content=audio_data,
media_type=f"audio/{audio_format}",
headers={
"Content-Disposition": f"attachment; filename=generated_audio.{audio_format}"
}
)

except Exception as e:
if isinstance(e, torch.cuda.OutOfMemoryError):
torch.cuda.empty_cache()
logger.error(f"TextToAudio pipeline error: {e}")
return handle_pipeline_exception(
e,
default_error_message="Text-to-audio pipeline error.",
custom_error_config=PIPELINE_ERROR_CONFIG,
)
4 changes: 3 additions & 1 deletion runner/dl_checkpoints.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ function download_beta_models() {

# Download custom pipeline models.
huggingface-cli download facebook/sam2-hiera-large --include "*.pt" "*.yaml" --cache-dir models


printf "\nDownloading token-gated models...\n"

Expand Down Expand Up @@ -81,7 +82,8 @@ function download_restricted_models() {
huggingface-cli download black-forest-labs/FLUX.1-dev --include "*.safetensors" "*.json" "*.txt" "*.model" --exclude ".onnx" ".onnx_data" --cache-dir models ${TOKEN_FLAG:+"$TOKEN_FLAG"}
# Download LLM models (Warning: large model size)
huggingface-cli download meta-llama/Meta-Llama-3.1-8B-Instruct --include "*.json" "*.bin" "*.safetensors" "*.txt" --cache-dir models

#Download text-to-audio model
huggingface-cli download stabilityai/stable-audio-open-1.0 --include "*.safetensors" "*.bin" "*.json" --cache-dir models
}

# Enable HF transfer acceleration.
Expand Down
Loading