Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quantise Policy Network L1 #43

Merged
merged 5 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 0 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@ strip = true
lto = true
codegen-units = 1

[dependencies]
goober = { git = 'https://github.com/jw1912/goober.git' }

[features]
embed = []
datagen = []
Expand Down
6 changes: 3 additions & 3 deletions src/chess.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,16 +140,16 @@ impl ChessState {
self.stm()
}

pub fn get_policy_feats(&self) -> (goober::SparseVector, u64) {
let mut feats = goober::SparseVector::with_capacity(32);
pub fn get_policy_feats(&self) -> (Vec<usize>, u64) {
let mut feats = Vec::with_capacity(32);
self.board.map_policy_features(|feat| feats.push(feat));
(feats, self.board.threats())
}

pub fn get_policy(
&self,
mov: Move,
(feats, threats): &(goober::SparseVector, u64),
(feats, threats): &(Vec<usize>, u64),
policy: &PolicyNetwork,
) -> f32 {
policy.get(&self.board, &mov, feats, *threats)
Expand Down
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ mod uci;
pub use chess::{Board, Castling, ChessState, GameState, Move};
pub use mcts::{Limits, MctsParams, Searcher};
pub use networks::{
PolicyFileDefaultName, PolicyNetwork, SubNet, UnquantisedValueNetwork, ValueFileDefaultName,
ValueNetwork,
PolicyFileDefaultName, PolicyNetwork, UnquantisedPolicyNetwork, UnquantisedValueNetwork,
ValueFileDefaultName, ValueNetwork,
};
pub use tree::Tree;
pub use uci::Uci;
Expand Down
3 changes: 2 additions & 1 deletion src/networks.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
mod accumulator;
mod activation;
mod layer;
mod policy;
mod value;

pub use policy::{PolicyFileDefaultName, PolicyNetwork, SubNet};
pub use policy::{PolicyFileDefaultName, PolicyNetwork, UnquantisedPolicyNetwork};
pub use value::{UnquantisedValueNetwork, ValueFileDefaultName, ValueNetwork};

const QA: i16 = 512;
22 changes: 19 additions & 3 deletions src/networks/accumulator.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use std::ops::{AddAssign, Mul};

#[derive(Clone, Copy)]
use super::activation::Activation;

#[repr(C)]
#[derive(Clone, Copy)]
pub struct Accumulator<T: Copy, const N: usize>(pub [T; N]);

impl<T: AddAssign<T> + Copy + Mul<T, Output = T>, const N: usize> Accumulator<T, N> {
Expand All @@ -19,11 +21,25 @@ impl<T: AddAssign<T> + Copy + Mul<T, Output = T>, const N: usize> Accumulator<T,
}

impl<const N: usize> Accumulator<f32, N> {
pub fn dot<T: Activation>(&self, other: &Self) -> f32 {
let mut res = 0.0;

for (i, j) in self.0.iter().zip(other.0.iter()) {
res += T::activate(*i) * T::activate(*j);
}

res
}

pub fn quantise(&self, qa: i16) -> Accumulator<i16, N> {
let mut res = Accumulator([0; N]);

for (i, j) in res.0.iter_mut().zip(self.0.iter()) {
*i = (*j * f32::from(qa)) as i16;
for (i, &j) in res.0.iter_mut().zip(self.0.iter()) {
if j > 1.98 {
println!("{j}")
}

*i = (j * f32::from(qa)) as i16;
}

res
Expand Down
19 changes: 19 additions & 0 deletions src/networks/activation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
pub trait Activation {
fn activate(x: f32) -> f32;
}

pub struct ReLU;
impl Activation for ReLU {
#[inline]
fn activate(x: f32) -> f32 {
x.max(0.0)
}
}

pub struct SCReLU;
impl Activation for SCReLU {
#[inline]
fn activate(x: f32) -> f32 {
x.clamp(0.0, 1.0).powi(2)
}
}
39 changes: 29 additions & 10 deletions src/networks/layer.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::Board;

use super::{accumulator::Accumulator, QA};
use super::{accumulator::Accumulator, activation::Activation, QA};

#[derive(Clone, Copy)]
pub struct Layer<T: Copy, const M: usize, const N: usize> {
Expand All @@ -16,30 +16,38 @@ impl<const M: usize, const N: usize> Layer<i16, M, N> {

out
}
}

impl<const M: usize, const N: usize> Layer<f32, M, N> {
#[inline]
fn screlu(x: f32) -> f32 {
x.clamp(0.0, 1.0).powi(2)
pub fn forward_from_slice(&self, feats: &[usize]) -> Accumulator<i16, N> {
let mut out = self.biases;

for &feat in feats {
out.add(&self.weights[feat])
}

out
}
}

pub fn forward(&self, inputs: &Accumulator<f32, M>) -> Accumulator<f32, N> {
impl<const M: usize, const N: usize> Layer<f32, M, N> {
pub fn forward<T: Activation>(&self, inputs: &Accumulator<f32, M>) -> Accumulator<f32, N> {
let mut fwd = self.biases;

for (i, d) in inputs.0.iter().zip(self.weights.iter()) {
let act = Self::screlu(*i);
let act = T::activate(*i);
fwd.madd(act, d);
}

fwd
}

pub fn forward_from_i16(&self, inputs: &Accumulator<i16, M>) -> Accumulator<f32, N> {
pub fn forward_from_i16<T: Activation>(
&self,
inputs: &Accumulator<i16, M>,
) -> Accumulator<f32, N> {
let mut fwd = self.biases;

for (i, d) in inputs.0.iter().zip(self.weights.iter()) {
let act = Self::screlu(f32::from(*i) / f32::from(QA));
let act = T::activate(f32::from(*i) / f32::from(QA));
fwd.madd(act, d);
}

Expand All @@ -53,4 +61,15 @@ impl<const M: usize, const N: usize> Layer<f32, M, N> {

dest.biases = self.biases.quantise(qa);
}

pub fn quantise(&self, qa: i16) -> Layer<i16, M, N> {
let mut res = Layer {
weights: [Accumulator([0; N]); M],
biases: Accumulator([0; N]),
};

self.quantise_into(&mut res, qa);

res
}
}
100 changes: 60 additions & 40 deletions src/networks/policy.rs
Original file line number Diff line number Diff line change
@@ -1,56 +1,37 @@
use crate::chess::{Board, Move};
use crate::{
boxed_and_zeroed,
chess::{Board, Move},
};

use goober::{activation, layer, FeedForwardNetwork, Matrix, SparseVector, Vector};
use super::{accumulator::Accumulator, activation::ReLU, layer::Layer, QA};

// DO NOT MOVE
#[allow(non_upper_case_globals)]
pub const PolicyFileDefaultName: &str = "nn-6b5dc1d7fff9.network";
pub const PolicyFileDefaultName: &str = "nn-e2a03baa505c.network";

#[repr(C)]
#[derive(Clone, Copy, FeedForwardNetwork)]
pub struct SubNet {
ft: layer::SparseConnected<activation::ReLU, 768, 16>,
l2: layer::DenseConnected<activation::ReLU, 16, 16>,
#[derive(Clone, Copy)]
struct SubNet {
ft: Layer<i16, 768, 16>,
l2: Layer<f32, 16, 16>,
}

impl SubNet {
pub const fn zeroed() -> Self {
Self {
ft: layer::SparseConnected::zeroed(),
l2: layer::DenseConnected::zeroed(),
}
}

pub fn from_fn<F: FnMut() -> f32>(mut f: F) -> Self {
let matrix = Matrix::from_fn(|_, _| f());
let vector = Vector::from_fn(|_| f());

let matrix2 = Matrix::from_fn(|_, _| f());
let vector2 = Vector::from_fn(|_| f());

Self {
ft: layer::SparseConnected::from_raw(matrix, vector),
l2: layer::DenseConnected::from_raw(matrix2, vector2),
}
fn out(&self, feats: &[usize]) -> Accumulator<f32, 16> {
let l2 = self.ft.forward_from_slice(feats);
self.l2.forward_from_i16::<ReLU>(&l2)
}
}

#[repr(C)]
#[derive(Clone, Copy)]
pub struct PolicyNetwork {
pub subnets: [[SubNet; 2]; 448],
pub hce: layer::DenseConnected<activation::Identity, 4, 1>,
subnets: [[SubNet; 2]; 448],
hce: Layer<f32, 4, 1>,
}

impl PolicyNetwork {
pub const fn zeroed() -> Self {
Self {
subnets: [[SubNet::zeroed(); 2]; 448],
hce: layer::DenseConnected::zeroed(),
}
}

pub fn get(&self, pos: &Board, mov: &Move, feats: &SparseVector, threats: u64) -> f32 {
pub fn get(&self, pos: &Board, mov: &Move, feats: &[usize], threats: u64) -> f32 {
let flip = pos.flip_val();
let pc = pos.get_pc(1 << mov.src()) - 1;

Expand All @@ -62,18 +43,57 @@ impl PolicyNetwork {
let to_subnet = &self.subnets[64 * pc + usize::from(mov.to() ^ flip)][good_see];
let to_vec = to_subnet.out(feats);

let hce = self.hce.out(&Self::get_hce_feats(pos, mov))[0];
let hce = self.hce.forward::<ReLU>(&Self::get_hce_feats(pos, mov)).0[0];

from_vec.dot(&to_vec) + hce
from_vec.dot::<ReLU>(&to_vec) + hce
}

pub fn get_hce_feats(_: &Board, mov: &Move) -> Vector<4> {
let mut feats = Vector::zeroed();
pub fn get_hce_feats(_: &Board, mov: &Move) -> Accumulator<f32, 4> {
let mut feats = [0.0; 4];

if mov.is_promo() {
feats[mov.promo_pc() - 3] = 1.0;
}

feats
Accumulator(feats)
}
}

#[repr(C)]
#[derive(Clone, Copy)]
struct UnquantisedSubNet {
ft: Layer<f32, 768, 16>,
l2: Layer<f32, 16, 16>,
}

impl UnquantisedSubNet {
fn quantise(&self, qa: i16) -> SubNet {
SubNet {
ft: self.ft.quantise(qa),
l2: self.l2,
}
}
}

#[repr(C)]
#[derive(Clone, Copy)]
pub struct UnquantisedPolicyNetwork {
subnets: [[UnquantisedSubNet; 2]; 448],
hce: Layer<f32, 4, 1>,
}

impl UnquantisedPolicyNetwork {
pub fn quantise(&self) -> Box<PolicyNetwork> {
let mut quant: Box<PolicyNetwork> = unsafe { boxed_and_zeroed() };

for (qpair, unqpair) in quant.subnets.iter_mut().zip(self.subnets.iter()) {
for (qsubnet, unqsubnet) in qpair.iter_mut().zip(unqpair.iter()) {
*qsubnet = unqsubnet.quantise(QA);
}
}

quant.hce = self.hce;

quant
}
}
22 changes: 11 additions & 11 deletions src/networks/value.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{boxed_and_zeroed, Board};

use super::{layer::Layer, QA};
use super::{activation::SCReLU, layer::Layer, QA};

// DO NOT MOVE
#[allow(non_upper_case_globals)]
Expand All @@ -26,16 +26,16 @@ pub struct ValueNetwork {
impl ValueNetwork {
pub fn eval(&self, board: &Board) -> i32 {
let l2 = self.l1.forward(board);
let l3 = self.l2.forward_from_i16(&l2);
let l4 = self.l3.forward(&l3);
let l5 = self.l4.forward(&l4);
let l6 = self.l5.forward(&l5);
let l7 = self.l6.forward(&l6);
let l8 = self.l7.forward(&l7);
let l9 = self.l8.forward(&l8);
let l10 = self.l9.forward(&l9);
let l11 = self.l10.forward(&l10);
let out = self.l11.forward(&l11);
let l3 = self.l2.forward_from_i16::<SCReLU>(&l2);
let l4 = self.l3.forward::<SCReLU>(&l3);
let l5 = self.l4.forward::<SCReLU>(&l4);
let l6 = self.l5.forward::<SCReLU>(&l5);
let l7 = self.l6.forward::<SCReLU>(&l6);
let l8 = self.l7.forward::<SCReLU>(&l7);
let l9 = self.l8.forward::<SCReLU>(&l8);
let l10 = self.l9.forward::<SCReLU>(&l9);
let l11 = self.l10.forward::<SCReLU>(&l10);
let out = self.l11.forward::<SCReLU>(&l11);

(out.0[0] * SCALE as f32) as i32
}
Expand Down
Loading
Loading