Skip to content

Commit

Permalink
Upd new model code
Browse files Browse the repository at this point in the history
  • Loading branch information
emnigma committed Aug 4, 2023
1 parent 71cc7fb commit 8fbbbe7
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 11 deletions.
8 changes: 2 additions & 6 deletions VSharp.ML.AIAgent/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,8 @@ def _build_bar_format() -> str:
return f"{custom_left} {custom_bar} - {custom_info}"


IMPORTED_FULL_MODEL_PATH = Path(
"ml/imported/GNN_state_pred_het_full_TAGConv_20e_2xAll_10h"
)
IMPORTED_DICT_MODEL_PATH = Path(
"ml/imported/GNN_state_pred_het_dict_TAGConv_20e_2xAll_10h"
)
IMPORTED_FULL_MODEL_PATH = Path("ml/imported/GNN_state_pred_het_full_compact.zip")
IMPORTED_DICT_MODEL_PATH = Path("ml/imported/GNN_state_pred_het_dict_compact.zip")

BASE_REPORT_DIR = Path("./report")
TABLES_LOG_FILE = BASE_REPORT_DIR / "tables.log"
Expand Down
4 changes: 4 additions & 0 deletions VSharp.ML.AIAgent/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch

import ml.model_modified
import ml.models


class GeneralConfig:
Expand All @@ -17,6 +18,9 @@ class GeneralConfig:
MAX_STEPS = 100
MUTATION_PERCENT_GENES = 4
LOGGER_LEVEL = logging.INFO
IMPORT_MODEL_INIT = lambda: ml.models.StateModelEncoder(
hidden_channels=64, out_channels=8
)
MODEL_INIT = lambda: ml.model_modified.StateModelEncoderExport(
hidden_channels=64, out_channels=8
)
Expand Down
2 changes: 1 addition & 1 deletion VSharp.ML.AIAgent/connection/broker_conn/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

def acquire_instance() -> ServerInstanceInfo:
response, content = httplib2.Http().request(WebsocketSourceLinks.GET_WS)
aquired_instance = ServerInstanceInfo.from_json(content.decode("utf-8"))
aquired_instance = ServerInstanceInfo.from_json(json.loads(content.decode("utf-8")))
logging.info(f"acquired ws: {aquired_instance}")
return aquired_instance

Expand Down
11 changes: 7 additions & 4 deletions VSharp.ML.AIAgent/main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import learning.entry_point as ga
import ml.onnx.onnx_import
from common.constants import IMPORTED_DICT_MODEL_PATH
from config import GeneralConfig
from learning.selection.crossover_type import CrossoverType
from learning.selection.mutation_type import MutationType
from learning.selection.parent_selection_type import ParentSelectionType
from ml.utils import (
create_population,
load_model,
model_weights_with_last_layer,
model_weights_with_random_last_layer,
)
Expand All @@ -16,9 +18,10 @@ def main():

model.forward(*ml.onnx.onnx_import.create_torch_dummy_input())

random_population = create_population(lo=-5, hi=5, model=model, population_size=4)
random_population = create_population(lo=-5, hi=5, model=model, population_size=60)
with_random_last_layer = [
model_weights_with_random_last_layer(lo=-1, hi=1, model=model) for _ in range(2)
model_weights_with_random_last_layer(lo=-1, hi=1, model=model)
for _ in range(18)
]
with_last_layer1 = model_weights_with_last_layer(
[
Expand All @@ -31,7 +34,7 @@ def main():
0.95478059636744,
0.27937866719070503,
],
model,
load_model(path=IMPORTED_DICT_MODEL_PATH, model=GeneralConfig.MODEL_INIT()),
)
with_last_layer2 = model_weights_with_last_layer(
[
Expand All @@ -44,7 +47,7 @@ def main():
0.9555442824877577,
0.2793786892860371,
],
model,
load_model(path=IMPORTED_DICT_MODEL_PATH, model=GeneralConfig.MODEL_INIT()),
)

initial_population = [
Expand Down
Binary file not shown.
Binary file not shown.
7 changes: 7 additions & 0 deletions VSharp.ML.AIAgent/ml/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import random
from pathlib import Path

import numpy
import pygad.torchga
Expand All @@ -14,6 +15,12 @@
from ml.onnx.onnx_import import create_torch_dummy_input


def load_model(path: Path, model: torch.nn.Module):
model.load_state_dict(torch.load(path))
model.eval()
return model


def model_weights_with_last_layer(
last_layer_weights: list[float], model: ml.models.StateGNNEncoderConvEdgeAttr
) -> npt.NDArray:
Expand Down

0 comments on commit 8fbbbe7

Please sign in to comment.