Skip to content

Commit

Permalink
Merge pull request #81 from Anya497/optuna
Browse files Browse the repository at this point in the history
Optuna
  • Loading branch information
emnigma authored Dec 5, 2023
2 parents 408a938 + 61f5817 commit 7dab8d4
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 65 deletions.
2 changes: 1 addition & 1 deletion VSharp.ML.AIAgent/ml/common_model/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def update(
x.to("cpu")
filtered_map_steps = self.filter_map_steps(map_steps)
if map_name in self.maps_data.keys():
if self.maps_data[map_name][0] < map_result:
if self.maps_data[map_name][0] <= map_result:
logging.info(
f"The model with result = {self.maps_data[map_name][0]} was replaced with the model with "
f"result = {map_result} on the map {map_name}"
Expand Down
5 changes: 1 addition & 4 deletions VSharp.ML.AIAgent/ml/common_model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,7 @@ def load_dataset_state_dict(path):
return dataset_state_dict


def get_model(
path_to_weights: Path, model_init: t.Callable[[], torch.nn.Module], random_seed: int
):
np.random.seed(random_seed)
def get_model(path_to_weights: Path, model_init: t.Callable[[], torch.nn.Module]):
model = model_init()
weights = torch.load(path_to_weights)
weights["lin_last.weight"] = torch.tensor(np.random.random([1, 8]))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from torch_geometric.nn import Linear
from torch.nn.functional import softmax
from .model import StateModelEncoder


class StateModelEncoderLastLayer(StateModelEncoder):
def __init__(self, hidden_channels, out_channels):
super().__init__(hidden_channels, out_channels)
self.lin_last = Linear(out_channels, 1)

def forward(
self,
game_x,
state_x,
edge_index_v_v,
edge_type_v_v,
edge_index_history_v_s,
edge_attr_history_v_s,
edge_index_in_v_s,
edge_index_s_s,
):
return softmax(
self.lin_last(
super().forward(
game_x=game_x,
state_x=state_x,
edge_index_v_v=edge_index_v_v,
edge_type_v_v=edge_type_v_v,
edge_index_history_v_s=edge_index_history_v_s,
edge_attr_history_v_s=edge_attr_history_v_s,
edge_index_in_v_s=edge_index_in_v_s,
edge_index_s_s=edge_index_s_s,
)
),
dim=0,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from torch_geometric.nn import Linear
from torch.nn.functional import softmax
from .model import StateModelEncoder


class StateModelEncoderLastLayer(StateModelEncoder):
def __init__(self, hidden_channels, out_channels):
super().__init__(hidden_channels, out_channels)
self.lin_last = Linear(out_channels, 1)

def forward(
self,
game_x,
state_x,
edge_index_v_v,
edge_type_v_v,
edge_index_history_v_s,
edge_attr_history_v_s,
edge_index_in_v_s,
edge_index_s_s,
):
return softmax(
self.lin_last(
super().forward(
game_x=game_x,
state_x=state_x,
edge_index_v_v=edge_index_v_v,
edge_type_v_v=edge_type_v_v,
edge_index_history_v_s=edge_index_history_v_s,
edge_attr_history_v_s=edge_attr_history_v_s,
edge_index_in_v_s=edge_index_in_v_s,
edge_index_s_s=edge_index_s_s,
)
),
dim=0,
)
114 changes: 54 additions & 60 deletions VSharp.ML.AIAgent/run_common_model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,15 @@
)
from ml.common_model.utils import csv2best_models, get_model
from ml.common_model.wrapper import BestModelsWrapper, CommonModelWrapper
from ml.models.TAGSageSimple.model_modified import StateModelEncoderLastLayer
from ml.models.RGCNEdgeTypeTAG2VerticesDouble.model_modified import (
StateModelEncoderLastLayer,
)
from ml.models.StateGNNEncoderConvEdgeAttr.model_modified import (
StateModelEncoderLastLayer as RefStateModelEncoderLastLayer,
)
import optuna
from functools import partial
import joblib

LOG_PATH = Path("./ml_app.log")
TABLES_PATH = Path("./ml_tables.log")
Expand Down Expand Up @@ -82,19 +90,45 @@ class TrainConfig:
lr: float
epochs: int
batch_size: int


def train(train_config: TrainConfig, model: torch.nn.Module, dataset: FullDataset):
optimizer: torch.optim.Optimizer
loss: any
random_seed: int


def train(trial: optuna.trial.Trial, dataset: FullDataset):
config = TrainConfig(
lr=trial.suggest_float("lr", 1e-7, 1e-3),
batch_size=trial.suggest_int("batch_size", 32, 1024),
epochs=10,
optimizer=trial.suggest_categorical("optimizer", [torch.optim.Adam]),
loss=trial.suggest_categorical("loss", [nn.KLDivLoss]),
random_seed=937,
)
np.random.seed(config.random_seed)
# for name, param in model.named_parameters():
# if "lin_last" not in name:
# param.requires_grad = False

path_to_weights = os.path.join(
PRETRAINED_MODEL_PATH,
"RGCNEdgeTypeTAG2VerticesDouble",
"64ch",
"100e",
"GNN_state_pred_het_dict",
)
model = get_model(
Path(path_to_weights),
lambda: StateModelEncoderLastLayer(hidden_channels=64, out_channels=8),
)

model.to(GeneralConfig.DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=train_config.lr)
criterion = nn.KLDivLoss()
optimizer = config.optimizer(model.parameters(), lr=config.lr)
criterion = config.loss()

timestamp = datetime.now().timestamp()
run_name = f"{datetime.fromtimestamp(timestamp)}_{train_config.batch_size}_Adam_{train_config.lr}_KLDL"
run_name = (
f"{datetime.fromtimestamp(timestamp)}_{config.batch_size}_Adam_{config.lr}_KLDL"
)

print(run_name)
path_to_saved_models = os.path.join(COMMON_MODELS_PATH, run_name)
Expand All @@ -118,11 +152,9 @@ def train(train_config: TrainConfig, model: torch.nn.Module, dataset: FullDatase
# p = Pool(GeneralConfig.SERVER_COUNT)

all_average_results = []
for epoch in range(train_config.epochs):
for epoch in range(config.epochs):
data_list = dataset.get_plain_data()
data_loader = DataLoader(
data_list, batch_size=train_config.batch_size, shuffle=True
)
data_loader = DataLoader(data_list, batch_size=config.batch_size, shuffle=True)
print("DataLoader size", len(data_loader))

model.train()
Expand Down Expand Up @@ -178,8 +210,9 @@ def train(train_config: TrainConfig, model: torch.nn.Module, dataset: FullDatase
list(map(lambda x: x.game_result.actual_coverage_percent, all_results))
)
all_average_results.append(average_result)
all_results = sorted(all_results, key=lambda x: x.map.MapName)
table, _, _ = create_pivot_table({cmwrapper: all_results})
table, _, _ = create_pivot_table(
{cmwrapper: sorted(all_results, key=lambda x: x.map.MapName)}
)
table = table_to_string(table)
append_to_file(
TABLES_PATH,
Expand All @@ -193,7 +226,7 @@ def train(train_config: TrainConfig, model: torch.nn.Module, dataset: FullDatase
del data_loader
# p.close()

return all_average_results
return max(all_average_results)


def get_dataset(
Expand Down Expand Up @@ -222,57 +255,18 @@ def get_dataset(

def main():
print(GeneralConfig.DEVICE)
path_to_weights = os.path.join(
PRETRAINED_MODEL_PATH,
"TAGSageSimple",
"32ch",
"20e",
"GNN_state_pred_het_dict",
)
model_initializer = lambda: StateModelEncoderLastLayer(
ref_model_initializer = lambda: RefStateModelEncoderLastLayer(
hidden_channels=32, out_channels=8
)

best_result = {"average_coverage": 0, "config": dict(), "epoch": 0}
generate_dataset = False
dataset = get_dataset(generate_dataset, ref_model_init=model_initializer)
dataset = get_dataset(generate_dataset, ref_model_init=ref_model_initializer)

while True:
config = TrainConfig(
lr=random.choice([10 ** (-i) for i in range(3, 8)]),
batch_size=random.choice([2**i for i in range(5, 10)]),
epochs=20,
)
print("Current hyperparameters")
data_frame = pd.DataFrame(
data=[asdict(config).values()],
columns=asdict(config).keys(),
index=["value"],
)
print(data_frame)

model = get_model(
Path(path_to_weights),
model_initializer,
random_seed=937,
)

results = train(train_config=config, model=model, dataset=dataset)
max_value = max(results)
max_ind = results.index(max_value)
if best_result["average_coverage"] < max_value:
best_result["average_coverage"] = max_value
best_result["config"] = asdict(config)
best_result["epoch"] = max_ind + 1
print(
f"The best result for now:\nAverage coverage: {best_result['average_coverage']}"
)
data_frame = pd.DataFrame(
data=[best_result["config"].values()],
columns=best_result["config"].keys(),
index=["value"],
)
print(data_frame)
sampler = optuna.samplers.TPESampler(n_startup_trials=10)
study = optuna.create_study(sampler=sampler, direction="maximize")
objective = partial(train, dataset=dataset)
study.optimize(objective, n_trials=100)
joblib.dump(study, f"{datetime.fromtimestamp(datetime.now().timestamp())}.pkl")


if __name__ == "__main__":
Expand Down

0 comments on commit 7dab8d4

Please sign in to comment.