-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ml searcher add pretrained #76
Conversation
@katyacyfra , Замечаешь, что у тебя не проходит CI? красный крест и написано "Al checks have failed". У нас только один чек, это проверка на кодстайл, чтобы код был однообразный и его было легче читать Тебе надо при активном питоновском окружении прописать
Тогда после написания тобой команды "commit" инструмент проверит код и поправит его. После этого поправленный код снова нужно добавить в индекс командой |
@@ -127,6 +129,8 @@ def convert_input_to_tensor(input: GameState) -> Tuple[HeteroData, Dict[int, int | |||
data["game_vertex", "to", "game_vertex"].edge_index = ( | |||
torch.tensor(np.array(edges_index_v_v), dtype=torch.long).t().contiguous() | |||
) | |||
data['game_vertex', 'to', 'game_vertex'].edge_attr = torch.tensor(np.array(edges_attr_v_v), dtype=torch.long) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
вроде как хотели из таплов строки сделать?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Через пробел пойдет? Поправлю
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
лучше через нижние подчеркивания
return self.lin(state_x) | ||
|
||
class StateModelEncoder(torch.nn.Module): | ||
def __init__(self, hidden_channels, out_channels): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
я знаю, что можно сделать без енкодеров, вот так например Аня делала.
или z-dict тебе очень нужен?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Можно и без них, вопрос в том, что нужно? Какой формат?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
на вход форварда 8 параметров (сейчас ты добавила восьмой, поэтому 8, по ссылке 7), на выход - результат работы модели. Отдельную обертку (Encoder) делать не надо, z_dict делать не надо
return correct / len(loader.dataset) | ||
|
||
@staticmethod | ||
def predict_state(model, data: HeteroData, state_map: Dict[int, int]) -> int: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
обязательно ли функция должна статически принадлежать классу?
"""Gets state id from model and heterogeneous graph | ||
data.state_map - maps real state id to state index""" | ||
state_map = {v: k for k, v in state_map.items()} # inversion for prediction | ||
out = model(data.x_dict, data.edge_index_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
у нас вроде уже 8 входов. этот код все еще актуален?
torch.save(model, filepath) | ||
|
||
|
||
if __name__ == "__main__": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
в файлах, которые не используются как исполняемые, то есть те, в которых только функции лежат, лучше не писать if__name__ == "__main__":
, иначе у читающего создается впечатление, что это исполняемый файл
|
||
@staticmethod | ||
def predict_state(model, data: HeteroData, state_map: dict[int, int]) -> int: | ||
def predict_state(model, data: HeteroData, state_map: Dict[int, int]) -> int: | ||
"""Gets state id from model and heterogeneous graph | ||
data.state_map - maps real state id to state index""" | ||
state_map = {v: k for k, v in state_map.items()} # inversion for prediction | ||
out = model(data.x_dict, data.edge_index_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
тут все еще старый предикт?
torch.tensor(np.array(edges_index_v_v), dtype=torch.long).t().contiguous() | ||
) | ||
data["state_vertex", "in", "game_vertex"].edge_index = ( | ||
data["game_vertex_to_game_vertex"].edge_attr = torch.tensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
тут если edges_attr_v_v, edges_types_v_v могут быть пустыми, то нужно соответствующие тензоры оборачивать в null_if_empty. но это ладно, сделаю сам
No description provided.