diff --git a/VSharp.ML.AIAgent/ml/data_loader_compact.py b/VSharp.ML.AIAgent/ml/data_loader_compact.py index 778e90049..135998206 100644 --- a/VSharp.ML.AIAgent/ml/data_loader_compact.py +++ b/VSharp.ML.AIAgent/ml/data_loader_compact.py @@ -1,4 +1,3 @@ -import argparse import json import os.path import pickle @@ -46,6 +45,7 @@ def convert_input_to_tensor(input: GameState) -> Tuple[HeteroData, Dict[int, int edges_index_s_v_history = [] edges_index_v_s_history = [] edges_attr_v_v = [] + edges_types_v_v = [] edges_attr_s_v = [] edges_attr_v_s = [] @@ -77,9 +77,8 @@ def convert_input_to_tensor(input: GameState) -> Tuple[HeteroData, Dict[int, int edges_index_v_v.append( np.array([vertex_map[e.VertexFrom], vertex_map[e.VertexTo]]) ) - edges_attr_v_v.append( - np.array([e.Label.Token]) - ) # TODO: consider token in a model + edges_attr_v_v.append(np.array([e.Label.Token])) + edges_types_v_v.append(e.Label.Token) state_doubles = 0 @@ -101,7 +100,7 @@ def convert_input_to_tensor(input: GameState) -> Tuple[HeteroData, Dict[int, int ) ) # history edges: state -> vertex and back - for h in s.History: # TODO: process NumOfVisits as edge label + for h in s.History: v_to = vertex_map[h.GraphVertexId] edges_index_s_v_history.append(np.array([state_index, v_to])) edges_index_v_s_history.append(np.array([v_to, state_index])) @@ -124,15 +123,21 @@ def convert_input_to_tensor(input: GameState) -> Tuple[HeteroData, Dict[int, int data["game_vertex"].x = torch.tensor(np.array(nodes_vertex), dtype=torch.float) data["state_vertex"].x = torch.tensor(np.array(nodes_state), dtype=torch.float) - data["game_vertex", "to", "game_vertex"].edge_index = ( + data["game_vertex_to_game_vertex"].edge_index = ( torch.tensor(np.array(edges_index_v_v), dtype=torch.long).t().contiguous() ) - data["state_vertex", "in", "game_vertex"].edge_index = ( + data["game_vertex_to_game_vertex"].edge_attr = torch.tensor( + np.array(edges_attr_v_v), dtype=torch.long + ) + data["game_vertex_to_game_vertex"].edge_type = torch.tensor( + np.array(edges_types_v_v), dtype=torch.long + ) + data["state_vertex_in_game_vertex"].edge_index = ( torch.tensor(np.array(edges_index_s_v_in), dtype=torch.long) .t() .contiguous() ) - data["game_vertex", "in", "state_vertex"].edge_index = ( + data["game_vertex_in_state_vertex"].edge_index = ( torch.tensor(np.array(edges_index_v_s_in), dtype=torch.long) .t() .contiguous() @@ -149,24 +154,24 @@ def null_if_empty(tensor): else torch.empty((2, 0), dtype=torch.int64) ) - data["state_vertex", "history", "game_vertex"].edge_index = null_if_empty( + data["state_vertex_history_game_vertex"].edge_index = null_if_empty( torch.tensor(np.array(edges_index_s_v_history), dtype=torch.long) .t() .contiguous() ) - data["game_vertex", "history", "state_vertex"].edge_index = null_if_empty( + data["game_vertex_history_state_vertex"].edge_index = null_if_empty( torch.tensor(np.array(edges_index_v_s_history), dtype=torch.long) .t() .contiguous() ) - data["state_vertex", "history", "game_vertex"].edge_attr = torch.tensor( + data["state_vertex_history_game_vertex"].edge_attr = torch.tensor( np.array(edges_attr_s_v), dtype=torch.long ) - data["game_vertex", "history", "state_vertex"].edge_attr = torch.tensor( + data["game_vertex_history_state_vertex"].edge_attr = torch.tensor( np.array(edges_attr_v_s), dtype=torch.long ) # if (edges_index_s_s): #TODO: empty? - data["state_vertex", "parent_of", "state_vertex"].edge_index = null_if_empty( + data["state_vertex_parent_of_state_vertex"].edge_index = null_if_empty( torch.tensor(np.array(edges_index_s_s), dtype=torch.long).t().contiguous() ) # print(data['state', 'parent_of', 'state'].edge_index) @@ -242,19 +247,3 @@ def __process_files(self): with open(PIK, "wb") as f: pickle.dump(self.dataset, f) self.dataset = [] - - -def parse_cmd_line_args(): - parser = argparse.ArgumentParser( - prog="V# pytorch-geometric data conversion", description="Symbolic execution" - ) - parser.add_argument("--dataset", required=True, help="Dataset folder") - parser.add_argument( - "--mode", help="heterogeneous or homogeneous graph model (het|hom)" - ) - - -def get_data_hetero_vector(): - dl = ServerDataloaderHeteroVector("../../GNN_V#/all") - # dl = ServerDataloaderHetero("../../GNN_V#/Serialized_test") - return dl.dataset diff --git a/VSharp.ML.AIAgent/ml/models.py b/VSharp.ML.AIAgent/ml/models.py deleted file mode 100644 index 45841c3b7..000000000 --- a/VSharp.ML.AIAgent/ml/models.py +++ /dev/null @@ -1,765 +0,0 @@ -import torch -import torch.nn.functional as F -from torch import nn -from torch.nn import Linear -from torch_geometric.nn import ( - ARMAConv, - FeaStConv, - GATConv, - GCNConv, - GraphConv, - HeteroConv, - Linear, - ResGatedGraphConv, - SAGEConv, - TAGConv, - TransformerConv, - global_mean_pool, - to_hetero, -) - -from torchvision.ops import MLP - - -from learning.timer.wrapper import timeit - -NUM_PREDICTED_VALUES = 4 - - -class ARMANet(torch.nn.Module): - def __init__(self, hidden_channels, num_classes, num_stacks=1, num_layers=1): - super(ARMANet, self).__init__() - - self.conv1 = ARMAConv( - -1, - out_channels=hidden_channels, - num_stacks=num_stacks, - num_layers=num_layers, - ) - - self.conv2 = ARMAConv( - -1, - out_channels=hidden_channels, - num_stacks=num_stacks, - num_layers=num_layers, - ) - - self.fc1 = nn.Linear(64, num_classes) - - """def forward(self, x, edge_index): - x = F.relu(self.conv1(x, edge_index)) - x = F.relu(self.conv2(x, edge_index)) - #x = global_mean_pool(x, x.batch) - x = F.dropout(x) - x = self.fc1(x) - return x""" - - def forward(self, x, edge_index): - x = self.conv1(x, edge_index).relu() - x = self.conv2(x, edge_index) - return x - - -class GATNet(torch.nn.Module): - def __init__(self, in_channels, hidden_channels=512, num_classes=2): - super(GATNet, self).__init__() - - # self.fc0 = nn.Linear(in_channels, hidden_channels) - self.conv1 = GATConv(in_channels, hidden_channels) - self.conv2 = GATConv(hidden_channels, 64) - self.fc1 = nn.Linear(64, num_classes) - - self.reset_parameters() - - def reset_parameters(self): - self.conv1.reset_parameters() - self.conv2.reset_parameters() - - def forward(self, x, edge_index): - # x = F.relu(self.fc0(x)) - x = F.relu(self.conv1(x, edge_index)) - x = F.relu(self.conv2(x, edge_index)) - # x = global_mean_pool(x, data.batch) - x = F.dropout(x, training=self.training) - x = self.fc1(x) - return x - - -class FeaStNet(torch.nn.Module): - def __init__( - self, in_channels, hidden_channels=512, num_classes=2, heads=1, t_inv=True - ): - super(FeaStNet, self).__init__() - - # self.fc0 = nn.Linear(in_channels, hidden_channels) - self.conv1 = FeaStConv(in_channels, hidden_channels, heads=heads, t_inv=t_inv) - self.conv2 = FeaStConv(hidden_channels, 64, heads=heads, t_inv=t_inv) - self.fc1 = nn.Linear(64, num_classes) - - self.reset_parameters() - - def reset_parameters(self): - self.conv1.reset_parameters() - self.conv2.reset_parameters() - - def forward(self, data): - x, edge_index = data.x, data.edge_index - - # x = F.relu(self.fc0(x)) - x = F.relu(self.conv1(x, edge_index)) - x = F.relu(self.conv2(x, edge_index)) - x = global_mean_pool(x, data.batch) - x = F.dropout(x, training=self.training) - x = self.fc1(x) - return x - - -class RGGCN(torch.nn.Module): - def __init__(self, in_channels, hidden_channels=512, num_classes=2): - super(RGGCN, self).__init__() - - # self.fc0 = nn.Linear(in_channels, hidden_channels) - self.conv1 = ResGatedGraphConv(in_channels, hidden_channels) - self.conv2 = ResGatedGraphConv(hidden_channels, 64) - self.fc1 = nn.Linear(64, num_classes) - - self.reset_parameters() - - def reset_parameters(self): - self.conv1.reset_parameters() - self.conv2.reset_parameters() - - def forward(self, data): - x, edge_index = data.x, data.edge_index - - # x = F.relu(self.fc0(x)) - x = F.relu(self.conv1(x, edge_index)) - x = F.relu(self.conv2(x, edge_index)) - x = global_mean_pool(x, data.batch) - x = F.dropout(x, training=self.training) - x = self.fc1(x) - return x - - -class UniMP(torch.nn.Module): - def __init__(self, in_channels, hidden_channels=512, num_classes=2): - super(UniMP, self).__init__() - # self.fc0 = nn.Linear(in_channels, hidden_channels) - self.conv1 = TransformerConv(in_channels, hidden_channels) - self.conv2 = TransformerConv(hidden_channels, 64) - self.fc1 = nn.Linear(64, num_classes) - - self.reset_parameters() - - def reset_parameters(self): - self.conv1.reset_parameters() - self.conv2.reset_parameters() - - def forward(self, data): - x, edge_index = data.x, data.edge_index - # x = F.relu(self.fc0(x)) - x = F.relu(self.conv1(x, edge_index)) - x = F.relu(self.conv2(x, edge_index)) - x = global_mean_pool(x, data.batch) - x = F.dropout(x, training=self.training) - x = self.fc1(x) - return x - - -class HeteroGNN(torch.nn.Module): - def __init__(self, metadata, hidden_channels, out_channels, num_layers): - super().__init__() - - self.convs = torch.nn.ModuleList() - for _ in range(num_layers): - conv = HeteroConv( - { - edge_type: SAGEConv((-1, -1), hidden_channels) - for edge_type in metadata[1] - } - ) - self.convs.append(conv) - - self.lin = Linear(hidden_channels, out_channels) - - def forward(self, x_dict, edge_index_dict): - for conv in self.convs: - x_dict = conv(x_dict, edge_index_dict) - x_dict = {key: F.leaky_relu(x) for key, x in x_dict.items()} - return self.lin(x_dict["author"]) - - -class GCN_SimpleNoEdgeLabels(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv1 = GCNConv(NUM_NODE_FEATURES, 16) - self.conv2 = GCNConv(16, 1) - - def forward(self, data): - x, edge_index = data.x, data.edge_index - - x = self.conv1(x, edge_index) - x = F.relu(x) - x = F.dropout(x, training=self.training) - x = self.conv2(x, edge_index) - - return F.log_softmax(x, dim=1) - - -class GNN_MultipleOutput(torch.nn.Module): - def __init__(self, hidden_channels, out_channels): - super().__init__() - self.conv1 = SAGEConv((-1, -1), hidden_channels) - self.conv2 = SAGEConv((-1, -1), out_channels) - - def forward(self, x, edge_index): - x1 = self.conv1(x, edge_index).relu() - x1 = self.conv2(x1, edge_index) - - x2 = self.conv1(x, edge_index).relu() - x2 = self.conv2(x2, edge_index) - - x3 = self.conv1(x, edge_index).relu() - x3 = self.conv2(x3, edge_index) - - x4 = self.conv1(x, edge_index).relu() - x4 = self.conv2(x4, edge_index) - return x1, x2, x3, x4 - - -class GCN_SimpleMultipleOutput(torch.nn.Module): - # https://discuss.pytorch.org/t/a-model-with-multiple-outputs/10440 - def __init__(self, hidden_channels): - super(GCN_SimpleMultipleOutput, self).__init__() - torch.manual_seed(12345) - self.conv1 = GCNConv(NUM_NODE_FEATURES, hidden_channels) - self.conv1.add_self_loops = False - self.conv2 = GCNConv(hidden_channels, NUM_PREDICTED_VALUES) - self.conv2.add_self_loops = False - - def forward(self, x, edge_index): - x1 = self.conv1(x, edge_index) - x1 = F.relu(x1) - x1 = F.dropout(x1, training=self.training) - x1 = self.conv2(x1, edge_index) - - x2 = self.conv1(x, edge_index) - x2 = F.relu(x2) - x2 = F.dropout(x2, training=self.training) - x2 = self.conv2(x2, edge_index) - - x3 = self.conv1(x, edge_index) - x3 = F.relu(x3) - x3 = F.dropout(x3, training=self.training) - x3 = self.conv2(x3, edge_index) - - x4 = self.conv1(x, edge_index) - x4 = F.relu(x4) - x4 = F.dropout(x4, training=self.training) - x4 = self.conv2(x4, edge_index) - - return ( - F.log_softmax(x1, dim=1), - F.log_softmax(x2, dim=1), - F.log_softmax(x3, dim=1), - F.log_softmax(x4, dim=1), - ) - - -class GNN_Het(torch.nn.Module): - def __init__(self, hidden_channels, out_channels): - super().__init__() - self.conv1 = SAGEConv((-1, -1), hidden_channels) - self.conv2 = SAGEConv((-1, -1), out_channels) - - def forward(self, x, edge_index): - x = self.conv1(x, edge_index).relu() - x = self.conv2(x, edge_index) - return x - - -class ARMA_Het(torch.nn.Module): - def __init__(self, hidden_channels, out_channels): - super().__init__() - self.conv1 = ARMAConv(-1, hidden_channels) - self.conv2 = ARMAConv(-1, out_channels) - - def forward(self, x, edge_index): - x = self.conv1(x, edge_index).relu() - x = self.conv2(x, edge_index) - return x - - -class GNN_Het_EA(torch.nn.Module): - def __init__(self, hidden_channels, out_channels): - super().__init__() - self.conv1 = SAGEConv((-1, -1), hidden_channels) - self.conv2 = SAGEConv((-1, -1), out_channels) - - def forward(self, x, edge_index, edge_attr): - x = self.conv1(x, edge_index, edge_attr).relu() - x = self.conv2(x, edge_index, edge_attr) - return x - - -class GCN(torch.nn.Module): - def __init__(self, hidden_channels): - super(GCN, self).__init__() - torch.manual_seed(12345) - self.conv1 = GCNConv(NUM_NODE_FEATURES, hidden_channels) - self.conv2 = GCNConv(hidden_channels, hidden_channels) - self.conv3 = GCNConv(hidden_channels, hidden_channels) - self.lin = Linear(hidden_channels, 3421) - - def forward(self, x, edge_index, batch): - x = self.conv1(x, edge_index) - x = x.relu() - x = self.conv2(x, edge_index) - x = x.relu() - x = self.conv3(x, edge_index) - - x = global_mean_pool(x, batch) # [batch_size, hidden_channels] - - x = F.dropout(x, p=0.5, training=self.training) - x = self.lin(x) - - return x - - -class GAT(torch.nn.Module): - def __init__(self, hidden_channels, out_channels): - super().__init__() - self.conv1 = GATConv((-1, -1), hidden_channels, add_self_loops=False) - self.lin1 = Linear(-1, hidden_channels) - self.conv2 = GATConv((-1, -1), out_channels, add_self_loops=False) - self.lin2 = Linear(-1, out_channels) - - def forward(self, x, edge_index): - x = self.conv1(x, edge_index) + self.lin1(x) - x = x.relu() - x = self.conv2(x, edge_index) + self.lin2(x) - return x - - -class VertexGNNEncoder(torch.nn.Module): - def __init__(self, hidden_channels, out_channels): - super().__init__() - - self.conv1 = SAGEConv(-1, hidden_channels) - self.conv2 = SAGEConv(hidden_channels, hidden_channels) - self.lin = Linear(hidden_channels, out_channels) - - def forward(self, x, edge_index): - x = self.conv1(x, edge_index).relu() - x = self.conv2(x, edge_index).relu() - return self.lin(x) - - -class StateGNNEncoder(torch.nn.Module): - def __init__(self, hidden_channels, out_channels): - super().__init__() - self.conv1 = SAGEConv((-1, -1), hidden_channels) - self.conv2 = SAGEConv((-1, -1), hidden_channels) - self.conv3 = SAGEConv((-1, -1), hidden_channels) - self.conv4 = SAGEConv((-1, -1), hidden_channels) - self.lin = Linear(hidden_channels, out_channels) - - def forward(self, x_dict, edge_index_dict): - game_x = self.conv1( - x_dict["game_vertex"], - edge_index_dict[("game_vertex", "to", "game_vertex")], - ).relu() - - state_x = self.conv2( - x_dict["state_vertex"], - edge_index_dict[("state_vertex", "parent_of", "state_vertex")], - ).relu() - - state_x = self.conv3( - (game_x, state_x), - edge_index_dict[("game_vertex", "history", "state_vertex")], - ).relu() - - state_x = self.conv4( - (game_x, state_x), - edge_index_dict[("game_vertex", "in", "state_vertex")], - ).relu() - - return self.lin(state_x) - - -class StateGNNEncoderConv(torch.nn.Module): - def __init__(self, hidden_channels, out_channels): - super().__init__() - # self.conv1 = GCNConv(5, hidden_channels) - # self.conv2 = GCNConv(6, hidden_channels) - # GravNetConv - # self.conv1 = GravNetConv(-1, hidden_channels, 2, 2, 2) - # self.conv2 = GravNetConv(-1, hidden_channels, 2, 2, 2) - # GatedGraphConv - self.conv1 = TAGConv(5, hidden_channels) - self.conv2 = TAGConv(6, hidden_channels) - self.conv3 = SAGEConv((-1, -1), hidden_channels) # SAGEConv - self.conv4 = SAGEConv((-1, -1), hidden_channels) - self.lin = Linear(hidden_channels, out_channels) - - def forward(self, x_dict, edge_index_dict): - game_x = self.conv1( - x_dict["game_vertex"], - edge_index_dict[("game_vertex", "to", "game_vertex")], - ).relu() - - state_x = self.conv2( - x_dict["state_vertex"], - edge_index_dict[("state_vertex", "parent_of", "state_vertex")], - ).relu() - - state_x = self.conv3( - (game_x, state_x), - edge_index_dict[("game_vertex", "history", "state_vertex")], - ).relu() - - state_x = self.conv4( - (game_x, state_x), - edge_index_dict[("game_vertex", "in", "state_vertex")], - ).relu() - - return self.lin(state_x) - - -class StateGNNEncoderConvTAG100hops(torch.nn.Module): - def __init__(self, hidden_channels, out_channels): - super().__init__() - self.conv1 = TAGConv(5, hidden_channels, k=100) - self.conv2 = TAGConv(6, hidden_channels, k=100) - self.conv3 = SAGEConv((-1, -1), hidden_channels) # SAGEConv - self.conv4 = SAGEConv((-1, -1), hidden_channels) - self.lin = Linear(hidden_channels, out_channels) - - def forward(self, x_dict, edge_index_dict): - game_x = self.conv1( - x_dict["game_vertex"], - edge_index_dict[("game_vertex", "to", "game_vertex")], - ).relu() - - state_x = self.conv2( - x_dict["state_vertex"], - edge_index_dict[("state_vertex", "parent_of", "state_vertex")], - ).relu() - - state_x = self.conv3( - (game_x, state_x), - edge_index_dict[("game_vertex", "history", "state_vertex")], - ).relu() - - state_x = self.conv4( - (game_x, state_x), - edge_index_dict[("game_vertex", "in", "state_vertex")], - ).relu() - - return self.lin(state_x) - - -class StateGNNEncoderConvEdgeAttr(torch.nn.Module): - def __init__(self, hidden_channels, out_channels): - super().__init__() - self.conv1 = TAGConv(5, hidden_channels, 2) - self.conv2 = TAGConv(6, hidden_channels, 3) # TAGConv - self.conv3 = GraphConv((-1, -1), hidden_channels) # SAGEConv - self.conv32 = GraphConv((-1, -1), hidden_channels) - self.conv4 = SAGEConv((-1, -1), hidden_channels) - self.conv42 = SAGEConv((-1, -1), hidden_channels) - self.lin = Linear(hidden_channels, out_channels) - - def forward(self, x_dict, edge_index_dict, edge_attr=None): - game_x = self.conv1( - x_dict["game_vertex"], - edge_index_dict[("game_vertex", "to", "game_vertex")], - ).relu() - - state_x = self.conv2( - x_dict["state_vertex"], - edge_index_dict[("state_vertex", "parent_of", "state_vertex")], - ).relu() - - state_x = self.conv3( - (game_x, state_x), - edge_index_dict[("game_vertex", "history", "state_vertex")], - edge_attr[("game_vertex", "history", "state_vertex")], - ).relu() - - state_x = self.conv32( - (game_x, state_x), - edge_index_dict[("game_vertex", "history", "state_vertex")], - edge_attr[("game_vertex", "history", "state_vertex")], - ).relu() - - state_x = self.conv4( - (game_x, state_x), - edge_index_dict[("game_vertex", "in", "state_vertex")], - ).relu() - - state_x = self.conv42( - (game_x, state_x), - edge_index_dict[("game_vertex", "in", "state_vertex")], - ).relu() - - return self.lin(state_x) - - -class StateGNNEncoderConvExp(torch.nn.Module): - def __init__(self, hidden_channels, out_channels): - super().__init__() - # self.conv1 = GCNConv(5, hidden_channels) - # self.conv2 = GCNConv(6, hidden_channels) - # GravNetConv - # self.conv1 = GravNetConv(-1, hidden_channels, 2, 2, 2) - # self.conv2 = GravNetConv(-1, hidden_channels, 2, 2, 2) - # GatedGraphConv - self.conv1 = TAGConv(5, hidden_channels, 10) - self.conv12 = TAGConv(hidden_channels, hidden_channels, 10) # TAGConv - self.conv22 = TAGConv(hidden_channels, hidden_channels, 10) # TAGConv - self.conv2 = TAGConv(6, hidden_channels, 10) # TAGConv - self.conv3 = SAGEConv((-1, -1), hidden_channels) # SAGEConv - self.conv32 = SAGEConv((-1, -1), hidden_channels) - self.conv4 = SAGEConv((-1, -1), hidden_channels) - self.conv42 = SAGEConv((-1, -1), hidden_channels) - self.lin = Linear(hidden_channels, out_channels) - - def forward(self, x_dict, edge_index_dict): - game_x = self.conv1( - x_dict["game_vertex"], - edge_index_dict[("game_vertex", "to", "game_vertex")], - ).relu() - - game_x = self.conv12( - game_x, - edge_index_dict[("game_vertex", "to", "game_vertex")], - ).relu() - - state_x = self.conv2( - x_dict["state_vertex"], - edge_index_dict[("state_vertex", "parent_of", "state_vertex")], - ).relu() - - state_x = self.conv22( - state_x, - edge_index_dict[("state_vertex", "parent_of", "state_vertex")], - ).relu() - - state_x = self.conv3( - (game_x, state_x), - edge_index_dict[("game_vertex", "history", "state_vertex")], - ).relu() - - state_x = self.conv32( - (game_x, state_x), - edge_index_dict[("game_vertex", "history", "state_vertex")], - ).relu() - - state_x = self.conv4( - (game_x, state_x), - edge_index_dict[("game_vertex", "in", "state_vertex")], - ).relu() - - state_x = self.conv42( - (game_x, state_x), - edge_index_dict[("game_vertex", "in", "state_vertex")], - ).relu() - - return self.lin(state_x) - - -class StateGNNEncoderConv(torch.nn.Module): - def __init__(self, hidden_channels, out_channels): - super().__init__() - self.conv1 = GCNConv(-1, hidden_channels) - self.conv2 = GCNConv(-1, hidden_channels) - self.conv12 = GCNConv(-1, hidden_channels) - self.conv22 = GCNConv(-1, hidden_channels) - self.conv3 = SAGEConv((-1, -1), hidden_channels) # SAGEConv - self.conv4 = SAGEConv((-1, -1), hidden_channels) - self.conv32 = SAGEConv((-1, -1), hidden_channels) # SAGEConv - self.conv42 = SAGEConv((-1, -1), hidden_channels) - self.lin = Linear(hidden_channels, out_channels) - - def forward(self, x_dict, edge_index_dict): - game_x = self.conv1( - x_dict["game_vertex"], - edge_index_dict[("game_vertex", "to", "game_vertex")], - ).relu() - - game_x = self.conv12( - game_x, - edge_index_dict[("game_vertex", "to", "game_vertex")], - ).relu() - - state_x = self.conv2( - x_dict["state_vertex"], - edge_index_dict[("state_vertex", "parent_of", "state_vertex")], - ).relu() - - state_x = self.conv22( - state_x, - edge_index_dict[("state_vertex", "parent_of", "state_vertex")], - ).relu() - - state_x = self.conv3( - (game_x, state_x), - edge_index_dict[("game_vertex", "history", "state_vertex")], - ).relu() - - state_x = self.conv32( - (game_x, state_x), - edge_index_dict[("game_vertex", "history", "state_vertex")], - ).relu() - - state_x = self.conv4( - (game_x, state_x), - edge_index_dict[("game_vertex", "in", "state_vertex")], - ).relu() - - state_x = self.conv42( - (game_x, state_x), - edge_index_dict[("game_vertex", "in", "state_vertex")], - ).relu() - - return self.lin(state_x) - - -class VerStateModel(torch.nn.Module): - def __init__(self, metadata, hidden_channels, out_channels): - super().__init__() - self.vertex_encoder = VertexGNNEncoder(hidden_channels, out_channels) - self.state_encoder = StateGNNEncoder(hidden_channels, out_channels) - self.decoder = GNN_Het(hidden_channels, out_channels) - self.decoder = to_hetero(self.decoder, metadata, aggr="sum") - - def forward(self, x_dict, edge_index_dict): - z_dict = {} - # x_dict['game_vertex'] = self.user_emb(x_dict['game_vertex']) - z_dict["state_vertex"] = self.state_encoder(x_dict, edge_index_dict) - z_dict["game_vertex"] = x_dict["game_vertex"] - # print(edge_index_dict) - # z_dict['state_vertex'] = self.state_encoder( - # x_dict['state_vertex'], - # edge_index_dict[('state_vertex', 'parent_of', 'state_vertex')], - # ) - - # return self.decoder(z_dict, edge_index_dict) # TODO: process separately - return z_dict - - -class StateModelEncoder(torch.nn.Module): - def __init__(self, hidden_channels, out_channels): - super().__init__() - # self.vertex_encoder = VertexGNNEncoder(hidden_channels, out_channels) - self.state_encoder = StateGNNEncoderConvEdgeAttr(hidden_channels, out_channels) - # self.decoder = GNN_Het(hidden_channels, out_channels) - # self.decoder = to_hetero(self.decoder, metadata, aggr='sum') - - @timeit - def forward(self, x_dict, edge_index_dict, edge_attr=None): - z_dict = {} - # x_dict['game_vertex'] = self.user_emb(x_dict['game_vertex']) - # print(x_dict, edge_index_dict) - z_dict["state_vertex"] = self.state_encoder(x_dict, edge_index_dict, edge_attr) - z_dict["game_vertex"] = x_dict["game_vertex"] - # print(edge_index_dict) - # z_dict['state_vertex'] = self.state_encoder( - # x_dict['state_vertex'], - # edge_index_dict[('state_vertex', 'parent_of', 'state_vertex')], - # ) - - # return self.decoder(z_dict, edge_index_dict) # TODO: process separately - return z_dict - - -class StateModelEncoderTAG100hops(torch.nn.Module): - def __init__(self, hidden_channels, out_channels): - super().__init__() - self.state_encoder = StateGNNEncoderConvTAG100hops( - hidden_channels, out_channels - ) - - def forward(self, x_dict, edge_index_dict): - z_dict = {} - z_dict["state_vertex"] = self.state_encoder(x_dict, edge_index_dict) - z_dict["game_vertex"] = x_dict["game_vertex"] - return z_dict - - -class SAGEConvModel(torch.nn.Module): - def __init__( - self, - hidden_channels, - num_gv_layers=2, - num_sv_layers=2, - ): - super().__init__() - self.gv_layers = nn.ModuleList() - self.gv_layers.append(SAGEConv(-1, hidden_channels)) - for i in range(num_gv_layers - 1): - sage_gv = SAGEConv(-1, hidden_channels) - self.gv_layers.append(sage_gv) - - self.sv_layers = nn.ModuleList() - self.sv_layers.append(SAGEConv(-1, hidden_channels)) - for i in range(num_sv_layers - 1): - sage_sv = SAGEConv(-1, hidden_channels) - self.sv_layers.append(sage_sv) - - self.history1 = GATConv((-1, -1), hidden_channels, add_self_loops=False) - self.in1 = SAGEConv((-1, -1), hidden_channels) - - self.sv_layers2 = nn.ModuleList() - self.sv_layers2.append(SAGEConv(-1, hidden_channels)) - for i in range(num_sv_layers - 1): - sage_sv = SAGEConv(-1, hidden_channels) - self.sv_layers2.append(sage_sv) - self.mlp = MLP(hidden_channels, [1]) - - @timeit - def forward(self, x_dict, edge_index_dict, edge_attr_dict): - game_x = self.gv_layers[0]( - x_dict["game_vertex"], - edge_index_dict[("game_vertex", "to", "game_vertex")], - ).relu() - for layer in self.gv_layers[1:]: - game_x = layer( - game_x, - edge_index_dict[("game_vertex", "to", "game_vertex")], - ).relu() - - state_x = self.sv_layers[0]( - x_dict["state_vertex"], - edge_index_dict[("state_vertex", "parent_of", "state_vertex")], - ).relu() - for layer in self.sv_layers[1:]: - state_x = layer( - state_x, - edge_index_dict[("state_vertex", "parent_of", "state_vertex")], - ).relu() - - history_x = self.history1( - (game_x, state_x), - edge_index_dict[("game_vertex", "history", "state_vertex")], - edge_attr_dict, - size=(game_x.size(0), state_x.size(0)), - ).relu() - - in_x = self.in1( - (game_x, history_x), edge_index_dict[("game_vertex", "in", "state_vertex")] - ).relu() - - state_x = self.sv_layers2[0]( - in_x, - edge_index_dict[("state_vertex", "parent_of", "state_vertex")], - ).relu() - for layer in self.sv_layers2[1:]: - state_x = layer( - state_x, - edge_index_dict[("state_vertex", "parent_of", "state_vertex")], - ).relu() - x = self.mlp(in_x) - return x diff --git a/VSharp.ML.AIAgent/ml/models/README.md b/VSharp.ML.AIAgent/ml/models/README.md new file mode 100644 index 000000000..a515da2cc --- /dev/null +++ b/VSharp.ML.AIAgent/ml/models/README.md @@ -0,0 +1,32 @@ + + +## Pretrained Models Summary Balanced Dataset + +| Name | Train Loss | Test Loss |#parameters|#epochs|#hidden|lr|Comments| +|--|--|--|--|--|--|--|--| +|[EdgeTypeRGCNSageSubgraphs64ch](./models/EdgeTypeRGCNSageSubgraphs64ch/) |0.009276 |0.009997 |22,920|20|64|0.0001|| +|[EdgeTypeRGCNSageSubgraphs64ch](./models/EdgeTypeRGCNSageSubgraphs64ch/) |0.006144 |0.006247 |22,920|50|64|0.0001|| +|[EdgeTypeRGCNSageSubgraphs64ch](./models/EdgeTypeRGCNSageSubgraphs64ch/) |0.004633 |0.004714 |22,920|100|64|0.0001|| + + +## Inference +```python +model = StateModelEncoder(hidden_channels=#hidden, out_channels=8) +``` + + +## Pretrained Models Summary Old + +| Name | Train Loss | Test Loss |#parameters|#epochs|#hidden|lr|Comments| +|--|--|--|--|--|--|--|--| +|[StateGNNEncoderConvEdgeAttrSPA](./models/StateGNNEncoderConvEdgeAttrSPA/) |0.006074 |0.006214 |12,392|20|32|0.0001|| +|[StateGNNEncoderConvEdgeAttrBasic32Ch](./models/StateGNNEncoderConvEdgeAttrBasic32Ch/) |0.007862 |0.007962 |9,896|20|32|0.0001|| +|[StateGNNEncoderConvEdgeAttrBasic](./models/StateGNNEncoderConvEdgeAttrBasic/) |0.007512 |0.007499 |36,168|20|64|0.0001|| +|[StatesAfterAllCompact32ch](./models/StatesAfterAllCompact32ch/) |0.006934 |0.006977 |8,296|20|32|0.0001|| +|[StatesAfterAllCompact32ch](./models/StatesAfterAllCompact32ch/) |0.006135 |0.006215 |8,296|50|32|0.0001|| + +## Inference +```python +model = StateModelEncoder(hidden_channels=#hidden, out_channels=8) +``` + diff --git a/VSharp.ML.AIAgent/ml/models/TAGSageSimple/32ch/20e/GNN_state_pred_het_dict b/VSharp.ML.AIAgent/ml/models/TAGSageSimple/32ch/20e/GNN_state_pred_het_dict new file mode 100644 index 000000000..303de7806 Binary files /dev/null and b/VSharp.ML.AIAgent/ml/models/TAGSageSimple/32ch/20e/GNN_state_pred_het_dict differ diff --git a/VSharp.ML.AIAgent/ml/models/TAGSageSimple/32ch/20e/GNN_state_pred_het_full b/VSharp.ML.AIAgent/ml/models/TAGSageSimple/32ch/20e/GNN_state_pred_het_full new file mode 100644 index 000000000..dc3816cd1 Binary files /dev/null and b/VSharp.ML.AIAgent/ml/models/TAGSageSimple/32ch/20e/GNN_state_pred_het_full differ diff --git a/VSharp.ML.AIAgent/ml/models/TAGSageSimple/32ch/20e/train_res b/VSharp.ML.AIAgent/ml/models/TAGSageSimple/32ch/20e/train_res new file mode 100644 index 000000000..22d77d965 --- /dev/null +++ b/VSharp.ML.AIAgent/ml/models/TAGSageSimple/32ch/20e/train_res @@ -0,0 +1,20 @@ +Epoch: 001, Train Loss: 0.131891, Test Loss: 0.126640 +Epoch: 002, Train Loss: 0.079984, Test Loss: 0.078351 +Epoch: 003, Train Loss: 0.058005, Test Loss: 0.057235 +Epoch: 004, Train Loss: 0.042217, Test Loss: 0.041487 +Epoch: 005, Train Loss: 0.035443, Test Loss: 0.034651 +Epoch: 006, Train Loss: 0.030673, Test Loss: 0.029894 +Epoch: 007, Train Loss: 0.027275, Test Loss: 0.026522 +Epoch: 008, Train Loss: 0.024535, Test Loss: 0.023877 +Epoch: 009, Train Loss: 0.022836, Test Loss: 0.022223 +Epoch: 010, Train Loss: 0.021300, Test Loss: 0.020765 +Epoch: 011, Train Loss: 0.019915, Test Loss: 0.019415 +Epoch: 012, Train Loss: 0.018757, Test Loss: 0.018282 +Epoch: 013, Train Loss: 0.018087, Test Loss: 0.017616 +Epoch: 014, Train Loss: 0.017681, Test Loss: 0.017230 +Epoch: 015, Train Loss: 0.016887, Test Loss: 0.016468 +Epoch: 016, Train Loss: 0.015917, Test Loss: 0.015553 +Epoch: 017, Train Loss: 0.015471, Test Loss: 0.015129 +Epoch: 018, Train Loss: 0.015431, Test Loss: 0.015111 +Epoch: 019, Train Loss: 0.014583, Test Loss: 0.014261 +Epoch: 020, Train Loss: 0.014174, Test Loss: 0.013858 \ No newline at end of file diff --git a/VSharp.ML.AIAgent/ml/models/TAGSageSimple/model.py b/VSharp.ML.AIAgent/ml/models/TAGSageSimple/model.py new file mode 100644 index 000000000..7ce1082d3 --- /dev/null +++ b/VSharp.ML.AIAgent/ml/models/TAGSageSimple/model.py @@ -0,0 +1,47 @@ +import torch +from torch.nn import Linear +from torch_geometric.nn import TAGConv, GraphConv, SAGEConv + + +class StateModelEncoder(torch.nn.Module): + def __init__(self, hidden_channels, out_channels): + super().__init__() + self.conv1 = TAGConv(5, hidden_channels, 2) + self.conv2 = TAGConv(hidden_channels, hidden_channels, 3) # TAGConv + self.conv3 = GraphConv((-1, -1), hidden_channels) # SAGEConv + self.conv4 = SAGEConv((-1, -1), hidden_channels) + self.lin = Linear(hidden_channels, out_channels) + + def forward( + self, + game_x, + state_x, + edge_index_v_v, + edge_type_v_v, + edge_index_history_v_s, + edge_attr_history_v_s, + edge_index_in_v_s, + edge_index_s_s, + ): + game_x = self.conv1( + game_x, + edge_index_v_v, + ).relu() + + state_x = self.conv3( + (game_x, state_x), + edge_index_history_v_s, + edge_attr_history_v_s, + ).relu() + + state_x = self.conv4( + (game_x, state_x), + edge_index_in_v_s, + ).relu() + + state_x = self.conv2( + state_x, + edge_index_s_s, + ).relu() + + return self.lin(state_x) diff --git a/VSharp.ML.AIAgent/ml/predict_state_vector_hetero.py b/VSharp.ML.AIAgent/ml/predict_state_vector_hetero.py index fff3d4b92..05c8c4104 100644 --- a/VSharp.ML.AIAgent/ml/predict_state_vector_hetero.py +++ b/VSharp.ML.AIAgent/ml/predict_state_vector_hetero.py @@ -1,29 +1,45 @@ import os.path -from collections import namedtuple +import pickle + +from typing import Dict import torch -import torch.nn.functional as F from torch_geometric.data import HeteroData + +from random import shuffle from torch_geometric.loader import DataLoader -from torch_geometric.nn import to_hetero +import torch.nn.functional as F -from config import GeneralConfig -from ml import data_loader_compact -from ml.models import GNN_Het +BALANCE_DATASET = False -StateVectorMapping = namedtuple("StateVectorMapping", ["state", "vector"]) + +def get_module_name(clazz): + return clazz.__module__.split(".")[-2] class PredictStateVectorHetGNN: """predicts ExpectedStateNumber using Heterogeneous GNN""" - def __init__(self): + def __init__(self, model_class, hidden): self.state_maps = {} - self.start() + self.model_class = model_class + self.hidden = hidden + + def train_and_save(self, dataset_dir, epochs, dir_to_save): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(device) + dataset = [] + for file in os.listdir(dataset_dir): + print(file) + with open(os.path.join(dataset_dir, file), "rb") as f: + dat = pickle.load(f) + if BALANCE_DATASET and len(dat) > 5000: + dat = dat[:5000] + print("Part of the dataset is chosen!") + dataset.extend(dat) - def start(self): - dataset = data_loader_compact.get_data_hetero_vector() torch.manual_seed(12345) + shuffle(dataset) split_at = round(len(dataset) * 0.85) @@ -34,107 +50,88 @@ def start(self): print(f"Number of test graphs: {len(test_dataset)}") train_loader = DataLoader( - train_dataset, batch_size=1, shuffle=True + train_dataset, batch_size=1, shuffle=False ) # TODO: add learning by batches! test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False) - model = GNN_Het(hidden_channels=64, out_channels=8) - model = to_hetero(model, dataset[0].metadata(), aggr="sum") - optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + model = self.model_class(hidden_channels=self.hidden, out_channels=8).to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) - for epoch in range(1, 31): - self.train(model, train_loader, optimizer) + for epoch in range(1, epochs + 1): + self.train(model, train_loader, optimizer, device) train_acc = self.tst(model, train_loader) test_acc = self.tst(model, test_loader) print( - f"Epoch: {epoch:03d}, Train Loss: {train_acc:.4f}, Test Loss: {test_acc:.4f}" + f"Epoch: {epoch:03d}, Train Loss: {train_acc:.6f}, Test Loss: {test_acc:.6f}" ) - self.save(model, "./saved_models") + self.save_simple(model, dir_to_save, epochs) # loss function from link prediction example def weighted_mse_loss(self, pred, target, weight=None): weight = 1.0 if weight is None else weight[target].to(pred.dtype) return (weight * (pred - target.to(pred.dtype)).pow(2)).mean() - def train(self, model, train_loader, optimizer): + def train(self, model, train_loader, optimizer, device): model.train() - for data in train_loader: # Iterate in batches over the training dataset. - out = model(data.x_dict, data.edge_index_dict) - pred = out["state_vertex"] + data = data.to(device) + optimizer.zero_grad() # Clear gradients. + out = model( + data.x_dict["game_vertex"], + data.x_dict["state_vertex"], + data.edge_index_dict["game_vertex_to_game_vertex"], + data["game_vertex_to_game_vertex"].edge_type, + data["game_vertex_history_state_vertex"].edge_index, + data["game_vertex_history_state_vertex"].edge_attr, + data["game_vertex_in_state_vertex"].edge_index, + data["state_vertex_parent_of_state_vertex"].edge_index, + ) target = data.y - loss = F.mse_loss(pred, target) + loss = F.mse_loss(out, target) loss.backward() # Derive gradients. optimizer.step() # Update parameters based on gradients. - optimizer.zero_grad() # Clear gradients. + @torch.no_grad() def tst(self, model, loader): model.eval() + total_loss = 0 + number_of_states_total = 0 for data in loader: - out = model(data.x_dict, data.edge_index_dict) - pred = out["state_vertex"] + out = model( + data.x_dict["game_vertex"], + data.x_dict["state_vertex"], + data["game_vertex_to_game_vertex"].edge_index, + data["game_vertex_to_game_vertex"].edge_type, + data["game_vertex_history_state_vertex"].edge_index, + data["game_vertex_history_state_vertex"].edge_attr, + data["game_vertex_in_state_vertex"].edge_index, + data["state_vertex_parent_of_state_vertex"].edge_index, + ) target = data.y - loss = F.mse_loss(pred, target) - return loss + for i, x in enumerate(out): + loss = F.mse_loss(x, target[i]) + total_loss += loss + number_of_states_total += 1 + return total_loss / number_of_states_total # correct / len(loader.dataset) @staticmethod - def predict_state(model, data: HeteroData, state_map: dict[int, int]) -> int: + def predict_state(model, data: HeteroData, state_map: Dict[int, int]) -> int: """Gets state id from model and heterogeneous graph data.state_map - maps real state id to state index""" state_map = {v: k for k, v in state_map.items()} # inversion for prediction out = model(data.x_dict, data.edge_index_dict) return state_map[int(out["state_vertex"].argmax(dim=0)[0])] - @staticmethod - def predict_state_with_dict( - model: torch.nn.Module, data: HeteroData, state_map: dict[int, int] - ) -> int: - """Gets state id from model and heterogeneous graph - data.state_map - maps real state id to state index""" - - data.to(GeneralConfig.DEVICE) - reversed_state_map = {v: k for k, v in state_map.items()} - - with torch.no_grad(): - out = model.forward(data.x_dict, data.edge_index_dict, data.edge_attr_dict) - - remapped = [] - - for index, vector in enumerate(out["state_vertex"]): - state_vector_mapping = StateVectorMapping( - state=reversed_state_map[index], - vector=(vector.detach().cpu().numpy()).tolist(), - ) - remapped.append(state_vector_mapping) - - return max(remapped, key=lambda mapping: sum(mapping.vector)).state - - def predict_state_single_out( - model: torch.nn.Module, data: HeteroData, state_map: dict[int, int] - ) -> int: - """Gets state id from model and heterogeneous graph - data.state_map - maps real state id to state index""" - - data.to(GeneralConfig.DEVICE) - reversed_state_map = {v: k for k, v in state_map.items()} - - with torch.no_grad(): - out = model.forward(data.x_dict, data.edge_index_dict, data.edge_attr_dict) - - remapped = [] - if type(out) is dict: - out = out["state_vertex"] - for index, vector in enumerate(out): - state_vector_mapping = StateVectorMapping( - state=reversed_state_map[index], - vector=(vector.detach().cpu().numpy()).tolist(), - ) - remapped.append(state_vector_mapping) - - return max(remapped, key=lambda mapping: sum(mapping.vector)).state - - def save(self, model, dir): + def save_simple(self, model, dir, epochs): + dir = os.path.join( + dir, + get_module_name(self.model_class), + str(self.hidden) + "ch", + str(epochs) + "e", + ) + if not os.path.exists(dir): + os.makedirs(dir) filepath = os.path.join(dir, "GNN_state_pred_het_dict") # case 1 torch.save(model.state_dict(), filepath) diff --git a/VSharp.ML.AIAgent/pretrain.py b/VSharp.ML.AIAgent/pretrain.py new file mode 100644 index 000000000..fdee3e77a --- /dev/null +++ b/VSharp.ML.AIAgent/pretrain.py @@ -0,0 +1,14 @@ +from ml.data_loader_compact import ServerDataloaderHeteroVector +from ml.models.TAGSageSimple.model import StateModelEncoder +from ml.predict_state_vector_hetero import PredictStateVectorHetGNN + + +def get_data_hetero_vector(): + dl = ServerDataloaderHeteroVector("../serialized") + return dl.dataset + + +if __name__ == "__main__": + # get_data_hetero_vector() + pr = PredictStateVectorHetGNN(StateModelEncoder, 32) + pr.train_and_save("../dataset", 20, "./ml/models/")