Skip to content

Commit

Permalink
Attempting to cluster correlation outputs using OPTICS
Browse files Browse the repository at this point in the history
  • Loading branch information
OmnipotentEntity committed Jul 28, 2023
1 parent 71139d1 commit 8a0de60
Showing 1 changed file with 72 additions and 1 deletion.
73 changes: 72 additions & 1 deletion python/play.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
import colorsys
import json
import numpy as np
from sklearn.cluster import OPTICS as optics
from sklearn.metrics.pairwise import pairwise_distances
from scipy.spatial.distance import cosine as cosine_similarity

from board import Board
from features import Features
Expand Down Expand Up @@ -40,6 +43,9 @@
# Hardcoded max board size
pos_len = 19

# Hardcoded correlation feature length size
corr_feature_len = 32

# Model ----------------------------------------------------------------

logging.root.handlers = []
Expand Down Expand Up @@ -138,6 +144,8 @@ def get_outputs(gs, rules):
seki = seki_probs[1] - seki_probs[2]
seki2 = torch.sigmoid(seki_logits[3,:,:]).cpu().numpy()
scorebelief = torch.nn.functional.softmax(scorebelief_logits,dim=0).cpu().numpy()
ownership_corr = torch.tanh(out_ownership_corr).cpu().numpy()
futurepos_corr = torch.tanh(out_futurepos_corr).cpu().numpy()

board = gs.board

Expand Down Expand Up @@ -271,7 +279,9 @@ def get_outputs(gs, rules):
"seki2": seki2,
"seki_by_loc2": seki_by_loc2,
"scorebelief": scorebelief,
"genmove_result": genmove_result
"genmove_result": genmove_result,
"ownership_corr": ownership_corr,
"futurepos_corr": futurepos_corr
}

def get_input_feature(gs, rules, feature_idx):
Expand Down Expand Up @@ -531,6 +541,19 @@ def print_scorebelief(gs,outputs):
return ret


def abs_cosine_metric(x, y):
return 1 - abs(cosine_similarity(x, y))


def corr_distances_by_cosine_metric(corr_input):
clustering = optics(metric=abs_cosine_metric).fit(corr_input)
if max(clustering.labels_ == -1):
raise RuntimeError("No clusters found")
centers = np.vstack([np.average(corr_input[clustering.labels_ == i], axis=0) for i in range(max(clustering.labels_) + 1)])
results = pairwise_distances(corr_input, centers, metric='cosine')
return np.transpose(results)


# Basic parsing --------------------------------------------------------
colstr = 'ABCDEFGHJKLMNOPQRST'
def parse_coord(s,board):
Expand Down Expand Up @@ -578,6 +601,7 @@ def str_coord(loc,board):
'scorebelief',
'passalive',
]

known_analyze_commands = [
'gfx/Policy/policy',
'gfx/Policy1/policy1',
Expand Down Expand Up @@ -675,6 +699,24 @@ def get_board_matrix_str(matrix, scale, formatstr):
gs.boards.append(gs.board.copy())
ret = str_coord(loc,gs.board)

elif command[0] == "genmoves":
count = 10
if len(command) > 1:
count = int(command[1])
if count < 0 or count > pos_len**2:
count = 10

for i in range(count):
outputs = get_outputs(gs, rules)
loc = outputs["genmove_result"]
pla = gs.board.pla

gs.board.play(pla,loc)
gs.moves.append((pla,loc))
gs.boards.append(gs.board.copy())
ret += str_coord(loc,gs.board)


elif command[0] == "name":
ret = 'KataGo Raw Neural Net Debug/Test Script'
elif command[0] == "version":
Expand Down Expand Up @@ -804,6 +846,35 @@ def get_board_matrix_str(matrix, scale, formatstr):
elif command[0] == "futurepos1_raw":
outputs = get_outputs(gs, rules)
ret = get_board_matrix_str(outputs["futurepos"][1], 100.0, "%+7.3f")

elif command[0] == "ownership_corr":
outputs = get_outputs(gs, rules)
corr = np.reshape(outputs["ownership_corr"], (corr_feature_len, features.pos_len ** 2))
corr = np.transpose(corr)
try:
distances = corr_distances_by_cosine_metric(corr)
ret = '\n\n'.join(list(get_board_matrix_str(i, 100.0, "%+7.3f") for i in distances))
except RuntimeError:
ret = "No clusters found returning raw output instead\n\n"
corr = np.transpose(corr)
for i in range(corr_feature_len):
ret += get_board_matrix_str(corr[i], 100.0, "%+7.3f")
ret += '\n'

elif command[0] == "futurepos_corr":
outputs = get_outputs(gs, rules)
corr = np.reshape(outputs["futurepos_corr"], (corr_feature_len, features.pos_len ** 2))
corr = np.transpose(corr)
try:
distances = corr_distances_by_cosine_metric(corr)
ret = '\n\n'.join(list(get_board_matrix_str(i, 100.0, "%+7.3f") for i in distances))
except RuntimeError:
ret = "No clusters found returning raw output instead\n\n"
corr = np.transpose(corr)
for i in range(corr_feature_len):
ret += get_board_matrix_str(corr[i], 100.0, "%+7.3f")
ret += '\n'

elif command[0] == "seki_raw":
outputs = get_outputs(gs, rules)
ret = get_board_matrix_str(outputs["seki"], 100.0, "%+7.3f")
Expand Down

0 comments on commit 8a0de60

Please sign in to comment.