Skip to content

Commit

Permalink
Merge pull request #199 from AmineDiro/erf-op
Browse files Browse the repository at this point in the history
Erf operator  Implementation
  • Loading branch information
pixelspark authored May 18, 2024
2 parents cbe7e7c + 19906f9 commit c62f5d3
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 2 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ fn test_matmul_square_matrix() {
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#Einsum">Einsum</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Einsum-12">12</a>|
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#Elu">Elu</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Elu-6">6</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Elu-1">1</a>|||
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#Equal">Equal</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Equal-13">13</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Equal-11">11</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Equal-7">7</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Equal-1">1</a>||
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#Erf">Erf</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Erf-13">13</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Erf-9">9</a>|||
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#Erf">Erf</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Erf-13">13</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Erf-9">9</a>|||
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#Exp">Exp</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Exp-13">13</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Exp-6">6</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Exp-1">1</a>|||
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#Expand">Expand</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Expand-13">13</a>, <a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Expand-8">8</a>|
|<a href="https://github.com/onnx/onnx/blob/main/docs/Operators.md#EyeLike">EyeLike</a>|<a href="https://github.com/onnx/onnx/blob/main/docs/Changelog.md#EyeLike-9">9</a>|
Expand Down
1 change: 1 addition & 0 deletions wonnx/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,4 @@ ndarray = "0.15.4"
approx = "0.5.1"
pollster = "0.3.0"
env_logger = "0.10.0"
ndarray-rand = "0.14.0"
2 changes: 1 addition & 1 deletion wonnx/src/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,7 @@ pub fn compile(
}
}
op @ ("Relu" | "Sigmoid" | "Softsign" | "Softplus" | "Clip" | "Celu" | "Elu"
| "LeakyRelu" | "HardSigmoid") => {
| "LeakyRelu" | "HardSigmoid" | "Erf") => {
let alpha = match op {
"LeakyRelu" => node.get_attribute_value("alpha", Some(0.01))?,
"HardSigmoid" => node.get_attribute_value("alpha", Some(0.2))?,
Expand Down
2 changes: 2 additions & 0 deletions wonnx/templates/endomorphism/activation.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ var<storage, read> input_0: ArrayVector;
@group(0) @binding(1)
var<storage, read_write> output_0: ArrayVector;

const pi: f32 = 3.1415;

@compute @workgroup_size({{ workgroup_size_x }})
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let gidx = global_id.x;
Expand Down
5 changes: 5 additions & 0 deletions wonnx/templates/snippets/activation_vec.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@
{{ activation_output }} = max({{ activation_input }}, Vec4(Scalar(), Scalar(), Scalar(), Scalar()))
+ min({{ scalar_type }}({{ alpha }}) * {{ activation_input }}, Vec4(Scalar(), Scalar(), Scalar(), Scalar()));

{%- elif activation_type == "Erf" -%}
var intermediate = 2.0/sqrt(pi)*({{ activation_input }}+ pow({{activation_input}},vec4(3.0,3.0,3.0,3.0))*0.08943 );
intermediate = clamp(intermediate,vec4(-10.0,-10.0,-10.0,-10.0),vec4(10.0,10.0,10.0,10.0));
{{ activation_output }} = tanh(intermediate);

{%- elif activation_type == "HardSigmoid" -%}
{{ activation_output }} = max(
Vec4(Scalar(), Scalar(), Scalar(), Scalar()),
Expand Down
30 changes: 30 additions & 0 deletions wonnx/tests/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,38 @@ use wonnx::{
},
};

use ndarray_rand::rand_distr::Uniform;
use ndarray_rand::RandomExt;

mod common;
#[test]
fn test_erf() {
let n: usize = 16;
let mut input_data = HashMap::new();
let data = ndarray::Array1::<f32>::random(n, Uniform::new(-100f32, 100f32));

let shape = vec![n as i64];
input_data.insert("X".to_string(), data.as_slice().unwrap().into());

// Model: X -> Cos -> Y
let model = model(graph(
vec![tensor("X", &shape)],
vec![tensor("Y", &shape)],
vec![],
vec![],
vec![node(vec!["X"], vec!["Y"], "erf", "Erf", vec![])],
));

let session =
pollster::block_on(wonnx::Session::from_model(model)).expect("Session did not create");

let result = pollster::block_on(session.run(&input_data)).unwrap();
if let OutputTensor::F32(vec) = &result["Y"] {
for e in vec {
assert_ne!(*e, f32::NAN);
}
}
}
#[test]
fn test_cos() {
let n: usize = 16;
Expand Down

0 comments on commit c62f5d3

Please sign in to comment.