Skip to content

Commit

Permalink
Add --zstd-workers option (#116)
Browse files Browse the repository at this point in the history
Co-authored-by: Koichi Akabe <[email protected]>
  • Loading branch information
akirakubo and vbkaisetsu authored Nov 9, 2023
1 parent b27abca commit f01f5e9
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 3 deletions.
2 changes: 1 addition & 1 deletion convert_kytea_model/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ edition = "2021"
[dependencies]
clap = { version = "4.2", features = ["derive"] } # MIT or Apache-2.0
vaporetto = { path = "../vaporetto", features = ["kytea"] } # MIT or Apache-2.0
zstd = "0.13" # MIT
zstd = { version = "0.13", features = ["zstdmt"] } # MIT
5 changes: 5 additions & 0 deletions convert_kytea_model/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ struct Args {
/// Vaporetto model file
#[clap(long)]
model_out: PathBuf,

/// The number of workers for zstd (0 means multithreaded will be disabled)
#[arg(long, default_value="0")]
zstd_workers: u32,
}

fn main() -> Result<(), Box<dyn std::error::Error>> {
Expand All @@ -31,6 +35,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
eprintln!("Saving model file...");
let model = Model::try_from(model)?;
let mut f = zstd::Encoder::new(fs::File::create(args.model_out)?, 19)?;
f.multithread(args.zstd_workers)?;
model.write(&mut f)?;
f.finish()?;

Expand Down
2 changes: 1 addition & 1 deletion manipulate_model/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ clap = { version = "4.2", features = ["derive"] } # MIT or Apache-2.0
csv = "1.2" # Unlicense or MIT
serde = { version = "1.0", features = ["derive"] } # MIT or Apache-2.0
vaporetto = { path = "../vaporetto" } # MIT or Apache-2.0
zstd = "0.13" # MIT
zstd = { version = "0.13", features = ["zstdmt"] } # MIT
5 changes: 5 additions & 0 deletions manipulate_model/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ struct Args {
/// Replace a dictionary if the argument is specified.
#[arg(long)]
replace_dict: Option<PathBuf>,

/// The number of workers for zstd (0 means multithreaded will be disabled)
#[arg(long, default_value="0")]
zstd_workers: u32,
}

#[derive(Deserialize, Serialize)]
Expand Down Expand Up @@ -72,6 +76,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
if let Some(path) = args.model_out {
eprintln!("Saving model file...");
let mut f = zstd::Encoder::new(fs::File::create(path)?, 19)?;
f.multithread(args.zstd_workers)?;
model.write(&mut f)?;
f.finish()?;
}
Expand Down
2 changes: 1 addition & 1 deletion train/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ edition = "2021"
clap = { version = "4.2", features = ["derive"] } # MIT or Apache-2.0
vaporetto = { path = "../vaporetto", features = ["train"] } # MIT or Apache-2.0
vaporetto_rules = { path = "../vaporetto_rules" } # MIT or Apache-2.0
zstd = "0.13" # MIT
zstd = { version = "0.13", features = ["zstdmt"] } # MIT
5 changes: 5 additions & 0 deletions train/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ struct Args {
/// Do not normalize training data.
#[arg(long)]
no_norm: bool,

/// The number of workers for zstd (0 means multithreaded will be disabled)
#[arg(long, default_value="0")]
zstd_workers: u32,
}

fn main() -> Result<(), Box<dyn std::error::Error>> {
Expand Down Expand Up @@ -179,6 +183,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
eprintln!("Finish training.");

let mut f = zstd::Encoder::new(File::create(args.model)?, 19)?;
f.multithread(args.zstd_workers)?;
model.write(&mut f)?;
f.finish()?;

Expand Down

0 comments on commit f01f5e9

Please sign in to comment.