Skip to content

Commit

Permalink
chore(format): run black on dev
Browse files Browse the repository at this point in the history
  • Loading branch information
github-actions[bot] committed Sep 10, 2024
1 parent 6feb586 commit 1a60b69
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 26 deletions.
9 changes: 5 additions & 4 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ async def _infer(
else:
# Hacker:Check if there are any silent segments; if so, take the last segment. Otherwise, try waiting for another loop.
import librosa

silence_intervals = librosa.effects.split(wavs[0][length:], top_db=10)
silence_left = 0
if len(silence_intervals) == 0:
Expand Down Expand Up @@ -532,7 +533,9 @@ async def _infer_code(
async for i in results_generator:
token_ids = []
hidden_states = []
if (stream and len(i.outputs[0].token_ids) % stream_batch_size == 0) or i.finished:
if (
stream and len(i.outputs[0].token_ids) % stream_batch_size == 0
) or i.finished:
token_ids.append(torch.tensor(i.outputs[0].token_ids))
hidden_states.append(
i.outputs[0].hidden_states.to(torch.float32).to(self.device)
Expand Down Expand Up @@ -568,9 +571,7 @@ async def _infer_code(
hidden_states = []
if (stream and len(i.ids[0]) % stream_batch_size == 0) or i.finished:
token_ids.append(i.ids[0])
hidden_states.append(
i.hiddens[0].to(torch.float32).to(self.device)
)
hidden_states.append(i.hiddens[0].to(torch.float32).to(self.device))
yield GPT.GenerationOutputs(
ids=token_ids,
finished=i.finished,
Expand Down
12 changes: 3 additions & 9 deletions ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def from_pretrained(
num_audio_tokens=self.num_audio_tokens,
num_text_tokens=self.num_text_tokens,
post_model_path=embed_file_path,
dtype="float32"
dtype="float32",
)
self.logger.info("vLLM model loaded")
return
Expand Down Expand Up @@ -585,7 +585,7 @@ async def generate(
attentions,
hiddens,
infer_text,
False
False,
)
del not_finished

Expand All @@ -609,11 +609,5 @@ async def generate(
del finish, inputs_ids_buf

yield self._prepare_generation_outputs(
inputs_ids,
start_idx,
end_idx,
attentions,
hiddens,
infer_text,
True
inputs_ids, start_idx, end_idx, attentions, hiddens, infer_text, True
)
53 changes: 40 additions & 13 deletions ChatTTS/model/velocity/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ def set_block_size(self, block_size: int) -> None:
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> tuple[list[list[int]], list[list[int]], InputMetadata, list[int], list[Tensor]]:
) -> tuple[
list[list[int]], list[list[int]], InputMetadata, list[int], list[Tensor]
]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
Expand Down Expand Up @@ -360,17 +362,23 @@ def _prepare_sample(
def prepare_input_tensors(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata, list[torch.Tensor]]:
) -> Tuple[
torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata, list[torch.Tensor]
]:
speaker_embedding = None
if self.is_driver_worker:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, input_metadata, prompt_lens, speaker_embedding) = (
self._prepare_prompt(seq_group_metadata_list)
)
(
input_tokens,
input_positions,
input_metadata,
prompt_lens,
speaker_embedding,
) = self._prepare_prompt(seq_group_metadata_list)
else:
(input_tokens, input_positions, input_metadata) = self._prepare_decode(
seq_group_metadata_list
Expand Down Expand Up @@ -462,7 +470,13 @@ def get_size_or_none(x: Optional[torch.Tensor]):
perform_sampling=False,
)

return input_tokens, input_positions, input_metadata, sampling_metadata, speaker_embedding
return (
input_tokens,
input_positions,
input_metadata,
sampling_metadata,
speaker_embedding,
)

@torch.inference_mode()
def execute_model(
Expand All @@ -471,9 +485,13 @@ def execute_model(
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
) -> Optional[SamplerOutput]:

input_tokens, input_positions, input_metadata, sampling_metadata, speaker_embedding = (
self.prepare_input_tensors(seq_group_metadata_list)
)
(
input_tokens,
input_positions,
input_metadata,
sampling_metadata,
speaker_embedding,
) = self.prepare_input_tensors(seq_group_metadata_list)
# print(sampling_metadata.seq_data)
seq_groups = []
for i, rtn in enumerate(sampling_metadata.seq_groups):
Expand Down Expand Up @@ -522,7 +540,9 @@ def execute_model(
if speaker_embedding_params is None:
speaker_embedding_params = speaker_embedding[i]
else:
speaker_embedding_params = torch.cat((speaker_embedding_params, speaker_embedding[i]))
speaker_embedding_params = torch.cat(
(speaker_embedding_params, speaker_embedding[i])
)

else:
speaker_embedding_params = self.post_model(input_tokens, text_mask)
Expand Down Expand Up @@ -560,7 +580,7 @@ def execute_model(
# sampling_metadata=sampling_metadata,
# )
results = []
for i,val in enumerate(seq_groups):
for i, val in enumerate(seq_groups):
idx_next_i = idx_next[i, 0, :].tolist()
logprob_i = logprob[i].tolist()
tmp_hidden_states = hidden_states[i]
Expand Down Expand Up @@ -781,7 +801,9 @@ def _make_tensor_with_pad(
for x_i in x:
pad_i = pad
if isinstance(x[0][0], list):
pad_i = [0,] * len(x[0][0])
pad_i = [
0,
] * len(x[0][0])
elif isinstance(x[0][0], tuple):
pad_i = (0,) * len(x[0][0])
padded_x.append(_pad_to_max(x_i, max_len, pad_i))
Expand All @@ -791,6 +813,7 @@ def _make_tensor_with_pad(
device=device,
)


def _make_with_pad(
x: List[torch.Tensor],
max_len: int,
Expand All @@ -805,11 +828,15 @@ def _make_with_pad(
padded_x.append(x_i)
else:
padded_x.append(
torch.cat((torch.zeros(1, max_len-x_i.shape[-2], 768).to(device), x_i), dim=1)
torch.cat(
(torch.zeros(1, max_len - x_i.shape[-2], 768).to(device), x_i),
dim=1,
)
)

return padded_x


def _get_graph_batch_size(batch_size: int) -> int:
if batch_size <= 2:
return batch_size
Expand Down

0 comments on commit 1a60b69

Please sign in to comment.