model.predict for graph classification #369
Unanswered
Digital-Chemist
asked this question in
Q&A
Replies: 1 comment 2 replies
-
Hi, that example is using DisjointLoader and graph-level predictions, so the model expects a batch index input (typically called Change the call to model.predict([graphX, graphA, tf.zeros(x.shape[0])]) |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Not exactly sure how to use model.predict() for graph classification. I've trained a GeneralGNN model with some custom data using the framework in the example with about 130 graphs (https://github.com/danielegrattarola/spektral/blob/master/examples/graph_prediction/general_gnn.py). No issues with the training/testing. If I then have a single unlabeled Graph (no y label) created the exact same way that the training/testing set graphs were created except no y=label, should I not just be able to feed that into model.predict as
This gives an error:
global_pool.py", line 30, in call * I = inputs[1] IndexError: list index out of range
If I put the single Graph into the same loader (DisjointLoader) created in the same way that the training/testing set was created, I get an error:
Both x and a are definitely in the single dataset (print(pred_dataset[0].x) and print(pred_dataset[0].a) give the expected arrays and sizes.
What am I missing here? Any help would be greatly appreciated.
Beta Was this translation helpful? Give feedback.
All reactions