diff --git a/wonnx/src/compiler.rs b/wonnx/src/compiler.rs index bed2b865..ba757c20 100644 --- a/wonnx/src/compiler.rs +++ b/wonnx/src/compiler.rs @@ -2,7 +2,7 @@ use crate::utils::{ ceil, AttributeNotFoundError, DataTypeError, MultiType, NodeAttributes, ScalarType, Shape, }; -use num::integer::gcd; +use num::integer::{gcd, Roots}; use tera::{Context, Tera}; use thiserror::Error; @@ -792,10 +792,13 @@ pub fn compile( } context.insert("cum_len", &input_cumulative_len); + let root = output_lengths[0].sqrt() + 1; + let per_dim = ceil(root, 16) + 1; + NodeTemplate { scalar_type: agreed_type(input_shapes, output_shapes)?, template: "matrix/concat.wgsl", - threads: (ceil(output_lengths[0], 256) as u32, 1, 1), + threads: (per_dim as u32, per_dim as u32, 1), } } op @ ("MaxPool" | "AveragePool" | "Conv" | "ConvRelu" | "ConvLeakyRelu" | "ConvMish" diff --git a/wonnx/templates/matrix/concat.wgsl b/wonnx/templates/matrix/concat.wgsl index 76872cf7..4c282723 100644 --- a/wonnx/templates/matrix/concat.wgsl +++ b/wonnx/templates/matrix/concat.wgsl @@ -12,19 +12,24 @@ var input_{{ loop.index0 }}: Array; @group({{ binding_len / 4 | int }}) @binding({{ binding_len % 4 }}) var output_0: Array; -@compute @workgroup_size(256, 1, 1) -fn main(@builtin(global_invocation_id) global_id: vec3) { +@compute @workgroup_size(16, 16, 1) +fn main(@builtin(global_invocation_id) global_id: vec3, @builtin(num_workgroups) num_workgroups: vec3) { let gidx = global_id.x; + let gidy = global_id.y; + + let nx = num_workgroups.x; + + let actual_idx = gidx + gidy * nx; {% for input in i_lens %} {% if loop.first %} - if (gidx < {{ i_lens[0] }}u) { - output_0.data[gidx] = input_0.data[gidx]; + if (actual_idx < {{ i_lens[0] }}u) { + output_0.data[actual_idx] = input_0.data[actual_idx]; } {% else %} - if ((gidx >= {{ cum_len | nth(n=loop.index0 -1) }}u) && (gidx < {{ cum_len | nth(n=loop.index0)}}u)) { - output_0.data[gidx] = input_{{ loop.index0 }}.data[gidx - {{ cum_len | nth(n=loop.index0 -1) }}u]; + if ((actual_idx >= {{ cum_len | nth(n=loop.index0 -1) }}u) && (actual_idx < {{ cum_len | nth(n=loop.index0)}}u)) { + output_0.data[actual_idx] = input_{{ loop.index0 }}.data[actual_idx - {{ cum_len | nth(n=loop.index0 -1) }}u]; } {% endif %} diff --git a/wonnx/tests/concat.rs b/wonnx/tests/concat.rs new file mode 100644 index 00000000..980ad039 --- /dev/null +++ b/wonnx/tests/concat.rs @@ -0,0 +1,80 @@ +use std::{collections::HashMap, convert::TryInto}; +use wonnx::utils::{graph, model, node, tensor}; +mod common; + +#[test] +fn test_concat() { + let n: usize = 16; + + let xdata: Vec = (0..n).map(|x| x as f32).collect(); + let mut ydata: Vec = (n..2 * n).map(|x| x as f32).collect(); + let input_dims = vec![n as i64]; + let output_dims = vec![(n * 2) as i64]; + + let input_data = HashMap::from([ + ("X".into(), xdata.as_slice().into()), + ("Y".into(), ydata.as_slice().into()), + ]); + + let model = model(graph( + vec![tensor("X", &input_dims), tensor("Y", &input_dims)], + vec![tensor("Z", &output_dims)], + vec![], + vec![], + vec![node(vec!["X", "Y"], vec!["Z"], "a", "Concat", vec![])], + )); + + let session = + pollster::block_on(wonnx::Session::from_model(model)).expect("Session did not create"); + + let result = pollster::block_on(session.run(&input_data)).unwrap(); + + let mut expected_result = xdata.clone(); + expected_result.append(&mut ydata); + + common::assert_eq_vector((&result["Z"]).try_into().unwrap(), &expected_result); +} + +#[test] +fn test_concat4() { + let n: usize = 13; + + let xdata: Vec = (0..n).map(|x| x as f32).collect(); + let mut ydata: Vec = (n..2 * n).map(|x| x as f32).collect(); + let mut zdata: Vec = (n * 2..3 * n).map(|x| x as f32).collect(); + let mut wdata: Vec = (n * 3..4 * n).map(|x| x as f32).collect(); + let input_dims = vec![n as i64]; + let output_dims = vec![(n * 4) as i64]; + + let input_data = HashMap::from([ + ("X".into(), xdata.as_slice().into()), + ("Y".into(), ydata.as_slice().into()), + ("Z".into(), zdata.as_slice().into()), + ("W".into(), wdata.as_slice().into()), + ]); + + let model = model(graph( + vec![ + tensor("X", &input_dims), + tensor("Y", &input_dims), + tensor("Z", &input_dims), + tensor("W", &input_dims), + ], + vec![tensor("O", &output_dims)], + vec![], + vec![], + vec![node(vec!["X", "Y", "Z", "W"], vec!["O"], "a", "Concat", vec![])], + )); + + let session = + pollster::block_on(wonnx::Session::from_model(model)).expect("Session did not create"); + + let result = pollster::block_on(session.run(&input_data)).unwrap(); + + let mut expected_result = xdata.clone(); + expected_result.append(&mut ydata); + expected_result.append(&mut zdata); + expected_result.append(&mut wdata); + + common::assert_eq_vector((&result["O"]).try_into().unwrap(), &expected_result); +}