From 7b92be7883d6d2e7a38eb408dd0769d3086de15a Mon Sep 17 00:00:00 2001 From: Muzhen Gaming <61100393+XInTheDark@users.noreply.github.com> Date: Thu, 15 Aug 2024 18:15:12 +0800 Subject: [PATCH] Scale exploration with Gini impurity of policy values (#44) First, when computing policy values in a position, we also calculate the Gini impurity, defined as (1 - sum of squares of policy values). A high Gini impurity indicates that there are many strong candidate moves in a position, and vice versa. This Gini impurity is then used to adjust the exploration scaling using a logarithmic formula. For higher values of Gini impurity, we decrease the exploration value so that the search focuses more on exploring variations with high q values. Conversely, for positions with low Gini impurity where one move is much better than the others, we increase the exploration value to ensure that other potential lines are not prematurely discarded. The idea to use Gini impurity was first proposed and tested by @Viren6. This patch was shown to affect the quality of data produced, so it has been intentionally excluded for datagen. Passed STC: https://montychess.org/tests/view/66aef6280f6f1e65cfa2b1f8 LLR: 2.93 (-2.94,2.94) <0.00,4.00> Total: 6944 W: 1626 L: 1460 D: 3858 Ptnml(0-2): 54, 787, 1643, 915, 73 Passed LTC: https://montychess.org/tests/view/66af1fd90f6f1e65cfa2b235 LLR: 2.93 (-2.94,2.94) <1.00,5.00> Total: 8718 W: 1877 L: 1705 D: 5136 Ptnml(0-2): 40, 931, 2255, 1083, 50 Rebased STC: https://montychess.org/tests/view/66b053380f6f1e65cfa2b657 LLR: 2.91 (-2.94,2.94) <0.00,4.00> Total: 9216 W: 2083 L: 1913 D: 5220 Ptnml(0-2): 82, 998, 2291, 1142, 95 2nd Rebased STC: https://montychess.org/tests/view/66bd91bf68e8f7e2fe23cfde LLR: 3.03 (-2.94,2.94) <0.00,4.00> Total: 4448 W: 1158 L: 982 D: 2308 Ptnml(0-2): 55, 487, 997, 597, 88 Bench: 1317423 --- src/mcts.rs | 3 ++- src/mcts/helpers.rs | 19 +++++++++++++++++-- src/tree.rs | 1 + src/tree/node.rs | 23 ++++++++++++++++++++++- 4 files changed, 42 insertions(+), 4 deletions(-) diff --git a/src/mcts.rs b/src/mcts.rs index 87600c6c..e2e95401 100644 --- a/src/mcts.rs +++ b/src/mcts.rs @@ -364,7 +364,8 @@ impl<'a> Searcher<'a> { let cpuct = SearchHelpers::get_cpuct(self.params, node_stats, is_root); let fpu = SearchHelpers::get_fpu(node_stats); - let expl_scale = SearchHelpers::get_explore_scaling(self.params, node_stats); + let expl_scale = + SearchHelpers::get_explore_scaling(self.params, node_stats, &self.tree[ptr]); let expl = cpuct * expl_scale; diff --git a/src/mcts/helpers.rs b/src/mcts/helpers.rs index a5015214..3d12a2a0 100644 --- a/src/mcts/helpers.rs +++ b/src/mcts/helpers.rs @@ -2,7 +2,7 @@ use std::time::Instant; use crate::{ mcts::{MctsParams, Searcher}, - tree::{ActionStats, Edge}, + tree::{ActionStats, Edge, Node}, }; pub struct SearchHelpers; @@ -35,10 +35,25 @@ impl SearchHelpers { /// Exploration Scaling /// /// Larger value implies more exploration. - pub fn get_explore_scaling(params: &MctsParams, node_stats: &ActionStats) -> f32 { + fn base_explore_scaling(params: &MctsParams, node_stats: &ActionStats) -> f32 { (params.expl_tau() * (node_stats.visits().max(1) as f32).ln()).exp() } + #[allow(unused_variables)] + pub fn get_explore_scaling(params: &MctsParams, node_stats: &ActionStats, node: &Node) -> f32 { + #[cfg(not(feature = "datagen"))] + { + let mut scale = Self::base_explore_scaling(params, node_stats); + + let gini = node.gini_impurity(); + scale *= (0.679 - 1.634 * (gini + 0.001).ln()).min(2.1); + scale + } + + #[cfg(feature = "datagen")] + Self::base_explore_scaling(params, node_stats) + } + /// First Play Urgency /// /// #### Note diff --git a/src/tree.rs b/src/tree.rs index 4df4f68f..6f9580d4 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -72,6 +72,7 @@ impl Tree { let t = &mut *self[to].actions_mut(); self[to].set_state(self[from].state()); + self[to].set_gini_impurity(self[from].gini_impurity()); if f.is_empty() { return; diff --git a/src/tree/node.rs b/src/tree/node.rs index a89e96f9..bd852826 100644 --- a/src/tree/node.rs +++ b/src/tree/node.rs @@ -1,5 +1,5 @@ use std::sync::{ - atomic::{AtomicU16, Ordering}, + atomic::{AtomicU16, AtomicU32, Ordering}, RwLock, RwLockReadGuard, RwLockWriteGuard, }; @@ -14,6 +14,9 @@ pub struct Node { actions: RwLock>, state: AtomicU16, threads: AtomicU16, + + // heuristics used in search + gini_impurity: AtomicU32, } impl Node { @@ -22,12 +25,14 @@ impl Node { actions: RwLock::new(Vec::new()), state: AtomicU16::new(u16::from(state)), threads: AtomicU16::new(0), + gini_impurity: AtomicU32::new(0), } } pub fn set_new(&self, state: GameState) { *self.actions_mut() = Vec::new(); self.set_state(state); + self.set_gini_impurity(0.0); } pub fn is_terminal(&self) -> bool { @@ -74,9 +79,19 @@ impl Node { self.state() == GameState::Ongoing && !self.has_children() } + pub fn gini_impurity(&self) -> f32 { + f32::from_bits(self.gini_impurity.load(Ordering::Relaxed)) + } + + pub fn set_gini_impurity(&self, gini_impurity: f32) { + self.gini_impurity + .store(f32::to_bits(gini_impurity), Ordering::Relaxed); + } + pub fn clear(&self) { *self.actions.write().unwrap() = Vec::new(); self.set_state(GameState::Ongoing); + self.set_gini_impurity(0.0); } pub fn expand( @@ -122,11 +137,17 @@ impl Node { total += policy; } + let mut sum_of_squares = 0.0; + for action in actions.iter_mut() { let policy = f32::from_bits(action.ptr().inner()) / total; action.set_ptr(NodePtr::NULL); action.set_policy(policy); + sum_of_squares += policy * policy; } + + let gini_impurity = (1.0 - sum_of_squares).clamp(0.0, 1.0); + self.set_gini_impurity(gini_impurity); } pub fn relabel_policy(&self, pos: &ChessState, params: &MctsParams, policy: &PolicyNetwork) {