Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix audio for vllm #755

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,9 @@ def _infer_code(
None,
sample_params,
input_ids,
use_refine=False,
spk_emb=params.spk_emb,
text_mask=text_mask,
)

token_ids = []
Expand Down Expand Up @@ -625,7 +628,7 @@ def _refine_text(
del input_ids

result = gpt.llm.generate(
None, sample_params, input_ids_list, params.show_tqdm
None, sample_params, input_ids_list, params.show_tqdm, use_refine=True
)
token_ids = []
hidden_states = []
Expand Down
1 change: 1 addition & 0 deletions ChatTTS/model/speaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def decorate_code_prompts(
t.replace("[Stts]", "")
.replace("[spk_emb]", "")
.replace("[empty_spk]", "")
.replace("[Ebreak]", "")
.strip()
)
"""
Expand Down
10 changes: 8 additions & 2 deletions ChatTTS/model/velocity/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ def generate(
sampling_params: Optional[SamplingParams] = None,
prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True,
use_refine: bool = False,
spk_emb: str = None,
text_mask = None,
) -> List[RequestOutput]:
"""Generates the completions for the input prompts.

Expand Down Expand Up @@ -166,7 +169,7 @@ def generate(
for i in range(num_requests):
prompt = prompts[i] if prompts is not None else None
token_ids = None if prompt_token_ids is None else prompt_token_ids[i]
self._add_request(prompt, sampling_params, token_ids)
self._add_request(prompt, sampling_params, token_ids, use_refine, spk_emb, text_mask)

rtns = self._run_engine(use_tqdm)
for i, rtn in enumerate(rtns):
Expand All @@ -184,10 +187,13 @@ def _add_request(
prompt: Optional[str],
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]],
use_refine: bool,
spk_emb: str,
text_mask = None,
) -> None:
request_id = str(next(self.request_counter))
self.llm_engine.add_request(
request_id, prompt, sampling_params, prompt_token_ids
request_id, prompt, sampling_params, prompt_token_ids, use_refine = use_refine, spk_emb = spk_emb, text_mask = text_mask
)

def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
Expand Down
5 changes: 4 additions & 1 deletion ChatTTS/model/velocity/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,9 @@ def add_request(
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
use_refine: bool = False,
spk_emb: str = None,
text_mask = None,
) -> None:
"""Add a request to the engine's request pool.

Expand All @@ -354,7 +357,7 @@ def add_request(
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, use_refine, spk_emb, text_mask)

# Create the sequence group.
seq_group = SequenceGroup(request_id, [seq], sampling_params, arrival_time)
Expand Down
44 changes: 38 additions & 6 deletions ChatTTS/model/velocity/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,19 @@ def __init__(
# cache in_wsl result
self.in_wsl = in_wsl()

from ...config import Config
self.config = Config()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

很坏的解决方案,完全破坏了设置的继承依赖关系。此配置应当从外部传入,而非自己二次引入。


from ..speaker import Speaker
from ...utils import select_device
device = None
if device is None:
self.device = select_device()

self.speaker = Speaker(
self.config.gpt.hidden_size, self.config.spk_stat, device
)

def load_model(self) -> None:
self.model = get_model(self.model_config)
self.post_model = Embed(
Expand Down Expand Up @@ -122,6 +135,9 @@ def _prepare_prompt(
prompt_tokens = seq_data.get_token_ids()
prompt_len = len(prompt_tokens)
prompt_lens.append(prompt_len)
use_refine = seq_data.use_refine
spk_emb = seq_data.spk_emb
text_mask = seq_data.text_mask

input_tokens.append(prompt_tokens)
# NOTE(woosuk): Here we assume that the first token in the prompt
Expand Down Expand Up @@ -174,7 +190,7 @@ def _prepare_prompt(
block_tables=None,
use_cuda_graph=False,
)
return input_tokens, input_positions, input_metadata, prompt_lens
return input_tokens, input_positions, input_metadata, prompt_lens, use_refine, spk_emb, text_mask

def _prepare_decode(
self,
Expand Down Expand Up @@ -354,13 +370,16 @@ def prepare_input_tensors(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata]:
use_refine = False
spk_emb = None
text_mask = None
if self.is_driver_worker:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, input_metadata, prompt_lens) = (
(input_tokens, input_positions, input_metadata, prompt_lens, use_refine, spk_emb, text_mask) = (
self._prepare_prompt(seq_group_metadata_list)
)
else:
Expand Down Expand Up @@ -454,15 +473,15 @@ def get_size_or_none(x: Optional[torch.Tensor]):
perform_sampling=False,
)

return input_tokens, input_positions, input_metadata, sampling_metadata
return input_tokens, input_positions, input_metadata, sampling_metadata, use_refine, spk_emb, text_mask

@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
) -> Optional[SamplerOutput]:
input_tokens, input_positions, input_metadata, sampling_metadata = (
input_tokens, input_positions, input_metadata, sampling_metadata, use_refine, spk_emb, text_mask = (
self.prepare_input_tensors(seq_group_metadata_list)
)
# print(sampling_metadata.seq_data)
Expand Down Expand Up @@ -495,8 +514,11 @@ def execute_model(
input_tokens_history = input_tokens_history.unsqueeze(2).repeat(1, 1, 4)
# print(input_tokens_history.shape)
# print("it2",input_tokens.shape)
text_mask = input_tokens != 0
text_mask = text_mask[:, :, 0]
# text_mask = input_tokens != 0
# text_mask = text_mask[:, :, 0]
if text_mask is None:
text_mask = input_tokens != 0
text_mask = text_mask[:, :, 0]

if input_metadata.use_cuda_graph:
graph_batch_size = input_tokens.shape[0]
Expand Down Expand Up @@ -533,6 +555,16 @@ def execute_model(
)
else:
input_emb = self.post_model(input_tokens, text_mask)
if not use_refine:
if spk_emb is not None:
self.speaker.apply(
input_emb,
spk_emb,
input_tokens,
21143,
self.device,
)

# print(input_emb.shape)
hidden_states = model_executable(
input_emb=input_emb,
Expand Down
11 changes: 10 additions & 1 deletion ChatTTS/model/velocity/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,18 @@ class SequenceData:
def __init__(
self,
prompt_token_ids: List[int],
use_refine: bool = True,
spk_emb: str = None,
text_mask = None,
) -> None:
self.prompt_token_ids = prompt_token_ids
self.output_token_ids: List[int] = []
self.cumulative_logprob = 0.0
self.hidden_states: Optional[torch.Tensor] = None
self.finished = False
self.use_refine = use_refine
self.spk_emb = spk_emb
self.text_mask = text_mask

def append_token_id(self, token_id: int, logprob: float) -> None:
if isinstance(self.cumulative_logprob, float):
Expand Down Expand Up @@ -132,12 +138,15 @@ def __init__(
prompt: str,
prompt_token_ids: List[int],
block_size: int,
use_refine: bool,
spk_emb: str,
text_mask,
) -> None:
self.seq_id = seq_id
self.prompt = prompt
self.block_size = block_size

self.data = SequenceData(prompt_token_ids)
self.data = SequenceData(prompt_token_ids, use_refine, spk_emb, text_mask)
self.output_logprobs: SampleLogprobs = []
self.output_text = ""

Expand Down
Loading