diff --git a/VSharp.ML.AIAgent/config.py b/VSharp.ML.AIAgent/config.py index acd780ae3..5aa6f9999 100644 --- a/VSharp.ML.AIAgent/config.py +++ b/VSharp.ML.AIAgent/config.py @@ -55,7 +55,6 @@ class FeatureConfig: VERBOSE_TABLES = True SHOW_SUCCESSORS = True NAME_LEN = 7 - N_BEST_SAVED_EACH_GEN = 2 DISABLE_MESSAGE_CHECKS = True DUMP_BY_TIMEOUT = DumpByTimeoutFeature( enabled=True, timeout_sec=600, save_path=Path("./report/timeouted_agents/") diff --git a/VSharp.ML.AIAgent/learning/genetic_learning.py b/VSharp.ML.AIAgent/learning/genetic_learning.py index 8df2aaabc..8056df0b8 100644 --- a/VSharp.ML.AIAgent/learning/genetic_learning.py +++ b/VSharp.ML.AIAgent/learning/genetic_learning.py @@ -62,9 +62,7 @@ def on_generation(ga_instance): print(f"Generation = {ga_instance.generations_completed};") epoch_subdir = create_epoch_subdir(ga_instance.generations_completed) - for weights in get_n_best_weights_in_last_generation( - ga_instance, FeatureConfig.N_BEST_SAVED_EACH_GEN - ): + for weights in ga_instance.population: save_model( GeneralConfig.MODEL_INIT(), to=epoch_subdir / f"{sum(weights)}.pth", diff --git a/VSharp.ML.AIAgent/ml/fileop.py b/VSharp.ML.AIAgent/ml/fileop.py index 350cd29d7..c33a3dd1b 100644 --- a/VSharp.ML.AIAgent/ml/fileop.py +++ b/VSharp.ML.AIAgent/ml/fileop.py @@ -3,7 +3,7 @@ import pygad import torch -import ml +from ml.onnx.onnx_import import create_torch_dummy_input from common.constants import DEVICE @@ -11,7 +11,7 @@ def save_model(model: torch.nn.Module, to: Path, weights=None): if weights is None: torch.save(model.state_dict(), to) else: - model.forward(*ml.onnx.onnx_import.create_torch_dummy_input()) + model.forward(*create_torch_dummy_input()) state_dict = pygad.torchga.model_weights_as_dict(model, weights) torch.save(state_dict, to) diff --git a/VSharp.ML.AIAgent/ml/onnx/onnx_import.py b/VSharp.ML.AIAgent/ml/onnx/onnx_import.py index 07af5adc2..24bc7b77d 100644 --- a/VSharp.ML.AIAgent/ml/onnx/onnx_import.py +++ b/VSharp.ML.AIAgent/ml/onnx/onnx_import.py @@ -57,21 +57,12 @@ def create_torch_dummy_input(): def export_onnx_model(model: torch.nn.Module, save_path: str): torch.onnx.export( model=model, - args=create_onnx_dummy_input(), + args=(*create_torch_dummy_input(), {}), f=save_path, verbose=False, export_params=True, - input_names=["x_dict", "edge_index_dict"], - # input_names=[ - # "game_vertex", - # "state_vertex", - # "gv2gv", - # "sv_in_gv", - # "gv_in_sv", - # "sv_his_gv", - # "gv_his_sv", - # "sv_parentof_sv", - # ], + input_names=["x_dict", "edge_index_dict", "edge_attr_dict"], + opset_version=16, ) torch_model_out = model(*create_torch_dummy_input())