diff --git a/convert_kytea_model/Cargo.toml b/convert_kytea_model/Cargo.toml index d827d547..070740dd 100644 --- a/convert_kytea_model/Cargo.toml +++ b/convert_kytea_model/Cargo.toml @@ -6,4 +6,4 @@ edition = "2021" [dependencies] clap = { version = "4.2", features = ["derive"] } # MIT or Apache-2.0 vaporetto = { path = "../vaporetto", features = ["kytea"] } # MIT or Apache-2.0 -zstd = "0.13" # MIT +zstd = { version = "0.13", features = ["zstdmt"] } # MIT diff --git a/convert_kytea_model/src/main.rs b/convert_kytea_model/src/main.rs index b97efd25..b390e4af 100644 --- a/convert_kytea_model/src/main.rs +++ b/convert_kytea_model/src/main.rs @@ -19,6 +19,10 @@ struct Args { /// Vaporetto model file #[clap(long)] model_out: PathBuf, + + /// The number of workers for zstd (0 means multithreaded will be disabled) + #[arg(long, default_value="0")] + zstd_workers: u32, } fn main() -> Result<(), Box> { @@ -31,6 +35,7 @@ fn main() -> Result<(), Box> { eprintln!("Saving model file..."); let model = Model::try_from(model)?; let mut f = zstd::Encoder::new(fs::File::create(args.model_out)?, 19)?; + f.multithread(args.zstd_workers)?; model.write(&mut f)?; f.finish()?; diff --git a/manipulate_model/Cargo.toml b/manipulate_model/Cargo.toml index a1c45b17..1e966a23 100644 --- a/manipulate_model/Cargo.toml +++ b/manipulate_model/Cargo.toml @@ -8,4 +8,4 @@ clap = { version = "4.2", features = ["derive"] } # MIT or Apache-2.0 csv = "1.2" # Unlicense or MIT serde = { version = "1.0", features = ["derive"] } # MIT or Apache-2.0 vaporetto = { path = "../vaporetto" } # MIT or Apache-2.0 -zstd = "0.13" # MIT +zstd = { version = "0.13", features = ["zstdmt"] } # MIT diff --git a/manipulate_model/src/main.rs b/manipulate_model/src/main.rs index d5348452..35d9c6e7 100644 --- a/manipulate_model/src/main.rs +++ b/manipulate_model/src/main.rs @@ -23,6 +23,10 @@ struct Args { /// Replace a dictionary if the argument is specified. #[arg(long)] replace_dict: Option, + + /// The number of workers for zstd (0 means multithreaded will be disabled) + #[arg(long, default_value="0")] + zstd_workers: u32, } #[derive(Deserialize, Serialize)] @@ -72,6 +76,7 @@ fn main() -> Result<(), Box> { if let Some(path) = args.model_out { eprintln!("Saving model file..."); let mut f = zstd::Encoder::new(fs::File::create(path)?, 19)?; + f.multithread(args.zstd_workers)?; model.write(&mut f)?; f.finish()?; } diff --git a/train/Cargo.toml b/train/Cargo.toml index e14f2561..fa145d5d 100644 --- a/train/Cargo.toml +++ b/train/Cargo.toml @@ -7,4 +7,4 @@ edition = "2021" clap = { version = "4.2", features = ["derive"] } # MIT or Apache-2.0 vaporetto = { path = "../vaporetto", features = ["train"] } # MIT or Apache-2.0 vaporetto_rules = { path = "../vaporetto_rules" } # MIT or Apache-2.0 -zstd = "0.13" # MIT +zstd = { version = "0.13", features = ["zstdmt"] } # MIT diff --git a/train/src/main.rs b/train/src/main.rs index 7958935a..94001908 100644 --- a/train/src/main.rs +++ b/train/src/main.rs @@ -65,6 +65,10 @@ struct Args { /// Do not normalize training data. #[arg(long)] no_norm: bool, + + /// The number of workers for zstd (0 means multithreaded will be disabled) + #[arg(long, default_value="0")] + zstd_workers: u32, } fn main() -> Result<(), Box> { @@ -179,6 +183,7 @@ fn main() -> Result<(), Box> { eprintln!("Finish training."); let mut f = zstd::Encoder::new(File::create(args.model)?, 19)?; + f.multithread(args.zstd_workers)?; model.write(&mut f)?; f.finish()?;