Skip to content

Commit

Permalink
Merge pull request #180 from mayjs/larger_concats
Browse files Browse the repository at this point in the history
Allow larger output sizes for Concat
  • Loading branch information
pixelspark authored Jul 30, 2023
2 parents 78958f2 + 4806f31 commit 07b7a9e
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 8 deletions.
7 changes: 5 additions & 2 deletions wonnx/src/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use crate::utils::{
ceil, AttributeNotFoundError, DataTypeError, MultiType, NodeAttributes, ScalarType, Shape,
};
use num::integer::gcd;
use num::integer::{gcd, Roots};
use tera::{Context, Tera};
use thiserror::Error;

Expand Down Expand Up @@ -792,10 +792,13 @@ pub fn compile(
}
context.insert("cum_len", &input_cumulative_len);

let root = output_lengths[0].sqrt() + 1;
let per_dim = ceil(root, 16) + 1;

NodeTemplate {
scalar_type: agreed_type(input_shapes, output_shapes)?,
template: "matrix/concat.wgsl",
threads: (ceil(output_lengths[0], 256) as u32, 1, 1),
threads: (per_dim as u32, per_dim as u32, 1),
}
}
op @ ("MaxPool" | "AveragePool" | "Conv" | "ConvRelu" | "ConvLeakyRelu" | "ConvMish"
Expand Down
17 changes: 11 additions & 6 deletions wonnx/templates/matrix/concat.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,24 @@ var<storage, read> input_{{ loop.index0 }}: Array;
@group({{ binding_len / 4 | int }}) @binding({{ binding_len % 4 }})
var<storage, read_write> output_0: Array;

@compute @workgroup_size(256, 1, 1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
@compute @workgroup_size(16, 16, 1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>, @builtin(num_workgroups) num_workgroups: vec3<u32>) {
let gidx = global_id.x;
let gidy = global_id.y;

let nx = num_workgroups.x;

let actual_idx = gidx + gidy * nx;

{% for input in i_lens %}
{% if loop.first %}
if (gidx < {{ i_lens[0] }}u) {
output_0.data[gidx] = input_0.data[gidx];
if (actual_idx < {{ i_lens[0] }}u) {
output_0.data[actual_idx] = input_0.data[actual_idx];
}

{% else %}
if ((gidx >= {{ cum_len | nth(n=loop.index0 -1) }}u) && (gidx < {{ cum_len | nth(n=loop.index0)}}u)) {
output_0.data[gidx] = input_{{ loop.index0 }}.data[gidx - {{ cum_len | nth(n=loop.index0 -1) }}u];
if ((actual_idx >= {{ cum_len | nth(n=loop.index0 -1) }}u) && (actual_idx < {{ cum_len | nth(n=loop.index0)}}u)) {
output_0.data[actual_idx] = input_{{ loop.index0 }}.data[actual_idx - {{ cum_len | nth(n=loop.index0 -1) }}u];
}

{% endif %}
Expand Down
80 changes: 80 additions & 0 deletions wonnx/tests/concat.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
use std::{collections::HashMap, convert::TryInto};
use wonnx::utils::{graph, model, node, tensor};
mod common;

#[test]
fn test_concat() {
let n: usize = 16;

let xdata: Vec<f32> = (0..n).map(|x| x as f32).collect();
let mut ydata: Vec<f32> = (n..2 * n).map(|x| x as f32).collect();
let input_dims = vec![n as i64];
let output_dims = vec![(n * 2) as i64];

let input_data = HashMap::from([
("X".into(), xdata.as_slice().into()),
("Y".into(), ydata.as_slice().into()),
]);

let model = model(graph(
vec![tensor("X", &input_dims), tensor("Y", &input_dims)],
vec![tensor("Z", &output_dims)],
vec![],
vec![],
vec![node(vec!["X", "Y"], vec!["Z"], "a", "Concat", vec![])],
));

let session =
pollster::block_on(wonnx::Session::from_model(model)).expect("Session did not create");

let result = pollster::block_on(session.run(&input_data)).unwrap();

let mut expected_result = xdata.clone();
expected_result.append(&mut ydata);

common::assert_eq_vector((&result["Z"]).try_into().unwrap(), &expected_result);
}

#[test]
fn test_concat4() {
let n: usize = 13;

let xdata: Vec<f32> = (0..n).map(|x| x as f32).collect();
let mut ydata: Vec<f32> = (n..2 * n).map(|x| x as f32).collect();
let mut zdata: Vec<f32> = (n * 2..3 * n).map(|x| x as f32).collect();
let mut wdata: Vec<f32> = (n * 3..4 * n).map(|x| x as f32).collect();
let input_dims = vec![n as i64];
let output_dims = vec![(n * 4) as i64];

let input_data = HashMap::from([
("X".into(), xdata.as_slice().into()),
("Y".into(), ydata.as_slice().into()),
("Z".into(), zdata.as_slice().into()),
("W".into(), wdata.as_slice().into()),
]);

let model = model(graph(
vec![
tensor("X", &input_dims),
tensor("Y", &input_dims),
tensor("Z", &input_dims),
tensor("W", &input_dims),
],
vec![tensor("O", &output_dims)],
vec![],
vec![],
vec![node(vec!["X", "Y", "Z", "W"], vec!["O"], "a", "Concat", vec![])],
));

let session =
pollster::block_on(wonnx::Session::from_model(model)).expect("Session did not create");

let result = pollster::block_on(session.run(&input_data)).unwrap();

let mut expected_result = xdata.clone();
expected_result.append(&mut ydata);
expected_result.append(&mut zdata);
expected_result.append(&mut wdata);

common::assert_eq_vector((&result["O"]).try_into().unwrap(), &expected_result);
}

0 comments on commit 07b7a9e

Please sign in to comment.