Skip to content

Commit

Permalink
Upd code squashed
Browse files Browse the repository at this point in the history
  • Loading branch information
emnigma committed Aug 4, 2023
1 parent 621bfa9 commit 5ccf8a9
Show file tree
Hide file tree
Showing 21 changed files with 523 additions and 293 deletions.
6 changes: 3 additions & 3 deletions VSharp.ML.AIAgent/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@ def _build_bar_format() -> str:
"dynamic_ncols": True,
}

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
BASE_NN_OUT_FEATURES_NUM = 8


class WebsocketSourceLinks:
GET_WS = f"http://0.0.0.0:{BrokerConfig.BROKER_PORT}/get_ws"
Expand All @@ -49,3 +46,6 @@ class ResultsHandlerLinks:
DUMMY_INPUT_PATH = Path("ml/onnx/dummy_input.json")
BEST_MODEL_ONNX_SAVE_PATH = Path("ml/onnx/StateModelEncoder.onnx")
TEMP_EPOCH_INFERENCE_TIMES_DIR = Path(".epoch_inference_times/")

# assuming we start from /VSharp/VSharp.ML.AIAgent
SERVER_WORKING_DIR = "../VSharp.ML.GameServer.Runner/bin/Release/net6.0/"
24 changes: 21 additions & 3 deletions VSharp.ML.AIAgent/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,22 @@
from pathlib import Path
from shutil import rmtree

import torch

import ml.models


class GeneralConfig:
SERVER_COUNT = 8
NUM_GENERATIONS = 20
NUM_PARENTS_MATING = 10
KEEP_ELITISM = 2
NUM_SOLUTIONS = 60
MAX_STEPS = 3000
MUTATION_PERCENT_GENES = 30
LOGGER_LEVEL = logging.INFO
MODEL_INIT = lambda: ml.models.SAGEConvModel(16)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class BrokerConfig:
Expand All @@ -24,7 +32,7 @@ class ServerConfig:
@dataclass(slots=True, frozen=True)
class DumpByTimeoutFeature:
enabled: bool
timeout_seconds: int
timeout_sec: int
save_path: Path

def create_save_path_if_not_exists(self):
Expand All @@ -50,11 +58,21 @@ class FeatureConfig:
VERBOSE_TABLES = True
SHOW_SUCCESSORS = True
NAME_LEN = 7
N_BEST_SAVED_EACH_GEN = 2
DISABLE_MESSAGE_CHECKS = True
DUMP_BY_TIMEOUT = DumpByTimeoutFeature(
enabled=True, timeout_seconds=1200, save_path=Path("./report/timeouted_agents/")
enabled=True, timeout_sec=600, save_path=Path("./report/timeouted_agents/")
)
SAVE_EPOCHS_COVERAGES = SaveEpochsCoveragesFeature(
enabled=True, save_path=Path("./report/epochs_tables/")
)
ON_GAME_SERVER_RESTART = True


class GameServerConnectorConfig:
CREATE_CONNECTION_TIMEOUT = 1
RESPONCE_TIMEOUT_SEC = (
FeatureConfig.DUMP_BY_TIMEOUT.timeout_sec + 1
if FeatureConfig.DUMP_BY_TIMEOUT.enabled
else 1000
)
SKIP_UTF_VALIDATION = True
13 changes: 12 additions & 1 deletion VSharp.ML.AIAgent/connection/broker_conn/classes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Callable
from typing import Callable, TypeAlias

from dataclasses_json import config, dataclass_json

Expand All @@ -8,6 +8,17 @@
from connection.game_server_conn.unsafe_json import asdict
from ml.model_wrappers.nnwrapper import NNWrapper, decode, encode

WSUrl: TypeAlias = str
Undefined: TypeAlias = None


@dataclass_json
@dataclass(slots=True, frozen=True)
class ServerInstanceInfo:
port: int
ws_url: WSUrl
pid: int | Undefined


def custom_encoder_if_disable_message_checks() -> Callable | None:
return asdict if FeatureConfig.DISABLE_MESSAGE_CHECKS else None
Expand Down
27 changes: 15 additions & 12 deletions VSharp.ML.AIAgent/connection/broker_conn/requests.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,39 @@
import json
import logging

import httplib2

from common.constants import ResultsHandlerLinks, WebsocketSourceLinks

from .classes import Agent2ResultsOnMaps
from .classes import Agent2ResultsOnMaps, ServerInstanceInfo


def aquire_ws() -> str:
def acquire_instance() -> ServerInstanceInfo:
while True:
response, content = httplib2.Http().request(WebsocketSourceLinks.GET_WS)
aquired_ws_url = content.decode("utf-8")
if aquired_ws_url == "":
if content.decode("utf-8") == "":
logging.warning(f"all sockets are in use")
continue
logging.info(f"aquired ws: {aquired_ws_url}")
return aquired_ws_url
aquired_instance = ServerInstanceInfo.from_json(
json.loads(content.decode("utf-8"))
)
logging.info(f"acquired ws: {aquired_instance}")
return aquired_instance


def return_ws(ws_url: str):
logging.info(f"returning: {ws_url}")
def return_instance(instance: ServerInstanceInfo):
logging.info(f"returning: {instance}")

response, content = httplib2.Http().request(
WebsocketSourceLinks.POST_WS,
method="POST",
body=ws_url,
body=instance.to_json(),
)

if response.status == 200:
logging.info(f"{ws_url} is returned")
logging.info(f"{instance} is returned")
else:
logging.error(f"{response.status} on returning {ws_url}")
logging.error(f"{response.status} on returning {instance}")
raise RuntimeError(f"Not ok response: {response.status}")


Expand All @@ -51,5 +54,5 @@ def send_game_results(data: Agent2ResultsOnMaps):
def recv_game_result_list() -> str:
response, content = httplib2.Http().request(ResultsHandlerLinks.GET_RES)
games_data = content.decode("utf-8")
logging.info(f"Aquired games data")
logging.info(f"Acquired games data")
return games_data
33 changes: 28 additions & 5 deletions VSharp.ML.AIAgent/connection/broker_conn/socket_manager.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,38 @@
from contextlib import contextmanager
import time
from contextlib import contextmanager, suppress

import websocket
from .requests import aquire_ws, return_ws

from config import GameServerConnectorConfig
from connection.broker_conn.classes import WSUrl

from .requests import acquire_instance, return_instance

websocket.setdefaulttimeout(GameServerConnectorConfig.RESPONCE_TIMEOUT_SEC)


def wait_for_connection(url: WSUrl):
ws = websocket.WebSocket()

while True:
with suppress(ConnectionRefusedError, ConnectionResetError):
ws.settimeout(GameServerConnectorConfig.CREATE_CONNECTION_TIMEOUT)
ws.connect(
url, skip_utf8_validation=GameServerConnectorConfig.SKIP_UTF_VALIDATION
)
if ws.connected:
return ws
time.sleep(GameServerConnectorConfig.CREATE_CONNECTION_TIMEOUT)


@contextmanager
def game_server_socket_manager():
socket_url = aquire_ws()
socket = websocket.create_connection(socket_url, skip_utf8_validation=True)
server_instance = acquire_instance()

socket = None
try:
socket = wait_for_connection(server_instance.ws_url)
yield socket
finally:
socket.close()
return_ws(socket_url)
return_instance(server_instance)
13 changes: 7 additions & 6 deletions VSharp.ML.AIAgent/connection/game_server_conn/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,20 +103,21 @@ def send_step(self, next_state_id: int, predicted_usefullness: int):
self._sent_state_id = next_state_id

def recv_reward_or_throw_gameover(self) -> Reward:
data = RewardServerMessage.from_json_handle(
self._raise_if_gameover(self.ws.recv()),
received = self.ws.recv()
decoded = RewardServerMessage.from_json_handle(
self._raise_if_gameover(received),
expected=RewardServerMessage,
)
logging.debug(f"<-- MoveReward : {data.MessageBody}")
logging.debug(f"<-- MoveReward : {decoded.MessageBody}")

return self._process_reward_server_message(data)
return self._process_reward_server_message(decoded)

def _process_reward_server_message(self, msg):
match msg.MessageType:
case ServerMessageType.INCORRECT_PREDICTED_STATEID:
raise Connector.IncorrectSentStateError(
f"Sending state_id={self._sent_state_id} \
at step #{self._current_step} resulted in {msg.MessageType}"
f"Sending state_id={self._sent_state_id} "
f"at step #{self._current_step} resulted in {msg.MessageType}"
)

case ServerMessageType.MOVE_REVARD:
Expand Down
9 changes: 8 additions & 1 deletion VSharp.ML.AIAgent/epochs_statistics/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,14 @@ def create_pivot_table(
for mutable2result in mutable2result_list:
name_results_dict[map_obj.Id].append(convert_to_view_model(mutable2result))
epoch_percents_dict[map_obj.Id].append(
mutable2result.game_result.actual_coverage_percent
str(
(
mutable2result.game_result.actual_coverage_percent,
mutable2result.game_result.tests_count,
mutable2result.game_result.errors_count,
mutable2result.game_result.steps_count,
)
)
)

mutable_names = get_model_names_in_order(name_results_dict)
Expand Down
1 change: 1 addition & 0 deletions VSharp.ML.AIAgent/install_script.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
conda install numpy pandas tabulate
conda install -c pytorch pytorch=1.13.1 torchvision=0.14.1 torchaudio=0.13.1
python3 -m pip install --force-reinstall -v "torch-scatter==2.1.0" "torch-geometric==2.2.0" "torch-sparse==0.6.16"
python3 -m pip install func_timeout
conda install -c conda-forge dataclasses-json websocket-client pre_commit aiohttp cchardet pygad httplib2 onnx onnxruntime
pre-commit install
Loading

0 comments on commit 5ccf8a9

Please sign in to comment.