Skip to content

Commit

Permalink
Separate sum vector for reproducible bench (#210)
Browse files Browse the repository at this point in the history
* determinism

Bench: 5402525

* aaaaaaaaaaaaaaaaaa

Bench: 5685651
  • Loading branch information
cosmobobak authored Nov 5, 2024
1 parent 7f42700 commit 9feceb0
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 32 deletions.
12 changes: 4 additions & 8 deletions src/board.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,7 @@ impl Board {

#[cfg(feature = "datagen")]
pub fn regenerate_zobrist(&mut self) {
(self.key, self.pawn_key, self.non_pawn_key, self.minor_key, self.major_key) =
self.generate_pos_keys();
(self.key, self.pawn_key, self.non_pawn_key, self.minor_key, self.major_key) = self.generate_pos_keys();
}

#[cfg(feature = "datagen")]
Expand Down Expand Up @@ -359,8 +358,7 @@ impl Board {
bk: Some(Square::from_rank_file(Rank::Eight, File::from_index(kingside_file as u8).unwrap())),
bq: Some(Square::from_rank_file(Rank::Eight, File::from_index(queenside_file as u8).unwrap())),
};
(self.key, self.pawn_key, self.non_pawn_key, self.minor_key, self.major_key) =
self.generate_pos_keys();
(self.key, self.pawn_key, self.non_pawn_key, self.minor_key, self.major_key) = self.generate_pos_keys();
self.threats = self.generate_threats(self.side.flip());
}

Expand Down Expand Up @@ -418,8 +416,7 @@ impl Board {
bk: Some(Square::from_rank_file(Rank::Eight, File::from_index(black_kingside_file as u8).unwrap())),
bq: Some(Square::from_rank_file(Rank::Eight, File::from_index(black_queenside_file as u8).unwrap())),
};
(self.key, self.pawn_key, self.non_pawn_key, self.minor_key, self.major_key) =
self.generate_pos_keys();
(self.key, self.pawn_key, self.non_pawn_key, self.minor_key, self.major_key) = self.generate_pos_keys();
self.threats = self.generate_threats(self.side.flip());
}

Expand Down Expand Up @@ -561,8 +558,7 @@ impl Board {
self.set_halfmove(info_parts.next())?;
self.set_fullmove(info_parts.next())?;

(self.key, self.pawn_key, self.non_pawn_key, self.minor_key, self.major_key) =
self.generate_pos_keys();
(self.key, self.pawn_key, self.non_pawn_key, self.minor_key, self.major_key) = self.generate_pos_keys();
self.threats = self.generate_threats(self.side.flip());

Ok(())
Expand Down
13 changes: 8 additions & 5 deletions src/nnue/network/layers.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#[allow(dead_code)]
const AVX512CHUNK: usize = 512 / 32;
const FT_SHIFT: u32 = 10;
#[allow(clippy::cast_precision_loss)]
Expand Down Expand Up @@ -151,7 +150,7 @@ mod generic {
mod x86simd {
use super::{
super::{Align64, L1_SIZE, L2_SIZE, L3_SIZE, QA},
FT_SHIFT, L1_MUL,
AVX512CHUNK, FT_SHIFT, L1_MUL,
};
use crate::nnue::{
network::L1_CHUNK_PER_32,
Expand Down Expand Up @@ -466,22 +465,26 @@ mod x86simd {
bias: f32,
output: &mut f32,
) {
// These weird multiple-sum shenanigans is to make sure we add the floats in the exact same manner
// and order on ALL architectures, so that behaviour is deterministic
// We multiply the weights by the inputs, and sum them up
const NUM_SUMS: usize = AVX512CHUNK / F32_CHUNK_SIZE;
// SAFETY: Breaking it down by unsafe operations:
// 1. get_unchecked[_mut]: We only ever index at most (L3_SIZE / F32_CHUNK_SIZE - 1) * F32_CHUNK_SIZE
// into the `weights` and `inputs` arrays. This is in bounds, as `weights` has length L3_SIZE and
// `inputs` has length L3_SIZE.
// 2. SIMD instructions: All of our loads and stores are aligned.
unsafe {
let mut sum = simd::zero_f32();
let mut sum_vecs = [simd::zero_f32(); NUM_SUMS];

// affine transform
for i in 0..L3_SIZE / F32_CHUNK_SIZE {
let weight_vec = simd::load_f32(weights.get_unchecked(i * F32_CHUNK_SIZE));
let input_vec = simd::load_f32(inputs.get_unchecked(i * F32_CHUNK_SIZE));
sum = simd::mul_add_f32(input_vec, weight_vec, sum);
sum_vecs[i % NUM_SUMS] = simd::mul_add_f32(input_vec, weight_vec, sum_vecs[i % NUM_SUMS]);
}

*output = simd::sum_f32(sum) + bias;
*output = simd::reduce_add_f32s(&sum_vecs) + bias;
}
}
}
Expand Down
26 changes: 7 additions & 19 deletions src/nnue/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ macro_rules! wrap_simd_register {
#[cfg(target_feature = "avx512f")]
mod avx512 {
#![allow(non_camel_case_types)]
use crate::nnue::network::Align64;
use std::arch::x86_64::*;

wrap_simd_register!(__m512i, i8, VecI8);
Expand Down Expand Up @@ -260,8 +259,8 @@ mod avx512 {
return _mm512_reduce_add_ps(vec.inner());
}
#[inline]
pub unsafe fn reduce_add_f32s(vec: &Align64<[f32; 1 * F32_CHUNK_SIZE]>) -> f32 {
return _mm512_reduce_add_ps(load_f32(vec.get_unchecked(0 * F32_CHUNK_SIZE)).inner());
pub unsafe fn reduce_add_f32s(vec: &[VecF32; 1]) -> f32 {
return _mm512_reduce_add_ps(vec.get_unchecked(0).inner());
}

pub const U8_CHUNK_SIZE: usize = std::mem::size_of::<VecI8>() / std::mem::size_of::<u8>();
Expand All @@ -274,7 +273,6 @@ mod avx512 {
#[cfg(all(target_feature = "avx2", not(target_feature = "avx512f")))]
mod avx2 {
#![allow(non_camel_case_types)]
use crate::nnue::network::Align64;
use std::arch::x86_64::*;

wrap_simd_register!(__m256i, i8, VecI8);
Expand Down Expand Up @@ -473,11 +471,8 @@ mod avx2 {
return _mm_cvtss_f32(sum_32);
}
#[inline]
pub unsafe fn reduce_add_f32s(vec: &Align64<[f32; 2 * F32_CHUNK_SIZE]>) -> f32 {
let vec = _mm256_add_ps(
load_f32(vec.get_unchecked(0 * F32_CHUNK_SIZE)).inner(),
load_f32(vec.get_unchecked(1 * F32_CHUNK_SIZE)).inner(),
);
pub unsafe fn reduce_add_f32s(vec: &[VecF32; 2]) -> f32 {
let vec = _mm256_add_ps(vec.get_unchecked(0).inner(), vec.get_unchecked(1).inner());

let upper_128 = _mm256_extractf128_ps(vec, 1);
let lower_128 = _mm256_castps256_ps128(vec);
Expand All @@ -502,7 +497,6 @@ mod avx2 {
#[cfg(all(target_feature = "ssse3", not(target_feature = "avx2"), not(target_feature = "avx512f")))]
mod ssse3 {
#![allow(non_camel_case_types)]
use crate::nnue::network::Align64;
use std::arch::x86_64::*;

wrap_simd_register!(__m128i, i8, VecI8);
Expand Down Expand Up @@ -695,15 +689,9 @@ mod ssse3 {
return _mm_cvtss_f32(sum_32);
}
#[inline]
pub unsafe fn reduce_add_f32s(vec: &Align64<[f32; 4 * F32_CHUNK_SIZE]>) -> f32 {
let vec_a = _mm_add_ps(
load_f32(vec.get_unchecked(0 * F32_CHUNK_SIZE)).inner(),
load_f32(vec.get_unchecked(2 * F32_CHUNK_SIZE)).inner(),
);
let vec_b = _mm_add_ps(
load_f32(vec.get_unchecked(1 * F32_CHUNK_SIZE)).inner(),
load_f32(vec.get_unchecked(3 * F32_CHUNK_SIZE)).inner(),
);
pub unsafe fn reduce_add_f32s(vec: &[VecF32; 4]) -> f32 {
let vec_a = _mm_add_ps(vec.get_unchecked(0).inner(), vec.get_unchecked(2).inner());
let vec_b = _mm_add_ps(vec.get_unchecked(1).inner(), vec.get_unchecked(3).inner());
let vec = _mm_add_ps(vec_a, vec_b);
let upper_64 = _mm_movehl_ps(vec, vec);
let sum_64 = _mm_add_ps(vec, upper_64);
Expand Down

0 comments on commit 9feceb0

Please sign in to comment.