Skip to content

Commit

Permalink
Multithreaded binpack scanning (#171)
Browse files Browse the repository at this point in the history
Bench: 14655488
  • Loading branch information
cosmobobak authored Jul 24, 2024
1 parent 80e8b1f commit caf478c
Showing 1 changed file with 49 additions and 34 deletions.
83 changes: 49 additions & 34 deletions src/datagen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@ use std::{
io::{BufReader, BufWriter, Seek, Write},
path::{Path, PathBuf},
str::FromStr,
sync::atomic::{AtomicBool, AtomicU64, Ordering},
sync::{
atomic::{AtomicBool, AtomicU64, Ordering},
Mutex,
},
time::Instant,
};

use anyhow::{bail, Context};
use anyhow::{anyhow, bail, Context};
use bulletformat::ChessBoard;
use rand::Rng;

Expand Down Expand Up @@ -983,11 +986,6 @@ pub fn dataset_stats(dataset_path: &Path) -> anyhow::Result<()> {

/// Scans one or more variable-length game format files and prints the position counts.
pub fn dataset_count(path: &Path) -> anyhow::Result<()> {
let mut move_buffer = Vec::new();

let mut total_count = 0u64;
let mut filtered_count = 0u64;

let paths = if path.is_dir() {
fs::read_dir(path).map_or_else(
|_| Vec::new(),
Expand All @@ -1014,36 +1012,53 @@ pub fn dataset_count(path: &Path) -> anyhow::Result<()> {
}

let mpl = paths.iter().map(|path| path.display().to_string().len()).max().unwrap();
let start = Instant::now();

for path in paths {
let file = File::open(&path)?;
let len = file.metadata().with_context(|| "Failed to get file metadata!")?.len();
let mut reader = BufReader::new(file);
let mut count = 0u64;
let mut filtered = 0u64;
loop {
match dataformat::Game::deserialise_from(&mut reader, std::mem::take(&mut move_buffer)) {
Ok(game) => {
count += game.len() as u64;
filtered += game.filter_pass_count(|mv, eval, board| !should_filter(mv, eval, board));
move_buffer = game.into_move_buffer();
}
Err(error) => {
match error.kind() {
std::io::ErrorKind::UnexpectedEof => {}
_ => eprintln!("[WARN] dataset_count encountered an unexpected error wile reading {file}: {error}\n[WARN] this occured at an offset of {:?} into the file (but probably earlier than this, as we use buffered IO)\n[WARN] for reference, {file} is {} bytes long.", reader.into_inner().stream_position(), len, file = path.file_name().map_or(Cow::Borrowed("<???>"), |oss| oss.to_string_lossy()))
let stdout_lock = Mutex::new(());
let stdout_lock = &stdout_lock;

let (total_count, filtered_count) = std::thread::scope(|s| -> anyhow::Result<(u64, u64)> {
let mut thread_handles = Vec::new();
for path in paths {
thread_handles.push(s.spawn(move || -> anyhow::Result<(u64, u64)> {
let file = File::open(&path)?;
let len = file.metadata().with_context(|| "Failed to get file metadata!")?.len();
let mut reader = BufReader::new(file);
let mut count = 0u64;
let mut filtered = 0u64;
let mut move_buffer = Vec::new();
loop {
match dataformat::Game::deserialise_from(&mut reader, std::mem::take(&mut move_buffer)) {
Ok(game) => {
count += game.len() as u64;
filtered += game.filter_pass_count(|mv, eval, board| !should_filter(mv, eval, board));
move_buffer = game.into_move_buffer();
}
Err(error) => {
match error.kind() {
std::io::ErrorKind::UnexpectedEof => {}
_ => eprintln!("[WARN] dataset_count encountered an unexpected error wile reading {file}: {error}\n[WARN] this occured at an offset of {:?} into the file (but probably earlier than this, as we use buffered IO)\n[WARN] for reference, {file} is {} bytes long.", reader.into_inner().stream_position(), len, file = path.file_name().map_or(Cow::Borrowed("<???>"), |oss| oss.to_string_lossy()))
}
break;
}
}
break;
}
}
let lock = stdout_lock.lock().map_err(|_| anyhow!("Failed to lock mutex."))?;
println!("{:mpl$}: {} | {}", path.display(), count, filtered);
std::mem::drop(lock);
Ok((count, filtered))
}));
}
total_count += count;
filtered_count += filtered;
println!("{:mpl$}: {} | {}", path.display(), count, filtered);
print!(" {} pos/s\r", u128::from(total_count) * 1000 / start.elapsed().as_millis());
std::io::stdout().flush()?;
}
let (mut total_count, mut filtered_count) = (0, 0);
for handle in thread_handles {
let (count, filtered) = handle
.join()
.map_err(|_| anyhow!("Thread panicked."))
.with_context(|| "Failed to join processing thread")?
.with_context(|| "A processing job failed")?;
total_count += count;
filtered_count += filtered;
}
Ok((total_count, filtered_count))
})?;

println!();
println!("Total: {total_count}");
Expand Down

0 comments on commit caf478c

Please sign in to comment.