Skip to content

Commit

Permalink
copy over the changes from Kevin about align_s2s documented in
Browse files Browse the repository at this point in the history
https://docs.google.com/document/d/1LRMwpUt8TH96oRVi0zz7OOTbejrMk9zs4VH1SEX59ZM/edit?tab=t.0
and
/lustre/fsw/portfolios/llmservice/users/kevinhu/works/mod_speech_llm/code/NeMo_s2s_align_debug/

Signed-off-by: zhehuaichen <dian.chenzhehuai@gmail.com>
zhehuaichen committed Nov 12, 2024
1 parent f94b13f commit 713c931
Showing 11 changed files with 1,128 additions and 44 deletions.
52 changes: 31 additions & 21 deletions examples/multimodal/speech_llm/conf/s2s/pt_salm_1a.yaml
Original file line number Diff line number Diff line change
@@ -238,35 +238,40 @@ model:
train_ds:
input_cfg:
- type: lhotse_shar
shar_path: /workspace/data/s2s_shars/es/
# shar_path: /workspace/data/s2s_shars/es/
shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/es/timestamp/
weight: 1.0
tags:
lang: en
s2s: True
s2s_align: True
- type: lhotse_shar
shar_path: /workspace/data/s2s_shars/msmacro/
# shar_path: /workspace/data/s2s_shars/msmacro/
shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/msmacro/timestamp/
weight: 0.2
tags:
lang: en
s2s: True
s2s_align: True
- type: lhotse_shar
shar_path: /workspace/data/s2s_shars/alpaca/
# shar_path: /workspace/data/s2s_shars/alpaca/
shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/alpaca/timestamp/
weight: 0.1
tags:
lang: en
s2s: True
s2s_align: True
- type: lhotse_shar
shar_path: /workspace/data/s2s_shars/squadv2/
# shar_path: /workspace/data/s2s_shars/squadv2/
shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/squadv2/timestamp/
weight: 0.05
tags:
lang: en
s2s: True
s2s_align: True
- type: lhotse_shar
shar_path: /workspace/data/s2s_shars/msmacro_speech_instruct/
# shar_path: /workspace/data/s2s_shars/msmacro_speech_instruct/
shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/msmacro_speech_instruct/timestamp/
weight: 0.2
tags:
lang: en
s2s: True
s2s_align: True
global_batch_size: ${model.global_batch_size}
micro_batch_size: ${model.micro_batch_size}
shuffle: True
@@ -326,38 +331,43 @@ model:
weight: 1.0
tags:
lang: en
s2s: True
s2s_align: True
input_cfg:
- type: lhotse_shar
shar_path: /workspace/data/s2s_shars/es_validation/ # ST
# shar_path: /workspace/data/s2s_shars/es_validation/ # ST
shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/es/timestamp/
weight: 1.0
tags:
lang: en
s2s: True
s2s_align: True
- type: lhotse_shar
shar_path: /workspace/data/s2s_shars/msmacro/ # text SQA
# shar_path: /workspace/data/s2s_shars/msmacro/ # text SQA
shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/msmacro/timestamp/
weight: 1.0
tags:
lang: en
s2s: True
s2s_align: True
- type: lhotse_shar
shar_path: /workspace/data/s2s_shars/alpaca/ # text SQA
# shar_path: /workspace/data/s2s_shars/alpaca/ # text SQA
shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/alpaca/timestamp/
weight: 1.0
tags:
lang: en
s2s: True
s2s_align: True
- type: lhotse_shar
shar_path: /workspace/data/s2s_shars/squadv2/ # speech SQA
# shar_path: /workspace/data/s2s_shars/squadv2/ # speech SQA
shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/squadv2/timestamp/
weight: 0.1
tags:
lang: en
s2s: True
s2s_align: True
- type: lhotse_shar
shar_path: /workspace/data/s2s_shars/msmacro_speech_instruct/ # speech SQA
# shar_path: /workspace/data/s2s_shars/msmacro_speech_instruct/ # speech SQA
shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/msmacro_speech_instruct/timestamp/
weight: 0.9
tags:
lang: en
s2s: True
s2s_align: True

global_batch_size: ${model.global_batch_size}
micro_batch_size: ${model.micro_batch_size}
431 changes: 431 additions & 0 deletions examples/multimodal/speech_llm/conf/s2s/pt_salm_1a_s2s.yaml

Large diffs are not rendered by default.

431 changes: 431 additions & 0 deletions examples/multimodal/speech_llm/conf/s2s/pt_salm_1a_s2s_direct.yaml

Large diffs are not rendered by default.

15 changes: 9 additions & 6 deletions examples/multimodal/speech_llm/conf/s2s/pt_salm_1b.yaml
Original file line number Diff line number Diff line change
@@ -14,6 +14,8 @@

name: megatron_audio_gpt_s2s_lhotse

log_level: INFO

trainer:
devices: 1
accelerator: gpu
@@ -236,11 +238,12 @@ model:
train_ds:
input_cfg:
- type: lhotse_shar
shar_path: /workspace/data/s2s_shars/es/
# shar_path: /workspace/data/s2s_shars/es/
shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/es/src_timestamp
weight: 1.0
tags:
lang: en
s2s: True
s2s_align: True
global_batch_size: ${model.global_batch_size}
micro_batch_size: ${model.micro_batch_size}
shuffle: True
@@ -300,14 +303,15 @@ model:
weight: 1.0
tags:
lang: en
s2s: True
s2s_align: True
input_cfg:
- type: lhotse_shar
shar_path: /workspace/data/s2s_shars/es_validation/
# shar_path: /workspace/data/s2s_shars/es_validation/
shar_path: /lustre/fsw/portfolios/llmservice/users/kevinhu/s2s/data/es/src_timestamp
weight: 1.0
tags:
lang: en
s2s: True
s2s_align: True
global_batch_size: ${model.global_batch_size}
micro_batch_size: ${model.micro_batch_size}
shuffle: False
@@ -338,7 +342,6 @@ model:
source_target_text_ratio_limit: 4.0
# ASR configs
sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate}

log_every_n_steps: 10
metrics:
- name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss']
4 changes: 4 additions & 0 deletions examples/multimodal/speech_llm/modular_audio_gpt_train.py
Original file line number Diff line number Diff line change
@@ -47,6 +47,10 @@

@hydra_runner(config_path="conf", config_name="modular_audio_gpt_config_peft")
def main(cfg) -> None:
# Set up logging with the specified log level
logging_level = getattr(logging, cfg.log_level.upper(), logging.INFO)
logging.setLevel(logging_level)

logging.info("\n\n************** Experiment configuration ***********")
logging.info(f'\n{OmegaConf.to_yaml(cfg)}')
# hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams
196 changes: 187 additions & 9 deletions nemo/collections/multimodal/speech_llm/data/lhotse_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
import random
import re

import torch.utils.data
from lhotse import CutSet
@@ -12,6 +12,7 @@
build_loss_mask,
ceil_to_nearest,
)
from nemo.utils import logging


def collate_vectors(items, max_length: int, padding_value):
@@ -108,25 +109,79 @@ def __getitem__(self, cuts) -> dict[str, torch.Tensor | list[str] | dict]:
instructions, instruction_lengths = [], []
source_texts, source_text_lengths = [], [] # Not used in the current implementation
target_texts, target_text_lengths = [], []
start_time_tokens, word_lengths = [], []
remove_ids = []
for id, cut in enumerate(cuts):
metadata.append({'audio_filepath': cut.id + '.wav'})
if id == 0:
logging.debug(f'audio_filepath: {cut.id}.wav')
logging.debug(f'cut: {cut}')
metadata.append({'audio_filepath': cut.id + '.wav'})
# TODO: the following use of _process_example is not ideal. Should update
instruction = self.text_processor._process_example(context=cut.supervisions[0].text, output="")
instruction, instruction_length = torch.as_tensor(instruction["input_ids"][:-1]), torch.as_tensor(
len(instruction["input_ids"]) - 1
)
if id == 0:
logging.debug(f'instruction: {cut.supervisions[0].text}')

source_text = self.text_processor._process_example(context=cut.supervisions[1].text, output="")
source_text, source_text_length = torch.as_tensor(source_text["input_ids"]), torch.as_tensor(
len(source_text["input_ids"])
)

target_text = self.text_processor._process_example(context="", output=cut.supervisions[2].text)
# -1 to remove the eos token added by the text processor
target_text, target_text_length = torch.as_tensor(target_text["answer_ids"][:-1]), torch.as_tensor(
len(target_text["answer_ids"]) - 1
)
if id == 0:
logging.debug(f'source_text: {cut.supervisions[1].text}')

def extract_text_and_time_tokens(input_sequence):
# Regular expression to match time tokens (e.g., <|x|> where x is an integer)
time_token_pattern = r"<\|(\d+)\|>"
# Find all time tokens
time_tokens = re.findall(time_token_pattern, input_sequence)
# Only keep the first token of every pair (i.e., start time tokens)
start_time_token = [int(time_tokens[i]) for i in range(0, len(time_tokens), 2)]
# Remove all time tokens to isolate words
words = re.sub(time_token_pattern, '', input_sequence).split()
# Process each word, tokenize it, and calculate token lengths
tokenized_words = []
word_length = []
for idx, word in enumerate(words):
# Tokenize the word using the provided text processor
if id == 0:
logging.debug(f'word: {word}')
tokenized_word = self.text_processor._process_example(context="", output=word)
# Remove the EOS token (assuming the EOS token is at the end of "answer_ids")
token_ids = tokenized_word["answer_ids"][:-1] # Remove EOS token
if idx != 0: # If not the first word, remove the first token
token_ids = token_ids[1:]
if id == 0:
logging.debug(f'token_ids: {token_ids}')
token_length = len(token_ids) # Calculate the length
tokenized_words.extend(token_ids)
word_length.append(token_length)
return (
torch.as_tensor(tokenized_words),
torch.as_tensor(start_time_token),
torch.as_tensor(word_length),
)

# import pdb; pdb.set_trace()
use_timestamp = getattr(cut, "s2s_align", False)
if not use_timestamp:
pattern = r"<\|\d+\|>"
output_text = re.sub(pattern, "", cut.supervisions[2].text)
output_text = re.sub(r'\s+', ' ', output_text).strip()
target_text = self.text_processor._process_example(context="", output=output_text)
# -1 to remove the eos token added by the text processor
target_text, target_text_length = torch.as_tensor(target_text["answer_ids"][:-1]), torch.as_tensor(
len(target_text["answer_ids"]) - 1
)
if id == 0:
logging.debug(f'target_text: {output_text}')
else:
target_text, start_time_token, word_length = extract_text_and_time_tokens(cut.supervisions[2].text)
target_text_length = len(target_text)
# import pdb; pdb.set_trace()
if id == 0:
logging.debug(f'target_text: {cut.supervisions[2].text}')

if self.filter_by_source_target_text_ratio:
if (
@@ -142,6 +197,9 @@ def __getitem__(self, cuts) -> dict[str, torch.Tensor | list[str] | dict]:
source_text_lengths.append(source_text_length)
target_texts.append(target_text)
target_text_lengths.append(target_text_length)
if use_timestamp:
word_lengths.append(word_length)
start_time_tokens.append(start_time_token)

cuts = [c for i, c in enumerate(cuts) if i not in remove_ids]

@@ -215,13 +273,17 @@ def collate_and_pad(inputs):
# Loop through cuts and build target_codec
for i, cut in enumerate(cuts):
feat_i = cut.target_codes.load()
# logging.debug(f'frame_shift: {cut.target_codes.frame_shift}')
# logging.debug(f'feat_i.shape: {feat_i.shape}')
target_codec[i, : feat_i.shape[0], 0] = text_unk_id
feat_i = feat_i[: features_lens[i] * self.decoder_reduction_factor, : self.n_speech_codebooks]
feat_i = feat_i.reshape((-1, self.n_speech_codebooks * self.decoder_reduction_factor))
target_codec[i, : feat_i.shape[0], 1:] = torch.tensor(feat_i)
target_codec[i, feat_i.shape[0], :] = eos_tensor

target_codec = target_codec.to(torch.int)
logging.debug(f'target_codec.shape: {target_codec.shape} ')
logging.debug(f'features_lens.shape: {features_lens.shape} ')

source_texts, source_text_lengths = collate_and_pad(source_texts)

@@ -243,6 +305,9 @@ def _convert_text_to_3d_tensor(texts, include_eos=True, tokens_to_generate=0):
texts_expanded = texts_expanded[:, :-1]
return texts, text_lengths, texts_expanded

# import pdb; pdb.set_trace()

unpadded_target_texts = target_texts
target_texts, target_text_lengths, target_texts_expanded = _convert_text_to_3d_tensor(target_texts)
instructions, instruction_lengths, instructions_expanded_no_eos = _convert_text_to_3d_tensor(
# tokens_to_generate is used in inference
@@ -253,12 +318,125 @@ def _convert_text_to_3d_tensor(texts, include_eos=True, tokens_to_generate=0):

# answers = torch.concat([speaker_context, bos_tensor, target_codec], 1)

if getattr(cut, "s2s", False):
logging.debug(f'target_texts_expanded.shape: {target_texts_expanded.shape} ')
logging.debug(f'target_text_lengths.shape: {target_text_lengths.shape} ')
logging.debug(f'cut: {cut} ')

def discretize_time(start_token, speech_resolution=0.08, timestamp_resolution=0.08):
"""Convert the start token into a time index based on the resolution."""
return int(start_token * timestamp_resolution / speech_resolution)

def _expand_text_with_timestamps_and_word_lengths(
word_tokens, word_lengths, start_time_tokens, features_lens, frame_rate=0.08, pad_id=None
):
"""
Expand word tokens according to start time tokens and word lengths for a batch of sequences.
Args:
- word_tokens: List of lists of token sequences (each inner list is a word's token IDs), shape [batch][time].
- word_lengths: List of lists of word lengths, shape [batch][time].
- start_time_tokens: List of lists of start times, shape [batch][time].
- max_length: Maximum length in the time dimension (number of frames).
- frame_rate: Frame rate resolution.
- pad_id: Padding ID to use for empty positions in the tensor.
Returns:
- 2D tensor [batch, max_length] where each row is the expanded token sequence for that batch.
"""
if pad_id is None:
raise ValueError("pad_id must be provided.")

batch_size = len(word_tokens)
max_length = max(features_lens).item()

# Create the empty 2D tensor [batch, max_length] with pad_id as the default value
texts_expanded = torch.full((batch_size, max_length), fill_value=pad_id, dtype=torch.long)

# Iterate over each batch
for batch_idx in range(batch_size):
batch_max_length = features_lens[batch_idx]
word_start_idx = 0 # Start index to keep track of the position within the concatenated word tokens

# Iterate over the words in the current batch
for word_idx, word_length in enumerate(word_lengths[batch_idx]):
start_token = start_time_tokens[batch_idx][word_idx]

# Convert the start time token into a time index based on frame rate
start_time_index = discretize_time(start_token, frame_rate)

# Reduction of start time index due to stacking of frames
start_time_index = int(start_time_index / self.decoder_reduction_factor)
if batch_idx == 0:
logging.debug(f'start_time_index[0]: {start_time_index}')

# Calculate the end time index based on word length
end_time_index = start_time_index + word_length
end_time_index = min(end_time_index, max_length) # Ensure it doesn't exceed max length

# Get the word tokens for the current word
word_token_ids = word_tokens[batch_idx][word_start_idx : word_start_idx + word_length]

# Populate the tokens in the expanded tensor at the correct positions
for t_idx in range(start_time_index, end_time_index):
if t_idx - start_time_index < len(word_token_ids): # Ensure tokens are within bounds
token_id = word_token_ids[t_idx - start_time_index] # Get token for this time step
texts_expanded[batch_idx][t_idx] = token_id # Directly assign the token ID

# Move to the next word in the concatenated word tokens
word_start_idx += word_length

# Overwrite padding tokens
texts_expanded[batch_idx][batch_max_length:] = text_pad_id

return texts_expanded

# import pdb; pdb.set_trace()

# TODO(huk): Consider smaller reduction factor
if getattr(cut, "s2s_align", False):
max_feat_len = max(features_lens).item() + 1
# [batch, max_feat_len]
target_text_expanded = _expand_text_with_timestamps_and_word_lengths(
unpadded_target_texts,
word_lengths,
start_time_tokens,
features_lens + 1,
cut.target_codes.frame_shift,
pad_id=text_unk_id,
)
# import pdb; pdb.set_trace()
logging.debug(f'start_time_token: {start_time_tokens[0]}')
logging.debug(f'word_length: {word_lengths[0]}')
logging.debug(f'target_tokens: {unpadded_target_texts[0]}')
logging.debug(f'target_text_expanded: {target_text_expanded[0,:]}')
# [batch, max_feat_len, 1+V], where V = #codebooks * reduction_factor
target_codec[:, :, 0] = target_text_expanded
token_list = target_codec

logging.debug(f'token_list[0].shape: {token_list[0].shape}')
if not self.t5_style:
token_list = [
torch.concat([it[:itl], tt], 0)
for tt, it, itl in zip(token_list, instructions_expanded_no_eos, instruction_lengths)
]
tokens, _ = collate_and_pad(token_list)
speech_loss_mask = tokens[:, :, 1:] != self.speech_pad_id
# Make the text loss mask the same as speech since they are aligned
loss_mask = torch.cat([speech_loss_mask[..., :1], speech_loss_mask], dim=-1)
if not self.t5_style:
for itl in instruction_lengths:
loss_mask[:, :itl, :] = False
# loss_mask = torch.cat([text_loss_mask, speech_loss_mask], 2)
# full_lengths = target_text_lengths + 1 + features_lens + 1 + instruction_length
full_lengths = features_lens + 1 + instruction_length
elif getattr(cut, "s2s", False):
# Add 1 for eos token
token_list = [
torch.concat([tt[: ttl + 1], tc[: tcl + 1]], 0)
for tt, ttl, tc, tcl in zip(target_texts_expanded, target_text_lengths, target_codec, features_lens)
]
# import pdb; pdb.set_trace()
logging.debug(f'token_list[0].shape: {token_list[0].shape}')
if not self.t5_style:
token_list = [
torch.concat([it[:itl], tt], 0)
Original file line number Diff line number Diff line change
@@ -14,6 +14,7 @@

import itertools
import json
import logging
import os
from functools import partial
from typing import List, Optional, Union
@@ -289,6 +290,8 @@ def inject_perception_input(
attention_mask = self._create_attention_mask(encoder_input)
position_ids = build_position_ids(encoder_input[:, :, 0])

# import pdb; pdb.set_trace()

# Add position embeddings
if (
getattr(lm_embedding, "position_embeddings", None) is not None
@@ -328,6 +331,7 @@ def _get_text_embeddings(self, text_tokens, position_ids):
def prepare_llm_input(self, audio_batch):
"""Prepare input for the LLM."""
input_signal = audio_batch['audio_signal']
logging.debug(f'input_signal.shape: {input_signal.shape}')
input_signal_length = audio_batch['audio_signal_length']

input_ids, input_length, labels, loss_mask = (
@@ -348,6 +352,9 @@ def prepare_llm_input(self, audio_batch):
processed_signal_length=None,
)

logging.debug(f'encoded.shape: {encoded.shape}')
logging.debug(f'encoded_len.shape: {encoded_len.shape}')
logging.debug(f'num_audios: {num_audios}')
if num_audios is not None:
# split the encoded and encoded_len by num_audios, used when there're multiple audio files per sample
encoded = encoded.split(num_audios.tolist())
@@ -604,6 +611,8 @@ def fwd_output_only_func(dataloader_iter, model):
**extra_arg,
)

# import pdb; pdb.set_trace()

# Advance inference sequence offset.
if self.inference_params:
# if last stage, then (final) output is [b, s, h], otherwise it's [s, b, h]
Original file line number Diff line number Diff line change
@@ -62,6 +62,7 @@ def __init__(
self.proj_head_dims = proj_head_dims

def forward(self, input_):

if input_.ndim == 3:
assert input_.shape[2] == len(self.proj_head_dims)
input_ = input_.clone()
@@ -311,6 +312,8 @@ def inference_step(self, dataloader_iter, mode):
"""
Used for validation and test steps, added postprocessing after calling self.predict_step().
"""
# import pdb; pdb.set_trace()

# Evaluation of multimodal data follows the same pattern as training except predict_step
batch, batch_idx, dataloader_idx = next(dataloader_iter)
data_cfg = self.cfg.data.validation_ds if mode == 'validation' else self.cfg.data.test_ds
@@ -403,15 +406,20 @@ def parse_decoder_outputs(
decoder_output = input_decoder_output[-1:].tile([max_len, 1])
decoder_output[: max_len - context_length] = input_decoder_output[context_length:]

# Do not split because text and speech are now aligned
# Split text and speech part based on the position of the first separator token
sep_pos = (decoder_output[:, 0] == text_separator).long()
if torch.any(sep_pos):
first_sep_pos = torch.argmax(sep_pos)
text_tokens = decoder_output[:first_sep_pos, 0]
speech_tokens = decoder_output[first_sep_pos + 1 :, 1:]
else:
text_tokens = decoder_output[:, 0]
speech_tokens = decoder_output[:, 1:]
# sep_pos = (decoder_output[:, 0] == text_separator).long()
# if torch.any(sep_pos):
# first_sep_pos = torch.argmax(sep_pos)
# text_tokens = decoder_output[:first_sep_pos, 0]
# speech_tokens = decoder_output[first_sep_pos + 1 :, 1:]
# else:
# text_tokens = decoder_output[:, 0]
# speech_tokens = decoder_output[:, 1:]
text_tokens = decoder_output[:, 0]
speech_tokens = decoder_output[:, 1:]

# import pdb; pdb.set_trace()

# Get speech token ids
n_speech_codebooks = self.model.n_proj_heads - 1
@@ -816,6 +824,7 @@ def write_predictions_to_file(self, outputs, output_file_path_prefix, output_dir
def de_concat_multiproj_logits(self, logits):
logits_list = []
prev = 0
# import pdb; pdb.set_trace()
for i in self.model.proj_head_dims:
logits_list.append(logits[:, prev : prev + i])
prev += i
Original file line number Diff line number Diff line change
@@ -140,6 +140,7 @@ def end_of_generation_condition(
returns:
a boolean tensor indicating whether the generation should stop
"""
# import pdb; pdb.set_trace()
if len(end_strings) == 1 and end_strings[0] == END_OF_SEQ:
return prev == eod_id
else:
Original file line number Diff line number Diff line change
@@ -787,6 +787,8 @@ def get_prev(logits, started, temperature, extra):
prev = torch.multinomial(probs, num_samples=1).view(-1)
return prev

# import pdb; pdb.set_trace()

prev = [get_prev(logits_i, started, temperature, extra) for logits_i in logits]
prev = torch.stack(prev, dim=1)
started_expand = started.unsqueeze(1).expand(-1, prev.size(1))
@@ -833,6 +835,8 @@ def get_prev(logits, started, temperature, extra):
model.cfg.speech_eos_id,
)

# import pdb; pdb.set_trace()

done_token = done_token.byte() & started.byte()

just_finished = (done_token & ~is_done).bool()
Original file line number Diff line number Diff line change
@@ -310,6 +310,8 @@ def _process_example(self, context: str, output: str):
else:
text = context + ' ' + output

# logging.debug(f'text: {text}')

if self.virtual_tokens:
# (@adithyare) we are going to insert "pad/eos" tokens in the beginning of the text and context
# these pad/eos tokens are placeholders for virtual tokens
@@ -321,6 +323,8 @@ def _process_example(self, context: str, output: str):
if self.end_string:
answer_ids += self.tokenizer.text_to_ids(self.end_string)

# logging.debug(f'answer_text: {answer_text}')

if self.audio_locator is None:
# signle audio case
context_ids = self.tokenizer.text_to_ids(context)

0 comments on commit 713c931

Please sign in to comment.