Skip to content

Commit

Permalink
Policy Softmax Temperature at depth=2 (#64)
Browse files Browse the repository at this point in the history
Applies a small amount of PST to the moves in children of the root node.

Passed STC:
LLR: 2.93 (-2.94,2.94) <0.00,4.00>
Total: 19552 W: 4594 L: 4380 D: 10578
Ptnml(0-2): 237, 2229, 4663, 2377, 270
https://tests.montychess.org/tests/view/670799dcb12c9e78f1354ad3

Passed LTC:
LLR: 2.95 (-2.94,2.94) <1.00,5.00>
Total: 32646 W: 6533 L: 6258 D: 19855
Ptnml(0-2): 128, 3593, 8634, 3812, 156
https://tests.montychess.org/tests/view/6707bc17b12c9e78f1354b00

Bench: 1358297
  • Loading branch information
TomaszJaworski777 authored Oct 11, 2024
1 parent 4d0a7a5 commit 42052c8
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 11 deletions.
16 changes: 13 additions & 3 deletions src/mcts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,19 @@ impl<'a> Searcher<'a> {

// relabel root policies with root PST value
if self.tree[node].has_children() {
self.tree[node].relabel_policy(&self.root_position, self.params, self.policy);
self.tree[node].relabel_policy(&self.root_position, self.params, self.policy, 1);

for action in &*self.tree[node].actions() {
if action.ptr().is_null() || !self.tree[action.ptr()].has_children() {
continue;
}

let mut position = self.root_position.clone();
position.make_move(Move::from(action.mov()));
self.tree[action.ptr()].relabel_policy(&position, self.params, self.policy, 2);
}
} else {
self.tree[node].expand::<true>(&self.root_position, self.params, self.policy);
self.tree[node].expand(&self.root_position, self.params, self.policy, 1);
}

let search_stats = SearchStats::default();
Expand Down Expand Up @@ -310,7 +320,7 @@ impl<'a> Searcher<'a> {
} else {
// expand node on the second visit
if self.tree[ptr].is_not_expanded() {
self.tree[ptr].expand::<false>(pos, self.params, self.policy);
self.tree[ptr].expand(pos, self.params, self.policy, *depth);
}

// select action to take via PUCT
Expand Down
1 change: 1 addition & 0 deletions src/mcts/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ macro_rules! make_mcts_params {

make_mcts_params! {
root_pst: f32 = 3.64, 1.0, 10.0, 0.4, 0.002;
depth_2_pst: f32 = 1.2, 1.0, 10.0, 0.4, 0.002;
root_cpuct: f32 = 0.314, 0.1, 5.0, 0.065, 0.002;
cpuct: f32 = 0.314, 0.1, 5.0, 0.065, 0.002;
cpuct_var_weight: f32 = 0.851, 0.0, 2.0, 0.085, 0.002;
Expand Down
33 changes: 25 additions & 8 deletions src/tree/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,12 @@ impl Node {
self.set_gini_impurity(0.0);
}

pub fn expand<const ROOT: bool>(
pub fn expand(
&self,
pos: &ChessState,
params: &MctsParams,
policy: &PolicyNetwork,
depth: usize,
) {
let mut actions = self.actions_mut();

Expand All @@ -121,16 +122,19 @@ impl Node {
max = max.max(policy);
});

let pst = match depth {
0 => unreachable!(),
1 => params.root_pst(),
2 => params.depth_2_pst(),
3.. => 1.0,
};

let mut total = 0.0;

for action in actions.iter_mut() {
let mut policy = f32::from_bits(action.ptr().inner());

policy = if ROOT {
((policy - max) / params.root_pst()).exp()
} else {
(policy - max).exp()
};
policy = ((policy - max) / pst).exp();

action.set_ptr(NodePtr::from_raw(f32::to_bits(policy)));

Expand All @@ -150,7 +154,13 @@ impl Node {
self.set_gini_impurity(gini_impurity);
}

pub fn relabel_policy(&self, pos: &ChessState, params: &MctsParams, policy: &PolicyNetwork) {
pub fn relabel_policy(
&self,
pos: &ChessState,
params: &MctsParams,
policy: &PolicyNetwork,
depth: u8,
) {
let feats = pos.get_policy_feats();
let mut max = f32::NEG_INFINITY;

Expand All @@ -163,10 +173,17 @@ impl Node {
max = max.max(policy);
}

let pst = match depth {
0 => unreachable!(),
1 => params.root_pst(),
2 => params.depth_2_pst(),
3.. => unreachable!(),
};

let mut total = 0.0;

for policy in &mut policies {
*policy = ((*policy - max) / params.root_pst()).exp();
*policy = ((*policy - max) / pst).exp();
total += *policy;
}

Expand Down

0 comments on commit 42052c8

Please sign in to comment.