Skip to content

Commit

Permalink
Add TTS to OpenAI_API_Compatible (#11071)
Browse files Browse the repository at this point in the history
  • Loading branch information
taowang1993 authored Nov 26, 2024
1 parent 044e7b6 commit aa135a3
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel):
"""
Model class for OpenAI Speech to text model.
Model class for OpenAI text2speech model.
"""

def _invoke(
Expand Down
2 changes: 1 addition & 1 deletion api/core/model_runtime/model_providers/gitee_ai/tts/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

class GiteeAIText2SpeechModel(_CommonGiteeAI, TTSModel):
"""
Model class for OpenAI Speech to text model.
Model class for OpenAI text2speech model.
"""

def _invoke(
Expand Down
2 changes: 1 addition & 1 deletion api/core/model_runtime/model_providers/openai/tts/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
"""
Model class for OpenAI Speech to text model.
Model class for OpenAI text2speech model.
"""

def _invoke(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ supported_model_types:
- text-embedding
- speech2text
- rerank
- tts
configurate_methods:
- customizable-model
model_credential_schema:
Expand Down Expand Up @@ -67,7 +68,7 @@ model_credential_schema:
- variable: __model_type
value: llm
type: text-input
default: '4096'
default: "4096"
placeholder:
zh_Hans: 在此输入您的模型上下文长度
en_US: Enter your Model context size
Expand All @@ -80,7 +81,7 @@ model_credential_schema:
- variable: __model_type
value: text-embedding
type: text-input
default: '4096'
default: "4096"
placeholder:
zh_Hans: 在此输入您的模型上下文长度
en_US: Enter your Model context size
Expand All @@ -93,7 +94,7 @@ model_credential_schema:
- variable: __model_type
value: rerank
type: text-input
default: '4096'
default: "4096"
placeholder:
zh_Hans: 在此输入您的模型上下文长度
en_US: Enter your Model context size
Expand All @@ -104,7 +105,7 @@ model_credential_schema:
show_on:
- variable: __model_type
value: llm
default: '4096'
default: "4096"
type: text-input
- variable: function_calling_type
show_on:
Expand Down Expand Up @@ -174,3 +175,19 @@ model_credential_schema:
value: llm
default: '\n\n'
type: text-input
- variable: voices
show_on:
- variable: __model_type
value: tts
label:
en_US: Available Voices (comma-separated)
zh_Hans: 可用声音(用英文逗号分隔)
type: text-input
required: false
default: "alloy"
placeholder:
en_US: "alloy,echo,fable,onyx,nova,shimmer"
zh_Hans: "alloy,echo,fable,onyx,nova,shimmer"
help:
en_US: "List voice names separated by commas. First voice will be used as default."
zh_Hans: "用英文逗号分隔的声音列表。第一个声音将作为默认值。"
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from collections.abc import Iterable
from typing import Optional
from urllib.parse import urljoin

import requests

from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType
from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.tts_model import TTSModel
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat


class OAICompatText2SpeechModel(_CommonOaiApiCompat, TTSModel):
"""
Model class for OpenAI-compatible text2speech model.
"""

def _invoke(
self,
model: str,
tenant_id: str,
credentials: dict,
content_text: str,
voice: str,
user: Optional[str] = None,
) -> Iterable[bytes]:
"""
Invoke TTS model
:param model: model name
:param tenant_id: user tenant id
:param credentials: model credentials
:param content_text: text content to be translated
:param voice: model voice/speaker
:param user: unique user id
:return: audio data as bytes iterator
"""
# Set up headers with authentication if provided
headers = {}
if api_key := credentials.get("api_key"):
headers["Authorization"] = f"Bearer {api_key}"

# Construct endpoint URL
endpoint_url = credentials.get("endpoint_url")
if not endpoint_url.endswith("/"):
endpoint_url += "/"
endpoint_url = urljoin(endpoint_url, "audio/speech")

# Get audio format from model properties
audio_format = self._get_model_audio_type(model, credentials)

# Split text into chunks if needed based on word limit
word_limit = self._get_model_word_limit(model, credentials)
sentences = self._split_text_into_sentences(content_text, word_limit)

for sentence in sentences:
# Prepare request payload
payload = {"model": model, "input": sentence, "voice": voice, "response_format": audio_format}

# Make POST request
response = requests.post(endpoint_url, headers=headers, json=payload, stream=True)

if response.status_code != 200:
raise InvokeBadRequestError(response.text)

# Stream the audio data
for chunk in response.iter_content(chunk_size=4096):
if chunk:
yield chunk

def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
# Get default voice for validation
voice = self._get_model_default_voice(model, credentials)

# Test with a simple text
next(
self._invoke(
model=model, tenant_id="validate", credentials=credentials, content_text="Test.", voice=voice
)
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))

def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
"""
Get customizable model schema
"""
# Parse voices from comma-separated string
voice_names = credentials.get("voices", "alloy").strip().split(",")
voices = []

for voice in voice_names:
voice = voice.strip()
if not voice:
continue

# Use en-US for all voices
voices.append(
{
"name": voice,
"mode": voice,
"language": "en-US",
}
)

# If no voices provided or all voices were empty strings, use 'alloy' as default
if not voices:
voices = [{"name": "Alloy", "mode": "alloy", "language": "en-US"}]

return AIModelEntity(
model=model,
label=I18nObject(en_US=model),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.TTS,
model_properties={
ModelPropertyKey.AUDIO_TYPE: credentials.get("audio_type", "mp3"),
ModelPropertyKey.WORD_LIMIT: int(credentials.get("word_limit", 4096)),
ModelPropertyKey.DEFAULT_VOICE: voices[0]["mode"],
ModelPropertyKey.VOICES: voices,
},
)

def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list:
"""
Override base get_tts_model_voices to handle customizable voices
"""
model_schema = self.get_customizable_model_schema(model, credentials)

if not model_schema or ModelPropertyKey.VOICES not in model_schema.model_properties:
raise ValueError("this model does not support voice")

voices = model_schema.model_properties[ModelPropertyKey.VOICES]

# Always return all voices regardless of language
return [{"name": d["name"], "value": d["mode"]} for d in voices]

0 comments on commit aa135a3

Please sign in to comment.