Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Viren6 committed Oct 4, 2024
1 parent 4eb59ee commit 7bdb12e
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 23 deletions.
16 changes: 11 additions & 5 deletions datagen/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
use datagen::{parse_args, run_datagen};
use monty::{read_into_struct_unchecked, ChessState, MctsParams, Uci};
use monty::{read_into_struct_unchecked, ChessState, MappedStruct, MctsParams, Uci};

fn main() {
let mut args = std::env::args();
args.next();

let policy = unsafe { read_into_struct_unchecked(monty::PolicyFileDefaultName) };
let value = unsafe { read_into_struct_unchecked(monty::ValueFileDefaultName) };
let policy_mapped: MappedStruct<monty::PolicyNetwork> =
unsafe { read_into_struct_unchecked(monty::PolicyFileDefaultName) };

let value_mapped: MappedStruct<monty::ValueNetwork> =
unsafe { read_into_struct_unchecked(monty::ValueFileDefaultName) };

let policy = &policy_mapped.data;
let value = &value_mapped.data;

let params = MctsParams::default();

if let Some(opts) = parse_args(args) {
run_datagen(params, opts, &policy, &value);
run_datagen(params, opts, policy, value);
} else {
Uci::bench(ChessState::BENCH_DEPTH, &policy, &value, &params);
Uci::bench(ChessState::BENCH_DEPTH, policy, value, &params);
}
}
32 changes: 20 additions & 12 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,19 @@ mod uci;

pub use chess::{Board, Castling, ChessState, GameState, Move};
pub use mcts::{Limits, MctsParams, Searcher};
use memmap2::Mmap;
pub use networks::{
PolicyFileDefaultName, PolicyNetwork, UnquantisedPolicyNetwork, UnquantisedValueNetwork,
ValueFileDefaultName, ValueNetwork,
};
pub use tree::Tree;
pub use uci::Uci;

pub struct MappedStruct<'a, T> {
pub mmap: Mmap, // The memory-mapped file
pub data: &'a T, // A reference to the data in the mmap
}

// Macro for calculating tables (until const fn pointers are stable).
#[macro_export]
macro_rules! init {
Expand Down Expand Up @@ -50,24 +56,26 @@ pub unsafe fn boxed_and_zeroed<T>() -> Box<T> {

/// # Safety
/// Only to be used internally.
pub unsafe fn read_into_struct_unchecked<T>(path: &str) -> Box<T> {
use memmap2::Mmap;

pub unsafe fn read_into_struct_unchecked<'a, T>(path: &str) -> MappedStruct<'a, T> {
let f = std::fs::File::open(path).unwrap();
let mmap = Mmap::map(&f).unwrap();

let mut x: Box<T> = boxed_and_zeroed();

let size = std::mem::size_of::<T>();
let file_size = mmap.len();
assert_eq!(
file_size, size,
"File size does not match the size of the structure"
);

let file_size = f.metadata().unwrap().len();
let ptr = mmap.as_ptr() as *const T;

assert_eq!(file_size as usize, size);

unsafe {
let slice = std::slice::from_raw_parts_mut(x.as_mut() as *mut T as *mut u8, size);
slice.copy_from_slice(&mmap[..size]);
// Check if the pointer is properly aligned
if (ptr as usize) % std::mem::align_of::<T>() != 0 {
panic!("Memory is not properly aligned for the type");
}

x
MappedStruct {
mmap, // This ensures the memory is valid as long as MappedStruct exists
data: &*ptr,
}
}
17 changes: 11 additions & 6 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,26 +35,31 @@ mod net {

#[cfg(not(feature = "embed"))]
mod nonet {
use monty::{read_into_struct_unchecked, ChessState, MctsParams, Uci};
use monty::{read_into_struct_unchecked, ChessState, MappedStruct, MctsParams, Uci};

pub fn run() {
let mut args = std::env::args();
let arg1 = args.nth(1);

let policy = unsafe { read_into_struct_unchecked(monty::PolicyFileDefaultName) };
let policy_mapped: MappedStruct<monty::PolicyNetwork> =
unsafe { read_into_struct_unchecked(monty::PolicyFileDefaultName) };

let value = unsafe { read_into_struct_unchecked(monty::ValueFileDefaultName) };
let value_mapped: MappedStruct<monty::ValueNetwork> =
unsafe { read_into_struct_unchecked(monty::ValueFileDefaultName) };

let policy = policy_mapped.data;
let value = value_mapped.data;

if let Some("bench") = arg1.as_deref() {
Uci::bench(
ChessState::BENCH_DEPTH,
&policy,
&value,
policy,
value,
&MctsParams::default(),
);
return;
}

Uci::run(&policy, &value);
Uci::run(policy, value);
}
}

0 comments on commit 7bdb12e

Please sign in to comment.