Skip to content

Commit

Permalink
Merge pull request #76 from katyacyfra/mlSearcher
Browse files Browse the repository at this point in the history
Ml searcher add pretrained
  • Loading branch information
gsvgit authored Nov 21, 2023
2 parents 16e54f8 + a31cb41 commit 1fce131
Show file tree
Hide file tree
Showing 9 changed files with 207 additions and 873 deletions.
47 changes: 18 additions & 29 deletions VSharp.ML.AIAgent/ml/data_loader_compact.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import argparse
import json
import os.path
import pickle
Expand Down Expand Up @@ -46,6 +45,7 @@ def convert_input_to_tensor(input: GameState) -> Tuple[HeteroData, Dict[int, int
edges_index_s_v_history = []
edges_index_v_s_history = []
edges_attr_v_v = []
edges_types_v_v = []

edges_attr_s_v = []
edges_attr_v_s = []
Expand Down Expand Up @@ -77,9 +77,8 @@ def convert_input_to_tensor(input: GameState) -> Tuple[HeteroData, Dict[int, int
edges_index_v_v.append(
np.array([vertex_map[e.VertexFrom], vertex_map[e.VertexTo]])
)
edges_attr_v_v.append(
np.array([e.Label.Token])
) # TODO: consider token in a model
edges_attr_v_v.append(np.array([e.Label.Token]))
edges_types_v_v.append(e.Label.Token)

state_doubles = 0

Expand All @@ -101,7 +100,7 @@ def convert_input_to_tensor(input: GameState) -> Tuple[HeteroData, Dict[int, int
)
)
# history edges: state -> vertex and back
for h in s.History: # TODO: process NumOfVisits as edge label
for h in s.History:
v_to = vertex_map[h.GraphVertexId]
edges_index_s_v_history.append(np.array([state_index, v_to]))
edges_index_v_s_history.append(np.array([v_to, state_index]))
Expand All @@ -124,15 +123,21 @@ def convert_input_to_tensor(input: GameState) -> Tuple[HeteroData, Dict[int, int

data["game_vertex"].x = torch.tensor(np.array(nodes_vertex), dtype=torch.float)
data["state_vertex"].x = torch.tensor(np.array(nodes_state), dtype=torch.float)
data["game_vertex", "to", "game_vertex"].edge_index = (
data["game_vertex_to_game_vertex"].edge_index = (
torch.tensor(np.array(edges_index_v_v), dtype=torch.long).t().contiguous()
)
data["state_vertex", "in", "game_vertex"].edge_index = (
data["game_vertex_to_game_vertex"].edge_attr = torch.tensor(
np.array(edges_attr_v_v), dtype=torch.long
)
data["game_vertex_to_game_vertex"].edge_type = torch.tensor(
np.array(edges_types_v_v), dtype=torch.long
)
data["state_vertex_in_game_vertex"].edge_index = (
torch.tensor(np.array(edges_index_s_v_in), dtype=torch.long)
.t()
.contiguous()
)
data["game_vertex", "in", "state_vertex"].edge_index = (
data["game_vertex_in_state_vertex"].edge_index = (
torch.tensor(np.array(edges_index_v_s_in), dtype=torch.long)
.t()
.contiguous()
Expand All @@ -149,24 +154,24 @@ def null_if_empty(tensor):
else torch.empty((2, 0), dtype=torch.int64)
)

data["state_vertex", "history", "game_vertex"].edge_index = null_if_empty(
data["state_vertex_history_game_vertex"].edge_index = null_if_empty(
torch.tensor(np.array(edges_index_s_v_history), dtype=torch.long)
.t()
.contiguous()
)
data["game_vertex", "history", "state_vertex"].edge_index = null_if_empty(
data["game_vertex_history_state_vertex"].edge_index = null_if_empty(
torch.tensor(np.array(edges_index_v_s_history), dtype=torch.long)
.t()
.contiguous()
)
data["state_vertex", "history", "game_vertex"].edge_attr = torch.tensor(
data["state_vertex_history_game_vertex"].edge_attr = torch.tensor(
np.array(edges_attr_s_v), dtype=torch.long
)
data["game_vertex", "history", "state_vertex"].edge_attr = torch.tensor(
data["game_vertex_history_state_vertex"].edge_attr = torch.tensor(
np.array(edges_attr_v_s), dtype=torch.long
)
# if (edges_index_s_s): #TODO: empty?
data["state_vertex", "parent_of", "state_vertex"].edge_index = null_if_empty(
data["state_vertex_parent_of_state_vertex"].edge_index = null_if_empty(
torch.tensor(np.array(edges_index_s_s), dtype=torch.long).t().contiguous()
)
# print(data['state', 'parent_of', 'state'].edge_index)
Expand Down Expand Up @@ -242,19 +247,3 @@ def __process_files(self):
with open(PIK, "wb") as f:
pickle.dump(self.dataset, f)
self.dataset = []


def parse_cmd_line_args():
parser = argparse.ArgumentParser(
prog="V# pytorch-geometric data conversion", description="Symbolic execution"
)
parser.add_argument("--dataset", required=True, help="Dataset folder")
parser.add_argument(
"--mode", help="heterogeneous or homogeneous graph model (het|hom)"
)


def get_data_hetero_vector():
dl = ServerDataloaderHeteroVector("../../GNN_V#/all")
# dl = ServerDataloaderHetero("../../GNN_V#/Serialized_test")
return dl.dataset
Loading

0 comments on commit 1fce131

Please sign in to comment.