From ca479a873eeebb21cdef1e2a95d11fea742390f4 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 27 Jul 2023 20:05:02 +0200 Subject: [PATCH] Upgrading hf-hub to `0.2.0` (Modified API to not pass the Repo around all the time) --- Cargo.toml | 2 +- candle-examples/examples/bert/main.rs | 7 ++++--- candle-examples/examples/falcon/main.rs | 10 +++++++--- candle-examples/examples/llama/main.rs | 8 ++++---- candle-examples/examples/whisper/main.rs | 19 +++++++------------ 5 files changed, 23 insertions(+), 23 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 0dec835b7..05c6240b0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,7 +22,7 @@ clap = { version = "4.2.4", features = ["derive"] } cudarc = { git = "https://github.com/LaurentMazare/cudarc.git", branch = "cublas-bf16", features = ["f16"] } # TODO: Switch back to the official gemm implementation if we manage to upstream the changes. gemm = { git = "https://github.com/LaurentMazare/gemm.git" } -hf-hub = "0.1.3" +hf-hub = "0.2.0" half = { version = "2.3.1", features = ["num-traits", "rand_distr"] } intel-mkl-src = { version = "0.8.1", features = ["mkl-static-lp64-iomp"] } libc = { version = "0.2.147" } diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index 6672ad092..79c789682 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -69,10 +69,11 @@ impl Args { ) } else { let api = Api::new()?; + let api = api.repo(repo); ( - api.get(&repo, "config.json")?, - api.get(&repo, "tokenizer.json")?, - api.get(&repo, "model.safetensors")?, + api.get("config.json")?, + api.get("tokenizer.json")?, + api.get("model.safetensors")?, ) }; let config = std::fs::read_to_string(config_filename)?; diff --git a/candle-examples/examples/falcon/main.rs b/candle-examples/examples/falcon/main.rs index 3a284c860..a01191a58 100644 --- a/candle-examples/examples/falcon/main.rs +++ b/candle-examples/examples/falcon/main.rs @@ -123,14 +123,18 @@ fn main() -> Result<()> { let device = candle_examples::device(args.cpu)?; let start = std::time::Instant::now(); let api = Api::new()?; - let repo = Repo::with_revision(args.model_id, RepoType::Model, args.revision); - let tokenizer_filename = api.get(&repo, "tokenizer.json")?; + let repo = api.repo(Repo::with_revision( + args.model_id, + RepoType::Model, + args.revision, + )); + let tokenizer_filename = repo.get("tokenizer.json")?; let mut filenames = vec![]; for rfilename in [ "model-00001-of-00002.safetensors", "model-00002-of-00002.safetensors", ] { - let filename = api.get(&repo, rfilename)?; + let filename = repo.get(rfilename)?; filenames.push(filename); } println!("retrieved the files in {:?}", start.elapsed()); diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index 582ac3f81..d9d1e21ae 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -18,7 +18,7 @@ use clap::Parser; use candle::{DType, Tensor}; use candle_nn::VarBuilder; use candle_transformers::generation::LogitsProcessor; -use hf_hub::{api::sync::Api, Repo, RepoType}; +use hf_hub::api::sync::Api; mod model; use model::{Config, Llama}; @@ -146,14 +146,14 @@ fn main() -> Result<()> { } }); println!("loading the model weights from {model_id}"); - let repo = Repo::new(model_id, RepoType::Model); - let tokenizer_filename = api.get(&repo, "tokenizer.json")?; + let api = api.model(model_id); + let tokenizer_filename = api.get("tokenizer.json")?; let mut filenames = vec![]; for rfilename in [ "model-00001-of-00002.safetensors", "model-00002-of-00002.safetensors", ] { - let filename = api.get(&repo, rfilename)?; + let filename = api.get(rfilename)?; filenames.push(filename); } diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 079424e38..c03779e71 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -282,28 +282,23 @@ fn main() -> Result<()> { std::path::PathBuf::from(args.input.expect("You didn't specify a file to read from yet, are using a local model, please add `--input example.wav` to read some audio file")), ) } else { - let repo = Repo::with_revision(model_id, RepoType::Model, revision); let api = Api::new()?; + let dataset = api.dataset("Narsil/candle-examples".to_string()); + let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision)); let sample = if let Some(input) = args.input { if let Some(sample) = input.strip_prefix("sample:") { - api.get( - &Repo::new("Narsil/candle-examples".to_string(), RepoType::Dataset), - &format!("samples_{sample}.wav"), - )? + dataset.get(&format!("samples_{sample}.wav"))? } else { std::path::PathBuf::from(input) } } else { println!("No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav"); - api.get( - &Repo::new("Narsil/candle-examples".to_string(), RepoType::Dataset), - "samples_jfk.wav", - )? + dataset.get("samples_jfk.wav")? }; ( - api.get(&repo, "config.json")?, - api.get(&repo, "tokenizer.json")?, - api.get(&repo, "model.safetensors")?, + repo.get("config.json")?, + repo.get("tokenizer.json")?, + repo.get("model.safetensors")?, sample, ) };