From caf478c9a3057ee925421853e7ef947a159816be Mon Sep 17 00:00:00 2001 From: Cosmo Bobak <56003038+cosmobobak@users.noreply.github.com> Date: Wed, 24 Jul 2024 16:35:13 +0100 Subject: [PATCH] Multithreaded binpack scanning (#171) Bench: 14655488 --- src/datagen.rs | 83 +++++++++++++++++++++++++++++--------------------- 1 file changed, 49 insertions(+), 34 deletions(-) diff --git a/src/datagen.rs b/src/datagen.rs index c4e14886..a012247e 100644 --- a/src/datagen.rs +++ b/src/datagen.rs @@ -12,11 +12,14 @@ use std::{ io::{BufReader, BufWriter, Seek, Write}, path::{Path, PathBuf}, str::FromStr, - sync::atomic::{AtomicBool, AtomicU64, Ordering}, + sync::{ + atomic::{AtomicBool, AtomicU64, Ordering}, + Mutex, + }, time::Instant, }; -use anyhow::{bail, Context}; +use anyhow::{anyhow, bail, Context}; use bulletformat::ChessBoard; use rand::Rng; @@ -983,11 +986,6 @@ pub fn dataset_stats(dataset_path: &Path) -> anyhow::Result<()> { /// Scans one or more variable-length game format files and prints the position counts. pub fn dataset_count(path: &Path) -> anyhow::Result<()> { - let mut move_buffer = Vec::new(); - - let mut total_count = 0u64; - let mut filtered_count = 0u64; - let paths = if path.is_dir() { fs::read_dir(path).map_or_else( |_| Vec::new(), @@ -1014,36 +1012,53 @@ pub fn dataset_count(path: &Path) -> anyhow::Result<()> { } let mpl = paths.iter().map(|path| path.display().to_string().len()).max().unwrap(); - let start = Instant::now(); - - for path in paths { - let file = File::open(&path)?; - let len = file.metadata().with_context(|| "Failed to get file metadata!")?.len(); - let mut reader = BufReader::new(file); - let mut count = 0u64; - let mut filtered = 0u64; - loop { - match dataformat::Game::deserialise_from(&mut reader, std::mem::take(&mut move_buffer)) { - Ok(game) => { - count += game.len() as u64; - filtered += game.filter_pass_count(|mv, eval, board| !should_filter(mv, eval, board)); - move_buffer = game.into_move_buffer(); - } - Err(error) => { - match error.kind() { - std::io::ErrorKind::UnexpectedEof => {} - _ => eprintln!("[WARN] dataset_count encountered an unexpected error wile reading {file}: {error}\n[WARN] this occured at an offset of {:?} into the file (but probably earlier than this, as we use buffered IO)\n[WARN] for reference, {file} is {} bytes long.", reader.into_inner().stream_position(), len, file = path.file_name().map_or(Cow::Borrowed(""), |oss| oss.to_string_lossy())) + let stdout_lock = Mutex::new(()); + let stdout_lock = &stdout_lock; + + let (total_count, filtered_count) = std::thread::scope(|s| -> anyhow::Result<(u64, u64)> { + let mut thread_handles = Vec::new(); + for path in paths { + thread_handles.push(s.spawn(move || -> anyhow::Result<(u64, u64)> { + let file = File::open(&path)?; + let len = file.metadata().with_context(|| "Failed to get file metadata!")?.len(); + let mut reader = BufReader::new(file); + let mut count = 0u64; + let mut filtered = 0u64; + let mut move_buffer = Vec::new(); + loop { + match dataformat::Game::deserialise_from(&mut reader, std::mem::take(&mut move_buffer)) { + Ok(game) => { + count += game.len() as u64; + filtered += game.filter_pass_count(|mv, eval, board| !should_filter(mv, eval, board)); + move_buffer = game.into_move_buffer(); + } + Err(error) => { + match error.kind() { + std::io::ErrorKind::UnexpectedEof => {} + _ => eprintln!("[WARN] dataset_count encountered an unexpected error wile reading {file}: {error}\n[WARN] this occured at an offset of {:?} into the file (but probably earlier than this, as we use buffered IO)\n[WARN] for reference, {file} is {} bytes long.", reader.into_inner().stream_position(), len, file = path.file_name().map_or(Cow::Borrowed(""), |oss| oss.to_string_lossy())) + } + break; + } } - break; } - } + let lock = stdout_lock.lock().map_err(|_| anyhow!("Failed to lock mutex."))?; + println!("{:mpl$}: {} | {}", path.display(), count, filtered); + std::mem::drop(lock); + Ok((count, filtered)) + })); } - total_count += count; - filtered_count += filtered; - println!("{:mpl$}: {} | {}", path.display(), count, filtered); - print!(" {} pos/s\r", u128::from(total_count) * 1000 / start.elapsed().as_millis()); - std::io::stdout().flush()?; - } + let (mut total_count, mut filtered_count) = (0, 0); + for handle in thread_handles { + let (count, filtered) = handle + .join() + .map_err(|_| anyhow!("Thread panicked.")) + .with_context(|| "Failed to join processing thread")? + .with_context(|| "A processing job failed")?; + total_count += count; + filtered_count += filtered; + } + Ok((total_count, filtered_count)) + })?; println!(); println!("Total: {total_count}");