Skip to content

Commit

Permalink
Allow larger output sizes for Concat
Browse files Browse the repository at this point in the history
UNet caused errors because the Concat operator could not handle
the large output sizes of some Concat operations.
The solution for this is to change the workgroup dimensions to
also use the y dimension.
Since there is a limit of threads per dimension, this allows
us quadratically increase the maximum possible output size.
  • Loading branch information
mayjs committed Jul 30, 2023
1 parent 78958f2 commit 4806f31
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 8 deletions.
7 changes: 5 additions & 2 deletions wonnx/src/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use crate::utils::{
ceil, AttributeNotFoundError, DataTypeError, MultiType, NodeAttributes, ScalarType, Shape,
};
use num::integer::gcd;
use num::integer::{gcd, Roots};
use tera::{Context, Tera};
use thiserror::Error;

Expand Down Expand Up @@ -792,10 +792,13 @@ pub fn compile(
}
context.insert("cum_len", &input_cumulative_len);

let root = output_lengths[0].sqrt() + 1;
let per_dim = ceil(root, 16) + 1;

NodeTemplate {
scalar_type: agreed_type(input_shapes, output_shapes)?,
template: "matrix/concat.wgsl",
threads: (ceil(output_lengths[0], 256) as u32, 1, 1),
threads: (per_dim as u32, per_dim as u32, 1),
}
}
op @ ("MaxPool" | "AveragePool" | "Conv" | "ConvRelu" | "ConvLeakyRelu" | "ConvMish"
Expand Down
17 changes: 11 additions & 6 deletions wonnx/templates/matrix/concat.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,24 @@ var<storage, read> input_{{ loop.index0 }}: Array;
@group({{ binding_len / 4 | int }}) @binding({{ binding_len % 4 }})
var<storage, read_write> output_0: Array;

@compute @workgroup_size(256, 1, 1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
@compute @workgroup_size(16, 16, 1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>, @builtin(num_workgroups) num_workgroups: vec3<u32>) {
let gidx = global_id.x;
let gidy = global_id.y;

let nx = num_workgroups.x;

let actual_idx = gidx + gidy * nx;

{% for input in i_lens %}
{% if loop.first %}
if (gidx < {{ i_lens[0] }}u) {
output_0.data[gidx] = input_0.data[gidx];
if (actual_idx < {{ i_lens[0] }}u) {
output_0.data[actual_idx] = input_0.data[actual_idx];
}

{% else %}
if ((gidx >= {{ cum_len | nth(n=loop.index0 -1) }}u) && (gidx < {{ cum_len | nth(n=loop.index0)}}u)) {
output_0.data[gidx] = input_{{ loop.index0 }}.data[gidx - {{ cum_len | nth(n=loop.index0 -1) }}u];
if ((actual_idx >= {{ cum_len | nth(n=loop.index0 -1) }}u) && (actual_idx < {{ cum_len | nth(n=loop.index0)}}u)) {
output_0.data[actual_idx] = input_{{ loop.index0 }}.data[actual_idx - {{ cum_len | nth(n=loop.index0 -1) }}u];
}

{% endif %}
Expand Down
80 changes: 80 additions & 0 deletions wonnx/tests/concat.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
use std::{collections::HashMap, convert::TryInto};
use wonnx::utils::{graph, model, node, tensor};
mod common;

#[test]
fn test_concat() {
let n: usize = 16;

let xdata: Vec<f32> = (0..n).map(|x| x as f32).collect();
let mut ydata: Vec<f32> = (n..2 * n).map(|x| x as f32).collect();
let input_dims = vec![n as i64];
let output_dims = vec![(n * 2) as i64];

let input_data = HashMap::from([
("X".into(), xdata.as_slice().into()),
("Y".into(), ydata.as_slice().into()),
]);

let model = model(graph(
vec![tensor("X", &input_dims), tensor("Y", &input_dims)],
vec![tensor("Z", &output_dims)],
vec![],
vec![],
vec![node(vec!["X", "Y"], vec!["Z"], "a", "Concat", vec![])],
));

let session =
pollster::block_on(wonnx::Session::from_model(model)).expect("Session did not create");

let result = pollster::block_on(session.run(&input_data)).unwrap();

let mut expected_result = xdata.clone();
expected_result.append(&mut ydata);

common::assert_eq_vector((&result["Z"]).try_into().unwrap(), &expected_result);
}

#[test]
fn test_concat4() {
let n: usize = 13;

let xdata: Vec<f32> = (0..n).map(|x| x as f32).collect();
let mut ydata: Vec<f32> = (n..2 * n).map(|x| x as f32).collect();
let mut zdata: Vec<f32> = (n * 2..3 * n).map(|x| x as f32).collect();
let mut wdata: Vec<f32> = (n * 3..4 * n).map(|x| x as f32).collect();
let input_dims = vec![n as i64];
let output_dims = vec![(n * 4) as i64];

let input_data = HashMap::from([
("X".into(), xdata.as_slice().into()),
("Y".into(), ydata.as_slice().into()),
("Z".into(), zdata.as_slice().into()),
("W".into(), wdata.as_slice().into()),
]);

let model = model(graph(
vec![
tensor("X", &input_dims),
tensor("Y", &input_dims),
tensor("Z", &input_dims),
tensor("W", &input_dims),
],
vec![tensor("O", &output_dims)],
vec![],
vec![],
vec![node(vec!["X", "Y", "Z", "W"], vec!["O"], "a", "Concat", vec![])],
));

let session =
pollster::block_on(wonnx::Session::from_model(model)).expect("Session did not create");

let result = pollster::block_on(session.run(&input_data)).unwrap();

let mut expected_result = xdata.clone();
expected_result.append(&mut ydata);
expected_result.append(&mut zdata);
expected_result.append(&mut wdata);

common::assert_eq_vector((&result["O"]).try_into().unwrap(), &expected_result);
}

0 comments on commit 4806f31

Please sign in to comment.