Skip to content

Commit

Permalink
Restructure Tree Layout (#65)
Browse files Browse the repository at this point in the history
Rather than the tree being composed of nodes which each contained a vector of edges, the tree is now constructed of only nodes, and the children of a node are stored contiguously in the tree, so that each node needs only store the index of the first child, and the number of children it has.

This massively reduces allocations/frees of vectors (which were necessary to save as much memory as possible). It also allows the **exact** size of the tree to be set (which is a big deal for CCC and TCEC).

Passed STC:
LLR: 3.01 (-2.94,2.94) <-3.50,0.50>
Total: 3520 W: 881 L: 733 D: 1906
Ptnml(0-2): 42, 370, 812, 470, 66
https://tests.montychess.org/tests/view/671e812fb12c9e78f1354e34

Passed LTC:
LLR: 2.93 (-2.94,2.94) <-3.50,0.50>
Total: 1608 W: 392 L: 256 D: 960
Ptnml(0-2): 8, 128, 402, 252, 14
https://tests.montychess.org/tests/view/671e8537b12c9e78f1354e3a

Passed STC SMP:
LLR: 3.21 (-2.94,2.94) <-3.50,0.50>
Total: 1416 W: 441 L: 273 D: 702
Ptnml(0-2): 12, 106, 319, 244, 27
https://tests.montychess.org/tests/view/671e9a68b12c9e78f1354e55

Bench: 1488195
  • Loading branch information
jw1912 authored Oct 28, 2024
1 parent 379305f commit 078280e
Show file tree
Hide file tree
Showing 10 changed files with 463 additions and 598 deletions.
9 changes: 6 additions & 3 deletions datagen/src/thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,12 @@ impl<'a> DatagenThread<'a> {
} else {
let mut dist = Vec::new();

for action in tree[tree.root_node()].actions().iter() {
let mov = montyformat::chess::Move::from(action.mov());
dist.push((mov, action.visits() as u32));
let actions = { *tree[tree.root_node()].actions() };

for action in 0..tree[tree.root_node()].num_actions() {
let node = &tree[actions + action];
let mov = montyformat::chess::Move::from(u16::from(node.parent_move()));
dist.push((mov, node.visits() as u32));
}

assert_eq!(root_count, dist.len());
Expand Down
181 changes: 110 additions & 71 deletions src/mcts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pub use params::MctsParams;

use crate::{
chess::Move,
tree::{ActionStats, Edge, NodePtr, Tree},
tree::{Node, NodePtr, Tree},
ChessState, GameState, PolicyNetwork, ValueNetwork,
};

Expand Down Expand Up @@ -104,14 +104,10 @@ impl<'a> Searcher<'a> {
let mut pos = self.root_position.clone();
let mut this_depth = 0;

if let Some(u) = self.perform_one_iteration(
&mut pos,
self.tree.root_node(),
self.tree.root_stats(),
&mut this_depth,
) {
self.tree.root_stats().update(u);
} else {
if self
.perform_one_iteration(&mut pos, self.tree.root_node(), &mut this_depth)
.is_none()
{
return false;
}

Expand Down Expand Up @@ -163,7 +159,7 @@ impl<'a> Searcher<'a> {
}
}

let new_best_move = self.get_best_move();
let (_, new_best_move, _) = self.get_best_action(self.tree.root_node());
if new_best_move != *best_move {
*best_move = new_best_move;
*best_move_changes += 1;
Expand Down Expand Up @@ -229,24 +225,41 @@ impl<'a> Searcher<'a> {
) -> (Move, f32) {
let timer = Instant::now();

// attempt to reuse the current tree stored in memory
let node = self.tree.root_node();

// relabel root policies with root PST value
if self.tree[node].has_children() {
self.tree[node].relabel_policy(&self.root_position, self.params, self.policy, 1);
// the root node is added to an empty tree, **and not counted** towards the
// total node count, in order for `go nodes 1` to give the expected result
if self.tree.is_empty() {
let ptr = self.tree.push_new_node().unwrap();

assert_eq!(node, ptr);

self.tree[ptr].clear();
self.tree
.expand_node(ptr, &self.root_position, self.params, self.policy, 1);

let root_eval = self.root_position.get_value_wdl(self.value, self.params);
self.tree[ptr].update(1.0 - root_eval);
}
// relabel preexisting root policies with root PST value
else if self.tree[node].has_children() {
self.tree
.relabel_policy(node, &self.root_position, self.params, self.policy, 1);

let first_child_ptr = { *self.tree[node].actions() };

for action in 0..self.tree[node].num_actions() {
let ptr = first_child_ptr + action;

for action in &*self.tree[node].actions() {
if action.ptr().is_null() || !self.tree[action.ptr()].has_children() {
if ptr.is_null() || !self.tree[ptr].has_children() {
continue;
}

let mut position = self.root_position.clone();
position.make_move(Move::from(action.mov()));
self.tree[action.ptr()].relabel_policy(&position, self.params, self.policy, 2);
position.make_move(self.tree[ptr].parent_move());
self.tree
.relabel_policy(ptr, &position, self.params, self.policy, 2);
}
} else {
self.tree[node].expand(&self.root_position, self.params, self.policy, 1);
}

let search_stats = SearchStats::default();
Expand Down Expand Up @@ -291,24 +304,28 @@ impl<'a> Searcher<'a> {
);
}

let best_action = self.get_best_action();
(Move::from(best_action.mov()), best_action.q())
let (_, mov, q) = self.get_best_action(self.tree.root_node());
(mov, q)
}

fn perform_one_iteration(
&self,
pos: &mut ChessState,
ptr: NodePtr,
node_stats: &ActionStats,
depth: &mut usize,
) -> Option<f32> {
*depth += 1;

let hash = pos.hash();
let node = &self.tree[ptr];

let mut u = if node.is_terminal() || node.visits() == 0 {
if node.visits() == 0 {
node.set_state(pos.game_state());
}

let u = if self.tree[ptr].is_terminal() || node_stats.visits() == 0 {
// probe hash table to use in place of network
if self.tree[ptr].state() == GameState::Ongoing {
if node.state() == GameState::Ongoing {
if let Some(entry) = self.tree.probe_hash(hash) {
entry.q()
} else {
Expand All @@ -319,38 +336,59 @@ impl<'a> Searcher<'a> {
}
} else {
// expand node on the second visit
if self.tree[ptr].is_not_expanded() {
self.tree[ptr].expand(pos, self.params, self.policy, *depth);
if node.is_not_expanded() {
self.tree
.expand_node(ptr, pos, self.params, self.policy, *depth)?;
}

// this node has now been accessed so we need to move its
// children across if they are in the other tree half
self.tree.fetch_children(ptr)?;

// select action to take via PUCT
let action = self.pick_action(ptr, node_stats);
let action = self.pick_action(ptr, node);

let edge = self.tree.edge_copy(ptr, action);
let first_child_ptr = { *node.actions() };
let child_ptr = first_child_ptr + action;

pos.make_move(Move::from(edge.mov()));
let mov = self.tree[child_ptr].parent_move();

let child_ptr = self.tree.fetch_node(pos, ptr, edge.ptr(), action)?;
pos.make_move(mov);

self.tree[child_ptr].inc_threads();

// acquire lock to avoid issues with desynced setting of
// game state between threads when threads > 1
let lock = if self.tree[child_ptr].visits() == 0 {
Some(node.actions_mut())
} else {
None
};

// descend further
let maybe_u = self.perform_one_iteration(pos, child_ptr, &edge.stats(), depth);
let maybe_u = self.perform_one_iteration(pos, child_ptr, depth);

drop(lock);

self.tree[child_ptr].dec_threads();

let u = maybe_u?;

let new_q = self.tree.update_edge_stats(ptr, action, u);
self.tree.push_hash(hash, new_q);

self.tree
.propogate_proven_mates(ptr, self.tree[child_ptr].state());

u
};

Some(1.0 - u)
// node scores are stored from the perspective
// **of the parent**, as they are usually only
// accessed from the parent's POV
u = 1.0 - u;

let new_q = node.update(u);
self.tree.push_hash(hash, 1.0 - new_q);

Some(u)
}

fn get_utility(&self, ptr: NodePtr, pos: &ChessState) -> f32 {
Expand All @@ -362,30 +400,27 @@ impl<'a> Searcher<'a> {
}
}

fn pick_action(&self, ptr: NodePtr, node_stats: &ActionStats) -> usize {
fn pick_action(&self, ptr: NodePtr, node: &Node) -> usize {
let is_root = ptr == self.tree.root_node();

let cpuct = SearchHelpers::get_cpuct(self.params, node_stats, is_root);
let fpu = SearchHelpers::get_fpu(node_stats);
let expl_scale =
SearchHelpers::get_explore_scaling(self.params, node_stats, &self.tree[ptr]);
let cpuct = SearchHelpers::get_cpuct(self.params, node, is_root);
let fpu = SearchHelpers::get_fpu(node);
let expl_scale = SearchHelpers::get_explore_scaling(self.params, node);

let expl = cpuct * expl_scale;

self.tree.get_best_child_by_key(ptr, |action| {
let mut q = SearchHelpers::get_action_value(action, fpu);
self.tree.get_best_child_by_key(ptr, |child| {
let mut q = SearchHelpers::get_action_value(child, fpu);

// virtual loss
if !action.ptr().is_null() {
let threads = f64::from(self.tree[action.ptr()].threads());
if threads > 0.0 {
let visits = f64::from(action.visits());
let q2 = f64::from(q) * visits / (visits + threads);
q = q2 as f32;
}
let threads = f64::from(child.threads());
if threads > 0.0 {
let visits = f64::from(child.visits());
let q2 = f64::from(q) * visits / (visits + threads);
q = q2 as f32;
}

let u = expl * action.policy() / (1 + action.visits()) as f32;
let u = expl * child.policy() / (1 + child.visits()) as f32;

q + u
})
Expand Down Expand Up @@ -420,55 +455,59 @@ impl<'a> Searcher<'a> {
fn get_pv(&self, mut depth: usize) -> (Vec<Move>, f32) {
let mate = self.tree[self.tree.root_node()].is_terminal();

let mut action = self.get_best_action();
let (mut ptr, mut mov, q) = self.get_best_action(self.tree.root_node());

let score = if !action.ptr().is_null() {
match self.tree[action.ptr()].state() {
let score = if !ptr.is_null() {
match self.tree[ptr].state() {
GameState::Lost(_) => 1.1,
GameState::Won(_) => -0.1,
GameState::Draw => 0.5,
GameState::Ongoing => action.q(),
GameState::Ongoing => q,
}
} else {
action.q()
q
};

let mut pv = Vec::new();
let half = self.tree.half() > 0;

while (mate || depth > 0) && !action.ptr().is_null() && action.ptr().half() == half {
pv.push(Move::from(action.mov()));
let idx = self.tree.get_best_child(action.ptr());
while (mate || depth > 0) && !ptr.is_null() && ptr.half() == half {
pv.push(mov);
let idx = self.tree.get_best_child(ptr);

if idx == usize::MAX {
break;
}

action = self.tree.edge_copy(action.ptr(), idx);
(ptr, mov, _) = self.get_best_action(ptr);
depth = depth.saturating_sub(1);
}

(pv, score)
}

fn get_best_action(&self) -> Edge {
let idx = self.tree.get_best_child(self.tree.root_node());
self.tree.edge_copy(self.tree.root_node(), idx)
}

fn get_best_move(&self) -> Move {
Move::from(self.get_best_action().mov())
fn get_best_action(&self, node: NodePtr) -> (NodePtr, Move, f32) {
let idx = self.tree.get_best_child(node);
let ptr = *self.tree[node].actions() + idx;
let child = &self.tree[ptr];
(ptr, child.parent_move(), child.q())
}

fn get_cp(score: f32) -> f32 {
-400.0 * (1.0 / score.clamp(0.0, 1.0) - 1.0).ln()
}

pub fn display_moves(&self) {
for action in self.tree[self.tree.root_node()].actions().iter() {
let mov = self.root_position.conv_mov_to_str(action.mov().into());
let q = action.q() * 100.0;
println!("{mov} -> {q:.2}%");
let first_child_ptr = { *self.tree[self.tree.root_node()].actions() };
for action in 0..self.tree[self.tree.root_node()].num_actions() {
let child = &self.tree[first_child_ptr + action];
let mov = self.root_position.conv_mov_to_str(child.parent_move());
let q = child.q() * 100.0;
println!(
"{mov} -> {q:.2}% V({}) S({})",
child.visits(),
child.state()
);
}
}
}
Loading

0 comments on commit 078280e

Please sign in to comment.