Skip to content

Commit

Permalink
Merge pull request #79 from Anya497/batching_in_common_model_training
Browse files Browse the repository at this point in the history
Batching in common model training
  • Loading branch information
emnigma authored Dec 5, 2023
2 parents d512831 + 6128645 commit 0df023e
Show file tree
Hide file tree
Showing 17 changed files with 740 additions and 218 deletions.
1 change: 1 addition & 0 deletions VSharp.ML.AIAgent/.gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# python cache and venv
.env
nvidia_env
__pycache__/
report**/
ml/pretrained_models/
Expand Down
11 changes: 2 additions & 9 deletions VSharp.ML.AIAgent/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@

import torch

import ml.model_modified
import ml.models


class GeneralConfig:
SERVER_COUNT = 16
Expand All @@ -19,12 +16,8 @@ class GeneralConfig:
MAX_STEPS = 5000
MUTATION_PERCENT_GENES = 5
LOGGER_LEVEL = logging.INFO
IMPORT_MODEL_INIT = lambda: ml.models.StateModelEncoder(
hidden_channels=32, out_channels=8
)
EXPORT_MODEL_INIT = lambda: ml.model_modified.StateModelEncoderExport(
hidden_channels=32, out_channels=8
)
IMPORT_MODEL_INIT = ...
EXPORT_MODEL_INIT = ...
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


Expand Down
55 changes: 39 additions & 16 deletions VSharp.ML.AIAgent/learning/play_game.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,30 @@
from statistics import StatisticsError
from time import perf_counter
from typing import TypeAlias
import random

import tqdm
from func_timeout import FunctionTimedOut, func_set_timeout

from common.classes import GameResult, Map2Result
from common.constants import TQDM_FORMAT_DICT
from common.game import GameMap
from common.utils import get_states
from config import FeatureConfig
from config import FeatureConfig, GeneralConfig
from connection.broker_conn.socket_manager import game_server_socket_manager
from connection.game_server_conn.connector import Connector
from connection.game_server_conn.utils import MapsType, get_maps
from connection.game_server_conn.utils import MapsType
from learning.timer.resources_manager import manage_map_inference_times_array
from learning.timer.stats import compute_statistics
from learning.timer.utils import get_map_inference_times
from ml.data_loader_compact import ServerDataloaderHeteroVector
from ml.fileop import save_model
from ml.model_wrappers.protocols import Predictor

TimeDuration: TypeAlias = float


def play_map(
with_connector: Connector, with_predictor: Predictor
with_connector: Connector, with_predictor: Predictor, with_dataset
) -> tuple[GameResult, TimeDuration]:
steps_count = 0
game_state = None
Expand All @@ -33,12 +34,23 @@ def play_map(

start_time = perf_counter()

map_steps = []

def add_single_step(input, output):
hetero_input, _ = ServerDataloaderHeteroVector.convert_input_to_tensor(input)
hetero_input["y_true"] = output
hetero_input.to(GeneralConfig.DEVICE)
map_steps.append(hetero_input)

try:
for _ in range(steps):
game_state = with_connector.recv_state_or_throw_gameover()
predicted_state_id = with_predictor.predict(
predicted_state_id, nn_output = with_predictor.predict(
game_state, with_connector.map.MapName
)

add_single_step(game_state, nn_output)

logging.debug(
f"<{with_predictor.name()}> step: {steps_count}, available states: {get_states(game_state)}, predicted: {predicted_state_id}"
)
Expand Down Expand Up @@ -83,15 +95,21 @@ def play_map(
errors_count=errors_count,
actual_coverage_percent=actual_coverage,
)

with_predictor.update(with_connector.map.MapName, model_result)
if with_dataset is not None:
map_result = (
model_result.actual_coverage_percent,
-model_result.tests_count,
model_result.errors_count,
-model_result.steps_count,
)
with_dataset.update(with_connector.map.MapName, map_result, map_steps)
return model_result, end_time - start_time


def play_map_with_stats(
with_connector: Connector, with_predictor: Predictor
with_connector: Connector, with_predictor: Predictor, with_dataset
) -> tuple[GameResult, TimeDuration]:
model_result, time_duration = play_map(with_connector, with_predictor)
model_result, time_duration = play_map(with_connector, with_predictor, with_dataset)

with manage_map_inference_times_array():
try:
Expand All @@ -110,15 +128,19 @@ def play_map_with_stats(

@func_set_timeout(FeatureConfig.DUMP_BY_TIMEOUT.timeout_sec)
def play_map_with_timeout(
with_connector: Connector, with_predictor: Predictor
with_connector: Connector, with_predictor: Predictor, with_dataset
) -> tuple[GameResult, TimeDuration]:
return play_map_with_stats(with_connector, with_predictor)
return play_map_with_stats(with_connector, with_predictor, with_dataset)


def play_game(with_predictor: Predictor, max_steps: int, maps_type: MapsType):
with game_server_socket_manager() as ws:
maps = get_maps(websocket=ws, type=maps_type)
random.shuffle(maps)
def play_game(
with_predictor: Predictor,
max_steps: int,
maps: list[GameMap],
maps_type: MapsType,
with_dataset=None,
):
# random.shuffle(maps)
with tqdm.tqdm(
total=len(maps),
desc=f"{with_predictor.name():20}: {maps_type.value}",
Expand All @@ -138,6 +160,7 @@ def play_game(with_predictor: Predictor, max_steps: int, maps_type: MapsType):
game_result, time = play_func(
with_connector=Connector(ws, game_map, max_steps),
with_predictor=with_predictor,
with_dataset=with_dataset,
)
logging.info(
f"<{with_predictor.name()}> finished map {game_map.MapName} "
Expand All @@ -159,4 +182,4 @@ def play_game(with_predictor: Predictor, max_steps: int, maps_type: MapsType):
)
list_of_map2result.append(Map2Result(game_map, game_result))
pbar.update(1)
return list_of_map2result
return (list_of_map2result, with_dataset.maps_data)
136 changes: 136 additions & 0 deletions VSharp.ML.AIAgent/ml/common_model/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from collections.abc import Sequence
import torch

import os
import numpy as np

import tqdm
import logging
from ml.common_model.utils import load_dataset_state_dict
import csv
from torch_geometric.data import HeteroData
from typing import TypeAlias


MapName: TypeAlias = str
GameStatistics: TypeAlias = tuple[int, int, int, int]
GameStepHeteroData: TypeAlias = HeteroData
GameStepsOnMapInfo: TypeAlias = tuple[GameStatistics, Sequence[GameStepHeteroData]]


class FullDataset:
def __init__(
self,
dataset_root_path,
dataset_map_results_file_name,
similar_steps_save_prob=0,
):
self.dataset_map_results_file_name = dataset_map_results_file_name
self.dataset_root_path = dataset_root_path
self.maps_data: dict[str, GameStepsOnMapInfo] = dict()
self.similar_steps_save_prob = similar_steps_save_prob

def load(self):
maps_results = load_dataset_state_dict(self.dataset_map_results_file_name)
for file_with_map_steps in tqdm.tqdm(
os.listdir(self.dataset_root_path), desc="data loading"
):
map_steps = torch.load(
os.path.join(self.dataset_root_path, file_with_map_steps),
map_location="cpu",
)
map_name = file_with_map_steps[:-3]
filtered_map_steps = self.filter_map_steps(map_steps)
filtered_map_steps = self.remove_similar_steps(filtered_map_steps)
self.maps_data[map_name] = (maps_results[map_name], filtered_map_steps)

def remove_similar_steps(self, map_steps):
filtered_map_steps = []
for step in map_steps:
if (
len(filtered_map_steps) != 0
and step["y_true"].size() == filtered_map_steps[-1]["y_true"].size()
):
cos_d = 1 - torch.sum(
(step["y_true"] / torch.linalg.vector_norm(step["y_true"]))
* (
filtered_map_steps[-1]["y_true"]
/ torch.linalg.vector_norm(filtered_map_steps[-1]["y_true"])
)
)
if (
cos_d < 1e-7
and step["game_vertex"]["x"].size()[0]
== filtered_map_steps[-1]["game_vertex"]["x"].size()[0]
):
step.use_for_train = np.random.choice(
[True, False],
p=[
self.similar_steps_save_prob,
1 - self.similar_steps_save_prob,
],
)
else:
step.use_for_train = True
else:
step.use_for_train = True
filtered_map_steps.append(step)
return filtered_map_steps

def filter_map_steps(self, map_steps):
filtered_map_steps = []
for step in map_steps:
if step["y_true"].size()[0] != 1 and not step["y_true"].isnan().any():
max_ind = torch.argmax(step["y_true"])
step["y_true"] = torch.zeros_like(step["y_true"])
step["y_true"][max_ind] = 1.0
filtered_map_steps.append(step)
return filtered_map_steps

def get_plain_data(self):
result = []
for _, map_steps in self.maps_data.values():
for step in map_steps:
if step.use_for_train:
result.append(step)
return result

def save(self):
values_for_csv = []
for map_name in self.maps_data.keys():
values_for_csv.append(
{
"map_name": map_name,
"result": self.maps_data[map_name][0],
}
)
torch.save(
self.maps_data[map_name][1],
os.path.join(self.dataset_root_path, map_name + ".pt"),
)
with open(self.dataset_map_results_file_name, "w") as csv_file:
writer = csv.DictWriter(csv_file, fieldnames=["map_name", "result"])
writer.writerows(values_for_csv)

def update(
self,
map_name,
map_result: tuple[int, int, int, int],
map_steps,
move_to_cpu=False,
):
if move_to_cpu:
for x in map_steps:
x.to("cpu")
filtered_map_steps = self.filter_map_steps(map_steps)
if map_name in self.maps_data.keys():
if self.maps_data[map_name][0] < map_result:
logging.info(
f"The model with result = {self.maps_data[map_name][0]} was replaced with the model with "
f"result = {map_result} on the map {map_name}"
)
filtered_map_steps = self.remove_similar_steps(filtered_map_steps)
self.maps_data[map_name] = (map_result, filtered_map_steps)
else:
filtered_map_steps = self.remove_similar_steps(filtered_map_steps)
self.maps_data[map_name] = (map_result, filtered_map_steps)
Loading

0 comments on commit 0df023e

Please sign in to comment.