Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ml searcher add pretrained #76

Merged
merged 8 commits into from
Nov 21, 2023
Merged

Conversation

katyacyfra
Copy link

No description provided.

@emnigma
Copy link
Collaborator

emnigma commented Nov 5, 2023

@katyacyfra , Замечаешь, что у тебя не проходит CI? красный крест и написано "Al checks have failed". У нас только один чек, это проверка на кодстайл, чтобы код был однообразный и его было легче читать

Тебе надо при активном питоновском окружении прописать

pip install pre-commit
pre-commit install

Тогда после написания тобой команды "commit" инструмент проверит код и поправит его. После этого поправленный код снова нужно добавить в индекс командой git add и снова прописать git commit

@@ -127,6 +129,8 @@ def convert_input_to_tensor(input: GameState) -> Tuple[HeteroData, Dict[int, int
data["game_vertex", "to", "game_vertex"].edge_index = (
torch.tensor(np.array(edges_index_v_v), dtype=torch.long).t().contiguous()
)
data['game_vertex', 'to', 'game_vertex'].edge_attr = torch.tensor(np.array(edges_attr_v_v), dtype=torch.long)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

вроде как хотели из таплов строки сделать?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Через пробел пойдет? Поправлю

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

лучше через нижние подчеркивания

return self.lin(state_x)

class StateModelEncoder(torch.nn.Module):
def __init__(self, hidden_channels, out_channels):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

я знаю, что можно сделать без енкодеров, вот так например Аня делала.

или z-dict тебе очень нужен?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Можно и без них, вопрос в том, что нужно? Какой формат?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

на вход форварда 8 параметров (сейчас ты добавила восьмой, поэтому 8, по ссылке 7), на выход - результат работы модели. Отдельную обертку (Encoder) делать не надо, z_dict делать не надо

return correct / len(loader.dataset)

@staticmethod
def predict_state(model, data: HeteroData, state_map: Dict[int, int]) -> int:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

обязательно ли функция должна статически принадлежать классу?

"""Gets state id from model and heterogeneous graph
data.state_map - maps real state id to state index"""
state_map = {v: k for k, v in state_map.items()} # inversion for prediction
out = model(data.x_dict, data.edge_index_dict)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

у нас вроде уже 8 входов. этот код все еще актуален?

torch.save(model, filepath)


if __name__ == "__main__":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

в файлах, которые не используются как исполняемые, то есть те, в которых только функции лежат, лучше не писать if__name__ == "__main__":, иначе у читающего создается впечатление, что это исполняемый файл


@staticmethod
def predict_state(model, data: HeteroData, state_map: dict[int, int]) -> int:
def predict_state(model, data: HeteroData, state_map: Dict[int, int]) -> int:
"""Gets state id from model and heterogeneous graph
data.state_map - maps real state id to state index"""
state_map = {v: k for k, v in state_map.items()} # inversion for prediction
out = model(data.x_dict, data.edge_index_dict)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

тут все еще старый предикт?

torch.tensor(np.array(edges_index_v_v), dtype=torch.long).t().contiguous()
)
data["state_vertex", "in", "game_vertex"].edge_index = (
data["game_vertex_to_game_vertex"].edge_attr = torch.tensor(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

тут если edges_attr_v_v, edges_types_v_v могут быть пустыми, то нужно соответствующие тензоры оборачивать в null_if_empty. но это ладно, сделаю сам

@gsvgit gsvgit merged commit 1fce131 into PySymGym:mlSearcher Nov 21, 2023
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants