diff --git a/Cargo.lock b/Cargo.lock index 2ca4e08c..8abfde48 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -140,7 +140,7 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" [[package]] name = "goober" version = "0.1.0" -source = "git+https://github.com/jw1912/goober.git#32b9b52e68ef03d9d706548fb21bb3c4535c4dd2" +source = "git+https://github.com/jw1912/goober#32b9b52e68ef03d9d706548fb21bb3c4535c4dd2" dependencies = [ "goober-core", "goober-derive", @@ -150,12 +150,12 @@ dependencies = [ [[package]] name = "goober-core" version = "0.1.0" -source = "git+https://github.com/jw1912/goober.git#32b9b52e68ef03d9d706548fb21bb3c4535c4dd2" +source = "git+https://github.com/jw1912/goober#32b9b52e68ef03d9d706548fb21bb3c4535c4dd2" [[package]] name = "goober-derive" version = "0.1.0" -source = "git+https://github.com/jw1912/goober.git#32b9b52e68ef03d9d706548fb21bb3c4535c4dd2" +source = "git+https://github.com/jw1912/goober#32b9b52e68ef03d9d706548fb21bb3c4535c4dd2" dependencies = [ "proc-macro2", "quote", @@ -165,7 +165,7 @@ dependencies = [ [[package]] name = "goober-layer" version = "0.1.0" -source = "git+https://github.com/jw1912/goober.git#32b9b52e68ef03d9d706548fb21bb3c4535c4dd2" +source = "git+https://github.com/jw1912/goober#32b9b52e68ef03d9d706548fb21bb3c4535c4dd2" dependencies = [ "goober-core", ] @@ -240,9 +240,6 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "monty" version = "1.0.0" -dependencies = [ - "goober", -] [[package]] name = "montyformat" diff --git a/Cargo.toml b/Cargo.toml index 253ef440..ab1b14b7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,9 +10,6 @@ strip = true lto = true codegen-units = 1 -[dependencies] -goober = { git = 'https://github.com/jw1912/goober.git' } - [features] embed = [] datagen = [] diff --git a/src/chess.rs b/src/chess.rs index d355307a..c679de5b 100644 --- a/src/chess.rs +++ b/src/chess.rs @@ -140,8 +140,8 @@ impl ChessState { self.stm() } - pub fn get_policy_feats(&self) -> (goober::SparseVector, u64) { - let mut feats = goober::SparseVector::with_capacity(32); + pub fn get_policy_feats(&self) -> (Vec, u64) { + let mut feats = Vec::with_capacity(32); self.board.map_policy_features(|feat| feats.push(feat)); (feats, self.board.threats()) } @@ -149,7 +149,7 @@ impl ChessState { pub fn get_policy( &self, mov: Move, - (feats, threats): &(goober::SparseVector, u64), + (feats, threats): &(Vec, u64), policy: &PolicyNetwork, ) -> f32 { policy.get(&self.board, &mov, feats, *threats) diff --git a/src/lib.rs b/src/lib.rs index 68bd2c1c..73600c7c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,8 +7,8 @@ mod uci; pub use chess::{Board, Castling, ChessState, GameState, Move}; pub use mcts::{Limits, MctsParams, Searcher}; pub use networks::{ - PolicyFileDefaultName, PolicyNetwork, SubNet, UnquantisedValueNetwork, ValueFileDefaultName, - ValueNetwork, + PolicyFileDefaultName, PolicyNetwork, UnquantisedPolicyNetwork, UnquantisedValueNetwork, + ValueFileDefaultName, ValueNetwork, }; pub use tree::Tree; pub use uci::Uci; diff --git a/src/networks.rs b/src/networks.rs index d776a888..0afea34e 100644 --- a/src/networks.rs +++ b/src/networks.rs @@ -1,9 +1,10 @@ mod accumulator; +mod activation; mod layer; mod policy; mod value; -pub use policy::{PolicyFileDefaultName, PolicyNetwork, SubNet}; +pub use policy::{PolicyFileDefaultName, PolicyNetwork, UnquantisedPolicyNetwork}; pub use value::{UnquantisedValueNetwork, ValueFileDefaultName, ValueNetwork}; const QA: i16 = 512; diff --git a/src/networks/accumulator.rs b/src/networks/accumulator.rs index 096bf823..775aab26 100644 --- a/src/networks/accumulator.rs +++ b/src/networks/accumulator.rs @@ -1,7 +1,9 @@ use std::ops::{AddAssign, Mul}; -#[derive(Clone, Copy)] +use super::activation::Activation; + #[repr(C)] +#[derive(Clone, Copy)] pub struct Accumulator(pub [T; N]); impl + Copy + Mul, const N: usize> Accumulator { @@ -19,11 +21,25 @@ impl + Copy + Mul, const N: usize> Accumulator Accumulator { + pub fn dot(&self, other: &Self) -> f32 { + let mut res = 0.0; + + for (i, j) in self.0.iter().zip(other.0.iter()) { + res += T::activate(*i) * T::activate(*j); + } + + res + } + pub fn quantise(&self, qa: i16) -> Accumulator { let mut res = Accumulator([0; N]); - for (i, j) in res.0.iter_mut().zip(self.0.iter()) { - *i = (*j * f32::from(qa)) as i16; + for (i, &j) in res.0.iter_mut().zip(self.0.iter()) { + if j > 1.98 { + println!("{j}") + } + + *i = (j * f32::from(qa)) as i16; } res diff --git a/src/networks/activation.rs b/src/networks/activation.rs new file mode 100644 index 00000000..af0a4295 --- /dev/null +++ b/src/networks/activation.rs @@ -0,0 +1,19 @@ +pub trait Activation { + fn activate(x: f32) -> f32; +} + +pub struct ReLU; +impl Activation for ReLU { + #[inline] + fn activate(x: f32) -> f32 { + x.max(0.0) + } +} + +pub struct SCReLU; +impl Activation for SCReLU { + #[inline] + fn activate(x: f32) -> f32 { + x.clamp(0.0, 1.0).powi(2) + } +} diff --git a/src/networks/layer.rs b/src/networks/layer.rs index dcf92fd6..8a470d20 100644 --- a/src/networks/layer.rs +++ b/src/networks/layer.rs @@ -1,6 +1,6 @@ use crate::Board; -use super::{accumulator::Accumulator, QA}; +use super::{accumulator::Accumulator, activation::Activation, QA}; #[derive(Clone, Copy)] pub struct Layer { @@ -16,30 +16,38 @@ impl Layer { out } -} -impl Layer { - #[inline] - fn screlu(x: f32) -> f32 { - x.clamp(0.0, 1.0).powi(2) + pub fn forward_from_slice(&self, feats: &[usize]) -> Accumulator { + let mut out = self.biases; + + for &feat in feats { + out.add(&self.weights[feat]) + } + + out } +} - pub fn forward(&self, inputs: &Accumulator) -> Accumulator { +impl Layer { + pub fn forward(&self, inputs: &Accumulator) -> Accumulator { let mut fwd = self.biases; for (i, d) in inputs.0.iter().zip(self.weights.iter()) { - let act = Self::screlu(*i); + let act = T::activate(*i); fwd.madd(act, d); } fwd } - pub fn forward_from_i16(&self, inputs: &Accumulator) -> Accumulator { + pub fn forward_from_i16( + &self, + inputs: &Accumulator, + ) -> Accumulator { let mut fwd = self.biases; for (i, d) in inputs.0.iter().zip(self.weights.iter()) { - let act = Self::screlu(f32::from(*i) / f32::from(QA)); + let act = T::activate(f32::from(*i) / f32::from(QA)); fwd.madd(act, d); } @@ -53,4 +61,15 @@ impl Layer { dest.biases = self.biases.quantise(qa); } + + pub fn quantise(&self, qa: i16) -> Layer { + let mut res = Layer { + weights: [Accumulator([0; N]); M], + biases: Accumulator([0; N]), + }; + + self.quantise_into(&mut res, qa); + + res + } } diff --git a/src/networks/policy.rs b/src/networks/policy.rs index e7df5adf..d3659a85 100644 --- a/src/networks/policy.rs +++ b/src/networks/policy.rs @@ -1,56 +1,37 @@ -use crate::chess::{Board, Move}; +use crate::{ + boxed_and_zeroed, + chess::{Board, Move}, +}; -use goober::{activation, layer, FeedForwardNetwork, Matrix, SparseVector, Vector}; +use super::{accumulator::Accumulator, activation::ReLU, layer::Layer, QA}; // DO NOT MOVE #[allow(non_upper_case_globals)] -pub const PolicyFileDefaultName: &str = "nn-6b5dc1d7fff9.network"; +pub const PolicyFileDefaultName: &str = "nn-e2a03baa505c.network"; #[repr(C)] -#[derive(Clone, Copy, FeedForwardNetwork)] -pub struct SubNet { - ft: layer::SparseConnected, - l2: layer::DenseConnected, +#[derive(Clone, Copy)] +struct SubNet { + ft: Layer, + l2: Layer, } impl SubNet { - pub const fn zeroed() -> Self { - Self { - ft: layer::SparseConnected::zeroed(), - l2: layer::DenseConnected::zeroed(), - } - } - - pub fn from_fn f32>(mut f: F) -> Self { - let matrix = Matrix::from_fn(|_, _| f()); - let vector = Vector::from_fn(|_| f()); - - let matrix2 = Matrix::from_fn(|_, _| f()); - let vector2 = Vector::from_fn(|_| f()); - - Self { - ft: layer::SparseConnected::from_raw(matrix, vector), - l2: layer::DenseConnected::from_raw(matrix2, vector2), - } + fn out(&self, feats: &[usize]) -> Accumulator { + let l2 = self.ft.forward_from_slice(feats); + self.l2.forward_from_i16::(&l2) } } #[repr(C)] #[derive(Clone, Copy)] pub struct PolicyNetwork { - pub subnets: [[SubNet; 2]; 448], - pub hce: layer::DenseConnected, + subnets: [[SubNet; 2]; 448], + hce: Layer, } impl PolicyNetwork { - pub const fn zeroed() -> Self { - Self { - subnets: [[SubNet::zeroed(); 2]; 448], - hce: layer::DenseConnected::zeroed(), - } - } - - pub fn get(&self, pos: &Board, mov: &Move, feats: &SparseVector, threats: u64) -> f32 { + pub fn get(&self, pos: &Board, mov: &Move, feats: &[usize], threats: u64) -> f32 { let flip = pos.flip_val(); let pc = pos.get_pc(1 << mov.src()) - 1; @@ -62,18 +43,57 @@ impl PolicyNetwork { let to_subnet = &self.subnets[64 * pc + usize::from(mov.to() ^ flip)][good_see]; let to_vec = to_subnet.out(feats); - let hce = self.hce.out(&Self::get_hce_feats(pos, mov))[0]; + let hce = self.hce.forward::(&Self::get_hce_feats(pos, mov)).0[0]; - from_vec.dot(&to_vec) + hce + from_vec.dot::(&to_vec) + hce } - pub fn get_hce_feats(_: &Board, mov: &Move) -> Vector<4> { - let mut feats = Vector::zeroed(); + pub fn get_hce_feats(_: &Board, mov: &Move) -> Accumulator { + let mut feats = [0.0; 4]; if mov.is_promo() { feats[mov.promo_pc() - 3] = 1.0; } - feats + Accumulator(feats) + } +} + +#[repr(C)] +#[derive(Clone, Copy)] +struct UnquantisedSubNet { + ft: Layer, + l2: Layer, +} + +impl UnquantisedSubNet { + fn quantise(&self, qa: i16) -> SubNet { + SubNet { + ft: self.ft.quantise(qa), + l2: self.l2, + } + } +} + +#[repr(C)] +#[derive(Clone, Copy)] +pub struct UnquantisedPolicyNetwork { + subnets: [[UnquantisedSubNet; 2]; 448], + hce: Layer, +} + +impl UnquantisedPolicyNetwork { + pub fn quantise(&self) -> Box { + let mut quant: Box = unsafe { boxed_and_zeroed() }; + + for (qpair, unqpair) in quant.subnets.iter_mut().zip(self.subnets.iter()) { + for (qsubnet, unqsubnet) in qpair.iter_mut().zip(unqpair.iter()) { + *qsubnet = unqsubnet.quantise(QA); + } + } + + quant.hce = self.hce; + + quant } } diff --git a/src/networks/value.rs b/src/networks/value.rs index 86708973..70546d09 100644 --- a/src/networks/value.rs +++ b/src/networks/value.rs @@ -1,6 +1,6 @@ use crate::{boxed_and_zeroed, Board}; -use super::{layer::Layer, QA}; +use super::{activation::SCReLU, layer::Layer, QA}; // DO NOT MOVE #[allow(non_upper_case_globals)] @@ -26,16 +26,16 @@ pub struct ValueNetwork { impl ValueNetwork { pub fn eval(&self, board: &Board) -> i32 { let l2 = self.l1.forward(board); - let l3 = self.l2.forward_from_i16(&l2); - let l4 = self.l3.forward(&l3); - let l5 = self.l4.forward(&l4); - let l6 = self.l5.forward(&l5); - let l7 = self.l6.forward(&l6); - let l8 = self.l7.forward(&l7); - let l9 = self.l8.forward(&l8); - let l10 = self.l9.forward(&l9); - let l11 = self.l10.forward(&l10); - let out = self.l11.forward(&l11); + let l3 = self.l2.forward_from_i16::(&l2); + let l4 = self.l3.forward::(&l3); + let l5 = self.l4.forward::(&l4); + let l6 = self.l5.forward::(&l5); + let l7 = self.l6.forward::(&l6); + let l8 = self.l7.forward::(&l7); + let l9 = self.l8.forward::(&l8); + let l10 = self.l9.forward::(&l9); + let l11 = self.l10.forward::(&l10); + let out = self.l11.forward::(&l11); (out.0[0] * SCALE as f32) as i32 } diff --git a/train/policy/src/bin/quantise.rs b/train/policy/src/bin/quantise.rs new file mode 100644 index 00000000..728b0161 --- /dev/null +++ b/train/policy/src/bin/quantise.rs @@ -0,0 +1,19 @@ +use std::io::Write; + +use monty::{read_into_struct_unchecked, PolicyNetwork, UnquantisedPolicyNetwork}; + +fn main() { + let unquantised: Box = + unsafe { read_into_struct_unchecked("nn-6b5dc1d7fff9.network") }; + + let quantised = unquantised.quantise(); + + let mut file = std::fs::File::create("quantised.network").unwrap(); + + unsafe { + let ptr: *const PolicyNetwork = quantised.as_ref(); + let slice_ptr: *const u8 = std::mem::transmute(ptr); + let slice = std::slice::from_raw_parts(slice_ptr, std::mem::size_of::()); + file.write_all(slice).unwrap(); + } +} diff --git a/train/policy/src/chess.rs b/train/policy/src/chess.rs index 784bb9a7..c9f603bd 100644 --- a/train/policy/src/chess.rs +++ b/train/policy/src/chess.rs @@ -1,9 +1,81 @@ use datagen::{PolicyData, Rand}; -use goober::{FeedForwardNetwork, OutputLayer, SparseVector, Vector}; -use monty::{Board, Move, PolicyNetwork, SubNet}; +use goober::{activation, layer, FeedForwardNetwork, Matrix, OutputLayer, SparseVector, Vector}; +use monty::{Board, Move}; use crate::TrainablePolicy; +#[repr(C)] +#[derive(Clone, Copy, FeedForwardNetwork)] +pub struct SubNet { + ft: layer::SparseConnected, + l2: layer::DenseConnected, +} + +impl SubNet { + pub const fn zeroed() -> Self { + Self { + ft: layer::SparseConnected::zeroed(), + l2: layer::DenseConnected::zeroed(), + } + } + + pub fn from_fn f32>(mut f: F) -> Self { + let matrix = Matrix::from_fn(|_, _| f()); + let vector = Vector::from_fn(|_| f()); + + let matrix2 = Matrix::from_fn(|_, _| f()); + let vector2 = Vector::from_fn(|_| f()); + + Self { + ft: layer::SparseConnected::from_raw(matrix, vector), + l2: layer::DenseConnected::from_raw(matrix2, vector2), + } + } +} + +#[repr(C)] +#[derive(Clone, Copy)] +pub struct PolicyNetwork { + pub subnets: [[SubNet; 2]; 448], + pub hce: layer::DenseConnected, +} + +impl PolicyNetwork { + pub const fn zeroed() -> Self { + Self { + subnets: [[SubNet::zeroed(); 2]; 448], + hce: layer::DenseConnected::zeroed(), + } + } + + pub fn get(&self, pos: &Board, mov: &Move, feats: &SparseVector, threats: u64) -> f32 { + let flip = pos.flip_val(); + let pc = pos.get_pc(1 << mov.src()) - 1; + + let from_threat = usize::from(threats & (1 << mov.src()) > 0); + let from_subnet = &self.subnets[usize::from(mov.src() ^ flip)][from_threat]; + let from_vec = from_subnet.out(feats); + + let good_see = usize::from(pos.see(mov, -108)); + let to_subnet = &self.subnets[64 * pc + usize::from(mov.to() ^ flip)][good_see]; + let to_vec = to_subnet.out(feats); + + let hce = self.hce.out(&Self::get_hce_feats(pos, mov))[0]; + + from_vec.dot(&to_vec) + hce + } + + pub fn get_hce_feats(_: &Board, mov: &Move) -> Vector<4> { + let mut feats = Vector::zeroed(); + + if mov.is_promo() { + feats[mov.promo_pc() - 3] = 1.0; + } + + feats + } +} + impl TrainablePolicy for PolicyNetwork { type Data = PolicyData; diff --git a/train/policy/src/main.rs b/train/policy/src/main.rs index 6966d0db..28ff32f4 100644 --- a/train/policy/src/main.rs +++ b/train/policy/src/main.rs @@ -1,4 +1,4 @@ -use monty::PolicyNetwork; +use policy::chess::PolicyNetwork; fn main() { let mut args = std::env::args();