Skip to content

Commit

Permalink
Quantise Policy Network L1 (#43)
Browse files Browse the repository at this point in the history
Quantises Policy Network L1.

Passed STC:
LLR: 2.92 (-2.94,2.94) <0.00,4.00>
Total: 4000 W: 1028 L: 867 D: 2105
Ptnml(0-2): 40, 417, 946, 536, 61
https://montychess.org/tests/view/66bd67dc68e8f7e2fe23ccdc

Passed LTC:
LLR: 2.94 (-2.94,2.94) <1.00,5.00>
Total: 5382 W: 1181 L: 1021 D: 3180
Ptnml(0-2): 30, 550, 1376, 700, 35
https://montychess.org/tests/view/66bd6d5568e8f7e2fe23cd4b

Bench: 2233112
  • Loading branch information
jw1912 authored Aug 15, 2024
1 parent 3bcadeb commit 9e8e7c5
Show file tree
Hide file tree
Showing 13 changed files with 243 additions and 83 deletions.
11 changes: 4 additions & 7 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 0 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@ strip = true
lto = true
codegen-units = 1

[dependencies]
goober = { git = 'https://github.com/jw1912/goober.git' }

[features]
embed = []
datagen = []
Expand Down
6 changes: 3 additions & 3 deletions src/chess.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,16 +140,16 @@ impl ChessState {
self.stm()
}

pub fn get_policy_feats(&self) -> (goober::SparseVector, u64) {
let mut feats = goober::SparseVector::with_capacity(32);
pub fn get_policy_feats(&self) -> (Vec<usize>, u64) {
let mut feats = Vec::with_capacity(32);
self.board.map_policy_features(|feat| feats.push(feat));
(feats, self.board.threats())
}

pub fn get_policy(
&self,
mov: Move,
(feats, threats): &(goober::SparseVector, u64),
(feats, threats): &(Vec<usize>, u64),
policy: &PolicyNetwork,
) -> f32 {
policy.get(&self.board, &mov, feats, *threats)
Expand Down
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ mod uci;
pub use chess::{Board, Castling, ChessState, GameState, Move};
pub use mcts::{Limits, MctsParams, Searcher};
pub use networks::{
PolicyFileDefaultName, PolicyNetwork, SubNet, UnquantisedValueNetwork, ValueFileDefaultName,
ValueNetwork,
PolicyFileDefaultName, PolicyNetwork, UnquantisedPolicyNetwork, UnquantisedValueNetwork,
ValueFileDefaultName, ValueNetwork,
};
pub use tree::Tree;
pub use uci::Uci;
Expand Down
3 changes: 2 additions & 1 deletion src/networks.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
mod accumulator;
mod activation;
mod layer;
mod policy;
mod value;

pub use policy::{PolicyFileDefaultName, PolicyNetwork, SubNet};
pub use policy::{PolicyFileDefaultName, PolicyNetwork, UnquantisedPolicyNetwork};
pub use value::{UnquantisedValueNetwork, ValueFileDefaultName, ValueNetwork};

const QA: i16 = 512;
22 changes: 19 additions & 3 deletions src/networks/accumulator.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use std::ops::{AddAssign, Mul};

#[derive(Clone, Copy)]
use super::activation::Activation;

#[repr(C)]
#[derive(Clone, Copy)]
pub struct Accumulator<T: Copy, const N: usize>(pub [T; N]);

impl<T: AddAssign<T> + Copy + Mul<T, Output = T>, const N: usize> Accumulator<T, N> {
Expand All @@ -19,11 +21,25 @@ impl<T: AddAssign<T> + Copy + Mul<T, Output = T>, const N: usize> Accumulator<T,
}

impl<const N: usize> Accumulator<f32, N> {
pub fn dot<T: Activation>(&self, other: &Self) -> f32 {
let mut res = 0.0;

for (i, j) in self.0.iter().zip(other.0.iter()) {
res += T::activate(*i) * T::activate(*j);
}

res
}

pub fn quantise(&self, qa: i16) -> Accumulator<i16, N> {
let mut res = Accumulator([0; N]);

for (i, j) in res.0.iter_mut().zip(self.0.iter()) {
*i = (*j * f32::from(qa)) as i16;
for (i, &j) in res.0.iter_mut().zip(self.0.iter()) {
if j > 1.98 {
println!("{j}")
}

*i = (j * f32::from(qa)) as i16;
}

res
Expand Down
19 changes: 19 additions & 0 deletions src/networks/activation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
pub trait Activation {
fn activate(x: f32) -> f32;
}

pub struct ReLU;
impl Activation for ReLU {
#[inline]
fn activate(x: f32) -> f32 {
x.max(0.0)
}
}

pub struct SCReLU;
impl Activation for SCReLU {
#[inline]
fn activate(x: f32) -> f32 {
x.clamp(0.0, 1.0).powi(2)
}
}
39 changes: 29 additions & 10 deletions src/networks/layer.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::Board;

use super::{accumulator::Accumulator, QA};
use super::{accumulator::Accumulator, activation::Activation, QA};

#[derive(Clone, Copy)]
pub struct Layer<T: Copy, const M: usize, const N: usize> {
Expand All @@ -16,30 +16,38 @@ impl<const M: usize, const N: usize> Layer<i16, M, N> {

out
}
}

impl<const M: usize, const N: usize> Layer<f32, M, N> {
#[inline]
fn screlu(x: f32) -> f32 {
x.clamp(0.0, 1.0).powi(2)
pub fn forward_from_slice(&self, feats: &[usize]) -> Accumulator<i16, N> {
let mut out = self.biases;

for &feat in feats {
out.add(&self.weights[feat])
}

out
}
}

pub fn forward(&self, inputs: &Accumulator<f32, M>) -> Accumulator<f32, N> {
impl<const M: usize, const N: usize> Layer<f32, M, N> {
pub fn forward<T: Activation>(&self, inputs: &Accumulator<f32, M>) -> Accumulator<f32, N> {
let mut fwd = self.biases;

for (i, d) in inputs.0.iter().zip(self.weights.iter()) {
let act = Self::screlu(*i);
let act = T::activate(*i);
fwd.madd(act, d);
}

fwd
}

pub fn forward_from_i16(&self, inputs: &Accumulator<i16, M>) -> Accumulator<f32, N> {
pub fn forward_from_i16<T: Activation>(
&self,
inputs: &Accumulator<i16, M>,
) -> Accumulator<f32, N> {
let mut fwd = self.biases;

for (i, d) in inputs.0.iter().zip(self.weights.iter()) {
let act = Self::screlu(f32::from(*i) / f32::from(QA));
let act = T::activate(f32::from(*i) / f32::from(QA));
fwd.madd(act, d);
}

Expand All @@ -53,4 +61,15 @@ impl<const M: usize, const N: usize> Layer<f32, M, N> {

dest.biases = self.biases.quantise(qa);
}

pub fn quantise(&self, qa: i16) -> Layer<i16, M, N> {
let mut res = Layer {
weights: [Accumulator([0; N]); M],
biases: Accumulator([0; N]),
};

self.quantise_into(&mut res, qa);

res
}
}
100 changes: 60 additions & 40 deletions src/networks/policy.rs
Original file line number Diff line number Diff line change
@@ -1,56 +1,37 @@
use crate::chess::{Board, Move};
use crate::{
boxed_and_zeroed,
chess::{Board, Move},
};

use goober::{activation, layer, FeedForwardNetwork, Matrix, SparseVector, Vector};
use super::{accumulator::Accumulator, activation::ReLU, layer::Layer, QA};

// DO NOT MOVE
#[allow(non_upper_case_globals)]
pub const PolicyFileDefaultName: &str = "nn-6b5dc1d7fff9.network";
pub const PolicyFileDefaultName: &str = "nn-e2a03baa505c.network";

#[repr(C)]
#[derive(Clone, Copy, FeedForwardNetwork)]
pub struct SubNet {
ft: layer::SparseConnected<activation::ReLU, 768, 16>,
l2: layer::DenseConnected<activation::ReLU, 16, 16>,
#[derive(Clone, Copy)]
struct SubNet {
ft: Layer<i16, 768, 16>,
l2: Layer<f32, 16, 16>,
}

impl SubNet {
pub const fn zeroed() -> Self {
Self {
ft: layer::SparseConnected::zeroed(),
l2: layer::DenseConnected::zeroed(),
}
}

pub fn from_fn<F: FnMut() -> f32>(mut f: F) -> Self {
let matrix = Matrix::from_fn(|_, _| f());
let vector = Vector::from_fn(|_| f());

let matrix2 = Matrix::from_fn(|_, _| f());
let vector2 = Vector::from_fn(|_| f());

Self {
ft: layer::SparseConnected::from_raw(matrix, vector),
l2: layer::DenseConnected::from_raw(matrix2, vector2),
}
fn out(&self, feats: &[usize]) -> Accumulator<f32, 16> {
let l2 = self.ft.forward_from_slice(feats);
self.l2.forward_from_i16::<ReLU>(&l2)
}
}

#[repr(C)]
#[derive(Clone, Copy)]
pub struct PolicyNetwork {
pub subnets: [[SubNet; 2]; 448],
pub hce: layer::DenseConnected<activation::Identity, 4, 1>,
subnets: [[SubNet; 2]; 448],
hce: Layer<f32, 4, 1>,
}

impl PolicyNetwork {
pub const fn zeroed() -> Self {
Self {
subnets: [[SubNet::zeroed(); 2]; 448],
hce: layer::DenseConnected::zeroed(),
}
}

pub fn get(&self, pos: &Board, mov: &Move, feats: &SparseVector, threats: u64) -> f32 {
pub fn get(&self, pos: &Board, mov: &Move, feats: &[usize], threats: u64) -> f32 {
let flip = pos.flip_val();
let pc = pos.get_pc(1 << mov.src()) - 1;

Expand All @@ -62,18 +43,57 @@ impl PolicyNetwork {
let to_subnet = &self.subnets[64 * pc + usize::from(mov.to() ^ flip)][good_see];
let to_vec = to_subnet.out(feats);

let hce = self.hce.out(&Self::get_hce_feats(pos, mov))[0];
let hce = self.hce.forward::<ReLU>(&Self::get_hce_feats(pos, mov)).0[0];

from_vec.dot(&to_vec) + hce
from_vec.dot::<ReLU>(&to_vec) + hce
}

pub fn get_hce_feats(_: &Board, mov: &Move) -> Vector<4> {
let mut feats = Vector::zeroed();
pub fn get_hce_feats(_: &Board, mov: &Move) -> Accumulator<f32, 4> {
let mut feats = [0.0; 4];

if mov.is_promo() {
feats[mov.promo_pc() - 3] = 1.0;
}

feats
Accumulator(feats)
}
}

#[repr(C)]
#[derive(Clone, Copy)]
struct UnquantisedSubNet {
ft: Layer<f32, 768, 16>,
l2: Layer<f32, 16, 16>,
}

impl UnquantisedSubNet {
fn quantise(&self, qa: i16) -> SubNet {
SubNet {
ft: self.ft.quantise(qa),
l2: self.l2,
}
}
}

#[repr(C)]
#[derive(Clone, Copy)]
pub struct UnquantisedPolicyNetwork {
subnets: [[UnquantisedSubNet; 2]; 448],
hce: Layer<f32, 4, 1>,
}

impl UnquantisedPolicyNetwork {
pub fn quantise(&self) -> Box<PolicyNetwork> {
let mut quant: Box<PolicyNetwork> = unsafe { boxed_and_zeroed() };

for (qpair, unqpair) in quant.subnets.iter_mut().zip(self.subnets.iter()) {
for (qsubnet, unqsubnet) in qpair.iter_mut().zip(unqpair.iter()) {
*qsubnet = unqsubnet.quantise(QA);
}
}

quant.hce = self.hce;

quant
}
}
22 changes: 11 additions & 11 deletions src/networks/value.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{boxed_and_zeroed, Board};

use super::{layer::Layer, QA};
use super::{activation::SCReLU, layer::Layer, QA};

// DO NOT MOVE
#[allow(non_upper_case_globals)]
Expand All @@ -26,16 +26,16 @@ pub struct ValueNetwork {
impl ValueNetwork {
pub fn eval(&self, board: &Board) -> i32 {
let l2 = self.l1.forward(board);
let l3 = self.l2.forward_from_i16(&l2);
let l4 = self.l3.forward(&l3);
let l5 = self.l4.forward(&l4);
let l6 = self.l5.forward(&l5);
let l7 = self.l6.forward(&l6);
let l8 = self.l7.forward(&l7);
let l9 = self.l8.forward(&l8);
let l10 = self.l9.forward(&l9);
let l11 = self.l10.forward(&l10);
let out = self.l11.forward(&l11);
let l3 = self.l2.forward_from_i16::<SCReLU>(&l2);
let l4 = self.l3.forward::<SCReLU>(&l3);
let l5 = self.l4.forward::<SCReLU>(&l4);
let l6 = self.l5.forward::<SCReLU>(&l5);
let l7 = self.l6.forward::<SCReLU>(&l6);
let l8 = self.l7.forward::<SCReLU>(&l7);
let l9 = self.l8.forward::<SCReLU>(&l8);
let l10 = self.l9.forward::<SCReLU>(&l9);
let l11 = self.l10.forward::<SCReLU>(&l10);
let out = self.l11.forward::<SCReLU>(&l11);

(out.0[0] * SCALE as f32) as i32
}
Expand Down
Loading

0 comments on commit 9e8e7c5

Please sign in to comment.