Skip to content

Commit

Permalink
parler-tts support (#2431)
Browse files Browse the repository at this point in the history
* Start sketching parler-tts support.

* Implement the attention.

* Add the example code.

* Fix the example.

* Add the description + t5 encode it.

* More of the parler forward pass.

* Fix the positional embeddings.

* Support random sampling in generation.

* Handle EOS.

* Add the python decoder.

* Proper causality mask.
  • Loading branch information
LaurentMazare authored Aug 18, 2024
1 parent 736d8eb commit 58197e1
Show file tree
Hide file tree
Showing 5 changed files with 658 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,4 @@ candle-wasm-examples/**/config*.json
.DS_Store
.idea/*
__pycache__
out.safetensors
29 changes: 29 additions & 0 deletions candle-examples/examples/parler-tts/decode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch
import torchaudio
from safetensors.torch import load_file
from parler_tts import DACModel

tensors = load_file("out.safetensors")
dac_model = DACModel.from_pretrained("parler-tts/dac_44khZ_8kbps")
output_ids = tensors["codes"][None, None]
print(output_ids, "\n", output_ids.shape)
batch_size = 1
with torch.no_grad():
output_values = []
for sample_id in range(batch_size):
sample = output_ids[:, sample_id]
sample_mask = (sample >= dac_model.config.codebook_size).sum(dim=(0, 1)) == 0
if sample_mask.sum() > 0:
sample = sample[:, :, sample_mask]
sample = dac_model.decode(sample[None, ...], [None]).audio_values
output_values.append(sample.transpose(0, 2))
else:
output_values.append(torch.zeros((1, 1, 1)).to(dac_model.device))
output_lengths = [audio.shape[0] for audio in output_values]
pcm = (
torch.nn.utils.rnn.pad_sequence(output_values, batch_first=True, padding_value=0)
.squeeze(-1)
.squeeze(-1)
)
print(pcm.shape, pcm.dtype)
torchaudio.save("out.wav", pcm.cpu(), sample_rate=44100)
175 changes: 175 additions & 0 deletions candle-examples/examples/parler-tts/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

#[cfg(feature = "accelerate")]
extern crate accelerate_src;

use anyhow::Error as E;
use clap::Parser;

use candle::{DType, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::parler_tts::{Config, Model};
use tokenizers::Tokenizer;

#[derive(Parser)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,

/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,

/// Display the token for the specified prompt.
#[arg(long)]
verbose_prompt: bool,

#[arg(long, default_value = "Hey, how are you doing today?")]
prompt: String,

#[arg(
long,
default_value = "A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch. The recording is of very high quality, with the speaker's voice sounding clear and very close up."
)]
description: String,

/// The temperature used to generate samples.
#[arg(long, default_value_t = 1.0)]
temperature: f64,

/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,

/// The seed to use when generating random samples.
#[arg(long, default_value_t = 0)]
seed: u64,

#[arg(long, default_value_t = 5000)]
sample_len: usize,

/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.0)]
repeat_penalty: f32,

/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,

#[arg(long)]
model_id: Option<String>,

#[arg(long)]
revision: Option<String>,

#[arg(long)]
quantized: bool,

/// Use f16 precision for all the computations rather than f32.
#[arg(long)]
f16: bool,

#[arg(long)]
model_file: Option<String>,

#[arg(long)]
tokenizer_file: Option<String>,

#[arg(long)]
config_file: Option<String>,

#[arg(long, default_value_t = 512)]
max_steps: usize,
}

fn main() -> anyhow::Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;

let args = Args::parse();

let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature, args.repeat_penalty, args.repeat_last_n
);

let start = std::time::Instant::now();
let api = hf_hub::api::sync::Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id.to_string(),
None => "parler-tts/parler-tts-large-v1".to_string(),
};
let revision = match args.revision {
Some(r) => r,
None => "main".to_string(),
};
let repo = api.repo(hf_hub::Repo::with_revision(
model_id,
hf_hub::RepoType::Model,
revision,
));
let model_files = match args.model_file {
Some(m) => vec![m.into()],
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
let config = match args.config_file {
Some(m) => m.into(),
None => repo.get("config.json")?,
};
let tokenizer = match args.tokenizer_file {
Some(m) => m.into(),
None => repo.get("tokenizer.json")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;

let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&model_files, DType::F32, &device)? };
let config: Config = serde_json::from_reader(std::fs::File::open(config)?)?;
let mut model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());

let description_tokens = tokenizer
.encode(args.description, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let description_tokens = Tensor::new(description_tokens, &device)?.unsqueeze(0)?;
println!("{description_tokens}");

let prompt_tokens = tokenizer
.encode(args.prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let prompt_tokens = Tensor::new(prompt_tokens, &device)?.unsqueeze(0)?;
println!("{prompt_tokens}");

let lp = candle_transformers::generation::LogitsProcessor::new(
args.seed,
Some(args.temperature),
args.top_p,
);
let codes = model.generate(&prompt_tokens, &description_tokens, lp, args.max_steps)?;
println!("{codes}");
let codes = codes.to_dtype(DType::I64)?;
codes.save_safetensors("codes", "out.safetensors")?;
Ok(())
}
1 change: 1 addition & 0 deletions candle-transformers/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ pub mod mobileone;
pub mod moondream;
pub mod mpt;
pub mod olmo;
pub mod parler_tts;
pub mod persimmon;
pub mod phi;
pub mod phi3;
Expand Down
Loading

0 comments on commit 58197e1

Please sign in to comment.