From 4806f31176cf30acae8eaf818e90490a949adf4d Mon Sep 17 00:00:00 2001 From: Johannes May Date: Sun, 30 Jul 2023 11:06:28 +0200 Subject: [PATCH] Allow larger output sizes for Concat UNet caused errors because the Concat operator could not handle the large output sizes of some Concat operations. The solution for this is to change the workgroup dimensions to also use the y dimension. Since there is a limit of threads per dimension, this allows us quadratically increase the maximum possible output size. --- wonnx/src/compiler.rs | 7 ++- wonnx/templates/matrix/concat.wgsl | 17 ++++--- wonnx/tests/concat.rs | 80 ++++++++++++++++++++++++++++++ 3 files changed, 96 insertions(+), 8 deletions(-) create mode 100644 wonnx/tests/concat.rs diff --git a/wonnx/src/compiler.rs b/wonnx/src/compiler.rs index bed2b865..ba757c20 100644 --- a/wonnx/src/compiler.rs +++ b/wonnx/src/compiler.rs @@ -2,7 +2,7 @@ use crate::utils::{ ceil, AttributeNotFoundError, DataTypeError, MultiType, NodeAttributes, ScalarType, Shape, }; -use num::integer::gcd; +use num::integer::{gcd, Roots}; use tera::{Context, Tera}; use thiserror::Error; @@ -792,10 +792,13 @@ pub fn compile( } context.insert("cum_len", &input_cumulative_len); + let root = output_lengths[0].sqrt() + 1; + let per_dim = ceil(root, 16) + 1; + NodeTemplate { scalar_type: agreed_type(input_shapes, output_shapes)?, template: "matrix/concat.wgsl", - threads: (ceil(output_lengths[0], 256) as u32, 1, 1), + threads: (per_dim as u32, per_dim as u32, 1), } } op @ ("MaxPool" | "AveragePool" | "Conv" | "ConvRelu" | "ConvLeakyRelu" | "ConvMish" diff --git a/wonnx/templates/matrix/concat.wgsl b/wonnx/templates/matrix/concat.wgsl index 76872cf7..4c282723 100644 --- a/wonnx/templates/matrix/concat.wgsl +++ b/wonnx/templates/matrix/concat.wgsl @@ -12,19 +12,24 @@ var input_{{ loop.index0 }}: Array; @group({{ binding_len / 4 | int }}) @binding({{ binding_len % 4 }}) var output_0: Array; -@compute @workgroup_size(256, 1, 1) -fn main(@builtin(global_invocation_id) global_id: vec3) { +@compute @workgroup_size(16, 16, 1) +fn main(@builtin(global_invocation_id) global_id: vec3, @builtin(num_workgroups) num_workgroups: vec3) { let gidx = global_id.x; + let gidy = global_id.y; + + let nx = num_workgroups.x; + + let actual_idx = gidx + gidy * nx; {% for input in i_lens %} {% if loop.first %} - if (gidx < {{ i_lens[0] }}u) { - output_0.data[gidx] = input_0.data[gidx]; + if (actual_idx < {{ i_lens[0] }}u) { + output_0.data[actual_idx] = input_0.data[actual_idx]; } {% else %} - if ((gidx >= {{ cum_len | nth(n=loop.index0 -1) }}u) && (gidx < {{ cum_len | nth(n=loop.index0)}}u)) { - output_0.data[gidx] = input_{{ loop.index0 }}.data[gidx - {{ cum_len | nth(n=loop.index0 -1) }}u]; + if ((actual_idx >= {{ cum_len | nth(n=loop.index0 -1) }}u) && (actual_idx < {{ cum_len | nth(n=loop.index0)}}u)) { + output_0.data[actual_idx] = input_{{ loop.index0 }}.data[actual_idx - {{ cum_len | nth(n=loop.index0 -1) }}u]; } {% endif %} diff --git a/wonnx/tests/concat.rs b/wonnx/tests/concat.rs new file mode 100644 index 00000000..980ad039 --- /dev/null +++ b/wonnx/tests/concat.rs @@ -0,0 +1,80 @@ +use std::{collections::HashMap, convert::TryInto}; +use wonnx::utils::{graph, model, node, tensor}; +mod common; + +#[test] +fn test_concat() { + let n: usize = 16; + + let xdata: Vec = (0..n).map(|x| x as f32).collect(); + let mut ydata: Vec = (n..2 * n).map(|x| x as f32).collect(); + let input_dims = vec![n as i64]; + let output_dims = vec![(n * 2) as i64]; + + let input_data = HashMap::from([ + ("X".into(), xdata.as_slice().into()), + ("Y".into(), ydata.as_slice().into()), + ]); + + let model = model(graph( + vec![tensor("X", &input_dims), tensor("Y", &input_dims)], + vec![tensor("Z", &output_dims)], + vec![], + vec![], + vec![node(vec!["X", "Y"], vec!["Z"], "a", "Concat", vec![])], + )); + + let session = + pollster::block_on(wonnx::Session::from_model(model)).expect("Session did not create"); + + let result = pollster::block_on(session.run(&input_data)).unwrap(); + + let mut expected_result = xdata.clone(); + expected_result.append(&mut ydata); + + common::assert_eq_vector((&result["Z"]).try_into().unwrap(), &expected_result); +} + +#[test] +fn test_concat4() { + let n: usize = 13; + + let xdata: Vec = (0..n).map(|x| x as f32).collect(); + let mut ydata: Vec = (n..2 * n).map(|x| x as f32).collect(); + let mut zdata: Vec = (n * 2..3 * n).map(|x| x as f32).collect(); + let mut wdata: Vec = (n * 3..4 * n).map(|x| x as f32).collect(); + let input_dims = vec![n as i64]; + let output_dims = vec![(n * 4) as i64]; + + let input_data = HashMap::from([ + ("X".into(), xdata.as_slice().into()), + ("Y".into(), ydata.as_slice().into()), + ("Z".into(), zdata.as_slice().into()), + ("W".into(), wdata.as_slice().into()), + ]); + + let model = model(graph( + vec![ + tensor("X", &input_dims), + tensor("Y", &input_dims), + tensor("Z", &input_dims), + tensor("W", &input_dims), + ], + vec![tensor("O", &output_dims)], + vec![], + vec![], + vec![node(vec!["X", "Y", "Z", "W"], vec!["O"], "a", "Concat", vec![])], + )); + + let session = + pollster::block_on(wonnx::Session::from_model(model)).expect("Session did not create"); + + let result = pollster::block_on(session.run(&input_data)).unwrap(); + + let mut expected_result = xdata.clone(); + expected_result.append(&mut ydata); + expected_result.append(&mut zdata); + expected_result.append(&mut wdata); + + common::assert_eq_vector((&result["O"]).try_into().unwrap(), &expected_result); +}