Skip to content

Commit

Permalink
Fix predictor, upd train class name
Browse files Browse the repository at this point in the history
  • Loading branch information
emnigma authored and gsvgit committed Nov 22, 2023
1 parent 1fce131 commit d512831
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 28 deletions.
13 changes: 6 additions & 7 deletions VSharp.ML.AIAgent/ml/common_model/wrapper.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import torch
import copy
import logging

import torch
from predict import predict_state_single_out, predict_state_with_dict

from common.game import GameState
from ml.common_model.utils import back_prop
from ml.data_loader_compact import ServerDataloaderHeteroVector
from ml.model_wrappers.protocols import Predictor
from ml.predict_state_vector_hetero import PredictStateVectorHetGNN
from ml.common_model.utils import back_prop


class CommonModelWrapper(Predictor):
Expand Down Expand Up @@ -51,9 +52,7 @@ def predict(self, input: GameState, map_name):
)
assert self._model is not None

next_step_id = PredictStateVectorHetGNN.predict_state_with_dict(
self._model, hetero_input, state_map
)
next_step_id = predict_state_with_dict(self._model, hetero_input, state_map)

del hetero_input
return next_step_id
Expand Down Expand Up @@ -83,7 +82,7 @@ def predict(self, input: GameState, map_name):
)
assert self._model is not None

next_step_id = PredictStateVectorHetGNN.predict_state_single_out(
next_step_id = predict_state_single_out(
self.best_models[map_name][0], hetero_input, state_map
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
import os.path
import pickle

from typing import Dict
from random import shuffle

import torch
from torch_geometric.data import HeteroData

from random import shuffle
from torch_geometric.loader import DataLoader
import torch.nn.functional as F
from torch_geometric.loader import DataLoader

BALANCE_DATASET = False

Expand All @@ -17,7 +13,7 @@ def get_module_name(clazz):
return clazz.__module__.split(".")[-2]


class PredictStateVectorHetGNN:
class HetGNNTestTrain:
"""predicts ExpectedStateNumber using Heterogeneous GNN"""

def __init__(self, model_class, hidden):
Expand Down Expand Up @@ -115,14 +111,6 @@ def tst(self, model, loader):
number_of_states_total += 1
return total_loss / number_of_states_total # correct / len(loader.dataset)

@staticmethod
def predict_state(model, data: HeteroData, state_map: Dict[int, int]) -> int:
"""Gets state id from model and heterogeneous graph
data.state_map - maps real state id to state index"""
state_map = {v: k for k, v in state_map.items()} # inversion for prediction
out = model(data.x_dict, data.edge_index_dict)
return state_map[int(out["state_vertex"].argmax(dim=0)[0])]

def save_simple(self, model, dir, epochs):
dir = os.path.join(
dir,
Expand Down
Binary file not shown.
Binary file not shown.
6 changes: 2 additions & 4 deletions VSharp.ML.AIAgent/ml/model_wrappers/nnwrapper.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import json

import torch.nn
from predict import predict_state_with_dict

from common.game import GameState
from ml.data_loader_compact import ServerDataloaderHeteroVector
from ml.model_wrappers.protocols import Predictor
from ml.predict_state_vector_hetero import PredictStateVectorHetGNN


class NNWrapper(Predictor):
Expand All @@ -26,9 +26,7 @@ def predict(self, input: GameState, map_name):
)
assert self._model is not None

next_step_id = PredictStateVectorHetGNN.predict_state_with_dict(
self._model, hetero_input, state_map
)
next_step_id = predict_state_with_dict(self._model, hetero_input, state_map)
del hetero_input
return next_step_id

Expand Down
66 changes: 66 additions & 0 deletions VSharp.ML.AIAgent/ml/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from collections import namedtuple

import torch
from torch_geometric.data import HeteroData

from config import GeneralConfig

StateVectorMapping = namedtuple("StateVectorMapping", ["state", "vector"])


def predict_state_with_dict(
model: torch.nn.Module, data: HeteroData, state_map: dict[int, int]
) -> int:
"""Gets state id from model and heterogeneous graph
data.state_map - maps real state id to state index"""

data.to(GeneralConfig.DEVICE)
reversed_state_map = {v: k for k, v in state_map.items()}

with torch.no_grad():
out = model(
data.x_dict["game_vertex"],
data.x_dict["state_vertex"],
data.edge_index_dict["game_vertex_to_game_vertex"],
data["game_vertex_to_game_vertex"].edge_type,
data["game_vertex_history_state_vertex"].edge_index,
data["game_vertex_history_state_vertex"].edge_attr,
data["game_vertex_in_state_vertex"].edge_index,
data["state_vertex_parent_of_state_vertex"].edge_index,
)

remapped = []

for index, vector in enumerate(out["state_vertex"]):
state_vector_mapping = StateVectorMapping(
state=reversed_state_map[index],
vector=(vector.detach().cpu().numpy()).tolist(),
)
remapped.append(state_vector_mapping)

return max(remapped, key=lambda mapping: sum(mapping.vector)).state, out


def predict_state_single_out(
model: torch.nn.Module, data: HeteroData, state_map: dict[int, int]
) -> int:
"""Gets state id from model and heterogeneous graph
data.state_map - maps real state id to state index"""

data.to(GeneralConfig.DEVICE)
reversed_state_map = {v: k for k, v in state_map.items()}

with torch.no_grad():
out = model.forward(data.x_dict, data.edge_index_dict, data.edge_attr_dict)

remapped = []
if type(out) is dict:
out = out["state_vertex"]
for index, vector in enumerate(out):
state_vector_mapping = StateVectorMapping(
state=reversed_state_map[index],
vector=(vector.detach().cpu().numpy()).tolist(),
)
remapped.append(state_vector_mapping)

return max(remapped, key=lambda mapping: sum(mapping.vector)).state
4 changes: 2 additions & 2 deletions VSharp.ML.AIAgent/pretrain.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from ml.data_loader_compact import ServerDataloaderHeteroVector
from ml.models.TAGSageSimple.model import StateModelEncoder
from ml.predict_state_vector_hetero import PredictStateVectorHetGNN
from ml.het_gnn_test_train import HetGNNTestTrain


def get_data_hetero_vector():
Expand All @@ -10,5 +10,5 @@ def get_data_hetero_vector():

if __name__ == "__main__":
# get_data_hetero_vector()
pr = PredictStateVectorHetGNN(StateModelEncoder, 32)
pr = HetGNNTestTrain(StateModelEncoder, 32)
pr.train_and_save("../dataset", 20, "./ml/models/")

0 comments on commit d512831

Please sign in to comment.