Skip to content

Commit

Permalink
Strong typing for SIMD (#184)
Browse files Browse the repository at this point in the history
* use stronk typing for simd

Bench: 13292315

* register type names
  • Loading branch information
cosmobobak authored Aug 13, 2024
1 parent 5f2478e commit fa39a7b
Show file tree
Hide file tree
Showing 5 changed files with 398 additions and 378 deletions.
13 changes: 6 additions & 7 deletions src/datagen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ pub fn gen_data_main(cli_config: DataGenOptionsBuilder) -> anyhow::Result<()> {
.collect::<Vec<_>>();
for handle in thread_handles {
if let Ok(res) = handle.join() {
counters.push(res);
counters.push(res?);
} else {
bail!("Thread failed to join!");
}
Expand All @@ -211,16 +211,15 @@ pub fn gen_data_main(cli_config: DataGenOptionsBuilder) -> anyhow::Result<()> {
println!("Done!");
}

let counters = counters.into_iter().reduce(|a, b| {
let mut a = a?;
for (key, value) in b? {
*a.entry(key).or_insert(0) += value;
let counters = counters.into_iter().reduce(|mut acc, e| {
for (key, value) in e {
*acc.entry(key).or_insert(0) += value;
}
Ok(a)
acc
});

if let Some(counters) = counters {
print_game_stats(&counters?);
print_game_stats(&counters);
}

Ok(())
Expand Down
66 changes: 32 additions & 34 deletions src/nnue/accumulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,7 @@ unsafe fn slice_to_aligned<'a>(slice: &'a [i16]) -> &'a Align64<[i16; L1_SIZE]>
#[cfg(target_feature = "avx2")]
mod avx2 {
use super::{slice_to_aligned, Align64, FeatureIndex, INPUT, L1_SIZE};
use crate::nnue::simd::{
vec_add_epi16, vec_load_epi16, vec_store_epi16, vec_sub_epi16, vec_zero_epi16, I16_CHUNK_SIZE,
};
use crate::nnue::simd::{self, I16_CHUNK_SIZE};

/// Apply add/subtract updates in place.
pub fn vector_update_inplace(
Expand All @@ -64,30 +62,30 @@ mod avx2 {
// SAFETY: we never hold multiple mutable references, we never mutate immutable memory,
// we use iterators to ensure that we're staying in-bounds, etc.
unsafe {
let mut registers = [vec_zero_epi16(); 16];
let mut registers = [simd::zero_i16(); 16];
for i in 0..L1_SIZE / UNROLL {
let unroll_offset = i * UNROLL;
for (r_idx, reg) in registers.iter_mut().enumerate() {
*reg = vec_load_epi16(input.get_unchecked(unroll_offset + r_idx * I16_CHUNK_SIZE));
*reg = simd::load_i16(input.get_unchecked(unroll_offset + r_idx * I16_CHUNK_SIZE));
}
for &sub_index in subs {
let sub_index = sub_index.index() * L1_SIZE;
let sub_block = slice_to_aligned(bucket.get_unchecked(sub_index..sub_index + L1_SIZE));
for (r_idx, reg) in registers.iter_mut().enumerate() {
let sub = vec_load_epi16(sub_block.get_unchecked(unroll_offset + r_idx * I16_CHUNK_SIZE));
*reg = vec_sub_epi16(*reg, sub);
let sub = simd::load_i16(sub_block.get_unchecked(unroll_offset + r_idx * I16_CHUNK_SIZE));
*reg = simd::sub_i16(*reg, sub);
}
}
for &add_index in adds {
let add_index = add_index.index() * L1_SIZE;
let add_block = slice_to_aligned(bucket.get_unchecked(add_index..add_index + L1_SIZE));
for (r_idx, reg) in registers.iter_mut().enumerate() {
let add = vec_load_epi16(add_block.get_unchecked(unroll_offset + r_idx * I16_CHUNK_SIZE));
*reg = vec_add_epi16(*reg, add);
let add = simd::load_i16(add_block.get_unchecked(unroll_offset + r_idx * I16_CHUNK_SIZE));
*reg = simd::add_i16(*reg, add);
}
}
for (r_idx, reg) in registers.iter().enumerate() {
vec_store_epi16(input.get_unchecked_mut(unroll_offset + r_idx * I16_CHUNK_SIZE), *reg);
simd::store_i16(input.get_unchecked_mut(unroll_offset + r_idx * I16_CHUNK_SIZE), *reg);
}
}
}
Expand Down Expand Up @@ -115,12 +113,12 @@ mod avx2 {
for i in 0..L1_SIZE / I16_CHUNK_SIZE {
// SAFETY: we never hold multiple mutable references, we never mutate immutable memory, etc.
unsafe {
let x = vec_load_epi16(input.get_unchecked(i * I16_CHUNK_SIZE));
let w_sub = vec_load_epi16(s_block.get_unchecked(i * I16_CHUNK_SIZE));
let w_add = vec_load_epi16(a_block.get_unchecked(i * I16_CHUNK_SIZE));
let t = vec_sub_epi16(x, w_sub);
let t = vec_add_epi16(t, w_add);
vec_store_epi16(output.get_unchecked_mut(i * I16_CHUNK_SIZE), t);
let x = simd::load_i16(input.get_unchecked(i * I16_CHUNK_SIZE));
let w_sub = simd::load_i16(s_block.get_unchecked(i * I16_CHUNK_SIZE));
let w_add = simd::load_i16(a_block.get_unchecked(i * I16_CHUNK_SIZE));
let t = simd::sub_i16(x, w_sub);
let t = simd::add_i16(t, w_add);
simd::store_i16(output.get_unchecked_mut(i * I16_CHUNK_SIZE), t);
}
}
}
Expand Down Expand Up @@ -151,14 +149,14 @@ mod avx2 {
for i in 0..L1_SIZE / I16_CHUNK_SIZE {
// SAFETY: we never hold multiple mutable references, we never mutate immutable memory, etc.
unsafe {
let x = vec_load_epi16(input.get_unchecked(i * I16_CHUNK_SIZE));
let w_sub1 = vec_load_epi16(s_block1.get_unchecked(i * I16_CHUNK_SIZE));
let w_sub2 = vec_load_epi16(s_block2.get_unchecked(i * I16_CHUNK_SIZE));
let w_add = vec_load_epi16(a_block.get_unchecked(i * I16_CHUNK_SIZE));
let t = vec_sub_epi16(x, w_sub1);
let t = vec_sub_epi16(t, w_sub2);
let t = vec_add_epi16(t, w_add);
vec_store_epi16(output.get_unchecked_mut(i * I16_CHUNK_SIZE), t);
let x = simd::load_i16(input.get_unchecked(i * I16_CHUNK_SIZE));
let w_sub1 = simd::load_i16(s_block1.get_unchecked(i * I16_CHUNK_SIZE));
let w_sub2 = simd::load_i16(s_block2.get_unchecked(i * I16_CHUNK_SIZE));
let w_add = simd::load_i16(a_block.get_unchecked(i * I16_CHUNK_SIZE));
let t = simd::sub_i16(x, w_sub1);
let t = simd::sub_i16(t, w_sub2);
let t = simd::add_i16(t, w_add);
simd::store_i16(output.get_unchecked_mut(i * I16_CHUNK_SIZE), t);
}
}
}
Expand Down Expand Up @@ -193,16 +191,16 @@ mod avx2 {
for i in 0..L1_SIZE / I16_CHUNK_SIZE {
// SAFETY: we never hold multiple mutable references, we never mutate immutable memory, etc.
unsafe {
let x = vec_load_epi16(input.get_unchecked(i * I16_CHUNK_SIZE));
let w_sub1 = vec_load_epi16(s_block1.get_unchecked(i * I16_CHUNK_SIZE));
let w_sub2 = vec_load_epi16(s_block2.get_unchecked(i * I16_CHUNK_SIZE));
let w_add1 = vec_load_epi16(a_block1.get_unchecked(i * I16_CHUNK_SIZE));
let w_add2 = vec_load_epi16(a_block2.get_unchecked(i * I16_CHUNK_SIZE));
let t = vec_sub_epi16(x, w_sub1);
let t = vec_sub_epi16(t, w_sub2);
let t = vec_add_epi16(t, w_add1);
let t = vec_add_epi16(t, w_add2);
vec_store_epi16(output.get_unchecked_mut(i * I16_CHUNK_SIZE), t);
let x = simd::load_i16(input.get_unchecked(i * I16_CHUNK_SIZE));
let w_sub1 = simd::load_i16(s_block1.get_unchecked(i * I16_CHUNK_SIZE));
let w_sub2 = simd::load_i16(s_block2.get_unchecked(i * I16_CHUNK_SIZE));
let w_add1 = simd::load_i16(a_block1.get_unchecked(i * I16_CHUNK_SIZE));
let w_add2 = simd::load_i16(a_block2.get_unchecked(i * I16_CHUNK_SIZE));
let t = simd::sub_i16(x, w_sub1);
let t = simd::sub_i16(t, w_sub2);
let t = simd::add_i16(t, w_add1);
let t = simd::add_i16(t, w_add2);
simd::store_i16(output.get_unchecked_mut(i * I16_CHUNK_SIZE), t);
}
}
}
Expand Down
103 changes: 51 additions & 52 deletions src/nnue/network/layers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,7 @@ mod x86simd {
use super::super::{Align64, L1_SIZE, L2_SIZE, L3_SIZE, QA, QB};
use crate::nnue::{
network::L1_CHUNK_PER_32,
simd::{
vec_cvtepi32_ps, vec_dpbusd_epi32, vec_load_epi16, vec_load_epi32, vec_load_ps, vec_max_epi16, vec_max_ps,
vec_min_epi16, vec_min_ps, vec_mul_add_ps, vec_mul_ps, vec_mulhi_epi16, vec_nnz_mask,
vec_packus_permute_epi16, vec_reduce_add_ps, vec_set1_epi16, vec_set1_epi32, vec_set1_ps, vec_slli_epi16,
vec_store_epiu8, vec_store_ps, vec_zero_epi16, vec_zero_epi32, vec_zero_ps, vepi32, vepi8, vps32, S,
F32_CHUNK_SIZE, I16_CHUNK_SIZE, I32_CHUNK_SIZE,
},
simd::{self, VecF32, VecI32, VecI8, F32_CHUNK_SIZE, I16_CHUNK_SIZE, I32_CHUNK_SIZE, S},
};
use std::mem::MaybeUninit;

Expand Down Expand Up @@ -168,7 +162,7 @@ mod x86simd {
use std::arch::x86_64::_mm_setzero_si128 as vec128_zero;
use std::arch::x86_64::_mm_storeu_si128 as vec128_storeu;

const INPUT_SIMD_WIDTH: usize = std::mem::size_of::<vepi32>() / std::mem::size_of::<i32>();
const INPUT_SIMD_WIDTH: usize = std::mem::size_of::<VecI32>() / std::mem::size_of::<i32>();
const CHUNK_SIZE: usize = max!(INPUT_SIMD_WIDTH, 8);
const NUM_CHUNKS: usize = (L1_SIZE / L1_CHUNK_PER_32) / CHUNK_SIZE;
const INPUTS_PER_CHUNK: usize = CHUNK_SIZE / INPUT_SIMD_WIDTH;
Expand All @@ -181,8 +175,8 @@ mod x86simd {
// bitmask of nonzero values in this chunk
let mut nnz = 0;
for j in 0..INPUTS_PER_CHUNK {
let input_chunk = vec_load_epi32(input.get_unchecked((i * INPUTS_PER_CHUNK + j) * I32_CHUNK_SIZE));
nnz |= u32::from(vec_nnz_mask(input_chunk)) << (j * INPUT_SIMD_WIDTH);
let input_chunk = simd::load_i32(input.get_unchecked((i * INPUTS_PER_CHUNK + j) * I32_CHUNK_SIZE));
nnz |= u32::from(simd::nonzero_mask_i32(input_chunk)) << (j * INPUT_SIMD_WIDTH);
}
for j in 0..OUTPUTS_PER_CHUNK {
let lookup = (nnz >> (j * 8)) & 0xFF;
Expand All @@ -203,6 +197,7 @@ mod x86simd {
clippy::cast_possible_truncation,
clippy::cast_precision_loss,
clippy::cast_ptr_alignment,
clippy::cast_possible_wrap,
clippy::needless_range_loop,
clippy::similar_names
)]
Expand All @@ -218,29 +213,29 @@ mod x86simd {

// SAFETY: we never hold multiple mutable references, we never mutate immutable memory, etc.
unsafe {
let ft_zero = vec_zero_epi16();
let ft_one = vec_set1_epi16(QA as i16);
let ft_zero = simd::zero_i16();
let ft_one = simd::splat_i16(QA as i16);

let mut ft_outputs: Align64<[MaybeUninit<u8>; L1_SIZE]> = MaybeUninit::uninit().assume_init();

let mut offset = 0;
for acc in [us, them] {
for i in (0..L1_PAIR_COUNT).step_by(I16_CHUNK_SIZE * 2) {
let input0a = vec_load_epi16(acc.get_unchecked(i + 0 + 0));
let input0b = vec_load_epi16(acc.get_unchecked(i + I16_CHUNK_SIZE + 0));
let input1a = vec_load_epi16(acc.get_unchecked(i + 0 + L1_PAIR_COUNT));
let input1b = vec_load_epi16(acc.get_unchecked(i + I16_CHUNK_SIZE + L1_PAIR_COUNT));

let clipped0a = vec_min_epi16(vec_max_epi16(input0a, ft_zero), ft_one);
let clipped0b = vec_min_epi16(vec_max_epi16(input0b, ft_zero), ft_one);
let clipped1a = vec_min_epi16(input1a, ft_one);
let clipped1b = vec_min_epi16(input1b, ft_one);

let producta = vec_mulhi_epi16(vec_slli_epi16::<{ 16 - FT_SHIFT as S }>(clipped0a), clipped1a);
let productb = vec_mulhi_epi16(vec_slli_epi16::<{ 16 - FT_SHIFT as S }>(clipped0b), clipped1b);
vec_store_epiu8(
let input0a = simd::load_i16(acc.get_unchecked(i + 0 + 0));
let input0b = simd::load_i16(acc.get_unchecked(i + I16_CHUNK_SIZE + 0));
let input1a = simd::load_i16(acc.get_unchecked(i + 0 + L1_PAIR_COUNT));
let input1b = simd::load_i16(acc.get_unchecked(i + I16_CHUNK_SIZE + L1_PAIR_COUNT));

let clipped0a = simd::min_i16(simd::max_i16(input0a, ft_zero), ft_one);
let clipped0b = simd::min_i16(simd::max_i16(input0b, ft_zero), ft_one);
let clipped1a = simd::min_i16(input1a, ft_one);
let clipped1b = simd::min_i16(input1b, ft_one);

let producta = simd::mul_high_i16(simd::shl_i16::<{ 16 - FT_SHIFT as S }>(clipped0a), clipped1a);
let productb = simd::mul_high_i16(simd::shl_i16::<{ 16 - FT_SHIFT as S }>(clipped0b), clipped1b);
simd::store_u8(
std::ptr::from_mut(ft_outputs.get_unchecked_mut(offset + i)).cast(),
vec_packus_permute_epi16(producta, productb),
simd::pack_i16_to_unsigned_and_permute(producta, productb),
);
}
offset += L1_PAIR_COUNT;
Expand All @@ -262,28 +257,32 @@ mod x86simd {

let nnz_count = find_nnz(input32, &mut nnz);

let mut sums = [vec_zero_epi32(); L2_SIZE / F32_CHUNK_SIZE];
let mut sums = [simd::zero_i32(); L2_SIZE / F32_CHUNK_SIZE];

for &i in nnz.get_unchecked(..nnz_count) {
let i = i.assume_init();
let input = vec_set1_epi32(*input32.get_unchecked(i as usize));
let input = simd::splat_i32(*input32.get_unchecked(i as usize));
let i_col = i as usize * L2_SIZE * L1_CHUNK_PER_32;
let col = std::ptr::from_ref(weights.get_unchecked(i_col)).cast::<vepi8>();
let col = std::ptr::from_ref(weights.get_unchecked(i_col)).cast::<VecI8>();
for k in 0..L2_SIZE / F32_CHUNK_SIZE {
*sums.get_unchecked_mut(k) = vec_dpbusd_epi32(*sums.get_unchecked(k), input, *col.add(k));
*sums.get_unchecked_mut(k) = simd::mul_add_u8_to_i32(
*sums.get_unchecked(k),
simd::reinterpret_i32s_as_i8s(input),
*col.add(k),
);
}
}

let zero = vec_zero_ps();
let one = vec_set1_ps(1.0);
let sum_mul = vec_set1_ps(L1_MUL);
let zero = simd::zero_f32();
let one = simd::splat_f32(1.0);
let sum_mul = simd::splat_f32(L1_MUL);
for i in 0..L2_SIZE / F32_CHUNK_SIZE {
// Convert into floats, and activate L1
let bias_vec = vec_load_ps(biases.get_unchecked(i * F32_CHUNK_SIZE));
let sum_ps = vec_mul_add_ps(vec_cvtepi32_ps(*sums.get_unchecked(i)), sum_mul, bias_vec);
let clipped = vec_min_ps(vec_max_ps(sum_ps, zero), one);
let squared = vec_mul_ps(clipped, clipped);
vec_store_ps(output.get_unchecked_mut(i * F32_CHUNK_SIZE), squared);
let bias_vec = simd::load_f32(biases.get_unchecked(i * F32_CHUNK_SIZE));
let sum_ps = simd::mul_add_f32(simd::i32_to_f32(*sums.get_unchecked(i)), sum_mul, bias_vec);
let clipped = simd::min_f32(simd::max_f32(sum_ps, zero), one);
let squared = simd::mul_f32(clipped, clipped);
simd::store_f32(output.get_unchecked_mut(i * F32_CHUNK_SIZE), squared);
}
}
}
Expand All @@ -297,27 +296,27 @@ mod x86simd {
) {
// SAFETY: we never hold multiple mutable references, we never mutate immutable memory, etc.
unsafe {
let mut sum_vecs = [vec_zero_ps(); L3_SIZE / F32_CHUNK_SIZE];
let mut sum_vecs = [simd::zero_f32(); L3_SIZE / F32_CHUNK_SIZE];

for i in 0..L3_SIZE / F32_CHUNK_SIZE {
*sum_vecs.get_unchecked_mut(i) = vec_load_ps(biases.get_unchecked(i * F32_CHUNK_SIZE));
*sum_vecs.get_unchecked_mut(i) = simd::load_f32(biases.get_unchecked(i * F32_CHUNK_SIZE));
}

for i in 0..L2_SIZE {
let input_vec = vec_set1_ps(*inputs.get_unchecked(i));
let weight = std::ptr::from_ref(weights.get_unchecked(i * L3_SIZE)).cast::<vps32>();
let input_vec = simd::splat_f32(*inputs.get_unchecked(i));
let weight = std::ptr::from_ref(weights.get_unchecked(i * L3_SIZE)).cast::<VecF32>();
for j in 0..L3_SIZE / F32_CHUNK_SIZE {
*sum_vecs.get_unchecked_mut(j) =
vec_mul_add_ps(input_vec, *weight.add(j), *sum_vecs.get_unchecked(j));
simd::mul_add_f32(input_vec, *weight.add(j), *sum_vecs.get_unchecked(j));
}
}

// Activate L2
let one = vec_set1_ps(1.0);
let one = simd::splat_f32(1.0);
for i in 0..L3_SIZE / F32_CHUNK_SIZE {
let clipped = vec_min_ps(vec_max_ps(*sum_vecs.get_unchecked(i), vec_zero_ps()), one);
let squared = vec_mul_ps(clipped, clipped);
vec_store_ps(output.get_unchecked_mut(i * F32_CHUNK_SIZE), squared);
let clipped = simd::min_f32(simd::max_f32(*sum_vecs.get_unchecked(i), simd::zero_f32()), one);
let squared = simd::mul_f32(clipped, clipped);
simd::store_f32(output.get_unchecked_mut(i * F32_CHUNK_SIZE), squared);
}
}
}
Expand All @@ -330,16 +329,16 @@ mod x86simd {
) {
// SAFETY: we never hold multiple mutable references, we never mutate immutable memory, etc.
unsafe {
let mut sum_vec = vec_zero_ps();
let mut sum_vec = simd::zero_f32();

// Affine transform for L3
for i in (0..L3_SIZE).step_by(F32_CHUNK_SIZE) {
let weight_vec = vec_load_ps(weights.get_unchecked(i));
let input_vec = vec_load_ps(inputs.get_unchecked(i));
sum_vec = vec_mul_add_ps(input_vec, weight_vec, sum_vec);
let weight_vec = simd::load_f32(weights.get_unchecked(i));
let input_vec = simd::load_f32(inputs.get_unchecked(i));
sum_vec = simd::mul_add_f32(input_vec, weight_vec, sum_vec);
}

*output = bias + vec_reduce_add_ps(sum_vec);
*output = bias + simd::sum_f32(sum_vec);
}
}
}
Expand Down
Loading

0 comments on commit fa39a7b

Please sign in to comment.