-
Notifications
You must be signed in to change notification settings - Fork 60
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MatMul Broadcasting #197
Comments
Hi @AmineDiro! In WONNX, shape inference and model inference are two completely separate components. Model inference requires all shapes to be inferred (which you can do using the shape inference module). The Conceptually, to implement broadcasting for MatMul, you need to add the calculation in the shader and invoke it the right way (by setting parameters or switching to a different shader altogether) in To not reinvent the wheel I suggest to have a look at other GPU-based frameworks and their implementation (I often peeked at ONNXRuntime to see their test implementations as well). Feel free to go ahead implementing this and if you have any questions, let me know! |
Hi @pixelspark , Thanks a lot for the detailed explanation. I'll start working on this, that's a good excuse to start learning |
Hi there @pixelspark, I started working on Matmul broadcasting this weekend and I encountered an issue with the existing implementation. I am new #[test]
fn test_matmul_square_random_matrix() {
let tol = 1e-3f32;
let n = 10;
let _ = env_logger::builder().is_test(true).try_init();
// Generate random arrays
let x_data = ndarray::Array2::<f32>::random((n, n), Uniform::new(0f32, 100f32));
let y_data = ndarray::Array2::<f32>::random((n, n), Uniform::new(0f32, 10f32));
let expected = x_data.dot(&y_data).as_slice().unwrap().to_owned();
let mut input_data = HashMap::new();
input_data.insert("X".to_string(), x_data.as_slice().unwrap().into());
input_data.insert("Y".to_string(), y_data.as_slice().unwrap().into());
let n = n as i64;
let shape = [1, n, n];
let model = model(graph(
vec![tensor("X", &shape), tensor("Y", &shape)],
vec![tensor("Z", &shape)],
vec![],
vec![],
vec![node(vec!["X", "Y"], vec!["Z"], "matmul", "MatMul", vec![])],
));
let session =
pollster::block_on(wonnx::Session::from_model(model)).expect("Session did not create");
let result = pollster::block_on(session.run(&input_data)).unwrap();
match result.get("Z").unwrap() {
OutputTensor::F32(v) => {
let diff: Vec<f32> = v.iter().zip(expected.iter()).map(|(l, r)| l - r).collect();
let diff = diff
.iter()
.filter(|e| f32::abs(**e) > tol)
.collect::<Vec<_>>();
assert_eq!(diff.len(), 0);
}
_ => unreachable!(),
}
} This test fails for the value |
Hm, these are always tough issues to figure out... First of all, you should establish that this is not actually some rounding error (differences are expected given wgpu is a cake of abstractions on top of abstractions, so it might depend on your tolerance). That said, given it works with n=16 suggests this has something to do with the chunk sizes (could even be an off-by-one somewhere). I am unfortunately not too familiar with the matmul implementation (@haixuanTao wrote it iirc, he might have a clue). I suspect he wrote this for the n=16 case for a specific model and then called it a day (as for most ONNX operators there are just so many variations that this is a very reasonable thing to do) For most of the shaders I implemented, I let the shader coordinates refer to the output coordinate and let it 'gather' all the input data needed (there are probably more efficient ways to do it but this for me was the easiest to reason about). This way the group size should not really matter for the result in any way (I also did not use any complicated stuff like sharing data within work groups or barriers). You could of course try to write something of your own this way for the case where n!=16. |
Thanks for your response !
100%, I used a pretty high tolerance for testing. Usually matching float mult precision between two hardware doesn't require this much tolerance.
I am not sure I follow here. If I reformulate, you are trying to eliminate picking the group sizes entirely by setting them to a fix size and having the logic access some arbitrarily long input array? Thanks again for your help :) |
So, you should probably look at the matrix implementation. Basically each shader can utilize up to 16 elements as mat 4x4. This is the fastest way to do matmul. Using arbitrary matrix not multiple of 16 will probably not work as you’re not contiguous. you can revert to a more simple matmul kernel or try to do some form of padding. |
Describe the bug
Hello,
Thanks for this amazing project !
I was trying to run the
sentence-transformers/all-MiniLM-L6-v2
model and I had the following error:It seems that you can't multiply 1x512x384 @ 384x384 = 1x512z384. I don't know if there is a need to match the same dimensions.
I understand that MatMul doesn't support shape inference and I don't know if this issue is related to this. I might be doing something wrong when preparing the model so could you help on this matter.
all-MiniLM-L6-v2
is basically a bert model and I saw that Bert ran successfully using wonnx.To Reproduce
sentence-transformers/all-MiniLM-L6-v2
to ONNX format using optimum libonnx-simplify
:Expected behavior
Run session inference
Desktop (please complete the following information):
Thanks for your help, I'm also open to contributing if there is a fix to implement if you point me to the right direction 😄
The text was updated successfully, but these errors were encountered: