From 55420a9e3004605d65416f9dceb804669a489569 Mon Sep 17 00:00:00 2001 From: Cosmo Bobak <56003038+cosmobobak@users.noreply.github.com> Date: Mon, 19 Aug 2024 16:13:55 +0100 Subject: [PATCH] Add binpack subsampling (#187) Bench: 13292315 --- src/datagen.rs | 72 ++++++++++++++++---------------- src/datagen/dataformat.rs | 86 ++++++++++++++++++++++++++++++++++----- 2 files changed, 113 insertions(+), 45 deletions(-) diff --git a/src/datagen.rs b/src/datagen.rs index 74815ea5..2c847beb 100644 --- a/src/datagen.rs +++ b/src/datagen.rs @@ -29,7 +29,6 @@ use crate::{ evaluation::{is_game_theoretic_score, is_mate_score}, Board, GameOutcome, }, - chessmove::Move, datagen::dataformat::Game, nnue::network::NNUEParams, piece::{Colour, PieceType}, @@ -652,17 +651,8 @@ pub fn run_splat( bail!("Output file already exists."); } - let filter_state = if let Some(path) = cfg_path { - let text = - std::fs::read_to_string(path).with_context(|| format!("Failed to read filter config file at {path:?}"))?; - toml::from_str(&text).with_context(|| { - let default = toml::to_string_pretty(&Filter::default()).unwrap(); - format!("Failed to parse filter config file at {path:?} \nNote: the config file must be in TOML format. The default config looks like this: \n```\n{default}```") - })? - } else { - Filter::default() - }; - let filter_fn = |mv: Move, eval: i32, board: &Board, wdl: WDL| !filter_state.should_filter(mv, eval, board, wdl); + let filter = cfg_path.map_or_else(|| Ok(Filter::default()), Filter::from_path)?; + let mut rng = rand::thread_rng(); // open the input file let input_file = File::open(input).with_context(|| "Failed to create input file")?; @@ -684,7 +674,8 @@ pub fn run_splat( .write_all(&packed_board.as_bytes()) .with_context(|| "Failed to write PackedBoard into buffered writer.") }, - filter_fn, + &filter, + &mut rng, )?; } else { game.splat_to_bulletformat( @@ -695,7 +686,8 @@ pub fn run_splat( .write_all(&bytes) .with_context(|| "Failed to write bulletformat::ChessBoard into buffered writer.") }, - filter_fn, + &filter, + &mut rng, )?; } move_buffer = game.into_move_buffer(); @@ -1006,22 +998,26 @@ pub fn dataset_count(path: &Path) -> anyhow::Result<()> { let stdout_lock = Mutex::new(()); let stdout_lock = &stdout_lock; - let filter_state = Filter::default(); - let (total_count, filtered_count) = std::thread::scope(|s| -> anyhow::Result<(u64, u64)> { - let mut thread_handles = Vec::new(); - for path in paths { - thread_handles.push(s.spawn(move || -> anyhow::Result<(u64, u64)> { + let filter = &Filter::default(); + let (total_count, filtered_count, pass_count_buckets) = std::thread::scope( + |s| -> anyhow::Result<(u64, u64, Vec)> { + let mut thread_handles = Vec::new(); + for path in paths { + thread_handles.push(s.spawn(move || -> anyhow::Result<(u64, u64, Vec)> { let file = File::open(&path)?; let len = file.metadata().with_context(|| "Failed to get file metadata!")?.len(); let mut reader = BufReader::new(file); let mut count = 0u64; let mut filtered = 0u64; + let mut pass_count_buckets = vec![0u64; Game::MAX_SPLATTABLE_GAME_SIZE]; let mut move_buffer = Vec::new(); loop { match dataformat::Game::deserialise_from(&mut reader, std::mem::take(&mut move_buffer)) { Ok(game) => { count += game.len() as u64; - filtered += game.filter_pass_count(|mv, eval, board, wdl| !filter_state.should_filter(mv, eval, board, wdl)); + let pass_count = game.filter_pass_count(filter); + filtered += pass_count; + pass_count_buckets[usize::try_from(pass_count).unwrap().min(Game::MAX_SPLATTABLE_GAME_SIZE - 1)] += 1; move_buffer = game.into_move_buffer(); } Err(error) => { @@ -1036,25 +1032,33 @@ pub fn dataset_count(path: &Path) -> anyhow::Result<()> { let lock = stdout_lock.lock().map_err(|_| anyhow!("Failed to lock mutex."))?; println!("{:mpl$}: {} | {}", path.display(), count, filtered); std::mem::drop(lock); - Ok((count, filtered)) + Ok((count, filtered, pass_count_buckets)) })); - } - let (mut total_count, mut filtered_count) = (0, 0); - for handle in thread_handles { - let (count, filtered) = handle - .join() - .map_err(|_| anyhow!("Thread panicked.")) - .with_context(|| "Failed to join processing thread")? - .with_context(|| "A processing job failed")?; - total_count += count; - filtered_count += filtered; - } - Ok((total_count, filtered_count)) - })?; + } + let (mut total_count, mut filtered_count) = (0, 0); + let mut total_pass_count_buckets = vec![0u64; Game::MAX_SPLATTABLE_GAME_SIZE]; + for handle in thread_handles { + let (count, filtered, pass_count_buckets) = handle + .join() + .map_err(|_| anyhow!("Thread panicked.")) + .with_context(|| "Failed to join processing thread")? + .with_context(|| "A processing job failed")?; + total_count += count; + filtered_count += filtered; + for (i, count) in pass_count_buckets.into_iter().enumerate() { + total_pass_count_buckets[i] += count; + } + } + Ok((total_count, filtered_count, total_pass_count_buckets)) + }, + )?; println!(); println!("Total: {total_count}"); println!("Total that pass the filter: {filtered_count}"); + for (i, c) in pass_count_buckets.chunks(16).enumerate() { + println!("Games with {:3} to {:3} filtered positions: {}", i * 16, i * 16 + 15, c.iter().sum::()); + } Ok(()) } diff --git a/src/datagen/dataformat.rs b/src/datagen/dataformat.rs index c4666f47..fd11da28 100644 --- a/src/datagen/dataformat.rs +++ b/src/datagen/dataformat.rs @@ -1,3 +1,5 @@ +use std::path::Path; + use crate::{ board::{evaluation::MINIMUM_TB_WIN_SCORE, Board, GameOutcome}, chessmove::Move, @@ -7,12 +9,14 @@ use crate::{ use self::marlinformat::{util::I16Le, PackedBoard}; use anyhow::{anyhow, Context}; +use arrayvec::ArrayVec; +use rand::prelude::SliceRandom; use serde::{Deserialize, Serialize}; mod marlinformat; /// The configuration for a filter that can be applied to a game during unpacking. -#[derive(Clone, Copy, Debug, Hash, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] #[allow(clippy::struct_field_names)] #[serde(default)] pub struct Filter { @@ -30,6 +34,8 @@ pub struct Filter { filter_castling: bool, /// Filter out positions where eval diverges from WDL by more than this value. max_eval_incorrectness: u32, + /// Take this many positions per game. + sample_size: usize, } impl Default for Filter { @@ -42,11 +48,23 @@ impl Default for Filter { filter_check: true, filter_castling: false, max_eval_incorrectness: u32::MAX, + sample_size: usize::MAX, } } } impl Filter { + const UNRESTRICTED: Self = Self { + min_ply: 0, + min_pieces: 0, + max_eval: u32::MAX, + filter_tactical: false, + filter_check: false, + filter_castling: false, + max_eval_incorrectness: u32::MAX, + sample_size: usize::MAX, + }; + pub fn should_filter(&self, mv: Move, eval: i32, board: &Board, wdl: WDL) -> bool { if board.ply() < self.min_ply { return true; @@ -87,6 +105,15 @@ impl Filter { } false } + + pub fn from_path(path: &Path) -> Result { + let text = + std::fs::read_to_string(path).with_context(|| format!("Failed to read filter config file at {path:?}"))?; + toml::from_str(&text).with_context(|| { + let default = toml::to_string_pretty(&Self::default()).unwrap(); + format!("Failed to parse filter config file at {path:?} \nNote: the config file must be in TOML format. The default config looks like this: \n```\n{default}```") + }) + } } /// A game annotated with evaluations starting from a potentially custom position, with support for efficent binary serialisation and deserialisation. @@ -112,6 +139,8 @@ impl WDL { } impl Game { + pub const MAX_SPLATTABLE_GAME_SIZE: usize = 512; + pub fn new(initial_position: &Board) -> Self { Self { initial_position: initial_position.pack(0, 0, 0), moves: Vec::new() } } @@ -217,13 +246,13 @@ impl Game { } /// Internally counts how many positions would pass the filter in this game. - pub fn filter_pass_count(&self, filter: impl Fn(Move, i32, &Board, WDL) -> bool) -> u64 { + pub fn filter_pass_count(&self, filter: &Filter) -> u64 { let mut cnt = 0; let (mut board, _, wdl, _) = self.initial_position.unpack(); let outcome = WDL::from_packed(wdl); for (mv, eval) in &self.moves { let eval = eval.get(); - if filter(*mv, i32::from(eval), &board, outcome) { + if !filter.should_filter(*mv, i32::from(eval), &board, outcome) { cnt += 1; } board.make_move_simple(*mv); @@ -236,18 +265,34 @@ impl Game { pub fn splat_to_marlinformat( &self, mut callback: impl FnMut(marlinformat::PackedBoard) -> anyhow::Result<()>, - filter: impl Fn(Move, i32, &Board, WDL) -> bool, + filter: &Filter, + rng: &mut impl rand::Rng, ) -> anyhow::Result<()> { + // we don't allow buffers of more than this size. + if self.moves.len() > Self::MAX_SPLATTABLE_GAME_SIZE { + return Ok(()); + } + + let mut sample_buffer = ArrayVec::::new(); let (mut board, _, wdl, _) = self.initial_position.unpack(); let outcome = WDL::from_packed(wdl); + + // record all the positions that pass the filter. for (mv, eval) in &self.moves { let eval = eval.get(); - if filter(*mv, i32::from(eval), &board, outcome) { - callback(board.pack(eval, wdl, 0))?; + if !filter.should_filter(*mv, i32::from(eval), &board, outcome) { + sample_buffer.push(board.pack(eval, wdl, 0)); } board.make_move_simple(*mv); } + // sample down to the requested number of positions. + let samples_to_take = filter.sample_size.min(sample_buffer.len()); + let (selected, _) = sample_buffer.partial_shuffle(rng, samples_to_take); + for board in selected { + callback(*board)?; + } + Ok(()) } @@ -255,13 +300,22 @@ impl Game { pub fn splat_to_bulletformat( &self, mut callback: impl FnMut(bulletformat::ChessBoard) -> anyhow::Result<()>, - filter: impl Fn(Move, i32, &Board, WDL) -> bool, + filter: &Filter, + rng: &mut impl rand::Rng, ) -> anyhow::Result<()> { + // we don't allow buffers of more than this size. + if self.moves.len() > Self::MAX_SPLATTABLE_GAME_SIZE { + return Ok(()); + } + + let mut sample_buffer = ArrayVec::::new(); let (mut board, _, wdl, _) = self.initial_position.unpack(); let outcome = WDL::from_packed(wdl); + + // record all the positions that pass the filter. for (mv, eval) in &self.moves { let eval = eval.get(); - if filter(*mv, i32::from(eval), &board, outcome) { + if !filter.should_filter(*mv, i32::from(eval), &board, outcome) { let mut bbs = [0; 8]; let piece_layout = &board.pieces; bbs[0] = piece_layout.occupied_co(Colour::White).inner(); @@ -272,7 +326,7 @@ impl Game { bbs[5] = piece_layout.of_type(PieceType::Rook).inner(); bbs[6] = piece_layout.of_type(PieceType::Queen).inner(); bbs[7] = piece_layout.of_type(PieceType::King).inner(); - callback( + sample_buffer.push( bulletformat::ChessBoard::from_raw( bbs, (board.turn() != Colour::White).into(), @@ -281,11 +335,18 @@ impl Game { ) .map_err(|e| anyhow!(e)) .with_context(|| "Failed to convert raw components into bulletformat::ChessBoard.")?, - )?; + ); } board.make_move_simple(*mv); } + // sample down to the requested number of positions. + let samples_to_take = filter.sample_size.min(sample_buffer.len()); + let (selected, _) = sample_buffer.partial_shuffle(rng, samples_to_take); + for board in selected { + callback(*board)?; + } + Ok(()) } @@ -339,12 +400,15 @@ mod tests { game.add_move(Move::new(Square::G1, Square::F3), 200); let mut boards = Vec::new(); + let filter = Filter::UNRESTRICTED; + let mut rng = rand::thread_rng(); game.splat_to_marlinformat( |board| { boards.push(board); Ok(()) }, - |_, _, _, _| true, + &filter, + &mut rng, ) .unwrap(); assert_eq!(boards.len(), 3);