-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
benchdnn: graph: extend --dt to support specifying tensor id #2331
base: main
Are you sure you want to change the base?
Conversation
086c6dd
to
ab80ed6
Compare
ab80ed6
to
539767a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've left a few comments, please incorporate as you see fit, thanks!
539767a
to
143ce80
Compare
Thank you @ranukund . Suggestions are incorporated now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thank you!
make test |
// format like --dt=f32,bf16,f16 | ||
const std::vector<dnnl_data_type_t> def_dt = {dnnl_data_type_undef}; | ||
parser::parse_dt(dts, def_dt, str, option_name); | ||
} else { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on this logic it looks like the following will work:
--dt=f16 --dt=0:f32+... --case=some_case
if
branch will update dt
, then else
branch will update dt_map
, then dt_rewrite
will update the whole graph, and then dt_map_rewrite
will update specific IDs.
Is it an expected flow?
If yes, should it be properly documented?
If no, then it probably should be restricted here by additional checking whether the other object was updated or not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! I did not expect the two --dt flavors would be specified in a single command line. I just added a check in the bench()
function to error out.
3e42fa0
to
7d3f6c3
Compare
tests/benchdnn/graph/bench_graph.cpp
Outdated
BENCHDNN_PRINT(0, "%s\n", | ||
"Error: --dt is specified twice with different styles."); | ||
SAFE_V(FAIL); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would move it into parse_dt
since it has both objects and can check their state to figure out if both options were passed. Afterall, it's a parser obligation to make sure the input is valid (when it can).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, it's moved to the parser function now.
7d3f6c3
to
2fd06e0
Compare
rewrite sdpa-compressed-kv-int4-gs32.json for it.
2fd06e0
to
4ee753e
Compare
We introduced
--dt
in 461388e to support floating-point data type rewrite for the whole graph.In this PR, the knob is further extended to support data type rewrite for specified tensor IDs, in a format of
--dt=ID0:DT0+ID1:DT1...
. It provides more flexibility to test a json file with different data types, especially for the various zps/scales data types for quantizations which were excluded previously in whole graph data type rewrite. Due to the flexibility of this knob and the restriction of op definition in the library, users need to be cautious when rewriting the data type of a specific tensor in a graph as it may lead to graph construction failures in the library during testing (eg. incompatible input and output tensor data types).An example as below:
benchdnn --graph --dt=0:s8+2:s8+6:s8+8:s8 --case=complex_fusion/mha/sdpa-compressed-kv-int4-gs32.json
It can be used to test int8 compressed KV graph by rewriting rewrite the int4 compressed KV graph (0 is the key tensor, 2 is the zps pf key tensor, 6 is the value tensor, and 8 is the zps of values tensor).