Skip to content

Commit

Permalink
Scale exploration with Gini impurity of policy values (#44)
Browse files Browse the repository at this point in the history
First, when computing policy values in a position, we also calculate the Gini impurity, defined as (1 - sum of squares of policy values). A high Gini impurity indicates that there are many strong candidate moves in a position, and vice versa.

This Gini impurity is then used to adjust the exploration scaling using a logarithmic formula. For higher values of Gini impurity, we decrease the exploration value so that the search focuses more on exploring variations with high q values. Conversely, for positions with low Gini impurity where one move is much better than the others, we increase the exploration value to ensure that other potential lines are not prematurely discarded.

The idea to use Gini impurity was first proposed and tested by @Viren6.

This patch was shown to affect the quality of data produced, so it has been intentionally excluded for datagen.

Passed STC: https://montychess.org/tests/view/66aef6280f6f1e65cfa2b1f8
LLR: 2.93 (-2.94,2.94) <0.00,4.00>
Total: 6944 W: 1626 L: 1460 D: 3858
Ptnml(0-2): 54, 787, 1643, 915, 73

Passed LTC: https://montychess.org/tests/view/66af1fd90f6f1e65cfa2b235
LLR: 2.93 (-2.94,2.94) <1.00,5.00>
Total: 8718 W: 1877 L: 1705 D: 5136
Ptnml(0-2): 40, 931, 2255, 1083, 50

Rebased STC: https://montychess.org/tests/view/66b053380f6f1e65cfa2b657
LLR: 2.91 (-2.94,2.94) <0.00,4.00>
Total: 9216 W: 2083 L: 1913 D: 5220
Ptnml(0-2): 82, 998, 2291, 1142, 95

2nd Rebased STC: https://montychess.org/tests/view/66bd91bf68e8f7e2fe23cfde
LLR: 3.03 (-2.94,2.94) <0.00,4.00>
Total: 4448 W: 1158 L: 982 D: 2308
Ptnml(0-2): 55, 487, 997, 597, 88

Bench: 1317423
  • Loading branch information
XInTheDark authored and Viren6 committed Aug 16, 2024
1 parent 9e8e7c5 commit 7a3dc7e
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 4 deletions.
3 changes: 2 additions & 1 deletion src/mcts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,8 @@ impl<'a> Searcher<'a> {

let cpuct = SearchHelpers::get_cpuct(self.params, node_stats, is_root);
let fpu = SearchHelpers::get_fpu(node_stats);
let expl_scale = SearchHelpers::get_explore_scaling(self.params, node_stats);
let expl_scale =
SearchHelpers::get_explore_scaling(self.params, node_stats, &self.tree[ptr]);

let expl = cpuct * expl_scale;

Expand Down
19 changes: 17 additions & 2 deletions src/mcts/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::time::Instant;

use crate::{
mcts::{MctsParams, Searcher},
tree::{ActionStats, Edge},
tree::{ActionStats, Edge, Node},
};

pub struct SearchHelpers;
Expand Down Expand Up @@ -35,10 +35,25 @@ impl SearchHelpers {
/// Exploration Scaling
///
/// Larger value implies more exploration.
pub fn get_explore_scaling(params: &MctsParams, node_stats: &ActionStats) -> f32 {
fn base_explore_scaling(params: &MctsParams, node_stats: &ActionStats) -> f32 {
(params.expl_tau() * (node_stats.visits().max(1) as f32).ln()).exp()
}

#[allow(unused_variables)]
pub fn get_explore_scaling(params: &MctsParams, node_stats: &ActionStats, node: &Node) -> f32 {
#[cfg(not(feature = "datagen"))]
{
let mut scale = Self::base_explore_scaling(params, node_stats);

let gini = node.gini_impurity();
scale *= (0.679 - 1.634 * (gini + 0.001).ln()).min(2.1);
scale
}

#[cfg(feature = "datagen")]
Self::base_explore_scaling(params, node_stats)
}

/// First Play Urgency
///
/// #### Note
Expand Down
1 change: 1 addition & 0 deletions src/tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ impl Tree {
let t = &mut *self[to].actions_mut();

self[to].set_state(self[from].state());
self[to].set_gini_impurity(self[from].gini_impurity());

if f.is_empty() {
return;
Expand Down
23 changes: 22 additions & 1 deletion src/tree/node.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::sync::{
atomic::{AtomicU16, Ordering},
atomic::{AtomicU16, AtomicU32, Ordering},
RwLock, RwLockReadGuard, RwLockWriteGuard,
};

Expand All @@ -14,6 +14,9 @@ pub struct Node {
actions: RwLock<Vec<Edge>>,
state: AtomicU16,
threads: AtomicU16,

// heuristics used in search
gini_impurity: AtomicU32,
}

impl Node {
Expand All @@ -22,12 +25,14 @@ impl Node {
actions: RwLock::new(Vec::new()),
state: AtomicU16::new(u16::from(state)),
threads: AtomicU16::new(0),
gini_impurity: AtomicU32::new(0),
}
}

pub fn set_new(&self, state: GameState) {
*self.actions_mut() = Vec::new();
self.set_state(state);
self.set_gini_impurity(0.0);
}

pub fn is_terminal(&self) -> bool {
Expand Down Expand Up @@ -74,9 +79,19 @@ impl Node {
self.state() == GameState::Ongoing && !self.has_children()
}

pub fn gini_impurity(&self) -> f32 {
f32::from_bits(self.gini_impurity.load(Ordering::Relaxed))
}

pub fn set_gini_impurity(&self, gini_impurity: f32) {
self.gini_impurity
.store(f32::to_bits(gini_impurity), Ordering::Relaxed);
}

pub fn clear(&self) {
*self.actions.write().unwrap() = Vec::new();
self.set_state(GameState::Ongoing);
self.set_gini_impurity(0.0);
}

pub fn expand<const ROOT: bool>(
Expand Down Expand Up @@ -122,11 +137,17 @@ impl Node {
total += policy;
}

let mut sum_of_squares = 0.0;

for action in actions.iter_mut() {
let policy = f32::from_bits(action.ptr().inner()) / total;
action.set_ptr(NodePtr::NULL);
action.set_policy(policy);
sum_of_squares += policy * policy;
}

let gini_impurity = (1.0 - sum_of_squares).clamp(0.0, 1.0);
self.set_gini_impurity(gini_impurity);
}

pub fn relabel_policy(&self, pos: &ChessState, params: &MctsParams, policy: &PolicyNetwork) {
Expand Down

0 comments on commit 7a3dc7e

Please sign in to comment.