Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add binpack subsampling. #187

Merged
merged 1 commit into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 38 additions & 34 deletions src/datagen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ use crate::{
evaluation::{is_game_theoretic_score, is_mate_score},
Board, GameOutcome,
},
chessmove::Move,
datagen::dataformat::Game,
nnue::network::NNUEParams,
piece::{Colour, PieceType},
Expand Down Expand Up @@ -652,17 +651,8 @@ pub fn run_splat(
bail!("Output file already exists.");
}

let filter_state = if let Some(path) = cfg_path {
let text =
std::fs::read_to_string(path).with_context(|| format!("Failed to read filter config file at {path:?}"))?;
toml::from_str(&text).with_context(|| {
let default = toml::to_string_pretty(&Filter::default()).unwrap();
format!("Failed to parse filter config file at {path:?} \nNote: the config file must be in TOML format. The default config looks like this: \n```\n{default}```")
})?
} else {
Filter::default()
};
let filter_fn = |mv: Move, eval: i32, board: &Board, wdl: WDL| !filter_state.should_filter(mv, eval, board, wdl);
let filter = cfg_path.map_or_else(|| Ok(Filter::default()), Filter::from_path)?;
let mut rng = rand::thread_rng();

// open the input file
let input_file = File::open(input).with_context(|| "Failed to create input file")?;
Expand All @@ -684,7 +674,8 @@ pub fn run_splat(
.write_all(&packed_board.as_bytes())
.with_context(|| "Failed to write PackedBoard into buffered writer.")
},
filter_fn,
&filter,
&mut rng,
)?;
} else {
game.splat_to_bulletformat(
Expand All @@ -695,7 +686,8 @@ pub fn run_splat(
.write_all(&bytes)
.with_context(|| "Failed to write bulletformat::ChessBoard into buffered writer.")
},
filter_fn,
&filter,
&mut rng,
)?;
}
move_buffer = game.into_move_buffer();
Expand Down Expand Up @@ -1006,22 +998,26 @@ pub fn dataset_count(path: &Path) -> anyhow::Result<()> {
let stdout_lock = Mutex::new(());
let stdout_lock = &stdout_lock;

let filter_state = Filter::default();
let (total_count, filtered_count) = std::thread::scope(|s| -> anyhow::Result<(u64, u64)> {
let mut thread_handles = Vec::new();
for path in paths {
thread_handles.push(s.spawn(move || -> anyhow::Result<(u64, u64)> {
let filter = &Filter::default();
let (total_count, filtered_count, pass_count_buckets) = std::thread::scope(
|s| -> anyhow::Result<(u64, u64, Vec<u64>)> {
let mut thread_handles = Vec::new();
for path in paths {
thread_handles.push(s.spawn(move || -> anyhow::Result<(u64, u64, Vec<u64>)> {
let file = File::open(&path)?;
let len = file.metadata().with_context(|| "Failed to get file metadata!")?.len();
let mut reader = BufReader::new(file);
let mut count = 0u64;
let mut filtered = 0u64;
let mut pass_count_buckets = vec![0u64; Game::MAX_SPLATTABLE_GAME_SIZE];
let mut move_buffer = Vec::new();
loop {
match dataformat::Game::deserialise_from(&mut reader, std::mem::take(&mut move_buffer)) {
Ok(game) => {
count += game.len() as u64;
filtered += game.filter_pass_count(|mv, eval, board, wdl| !filter_state.should_filter(mv, eval, board, wdl));
let pass_count = game.filter_pass_count(filter);
filtered += pass_count;
pass_count_buckets[usize::try_from(pass_count).unwrap().min(Game::MAX_SPLATTABLE_GAME_SIZE - 1)] += 1;
move_buffer = game.into_move_buffer();
}
Err(error) => {
Expand All @@ -1036,25 +1032,33 @@ pub fn dataset_count(path: &Path) -> anyhow::Result<()> {
let lock = stdout_lock.lock().map_err(|_| anyhow!("Failed to lock mutex."))?;
println!("{:mpl$}: {} | {}", path.display(), count, filtered);
std::mem::drop(lock);
Ok((count, filtered))
Ok((count, filtered, pass_count_buckets))
}));
}
let (mut total_count, mut filtered_count) = (0, 0);
for handle in thread_handles {
let (count, filtered) = handle
.join()
.map_err(|_| anyhow!("Thread panicked."))
.with_context(|| "Failed to join processing thread")?
.with_context(|| "A processing job failed")?;
total_count += count;
filtered_count += filtered;
}
Ok((total_count, filtered_count))
})?;
}
let (mut total_count, mut filtered_count) = (0, 0);
let mut total_pass_count_buckets = vec![0u64; Game::MAX_SPLATTABLE_GAME_SIZE];
for handle in thread_handles {
let (count, filtered, pass_count_buckets) = handle
.join()
.map_err(|_| anyhow!("Thread panicked."))
.with_context(|| "Failed to join processing thread")?
.with_context(|| "A processing job failed")?;
total_count += count;
filtered_count += filtered;
for (i, count) in pass_count_buckets.into_iter().enumerate() {
total_pass_count_buckets[i] += count;
}
}
Ok((total_count, filtered_count, total_pass_count_buckets))
},
)?;

println!();
println!("Total: {total_count}");
println!("Total that pass the filter: {filtered_count}");
for (i, c) in pass_count_buckets.chunks(16).enumerate() {
println!("Games with {:3} to {:3} filtered positions: {}", i * 16, i * 16 + 15, c.iter().sum::<u64>());
}

Ok(())
}
86 changes: 75 additions & 11 deletions src/datagen/dataformat.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::path::Path;

use crate::{
board::{evaluation::MINIMUM_TB_WIN_SCORE, Board, GameOutcome},
chessmove::Move,
Expand All @@ -7,12 +9,14 @@ use crate::{

use self::marlinformat::{util::I16Le, PackedBoard};
use anyhow::{anyhow, Context};
use arrayvec::ArrayVec;
use rand::prelude::SliceRandom;
use serde::{Deserialize, Serialize};

mod marlinformat;

/// The configuration for a filter that can be applied to a game during unpacking.
#[derive(Clone, Copy, Debug, Hash, Serialize, Deserialize)]
#[derive(Clone, Debug, Serialize, Deserialize)]
#[allow(clippy::struct_field_names)]
#[serde(default)]
pub struct Filter {
Expand All @@ -30,6 +34,8 @@ pub struct Filter {
filter_castling: bool,
/// Filter out positions where eval diverges from WDL by more than this value.
max_eval_incorrectness: u32,
/// Take this many positions per game.
sample_size: usize,
}

impl Default for Filter {
Expand All @@ -42,11 +48,23 @@ impl Default for Filter {
filter_check: true,
filter_castling: false,
max_eval_incorrectness: u32::MAX,
sample_size: usize::MAX,
}
}
}

impl Filter {
const UNRESTRICTED: Self = Self {
min_ply: 0,
min_pieces: 0,
max_eval: u32::MAX,
filter_tactical: false,
filter_check: false,
filter_castling: false,
max_eval_incorrectness: u32::MAX,
sample_size: usize::MAX,
};

pub fn should_filter(&self, mv: Move, eval: i32, board: &Board, wdl: WDL) -> bool {
if board.ply() < self.min_ply {
return true;
Expand Down Expand Up @@ -87,6 +105,15 @@ impl Filter {
}
false
}

pub fn from_path(path: &Path) -> Result<Self, anyhow::Error> {
let text =
std::fs::read_to_string(path).with_context(|| format!("Failed to read filter config file at {path:?}"))?;
toml::from_str(&text).with_context(|| {
let default = toml::to_string_pretty(&Self::default()).unwrap();
format!("Failed to parse filter config file at {path:?} \nNote: the config file must be in TOML format. The default config looks like this: \n```\n{default}```")
})
}
}

/// A game annotated with evaluations starting from a potentially custom position, with support for efficent binary serialisation and deserialisation.
Expand All @@ -112,6 +139,8 @@ impl WDL {
}

impl Game {
pub const MAX_SPLATTABLE_GAME_SIZE: usize = 512;

pub fn new(initial_position: &Board) -> Self {
Self { initial_position: initial_position.pack(0, 0, 0), moves: Vec::new() }
}
Expand Down Expand Up @@ -217,13 +246,13 @@ impl Game {
}

/// Internally counts how many positions would pass the filter in this game.
pub fn filter_pass_count(&self, filter: impl Fn(Move, i32, &Board, WDL) -> bool) -> u64 {
pub fn filter_pass_count(&self, filter: &Filter) -> u64 {
let mut cnt = 0;
let (mut board, _, wdl, _) = self.initial_position.unpack();
let outcome = WDL::from_packed(wdl);
for (mv, eval) in &self.moves {
let eval = eval.get();
if filter(*mv, i32::from(eval), &board, outcome) {
if !filter.should_filter(*mv, i32::from(eval), &board, outcome) {
cnt += 1;
}
board.make_move_simple(*mv);
Expand All @@ -236,32 +265,57 @@ impl Game {
pub fn splat_to_marlinformat(
&self,
mut callback: impl FnMut(marlinformat::PackedBoard) -> anyhow::Result<()>,
filter: impl Fn(Move, i32, &Board, WDL) -> bool,
filter: &Filter,
rng: &mut impl rand::Rng,
) -> anyhow::Result<()> {
// we don't allow buffers of more than this size.
if self.moves.len() > Self::MAX_SPLATTABLE_GAME_SIZE {
return Ok(());
}

let mut sample_buffer = ArrayVec::<marlinformat::PackedBoard, { Self::MAX_SPLATTABLE_GAME_SIZE }>::new();
let (mut board, _, wdl, _) = self.initial_position.unpack();
let outcome = WDL::from_packed(wdl);

// record all the positions that pass the filter.
for (mv, eval) in &self.moves {
let eval = eval.get();
if filter(*mv, i32::from(eval), &board, outcome) {
callback(board.pack(eval, wdl, 0))?;
if !filter.should_filter(*mv, i32::from(eval), &board, outcome) {
sample_buffer.push(board.pack(eval, wdl, 0));
}
board.make_move_simple(*mv);
}

// sample down to the requested number of positions.
let samples_to_take = filter.sample_size.min(sample_buffer.len());
let (selected, _) = sample_buffer.partial_shuffle(rng, samples_to_take);
for board in selected {
callback(*board)?;
}

Ok(())
}

/// Converts the game into a sequence of bulletformat `ChessBoard` objects, yielding only those positions that pass the filter.
pub fn splat_to_bulletformat(
&self,
mut callback: impl FnMut(bulletformat::ChessBoard) -> anyhow::Result<()>,
filter: impl Fn(Move, i32, &Board, WDL) -> bool,
filter: &Filter,
rng: &mut impl rand::Rng,
) -> anyhow::Result<()> {
// we don't allow buffers of more than this size.
if self.moves.len() > Self::MAX_SPLATTABLE_GAME_SIZE {
return Ok(());
}

let mut sample_buffer = ArrayVec::<bulletformat::ChessBoard, { Self::MAX_SPLATTABLE_GAME_SIZE }>::new();
let (mut board, _, wdl, _) = self.initial_position.unpack();
let outcome = WDL::from_packed(wdl);

// record all the positions that pass the filter.
for (mv, eval) in &self.moves {
let eval = eval.get();
if filter(*mv, i32::from(eval), &board, outcome) {
if !filter.should_filter(*mv, i32::from(eval), &board, outcome) {
let mut bbs = [0; 8];
let piece_layout = &board.pieces;
bbs[0] = piece_layout.occupied_co(Colour::White).inner();
Expand All @@ -272,7 +326,7 @@ impl Game {
bbs[5] = piece_layout.of_type(PieceType::Rook).inner();
bbs[6] = piece_layout.of_type(PieceType::Queen).inner();
bbs[7] = piece_layout.of_type(PieceType::King).inner();
callback(
sample_buffer.push(
bulletformat::ChessBoard::from_raw(
bbs,
(board.turn() != Colour::White).into(),
Expand All @@ -281,11 +335,18 @@ impl Game {
)
.map_err(|e| anyhow!(e))
.with_context(|| "Failed to convert raw components into bulletformat::ChessBoard.")?,
)?;
);
}
board.make_move_simple(*mv);
}

// sample down to the requested number of positions.
let samples_to_take = filter.sample_size.min(sample_buffer.len());
let (selected, _) = sample_buffer.partial_shuffle(rng, samples_to_take);
for board in selected {
callback(*board)?;
}

Ok(())
}

Expand Down Expand Up @@ -339,12 +400,15 @@ mod tests {
game.add_move(Move::new(Square::G1, Square::F3), 200);

let mut boards = Vec::new();
let filter = Filter::UNRESTRICTED;
let mut rng = rand::thread_rng();
game.splat_to_marlinformat(
|board| {
boards.push(board);
Ok(())
},
|_, _, _, _| true,
&filter,
&mut rng,
)
.unwrap();
assert_eq!(boards.len(), 3);
Expand Down
Loading