From 1a8ac8609ed055c0ddb5069e983a386141925a29 Mon Sep 17 00:00:00 2001 From: Muzhen Gaming <61100393+XInTheDark@users.noreply.github.com> Date: Mon, 5 Aug 2024 11:49:13 +0800 Subject: [PATCH 1/7] Merge. create directly from master Bench: 1557738 --- src/mcts.rs | 3 ++- src/mcts/helpers.rs | 9 ++++++--- src/tree/node.rs | 17 ++++++++++++++++- 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/src/mcts.rs b/src/mcts.rs index a0ec79a6..2f5b65de 100644 --- a/src/mcts.rs +++ b/src/mcts.rs @@ -346,7 +346,8 @@ impl<'a> Searcher<'a> { let cpuct = SearchHelpers::get_cpuct(self.params, node_stats, is_root); let fpu = SearchHelpers::get_fpu(node_stats); - let expl_scale = SearchHelpers::get_explore_scaling(self.params, node_stats); + let expl_scale = + SearchHelpers::get_explore_scaling(self.params, node_stats, &self.tree[ptr]); let expl = cpuct * expl_scale; diff --git a/src/mcts/helpers.rs b/src/mcts/helpers.rs index a5015214..772aa518 100644 --- a/src/mcts/helpers.rs +++ b/src/mcts/helpers.rs @@ -2,7 +2,7 @@ use std::time::Instant; use crate::{ mcts::{MctsParams, Searcher}, - tree::{ActionStats, Edge}, + tree::{ActionStats, Edge, Node}, }; pub struct SearchHelpers; @@ -35,8 +35,11 @@ impl SearchHelpers { /// Exploration Scaling /// /// Larger value implies more exploration. - pub fn get_explore_scaling(params: &MctsParams, node_stats: &ActionStats) -> f32 { - (params.expl_tau() * (node_stats.visits().max(1) as f32).ln()).exp() + pub fn get_explore_scaling(params: &MctsParams, node_stats: &ActionStats, node: &Node) -> f32 { + let mut scale = (params.expl_tau() * (node_stats.visits().max(1) as f32).ln()).exp(); + let gini = node.gini_impurity(); + scale *= (0.679 - 1.634 * (gini + 0.001).ln()).min(2.1); + scale } /// First Play Urgency diff --git a/src/tree/node.rs b/src/tree/node.rs index a89e96f9..08fbe062 100644 --- a/src/tree/node.rs +++ b/src/tree/node.rs @@ -1,5 +1,5 @@ use std::sync::{ - atomic::{AtomicU16, Ordering}, + atomic::{AtomicU16, AtomicU32, Ordering}, RwLock, RwLockReadGuard, RwLockWriteGuard, }; @@ -14,6 +14,9 @@ pub struct Node { actions: RwLock>, state: AtomicU16, threads: AtomicU16, + + // heuristics used in search + gini_impurity: AtomicU32, } impl Node { @@ -22,6 +25,7 @@ impl Node { actions: RwLock::new(Vec::new()), state: AtomicU16::new(u16::from(state)), threads: AtomicU16::new(0), + gini_impurity: AtomicU32::new(0), } } @@ -74,6 +78,10 @@ impl Node { self.state() == GameState::Ongoing && !self.has_children() } + pub fn gini_impurity(&self) -> f32 { + f32::from_bits(self.gini_impurity.load(Ordering::Relaxed)) + } + pub fn clear(&self) { *self.actions.write().unwrap() = Vec::new(); self.set_state(GameState::Ongoing); @@ -122,11 +130,18 @@ impl Node { total += policy; } + let mut sum_of_squares = 0.0; + for action in actions.iter_mut() { let policy = f32::from_bits(action.ptr().inner()) / total; action.set_ptr(NodePtr::NULL); action.set_policy(policy); + sum_of_squares += policy * policy; } + + let gini_impurity = (1.0 - sum_of_squares).clamp(0.0, 1.0); + self.gini_impurity + .store(f32::to_bits(gini_impurity), Ordering::Relaxed); } pub fn relabel_policy(&self, pos: &ChessState, params: &MctsParams, policy: &PolicyNetwork) { From 21a0d4afbeb1c54f4573a0b6a91b2ccd9c37402f Mon Sep 17 00:00:00 2001 From: Muzhen Gaming <61100393+XInTheDark@users.noreply.github.com> Date: Mon, 5 Aug 2024 11:58:17 +0800 Subject: [PATCH 2/7] merge, fixed Bench: 4361350 --- src/tree/node.rs | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/tree/node.rs b/src/tree/node.rs index 08fbe062..086e1a4a 100644 --- a/src/tree/node.rs +++ b/src/tree/node.rs @@ -32,6 +32,7 @@ impl Node { pub fn set_new(&self, state: GameState) { *self.actions_mut() = Vec::new(); self.set_state(state); + self.set_gini_impurity(0.0); } pub fn is_terminal(&self) -> bool { @@ -82,6 +83,11 @@ impl Node { f32::from_bits(self.gini_impurity.load(Ordering::Relaxed)) } + pub fn set_gini_impurity(&self, gini_impurity: f32) { + self.gini_impurity + .store(f32::to_bits(gini_impurity), Ordering::Relaxed); + } + pub fn clear(&self) { *self.actions.write().unwrap() = Vec::new(); self.set_state(GameState::Ongoing); @@ -140,8 +146,7 @@ impl Node { } let gini_impurity = (1.0 - sum_of_squares).clamp(0.0, 1.0); - self.gini_impurity - .store(f32::to_bits(gini_impurity), Ordering::Relaxed); + self.set_gini_impurity(gini_impurity); } pub fn relabel_policy(&self, pos: &ChessState, params: &MctsParams, policy: &PolicyNetwork) { From ac3b5c60223d61d2e741ce691e12ac78326cbe68 Mon Sep 17 00:00:00 2001 From: Muzhen Gaming <61100393+XInTheDark@users.noreply.github.com> Date: Mon, 5 Aug 2024 11:59:59 +0800 Subject: [PATCH 3/7] merge, fixed Bench: 4361350 --- src/tree/node.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/tree/node.rs b/src/tree/node.rs index 086e1a4a..bd852826 100644 --- a/src/tree/node.rs +++ b/src/tree/node.rs @@ -91,6 +91,7 @@ impl Node { pub fn clear(&self) { *self.actions.write().unwrap() = Vec::new(); self.set_state(GameState::Ongoing); + self.set_gini_impurity(0.0); } pub fn expand( From 1ab845221cbb1fae24016bb69a055323cb53f5f6 Mon Sep 17 00:00:00 2001 From: Muzhen Gaming <61100393+XInTheDark@users.noreply.github.com> Date: Mon, 5 Aug 2024 12:20:58 +0800 Subject: [PATCH 4/7] merge, fixed v2 Bench: 1490523 --- src/tree.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/tree.rs b/src/tree.rs index 55095549..c2409695 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -72,6 +72,7 @@ impl Tree { let t = &mut *self[to].actions_mut(); self[to].set_state(self[from].state()); + self[to].set_gini_impurity(self[from].gini_impurity()); if f.is_empty() { return; From 24f497f811ff55263b4a83089ca73e788587ec94 Mon Sep 17 00:00:00 2001 From: Muzhen Gaming <61100393+XInTheDark@users.noreply.github.com> Date: Thu, 15 Aug 2024 13:27:12 +0800 Subject: [PATCH 5/7] Rebased Bench: 1281414 --- src/mcts/helpers.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/mcts/helpers.rs b/src/mcts/helpers.rs index 772aa518..39d5991b 100644 --- a/src/mcts/helpers.rs +++ b/src/mcts/helpers.rs @@ -37,8 +37,13 @@ impl SearchHelpers { /// Larger value implies more exploration. pub fn get_explore_scaling(params: &MctsParams, node_stats: &ActionStats, node: &Node) -> f32 { let mut scale = (params.expl_tau() * (node_stats.visits().max(1) as f32).ln()).exp(); + let gini = node.gini_impurity(); - scale *= (0.679 - 1.634 * (gini + 0.001).ln()).min(2.1); + #[cfg(not(feature = "datagen"))] + { + scale *= (0.679 - 1.634 * (gini + 0.001).ln()).min(2.1); + } + scale } From 748048927eecbccd1701989ad32810128a7d19d3 Mon Sep 17 00:00:00 2001 From: Muzhen Gaming <61100393+XInTheDark@users.noreply.github.com> Date: Thu, 15 Aug 2024 17:43:00 +0800 Subject: [PATCH 6/7] fix CI. Bench: 1317423 --- src/mcts/helpers.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcts/helpers.rs b/src/mcts/helpers.rs index 39d5991b..789e1172 100644 --- a/src/mcts/helpers.rs +++ b/src/mcts/helpers.rs @@ -38,9 +38,9 @@ impl SearchHelpers { pub fn get_explore_scaling(params: &MctsParams, node_stats: &ActionStats, node: &Node) -> f32 { let mut scale = (params.expl_tau() * (node_stats.visits().max(1) as f32).ln()).exp(); - let gini = node.gini_impurity(); #[cfg(not(feature = "datagen"))] { + let gini = node.gini_impurity(); scale *= (0.679 - 1.634 * (gini + 0.001).ln()).min(2.1); } From c3d3c19c4211c92466858b01961c4d1eb666c6cf Mon Sep 17 00:00:00 2001 From: Muzhen Gaming <61100393+XInTheDark@users.noreply.github.com> Date: Thu, 15 Aug 2024 18:13:17 +0800 Subject: [PATCH 7/7] fix CI (2) Bench: 1317423 --- src/mcts/helpers.rs | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/mcts/helpers.rs b/src/mcts/helpers.rs index 789e1172..3d12a2a0 100644 --- a/src/mcts/helpers.rs +++ b/src/mcts/helpers.rs @@ -35,16 +35,23 @@ impl SearchHelpers { /// Exploration Scaling /// /// Larger value implies more exploration. - pub fn get_explore_scaling(params: &MctsParams, node_stats: &ActionStats, node: &Node) -> f32 { - let mut scale = (params.expl_tau() * (node_stats.visits().max(1) as f32).ln()).exp(); + fn base_explore_scaling(params: &MctsParams, node_stats: &ActionStats) -> f32 { + (params.expl_tau() * (node_stats.visits().max(1) as f32).ln()).exp() + } + #[allow(unused_variables)] + pub fn get_explore_scaling(params: &MctsParams, node_stats: &ActionStats, node: &Node) -> f32 { #[cfg(not(feature = "datagen"))] { + let mut scale = Self::base_explore_scaling(params, node_stats); + let gini = node.gini_impurity(); scale *= (0.679 - 1.634 * (gini + 0.001).ln()).min(2.1); + scale } - scale + #[cfg(feature = "datagen")] + Self::base_explore_scaling(params, node_stats) } /// First Play Urgency