Skip to content

Commit

Permalink
Add binpack subsampling (#187)
Browse files Browse the repository at this point in the history
Bench: 13292315
  • Loading branch information
cosmobobak authored Aug 19, 2024
1 parent a431907 commit 55420a9
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 45 deletions.
72 changes: 38 additions & 34 deletions src/datagen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ use crate::{
evaluation::{is_game_theoretic_score, is_mate_score},
Board, GameOutcome,
},
chessmove::Move,
datagen::dataformat::Game,
nnue::network::NNUEParams,
piece::{Colour, PieceType},
Expand Down Expand Up @@ -652,17 +651,8 @@ pub fn run_splat(
bail!("Output file already exists.");
}

let filter_state = if let Some(path) = cfg_path {
let text =
std::fs::read_to_string(path).with_context(|| format!("Failed to read filter config file at {path:?}"))?;
toml::from_str(&text).with_context(|| {
let default = toml::to_string_pretty(&Filter::default()).unwrap();
format!("Failed to parse filter config file at {path:?} \nNote: the config file must be in TOML format. The default config looks like this: \n```\n{default}```")
})?
} else {
Filter::default()
};
let filter_fn = |mv: Move, eval: i32, board: &Board, wdl: WDL| !filter_state.should_filter(mv, eval, board, wdl);
let filter = cfg_path.map_or_else(|| Ok(Filter::default()), Filter::from_path)?;
let mut rng = rand::thread_rng();

// open the input file
let input_file = File::open(input).with_context(|| "Failed to create input file")?;
Expand All @@ -684,7 +674,8 @@ pub fn run_splat(
.write_all(&packed_board.as_bytes())
.with_context(|| "Failed to write PackedBoard into buffered writer.")
},
filter_fn,
&filter,
&mut rng,
)?;
} else {
game.splat_to_bulletformat(
Expand All @@ -695,7 +686,8 @@ pub fn run_splat(
.write_all(&bytes)
.with_context(|| "Failed to write bulletformat::ChessBoard into buffered writer.")
},
filter_fn,
&filter,
&mut rng,
)?;
}
move_buffer = game.into_move_buffer();
Expand Down Expand Up @@ -1006,22 +998,26 @@ pub fn dataset_count(path: &Path) -> anyhow::Result<()> {
let stdout_lock = Mutex::new(());
let stdout_lock = &stdout_lock;

let filter_state = Filter::default();
let (total_count, filtered_count) = std::thread::scope(|s| -> anyhow::Result<(u64, u64)> {
let mut thread_handles = Vec::new();
for path in paths {
thread_handles.push(s.spawn(move || -> anyhow::Result<(u64, u64)> {
let filter = &Filter::default();
let (total_count, filtered_count, pass_count_buckets) = std::thread::scope(
|s| -> anyhow::Result<(u64, u64, Vec<u64>)> {
let mut thread_handles = Vec::new();
for path in paths {
thread_handles.push(s.spawn(move || -> anyhow::Result<(u64, u64, Vec<u64>)> {
let file = File::open(&path)?;
let len = file.metadata().with_context(|| "Failed to get file metadata!")?.len();
let mut reader = BufReader::new(file);
let mut count = 0u64;
let mut filtered = 0u64;
let mut pass_count_buckets = vec![0u64; Game::MAX_SPLATTABLE_GAME_SIZE];
let mut move_buffer = Vec::new();
loop {
match dataformat::Game::deserialise_from(&mut reader, std::mem::take(&mut move_buffer)) {
Ok(game) => {
count += game.len() as u64;
filtered += game.filter_pass_count(|mv, eval, board, wdl| !filter_state.should_filter(mv, eval, board, wdl));
let pass_count = game.filter_pass_count(filter);
filtered += pass_count;
pass_count_buckets[usize::try_from(pass_count).unwrap().min(Game::MAX_SPLATTABLE_GAME_SIZE - 1)] += 1;
move_buffer = game.into_move_buffer();
}
Err(error) => {
Expand All @@ -1036,25 +1032,33 @@ pub fn dataset_count(path: &Path) -> anyhow::Result<()> {
let lock = stdout_lock.lock().map_err(|_| anyhow!("Failed to lock mutex."))?;
println!("{:mpl$}: {} | {}", path.display(), count, filtered);
std::mem::drop(lock);
Ok((count, filtered))
Ok((count, filtered, pass_count_buckets))
}));
}
let (mut total_count, mut filtered_count) = (0, 0);
for handle in thread_handles {
let (count, filtered) = handle
.join()
.map_err(|_| anyhow!("Thread panicked."))
.with_context(|| "Failed to join processing thread")?
.with_context(|| "A processing job failed")?;
total_count += count;
filtered_count += filtered;
}
Ok((total_count, filtered_count))
})?;
}
let (mut total_count, mut filtered_count) = (0, 0);
let mut total_pass_count_buckets = vec![0u64; Game::MAX_SPLATTABLE_GAME_SIZE];
for handle in thread_handles {
let (count, filtered, pass_count_buckets) = handle
.join()
.map_err(|_| anyhow!("Thread panicked."))
.with_context(|| "Failed to join processing thread")?
.with_context(|| "A processing job failed")?;
total_count += count;
filtered_count += filtered;
for (i, count) in pass_count_buckets.into_iter().enumerate() {
total_pass_count_buckets[i] += count;
}
}
Ok((total_count, filtered_count, total_pass_count_buckets))
},
)?;

println!();
println!("Total: {total_count}");
println!("Total that pass the filter: {filtered_count}");
for (i, c) in pass_count_buckets.chunks(16).enumerate() {
println!("Games with {:3} to {:3} filtered positions: {}", i * 16, i * 16 + 15, c.iter().sum::<u64>());
}

Ok(())
}
86 changes: 75 additions & 11 deletions src/datagen/dataformat.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::path::Path;

use crate::{
board::{evaluation::MINIMUM_TB_WIN_SCORE, Board, GameOutcome},
chessmove::Move,
Expand All @@ -7,12 +9,14 @@ use crate::{

use self::marlinformat::{util::I16Le, PackedBoard};
use anyhow::{anyhow, Context};
use arrayvec::ArrayVec;
use rand::prelude::SliceRandom;
use serde::{Deserialize, Serialize};

mod marlinformat;

/// The configuration for a filter that can be applied to a game during unpacking.
#[derive(Clone, Copy, Debug, Hash, Serialize, Deserialize)]
#[derive(Clone, Debug, Serialize, Deserialize)]
#[allow(clippy::struct_field_names)]
#[serde(default)]
pub struct Filter {
Expand All @@ -30,6 +34,8 @@ pub struct Filter {
filter_castling: bool,
/// Filter out positions where eval diverges from WDL by more than this value.
max_eval_incorrectness: u32,
/// Take this many positions per game.
sample_size: usize,
}

impl Default for Filter {
Expand All @@ -42,11 +48,23 @@ impl Default for Filter {
filter_check: true,
filter_castling: false,
max_eval_incorrectness: u32::MAX,
sample_size: usize::MAX,
}
}
}

impl Filter {
const UNRESTRICTED: Self = Self {
min_ply: 0,
min_pieces: 0,
max_eval: u32::MAX,
filter_tactical: false,
filter_check: false,
filter_castling: false,
max_eval_incorrectness: u32::MAX,
sample_size: usize::MAX,
};

pub fn should_filter(&self, mv: Move, eval: i32, board: &Board, wdl: WDL) -> bool {
if board.ply() < self.min_ply {
return true;
Expand Down Expand Up @@ -87,6 +105,15 @@ impl Filter {
}
false
}

pub fn from_path(path: &Path) -> Result<Self, anyhow::Error> {
let text =
std::fs::read_to_string(path).with_context(|| format!("Failed to read filter config file at {path:?}"))?;
toml::from_str(&text).with_context(|| {
let default = toml::to_string_pretty(&Self::default()).unwrap();
format!("Failed to parse filter config file at {path:?} \nNote: the config file must be in TOML format. The default config looks like this: \n```\n{default}```")
})
}
}

/// A game annotated with evaluations starting from a potentially custom position, with support for efficent binary serialisation and deserialisation.
Expand All @@ -112,6 +139,8 @@ impl WDL {
}

impl Game {
pub const MAX_SPLATTABLE_GAME_SIZE: usize = 512;

pub fn new(initial_position: &Board) -> Self {
Self { initial_position: initial_position.pack(0, 0, 0), moves: Vec::new() }
}
Expand Down Expand Up @@ -217,13 +246,13 @@ impl Game {
}

/// Internally counts how many positions would pass the filter in this game.
pub fn filter_pass_count(&self, filter: impl Fn(Move, i32, &Board, WDL) -> bool) -> u64 {
pub fn filter_pass_count(&self, filter: &Filter) -> u64 {
let mut cnt = 0;
let (mut board, _, wdl, _) = self.initial_position.unpack();
let outcome = WDL::from_packed(wdl);
for (mv, eval) in &self.moves {
let eval = eval.get();
if filter(*mv, i32::from(eval), &board, outcome) {
if !filter.should_filter(*mv, i32::from(eval), &board, outcome) {
cnt += 1;
}
board.make_move_simple(*mv);
Expand All @@ -236,32 +265,57 @@ impl Game {
pub fn splat_to_marlinformat(
&self,
mut callback: impl FnMut(marlinformat::PackedBoard) -> anyhow::Result<()>,
filter: impl Fn(Move, i32, &Board, WDL) -> bool,
filter: &Filter,
rng: &mut impl rand::Rng,
) -> anyhow::Result<()> {
// we don't allow buffers of more than this size.
if self.moves.len() > Self::MAX_SPLATTABLE_GAME_SIZE {
return Ok(());
}

let mut sample_buffer = ArrayVec::<marlinformat::PackedBoard, { Self::MAX_SPLATTABLE_GAME_SIZE }>::new();
let (mut board, _, wdl, _) = self.initial_position.unpack();
let outcome = WDL::from_packed(wdl);

// record all the positions that pass the filter.
for (mv, eval) in &self.moves {
let eval = eval.get();
if filter(*mv, i32::from(eval), &board, outcome) {
callback(board.pack(eval, wdl, 0))?;
if !filter.should_filter(*mv, i32::from(eval), &board, outcome) {
sample_buffer.push(board.pack(eval, wdl, 0));
}
board.make_move_simple(*mv);
}

// sample down to the requested number of positions.
let samples_to_take = filter.sample_size.min(sample_buffer.len());
let (selected, _) = sample_buffer.partial_shuffle(rng, samples_to_take);
for board in selected {
callback(*board)?;
}

Ok(())
}

/// Converts the game into a sequence of bulletformat `ChessBoard` objects, yielding only those positions that pass the filter.
pub fn splat_to_bulletformat(
&self,
mut callback: impl FnMut(bulletformat::ChessBoard) -> anyhow::Result<()>,
filter: impl Fn(Move, i32, &Board, WDL) -> bool,
filter: &Filter,
rng: &mut impl rand::Rng,
) -> anyhow::Result<()> {
// we don't allow buffers of more than this size.
if self.moves.len() > Self::MAX_SPLATTABLE_GAME_SIZE {
return Ok(());
}

let mut sample_buffer = ArrayVec::<bulletformat::ChessBoard, { Self::MAX_SPLATTABLE_GAME_SIZE }>::new();
let (mut board, _, wdl, _) = self.initial_position.unpack();
let outcome = WDL::from_packed(wdl);

// record all the positions that pass the filter.
for (mv, eval) in &self.moves {
let eval = eval.get();
if filter(*mv, i32::from(eval), &board, outcome) {
if !filter.should_filter(*mv, i32::from(eval), &board, outcome) {
let mut bbs = [0; 8];
let piece_layout = &board.pieces;
bbs[0] = piece_layout.occupied_co(Colour::White).inner();
Expand All @@ -272,7 +326,7 @@ impl Game {
bbs[5] = piece_layout.of_type(PieceType::Rook).inner();
bbs[6] = piece_layout.of_type(PieceType::Queen).inner();
bbs[7] = piece_layout.of_type(PieceType::King).inner();
callback(
sample_buffer.push(
bulletformat::ChessBoard::from_raw(
bbs,
(board.turn() != Colour::White).into(),
Expand All @@ -281,11 +335,18 @@ impl Game {
)
.map_err(|e| anyhow!(e))
.with_context(|| "Failed to convert raw components into bulletformat::ChessBoard.")?,
)?;
);
}
board.make_move_simple(*mv);
}

// sample down to the requested number of positions.
let samples_to_take = filter.sample_size.min(sample_buffer.len());
let (selected, _) = sample_buffer.partial_shuffle(rng, samples_to_take);
for board in selected {
callback(*board)?;
}

Ok(())
}

Expand Down Expand Up @@ -339,12 +400,15 @@ mod tests {
game.add_move(Move::new(Square::G1, Square::F3), 200);

let mut boards = Vec::new();
let filter = Filter::UNRESTRICTED;
let mut rng = rand::thread_rng();
game.splat_to_marlinformat(
|board| {
boards.push(board);
Ok(())
},
|_, _, _, _| true,
&filter,
&mut rng,
)
.unwrap();
assert_eq!(boards.len(), 3);
Expand Down

0 comments on commit 55420a9

Please sign in to comment.