From 3e750efc48484af5a69810b2634dcc51ab9360cd Mon Sep 17 00:00:00 2001 From: Jamie Whiting <99771266+jw1912@users.noreply.github.com> Date: Fri, 23 Aug 2024 10:43:49 +0100 Subject: [PATCH] Use Newer `bullet` Version (#46) Allows use of newer optimiser tech. Bench: 1317423 --- Cargo.lock | 6 +++--- train/value/src/main.rs | 24 +++++++++++++++++------- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8abfde48..717402c4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -49,7 +49,7 @@ checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" [[package]] name = "bullet_lib" version = "1.0.0" -source = "git+https://github.com/jw1912/bullet#630c9622b93dbee045524bd61b977618bd184f96" +source = "git+https://github.com/jw1912/bullet#f39de09c7c9f9f7634b8cf6977df7c02efe0fa74" dependencies = [ "bindgen", "bulletformat", @@ -60,9 +60,9 @@ dependencies = [ [[package]] name = "bulletformat" -version = "1.6.0" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e25220aef4d4194f3091b7cd93641d7e324063b287c351664e48cf19edbd42de" +checksum = "e5261a2681e729de3d341e007d038465ee3cacb62ee87487e0b3dbeb31aec3c2" [[package]] name = "cc" diff --git a/train/value/src/main.rs b/train/value/src/main.rs index 91d1f1d5..4258451a 100644 --- a/train/value/src/main.rs +++ b/train/value/src/main.rs @@ -1,7 +1,7 @@ use bullet::{ format::{chess::BoardIter, ChessBoard}, - inputs, outputs, Activation, LocalSettings, Loss, LrScheduler, TrainerBuilder, - TrainingSchedule, WdlScheduler, + inputs, loader, lr, optimiser, outputs, wdl, Activation, LocalSettings, Loss, TrainerBuilder, + TrainingSchedule, }; use monty::Board; @@ -9,6 +9,7 @@ const HIDDEN_SIZE: usize = 2048; fn main() { let mut trainer = TrainerBuilder::default() + .optimiser(optimiser::AdamW) .single_perspective() .input(ThreatInputs) .output_buckets(outputs::Single) @@ -43,23 +44,32 @@ fn main() { batches_per_superbatch: 6104, start_superbatch: 1, end_superbatch: 1200, - wdl_scheduler: WdlScheduler::Constant { value: 0.5 }, - lr_scheduler: LrScheduler::Step { + wdl_scheduler: wdl::ConstantWDL { value: 0.5 }, + lr_scheduler: lr::StepLR { start: 0.001, gamma: 0.1, step: 300, }, loss_function: Loss::SigmoidMSE, save_rate: 10, + optimiser_settings: optimiser::AdamWParams { + decay: 0.01, + beta1: 0.9, + beta2: 0.999, + min_weight: -1.98, + max_weight: 1.98, + }, }; let settings = LocalSettings { threads: 8, - data_file_paths: vec!["../monty-data/12-08-24.data"], + test_set: None, output_directory: "checkpoints", }; - trainer.run(&schedule, &settings); + let data_loader = loader::DirectSequentialDataLoader::new(&["../monty-data/12-08-24.data"]); + + trainer.run(&schedule, &settings, &data_loader); for fen in [ "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1", @@ -109,7 +119,7 @@ impl inputs::InputType for ThreatInputs { bb[usize::from(2 + (pc & 7))] ^= bit; } - let board = Board::from_raw(bb, false, 0, 0, 0); + let board = Board::from_raw(bb, false, 0, 0, 0, 1); let threats = board.threats_by(1); let defences = board.threats_by(0);