Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes and verbosity improvements for device mapping #332

Merged
merged 1 commit into from
May 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion mistralrs-bench/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,6 @@ fn main() -> anyhow::Result<()> {
candle_core::utils::with_f16c()
);
info!("Sampling method: penalties -> temperature -> topk -> topp -> multinomial");
info!("Loading model `{}` on {device:?}...", loader.get_id());
if use_flash_attn {
info!("Using flash attention.");
}
Expand Down
10 changes: 7 additions & 3 deletions mistralrs-core/src/device_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,15 @@ impl DeviceMapMetadata {
}));
};
// How many host (cpu) layers, defaulting to automatically filling the rest.
let n_host_layers = self.host_layers.unwrap_or(model_layers - n_device_layers);
// If n_device_layers > model_layers, n_host_layers = 0
let n_host_layers = self
.host_layers
.unwrap_or(model_layers.saturating_sub(n_device_layers));
if n_device_layers + n_host_layers != model_layers {
candle_core::bail!("Expected the number of device ({n_device_layers}) and host layers ({n_host_layers}) to sum to the number of model hidden layers ({model_layers})");
candle_core::bail!("Expected the number of GPU ({n_device_layers}) and host layers ({n_host_layers}) to sum to the number of model hidden layers ({model_layers})");
}
info!("Using {n_device_layers} layers on device and {n_host_layers} on host.");
info!("Model has {model_layers} repeating layers.");
info!("Using {n_device_layers} repeating layers on GPU and {n_host_layers} repeating layers on host.");
let mut combined = vec![device.clone(); n_device_layers];
// Always put the CPU layers at the end so that we reduce dtoh and htod copies
combined.extend(vec![Device::Cpu; n_host_layers]);
Expand Down
1 change: 1 addition & 0 deletions mistralrs-core/src/pipeline/ggml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ impl Loader for GGMLLoader {
if !mapper.is_dummy() {
warn!("GGML models do not support device mapping. Device mapping will not work. Please consider using a GGUF model.");
}
info!("Loading model `{}` on {device:?}...", self.get_id());

let mut file = std::fs::File::open(paths.get_weight_filenames().first().unwrap())?;
let model = ggml_file::Content::read(&mut file, device)
Expand Down
5 changes: 5 additions & 0 deletions mistralrs-core/src/pipeline/gguf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,11 @@ impl Loader for GGUFLoader {
"You are trying to in-situ quantize a GGUF model. This will not do anything."
);
}
// Otherwise, the device mapper will print it
if mapper.is_dummy() {
info!("Loading model `{}` on {device:?}...", self.get_id());
}

let mut file = std::fs::File::open(paths.get_weight_filenames().first().unwrap())?;
let model = gguf_file::Content::read(&mut file)
.map_err(|e| e.with_path(paths.get_weight_filenames().first().unwrap()))?;
Expand Down
4 changes: 4 additions & 0 deletions mistralrs-core/src/pipeline/normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,10 @@ impl Loader for NormalLoader {
} else {
DType::F32
};
// Otherwise, the device mapper will print it
if mapper.is_dummy() {
info!("Loading model `{}` on {device:?}...", self.get_id());
}

info!(
"Model config: {:?}",
Expand Down
1 change: 0 additions & 1 deletion mistralrs-server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
chat_template: Option<String>,

/// Source of the token for authentication.
/// Can be in the formats: "literal:<value>", "env:<value>", "path:<value>", "cache" to use a cached token or "none" to use no token.

Check warning on line 93 in mistralrs-server/src/main.rs

View workflow job for this annotation

GitHub Actions / Docs

unclosed HTML tag `value`

Check warning on line 93 in mistralrs-server/src/main.rs

View workflow job for this annotation

GitHub Actions / Docs

unclosed HTML tag `value`

Check warning on line 93 in mistralrs-server/src/main.rs

View workflow job for this annotation

GitHub Actions / Docs

unclosed HTML tag `value`
/// Defaults to using a cached token.
#[arg(long, default_value_t = TokenSource::CacheToken, value_parser = parse_token_source)]
token_source: TokenSource,
Expand Down Expand Up @@ -266,7 +266,6 @@
candle_core::utils::with_f16c()
);
info!("Sampling method: penalties -> temperature -> topk -> topp -> multinomial");
info!("Loading model `{}` on {device:?}...", loader.get_id());
if use_flash_attn {
info!("Using flash attention.");
}
Expand Down
Loading