Skip to content

Commit

Permalink
Merge pull request #3 from emnigma/batching_in_common_model_training
Browse files Browse the repository at this point in the history
Сохранение моделек и рефакторинг
  • Loading branch information
Anya497 authored Dec 5, 2023
2 parents e2cebdc + 8647634 commit 6128645
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 91 deletions.
11 changes: 2 additions & 9 deletions VSharp.ML.AIAgent/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@

import torch

import ml.model_modified
import ml.models


class GeneralConfig:
SERVER_COUNT = 16
Expand All @@ -19,12 +16,8 @@ class GeneralConfig:
MAX_STEPS = 5000
MUTATION_PERCENT_GENES = 5
LOGGER_LEVEL = logging.INFO
IMPORT_MODEL_INIT = lambda: ml.models.StateModelEncoder(
hidden_channels=32, out_channels=8
)
EXPORT_MODEL_INIT = lambda: ml.model_modified.StateModelEncoderExport(
hidden_channels=32, out_channels=8
)
IMPORT_MODEL_INIT = ...
EXPORT_MODEL_INIT = ...
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


Expand Down
19 changes: 10 additions & 9 deletions VSharp.ML.AIAgent/ml/common_model/paths.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import os
import pathlib

csv_path = os.path.join("report", "epochs_tables")
models_path = os.path.join("report", "epochs_best")
common_models_path = os.path.join("report", "common_models")
best_models_dict_path = os.path.join("report", "updated_best_models_dicts")
dataset_root_path = os.path.join("report", "dataset")
dataset_map_results_file_name = os.path.join("report", "dataset_state.csv")
training_data_path = os.path.join("report", "run_tables")
pretrained_models_path = os.path.join("ml", "models")
CSV_PATH = os.path.join("report", "epochs_tables")
MODELS_PATH = os.path.join("report", "epochs_best")
COMMON_MODELS_PATH = os.path.join("report", "common_models")
BEST_MODELS_DICT_PATH = os.path.join("report", "updated_best_models_dicts")
DATASET_ROOT_PATH = os.path.join("report", "dataset")
DATASET_MAP_RESULTS_FILENAME = os.path.join("report", "dataset_state.csv")
TRAINING_DATA_PATH = os.path.join("report", "run_tables")
PRETRAINED_MODEL_PATH = os.path.join("ml", "models")

path_to_models_for_parallel_architecture = os.path.join(
PATH_TO_MODELS_FOR_PARALLEL_ARCHITECTURE = os.path.join(
"ml", "pretrained_models", "models_for_parallel_architecture"
)
52 changes: 22 additions & 30 deletions VSharp.ML.AIAgent/ml/common_model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,15 @@
import csv
import os
import re
import typing as t
from pathlib import Path

import torch
import numpy as np
import torch

from config import GeneralConfig
from ml.common_model.paths import (
csv_path,
models_path,
common_models_path,
)
from ml.common_model.paths import COMMON_MODELS_PATH, CSV_PATH, MODELS_PATH
from ml.utils import load_model
from ml.models.StateGNNEncoderConvEdgeAttr.model_modified import (
StateModelEncoderLastLayer,
)


def euclidean_dist(y_pred, y_true):
Expand All @@ -43,10 +37,10 @@ def get_tuple_for_max(t):
return tuple(values_list)


def csv2best_models():
def csv2best_models(ref_model_init: t.Callable[[], torch.nn.Module]):
best_models = {}
for epoch_num in range(1, len(os.listdir(csv_path)) + 1):
path_to_csv = os.path.join(csv_path, str(epoch_num) + ".csv")
for epoch_num in range(1, len(os.listdir(CSV_PATH)) + 1):
path_to_csv = os.path.join(CSV_PATH, str(epoch_num) + ".csv")
with open(path_to_csv, "r") as csv_file:
csv_reader = csv.reader(csv_file)
map_names = next(csv_reader)[1:]
Expand All @@ -59,18 +53,15 @@ def csv2best_models():
for i in range(len(int_row)):
models_stat[map_names[i]] = int_row[i]
models.append((row[0], models_stat))

for map_name in map_names:
best_model = max(models, key=(lambda m: m[1][map_name]))
best_model_name, best_model_score = best_model[0], best_model[1]
path_to_model = os.path.join(
models_path,
"epoch_" + str(epoch_num),
best_model_name + ".pth",
)
ref_model = load_model(
Path(path_to_model), model=StateModelEncoderLastLayer(32, 8)
)
for map_name in map_names:
best_model = max(models, key=(lambda m: m[1][map_name]))
best_model_name, best_model_score = best_model[0], best_model[1]
path_to_model = os.path.join(
MODELS_PATH,
"epoch_" + str(epoch_num),
best_model_name + ".pth",
)
ref_model = load_model(Path(path_to_model), model=ref_model_init())

ref_model.to(GeneralConfig.DEVICE)
best_models[map_name] = (
Expand Down Expand Up @@ -116,16 +107,14 @@ def save_best_models2csv(best_models: dict, path):
writer.writerows(values_for_csv)


def load_best_models_dict(path):
def load_best_models_dict(path, model_init: t.Callable[[], torch.nn.Module]):
best_models = csv2best_models()
with open(path, "r") as csv_file:
csv_reader = csv.reader(csv_file)
for row in csv_reader:
if row[1] != best_models[row[0]][2]:
path_to_model = os.path.join(common_models_path, row[1])
ref_model = load_model(
Path(path_to_model), model=StateModelEncoderLastLayer(32, 8)
)
path_to_model = os.path.join(COMMON_MODELS_PATH, row[1])
ref_model = load_model(Path(path_to_model), model=model_init())
ref_model.load_state_dict(torch.load(path_to_model))
ref_model.to(GeneralConfig.DEVICE)
best_models[row[0]] = (ref_model, ast.literal_eval(row[2]), row[1])
Expand All @@ -141,8 +130,11 @@ def load_dataset_state_dict(path):
return dataset_state_dict


def get_model(path_to_weights: Path, model: torch.nn.Module, random_seed: int):
def get_model(
path_to_weights: Path, model_init: t.Callable[[], torch.nn.Module], random_seed: int
):
np.random.seed(random_seed)
model = model_init()
weights = torch.load(path_to_weights)
weights["lin_last.weight"] = torch.tensor(np.random.random([1, 8]))
weights["lin_last.bias"] = torch.tensor(np.random.random([1]))
Expand Down
43 changes: 43 additions & 0 deletions VSharp.ML.AIAgent/ml/models/modelop/filemanager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import typing as t
from datetime import datetime
from pathlib import Path

import torch
from torch import Tensor


class PathSelectorNNProtocol(t.Protocol):
def forward(
self,
game_x,
state_x,
edge_index_v_v,
edge_index_history_v_s,
edge_attr_history_v_s,
edge_index_in_v_s,
edge_index_s_s,
) -> Tensor:
...


def save_model(model: torch.nn.Module, /, **initargs):
weights = model.state_dict()

# ml.models.TAGSageSimple.model
save_path_components = model.__module__.split(".")[:-1]

# ml.models.TAGSageSimple.model.StateModelEncoder
class_fullname = model.__module__ + "." + model.__class__.__name__

# **{hidden_channels: 32, out_channels: 8}
model_initargs = "_".join([f"{param}_{value}" for param, value in initargs.items()])

save_dir = Path("/".join(save_path_components))

timestamp = datetime.fromtimestamp(datetime.now().timestamp())

suffix = ".pt"

save_name = f"{class_fullname}{'_' + model_initargs + '_' if initargs else ''}{timestamp}{suffix}"

torch.save(weights, save_dir / save_name)
84 changes: 41 additions & 43 deletions VSharp.ML.AIAgent/run_common_model_training.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,42 @@
import copy
from dataclasses import dataclass, asdict
import logging
from multiprocessing import Pool
import multiprocessing

import multiprocessing as mp
import os
from pathlib import Path
from datetime import datetime
import random
import typing as t
from dataclasses import asdict, dataclass
from datetime import datetime
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import numpy as np
import tqdm
from torch_geometric.loader import DataLoader

from config import GeneralConfig
from connection.broker_conn.socket_manager import game_server_socket_manager
from connection.game_server_conn.utils import MapsType, get_maps
from epochs_statistics.tables import create_pivot_table, table_to_string
from learning.play_game import play_game
from ml.common_model.utils import (
csv2best_models,
get_model,
)
from ml.common_model.wrapper import CommonModelWrapper, BestModelsWrapper
from ml.common_model.dataset import FullDataset
from ml.common_model.paths import (
common_models_path,
best_models_dict_path,
dataset_root_path,
dataset_map_results_file_name,
training_data_path,
pretrained_models_path,
BEST_MODELS_DICT_PATH,
COMMON_MODELS_PATH,
DATASET_MAP_RESULTS_FILENAME,
DATASET_ROOT_PATH,
PRETRAINED_MODEL_PATH,
TRAINING_DATA_PATH,
)
import numpy as np
from ml.common_model.dataset import FullDataset
from torch_geometric.loader import DataLoader
import tqdm
import pandas as pd
from ml.common_model.utils import csv2best_models, get_model
from ml.common_model.wrapper import BestModelsWrapper, CommonModelWrapper
from ml.models.TAGSageSimple.model_modified import StateModelEncoderLastLayer


LOG_PATH = Path("./ml_app.log")
TABLES_PATH = Path("./ml_tables.log")
COMMON_MODELS_PATH = Path(common_models_path)
BEST_MODELS_DICT = Path(best_models_dict_path)
TRAINING_DATA_PATH = Path(training_data_path)
COMMON_MODELS_PATH = Path(COMMON_MODELS_PATH)
BEST_MODELS_DICT = Path(BEST_MODELS_DICT_PATH)
TRAINING_DATA_PATH = Path(TRAINING_DATA_PATH)


logging.basicConfig(
Expand All @@ -54,13 +47,13 @@
)

if not COMMON_MODELS_PATH.exists():
os.makedirs(common_models_path)
os.makedirs(COMMON_MODELS_PATH)

if not BEST_MODELS_DICT.exists():
os.makedirs(best_models_dict_path)
os.makedirs(BEST_MODELS_DICT_PATH)

if not TRAINING_DATA_PATH.exists():
os.makedirs(training_data_path)
os.makedirs(TRAINING_DATA_PATH)


def create_file(file: Path):
Expand Down Expand Up @@ -104,9 +97,9 @@ def train(train_config: TrainConfig, model: torch.nn.Module, dataset: FullDatase
run_name = f"{datetime.fromtimestamp(timestamp)}_{train_config.batch_size}_Adam_{train_config.lr}_KLDL"

print(run_name)
path_to_saved_models = os.path.join(common_models_path, run_name)
path_to_saved_models = os.path.join(COMMON_MODELS_PATH, run_name)
os.makedirs(path_to_saved_models)
TABLES_PATH = Path(os.path.join(training_data_path, run_name + ".log"))
TABLES_PATH = Path(os.path.join(TRAINING_DATA_PATH, run_name + ".log"))
create_file(TABLES_PATH)
create_file(LOG_PATH)

Expand All @@ -121,7 +114,7 @@ def train(train_config: TrainConfig, model: torch.nn.Module, dataset: FullDatase
for i in range(GeneralConfig.SERVER_COUNT)
]

multiprocessing.set_start_method("spawn", force=True)
mp.set_start_method("spawn", force=True)
# p = Pool(GeneralConfig.SERVER_COUNT)

all_average_results = []
Expand Down Expand Up @@ -164,7 +157,7 @@ def train(train_config: TrainConfig, model: torch.nn.Module, dataset: FullDatase
model.eval()
cmwrapper.make_copy(str(epoch + 1))

with Pool(GeneralConfig.SERVER_COUNT) as p:
with mp.Pool(GeneralConfig.SERVER_COUNT) as p:
result = list(p.map(play_game_task, tasks))

all_results = []
Expand Down Expand Up @@ -194,7 +187,7 @@ def train(train_config: TrainConfig, model: torch.nn.Module, dataset: FullDatase
)
append_to_file(TABLES_PATH, table + "\n")

path_to_model = os.path.join(common_models_path, run_name, str(epoch + 1))
path_to_model = os.path.join(COMMON_MODELS_PATH, run_name, str(epoch + 1))
torch.save(model.state_dict(), Path(path_to_model))
del data_list
del data_loader
Expand All @@ -203,14 +196,16 @@ def train(train_config: TrainConfig, model: torch.nn.Module, dataset: FullDatase
return all_average_results


def get_dataset(generate_dataset: bool):
dataset = FullDataset(dataset_root_path, dataset_map_results_file_name)
def get_dataset(
generate_dataset: bool, ref_model_init: t.Callable[[], torch.nn.Module]
):
dataset = FullDataset(DATASET_ROOT_PATH, DATASET_MAP_RESULTS_FILENAME)

if generate_dataset:
# creating new dataset
with game_server_socket_manager() as ws:
all_maps = get_maps(websocket=ws, type=MapsType.TRAIN)
best_models_dict = csv2best_models()
# creating new dataset
best_models_dict = csv2best_models(ref_model_init=ref_model_init)
play_game(
with_predictor=BestModelsWrapper(best_models_dict),
max_steps=GeneralConfig.MAX_STEPS,
Expand All @@ -228,16 +223,19 @@ def get_dataset(generate_dataset: bool):
def main():
print(GeneralConfig.DEVICE)
path_to_weights = os.path.join(
pretrained_models_path,
PRETRAINED_MODEL_PATH,
"TAGSageSimple",
"32ch",
"20e",
"GNN_state_pred_het_dict",
)
model_initializer = lambda: StateModelEncoderLastLayer(
hidden_channels=32, out_channels=8
)

best_result = {"average_coverage": 0, "config": dict(), "epoch": 0}
generate_dataset = False
dataset = get_dataset(generate_dataset)
dataset = get_dataset(generate_dataset, ref_model_init=model_initializer)

while True:
config = TrainConfig(
Expand All @@ -255,7 +253,7 @@ def main():

model = get_model(
Path(path_to_weights),
StateModelEncoderLastLayer(hidden_channels=32, out_channels=8),
model_initializer,
random_seed=937,
)

Expand Down

0 comments on commit 6128645

Please sign in to comment.