From 516830577566ce09a0f9183a6ed5845f19f2bb84 Mon Sep 17 00:00:00 2001 From: Daniel Voronov Date: Fri, 17 Nov 2023 17:33:39 +0100 Subject: [PATCH 01/38] Implement SequenceTrait --- src/operators.cairo | 1 + src/operators/sequence.cairo | 11 +++ src/operators/sequence/core.cairo | 82 +++++++++++++++++++ src/operators/sequence/functional.cairo | 2 + .../functional/sequence_construct.cairo | 11 +++ .../sequence/functional/sequence_empty.cairo | 19 +++++ src/operators/sequence/implementations.cairo | 7 ++ .../implementations/sequence_fp16x16.cairo | 18 ++++ .../implementations/sequence_fp32x32.cairo | 18 ++++ .../implementations/sequence_fp64x64.cairo | 18 ++++ .../implementations/sequence_fp8x23.cairo | 18 ++++ .../implementations/sequence_i32.cairo | 18 ++++ .../implementations/sequence_i8.cairo | 18 ++++ .../implementations/sequence_u32.cairo | 17 ++++ 14 files changed, 258 insertions(+) create mode 100644 src/operators/sequence.cairo create mode 100644 src/operators/sequence/core.cairo create mode 100644 src/operators/sequence/functional.cairo create mode 100644 src/operators/sequence/functional/sequence_construct.cairo create mode 100644 src/operators/sequence/functional/sequence_empty.cairo create mode 100644 src/operators/sequence/implementations.cairo create mode 100644 src/operators/sequence/implementations/sequence_fp16x16.cairo create mode 100644 src/operators/sequence/implementations/sequence_fp32x32.cairo create mode 100644 src/operators/sequence/implementations/sequence_fp64x64.cairo create mode 100644 src/operators/sequence/implementations/sequence_fp8x23.cairo create mode 100644 src/operators/sequence/implementations/sequence_i32.cairo create mode 100644 src/operators/sequence/implementations/sequence_i8.cairo create mode 100644 src/operators/sequence/implementations/sequence_u32.cairo diff --git a/src/operators.cairo b/src/operators.cairo index 1ca6cdee2..f125386a2 100644 --- a/src/operators.cairo +++ b/src/operators.cairo @@ -3,3 +3,4 @@ mod nn; mod ml; mod matrix; mod vec; +mod sequence; diff --git a/src/operators/sequence.cairo b/src/operators/sequence.cairo new file mode 100644 index 000000000..d0e4583a1 --- /dev/null +++ b/src/operators/sequence.cairo @@ -0,0 +1,11 @@ +mod core; +mod implementations; +mod functional; + +use orion::operators::sequence::core::SequenceTrait; + +use orion::operators::sequence::implementations::sequence_fp8x23::FP8x23Sequence; +use orion::operators::sequence::implementations::sequence_fp16x16::FP16x16Sequence; +use orion::operators::sequence::implementations::sequence_i8::I8Sequence; +use orion::operators::sequence::implementations::sequence_i32::I32Sequence; +use orion::operators::sequence::implementations::sequence_u32::U32Sequence; diff --git a/src/operators/sequence/core.cairo b/src/operators/sequence/core.cairo new file mode 100644 index 000000000..f86b7b0dc --- /dev/null +++ b/src/operators/sequence/core.cairo @@ -0,0 +1,82 @@ +use orion::operators::tensor::core::Tensor; + +/// Trait +/// +/// sequence_construct – Constructs a tensor sequence containing the input tensors. +/// sequence_empty - Returns an empty tensor sequence. +trait SequenceTrait { + /// ## tensor.sequence_construct + /// + /// ```rust + /// fn sequence_construct(tensors: Array>) -> Array>; + /// ``` + /// + /// Constructs a tensor sequence containing the input tensors. + /// + /// ## Args + /// + /// * `tensors`(`Array>`) - The array of input tensors. + /// + /// ## Panics + /// + /// * Panics if input tensor array is empty. + /// + /// ## Returns + /// + /// A tensor sequence `Array>` containing the input tensors. + /// + /// ## Examples + /// + /// ```rust + /// use array::{ArrayTrait, SpanTrait}; + /// + /// use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor}; + /// + /// fn sequence_construct_example() -> Array> { + /// let tensor1 = TensorTrait::new(shape: array![2, 2].span(), data: array![0, 1, 2, 3].span()); + /// let tensor2 = TensorTrait::new(shape: array![2, 2].span(), data: array![4, 5, 6, 7].span()); + /// let result = TensorTrait::sequence_construct(tensors: array![tensor1, tensor2]); + /// return result; + /// } + /// >>> [[0, 1, 2, 3], [4, 5, 6, 7]] + /// ``` + /// + fn sequence_construct(tensors: Array>) -> Array>; + /// # tensor.sequence_empty + /// + /// ```rust + /// fn sequence_empty() -> Array>; + /// ``` + /// + /// Returns an empty tensor sequence. + /// + /// ## Args + /// + /// ## Returns + /// + /// An empty `Array>` instance. + /// + /// ## Examples + /// + /// Let's create a new empty sequence. + /// + /// ```rust + /// use array::{ArrayTrait, SpanTrait}; + /// + /// use orion::operators::tensor::{ + /// TensorTrait, // we import the trait + /// Tensor, // we import the type + /// U32Tensor // we import the implementation. + /// }; + /// + /// fn sequence_empty_example() -> Array> { + /// let sequence = TensorTrait::sequence_empty(); + /// + /// return sequence; + /// } + /// + /// >>> [] + /// ``` + /// + fn sequence_empty() -> Array>; +} diff --git a/src/operators/sequence/functional.cairo b/src/operators/sequence/functional.cairo new file mode 100644 index 000000000..17b7014a0 --- /dev/null +++ b/src/operators/sequence/functional.cairo @@ -0,0 +1,2 @@ +mod sequence_construct; +mod sequence_empty; \ No newline at end of file diff --git a/src/operators/sequence/functional/sequence_construct.cairo b/src/operators/sequence/functional/sequence_construct.cairo new file mode 100644 index 000000000..dd8ff3101 --- /dev/null +++ b/src/operators/sequence/functional/sequence_construct.cairo @@ -0,0 +1,11 @@ +use array::{ArrayTrait, SpanTrait}; + +use orion::operators::tensor::{TensorTrait, Tensor}; + + +/// Cf: TensorTrait::sequence_construct docstring +fn sequence_construct>(tensors: Array>) -> Array> { + assert(tensors.len() >= 1, 'Input tensors must be >= 1'); + + return tensors; +} diff --git a/src/operators/sequence/functional/sequence_empty.cairo b/src/operators/sequence/functional/sequence_empty.cairo new file mode 100644 index 000000000..c823cabe8 --- /dev/null +++ b/src/operators/sequence/functional/sequence_empty.cairo @@ -0,0 +1,19 @@ +use array::{ArrayTrait, SpanTrait}; + +use orion::operators::tensor::{TensorTrait, Tensor}; + + +/// Cf: TensorTrait::sequence_empty docstring +fn sequence_empty, impl TDrop: Drop>() -> Array> { + let mut sequence = ArrayTrait::new(); + + let mut shape = ArrayTrait::::new(); + shape.append(0); + + let mut data = ArrayTrait::new(); + let tensor = TensorTrait::new(shape.span(), data.span()); + + sequence.append(tensor); + + sequence +} diff --git a/src/operators/sequence/implementations.cairo b/src/operators/sequence/implementations.cairo new file mode 100644 index 000000000..d9c8974f0 --- /dev/null +++ b/src/operators/sequence/implementations.cairo @@ -0,0 +1,7 @@ +mod sequence_i8; +mod sequence_i32; +mod sequence_u32; +mod sequence_fp8x23; +mod sequence_fp16x16; +mod sequence_fp64x64; +mod sequence_fp32x32; diff --git a/src/operators/sequence/implementations/sequence_fp16x16.cairo b/src/operators/sequence/implementations/sequence_fp16x16.cairo new file mode 100644 index 000000000..29bbe4f56 --- /dev/null +++ b/src/operators/sequence/implementations/sequence_fp16x16.cairo @@ -0,0 +1,18 @@ +use core::option::OptionTrait; + +use orion::operators::tensor::core::Tensor; +use orion::operators::sequence::core::SequenceTrait; +use orion::operators::sequence::functional; +use orion::numbers::fixed_point::implementations::fp16x16::core::FP16x16; +use orion::operators::tensor::implementations::tensor_fp16x16::FP16x16Tensor; + + +impl FP16x16Sequence of SequenceTrait { + fn sequence_construct(tensors: Array>) -> Array> { + functional::sequence_construct::sequence_construct(tensors) + } + + fn sequence_empty() -> Array> { + functional::sequence_empty::sequence_empty::() + } +} diff --git a/src/operators/sequence/implementations/sequence_fp32x32.cairo b/src/operators/sequence/implementations/sequence_fp32x32.cairo new file mode 100644 index 000000000..b0ddc0e92 --- /dev/null +++ b/src/operators/sequence/implementations/sequence_fp32x32.cairo @@ -0,0 +1,18 @@ +use core::option::OptionTrait; + +use orion::operators::tensor::core::Tensor; +use orion::operators::sequence::core::SequenceTrait; +use orion::operators::sequence::functional; +use orion::numbers::fixed_point::implementations::fp32x32::core::FP32x32; +use orion::operators::tensor::implementations::tensor_fp32x32::FP32x32Tensor; + + +impl FP32x32Sequence of SequenceTrait { + fn sequence_construct(tensors: Array>) -> Array> { + functional::sequence_construct::sequence_construct(tensors) + } + + fn sequence_empty() -> Array> { + functional::sequence_empty::sequence_empty::() + } +} diff --git a/src/operators/sequence/implementations/sequence_fp64x64.cairo b/src/operators/sequence/implementations/sequence_fp64x64.cairo new file mode 100644 index 000000000..f9a6759cd --- /dev/null +++ b/src/operators/sequence/implementations/sequence_fp64x64.cairo @@ -0,0 +1,18 @@ +use core::option::OptionTrait; + +use orion::operators::tensor::core::Tensor; +use orion::operators::sequence::core::SequenceTrait; +use orion::operators::sequence::functional; +use orion::numbers::fixed_point::implementations::fp64x64::core::FP64x64; +use orion::operators::tensor::implementations::tensor_fp64x64::FP64x64Tensor; + + +impl FP64x64Sequence of SequenceTrait { + fn sequence_construct(tensors: Array>) -> Array> { + functional::sequence_construct::sequence_construct(tensors) + } + + fn sequence_empty() -> Array> { + functional::sequence_empty::sequence_empty::() + } +} diff --git a/src/operators/sequence/implementations/sequence_fp8x23.cairo b/src/operators/sequence/implementations/sequence_fp8x23.cairo new file mode 100644 index 000000000..ac70a7e47 --- /dev/null +++ b/src/operators/sequence/implementations/sequence_fp8x23.cairo @@ -0,0 +1,18 @@ +use core::option::OptionTrait; + +use orion::operators::tensor::core::Tensor; +use orion::operators::sequence::core::SequenceTrait; +use orion::operators::sequence::functional; +use orion::numbers::fixed_point::implementations::fp8x23::core::FP8x23; +use orion::operators::tensor::implementations::tensor_fp8x23::FP8x23Tensor; + + +impl FP8x23Sequence of SequenceTrait { + fn sequence_construct(tensors: Array>) -> Array> { + functional::sequence_construct::sequence_construct(tensors) + } + + fn sequence_empty() -> Array> { + functional::sequence_empty::sequence_empty::() + } +} diff --git a/src/operators/sequence/implementations/sequence_i32.cairo b/src/operators/sequence/implementations/sequence_i32.cairo new file mode 100644 index 000000000..28d543fff --- /dev/null +++ b/src/operators/sequence/implementations/sequence_i32.cairo @@ -0,0 +1,18 @@ +use core::option::OptionTrait; + +use orion::operators::tensor::core::Tensor; +use orion::operators::sequence::core::SequenceTrait; +use orion::operators::sequence::functional; +use orion::numbers::signed_integer::i32::i32; +use orion::operators::tensor::implementations::tensor_i32::I32Tensor; + + +impl I32Sequence of SequenceTrait { + fn sequence_construct(tensors: Array>) -> Array> { + functional::sequence_construct::sequence_construct(tensors) + } + + fn sequence_empty() -> Array> { + functional::sequence_empty::sequence_empty::() + } +} diff --git a/src/operators/sequence/implementations/sequence_i8.cairo b/src/operators/sequence/implementations/sequence_i8.cairo new file mode 100644 index 000000000..73b886299 --- /dev/null +++ b/src/operators/sequence/implementations/sequence_i8.cairo @@ -0,0 +1,18 @@ +use core::option::OptionTrait; + +use orion::operators::tensor::core::Tensor; +use orion::operators::sequence::core::SequenceTrait; +use orion::operators::sequence::functional; +use orion::numbers::signed_integer::i8::i8; +use orion::operators::tensor::implementations::tensor_i8::I8Tensor; + + +impl I8Sequence of SequenceTrait { + fn sequence_construct(tensors: Array>) -> Array> { + functional::sequence_construct::sequence_construct(tensors) + } + + fn sequence_empty() -> Array> { + functional::sequence_empty::sequence_empty::() + } +} diff --git a/src/operators/sequence/implementations/sequence_u32.cairo b/src/operators/sequence/implementations/sequence_u32.cairo new file mode 100644 index 000000000..a6bdb25e9 --- /dev/null +++ b/src/operators/sequence/implementations/sequence_u32.cairo @@ -0,0 +1,17 @@ +use core::option::OptionTrait; + +use orion::operators::tensor::core::Tensor; +use orion::operators::sequence::core::SequenceTrait; +use orion::operators::sequence::functional; +use orion::operators::tensor::implementations::tensor_u32::U32Tensor; + + +impl U32Sequence of SequenceTrait { + fn sequence_construct(tensors: Array>) -> Array> { + functional::sequence_construct::sequence_construct(tensors) + } + + fn sequence_empty() -> Array> { + functional::sequence_empty::sequence_empty::() + } +} From 3edc3dfc813e9a15ff93b6650732aa81ad065e35 Mon Sep 17 00:00:00 2001 From: Daniel Voronov Date: Fri, 17 Nov 2023 18:02:15 +0100 Subject: [PATCH 02/38] Adapt nodegen tests --- nodegen/helpers.py | 24 ++++- nodegen/node/sequence_construct.py | 12 +-- nodegen/node/sequence_empty.py | 17 ++-- tests/nodes/sequence_construct_fp16x16.cairo | 8 +- .../sequence_construct_fp16x16/input_0.cairo | 37 ++------ .../sequence_construct_fp16x16/output_0.cairo | 37 ++------ tests/nodes/sequence_construct_fp8x23.cairo | 8 +- .../sequence_construct_fp8x23/input_0.cairo | 40 ++------ .../sequence_construct_fp8x23/output_0.cairo | 40 ++------ tests/nodes/sequence_construct_i32.cairo | 10 +- .../sequence_construct_i32/input_0.cairo | 28 +++++- .../sequence_construct_i32/output_0.cairo | 28 +++++- tests/nodes/sequence_construct_i8.cairo | 8 +- .../nodes/sequence_construct_i8/input_0.cairo | 29 +++--- .../sequence_construct_i8/output_0.cairo | 29 +++--- tests/nodes/sequence_construct_u32.cairo | 8 +- .../sequence_construct_u32/input_0.cairo | 95 +++++++++++-------- .../sequence_construct_u32/output_0.cairo | 95 +++++++++++-------- tests/nodes/sequence_empty_fp16x16.cairo | 8 +- tests/nodes/sequence_empty_fp8x23.cairo | 8 +- tests/nodes/sequence_empty_i32.cairo | 10 +- tests/nodes/sequence_empty_i8.cairo | 8 +- tests/nodes/sequence_empty_u32.cairo | 10 +- 23 files changed, 318 insertions(+), 279 deletions(-) diff --git a/nodegen/helpers.py b/nodegen/helpers.py index 517954c14..db363d7e7 100644 --- a/nodegen/helpers.py +++ b/nodegen/helpers.py @@ -42,6 +42,7 @@ def __init__(self, dtype: Dtype, shape: tuple, data: np.ndarray): class Trait(Enum): TENSOR = 'TENSOR' NN = 'NN' + SEQUENCE = 'SEQUENCE' def make_test(inputs: list[Tensor | Sequence], output: Tensor | Sequence, func_sig: str, name: str, trait: Trait = Trait.TENSOR): @@ -158,11 +159,17 @@ def get_all_test_refs(dtypes: list[Dtype], trait: Trait) -> list[str]: return list(set(refs)) -def get_test_refs(dtype: Dtype, trait: Trait, is_sequence: bool) -> list[str]: +def get_test_refs(dtype: Dtype, trait: Trait) -> list[str]: if trait == Trait.NN and dtype == Dtype.BOOL: raise Exception("NN trait does not support bool dtype") - dtype_ref = dtype_to_nn[dtype] if trait == Trait.NN else dtype_to_tensor[dtype] + if trait == Trait.NN: + dtype_ref = dtype_to_nn[dtype] + elif trait == Trait.SEQUENCE: + dtype_ref = dtype_to_sequence[dtype] + else: + dtype_ref = dtype_to_tensor[dtype] + refs = [ *trait_to_ref[trait], *dtype_ref, @@ -193,6 +200,10 @@ def find_all_types(tensors: list[Tensor | Sequence]) -> list[Dtype]: "orion::numbers::FixedTrait", "orion::operators::nn::NNTrait", ], + Trait.SEQUENCE: [ + "array::{ArrayTrait, SpanTrait}", + "orion::operators::sequence::SequenceTrait", + ], } @@ -215,6 +226,15 @@ def find_all_types(tensors: list[Tensor | Sequence]) -> list[Dtype]: } +dtype_to_sequence = { + Dtype.U32: ["orion::operators::sequence::U32Sequence",], + Dtype.I32: ["orion::operators::sequence::I32Sequence",], + Dtype.I8: ["orion::operators::sequence::I8Sequence",], + Dtype.FP8x23: ["orion::operators::sequence::FP8x23Sequence",], + Dtype.FP16x16: ["orion::operators::sequence::FP16x16Sequence",], +} + + dtype_to_partial_eq = { Dtype.U32: ["orion::operators::tensor::U32TensorPartialEq",], Dtype.I32: ["orion::operators::tensor::I32TensorPartialEq",], diff --git a/nodegen/node/sequence_construct.py b/nodegen/node/sequence_construct.py index 7c76d1532..fba919f56 100644 --- a/nodegen/node/sequence_construct.py +++ b/nodegen/node/sequence_construct.py @@ -1,6 +1,6 @@ import numpy as np from nodegen.node import RunAll -from ..helpers import make_test, to_fp, Tensor, Dtype, FixedImpl +from ..helpers import make_test, to_fp, Tensor, Dtype, FixedImpl, Trait class Sequence_construct(RunAll): @@ -18,7 +18,7 @@ def sequence_construct_u32(): sequence.append(tensor) name = "sequence_construct_u32" - make_test([sequence], sequence, "TensorTrait::sequence_construct(input_0)", name) + make_test([sequence], sequence, "SequenceTrait::sequence_construct(input_0)", name, Trait.SEQUENCE) @staticmethod @@ -34,7 +34,7 @@ def sequence_construct_i32(): sequence.append(tensor) name = "sequence_construct_i32" - make_test([sequence], sequence, "TensorTrait::sequence_construct(input_0)", name) + make_test([sequence], sequence, "SequenceTrait::sequence_construct(input_0)", name, Trait.SEQUENCE) @staticmethod @@ -50,7 +50,7 @@ def sequence_construct_i8(): sequence.append(tensor) name = "sequence_construct_i8" - make_test([sequence], sequence, "TensorTrait::sequence_construct(input_0)", name) + make_test([sequence], sequence, "SequenceTrait::sequence_construct(input_0)", name, Trait.SEQUENCE) @staticmethod @@ -66,7 +66,7 @@ def sequence_construct_fp8x23(): sequence.append(tensor) name = "sequence_construct_fp8x23" - make_test([sequence], sequence, "TensorTrait::sequence_construct(input_0)", name) + make_test([sequence], sequence, "SequenceTrait::sequence_construct(input_0)", name, Trait.SEQUENCE) @staticmethod @@ -82,4 +82,4 @@ def sequence_construct_fp16x16(): sequence.append(tensor) name = "sequence_construct_fp16x16" - make_test([sequence], sequence, "TensorTrait::sequence_construct(input_0)", name) + make_test([sequence], sequence, "SequenceTrait::sequence_construct(input_0)", name, Trait.SEQUENCE) diff --git a/nodegen/node/sequence_empty.py b/nodegen/node/sequence_empty.py index 91dc78dbc..779859fa5 100644 --- a/nodegen/node/sequence_empty.py +++ b/nodegen/node/sequence_empty.py @@ -1,6 +1,6 @@ import numpy as np from nodegen.node import RunAll -from ..helpers import make_test, Dtype, Tensor +from ..helpers import make_test, Dtype, Tensor, Trait class Sequence_empty(RunAll): @@ -14,8 +14,9 @@ def default(): make_test( inputs=[], output=[t], - func_sig="TensorTrait::sequence_empty()", + func_sig="SequenceTrait::sequence_empty()", name="sequence_empty_u32", + trait=Trait.SEQUENCE ) default() @@ -29,8 +30,9 @@ def default(): make_test( inputs=[], output=[t], - func_sig="TensorTrait::sequence_empty()", + func_sig="SequenceTrait::sequence_empty()", name="sequence_empty_i32", + trait=Trait.SEQUENCE ) default() @@ -44,8 +46,9 @@ def default(): make_test( inputs=[], output=[t], - func_sig="TensorTrait::sequence_empty()", + func_sig="SequenceTrait::sequence_empty()", name="sequence_empty_i8", + trait=Trait.SEQUENCE ) default() @@ -59,8 +62,9 @@ def default(): make_test( inputs=[], output=[t], - func_sig="TensorTrait::sequence_empty()", + func_sig="SequenceTrait::sequence_empty()", name="sequence_empty_fp8x23", + trait=Trait.SEQUENCE ) default() @@ -74,8 +78,9 @@ def default(): make_test( inputs=[], output=[t], - func_sig="TensorTrait::sequence_empty()", + func_sig="SequenceTrait::sequence_empty()", name="sequence_empty_fp16x16", + trait=Trait.SEQUENCE ) default() diff --git a/tests/nodes/sequence_construct_fp16x16.cairo b/tests/nodes/sequence_construct_fp16x16.cairo index 926a57fe2..291196d2d 100644 --- a/tests/nodes/sequence_construct_fp16x16.cairo +++ b/tests/nodes/sequence_construct_fp16x16.cairo @@ -2,11 +2,11 @@ mod input_0; mod output_0; +use orion::utils::{assert_eq, assert_seq_eq}; use orion::operators::tensor::FP16x16TensorPartialEq; -use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::sequence::SequenceTrait; use array::{ArrayTrait, SpanTrait}; -use orion::utils::{assert_eq, assert_seq_eq}; -use orion::operators::tensor::FP16x16Tensor; +use orion::operators::sequence::FP16x16Sequence; #[test] #[available_gas(2000000000)] @@ -14,7 +14,7 @@ fn test_sequence_construct_fp16x16() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = TensorTrait::sequence_construct(input_0); + let y = SequenceTrait::sequence_construct(input_0); assert_seq_eq(y, z); } diff --git a/tests/nodes/sequence_construct_fp16x16/input_0.cairo b/tests/nodes/sequence_construct_fp16x16/input_0.cairo index 770942e5b..b82457030 100644 --- a/tests/nodes/sequence_construct_fp16x16/input_0.cairo +++ b/tests/nodes/sequence_construct_fp16x16/input_0.cairo @@ -7,68 +7,47 @@ fn input_0() -> Array> { let mut sequence = ArrayTrait::new(); let mut shape = ArrayTrait::::new(); - shape.append(3); shape.append(1); - - let mut data = ArrayTrait::new(); - data.append(FP16x16 { mag: 131072, sign: true }); - data.append(FP16x16 { mag: 131072, sign: false }); - data.append(FP16x16 { mag: 262144, sign: false }); - - sequence.append(TensorTrait::new(shape.span(), data.span())); - - let mut shape = ArrayTrait::::new(); - shape.append(3); shape.append(1); let mut data = ArrayTrait::new(); - data.append(FP16x16 { mag: 393216, sign: true }); data.append(FP16x16 { mag: 327680, sign: false }); - data.append(FP16x16 { mag: 65536, sign: true }); sequence.append(TensorTrait::new(shape.span(), data.span())); let mut shape = ArrayTrait::::new(); - shape.append(3); + shape.append(1); shape.append(1); let mut data = ArrayTrait::new(); - data.append(FP16x16 { mag: 65536, sign: true }); - data.append(FP16x16 { mag: 131072, sign: false }); - data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 196608, sign: true }); sequence.append(TensorTrait::new(shape.span(), data.span())); let mut shape = ArrayTrait::::new(); - shape.append(3); + shape.append(1); shape.append(1); let mut data = ArrayTrait::new(); - data.append(FP16x16 { mag: 262144, sign: true }); - data.append(FP16x16 { mag: 0, sign: false }); - data.append(FP16x16 { mag: 196608, sign: true }); + data.append(FP16x16 { mag: 327680, sign: true }); sequence.append(TensorTrait::new(shape.span(), data.span())); let mut shape = ArrayTrait::::new(); - shape.append(3); + shape.append(1); shape.append(1); let mut data = ArrayTrait::new(); - data.append(FP16x16 { mag: 65536, sign: false }); - data.append(FP16x16 { mag: 262144, sign: false }); - data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 0, sign: false }); sequence.append(TensorTrait::new(shape.span(), data.span())); let mut shape = ArrayTrait::::new(); - shape.append(3); + shape.append(1); shape.append(1); let mut data = ArrayTrait::new(); - data.append(FP16x16 { mag: 327680, sign: false }); - data.append(FP16x16 { mag: 196608, sign: false }); - data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 0, sign: false }); sequence.append(TensorTrait::new(shape.span(), data.span())); diff --git a/tests/nodes/sequence_construct_fp16x16/output_0.cairo b/tests/nodes/sequence_construct_fp16x16/output_0.cairo index 1de556154..cbe5c31ae 100644 --- a/tests/nodes/sequence_construct_fp16x16/output_0.cairo +++ b/tests/nodes/sequence_construct_fp16x16/output_0.cairo @@ -7,68 +7,47 @@ fn output_0() -> Array> { let mut sequence = ArrayTrait::new(); let mut shape = ArrayTrait::::new(); - shape.append(3); shape.append(1); - - let mut data = ArrayTrait::new(); - data.append(FP16x16 { mag: 131072, sign: true }); - data.append(FP16x16 { mag: 131072, sign: false }); - data.append(FP16x16 { mag: 262144, sign: false }); - - sequence.append(TensorTrait::new(shape.span(), data.span())); - - let mut shape = ArrayTrait::::new(); - shape.append(3); shape.append(1); let mut data = ArrayTrait::new(); - data.append(FP16x16 { mag: 393216, sign: true }); data.append(FP16x16 { mag: 327680, sign: false }); - data.append(FP16x16 { mag: 65536, sign: true }); sequence.append(TensorTrait::new(shape.span(), data.span())); let mut shape = ArrayTrait::::new(); - shape.append(3); + shape.append(1); shape.append(1); let mut data = ArrayTrait::new(); - data.append(FP16x16 { mag: 65536, sign: true }); - data.append(FP16x16 { mag: 131072, sign: false }); - data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 196608, sign: true }); sequence.append(TensorTrait::new(shape.span(), data.span())); let mut shape = ArrayTrait::::new(); - shape.append(3); + shape.append(1); shape.append(1); let mut data = ArrayTrait::new(); - data.append(FP16x16 { mag: 262144, sign: true }); - data.append(FP16x16 { mag: 0, sign: false }); - data.append(FP16x16 { mag: 196608, sign: true }); + data.append(FP16x16 { mag: 327680, sign: true }); sequence.append(TensorTrait::new(shape.span(), data.span())); let mut shape = ArrayTrait::::new(); - shape.append(3); + shape.append(1); shape.append(1); let mut data = ArrayTrait::new(); - data.append(FP16x16 { mag: 65536, sign: false }); - data.append(FP16x16 { mag: 262144, sign: false }); - data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 0, sign: false }); sequence.append(TensorTrait::new(shape.span(), data.span())); let mut shape = ArrayTrait::::new(); - shape.append(3); + shape.append(1); shape.append(1); let mut data = ArrayTrait::new(); - data.append(FP16x16 { mag: 327680, sign: false }); - data.append(FP16x16 { mag: 196608, sign: false }); - data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 0, sign: false }); sequence.append(TensorTrait::new(shape.span(), data.span())); diff --git a/tests/nodes/sequence_construct_fp8x23.cairo b/tests/nodes/sequence_construct_fp8x23.cairo index 9a024885d..bf42ccf59 100644 --- a/tests/nodes/sequence_construct_fp8x23.cairo +++ b/tests/nodes/sequence_construct_fp8x23.cairo @@ -2,11 +2,11 @@ mod input_0; mod output_0; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::sequence::FP8x23Sequence; use orion::operators::tensor::FP8x23TensorPartialEq; -use orion::operators::tensor::FP8x23Tensor; -use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::sequence::SequenceTrait; use array::{ArrayTrait, SpanTrait}; -use orion::utils::{assert_eq, assert_seq_eq}; #[test] #[available_gas(2000000000)] @@ -14,7 +14,7 @@ fn test_sequence_construct_fp8x23() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = TensorTrait::sequence_construct(input_0); + let y = SequenceTrait::sequence_construct(input_0); assert_seq_eq(y, z); } diff --git a/tests/nodes/sequence_construct_fp8x23/input_0.cairo b/tests/nodes/sequence_construct_fp8x23/input_0.cairo index 37fb75902..ce92c4d84 100644 --- a/tests/nodes/sequence_construct_fp8x23/input_0.cairo +++ b/tests/nodes/sequence_construct_fp8x23/input_0.cairo @@ -12,7 +12,7 @@ fn input_0() -> Array> { let mut data = ArrayTrait::new(); data.append(FP8x23 { mag: 41943040, sign: true }); - data.append(FP8x23 { mag: 16777216, sign: false }); + data.append(FP8x23 { mag: 8388608, sign: false }); sequence.append(TensorTrait::new(shape.span(), data.span())); @@ -21,28 +21,8 @@ fn input_0() -> Array> { shape.append(2); let mut data = ArrayTrait::new(); - data.append(FP8x23 { mag: 25165824, sign: false }); - data.append(FP8x23 { mag: 33554432, sign: false }); - - sequence.append(TensorTrait::new(shape.span(), data.span())); - - let mut shape = ArrayTrait::::new(); - shape.append(1); - shape.append(2); - - let mut data = ArrayTrait::new(); - data.append(FP8x23 { mag: 8388608, sign: true }); - data.append(FP8x23 { mag: 25165824, sign: false }); - - sequence.append(TensorTrait::new(shape.span(), data.span())); - - let mut shape = ArrayTrait::::new(); - shape.append(1); - shape.append(2); - - let mut data = ArrayTrait::new(); - data.append(FP8x23 { mag: 16777216, sign: false }); - data.append(FP8x23 { mag: 25165824, sign: false }); + data.append(FP8x23 { mag: 33554432, sign: true }); + data.append(FP8x23 { mag: 41943040, sign: true }); sequence.append(TensorTrait::new(shape.span(), data.span())); @@ -51,8 +31,8 @@ fn input_0() -> Array> { shape.append(2); let mut data = ArrayTrait::new(); - data.append(FP8x23 { mag: 33554432, sign: true }); - data.append(FP8x23 { mag: 41943040, sign: false }); + data.append(FP8x23 { mag: 33554432, sign: false }); + data.append(FP8x23 { mag: 25165824, sign: true }); sequence.append(TensorTrait::new(shape.span(), data.span())); @@ -61,8 +41,8 @@ fn input_0() -> Array> { shape.append(2); let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 0, sign: false }); data.append(FP8x23 { mag: 41943040, sign: true }); - data.append(FP8x23 { mag: 41943040, sign: false }); sequence.append(TensorTrait::new(shape.span(), data.span())); @@ -71,8 +51,8 @@ fn input_0() -> Array> { shape.append(2); let mut data = ArrayTrait::new(); - data.append(FP8x23 { mag: 8388608, sign: true }); - data.append(FP8x23 { mag: 33554432, sign: false }); + data.append(FP8x23 { mag: 8388608, sign: false }); + data.append(FP8x23 { mag: 33554432, sign: true }); sequence.append(TensorTrait::new(shape.span(), data.span())); @@ -81,8 +61,8 @@ fn input_0() -> Array> { shape.append(2); let mut data = ArrayTrait::new(); - data.append(FP8x23 { mag: 25165824, sign: true }); - data.append(FP8x23 { mag: 16777216, sign: true }); + data.append(FP8x23 { mag: 25165824, sign: false }); + data.append(FP8x23 { mag: 50331648, sign: true }); sequence.append(TensorTrait::new(shape.span(), data.span())); diff --git a/tests/nodes/sequence_construct_fp8x23/output_0.cairo b/tests/nodes/sequence_construct_fp8x23/output_0.cairo index 8bae335fe..8e33f82c4 100644 --- a/tests/nodes/sequence_construct_fp8x23/output_0.cairo +++ b/tests/nodes/sequence_construct_fp8x23/output_0.cairo @@ -12,7 +12,7 @@ fn output_0() -> Array> { let mut data = ArrayTrait::new(); data.append(FP8x23 { mag: 41943040, sign: true }); - data.append(FP8x23 { mag: 16777216, sign: false }); + data.append(FP8x23 { mag: 8388608, sign: false }); sequence.append(TensorTrait::new(shape.span(), data.span())); @@ -21,28 +21,8 @@ fn output_0() -> Array> { shape.append(2); let mut data = ArrayTrait::new(); - data.append(FP8x23 { mag: 25165824, sign: false }); - data.append(FP8x23 { mag: 33554432, sign: false }); - - sequence.append(TensorTrait::new(shape.span(), data.span())); - - let mut shape = ArrayTrait::::new(); - shape.append(1); - shape.append(2); - - let mut data = ArrayTrait::new(); - data.append(FP8x23 { mag: 8388608, sign: true }); - data.append(FP8x23 { mag: 25165824, sign: false }); - - sequence.append(TensorTrait::new(shape.span(), data.span())); - - let mut shape = ArrayTrait::::new(); - shape.append(1); - shape.append(2); - - let mut data = ArrayTrait::new(); - data.append(FP8x23 { mag: 16777216, sign: false }); - data.append(FP8x23 { mag: 25165824, sign: false }); + data.append(FP8x23 { mag: 33554432, sign: true }); + data.append(FP8x23 { mag: 41943040, sign: true }); sequence.append(TensorTrait::new(shape.span(), data.span())); @@ -51,8 +31,8 @@ fn output_0() -> Array> { shape.append(2); let mut data = ArrayTrait::new(); - data.append(FP8x23 { mag: 33554432, sign: true }); - data.append(FP8x23 { mag: 41943040, sign: false }); + data.append(FP8x23 { mag: 33554432, sign: false }); + data.append(FP8x23 { mag: 25165824, sign: true }); sequence.append(TensorTrait::new(shape.span(), data.span())); @@ -61,8 +41,8 @@ fn output_0() -> Array> { shape.append(2); let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 0, sign: false }); data.append(FP8x23 { mag: 41943040, sign: true }); - data.append(FP8x23 { mag: 41943040, sign: false }); sequence.append(TensorTrait::new(shape.span(), data.span())); @@ -71,8 +51,8 @@ fn output_0() -> Array> { shape.append(2); let mut data = ArrayTrait::new(); - data.append(FP8x23 { mag: 8388608, sign: true }); - data.append(FP8x23 { mag: 33554432, sign: false }); + data.append(FP8x23 { mag: 8388608, sign: false }); + data.append(FP8x23 { mag: 33554432, sign: true }); sequence.append(TensorTrait::new(shape.span(), data.span())); @@ -81,8 +61,8 @@ fn output_0() -> Array> { shape.append(2); let mut data = ArrayTrait::new(); - data.append(FP8x23 { mag: 25165824, sign: true }); - data.append(FP8x23 { mag: 16777216, sign: true }); + data.append(FP8x23 { mag: 25165824, sign: false }); + data.append(FP8x23 { mag: 50331648, sign: true }); sequence.append(TensorTrait::new(shape.span(), data.span())); diff --git a/tests/nodes/sequence_construct_i32.cairo b/tests/nodes/sequence_construct_i32.cairo index 8c4f6a19f..ba9939c4a 100644 --- a/tests/nodes/sequence_construct_i32.cairo +++ b/tests/nodes/sequence_construct_i32.cairo @@ -2,11 +2,11 @@ mod input_0; mod output_0; -use orion::operators::tensor::I32Tensor; -use orion::operators::tensor::{TensorTrait, Tensor}; -use array::{ArrayTrait, SpanTrait}; -use orion::utils::{assert_eq, assert_seq_eq}; use orion::operators::tensor::I32TensorPartialEq; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::sequence::I32Sequence; +use orion::operators::sequence::SequenceTrait; +use array::{ArrayTrait, SpanTrait}; #[test] #[available_gas(2000000000)] @@ -14,7 +14,7 @@ fn test_sequence_construct_i32() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = TensorTrait::sequence_construct(input_0); + let y = SequenceTrait::sequence_construct(input_0); assert_seq_eq(y, z); } diff --git a/tests/nodes/sequence_construct_i32/input_0.cairo b/tests/nodes/sequence_construct_i32/input_0.cairo index 872cd3195..aa6875a42 100644 --- a/tests/nodes/sequence_construct_i32/input_0.cairo +++ b/tests/nodes/sequence_construct_i32/input_0.cairo @@ -7,22 +7,42 @@ fn input_0() -> Array> { let mut sequence = ArrayTrait::new(); let mut shape = ArrayTrait::::new(); + shape.append(2); shape.append(1); + + let mut data = ArrayTrait::new(); + data.append(i32 { mag: 5, sign: false }); + data.append(i32 { mag: 5, sign: true }); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + let mut shape = ArrayTrait::::new(); shape.append(2); + shape.append(1); let mut data = ArrayTrait::new(); - data.append(i32 { mag: 2, sign: false }); - data.append(i32 { mag: 6, sign: true }); + data.append(i32 { mag: 2, sign: true }); + data.append(i32 { mag: 5, sign: false }); sequence.append(TensorTrait::new(shape.span(), data.span())); let mut shape = ArrayTrait::::new(); + shape.append(2); shape.append(1); + + let mut data = ArrayTrait::new(); + data.append(i32 { mag: 2, sign: false }); + data.append(i32 { mag: 3, sign: false }); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + let mut shape = ArrayTrait::::new(); shape.append(2); + shape.append(1); let mut data = ArrayTrait::new(); - data.append(i32 { mag: 5, sign: false }); - data.append(i32 { mag: 6, sign: true }); + data.append(i32 { mag: 0, sign: false }); + data.append(i32 { mag: 5, sign: true }); sequence.append(TensorTrait::new(shape.span(), data.span())); diff --git a/tests/nodes/sequence_construct_i32/output_0.cairo b/tests/nodes/sequence_construct_i32/output_0.cairo index 4f246f2d1..ce3a56bcd 100644 --- a/tests/nodes/sequence_construct_i32/output_0.cairo +++ b/tests/nodes/sequence_construct_i32/output_0.cairo @@ -7,22 +7,42 @@ fn output_0() -> Array> { let mut sequence = ArrayTrait::new(); let mut shape = ArrayTrait::::new(); + shape.append(2); shape.append(1); + + let mut data = ArrayTrait::new(); + data.append(i32 { mag: 5, sign: false }); + data.append(i32 { mag: 5, sign: true }); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + let mut shape = ArrayTrait::::new(); shape.append(2); + shape.append(1); let mut data = ArrayTrait::new(); - data.append(i32 { mag: 2, sign: false }); - data.append(i32 { mag: 6, sign: true }); + data.append(i32 { mag: 2, sign: true }); + data.append(i32 { mag: 5, sign: false }); sequence.append(TensorTrait::new(shape.span(), data.span())); let mut shape = ArrayTrait::::new(); + shape.append(2); shape.append(1); + + let mut data = ArrayTrait::new(); + data.append(i32 { mag: 2, sign: false }); + data.append(i32 { mag: 3, sign: false }); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + let mut shape = ArrayTrait::::new(); shape.append(2); + shape.append(1); let mut data = ArrayTrait::new(); - data.append(i32 { mag: 5, sign: false }); - data.append(i32 { mag: 6, sign: true }); + data.append(i32 { mag: 0, sign: false }); + data.append(i32 { mag: 5, sign: true }); sequence.append(TensorTrait::new(shape.span(), data.span())); diff --git a/tests/nodes/sequence_construct_i8.cairo b/tests/nodes/sequence_construct_i8.cairo index 0c0dda805..c52e20d15 100644 --- a/tests/nodes/sequence_construct_i8.cairo +++ b/tests/nodes/sequence_construct_i8.cairo @@ -2,11 +2,11 @@ mod input_0; mod output_0; +use orion::utils::{assert_eq, assert_seq_eq}; use orion::operators::tensor::I8TensorPartialEq; -use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::sequence::SequenceTrait; use array::{ArrayTrait, SpanTrait}; -use orion::utils::{assert_eq, assert_seq_eq}; -use orion::operators::tensor::I8Tensor; +use orion::operators::sequence::I8Sequence; #[test] #[available_gas(2000000000)] @@ -14,7 +14,7 @@ fn test_sequence_construct_i8() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = TensorTrait::sequence_construct(input_0); + let y = SequenceTrait::sequence_construct(input_0); assert_seq_eq(y, z); } diff --git a/tests/nodes/sequence_construct_i8/input_0.cairo b/tests/nodes/sequence_construct_i8/input_0.cairo index b66c72889..33cc57006 100644 --- a/tests/nodes/sequence_construct_i8/input_0.cairo +++ b/tests/nodes/sequence_construct_i8/input_0.cairo @@ -7,35 +7,42 @@ fn input_0() -> Array> { let mut sequence = ArrayTrait::new(); let mut shape = ArrayTrait::::new(); + shape.append(2); shape.append(1); - shape.append(3); let mut data = ArrayTrait::new(); - data.append(i8 { mag: 1, sign: false }); - data.append(i8 { mag: 5, sign: true }); - data.append(i8 { mag: 5, sign: false }); + data.append(i8 { mag: 4, sign: false }); + data.append(i8 { mag: 6, sign: true }); sequence.append(TensorTrait::new(shape.span(), data.span())); let mut shape = ArrayTrait::::new(); + shape.append(2); shape.append(1); - shape.append(3); let mut data = ArrayTrait::new(); - data.append(i8 { mag: 5, sign: false }); + data.append(i8 { mag: 0, sign: false }); data.append(i8 { mag: 4, sign: false }); - data.append(i8 { mag: 6, sign: true }); sequence.append(TensorTrait::new(shape.span(), data.span())); let mut shape = ArrayTrait::::new(); + shape.append(2); shape.append(1); - shape.append(3); let mut data = ArrayTrait::new(); - data.append(i8 { mag: 5, sign: false }); - data.append(i8 { mag: 3, sign: true }); - data.append(i8 { mag: 4, sign: false }); + data.append(i8 { mag: 3, sign: false }); + data.append(i8 { mag: 1, sign: true }); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(1); + + let mut data = ArrayTrait::new(); + data.append(i8 { mag: 2, sign: true }); + data.append(i8 { mag: 3, sign: false }); sequence.append(TensorTrait::new(shape.span(), data.span())); diff --git a/tests/nodes/sequence_construct_i8/output_0.cairo b/tests/nodes/sequence_construct_i8/output_0.cairo index e048ee27a..dea819b73 100644 --- a/tests/nodes/sequence_construct_i8/output_0.cairo +++ b/tests/nodes/sequence_construct_i8/output_0.cairo @@ -7,35 +7,42 @@ fn output_0() -> Array> { let mut sequence = ArrayTrait::new(); let mut shape = ArrayTrait::::new(); + shape.append(2); shape.append(1); - shape.append(3); let mut data = ArrayTrait::new(); - data.append(i8 { mag: 1, sign: false }); - data.append(i8 { mag: 5, sign: true }); - data.append(i8 { mag: 5, sign: false }); + data.append(i8 { mag: 4, sign: false }); + data.append(i8 { mag: 6, sign: true }); sequence.append(TensorTrait::new(shape.span(), data.span())); let mut shape = ArrayTrait::::new(); + shape.append(2); shape.append(1); - shape.append(3); let mut data = ArrayTrait::new(); - data.append(i8 { mag: 5, sign: false }); + data.append(i8 { mag: 0, sign: false }); data.append(i8 { mag: 4, sign: false }); - data.append(i8 { mag: 6, sign: true }); sequence.append(TensorTrait::new(shape.span(), data.span())); let mut shape = ArrayTrait::::new(); + shape.append(2); shape.append(1); - shape.append(3); let mut data = ArrayTrait::new(); - data.append(i8 { mag: 5, sign: false }); - data.append(i8 { mag: 3, sign: true }); - data.append(i8 { mag: 4, sign: false }); + data.append(i8 { mag: 3, sign: false }); + data.append(i8 { mag: 1, sign: true }); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(1); + + let mut data = ArrayTrait::new(); + data.append(i8 { mag: 2, sign: true }); + data.append(i8 { mag: 3, sign: false }); sequence.append(TensorTrait::new(shape.span(), data.span())); diff --git a/tests/nodes/sequence_construct_u32.cairo b/tests/nodes/sequence_construct_u32.cairo index 9301350d7..280b62101 100644 --- a/tests/nodes/sequence_construct_u32.cairo +++ b/tests/nodes/sequence_construct_u32.cairo @@ -2,11 +2,11 @@ mod input_0; mod output_0; -use orion::operators::tensor::{TensorTrait, Tensor}; -use array::{ArrayTrait, SpanTrait}; -use orion::operators::tensor::U32Tensor; use orion::utils::{assert_eq, assert_seq_eq}; use orion::operators::tensor::U32TensorPartialEq; +use orion::operators::sequence::SequenceTrait; +use array::{ArrayTrait, SpanTrait}; +use orion::operators::sequence::U32Sequence; #[test] #[available_gas(2000000000)] @@ -14,7 +14,7 @@ fn test_sequence_construct_u32() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = TensorTrait::sequence_construct(input_0); + let y = SequenceTrait::sequence_construct(input_0); assert_seq_eq(y, z); } diff --git a/tests/nodes/sequence_construct_u32/input_0.cairo b/tests/nodes/sequence_construct_u32/input_0.cairo index 6fa9b512b..50920f8ac 100644 --- a/tests/nodes/sequence_construct_u32/input_0.cairo +++ b/tests/nodes/sequence_construct_u32/input_0.cairo @@ -6,83 +6,104 @@ fn input_0() -> Array> { let mut sequence = ArrayTrait::new(); let mut shape = ArrayTrait::::new(); - shape.append(1); - shape.append(1); + shape.append(3); + shape.append(3); let mut data = ArrayTrait::new(); - data.append(4); + data.append(1); + data.append(3); + data.append(3); + data.append(5); + data.append(2); + data.append(2); + data.append(1); + data.append(0); + data.append(5); sequence.append(TensorTrait::new(shape.span(), data.span())); let mut shape = ArrayTrait::::new(); - shape.append(1); - shape.append(1); + shape.append(3); + shape.append(3); let mut data = ArrayTrait::new(); + data.append(1); + data.append(1); data.append(4); - - sequence.append(TensorTrait::new(shape.span(), data.span())); - - let mut shape = ArrayTrait::::new(); - shape.append(1); - shape.append(1); - - let mut data = ArrayTrait::new(); + data.append(0); + data.append(4); + data.append(4); + data.append(2); + data.append(3); data.append(5); sequence.append(TensorTrait::new(shape.span(), data.span())); let mut shape = ArrayTrait::::new(); - shape.append(1); - shape.append(1); + shape.append(3); + shape.append(3); let mut data = ArrayTrait::new(); data.append(3); - - sequence.append(TensorTrait::new(shape.span(), data.span())); - - let mut shape = ArrayTrait::::new(); - shape.append(1); - shape.append(1); - - let mut data = ArrayTrait::new(); + data.append(3); + data.append(4); data.append(0); - - sequence.append(TensorTrait::new(shape.span(), data.span())); - - let mut shape = ArrayTrait::::new(); - shape.append(1); - shape.append(1); - - let mut data = ArrayTrait::new(); + data.append(1); data.append(0); + data.append(2); + data.append(1); + data.append(3); sequence.append(TensorTrait::new(shape.span(), data.span())); let mut shape = ArrayTrait::::new(); - shape.append(1); - shape.append(1); + shape.append(3); + shape.append(3); let mut data = ArrayTrait::new(); + data.append(0); + data.append(5); + data.append(1); + data.append(2); + data.append(5); + data.append(0); + data.append(3); + data.append(4); data.append(3); sequence.append(TensorTrait::new(shape.span(), data.span())); let mut shape = ArrayTrait::::new(); - shape.append(1); - shape.append(1); + shape.append(3); + shape.append(3); let mut data = ArrayTrait::new(); + data.append(0); + data.append(5); + data.append(2); + data.append(1); + data.append(3); + data.append(2); + data.append(4); + data.append(1); data.append(4); sequence.append(TensorTrait::new(shape.span(), data.span())); let mut shape = ArrayTrait::::new(); - shape.append(1); - shape.append(1); + shape.append(3); + shape.append(3); let mut data = ArrayTrait::new(); + data.append(5); + data.append(4); + data.append(5); data.append(1); + data.append(5); + data.append(0); + data.append(4); + data.append(2); + data.append(2); sequence.append(TensorTrait::new(shape.span(), data.span())); diff --git a/tests/nodes/sequence_construct_u32/output_0.cairo b/tests/nodes/sequence_construct_u32/output_0.cairo index 78e4ce013..7fc1896a2 100644 --- a/tests/nodes/sequence_construct_u32/output_0.cairo +++ b/tests/nodes/sequence_construct_u32/output_0.cairo @@ -6,83 +6,104 @@ fn output_0() -> Array> { let mut sequence = ArrayTrait::new(); let mut shape = ArrayTrait::::new(); - shape.append(1); - shape.append(1); + shape.append(3); + shape.append(3); let mut data = ArrayTrait::new(); - data.append(4); + data.append(1); + data.append(3); + data.append(3); + data.append(5); + data.append(2); + data.append(2); + data.append(1); + data.append(0); + data.append(5); sequence.append(TensorTrait::new(shape.span(), data.span())); let mut shape = ArrayTrait::::new(); - shape.append(1); - shape.append(1); + shape.append(3); + shape.append(3); let mut data = ArrayTrait::new(); + data.append(1); + data.append(1); data.append(4); - - sequence.append(TensorTrait::new(shape.span(), data.span())); - - let mut shape = ArrayTrait::::new(); - shape.append(1); - shape.append(1); - - let mut data = ArrayTrait::new(); + data.append(0); + data.append(4); + data.append(4); + data.append(2); + data.append(3); data.append(5); sequence.append(TensorTrait::new(shape.span(), data.span())); let mut shape = ArrayTrait::::new(); - shape.append(1); - shape.append(1); + shape.append(3); + shape.append(3); let mut data = ArrayTrait::new(); data.append(3); - - sequence.append(TensorTrait::new(shape.span(), data.span())); - - let mut shape = ArrayTrait::::new(); - shape.append(1); - shape.append(1); - - let mut data = ArrayTrait::new(); + data.append(3); + data.append(4); data.append(0); - - sequence.append(TensorTrait::new(shape.span(), data.span())); - - let mut shape = ArrayTrait::::new(); - shape.append(1); - shape.append(1); - - let mut data = ArrayTrait::new(); + data.append(1); data.append(0); + data.append(2); + data.append(1); + data.append(3); sequence.append(TensorTrait::new(shape.span(), data.span())); let mut shape = ArrayTrait::::new(); - shape.append(1); - shape.append(1); + shape.append(3); + shape.append(3); let mut data = ArrayTrait::new(); + data.append(0); + data.append(5); + data.append(1); + data.append(2); + data.append(5); + data.append(0); + data.append(3); + data.append(4); data.append(3); sequence.append(TensorTrait::new(shape.span(), data.span())); let mut shape = ArrayTrait::::new(); - shape.append(1); - shape.append(1); + shape.append(3); + shape.append(3); let mut data = ArrayTrait::new(); + data.append(0); + data.append(5); + data.append(2); + data.append(1); + data.append(3); + data.append(2); + data.append(4); + data.append(1); data.append(4); sequence.append(TensorTrait::new(shape.span(), data.span())); let mut shape = ArrayTrait::::new(); - shape.append(1); - shape.append(1); + shape.append(3); + shape.append(3); let mut data = ArrayTrait::new(); + data.append(5); + data.append(4); + data.append(5); data.append(1); + data.append(5); + data.append(0); + data.append(4); + data.append(2); + data.append(2); sequence.append(TensorTrait::new(shape.span(), data.span())); diff --git a/tests/nodes/sequence_empty_fp16x16.cairo b/tests/nodes/sequence_empty_fp16x16.cairo index 745c52c01..da7ae8750 100644 --- a/tests/nodes/sequence_empty_fp16x16.cairo +++ b/tests/nodes/sequence_empty_fp16x16.cairo @@ -1,18 +1,18 @@ mod output_0; -use array::{ArrayTrait, SpanTrait}; +use orion::operators::sequence::FP16x16Sequence; use orion::operators::tensor::FP16x16TensorPartialEq; -use orion::operators::tensor::FP16x16Tensor; -use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::sequence::SequenceTrait; use orion::utils::{assert_eq, assert_seq_eq}; +use array::{ArrayTrait, SpanTrait}; #[test] #[available_gas(2000000000)] fn test_sequence_empty_fp16x16() { let z = output_0::output_0(); - let y = TensorTrait::sequence_empty(); + let y = SequenceTrait::sequence_empty(); assert_seq_eq(y, z); } diff --git a/tests/nodes/sequence_empty_fp8x23.cairo b/tests/nodes/sequence_empty_fp8x23.cairo index 8d8e3f047..dc3ac8cfb 100644 --- a/tests/nodes/sequence_empty_fp8x23.cairo +++ b/tests/nodes/sequence_empty_fp8x23.cairo @@ -1,18 +1,18 @@ mod output_0; -use array::{ArrayTrait, SpanTrait}; use orion::operators::tensor::FP8x23TensorPartialEq; -use orion::operators::tensor::{TensorTrait, Tensor}; -use orion::operators::tensor::FP8x23Tensor; +use orion::operators::sequence::FP8x23Sequence; use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::sequence::SequenceTrait; +use array::{ArrayTrait, SpanTrait}; #[test] #[available_gas(2000000000)] fn test_sequence_empty_fp8x23() { let z = output_0::output_0(); - let y = TensorTrait::sequence_empty(); + let y = SequenceTrait::sequence_empty(); assert_seq_eq(y, z); } diff --git a/tests/nodes/sequence_empty_i32.cairo b/tests/nodes/sequence_empty_i32.cairo index 68b8f1f5e..b8bfd4a39 100644 --- a/tests/nodes/sequence_empty_i32.cairo +++ b/tests/nodes/sequence_empty_i32.cairo @@ -1,18 +1,18 @@ mod output_0; -use array::{ArrayTrait, SpanTrait}; -use orion::operators::tensor::I32TensorPartialEq; -use orion::operators::tensor::{TensorTrait, Tensor}; -use orion::operators::tensor::I32Tensor; +use orion::operators::sequence::SequenceTrait; use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::I32TensorPartialEq; +use orion::operators::sequence::I32Sequence; +use array::{ArrayTrait, SpanTrait}; #[test] #[available_gas(2000000000)] fn test_sequence_empty_i32() { let z = output_0::output_0(); - let y = TensorTrait::sequence_empty(); + let y = SequenceTrait::sequence_empty(); assert_seq_eq(y, z); } diff --git a/tests/nodes/sequence_empty_i8.cairo b/tests/nodes/sequence_empty_i8.cairo index c08a9bf45..efe2d43db 100644 --- a/tests/nodes/sequence_empty_i8.cairo +++ b/tests/nodes/sequence_empty_i8.cairo @@ -1,18 +1,18 @@ mod output_0; -use array::{ArrayTrait, SpanTrait}; -use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::sequence::I8Sequence; +use orion::operators::sequence::SequenceTrait; use orion::utils::{assert_eq, assert_seq_eq}; -use orion::operators::tensor::I8Tensor; use orion::operators::tensor::I8TensorPartialEq; +use array::{ArrayTrait, SpanTrait}; #[test] #[available_gas(2000000000)] fn test_sequence_empty_i8() { let z = output_0::output_0(); - let y = TensorTrait::sequence_empty(); + let y = SequenceTrait::sequence_empty(); assert_seq_eq(y, z); } diff --git a/tests/nodes/sequence_empty_u32.cairo b/tests/nodes/sequence_empty_u32.cairo index 0b20e2c65..85bc767ef 100644 --- a/tests/nodes/sequence_empty_u32.cairo +++ b/tests/nodes/sequence_empty_u32.cairo @@ -1,18 +1,18 @@ mod output_0; -use array::{ArrayTrait, SpanTrait}; -use orion::operators::tensor::{TensorTrait, Tensor}; -use orion::operators::tensor::U32Tensor; -use orion::utils::{assert_eq, assert_seq_eq}; use orion::operators::tensor::U32TensorPartialEq; +use orion::operators::sequence::SequenceTrait; +use orion::utils::{assert_eq, assert_seq_eq}; +use array::{ArrayTrait, SpanTrait}; +use orion::operators::sequence::U32Sequence; #[test] #[available_gas(2000000000)] fn test_sequence_empty_u32() { let z = output_0::output_0(); - let y = TensorTrait::sequence_empty(); + let y = SequenceTrait::sequence_empty(); assert_seq_eq(y, z); } From fec3d19b2bc05812b9d3b7c2ccc135d987a520fb Mon Sep 17 00:00:00 2001 From: Daniel Voronov Date: Fri, 17 Nov 2023 21:34:55 +0100 Subject: [PATCH 03/38] Add missing implementations --- src/operators/sequence.cairo | 3 +++ src/operators/sequence/implementations.cairo | 5 ++++- .../implementations/sequence_bool.cairo | 17 +++++++++++++++++ .../implementations/sequence_fp16x16wide.cairo | 18 ++++++++++++++++++ .../implementations/sequence_fp8x23wide.cairo | 18 ++++++++++++++++++ 5 files changed, 60 insertions(+), 1 deletion(-) create mode 100644 src/operators/sequence/implementations/sequence_bool.cairo create mode 100644 src/operators/sequence/implementations/sequence_fp16x16wide.cairo create mode 100644 src/operators/sequence/implementations/sequence_fp8x23wide.cairo diff --git a/src/operators/sequence.cairo b/src/operators/sequence.cairo index d0e4583a1..eaf749c63 100644 --- a/src/operators/sequence.cairo +++ b/src/operators/sequence.cairo @@ -5,7 +5,10 @@ mod functional; use orion::operators::sequence::core::SequenceTrait; use orion::operators::sequence::implementations::sequence_fp8x23::FP8x23Sequence; +use orion::operators::sequence::implementations::sequence_fp8x23wide::FP8x23WSequence; use orion::operators::sequence::implementations::sequence_fp16x16::FP16x16Sequence; +use orion::operators::sequence::implementations::sequence_fp16x16wide::FP16x16WSequence; use orion::operators::sequence::implementations::sequence_i8::I8Sequence; use orion::operators::sequence::implementations::sequence_i32::I32Sequence; use orion::operators::sequence::implementations::sequence_u32::U32Sequence; +use orion::operators::sequence::implementations::sequence_bool::BoolSequence; diff --git a/src/operators/sequence/implementations.cairo b/src/operators/sequence/implementations.cairo index d9c8974f0..68e3bdf6b 100644 --- a/src/operators/sequence/implementations.cairo +++ b/src/operators/sequence/implementations.cairo @@ -1,7 +1,10 @@ +mod sequence_bool; mod sequence_i8; mod sequence_i32; mod sequence_u32; mod sequence_fp8x23; +mod sequence_fp8x23wide; mod sequence_fp16x16; -mod sequence_fp64x64; +mod sequence_fp16x16wide; mod sequence_fp32x32; +mod sequence_fp64x64; diff --git a/src/operators/sequence/implementations/sequence_bool.cairo b/src/operators/sequence/implementations/sequence_bool.cairo new file mode 100644 index 000000000..0a554efda --- /dev/null +++ b/src/operators/sequence/implementations/sequence_bool.cairo @@ -0,0 +1,17 @@ +use core::option::OptionTrait; + +use orion::operators::tensor::core::Tensor; +use orion::operators::sequence::core::SequenceTrait; +use orion::operators::sequence::functional; +use orion::operators::tensor::implementations::tensor_bool::BoolTensor; + + +impl BoolSequence of SequenceTrait { + fn sequence_construct(tensors: Array>) -> Array> { + functional::sequence_construct::sequence_construct(tensors) + } + + fn sequence_empty() -> Array> { + functional::sequence_empty::sequence_empty::() + } +} diff --git a/src/operators/sequence/implementations/sequence_fp16x16wide.cairo b/src/operators/sequence/implementations/sequence_fp16x16wide.cairo new file mode 100644 index 000000000..66b506a68 --- /dev/null +++ b/src/operators/sequence/implementations/sequence_fp16x16wide.cairo @@ -0,0 +1,18 @@ +use core::option::OptionTrait; + +use orion::operators::tensor::core::Tensor; +use orion::operators::sequence::core::SequenceTrait; +use orion::operators::sequence::functional; +use orion::numbers::fixed_point::implementations::fp16x16wide::core::FP16x16W; +use orion::operators::tensor::implementations::tensor_fp16x16wide::FP16x16WTensor; + + +impl FP16x16WSequence of SequenceTrait { + fn sequence_construct(tensors: Array>) -> Array> { + functional::sequence_construct::sequence_construct(tensors) + } + + fn sequence_empty() -> Array> { + functional::sequence_empty::sequence_empty::() + } +} diff --git a/src/operators/sequence/implementations/sequence_fp8x23wide.cairo b/src/operators/sequence/implementations/sequence_fp8x23wide.cairo new file mode 100644 index 000000000..544f931aa --- /dev/null +++ b/src/operators/sequence/implementations/sequence_fp8x23wide.cairo @@ -0,0 +1,18 @@ +use core::option::OptionTrait; + +use orion::operators::tensor::core::Tensor; +use orion::operators::sequence::core::SequenceTrait; +use orion::operators::sequence::functional; +use orion::numbers::fixed_point::implementations::fp8x23wide::core::FP8x23W; +use orion::operators::tensor::implementations::tensor_fp8x23wide::FP8x23WTensor; + + +impl FP8x23WSequence of SequenceTrait { + fn sequence_construct(tensors: Array>) -> Array> { + functional::sequence_construct::sequence_construct(tensors) + } + + fn sequence_empty() -> Array> { + functional::sequence_empty::sequence_empty::() + } +} From 5850eaedbb6ceba58cee2880e3f2a03ecd8f274c Mon Sep 17 00:00:00 2001 From: Daniel Voronov Date: Fri, 17 Nov 2023 21:45:47 +0100 Subject: [PATCH 04/38] Remove sequence operators from TensorTrait --- src/operators/tensor/core.cairo | 76 ------------------- .../tensor/implementations/tensor_bool.cairo | 8 -- .../implementations/tensor_fp16x16.cairo | 8 -- .../implementations/tensor_fp16x16wide.cairo | 8 -- .../implementations/tensor_fp32x32.cairo | 8 -- .../implementations/tensor_fp64x64.cairo | 8 -- .../implementations/tensor_fp8x23.cairo | 8 -- .../implementations/tensor_fp8x23wide.cairo | 8 -- .../tensor/implementations/tensor_i32.cairo | 8 -- .../tensor/implementations/tensor_i8.cairo | 8 -- .../tensor/implementations/tensor_u32.cairo | 8 -- src/operators/tensor/math.cairo | 2 - .../tensor/math/sequence_construct.cairo | 11 --- .../tensor/math/sequence_empty.cairo | 19 ----- 14 files changed, 188 deletions(-) delete mode 100644 src/operators/tensor/math/sequence_construct.cairo delete mode 100644 src/operators/tensor/math/sequence_empty.cairo diff --git a/src/operators/tensor/core.cairo b/src/operators/tensor/core.cairo index 6e7c728e3..2020a842e 100644 --- a/src/operators/tensor/core.cairo +++ b/src/operators/tensor/core.cairo @@ -97,9 +97,7 @@ impl TensorSerde, impl TDrop: Drop> of Serde { /// # tensor.new @@ -3740,43 +3738,6 @@ trait TensorTrait { keepdims: Option, noop_with_empty_axes: Option ) -> Tensor; - /// # tensor.sequence_empty - /// - /// ```rust - /// fn sequence_empty() -> Array>; - /// ``` - /// - /// Returns an empty tensor sequence. - /// - /// ## Args - /// - /// ## Returns - /// - /// An empty `Array>` instance. - /// - /// ## Examples - /// - /// Let's create a new empty sequence. - /// - /// ```rust - /// use array::{ArrayTrait, SpanTrait}; - /// - /// use orion::operators::tensor::{ - /// TensorTrait, // we import the trait - /// Tensor, // we import the type - /// U32Tensor // we import the implementation. - /// }; - /// - /// fn sequence_empty_example() -> Array> { - /// let sequence = TensorTrait::sequence_empty(); - /// - /// return sequence; - /// } - /// - /// >>> [] - /// ``` - /// - fn sequence_empty() -> Array>; /// # tensor.shrink /// /// ```rust @@ -3828,43 +3789,6 @@ trait TensorTrait { /// ``` /// fn shrink(self: Tensor, bias: Option, lambd: Option) -> Tensor; - /// ## tensor.sequence_construct - /// - /// ```rust - /// fn sequence_construct(tensors: Array>) -> Array>; - /// ``` - /// - /// Constructs a tensor sequence containing the input tensors. - /// - /// ## Args - /// - /// * `tensors`(`Array>`) - The array of input tensors. - /// - /// ## Panics - /// - /// * Panics if input tensor array is empty. - /// - /// ## Returns - /// - /// A tensor sequence `Array>` containing the input tensors. - /// - /// ## Examples - /// - /// ```rust - /// use array::{ArrayTrait, SpanTrait}; - /// - /// use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor}; - /// - /// fn sequence_construct_example() -> Array> { - /// let tensor1 = TensorTrait::new(shape: array![2, 2].span(), data: array![0, 1, 2, 3].span()); - /// let tensor2 = TensorTrait::new(shape: array![2, 2].span(), data: array![4, 5, 6, 7].span()); - /// let result = TensorTrait::sequence_construct(tensors: array![tensor1, tensor2]); - /// return result; - /// } - /// >>> [[0, 1, 2, 3], [4, 5, 6, 7]] - /// ``` - /// - fn sequence_construct(tensors: Array>) -> Array>; /// ## tensor.reduce_min /// /// ```rust diff --git a/src/operators/tensor/implementations/tensor_bool.cairo b/src/operators/tensor/implementations/tensor_bool.cairo index d24d9a88b..e98d32432 100644 --- a/src/operators/tensor/implementations/tensor_bool.cairo +++ b/src/operators/tensor/implementations/tensor_bool.cairo @@ -312,18 +312,10 @@ impl BoolTensor of TensorTrait { constant_of_shape(shape, value) } - fn sequence_construct(tensors: Array>) -> Array> { - math::sequence_construct::sequence_construct(tensors) - } - fn shrink(self: Tensor, bias: Option, lambd: Option) -> Tensor { panic(array!['not supported!']) } - fn sequence_empty() -> Array> { - math::sequence_empty::sequence_empty::() - } - fn reduce_mean( self: @Tensor, axes: Option>, diff --git a/src/operators/tensor/implementations/tensor_fp16x16.cairo b/src/operators/tensor/implementations/tensor_fp16x16.cairo index 610ffba1e..1ec4e4c44 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16.cairo @@ -349,20 +349,12 @@ impl FP16x16Tensor of TensorTrait { math::scatter::scatter(self, updates, indices, axis, reduction) } - fn sequence_construct(tensors: Array>) -> Array> { - math::sequence_construct::sequence_construct(tensors) - } - fn shrink( self: Tensor, bias: Option, lambd: Option ) -> Tensor { math::shrink::shrink(self, bias, lambd) } - fn sequence_empty() -> Array> { - math::sequence_empty::sequence_empty::() - } - fn reduce_mean( self: @Tensor, axes: Option>, diff --git a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo index 3a9873570..4359a0df6 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo @@ -339,20 +339,12 @@ impl FP16x16WTensor of TensorTrait { math::reduce_l2::reduce_l2(self, axis, keepdims) } - fn sequence_construct(tensors: Array>) -> Array> { - math::sequence_construct::sequence_construct(tensors) - } - fn shrink( self: Tensor, bias: Option, lambd: Option ) -> Tensor { math::shrink::shrink(self, bias, lambd) } - fn sequence_empty() -> Array> { - math::sequence_empty::sequence_empty::() - } - fn reduce_mean( self: @Tensor, axes: Option>, diff --git a/src/operators/tensor/implementations/tensor_fp32x32.cairo b/src/operators/tensor/implementations/tensor_fp32x32.cairo index 6b50b1074..88f8e0d78 100644 --- a/src/operators/tensor/implementations/tensor_fp32x32.cairo +++ b/src/operators/tensor/implementations/tensor_fp32x32.cairo @@ -350,20 +350,12 @@ impl FP32x32Tensor of TensorTrait { math::reduce_l2::reduce_l2(self, axis, keepdims) } - fn sequence_construct(tensors: Array>) -> Array> { - math::sequence_construct::sequence_construct(tensors) - } - fn shrink( self: Tensor, bias: Option, lambd: Option ) -> Tensor { math::shrink::shrink(self, bias, lambd) } - fn sequence_empty() -> Array> { - math::sequence_empty::sequence_empty::() - } - fn reduce_mean( self: @Tensor, axes: Option>, diff --git a/src/operators/tensor/implementations/tensor_fp64x64.cairo b/src/operators/tensor/implementations/tensor_fp64x64.cairo index 9931dd661..7ef6f70a4 100644 --- a/src/operators/tensor/implementations/tensor_fp64x64.cairo +++ b/src/operators/tensor/implementations/tensor_fp64x64.cairo @@ -350,20 +350,12 @@ impl FP64x64Tensor of TensorTrait { math::scatter::scatter(self, updates, indices, axis, reduction) } - fn sequence_construct(tensors: Array>) -> Array> { - math::sequence_construct::sequence_construct(tensors) - } - fn shrink( self: Tensor, bias: Option, lambd: Option ) -> Tensor { math::shrink::shrink(self, bias, lambd) } - fn sequence_empty() -> Array> { - math::sequence_empty::sequence_empty::() - } - fn reduce_mean( self: @Tensor, axes: Option>, diff --git a/src/operators/tensor/implementations/tensor_fp8x23.cairo b/src/operators/tensor/implementations/tensor_fp8x23.cairo index 5d5e2e06a..8b3e08f07 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23.cairo @@ -341,18 +341,10 @@ impl FP8x23Tensor of TensorTrait { math::reduce_l2::reduce_l2(self, axis, keepdims) } - fn sequence_construct(tensors: Array>) -> Array> { - math::sequence_construct::sequence_construct(tensors) - } - fn shrink(self: Tensor, bias: Option, lambd: Option) -> Tensor { math::shrink::shrink(self, bias, lambd) } - fn sequence_empty() -> Array> { - math::sequence_empty::sequence_empty::() - } - fn reduce_mean( self: @Tensor, axes: Option>, diff --git a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo index 8fd044343..1477f55a1 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo @@ -328,20 +328,12 @@ impl FP8x23WTensor of TensorTrait { math::scatter::scatter(self, updates, indices, axis, reduction) } - fn sequence_construct(tensors: Array>) -> Array> { - math::sequence_construct::sequence_construct(tensors) - } - fn shrink( self: Tensor, bias: Option, lambd: Option ) -> Tensor { math::shrink::shrink(self, bias, lambd) } - fn sequence_empty() -> Array> { - math::sequence_empty::sequence_empty::() - } - fn reduce_mean( self: @Tensor, axes: Option>, diff --git a/src/operators/tensor/implementations/tensor_i32.cairo b/src/operators/tensor/implementations/tensor_i32.cairo index 9c7c2dc3b..0ae8d9721 100644 --- a/src/operators/tensor/implementations/tensor_i32.cairo +++ b/src/operators/tensor/implementations/tensor_i32.cairo @@ -349,18 +349,10 @@ impl I32Tensor of TensorTrait { panic(array!['not supported!']) } - fn sequence_construct(tensors: Array>) -> Array> { - math::sequence_construct::sequence_construct(tensors) - } - fn shrink(self: Tensor, bias: Option, lambd: Option) -> Tensor { panic(array!['not supported!']) } - fn sequence_empty() -> Array> { - math::sequence_empty::sequence_empty::() - } - fn reduce_mean( self: @Tensor, axes: Option>, diff --git a/src/operators/tensor/implementations/tensor_i8.cairo b/src/operators/tensor/implementations/tensor_i8.cairo index 2d8b7ab4c..dadb4ab4e 100644 --- a/src/operators/tensor/implementations/tensor_i8.cairo +++ b/src/operators/tensor/implementations/tensor_i8.cairo @@ -348,18 +348,10 @@ impl I8Tensor of TensorTrait { panic(array!['not supported!']) } - fn sequence_construct(tensors: Array>) -> Array> { - math::sequence_construct::sequence_construct(tensors) - } - fn shrink(self: Tensor, bias: Option, lambd: Option) -> Tensor { panic(array!['not supported!']) } - fn sequence_empty() -> Array> { - math::sequence_empty::sequence_empty::() - } - fn reduce_mean( self: @Tensor, axes: Option>, diff --git a/src/operators/tensor/implementations/tensor_u32.cairo b/src/operators/tensor/implementations/tensor_u32.cairo index d5ec1bd4b..4d27b515e 100644 --- a/src/operators/tensor/implementations/tensor_u32.cairo +++ b/src/operators/tensor/implementations/tensor_u32.cairo @@ -319,18 +319,10 @@ impl U32Tensor of TensorTrait { panic(array!['not supported!']) } - fn sequence_construct(tensors: Array>) -> Array> { - math::sequence_construct::sequence_construct(tensors) - } - fn shrink(self: Tensor, bias: Option, lambd: Option) -> Tensor { panic(array!['not supported!']) } - fn sequence_empty() -> Array> { - math::sequence_empty::sequence_empty::() - } - fn reduce_mean( self: @Tensor, axes: Option>, diff --git a/src/operators/tensor/math.cairo b/src/operators/tensor/math.cairo index 14625600c..9986650ff 100644 --- a/src/operators/tensor/math.cairo +++ b/src/operators/tensor/math.cairo @@ -45,7 +45,5 @@ mod reduce_l1; mod reduce_sum_square; mod bitwise_and; mod reduce_min; -mod sequence_construct; mod shrink; -mod sequence_empty; mod reduce_mean; diff --git a/src/operators/tensor/math/sequence_construct.cairo b/src/operators/tensor/math/sequence_construct.cairo deleted file mode 100644 index dd8ff3101..000000000 --- a/src/operators/tensor/math/sequence_construct.cairo +++ /dev/null @@ -1,11 +0,0 @@ -use array::{ArrayTrait, SpanTrait}; - -use orion::operators::tensor::{TensorTrait, Tensor}; - - -/// Cf: TensorTrait::sequence_construct docstring -fn sequence_construct>(tensors: Array>) -> Array> { - assert(tensors.len() >= 1, 'Input tensors must be >= 1'); - - return tensors; -} diff --git a/src/operators/tensor/math/sequence_empty.cairo b/src/operators/tensor/math/sequence_empty.cairo deleted file mode 100644 index c823cabe8..000000000 --- a/src/operators/tensor/math/sequence_empty.cairo +++ /dev/null @@ -1,19 +0,0 @@ -use array::{ArrayTrait, SpanTrait}; - -use orion::operators::tensor::{TensorTrait, Tensor}; - - -/// Cf: TensorTrait::sequence_empty docstring -fn sequence_empty, impl TDrop: Drop>() -> Array> { - let mut sequence = ArrayTrait::new(); - - let mut shape = ArrayTrait::::new(); - shape.append(0); - - let mut data = ArrayTrait::new(); - let tensor = TensorTrait::new(shape.span(), data.span()); - - sequence.append(tensor); - - sequence -} From 438e0b3bddf139a523108f0e02b714df42437207 Mon Sep 17 00:00:00 2001 From: Daniel Voronov Date: Sat, 18 Nov 2023 14:40:05 +0100 Subject: [PATCH 05/38] Refactor docgen --- docgen/src/main.rs | 8 +++++ docs/SUMMARY.md | 5 +-- docs/framework/compatibility.md | 4 +-- docs/framework/operators/sequence/README.md | 28 +++++++++++++++ .../sequence/sequence.sequence_construct.md | 36 +++++++++++++++++++ .../sequence/sequence.sequence_empty.md | 36 +++++++++++++++++++ docs/framework/operators/tensor/README.md | 2 -- src/operators/sequence/core.cairo | 12 ++++--- 8 files changed, 120 insertions(+), 11 deletions(-) create mode 100644 docs/framework/operators/sequence/README.md create mode 100644 docs/framework/operators/sequence/sequence.sequence_construct.md create mode 100644 docs/framework/operators/sequence/sequence.sequence_empty.md diff --git a/docgen/src/main.rs b/docgen/src/main.rs index 97d11b112..5c9cc4c86 100644 --- a/docgen/src/main.rs +++ b/docgen/src/main.rs @@ -19,6 +19,14 @@ fn main() { doc_trait(trait_path, doc_path, label); doc_functions(trait_path, doc_path, trait_name, label); + // SEQUENCE DOC + let trait_path = "src/operators/sequence/core.cairo"; + let doc_path = "docs/framework/operators/sequence"; + let label = "sequence"; + let trait_name = "SequenceTrait"; + doc_trait(trait_path, doc_path, label); + doc_functions(trait_path, doc_path, trait_name, label); + // FIXED POINT DOC let trait_path = "src/numbers/fixed_point/core.cairo"; let doc_path = "docs/framework/numbers/fixed-point"; diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index c33854677..e6a030db9 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -103,9 +103,7 @@ * [tensor.reduce\_l2](framework/operators/tensor/tensor.reduce\_l2.md) * [tensor.reduce\_l1](framework/operators/tensor/tensor.reduce\_l1.md) * [tensor.reduce\_min](framework/operators/tensor/tensor.reduce\_min.md) - * [tensor.sequence\_construct](framework/operators/tensor/tensor.sequence\_construct.md) * [tensor.shrink](framework/operators/tensor/tensor.shrink.md) - * [tensor.sequence\_empty](framework/operators/tensor/tensor.sequence\_empty.md) * [tensor.reduce_mean](framework/operators/tensor/tensor.reduce\_mean.md) * [Neural Network](framework/operators/neural-network/README.md) * [nn.relu](framework/operators/neural-network/nn.relu.md) @@ -129,6 +127,9 @@ * [tree.predict_proba](framework/operators/machine-learning/tree-classifier/tree.predict_proba.md) * [XGBoost Regressor](framework/operators/machine-learning/xgboost-regressor/README.md) * [xgboost.predict](framework/operators/machine-learning/xgboost-regressor/xgboost.predict.md) + * [Sequence](framework/operators/sequence/README.md) + * [sequence.sequence\_construct](framework/operators/sequence/sequence.sequence\_construct.md) + * [sequence.sequence\_empty](framework/operators/sequence/sequence.sequence\_empty.md) ## πŸ› Hub diff --git a/docs/framework/compatibility.md b/docs/framework/compatibility.md index 7cd555358..d32618ff9 100644 --- a/docs/framework/compatibility.md +++ b/docs/framework/compatibility.md @@ -83,9 +83,9 @@ You can see below the list of current supported ONNX Operators: | [ConstantOfShape](operators/tensor/tensor.constant_of_shape.md) | :white\_check\_mark: | | [ReduceL1](operators/tensor/tensor.reduce\_l1.md) | :white\_check\_mark: | | [ReduceL2](operators/tensor/tensor.reduce\_l2.md) | :white\_check\_mark: | -| [SequenceConstruct](operators/tensor/tensor.sequence\_construct.md) | :white\_check\_mark: | +| [SequenceConstruct](operators/sequence/sequence.sequence\_construct.md) | :white\_check\_mark: | | [Shrink](operators/tensor/tensor.shrink.md) | :white\_check\_mark: | -| [SequenceEmpty](operators/tensor/tensor.sequence\_empty.md) | :white\_check\_mark: | +| [SequenceEmpty](operators/sequence/sequence.sequence\_empty.md) | :white\_check\_mark: | | [ReduceL2](operators/tensor/tensor.reduce\_l2.md) | :white\_check\_mark: | Current Operators support: **81/156 (52%)** diff --git a/docs/framework/operators/sequence/README.md b/docs/framework/operators/sequence/README.md new file mode 100644 index 000000000..2dab15cb1 --- /dev/null +++ b/docs/framework/operators/sequence/README.md @@ -0,0 +1,28 @@ +# Sequence + +A Sequence represents an array of tensors. + +```rust +use orion::operators::sequence; +``` + +### Data types + +Orion supports currently these `Sequence` types. + +| Data type | dtype | +| ------------------------- | -------------------------------------------------------- | +| 32-bit integer (signed) | `Array>` | +| 8-bit integer (signed) | `Array>` | +| 32-bit integer (unsigned) | `Array>` | +| Fixed point (signed) | `Array>` | + +### Sequence**Trait** + +`SequenceTrait` defines the operations that can be performed on a Sequence of tensors. + +| function | description | +| --- | --- | +| [`sequence.sequence_construct`](sequence.sequence\_construct.md) | Constructs a tensor sequence containing the input tensors. | +| [`sequence.sequence_empty`](sequence.sequence\_empty.md) | Returns an empty tensor sequence. | + diff --git a/docs/framework/operators/sequence/sequence.sequence_construct.md b/docs/framework/operators/sequence/sequence.sequence_construct.md new file mode 100644 index 000000000..92f005020 --- /dev/null +++ b/docs/framework/operators/sequence/sequence.sequence_construct.md @@ -0,0 +1,36 @@ +## sequence.sequence_construct + +```rust + fn sequence_construct(tensors: Array>) -> Array>; +``` + +Constructs a tensor sequence containing the input tensors. + +## Args + +* `tensors`(`Array>`) - The array of input tensors. + +## Panics + +* Panics if input tensor array is empty. + +## Returns + +A tensor sequence `Array>` containing the input tensors. + +## Examples + +```rust +use array::{ArrayTrait, SpanTrait}; + +use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor}; +use orion::operators::sequence::SequenceTrait; + +fn sequence_construct_example() -> Array> { + let tensor1 = TensorTrait::new(shape: array![2, 2].span(), data: array![0, 1, 2, 3].span()); + let tensor2 = TensorTrait::new(shape: array![2, 2].span(), data: array![4, 5, 6, 7].span()); + let result = SequenceTrait::sequence_construct(tensors: array![tensor1, tensor2]); + return result; +} +>>> [[0, 1, 2, 3], [4, 5, 6, 7]] +``` diff --git a/docs/framework/operators/sequence/sequence.sequence_empty.md b/docs/framework/operators/sequence/sequence.sequence_empty.md new file mode 100644 index 000000000..8c2568759 --- /dev/null +++ b/docs/framework/operators/sequence/sequence.sequence_empty.md @@ -0,0 +1,36 @@ +## sequence.sequence_empty + +```rust + fn sequence_empty() -> Array>; +``` + +Returns an empty tensor sequence. + +## Args + +## Returns + +An empty `Array>` instance. + +## Examples + +Let's create a new empty sequence. + +```rust +use array::{ArrayTrait, SpanTrait}; + +use orion::operators::tensor::{ + TensorTrait, // we import the trait + Tensor, // we import the type + U32Tensor // we import the implementation. +}; +use orion::operators::sequence::SequenceTrait; + +fn sequence_empty_example() -> Array> { + let sequence = SequenceTrait::sequence_empty(); + + return sequence; +} + +>>> [] +``` diff --git a/docs/framework/operators/tensor/README.md b/docs/framework/operators/tensor/README.md index dbb70e15b..fb6fe2f34 100644 --- a/docs/framework/operators/tensor/README.md +++ b/docs/framework/operators/tensor/README.md @@ -101,9 +101,7 @@ use orion::operators::tensor::TensorTrait; | [`tensor.reduce_sum_square`](tensor.reduce\_sum\_square.md) | Computes the sum square of the input tensor's elements along the provided axes. | | [`tensor.reduce_l2`](tensor.reduce\_l2.md) | Computes the L2 norm of the input tensor's elements along the provided axes. | | [`tensor.reduce_min`](tensor.reduce\_min.md) | Computes the min of the input tensor's elements along the provided axes. | -| [`tensor.sequence_construct`](tensor.sequence\_construct.md) | Constructs a tensor sequence containing the input tensors. | | [`tensor.shrink`](tensor.shrink.md) | Shrinks the input tensor element-wise to the output tensor with the same datatype and shape based on a defined formula. | -| [`tensor.sequence_empty`](tensor.sequence\_empty.md) | Returns an empty tensor sequence. | | [`tensor.reduce_mean`](tensor.reduce\_mean.md) | Computes the mean of the input tensor's elements along the provided axes. | ## Arithmetic Operations diff --git a/src/operators/sequence/core.cairo b/src/operators/sequence/core.cairo index f86b7b0dc..df33c5ab6 100644 --- a/src/operators/sequence/core.cairo +++ b/src/operators/sequence/core.cairo @@ -2,10 +2,10 @@ use orion::operators::tensor::core::Tensor; /// Trait /// -/// sequence_construct – Constructs a tensor sequence containing the input tensors. +/// sequence_construct - Constructs a tensor sequence containing the input tensors. /// sequence_empty - Returns an empty tensor sequence. trait SequenceTrait { - /// ## tensor.sequence_construct + /// ## sequence.sequence_construct /// /// ```rust /// fn sequence_construct(tensors: Array>) -> Array>; @@ -31,18 +31,19 @@ trait SequenceTrait { /// use array::{ArrayTrait, SpanTrait}; /// /// use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor}; + /// use orion::operators::sequence::SequenceTrait; /// /// fn sequence_construct_example() -> Array> { /// let tensor1 = TensorTrait::new(shape: array![2, 2].span(), data: array![0, 1, 2, 3].span()); /// let tensor2 = TensorTrait::new(shape: array![2, 2].span(), data: array![4, 5, 6, 7].span()); - /// let result = TensorTrait::sequence_construct(tensors: array![tensor1, tensor2]); + /// let result = SequenceTrait::sequence_construct(tensors: array![tensor1, tensor2]); /// return result; /// } /// >>> [[0, 1, 2, 3], [4, 5, 6, 7]] /// ``` /// fn sequence_construct(tensors: Array>) -> Array>; - /// # tensor.sequence_empty + /// ## sequence.sequence_empty /// /// ```rust /// fn sequence_empty() -> Array>; @@ -68,9 +69,10 @@ trait SequenceTrait { /// Tensor, // we import the type /// U32Tensor // we import the implementation. /// }; + /// use orion::operators::sequence::SequenceTrait; /// /// fn sequence_empty_example() -> Array> { - /// let sequence = TensorTrait::sequence_empty(); + /// let sequence = SequenceTrait::sequence_empty(); /// /// return sequence; /// } From feca40f6e3bd78d4f192fa6152a27e23dfd76960 Mon Sep 17 00:00:00 2001 From: Daniel Voronov Date: Sun, 19 Nov 2023 15:11:10 +0100 Subject: [PATCH 06/38] Refactor docstring --- src/operators/sequence/functional/sequence_construct.cairo | 2 +- src/operators/sequence/functional/sequence_empty.cairo | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operators/sequence/functional/sequence_construct.cairo b/src/operators/sequence/functional/sequence_construct.cairo index dd8ff3101..fe4f21aa4 100644 --- a/src/operators/sequence/functional/sequence_construct.cairo +++ b/src/operators/sequence/functional/sequence_construct.cairo @@ -3,7 +3,7 @@ use array::{ArrayTrait, SpanTrait}; use orion::operators::tensor::{TensorTrait, Tensor}; -/// Cf: TensorTrait::sequence_construct docstring +/// Cf: SequenceTrait::sequence_construct docstring fn sequence_construct>(tensors: Array>) -> Array> { assert(tensors.len() >= 1, 'Input tensors must be >= 1'); diff --git a/src/operators/sequence/functional/sequence_empty.cairo b/src/operators/sequence/functional/sequence_empty.cairo index c823cabe8..38b77564c 100644 --- a/src/operators/sequence/functional/sequence_empty.cairo +++ b/src/operators/sequence/functional/sequence_empty.cairo @@ -3,7 +3,7 @@ use array::{ArrayTrait, SpanTrait}; use orion::operators::tensor::{TensorTrait, Tensor}; -/// Cf: TensorTrait::sequence_empty docstring +/// Cf: SequenceTrait::sequence_empty docstring fn sequence_empty, impl TDrop: Drop>() -> Array> { let mut sequence = ArrayTrait::new(); From e1a0dd3729e429895133e4edc7621acbfd7e5ac2 Mon Sep 17 00:00:00 2001 From: zhangzhichao Date: Mon, 18 Dec 2023 15:29:36 +0800 Subject: [PATCH 07/38] feat: Add binary classs in TEC operator Add binary classs in tree_ensemble_classifier operator --- .../tree_ensemble_classifier.cairo | 150 ++++++++++++++++++ 1 file changed, 150 insertions(+) diff --git a/src/operators/ml/tree_ensemble/tree_ensemble_classifier.cairo b/src/operators/ml/tree_ensemble/tree_ensemble_classifier.cairo index db83cb0aa..f882655e8 100644 --- a/src/operators/ml/tree_ensemble/tree_ensemble_classifier.cairo +++ b/src/operators/ml/tree_ensemble/tree_ensemble_classifier.cairo @@ -402,6 +402,156 @@ impl TreeEnsembleClassifierImpl< i += 1; }; + // Binary class + let mut binary = false; + let mut i: usize = 0; + let mut class_ids = self.class_ids; + let mut class_id: usize = 0; + // Get first class_id in class_ids + match class_ids.pop_front() { + Option::Some(c_id) => { + let mut class_id = *c_id; + }, + Option::None(_) => { + let mut class_id: usize = 0; + } + }; + loop { + if i == self.class_ids.len() { + break; + } + match class_ids.pop_front() { + Option::Some(c_id) => { + if *c_id == class_id { + binary = true; + continue; + }else{ + binary = false; + break; + } + + }, + Option::None(_) => { break; } + }; + + }; + + // Clone res + if binary{ + let mut new_res: MutMatrix = MutMatrixImpl::new(res.rows, res.cols); + let mut i: usize = 0; + loop { + if i == res.rows { + break; + } + // Exchange + let res_ele_1 = match res.get(i, 0) { + Option::Some(res_0) => { + new_res.set(i, 1, res_0); + }, + Option::None(_) => { + new_res.set(i, 1, NumberTrait::zero()); + }, + }; + i+=1; + }; + match self.post_transform { + POST_TRANSFORM::NONE => { + let mut i: usize = 0; + loop { + if i == res.rows { + break; + } + // Exchange + let res_ele_0 = match res.get(i, 1) { + Option::Some(res_1) => { + let value = NumberTrait::sub(NumberTrait::one(), res_1); + new_res.set(i, 0, value); + }, + Option::None(_) => { + new_res.set(i, 0, NumberTrait::zero()); + }, + }; + i+=1; + }; + }, + POST_TRANSFORM::SOFTMAX => { + let mut i: usize = 0; + loop { + if i == res.rows { + break; + } + // Exchange + let res_ele_0 = match res.get(i, 1) { + Option::Some(res_1) => { + new_res.set(i, 0, res_1.neg()); + }, + Option::None(_) => { + new_res.set(i, 0, NumberTrait::zero()); + }, + }; + i+=1; + }; + }, + POST_TRANSFORM::LOGISTIC => { + let mut i: usize = 0; + loop { + if i == res.rows { + break; + } + // Exchange + let res_ele_0 = match res.get(i, 1) { + Option::Some(res_1) => { + new_res.set(i, 0, res_1.neg()); + }, + Option::None(_) => { + new_res.set(i, 0, NumberTrait::zero()); + }, + }; + i+=1; + }; + }, + POST_TRANSFORM::SOFTMAXZERO => { + let mut i: usize = 0; + loop { + if i == res.rows { + break; + } + // Exchange + let res_ele_0 = match res.get(i, 1) { + Option::Some(res_1) => { + new_res.set(i, 0, res_1.neg()); + }, + Option::None(_) => { + new_res.set(i, 0, NumberTrait::zero()); + }, + }; + i+=1; + }; + }, + POST_TRANSFORM::PROBIT => { + let mut i: usize = 0; + loop { + if i == res.rows { + break; + } + // Exchange + let res_ele_0 = match res.get(i, 1) { + Option::Some(res_1) => { + let value = NumberTrait::sub(NumberTrait::one(), res_1); + new_res.set(i, 0, value); + }, + Option::None(_) => { + new_res.set(i, 0, NumberTrait::zero()); + }, + }; + i+=1; + }; + }, + }; + res = new_res; + } + // Post Transform let mut new_scores = match self.post_transform { POST_TRANSFORM::NONE => res, // No action required From 6747b0dcf12057e01574ee2499c5f211ef0161e4 Mon Sep 17 00:00:00 2001 From: chachaleo Date: Thu, 21 Dec 2023 05:15:44 +0100 Subject: [PATCH 08/38] linear classifier --- docgen/src/main.rs | 8 + docs/SUMMARY.md | 7 +- .../linear-classifier/README.md | 0 .../linear_classifier.predict.md | 100 ++++++ docs/framework/operators/tensor/tensor.erf.md | 4 - .../operators/tensor/tensor.gather_nd.md | 8 - src/operators/ml.cairo | 9 +- src/operators/ml/linear.cairo | 1 + .../ml/linear/linear_classifier.cairo | 276 +++++++++++++++ .../tensor/implementations/tensor_bool.cairo | 4 +- .../implementations/tensor_fp16x16.cairo | 6 +- .../implementations/tensor_fp16x16wide.cairo | 6 +- .../implementations/tensor_fp32x32.cairo | 6 +- .../implementations/tensor_fp64x64.cairo | 6 +- .../implementations/tensor_fp8x23.cairo | 6 +- .../implementations/tensor_fp8x23wide.cairo | 6 +- .../tensor/implementations/tensor_i32.cairo | 6 +- .../tensor/implementations/tensor_i8.cairo | 6 +- .../tensor/implementations/tensor_u32.cairo | 6 +- src/operators/tensor/math/gather_nd.cairo | 45 ++- tests/lib.cairo | 11 +- tests/ml.cairo | 7 +- tests/ml/linear_classifier_test.cairo | 319 ++++++++++++++++++ .../gather_nd_fp16x16_3d_batch_dims1.cairo | 2 +- .../gather_nd_fp16x16_3d_batch_dims2.cairo | 2 +- .../nodes/gather_nd_fp16x16_3d_default.cairo | 2 +- .../gather_nd_fp8x23_3d_batch_dims1.cairo | 2 +- .../gather_nd_fp8x23_3d_batch_dims2.cairo | 2 +- tests/nodes/gather_nd_fp8x23_3d_default.cairo | 2 +- .../nodes/gather_nd_i32_3d_batch_dims1.cairo | 2 +- .../nodes/gather_nd_i32_3d_batch_dims2.cairo | 2 +- tests/nodes/gather_nd_i32_3d_default.cairo | 2 +- tests/nodes/gather_nd_i8_3d_batch_dims1.cairo | 2 +- tests/nodes/gather_nd_i8_3d_default.cairo | 2 +- tests/nodes/gather_nd_u32_batch_dims1.cairo | 2 +- tests/nodes/gather_nd_u32_batch_dims2.cairo | 2 +- tests/nodes/gather_nd_u32_default.cairo | 2 +- 37 files changed, 801 insertions(+), 80 deletions(-) create mode 100644 docs/framework/operators/machine-learning/linear-classifier/README.md create mode 100644 docs/framework/operators/machine-learning/linear-classifier/linear_classifier.predict.md create mode 100644 src/operators/ml/linear/linear_classifier.cairo create mode 100644 tests/ml/linear_classifier_test.cairo diff --git a/docgen/src/main.rs b/docgen/src/main.rs index 20426c832..b949f7e48 100644 --- a/docgen/src/main.rs +++ b/docgen/src/main.rs @@ -66,6 +66,14 @@ fn main() { let trait_name: &str = "LinearRegressorTrait"; doc_trait(trait_path, doc_path, label); doc_functions(trait_path, doc_path, trait_name, label); + + // LINEAR REGRESSOR DOC + let trait_path = "src/operators/ml/linear/linear_classifier.cairo"; + let doc_path = "docs/framework/operators/machine-learning/linear-classifier"; + let label = "linear_classifier"; + let trait_name: &str = "LinearClassifierTrait"; + doc_trait(trait_path, doc_path, label); + doc_functions(trait_path, doc_path, trait_name, label); } fn doc_trait(trait_path: &str, doc_path: &str, label: &str) { diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 37f90ece0..da8439fb3 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -142,7 +142,12 @@ * [Machine Learning](framework/operators/machine-learning/README.md) * [Tree Ensemble Classifier](framework/operators/machine-learning/tree-ensemble-classifier/README.md) * [tree\_ensemble\_classifier.predict](framework/operators/machine-learning/tree-ensemble-classifier/tree\_ensemble\_classifier.predict.md) - + * [Tree Ensemble Regressor](framework/operators/machine-learning/tree-ensemble-regressor/README.md) + * [tree\_ensemble\_regressor.predict](framework/operators/machine-learning/tree-ensemble-regressor/tree\_ensemble\_regressor.predict.md) + * [Linear Classifier](framework/operators/machine-learning/linear-classifier/README.md) + * [linear\_classifier.predict](framework/operators/machine-learning/linear-classifier/linear\_classifier.predict.md) + * [Linear Regressor](framework/operators/machine-learning/linear-regressor/README.md) + * [linear\_regressor.predict](framework/operators/machine-learning/linear-regressor/linear\_regressor.predict.md) ## πŸ› Hub * [Models](hub/algorithms.md) diff --git a/docs/framework/operators/machine-learning/linear-classifier/README.md b/docs/framework/operators/machine-learning/linear-classifier/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/docs/framework/operators/machine-learning/linear-classifier/linear_classifier.predict.md b/docs/framework/operators/machine-learning/linear-classifier/linear_classifier.predict.md new file mode 100644 index 000000000..3b9537b1c --- /dev/null +++ b/docs/framework/operators/machine-learning/linear-classifier/linear_classifier.predict.md @@ -0,0 +1,100 @@ +# LinearClassifierTrait::predict + +```rust + fn predict(ref self: LinearClassifier, X: Tensor) -> Tensor; +``` + +Linear Regressor. Performs the linear classification. + +## Args + +* `self`: LinearClassifier - A LinearClassifier object. +* `X`: Input 2D tensor. + +## Returns + +* Tensor containing the generalized linear regression evaluation of the input X. + +## Type Constraints + +`LinearClassifier` and `X` must be fixed points + +## Examples + +```rust +use orion::numbers::FP16x16; +use orion::operators::tensor::{Tensor, TensorTrait, FP16x16Tensor, U32Tensor}; + +use orion::operators::ml::linear::linear_classifier::{ + LinearClassifierTrait, POST_TRANSFORM, LinearClassifier +}; + +fn linear_classifier_helper( + post_transform: POST_TRANSFORM +) -> (LinearClassifier, Tensor) { + + let classlabels: Span = array![0, 1, 2].span(); + let classlabels = Option::Some(classlabels); + + let classlabels_strings: Option> = Option::None; + + let coefficients: Span = array![ + FP16x16 { mag: 38011, sign: true }, + FP16x16 { mag: 19005, sign: true }, + FP16x16 { mag: 5898, sign: true }, + FP16x16 { mag: 38011, sign: false }, + FP16x16 { mag: 19005, sign: false }, + FP16x16 { mag: 5898, sign: false }, + ] + .span(); + + let intercepts: Span = array![ + FP16x16 { mag: 176947, sign: false }, + FP16x16 { mag: 176947, sign: true }, + FP16x16 { mag: 32768, sign: false }, + ] + .span(); + let intercepts = Option::Some(intercepts); + + let multi_class: usize = 0; + + let mut classifier: LinearClassifier = LinearClassifier { + classlabels, + coefficients, + intercepts, + multi_class, + post_transform + }; + + let mut X: Tensor = TensorTrait::new( + array![3, 2].span(), + array![ + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 65536, sign: false }, + FP16x16 { mag: 131072, sign: false }, + FP16x16 { mag: 196608, sign: false }, + FP16x16 { mag: 262144, sign: false }, + FP16x16 { mag: 327680, sign: false }, + ] + .span() + ); + + (classifier, X) +} + +fn linear_classifier_multi_softmax() -> (Span, Tensor) { + let (mut classifier, X) = linear_classifier_helper(POST_TRANSFORM::SOFTMAX); + + let (labels, mut scores) = LinearClassifierTrait::predict(ref classifier, X); + + (labels, scores) +} + +>>> +([0, 2, 2], + [ + [0.852656, 0.009192, 0.138152], + [0.318722, 0.05216, 0.629118], + [0.036323, 0.090237, 0.87344] + ]) +``` \ No newline at end of file diff --git a/docs/framework/operators/tensor/tensor.erf.md b/docs/framework/operators/tensor/tensor.erf.md index 19ce86a94..384a941d0 100644 --- a/docs/framework/operators/tensor/tensor.erf.md +++ b/docs/framework/operators/tensor/tensor.erf.md @@ -6,10 +6,6 @@ Computes the mean of the input tensor's elements along the provided axes. -## Args - -* `self`(`@Tensor`) - The input tensor. - ## Returns A new `Tensor` of the same shape as the input tensor with diff --git a/docs/framework/operators/tensor/tensor.gather_nd.md b/docs/framework/operators/tensor/tensor.gather_nd.md index 021d4f235..ce6f94462 100644 --- a/docs/framework/operators/tensor/tensor.gather_nd.md +++ b/docs/framework/operators/tensor/tensor.gather_nd.md @@ -21,14 +21,6 @@ Given data tensor of rank r >= 1, indices tensor of rank q >= 1, and batch_dims ## Returns A new `Tensor` . - -## Example - -```rust -use array::{ArrayTrait, SpanTrait}; - -use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor}; - fn gather_nd_example() -> Tensor { let tensor = TensorTrait::::new( shape: array![2, 2].span(), diff --git a/src/operators/ml.cairo b/src/operators/ml.cairo index 4bfd10060..1098e42f3 100644 --- a/src/operators/ml.cairo +++ b/src/operators/ml.cairo @@ -11,4 +11,11 @@ use orion::operators::ml::tree_ensemble::tree_ensemble_classifier::{ use orion::operators::ml::tree_ensemble::tree_ensemble_regressor::{ TreeEnsembleRegressor, TreeEnsembleRegressorImpl, TreeEnsembleRegressorTrait, AGGREGATE_FUNCTION }; -use orion::operators::ml::linear::linear_regressor::{LinearRegressorTrait, LinearRegressorImpl, LinearRegressor}; + +use orion::operators::ml::linear::linear_regressor::{ + LinearRegressorTrait, LinearRegressorImpl, LinearRegressor +}; + +use orion::operators::ml::linear::linear_classifier::{ + LinearClassifierTrait, LinearClassifierImpl, LinearClassifier +}; diff --git a/src/operators/ml/linear.cairo b/src/operators/ml/linear.cairo index bc8062bd1..ef33bae6b 100644 --- a/src/operators/ml/linear.cairo +++ b/src/operators/ml/linear.cairo @@ -1 +1,2 @@ mod linear_regressor; +mod linear_classifier; diff --git a/src/operators/ml/linear/linear_classifier.cairo b/src/operators/ml/linear/linear_classifier.cairo new file mode 100644 index 000000000..21e524e4c --- /dev/null +++ b/src/operators/ml/linear/linear_classifier.cairo @@ -0,0 +1,276 @@ +use core::array::ArrayTrait; +use core::array::SpanTrait; +use orion::numbers::FP16x16; + +use orion::operators::tensor::{Tensor, TensorTrait}; +use orion::numbers::NumberTrait; +use orion::operators::tensor::{I8Tensor, I32Tensor, U32Tensor, FP16x16Tensor, FP16x16TensorAdd}; +use orion::numbers::{FP32x32, FP32x32Impl, FixedTrait}; +use orion::operators::nn::{NNTrait, FP16x16NN}; + + +#[derive(Destruct)] +struct LinearClassifier { + classlabels: Option>, + coefficients: Span, + intercepts: Option>, + multi_class: usize, + post_transform: POST_TRANSFORM, +} + + +#[derive(Copy, Drop)] +enum POST_TRANSFORM { + NONE, + SOFTMAX, + LOGISTIC, + SOFTMAXZERO, + PROBIT, +} + +/// Trait +/// +/// predict - Performs the linear classification. +trait LinearClassifierTrait { + /// # LinearClassifierTrait::predict + /// + /// ```rust + /// fn predict(ref self: LinearClassifier, X: Tensor) -> Tensor; + /// ``` + /// + /// Linear Classifier. Performs the linear classification. + /// + /// ## Args + /// + /// * `self`: LinearClassifier - A LinearClassifier object. + /// * `X`: Input 2D tensor. + /// + /// ## Returns + /// + /// * Tensor containing the linear classification evaluation of the input X. + /// + /// ## Type Constraints + /// + /// `LinearClassifier` and `X` must be fixed points + /// + /// ## Examples + /// + /// ```rust + /// use orion::numbers::FP16x16; + /// use orion::operators::tensor::{Tensor, TensorTrait, FP16x16Tensor, U32Tensor}; + /// + /// use orion::operators::ml::linear::linear_classifier::{ + /// LinearClassifierTrait, POST_TRANSFORM, LinearClassifier + /// }; + /// + /// fn linear_classifier_helper( + /// post_transform: POST_TRANSFORM + /// ) -> (LinearClassifier, Tensor) { + /// + /// let classlabels: Span = array![0, 1, 2].span(); + /// let classlabels = Option::Some(classlabels); + /// + /// let classlabels_strings: Option> = Option::None; + /// + /// let coefficients: Span = array![ + /// FP16x16 { mag: 38011, sign: true }, + /// FP16x16 { mag: 19005, sign: true }, + /// FP16x16 { mag: 5898, sign: true }, + /// FP16x16 { mag: 38011, sign: false }, + /// FP16x16 { mag: 19005, sign: false }, + /// FP16x16 { mag: 5898, sign: false }, + /// ] + /// .span(); + /// + /// let intercepts: Span = array![ + /// FP16x16 { mag: 176947, sign: false }, + /// FP16x16 { mag: 176947, sign: true }, + /// FP16x16 { mag: 32768, sign: false }, + /// ] + /// .span(); + /// let intercepts = Option::Some(intercepts); + /// + /// let multi_class: usize = 0; + /// + /// let mut classifier: LinearClassifier = LinearClassifier { + /// classlabels, + /// coefficients, + /// intercepts, + /// multi_class, + /// post_transform + /// }; + /// + /// let mut X: Tensor = TensorTrait::new( + /// array![3, 2].span(), + /// array![ + /// FP16x16 { mag: 0, sign: false }, + /// FP16x16 { mag: 65536, sign: false }, + /// FP16x16 { mag: 131072, sign: false }, + /// FP16x16 { mag: 196608, sign: false }, + /// FP16x16 { mag: 262144, sign: false }, + /// FP16x16 { mag: 327680, sign: false }, + /// ] + /// .span() + /// ); + /// + /// (classifier, X) + /// } + /// + /// fn linear_classifier_multi_softmax() -> (Span, Tensor) { + /// let (mut classifier, X) = linear_classifier_helper(POST_TRANSFORM::SOFTMAX); + /// + /// let (labels, mut scores) = LinearClassifierTrait::predict(ref classifier, X); + /// + /// (labels, scores) + /// } + /// + /// >>> + /// ([0, 2, 2], + /// [ + /// [0.852656, 0.009192, 0.138152], + /// [0.318722, 0.05216, 0.629118], + /// [0.036323, 0.090237, 0.87344] + /// ]) + /// ``` + fn predict(ref self: LinearClassifier, X: Tensor) -> (Span, Tensor); +} + +impl LinearClassifierImpl< + T, + MAG, + +Drop, + +Copy, + +NumberTrait, + +PartialOrd, + +PartialEq, + +Add, + +TensorTrait, + +TensorTrait, + +AddEq, + +Div, + +Mul, + +Add>, + +NNTrait +> of LinearClassifierTrait { + fn predict(ref self: LinearClassifier, X: Tensor) -> (Span, Tensor) { + let n: usize = self.coefficients.len() / *(X.shape).at(1); + let mut shape = ArrayTrait::::new(); + shape.append(n); + shape.append(*(X.shape).at(1)); + + let mut coefficients = TensorTrait::new(shape.span(), self.coefficients); + let coefficients = coefficients.transpose(array![1, 0].span()); + + let mut scores = X.matmul(@coefficients); + match self.intercepts { + Option::Some(intercepts) => { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(intercepts.len()); + let intercepts = TensorTrait::new(shape.span(), intercepts); + scores = TensorTrait::add(scores, intercepts); + }, + Option::None(_) => {}, + }; + + let (n_classes, classlabels) = match self.classlabels { + Option::Some(classlabels) => { (classlabels.len(), classlabels) }, + Option::None(_) => { (0, ArrayTrait::::new().span()) }, + }; + if *coefficients.shape.at(1) == 1 && n_classes == 2 { + let mut new_scores = ArrayTrait::new(); + + loop { + match scores.data.pop_front() { + Option::Some(item) => { + new_scores.append(NumberTrait::neg(*item)); + new_scores.append(*item); + }, + Option::None(_) => { break; }, + } + }; + scores = TensorTrait::new(array![*scores.shape.at(0), 2].span(), new_scores.span()); + } + // Post Transform + scores = match self.post_transform { + POST_TRANSFORM::NONE => { scores }, + POST_TRANSFORM::SOFTMAX => { NNTrait::softmax(@scores, 1) }, + POST_TRANSFORM::LOGISTIC => { NNTrait::sigmoid(@scores) }, + POST_TRANSFORM::SOFTMAXZERO => core::panic_with_felt252( + 'Softmax_zero not supported yet' + ), + POST_TRANSFORM::PROBIT => core::panic_with_felt252('Probit not supported yet'), + }; + + // Labels + let mut labels_list = ArrayTrait::new(); + if *scores.shape.at(1) > 1 { + let mut labels = scores.argmax(1, Option::None, Option::None); + loop { + match labels.data.pop_front() { + Option::Some(i) => { labels_list.append(*classlabels[*i]); }, + Option::None(_) => { break; } + }; + }; + } else { + let mut i = 0; + match self.post_transform { + POST_TRANSFORM::NONE => { + loop { + if i == scores.data.len() { + break; + } + if *scores.data.at(i) >= NumberTrait::zero() { + labels_list.append(*classlabels[0]); + } else { + labels_list.append(0); + } + i += 1; + }; + }, + POST_TRANSFORM::SOFTMAX => { + loop { + if i == scores.data.len() { + break; + } + if *scores.data.at(i) >= NumberTrait::half() { + labels_list.append(*classlabels[0]); + } else { + labels_list.append(0); + } + i += 1; + }; + }, + POST_TRANSFORM::LOGISTIC => { + loop { + if i == scores.data.len() { + break; + } + if *scores.data.at(i) >= NumberTrait::half() { + labels_list.append(*classlabels[0]); + } else { + labels_list.append(0); + } + i += 1; + }; + }, + POST_TRANSFORM::SOFTMAXZERO => core::panic_with_felt252( + 'Softmax_zero not supported yet' + ), + POST_TRANSFORM::PROBIT => core::panic_with_felt252('Probit not supported yet'), + }; + } + + (labels_list.span(), scores) + } +} + + +fn max(a: usize, b: usize) -> usize { + if a > b { + return a; + } else { + return b; + } +} + diff --git a/src/operators/tensor/implementations/tensor_bool.cairo b/src/operators/tensor/implementations/tensor_bool.cairo index d2afe3fc5..1f11e1f3f 100644 --- a/src/operators/tensor/implementations/tensor_bool.cairo +++ b/src/operators/tensor/implementations/tensor_bool.cairo @@ -472,7 +472,9 @@ impl BoolTensor of TensorTrait { panic(array!['not supported!']) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } } diff --git a/src/operators/tensor/implementations/tensor_fp16x16.cairo b/src/operators/tensor/implementations/tensor_fp16x16.cairo index ccaf5903d..ce489a45a 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16.cairo @@ -520,10 +520,12 @@ impl FP16x16Tensor of TensorTrait { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } - + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { math::reduce_log_sum::reduce_log_sum(self, axis, keepdims) } diff --git a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo index dc32202ed..1ea152edf 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo @@ -486,10 +486,12 @@ impl FP16x16WTensor of TensorTrait { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } - + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { math::reduce_log_sum::reduce_log_sum(self, axis, keepdims) } diff --git a/src/operators/tensor/implementations/tensor_fp32x32.cairo b/src/operators/tensor/implementations/tensor_fp32x32.cairo index 9100d6f82..5e9b5512a 100644 --- a/src/operators/tensor/implementations/tensor_fp32x32.cairo +++ b/src/operators/tensor/implementations/tensor_fp32x32.cairo @@ -521,10 +521,12 @@ impl FP32x32Tensor of TensorTrait { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } - + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { math::reduce_log_sum::reduce_log_sum(self, axis, keepdims) } diff --git a/src/operators/tensor/implementations/tensor_fp64x64.cairo b/src/operators/tensor/implementations/tensor_fp64x64.cairo index ee6441058..2b5142789 100644 --- a/src/operators/tensor/implementations/tensor_fp64x64.cairo +++ b/src/operators/tensor/implementations/tensor_fp64x64.cairo @@ -522,10 +522,12 @@ impl FP64x64Tensor of TensorTrait { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } - + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { math::reduce_log_sum::reduce_log_sum(self, axis, keepdims) } diff --git a/src/operators/tensor/implementations/tensor_fp8x23.cairo b/src/operators/tensor/implementations/tensor_fp8x23.cairo index 17a601f7b..18cd91af7 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23.cairo @@ -520,10 +520,12 @@ impl FP8x23Tensor of TensorTrait { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } - + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { math::reduce_log_sum::reduce_log_sum(self, axis, keepdims) } diff --git a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo index a7d19901b..a1f9bcbdf 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo @@ -473,10 +473,12 @@ impl FP8x23WTensor of TensorTrait { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } - + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { math::reduce_log_sum::reduce_log_sum(self, axis, keepdims) } diff --git a/src/operators/tensor/implementations/tensor_i32.cairo b/src/operators/tensor/implementations/tensor_i32.cairo index a987b0633..c8753c5a9 100644 --- a/src/operators/tensor/implementations/tensor_i32.cairo +++ b/src/operators/tensor/implementations/tensor_i32.cairo @@ -517,10 +517,12 @@ impl I32Tensor of TensorTrait { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } - + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { panic(array!['not supported!']) } diff --git a/src/operators/tensor/implementations/tensor_i8.cairo b/src/operators/tensor/implementations/tensor_i8.cairo index 8c1e2fd32..fa67a7950 100644 --- a/src/operators/tensor/implementations/tensor_i8.cairo +++ b/src/operators/tensor/implementations/tensor_i8.cairo @@ -515,10 +515,12 @@ impl I8Tensor of TensorTrait { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } - + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { panic(array!['not supported!']) } diff --git a/src/operators/tensor/implementations/tensor_u32.cairo b/src/operators/tensor/implementations/tensor_u32.cairo index 5b2058401..9a90ef883 100644 --- a/src/operators/tensor/implementations/tensor_u32.cairo +++ b/src/operators/tensor/implementations/tensor_u32.cairo @@ -458,10 +458,12 @@ impl U32Tensor of TensorTrait { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } - + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { panic(array!['not supported!']) } diff --git a/src/operators/tensor/math/gather_nd.cairo b/src/operators/tensor/math/gather_nd.cairo index 120eff5be..737a4fe32 100644 --- a/src/operators/tensor/math/gather_nd.cairo +++ b/src/operators/tensor/math/gather_nd.cairo @@ -14,12 +14,7 @@ use orion::operators::tensor::U32TensorPartialEq; use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor}; /// Cf: TensorTrait::gather_nd docstring -fn gather_nd< - T, - impl TTensorTrait: TensorTrait, - impl TCopy: Copy, - impl TDrop: Drop, ->( +fn gather_nd, impl TCopy: Copy, impl TDrop: Drop,>( self: @Tensor, indices: Tensor, batch_dims: Option ) -> Tensor { let batch_dims = match batch_dims { @@ -29,19 +24,22 @@ fn gather_nd< let data_rank = (*self.shape).len(); let indices_rank = (indices.shape).len(); - assert((data_rank >= 1 ) & (indices_rank >= 1), 'rank must > 1'); - + assert((data_rank >= 1) & (indices_rank >= 1), 'rank must > 1'); + let mut data_shape = *self.shape; let mut indices_shape = indices.shape; let mut data_shape_clone = data_shape.clone(); let mut indices_shape_clone = indices_shape.clone(); let indices_shape_last = indices_shape_clone.pop_back().unwrap(); - assert((*indices_shape_last >= 1) & (*indices_shape_last <= data_rank-batch_dims), 'check indices'); + assert( + (*indices_shape_last >= 1) & (*indices_shape_last <= data_rank - batch_dims), + 'check indices' + ); let mut batch_dims_shape = ArrayTrait::new(); let mut output_shape = ArrayTrait::new(); - let mut index_data = ArrayTrait::new(); + let mut index_data = ArrayTrait::new(); let mut output_data = ArrayTrait::new(); let mut batch_dims_size = batch_dims; @@ -51,7 +49,7 @@ fn gather_nd< let mut ind = 0; loop { if (ind == batch_dims) { - break(); + break (); } match indices_shape_clone.pop_front() { Option::Some(val) => { @@ -65,17 +63,14 @@ fn gather_nd< loop { match indices_shape_clone.pop_front() { - Option::Some(val) => { - batch_dims_shape.append(*val); - }, + Option::Some(val) => { batch_dims_shape.append(*val); }, Option::None(_) => { break; } }; }; if (*indices_shape_last == data_rank - batch_dims) { output_shape = batch_dims_shape; - } - else { + } else { let mut ind = 0; let mut multiple = 1; output_shape = batch_dims_shape; @@ -136,16 +131,18 @@ fn gather_nd< match data_indices.pop_front() { Option::Some(val) => { let index = ind % *indices_shape_last; - let incr= total_data_len * (ind / breaker); + let incr = total_data_len * (ind / breaker); result += (*val * total_data_len / *multiple_data_len.at(index)); ind += 1; - if (index == *indices_shape_last-1) { - let mut data_ind:usize = result ; + if (index == *indices_shape_last - 1) { + let mut data_ind: usize = result; loop { - if data_ind == result + incrementer { break; } + if data_ind == result + incrementer { + break; + } index_data.append(data_ind + incr); - data_ind+=1; + data_ind += 1; }; result = 0; }; @@ -156,13 +153,11 @@ fn gather_nd< loop { match index_data.pop_front() { - Option::Some(val) => { - output_data.append(*self.data[val]); - }, + Option::Some(val) => { output_data.append(*self.data[val]); }, Option::None(_) => { break; } }; }; let mut output_tensor = TensorTrait::::new(output_shape.span(), output_data.span()); return output_tensor; -} \ No newline at end of file +} diff --git a/tests/lib.cairo b/tests/lib.cairo index f5cecb77d..defb16442 100644 --- a/tests/lib.cairo +++ b/tests/lib.cairo @@ -1,7 +1,8 @@ -mod numbers; -mod performance; -mod tensor_core; -mod nodes; +//mod numbers; +//mod performance; +//mod tensor_core; +//mod nodes; mod ml; -mod operators; +//mod operators; + diff --git a/tests/ml.cairo b/tests/ml.cairo index bb54ef56f..ef228f254 100644 --- a/tests/ml.cairo +++ b/tests/ml.cairo @@ -1,4 +1,5 @@ -mod tree_ensemble_classifier; -mod tree_ensemble_regressor; -mod linear_regressor_test; +//mod tree_ensemble_classifier; +//mod tree_ensemble_regressor; +//mod linear_regressor_test; +mod linear_classifier_test; diff --git a/tests/ml/linear_classifier_test.cairo b/tests/ml/linear_classifier_test.cairo new file mode 100644 index 000000000..3ed27cfcc --- /dev/null +++ b/tests/ml/linear_classifier_test.cairo @@ -0,0 +1,319 @@ +use orion::numbers::FP16x16; +use orion::operators::tensor::{Tensor, TensorTrait, FP16x16Tensor, U32Tensor}; + +use orion::operators::ml::linear::linear_classifier::{ + LinearClassifierTrait, POST_TRANSFORM, LinearClassifier +}; +use core::debug::PrintTrait; + +#[test] +#[available_gas(200000000000)] +fn test_linear_classifier_multi_none() { + let (mut classifier, X) = linear_classifier_helper(POST_TRANSFORM::NONE); + + let (labels, mut scores) = LinearClassifierTrait::predict(ref classifier, X); + + // ASSERT LABELS + assert(*labels[0] == 0, 'labels[0]'); + assert(*labels[1] == 2, 'labels[1]'); + assert(*labels[2] == 2, 'labels[2]'); + assert(labels.len() == 3, 'len(labels)'); + + // ASSERT SCORES + assert(*scores.data[0] == FP16x16 { mag: 157942, sign: false }, '*scores[0] == 2.41'); + assert(*scores.data[1] == FP16x16 { mag: 138936, sign: true }, '*scores[1] == -2.12'); + assert(*scores.data[2] == FP16x16 { mag: 38666, sign: false }, '*scores[2] == 0.59'); + assert(*scores.data[3] == FP16x16 { mag: 43910, sign: false }, '*scores[3] == 0.67'); + assert(*scores.data[4] == FP16x16 { mag: 74710, sign: true }, '*scores[4] == -1.14'); + assert(*scores.data[5] == FP16x16 { mag: 88472, sign: false }, '*scores[5] == 1.35'); + assert(*scores.data[6] == FP16x16 { mag: 70122, sign: true }, '*scores[6] == -1.07'); + assert(*scores.data[7] == FP16x16 { mag: 10484, sign: true }, '*scores[7] == -0.16'); + assert(*scores.data[8] == FP16x16 { mag: 138278, sign: false }, '*scores[8] == 2.11'); +} + + +#[test] +#[available_gas(200000000000)] +fn test_linear_classifier_multi_softmax() { + let (mut classifier, X) = linear_classifier_helper(POST_TRANSFORM::SOFTMAX); + + let (labels, mut scores) = LinearClassifierTrait::predict(ref classifier, X); + + // ASSERT LABELS + assert(*labels[0] == 0, 'labels[0]'); + assert(*labels[1] == 2, 'labels[1]'); + assert(*labels[2] == 2, 'labels[2]'); + assert(labels.len() == 3, 'len(labels)'); + + // ASSERT SCORES + assert(*scores.data[0] == FP16x16 { mag: 55879, sign: false }, '*scores[0] == 0.852656'); + assert(*scores.data[1] == FP16x16 { mag: 602, sign: false }, '*scores[1] == 0.009192'); + assert(*scores.data[2] == FP16x16 { mag: 9053, sign: false }, '*scores[2] == 0.138152'); + assert(*scores.data[3] == FP16x16 { mag: 20888, sign: false }, '*scores[3] == 0.318722'); + assert(*scores.data[4] == FP16x16 { mag: 3418, sign: false }, '*scores[4] == 0.05216'); + assert(*scores.data[5] == FP16x16 { mag: 41229, sign: false }, '*scores[5] == 0.629118'); + assert(*scores.data[6] == FP16x16 { mag: 2380, sign: false }, '*scores[6] == 0.036323'); + assert(*scores.data[7] == FP16x16 { mag: 5914, sign: false }, '*scores[7] == 0.090237'); + assert(*scores.data[8] == FP16x16 { mag: 57241, sign: false }, '*scores[8] == 0.87344'); +} + + +#[test] +#[available_gas(200000000000)] +fn test_linear_classifier_multi_logistic() { + let (mut classifier, X) = linear_classifier_helper(POST_TRANSFORM::LOGISTIC); + + let (labels, mut scores) = LinearClassifierTrait::predict(ref classifier, X); + + // ASSERT LABELS + assert(*labels[0] == 0, 'labels[0] == 0'); + assert(*labels[1] == 2, 'labels[1] == 2'); + assert(*labels[2] == 2, 'labels[2] == 2'); + assert(labels.len() == 3, 'len(labels) == 3'); + + // ASSERT SCORES + assert(*scores.data[0] == FP16x16 { mag: 60135, sign: false }, '*scores[0] == 0.917587'); + assert(*scores.data[1] == FP16x16 { mag: 7023, sign: false }, '*scores[1] == 0.107168'); + assert(*scores.data[2] == FP16x16 { mag: 42163, sign: false }, '*scores[2] == 0.643365'); + assert(*scores.data[3] == FP16x16 { mag: 43351, sign: false }, '*scores[3] == 0.661503'); + assert(*scores.data[4] == FP16x16 { mag: 15881, sign: false }, '*scores[4] == 0.24232'); + assert(*scores.data[5] == FP16x16 { mag: 52043, sign: false }, '*scores[5] == 0.79413'); + assert(*scores.data[6] == FP16x16 { mag: 16738, sign: false }, '*scores[6] == 0.255403'); + assert(*scores.data[7] == FP16x16 { mag: 30152, sign: false }, '*scores[7] == 0.460085'); + assert(*scores.data[8] == FP16x16 { mag: 58450, sign: false }, '*scores[8] == 0.891871'); +} + +#[test] +#[available_gas(200000000000)] +fn test_linear_classifier_binary_none() { + let (mut classifier, X) = linear_classifier_helper_binary(POST_TRANSFORM::NONE); + + let (labels, mut scores) = LinearClassifierTrait::predict(ref classifier, X); + + // ASSERT LABELS + assert(*labels[0] == 1, 'labels[0]'); + assert(*labels[1] == 1, 'labels[1]'); + assert(labels.len() == 2, 'len(labels)'); + + // ASSERT SCORES + assert(*scores.data[0] == FP16x16 { mag: 624559, sign: true }, '*scores[0] == -9.53'); + assert(*scores.data[1] == FP16x16 { mag: 624559, sign: false }, '*scores[1] == 9.53'); + assert(*scores.data[2] == FP16x16 { mag: 435817, sign: true }, '*scores[2] == -6.65'); + assert(*scores.data[3] == FP16x16 { mag: 435817, sign: false }, '*scores[3] == 6.65'); +} + +#[test] +#[available_gas(200000000000)] +fn test_linear_classifier_binary_logistic() { + let (mut classifier, X) = linear_classifier_helper_binary(POST_TRANSFORM::LOGISTIC); + + let (labels, mut scores) = LinearClassifierTrait::predict(ref classifier, X); + + // ASSERT LABELS + assert(*labels[0] == 1, 'labels[0]'); + assert(*labels[1] == 1, 'labels[1]'); + assert(labels.len() == 2, 'len(labels)'); + + // ASSERT SCORES + assert(*scores.data[0] == FP16x16 { mag: 4, sign: false }, '*scores[0] == 7.263436e-05'); + assert(*scores.data[1] == FP16x16 { mag: 65532, sign: false }, '*scores[1] == 9.999274e-01'); + assert(*scores.data[2] == FP16x16 { mag: 84, sign: false }, '*scores[2] == 1.292350e-03'); + assert(*scores.data[3] == FP16x16 { mag: 65452, sign: false }, '*scores[3] == 9.999983e-01'); +} + +#[test] +#[available_gas(200000000000)] +fn test_linear_classifier_binary_softmax() { + let (mut classifier, X) = linear_classifier_helper_binary(POST_TRANSFORM::SOFTMAX); + + let (labels, mut scores) = LinearClassifierTrait::predict(ref classifier, X); + // ASSERT LABELS + assert(*labels[0] == 1, 'labels[0]'); + assert(*labels[1] == 1, 'labels[1]'); + assert(labels.len() == 2, 'len(labels)'); + + // ASSERT SCORES + assert(*scores.data[0] == FP16x16 { mag: 0, sign: false }, '*scores[0] == 5.276517e-09'); + assert(*scores.data[1] == FP16x16 { mag: 65535, sign: false }, '*scores[1] == 1.000000'); + assert(*scores.data[2] == FP16x16 { mag: 0, sign: false }, '*scores[2] == 1.674492e-06'); + assert(*scores.data[3] == FP16x16 { mag: 65535, sign: false }, '*scores[3] == 9.999983e-01'); +} + +#[test] +#[available_gas(200000000000)] +fn test_linear_classifier_unary_none() { + let (mut classifier, X) = linear_classifier_helper_unary(POST_TRANSFORM::NONE); + + let (labels, mut scores) = LinearClassifierTrait::predict(ref classifier, X); + + // ASSERT LABELS + assert(*labels[0] == 1, 'labels[0]'); + assert(*labels[1] == 0, 'labels[1]'); + assert(labels.len() == 2, 'len(labels)'); + + // ASSERT SCORES + assert(*scores.data[0] == FP16x16 { mag: 146146, sign: false }, '*scores[0] == 2.23'); + assert(*scores.data[1] == FP16x16 { mag: 42596, sign: true }, '*scores[1] == -0.65'); +} + +#[test] +#[available_gas(200000000000)] +fn test_linear_classifier_unary_logistic() { + let (mut classifier, X) = linear_classifier_helper_unary(POST_TRANSFORM::LOGISTIC); + + let (labels, mut scores) = LinearClassifierTrait::predict(ref classifier, X); + + // ASSERT LABELS + assert(*labels[0] == 1, 'labels[0]'); + assert(*labels[1] == 0, 'labels[1]'); + assert(labels.len() == 2, 'len(labels)'); + + // ASSERT SCORES + assert(*scores.data[0] == FP16x16 { mag: 59173, sign: false }, '*scores[0] == 0.902911'); + assert(*scores.data[1] == FP16x16 { mag: 22479, sign: false }, '*scores[1] == 0.34299'); +} + +#[test] +#[available_gas(200000000000)] +fn test_linear_classifier_unary_softmax() { + let (mut classifier, X) = linear_classifier_helper_unary(POST_TRANSFORM::SOFTMAX); + + let (labels, mut scores) = LinearClassifierTrait::predict(ref classifier, X); + + // ASSERT LABELS + assert(*labels[0] == 1, 'labels[0]'); + assert(*labels[1] == 1, 'labels[1]'); + assert(labels.len() == 2, 'len(labels)'); + + // ASSERT SCORES + assert(*scores.data[0] == FP16x16 { mag: 65536, sign: false }, '*scores[0] == 1'); + assert(*scores.data[1] == FP16x16 { mag: 65536, sign: false }, '*scores[1] == 1'); +} + + +// ============ HELPER ============ // + +fn linear_classifier_helper( + post_transform: POST_TRANSFORM +) -> (LinearClassifier, Tensor) { + let classlabels: Span = array![0, 1, 2].span(); + let classlabels = Option::Some(classlabels); + + let classlabels_strings: Option> = Option::None; + + let coefficients: Span = array![ + FP16x16 { mag: 38011, sign: true }, + FP16x16 { mag: 19005, sign: true }, + FP16x16 { mag: 5898, sign: true }, + FP16x16 { mag: 38011, sign: false }, + FP16x16 { mag: 19005, sign: false }, + FP16x16 { mag: 5898, sign: false }, + ] + .span(); + + let intercepts: Span = array![ + FP16x16 { mag: 176947, sign: false }, + FP16x16 { mag: 176947, sign: true }, + FP16x16 { mag: 32768, sign: false }, + ] + .span(); + let intercepts = Option::Some(intercepts); + + let multi_class: usize = 0; + + let mut classifier: LinearClassifier = LinearClassifier { + classlabels, coefficients, intercepts, multi_class, post_transform + }; + + let mut X: Tensor = TensorTrait::new( + array![3, 2].span(), + array![ + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 65536, sign: false }, + FP16x16 { mag: 131072, sign: false }, + FP16x16 { mag: 196608, sign: false }, + FP16x16 { mag: 262144, sign: false }, + FP16x16 { mag: 327680, sign: false }, + ] + .span() + ); + + (classifier, X) +} + + +fn linear_classifier_helper_binary( + post_transform: POST_TRANSFORM +) -> (LinearClassifier, Tensor) { + let classlabels: Span = array![0, 1].span(); + let classlabels = Option::Some(classlabels); + + let coefficients: Span = array![ + FP16x16 { mag: 38011, sign: true }, + FP16x16 { mag: 19005, sign: true }, + FP16x16 { mag: 5898, sign: true }, + ] + .span(); + + let intercepts: Span = array![FP16x16 { mag: 655360, sign: false },].span(); + let intercepts = Option::Some(intercepts); + + let multi_class: usize = 0; + + let mut classifier: LinearClassifier = LinearClassifier { + classlabels, coefficients, intercepts, multi_class, post_transform + }; + + let mut X: Tensor = TensorTrait::new( + array![2, 3].span(), + array![ + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 65536, sign: false }, + FP16x16 { mag: 131072, sign: false }, + FP16x16 { mag: 196608, sign: false }, + FP16x16 { mag: 262144, sign: false }, + FP16x16 { mag: 327680, sign: false }, + ] + .span() + ); + + (classifier, X) +} + +fn linear_classifier_helper_unary( + post_transform: POST_TRANSFORM +) -> (LinearClassifier, Tensor) { + let classlabels: Span = array![1].span(); + let classlabels = Option::Some(classlabels); + + let coefficients: Span = array![ + FP16x16 { mag: 38011, sign: true }, + FP16x16 { mag: 19005, sign: true }, + FP16x16 { mag: 5898, sign: true }, + ] + .span(); + + let intercepts: Span = array![FP16x16 { mag: 176947, sign: false },].span(); + let intercepts = Option::Some(intercepts); + + let multi_class: usize = 0; + + let mut classifier: LinearClassifier = LinearClassifier { + classlabels, coefficients, intercepts, multi_class, post_transform + }; + + let mut X: Tensor = TensorTrait::new( + array![2, 3].span(), + array![ + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 65536, sign: false }, + FP16x16 { mag: 131072, sign: false }, + FP16x16 { mag: 196608, sign: false }, + FP16x16 { mag: 262144, sign: false }, + FP16x16 { mag: 327680, sign: false }, + ] + .span() + ); + + (classifier, X) +} diff --git a/tests/nodes/gather_nd_fp16x16_3d_batch_dims1.cairo b/tests/nodes/gather_nd_fp16x16_3d_batch_dims1.cairo index d2c0b80dd..025cc8261 100644 --- a/tests/nodes/gather_nd_fp16x16_3d_batch_dims1.cairo +++ b/tests/nodes/gather_nd_fp16x16_3d_batch_dims1.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_fp16x16_3d_batch_dims1() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(1)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(1)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_fp16x16_3d_batch_dims2.cairo b/tests/nodes/gather_nd_fp16x16_3d_batch_dims2.cairo index 507847851..677a40f6a 100644 --- a/tests/nodes/gather_nd_fp16x16_3d_batch_dims2.cairo +++ b/tests/nodes/gather_nd_fp16x16_3d_batch_dims2.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_fp16x16_3d_batch_dims2() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(2)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(2)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_fp16x16_3d_default.cairo b/tests/nodes/gather_nd_fp16x16_3d_default.cairo index ae4609a66..b8339a0d2 100644 --- a/tests/nodes/gather_nd_fp16x16_3d_default.cairo +++ b/tests/nodes/gather_nd_fp16x16_3d_default.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_fp16x16_3d_default() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(0)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(0)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_fp8x23_3d_batch_dims1.cairo b/tests/nodes/gather_nd_fp8x23_3d_batch_dims1.cairo index b9a083796..65980d91f 100644 --- a/tests/nodes/gather_nd_fp8x23_3d_batch_dims1.cairo +++ b/tests/nodes/gather_nd_fp8x23_3d_batch_dims1.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_fp8x23_3d_batch_dims1() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(1)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(1)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_fp8x23_3d_batch_dims2.cairo b/tests/nodes/gather_nd_fp8x23_3d_batch_dims2.cairo index 5e42ca893..48c812baf 100644 --- a/tests/nodes/gather_nd_fp8x23_3d_batch_dims2.cairo +++ b/tests/nodes/gather_nd_fp8x23_3d_batch_dims2.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_fp8x23_3d_batch_dims2() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(2)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(2)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_fp8x23_3d_default.cairo b/tests/nodes/gather_nd_fp8x23_3d_default.cairo index 12b6408e0..342cd2b72 100644 --- a/tests/nodes/gather_nd_fp8x23_3d_default.cairo +++ b/tests/nodes/gather_nd_fp8x23_3d_default.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_fp8x23_3d_default() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(0)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(0)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_i32_3d_batch_dims1.cairo b/tests/nodes/gather_nd_i32_3d_batch_dims1.cairo index 243b0ca16..318ccd62e 100644 --- a/tests/nodes/gather_nd_i32_3d_batch_dims1.cairo +++ b/tests/nodes/gather_nd_i32_3d_batch_dims1.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_i32_3d_batch_dims1() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(1)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(1)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_i32_3d_batch_dims2.cairo b/tests/nodes/gather_nd_i32_3d_batch_dims2.cairo index d11370b94..177c8e40f 100644 --- a/tests/nodes/gather_nd_i32_3d_batch_dims2.cairo +++ b/tests/nodes/gather_nd_i32_3d_batch_dims2.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_i32_3d_batch_dims2() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(2)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(2)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_i32_3d_default.cairo b/tests/nodes/gather_nd_i32_3d_default.cairo index 35c054093..97212f737 100644 --- a/tests/nodes/gather_nd_i32_3d_default.cairo +++ b/tests/nodes/gather_nd_i32_3d_default.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_i32_3d_default() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(0)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(0)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_i8_3d_batch_dims1.cairo b/tests/nodes/gather_nd_i8_3d_batch_dims1.cairo index ae83a8c7d..f849c8677 100644 --- a/tests/nodes/gather_nd_i8_3d_batch_dims1.cairo +++ b/tests/nodes/gather_nd_i8_3d_batch_dims1.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_i8_3d_batch_dims1() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(1)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(1)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_i8_3d_default.cairo b/tests/nodes/gather_nd_i8_3d_default.cairo index 73e1d91b2..ff7ad9252 100644 --- a/tests/nodes/gather_nd_i8_3d_default.cairo +++ b/tests/nodes/gather_nd_i8_3d_default.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_i8_3d_default() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(0)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(0)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_u32_batch_dims1.cairo b/tests/nodes/gather_nd_u32_batch_dims1.cairo index 0428ec1d5..860675f66 100644 --- a/tests/nodes/gather_nd_u32_batch_dims1.cairo +++ b/tests/nodes/gather_nd_u32_batch_dims1.cairo @@ -16,7 +16,7 @@ fn test_gather_nd_u32_batch_dims1() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(1)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(1)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_u32_batch_dims2.cairo b/tests/nodes/gather_nd_u32_batch_dims2.cairo index 39857ef1d..f0662be99 100644 --- a/tests/nodes/gather_nd_u32_batch_dims2.cairo +++ b/tests/nodes/gather_nd_u32_batch_dims2.cairo @@ -16,7 +16,7 @@ fn test_gather_nd_u32_batch_dims2() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(2)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(2)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_u32_default.cairo b/tests/nodes/gather_nd_u32_default.cairo index f55b49d5e..be6edd699 100644 --- a/tests/nodes/gather_nd_u32_default.cairo +++ b/tests/nodes/gather_nd_u32_default.cairo @@ -16,7 +16,7 @@ fn test_gather_nd_u32_default() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(0)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(0)); assert_eq(y_0, z_0); } From d502465370f018061865f53fea9e696c228f9b92 Mon Sep 17 00:00:00 2001 From: chachaleo Date: Thu, 21 Dec 2023 05:22:02 +0100 Subject: [PATCH 09/38] add readme --- .../linear-classifier/README.md | 22 +++++++++++++++++++ tests/lib.cairo | 10 ++++----- tests/ml.cairo | 6 ++--- 3 files changed, 30 insertions(+), 8 deletions(-) diff --git a/docs/framework/operators/machine-learning/linear-classifier/README.md b/docs/framework/operators/machine-learning/linear-classifier/README.md index e69de29bb..7b68132c4 100644 --- a/docs/framework/operators/machine-learning/linear-classifier/README.md +++ b/docs/framework/operators/machine-learning/linear-classifier/README.md @@ -0,0 +1,22 @@ +# Linear Classifier + +`LinearClassifierTrait` provides a trait definition for linear classification problem. + +```rust +use orion::operators::ml::LinearClassificationTrait; +``` + +### Data types + +Orion supports currently only fixed point data types for `LinearClassificationTrait`. + +| Data type | dtype | +| -------------------- | ------------------------------------------------------------- | +| Fixed point (signed) | `LinearClassifierTrait` | + + +*** + +| function | description | +| --- | --- | +| [`linear_classifier.predict`](linear_classifier.predict.md) | Performs the linear classification evaluation. | diff --git a/tests/lib.cairo b/tests/lib.cairo index defb16442..c408347ef 100644 --- a/tests/lib.cairo +++ b/tests/lib.cairo @@ -1,8 +1,8 @@ -//mod numbers; -//mod performance; -//mod tensor_core; -//mod nodes; +mod numbers; +mod performance; +mod tensor_core; +mod nodes; mod ml; -//mod operators; +mod operators; diff --git a/tests/ml.cairo b/tests/ml.cairo index ef228f254..78f6b370b 100644 --- a/tests/ml.cairo +++ b/tests/ml.cairo @@ -1,5 +1,5 @@ -//mod tree_ensemble_classifier; -//mod tree_ensemble_regressor; -//mod linear_regressor_test; +mod tree_ensemble_classifier; +mod tree_ensemble_regressor; +mod linear_regressor_test; mod linear_classifier_test; From e475cae87f3b159c34b3916b9dc9035f1e068cf5 Mon Sep 17 00:00:00 2001 From: chachaleo Date: Fri, 22 Dec 2023 10:44:55 +0100 Subject: [PATCH 10/38] feat/softmax_zero --- docs/SUMMARY.md | 1 + docs/framework/compatibility.md | 1 + .../operators/neural-network/README.md | 1 + .../neural-network/nn.softmax_zero.md | 64 ++++++ docs/framework/operators/tensor/tensor.erf.md | 4 - .../operators/tensor/tensor.gather_nd.md | 8 - nodegen/node/softmax_zero.py | 48 +++++ src/operators/ml.cairo | 4 +- src/operators/nn/core.cairo | 67 ++++++ src/operators/nn/functional.cairo | 1 + .../nn/functional/softmax_zero.cairo | 199 ++++++++++++++++++ .../nn/implementations/nn_fp16x16.cairo | 4 + .../nn/implementations/nn_fp32x32.cairo | 4 + .../nn/implementations/nn_fp64x64.cairo | 4 + .../nn/implementations/nn_fp8x23.cairo | 4 + src/operators/nn/implementations/nn_i32.cairo | 4 + src/operators/nn/implementations/nn_i8.cairo | 4 + src/operators/nn/implementations/nn_u32.cairo | 4 + .../tensor/implementations/tensor_bool.cairo | 4 +- .../implementations/tensor_fp16x16.cairo | 6 +- .../implementations/tensor_fp16x16wide.cairo | 6 +- .../implementations/tensor_fp32x32.cairo | 6 +- .../implementations/tensor_fp64x64.cairo | 6 +- .../implementations/tensor_fp8x23.cairo | 6 +- .../implementations/tensor_fp8x23wide.cairo | 6 +- .../tensor/implementations/tensor_i32.cairo | 6 +- .../tensor/implementations/tensor_i8.cairo | 6 +- .../tensor/implementations/tensor_u32.cairo | 6 +- src/operators/tensor/math/gather_nd.cairo | 45 ++-- tests/lib.cairo | 1 + tests/nodes.cairo | 2 + .../gather_nd_fp16x16_3d_batch_dims1.cairo | 2 +- .../gather_nd_fp16x16_3d_batch_dims2.cairo | 2 +- .../nodes/gather_nd_fp16x16_3d_default.cairo | 2 +- .../gather_nd_fp8x23_3d_batch_dims1.cairo | 2 +- .../gather_nd_fp8x23_3d_batch_dims2.cairo | 2 +- tests/nodes/gather_nd_fp8x23_3d_default.cairo | 2 +- .../nodes/gather_nd_i32_3d_batch_dims1.cairo | 2 +- .../nodes/gather_nd_i32_3d_batch_dims2.cairo | 2 +- tests/nodes/gather_nd_i32_3d_default.cairo | 2 +- tests/nodes/gather_nd_i8_3d_batch_dims1.cairo | 2 +- tests/nodes/gather_nd_i8_3d_default.cairo | 2 +- tests/nodes/gather_nd_u32_batch_dims1.cairo | 2 +- tests/nodes/gather_nd_u32_batch_dims2.cairo | 2 +- tests/nodes/gather_nd_u32_default.cairo | 2 +- tests/nodes/softmax_zero_fp16x16.cairo | 20 ++ .../nodes/softmax_zero_fp16x16/input_0.cairo | 17 ++ .../nodes/softmax_zero_fp16x16/output_0.cairo | 17 ++ tests/nodes/softmax_zero_fp8x23.cairo | 20 ++ tests/nodes/softmax_zero_fp8x23/input_0.cairo | 17 ++ .../nodes/softmax_zero_fp8x23/output_0.cairo | 17 ++ 51 files changed, 597 insertions(+), 71 deletions(-) create mode 100644 docs/framework/operators/neural-network/nn.softmax_zero.md create mode 100644 nodegen/node/softmax_zero.py create mode 100644 src/operators/nn/functional/softmax_zero.cairo create mode 100644 tests/nodes/softmax_zero_fp16x16.cairo create mode 100644 tests/nodes/softmax_zero_fp16x16/input_0.cairo create mode 100644 tests/nodes/softmax_zero_fp16x16/output_0.cairo create mode 100644 tests/nodes/softmax_zero_fp8x23.cairo create mode 100644 tests/nodes/softmax_zero_fp8x23/input_0.cairo create mode 100644 tests/nodes/softmax_zero_fp8x23/output_0.cairo diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 37f90ece0..6f409d29c 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -132,6 +132,7 @@ * [nn.leaky\_relu](framework/operators/neural-network/nn.leaky\_relu.md) * [nn.sigmoid](framework/operators/neural-network/nn.sigmoid.md) * [nn.softmax](framework/operators/neural-network/nn.softmax.md) + * [nn.softmax_zero](framework/operators/neural-network/nn.softmax_zero.md) * [nn.logsoftmax](framework/operators/neural-network/nn.logsoftmax.md) * [nn.softsign](framework/operators/neural-network/nn.softsign.md) * [nn.softplus](framework/operators/neural-network/nn.softplus.md) diff --git a/docs/framework/compatibility.md b/docs/framework/compatibility.md index 68cd44241..fe539fde5 100644 --- a/docs/framework/compatibility.md +++ b/docs/framework/compatibility.md @@ -37,6 +37,7 @@ You can see below the list of current supported ONNX Operators: | [ThresholdedRelu](operators/neural-network/nn.thresholded\_relu.md) | :white\_check\_mark: | | [Sigmoid](operators/neural-network/nn.sigmoid.md) | :white\_check\_mark: | | [Softmax](operators/neural-network/nn.softmax.md) | :white\_check\_mark: | +| [Softmax_zero](operators/neural-network/nn.softmax_zero.md) | :white\_check\_mark: | | [LogSoftmax](operators/neural-network/nn.logsoftmax.md) | :white\_check\_mark: | | [Softsign](operators/neural-network/nn.softsign.md) | :white\_check\_mark: | | [Softplus](operators/neural-network/nn.softplus.md) | :white\_check\_mark: | diff --git a/docs/framework/operators/neural-network/README.md b/docs/framework/operators/neural-network/README.md index cd1c92f8d..8343d0c90 100644 --- a/docs/framework/operators/neural-network/README.md +++ b/docs/framework/operators/neural-network/README.md @@ -27,6 +27,7 @@ Orion supports currently these `NN` types. | [`nn.leaky_relu`](nn.leaky\_relu.md) | Applies the leaky rectified linear unit (Leaky ReLU) activation function element-wise. | | [`nn.sigmoid`](nn.sigmoid.md) | Applies the Sigmoid function to an n-dimensional input tensor. | | [`nn.softmax`](nn.softmax.md) | Computes softmax activations. | +| [`nn.softmax_zero`](nn.softmax\_zero.md) | Computes softmax zero. | | [`nn.logsoftmax`](nn.logsoftmax.md) | Applies the natural log to Softmax function to an n-dimensional input Tensor. | | [`nn.softsign`](nn.softsign.md) | Applies the Softsign function element-wise. | | [`nn.softplus`](nn.softplus.md) | Applies the Softplus function element-wise. | diff --git a/docs/framework/operators/neural-network/nn.softmax_zero.md b/docs/framework/operators/neural-network/nn.softmax_zero.md new file mode 100644 index 000000000..f5cc9159b --- /dev/null +++ b/docs/framework/operators/neural-network/nn.softmax_zero.md @@ -0,0 +1,64 @@ +# NNTrait::softmax_zero + +```rust + fn softmax_zero(tensor: @Tensor, axis: usize) -> Tensor; +``` + +Applies the Softmax zero function to an n-dimensional input Tensor rescaling them so that the elements of the n-dimensional output Tensor lie in the range \[0,1] and sum to 1 while keeping the zero elements to zero. + +The softmax zero on the set $\mathbf{x} = (x_1, ..., x_n)$ is given by : + +$$ +\text{softmax zero}(x_i) = \begin{cases} +0 & \qquad x_i = 0 \\ +\frac{e^{x_i}}{ \sum_{x \in {S}} e^{x}} & \qquad \text{otherwise} +\end{cases} +$$ +where $S$ in a subset of $\mathbf{x}$ given by + +$$ + \ S = \{ (x_1, \ldots, x_k) \mid 1 \leq k \leq n, x_j \neq 0 \text{ for } 1 \leq j \leq k \} +$$ + +## Args + +* `tensor`(`@Tensor`) - The input tensor. +* `axis`(`usize`) - The axis along which to compute the softmax zero. + +## Returns + +A Tensor of fixed point numbers with the same shape than the input Tensor. + +## Type Constraints + +Constrain input and output types to fixed point tensors. + +## Examples + +```rust +use core::array::{ArrayTrait, SpanTrait}; + +use orion::operators::tensor::{TensorTrait, Tensor, FP8x23Tensor}; +use orion::operators::nn::{NNTrait, FP8x23NN}; +use orion::numbers::{FP8x23, FixedTrait}; + +use core::debug::PrintTrait; + +fn softmax_zero_example() -> Tensor { + let tensor = TensorTrait::::new( + shape: array![2, 2].span(), + data: array![ + FixedTrait::new(0, false), + FixedTrait::new(8388608, false), + FixedTrait::new(16777216, false), + FixedTrait::new(25165824, false), + ] + .span(), + ); + + return NNTrait::softmax_zero(@tensor, 1); +} +>>> [[0,0x800000],[2256043,6132564]] + // The fixed point representation of + // [[0, 1],[0.2689, 0.7311]] +``` diff --git a/docs/framework/operators/tensor/tensor.erf.md b/docs/framework/operators/tensor/tensor.erf.md index 19ce86a94..384a941d0 100644 --- a/docs/framework/operators/tensor/tensor.erf.md +++ b/docs/framework/operators/tensor/tensor.erf.md @@ -6,10 +6,6 @@ Computes the mean of the input tensor's elements along the provided axes. -## Args - -* `self`(`@Tensor`) - The input tensor. - ## Returns A new `Tensor` of the same shape as the input tensor with diff --git a/docs/framework/operators/tensor/tensor.gather_nd.md b/docs/framework/operators/tensor/tensor.gather_nd.md index 021d4f235..ce6f94462 100644 --- a/docs/framework/operators/tensor/tensor.gather_nd.md +++ b/docs/framework/operators/tensor/tensor.gather_nd.md @@ -21,14 +21,6 @@ Given data tensor of rank r >= 1, indices tensor of rank q >= 1, and batch_dims ## Returns A new `Tensor` . - -## Example - -```rust -use array::{ArrayTrait, SpanTrait}; - -use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor}; - fn gather_nd_example() -> Tensor { let tensor = TensorTrait::::new( shape: array![2, 2].span(), diff --git a/nodegen/node/softmax_zero.py b/nodegen/node/softmax_zero.py new file mode 100644 index 000000000..40e5528cb --- /dev/null +++ b/nodegen/node/softmax_zero.py @@ -0,0 +1,48 @@ +import numpy as np +from nodegen.node import RunAll +from ..helpers import make_test, to_fp, Tensor, Dtype, FixedImpl, Trait + + +def softmax_zero(x: np.ndarray, axis: int = -1) -> np.ndarray: + x_max = np.max(x, axis=axis, keepdims=True) + tmp = np.exp(x - x_max) + tmp = np.where(x == 0.0, 0.0, tmp) + + s = np.sum(tmp, axis=axis, keepdims=True) + s = np.where(s == 0.0, 1, s) + + + return tmp / s + + +class Softmax_zero(RunAll): + + + @staticmethod + def fp8x23(): + x = np.random.uniform(-3, 3, (2, 2)).astype(np.float64) + y = softmax_zero(x) + + x = Tensor(Dtype.FP8x23, x.shape, to_fp( + x.flatten(), FixedImpl.FP8x23)) + y = Tensor(Dtype.FP8x23, y.shape, to_fp( + y.flatten(), FixedImpl.FP8x23)) + + name = "softmax_zero_fp8x23" + make_test([x], y, "NNTrait::softmax_zero(@input_0, 1)", + name, Trait.NN) + + @staticmethod + def fp16x16(): + x = np.random.uniform(-3, 3, (2, 2)).astype(np.float64) + y = softmax_zero(x) + + x = Tensor(Dtype.FP16x16, x.shape, to_fp( + x.flatten(), FixedImpl.FP16x16)) + y = Tensor(Dtype.FP16x16, y.shape, to_fp( + y.flatten(), FixedImpl.FP16x16)) + + name = "softmax_zero_fp16x16" + make_test([x], y, "NNTrait::softmax_zero(@input_0, 1)", + name, Trait.NN) + diff --git a/src/operators/ml.cairo b/src/operators/ml.cairo index 4bfd10060..93a490bbe 100644 --- a/src/operators/ml.cairo +++ b/src/operators/ml.cairo @@ -11,4 +11,6 @@ use orion::operators::ml::tree_ensemble::tree_ensemble_classifier::{ use orion::operators::ml::tree_ensemble::tree_ensemble_regressor::{ TreeEnsembleRegressor, TreeEnsembleRegressorImpl, TreeEnsembleRegressorTrait, AGGREGATE_FUNCTION }; -use orion::operators::ml::linear::linear_regressor::{LinearRegressorTrait, LinearRegressorImpl, LinearRegressor}; +use orion::operators::ml::linear::linear_regressor::{ + LinearRegressorTrait, LinearRegressorImpl, LinearRegressor +}; diff --git a/src/operators/nn/core.cairo b/src/operators/nn/core.cairo index ae794c3a1..7594ba908 100644 --- a/src/operators/nn/core.cairo +++ b/src/operators/nn/core.cairo @@ -6,6 +6,7 @@ use orion::operators::tensor::core::Tensor; /// leaky_relu - Applies the leaky rectified linear unit (Leaky ReLU) activation function element-wise. /// sigmoid - Applies the Sigmoid function to an n-dimensional input tensor. /// softmax - Computes softmax activations. +/// softmax_zero - Computes softmax zero. /// logsoftmax - Applies the natural log to Softmax function to an n-dimensional input Tensor. /// softsign - Applies the Softsign function element-wise. /// softplus - Applies the Softplus function element-wise. @@ -115,6 +116,72 @@ trait NNTrait { /// ``` /// fn softmax(tensor: @Tensor, axis: usize) -> Tensor; + /// # NNTrait::softmax_zero + /// + /// ```rust + /// fn softmax_zero(tensor: @Tensor, axis: usize) -> Tensor; + /// ``` + /// + /// Applies the Softmax zero function to an n-dimensional input Tensor rescaling them so that the elements of the n-dimensional output Tensor lie in the range \[0,1] and sum to 1 while keeping the zero elements to zero. + /// + /// The softmax zero on the set $\mathbf{x} = (x_1, ..., x_n)$ is given by : + /// + /// $$ + /// \text{softmax zero}(x_i) = \begin{cases} + /// 0 & \qquad x_i = 0 \\ + /// \frac{e^{x_i}}{ \sum_{x \in {S}} e^{x}} & \qquad \text{otherwise} + /// \end{cases} + /// $$ + /// where $S$ in a subset of $\mathbf{x}$ given by + /// + /// $$ + /// \ S = \{ (x_1, \ldots, x_k) \mid 1 \leq k \leq n, x_j \neq 0 \text{ for } 1 \leq j \leq k \} + /// $$ + /// + /// ## Args + /// + /// * `tensor`(`@Tensor`) - The input tensor. + /// * `axis`(`usize`) - The axis along which to compute the softmax zero. + /// + /// ## Returns + /// + /// A Tensor of fixed point numbers with the same shape than the input Tensor. + /// + /// ## Type Constraints + /// + /// Constrain input and output types to fixed point tensors. + /// + /// ## Examples + /// + /// ```rust + /// use core::array::{ArrayTrait, SpanTrait}; + /// + /// use orion::operators::tensor::{TensorTrait, Tensor, FP8x23Tensor}; + /// use orion::operators::nn::{NNTrait, FP8x23NN}; + /// use orion::numbers::{FP8x23, FixedTrait}; + /// + /// use core::debug::PrintTrait; + /// + /// fn softmax_zero_example() -> Tensor { + /// let tensor = TensorTrait::::new( + /// shape: array![2, 2].span(), + /// data: array![ + /// FixedTrait::new(0, false), + /// FixedTrait::new(8388608, false), + /// FixedTrait::new(16777216, false), + /// FixedTrait::new(25165824, false), + /// ] + /// .span(), + /// ); + /// + /// return NNTrait::softmax_zero(@tensor, 1); + /// } + /// >>> [[0,0x800000],[2256043,6132564]] + /// // The fixed point representation of + /// // [[0, 1],[0.2689, 0.7311]] + /// ``` + /// + fn softmax_zero(tensor: @Tensor, axis: usize) -> Tensor; /// # NNTrait::logsoftmax /// /// ```rust diff --git a/src/operators/nn/functional.cairo b/src/operators/nn/functional.cairo index 6e40e562f..a0fd96cc8 100644 --- a/src/operators/nn/functional.cairo +++ b/src/operators/nn/functional.cairo @@ -2,6 +2,7 @@ mod relu; mod leaky_relu; mod sigmoid; mod softmax; +mod softmax_zero; mod softsign; mod softplus; mod linear; diff --git a/src/operators/nn/functional/softmax_zero.cairo b/src/operators/nn/functional/softmax_zero.cairo new file mode 100644 index 000000000..c0aa13dac --- /dev/null +++ b/src/operators/nn/functional/softmax_zero.cairo @@ -0,0 +1,199 @@ +use core::traits::Into; +use core::option::OptionTrait; + +use orion::numbers::fixed_point::core::FixedTrait; +use orion::numbers::NumberTrait; + +use orion::operators::tensor::core::{Tensor, TensorTrait, ravel_index, unravel_index}; +use orion::operators::tensor::helpers::{reduce_output_shape, len_from_shape, combine_indices}; +use orion::operators::tensor::math::{reduce_sum::accumulate_sum, arithmetic::div_downcast}; + + +/// Cf: NNTrait::softmax_zero docstring +fn softmax_zero< + T, + MAG, + impl TTensor: TensorTrait, + impl TTensorDiv: Div>, + impl TPartialEq: PartialEq, + impl TNumber: NumberTrait, + impl TCopy: Copy, + impl TDrop: Drop, + impl TAddEq: AddEq, +>( + z: @Tensor, axis: usize +) -> Tensor { + let exp_tensor = exp_zero(*z); + let sum_no_zero = reduce_sum_no_zero(@exp_tensor, axis, true); + exp_tensor / sum_no_zero +} + +/// Cf: NNTrait::softmax_zero docstring +fn softmaxWide_zero< + T, + TMAG, + W, + WMAG, + impl TTensor: TensorTrait, + impl WTensor: TensorTrait, + impl TDiv: Div, + impl TIntoW: Into, + impl WTryIntoT: TryInto, + impl TCopy: Copy, + impl TDrop: Drop, + impl WCopy: Copy, + impl WDrop: Drop, + impl TNumber: NumberTrait, + impl WNumber: NumberTrait, + impl TPartialEq: PartialEq, + impl WPartialEq: PartialEq, + impl TAddEq: AddEq, + impl WAddEq: AddEq, +>( + z: @Tensor, axis: usize +) -> Tensor { + let exp_tensor: Tensor = exp_upcast_zero(*z); + let sum_no_zero = reduce_sum_no_zero(@exp_tensor, axis, true); + div_downcast(@exp_tensor, @sum_no_zero) +} + + +/// Helper function that compute the exponential of a tensor except if the value of an entry is zero, the value remains zero. +/// +/// # Arguments +/// * `z` - The input tensor. +/// +/// # Returns +/// * A Tensor representing the exponential of the tensor except for the entries equal to zero in the input tensor, they remain zero. +fn exp_zero< + T, + MAG, + impl TNumber: NumberTrait, + impl FTensor: TensorTrait, + impl TPartialEq: PartialEq, + impl FCopy: Copy, + impl FDrop: Drop, +>( + mut z: Tensor +) -> Tensor { + let mut result = ArrayTrait::new(); + + loop { + match z.data.pop_front() { + Option::Some(item) => { + if *item == NumberTrait::zero() { + result.append(NumberTrait::zero()); + } else { + result.append((*item).exp()); + } + }, + Option::None(_) => { break; } + }; + }; + + return TensorTrait::new(z.shape, result.span()); +} + +/// Helper function that compute the exponential of a tensor except if the value of an entry is zero, the value remains zero. +/// +/// # Arguments +/// * `z` - The input tensor. +/// +/// # Returns +/// * A Tensor representing the exponential of the tensor except for the entries equal to zero in the input tensor, they remain zero. +fn exp_upcast_zero< + T, + TMAG, + W, + WMAG, + impl TNumber: NumberTrait, + impl TTensor: TensorTrait, + impl TPartialEq: PartialEq, + impl TCopy: Copy, + impl TDrop: Drop, + impl WNumber: NumberTrait, + impl WTensor: TensorTrait, + impl WCopy: Copy, + impl WDrop: Drop, + impl TIntoW: Into, +>( + mut self: Tensor +) -> Tensor { + let mut result = ArrayTrait::new(); + + loop { + match self.data.pop_front() { + Option::Some(item) => { + if *item == NumberTrait::zero() { + result.append(NumberTrait::zero()); + } else { + result.append((TIntoW::into(*item)).exp()); + } + }, + Option::None(_) => { break; } + }; + }; + + return TensorTrait::new(self.shape, result.span()); +} + + +/// Helper function that compute the reduce sum making sure no none zero value are in the output tensor. +/// +/// # Arguments +/// * `z` - The input tensor. +/// +/// # Returns +/// * A Tensor representing the ereduce sum with no entries equal to zero. + +fn reduce_sum_no_zero< + T, + MAG, + impl TTensor: TensorTrait, + impl TNumber: NumberTrait, + impl TAddEq: AddEq, + impl TCopy: Copy, + impl TDrop: Drop, + impl TPartialEq: PartialEq, +>( + self: @Tensor, axis: usize, keepdims: bool +) -> Tensor { + let mut output_data = ArrayTrait::new(); + + if (*self.shape).len() == 1 { + assert(axis == 0, 'axis out of dimensions'); + let current_sum = accumulate_sum::(*self.data, *self.shape, *self.shape, axis); + output_data.append(current_sum); + + let mut output_shape = ArrayTrait::new(); + output_shape.append(1); + + return TensorTrait::new(output_shape.span(), output_data.span()); + } else { + assert(axis <= (*self.shape).len(), 'axis out of dimensions'); + let output_shape = reduce_output_shape(*self.shape, axis, false); + let output_data_len = len_from_shape(output_shape); + let mut index: usize = 0; + loop { + let output_indices = unravel_index(index, output_shape); + let current_sum = accumulate_sum::(*self.data, *self.shape, output_indices, axis); + + if current_sum == NumberTrait::zero() { + let current_sum: T = NumberTrait::one(); + } + output_data.append(current_sum); + + index += 1; + if index == output_data_len { + break (); + }; + }; + + if keepdims { + let output_shape = reduce_output_shape(*self.shape, axis, true); + return TensorTrait::::new(output_shape, output_data.span()); + } else { + return TensorTrait::::new(output_shape, output_data.span()); + } + } +} diff --git a/src/operators/nn/implementations/nn_fp16x16.cairo b/src/operators/nn/implementations/nn_fp16x16.cairo index c1bc6970b..785d3c9fa 100644 --- a/src/operators/nn/implementations/nn_fp16x16.cairo +++ b/src/operators/nn/implementations/nn_fp16x16.cairo @@ -27,6 +27,10 @@ impl FP16x16NN of NNTrait { functional::softmax::softmaxWide::(tensor, axis) } + fn softmax_zero(tensor: @Tensor, axis: usize) -> Tensor { + functional::softmax_zero::softmaxWide_zero::(tensor, axis) + } + fn logsoftmax(tensor: @Tensor, axis: usize) -> Tensor { functional::logsoftmax::logsoftmaxWide::(tensor, axis) } diff --git a/src/operators/nn/implementations/nn_fp32x32.cairo b/src/operators/nn/implementations/nn_fp32x32.cairo index 832c26dcf..0427ea5f7 100644 --- a/src/operators/nn/implementations/nn_fp32x32.cairo +++ b/src/operators/nn/implementations/nn_fp32x32.cairo @@ -21,6 +21,10 @@ impl FP32x32NN of NNTrait { functional::softmax::softmax(tensor, axis) } + fn softmax_zero(tensor: @Tensor, axis: usize) -> Tensor { + functional::softmax_zero::softmax_zero(tensor, axis) + } + fn logsoftmax(tensor: @Tensor, axis: usize) -> Tensor { functional::logsoftmax::logsoftmax(tensor, axis) } diff --git a/src/operators/nn/implementations/nn_fp64x64.cairo b/src/operators/nn/implementations/nn_fp64x64.cairo index 0a674fe47..fec810679 100644 --- a/src/operators/nn/implementations/nn_fp64x64.cairo +++ b/src/operators/nn/implementations/nn_fp64x64.cairo @@ -21,6 +21,10 @@ impl FP64x64NN of NNTrait { functional::softmax::softmax(tensor, axis) } + fn softmax_zero(tensor: @Tensor, axis: usize) -> Tensor { + functional::softmax_zero::softmax_zero(tensor, axis) + } + fn logsoftmax(tensor: @Tensor, axis: usize) -> Tensor { functional::logsoftmax::logsoftmax(tensor, axis) } diff --git a/src/operators/nn/implementations/nn_fp8x23.cairo b/src/operators/nn/implementations/nn_fp8x23.cairo index d246bf2cc..9f5416121 100644 --- a/src/operators/nn/implementations/nn_fp8x23.cairo +++ b/src/operators/nn/implementations/nn_fp8x23.cairo @@ -25,6 +25,10 @@ impl FP8x23NN of NNTrait { functional::softmax::softmaxWide::(tensor, axis) } + fn softmax_zero(tensor: @Tensor, axis: usize) -> Tensor { + functional::softmax_zero::softmaxWide_zero::(tensor, axis) + } + fn logsoftmax(tensor: @Tensor, axis: usize) -> Tensor { functional::logsoftmax::logsoftmaxWide::(tensor, axis) } diff --git a/src/operators/nn/implementations/nn_i32.cairo b/src/operators/nn/implementations/nn_i32.cairo index dee95ec9f..232ebee84 100644 --- a/src/operators/nn/implementations/nn_i32.cairo +++ b/src/operators/nn/implementations/nn_i32.cairo @@ -19,6 +19,10 @@ impl I32NN of NNTrait { panic(array!['not supported!']) } + fn softmax_zero(tensor: @Tensor, axis: usize) -> Tensor { + panic(array!['not supported!']) + } + fn logsoftmax(tensor: @Tensor, axis: usize) -> Tensor { panic(array!['not supported!']) } diff --git a/src/operators/nn/implementations/nn_i8.cairo b/src/operators/nn/implementations/nn_i8.cairo index c9057fcce..185ba4a4e 100644 --- a/src/operators/nn/implementations/nn_i8.cairo +++ b/src/operators/nn/implementations/nn_i8.cairo @@ -19,6 +19,10 @@ impl I8NN of NNTrait { panic(array!['not supported!']) } + fn softmax_zero(tensor: @Tensor, axis: usize) -> Tensor { + panic(array!['not supported!']) + } + fn logsoftmax(tensor: @Tensor, axis: usize) -> Tensor { panic(array!['not supported!']) } diff --git a/src/operators/nn/implementations/nn_u32.cairo b/src/operators/nn/implementations/nn_u32.cairo index 1a2883a16..370880e8d 100644 --- a/src/operators/nn/implementations/nn_u32.cairo +++ b/src/operators/nn/implementations/nn_u32.cairo @@ -18,6 +18,10 @@ impl U32NN of NNTrait { panic(array!['not supported!']) } + fn softmax_zero(tensor: @Tensor, axis: usize) -> Tensor { + panic(array!['not supported!']) + } + fn logsoftmax(tensor: @Tensor, axis: usize) -> Tensor { panic(array!['not supported!']) } diff --git a/src/operators/tensor/implementations/tensor_bool.cairo b/src/operators/tensor/implementations/tensor_bool.cairo index d2afe3fc5..1f11e1f3f 100644 --- a/src/operators/tensor/implementations/tensor_bool.cairo +++ b/src/operators/tensor/implementations/tensor_bool.cairo @@ -472,7 +472,9 @@ impl BoolTensor of TensorTrait { panic(array!['not supported!']) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } } diff --git a/src/operators/tensor/implementations/tensor_fp16x16.cairo b/src/operators/tensor/implementations/tensor_fp16x16.cairo index ccaf5903d..ce489a45a 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16.cairo @@ -520,10 +520,12 @@ impl FP16x16Tensor of TensorTrait { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } - + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { math::reduce_log_sum::reduce_log_sum(self, axis, keepdims) } diff --git a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo index dc32202ed..1ea152edf 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo @@ -486,10 +486,12 @@ impl FP16x16WTensor of TensorTrait { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } - + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { math::reduce_log_sum::reduce_log_sum(self, axis, keepdims) } diff --git a/src/operators/tensor/implementations/tensor_fp32x32.cairo b/src/operators/tensor/implementations/tensor_fp32x32.cairo index 9100d6f82..5e9b5512a 100644 --- a/src/operators/tensor/implementations/tensor_fp32x32.cairo +++ b/src/operators/tensor/implementations/tensor_fp32x32.cairo @@ -521,10 +521,12 @@ impl FP32x32Tensor of TensorTrait { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } - + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { math::reduce_log_sum::reduce_log_sum(self, axis, keepdims) } diff --git a/src/operators/tensor/implementations/tensor_fp64x64.cairo b/src/operators/tensor/implementations/tensor_fp64x64.cairo index ee6441058..2b5142789 100644 --- a/src/operators/tensor/implementations/tensor_fp64x64.cairo +++ b/src/operators/tensor/implementations/tensor_fp64x64.cairo @@ -522,10 +522,12 @@ impl FP64x64Tensor of TensorTrait { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } - + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { math::reduce_log_sum::reduce_log_sum(self, axis, keepdims) } diff --git a/src/operators/tensor/implementations/tensor_fp8x23.cairo b/src/operators/tensor/implementations/tensor_fp8x23.cairo index 17a601f7b..18cd91af7 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23.cairo @@ -520,10 +520,12 @@ impl FP8x23Tensor of TensorTrait { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } - + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { math::reduce_log_sum::reduce_log_sum(self, axis, keepdims) } diff --git a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo index a7d19901b..a1f9bcbdf 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo @@ -473,10 +473,12 @@ impl FP8x23WTensor of TensorTrait { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } - + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { math::reduce_log_sum::reduce_log_sum(self, axis, keepdims) } diff --git a/src/operators/tensor/implementations/tensor_i32.cairo b/src/operators/tensor/implementations/tensor_i32.cairo index a987b0633..c8753c5a9 100644 --- a/src/operators/tensor/implementations/tensor_i32.cairo +++ b/src/operators/tensor/implementations/tensor_i32.cairo @@ -517,10 +517,12 @@ impl I32Tensor of TensorTrait { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } - + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { panic(array!['not supported!']) } diff --git a/src/operators/tensor/implementations/tensor_i8.cairo b/src/operators/tensor/implementations/tensor_i8.cairo index 8c1e2fd32..fa67a7950 100644 --- a/src/operators/tensor/implementations/tensor_i8.cairo +++ b/src/operators/tensor/implementations/tensor_i8.cairo @@ -515,10 +515,12 @@ impl I8Tensor of TensorTrait { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } - + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { panic(array!['not supported!']) } diff --git a/src/operators/tensor/implementations/tensor_u32.cairo b/src/operators/tensor/implementations/tensor_u32.cairo index 5b2058401..9a90ef883 100644 --- a/src/operators/tensor/implementations/tensor_u32.cairo +++ b/src/operators/tensor/implementations/tensor_u32.cairo @@ -458,10 +458,12 @@ impl U32Tensor of TensorTrait { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } - + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { panic(array!['not supported!']) } diff --git a/src/operators/tensor/math/gather_nd.cairo b/src/operators/tensor/math/gather_nd.cairo index 120eff5be..737a4fe32 100644 --- a/src/operators/tensor/math/gather_nd.cairo +++ b/src/operators/tensor/math/gather_nd.cairo @@ -14,12 +14,7 @@ use orion::operators::tensor::U32TensorPartialEq; use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor}; /// Cf: TensorTrait::gather_nd docstring -fn gather_nd< - T, - impl TTensorTrait: TensorTrait, - impl TCopy: Copy, - impl TDrop: Drop, ->( +fn gather_nd, impl TCopy: Copy, impl TDrop: Drop,>( self: @Tensor, indices: Tensor, batch_dims: Option ) -> Tensor { let batch_dims = match batch_dims { @@ -29,19 +24,22 @@ fn gather_nd< let data_rank = (*self.shape).len(); let indices_rank = (indices.shape).len(); - assert((data_rank >= 1 ) & (indices_rank >= 1), 'rank must > 1'); - + assert((data_rank >= 1) & (indices_rank >= 1), 'rank must > 1'); + let mut data_shape = *self.shape; let mut indices_shape = indices.shape; let mut data_shape_clone = data_shape.clone(); let mut indices_shape_clone = indices_shape.clone(); let indices_shape_last = indices_shape_clone.pop_back().unwrap(); - assert((*indices_shape_last >= 1) & (*indices_shape_last <= data_rank-batch_dims), 'check indices'); + assert( + (*indices_shape_last >= 1) & (*indices_shape_last <= data_rank - batch_dims), + 'check indices' + ); let mut batch_dims_shape = ArrayTrait::new(); let mut output_shape = ArrayTrait::new(); - let mut index_data = ArrayTrait::new(); + let mut index_data = ArrayTrait::new(); let mut output_data = ArrayTrait::new(); let mut batch_dims_size = batch_dims; @@ -51,7 +49,7 @@ fn gather_nd< let mut ind = 0; loop { if (ind == batch_dims) { - break(); + break (); } match indices_shape_clone.pop_front() { Option::Some(val) => { @@ -65,17 +63,14 @@ fn gather_nd< loop { match indices_shape_clone.pop_front() { - Option::Some(val) => { - batch_dims_shape.append(*val); - }, + Option::Some(val) => { batch_dims_shape.append(*val); }, Option::None(_) => { break; } }; }; if (*indices_shape_last == data_rank - batch_dims) { output_shape = batch_dims_shape; - } - else { + } else { let mut ind = 0; let mut multiple = 1; output_shape = batch_dims_shape; @@ -136,16 +131,18 @@ fn gather_nd< match data_indices.pop_front() { Option::Some(val) => { let index = ind % *indices_shape_last; - let incr= total_data_len * (ind / breaker); + let incr = total_data_len * (ind / breaker); result += (*val * total_data_len / *multiple_data_len.at(index)); ind += 1; - if (index == *indices_shape_last-1) { - let mut data_ind:usize = result ; + if (index == *indices_shape_last - 1) { + let mut data_ind: usize = result; loop { - if data_ind == result + incrementer { break; } + if data_ind == result + incrementer { + break; + } index_data.append(data_ind + incr); - data_ind+=1; + data_ind += 1; }; result = 0; }; @@ -156,13 +153,11 @@ fn gather_nd< loop { match index_data.pop_front() { - Option::Some(val) => { - output_data.append(*self.data[val]); - }, + Option::Some(val) => { output_data.append(*self.data[val]); }, Option::None(_) => { break; } }; }; let mut output_tensor = TensorTrait::::new(output_shape.span(), output_data.span()); return output_tensor; -} \ No newline at end of file +} diff --git a/tests/lib.cairo b/tests/lib.cairo index f5cecb77d..c408347ef 100644 --- a/tests/lib.cairo +++ b/tests/lib.cairo @@ -5,3 +5,4 @@ mod nodes; mod ml; mod operators; + diff --git a/tests/nodes.cairo b/tests/nodes.cairo index c7155e942..8cd33f470 100644 --- a/tests/nodes.cairo +++ b/tests/nodes.cairo @@ -850,3 +850,5 @@ mod gather_nd_i8_3d_batch_dims1; mod gather_nd_u32_default; mod gather_nd_u32_batch_dims1; mod gather_nd_u32_batch_dims2; +mod softmax_zero_fp16x16; +mod softmax_zero_fp8x23; diff --git a/tests/nodes/gather_nd_fp16x16_3d_batch_dims1.cairo b/tests/nodes/gather_nd_fp16x16_3d_batch_dims1.cairo index d2c0b80dd..025cc8261 100644 --- a/tests/nodes/gather_nd_fp16x16_3d_batch_dims1.cairo +++ b/tests/nodes/gather_nd_fp16x16_3d_batch_dims1.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_fp16x16_3d_batch_dims1() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(1)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(1)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_fp16x16_3d_batch_dims2.cairo b/tests/nodes/gather_nd_fp16x16_3d_batch_dims2.cairo index 507847851..677a40f6a 100644 --- a/tests/nodes/gather_nd_fp16x16_3d_batch_dims2.cairo +++ b/tests/nodes/gather_nd_fp16x16_3d_batch_dims2.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_fp16x16_3d_batch_dims2() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(2)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(2)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_fp16x16_3d_default.cairo b/tests/nodes/gather_nd_fp16x16_3d_default.cairo index ae4609a66..b8339a0d2 100644 --- a/tests/nodes/gather_nd_fp16x16_3d_default.cairo +++ b/tests/nodes/gather_nd_fp16x16_3d_default.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_fp16x16_3d_default() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(0)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(0)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_fp8x23_3d_batch_dims1.cairo b/tests/nodes/gather_nd_fp8x23_3d_batch_dims1.cairo index b9a083796..65980d91f 100644 --- a/tests/nodes/gather_nd_fp8x23_3d_batch_dims1.cairo +++ b/tests/nodes/gather_nd_fp8x23_3d_batch_dims1.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_fp8x23_3d_batch_dims1() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(1)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(1)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_fp8x23_3d_batch_dims2.cairo b/tests/nodes/gather_nd_fp8x23_3d_batch_dims2.cairo index 5e42ca893..48c812baf 100644 --- a/tests/nodes/gather_nd_fp8x23_3d_batch_dims2.cairo +++ b/tests/nodes/gather_nd_fp8x23_3d_batch_dims2.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_fp8x23_3d_batch_dims2() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(2)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(2)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_fp8x23_3d_default.cairo b/tests/nodes/gather_nd_fp8x23_3d_default.cairo index 12b6408e0..342cd2b72 100644 --- a/tests/nodes/gather_nd_fp8x23_3d_default.cairo +++ b/tests/nodes/gather_nd_fp8x23_3d_default.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_fp8x23_3d_default() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(0)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(0)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_i32_3d_batch_dims1.cairo b/tests/nodes/gather_nd_i32_3d_batch_dims1.cairo index 243b0ca16..318ccd62e 100644 --- a/tests/nodes/gather_nd_i32_3d_batch_dims1.cairo +++ b/tests/nodes/gather_nd_i32_3d_batch_dims1.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_i32_3d_batch_dims1() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(1)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(1)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_i32_3d_batch_dims2.cairo b/tests/nodes/gather_nd_i32_3d_batch_dims2.cairo index d11370b94..177c8e40f 100644 --- a/tests/nodes/gather_nd_i32_3d_batch_dims2.cairo +++ b/tests/nodes/gather_nd_i32_3d_batch_dims2.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_i32_3d_batch_dims2() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(2)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(2)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_i32_3d_default.cairo b/tests/nodes/gather_nd_i32_3d_default.cairo index 35c054093..97212f737 100644 --- a/tests/nodes/gather_nd_i32_3d_default.cairo +++ b/tests/nodes/gather_nd_i32_3d_default.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_i32_3d_default() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(0)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(0)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_i8_3d_batch_dims1.cairo b/tests/nodes/gather_nd_i8_3d_batch_dims1.cairo index ae83a8c7d..f849c8677 100644 --- a/tests/nodes/gather_nd_i8_3d_batch_dims1.cairo +++ b/tests/nodes/gather_nd_i8_3d_batch_dims1.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_i8_3d_batch_dims1() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(1)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(1)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_i8_3d_default.cairo b/tests/nodes/gather_nd_i8_3d_default.cairo index 73e1d91b2..ff7ad9252 100644 --- a/tests/nodes/gather_nd_i8_3d_default.cairo +++ b/tests/nodes/gather_nd_i8_3d_default.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_i8_3d_default() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(0)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(0)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_u32_batch_dims1.cairo b/tests/nodes/gather_nd_u32_batch_dims1.cairo index 0428ec1d5..860675f66 100644 --- a/tests/nodes/gather_nd_u32_batch_dims1.cairo +++ b/tests/nodes/gather_nd_u32_batch_dims1.cairo @@ -16,7 +16,7 @@ fn test_gather_nd_u32_batch_dims1() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(1)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(1)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_u32_batch_dims2.cairo b/tests/nodes/gather_nd_u32_batch_dims2.cairo index 39857ef1d..f0662be99 100644 --- a/tests/nodes/gather_nd_u32_batch_dims2.cairo +++ b/tests/nodes/gather_nd_u32_batch_dims2.cairo @@ -16,7 +16,7 @@ fn test_gather_nd_u32_batch_dims2() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(2)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(2)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_u32_default.cairo b/tests/nodes/gather_nd_u32_default.cairo index f55b49d5e..be6edd699 100644 --- a/tests/nodes/gather_nd_u32_default.cairo +++ b/tests/nodes/gather_nd_u32_default.cairo @@ -16,7 +16,7 @@ fn test_gather_nd_u32_default() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(0)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(0)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/softmax_zero_fp16x16.cairo b/tests/nodes/softmax_zero_fp16x16.cairo new file mode 100644 index 000000000..41e229944 --- /dev/null +++ b/tests/nodes/softmax_zero_fp16x16.cairo @@ -0,0 +1,20 @@ +mod input_0; +mod output_0; + + +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::numbers::FixedTrait; +use orion::operators::nn::FP16x16NN; +use orion::operators::nn::NNTrait; +use orion::operators::tensor::FP16x16TensorPartialEq; + +#[test] +#[available_gas(2000000000)] +fn test_softmax_zero_fp16x16() { + let input_0 = input_0::input_0(); + let z_0 = output_0::output_0(); + + let y_0 = NNTrait::softmax_zero(@input_0, 1); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/softmax_zero_fp16x16/input_0.cairo b/tests/nodes/softmax_zero_fp16x16/input_0.cairo new file mode 100644 index 000000000..fb0512473 --- /dev/null +++ b/tests/nodes/softmax_zero_fp16x16/input_0.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 2278, sign: true }); + data.append(FP16x16 { mag: 76483, sign: false }); + data.append(FP16x16 { mag: 23998, sign: false }); + data.append(FP16x16 { mag: 808, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/softmax_zero_fp16x16/output_0.cairo b/tests/nodes/softmax_zero_fp16x16/output_0.cairo new file mode 100644 index 000000000..f72d7363c --- /dev/null +++ b/tests/nodes/softmax_zero_fp16x16/output_0.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 15148, sign: false }); + data.append(FP16x16 { mag: 50387, sign: false }); + data.append(FP16x16 { mag: 38896, sign: false }); + data.append(FP16x16 { mag: 26639, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/softmax_zero_fp8x23.cairo b/tests/nodes/softmax_zero_fp8x23.cairo new file mode 100644 index 000000000..88e2eb3a1 --- /dev/null +++ b/tests/nodes/softmax_zero_fp8x23.cairo @@ -0,0 +1,20 @@ +mod input_0; +mod output_0; + + +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::numbers::FixedTrait; +use orion::operators::tensor::FP8x23TensorPartialEq; +use orion::operators::nn::FP8x23NN; +use orion::operators::nn::NNTrait; + +#[test] +#[available_gas(2000000000)] +fn test_softmax_zero_fp8x23() { + let input_0 = input_0::input_0(); + let z_0 = output_0::output_0(); + + let y_0 = NNTrait::softmax_zero(@input_0, 1); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/softmax_zero_fp8x23/input_0.cairo b/tests/nodes/softmax_zero_fp8x23/input_0.cairo new file mode 100644 index 000000000..c74e57ae0 --- /dev/null +++ b/tests/nodes/softmax_zero_fp8x23/input_0.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 21210813, sign: true }); + data.append(FP8x23 { mag: 18026313, sign: true }); + data.append(FP8x23 { mag: 11180685, sign: false }); + data.append(FP8x23 { mag: 9192264, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/softmax_zero_fp8x23/output_0.cairo b/tests/nodes/softmax_zero_fp8x23/output_0.cairo new file mode 100644 index 000000000..5a73df92d --- /dev/null +++ b/tests/nodes/softmax_zero_fp8x23/output_0.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 3407604, sign: false }); + data.append(FP8x23 { mag: 4981003, sign: false }); + data.append(FP8x23 { mag: 4689094, sign: false }); + data.append(FP8x23 { mag: 3699513, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} From cb8de919bc3b0805e956bb15fd66a06e8bda6a0a Mon Sep 17 00:00:00 2001 From: chachaleo Date: Fri, 22 Dec 2023 16:25:01 +0100 Subject: [PATCH 11/38] feat.tensor-complex --- nodegen/helpers.py | 16 + nodegen/node/reduce_l2.py | 26 + src/numbers/complex_number/complex64.cairo | 7 +- src/operators/ml.cairo | 4 +- src/operators/tensor.cairo | 4 + src/operators/tensor/implementations.cairo | 1 + .../tensor/implementations/tensor_bool.cairo | 4 +- .../implementations/tensor_complex64.cairo | 615 ++++++++++++++++++ .../implementations/tensor_fp16x16.cairo | 6 +- .../implementations/tensor_fp16x16wide.cairo | 6 +- .../implementations/tensor_fp32x32.cairo | 6 +- .../implementations/tensor_fp64x64.cairo | 6 +- .../implementations/tensor_fp8x23.cairo | 6 +- .../implementations/tensor_fp8x23wide.cairo | 6 +- .../tensor/implementations/tensor_i32.cairo | 6 +- .../tensor/implementations/tensor_i8.cairo | 6 +- .../tensor/implementations/tensor_u32.cairo | 6 +- src/operators/tensor/math/acos.cairo | 9 +- src/operators/tensor/math/acosh.cairo | 8 +- src/operators/tensor/math/asin.cairo | 9 +- src/operators/tensor/math/asinh.cairo | 8 +- src/operators/tensor/math/atan.cairo | 8 +- src/operators/tensor/math/cos.cairo | 8 +- src/operators/tensor/math/cosh.cairo | 8 +- src/operators/tensor/math/exp.cairo | 8 +- src/operators/tensor/math/gather_nd.cairo | 45 +- src/operators/tensor/math/log.cairo | 8 +- src/operators/tensor/math/reduce_l2.cairo | 30 +- src/operators/tensor/math/sin.cairo | 8 +- src/operators/tensor/math/sinh.cairo | 8 +- src/operators/tensor/math/sqrt.cairo | 8 +- src/operators/tensor/math/tanh.cairo | 8 +- tests/lib.cairo | 1 + tests/nodes.cairo | 1 + .../gather_nd_fp16x16_3d_batch_dims1.cairo | 2 +- .../gather_nd_fp16x16_3d_batch_dims2.cairo | 2 +- .../nodes/gather_nd_fp16x16_3d_default.cairo | 2 +- .../gather_nd_fp8x23_3d_batch_dims1.cairo | 2 +- .../gather_nd_fp8x23_3d_batch_dims2.cairo | 2 +- tests/nodes/gather_nd_fp8x23_3d_default.cairo | 2 +- .../nodes/gather_nd_i32_3d_batch_dims1.cairo | 2 +- .../nodes/gather_nd_i32_3d_batch_dims2.cairo | 2 +- tests/nodes/gather_nd_i32_3d_default.cairo | 2 +- tests/nodes/gather_nd_i8_3d_batch_dims1.cairo | 2 +- tests/nodes/gather_nd_i8_3d_default.cairo | 2 +- tests/nodes/gather_nd_u32_batch_dims1.cairo | 2 +- tests/nodes/gather_nd_u32_batch_dims2.cairo | 2 +- tests/nodes/gather_nd_u32_default.cairo | 2 +- tests/nodes/reduce_l2_complex64_axis_0.cairo | 24 + .../reduce_l2_complex64_axis_0/input_0.cairo | 56 ++ .../reduce_l2_complex64_axis_0/output_0.cairo | 35 + 51 files changed, 939 insertions(+), 118 deletions(-) create mode 100644 src/operators/tensor/implementations/tensor_complex64.cairo create mode 100644 tests/nodes/reduce_l2_complex64_axis_0.cairo create mode 100644 tests/nodes/reduce_l2_complex64_axis_0/input_0.cairo create mode 100644 tests/nodes/reduce_l2_complex64_axis_0/output_0.cairo diff --git a/nodegen/helpers.py b/nodegen/helpers.py index cf876ccc0..07ae07ed0 100644 --- a/nodegen/helpers.py +++ b/nodegen/helpers.py @@ -10,6 +10,8 @@ class FixedImpl(Enum): FP8x23 = 'FP8x23' FP16x16 = 'FP16x16' + FP64x64 = 'FP64x64' + def to_fp(x: np.ndarray, fp_impl: FixedImpl): @@ -18,15 +20,19 @@ def to_fp(x: np.ndarray, fp_impl: FixedImpl): return (x * 2**23).astype(np.int64) case FixedImpl.FP16x16: return (x * 2**16).astype(np.int64) + case FixedImpl.FP64x64: + return (x * 2**64) class Dtype(Enum): FP8x23 = 'FP8x23' FP16x16 = 'FP16x16' + FP64x64 = 'FP64x64' I8 = 'i8' I32 = 'i32' U32 = 'u32' BOOL = 'bool' + COMPLEX64 = 'complex64' class Tensor: @@ -166,8 +172,15 @@ def get_data_statement(data: np.ndarray, dtype: Dtype) -> list[str]: return ["FP8x23 { "+f"mag: {abs(int(x))}, sign: {str(x < 0).lower()} "+"}" for x in data.flatten()] case Dtype.FP16x16: return ["FP16x16 { "+f"mag: {abs(int(x))}, sign: {str(x < 0).lower()} "+"}" for x in data.flatten()] + case Dtype.FP64x64: + return ["FP64x64 { "+f"mag: {abs(int(x))}, sign: {str(x < 0).lower()} "+"}" for x in data.flatten()] case Dtype.BOOL: return [str(x).lower() for x in data.flatten()] + case Dtype.COMPLEX64: + return ["complex64 { "+"real: FP64x64 { "+f"mag: {abs(int(np.real(x)))}, sign: {str(np.real(x) < 0).lower()} "+"} , img: FP64x64 { "+f"mag: {abs(int(np.imag(x)))}, sign: {str(np.imag(x) < 0).lower()} "+"} }" for x in data.flatten()] + + + def get_data_statement_for_sequences(data: Sequence, dtype: Dtype) -> list[list[str]]: @@ -227,6 +240,7 @@ def find_all_types(tensors: list[Tensor | Sequence]) -> list[Dtype]: Dtype.FP8x23: ["orion::operators::tensor::FP8x23Tensor",], Dtype.FP16x16: ["orion::operators::tensor::FP16x16Tensor",], Dtype.BOOL: ["orion::operators::tensor::BoolTensor",], + Dtype.COMPLEX64: ["orion::operators::tensor::Complex64Tensor",], } @@ -246,6 +260,7 @@ def find_all_types(tensors: list[Tensor | Sequence]) -> list[Dtype]: Dtype.FP8x23: ["orion::operators::tensor::FP8x23TensorPartialEq",], Dtype.FP16x16: ["orion::operators::tensor::FP16x16TensorPartialEq",], Dtype.BOOL: ["orion::operators::tensor::BoolTensorPartialEq",], + Dtype.COMPLEX64: ["orion::operators::tensor::Complex64TensorPartialEq",], } @@ -256,4 +271,5 @@ def find_all_types(tensors: list[Tensor | Sequence]) -> list[Dtype]: Dtype.FP8x23: ["orion::numbers::{FixedTrait, FP8x23}",], Dtype.FP16x16: ["orion::numbers::{FixedTrait, FP16x16}",], Dtype.BOOL: [], + Dtype.COMPLEX64: ["orion::numbers::{NumberTrait, complex64}",], } \ No newline at end of file diff --git a/nodegen/node/reduce_l2.py b/nodegen/node/reduce_l2.py index 081125f54..67c8cfa10 100644 --- a/nodegen/node/reduce_l2.py +++ b/nodegen/node/reduce_l2.py @@ -4,6 +4,7 @@ import numpy as np + class Reduce_l2(RunAll): @staticmethod def reduce_l2_fp8x23(): @@ -107,4 +108,29 @@ def reduce_l2_axis_0(): reduce_l2_export_do_not_keepdims() reduce_l2_export_keepdims() + reduce_l2_axis_0() + + @staticmethod + def reduce_l2_complex64(): + + + + def reduce_l2_axis_0(): + shape = [2, 3] + axes = np.array([0], dtype=np.int64) + keepdims = True + x = np.reshape(np.array([1.+2.j, 2.-1.j, 3.-3.j, 3.-2.j, 3.+5.j, 4.- 1.j]), shape) + y = np.sqrt(np.sum(a=np.square(abs(x)), axis=tuple(axes), keepdims=True)) + print(to_fp(x.flatten(), FixedImpl.FP64x64)) + + x = Tensor(Dtype.COMPLEX64, x.shape, to_fp( + x.flatten(), FixedImpl.FP64x64)) + + y = Tensor(Dtype.COMPLEX64, y.shape, to_fp( + y.flatten(), FixedImpl.FP64x64)) + + name = "reduce_l2_complex64_axis_0" + make_test( + [x], y, "input_0.reduce_l2(0, true)", name) + reduce_l2_axis_0() \ No newline at end of file diff --git a/src/numbers/complex_number/complex64.cairo b/src/numbers/complex_number/complex64.cairo index fe7b70581..c3b649c6d 100644 --- a/src/numbers/complex_number/complex64.cairo +++ b/src/numbers/complex_number/complex64.cairo @@ -73,7 +73,12 @@ impl Complex64Impl of ComplexTrait { let y = self.img; let two = FP64x64Impl::new(TWO, false); let real = (((x.pow(two) + y.pow(two)).sqrt() + x) / two).sqrt(); - let img = (((x.pow(two) + y.pow(two)).sqrt() - x) / two).sqrt(); + let img = if y == FP64x64Impl::ZERO() { + FP64x64Impl::ZERO() + } else { + (((x.pow(two) + y.pow(two)).sqrt() - x) / two).sqrt() + }; + let img = FP64x64Impl::new(img.mag, y.sign); complex64 { real, img } } diff --git a/src/operators/ml.cairo b/src/operators/ml.cairo index 4bfd10060..93a490bbe 100644 --- a/src/operators/ml.cairo +++ b/src/operators/ml.cairo @@ -11,4 +11,6 @@ use orion::operators::ml::tree_ensemble::tree_ensemble_classifier::{ use orion::operators::ml::tree_ensemble::tree_ensemble_regressor::{ TreeEnsembleRegressor, TreeEnsembleRegressorImpl, TreeEnsembleRegressorTrait, AGGREGATE_FUNCTION }; -use orion::operators::ml::linear::linear_regressor::{LinearRegressorTrait, LinearRegressorImpl, LinearRegressor}; +use orion::operators::ml::linear::linear_regressor::{ + LinearRegressorTrait, LinearRegressorImpl, LinearRegressor +}; diff --git a/src/operators/tensor.cairo b/src/operators/tensor.cairo index c217cb1f6..adf2076b5 100644 --- a/src/operators/tensor.cairo +++ b/src/operators/tensor.cairo @@ -39,3 +39,7 @@ use orion::operators::tensor::implementations::tensor_u32::{ use orion::operators::tensor::implementations::tensor_bool::{BoolTensor, BoolTensorPartialEq}; +use orion::operators::tensor::implementations::tensor_complex64::{ + Complex64Tensor, Complex64TensorAdd, Complex64TensorSub, Complex64TensorMul, Complex64TensorDiv, + Complex64TensorPartialEq, +}; diff --git a/src/operators/tensor/implementations.cairo b/src/operators/tensor/implementations.cairo index 61ee08dd3..f9a7b406f 100644 --- a/src/operators/tensor/implementations.cairo +++ b/src/operators/tensor/implementations.cairo @@ -8,3 +8,4 @@ mod tensor_fp64x64; mod tensor_fp32x32; mod tensor_fp16x16wide; mod tensor_fp8x23wide; +mod tensor_complex64; diff --git a/src/operators/tensor/implementations/tensor_bool.cairo b/src/operators/tensor/implementations/tensor_bool.cairo index d2afe3fc5..1f11e1f3f 100644 --- a/src/operators/tensor/implementations/tensor_bool.cairo +++ b/src/operators/tensor/implementations/tensor_bool.cairo @@ -472,7 +472,9 @@ impl BoolTensor of TensorTrait { panic(array!['not supported!']) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } } diff --git a/src/operators/tensor/implementations/tensor_complex64.cairo b/src/operators/tensor/implementations/tensor_complex64.cairo new file mode 100644 index 000000000..810f347ca --- /dev/null +++ b/src/operators/tensor/implementations/tensor_complex64.cairo @@ -0,0 +1,615 @@ +use core::array::ArrayTrait; +use core::array::SpanTrait; +use core::option::OptionTrait; +use core::traits::{TryInto, Into}; + +use orion::numbers::fixed_point::core::FixedTrait; +use orion::operators::tensor::core::{ + new_tensor, constant_of_shape, stride, Tensor, TensorTrait, ravel_index, unravel_index, reshape, + at_tensor, +}; +use orion::operators::tensor::{math, linalg, quantization, core as core_tensor, ml}; +use orion::numbers::{i8, i32, NumberTrait, FP64x64, FP64x64Impl}; +use orion::numbers::fixed_point::implementations::fp64x64::core::ONE; +use orion::operators::tensor::implementations::{ + tensor_i8::I8Tensor, tensor_u32::U32Tensor, tensor_bool::BoolTensor +}; +use orion::numbers::complex_number::complex_trait::ComplexTrait; +use orion::numbers::complex_number::complex64::{Complex64Impl, complex64}; + + +impl Complex64Tensor of TensorTrait { + fn new(shape: Span, data: Span) -> Tensor { + new_tensor(shape, data) + } + + fn constant_of_shape(shape: Span, value: complex64) -> Tensor { + constant_of_shape(shape, value) + } + + fn at(self: @Tensor, indices: Span) -> complex64 { + *at_tensor(self, indices) + } + + fn add(lhs: Tensor, rhs: Tensor) -> Tensor { + math::arithmetic::add(@lhs, @rhs) + } + + fn sub(lhs: Tensor, rhs: Tensor) -> Tensor { + math::arithmetic::sub(@lhs, @rhs) + } + + fn mul(lhs: Tensor, rhs: Tensor) -> Tensor { + math::arithmetic::mul(@lhs, @rhs) + } + + fn div(lhs: Tensor, rhs: Tensor) -> Tensor { + math::arithmetic::div(@lhs, @rhs) + } + + fn min_in_tensor(self: @Tensor) -> complex64 { + panic(array!['not supported!']) + } + + fn min(tensors: Span>) -> Tensor { + panic(array!['not supported!']) + } + + fn max_in_tensor(self: @Tensor) -> complex64 { + panic(array!['not supported!']) + } + + fn max(tensors: Span>) -> Tensor { + panic(array!['not supported!']) + } + + fn stride(self: @Tensor) -> Span { + stride(*self.shape) + } + + fn ravel_index(self: @Tensor, indices: Span) -> usize { + ravel_index(*self.shape, indices) + } + + fn unravel_index(self: @Tensor, index: usize) -> Span { + unravel_index(index, *self.shape) + } + + fn reshape(self: @Tensor, target_shape: Span) -> Tensor { + reshape(self, target_shape) + } + + fn reduce_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { + math::reduce_sum::reduce_sum(self, axis, keepdims) + } + + fn reduce_prod(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { + math::reduce_prod::reduce_prod(self, axis, keepdims) + } + + fn argmax( + self: @Tensor, + axis: usize, + keepdims: Option, + select_last_index: Option + ) -> Tensor { + panic(array!['not supported!']) + } + + fn argmin( + self: @Tensor, + axis: usize, + keepdims: Option, + select_last_index: Option + ) -> Tensor { + panic(array!['not supported!']) + } + + fn transpose(self: @Tensor, axes: Span) -> Tensor { + linalg::transpose::transpose(self, axes) + } + + fn matmul(self: @Tensor, other: @Tensor) -> Tensor { + linalg::matmul::matmul(self, other) + } + + fn exp(self: @Tensor) -> Tensor { + math::exp::exp(*self) + } + + fn log(self: @Tensor) -> Tensor { + math::log::log(*self) + } + + fn equal(self: @Tensor, other: @Tensor) -> Tensor { + math::equal::equal(self, other) + } + + fn greater(self: @Tensor, other: @Tensor) -> Tensor { + panic(array!['not supported!']) + } + + fn greater_equal(self: @Tensor, other: @Tensor) -> Tensor { + panic(array!['not supported!']) + } + + fn less(self: @Tensor, other: @Tensor) -> Tensor { + panic(array!['not supported!']) + } + + fn less_equal(self: @Tensor, other: @Tensor) -> Tensor { + panic(array!['not supported!']) + } + + fn abs(self: @Tensor) -> Tensor { + math::abs::abs(*self) + } + + fn neg(self: @Tensor) -> Tensor { + panic(array!['not supported!']) + } + + fn ceil(self: @Tensor) -> Tensor { + panic(array!['not supported!']) + } + + fn sin(self: @Tensor) -> Tensor { + math::sin::sin(*self) + } + + fn cos(self: @Tensor) -> Tensor { + math::cos::cos(*self) + } + + fn asin(self: @Tensor) -> Tensor { + math::asin::asin(*self) + } + + fn cumsum( + self: @Tensor, axis: usize, exclusive: Option, reverse: Option + ) -> Tensor { + math::cumsum::cumsum(self, axis, exclusive, reverse) + } + + fn flatten(self: @Tensor, axis: usize) -> Tensor { + math::flatten::flatten(self, axis) + } + + fn sinh(self: @Tensor) -> Tensor { + math::sinh::sinh(*self) + } + + fn tanh(self: @Tensor) -> Tensor { + math::tanh::tanh(*self) + } + + fn cosh(self: @Tensor) -> Tensor { + math::cosh::cosh(*self) + } + + fn acosh(self: @Tensor) -> Tensor { + math::acosh::acosh(*self) + } + + fn asinh(self: @Tensor) -> Tensor { + math::asinh::asinh(*self) + } + + fn atan(self: @Tensor) -> Tensor { + math::atan::atan(*self) + } + + fn xor(self: @Tensor, other: @Tensor) -> Tensor { + panic(array!['not supported!']) + } + + fn or(self: @Tensor, other: @Tensor) -> Tensor { + panic(array!['not supported!']) + } + + fn acos(self: @Tensor) -> Tensor { + math::acos::acos(*self) + } + + fn onehot( + self: @Tensor, depth: usize, axis: Option, values: Span + ) -> Tensor { + panic(array!['not supported!']) + } + + fn sqrt(self: @Tensor) -> Tensor { + math::sqrt::sqrt(*self) + } + + fn concat(tensors: Span>, axis: usize,) -> Tensor { + math::concat::concat(tensors, axis) + } + + fn quantize_linear( + self: @Tensor, y_scale: @Tensor, y_zero_point: @Tensor + ) -> Tensor:: { + panic(array!['not supported!']) + } + + fn dequantize_linear( + self: @Tensor, x_scale: @Tensor, x_zero_point: @Tensor + ) -> Tensor:: { + panic(array!['not supported!']) + } + + fn qlinear_add( + self: @Tensor, + a_scale: @Tensor, + a_zero_point: @Tensor, + b: @Tensor, + b_scale: @Tensor, + b_zero_point: @Tensor, + y_scale: @Tensor, + y_zero_point: @Tensor + ) -> Tensor:: { + panic(array!['not supported!']) + } + + fn qlinear_mul( + self: @Tensor, + a_scale: @Tensor, + a_zero_point: @Tensor, + b: @Tensor, + b_scale: @Tensor, + b_zero_point: @Tensor, + y_scale: @Tensor, + y_zero_point: @Tensor + ) -> Tensor:: { + panic(array!['not supported!']) + } + + fn qlinear_matmul( + self: @Tensor, + a_scale: @Tensor, + a_zero_point: @Tensor, + b: @Tensor, + b_scale: @Tensor, + b_zero_point: @Tensor, + y_scale: @Tensor, + y_zero_point: @Tensor + ) -> Tensor:: { + panic(array!['not supported!']) + } + + fn qlinear_concat( + tensors: Span>, + scales: Span>, + zero_points: Span>, + y_scale: @Tensor, + y_zero_point: @Tensor, + axis: usize + ) -> Tensor:: { + panic(array!['not supported!']) + } + + fn qlinear_leakyrelu( + self: @Tensor, + a_scale: @Tensor, + a_zero_point: @Tensor, + alpha: complex64 + ) -> Tensor:: { + panic(array!['not supported!']) + } + + fn slice( + self: @Tensor, + starts: Span, + ends: Span, + axes: Option>, + steps: Option> + ) -> Tensor { + core_tensor::slice::(self, starts, ends, axes, steps) + } + + fn gather( + self: @Tensor, indices: Tensor, axis: Option + ) -> Tensor { + math::gather::gather(self, indices, axis) + } + + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { + math::gather_nd::gather_nd(self, indices, batch_dims) + } + + fn nonzero(self: @Tensor) -> Tensor { + core_tensor::nonzero(self) + } + + fn squeeze(self: @Tensor, axes: Option>) -> Tensor { + core_tensor::squeeze(self, axes) + } + + fn unsqueeze(self: @Tensor, axes: Span) -> Tensor { + core_tensor::unsqueeze(self, axes) + } + + fn sign(self: @Tensor) -> Tensor { + panic(array!['not supported!']) + } + + fn clip( + self: @Tensor, min: Option, max: Option + ) -> Tensor { + panic(array!['not supported!']) + } + + fn and(self: @Tensor, other: @Tensor) -> Tensor { + math::and::and(self, other) + } + + fn identity(self: @Tensor) -> Tensor { + core_tensor::identity(self) + } + + fn where( + self: @Tensor, x: @Tensor, y: @Tensor + ) -> Tensor { + panic(array!['not supported!']) + } + + fn bitwise_and(self: @Tensor, other: @Tensor) -> Tensor { + panic(array!['not supported!']) + } + + fn bitwise_xor(self: @Tensor, other: @Tensor) -> Tensor { + panic(array!['not supported!']) + } + + fn bitwise_or(self: @Tensor, other: @Tensor) -> Tensor { + panic(array!['not supported!']) + } + + fn round(self: @Tensor) -> Tensor { + panic(array!['not supported!']) + } + + fn reduce_l1(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { + math::reduce_l1::reduce_l1(self, axis, keepdims) + } + + fn array_feature_extractor( + self: @Tensor, indices: Tensor + ) -> Tensor { + ml::array_feature_extractor::array_feature_extractor(*self, indices) + } + + fn binarizer(self: @Tensor, threshold: Option) -> Tensor { + panic(array!['not supported!']) + } + + fn reduce_sum_square( + self: @Tensor, axis: usize, keepdims: bool + ) -> Tensor { + math::reduce_sum_square::reduce_sum_square(self, axis, keepdims) + } + + fn reduce_l2(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { + math::reduce_l2::reduce_l2_complex(self, axis, keepdims) + } + + fn trilu(self: @Tensor, upper: bool, k: i64) -> Tensor { + linalg::trilu::trilu(self, upper, k) + } + + fn scatter( + self: @Tensor, + updates: Tensor, + indices: Tensor, + axis: Option, + reduction: Option + ) -> Tensor { + panic(array!['not supported!']) + } + + fn not(self: @Tensor) -> Tensor { + panic(array!['not supported!']) + } + + + fn gather_elements( + self: @Tensor, indices: Tensor, axis: Option + ) -> Tensor { + math::gather_elements::gather_elements(self, indices, axis) + } + + fn sequence_length(self: Array>) -> Tensor { + math::sequence_length::sequence_length(self) + } + + fn shrink( + self: Tensor, bias: Option, lambd: Option + ) -> Tensor { + panic(array!['not supported!']) + } + + fn sequence_at(sequence: Array>, position: Tensor) -> Tensor { + math::sequence_at::sequence_at(sequence, position) + } + + fn sequence_construct(tensors: Array>) -> Array> { + math::sequence_construct::sequence_construct(tensors) + } + + + fn sequence_empty() -> Array> { + math::sequence_empty::sequence_empty::() + } + + fn reduce_mean( + self: @Tensor, + axes: Option>, + keepdims: Option, + noop_with_empty_axes: Option + ) -> Tensor { + math::reduce_mean::reduce_mean(self, axes, keepdims, noop_with_empty_axes) + } + + fn reduce_min( + self: @Tensor, + axes: Option>, + keepdims: Option, + noop_with_empty_axes: Option + ) -> Tensor { + panic(array!['not supported!']) + } + + fn pow(self: @Tensor, other: @Tensor) -> Tensor { + math::pow::pow(self, other) + } + + fn sequence_erase( + sequence: Array>, position: Option> + ) -> Array> { + math::sequence_erase::sequence_erase(sequence, position) + } + + fn sequence_insert( + self: Array>, tensor: @Tensor, position: Option> + ) -> Array> { + math::sequence_insert::sequence_insert(self, tensor, position) + } + + fn is_inf( + self: @Tensor, detect_negative: Option, detect_positive: Option + ) -> Tensor { + panic(array!['not supported!']) + } + + fn is_nan(self: @Tensor) -> Tensor { + panic(array!['not supported!']) + } + + fn concat_from_sequence( + sequence: Array>, axis: i32, new_axis: Option + ) -> Tensor { + math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) + } + + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { + math::reduce_log_sum::reduce_log_sum(self, axis, keepdims) + } + + + fn erf(self: @Tensor) -> Tensor { + panic(array!['not supported!']) + } + + fn unique( + self: @Tensor, axis: Option, sorted: Option + ) -> (Tensor, Tensor, Tensor, Tensor) { + panic(array!['not supported!']) + } +} + +/// Implements addition for `Tensor` using the `Add` trait. +impl Complex64TensorAdd of Add> { + /// Adds two `Tensor` instances element-wise. + /// + /// # Arguments + /// * `lhs` - The first tensor. + /// * `rhs` - The second tensor. + /// + /// # Returns + /// * A `Tensor` instance representing the result of the element-wise addition. + fn add(lhs: Tensor, rhs: Tensor) -> Tensor { + math::arithmetic::add(@lhs, @rhs) + } +} + +/// Implements subtraction for `Tensor` using the `Sub` trait. +impl Complex64TensorSub of Sub> { + /// Subtracts two `Tensor` instances element-wise. + /// + /// # Arguments + /// * `lhs` - The first tensor. + /// * `rhs` - The second tensor. + /// + /// # Returns + /// * A `Tensor` instance representing the result of the element-wise subtraction. + fn sub(lhs: Tensor, rhs: Tensor) -> Tensor { + math::arithmetic::sub(@lhs, @rhs) + } +} + +/// Implements multiplication for `Tensor` using the `Mul` trait. +impl Complex64TensorMul of Mul> { + /// Multiplies two `Tensor` instances element-wise. + /// + /// # Arguments + /// * `lhs` - The first tensor. + /// * `rhs` - The second tensor. + /// + /// # Returns + /// * A `Tensor` instance representing the result of the element-wise multiplication. + fn mul(lhs: Tensor, rhs: Tensor) -> Tensor { + math::arithmetic::mul(@lhs, @rhs) + } +} + +/// Implements division for `Tensor` using the `Div` trait. +impl Complex64TensorDiv of Div> { + /// Divides two `Tensor` instances element-wise. + /// + /// # Arguments + /// * `lhs` - The first tensor. + /// * `rhs` - The second tensor. + /// + /// # Returns + /// * A `Tensor` instance representing the result of the element-wise division. + fn div(lhs: Tensor, rhs: Tensor) -> Tensor { + math::arithmetic::div(@lhs, @rhs) + } +} + +/// Implements partial equal for two `Tensor` using the `complex64` trait. +impl Complex64TensorPartialEq of PartialEq> { + fn eq(lhs: @Tensor, rhs: @Tensor) -> bool { + tensor_eq(*lhs, *rhs) + } + + fn ne(lhs: @Tensor, rhs: @Tensor) -> bool { + !tensor_eq(*lhs, *rhs) + } +} + + +// Internals + +fn eq(lhs: @complex64, rhs: @complex64) -> bool { + let eq = (*lhs.real == *rhs.real) && (*lhs.img == *rhs.img); + eq +} + +fn tensor_eq(mut lhs: Tensor, mut rhs: Tensor,) -> bool { + let mut is_eq = true; + + loop { + if lhs.shape.len() == 0 || !is_eq { + break; + } + + is_eq = lhs.shape.pop_front().unwrap() == rhs.shape.pop_front().unwrap(); + }; + + if !is_eq { + return false; + } + + loop { + if lhs.data.len() == 0 || !is_eq { + break; + } + + is_eq = eq(lhs.data.pop_front().unwrap(), rhs.data.pop_front().unwrap()); + }; + + return is_eq; +} + diff --git a/src/operators/tensor/implementations/tensor_fp16x16.cairo b/src/operators/tensor/implementations/tensor_fp16x16.cairo index ccaf5903d..ce489a45a 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16.cairo @@ -520,10 +520,12 @@ impl FP16x16Tensor of TensorTrait { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } - + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { math::reduce_log_sum::reduce_log_sum(self, axis, keepdims) } diff --git a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo index dc32202ed..1ea152edf 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo @@ -486,10 +486,12 @@ impl FP16x16WTensor of TensorTrait { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } - + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { math::reduce_log_sum::reduce_log_sum(self, axis, keepdims) } diff --git a/src/operators/tensor/implementations/tensor_fp32x32.cairo b/src/operators/tensor/implementations/tensor_fp32x32.cairo index 9100d6f82..5e9b5512a 100644 --- a/src/operators/tensor/implementations/tensor_fp32x32.cairo +++ b/src/operators/tensor/implementations/tensor_fp32x32.cairo @@ -521,10 +521,12 @@ impl FP32x32Tensor of TensorTrait { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } - + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { math::reduce_log_sum::reduce_log_sum(self, axis, keepdims) } diff --git a/src/operators/tensor/implementations/tensor_fp64x64.cairo b/src/operators/tensor/implementations/tensor_fp64x64.cairo index ee6441058..2b5142789 100644 --- a/src/operators/tensor/implementations/tensor_fp64x64.cairo +++ b/src/operators/tensor/implementations/tensor_fp64x64.cairo @@ -522,10 +522,12 @@ impl FP64x64Tensor of TensorTrait { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } - + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { math::reduce_log_sum::reduce_log_sum(self, axis, keepdims) } diff --git a/src/operators/tensor/implementations/tensor_fp8x23.cairo b/src/operators/tensor/implementations/tensor_fp8x23.cairo index 17a601f7b..18cd91af7 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23.cairo @@ -520,10 +520,12 @@ impl FP8x23Tensor of TensorTrait { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } - + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { math::reduce_log_sum::reduce_log_sum(self, axis, keepdims) } diff --git a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo index a7d19901b..a1f9bcbdf 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo @@ -473,10 +473,12 @@ impl FP8x23WTensor of TensorTrait { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } - + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { math::reduce_log_sum::reduce_log_sum(self, axis, keepdims) } diff --git a/src/operators/tensor/implementations/tensor_i32.cairo b/src/operators/tensor/implementations/tensor_i32.cairo index a987b0633..c8753c5a9 100644 --- a/src/operators/tensor/implementations/tensor_i32.cairo +++ b/src/operators/tensor/implementations/tensor_i32.cairo @@ -517,10 +517,12 @@ impl I32Tensor of TensorTrait { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } - + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { panic(array!['not supported!']) } diff --git a/src/operators/tensor/implementations/tensor_i8.cairo b/src/operators/tensor/implementations/tensor_i8.cairo index 8c1e2fd32..fa67a7950 100644 --- a/src/operators/tensor/implementations/tensor_i8.cairo +++ b/src/operators/tensor/implementations/tensor_i8.cairo @@ -515,10 +515,12 @@ impl I8Tensor of TensorTrait { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } - + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { panic(array!['not supported!']) } diff --git a/src/operators/tensor/implementations/tensor_u32.cairo b/src/operators/tensor/implementations/tensor_u32.cairo index 5b2058401..9a90ef883 100644 --- a/src/operators/tensor/implementations/tensor_u32.cairo +++ b/src/operators/tensor/implementations/tensor_u32.cairo @@ -458,10 +458,12 @@ impl U32Tensor of TensorTrait { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } - + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { panic(array!['not supported!']) } diff --git a/src/operators/tensor/math/acos.cairo b/src/operators/tensor/math/acos.cairo index 1e6fbcfe8..c36260752 100644 --- a/src/operators/tensor/math/acos.cairo +++ b/src/operators/tensor/math/acos.cairo @@ -2,6 +2,7 @@ use core::array::ArrayTrait; use core::array::SpanTrait; use core::option::OptionTrait; +use orion::numbers::NumberTrait; use orion::numbers::fixed_point::core::FixedTrait; use orion::operators::tensor::core::{Tensor, TensorTrait}; @@ -9,10 +10,10 @@ use orion::operators::tensor::core::{Tensor, TensorTrait}; fn acos< T, MAG, - impl FFixedTrait: FixedTrait, - impl FTensor: TensorTrait, - impl FCopy: Copy, - impl FDrop: Drop, + impl TNumberTrait: NumberTrait, + impl TTensor: TensorTrait, + impl TCopy: Copy, + impl TDrop: Drop, >( mut self: Tensor ) -> Tensor { diff --git a/src/operators/tensor/math/acosh.cairo b/src/operators/tensor/math/acosh.cairo index 78649e620..f486d5609 100644 --- a/src/operators/tensor/math/acosh.cairo +++ b/src/operators/tensor/math/acosh.cairo @@ -11,10 +11,10 @@ use orion::operators::tensor::core::{Tensor, TensorTrait}; fn acosh< T, MAG, - impl FFixedTrait: FixedTrait, - impl FTensor: TensorTrait, - impl FCopy: Copy, - impl FDrop: Drop, + impl TNumberTrait: NumberTrait, + impl TTensor: TensorTrait, + impl TCopy: Copy, + impl TDrop: Drop, >( mut self: Tensor ) -> Tensor { diff --git a/src/operators/tensor/math/asin.cairo b/src/operators/tensor/math/asin.cairo index 7018edcc6..b33132797 100644 --- a/src/operators/tensor/math/asin.cairo +++ b/src/operators/tensor/math/asin.cairo @@ -2,6 +2,7 @@ use core::array::ArrayTrait; use core::array::SpanTrait; use core::option::OptionTrait; +use orion::numbers::NumberTrait; use orion::numbers::fixed_point::core::FixedTrait; use orion::operators::tensor::core::{Tensor, TensorTrait}; @@ -9,10 +10,10 @@ use orion::operators::tensor::core::{Tensor, TensorTrait}; fn asin< T, MAG, - impl FFixedTrait: FixedTrait, - impl FTensor: TensorTrait, - impl FCopy: Copy, - impl FDrop: Drop, + impl TNumberTrait: NumberTrait, + impl TTensor: TensorTrait, + impl TCopy: Copy, + impl TDrop: Drop, >( mut self: Tensor ) -> Tensor { diff --git a/src/operators/tensor/math/asinh.cairo b/src/operators/tensor/math/asinh.cairo index 28fcb467f..8d015554d 100644 --- a/src/operators/tensor/math/asinh.cairo +++ b/src/operators/tensor/math/asinh.cairo @@ -12,10 +12,10 @@ use orion::operators::tensor::core::{Tensor, TensorTrait}; fn asinh< T, MAG, - impl FFixedTrait: FixedTrait, - impl FTensor: TensorTrait, - impl FCopy: Copy, - impl FDrop: Drop, + impl TNumberTrait: NumberTrait, + impl TTensor: TensorTrait, + impl TCopy: Copy, + impl TDrop: Drop, >( mut self: Tensor ) -> Tensor { diff --git a/src/operators/tensor/math/atan.cairo b/src/operators/tensor/math/atan.cairo index 25e991700..b5f93eb1c 100644 --- a/src/operators/tensor/math/atan.cairo +++ b/src/operators/tensor/math/atan.cairo @@ -11,10 +11,10 @@ use orion::operators::tensor::core::{Tensor, TensorTrait}; fn atan< T, MAG, - impl FFixedTrait: FixedTrait, - impl FTensor: TensorTrait, - impl FCopy: Copy, - impl FDrop: Drop, + impl TNumberTrait: NumberTrait, + impl TTensor: TensorTrait, + impl TCopy: Copy, + impl TDrop: Drop, >( mut self: Tensor ) -> Tensor { diff --git a/src/operators/tensor/math/cos.cairo b/src/operators/tensor/math/cos.cairo index aad6ea925..5abeb327d 100644 --- a/src/operators/tensor/math/cos.cairo +++ b/src/operators/tensor/math/cos.cairo @@ -12,10 +12,10 @@ use orion::operators::tensor::core::{Tensor, TensorTrait}; fn cos< T, MAG, - impl FFixedTrait: FixedTrait, - impl FTensor: TensorTrait, - impl FCopy: Copy, - impl FDrop: Drop, + impl TNumberTrait: NumberTrait, + impl TTensor: TensorTrait, + impl TCopy: Copy, + impl TDrop: Drop, >( mut self: Tensor ) -> Tensor { diff --git a/src/operators/tensor/math/cosh.cairo b/src/operators/tensor/math/cosh.cairo index f3f15a284..08c434f5a 100644 --- a/src/operators/tensor/math/cosh.cairo +++ b/src/operators/tensor/math/cosh.cairo @@ -12,10 +12,10 @@ use orion::operators::tensor::core::{Tensor, TensorTrait}; fn cosh< T, MAG, - impl FFixedTrait: FixedTrait, - impl FTensor: TensorTrait, - impl FCopy: Copy, - impl FDrop: Drop, + impl TNumberTrait: NumberTrait, + impl TTensor: TensorTrait, + impl TCopy: Copy, + impl TDrop: Drop, >( mut self: Tensor ) -> Tensor { diff --git a/src/operators/tensor/math/exp.cairo b/src/operators/tensor/math/exp.cairo index a7e04dca1..889082d56 100644 --- a/src/operators/tensor/math/exp.cairo +++ b/src/operators/tensor/math/exp.cairo @@ -12,10 +12,10 @@ use orion::operators::tensor::core::{Tensor, TensorTrait}; fn exp< T, MAG, - impl FFixedTrait: FixedTrait, - impl FTensor: TensorTrait, - impl FCopy: Copy, - impl FDrop: Drop, + impl TNumberTrait: NumberTrait, + impl TTensor: TensorTrait, + impl TCopy: Copy, + impl TDrop: Drop, >( mut self: Tensor ) -> Tensor { diff --git a/src/operators/tensor/math/gather_nd.cairo b/src/operators/tensor/math/gather_nd.cairo index 120eff5be..737a4fe32 100644 --- a/src/operators/tensor/math/gather_nd.cairo +++ b/src/operators/tensor/math/gather_nd.cairo @@ -14,12 +14,7 @@ use orion::operators::tensor::U32TensorPartialEq; use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor}; /// Cf: TensorTrait::gather_nd docstring -fn gather_nd< - T, - impl TTensorTrait: TensorTrait, - impl TCopy: Copy, - impl TDrop: Drop, ->( +fn gather_nd, impl TCopy: Copy, impl TDrop: Drop,>( self: @Tensor, indices: Tensor, batch_dims: Option ) -> Tensor { let batch_dims = match batch_dims { @@ -29,19 +24,22 @@ fn gather_nd< let data_rank = (*self.shape).len(); let indices_rank = (indices.shape).len(); - assert((data_rank >= 1 ) & (indices_rank >= 1), 'rank must > 1'); - + assert((data_rank >= 1) & (indices_rank >= 1), 'rank must > 1'); + let mut data_shape = *self.shape; let mut indices_shape = indices.shape; let mut data_shape_clone = data_shape.clone(); let mut indices_shape_clone = indices_shape.clone(); let indices_shape_last = indices_shape_clone.pop_back().unwrap(); - assert((*indices_shape_last >= 1) & (*indices_shape_last <= data_rank-batch_dims), 'check indices'); + assert( + (*indices_shape_last >= 1) & (*indices_shape_last <= data_rank - batch_dims), + 'check indices' + ); let mut batch_dims_shape = ArrayTrait::new(); let mut output_shape = ArrayTrait::new(); - let mut index_data = ArrayTrait::new(); + let mut index_data = ArrayTrait::new(); let mut output_data = ArrayTrait::new(); let mut batch_dims_size = batch_dims; @@ -51,7 +49,7 @@ fn gather_nd< let mut ind = 0; loop { if (ind == batch_dims) { - break(); + break (); } match indices_shape_clone.pop_front() { Option::Some(val) => { @@ -65,17 +63,14 @@ fn gather_nd< loop { match indices_shape_clone.pop_front() { - Option::Some(val) => { - batch_dims_shape.append(*val); - }, + Option::Some(val) => { batch_dims_shape.append(*val); }, Option::None(_) => { break; } }; }; if (*indices_shape_last == data_rank - batch_dims) { output_shape = batch_dims_shape; - } - else { + } else { let mut ind = 0; let mut multiple = 1; output_shape = batch_dims_shape; @@ -136,16 +131,18 @@ fn gather_nd< match data_indices.pop_front() { Option::Some(val) => { let index = ind % *indices_shape_last; - let incr= total_data_len * (ind / breaker); + let incr = total_data_len * (ind / breaker); result += (*val * total_data_len / *multiple_data_len.at(index)); ind += 1; - if (index == *indices_shape_last-1) { - let mut data_ind:usize = result ; + if (index == *indices_shape_last - 1) { + let mut data_ind: usize = result; loop { - if data_ind == result + incrementer { break; } + if data_ind == result + incrementer { + break; + } index_data.append(data_ind + incr); - data_ind+=1; + data_ind += 1; }; result = 0; }; @@ -156,13 +153,11 @@ fn gather_nd< loop { match index_data.pop_front() { - Option::Some(val) => { - output_data.append(*self.data[val]); - }, + Option::Some(val) => { output_data.append(*self.data[val]); }, Option::None(_) => { break; } }; }; let mut output_tensor = TensorTrait::::new(output_shape.span(), output_data.span()); return output_tensor; -} \ No newline at end of file +} diff --git a/src/operators/tensor/math/log.cairo b/src/operators/tensor/math/log.cairo index 817ca135a..e55291fca 100644 --- a/src/operators/tensor/math/log.cairo +++ b/src/operators/tensor/math/log.cairo @@ -12,10 +12,10 @@ use orion::operators::tensor::core::{Tensor, TensorTrait}; fn log< T, MAG, - impl FFixedTrait: FixedTrait, - impl FTensor: TensorTrait, - impl FCopy: Copy, - impl FDrop: Drop, + impl TNumberTrait: NumberTrait, + impl TTensor: TensorTrait, + impl TCopy: Copy, + impl TDrop: Drop, >( mut self: Tensor ) -> Tensor { diff --git a/src/operators/tensor/math/reduce_l2.cairo b/src/operators/tensor/math/reduce_l2.cairo index b18fa704c..d03499fec 100644 --- a/src/operators/tensor/math/reduce_l2.cairo +++ b/src/operators/tensor/math/reduce_l2.cairo @@ -11,12 +11,11 @@ use orion::numbers::fixed_point::core::FixedTrait; fn square< T, MAG, - impl FTensorTrait: TensorTrait, - impl FFixed: FixedTrait, - impl FNumber: NumberTrait, + impl TTensorTrait: TensorTrait, + impl TNumber: NumberTrait, impl TMul: Mul, - impl FCopy: Copy, - impl FDrop: Drop, + impl TCopy: Copy, + impl TDrop: Drop, >( self: @Tensor ) -> Tensor { @@ -41,7 +40,6 @@ fn reduce_l2< T, MAG, impl TTensor: TensorTrait, - impl FFixed: FixedTrait, impl TNumber: NumberTrait, impl TMul: Mul, impl TCopy: Copy, @@ -53,3 +51,23 @@ fn reduce_l2< let tensor_square_sum = tensor_square.reduce_sum(axis: axis, keepdims: keepdims); return tensor_square_sum.sqrt(); } + +fn reduce_l2_complex< + T, + MAG, + impl TTensor: TensorTrait, + impl TNumber: NumberTrait, + impl TMul: Mul, + impl TCopy: Copy, + impl TDrop: Drop, + impl TPrint: PrintTrait +>( + self: @Tensor, axis: usize, keepdims: bool +) -> Tensor { + let mut tensor_square = square(@self.abs()); + + let mut tensor_square_sum = tensor_square.reduce_sum(axis: axis, keepdims: keepdims); + + return tensor_square_sum.sqrt(); +} + diff --git a/src/operators/tensor/math/sin.cairo b/src/operators/tensor/math/sin.cairo index b1471d77e..cc810eab7 100644 --- a/src/operators/tensor/math/sin.cairo +++ b/src/operators/tensor/math/sin.cairo @@ -12,10 +12,10 @@ use orion::operators::tensor::core::{Tensor, TensorTrait}; fn sin< T, MAG, - impl FFixedTrait: FixedTrait, - impl FTensor: TensorTrait, - impl FCopy: Copy, - impl FDrop: Drop, + impl TNumberTrait: NumberTrait, + impl TTensor: TensorTrait, + impl TCopy: Copy, + impl TDrop: Drop, >( mut self: Tensor ) -> Tensor { diff --git a/src/operators/tensor/math/sinh.cairo b/src/operators/tensor/math/sinh.cairo index 5d3ef828b..7c3373288 100644 --- a/src/operators/tensor/math/sinh.cairo +++ b/src/operators/tensor/math/sinh.cairo @@ -12,10 +12,10 @@ use orion::operators::tensor::core::{Tensor, TensorTrait}; fn sinh< T, MAG, - impl FFixedTrait: FixedTrait, - impl FTensor: TensorTrait, - impl FCopy: Copy, - impl FDrop: Drop, + impl TNumberTrait: NumberTrait, + impl TTensor: TensorTrait, + impl TCopy: Copy, + impl TDrop: Drop, >( mut self: Tensor ) -> Tensor { diff --git a/src/operators/tensor/math/sqrt.cairo b/src/operators/tensor/math/sqrt.cairo index 84d23c150..f3111bed9 100644 --- a/src/operators/tensor/math/sqrt.cairo +++ b/src/operators/tensor/math/sqrt.cairo @@ -11,10 +11,10 @@ use orion::operators::tensor::core::{Tensor, TensorTrait}; fn sqrt< T, MAG, - impl FFixedTrait: FixedTrait, - impl FTensor: TensorTrait, - impl FCopy: Copy, - impl FDrop: Drop, + impl TNumberTrait: NumberTrait, + impl TTensor: TensorTrait, + impl TCopy: Copy, + impl TDrop: Drop, >( mut self: Tensor ) -> Tensor { diff --git a/src/operators/tensor/math/tanh.cairo b/src/operators/tensor/math/tanh.cairo index b94fba485..f6f3eb6e1 100644 --- a/src/operators/tensor/math/tanh.cairo +++ b/src/operators/tensor/math/tanh.cairo @@ -12,10 +12,10 @@ use orion::operators::tensor::core::{Tensor, TensorTrait}; fn tanh< T, MAG, - impl FFixedTrait: FixedTrait, - impl FTensor: TensorTrait, - impl FCopy: Copy, - impl FDrop: Drop, + impl TNumberTrait: NumberTrait, + impl TTensor: TensorTrait, + impl TCopy: Copy, + impl TDrop: Drop, >( mut self: Tensor ) -> Tensor { diff --git a/tests/lib.cairo b/tests/lib.cairo index f5cecb77d..c408347ef 100644 --- a/tests/lib.cairo +++ b/tests/lib.cairo @@ -5,3 +5,4 @@ mod nodes; mod ml; mod operators; + diff --git a/tests/nodes.cairo b/tests/nodes.cairo index c7155e942..57312ec2f 100644 --- a/tests/nodes.cairo +++ b/tests/nodes.cairo @@ -850,3 +850,4 @@ mod gather_nd_i8_3d_batch_dims1; mod gather_nd_u32_default; mod gather_nd_u32_batch_dims1; mod gather_nd_u32_batch_dims2; +mod reduce_l2_complex64_axis_0; diff --git a/tests/nodes/gather_nd_fp16x16_3d_batch_dims1.cairo b/tests/nodes/gather_nd_fp16x16_3d_batch_dims1.cairo index d2c0b80dd..025cc8261 100644 --- a/tests/nodes/gather_nd_fp16x16_3d_batch_dims1.cairo +++ b/tests/nodes/gather_nd_fp16x16_3d_batch_dims1.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_fp16x16_3d_batch_dims1() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(1)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(1)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_fp16x16_3d_batch_dims2.cairo b/tests/nodes/gather_nd_fp16x16_3d_batch_dims2.cairo index 507847851..677a40f6a 100644 --- a/tests/nodes/gather_nd_fp16x16_3d_batch_dims2.cairo +++ b/tests/nodes/gather_nd_fp16x16_3d_batch_dims2.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_fp16x16_3d_batch_dims2() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(2)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(2)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_fp16x16_3d_default.cairo b/tests/nodes/gather_nd_fp16x16_3d_default.cairo index ae4609a66..b8339a0d2 100644 --- a/tests/nodes/gather_nd_fp16x16_3d_default.cairo +++ b/tests/nodes/gather_nd_fp16x16_3d_default.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_fp16x16_3d_default() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(0)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(0)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_fp8x23_3d_batch_dims1.cairo b/tests/nodes/gather_nd_fp8x23_3d_batch_dims1.cairo index b9a083796..65980d91f 100644 --- a/tests/nodes/gather_nd_fp8x23_3d_batch_dims1.cairo +++ b/tests/nodes/gather_nd_fp8x23_3d_batch_dims1.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_fp8x23_3d_batch_dims1() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(1)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(1)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_fp8x23_3d_batch_dims2.cairo b/tests/nodes/gather_nd_fp8x23_3d_batch_dims2.cairo index 5e42ca893..48c812baf 100644 --- a/tests/nodes/gather_nd_fp8x23_3d_batch_dims2.cairo +++ b/tests/nodes/gather_nd_fp8x23_3d_batch_dims2.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_fp8x23_3d_batch_dims2() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(2)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(2)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_fp8x23_3d_default.cairo b/tests/nodes/gather_nd_fp8x23_3d_default.cairo index 12b6408e0..342cd2b72 100644 --- a/tests/nodes/gather_nd_fp8x23_3d_default.cairo +++ b/tests/nodes/gather_nd_fp8x23_3d_default.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_fp8x23_3d_default() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(0)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(0)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_i32_3d_batch_dims1.cairo b/tests/nodes/gather_nd_i32_3d_batch_dims1.cairo index 243b0ca16..318ccd62e 100644 --- a/tests/nodes/gather_nd_i32_3d_batch_dims1.cairo +++ b/tests/nodes/gather_nd_i32_3d_batch_dims1.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_i32_3d_batch_dims1() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(1)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(1)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_i32_3d_batch_dims2.cairo b/tests/nodes/gather_nd_i32_3d_batch_dims2.cairo index d11370b94..177c8e40f 100644 --- a/tests/nodes/gather_nd_i32_3d_batch_dims2.cairo +++ b/tests/nodes/gather_nd_i32_3d_batch_dims2.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_i32_3d_batch_dims2() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(2)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(2)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_i32_3d_default.cairo b/tests/nodes/gather_nd_i32_3d_default.cairo index 35c054093..97212f737 100644 --- a/tests/nodes/gather_nd_i32_3d_default.cairo +++ b/tests/nodes/gather_nd_i32_3d_default.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_i32_3d_default() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(0)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(0)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_i8_3d_batch_dims1.cairo b/tests/nodes/gather_nd_i8_3d_batch_dims1.cairo index ae83a8c7d..f849c8677 100644 --- a/tests/nodes/gather_nd_i8_3d_batch_dims1.cairo +++ b/tests/nodes/gather_nd_i8_3d_batch_dims1.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_i8_3d_batch_dims1() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(1)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(1)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_i8_3d_default.cairo b/tests/nodes/gather_nd_i8_3d_default.cairo index 73e1d91b2..ff7ad9252 100644 --- a/tests/nodes/gather_nd_i8_3d_default.cairo +++ b/tests/nodes/gather_nd_i8_3d_default.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_i8_3d_default() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(0)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(0)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_u32_batch_dims1.cairo b/tests/nodes/gather_nd_u32_batch_dims1.cairo index 0428ec1d5..860675f66 100644 --- a/tests/nodes/gather_nd_u32_batch_dims1.cairo +++ b/tests/nodes/gather_nd_u32_batch_dims1.cairo @@ -16,7 +16,7 @@ fn test_gather_nd_u32_batch_dims1() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(1)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(1)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_u32_batch_dims2.cairo b/tests/nodes/gather_nd_u32_batch_dims2.cairo index 39857ef1d..f0662be99 100644 --- a/tests/nodes/gather_nd_u32_batch_dims2.cairo +++ b/tests/nodes/gather_nd_u32_batch_dims2.cairo @@ -16,7 +16,7 @@ fn test_gather_nd_u32_batch_dims2() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(2)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(2)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_u32_default.cairo b/tests/nodes/gather_nd_u32_default.cairo index f55b49d5e..be6edd699 100644 --- a/tests/nodes/gather_nd_u32_default.cairo +++ b/tests/nodes/gather_nd_u32_default.cairo @@ -16,7 +16,7 @@ fn test_gather_nd_u32_default() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(0)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(0)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/reduce_l2_complex64_axis_0.cairo b/tests/nodes/reduce_l2_complex64_axis_0.cairo new file mode 100644 index 000000000..bb1391263 --- /dev/null +++ b/tests/nodes/reduce_l2_complex64_axis_0.cairo @@ -0,0 +1,24 @@ +mod input_0; +mod output_0; + + +use orion::operators::tensor::Complex64Tensor; +use core::array::{ArrayTrait, SpanTrait}; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::Complex64TensorPartialEq; + +use orion::numbers::complex_number::complex64::Complex64Print; +use orion::numbers::{NumberTrait, complex64}; + +#[test] +#[available_gas(2000000000)] +fn test_reduce_l2_complex64_axis_0() { + let input_0 = input_0::input_0(); + let z_0 = output_0::output_0(); + + let y_0 = input_0.reduce_l2(0, true); + + assert_eq(y_0, z_0); +} + diff --git a/tests/nodes/reduce_l2_complex64_axis_0/input_0.cairo b/tests/nodes/reduce_l2_complex64_axis_0/input_0.cairo new file mode 100644 index 000000000..6d7281bfb --- /dev/null +++ b/tests/nodes/reduce_l2_complex64_axis_0/input_0.cairo @@ -0,0 +1,56 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::Complex64Tensor; +use orion::numbers::{NumberTrait, complex64}; +use orion::numbers::{FixedTrait, FP64x64}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + + let mut data = ArrayTrait::new(); + data + .append( + complex64 { + real: FP64x64 { mag: 18446744073709551616, sign: false }, + img: FP64x64 { mag: 36893488147419103232, sign: false } + } + ); + data + .append( + complex64 { + real: FP64x64 { mag: 36893488147419103232, sign: false }, + img: FP64x64 { mag: 18446744073709551616, sign: true } + } + ); + data + .append( + complex64 { + real: FP64x64 { mag: 55340232221128654848, sign: false }, + img: FP64x64 { mag: 55340232221128654848, sign: true } + } + ); + data + .append( + complex64 { + real: FP64x64 { mag: 55340232221128654848, sign: false }, + img: FP64x64 { mag: 36893488147419103232, sign: true } + } + ); + data + .append( + complex64 { + real: FP64x64 { mag: 55340232221128654848, sign: false }, + img: FP64x64 { mag: 92233720368547758080, sign: false } + } + ); + data + .append( + complex64 { + real: FP64x64 { mag: 73786976294838206464, sign: false }, + img: FP64x64 { mag: 18446744073709551616, sign: true } + } + ); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/reduce_l2_complex64_axis_0/output_0.cairo b/tests/nodes/reduce_l2_complex64_axis_0/output_0.cairo new file mode 100644 index 000000000..ee6564432 --- /dev/null +++ b/tests/nodes/reduce_l2_complex64_axis_0/output_0.cairo @@ -0,0 +1,35 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::Complex64Tensor; +use orion::numbers::{NumberTrait, complex64}; +use orion::numbers::{FixedTrait, FP64x64}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(3); + + let mut data = ArrayTrait::new(); + data + .append( + complex64 { + real: FP64x64 { mag: 78262906948318920704, sign: false }, + img: FP64x64 { mag: 0, sign: false } + } + ); + data + .append( + complex64 { + real: FP64x64 { mag: 115199879809953955840, sign: false }, + img: FP64x64 { mag: 0, sign: false } + } + ); + data + .append( + complex64 { + real: FP64x64 { mag: 109132409670155108352, sign: false }, + img: FP64x64 { mag: 0, sign: false } + } + ); + TensorTrait::new(shape.span(), data.span()) +} From da060ad678a45861669803d346706b05a20359d7 Mon Sep 17 00:00:00 2001 From: zhangzhichao Date: Mon, 25 Dec 2023 15:00:51 +0800 Subject: [PATCH 12/38] fix: Fix exchange logic 1. Fix res object exchange logic in binary 2. Add binary testcase in TEC operator --- .../tree_ensemble_classifier.cairo | 11 +- tests/ml/tree_ensemble_classifier.cairo | 334 ++++++++++++++++++ 2 files changed, 339 insertions(+), 6 deletions(-) diff --git a/src/operators/ml/tree_ensemble/tree_ensemble_classifier.cairo b/src/operators/ml/tree_ensemble/tree_ensemble_classifier.cairo index f882655e8..eb50a2e14 100644 --- a/src/operators/ml/tree_ensemble/tree_ensemble_classifier.cairo +++ b/src/operators/ml/tree_ensemble/tree_ensemble_classifier.cairo @@ -270,7 +270,6 @@ impl TreeEnsembleClassifierImpl< fn predict(ref self: TreeEnsembleClassifier, X: Tensor) -> (Span, MutMatrix::) { let leaves_index = self.ensemble.leave_index_tree(X); let n_classes = self.classlabels.len(); - assert(n_classes > 1, 'binary class not supported yet'); let mut res: MutMatrix = MutMatrixImpl::new(*leaves_index.shape.at(0), n_classes); // Set base values @@ -463,7 +462,7 @@ impl TreeEnsembleClassifierImpl< break; } // Exchange - let res_ele_0 = match res.get(i, 1) { + let res_ele_0 = match new_res.get(i, 1) { Option::Some(res_1) => { let value = NumberTrait::sub(NumberTrait::one(), res_1); new_res.set(i, 0, value); @@ -482,7 +481,7 @@ impl TreeEnsembleClassifierImpl< break; } // Exchange - let res_ele_0 = match res.get(i, 1) { + let res_ele_0 = match new_res.get(i, 1) { Option::Some(res_1) => { new_res.set(i, 0, res_1.neg()); }, @@ -500,7 +499,7 @@ impl TreeEnsembleClassifierImpl< break; } // Exchange - let res_ele_0 = match res.get(i, 1) { + let res_ele_0 = match new_res.get(i, 1) { Option::Some(res_1) => { new_res.set(i, 0, res_1.neg()); }, @@ -518,7 +517,7 @@ impl TreeEnsembleClassifierImpl< break; } // Exchange - let res_ele_0 = match res.get(i, 1) { + let res_ele_0 = match new_res.get(i, 1) { Option::Some(res_1) => { new_res.set(i, 0, res_1.neg()); }, @@ -536,7 +535,7 @@ impl TreeEnsembleClassifierImpl< break; } // Exchange - let res_ele_0 = match res.get(i, 1) { + let res_ele_0 = match new_res.get(i, 1) { Option::Some(res_1) => { let value = NumberTrait::sub(NumberTrait::one(), res_1); new_res.set(i, 0, value); diff --git a/tests/ml/tree_ensemble_classifier.cairo b/tests/ml/tree_ensemble_classifier.cairo index cee9ee80d..6ee2afc11 100644 --- a/tests/ml/tree_ensemble_classifier.cairo +++ b/tests/ml/tree_ensemble_classifier.cairo @@ -216,6 +216,116 @@ fn test_tree_ensemble_classifier_multi_pt_logistic() { ); } +#[test] +#[available_gas(200000000000)] +fn test_tree_ensemble_classifier_binary_none() { + let (mut classifier, X) = tree_ensemble_classifier_binary_class_helper(POST_TRANSFORM::NONE); + + let (labels, mut scores) = TreeEnsembleClassifierTrait::predict(ref classifier, X); + + // ASSERT LABELS + assert(*labels[0] == 1, 'labels[0]'); + assert(labels.len() == 1, 'len(labels)'); + + // ASSERT SCORES + assert( + relative_eq(@scores.get(0, 0).unwrap(), @FP16x16 { mag: 0, sign: false }) == true, + 'score[0, 0]' + ); + assert( + relative_eq(@scores.get(0, 1).unwrap(), @FP16x16 { mag: 65536, sign: false }) == true, + 'score[0, 1]' + ); +} + +#[test] +#[available_gas(200000000000)] +fn test_tree_ensemble_classifier_binary_logistic() { + + let (mut classifier, X) = tree_ensemble_classifier_binary_class_helper(POST_TRANSFORM::LOGISTIC); + + let (labels, mut scores) = TreeEnsembleClassifierTrait::predict(ref classifier, X); + + // ASSERT LABELS + assert(*labels[0] == 1, 'labels[0]'); + assert(labels.len() == 1, 'len(labels)'); + + // ASSERT SCORES + assert( + relative_eq(@scores.get(0, 0).unwrap(), @FP16x16 { mag: 17625, sign: false }) == true, + 'score[0, 0]' + ); + assert( + relative_eq(@scores.get(0, 1).unwrap(), @FP16x16 { mag: 47910, sign: false }) == true, + 'score[0, 1]' + ); +} + +#[test] +#[available_gas(200000000000)] +fn test_tree_ensemble_classifier_binary_softmax() { + let (mut classifier, X) = tree_ensemble_classifier_binary_class_helper(POST_TRANSFORM::SOFTMAX); + + let (labels, mut scores) = TreeEnsembleClassifierTrait::predict(ref classifier, X); + + // ASSERT LABELS + assert(*labels[0] == 1, 'labels[0]'); + assert(labels.len() == 1, 'len(labels)'); + + // ASSERT SCORES + assert( + relative_eq(@scores.get(0, 0).unwrap(), @FP16x16 { mag: 7812, sign: false }) == true, + 'score[0, 0]' + ); + assert( + relative_eq(@scores.get(0, 1).unwrap(), @FP16x16 { mag: 57723, sign: false }) == true, + 'score[0, 1]' + ); +} + +#[test] +#[available_gas(200000000000)] +fn test_tree_ensemble_classifier_binary_softmax_zero() { + let (mut classifier, X) = tree_ensemble_classifier_binary_class_helper(POST_TRANSFORM::SOFTMAXZERO); + + let (labels, mut scores) = TreeEnsembleClassifierTrait::predict(ref classifier, X); + + // ASSERT LABELS + assert(*labels[0] == 1, 'labels[0]'); + assert(labels.len() == 1, 'len(labels)'); + + // ASSERT SCORES + assert( + relative_eq(@scores.get(0, 0).unwrap(), @FP16x16 { mag: 7812, sign: false }) == true, + 'score[0, 0]' + ); + assert( + relative_eq(@scores.get(0, 1).unwrap(), @FP16x16 { mag: 57723, sign: false }) == true, + 'score[0, 1]' + ); +} + +// #[test] +// #[available_gas(200000000000)] +// fn test_tree_ensemble_classifier_binary_probit() { +// let (mut classifier, X) = tree_ensemble_classifier_binary_class_helper(POST_TRANSFORM::PROBIT); + +// let (labels, mut scores) = TreeEnsembleClassifierTrait::predict(ref classifier, X); + +// // ASSERT LABELS +// assert(*labels[0] == 1, 'labels[0]'); +// assert(labels.len() == 1, 'len(labels)'); + +// // ASSERT SCORES +// assert( +// relative_eq(@scores.get(0, 0).unwrap(), @FP16x16 { mag: 0, sign: false }) == true, +// 'score[0, 0]' +// ); +// assert( +// relative_eq(@scores.get(0, 1).unwrap(), @FP16x16 { mag: 65536, sign: false }) == true, +// 'score[0, 1]' +// ); +// } // ============ HELPER ============ // @@ -369,3 +479,227 @@ fn tree_ensemble_classifier_helper( (classifier, X) } + +// ============ BINARY CLASS HELPER ============ // + +fn tree_ensemble_classifier_binary_class_helper( + post_transform: POST_TRANSFORM +) -> (TreeEnsembleClassifier, Tensor) { + let class_ids: Span = array![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0].span(); + let class_nodeids: Span = array![4, 5, 7, 10, 12, 13, 15, 17, 19, 20, 24, 26, 29, 31, 32, 33, 37, 38, 39, 40, 46, 49, 50, 52, 56, 57, 58, 59, 62, 64, 66, 67, 68, 73, 74, 75, 76, 81, 82, 83, 84, 88, 89, 91, 93, 94, 95, 98, 99, 101, 104, 106, 107, 108, 112, 113, 114, 115, 119, 121, 124, 125, 127, 128, 130, 131, 138, 140, 141, 142, 143, 148, 149, 150, 151, 152, 153, 154].span(); + let class_treeids: Span = array![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0].span(); + let class_weights: Span = array![FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 43690, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }].span(); + let classlabels: Span = array![0, 1].span(); + let nodes_falsenodeids: Span = array![116, 21, 6, 5, 0, 0, 8, 0, 14, 11, 0, 13, 0, 0, 16, 0, 18, 0, 20, 0, 0, 41, 34, 25, 0, 27, 0, 33, 30, 0, 32, 0, 0, 0, 40, 39, 38, 0, 0, 0, 0, 109, 96, 69, 60, 47, 0, 51, 50, 0, 0, 53, 0, 59, 58, 57, 0, 0, 0, 0, 68, 63, 0, 65, 0, 67, 0, 0, 0, 77, 76, 75, 74, 0, 0, 0, 0, 85, 84, 83, 82, 0, 0, 0, 0, 95, 90, 89, 0, 0, 92, 0, 94, 0, 0, 0, 100, 99, 0, 0, 102, 0, 108, 105, 0, 107, 0, 0, 0, 115, 114, 113, 0, 0, 0, 0, 132, 129, 120, 0, 122, 0, 126, 125, 0, 0, 128, 0, 0, 131, 0, 0, 154, 153, 144, 143, 142, 139, 0, 141, 0, 0, 0, 0, 152, 151, 150, 149, 0, 0, 0, 0, 0, 0, 0].span(); + let nodes_featureids: Span = array![3, 2, 4, 8, 0, 0, 1, 0, 2, 7, 0, 0, 0, 0, 7, 0, 0, 0, 6, 0, 0, 8, 0, 2, 0, 7, 0, 7, 2, 0, 2, 0, 0, 0, 2, 6, 7, 0, 0, 0, 0, 7, 7, 0, 7, 1, 0, 0, 2, 0, 0, 2, 0, 2, 2, 6, 0, 0, 0, 0, 2, 0, 0, 1, 0, 6, 0, 0, 0, 0, 2, 6, 7, 0, 0, 0, 0, 6, 7, 2, 0, 0, 0, 0, 0, 2, 2, 7, 0, 0, 2, 0, 0, 0, 0, 0, 6, 1, 0, 0, 4, 0, 2, 2, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 6, 0, 7, 0, 0, 0, 1, 3, 0, 0, 2, 0, 0, 8, 0, 0, 2, 2, 2, 4, 7, 3, 0, 1, 0, 0, 0, 0, 4, 3, 7, 8, 0, 0, 0, 0, 0, 0, 0].span(); + let nodes_missing_value_tracks_true: Span = array![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0].span(); + let nodes_modes: Span = array![NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF].span(); + let nodes_nodeids: Span = array![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154].span(); + let nodes_treeids: Span = array![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0].span(); + let nodes_truenodeids: Span = array![1, 2, 3, 4, 0, 0, 7, 0, 9, 10, 0, 12, 0, 0, 15, 0, 17, 0, 19, 0, 0, 22, 23, 24, 0, 26, 0, 28, 29, 0, 31, 0, 0, 0, 35, 36, 37, 0, 0, 0, 0, 42, 43, 44, 45, 46, 0, 48, 49, 0, 0, 52, 0, 54, 55, 56, 0, 0, 0, 0, 61, 62, 0, 64, 0, 66, 0, 0, 0, 70, 71, 72, 73, 0, 0, 0, 0, 78, 79, 80, 81, 0, 0, 0, 0, 86, 87, 88, 0, 0, 91, 0, 93, 0, 0, 0, 97, 98, 0, 0, 101, 0, 103, 104, 0, 106, 0, 0, 0, 110, 111, 112, 0, 0, 0, 0, 117, 118, 119, 0, 121, 0, 123, 124, 0, 0, 127, 0, 0, 130, 0, 0, 133, 134, 135, 136, 137, 138, 0, 140, 0, 0, 0, 0, 145, 146, 147, 148, 0, 0, 0, 0, 0, 0, 0].span(); + let nodes_values: Span = array![FP16x16 { mag: 4096, sign: false }, FP16x16 { mag: 22937, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 49152, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 16384, sign: false }, FP16x16 { mag: 57344, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 19660, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 8192, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 29491, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 8192, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 24576, sign: false }, FP16x16 { mag: 42598, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 62259, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 62259, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 40960, sign: false }, FP16x16 { mag: 24576, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 8192, sign: false }, FP16x16 { mag: 49152, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 19660, sign: false }, FP16x16 { mag: 45875, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 29491, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 49152, sign: false }, FP16x16 { mag: 42598, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 36044, sign: false }, FP16x16 { mag: 19660, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 49152, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 45875, sign: false }, FP16x16 { mag: 29491, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 8192, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 8192, sign: false }, FP16x16 { mag: 36044, sign: false }, FP16x16 { mag: 58982, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 58982, sign: false }, FP16x16 { mag: 29491, sign: false }, FP16x16 { mag: 8192, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 45875, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 58982, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 49152, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 42598, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 45875, sign: false }, FP16x16 { mag: 49152, sign: false }, FP16x16 { mag: 29491, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 45875, sign: false }, FP16x16 { mag: 8192, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 49152, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 36044, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 58982, sign: false }, FP16x16 { mag: 49152, sign: false }, FP16x16 { mag: 36044, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 16384, sign: false }, FP16x16 { mag: 20480, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 49152, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 8192, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }].span(); + let base_values: Option> = Option::None; + + let tree_ids: Span = array![0].span(); + let mut root_index: Felt252Dict = Default::default(); + root_index.insert(0, 0); + let mut node_index: Felt252Dict = Default::default(); + node_index.insert(2089986280348253421170679821480865132823066470938446095505822317253594081284, 0); + node_index.insert(2001140082530619239661729809084578298299223810202097622761632384561112390979, 1); + node_index.insert(2592670241084192212354027440049085852792506518781954896144296316131790403900, 2); + node_index.insert(2960591271376829378356567803618548672034867345123727178628869426548453833420, 3); + node_index.insert(458933264452572171106695256465341160654132084710250671055261382009315664425, 4); + node_index.insert(3344223123784052057366048933846905716067140384361791026153972616805110454637, 5); + node_index.insert(658476905110174425295568215706634733332002869979287079110965040248935650599, 6); + node_index.insert(2836212335642438363012490794290757623813171043187182819737087983331902926990, 7); + node_index.insert(3496601277869056110810900082189273917786762659443522403285387602989271154262, 8); + node_index.insert(1249294489531540970169611621067106471309281870082955806338234725206665112557, 9); + node_index.insert(2161697998033672097816961828039488190903838124365465380011173778905747857792, 10); + node_index.insert(1129815197211541481934112806673325772687763881719835256646064516195041515616, 11); + node_index.insert(2592593088135949192377729543480191336537305484235681164569491942155715064163, 12); + node_index.insert(578223957014284909949571568465953382377214912750427143720957054706073492593, 13); + node_index.insert(1645617302026197421098102802983206579163506957138012501615708926120228167528, 14); + node_index.insert(2809438816810155970395166036110536928593305127049404137239671320081144123490, 15); + node_index.insert(2496308528011391755709310159103918074725328650411689040761791240500618770096, 16); + node_index.insert(2003594778587446957576114348312422277631766150749194167061999666337236425714, 17); + node_index.insert(2215681478480673835576618830034726157921200517935329010004363713426342305479, 18); + node_index.insert(3185925835074464079989752015681272863271067691852543168049845807561733691707, 19); + node_index.insert(1207265836470221457484062512091666004839070622130697586496866096347024057755, 20); + node_index.insert(1870230949202979679764944800468118671928852128047695497376875566624821494262, 21); + node_index.insert(618060852536781954395603948693216564334274573299243914053414488061601327758, 22); + node_index.insert(232760707548494477255512699093366059519467428168757247456690480397246371463, 23); + node_index.insert(1617386247965480308136742715422077429967341022950306068917456849194882895900, 24); + node_index.insert(654822874782506608656472905579051041410086644071534146326024101025575400153, 25); + node_index.insert(525638101901638132526332140778087078272370083489998903571807698910013602668, 26); + node_index.insert(3091640181556387972179279087539287892670640556085669903494551919685982442095, 27); + node_index.insert(1425411460578159050163131982087304445715005458700346341117759372943452688022, 28); + node_index.insert(1722933265299553894839124723076027659619615015638971980461286818493531809034, 29); + node_index.insert(3325117385742592388671007840076299062858228097051060057749225651290693960897, 30); + node_index.insert(1869273998012404873272699831805499731567895666937555882116307079956228100456, 31); + node_index.insert(257262395234910825879033951801423835835630270967846664413154594520703929530, 32); + node_index.insert(2891500475385583315757684141371327604925143655360011721762142660942782195029, 33); + node_index.insert(1257459981124043271342269816753070228024611695909553991758648317372015085782, 34); + node_index.insert(3573101724490615587655146760489247477770015274618159524231872921394794809579, 35); + node_index.insert(2951401777594449283985541406642940553317465718696638438535370997641527993378, 36); + node_index.insert(2436860863451320452900512817385686838091627966322316039332239784330434600829, 37); + node_index.insert(3257977356974702770994741663931928753019715185508521958836925918758890988390, 38); + node_index.insert(2741853283805093821434776875305720302351684616683152528499335618682018880592, 39); + node_index.insert(514567459251558911686762246500770717674979116530125263461114578537254680672, 40); + node_index.insert(2119374930171040799805795099091470687208894498354655018353474015395489390434, 41); + node_index.insert(3338470191188327918255138125570464269857839379813971679216902484398948556964, 42); + node_index.insert(2892272281879752543368066497063301979597320550780387266511926397533716561161, 43); + node_index.insert(2855312300216814846973137837923466865382642814675378398541743368270404441020, 44); + node_index.insert(3483159989811162048659069774034779954374540681397531094699912464364012442948, 45); + node_index.insert(2987290998320166766043911843685118029159841654368226419198314196237253901671, 46); + node_index.insert(2925128850088180758852255336587985612621894021863350117875677692518888637440, 47); + node_index.insert(2816470536741550741568042622139415760794090671576940833850781679568928363263, 48); + node_index.insert(117504025904364990582663097556885493352655695615775952177872159762046032741, 49); + node_index.insert(2143228410294149239354901612797540167003066966910132278060626241695943498248, 50); + node_index.insert(419311759585766455354017006957403420381614228026953716552023555428752798694, 51); + node_index.insert(3050064038480880151202753004776919876287903442365303272956696507808448797287, 52); + node_index.insert(1385347512411195789080079656286641766866442255046855963092069449745407366357, 53); + node_index.insert(3070310993421490198115289431281422702215620142859327949152517372324361472619, 54); + node_index.insert(2913742884576958969164113782587195202828846527657900496424141449477472273564, 55); + node_index.insert(2093568472535973986606438755824580633177115509557931302974988564932601955239, 56); + node_index.insert(3560543329106347446823281318204312198881533222464682017397248462954529220234, 57); + node_index.insert(2258329791422139736262782239641765930569031761627249090322755566443202104242, 58); + node_index.insert(780147230530856456622774510057100334628735431063744145772648079601317149643, 59); + node_index.insert(2316329094783634722527635915976455864728431870713378530935487247638854220445, 60); + node_index.insert(595942459003356191117553450912822964169058193996898486073017533717706655996, 61); + node_index.insert(468061318535033931711585815055033307297228787991312757359512916260570188285, 62); + node_index.insert(2052204235688624923559873131063770183910134013049526186717275231865702195614, 63); + node_index.insert(1699955311620840869165542755053722387608345658646185648087789689690825797785, 64); + node_index.insert(3374282522812564185678772854203408947562394461702303390331208821006329361123, 65); + node_index.insert(2973169188135795465401576355486514117723575153845438471619715618155257254587, 66); + node_index.insert(1933845760462748501896196912926633344425020928596291295340561855718789280752, 67); + node_index.insert(1400206374308839959676708676217334569580738052049798766556848516900888958934, 68); + node_index.insert(1440488595273849761788031183901254714714513692476890759699232177835922420051, 69); + node_index.insert(1765607197782429306903827944694032984087223086461400721152786273443512274576, 70); + node_index.insert(1081728107764482028110815183657783965582618309560569428049406599883158895762, 71); + node_index.insert(2062101824085365476835789898002802715794623271831111740147610520210138854237, 72); + node_index.insert(2074740322618091900768870458741540994849904300182495465356314088191301853065, 73); + node_index.insert(3258451235037745323160669027918885172565773098482160366154412360890640013860, 74); + node_index.insert(525053653813541387331907730505904505067816165493211829943994988775279102044, 75); + node_index.insert(1899573658331441767985549642643113663505618738939032010935036740376062596854, 76); + node_index.insert(350484224543766923071449868701665032398970313961410080649918872017849315812, 77); + node_index.insert(1950842492180490337143378914485176805944281696420768035114335939818602766139, 78); + node_index.insert(1404824782481446239312837894341789608778585592445990662138109764117920511709, 79); + node_index.insert(362836422984951199752185473435750713386745407518736982952373985921347236081, 80); + node_index.insert(946623025367211063265176586824604502073515634531788667777364911179858705558, 81); + node_index.insert(2633163324000277496191816132521100721217797223993064604664039067710591734562, 82); + node_index.insert(1801986104078933931671502775029170829560335045042499367678597186639133610708, 83); + node_index.insert(1420697278439090953165809531316265389371075037014378922361911811337560296928, 84); + node_index.insert(2818913779862691152404893285048164649343019708946413114150419613972391643833, 85); + node_index.insert(2117995436013652728497840885480545729833030913486848118093758726746902541269, 86); + node_index.insert(127751852951361188238686395231851222850913859197429858579312845246901369178, 87); + node_index.insert(2698811633001158191033663638617437313508153976714307643233173949778419312517, 88); + node_index.insert(658388282521842455588914251287531837029259203197178137902217792556456503561, 89); + node_index.insert(1181527093320872098458354979612125149419384756607076935731557552577945926179, 90); + node_index.insert(749436134732178646256740138670151907037714564259781780243747781475007506978, 91); + node_index.insert(139527053159256821789882596124320673637475746672994443968014105962305658551, 92); + node_index.insert(2256264752321707533173578319742847366660740117899562657584919346001438808295, 93); + node_index.insert(1471349294215639651865069312281269029496180149092207674923855978537861742949, 94); + node_index.insert(1599527610774916650758786135513735847459194869088601099692148267264507139422, 95); + node_index.insert(1348925567371118538973078195838174941892601233016661969987842843098656775084, 96); + node_index.insert(3255130909854220350850821724488067913492420563978595271106701962634473840914, 97); + node_index.insert(1098499015810170842401428216621470177488952811780672364884710297364076372943, 98); + node_index.insert(2666902303639302012507119689908308317608522901613536135678723310999647515155, 99); + node_index.insert(907997515879651052705985194221621380802961721264372722705825219340461809200, 100); + node_index.insert(2124360554325144308113106422635485756539471211141315552843423768396084888273, 101); + node_index.insert(3598736440043009208771817410113758019876931018927260161846683440123219507147, 102); + node_index.insert(1237113034722832488580561245188430373504295256910735188987019984096012001931, 103); + node_index.insert(884558344049768836371555446021588200903052780339208951904957349404044037185, 104); + node_index.insert(784280321344489256066716285882203121428790637989919760379274813665427427262, 105); + node_index.insert(3472551952588748711709398308465335743810517871695257916614928877311914574241, 106); + node_index.insert(1579363348100943961344032004617708767155021524242506190674861550786419896732, 107); + node_index.insert(653576968777651719072715499492112313607520878545254037043893560183879857489, 108); + node_index.insert(2633327961579170199842757290989312779085828750765842327985383652720803061926, 109); + node_index.insert(3101204920253220343970782457572784926765600523633379722044614528209389590915, 110); + node_index.insert(2537565394330405662800880050062241097694806466900452037378113841155978555645, 111); + node_index.insert(306955559655552244989220345789093187601563118591829582730637833945761653350, 112); + node_index.insert(1144065212212058748489308207801098564095305699242880891977316839573431241916, 113); + node_index.insert(3478181491851418723342103101321490659650934149094649769124337426850038155270, 114); + node_index.insert(3419621624676637660673415219086314486713019053519954317586073983685881930356, 115); + node_index.insert(2426908011370291613447136873176769136554489197972200481728552402228021778402, 116); + node_index.insert(1916122042123370178944690083048900704842269230325086549679099089416174875473, 117); + node_index.insert(2057207652658215393591191155928140567561900227203223756539551876829334137660, 118); + node_index.insert(2722034389703601317070746005702467061064354401688341549606678773616189196490, 119); + node_index.insert(1171026027377763359814377926117880688616494219551682642535759838199732407496, 120); + node_index.insert(3507234282031533800397666430789917374211847440333243952151005899337152633413, 121); + node_index.insert(591003147462937848375161803108517142253138969543815135207326321181858185919, 122); + node_index.insert(182069734527202013451813026473135702900640769187641767871411473365447302169, 123); + node_index.insert(1195243682249232878341146428166676460720423167409013083888435705219134747702, 124); + node_index.insert(1793425644853312386902998134061844248823841892125424765064687913085130719534, 125); + node_index.insert(1983622665815164792580256365519803214027269990384198703315493315153573288434, 126); + node_index.insert(3615973154491344159350153395208055142342062736505558158666764642048838175685, 127); + node_index.insert(2751715913626909804252433699602081411293721754810298670422380863932998088133, 128); + node_index.insert(186918881712189523740089713555196200069231794627360499557319265374750577226, 129); + node_index.insert(696585542544434929491503209053317581175146475161262066468664234437983008675, 130); + node_index.insert(4359830495913805154545225899592517767672472055784183911796827820518038513, 131); + node_index.insert(2954335207058000607751727656601539819316106074875304820535376873121805433820, 132); + node_index.insert(2510390039949230255082316953804013731253145558531652907601250263563528226672, 133); + node_index.insert(3226995230854300551967642178527450300960499043510855212238369890580256668532, 134); + node_index.insert(1620924075233065517364532267959798304439946408626316544761884056227131075831, 135); + node_index.insert(1610900122192929153657761847202689179268074338802437933866337242354758101660, 136); + node_index.insert(2565949095169598991903537465065584077778440646580025930326495506484329892725, 137); + node_index.insert(1012362975819634411571869839734809106575285344002573666983595104659295812607, 138); + node_index.insert(242312010918799555845832460483650516749990744287009628468613253461264531026, 139); + node_index.insert(1104776796569046483584574115975216172161469015460244982207905888870418040487, 140); + node_index.insert(3289555912992777681578950209252840071327866822704829766247386311885634446673, 141); + node_index.insert(3133389957643610781371406448279843175887428913359743769920083259111437722268, 142); + node_index.insert(1169918710119352022244140656086831769713178729571654411898266328562003734517, 143); + node_index.insert(3592039235252149652556167686570045881877115549259769455422056097903987237819, 144); + node_index.insert(2048175709145840597887667330964815895803568760936075562647625937161113445908, 145); + node_index.insert(602222645962845554276438041138511866776339653340605661136009451417275008940, 146); + node_index.insert(3318742320906017551291978242369663702298606650330380959683585594592748661010, 147); + node_index.insert(564160996724923690963741657975239836484028160385417016805513722318839327322, 148); + node_index.insert(656294390376267384135628810815504467149264887388377312825033341338166573620, 149); + node_index.insert(1201592236750942207412694706123654466634588634474700675083122904145559965915, 150); + node_index.insert(2141408926815137181004274624388915700231991905288681935478972043994347966006, 151); + node_index.insert(1440847977042239464860406726605567303568767649154338464116083965986084755262, 152); + node_index.insert(950585553138591375958592507876257987416844837045084288783892644487908218679, 153); + node_index.insert(257643451533833048856069434258149588745628261389615631070776723485957908127, 154); + + let atts = TreeEnsembleAttributes { + nodes_falsenodeids, + nodes_featureids, + nodes_missing_value_tracks_true, + nodes_modes, + nodes_nodeids, + nodes_treeids, + nodes_truenodeids, + nodes_values + }; + + let mut ensemble: TreeEnsemble = TreeEnsemble { + atts, tree_ids, root_index, node_index + }; + + let mut classifier: TreeEnsembleClassifier = TreeEnsembleClassifier { + ensemble, + class_ids, + class_nodeids, + class_treeids, + class_weights, + classlabels, + base_values, + post_transform + }; + + let mut X = TensorTrait::new( + array![1,9].span(), + array![ + FP16x16 { mag: 39321, sign: false }, + FP16x16 { mag: 32768, sign: false }, + FP16x16 { mag: 52428, sign: false }, + FP16x16 { mag: 16384, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 65536, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 16384, sign: false }, + FP16x16 { mag: 0, sign: false }, + ].span() + ); + + (classifier, X) +} \ No newline at end of file From 2c93d3a206f4004f7a361918685d43916639cc09 Mon Sep 17 00:00:00 2001 From: Hakeem Kazeem Date: Mon, 25 Dec 2023 10:10:48 +0100 Subject: [PATCH 13/38] compress operator --- src/operators/tensor/core.cairo | 3 + .../tensor/implementations/tensor_bool.cairo | 4 + .../implementations/tensor_fp16x16.cairo | 4 + .../implementations/tensor_fp16x16wide.cairo | 4 + .../implementations/tensor_fp32x32.cairo | 4 + .../implementations/tensor_fp64x64.cairo | 4 + .../implementations/tensor_fp8x23.cairo | 4 + .../implementations/tensor_fp8x23wide.cairo | 4 + .../tensor/implementations/tensor_i32.cairo | 4 + .../tensor/implementations/tensor_i8.cairo | 4 + .../tensor/implementations/tensor_u32.cairo | 4 + src/operators/tensor/math.cairo | 1 + src/operators/tensor/math/compress.cairo | 239 +++ tests/nodes.cairo | 1704 ++++++++--------- 14 files changed, 1135 insertions(+), 852 deletions(-) create mode 100644 src/operators/tensor/math/compress.cairo diff --git a/src/operators/tensor/core.cairo b/src/operators/tensor/core.cairo index decc2e343..d3eef4158 100644 --- a/src/operators/tensor/core.cairo +++ b/src/operators/tensor/core.cairo @@ -5077,6 +5077,9 @@ trait TensorTrait { /// ``` /// fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor; + + fn compress(self: @Tensor, indices: Tensor, axis: Option) -> Tensor; + } /// Cf: TensorTrait::new docstring diff --git a/src/operators/tensor/implementations/tensor_bool.cairo b/src/operators/tensor/implementations/tensor_bool.cairo index d2afe3fc5..c011b5f6c 100644 --- a/src/operators/tensor/implementations/tensor_bool.cairo +++ b/src/operators/tensor/implementations/tensor_bool.cairo @@ -475,6 +475,10 @@ impl BoolTensor of TensorTrait { fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } + + fn compress(self: @Tensor, indices: Tensor, axis: Option) -> Tensor { + math::compress::compress(self, indices, axis) + } } /// Implements partial equal for two `Tensor` using the `PartialEq` trait. diff --git a/src/operators/tensor/implementations/tensor_fp16x16.cairo b/src/operators/tensor/implementations/tensor_fp16x16.cairo index ccaf5903d..e9ac30c2b 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16.cairo @@ -537,6 +537,10 @@ impl FP16x16Tensor of TensorTrait { ) -> (Tensor, Tensor, Tensor, Tensor) { manipulation::unique::unique(self, axis, sorted) } + + fn compress(self: @Tensor, indices: Tensor, axis: Option) -> Tensor { + math::compress::compress(self, indices, axis) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo index dc32202ed..53c60d745 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo @@ -503,6 +503,10 @@ impl FP16x16WTensor of TensorTrait { ) -> (Tensor, Tensor, Tensor, Tensor) { manipulation::unique::unique(self, axis, sorted) } + + fn compress(self: @Tensor, indices: Tensor, axis: Option) -> Tensor { + math::compress::compress(self, indices, axis) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp32x32.cairo b/src/operators/tensor/implementations/tensor_fp32x32.cairo index 9100d6f82..6cf2113c5 100644 --- a/src/operators/tensor/implementations/tensor_fp32x32.cairo +++ b/src/operators/tensor/implementations/tensor_fp32x32.cairo @@ -538,6 +538,10 @@ impl FP32x32Tensor of TensorTrait { ) -> (Tensor, Tensor, Tensor, Tensor) { manipulation::unique::unique(self, axis, sorted) } + + fn compress(self: @Tensor, indices: Tensor, axis: Option) -> Tensor { + math::compress::compress(self, indices, axis) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp64x64.cairo b/src/operators/tensor/implementations/tensor_fp64x64.cairo index ee6441058..313299072 100644 --- a/src/operators/tensor/implementations/tensor_fp64x64.cairo +++ b/src/operators/tensor/implementations/tensor_fp64x64.cairo @@ -539,6 +539,10 @@ impl FP64x64Tensor of TensorTrait { ) -> (Tensor, Tensor, Tensor, Tensor) { manipulation::unique::unique(self, axis, sorted) } + + fn compress(self: @Tensor, indices: Tensor, axis: Option) -> Tensor { + math::compress::compress(self, indices, axis) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp8x23.cairo b/src/operators/tensor/implementations/tensor_fp8x23.cairo index 17a601f7b..22608e870 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23.cairo @@ -537,6 +537,10 @@ impl FP8x23Tensor of TensorTrait { ) -> (Tensor, Tensor, Tensor, Tensor) { manipulation::unique::unique(self, axis, sorted) } + + fn compress(self: @Tensor, indices: Tensor, axis: Option) -> Tensor { + math::compress::compress(self, indices, axis) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo index a7d19901b..bebbc4075 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo @@ -490,6 +490,10 @@ impl FP8x23WTensor of TensorTrait { ) -> (Tensor, Tensor, Tensor, Tensor) { manipulation::unique::unique(self, axis, sorted) } + + fn compress(self: @Tensor, indices: Tensor, axis: Option) -> Tensor { + math::compress::compress(self, indices, axis) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_i32.cairo b/src/operators/tensor/implementations/tensor_i32.cairo index a987b0633..e8f7756e2 100644 --- a/src/operators/tensor/implementations/tensor_i32.cairo +++ b/src/operators/tensor/implementations/tensor_i32.cairo @@ -534,6 +534,10 @@ impl I32Tensor of TensorTrait { ) -> (Tensor, Tensor, Tensor, Tensor) { manipulation::unique::unique(self, axis, sorted) } + + fn compress(self: @Tensor, indices: Tensor, axis: Option) -> Tensor { + math::compress::compress(self, indices, axis) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_i8.cairo b/src/operators/tensor/implementations/tensor_i8.cairo index 8c1e2fd32..d6304ef27 100644 --- a/src/operators/tensor/implementations/tensor_i8.cairo +++ b/src/operators/tensor/implementations/tensor_i8.cairo @@ -532,6 +532,10 @@ impl I8Tensor of TensorTrait { ) -> (Tensor, Tensor, Tensor, Tensor) { manipulation::unique::unique(self, axis, sorted) } + + fn compress(self: @Tensor, indices: Tensor, axis: Option) -> Tensor { + math::compress::compress(self, indices, axis) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_u32.cairo b/src/operators/tensor/implementations/tensor_u32.cairo index 5b2058401..136dfe2c9 100644 --- a/src/operators/tensor/implementations/tensor_u32.cairo +++ b/src/operators/tensor/implementations/tensor_u32.cairo @@ -475,6 +475,10 @@ impl U32Tensor of TensorTrait { ) -> (Tensor, Tensor, Tensor, Tensor) { manipulation::unique::unique(self, axis, sorted) } + + fn compress(self: @Tensor, indices: Tensor, axis: Option) -> Tensor { + math::compress::compress(self, indices, axis) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/math.cairo b/src/operators/tensor/math.cairo index e9822a21f..f64669ec3 100644 --- a/src/operators/tensor/math.cairo +++ b/src/operators/tensor/math.cairo @@ -65,3 +65,4 @@ mod is_inf; mod gather_nd; mod reduce_log_sum; mod erf; +mod compress; \ No newline at end of file diff --git a/src/operators/tensor/math/compress.cairo b/src/operators/tensor/math/compress.cairo new file mode 100644 index 000000000..0f6a3d9b6 --- /dev/null +++ b/src/operators/tensor/math/compress.cairo @@ -0,0 +1,239 @@ +use alexandria_data_structures::array_ext::SpanTraitExt; +use core::array::ArrayTrait; +use core::array::SpanTrait; +use core::option::OptionTrait; + +use core::traits::Into; +use core::debug::PrintTrait; +use core::traits::TryInto; +use core::serde::Serde; +use core::traits::Destruct; + +use orion::numbers::NumberTrait; +use orion::operators::tensor::U32TensorPartialEq; +use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor}; + +/// Cf: TensorTrait::gather_nd docstring +fn compress< + T, + impl TTensorTrait: TensorTrait, + impl TCopy: Copy, + impl TDrop: Drop, +>( + self: @Tensor, condition: Tensor, axis: Option +) -> Tensor { + let axis = match axis { + Option::Some(val) => val, + Option::None(_) => 999 + }; + + let data_rank = (*self.shape).len(); + let condition_rank = (condition.shape).len(); + assert((data_rank >= 1 ), 'data rank must > 1'); + assert((condition_rank == 1), 'condition rank must be 1'); + + + let mut data_shape = *self.shape; + let mut condition_shape = condition.shape; + // let mut data_shape_clone = data_shape.clone(); + // let mut condition_shape_clone = condition_shape.clone(); + assert(*data_shape.at(axis) >= condition.data.len(), 'index out of bound'); + + let mut output_shape = ArrayTrait::new(); + let mut index_data = ArrayTrait::new(); + let mut output_data = ArrayTrait::new(); + + let mut data = *self.data; + let mut condition_data = condition.data; + + let mut ind = 0; + let mut condition_data_clone = condition_data.clone(); + let mut output = 0; + loop { + match condition_data_clone.pop_front() { + Option::Some(val) => { + if (*val != 0) { + output += 1; + } + ind += 1; + }, + Option::None(_) => { break; } + }; + }; + + let mut ind = 0; + let mut loop_breaker = 1; + let mut other_loop_breaker = 1; + let mut multiplier = 1; + + let mut data_shape_clone = data_shape.clone(); + loop { + match data_shape_clone.pop_front() { + Option::Some(val) => { + if (ind == axis) { + output_shape.append(output); + } + else { + output_shape.append(*val); + if (ind > axis) { + loop_breaker *= *val; + } + if (ind >= axis) { + multiplier *= *val; + } + if (ind < axis) { + other_loop_breaker *= *val; + } + } + ind += 1; + }, + Option::None(_) => { break; } + }; + }; + + let mut ind = 0; + let mut inner_index: usize = 0; + let mut condition_data_clone = condition_data.clone(); + + loop { + if (ind == other_loop_breaker) {break;} + + let mut condition_data_clone = condition_data.clone(); + + loop { + match condition_data_clone.pop_front() { + Option::Some(val) => { + if (*val != 0){ + let result = inner_index * loop_breaker ; + // + multiplier * ind + // 'Start'.print(); + // (inner_index).print(); + // (loop_breaker).print(); + // (multiplier).print(); + // (ind).print(); + // (result).print(); + + + let mut data_ind:usize = result ; + loop { + if data_ind == result + loop_breaker { break; } + index_data.append(data_ind); + data_ind+=1; + }; + } + inner_index += 1; + }, + + Option::None(_) => { break; } + }; + }; + + ind += 1; + }; + + loop { + match index_data.pop_front() { + Option::Some(val) => { + output_data.append(*self.data[val]); + }, + Option::None(_) => { break; } + }; + }; + + let mut output_tensor = TensorTrait::::new(output_shape.span(), output_data.span()); + return output_tensor; +} + +// Tests-------------------------------------------------------------------------------------------------------------- + +use orion::utils::assert_eq; + +fn indices() -> Tensor { + let mut sizes = ArrayTrait::new(); + sizes.append(3); + + let mut data = ArrayTrait::new(); + data.append(0); + data.append(1); + data.append(1); + + let tensor = TensorTrait::::new(sizes.span(), data.span()); + return tensor; + +} + +fn data() -> Tensor { + let mut sizes = ArrayTrait::new(); + sizes.append(2); + sizes.append(2); + + let mut data = ArrayTrait::new(); + data.append(0); + data.append(1); + data.append(2); + data.append(3); + + let tensor = TensorTrait::::new(sizes.span(), data.span()); + return tensor; +} + +fn data1() -> Tensor { + let mut sizes = ArrayTrait::new(); + sizes.append(3); + sizes.append(3); + sizes.append(3); + + let mut data = ArrayTrait::new(); + + data.append(1); + data.append(2); + data.append(3); + data.append(4); + data.append(5); + data.append(6); + data.append(7); + data.append(8); + data.append(9); + data.append(1); + data.append(2); + data.append(3); + data.append(4); + data.append(5); + data.append(6); + data.append(7); + data.append(8); + data.append(9); + data.append(1); + data.append(2); + data.append(3); + data.append(4); + data.append(5); + data.append(6); + data.append(7); + data.append(8); + data.append(9); + + let tensor = TensorTrait::::new(sizes.span(), data.span()); + return tensor; +} + +#[test] +#[available_gas(20000000000)] +fn test_gather_elements_default() { + let data = data1(); + let indices = indices(); + + let y = data.compress(indices: indices, axis:Option::Some(0)); + let mut output = y.data; + + loop { + match output.pop_front() { + Option::Some(val) => { + (*val).print(); + + }, + Option::None(_) => { break; } + }; + }; + +} \ No newline at end of file diff --git a/tests/nodes.cairo b/tests/nodes.cairo index c7155e942..48ca8e25f 100644 --- a/tests/nodes.cairo +++ b/tests/nodes.cairo @@ -1,852 +1,852 @@ -mod abs_fp16x16; -mod abs_fp8x23; -mod abs_i32; -mod abs_i8; -mod acos_fp16x16; -mod acos_fp8x23; -mod acosh_fp16x16; -mod acosh_fp8x23; -mod add_fp16x16; -mod add_fp16x16_broadcast; -mod add_fp8x23; -mod add_fp8x23_broadcast; -mod add_i32; -mod add_i32_broadcast; -mod add_i8; -mod add_i8_broadcast; -mod add_u32; -mod add_u32_broadcast; -mod argmax_fp16x16_1D_default; -mod argmax_fp16x16_1D_keepdims_false; -mod argmax_fp16x16_1D_last_index; -mod argmax_fp16x16_2D_default; -mod argmax_fp16x16_2D_keepdims_false; -mod argmax_fp16x16_2D_last_index; -mod argmax_fp16x16_3D_default; -mod argmax_fp16x16_3D_keepdims_false; -mod argmax_fp16x16_3D_last_index; -mod argmax_fp8x23_1D_default; -mod argmax_fp8x23_1D_keepdims_false; -mod argmax_fp8x23_1D_last_index; -mod argmax_fp8x23_2D_default; -mod argmax_fp8x23_2D_keepdims_false; -mod argmax_fp8x23_2D_last_index; -mod argmax_fp8x23_3D_default; -mod argmax_fp8x23_3D_keepdims_false; -mod argmax_fp8x23_3D_last_index; -mod argmax_i32_1D_default; -mod argmax_i32_1D_keepdims_false; -mod argmax_i32_1D_last_index; -mod argmax_i32_2D_default; -mod argmax_i32_2D_keepdims_false; -mod argmax_i32_2D_last_index; -mod argmax_i32_3D_default; -mod argmax_i32_3D_keepdims_false; -mod argmax_i32_3D_last_index; -mod argmax_i8_1D_default; -mod argmax_i8_1D_keepdims_false; -mod argmax_i8_1D_last_index; -mod argmax_i8_2D_default; -mod argmax_i8_2D_keepdims_false; -mod argmax_i8_2D_last_index; -mod argmax_i8_3D_default; -mod argmax_i8_3D_keepdims_false; -mod argmax_i8_3D_last_index; -mod argmax_u32_1D_default; -mod argmax_u32_1D_keepdims_false; -mod argmax_u32_1D_last_index; -mod argmax_u32_2D_default; -mod argmax_u32_2D_keepdims_false; -mod argmax_u32_2D_last_index; -mod argmax_u32_3D_default; -mod argmax_u32_3D_keepdims_false; -mod argmax_u32_3D_last_index; -mod argmin_fp16x16_1D_default; -mod argmin_fp16x16_1D_keepdims_false; -mod argmin_fp16x16_1D_last_index; -mod argmin_fp16x16_2D_default; -mod argmin_fp16x16_2D_keepdims_false; -mod argmin_fp16x16_2D_last_index; -mod argmin_fp16x16_3D_default; -mod argmin_fp16x16_3D_keepdims_false; -mod argmin_fp16x16_3D_last_index; -mod argmin_fp8x23_1D_default; -mod argmin_fp8x23_1D_keepdims_false; -mod argmin_fp8x23_1D_last_index; -mod argmin_fp8x23_2D_default; -mod argmin_fp8x23_2D_keepdims_false; -mod argmin_fp8x23_2D_last_index; -mod argmin_fp8x23_3D_default; -mod argmin_fp8x23_3D_keepdims_false; -mod argmin_fp8x23_3D_last_index; -mod argmin_i32_1D_default; -mod argmin_i32_1D_keepdims_false; -mod argmin_i32_1D_last_index; -mod argmin_i32_2D_default; -mod argmin_i32_2D_keepdims_false; -mod argmin_i32_2D_last_index; -mod argmin_i32_3D_default; -mod argmin_i32_3D_keepdims_false; -mod argmin_i32_3D_last_index; -mod argmin_i8_1D_default; -mod argmin_i8_1D_keepdims_false; -mod argmin_i8_1D_last_index; -mod argmin_i8_2D_default; -mod argmin_i8_2D_keepdims_false; -mod argmin_i8_2D_last_index; -mod argmin_i8_3D_default; -mod argmin_i8_3D_keepdims_false; -mod argmin_i8_3D_last_index; -mod argmin_u32_1D_default; -mod argmin_u32_1D_keepdims_false; -mod argmin_u32_1D_last_index; -mod argmin_u32_2D_default; -mod argmin_u32_2D_keepdims_false; -mod argmin_u32_2D_last_index; -mod argmin_u32_3D_default; -mod argmin_u32_3D_keepdims_false; -mod argmin_u32_3D_last_index; -mod asin_fp16x16; -mod asin_fp8x23; -mod asinh_fp16x16; -mod asinh_fp8x23; -mod atan_fp16x16; -mod atan_fp8x23; -mod ceil_fp16x16; -mod ceil_fp8x23; -mod concat_fp16x16_1d; -mod concat_fp16x16_2d; -mod concat_fp16x16_3d_default; -mod concat_fp16x16_3d_axis_1; -mod concat_fp16x16_3d_axis_2; -mod concat_fp16x16_3d_three_tensors_axis_1; -mod concat_fp16x16_3d_three_tensors_axis_2; -mod concat_fp8x23_1d; -mod concat_fp8x23_2d; -mod concat_fp8x23_3d_default; -mod concat_fp8x23_3d_axis_1; -mod concat_fp8x23_3d_axis_2; -mod concat_fp8x23_3d_three_tensors_axis_1; -mod concat_fp8x23_3d_three_tensors_axis_2; -mod concat_i32_1d; -mod concat_i32_2d; -mod concat_i32_3d_default; -mod concat_i32_3d_axis_1; -mod concat_i32_3d_axis_2; -mod concat_i32_3d_three_tensors_axis_1; -mod concat_i32_3d_three_tensors_axis_2; -mod concat_i8_1d; -mod concat_i8_2d; -mod concat_i8_3d_default; -mod concat_i8_3d_axis_1; -mod concat_i8_3d_axis_2; -mod concat_i8_3d_three_tensors_axis_1; -mod concat_i8_3d_three_tensors_axis_2; -mod concat_u32_1d; -mod concat_u32_2d; -mod concat_u32_3d_default; -mod concat_u32_3d_axis_1; -mod concat_u32_3d_axis_2; -mod concat_u32_3d_three_tensors_axis_1; -mod concat_u32_3d_three_tensors_axis_2; -mod cos_fp16x16; -mod cos_fp8x23; -mod cosh_fp16x16; -mod cosh_fp8x23; -mod cumsum_fp16x16_1d_default; -mod cumsum_fp16x16_1d_exclusive; -mod cumsum_fp16x16_1d_reverse; -mod cumsum_fp16x16_1d_reverse_exclusive; -mod cumsum_fp16x16_2d_axis_0; -mod cumsum_fp16x16_2d_axis_1; -mod cumsum_fp8x23_1d_default; -mod cumsum_fp8x23_1d_exclusive; -mod cumsum_fp8x23_1d_reverse; -mod cumsum_fp8x23_1d_reverse_exclusive; -mod cumsum_fp8x23_2d_axis_0; -mod cumsum_fp8x23_2d_axis_1; -mod cumsum_i32_1d_default; -mod cumsum_i32_1d_exclusive; -mod cumsum_i32_1d_reverse; -mod cumsum_i32_1d_reverse_exclusive; -mod cumsum_i32_2d_axis_0; -mod cumsum_i32_2d_axis_1; -mod cumsum_i8_1d_default; -mod cumsum_i8_1d_exclusive; -mod cumsum_i8_1d_reverse; -mod cumsum_i8_1d_reverse_exclusive; -mod cumsum_i8_2d_axis_0; -mod cumsum_i8_2d_axis_1; -mod cumsum_u32_1d_default; -mod cumsum_u32_1d_exclusive; -mod cumsum_u32_1d_reverse; -mod cumsum_u32_1d_reverse_exclusive; -mod cumsum_u32_2d_axis_0; -mod cumsum_u32_2d_axis_1; -mod div_fp16x16; -mod div_fp16x16_broadcast; -mod div_fp8x23; -mod div_fp8x23_broadcast; -mod div_i32; -mod div_i32_broadcast; -mod div_i8; -mod div_i8_broadcast; -mod div_u32; -mod div_u32_broadcast; -mod equal_fp16x16; -mod equal_fp16x16_broadcast; -mod equal_fp8x23; -mod equal_fp8x23_broadcast; -mod equal_i32; -mod equal_i32_broadcast; -mod equal_i8; -mod equal_i8_broadcast; -mod equal_u32; -mod equal_u32_broadcast; -mod exp_fp16x16; -mod exp_fp8x23; -mod less_equal_fp16x16; -mod less_equal_fp16x16_broadcast; -mod less_equal_fp8x23; -mod less_equal_fp8x23_broadcast; -mod less_equal_i32; -mod less_equal_i32_broadcast; -mod less_equal_i8; -mod less_equal_i8_broadcast; -mod less_equal_u32; -mod less_equal_u32_broadcast; -mod greater_fp16x16; -mod greater_fp16x16_broadcast; -mod greater_fp8x23; -mod greater_fp8x23_broadcast; -mod greater_i32; -mod greater_i32_broadcast; -mod greater_i8; -mod greater_i8_broadcast; -mod greater_u32; -mod greater_u32_broadcast; -mod leaky_relu_fp16x16; -mod leaky_relu_fp8x23; -mod linear_fp16x16; -mod linear_fp8x23; -mod linear_i32; -mod linear_i8; -mod linear_u32; -mod log_fp16x16; -mod log_fp8x23; -mod logsoftmax_fp16x16_axis_0; -mod logsoftmax_fp16x16_axis_1; -mod logsoftmax_fp8x23_axis_0; -mod logsoftmax_fp8x23_axis_1; -mod matmul_fp16x16_1d; -mod matmul_fp16x16_2x2; -mod matmul_fp16x16_2x1; -mod matmul_fp16x16_1x2; -mod matmul_fp8x23_1d; -mod matmul_fp8x23_2x2; -mod matmul_fp8x23_2x1; -mod matmul_fp8x23_1x2; -mod matmul_i32_1d; -mod matmul_i32_2x2; -mod matmul_i32_2x1; -mod matmul_i32_1x2; -mod matmul_i8_1d; -mod matmul_i8_2x2; -mod matmul_i8_2x1; -mod matmul_i8_1x2; -mod matmul_u32_1d; -mod matmul_u32_2x2; -mod matmul_u32_2x1; -mod matmul_u32_1x2; -mod mul_fp16x16; -mod mul_fp16x16_broadcast; -mod mul_fp8x23; -mod mul_fp8x23_broadcast; -mod mul_i32; -mod mul_i32_broadcast; -mod mul_i8; -mod mul_i8_broadcast; -mod mul_u32; -mod mul_u32_broadcast; -mod or_fp16x16; -mod or_fp16x16_broadcast; -mod or_fp8x23; -mod or_fp8x23_broadcast; -mod or_i32; -mod or_i32_broadcast; -mod or_i8; -mod or_i8_broadcast; -mod or_u32; -mod or_u32_broadcast; -mod reduce_sum_fp16x16_1D; -mod reduce_sum_fp16x16_2D_default; -mod reduce_sum_fp16x16_2D_keepdims; -mod reduce_sum_fp16x16_2D_axis_1; -mod reduce_sum_fp8x23_1D; -mod reduce_sum_fp8x23_2D_default; -mod reduce_sum_fp8x23_2D_keepdims; -mod reduce_sum_fp8x23_2D_axis_1; -mod reduce_sum_i32_1D; -mod reduce_sum_i32_2D_default; -mod reduce_sum_i32_2D_keepdims; -mod reduce_sum_i32_2D_axis_1; -mod reduce_sum_i8_1D; -mod reduce_sum_i8_2D_default; -mod reduce_sum_i8_2D_keepdims; -mod reduce_sum_i8_2D_axis_1; -mod reduce_sum_u32_1D; -mod reduce_sum_u32_2D_default; -mod reduce_sum_u32_2D_keepdims; -mod reduce_sum_u32_2D_axis_1; -mod relu_fp16x16; -mod relu_fp8x23; -mod relu_i32; -mod relu_i8; -mod sigmoid_fp16x16; -mod sigmoid_fp8x23; -mod sin_fp16x16; -mod sin_fp8x23; -mod sinh_fp16x16; -mod sinh_fp8x23; -mod softmax_fp16x16; -mod softmax_fp8x23; -mod softplus_fp8x23; -mod softplus_fp16x16; -mod softsign_fp8x23; -mod softsign_fp16x16; -mod sqrt_fp16x16; -mod sqrt_fp8x23; -mod sub_fp16x16; -mod sub_fp16x16_broadcast; -mod sub_fp8x23; -mod sub_fp8x23_broadcast; -mod sub_i32; -mod sub_i32_broadcast; -mod sub_i8; -mod sub_i8_broadcast; -mod sub_u32; -mod sub_u32_broadcast; -mod tanh_fp16x16; -mod tanh_fp8x23; -mod transpose_fp16x16_2d; -mod transpose_fp16x16_3d; -mod transpose_fp8x23_2d; -mod transpose_fp8x23_3d; -mod transpose_i32_2d; -mod transpose_i32_3d; -mod transpose_i8_2d; -mod transpose_i8_3d; -mod transpose_u32_2d; -mod transpose_u32_3d; -mod xor_fp16x16; -mod xor_fp16x16_broadcast; -mod xor_fp8x23; -mod xor_fp8x23_broadcast; -mod xor_i32; -mod xor_i32_broadcast; -mod xor_i8; -mod xor_i8_broadcast; -mod xor_u32; -mod xor_u32_broadcast; -mod less_fp16x16; -mod less_fp16x16_broadcast; -mod less_fp8x23; -mod less_fp8x23_broadcast; -mod less_i32; -mod less_i32_broadcast; -mod less_i8; -mod less_i8_broadcast; -mod less_u32; -mod less_u32_broadcast; -mod greater_equal_fp16x16; -mod greater_equal_fp16x16_broadcast; -mod greater_equal_fp8x23; -mod greater_equal_fp8x23_broadcast; -mod greater_equal_i32; -mod greater_equal_i32_broadcast; -mod greater_equal_i8; -mod greater_equal_i8_broadcast; -mod greater_equal_u32; -mod greater_equal_u32_broadcast; -mod slice_fp16x16_2d; -mod slice_fp16x16_3d; -mod slice_fp8x23_2d; -mod slice_fp8x23_3d; -mod slice_i32_2d; -mod slice_i32_3d; -mod slice_i8_2d; -mod slice_i8_3d; -mod slice_u32_2d; -mod slice_u32_3d; -mod gather_fp8x23_3d_default; -mod gather_fp8x23_3d_axis1; -mod gather_fp8x23_3d_axis2; -mod gather_fp16x16_3d_default; -mod gather_fp16x16_3d_axis1; -mod gather_fp16x16_3d_axis2; -mod gather_i8_3d_default; -mod gather_i8_3d_axis1; -mod gather_i8_3d_axis2; -mod gather_i32_3d_default; -mod gather_i32_3d_axis1; -mod gather_i32_3d_axis2; -mod gather_u32_3d_default; -mod gather_u32_3d_axis1; -mod gather_u32_3d_axis2; -mod nonzero_fp16x16_2d; -mod nonzero_fp16x16_3d; -mod nonzero_fp8x23_2d; -mod nonzero_fp8x23_3d; -mod nonzero_i32_2d; -mod nonzero_i32_3d; -mod nonzero_i8_2d; -mod nonzero_i8_3d; -mod nonzero_u32_2d; -mod nonzero_u32_3d; -mod squeeze_fP16x16; -mod squeeze_fP8x23; -mod squeeze_i32; -mod squeeze_i8; -mod squeeze_u32; -mod unsqueeze_fp16x16_2d; -mod unsqueeze_fp16x16_3d; -mod unsqueeze_fp8x23_2d; -mod unsqueeze_fp8x23_3d; -mod unsqueeze_i32_2d; -mod unsqueeze_i32_3d; -mod unsqueeze_i8_2d; -mod unsqueeze_i8_3d; -mod unsqueeze_u32_2d; -mod unsqueeze_u32_3d; -mod sign_fP16x16; -mod sign_fP8x23; -mod sign_fail; -mod sign_i32; -mod sign_i8; -mod clip_fp16x16_2d; -mod clip_fp16x16_3d; -mod clip_fp8x23_2d; -mod clip_fp8x23_3d; -mod clip_i32_2d; -mod clip_i32_3d; -mod clip_i8_2d; -mod clip_i8_3d; -mod clip_u32_2d; -mod clip_u32_3d; -mod identity_fP16x16; -mod identity_fP8x23; -mod identity_i32; -mod identity_i8; -mod identity_u32; -mod thresholded_relu_fp16x16; -mod thresholded_relu_fp8x23; -mod hard_sigmoid_fp8x23; -mod hard_sigmoid_fp16x16; -mod neg_fp16x16; -mod neg_fp8x23; -mod neg_i32; -mod neg_i8; -mod gemm_all_attributes; -mod gemm_alpha; -mod gemm_beta; -mod gemm_default_matrix_bias; -mod gemm_default_vector_bias; -mod gemm_default_no_bias; -mod gemm_transposeA; -mod gemm_transposeB; -mod min_fp16x16_three_tensors; -mod min_fp16x16_broadcast_three_tensors; -mod min_fp16x16_two_tensors; -mod min_fp16x16_broadcast_two_tensors; -mod min_fp8x23_three_tensors; -mod min_fp8x23_broadcast_three_tensors; -mod min_fp8x23_two_tensors; -mod min_fp8x23_broadcast_two_tensors; -mod min_i32_three_tensors; -mod min_i32_broadcast_three_tensors; -mod min_i32_two_tensors; -mod min_i32_broadcast_two_tensors; -mod min_i8_three_tensors; -mod min_i8_broadcast_three_tensors; -mod min_i8_two_tensors; -mod min_i8_broadcast_two_tensors; -mod min_u32_three_tensors; -mod min_u32_broadcast_three_tensors; -mod min_u32_two_tensors; -mod min_u32_broadcast_two_tensors; -mod where_fp16x16; -mod where_fp16x16_broadcast; -mod where_fp8x23; -mod where_fp8x23_broadcast; -mod where_i32; -mod where_i32_broadcast; -mod where_i8; -mod where_i8_broadcast; -mod where_u32; -mod where_u32_broadcast; -mod not_bool; -mod round_fp16x16; -mod round_fp8x23; -mod max_fp16x16_three_tensors; -mod max_fp16x16_broadcast_three_tensors; -mod max_fp16x16_two_tensors; -mod max_fp16x16_broadcast_two_tensors; -mod max_fp8x23_three_tensors; -mod max_fp8x23_broadcast_three_tensors; -mod max_fp8x23_two_tensors; -mod max_fp8x23_broadcast_two_tensors; -mod max_i32_three_tensors; -mod max_i32_broadcast_three_tensors; -mod max_i32_two_tensors; -mod max_i32_broadcast_two_tensors; -mod max_i8_three_tensors; -mod max_i8_broadcast_three_tensors; -mod max_i8_two_tensors; -mod max_i8_broadcast_two_tensors; -mod max_u32_three_tensors; -mod max_u32_broadcast_three_tensors; -mod max_u32_two_tensors; -mod max_u32_broadcast_two_tensors; -mod scatter_fp16x16_3d_default; -mod scatter_fp16x16_3d_axis1; -mod scatter_fp16x16_3d_axis1_add; -mod scatter_fp8x23_default; -mod scatter_fp8x23_axis1; -mod scatter_fp8x23_mul; -mod scatter_i8_default; -mod scatter_i8_axis1; -mod scatter_i8_axis1_max; -mod scatter_u32_default; -mod scatter_u32_axis1; -mod scatter_u32_add; -mod array_feature_extractor_1D_i32; -mod array_feature_extractor_1D_fp8x23; -mod array_feature_extractor_1D_fp16x16; -mod array_feature_extractor_2D_i32; -mod array_feature_extractor_2D_fp8x23; -mod array_feature_extractor_2D_fp16x16; -mod array_feature_extractor_3D_i32; -mod array_feature_extractor_3D_fp8x23; -mod array_feature_extractor_3D_fp16x16; -mod binarizer_fp16x16; -mod binarizer_fp8x23; -mod tril_fp16x16; -mod tril_fp16x16_neg; -mod tril_fp16x16_one_row; -mod tril_fp16x16_out_neg; -mod tril_fp16x16_out_pos; -mod tril_fp16x16_pos; -mod tril_fp16x16_square; -mod tril_fp16x16_square_neg; -mod tril_fp16x16_zero; -mod triu_fp16x16; -mod triu_fp16x16_neg; -mod triu_fp16x16_one_row; -mod triu_fp16x16_out_neg; -mod triu_fp16x16_out_pos; -mod triu_fp16x16_pos; -mod triu_fp16x16_square; -mod triu_fp16x16_square_neg; -mod triu_fp16x16_zero; -mod tril_fp8x23; -mod tril_fp8x23_neg; -mod tril_fp8x23_one_row; -mod tril_fp8x23_out_neg; -mod tril_fp8x23_out_pos; -mod tril_fp8x23_pos; -mod tril_fp8x23_square; -mod tril_fp8x23_square_neg; -mod tril_fp8x23_zero; -mod triu_fp8x23; -mod triu_fp8x23_neg; -mod triu_fp8x23_one_row; -mod triu_fp8x23_out_neg; -mod triu_fp8x23_out_pos; -mod triu_fp8x23_pos; -mod triu_fp8x23_square; -mod triu_fp8x23_square_neg; -mod triu_fp8x23_zero; -mod tril_i32; -mod tril_neg_i32; -mod tril_i32_one_row; -mod tril_i32_out_neg; -mod tril_i32_out_pos; -mod tril_i32_pos; -mod tril_i32_square; -mod tril_i32_square_neg; -mod tril_i32_zero; -mod triu_i32; -mod triu_i32_neg; -mod triu_i32_one_row; -mod triu_i32_out_neg; -mod triu_i32_out_pos; -mod triu_i32_pos; -mod triu_i32_square; -mod triu_i32_square_neg; -mod triu_i32_zero; -mod tril_i8; -mod tril_i8_neg; -mod tril_i8_one_row; -mod tril_i8_out_neg; -mod tril_i8_out_pos; -mod tril_i8_pos; -mod tril_i8_square; -mod tril_i8_square_neg; -mod tril_i8_zero; -mod triu_i8; -mod triu_i8_neg; -mod triu_i8_one_row; -mod triu_i8_out_neg; -mod triu_i8_out_pos; -mod triu_i8_pos; -mod triu_i8_square; -mod triu_i8_square_neg; -mod triu_i8_zero; -mod tril_u32; -mod tril_u32_neg; -mod tril_u32_one_row; -mod tril_u32_out_neg; -mod tril_u32_out_pos; -mod tril_u32_pos; -mod tril_u32_square; -mod tril_u32_square_neg; -mod tril_u32_zero; -mod triu_u32; -mod triu_u32_neg; -mod triu_u32_one_row; -mod triu_u32_out_neg; -mod triu_u32_out_pos; -mod triu_u32_pos; -mod triu_u32_square; -mod triu_u32_square_neg; -mod triu_u32_zero; -mod reduce_sum_square_fp16x16_export_do_not_keepdims; -mod reduce_sum_square_fp16x16_export_keepdims; -mod reduce_sum_square_fp16x16_export_negative_axes_keepdims; -mod reduce_sum_square_fp8x23_export_do_not_keepdims; -mod reduce_sum_square_fp8x23_export_keepdims; -mod reduce_sum_square_fp8x23_export_negative_axes_keepdims; -mod reduce_sum_square_i32_export_do_not_keepdims; -mod reduce_sum_square_i32_export_keepdims; -mod reduce_sum_square_i32_export_negative_axes_keepdims; -mod reduce_sum_square_i8_export_do_not_keepdims; -mod reduce_sum_square_i8_export_keepdims; -mod reduce_sum_square_i8_export_negative_axes_keepdims; -mod reduce_sum_square_u32_export_do_not_keepdims; -mod reduce_sum_square_u32_export_keepdims; -mod reduce_sum_square_u32_export_negative_axes_keepdims; -mod reduce_l2_fp16x16_export_do_not_keepdims; -mod reduce_l2_fp16x16_export_keepdims; -mod reduce_l2_fp16x16_export_negative_axes_keepdims; -mod reduce_l2_fp8x23_export_do_not_keepdims; -mod reduce_l2_fp8x23_export_keepdims; -mod reduce_l2_fp8x23_export_negative_axes_keepdims; -mod reduce_l1_fp16x16_export_do_not_keepdims; -mod reduce_l1_fp16x16_export_keepdims; -mod reduce_l1_fp16x16_export_negative_axes_keepdims; -mod reduce_l1_fp8x23_export_do_not_keepdims; -mod reduce_l1_fp8x23_export_keepdims; -mod reduce_l1_fp8x23_export_negative_axes_keepdims; -mod reduce_l1_i32_export_do_not_keepdims; -mod reduce_l1_i32_export_keepdims; -mod reduce_l1_i32_export_negative_axes_keepdims; -mod reduce_l1_i8_export_do_not_keepdims; -mod reduce_l1_i8_export_keepdims; -mod reduce_l1_i8_export_negative_axes_keepdims; -mod reduce_l1_u32_export_do_not_keepdims; -mod reduce_l1_u32_export_keepdims; -mod reduce_l1_u32_export_negative_axes_keepdims; -mod reduce_prod_fp16x16_1D; -mod reduce_prod_fp16x16_2D_default; -mod reduce_prod_fp16x16_2D_keepdims; -mod reduce_prod_fp16x16_2D_axis_1; -mod reduce_prod_fp8x23_1D; -mod reduce_prod_fp8x23_2D_default; -mod reduce_prod_fp8x23_2D_keepdims; -mod reduce_prod_fp8x23_2D_axis_1; -mod reduce_prod_i32_1D; -mod reduce_prod_i32_2D_default; -mod reduce_prod_i32_2D_keepdims; -mod reduce_prod_i32_2D_axis_1; -mod reduce_prod_i8_1D; -mod reduce_prod_i8_2D_default; -mod reduce_prod_i8_2D_keepdims; -mod reduce_prod_i8_2D_axis_1; -mod reduce_prod_u32_1D; -mod reduce_prod_u32_2D_default; -mod reduce_prod_u32_2D_keepdims; -mod reduce_prod_u32_2D_axis_1; -mod gather_elements_fp16x16_3d_default; -mod gather_elements_fp16x16_3d_axis1; -mod gather_elements_fp16x16_3d_axis2; -mod gather_elements_fp8x23_3d_default; -mod gather_elements_fp8x23_3d_axis1; -mod gather_elements_fp8x23_3d_axis2; -mod gather_elements_i8_3d_default; -mod gather_elements_i8_3d_axis1; -mod gather_elements_i32_3d_default; -mod gather_elements_i32_3d_axis1; -mod gather_elements_i32_3d_axis2; -mod gather_elements_u32_default; -mod gather_elements_u32_axis1; -mod gather_elements_u32_axis2; -mod gather_elements_u32_axis3; -mod sequence_length_fp16x16; -mod sequence_length_fp16x16_broadcast; -mod sequence_length_fp8x23; -mod sequence_length_fp8x23_broadcast; -mod sequence_length_i32; -mod sequence_length_i32_broadcast; -mod sequence_length_i8; -mod sequence_length_i8_broadcast; -mod sequence_length_u32; -mod sequence_length_u32_broadcast; -mod sequence_at_u32_positive; -mod sequence_at_u32_negative; -mod sequence_at_fp16x16_positive; -mod sequence_at_fp16x16_negative; -mod sequence_at_fp8x23_positive; -mod sequence_at_fp8x23_negative; -mod sequence_at_i32_positive; -mod sequence_at_i32_negative; -mod sequence_at_i8_positive; -mod sequence_at_i8_negative; -mod reduce_min_fp16x16_1D; -mod reduce_min_fp16x16_2D_default; -mod reduce_min_fp16x16_2D_keepdims; -mod reduce_min_fp16x16_2D_axis_1; -mod reduce_min_fp8x23_1D; -mod reduce_min_fp8x23_2D_default; -mod reduce_min_fp8x23_2D_keepdims; -mod reduce_min_fp8x23_2D_axis_1; -mod reduce_min_i32_1D; -mod reduce_min_i32_2D_default; -mod reduce_min_i32_2D_keepdims; -mod reduce_min_i32_2D_axis_1; -mod reduce_min_i8_1D; -mod reduce_min_i8_2D_default; -mod reduce_min_i8_2D_keepdims; -mod reduce_min_i8_2D_axis_1; -mod reduce_min_u32_1D; -mod reduce_min_u32_2D_default; -mod reduce_min_u32_2D_keepdims; -mod reduce_min_u32_2D_axis_1; -mod sequence_construct_fp16x16; -mod sequence_construct_fp8x23; -mod sequence_construct_i32; -mod sequence_construct_i8; -mod sequence_construct_u32; -mod shrink_hard_fp16x16; -mod shrink_soft_fp16x16; -mod shrink_hard_fp8x23; -mod shrink_soft_fp8x23; -mod sequence_empty_fp16x16; -mod sequence_empty_fp8x23; -mod sequence_empty_i32; -mod sequence_empty_i8; -mod sequence_empty_u32; -mod reduce_mean_fp16x16_1D; -mod reduce_mean_fp16x16_2D_default; -mod reduce_mean_fp16x16_2D_keepdims; -mod reduce_mean_fp16x16_2D_axis_1; -mod reduce_mean_fp8x23_1D; -mod reduce_mean_fp8x23_2D_default; -mod reduce_mean_fp8x23_2D_keepdims; -mod reduce_mean_fp8x23_2D_axis_1; -mod reduce_mean_i32_1D; -mod reduce_mean_i32_2D_default; -mod reduce_mean_i32_2D_keepdims; -mod reduce_mean_i32_2D_axis_1; -mod reduce_mean_i8_1D; -mod reduce_mean_i8_2D_default; -mod reduce_mean_i8_2D_keepdims; -mod reduce_mean_i8_2D_axis_1; -mod reduce_mean_u32_1D; -mod reduce_mean_u32_2D_default; -mod reduce_mean_u32_2D_keepdims; -mod reduce_mean_u32_2D_axis_1; -mod pow_fp16x16; -mod pow_fp16x16_broadcast; -mod pow_fp8x23; -mod pow_fp8x23_broadcast; -mod sequence_erase_u32_positive; -mod sequence_erase_u32_negative; -mod sequence_erase_u32_empty; -mod sequence_erase_fp16x16_positive; -mod sequence_erase_fp16x16_negative; -mod sequence_erase_fp16x16_empty; -mod sequence_erase_fp8x23_positive; -mod sequence_erase_fp8x23_negative; -mod sequence_erase_fp8x23_empty; -mod sequence_erase_i32_positive; -mod sequence_erase_i32_negative; -mod sequence_erase_i32_empty; -mod sequence_erase_i8_positive; -mod sequence_erase_i8_negative; -mod sequence_erase_i8_empty; -mod sequence_insert_fp16x16; -mod sequence_insert_fp8x23; -mod sequence_insert_i32; -mod sequence_insert_i8; -mod sequence_insert_u32; -mod concat_from_sequence_fp8x23_new_axis_zero; -mod concat_from_sequence_fp8x23_new_axis_one; -mod concat_from_sequence_fp8x23_new_axis_default; -mod concat_from_sequence_fp16x16_new_axis_zero; -mod concat_from_sequence_fp16x16_new_axis_one; -mod concat_from_sequence_fp16x16_new_axis_default; -mod concat_from_sequence_i32_new_axis_zero; -mod concat_from_sequence_i32_new_axis_one; -mod concat_from_sequence_i32_new_axis_default; -mod concat_from_sequence_i8_new_axis_zero; -mod concat_from_sequence_i8_new_axis_one; -mod concat_from_sequence_i8_new_axis_default; -mod concat_from_sequence_u32_new_axis_zero; -mod concat_from_sequence_u32_new_axis_one; -mod concat_from_sequence_u32_new_axis_default; -mod is_nan_fp16x16; -mod is_nan_fp8x23; -mod is_inf_fp16x16; -mod is_inf_fp8x23; -mod is_inf_i32; -mod is_inf_i8; -mod is_inf_u32; -mod is_pos_inf_fp16x16; -mod is_neg_inf_fp16x16; -mod is_pos_inf_fp8x23; -mod is_neg_inf_fp8x23; -mod is_pos_inf_i32; -mod is_neg_inf_i32; -mod is_pos_inf_i8; -mod is_neg_inf_i8; -mod reduce_log_sum_fp8x23_export_do_not_keepdims; -mod reduce_log_sum_fp8x23_export_keepdims; -mod reduce_log_sum_fp8x23_export_negative_axes_keepdims; -mod reduce_log_sum_fp16x16_export_do_not_keepdims; -mod reduce_log_sum_fp16x16_export_keepdims; -mod reduce_log_sum_fp16x16_export_negative_axes_keepdims; -mod and_bool; -mod erf_fp16x16; -mod erf_fp8x23; -mod unique_fp16x16_without_axis_sorted; -mod unique_fp16x16_with_axis_zero_sorted; -mod unique_u32_without_axis_sorted; -mod unique_u32_without_axis_not_sorted; -mod unique_u32_with_axis_zero_sorted; -mod unique_u32_with_axis_zero_not_sorted; -mod unique_u32_with_axis_one_sorted; -mod unique_u32_with_axis_one_not_sorted; -mod gather_nd_fp16x16_3d_default; -mod gather_nd_fp16x16_3d_batch_dims1; -mod gather_nd_fp16x16_3d_batch_dims2; -mod gather_nd_fp8x23_3d_default; -mod gather_nd_fp8x23_3d_batch_dims1; -mod gather_nd_fp8x23_3d_batch_dims2; -mod gather_nd_i32_3d_default; -mod gather_nd_i32_3d_batch_dims1; -mod gather_nd_i32_3d_batch_dims2; -mod gather_nd_i8_3d_default; -mod gather_nd_i8_3d_batch_dims1; -mod gather_nd_u32_default; -mod gather_nd_u32_batch_dims1; -mod gather_nd_u32_batch_dims2; +// mod abs_fp16x16; +// mod abs_fp8x23; +// mod abs_i32; +// mod abs_i8; +// mod acos_fp16x16; +// mod acos_fp8x23; +// mod acosh_fp16x16; +// mod acosh_fp8x23; +// mod add_fp16x16; +// mod add_fp16x16_broadcast; +// mod add_fp8x23; +// mod add_fp8x23_broadcast; +// mod add_i32; +// mod add_i32_broadcast; +// mod add_i8; +// mod add_i8_broadcast; +// mod add_u32; +// mod add_u32_broadcast; +// mod argmax_fp16x16_1D_default; +// mod argmax_fp16x16_1D_keepdims_false; +// mod argmax_fp16x16_1D_last_index; +// mod argmax_fp16x16_2D_default; +// mod argmax_fp16x16_2D_keepdims_false; +// mod argmax_fp16x16_2D_last_index; +// mod argmax_fp16x16_3D_default; +// mod argmax_fp16x16_3D_keepdims_false; +// mod argmax_fp16x16_3D_last_index; +// mod argmax_fp8x23_1D_default; +// mod argmax_fp8x23_1D_keepdims_false; +// mod argmax_fp8x23_1D_last_index; +// mod argmax_fp8x23_2D_default; +// mod argmax_fp8x23_2D_keepdims_false; +// mod argmax_fp8x23_2D_last_index; +// mod argmax_fp8x23_3D_default; +// mod argmax_fp8x23_3D_keepdims_false; +// mod argmax_fp8x23_3D_last_index; +// mod argmax_i32_1D_default; +// mod argmax_i32_1D_keepdims_false; +// mod argmax_i32_1D_last_index; +// mod argmax_i32_2D_default; +// mod argmax_i32_2D_keepdims_false; +// mod argmax_i32_2D_last_index; +// mod argmax_i32_3D_default; +// mod argmax_i32_3D_keepdims_false; +// mod argmax_i32_3D_last_index; +// mod argmax_i8_1D_default; +// mod argmax_i8_1D_keepdims_false; +// mod argmax_i8_1D_last_index; +// mod argmax_i8_2D_default; +// mod argmax_i8_2D_keepdims_false; +// mod argmax_i8_2D_last_index; +// mod argmax_i8_3D_default; +// mod argmax_i8_3D_keepdims_false; +// mod argmax_i8_3D_last_index; +// mod argmax_u32_1D_default; +// mod argmax_u32_1D_keepdims_false; +// mod argmax_u32_1D_last_index; +// mod argmax_u32_2D_default; +// mod argmax_u32_2D_keepdims_false; +// mod argmax_u32_2D_last_index; +// mod argmax_u32_3D_default; +// mod argmax_u32_3D_keepdims_false; +// mod argmax_u32_3D_last_index; +// mod argmin_fp16x16_1D_default; +// mod argmin_fp16x16_1D_keepdims_false; +// mod argmin_fp16x16_1D_last_index; +// mod argmin_fp16x16_2D_default; +// mod argmin_fp16x16_2D_keepdims_false; +// mod argmin_fp16x16_2D_last_index; +// mod argmin_fp16x16_3D_default; +// mod argmin_fp16x16_3D_keepdims_false; +// mod argmin_fp16x16_3D_last_index; +// mod argmin_fp8x23_1D_default; +// mod argmin_fp8x23_1D_keepdims_false; +// mod argmin_fp8x23_1D_last_index; +// mod argmin_fp8x23_2D_default; +// mod argmin_fp8x23_2D_keepdims_false; +// mod argmin_fp8x23_2D_last_index; +// mod argmin_fp8x23_3D_default; +// mod argmin_fp8x23_3D_keepdims_false; +// mod argmin_fp8x23_3D_last_index; +// mod argmin_i32_1D_default; +// mod argmin_i32_1D_keepdims_false; +// mod argmin_i32_1D_last_index; +// mod argmin_i32_2D_default; +// mod argmin_i32_2D_keepdims_false; +// mod argmin_i32_2D_last_index; +// mod argmin_i32_3D_default; +// mod argmin_i32_3D_keepdims_false; +// mod argmin_i32_3D_last_index; +// mod argmin_i8_1D_default; +// mod argmin_i8_1D_keepdims_false; +// mod argmin_i8_1D_last_index; +// mod argmin_i8_2D_default; +// mod argmin_i8_2D_keepdims_false; +// mod argmin_i8_2D_last_index; +// mod argmin_i8_3D_default; +// mod argmin_i8_3D_keepdims_false; +// mod argmin_i8_3D_last_index; +// mod argmin_u32_1D_default; +// mod argmin_u32_1D_keepdims_false; +// mod argmin_u32_1D_last_index; +// mod argmin_u32_2D_default; +// mod argmin_u32_2D_keepdims_false; +// mod argmin_u32_2D_last_index; +// mod argmin_u32_3D_default; +// mod argmin_u32_3D_keepdims_false; +// mod argmin_u32_3D_last_index; +// mod asin_fp16x16; +// mod asin_fp8x23; +// mod asinh_fp16x16; +// mod asinh_fp8x23; +// mod atan_fp16x16; +// mod atan_fp8x23; +// mod ceil_fp16x16; +// mod ceil_fp8x23; +// mod concat_fp16x16_1d; +// mod concat_fp16x16_2d; +// mod concat_fp16x16_3d_default; +// mod concat_fp16x16_3d_axis_1; +// mod concat_fp16x16_3d_axis_2; +// mod concat_fp16x16_3d_three_tensors_axis_1; +// mod concat_fp16x16_3d_three_tensors_axis_2; +// mod concat_fp8x23_1d; +// mod concat_fp8x23_2d; +// mod concat_fp8x23_3d_default; +// mod concat_fp8x23_3d_axis_1; +// mod concat_fp8x23_3d_axis_2; +// mod concat_fp8x23_3d_three_tensors_axis_1; +// mod concat_fp8x23_3d_three_tensors_axis_2; +// mod concat_i32_1d; +// mod concat_i32_2d; +// mod concat_i32_3d_default; +// mod concat_i32_3d_axis_1; +// mod concat_i32_3d_axis_2; +// mod concat_i32_3d_three_tensors_axis_1; +// mod concat_i32_3d_three_tensors_axis_2; +// mod concat_i8_1d; +// mod concat_i8_2d; +// mod concat_i8_3d_default; +// mod concat_i8_3d_axis_1; +// mod concat_i8_3d_axis_2; +// mod concat_i8_3d_three_tensors_axis_1; +// mod concat_i8_3d_three_tensors_axis_2; +// mod concat_u32_1d; +// mod concat_u32_2d; +// mod concat_u32_3d_default; +// mod concat_u32_3d_axis_1; +// mod concat_u32_3d_axis_2; +// mod concat_u32_3d_three_tensors_axis_1; +// mod concat_u32_3d_three_tensors_axis_2; +// mod cos_fp16x16; +// mod cos_fp8x23; +// mod cosh_fp16x16; +// mod cosh_fp8x23; +// mod cumsum_fp16x16_1d_default; +// mod cumsum_fp16x16_1d_exclusive; +// mod cumsum_fp16x16_1d_reverse; +// mod cumsum_fp16x16_1d_reverse_exclusive; +// mod cumsum_fp16x16_2d_axis_0; +// mod cumsum_fp16x16_2d_axis_1; +// mod cumsum_fp8x23_1d_default; +// mod cumsum_fp8x23_1d_exclusive; +// mod cumsum_fp8x23_1d_reverse; +// mod cumsum_fp8x23_1d_reverse_exclusive; +// mod cumsum_fp8x23_2d_axis_0; +// mod cumsum_fp8x23_2d_axis_1; +// mod cumsum_i32_1d_default; +// mod cumsum_i32_1d_exclusive; +// mod cumsum_i32_1d_reverse; +// mod cumsum_i32_1d_reverse_exclusive; +// mod cumsum_i32_2d_axis_0; +// mod cumsum_i32_2d_axis_1; +// mod cumsum_i8_1d_default; +// mod cumsum_i8_1d_exclusive; +// mod cumsum_i8_1d_reverse; +// mod cumsum_i8_1d_reverse_exclusive; +// mod cumsum_i8_2d_axis_0; +// mod cumsum_i8_2d_axis_1; +// mod cumsum_u32_1d_default; +// mod cumsum_u32_1d_exclusive; +// mod cumsum_u32_1d_reverse; +// mod cumsum_u32_1d_reverse_exclusive; +// mod cumsum_u32_2d_axis_0; +// mod cumsum_u32_2d_axis_1; +// mod div_fp16x16; +// mod div_fp16x16_broadcast; +// mod div_fp8x23; +// mod div_fp8x23_broadcast; +// mod div_i32; +// mod div_i32_broadcast; +// mod div_i8; +// mod div_i8_broadcast; +// mod div_u32; +// mod div_u32_broadcast; +// mod equal_fp16x16; +// mod equal_fp16x16_broadcast; +// mod equal_fp8x23; +// mod equal_fp8x23_broadcast; +// mod equal_i32; +// mod equal_i32_broadcast; +// mod equal_i8; +// mod equal_i8_broadcast; +// mod equal_u32; +// mod equal_u32_broadcast; +// mod exp_fp16x16; +// mod exp_fp8x23; +// mod less_equal_fp16x16; +// mod less_equal_fp16x16_broadcast; +// mod less_equal_fp8x23; +// mod less_equal_fp8x23_broadcast; +// mod less_equal_i32; +// mod less_equal_i32_broadcast; +// mod less_equal_i8; +// mod less_equal_i8_broadcast; +// mod less_equal_u32; +// mod less_equal_u32_broadcast; +// mod greater_fp16x16; +// mod greater_fp16x16_broadcast; +// mod greater_fp8x23; +// mod greater_fp8x23_broadcast; +// mod greater_i32; +// mod greater_i32_broadcast; +// mod greater_i8; +// mod greater_i8_broadcast; +// mod greater_u32; +// mod greater_u32_broadcast; +// mod leaky_relu_fp16x16; +// mod leaky_relu_fp8x23; +// mod linear_fp16x16; +// mod linear_fp8x23; +// mod linear_i32; +// mod linear_i8; +// mod linear_u32; +// mod log_fp16x16; +// mod log_fp8x23; +// mod logsoftmax_fp16x16_axis_0; +// mod logsoftmax_fp16x16_axis_1; +// mod logsoftmax_fp8x23_axis_0; +// mod logsoftmax_fp8x23_axis_1; +// mod matmul_fp16x16_1d; +// mod matmul_fp16x16_2x2; +// mod matmul_fp16x16_2x1; +// mod matmul_fp16x16_1x2; +// mod matmul_fp8x23_1d; +// mod matmul_fp8x23_2x2; +// mod matmul_fp8x23_2x1; +// mod matmul_fp8x23_1x2; +// mod matmul_i32_1d; +// mod matmul_i32_2x2; +// mod matmul_i32_2x1; +// mod matmul_i32_1x2; +// mod matmul_i8_1d; +// mod matmul_i8_2x2; +// mod matmul_i8_2x1; +// mod matmul_i8_1x2; +// mod matmul_u32_1d; +// mod matmul_u32_2x2; +// mod matmul_u32_2x1; +// mod matmul_u32_1x2; +// mod mul_fp16x16; +// mod mul_fp16x16_broadcast; +// mod mul_fp8x23; +// mod mul_fp8x23_broadcast; +// mod mul_i32; +// mod mul_i32_broadcast; +// mod mul_i8; +// mod mul_i8_broadcast; +// mod mul_u32; +// mod mul_u32_broadcast; +// mod or_fp16x16; +// mod or_fp16x16_broadcast; +// mod or_fp8x23; +// mod or_fp8x23_broadcast; +// mod or_i32; +// mod or_i32_broadcast; +// mod or_i8; +// mod or_i8_broadcast; +// mod or_u32; +// mod or_u32_broadcast; +// mod reduce_sum_fp16x16_1D; +// mod reduce_sum_fp16x16_2D_default; +// mod reduce_sum_fp16x16_2D_keepdims; +// mod reduce_sum_fp16x16_2D_axis_1; +// mod reduce_sum_fp8x23_1D; +// mod reduce_sum_fp8x23_2D_default; +// mod reduce_sum_fp8x23_2D_keepdims; +// mod reduce_sum_fp8x23_2D_axis_1; +// mod reduce_sum_i32_1D; +// mod reduce_sum_i32_2D_default; +// mod reduce_sum_i32_2D_keepdims; +// mod reduce_sum_i32_2D_axis_1; +// mod reduce_sum_i8_1D; +// mod reduce_sum_i8_2D_default; +// mod reduce_sum_i8_2D_keepdims; +// mod reduce_sum_i8_2D_axis_1; +// mod reduce_sum_u32_1D; +// mod reduce_sum_u32_2D_default; +// mod reduce_sum_u32_2D_keepdims; +// mod reduce_sum_u32_2D_axis_1; +// mod relu_fp16x16; +// mod relu_fp8x23; +// mod relu_i32; +// mod relu_i8; +// mod sigmoid_fp16x16; +// mod sigmoid_fp8x23; +// mod sin_fp16x16; +// mod sin_fp8x23; +// mod sinh_fp16x16; +// mod sinh_fp8x23; +// mod softmax_fp16x16; +// mod softmax_fp8x23; +// mod softplus_fp8x23; +// mod softplus_fp16x16; +// mod softsign_fp8x23; +// mod softsign_fp16x16; +// mod sqrt_fp16x16; +// mod sqrt_fp8x23; +// mod sub_fp16x16; +// mod sub_fp16x16_broadcast; +// mod sub_fp8x23; +// mod sub_fp8x23_broadcast; +// mod sub_i32; +// mod sub_i32_broadcast; +// mod sub_i8; +// mod sub_i8_broadcast; +// mod sub_u32; +// mod sub_u32_broadcast; +// mod tanh_fp16x16; +// mod tanh_fp8x23; +// mod transpose_fp16x16_2d; +// mod transpose_fp16x16_3d; +// mod transpose_fp8x23_2d; +// mod transpose_fp8x23_3d; +// mod transpose_i32_2d; +// mod transpose_i32_3d; +// mod transpose_i8_2d; +// mod transpose_i8_3d; +// mod transpose_u32_2d; +// mod transpose_u32_3d; +// mod xor_fp16x16; +// mod xor_fp16x16_broadcast; +// mod xor_fp8x23; +// mod xor_fp8x23_broadcast; +// mod xor_i32; +// mod xor_i32_broadcast; +// mod xor_i8; +// mod xor_i8_broadcast; +// mod xor_u32; +// mod xor_u32_broadcast; +// mod less_fp16x16; +// mod less_fp16x16_broadcast; +// mod less_fp8x23; +// mod less_fp8x23_broadcast; +// mod less_i32; +// mod less_i32_broadcast; +// mod less_i8; +// mod less_i8_broadcast; +// mod less_u32; +// mod less_u32_broadcast; +// mod greater_equal_fp16x16; +// mod greater_equal_fp16x16_broadcast; +// mod greater_equal_fp8x23; +// mod greater_equal_fp8x23_broadcast; +// mod greater_equal_i32; +// mod greater_equal_i32_broadcast; +// mod greater_equal_i8; +// mod greater_equal_i8_broadcast; +// mod greater_equal_u32; +// mod greater_equal_u32_broadcast; +// mod slice_fp16x16_2d; +// mod slice_fp16x16_3d; +// mod slice_fp8x23_2d; +// mod slice_fp8x23_3d; +// mod slice_i32_2d; +// mod slice_i32_3d; +// mod slice_i8_2d; +// mod slice_i8_3d; +// mod slice_u32_2d; +// mod slice_u32_3d; +// mod gather_fp8x23_3d_default; +// mod gather_fp8x23_3d_axis1; +// mod gather_fp8x23_3d_axis2; +// mod gather_fp16x16_3d_default; +// mod gather_fp16x16_3d_axis1; +// mod gather_fp16x16_3d_axis2; +// mod gather_i8_3d_default; +// mod gather_i8_3d_axis1; +// mod gather_i8_3d_axis2; +// mod gather_i32_3d_default; +// mod gather_i32_3d_axis1; +// mod gather_i32_3d_axis2; +// mod gather_u32_3d_default; +// mod gather_u32_3d_axis1; +// mod gather_u32_3d_axis2; +// mod nonzero_fp16x16_2d; +// mod nonzero_fp16x16_3d; +// mod nonzero_fp8x23_2d; +// mod nonzero_fp8x23_3d; +// mod nonzero_i32_2d; +// mod nonzero_i32_3d; +// mod nonzero_i8_2d; +// mod nonzero_i8_3d; +// mod nonzero_u32_2d; +// mod nonzero_u32_3d; +// mod squeeze_fP16x16; +// mod squeeze_fP8x23; +// mod squeeze_i32; +// mod squeeze_i8; +// mod squeeze_u32; +// mod unsqueeze_fp16x16_2d; +// mod unsqueeze_fp16x16_3d; +// mod unsqueeze_fp8x23_2d; +// mod unsqueeze_fp8x23_3d; +// mod unsqueeze_i32_2d; +// mod unsqueeze_i32_3d; +// mod unsqueeze_i8_2d; +// mod unsqueeze_i8_3d; +// mod unsqueeze_u32_2d; +// mod unsqueeze_u32_3d; +// mod sign_fP16x16; +// mod sign_fP8x23; +// mod sign_fail; +// mod sign_i32; +// mod sign_i8; +// mod clip_fp16x16_2d; +// mod clip_fp16x16_3d; +// mod clip_fp8x23_2d; +// mod clip_fp8x23_3d; +// mod clip_i32_2d; +// mod clip_i32_3d; +// mod clip_i8_2d; +// mod clip_i8_3d; +// mod clip_u32_2d; +// mod clip_u32_3d; +// mod identity_fP16x16; +// mod identity_fP8x23; +// mod identity_i32; +// mod identity_i8; +// mod identity_u32; +// mod thresholded_relu_fp16x16; +// mod thresholded_relu_fp8x23; +// mod hard_sigmoid_fp8x23; +// mod hard_sigmoid_fp16x16; +// mod neg_fp16x16; +// mod neg_fp8x23; +// mod neg_i32; +// mod neg_i8; +// mod gemm_all_attributes; +// mod gemm_alpha; +// mod gemm_beta; +// mod gemm_default_matrix_bias; +// mod gemm_default_vector_bias; +// mod gemm_default_no_bias; +// mod gemm_transposeA; +// mod gemm_transposeB; +// mod min_fp16x16_three_tensors; +// mod min_fp16x16_broadcast_three_tensors; +// mod min_fp16x16_two_tensors; +// mod min_fp16x16_broadcast_two_tensors; +// mod min_fp8x23_three_tensors; +// mod min_fp8x23_broadcast_three_tensors; +// mod min_fp8x23_two_tensors; +// mod min_fp8x23_broadcast_two_tensors; +// mod min_i32_three_tensors; +// mod min_i32_broadcast_three_tensors; +// mod min_i32_two_tensors; +// mod min_i32_broadcast_two_tensors; +// mod min_i8_three_tensors; +// mod min_i8_broadcast_three_tensors; +// mod min_i8_two_tensors; +// mod min_i8_broadcast_two_tensors; +// mod min_u32_three_tensors; +// mod min_u32_broadcast_three_tensors; +// mod min_u32_two_tensors; +// mod min_u32_broadcast_two_tensors; +// mod where_fp16x16; +// mod where_fp16x16_broadcast; +// mod where_fp8x23; +// mod where_fp8x23_broadcast; +// mod where_i32; +// mod where_i32_broadcast; +// mod where_i8; +// mod where_i8_broadcast; +// mod where_u32; +// mod where_u32_broadcast; +// mod not_bool; +// mod round_fp16x16; +// mod round_fp8x23; +// mod max_fp16x16_three_tensors; +// mod max_fp16x16_broadcast_three_tensors; +// mod max_fp16x16_two_tensors; +// mod max_fp16x16_broadcast_two_tensors; +// mod max_fp8x23_three_tensors; +// mod max_fp8x23_broadcast_three_tensors; +// mod max_fp8x23_two_tensors; +// mod max_fp8x23_broadcast_two_tensors; +// mod max_i32_three_tensors; +// mod max_i32_broadcast_three_tensors; +// mod max_i32_two_tensors; +// mod max_i32_broadcast_two_tensors; +// mod max_i8_three_tensors; +// mod max_i8_broadcast_three_tensors; +// mod max_i8_two_tensors; +// mod max_i8_broadcast_two_tensors; +// mod max_u32_three_tensors; +// mod max_u32_broadcast_three_tensors; +// mod max_u32_two_tensors; +// mod max_u32_broadcast_two_tensors; +// mod scatter_fp16x16_3d_default; +// mod scatter_fp16x16_3d_axis1; +// mod scatter_fp16x16_3d_axis1_add; +// mod scatter_fp8x23_default; +// mod scatter_fp8x23_axis1; +// mod scatter_fp8x23_mul; +// mod scatter_i8_default; +// mod scatter_i8_axis1; +// mod scatter_i8_axis1_max; +// mod scatter_u32_default; +// mod scatter_u32_axis1; +// mod scatter_u32_add; +// mod array_feature_extractor_1D_i32; +// mod array_feature_extractor_1D_fp8x23; +// mod array_feature_extractor_1D_fp16x16; +// mod array_feature_extractor_2D_i32; +// mod array_feature_extractor_2D_fp8x23; +// mod array_feature_extractor_2D_fp16x16; +// mod array_feature_extractor_3D_i32; +// mod array_feature_extractor_3D_fp8x23; +// mod array_feature_extractor_3D_fp16x16; +// mod binarizer_fp16x16; +// mod binarizer_fp8x23; +// mod tril_fp16x16; +// mod tril_fp16x16_neg; +// mod tril_fp16x16_one_row; +// mod tril_fp16x16_out_neg; +// mod tril_fp16x16_out_pos; +// mod tril_fp16x16_pos; +// mod tril_fp16x16_square; +// mod tril_fp16x16_square_neg; +// mod tril_fp16x16_zero; +// mod triu_fp16x16; +// mod triu_fp16x16_neg; +// mod triu_fp16x16_one_row; +// mod triu_fp16x16_out_neg; +// mod triu_fp16x16_out_pos; +// mod triu_fp16x16_pos; +// mod triu_fp16x16_square; +// mod triu_fp16x16_square_neg; +// mod triu_fp16x16_zero; +// mod tril_fp8x23; +// mod tril_fp8x23_neg; +// mod tril_fp8x23_one_row; +// mod tril_fp8x23_out_neg; +// mod tril_fp8x23_out_pos; +// mod tril_fp8x23_pos; +// mod tril_fp8x23_square; +// mod tril_fp8x23_square_neg; +// mod tril_fp8x23_zero; +// mod triu_fp8x23; +// mod triu_fp8x23_neg; +// mod triu_fp8x23_one_row; +// mod triu_fp8x23_out_neg; +// mod triu_fp8x23_out_pos; +// mod triu_fp8x23_pos; +// mod triu_fp8x23_square; +// mod triu_fp8x23_square_neg; +// mod triu_fp8x23_zero; +// mod tril_i32; +// mod tril_neg_i32; +// mod tril_i32_one_row; +// mod tril_i32_out_neg; +// mod tril_i32_out_pos; +// mod tril_i32_pos; +// mod tril_i32_square; +// mod tril_i32_square_neg; +// mod tril_i32_zero; +// mod triu_i32; +// mod triu_i32_neg; +// mod triu_i32_one_row; +// mod triu_i32_out_neg; +// mod triu_i32_out_pos; +// mod triu_i32_pos; +// mod triu_i32_square; +// mod triu_i32_square_neg; +// mod triu_i32_zero; +// mod tril_i8; +// mod tril_i8_neg; +// mod tril_i8_one_row; +// mod tril_i8_out_neg; +// mod tril_i8_out_pos; +// mod tril_i8_pos; +// mod tril_i8_square; +// mod tril_i8_square_neg; +// mod tril_i8_zero; +// mod triu_i8; +// mod triu_i8_neg; +// mod triu_i8_one_row; +// mod triu_i8_out_neg; +// mod triu_i8_out_pos; +// mod triu_i8_pos; +// mod triu_i8_square; +// mod triu_i8_square_neg; +// mod triu_i8_zero; +// mod tril_u32; +// mod tril_u32_neg; +// mod tril_u32_one_row; +// mod tril_u32_out_neg; +// mod tril_u32_out_pos; +// mod tril_u32_pos; +// mod tril_u32_square; +// mod tril_u32_square_neg; +// mod tril_u32_zero; +// mod triu_u32; +// mod triu_u32_neg; +// mod triu_u32_one_row; +// mod triu_u32_out_neg; +// mod triu_u32_out_pos; +// mod triu_u32_pos; +// mod triu_u32_square; +// mod triu_u32_square_neg; +// mod triu_u32_zero; +// mod reduce_sum_square_fp16x16_export_do_not_keepdims; +// mod reduce_sum_square_fp16x16_export_keepdims; +// mod reduce_sum_square_fp16x16_export_negative_axes_keepdims; +// mod reduce_sum_square_fp8x23_export_do_not_keepdims; +// mod reduce_sum_square_fp8x23_export_keepdims; +// mod reduce_sum_square_fp8x23_export_negative_axes_keepdims; +// mod reduce_sum_square_i32_export_do_not_keepdims; +// mod reduce_sum_square_i32_export_keepdims; +// mod reduce_sum_square_i32_export_negative_axes_keepdims; +// mod reduce_sum_square_i8_export_do_not_keepdims; +// mod reduce_sum_square_i8_export_keepdims; +// mod reduce_sum_square_i8_export_negative_axes_keepdims; +// mod reduce_sum_square_u32_export_do_not_keepdims; +// mod reduce_sum_square_u32_export_keepdims; +// mod reduce_sum_square_u32_export_negative_axes_keepdims; +// mod reduce_l2_fp16x16_export_do_not_keepdims; +// mod reduce_l2_fp16x16_export_keepdims; +// mod reduce_l2_fp16x16_export_negative_axes_keepdims; +// mod reduce_l2_fp8x23_export_do_not_keepdims; +// mod reduce_l2_fp8x23_export_keepdims; +// mod reduce_l2_fp8x23_export_negative_axes_keepdims; +// mod reduce_l1_fp16x16_export_do_not_keepdims; +// mod reduce_l1_fp16x16_export_keepdims; +// mod reduce_l1_fp16x16_export_negative_axes_keepdims; +// mod reduce_l1_fp8x23_export_do_not_keepdims; +// mod reduce_l1_fp8x23_export_keepdims; +// mod reduce_l1_fp8x23_export_negative_axes_keepdims; +// mod reduce_l1_i32_export_do_not_keepdims; +// mod reduce_l1_i32_export_keepdims; +// mod reduce_l1_i32_export_negative_axes_keepdims; +// mod reduce_l1_i8_export_do_not_keepdims; +// mod reduce_l1_i8_export_keepdims; +// mod reduce_l1_i8_export_negative_axes_keepdims; +// mod reduce_l1_u32_export_do_not_keepdims; +// mod reduce_l1_u32_export_keepdims; +// mod reduce_l1_u32_export_negative_axes_keepdims; +// mod reduce_prod_fp16x16_1D; +// mod reduce_prod_fp16x16_2D_default; +// mod reduce_prod_fp16x16_2D_keepdims; +// mod reduce_prod_fp16x16_2D_axis_1; +// mod reduce_prod_fp8x23_1D; +// mod reduce_prod_fp8x23_2D_default; +// mod reduce_prod_fp8x23_2D_keepdims; +// mod reduce_prod_fp8x23_2D_axis_1; +// mod reduce_prod_i32_1D; +// mod reduce_prod_i32_2D_default; +// mod reduce_prod_i32_2D_keepdims; +// mod reduce_prod_i32_2D_axis_1; +// mod reduce_prod_i8_1D; +// mod reduce_prod_i8_2D_default; +// mod reduce_prod_i8_2D_keepdims; +// mod reduce_prod_i8_2D_axis_1; +// mod reduce_prod_u32_1D; +// mod reduce_prod_u32_2D_default; +// mod reduce_prod_u32_2D_keepdims; +// mod reduce_prod_u32_2D_axis_1; +// mod gather_elements_fp16x16_3d_default; +// mod gather_elements_fp16x16_3d_axis1; +// mod gather_elements_fp16x16_3d_axis2; +// mod gather_elements_fp8x23_3d_default; +// mod gather_elements_fp8x23_3d_axis1; +// mod gather_elements_fp8x23_3d_axis2; +// mod gather_elements_i8_3d_default; +// mod gather_elements_i8_3d_axis1; +// mod gather_elements_i32_3d_default; +// mod gather_elements_i32_3d_axis1; +// mod gather_elements_i32_3d_axis2; +// mod gather_elements_u32_default; +// mod gather_elements_u32_axis1; +// mod gather_elements_u32_axis2; +// mod gather_elements_u32_axis3; +// mod sequence_length_fp16x16; +// mod sequence_length_fp16x16_broadcast; +// mod sequence_length_fp8x23; +// mod sequence_length_fp8x23_broadcast; +// mod sequence_length_i32; +// mod sequence_length_i32_broadcast; +// mod sequence_length_i8; +// mod sequence_length_i8_broadcast; +// mod sequence_length_u32; +// mod sequence_length_u32_broadcast; +// mod sequence_at_u32_positive; +// mod sequence_at_u32_negative; +// mod sequence_at_fp16x16_positive; +// mod sequence_at_fp16x16_negative; +// mod sequence_at_fp8x23_positive; +// mod sequence_at_fp8x23_negative; +// mod sequence_at_i32_positive; +// mod sequence_at_i32_negative; +// mod sequence_at_i8_positive; +// mod sequence_at_i8_negative; +// mod reduce_min_fp16x16_1D; +// mod reduce_min_fp16x16_2D_default; +// mod reduce_min_fp16x16_2D_keepdims; +// mod reduce_min_fp16x16_2D_axis_1; +// mod reduce_min_fp8x23_1D; +// mod reduce_min_fp8x23_2D_default; +// mod reduce_min_fp8x23_2D_keepdims; +// mod reduce_min_fp8x23_2D_axis_1; +// mod reduce_min_i32_1D; +// mod reduce_min_i32_2D_default; +// mod reduce_min_i32_2D_keepdims; +// mod reduce_min_i32_2D_axis_1; +// mod reduce_min_i8_1D; +// mod reduce_min_i8_2D_default; +// mod reduce_min_i8_2D_keepdims; +// mod reduce_min_i8_2D_axis_1; +// mod reduce_min_u32_1D; +// mod reduce_min_u32_2D_default; +// mod reduce_min_u32_2D_keepdims; +// mod reduce_min_u32_2D_axis_1; +// mod sequence_construct_fp16x16; +// mod sequence_construct_fp8x23; +// mod sequence_construct_i32; +// mod sequence_construct_i8; +// mod sequence_construct_u32; +// mod shrink_hard_fp16x16; +// mod shrink_soft_fp16x16; +// mod shrink_hard_fp8x23; +// mod shrink_soft_fp8x23; +// mod sequence_empty_fp16x16; +// mod sequence_empty_fp8x23; +// mod sequence_empty_i32; +// mod sequence_empty_i8; +// mod sequence_empty_u32; +// mod reduce_mean_fp16x16_1D; +// mod reduce_mean_fp16x16_2D_default; +// mod reduce_mean_fp16x16_2D_keepdims; +// mod reduce_mean_fp16x16_2D_axis_1; +// mod reduce_mean_fp8x23_1D; +// mod reduce_mean_fp8x23_2D_default; +// mod reduce_mean_fp8x23_2D_keepdims; +// mod reduce_mean_fp8x23_2D_axis_1; +// mod reduce_mean_i32_1D; +// mod reduce_mean_i32_2D_default; +// mod reduce_mean_i32_2D_keepdims; +// mod reduce_mean_i32_2D_axis_1; +// mod reduce_mean_i8_1D; +// mod reduce_mean_i8_2D_default; +// mod reduce_mean_i8_2D_keepdims; +// mod reduce_mean_i8_2D_axis_1; +// mod reduce_mean_u32_1D; +// mod reduce_mean_u32_2D_default; +// mod reduce_mean_u32_2D_keepdims; +// mod reduce_mean_u32_2D_axis_1; +// mod pow_fp16x16; +// mod pow_fp16x16_broadcast; +// mod pow_fp8x23; +// mod pow_fp8x23_broadcast; +// mod sequence_erase_u32_positive; +// mod sequence_erase_u32_negative; +// mod sequence_erase_u32_empty; +// mod sequence_erase_fp16x16_positive; +// mod sequence_erase_fp16x16_negative; +// mod sequence_erase_fp16x16_empty; +// mod sequence_erase_fp8x23_positive; +// mod sequence_erase_fp8x23_negative; +// mod sequence_erase_fp8x23_empty; +// mod sequence_erase_i32_positive; +// mod sequence_erase_i32_negative; +// mod sequence_erase_i32_empty; +// mod sequence_erase_i8_positive; +// mod sequence_erase_i8_negative; +// mod sequence_erase_i8_empty; +// mod sequence_insert_fp16x16; +// mod sequence_insert_fp8x23; +// mod sequence_insert_i32; +// mod sequence_insert_i8; +// mod sequence_insert_u32; +// mod concat_from_sequence_fp8x23_new_axis_zero; +// mod concat_from_sequence_fp8x23_new_axis_one; +// mod concat_from_sequence_fp8x23_new_axis_default; +// mod concat_from_sequence_fp16x16_new_axis_zero; +// mod concat_from_sequence_fp16x16_new_axis_one; +// mod concat_from_sequence_fp16x16_new_axis_default; +// mod concat_from_sequence_i32_new_axis_zero; +// mod concat_from_sequence_i32_new_axis_one; +// mod concat_from_sequence_i32_new_axis_default; +// mod concat_from_sequence_i8_new_axis_zero; +// mod concat_from_sequence_i8_new_axis_one; +// mod concat_from_sequence_i8_new_axis_default; +// mod concat_from_sequence_u32_new_axis_zero; +// mod concat_from_sequence_u32_new_axis_one; +// mod concat_from_sequence_u32_new_axis_default; +// mod is_nan_fp16x16; +// mod is_nan_fp8x23; +// mod is_inf_fp16x16; +// mod is_inf_fp8x23; +// mod is_inf_i32; +// mod is_inf_i8; +// mod is_inf_u32; +// mod is_pos_inf_fp16x16; +// mod is_neg_inf_fp16x16; +// mod is_pos_inf_fp8x23; +// mod is_neg_inf_fp8x23; +// mod is_pos_inf_i32; +// mod is_neg_inf_i32; +// mod is_pos_inf_i8; +// mod is_neg_inf_i8; +// mod reduce_log_sum_fp8x23_export_do_not_keepdims; +// mod reduce_log_sum_fp8x23_export_keepdims; +// mod reduce_log_sum_fp8x23_export_negative_axes_keepdims; +// mod reduce_log_sum_fp16x16_export_do_not_keepdims; +// mod reduce_log_sum_fp16x16_export_keepdims; +// mod reduce_log_sum_fp16x16_export_negative_axes_keepdims; +// mod and_bool; +// mod erf_fp16x16; +// mod erf_fp8x23; +// mod unique_fp16x16_without_axis_sorted; +// mod unique_fp16x16_with_axis_zero_sorted; +// mod unique_u32_without_axis_sorted; +// mod unique_u32_without_axis_not_sorted; +// mod unique_u32_with_axis_zero_sorted; +// mod unique_u32_with_axis_zero_not_sorted; +// mod unique_u32_with_axis_one_sorted; +// mod unique_u32_with_axis_one_not_sorted; +// mod gather_nd_fp16x16_3d_default; +// mod gather_nd_fp16x16_3d_batch_dims1; +// mod gather_nd_fp16x16_3d_batch_dims2; +// mod gather_nd_fp8x23_3d_default; +// mod gather_nd_fp8x23_3d_batch_dims1; +// mod gather_nd_fp8x23_3d_batch_dims2; +// mod gather_nd_i32_3d_default; +// mod gather_nd_i32_3d_batch_dims1; +// mod gather_nd_i32_3d_batch_dims2; +// mod gather_nd_i8_3d_default; +// mod gather_nd_i8_3d_batch_dims1; +// mod gather_nd_u32_default; +// mod gather_nd_u32_batch_dims1; +// mod gather_nd_u32_batch_dims2; From 7e224088def0952e3de962f9fc295dc3626c8a61 Mon Sep 17 00:00:00 2001 From: Hakeem Kazeem Date: Mon, 25 Dec 2023 14:40:08 +0100 Subject: [PATCH 14/38] compress tests --- nodegen/node/compress.py | 326 ++++ src/operators/tensor/core.cairo | 2 +- .../tensor/implementations/tensor_bool.cairo | 4 +- .../implementations/tensor_fp16x16.cairo | 4 +- .../implementations/tensor_fp16x16wide.cairo | 4 +- .../implementations/tensor_fp32x32.cairo | 4 +- .../implementations/tensor_fp64x64.cairo | 4 +- .../implementations/tensor_fp8x23.cairo | 4 +- .../implementations/tensor_fp8x23wide.cairo | 4 +- .../tensor/implementations/tensor_i32.cairo | 4 +- .../tensor/implementations/tensor_i8.cairo | 4 +- .../tensor/implementations/tensor_u32.cairo | 4 +- src/operators/tensor/math/compress.cairo | 255 +-- tests/nodes.cairo | 1723 +++++++++-------- tests/nodes/compress_fp16x16_3d_axis1.cairo | 24 + .../compress_fp16x16_3d_axis1/input_0.cairo | 195 ++ .../compress_fp16x16_3d_axis1/input_1.cairo | 15 + .../compress_fp16x16_3d_axis1/output_0.cairo | 150 ++ tests/nodes/compress_fp16x16_3d_axis2.cairo | 24 + .../compress_fp16x16_3d_axis2/input_0.cairo | 62 + .../compress_fp16x16_3d_axis2/input_1.cairo | 15 + .../compress_fp16x16_3d_axis2/output_0.cairo | 50 + tests/nodes/compress_fp16x16_3d_axis3.cairo | 24 + .../compress_fp16x16_3d_axis3/input_0.cairo | 111 ++ .../compress_fp16x16_3d_axis3/input_1.cairo | 13 + .../compress_fp16x16_3d_axis3/output_0.cairo | 63 + tests/nodes/compress_fp16x16_3d_default.cairo | 24 + .../compress_fp16x16_3d_default/input_0.cairo | 41 + .../compress_fp16x16_3d_default/input_1.cairo | 14 + .../output_0.cairo | 32 + tests/nodes/compress_fp16x16_3d_noaxis.cairo | 24 + .../compress_fp16x16_3d_noaxis/input_0.cairo | 41 + .../compress_fp16x16_3d_noaxis/input_1.cairo | 20 + .../compress_fp16x16_3d_noaxis/output_0.cairo | 19 + tests/nodes/compress_fp8x23_3d_axis1.cairo | 24 + .../compress_fp8x23_3d_axis1/input_0.cairo | 41 + .../compress_fp8x23_3d_axis1/input_1.cairo | 14 + .../compress_fp8x23_3d_axis1/output_0.cairo | 32 + tests/nodes/compress_fp8x23_3d_axis2.cairo | 24 + .../compress_fp8x23_3d_axis2/input_0.cairo | 41 + .../compress_fp8x23_3d_axis2/input_1.cairo | 14 + .../compress_fp8x23_3d_axis2/output_0.cairo | 32 + tests/nodes/compress_fp8x23_3d_default.cairo | 24 + .../compress_fp8x23_3d_default/input_0.cairo | 41 + .../compress_fp8x23_3d_default/input_1.cairo | 14 + .../compress_fp8x23_3d_default/output_0.cairo | 32 + tests/nodes/compress_i32_3d_axis1.cairo | 24 + .../nodes/compress_i32_3d_axis1/input_0.cairo | 41 + .../nodes/compress_i32_3d_axis1/input_1.cairo | 14 + .../compress_i32_3d_axis1/output_0.cairo | 32 + tests/nodes/compress_i32_3d_axis2.cairo | 24 + .../nodes/compress_i32_3d_axis2/input_0.cairo | 41 + .../nodes/compress_i32_3d_axis2/input_1.cairo | 14 + .../compress_i32_3d_axis2/output_0.cairo | 32 + tests/nodes/compress_i32_3d_default.cairo | 24 + .../compress_i32_3d_default/input_0.cairo | 41 + .../compress_i32_3d_default/input_1.cairo | 14 + .../compress_i32_3d_default/output_0.cairo | 32 + tests/nodes/compress_i8_3d_axis1.cairo | 24 + .../nodes/compress_i8_3d_axis1/input_0.cairo | 41 + .../nodes/compress_i8_3d_axis1/input_1.cairo | 14 + .../nodes/compress_i8_3d_axis1/output_0.cairo | 32 + tests/nodes/compress_i8_3d_axis2.cairo | 24 + .../nodes/compress_i8_3d_axis2/input_0.cairo | 41 + .../nodes/compress_i8_3d_axis2/input_1.cairo | 14 + .../nodes/compress_i8_3d_axis2/output_0.cairo | 32 + tests/nodes/compress_i8_3d_default.cairo | 24 + .../compress_i8_3d_default/input_0.cairo | 41 + .../compress_i8_3d_default/input_1.cairo | 14 + .../compress_i8_3d_default/output_0.cairo | 32 + tests/nodes/compress_u32_3d_axis1.cairo | 22 + .../nodes/compress_u32_3d_axis1/input_0.cairo | 49 + .../nodes/compress_u32_3d_axis1/input_1.cairo | 14 + .../compress_u32_3d_axis1/output_0.cairo | 31 + tests/nodes/compress_u32_3d_axis2.cairo | 22 + .../nodes/compress_u32_3d_axis2/input_0.cairo | 61 + .../nodes/compress_u32_3d_axis2/input_1.cairo | 14 + .../compress_u32_3d_axis2/output_0.cairo | 37 + tests/nodes/compress_u32_3d_axis2_2.cairo | 22 + .../compress_u32_3d_axis2_2/input_0.cairo | 73 + .../compress_u32_3d_axis2_2/input_1.cairo | 14 + .../compress_u32_3d_axis2_2/output_0.cairo | 37 + tests/nodes/compress_u32_3d_axis3.cairo | 22 + .../nodes/compress_u32_3d_axis3/input_0.cairo | 284 +++ .../nodes/compress_u32_3d_axis3/input_1.cairo | 17 + .../compress_u32_3d_axis3/output_0.cairo | 194 ++ tests/nodes/compress_u32_3d_default.cairo | 22 + .../compress_u32_3d_default/input_0.cairo | 61 + .../compress_u32_3d_default/input_1.cairo | 13 + .../compress_u32_3d_default/output_0.cairo | 37 + 90 files changed, 4315 insertions(+), 1037 deletions(-) create mode 100644 nodegen/node/compress.py create mode 100644 tests/nodes/compress_fp16x16_3d_axis1.cairo create mode 100644 tests/nodes/compress_fp16x16_3d_axis1/input_0.cairo create mode 100644 tests/nodes/compress_fp16x16_3d_axis1/input_1.cairo create mode 100644 tests/nodes/compress_fp16x16_3d_axis1/output_0.cairo create mode 100644 tests/nodes/compress_fp16x16_3d_axis2.cairo create mode 100644 tests/nodes/compress_fp16x16_3d_axis2/input_0.cairo create mode 100644 tests/nodes/compress_fp16x16_3d_axis2/input_1.cairo create mode 100644 tests/nodes/compress_fp16x16_3d_axis2/output_0.cairo create mode 100644 tests/nodes/compress_fp16x16_3d_axis3.cairo create mode 100644 tests/nodes/compress_fp16x16_3d_axis3/input_0.cairo create mode 100644 tests/nodes/compress_fp16x16_3d_axis3/input_1.cairo create mode 100644 tests/nodes/compress_fp16x16_3d_axis3/output_0.cairo create mode 100644 tests/nodes/compress_fp16x16_3d_default.cairo create mode 100644 tests/nodes/compress_fp16x16_3d_default/input_0.cairo create mode 100644 tests/nodes/compress_fp16x16_3d_default/input_1.cairo create mode 100644 tests/nodes/compress_fp16x16_3d_default/output_0.cairo create mode 100644 tests/nodes/compress_fp16x16_3d_noaxis.cairo create mode 100644 tests/nodes/compress_fp16x16_3d_noaxis/input_0.cairo create mode 100644 tests/nodes/compress_fp16x16_3d_noaxis/input_1.cairo create mode 100644 tests/nodes/compress_fp16x16_3d_noaxis/output_0.cairo create mode 100644 tests/nodes/compress_fp8x23_3d_axis1.cairo create mode 100644 tests/nodes/compress_fp8x23_3d_axis1/input_0.cairo create mode 100644 tests/nodes/compress_fp8x23_3d_axis1/input_1.cairo create mode 100644 tests/nodes/compress_fp8x23_3d_axis1/output_0.cairo create mode 100644 tests/nodes/compress_fp8x23_3d_axis2.cairo create mode 100644 tests/nodes/compress_fp8x23_3d_axis2/input_0.cairo create mode 100644 tests/nodes/compress_fp8x23_3d_axis2/input_1.cairo create mode 100644 tests/nodes/compress_fp8x23_3d_axis2/output_0.cairo create mode 100644 tests/nodes/compress_fp8x23_3d_default.cairo create mode 100644 tests/nodes/compress_fp8x23_3d_default/input_0.cairo create mode 100644 tests/nodes/compress_fp8x23_3d_default/input_1.cairo create mode 100644 tests/nodes/compress_fp8x23_3d_default/output_0.cairo create mode 100644 tests/nodes/compress_i32_3d_axis1.cairo create mode 100644 tests/nodes/compress_i32_3d_axis1/input_0.cairo create mode 100644 tests/nodes/compress_i32_3d_axis1/input_1.cairo create mode 100644 tests/nodes/compress_i32_3d_axis1/output_0.cairo create mode 100644 tests/nodes/compress_i32_3d_axis2.cairo create mode 100644 tests/nodes/compress_i32_3d_axis2/input_0.cairo create mode 100644 tests/nodes/compress_i32_3d_axis2/input_1.cairo create mode 100644 tests/nodes/compress_i32_3d_axis2/output_0.cairo create mode 100644 tests/nodes/compress_i32_3d_default.cairo create mode 100644 tests/nodes/compress_i32_3d_default/input_0.cairo create mode 100644 tests/nodes/compress_i32_3d_default/input_1.cairo create mode 100644 tests/nodes/compress_i32_3d_default/output_0.cairo create mode 100644 tests/nodes/compress_i8_3d_axis1.cairo create mode 100644 tests/nodes/compress_i8_3d_axis1/input_0.cairo create mode 100644 tests/nodes/compress_i8_3d_axis1/input_1.cairo create mode 100644 tests/nodes/compress_i8_3d_axis1/output_0.cairo create mode 100644 tests/nodes/compress_i8_3d_axis2.cairo create mode 100644 tests/nodes/compress_i8_3d_axis2/input_0.cairo create mode 100644 tests/nodes/compress_i8_3d_axis2/input_1.cairo create mode 100644 tests/nodes/compress_i8_3d_axis2/output_0.cairo create mode 100644 tests/nodes/compress_i8_3d_default.cairo create mode 100644 tests/nodes/compress_i8_3d_default/input_0.cairo create mode 100644 tests/nodes/compress_i8_3d_default/input_1.cairo create mode 100644 tests/nodes/compress_i8_3d_default/output_0.cairo create mode 100644 tests/nodes/compress_u32_3d_axis1.cairo create mode 100644 tests/nodes/compress_u32_3d_axis1/input_0.cairo create mode 100644 tests/nodes/compress_u32_3d_axis1/input_1.cairo create mode 100644 tests/nodes/compress_u32_3d_axis1/output_0.cairo create mode 100644 tests/nodes/compress_u32_3d_axis2.cairo create mode 100644 tests/nodes/compress_u32_3d_axis2/input_0.cairo create mode 100644 tests/nodes/compress_u32_3d_axis2/input_1.cairo create mode 100644 tests/nodes/compress_u32_3d_axis2/output_0.cairo create mode 100644 tests/nodes/compress_u32_3d_axis2_2.cairo create mode 100644 tests/nodes/compress_u32_3d_axis2_2/input_0.cairo create mode 100644 tests/nodes/compress_u32_3d_axis2_2/input_1.cairo create mode 100644 tests/nodes/compress_u32_3d_axis2_2/output_0.cairo create mode 100644 tests/nodes/compress_u32_3d_axis3.cairo create mode 100644 tests/nodes/compress_u32_3d_axis3/input_0.cairo create mode 100644 tests/nodes/compress_u32_3d_axis3/input_1.cairo create mode 100644 tests/nodes/compress_u32_3d_axis3/output_0.cairo create mode 100644 tests/nodes/compress_u32_3d_default.cairo create mode 100644 tests/nodes/compress_u32_3d_default/input_0.cairo create mode 100644 tests/nodes/compress_u32_3d_default/input_1.cairo create mode 100644 tests/nodes/compress_u32_3d_default/output_0.cairo diff --git a/nodegen/node/compress.py b/nodegen/node/compress.py new file mode 100644 index 000000000..cca518f54 --- /dev/null +++ b/nodegen/node/compress.py @@ -0,0 +1,326 @@ +import numpy as np +from nodegen.node import RunAll +from ..helpers import make_test, to_fp, Tensor, Dtype, FixedImpl, Trait + +class Compress(RunAll): + + @staticmethod + def compress_fp16x16(): + + def compress_3D(): + def default(): + x1 = np.arange(0,27).reshape(3,3,3).astype(np.int64) + x2 = np.array([0, 1, 1]).astype(np.uint32) + y = x1.compress(x2, axis=0) + + x1 = Tensor(Dtype.FP16x16, x1.shape, to_fp(x1.flatten(), FixedImpl.FP16x16)) + x2 = Tensor(Dtype.U32, x2.shape, x2.flatten()) + y = Tensor(Dtype.FP16x16, y.shape, to_fp( + y.flatten(), FixedImpl.FP16x16)) + + name = "compress_fp16x16_3d_default" + make_test( + inputs = [x1, x2], output = y, func_sig = "input_0.compress(condition:input_1, axis:Option::Some(0))", + name= name) + + def axis1(): + x1 = np.arange(0,180).reshape(3,4,3,5).astype(np.int64) + x2 = np.array([1, 1, 1, 0]).astype(np.int64) + y = x1.compress(x2, axis=1) + + x1 = Tensor(Dtype.FP16x16, x1.shape, to_fp(x1.flatten(), FixedImpl.FP16x16)) + x2 = Tensor(Dtype.U32, x2.shape, x2.flatten()) + y = Tensor(Dtype.FP16x16, y.shape, to_fp( + y.flatten(), FixedImpl.FP16x16)) + + name = "compress_fp16x16_3d_axis1" + make_test( + inputs = [x1, x2], output = y, func_sig = "input_0.compress(condition:input_1, axis:Option::Some(1))", + name= name) + + def axis2(): + x1 = np.arange(0,48).reshape(4,3,4).astype(np.int64) + x2 = np.array([1, 0, 1, 1]).astype(np.int64) + y = x1.compress(x2, axis=2) + + x1 = Tensor(Dtype.FP16x16, x1.shape, to_fp(x1.flatten(), FixedImpl.FP16x16)) + x2 = Tensor(Dtype.U32, x2.shape, x2.flatten()) + y = Tensor(Dtype.FP16x16, y.shape, to_fp( + y.flatten(), FixedImpl.FP16x16)) + + name = "compress_fp16x16_3d_axis2" + make_test( + inputs = [x1, x2], output = y, func_sig = "input_0.compress(condition:input_1, axis:Option::Some(2))", + name= name) + + def axis3(): + x1 = np.arange(0,96).reshape(4,3,4, 2).astype(np.int64) + x2 = np.array([1, 0]).astype(np.int64) + y = x1.compress(x2, axis=3) + + x1 = Tensor(Dtype.FP16x16, x1.shape, to_fp(x1.flatten(), FixedImpl.FP16x16)) + x2 = Tensor(Dtype.U32, x2.shape, x2.flatten()) + y = Tensor(Dtype.FP16x16, y.shape, to_fp( + y.flatten(), FixedImpl.FP16x16)) + + name = "compress_fp16x16_3d_axis3" + make_test( + inputs = [x1, x2], output = y, func_sig = "input_0.compress(condition:input_1, axis:Option::Some(3))", + name= name) + + def noaxis(): + x1 = np.arange(0,27).reshape(3,3,3).astype(np.int64) + x2 = np.array([1, 0, 1, 0, 1, 1, 1, 1, 1]).astype(np.int64) + y = x1.compress(x2) + + x1 = Tensor(Dtype.FP16x16, x1.shape, to_fp(x1.flatten(), FixedImpl.FP16x16)) + x2 = Tensor(Dtype.U32, x2.shape, x2.flatten()) + y = Tensor(Dtype.FP16x16, y.shape, to_fp( + y.flatten(), FixedImpl.FP16x16)) + + name = "compress_fp16x16_3d_noaxis" + make_test( + inputs = [x1, x2], output = y, func_sig = "input_0.compress(condition:input_1, axis:Option::None(()))", + name= name) + + default() + axis1() + axis2() + axis3() + noaxis() + compress_3D() + + @staticmethod + def compress_fp8x23(): + + def compress_3D(): + def default(): + x1 = np.arange(0,27).reshape(3,3,3).astype(np.int64) + x2 = np.array([0, 1, 1]).astype(np.uint32) + y = x1.compress(x2, axis=0) + + x1 = Tensor(Dtype.FP8x23, x1.shape, to_fp(x1.flatten(), FixedImpl.FP8x23)) + x2 = Tensor(Dtype.U32, x2.shape, x2.flatten()) + y = Tensor(Dtype.FP8x23, y.shape, to_fp(y.flatten(), FixedImpl.FP8x23)) + + name = "compress_fp8x23_3d_default" + make_test( + inputs = [x1, x2], output = y, func_sig = "input_0.compress(condition:input_1, axis:Option::Some(0))", + name= name) + + def axis1(): + x1 = np.arange(0,27).reshape(3,3,3).astype(np.int64) + x2 = np.array([0, 1, 1]).astype(np.uint32) + y = x1.compress(x2, axis=1) + + x1 = Tensor(Dtype.FP8x23, x1.shape, to_fp(x1.flatten(), FixedImpl.FP8x23)) + x2 = Tensor(Dtype.U32, x2.shape, x2.flatten()) + y = Tensor(Dtype.FP8x23, y.shape, to_fp(y.flatten(), FixedImpl.FP8x23)) + + name = "compress_fp8x23_3d_axis1" + make_test( + inputs = [x1, x2], output = y, func_sig = "input_0.compress(condition:input_1, axis:Option::Some(1))", + name= name) + + def axis2(): + x1 = np.arange(0,27).reshape(3,3,3).astype(np.int64) + x2 = np.array([0, 1, 1]).astype(np.uint32) + y = x1.compress(x2, axis=2) + + x1 = Tensor(Dtype.FP8x23, x1.shape, to_fp(x1.flatten(), FixedImpl.FP8x23)) + x2 = Tensor(Dtype.U32, x2.shape, x2.flatten()) + y = Tensor(Dtype.FP8x23, y.shape, to_fp(y.flatten(), FixedImpl.FP8x23)) + + name = "compress_fp8x23_3d_axis2" + make_test( + inputs = [x1, x2], output = y, func_sig = "input_0.compress(condition:input_1, axis:Option::Some(2))", + name= name) + + default() + axis1() + axis2() + compress_3D() + + @staticmethod + def compress_i8(): + + def compress_3D(): + def default(): + x1 = np.arange(0,27).reshape(3,3,3).astype(np.int8) + x2 = np.array([0, 1, 1]).astype(np.uint8) + y = x1.compress(x2, axis=0) + + x1 = Tensor(Dtype.I8, x1.shape, x1.flatten()) + x2 = Tensor(Dtype.U32, x2.shape, x2.flatten()) + y = Tensor(Dtype.I8, y.shape, y.flatten()) + + name = "compress_i8_3d_default" + make_test( + inputs = [x1, x2], output = y, func_sig = "input_0.compress(condition:input_1, axis:Option::Some(0))", + name= name) + + def axis1(): + x1 = np.arange(0,27).reshape(3,3,3).astype(np.int8) + x2 = np.array([0, 1, 1]).astype(np.uint8) + y = x1.compress(x2, axis=1) + + x1 = Tensor(Dtype.I8, x1.shape, x1.flatten()) + x2 = Tensor(Dtype.U32, x2.shape, x2.flatten()) + y = Tensor(Dtype.I8, y.shape, y.flatten()) + + name = "compress_i8_3d_axis1" + make_test( + inputs = [x1, x2], output = y, func_sig = "input_0.compress(condition:input_1, axis:Option::Some(1))", + name= name) + + def axis2(): + x1 = np.arange(0,27).reshape(3,3,3).astype(np.int8) + x2 = np.array([0, 1, 1]).astype(np.uint8) + y = x1.compress(x2, axis=2) + + x1 = Tensor(Dtype.I8, x1.shape, x1.flatten()) + x2 = Tensor(Dtype.U32, x2.shape, x2.flatten()) + y = Tensor(Dtype.I8, y.shape, y.flatten()) + + name = "compress_i8_3d_axis2" + make_test( + inputs = [x1, x2], output = y, func_sig = "input_0.compress(condition:input_1, axis:Option::Some(2))", + name= name) + + default() + axis1() + axis2() + compress_3D() + + + @staticmethod + def compress_i32(): + + def compress_3D(): + def default(): + x1 = np.arange(0,27).reshape(3,3,3).astype(np.int32) + x2 = np.array([0, 1, 1]).astype(np.int32) + y = x1.compress(x2, axis=0) + + x1 = Tensor(Dtype.I32, x1.shape, x1.flatten()) + x2 = Tensor(Dtype.U32, x2.shape, x2.flatten()) + y = Tensor(Dtype.I32, y.shape, y.flatten()) + + name = "compress_i32_3d_default" + make_test( + inputs = [x1, x2], output = y, func_sig = "input_0.compress(condition:input_1, axis:Option::Some(0))", + name= name) + + def axis1(): + x1 = np.arange(0,27).reshape(3,3,3).astype(np.int32) + x2 = np.array([0, 1, 1]).astype(np.int32) + y = x1.compress(x2, axis=1) + + x1 = Tensor(Dtype.I32, x1.shape, x1.flatten()) + x2 = Tensor(Dtype.U32, x2.shape, x2.flatten()) + y = Tensor(Dtype.I32, y.shape, y.flatten()) + + name = "compress_i32_3d_axis1" + make_test( + inputs = [x1, x2], output = y, func_sig = "input_0.compress(condition:input_1, axis:Option::Some(1))", + name= name) + + def axis2(): + x1 = np.arange(0,27).reshape(3,3,3).astype(np.int32) + x2 = np.array([0, 1, 1]).astype(np.int32) + y = x1.compress(x2, axis=2) + + x1 = Tensor(Dtype.I32, x1.shape, x1.flatten()) + x2 = Tensor(Dtype.U32, x2.shape, x2.flatten()) + y = Tensor(Dtype.I32, y.shape, y.flatten()) + + name = "compress_i32_3d_axis2" + make_test( + inputs = [x1, x2], output = y, func_sig = "input_0.compress(condition:input_1, axis:Option::Some(2))", + name= name) + + default() + axis1() + axis2() + compress_3D() + + @staticmethod + def compress_u32(): + + def compress_3D(): + def default(): + x1 = np.arange(0,48).reshape(4,4,3).astype(np.uint32) + x2 = np.array([1, 1]).astype(np.uint32) + y = x1.compress(x2, axis=0) + + x1 = Tensor(Dtype.U32, x1.shape, x1.flatten()) + x2 = Tensor(Dtype.U32, x2.shape, x2.flatten()) + y = Tensor(Dtype.U32, y.shape, y.flatten()) + + name = "compress_u32_3d_default" + make_test( + inputs = [x1, x2], output = y, func_sig = "input_0.compress(condition:input_1, axis:Option::Some(0))", + name= name) + + def axis1(): + x1 = np.arange(0,36).reshape(3,4,3).astype(np.uint32) + x2 = np.array([0, 1, 1]).astype(np.uint32) + y = x1.compress(x2, axis=1) + + x1 = Tensor(Dtype.U32, x1.shape, x1.flatten()) + x2 = Tensor(Dtype.U32, x2.shape, x2.flatten()) + y = Tensor(Dtype.U32, y.shape, y.flatten()) + + name = "compress_u32_3d_axis1" + make_test( + inputs = [x1, x2], output = y, func_sig = "input_0.compress(condition:input_1, axis:Option::Some(1))", + name= name) + + def axis2(): + x1 = np.arange(0,48).reshape(3,4,4).astype(np.uint32) + x2 = np.array([0, 1, 1]).astype(np.uint32) + y = x1.compress(x2, axis=2) + + x1 = Tensor(Dtype.U32, x1.shape, x1.flatten()) + x2 = Tensor(Dtype.U32, x2.shape, x2.flatten()) + y = Tensor(Dtype.U32, y.shape, y.flatten()) + + name = "compress_u32_3d_axis2" + make_test( + inputs = [x1, x2], output = y, func_sig = "input_0.compress(condition:input_1, axis:Option::Some(2))", + name= name) + + def axis2_2(): + x1 = np.arange(0,60).reshape(3,4,5).astype(np.uint32) + x2 = np.array([0, 1, 1]).astype(np.uint32) + y = x1.compress(x2, axis=2) + + x1 = Tensor(Dtype.U32, x1.shape, x1.flatten()) + x2 = Tensor(Dtype.U32, x2.shape, x2.flatten()) + y = Tensor(Dtype.U32, y.shape, y.flatten()) + + name = "compress_u32_3d_axis2_2" + make_test( + inputs = [x1, x2], output = y, func_sig = "input_0.compress(condition:input_1, axis:Option::Some(2))", + name= name) + + def axis3(): + x1 = np.arange(0,270).reshape(3,3,5,6).astype(np.uint32) + x2 = np.array([0, 1, 1,1,0,1]).astype(np.uint32) + y = x1.compress(x2, axis=3) + + x1 = Tensor(Dtype.U32, x1.shape, x1.flatten()) + x2 = Tensor(Dtype.U32, x2.shape, x2.flatten()) + y = Tensor(Dtype.U32, y.shape, y.flatten()) + + name = "compress_u32_3d_axis3" + make_test( + inputs = [x1, x2], output = y, func_sig = "input_0.compress(condition:input_1, axis:Option::Some(3))", + name= name) + + default() + axis1() + axis2() + axis2_2() + axis3() + compress_3D() diff --git a/src/operators/tensor/core.cairo b/src/operators/tensor/core.cairo index d3eef4158..91f3ecda8 100644 --- a/src/operators/tensor/core.cairo +++ b/src/operators/tensor/core.cairo @@ -5078,7 +5078,7 @@ trait TensorTrait { /// fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor; - fn compress(self: @Tensor, indices: Tensor, axis: Option) -> Tensor; + fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor; } diff --git a/src/operators/tensor/implementations/tensor_bool.cairo b/src/operators/tensor/implementations/tensor_bool.cairo index c011b5f6c..18f5185f1 100644 --- a/src/operators/tensor/implementations/tensor_bool.cairo +++ b/src/operators/tensor/implementations/tensor_bool.cairo @@ -476,8 +476,8 @@ impl BoolTensor of TensorTrait { math::gather_nd::gather_nd(self, indices, batch_dims) } - fn compress(self: @Tensor, indices: Tensor, axis: Option) -> Tensor { - math::compress::compress(self, indices, axis) + fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { + math::compress::compress(self, condition, axis) } } diff --git a/src/operators/tensor/implementations/tensor_fp16x16.cairo b/src/operators/tensor/implementations/tensor_fp16x16.cairo index e9ac30c2b..1b0c94c01 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16.cairo @@ -538,8 +538,8 @@ impl FP16x16Tensor of TensorTrait { manipulation::unique::unique(self, axis, sorted) } - fn compress(self: @Tensor, indices: Tensor, axis: Option) -> Tensor { - math::compress::compress(self, indices, axis) + fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { + math::compress::compress(self, condition, axis) } } diff --git a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo index 53c60d745..a8dea5614 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo @@ -504,8 +504,8 @@ impl FP16x16WTensor of TensorTrait { manipulation::unique::unique(self, axis, sorted) } - fn compress(self: @Tensor, indices: Tensor, axis: Option) -> Tensor { - math::compress::compress(self, indices, axis) + fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { + math::compress::compress(self, condition, axis) } } diff --git a/src/operators/tensor/implementations/tensor_fp32x32.cairo b/src/operators/tensor/implementations/tensor_fp32x32.cairo index 6cf2113c5..90c4b207b 100644 --- a/src/operators/tensor/implementations/tensor_fp32x32.cairo +++ b/src/operators/tensor/implementations/tensor_fp32x32.cairo @@ -539,8 +539,8 @@ impl FP32x32Tensor of TensorTrait { manipulation::unique::unique(self, axis, sorted) } - fn compress(self: @Tensor, indices: Tensor, axis: Option) -> Tensor { - math::compress::compress(self, indices, axis) + fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { + math::compress::compress(self, condition, axis) } } diff --git a/src/operators/tensor/implementations/tensor_fp64x64.cairo b/src/operators/tensor/implementations/tensor_fp64x64.cairo index 313299072..962002f26 100644 --- a/src/operators/tensor/implementations/tensor_fp64x64.cairo +++ b/src/operators/tensor/implementations/tensor_fp64x64.cairo @@ -540,8 +540,8 @@ impl FP64x64Tensor of TensorTrait { manipulation::unique::unique(self, axis, sorted) } - fn compress(self: @Tensor, indices: Tensor, axis: Option) -> Tensor { - math::compress::compress(self, indices, axis) + fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { + math::compress::compress(self, condition, axis) } } diff --git a/src/operators/tensor/implementations/tensor_fp8x23.cairo b/src/operators/tensor/implementations/tensor_fp8x23.cairo index 22608e870..93d3345cc 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23.cairo @@ -538,8 +538,8 @@ impl FP8x23Tensor of TensorTrait { manipulation::unique::unique(self, axis, sorted) } - fn compress(self: @Tensor, indices: Tensor, axis: Option) -> Tensor { - math::compress::compress(self, indices, axis) + fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { + math::compress::compress(self, condition, axis) } } diff --git a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo index bebbc4075..95f80b234 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo @@ -491,8 +491,8 @@ impl FP8x23WTensor of TensorTrait { manipulation::unique::unique(self, axis, sorted) } - fn compress(self: @Tensor, indices: Tensor, axis: Option) -> Tensor { - math::compress::compress(self, indices, axis) + fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { + math::compress::compress(self, condition, axis) } } diff --git a/src/operators/tensor/implementations/tensor_i32.cairo b/src/operators/tensor/implementations/tensor_i32.cairo index e8f7756e2..95ff04b5d 100644 --- a/src/operators/tensor/implementations/tensor_i32.cairo +++ b/src/operators/tensor/implementations/tensor_i32.cairo @@ -535,8 +535,8 @@ impl I32Tensor of TensorTrait { manipulation::unique::unique(self, axis, sorted) } - fn compress(self: @Tensor, indices: Tensor, axis: Option) -> Tensor { - math::compress::compress(self, indices, axis) + fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { + math::compress::compress(self, condition, axis) } } diff --git a/src/operators/tensor/implementations/tensor_i8.cairo b/src/operators/tensor/implementations/tensor_i8.cairo index d6304ef27..ae9d539e7 100644 --- a/src/operators/tensor/implementations/tensor_i8.cairo +++ b/src/operators/tensor/implementations/tensor_i8.cairo @@ -533,8 +533,8 @@ impl I8Tensor of TensorTrait { manipulation::unique::unique(self, axis, sorted) } - fn compress(self: @Tensor, indices: Tensor, axis: Option) -> Tensor { - math::compress::compress(self, indices, axis) + fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { + math::compress::compress(self, condition, axis) } } diff --git a/src/operators/tensor/implementations/tensor_u32.cairo b/src/operators/tensor/implementations/tensor_u32.cairo index 136dfe2c9..496e653e5 100644 --- a/src/operators/tensor/implementations/tensor_u32.cairo +++ b/src/operators/tensor/implementations/tensor_u32.cairo @@ -476,8 +476,8 @@ impl U32Tensor of TensorTrait { manipulation::unique::unique(self, axis, sorted) } - fn compress(self: @Tensor, indices: Tensor, axis: Option) -> Tensor { - math::compress::compress(self, indices, axis) + fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { + math::compress::compress(self, condition, axis) } } diff --git a/src/operators/tensor/math/compress.cairo b/src/operators/tensor/math/compress.cairo index 0f6a3d9b6..6380d5d15 100644 --- a/src/operators/tensor/math/compress.cairo +++ b/src/operators/tensor/math/compress.cairo @@ -13,7 +13,7 @@ use orion::numbers::NumberTrait; use orion::operators::tensor::U32TensorPartialEq; use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor}; -/// Cf: TensorTrait::gather_nd docstring +/// Cf: TensorTrait::compare docstring fn compress< T, impl TTensorTrait: TensorTrait, @@ -32,12 +32,12 @@ fn compress< assert((data_rank >= 1 ), 'data rank must > 1'); assert((condition_rank == 1), 'condition rank must be 1'); - let mut data_shape = *self.shape; let mut condition_shape = condition.shape; - // let mut data_shape_clone = data_shape.clone(); - // let mut condition_shape_clone = condition_shape.clone(); - assert(*data_shape.at(axis) >= condition.data.len(), 'index out of bound'); + + if (axis != 999) { + assert(*data_shape.at(axis) >= condition.data.len(), 'index out of bound'); + } let mut output_shape = ArrayTrait::new(); let mut index_data = ArrayTrait::new(); @@ -61,179 +61,106 @@ fn compress< }; }; - let mut ind = 0; - let mut loop_breaker = 1; - let mut other_loop_breaker = 1; - let mut multiplier = 1; + if (axis == 999) { + output_shape.append(output); - let mut data_shape_clone = data_shape.clone(); - loop { - match data_shape_clone.pop_front() { - Option::Some(val) => { - if (ind == axis) { - output_shape.append(output); - } - else { - output_shape.append(*val); - if (ind > axis) { - loop_breaker *= *val; + let mut total_shape = 1; + loop { + match data_shape.pop_front() { + Option::Some(val) => { + total_shape *= *val; + }, + Option::None(_) => { break; } + }; + }; + + let mut ind = 0; + loop { + match condition_data.pop_front() { + Option::Some(val) => { + if (ind == total_shape) {break; } + if (*val != 0){ + output_data.append(*self.data[ind]); } - if (ind >= axis) { - multiplier *= *val; + ind += 1; + }, + Option::None(_) => { break; } + }; + }; + } else { + let mut ind = 0; + let mut loop_breaker = 1; + let mut other_loop_breaker = 1; + let mut multiplier = 1; + + let mut data_shape_clone = data_shape.clone(); + loop { + match data_shape_clone.pop_front() { + Option::Some(val) => { + if (ind == axis) { + output_shape.append(output); } - if (ind < axis) { - other_loop_breaker *= *val; + else { + output_shape.append(*val); + if (ind > axis) { + loop_breaker *= *val; + } + if (ind >= axis) { + multiplier *= *val; + } + if (ind < axis) { + other_loop_breaker *= *val; + } } - } - ind += 1; - }, - Option::None(_) => { break; } + ind += 1; + }, + Option::None(_) => { break; } + }; }; - }; - - let mut ind = 0; - let mut inner_index: usize = 0; - let mut condition_data_clone = condition_data.clone(); - - loop { - if (ind == other_loop_breaker) {break;} + let mut ind = 0; + let mut ind_loop = 0; + + let mut inner_index: usize = 0; let mut condition_data_clone = condition_data.clone(); loop { - match condition_data_clone.pop_front() { - Option::Some(val) => { - if (*val != 0){ - let result = inner_index * loop_breaker ; - // + multiplier * ind - // 'Start'.print(); - // (inner_index).print(); - // (loop_breaker).print(); - // (multiplier).print(); - // (ind).print(); - // (result).print(); + if (ind == other_loop_breaker) {break;} + let mut condition_data_clone = condition_data.clone(); + inner_index = *data_shape.at(axis) * ind; + loop { + + match condition_data_clone.pop_front() { + Option::Some(val) => { + if (*val != 0){ + let result = inner_index * loop_breaker ; - - let mut data_ind:usize = result ; - loop { - if data_ind == result + loop_breaker { break; } - index_data.append(data_ind); - data_ind+=1; - }; - } - inner_index += 1; - }, - - Option::None(_) => { break; } + let mut data_ind:usize = result ; + loop { + if data_ind == result + loop_breaker { break; } + index_data.append(data_ind); + data_ind+=1; + }; + } + inner_index += 1; + }, + Option::None(_) => { break; } + }; }; - }; - ind += 1; - }; - - loop { - match index_data.pop_front() { - Option::Some(val) => { - output_data.append(*self.data[val]); - }, - Option::None(_) => { break; } + ind += 1; }; - }; + + loop { + match index_data.pop_front() { + Option::Some(val) => { + output_data.append(*self.data[val]); + }, + Option::None(_) => { break; } + }; + }; + } let mut output_tensor = TensorTrait::::new(output_shape.span(), output_data.span()); return output_tensor; -} - -// Tests-------------------------------------------------------------------------------------------------------------- - -use orion::utils::assert_eq; - -fn indices() -> Tensor { - let mut sizes = ArrayTrait::new(); - sizes.append(3); - - let mut data = ArrayTrait::new(); - data.append(0); - data.append(1); - data.append(1); - - let tensor = TensorTrait::::new(sizes.span(), data.span()); - return tensor; - -} - -fn data() -> Tensor { - let mut sizes = ArrayTrait::new(); - sizes.append(2); - sizes.append(2); - - let mut data = ArrayTrait::new(); - data.append(0); - data.append(1); - data.append(2); - data.append(3); - - let tensor = TensorTrait::::new(sizes.span(), data.span()); - return tensor; -} - -fn data1() -> Tensor { - let mut sizes = ArrayTrait::new(); - sizes.append(3); - sizes.append(3); - sizes.append(3); - - let mut data = ArrayTrait::new(); - - data.append(1); - data.append(2); - data.append(3); - data.append(4); - data.append(5); - data.append(6); - data.append(7); - data.append(8); - data.append(9); - data.append(1); - data.append(2); - data.append(3); - data.append(4); - data.append(5); - data.append(6); - data.append(7); - data.append(8); - data.append(9); - data.append(1); - data.append(2); - data.append(3); - data.append(4); - data.append(5); - data.append(6); - data.append(7); - data.append(8); - data.append(9); - - let tensor = TensorTrait::::new(sizes.span(), data.span()); - return tensor; -} - -#[test] -#[available_gas(20000000000)] -fn test_gather_elements_default() { - let data = data1(); - let indices = indices(); - - let y = data.compress(indices: indices, axis:Option::Some(0)); - let mut output = y.data; - - loop { - match output.pop_front() { - Option::Some(val) => { - (*val).print(); - - }, - Option::None(_) => { break; } - }; - }; - } \ No newline at end of file diff --git a/tests/nodes.cairo b/tests/nodes.cairo index 48ca8e25f..e3eccf547 100644 --- a/tests/nodes.cairo +++ b/tests/nodes.cairo @@ -1,852 +1,871 @@ -// mod abs_fp16x16; -// mod abs_fp8x23; -// mod abs_i32; -// mod abs_i8; -// mod acos_fp16x16; -// mod acos_fp8x23; -// mod acosh_fp16x16; -// mod acosh_fp8x23; -// mod add_fp16x16; -// mod add_fp16x16_broadcast; -// mod add_fp8x23; -// mod add_fp8x23_broadcast; -// mod add_i32; -// mod add_i32_broadcast; -// mod add_i8; -// mod add_i8_broadcast; -// mod add_u32; -// mod add_u32_broadcast; -// mod argmax_fp16x16_1D_default; -// mod argmax_fp16x16_1D_keepdims_false; -// mod argmax_fp16x16_1D_last_index; -// mod argmax_fp16x16_2D_default; -// mod argmax_fp16x16_2D_keepdims_false; -// mod argmax_fp16x16_2D_last_index; -// mod argmax_fp16x16_3D_default; -// mod argmax_fp16x16_3D_keepdims_false; -// mod argmax_fp16x16_3D_last_index; -// mod argmax_fp8x23_1D_default; -// mod argmax_fp8x23_1D_keepdims_false; -// mod argmax_fp8x23_1D_last_index; -// mod argmax_fp8x23_2D_default; -// mod argmax_fp8x23_2D_keepdims_false; -// mod argmax_fp8x23_2D_last_index; -// mod argmax_fp8x23_3D_default; -// mod argmax_fp8x23_3D_keepdims_false; -// mod argmax_fp8x23_3D_last_index; -// mod argmax_i32_1D_default; -// mod argmax_i32_1D_keepdims_false; -// mod argmax_i32_1D_last_index; -// mod argmax_i32_2D_default; -// mod argmax_i32_2D_keepdims_false; -// mod argmax_i32_2D_last_index; -// mod argmax_i32_3D_default; -// mod argmax_i32_3D_keepdims_false; -// mod argmax_i32_3D_last_index; -// mod argmax_i8_1D_default; -// mod argmax_i8_1D_keepdims_false; -// mod argmax_i8_1D_last_index; -// mod argmax_i8_2D_default; -// mod argmax_i8_2D_keepdims_false; -// mod argmax_i8_2D_last_index; -// mod argmax_i8_3D_default; -// mod argmax_i8_3D_keepdims_false; -// mod argmax_i8_3D_last_index; -// mod argmax_u32_1D_default; -// mod argmax_u32_1D_keepdims_false; -// mod argmax_u32_1D_last_index; -// mod argmax_u32_2D_default; -// mod argmax_u32_2D_keepdims_false; -// mod argmax_u32_2D_last_index; -// mod argmax_u32_3D_default; -// mod argmax_u32_3D_keepdims_false; -// mod argmax_u32_3D_last_index; -// mod argmin_fp16x16_1D_default; -// mod argmin_fp16x16_1D_keepdims_false; -// mod argmin_fp16x16_1D_last_index; -// mod argmin_fp16x16_2D_default; -// mod argmin_fp16x16_2D_keepdims_false; -// mod argmin_fp16x16_2D_last_index; -// mod argmin_fp16x16_3D_default; -// mod argmin_fp16x16_3D_keepdims_false; -// mod argmin_fp16x16_3D_last_index; -// mod argmin_fp8x23_1D_default; -// mod argmin_fp8x23_1D_keepdims_false; -// mod argmin_fp8x23_1D_last_index; -// mod argmin_fp8x23_2D_default; -// mod argmin_fp8x23_2D_keepdims_false; -// mod argmin_fp8x23_2D_last_index; -// mod argmin_fp8x23_3D_default; -// mod argmin_fp8x23_3D_keepdims_false; -// mod argmin_fp8x23_3D_last_index; -// mod argmin_i32_1D_default; -// mod argmin_i32_1D_keepdims_false; -// mod argmin_i32_1D_last_index; -// mod argmin_i32_2D_default; -// mod argmin_i32_2D_keepdims_false; -// mod argmin_i32_2D_last_index; -// mod argmin_i32_3D_default; -// mod argmin_i32_3D_keepdims_false; -// mod argmin_i32_3D_last_index; -// mod argmin_i8_1D_default; -// mod argmin_i8_1D_keepdims_false; -// mod argmin_i8_1D_last_index; -// mod argmin_i8_2D_default; -// mod argmin_i8_2D_keepdims_false; -// mod argmin_i8_2D_last_index; -// mod argmin_i8_3D_default; -// mod argmin_i8_3D_keepdims_false; -// mod argmin_i8_3D_last_index; -// mod argmin_u32_1D_default; -// mod argmin_u32_1D_keepdims_false; -// mod argmin_u32_1D_last_index; -// mod argmin_u32_2D_default; -// mod argmin_u32_2D_keepdims_false; -// mod argmin_u32_2D_last_index; -// mod argmin_u32_3D_default; -// mod argmin_u32_3D_keepdims_false; -// mod argmin_u32_3D_last_index; -// mod asin_fp16x16; -// mod asin_fp8x23; -// mod asinh_fp16x16; -// mod asinh_fp8x23; -// mod atan_fp16x16; -// mod atan_fp8x23; -// mod ceil_fp16x16; -// mod ceil_fp8x23; -// mod concat_fp16x16_1d; -// mod concat_fp16x16_2d; -// mod concat_fp16x16_3d_default; -// mod concat_fp16x16_3d_axis_1; -// mod concat_fp16x16_3d_axis_2; -// mod concat_fp16x16_3d_three_tensors_axis_1; -// mod concat_fp16x16_3d_three_tensors_axis_2; -// mod concat_fp8x23_1d; -// mod concat_fp8x23_2d; -// mod concat_fp8x23_3d_default; -// mod concat_fp8x23_3d_axis_1; -// mod concat_fp8x23_3d_axis_2; -// mod concat_fp8x23_3d_three_tensors_axis_1; -// mod concat_fp8x23_3d_three_tensors_axis_2; -// mod concat_i32_1d; -// mod concat_i32_2d; -// mod concat_i32_3d_default; -// mod concat_i32_3d_axis_1; -// mod concat_i32_3d_axis_2; -// mod concat_i32_3d_three_tensors_axis_1; -// mod concat_i32_3d_three_tensors_axis_2; -// mod concat_i8_1d; -// mod concat_i8_2d; -// mod concat_i8_3d_default; -// mod concat_i8_3d_axis_1; -// mod concat_i8_3d_axis_2; -// mod concat_i8_3d_three_tensors_axis_1; -// mod concat_i8_3d_three_tensors_axis_2; -// mod concat_u32_1d; -// mod concat_u32_2d; -// mod concat_u32_3d_default; -// mod concat_u32_3d_axis_1; -// mod concat_u32_3d_axis_2; -// mod concat_u32_3d_three_tensors_axis_1; -// mod concat_u32_3d_three_tensors_axis_2; -// mod cos_fp16x16; -// mod cos_fp8x23; -// mod cosh_fp16x16; -// mod cosh_fp8x23; -// mod cumsum_fp16x16_1d_default; -// mod cumsum_fp16x16_1d_exclusive; -// mod cumsum_fp16x16_1d_reverse; -// mod cumsum_fp16x16_1d_reverse_exclusive; -// mod cumsum_fp16x16_2d_axis_0; -// mod cumsum_fp16x16_2d_axis_1; -// mod cumsum_fp8x23_1d_default; -// mod cumsum_fp8x23_1d_exclusive; -// mod cumsum_fp8x23_1d_reverse; -// mod cumsum_fp8x23_1d_reverse_exclusive; -// mod cumsum_fp8x23_2d_axis_0; -// mod cumsum_fp8x23_2d_axis_1; -// mod cumsum_i32_1d_default; -// mod cumsum_i32_1d_exclusive; -// mod cumsum_i32_1d_reverse; -// mod cumsum_i32_1d_reverse_exclusive; -// mod cumsum_i32_2d_axis_0; -// mod cumsum_i32_2d_axis_1; -// mod cumsum_i8_1d_default; -// mod cumsum_i8_1d_exclusive; -// mod cumsum_i8_1d_reverse; -// mod cumsum_i8_1d_reverse_exclusive; -// mod cumsum_i8_2d_axis_0; -// mod cumsum_i8_2d_axis_1; -// mod cumsum_u32_1d_default; -// mod cumsum_u32_1d_exclusive; -// mod cumsum_u32_1d_reverse; -// mod cumsum_u32_1d_reverse_exclusive; -// mod cumsum_u32_2d_axis_0; -// mod cumsum_u32_2d_axis_1; -// mod div_fp16x16; -// mod div_fp16x16_broadcast; -// mod div_fp8x23; -// mod div_fp8x23_broadcast; -// mod div_i32; -// mod div_i32_broadcast; -// mod div_i8; -// mod div_i8_broadcast; -// mod div_u32; -// mod div_u32_broadcast; -// mod equal_fp16x16; -// mod equal_fp16x16_broadcast; -// mod equal_fp8x23; -// mod equal_fp8x23_broadcast; -// mod equal_i32; -// mod equal_i32_broadcast; -// mod equal_i8; -// mod equal_i8_broadcast; -// mod equal_u32; -// mod equal_u32_broadcast; -// mod exp_fp16x16; -// mod exp_fp8x23; -// mod less_equal_fp16x16; -// mod less_equal_fp16x16_broadcast; -// mod less_equal_fp8x23; -// mod less_equal_fp8x23_broadcast; -// mod less_equal_i32; -// mod less_equal_i32_broadcast; -// mod less_equal_i8; -// mod less_equal_i8_broadcast; -// mod less_equal_u32; -// mod less_equal_u32_broadcast; -// mod greater_fp16x16; -// mod greater_fp16x16_broadcast; -// mod greater_fp8x23; -// mod greater_fp8x23_broadcast; -// mod greater_i32; -// mod greater_i32_broadcast; -// mod greater_i8; -// mod greater_i8_broadcast; -// mod greater_u32; -// mod greater_u32_broadcast; -// mod leaky_relu_fp16x16; -// mod leaky_relu_fp8x23; -// mod linear_fp16x16; -// mod linear_fp8x23; -// mod linear_i32; -// mod linear_i8; -// mod linear_u32; -// mod log_fp16x16; -// mod log_fp8x23; -// mod logsoftmax_fp16x16_axis_0; -// mod logsoftmax_fp16x16_axis_1; -// mod logsoftmax_fp8x23_axis_0; -// mod logsoftmax_fp8x23_axis_1; -// mod matmul_fp16x16_1d; -// mod matmul_fp16x16_2x2; -// mod matmul_fp16x16_2x1; -// mod matmul_fp16x16_1x2; -// mod matmul_fp8x23_1d; -// mod matmul_fp8x23_2x2; -// mod matmul_fp8x23_2x1; -// mod matmul_fp8x23_1x2; -// mod matmul_i32_1d; -// mod matmul_i32_2x2; -// mod matmul_i32_2x1; -// mod matmul_i32_1x2; -// mod matmul_i8_1d; -// mod matmul_i8_2x2; -// mod matmul_i8_2x1; -// mod matmul_i8_1x2; -// mod matmul_u32_1d; -// mod matmul_u32_2x2; -// mod matmul_u32_2x1; -// mod matmul_u32_1x2; -// mod mul_fp16x16; -// mod mul_fp16x16_broadcast; -// mod mul_fp8x23; -// mod mul_fp8x23_broadcast; -// mod mul_i32; -// mod mul_i32_broadcast; -// mod mul_i8; -// mod mul_i8_broadcast; -// mod mul_u32; -// mod mul_u32_broadcast; -// mod or_fp16x16; -// mod or_fp16x16_broadcast; -// mod or_fp8x23; -// mod or_fp8x23_broadcast; -// mod or_i32; -// mod or_i32_broadcast; -// mod or_i8; -// mod or_i8_broadcast; -// mod or_u32; -// mod or_u32_broadcast; -// mod reduce_sum_fp16x16_1D; -// mod reduce_sum_fp16x16_2D_default; -// mod reduce_sum_fp16x16_2D_keepdims; -// mod reduce_sum_fp16x16_2D_axis_1; -// mod reduce_sum_fp8x23_1D; -// mod reduce_sum_fp8x23_2D_default; -// mod reduce_sum_fp8x23_2D_keepdims; -// mod reduce_sum_fp8x23_2D_axis_1; -// mod reduce_sum_i32_1D; -// mod reduce_sum_i32_2D_default; -// mod reduce_sum_i32_2D_keepdims; -// mod reduce_sum_i32_2D_axis_1; -// mod reduce_sum_i8_1D; -// mod reduce_sum_i8_2D_default; -// mod reduce_sum_i8_2D_keepdims; -// mod reduce_sum_i8_2D_axis_1; -// mod reduce_sum_u32_1D; -// mod reduce_sum_u32_2D_default; -// mod reduce_sum_u32_2D_keepdims; -// mod reduce_sum_u32_2D_axis_1; -// mod relu_fp16x16; -// mod relu_fp8x23; -// mod relu_i32; -// mod relu_i8; -// mod sigmoid_fp16x16; -// mod sigmoid_fp8x23; -// mod sin_fp16x16; -// mod sin_fp8x23; -// mod sinh_fp16x16; -// mod sinh_fp8x23; -// mod softmax_fp16x16; -// mod softmax_fp8x23; -// mod softplus_fp8x23; -// mod softplus_fp16x16; -// mod softsign_fp8x23; -// mod softsign_fp16x16; -// mod sqrt_fp16x16; -// mod sqrt_fp8x23; -// mod sub_fp16x16; -// mod sub_fp16x16_broadcast; -// mod sub_fp8x23; -// mod sub_fp8x23_broadcast; -// mod sub_i32; -// mod sub_i32_broadcast; -// mod sub_i8; -// mod sub_i8_broadcast; -// mod sub_u32; -// mod sub_u32_broadcast; -// mod tanh_fp16x16; -// mod tanh_fp8x23; -// mod transpose_fp16x16_2d; -// mod transpose_fp16x16_3d; -// mod transpose_fp8x23_2d; -// mod transpose_fp8x23_3d; -// mod transpose_i32_2d; -// mod transpose_i32_3d; -// mod transpose_i8_2d; -// mod transpose_i8_3d; -// mod transpose_u32_2d; -// mod transpose_u32_3d; -// mod xor_fp16x16; -// mod xor_fp16x16_broadcast; -// mod xor_fp8x23; -// mod xor_fp8x23_broadcast; -// mod xor_i32; -// mod xor_i32_broadcast; -// mod xor_i8; -// mod xor_i8_broadcast; -// mod xor_u32; -// mod xor_u32_broadcast; -// mod less_fp16x16; -// mod less_fp16x16_broadcast; -// mod less_fp8x23; -// mod less_fp8x23_broadcast; -// mod less_i32; -// mod less_i32_broadcast; -// mod less_i8; -// mod less_i8_broadcast; -// mod less_u32; -// mod less_u32_broadcast; -// mod greater_equal_fp16x16; -// mod greater_equal_fp16x16_broadcast; -// mod greater_equal_fp8x23; -// mod greater_equal_fp8x23_broadcast; -// mod greater_equal_i32; -// mod greater_equal_i32_broadcast; -// mod greater_equal_i8; -// mod greater_equal_i8_broadcast; -// mod greater_equal_u32; -// mod greater_equal_u32_broadcast; -// mod slice_fp16x16_2d; -// mod slice_fp16x16_3d; -// mod slice_fp8x23_2d; -// mod slice_fp8x23_3d; -// mod slice_i32_2d; -// mod slice_i32_3d; -// mod slice_i8_2d; -// mod slice_i8_3d; -// mod slice_u32_2d; -// mod slice_u32_3d; -// mod gather_fp8x23_3d_default; -// mod gather_fp8x23_3d_axis1; -// mod gather_fp8x23_3d_axis2; -// mod gather_fp16x16_3d_default; -// mod gather_fp16x16_3d_axis1; -// mod gather_fp16x16_3d_axis2; -// mod gather_i8_3d_default; -// mod gather_i8_3d_axis1; -// mod gather_i8_3d_axis2; -// mod gather_i32_3d_default; -// mod gather_i32_3d_axis1; -// mod gather_i32_3d_axis2; -// mod gather_u32_3d_default; -// mod gather_u32_3d_axis1; -// mod gather_u32_3d_axis2; -// mod nonzero_fp16x16_2d; -// mod nonzero_fp16x16_3d; -// mod nonzero_fp8x23_2d; -// mod nonzero_fp8x23_3d; -// mod nonzero_i32_2d; -// mod nonzero_i32_3d; -// mod nonzero_i8_2d; -// mod nonzero_i8_3d; -// mod nonzero_u32_2d; -// mod nonzero_u32_3d; -// mod squeeze_fP16x16; -// mod squeeze_fP8x23; -// mod squeeze_i32; -// mod squeeze_i8; -// mod squeeze_u32; -// mod unsqueeze_fp16x16_2d; -// mod unsqueeze_fp16x16_3d; -// mod unsqueeze_fp8x23_2d; -// mod unsqueeze_fp8x23_3d; -// mod unsqueeze_i32_2d; -// mod unsqueeze_i32_3d; -// mod unsqueeze_i8_2d; -// mod unsqueeze_i8_3d; -// mod unsqueeze_u32_2d; -// mod unsqueeze_u32_3d; -// mod sign_fP16x16; -// mod sign_fP8x23; -// mod sign_fail; -// mod sign_i32; -// mod sign_i8; -// mod clip_fp16x16_2d; -// mod clip_fp16x16_3d; -// mod clip_fp8x23_2d; -// mod clip_fp8x23_3d; -// mod clip_i32_2d; -// mod clip_i32_3d; -// mod clip_i8_2d; -// mod clip_i8_3d; -// mod clip_u32_2d; -// mod clip_u32_3d; -// mod identity_fP16x16; -// mod identity_fP8x23; -// mod identity_i32; -// mod identity_i8; -// mod identity_u32; -// mod thresholded_relu_fp16x16; -// mod thresholded_relu_fp8x23; -// mod hard_sigmoid_fp8x23; -// mod hard_sigmoid_fp16x16; -// mod neg_fp16x16; -// mod neg_fp8x23; -// mod neg_i32; -// mod neg_i8; -// mod gemm_all_attributes; -// mod gemm_alpha; -// mod gemm_beta; -// mod gemm_default_matrix_bias; -// mod gemm_default_vector_bias; -// mod gemm_default_no_bias; -// mod gemm_transposeA; -// mod gemm_transposeB; -// mod min_fp16x16_three_tensors; -// mod min_fp16x16_broadcast_three_tensors; -// mod min_fp16x16_two_tensors; -// mod min_fp16x16_broadcast_two_tensors; -// mod min_fp8x23_three_tensors; -// mod min_fp8x23_broadcast_three_tensors; -// mod min_fp8x23_two_tensors; -// mod min_fp8x23_broadcast_two_tensors; -// mod min_i32_three_tensors; -// mod min_i32_broadcast_three_tensors; -// mod min_i32_two_tensors; -// mod min_i32_broadcast_two_tensors; -// mod min_i8_three_tensors; -// mod min_i8_broadcast_three_tensors; -// mod min_i8_two_tensors; -// mod min_i8_broadcast_two_tensors; -// mod min_u32_three_tensors; -// mod min_u32_broadcast_three_tensors; -// mod min_u32_two_tensors; -// mod min_u32_broadcast_two_tensors; -// mod where_fp16x16; -// mod where_fp16x16_broadcast; -// mod where_fp8x23; -// mod where_fp8x23_broadcast; -// mod where_i32; -// mod where_i32_broadcast; -// mod where_i8; -// mod where_i8_broadcast; -// mod where_u32; -// mod where_u32_broadcast; -// mod not_bool; -// mod round_fp16x16; -// mod round_fp8x23; -// mod max_fp16x16_three_tensors; -// mod max_fp16x16_broadcast_three_tensors; -// mod max_fp16x16_two_tensors; -// mod max_fp16x16_broadcast_two_tensors; -// mod max_fp8x23_three_tensors; -// mod max_fp8x23_broadcast_three_tensors; -// mod max_fp8x23_two_tensors; -// mod max_fp8x23_broadcast_two_tensors; -// mod max_i32_three_tensors; -// mod max_i32_broadcast_three_tensors; -// mod max_i32_two_tensors; -// mod max_i32_broadcast_two_tensors; -// mod max_i8_three_tensors; -// mod max_i8_broadcast_three_tensors; -// mod max_i8_two_tensors; -// mod max_i8_broadcast_two_tensors; -// mod max_u32_three_tensors; -// mod max_u32_broadcast_three_tensors; -// mod max_u32_two_tensors; -// mod max_u32_broadcast_two_tensors; -// mod scatter_fp16x16_3d_default; -// mod scatter_fp16x16_3d_axis1; -// mod scatter_fp16x16_3d_axis1_add; -// mod scatter_fp8x23_default; -// mod scatter_fp8x23_axis1; -// mod scatter_fp8x23_mul; -// mod scatter_i8_default; -// mod scatter_i8_axis1; -// mod scatter_i8_axis1_max; -// mod scatter_u32_default; -// mod scatter_u32_axis1; -// mod scatter_u32_add; -// mod array_feature_extractor_1D_i32; -// mod array_feature_extractor_1D_fp8x23; -// mod array_feature_extractor_1D_fp16x16; -// mod array_feature_extractor_2D_i32; -// mod array_feature_extractor_2D_fp8x23; -// mod array_feature_extractor_2D_fp16x16; -// mod array_feature_extractor_3D_i32; -// mod array_feature_extractor_3D_fp8x23; -// mod array_feature_extractor_3D_fp16x16; -// mod binarizer_fp16x16; -// mod binarizer_fp8x23; -// mod tril_fp16x16; -// mod tril_fp16x16_neg; -// mod tril_fp16x16_one_row; -// mod tril_fp16x16_out_neg; -// mod tril_fp16x16_out_pos; -// mod tril_fp16x16_pos; -// mod tril_fp16x16_square; -// mod tril_fp16x16_square_neg; -// mod tril_fp16x16_zero; -// mod triu_fp16x16; -// mod triu_fp16x16_neg; -// mod triu_fp16x16_one_row; -// mod triu_fp16x16_out_neg; -// mod triu_fp16x16_out_pos; -// mod triu_fp16x16_pos; -// mod triu_fp16x16_square; -// mod triu_fp16x16_square_neg; -// mod triu_fp16x16_zero; -// mod tril_fp8x23; -// mod tril_fp8x23_neg; -// mod tril_fp8x23_one_row; -// mod tril_fp8x23_out_neg; -// mod tril_fp8x23_out_pos; -// mod tril_fp8x23_pos; -// mod tril_fp8x23_square; -// mod tril_fp8x23_square_neg; -// mod tril_fp8x23_zero; -// mod triu_fp8x23; -// mod triu_fp8x23_neg; -// mod triu_fp8x23_one_row; -// mod triu_fp8x23_out_neg; -// mod triu_fp8x23_out_pos; -// mod triu_fp8x23_pos; -// mod triu_fp8x23_square; -// mod triu_fp8x23_square_neg; -// mod triu_fp8x23_zero; -// mod tril_i32; -// mod tril_neg_i32; -// mod tril_i32_one_row; -// mod tril_i32_out_neg; -// mod tril_i32_out_pos; -// mod tril_i32_pos; -// mod tril_i32_square; -// mod tril_i32_square_neg; -// mod tril_i32_zero; -// mod triu_i32; -// mod triu_i32_neg; -// mod triu_i32_one_row; -// mod triu_i32_out_neg; -// mod triu_i32_out_pos; -// mod triu_i32_pos; -// mod triu_i32_square; -// mod triu_i32_square_neg; -// mod triu_i32_zero; -// mod tril_i8; -// mod tril_i8_neg; -// mod tril_i8_one_row; -// mod tril_i8_out_neg; -// mod tril_i8_out_pos; -// mod tril_i8_pos; -// mod tril_i8_square; -// mod tril_i8_square_neg; -// mod tril_i8_zero; -// mod triu_i8; -// mod triu_i8_neg; -// mod triu_i8_one_row; -// mod triu_i8_out_neg; -// mod triu_i8_out_pos; -// mod triu_i8_pos; -// mod triu_i8_square; -// mod triu_i8_square_neg; -// mod triu_i8_zero; -// mod tril_u32; -// mod tril_u32_neg; -// mod tril_u32_one_row; -// mod tril_u32_out_neg; -// mod tril_u32_out_pos; -// mod tril_u32_pos; -// mod tril_u32_square; -// mod tril_u32_square_neg; -// mod tril_u32_zero; -// mod triu_u32; -// mod triu_u32_neg; -// mod triu_u32_one_row; -// mod triu_u32_out_neg; -// mod triu_u32_out_pos; -// mod triu_u32_pos; -// mod triu_u32_square; -// mod triu_u32_square_neg; -// mod triu_u32_zero; -// mod reduce_sum_square_fp16x16_export_do_not_keepdims; -// mod reduce_sum_square_fp16x16_export_keepdims; -// mod reduce_sum_square_fp16x16_export_negative_axes_keepdims; -// mod reduce_sum_square_fp8x23_export_do_not_keepdims; -// mod reduce_sum_square_fp8x23_export_keepdims; -// mod reduce_sum_square_fp8x23_export_negative_axes_keepdims; -// mod reduce_sum_square_i32_export_do_not_keepdims; -// mod reduce_sum_square_i32_export_keepdims; -// mod reduce_sum_square_i32_export_negative_axes_keepdims; -// mod reduce_sum_square_i8_export_do_not_keepdims; -// mod reduce_sum_square_i8_export_keepdims; -// mod reduce_sum_square_i8_export_negative_axes_keepdims; -// mod reduce_sum_square_u32_export_do_not_keepdims; -// mod reduce_sum_square_u32_export_keepdims; -// mod reduce_sum_square_u32_export_negative_axes_keepdims; -// mod reduce_l2_fp16x16_export_do_not_keepdims; -// mod reduce_l2_fp16x16_export_keepdims; -// mod reduce_l2_fp16x16_export_negative_axes_keepdims; -// mod reduce_l2_fp8x23_export_do_not_keepdims; -// mod reduce_l2_fp8x23_export_keepdims; -// mod reduce_l2_fp8x23_export_negative_axes_keepdims; -// mod reduce_l1_fp16x16_export_do_not_keepdims; -// mod reduce_l1_fp16x16_export_keepdims; -// mod reduce_l1_fp16x16_export_negative_axes_keepdims; -// mod reduce_l1_fp8x23_export_do_not_keepdims; -// mod reduce_l1_fp8x23_export_keepdims; -// mod reduce_l1_fp8x23_export_negative_axes_keepdims; -// mod reduce_l1_i32_export_do_not_keepdims; -// mod reduce_l1_i32_export_keepdims; -// mod reduce_l1_i32_export_negative_axes_keepdims; -// mod reduce_l1_i8_export_do_not_keepdims; -// mod reduce_l1_i8_export_keepdims; -// mod reduce_l1_i8_export_negative_axes_keepdims; -// mod reduce_l1_u32_export_do_not_keepdims; -// mod reduce_l1_u32_export_keepdims; -// mod reduce_l1_u32_export_negative_axes_keepdims; -// mod reduce_prod_fp16x16_1D; -// mod reduce_prod_fp16x16_2D_default; -// mod reduce_prod_fp16x16_2D_keepdims; -// mod reduce_prod_fp16x16_2D_axis_1; -// mod reduce_prod_fp8x23_1D; -// mod reduce_prod_fp8x23_2D_default; -// mod reduce_prod_fp8x23_2D_keepdims; -// mod reduce_prod_fp8x23_2D_axis_1; -// mod reduce_prod_i32_1D; -// mod reduce_prod_i32_2D_default; -// mod reduce_prod_i32_2D_keepdims; -// mod reduce_prod_i32_2D_axis_1; -// mod reduce_prod_i8_1D; -// mod reduce_prod_i8_2D_default; -// mod reduce_prod_i8_2D_keepdims; -// mod reduce_prod_i8_2D_axis_1; -// mod reduce_prod_u32_1D; -// mod reduce_prod_u32_2D_default; -// mod reduce_prod_u32_2D_keepdims; -// mod reduce_prod_u32_2D_axis_1; -// mod gather_elements_fp16x16_3d_default; -// mod gather_elements_fp16x16_3d_axis1; -// mod gather_elements_fp16x16_3d_axis2; -// mod gather_elements_fp8x23_3d_default; -// mod gather_elements_fp8x23_3d_axis1; -// mod gather_elements_fp8x23_3d_axis2; -// mod gather_elements_i8_3d_default; -// mod gather_elements_i8_3d_axis1; -// mod gather_elements_i32_3d_default; -// mod gather_elements_i32_3d_axis1; -// mod gather_elements_i32_3d_axis2; -// mod gather_elements_u32_default; -// mod gather_elements_u32_axis1; -// mod gather_elements_u32_axis2; -// mod gather_elements_u32_axis3; -// mod sequence_length_fp16x16; -// mod sequence_length_fp16x16_broadcast; -// mod sequence_length_fp8x23; -// mod sequence_length_fp8x23_broadcast; -// mod sequence_length_i32; -// mod sequence_length_i32_broadcast; -// mod sequence_length_i8; -// mod sequence_length_i8_broadcast; -// mod sequence_length_u32; -// mod sequence_length_u32_broadcast; -// mod sequence_at_u32_positive; -// mod sequence_at_u32_negative; -// mod sequence_at_fp16x16_positive; -// mod sequence_at_fp16x16_negative; -// mod sequence_at_fp8x23_positive; -// mod sequence_at_fp8x23_negative; -// mod sequence_at_i32_positive; -// mod sequence_at_i32_negative; -// mod sequence_at_i8_positive; -// mod sequence_at_i8_negative; -// mod reduce_min_fp16x16_1D; -// mod reduce_min_fp16x16_2D_default; -// mod reduce_min_fp16x16_2D_keepdims; -// mod reduce_min_fp16x16_2D_axis_1; -// mod reduce_min_fp8x23_1D; -// mod reduce_min_fp8x23_2D_default; -// mod reduce_min_fp8x23_2D_keepdims; -// mod reduce_min_fp8x23_2D_axis_1; -// mod reduce_min_i32_1D; -// mod reduce_min_i32_2D_default; -// mod reduce_min_i32_2D_keepdims; -// mod reduce_min_i32_2D_axis_1; -// mod reduce_min_i8_1D; -// mod reduce_min_i8_2D_default; -// mod reduce_min_i8_2D_keepdims; -// mod reduce_min_i8_2D_axis_1; -// mod reduce_min_u32_1D; -// mod reduce_min_u32_2D_default; -// mod reduce_min_u32_2D_keepdims; -// mod reduce_min_u32_2D_axis_1; -// mod sequence_construct_fp16x16; -// mod sequence_construct_fp8x23; -// mod sequence_construct_i32; -// mod sequence_construct_i8; -// mod sequence_construct_u32; -// mod shrink_hard_fp16x16; -// mod shrink_soft_fp16x16; -// mod shrink_hard_fp8x23; -// mod shrink_soft_fp8x23; -// mod sequence_empty_fp16x16; -// mod sequence_empty_fp8x23; -// mod sequence_empty_i32; -// mod sequence_empty_i8; -// mod sequence_empty_u32; -// mod reduce_mean_fp16x16_1D; -// mod reduce_mean_fp16x16_2D_default; -// mod reduce_mean_fp16x16_2D_keepdims; -// mod reduce_mean_fp16x16_2D_axis_1; -// mod reduce_mean_fp8x23_1D; -// mod reduce_mean_fp8x23_2D_default; -// mod reduce_mean_fp8x23_2D_keepdims; -// mod reduce_mean_fp8x23_2D_axis_1; -// mod reduce_mean_i32_1D; -// mod reduce_mean_i32_2D_default; -// mod reduce_mean_i32_2D_keepdims; -// mod reduce_mean_i32_2D_axis_1; -// mod reduce_mean_i8_1D; -// mod reduce_mean_i8_2D_default; -// mod reduce_mean_i8_2D_keepdims; -// mod reduce_mean_i8_2D_axis_1; -// mod reduce_mean_u32_1D; -// mod reduce_mean_u32_2D_default; -// mod reduce_mean_u32_2D_keepdims; -// mod reduce_mean_u32_2D_axis_1; -// mod pow_fp16x16; -// mod pow_fp16x16_broadcast; -// mod pow_fp8x23; -// mod pow_fp8x23_broadcast; -// mod sequence_erase_u32_positive; -// mod sequence_erase_u32_negative; -// mod sequence_erase_u32_empty; -// mod sequence_erase_fp16x16_positive; -// mod sequence_erase_fp16x16_negative; -// mod sequence_erase_fp16x16_empty; -// mod sequence_erase_fp8x23_positive; -// mod sequence_erase_fp8x23_negative; -// mod sequence_erase_fp8x23_empty; -// mod sequence_erase_i32_positive; -// mod sequence_erase_i32_negative; -// mod sequence_erase_i32_empty; -// mod sequence_erase_i8_positive; -// mod sequence_erase_i8_negative; -// mod sequence_erase_i8_empty; -// mod sequence_insert_fp16x16; -// mod sequence_insert_fp8x23; -// mod sequence_insert_i32; -// mod sequence_insert_i8; -// mod sequence_insert_u32; -// mod concat_from_sequence_fp8x23_new_axis_zero; -// mod concat_from_sequence_fp8x23_new_axis_one; -// mod concat_from_sequence_fp8x23_new_axis_default; -// mod concat_from_sequence_fp16x16_new_axis_zero; -// mod concat_from_sequence_fp16x16_new_axis_one; -// mod concat_from_sequence_fp16x16_new_axis_default; -// mod concat_from_sequence_i32_new_axis_zero; -// mod concat_from_sequence_i32_new_axis_one; -// mod concat_from_sequence_i32_new_axis_default; -// mod concat_from_sequence_i8_new_axis_zero; -// mod concat_from_sequence_i8_new_axis_one; -// mod concat_from_sequence_i8_new_axis_default; -// mod concat_from_sequence_u32_new_axis_zero; -// mod concat_from_sequence_u32_new_axis_one; -// mod concat_from_sequence_u32_new_axis_default; -// mod is_nan_fp16x16; -// mod is_nan_fp8x23; -// mod is_inf_fp16x16; -// mod is_inf_fp8x23; -// mod is_inf_i32; -// mod is_inf_i8; -// mod is_inf_u32; -// mod is_pos_inf_fp16x16; -// mod is_neg_inf_fp16x16; -// mod is_pos_inf_fp8x23; -// mod is_neg_inf_fp8x23; -// mod is_pos_inf_i32; -// mod is_neg_inf_i32; -// mod is_pos_inf_i8; -// mod is_neg_inf_i8; -// mod reduce_log_sum_fp8x23_export_do_not_keepdims; -// mod reduce_log_sum_fp8x23_export_keepdims; -// mod reduce_log_sum_fp8x23_export_negative_axes_keepdims; -// mod reduce_log_sum_fp16x16_export_do_not_keepdims; -// mod reduce_log_sum_fp16x16_export_keepdims; -// mod reduce_log_sum_fp16x16_export_negative_axes_keepdims; -// mod and_bool; -// mod erf_fp16x16; -// mod erf_fp8x23; -// mod unique_fp16x16_without_axis_sorted; -// mod unique_fp16x16_with_axis_zero_sorted; -// mod unique_u32_without_axis_sorted; -// mod unique_u32_without_axis_not_sorted; -// mod unique_u32_with_axis_zero_sorted; -// mod unique_u32_with_axis_zero_not_sorted; -// mod unique_u32_with_axis_one_sorted; -// mod unique_u32_with_axis_one_not_sorted; -// mod gather_nd_fp16x16_3d_default; -// mod gather_nd_fp16x16_3d_batch_dims1; -// mod gather_nd_fp16x16_3d_batch_dims2; -// mod gather_nd_fp8x23_3d_default; -// mod gather_nd_fp8x23_3d_batch_dims1; -// mod gather_nd_fp8x23_3d_batch_dims2; -// mod gather_nd_i32_3d_default; -// mod gather_nd_i32_3d_batch_dims1; -// mod gather_nd_i32_3d_batch_dims2; -// mod gather_nd_i8_3d_default; -// mod gather_nd_i8_3d_batch_dims1; -// mod gather_nd_u32_default; -// mod gather_nd_u32_batch_dims1; -// mod gather_nd_u32_batch_dims2; +mod abs_fp16x16; +mod abs_fp8x23; +mod abs_i32; +mod abs_i8; +mod acos_fp16x16; +mod acos_fp8x23; +mod acosh_fp16x16; +mod acosh_fp8x23; +mod add_fp16x16; +mod add_fp16x16_broadcast; +mod add_fp8x23; +mod add_fp8x23_broadcast; +mod add_i32; +mod add_i32_broadcast; +mod add_i8; +mod add_i8_broadcast; +mod add_u32; +mod add_u32_broadcast; +mod argmax_fp16x16_1D_default; +mod argmax_fp16x16_1D_keepdims_false; +mod argmax_fp16x16_1D_last_index; +mod argmax_fp16x16_2D_default; +mod argmax_fp16x16_2D_keepdims_false; +mod argmax_fp16x16_2D_last_index; +mod argmax_fp16x16_3D_default; +mod argmax_fp16x16_3D_keepdims_false; +mod argmax_fp16x16_3D_last_index; +mod argmax_fp8x23_1D_default; +mod argmax_fp8x23_1D_keepdims_false; +mod argmax_fp8x23_1D_last_index; +mod argmax_fp8x23_2D_default; +mod argmax_fp8x23_2D_keepdims_false; +mod argmax_fp8x23_2D_last_index; +mod argmax_fp8x23_3D_default; +mod argmax_fp8x23_3D_keepdims_false; +mod argmax_fp8x23_3D_last_index; +mod argmax_i32_1D_default; +mod argmax_i32_1D_keepdims_false; +mod argmax_i32_1D_last_index; +mod argmax_i32_2D_default; +mod argmax_i32_2D_keepdims_false; +mod argmax_i32_2D_last_index; +mod argmax_i32_3D_default; +mod argmax_i32_3D_keepdims_false; +mod argmax_i32_3D_last_index; +mod argmax_i8_1D_default; +mod argmax_i8_1D_keepdims_false; +mod argmax_i8_1D_last_index; +mod argmax_i8_2D_default; +mod argmax_i8_2D_keepdims_false; +mod argmax_i8_2D_last_index; +mod argmax_i8_3D_default; +mod argmax_i8_3D_keepdims_false; +mod argmax_i8_3D_last_index; +mod argmax_u32_1D_default; +mod argmax_u32_1D_keepdims_false; +mod argmax_u32_1D_last_index; +mod argmax_u32_2D_default; +mod argmax_u32_2D_keepdims_false; +mod argmax_u32_2D_last_index; +mod argmax_u32_3D_default; +mod argmax_u32_3D_keepdims_false; +mod argmax_u32_3D_last_index; +mod argmin_fp16x16_1D_default; +mod argmin_fp16x16_1D_keepdims_false; +mod argmin_fp16x16_1D_last_index; +mod argmin_fp16x16_2D_default; +mod argmin_fp16x16_2D_keepdims_false; +mod argmin_fp16x16_2D_last_index; +mod argmin_fp16x16_3D_default; +mod argmin_fp16x16_3D_keepdims_false; +mod argmin_fp16x16_3D_last_index; +mod argmin_fp8x23_1D_default; +mod argmin_fp8x23_1D_keepdims_false; +mod argmin_fp8x23_1D_last_index; +mod argmin_fp8x23_2D_default; +mod argmin_fp8x23_2D_keepdims_false; +mod argmin_fp8x23_2D_last_index; +mod argmin_fp8x23_3D_default; +mod argmin_fp8x23_3D_keepdims_false; +mod argmin_fp8x23_3D_last_index; +mod argmin_i32_1D_default; +mod argmin_i32_1D_keepdims_false; +mod argmin_i32_1D_last_index; +mod argmin_i32_2D_default; +mod argmin_i32_2D_keepdims_false; +mod argmin_i32_2D_last_index; +mod argmin_i32_3D_default; +mod argmin_i32_3D_keepdims_false; +mod argmin_i32_3D_last_index; +mod argmin_i8_1D_default; +mod argmin_i8_1D_keepdims_false; +mod argmin_i8_1D_last_index; +mod argmin_i8_2D_default; +mod argmin_i8_2D_keepdims_false; +mod argmin_i8_2D_last_index; +mod argmin_i8_3D_default; +mod argmin_i8_3D_keepdims_false; +mod argmin_i8_3D_last_index; +mod argmin_u32_1D_default; +mod argmin_u32_1D_keepdims_false; +mod argmin_u32_1D_last_index; +mod argmin_u32_2D_default; +mod argmin_u32_2D_keepdims_false; +mod argmin_u32_2D_last_index; +mod argmin_u32_3D_default; +mod argmin_u32_3D_keepdims_false; +mod argmin_u32_3D_last_index; +mod asin_fp16x16; +mod asin_fp8x23; +mod asinh_fp16x16; +mod asinh_fp8x23; +mod atan_fp16x16; +mod atan_fp8x23; +mod ceil_fp16x16; +mod ceil_fp8x23; +mod concat_fp16x16_1d; +mod concat_fp16x16_2d; +mod concat_fp16x16_3d_default; +mod concat_fp16x16_3d_axis_1; +mod concat_fp16x16_3d_axis_2; +mod concat_fp16x16_3d_three_tensors_axis_1; +mod concat_fp16x16_3d_three_tensors_axis_2; +mod concat_fp8x23_1d; +mod concat_fp8x23_2d; +mod concat_fp8x23_3d_default; +mod concat_fp8x23_3d_axis_1; +mod concat_fp8x23_3d_axis_2; +mod concat_fp8x23_3d_three_tensors_axis_1; +mod concat_fp8x23_3d_three_tensors_axis_2; +mod concat_i32_1d; +mod concat_i32_2d; +mod concat_i32_3d_default; +mod concat_i32_3d_axis_1; +mod concat_i32_3d_axis_2; +mod concat_i32_3d_three_tensors_axis_1; +mod concat_i32_3d_three_tensors_axis_2; +mod concat_i8_1d; +mod concat_i8_2d; +mod concat_i8_3d_default; +mod concat_i8_3d_axis_1; +mod concat_i8_3d_axis_2; +mod concat_i8_3d_three_tensors_axis_1; +mod concat_i8_3d_three_tensors_axis_2; +mod concat_u32_1d; +mod concat_u32_2d; +mod concat_u32_3d_default; +mod concat_u32_3d_axis_1; +mod concat_u32_3d_axis_2; +mod concat_u32_3d_three_tensors_axis_1; +mod concat_u32_3d_three_tensors_axis_2; +mod cos_fp16x16; +mod cos_fp8x23; +mod cosh_fp16x16; +mod cosh_fp8x23; +mod cumsum_fp16x16_1d_default; +mod cumsum_fp16x16_1d_exclusive; +mod cumsum_fp16x16_1d_reverse; +mod cumsum_fp16x16_1d_reverse_exclusive; +mod cumsum_fp16x16_2d_axis_0; +mod cumsum_fp16x16_2d_axis_1; +mod cumsum_fp8x23_1d_default; +mod cumsum_fp8x23_1d_exclusive; +mod cumsum_fp8x23_1d_reverse; +mod cumsum_fp8x23_1d_reverse_exclusive; +mod cumsum_fp8x23_2d_axis_0; +mod cumsum_fp8x23_2d_axis_1; +mod cumsum_i32_1d_default; +mod cumsum_i32_1d_exclusive; +mod cumsum_i32_1d_reverse; +mod cumsum_i32_1d_reverse_exclusive; +mod cumsum_i32_2d_axis_0; +mod cumsum_i32_2d_axis_1; +mod cumsum_i8_1d_default; +mod cumsum_i8_1d_exclusive; +mod cumsum_i8_1d_reverse; +mod cumsum_i8_1d_reverse_exclusive; +mod cumsum_i8_2d_axis_0; +mod cumsum_i8_2d_axis_1; +mod cumsum_u32_1d_default; +mod cumsum_u32_1d_exclusive; +mod cumsum_u32_1d_reverse; +mod cumsum_u32_1d_reverse_exclusive; +mod cumsum_u32_2d_axis_0; +mod cumsum_u32_2d_axis_1; +mod div_fp16x16; +mod div_fp16x16_broadcast; +mod div_fp8x23; +mod div_fp8x23_broadcast; +mod div_i32; +mod div_i32_broadcast; +mod div_i8; +mod div_i8_broadcast; +mod div_u32; +mod div_u32_broadcast; +mod equal_fp16x16; +mod equal_fp16x16_broadcast; +mod equal_fp8x23; +mod equal_fp8x23_broadcast; +mod equal_i32; +mod equal_i32_broadcast; +mod equal_i8; +mod equal_i8_broadcast; +mod equal_u32; +mod equal_u32_broadcast; +mod exp_fp16x16; +mod exp_fp8x23; +mod less_equal_fp16x16; +mod less_equal_fp16x16_broadcast; +mod less_equal_fp8x23; +mod less_equal_fp8x23_broadcast; +mod less_equal_i32; +mod less_equal_i32_broadcast; +mod less_equal_i8; +mod less_equal_i8_broadcast; +mod less_equal_u32; +mod less_equal_u32_broadcast; +mod greater_fp16x16; +mod greater_fp16x16_broadcast; +mod greater_fp8x23; +mod greater_fp8x23_broadcast; +mod greater_i32; +mod greater_i32_broadcast; +mod greater_i8; +mod greater_i8_broadcast; +mod greater_u32; +mod greater_u32_broadcast; +mod leaky_relu_fp16x16; +mod leaky_relu_fp8x23; +mod linear_fp16x16; +mod linear_fp8x23; +mod linear_i32; +mod linear_i8; +mod linear_u32; +mod log_fp16x16; +mod log_fp8x23; +mod logsoftmax_fp16x16_axis_0; +mod logsoftmax_fp16x16_axis_1; +mod logsoftmax_fp8x23_axis_0; +mod logsoftmax_fp8x23_axis_1; +mod matmul_fp16x16_1d; +mod matmul_fp16x16_2x2; +mod matmul_fp16x16_2x1; +mod matmul_fp16x16_1x2; +mod matmul_fp8x23_1d; +mod matmul_fp8x23_2x2; +mod matmul_fp8x23_2x1; +mod matmul_fp8x23_1x2; +mod matmul_i32_1d; +mod matmul_i32_2x2; +mod matmul_i32_2x1; +mod matmul_i32_1x2; +mod matmul_i8_1d; +mod matmul_i8_2x2; +mod matmul_i8_2x1; +mod matmul_i8_1x2; +mod matmul_u32_1d; +mod matmul_u32_2x2; +mod matmul_u32_2x1; +mod matmul_u32_1x2; +mod mul_fp16x16; +mod mul_fp16x16_broadcast; +mod mul_fp8x23; +mod mul_fp8x23_broadcast; +mod mul_i32; +mod mul_i32_broadcast; +mod mul_i8; +mod mul_i8_broadcast; +mod mul_u32; +mod mul_u32_broadcast; +mod or_fp16x16; +mod or_fp16x16_broadcast; +mod or_fp8x23; +mod or_fp8x23_broadcast; +mod or_i32; +mod or_i32_broadcast; +mod or_i8; +mod or_i8_broadcast; +mod or_u32; +mod or_u32_broadcast; +mod reduce_sum_fp16x16_1D; +mod reduce_sum_fp16x16_2D_default; +mod reduce_sum_fp16x16_2D_keepdims; +mod reduce_sum_fp16x16_2D_axis_1; +mod reduce_sum_fp8x23_1D; +mod reduce_sum_fp8x23_2D_default; +mod reduce_sum_fp8x23_2D_keepdims; +mod reduce_sum_fp8x23_2D_axis_1; +mod reduce_sum_i32_1D; +mod reduce_sum_i32_2D_default; +mod reduce_sum_i32_2D_keepdims; +mod reduce_sum_i32_2D_axis_1; +mod reduce_sum_i8_1D; +mod reduce_sum_i8_2D_default; +mod reduce_sum_i8_2D_keepdims; +mod reduce_sum_i8_2D_axis_1; +mod reduce_sum_u32_1D; +mod reduce_sum_u32_2D_default; +mod reduce_sum_u32_2D_keepdims; +mod reduce_sum_u32_2D_axis_1; +mod relu_fp16x16; +mod relu_fp8x23; +mod relu_i32; +mod relu_i8; +mod sigmoid_fp16x16; +mod sigmoid_fp8x23; +mod sin_fp16x16; +mod sin_fp8x23; +mod sinh_fp16x16; +mod sinh_fp8x23; +mod softmax_fp16x16; +mod softmax_fp8x23; +mod softplus_fp8x23; +mod softplus_fp16x16; +mod softsign_fp8x23; +mod softsign_fp16x16; +mod sqrt_fp16x16; +mod sqrt_fp8x23; +mod sub_fp16x16; +mod sub_fp16x16_broadcast; +mod sub_fp8x23; +mod sub_fp8x23_broadcast; +mod sub_i32; +mod sub_i32_broadcast; +mod sub_i8; +mod sub_i8_broadcast; +mod sub_u32; +mod sub_u32_broadcast; +mod tanh_fp16x16; +mod tanh_fp8x23; +mod transpose_fp16x16_2d; +mod transpose_fp16x16_3d; +mod transpose_fp8x23_2d; +mod transpose_fp8x23_3d; +mod transpose_i32_2d; +mod transpose_i32_3d; +mod transpose_i8_2d; +mod transpose_i8_3d; +mod transpose_u32_2d; +mod transpose_u32_3d; +mod xor_fp16x16; +mod xor_fp16x16_broadcast; +mod xor_fp8x23; +mod xor_fp8x23_broadcast; +mod xor_i32; +mod xor_i32_broadcast; +mod xor_i8; +mod xor_i8_broadcast; +mod xor_u32; +mod xor_u32_broadcast; +mod less_fp16x16; +mod less_fp16x16_broadcast; +mod less_fp8x23; +mod less_fp8x23_broadcast; +mod less_i32; +mod less_i32_broadcast; +mod less_i8; +mod less_i8_broadcast; +mod less_u32; +mod less_u32_broadcast; +mod greater_equal_fp16x16; +mod greater_equal_fp16x16_broadcast; +mod greater_equal_fp8x23; +mod greater_equal_fp8x23_broadcast; +mod greater_equal_i32; +mod greater_equal_i32_broadcast; +mod greater_equal_i8; +mod greater_equal_i8_broadcast; +mod greater_equal_u32; +mod greater_equal_u32_broadcast; +mod slice_fp16x16_2d; +mod slice_fp16x16_3d; +mod slice_fp8x23_2d; +mod slice_fp8x23_3d; +mod slice_i32_2d; +mod slice_i32_3d; +mod slice_i8_2d; +mod slice_i8_3d; +mod slice_u32_2d; +mod slice_u32_3d; +mod gather_fp8x23_3d_default; +mod gather_fp8x23_3d_axis1; +mod gather_fp8x23_3d_axis2; +mod gather_fp16x16_3d_default; +mod gather_fp16x16_3d_axis1; +mod gather_fp16x16_3d_axis2; +mod gather_i8_3d_default; +mod gather_i8_3d_axis1; +mod gather_i8_3d_axis2; +mod gather_i32_3d_default; +mod gather_i32_3d_axis1; +mod gather_i32_3d_axis2; +mod gather_u32_3d_default; +mod gather_u32_3d_axis1; +mod gather_u32_3d_axis2; +mod nonzero_fp16x16_2d; +mod nonzero_fp16x16_3d; +mod nonzero_fp8x23_2d; +mod nonzero_fp8x23_3d; +mod nonzero_i32_2d; +mod nonzero_i32_3d; +mod nonzero_i8_2d; +mod nonzero_i8_3d; +mod nonzero_u32_2d; +mod nonzero_u32_3d; +mod squeeze_fP16x16; +mod squeeze_fP8x23; +mod squeeze_i32; +mod squeeze_i8; +mod squeeze_u32; +mod unsqueeze_fp16x16_2d; +mod unsqueeze_fp16x16_3d; +mod unsqueeze_fp8x23_2d; +mod unsqueeze_fp8x23_3d; +mod unsqueeze_i32_2d; +mod unsqueeze_i32_3d; +mod unsqueeze_i8_2d; +mod unsqueeze_i8_3d; +mod unsqueeze_u32_2d; +mod unsqueeze_u32_3d; +mod sign_fP16x16; +mod sign_fP8x23; +mod sign_fail; +mod sign_i32; +mod sign_i8; +mod clip_fp16x16_2d; +mod clip_fp16x16_3d; +mod clip_fp8x23_2d; +mod clip_fp8x23_3d; +mod clip_i32_2d; +mod clip_i32_3d; +mod clip_i8_2d; +mod clip_i8_3d; +mod clip_u32_2d; +mod clip_u32_3d; +mod identity_fP16x16; +mod identity_fP8x23; +mod identity_i32; +mod identity_i8; +mod identity_u32; +mod thresholded_relu_fp16x16; +mod thresholded_relu_fp8x23; +mod hard_sigmoid_fp8x23; +mod hard_sigmoid_fp16x16; +mod neg_fp16x16; +mod neg_fp8x23; +mod neg_i32; +mod neg_i8; +mod gemm_all_attributes; +mod gemm_alpha; +mod gemm_beta; +mod gemm_default_matrix_bias; +mod gemm_default_vector_bias; +mod gemm_default_no_bias; +mod gemm_transposeA; +mod gemm_transposeB; +mod min_fp16x16_three_tensors; +mod min_fp16x16_broadcast_three_tensors; +mod min_fp16x16_two_tensors; +mod min_fp16x16_broadcast_two_tensors; +mod min_fp8x23_three_tensors; +mod min_fp8x23_broadcast_three_tensors; +mod min_fp8x23_two_tensors; +mod min_fp8x23_broadcast_two_tensors; +mod min_i32_three_tensors; +mod min_i32_broadcast_three_tensors; +mod min_i32_two_tensors; +mod min_i32_broadcast_two_tensors; +mod min_i8_three_tensors; +mod min_i8_broadcast_three_tensors; +mod min_i8_two_tensors; +mod min_i8_broadcast_two_tensors; +mod min_u32_three_tensors; +mod min_u32_broadcast_three_tensors; +mod min_u32_two_tensors; +mod min_u32_broadcast_two_tensors; +mod where_fp16x16; +mod where_fp16x16_broadcast; +mod where_fp8x23; +mod where_fp8x23_broadcast; +mod where_i32; +mod where_i32_broadcast; +mod where_i8; +mod where_i8_broadcast; +mod where_u32; +mod where_u32_broadcast; +mod not_bool; +mod round_fp16x16; +mod round_fp8x23; +mod max_fp16x16_three_tensors; +mod max_fp16x16_broadcast_three_tensors; +mod max_fp16x16_two_tensors; +mod max_fp16x16_broadcast_two_tensors; +mod max_fp8x23_three_tensors; +mod max_fp8x23_broadcast_three_tensors; +mod max_fp8x23_two_tensors; +mod max_fp8x23_broadcast_two_tensors; +mod max_i32_three_tensors; +mod max_i32_broadcast_three_tensors; +mod max_i32_two_tensors; +mod max_i32_broadcast_two_tensors; +mod max_i8_three_tensors; +mod max_i8_broadcast_three_tensors; +mod max_i8_two_tensors; +mod max_i8_broadcast_two_tensors; +mod max_u32_three_tensors; +mod max_u32_broadcast_three_tensors; +mod max_u32_two_tensors; +mod max_u32_broadcast_two_tensors; +mod scatter_fp16x16_3d_default; +mod scatter_fp16x16_3d_axis1; +mod scatter_fp16x16_3d_axis1_add; +mod scatter_fp8x23_default; +mod scatter_fp8x23_axis1; +mod scatter_fp8x23_mul; +mod scatter_i8_default; +mod scatter_i8_axis1; +mod scatter_i8_axis1_max; +mod scatter_u32_default; +mod scatter_u32_axis1; +mod scatter_u32_add; +mod array_feature_extractor_1D_i32; +mod array_feature_extractor_1D_fp8x23; +mod array_feature_extractor_1D_fp16x16; +mod array_feature_extractor_2D_i32; +mod array_feature_extractor_2D_fp8x23; +mod array_feature_extractor_2D_fp16x16; +mod array_feature_extractor_3D_i32; +mod array_feature_extractor_3D_fp8x23; +mod array_feature_extractor_3D_fp16x16; +mod binarizer_fp16x16; +mod binarizer_fp8x23; +mod tril_fp16x16; +mod tril_fp16x16_neg; +mod tril_fp16x16_one_row; +mod tril_fp16x16_out_neg; +mod tril_fp16x16_out_pos; +mod tril_fp16x16_pos; +mod tril_fp16x16_square; +mod tril_fp16x16_square_neg; +mod tril_fp16x16_zero; +mod triu_fp16x16; +mod triu_fp16x16_neg; +mod triu_fp16x16_one_row; +mod triu_fp16x16_out_neg; +mod triu_fp16x16_out_pos; +mod triu_fp16x16_pos; +mod triu_fp16x16_square; +mod triu_fp16x16_square_neg; +mod triu_fp16x16_zero; +mod tril_fp8x23; +mod tril_fp8x23_neg; +mod tril_fp8x23_one_row; +mod tril_fp8x23_out_neg; +mod tril_fp8x23_out_pos; +mod tril_fp8x23_pos; +mod tril_fp8x23_square; +mod tril_fp8x23_square_neg; +mod tril_fp8x23_zero; +mod triu_fp8x23; +mod triu_fp8x23_neg; +mod triu_fp8x23_one_row; +mod triu_fp8x23_out_neg; +mod triu_fp8x23_out_pos; +mod triu_fp8x23_pos; +mod triu_fp8x23_square; +mod triu_fp8x23_square_neg; +mod triu_fp8x23_zero; +mod tril_i32; +mod tril_neg_i32; +mod tril_i32_one_row; +mod tril_i32_out_neg; +mod tril_i32_out_pos; +mod tril_i32_pos; +mod tril_i32_square; +mod tril_i32_square_neg; +mod tril_i32_zero; +mod triu_i32; +mod triu_i32_neg; +mod triu_i32_one_row; +mod triu_i32_out_neg; +mod triu_i32_out_pos; +mod triu_i32_pos; +mod triu_i32_square; +mod triu_i32_square_neg; +mod triu_i32_zero; +mod tril_i8; +mod tril_i8_neg; +mod tril_i8_one_row; +mod tril_i8_out_neg; +mod tril_i8_out_pos; +mod tril_i8_pos; +mod tril_i8_square; +mod tril_i8_square_neg; +mod tril_i8_zero; +mod triu_i8; +mod triu_i8_neg; +mod triu_i8_one_row; +mod triu_i8_out_neg; +mod triu_i8_out_pos; +mod triu_i8_pos; +mod triu_i8_square; +mod triu_i8_square_neg; +mod triu_i8_zero; +mod tril_u32; +mod tril_u32_neg; +mod tril_u32_one_row; +mod tril_u32_out_neg; +mod tril_u32_out_pos; +mod tril_u32_pos; +mod tril_u32_square; +mod tril_u32_square_neg; +mod tril_u32_zero; +mod triu_u32; +mod triu_u32_neg; +mod triu_u32_one_row; +mod triu_u32_out_neg; +mod triu_u32_out_pos; +mod triu_u32_pos; +mod triu_u32_square; +mod triu_u32_square_neg; +mod triu_u32_zero; +mod reduce_sum_square_fp16x16_export_do_not_keepdims; +mod reduce_sum_square_fp16x16_export_keepdims; +mod reduce_sum_square_fp16x16_export_negative_axes_keepdims; +mod reduce_sum_square_fp8x23_export_do_not_keepdims; +mod reduce_sum_square_fp8x23_export_keepdims; +mod reduce_sum_square_fp8x23_export_negative_axes_keepdims; +mod reduce_sum_square_i32_export_do_not_keepdims; +mod reduce_sum_square_i32_export_keepdims; +mod reduce_sum_square_i32_export_negative_axes_keepdims; +mod reduce_sum_square_i8_export_do_not_keepdims; +mod reduce_sum_square_i8_export_keepdims; +mod reduce_sum_square_i8_export_negative_axes_keepdims; +mod reduce_sum_square_u32_export_do_not_keepdims; +mod reduce_sum_square_u32_export_keepdims; +mod reduce_sum_square_u32_export_negative_axes_keepdims; +mod reduce_l2_fp16x16_export_do_not_keepdims; +mod reduce_l2_fp16x16_export_keepdims; +mod reduce_l2_fp16x16_export_negative_axes_keepdims; +mod reduce_l2_fp8x23_export_do_not_keepdims; +mod reduce_l2_fp8x23_export_keepdims; +mod reduce_l2_fp8x23_export_negative_axes_keepdims; +mod reduce_l1_fp16x16_export_do_not_keepdims; +mod reduce_l1_fp16x16_export_keepdims; +mod reduce_l1_fp16x16_export_negative_axes_keepdims; +mod reduce_l1_fp8x23_export_do_not_keepdims; +mod reduce_l1_fp8x23_export_keepdims; +mod reduce_l1_fp8x23_export_negative_axes_keepdims; +mod reduce_l1_i32_export_do_not_keepdims; +mod reduce_l1_i32_export_keepdims; +mod reduce_l1_i32_export_negative_axes_keepdims; +mod reduce_l1_i8_export_do_not_keepdims; +mod reduce_l1_i8_export_keepdims; +mod reduce_l1_i8_export_negative_axes_keepdims; +mod reduce_l1_u32_export_do_not_keepdims; +mod reduce_l1_u32_export_keepdims; +mod reduce_l1_u32_export_negative_axes_keepdims; +mod reduce_prod_fp16x16_1D; +mod reduce_prod_fp16x16_2D_default; +mod reduce_prod_fp16x16_2D_keepdims; +mod reduce_prod_fp16x16_2D_axis_1; +mod reduce_prod_fp8x23_1D; +mod reduce_prod_fp8x23_2D_default; +mod reduce_prod_fp8x23_2D_keepdims; +mod reduce_prod_fp8x23_2D_axis_1; +mod reduce_prod_i32_1D; +mod reduce_prod_i32_2D_default; +mod reduce_prod_i32_2D_keepdims; +mod reduce_prod_i32_2D_axis_1; +mod reduce_prod_i8_1D; +mod reduce_prod_i8_2D_default; +mod reduce_prod_i8_2D_keepdims; +mod reduce_prod_i8_2D_axis_1; +mod reduce_prod_u32_1D; +mod reduce_prod_u32_2D_default; +mod reduce_prod_u32_2D_keepdims; +mod reduce_prod_u32_2D_axis_1; +mod gather_elements_fp16x16_3d_default; +mod gather_elements_fp16x16_3d_axis1; +mod gather_elements_fp16x16_3d_axis2; +mod gather_elements_fp8x23_3d_default; +mod gather_elements_fp8x23_3d_axis1; +mod gather_elements_fp8x23_3d_axis2; +mod gather_elements_i8_3d_default; +mod gather_elements_i8_3d_axis1; +mod gather_elements_i32_3d_default; +mod gather_elements_i32_3d_axis1; +mod gather_elements_i32_3d_axis2; +mod gather_elements_u32_default; +mod gather_elements_u32_axis1; +mod gather_elements_u32_axis2; +mod gather_elements_u32_axis3; +mod sequence_length_fp16x16; +mod sequence_length_fp16x16_broadcast; +mod sequence_length_fp8x23; +mod sequence_length_fp8x23_broadcast; +mod sequence_length_i32; +mod sequence_length_i32_broadcast; +mod sequence_length_i8; +mod sequence_length_i8_broadcast; +mod sequence_length_u32; +mod sequence_length_u32_broadcast; +mod sequence_at_u32_positive; +mod sequence_at_u32_negative; +mod sequence_at_fp16x16_positive; +mod sequence_at_fp16x16_negative; +mod sequence_at_fp8x23_positive; +mod sequence_at_fp8x23_negative; +mod sequence_at_i32_positive; +mod sequence_at_i32_negative; +mod sequence_at_i8_positive; +mod sequence_at_i8_negative; +mod reduce_min_fp16x16_1D; +mod reduce_min_fp16x16_2D_default; +mod reduce_min_fp16x16_2D_keepdims; +mod reduce_min_fp16x16_2D_axis_1; +mod reduce_min_fp8x23_1D; +mod reduce_min_fp8x23_2D_default; +mod reduce_min_fp8x23_2D_keepdims; +mod reduce_min_fp8x23_2D_axis_1; +mod reduce_min_i32_1D; +mod reduce_min_i32_2D_default; +mod reduce_min_i32_2D_keepdims; +mod reduce_min_i32_2D_axis_1; +mod reduce_min_i8_1D; +mod reduce_min_i8_2D_default; +mod reduce_min_i8_2D_keepdims; +mod reduce_min_i8_2D_axis_1; +mod reduce_min_u32_1D; +mod reduce_min_u32_2D_default; +mod reduce_min_u32_2D_keepdims; +mod reduce_min_u32_2D_axis_1; +mod sequence_construct_fp16x16; +mod sequence_construct_fp8x23; +mod sequence_construct_i32; +mod sequence_construct_i8; +mod sequence_construct_u32; +mod shrink_hard_fp16x16; +mod shrink_soft_fp16x16; +mod shrink_hard_fp8x23; +mod shrink_soft_fp8x23; +mod sequence_empty_fp16x16; +mod sequence_empty_fp8x23; +mod sequence_empty_i32; +mod sequence_empty_i8; +mod sequence_empty_u32; +mod reduce_mean_fp16x16_1D; +mod reduce_mean_fp16x16_2D_default; +mod reduce_mean_fp16x16_2D_keepdims; +mod reduce_mean_fp16x16_2D_axis_1; +mod reduce_mean_fp8x23_1D; +mod reduce_mean_fp8x23_2D_default; +mod reduce_mean_fp8x23_2D_keepdims; +mod reduce_mean_fp8x23_2D_axis_1; +mod reduce_mean_i32_1D; +mod reduce_mean_i32_2D_default; +mod reduce_mean_i32_2D_keepdims; +mod reduce_mean_i32_2D_axis_1; +mod reduce_mean_i8_1D; +mod reduce_mean_i8_2D_default; +mod reduce_mean_i8_2D_keepdims; +mod reduce_mean_i8_2D_axis_1; +mod reduce_mean_u32_1D; +mod reduce_mean_u32_2D_default; +mod reduce_mean_u32_2D_keepdims; +mod reduce_mean_u32_2D_axis_1; +mod pow_fp16x16; +mod pow_fp16x16_broadcast; +mod pow_fp8x23; +mod pow_fp8x23_broadcast; +mod sequence_erase_u32_positive; +mod sequence_erase_u32_negative; +mod sequence_erase_u32_empty; +mod sequence_erase_fp16x16_positive; +mod sequence_erase_fp16x16_negative; +mod sequence_erase_fp16x16_empty; +mod sequence_erase_fp8x23_positive; +mod sequence_erase_fp8x23_negative; +mod sequence_erase_fp8x23_empty; +mod sequence_erase_i32_positive; +mod sequence_erase_i32_negative; +mod sequence_erase_i32_empty; +mod sequence_erase_i8_positive; +mod sequence_erase_i8_negative; +mod sequence_erase_i8_empty; +mod sequence_insert_fp16x16; +mod sequence_insert_fp8x23; +mod sequence_insert_i32; +mod sequence_insert_i8; +mod sequence_insert_u32; +mod concat_from_sequence_fp8x23_new_axis_zero; +mod concat_from_sequence_fp8x23_new_axis_one; +mod concat_from_sequence_fp8x23_new_axis_default; +mod concat_from_sequence_fp16x16_new_axis_zero; +mod concat_from_sequence_fp16x16_new_axis_one; +mod concat_from_sequence_fp16x16_new_axis_default; +mod concat_from_sequence_i32_new_axis_zero; +mod concat_from_sequence_i32_new_axis_one; +mod concat_from_sequence_i32_new_axis_default; +mod concat_from_sequence_i8_new_axis_zero; +mod concat_from_sequence_i8_new_axis_one; +mod concat_from_sequence_i8_new_axis_default; +mod concat_from_sequence_u32_new_axis_zero; +mod concat_from_sequence_u32_new_axis_one; +mod concat_from_sequence_u32_new_axis_default; +mod is_nan_fp16x16; +mod is_nan_fp8x23; +mod is_inf_fp16x16; +mod is_inf_fp8x23; +mod is_inf_i32; +mod is_inf_i8; +mod is_inf_u32; +mod is_pos_inf_fp16x16; +mod is_neg_inf_fp16x16; +mod is_pos_inf_fp8x23; +mod is_neg_inf_fp8x23; +mod is_pos_inf_i32; +mod is_neg_inf_i32; +mod is_pos_inf_i8; +mod is_neg_inf_i8; +mod reduce_log_sum_fp8x23_export_do_not_keepdims; +mod reduce_log_sum_fp8x23_export_keepdims; +mod reduce_log_sum_fp8x23_export_negative_axes_keepdims; +mod reduce_log_sum_fp16x16_export_do_not_keepdims; +mod reduce_log_sum_fp16x16_export_keepdims; +mod reduce_log_sum_fp16x16_export_negative_axes_keepdims; +mod and_bool; +mod erf_fp16x16; +mod erf_fp8x23; +mod unique_fp16x16_without_axis_sorted; +mod unique_fp16x16_with_axis_zero_sorted; +mod unique_u32_without_axis_sorted; +mod unique_u32_without_axis_not_sorted; +mod unique_u32_with_axis_zero_sorted; +mod unique_u32_with_axis_zero_not_sorted; +mod unique_u32_with_axis_one_sorted; +mod unique_u32_with_axis_one_not_sorted; +mod gather_nd_fp16x16_3d_default; +mod gather_nd_fp16x16_3d_batch_dims1; +mod gather_nd_fp16x16_3d_batch_dims2; +mod gather_nd_fp8x23_3d_default; +mod gather_nd_fp8x23_3d_batch_dims1; +mod gather_nd_fp8x23_3d_batch_dims2; +mod gather_nd_i32_3d_default; +mod gather_nd_i32_3d_batch_dims1; +mod gather_nd_i32_3d_batch_dims2; +mod gather_nd_i8_3d_default; +mod gather_nd_i8_3d_batch_dims1; +mod gather_nd_u32_default; +mod gather_nd_u32_batch_dims1; +mod gather_nd_u32_batch_dims2; +mod compress_fp16x16_3d_default; +mod compress_fp16x16_3d_axis1; +mod compress_fp16x16_3d_axis2; +mod compress_fp16x16_3d_axis3; +mod compress_fp16x16_3d_noaxis; +mod compress_fp8x23_3d_default; +mod compress_fp8x23_3d_axis1; +mod compress_fp8x23_3d_axis2; +mod compress_i32_3d_default; +mod compress_i32_3d_axis1; +mod compress_i32_3d_axis2; +mod compress_i8_3d_default; +mod compress_i8_3d_axis1; +mod compress_i8_3d_axis2; +mod compress_u32_3d_default; +mod compress_u32_3d_axis1; +mod compress_u32_3d_axis2; +mod compress_u32_3d_axis2_2; +mod compress_u32_3d_axis3; diff --git a/tests/nodes/compress_fp16x16_3d_axis1.cairo b/tests/nodes/compress_fp16x16_3d_axis1.cairo new file mode 100644 index 000000000..de0c173ed --- /dev/null +++ b/tests/nodes/compress_fp16x16_3d_axis1.cairo @@ -0,0 +1,24 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::operators::tensor::U32TensorPartialEq; +use orion::operators::tensor::U32Tensor; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::FP16x16TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::FP16x16Tensor; + +#[test] +#[available_gas(2000000000)] +fn test_compress_fp16x16_3d_axis1() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z_0 = output_0::output_0(); + + let y_0 = input_0.compress(condition:input_1, axis:Option::Some(1)); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/compress_fp16x16_3d_axis1/input_0.cairo b/tests/nodes/compress_fp16x16_3d_axis1/input_0.cairo new file mode 100644 index 000000000..46162489f --- /dev/null +++ b/tests/nodes/compress_fp16x16_3d_axis1/input_0.cairo @@ -0,0 +1,195 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(4); + shape.append(3); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 851968, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + data.append(FP16x16 { mag: 1114112, sign: false }); + data.append(FP16x16 { mag: 1179648, sign: false }); + data.append(FP16x16 { mag: 1245184, sign: false }); + data.append(FP16x16 { mag: 1310720, sign: false }); + data.append(FP16x16 { mag: 1376256, sign: false }); + data.append(FP16x16 { mag: 1441792, sign: false }); + data.append(FP16x16 { mag: 1507328, sign: false }); + data.append(FP16x16 { mag: 1572864, sign: false }); + data.append(FP16x16 { mag: 1638400, sign: false }); + data.append(FP16x16 { mag: 1703936, sign: false }); + data.append(FP16x16 { mag: 1769472, sign: false }); + data.append(FP16x16 { mag: 1835008, sign: false }); + data.append(FP16x16 { mag: 1900544, sign: false }); + data.append(FP16x16 { mag: 1966080, sign: false }); + data.append(FP16x16 { mag: 2031616, sign: false }); + data.append(FP16x16 { mag: 2097152, sign: false }); + data.append(FP16x16 { mag: 2162688, sign: false }); + data.append(FP16x16 { mag: 2228224, sign: false }); + data.append(FP16x16 { mag: 2293760, sign: false }); + data.append(FP16x16 { mag: 2359296, sign: false }); + data.append(FP16x16 { mag: 2424832, sign: false }); + data.append(FP16x16 { mag: 2490368, sign: false }); + data.append(FP16x16 { mag: 2555904, sign: false }); + data.append(FP16x16 { mag: 2621440, sign: false }); + data.append(FP16x16 { mag: 2686976, sign: false }); + data.append(FP16x16 { mag: 2752512, sign: false }); + data.append(FP16x16 { mag: 2818048, sign: false }); + data.append(FP16x16 { mag: 2883584, sign: false }); + data.append(FP16x16 { mag: 2949120, sign: false }); + data.append(FP16x16 { mag: 3014656, sign: false }); + data.append(FP16x16 { mag: 3080192, sign: false }); + data.append(FP16x16 { mag: 3145728, sign: false }); + data.append(FP16x16 { mag: 3211264, sign: false }); + data.append(FP16x16 { mag: 3276800, sign: false }); + data.append(FP16x16 { mag: 3342336, sign: false }); + data.append(FP16x16 { mag: 3407872, sign: false }); + data.append(FP16x16 { mag: 3473408, sign: false }); + data.append(FP16x16 { mag: 3538944, sign: false }); + data.append(FP16x16 { mag: 3604480, sign: false }); + data.append(FP16x16 { mag: 3670016, sign: false }); + data.append(FP16x16 { mag: 3735552, sign: false }); + data.append(FP16x16 { mag: 3801088, sign: false }); + data.append(FP16x16 { mag: 3866624, sign: false }); + data.append(FP16x16 { mag: 3932160, sign: false }); + data.append(FP16x16 { mag: 3997696, sign: false }); + data.append(FP16x16 { mag: 4063232, sign: false }); + data.append(FP16x16 { mag: 4128768, sign: false }); + data.append(FP16x16 { mag: 4194304, sign: false }); + data.append(FP16x16 { mag: 4259840, sign: false }); + data.append(FP16x16 { mag: 4325376, sign: false }); + data.append(FP16x16 { mag: 4390912, sign: false }); + data.append(FP16x16 { mag: 4456448, sign: false }); + data.append(FP16x16 { mag: 4521984, sign: false }); + data.append(FP16x16 { mag: 4587520, sign: false }); + data.append(FP16x16 { mag: 4653056, sign: false }); + data.append(FP16x16 { mag: 4718592, sign: false }); + data.append(FP16x16 { mag: 4784128, sign: false }); + data.append(FP16x16 { mag: 4849664, sign: false }); + data.append(FP16x16 { mag: 4915200, sign: false }); + data.append(FP16x16 { mag: 4980736, sign: false }); + data.append(FP16x16 { mag: 5046272, sign: false }); + data.append(FP16x16 { mag: 5111808, sign: false }); + data.append(FP16x16 { mag: 5177344, sign: false }); + data.append(FP16x16 { mag: 5242880, sign: false }); + data.append(FP16x16 { mag: 5308416, sign: false }); + data.append(FP16x16 { mag: 5373952, sign: false }); + data.append(FP16x16 { mag: 5439488, sign: false }); + data.append(FP16x16 { mag: 5505024, sign: false }); + data.append(FP16x16 { mag: 5570560, sign: false }); + data.append(FP16x16 { mag: 5636096, sign: false }); + data.append(FP16x16 { mag: 5701632, sign: false }); + data.append(FP16x16 { mag: 5767168, sign: false }); + data.append(FP16x16 { mag: 5832704, sign: false }); + data.append(FP16x16 { mag: 5898240, sign: false }); + data.append(FP16x16 { mag: 5963776, sign: false }); + data.append(FP16x16 { mag: 6029312, sign: false }); + data.append(FP16x16 { mag: 6094848, sign: false }); + data.append(FP16x16 { mag: 6160384, sign: false }); + data.append(FP16x16 { mag: 6225920, sign: false }); + data.append(FP16x16 { mag: 6291456, sign: false }); + data.append(FP16x16 { mag: 6356992, sign: false }); + data.append(FP16x16 { mag: 6422528, sign: false }); + data.append(FP16x16 { mag: 6488064, sign: false }); + data.append(FP16x16 { mag: 6553600, sign: false }); + data.append(FP16x16 { mag: 6619136, sign: false }); + data.append(FP16x16 { mag: 6684672, sign: false }); + data.append(FP16x16 { mag: 6750208, sign: false }); + data.append(FP16x16 { mag: 6815744, sign: false }); + data.append(FP16x16 { mag: 6881280, sign: false }); + data.append(FP16x16 { mag: 6946816, sign: false }); + data.append(FP16x16 { mag: 7012352, sign: false }); + data.append(FP16x16 { mag: 7077888, sign: false }); + data.append(FP16x16 { mag: 7143424, sign: false }); + data.append(FP16x16 { mag: 7208960, sign: false }); + data.append(FP16x16 { mag: 7274496, sign: false }); + data.append(FP16x16 { mag: 7340032, sign: false }); + data.append(FP16x16 { mag: 7405568, sign: false }); + data.append(FP16x16 { mag: 7471104, sign: false }); + data.append(FP16x16 { mag: 7536640, sign: false }); + data.append(FP16x16 { mag: 7602176, sign: false }); + data.append(FP16x16 { mag: 7667712, sign: false }); + data.append(FP16x16 { mag: 7733248, sign: false }); + data.append(FP16x16 { mag: 7798784, sign: false }); + data.append(FP16x16 { mag: 7864320, sign: false }); + data.append(FP16x16 { mag: 7929856, sign: false }); + data.append(FP16x16 { mag: 7995392, sign: false }); + data.append(FP16x16 { mag: 8060928, sign: false }); + data.append(FP16x16 { mag: 8126464, sign: false }); + data.append(FP16x16 { mag: 8192000, sign: false }); + data.append(FP16x16 { mag: 8257536, sign: false }); + data.append(FP16x16 { mag: 8323072, sign: false }); + data.append(FP16x16 { mag: 8388608, sign: false }); + data.append(FP16x16 { mag: 8454144, sign: false }); + data.append(FP16x16 { mag: 8519680, sign: false }); + data.append(FP16x16 { mag: 8585216, sign: false }); + data.append(FP16x16 { mag: 8650752, sign: false }); + data.append(FP16x16 { mag: 8716288, sign: false }); + data.append(FP16x16 { mag: 8781824, sign: false }); + data.append(FP16x16 { mag: 8847360, sign: false }); + data.append(FP16x16 { mag: 8912896, sign: false }); + data.append(FP16x16 { mag: 8978432, sign: false }); + data.append(FP16x16 { mag: 9043968, sign: false }); + data.append(FP16x16 { mag: 9109504, sign: false }); + data.append(FP16x16 { mag: 9175040, sign: false }); + data.append(FP16x16 { mag: 9240576, sign: false }); + data.append(FP16x16 { mag: 9306112, sign: false }); + data.append(FP16x16 { mag: 9371648, sign: false }); + data.append(FP16x16 { mag: 9437184, sign: false }); + data.append(FP16x16 { mag: 9502720, sign: false }); + data.append(FP16x16 { mag: 9568256, sign: false }); + data.append(FP16x16 { mag: 9633792, sign: false }); + data.append(FP16x16 { mag: 9699328, sign: false }); + data.append(FP16x16 { mag: 9764864, sign: false }); + data.append(FP16x16 { mag: 9830400, sign: false }); + data.append(FP16x16 { mag: 9895936, sign: false }); + data.append(FP16x16 { mag: 9961472, sign: false }); + data.append(FP16x16 { mag: 10027008, sign: false }); + data.append(FP16x16 { mag: 10092544, sign: false }); + data.append(FP16x16 { mag: 10158080, sign: false }); + data.append(FP16x16 { mag: 10223616, sign: false }); + data.append(FP16x16 { mag: 10289152, sign: false }); + data.append(FP16x16 { mag: 10354688, sign: false }); + data.append(FP16x16 { mag: 10420224, sign: false }); + data.append(FP16x16 { mag: 10485760, sign: false }); + data.append(FP16x16 { mag: 10551296, sign: false }); + data.append(FP16x16 { mag: 10616832, sign: false }); + data.append(FP16x16 { mag: 10682368, sign: false }); + data.append(FP16x16 { mag: 10747904, sign: false }); + data.append(FP16x16 { mag: 10813440, sign: false }); + data.append(FP16x16 { mag: 10878976, sign: false }); + data.append(FP16x16 { mag: 10944512, sign: false }); + data.append(FP16x16 { mag: 11010048, sign: false }); + data.append(FP16x16 { mag: 11075584, sign: false }); + data.append(FP16x16 { mag: 11141120, sign: false }); + data.append(FP16x16 { mag: 11206656, sign: false }); + data.append(FP16x16 { mag: 11272192, sign: false }); + data.append(FP16x16 { mag: 11337728, sign: false }); + data.append(FP16x16 { mag: 11403264, sign: false }); + data.append(FP16x16 { mag: 11468800, sign: false }); + data.append(FP16x16 { mag: 11534336, sign: false }); + data.append(FP16x16 { mag: 11599872, sign: false }); + data.append(FP16x16 { mag: 11665408, sign: false }); + data.append(FP16x16 { mag: 11730944, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_fp16x16_3d_axis1/input_1.cairo b/tests/nodes/compress_fp16x16_3d_axis1/input_1.cairo new file mode 100644 index 000000000..b97e1003b --- /dev/null +++ b/tests/nodes/compress_fp16x16_3d_axis1/input_1.cairo @@ -0,0 +1,15 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(1); + data.append(1); + data.append(1); + data.append(0); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_fp16x16_3d_axis1/output_0.cairo b/tests/nodes/compress_fp16x16_3d_axis1/output_0.cairo new file mode 100644 index 000000000..1ff0323a5 --- /dev/null +++ b/tests/nodes/compress_fp16x16_3d_axis1/output_0.cairo @@ -0,0 +1,150 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(3); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 851968, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + data.append(FP16x16 { mag: 1114112, sign: false }); + data.append(FP16x16 { mag: 1179648, sign: false }); + data.append(FP16x16 { mag: 1245184, sign: false }); + data.append(FP16x16 { mag: 1310720, sign: false }); + data.append(FP16x16 { mag: 1376256, sign: false }); + data.append(FP16x16 { mag: 1441792, sign: false }); + data.append(FP16x16 { mag: 1507328, sign: false }); + data.append(FP16x16 { mag: 1572864, sign: false }); + data.append(FP16x16 { mag: 1638400, sign: false }); + data.append(FP16x16 { mag: 1703936, sign: false }); + data.append(FP16x16 { mag: 1769472, sign: false }); + data.append(FP16x16 { mag: 1835008, sign: false }); + data.append(FP16x16 { mag: 1900544, sign: false }); + data.append(FP16x16 { mag: 1966080, sign: false }); + data.append(FP16x16 { mag: 2031616, sign: false }); + data.append(FP16x16 { mag: 2097152, sign: false }); + data.append(FP16x16 { mag: 2162688, sign: false }); + data.append(FP16x16 { mag: 2228224, sign: false }); + data.append(FP16x16 { mag: 2293760, sign: false }); + data.append(FP16x16 { mag: 2359296, sign: false }); + data.append(FP16x16 { mag: 2424832, sign: false }); + data.append(FP16x16 { mag: 2490368, sign: false }); + data.append(FP16x16 { mag: 2555904, sign: false }); + data.append(FP16x16 { mag: 2621440, sign: false }); + data.append(FP16x16 { mag: 2686976, sign: false }); + data.append(FP16x16 { mag: 2752512, sign: false }); + data.append(FP16x16 { mag: 2818048, sign: false }); + data.append(FP16x16 { mag: 2883584, sign: false }); + data.append(FP16x16 { mag: 3932160, sign: false }); + data.append(FP16x16 { mag: 3997696, sign: false }); + data.append(FP16x16 { mag: 4063232, sign: false }); + data.append(FP16x16 { mag: 4128768, sign: false }); + data.append(FP16x16 { mag: 4194304, sign: false }); + data.append(FP16x16 { mag: 4259840, sign: false }); + data.append(FP16x16 { mag: 4325376, sign: false }); + data.append(FP16x16 { mag: 4390912, sign: false }); + data.append(FP16x16 { mag: 4456448, sign: false }); + data.append(FP16x16 { mag: 4521984, sign: false }); + data.append(FP16x16 { mag: 4587520, sign: false }); + data.append(FP16x16 { mag: 4653056, sign: false }); + data.append(FP16x16 { mag: 4718592, sign: false }); + data.append(FP16x16 { mag: 4784128, sign: false }); + data.append(FP16x16 { mag: 4849664, sign: false }); + data.append(FP16x16 { mag: 4915200, sign: false }); + data.append(FP16x16 { mag: 4980736, sign: false }); + data.append(FP16x16 { mag: 5046272, sign: false }); + data.append(FP16x16 { mag: 5111808, sign: false }); + data.append(FP16x16 { mag: 5177344, sign: false }); + data.append(FP16x16 { mag: 5242880, sign: false }); + data.append(FP16x16 { mag: 5308416, sign: false }); + data.append(FP16x16 { mag: 5373952, sign: false }); + data.append(FP16x16 { mag: 5439488, sign: false }); + data.append(FP16x16 { mag: 5505024, sign: false }); + data.append(FP16x16 { mag: 5570560, sign: false }); + data.append(FP16x16 { mag: 5636096, sign: false }); + data.append(FP16x16 { mag: 5701632, sign: false }); + data.append(FP16x16 { mag: 5767168, sign: false }); + data.append(FP16x16 { mag: 5832704, sign: false }); + data.append(FP16x16 { mag: 5898240, sign: false }); + data.append(FP16x16 { mag: 5963776, sign: false }); + data.append(FP16x16 { mag: 6029312, sign: false }); + data.append(FP16x16 { mag: 6094848, sign: false }); + data.append(FP16x16 { mag: 6160384, sign: false }); + data.append(FP16x16 { mag: 6225920, sign: false }); + data.append(FP16x16 { mag: 6291456, sign: false }); + data.append(FP16x16 { mag: 6356992, sign: false }); + data.append(FP16x16 { mag: 6422528, sign: false }); + data.append(FP16x16 { mag: 6488064, sign: false }); + data.append(FP16x16 { mag: 6553600, sign: false }); + data.append(FP16x16 { mag: 6619136, sign: false }); + data.append(FP16x16 { mag: 6684672, sign: false }); + data.append(FP16x16 { mag: 6750208, sign: false }); + data.append(FP16x16 { mag: 6815744, sign: false }); + data.append(FP16x16 { mag: 7864320, sign: false }); + data.append(FP16x16 { mag: 7929856, sign: false }); + data.append(FP16x16 { mag: 7995392, sign: false }); + data.append(FP16x16 { mag: 8060928, sign: false }); + data.append(FP16x16 { mag: 8126464, sign: false }); + data.append(FP16x16 { mag: 8192000, sign: false }); + data.append(FP16x16 { mag: 8257536, sign: false }); + data.append(FP16x16 { mag: 8323072, sign: false }); + data.append(FP16x16 { mag: 8388608, sign: false }); + data.append(FP16x16 { mag: 8454144, sign: false }); + data.append(FP16x16 { mag: 8519680, sign: false }); + data.append(FP16x16 { mag: 8585216, sign: false }); + data.append(FP16x16 { mag: 8650752, sign: false }); + data.append(FP16x16 { mag: 8716288, sign: false }); + data.append(FP16x16 { mag: 8781824, sign: false }); + data.append(FP16x16 { mag: 8847360, sign: false }); + data.append(FP16x16 { mag: 8912896, sign: false }); + data.append(FP16x16 { mag: 8978432, sign: false }); + data.append(FP16x16 { mag: 9043968, sign: false }); + data.append(FP16x16 { mag: 9109504, sign: false }); + data.append(FP16x16 { mag: 9175040, sign: false }); + data.append(FP16x16 { mag: 9240576, sign: false }); + data.append(FP16x16 { mag: 9306112, sign: false }); + data.append(FP16x16 { mag: 9371648, sign: false }); + data.append(FP16x16 { mag: 9437184, sign: false }); + data.append(FP16x16 { mag: 9502720, sign: false }); + data.append(FP16x16 { mag: 9568256, sign: false }); + data.append(FP16x16 { mag: 9633792, sign: false }); + data.append(FP16x16 { mag: 9699328, sign: false }); + data.append(FP16x16 { mag: 9764864, sign: false }); + data.append(FP16x16 { mag: 9830400, sign: false }); + data.append(FP16x16 { mag: 9895936, sign: false }); + data.append(FP16x16 { mag: 9961472, sign: false }); + data.append(FP16x16 { mag: 10027008, sign: false }); + data.append(FP16x16 { mag: 10092544, sign: false }); + data.append(FP16x16 { mag: 10158080, sign: false }); + data.append(FP16x16 { mag: 10223616, sign: false }); + data.append(FP16x16 { mag: 10289152, sign: false }); + data.append(FP16x16 { mag: 10354688, sign: false }); + data.append(FP16x16 { mag: 10420224, sign: false }); + data.append(FP16x16 { mag: 10485760, sign: false }); + data.append(FP16x16 { mag: 10551296, sign: false }); + data.append(FP16x16 { mag: 10616832, sign: false }); + data.append(FP16x16 { mag: 10682368, sign: false }); + data.append(FP16x16 { mag: 10747904, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_fp16x16_3d_axis2.cairo b/tests/nodes/compress_fp16x16_3d_axis2.cairo new file mode 100644 index 000000000..765bcb5ea --- /dev/null +++ b/tests/nodes/compress_fp16x16_3d_axis2.cairo @@ -0,0 +1,24 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::operators::tensor::U32TensorPartialEq; +use orion::operators::tensor::U32Tensor; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::FP16x16TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::FP16x16Tensor; + +#[test] +#[available_gas(2000000000)] +fn test_compress_fp16x16_3d_axis2() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z_0 = output_0::output_0(); + + let y_0 = input_0.compress(condition:input_1, axis:Option::Some(2)); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/compress_fp16x16_3d_axis2/input_0.cairo b/tests/nodes/compress_fp16x16_3d_axis2/input_0.cairo new file mode 100644 index 000000000..2ddf33488 --- /dev/null +++ b/tests/nodes/compress_fp16x16_3d_axis2/input_0.cairo @@ -0,0 +1,62 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + shape.append(3); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 851968, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + data.append(FP16x16 { mag: 1114112, sign: false }); + data.append(FP16x16 { mag: 1179648, sign: false }); + data.append(FP16x16 { mag: 1245184, sign: false }); + data.append(FP16x16 { mag: 1310720, sign: false }); + data.append(FP16x16 { mag: 1376256, sign: false }); + data.append(FP16x16 { mag: 1441792, sign: false }); + data.append(FP16x16 { mag: 1507328, sign: false }); + data.append(FP16x16 { mag: 1572864, sign: false }); + data.append(FP16x16 { mag: 1638400, sign: false }); + data.append(FP16x16 { mag: 1703936, sign: false }); + data.append(FP16x16 { mag: 1769472, sign: false }); + data.append(FP16x16 { mag: 1835008, sign: false }); + data.append(FP16x16 { mag: 1900544, sign: false }); + data.append(FP16x16 { mag: 1966080, sign: false }); + data.append(FP16x16 { mag: 2031616, sign: false }); + data.append(FP16x16 { mag: 2097152, sign: false }); + data.append(FP16x16 { mag: 2162688, sign: false }); + data.append(FP16x16 { mag: 2228224, sign: false }); + data.append(FP16x16 { mag: 2293760, sign: false }); + data.append(FP16x16 { mag: 2359296, sign: false }); + data.append(FP16x16 { mag: 2424832, sign: false }); + data.append(FP16x16 { mag: 2490368, sign: false }); + data.append(FP16x16 { mag: 2555904, sign: false }); + data.append(FP16x16 { mag: 2621440, sign: false }); + data.append(FP16x16 { mag: 2686976, sign: false }); + data.append(FP16x16 { mag: 2752512, sign: false }); + data.append(FP16x16 { mag: 2818048, sign: false }); + data.append(FP16x16 { mag: 2883584, sign: false }); + data.append(FP16x16 { mag: 2949120, sign: false }); + data.append(FP16x16 { mag: 3014656, sign: false }); + data.append(FP16x16 { mag: 3080192, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_fp16x16_3d_axis2/input_1.cairo b/tests/nodes/compress_fp16x16_3d_axis2/input_1.cairo new file mode 100644 index 000000000..f1549f634 --- /dev/null +++ b/tests/nodes/compress_fp16x16_3d_axis2/input_1.cairo @@ -0,0 +1,15 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(1); + data.append(0); + data.append(1); + data.append(1); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_fp16x16_3d_axis2/output_0.cairo b/tests/nodes/compress_fp16x16_3d_axis2/output_0.cairo new file mode 100644 index 000000000..4a9d7b81f --- /dev/null +++ b/tests/nodes/compress_fp16x16_3d_axis2/output_0.cairo @@ -0,0 +1,50 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + data.append(FP16x16 { mag: 1179648, sign: false }); + data.append(FP16x16 { mag: 1245184, sign: false }); + data.append(FP16x16 { mag: 1310720, sign: false }); + data.append(FP16x16 { mag: 1441792, sign: false }); + data.append(FP16x16 { mag: 1507328, sign: false }); + data.append(FP16x16 { mag: 1572864, sign: false }); + data.append(FP16x16 { mag: 1703936, sign: false }); + data.append(FP16x16 { mag: 1769472, sign: false }); + data.append(FP16x16 { mag: 1835008, sign: false }); + data.append(FP16x16 { mag: 1966080, sign: false }); + data.append(FP16x16 { mag: 2031616, sign: false }); + data.append(FP16x16 { mag: 2097152, sign: false }); + data.append(FP16x16 { mag: 2228224, sign: false }); + data.append(FP16x16 { mag: 2293760, sign: false }); + data.append(FP16x16 { mag: 2359296, sign: false }); + data.append(FP16x16 { mag: 2490368, sign: false }); + data.append(FP16x16 { mag: 2555904, sign: false }); + data.append(FP16x16 { mag: 2621440, sign: false }); + data.append(FP16x16 { mag: 2752512, sign: false }); + data.append(FP16x16 { mag: 2818048, sign: false }); + data.append(FP16x16 { mag: 2883584, sign: false }); + data.append(FP16x16 { mag: 3014656, sign: false }); + data.append(FP16x16 { mag: 3080192, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_fp16x16_3d_axis3.cairo b/tests/nodes/compress_fp16x16_3d_axis3.cairo new file mode 100644 index 000000000..ffa9c8321 --- /dev/null +++ b/tests/nodes/compress_fp16x16_3d_axis3.cairo @@ -0,0 +1,24 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::operators::tensor::U32TensorPartialEq; +use orion::operators::tensor::U32Tensor; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::FP16x16TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::FP16x16Tensor; + +#[test] +#[available_gas(2000000000)] +fn test_compress_fp16x16_3d_axis3() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z_0 = output_0::output_0(); + + let y_0 = input_0.compress(condition:input_1, axis:Option::Some(3)); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/compress_fp16x16_3d_axis3/input_0.cairo b/tests/nodes/compress_fp16x16_3d_axis3/input_0.cairo new file mode 100644 index 000000000..a189b31f1 --- /dev/null +++ b/tests/nodes/compress_fp16x16_3d_axis3/input_0.cairo @@ -0,0 +1,111 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + shape.append(3); + shape.append(4); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 851968, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + data.append(FP16x16 { mag: 1114112, sign: false }); + data.append(FP16x16 { mag: 1179648, sign: false }); + data.append(FP16x16 { mag: 1245184, sign: false }); + data.append(FP16x16 { mag: 1310720, sign: false }); + data.append(FP16x16 { mag: 1376256, sign: false }); + data.append(FP16x16 { mag: 1441792, sign: false }); + data.append(FP16x16 { mag: 1507328, sign: false }); + data.append(FP16x16 { mag: 1572864, sign: false }); + data.append(FP16x16 { mag: 1638400, sign: false }); + data.append(FP16x16 { mag: 1703936, sign: false }); + data.append(FP16x16 { mag: 1769472, sign: false }); + data.append(FP16x16 { mag: 1835008, sign: false }); + data.append(FP16x16 { mag: 1900544, sign: false }); + data.append(FP16x16 { mag: 1966080, sign: false }); + data.append(FP16x16 { mag: 2031616, sign: false }); + data.append(FP16x16 { mag: 2097152, sign: false }); + data.append(FP16x16 { mag: 2162688, sign: false }); + data.append(FP16x16 { mag: 2228224, sign: false }); + data.append(FP16x16 { mag: 2293760, sign: false }); + data.append(FP16x16 { mag: 2359296, sign: false }); + data.append(FP16x16 { mag: 2424832, sign: false }); + data.append(FP16x16 { mag: 2490368, sign: false }); + data.append(FP16x16 { mag: 2555904, sign: false }); + data.append(FP16x16 { mag: 2621440, sign: false }); + data.append(FP16x16 { mag: 2686976, sign: false }); + data.append(FP16x16 { mag: 2752512, sign: false }); + data.append(FP16x16 { mag: 2818048, sign: false }); + data.append(FP16x16 { mag: 2883584, sign: false }); + data.append(FP16x16 { mag: 2949120, sign: false }); + data.append(FP16x16 { mag: 3014656, sign: false }); + data.append(FP16x16 { mag: 3080192, sign: false }); + data.append(FP16x16 { mag: 3145728, sign: false }); + data.append(FP16x16 { mag: 3211264, sign: false }); + data.append(FP16x16 { mag: 3276800, sign: false }); + data.append(FP16x16 { mag: 3342336, sign: false }); + data.append(FP16x16 { mag: 3407872, sign: false }); + data.append(FP16x16 { mag: 3473408, sign: false }); + data.append(FP16x16 { mag: 3538944, sign: false }); + data.append(FP16x16 { mag: 3604480, sign: false }); + data.append(FP16x16 { mag: 3670016, sign: false }); + data.append(FP16x16 { mag: 3735552, sign: false }); + data.append(FP16x16 { mag: 3801088, sign: false }); + data.append(FP16x16 { mag: 3866624, sign: false }); + data.append(FP16x16 { mag: 3932160, sign: false }); + data.append(FP16x16 { mag: 3997696, sign: false }); + data.append(FP16x16 { mag: 4063232, sign: false }); + data.append(FP16x16 { mag: 4128768, sign: false }); + data.append(FP16x16 { mag: 4194304, sign: false }); + data.append(FP16x16 { mag: 4259840, sign: false }); + data.append(FP16x16 { mag: 4325376, sign: false }); + data.append(FP16x16 { mag: 4390912, sign: false }); + data.append(FP16x16 { mag: 4456448, sign: false }); + data.append(FP16x16 { mag: 4521984, sign: false }); + data.append(FP16x16 { mag: 4587520, sign: false }); + data.append(FP16x16 { mag: 4653056, sign: false }); + data.append(FP16x16 { mag: 4718592, sign: false }); + data.append(FP16x16 { mag: 4784128, sign: false }); + data.append(FP16x16 { mag: 4849664, sign: false }); + data.append(FP16x16 { mag: 4915200, sign: false }); + data.append(FP16x16 { mag: 4980736, sign: false }); + data.append(FP16x16 { mag: 5046272, sign: false }); + data.append(FP16x16 { mag: 5111808, sign: false }); + data.append(FP16x16 { mag: 5177344, sign: false }); + data.append(FP16x16 { mag: 5242880, sign: false }); + data.append(FP16x16 { mag: 5308416, sign: false }); + data.append(FP16x16 { mag: 5373952, sign: false }); + data.append(FP16x16 { mag: 5439488, sign: false }); + data.append(FP16x16 { mag: 5505024, sign: false }); + data.append(FP16x16 { mag: 5570560, sign: false }); + data.append(FP16x16 { mag: 5636096, sign: false }); + data.append(FP16x16 { mag: 5701632, sign: false }); + data.append(FP16x16 { mag: 5767168, sign: false }); + data.append(FP16x16 { mag: 5832704, sign: false }); + data.append(FP16x16 { mag: 5898240, sign: false }); + data.append(FP16x16 { mag: 5963776, sign: false }); + data.append(FP16x16 { mag: 6029312, sign: false }); + data.append(FP16x16 { mag: 6094848, sign: false }); + data.append(FP16x16 { mag: 6160384, sign: false }); + data.append(FP16x16 { mag: 6225920, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_fp16x16_3d_axis3/input_1.cairo b/tests/nodes/compress_fp16x16_3d_axis3/input_1.cairo new file mode 100644 index 000000000..8d36b9136 --- /dev/null +++ b/tests/nodes/compress_fp16x16_3d_axis3/input_1.cairo @@ -0,0 +1,13 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(1); + data.append(0); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_fp16x16_3d_axis3/output_0.cairo b/tests/nodes/compress_fp16x16_3d_axis3/output_0.cairo new file mode 100644 index 000000000..edfec2268 --- /dev/null +++ b/tests/nodes/compress_fp16x16_3d_axis3/output_0.cairo @@ -0,0 +1,63 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + shape.append(3); + shape.append(4); + shape.append(1); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + data.append(FP16x16 { mag: 1179648, sign: false }); + data.append(FP16x16 { mag: 1310720, sign: false }); + data.append(FP16x16 { mag: 1441792, sign: false }); + data.append(FP16x16 { mag: 1572864, sign: false }); + data.append(FP16x16 { mag: 1703936, sign: false }); + data.append(FP16x16 { mag: 1835008, sign: false }); + data.append(FP16x16 { mag: 1966080, sign: false }); + data.append(FP16x16 { mag: 2097152, sign: false }); + data.append(FP16x16 { mag: 2228224, sign: false }); + data.append(FP16x16 { mag: 2359296, sign: false }); + data.append(FP16x16 { mag: 2490368, sign: false }); + data.append(FP16x16 { mag: 2621440, sign: false }); + data.append(FP16x16 { mag: 2752512, sign: false }); + data.append(FP16x16 { mag: 2883584, sign: false }); + data.append(FP16x16 { mag: 3014656, sign: false }); + data.append(FP16x16 { mag: 3145728, sign: false }); + data.append(FP16x16 { mag: 3276800, sign: false }); + data.append(FP16x16 { mag: 3407872, sign: false }); + data.append(FP16x16 { mag: 3538944, sign: false }); + data.append(FP16x16 { mag: 3670016, sign: false }); + data.append(FP16x16 { mag: 3801088, sign: false }); + data.append(FP16x16 { mag: 3932160, sign: false }); + data.append(FP16x16 { mag: 4063232, sign: false }); + data.append(FP16x16 { mag: 4194304, sign: false }); + data.append(FP16x16 { mag: 4325376, sign: false }); + data.append(FP16x16 { mag: 4456448, sign: false }); + data.append(FP16x16 { mag: 4587520, sign: false }); + data.append(FP16x16 { mag: 4718592, sign: false }); + data.append(FP16x16 { mag: 4849664, sign: false }); + data.append(FP16x16 { mag: 4980736, sign: false }); + data.append(FP16x16 { mag: 5111808, sign: false }); + data.append(FP16x16 { mag: 5242880, sign: false }); + data.append(FP16x16 { mag: 5373952, sign: false }); + data.append(FP16x16 { mag: 5505024, sign: false }); + data.append(FP16x16 { mag: 5636096, sign: false }); + data.append(FP16x16 { mag: 5767168, sign: false }); + data.append(FP16x16 { mag: 5898240, sign: false }); + data.append(FP16x16 { mag: 6029312, sign: false }); + data.append(FP16x16 { mag: 6160384, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_fp16x16_3d_default.cairo b/tests/nodes/compress_fp16x16_3d_default.cairo new file mode 100644 index 000000000..d9b837a19 --- /dev/null +++ b/tests/nodes/compress_fp16x16_3d_default.cairo @@ -0,0 +1,24 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::operators::tensor::U32TensorPartialEq; +use orion::operators::tensor::U32Tensor; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::FP16x16TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::FP16x16Tensor; + +#[test] +#[available_gas(2000000000)] +fn test_compress_fp16x16_3d_default() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z_0 = output_0::output_0(); + + let y_0 = input_0.compress(condition:input_1, axis:Option::Some(0)); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/compress_fp16x16_3d_default/input_0.cairo b/tests/nodes/compress_fp16x16_3d_default/input_0.cairo new file mode 100644 index 000000000..22bb148eb --- /dev/null +++ b/tests/nodes/compress_fp16x16_3d_default/input_0.cairo @@ -0,0 +1,41 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 851968, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + data.append(FP16x16 { mag: 1114112, sign: false }); + data.append(FP16x16 { mag: 1179648, sign: false }); + data.append(FP16x16 { mag: 1245184, sign: false }); + data.append(FP16x16 { mag: 1310720, sign: false }); + data.append(FP16x16 { mag: 1376256, sign: false }); + data.append(FP16x16 { mag: 1441792, sign: false }); + data.append(FP16x16 { mag: 1507328, sign: false }); + data.append(FP16x16 { mag: 1572864, sign: false }); + data.append(FP16x16 { mag: 1638400, sign: false }); + data.append(FP16x16 { mag: 1703936, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_fp16x16_3d_default/input_1.cairo b/tests/nodes/compress_fp16x16_3d_default/input_1.cairo new file mode 100644 index 000000000..5b1c8b963 --- /dev/null +++ b/tests/nodes/compress_fp16x16_3d_default/input_1.cairo @@ -0,0 +1,14 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(0); + data.append(1); + data.append(1); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_fp16x16_3d_default/output_0.cairo b/tests/nodes/compress_fp16x16_3d_default/output_0.cairo new file mode 100644 index 000000000..32e64e952 --- /dev/null +++ b/tests/nodes/compress_fp16x16_3d_default/output_0.cairo @@ -0,0 +1,32 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 851968, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + data.append(FP16x16 { mag: 1114112, sign: false }); + data.append(FP16x16 { mag: 1179648, sign: false }); + data.append(FP16x16 { mag: 1245184, sign: false }); + data.append(FP16x16 { mag: 1310720, sign: false }); + data.append(FP16x16 { mag: 1376256, sign: false }); + data.append(FP16x16 { mag: 1441792, sign: false }); + data.append(FP16x16 { mag: 1507328, sign: false }); + data.append(FP16x16 { mag: 1572864, sign: false }); + data.append(FP16x16 { mag: 1638400, sign: false }); + data.append(FP16x16 { mag: 1703936, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_fp16x16_3d_noaxis.cairo b/tests/nodes/compress_fp16x16_3d_noaxis.cairo new file mode 100644 index 000000000..2bd536e08 --- /dev/null +++ b/tests/nodes/compress_fp16x16_3d_noaxis.cairo @@ -0,0 +1,24 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::operators::tensor::U32TensorPartialEq; +use orion::operators::tensor::U32Tensor; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::FP16x16TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::FP16x16Tensor; + +#[test] +#[available_gas(2000000000)] +fn test_compress_fp16x16_3d_noaxis() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z_0 = output_0::output_0(); + + let y_0 = input_0.compress(condition:input_1, axis:Option::None(())); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/compress_fp16x16_3d_noaxis/input_0.cairo b/tests/nodes/compress_fp16x16_3d_noaxis/input_0.cairo new file mode 100644 index 000000000..22bb148eb --- /dev/null +++ b/tests/nodes/compress_fp16x16_3d_noaxis/input_0.cairo @@ -0,0 +1,41 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 851968, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + data.append(FP16x16 { mag: 1114112, sign: false }); + data.append(FP16x16 { mag: 1179648, sign: false }); + data.append(FP16x16 { mag: 1245184, sign: false }); + data.append(FP16x16 { mag: 1310720, sign: false }); + data.append(FP16x16 { mag: 1376256, sign: false }); + data.append(FP16x16 { mag: 1441792, sign: false }); + data.append(FP16x16 { mag: 1507328, sign: false }); + data.append(FP16x16 { mag: 1572864, sign: false }); + data.append(FP16x16 { mag: 1638400, sign: false }); + data.append(FP16x16 { mag: 1703936, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_fp16x16_3d_noaxis/input_1.cairo b/tests/nodes/compress_fp16x16_3d_noaxis/input_1.cairo new file mode 100644 index 000000000..2f28d9179 --- /dev/null +++ b/tests/nodes/compress_fp16x16_3d_noaxis/input_1.cairo @@ -0,0 +1,20 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(9); + + let mut data = ArrayTrait::new(); + data.append(1); + data.append(0); + data.append(1); + data.append(0); + data.append(1); + data.append(1); + data.append(1); + data.append(1); + data.append(1); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_fp16x16_3d_noaxis/output_0.cairo b/tests/nodes/compress_fp16x16_3d_noaxis/output_0.cairo new file mode 100644 index 000000000..e443f84b4 --- /dev/null +++ b/tests/nodes/compress_fp16x16_3d_noaxis/output_0.cairo @@ -0,0 +1,19 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(7); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_fp8x23_3d_axis1.cairo b/tests/nodes/compress_fp8x23_3d_axis1.cairo new file mode 100644 index 000000000..edd013f54 --- /dev/null +++ b/tests/nodes/compress_fp8x23_3d_axis1.cairo @@ -0,0 +1,24 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::operators::tensor::U32TensorPartialEq; +use orion::operators::tensor::U32Tensor; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::FP8x23Tensor; +use orion::operators::tensor::FP8x23TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; + +#[test] +#[available_gas(2000000000)] +fn test_compress_fp8x23_3d_axis1() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z_0 = output_0::output_0(); + + let y_0 = input_0.compress(condition:input_1, axis:Option::Some(1)); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/compress_fp8x23_3d_axis1/input_0.cairo b/tests/nodes/compress_fp8x23_3d_axis1/input_0.cairo new file mode 100644 index 000000000..158782e65 --- /dev/null +++ b/tests/nodes/compress_fp8x23_3d_axis1/input_0.cairo @@ -0,0 +1,41 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 0, sign: false }); + data.append(FP8x23 { mag: 8388608, sign: false }); + data.append(FP8x23 { mag: 16777216, sign: false }); + data.append(FP8x23 { mag: 25165824, sign: false }); + data.append(FP8x23 { mag: 33554432, sign: false }); + data.append(FP8x23 { mag: 41943040, sign: false }); + data.append(FP8x23 { mag: 50331648, sign: false }); + data.append(FP8x23 { mag: 58720256, sign: false }); + data.append(FP8x23 { mag: 67108864, sign: false }); + data.append(FP8x23 { mag: 75497472, sign: false }); + data.append(FP8x23 { mag: 83886080, sign: false }); + data.append(FP8x23 { mag: 92274688, sign: false }); + data.append(FP8x23 { mag: 100663296, sign: false }); + data.append(FP8x23 { mag: 109051904, sign: false }); + data.append(FP8x23 { mag: 117440512, sign: false }); + data.append(FP8x23 { mag: 125829120, sign: false }); + data.append(FP8x23 { mag: 134217728, sign: false }); + data.append(FP8x23 { mag: 142606336, sign: false }); + data.append(FP8x23 { mag: 150994944, sign: false }); + data.append(FP8x23 { mag: 159383552, sign: false }); + data.append(FP8x23 { mag: 167772160, sign: false }); + data.append(FP8x23 { mag: 176160768, sign: false }); + data.append(FP8x23 { mag: 184549376, sign: false }); + data.append(FP8x23 { mag: 192937984, sign: false }); + data.append(FP8x23 { mag: 201326592, sign: false }); + data.append(FP8x23 { mag: 209715200, sign: false }); + data.append(FP8x23 { mag: 218103808, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_fp8x23_3d_axis1/input_1.cairo b/tests/nodes/compress_fp8x23_3d_axis1/input_1.cairo new file mode 100644 index 000000000..5b1c8b963 --- /dev/null +++ b/tests/nodes/compress_fp8x23_3d_axis1/input_1.cairo @@ -0,0 +1,14 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(0); + data.append(1); + data.append(1); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_fp8x23_3d_axis1/output_0.cairo b/tests/nodes/compress_fp8x23_3d_axis1/output_0.cairo new file mode 100644 index 000000000..83aae904e --- /dev/null +++ b/tests/nodes/compress_fp8x23_3d_axis1/output_0.cairo @@ -0,0 +1,32 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(2); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 25165824, sign: false }); + data.append(FP8x23 { mag: 33554432, sign: false }); + data.append(FP8x23 { mag: 41943040, sign: false }); + data.append(FP8x23 { mag: 50331648, sign: false }); + data.append(FP8x23 { mag: 58720256, sign: false }); + data.append(FP8x23 { mag: 67108864, sign: false }); + data.append(FP8x23 { mag: 100663296, sign: false }); + data.append(FP8x23 { mag: 109051904, sign: false }); + data.append(FP8x23 { mag: 117440512, sign: false }); + data.append(FP8x23 { mag: 125829120, sign: false }); + data.append(FP8x23 { mag: 134217728, sign: false }); + data.append(FP8x23 { mag: 142606336, sign: false }); + data.append(FP8x23 { mag: 176160768, sign: false }); + data.append(FP8x23 { mag: 184549376, sign: false }); + data.append(FP8x23 { mag: 192937984, sign: false }); + data.append(FP8x23 { mag: 201326592, sign: false }); + data.append(FP8x23 { mag: 209715200, sign: false }); + data.append(FP8x23 { mag: 218103808, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_fp8x23_3d_axis2.cairo b/tests/nodes/compress_fp8x23_3d_axis2.cairo new file mode 100644 index 000000000..580a6272a --- /dev/null +++ b/tests/nodes/compress_fp8x23_3d_axis2.cairo @@ -0,0 +1,24 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::operators::tensor::U32TensorPartialEq; +use orion::operators::tensor::U32Tensor; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::FP8x23Tensor; +use orion::operators::tensor::FP8x23TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; + +#[test] +#[available_gas(2000000000)] +fn test_compress_fp8x23_3d_axis2() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z_0 = output_0::output_0(); + + let y_0 = input_0.compress(condition:input_1, axis:Option::Some(2)); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/compress_fp8x23_3d_axis2/input_0.cairo b/tests/nodes/compress_fp8x23_3d_axis2/input_0.cairo new file mode 100644 index 000000000..158782e65 --- /dev/null +++ b/tests/nodes/compress_fp8x23_3d_axis2/input_0.cairo @@ -0,0 +1,41 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 0, sign: false }); + data.append(FP8x23 { mag: 8388608, sign: false }); + data.append(FP8x23 { mag: 16777216, sign: false }); + data.append(FP8x23 { mag: 25165824, sign: false }); + data.append(FP8x23 { mag: 33554432, sign: false }); + data.append(FP8x23 { mag: 41943040, sign: false }); + data.append(FP8x23 { mag: 50331648, sign: false }); + data.append(FP8x23 { mag: 58720256, sign: false }); + data.append(FP8x23 { mag: 67108864, sign: false }); + data.append(FP8x23 { mag: 75497472, sign: false }); + data.append(FP8x23 { mag: 83886080, sign: false }); + data.append(FP8x23 { mag: 92274688, sign: false }); + data.append(FP8x23 { mag: 100663296, sign: false }); + data.append(FP8x23 { mag: 109051904, sign: false }); + data.append(FP8x23 { mag: 117440512, sign: false }); + data.append(FP8x23 { mag: 125829120, sign: false }); + data.append(FP8x23 { mag: 134217728, sign: false }); + data.append(FP8x23 { mag: 142606336, sign: false }); + data.append(FP8x23 { mag: 150994944, sign: false }); + data.append(FP8x23 { mag: 159383552, sign: false }); + data.append(FP8x23 { mag: 167772160, sign: false }); + data.append(FP8x23 { mag: 176160768, sign: false }); + data.append(FP8x23 { mag: 184549376, sign: false }); + data.append(FP8x23 { mag: 192937984, sign: false }); + data.append(FP8x23 { mag: 201326592, sign: false }); + data.append(FP8x23 { mag: 209715200, sign: false }); + data.append(FP8x23 { mag: 218103808, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_fp8x23_3d_axis2/input_1.cairo b/tests/nodes/compress_fp8x23_3d_axis2/input_1.cairo new file mode 100644 index 000000000..5b1c8b963 --- /dev/null +++ b/tests/nodes/compress_fp8x23_3d_axis2/input_1.cairo @@ -0,0 +1,14 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(0); + data.append(1); + data.append(1); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_fp8x23_3d_axis2/output_0.cairo b/tests/nodes/compress_fp8x23_3d_axis2/output_0.cairo new file mode 100644 index 000000000..340000317 --- /dev/null +++ b/tests/nodes/compress_fp8x23_3d_axis2/output_0.cairo @@ -0,0 +1,32 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 8388608, sign: false }); + data.append(FP8x23 { mag: 16777216, sign: false }); + data.append(FP8x23 { mag: 33554432, sign: false }); + data.append(FP8x23 { mag: 41943040, sign: false }); + data.append(FP8x23 { mag: 58720256, sign: false }); + data.append(FP8x23 { mag: 67108864, sign: false }); + data.append(FP8x23 { mag: 83886080, sign: false }); + data.append(FP8x23 { mag: 92274688, sign: false }); + data.append(FP8x23 { mag: 109051904, sign: false }); + data.append(FP8x23 { mag: 117440512, sign: false }); + data.append(FP8x23 { mag: 134217728, sign: false }); + data.append(FP8x23 { mag: 142606336, sign: false }); + data.append(FP8x23 { mag: 159383552, sign: false }); + data.append(FP8x23 { mag: 167772160, sign: false }); + data.append(FP8x23 { mag: 184549376, sign: false }); + data.append(FP8x23 { mag: 192937984, sign: false }); + data.append(FP8x23 { mag: 209715200, sign: false }); + data.append(FP8x23 { mag: 218103808, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_fp8x23_3d_default.cairo b/tests/nodes/compress_fp8x23_3d_default.cairo new file mode 100644 index 000000000..a927f7fe8 --- /dev/null +++ b/tests/nodes/compress_fp8x23_3d_default.cairo @@ -0,0 +1,24 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::operators::tensor::U32TensorPartialEq; +use orion::operators::tensor::U32Tensor; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::FP8x23Tensor; +use orion::operators::tensor::FP8x23TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; + +#[test] +#[available_gas(2000000000)] +fn test_compress_fp8x23_3d_default() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z_0 = output_0::output_0(); + + let y_0 = input_0.compress(condition:input_1, axis:Option::Some(0)); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/compress_fp8x23_3d_default/input_0.cairo b/tests/nodes/compress_fp8x23_3d_default/input_0.cairo new file mode 100644 index 000000000..158782e65 --- /dev/null +++ b/tests/nodes/compress_fp8x23_3d_default/input_0.cairo @@ -0,0 +1,41 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 0, sign: false }); + data.append(FP8x23 { mag: 8388608, sign: false }); + data.append(FP8x23 { mag: 16777216, sign: false }); + data.append(FP8x23 { mag: 25165824, sign: false }); + data.append(FP8x23 { mag: 33554432, sign: false }); + data.append(FP8x23 { mag: 41943040, sign: false }); + data.append(FP8x23 { mag: 50331648, sign: false }); + data.append(FP8x23 { mag: 58720256, sign: false }); + data.append(FP8x23 { mag: 67108864, sign: false }); + data.append(FP8x23 { mag: 75497472, sign: false }); + data.append(FP8x23 { mag: 83886080, sign: false }); + data.append(FP8x23 { mag: 92274688, sign: false }); + data.append(FP8x23 { mag: 100663296, sign: false }); + data.append(FP8x23 { mag: 109051904, sign: false }); + data.append(FP8x23 { mag: 117440512, sign: false }); + data.append(FP8x23 { mag: 125829120, sign: false }); + data.append(FP8x23 { mag: 134217728, sign: false }); + data.append(FP8x23 { mag: 142606336, sign: false }); + data.append(FP8x23 { mag: 150994944, sign: false }); + data.append(FP8x23 { mag: 159383552, sign: false }); + data.append(FP8x23 { mag: 167772160, sign: false }); + data.append(FP8x23 { mag: 176160768, sign: false }); + data.append(FP8x23 { mag: 184549376, sign: false }); + data.append(FP8x23 { mag: 192937984, sign: false }); + data.append(FP8x23 { mag: 201326592, sign: false }); + data.append(FP8x23 { mag: 209715200, sign: false }); + data.append(FP8x23 { mag: 218103808, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_fp8x23_3d_default/input_1.cairo b/tests/nodes/compress_fp8x23_3d_default/input_1.cairo new file mode 100644 index 000000000..5b1c8b963 --- /dev/null +++ b/tests/nodes/compress_fp8x23_3d_default/input_1.cairo @@ -0,0 +1,14 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(0); + data.append(1); + data.append(1); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_fp8x23_3d_default/output_0.cairo b/tests/nodes/compress_fp8x23_3d_default/output_0.cairo new file mode 100644 index 000000000..baf4ef69f --- /dev/null +++ b/tests/nodes/compress_fp8x23_3d_default/output_0.cairo @@ -0,0 +1,32 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 75497472, sign: false }); + data.append(FP8x23 { mag: 83886080, sign: false }); + data.append(FP8x23 { mag: 92274688, sign: false }); + data.append(FP8x23 { mag: 100663296, sign: false }); + data.append(FP8x23 { mag: 109051904, sign: false }); + data.append(FP8x23 { mag: 117440512, sign: false }); + data.append(FP8x23 { mag: 125829120, sign: false }); + data.append(FP8x23 { mag: 134217728, sign: false }); + data.append(FP8x23 { mag: 142606336, sign: false }); + data.append(FP8x23 { mag: 150994944, sign: false }); + data.append(FP8x23 { mag: 159383552, sign: false }); + data.append(FP8x23 { mag: 167772160, sign: false }); + data.append(FP8x23 { mag: 176160768, sign: false }); + data.append(FP8x23 { mag: 184549376, sign: false }); + data.append(FP8x23 { mag: 192937984, sign: false }); + data.append(FP8x23 { mag: 201326592, sign: false }); + data.append(FP8x23 { mag: 209715200, sign: false }); + data.append(FP8x23 { mag: 218103808, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_i32_3d_axis1.cairo b/tests/nodes/compress_i32_3d_axis1.cairo new file mode 100644 index 000000000..f69cf2e2a --- /dev/null +++ b/tests/nodes/compress_i32_3d_axis1.cairo @@ -0,0 +1,24 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::operators::tensor::U32TensorPartialEq; +use orion::operators::tensor::U32Tensor; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::I32TensorPartialEq; +use orion::operators::tensor::I32Tensor; +use core::array::{ArrayTrait, SpanTrait}; + +#[test] +#[available_gas(2000000000)] +fn test_compress_i32_3d_axis1() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z_0 = output_0::output_0(); + + let y_0 = input_0.compress(condition:input_1, axis:Option::Some(1)); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/compress_i32_3d_axis1/input_0.cairo b/tests/nodes/compress_i32_3d_axis1/input_0.cairo new file mode 100644 index 000000000..9b4392c28 --- /dev/null +++ b/tests/nodes/compress_i32_3d_axis1/input_0.cairo @@ -0,0 +1,41 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::I32Tensor; +use orion::numbers::{IntegerTrait, i32}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(i32 { mag: 0, sign: false }); + data.append(i32 { mag: 1, sign: false }); + data.append(i32 { mag: 2, sign: false }); + data.append(i32 { mag: 3, sign: false }); + data.append(i32 { mag: 4, sign: false }); + data.append(i32 { mag: 5, sign: false }); + data.append(i32 { mag: 6, sign: false }); + data.append(i32 { mag: 7, sign: false }); + data.append(i32 { mag: 8, sign: false }); + data.append(i32 { mag: 9, sign: false }); + data.append(i32 { mag: 10, sign: false }); + data.append(i32 { mag: 11, sign: false }); + data.append(i32 { mag: 12, sign: false }); + data.append(i32 { mag: 13, sign: false }); + data.append(i32 { mag: 14, sign: false }); + data.append(i32 { mag: 15, sign: false }); + data.append(i32 { mag: 16, sign: false }); + data.append(i32 { mag: 17, sign: false }); + data.append(i32 { mag: 18, sign: false }); + data.append(i32 { mag: 19, sign: false }); + data.append(i32 { mag: 20, sign: false }); + data.append(i32 { mag: 21, sign: false }); + data.append(i32 { mag: 22, sign: false }); + data.append(i32 { mag: 23, sign: false }); + data.append(i32 { mag: 24, sign: false }); + data.append(i32 { mag: 25, sign: false }); + data.append(i32 { mag: 26, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_i32_3d_axis1/input_1.cairo b/tests/nodes/compress_i32_3d_axis1/input_1.cairo new file mode 100644 index 000000000..5b1c8b963 --- /dev/null +++ b/tests/nodes/compress_i32_3d_axis1/input_1.cairo @@ -0,0 +1,14 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(0); + data.append(1); + data.append(1); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_i32_3d_axis1/output_0.cairo b/tests/nodes/compress_i32_3d_axis1/output_0.cairo new file mode 100644 index 000000000..4e04ef235 --- /dev/null +++ b/tests/nodes/compress_i32_3d_axis1/output_0.cairo @@ -0,0 +1,32 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::I32Tensor; +use orion::numbers::{IntegerTrait, i32}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(2); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(i32 { mag: 3, sign: false }); + data.append(i32 { mag: 4, sign: false }); + data.append(i32 { mag: 5, sign: false }); + data.append(i32 { mag: 6, sign: false }); + data.append(i32 { mag: 7, sign: false }); + data.append(i32 { mag: 8, sign: false }); + data.append(i32 { mag: 12, sign: false }); + data.append(i32 { mag: 13, sign: false }); + data.append(i32 { mag: 14, sign: false }); + data.append(i32 { mag: 15, sign: false }); + data.append(i32 { mag: 16, sign: false }); + data.append(i32 { mag: 17, sign: false }); + data.append(i32 { mag: 21, sign: false }); + data.append(i32 { mag: 22, sign: false }); + data.append(i32 { mag: 23, sign: false }); + data.append(i32 { mag: 24, sign: false }); + data.append(i32 { mag: 25, sign: false }); + data.append(i32 { mag: 26, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_i32_3d_axis2.cairo b/tests/nodes/compress_i32_3d_axis2.cairo new file mode 100644 index 000000000..bfe01e5a0 --- /dev/null +++ b/tests/nodes/compress_i32_3d_axis2.cairo @@ -0,0 +1,24 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::operators::tensor::U32TensorPartialEq; +use orion::operators::tensor::U32Tensor; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::I32TensorPartialEq; +use orion::operators::tensor::I32Tensor; +use core::array::{ArrayTrait, SpanTrait}; + +#[test] +#[available_gas(2000000000)] +fn test_compress_i32_3d_axis2() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z_0 = output_0::output_0(); + + let y_0 = input_0.compress(condition:input_1, axis:Option::Some(2)); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/compress_i32_3d_axis2/input_0.cairo b/tests/nodes/compress_i32_3d_axis2/input_0.cairo new file mode 100644 index 000000000..9b4392c28 --- /dev/null +++ b/tests/nodes/compress_i32_3d_axis2/input_0.cairo @@ -0,0 +1,41 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::I32Tensor; +use orion::numbers::{IntegerTrait, i32}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(i32 { mag: 0, sign: false }); + data.append(i32 { mag: 1, sign: false }); + data.append(i32 { mag: 2, sign: false }); + data.append(i32 { mag: 3, sign: false }); + data.append(i32 { mag: 4, sign: false }); + data.append(i32 { mag: 5, sign: false }); + data.append(i32 { mag: 6, sign: false }); + data.append(i32 { mag: 7, sign: false }); + data.append(i32 { mag: 8, sign: false }); + data.append(i32 { mag: 9, sign: false }); + data.append(i32 { mag: 10, sign: false }); + data.append(i32 { mag: 11, sign: false }); + data.append(i32 { mag: 12, sign: false }); + data.append(i32 { mag: 13, sign: false }); + data.append(i32 { mag: 14, sign: false }); + data.append(i32 { mag: 15, sign: false }); + data.append(i32 { mag: 16, sign: false }); + data.append(i32 { mag: 17, sign: false }); + data.append(i32 { mag: 18, sign: false }); + data.append(i32 { mag: 19, sign: false }); + data.append(i32 { mag: 20, sign: false }); + data.append(i32 { mag: 21, sign: false }); + data.append(i32 { mag: 22, sign: false }); + data.append(i32 { mag: 23, sign: false }); + data.append(i32 { mag: 24, sign: false }); + data.append(i32 { mag: 25, sign: false }); + data.append(i32 { mag: 26, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_i32_3d_axis2/input_1.cairo b/tests/nodes/compress_i32_3d_axis2/input_1.cairo new file mode 100644 index 000000000..5b1c8b963 --- /dev/null +++ b/tests/nodes/compress_i32_3d_axis2/input_1.cairo @@ -0,0 +1,14 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(0); + data.append(1); + data.append(1); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_i32_3d_axis2/output_0.cairo b/tests/nodes/compress_i32_3d_axis2/output_0.cairo new file mode 100644 index 000000000..697fe6ee3 --- /dev/null +++ b/tests/nodes/compress_i32_3d_axis2/output_0.cairo @@ -0,0 +1,32 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::I32Tensor; +use orion::numbers::{IntegerTrait, i32}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(i32 { mag: 1, sign: false }); + data.append(i32 { mag: 2, sign: false }); + data.append(i32 { mag: 4, sign: false }); + data.append(i32 { mag: 5, sign: false }); + data.append(i32 { mag: 7, sign: false }); + data.append(i32 { mag: 8, sign: false }); + data.append(i32 { mag: 10, sign: false }); + data.append(i32 { mag: 11, sign: false }); + data.append(i32 { mag: 13, sign: false }); + data.append(i32 { mag: 14, sign: false }); + data.append(i32 { mag: 16, sign: false }); + data.append(i32 { mag: 17, sign: false }); + data.append(i32 { mag: 19, sign: false }); + data.append(i32 { mag: 20, sign: false }); + data.append(i32 { mag: 22, sign: false }); + data.append(i32 { mag: 23, sign: false }); + data.append(i32 { mag: 25, sign: false }); + data.append(i32 { mag: 26, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_i32_3d_default.cairo b/tests/nodes/compress_i32_3d_default.cairo new file mode 100644 index 000000000..b07d95010 --- /dev/null +++ b/tests/nodes/compress_i32_3d_default.cairo @@ -0,0 +1,24 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::operators::tensor::U32TensorPartialEq; +use orion::operators::tensor::U32Tensor; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::I32TensorPartialEq; +use orion::operators::tensor::I32Tensor; +use core::array::{ArrayTrait, SpanTrait}; + +#[test] +#[available_gas(2000000000)] +fn test_compress_i32_3d_default() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z_0 = output_0::output_0(); + + let y_0 = input_0.compress(condition:input_1, axis:Option::Some(0)); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/compress_i32_3d_default/input_0.cairo b/tests/nodes/compress_i32_3d_default/input_0.cairo new file mode 100644 index 000000000..9b4392c28 --- /dev/null +++ b/tests/nodes/compress_i32_3d_default/input_0.cairo @@ -0,0 +1,41 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::I32Tensor; +use orion::numbers::{IntegerTrait, i32}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(i32 { mag: 0, sign: false }); + data.append(i32 { mag: 1, sign: false }); + data.append(i32 { mag: 2, sign: false }); + data.append(i32 { mag: 3, sign: false }); + data.append(i32 { mag: 4, sign: false }); + data.append(i32 { mag: 5, sign: false }); + data.append(i32 { mag: 6, sign: false }); + data.append(i32 { mag: 7, sign: false }); + data.append(i32 { mag: 8, sign: false }); + data.append(i32 { mag: 9, sign: false }); + data.append(i32 { mag: 10, sign: false }); + data.append(i32 { mag: 11, sign: false }); + data.append(i32 { mag: 12, sign: false }); + data.append(i32 { mag: 13, sign: false }); + data.append(i32 { mag: 14, sign: false }); + data.append(i32 { mag: 15, sign: false }); + data.append(i32 { mag: 16, sign: false }); + data.append(i32 { mag: 17, sign: false }); + data.append(i32 { mag: 18, sign: false }); + data.append(i32 { mag: 19, sign: false }); + data.append(i32 { mag: 20, sign: false }); + data.append(i32 { mag: 21, sign: false }); + data.append(i32 { mag: 22, sign: false }); + data.append(i32 { mag: 23, sign: false }); + data.append(i32 { mag: 24, sign: false }); + data.append(i32 { mag: 25, sign: false }); + data.append(i32 { mag: 26, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_i32_3d_default/input_1.cairo b/tests/nodes/compress_i32_3d_default/input_1.cairo new file mode 100644 index 000000000..5b1c8b963 --- /dev/null +++ b/tests/nodes/compress_i32_3d_default/input_1.cairo @@ -0,0 +1,14 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(0); + data.append(1); + data.append(1); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_i32_3d_default/output_0.cairo b/tests/nodes/compress_i32_3d_default/output_0.cairo new file mode 100644 index 000000000..97e177389 --- /dev/null +++ b/tests/nodes/compress_i32_3d_default/output_0.cairo @@ -0,0 +1,32 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::I32Tensor; +use orion::numbers::{IntegerTrait, i32}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(i32 { mag: 9, sign: false }); + data.append(i32 { mag: 10, sign: false }); + data.append(i32 { mag: 11, sign: false }); + data.append(i32 { mag: 12, sign: false }); + data.append(i32 { mag: 13, sign: false }); + data.append(i32 { mag: 14, sign: false }); + data.append(i32 { mag: 15, sign: false }); + data.append(i32 { mag: 16, sign: false }); + data.append(i32 { mag: 17, sign: false }); + data.append(i32 { mag: 18, sign: false }); + data.append(i32 { mag: 19, sign: false }); + data.append(i32 { mag: 20, sign: false }); + data.append(i32 { mag: 21, sign: false }); + data.append(i32 { mag: 22, sign: false }); + data.append(i32 { mag: 23, sign: false }); + data.append(i32 { mag: 24, sign: false }); + data.append(i32 { mag: 25, sign: false }); + data.append(i32 { mag: 26, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_i8_3d_axis1.cairo b/tests/nodes/compress_i8_3d_axis1.cairo new file mode 100644 index 000000000..6a4197ce1 --- /dev/null +++ b/tests/nodes/compress_i8_3d_axis1.cairo @@ -0,0 +1,24 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::operators::tensor::U32TensorPartialEq; +use orion::operators::tensor::U32Tensor; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::I8Tensor; +use orion::operators::tensor::I8TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; + +#[test] +#[available_gas(2000000000)] +fn test_compress_i8_3d_axis1() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z_0 = output_0::output_0(); + + let y_0 = input_0.compress(condition:input_1, axis:Option::Some(1)); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/compress_i8_3d_axis1/input_0.cairo b/tests/nodes/compress_i8_3d_axis1/input_0.cairo new file mode 100644 index 000000000..c095ab988 --- /dev/null +++ b/tests/nodes/compress_i8_3d_axis1/input_0.cairo @@ -0,0 +1,41 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::I8Tensor; +use orion::numbers::{IntegerTrait, i8}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(i8 { mag: 0, sign: false }); + data.append(i8 { mag: 1, sign: false }); + data.append(i8 { mag: 2, sign: false }); + data.append(i8 { mag: 3, sign: false }); + data.append(i8 { mag: 4, sign: false }); + data.append(i8 { mag: 5, sign: false }); + data.append(i8 { mag: 6, sign: false }); + data.append(i8 { mag: 7, sign: false }); + data.append(i8 { mag: 8, sign: false }); + data.append(i8 { mag: 9, sign: false }); + data.append(i8 { mag: 10, sign: false }); + data.append(i8 { mag: 11, sign: false }); + data.append(i8 { mag: 12, sign: false }); + data.append(i8 { mag: 13, sign: false }); + data.append(i8 { mag: 14, sign: false }); + data.append(i8 { mag: 15, sign: false }); + data.append(i8 { mag: 16, sign: false }); + data.append(i8 { mag: 17, sign: false }); + data.append(i8 { mag: 18, sign: false }); + data.append(i8 { mag: 19, sign: false }); + data.append(i8 { mag: 20, sign: false }); + data.append(i8 { mag: 21, sign: false }); + data.append(i8 { mag: 22, sign: false }); + data.append(i8 { mag: 23, sign: false }); + data.append(i8 { mag: 24, sign: false }); + data.append(i8 { mag: 25, sign: false }); + data.append(i8 { mag: 26, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_i8_3d_axis1/input_1.cairo b/tests/nodes/compress_i8_3d_axis1/input_1.cairo new file mode 100644 index 000000000..5b1c8b963 --- /dev/null +++ b/tests/nodes/compress_i8_3d_axis1/input_1.cairo @@ -0,0 +1,14 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(0); + data.append(1); + data.append(1); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_i8_3d_axis1/output_0.cairo b/tests/nodes/compress_i8_3d_axis1/output_0.cairo new file mode 100644 index 000000000..0a6e60910 --- /dev/null +++ b/tests/nodes/compress_i8_3d_axis1/output_0.cairo @@ -0,0 +1,32 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::I8Tensor; +use orion::numbers::{IntegerTrait, i8}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(2); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(i8 { mag: 3, sign: false }); + data.append(i8 { mag: 4, sign: false }); + data.append(i8 { mag: 5, sign: false }); + data.append(i8 { mag: 6, sign: false }); + data.append(i8 { mag: 7, sign: false }); + data.append(i8 { mag: 8, sign: false }); + data.append(i8 { mag: 12, sign: false }); + data.append(i8 { mag: 13, sign: false }); + data.append(i8 { mag: 14, sign: false }); + data.append(i8 { mag: 15, sign: false }); + data.append(i8 { mag: 16, sign: false }); + data.append(i8 { mag: 17, sign: false }); + data.append(i8 { mag: 21, sign: false }); + data.append(i8 { mag: 22, sign: false }); + data.append(i8 { mag: 23, sign: false }); + data.append(i8 { mag: 24, sign: false }); + data.append(i8 { mag: 25, sign: false }); + data.append(i8 { mag: 26, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_i8_3d_axis2.cairo b/tests/nodes/compress_i8_3d_axis2.cairo new file mode 100644 index 000000000..4dd7b5a8f --- /dev/null +++ b/tests/nodes/compress_i8_3d_axis2.cairo @@ -0,0 +1,24 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::operators::tensor::U32TensorPartialEq; +use orion::operators::tensor::U32Tensor; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::I8Tensor; +use orion::operators::tensor::I8TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; + +#[test] +#[available_gas(2000000000)] +fn test_compress_i8_3d_axis2() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z_0 = output_0::output_0(); + + let y_0 = input_0.compress(condition:input_1, axis:Option::Some(2)); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/compress_i8_3d_axis2/input_0.cairo b/tests/nodes/compress_i8_3d_axis2/input_0.cairo new file mode 100644 index 000000000..c095ab988 --- /dev/null +++ b/tests/nodes/compress_i8_3d_axis2/input_0.cairo @@ -0,0 +1,41 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::I8Tensor; +use orion::numbers::{IntegerTrait, i8}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(i8 { mag: 0, sign: false }); + data.append(i8 { mag: 1, sign: false }); + data.append(i8 { mag: 2, sign: false }); + data.append(i8 { mag: 3, sign: false }); + data.append(i8 { mag: 4, sign: false }); + data.append(i8 { mag: 5, sign: false }); + data.append(i8 { mag: 6, sign: false }); + data.append(i8 { mag: 7, sign: false }); + data.append(i8 { mag: 8, sign: false }); + data.append(i8 { mag: 9, sign: false }); + data.append(i8 { mag: 10, sign: false }); + data.append(i8 { mag: 11, sign: false }); + data.append(i8 { mag: 12, sign: false }); + data.append(i8 { mag: 13, sign: false }); + data.append(i8 { mag: 14, sign: false }); + data.append(i8 { mag: 15, sign: false }); + data.append(i8 { mag: 16, sign: false }); + data.append(i8 { mag: 17, sign: false }); + data.append(i8 { mag: 18, sign: false }); + data.append(i8 { mag: 19, sign: false }); + data.append(i8 { mag: 20, sign: false }); + data.append(i8 { mag: 21, sign: false }); + data.append(i8 { mag: 22, sign: false }); + data.append(i8 { mag: 23, sign: false }); + data.append(i8 { mag: 24, sign: false }); + data.append(i8 { mag: 25, sign: false }); + data.append(i8 { mag: 26, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_i8_3d_axis2/input_1.cairo b/tests/nodes/compress_i8_3d_axis2/input_1.cairo new file mode 100644 index 000000000..5b1c8b963 --- /dev/null +++ b/tests/nodes/compress_i8_3d_axis2/input_1.cairo @@ -0,0 +1,14 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(0); + data.append(1); + data.append(1); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_i8_3d_axis2/output_0.cairo b/tests/nodes/compress_i8_3d_axis2/output_0.cairo new file mode 100644 index 000000000..f2d25b67d --- /dev/null +++ b/tests/nodes/compress_i8_3d_axis2/output_0.cairo @@ -0,0 +1,32 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::I8Tensor; +use orion::numbers::{IntegerTrait, i8}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(i8 { mag: 1, sign: false }); + data.append(i8 { mag: 2, sign: false }); + data.append(i8 { mag: 4, sign: false }); + data.append(i8 { mag: 5, sign: false }); + data.append(i8 { mag: 7, sign: false }); + data.append(i8 { mag: 8, sign: false }); + data.append(i8 { mag: 10, sign: false }); + data.append(i8 { mag: 11, sign: false }); + data.append(i8 { mag: 13, sign: false }); + data.append(i8 { mag: 14, sign: false }); + data.append(i8 { mag: 16, sign: false }); + data.append(i8 { mag: 17, sign: false }); + data.append(i8 { mag: 19, sign: false }); + data.append(i8 { mag: 20, sign: false }); + data.append(i8 { mag: 22, sign: false }); + data.append(i8 { mag: 23, sign: false }); + data.append(i8 { mag: 25, sign: false }); + data.append(i8 { mag: 26, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_i8_3d_default.cairo b/tests/nodes/compress_i8_3d_default.cairo new file mode 100644 index 000000000..14b684377 --- /dev/null +++ b/tests/nodes/compress_i8_3d_default.cairo @@ -0,0 +1,24 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::operators::tensor::U32TensorPartialEq; +use orion::operators::tensor::U32Tensor; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::I8Tensor; +use orion::operators::tensor::I8TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; + +#[test] +#[available_gas(2000000000)] +fn test_compress_i8_3d_default() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z_0 = output_0::output_0(); + + let y_0 = input_0.compress(condition:input_1, axis:Option::Some(0)); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/compress_i8_3d_default/input_0.cairo b/tests/nodes/compress_i8_3d_default/input_0.cairo new file mode 100644 index 000000000..c095ab988 --- /dev/null +++ b/tests/nodes/compress_i8_3d_default/input_0.cairo @@ -0,0 +1,41 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::I8Tensor; +use orion::numbers::{IntegerTrait, i8}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(i8 { mag: 0, sign: false }); + data.append(i8 { mag: 1, sign: false }); + data.append(i8 { mag: 2, sign: false }); + data.append(i8 { mag: 3, sign: false }); + data.append(i8 { mag: 4, sign: false }); + data.append(i8 { mag: 5, sign: false }); + data.append(i8 { mag: 6, sign: false }); + data.append(i8 { mag: 7, sign: false }); + data.append(i8 { mag: 8, sign: false }); + data.append(i8 { mag: 9, sign: false }); + data.append(i8 { mag: 10, sign: false }); + data.append(i8 { mag: 11, sign: false }); + data.append(i8 { mag: 12, sign: false }); + data.append(i8 { mag: 13, sign: false }); + data.append(i8 { mag: 14, sign: false }); + data.append(i8 { mag: 15, sign: false }); + data.append(i8 { mag: 16, sign: false }); + data.append(i8 { mag: 17, sign: false }); + data.append(i8 { mag: 18, sign: false }); + data.append(i8 { mag: 19, sign: false }); + data.append(i8 { mag: 20, sign: false }); + data.append(i8 { mag: 21, sign: false }); + data.append(i8 { mag: 22, sign: false }); + data.append(i8 { mag: 23, sign: false }); + data.append(i8 { mag: 24, sign: false }); + data.append(i8 { mag: 25, sign: false }); + data.append(i8 { mag: 26, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_i8_3d_default/input_1.cairo b/tests/nodes/compress_i8_3d_default/input_1.cairo new file mode 100644 index 000000000..5b1c8b963 --- /dev/null +++ b/tests/nodes/compress_i8_3d_default/input_1.cairo @@ -0,0 +1,14 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(0); + data.append(1); + data.append(1); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_i8_3d_default/output_0.cairo b/tests/nodes/compress_i8_3d_default/output_0.cairo new file mode 100644 index 000000000..9472aea9f --- /dev/null +++ b/tests/nodes/compress_i8_3d_default/output_0.cairo @@ -0,0 +1,32 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::I8Tensor; +use orion::numbers::{IntegerTrait, i8}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(i8 { mag: 9, sign: false }); + data.append(i8 { mag: 10, sign: false }); + data.append(i8 { mag: 11, sign: false }); + data.append(i8 { mag: 12, sign: false }); + data.append(i8 { mag: 13, sign: false }); + data.append(i8 { mag: 14, sign: false }); + data.append(i8 { mag: 15, sign: false }); + data.append(i8 { mag: 16, sign: false }); + data.append(i8 { mag: 17, sign: false }); + data.append(i8 { mag: 18, sign: false }); + data.append(i8 { mag: 19, sign: false }); + data.append(i8 { mag: 20, sign: false }); + data.append(i8 { mag: 21, sign: false }); + data.append(i8 { mag: 22, sign: false }); + data.append(i8 { mag: 23, sign: false }); + data.append(i8 { mag: 24, sign: false }); + data.append(i8 { mag: 25, sign: false }); + data.append(i8 { mag: 26, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_u32_3d_axis1.cairo b/tests/nodes/compress_u32_3d_axis1.cairo new file mode 100644 index 000000000..dda59bead --- /dev/null +++ b/tests/nodes/compress_u32_3d_axis1.cairo @@ -0,0 +1,22 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::operators::tensor::{TensorTrait, Tensor}; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::U32TensorPartialEq; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::U32Tensor; + +#[test] +#[available_gas(2000000000)] +fn test_compress_u32_3d_axis1() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z_0 = output_0::output_0(); + + let y_0 = input_0.compress(condition:input_1, axis:Option::Some(1)); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/compress_u32_3d_axis1/input_0.cairo b/tests/nodes/compress_u32_3d_axis1/input_0.cairo new file mode 100644 index 000000000..3b921bc94 --- /dev/null +++ b/tests/nodes/compress_u32_3d_axis1/input_0.cairo @@ -0,0 +1,49 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(4); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(0); + data.append(1); + data.append(2); + data.append(3); + data.append(4); + data.append(5); + data.append(6); + data.append(7); + data.append(8); + data.append(9); + data.append(10); + data.append(11); + data.append(12); + data.append(13); + data.append(14); + data.append(15); + data.append(16); + data.append(17); + data.append(18); + data.append(19); + data.append(20); + data.append(21); + data.append(22); + data.append(23); + data.append(24); + data.append(25); + data.append(26); + data.append(27); + data.append(28); + data.append(29); + data.append(30); + data.append(31); + data.append(32); + data.append(33); + data.append(34); + data.append(35); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_u32_3d_axis1/input_1.cairo b/tests/nodes/compress_u32_3d_axis1/input_1.cairo new file mode 100644 index 000000000..5b1c8b963 --- /dev/null +++ b/tests/nodes/compress_u32_3d_axis1/input_1.cairo @@ -0,0 +1,14 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(0); + data.append(1); + data.append(1); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_u32_3d_axis1/output_0.cairo b/tests/nodes/compress_u32_3d_axis1/output_0.cairo new file mode 100644 index 000000000..41b41498a --- /dev/null +++ b/tests/nodes/compress_u32_3d_axis1/output_0.cairo @@ -0,0 +1,31 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(2); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(3); + data.append(4); + data.append(5); + data.append(6); + data.append(7); + data.append(8); + data.append(15); + data.append(16); + data.append(17); + data.append(18); + data.append(19); + data.append(20); + data.append(27); + data.append(28); + data.append(29); + data.append(30); + data.append(31); + data.append(32); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_u32_3d_axis2.cairo b/tests/nodes/compress_u32_3d_axis2.cairo new file mode 100644 index 000000000..ba8fa77ef --- /dev/null +++ b/tests/nodes/compress_u32_3d_axis2.cairo @@ -0,0 +1,22 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::operators::tensor::{TensorTrait, Tensor}; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::U32TensorPartialEq; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::U32Tensor; + +#[test] +#[available_gas(2000000000)] +fn test_compress_u32_3d_axis2() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z_0 = output_0::output_0(); + + let y_0 = input_0.compress(condition:input_1, axis:Option::Some(2)); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/compress_u32_3d_axis2/input_0.cairo b/tests/nodes/compress_u32_3d_axis2/input_0.cairo new file mode 100644 index 000000000..bb021b506 --- /dev/null +++ b/tests/nodes/compress_u32_3d_axis2/input_0.cairo @@ -0,0 +1,61 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(4); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(0); + data.append(1); + data.append(2); + data.append(3); + data.append(4); + data.append(5); + data.append(6); + data.append(7); + data.append(8); + data.append(9); + data.append(10); + data.append(11); + data.append(12); + data.append(13); + data.append(14); + data.append(15); + data.append(16); + data.append(17); + data.append(18); + data.append(19); + data.append(20); + data.append(21); + data.append(22); + data.append(23); + data.append(24); + data.append(25); + data.append(26); + data.append(27); + data.append(28); + data.append(29); + data.append(30); + data.append(31); + data.append(32); + data.append(33); + data.append(34); + data.append(35); + data.append(36); + data.append(37); + data.append(38); + data.append(39); + data.append(40); + data.append(41); + data.append(42); + data.append(43); + data.append(44); + data.append(45); + data.append(46); + data.append(47); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_u32_3d_axis2/input_1.cairo b/tests/nodes/compress_u32_3d_axis2/input_1.cairo new file mode 100644 index 000000000..5b1c8b963 --- /dev/null +++ b/tests/nodes/compress_u32_3d_axis2/input_1.cairo @@ -0,0 +1,14 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(0); + data.append(1); + data.append(1); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_u32_3d_axis2/output_0.cairo b/tests/nodes/compress_u32_3d_axis2/output_0.cairo new file mode 100644 index 000000000..029f6af3a --- /dev/null +++ b/tests/nodes/compress_u32_3d_axis2/output_0.cairo @@ -0,0 +1,37 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(4); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(1); + data.append(2); + data.append(5); + data.append(6); + data.append(9); + data.append(10); + data.append(13); + data.append(14); + data.append(17); + data.append(18); + data.append(21); + data.append(22); + data.append(25); + data.append(26); + data.append(29); + data.append(30); + data.append(33); + data.append(34); + data.append(37); + data.append(38); + data.append(41); + data.append(42); + data.append(45); + data.append(46); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_u32_3d_axis2_2.cairo b/tests/nodes/compress_u32_3d_axis2_2.cairo new file mode 100644 index 000000000..aa283b2cc --- /dev/null +++ b/tests/nodes/compress_u32_3d_axis2_2.cairo @@ -0,0 +1,22 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::operators::tensor::{TensorTrait, Tensor}; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::U32TensorPartialEq; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::U32Tensor; + +#[test] +#[available_gas(2000000000)] +fn test_compress_u32_3d_axis2_2() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z_0 = output_0::output_0(); + + let y_0 = input_0.compress(condition:input_1, axis:Option::Some(2)); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/compress_u32_3d_axis2_2/input_0.cairo b/tests/nodes/compress_u32_3d_axis2_2/input_0.cairo new file mode 100644 index 000000000..f92e7b071 --- /dev/null +++ b/tests/nodes/compress_u32_3d_axis2_2/input_0.cairo @@ -0,0 +1,73 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(0); + data.append(1); + data.append(2); + data.append(3); + data.append(4); + data.append(5); + data.append(6); + data.append(7); + data.append(8); + data.append(9); + data.append(10); + data.append(11); + data.append(12); + data.append(13); + data.append(14); + data.append(15); + data.append(16); + data.append(17); + data.append(18); + data.append(19); + data.append(20); + data.append(21); + data.append(22); + data.append(23); + data.append(24); + data.append(25); + data.append(26); + data.append(27); + data.append(28); + data.append(29); + data.append(30); + data.append(31); + data.append(32); + data.append(33); + data.append(34); + data.append(35); + data.append(36); + data.append(37); + data.append(38); + data.append(39); + data.append(40); + data.append(41); + data.append(42); + data.append(43); + data.append(44); + data.append(45); + data.append(46); + data.append(47); + data.append(48); + data.append(49); + data.append(50); + data.append(51); + data.append(52); + data.append(53); + data.append(54); + data.append(55); + data.append(56); + data.append(57); + data.append(58); + data.append(59); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_u32_3d_axis2_2/input_1.cairo b/tests/nodes/compress_u32_3d_axis2_2/input_1.cairo new file mode 100644 index 000000000..5b1c8b963 --- /dev/null +++ b/tests/nodes/compress_u32_3d_axis2_2/input_1.cairo @@ -0,0 +1,14 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(0); + data.append(1); + data.append(1); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_u32_3d_axis2_2/output_0.cairo b/tests/nodes/compress_u32_3d_axis2_2/output_0.cairo new file mode 100644 index 000000000..e80139ba6 --- /dev/null +++ b/tests/nodes/compress_u32_3d_axis2_2/output_0.cairo @@ -0,0 +1,37 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(4); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(1); + data.append(2); + data.append(6); + data.append(7); + data.append(11); + data.append(12); + data.append(16); + data.append(17); + data.append(21); + data.append(22); + data.append(26); + data.append(27); + data.append(31); + data.append(32); + data.append(36); + data.append(37); + data.append(41); + data.append(42); + data.append(46); + data.append(47); + data.append(51); + data.append(52); + data.append(56); + data.append(57); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_u32_3d_axis3.cairo b/tests/nodes/compress_u32_3d_axis3.cairo new file mode 100644 index 000000000..62684b39f --- /dev/null +++ b/tests/nodes/compress_u32_3d_axis3.cairo @@ -0,0 +1,22 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::operators::tensor::{TensorTrait, Tensor}; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::U32TensorPartialEq; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::U32Tensor; + +#[test] +#[available_gas(2000000000)] +fn test_compress_u32_3d_axis3() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z_0 = output_0::output_0(); + + let y_0 = input_0.compress(condition:input_1, axis:Option::Some(3)); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/compress_u32_3d_axis3/input_0.cairo b/tests/nodes/compress_u32_3d_axis3/input_0.cairo new file mode 100644 index 000000000..05ec2b44a --- /dev/null +++ b/tests/nodes/compress_u32_3d_axis3/input_0.cairo @@ -0,0 +1,284 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(5); + shape.append(6); + + let mut data = ArrayTrait::new(); + data.append(0); + data.append(1); + data.append(2); + data.append(3); + data.append(4); + data.append(5); + data.append(6); + data.append(7); + data.append(8); + data.append(9); + data.append(10); + data.append(11); + data.append(12); + data.append(13); + data.append(14); + data.append(15); + data.append(16); + data.append(17); + data.append(18); + data.append(19); + data.append(20); + data.append(21); + data.append(22); + data.append(23); + data.append(24); + data.append(25); + data.append(26); + data.append(27); + data.append(28); + data.append(29); + data.append(30); + data.append(31); + data.append(32); + data.append(33); + data.append(34); + data.append(35); + data.append(36); + data.append(37); + data.append(38); + data.append(39); + data.append(40); + data.append(41); + data.append(42); + data.append(43); + data.append(44); + data.append(45); + data.append(46); + data.append(47); + data.append(48); + data.append(49); + data.append(50); + data.append(51); + data.append(52); + data.append(53); + data.append(54); + data.append(55); + data.append(56); + data.append(57); + data.append(58); + data.append(59); + data.append(60); + data.append(61); + data.append(62); + data.append(63); + data.append(64); + data.append(65); + data.append(66); + data.append(67); + data.append(68); + data.append(69); + data.append(70); + data.append(71); + data.append(72); + data.append(73); + data.append(74); + data.append(75); + data.append(76); + data.append(77); + data.append(78); + data.append(79); + data.append(80); + data.append(81); + data.append(82); + data.append(83); + data.append(84); + data.append(85); + data.append(86); + data.append(87); + data.append(88); + data.append(89); + data.append(90); + data.append(91); + data.append(92); + data.append(93); + data.append(94); + data.append(95); + data.append(96); + data.append(97); + data.append(98); + data.append(99); + data.append(100); + data.append(101); + data.append(102); + data.append(103); + data.append(104); + data.append(105); + data.append(106); + data.append(107); + data.append(108); + data.append(109); + data.append(110); + data.append(111); + data.append(112); + data.append(113); + data.append(114); + data.append(115); + data.append(116); + data.append(117); + data.append(118); + data.append(119); + data.append(120); + data.append(121); + data.append(122); + data.append(123); + data.append(124); + data.append(125); + data.append(126); + data.append(127); + data.append(128); + data.append(129); + data.append(130); + data.append(131); + data.append(132); + data.append(133); + data.append(134); + data.append(135); + data.append(136); + data.append(137); + data.append(138); + data.append(139); + data.append(140); + data.append(141); + data.append(142); + data.append(143); + data.append(144); + data.append(145); + data.append(146); + data.append(147); + data.append(148); + data.append(149); + data.append(150); + data.append(151); + data.append(152); + data.append(153); + data.append(154); + data.append(155); + data.append(156); + data.append(157); + data.append(158); + data.append(159); + data.append(160); + data.append(161); + data.append(162); + data.append(163); + data.append(164); + data.append(165); + data.append(166); + data.append(167); + data.append(168); + data.append(169); + data.append(170); + data.append(171); + data.append(172); + data.append(173); + data.append(174); + data.append(175); + data.append(176); + data.append(177); + data.append(178); + data.append(179); + data.append(180); + data.append(181); + data.append(182); + data.append(183); + data.append(184); + data.append(185); + data.append(186); + data.append(187); + data.append(188); + data.append(189); + data.append(190); + data.append(191); + data.append(192); + data.append(193); + data.append(194); + data.append(195); + data.append(196); + data.append(197); + data.append(198); + data.append(199); + data.append(200); + data.append(201); + data.append(202); + data.append(203); + data.append(204); + data.append(205); + data.append(206); + data.append(207); + data.append(208); + data.append(209); + data.append(210); + data.append(211); + data.append(212); + data.append(213); + data.append(214); + data.append(215); + data.append(216); + data.append(217); + data.append(218); + data.append(219); + data.append(220); + data.append(221); + data.append(222); + data.append(223); + data.append(224); + data.append(225); + data.append(226); + data.append(227); + data.append(228); + data.append(229); + data.append(230); + data.append(231); + data.append(232); + data.append(233); + data.append(234); + data.append(235); + data.append(236); + data.append(237); + data.append(238); + data.append(239); + data.append(240); + data.append(241); + data.append(242); + data.append(243); + data.append(244); + data.append(245); + data.append(246); + data.append(247); + data.append(248); + data.append(249); + data.append(250); + data.append(251); + data.append(252); + data.append(253); + data.append(254); + data.append(255); + data.append(256); + data.append(257); + data.append(258); + data.append(259); + data.append(260); + data.append(261); + data.append(262); + data.append(263); + data.append(264); + data.append(265); + data.append(266); + data.append(267); + data.append(268); + data.append(269); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_u32_3d_axis3/input_1.cairo b/tests/nodes/compress_u32_3d_axis3/input_1.cairo new file mode 100644 index 000000000..ae751534e --- /dev/null +++ b/tests/nodes/compress_u32_3d_axis3/input_1.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(6); + + let mut data = ArrayTrait::new(); + data.append(0); + data.append(1); + data.append(1); + data.append(1); + data.append(0); + data.append(1); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_u32_3d_axis3/output_0.cairo b/tests/nodes/compress_u32_3d_axis3/output_0.cairo new file mode 100644 index 000000000..069e45a81 --- /dev/null +++ b/tests/nodes/compress_u32_3d_axis3/output_0.cairo @@ -0,0 +1,194 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(3); + shape.append(5); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(1); + data.append(2); + data.append(3); + data.append(5); + data.append(7); + data.append(8); + data.append(9); + data.append(11); + data.append(13); + data.append(14); + data.append(15); + data.append(17); + data.append(19); + data.append(20); + data.append(21); + data.append(23); + data.append(25); + data.append(26); + data.append(27); + data.append(29); + data.append(31); + data.append(32); + data.append(33); + data.append(35); + data.append(37); + data.append(38); + data.append(39); + data.append(41); + data.append(43); + data.append(44); + data.append(45); + data.append(47); + data.append(49); + data.append(50); + data.append(51); + data.append(53); + data.append(55); + data.append(56); + data.append(57); + data.append(59); + data.append(61); + data.append(62); + data.append(63); + data.append(65); + data.append(67); + data.append(68); + data.append(69); + data.append(71); + data.append(73); + data.append(74); + data.append(75); + data.append(77); + data.append(79); + data.append(80); + data.append(81); + data.append(83); + data.append(85); + data.append(86); + data.append(87); + data.append(89); + data.append(91); + data.append(92); + data.append(93); + data.append(95); + data.append(97); + data.append(98); + data.append(99); + data.append(101); + data.append(103); + data.append(104); + data.append(105); + data.append(107); + data.append(109); + data.append(110); + data.append(111); + data.append(113); + data.append(115); + data.append(116); + data.append(117); + data.append(119); + data.append(121); + data.append(122); + data.append(123); + data.append(125); + data.append(127); + data.append(128); + data.append(129); + data.append(131); + data.append(133); + data.append(134); + data.append(135); + data.append(137); + data.append(139); + data.append(140); + data.append(141); + data.append(143); + data.append(145); + data.append(146); + data.append(147); + data.append(149); + data.append(151); + data.append(152); + data.append(153); + data.append(155); + data.append(157); + data.append(158); + data.append(159); + data.append(161); + data.append(163); + data.append(164); + data.append(165); + data.append(167); + data.append(169); + data.append(170); + data.append(171); + data.append(173); + data.append(175); + data.append(176); + data.append(177); + data.append(179); + data.append(181); + data.append(182); + data.append(183); + data.append(185); + data.append(187); + data.append(188); + data.append(189); + data.append(191); + data.append(193); + data.append(194); + data.append(195); + data.append(197); + data.append(199); + data.append(200); + data.append(201); + data.append(203); + data.append(205); + data.append(206); + data.append(207); + data.append(209); + data.append(211); + data.append(212); + data.append(213); + data.append(215); + data.append(217); + data.append(218); + data.append(219); + data.append(221); + data.append(223); + data.append(224); + data.append(225); + data.append(227); + data.append(229); + data.append(230); + data.append(231); + data.append(233); + data.append(235); + data.append(236); + data.append(237); + data.append(239); + data.append(241); + data.append(242); + data.append(243); + data.append(245); + data.append(247); + data.append(248); + data.append(249); + data.append(251); + data.append(253); + data.append(254); + data.append(255); + data.append(257); + data.append(259); + data.append(260); + data.append(261); + data.append(263); + data.append(265); + data.append(266); + data.append(267); + data.append(269); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_u32_3d_default.cairo b/tests/nodes/compress_u32_3d_default.cairo new file mode 100644 index 000000000..058750c53 --- /dev/null +++ b/tests/nodes/compress_u32_3d_default.cairo @@ -0,0 +1,22 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::operators::tensor::{TensorTrait, Tensor}; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::U32TensorPartialEq; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::U32Tensor; + +#[test] +#[available_gas(2000000000)] +fn test_compress_u32_3d_default() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let z_0 = output_0::output_0(); + + let y_0 = input_0.compress(condition:input_1, axis:Option::Some(0)); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/compress_u32_3d_default/input_0.cairo b/tests/nodes/compress_u32_3d_default/input_0.cairo new file mode 100644 index 000000000..f2805e81e --- /dev/null +++ b/tests/nodes/compress_u32_3d_default/input_0.cairo @@ -0,0 +1,61 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + shape.append(4); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(0); + data.append(1); + data.append(2); + data.append(3); + data.append(4); + data.append(5); + data.append(6); + data.append(7); + data.append(8); + data.append(9); + data.append(10); + data.append(11); + data.append(12); + data.append(13); + data.append(14); + data.append(15); + data.append(16); + data.append(17); + data.append(18); + data.append(19); + data.append(20); + data.append(21); + data.append(22); + data.append(23); + data.append(24); + data.append(25); + data.append(26); + data.append(27); + data.append(28); + data.append(29); + data.append(30); + data.append(31); + data.append(32); + data.append(33); + data.append(34); + data.append(35); + data.append(36); + data.append(37); + data.append(38); + data.append(39); + data.append(40); + data.append(41); + data.append(42); + data.append(43); + data.append(44); + data.append(45); + data.append(46); + data.append(47); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_u32_3d_default/input_1.cairo b/tests/nodes/compress_u32_3d_default/input_1.cairo new file mode 100644 index 000000000..43059b669 --- /dev/null +++ b/tests/nodes/compress_u32_3d_default/input_1.cairo @@ -0,0 +1,13 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(1); + data.append(1); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/compress_u32_3d_default/output_0.cairo b/tests/nodes/compress_u32_3d_default/output_0.cairo new file mode 100644 index 000000000..875197f4f --- /dev/null +++ b/tests/nodes/compress_u32_3d_default/output_0.cairo @@ -0,0 +1,37 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(4); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(0); + data.append(1); + data.append(2); + data.append(3); + data.append(4); + data.append(5); + data.append(6); + data.append(7); + data.append(8); + data.append(9); + data.append(10); + data.append(11); + data.append(12); + data.append(13); + data.append(14); + data.append(15); + data.append(16); + data.append(17); + data.append(18); + data.append(19); + data.append(20); + data.append(21); + data.append(22); + data.append(23); + TensorTrait::new(shape.span(), data.span()) +} From 58ec7e19405b0df77b50f455e23ac5a16516daa6 Mon Sep 17 00:00:00 2001 From: Hakeem Kazeem Date: Mon, 25 Dec 2023 19:34:54 +0100 Subject: [PATCH 15/38] summary, readme --- docs/CHANGELOG.md | 4 ++ docs/SUMMARY.md | 1 + docs/framework/compatibility.md | 5 ++- .../operators/tensor/tensor.compress.md | 39 ++++++++++++++++ docs/framework/operators/tensor/tensor.erf.md | 4 -- .../operators/tensor/tensor.gather_nd.md | 10 +---- src/operators/tensor/core.cairo | 44 +++++++++++++++++-- 7 files changed, 89 insertions(+), 18 deletions(-) create mode 100644 docs/framework/operators/tensor/tensor.compress.md diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 9c22caa01..a7ec58d5f 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -4,7 +4,11 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] - 2023-12-25 +## Added +- Compress Operator. + ## [Unreleased] - 2023-12-14 ## Added diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 62ae2a2b3..d20c843c6 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -127,6 +127,7 @@ * [tensor.erf](framework/operators/tensor/tensor.erf.md) * [tensor.reduce\_log\_sum](framework/operators/tensor/tensor.reduce\_log\_sum.md) * [tensor.unique](framework/operators/tensor/tensor.unique.md) + * [tensor.compress](framework/operators/tensor/tensor.compress.md) * [Neural Network](framework/operators/neural-network/README.md) * [nn.relu](framework/operators/neural-network/nn.relu.md) * [nn.leaky\_relu](framework/operators/neural-network/nn.leaky\_relu.md) diff --git a/docs/framework/compatibility.md b/docs/framework/compatibility.md index 68cd44241..1df3207fe 100644 --- a/docs/framework/compatibility.md +++ b/docs/framework/compatibility.md @@ -103,6 +103,7 @@ You can see below the list of current supported ONNX Operators: | [Not](operators/tensor/tensor.not.md) | :white\_check\_mark: | | [GatherND](operators/tensor/tensor.gather/_nd.md) | :white\_check\_mark: | | [ReduceLogSum](operators/tensor/tensor.reduce\_log\_sum.md) | :white\_check\_mark: | -| [Erf](operators/tensor/tensor.erf.md) | :white\_check\_mark: | +| [Erf](operators/tensor/tensor.erf.md) | :white\_check\_mark: | +| [Compress](operators/tensor/tensor.compress.md) | :white\_check\_mark: | -Current Operators support: **96/156 (62%)** +Current Operators support: **97/156 (62%)** diff --git a/docs/framework/operators/tensor/tensor.compress.md b/docs/framework/operators/tensor/tensor.compress.md new file mode 100644 index 000000000..59cb043b3 --- /dev/null +++ b/docs/framework/operators/tensor/tensor.compress.md @@ -0,0 +1,39 @@ +# tensor.compress + +```rust + fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor; +``` + +Selects slices from an input tensor along a given axis where condition evaluates to True for each axis index. In case axis is not provided, input is flattened before elements are selected. + +## Args + +* `self`(`@Tensor`) - The input tensor. +* `condition`(`Tensor`) - Rank 1 tensor of booleans to indicate which slices or data elements to be selected. Its length can be less than the input length along the axis or the flattened input size if axis is not specified. In such cases data slices or elements exceeding the condition length are discarded. +* `axis`(`Option`) - (Optional) Axis along which to take slices. If not specified, input is flattened before elements being selected. Negative value means counting dimensions from the back. Accepted range is [-r, r-1] where r = rank(input). + +## Panics + +* Panics if condition rank is not equal to 1. + +## Returns + +A new `Tensor` . +fn compress_example() -> Tensor { + let tensor = TensorTrait::::new( + shape: array![3, 2].span(), + data: array![[1, 2], [3, 4], [5, 6]].span(), + ); + let condition = TensorTrait::::new( + shape: array![3].span(), + data: array![0, 1, 1].span(), + ); + + return tensor.compress( + condition: condition, + axis: Option::Some((0)), + ); +} +>>> [[3, 4], + [5, 6]] +``` diff --git a/docs/framework/operators/tensor/tensor.erf.md b/docs/framework/operators/tensor/tensor.erf.md index 19ce86a94..384a941d0 100644 --- a/docs/framework/operators/tensor/tensor.erf.md +++ b/docs/framework/operators/tensor/tensor.erf.md @@ -6,10 +6,6 @@ Computes the mean of the input tensor's elements along the provided axes. -## Args - -* `self`(`@Tensor`) - The input tensor. - ## Returns A new `Tensor` of the same shape as the input tensor with diff --git a/docs/framework/operators/tensor/tensor.gather_nd.md b/docs/framework/operators/tensor/tensor.gather_nd.md index 021d4f235..a922b41ad 100644 --- a/docs/framework/operators/tensor/tensor.gather_nd.md +++ b/docs/framework/operators/tensor/tensor.gather_nd.md @@ -21,14 +21,6 @@ Given data tensor of rank r >= 1, indices tensor of rank q >= 1, and batch_dims ## Returns A new `Tensor` . - -## Example - -```rust -use array::{ArrayTrait, SpanTrait}; - -use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor}; - fn gather_nd_example() -> Tensor { let tensor = TensorTrait::::new( shape: array![2, 2].span(), @@ -41,7 +33,7 @@ fn gather_nd_example() -> Tensor { return tensor.gather_nd( indices: indices, - axis: Option::None((0)), + axis: Option::Some((0)), ); } >>> [[0, 1], diff --git a/src/operators/tensor/core.cairo b/src/operators/tensor/core.cairo index 91f3ecda8..8aad8ab18 100644 --- a/src/operators/tensor/core.cairo +++ b/src/operators/tensor/core.cairo @@ -5067,7 +5067,7 @@ trait TensorTrait { /// /// return tensor.gather_nd( /// indices: indices, - /// axis: Option::None((0)), + /// axis: Option::Some((0)), /// ); /// } /// >>> [[0, 1], @@ -5077,9 +5077,47 @@ trait TensorTrait { /// ``` /// fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor; - + /// # tensor.compress + /// + /// ```rust + /// fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor; + /// ``` + /// + /// Selects slices from an input tensor along a given axis where condition evaluates to True for each axis index. In case axis is not provided, input is flattened before elements are selected. + /// + /// ## Args + /// + /// * `self`(`@Tensor`) - The input tensor. + /// * `condition`(`Tensor`) - Rank 1 tensor of booleans to indicate which slices or data elements to be selected. Its length can be less than the input length along the axis or the flattened input size if axis is not specified. In such cases data slices or elements exceeding the condition length are discarded. + /// * `axis`(`Option`) - (Optional) Axis along which to take slices. If not specified, input is flattened before elements being selected. Negative value means counting dimensions from the back. Accepted range is [-r, r-1] where r = rank(input). + /// + /// ## Panics + /// + /// * Panics if condition rank is not equal to 1. + /// + /// ## Returns + /// + /// A new `Tensor` . + /// fn compress_example() -> Tensor { + /// let tensor = TensorTrait::::new( + /// shape: array![3, 2].span(), + /// data: array![[1, 2], [3, 4], [5, 6]].span(), + /// ); + /// let condition = TensorTrait::::new( + /// shape: array![3].span(), + /// data: array![0, 1, 1].span(), + /// ); + /// + /// return tensor.compress( + /// condition: condition, + /// axis: Option::Some((0)), + /// ); + /// } + /// >>> [[3, 4], + /// [5, 6]] + /// ``` + /// fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor; - } /// Cf: TensorTrait::new docstring From 201cb22b48cabe1e17596e817e53e6e36c8d1784 Mon Sep 17 00:00:00 2001 From: hattizai <150505746+hattizai@users.noreply.github.com> Date: Sat, 30 Dec 2023 15:19:14 +0800 Subject: [PATCH 16/38] Update provable-mlr-forecasting-aaves-lifetime-repayments.md --- .../provable-mlr-forecasting-aaves-lifetime-repayments.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/academy/tutorials/provable-mlr-forecasting-aaves-lifetime-repayments.md b/docs/academy/tutorials/provable-mlr-forecasting-aaves-lifetime-repayments.md index 54b35181e..68c4daac6 100644 --- a/docs/academy/tutorials/provable-mlr-forecasting-aaves-lifetime-repayments.md +++ b/docs/academy/tutorials/provable-mlr-forecasting-aaves-lifetime-repayments.md @@ -205,7 +205,7 @@ X_min = np.min(X_original, axis=0) X_max = np.max(X_original, axis=0) X_range = X_max - X_min df_forecast_data_normalized = (df_forecast_data - X_min) / X_range -# tranpose the matrix and add bias +# transpose the matrix and add bias df_forecast_data_normalized_transposed= df_forecast_data_normalized.T df_forecast_data_normalized_transposed_with_bias = np.vstack((df_forecast_data_normalized_transposed, np.ones(df_forecast_data_normalized_transposed.shape[1]))) #normalized forecasts @@ -913,7 +913,7 @@ fn normalize_user_x_inputs( shape: array![data_len].span(), data: x_range.span() ); - // for normalizing 2D user inputed feature vals + // for normalizing 2D user inputted feature vals if x_inputs.shape.len() > 1 { let mut j: u32 = 0; loop { @@ -1069,7 +1069,7 @@ let mut rescale_forecasts = rescale_predictions(forecast_results, main_y_vals); Our model will get tested under the `multiple_linear_regression_test()` function which will follow these steps: -1. Data retrival: The function initiates by fetching the AAVE dataset's x and y values. +1. Data retrieval: The function initiates by fetching the AAVE dataset's x and y values. 2. Dataset construction and normalization: A new Dataset object gets initialized by passing the x and y variables. It is then normalized using the built-in `normalize_dataset()` method. 3. Model fitting: Using the `MultipleLinearRegression` function we fit the normalized dataset and compute the regression coefficients. 4. Computing accuracy of the model: To calculate the accuracy we utilize the `predict` method to compute the dot product between the model's regression coefficients and the x values. We then compute the R-squared score to measure the accuracy of our model. From a24a8c1fec60f8837295dfc0e76f1d2b709852ad Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Fri, 5 Jan 2024 07:34:05 -0800 Subject: [PATCH 17/38] fix compilation error --- src/operators/tensor/core.cairo | 126 ------------------ .../tensor/implementations/tensor_bool.cairo | 4 - .../implementations/tensor_fp16x16.cairo | 8 -- .../implementations/tensor_fp16x16wide.cairo | 8 -- .../implementations/tensor_fp32x32.cairo | 8 -- .../implementations/tensor_fp64x64.cairo | 9 -- .../implementations/tensor_fp8x23.cairo | 9 -- .../implementations/tensor_fp8x23wide.cairo | 9 -- .../tensor/implementations/tensor_i32.cairo | 9 -- .../tensor/implementations/tensor_i8.cairo | 9 -- .../tensor/implementations/tensor_u32.cairo | 8 -- tests/nodes/sequence_construct_fp16x16.cairo | 2 - tests/nodes/sequence_construct_fp8x23.cairo | 2 - tests/nodes/sequence_construct_i32.cairo | 2 - tests/nodes/sequence_construct_i8.cairo | 2 - tests/nodes/sequence_construct_u32.cairo | 1 - tests/nodes/sequence_empty_fp16x16.cairo | 1 - tests/nodes/sequence_empty_fp8x23.cairo | 1 - tests/nodes/sequence_empty_i32.cairo | 2 - tests/nodes/sequence_empty_i8.cairo | 1 - tests/nodes/sequence_empty_u32.cairo | 2 - 21 files changed, 223 deletions(-) diff --git a/src/operators/tensor/core.cairo b/src/operators/tensor/core.cairo index 173192987..d20f8911d 100644 --- a/src/operators/tensor/core.cairo +++ b/src/operators/tensor/core.cairo @@ -103,15 +103,12 @@ impl TensorSerde, impl TDrop: Drop> of Serde { /// ``` /// fn array_feature_extractor(self: @Tensor, indices: Tensor) -> Tensor; - /// ## tensor.reduce_mean - /// - /// ```rust - /// fn reduce_mean(self: @Tensor, axes: Option>, keepdims: Option, noop_with_empty_axes: Option) -> Tensor; - /// ``` - /// - /// Computes the mean of the input tensor's elements along the provided axes. - /// - /// ## Args - /// - /// * `self`(`@Tensor`) - The input tensor. - /// * `axes`(`Option>`) - Optional input list of integers, along which to reduce. The default is to reduce over all the dimensions of the input tensor if 'noop_with_empty_axes' is false, else act as an Identity op when 'noop_with_empty_axes' is true. - /// * `keepdims`(`Option`) - Keep the reduced dimension or not, default true means keep reduced dimension. - /// * `noop_with_empty_axes`(`Option`) - Defines behavior if 'axes' is empty. Default behavior with 'false' is to reduce all axes. When axes is empty and this attribute is set to true, input tensor will not be reduced,and the output tensor would be equivalent to input tensor. - /// - /// ## Panics - /// - /// * Panics if axis is not in the range of the input tensor's dimensions. - /// - /// ## Returns - /// - /// A new `Tensor` instance with the specified axes reduced by meaning its elements. - /// - /// ## Examples - /// - /// ```rust - /// use array::{ArrayTrait, SpanTrait}; - /// - /// use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor}; - /// - /// fn reduce_mean_example() -> Tensor { - /// let tensor = TensorTrait::::new( - /// shape: array![2, 2, 2].span(), data: array![0, 1, 2, 3, 4, 5, 6, 7].span(), - /// ); - /// - /// // We can call `reduce_mean` function as follows. - /// return tensor.reduce_mean(axes: array![1].span(), - /// keepdims: Option::None(()), - /// noop_with_empty_axes: Option::None(())); - /// } - /// >>> [[1,2],[5,6]] - /// ``` - /// - fn reduce_mean( - self: @Tensor, - axes: Option>, - keepdims: Option, - noop_with_empty_axes: Option - ) -> Tensor; - /// # tensor.sequence_empty - /// - /// ```rust - /// fn sequence_empty() -> Array>; - /// ``` - /// - /// Returns an empty tensor sequence. - /// - /// ## Args - /// - /// ## Returns - /// - /// An empty `Array>` instance. - /// - /// ## Examples - /// - /// Let's create a new empty sequence. - /// - /// ```rust - /// use core::array::{ArrayTrait, SpanTrait}; - /// - /// use orion::operators::tensor::{ - /// TensorTrait, // we import the trait - /// Tensor, // we import the type - /// U32Tensor // we import the implementation. - /// }; - /// - /// fn sequence_empty_example() -> Array> { - /// let sequence = TensorTrait::sequence_empty(); - /// - /// return sequence; - /// } - /// - /// >>> [] - /// ``` - /// - fn sequence_empty() -> Array>; /// # tensor.shrink /// /// ```rust @@ -4282,43 +4193,6 @@ trait TensorTrait { /// ``` /// fn shrink(self: Tensor, bias: Option, lambd: Option) -> Tensor; - /// ## tensor.sequence_construct - /// - /// ```rust - /// fn sequence_construct(tensors: Array>) -> Array>; - /// ``` - /// - /// Constructs a tensor sequence containing the input tensors. - /// - /// ## Args - /// - /// * `tensors`(`Array>`) - The array of input tensors. - /// - /// ## Panics - /// - /// * Panics if input tensor array is empty. - /// - /// ## Returns - /// - /// A tensor sequence `Array>` containing the input tensors. - /// - /// ## Examples - /// - /// ```rust - /// use core::array::{ArrayTrait, SpanTrait}; - /// - /// use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor}; - /// - /// fn sequence_construct_example() -> Array> { - /// let tensor1 = TensorTrait::new(shape: array![2, 2].span(), data: array![0, 1, 2, 3].span()); - /// let tensor2 = TensorTrait::new(shape: array![2, 2].span(), data: array![4, 5, 6, 7].span()); - /// let result = TensorTrait::sequence_construct(tensors: array![tensor1, tensor2]); - /// return result; - /// } - /// >>> [[0, 1, 2, 3], [4, 5, 6, 7]] - /// ``` - /// - fn sequence_construct(tensors: Array>) -> Array>; /// ## tensor.reduce_mean /// /// ```rust diff --git a/src/operators/tensor/implementations/tensor_bool.cairo b/src/operators/tensor/implementations/tensor_bool.cairo index b364457af..99cfcee76 100644 --- a/src/operators/tensor/implementations/tensor_bool.cairo +++ b/src/operators/tensor/implementations/tensor_bool.cairo @@ -388,10 +388,6 @@ impl BoolTensor of TensorTrait { math::sequence_at::sequence_at(sequence, position) } - fn sequence_construct(tensors: Array>) -> Array> { - math::sequence_construct::sequence_construct(tensors) - } - fn shrink(self: Tensor, bias: Option, lambd: Option) -> Tensor { panic(array!['not supported!']) } diff --git a/src/operators/tensor/implementations/tensor_fp16x16.cairo b/src/operators/tensor/implementations/tensor_fp16x16.cairo index ccaf5903d..03c022121 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16.cairo @@ -462,14 +462,6 @@ impl FP16x16Tensor of TensorTrait { math::sequence_at::sequence_at(sequence, position) } - fn sequence_construct(tensors: Array>) -> Array> { - math::sequence_construct::sequence_construct(tensors) - } - - fn sequence_empty() -> Array> { - math::sequence_empty::sequence_empty::() - } - fn reduce_mean( self: @Tensor, axes: Option>, diff --git a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo index dc32202ed..109f956ad 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo @@ -428,14 +428,6 @@ impl FP16x16WTensor of TensorTrait { math::sequence_at::sequence_at(sequence, position) } - fn sequence_construct(tensors: Array>) -> Array> { - math::sequence_construct::sequence_construct(tensors) - } - - fn sequence_empty() -> Array> { - math::sequence_empty::sequence_empty::() - } - fn reduce_mean( self: @Tensor, axes: Option>, diff --git a/src/operators/tensor/implementations/tensor_fp32x32.cairo b/src/operators/tensor/implementations/tensor_fp32x32.cairo index 9100d6f82..a780cd6c6 100644 --- a/src/operators/tensor/implementations/tensor_fp32x32.cairo +++ b/src/operators/tensor/implementations/tensor_fp32x32.cairo @@ -463,14 +463,6 @@ impl FP32x32Tensor of TensorTrait { math::sequence_at::sequence_at(sequence, position) } - fn sequence_construct(tensors: Array>) -> Array> { - math::sequence_construct::sequence_construct(tensors) - } - - fn sequence_empty() -> Array> { - math::sequence_empty::sequence_empty::() - } - fn reduce_mean( self: @Tensor, axes: Option>, diff --git a/src/operators/tensor/implementations/tensor_fp64x64.cairo b/src/operators/tensor/implementations/tensor_fp64x64.cairo index ee6441058..24a635532 100644 --- a/src/operators/tensor/implementations/tensor_fp64x64.cairo +++ b/src/operators/tensor/implementations/tensor_fp64x64.cairo @@ -463,15 +463,6 @@ impl FP64x64Tensor of TensorTrait { math::sequence_at::sequence_at(sequence, position) } - fn sequence_construct(tensors: Array>) -> Array> { - math::sequence_construct::sequence_construct(tensors) - } - - - fn sequence_empty() -> Array> { - math::sequence_empty::sequence_empty::() - } - fn reduce_mean( self: @Tensor, axes: Option>, diff --git a/src/operators/tensor/implementations/tensor_fp8x23.cairo b/src/operators/tensor/implementations/tensor_fp8x23.cairo index 17a601f7b..40fa0a95b 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23.cairo @@ -448,15 +448,6 @@ impl FP8x23Tensor of TensorTrait { math::sequence_at::sequence_at(sequence, position) } - fn sequence_construct(tensors: Array>) -> Array> { - math::sequence_construct::sequence_construct(tensors) - } - - - fn sequence_empty() -> Array> { - math::sequence_empty::sequence_empty::() - } - fn reduce_mean( self: @Tensor, axes: Option>, diff --git a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo index a7d19901b..09ef81e1a 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo @@ -414,15 +414,6 @@ impl FP8x23WTensor of TensorTrait { math::sequence_at::sequence_at(sequence, position) } - fn sequence_construct(tensors: Array>) -> Array> { - math::sequence_construct::sequence_construct(tensors) - } - - - fn sequence_empty() -> Array> { - math::sequence_empty::sequence_empty::() - } - fn reduce_mean( self: @Tensor, axes: Option>, diff --git a/src/operators/tensor/implementations/tensor_i32.cairo b/src/operators/tensor/implementations/tensor_i32.cairo index a987b0633..dacf2b733 100644 --- a/src/operators/tensor/implementations/tensor_i32.cairo +++ b/src/operators/tensor/implementations/tensor_i32.cairo @@ -458,15 +458,6 @@ impl I32Tensor of TensorTrait { math::sequence_at::sequence_at(sequence, position) } - fn sequence_construct(tensors: Array>) -> Array> { - math::sequence_construct::sequence_construct(tensors) - } - - - fn sequence_empty() -> Array> { - math::sequence_empty::sequence_empty::() - } - fn reduce_mean( self: @Tensor, axes: Option>, diff --git a/src/operators/tensor/implementations/tensor_i8.cairo b/src/operators/tensor/implementations/tensor_i8.cairo index 8c1e2fd32..e5b0c299e 100644 --- a/src/operators/tensor/implementations/tensor_i8.cairo +++ b/src/operators/tensor/implementations/tensor_i8.cairo @@ -456,15 +456,6 @@ impl I8Tensor of TensorTrait { math::sequence_at::sequence_at(sequence, position) } - fn sequence_construct(tensors: Array>) -> Array> { - math::sequence_construct::sequence_construct(tensors) - } - - - fn sequence_empty() -> Array> { - math::sequence_empty::sequence_empty::() - } - fn reduce_mean( self: @Tensor, axes: Option>, diff --git a/src/operators/tensor/implementations/tensor_u32.cairo b/src/operators/tensor/implementations/tensor_u32.cairo index 5b2058401..a8e989eb1 100644 --- a/src/operators/tensor/implementations/tensor_u32.cairo +++ b/src/operators/tensor/implementations/tensor_u32.cairo @@ -400,14 +400,6 @@ impl U32Tensor of TensorTrait { math::sequence_at::sequence_at(sequence, position) } - fn sequence_construct(tensors: Array>) -> Array> { - math::sequence_construct::sequence_construct(tensors) - } - - fn sequence_empty() -> Array> { - math::sequence_empty::sequence_empty::() - } - fn reduce_mean( self: @Tensor, axes: Option>, diff --git a/tests/nodes/sequence_construct_fp16x16.cairo b/tests/nodes/sequence_construct_fp16x16.cairo index 50596c598..c6d2238fb 100644 --- a/tests/nodes/sequence_construct_fp16x16.cairo +++ b/tests/nodes/sequence_construct_fp16x16.cairo @@ -6,9 +6,7 @@ use orion::operators::tensor::FP16x16TensorPartialEq; use orion::operators::tensor::{TensorTrait, Tensor}; use core::array::{ArrayTrait, SpanTrait}; use orion::utils::{assert_eq, assert_seq_eq}; -use orion::operators::tensor::FP16x16TensorPartialEq; use orion::operators::sequence::SequenceTrait; -use array::{ArrayTrait, SpanTrait}; use orion::operators::sequence::FP16x16Sequence; #[test] diff --git a/tests/nodes/sequence_construct_fp8x23.cairo b/tests/nodes/sequence_construct_fp8x23.cairo index abafe8c27..fa08e589f 100644 --- a/tests/nodes/sequence_construct_fp8x23.cairo +++ b/tests/nodes/sequence_construct_fp8x23.cairo @@ -8,9 +8,7 @@ use orion::operators::tensor::{TensorTrait, Tensor}; use core::array::{ArrayTrait, SpanTrait}; use orion::utils::{assert_eq, assert_seq_eq}; use orion::operators::sequence::FP8x23Sequence; -use orion::operators::tensor::FP8x23TensorPartialEq; use orion::operators::sequence::SequenceTrait; -use array::{ArrayTrait, SpanTrait}; #[test] #[available_gas(2000000000)] diff --git a/tests/nodes/sequence_construct_i32.cairo b/tests/nodes/sequence_construct_i32.cairo index 623d3c0e5..f59b4df1b 100644 --- a/tests/nodes/sequence_construct_i32.cairo +++ b/tests/nodes/sequence_construct_i32.cairo @@ -7,10 +7,8 @@ use orion::operators::tensor::{TensorTrait, Tensor}; use core::array::{ArrayTrait, SpanTrait}; use orion::utils::{assert_eq, assert_seq_eq}; use orion::operators::tensor::I32TensorPartialEq; -use orion::utils::{assert_eq, assert_seq_eq}; use orion::operators::sequence::I32Sequence; use orion::operators::sequence::SequenceTrait; -use array::{ArrayTrait, SpanTrait}; #[test] #[available_gas(2000000000)] diff --git a/tests/nodes/sequence_construct_i8.cairo b/tests/nodes/sequence_construct_i8.cairo index ea79e3fd9..4d4a3f5ac 100644 --- a/tests/nodes/sequence_construct_i8.cairo +++ b/tests/nodes/sequence_construct_i8.cairo @@ -6,9 +6,7 @@ use orion::operators::tensor::I8TensorPartialEq; use orion::operators::tensor::{TensorTrait, Tensor}; use core::array::{ArrayTrait, SpanTrait}; use orion::utils::{assert_eq, assert_seq_eq}; -use orion::operators::tensor::I8TensorPartialEq; use orion::operators::sequence::SequenceTrait; -use array::{ArrayTrait, SpanTrait}; use orion::operators::sequence::I8Sequence; #[test] diff --git a/tests/nodes/sequence_construct_u32.cairo b/tests/nodes/sequence_construct_u32.cairo index 09fc37f78..a6a783f51 100644 --- a/tests/nodes/sequence_construct_u32.cairo +++ b/tests/nodes/sequence_construct_u32.cairo @@ -8,7 +8,6 @@ use orion::operators::tensor::U32Tensor; use orion::utils::{assert_eq, assert_seq_eq}; use orion::operators::tensor::U32TensorPartialEq; use orion::operators::sequence::SequenceTrait; -use array::{ArrayTrait, SpanTrait}; use orion::operators::sequence::U32Sequence; #[test] diff --git a/tests/nodes/sequence_empty_fp16x16.cairo b/tests/nodes/sequence_empty_fp16x16.cairo index 06141a0eb..b9348bb4f 100644 --- a/tests/nodes/sequence_empty_fp16x16.cairo +++ b/tests/nodes/sequence_empty_fp16x16.cairo @@ -6,7 +6,6 @@ use core::array::{ArrayTrait, SpanTrait}; use orion::operators::tensor::FP16x16TensorPartialEq; use orion::operators::sequence::SequenceTrait; use orion::utils::{assert_eq, assert_seq_eq}; -use array::{ArrayTrait, SpanTrait}; #[test] #[available_gas(2000000000)] diff --git a/tests/nodes/sequence_empty_fp8x23.cairo b/tests/nodes/sequence_empty_fp8x23.cairo index dad959ff7..6c7248e7f 100644 --- a/tests/nodes/sequence_empty_fp8x23.cairo +++ b/tests/nodes/sequence_empty_fp8x23.cairo @@ -6,7 +6,6 @@ use orion::operators::tensor::FP8x23TensorPartialEq; use orion::operators::sequence::FP8x23Sequence; use orion::utils::{assert_eq, assert_seq_eq}; use orion::operators::sequence::SequenceTrait; -use array::{ArrayTrait, SpanTrait}; #[test] #[available_gas(2000000000)] diff --git a/tests/nodes/sequence_empty_i32.cairo b/tests/nodes/sequence_empty_i32.cairo index 24eba7d3e..798bc794a 100644 --- a/tests/nodes/sequence_empty_i32.cairo +++ b/tests/nodes/sequence_empty_i32.cairo @@ -7,9 +7,7 @@ use orion::operators::tensor::I32TensorPartialEq; use orion::operators::tensor::{TensorTrait, Tensor}; use orion::operators::tensor::I32Tensor; use orion::utils::{assert_eq, assert_seq_eq}; -use orion::operators::tensor::I32TensorPartialEq; use orion::operators::sequence::I32Sequence; -use array::{ArrayTrait, SpanTrait}; #[test] #[available_gas(2000000000)] diff --git a/tests/nodes/sequence_empty_i8.cairo b/tests/nodes/sequence_empty_i8.cairo index 8cb08d668..07e9f8a57 100644 --- a/tests/nodes/sequence_empty_i8.cairo +++ b/tests/nodes/sequence_empty_i8.cairo @@ -7,7 +7,6 @@ use core::array::{ArrayTrait, SpanTrait}; use orion::operators::tensor::{TensorTrait, Tensor}; use orion::utils::{assert_eq, assert_seq_eq}; use orion::operators::tensor::I8TensorPartialEq; -use array::{ArrayTrait, SpanTrait}; #[test] #[available_gas(2000000000)] diff --git a/tests/nodes/sequence_empty_u32.cairo b/tests/nodes/sequence_empty_u32.cairo index 7d4ba3803..ecd7cefef 100644 --- a/tests/nodes/sequence_empty_u32.cairo +++ b/tests/nodes/sequence_empty_u32.cairo @@ -7,8 +7,6 @@ use orion::operators::tensor::U32Tensor; use orion::utils::{assert_eq, assert_seq_eq}; use orion::operators::tensor::U32TensorPartialEq; use orion::operators::sequence::SequenceTrait; -use orion::utils::{assert_eq, assert_seq_eq}; -use array::{ArrayTrait, SpanTrait}; use orion::operators::sequence::U32Sequence; #[test] From 303dc17e21533c44b38eac78961d8706e568394d Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Fri, 5 Jan 2024 08:40:32 -0800 Subject: [PATCH 18/38] update sequence trait + tests --- nodegen/node/concat_from_sequence.py | 32 +-- nodegen/node/sequence_at.py | 22 +- nodegen/node/sequence_erase.py | 32 +-- nodegen/node/sequence_insert.py | 12 +- nodegen/node/sequence_length.py | 22 +- src/operators/sequence/core.cairo | 264 ++++++++++++++++++ src/operators/sequence/functional.cairo | 7 +- .../functional}/concat_from_sequence.cairo | 0 .../functional}/sequence_at.cairo | 2 +- .../functional}/sequence_erase.cairo | 2 +- .../functional}/sequence_insert.cairo | 2 +- .../functional}/sequence_length.cairo | 2 +- .../implementations/sequence_bool.cairo | 30 +- .../implementations/sequence_fp16x16.cairo | 28 ++ .../sequence_fp16x16wide.cairo | 28 ++ .../implementations/sequence_fp32x32.cairo | 28 ++ .../implementations/sequence_fp64x64.cairo | 28 ++ .../implementations/sequence_fp8x23.cairo | 28 ++ .../implementations/sequence_fp8x23wide.cairo | 29 ++ .../implementations/sequence_i32.cairo | 26 ++ .../implementations/sequence_i8.cairo | 28 ++ .../implementations/sequence_u32.cairo | 28 ++ src/operators/tensor/core.cairo | 263 ----------------- .../tensor/implementations/tensor_bool.cairo | 26 -- .../implementations/tensor_fp16x16.cairo | 26 -- .../implementations/tensor_fp16x16wide.cairo | 26 -- .../implementations/tensor_fp32x32.cairo | 26 -- .../implementations/tensor_fp64x64.cairo | 26 -- .../implementations/tensor_fp8x23.cairo | 26 -- .../implementations/tensor_fp8x23wide.cairo | 26 -- .../tensor/implementations/tensor_i32.cairo | 26 -- .../tensor/implementations/tensor_i8.cairo | 26 -- .../tensor/implementations/tensor_u32.cairo | 26 -- src/operators/tensor/math.cairo | 5 - ...om_sequence_fp16x16_new_axis_default.cairo | 4 +- ...t_from_sequence_fp16x16_new_axis_one.cairo | 4 +- ..._from_sequence_fp16x16_new_axis_zero.cairo | 4 +- ...rom_sequence_fp8x23_new_axis_default.cairo | 4 +- ...at_from_sequence_fp8x23_new_axis_one.cairo | 4 +- ...t_from_sequence_fp8x23_new_axis_zero.cairo | 4 +- ...t_from_sequence_i32_new_axis_default.cairo | 4 +- ...oncat_from_sequence_i32_new_axis_one.cairo | 4 +- ...ncat_from_sequence_i32_new_axis_zero.cairo | 4 +- ...at_from_sequence_i8_new_axis_default.cairo | 4 +- ...concat_from_sequence_i8_new_axis_one.cairo | 4 +- ...oncat_from_sequence_i8_new_axis_zero.cairo | 4 +- ...t_from_sequence_u32_new_axis_default.cairo | 4 +- ...oncat_from_sequence_u32_new_axis_one.cairo | 4 +- ...ncat_from_sequence_u32_new_axis_zero.cairo | 4 +- .../nodes/sequence_at_fp16x16_negative.cairo | 4 +- .../nodes/sequence_at_fp16x16_positive.cairo | 4 +- tests/nodes/sequence_at_fp8x23_negative.cairo | 4 +- tests/nodes/sequence_at_fp8x23_positive.cairo | 4 +- tests/nodes/sequence_at_i32_negative.cairo | 4 +- tests/nodes/sequence_at_i32_positive.cairo | 4 +- tests/nodes/sequence_at_i8_negative.cairo | 4 +- tests/nodes/sequence_at_i8_positive.cairo | 4 +- tests/nodes/sequence_at_u32_negative.cairo | 4 +- tests/nodes/sequence_at_u32_positive.cairo | 4 +- .../nodes/sequence_erase_fp16x16_empty.cairo | 4 +- .../sequence_erase_fp16x16_negative.cairo | 4 +- .../sequence_erase_fp16x16_positive.cairo | 4 +- tests/nodes/sequence_erase_fp8x23_empty.cairo | 4 +- .../sequence_erase_fp8x23_negative.cairo | 4 +- .../sequence_erase_fp8x23_positive.cairo | 4 +- tests/nodes/sequence_erase_i32_empty.cairo | 4 +- tests/nodes/sequence_erase_i32_negative.cairo | 4 +- tests/nodes/sequence_erase_i32_positive.cairo | 4 +- tests/nodes/sequence_erase_i8_empty.cairo | 4 +- tests/nodes/sequence_erase_i8_negative.cairo | 4 +- tests/nodes/sequence_erase_i8_positive.cairo | 4 +- tests/nodes/sequence_erase_u32_empty.cairo | 4 +- tests/nodes/sequence_erase_u32_negative.cairo | 4 +- tests/nodes/sequence_erase_u32_positive.cairo | 4 +- tests/nodes/sequence_insert_fp16x16.cairo | 2 + tests/nodes/sequence_insert_fp8x23.cairo | 2 + tests/nodes/sequence_insert_i32.cairo | 2 + tests/nodes/sequence_insert_i8.cairo | 2 + tests/nodes/sequence_insert_u32.cairo | 2 + tests/nodes/sequence_length_fp16x16.cairo | 2 + .../sequence_length_fp16x16_broadcast.cairo | 2 + tests/nodes/sequence_length_fp8x23.cairo | 2 + .../sequence_length_fp8x23_broadcast.cairo | 2 + tests/nodes/sequence_length_i32.cairo | 2 + .../nodes/sequence_length_i32_broadcast.cairo | 2 + tests/nodes/sequence_length_i8.cairo | 2 + .../nodes/sequence_length_i8_broadcast.cairo | 2 + tests/nodes/sequence_length_u32.cairo | 2 + .../nodes/sequence_length_u32_broadcast.cairo | 2 + 89 files changed, 764 insertions(+), 634 deletions(-) rename src/operators/{tensor/math => sequence/functional}/concat_from_sequence.cairo (100%) rename src/operators/{tensor/math => sequence/functional}/sequence_at.cairo (95%) rename src/operators/{tensor/math => sequence/functional}/sequence_erase.cairo (97%) rename src/operators/{tensor/math => sequence/functional}/sequence_insert.cairo (97%) rename src/operators/{tensor/math => sequence/functional}/sequence_length.cairo (88%) diff --git a/nodegen/node/concat_from_sequence.py b/nodegen/node/concat_from_sequence.py index 4918785ea..eb6d6c9e1 100644 --- a/nodegen/node/concat_from_sequence.py +++ b/nodegen/node/concat_from_sequence.py @@ -1,6 +1,6 @@ import numpy as np from nodegen.node import RunAll -from ..helpers import make_test, to_fp, Tensor, Dtype, FixedImpl +from ..helpers import make_test, to_fp, Tensor, Dtype, FixedImpl, Trait class Concat_from_sequence(RunAll): @@ -25,7 +25,7 @@ def new_axis_zero(): concatenated_tensor = Tensor(Dtype.U32, concatenated_tensor.shape, concatenated_tensor.flatten()) name = "concat_from_sequence_u32_new_axis_zero" - make_test([sequence], concatenated_tensor, "TensorTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::Some(0))", name) + make_test([sequence], concatenated_tensor, "SequenceTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::Some(0))", name, Trait.SEQUENCE) def new_axis_one(): sequence = [] @@ -45,7 +45,7 @@ def new_axis_one(): concatenated_tensor = Tensor(Dtype.U32, concatenated_tensor.shape, concatenated_tensor.flatten()) name = "concat_from_sequence_u32_new_axis_one" - make_test([sequence], concatenated_tensor, "TensorTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::Some(1))", name) + make_test([sequence], concatenated_tensor, "SequenceTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::Some(1))", name, Trait.SEQUENCE) def new_axis_default(): sequence = [] @@ -65,7 +65,7 @@ def new_axis_default(): concatenated_tensor = Tensor(Dtype.U32, concatenated_tensor.shape, concatenated_tensor.flatten()) name = "concat_from_sequence_u32_new_axis_default" - make_test([sequence], concatenated_tensor, "TensorTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::None(()))", name) + make_test([sequence], concatenated_tensor, "SequenceTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::None(()))", name, Trait.SEQUENCE) new_axis_zero() new_axis_one() @@ -92,7 +92,7 @@ def new_axis_zero(): concatenated_tensor = Tensor(Dtype.I32, concatenated_tensor.shape, concatenated_tensor.flatten()) name = "concat_from_sequence_i32_new_axis_zero" - make_test([sequence], concatenated_tensor, "TensorTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::Some(0))", name) + make_test([sequence], concatenated_tensor, "SequenceTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::Some(0))", name, Trait.SEQUENCE) def new_axis_one(): sequence = [] @@ -112,7 +112,7 @@ def new_axis_one(): concatenated_tensor = Tensor(Dtype.I32, concatenated_tensor.shape, concatenated_tensor.flatten()) name = "concat_from_sequence_i32_new_axis_one" - make_test([sequence], concatenated_tensor, "TensorTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::Some(1))", name) + make_test([sequence], concatenated_tensor, "SequenceTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::Some(1))", name, Trait.SEQUENCE) def new_axis_default(): sequence = [] @@ -132,7 +132,7 @@ def new_axis_default(): concatenated_tensor = Tensor(Dtype.I32, concatenated_tensor.shape, concatenated_tensor.flatten()) name = "concat_from_sequence_i32_new_axis_default" - make_test([sequence], concatenated_tensor, "TensorTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::None(()))", name) + make_test([sequence], concatenated_tensor, "SequenceTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::None(()))", name, Trait.SEQUENCE) new_axis_zero() new_axis_one() @@ -159,7 +159,7 @@ def new_axis_zero(): concatenated_tensor = Tensor(Dtype.I8, concatenated_tensor.shape, concatenated_tensor.flatten()) name = "concat_from_sequence_i8_new_axis_zero" - make_test([sequence], concatenated_tensor, "TensorTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::Some(0))", name) + make_test([sequence], concatenated_tensor, "SequenceTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::Some(0))", name, Trait.SEQUENCE) def new_axis_one(): sequence = [] @@ -179,7 +179,7 @@ def new_axis_one(): concatenated_tensor = Tensor(Dtype.I8, concatenated_tensor.shape, concatenated_tensor.flatten()) name = "concat_from_sequence_i8_new_axis_one" - make_test([sequence], concatenated_tensor, "TensorTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::Some(1))", name) + make_test([sequence], concatenated_tensor, "SequenceTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::Some(1))", name, Trait.SEQUENCE) def new_axis_default(): sequence = [] @@ -199,7 +199,7 @@ def new_axis_default(): concatenated_tensor = Tensor(Dtype.I8, concatenated_tensor.shape, concatenated_tensor.flatten()) name = "concat_from_sequence_i8_new_axis_default" - make_test([sequence], concatenated_tensor, "TensorTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::None(()))", name) + make_test([sequence], concatenated_tensor, "SequenceTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::None(()))", name, Trait.SEQUENCE) new_axis_zero() new_axis_one() @@ -226,7 +226,7 @@ def new_axis_zero(): concatenated_tensor = Tensor(Dtype.FP8x23, concatenated_tensor.shape, to_fp(concatenated_tensor.flatten(), FixedImpl.FP8x23)) name = "concat_from_sequence_fp8x23_new_axis_zero" - make_test([sequence], concatenated_tensor, "TensorTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::Some(0))", name) + make_test([sequence], concatenated_tensor, "SequenceTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::Some(0))", name, Trait.SEQUENCE) def new_axis_one(): sequence = [] @@ -246,7 +246,7 @@ def new_axis_one(): concatenated_tensor = Tensor(Dtype.FP8x23, concatenated_tensor.shape, to_fp(concatenated_tensor.flatten(), FixedImpl.FP8x23)) name = "concat_from_sequence_fp8x23_new_axis_one" - make_test([sequence], concatenated_tensor, "TensorTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::Some(1))", name) + make_test([sequence], concatenated_tensor, "SequenceTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::Some(1))", name, Trait.SEQUENCE) def new_axis_default(): sequence = [] @@ -266,7 +266,7 @@ def new_axis_default(): concatenated_tensor = Tensor(Dtype.FP8x23, concatenated_tensor.shape, to_fp(concatenated_tensor.flatten(), FixedImpl.FP8x23)) name = "concat_from_sequence_fp8x23_new_axis_default" - make_test([sequence], concatenated_tensor, "TensorTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::None(()))", name) + make_test([sequence], concatenated_tensor, "SequenceTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::None(()))", name, Trait.SEQUENCE) new_axis_zero() new_axis_one() @@ -293,7 +293,7 @@ def new_axis_zero(): concatenated_tensor = Tensor(Dtype.FP16x16, concatenated_tensor.shape, to_fp(concatenated_tensor.flatten(), FixedImpl.FP16x16)) name = "concat_from_sequence_fp16x16_new_axis_zero" - make_test([sequence], concatenated_tensor, "TensorTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::Some(0))", name) + make_test([sequence], concatenated_tensor, "SequenceTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::Some(0))", name, Trait.SEQUENCE) def new_axis_one(): sequence = [] @@ -313,7 +313,7 @@ def new_axis_one(): concatenated_tensor = Tensor(Dtype.FP16x16, concatenated_tensor.shape, to_fp(concatenated_tensor.flatten(), FixedImpl.FP16x16)) name = "concat_from_sequence_fp16x16_new_axis_one" - make_test([sequence], concatenated_tensor, "TensorTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::Some(1))", name) + make_test([sequence], concatenated_tensor, "SequenceTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::Some(1))", name, Trait.SEQUENCE) def new_axis_default(): sequence = [] @@ -333,7 +333,7 @@ def new_axis_default(): concatenated_tensor = Tensor(Dtype.FP16x16, concatenated_tensor.shape, to_fp(concatenated_tensor.flatten(), FixedImpl.FP16x16)) name = "concat_from_sequence_fp16x16_new_axis_default" - make_test([sequence], concatenated_tensor, "TensorTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::None(()))", name) + make_test([sequence], concatenated_tensor, "SequenceTrait::concat_from_sequence(input_0, IntegerTrait::::new(1, false), Option::None(()))", name, Trait.SEQUENCE) new_axis_zero() new_axis_one() diff --git a/nodegen/node/sequence_at.py b/nodegen/node/sequence_at.py index b65ac5dae..a108ad52a 100644 --- a/nodegen/node/sequence_at.py +++ b/nodegen/node/sequence_at.py @@ -1,6 +1,6 @@ import numpy as np from nodegen.node import RunAll -from ..helpers import make_test, to_fp, Tensor, Dtype, FixedImpl +from ..helpers import make_test, to_fp, Tensor, Dtype, FixedImpl, Trait scalar = lambda x: Tensor(Dtype.I32, (), np.array([x]).astype(np.int32).flatten()) @@ -23,7 +23,7 @@ def positive_position(): position = scalar(2) name = "sequence_at_u32_positive" - make_test([sequence, position], sequence[2], "TensorTrait::sequence_at(input_0, input_1)", name) + make_test([sequence, position], sequence[2], "SequenceTrait::sequence_at(input_0, input_1)", name, Trait.SEQUENCE) def negative_position(): sequence = [] @@ -38,7 +38,7 @@ def negative_position(): position = scalar(-2) name = "sequence_at_u32_negative" - make_test([sequence, position], sequence[-2], "TensorTrait::sequence_at(input_0, input_1)", name) + make_test([sequence, position], sequence[-2], "SequenceTrait::sequence_at(input_0, input_1)", name, Trait.SEQUENCE) positive_position() negative_position() @@ -59,7 +59,7 @@ def positive_position(): position = scalar(2) name = "sequence_at_i32_positive" - make_test([sequence, position], sequence[2], "TensorTrait::sequence_at(input_0, input_1)", name) + make_test([sequence, position], sequence[2], "SequenceTrait::sequence_at(input_0, input_1)", name, Trait.SEQUENCE) def negative_position(): sequence = [] @@ -74,7 +74,7 @@ def negative_position(): position = scalar(-2) name = "sequence_at_i32_negative" - make_test([sequence, position], sequence[-2], "TensorTrait::sequence_at(input_0, input_1)", name) + make_test([sequence, position], sequence[-2], "SequenceTrait::sequence_at(input_0, input_1)", name, Trait.SEQUENCE) positive_position() negative_position() @@ -95,7 +95,7 @@ def positive_position(): position = scalar(2) name = "sequence_at_i8_positive" - make_test([sequence, position], sequence[2], "TensorTrait::sequence_at(input_0, input_1)", name) + make_test([sequence, position], sequence[2], "SequenceTrait::sequence_at(input_0, input_1)", name, Trait.SEQUENCE) def negative_position(): sequence = [] @@ -110,7 +110,7 @@ def negative_position(): position = scalar(-2) name = "sequence_at_i8_negative" - make_test([sequence, position], sequence[-2], "TensorTrait::sequence_at(input_0, input_1)", name) + make_test([sequence, position], sequence[-2], "SequenceTrait::sequence_at(input_0, input_1)", name, Trait.SEQUENCE) positive_position() negative_position() @@ -131,7 +131,7 @@ def positive_position(): position = scalar(2) name = "sequence_at_fp8x23_positive" - make_test([sequence, position], sequence[2], "TensorTrait::sequence_at(input_0, input_1)", name) + make_test([sequence, position], sequence[2], "SequenceTrait::sequence_at(input_0, input_1)", name, Trait.SEQUENCE) def negative_position(): sequence = [] @@ -146,7 +146,7 @@ def negative_position(): position = scalar(-2) name = "sequence_at_fp8x23_negative" - make_test([sequence, position], sequence[-2], "TensorTrait::sequence_at(input_0, input_1)", name) + make_test([sequence, position], sequence[-2], "SequenceTrait::sequence_at(input_0, input_1)", name, Trait.SEQUENCE) positive_position() negative_position() @@ -167,7 +167,7 @@ def positive_position(): position = scalar(2) name = "sequence_at_fp16x16_positive" - make_test([sequence, position], sequence[2], "TensorTrait::sequence_at(input_0, input_1)", name) + make_test([sequence, position], sequence[2], "SequenceTrait::sequence_at(input_0, input_1)", name, Trait.SEQUENCE) def negative_position(): sequence = [] @@ -182,7 +182,7 @@ def negative_position(): position = scalar(-2) name = "sequence_at_fp16x16_negative" - make_test([sequence, position], sequence[-2], "TensorTrait::sequence_at(input_0, input_1)", name) + make_test([sequence, position], sequence[-2], "SequenceTrait::sequence_at(input_0, input_1)", name, Trait.SEQUENCE) positive_position() negative_position() diff --git a/nodegen/node/sequence_erase.py b/nodegen/node/sequence_erase.py index 8f8f07c46..a57b2e7d2 100644 --- a/nodegen/node/sequence_erase.py +++ b/nodegen/node/sequence_erase.py @@ -1,6 +1,6 @@ import numpy as np from nodegen.node import RunAll -from ..helpers import make_test, to_fp, Tensor, Dtype, FixedImpl +from ..helpers import make_test, to_fp, Tensor, Dtype, FixedImpl, Trait scalar = lambda x: Tensor(Dtype.I32, (), np.array([x]).astype(np.int32).flatten()) @@ -26,7 +26,7 @@ def positive_position(): output_sequence.pop(2) name = "sequence_erase_u32_positive" - make_test([sequence, position], output_sequence, "TensorTrait::sequence_erase(input_0, Option::Some(input_1))", name) + make_test([sequence, position], output_sequence, "SequenceTrait::sequence_erase(input_0, Option::Some(input_1))", name, Trait.SEQUENCE) def negative_position(): sequence = [] @@ -44,7 +44,7 @@ def negative_position(): output_sequence.pop(-2) name = "sequence_erase_u32_negative" - make_test([sequence, position], output_sequence, "TensorTrait::sequence_erase(input_0, Option::Some(input_1))", name) + make_test([sequence, position], output_sequence, "SequenceTrait::sequence_erase(input_0, Option::Some(input_1))", name, Trait.SEQUENCE) def empty_position(): sequence = [] @@ -60,7 +60,7 @@ def empty_position(): output_sequence.pop(-1) name = "sequence_erase_u32_empty" - make_test([sequence], output_sequence, "TensorTrait::sequence_erase(input_0, Option::None(()))", name) + make_test([sequence], output_sequence, "SequenceTrait::sequence_erase(input_0, Option::None(()))", name, Trait.SEQUENCE) positive_position() negative_position() @@ -85,7 +85,7 @@ def positive_position(): output_sequence.pop(2) name = "sequence_erase_i32_positive" - make_test([sequence, position], output_sequence, "TensorTrait::sequence_erase(input_0, Option::Some(input_1))", name) + make_test([sequence, position], output_sequence, "SequenceTrait::sequence_erase(input_0, Option::Some(input_1))", name, Trait.SEQUENCE) def negative_position(): sequence = [] @@ -103,7 +103,7 @@ def negative_position(): output_sequence.pop(-2) name = "sequence_erase_i32_negative" - make_test([sequence, position], output_sequence, "TensorTrait::sequence_erase(input_0, Option::Some(input_1))", name) + make_test([sequence, position], output_sequence, "SequenceTrait::sequence_erase(input_0, Option::Some(input_1))", name, Trait.SEQUENCE) def empty_position(): sequence = [] @@ -119,7 +119,7 @@ def empty_position(): output_sequence.pop(-1) name = "sequence_erase_i32_empty" - make_test([sequence], output_sequence, "TensorTrait::sequence_erase(input_0, Option::None(()))", name) + make_test([sequence], output_sequence, "SequenceTrait::sequence_erase(input_0, Option::None(()))", name, Trait.SEQUENCE) positive_position() negative_position() @@ -144,7 +144,7 @@ def positive_position(): output_sequence.pop(2) name = "sequence_erase_i8_positive" - make_test([sequence, position], output_sequence, "TensorTrait::sequence_erase(input_0, Option::Some(input_1))", name) + make_test([sequence, position], output_sequence, "SequenceTrait::sequence_erase(input_0, Option::Some(input_1))", name, Trait.SEQUENCE) def negative_position(): sequence = [] @@ -162,7 +162,7 @@ def negative_position(): output_sequence.pop(-2) name = "sequence_erase_i8_negative" - make_test([sequence, position], output_sequence, "TensorTrait::sequence_erase(input_0, Option::Some(input_1))", name) + make_test([sequence, position], output_sequence, "SequenceTrait::sequence_erase(input_0, Option::Some(input_1))", name, Trait.SEQUENCE) def empty_position(): sequence = [] @@ -178,7 +178,7 @@ def empty_position(): output_sequence.pop(-1) name = "sequence_erase_i8_empty" - make_test([sequence], output_sequence, "TensorTrait::sequence_erase(input_0, Option::None(()))", name) + make_test([sequence], output_sequence, "SequenceTrait::sequence_erase(input_0, Option::None(()))", name, Trait.SEQUENCE) positive_position() negative_position() @@ -203,7 +203,7 @@ def positive_position(): output_sequence.pop(2) name = "sequence_erase_fp8x23_positive" - make_test([sequence, position], output_sequence, "TensorTrait::sequence_erase(input_0, Option::Some(input_1))", name) + make_test([sequence, position], output_sequence, "SequenceTrait::sequence_erase(input_0, Option::Some(input_1))", name, Trait.SEQUENCE) def negative_position(): sequence = [] @@ -221,7 +221,7 @@ def negative_position(): output_sequence.pop(-2) name = "sequence_erase_fp8x23_negative" - make_test([sequence, position], output_sequence, "TensorTrait::sequence_erase(input_0, Option::Some(input_1))", name) + make_test([sequence, position], output_sequence, "SequenceTrait::sequence_erase(input_0, Option::Some(input_1))", name, Trait.SEQUENCE) def empty_position(): sequence = [] @@ -237,7 +237,7 @@ def empty_position(): output_sequence.pop(-1) name = "sequence_erase_fp8x23_empty" - make_test([sequence], output_sequence, "TensorTrait::sequence_erase(input_0, Option::None(()))", name) + make_test([sequence], output_sequence, "SequenceTrait::sequence_erase(input_0, Option::None(()))", name, Trait.SEQUENCE) positive_position() negative_position() @@ -262,7 +262,7 @@ def positive_position(): output_sequence.pop(2) name = "sequence_erase_fp16x16_positive" - make_test([sequence, position], output_sequence, "TensorTrait::sequence_erase(input_0, Option::Some(input_1))", name) + make_test([sequence, position], output_sequence, "SequenceTrait::sequence_erase(input_0, Option::Some(input_1))", name, Trait.SEQUENCE) def negative_position(): sequence = [] @@ -280,7 +280,7 @@ def negative_position(): output_sequence.pop(-2) name = "sequence_erase_fp16x16_negative" - make_test([sequence, position], output_sequence, "TensorTrait::sequence_erase(input_0, Option::Some(input_1))", name) + make_test([sequence, position], output_sequence, "SequenceTrait::sequence_erase(input_0, Option::Some(input_1))", name, Trait.SEQUENCE) def empty_position(): sequence = [] @@ -296,7 +296,7 @@ def empty_position(): output_sequence.pop(-1) name = "sequence_erase_fp16x16_empty" - make_test([sequence], output_sequence, "TensorTrait::sequence_erase(input_0, Option::None(()))", name) + make_test([sequence], output_sequence, "SequenceTrait::sequence_erase(input_0, Option::None(()))", name, Trait.SEQUENCE) positive_position() negative_position() diff --git a/nodegen/node/sequence_insert.py b/nodegen/node/sequence_insert.py index 9cc2ceb82..e4ae7be55 100644 --- a/nodegen/node/sequence_insert.py +++ b/nodegen/node/sequence_insert.py @@ -1,6 +1,6 @@ import numpy as np from nodegen.node import RunAll -from ..helpers import make_test, to_fp, Tensor, Dtype, FixedImpl +from ..helpers import make_test, to_fp, Tensor, Dtype, FixedImpl, Trait scalar = lambda x: Tensor(Dtype.I32, (), np.array([x]).astype(np.int32).flatten()) @@ -30,7 +30,7 @@ def default(): expected_sequence.insert(position, tensor) name = "sequence_insert_u32" - make_test([sequence, tensor, scalar(position)], expected_sequence, "input_0.sequence_insert(@input_1,Option::Some(input_2))", name) + make_test([sequence, tensor, scalar(position)], expected_sequence, "input_0.sequence_insert(@input_1,Option::Some(input_2))", name, Trait.SEQUENCE) default() @@ -56,7 +56,7 @@ def default(): expected_sequence.insert(position, tensor) name = "sequence_insert_i32" - make_test([sequence, tensor, scalar(position)], expected_sequence, "input_0.sequence_insert(@input_1,Option::Some(input_2))", name) + make_test([sequence, tensor, scalar(position)], expected_sequence, "input_0.sequence_insert(@input_1,Option::Some(input_2))", name, Trait.SEQUENCE) default() @@ -82,7 +82,7 @@ def default(): expected_sequence.insert(position, tensor) name = "sequence_insert_i8" - make_test([sequence, tensor, scalar(position)], expected_sequence, "input_0.sequence_insert(@input_1,Option::Some(input_2))", name) + make_test([sequence, tensor, scalar(position)], expected_sequence, "input_0.sequence_insert(@input_1,Option::Some(input_2))", name, Trait.SEQUENCE) default() @@ -110,7 +110,7 @@ def default(): expected_sequence.insert(position, tensor) name = "sequence_insert_fp8x23" - make_test([sequence, tensor, scalar(position)], expected_sequence, "input_0.sequence_insert(@input_1,Option::Some(input_2))", name) + make_test([sequence, tensor, scalar(position)], expected_sequence, "input_0.sequence_insert(@input_1,Option::Some(input_2))", name, Trait.SEQUENCE) default() @@ -138,6 +138,6 @@ def default(): expected_sequence.insert(position, tensor) name = "sequence_insert_fp16x16" - make_test([sequence, tensor, scalar(position)], expected_sequence, "input_0.sequence_insert(@input_1,Option::Some(input_2))", name) + make_test([sequence, tensor, scalar(position)], expected_sequence, "input_0.sequence_insert(@input_1,Option::Some(input_2))", name, Trait.SEQUENCE) default() diff --git a/nodegen/node/sequence_length.py b/nodegen/node/sequence_length.py index 87f0dcd01..7b8993309 100644 --- a/nodegen/node/sequence_length.py +++ b/nodegen/node/sequence_length.py @@ -1,6 +1,6 @@ import numpy as np from nodegen.node import RunAll -from ..helpers import make_test, to_fp, Tensor, Dtype, FixedImpl +from ..helpers import make_test, to_fp, Tensor, Dtype, FixedImpl, Trait scalar = lambda x: Tensor(Dtype.U32, (), np.array([x]).astype(np.uint32).flatten()) @@ -21,7 +21,7 @@ def default(): sequence.append(tensor) name = "sequence_length_u32" - make_test([sequence], scalar(len(sequence)), "input_0.sequence_length()", name) + make_test([sequence], scalar(len(sequence)), "input_0.sequence_length()", name, Trait.SEQUENCE) def broadcast(): sequence = [] @@ -35,7 +35,7 @@ def broadcast(): sequence.append(tensor) name = "sequence_length_u32_broadcast" - make_test([sequence], scalar(len(sequence)), "input_0.sequence_length()", name) + make_test([sequence], scalar(len(sequence)), "input_0.sequence_length()", name, Trait.SEQUENCE) default() broadcast() @@ -54,7 +54,7 @@ def default(): sequence.append(tensor) name = "sequence_length_i32" - make_test([sequence], scalar(len(sequence)), "input_0.sequence_length()", name) + make_test([sequence], scalar(len(sequence)), "input_0.sequence_length()", name, Trait.SEQUENCE) def broadcast(): sequence = [] @@ -68,7 +68,7 @@ def broadcast(): sequence.append(tensor) name = "sequence_length_i32_broadcast" - make_test([sequence], scalar(len(sequence)), "input_0.sequence_length()", name) + make_test([sequence], scalar(len(sequence)), "input_0.sequence_length()", name, Trait.SEQUENCE) default() broadcast() @@ -87,7 +87,7 @@ def default(): sequence.append(tensor) name = "sequence_length_i8" - make_test([sequence], scalar(len(sequence)), "input_0.sequence_length()", name) + make_test([sequence], scalar(len(sequence)), "input_0.sequence_length()", name, Trait.SEQUENCE) def broadcast(): sequence = [] @@ -101,7 +101,7 @@ def broadcast(): sequence.append(tensor) name = "sequence_length_i8_broadcast" - make_test([sequence], scalar(len(sequence)), "input_0.sequence_length()", name) + make_test([sequence], scalar(len(sequence)), "input_0.sequence_length()", name, Trait.SEQUENCE) default() broadcast() @@ -120,7 +120,7 @@ def default(): sequence.append(tensor) name = "sequence_length_fp8x23" - make_test([sequence], scalar(len(sequence)), "input_0.sequence_length()", name) + make_test([sequence], scalar(len(sequence)), "input_0.sequence_length()", name, Trait.SEQUENCE) def broadcast(): sequence = [] @@ -134,7 +134,7 @@ def broadcast(): sequence.append(tensor) name = "sequence_length_fp8x23_broadcast" - make_test([sequence], scalar(len(sequence)), "input_0.sequence_length()", name) + make_test([sequence], scalar(len(sequence)), "input_0.sequence_length()", name, Trait.SEQUENCE) default() broadcast() @@ -153,7 +153,7 @@ def default(): sequence.append(tensor) name = "sequence_length_fp16x16" - make_test([sequence], scalar(len(sequence)), "input_0.sequence_length()", name) + make_test([sequence], scalar(len(sequence)), "input_0.sequence_length()", name, Trait.SEQUENCE) def broadcast(): sequence = [] @@ -167,7 +167,7 @@ def broadcast(): sequence.append(tensor) name = "sequence_length_fp16x16_broadcast" - make_test([sequence], scalar(len(sequence)), "input_0.sequence_length()", name) + make_test([sequence], scalar(len(sequence)), "input_0.sequence_length()", name, Trait.SEQUENCE) default() broadcast() diff --git a/src/operators/sequence/core.cairo b/src/operators/sequence/core.cairo index df33c5ab6..268e25b16 100644 --- a/src/operators/sequence/core.cairo +++ b/src/operators/sequence/core.cairo @@ -1,9 +1,15 @@ use orion::operators::tensor::core::Tensor; +use orion::numbers::signed_integer::i32::i32; /// Trait /// /// sequence_construct - Constructs a tensor sequence containing the input tensors. /// sequence_empty - Returns an empty tensor sequence. +/// sequence_length - Returns the length of the input sequence. +/// sequence_insert - Insert a tensor into a sequence. +/// sequence_at - Outputs the tensor at the specified position in the input sequence. +/// sequence_erase – Outputs the tensor sequence with the erased tensor at the specified position. +/// concat_from_sequence - Concatenate a sequence of tensors into a single tensor. trait SequenceTrait { /// ## sequence.sequence_construct /// @@ -81,4 +87,262 @@ trait SequenceTrait { /// ``` /// fn sequence_empty() -> Array>; + /// # tensor.sequence_length + /// + /// ```rust + /// fn sequence_length(self: Array>) -> Tensor; + /// ``` + /// + /// Returns the length of the input sequence. + /// + /// ## Args + /// + /// * `self`(`Array>`) - The input sequence. + /// + /// ## Returns + /// + /// The length of the sequence as scalar, i.e. a tensor of shape []. + /// + /// ## Examples + /// + /// Let's create new u32 Tensor with constant 42. + /// + /// ```rust + /// let mut sequence = ArrayTrait::new(); + /// + /// let mut shape = ArrayTrait::::new(); + /// shape.append(1); + /// shape.append(2); + /// + /// let mut data = ArrayTrait::new(); + /// data.append(3); + /// data.append(1); + /// + /// sequence.append(TensorTrait::new(shape.span(), data.span())); + /// + /// sequence.sequence_length() + /// >>> [1] + /// ``` + /// + fn sequence_length(self: Array>) -> Tensor; + /// # tensor.sequence_insert + /// + /// ```rust + /// fn sequence_insert(self: Array>, tensor: @Tensor, position: Option>) -> Array>; + /// ``` + /// + /// Returns a tensor sequence that inserts 'tensor' into 'self' at 'position'. + /// + /// ## Args + /// + /// * `self`(`Array>`) - input sequence. + /// * `tensor` (`@Tensor`) - the tensor to insert. + /// * `position` (`@Tensor`) - the index for insertion (default: -1). + /// + /// ## Returns + /// + /// Tensor sequence containing 'tensor' inserted into 'self' at 'position'. + /// + /// ## Examples + /// + /// Let's insert the tensor [2] into the sequence [[1], [3]] at position 1. + /// use orion::operators::tensor::{TensorTrait, Tensor, I32Tensor, U32Tensor}; + /// + /// fn sequence_insert_example() -> Array> { + /// // Prepare sequence + /// let mut sequence = ArrayTrait::new(); + /// let mut shape = ArrayTrait::::new(); + /// shape.append(1); + /// + /// let mut data = ArrayTrait::new(); + /// data.append(1); + /// sequence.append(TensorTrait::new(shape.span(), data.span())); + /// let mut data = ArrayTrait::new(); + /// data.append(3); + /// + /// sequence.append(TensorTrait::new(shape.span(), data.span())); + /// + /// // Prepare input tensor + /// let mut data = ArrayTrait::new(); + /// data.append(2); + /// let tensor = TensorTrait::new(shape.span(), data.span()); + /// + /// // Prepare position + /// let mut shape = ArrayTrait::::new(); + /// let mut data = ArrayTrait::::new(); + /// data.append(i32 { mag: 1, sign: false }); + /// let position = TensorTrait::::new(shape.span(), data.span()) + /// + /// let sequence = self.sequence_insert(tensor, Option::Some(position)); + /// + /// return sequence; + /// } + /// + /// >>> [[1], [2], [3]] + /// ``` + /// + fn sequence_insert( + self: Array>, tensor: @Tensor, position: Option> + ) -> Array>; + /// ## tensor.sequence_at + /// + /// ```rust + /// fn sequence_at(sequence: Array>, position: Tensor) -> Tensor; + /// ``` + /// + /// Outputs the tensor at the specified position in the input sequence. + /// + /// ## Args + /// + /// * `tensors`(`Array>`) - The tensor sequence. + /// * `position`(`Tensor`) - The position tensor. + /// + /// ## Panics + /// + /// * Panics if position is not a scalar + /// * Panics if position is out of bounds [-n, n - 1] + /// + /// ## Returns + /// + /// The tensor `Tensor` from the sequence at the specified position. + /// + /// ## Examples + /// + /// ```rust + /// use core::array::{ArrayTrait, SpanTrait}; + /// + /// use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor, I32Tensor}; + /// use orion::numbers::{i32, IntegerTrait}; + /// + /// fn sequence_at_example() -> Tensor { + /// let tensor1 = TensorTrait::new(shape: array![2, 2].span(), data: array![0, 1, 2, 3].span()); + /// let tensor2 = TensorTrait::new(shape: array![2, 2].span(), data: array![4, 5, 6, 7].span()); + /// + /// let mut sequence = ArrayTrait::new(); + /// sequence.append(tensor1); + /// sequence.append(tensor2); + /// + /// let position = TensorTrait::new(shape: array![].span(), data: array![IntegerTrait::new(1, false)].span()); + /// + /// let result = TensorTrait::sequence_at(sequence, position); + /// return result; + /// } + /// >>> [4, 5, 6, 7] + /// ``` + /// + fn sequence_at(sequence: Array>, position: Tensor) -> Tensor; + /// ## tensor.sequence_erase + /// + /// ```rust + /// fn sequence_erase(sequence: Array>, position: Option>) -> Array>; + /// ``` + /// + /// Outputs the tensor sequence with the erased tensor at the specified position. + /// + /// ## Args + /// + /// * `tensors`(`Array>`) - The tensor sequence. + /// * `position`(`Option>`) - The optional position tensor (by default erases the last tensor). + /// + /// ## Panics + /// + /// * Panics if position is not a scalar + /// * Panics if position is out of bounds [-n, n - 1] + /// + /// ## Returns + /// + /// The tensor sequence `Array>` with the erased tensor at the specified position. + /// + /// ## Examples + /// + /// ```rust + /// use core::array::{ArrayTrait, SpanTrait}; + /// + /// use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor, I32Tensor}; + /// use orion::numbers::{i32, IntegerTrait}; + /// + /// fn sequence_erase_example() -> Tensor { + /// let tensor1 = TensorTrait::new(shape: array![2, 2].span(), data: array![0, 1, 2, 3].span()); + /// let tensor2 = TensorTrait::new(shape: array![2, 2].span(), data: array![4, 5, 6, 7].span()); + /// let tensor3 = TensorTrait::new(shape: array![2, 2].span(), data: array![8, 9, 10, 11].span()); + /// + /// let mut sequence = ArrayTrait::new(); + /// sequence.append(tensor1); + /// sequence.append(tensor2); + /// sequence.append(tensor3); + /// + /// let position = TensorTrait::new(shape: array![].span(), data: array![IntegerTrait::new(1, false)].span()); + /// + /// let result = TensorTrait::sequence_erase(sequence, position); + /// return result; + /// } + /// >>> [[0, 1, 2, 3], [8, 9, 10, 11]] + /// ``` + /// + fn sequence_erase( + sequence: Array>, position: Option> + ) -> Array>; + /// # tensor.concat_from_sequence + /// + /// ```rust + /// fn concat_from_sequence(sequence: Array>, axis: i32, new_axis: Option) -> Tensor; + /// ``` + /// + /// Concatenate a sequence of tensors into a single tensor. + /// + /// ## Args + /// + /// * `sequence`(`Array>`) - The input sequence. + /// * `axis`(`i32`) - Axis to concat on. + /// * `new_axis`(`Option`) - Optionally added new axis. + /// + /// ## Panics + /// + /// * Panics if new_axis not 0 or 1 (if value provided). + /// * Panics if axis not in accepted ranges. + /// * Panics if sequence length is not greater than 1. + /// + /// ## Returns + /// + /// A new `Tensor` concatenated tensor from the input tensor sequence. + /// + /// ## Example + /// + /// ```rust + /// use core::array::{ArrayTrait, SpanTrait}; + /// + /// use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor}; + /// + /// fn concat_example() -> Tensor { + /// let tensor1 = TensorTrait::new(shape: array![2, 2].span(), data: array![0, 1, 2, 3].span(),); + /// let tensor2 = TensorTrait::new(shape: array![2, 2].span(), data: array![0, 1, 2, 3].span(),); + /// + /// let mut sequence = ArrayTrait::new(); + /// sequence.append(tensor1); + /// sequence.append(tensor2); + /// + /// let result = TensorTrait::concat_from_sequence(sequence: sequence, axis: 0, new_axis: Option::Some(0)); + /// return result; + /// } + /// >>> [[0. 1.] + /// [2. 3.], + /// [0. 1.] + /// [2. 3.]] + /// + /// result.shape + /// >>> (4, 2) + /// + /// let result = TensorTrait::concat_from_sequence(sequence: sequence, axis: 1, new_axis: Option::Some(0)); + /// return result; + /// } + /// >>> [[0. 1., 0., 1.] + /// [2. 3., 2., 3.]] + /// + /// result.shape + /// >>> (2, 4 ) + /// ``` + /// + fn concat_from_sequence( + sequence: Array>, axis: i32, new_axis: Option + ) -> Tensor; } diff --git a/src/operators/sequence/functional.cairo b/src/operators/sequence/functional.cairo index 17b7014a0..84f30cfc7 100644 --- a/src/operators/sequence/functional.cairo +++ b/src/operators/sequence/functional.cairo @@ -1,2 +1,7 @@ mod sequence_construct; -mod sequence_empty; \ No newline at end of file +mod sequence_empty; +mod sequence_at; +mod sequence_erase; +mod sequence_insert; +mod sequence_length; +mod concat_from_sequence; \ No newline at end of file diff --git a/src/operators/tensor/math/concat_from_sequence.cairo b/src/operators/sequence/functional/concat_from_sequence.cairo similarity index 100% rename from src/operators/tensor/math/concat_from_sequence.cairo rename to src/operators/sequence/functional/concat_from_sequence.cairo diff --git a/src/operators/tensor/math/sequence_at.cairo b/src/operators/sequence/functional/sequence_at.cairo similarity index 95% rename from src/operators/tensor/math/sequence_at.cairo rename to src/operators/sequence/functional/sequence_at.cairo index 81f6776ac..6108b97d4 100644 --- a/src/operators/tensor/math/sequence_at.cairo +++ b/src/operators/sequence/functional/sequence_at.cairo @@ -5,7 +5,7 @@ use orion::operators::tensor::core::{Tensor, TensorTrait}; use orion::numbers::NumberTrait; use orion::numbers::signed_integer::i32::i32; -/// Cf: TensorTrait::sequence_at docstring +/// Cf: SequenceTrait::sequence_at docstring fn sequence_at, impl TCopy: Copy, impl TDrop: Drop>( sequence: Array>, position: Tensor ) -> Tensor { diff --git a/src/operators/tensor/math/sequence_erase.cairo b/src/operators/sequence/functional/sequence_erase.cairo similarity index 97% rename from src/operators/tensor/math/sequence_erase.cairo rename to src/operators/sequence/functional/sequence_erase.cairo index d2f8de500..69349b8ad 100644 --- a/src/operators/tensor/math/sequence_erase.cairo +++ b/src/operators/sequence/functional/sequence_erase.cairo @@ -6,7 +6,7 @@ use orion::operators::tensor::I32Tensor; use orion::numbers::NumberTrait; use orion::numbers::signed_integer::i32::i32; -/// Cf: TensorTrait::sequence_erase docstring +/// Cf: SequenceTrait::sequence_erase docstring fn sequence_erase, impl TCopy: Copy, impl TDrop: Drop>( sequence: Array>, position: Option> ) -> Array> { diff --git a/src/operators/tensor/math/sequence_insert.cairo b/src/operators/sequence/functional/sequence_insert.cairo similarity index 97% rename from src/operators/tensor/math/sequence_insert.cairo rename to src/operators/sequence/functional/sequence_insert.cairo index a5f7ce673..a140938dc 100644 --- a/src/operators/tensor/math/sequence_insert.cairo +++ b/src/operators/sequence/functional/sequence_insert.cairo @@ -6,7 +6,7 @@ use orion::operators::tensor::I32Tensor; use orion::numbers::NumberTrait; use orion::numbers::signed_integer::i32::i32; -/// Cf: TensorTrait::sequence_insert docstring +/// Cf: SequenceTrait::sequence_insert docstring fn sequence_insert, impl TCopy: Copy, impl TDrop: Drop>( self: Array>, tensor: @Tensor, position: Option> ) -> Array> { diff --git a/src/operators/tensor/math/sequence_length.cairo b/src/operators/sequence/functional/sequence_length.cairo similarity index 88% rename from src/operators/tensor/math/sequence_length.cairo rename to src/operators/sequence/functional/sequence_length.cairo index d87aaf357..84f91e48f 100644 --- a/src/operators/tensor/math/sequence_length.cairo +++ b/src/operators/sequence/functional/sequence_length.cairo @@ -3,7 +3,7 @@ use core::array::{ArrayTrait, SpanTrait}; use orion::operators::tensor::{TensorTrait, Tensor}; -/// Cf: TensorTrait::sequence_length docstring +/// Cf: SequenceTrait::sequence_length docstring fn sequence_length>(self: Array>) -> Tensor { let mut shape = ArrayTrait::::new(); let mut result = ArrayTrait::new(); diff --git a/src/operators/sequence/implementations/sequence_bool.cairo b/src/operators/sequence/implementations/sequence_bool.cairo index 0a554efda..1ac241e41 100644 --- a/src/operators/sequence/implementations/sequence_bool.cairo +++ b/src/operators/sequence/implementations/sequence_bool.cairo @@ -4,7 +4,8 @@ use orion::operators::tensor::core::Tensor; use orion::operators::sequence::core::SequenceTrait; use orion::operators::sequence::functional; use orion::operators::tensor::implementations::tensor_bool::BoolTensor; - +use orion::numbers::signed_integer::i32::i32; +use orion::operators::tensor::implementations::tensor_i32::I32Tensor; impl BoolSequence of SequenceTrait { fn sequence_construct(tensors: Array>) -> Array> { @@ -14,4 +15,31 @@ impl BoolSequence of SequenceTrait { fn sequence_empty() -> Array> { functional::sequence_empty::sequence_empty::() } + + fn sequence_length(self: Array>) -> Tensor { + functional::sequence_length::sequence_length(self) + } + + fn sequence_at(sequence: Array>, position: Tensor) -> Tensor { + functional::sequence_at::sequence_at(sequence, position) + } + + fn sequence_erase( + sequence: Array>, position: Option> + ) -> Array> { + functional::sequence_erase::sequence_erase(sequence, position) + } + + fn sequence_insert( + self: Array>, tensor: @Tensor, position: Option> + ) -> Array> { + functional::sequence_insert::sequence_insert(self, tensor, position) + } + + fn concat_from_sequence( + sequence: Array>, axis: i32, new_axis: Option + ) -> Tensor { + functional::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) + } + } diff --git a/src/operators/sequence/implementations/sequence_fp16x16.cairo b/src/operators/sequence/implementations/sequence_fp16x16.cairo index 29bbe4f56..5da4af097 100644 --- a/src/operators/sequence/implementations/sequence_fp16x16.cairo +++ b/src/operators/sequence/implementations/sequence_fp16x16.cairo @@ -5,6 +5,8 @@ use orion::operators::sequence::core::SequenceTrait; use orion::operators::sequence::functional; use orion::numbers::fixed_point::implementations::fp16x16::core::FP16x16; use orion::operators::tensor::implementations::tensor_fp16x16::FP16x16Tensor; +use orion::numbers::signed_integer::i32::i32; +use orion::operators::tensor::implementations::tensor_i32::I32Tensor; impl FP16x16Sequence of SequenceTrait { @@ -15,4 +17,30 @@ impl FP16x16Sequence of SequenceTrait { fn sequence_empty() -> Array> { functional::sequence_empty::sequence_empty::() } + + fn sequence_length(self: Array>) -> Tensor { + functional::sequence_length::sequence_length(self) + } + + fn sequence_at(sequence: Array>, position: Tensor) -> Tensor { + functional::sequence_at::sequence_at(sequence, position) + } + + fn sequence_erase( + sequence: Array>, position: Option> + ) -> Array> { + functional::sequence_erase::sequence_erase(sequence, position) + } + + fn sequence_insert( + self: Array>, tensor: @Tensor, position: Option> + ) -> Array> { + functional::sequence_insert::sequence_insert(self, tensor, position) + } + + fn concat_from_sequence( + sequence: Array>, axis: i32, new_axis: Option + ) -> Tensor { + functional::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) + } } diff --git a/src/operators/sequence/implementations/sequence_fp16x16wide.cairo b/src/operators/sequence/implementations/sequence_fp16x16wide.cairo index 66b506a68..3b2273edb 100644 --- a/src/operators/sequence/implementations/sequence_fp16x16wide.cairo +++ b/src/operators/sequence/implementations/sequence_fp16x16wide.cairo @@ -5,6 +5,8 @@ use orion::operators::sequence::core::SequenceTrait; use orion::operators::sequence::functional; use orion::numbers::fixed_point::implementations::fp16x16wide::core::FP16x16W; use orion::operators::tensor::implementations::tensor_fp16x16wide::FP16x16WTensor; +use orion::numbers::signed_integer::i32::i32; +use orion::operators::tensor::implementations::tensor_i32::I32Tensor; impl FP16x16WSequence of SequenceTrait { @@ -15,4 +17,30 @@ impl FP16x16WSequence of SequenceTrait { fn sequence_empty() -> Array> { functional::sequence_empty::sequence_empty::() } + + fn sequence_length(self: Array>) -> Tensor { + functional::sequence_length::sequence_length(self) + } + + fn sequence_at(sequence: Array>, position: Tensor) -> Tensor { + functional::sequence_at::sequence_at(sequence, position) + } + + fn sequence_erase( + sequence: Array>, position: Option> + ) -> Array> { + functional::sequence_erase::sequence_erase(sequence, position) + } + + fn sequence_insert( + self: Array>, tensor: @Tensor, position: Option> + ) -> Array> { + functional::sequence_insert::sequence_insert(self, tensor, position) + } + + fn concat_from_sequence( + sequence: Array>, axis: i32, new_axis: Option + ) -> Tensor { + functional::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) + } } diff --git a/src/operators/sequence/implementations/sequence_fp32x32.cairo b/src/operators/sequence/implementations/sequence_fp32x32.cairo index b0ddc0e92..fd63f9639 100644 --- a/src/operators/sequence/implementations/sequence_fp32x32.cairo +++ b/src/operators/sequence/implementations/sequence_fp32x32.cairo @@ -5,6 +5,8 @@ use orion::operators::sequence::core::SequenceTrait; use orion::operators::sequence::functional; use orion::numbers::fixed_point::implementations::fp32x32::core::FP32x32; use orion::operators::tensor::implementations::tensor_fp32x32::FP32x32Tensor; +use orion::numbers::signed_integer::i32::i32; +use orion::operators::tensor::implementations::tensor_i32::I32Tensor; impl FP32x32Sequence of SequenceTrait { @@ -15,4 +17,30 @@ impl FP32x32Sequence of SequenceTrait { fn sequence_empty() -> Array> { functional::sequence_empty::sequence_empty::() } + + fn sequence_length(self: Array>) -> Tensor { + functional::sequence_length::sequence_length(self) + } + + fn sequence_at(sequence: Array>, position: Tensor) -> Tensor { + functional::sequence_at::sequence_at(sequence, position) + } + + fn sequence_erase( + sequence: Array>, position: Option> + ) -> Array> { + functional::sequence_erase::sequence_erase(sequence, position) + } + + fn sequence_insert( + self: Array>, tensor: @Tensor, position: Option> + ) -> Array> { + functional::sequence_insert::sequence_insert(self, tensor, position) + } + + fn concat_from_sequence( + sequence: Array>, axis: i32, new_axis: Option + ) -> Tensor { + functional::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) + } } diff --git a/src/operators/sequence/implementations/sequence_fp64x64.cairo b/src/operators/sequence/implementations/sequence_fp64x64.cairo index f9a6759cd..43424f9eb 100644 --- a/src/operators/sequence/implementations/sequence_fp64x64.cairo +++ b/src/operators/sequence/implementations/sequence_fp64x64.cairo @@ -5,6 +5,8 @@ use orion::operators::sequence::core::SequenceTrait; use orion::operators::sequence::functional; use orion::numbers::fixed_point::implementations::fp64x64::core::FP64x64; use orion::operators::tensor::implementations::tensor_fp64x64::FP64x64Tensor; +use orion::numbers::signed_integer::i32::i32; +use orion::operators::tensor::implementations::tensor_i32::I32Tensor; impl FP64x64Sequence of SequenceTrait { @@ -15,4 +17,30 @@ impl FP64x64Sequence of SequenceTrait { fn sequence_empty() -> Array> { functional::sequence_empty::sequence_empty::() } + + fn sequence_length(self: Array>) -> Tensor { + functional::sequence_length::sequence_length(self) + } + + fn sequence_at(sequence: Array>, position: Tensor) -> Tensor { + functional::sequence_at::sequence_at(sequence, position) + } + + fn sequence_erase( + sequence: Array>, position: Option> + ) -> Array> { + functional::sequence_erase::sequence_erase(sequence, position) + } + + fn sequence_insert( + self: Array>, tensor: @Tensor, position: Option> + ) -> Array> { + functional::sequence_insert::sequence_insert(self, tensor, position) + } + + fn concat_from_sequence( + sequence: Array>, axis: i32, new_axis: Option + ) -> Tensor { + functional::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) + } } diff --git a/src/operators/sequence/implementations/sequence_fp8x23.cairo b/src/operators/sequence/implementations/sequence_fp8x23.cairo index ac70a7e47..08ad26a10 100644 --- a/src/operators/sequence/implementations/sequence_fp8x23.cairo +++ b/src/operators/sequence/implementations/sequence_fp8x23.cairo @@ -5,6 +5,8 @@ use orion::operators::sequence::core::SequenceTrait; use orion::operators::sequence::functional; use orion::numbers::fixed_point::implementations::fp8x23::core::FP8x23; use orion::operators::tensor::implementations::tensor_fp8x23::FP8x23Tensor; +use orion::numbers::signed_integer::i32::i32; +use orion::operators::tensor::implementations::tensor_i32::I32Tensor; impl FP8x23Sequence of SequenceTrait { @@ -15,4 +17,30 @@ impl FP8x23Sequence of SequenceTrait { fn sequence_empty() -> Array> { functional::sequence_empty::sequence_empty::() } + + fn sequence_length(self: Array>) -> Tensor { + functional::sequence_length::sequence_length(self) + } + + fn sequence_at(sequence: Array>, position: Tensor) -> Tensor { + functional::sequence_at::sequence_at(sequence, position) + } + + fn sequence_erase( + sequence: Array>, position: Option> + ) -> Array> { + functional::sequence_erase::sequence_erase(sequence, position) + } + + fn sequence_insert( + self: Array>, tensor: @Tensor, position: Option> + ) -> Array> { + functional::sequence_insert::sequence_insert(self, tensor, position) + } + + fn concat_from_sequence( + sequence: Array>, axis: i32, new_axis: Option + ) -> Tensor { + functional::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) + } } diff --git a/src/operators/sequence/implementations/sequence_fp8x23wide.cairo b/src/operators/sequence/implementations/sequence_fp8x23wide.cairo index 544f931aa..eaebb072d 100644 --- a/src/operators/sequence/implementations/sequence_fp8x23wide.cairo +++ b/src/operators/sequence/implementations/sequence_fp8x23wide.cairo @@ -5,6 +5,8 @@ use orion::operators::sequence::core::SequenceTrait; use orion::operators::sequence::functional; use orion::numbers::fixed_point::implementations::fp8x23wide::core::FP8x23W; use orion::operators::tensor::implementations::tensor_fp8x23wide::FP8x23WTensor; +use orion::numbers::signed_integer::i32::i32; +use orion::operators::tensor::implementations::tensor_i32::I32Tensor; impl FP8x23WSequence of SequenceTrait { @@ -15,4 +17,31 @@ impl FP8x23WSequence of SequenceTrait { fn sequence_empty() -> Array> { functional::sequence_empty::sequence_empty::() } + + fn sequence_length(self: Array>) -> Tensor { + functional::sequence_length::sequence_length(self) + } + + fn sequence_at(sequence: Array>, position: Tensor) -> Tensor { + functional::sequence_at::sequence_at(sequence, position) + } + + fn sequence_erase( + sequence: Array>, position: Option> + ) -> Array> { + functional::sequence_erase::sequence_erase(sequence, position) + } + + fn sequence_insert( + self: Array>, tensor: @Tensor, position: Option> + ) -> Array> { + functional::sequence_insert::sequence_insert(self, tensor, position) + } + + fn concat_from_sequence( + sequence: Array>, axis: i32, new_axis: Option + ) -> Tensor { + functional::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) + } + } diff --git a/src/operators/sequence/implementations/sequence_i32.cairo b/src/operators/sequence/implementations/sequence_i32.cairo index 28d543fff..40892070a 100644 --- a/src/operators/sequence/implementations/sequence_i32.cairo +++ b/src/operators/sequence/implementations/sequence_i32.cairo @@ -15,4 +15,30 @@ impl I32Sequence of SequenceTrait { fn sequence_empty() -> Array> { functional::sequence_empty::sequence_empty::() } + + fn sequence_length(self: Array>) -> Tensor { + functional::sequence_length::sequence_length(self) + } + + fn sequence_at(sequence: Array>, position: Tensor) -> Tensor { + functional::sequence_at::sequence_at(sequence, position) + } + + fn sequence_erase( + sequence: Array>, position: Option> + ) -> Array> { + functional::sequence_erase::sequence_erase(sequence, position) + } + + fn sequence_insert( + self: Array>, tensor: @Tensor, position: Option> + ) -> Array> { + functional::sequence_insert::sequence_insert(self, tensor, position) + } + + fn concat_from_sequence( + sequence: Array>, axis: i32, new_axis: Option + ) -> Tensor { + functional::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) + } } diff --git a/src/operators/sequence/implementations/sequence_i8.cairo b/src/operators/sequence/implementations/sequence_i8.cairo index 73b886299..2510f3cf2 100644 --- a/src/operators/sequence/implementations/sequence_i8.cairo +++ b/src/operators/sequence/implementations/sequence_i8.cairo @@ -5,6 +5,8 @@ use orion::operators::sequence::core::SequenceTrait; use orion::operators::sequence::functional; use orion::numbers::signed_integer::i8::i8; use orion::operators::tensor::implementations::tensor_i8::I8Tensor; +use orion::numbers::signed_integer::i32::i32; +use orion::operators::tensor::implementations::tensor_i32::I32Tensor; impl I8Sequence of SequenceTrait { @@ -15,4 +17,30 @@ impl I8Sequence of SequenceTrait { fn sequence_empty() -> Array> { functional::sequence_empty::sequence_empty::() } + + fn sequence_length(self: Array>) -> Tensor { + functional::sequence_length::sequence_length(self) + } + + fn sequence_at(sequence: Array>, position: Tensor) -> Tensor { + functional::sequence_at::sequence_at(sequence, position) + } + + fn sequence_erase( + sequence: Array>, position: Option> + ) -> Array> { + functional::sequence_erase::sequence_erase(sequence, position) + } + + fn sequence_insert( + self: Array>, tensor: @Tensor, position: Option> + ) -> Array> { + functional::sequence_insert::sequence_insert(self, tensor, position) + } + + fn concat_from_sequence( + sequence: Array>, axis: i32, new_axis: Option + ) -> Tensor { + functional::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) + } } diff --git a/src/operators/sequence/implementations/sequence_u32.cairo b/src/operators/sequence/implementations/sequence_u32.cairo index a6bdb25e9..ff0e57212 100644 --- a/src/operators/sequence/implementations/sequence_u32.cairo +++ b/src/operators/sequence/implementations/sequence_u32.cairo @@ -4,6 +4,8 @@ use orion::operators::tensor::core::Tensor; use orion::operators::sequence::core::SequenceTrait; use orion::operators::sequence::functional; use orion::operators::tensor::implementations::tensor_u32::U32Tensor; +use orion::numbers::signed_integer::i32::i32; +use orion::operators::tensor::implementations::tensor_i32::I32Tensor; impl U32Sequence of SequenceTrait { @@ -14,4 +16,30 @@ impl U32Sequence of SequenceTrait { fn sequence_empty() -> Array> { functional::sequence_empty::sequence_empty::() } + + fn sequence_length(self: Array>) -> Tensor { + functional::sequence_length::sequence_length(self) + } + + fn sequence_at(sequence: Array>, position: Tensor) -> Tensor { + functional::sequence_at::sequence_at(sequence, position) + } + + fn sequence_erase( + sequence: Array>, position: Option> + ) -> Array> { + functional::sequence_erase::sequence_erase(sequence, position) + } + + fn sequence_insert( + self: Array>, tensor: @Tensor, position: Option> + ) -> Array> { + functional::sequence_insert::sequence_insert(self, tensor, position) + } + + fn concat_from_sequence( + sequence: Array>, axis: i32, new_axis: Option + ) -> Tensor { + functional::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) + } } diff --git a/src/operators/tensor/core.cairo b/src/operators/tensor/core.cairo index d20f8911d..92389b303 100644 --- a/src/operators/tensor/core.cairo +++ b/src/operators/tensor/core.cairo @@ -103,16 +103,11 @@ impl TensorSerde, impl TDrop: Drop> of Serde { keepdims: Option, noop_with_empty_axes: Option ) -> Tensor; - /// # tensor.sequence_insert - /// - /// ```rust - /// fn sequence_insert(self: Array>, tensor: @Tensor, position: Option>) -> Array>; - /// ``` - /// - /// Returns a tensor sequence that inserts 'tensor' into 'self' at 'position'. - /// - /// ## Args - /// - /// * `self`(`Array>`) - input sequence. - /// * `tensor` (`@Tensor`) - the tensor to insert. - /// * `position` (`@Tensor`) - the index for insertion (default: -1). - /// - /// ## Returns - /// - /// Tensor sequence containing 'tensor' inserted into 'self' at 'position'. - /// - /// ## Examples - /// - /// Let's insert the tensor [2] into the sequence [[1], [3]] at position 1. - /// use orion::operators::tensor::{TensorTrait, Tensor, I32Tensor, U32Tensor}; - /// - /// fn sequence_insert_example() -> Array> { - /// // Prepare sequence - /// let mut sequence = ArrayTrait::new(); - /// let mut shape = ArrayTrait::::new(); - /// shape.append(1); - /// - /// let mut data = ArrayTrait::new(); - /// data.append(1); - /// sequence.append(TensorTrait::new(shape.span(), data.span())); - /// let mut data = ArrayTrait::new(); - /// data.append(3); - /// - /// sequence.append(TensorTrait::new(shape.span(), data.span())); - /// - /// // Prepare input tensor - /// let mut data = ArrayTrait::new(); - /// data.append(2); - /// let tensor = TensorTrait::new(shape.span(), data.span()); - /// - /// // Prepare position - /// let mut shape = ArrayTrait::::new(); - /// let mut data = ArrayTrait::::new(); - /// data.append(i32 { mag: 1, sign: false }); - /// let position = TensorTrait::::new(shape.span(), data.span()) - /// - /// let sequence = self.sequence_insert(tensor, Option::Some(position)); - /// - /// return sequence; - /// } - /// - /// >>> [[1], [2], [3]] - /// ``` - /// - fn sequence_insert( - self: Array>, tensor: @Tensor, position: Option> - ) -> Array>; - /// ## tensor.sequence_at - /// - /// ```rust - /// fn sequence_at(sequence: Array>, position: Tensor) -> Tensor; - /// ``` - /// - /// Outputs the tensor at the specified position in the input sequence. - /// - /// ## Args - /// - /// * `tensors`(`Array>`) - The tensor sequence. - /// * `position`(`Tensor`) - The position tensor. - /// - /// ## Panics - /// - /// * Panics if position is not a scalar - /// * Panics if position is out of bounds [-n, n - 1] - /// - /// ## Returns - /// - /// The tensor `Tensor` from the sequence at the specified position. - /// - /// ## Examples - /// - /// ```rust - /// use core::array::{ArrayTrait, SpanTrait}; - /// - /// use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor, I32Tensor}; - /// use orion::numbers::{i32, IntegerTrait}; - /// - /// fn sequence_at_example() -> Tensor { - /// let tensor1 = TensorTrait::new(shape: array![2, 2].span(), data: array![0, 1, 2, 3].span()); - /// let tensor2 = TensorTrait::new(shape: array![2, 2].span(), data: array![4, 5, 6, 7].span()); - /// - /// let mut sequence = ArrayTrait::new(); - /// sequence.append(tensor1); - /// sequence.append(tensor2); - /// - /// let position = TensorTrait::new(shape: array![].span(), data: array![IntegerTrait::new(1, false)].span()); - /// - /// let result = TensorTrait::sequence_at(sequence, position); - /// return result; - /// } - /// >>> [4, 5, 6, 7] - /// ``` - /// - fn sequence_at(sequence: Array>, position: Tensor) -> Tensor; - /// ## tensor.sequence_erase - /// - /// ```rust - /// fn sequence_erase(sequence: Array>, position: Option>) -> Array>; - /// ``` - /// - /// Outputs the tensor sequence with the erased tensor at the specified position. - /// - /// ## Args - /// - /// * `tensors`(`Array>`) - The tensor sequence. - /// * `position`(`Option>`) - The optional position tensor (by default erases the last tensor). - /// - /// ## Panics - /// - /// * Panics if position is not a scalar - /// * Panics if position is out of bounds [-n, n - 1] - /// - /// ## Returns - /// - /// The tensor sequence `Array>` with the erased tensor at the specified position. - /// - /// ## Examples - /// - /// ```rust - /// use core::array::{ArrayTrait, SpanTrait}; - /// - /// use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor, I32Tensor}; - /// use orion::numbers::{i32, IntegerTrait}; - /// - /// fn sequence_erase_example() -> Tensor { - /// let tensor1 = TensorTrait::new(shape: array![2, 2].span(), data: array![0, 1, 2, 3].span()); - /// let tensor2 = TensorTrait::new(shape: array![2, 2].span(), data: array![4, 5, 6, 7].span()); - /// let tensor3 = TensorTrait::new(shape: array![2, 2].span(), data: array![8, 9, 10, 11].span()); - /// - /// let mut sequence = ArrayTrait::new(); - /// sequence.append(tensor1); - /// sequence.append(tensor2); - /// sequence.append(tensor3); - /// - /// let position = TensorTrait::new(shape: array![].span(), data: array![IntegerTrait::new(1, false)].span()); - /// - /// let result = TensorTrait::sequence_erase(sequence, position); - /// return result; - /// } - /// >>> [[0, 1, 2, 3], [8, 9, 10, 11]] - /// ``` - /// - fn sequence_erase( - sequence: Array>, position: Option> - ) -> Array>; /// #tensor.pow /// /// ```rust @@ -4517,44 +4355,6 @@ trait TensorTrait { /// ``` /// fn pow(self: @Tensor, other: @Tensor) -> Tensor; - /// # tensor.sequence_length - /// - /// ```rust - /// fn sequence_length(self: Array>) -> Tensor; - /// ``` - /// - /// Returns the length of the input sequence. - /// - /// ## Args - /// - /// * `self`(`Array>`) - The input sequence. - /// - /// ## Returns - /// - /// The length of the sequence as scalar, i.e. a tensor of shape []. - /// - /// ## Examples - /// - /// Let's create new u32 Tensor with constant 42. - /// - /// ```rust - /// let mut sequence = ArrayTrait::new(); - /// - /// let mut shape = ArrayTrait::::new(); - /// shape.append(1); - /// shape.append(2); - /// - /// let mut data = ArrayTrait::new(); - /// data.append(3); - /// data.append(1); - /// - /// sequence.append(TensorTrait::new(shape.span(), data.span())); - /// - /// sequence.sequence_length() - /// >>> [1] - /// ``` - /// - fn sequence_length(self: Array>) -> Tensor; /// ## tensor.reduce_prod /// /// ```rust @@ -4673,69 +4473,6 @@ trait TensorTrait { /// ``` /// fn is_nan(self: @Tensor) -> Tensor; - /// # tensor.concat_from_sequence - /// - /// ```rust - /// fn concat_from_sequence(sequence: Array>, axis: i32, new_axis: Option) -> Tensor; - /// ``` - /// - /// Concatenate a sequence of tensors into a single tensor. - /// - /// ## Args - /// - /// * `sequence`(`Array>`) - The input sequence. - /// * `axis`(`i32`) - Axis to concat on. - /// * `new_axis`(`Option`) - Optionally added new axis. - /// - /// ## Panics - /// - /// * Panics if new_axis not 0 or 1 (if value provided). - /// * Panics if axis not in accepted ranges. - /// * Panics if sequence length is not greater than 1. - /// - /// ## Returns - /// - /// A new `Tensor` concatenated tensor from the input tensor sequence. - /// - /// ## Example - /// - /// ```rust - /// use core::array::{ArrayTrait, SpanTrait}; - /// - /// use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor}; - /// - /// fn concat_example() -> Tensor { - /// let tensor1 = TensorTrait::new(shape: array![2, 2].span(), data: array![0, 1, 2, 3].span(),); - /// let tensor2 = TensorTrait::new(shape: array![2, 2].span(), data: array![0, 1, 2, 3].span(),); - /// - /// let mut sequence = ArrayTrait::new(); - /// sequence.append(tensor1); - /// sequence.append(tensor2); - /// - /// let result = TensorTrait::concat_from_sequence(sequence: sequence, axis: 0, new_axis: Option::Some(0)); - /// return result; - /// } - /// >>> [[0. 1.] - /// [2. 3.], - /// [0. 1.] - /// [2. 3.]] - /// - /// result.shape - /// >>> (4, 2) - /// - /// let result = TensorTrait::concat_from_sequence(sequence: sequence, axis: 1, new_axis: Option::Some(0)); - /// return result; - /// } - /// >>> [[0. 1., 0., 1.] - /// [2. 3., 2., 3.]] - /// - /// result.shape - /// >>> (2, 4 ) - /// ``` - /// - fn concat_from_sequence( - sequence: Array>, axis: i32, new_axis: Option - ) -> Tensor; /// #tensor.not /// /// ```rust diff --git a/src/operators/tensor/implementations/tensor_bool.cairo b/src/operators/tensor/implementations/tensor_bool.cairo index 99cfcee76..5e14ff92d 100644 --- a/src/operators/tensor/implementations/tensor_bool.cairo +++ b/src/operators/tensor/implementations/tensor_bool.cairo @@ -380,14 +380,6 @@ impl BoolTensor of TensorTrait { math::gather_elements::gather_elements(self, indices, axis) } - fn sequence_length(self: Array>) -> Tensor { - math::sequence_length::sequence_length(self) - } - - fn sequence_at(sequence: Array>, position: Tensor) -> Tensor { - math::sequence_at::sequence_at(sequence, position) - } - fn shrink(self: Tensor, bias: Option, lambd: Option) -> Tensor { panic(array!['not supported!']) } @@ -422,18 +414,6 @@ impl BoolTensor of TensorTrait { panic(array!['not supported!']) } - fn sequence_erase( - sequence: Array>, position: Option> - ) -> Array> { - math::sequence_erase::sequence_erase(sequence, position) - } - - fn sequence_insert( - self: Array>, tensor: @Tensor, position: Option> - ) -> Array> { - math::sequence_insert::sequence_insert(self, tensor, position) - } - fn is_inf( self: @Tensor, detect_negative: Option, detect_positive: Option ) -> Tensor { @@ -444,12 +424,6 @@ impl BoolTensor of TensorTrait { panic(array!['not supported!']) } - fn concat_from_sequence( - sequence: Array>, axis: i32, new_axis: Option - ) -> Tensor { - panic(array!['not supported!']) - } - fn erf(self: @Tensor) -> Tensor { panic(array!['not supported!']) } diff --git a/src/operators/tensor/implementations/tensor_fp16x16.cairo b/src/operators/tensor/implementations/tensor_fp16x16.cairo index 03c022121..995f43272 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16.cairo @@ -448,20 +448,12 @@ impl FP16x16Tensor of TensorTrait { math::gather_elements::gather_elements(self, indices, axis) } - fn sequence_length(self: Array>) -> Tensor { - math::sequence_length::sequence_length(self) - } - fn shrink( self: Tensor, bias: Option, lambd: Option ) -> Tensor { math::shrink::shrink(self, bias, lambd) } - fn sequence_at(sequence: Array>, position: Tensor) -> Tensor { - math::sequence_at::sequence_at(sequence, position) - } - fn reduce_mean( self: @Tensor, axes: Option>, @@ -484,18 +476,6 @@ impl FP16x16Tensor of TensorTrait { math::pow::pow(self, other) } - fn sequence_erase( - sequence: Array>, position: Option> - ) -> Array> { - math::sequence_erase::sequence_erase(sequence, position) - } - - fn sequence_insert( - self: Array>, tensor: @Tensor, position: Option> - ) -> Array> { - math::sequence_insert::sequence_insert(self, tensor, position) - } - fn is_inf( self: @Tensor, detect_negative: Option, detect_positive: Option ) -> Tensor { @@ -506,12 +486,6 @@ impl FP16x16Tensor of TensorTrait { math::is_nan::is_nan(self) } - fn concat_from_sequence( - sequence: Array>, axis: i32, new_axis: Option - ) -> Tensor { - math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) - } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } diff --git a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo index 109f956ad..f4e90a1cd 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo @@ -414,20 +414,12 @@ impl FP16x16WTensor of TensorTrait { math::gather_elements::gather_elements(self, indices, axis) } - fn sequence_length(self: Array>) -> Tensor { - math::sequence_length::sequence_length(self) - } - fn shrink( self: Tensor, bias: Option, lambd: Option ) -> Tensor { math::shrink::shrink(self, bias, lambd) } - fn sequence_at(sequence: Array>, position: Tensor) -> Tensor { - math::sequence_at::sequence_at(sequence, position) - } - fn reduce_mean( self: @Tensor, axes: Option>, @@ -450,18 +442,6 @@ impl FP16x16WTensor of TensorTrait { math::pow::pow(self, other) } - fn sequence_erase( - sequence: Array>, position: Option> - ) -> Array> { - math::sequence_erase::sequence_erase(sequence, position) - } - - fn sequence_insert( - self: Array>, tensor: @Tensor, position: Option> - ) -> Array> { - math::sequence_insert::sequence_insert(self, tensor, position) - } - fn is_inf( self: @Tensor, detect_negative: Option, detect_positive: Option ) -> Tensor { @@ -472,12 +452,6 @@ impl FP16x16WTensor of TensorTrait { math::is_nan::is_nan(self) } - fn concat_from_sequence( - sequence: Array>, axis: i32, new_axis: Option - ) -> Tensor { - math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) - } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } diff --git a/src/operators/tensor/implementations/tensor_fp32x32.cairo b/src/operators/tensor/implementations/tensor_fp32x32.cairo index a780cd6c6..7121be2ad 100644 --- a/src/operators/tensor/implementations/tensor_fp32x32.cairo +++ b/src/operators/tensor/implementations/tensor_fp32x32.cairo @@ -449,20 +449,12 @@ impl FP32x32Tensor of TensorTrait { math::gather_elements::gather_elements(self, indices, axis) } - fn sequence_length(self: Array>) -> Tensor { - math::sequence_length::sequence_length(self) - } - fn shrink( self: Tensor, bias: Option, lambd: Option ) -> Tensor { math::shrink::shrink(self, bias, lambd) } - fn sequence_at(sequence: Array>, position: Tensor) -> Tensor { - math::sequence_at::sequence_at(sequence, position) - } - fn reduce_mean( self: @Tensor, axes: Option>, @@ -485,18 +477,6 @@ impl FP32x32Tensor of TensorTrait { math::pow::pow(self, other) } - fn sequence_erase( - sequence: Array>, position: Option> - ) -> Array> { - math::sequence_erase::sequence_erase(sequence, position) - } - - fn sequence_insert( - self: Array>, tensor: @Tensor, position: Option> - ) -> Array> { - math::sequence_insert::sequence_insert(self, tensor, position) - } - fn is_inf( self: @Tensor, detect_negative: Option, detect_positive: Option ) -> Tensor { @@ -507,12 +487,6 @@ impl FP32x32Tensor of TensorTrait { math::is_nan::is_nan(self) } - fn concat_from_sequence( - sequence: Array>, axis: i32, new_axis: Option - ) -> Tensor { - math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) - } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } diff --git a/src/operators/tensor/implementations/tensor_fp64x64.cairo b/src/operators/tensor/implementations/tensor_fp64x64.cairo index 24a635532..447d00013 100644 --- a/src/operators/tensor/implementations/tensor_fp64x64.cairo +++ b/src/operators/tensor/implementations/tensor_fp64x64.cairo @@ -449,20 +449,12 @@ impl FP64x64Tensor of TensorTrait { math::gather_elements::gather_elements(self, indices, axis) } - fn sequence_length(self: Array>) -> Tensor { - math::sequence_length::sequence_length(self) - } - fn shrink( self: Tensor, bias: Option, lambd: Option ) -> Tensor { math::shrink::shrink(self, bias, lambd) } - fn sequence_at(sequence: Array>, position: Tensor) -> Tensor { - math::sequence_at::sequence_at(sequence, position) - } - fn reduce_mean( self: @Tensor, axes: Option>, @@ -485,18 +477,6 @@ impl FP64x64Tensor of TensorTrait { math::pow::pow(self, other) } - fn sequence_erase( - sequence: Array>, position: Option> - ) -> Array> { - math::sequence_erase::sequence_erase(sequence, position) - } - - fn sequence_insert( - self: Array>, tensor: @Tensor, position: Option> - ) -> Array> { - math::sequence_insert::sequence_insert(self, tensor, position) - } - fn is_inf( self: @Tensor, detect_negative: Option, detect_positive: Option ) -> Tensor { @@ -507,12 +487,6 @@ impl FP64x64Tensor of TensorTrait { math::is_nan::is_nan(self) } - fn concat_from_sequence( - sequence: Array>, axis: i32, new_axis: Option - ) -> Tensor { - math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) - } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } diff --git a/src/operators/tensor/implementations/tensor_fp8x23.cairo b/src/operators/tensor/implementations/tensor_fp8x23.cairo index 40fa0a95b..0e9ceae4c 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23.cairo @@ -436,18 +436,10 @@ impl FP8x23Tensor of TensorTrait { math::gather_elements::gather_elements(self, indices, axis) } - fn sequence_length(self: Array>) -> Tensor { - math::sequence_length::sequence_length(self) - } - fn shrink(self: Tensor, bias: Option, lambd: Option) -> Tensor { math::shrink::shrink(self, bias, lambd) } - fn sequence_at(sequence: Array>, position: Tensor) -> Tensor { - math::sequence_at::sequence_at(sequence, position) - } - fn reduce_mean( self: @Tensor, axes: Option>, @@ -483,18 +475,6 @@ impl FP8x23Tensor of TensorTrait { math::pow::pow(self, other) } - fn sequence_erase( - sequence: Array>, position: Option> - ) -> Array> { - math::sequence_erase::sequence_erase(sequence, position) - } - - fn sequence_insert( - self: Array>, tensor: @Tensor, position: Option> - ) -> Array> { - math::sequence_insert::sequence_insert(self, tensor, position) - } - fn is_inf( self: @Tensor, detect_negative: Option, detect_positive: Option ) -> Tensor { @@ -505,12 +485,6 @@ impl FP8x23Tensor of TensorTrait { math::is_nan::is_nan(self) } - fn concat_from_sequence( - sequence: Array>, axis: i32, new_axis: Option - ) -> Tensor { - math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) - } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } diff --git a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo index 09ef81e1a..5db6c986a 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo @@ -400,20 +400,12 @@ impl FP8x23WTensor of TensorTrait { math::gather_elements::gather_elements(self, indices, axis) } - fn sequence_length(self: Array>) -> Tensor { - math::sequence_length::sequence_length(self) - } - fn shrink( self: Tensor, bias: Option, lambd: Option ) -> Tensor { math::shrink::shrink(self, bias, lambd) } - fn sequence_at(sequence: Array>, position: Tensor) -> Tensor { - math::sequence_at::sequence_at(sequence, position) - } - fn reduce_mean( self: @Tensor, axes: Option>, @@ -436,18 +428,6 @@ impl FP8x23WTensor of TensorTrait { math::pow::pow(self, other) } - fn sequence_erase( - sequence: Array>, position: Option> - ) -> Array> { - math::sequence_erase::sequence_erase(sequence, position) - } - - fn sequence_insert( - self: Array>, tensor: @Tensor, position: Option> - ) -> Array> { - math::sequence_insert::sequence_insert(self, tensor, position) - } - fn is_inf( self: @Tensor, detect_negative: Option, detect_positive: Option ) -> Tensor { @@ -458,12 +438,6 @@ impl FP8x23WTensor of TensorTrait { math::is_nan::is_nan(self) } - fn concat_from_sequence( - sequence: Array>, axis: i32, new_axis: Option - ) -> Tensor { - math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) - } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } diff --git a/src/operators/tensor/implementations/tensor_i32.cairo b/src/operators/tensor/implementations/tensor_i32.cairo index dacf2b733..524700cfe 100644 --- a/src/operators/tensor/implementations/tensor_i32.cairo +++ b/src/operators/tensor/implementations/tensor_i32.cairo @@ -446,18 +446,10 @@ impl I32Tensor of TensorTrait { math::gather_elements::gather_elements(self, indices, axis) } - fn sequence_length(self: Array>) -> Tensor { - math::sequence_length::sequence_length(self) - } - fn shrink(self: Tensor, bias: Option, lambd: Option) -> Tensor { panic(array!['not supported!']) } - fn sequence_at(sequence: Array>, position: Tensor) -> Tensor { - math::sequence_at::sequence_at(sequence, position) - } - fn reduce_mean( self: @Tensor, axes: Option>, @@ -480,18 +472,6 @@ impl I32Tensor of TensorTrait { panic(array!['not supported!']) } - fn sequence_erase( - sequence: Array>, position: Option> - ) -> Array> { - math::sequence_erase::sequence_erase(sequence, position) - } - - fn sequence_insert( - self: Array>, tensor: @Tensor, position: Option> - ) -> Array> { - math::sequence_insert::sequence_insert(self, tensor, position) - } - fn is_inf( self: @Tensor, detect_negative: Option, detect_positive: Option ) -> Tensor { @@ -502,12 +482,6 @@ impl I32Tensor of TensorTrait { panic(array!['not supported!']) } - fn concat_from_sequence( - sequence: Array>, axis: i32, new_axis: Option - ) -> Tensor { - math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) - } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } diff --git a/src/operators/tensor/implementations/tensor_i8.cairo b/src/operators/tensor/implementations/tensor_i8.cairo index e5b0c299e..19d47cc4b 100644 --- a/src/operators/tensor/implementations/tensor_i8.cairo +++ b/src/operators/tensor/implementations/tensor_i8.cairo @@ -444,18 +444,10 @@ impl I8Tensor of TensorTrait { math::gather_elements::gather_elements(self, indices, axis) } - fn sequence_length(self: Array>) -> Tensor { - math::sequence_length::sequence_length(self) - } - fn shrink(self: Tensor, bias: Option, lambd: Option) -> Tensor { panic(array!['not supported!']) } - fn sequence_at(sequence: Array>, position: Tensor) -> Tensor { - math::sequence_at::sequence_at(sequence, position) - } - fn reduce_mean( self: @Tensor, axes: Option>, @@ -478,18 +470,6 @@ impl I8Tensor of TensorTrait { panic(array!['not supported!']) } - fn sequence_erase( - sequence: Array>, position: Option> - ) -> Array> { - math::sequence_erase::sequence_erase(sequence, position) - } - - fn sequence_insert( - self: Array>, tensor: @Tensor, position: Option> - ) -> Array> { - math::sequence_insert::sequence_insert(self, tensor, position) - } - fn is_inf( self: @Tensor, detect_negative: Option, detect_positive: Option ) -> Tensor { @@ -500,12 +480,6 @@ impl I8Tensor of TensorTrait { panic(array!['not supported!']) } - fn concat_from_sequence( - sequence: Array>, axis: i32, new_axis: Option - ) -> Tensor { - math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) - } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } diff --git a/src/operators/tensor/implementations/tensor_u32.cairo b/src/operators/tensor/implementations/tensor_u32.cairo index a8e989eb1..165eee7d5 100644 --- a/src/operators/tensor/implementations/tensor_u32.cairo +++ b/src/operators/tensor/implementations/tensor_u32.cairo @@ -388,18 +388,10 @@ impl U32Tensor of TensorTrait { math::gather_elements::gather_elements(self, indices, axis) } - fn sequence_length(self: Array>) -> Tensor { - math::sequence_length::sequence_length(self) - } - fn shrink(self: Tensor, bias: Option, lambd: Option) -> Tensor { panic(array!['not supported!']) } - fn sequence_at(sequence: Array>, position: Tensor) -> Tensor { - math::sequence_at::sequence_at(sequence, position) - } - fn reduce_mean( self: @Tensor, axes: Option>, @@ -422,18 +414,6 @@ impl U32Tensor of TensorTrait { panic(array!['not supported!']) } - fn sequence_erase( - sequence: Array>, position: Option> - ) -> Array> { - math::sequence_erase::sequence_erase(sequence, position) - } - - fn sequence_insert( - self: Array>, tensor: @Tensor, position: Option> - ) -> Array> { - math::sequence_insert::sequence_insert(self, tensor, position) - } - fn is_inf( self: @Tensor, detect_negative: Option, detect_positive: Option ) -> Tensor { @@ -444,12 +424,6 @@ impl U32Tensor of TensorTrait { panic(array!['not supported!']) } - fn concat_from_sequence( - sequence: Array>, axis: i32, new_axis: Option - ) -> Tensor { - math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) - } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } diff --git a/src/operators/tensor/math.cairo b/src/operators/tensor/math.cairo index c29ce69e4..a7cf39a30 100644 --- a/src/operators/tensor/math.cairo +++ b/src/operators/tensor/math.cairo @@ -49,15 +49,10 @@ mod bitwise_and; mod bitwise_xor; mod bitwise_or; mod gather_elements; -mod sequence_length; -mod sequence_at; mod reduce_min; mod shrink; mod reduce_mean; mod pow; -mod sequence_erase; -mod sequence_insert; -mod concat_from_sequence; mod is_nan; mod is_inf; mod gather_nd; diff --git a/tests/nodes/concat_from_sequence_fp16x16_new_axis_default.cairo b/tests/nodes/concat_from_sequence_fp16x16_new_axis_default.cairo index ce6b56929..adaf44a2d 100644 --- a/tests/nodes/concat_from_sequence_fp16x16_new_axis_default.cairo +++ b/tests/nodes/concat_from_sequence_fp16x16_new_axis_default.cairo @@ -8,6 +8,8 @@ use orion::operators::tensor::FP16x16Tensor; use core::array::{ArrayTrait, SpanTrait}; use orion::utils::{assert_eq, assert_seq_eq}; use orion::numbers::signed_integer::{integer_trait::IntegerTrait, i32::i32}; +use orion::operators::sequence::FP16x16Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] @@ -15,7 +17,7 @@ fn test_concat_from_sequence_fp16x16_new_axis_default() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = TensorTrait::concat_from_sequence( + let y = SequenceTrait::concat_from_sequence( input_0, IntegerTrait::::new(1, false), Option::None(()) ); diff --git a/tests/nodes/concat_from_sequence_fp16x16_new_axis_one.cairo b/tests/nodes/concat_from_sequence_fp16x16_new_axis_one.cairo index bbb5d9fc0..0bb989ed7 100644 --- a/tests/nodes/concat_from_sequence_fp16x16_new_axis_one.cairo +++ b/tests/nodes/concat_from_sequence_fp16x16_new_axis_one.cairo @@ -8,6 +8,8 @@ use orion::operators::tensor::FP16x16Tensor; use core::array::{ArrayTrait, SpanTrait}; use orion::utils::{assert_eq, assert_seq_eq}; use orion::numbers::signed_integer::{integer_trait::IntegerTrait, i32::i32}; +use orion::operators::sequence::FP16x16Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] @@ -15,7 +17,7 @@ fn test_concat_from_sequence_fp16x16_new_axis_one() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = TensorTrait::concat_from_sequence( + let y = SequenceTrait::concat_from_sequence( input_0, IntegerTrait::::new(1, false), Option::Some(1) ); diff --git a/tests/nodes/concat_from_sequence_fp16x16_new_axis_zero.cairo b/tests/nodes/concat_from_sequence_fp16x16_new_axis_zero.cairo index 7ae21d053..a2856a1cc 100644 --- a/tests/nodes/concat_from_sequence_fp16x16_new_axis_zero.cairo +++ b/tests/nodes/concat_from_sequence_fp16x16_new_axis_zero.cairo @@ -8,6 +8,8 @@ use orion::operators::tensor::FP16x16Tensor; use core::array::{ArrayTrait, SpanTrait}; use orion::utils::{assert_eq, assert_seq_eq}; use orion::numbers::signed_integer::{integer_trait::IntegerTrait, i32::i32}; +use orion::operators::sequence::FP16x16Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] @@ -15,7 +17,7 @@ fn test_concat_from_sequence_fp16x16_new_axis_zero() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = TensorTrait::concat_from_sequence( + let y = SequenceTrait::concat_from_sequence( input_0, IntegerTrait::::new(1, false), Option::Some(0) ); diff --git a/tests/nodes/concat_from_sequence_fp8x23_new_axis_default.cairo b/tests/nodes/concat_from_sequence_fp8x23_new_axis_default.cairo index 3383ce9f6..e569736f6 100644 --- a/tests/nodes/concat_from_sequence_fp8x23_new_axis_default.cairo +++ b/tests/nodes/concat_from_sequence_fp8x23_new_axis_default.cairo @@ -8,6 +8,8 @@ use orion::operators::tensor::FP8x23Tensor; use core::array::{ArrayTrait, SpanTrait}; use orion::utils::{assert_eq, assert_seq_eq}; use orion::numbers::signed_integer::{integer_trait::IntegerTrait, i32::i32}; +use orion::operators::sequence::FP8x23Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] @@ -15,7 +17,7 @@ fn test_concat_from_sequence_fp8x23_new_axis_default() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = TensorTrait::concat_from_sequence( + let y = SequenceTrait::concat_from_sequence( input_0, IntegerTrait::::new(1, false), Option::None(()) ); diff --git a/tests/nodes/concat_from_sequence_fp8x23_new_axis_one.cairo b/tests/nodes/concat_from_sequence_fp8x23_new_axis_one.cairo index a032dd956..12ef40787 100644 --- a/tests/nodes/concat_from_sequence_fp8x23_new_axis_one.cairo +++ b/tests/nodes/concat_from_sequence_fp8x23_new_axis_one.cairo @@ -8,6 +8,8 @@ use orion::operators::tensor::FP8x23Tensor; use core::array::{ArrayTrait, SpanTrait}; use orion::utils::{assert_eq, assert_seq_eq}; use orion::numbers::signed_integer::{integer_trait::IntegerTrait, i32::i32}; +use orion::operators::sequence::FP8x23Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] @@ -15,7 +17,7 @@ fn test_concat_from_sequence_fp8x23_new_axis_one() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = TensorTrait::concat_from_sequence( + let y = SequenceTrait::concat_from_sequence( input_0, IntegerTrait::::new(1, false), Option::Some(1) ); diff --git a/tests/nodes/concat_from_sequence_fp8x23_new_axis_zero.cairo b/tests/nodes/concat_from_sequence_fp8x23_new_axis_zero.cairo index 3696d14af..a249cd309 100644 --- a/tests/nodes/concat_from_sequence_fp8x23_new_axis_zero.cairo +++ b/tests/nodes/concat_from_sequence_fp8x23_new_axis_zero.cairo @@ -8,6 +8,8 @@ use orion::operators::tensor::FP8x23Tensor; use core::array::{ArrayTrait, SpanTrait}; use orion::utils::{assert_eq, assert_seq_eq}; use orion::numbers::signed_integer::{integer_trait::IntegerTrait, i32::i32}; +use orion::operators::sequence::FP8x23Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] @@ -15,7 +17,7 @@ fn test_concat_from_sequence_fp8x23_new_axis_zero() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = TensorTrait::concat_from_sequence( + let y = SequenceTrait::concat_from_sequence( input_0, IntegerTrait::::new(1, false), Option::Some(0) ); diff --git a/tests/nodes/concat_from_sequence_i32_new_axis_default.cairo b/tests/nodes/concat_from_sequence_i32_new_axis_default.cairo index 8c53ebf67..e6fbf0409 100644 --- a/tests/nodes/concat_from_sequence_i32_new_axis_default.cairo +++ b/tests/nodes/concat_from_sequence_i32_new_axis_default.cairo @@ -8,6 +8,8 @@ use orion::operators::tensor::I32TensorPartialEq; use core::array::{ArrayTrait, SpanTrait}; use orion::utils::{assert_eq, assert_seq_eq}; use orion::numbers::signed_integer::{integer_trait::IntegerTrait, i32::i32}; +use orion::operators::sequence::I32Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] @@ -15,7 +17,7 @@ fn test_concat_from_sequence_i32_new_axis_default() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = TensorTrait::concat_from_sequence( + let y = SequenceTrait::concat_from_sequence( input_0, IntegerTrait::::new(1, false), Option::None(()) ); diff --git a/tests/nodes/concat_from_sequence_i32_new_axis_one.cairo b/tests/nodes/concat_from_sequence_i32_new_axis_one.cairo index 855321bf4..703d0b928 100644 --- a/tests/nodes/concat_from_sequence_i32_new_axis_one.cairo +++ b/tests/nodes/concat_from_sequence_i32_new_axis_one.cairo @@ -8,6 +8,8 @@ use orion::operators::tensor::I32TensorPartialEq; use core::array::{ArrayTrait, SpanTrait}; use orion::utils::{assert_eq, assert_seq_eq}; use orion::numbers::signed_integer::{integer_trait::IntegerTrait, i32::i32}; +use orion::operators::sequence::I32Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] @@ -15,7 +17,7 @@ fn test_concat_from_sequence_i32_new_axis_one() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = TensorTrait::concat_from_sequence( + let y = SequenceTrait::concat_from_sequence( input_0, IntegerTrait::::new(1, false), Option::Some(1) ); diff --git a/tests/nodes/concat_from_sequence_i32_new_axis_zero.cairo b/tests/nodes/concat_from_sequence_i32_new_axis_zero.cairo index ca190e242..d1ffd1caf 100644 --- a/tests/nodes/concat_from_sequence_i32_new_axis_zero.cairo +++ b/tests/nodes/concat_from_sequence_i32_new_axis_zero.cairo @@ -8,6 +8,8 @@ use orion::operators::tensor::I32TensorPartialEq; use core::array::{ArrayTrait, SpanTrait}; use orion::utils::{assert_eq, assert_seq_eq}; use orion::numbers::signed_integer::{integer_trait::IntegerTrait, i32::i32}; +use orion::operators::sequence::I32Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] @@ -15,7 +17,7 @@ fn test_concat_from_sequence_i32_new_axis_zero() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = TensorTrait::concat_from_sequence( + let y = SequenceTrait::concat_from_sequence( input_0, IntegerTrait::::new(1, false), Option::Some(0) ); diff --git a/tests/nodes/concat_from_sequence_i8_new_axis_default.cairo b/tests/nodes/concat_from_sequence_i8_new_axis_default.cairo index 0cab162e7..e1bff9308 100644 --- a/tests/nodes/concat_from_sequence_i8_new_axis_default.cairo +++ b/tests/nodes/concat_from_sequence_i8_new_axis_default.cairo @@ -8,6 +8,8 @@ use core::array::{ArrayTrait, SpanTrait}; use orion::operators::tensor::I8Tensor; use orion::utils::{assert_eq, assert_seq_eq}; use orion::numbers::signed_integer::{integer_trait::IntegerTrait, i32::i32}; +use orion::operators::sequence::I8Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] @@ -15,7 +17,7 @@ fn test_concat_from_sequence_i8_new_axis_default() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = TensorTrait::concat_from_sequence( + let y = SequenceTrait::concat_from_sequence( input_0, IntegerTrait::::new(1, false), Option::None(()) ); diff --git a/tests/nodes/concat_from_sequence_i8_new_axis_one.cairo b/tests/nodes/concat_from_sequence_i8_new_axis_one.cairo index 59d295d3c..003b5446d 100644 --- a/tests/nodes/concat_from_sequence_i8_new_axis_one.cairo +++ b/tests/nodes/concat_from_sequence_i8_new_axis_one.cairo @@ -8,6 +8,8 @@ use core::array::{ArrayTrait, SpanTrait}; use orion::operators::tensor::I8Tensor; use orion::utils::{assert_eq, assert_seq_eq}; use orion::numbers::signed_integer::{integer_trait::IntegerTrait, i32::i32}; +use orion::operators::sequence::I8Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] @@ -15,7 +17,7 @@ fn test_concat_from_sequence_i8_new_axis_one() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = TensorTrait::concat_from_sequence( + let y = SequenceTrait::concat_from_sequence( input_0, IntegerTrait::::new(1, false), Option::Some(1) ); diff --git a/tests/nodes/concat_from_sequence_i8_new_axis_zero.cairo b/tests/nodes/concat_from_sequence_i8_new_axis_zero.cairo index 0d4e4daea..1299495df 100644 --- a/tests/nodes/concat_from_sequence_i8_new_axis_zero.cairo +++ b/tests/nodes/concat_from_sequence_i8_new_axis_zero.cairo @@ -8,6 +8,8 @@ use core::array::{ArrayTrait, SpanTrait}; use orion::operators::tensor::I8Tensor; use orion::utils::{assert_eq, assert_seq_eq}; use orion::numbers::signed_integer::{integer_trait::IntegerTrait, i32::i32}; +use orion::operators::sequence::I8Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] @@ -15,7 +17,7 @@ fn test_concat_from_sequence_i8_new_axis_zero() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = TensorTrait::concat_from_sequence( + let y = SequenceTrait::concat_from_sequence( input_0, IntegerTrait::::new(1, false), Option::Some(0) ); diff --git a/tests/nodes/concat_from_sequence_u32_new_axis_default.cairo b/tests/nodes/concat_from_sequence_u32_new_axis_default.cairo index 8a787ea18..474cda7db 100644 --- a/tests/nodes/concat_from_sequence_u32_new_axis_default.cairo +++ b/tests/nodes/concat_from_sequence_u32_new_axis_default.cairo @@ -8,6 +8,8 @@ use orion::operators::tensor::U32TensorPartialEq; use core::array::{ArrayTrait, SpanTrait}; use orion::utils::{assert_eq, assert_seq_eq}; use orion::numbers::signed_integer::{integer_trait::IntegerTrait, i32::i32}; +use orion::operators::sequence::U32Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] @@ -15,7 +17,7 @@ fn test_concat_from_sequence_u32_new_axis_default() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = TensorTrait::concat_from_sequence( + let y = SequenceTrait::concat_from_sequence( input_0, IntegerTrait::::new(1, false), Option::None(()) ); diff --git a/tests/nodes/concat_from_sequence_u32_new_axis_one.cairo b/tests/nodes/concat_from_sequence_u32_new_axis_one.cairo index fa1fa4d9c..7d86fa0b3 100644 --- a/tests/nodes/concat_from_sequence_u32_new_axis_one.cairo +++ b/tests/nodes/concat_from_sequence_u32_new_axis_one.cairo @@ -8,6 +8,8 @@ use orion::operators::tensor::U32TensorPartialEq; use core::array::{ArrayTrait, SpanTrait}; use orion::utils::{assert_eq, assert_seq_eq}; use orion::numbers::signed_integer::{integer_trait::IntegerTrait, i32::i32}; +use orion::operators::sequence::U32Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] @@ -15,7 +17,7 @@ fn test_concat_from_sequence_u32_new_axis_one() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = TensorTrait::concat_from_sequence( + let y = SequenceTrait::concat_from_sequence( input_0, IntegerTrait::::new(1, false), Option::Some(1) ); diff --git a/tests/nodes/concat_from_sequence_u32_new_axis_zero.cairo b/tests/nodes/concat_from_sequence_u32_new_axis_zero.cairo index 5a2b1fa30..9bfbeacb0 100644 --- a/tests/nodes/concat_from_sequence_u32_new_axis_zero.cairo +++ b/tests/nodes/concat_from_sequence_u32_new_axis_zero.cairo @@ -8,6 +8,8 @@ use orion::operators::tensor::U32TensorPartialEq; use core::array::{ArrayTrait, SpanTrait}; use orion::utils::{assert_eq, assert_seq_eq}; use orion::numbers::signed_integer::{integer_trait::IntegerTrait, i32::i32}; +use orion::operators::sequence::U32Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] @@ -15,7 +17,7 @@ fn test_concat_from_sequence_u32_new_axis_zero() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = TensorTrait::concat_from_sequence( + let y = SequenceTrait::concat_from_sequence( input_0, IntegerTrait::::new(1, false), Option::Some(0) ); diff --git a/tests/nodes/sequence_at_fp16x16_negative.cairo b/tests/nodes/sequence_at_fp16x16_negative.cairo index d65c1e56f..12bd5507c 100644 --- a/tests/nodes/sequence_at_fp16x16_negative.cairo +++ b/tests/nodes/sequence_at_fp16x16_negative.cairo @@ -10,6 +10,8 @@ use orion::operators::tensor::I32Tensor; use orion::utils::{assert_eq, assert_seq_eq}; use core::array::{ArrayTrait, SpanTrait}; use orion::operators::tensor::I32TensorPartialEq; +use orion::operators::sequence::FP16x16Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] @@ -18,7 +20,7 @@ fn test_sequence_at_fp16x16_negative() { let input_1 = input_1::input_1(); let z = output_0::output_0(); - let y = TensorTrait::sequence_at(input_0, input_1); + let y = SequenceTrait::sequence_at(input_0, input_1); assert_eq(y, z); } diff --git a/tests/nodes/sequence_at_fp16x16_positive.cairo b/tests/nodes/sequence_at_fp16x16_positive.cairo index be329f13a..c4f35cac4 100644 --- a/tests/nodes/sequence_at_fp16x16_positive.cairo +++ b/tests/nodes/sequence_at_fp16x16_positive.cairo @@ -10,6 +10,8 @@ use orion::operators::tensor::I32Tensor; use orion::utils::{assert_eq, assert_seq_eq}; use core::array::{ArrayTrait, SpanTrait}; use orion::operators::tensor::I32TensorPartialEq; +use orion::operators::sequence::FP16x16Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] @@ -18,7 +20,7 @@ fn test_sequence_at_fp16x16_positive() { let input_1 = input_1::input_1(); let z = output_0::output_0(); - let y = TensorTrait::sequence_at(input_0, input_1); + let y = SequenceTrait::sequence_at(input_0, input_1); assert_eq(y, z); } diff --git a/tests/nodes/sequence_at_fp8x23_negative.cairo b/tests/nodes/sequence_at_fp8x23_negative.cairo index e222aad40..4a04b333c 100644 --- a/tests/nodes/sequence_at_fp8x23_negative.cairo +++ b/tests/nodes/sequence_at_fp8x23_negative.cairo @@ -10,6 +10,8 @@ use orion::utils::{assert_eq, assert_seq_eq}; use orion::operators::tensor::FP8x23TensorPartialEq; use core::array::{ArrayTrait, SpanTrait}; use orion::operators::tensor::I32TensorPartialEq; +use orion::operators::sequence::FP8x23Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] @@ -18,7 +20,7 @@ fn test_sequence_at_fp8x23_negative() { let input_1 = input_1::input_1(); let z = output_0::output_0(); - let y = TensorTrait::sequence_at(input_0, input_1); + let y = SequenceTrait::sequence_at(input_0, input_1); assert_eq(y, z); } diff --git a/tests/nodes/sequence_at_fp8x23_positive.cairo b/tests/nodes/sequence_at_fp8x23_positive.cairo index f1b4b627e..543711645 100644 --- a/tests/nodes/sequence_at_fp8x23_positive.cairo +++ b/tests/nodes/sequence_at_fp8x23_positive.cairo @@ -10,6 +10,8 @@ use orion::utils::{assert_eq, assert_seq_eq}; use orion::operators::tensor::FP8x23TensorPartialEq; use core::array::{ArrayTrait, SpanTrait}; use orion::operators::tensor::I32TensorPartialEq; +use orion::operators::sequence::FP8x23Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] @@ -18,7 +20,7 @@ fn test_sequence_at_fp8x23_positive() { let input_1 = input_1::input_1(); let z = output_0::output_0(); - let y = TensorTrait::sequence_at(input_0, input_1); + let y = SequenceTrait::sequence_at(input_0, input_1); assert_eq(y, z); } diff --git a/tests/nodes/sequence_at_i32_negative.cairo b/tests/nodes/sequence_at_i32_negative.cairo index 6ba6b2594..d45079d6a 100644 --- a/tests/nodes/sequence_at_i32_negative.cairo +++ b/tests/nodes/sequence_at_i32_negative.cairo @@ -8,6 +8,8 @@ use orion::operators::tensor::I32Tensor; use orion::utils::{assert_eq, assert_seq_eq}; use core::array::{ArrayTrait, SpanTrait}; use orion::operators::tensor::I32TensorPartialEq; +use orion::operators::sequence::I32Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] @@ -16,7 +18,7 @@ fn test_sequence_at_i32_negative() { let input_1 = input_1::input_1(); let z = output_0::output_0(); - let y = TensorTrait::sequence_at(input_0, input_1); + let y = SequenceTrait::sequence_at(input_0, input_1); assert_eq(y, z); } diff --git a/tests/nodes/sequence_at_i32_positive.cairo b/tests/nodes/sequence_at_i32_positive.cairo index 36647a780..c25e289b7 100644 --- a/tests/nodes/sequence_at_i32_positive.cairo +++ b/tests/nodes/sequence_at_i32_positive.cairo @@ -8,6 +8,8 @@ use orion::operators::tensor::I32Tensor; use orion::utils::{assert_eq, assert_seq_eq}; use core::array::{ArrayTrait, SpanTrait}; use orion::operators::tensor::I32TensorPartialEq; +use orion::operators::sequence::I32Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] @@ -16,7 +18,7 @@ fn test_sequence_at_i32_positive() { let input_1 = input_1::input_1(); let z = output_0::output_0(); - let y = TensorTrait::sequence_at(input_0, input_1); + let y = SequenceTrait::sequence_at(input_0, input_1); assert_eq(y, z); } diff --git a/tests/nodes/sequence_at_i8_negative.cairo b/tests/nodes/sequence_at_i8_negative.cairo index 1e176feda..33007f491 100644 --- a/tests/nodes/sequence_at_i8_negative.cairo +++ b/tests/nodes/sequence_at_i8_negative.cairo @@ -10,6 +10,8 @@ use orion::utils::{assert_eq, assert_seq_eq}; use orion::operators::tensor::I8Tensor; use core::array::{ArrayTrait, SpanTrait}; use orion::operators::tensor::I8TensorPartialEq; +use orion::operators::sequence::I8Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] @@ -18,7 +20,7 @@ fn test_sequence_at_i8_negative() { let input_1 = input_1::input_1(); let z = output_0::output_0(); - let y = TensorTrait::sequence_at(input_0, input_1); + let y = SequenceTrait::sequence_at(input_0, input_1); assert_eq(y, z); } diff --git a/tests/nodes/sequence_at_i8_positive.cairo b/tests/nodes/sequence_at_i8_positive.cairo index b45bbe178..4cd397fff 100644 --- a/tests/nodes/sequence_at_i8_positive.cairo +++ b/tests/nodes/sequence_at_i8_positive.cairo @@ -10,6 +10,8 @@ use orion::utils::{assert_eq, assert_seq_eq}; use orion::operators::tensor::I8Tensor; use core::array::{ArrayTrait, SpanTrait}; use orion::operators::tensor::I8TensorPartialEq; +use orion::operators::sequence::I8Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] @@ -18,7 +20,7 @@ fn test_sequence_at_i8_positive() { let input_1 = input_1::input_1(); let z = output_0::output_0(); - let y = TensorTrait::sequence_at(input_0, input_1); + let y = SequenceTrait::sequence_at(input_0, input_1); assert_eq(y, z); } diff --git a/tests/nodes/sequence_at_u32_negative.cairo b/tests/nodes/sequence_at_u32_negative.cairo index ae5a20d90..dc22c2cbc 100644 --- a/tests/nodes/sequence_at_u32_negative.cairo +++ b/tests/nodes/sequence_at_u32_negative.cairo @@ -10,6 +10,8 @@ use orion::utils::{assert_eq, assert_seq_eq}; use core::array::{ArrayTrait, SpanTrait}; use orion::operators::tensor::U32Tensor; use orion::operators::tensor::I32TensorPartialEq; +use orion::operators::sequence::U32Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] @@ -18,7 +20,7 @@ fn test_sequence_at_u32_negative() { let input_1 = input_1::input_1(); let z = output_0::output_0(); - let y = TensorTrait::sequence_at(input_0, input_1); + let y = SequenceTrait::sequence_at(input_0, input_1); assert_eq(y, z); } diff --git a/tests/nodes/sequence_at_u32_positive.cairo b/tests/nodes/sequence_at_u32_positive.cairo index 6e3922cd1..b55489cb7 100644 --- a/tests/nodes/sequence_at_u32_positive.cairo +++ b/tests/nodes/sequence_at_u32_positive.cairo @@ -10,6 +10,8 @@ use orion::utils::{assert_eq, assert_seq_eq}; use core::array::{ArrayTrait, SpanTrait}; use orion::operators::tensor::U32Tensor; use orion::operators::tensor::I32TensorPartialEq; +use orion::operators::sequence::U32Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] @@ -18,7 +20,7 @@ fn test_sequence_at_u32_positive() { let input_1 = input_1::input_1(); let z = output_0::output_0(); - let y = TensorTrait::sequence_at(input_0, input_1); + let y = SequenceTrait::sequence_at(input_0, input_1); assert_eq(y, z); } diff --git a/tests/nodes/sequence_erase_fp16x16_empty.cairo b/tests/nodes/sequence_erase_fp16x16_empty.cairo index 23f42b7b5..3c6492874 100644 --- a/tests/nodes/sequence_erase_fp16x16_empty.cairo +++ b/tests/nodes/sequence_erase_fp16x16_empty.cairo @@ -7,6 +7,8 @@ use orion::operators::tensor::{TensorTrait, Tensor}; use orion::operators::tensor::FP16x16TensorPartialEq; use core::array::{ArrayTrait, SpanTrait}; use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::sequence::FP16x16Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] @@ -14,7 +16,7 @@ fn test_sequence_erase_fp16x16_empty() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = TensorTrait::sequence_erase(input_0, Option::None(())); + let y = SequenceTrait::sequence_erase(input_0, Option::None(())); assert_seq_eq(y, z); } diff --git a/tests/nodes/sequence_erase_fp16x16_negative.cairo b/tests/nodes/sequence_erase_fp16x16_negative.cairo index 183643c99..f8deff4bb 100644 --- a/tests/nodes/sequence_erase_fp16x16_negative.cairo +++ b/tests/nodes/sequence_erase_fp16x16_negative.cairo @@ -10,6 +10,8 @@ use orion::operators::tensor::FP16x16TensorPartialEq; use orion::operators::tensor::I32TensorPartialEq; use core::array::{ArrayTrait, SpanTrait}; use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::sequence::FP16x16Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] @@ -18,7 +20,7 @@ fn test_sequence_erase_fp16x16_negative() { let input_1 = input_1::input_1(); let z = output_0::output_0(); - let y = TensorTrait::sequence_erase(input_0, Option::Some(input_1)); + let y = SequenceTrait::sequence_erase(input_0, Option::Some(input_1)); assert_seq_eq(y, z); } diff --git a/tests/nodes/sequence_erase_fp16x16_positive.cairo b/tests/nodes/sequence_erase_fp16x16_positive.cairo index dc17a00b3..dbc41c494 100644 --- a/tests/nodes/sequence_erase_fp16x16_positive.cairo +++ b/tests/nodes/sequence_erase_fp16x16_positive.cairo @@ -10,6 +10,8 @@ use orion::operators::tensor::FP16x16TensorPartialEq; use orion::operators::tensor::I32TensorPartialEq; use core::array::{ArrayTrait, SpanTrait}; use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::sequence::FP16x16Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] @@ -18,7 +20,7 @@ fn test_sequence_erase_fp16x16_positive() { let input_1 = input_1::input_1(); let z = output_0::output_0(); - let y = TensorTrait::sequence_erase(input_0, Option::Some(input_1)); + let y = SequenceTrait::sequence_erase(input_0, Option::Some(input_1)); assert_seq_eq(y, z); } diff --git a/tests/nodes/sequence_erase_fp8x23_empty.cairo b/tests/nodes/sequence_erase_fp8x23_empty.cairo index 86dbf1cdc..bcaabf9ad 100644 --- a/tests/nodes/sequence_erase_fp8x23_empty.cairo +++ b/tests/nodes/sequence_erase_fp8x23_empty.cairo @@ -7,6 +7,8 @@ use orion::operators::tensor::FP8x23TensorPartialEq; use orion::operators::tensor::{TensorTrait, Tensor}; use core::array::{ArrayTrait, SpanTrait}; use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::sequence::FP8x23Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] @@ -14,7 +16,7 @@ fn test_sequence_erase_fp8x23_empty() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = TensorTrait::sequence_erase(input_0, Option::None(())); + let y = SequenceTrait::sequence_erase(input_0, Option::None(())); assert_seq_eq(y, z); } diff --git a/tests/nodes/sequence_erase_fp8x23_negative.cairo b/tests/nodes/sequence_erase_fp8x23_negative.cairo index 0a427225b..408c8e897 100644 --- a/tests/nodes/sequence_erase_fp8x23_negative.cairo +++ b/tests/nodes/sequence_erase_fp8x23_negative.cairo @@ -10,6 +10,8 @@ use orion::operators::tensor::{TensorTrait, Tensor}; use orion::operators::tensor::I32TensorPartialEq; use core::array::{ArrayTrait, SpanTrait}; use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::sequence::FP8x23Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] @@ -18,7 +20,7 @@ fn test_sequence_erase_fp8x23_negative() { let input_1 = input_1::input_1(); let z = output_0::output_0(); - let y = TensorTrait::sequence_erase(input_0, Option::Some(input_1)); + let y = SequenceTrait::sequence_erase(input_0, Option::Some(input_1)); assert_seq_eq(y, z); } diff --git a/tests/nodes/sequence_erase_fp8x23_positive.cairo b/tests/nodes/sequence_erase_fp8x23_positive.cairo index ed1ca666d..ce5f2e1a9 100644 --- a/tests/nodes/sequence_erase_fp8x23_positive.cairo +++ b/tests/nodes/sequence_erase_fp8x23_positive.cairo @@ -10,6 +10,8 @@ use orion::operators::tensor::{TensorTrait, Tensor}; use orion::operators::tensor::I32TensorPartialEq; use core::array::{ArrayTrait, SpanTrait}; use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::sequence::FP8x23Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] @@ -18,7 +20,7 @@ fn test_sequence_erase_fp8x23_positive() { let input_1 = input_1::input_1(); let z = output_0::output_0(); - let y = TensorTrait::sequence_erase(input_0, Option::Some(input_1)); + let y = SequenceTrait::sequence_erase(input_0, Option::Some(input_1)); assert_seq_eq(y, z); } diff --git a/tests/nodes/sequence_erase_i32_empty.cairo b/tests/nodes/sequence_erase_i32_empty.cairo index 800a27af0..bd7f7da59 100644 --- a/tests/nodes/sequence_erase_i32_empty.cairo +++ b/tests/nodes/sequence_erase_i32_empty.cairo @@ -7,6 +7,8 @@ use orion::operators::tensor::{TensorTrait, Tensor}; use orion::operators::tensor::I32TensorPartialEq; use core::array::{ArrayTrait, SpanTrait}; use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::sequence::I32Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] @@ -14,7 +16,7 @@ fn test_sequence_erase_i32_empty() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = TensorTrait::sequence_erase(input_0, Option::None(())); + let y = SequenceTrait::sequence_erase(input_0, Option::None(())); assert_seq_eq(y, z); } diff --git a/tests/nodes/sequence_erase_i32_negative.cairo b/tests/nodes/sequence_erase_i32_negative.cairo index 4bff09eb8..8d839c32d 100644 --- a/tests/nodes/sequence_erase_i32_negative.cairo +++ b/tests/nodes/sequence_erase_i32_negative.cairo @@ -8,6 +8,8 @@ use orion::operators::tensor::{TensorTrait, Tensor}; use orion::operators::tensor::I32TensorPartialEq; use core::array::{ArrayTrait, SpanTrait}; use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::sequence::I32Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] @@ -16,7 +18,7 @@ fn test_sequence_erase_i32_negative() { let input_1 = input_1::input_1(); let z = output_0::output_0(); - let y = TensorTrait::sequence_erase(input_0, Option::Some(input_1)); + let y = SequenceTrait::sequence_erase(input_0, Option::Some(input_1)); assert_seq_eq(y, z); } diff --git a/tests/nodes/sequence_erase_i32_positive.cairo b/tests/nodes/sequence_erase_i32_positive.cairo index aee5ef58b..928acffcb 100644 --- a/tests/nodes/sequence_erase_i32_positive.cairo +++ b/tests/nodes/sequence_erase_i32_positive.cairo @@ -8,6 +8,8 @@ use orion::operators::tensor::{TensorTrait, Tensor}; use orion::operators::tensor::I32TensorPartialEq; use core::array::{ArrayTrait, SpanTrait}; use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::sequence::I32Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] @@ -16,7 +18,7 @@ fn test_sequence_erase_i32_positive() { let input_1 = input_1::input_1(); let z = output_0::output_0(); - let y = TensorTrait::sequence_erase(input_0, Option::Some(input_1)); + let y = SequenceTrait::sequence_erase(input_0, Option::Some(input_1)); assert_seq_eq(y, z); } diff --git a/tests/nodes/sequence_erase_i8_empty.cairo b/tests/nodes/sequence_erase_i8_empty.cairo index 1ad6d9f1b..e7a55fb4d 100644 --- a/tests/nodes/sequence_erase_i8_empty.cairo +++ b/tests/nodes/sequence_erase_i8_empty.cairo @@ -7,6 +7,8 @@ use orion::operators::tensor::{TensorTrait, Tensor}; use orion::operators::tensor::I8Tensor; use core::array::{ArrayTrait, SpanTrait}; use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::sequence::I8Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] @@ -14,7 +16,7 @@ fn test_sequence_erase_i8_empty() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = TensorTrait::sequence_erase(input_0, Option::None(())); + let y = SequenceTrait::sequence_erase(input_0, Option::None(())); assert_seq_eq(y, z); } diff --git a/tests/nodes/sequence_erase_i8_negative.cairo b/tests/nodes/sequence_erase_i8_negative.cairo index 0caad87fa..a41daf502 100644 --- a/tests/nodes/sequence_erase_i8_negative.cairo +++ b/tests/nodes/sequence_erase_i8_negative.cairo @@ -10,6 +10,8 @@ use orion::operators::tensor::I8Tensor; use orion::operators::tensor::I32TensorPartialEq; use core::array::{ArrayTrait, SpanTrait}; use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::sequence::I8Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] @@ -18,7 +20,7 @@ fn test_sequence_erase_i8_negative() { let input_1 = input_1::input_1(); let z = output_0::output_0(); - let y = TensorTrait::sequence_erase(input_0, Option::Some(input_1)); + let y = SequenceTrait::sequence_erase(input_0, Option::Some(input_1)); assert_seq_eq(y, z); } diff --git a/tests/nodes/sequence_erase_i8_positive.cairo b/tests/nodes/sequence_erase_i8_positive.cairo index 52d1a62ee..d0c7dfb04 100644 --- a/tests/nodes/sequence_erase_i8_positive.cairo +++ b/tests/nodes/sequence_erase_i8_positive.cairo @@ -10,6 +10,8 @@ use orion::operators::tensor::I8Tensor; use orion::operators::tensor::I32TensorPartialEq; use core::array::{ArrayTrait, SpanTrait}; use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::sequence::I8Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] @@ -18,7 +20,7 @@ fn test_sequence_erase_i8_positive() { let input_1 = input_1::input_1(); let z = output_0::output_0(); - let y = TensorTrait::sequence_erase(input_0, Option::Some(input_1)); + let y = SequenceTrait::sequence_erase(input_0, Option::Some(input_1)); assert_seq_eq(y, z); } diff --git a/tests/nodes/sequence_erase_u32_empty.cairo b/tests/nodes/sequence_erase_u32_empty.cairo index 915fdfd41..8065bda49 100644 --- a/tests/nodes/sequence_erase_u32_empty.cairo +++ b/tests/nodes/sequence_erase_u32_empty.cairo @@ -7,6 +7,8 @@ use orion::operators::tensor::U32TensorPartialEq; use orion::operators::tensor::{TensorTrait, Tensor}; use core::array::{ArrayTrait, SpanTrait}; use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::sequence::U32Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] @@ -14,7 +16,7 @@ fn test_sequence_erase_u32_empty() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = TensorTrait::sequence_erase(input_0, Option::None(())); + let y = SequenceTrait::sequence_erase(input_0, Option::None(())); assert_seq_eq(y, z); } diff --git a/tests/nodes/sequence_erase_u32_negative.cairo b/tests/nodes/sequence_erase_u32_negative.cairo index 673724468..feab9c1ee 100644 --- a/tests/nodes/sequence_erase_u32_negative.cairo +++ b/tests/nodes/sequence_erase_u32_negative.cairo @@ -10,6 +10,8 @@ use orion::operators::tensor::{TensorTrait, Tensor}; use orion::operators::tensor::I32TensorPartialEq; use core::array::{ArrayTrait, SpanTrait}; use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::sequence::U32Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] @@ -18,7 +20,7 @@ fn test_sequence_erase_u32_negative() { let input_1 = input_1::input_1(); let z = output_0::output_0(); - let y = TensorTrait::sequence_erase(input_0, Option::Some(input_1)); + let y = SequenceTrait::sequence_erase(input_0, Option::Some(input_1)); assert_seq_eq(y, z); } diff --git a/tests/nodes/sequence_erase_u32_positive.cairo b/tests/nodes/sequence_erase_u32_positive.cairo index d5dc5f353..a742989c2 100644 --- a/tests/nodes/sequence_erase_u32_positive.cairo +++ b/tests/nodes/sequence_erase_u32_positive.cairo @@ -10,6 +10,8 @@ use orion::operators::tensor::{TensorTrait, Tensor}; use orion::operators::tensor::I32TensorPartialEq; use core::array::{ArrayTrait, SpanTrait}; use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::sequence::U32Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] @@ -18,7 +20,7 @@ fn test_sequence_erase_u32_positive() { let input_1 = input_1::input_1(); let z = output_0::output_0(); - let y = TensorTrait::sequence_erase(input_0, Option::Some(input_1)); + let y = SequenceTrait::sequence_erase(input_0, Option::Some(input_1)); assert_seq_eq(y, z); } diff --git a/tests/nodes/sequence_insert_fp16x16.cairo b/tests/nodes/sequence_insert_fp16x16.cairo index ca636990f..9f9fe2754 100644 --- a/tests/nodes/sequence_insert_fp16x16.cairo +++ b/tests/nodes/sequence_insert_fp16x16.cairo @@ -11,6 +11,8 @@ use orion::operators::tensor::FP16x16TensorPartialEq; use orion::operators::tensor::FP16x16Tensor; use core::array::{ArrayTrait, SpanTrait}; use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::sequence::FP16x16Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] diff --git a/tests/nodes/sequence_insert_fp8x23.cairo b/tests/nodes/sequence_insert_fp8x23.cairo index be838402d..e18208548 100644 --- a/tests/nodes/sequence_insert_fp8x23.cairo +++ b/tests/nodes/sequence_insert_fp8x23.cairo @@ -11,6 +11,8 @@ use orion::operators::tensor::FP8x23TensorPartialEq; use core::array::{ArrayTrait, SpanTrait}; use orion::operators::tensor::{TensorTrait, Tensor}; use orion::operators::tensor::FP8x23Tensor; +use orion::operators::sequence::FP8x23Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] diff --git a/tests/nodes/sequence_insert_i32.cairo b/tests/nodes/sequence_insert_i32.cairo index 0eac00d7b..f23a2ea7c 100644 --- a/tests/nodes/sequence_insert_i32.cairo +++ b/tests/nodes/sequence_insert_i32.cairo @@ -9,6 +9,8 @@ use orion::utils::{assert_eq, assert_seq_eq}; use orion::operators::tensor::I32TensorPartialEq; use core::array::{ArrayTrait, SpanTrait}; use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::sequence::I32Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] diff --git a/tests/nodes/sequence_insert_i8.cairo b/tests/nodes/sequence_insert_i8.cairo index 0d25a649b..604c39326 100644 --- a/tests/nodes/sequence_insert_i8.cairo +++ b/tests/nodes/sequence_insert_i8.cairo @@ -11,6 +11,8 @@ use orion::operators::tensor::I32TensorPartialEq; use core::array::{ArrayTrait, SpanTrait}; use orion::operators::tensor::I8TensorPartialEq; use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::sequence::I8Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] diff --git a/tests/nodes/sequence_insert_u32.cairo b/tests/nodes/sequence_insert_u32.cairo index 0a8a55f18..8791e81d6 100644 --- a/tests/nodes/sequence_insert_u32.cairo +++ b/tests/nodes/sequence_insert_u32.cairo @@ -11,6 +11,8 @@ use orion::operators::tensor::U32Tensor; use core::array::{ArrayTrait, SpanTrait}; use orion::operators::tensor::{TensorTrait, Tensor}; use orion::operators::tensor::U32TensorPartialEq; +use orion::operators::sequence::U32Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] diff --git a/tests/nodes/sequence_length_fp16x16.cairo b/tests/nodes/sequence_length_fp16x16.cairo index 5d39a1df6..b43f3494b 100644 --- a/tests/nodes/sequence_length_fp16x16.cairo +++ b/tests/nodes/sequence_length_fp16x16.cairo @@ -9,6 +9,8 @@ use core::array::{ArrayTrait, SpanTrait}; use orion::operators::tensor::FP16x16Tensor; use orion::operators::tensor::U32Tensor; use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::operators::sequence::FP16x16Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] diff --git a/tests/nodes/sequence_length_fp16x16_broadcast.cairo b/tests/nodes/sequence_length_fp16x16_broadcast.cairo index 0b7bba50b..f8c59e131 100644 --- a/tests/nodes/sequence_length_fp16x16_broadcast.cairo +++ b/tests/nodes/sequence_length_fp16x16_broadcast.cairo @@ -9,6 +9,8 @@ use core::array::{ArrayTrait, SpanTrait}; use orion::operators::tensor::FP16x16Tensor; use orion::operators::tensor::U32Tensor; use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::operators::sequence::FP16x16Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] diff --git a/tests/nodes/sequence_length_fp8x23.cairo b/tests/nodes/sequence_length_fp8x23.cairo index 64c3291c9..cbe83fbc7 100644 --- a/tests/nodes/sequence_length_fp8x23.cairo +++ b/tests/nodes/sequence_length_fp8x23.cairo @@ -9,6 +9,8 @@ use orion::operators::tensor::U32TensorPartialEq; use orion::utils::{assert_eq, assert_seq_eq}; use core::array::{ArrayTrait, SpanTrait}; use orion::operators::tensor::U32Tensor; +use orion::operators::sequence::FP8x23Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] diff --git a/tests/nodes/sequence_length_fp8x23_broadcast.cairo b/tests/nodes/sequence_length_fp8x23_broadcast.cairo index f61131b65..4c78e99bb 100644 --- a/tests/nodes/sequence_length_fp8x23_broadcast.cairo +++ b/tests/nodes/sequence_length_fp8x23_broadcast.cairo @@ -9,6 +9,8 @@ use orion::operators::tensor::U32TensorPartialEq; use orion::utils::{assert_eq, assert_seq_eq}; use core::array::{ArrayTrait, SpanTrait}; use orion::operators::tensor::U32Tensor; +use orion::operators::sequence::FP8x23Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] diff --git a/tests/nodes/sequence_length_i32.cairo b/tests/nodes/sequence_length_i32.cairo index 412e23dbd..566c70e9a 100644 --- a/tests/nodes/sequence_length_i32.cairo +++ b/tests/nodes/sequence_length_i32.cairo @@ -9,6 +9,8 @@ use core::array::{ArrayTrait, SpanTrait}; use orion::operators::tensor::I32Tensor; use orion::operators::tensor::I32TensorPartialEq; use orion::operators::tensor::U32Tensor; +use orion::operators::sequence::I32Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] diff --git a/tests/nodes/sequence_length_i32_broadcast.cairo b/tests/nodes/sequence_length_i32_broadcast.cairo index d773749eb..89ae3fad8 100644 --- a/tests/nodes/sequence_length_i32_broadcast.cairo +++ b/tests/nodes/sequence_length_i32_broadcast.cairo @@ -9,6 +9,8 @@ use core::array::{ArrayTrait, SpanTrait}; use orion::operators::tensor::I32Tensor; use orion::operators::tensor::I32TensorPartialEq; use orion::operators::tensor::U32Tensor; +use orion::operators::sequence::I32Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] diff --git a/tests/nodes/sequence_length_i8.cairo b/tests/nodes/sequence_length_i8.cairo index 3aba0cfb4..3de36247b 100644 --- a/tests/nodes/sequence_length_i8.cairo +++ b/tests/nodes/sequence_length_i8.cairo @@ -9,6 +9,8 @@ use core::array::{ArrayTrait, SpanTrait}; use orion::operators::tensor::U32Tensor; use orion::operators::tensor::I8TensorPartialEq; use orion::operators::tensor::I8Tensor; +use orion::operators::sequence::I8Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] diff --git a/tests/nodes/sequence_length_i8_broadcast.cairo b/tests/nodes/sequence_length_i8_broadcast.cairo index f8af9b5cf..03b2a9d90 100644 --- a/tests/nodes/sequence_length_i8_broadcast.cairo +++ b/tests/nodes/sequence_length_i8_broadcast.cairo @@ -9,6 +9,8 @@ use core::array::{ArrayTrait, SpanTrait}; use orion::operators::tensor::U32Tensor; use orion::operators::tensor::I8TensorPartialEq; use orion::operators::tensor::I8Tensor; +use orion::operators::sequence::I8Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] diff --git a/tests/nodes/sequence_length_u32.cairo b/tests/nodes/sequence_length_u32.cairo index edf9dced9..caf367c71 100644 --- a/tests/nodes/sequence_length_u32.cairo +++ b/tests/nodes/sequence_length_u32.cairo @@ -7,6 +7,8 @@ use orion::operators::tensor::U32TensorPartialEq; use orion::utils::{assert_eq, assert_seq_eq}; use core::array::{ArrayTrait, SpanTrait}; use orion::operators::tensor::U32Tensor; +use orion::operators::sequence::U32Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] diff --git a/tests/nodes/sequence_length_u32_broadcast.cairo b/tests/nodes/sequence_length_u32_broadcast.cairo index b5432cfaf..67b26c272 100644 --- a/tests/nodes/sequence_length_u32_broadcast.cairo +++ b/tests/nodes/sequence_length_u32_broadcast.cairo @@ -7,6 +7,8 @@ use orion::operators::tensor::U32TensorPartialEq; use orion::utils::{assert_eq, assert_seq_eq}; use core::array::{ArrayTrait, SpanTrait}; use orion::operators::tensor::U32Tensor; +use orion::operators::sequence::U32Sequence; +use orion::operators::sequence::SequenceTrait; #[test] #[available_gas(2000000000)] From 77100119f930fc98b84cac94d4e95aa95463cba8 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Fri, 5 Jan 2024 08:49:40 -0800 Subject: [PATCH 19/38] update doc --- docs/SUMMARY.md | 13 +++---- docs/framework/compatibility.md | 20 +++++------ docs/framework/operators/sequence/README.md | 4 +++ .../sequence.concat_from_sequence.md} | 0 .../sequence.sequence_at.md} | 0 .../sequence.sequence_erase.md} | 0 .../sequence.sequence_insert.md} | 0 .../sequence.sequence_length.md} | 0 docs/framework/operators/tensor/README.md | 8 ----- docs/framework/operators/tensor/tensor.erf.md | 4 --- .../operators/tensor/tensor.gather_nd.md | 8 ----- .../tensor/tensor.sequence_construct.md | 35 ------------------- .../operators/tensor/tensor.sequence_empty.md | 35 ------------------- 13 files changed, 21 insertions(+), 106 deletions(-) rename docs/framework/operators/{tensor/tensor.concat_from_sequence.md => sequence/sequence.concat_from_sequence.md} (100%) rename docs/framework/operators/{tensor/tensor.sequence_at.md => sequence/sequence.sequence_at.md} (100%) rename docs/framework/operators/{tensor/tensor.sequence_erase.md => sequence/sequence.sequence_erase.md} (100%) rename docs/framework/operators/{tensor/tensor.sequence_insert.md => sequence/sequence.sequence_insert.md} (100%) rename docs/framework/operators/{tensor/tensor.sequence_length.md => sequence/sequence.sequence_length.md} (100%) delete mode 100644 docs/framework/operators/tensor/tensor.sequence_construct.md delete mode 100644 docs/framework/operators/tensor/tensor.sequence_empty.md diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index f1f9f2b04..a40674622 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -110,16 +110,10 @@ * [tensor.reduce\_prod](framework/operators/tensor/tensor.reduce\_prod.md) * [tensor.gather\_elements](framework/operators/tensor/tensor.gather\_elements.md) * [tensor.gather\_nd](framework/operators/tensor/tensor.gather\_nd.md) - * [tensor.sequence\_length](framework/operators/tensor/tensor.sequence\_length.md) - * [tensor.sequence\_at](framework/operators/tensor/tensor.sequence\_at.md) * [tensor.reduce\_min](framework/operators/tensor/tensor.reduce\_min.md) * [tensor.shrink](framework/operators/tensor/tensor.shrink.md) - * [tensor.sequence\_empty](framework/operators/tensor/tensor.sequence\_empty.md) * [tensor.reduce\_mean](framework/operators/tensor/tensor.reduce\_mean.md) * [tensor.pow](framework/operators/tensor/tensor.pow.md) - * [tensor.sequence\_erase](framework/operators/tensor/tensor.sequence\_erase.md) - * [tensor.sequence\_insert](framework/operators/tensor/tensor.sequence\_insert.md) - * [tensor.concat\_from\_sequence](framework/operators/tensor/tensor.concat\_from\_sequence.md) * [tensor.is\_nan](framework/operators/tensor/tensor.is\_nan.md) * [tensor.is\_inf](framework/operators/tensor/tensor.is\_inf.md) * [tensor.not](framework/operators/tensor/tensor.not.md) @@ -144,6 +138,13 @@ * [Sequence](framework/operators/sequence/README.md) * [sequence.sequence\_construct](framework/operators/sequence/sequence.sequence\_construct.md) * [sequence.sequence\_empty](framework/operators/sequence/sequence.sequence\_empty.md) + * [tensor.sequence\_length](framework/operators/sequence/sequence.sequence\_length.md) + * [tensor.sequence\_at](framework/operators/sequence/sequence.sequence\_at.md) + * [tensor.sequence\_empty](framework/operators/sequence/sequence.sequence\_empty.md) + * [tensor.sequence\_erase](framework/operators/sequence/sequence.sequence\_erase.md) + * [tensor.sequence\_insert](framework/operators/sequence/sequence.sequence\_insert.md) + * [tensor.concat\_from\_sequence](framework/operators/sequence/sequence.concat\_from\_sequence.md) + ## πŸ› Hub diff --git a/docs/framework/compatibility.md b/docs/framework/compatibility.md index 68cd44241..9dff192fe 100644 --- a/docs/framework/compatibility.md +++ b/docs/framework/compatibility.md @@ -89,20 +89,20 @@ You can see below the list of current supported ONNX Operators: | [ReduceL1](operators/tensor/tensor.reduce\_l1.md) | :white\_check\_mark: | | [ReduceL2](operators/tensor/tensor.reduce\_l2.md) | :white\_check\_mark: | | [GatherElements](operators/tensor/tensor.gather/_elements.md) | :white\_check\_mark: | -| [SequenceLength](operators/tensor/tensor.sequence\_length.md) | :white\_check\_mark: | -| [SequenceAt](operators/tensor/tensor.sequence\_at.md) | :white\_check\_mark: | -| [SequenceConstruct](operators/tensor/tensor.sequence\_construct.md) | :white\_check\_mark: | +| [SequenceLength](operators/sequence/sequence.sequence\_length.md) | :white\_check\_mark: | +| [SequenceAt](operators/sequence/sequence.sequence\_at.md) | :white\_check\_mark: | +| [SequenceConstruct](operators/sequence/sequence.sequence\_construct.md) | :white\_check\_mark: | | [Shrink](operators/tensor/tensor.shrink.md) | :white\_check\_mark: | -| [SequenceEmpty](operators/tensor/tensor.sequence\_empty.md) | :white\_check\_mark: | +| [SequenceEmpty](operators/sequence/sequence.sequence\_empty.md) | :white\_check\_mark: | | [ReduceL2](operators/tensor/tensor.reduce\_l2.md) | :white\_check\_mark: | -| [SequenceErase](operators/tensor/tensor.sequence\_erase.md) | :white\_check\_mark: | -| [SequenceInsert](operators/tensor/tensor.sequence\_insert.md) | :white\_check\_mark: | -| [ConcatFromSequence](operators/tensor/tensor.concat\_from\_sequence.md) | :white\_check\_mark: | +| [SequenceErase](operators/sequence/sequence.sequence\_erase.md) | :white\_check\_mark: | +| [SequenceInsert](operators/sequence/sequence.sequence\_insert.md) | :white\_check\_mark: | +| [ConcatFromSequence](operators/sequence/sequence.concat\_from\_sequence.md) | :white\_check\_mark: | | [IsNaN](operators/tensor/tensor.is\_nan.md) | :white\_check\_mark: | | [IsInf](operators/tensor/tensor.is\_inf.md) | :white\_check\_mark: | | [Not](operators/tensor/tensor.not.md) | :white\_check\_mark: | -| [GatherND](operators/tensor/tensor.gather/_nd.md) | :white\_check\_mark: | -| [ReduceLogSum](operators/tensor/tensor.reduce\_log\_sum.md) | :white\_check\_mark: | -| [Erf](operators/tensor/tensor.erf.md) | :white\_check\_mark: | +| [GatherND](operators/tensor/tensor.gather/_nd.md) | :white\_check\_mark: | +| [ReduceLogSum](operators/tensor/tensor.reduce\_log\_sum.md) | :white\_check\_mark: | +| [Erf](operators/tensor/tensor.erf.md) | :white\_check\_mark: | Current Operators support: **96/156 (62%)** diff --git a/docs/framework/operators/sequence/README.md b/docs/framework/operators/sequence/README.md index 2dab15cb1..6b3599e84 100644 --- a/docs/framework/operators/sequence/README.md +++ b/docs/framework/operators/sequence/README.md @@ -25,4 +25,8 @@ Orion supports currently these `Sequence` types. | --- | --- | | [`sequence.sequence_construct`](sequence.sequence\_construct.md) | Constructs a tensor sequence containing the input tensors. | | [`sequence.sequence_empty`](sequence.sequence\_empty.md) | Returns an empty tensor sequence. | +| [`sequence.sequence_length`](sequence.sequence\_length.md) | Returns the length of the input sequence. | +| [`sequence.sequence_insert`](sequence.sequence\_insert.md) | Insert a tensor into a sequence. | +| [`sequence.sequence_at`](sequence.sequence\_at.md) | Outputs the tensor at the specified position in the input sequence. | +| [`sequence.concat_from_sequence`](sequence.concat\_from\_sequence.md) | Concatenate a sequence of tensors into a single tensor. | diff --git a/docs/framework/operators/tensor/tensor.concat_from_sequence.md b/docs/framework/operators/sequence/sequence.concat_from_sequence.md similarity index 100% rename from docs/framework/operators/tensor/tensor.concat_from_sequence.md rename to docs/framework/operators/sequence/sequence.concat_from_sequence.md diff --git a/docs/framework/operators/tensor/tensor.sequence_at.md b/docs/framework/operators/sequence/sequence.sequence_at.md similarity index 100% rename from docs/framework/operators/tensor/tensor.sequence_at.md rename to docs/framework/operators/sequence/sequence.sequence_at.md diff --git a/docs/framework/operators/tensor/tensor.sequence_erase.md b/docs/framework/operators/sequence/sequence.sequence_erase.md similarity index 100% rename from docs/framework/operators/tensor/tensor.sequence_erase.md rename to docs/framework/operators/sequence/sequence.sequence_erase.md diff --git a/docs/framework/operators/tensor/tensor.sequence_insert.md b/docs/framework/operators/sequence/sequence.sequence_insert.md similarity index 100% rename from docs/framework/operators/tensor/tensor.sequence_insert.md rename to docs/framework/operators/sequence/sequence.sequence_insert.md diff --git a/docs/framework/operators/tensor/tensor.sequence_length.md b/docs/framework/operators/sequence/sequence.sequence_length.md similarity index 100% rename from docs/framework/operators/tensor/tensor.sequence_length.md rename to docs/framework/operators/sequence/sequence.sequence_length.md diff --git a/docs/framework/operators/tensor/README.md b/docs/framework/operators/tensor/README.md index 777e00b78..75e094f99 100644 --- a/docs/framework/operators/tensor/README.md +++ b/docs/framework/operators/tensor/README.md @@ -106,19 +106,11 @@ use orion::operators::tensor::TensorTrait; | [`tensor.reduce_l2`](tensor.reduce\_l2.md) | Computes the L2 norm of the input tensor's elements along the provided axes. | | [`tensor.gather_elements`](tensor.gather\_elements.md) | GatherElements is an indexing operation that produces its output by indexing into the input data tensor at index positions determined by elements of the indices tensor. | | [`tensor.reduce_min`](tensor.reduce\_min.md) | Computes the min of the input tensor's elements along the provided axes. | -| [`tensor.sequence_empty`](tensor.sequence\_empty.md) | Returns an empty tensor sequence. | -| [`tensor.sequence_length`](tensor.sequence\_length.md) | Returns the length of the input sequence. | -| [`tensor.sequence_insert`](tensor.sequence\_insert.md) | Insert a tensor into a sequence. | -| [`tensor.sequence_at`](tensor.sequence\_at.md) | Outputs the tensor at the specified position in the input sequence. | -| [`tensor.sequence_construct`](tensor.sequence\_construct.md) | Constructs a tensor sequence containing the input tensors. | -| [`tensor.shrink`](tensor.shrink.md) | Shrinks the input tensor element-wise to the output tensor. | | [`tensor.reduce_mean`](tensor.reduce\_mean.md) | Computes the mean of the input tensor's elements along the provided axes. | | [`tensor.pow`](tensor.pow.md) | Pow takes input data (Tensor) and exponent Tensor, and produces one output data (Tensor) where the function f(x) = x^exponent, is applied to the data tensor elementwise. | -| [`tensor.sequence_empty`](tensor.sequence\_empty.md) | Returns an empty tensor sequence. | | [`tensor.binarizer`](tensor.binarizer.md) | Maps the values of a tensor element-wise to 0 or 1 based on the comparison against a threshold value. | | [`tensor.array_feature_extractor`](tensor.array\_feature\_extractor.md) | Selects elements of the input tensor based on the indices passed applied to the last tensor axis. | | [`tensor.reduce_min`](tensor.reduce\_min.md) | Computes the min of the input tensor's elements along the provided axes. | -| [`tensor.concat_from_sequence`](tensor.concat\_from\_sequence.md) | Concatenate a sequence of tensors into a single tensor. | | [`tensor.is_nan`](tensor.is\_nan.md) | Returns which elements of the input are NaN. | | [`tensor.is_inf`](tensor.is\_inf.md) | Maps infinity to true and other values to false. | | [`tensor.not`](tensor.not.md) | Computes the logical negation of all elements in the input tensor. | diff --git a/docs/framework/operators/tensor/tensor.erf.md b/docs/framework/operators/tensor/tensor.erf.md index 19ce86a94..384a941d0 100644 --- a/docs/framework/operators/tensor/tensor.erf.md +++ b/docs/framework/operators/tensor/tensor.erf.md @@ -6,10 +6,6 @@ Computes the mean of the input tensor's elements along the provided axes. -## Args - -* `self`(`@Tensor`) - The input tensor. - ## Returns A new `Tensor` of the same shape as the input tensor with diff --git a/docs/framework/operators/tensor/tensor.gather_nd.md b/docs/framework/operators/tensor/tensor.gather_nd.md index 021d4f235..ce6f94462 100644 --- a/docs/framework/operators/tensor/tensor.gather_nd.md +++ b/docs/framework/operators/tensor/tensor.gather_nd.md @@ -21,14 +21,6 @@ Given data tensor of rank r >= 1, indices tensor of rank q >= 1, and batch_dims ## Returns A new `Tensor` . - -## Example - -```rust -use array::{ArrayTrait, SpanTrait}; - -use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor}; - fn gather_nd_example() -> Tensor { let tensor = TensorTrait::::new( shape: array![2, 2].span(), diff --git a/docs/framework/operators/tensor/tensor.sequence_construct.md b/docs/framework/operators/tensor/tensor.sequence_construct.md deleted file mode 100644 index d5e627bd1..000000000 --- a/docs/framework/operators/tensor/tensor.sequence_construct.md +++ /dev/null @@ -1,35 +0,0 @@ -## tensor.sequence_construct - -```rust - fn sequence_construct(tensors: Array>) -> Array>; -``` - -Constructs a tensor sequence containing the input tensors. - -## Args - -* `tensors`(`Array>`) - The array of input tensors. - -## Panics - -* Panics if input tensor array is empty. - -## Returns - -A tensor sequence `Array>` containing the input tensors. - -## Examples - -```rust -use core::array::{ArrayTrait, SpanTrait}; - -use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor}; - -fn sequence_construct_example() -> Array> { - let tensor1 = TensorTrait::new(shape: array![2, 2].span(), data: array![0, 1, 2, 3].span()); - let tensor2 = TensorTrait::new(shape: array![2, 2].span(), data: array![4, 5, 6, 7].span()); - let result = TensorTrait::sequence_construct(tensors: array![tensor1, tensor2]); - return result; -} ->>> [[0, 1, 2, 3], [4, 5, 6, 7]] -``` diff --git a/docs/framework/operators/tensor/tensor.sequence_empty.md b/docs/framework/operators/tensor/tensor.sequence_empty.md deleted file mode 100644 index 60ea380e5..000000000 --- a/docs/framework/operators/tensor/tensor.sequence_empty.md +++ /dev/null @@ -1,35 +0,0 @@ -# tensor.sequence_empty - -```rust - fn sequence_empty() -> Array>; -``` - -Returns an empty tensor sequence. - -## Args - -## Returns - -An empty `Array>` instance. - -## Examples - -Let's create a new empty sequence. - -```rust -use core::array::{ArrayTrait, SpanTrait}; - -use orion::operators::tensor::{ - TensorTrait, // we import the trait - Tensor, // we import the type - U32Tensor // we import the implementation. -}; - -fn sequence_empty_example() -> Array> { - let sequence = TensorTrait::sequence_empty(); - - return sequence; -} - ->>> [] -``` From dbc5794727f4d6cb6b7ec145af7c4692fc27b87b Mon Sep 17 00:00:00 2001 From: chachaleo Date: Sat, 6 Jan 2024 10:56:54 +0100 Subject: [PATCH 20/38] feat: resize --- docs/SUMMARY.md | 1 + docs/framework/compatibility.md | 1 + docs/framework/operators/tensor/README.md | 1 + .../operators/tensor/tensor.resize.md | 229 ++ nodegen/node/resize.py | 2204 +++++++++++++++++ src/operators/ml.cairo | 4 +- src/operators/tensor/core.cairo | 251 ++ .../tensor/implementations/tensor_bool.cairo | 22 +- .../implementations/tensor_fp16x16.cairo | 38 +- .../implementations/tensor_fp16x16wide.cairo | 24 +- .../implementations/tensor_fp32x32.cairo | 38 +- .../implementations/tensor_fp64x64.cairo | 38 +- .../implementations/tensor_fp8x23.cairo | 38 +- .../implementations/tensor_fp8x23wide.cairo | 24 +- .../tensor/implementations/tensor_i32.cairo | 24 +- .../tensor/implementations/tensor_i8.cairo | 24 +- .../tensor/implementations/tensor_u32.cairo | 24 +- src/operators/tensor/math.cairo | 1 + src/operators/tensor/math/gather_nd.cairo | 45 +- src/operators/tensor/math/resize.cairo | 1470 +++++++++++ tests/lib.cairo | 11 +- tests/nodes.cairo | 37 + .../gather_nd_fp16x16_3d_batch_dims1.cairo | 2 +- .../gather_nd_fp16x16_3d_batch_dims2.cairo | 2 +- .../nodes/gather_nd_fp16x16_3d_default.cairo | 2 +- .../gather_nd_fp8x23_3d_batch_dims1.cairo | 2 +- .../gather_nd_fp8x23_3d_batch_dims2.cairo | 2 +- tests/nodes/gather_nd_fp8x23_3d_default.cairo | 2 +- .../nodes/gather_nd_i32_3d_batch_dims1.cairo | 2 +- .../nodes/gather_nd_i32_3d_batch_dims2.cairo | 2 +- tests/nodes/gather_nd_i32_3d_default.cairo | 2 +- tests/nodes/gather_nd_i8_3d_batch_dims1.cairo | 2 +- tests/nodes/gather_nd_i8_3d_default.cairo | 2 +- tests/nodes/gather_nd_u32_batch_dims1.cairo | 2 +- tests/nodes/gather_nd_u32_batch_dims2.cairo | 2 +- tests/nodes/gather_nd_u32_default.cairo | 2 +- .../resize_downsample_scales_cubic.cairo | 39 + .../input_0.cairo | 31 + .../input_1.cairo | 16 + .../output_0.cairo | 24 + ..._scales_cubic_A_n0p5_exclude_outside.cairo | 40 + .../input_0.cairo | 31 + .../input_1.cairo | 16 + .../input_2.cairo | 13 + .../output_0.cairo | 24 + ...ownsample_scales_cubic_align_corners.cairo | 39 + .../input_0.cairo | 31 + .../input_1.cairo | 16 + .../output_0.cairo | 24 + ...ze_downsample_scales_cubic_antialias.cairo | 39 + .../input_0.cairo | 31 + .../input_1.cairo | 16 + .../output_0.cairo | 19 + ...wnsample_scales_linear_align_corners.cairo | 39 + .../input_0.cairo | 23 + .../input_1.cairo | 16 + .../output_0.cairo | 17 + ...e_downsample_scales_linear_antialias.cairo | 39 + .../input_0.cairo | 31 + .../input_1.cairo | 16 + .../output_0.cairo | 19 + ...e_scales_linear_half_pixel_symmetric.cairo | 39 + .../input_0.cairo | 19 + .../input_1.cairo | 16 + .../output_0.cairo | 17 + .../resize_downsample_scales_nearest.cairo | 39 + .../input_0.cairo | 23 + .../input_1.cairo | 16 + .../output_0.cairo | 17 + .../nodes/resize_downsample_sizes_cubic.cairo | 40 + .../input_0.cairo | 31 + .../input_1.cairo | 15 + .../output_0.cairo | 24 + ...ize_downsample_sizes_cubic_antialias.cairo | 41 + .../input_0.cairo | 31 + .../input_1.cairo | 15 + .../output_0.cairo | 24 + ...mple_sizes_linear_pytorch_half_pixel.cairo | 41 + .../input_0.cairo | 31 + .../input_1.cairo | 15 + .../output_0.cairo | 18 + .../resize_downsample_sizes_nearest.cairo | 43 + .../input_0.cairo | 23 + .../input_1.cairo | 15 + .../output_0.cairo | 18 + ..._downsample_sizes_nearest_not_larger.cairo | 44 + .../input_0.cairo | 23 + .../input_1.cairo | 13 + .../input_2.cairo | 13 + .../output_0.cairo | 17 + ...downsample_sizes_nearest_not_smaller.cairo | 43 + .../input_0.cairo | 23 + .../input_1.cairo | 13 + .../input_2.cairo | 13 + .../output_0.cairo | 21 + tests/nodes/resize_tf_crop_and_resize.cairo | 43 + .../resize_tf_crop_and_resize/input_0.cairo | 31 + .../resize_tf_crop_and_resize/input_1.cairo | 15 + .../resize_tf_crop_and_resize/input_2.cairo | 20 + .../resize_tf_crop_and_resize/output_0.cairo | 24 + .../resize_tf_crop_and_resize_axes_2_3.cairo | 45 + .../input_0.cairo | 31 + .../input_1.cairo | 13 + .../input_2.cairo | 16 + .../input_3.cairo | 13 + .../output_0.cairo | 24 + .../resize_tf_crop_and_resize_axes_3_2.cairo | 45 + .../input_0.cairo | 31 + .../input_1.cairo | 13 + .../input_2.cairo | 16 + .../input_3.cairo | 13 + .../output_0.cairo | 24 + ..._crop_and_resize_extrapolation_value.cairo | 44 + .../input_0.cairo | 31 + .../input_1.cairo | 15 + .../input_2.cairo | 20 + .../output_0.cairo | 24 + .../nodes/resize_upsample_scales_cubic.cairo | 39 + .../input_0.cairo | 31 + .../input_1.cairo | 16 + .../output_0.cairo | 79 + ..._scales_cubic_A_n0p5_exclude_outside.cairo | 40 + .../input_0.cairo | 31 + .../input_1.cairo | 16 + .../output_0.cairo | 79 + ..._upsample_scales_cubic_align_corners.cairo | 39 + .../input_0.cairo | 31 + .../input_1.cairo | 16 + .../output_0.cairo | 79 + ...ize_upsample_scales_cubic_asymmetric.cairo | 39 + .../input_0.cairo | 31 + .../input_1.cairo | 16 + .../output_0.cairo | 79 + .../nodes/resize_upsample_scales_linear.cairo | 39 + .../input_0.cairo | 19 + .../input_1.cairo | 16 + .../output_0.cairo | 31 + ...upsample_scales_linear_align_corners.cairo | 39 + .../input_0.cairo | 19 + .../input_1.cairo | 16 + .../output_0.cairo | 31 + ...e_scales_linear_half_pixel_symmetric.cairo | 39 + .../input_0.cairo | 19 + .../input_1.cairo | 16 + .../output_0.cairo | 35 + .../resize_upsample_scales_nearest.cairo | 39 + .../input_0.cairo | 19 + .../input_1.cairo | 16 + .../output_0.cairo | 39 + ...ize_upsample_scales_nearest_axes_2_3.cairo | 41 + .../input_0.cairo | 19 + .../input_1.cairo | 14 + .../input_2.cairo | 13 + .../output_0.cairo | 39 + ...ize_upsample_scales_nearest_axes_3_2.cairo | 41 + .../input_0.cairo | 19 + .../input_1.cairo | 14 + .../input_2.cairo | 13 + .../output_0.cairo | 39 + tests/nodes/resize_upsample_sizes_cubic.cairo | 39 + .../resize_upsample_sizes_cubic/input_0.cairo | 31 + .../resize_upsample_sizes_cubic/input_1.cairo | 15 + .../output_0.cairo | 105 + .../nodes/resize_upsample_sizes_nearest.cairo | 39 + .../input_0.cairo | 19 + .../input_1.cairo | 15 + .../output_0.cairo | 71 + ...size_upsample_sizes_nearest_axes_2_3.cairo | 41 + .../input_0.cairo | 19 + .../input_1.cairo | 13 + .../input_2.cairo | 13 + .../output_0.cairo | 71 + ...size_upsample_sizes_nearest_axes_3_2.cairo | 45 + .../input_0.cairo | 19 + .../input_1.cairo | 13 + .../input_2.cairo | 13 + .../output_0.cairo | 71 + ...sample_sizes_nearest_ceil_half_pixel.cairo | 43 + .../input_0.cairo | 31 + .../input_1.cairo | 15 + .../output_0.cairo | 79 + ...le_sizes_nearest_floor_align_corners.cairo | 41 + .../input_0.cairo | 31 + .../input_1.cairo | 15 + .../output_0.cairo | 79 + ...ze_upsample_sizes_nearest_not_larger.cairo | 45 + .../input_0.cairo | 19 + .../input_1.cairo | 13 + .../input_2.cairo | 13 + .../output_0.cairo | 64 + ...e_upsample_sizes_nearest_not_smaller.cairo | 45 + .../input_0.cairo | 19 + .../input_1.cairo | 13 + .../input_2.cairo | 13 + .../output_0.cairo | 79 + ...nearest_round_prefer_ceil_asymmetric.cairo | 41 + .../input_0.cairo | 31 + .../input_1.cairo | 15 + .../output_0.cairo | 79 + 199 files changed, 9367 insertions(+), 64 deletions(-) create mode 100644 docs/framework/operators/tensor/tensor.resize.md create mode 100644 nodegen/node/resize.py create mode 100644 src/operators/tensor/math/resize.cairo create mode 100644 tests/nodes/resize_downsample_scales_cubic.cairo create mode 100644 tests/nodes/resize_downsample_scales_cubic/input_0.cairo create mode 100644 tests/nodes/resize_downsample_scales_cubic/input_1.cairo create mode 100644 tests/nodes/resize_downsample_scales_cubic/output_0.cairo create mode 100644 tests/nodes/resize_downsample_scales_cubic_A_n0p5_exclude_outside.cairo create mode 100644 tests/nodes/resize_downsample_scales_cubic_A_n0p5_exclude_outside/input_0.cairo create mode 100644 tests/nodes/resize_downsample_scales_cubic_A_n0p5_exclude_outside/input_1.cairo create mode 100644 tests/nodes/resize_downsample_scales_cubic_A_n0p5_exclude_outside/input_2.cairo create mode 100644 tests/nodes/resize_downsample_scales_cubic_A_n0p5_exclude_outside/output_0.cairo create mode 100644 tests/nodes/resize_downsample_scales_cubic_align_corners.cairo create mode 100644 tests/nodes/resize_downsample_scales_cubic_align_corners/input_0.cairo create mode 100644 tests/nodes/resize_downsample_scales_cubic_align_corners/input_1.cairo create mode 100644 tests/nodes/resize_downsample_scales_cubic_align_corners/output_0.cairo create mode 100644 tests/nodes/resize_downsample_scales_cubic_antialias.cairo create mode 100644 tests/nodes/resize_downsample_scales_cubic_antialias/input_0.cairo create mode 100644 tests/nodes/resize_downsample_scales_cubic_antialias/input_1.cairo create mode 100644 tests/nodes/resize_downsample_scales_cubic_antialias/output_0.cairo create mode 100644 tests/nodes/resize_downsample_scales_linear_align_corners.cairo create mode 100644 tests/nodes/resize_downsample_scales_linear_align_corners/input_0.cairo create mode 100644 tests/nodes/resize_downsample_scales_linear_align_corners/input_1.cairo create mode 100644 tests/nodes/resize_downsample_scales_linear_align_corners/output_0.cairo create mode 100644 tests/nodes/resize_downsample_scales_linear_antialias.cairo create mode 100644 tests/nodes/resize_downsample_scales_linear_antialias/input_0.cairo create mode 100644 tests/nodes/resize_downsample_scales_linear_antialias/input_1.cairo create mode 100644 tests/nodes/resize_downsample_scales_linear_antialias/output_0.cairo create mode 100644 tests/nodes/resize_downsample_scales_linear_half_pixel_symmetric.cairo create mode 100644 tests/nodes/resize_downsample_scales_linear_half_pixel_symmetric/input_0.cairo create mode 100644 tests/nodes/resize_downsample_scales_linear_half_pixel_symmetric/input_1.cairo create mode 100644 tests/nodes/resize_downsample_scales_linear_half_pixel_symmetric/output_0.cairo create mode 100644 tests/nodes/resize_downsample_scales_nearest.cairo create mode 100644 tests/nodes/resize_downsample_scales_nearest/input_0.cairo create mode 100644 tests/nodes/resize_downsample_scales_nearest/input_1.cairo create mode 100644 tests/nodes/resize_downsample_scales_nearest/output_0.cairo create mode 100644 tests/nodes/resize_downsample_sizes_cubic.cairo create mode 100644 tests/nodes/resize_downsample_sizes_cubic/input_0.cairo create mode 100644 tests/nodes/resize_downsample_sizes_cubic/input_1.cairo create mode 100644 tests/nodes/resize_downsample_sizes_cubic/output_0.cairo create mode 100644 tests/nodes/resize_downsample_sizes_cubic_antialias.cairo create mode 100644 tests/nodes/resize_downsample_sizes_cubic_antialias/input_0.cairo create mode 100644 tests/nodes/resize_downsample_sizes_cubic_antialias/input_1.cairo create mode 100644 tests/nodes/resize_downsample_sizes_cubic_antialias/output_0.cairo create mode 100644 tests/nodes/resize_downsample_sizes_linear_pytorch_half_pixel.cairo create mode 100644 tests/nodes/resize_downsample_sizes_linear_pytorch_half_pixel/input_0.cairo create mode 100644 tests/nodes/resize_downsample_sizes_linear_pytorch_half_pixel/input_1.cairo create mode 100644 tests/nodes/resize_downsample_sizes_linear_pytorch_half_pixel/output_0.cairo create mode 100644 tests/nodes/resize_downsample_sizes_nearest.cairo create mode 100644 tests/nodes/resize_downsample_sizes_nearest/input_0.cairo create mode 100644 tests/nodes/resize_downsample_sizes_nearest/input_1.cairo create mode 100644 tests/nodes/resize_downsample_sizes_nearest/output_0.cairo create mode 100644 tests/nodes/resize_downsample_sizes_nearest_not_larger.cairo create mode 100644 tests/nodes/resize_downsample_sizes_nearest_not_larger/input_0.cairo create mode 100644 tests/nodes/resize_downsample_sizes_nearest_not_larger/input_1.cairo create mode 100644 tests/nodes/resize_downsample_sizes_nearest_not_larger/input_2.cairo create mode 100644 tests/nodes/resize_downsample_sizes_nearest_not_larger/output_0.cairo create mode 100644 tests/nodes/resize_downsample_sizes_nearest_not_smaller.cairo create mode 100644 tests/nodes/resize_downsample_sizes_nearest_not_smaller/input_0.cairo create mode 100644 tests/nodes/resize_downsample_sizes_nearest_not_smaller/input_1.cairo create mode 100644 tests/nodes/resize_downsample_sizes_nearest_not_smaller/input_2.cairo create mode 100644 tests/nodes/resize_downsample_sizes_nearest_not_smaller/output_0.cairo create mode 100644 tests/nodes/resize_tf_crop_and_resize.cairo create mode 100644 tests/nodes/resize_tf_crop_and_resize/input_0.cairo create mode 100644 tests/nodes/resize_tf_crop_and_resize/input_1.cairo create mode 100644 tests/nodes/resize_tf_crop_and_resize/input_2.cairo create mode 100644 tests/nodes/resize_tf_crop_and_resize/output_0.cairo create mode 100644 tests/nodes/resize_tf_crop_and_resize_axes_2_3.cairo create mode 100644 tests/nodes/resize_tf_crop_and_resize_axes_2_3/input_0.cairo create mode 100644 tests/nodes/resize_tf_crop_and_resize_axes_2_3/input_1.cairo create mode 100644 tests/nodes/resize_tf_crop_and_resize_axes_2_3/input_2.cairo create mode 100644 tests/nodes/resize_tf_crop_and_resize_axes_2_3/input_3.cairo create mode 100644 tests/nodes/resize_tf_crop_and_resize_axes_2_3/output_0.cairo create mode 100644 tests/nodes/resize_tf_crop_and_resize_axes_3_2.cairo create mode 100644 tests/nodes/resize_tf_crop_and_resize_axes_3_2/input_0.cairo create mode 100644 tests/nodes/resize_tf_crop_and_resize_axes_3_2/input_1.cairo create mode 100644 tests/nodes/resize_tf_crop_and_resize_axes_3_2/input_2.cairo create mode 100644 tests/nodes/resize_tf_crop_and_resize_axes_3_2/input_3.cairo create mode 100644 tests/nodes/resize_tf_crop_and_resize_axes_3_2/output_0.cairo create mode 100644 tests/nodes/resize_tf_crop_and_resize_extrapolation_value.cairo create mode 100644 tests/nodes/resize_tf_crop_and_resize_extrapolation_value/input_0.cairo create mode 100644 tests/nodes/resize_tf_crop_and_resize_extrapolation_value/input_1.cairo create mode 100644 tests/nodes/resize_tf_crop_and_resize_extrapolation_value/input_2.cairo create mode 100644 tests/nodes/resize_tf_crop_and_resize_extrapolation_value/output_0.cairo create mode 100644 tests/nodes/resize_upsample_scales_cubic.cairo create mode 100644 tests/nodes/resize_upsample_scales_cubic/input_0.cairo create mode 100644 tests/nodes/resize_upsample_scales_cubic/input_1.cairo create mode 100644 tests/nodes/resize_upsample_scales_cubic/output_0.cairo create mode 100644 tests/nodes/resize_upsample_scales_cubic_A_n0p5_exclude_outside.cairo create mode 100644 tests/nodes/resize_upsample_scales_cubic_A_n0p5_exclude_outside/input_0.cairo create mode 100644 tests/nodes/resize_upsample_scales_cubic_A_n0p5_exclude_outside/input_1.cairo create mode 100644 tests/nodes/resize_upsample_scales_cubic_A_n0p5_exclude_outside/output_0.cairo create mode 100644 tests/nodes/resize_upsample_scales_cubic_align_corners.cairo create mode 100644 tests/nodes/resize_upsample_scales_cubic_align_corners/input_0.cairo create mode 100644 tests/nodes/resize_upsample_scales_cubic_align_corners/input_1.cairo create mode 100644 tests/nodes/resize_upsample_scales_cubic_align_corners/output_0.cairo create mode 100644 tests/nodes/resize_upsample_scales_cubic_asymmetric.cairo create mode 100644 tests/nodes/resize_upsample_scales_cubic_asymmetric/input_0.cairo create mode 100644 tests/nodes/resize_upsample_scales_cubic_asymmetric/input_1.cairo create mode 100644 tests/nodes/resize_upsample_scales_cubic_asymmetric/output_0.cairo create mode 100644 tests/nodes/resize_upsample_scales_linear.cairo create mode 100644 tests/nodes/resize_upsample_scales_linear/input_0.cairo create mode 100644 tests/nodes/resize_upsample_scales_linear/input_1.cairo create mode 100644 tests/nodes/resize_upsample_scales_linear/output_0.cairo create mode 100644 tests/nodes/resize_upsample_scales_linear_align_corners.cairo create mode 100644 tests/nodes/resize_upsample_scales_linear_align_corners/input_0.cairo create mode 100644 tests/nodes/resize_upsample_scales_linear_align_corners/input_1.cairo create mode 100644 tests/nodes/resize_upsample_scales_linear_align_corners/output_0.cairo create mode 100644 tests/nodes/resize_upsample_scales_linear_half_pixel_symmetric.cairo create mode 100644 tests/nodes/resize_upsample_scales_linear_half_pixel_symmetric/input_0.cairo create mode 100644 tests/nodes/resize_upsample_scales_linear_half_pixel_symmetric/input_1.cairo create mode 100644 tests/nodes/resize_upsample_scales_linear_half_pixel_symmetric/output_0.cairo create mode 100644 tests/nodes/resize_upsample_scales_nearest.cairo create mode 100644 tests/nodes/resize_upsample_scales_nearest/input_0.cairo create mode 100644 tests/nodes/resize_upsample_scales_nearest/input_1.cairo create mode 100644 tests/nodes/resize_upsample_scales_nearest/output_0.cairo create mode 100644 tests/nodes/resize_upsample_scales_nearest_axes_2_3.cairo create mode 100644 tests/nodes/resize_upsample_scales_nearest_axes_2_3/input_0.cairo create mode 100644 tests/nodes/resize_upsample_scales_nearest_axes_2_3/input_1.cairo create mode 100644 tests/nodes/resize_upsample_scales_nearest_axes_2_3/input_2.cairo create mode 100644 tests/nodes/resize_upsample_scales_nearest_axes_2_3/output_0.cairo create mode 100644 tests/nodes/resize_upsample_scales_nearest_axes_3_2.cairo create mode 100644 tests/nodes/resize_upsample_scales_nearest_axes_3_2/input_0.cairo create mode 100644 tests/nodes/resize_upsample_scales_nearest_axes_3_2/input_1.cairo create mode 100644 tests/nodes/resize_upsample_scales_nearest_axes_3_2/input_2.cairo create mode 100644 tests/nodes/resize_upsample_scales_nearest_axes_3_2/output_0.cairo create mode 100644 tests/nodes/resize_upsample_sizes_cubic.cairo create mode 100644 tests/nodes/resize_upsample_sizes_cubic/input_0.cairo create mode 100644 tests/nodes/resize_upsample_sizes_cubic/input_1.cairo create mode 100644 tests/nodes/resize_upsample_sizes_cubic/output_0.cairo create mode 100644 tests/nodes/resize_upsample_sizes_nearest.cairo create mode 100644 tests/nodes/resize_upsample_sizes_nearest/input_0.cairo create mode 100644 tests/nodes/resize_upsample_sizes_nearest/input_1.cairo create mode 100644 tests/nodes/resize_upsample_sizes_nearest/output_0.cairo create mode 100644 tests/nodes/resize_upsample_sizes_nearest_axes_2_3.cairo create mode 100644 tests/nodes/resize_upsample_sizes_nearest_axes_2_3/input_0.cairo create mode 100644 tests/nodes/resize_upsample_sizes_nearest_axes_2_3/input_1.cairo create mode 100644 tests/nodes/resize_upsample_sizes_nearest_axes_2_3/input_2.cairo create mode 100644 tests/nodes/resize_upsample_sizes_nearest_axes_2_3/output_0.cairo create mode 100644 tests/nodes/resize_upsample_sizes_nearest_axes_3_2.cairo create mode 100644 tests/nodes/resize_upsample_sizes_nearest_axes_3_2/input_0.cairo create mode 100644 tests/nodes/resize_upsample_sizes_nearest_axes_3_2/input_1.cairo create mode 100644 tests/nodes/resize_upsample_sizes_nearest_axes_3_2/input_2.cairo create mode 100644 tests/nodes/resize_upsample_sizes_nearest_axes_3_2/output_0.cairo create mode 100644 tests/nodes/resize_upsample_sizes_nearest_ceil_half_pixel.cairo create mode 100644 tests/nodes/resize_upsample_sizes_nearest_ceil_half_pixel/input_0.cairo create mode 100644 tests/nodes/resize_upsample_sizes_nearest_ceil_half_pixel/input_1.cairo create mode 100644 tests/nodes/resize_upsample_sizes_nearest_ceil_half_pixel/output_0.cairo create mode 100644 tests/nodes/resize_upsample_sizes_nearest_floor_align_corners.cairo create mode 100644 tests/nodes/resize_upsample_sizes_nearest_floor_align_corners/input_0.cairo create mode 100644 tests/nodes/resize_upsample_sizes_nearest_floor_align_corners/input_1.cairo create mode 100644 tests/nodes/resize_upsample_sizes_nearest_floor_align_corners/output_0.cairo create mode 100644 tests/nodes/resize_upsample_sizes_nearest_not_larger.cairo create mode 100644 tests/nodes/resize_upsample_sizes_nearest_not_larger/input_0.cairo create mode 100644 tests/nodes/resize_upsample_sizes_nearest_not_larger/input_1.cairo create mode 100644 tests/nodes/resize_upsample_sizes_nearest_not_larger/input_2.cairo create mode 100644 tests/nodes/resize_upsample_sizes_nearest_not_larger/output_0.cairo create mode 100644 tests/nodes/resize_upsample_sizes_nearest_not_smaller.cairo create mode 100644 tests/nodes/resize_upsample_sizes_nearest_not_smaller/input_0.cairo create mode 100644 tests/nodes/resize_upsample_sizes_nearest_not_smaller/input_1.cairo create mode 100644 tests/nodes/resize_upsample_sizes_nearest_not_smaller/input_2.cairo create mode 100644 tests/nodes/resize_upsample_sizes_nearest_not_smaller/output_0.cairo create mode 100644 tests/nodes/resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric.cairo create mode 100644 tests/nodes/resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric/input_0.cairo create mode 100644 tests/nodes/resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric/input_1.cairo create mode 100644 tests/nodes/resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric/output_0.cairo diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index 62ae2a2b3..00859d29c 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -100,6 +100,7 @@ * [tensor.bitwise\_and](framework/operators/tensor/tensor.bitwise\_and.md) * [tensor.bitwise\_xor](framework/operators/tensor/tensor.bitwise\_xor.md) * [tensor.bitwise\_or](framework/operators/tensor/tensor.bitwise\_or.md) + * [tensor.resize](framework/operators/tensor/tensor.resize.md) * [tensor.round](framework/operators/tensor/tensor.round.md) * [tensor.scatter](framework/operators/tensor/tensor.scatter.md) * [tensor.array\_feature\_extractor](framework/operators/tensor/tensor.array\_feature\_extractor.md) diff --git a/docs/framework/compatibility.md b/docs/framework/compatibility.md index 68cd44241..0126ced4e 100644 --- a/docs/framework/compatibility.md +++ b/docs/framework/compatibility.md @@ -77,6 +77,7 @@ You can see below the list of current supported ONNX Operators: | [BitwiseAnd](operators/tensor/tensor.bitwise_and.md) | :white\_check\_mark: | | [BitwiseOr](operators/tensor/tensor.bitwise_or.md) | :white\_check\_mark: | | [BitwiseXor](operators/tensor/tensor.bitwise_xor.md) | :white\_check\_mark: | +| [Resize](operators/tensor/tensor.resize.md) | :white\_check\_mark: | | [Round](operators/tensor/tensor.round.md) | :white\_check\_mark: | | [MaxInTensor](operators/tensor/tensor.max\_in\_tensor.md) | :white\_check\_mark: | | [Max](operators/tensor/tensor.max.md) | :white\_check\_mark: | diff --git a/docs/framework/operators/tensor/README.md b/docs/framework/operators/tensor/README.md index 777e00b78..94c11362c 100644 --- a/docs/framework/operators/tensor/README.md +++ b/docs/framework/operators/tensor/README.md @@ -98,6 +98,7 @@ use orion::operators::tensor::TensorTrait; | [`tensor.bitwise_and`](tensor.bitwise\_and.md) | Computes the bitwise AND of two tensors element-wise. | | [`tensor.bitwise_xor`](tensor.bitwise\_xor.md) | Computes the bitwise XOR of two tensors element-wise. | | [`tensor.bitwise_or`](tensor.bitwise\_or.md) | Computes the bitwise OR of two tensors element-wise. | +| [`tensor.resize`](tensor.resize.md) | Resizes the input tensor. | | [`tensor.round`](tensor.round.md) | Computes the round value of all elements in the input tensor. | | [`tensor.reduce_l1`](tensor.reduce\_l1.md) | Computes the L1 norm of the input tensor's elements along the provided axes. | | [`tensor.trilu`](tensor.trilu.md) | Returns the upper or lower triangular part of a tensor or a batch of 2D matrices. | diff --git a/docs/framework/operators/tensor/tensor.resize.md b/docs/framework/operators/tensor/tensor.resize.md new file mode 100644 index 000000000..35164970d --- /dev/null +++ b/docs/framework/operators/tensor/tensor.resize.md @@ -0,0 +1,229 @@ +#tensor.resize + +```rust + fn resize( + self: @Tensor, + roi: Option>, + scales: Option>, + sizes: Option>, + antialias: Option, + axes: Option>, + coordinate_transformation_mode: Option, + cubic_coeff_a: Option, + exclude_outside: Option, + extrapolation_value: Option, + keep_aspect_ratio_policy: Option, + mode: Option, + nearest_mode: Option, + ) -> Tensor; +``` + +Resizes the input tensor. In general, it calculates every value in the output tensor as a weighted average of neighborhood in the input tensor. + +## Args + +* `self`(`@Tensor`) - The input tensor. +* `roi` (`Option>`) (optional) - 1-D tensor given as [start1, ..., startN, end1, ..., endN], where N is the rank of X or the length of axes, if provided. It only takes effect when coordinate_transformation_mode is "tf_crop_and_resize" +* `scales` (`Option>`) (optional) - The scale array along each dimension. It takes value greater than 0. If it's less than 1, it's sampling down, otherwise, it's upsampling. The number of elements of 'scales' should be the same as the rank of input 'X' or the length of 'axes', if provided. One and only one of 'scales' and 'sizes' MUST be specified. +* `sizes` (`Option>`) (optional) - Target size of the output tensor. Its interpretation depends on the 'keep_aspect_ratio_policy' value. The number of elements of 'sizes' should be the same as the rank of input 'X', or the length of 'axes', if provided. One and only one of 'scales' and 'sizes' MUST be specified. +* `antialias` (`Option`) (default is 0) - If set to 1, "linear" and "cubic" interpolation modes will use an antialiasing filter when downscaling. Antialiasing is achieved by stretching the resampling filter by a factor max(1, 1 / scale). +* `axes`(`Option>`) - If provided, it specifies a subset of axes that 'roi', 'scales' and 'sizes' refer to. If not provided, all axes are assumed [0, 1, ..., r-1], where r = rank(data). +* `coordinate_transformation_mode` (`Option`) (default is half_pixel) - This attribute describes how to transform the coordinate in the resized tensor to the coordinate in the original tensor. +* `cubic_coeff_a` (`Option`) (default is -0.75) - The coefficient 'a' used in cubic interpolation. +* `exclude_outside` (`Option`) (default is false) - If set to true, the weight of sampling locations outside the tensor will be set to 0 and the weight will be renormalized so that their sum is 1.0. +* `extrapolation_value` (`Option`) (default is 0.0) - When coordinate_transformation_mode is "tf_crop_and_resize" and x_original is outside the range [0, length_original - 1], this value is used as the corresponding output value. +* `keep_aspect_ratio_policy` (`Option`) (default is stretch) - This attribute describes how to interpret the `sizes` input with regard to keeping the original aspect ratio of the input, and it is not applicable when the `scales` input is used. +* `mode` (`Option`) (default is nearest) - Three interpolation modes: "nearest", "linear" and "cubic". +* `nearest_mode` (`Option`) (default is round_prefer_floor) - Four modes: "round_prefer_floor" (as known as round half down), "round_prefer_ceil" (as known as round half up), "floor", "ceil". Only used by nearest interpolation. + +## Panics + +* Panics if both scales and sizes are `Option::None`. +* Panics if roi is `Option::None` for the coordinate_transformation_mode `tf_crop_and_resize`. +* Panics if antialias is not `Option::None` for mode `nearest`. + +## Returns + +A new resized `Tensor` of the dimension given by output_dimension = floor(input_dimension * (roi_end - roi_start) * scale) is scale is specified, or output_size if size is specified (note that some value of the parameter `keep_aspect_ratio_policy` can change sizes and therefore the dimension of the output tensor) + +## Example + +```rust +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor, FP16x16Tensor, FP16x16TensorPartialEq}; +use orion::operators::tensor::math::resize::{ + MODE, NEAREST_MODE, KEEP_ASPECT_RATIO_POLICY, TRANSFORMATION_MODE +}; +use orion::numbers::{FP16x16, FP16x16Impl, FixedTrait}; +use core::debug::PrintTrait; + +fn example_resize_downsample_scales_linear() -> Tensor{ + let mut data = TensorTrait::< + FP16x16 + >::new( + shape: array![1, 1, 2, 4].span(), + data: array![ + FixedTrait::::new(65536, false), //1 + FixedTrait::::new(131072, false), //2 + FixedTrait::::new(196608, false), //3 + FixedTrait::::new(262144, false), //4 + FixedTrait::::new(327680, false), //5 + FixedTrait::::new(393216, false), //6 + FixedTrait::::new(458752, false), //7 + FixedTrait::::new(524288, false), //8 + ] + .span(), + ); + let mut scales = array![ + FixedTrait::::new(65536, false), //1 + FixedTrait::::new(65536, false), + FixedTrait::::new(39322, false), //0.6 + FixedTrait::::new(39322, false) + ] + .span(); + + let scales = Option::Some(scales); + + return data.resize( + Option::None, + scales, + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::Some(MODE::LINEAR), + Option::None, + ); + +} +>>> [[[[2.6666665 4.3333331]]]] + + + +fn example_resize_tf_crop_and_resize_extrapolation_value() -> Tensor { + let mut data = TensorTrait::< + FP16x16 + >::new( + shape: array![1, 1, 4, 4].span(), + data: array![ + FixedTrait::::new(65536, false), + FixedTrait::::new(131072, false), + FixedTrait::::new(196608, false), + FixedTrait::::new(262144, false), + FixedTrait::::new(327680, false), + FixedTrait::::new(393216, false), + FixedTrait::::new(458752, false), + FixedTrait::::new(524288, false), + FixedTrait::::new(589824, false), + FixedTrait::::new(655360, false), + FixedTrait::::new(720896, false), + FixedTrait::::new(786432, false), + FixedTrait::::new(851968, false), + FixedTrait::::new(917504, false), + FixedTrait::::new(983040, false), + FixedTrait::::new(1048576, false), + ] + .span(), + ); + + let mut roi = TensorTrait::< + FP16x16 + >::new( + shape: array![8].span(), + data: array![ + FixedTrait::::new(0, false), + FixedTrait::::new(0, false), + FixedTrait::::new(26214, false), + FixedTrait::::new(39322, false), + FixedTrait::::new(65536, false), + FixedTrait::::new(65536, false), + FixedTrait::::new(78643, false), + FixedTrait::::new(111411, false), + ] + .span(), + ); + let roi = Option::Some(roi); + + let mut sizes = array![1, 1, 3, 3].span(); + let sizes = Option::Some(sizes); + + let extrapolation_value = Option::Some(FixedTrait::::new(655360, false)); + + return data.resize( + roi, + Option::None, + sizes, + Option::None, + Option::None, + Option::Some(TRANSFORMATION_MODE::TF_CROP_AND_RESIZE), + Option::None, + Option::None, + extrapolation_value, + Option::None, + Option::Some(MODE::LINEAR), + Option::None, + ); + +} +>>> [[[[ 7.6000004 10. 10. ] + [12.400001 10. 10. ] + [10. 10. 10. ]]]] + + + +fn example_resize_downsample_sizes_cubic_antialias() -> Tensor { + let mut data = TensorTrait::< + FP16x16 + >::new( + shape: array![1, 1, 4, 4].span(), + data: array![ + FixedTrait::::new(65536, false), + FixedTrait::::new(131072, false), + FixedTrait::::new(196608, false), + FixedTrait::::new(262144, false), + FixedTrait::::new(327680, false), + FixedTrait::::new(393216, false), + FixedTrait::::new(458752, false), + FixedTrait::::new(524288, false), + FixedTrait::::new(589824, false), + FixedTrait::::new(655360, false), + FixedTrait::::new(720896, false), + FixedTrait::::new(786432, false), + FixedTrait::::new(851968, false), + FixedTrait::::new(917504, false), + FixedTrait::::new(983040, false), + FixedTrait::::new(1048576, false), + ] + .span(), + ); + + let antialias = Option::Some(1); + + let mut sizes = array![1, 1, 3, 3].span(); + let sizes = Option::Some(sizes); + + return data.resize( + Option::None, + Option::None, + sizes, + antialias, + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::Some(MODE::CUBIC), + Option::None, + ); +} + +>>> [[[[ 1.7750092 3.1200073 4.4650054] + [ 7.1550016 8.5 9.844998 ] + [12.534994 13.8799925 15.224991 ]]]] + +``` diff --git a/nodegen/node/resize.py b/nodegen/node/resize.py new file mode 100644 index 000000000..65cafa9ba --- /dev/null +++ b/nodegen/node/resize.py @@ -0,0 +1,2204 @@ +# Python test implementation from ONNX library : https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_resize.py + +import numpy as np +from typing import Any, Callable + +from nodegen.node import RunAll +from ..helpers import make_test, to_fp, Tensor, Dtype, FixedImpl + + +def _cartesian(arrays: list[np.ndarray], out: np.ndarray | None = None) -> np.ndarray: + #From https://stackoverflow.com/a/1235363 + arrays = [np.asarray(x) for x in arrays] + dtype = arrays[0].dtype + + n = np.prod([x.size for x in arrays]) + if out is None: + out = np.zeros([n, len(arrays)], dtype=dtype) + + m = n // arrays[0].size + out[:, 0] = np.repeat(arrays[0], m) + if arrays[1:]: + _cartesian(arrays[1:], out=out[0:m, 1:]) + for j in range(1, arrays[0].size): + out[j * m : (j + 1) * m, 1:] = out[0:m, 1:] + return out + + +def _get_neighbor_idxes(x: float, n: int, limit: int) -> np.ndarray: + idxes = sorted(range(limit), key=lambda idx: (abs(x - idx), idx))[:n] + idxes = sorted(idxes) + return np.array(idxes) + + +def _get_neighbor(x: float, n: int, data: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + + pad_width = np.ceil(n / 2).astype(int) + padded = np.pad(data, pad_width, mode="edge") + x += pad_width + + idxes = _get_neighbor_idxes(x, n, len(padded)) + + + ret = padded[idxes] + return idxes - pad_width, ret + +def linear_coeffs(ratio: float, scale: float | None = None) -> np.ndarray: + del scale + return np.array([1 - ratio, ratio]) + + +def linear_coeffs_antialias(ratio: float, scale: float) -> np.ndarray: + scale = min(scale, 1.0) + + start = int(np.floor(-1 / scale) + 1) + footprint = 2 - 2 * start + args = (np.arange(start, start + footprint) - ratio) * scale + coeffs = np.clip(1 - np.abs(args), 0, 1) + + return np.array(coeffs) / sum(coeffs) + +def cubic_coeffs_antialias(ratio: float, scale: float, A: float = -0.75) -> np.ndarray: + scale = min(scale, 1.0) + + def compute_coeff(x: float) -> float: + x = abs(x) + x_2 = x * x + x_3 = x * x_2 + if x <= 1: + return (A + 2) * x_3 - (A + 3) * x_2 + 1 + if x < 2: + return A * x_3 - 5 * A * x_2 + 8 * A * x - 4 * A + return 0.0 + + i_start = int(np.floor(-2 / scale) + 1) + i_end = 2 - i_start + args = [scale * (i - ratio) for i in range(i_start, i_end)] + coeffs = [compute_coeff(x) for x in args] + return np.array(coeffs) / sum(coeffs) + +def nearest_coeffs( + ratio: float | int | np.ndarray, mode: str = "round_prefer_floor" +) -> np.ndarray: + if isinstance(ratio, int) or ratio.is_integer(): + return np.array([0, 1]) + if mode == "round_prefer_floor": + return np.array([ratio <= 0.5, ratio > 0.5]) + if mode == "round_prefer_ceil": + return np.array([ratio < 0.5, ratio >= 0.5]) + if mode == "floor": + return np.array([1, 0]) + if mode == "ceil": + return np.array([0, 1]) + raise ValueError(f"Unexpected value {mode!r}.") + + + +def _interpolate_1d_with_x( + data: np.ndarray, + scale_factor: float, + output_width_int: int, + x: float, + get_coeffs: Callable[[float, float], np.ndarray], + roi: np.ndarray | None = None, + extrapolation_value: float = 0.0, + coordinate_transformation_mode: str = "half_pixel", + exclude_outside: bool = False, +) -> np.ndarray: + + input_width = len(data) + output_width = scale_factor * input_width + + if coordinate_transformation_mode == "align_corners": + if output_width == 1: + x_ori = 0.0 + else: + x_ori = x * (input_width - 1) / (output_width - 1) + elif coordinate_transformation_mode == "asymmetric": + x_ori = x / scale_factor + elif coordinate_transformation_mode == "tf_crop_and_resize": + if roi is None: + raise ValueError("roi cannot be None.") + if output_width == 1: + x_ori = (roi[1] - roi[0]) * (input_width - 1) / 2 + else: + x_ori = x * (roi[1] - roi[0]) * (input_width - 1) / (output_width - 1) + x_ori += roi[0] * (input_width - 1) + + if x_ori < 0 or x_ori > input_width - 1: + return np.array(extrapolation_value) + elif coordinate_transformation_mode == "pytorch_half_pixel": + if output_width == 1: + x_ori = -0.5 + else: + x_ori = (x + 0.5) / scale_factor - 0.5 + elif coordinate_transformation_mode == "half_pixel": + x_ori = (x + 0.5) / scale_factor - 0.5 + elif coordinate_transformation_mode == "half_pixel_symmetric": + adjustment = output_width_int / output_width + center = input_width / 2 + offset = center * (1 - adjustment) + x_ori = offset + (x + 0.5) / scale_factor - 0.5 + else: + raise ValueError( + f"Invalid coordinate_transformation_mode: {coordinate_transformation_mode!r}." + ) + + x_ori_int = np.floor(x_ori).astype(int).item() + + if x_ori.is_integer(): + ratio = 1 + else: + ratio = x_ori - x_ori_int + + coeffs = get_coeffs(ratio, scale_factor) + n = len(coeffs) + + idxes, points = _get_neighbor(x_ori, n, data) + + if exclude_outside: + for i, idx in enumerate(idxes): + if idx < 0 or idx >= input_width: + coeffs[i] = 0 + coeffs /= sum(coeffs) + + return np.dot(coeffs, points).item() + + +def _interpolate_nd_with_x( + data: np.ndarray, + n: int, + scale_factors: list[float], + output_size: list[int], + x: list[float], + get_coeffs: Callable[[float, float], np.ndarray], + roi: np.ndarray | None = None, + exclude_outside: bool = False, + **kwargs: Any, +) -> np.ndarray: + + if n == 1: + return _interpolate_1d_with_x( + data, + scale_factors[0], + output_size[0], + x[0], + get_coeffs, + roi=roi, + exclude_outside=exclude_outside, + **kwargs, + ) + res1d = [] + + for i in range(data.shape[0]): + r = _interpolate_nd_with_x( + data[i], + n - 1, + scale_factors[1:], + output_size[1:], + x[1:], + get_coeffs, + roi=None if roi is None else np.concatenate([roi[1:n], roi[n + 1 :]]), + exclude_outside=exclude_outside, + **kwargs, + ) + res1d.append(r) + + + return _interpolate_1d_with_x( + res1d, + scale_factors[0], + output_size[0], + x[0], + get_coeffs, + roi=None if roi is None else [roi[0], roi[n]], + exclude_outside=exclude_outside, + **kwargs, + ) + + +def _get_all_coords(data: np.ndarray) -> np.ndarray: + return _cartesian( + [list(range(data.shape[i])) for i in range(len(data.shape))] + ) + + +def interpolate_nd( + data: np.ndarray, + get_coeffs: Callable[[float, float], np.ndarray], + output_size: list[int] | None = None, + scale_factors: list[float] | None = None, + axes: list[int] | None = None, + roi: np.ndarray | None = None, + keep_aspect_ratio_policy: str | None = "stretch", + exclude_outside: bool = False, + **kwargs: Any, +) -> np.ndarray: + if output_size is None and scale_factors is None: + raise ValueError("output_size is None and scale_factors is None.") + + r = len(data.shape) + if axes is not None: + if scale_factors is not None: + new_scale_factors = [1.0] * r + for i, d in enumerate(axes): + new_scale_factors[d] = scale_factors[i] + scale_factors = new_scale_factors + + if output_size is not None: + new_output_size = [data.shape[i] for i in range(r)] + for i, d in enumerate(axes): + new_output_size[d] = output_size[i] + output_size = new_output_size + + + if roi is not None: + new_roi = ([0.0] * r) + ([1.0] * r) + naxes = len(axes) + for i, d in enumerate(axes): + new_roi[d] = roi[i] + new_roi[r + d] = roi[naxes + i] + roi = new_roi + else: + axes = list(range(r)) + + if output_size is not None: + scale_factors = [output_size[i] / data.shape[i] for i in range(r)] + if keep_aspect_ratio_policy != "stretch": + if keep_aspect_ratio_policy == "not_larger": + scale = np.array(scale_factors)[axes].min() + elif keep_aspect_ratio_policy == "not_smaller": + scale = np.array(scale_factors)[axes].max() + else: + raise ValueError( + f"Invalid keep_aspect_ratio_policy={keep_aspect_ratio_policy!r}" + ) + + scale_factors = [scale if i in axes else 1.0 for i in range(r)] + + def round_half_up(x: float) -> int: + return int(x + 0.5) + + output_size = [ + round_half_up(scale * data.shape[i]) if i in axes else data.shape[i] + for i in range(r) + ] + + else: + output_size = (scale_factors * np.array(data.shape)).astype(int) + if scale_factors is None: + raise ValueError("scale_factors is None.") + if output_size is None: + raise ValueError("output_size is None.") + + ret = np.zeros(output_size) + for x in _get_all_coords(ret): + ret[tuple(x)] = _interpolate_nd_with_x( + data, + len(data.shape), + scale_factors, + output_size, + x, + get_coeffs, + roi=roi, + exclude_outside=exclude_outside, + **kwargs, + ) + return ret + + +def cubic_coeffs( + ratio: float, scale: float | None = None, A: float = -0.75 +) -> np.ndarray: + del scale # Unused + coeffs = [ + ((A * (ratio + 1) - 5 * A) * (ratio + 1) + 8 * A) * (ratio + 1) - 4 * A, + ((A + 2) * ratio - (A + 3)) * ratio * ratio + 1, + ((A + 2) * (1 - ratio) - (A + 3)) * (1 - ratio) * (1 - ratio) + 1, + ((A * ((1 - ratio) + 1) - 5 * A) * ((1 - ratio) + 1) + 8 * A) + * ((1 - ratio) + 1) + - 4 * A, + ] + return np.array(coeffs) + + + + +class Resize(RunAll): + + @staticmethod + def resize_upsample_scales_nearest() -> None: + + data = np.array( + [ + [ + [ + [1, 2], + [3, 4], + ] + ] + ], + dtype=np.float32, + ) + + scales = np.array([1.0, 1.0, 2.0, 3.0], dtype=np.float32) + + output = interpolate_nd( + data, lambda x, _: nearest_coeffs(x), scale_factors=scales + ).astype(np.float32) + + x = [data, scales] + y = output + for i in range(len(x)): + x[i] = Tensor(Dtype.FP16x16, x[i].shape, to_fp(x[i].flatten(), FixedImpl.FP16x16)) + + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + + name = "resize_upsample_scales_nearest" + func_sig = "data.resize(" + func_sig += "Option::None," + func_sig += "scales," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(TRANSFORMATION_MODE::HALF_PIXEL_SYMMETRIC)," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(MODE::NEAREST)," + func_sig += "Option::None,)" + make_test([x[0], x[1]], y, func_sig, name) + + + @staticmethod + def resize_downsample_scales_nearest() -> None: + + data = np.array( + [ + [ + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + ] + ] + ], + dtype=np.float32, + ) + + scales = np.array([1.0, 1.0, 0.6, 0.6], dtype=np.float32) + + output = interpolate_nd( + data, lambda x, _: nearest_coeffs(x), scale_factors=scales + ).astype(np.float32) + + x = [data, scales] + y = output + for i in range(len(x)): + x[i] = Tensor(Dtype.FP16x16, x[i].shape, to_fp(x[i].flatten(), FixedImpl.FP16x16)) + + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + + name = "resize_downsample_scales_nearest" + func_sig = "data.resize(" + func_sig += "Option::None," + func_sig += "scales," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(MODE::NEAREST)," + func_sig += "Option::None,)" + make_test([x[0], x[1]], y, func_sig, name) + + + @staticmethod + def resize_upsample_sizes_nearest() -> None: + + data = np.array( + [ + [ + [ + [1, 2], + [3, 4], + ] + ] + ], + dtype=np.float32, + ) + + sizes = np.array([1, 1, 7, 8], dtype=np.int64) + + output = interpolate_nd( + data, lambda x, _: nearest_coeffs(x), output_size=sizes + ).astype(np.float32) + + x = [data, sizes] + y = output + x[0] = Tensor(Dtype.FP16x16, x[0].shape, to_fp(x[0].flatten(), FixedImpl.FP16x16)) + x[1] = Tensor(Dtype.U32, x[1].shape, x[1].flatten()) + + + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + + name = "resize_upsample_sizes_nearest" + func_sig = "data.resize(" + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "sizes," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(MODE::NEAREST)," + func_sig += "Option::None,)" + make_test([x[0], x[1]], y, func_sig, name) + + + @staticmethod + def resize_downsample_sizes_nearest() -> None: + + data = np.array( + [ + [ + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + ] + ] + ], + dtype=np.float32, + ) + + sizes = np.array([1, 1, 1, 3], dtype=np.int64) + + output = interpolate_nd( + data, lambda x, _: nearest_coeffs(x), output_size=sizes + ).astype(np.float32) + + x = [data, sizes] + y = output + x[0] = Tensor(Dtype.FP16x16, x[0].shape, to_fp(x[0].flatten(), FixedImpl.FP16x16)) + x[1] = Tensor(Dtype.U32, x[1].shape, x[1].flatten()) + + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + + name = "resize_downsample_sizes_nearest" + func_sig = "data.resize(" + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "sizes," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(MODE::NEAREST)," + func_sig += "Option::None,)" + make_test([x[0], x[1]], y, func_sig, name) + + @staticmethod + def resize_upsample_scales_linear() -> None: + + data = np.array( + [ + [ + [ + [1, 2], + [3, 4], + ] + ] + ], + dtype=np.float32, + ) + + scales = np.array([1.0, 1.0, 2.0, 2.0], dtype=np.float32) + + output = interpolate_nd( + data, lambda x, _: linear_coeffs(x, None), scale_factors=scales + ).astype(np.float32) + + x = [data, scales] + y = output + + for i in range(len(x)): + x[i] = Tensor(Dtype.FP16x16, x[i].shape, to_fp(x[i].flatten(), FixedImpl.FP16x16)) + + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + + name = "resize_upsample_scales_linear" + func_sig = "data.resize(" + func_sig += "Option::None," + func_sig += "scales," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(MODE::LINEAR)," + func_sig += "Option::None,)" + make_test([x[0], x[1]], y, func_sig, name) + + + @staticmethod + def resize_upsample_scales_linear_align_corners() -> None: + + data = np.array( + [ + [ + [ + [1, 2], + [3, 4], + ] + ] + ], + dtype=np.float32, + ) + + scales = np.array([1.0, 1.0, 2.0, 2.0], dtype=np.float32) + + output = interpolate_nd( + data, + lambda x, _: linear_coeffs(x, None), + scale_factors=scales, + coordinate_transformation_mode="align_corners", + ).astype(np.float32) + + x = [data, scales] + y = output + + for i in range(len(x)): + x[i] = Tensor(Dtype.FP16x16, x[i].shape, to_fp(x[i].flatten(), FixedImpl.FP16x16)) + + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + + name = "resize_upsample_scales_linear_align_corners" + func_sig = "data.resize(" + func_sig += "Option::None," + func_sig += "scales," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(TRANSFORMATION_MODE::ALIGN_CORNERS)," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(MODE::LINEAR)," + func_sig += "Option::None,)" + make_test([x[0], x[1]], y, func_sig, name) + + + @staticmethod + def resize_downsample_scales_linear() -> None: + + data = np.array( + [ + [ + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + ] + ] + ], + dtype=np.float32, + ) + + scales = np.array([1.0, 1.0, 0.6, 0.6], dtype=np.float32) + + output = interpolate_nd( + data, lambda x, _: linear_coeffs(x, None), scale_factors=scales + ).astype(np.float32) + + x = [data, scales] + y = output + + for i in range(len(x)): + x[i] = Tensor(Dtype.FP16x16, x[i].shape, to_fp(x[i].flatten(), FixedImpl.FP16x16)) + + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + + name = "resize_upsample_scales_linear" + func_sig = "data.resize(" + func_sig += "Option::None," + func_sig += "scales," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(MODE::LINEAR)," + func_sig += "Option::None,)" + make_test([x[0], x[1]], y, func_sig, name) + + @staticmethod + def resize_downsample_scales_linear_align_corners() -> None: + + data = np.array( + [ + [ + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + ] + ] + ], + dtype=np.float32, + ) + + scales = np.array([1.0, 1.0, 0.6, 0.6], dtype=np.float32) + + output = interpolate_nd( + data, + lambda x, _: linear_coeffs(x, None), + scale_factors=scales, + coordinate_transformation_mode="align_corners", + ).astype(np.float32) + + x = [data, scales] + y = output + + for i in range(len(x)): + x[i] = Tensor(Dtype.FP16x16, x[i].shape, to_fp(x[i].flatten(), FixedImpl.FP16x16)) + + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + + name = "resize_downsample_scales_linear_align_corners" + func_sig = "data.resize(" + func_sig += "Option::None," + func_sig += "scales," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(TRANSFORMATION_MODE::ALIGN_CORNERS)," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(MODE::LINEAR)," + func_sig += "Option::None,)" + make_test([x[0], x[1]], y, func_sig, name) + + @staticmethod + def resize_upsample_scales_cubic() -> None: + + data = np.array( + [ + [ + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16], + ] + ] + ], + dtype=np.float32, + ) + + scales = np.array([1.0, 1.0, 2.0, 2.0], dtype=np.float32) + + output = interpolate_nd( + data, lambda x, _: cubic_coeffs(x, None), scale_factors=scales + ).astype(np.float32) + + x = [data, scales] + y = output + + for i in range(len(x)): + x[i] = Tensor(Dtype.FP16x16, x[i].shape, to_fp(x[i].flatten(), FixedImpl.FP16x16)) + + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + + name = "resize_upsample_scales_cubic" + func_sig = "data.resize(" + func_sig += "Option::None," + func_sig += "scales," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(MODE::CUBIC)," + func_sig += "Option::None,)" + make_test([x[0], x[1]], y, func_sig, name) + + @staticmethod + def resize_upsample_scales_cubic_align_corners() -> None: + data = np.array( + [ + [ + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16], + ] + ] + ], + dtype=np.float32, + ) + + scales = np.array([1.0, 1.0, 2.0, 2.0], dtype=np.float32) + + output = interpolate_nd( + data, + lambda x, _: cubic_coeffs(x), + scale_factors=scales, + coordinate_transformation_mode="align_corners", + ).astype(np.float32) + + x = [data, scales] + y = output + + for i in range(len(x)): + x[i] = Tensor(Dtype.FP16x16, x[i].shape, to_fp(x[i].flatten(), FixedImpl.FP16x16)) + + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + + name = "resize_upsample_scales_cubic_align_corners" + func_sig = "data.resize(" + func_sig += "Option::None," + func_sig += "scales," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(TRANSFORMATION_MODE::ALIGN_CORNERS)," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(MODE::CUBIC)," + func_sig += "Option::None,)" + make_test([x[0], x[1]], y, func_sig, name) + + @staticmethod + def resize_downsample_scales_cubic() -> None: + data = np.array( + [ + [ + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16], + ] + ] + ], + dtype=np.float32, + ) + + scales = np.array([1.0, 1.0, 0.8, 0.8], dtype=np.float32) + output = interpolate_nd( + data, lambda x, _: cubic_coeffs(x), scale_factors=scales + ).astype(np.float32) + + x = [data, scales] + y = output + + for i in range(len(x)): + x[i] = Tensor(Dtype.FP16x16, x[i].shape, to_fp(x[i].flatten(), FixedImpl.FP16x16)) + + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + + name = "resize_downsample_scales_cubic" + func_sig = "data.resize(" + func_sig += "Option::None," + func_sig += "scales," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(MODE::CUBIC)," + func_sig += "Option::None,)" + make_test([x[0], x[1]], y, func_sig, name) + + + @staticmethod + def resize_downsample_scales_cubic_align_corners() -> None: + data = np.array( + [ + [ + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16], + ] + ] + ], + dtype=np.float32, + ) + + scales = np.array([1.0, 1.0, 0.8, 0.8], dtype=np.float32) + + output = interpolate_nd( + data, + lambda x, _: cubic_coeffs(x), + scale_factors=scales, + coordinate_transformation_mode="align_corners", + ).astype(np.float32) + + x = [data, scales] + y = output + + for i in range(len(x)): + x[i] = Tensor(Dtype.FP16x16, x[i].shape, to_fp(x[i].flatten(), FixedImpl.FP16x16)) + + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + + name = "resize_downsample_scales_cubic_align_corners" + func_sig = "data.resize(" + func_sig += "Option::None," + func_sig += "scales," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(TRANSFORMATION_MODE::ALIGN_CORNERS)," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(MODE::CUBIC)," + func_sig += "Option::None,)" + make_test([x[0], x[1]], y, func_sig, name) + + + @staticmethod + def resize_upsample_sizes_cubic() -> None: + data = np.array( + [ + [ + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16], + ] + ] + ], + dtype=np.float32, + ) + + sizes = np.array([1, 1, 9, 10], dtype=np.int64) + output = interpolate_nd( + data, lambda x, _: cubic_coeffs(x), output_size=sizes + ).astype(np.float32) + + x = [data, sizes] + y = output + x[0] = Tensor(Dtype.FP16x16, x[0].shape, to_fp(x[0].flatten(), FixedImpl.FP16x16)) + x[1] = Tensor(Dtype.U32, x[1].shape, x[1].flatten()) + + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + + name = "resize_upsample_sizes_cubic" + func_sig = "data.resize(" + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "sizes," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(MODE::CUBIC)," + func_sig += "Option::None,)" + make_test([x[0], x[1]], y, func_sig, name) + + + @staticmethod + def resize_downsample_sizes_cubic() -> None: + data = np.array( + [ + [ + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16], + ] + ] + ], + dtype=np.float32, + ) + + sizes = np.array([1, 1, 3, 3], dtype=np.int64) + + output = interpolate_nd( + data, lambda x, _: cubic_coeffs(x), output_size=sizes + ).astype(np.float32) + + x = [data, sizes] + y = output + x[0] = Tensor(Dtype.FP16x16, x[0].shape, to_fp(x[0].flatten(), FixedImpl.FP16x16)) + x[1] = Tensor(Dtype.U32, x[1].shape, x[1].flatten()) + + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + + name = "resize_downsample_sizes_cubic" + func_sig = "data.resize(" + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "sizes," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(MODE::CUBIC)," + func_sig += "Option::None,)" + make_test([x[0], x[1]], y, func_sig, name) + + + + + + @staticmethod + def resize_upsample_scales_cubic_A_n0p5_exclude_outside() -> None: + data = np.array( + [ + [ + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16], + ] + ] + ], + dtype=np.float32, + ) + + scales = np.array([1.0, 1.0, 2.0, 2.0], dtype=np.float32) + output = interpolate_nd( + data, + lambda x, _: cubic_coeffs(x, A=-0.5), + scale_factors=scales, + exclude_outside=True, + ).astype(np.float32) + + x = [data, scales] + y = output + + for i in range(len(x)): + x[i] = Tensor(Dtype.FP16x16, x[i].shape, to_fp(x[i].flatten(), FixedImpl.FP16x16)) + + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + + name = "resize_upsample_scales_cubic_A_n0p5_exclude_outside" + func_sig = "data.resize(" + func_sig += "Option::None," + func_sig += "scales," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(FixedTrait::::new(32768, true))," + func_sig += "Option::Some(true)," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(MODE::CUBIC)," + func_sig += "Option::None,)" + make_test([x[0], x[1]], y, func_sig, name) + + + @staticmethod + def resize_downsample_scales_cubic_A_n0p5_exclude_outside() -> None: + data = np.array( + [ + [ + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16], + ] + ] + ], + dtype=np.float32, + ) + + scales = np.array([1.0, 1.0, 0.8, 0.8], dtype=np.float32) + output = interpolate_nd( + data, + lambda x, _: cubic_coeffs(x, A=-0.5), + scale_factors=scales, + exclude_outside=True, + ).astype(np.float32) + + x = [data, scales] + y = output + + for i in range(len(x)): + x[i] = Tensor(Dtype.FP16x16, x[i].shape, to_fp(x[i].flatten(), FixedImpl.FP16x16)) + + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + + name = "resize_downsample_scales_cubic_A_n0p5_exclude_outside" + func_sig = "data.resize(" + func_sig += "Option::None," + func_sig += "scales," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(FixedTrait::::new(32768, true))," + func_sig += "Option::Some(true)," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(MODE::CUBIC)," + func_sig += "Option::None,)" + make_test([x[0], x[1]], y, func_sig, name) + + + + @staticmethod + def resize_upsample_scales_cubic_asymmetric() -> None: + data = np.array( + [ + [ + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16], + ] + ] + ], + dtype=np.float32, + ) + + scales = np.array([1.0, 1.0, 2.0, 2.0], dtype=np.float32) + + output = interpolate_nd( + data, + lambda x, _: cubic_coeffs(x, A=-0.75), + scale_factors=scales, + coordinate_transformation_mode="asymmetric", + ).astype(np.float32) + + x = [data, scales] + y = output + + for i in range(len(x)): + x[i] = Tensor(Dtype.FP16x16, x[i].shape, to_fp(x[i].flatten(), FixedImpl.FP16x16)) + + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + + name = "resize_upsample_scales_cubic_asymmetric" + func_sig = "data.resize(" + func_sig += "Option::None," + func_sig += "scales," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(TRANSFORMATION_MODE::ASYMMETRIC)," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(MODE::CUBIC)," + func_sig += "Option::None,)" + make_test([x[0], x[1]], y, func_sig, name) + + + + @staticmethod + def resize_tf_crop_and_resize() -> None: + data = np.array( + [ + [ + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16], + ] + ] + ], + dtype=np.float32, + ) + roi = np.array([0, 0, 0.4, 0.6, 1, 1, 0.6, 0.8], dtype=np.float32) + sizes = np.array([1, 1, 3, 3], dtype=np.int64) + + output = interpolate_nd( + data, + lambda x, _: linear_coeffs(x), + output_size=sizes, + roi=roi, + coordinate_transformation_mode="tf_crop_and_resize", + ).astype(np.float32) + x = [data, sizes, roi] + y = output + x[0] = Tensor(Dtype.FP16x16, x[0].shape, to_fp(x[0].flatten(), FixedImpl.FP16x16)) + x[1] = Tensor(Dtype.U32, x[1].shape, x[1].flatten()) + x[2] = Tensor(Dtype.FP16x16, x[2].shape, to_fp(x[2].flatten(), FixedImpl.FP16x16)) + + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + name = "resize_tf_crop_and_resize" + func_sig = "data.resize(" + func_sig += "roi," + func_sig += "Option::None," + func_sig += "sizes," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(TRANSFORMATION_MODE::TF_CROP_AND_RESIZE)," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(MODE::LINEAR)," + func_sig += "Option::None,)" + make_test([x[0], x[1], x[2]], y, func_sig, name) + + @staticmethod + def resize_tf_crop_and_resize_extrapolation_value() -> None: + data = np.array( + [ + [ + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16], + ] + ] + ], + dtype=np.float32, + ) + + roi = np.array([0, 0, 0.4, 0.6, 1, 1, 1.2, 1.7], dtype=np.float32) + sizes = np.array([1, 1, 3, 3], dtype=np.int64) + + output = interpolate_nd( + data, + lambda x, _: linear_coeffs(x), + output_size=sizes, + roi=roi, + coordinate_transformation_mode="tf_crop_and_resize", + extrapolation_value=10.0, + ).astype(np.float32) + + x = [data, sizes, roi] + y = output + x[0] = Tensor(Dtype.FP16x16, x[0].shape, to_fp(x[0].flatten(), FixedImpl.FP16x16)) + x[1] = Tensor(Dtype.U32, x[1].shape, x[1].flatten()) + x[2] = Tensor(Dtype.FP16x16, x[2].shape, to_fp(x[2].flatten(), FixedImpl.FP16x16)) + + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + name = "resize_tf_crop_and_resize_extrapolation_value" + func_sig = "data.resize(" + func_sig += "roi," + func_sig += "Option::None," + func_sig += "sizes," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(TRANSFORMATION_MODE::TF_CROP_AND_RESIZE)," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(FixedTrait::::new(655360, false))," + func_sig += "Option::None," + func_sig += "Option::Some(MODE::LINEAR)," + func_sig += "Option::None,)" + make_test([x[0], x[1], x[2]], y, func_sig, name) + + @staticmethod + def resize_downsample_sizes_linear_pytorch_half_pixel() -> None: + data = np.array( + [ + [ + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16], + ] + ] + ], + dtype=np.float32, + ) + + sizes = np.array([1, 1, 3, 1], dtype=np.int64) + output = interpolate_nd( + data, + lambda x, _: linear_coeffs(x), + output_size=sizes, + coordinate_transformation_mode="pytorch_half_pixel", + ).astype(np.float32) + + x = [data, sizes] + y = output + x[0] = Tensor(Dtype.FP16x16, x[0].shape, to_fp(x[0].flatten(), FixedImpl.FP16x16)) + x[1] = Tensor(Dtype.U32, x[1].shape, x[1].flatten()) + + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + name = "resize_downsample_sizes_linear_pytorch_half_pixel" + func_sig = "data.resize(" + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "sizes," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(TRANSFORMATION_MODE::PYTORCH_HALF_PIXEL)," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(MODE::LINEAR)," + func_sig += "Option::None,)" + make_test([x[0], x[1]], y, func_sig, name) + + + + @staticmethod + def resize_upsample_sizes_nearest_floor_align_corners() -> None: + data = np.array( + [ + [ + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16], + ] + ] + ], + dtype=np.float32, + ) + + sizes = np.array([1, 1, 8, 8], dtype=np.int64) + output = interpolate_nd( + data, + lambda x, _: nearest_coeffs(x, mode="floor"), + output_size=sizes, + coordinate_transformation_mode="align_corners", + ).astype(np.float32) + + x = [data, sizes] + y = output + x[0] = Tensor(Dtype.FP16x16, x[0].shape, to_fp(x[0].flatten(), FixedImpl.FP16x16)) + x[1] = Tensor(Dtype.U32, x[1].shape, x[1].flatten()) + + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + name = "resize_upsample_sizes_nearest_floor_align_corners" + func_sig = "data.resize(" + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "sizes," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(TRANSFORMATION_MODE::ALIGN_CORNERS)," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(MODE::NEAREST)," + func_sig += "Option::Some(NEAREST_MODE::FLOOR),)" + make_test([x[0], x[1]], y, func_sig, name) + + @staticmethod + def resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric() -> None: + data = np.array( + [ + [ + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16], + ] + ] + ], + dtype=np.float32, + ) + + sizes = np.array([1, 1, 8, 8], dtype=np.int64) + + output = interpolate_nd( + data, + lambda x, _: nearest_coeffs(x, mode="round_prefer_ceil"), + output_size=sizes, + coordinate_transformation_mode="asymmetric", + ).astype(np.float32) + + x = [data, sizes] + y = output + x[0] = Tensor(Dtype.FP16x16, x[0].shape, to_fp(x[0].flatten(), FixedImpl.FP16x16)) + x[1] = Tensor(Dtype.U32, x[1].shape, x[1].flatten()) + + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + name = "resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric" + func_sig = "data.resize(" + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "sizes," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(TRANSFORMATION_MODE::ASYMMETRIC)," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(MODE::NEAREST)," + func_sig += "Option::Some(NEAREST_MODE::ROUND_PREFER_CEIL),)" + make_test([x[0], x[1]], y, func_sig, name) + + + @staticmethod + def resize_upsample_sizes_nearest_ceil_half_pixel() -> None: + data = np.array( + [ + [ + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16], + ] + ] + ], + dtype=np.float32, + ) + + sizes = np.array([1, 1, 8, 8], dtype=np.int64) + + output = interpolate_nd( + data, lambda x, _: nearest_coeffs(x, mode="ceil"), output_size=sizes + ).astype(np.float32) + + x = [data, sizes] + y = output + x[0] = Tensor(Dtype.FP16x16, x[0].shape, to_fp(x[0].flatten(), FixedImpl.FP16x16)) + x[1] = Tensor(Dtype.U32, x[1].shape, x[1].flatten()) + + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + name = "resize_upsample_sizes_nearest_ceil_half_pixel" + func_sig = "data.resize(" + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "sizes," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(TRANSFORMATION_MODE::HALF_PIXEL)," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(MODE::NEAREST)," + func_sig += "Option::Some(NEAREST_MODE::CEIL),)" + make_test([x[0], x[1]], y, func_sig, name) + + @staticmethod + def resize_downsample_scales_linear_antialias() -> None: + data = np.array( + [ + [ + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16], + ] + ] + ], + dtype=np.float32, + ) + + scales = np.array([1.0, 1.0, 0.6, 0.6], dtype=np.float32) + + output = interpolate_nd( + data, linear_coeffs_antialias, scale_factors=scales + ).astype(np.float32) + + x = [data, scales] + y = output + + for i in range(len(x)): + x[i] = Tensor(Dtype.FP16x16, x[i].shape, to_fp(x[i].flatten(), FixedImpl.FP16x16)) + + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + + name = "resize_downsample_scales_linear_antialias" + func_sig = "data.resize(" + func_sig += "Option::None," + func_sig += "scales," + func_sig += "Option::None," + func_sig += "Option::Some(1)," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(MODE::LINEAR)," + func_sig += "Option::None,)" + make_test([x[0], x[1]], y, func_sig, name) + + @staticmethod + def resize_downsample_sizes_linear_antialias() -> None: + data = np.array( + [ + [ + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16], + ] + ] + ], + dtype=np.float32, + ) + + sizes = np.array([1, 1, 3, 3], dtype=np.int64) + + output = interpolate_nd( + data, linear_coeffs_antialias, output_size=sizes + ).astype(np.float32) + + x = [data, sizes] + y = output + x[0] = Tensor(Dtype.FP16x16, x[0].shape, to_fp(x[0].flatten(), FixedImpl.FP16x16)) + x[1] = Tensor(Dtype.U32, x[1].shape, x[1].flatten()) + + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + name = "resize_downsample_sizes_linear_pytorch_half_pixel" + func_sig = "data.resize(" + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "sizes," + func_sig += "Option::Some(1)," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(MODE::LINEAR)," + func_sig += "Option::None,)" + make_test([x[0], x[1]], y, func_sig, name) + + + @staticmethod + def resize_downsample_scales_cubic_antialias() -> None: + data = np.array( + [ + [ + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16], + ] + ] + ], + dtype=np.float32, + ) + + scales = np.array([1.0, 1.0, 0.6, 0.6], dtype=np.float32) + + output = interpolate_nd( + data, cubic_coeffs_antialias, scale_factors=scales + ).astype(np.float32) + + x = [data, scales] + y = output + + for i in range(len(x)): + x[i] = Tensor(Dtype.FP16x16, x[i].shape, to_fp(x[i].flatten(), FixedImpl.FP16x16)) + + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + + name = "resize_downsample_scales_cubic_antialias" + func_sig = "data.resize(" + func_sig += "Option::None," + func_sig += "scales," + func_sig += "Option::None," + func_sig += "Option::Some(1)," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(MODE::CUBIC)," + func_sig += "Option::None,)" + make_test([x[0], x[1]], y, func_sig, name) + + @staticmethod + def resize_downsample_sizes_cubic_antialias() -> None: + data = np.array( + [ + [ + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16], + ] + ] + ], + dtype=np.float32, + ) + + sizes = np.array([1, 1, 3, 3], dtype=np.int64) + + output = interpolate_nd(data, cubic_coeffs_antialias, output_size=sizes).astype( + np.float32 + ) + x = [data, sizes] + y = output + x[0] = Tensor(Dtype.FP16x16, x[0].shape, to_fp(x[0].flatten(), FixedImpl.FP16x16)) + x[1] = Tensor(Dtype.U32, x[1].shape, x[1].flatten()) + + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + name = "resize_downsample_sizes_cubic_antialias" + func_sig = "data.resize(" + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "sizes," + func_sig += "Option::Some(1)," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(MODE::CUBIC)," + func_sig += "Option::None,)" + make_test([x[0], x[1]], y, func_sig, name) + + @staticmethod + def resize_upsample_scales_nearest_axes_2_3() -> None: + axes = np.array([2, 3], dtype=np.int64) + data = np.array( + [ + [ + [ + [1, 2], + [3, 4], + ] + ] + ], + dtype=np.float32, + ) + + scales = np.array([2.0, 3.0], dtype=np.float32) + + output = interpolate_nd( + data, lambda x, _: nearest_coeffs(x), scale_factors=scales, axes=axes + ).astype(np.float32) + + x = [data, scales, axes] + y = output + + x[0] = Tensor(Dtype.FP16x16, x[0].shape, to_fp(x[0].flatten(), FixedImpl.FP16x16)) + x[1] = Tensor(Dtype.FP16x16, x[1].shape, to_fp(x[1].flatten(), FixedImpl.FP16x16)) + x[2] = Tensor(Dtype.U32, x[2].shape, x[2].flatten()) + + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + + name = "resize_upsample_scales_nearest_axes_2_3" + func_sig = "data.resize(" + func_sig += "Option::None," + func_sig += "scales," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "axes," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(MODE::NEAREST)," + func_sig += "Option::None,)" + make_test([x[0], x[1], x[2]], y, func_sig, name) + + @staticmethod + def resize_upsample_scales_nearest_axes_3_2() -> None: + + axes = np.array([3, 2], dtype=np.int64) + data = np.array([[[[1, 2],[3, 4],]]],dtype=np.float32,) + + scales = np.array([3.0, 2.0], dtype=np.float32) + + output = interpolate_nd( + data, lambda x, _: nearest_coeffs(x), scale_factors=scales, axes=axes + ).astype(np.float32) + x = [data, scales, axes] + y = output + + x[0] = Tensor(Dtype.FP16x16, x[0].shape, to_fp(x[0].flatten(), FixedImpl.FP16x16)) + x[1] = Tensor(Dtype.FP16x16, x[1].shape, to_fp(x[1].flatten(), FixedImpl.FP16x16)) + x[2] = Tensor(Dtype.U32, x[2].shape, x[2].flatten()) + + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + + name = "resize_upsample_scales_nearest_axes_3_2" + func_sig = "data.resize(" + func_sig += "Option::None," + func_sig += "scales," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "axes," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(MODE::NEAREST)," + func_sig += "Option::None,)" + make_test([x[0], x[1], x[2]], y, func_sig, name) + + @staticmethod + def resize_upsample_sizes_nearest_axes_2_3() -> None: + data = np.array( + [ + [ + [ + [1, 2], + [3, 4], + ] + ] + ], + dtype=np.float32, + ) + + sizes = np.array([7, 8], dtype=np.int64) + axes = np.array([2, 3], dtype=np.int64) + + output = interpolate_nd( + data, lambda x, _: nearest_coeffs(x), output_size=sizes, axes=axes + ).astype(np.float32) + + x = [data, sizes, axes] + y = output + + x[0] = Tensor(Dtype.FP16x16, x[0].shape, to_fp(x[0].flatten(), FixedImpl.FP16x16)) + x[1] = Tensor(Dtype.U32, x[1].shape, x[1].flatten()) + x[2] = Tensor(Dtype.U32, x[2].shape, x[2].flatten()) + + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + + name = "resize_upsample_sizes_nearest_axes_2_3" + func_sig = "data.resize(" + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "sizes," + func_sig += "Option::None," + func_sig += "axes," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(MODE::NEAREST)," + func_sig += "Option::None,)" + make_test([x[0], x[1], x[2]], y, func_sig, name) + + @staticmethod + def resize_upsample_sizes_nearest_axes_3_2() -> None: + data = np.array( + [ + [ + [ + [1, 2], + [3, 4], + ] + ] + ], + dtype=np.float32, + ) + + sizes = np.array([8, 7], dtype=np.int64) + axes = np.array([3, 2], dtype=np.int64) + + output = interpolate_nd( + data, lambda x, _: nearest_coeffs(x), output_size=sizes, axes=axes + ).astype(np.float32) + + x = [data, sizes, axes] + y = output + + x[0] = Tensor(Dtype.FP16x16, x[0].shape, to_fp(x[0].flatten(), FixedImpl.FP16x16)) + x[1] = Tensor(Dtype.U32, x[1].shape, x[1].flatten()) + x[2] = Tensor(Dtype.U32, x[2].shape, x[2].flatten()) + + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + + name = "resize_upsample_sizes_nearest_axes_3_2" + func_sig = "data.resize(" + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "sizes," + func_sig += "Option::None," + func_sig += "axes," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(MODE::NEAREST)," + func_sig += "Option::None,)" + make_test([x[0], x[1], x[2]], y, func_sig, name) + + + @staticmethod + def resize_tf_crop_and_resize_axes_2_3() -> None: + axes = np.array([2, 3], dtype=np.int64) + data = np.array( + [ + [ + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16], + ] + ] + ], + dtype=np.float32, + ) + + roi = np.array([0.4, 0.6, 0.6, 0.8], dtype=np.float32) + sizes = np.array([3, 3], dtype=np.int64) + + output = interpolate_nd( + data, + lambda x, _: linear_coeffs(x), + output_size=sizes, + roi=roi, + axes=axes, + coordinate_transformation_mode="tf_crop_and_resize", + ).astype(np.float32) + + x = [data, sizes, roi, axes] + y = output + + x[0] = Tensor(Dtype.FP16x16, x[0].shape, to_fp(x[0].flatten(), FixedImpl.FP16x16)) + x[1] = Tensor(Dtype.U32, x[1].shape, x[1].flatten()) + x[2] = Tensor(Dtype.FP16x16, x[2].shape, to_fp(x[2].flatten(), FixedImpl.FP16x16)) + x[3] = Tensor(Dtype.U32, x[3].shape, x[3].flatten()) + + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + + name = "resize_tf_crop_and_resize_axes_2_3" + func_sig = "data.resize(" + func_sig += "roi," + func_sig += "Option::None," + func_sig += "sizes," + func_sig += "Option::None," + func_sig += "axes," + func_sig += "Option::Some(TRANSFORMATION_MODE::TF_CROP_AND_RESIZE)," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(MODE::LINEAR)," + func_sig += "Option::None,)" + make_test([x[0], x[1], x[2], x[3]], y, func_sig, name) + + @staticmethod + def resize_tf_crop_and_resize_axes_3_2() -> None: + axes = np.array([3, 2], dtype=np.int64) + data = np.array( + [ + [ + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12], + [13, 14, 15, 16], + ] + ] + ], + dtype=np.float32, + ) + + roi = np.array([0.6, 0.4, 0.8, 0.6], dtype=np.float32) + sizes = np.array([3, 3], dtype=np.int64) + + output = interpolate_nd( + data, + lambda x, _: linear_coeffs(x), + output_size=sizes, + roi=roi, + axes=axes, + coordinate_transformation_mode="tf_crop_and_resize", + ).astype(np.float32) + + x = [data, sizes, roi, axes] + y = output + + x[0] = Tensor(Dtype.FP16x16, x[0].shape, to_fp(x[0].flatten(), FixedImpl.FP16x16)) + x[1] = Tensor(Dtype.U32, x[1].shape, x[1].flatten()) + x[2] = Tensor(Dtype.FP16x16, x[2].shape, to_fp(x[2].flatten(), FixedImpl.FP16x16)) + x[3] = Tensor(Dtype.U32, x[3].shape, x[3].flatten()) + + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + + name = "resize_tf_crop_and_resize_axes_3_2" + func_sig = "data.resize(" + func_sig += "roi," + func_sig += "Option::None," + func_sig += "sizes," + func_sig += "Option::None," + func_sig += "axes," + func_sig += "Option::Some(TRANSFORMATION_MODE::TF_CROP_AND_RESIZE)," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(MODE::LINEAR)," + func_sig += "Option::None,)" + make_test([x[0], x[1], x[2], x[3]], y, func_sig, name) + + + + @staticmethod + def resize_upsample_sizes_nearest_not_larger() -> None: + keep_aspect_ratio_policy = "not_larger" + axes = np.array([2, 3], dtype=np.int64) + data = np.array( + [ + [ + [ + [1, 2], + [3, 4], + ] + ] + ], + dtype=np.float32, + ) + + sizes = np.array([7, 8], dtype=np.int64) + output = interpolate_nd( + data, + lambda x, _: nearest_coeffs(x), + output_size=sizes, + axes=axes, + keep_aspect_ratio_policy=keep_aspect_ratio_policy, + ).astype(np.float32) + + x = [data, sizes, axes] + y = output + + x[0] = Tensor(Dtype.FP16x16, x[0].shape, to_fp(x[0].flatten(), FixedImpl.FP16x16)) + x[1] = Tensor(Dtype.U32, x[1].shape, x[1].flatten()) + x[2] = Tensor(Dtype.U32, x[2].shape, x[2].flatten()) + + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + + name = "resize_upsample_sizes_nearest_not_larger" + func_sig = "data.resize(" + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "sizes," + func_sig += "Option::None," + func_sig += "axes," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(KEEP_ASPECT_RATIO_POLICY::NOT_LARGER)," + func_sig += "Option::Some(MODE::NEAREST)," + func_sig += "Option::None,)" + make_test([x[0], x[1], x[2]], y, func_sig, name) + + + + @staticmethod + def resize_upsample_sizes_nearest_not_smaller() -> None: + keep_aspect_ratio_policy = "not_smaller" + axes = np.array([2, 3], dtype=np.int64) + data = np.array( + [ + [ + [ + [1, 2], + [3, 4], + ] + ] + ], + dtype=np.float32, + ) + + sizes = np.array([7, 8], dtype=np.int64) # Results in 8x8 + + output = interpolate_nd( + data, + lambda x, _: nearest_coeffs(x), + output_size=sizes, + axes=axes, + keep_aspect_ratio_policy=keep_aspect_ratio_policy, + ).astype(np.float32) + + x = [data, sizes, axes] + y = output + + x[0] = Tensor(Dtype.FP16x16, x[0].shape, to_fp(x[0].flatten(), FixedImpl.FP16x16)) + x[1] = Tensor(Dtype.U32, x[1].shape, x[1].flatten()) + x[2] = Tensor(Dtype.U32, x[2].shape, x[2].flatten()) + + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + + name = "resize_upsample_sizes_nearest_not_smaller" + func_sig = "data.resize(" + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "sizes," + func_sig += "Option::None," + func_sig += "axes," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(KEEP_ASPECT_RATIO_POLICY::NOT_SMALLER)," + func_sig += "Option::Some(MODE::NEAREST)," + func_sig += "Option::None,)" + make_test([x[0], x[1], x[2]], y, func_sig, name) + + + + + @staticmethod + def resize_downsample_sizes_nearest_not_larger() -> None: + keep_aspect_ratio_policy = "not_larger" + axes = np.array([2, 3], dtype=np.int64) + data = np.array( + [ + [ + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + ] + ] + ], + dtype=np.float32, + ) + + sizes = np.array([1, 3], dtype=np.int64) + + output = interpolate_nd( + data, + lambda x, _: nearest_coeffs(x), + output_size=sizes, + axes=axes, + keep_aspect_ratio_policy=keep_aspect_ratio_policy, + ).astype(np.float32) + + x = [data, sizes, axes] + y = output + + x[0] = Tensor(Dtype.FP16x16, x[0].shape, to_fp(x[0].flatten(), FixedImpl.FP16x16)) + x[1] = Tensor(Dtype.U32, x[1].shape, x[1].flatten()) + x[2] = Tensor(Dtype.U32, x[2].shape, x[2].flatten()) + + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + + name = "resize_downsample_sizes_nearest_not_larger" + func_sig = "data.resize(" + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "sizes," + func_sig += "Option::None," + func_sig += "axes," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(KEEP_ASPECT_RATIO_POLICY::NOT_LARGER)," + func_sig += "Option::Some(MODE::NEAREST)," + func_sig += "Option::None,)" + make_test([x[0], x[1], x[2]], y, func_sig, name) + + + + @staticmethod + def resize_downsample_sizes_nearest_not_smaller() -> None: + keep_aspect_ratio_policy = "not_smaller" + axes = np.array([2, 3], dtype=np.int64) + + data = np.array( + [ + [ + [ + [1, 2, 3, 4], + [5, 6, 7, 8], + ] + ] + ], + dtype=np.float32, + ) + + sizes = np.array([1, 3], dtype=np.int64) + output = interpolate_nd( + data, + lambda x, _: nearest_coeffs(x), + output_size=sizes, + axes=axes, + keep_aspect_ratio_policy=keep_aspect_ratio_policy, + ).astype(np.float32) + + x = [data, sizes, axes] + y = output + + x[0] = Tensor(Dtype.FP16x16, x[0].shape, to_fp(x[0].flatten(), FixedImpl.FP16x16)) + x[1] = Tensor(Dtype.U32, x[1].shape, x[1].flatten()) + x[2] = Tensor(Dtype.U32, x[2].shape, x[2].flatten()) + + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + + name = "resize_downsample_sizes_nearest_not_smaller" + func_sig = "data.resize(" + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "sizes," + func_sig += "Option::None," + func_sig += "axes," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(KEEP_ASPECT_RATIO_POLICY::NOT_SMALLER)," + func_sig += "Option::Some(MODE::NEAREST)," + func_sig += "Option::None,)" + make_test([x[0], x[1], x[2]], y, func_sig, name) + + + + + @staticmethod + def resize_downsample_scales_linear_half_pixel_symmetric() -> None: + data = np.array([[[[1, 2, 3, 4]]]], dtype=np.float32) + scales = np.array([1.0, 1.0, 1.0, 0.6], dtype=np.float32) + + + output = interpolate_nd( + data, + lambda x, _: linear_coeffs(x), + scale_factors=scales, + coordinate_transformation_mode="half_pixel_symmetric", + ).astype(np.float32) + + x = [data, scales] + y = output + for i in range(len(x)): + x[i] = Tensor(Dtype.FP16x16, x[i].shape, to_fp(x[i].flatten(), FixedImpl.FP16x16)) + + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + + name = "resize_downsample_scales_linear_half_pixel_symmetric" + func_sig = "data.resize(" + func_sig += "Option::None," + func_sig += "scales," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(TRANSFORMATION_MODE::HALF_PIXEL_SYMMETRIC)," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(MODE::LINEAR)," + func_sig += "Option::None,)" + make_test([x[0], x[1]], y, func_sig, name) + + + + @staticmethod + def resize_upsample_scales_linear_half_pixel_symmetric() -> None: + data = np.array([[[[1, 2], [3, 4]]]], dtype=np.float32) + scales = np.array([1.0, 1.0, 2.3, 2.94], dtype=np.float32) + + output = interpolate_nd( + data, + lambda x, _: linear_coeffs(x), + scale_factors=scales, + coordinate_transformation_mode="half_pixel_symmetric", + ).astype(np.float32) + + x = [data, scales] + y = output + for i in range(len(x)): + x[i] = Tensor(Dtype.FP16x16, x[i].shape, to_fp(x[i].flatten(), FixedImpl.FP16x16)) + + y = Tensor(Dtype.FP16x16, y.shape, to_fp(y.flatten(), FixedImpl.FP16x16)) + + name = "resize_upsample_scales_linear_half_pixel_symmetric" + func_sig = "data.resize(" + func_sig += "Option::None," + func_sig += "scales," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(TRANSFORMATION_MODE::HALF_PIXEL_SYMMETRIC)," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::None," + func_sig += "Option::Some(MODE::LINEAR)," + func_sig += "Option::None,)" + make_test([x[0], x[1]], y, func_sig, name) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/operators/ml.cairo b/src/operators/ml.cairo index 4bfd10060..93a490bbe 100644 --- a/src/operators/ml.cairo +++ b/src/operators/ml.cairo @@ -11,4 +11,6 @@ use orion::operators::ml::tree_ensemble::tree_ensemble_classifier::{ use orion::operators::ml::tree_ensemble::tree_ensemble_regressor::{ TreeEnsembleRegressor, TreeEnsembleRegressorImpl, TreeEnsembleRegressorTrait, AGGREGATE_FUNCTION }; -use orion::operators::ml::linear::linear_regressor::{LinearRegressorTrait, LinearRegressorImpl, LinearRegressor}; +use orion::operators::ml::linear::linear_regressor::{ + LinearRegressorTrait, LinearRegressorImpl, LinearRegressor +}; diff --git a/src/operators/tensor/core.cairo b/src/operators/tensor/core.cairo index decc2e343..4eb0ea865 100644 --- a/src/operators/tensor/core.cairo +++ b/src/operators/tensor/core.cairo @@ -3,6 +3,7 @@ use core::serde::Serde; use core::option::OptionTrait; use alexandria_data_structures::array_ext::{SpanTraitExt}; +//::resize::{MODE, NEAREST_MODE, KEEP_ASPECT_RATIO_POLICY, TRANSFORMATION_MODE}; use orion::operators::tensor::helpers::{len_from_shape, check_shape}; use orion::numbers::{i8, i32, NumberTrait}; @@ -94,6 +95,7 @@ impl TensorSerde, impl TDrop: Drop> of Serde { /// ``` /// fn where(self: @Tensor, x: @Tensor, y: @Tensor) -> Tensor; + /// #tensor.resize + /// + /// ```rust + /// fn resize( + /// self: @Tensor, + /// roi: Option>, + /// scales: Option>, + /// sizes: Option>, + /// antialias: Option, + /// axes: Option>, + /// coordinate_transformation_mode: Option, + /// cubic_coeff_a: Option, + /// exclude_outside: Option, + /// extrapolation_value: Option, + /// keep_aspect_ratio_policy: Option, + /// mode: Option, + /// nearest_mode: Option, + /// ) -> Tensor; + /// ``` + /// + /// Resizes the input tensor. In general, it calculates every value in the output tensor as a weighted average of neighborhood in the input tensor. + /// + /// ## Args + /// + /// * `self`(`@Tensor`) - The input tensor. + /// * `roi` (`Option>`) (optional) - 1-D tensor given as [start1, ..., startN, end1, ..., endN], where N is the rank of X or the length of axes, if provided. It only takes effect when coordinate_transformation_mode is "tf_crop_and_resize" + /// * `scales` (`Option>`) (optional) - The scale array along each dimension. It takes value greater than 0. If it's less than 1, it's sampling down, otherwise, it's upsampling. The number of elements of 'scales' should be the same as the rank of input 'X' or the length of 'axes', if provided. One and only one of 'scales' and 'sizes' MUST be specified. + /// * `sizes` (`Option>`) (optional) - Target size of the output tensor. Its interpretation depends on the 'keep_aspect_ratio_policy' value. The number of elements of 'sizes' should be the same as the rank of input 'X', or the length of 'axes', if provided. One and only one of 'scales' and 'sizes' MUST be specified. + /// * `antialias` (`Option`) (default is 0) - If set to 1, "linear" and "cubic" interpolation modes will use an antialiasing filter when downscaling. Antialiasing is achieved by stretching the resampling filter by a factor max(1, 1 / scale). + /// * `axes`(`Option>`) - If provided, it specifies a subset of axes that 'roi', 'scales' and 'sizes' refer to. If not provided, all axes are assumed [0, 1, ..., r-1], where r = rank(data). + /// * `coordinate_transformation_mode` (`Option`) (default is half_pixel) - This attribute describes how to transform the coordinate in the resized tensor to the coordinate in the original tensor. + /// * `cubic_coeff_a` (`Option`) (default is -0.75) - The coefficient 'a' used in cubic interpolation. + /// * `exclude_outside` (`Option`) (default is false) - If set to true, the weight of sampling locations outside the tensor will be set to 0 and the weight will be renormalized so that their sum is 1.0. + /// * `extrapolation_value` (`Option`) (default is 0.0) - When coordinate_transformation_mode is "tf_crop_and_resize" and x_original is outside the range [0, length_original - 1], this value is used as the corresponding output value. + /// * `keep_aspect_ratio_policy` (`Option`) (default is stretch) - This attribute describes how to interpret the `sizes` input with regard to keeping the original aspect ratio of the input, and it is not applicable when the `scales` input is used. + /// * `mode` (`Option`) (default is nearest) - Three interpolation modes: "nearest", "linear" and "cubic". + /// * `nearest_mode` (`Option`) (default is round_prefer_floor) - Four modes: "round_prefer_floor" (as known as round half down), "round_prefer_ceil" (as known as round half up), "floor", "ceil". Only used by nearest interpolation. + /// + /// ## Panics + /// + /// * Panics if both scales and sizes are `Option::None`. + /// * Panics if roi is `Option::None` for the coordinate_transformation_mode `tf_crop_and_resize`. + /// * Panics if antialias is not `Option::None` for mode `nearest`. + /// + /// ## Returns + /// + /// A new resized `Tensor` of the dimension given by output_dimension = floor(input_dimension * (roi_end - roi_start) * scale) is scale is specified, or output_size if size is specified (note that some value of the parameter `keep_aspect_ratio_policy` can change sizes and therefore the dimension of the output tensor) + /// + /// ## Example + /// + /// ```rust + /// use core::array::{ArrayTrait, SpanTrait}; + /// use orion::operators::tensor::{TensorTrait, Tensor, FP16x16Tensor, FP16x16TensorPartialEq}; + /// use orion::operators::tensor::math::resize::{ + /// MODE, NEAREST_MODE, KEEP_ASPECT_RATIO_POLICY, TRANSFORMATION_MODE + /// }; + /// use orion::numbers::{FP16x16, FP16x16Impl, FixedTrait}; + /// use core::debug::PrintTrait; + /// + /// fn example_resize_downsample_scales_linear() -> Tensor{ + /// let mut data = TensorTrait::< + /// FP16x16 + /// >::new( + /// shape: array![1, 1, 2, 4].span(), + /// data: array![ + /// FixedTrait::::new(65536, false), //1 + /// FixedTrait::::new(131072, false), //2 + /// FixedTrait::::new(196608, false), //3 + /// FixedTrait::::new(262144, false), //4 + /// FixedTrait::::new(327680, false), //5 + /// FixedTrait::::new(393216, false), //6 + /// FixedTrait::::new(458752, false), //7 + /// FixedTrait::::new(524288, false), //8 + /// ] + /// .span(), + /// ); + /// let mut scales = array![ + /// FixedTrait::::new(65536, false), //1 + /// FixedTrait::::new(65536, false), + /// FixedTrait::::new(39322, false), //0.6 + /// FixedTrait::::new(39322, false) + /// ] + /// .span(); + /// + /// let scales = Option::Some(scales); + /// + /// return data.resize( + /// Option::None, + /// scales, + /// Option::None, + /// Option::None, + /// Option::None, + /// Option::None, + /// Option::None, + /// Option::None, + /// Option::None, + /// Option::None, + /// Option::Some(MODE::LINEAR), + /// Option::None, + /// ); + /// + /// } + /// >>> [[[[2.6666665 4.3333331]]]] + /// + /// + /// + /// fn example_resize_tf_crop_and_resize_extrapolation_value() -> Tensor { + /// let mut data = TensorTrait::< + /// FP16x16 + /// >::new( + /// shape: array![1, 1, 4, 4].span(), + /// data: array![ + /// FixedTrait::::new(65536, false), + /// FixedTrait::::new(131072, false), + /// FixedTrait::::new(196608, false), + /// FixedTrait::::new(262144, false), + /// FixedTrait::::new(327680, false), + /// FixedTrait::::new(393216, false), + /// FixedTrait::::new(458752, false), + /// FixedTrait::::new(524288, false), + /// FixedTrait::::new(589824, false), + /// FixedTrait::::new(655360, false), + /// FixedTrait::::new(720896, false), + /// FixedTrait::::new(786432, false), + /// FixedTrait::::new(851968, false), + /// FixedTrait::::new(917504, false), + /// FixedTrait::::new(983040, false), + /// FixedTrait::::new(1048576, false), + /// ] + /// .span(), + /// ); + /// + /// let mut roi = TensorTrait::< + /// FP16x16 + /// >::new( + /// shape: array![8].span(), + /// data: array![ + /// FixedTrait::::new(0, false), + /// FixedTrait::::new(0, false), + /// FixedTrait::::new(26214, false), + /// FixedTrait::::new(39322, false), + /// FixedTrait::::new(65536, false), + /// FixedTrait::::new(65536, false), + /// FixedTrait::::new(78643, false), + /// FixedTrait::::new(111411, false), + /// ] + /// .span(), + /// ); + /// let roi = Option::Some(roi); + /// + /// let mut sizes = array![1, 1, 3, 3].span(); + /// let sizes = Option::Some(sizes); + /// + /// let extrapolation_value = Option::Some(FixedTrait::::new(655360, false)); + /// + /// return data.resize( + /// roi, + /// Option::None, + /// sizes, + /// Option::None, + /// Option::None, + /// Option::Some(TRANSFORMATION_MODE::TF_CROP_AND_RESIZE), + /// Option::None, + /// Option::None, + /// extrapolation_value, + /// Option::None, + /// Option::Some(MODE::LINEAR), + /// Option::None, + /// ); + /// + /// } + /// >>> [[[[ 7.6000004 10. 10. ] + /// [12.400001 10. 10. ] + /// [10. 10. 10. ]]]] + /// + /// + /// + /// fn example_resize_downsample_sizes_cubic_antialias() -> Tensor { + /// let mut data = TensorTrait::< + /// FP16x16 + /// >::new( + /// shape: array![1, 1, 4, 4].span(), + /// data: array![ + /// FixedTrait::::new(65536, false), + /// FixedTrait::::new(131072, false), + /// FixedTrait::::new(196608, false), + /// FixedTrait::::new(262144, false), + /// FixedTrait::::new(327680, false), + /// FixedTrait::::new(393216, false), + /// FixedTrait::::new(458752, false), + /// FixedTrait::::new(524288, false), + /// FixedTrait::::new(589824, false), + /// FixedTrait::::new(655360, false), + /// FixedTrait::::new(720896, false), + /// FixedTrait::::new(786432, false), + /// FixedTrait::::new(851968, false), + /// FixedTrait::::new(917504, false), + /// FixedTrait::::new(983040, false), + /// FixedTrait::::new(1048576, false), + /// ] + /// .span(), + /// ); + /// + /// let antialias = Option::Some(1); + /// + /// let mut sizes = array![1, 1, 3, 3].span(); + /// let sizes = Option::Some(sizes); + /// + /// return data.resize( + /// Option::None, + /// Option::None, + /// sizes, + /// antialias, + /// Option::None, + /// Option::None, + /// Option::None, + /// Option::None, + /// Option::None, + /// Option::None, + /// Option::Some(MODE::CUBIC), + /// Option::None, + /// ); + /// } + /// + /// >>> [[[[ 1.7750092 3.1200073 4.4650054] + /// [ 7.1550016 8.5 9.844998 ] + /// [12.534994 13.8799925 15.224991 ]]]] + /// + /// ``` + /// + fn resize( + self: @Tensor, + roi: Option>, + scales: Option>, + sizes: Option>, + antialias: Option, + axes: Option>, + coordinate_transformation_mode: Option< + orion::operators::tensor::math::resize::TRANSFORMATION_MODE + >, + cubic_coeff_a: Option, + exclude_outside: Option, + extrapolation_value: Option, + keep_aspect_ratio_policy: Option< + orion::operators::tensor::math::resize::KEEP_ASPECT_RATIO_POLICY + >, + mode: Option, + nearest_mode: Option, + ) -> Tensor; /// #tensor.round /// /// ```rust diff --git a/src/operators/tensor/implementations/tensor_bool.cairo b/src/operators/tensor/implementations/tensor_bool.cairo index d2afe3fc5..589690203 100644 --- a/src/operators/tensor/implementations/tensor_bool.cairo +++ b/src/operators/tensor/implementations/tensor_bool.cairo @@ -472,9 +472,29 @@ impl BoolTensor of TensorTrait { panic(array!['not supported!']) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } + + fn resize( + self: @Tensor, + roi: Option>, + scales: Option>, + sizes: Option>, + antialias: Option, + axes: Option>, + coordinate_transformation_mode: Option, + cubic_coeff_a: Option, + exclude_outside: Option, + extrapolation_value: Option, + keep_aspect_ratio_policy: Option, + mode: Option, + nearest_mode: Option, + ) -> Tensor { + panic(array!['not supported!']) + } } /// Implements partial equal for two `Tensor` using the `PartialEq` trait. diff --git a/src/operators/tensor/implementations/tensor_fp16x16.cairo b/src/operators/tensor/implementations/tensor_fp16x16.cairo index ccaf5903d..a14c2704b 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16.cairo @@ -520,10 +520,12 @@ impl FP16x16Tensor of TensorTrait { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } - + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { math::reduce_log_sum::reduce_log_sum(self, axis, keepdims) } @@ -537,6 +539,38 @@ impl FP16x16Tensor of TensorTrait { ) -> (Tensor, Tensor, Tensor, Tensor) { manipulation::unique::unique(self, axis, sorted) } + + fn resize( + self: @Tensor, + roi: Option>, + scales: Option>, + sizes: Option>, + antialias: Option, + axes: Option>, + coordinate_transformation_mode: Option, + cubic_coeff_a: Option, + exclude_outside: Option, + extrapolation_value: Option, + keep_aspect_ratio_policy: Option, + mode: Option, + nearest_mode: Option, + ) -> Tensor { + math::resize::resize( + self, + roi, + scales, + sizes, + antialias, + axes, + coordinate_transformation_mode, + cubic_coeff_a, + exclude_outside, + extrapolation_value, + keep_aspect_ratio_policy, + mode, + nearest_mode + ) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo index dc32202ed..566da24b2 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo @@ -486,10 +486,12 @@ impl FP16x16WTensor of TensorTrait { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } - + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { math::reduce_log_sum::reduce_log_sum(self, axis, keepdims) } @@ -503,6 +505,24 @@ impl FP16x16WTensor of TensorTrait { ) -> (Tensor, Tensor, Tensor, Tensor) { manipulation::unique::unique(self, axis, sorted) } + + fn resize( + self: @Tensor, + roi: Option>, + scales: Option>, + sizes: Option>, + antialias: Option, + axes: Option>, + coordinate_transformation_mode: Option, + cubic_coeff_a: Option, + exclude_outside: Option, + extrapolation_value: Option, + keep_aspect_ratio_policy: Option, + mode: Option, + nearest_mode: Option, + ) -> Tensor { + panic(array!['not supported!']) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp32x32.cairo b/src/operators/tensor/implementations/tensor_fp32x32.cairo index 9100d6f82..3bb86aaae 100644 --- a/src/operators/tensor/implementations/tensor_fp32x32.cairo +++ b/src/operators/tensor/implementations/tensor_fp32x32.cairo @@ -521,10 +521,12 @@ impl FP32x32Tensor of TensorTrait { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } - + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { math::reduce_log_sum::reduce_log_sum(self, axis, keepdims) } @@ -538,6 +540,38 @@ impl FP32x32Tensor of TensorTrait { ) -> (Tensor, Tensor, Tensor, Tensor) { manipulation::unique::unique(self, axis, sorted) } + + fn resize( + self: @Tensor, + roi: Option>, + scales: Option>, + sizes: Option>, + antialias: Option, + axes: Option>, + coordinate_transformation_mode: Option, + cubic_coeff_a: Option, + exclude_outside: Option, + extrapolation_value: Option, + keep_aspect_ratio_policy: Option, + mode: Option, + nearest_mode: Option, + ) -> Tensor { + math::resize::resize( + self, + roi, + scales, + sizes, + antialias, + axes, + coordinate_transformation_mode, + cubic_coeff_a, + exclude_outside, + extrapolation_value, + keep_aspect_ratio_policy, + mode, + nearest_mode + ) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp64x64.cairo b/src/operators/tensor/implementations/tensor_fp64x64.cairo index ee6441058..d52922aab 100644 --- a/src/operators/tensor/implementations/tensor_fp64x64.cairo +++ b/src/operators/tensor/implementations/tensor_fp64x64.cairo @@ -522,10 +522,12 @@ impl FP64x64Tensor of TensorTrait { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } - + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { math::reduce_log_sum::reduce_log_sum(self, axis, keepdims) } @@ -539,6 +541,38 @@ impl FP64x64Tensor of TensorTrait { ) -> (Tensor, Tensor, Tensor, Tensor) { manipulation::unique::unique(self, axis, sorted) } + + fn resize( + self: @Tensor, + roi: Option>, + scales: Option>, + sizes: Option>, + antialias: Option, + axes: Option>, + coordinate_transformation_mode: Option, + cubic_coeff_a: Option, + exclude_outside: Option, + extrapolation_value: Option, + keep_aspect_ratio_policy: Option, + mode: Option, + nearest_mode: Option, + ) -> Tensor { + math::resize::resize( + self, + roi, + scales, + sizes, + antialias, + axes, + coordinate_transformation_mode, + cubic_coeff_a, + exclude_outside, + extrapolation_value, + keep_aspect_ratio_policy, + mode, + nearest_mode + ) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp8x23.cairo b/src/operators/tensor/implementations/tensor_fp8x23.cairo index 17a601f7b..e83ed3f4f 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23.cairo @@ -520,10 +520,12 @@ impl FP8x23Tensor of TensorTrait { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } - + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { math::reduce_log_sum::reduce_log_sum(self, axis, keepdims) } @@ -537,6 +539,38 @@ impl FP8x23Tensor of TensorTrait { ) -> (Tensor, Tensor, Tensor, Tensor) { manipulation::unique::unique(self, axis, sorted) } + + fn resize( + self: @Tensor, + roi: Option>, + scales: Option>, + sizes: Option>, + antialias: Option, + axes: Option>, + coordinate_transformation_mode: Option, + cubic_coeff_a: Option, + exclude_outside: Option, + extrapolation_value: Option, + keep_aspect_ratio_policy: Option, + mode: Option, + nearest_mode: Option, + ) -> Tensor { + math::resize::resize( + self, + roi, + scales, + sizes, + antialias, + axes, + coordinate_transformation_mode, + cubic_coeff_a, + exclude_outside, + extrapolation_value, + keep_aspect_ratio_policy, + mode, + nearest_mode + ) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo index a7d19901b..a71b0b307 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo @@ -473,10 +473,12 @@ impl FP8x23WTensor of TensorTrait { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } - + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { math::reduce_log_sum::reduce_log_sum(self, axis, keepdims) } @@ -490,6 +492,24 @@ impl FP8x23WTensor of TensorTrait { ) -> (Tensor, Tensor, Tensor, Tensor) { manipulation::unique::unique(self, axis, sorted) } + + fn resize( + self: @Tensor, + roi: Option>, + scales: Option>, + sizes: Option>, + antialias: Option, + axes: Option>, + coordinate_transformation_mode: Option, + cubic_coeff_a: Option, + exclude_outside: Option, + extrapolation_value: Option, + keep_aspect_ratio_policy: Option, + mode: Option, + nearest_mode: Option, + ) -> Tensor { + panic(array!['not supported!']) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_i32.cairo b/src/operators/tensor/implementations/tensor_i32.cairo index a987b0633..a64622c6c 100644 --- a/src/operators/tensor/implementations/tensor_i32.cairo +++ b/src/operators/tensor/implementations/tensor_i32.cairo @@ -517,10 +517,12 @@ impl I32Tensor of TensorTrait { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } - + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { panic(array!['not supported!']) } @@ -534,6 +536,24 @@ impl I32Tensor of TensorTrait { ) -> (Tensor, Tensor, Tensor, Tensor) { manipulation::unique::unique(self, axis, sorted) } + + fn resize( + self: @Tensor, + roi: Option>, + scales: Option>, + sizes: Option>, + antialias: Option, + axes: Option>, + coordinate_transformation_mode: Option, + cubic_coeff_a: Option, + exclude_outside: Option, + extrapolation_value: Option, + keep_aspect_ratio_policy: Option, + mode: Option, + nearest_mode: Option, + ) -> Tensor { + panic(array!['not supported!']) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_i8.cairo b/src/operators/tensor/implementations/tensor_i8.cairo index 8c1e2fd32..adffef420 100644 --- a/src/operators/tensor/implementations/tensor_i8.cairo +++ b/src/operators/tensor/implementations/tensor_i8.cairo @@ -515,10 +515,12 @@ impl I8Tensor of TensorTrait { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } - + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { panic(array!['not supported!']) } @@ -532,6 +534,24 @@ impl I8Tensor of TensorTrait { ) -> (Tensor, Tensor, Tensor, Tensor) { manipulation::unique::unique(self, axis, sorted) } + + fn resize( + self: @Tensor, + roi: Option>, + scales: Option>, + sizes: Option>, + antialias: Option, + axes: Option>, + coordinate_transformation_mode: Option, + cubic_coeff_a: Option, + exclude_outside: Option, + extrapolation_value: Option, + keep_aspect_ratio_policy: Option, + mode: Option, + nearest_mode: Option, + ) -> Tensor { + panic(array!['not supported!']) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_u32.cairo b/src/operators/tensor/implementations/tensor_u32.cairo index 5b2058401..feb8b501e 100644 --- a/src/operators/tensor/implementations/tensor_u32.cairo +++ b/src/operators/tensor/implementations/tensor_u32.cairo @@ -458,10 +458,12 @@ impl U32Tensor of TensorTrait { math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } - + fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { panic(array!['not supported!']) } @@ -475,6 +477,24 @@ impl U32Tensor of TensorTrait { ) -> (Tensor, Tensor, Tensor, Tensor) { manipulation::unique::unique(self, axis, sorted) } + + fn resize( + self: @Tensor, + roi: Option>, + scales: Option>, + sizes: Option>, + antialias: Option, + axes: Option>, + coordinate_transformation_mode: Option, + cubic_coeff_a: Option, + exclude_outside: Option, + extrapolation_value: Option, + keep_aspect_ratio_policy: Option, + mode: Option, + nearest_mode: Option, + ) -> Tensor { + panic(array!['not supported!']) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/math.cairo b/src/operators/tensor/math.cairo index e9822a21f..ad40bd1b2 100644 --- a/src/operators/tensor/math.cairo +++ b/src/operators/tensor/math.cairo @@ -65,3 +65,4 @@ mod is_inf; mod gather_nd; mod reduce_log_sum; mod erf; +mod resize; diff --git a/src/operators/tensor/math/gather_nd.cairo b/src/operators/tensor/math/gather_nd.cairo index 120eff5be..737a4fe32 100644 --- a/src/operators/tensor/math/gather_nd.cairo +++ b/src/operators/tensor/math/gather_nd.cairo @@ -14,12 +14,7 @@ use orion::operators::tensor::U32TensorPartialEq; use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor}; /// Cf: TensorTrait::gather_nd docstring -fn gather_nd< - T, - impl TTensorTrait: TensorTrait, - impl TCopy: Copy, - impl TDrop: Drop, ->( +fn gather_nd, impl TCopy: Copy, impl TDrop: Drop,>( self: @Tensor, indices: Tensor, batch_dims: Option ) -> Tensor { let batch_dims = match batch_dims { @@ -29,19 +24,22 @@ fn gather_nd< let data_rank = (*self.shape).len(); let indices_rank = (indices.shape).len(); - assert((data_rank >= 1 ) & (indices_rank >= 1), 'rank must > 1'); - + assert((data_rank >= 1) & (indices_rank >= 1), 'rank must > 1'); + let mut data_shape = *self.shape; let mut indices_shape = indices.shape; let mut data_shape_clone = data_shape.clone(); let mut indices_shape_clone = indices_shape.clone(); let indices_shape_last = indices_shape_clone.pop_back().unwrap(); - assert((*indices_shape_last >= 1) & (*indices_shape_last <= data_rank-batch_dims), 'check indices'); + assert( + (*indices_shape_last >= 1) & (*indices_shape_last <= data_rank - batch_dims), + 'check indices' + ); let mut batch_dims_shape = ArrayTrait::new(); let mut output_shape = ArrayTrait::new(); - let mut index_data = ArrayTrait::new(); + let mut index_data = ArrayTrait::new(); let mut output_data = ArrayTrait::new(); let mut batch_dims_size = batch_dims; @@ -51,7 +49,7 @@ fn gather_nd< let mut ind = 0; loop { if (ind == batch_dims) { - break(); + break (); } match indices_shape_clone.pop_front() { Option::Some(val) => { @@ -65,17 +63,14 @@ fn gather_nd< loop { match indices_shape_clone.pop_front() { - Option::Some(val) => { - batch_dims_shape.append(*val); - }, + Option::Some(val) => { batch_dims_shape.append(*val); }, Option::None(_) => { break; } }; }; if (*indices_shape_last == data_rank - batch_dims) { output_shape = batch_dims_shape; - } - else { + } else { let mut ind = 0; let mut multiple = 1; output_shape = batch_dims_shape; @@ -136,16 +131,18 @@ fn gather_nd< match data_indices.pop_front() { Option::Some(val) => { let index = ind % *indices_shape_last; - let incr= total_data_len * (ind / breaker); + let incr = total_data_len * (ind / breaker); result += (*val * total_data_len / *multiple_data_len.at(index)); ind += 1; - if (index == *indices_shape_last-1) { - let mut data_ind:usize = result ; + if (index == *indices_shape_last - 1) { + let mut data_ind: usize = result; loop { - if data_ind == result + incrementer { break; } + if data_ind == result + incrementer { + break; + } index_data.append(data_ind + incr); - data_ind+=1; + data_ind += 1; }; result = 0; }; @@ -156,13 +153,11 @@ fn gather_nd< loop { match index_data.pop_front() { - Option::Some(val) => { - output_data.append(*self.data[val]); - }, + Option::Some(val) => { output_data.append(*self.data[val]); }, Option::None(_) => { break; } }; }; let mut output_tensor = TensorTrait::::new(output_shape.span(), output_data.span()); return output_tensor; -} \ No newline at end of file +} diff --git a/src/operators/tensor/math/resize.cairo b/src/operators/tensor/math/resize.cairo new file mode 100644 index 000000000..ecae7b186 --- /dev/null +++ b/src/operators/tensor/math/resize.cairo @@ -0,0 +1,1470 @@ +use core::traits::TryInto; +use core::array::ArrayTrait; +use core::array::SpanTrait; +use core::option::OptionTrait; +use core::traits::Into; +use orion::numbers::NumberTrait; +use alexandria_sorting::bubble_sort; +use orion::operators::tensor::{ + TensorTrait, Tensor, I8Tensor, I32Tensor, U32Tensor, FP16x16Tensor, BoolTensor +}; +use orion::numbers::{FP16x16, FP16x16Impl, FP32x32, FP32x32Impl, FixedTrait}; +use core::debug::PrintTrait; + +#[derive(Copy, Drop)] +enum MODE { + NEAREST, + LINEAR, + CUBIC, +} + +#[derive(Copy, Drop)] +enum NEAREST_MODE { + ROUND_PREFER_FLOOR, + ROUND_PREFER_CEIL, + FLOOR, + CEIL +} + +#[derive(Copy, Drop)] +enum KEEP_ASPECT_RATIO_POLICY { + STRETCH, + NOT_LARGER, + NOT_SMALLER +} + +#[derive(Copy, Drop)] +enum TRANSFORMATION_MODE { + HALF_PIXEL, + ALIGN_CORNERS, + ASYMMETRIC, + TF_CROP_AND_RESIZE, + PYTORCH_HALF_PIXEL, + HALF_PIXEL_SYMMETRIC +} + + +/// Cf: TensorTrait::resize docstring +fn resize< + T, + MAG, + +TensorTrait, + +NumberTrait, + +PartialOrd, + +PartialEq, + +Copy, + +Drop, + +TryInto, + +Into, + +AddEq, + +Add, + +Div, + +Mul, + +Sub, +>( + self: @Tensor, + roi: Option>, + scales: Option>, + sizes: Option>, + antialias: Option, + axes: Option>, + coordinate_transformation_mode: Option, + cubic_coeff_a: Option, + exclude_outside: Option, + extrapolation_value: Option, + keep_aspect_ratio_policy: Option, + mode: Option, + nearest_mode: Option, +) -> Tensor { + let output = interpolate_nd( + self, + antialias, + mode, + nearest_mode, + scales, + sizes, + roi, + keep_aspect_ratio_policy, + exclude_outside, + coordinate_transformation_mode, + extrapolation_value, + axes, + cubic_coeff_a + ); + return output; +} + +fn interpolate_nd< + T, + MAG, + +TensorTrait, + +NumberTrait, + +PartialOrd, + +PartialEq, + +Copy, + +Drop, + +TryInto, + +Into, + +AddEq, + +Add, + +Div, + +Mul, + +Sub, +>( + data: @Tensor, + antialias: Option, + mode: Option, + nearest_mode: Option, + scale_factors: Option>, + output_size: Option>, + roi: Option>, + keep_aspect_ratio_policy: Option, + exclude_outside: Option, + coordinate_transformation_mode: Option, + extrapolation_value: Option, + axes: Option>, + cubic_coeff_a: Option, +) -> Tensor { + let mode = match mode { + Option::Some(mode) => mode, + Option::None => { MODE::NEAREST }, + }; + + let keep_aspect_ratio_policy = match keep_aspect_ratio_policy { + Option::Some(keep_aspect_ratio_policy) => keep_aspect_ratio_policy, + Option::None => { KEEP_ASPECT_RATIO_POLICY::STRETCH }, + }; + + let exclude_outside = match exclude_outside { + Option::Some(exclude_outside) => exclude_outside, + Option::None => { false }, + }; + + let extrapolation_value = match extrapolation_value { + Option::Some(extrapolation_value) => extrapolation_value, + Option::None => { NumberTrait::zero() }, + }; + + if output_size.is_none() && scale_factors.is_none() { + core::panic_with_felt252('size and scale are None'); + } + + let r = (*data).shape.len(); + + let (axes, scale_factors, output_size, roi) = match axes { + Option::Some(axes) => { + let mut scale_factors = match scale_factors { + Option::Some(scale_factors) => { + let mut new_scale_factors = ArrayTrait::::new(); + let mut d = 0; + loop { + if d == r { + break; + } + let mut i = 0; + let item = loop { + if i == axes.len() { + break NumberTrait::one(); + } + if *axes.at(i) == d { + break *scale_factors.at(i); + } + i += 1; + }; + new_scale_factors.append(item); + d += 1; + }; + + Option::Some(new_scale_factors.span()) + }, + Option::None => { Option::None }, + }; + + let mut output_size = match output_size { + Option::Some(output_size) => { + let mut new_output_size = ArrayTrait::new(); + let mut d = 0; + loop { + if d == r { + break; + } + let mut i = 0; + let item = loop { + if i == axes.len() { + break *(*data).shape.at(d); + } + if *axes.at(i) == d { + break *output_size.at(i); + } + i += 1; + }; + new_output_size.append(item); + d += 1; + }; + Option::Some(new_output_size.span()) + }, + Option::None => { Option::None }, + }; + + let mut roi = match roi { + Option::Some(roi) => { + let mut new_roi_data = ArrayTrait::new(); + let naxes = axes.len(); + let mut d = 0; + loop { + if d == r { + break; + } + let mut i = 0; + let item = loop { + if i == axes.len() { + break NumberTrait::zero(); + } + if *axes.at(i) == d { + break *roi.data.at(i); + } + i += 1; + }; + new_roi_data.append(item); + d += 1; + }; + + let mut d = 0; + loop { + if d == r { + break; + } + let mut i = 0; + let item = loop { + if i == axes.len() { + break NumberTrait::one(); + } + if *axes.at(i) == d { + break *roi.data.at(i + naxes); + } + i += 1; + }; + new_roi_data.append(item); + d += 1; + }; + let mut shape = ArrayTrait::new(); + shape.append(r * 2); + Option::Some(TensorTrait::new(shape.span(), new_roi_data.span())) + }, + Option::None => { Option::None }, + }; + (axes, scale_factors, output_size, roi) + }, + Option::None => { + let mut axes = ArrayTrait::new(); + let mut i = 0; + loop { + if i == r { + break; + } + axes.append(i); + i += 1; + }; + (axes.span(), scale_factors, output_size, roi) + } + }; + let (mut output_size, mut scale_factors) = match output_size { + Option::Some(output_size) => { + let mut scale_factors = ArrayTrait::::new(); + let mut i = 0; + loop { + if i == r { + break; + } + + let output_size_i: T = NumberTrait::new_unscaled( + (*output_size.at(i)).into(), false + ); + let data_shape_i: T = NumberTrait::new_unscaled( + (*(*data).shape.at(i)).into(), false + ); + + scale_factors.append(output_size_i / data_shape_i); + i += 1; + }; + + let (mut output_size, mut scale_factors) = match keep_aspect_ratio_policy { + KEEP_ASPECT_RATIO_POLICY::STRETCH => { (output_size, scale_factors.span()) }, + KEEP_ASPECT_RATIO_POLICY::NOT_LARGER => { + let mut scale = *scale_factors.at(*axes.at(0)); + let mut i = 1; + loop { + if i == axes.len() { + break; + } + if scale > *scale_factors.at(*axes.at(i)) { + scale = *scale_factors.at(*axes.at(i)); + } + i += 1; + }; + + let mut scale_factors = ArrayTrait::::new(); + let mut d = 0; + loop { + if d == r { + break; + } + let mut i = 0; + let item = loop { + if i == axes.len() { + break NumberTrait::one(); + } + if *axes.at(i) == d { + break scale; + } + i += 1; + }; + scale_factors.append(item); + d += 1; + }; + + let mut output_size = ArrayTrait::new(); + let mut d = 0; + loop { + if d == r { + break; + } + let mut i = 0; + let item = loop { + if i == axes.len() { + break *(*data).shape.at(d); + } + if *axes.at(i) == d { + break NumberTrait::round( + scale + * NumberTrait::new_unscaled( + (*(*data).shape.at(d)).into(), false + ) + ) + .try_into() + .unwrap(); + } + i += 1; + }; + output_size.append(item); + d += 1; + }; + (output_size.span(), scale_factors.span()) + }, + KEEP_ASPECT_RATIO_POLICY::NOT_SMALLER => { + let mut scale = *scale_factors.at(*axes.at(0)); + let mut i = 1; + loop { + if i == axes.len() { + break; + } + if scale < *scale_factors.at(*axes.at(i)) { + scale = *scale_factors.at(*axes.at(i)); + } + i += 1; + }; + let mut scale_factors = ArrayTrait::::new(); + let mut d = 0; + loop { + if d == r { + break; + } + let mut i = 0; + let item = loop { + if i == axes.len() { + break NumberTrait::one(); + } + if *axes.at(i) == d { + break scale; + } + i += 1; + }; + scale_factors.append(item); + d += 1; + }; + + let mut output_size = ArrayTrait::new(); + let mut d = 0; + loop { + if d == r { + break; + } + let mut i = 0; + let item = loop { + if i == axes.len() { + break *(*data).shape.at(d); + } + if *axes.at(i) == d { + break NumberTrait::round( + scale + * NumberTrait::new_unscaled( + (*(*data).shape.at(d)).into(), false + ) + ) + .try_into() + .unwrap(); + } + i += 1; + }; + output_size.append(item); + d += 1; + }; + (output_size.span(), scale_factors.span()) + }, + }; + + (output_size, scale_factors) + }, + Option::None => { + let mut output_size = ArrayTrait::::new(); + + let scale_factors = match scale_factors { + Option::Some(scale_factors) => scale_factors, + Option::None => { core::panic_with_felt252('size and scale None') }, + }; + + let mut i = 0; + loop { + if i == scale_factors.len() { + break; + } + let scale: usize = (*scale_factors.at(i)).try_into().unwrap(); + + let item = *scale_factors.at(i) + * NumberTrait::new_unscaled((*(*(data).shape).at(i)).into(), false); + output_size.append(item.try_into().unwrap()); + i += 1; + }; + (output_size.span(), scale_factors) + }, + }; + + let mut ret = ArrayTrait::>::new(); + let mut i = 0; + loop { + let mut temp = ArrayTrait::::new(); + if i == output_size.len() { + break; + } + let mut j = 0; + loop { + if j == *output_size.at(i) { + break; + } + temp.append(j); + j += 1; + }; + ret.append(temp.span()); + i += 1; + }; + + let mut ret = cartesian(ret.span()); + let mut ret_data = ArrayTrait::new(); + + loop { + match ret.pop_front() { + Option::Some(X) => { + let mut x = ArrayTrait::::new(); + let mut i = 0; + loop { + if i == X.len() { + break; + } + x.append(NumberTrait::new_unscaled((*X.at(i)).into(), false)); + i += 1; + }; + let mut x = x.span(); + let item = interpolate_nd_with_x( + data, + (*data).shape.len(), + scale_factors, + output_size, + x, + antialias, + mode, + nearest_mode, + roi, + extrapolation_value, + coordinate_transformation_mode, + exclude_outside, + cubic_coeff_a + ); + + ret_data.append(*item.data.at(0)); + }, + Option::None => { break; } + } + }; + + let mut shape = ArrayTrait::new(); + shape.append(ret_data.len()); + return TensorTrait::new(output_size, ret_data.span()); +} + +fn cartesian(mut arrays: Span>,) -> Array> { + let mut n = 1; + let mut i = arrays.len() - 1; + loop { + n = n * (*(arrays.at(i))).len(); + if i == 0 { + break; + } + i -= 1; + }; + + let mut i = 0; + let mut size_arrays = ArrayTrait::new(); + let mut m = n; + loop { + if i == arrays.len() { + break; + } + size_arrays.append((*(arrays.at(i))).len()); + + i += 1; + }; + let size_arrays = size_arrays.span(); + let mut output_arrays = ArrayTrait::>::new(); + let mut m = n; + + let mut i = 0; + loop { + if i == arrays.len() { + break; + } + m = m / (*(arrays.at(i))).len(); + let mut out = repeat(*(arrays.at(i)), m); + out = repeat_2(out, size_arrays, i); + + output_arrays.append(out); + i += 1; + }; + let output_arrays = output_arrays.span(); + + let mut i = 0; + let mut ret = ArrayTrait::>::new(); + loop { + if i == n { + break; + } + let mut j = 0; + let mut x = ArrayTrait::new(); + loop { + if j == arrays.len() { + break; + } + + x.append(*(output_arrays.at(j)).at(i)); + j += 1; + }; + ret.append(x); + i += 1; + }; + + return ret; +} + + +fn repeat_2(mut array: Array, size_array: Span, index: usize) -> Array { + let mut size = array.len(); + let mut i = 0; + loop { + if i == index { + break; + } + let mut j = 1; + loop { + if j == *size_array.at(index - 1 - i) { + break; + } + let mut k = 0; + loop { + if k == size { + break; + } + array.append(*array.at(k)); + k += 1; + }; + j += 1; + }; + size = size * *size_array.at(index - 1 - i); + i += 1; + }; + array +} + +fn repeat(array: Span, m: usize,) -> Array { + let mut out = ArrayTrait::new(); + let mut j = 0; + loop { + if j == array.len() { + break; + } + let mut k = 0; + loop { + if k == m { + break; + } + out.append(*array.at(j)); + k += 1; + }; + j += 1; + }; + + out +} + +fn interpolate_nd_with_x< + T, + MAG, + +TensorTrait, + +NumberTrait, + +PartialOrd, + +PartialEq, + +Copy, + +Drop, + +TryInto, + +Into, + +AddEq, + +Add, + +Div, + +Mul, + +Sub, +>( + data: @Tensor, + n: usize, + mut scale_factor: Span, + mut output_size: Span, + mut x: Span, + antialias: Option, + mode: MODE, + nearest_mode: Option, + roi: Option>, + extrapolation_value: T, + coordinate_transformation_mode: Option, + exclude_outside: bool, + cubic_coeff_a: Option, +) -> Tensor { + if n == 1 { + return interpolate_1d_with_x( + data, + *scale_factor.at(0), + *output_size.at(0), + *x.at(0), + antialias, + mode, + nearest_mode, + roi, + extrapolation_value, + coordinate_transformation_mode, + exclude_outside, + cubic_coeff_a + ); + } + let mut res1d = ArrayTrait::new(); + + let scale_factor_zero = match scale_factor.pop_front() { + Option::Some(item) => { *item }, + Option::None => core::panic_with_felt252('scale factor empty') + }; + let output_size_zero = match output_size.pop_front() { + Option::Some(item) => { *item }, + Option::None => core::panic_with_felt252('output_size empty') + }; + let x_zero = match x.pop_front() { + Option::Some(item) => { *item }, + Option::None => core::panic_with_felt252('x empty') + }; + + let reduced_roi = match roi { + Option::Some(roi) => { + let mut reduced_roi = ArrayTrait::new(); + let mut reduced_roi_shape = ArrayTrait::new(); + reduced_roi_shape.append(roi.data.len() - 2); + + let mut i = 1; + loop { + if i == 2 * n { + break; + } + if i != n { + reduced_roi.append(*roi.data.at(i)); + } + i += 1; + }; + Option::Some(TensorTrait::new(reduced_roi_shape.span(), reduced_roi.span())) + }, + Option::None => { Option::None } + }; + + let mut i = 0; + loop { + if i == *(*data).shape.at(0) { + break; + } + let data = get_row_n(data, i); + + let mut r = interpolate_nd_with_x( + @data, + n - 1, + scale_factor, + output_size, + x, + antialias, + mode, + nearest_mode, + reduced_roi, + extrapolation_value, + coordinate_transformation_mode, + exclude_outside, + cubic_coeff_a + ); + loop { + match r.data.pop_front() { + Option::Some(item) => { res1d.append(*item); }, + Option::None => { break; } + } + }; + i += 1; + }; + + let mut shape = ArrayTrait::new(); + shape.append(res1d.len()); + + let res1d = TensorTrait::new(shape.span(), res1d.span()); + + let reduced_roi = match roi { + Option::Some(roi) => { + let mut reduced_roi = ArrayTrait::new(); + let mut reduced_roi_shape = ArrayTrait::new(); + + reduced_roi_shape.append(2); + reduced_roi.append(*roi.data.at(0)); + reduced_roi.append(*roi.data.at(n)); + + Option::Some(TensorTrait::new(reduced_roi_shape.span(), reduced_roi.span())) + }, + Option::None => { Option::None } + }; + + let a = interpolate_1d_with_x( + @res1d, + scale_factor_zero, + output_size_zero, + x_zero, + antialias, + mode, + nearest_mode, + reduced_roi, + extrapolation_value, + coordinate_transformation_mode, + exclude_outside, + cubic_coeff_a + ); + + //let mut ret = ArrayTrait::new(); + //let mut shape = ArrayTrait::new(); + //shape.append(2); + //ret.append(NumberTrait::zero()); + return a; +} + +fn get_row_n, +Copy, +Drop,>( + data: @Tensor, index: usize, +) -> Tensor { + let mut output_data = ArrayTrait::new(); + let mut output_shape = ArrayTrait::new(); + let mut stride_output = 1; + let mut n: usize = 0; + + let mut i = 0; + loop { + if i == (*data).shape.len() { + break; + } + if i != 0 { + output_shape.append(*(*data).shape.at(i)); + stride_output = stride_output * *(*data).shape.at(i); + } + i += 1; + }; + + let mut i = 0; + loop { + if i == stride_output { + break; + } + output_data.append(*(*data).data.at(index * stride_output + i)); + i += 1; + }; + + return TensorTrait::new(output_shape.span(), output_data.span()); +} + + +fn interpolate_1d_with_x< + T, + MAG, + +TensorTrait, + +NumberTrait, + +PartialOrd, + +PartialEq, + +Copy, + +Drop, + +TryInto, + +Into, + +AddEq, + +Add, + +Div, + +Mul, + +Sub, +>( + data: @Tensor, + scale_factor: T, + output_width_int: usize, + x: T, + antialias: Option, + mode: MODE, + nearest_mode: Option, + roi: Option>, + extrapolation_value: T, + coordinate_transformation_mode: Option, + exclude_outside: bool, + cubic_coeff_a: Option, +) -> Tensor { + let coordinate_transformation_mode = match coordinate_transformation_mode { + Option::Some(coordinate_transformation_mode) => coordinate_transformation_mode, + Option::None(_) => { TRANSFORMATION_MODE::HALF_PIXEL }, + }; + + let input_width = (*data).data.len(); + let output_width = (scale_factor * NumberTrait::new_unscaled((input_width).into(), false)); + + let x_ori: T = match coordinate_transformation_mode { + TRANSFORMATION_MODE::HALF_PIXEL => { + (x + NumberTrait::half()) / scale_factor - NumberTrait::half() + }, + TRANSFORMATION_MODE::ALIGN_CORNERS => { + let mut x_ori = NumberTrait::zero(); + if output_width != NumberTrait::one() { + x_ori = x + * (NumberTrait::new_unscaled(input_width.into(), false) - NumberTrait::one()) + / (output_width - NumberTrait::one()); + } + x_ori + }, + TRANSFORMATION_MODE::ASYMMETRIC => { x / scale_factor }, + TRANSFORMATION_MODE::TF_CROP_AND_RESIZE => { + let x_ori = match roi { + Option::Some(roi) => { + let mut x_ori = if output_width == NumberTrait::one() { + (*roi.data.at(1) - *roi.data.at(0)) + * (NumberTrait::new_unscaled(input_width.into(), false) + - NumberTrait::one()) + / (NumberTrait::one() + NumberTrait::one()) + } else { + x + * (*roi.data.at(1) - *roi.data.at(0)) + * (NumberTrait::new_unscaled(input_width.into(), false) + - NumberTrait::one()) + / (output_width - NumberTrait::one()) + }; + + x_ori = x_ori + + *roi.data.at(0) + * (NumberTrait::new_unscaled(input_width.into(), false) + - NumberTrait::one()); + + if x_ori < NumberTrait::zero() + || x_ori > (NumberTrait::new_unscaled(input_width.into(), false) + - NumberTrait::one()) { + let mut ret = ArrayTrait::new(); + let mut shape = ArrayTrait::new(); + shape.append(1); + ret.append(extrapolation_value); + return TensorTrait::new(shape.span(), ret.span()); + }; + x_ori + }, + Option::None(_) => { core::panic_with_felt252('roi cannot be None.') }, + }; + x_ori + }, + TRANSFORMATION_MODE::PYTORCH_HALF_PIXEL => { + if output_width == NumberTrait::one() { + NumberTrait::neg(NumberTrait::::half()) + } else { + (x + NumberTrait::half()) / scale_factor - NumberTrait::half() + } + }, + TRANSFORMATION_MODE::HALF_PIXEL_SYMMETRIC => { + let adjustement: T = NumberTrait::new_unscaled(output_width_int.into(), false) + / output_width; + let center: T = NumberTrait::new_unscaled(input_width.into(), false) + / (NumberTrait::one() + NumberTrait::one()); + let offset = center * (NumberTrait::one() - adjustement); + offset + (x + NumberTrait::half()) / scale_factor - NumberTrait::half() + }, + }; + + let x_ori_int = x_ori.floor(); + + let ratio = if x_ori_int.try_into().unwrap() == x_ori { + NumberTrait::one() + } else { + x_ori - x_ori_int.try_into().unwrap() + }; + + let mut coeffs = match mode { + MODE::NEAREST => { + let coeffs = match antialias { + Option::Some(antialias) => core::panic_with_felt252( + 'antialias not for mode NEAREST' + ), + Option::None(_) => { nearest_coeffs(ratio, nearest_mode) }, + }; + coeffs + }, + MODE::LINEAR => { + let coeffs = match antialias { + Option::Some(antialias) => { + let coeffs = if antialias == 0 { + linear_coeffs(ratio) + } else { + linear_coeffs_antialias(ratio, scale_factor) + }; + coeffs + }, + Option::None(_) => { linear_coeffs(ratio) }, + }; + coeffs + }, + MODE::CUBIC => { + let coeffs = match antialias { + Option::Some(antialias) => { + cubic_coeffs_antialias(ratio, scale_factor, cubic_coeff_a) + }, + Option::None(_) => { cubic_coeffs(ratio, cubic_coeff_a) }, + }; + coeffs + }, + }; + + let n = coeffs.data.len(); + + let (idxes, points) = get_neighbor(x_ori, n, data); + + if exclude_outside { + let mut coeffs_exclude_outside = ArrayTrait::::new(); + let mut sum = NumberTrait::zero(); + let mut i = 0; + loop { + if i == idxes.data.len() { + break; + } + if *idxes.data.at(i) { + coeffs_exclude_outside.append(NumberTrait::zero()); + sum += NumberTrait::zero(); + } else { + coeffs_exclude_outside.append(*coeffs.data.at(i)); + sum += *coeffs.data.at(i); + } + i += 1; + }; + + let mut coeff_div = ArrayTrait::::new(); + let mut i = 0; + loop { + if i == n { + break; + } + coeff_div.append(*coeffs_exclude_outside.at(i) / sum); + i += 1; + }; + coeffs = TensorTrait::new(coeffs.shape, coeff_div.span()); + } + return TensorTrait::matmul(@coeffs, @points); +} + + +fn get_neighbor< + T, + MAG, + +TensorTrait, + +NumberTrait, + +PartialOrd, + +PartialEq, + +Copy, + +Drop, + +TryInto, + +Into, + +AddEq, + +Add, + +Div, + +Mul, + +Sub, +>( + mut x: T, n: usize, data: @Tensor, +) -> (Tensor, Tensor) { + let pad_width: usize = NumberTrait::ceil( + NumberTrait::new_unscaled(n.into(), false) + / (NumberTrait::::one() + NumberTrait::::one()) + ) + .try_into() + .unwrap(); + let mut padded = ArrayTrait::new(); + + let mut i = 0; + loop { + if i == pad_width { + break; + } + padded.append(*(*data).data.at(0)); + i += 1; + }; + let mut i = 0; + loop { + if i == (*data).data.len() { + break; + } + padded.append(*(*data).data.at(i)); + i += 1; + }; + let mut i = 0; + loop { + if i == pad_width { + break; + } + padded.append(*(*data).data.at((*data).data.len() - 1)); + i += 1; + }; + + x = x + NumberTrait::new_unscaled(pad_width.into(), false); + + let mut idxes = get_neighbor_idxes(x, n, padded.len()); + + let mut idxes_centered = ArrayTrait::new(); + let mut ret = ArrayTrait::new(); + let mut i = 0; + loop { + if i == idxes.data.len() { + break; + } + ret.append(*padded.at(*idxes.data.at(i))); + + if *idxes.data.at(i) >= pad_width { + if (*idxes.data.at(i) - pad_width) >= (*data).data.len() { + idxes_centered.append(true); + } else { + idxes_centered.append(false); + } + } else { + idxes_centered.append(true); + } + i += 1; + }; + + let mut shape = ArrayTrait::new(); + shape.append(idxes.data.len()); + + return ( + TensorTrait::new(shape.span(), idxes_centered.span()), + TensorTrait::new(shape.span(), ret.span()) + ); +} + +fn get_neighbor_idxes< + T, + MAG, + +TensorTrait, + +NumberTrait, + +PartialOrd, + +PartialEq, + +Copy, + +Drop, + +TryInto, + +Into, + +AddEq, + +Add, + +Div, + +Mul, + +Sub, +>( + mut x: T, n: usize, limit: usize, +) -> Tensor { + let pad_width: usize = NumberTrait::< + T + >::ceil( + NumberTrait::new_unscaled(n.into(), false) + / (NumberTrait::::one() + NumberTrait::::one()) + ) + .try_into() + .unwrap(); + let mut idxes = ArrayTrait::new(); + + if n % 2 == 0 { + let (mut i_low, mut i_high) = if x < NumberTrait::zero() { + (0, 1) + } else { + (NumberTrait::floor(x).try_into().unwrap(), NumberTrait::ceil(x).try_into().unwrap()) + }; + + if i_high >= limit { + i_low = limit - 2; + i_high = limit - 1; + } + + if i_low == i_high { + if i_low == 0 { + i_high = i_high + 1; + } else { + i_low = i_low - 1; + } + } + + let mut i = 0; + loop { + if i == n / 2 { + break; + } + if i_low - i < 0 { + idxes.append(i_high + i); + i_high += 1; + } else { + idxes.append(i_low - i); + } + if i_high + i >= limit { + i_low -= 1; + idxes.append(i_low - i); + } else { + idxes.append(i_high + i); + } + i += 1; + } + } else { + core::panic_with_felt252('MUST BE EVEN'); + } + + idxes = bubble_sort::bubble_sort_elements(idxes); + + let mut shape = ArrayTrait::new(); + shape.append(n); + + return TensorTrait::new(shape.span(), idxes.span()); +} + +fn linear_coeffs< + T, + MAG, + +NumberTrait, + +PartialOrd, + +PartialEq, + +TensorTrait, + +Copy, + +Drop, + +Sub +>( + mut ratio: T +) -> Tensor { + let mut ret = ArrayTrait::new(); + let mut shape = ArrayTrait::new(); + shape.append(2); + ret.append(NumberTrait::one() - ratio); + ret.append(ratio); + return TensorTrait::new(shape.span(), ret.span()); +} + + +fn linear_coeffs_antialias< + T, + MAG, + +TensorTrait, + +NumberTrait, + +PartialOrd, + +PartialEq, + +Copy, + +Drop, + +TryInto, + +Into, + +AddEq, + +Add, + +Div, + +Mul, + +Sub, +>( + mut ratio: T, scale: T +) -> Tensor { + let scale = NumberTrait::min(scale, NumberTrait::one()); + let start = (NumberTrait::floor(NumberTrait::neg(NumberTrait::one()) / scale) + + NumberTrait::one()); + let footprint = (NumberTrait::one() + NumberTrait::one()) + - (NumberTrait::one() + NumberTrait::one()) * start; + + let mut coeffs = ArrayTrait::::new(); + let mut sum = NumberTrait::zero(); + + // arange and clip + compute sum + let mut i = start; + loop { + if i == start + footprint { + break; + } + let value = NumberTrait::one() - NumberTrait::abs((i - ratio) * scale); + + if value < NumberTrait::zero() { + coeffs.append(NumberTrait::zero()); + } else if value > NumberTrait::one() { + coeffs.append(NumberTrait::one()); + sum += NumberTrait::one(); + } else { + coeffs.append(value); + sum += value; + } + i += NumberTrait::one(); + }; + + let n = coeffs.len(); + + let mut coeff_div = ArrayTrait::::new(); + let mut i = 0; + loop { + if i == n { + break; + } + coeff_div.append(*coeffs.at(i) / sum); + i += 1; + }; + + let mut shape = ArrayTrait::new(); + shape.append(n); + + return TensorTrait::new(shape.span(), coeff_div.span()); +} + +fn cubic_coeffs< + T, + MAG, + +TensorTrait, + +NumberTrait, + +PartialOrd, + +PartialEq, + +Copy, + +Drop, + +AddEq, + +Add, + +Div, + +Mul, + +Sub, +>( + mut ratio: T, A: Option +) -> Tensor { + let one = NumberTrait::one(); + let two = one + NumberTrait::one(); + let three = two + NumberTrait::one(); + let four = three + NumberTrait::one(); + let five = four + NumberTrait::one(); + let five = four + NumberTrait::one(); + let eigth = four + four; + + let A = match A { + Option::Some(A) => A, + Option::None(_) => { NumberTrait::neg(three / four) }, + }; + + let mut coeffs = ArrayTrait::new(); + let mut shape = ArrayTrait::new(); + + coeffs + .append( + ((A * (ratio + one) - five * A) * (ratio + one) + eigth * A) * (ratio + one) - four * A + ); + coeffs.append(((A + two) * ratio - (A + three)) * ratio * ratio + one); + coeffs.append(((A + two) * (one - ratio) - (A + three)) * (one - ratio) * (one - ratio) + one); + coeffs + .append( + ((A * ((one - ratio) + one) - five * A) * ((one - ratio) + one) + eigth * A) + * ((one - ratio) + one) + - four * A + ); + + shape.append(4); + return TensorTrait::new(shape.span(), coeffs.span()); +} + +fn cubic_coeffs_antialias< + T, + MAG, + +TensorTrait, + +NumberTrait, + +PartialOrd, + +PartialEq, + +Copy, + +Drop, + +TryInto, + +Into, + +AddEq, + +Add, + +Div, + +Mul, + +Sub, +>( + mut ratio: T, scale: T, A: Option +) -> Tensor { + let one = NumberTrait::one(); + let two = one + NumberTrait::one(); + let three = two + NumberTrait::one(); + let four = three + NumberTrait::one(); + let five = four + NumberTrait::one(); + let five = four + NumberTrait::one(); + let eigth = four + four; + + let scale = NumberTrait::min(scale, NumberTrait::one()); + + let i_start = NumberTrait::floor(NumberTrait::neg(two) / scale) + NumberTrait::one(); + let i_end = two - i_start; + assert(i_end > i_start, 'i_end must be greater'); + + let A = match A { + Option::Some(A) => A, + Option::None(_) => { NumberTrait::neg(three / four) }, + }; + + let mut coeffs = ArrayTrait::new(); + let mut sum = NumberTrait::zero(); + + let mut i = i_start; + loop { + if i == i_end { + break; + } + let value = compute_coeff(scale * (i - ratio), A); + coeffs.append(value); + sum += value; + + i += NumberTrait::one(); + }; + + let n = coeffs.len(); + + let mut coeff_div = ArrayTrait::::new(); + let mut i = 0; + loop { + if i == n { + break; + } + coeff_div.append(*coeffs.at(i) / sum); + i += 1; + }; + + let mut shape = ArrayTrait::new(); + shape.append(n); + + return TensorTrait::new(shape.span(), coeff_div.span()); +} + +fn compute_coeff< + T, + MAG, + +TensorTrait, + +NumberTrait, + +PartialOrd, + +PartialEq, + +Copy, + +Drop, + +AddEq, + +Add, + +Div, + +Mul, + +Sub, +>( + mut x: T, A: T +) -> T { + let one = NumberTrait::one(); + let two = one + NumberTrait::one(); + let three = two + NumberTrait::one(); + let four = three + NumberTrait::one(); + let five = four + NumberTrait::one(); + let eigth = four + four; + + x = x.abs(); + let mut x_2 = x * x; + let mut x_3 = x * x_2; + if x <= one { + return (A + two) * x_3 - (A + three) * x_2 + one; + } + if x < two { + return A * x_3 - five * A * x_2 + eigth * A * x - four * A; + } + return NumberTrait::zero(); +} + + +fn nearest_coeffs< + T, + MAG, + +TensorTrait, + +NumberTrait, + +PartialOrd, + +PartialEq, + +Copy, + +Drop, + +AddEq, + +Add, + +Div, + +Mul, + +Sub, +>( + mut ratio: T, nearest_mode: Option +) -> Tensor { + let nearest_mode = match nearest_mode { + Option::Some(nearest_mode) => { nearest_mode }, + Option::None(_) => { NEAREST_MODE::ROUND_PREFER_FLOOR }, + }; + + let mut ret = ArrayTrait::new(); + let mut shape = ArrayTrait::new(); + shape.append(2); + + // CHECK SI C'EST UNE CONDITION ASSEZ GENERALE + if ratio == NumberTrait::one() { + ret.append(NumberTrait::zero()); + ret.append(NumberTrait::one()); + return TensorTrait::new(shape.span(), ret.span()); + } + + match nearest_mode { + NEAREST_MODE::ROUND_PREFER_FLOOR => { + if ratio <= NumberTrait::half() { + ret.append(NumberTrait::one()); + ret.append(NumberTrait::zero()); + return TensorTrait::new(shape.span(), ret.span()); + } else { + ret.append(NumberTrait::zero()); + ret.append(NumberTrait::one()); + return TensorTrait::new(shape.span(), ret.span()); + } + }, + NEAREST_MODE::ROUND_PREFER_CEIL => { + if ratio < NumberTrait::half() { + ret.append(NumberTrait::one()); + ret.append(NumberTrait::zero()); + return TensorTrait::new(shape.span(), ret.span()); + } else { + ret.append(NumberTrait::zero()); + ret.append(NumberTrait::one()); + return TensorTrait::new(shape.span(), ret.span()); + } + }, + NEAREST_MODE::FLOOR => { + ret.append(NumberTrait::one()); + ret.append(NumberTrait::zero()); + return TensorTrait::new(shape.span(), ret.span()); + }, + NEAREST_MODE::CEIL => { + ret.append(NumberTrait::zero()); + ret.append(NumberTrait::one()); + return TensorTrait::new(shape.span(), ret.span()); + }, + } +} + diff --git a/tests/lib.cairo b/tests/lib.cairo index f5cecb77d..661d3cf0e 100644 --- a/tests/lib.cairo +++ b/tests/lib.cairo @@ -1,7 +1,8 @@ -mod numbers; -mod performance; -mod tensor_core; +//mod numbers; +//mod performance; +//mod tensor_core; mod nodes; -mod ml; -mod operators; +//mod ml; +//mod operators; + diff --git a/tests/nodes.cairo b/tests/nodes.cairo index c7155e942..cc9f82387 100644 --- a/tests/nodes.cairo +++ b/tests/nodes.cairo @@ -850,3 +850,40 @@ mod gather_nd_i8_3d_batch_dims1; mod gather_nd_u32_default; mod gather_nd_u32_batch_dims1; mod gather_nd_u32_batch_dims2; +mod resize_upsample_scales_nearest; +mod resize_downsample_scales_cubic; +mod resize_downsample_scales_cubic_A_n0p5_exclude_outside; +mod resize_downsample_scales_cubic_align_corners; +mod resize_upsample_scales_linear; +mod resize_downsample_scales_linear_align_corners; +mod resize_downsample_scales_nearest; +mod resize_upsample_scales_cubic; +mod resize_upsample_scales_cubic_A_n0p5_exclude_outside; +mod resize_upsample_scales_cubic_align_corners; +mod resize_upsample_scales_cubic_asymmetric; +mod resize_upsample_scales_linear_align_corners; +mod resize_upsample_sizes_nearest; +mod resize_upsample_sizes_cubic; +mod resize_downsample_sizes_cubic; +mod resize_downsample_sizes_nearest; +mod resize_upsample_scales_linear_half_pixel_symmetric; +mod resize_downsample_scales_cubic_antialias; +mod resize_downsample_scales_linear_antialias; +mod resize_downsample_sizes_cubic_antialias; +mod resize_downsample_sizes_linear_pytorch_half_pixel; +mod resize_tf_crop_and_resize; +mod resize_tf_crop_and_resize_extrapolation_value; +mod resize_upsample_scales_nearest_axes_2_3; +mod resize_upsample_scales_nearest_axes_3_2; +mod resize_upsample_sizes_nearest_axes_2_3; +mod resize_upsample_sizes_nearest_ceil_half_pixel; +mod resize_upsample_sizes_nearest_floor_align_corners; +mod resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric; +mod resize_downsample_scales_linear_half_pixel_symmetric; +mod resize_downsample_sizes_nearest_not_larger; +mod resize_downsample_sizes_nearest_not_smaller; +mod resize_tf_crop_and_resize_axes_2_3; +mod resize_tf_crop_and_resize_axes_3_2; +mod resize_upsample_sizes_nearest_axes_3_2; +mod resize_upsample_sizes_nearest_not_larger; +mod resize_upsample_sizes_nearest_not_smaller; diff --git a/tests/nodes/gather_nd_fp16x16_3d_batch_dims1.cairo b/tests/nodes/gather_nd_fp16x16_3d_batch_dims1.cairo index d2c0b80dd..025cc8261 100644 --- a/tests/nodes/gather_nd_fp16x16_3d_batch_dims1.cairo +++ b/tests/nodes/gather_nd_fp16x16_3d_batch_dims1.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_fp16x16_3d_batch_dims1() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(1)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(1)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_fp16x16_3d_batch_dims2.cairo b/tests/nodes/gather_nd_fp16x16_3d_batch_dims2.cairo index 507847851..677a40f6a 100644 --- a/tests/nodes/gather_nd_fp16x16_3d_batch_dims2.cairo +++ b/tests/nodes/gather_nd_fp16x16_3d_batch_dims2.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_fp16x16_3d_batch_dims2() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(2)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(2)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_fp16x16_3d_default.cairo b/tests/nodes/gather_nd_fp16x16_3d_default.cairo index ae4609a66..b8339a0d2 100644 --- a/tests/nodes/gather_nd_fp16x16_3d_default.cairo +++ b/tests/nodes/gather_nd_fp16x16_3d_default.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_fp16x16_3d_default() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(0)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(0)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_fp8x23_3d_batch_dims1.cairo b/tests/nodes/gather_nd_fp8x23_3d_batch_dims1.cairo index b9a083796..65980d91f 100644 --- a/tests/nodes/gather_nd_fp8x23_3d_batch_dims1.cairo +++ b/tests/nodes/gather_nd_fp8x23_3d_batch_dims1.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_fp8x23_3d_batch_dims1() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(1)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(1)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_fp8x23_3d_batch_dims2.cairo b/tests/nodes/gather_nd_fp8x23_3d_batch_dims2.cairo index 5e42ca893..48c812baf 100644 --- a/tests/nodes/gather_nd_fp8x23_3d_batch_dims2.cairo +++ b/tests/nodes/gather_nd_fp8x23_3d_batch_dims2.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_fp8x23_3d_batch_dims2() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(2)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(2)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_fp8x23_3d_default.cairo b/tests/nodes/gather_nd_fp8x23_3d_default.cairo index 12b6408e0..342cd2b72 100644 --- a/tests/nodes/gather_nd_fp8x23_3d_default.cairo +++ b/tests/nodes/gather_nd_fp8x23_3d_default.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_fp8x23_3d_default() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(0)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(0)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_i32_3d_batch_dims1.cairo b/tests/nodes/gather_nd_i32_3d_batch_dims1.cairo index 243b0ca16..318ccd62e 100644 --- a/tests/nodes/gather_nd_i32_3d_batch_dims1.cairo +++ b/tests/nodes/gather_nd_i32_3d_batch_dims1.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_i32_3d_batch_dims1() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(1)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(1)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_i32_3d_batch_dims2.cairo b/tests/nodes/gather_nd_i32_3d_batch_dims2.cairo index d11370b94..177c8e40f 100644 --- a/tests/nodes/gather_nd_i32_3d_batch_dims2.cairo +++ b/tests/nodes/gather_nd_i32_3d_batch_dims2.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_i32_3d_batch_dims2() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(2)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(2)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_i32_3d_default.cairo b/tests/nodes/gather_nd_i32_3d_default.cairo index 35c054093..97212f737 100644 --- a/tests/nodes/gather_nd_i32_3d_default.cairo +++ b/tests/nodes/gather_nd_i32_3d_default.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_i32_3d_default() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(0)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(0)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_i8_3d_batch_dims1.cairo b/tests/nodes/gather_nd_i8_3d_batch_dims1.cairo index ae83a8c7d..f849c8677 100644 --- a/tests/nodes/gather_nd_i8_3d_batch_dims1.cairo +++ b/tests/nodes/gather_nd_i8_3d_batch_dims1.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_i8_3d_batch_dims1() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(1)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(1)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_i8_3d_default.cairo b/tests/nodes/gather_nd_i8_3d_default.cairo index 73e1d91b2..ff7ad9252 100644 --- a/tests/nodes/gather_nd_i8_3d_default.cairo +++ b/tests/nodes/gather_nd_i8_3d_default.cairo @@ -18,7 +18,7 @@ fn test_gather_nd_i8_3d_default() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(0)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(0)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_u32_batch_dims1.cairo b/tests/nodes/gather_nd_u32_batch_dims1.cairo index 0428ec1d5..860675f66 100644 --- a/tests/nodes/gather_nd_u32_batch_dims1.cairo +++ b/tests/nodes/gather_nd_u32_batch_dims1.cairo @@ -16,7 +16,7 @@ fn test_gather_nd_u32_batch_dims1() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(1)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(1)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_u32_batch_dims2.cairo b/tests/nodes/gather_nd_u32_batch_dims2.cairo index 39857ef1d..f0662be99 100644 --- a/tests/nodes/gather_nd_u32_batch_dims2.cairo +++ b/tests/nodes/gather_nd_u32_batch_dims2.cairo @@ -16,7 +16,7 @@ fn test_gather_nd_u32_batch_dims2() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(2)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(2)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/gather_nd_u32_default.cairo b/tests/nodes/gather_nd_u32_default.cairo index f55b49d5e..be6edd699 100644 --- a/tests/nodes/gather_nd_u32_default.cairo +++ b/tests/nodes/gather_nd_u32_default.cairo @@ -16,7 +16,7 @@ fn test_gather_nd_u32_default() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.gather_nd(indices:input_1, batch_dims:Option::Some(0)); + let y_0 = input_0.gather_nd(indices: input_1, batch_dims: Option::Some(0)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/resize_downsample_scales_cubic.cairo b/tests/nodes/resize_downsample_scales_cubic.cairo new file mode 100644 index 000000000..29d3b6281 --- /dev/null +++ b/tests/nodes/resize_downsample_scales_cubic.cairo @@ -0,0 +1,39 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::FP16x16Tensor; +use orion::operators::tensor::FP16x16TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::math::resize::{ + MODE, NEAREST_MODE, KEEP_ASPECT_RATIO_POLICY, TRANSFORMATION_MODE +}; + +#[test] +#[available_gas(2000000000)] +fn test_resize_downsample_scales_cubic() { + let data = input_0::input_0(); + let mut scales = Option::Some(input_1::input_1().data); + let z_0 = output_0::output_0(); + + let y_0 = data + .resize( + Option::None, + scales, + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::Some(MODE::CUBIC), + Option::None, + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/resize_downsample_scales_cubic/input_0.cairo b/tests/nodes/resize_downsample_scales_cubic/input_0.cairo new file mode 100644 index 000000000..cd17b9adf --- /dev/null +++ b/tests/nodes/resize_downsample_scales_cubic/input_0.cairo @@ -0,0 +1,31 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(4); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 851968, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_scales_cubic/input_1.cairo b/tests/nodes/resize_downsample_scales_cubic/input_1.cairo new file mode 100644 index 000000000..f861f7954 --- /dev/null +++ b/tests/nodes/resize_downsample_scales_cubic/input_1.cairo @@ -0,0 +1,16 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 52428, sign: false }); + data.append(FP16x16 { mag: 52428, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_scales_cubic/output_0.cairo b/tests/nodes/resize_downsample_scales_cubic/output_0.cairo new file mode 100644 index 000000000..46d7658e5 --- /dev/null +++ b/tests/nodes/resize_downsample_scales_cubic/output_0.cairo @@ -0,0 +1,24 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 96416, sign: false }); + data.append(FP16x16 { mag: 182272, sign: false }); + data.append(FP16x16 { mag: 267552, sign: false }); + data.append(FP16x16 { mag: 439840, sign: false }); + data.append(FP16x16 { mag: 525696, sign: false }); + data.append(FP16x16 { mag: 610976, sign: false }); + data.append(FP16x16 { mag: 780960, sign: false }); + data.append(FP16x16 { mag: 866816, sign: false }); + data.append(FP16x16 { mag: 952096, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_scales_cubic_A_n0p5_exclude_outside.cairo b/tests/nodes/resize_downsample_scales_cubic_A_n0p5_exclude_outside.cairo new file mode 100644 index 000000000..785d81971 --- /dev/null +++ b/tests/nodes/resize_downsample_scales_cubic_A_n0p5_exclude_outside.cairo @@ -0,0 +1,40 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::FP16x16Tensor; +use orion::operators::tensor::FP16x16TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::math::resize::{ + MODE, NEAREST_MODE, KEEP_ASPECT_RATIO_POLICY, TRANSFORMATION_MODE +}; +use orion::numbers::{FP16x16, FP16x16Impl, FP32x32, FP32x32Impl, FixedTrait}; + +#[test] +#[available_gas(2000000000)] +fn test_resize_downsample_scales_cubic_A_n0p5_exclude_outside() { + let data = input_0::input_0(); + let mut scales = Option::Some(input_1::input_1().data); + let z_0 = output_0::output_0(); + + let y_0 = data + .resize( + Option::None, + scales, + Option::None, + Option::None, + Option::None, + Option::None, + Option::Some(FixedTrait::::new(32768, true)), + Option::Some(true), + Option::None, + Option::None, + Option::Some(MODE::CUBIC), + Option::None, + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/resize_downsample_scales_cubic_A_n0p5_exclude_outside/input_0.cairo b/tests/nodes/resize_downsample_scales_cubic_A_n0p5_exclude_outside/input_0.cairo new file mode 100644 index 000000000..cd17b9adf --- /dev/null +++ b/tests/nodes/resize_downsample_scales_cubic_A_n0p5_exclude_outside/input_0.cairo @@ -0,0 +1,31 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(4); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 851968, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_scales_cubic_A_n0p5_exclude_outside/input_1.cairo b/tests/nodes/resize_downsample_scales_cubic_A_n0p5_exclude_outside/input_1.cairo new file mode 100644 index 000000000..f861f7954 --- /dev/null +++ b/tests/nodes/resize_downsample_scales_cubic_A_n0p5_exclude_outside/input_1.cairo @@ -0,0 +1,16 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 52428, sign: false }); + data.append(FP16x16 { mag: 52428, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_scales_cubic_A_n0p5_exclude_outside/input_2.cairo b/tests/nodes/resize_downsample_scales_cubic_A_n0p5_exclude_outside/input_2.cairo new file mode 100644 index 000000000..775bcc066 --- /dev/null +++ b/tests/nodes/resize_downsample_scales_cubic_A_n0p5_exclude_outside/input_2.cairo @@ -0,0 +1,13 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_2() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(3); + data.append(2); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_scales_cubic_A_n0p5_exclude_outside/output_0.cairo b/tests/nodes/resize_downsample_scales_cubic_A_n0p5_exclude_outside/output_0.cairo new file mode 100644 index 000000000..4ce3f227c --- /dev/null +++ b/tests/nodes/resize_downsample_scales_cubic_A_n0p5_exclude_outside/output_0.cairo @@ -0,0 +1,24 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 89661, sign: false }); + data.append(FP16x16 { mag: 174948, sign: false }); + data.append(FP16x16 { mag: 263018, sign: false }); + data.append(FP16x16 { mag: 430809, sign: false }); + data.append(FP16x16 { mag: 516096, sign: false }); + data.append(FP16x16 { mag: 604165, sign: false }); + data.append(FP16x16 { mag: 783087, sign: false }); + data.append(FP16x16 { mag: 868374, sign: false }); + data.append(FP16x16 { mag: 956443, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_scales_cubic_align_corners.cairo b/tests/nodes/resize_downsample_scales_cubic_align_corners.cairo new file mode 100644 index 000000000..200326825 --- /dev/null +++ b/tests/nodes/resize_downsample_scales_cubic_align_corners.cairo @@ -0,0 +1,39 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::FP16x16Tensor; +use orion::operators::tensor::FP16x16TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::math::resize::{ + MODE, NEAREST_MODE, KEEP_ASPECT_RATIO_POLICY, TRANSFORMATION_MODE +}; + +#[test] +#[available_gas(2000000000)] +fn test_resize_downsample_scales_cubic_align_corners() { + let data = input_0::input_0(); + let mut scales = Option::Some(input_1::input_1().data); + let z_0 = output_0::output_0(); + + let y_0 = data + .resize( + Option::None, + scales, + Option::None, + Option::None, + Option::None, + Option::Some(TRANSFORMATION_MODE::ALIGN_CORNERS), + Option::None, + Option::None, + Option::None, + Option::None, + Option::Some(MODE::CUBIC), + Option::None, + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/resize_downsample_scales_cubic_align_corners/input_0.cairo b/tests/nodes/resize_downsample_scales_cubic_align_corners/input_0.cairo new file mode 100644 index 000000000..cd17b9adf --- /dev/null +++ b/tests/nodes/resize_downsample_scales_cubic_align_corners/input_0.cairo @@ -0,0 +1,31 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(4); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 851968, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_scales_cubic_align_corners/input_1.cairo b/tests/nodes/resize_downsample_scales_cubic_align_corners/input_1.cairo new file mode 100644 index 000000000..f861f7954 --- /dev/null +++ b/tests/nodes/resize_downsample_scales_cubic_align_corners/input_1.cairo @@ -0,0 +1,16 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 52428, sign: false }); + data.append(FP16x16 { mag: 52428, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_scales_cubic_align_corners/output_0.cairo b/tests/nodes/resize_downsample_scales_cubic_align_corners/output_0.cairo new file mode 100644 index 000000000..713b36edd --- /dev/null +++ b/tests/nodes/resize_downsample_scales_cubic_align_corners/output_0.cairo @@ -0,0 +1,24 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 156971, sign: false }); + data.append(FP16x16 { mag: 248406, sign: false }); + data.append(FP16x16 { mag: 431277, sign: false }); + data.append(FP16x16 { mag: 522712, sign: false }); + data.append(FP16x16 { mag: 614147, sign: false }); + data.append(FP16x16 { mag: 797018, sign: false }); + data.append(FP16x16 { mag: 888453, sign: false }); + data.append(FP16x16 { mag: 979888, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_scales_cubic_antialias.cairo b/tests/nodes/resize_downsample_scales_cubic_antialias.cairo new file mode 100644 index 000000000..230171410 --- /dev/null +++ b/tests/nodes/resize_downsample_scales_cubic_antialias.cairo @@ -0,0 +1,39 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::operators::tensor::FP16x16TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::FP16x16Tensor; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::math::resize::{ + MODE, NEAREST_MODE, KEEP_ASPECT_RATIO_POLICY, TRANSFORMATION_MODE +}; + +#[test] +#[available_gas(2000000000)] +fn test_resize_downsample_scales_cubic_antialias() { + let data = input_0::input_0(); + let mut scales = Option::Some(input_1::input_1().data); + let z_0 = output_0::output_0(); + + let y_0 = data + .resize( + Option::None, + scales, + Option::None, + Option::Some(1), + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::Some(MODE::CUBIC), + Option::None, + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/resize_downsample_scales_cubic_antialias/input_0.cairo b/tests/nodes/resize_downsample_scales_cubic_antialias/input_0.cairo new file mode 100644 index 000000000..cd17b9adf --- /dev/null +++ b/tests/nodes/resize_downsample_scales_cubic_antialias/input_0.cairo @@ -0,0 +1,31 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(4); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 851968, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_scales_cubic_antialias/input_1.cairo b/tests/nodes/resize_downsample_scales_cubic_antialias/input_1.cairo new file mode 100644 index 000000000..401a2622c --- /dev/null +++ b/tests/nodes/resize_downsample_scales_cubic_antialias/input_1.cairo @@ -0,0 +1,16 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 39321, sign: false }); + data.append(FP16x16 { mag: 39321, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_scales_cubic_antialias/output_0.cairo b/tests/nodes/resize_downsample_scales_cubic_antialias/output_0.cairo new file mode 100644 index 000000000..d2cb459b1 --- /dev/null +++ b/tests/nodes/resize_downsample_scales_cubic_antialias/output_0.cairo @@ -0,0 +1,19 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 165024, sign: false }); + data.append(FP16x16 { mag: 280879, sign: false }); + data.append(FP16x16 { mag: 628446, sign: false }); + data.append(FP16x16 { mag: 744301, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_scales_linear_align_corners.cairo b/tests/nodes/resize_downsample_scales_linear_align_corners.cairo new file mode 100644 index 000000000..c4b4d21a0 --- /dev/null +++ b/tests/nodes/resize_downsample_scales_linear_align_corners.cairo @@ -0,0 +1,39 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::FP16x16Tensor; +use orion::operators::tensor::FP16x16TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::math::resize::{ + MODE, NEAREST_MODE, KEEP_ASPECT_RATIO_POLICY, TRANSFORMATION_MODE +}; + +#[test] +#[available_gas(2000000000)] +fn test_resize_downsample_scales_linear_align_corners() { + let data = input_0::input_0(); + let mut scales = Option::Some(input_1::input_1().data); + let z_0 = output_0::output_0(); + + let y_0 = data + .resize( + Option::None, + scales, + Option::None, + Option::None, + Option::None, + Option::Some(TRANSFORMATION_MODE::ALIGN_CORNERS), + Option::None, + Option::None, + Option::None, + Option::None, + Option::Some(MODE::LINEAR), + Option::None, + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/resize_downsample_scales_linear_align_corners/input_0.cairo b/tests/nodes/resize_downsample_scales_linear_align_corners/input_0.cairo new file mode 100644 index 000000000..90a00e3c1 --- /dev/null +++ b/tests/nodes/resize_downsample_scales_linear_align_corners/input_0.cairo @@ -0,0 +1,23 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(2); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_scales_linear_align_corners/input_1.cairo b/tests/nodes/resize_downsample_scales_linear_align_corners/input_1.cairo new file mode 100644 index 000000000..401a2622c --- /dev/null +++ b/tests/nodes/resize_downsample_scales_linear_align_corners/input_1.cairo @@ -0,0 +1,16 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 39321, sign: false }); + data.append(FP16x16 { mag: 39321, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_scales_linear_align_corners/output_0.cairo b/tests/nodes/resize_downsample_scales_linear_align_corners/output_0.cairo new file mode 100644 index 000000000..fdcbd59d8 --- /dev/null +++ b/tests/nodes/resize_downsample_scales_linear_align_corners/output_0.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(1); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 205970, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_scales_linear_antialias.cairo b/tests/nodes/resize_downsample_scales_linear_antialias.cairo new file mode 100644 index 000000000..be072a81d --- /dev/null +++ b/tests/nodes/resize_downsample_scales_linear_antialias.cairo @@ -0,0 +1,39 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::operators::tensor::FP16x16TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::FP16x16Tensor; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::math::resize::{ + MODE, NEAREST_MODE, KEEP_ASPECT_RATIO_POLICY, TRANSFORMATION_MODE +}; + +#[test] +#[available_gas(2000000000)] +fn test_resize_downsample_scales_linear_antialias() { + let data = input_0::input_0(); + let mut scales = Option::Some(input_1::input_1().data); + let z_0 = output_0::output_0(); + + let y_0 = data + .resize( + Option::None, + scales, + Option::None, + Option::Some(1), + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::Some(MODE::LINEAR), + Option::None, + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/resize_downsample_scales_linear_antialias/input_0.cairo b/tests/nodes/resize_downsample_scales_linear_antialias/input_0.cairo new file mode 100644 index 000000000..cd17b9adf --- /dev/null +++ b/tests/nodes/resize_downsample_scales_linear_antialias/input_0.cairo @@ -0,0 +1,31 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(4); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 851968, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_scales_linear_antialias/input_1.cairo b/tests/nodes/resize_downsample_scales_linear_antialias/input_1.cairo new file mode 100644 index 000000000..401a2622c --- /dev/null +++ b/tests/nodes/resize_downsample_scales_linear_antialias/input_1.cairo @@ -0,0 +1,16 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 39321, sign: false }); + data.append(FP16x16 { mag: 39321, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_scales_linear_antialias/output_0.cairo b/tests/nodes/resize_downsample_scales_linear_antialias/output_0.cairo new file mode 100644 index 000000000..a5fa2d1e6 --- /dev/null +++ b/tests/nodes/resize_downsample_scales_linear_antialias/output_0.cairo @@ -0,0 +1,19 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 188416, sign: false }); + data.append(FP16x16 { mag: 294912, sign: false }); + data.append(FP16x16 { mag: 614400, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_scales_linear_half_pixel_symmetric.cairo b/tests/nodes/resize_downsample_scales_linear_half_pixel_symmetric.cairo new file mode 100644 index 000000000..3c5394cb2 --- /dev/null +++ b/tests/nodes/resize_downsample_scales_linear_half_pixel_symmetric.cairo @@ -0,0 +1,39 @@ +mod input_0; +mod input_1; +mod output_0; + + +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::FP16x16Tensor; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::operators::tensor::math::resize::{ + MODE, NEAREST_MODE, KEEP_ASPECT_RATIO_POLICY, TRANSFORMATION_MODE +}; + +#[test] +#[available_gas(2000000000)] +fn test_resize_downsample_scales_linear_half_pixel_symmetric() { + let data = input_0::input_0(); + let mut scales = Option::Some(input_1::input_1().data); + let z_0 = output_0::output_0(); + + let y_0 = data + .resize( + Option::None, + scales, + Option::None, + Option::None, + Option::None, + Option::Some(TRANSFORMATION_MODE::HALF_PIXEL_SYMMETRIC), + Option::None, + Option::None, + Option::None, + Option::None, + Option::Some(MODE::LINEAR), + Option::None, + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/resize_downsample_scales_linear_half_pixel_symmetric/input_0.cairo b/tests/nodes/resize_downsample_scales_linear_half_pixel_symmetric/input_0.cairo new file mode 100644 index 000000000..79947f581 --- /dev/null +++ b/tests/nodes/resize_downsample_scales_linear_half_pixel_symmetric/input_0.cairo @@ -0,0 +1,19 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(1); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_scales_linear_half_pixel_symmetric/input_1.cairo b/tests/nodes/resize_downsample_scales_linear_half_pixel_symmetric/input_1.cairo new file mode 100644 index 000000000..4ec2dc2c5 --- /dev/null +++ b/tests/nodes/resize_downsample_scales_linear_half_pixel_symmetric/input_1.cairo @@ -0,0 +1,16 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 39321, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_scales_linear_half_pixel_symmetric/output_0.cairo b/tests/nodes/resize_downsample_scales_linear_half_pixel_symmetric/output_0.cairo new file mode 100644 index 000000000..865747871 --- /dev/null +++ b/tests/nodes/resize_downsample_scales_linear_half_pixel_symmetric/output_0.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(1); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 109226, sign: false }); + data.append(FP16x16 { mag: 218453, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_scales_nearest.cairo b/tests/nodes/resize_downsample_scales_nearest.cairo new file mode 100644 index 000000000..8a1a089f4 --- /dev/null +++ b/tests/nodes/resize_downsample_scales_nearest.cairo @@ -0,0 +1,39 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::FP16x16Tensor; +use orion::operators::tensor::FP16x16TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::math::resize::{ + MODE, NEAREST_MODE, KEEP_ASPECT_RATIO_POLICY, TRANSFORMATION_MODE +}; + +#[test] +#[available_gas(2000000000)] +fn test_resize_downsample_scales_nearest() { + let data = input_0::input_0(); + let mut scales = Option::Some(input_1::input_1().data); + let z_0 = output_0::output_0(); + + let y_0 = data + .resize( + Option::None, + scales, + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::Some(MODE::NEAREST), + Option::None, + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/resize_downsample_scales_nearest/input_0.cairo b/tests/nodes/resize_downsample_scales_nearest/input_0.cairo new file mode 100644 index 000000000..90a00e3c1 --- /dev/null +++ b/tests/nodes/resize_downsample_scales_nearest/input_0.cairo @@ -0,0 +1,23 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(2); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_scales_nearest/input_1.cairo b/tests/nodes/resize_downsample_scales_nearest/input_1.cairo new file mode 100644 index 000000000..401a2622c --- /dev/null +++ b/tests/nodes/resize_downsample_scales_nearest/input_1.cairo @@ -0,0 +1,16 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 39321, sign: false }); + data.append(FP16x16 { mag: 39321, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_scales_nearest/output_0.cairo b/tests/nodes/resize_downsample_scales_nearest/output_0.cairo new file mode 100644 index 000000000..ba0aaa9bf --- /dev/null +++ b/tests/nodes/resize_downsample_scales_nearest/output_0.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(1); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_sizes_cubic.cairo b/tests/nodes/resize_downsample_sizes_cubic.cairo new file mode 100644 index 000000000..ac560cf4c --- /dev/null +++ b/tests/nodes/resize_downsample_sizes_cubic.cairo @@ -0,0 +1,40 @@ +mod input_0; +mod input_1; +mod output_0; + + +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::FP16x16Tensor; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::math::resize::{ + MODE, NEAREST_MODE, KEEP_ASPECT_RATIO_POLICY, TRANSFORMATION_MODE +}; + + +#[test] +#[available_gas(2000000000)] +fn test_resize_downsample_sizes_cubic() { + let data = input_0::input_0(); + let mut sizes = Option::Some(input_1::input_1().data); + let z_0 = output_0::output_0(); + + let y_0 = data + .resize( + Option::None, + Option::None, + sizes, + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::Some(MODE::CUBIC), + Option::None, + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/resize_downsample_sizes_cubic/input_0.cairo b/tests/nodes/resize_downsample_sizes_cubic/input_0.cairo new file mode 100644 index 000000000..cd17b9adf --- /dev/null +++ b/tests/nodes/resize_downsample_sizes_cubic/input_0.cairo @@ -0,0 +1,31 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(4); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 851968, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_sizes_cubic/input_1.cairo b/tests/nodes/resize_downsample_sizes_cubic/input_1.cairo new file mode 100644 index 000000000..9eab8d083 --- /dev/null +++ b/tests/nodes/resize_downsample_sizes_cubic/input_1.cairo @@ -0,0 +1,15 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(1); + data.append(1); + data.append(3); + data.append(3); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_sizes_cubic/output_0.cairo b/tests/nodes/resize_downsample_sizes_cubic/output_0.cairo new file mode 100644 index 000000000..14e3cf8c5 --- /dev/null +++ b/tests/nodes/resize_downsample_sizes_cubic/output_0.cairo @@ -0,0 +1,24 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 106875, sign: false }); + data.append(FP16x16 { mag: 196911, sign: false }); + data.append(FP16x16 { mag: 286947, sign: false }); + data.append(FP16x16 { mag: 467019, sign: false }); + data.append(FP16x16 { mag: 557056, sign: false }); + data.append(FP16x16 { mag: 647092, sign: false }); + data.append(FP16x16 { mag: 827164, sign: false }); + data.append(FP16x16 { mag: 917200, sign: false }); + data.append(FP16x16 { mag: 1007236, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_sizes_cubic_antialias.cairo b/tests/nodes/resize_downsample_sizes_cubic_antialias.cairo new file mode 100644 index 000000000..6f31c3678 --- /dev/null +++ b/tests/nodes/resize_downsample_sizes_cubic_antialias.cairo @@ -0,0 +1,41 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::operators::tensor::FP16x16TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::FP16x16Tensor; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::U32Tensor; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32TensorPartialEq; +use orion::operators::tensor::math::resize::{ + MODE, NEAREST_MODE, KEEP_ASPECT_RATIO_POLICY, TRANSFORMATION_MODE +}; + +#[test] +#[available_gas(2000000000)] +fn test_resize_downsample_sizes_cubic_antialias() { + let data = input_0::input_0(); + let mut sizes = Option::Some(input_1::input_1().data); + let z_0 = output_0::output_0(); + + let y_0 = data + .resize( + Option::None, + Option::None, + sizes, + Option::Some(1), + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::Some(MODE::CUBIC), + Option::None, + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/resize_downsample_sizes_cubic_antialias/input_0.cairo b/tests/nodes/resize_downsample_sizes_cubic_antialias/input_0.cairo new file mode 100644 index 000000000..cd17b9adf --- /dev/null +++ b/tests/nodes/resize_downsample_sizes_cubic_antialias/input_0.cairo @@ -0,0 +1,31 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(4); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 851968, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_sizes_cubic_antialias/input_1.cairo b/tests/nodes/resize_downsample_sizes_cubic_antialias/input_1.cairo new file mode 100644 index 000000000..9eab8d083 --- /dev/null +++ b/tests/nodes/resize_downsample_sizes_cubic_antialias/input_1.cairo @@ -0,0 +1,15 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(1); + data.append(1); + data.append(3); + data.append(3); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_sizes_cubic_antialias/output_0.cairo b/tests/nodes/resize_downsample_sizes_cubic_antialias/output_0.cairo new file mode 100644 index 000000000..3e02030e8 --- /dev/null +++ b/tests/nodes/resize_downsample_sizes_cubic_antialias/output_0.cairo @@ -0,0 +1,24 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 116327, sign: false }); + data.append(FP16x16 { mag: 204472, sign: false }); + data.append(FP16x16 { mag: 292618, sign: false }); + data.append(FP16x16 { mag: 468910, sign: false }); + data.append(FP16x16 { mag: 557056, sign: false }); + data.append(FP16x16 { mag: 645201, sign: false }); + data.append(FP16x16 { mag: 821493, sign: false }); + data.append(FP16x16 { mag: 909639, sign: false }); + data.append(FP16x16 { mag: 997785, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_sizes_linear_pytorch_half_pixel.cairo b/tests/nodes/resize_downsample_sizes_linear_pytorch_half_pixel.cairo new file mode 100644 index 000000000..7064afbcf --- /dev/null +++ b/tests/nodes/resize_downsample_sizes_linear_pytorch_half_pixel.cairo @@ -0,0 +1,41 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::operators::tensor::FP16x16TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::FP16x16Tensor; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::U32Tensor; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32TensorPartialEq; +use orion::operators::tensor::math::resize::{ + MODE, NEAREST_MODE, KEEP_ASPECT_RATIO_POLICY, TRANSFORMATION_MODE +}; + +#[test] +#[available_gas(2000000000)] +fn test_resize_downsample_sizes_linear_pytorch_half_pixel() { + let data = input_0::input_0(); + let mut sizes = Option::Some(input_1::input_1().data); + let z_0 = output_0::output_0(); + + let y_0 = data + .resize( + Option::None, + Option::None, + sizes, + Option::None, + Option::None, + Option::Some(TRANSFORMATION_MODE::PYTORCH_HALF_PIXEL), + Option::None, + Option::None, + Option::None, + Option::None, + Option::Some(MODE::LINEAR), + Option::None, + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/resize_downsample_sizes_linear_pytorch_half_pixel/input_0.cairo b/tests/nodes/resize_downsample_sizes_linear_pytorch_half_pixel/input_0.cairo new file mode 100644 index 000000000..cd17b9adf --- /dev/null +++ b/tests/nodes/resize_downsample_sizes_linear_pytorch_half_pixel/input_0.cairo @@ -0,0 +1,31 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(4); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 851968, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_sizes_linear_pytorch_half_pixel/input_1.cairo b/tests/nodes/resize_downsample_sizes_linear_pytorch_half_pixel/input_1.cairo new file mode 100644 index 000000000..9513937ed --- /dev/null +++ b/tests/nodes/resize_downsample_sizes_linear_pytorch_half_pixel/input_1.cairo @@ -0,0 +1,15 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(1); + data.append(1); + data.append(3); + data.append(1); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_sizes_linear_pytorch_half_pixel/output_0.cairo b/tests/nodes/resize_downsample_sizes_linear_pytorch_half_pixel/output_0.cairo new file mode 100644 index 000000000..6443746ef --- /dev/null +++ b/tests/nodes/resize_downsample_sizes_linear_pytorch_half_pixel/output_0.cairo @@ -0,0 +1,18 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(3); + shape.append(1); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 109226, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 808277, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_sizes_nearest.cairo b/tests/nodes/resize_downsample_sizes_nearest.cairo new file mode 100644 index 000000000..0c512b8a4 --- /dev/null +++ b/tests/nodes/resize_downsample_sizes_nearest.cairo @@ -0,0 +1,43 @@ +mod input_0; +mod input_1; +mod output_0; + + +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::FP16x16Tensor; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::operators::tensor::U32Tensor; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::U32TensorPartialEq; + +use orion::operators::tensor::math::resize::{ + MODE, NEAREST_MODE, KEEP_ASPECT_RATIO_POLICY, TRANSFORMATION_MODE +}; + + +#[test] +#[available_gas(2000000000)] +fn test_resize_downsample_sizes_nearest() { + let data = input_0::input_0(); + let mut sizes = Option::Some(input_1::input_1().data); + let z_0 = output_0::output_0(); + + let y_0 = data + .resize( + Option::None, + Option::None, + sizes, + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::Some(MODE::NEAREST), + Option::None, + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/resize_downsample_sizes_nearest/input_0.cairo b/tests/nodes/resize_downsample_sizes_nearest/input_0.cairo new file mode 100644 index 000000000..90a00e3c1 --- /dev/null +++ b/tests/nodes/resize_downsample_sizes_nearest/input_0.cairo @@ -0,0 +1,23 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(2); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_sizes_nearest/input_1.cairo b/tests/nodes/resize_downsample_sizes_nearest/input_1.cairo new file mode 100644 index 000000000..ccf02e251 --- /dev/null +++ b/tests/nodes/resize_downsample_sizes_nearest/input_1.cairo @@ -0,0 +1,15 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(1); + data.append(1); + data.append(1); + data.append(3); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_sizes_nearest/output_0.cairo b/tests/nodes/resize_downsample_sizes_nearest/output_0.cairo new file mode 100644 index 000000000..eecfc2472 --- /dev/null +++ b/tests/nodes/resize_downsample_sizes_nearest/output_0.cairo @@ -0,0 +1,18 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(1); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_sizes_nearest_not_larger.cairo b/tests/nodes/resize_downsample_sizes_nearest_not_larger.cairo new file mode 100644 index 000000000..4c097083e --- /dev/null +++ b/tests/nodes/resize_downsample_sizes_nearest_not_larger.cairo @@ -0,0 +1,44 @@ +mod input_0; +mod input_1; +mod input_2; +mod output_0; + + +use orion::operators::tensor::U32TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::FP16x16Tensor; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::operators::tensor::U32Tensor; +use orion::operators::tensor::math::resize::{ + MODE, NEAREST_MODE, KEEP_ASPECT_RATIO_POLICY, TRANSFORMATION_MODE +}; + + +#[test] +#[available_gas(2000000000)] +fn test_resize_downsample_sizes_nearest_not_larger() { + let data = input_0::input_0(); + let mut sizes = Option::Some(input_1::input_1().data); + let axes = Option::Some(input_2::input_2().data); + let z_0 = output_0::output_0(); + + let y_0 = data + .resize( + Option::None, + Option::None, + sizes, + Option::None, + axes, + Option::None, + Option::None, + Option::None, + Option::None, + Option::Some(KEEP_ASPECT_RATIO_POLICY::NOT_LARGER), + Option::Some(MODE::NEAREST), + Option::None, + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/resize_downsample_sizes_nearest_not_larger/input_0.cairo b/tests/nodes/resize_downsample_sizes_nearest_not_larger/input_0.cairo new file mode 100644 index 000000000..90a00e3c1 --- /dev/null +++ b/tests/nodes/resize_downsample_sizes_nearest_not_larger/input_0.cairo @@ -0,0 +1,23 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(2); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_sizes_nearest_not_larger/input_1.cairo b/tests/nodes/resize_downsample_sizes_nearest_not_larger/input_1.cairo new file mode 100644 index 000000000..7b523e95b --- /dev/null +++ b/tests/nodes/resize_downsample_sizes_nearest_not_larger/input_1.cairo @@ -0,0 +1,13 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(1); + data.append(3); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_sizes_nearest_not_larger/input_2.cairo b/tests/nodes/resize_downsample_sizes_nearest_not_larger/input_2.cairo new file mode 100644 index 000000000..6388a13b0 --- /dev/null +++ b/tests/nodes/resize_downsample_sizes_nearest_not_larger/input_2.cairo @@ -0,0 +1,13 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_2() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(2); + data.append(3); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_sizes_nearest_not_larger/output_0.cairo b/tests/nodes/resize_downsample_sizes_nearest_not_larger/output_0.cairo new file mode 100644 index 000000000..ba0aaa9bf --- /dev/null +++ b/tests/nodes/resize_downsample_sizes_nearest_not_larger/output_0.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(1); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_sizes_nearest_not_smaller.cairo b/tests/nodes/resize_downsample_sizes_nearest_not_smaller.cairo new file mode 100644 index 000000000..523472ef3 --- /dev/null +++ b/tests/nodes/resize_downsample_sizes_nearest_not_smaller.cairo @@ -0,0 +1,43 @@ +mod input_0; +mod input_1; +mod input_2; +mod output_0; + + +use orion::operators::tensor::U32TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::FP16x16Tensor; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::operators::tensor::U32Tensor; +use orion::operators::tensor::math::resize::{ + MODE, NEAREST_MODE, KEEP_ASPECT_RATIO_POLICY, TRANSFORMATION_MODE +}; + +#[test] +#[available_gas(2000000000)] +fn test_resize_downsample_sizes_nearest_not_smaller() { + let data = input_0::input_0(); + let mut sizes = Option::Some(input_1::input_1().data); + let axes = Option::Some(input_2::input_2().data); + let z_0 = output_0::output_0(); + + let y_0 = data + .resize( + Option::None, + Option::None, + sizes, + Option::None, + axes, + Option::None, + Option::None, + Option::None, + Option::None, + Option::Some(KEEP_ASPECT_RATIO_POLICY::NOT_SMALLER), + Option::Some(MODE::NEAREST), + Option::None, + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/resize_downsample_sizes_nearest_not_smaller/input_0.cairo b/tests/nodes/resize_downsample_sizes_nearest_not_smaller/input_0.cairo new file mode 100644 index 000000000..90a00e3c1 --- /dev/null +++ b/tests/nodes/resize_downsample_sizes_nearest_not_smaller/input_0.cairo @@ -0,0 +1,23 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(2); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_sizes_nearest_not_smaller/input_1.cairo b/tests/nodes/resize_downsample_sizes_nearest_not_smaller/input_1.cairo new file mode 100644 index 000000000..7b523e95b --- /dev/null +++ b/tests/nodes/resize_downsample_sizes_nearest_not_smaller/input_1.cairo @@ -0,0 +1,13 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(1); + data.append(3); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_sizes_nearest_not_smaller/input_2.cairo b/tests/nodes/resize_downsample_sizes_nearest_not_smaller/input_2.cairo new file mode 100644 index 000000000..6388a13b0 --- /dev/null +++ b/tests/nodes/resize_downsample_sizes_nearest_not_smaller/input_2.cairo @@ -0,0 +1,13 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_2() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(2); + data.append(3); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_downsample_sizes_nearest_not_smaller/output_0.cairo b/tests/nodes/resize_downsample_sizes_nearest_not_smaller/output_0.cairo new file mode 100644 index 000000000..18da97817 --- /dev/null +++ b/tests/nodes/resize_downsample_sizes_nearest_not_smaller/output_0.cairo @@ -0,0 +1,21 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(2); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_tf_crop_and_resize.cairo b/tests/nodes/resize_tf_crop_and_resize.cairo new file mode 100644 index 000000000..7ffd9a5b0 --- /dev/null +++ b/tests/nodes/resize_tf_crop_and_resize.cairo @@ -0,0 +1,43 @@ +mod input_0; +mod input_1; +mod input_2; +mod output_0; + + +use orion::operators::tensor::FP16x16TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::FP16x16Tensor; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::U32Tensor; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32TensorPartialEq; +use orion::operators::tensor::math::resize::{ + MODE, NEAREST_MODE, KEEP_ASPECT_RATIO_POLICY, TRANSFORMATION_MODE +}; + +#[test] +#[available_gas(2000000000)] +fn test_resize_tf_crop_and_resize() { + let data = input_0::input_0(); + let mut sizes = Option::Some(input_1::input_1().data); + let roi = Option::Some(input_2::input_2()); + let z_0 = output_0::output_0(); + + let y_0 = data + .resize( + roi, + Option::None, + sizes, + Option::None, + Option::None, + Option::Some(TRANSFORMATION_MODE::TF_CROP_AND_RESIZE), + Option::None, + Option::None, + Option::None, + Option::None, + Option::Some(MODE::LINEAR), + Option::None, + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/resize_tf_crop_and_resize/input_0.cairo b/tests/nodes/resize_tf_crop_and_resize/input_0.cairo new file mode 100644 index 000000000..cd17b9adf --- /dev/null +++ b/tests/nodes/resize_tf_crop_and_resize/input_0.cairo @@ -0,0 +1,31 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(4); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 851968, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_tf_crop_and_resize/input_1.cairo b/tests/nodes/resize_tf_crop_and_resize/input_1.cairo new file mode 100644 index 000000000..9eab8d083 --- /dev/null +++ b/tests/nodes/resize_tf_crop_and_resize/input_1.cairo @@ -0,0 +1,15 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(1); + data.append(1); + data.append(3); + data.append(3); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_tf_crop_and_resize/input_2.cairo b/tests/nodes/resize_tf_crop_and_resize/input_2.cairo new file mode 100644 index 000000000..c6d98970c --- /dev/null +++ b/tests/nodes/resize_tf_crop_and_resize/input_2.cairo @@ -0,0 +1,20 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_2() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(8); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 26214, sign: false }); + data.append(FP16x16 { mag: 39321, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 39321, sign: false }); + data.append(FP16x16 { mag: 52428, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_tf_crop_and_resize/output_0.cairo b/tests/nodes/resize_tf_crop_and_resize/output_0.cairo new file mode 100644 index 000000000..46668e78c --- /dev/null +++ b/tests/nodes/resize_tf_crop_and_resize/output_0.cairo @@ -0,0 +1,24 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 498073, sign: false }); + data.append(FP16x16 { mag: 517734, sign: false }); + data.append(FP16x16 { mag: 537395, sign: false }); + data.append(FP16x16 { mag: 576716, sign: false }); + data.append(FP16x16 { mag: 596377, sign: false }); + data.append(FP16x16 { mag: 616038, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 675020, sign: false }); + data.append(FP16x16 { mag: 694681, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_tf_crop_and_resize_axes_2_3.cairo b/tests/nodes/resize_tf_crop_and_resize_axes_2_3.cairo new file mode 100644 index 000000000..8d30b2993 --- /dev/null +++ b/tests/nodes/resize_tf_crop_and_resize_axes_2_3.cairo @@ -0,0 +1,45 @@ +mod input_0; +mod input_1; +mod input_2; +mod input_3; +mod output_0; + + +use orion::operators::tensor::U32TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::FP16x16Tensor; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::operators::tensor::U32Tensor; +use orion::operators::tensor::math::resize::{ + MODE, NEAREST_MODE, KEEP_ASPECT_RATIO_POLICY, TRANSFORMATION_MODE +}; + +#[test] +#[available_gas(2000000000)] +fn test_resize_tf_crop_and_resize_axes_2_3() { + let data = input_0::input_0(); + let mut sizes = Option::Some(input_1::input_1().data); + let roi = Option::Some(input_2::input_2()); + let axes = Option::Some(input_3::input_3().data); + let z_0 = output_0::output_0(); + + let y_0 = data + .resize( + roi, + Option::None, + sizes, + Option::None, + axes, + Option::Some(TRANSFORMATION_MODE::TF_CROP_AND_RESIZE), + Option::None, + Option::None, + Option::None, + Option::None, + Option::Some(MODE::LINEAR), + Option::None, + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/resize_tf_crop_and_resize_axes_2_3/input_0.cairo b/tests/nodes/resize_tf_crop_and_resize_axes_2_3/input_0.cairo new file mode 100644 index 000000000..cd17b9adf --- /dev/null +++ b/tests/nodes/resize_tf_crop_and_resize_axes_2_3/input_0.cairo @@ -0,0 +1,31 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(4); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 851968, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_tf_crop_and_resize_axes_2_3/input_1.cairo b/tests/nodes/resize_tf_crop_and_resize_axes_2_3/input_1.cairo new file mode 100644 index 000000000..3e04171f5 --- /dev/null +++ b/tests/nodes/resize_tf_crop_and_resize_axes_2_3/input_1.cairo @@ -0,0 +1,13 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(3); + data.append(3); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_tf_crop_and_resize_axes_2_3/input_2.cairo b/tests/nodes/resize_tf_crop_and_resize_axes_2_3/input_2.cairo new file mode 100644 index 000000000..7ec99b507 --- /dev/null +++ b/tests/nodes/resize_tf_crop_and_resize_axes_2_3/input_2.cairo @@ -0,0 +1,16 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_2() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 26214, sign: false }); + data.append(FP16x16 { mag: 39321, sign: false }); + data.append(FP16x16 { mag: 39321, sign: false }); + data.append(FP16x16 { mag: 52428, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_tf_crop_and_resize_axes_2_3/input_3.cairo b/tests/nodes/resize_tf_crop_and_resize_axes_2_3/input_3.cairo new file mode 100644 index 000000000..6fc07eeff --- /dev/null +++ b/tests/nodes/resize_tf_crop_and_resize_axes_2_3/input_3.cairo @@ -0,0 +1,13 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_3() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(2); + data.append(3); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_tf_crop_and_resize_axes_2_3/output_0.cairo b/tests/nodes/resize_tf_crop_and_resize_axes_2_3/output_0.cairo new file mode 100644 index 000000000..46668e78c --- /dev/null +++ b/tests/nodes/resize_tf_crop_and_resize_axes_2_3/output_0.cairo @@ -0,0 +1,24 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 498073, sign: false }); + data.append(FP16x16 { mag: 517734, sign: false }); + data.append(FP16x16 { mag: 537395, sign: false }); + data.append(FP16x16 { mag: 576716, sign: false }); + data.append(FP16x16 { mag: 596377, sign: false }); + data.append(FP16x16 { mag: 616038, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 675020, sign: false }); + data.append(FP16x16 { mag: 694681, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_tf_crop_and_resize_axes_3_2.cairo b/tests/nodes/resize_tf_crop_and_resize_axes_3_2.cairo new file mode 100644 index 000000000..f0e59f4e5 --- /dev/null +++ b/tests/nodes/resize_tf_crop_and_resize_axes_3_2.cairo @@ -0,0 +1,45 @@ +mod input_0; +mod input_1; +mod input_2; +mod input_3; +mod output_0; + + +use orion::operators::tensor::U32TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::FP16x16Tensor; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::operators::tensor::U32Tensor; +use orion::operators::tensor::math::resize::{ + MODE, NEAREST_MODE, KEEP_ASPECT_RATIO_POLICY, TRANSFORMATION_MODE +}; + +#[test] +#[available_gas(2000000000)] +fn test_resize_tf_crop_and_resize_axes_3_2() { + let data = input_0::input_0(); + let mut sizes = Option::Some(input_1::input_1().data); + let roi = Option::Some(input_2::input_2()); + let axes = Option::Some(input_3::input_3().data); + let z_0 = output_0::output_0(); + + let y_0 = data + .resize( + roi, + Option::None, + sizes, + Option::None, + axes, + Option::Some(TRANSFORMATION_MODE::TF_CROP_AND_RESIZE), + Option::None, + Option::None, + Option::None, + Option::None, + Option::Some(MODE::LINEAR), + Option::None, + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/resize_tf_crop_and_resize_axes_3_2/input_0.cairo b/tests/nodes/resize_tf_crop_and_resize_axes_3_2/input_0.cairo new file mode 100644 index 000000000..cd17b9adf --- /dev/null +++ b/tests/nodes/resize_tf_crop_and_resize_axes_3_2/input_0.cairo @@ -0,0 +1,31 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(4); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 851968, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_tf_crop_and_resize_axes_3_2/input_1.cairo b/tests/nodes/resize_tf_crop_and_resize_axes_3_2/input_1.cairo new file mode 100644 index 000000000..3e04171f5 --- /dev/null +++ b/tests/nodes/resize_tf_crop_and_resize_axes_3_2/input_1.cairo @@ -0,0 +1,13 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(3); + data.append(3); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_tf_crop_and_resize_axes_3_2/input_2.cairo b/tests/nodes/resize_tf_crop_and_resize_axes_3_2/input_2.cairo new file mode 100644 index 000000000..4369ef099 --- /dev/null +++ b/tests/nodes/resize_tf_crop_and_resize_axes_3_2/input_2.cairo @@ -0,0 +1,16 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_2() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 39321, sign: false }); + data.append(FP16x16 { mag: 26214, sign: false }); + data.append(FP16x16 { mag: 52428, sign: false }); + data.append(FP16x16 { mag: 39321, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_tf_crop_and_resize_axes_3_2/input_3.cairo b/tests/nodes/resize_tf_crop_and_resize_axes_3_2/input_3.cairo new file mode 100644 index 000000000..804c8f5c6 --- /dev/null +++ b/tests/nodes/resize_tf_crop_and_resize_axes_3_2/input_3.cairo @@ -0,0 +1,13 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_3() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(3); + data.append(2); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_tf_crop_and_resize_axes_3_2/output_0.cairo b/tests/nodes/resize_tf_crop_and_resize_axes_3_2/output_0.cairo new file mode 100644 index 000000000..46668e78c --- /dev/null +++ b/tests/nodes/resize_tf_crop_and_resize_axes_3_2/output_0.cairo @@ -0,0 +1,24 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 498073, sign: false }); + data.append(FP16x16 { mag: 517734, sign: false }); + data.append(FP16x16 { mag: 537395, sign: false }); + data.append(FP16x16 { mag: 576716, sign: false }); + data.append(FP16x16 { mag: 596377, sign: false }); + data.append(FP16x16 { mag: 616038, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 675020, sign: false }); + data.append(FP16x16 { mag: 694681, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_tf_crop_and_resize_extrapolation_value.cairo b/tests/nodes/resize_tf_crop_and_resize_extrapolation_value.cairo new file mode 100644 index 000000000..b3fb2e6d0 --- /dev/null +++ b/tests/nodes/resize_tf_crop_and_resize_extrapolation_value.cairo @@ -0,0 +1,44 @@ +mod input_0; +mod input_1; +mod input_2; +mod output_0; + + +use orion::operators::tensor::FP16x16TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::FP16x16Tensor; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::U32Tensor; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32TensorPartialEq; +use orion::operators::tensor::math::resize::{ + MODE, NEAREST_MODE, KEEP_ASPECT_RATIO_POLICY, TRANSFORMATION_MODE +}; +use orion::numbers::{FP16x16, FP16x16Impl, FP32x32, FP32x32Impl, FixedTrait}; + +#[test] +#[available_gas(2000000000)] +fn test_resize_tf_crop_and_resize_extrapolation_value() { + let data = input_0::input_0(); + let mut sizes = Option::Some(input_1::input_1().data); + let roi = Option::Some(input_2::input_2()); + let z_0 = output_0::output_0(); + + let y_0 = data + .resize( + roi, + Option::None, + sizes, + Option::None, + Option::None, + Option::Some(TRANSFORMATION_MODE::TF_CROP_AND_RESIZE), + Option::None, + Option::None, + Option::Some(FixedTrait::::new(655360, false)), + Option::None, + Option::Some(MODE::LINEAR), + Option::None, + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/resize_tf_crop_and_resize_extrapolation_value/input_0.cairo b/tests/nodes/resize_tf_crop_and_resize_extrapolation_value/input_0.cairo new file mode 100644 index 000000000..cd17b9adf --- /dev/null +++ b/tests/nodes/resize_tf_crop_and_resize_extrapolation_value/input_0.cairo @@ -0,0 +1,31 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(4); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 851968, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_tf_crop_and_resize_extrapolation_value/input_1.cairo b/tests/nodes/resize_tf_crop_and_resize_extrapolation_value/input_1.cairo new file mode 100644 index 000000000..9eab8d083 --- /dev/null +++ b/tests/nodes/resize_tf_crop_and_resize_extrapolation_value/input_1.cairo @@ -0,0 +1,15 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(1); + data.append(1); + data.append(3); + data.append(3); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_tf_crop_and_resize_extrapolation_value/input_2.cairo b/tests/nodes/resize_tf_crop_and_resize_extrapolation_value/input_2.cairo new file mode 100644 index 000000000..26627c0b1 --- /dev/null +++ b/tests/nodes/resize_tf_crop_and_resize_extrapolation_value/input_2.cairo @@ -0,0 +1,20 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_2() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(8); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 0, sign: false }); + data.append(FP16x16 { mag: 26214, sign: false }); + data.append(FP16x16 { mag: 39321, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 78643, sign: false }); + data.append(FP16x16 { mag: 111411, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_tf_crop_and_resize_extrapolation_value/output_0.cairo b/tests/nodes/resize_tf_crop_and_resize_extrapolation_value/output_0.cairo new file mode 100644 index 000000000..80acb5c8f --- /dev/null +++ b/tests/nodes/resize_tf_crop_and_resize_extrapolation_value/output_0.cairo @@ -0,0 +1,24 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(3); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 498073, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 812646, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_scales_cubic.cairo b/tests/nodes/resize_upsample_scales_cubic.cairo new file mode 100644 index 000000000..6d7b1da71 --- /dev/null +++ b/tests/nodes/resize_upsample_scales_cubic.cairo @@ -0,0 +1,39 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::FP16x16Tensor; +use orion::operators::tensor::FP16x16TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::math::resize::{ + MODE, NEAREST_MODE, KEEP_ASPECT_RATIO_POLICY, TRANSFORMATION_MODE +}; + +#[test] +#[available_gas(2000000000)] +fn test_resize_upsample_scales_cubic() { + let data = input_0::input_0(); + let mut scales = Option::Some(input_1::input_1().data); + let z_0 = output_0::output_0(); + + let y_0 = data + .resize( + Option::None, + scales, + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::Some(MODE::CUBIC), + Option::None, + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/resize_upsample_scales_cubic/input_0.cairo b/tests/nodes/resize_upsample_scales_cubic/input_0.cairo new file mode 100644 index 000000000..cd17b9adf --- /dev/null +++ b/tests/nodes/resize_upsample_scales_cubic/input_0.cairo @@ -0,0 +1,31 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(4); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 851968, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_scales_cubic/input_1.cairo b/tests/nodes/resize_upsample_scales_cubic/input_1.cairo new file mode 100644 index 000000000..75bd50d8a --- /dev/null +++ b/tests/nodes/resize_upsample_scales_cubic/input_1.cairo @@ -0,0 +1,16 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_scales_cubic/output_0.cairo b/tests/nodes/resize_upsample_scales_cubic/output_0.cairo new file mode 100644 index 000000000..8d8b8d5b4 --- /dev/null +++ b/tests/nodes/resize_upsample_scales_cubic/output_0.cairo @@ -0,0 +1,79 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(8); + shape.append(8); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 30976, sign: false }); + data.append(FP16x16 { mag: 50432, sign: false }); + data.append(FP16x16 { mag: 81664, sign: false }); + data.append(FP16x16 { mag: 122880, sign: false }); + data.append(FP16x16 { mag: 149504, sign: false }); + data.append(FP16x16 { mag: 190720, sign: false }); + data.append(FP16x16 { mag: 221952, sign: false }); + data.append(FP16x16 { mag: 241408, sign: false }); + data.append(FP16x16 { mag: 108800, sign: false }); + data.append(FP16x16 { mag: 128256, sign: false }); + data.append(FP16x16 { mag: 159488, sign: false }); + data.append(FP16x16 { mag: 200704, sign: false }); + data.append(FP16x16 { mag: 227328, sign: false }); + data.append(FP16x16 { mag: 268544, sign: false }); + data.append(FP16x16 { mag: 299776, sign: false }); + data.append(FP16x16 { mag: 319232, sign: false }); + data.append(FP16x16 { mag: 233728, sign: false }); + data.append(FP16x16 { mag: 253184, sign: false }); + data.append(FP16x16 { mag: 284416, sign: false }); + data.append(FP16x16 { mag: 325632, sign: false }); + data.append(FP16x16 { mag: 352256, sign: false }); + data.append(FP16x16 { mag: 393472, sign: false }); + data.append(FP16x16 { mag: 424704, sign: false }); + data.append(FP16x16 { mag: 444160, sign: false }); + data.append(FP16x16 { mag: 398592, sign: false }); + data.append(FP16x16 { mag: 418048, sign: false }); + data.append(FP16x16 { mag: 449280, sign: false }); + data.append(FP16x16 { mag: 490496, sign: false }); + data.append(FP16x16 { mag: 517120, sign: false }); + data.append(FP16x16 { mag: 558336, sign: false }); + data.append(FP16x16 { mag: 589568, sign: false }); + data.append(FP16x16 { mag: 609024, sign: false }); + data.append(FP16x16 { mag: 505088, sign: false }); + data.append(FP16x16 { mag: 524544, sign: false }); + data.append(FP16x16 { mag: 555776, sign: false }); + data.append(FP16x16 { mag: 596992, sign: false }); + data.append(FP16x16 { mag: 623616, sign: false }); + data.append(FP16x16 { mag: 664832, sign: false }); + data.append(FP16x16 { mag: 696064, sign: false }); + data.append(FP16x16 { mag: 715520, sign: false }); + data.append(FP16x16 { mag: 669952, sign: false }); + data.append(FP16x16 { mag: 689408, sign: false }); + data.append(FP16x16 { mag: 720640, sign: false }); + data.append(FP16x16 { mag: 761856, sign: false }); + data.append(FP16x16 { mag: 788480, sign: false }); + data.append(FP16x16 { mag: 829696, sign: false }); + data.append(FP16x16 { mag: 860928, sign: false }); + data.append(FP16x16 { mag: 880384, sign: false }); + data.append(FP16x16 { mag: 794880, sign: false }); + data.append(FP16x16 { mag: 814336, sign: false }); + data.append(FP16x16 { mag: 845568, sign: false }); + data.append(FP16x16 { mag: 886784, sign: false }); + data.append(FP16x16 { mag: 913408, sign: false }); + data.append(FP16x16 { mag: 954624, sign: false }); + data.append(FP16x16 { mag: 985856, sign: false }); + data.append(FP16x16 { mag: 1005312, sign: false }); + data.append(FP16x16 { mag: 872704, sign: false }); + data.append(FP16x16 { mag: 892160, sign: false }); + data.append(FP16x16 { mag: 923392, sign: false }); + data.append(FP16x16 { mag: 964608, sign: false }); + data.append(FP16x16 { mag: 991232, sign: false }); + data.append(FP16x16 { mag: 1032448, sign: false }); + data.append(FP16x16 { mag: 1063680, sign: false }); + data.append(FP16x16 { mag: 1083136, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_scales_cubic_A_n0p5_exclude_outside.cairo b/tests/nodes/resize_upsample_scales_cubic_A_n0p5_exclude_outside.cairo new file mode 100644 index 000000000..1a7e89dd4 --- /dev/null +++ b/tests/nodes/resize_upsample_scales_cubic_A_n0p5_exclude_outside.cairo @@ -0,0 +1,40 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::FP16x16Tensor; +use orion::operators::tensor::FP16x16TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::math::resize::{ + MODE, NEAREST_MODE, KEEP_ASPECT_RATIO_POLICY, TRANSFORMATION_MODE +}; +use orion::numbers::{FP16x16, FP16x16Impl, FP32x32, FP32x32Impl, FixedTrait}; + +#[test] +#[available_gas(2000000000)] +fn test_resize_upsample_scales_cubic_A_n0p5_exclude_outside() { + let data = input_0::input_0(); + let mut scales = Option::Some(input_1::input_1().data); + let z_0 = output_0::output_0(); + + let y_0 = data + .resize( + Option::None, + scales, + Option::None, + Option::None, + Option::None, + Option::None, + Option::Some(FixedTrait::::new(32768, true)), + Option::Some(true), + Option::None, + Option::None, + Option::Some(MODE::CUBIC), + Option::None, + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/resize_upsample_scales_cubic_A_n0p5_exclude_outside/input_0.cairo b/tests/nodes/resize_upsample_scales_cubic_A_n0p5_exclude_outside/input_0.cairo new file mode 100644 index 000000000..cd17b9adf --- /dev/null +++ b/tests/nodes/resize_upsample_scales_cubic_A_n0p5_exclude_outside/input_0.cairo @@ -0,0 +1,31 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(4); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 851968, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_scales_cubic_A_n0p5_exclude_outside/input_1.cairo b/tests/nodes/resize_upsample_scales_cubic_A_n0p5_exclude_outside/input_1.cairo new file mode 100644 index 000000000..75bd50d8a --- /dev/null +++ b/tests/nodes/resize_upsample_scales_cubic_A_n0p5_exclude_outside/input_1.cairo @@ -0,0 +1,16 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_scales_cubic_A_n0p5_exclude_outside/output_0.cairo b/tests/nodes/resize_upsample_scales_cubic_A_n0p5_exclude_outside/output_0.cairo new file mode 100644 index 000000000..70c95fe02 --- /dev/null +++ b/tests/nodes/resize_upsample_scales_cubic_A_n0p5_exclude_outside/output_0.cairo @@ -0,0 +1,79 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(8); + shape.append(8); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 36623, sign: false }); + data.append(FP16x16 { mag: 53408, sign: false }); + data.append(FP16x16 { mag: 88931, sign: false }); + data.append(FP16x16 { mag: 124325, sign: false }); + data.append(FP16x16 { mag: 157093, sign: false }); + data.append(FP16x16 { mag: 192488, sign: false }); + data.append(FP16x16 { mag: 228011, sign: false }); + data.append(FP16x16 { mag: 244796, sign: false }); + data.append(FP16x16 { mag: 103762, sign: false }); + data.append(FP16x16 { mag: 120547, sign: false }); + data.append(FP16x16 { mag: 156071, sign: false }); + data.append(FP16x16 { mag: 191465, sign: false }); + data.append(FP16x16 { mag: 224233, sign: false }); + data.append(FP16x16 { mag: 259628, sign: false }); + data.append(FP16x16 { mag: 295151, sign: false }); + data.append(FP16x16 { mag: 311936, sign: false }); + data.append(FP16x16 { mag: 245855, sign: false }); + data.append(FP16x16 { mag: 262640, sign: false }); + data.append(FP16x16 { mag: 298163, sign: false }); + data.append(FP16x16 { mag: 333558, sign: false }); + data.append(FP16x16 { mag: 366326, sign: false }); + data.append(FP16x16 { mag: 401720, sign: false }); + data.append(FP16x16 { mag: 437243, sign: false }); + data.append(FP16x16 { mag: 454028, sign: false }); + data.append(FP16x16 { mag: 387433, sign: false }); + data.append(FP16x16 { mag: 404218, sign: false }); + data.append(FP16x16 { mag: 439741, sign: false }); + data.append(FP16x16 { mag: 475136, sign: false }); + data.append(FP16x16 { mag: 507904, sign: false }); + data.append(FP16x16 { mag: 543298, sign: false }); + data.append(FP16x16 { mag: 578821, sign: false }); + data.append(FP16x16 { mag: 595606, sign: false }); + data.append(FP16x16 { mag: 518505, sign: false }); + data.append(FP16x16 { mag: 535290, sign: false }); + data.append(FP16x16 { mag: 570813, sign: false }); + data.append(FP16x16 { mag: 606208, sign: false }); + data.append(FP16x16 { mag: 638976, sign: false }); + data.append(FP16x16 { mag: 674370, sign: false }); + data.append(FP16x16 { mag: 709893, sign: false }); + data.append(FP16x16 { mag: 726678, sign: false }); + data.append(FP16x16 { mag: 660083, sign: false }); + data.append(FP16x16 { mag: 676868, sign: false }); + data.append(FP16x16 { mag: 712391, sign: false }); + data.append(FP16x16 { mag: 747785, sign: false }); + data.append(FP16x16 { mag: 780553, sign: false }); + data.append(FP16x16 { mag: 815948, sign: false }); + data.append(FP16x16 { mag: 851471, sign: false }); + data.append(FP16x16 { mag: 868256, sign: false }); + data.append(FP16x16 { mag: 802175, sign: false }); + data.append(FP16x16 { mag: 818960, sign: false }); + data.append(FP16x16 { mag: 854484, sign: false }); + data.append(FP16x16 { mag: 889878, sign: false }); + data.append(FP16x16 { mag: 922646, sign: false }); + data.append(FP16x16 { mag: 958040, sign: false }); + data.append(FP16x16 { mag: 993564, sign: false }); + data.append(FP16x16 { mag: 1010349, sign: false }); + data.append(FP16x16 { mag: 869315, sign: false }); + data.append(FP16x16 { mag: 886100, sign: false }); + data.append(FP16x16 { mag: 921623, sign: false }); + data.append(FP16x16 { mag: 957018, sign: false }); + data.append(FP16x16 { mag: 989786, sign: false }); + data.append(FP16x16 { mag: 1025180, sign: false }); + data.append(FP16x16 { mag: 1060704, sign: false }); + data.append(FP16x16 { mag: 1077489, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_scales_cubic_align_corners.cairo b/tests/nodes/resize_upsample_scales_cubic_align_corners.cairo new file mode 100644 index 000000000..77c9befd1 --- /dev/null +++ b/tests/nodes/resize_upsample_scales_cubic_align_corners.cairo @@ -0,0 +1,39 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::FP16x16Tensor; +use orion::operators::tensor::FP16x16TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::math::resize::{ + MODE, NEAREST_MODE, KEEP_ASPECT_RATIO_POLICY, TRANSFORMATION_MODE +}; + +#[test] +#[available_gas(2000000000)] +fn test_resize_upsample_scales_cubic_align_corners() { + let data = input_0::input_0(); + let mut scales = Option::Some(input_1::input_1().data); + let z_0 = output_0::output_0(); + + let y_0 = data + .resize( + Option::None, + scales, + Option::None, + Option::None, + Option::None, + Option::Some(TRANSFORMATION_MODE::ALIGN_CORNERS), + Option::None, + Option::None, + Option::None, + Option::None, + Option::Some(MODE::CUBIC), + Option::None, + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/resize_upsample_scales_cubic_align_corners/input_0.cairo b/tests/nodes/resize_upsample_scales_cubic_align_corners/input_0.cairo new file mode 100644 index 000000000..cd17b9adf --- /dev/null +++ b/tests/nodes/resize_upsample_scales_cubic_align_corners/input_0.cairo @@ -0,0 +1,31 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(4); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 851968, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_scales_cubic_align_corners/input_1.cairo b/tests/nodes/resize_upsample_scales_cubic_align_corners/input_1.cairo new file mode 100644 index 000000000..75bd50d8a --- /dev/null +++ b/tests/nodes/resize_upsample_scales_cubic_align_corners/input_1.cairo @@ -0,0 +1,16 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_scales_cubic_align_corners/output_0.cairo b/tests/nodes/resize_upsample_scales_cubic_align_corners/output_0.cairo new file mode 100644 index 000000000..87f61ce58 --- /dev/null +++ b/tests/nodes/resize_upsample_scales_cubic_align_corners/output_0.cairo @@ -0,0 +1,79 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(8); + shape.append(8); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 87890, sign: false }); + data.append(FP16x16 { mag: 117983, sign: false }); + data.append(FP16x16 { mag: 152662, sign: false }); + data.append(FP16x16 { mag: 175017, sign: false }); + data.append(FP16x16 { mag: 209696, sign: false }); + data.append(FP16x16 { mag: 239789, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 154955, sign: false }); + data.append(FP16x16 { mag: 177310, sign: false }); + data.append(FP16x16 { mag: 207403, sign: false }); + data.append(FP16x16 { mag: 242081, sign: false }); + data.append(FP16x16 { mag: 264436, sign: false }); + data.append(FP16x16 { mag: 299115, sign: false }); + data.append(FP16x16 { mag: 329208, sign: false }); + data.append(FP16x16 { mag: 351563, sign: false }); + data.append(FP16x16 { mag: 275327, sign: false }); + data.append(FP16x16 { mag: 297682, sign: false }); + data.append(FP16x16 { mag: 327775, sign: false }); + data.append(FP16x16 { mag: 362454, sign: false }); + data.append(FP16x16 { mag: 384809, sign: false }); + data.append(FP16x16 { mag: 419487, sign: false }); + data.append(FP16x16 { mag: 449580, sign: false }); + data.append(FP16x16 { mag: 471935, sign: false }); + data.append(FP16x16 { mag: 414042, sign: false }); + data.append(FP16x16 { mag: 436397, sign: false }); + data.append(FP16x16 { mag: 466490, sign: false }); + data.append(FP16x16 { mag: 501168, sign: false }); + data.append(FP16x16 { mag: 523523, sign: false }); + data.append(FP16x16 { mag: 558202, sign: false }); + data.append(FP16x16 { mag: 588295, sign: false }); + data.append(FP16x16 { mag: 610650, sign: false }); + data.append(FP16x16 { mag: 503461, sign: false }); + data.append(FP16x16 { mag: 525816, sign: false }); + data.append(FP16x16 { mag: 555909, sign: false }); + data.append(FP16x16 { mag: 590588, sign: false }); + data.append(FP16x16 { mag: 612943, sign: false }); + data.append(FP16x16 { mag: 647621, sign: false }); + data.append(FP16x16 { mag: 677714, sign: false }); + data.append(FP16x16 { mag: 700069, sign: false }); + data.append(FP16x16 { mag: 642176, sign: false }); + data.append(FP16x16 { mag: 664531, sign: false }); + data.append(FP16x16 { mag: 694624, sign: false }); + data.append(FP16x16 { mag: 729302, sign: false }); + data.append(FP16x16 { mag: 751657, sign: false }); + data.append(FP16x16 { mag: 786336, sign: false }); + data.append(FP16x16 { mag: 816429, sign: false }); + data.append(FP16x16 { mag: 838784, sign: false }); + data.append(FP16x16 { mag: 762548, sign: false }); + data.append(FP16x16 { mag: 784903, sign: false }); + data.append(FP16x16 { mag: 814996, sign: false }); + data.append(FP16x16 { mag: 849675, sign: false }); + data.append(FP16x16 { mag: 872030, sign: false }); + data.append(FP16x16 { mag: 906708, sign: false }); + data.append(FP16x16 { mag: 936801, sign: false }); + data.append(FP16x16 { mag: 959156, sign: false }); + data.append(FP16x16 { mag: 851968, sign: false }); + data.append(FP16x16 { mag: 874322, sign: false }); + data.append(FP16x16 { mag: 904415, sign: false }); + data.append(FP16x16 { mag: 939094, sign: false }); + data.append(FP16x16 { mag: 961449, sign: false }); + data.append(FP16x16 { mag: 996128, sign: false }); + data.append(FP16x16 { mag: 1026221, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_scales_cubic_asymmetric.cairo b/tests/nodes/resize_upsample_scales_cubic_asymmetric.cairo new file mode 100644 index 000000000..3b07d61e8 --- /dev/null +++ b/tests/nodes/resize_upsample_scales_cubic_asymmetric.cairo @@ -0,0 +1,39 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::FP16x16Tensor; +use orion::operators::tensor::FP16x16TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::math::resize::{ + MODE, NEAREST_MODE, KEEP_ASPECT_RATIO_POLICY, TRANSFORMATION_MODE +}; + +#[test] +#[available_gas(2000000000)] +fn test_resize_upsample_scales_cubic_asymmetric() { + let data = input_0::input_0(); + let mut scales = Option::Some(input_1::input_1().data); + let z_0 = output_0::output_0(); + + let y_0 = data + .resize( + Option::None, + scales, + Option::None, + Option::None, + Option::None, + Option::Some(TRANSFORMATION_MODE::ASYMMETRIC), + Option::None, + Option::None, + Option::None, + Option::None, + Option::Some(MODE::CUBIC), + Option::None, + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/resize_upsample_scales_cubic_asymmetric/input_0.cairo b/tests/nodes/resize_upsample_scales_cubic_asymmetric/input_0.cairo new file mode 100644 index 000000000..cd17b9adf --- /dev/null +++ b/tests/nodes/resize_upsample_scales_cubic_asymmetric/input_0.cairo @@ -0,0 +1,31 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(4); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 851968, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_scales_cubic_asymmetric/input_1.cairo b/tests/nodes/resize_upsample_scales_cubic_asymmetric/input_1.cairo new file mode 100644 index 000000000..75bd50d8a --- /dev/null +++ b/tests/nodes/resize_upsample_scales_cubic_asymmetric/input_1.cairo @@ -0,0 +1,16 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_scales_cubic_asymmetric/output_0.cairo b/tests/nodes/resize_upsample_scales_cubic_asymmetric/output_0.cairo new file mode 100644 index 000000000..61321f6df --- /dev/null +++ b/tests/nodes/resize_upsample_scales_cubic_asymmetric/output_0.cairo @@ -0,0 +1,79 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(8); + shape.append(8); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 92160, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 163840, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 235520, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 268288, sign: false }); + data.append(FP16x16 { mag: 172032, sign: false }); + data.append(FP16x16 { mag: 198656, sign: false }); + data.append(FP16x16 { mag: 237568, sign: false }); + data.append(FP16x16 { mag: 270336, sign: false }); + data.append(FP16x16 { mag: 303104, sign: false }); + data.append(FP16x16 { mag: 342016, sign: false }); + data.append(FP16x16 { mag: 368640, sign: false }); + data.append(FP16x16 { mag: 374784, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 354304, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 425984, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 497664, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 530432, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 485376, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 557056, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 628736, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 661504, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 616448, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 688128, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 759808, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 792576, sign: false }); + data.append(FP16x16 { mag: 745472, sign: false }); + data.append(FP16x16 { mag: 772096, sign: false }); + data.append(FP16x16 { mag: 811008, sign: false }); + data.append(FP16x16 { mag: 843776, sign: false }); + data.append(FP16x16 { mag: 876544, sign: false }); + data.append(FP16x16 { mag: 915456, sign: false }); + data.append(FP16x16 { mag: 942080, sign: false }); + data.append(FP16x16 { mag: 948224, sign: false }); + data.append(FP16x16 { mag: 851968, sign: false }); + data.append(FP16x16 { mag: 878592, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 950272, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 1021952, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + data.append(FP16x16 { mag: 1054720, sign: false }); + data.append(FP16x16 { mag: 876544, sign: false }); + data.append(FP16x16 { mag: 903168, sign: false }); + data.append(FP16x16 { mag: 942080, sign: false }); + data.append(FP16x16 { mag: 974848, sign: false }); + data.append(FP16x16 { mag: 1007616, sign: false }); + data.append(FP16x16 { mag: 1046528, sign: false }); + data.append(FP16x16 { mag: 1073152, sign: false }); + data.append(FP16x16 { mag: 1079296, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_scales_linear.cairo b/tests/nodes/resize_upsample_scales_linear.cairo new file mode 100644 index 000000000..650cbe7d1 --- /dev/null +++ b/tests/nodes/resize_upsample_scales_linear.cairo @@ -0,0 +1,39 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::FP16x16Tensor; +use orion::operators::tensor::FP16x16TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::math::resize::{ + MODE, NEAREST_MODE, KEEP_ASPECT_RATIO_POLICY, TRANSFORMATION_MODE +}; + +#[test] +#[available_gas(2000000000)] +fn test_resize_upsample_scales_linear() { + let data = input_0::input_0(); + let mut scales = Option::Some(input_1::input_1().data); + let z_0 = output_0::output_0(); + + let y_0 = data + .resize( + Option::None, + scales, + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::Some(MODE::LINEAR), + Option::None, + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/resize_upsample_scales_linear/input_0.cairo b/tests/nodes/resize_upsample_scales_linear/input_0.cairo new file mode 100644 index 000000000..84a850bc4 --- /dev/null +++ b/tests/nodes/resize_upsample_scales_linear/input_0.cairo @@ -0,0 +1,19 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_scales_linear/input_1.cairo b/tests/nodes/resize_upsample_scales_linear/input_1.cairo new file mode 100644 index 000000000..75bd50d8a --- /dev/null +++ b/tests/nodes/resize_upsample_scales_linear/input_1.cairo @@ -0,0 +1,16 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_scales_linear/output_0.cairo b/tests/nodes/resize_upsample_scales_linear/output_0.cairo new file mode 100644 index 000000000..67f6375ca --- /dev/null +++ b/tests/nodes/resize_upsample_scales_linear/output_0.cairo @@ -0,0 +1,31 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(4); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 81920, sign: false }); + data.append(FP16x16 { mag: 114688, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 98304, sign: false }); + data.append(FP16x16 { mag: 114688, sign: false }); + data.append(FP16x16 { mag: 147456, sign: false }); + data.append(FP16x16 { mag: 163840, sign: false }); + data.append(FP16x16 { mag: 163840, sign: false }); + data.append(FP16x16 { mag: 180224, sign: false }); + data.append(FP16x16 { mag: 212992, sign: false }); + data.append(FP16x16 { mag: 229376, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 212992, sign: false }); + data.append(FP16x16 { mag: 245760, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_scales_linear_align_corners.cairo b/tests/nodes/resize_upsample_scales_linear_align_corners.cairo new file mode 100644 index 000000000..d58d1e15c --- /dev/null +++ b/tests/nodes/resize_upsample_scales_linear_align_corners.cairo @@ -0,0 +1,39 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::FP16x16Tensor; +use orion::operators::tensor::FP16x16TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::math::resize::{ + MODE, NEAREST_MODE, KEEP_ASPECT_RATIO_POLICY, TRANSFORMATION_MODE +}; + +#[test] +#[available_gas(2000000000)] +fn test_resize_upsample_scales_linear_align_corners() { + let data = input_0::input_0(); + let mut scales = Option::Some(input_1::input_1().data); + let z_0 = output_0::output_0(); + + let y_0 = data + .resize( + Option::None, + scales, + Option::None, + Option::None, + Option::None, + Option::Some(TRANSFORMATION_MODE::ALIGN_CORNERS), + Option::None, + Option::None, + Option::None, + Option::None, + Option::Some(MODE::LINEAR), + Option::None, + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/resize_upsample_scales_linear_align_corners/input_0.cairo b/tests/nodes/resize_upsample_scales_linear_align_corners/input_0.cairo new file mode 100644 index 000000000..84a850bc4 --- /dev/null +++ b/tests/nodes/resize_upsample_scales_linear_align_corners/input_0.cairo @@ -0,0 +1,19 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_scales_linear_align_corners/input_1.cairo b/tests/nodes/resize_upsample_scales_linear_align_corners/input_1.cairo new file mode 100644 index 000000000..75bd50d8a --- /dev/null +++ b/tests/nodes/resize_upsample_scales_linear_align_corners/input_1.cairo @@ -0,0 +1,16 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_scales_linear_align_corners/output_0.cairo b/tests/nodes/resize_upsample_scales_linear_align_corners/output_0.cairo new file mode 100644 index 000000000..9a136353c --- /dev/null +++ b/tests/nodes/resize_upsample_scales_linear_align_corners/output_0.cairo @@ -0,0 +1,31 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(4); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 87381, sign: false }); + data.append(FP16x16 { mag: 109226, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 109226, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 152917, sign: false }); + data.append(FP16x16 { mag: 174762, sign: false }); + data.append(FP16x16 { mag: 152917, sign: false }); + data.append(FP16x16 { mag: 174762, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 218453, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 218453, sign: false }); + data.append(FP16x16 { mag: 240298, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_scales_linear_half_pixel_symmetric.cairo b/tests/nodes/resize_upsample_scales_linear_half_pixel_symmetric.cairo new file mode 100644 index 000000000..aa07b6dd3 --- /dev/null +++ b/tests/nodes/resize_upsample_scales_linear_half_pixel_symmetric.cairo @@ -0,0 +1,39 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::operators::tensor::FP16x16Tensor; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::math::resize::{ + MODE, NEAREST_MODE, KEEP_ASPECT_RATIO_POLICY, TRANSFORMATION_MODE +}; + +#[test] +#[available_gas(2000000000)] +fn test_resize_upsample_scales_linear_half_pixel_symmetric() { + let data = input_0::input_0(); + let mut scales = Option::Some(input_1::input_1().data); + let z_0 = output_0::output_0(); + + let y_0 = data + .resize( + Option::None, + scales, + Option::None, + Option::None, + Option::None, + Option::Some(TRANSFORMATION_MODE::HALF_PIXEL_SYMMETRIC), + Option::None, + Option::None, + Option::None, + Option::None, + Option::Some(MODE::LINEAR), + Option::None, + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/resize_upsample_scales_linear_half_pixel_symmetric/input_0.cairo b/tests/nodes/resize_upsample_scales_linear_half_pixel_symmetric/input_0.cairo new file mode 100644 index 000000000..84a850bc4 --- /dev/null +++ b/tests/nodes/resize_upsample_scales_linear_half_pixel_symmetric/input_0.cairo @@ -0,0 +1,19 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_scales_linear_half_pixel_symmetric/input_1.cairo b/tests/nodes/resize_upsample_scales_linear_half_pixel_symmetric/input_1.cairo new file mode 100644 index 000000000..ef375c64b --- /dev/null +++ b/tests/nodes/resize_upsample_scales_linear_half_pixel_symmetric/input_1.cairo @@ -0,0 +1,16 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 150732, sign: false }); + data.append(FP16x16 { mag: 192675, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_scales_linear_half_pixel_symmetric/output_0.cairo b/tests/nodes/resize_upsample_scales_linear_half_pixel_symmetric/output_0.cairo new file mode 100644 index 000000000..4ab735561 --- /dev/null +++ b/tests/nodes/resize_upsample_scales_linear_half_pixel_symmetric/output_0.cairo @@ -0,0 +1,35 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 76012, sign: false }); + data.append(FP16x16 { mag: 98304, sign: false }); + data.append(FP16x16 { mag: 120595, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 102578, sign: false }); + data.append(FP16x16 { mag: 113054, sign: false }); + data.append(FP16x16 { mag: 135346, sign: false }); + data.append(FP16x16 { mag: 157637, sign: false }); + data.append(FP16x16 { mag: 168114, sign: false }); + data.append(FP16x16 { mag: 159565, sign: false }); + data.append(FP16x16 { mag: 170042, sign: false }); + data.append(FP16x16 { mag: 192333, sign: false }); + data.append(FP16x16 { mag: 214625, sign: false }); + data.append(FP16x16 { mag: 225101, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 207084, sign: false }); + data.append(FP16x16 { mag: 229376, sign: false }); + data.append(FP16x16 { mag: 251667, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_scales_nearest.cairo b/tests/nodes/resize_upsample_scales_nearest.cairo new file mode 100644 index 000000000..f4a6c422f --- /dev/null +++ b/tests/nodes/resize_upsample_scales_nearest.cairo @@ -0,0 +1,39 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::operators::tensor::FP16x16Tensor; +use orion::operators::tensor::{TensorTrait, Tensor}; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::math::resize::{ + MODE, NEAREST_MODE, KEEP_ASPECT_RATIO_POLICY, TRANSFORMATION_MODE +}; + +#[test] +#[available_gas(2000000000)] +fn test_resize_upsample_scales_nearest() { + let data = input_0::input_0(); + let mut scales = Option::Some(input_1::input_1().data); + let z_0 = output_0::output_0(); + + let y_0 = data + .resize( + Option::None, + scales, + Option::None, + Option::None, + Option::None, + Option::Some(TRANSFORMATION_MODE::HALF_PIXEL_SYMMETRIC), + Option::None, + Option::None, + Option::None, + Option::None, + Option::Some(MODE::NEAREST), + Option::None, + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/resize_upsample_scales_nearest/input_0.cairo b/tests/nodes/resize_upsample_scales_nearest/input_0.cairo new file mode 100644 index 000000000..84a850bc4 --- /dev/null +++ b/tests/nodes/resize_upsample_scales_nearest/input_0.cairo @@ -0,0 +1,19 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_scales_nearest/input_1.cairo b/tests/nodes/resize_upsample_scales_nearest/input_1.cairo new file mode 100644 index 000000000..2f880ced5 --- /dev/null +++ b/tests/nodes/resize_upsample_scales_nearest/input_1.cairo @@ -0,0 +1,16 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_scales_nearest/output_0.cairo b/tests/nodes/resize_upsample_scales_nearest/output_0.cairo new file mode 100644 index 000000000..61e0d0daf --- /dev/null +++ b/tests/nodes/resize_upsample_scales_nearest/output_0.cairo @@ -0,0 +1,39 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(4); + shape.append(6); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_scales_nearest_axes_2_3.cairo b/tests/nodes/resize_upsample_scales_nearest_axes_2_3.cairo new file mode 100644 index 000000000..b492b7128 --- /dev/null +++ b/tests/nodes/resize_upsample_scales_nearest_axes_2_3.cairo @@ -0,0 +1,41 @@ +mod input_0; +mod input_1; +mod input_2; +mod output_0; + + +use orion::operators::tensor::FP16x16TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::FP16x16Tensor; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::math::resize::{ + MODE, NEAREST_MODE, KEEP_ASPECT_RATIO_POLICY, TRANSFORMATION_MODE +}; + +#[test] +#[available_gas(2000000000)] +fn test_resize_upsample_scales_nearest_axes_2_3() { + let data = input_0::input_0(); + let mut scales = Option::Some(input_1::input_1().data); + let axes = Option::Some(input_2::input_2().data); + let z_0 = output_0::output_0(); + + let y_0 = data + .resize( + Option::None, + scales, + Option::None, + Option::None, + axes, + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::Some(MODE::NEAREST), + Option::None, + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/resize_upsample_scales_nearest_axes_2_3/input_0.cairo b/tests/nodes/resize_upsample_scales_nearest_axes_2_3/input_0.cairo new file mode 100644 index 000000000..84a850bc4 --- /dev/null +++ b/tests/nodes/resize_upsample_scales_nearest_axes_2_3/input_0.cairo @@ -0,0 +1,19 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_scales_nearest_axes_2_3/input_1.cairo b/tests/nodes/resize_upsample_scales_nearest_axes_2_3/input_1.cairo new file mode 100644 index 000000000..01532dbdb --- /dev/null +++ b/tests/nodes/resize_upsample_scales_nearest_axes_2_3/input_1.cairo @@ -0,0 +1,14 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_scales_nearest_axes_2_3/input_2.cairo b/tests/nodes/resize_upsample_scales_nearest_axes_2_3/input_2.cairo new file mode 100644 index 000000000..6388a13b0 --- /dev/null +++ b/tests/nodes/resize_upsample_scales_nearest_axes_2_3/input_2.cairo @@ -0,0 +1,13 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_2() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(2); + data.append(3); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_scales_nearest_axes_2_3/output_0.cairo b/tests/nodes/resize_upsample_scales_nearest_axes_2_3/output_0.cairo new file mode 100644 index 000000000..61e0d0daf --- /dev/null +++ b/tests/nodes/resize_upsample_scales_nearest_axes_2_3/output_0.cairo @@ -0,0 +1,39 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(4); + shape.append(6); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_scales_nearest_axes_3_2.cairo b/tests/nodes/resize_upsample_scales_nearest_axes_3_2.cairo new file mode 100644 index 000000000..51b857b25 --- /dev/null +++ b/tests/nodes/resize_upsample_scales_nearest_axes_3_2.cairo @@ -0,0 +1,41 @@ +mod input_0; +mod input_1; +mod input_2; +mod output_0; + + +use orion::operators::tensor::FP16x16TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::FP16x16Tensor; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::math::resize::{ + MODE, NEAREST_MODE, KEEP_ASPECT_RATIO_POLICY, TRANSFORMATION_MODE +}; + +#[test] +#[available_gas(2000000000)] +fn test_resize_upsample_scales_nearest_axes_3_2() { + let data = input_0::input_0(); + let mut scales = Option::Some(input_1::input_1().data); + let axes = Option::Some(input_2::input_2().data); + let z_0 = output_0::output_0(); + + let y_0 = data + .resize( + Option::None, + scales, + Option::None, + Option::None, + axes, + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::Some(MODE::NEAREST), + Option::None, + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/resize_upsample_scales_nearest_axes_3_2/input_0.cairo b/tests/nodes/resize_upsample_scales_nearest_axes_3_2/input_0.cairo new file mode 100644 index 000000000..84a850bc4 --- /dev/null +++ b/tests/nodes/resize_upsample_scales_nearest_axes_3_2/input_0.cairo @@ -0,0 +1,19 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_scales_nearest_axes_3_2/input_1.cairo b/tests/nodes/resize_upsample_scales_nearest_axes_3_2/input_1.cairo new file mode 100644 index 000000000..f6a98ccbd --- /dev/null +++ b/tests/nodes/resize_upsample_scales_nearest_axes_3_2/input_1.cairo @@ -0,0 +1,14 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_scales_nearest_axes_3_2/input_2.cairo b/tests/nodes/resize_upsample_scales_nearest_axes_3_2/input_2.cairo new file mode 100644 index 000000000..775bcc066 --- /dev/null +++ b/tests/nodes/resize_upsample_scales_nearest_axes_3_2/input_2.cairo @@ -0,0 +1,13 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_2() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(3); + data.append(2); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_scales_nearest_axes_3_2/output_0.cairo b/tests/nodes/resize_upsample_scales_nearest_axes_3_2/output_0.cairo new file mode 100644 index 000000000..61e0d0daf --- /dev/null +++ b/tests/nodes/resize_upsample_scales_nearest_axes_3_2/output_0.cairo @@ -0,0 +1,39 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(4); + shape.append(6); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_sizes_cubic.cairo b/tests/nodes/resize_upsample_sizes_cubic.cairo new file mode 100644 index 000000000..cfeee00ba --- /dev/null +++ b/tests/nodes/resize_upsample_sizes_cubic.cairo @@ -0,0 +1,39 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::FP16x16Tensor; +use orion::operators::tensor::FP16x16TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::math::resize::{ + MODE, NEAREST_MODE, KEEP_ASPECT_RATIO_POLICY, TRANSFORMATION_MODE +}; + +#[test] +#[available_gas(2000000000)] +fn test_resize_upsample_sizes_cubic() { + let data = input_0::input_0(); + let mut sizes = Option::Some(input_1::input_1().data); + let z_0 = output_0::output_0(); + + let y_0 = data + .resize( + Option::None, + Option::None, + sizes, + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::Some(MODE::CUBIC), + Option::None, + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/resize_upsample_sizes_cubic/input_0.cairo b/tests/nodes/resize_upsample_sizes_cubic/input_0.cairo new file mode 100644 index 000000000..cd17b9adf --- /dev/null +++ b/tests/nodes/resize_upsample_sizes_cubic/input_0.cairo @@ -0,0 +1,31 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(4); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 851968, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_sizes_cubic/input_1.cairo b/tests/nodes/resize_upsample_sizes_cubic/input_1.cairo new file mode 100644 index 000000000..ac41c83c4 --- /dev/null +++ b/tests/nodes/resize_upsample_sizes_cubic/input_1.cairo @@ -0,0 +1,15 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(1); + data.append(1); + data.append(9); + data.append(10); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_sizes_cubic/output_0.cairo b/tests/nodes/resize_upsample_sizes_cubic/output_0.cairo new file mode 100644 index 000000000..7363ef635 --- /dev/null +++ b/tests/nodes/resize_upsample_sizes_cubic/output_0.cairo @@ -0,0 +1,105 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(9); + shape.append(10); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 29824, sign: false }); + data.append(FP16x16 { mag: 41981, sign: false }); + data.append(FP16x16 { mag: 63673, sign: false }); + data.append(FP16x16 { mag: 93230, sign: false }); + data.append(FP16x16 { mag: 124998, sign: false }); + data.append(FP16x16 { mag: 145708, sign: false }); + data.append(FP16x16 { mag: 177476, sign: false }); + data.append(FP16x16 { mag: 207033, sign: false }); + data.append(FP16x16 { mag: 228725, sign: false }); + data.append(FP16x16 { mag: 240882, sign: false }); + data.append(FP16x16 { mag: 91382, sign: false }); + data.append(FP16x16 { mag: 103538, sign: false }); + data.append(FP16x16 { mag: 125231, sign: false }); + data.append(FP16x16 { mag: 154788, sign: false }); + data.append(FP16x16 { mag: 186556, sign: false }); + data.append(FP16x16 { mag: 207266, sign: false }); + data.append(FP16x16 { mag: 239034, sign: false }); + data.append(FP16x16 { mag: 268591, sign: false }); + data.append(FP16x16 { mag: 290283, sign: false }); + data.append(FP16x16 { mag: 302440, sign: false }); + data.append(FP16x16 { mag: 193416, sign: false }); + data.append(FP16x16 { mag: 205573, sign: false }); + data.append(FP16x16 { mag: 227266, sign: false }); + data.append(FP16x16 { mag: 256822, sign: false }); + data.append(FP16x16 { mag: 288591, sign: false }); + data.append(FP16x16 { mag: 309300, sign: false }); + data.append(FP16x16 { mag: 341069, sign: false }); + data.append(FP16x16 { mag: 370626, sign: false }); + data.append(FP16x16 { mag: 392318, sign: false }); + data.append(FP16x16 { mag: 404475, sign: false }); + data.append(FP16x16 { mag: 341131, sign: false }); + data.append(FP16x16 { mag: 353288, sign: false }); + data.append(FP16x16 { mag: 374980, sign: false }); + data.append(FP16x16 { mag: 404537, sign: false }); + data.append(FP16x16 { mag: 436305, sign: false }); + data.append(FP16x16 { mag: 457015, sign: false }); + data.append(FP16x16 { mag: 488783, sign: false }); + data.append(FP16x16 { mag: 518340, sign: false }); + data.append(FP16x16 { mag: 540033, sign: false }); + data.append(FP16x16 { mag: 552190, sign: false }); + data.append(FP16x16 { mag: 451526, sign: false }); + data.append(FP16x16 { mag: 463683, sign: false }); + data.append(FP16x16 { mag: 485376, sign: false }); + data.append(FP16x16 { mag: 514932, sign: false }); + data.append(FP16x16 { mag: 546701, sign: false }); + data.append(FP16x16 { mag: 567410, sign: false }); + data.append(FP16x16 { mag: 599179, sign: false }); + data.append(FP16x16 { mag: 628736, sign: false }); + data.append(FP16x16 { mag: 650428, sign: false }); + data.append(FP16x16 { mag: 662585, sign: false }); + data.append(FP16x16 { mag: 561922, sign: false }); + data.append(FP16x16 { mag: 574078, sign: false }); + data.append(FP16x16 { mag: 595771, sign: false }); + data.append(FP16x16 { mag: 625328, sign: false }); + data.append(FP16x16 { mag: 657096, sign: false }); + data.append(FP16x16 { mag: 677806, sign: false }); + data.append(FP16x16 { mag: 709574, sign: false }); + data.append(FP16x16 { mag: 739131, sign: false }); + data.append(FP16x16 { mag: 760823, sign: false }); + data.append(FP16x16 { mag: 772980, sign: false }); + data.append(FP16x16 { mag: 709636, sign: false }); + data.append(FP16x16 { mag: 721793, sign: false }); + data.append(FP16x16 { mag: 743485, sign: false }); + data.append(FP16x16 { mag: 773042, sign: false }); + data.append(FP16x16 { mag: 804811, sign: false }); + data.append(FP16x16 { mag: 825520, sign: false }); + data.append(FP16x16 { mag: 857289, sign: false }); + data.append(FP16x16 { mag: 886845, sign: false }); + data.append(FP16x16 { mag: 908538, sign: false }); + data.append(FP16x16 { mag: 920695, sign: false }); + data.append(FP16x16 { mag: 811671, sign: false }); + data.append(FP16x16 { mag: 823828, sign: false }); + data.append(FP16x16 { mag: 845520, sign: false }); + data.append(FP16x16 { mag: 875077, sign: false }); + data.append(FP16x16 { mag: 906845, sign: false }); + data.append(FP16x16 { mag: 927555, sign: false }); + data.append(FP16x16 { mag: 959323, sign: false }); + data.append(FP16x16 { mag: 988880, sign: false }); + data.append(FP16x16 { mag: 1010573, sign: false }); + data.append(FP16x16 { mag: 1022729, sign: false }); + data.append(FP16x16 { mag: 873229, sign: false }); + data.append(FP16x16 { mag: 885386, sign: false }); + data.append(FP16x16 { mag: 907078, sign: false }); + data.append(FP16x16 { mag: 936635, sign: false }); + data.append(FP16x16 { mag: 968403, sign: false }); + data.append(FP16x16 { mag: 989113, sign: false }); + data.append(FP16x16 { mag: 1020881, sign: false }); + data.append(FP16x16 { mag: 1050438, sign: false }); + data.append(FP16x16 { mag: 1072131, sign: false }); + data.append(FP16x16 { mag: 1084287, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_sizes_nearest.cairo b/tests/nodes/resize_upsample_sizes_nearest.cairo new file mode 100644 index 000000000..2c766a209 --- /dev/null +++ b/tests/nodes/resize_upsample_sizes_nearest.cairo @@ -0,0 +1,39 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::FP16x16Tensor; +use orion::operators::tensor::FP16x16TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::math::resize::{ + MODE, NEAREST_MODE, KEEP_ASPECT_RATIO_POLICY, TRANSFORMATION_MODE +}; + +#[test] +#[available_gas(2000000000)] +fn test_resize_upsample_sizes_nearest() { + let data = input_0::input_0(); + let mut sizes = Option::Some(input_1::input_1().data); + let z_0 = output_0::output_0(); + + let y_0 = data + .resize( + Option::None, + Option::None, + sizes, + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::Some(MODE::NEAREST), + Option::None, + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/resize_upsample_sizes_nearest/input_0.cairo b/tests/nodes/resize_upsample_sizes_nearest/input_0.cairo new file mode 100644 index 000000000..84a850bc4 --- /dev/null +++ b/tests/nodes/resize_upsample_sizes_nearest/input_0.cairo @@ -0,0 +1,19 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_sizes_nearest/input_1.cairo b/tests/nodes/resize_upsample_sizes_nearest/input_1.cairo new file mode 100644 index 000000000..816910620 --- /dev/null +++ b/tests/nodes/resize_upsample_sizes_nearest/input_1.cairo @@ -0,0 +1,15 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(1); + data.append(1); + data.append(7); + data.append(8); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_sizes_nearest/output_0.cairo b/tests/nodes/resize_upsample_sizes_nearest/output_0.cairo new file mode 100644 index 000000000..592bfbcd8 --- /dev/null +++ b/tests/nodes/resize_upsample_sizes_nearest/output_0.cairo @@ -0,0 +1,71 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(7); + shape.append(8); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_sizes_nearest_axes_2_3.cairo b/tests/nodes/resize_upsample_sizes_nearest_axes_2_3.cairo new file mode 100644 index 000000000..8074ecf36 --- /dev/null +++ b/tests/nodes/resize_upsample_sizes_nearest_axes_2_3.cairo @@ -0,0 +1,41 @@ +mod input_0; +mod input_1; +mod input_2; +mod output_0; + + +use orion::operators::tensor::FP16x16TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::FP16x16Tensor; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::math::resize::{ + MODE, NEAREST_MODE, KEEP_ASPECT_RATIO_POLICY, TRANSFORMATION_MODE +}; + +#[test] +#[available_gas(2000000000)] +fn test_resize_upsample_sizes_nearest_axes_2_3() { + let data = input_0::input_0(); + let mut sizes = Option::Some(input_1::input_1().data); + let axes = Option::Some(input_2::input_2().data); + let z_0 = output_0::output_0(); + + let y_0 = data + .resize( + Option::None, + Option::None, + sizes, + Option::None, + axes, + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::Some(MODE::NEAREST), + Option::None, + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/resize_upsample_sizes_nearest_axes_2_3/input_0.cairo b/tests/nodes/resize_upsample_sizes_nearest_axes_2_3/input_0.cairo new file mode 100644 index 000000000..84a850bc4 --- /dev/null +++ b/tests/nodes/resize_upsample_sizes_nearest_axes_2_3/input_0.cairo @@ -0,0 +1,19 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_sizes_nearest_axes_2_3/input_1.cairo b/tests/nodes/resize_upsample_sizes_nearest_axes_2_3/input_1.cairo new file mode 100644 index 000000000..5fd88f168 --- /dev/null +++ b/tests/nodes/resize_upsample_sizes_nearest_axes_2_3/input_1.cairo @@ -0,0 +1,13 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(7); + data.append(8); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_sizes_nearest_axes_2_3/input_2.cairo b/tests/nodes/resize_upsample_sizes_nearest_axes_2_3/input_2.cairo new file mode 100644 index 000000000..6388a13b0 --- /dev/null +++ b/tests/nodes/resize_upsample_sizes_nearest_axes_2_3/input_2.cairo @@ -0,0 +1,13 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_2() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(2); + data.append(3); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_sizes_nearest_axes_2_3/output_0.cairo b/tests/nodes/resize_upsample_sizes_nearest_axes_2_3/output_0.cairo new file mode 100644 index 000000000..592bfbcd8 --- /dev/null +++ b/tests/nodes/resize_upsample_sizes_nearest_axes_2_3/output_0.cairo @@ -0,0 +1,71 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(7); + shape.append(8); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_sizes_nearest_axes_3_2.cairo b/tests/nodes/resize_upsample_sizes_nearest_axes_3_2.cairo new file mode 100644 index 000000000..8edc4daa2 --- /dev/null +++ b/tests/nodes/resize_upsample_sizes_nearest_axes_3_2.cairo @@ -0,0 +1,45 @@ +mod input_0; +mod input_1; +mod input_2; +mod output_0; + + +use orion::operators::tensor::U32TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::FP16x16Tensor; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::operators::tensor::U32Tensor; + +use orion::operators::tensor::math::resize::{ + MODE, NEAREST_MODE, KEEP_ASPECT_RATIO_POLICY, TRANSFORMATION_MODE +}; + + +#[test] +#[available_gas(2000000000)] +fn test_resize_upsample_sizes_nearest_axes_3_2() { + let data = input_0::input_0(); + let mut sizes = Option::Some(input_1::input_1().data); + let axes = Option::Some(input_2::input_2().data); + let z_0 = output_0::output_0(); + + let y_0 = data + .resize( + Option::None, + Option::None, + sizes, + Option::None, + axes, + Option::None, + Option::None, + Option::None, + Option::None, + Option::None, + Option::Some(MODE::NEAREST), + Option::None, + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/resize_upsample_sizes_nearest_axes_3_2/input_0.cairo b/tests/nodes/resize_upsample_sizes_nearest_axes_3_2/input_0.cairo new file mode 100644 index 000000000..84a850bc4 --- /dev/null +++ b/tests/nodes/resize_upsample_sizes_nearest_axes_3_2/input_0.cairo @@ -0,0 +1,19 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_sizes_nearest_axes_3_2/input_1.cairo b/tests/nodes/resize_upsample_sizes_nearest_axes_3_2/input_1.cairo new file mode 100644 index 000000000..a1e135ae0 --- /dev/null +++ b/tests/nodes/resize_upsample_sizes_nearest_axes_3_2/input_1.cairo @@ -0,0 +1,13 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(8); + data.append(7); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_sizes_nearest_axes_3_2/input_2.cairo b/tests/nodes/resize_upsample_sizes_nearest_axes_3_2/input_2.cairo new file mode 100644 index 000000000..775bcc066 --- /dev/null +++ b/tests/nodes/resize_upsample_sizes_nearest_axes_3_2/input_2.cairo @@ -0,0 +1,13 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_2() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(3); + data.append(2); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_sizes_nearest_axes_3_2/output_0.cairo b/tests/nodes/resize_upsample_sizes_nearest_axes_3_2/output_0.cairo new file mode 100644 index 000000000..592bfbcd8 --- /dev/null +++ b/tests/nodes/resize_upsample_sizes_nearest_axes_3_2/output_0.cairo @@ -0,0 +1,71 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(7); + shape.append(8); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_sizes_nearest_ceil_half_pixel.cairo b/tests/nodes/resize_upsample_sizes_nearest_ceil_half_pixel.cairo new file mode 100644 index 000000000..51a9f5216 --- /dev/null +++ b/tests/nodes/resize_upsample_sizes_nearest_ceil_half_pixel.cairo @@ -0,0 +1,43 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::operators::tensor::FP16x16TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::FP16x16Tensor; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::U32Tensor; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32TensorPartialEq; + +use orion::operators::tensor::math::resize::{ + MODE, NEAREST_MODE, KEEP_ASPECT_RATIO_POLICY, TRANSFORMATION_MODE +}; + + +#[test] +#[available_gas(2000000000)] +fn test_resize_upsample_sizes_nearest_ceil_half_pixel() { + let data = input_0::input_0(); + let mut sizes = Option::Some(input_1::input_1().data); + let z_0 = output_0::output_0(); + + let y_0 = data + .resize( + Option::None, + Option::None, + sizes, + Option::None, + Option::None, + Option::Some(TRANSFORMATION_MODE::HALF_PIXEL), + Option::None, + Option::None, + Option::None, + Option::None, + Option::Some(MODE::NEAREST), + Option::Some(NEAREST_MODE::CEIL), + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/resize_upsample_sizes_nearest_ceil_half_pixel/input_0.cairo b/tests/nodes/resize_upsample_sizes_nearest_ceil_half_pixel/input_0.cairo new file mode 100644 index 000000000..cd17b9adf --- /dev/null +++ b/tests/nodes/resize_upsample_sizes_nearest_ceil_half_pixel/input_0.cairo @@ -0,0 +1,31 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(4); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 851968, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_sizes_nearest_ceil_half_pixel/input_1.cairo b/tests/nodes/resize_upsample_sizes_nearest_ceil_half_pixel/input_1.cairo new file mode 100644 index 000000000..1b0cce2c6 --- /dev/null +++ b/tests/nodes/resize_upsample_sizes_nearest_ceil_half_pixel/input_1.cairo @@ -0,0 +1,15 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(1); + data.append(1); + data.append(8); + data.append(8); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_sizes_nearest_ceil_half_pixel/output_0.cairo b/tests/nodes/resize_upsample_sizes_nearest_ceil_half_pixel/output_0.cairo new file mode 100644 index 000000000..9c8571539 --- /dev/null +++ b/tests/nodes/resize_upsample_sizes_nearest_ceil_half_pixel/output_0.cairo @@ -0,0 +1,79 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(8); + shape.append(8); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 851968, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + data.append(FP16x16 { mag: 851968, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + data.append(FP16x16 { mag: 851968, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_sizes_nearest_floor_align_corners.cairo b/tests/nodes/resize_upsample_sizes_nearest_floor_align_corners.cairo new file mode 100644 index 000000000..172beaa84 --- /dev/null +++ b/tests/nodes/resize_upsample_sizes_nearest_floor_align_corners.cairo @@ -0,0 +1,41 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::operators::tensor::FP16x16TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::FP16x16Tensor; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::U32Tensor; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32TensorPartialEq; +use orion::operators::tensor::math::resize::{ + MODE, NEAREST_MODE, KEEP_ASPECT_RATIO_POLICY, TRANSFORMATION_MODE +}; + +#[test] +#[available_gas(2000000000)] +fn test_resize_upsample_sizes_nearest_floor_align_corners() { + let data = input_0::input_0(); + let mut sizes = Option::Some(input_1::input_1().data); + let z_0 = output_0::output_0(); + + let y_0 = data + .resize( + Option::None, + Option::None, + sizes, + Option::None, + Option::None, + Option::Some(TRANSFORMATION_MODE::ALIGN_CORNERS), + Option::None, + Option::None, + Option::None, + Option::None, + Option::Some(MODE::NEAREST), + Option::Some(NEAREST_MODE::FLOOR), + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/resize_upsample_sizes_nearest_floor_align_corners/input_0.cairo b/tests/nodes/resize_upsample_sizes_nearest_floor_align_corners/input_0.cairo new file mode 100644 index 000000000..cd17b9adf --- /dev/null +++ b/tests/nodes/resize_upsample_sizes_nearest_floor_align_corners/input_0.cairo @@ -0,0 +1,31 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(4); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 851968, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_sizes_nearest_floor_align_corners/input_1.cairo b/tests/nodes/resize_upsample_sizes_nearest_floor_align_corners/input_1.cairo new file mode 100644 index 000000000..1b0cce2c6 --- /dev/null +++ b/tests/nodes/resize_upsample_sizes_nearest_floor_align_corners/input_1.cairo @@ -0,0 +1,15 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(1); + data.append(1); + data.append(8); + data.append(8); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_sizes_nearest_floor_align_corners/output_0.cairo b/tests/nodes/resize_upsample_sizes_nearest_floor_align_corners/output_0.cairo new file mode 100644 index 000000000..bcfaad1bd --- /dev/null +++ b/tests/nodes/resize_upsample_sizes_nearest_floor_align_corners/output_0.cairo @@ -0,0 +1,79 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(8); + shape.append(8); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 851968, sign: false }); + data.append(FP16x16 { mag: 851968, sign: false }); + data.append(FP16x16 { mag: 851968, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_sizes_nearest_not_larger.cairo b/tests/nodes/resize_upsample_sizes_nearest_not_larger.cairo new file mode 100644 index 000000000..4e5e6221b --- /dev/null +++ b/tests/nodes/resize_upsample_sizes_nearest_not_larger.cairo @@ -0,0 +1,45 @@ +mod input_0; +mod input_1; +mod input_2; +mod output_0; + + +use orion::operators::tensor::U32TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::FP16x16Tensor; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::operators::tensor::U32Tensor; + +use orion::operators::tensor::math::resize::{ + MODE, NEAREST_MODE, KEEP_ASPECT_RATIO_POLICY, TRANSFORMATION_MODE +}; + + +#[test] +#[available_gas(2000000000)] +fn test_resize_upsample_sizes_nearest_not_larger() { + let data = input_0::input_0(); + let mut sizes = Option::Some(input_1::input_1().data); + let axes = Option::Some(input_2::input_2().data); + let z_0 = output_0::output_0(); + + let y_0 = data + .resize( + Option::None, + Option::None, + sizes, + Option::None, + axes, + Option::None, + Option::None, + Option::None, + Option::None, + Option::Some(KEEP_ASPECT_RATIO_POLICY::NOT_LARGER), + Option::Some(MODE::NEAREST), + Option::None, + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/resize_upsample_sizes_nearest_not_larger/input_0.cairo b/tests/nodes/resize_upsample_sizes_nearest_not_larger/input_0.cairo new file mode 100644 index 000000000..84a850bc4 --- /dev/null +++ b/tests/nodes/resize_upsample_sizes_nearest_not_larger/input_0.cairo @@ -0,0 +1,19 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_sizes_nearest_not_larger/input_1.cairo b/tests/nodes/resize_upsample_sizes_nearest_not_larger/input_1.cairo new file mode 100644 index 000000000..5fd88f168 --- /dev/null +++ b/tests/nodes/resize_upsample_sizes_nearest_not_larger/input_1.cairo @@ -0,0 +1,13 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(7); + data.append(8); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_sizes_nearest_not_larger/input_2.cairo b/tests/nodes/resize_upsample_sizes_nearest_not_larger/input_2.cairo new file mode 100644 index 000000000..6388a13b0 --- /dev/null +++ b/tests/nodes/resize_upsample_sizes_nearest_not_larger/input_2.cairo @@ -0,0 +1,13 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_2() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(2); + data.append(3); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_sizes_nearest_not_larger/output_0.cairo b/tests/nodes/resize_upsample_sizes_nearest_not_larger/output_0.cairo new file mode 100644 index 000000000..4bbc98a19 --- /dev/null +++ b/tests/nodes/resize_upsample_sizes_nearest_not_larger/output_0.cairo @@ -0,0 +1,64 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(7); + shape.append(7); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_sizes_nearest_not_smaller.cairo b/tests/nodes/resize_upsample_sizes_nearest_not_smaller.cairo new file mode 100644 index 000000000..25563fa16 --- /dev/null +++ b/tests/nodes/resize_upsample_sizes_nearest_not_smaller.cairo @@ -0,0 +1,45 @@ +mod input_0; +mod input_1; +mod input_2; +mod output_0; + + +use orion::operators::tensor::U32TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::FP16x16Tensor; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::operators::tensor::U32Tensor; + +use orion::operators::tensor::math::resize::{ + MODE, NEAREST_MODE, KEEP_ASPECT_RATIO_POLICY, TRANSFORMATION_MODE +}; + + +#[test] +#[available_gas(2000000000)] +fn test_resize_upsample_sizes_nearest_not_smaller() { + let data = input_0::input_0(); + let mut sizes = Option::Some(input_1::input_1().data); + let axes = Option::Some(input_2::input_2().data); + let z_0 = output_0::output_0(); + + let y_0 = data + .resize( + Option::None, + Option::None, + sizes, + Option::None, + axes, + Option::None, + Option::None, + Option::None, + Option::None, + Option::Some(KEEP_ASPECT_RATIO_POLICY::NOT_SMALLER), + Option::Some(MODE::NEAREST), + Option::None, + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/resize_upsample_sizes_nearest_not_smaller/input_0.cairo b/tests/nodes/resize_upsample_sizes_nearest_not_smaller/input_0.cairo new file mode 100644 index 000000000..84a850bc4 --- /dev/null +++ b/tests/nodes/resize_upsample_sizes_nearest_not_smaller/input_0.cairo @@ -0,0 +1,19 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_sizes_nearest_not_smaller/input_1.cairo b/tests/nodes/resize_upsample_sizes_nearest_not_smaller/input_1.cairo new file mode 100644 index 000000000..5fd88f168 --- /dev/null +++ b/tests/nodes/resize_upsample_sizes_nearest_not_smaller/input_1.cairo @@ -0,0 +1,13 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(7); + data.append(8); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_sizes_nearest_not_smaller/input_2.cairo b/tests/nodes/resize_upsample_sizes_nearest_not_smaller/input_2.cairo new file mode 100644 index 000000000..6388a13b0 --- /dev/null +++ b/tests/nodes/resize_upsample_sizes_nearest_not_smaller/input_2.cairo @@ -0,0 +1,13 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_2() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(2); + data.append(3); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_sizes_nearest_not_smaller/output_0.cairo b/tests/nodes/resize_upsample_sizes_nearest_not_smaller/output_0.cairo new file mode 100644 index 000000000..aeac198d9 --- /dev/null +++ b/tests/nodes/resize_upsample_sizes_nearest_not_smaller/output_0.cairo @@ -0,0 +1,79 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(8); + shape.append(8); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric.cairo b/tests/nodes/resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric.cairo new file mode 100644 index 000000000..27cc03ec8 --- /dev/null +++ b/tests/nodes/resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric.cairo @@ -0,0 +1,41 @@ +mod input_0; +mod input_1; +mod output_0; + + +use orion::operators::tensor::FP16x16TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::FP16x16Tensor; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::U32Tensor; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32TensorPartialEq; +use orion::operators::tensor::math::resize::{ + MODE, NEAREST_MODE, KEEP_ASPECT_RATIO_POLICY, TRANSFORMATION_MODE +}; + +#[test] +#[available_gas(2000000000)] +fn test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric() { + let data = input_0::input_0(); + let mut sizes = Option::Some(input_1::input_1().data); + let z_0 = output_0::output_0(); + + let y_0 = data + .resize( + Option::None, + Option::None, + sizes, + Option::None, + Option::None, + Option::Some(TRANSFORMATION_MODE::ASYMMETRIC), + Option::None, + Option::None, + Option::None, + Option::None, + Option::Some(MODE::NEAREST), + Option::Some(NEAREST_MODE::ROUND_PREFER_CEIL), + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric/input_0.cairo b/tests/nodes/resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric/input_0.cairo new file mode 100644 index 000000000..cd17b9adf --- /dev/null +++ b/tests/nodes/resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric/input_0.cairo @@ -0,0 +1,31 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(4); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 851968, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric/input_1.cairo b/tests/nodes/resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric/input_1.cairo new file mode 100644 index 000000000..1b0cce2c6 --- /dev/null +++ b/tests/nodes/resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric/input_1.cairo @@ -0,0 +1,15 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(1); + data.append(1); + data.append(8); + data.append(8); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric/output_0.cairo b/tests/nodes/resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric/output_0.cairo new file mode 100644 index 000000000..9c8571539 --- /dev/null +++ b/tests/nodes/resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric/output_0.cairo @@ -0,0 +1,79 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(1); + shape.append(1); + shape.append(8); + shape.append(8); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 65536, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 196608, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 327680, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 393216, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 458752, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 524288, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 589824, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 655360, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 720896, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 851968, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + data.append(FP16x16 { mag: 851968, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + data.append(FP16x16 { mag: 851968, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 917504, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 983040, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + data.append(FP16x16 { mag: 1048576, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} From 5c783620455f098aae2320917267db5fdbf969bd Mon Sep 17 00:00:00 2001 From: chachaleo Date: Sat, 6 Jan 2024 11:00:57 +0100 Subject: [PATCH 21/38] small fix --- tests/lib.cairo | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/lib.cairo b/tests/lib.cairo index 661d3cf0e..c408347ef 100644 --- a/tests/lib.cairo +++ b/tests/lib.cairo @@ -1,8 +1,8 @@ -//mod numbers; -//mod performance; -//mod tensor_core; +mod numbers; +mod performance; +mod tensor_core; mod nodes; -//mod ml; -//mod operators; +mod ml; +mod operators; From 14bdb6bfb5af95cfbbbfcca62fd93951c43f6dd5 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sat, 6 Jan 2024 09:49:35 -0800 Subject: [PATCH 22/38] remove unused prints --- src/operators/ml/tree_ensemble/tree_ensemble_regressor.cairo | 1 - src/operators/tensor/math/bitwise_and.cairo | 2 -- tests/ml/tree_ensemble_regressor.cairo | 3 --- 3 files changed, 6 deletions(-) diff --git a/src/operators/ml/tree_ensemble/tree_ensemble_regressor.cairo b/src/operators/ml/tree_ensemble/tree_ensemble_regressor.cairo index 7c2efcb97..9848efd9e 100644 --- a/src/operators/ml/tree_ensemble/tree_ensemble_regressor.cairo +++ b/src/operators/ml/tree_ensemble/tree_ensemble_regressor.cairo @@ -434,7 +434,6 @@ fn compute_res_AVERAGE< ) { let n_trees_felt: felt252 = (n_trees * 65536).into(); let n_trees: T = NumberTrait::from_felt(n_trees_felt); - n_trees_felt.print(); loop { match t_index.pop_front() { Option::Some(its) => { diff --git a/src/operators/tensor/math/bitwise_and.cairo b/src/operators/tensor/math/bitwise_and.cairo index c27d77938..e3487568b 100644 --- a/src/operators/tensor/math/bitwise_and.cairo +++ b/src/operators/tensor/math/bitwise_and.cairo @@ -20,12 +20,10 @@ fn bitwise_and< >( y: @Tensor, z: @Tensor ) -> Tensor { - 'check_compatibility'.print(); let broadcasted_shape = broadcast_shape(*y.shape, *z.shape); let mut result: Array = ArrayTrait::::new(); let num_elements = len_from_shape(broadcasted_shape); - 'checked'.print(); let mut n: usize = 0; loop { diff --git a/tests/ml/tree_ensemble_regressor.cairo b/tests/ml/tree_ensemble_regressor.cairo index 23c9274a7..2ee505774 100644 --- a/tests/ml/tree_ensemble_regressor.cairo +++ b/tests/ml/tree_ensemble_regressor.cairo @@ -38,7 +38,6 @@ fn test_tree_ensemble_regressor_AVERAGE() { let mut res = TreeEnsembleRegressorTrait::predict(ref regressor, X); let check = @res.get(1, 0).unwrap().mag; - (*check).print(); // ASSERT RES assert( @@ -63,7 +62,6 @@ fn test_tree_ensemble_regressor_MIN() { let mut res = TreeEnsembleRegressorTrait::predict(ref regressor, X); let check = @res.get(1, 0).unwrap().mag; - (*check).print(); // ASSERT RES assert( @@ -88,7 +86,6 @@ fn test_tree_ensemble_regressor_MAX() { let mut res = TreeEnsembleRegressorTrait::predict(ref regressor, X); let check = @res.get(1, 0).unwrap().mag; - (*check).print(); // ASSERT RES assert( From 4a0de581fadb21254b4c6416d33f69eb2c56e48d Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 7 Jan 2024 19:47:33 -0500 Subject: [PATCH 23/38] remove sequences functions from tensor complex --- .../implementations/tensor_complex64.cairo | 35 ------------------- 1 file changed, 35 deletions(-) diff --git a/src/operators/tensor/implementations/tensor_complex64.cairo b/src/operators/tensor/implementations/tensor_complex64.cairo index 810f347ca..31ff3f9e6 100644 --- a/src/operators/tensor/implementations/tensor_complex64.cairo +++ b/src/operators/tensor/implementations/tensor_complex64.cairo @@ -419,29 +419,12 @@ impl Complex64Tensor of TensorTrait { math::gather_elements::gather_elements(self, indices, axis) } - fn sequence_length(self: Array>) -> Tensor { - math::sequence_length::sequence_length(self) - } - fn shrink( self: Tensor, bias: Option, lambd: Option ) -> Tensor { panic(array!['not supported!']) } - fn sequence_at(sequence: Array>, position: Tensor) -> Tensor { - math::sequence_at::sequence_at(sequence, position) - } - - fn sequence_construct(tensors: Array>) -> Array> { - math::sequence_construct::sequence_construct(tensors) - } - - - fn sequence_empty() -> Array> { - math::sequence_empty::sequence_empty::() - } - fn reduce_mean( self: @Tensor, axes: Option>, @@ -464,18 +447,6 @@ impl Complex64Tensor of TensorTrait { math::pow::pow(self, other) } - fn sequence_erase( - sequence: Array>, position: Option> - ) -> Array> { - math::sequence_erase::sequence_erase(sequence, position) - } - - fn sequence_insert( - self: Array>, tensor: @Tensor, position: Option> - ) -> Array> { - math::sequence_insert::sequence_insert(self, tensor, position) - } - fn is_inf( self: @Tensor, detect_negative: Option, detect_positive: Option ) -> Tensor { @@ -486,12 +457,6 @@ impl Complex64Tensor of TensorTrait { panic(array!['not supported!']) } - fn concat_from_sequence( - sequence: Array>, axis: i32, new_axis: Option - ) -> Tensor { - math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) - } - fn reduce_log_sum(self: @Tensor, axis: usize, keepdims: bool) -> Tensor { math::reduce_log_sum::reduce_log_sum(self, axis, keepdims) } From f504b867c1db4a61b090b98fad692be2c4bca713 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Sun, 7 Jan 2024 19:55:54 -0500 Subject: [PATCH 24/38] Update tensor_complex64.cairo --- src/operators/tensor/implementations/tensor_complex64.cairo | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/operators/tensor/implementations/tensor_complex64.cairo b/src/operators/tensor/implementations/tensor_complex64.cairo index 31ff3f9e6..53feb8980 100644 --- a/src/operators/tensor/implementations/tensor_complex64.cairo +++ b/src/operators/tensor/implementations/tensor_complex64.cairo @@ -471,6 +471,10 @@ impl Complex64Tensor of TensorTrait { ) -> (Tensor, Tensor, Tensor, Tensor) { panic(array!['not supported!']) } + + fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { + math::compress::compress(self, condition, axis) + } } /// Implements addition for `Tensor` using the `Add` trait. From 1a73ddb2cf1c524df6d9eafd3a3a06a30f814299 Mon Sep 17 00:00:00 2001 From: Vid Kersic Date: Sat, 6 Jan 2024 23:12:05 +0100 Subject: [PATCH 25/38] fix: enable 1d tensor transpose --- src/operators/tensor/linalg/transpose.cairo | 5 ++++- tests/operators/transpose_test.cairo | 12 ++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/src/operators/tensor/linalg/transpose.cairo b/src/operators/tensor/linalg/transpose.cairo index 44df929ba..c7bb96da7 100644 --- a/src/operators/tensor/linalg/transpose.cairo +++ b/src/operators/tensor/linalg/transpose.cairo @@ -11,7 +11,10 @@ use orion::numbers::NumberTrait; fn transpose, impl TCopy: Copy, impl TDrop: Drop>( self: @Tensor, axes: Span ) -> Tensor { - assert((*self.shape).len() > 1, 'cannot transpose a 1D tensor'); + if (*self.shape).len() == 1 { + return self.identity(); + } + assert(axes.len() == (*self.shape).len(), 'shape and axes length unequal'); if (*self.shape).len() == 2 { diff --git a/tests/operators/transpose_test.cairo b/tests/operators/transpose_test.cairo index ffe4e663c..ee720be28 100644 --- a/tests/operators/transpose_test.cairo +++ b/tests/operators/transpose_test.cairo @@ -25,6 +25,18 @@ fn transpose_test_values() { assert(result.data == array![0, 2, 4, 6, 1, 3, 5, 7].span(), 'wrong data'); } +#[test] +#[available_gas(200000000000)] +fn transpose_test_1D() { + let tensor = TensorTrait::< + u32 + >::new(shape: array![4].span(), data: array![0, 1, 2, 3].span(),); + + let result = tensor.transpose(axes: array![0].span()); + + assert(result.shape == array![4].span(), 'wrong shape'); + assert(result.data == array![0, 1, 2, 3].span(), 'wrong data'); +} #[test] #[available_gas(200000000000)] From a913e5fb8ebaae82c067f60f7cf038c08edf3736 Mon Sep 17 00:00:00 2001 From: chachaleo Date: Mon, 8 Jan 2024 11:10:43 +0100 Subject: [PATCH 26/38] google colab for tree ensemble reg and softmax0 for lin reg and classifier --- .../tree-ensemble-regressor/README.md | 3 + .../ml/linear/linear_classifier.cairo | 20 ++++-- .../ml/linear/linear_regressor.cairo | 4 +- tests/ml/linear_classifier_test.cairo | 61 +++++++++++++++++++ 4 files changed, 79 insertions(+), 9 deletions(-) diff --git a/docs/framework/operators/machine-learning/tree-ensemble-regressor/README.md b/docs/framework/operators/machine-learning/tree-ensemble-regressor/README.md index cbd3747cf..ae21bfcb1 100644 --- a/docs/framework/operators/machine-learning/tree-ensemble-regressor/README.md +++ b/docs/framework/operators/machine-learning/tree-ensemble-regressor/README.md @@ -14,6 +14,9 @@ Orion supports currently only fixed point data types for `TreeEnsembleRegressorT | -------------------- | ------------------------------------------------------------- | | Fixed point (signed) | `TreeRegressorTrait` | +### How to construct `TreeEnsembleRegressor` + +You can utilize [this notebook](https://colab.research.google.com/drive/1zZC0tM7I5Mt542_cBsxaWcGPWzgxybGs?usp=sharing#scrollTo=VkXxLxDejrf3) to translate parameters from your ONNX TreeEnsembleRegressor model into Cairo code. Efforts are underway to integrate this functionality into Giza-CLI, aiming to enhance the user experience. *** diff --git a/src/operators/ml/linear/linear_classifier.cairo b/src/operators/ml/linear/linear_classifier.cairo index 21e524e4c..2230500bc 100644 --- a/src/operators/ml/linear/linear_classifier.cairo +++ b/src/operators/ml/linear/linear_classifier.cairo @@ -196,9 +196,7 @@ impl LinearClassifierImpl< POST_TRANSFORM::NONE => { scores }, POST_TRANSFORM::SOFTMAX => { NNTrait::softmax(@scores, 1) }, POST_TRANSFORM::LOGISTIC => { NNTrait::sigmoid(@scores) }, - POST_TRANSFORM::SOFTMAXZERO => core::panic_with_felt252( - 'Softmax_zero not supported yet' - ), + POST_TRANSFORM::SOFTMAXZERO => { NNTrait::softmax_zero(@scores, 1)}, POST_TRANSFORM::PROBIT => core::panic_with_felt252('Probit not supported yet'), }; @@ -254,9 +252,19 @@ impl LinearClassifierImpl< i += 1; }; }, - POST_TRANSFORM::SOFTMAXZERO => core::panic_with_felt252( - 'Softmax_zero not supported yet' - ), + POST_TRANSFORM::SOFTMAXZERO => { + loop { + if i == scores.data.len() { + break; + } + if *scores.data.at(i) >= NumberTrait::half() { + labels_list.append(*classlabels[0]); + } else { + labels_list.append(0); + } + i += 1; + }; + }, POST_TRANSFORM::PROBIT => core::panic_with_felt252('Probit not supported yet'), }; } diff --git a/src/operators/ml/linear/linear_regressor.cairo b/src/operators/ml/linear/linear_regressor.cairo index c9374a4cd..85aec9560 100644 --- a/src/operators/ml/linear/linear_regressor.cairo +++ b/src/operators/ml/linear/linear_regressor.cairo @@ -215,9 +215,7 @@ impl LinearRegressorImpl< POST_TRANSFORM::NONE => score, // No action required POST_TRANSFORM::SOFTMAX => NNTrait::softmax(@score, 1), POST_TRANSFORM::LOGISTIC => NNTrait::sigmoid(@score), - POST_TRANSFORM::SOFTMAXZERO => core::panic_with_felt252( - 'Softmax_zero not supported yet' - ), + POST_TRANSFORM::SOFTMAXZERO => NNTrait::softmax_zero(@score, 1), POST_TRANSFORM::PROBIT => core::panic_with_felt252('Probit not supported yet'), }; diff --git a/tests/ml/linear_classifier_test.cairo b/tests/ml/linear_classifier_test.cairo index 3ed27cfcc..1a1c90e7d 100644 --- a/tests/ml/linear_classifier_test.cairo +++ b/tests/ml/linear_classifier_test.cairo @@ -57,6 +57,31 @@ fn test_linear_classifier_multi_softmax() { assert(*scores.data[8] == FP16x16 { mag: 57241, sign: false }, '*scores[8] == 0.87344'); } +#[test] +#[available_gas(200000000000)] +fn test_linear_classifier_multi_softmax_zero() { + let (mut classifier, X) = linear_classifier_helper(POST_TRANSFORM::SOFTMAXZERO); + + let (labels, mut scores) = LinearClassifierTrait::predict(ref classifier, X); + + // ASSERT LABELS + assert(*labels[0] == 0, 'labels[0]'); + assert(*labels[1] == 2, 'labels[1]'); + assert(*labels[2] == 2, 'labels[2]'); + assert(labels.len() == 3, 'len(labels)'); + + // ASSERT SCORES + assert(*scores.data[0] == FP16x16 { mag: 55879, sign: false }, '*scores[0] == 0.852656'); + assert(*scores.data[1] == FP16x16 { mag: 602, sign: false }, '*scores[1] == 0.009192'); + assert(*scores.data[2] == FP16x16 { mag: 9053, sign: false }, '*scores[2] == 0.138152'); + assert(*scores.data[3] == FP16x16 { mag: 20888, sign: false }, '*scores[3] == 0.318722'); + assert(*scores.data[4] == FP16x16 { mag: 3418, sign: false }, '*scores[4] == 0.05216'); + assert(*scores.data[5] == FP16x16 { mag: 41229, sign: false }, '*scores[5] == 0.629118'); + assert(*scores.data[6] == FP16x16 { mag: 2380, sign: false }, '*scores[6] == 0.036323'); + assert(*scores.data[7] == FP16x16 { mag: 5914, sign: false }, '*scores[7] == 0.090237'); + assert(*scores.data[8] == FP16x16 { mag: 57241, sign: false }, '*scores[8] == 0.87344'); +} + #[test] #[available_gas(200000000000)] @@ -139,6 +164,25 @@ fn test_linear_classifier_binary_softmax() { assert(*scores.data[3] == FP16x16 { mag: 65535, sign: false }, '*scores[3] == 9.999983e-01'); } +#[test] +#[available_gas(200000000000)] +fn test_linear_classifier_binary_softmax_zero() { + let (mut classifier, X) = linear_classifier_helper_binary(POST_TRANSFORM::SOFTMAXZERO); + + let (labels, mut scores) = LinearClassifierTrait::predict(ref classifier, X); + // ASSERT LABELS + assert(*labels[0] == 1, 'labels[0]'); + assert(*labels[1] == 1, 'labels[1]'); + assert(labels.len() == 2, 'len(labels)'); + + // ASSERT SCORES + assert(*scores.data[0] == FP16x16 { mag: 0, sign: false }, '*scores[0] == 5.276517e-09'); + assert(*scores.data[1] == FP16x16 { mag: 65535, sign: false }, '*scores[1] == 1.000000'); + assert(*scores.data[2] == FP16x16 { mag: 0, sign: false }, '*scores[2] == 1.674492e-06'); + assert(*scores.data[3] == FP16x16 { mag: 65535, sign: false }, '*scores[3] == 9.999983e-01'); + +} + #[test] #[available_gas(200000000000)] fn test_linear_classifier_unary_none() { @@ -190,6 +234,23 @@ fn test_linear_classifier_unary_softmax() { assert(*scores.data[1] == FP16x16 { mag: 65536, sign: false }, '*scores[1] == 1'); } +#[test] +#[available_gas(200000000000)] +fn test_linear_classifier_unary_softmax_zero() { + let (mut classifier, X) = linear_classifier_helper_unary(POST_TRANSFORM::SOFTMAXZERO); + + let (labels, mut scores) = LinearClassifierTrait::predict(ref classifier, X); + + // ASSERT LABELS + assert(*labels[0] == 1, 'labels[0]'); + assert(*labels[1] == 1, 'labels[1]'); + assert(labels.len() == 2, 'len(labels)'); + + // ASSERT SCORES + assert(*scores.data[0] == FP16x16 { mag: 65536, sign: false }, '*scores[0] == 1'); + assert(*scores.data[1] == FP16x16 { mag: 65536, sign: false }, '*scores[1] == 1'); +} + // ============ HELPER ============ // From 1cb65e2ec1e732daafb242a25f3404747a0fb29e Mon Sep 17 00:00:00 2001 From: zhangzhichao Date: Wed, 10 Jan 2024 16:05:37 +0800 Subject: [PATCH 27/38] feat: feat split operator in orion --- .../operators/tensor/tensor.split.md | 47 ++++ nodegen/node/split.py | 261 ++++++++++++++++++ src/operators/tensor/core.cairo | 51 ++++ .../tensor/implementations/tensor_bool.cairo | 4 + .../implementations/tensor_fp16x16.cairo | 6 + .../implementations/tensor_fp16x16wide.cairo | 6 + .../implementations/tensor_fp32x32.cairo | 6 + .../implementations/tensor_fp64x64.cairo | 6 + .../implementations/tensor_fp8x23.cairo | 6 + .../implementations/tensor_fp8x23wide.cairo | 6 + .../tensor/implementations/tensor_i32.cairo | 6 + .../tensor/implementations/tensor_i8.cairo | 6 + .../tensor/implementations/tensor_u32.cairo | 6 + src/operators/tensor/manipulation.cairo | 1 + src/operators/tensor/manipulation/split.cairo | 195 +++++++++++++ tests/nodes.cairo | 14 + .../input_0.cairo | 18 ++ .../output_0.cairo | 37 +++ .../split_fp16x16_1d_uneven/input_0.cairo | 19 ++ .../split_fp16x16_1d_uneven/output_0.cairo | 45 +++ .../input_0.cairo | 18 ++ .../output_0.cairo | 30 ++ .../input_0.cairo | 25 ++ .../output_0.cairo | 38 +++ .../split_fp16x16_2d_uneven/input_0.cairo | 29 ++ .../split_fp16x16_2d_uneven/output_0.cairo | 50 ++++ .../input_0.cairo | 25 ++ .../output_0.cairo | 38 +++ .../split_fp16x16_zero_size/input_0.cairo | 12 + .../split_fp16x16_zero_size/output_0.cairo | 31 +++ .../split_u32_1d_equal_parts/input_0.cairo | 17 ++ .../split_u32_1d_equal_parts/output_0.cairo | 36 +++ tests/nodes/split_u32_1d_uneven/input_0.cairo | 18 ++ .../nodes/split_u32_1d_uneven/output_0.cairo | 44 +++ .../split_u32_1d_variable_parts/input_0.cairo | 17 ++ .../output_0.cairo | 29 ++ .../split_u32_2d_equal_parts/input_0.cairo | 24 ++ .../split_u32_2d_equal_parts/output_0.cairo | 37 +++ tests/nodes/split_u32_2d_uneven/input_0.cairo | 28 ++ .../nodes/split_u32_2d_uneven/output_0.cairo | 49 ++++ .../split_u32_2d_variable_parts/input_0.cairo | 24 ++ .../output_0.cairo | 37 +++ tests/nodes/split_u32_zero_size/input_0.cairo | 11 + .../nodes/split_u32_zero_size/output_0.cairo | 30 ++ 44 files changed, 1443 insertions(+) create mode 100644 docs/framework/operators/tensor/tensor.split.md create mode 100644 nodegen/node/split.py create mode 100644 src/operators/tensor/manipulation/split.cairo create mode 100644 tests/nodes/split_fp16x16_1d_equal_parts/input_0.cairo create mode 100644 tests/nodes/split_fp16x16_1d_equal_parts/output_0.cairo create mode 100644 tests/nodes/split_fp16x16_1d_uneven/input_0.cairo create mode 100644 tests/nodes/split_fp16x16_1d_uneven/output_0.cairo create mode 100644 tests/nodes/split_fp16x16_1d_variable_parts/input_0.cairo create mode 100644 tests/nodes/split_fp16x16_1d_variable_parts/output_0.cairo create mode 100644 tests/nodes/split_fp16x16_2d_equal_parts/input_0.cairo create mode 100644 tests/nodes/split_fp16x16_2d_equal_parts/output_0.cairo create mode 100644 tests/nodes/split_fp16x16_2d_uneven/input_0.cairo create mode 100644 tests/nodes/split_fp16x16_2d_uneven/output_0.cairo create mode 100644 tests/nodes/split_fp16x16_2d_variable_parts/input_0.cairo create mode 100644 tests/nodes/split_fp16x16_2d_variable_parts/output_0.cairo create mode 100644 tests/nodes/split_fp16x16_zero_size/input_0.cairo create mode 100644 tests/nodes/split_fp16x16_zero_size/output_0.cairo create mode 100644 tests/nodes/split_u32_1d_equal_parts/input_0.cairo create mode 100644 tests/nodes/split_u32_1d_equal_parts/output_0.cairo create mode 100644 tests/nodes/split_u32_1d_uneven/input_0.cairo create mode 100644 tests/nodes/split_u32_1d_uneven/output_0.cairo create mode 100644 tests/nodes/split_u32_1d_variable_parts/input_0.cairo create mode 100644 tests/nodes/split_u32_1d_variable_parts/output_0.cairo create mode 100644 tests/nodes/split_u32_2d_equal_parts/input_0.cairo create mode 100644 tests/nodes/split_u32_2d_equal_parts/output_0.cairo create mode 100644 tests/nodes/split_u32_2d_uneven/input_0.cairo create mode 100644 tests/nodes/split_u32_2d_uneven/output_0.cairo create mode 100644 tests/nodes/split_u32_2d_variable_parts/input_0.cairo create mode 100644 tests/nodes/split_u32_2d_variable_parts/output_0.cairo create mode 100644 tests/nodes/split_u32_zero_size/input_0.cairo create mode 100644 tests/nodes/split_u32_zero_size/output_0.cairo diff --git a/docs/framework/operators/tensor/tensor.split.md b/docs/framework/operators/tensor/tensor.split.md new file mode 100644 index 000000000..26b4a546f --- /dev/null +++ b/docs/framework/operators/tensor/tensor.split.md @@ -0,0 +1,47 @@ +# tensor.split + +```rust + fn split(self: @Tensor, axis: usize, num_outputs: Option, split: Option> + ) -> Array>; +``` + +Split a tensor into a list of tensors, along the specified β€˜axis’ + + +* `self`(`@Tensor`) - The input tensor. +* `axis`(`usize`) - The axis along which to split on. +* `num_outputs `(Option) - Number of outputs to split parts of the tensor into. +* `split `(Option>) - Optional length of each output. + +## Panics + +* Panics if the 'axis' accepted range is not [-rank, rank-1] where r = rank(input). +* Panics if the 'split' values not >= 0. Sum of the values is not equal to the dim value at β€˜axis’ specified. +* Panics if the input 'split' or the attribute 'num_outputs' both are specified or not. + +## Returns + +One or more outputs forming list of tensors after splitting. + +## Examples + +```rust +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor}; +use core::option::OptionTrait; +fn split_tensor_example() -> Array> { + let tensor: Tensor = TensorTrait::::new( + shape: array![2,4].span(), + data: array![ + 0, 1, 2, 3, 4, 5, 6, 7 + ].span(), + ); + let num_outputs = Option::Some(2); + // split = Option::Some(array![1, 1].span()); + let split_num: Option> = Option::None(()); + // We can call `split` function as follows. + return tensor.split(0, num_outputs, split_num); +} +>>> [[0,1],[4,5]] + [[2,3],[6,7]] +``` diff --git a/nodegen/node/split.py b/nodegen/node/split.py new file mode 100644 index 000000000..8e765141e --- /dev/null +++ b/nodegen/node/split.py @@ -0,0 +1,261 @@ +import numpy as np +from nodegen.node import RunAll +from ..helpers import make_test, to_fp, Tensor, Dtype, FixedImpl + + +class Split(RunAll): + @staticmethod + def split_u32(): + def split_1D(): + x = np.random.randint(0, 255, 6).astype(np.uint32) + y = [ + np.array(x[0:2]).astype(np.uint32), + np.array(x[2:4]).astype(np.uint32), + np.array(x[4:6]).astype(np.uint32), + ] + + _x = Tensor(Dtype.U32, x.shape, x.flatten()) + _y = [ + Tensor(Dtype.U32, y[0].shape, y[0].flatten()), + Tensor(Dtype.U32, y[1].shape, y[1].flatten()), + Tensor(Dtype.U32, y[2].shape, y[2].flatten()), + ] + + name = "split_u32_1d_equal_parts" + make_test( + [_x], _y, "input_0.split(0, Option::Some(3), Option::None(()))", name) + y = [ + np.array(x[0:2]).astype(np.uint32), + np.array(x[2:6]).astype(np.uint32), + ] + _y = [ + Tensor(Dtype.U32, y[0].shape, y[0].flatten()), + Tensor(Dtype.U32, y[1].shape, y[1].flatten()), + ] + name = "split_u32_1d_variable_parts" + make_test( + [_x], _y, "input_0.split(0, Option::None(()), Option::Some(TensorTrait::::new(shape: array![2].span(), data: array![2, 4].span(),)))", name) + def split_2D(): + x = np.random.randint(0, 255, (2, 6)).astype(np.uint32) + y = [ + np.array(x[0:2, 0:3]).astype(np.uint32), + np.array(x[0:2, 3:6]).astype(np.uint32), + ] + _x = Tensor(Dtype.U32, x.shape, x.flatten()) + _y = [ + Tensor(Dtype.U32, y[0].shape, y[0].flatten()), + Tensor(Dtype.U32, y[1].shape, y[1].flatten()), + ] + name = "split_u32_2d_equal_parts" + make_test( + [_x], _y, "input_0.split(1, Option::Some(2), Option::None(()))", name) + + y = [ + np.array(x[0:2, 0:2]).astype(np.uint32), + np.array(x[0:2, 2:6]).astype(np.uint32) + ] + _y = [ + Tensor(Dtype.U32, y[0].shape, y[0].flatten()), + Tensor(Dtype.U32, y[1].shape, y[1].flatten()), + ] + name = "split_u32_2d_variable_parts" + make_test( + [_x], _y, "input_0.split(1, Option::None(()), Option::Some(TensorTrait::::new(shape: array![2].span(), data: array![2, 4].span(),)))", name) + + def split_zero_size(): + # 1-dimensional tensor with dimension_size=0 + x = np.array([]).astype(np.uint32) + y = [ + np.array([]).astype(np.uint32), + np.array([]).astype(np.uint32), + np.array([]).astype(np.uint32), + ] + _x = Tensor(Dtype.U32, x.shape, x.flatten()) + _y = [ + Tensor(Dtype.U32, y[0].shape, y[0].flatten()), + Tensor(Dtype.U32, y[1].shape, y[1].flatten()), + Tensor(Dtype.U32, y[2].shape, y[2].flatten()), + ] + # Split emtpy tensor to tensors of size zero + name = "split_u32_zero_size" + make_test( + [_x], _y, "input_0.split(0, Option::None(()), Option::Some(TensorTrait::::new(shape: array![3].span(), data: array![0, 0, 0].span(),)))", name) + + + def split_1d_uneven(): + x = np.random.randint(0, 255, 7).astype(np.uint32) + y = [ + np.array(x[0:2]).astype(np.uint32), + np.array(x[2:4]).astype(np.uint32), + np.array(x[4:6]).astype(np.uint32), + np.array(x[6:7]).astype(np.uint32), + ] + + _x = Tensor(Dtype.U32, x.shape, x.flatten()) + _y = [ + Tensor(Dtype.U32, y[0].shape, y[0].flatten()), + Tensor(Dtype.U32, y[1].shape, y[1].flatten()), + Tensor(Dtype.U32, y[2].shape, y[2].flatten()), + Tensor(Dtype.U32, y[3].shape, y[3].flatten()), + ] + + name = "split_u32_1d_uneven" + make_test( + [_x], _y, "input_0.split(0, Option::Some(4), Option::None(()))", name) + + + def split_2d_uneven(): + x = np.random.randint(0, 255, (2, 8)).astype(np.uint32) + y = [ + np.array(x[0:2, 0:3]).astype(np.uint32), + np.array(x[0:2, 3:6]).astype(np.uint32), + np.array(x[0:2, 6:8]).astype(np.uint32) + ] + _x = Tensor(Dtype.U32, x.shape, x.flatten()) + _y = [ + Tensor(Dtype.U32, y[0].shape, y[0].flatten()), + Tensor(Dtype.U32, y[1].shape, y[1].flatten()), + Tensor(Dtype.U32, y[2].shape, y[2].flatten()), + ] + + name = "split_u32_2d_uneven" + make_test( + [_x], _y, "input_0.split(1, Option::Some(3), Option::None(()))", name) + + + split_1D() + split_2D() + split_zero_size() + split_1d_uneven() + split_2d_uneven() + + @staticmethod + def split_fp16x16(): + def split_1D(): + x = to_fp(np.random.randint(-127, 127, 6 + ).astype(np.int64), FixedImpl.FP16x16) + y = [ + np.array(x[0:2]).astype(np.int64), + np.array(x[2:4]).astype(np.int64), + np.array(x[4:6]).astype(np.int64), + ] + + _x = Tensor(Dtype.FP16x16, x.shape, x.flatten()) + _y = [ + Tensor(Dtype.FP16x16, y[0].shape, y[0].flatten()), + Tensor(Dtype.FP16x16, y[1].shape, y[1].flatten()), + Tensor(Dtype.FP16x16, y[2].shape, y[2].flatten()), + ] + + name = "split_fp16x16_1d_equal_parts" + make_test( + [_x], _y, "input_0.split(0, Option::Some(3), Option::None(()))", name) + y = [ + np.array(x[0:2]).astype(np.int64), + np.array(x[2:6]).astype(np.int64), + ] + _y = [ + Tensor(Dtype.FP16x16, y[0].shape, y[0].flatten()), + Tensor(Dtype.FP16x16, y[1].shape, y[1].flatten()), + ] + name = "split_fp16x16_1d_variable_parts" + make_test( + [_x], _y, "input_0.split(0, Option::None(()), Option::Some(TensorTrait::::new(shape: array![2].span(), data: array![2, 4].span(),)))", name) + def split_2D(): + x = to_fp(np.random.randint(-127, 127, (2, 6) + ).astype(np.int64), FixedImpl.FP16x16) + y = [ + np.array(x[0:2, 0:3]).astype(np.int64), + np.array(x[0:2, 3:6]).astype(np.int64), + ] + _x = Tensor(Dtype.FP16x16, x.shape, x.flatten()) + _y = [ + Tensor(Dtype.FP16x16, y[0].shape, y[0].flatten()), + Tensor(Dtype.FP16x16, y[1].shape, y[1].flatten()), + ] + name = "split_fp16x16_2d_equal_parts" + make_test( + [_x], _y, "input_0.split(1, Option::Some(2), Option::None(()))", name) + + y = [ + np.array(x[0:2, 0:2]).astype(np.int64), + np.array(x[0:2, 2:6]).astype(np.int64) + ] + _y = [ + Tensor(Dtype.FP16x16, y[0].shape, y[0].flatten()), + Tensor(Dtype.FP16x16, y[1].shape, y[1].flatten()), + ] + name = "split_fp16x16_2d_variable_parts" + make_test( + [_x], _y, "input_0.split(1, Option::None(()), Option::Some(TensorTrait::::new(shape: array![2].span(), data: array![2, 4].span(),)))", name) + + def split_zero_size(): + # 1-dimensional tensor with dimension_size=0 + x = to_fp(np.array([]).astype(np.int64 + ).astype(np.int64), FixedImpl.FP16x16) + y = [ + np.array([]).astype(np.int64), + np.array([]).astype(np.int64), + np.array([]).astype(np.int64), + ] + _x = Tensor(Dtype.FP16x16, x.shape, x.flatten()) + _y = [ + Tensor(Dtype.FP16x16, y[0].shape, y[0].flatten()), + Tensor(Dtype.FP16x16, y[1].shape, y[1].flatten()), + Tensor(Dtype.FP16x16, y[2].shape, y[2].flatten()), + ] + # Split emtpy tensor to tensors of size zero + name = "split_fp16x16_zero_size" + make_test( + [_x], _y, "input_0.split(0, Option::None(()), Option::Some(TensorTrait::::new(shape: array![3].span(), data: array![0, 0, 0].span(),)))", name) + + + def split_1d_uneven(): + x = to_fp(np.random.randint(-127, 127, 7 + ).astype(np.int64), FixedImpl.FP16x16) + y = [ + np.array(x[0:2]).astype(np.int64), + np.array(x[2:4]).astype(np.int64), + np.array(x[4:6]).astype(np.int64), + np.array(x[6:7]).astype(np.int64), + ] + + _x = Tensor(Dtype.FP16x16, x.shape, x.flatten()) + _y = [ + Tensor(Dtype.FP16x16, y[0].shape, y[0].flatten()), + Tensor(Dtype.FP16x16, y[1].shape, y[1].flatten()), + Tensor(Dtype.FP16x16, y[2].shape, y[2].flatten()), + Tensor(Dtype.FP16x16, y[3].shape, y[3].flatten()), + ] + + name = "split_fp16x16_1d_uneven" + make_test( + [_x], _y, "input_0.split(0, Option::Some(4), Option::None(()))", name) + + + def split_2d_uneven(): + x = to_fp(np.random.randint(-127, 127, (2, 8) + ).astype(np.int64), FixedImpl.FP16x16) + y = [ + np.array(x[0:2, 0:3]).astype(np.int64), + np.array(x[0:2, 3:6]).astype(np.int64), + np.array(x[0:2, 6:8]).astype(np.int64) + ] + _x = Tensor(Dtype.FP16x16, x.shape, x.flatten()) + _y = [ + Tensor(Dtype.FP16x16, y[0].shape, y[0].flatten()), + Tensor(Dtype.FP16x16, y[1].shape, y[1].flatten()), + Tensor(Dtype.FP16x16, y[2].shape, y[2].flatten()), + ] + + name = "split_fp16x16_2d_uneven" + make_test( + [_x], _y, "input_0.split(1, Option::Some(3), Option::None(()))", name) + + + split_1D() + split_2D() + split_zero_size() + split_1d_uneven() + split_2d_uneven() + \ No newline at end of file diff --git a/src/operators/tensor/core.cairo b/src/operators/tensor/core.cairo index decc2e343..a55279b33 100644 --- a/src/operators/tensor/core.cairo +++ b/src/operators/tensor/core.cairo @@ -124,6 +124,7 @@ impl TensorSerde, impl TDrop: Drop> of Serde= 1, indices tensor of rank q >= 1, and batch_dims integer b, this operator gathers slices of data into an output tensor of rank q + r - indices_shape[-1] - 1 - b. /// reduce_log_sum - Computes the log sum of the input tensor's elements along the provided axes. /// erf - Computes the error function of the given input tensor element-wise. +/// split - Split a tensor into a list of tensors, along the specified β€˜axis’. trait TensorTrait { /// # tensor.new /// @@ -5077,6 +5078,56 @@ trait TensorTrait { /// ``` /// fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor; + /// # tensor.split + /// + /// ```rust + /// fn split(self: @Tensor, axis: usize, num_outputs: Option, split: Option> + /// ) -> Array>; + /// ``` + /// + /// Split a tensor into a list of tensors, along the specified β€˜axis’ + /// + /// + /// * `self`(`@Tensor`) - The input tensor. + /// * `axis`(`usize`) - The axis along which to split on. + /// * `num_outputs `(Option) - Number of outputs to split parts of the tensor into. + /// * `split `(Option>) - Optional length of each output. + /// + /// ## Panics + /// + /// * Panics if the 'axis' accepted range is not [-rank, rank-1] where r = rank(input). + /// * Panics if the 'split' values not >= 0. Sum of the values is not equal to the dim value at β€˜axis’ specified. + /// * Panics if the input 'split' or the attribute 'num_outputs' both are specified or not. + /// + /// ## Returns + /// + /// One or more outputs forming list of tensors after splitting. + /// + /// ## Examples + /// + /// ```rust + /// use core::array::{ArrayTrait, SpanTrait}; + /// use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor}; + /// use core::option::OptionTrait; + /// fn split_tensor_example() -> Array> { + /// let tensor: Tensor = TensorTrait::::new( + /// shape: array![2,4].span(), + /// data: array![ + /// 0, 1, 2, 3, 4, 5, 6, 7 + /// ].span(), + /// ); + /// let num_outputs = Option::Some(2); + /// // split = Option::Some(array![1, 1].span()); + /// let split_num: Option> = Option::None(()); + /// // We can call `split` function as follows. + /// return tensor.split(0, num_outputs, split_num); + /// } + /// >>> [[0,1],[4,5]] + /// [[2,3],[6,7]] + /// ``` + /// + fn split(self: @Tensor, axis: usize, num_outputs: Option, spl: Option> + ) -> Array>; } /// Cf: TensorTrait::new docstring diff --git a/src/operators/tensor/implementations/tensor_bool.cairo b/src/operators/tensor/implementations/tensor_bool.cairo index d2afe3fc5..759de3073 100644 --- a/src/operators/tensor/implementations/tensor_bool.cairo +++ b/src/operators/tensor/implementations/tensor_bool.cairo @@ -475,6 +475,10 @@ impl BoolTensor of TensorTrait { fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } + + fn split(self: @Tensor, axis: usize, num_outputs: Option, spl: Option>) -> Array> { + panic(array!['not supported!']) + } } /// Implements partial equal for two `Tensor` using the `PartialEq` trait. diff --git a/src/operators/tensor/implementations/tensor_fp16x16.cairo b/src/operators/tensor/implementations/tensor_fp16x16.cairo index ccaf5903d..3398c07ae 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16.cairo @@ -537,6 +537,12 @@ impl FP16x16Tensor of TensorTrait { ) -> (Tensor, Tensor, Tensor, Tensor) { manipulation::unique::unique(self, axis, sorted) } + + fn split( + self: @Tensor, axis: usize, num_outputs: Option, spl: Option> + ) -> Array> { + manipulation::split::split(self, axis, num_outputs, spl) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo index dc32202ed..c71cfe8d2 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo @@ -503,6 +503,12 @@ impl FP16x16WTensor of TensorTrait { ) -> (Tensor, Tensor, Tensor, Tensor) { manipulation::unique::unique(self, axis, sorted) } + + fn split( + self: @Tensor, axis: usize, num_outputs: Option, spl: Option> + ) -> Array> { + manipulation::split::split(self, axis, num_outputs, spl) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp32x32.cairo b/src/operators/tensor/implementations/tensor_fp32x32.cairo index 9100d6f82..81f2f682f 100644 --- a/src/operators/tensor/implementations/tensor_fp32x32.cairo +++ b/src/operators/tensor/implementations/tensor_fp32x32.cairo @@ -538,6 +538,12 @@ impl FP32x32Tensor of TensorTrait { ) -> (Tensor, Tensor, Tensor, Tensor) { manipulation::unique::unique(self, axis, sorted) } + + fn split( + self: @Tensor, axis: usize, num_outputs: Option, spl: Option> + ) -> Array> { + manipulation::split::split(self, axis, num_outputs, spl) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp64x64.cairo b/src/operators/tensor/implementations/tensor_fp64x64.cairo index ee6441058..a483d3f01 100644 --- a/src/operators/tensor/implementations/tensor_fp64x64.cairo +++ b/src/operators/tensor/implementations/tensor_fp64x64.cairo @@ -539,6 +539,12 @@ impl FP64x64Tensor of TensorTrait { ) -> (Tensor, Tensor, Tensor, Tensor) { manipulation::unique::unique(self, axis, sorted) } + + fn split( + self: @Tensor, axis: usize, num_outputs: Option, spl: Option> + ) -> Array> { + manipulation::split::split(self, axis, num_outputs, spl) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp8x23.cairo b/src/operators/tensor/implementations/tensor_fp8x23.cairo index 17a601f7b..a3fe7bd84 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23.cairo @@ -537,6 +537,12 @@ impl FP8x23Tensor of TensorTrait { ) -> (Tensor, Tensor, Tensor, Tensor) { manipulation::unique::unique(self, axis, sorted) } + + fn split( + self: @Tensor, axis: usize, num_outputs: Option, spl: Option> + ) -> Array> { + manipulation::split::split(self, axis, num_outputs, spl) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo index a7d19901b..098a44f35 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo @@ -490,6 +490,12 @@ impl FP8x23WTensor of TensorTrait { ) -> (Tensor, Tensor, Tensor, Tensor) { manipulation::unique::unique(self, axis, sorted) } + + fn split( + self: @Tensor, axis: usize, num_outputs: Option, spl: Option> + ) -> Array> { + manipulation::split::split(self, axis, num_outputs, spl) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_i32.cairo b/src/operators/tensor/implementations/tensor_i32.cairo index a987b0633..177ec73bc 100644 --- a/src/operators/tensor/implementations/tensor_i32.cairo +++ b/src/operators/tensor/implementations/tensor_i32.cairo @@ -534,6 +534,12 @@ impl I32Tensor of TensorTrait { ) -> (Tensor, Tensor, Tensor, Tensor) { manipulation::unique::unique(self, axis, sorted) } + + fn split( + self: @Tensor, axis: usize, num_outputs: Option, spl: Option> + ) -> Array> { + manipulation::split::split(self, axis, num_outputs, spl) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_i8.cairo b/src/operators/tensor/implementations/tensor_i8.cairo index 8c1e2fd32..eafaea150 100644 --- a/src/operators/tensor/implementations/tensor_i8.cairo +++ b/src/operators/tensor/implementations/tensor_i8.cairo @@ -532,6 +532,12 @@ impl I8Tensor of TensorTrait { ) -> (Tensor, Tensor, Tensor, Tensor) { manipulation::unique::unique(self, axis, sorted) } + + fn split( + self: @Tensor, axis: usize, num_outputs: Option, spl: Option> + ) -> Array> { + manipulation::split::split(self, axis, num_outputs, spl) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_u32.cairo b/src/operators/tensor/implementations/tensor_u32.cairo index 5b2058401..43ff21f48 100644 --- a/src/operators/tensor/implementations/tensor_u32.cairo +++ b/src/operators/tensor/implementations/tensor_u32.cairo @@ -475,6 +475,12 @@ impl U32Tensor of TensorTrait { ) -> (Tensor, Tensor, Tensor, Tensor) { manipulation::unique::unique(self, axis, sorted) } + + fn split( + self: @Tensor, axis: usize, num_outputs: Option, spl: Option> + ) -> Array> { + manipulation::split::split(self, axis, num_outputs, spl) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/manipulation.cairo b/src/operators/tensor/manipulation.cairo index b517c624e..584eae027 100644 --- a/src/operators/tensor/manipulation.cairo +++ b/src/operators/tensor/manipulation.cairo @@ -1 +1,2 @@ mod unique; +mod split; diff --git a/src/operators/tensor/manipulation/split.cairo b/src/operators/tensor/manipulation/split.cairo new file mode 100644 index 000000000..436de61ca --- /dev/null +++ b/src/operators/tensor/manipulation/split.cairo @@ -0,0 +1,195 @@ +use orion::operators::tensor::{Tensor, TensorTrait, U32Tensor}; +use core::array::{ArrayTrait, SpanTrait}; +use core::option::OptionTrait; +use orion::operators::matrix::{MutMatrixTrait, MutMatrix, MutMatrixImpl}; + +/// Cf: NNTrait::split docstring +fn split< + T, + +Copy, + +Drop, + +TensorTrait, + +PartialOrd, + +PartialEq, + +PartialEq>, + +PartialOrd> +>( + self: @Tensor, axis: usize, num_outputs: Option, split: Option> +) -> Array> { + let has_num_outputs = match num_outputs { + Option::Some(value) => { + true + }, + Option::None => false, + }; + let has_split = match split { + Option::Some(value) => { + true + }, + Option::None => false, + }; + assert(!(has_num_outputs && has_split), 'split or num_outputs not both.'); + assert(has_num_outputs || has_split, 'split or num_outputs not both.'); + + let mut splited_t: Array> = array![]; + + let rank = (*self).shape.len(); + // assert(axis < rank && axis > -rank, 'axis out of dimensions'); + assert(axis < rank, 'axis out of dimensions'); + + if (has_num_outputs){ + splited_t = split_num_outputs(self, axis, num_outputs.unwrap()); + }else{ + splited_t = split_has_split(self, axis, split.unwrap()); + } + splited_t +} + +/// Subfunction split for tensors (wth num_outputs). +/// Cf: TensorTrait::split docstring +fn split_num_outputs, +Drop, +TensorTrait, +PartialOrd, +PartialEq,>( + t: @Tensor, mut axis: usize, num_outputs: usize +) -> Array> { + let mut splited_t: Array> = array![]; + let mut div: usize = 0; + // consturct split array + let mut split: Array = array![]; + // if axis==0 { + // axis = 1; + // } + if (*(*t).shape.at(axis) % num_outputs == 0){ + div = *(*t).shape.at(axis) / num_outputs; + let mut i = 0; + loop { + if (i>=num_outputs) { + break; + } + split.append(div); + i += 1; + }; + } else { + div = *(*t).shape.at(axis) / num_outputs+1; + let mut i = 0; + loop { + if (i>=num_outputs) { + break; + } + split.append(div); + i += 1; + }; + match split.pop_front(){ + Option::Some(split_last_one) => { + split.append(split_last_one + *(*t).shape.at(axis) - div*(num_outputs-1)); + }, + Option::None(_) => { assert(false, 'split is none array'); } + } + } + + let mut sli: MutMatrix = MutMatrixImpl::new((*t).shape.len(), 2); + let mut pos: usize = 0; + let mut i = 0; + loop { + if (i>=(*t).shape.len()) { + break; + } + let s: usize = *(*t).shape.at(i); + sli.set(i,0,0); + sli.set(i,1,s); + i += 1; + }; + let mut i: usize = 0; + loop { + if (i>=split.len()) { + break; + } + let spl = *split.at(i); + sli.set(axis, 0, pos); + pos += spl; + sli.set(axis, 1, pos); + + let end_ele_0 = match sli.get(axis, 0) { + Option::Some(res) => { + res + }, + Option::None(_) => { + assert(false, 'Get end_ele_0 is failed'); + 0 + }, + }; + let end_ele_1 = match sli.get(axis, 1) { + Option::Some(res) => { + res + }, + Option::None(_) => { + assert(false, 'Get end_ele_0 is failed'); + 0 + }, + }; + let starts: Span = array![sli.get(0,0).unwrap(),end_ele_0].span(); + let ends: Span = array![ sli.get(0,1).unwrap(), end_ele_1].span(); + let axes: Option> = Option::None(()); + let steps: Option> = Option::None(()); + let sub_t: Tensor = t.slice(starts, ends, axes, steps); + splited_t.append(sub_t); + i += 1; + }; + splited_t + +} + +/// Subfunction split for tensors (wth split). +/// Cf: TensorTrait::split docstring +fn split_has_split, +Drop, +TensorTrait, +PartialOrd, +PartialEq,>( + t: @Tensor, axis: usize, split: Tensor +) -> Array> { + let mut splited_t: Array> = array![]; + let mut sli: MutMatrix = MutMatrixImpl::new((*t).shape.len(), 2); + let mut pos: usize = 0; + let mut i = 0; + loop { + if (i>=(*t).shape.len()) { + break; + } + let s: usize = *(*t).shape.at(i); + sli.set(i,0,0); + sli.set(i,1,s); + i += 1; + }; + let mut i: usize = 0; + loop { + if (i>=split.data.len()) { + break; + } + let spl: usize = split.at(indices: array![i].span()); + sli.set(axis, 0, pos); + pos += spl; + sli.set(axis, 1, pos); + + let end_ele_0 = match sli.get(axis, 0) { + Option::Some(res) => { + res + }, + Option::None(_) => { + assert(false, 'Get end_ele_0 is failed'); + 0 + }, + }; + let end_ele_1 = match sli.get(axis, 1) { + Option::Some(res) => { + res + }, + Option::None(_) => { + assert(false, 'Get end_ele_0 is failed'); + 0 + }, + }; + let starts: Span = array![sli.get(0,0).unwrap(),end_ele_0].span(); + let ends: Span = array![ sli.get(0,1).unwrap(), end_ele_1].span(); + let axes: Option> = Option::None(()); + let steps: Option> = Option::None(()); + let sub_t: Tensor = t.slice(starts, ends, axes, steps); + splited_t.append(sub_t); + i += 1; + }; + splited_t +} \ No newline at end of file diff --git a/tests/nodes.cairo b/tests/nodes.cairo index c7155e942..ac064641c 100644 --- a/tests/nodes.cairo +++ b/tests/nodes.cairo @@ -850,3 +850,17 @@ mod gather_nd_i8_3d_batch_dims1; mod gather_nd_u32_default; mod gather_nd_u32_batch_dims1; mod gather_nd_u32_batch_dims2; +mod split_u32_1d_equal_parts; +mod split_u32_2d_equal_parts; +mod split_u32_zero_size; +mod split_u32_1d_variable_parts; +mod split_u32_2d_variable_parts; +mod split_u32_1d_uneven; +mod split_u32_2d_uneven; +mod split_fp16x16_1d_equal_parts; +mod split_fp16x16_1d_variable_parts; +mod split_fp16x16_2d_equal_parts; +mod split_fp16x16_2d_variable_parts; +mod split_fp16x16_zero_size; +mod split_fp16x16_1d_uneven; +mod split_fp16x16_2d_uneven; diff --git a/tests/nodes/split_fp16x16_1d_equal_parts/input_0.cairo b/tests/nodes/split_fp16x16_1d_equal_parts/input_0.cairo new file mode 100644 index 000000000..c5859b6f8 --- /dev/null +++ b/tests/nodes/split_fp16x16_1d_equal_parts/input_0.cairo @@ -0,0 +1,18 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(6); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 1966080, sign: false }); + data.append(FP16x16 { mag: 8257536, sign: false }); + data.append(FP16x16 { mag: 7471104, sign: true }); + data.append(FP16x16 { mag: 4849664, sign: true }); + data.append(FP16x16 { mag: 3407872, sign: false }); + data.append(FP16x16 { mag: 3014656, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/split_fp16x16_1d_equal_parts/output_0.cairo b/tests/nodes/split_fp16x16_1d_equal_parts/output_0.cairo new file mode 100644 index 000000000..43cf3eb64 --- /dev/null +++ b/tests/nodes/split_fp16x16_1d_equal_parts/output_0.cairo @@ -0,0 +1,37 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Array> { + let mut sequence = ArrayTrait::new(); + + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 1966080, sign: false }); + data.append(FP16x16 { mag: 8257536, sign: false }); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 7471104, sign: true }); + data.append(FP16x16 { mag: 4849664, sign: true }); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 3407872, sign: false }); + data.append(FP16x16 { mag: 3014656, sign: false }); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + sequence +} diff --git a/tests/nodes/split_fp16x16_1d_uneven/input_0.cairo b/tests/nodes/split_fp16x16_1d_uneven/input_0.cairo new file mode 100644 index 000000000..8426e2fe9 --- /dev/null +++ b/tests/nodes/split_fp16x16_1d_uneven/input_0.cairo @@ -0,0 +1,19 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(7); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 5701632, sign: true }); + data.append(FP16x16 { mag: 6946816, sign: false }); + data.append(FP16x16 { mag: 2883584, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + data.append(FP16x16 { mag: 2883584, sign: true }); + data.append(FP16x16 { mag: 6160384, sign: false }); + data.append(FP16x16 { mag: 8257536, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/split_fp16x16_1d_uneven/output_0.cairo b/tests/nodes/split_fp16x16_1d_uneven/output_0.cairo new file mode 100644 index 000000000..b5c8384ab --- /dev/null +++ b/tests/nodes/split_fp16x16_1d_uneven/output_0.cairo @@ -0,0 +1,45 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Array> { + let mut sequence = ArrayTrait::new(); + + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 5701632, sign: true }); + data.append(FP16x16 { mag: 6946816, sign: false }); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 2883584, sign: false }); + data.append(FP16x16 { mag: 131072, sign: false }); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 2883584, sign: true }); + data.append(FP16x16 { mag: 6160384, sign: false }); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + let mut shape = ArrayTrait::::new(); + shape.append(1); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 8257536, sign: true }); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + sequence +} diff --git a/tests/nodes/split_fp16x16_1d_variable_parts/input_0.cairo b/tests/nodes/split_fp16x16_1d_variable_parts/input_0.cairo new file mode 100644 index 000000000..c5859b6f8 --- /dev/null +++ b/tests/nodes/split_fp16x16_1d_variable_parts/input_0.cairo @@ -0,0 +1,18 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(6); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 1966080, sign: false }); + data.append(FP16x16 { mag: 8257536, sign: false }); + data.append(FP16x16 { mag: 7471104, sign: true }); + data.append(FP16x16 { mag: 4849664, sign: true }); + data.append(FP16x16 { mag: 3407872, sign: false }); + data.append(FP16x16 { mag: 3014656, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/split_fp16x16_1d_variable_parts/output_0.cairo b/tests/nodes/split_fp16x16_1d_variable_parts/output_0.cairo new file mode 100644 index 000000000..49474c9ff --- /dev/null +++ b/tests/nodes/split_fp16x16_1d_variable_parts/output_0.cairo @@ -0,0 +1,30 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Array> { + let mut sequence = ArrayTrait::new(); + + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 1966080, sign: false }); + data.append(FP16x16 { mag: 8257536, sign: false }); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 7471104, sign: true }); + data.append(FP16x16 { mag: 4849664, sign: true }); + data.append(FP16x16 { mag: 3407872, sign: false }); + data.append(FP16x16 { mag: 3014656, sign: false }); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + sequence +} diff --git a/tests/nodes/split_fp16x16_2d_equal_parts/input_0.cairo b/tests/nodes/split_fp16x16_2d_equal_parts/input_0.cairo new file mode 100644 index 000000000..f9e723397 --- /dev/null +++ b/tests/nodes/split_fp16x16_2d_equal_parts/input_0.cairo @@ -0,0 +1,25 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(6); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 7274496, sign: true }); + data.append(FP16x16 { mag: 1507328, sign: true }); + data.append(FP16x16 { mag: 3604480, sign: false }); + data.append(FP16x16 { mag: 7864320, sign: true }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 6422528, sign: false }); + data.append(FP16x16 { mag: 917504, sign: true }); + data.append(FP16x16 { mag: 3538944, sign: true }); + data.append(FP16x16 { mag: 1835008, sign: false }); + data.append(FP16x16 { mag: 3407872, sign: false }); + data.append(FP16x16 { mag: 3145728, sign: false }); + data.append(FP16x16 { mag: 8257536, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/split_fp16x16_2d_equal_parts/output_0.cairo b/tests/nodes/split_fp16x16_2d_equal_parts/output_0.cairo new file mode 100644 index 000000000..cc8491e5d --- /dev/null +++ b/tests/nodes/split_fp16x16_2d_equal_parts/output_0.cairo @@ -0,0 +1,38 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Array> { + let mut sequence = ArrayTrait::new(); + + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 7274496, sign: true }); + data.append(FP16x16 { mag: 1507328, sign: true }); + data.append(FP16x16 { mag: 3604480, sign: false }); + data.append(FP16x16 { mag: 917504, sign: true }); + data.append(FP16x16 { mag: 3538944, sign: true }); + data.append(FP16x16 { mag: 1835008, sign: false }); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 7864320, sign: true }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 6422528, sign: false }); + data.append(FP16x16 { mag: 3407872, sign: false }); + data.append(FP16x16 { mag: 3145728, sign: false }); + data.append(FP16x16 { mag: 8257536, sign: true }); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + sequence +} diff --git a/tests/nodes/split_fp16x16_2d_uneven/input_0.cairo b/tests/nodes/split_fp16x16_2d_uneven/input_0.cairo new file mode 100644 index 000000000..341bfc1af --- /dev/null +++ b/tests/nodes/split_fp16x16_2d_uneven/input_0.cairo @@ -0,0 +1,29 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(8); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 4980736, sign: false }); + data.append(FP16x16 { mag: 6750208, sign: true }); + data.append(FP16x16 { mag: 6488064, sign: true }); + data.append(FP16x16 { mag: 2490368, sign: false }); + data.append(FP16x16 { mag: 1245184, sign: true }); + data.append(FP16x16 { mag: 1310720, sign: false }); + data.append(FP16x16 { mag: 2686976, sign: true }); + data.append(FP16x16 { mag: 3801088, sign: false }); + data.append(FP16x16 { mag: 4849664, sign: false }); + data.append(FP16x16 { mag: 3538944, sign: true }); + data.append(FP16x16 { mag: 7077888, sign: false }); + data.append(FP16x16 { mag: 262144, sign: true }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 8192000, sign: false }); + data.append(FP16x16 { mag: 1441792, sign: true }); + data.append(FP16x16 { mag: 6553600, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/split_fp16x16_2d_uneven/output_0.cairo b/tests/nodes/split_fp16x16_2d_uneven/output_0.cairo new file mode 100644 index 000000000..fb660f51c --- /dev/null +++ b/tests/nodes/split_fp16x16_2d_uneven/output_0.cairo @@ -0,0 +1,50 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Array> { + let mut sequence = ArrayTrait::new(); + + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 4980736, sign: false }); + data.append(FP16x16 { mag: 6750208, sign: true }); + data.append(FP16x16 { mag: 6488064, sign: true }); + data.append(FP16x16 { mag: 4849664, sign: false }); + data.append(FP16x16 { mag: 3538944, sign: true }); + data.append(FP16x16 { mag: 7077888, sign: false }); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 2490368, sign: false }); + data.append(FP16x16 { mag: 1245184, sign: true }); + data.append(FP16x16 { mag: 1310720, sign: false }); + data.append(FP16x16 { mag: 262144, sign: true }); + data.append(FP16x16 { mag: 262144, sign: false }); + data.append(FP16x16 { mag: 8192000, sign: false }); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 2686976, sign: true }); + data.append(FP16x16 { mag: 3801088, sign: false }); + data.append(FP16x16 { mag: 1441792, sign: true }); + data.append(FP16x16 { mag: 6553600, sign: true }); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + sequence +} diff --git a/tests/nodes/split_fp16x16_2d_variable_parts/input_0.cairo b/tests/nodes/split_fp16x16_2d_variable_parts/input_0.cairo new file mode 100644 index 000000000..f9e723397 --- /dev/null +++ b/tests/nodes/split_fp16x16_2d_variable_parts/input_0.cairo @@ -0,0 +1,25 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(6); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 7274496, sign: true }); + data.append(FP16x16 { mag: 1507328, sign: true }); + data.append(FP16x16 { mag: 3604480, sign: false }); + data.append(FP16x16 { mag: 7864320, sign: true }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 6422528, sign: false }); + data.append(FP16x16 { mag: 917504, sign: true }); + data.append(FP16x16 { mag: 3538944, sign: true }); + data.append(FP16x16 { mag: 1835008, sign: false }); + data.append(FP16x16 { mag: 3407872, sign: false }); + data.append(FP16x16 { mag: 3145728, sign: false }); + data.append(FP16x16 { mag: 8257536, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/split_fp16x16_2d_variable_parts/output_0.cairo b/tests/nodes/split_fp16x16_2d_variable_parts/output_0.cairo new file mode 100644 index 000000000..3c7c5282f --- /dev/null +++ b/tests/nodes/split_fp16x16_2d_variable_parts/output_0.cairo @@ -0,0 +1,38 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Array> { + let mut sequence = ArrayTrait::new(); + + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 7274496, sign: true }); + data.append(FP16x16 { mag: 1507328, sign: true }); + data.append(FP16x16 { mag: 917504, sign: true }); + data.append(FP16x16 { mag: 3538944, sign: true }); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 3604480, sign: false }); + data.append(FP16x16 { mag: 7864320, sign: true }); + data.append(FP16x16 { mag: 786432, sign: false }); + data.append(FP16x16 { mag: 6422528, sign: false }); + data.append(FP16x16 { mag: 1835008, sign: false }); + data.append(FP16x16 { mag: 3407872, sign: false }); + data.append(FP16x16 { mag: 3145728, sign: false }); + data.append(FP16x16 { mag: 8257536, sign: true }); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + sequence +} diff --git a/tests/nodes/split_fp16x16_zero_size/input_0.cairo b/tests/nodes/split_fp16x16_zero_size/input_0.cairo new file mode 100644 index 000000000..99f4ba73b --- /dev/null +++ b/tests/nodes/split_fp16x16_zero_size/input_0.cairo @@ -0,0 +1,12 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(0); + + let mut data = ArrayTrait::new(); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/split_fp16x16_zero_size/output_0.cairo b/tests/nodes/split_fp16x16_zero_size/output_0.cairo new file mode 100644 index 000000000..09f338786 --- /dev/null +++ b/tests/nodes/split_fp16x16_zero_size/output_0.cairo @@ -0,0 +1,31 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Array> { + let mut sequence = ArrayTrait::new(); + + let mut shape = ArrayTrait::::new(); + shape.append(0); + + let mut data = ArrayTrait::new(); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + let mut shape = ArrayTrait::::new(); + shape.append(0); + + let mut data = ArrayTrait::new(); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + let mut shape = ArrayTrait::::new(); + shape.append(0); + + let mut data = ArrayTrait::new(); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + sequence +} diff --git a/tests/nodes/split_u32_1d_equal_parts/input_0.cairo b/tests/nodes/split_u32_1d_equal_parts/input_0.cairo new file mode 100644 index 000000000..7cb4b0bef --- /dev/null +++ b/tests/nodes/split_u32_1d_equal_parts/input_0.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(6); + + let mut data = ArrayTrait::new(); + data.append(191); + data.append(83); + data.append(144); + data.append(69); + data.append(77); + data.append(34); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/split_u32_1d_equal_parts/output_0.cairo b/tests/nodes/split_u32_1d_equal_parts/output_0.cairo new file mode 100644 index 000000000..65c19f5ff --- /dev/null +++ b/tests/nodes/split_u32_1d_equal_parts/output_0.cairo @@ -0,0 +1,36 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn output_0() -> Array> { + let mut sequence = ArrayTrait::new(); + + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(191); + data.append(83); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(144); + data.append(69); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(77); + data.append(34); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + sequence +} diff --git a/tests/nodes/split_u32_1d_uneven/input_0.cairo b/tests/nodes/split_u32_1d_uneven/input_0.cairo new file mode 100644 index 000000000..9ef03f0fe --- /dev/null +++ b/tests/nodes/split_u32_1d_uneven/input_0.cairo @@ -0,0 +1,18 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(7); + + let mut data = ArrayTrait::new(); + data.append(203); + data.append(159); + data.append(108); + data.append(166); + data.append(98); + data.append(220); + data.append(233); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/split_u32_1d_uneven/output_0.cairo b/tests/nodes/split_u32_1d_uneven/output_0.cairo new file mode 100644 index 000000000..c39b4bc14 --- /dev/null +++ b/tests/nodes/split_u32_1d_uneven/output_0.cairo @@ -0,0 +1,44 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn output_0() -> Array> { + let mut sequence = ArrayTrait::new(); + + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(203); + data.append(159); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(108); + data.append(166); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(98); + data.append(220); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + let mut shape = ArrayTrait::::new(); + shape.append(1); + + let mut data = ArrayTrait::new(); + data.append(233); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + sequence +} diff --git a/tests/nodes/split_u32_1d_variable_parts/input_0.cairo b/tests/nodes/split_u32_1d_variable_parts/input_0.cairo new file mode 100644 index 000000000..7cb4b0bef --- /dev/null +++ b/tests/nodes/split_u32_1d_variable_parts/input_0.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(6); + + let mut data = ArrayTrait::new(); + data.append(191); + data.append(83); + data.append(144); + data.append(69); + data.append(77); + data.append(34); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/split_u32_1d_variable_parts/output_0.cairo b/tests/nodes/split_u32_1d_variable_parts/output_0.cairo new file mode 100644 index 000000000..9407802e9 --- /dev/null +++ b/tests/nodes/split_u32_1d_variable_parts/output_0.cairo @@ -0,0 +1,29 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn output_0() -> Array> { + let mut sequence = ArrayTrait::new(); + + let mut shape = ArrayTrait::::new(); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(191); + data.append(83); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(144); + data.append(69); + data.append(77); + data.append(34); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + sequence +} diff --git a/tests/nodes/split_u32_2d_equal_parts/input_0.cairo b/tests/nodes/split_u32_2d_equal_parts/input_0.cairo new file mode 100644 index 000000000..4da0f4c2f --- /dev/null +++ b/tests/nodes/split_u32_2d_equal_parts/input_0.cairo @@ -0,0 +1,24 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(6); + + let mut data = ArrayTrait::new(); + data.append(189); + data.append(74); + data.append(230); + data.append(245); + data.append(231); + data.append(162); + data.append(11); + data.append(159); + data.append(108); + data.append(92); + data.append(6); + data.append(61); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/split_u32_2d_equal_parts/output_0.cairo b/tests/nodes/split_u32_2d_equal_parts/output_0.cairo new file mode 100644 index 000000000..7abbf7096 --- /dev/null +++ b/tests/nodes/split_u32_2d_equal_parts/output_0.cairo @@ -0,0 +1,37 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn output_0() -> Array> { + let mut sequence = ArrayTrait::new(); + + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(189); + data.append(74); + data.append(230); + data.append(11); + data.append(159); + data.append(108); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(245); + data.append(231); + data.append(162); + data.append(92); + data.append(6); + data.append(61); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + sequence +} diff --git a/tests/nodes/split_u32_2d_uneven/input_0.cairo b/tests/nodes/split_u32_2d_uneven/input_0.cairo new file mode 100644 index 000000000..c6f1e2857 --- /dev/null +++ b/tests/nodes/split_u32_2d_uneven/input_0.cairo @@ -0,0 +1,28 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(8); + + let mut data = ArrayTrait::new(); + data.append(114); + data.append(94); + data.append(130); + data.append(214); + data.append(213); + data.append(226); + data.append(218); + data.append(47); + data.append(173); + data.append(181); + data.append(108); + data.append(140); + data.append(123); + data.append(14); + data.append(181); + data.append(7); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/split_u32_2d_uneven/output_0.cairo b/tests/nodes/split_u32_2d_uneven/output_0.cairo new file mode 100644 index 000000000..8fdbd5b48 --- /dev/null +++ b/tests/nodes/split_u32_2d_uneven/output_0.cairo @@ -0,0 +1,49 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn output_0() -> Array> { + let mut sequence = ArrayTrait::new(); + + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(114); + data.append(94); + data.append(130); + data.append(173); + data.append(181); + data.append(108); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + + let mut data = ArrayTrait::new(); + data.append(214); + data.append(213); + data.append(226); + data.append(140); + data.append(123); + data.append(14); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(218); + data.append(47); + data.append(181); + data.append(7); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + sequence +} diff --git a/tests/nodes/split_u32_2d_variable_parts/input_0.cairo b/tests/nodes/split_u32_2d_variable_parts/input_0.cairo new file mode 100644 index 000000000..4da0f4c2f --- /dev/null +++ b/tests/nodes/split_u32_2d_variable_parts/input_0.cairo @@ -0,0 +1,24 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(6); + + let mut data = ArrayTrait::new(); + data.append(189); + data.append(74); + data.append(230); + data.append(245); + data.append(231); + data.append(162); + data.append(11); + data.append(159); + data.append(108); + data.append(92); + data.append(6); + data.append(61); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/split_u32_2d_variable_parts/output_0.cairo b/tests/nodes/split_u32_2d_variable_parts/output_0.cairo new file mode 100644 index 000000000..8e835404a --- /dev/null +++ b/tests/nodes/split_u32_2d_variable_parts/output_0.cairo @@ -0,0 +1,37 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn output_0() -> Array> { + let mut sequence = ArrayTrait::new(); + + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(2); + + let mut data = ArrayTrait::new(); + data.append(189); + data.append(74); + data.append(11); + data.append(159); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(230); + data.append(245); + data.append(231); + data.append(162); + data.append(108); + data.append(92); + data.append(6); + data.append(61); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + sequence +} diff --git a/tests/nodes/split_u32_zero_size/input_0.cairo b/tests/nodes/split_u32_zero_size/input_0.cairo new file mode 100644 index 000000000..442c49919 --- /dev/null +++ b/tests/nodes/split_u32_zero_size/input_0.cairo @@ -0,0 +1,11 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(0); + + let mut data = ArrayTrait::new(); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/split_u32_zero_size/output_0.cairo b/tests/nodes/split_u32_zero_size/output_0.cairo new file mode 100644 index 000000000..9ab25acaa --- /dev/null +++ b/tests/nodes/split_u32_zero_size/output_0.cairo @@ -0,0 +1,30 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +fn output_0() -> Array> { + let mut sequence = ArrayTrait::new(); + + let mut shape = ArrayTrait::::new(); + shape.append(0); + + let mut data = ArrayTrait::new(); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + let mut shape = ArrayTrait::::new(); + shape.append(0); + + let mut data = ArrayTrait::new(); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + let mut shape = ArrayTrait::::new(); + shape.append(0); + + let mut data = ArrayTrait::new(); + + sequence.append(TensorTrait::new(shape.span(), data.span())); + + sequence +} From cf9b480b78bca7d87f7087da31d4c45ddfc27f84 Mon Sep 17 00:00:00 2001 From: zhangzhichao Date: Wed, 10 Jan 2024 16:32:26 +0800 Subject: [PATCH 28/38] feat: supplement node test --- .../nodes/split_fp16x16_1d_equal_parts.cairo | 20 ++++++++++++++++++ tests/nodes/split_fp16x16_1d_uneven.cairo | 20 ++++++++++++++++++ .../split_fp16x16_1d_variable_parts.cairo | 21 +++++++++++++++++++ .../nodes/split_fp16x16_2d_equal_parts.cairo | 20 ++++++++++++++++++ tests/nodes/split_fp16x16_2d_uneven.cairo | 20 ++++++++++++++++++ .../split_fp16x16_2d_variable_parts.cairo | 21 +++++++++++++++++++ tests/nodes/split_fp16x16_zero_size.cairo | 21 +++++++++++++++++++ tests/nodes/split_u32_1d_equal_parts.cairo | 20 ++++++++++++++++++ tests/nodes/split_u32_1d_uneven.cairo | 20 ++++++++++++++++++ tests/nodes/split_u32_1d_variable_parts.cairo | 20 ++++++++++++++++++ tests/nodes/split_u32_2d_equal_parts.cairo | 20 ++++++++++++++++++ tests/nodes/split_u32_2d_uneven.cairo | 20 ++++++++++++++++++ tests/nodes/split_u32_2d_variable_parts.cairo | 20 ++++++++++++++++++ tests/nodes/split_u32_zero_size.cairo | 20 ++++++++++++++++++ 14 files changed, 283 insertions(+) create mode 100644 tests/nodes/split_fp16x16_1d_equal_parts.cairo create mode 100644 tests/nodes/split_fp16x16_1d_uneven.cairo create mode 100644 tests/nodes/split_fp16x16_1d_variable_parts.cairo create mode 100644 tests/nodes/split_fp16x16_2d_equal_parts.cairo create mode 100644 tests/nodes/split_fp16x16_2d_uneven.cairo create mode 100644 tests/nodes/split_fp16x16_2d_variable_parts.cairo create mode 100644 tests/nodes/split_fp16x16_zero_size.cairo create mode 100644 tests/nodes/split_u32_1d_equal_parts.cairo create mode 100644 tests/nodes/split_u32_1d_uneven.cairo create mode 100644 tests/nodes/split_u32_1d_variable_parts.cairo create mode 100644 tests/nodes/split_u32_2d_equal_parts.cairo create mode 100644 tests/nodes/split_u32_2d_uneven.cairo create mode 100644 tests/nodes/split_u32_2d_variable_parts.cairo create mode 100644 tests/nodes/split_u32_zero_size.cairo diff --git a/tests/nodes/split_fp16x16_1d_equal_parts.cairo b/tests/nodes/split_fp16x16_1d_equal_parts.cairo new file mode 100644 index 000000000..03b36c507 --- /dev/null +++ b/tests/nodes/split_fp16x16_1d_equal_parts.cairo @@ -0,0 +1,20 @@ +mod input_0; +mod output_0; + + +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::FP16x16Tensor; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; + +#[test] +#[available_gas(2000000000)] +fn test_split_fp16x16_1d_equal_parts() { + let input_0 = input_0::input_0(); + let z = output_0::output_0(); + + let y = input_0.split(0, Option::Some(3), Option::None(())); + + assert_seq_eq(y, z); +} diff --git a/tests/nodes/split_fp16x16_1d_uneven.cairo b/tests/nodes/split_fp16x16_1d_uneven.cairo new file mode 100644 index 000000000..e0221f924 --- /dev/null +++ b/tests/nodes/split_fp16x16_1d_uneven.cairo @@ -0,0 +1,20 @@ +mod input_0; +mod output_0; + + +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::FP16x16Tensor; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; + +#[test] +#[available_gas(2000000000)] +fn test_split_fp16x16_1d_uneven() { + let input_0 = input_0::input_0(); + let z = output_0::output_0(); + + let y = input_0.split(0, Option::Some(4), Option::None(())); + + assert_seq_eq(y, z); +} diff --git a/tests/nodes/split_fp16x16_1d_variable_parts.cairo b/tests/nodes/split_fp16x16_1d_variable_parts.cairo new file mode 100644 index 000000000..f5f46e75d --- /dev/null +++ b/tests/nodes/split_fp16x16_1d_variable_parts.cairo @@ -0,0 +1,21 @@ +mod input_0; +mod output_0; + + +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::FP16x16Tensor; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +#[test] +#[available_gas(2000000000)] +fn test_split_fp16x16_1d_variable_parts() { + let input_0 = input_0::input_0(); + let z = output_0::output_0(); + + let y = input_0.split(0, Option::None(()), Option::Some(TensorTrait::::new(shape: array![2].span(), data: array![2, 4].span(),))); + + assert_seq_eq(y, z); +} diff --git a/tests/nodes/split_fp16x16_2d_equal_parts.cairo b/tests/nodes/split_fp16x16_2d_equal_parts.cairo new file mode 100644 index 000000000..a1c441b60 --- /dev/null +++ b/tests/nodes/split_fp16x16_2d_equal_parts.cairo @@ -0,0 +1,20 @@ +mod input_0; +mod output_0; + + +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::FP16x16Tensor; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; + +#[test] +#[available_gas(2000000000)] +fn test_split_fp16x16_2d_equal_parts() { + let input_0 = input_0::input_0(); + let z = output_0::output_0(); + + let y = input_0.split(1, Option::Some(2), Option::None(())); + + assert_seq_eq(y, z); +} diff --git a/tests/nodes/split_fp16x16_2d_uneven.cairo b/tests/nodes/split_fp16x16_2d_uneven.cairo new file mode 100644 index 000000000..fbf5c12d7 --- /dev/null +++ b/tests/nodes/split_fp16x16_2d_uneven.cairo @@ -0,0 +1,20 @@ +mod input_0; +mod output_0; + + +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::FP16x16Tensor; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; + +#[test] +#[available_gas(2000000000)] +fn test_split_fp16x16_2d_uneven() { + let input_0 = input_0::input_0(); + let z = output_0::output_0(); + + let y = input_0.split(1, Option::Some(3), Option::None(())); + + assert_seq_eq(y, z); +} diff --git a/tests/nodes/split_fp16x16_2d_variable_parts.cairo b/tests/nodes/split_fp16x16_2d_variable_parts.cairo new file mode 100644 index 000000000..d627014e2 --- /dev/null +++ b/tests/nodes/split_fp16x16_2d_variable_parts.cairo @@ -0,0 +1,21 @@ +mod input_0; +mod output_0; + + +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::FP16x16Tensor; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +#[test] +#[available_gas(2000000000)] +fn test_split_fp16x16_2d_variable_parts() { + let input_0 = input_0::input_0(); + let z = output_0::output_0(); + + let y = input_0.split(1, Option::None(()), Option::Some(TensorTrait::::new(shape: array![2].span(), data: array![2, 4].span(),))); + + assert_seq_eq(y, z); +} diff --git a/tests/nodes/split_fp16x16_zero_size.cairo b/tests/nodes/split_fp16x16_zero_size.cairo new file mode 100644 index 000000000..c9056376b --- /dev/null +++ b/tests/nodes/split_fp16x16_zero_size.cairo @@ -0,0 +1,21 @@ +mod input_0; +mod output_0; + + +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::FP16x16Tensor; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::U32Tensor; + +#[test] +#[available_gas(2000000000)] +fn test_split_fp16x16_zero_size() { + let input_0 = input_0::input_0(); + let z = output_0::output_0(); + + let y = input_0.split(0, Option::None(()), Option::Some(TensorTrait::::new(shape: array![3].span(), data: array![0, 0, 0].span(),))); + + assert_seq_eq(y, z); +} diff --git a/tests/nodes/split_u32_1d_equal_parts.cairo b/tests/nodes/split_u32_1d_equal_parts.cairo new file mode 100644 index 000000000..d4530bac9 --- /dev/null +++ b/tests/nodes/split_u32_1d_equal_parts.cairo @@ -0,0 +1,20 @@ +mod input_0; +mod output_0; + + +use orion::operators::tensor::U32TensorPartialEq; +use orion::utils::{assert_eq, assert_seq_eq}; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::U32Tensor; +use orion::operators::tensor::{TensorTrait, Tensor}; + +#[test] +#[available_gas(2000000000)] +fn test_split_u32_1d_equal_parts() { + let input_0 = input_0::input_0(); + let z = output_0::output_0(); + + let y = input_0.split(0, Option::Some(3), Option::None(())); + + assert_seq_eq(y, z); +} diff --git a/tests/nodes/split_u32_1d_uneven.cairo b/tests/nodes/split_u32_1d_uneven.cairo new file mode 100644 index 000000000..a4180af74 --- /dev/null +++ b/tests/nodes/split_u32_1d_uneven.cairo @@ -0,0 +1,20 @@ +mod input_0; +mod output_0; + + +use orion::operators::tensor::U32TensorPartialEq; +use orion::utils::{assert_eq, assert_seq_eq}; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::U32Tensor; +use orion::operators::tensor::{TensorTrait, Tensor}; + +#[test] +#[available_gas(2000000000)] +fn test_split_u32_1d_uneven() { + let input_0 = input_0::input_0(); + let z = output_0::output_0(); + + let y = input_0.split(0, Option::Some(4), Option::None(())); + + assert_seq_eq(y, z); +} diff --git a/tests/nodes/split_u32_1d_variable_parts.cairo b/tests/nodes/split_u32_1d_variable_parts.cairo new file mode 100644 index 000000000..2680a6f77 --- /dev/null +++ b/tests/nodes/split_u32_1d_variable_parts.cairo @@ -0,0 +1,20 @@ +mod input_0; +mod output_0; + + +use orion::operators::tensor::U32TensorPartialEq; +use orion::utils::{assert_eq, assert_seq_eq}; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::U32Tensor; +use orion::operators::tensor::{TensorTrait, Tensor}; + +#[test] +#[available_gas(2000000000)] +fn test_split_u32_1d_variable_parts() { + let input_0 = input_0::input_0(); + let z = output_0::output_0(); + + let y = input_0.split(0, Option::None(()), Option::Some(TensorTrait::::new(shape: array![2].span(), data: array![2, 4].span(),))); + + assert_seq_eq(y, z); +} diff --git a/tests/nodes/split_u32_2d_equal_parts.cairo b/tests/nodes/split_u32_2d_equal_parts.cairo new file mode 100644 index 000000000..d5fa6d43a --- /dev/null +++ b/tests/nodes/split_u32_2d_equal_parts.cairo @@ -0,0 +1,20 @@ +mod input_0; +mod output_0; + + +use orion::operators::tensor::U32TensorPartialEq; +use orion::utils::{assert_eq, assert_seq_eq}; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::U32Tensor; +use orion::operators::tensor::{TensorTrait, Tensor}; + +#[test] +#[available_gas(2000000000)] +fn test_split_u32_2d_equal_parts() { + let input_0 = input_0::input_0(); + let z = output_0::output_0(); + + let y = input_0.split(1, Option::Some(2), Option::None(())); + + assert_seq_eq(y, z); +} diff --git a/tests/nodes/split_u32_2d_uneven.cairo b/tests/nodes/split_u32_2d_uneven.cairo new file mode 100644 index 000000000..f663cc95f --- /dev/null +++ b/tests/nodes/split_u32_2d_uneven.cairo @@ -0,0 +1,20 @@ +mod input_0; +mod output_0; + + +use orion::operators::tensor::U32TensorPartialEq; +use orion::utils::{assert_eq, assert_seq_eq}; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::U32Tensor; +use orion::operators::tensor::{TensorTrait, Tensor}; + +#[test] +#[available_gas(2000000000)] +fn test_split_u32_2d_uneven() { + let input_0 = input_0::input_0(); + let z = output_0::output_0(); + + let y = input_0.split(1, Option::Some(3), Option::None(())); + + assert_seq_eq(y, z); +} diff --git a/tests/nodes/split_u32_2d_variable_parts.cairo b/tests/nodes/split_u32_2d_variable_parts.cairo new file mode 100644 index 000000000..b38f87122 --- /dev/null +++ b/tests/nodes/split_u32_2d_variable_parts.cairo @@ -0,0 +1,20 @@ +mod input_0; +mod output_0; + + +use orion::operators::tensor::U32TensorPartialEq; +use orion::utils::{assert_eq, assert_seq_eq}; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::U32Tensor; +use orion::operators::tensor::{TensorTrait, Tensor}; + +#[test] +#[available_gas(2000000000)] +fn test_split_u32_2d_variable_parts() { + let input_0 = input_0::input_0(); + let z = output_0::output_0(); + + let y = input_0.split(1, Option::None(()), Option::Some(TensorTrait::::new(shape: array![2].span(), data: array![2, 4].span(),))); + + assert_seq_eq(y, z); +} diff --git a/tests/nodes/split_u32_zero_size.cairo b/tests/nodes/split_u32_zero_size.cairo new file mode 100644 index 000000000..39eeb9d67 --- /dev/null +++ b/tests/nodes/split_u32_zero_size.cairo @@ -0,0 +1,20 @@ +mod input_0; +mod output_0; + + +use orion::operators::tensor::U32TensorPartialEq; +use orion::utils::{assert_eq, assert_seq_eq}; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::U32Tensor; +use orion::operators::tensor::{TensorTrait, Tensor}; + +#[test] +#[available_gas(2000000000)] +fn test_split_u32_zero_size() { + let input_0 = input_0::input_0(); + let z = output_0::output_0(); + + let y = input_0.split(0, Option::None(()), Option::Some(TensorTrait::::new(shape: array![3].span(), data: array![0, 0, 0].span(),))); + + assert_seq_eq(y, z); +} From af7594a48a4f05ef211c62371348a491d5f13ee7 Mon Sep 17 00:00:00 2001 From: zhangzhichao Date: Wed, 10 Jan 2024 17:02:16 +0800 Subject: [PATCH 29/38] fixed: add split in tensor_complex64 --- src/operators/tensor/implementations/tensor_complex64.cairo | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/operators/tensor/implementations/tensor_complex64.cairo b/src/operators/tensor/implementations/tensor_complex64.cairo index 53feb8980..ca005d711 100644 --- a/src/operators/tensor/implementations/tensor_complex64.cairo +++ b/src/operators/tensor/implementations/tensor_complex64.cairo @@ -475,6 +475,10 @@ impl Complex64Tensor of TensorTrait { fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { math::compress::compress(self, condition, axis) } + + fn split(self: @Tensor, axis: usize, num_outputs: Option, spl: Option>) -> Array> { + panic(array!['not supported!']) + } } /// Implements addition for `Tensor` using the `Add` trait. From 40c884bf4492159dfac9d2195673f6bc222aba6c Mon Sep 17 00:00:00 2001 From: chachaleo Date: Wed, 10 Jan 2024 16:42:25 +0100 Subject: [PATCH 30/38] feat: layer_normalization --- docs/SUMMARY.md | 1 + docs/framework/compatibility.md | 1 + .../linear-classifier/README.md | 3 +- .../linear_classifier.predict.md | 4 +- docs/framework/operators/tensor/README.md | 1 + .../tensor/tensor.layer_normalization.md | 89 + nodegen/node/layer_normalization.py | 152 ++ .../tree_ensemble_classifier.cairo | 68 +- src/operators/sequence/functional.cairo | 2 +- .../implementations/sequence_bool.cairo | 1 - .../implementations/sequence_fp8x23wide.cairo | 1 - src/operators/tensor/core.cairo | 99 + .../tensor/implementations/tensor_bool.cairo | 15 +- .../implementations/tensor_complex64.cairo | 15 +- .../implementations/tensor_fp16x16.cairo | 19 +- .../implementations/tensor_fp16x16wide.cairo | 19 +- .../implementations/tensor_fp32x32.cairo | 19 +- .../implementations/tensor_fp64x64.cairo | 19 +- .../implementations/tensor_fp8x23.cairo | 19 +- .../implementations/tensor_fp8x23wide.cairo | 19 +- .../tensor/implementations/tensor_i32.cairo | 15 +- .../tensor/implementations/tensor_i8.cairo | 15 +- .../tensor/implementations/tensor_u32.cairo | 15 +- src/operators/tensor/math.cairo | 3 +- src/operators/tensor/math/compress.cairo | 69 +- .../tensor/math/layer_normalization.cairo | 191 ++ tests/lib.cairo | 1 - tests/ml/tree_ensemble_classifier.cairo | 2107 +++++++++++++++-- tests/nodes.cairo | 16 + tests/nodes/compress_fp16x16_3d_axis1.cairo | 2 +- tests/nodes/compress_fp16x16_3d_axis2.cairo | 2 +- tests/nodes/compress_fp16x16_3d_axis3.cairo | 2 +- tests/nodes/compress_fp16x16_3d_default.cairo | 2 +- tests/nodes/compress_fp16x16_3d_noaxis.cairo | 2 +- tests/nodes/compress_fp8x23_3d_axis1.cairo | 2 +- tests/nodes/compress_fp8x23_3d_axis2.cairo | 2 +- tests/nodes/compress_fp8x23_3d_default.cairo | 2 +- tests/nodes/compress_i32_3d_axis1.cairo | 2 +- tests/nodes/compress_i32_3d_axis2.cairo | 2 +- tests/nodes/compress_i32_3d_default.cairo | 2 +- tests/nodes/compress_i8_3d_axis1.cairo | 2 +- tests/nodes/compress_i8_3d_axis2.cairo | 2 +- tests/nodes/compress_i8_3d_default.cairo | 2 +- tests/nodes/compress_u32_3d_axis1.cairo | 2 +- tests/nodes/compress_u32_3d_axis2.cairo | 2 +- tests/nodes/compress_u32_3d_axis2_2.cairo | 2 +- tests/nodes/compress_u32_3d_axis3.cairo | 2 +- tests/nodes/compress_u32_3d_default.cairo | 2 +- ...layer_normalization_3d_axis0_epsilon.cairo | 33 + .../input_0.cairo | 44 + .../input_1.cairo | 44 + .../input_2.cairo | 44 + .../output_0.cairo | 44 + ...layer_normalization_3d_axis1_epsilon.cairo | 32 + .../input_0.cairo | 44 + .../input_1.cairo | 28 + .../input_2.cairo | 28 + .../output_0.cairo | 44 + ...layer_normalization_3d_axis2_epsilon.cairo | 32 + .../input_0.cairo | 44 + .../input_1.cairo | 17 + .../input_2.cairo | 17 + .../output_0.cairo | 44 + ...alization_3d_axis_negative_1_epsilon.cairo | 32 + .../input_0.cairo | 44 + .../input_1.cairo | 17 + .../input_2.cairo | 17 + .../output_0.cairo | 44 + ...alization_3d_axis_negative_2_epsilon.cairo | 33 + .../input_0.cairo | 44 + .../input_1.cairo | 28 + .../input_2.cairo | 28 + .../output_0.cairo | 44 + ...alization_3d_axis_negative_3_epsilon.cairo | 33 + .../input_0.cairo | 44 + .../input_1.cairo | 44 + .../input_2.cairo | 44 + .../output_0.cairo | 44 + .../input_0.cairo | 135 ++ .../input_1.cairo | 17 + .../input_2.cairo | 17 + .../output_0.cairo | 135 ++ .../input_0.cairo | 135 ++ .../input_1.cairo | 33 + .../input_2.cairo | 33 + .../output_0.cairo | 135 ++ .../input_0.cairo | 135 ++ .../input_1.cairo | 74 + .../input_2.cairo | 74 + .../output_0.cairo | 135 ++ .../input_0.cairo | 135 ++ .../input_1.cairo | 135 ++ .../input_2.cairo | 135 ++ .../output_0.cairo | 135 ++ .../nodes/layer_normalization_4d_axis0.cairo | 32 + .../input_0.cairo | 135 ++ .../input_1.cairo | 135 ++ .../input_2.cairo | 135 ++ .../output_0.cairo | 135 ++ .../nodes/layer_normalization_4d_axis1.cairo | 32 + .../input_0.cairo | 135 ++ .../input_1.cairo | 74 + .../input_2.cairo | 74 + .../output_0.cairo | 135 ++ .../nodes/layer_normalization_4d_axis2.cairo | 32 + .../input_0.cairo | 135 ++ .../input_1.cairo | 33 + .../input_2.cairo | 33 + .../output_0.cairo | 135 ++ .../nodes/layer_normalization_4d_axis3.cairo | 32 + .../input_0.cairo | 135 ++ .../input_1.cairo | 17 + .../input_2.cairo | 17 + .../output_0.cairo | 135 ++ ...yer_normalization_4d_axis_negative_1.cairo | 32 + .../input_0.cairo | 135 ++ .../input_1.cairo | 17 + .../input_2.cairo | 17 + .../output_0.cairo | 135 ++ ...yer_normalization_4d_axis_negative_2.cairo | 32 + .../input_0.cairo | 135 ++ .../input_1.cairo | 33 + .../input_2.cairo | 33 + .../output_0.cairo | 135 ++ ...yer_normalization_4d_axis_negative_3.cairo | 32 + .../input_0.cairo | 135 ++ .../input_1.cairo | 74 + .../input_2.cairo | 74 + .../output_0.cairo | 135 ++ ...yer_normalization_4d_axis_negative_4.cairo | 32 + .../input_0.cairo | 135 ++ .../input_1.cairo | 135 ++ .../input_2.cairo | 135 ++ .../output_0.cairo | 135 ++ .../layer_normalization_default_axis.cairo | 28 + .../input_0.cairo | 135 ++ .../input_1.cairo | 17 + .../input_2.cairo | 17 + .../output_0.cairo | 135 ++ tests/nodes/layer_normalization_test.cairo | 28 + .../layer_normalization_test/input_0.cairo | 25 + .../layer_normalization_test/input_1.cairo | 16 + .../layer_normalization_test/input_2.cairo | 16 + .../layer_normalization_test/output_0.cairo | 25 + 144 files changed, 9293 insertions(+), 314 deletions(-) create mode 100644 docs/framework/operators/tensor/tensor.layer_normalization.md create mode 100644 nodegen/node/layer_normalization.py create mode 100644 src/operators/tensor/math/layer_normalization.cairo create mode 100644 tests/nodes/layer_normalization_3d_axis0_epsilon.cairo create mode 100644 tests/nodes/layer_normalization_3d_axis0_epsilon/input_0.cairo create mode 100644 tests/nodes/layer_normalization_3d_axis0_epsilon/input_1.cairo create mode 100644 tests/nodes/layer_normalization_3d_axis0_epsilon/input_2.cairo create mode 100644 tests/nodes/layer_normalization_3d_axis0_epsilon/output_0.cairo create mode 100644 tests/nodes/layer_normalization_3d_axis1_epsilon.cairo create mode 100644 tests/nodes/layer_normalization_3d_axis1_epsilon/input_0.cairo create mode 100644 tests/nodes/layer_normalization_3d_axis1_epsilon/input_1.cairo create mode 100644 tests/nodes/layer_normalization_3d_axis1_epsilon/input_2.cairo create mode 100644 tests/nodes/layer_normalization_3d_axis1_epsilon/output_0.cairo create mode 100644 tests/nodes/layer_normalization_3d_axis2_epsilon.cairo create mode 100644 tests/nodes/layer_normalization_3d_axis2_epsilon/input_0.cairo create mode 100644 tests/nodes/layer_normalization_3d_axis2_epsilon/input_1.cairo create mode 100644 tests/nodes/layer_normalization_3d_axis2_epsilon/input_2.cairo create mode 100644 tests/nodes/layer_normalization_3d_axis2_epsilon/output_0.cairo create mode 100644 tests/nodes/layer_normalization_3d_axis_negative_1_epsilon.cairo create mode 100644 tests/nodes/layer_normalization_3d_axis_negative_1_epsilon/input_0.cairo create mode 100644 tests/nodes/layer_normalization_3d_axis_negative_1_epsilon/input_1.cairo create mode 100644 tests/nodes/layer_normalization_3d_axis_negative_1_epsilon/input_2.cairo create mode 100644 tests/nodes/layer_normalization_3d_axis_negative_1_epsilon/output_0.cairo create mode 100644 tests/nodes/layer_normalization_3d_axis_negative_2_epsilon.cairo create mode 100644 tests/nodes/layer_normalization_3d_axis_negative_2_epsilon/input_0.cairo create mode 100644 tests/nodes/layer_normalization_3d_axis_negative_2_epsilon/input_1.cairo create mode 100644 tests/nodes/layer_normalization_3d_axis_negative_2_epsilon/input_2.cairo create mode 100644 tests/nodes/layer_normalization_3d_axis_negative_2_epsilon/output_0.cairo create mode 100644 tests/nodes/layer_normalization_3d_axis_negative_3_epsilon.cairo create mode 100644 tests/nodes/layer_normalization_3d_axis_negative_3_epsilon/input_0.cairo create mode 100644 tests/nodes/layer_normalization_3d_axis_negative_3_epsilon/input_1.cairo create mode 100644 tests/nodes/layer_normalization_3d_axis_negative_3_epsilon/input_2.cairo create mode 100644 tests/nodes/layer_normalization_3d_axis_negative_3_epsilon/output_0.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis-1/input_0.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis-1/input_1.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis-1/input_2.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis-1/output_0.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis-2/input_0.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis-2/input_1.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis-2/input_2.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis-2/output_0.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis-3/input_0.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis-3/input_1.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis-3/input_2.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis-3/output_0.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis-4/input_0.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis-4/input_1.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis-4/input_2.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis-4/output_0.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis0.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis0/input_0.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis0/input_1.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis0/input_2.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis0/output_0.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis1.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis1/input_0.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis1/input_1.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis1/input_2.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis1/output_0.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis2.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis2/input_0.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis2/input_1.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis2/input_2.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis2/output_0.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis3.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis3/input_0.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis3/input_1.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis3/input_2.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis3/output_0.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis_negative_1.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis_negative_1/input_0.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis_negative_1/input_1.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis_negative_1/input_2.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis_negative_1/output_0.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis_negative_2.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis_negative_2/input_0.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis_negative_2/input_1.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis_negative_2/input_2.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis_negative_2/output_0.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis_negative_3.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis_negative_3/input_0.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis_negative_3/input_1.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis_negative_3/input_2.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis_negative_3/output_0.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis_negative_4.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis_negative_4/input_0.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis_negative_4/input_1.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis_negative_4/input_2.cairo create mode 100644 tests/nodes/layer_normalization_4d_axis_negative_4/output_0.cairo create mode 100644 tests/nodes/layer_normalization_default_axis.cairo create mode 100644 tests/nodes/layer_normalization_default_axis/input_0.cairo create mode 100644 tests/nodes/layer_normalization_default_axis/input_1.cairo create mode 100644 tests/nodes/layer_normalization_default_axis/input_2.cairo create mode 100644 tests/nodes/layer_normalization_default_axis/output_0.cairo create mode 100644 tests/nodes/layer_normalization_test.cairo create mode 100644 tests/nodes/layer_normalization_test/input_0.cairo create mode 100644 tests/nodes/layer_normalization_test/input_1.cairo create mode 100644 tests/nodes/layer_normalization_test/input_2.cairo create mode 100644 tests/nodes/layer_normalization_test/output_0.cairo diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index fa9998f2e..8b5a6497f 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -121,6 +121,7 @@ * [tensor.reduce\_log\_sum](framework/operators/tensor/tensor.reduce\_log\_sum.md) * [tensor.unique](framework/operators/tensor/tensor.unique.md) * [tensor.compress](framework/operators/tensor/tensor.compress.md) + * [tensor.layer_normalization](framework/operators/tensor/tensor.layer_normalization.md) * [Neural Network](framework/operators/neural-network/README.md) * [nn.relu](framework/operators/neural-network/nn.relu.md) * [nn.leaky\_relu](framework/operators/neural-network/nn.leaky\_relu.md) diff --git a/docs/framework/compatibility.md b/docs/framework/compatibility.md index e0153274a..b562f2d38 100644 --- a/docs/framework/compatibility.md +++ b/docs/framework/compatibility.md @@ -106,5 +106,6 @@ You can see below the list of current supported ONNX Operators: | [ReduceLogSum](operators/tensor/tensor.reduce\_log\_sum.md) | :white\_check\_mark: | | [Erf](operators/tensor/tensor.erf.md) | :white\_check\_mark: | | [Compress](operators/tensor/tensor.compress.md) | :white\_check\_mark: | +| [Layer_normalization](operators/tensor/tensor.layer_normalization.md) | :white\_check\_mark: | Current Operators support: **97/156 (62%)** diff --git a/docs/framework/operators/machine-learning/linear-classifier/README.md b/docs/framework/operators/machine-learning/linear-classifier/README.md index 7b68132c4..7323f8b7f 100644 --- a/docs/framework/operators/machine-learning/linear-classifier/README.md +++ b/docs/framework/operators/machine-learning/linear-classifier/README.md @@ -19,4 +19,5 @@ Orion supports currently only fixed point data types for `LinearClassificationTr | function | description | | --- | --- | -| [`linear_classifier.predict`](linear_classifier.predict.md) | Performs the linear classification evaluation. | +| [`linear_classifier.predict`](linear_classifier.predict.md) | Performs the linear classification. | + diff --git a/docs/framework/operators/machine-learning/linear-classifier/linear_classifier.predict.md b/docs/framework/operators/machine-learning/linear-classifier/linear_classifier.predict.md index 3b9537b1c..aec154f68 100644 --- a/docs/framework/operators/machine-learning/linear-classifier/linear_classifier.predict.md +++ b/docs/framework/operators/machine-learning/linear-classifier/linear_classifier.predict.md @@ -4,7 +4,7 @@ fn predict(ref self: LinearClassifier, X: Tensor) -> Tensor; ``` -Linear Regressor. Performs the linear classification. +Linear Classifier. Performs the linear classification. ## Args @@ -13,7 +13,7 @@ Linear Regressor. Performs the linear classification. ## Returns -* Tensor containing the generalized linear regression evaluation of the input X. +* Tensor containing the linear classification evaluation of the input X. ## Type Constraints diff --git a/docs/framework/operators/tensor/README.md b/docs/framework/operators/tensor/README.md index 75e094f99..c30cb776e 100644 --- a/docs/framework/operators/tensor/README.md +++ b/docs/framework/operators/tensor/README.md @@ -117,6 +117,7 @@ use orion::operators::tensor::TensorTrait; | [`tensor.gather_nd`](tensor.gather\_nd.md) | Given data tensor of rank r >= 1, indices tensor of rank q >= 1, and batch_dims integer b, this operator gathers slices of data into an output tensor of rank q + r - indices_shape[-1] - 1 - b. | | [`tensor.reduce_log_sum`](tensor.reduce\_log\_sum.md) | Computes the log sum of the input tensor's elements along the provided axes. | | [`tensor.erf`](tensor.erf.md) | Computes the error function of the given input tensor element-wise. | +| [`tensor.layer_normalization`](tensor.layer\_normalization.md) | computes the layer normalization of the input tensor. | ## Arithmetic Operations diff --git a/docs/framework/operators/tensor/tensor.layer_normalization.md b/docs/framework/operators/tensor/tensor.layer_normalization.md new file mode 100644 index 000000000..e52476b1c --- /dev/null +++ b/docs/framework/operators/tensor/tensor.layer_normalization.md @@ -0,0 +1,89 @@ +# tensor.layer_normalization + +```rust + fn layer_normalization( + self: @Tensor, + scale: @Tensor, + B: Option<@Tensor>, + axis: Option, + epsilon: Option, + stash_type: Option, +) -> (Tensor, Tensor, Tensor); +``` + +Layer normalization of the input, in two stages. +The first stage is standardization, which makes the normalized elements have zero mean and unit variances. +The second stage then scales and shifts the outcome of the first stage +## Args + +* `self`(`@Tensor`) - The input tensor. +* `scale`(`@Tensor,`) - Scale tensor. +* `B`(`Option<@Tensor>`) - Bias tensor. +* `axis`(`Option`) (default is -1) - The first normalization dimension. If rank(X) is r, axis' allowed range is [-r, r). Negative value means counting dimensions from the back. +* `epsilon`(`Option`) (default is 0) - The epsilon value to use to avoid division by zero. +* `stash_type`(`Option`) - Precise the computation precision - unused the precision is defined by the type of the tensor. +## Panics + +* Panics if condition rank is not equal to 1. + +## Returns + +A new normalized tensor`Tensor`. +A tensor containing the mean `Tensor`. +A tensor containing the inverse standard deviation `Tensor`. + +## Example + +```rust +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn layer_normalization_example() -> (Tensor, Tensor, Tensor) { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 41143, sign: true }); + data.append(FP16x16 { mag: 51803, sign: false }); + data.append(FP16x16 { mag: 113556, sign: false }); + data.append(FP16x16 { mag: 64774, sign: false }); + data.append(FP16x16 { mag: 866, sign: false }); + data.append(FP16x16 { mag: 698, sign: true }); + data.append(FP16x16 { mag: 106500, sign: false }); + data.append(FP16x16 { mag: 98929, sign: false }); + data.append(FP16x16 { mag: 7551, sign: false }); + data.append(FP16x16 { mag: 30689, sign: true }); + data.append(FP16x16 { mag: 38325, sign: false }); + data.append(FP16x16 { mag: 48164, sign: false }); + let X = TensorTrait::new(shape.span(), data.span()); + + let shape = ArrayTrait::::new(); + shape.append(4); + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 49855, sign: false }); + data.append(FP16x16 { mag: 150787, sign: false }); + data.append(FP16x16 { mag: 83498, sign: true }); + data.append(FP16x16 { mag: 30346, sign: false }); + let scale = TensorTrait::new(shape.span(), data.span()); + + + let mut shape = ArrayTrait::::new(); + shape.append(4); + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 54864, sign: true }); + data.append(FP16x16 { mag: 50952, sign: false }); + data.append(FP16x16 { mag: 8870, sign: true }); + data.append(FP16x16 { mag: 23216, sign: true }); + let bias = TensorTrait::new(shape.span(), data.span()); + + return X.layer_normalization(@scale,Option::Some(@bias),Option::None,Option::None,Option::None); +} +>>> [[-0.48926553 1.0185822 -0.02138367 -0.39223218] + [-0.7945549 0.99696046 0.04332176 -0.412645 ] + [-0.5664707 0.7491956 -0.7896356 -0.5320859 ]] + +``` diff --git a/nodegen/node/layer_normalization.py b/nodegen/node/layer_normalization.py new file mode 100644 index 000000000..54f6e63fb --- /dev/null +++ b/nodegen/node/layer_normalization.py @@ -0,0 +1,152 @@ +import numpy as np +from nodegen.node import RunAll +from ..helpers import make_test, to_fp, Tensor, Dtype, FixedImpl, Trait +import numpy as np + +import onnx +from onnx.backend.test.case.base import Base +from onnx.backend.test.case.node import expect + + +def _layer_normalization(X, W, B, axis=-1, epsilon=1e-5): + X_shape = X.shape + X_rank = len(X_shape) + if axis < 0: + axis = axis + X_rank + unsqueezed_rank = X_rank - axis + reduction_shape = X_shape[0:axis] + (1,) * unsqueezed_rank + + row_number = 1 + col_number = 1 + for i in range(X_rank): + if i < axis: + row_number *= X_shape[i] + else: + col_number *= X_shape[i] + x_mat = np.reshape(X, (row_number, col_number)) + x_mean = np.sum(x_mat, axis=1, keepdims=True) / col_number + x_diff = x_mat - x_mean + x_squared_diff = x_diff * x_diff + variance = np.sum(x_squared_diff, axis=1, keepdims=True) / col_number + variance_eps = variance + epsilon + std_dev = np.sqrt(variance_eps) + inv_std_dev = np.reciprocal(std_dev) + y_mat = x_diff * inv_std_dev + Y = np.reshape(y_mat, X_shape) * W + B + X_mean = np.reshape(x_mean, reduction_shape) + X_inv_std_dev = np.reshape(inv_std_dev, reduction_shape) + + return Y, X_mean, X_inv_std_dev + + +def calculate_normalized_shape(X_shape, axis): + X_rank = len(X_shape) + if axis < 0: + axis = axis + X_rank + return X_shape[axis:] + + +class Layer_normalization(RunAll): + @staticmethod + def export4d() -> None: + X = np.random.randn(2, 3, 4, 5).astype(np.float32) + + def case(axis: int) -> None: + normalized_shape = calculate_normalized_shape(X.shape, axis) + W = np.random.randn(*normalized_shape).astype(np.float32) + B = np.random.randn(*normalized_shape).astype(np.float32) + Y, mean, inv_std_dev = _layer_normalization(X, W, B, axis) + + if axis < 0: + name = f"layer_normalization_4d_axis_negative_{-axis}" + func_sig = f"input_0.layer_normalization(@input_1,Option::Some(@input_2),Option::Some(IntegerTrait::::new({-axis}, true)),Option::None,Option::None)" + else: + name = f"layer_normalization_4d_axis{axis}" + func_sig = f"input_0.layer_normalization(@input_1,Option::Some(@input_2),Option::Some(IntegerTrait::::new({axis}, false)),Option::None,Option::None)" + + + x = Tensor(Dtype.FP8x23, X.shape, to_fp(X.flatten(), FixedImpl.FP8x23)) + w = Tensor(Dtype.FP8x23, W.shape, to_fp(W.flatten(), FixedImpl.FP8x23)) + b = Tensor(Dtype.FP8x23, B.shape, to_fp(B.flatten(), FixedImpl.FP8x23)) + y = Tensor(Dtype.FP8x23, Y.shape, to_fp(Y.flatten(), FixedImpl.FP8x23)) + + make_test([x,w,b], y, func_sig, name) + + + for i in range(len(X.shape)): + case(i) + case(i - len(X.shape)) + + @staticmethod + def export_default_axis() -> None: + X = np.random.randn(2, 3, 4, 5).astype(np.float32) + + normalized_shape = calculate_normalized_shape(X.shape, -1) + W = np.random.randn(*normalized_shape).astype(np.float32) + B = np.random.randn(*normalized_shape).astype(np.float32) + Y, mean, inv_std_dev = _layer_normalization(X, W, B) + + x = Tensor(Dtype.FP16x16, X.shape, to_fp(X.flatten(), FixedImpl.FP16x16)) + w = Tensor(Dtype.FP16x16, W.shape, to_fp(W.flatten(), FixedImpl.FP16x16)) + b = Tensor(Dtype.FP16x16, B.shape, to_fp(B.flatten(), FixedImpl.FP16x16)) + y = Tensor(Dtype.FP16x16, Y.shape, to_fp(Y.flatten(), FixedImpl.FP16x16)) + + name = "layer_normalization_default_axis" + make_test([x,w,b], y, "input_0.layer_normalization(@input_1,Option::Some(@input_2),Option::None,Option::None,Option::None)", name) + + @staticmethod + def export3d_epsilon() -> None: + epsilon = 1e-1 + X = np.random.randn(2, 3, 5).astype(np.float32) + + def case(axis: int) -> None: + normalized_shape = calculate_normalized_shape(X.shape, axis) + W = np.random.randn(*normalized_shape).astype(np.float32) + B = np.random.randn(*normalized_shape).astype(np.float32) + Y, mean, inv_std_dev = _layer_normalization(X, W, B, axis, epsilon) + + if axis < 0: + name = f"layer_normalization_3d_axis_negative_{-axis}_epsilon" + func_sig = f"input_0.layer_normalization(@input_1,Option::Some(@input_2),Option::Some(IntegerTrait::::new({-axis}, true)),Option::Some(FixedTrait::new(6554, false)),Option::None)" + else: + name = f"layer_normalization_3d_axis{axis}_epsilon" + func_sig = f"input_0.layer_normalization(@input_1,Option::Some(@input_2),Option::Some(IntegerTrait::::new({axis}, false)),Option::Some(FixedTrait::new(6554, false)),Option::None)" + + x = Tensor(Dtype.FP16x16, X.shape, to_fp(X.flatten(), FixedImpl.FP16x16)) + w = Tensor(Dtype.FP16x16, W.shape, to_fp(W.flatten(), FixedImpl.FP16x16)) + b = Tensor(Dtype.FP16x16, B.shape, to_fp(B.flatten(), FixedImpl.FP16x16)) + y = Tensor(Dtype.FP16x16, Y.shape, to_fp(Y.flatten(), FixedImpl.FP16x16)) + + make_test([x,w,b], y, func_sig, name) + + + for i in range(len(X.shape)): + case(i) + case(i - len(X.shape)) + + @staticmethod + def test_2d_example() -> None: + X = np.random.randn(3, 4).astype(np.float32) + + def case(axis: int) -> None: + normalized_shape = calculate_normalized_shape(X.shape, axis) + W = np.random.randn(*normalized_shape).astype(np.float32) + B = np.random.randn(*normalized_shape).astype(np.float32) + Y, mean, inv_std_dev = _layer_normalization(X, W, B, axis=axis) + + node = onnx.helper.make_node( + "LayerNormalization", + inputs=["X", "W", "B"], + outputs=["Y", "Mean", "InvStdDev"], + axis=axis, + ) + + x = Tensor(Dtype.FP16x16, X.shape, to_fp(X.flatten(), FixedImpl.FP16x16)) + w = Tensor(Dtype.FP16x16, W.shape, to_fp(W.flatten(), FixedImpl.FP16x16)) + b = Tensor(Dtype.FP16x16, B.shape, to_fp(B.flatten(), FixedImpl.FP16x16)) + y = Tensor(Dtype.FP16x16, Y.shape, to_fp(Y.flatten(), FixedImpl.FP16x16)) + + name = "layer_normalization_test" + make_test([x,w,b], y, "input_0.layer_normalization(@input_1,Option::Some(@input_2),Option::None,Option::None,Option::None)", name) + + case(-1) \ No newline at end of file diff --git a/src/operators/ml/tree_ensemble/tree_ensemble_classifier.cairo b/src/operators/ml/tree_ensemble/tree_ensemble_classifier.cairo index eb50a2e14..051965260 100644 --- a/src/operators/ml/tree_ensemble/tree_ensemble_classifier.cairo +++ b/src/operators/ml/tree_ensemble/tree_ensemble_classifier.cairo @@ -408,12 +408,8 @@ impl TreeEnsembleClassifierImpl< let mut class_id: usize = 0; // Get first class_id in class_ids match class_ids.pop_front() { - Option::Some(c_id) => { - let mut class_id = *c_id; - }, - Option::None(_) => { - let mut class_id: usize = 0; - } + Option::Some(c_id) => { let mut class_id = *c_id; }, + Option::None(_) => { let mut class_id: usize = 0; } }; loop { if i == self.class_ids.len() { @@ -424,19 +420,17 @@ impl TreeEnsembleClassifierImpl< if *c_id == class_id { binary = true; continue; - }else{ + } else { binary = false; break; } - }, Option::None(_) => { break; } }; - }; // Clone res - if binary{ + if binary { let mut new_res: MutMatrix = MutMatrixImpl::new(res.rows, res.cols); let mut i: usize = 0; loop { @@ -445,14 +439,10 @@ impl TreeEnsembleClassifierImpl< } // Exchange let res_ele_1 = match res.get(i, 0) { - Option::Some(res_0) => { - new_res.set(i, 1, res_0); - }, - Option::None(_) => { - new_res.set(i, 1, NumberTrait::zero()); - }, + Option::Some(res_0) => { new_res.set(i, 1, res_0); }, + Option::None(_) => { new_res.set(i, 1, NumberTrait::zero()); }, }; - i+=1; + i += 1; }; match self.post_transform { POST_TRANSFORM::NONE => { @@ -467,11 +457,9 @@ impl TreeEnsembleClassifierImpl< let value = NumberTrait::sub(NumberTrait::one(), res_1); new_res.set(i, 0, value); }, - Option::None(_) => { - new_res.set(i, 0, NumberTrait::zero()); - }, + Option::None(_) => { new_res.set(i, 0, NumberTrait::zero()); }, }; - i+=1; + i += 1; }; }, POST_TRANSFORM::SOFTMAX => { @@ -482,14 +470,10 @@ impl TreeEnsembleClassifierImpl< } // Exchange let res_ele_0 = match new_res.get(i, 1) { - Option::Some(res_1) => { - new_res.set(i, 0, res_1.neg()); - }, - Option::None(_) => { - new_res.set(i, 0, NumberTrait::zero()); - }, + Option::Some(res_1) => { new_res.set(i, 0, res_1.neg()); }, + Option::None(_) => { new_res.set(i, 0, NumberTrait::zero()); }, }; - i+=1; + i += 1; }; }, POST_TRANSFORM::LOGISTIC => { @@ -500,14 +484,10 @@ impl TreeEnsembleClassifierImpl< } // Exchange let res_ele_0 = match new_res.get(i, 1) { - Option::Some(res_1) => { - new_res.set(i, 0, res_1.neg()); - }, - Option::None(_) => { - new_res.set(i, 0, NumberTrait::zero()); - }, + Option::Some(res_1) => { new_res.set(i, 0, res_1.neg()); }, + Option::None(_) => { new_res.set(i, 0, NumberTrait::zero()); }, }; - i+=1; + i += 1; }; }, POST_TRANSFORM::SOFTMAXZERO => { @@ -518,14 +498,10 @@ impl TreeEnsembleClassifierImpl< } // Exchange let res_ele_0 = match new_res.get(i, 1) { - Option::Some(res_1) => { - new_res.set(i, 0, res_1.neg()); - }, - Option::None(_) => { - new_res.set(i, 0, NumberTrait::zero()); - }, + Option::Some(res_1) => { new_res.set(i, 0, res_1.neg()); }, + Option::None(_) => { new_res.set(i, 0, NumberTrait::zero()); }, }; - i+=1; + i += 1; }; }, POST_TRANSFORM::PROBIT => { @@ -540,17 +516,15 @@ impl TreeEnsembleClassifierImpl< let value = NumberTrait::sub(NumberTrait::one(), res_1); new_res.set(i, 0, value); }, - Option::None(_) => { - new_res.set(i, 0, NumberTrait::zero()); - }, + Option::None(_) => { new_res.set(i, 0, NumberTrait::zero()); }, }; - i+=1; + i += 1; }; }, }; res = new_res; } - + // Post Transform let mut new_scores = match self.post_transform { POST_TRANSFORM::NONE => res, // No action required diff --git a/src/operators/sequence/functional.cairo b/src/operators/sequence/functional.cairo index 84f30cfc7..e0b80db7c 100644 --- a/src/operators/sequence/functional.cairo +++ b/src/operators/sequence/functional.cairo @@ -4,4 +4,4 @@ mod sequence_at; mod sequence_erase; mod sequence_insert; mod sequence_length; -mod concat_from_sequence; \ No newline at end of file +mod concat_from_sequence; diff --git a/src/operators/sequence/implementations/sequence_bool.cairo b/src/operators/sequence/implementations/sequence_bool.cairo index 1ac241e41..b9d800123 100644 --- a/src/operators/sequence/implementations/sequence_bool.cairo +++ b/src/operators/sequence/implementations/sequence_bool.cairo @@ -41,5 +41,4 @@ impl BoolSequence of SequenceTrait { ) -> Tensor { functional::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } - } diff --git a/src/operators/sequence/implementations/sequence_fp8x23wide.cairo b/src/operators/sequence/implementations/sequence_fp8x23wide.cairo index eaebb072d..64bb5576f 100644 --- a/src/operators/sequence/implementations/sequence_fp8x23wide.cairo +++ b/src/operators/sequence/implementations/sequence_fp8x23wide.cairo @@ -43,5 +43,4 @@ impl FP8x23WSequence of SequenceTrait { ) -> Tensor { functional::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) } - } diff --git a/src/operators/tensor/core.cairo b/src/operators/tensor/core.cairo index 2074a1ecf..6a8b17e1c 100644 --- a/src/operators/tensor/core.cairo +++ b/src/operators/tensor/core.cairo @@ -114,6 +114,7 @@ impl TensorSerde, impl TDrop: Drop> of Serde= 1, indices tensor of rank q >= 1, and batch_dims integer b, this operator gathers slices of data into an output tensor of rank q + r - indices_shape[-1] - 1 - b. /// reduce_log_sum - Computes the log sum of the input tensor's elements along the provided axes. /// erf - Computes the error function of the given input tensor element-wise. +/// layer_normalization - computes the layer normalization of the input tensor. trait TensorTrait { /// # tensor.new /// @@ -4776,6 +4777,104 @@ trait TensorTrait { /// ``` /// fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor; + /// # tensor.layer_normalization + /// + /// ```rust + /// fn layer_normalization( + /// self: @Tensor, + /// scale: @Tensor, + /// B: Option<@Tensor>, + /// axis: Option, + /// epsilon: Option, + /// stash_type: Option, + /// ) -> (Tensor, Tensor, Tensor); + /// ``` + /// + /// Layer normalization of the input, in two stages. + /// The first stage is standardization, which makes the normalized elements have zero mean and unit variances. + /// The second stage then scales and shifts the outcome of the first stage + /// ## Args + /// + /// * `self`(`@Tensor`) - The input tensor. + /// * `scale`(`@Tensor,`) - Scale tensor. + /// * `B`(`Option<@Tensor>`) - Bias tensor. + /// * `axis`(`Option`) (default is -1) - The first normalization dimension. If rank(X) is r, axis' allowed range is [-r, r). Negative value means counting dimensions from the back. + /// * `epsilon`(`Option`) (default is 0) - The epsilon value to use to avoid division by zero. + /// * `stash_type`(`Option`) - Precise the computation precision - unused the precision is defined by the type of the tensor. + /// ## Panics + /// + /// * Panics if condition rank is not equal to 1. + /// + /// ## Returns + /// + /// A new normalized tensor`Tensor`. + /// A tensor containing the mean `Tensor`. + /// A tensor containing the inverse standard deviation `Tensor`. + /// + /// ## Example + /// + /// ```rust + /// use orion::operators::tensor::{TensorTrait, Tensor}; + /// use orion::operators::tensor::FP16x16TensorPartialEq; + /// use core::array::{ArrayTrait, SpanTrait}; + /// use orion::operators::tensor::FP16x16Tensor; + /// use orion::numbers::{FixedTrait, FP16x16}; + /// + /// fn layer_normalization_example() -> (Tensor, Tensor, Tensor) { + /// let mut shape = ArrayTrait::::new(); + /// shape.append(3); + /// shape.append(4); + /// + /// let mut data = ArrayTrait::new(); + /// data.append(FP16x16 { mag: 41143, sign: true }); + /// data.append(FP16x16 { mag: 51803, sign: false }); + /// data.append(FP16x16 { mag: 113556, sign: false }); + /// data.append(FP16x16 { mag: 64774, sign: false }); + /// data.append(FP16x16 { mag: 866, sign: false }); + /// data.append(FP16x16 { mag: 698, sign: true }); + /// data.append(FP16x16 { mag: 106500, sign: false }); + /// data.append(FP16x16 { mag: 98929, sign: false }); + /// data.append(FP16x16 { mag: 7551, sign: false }); + /// data.append(FP16x16 { mag: 30689, sign: true }); + /// data.append(FP16x16 { mag: 38325, sign: false }); + /// data.append(FP16x16 { mag: 48164, sign: false }); + /// let X = TensorTrait::new(shape.span(), data.span()); + /// + /// let shape = ArrayTrait::::new(); + /// shape.append(4); + /// let mut data = ArrayTrait::new(); + /// data.append(FP16x16 { mag: 49855, sign: false }); + /// data.append(FP16x16 { mag: 150787, sign: false }); + /// data.append(FP16x16 { mag: 83498, sign: true }); + /// data.append(FP16x16 { mag: 30346, sign: false }); + /// let scale = TensorTrait::new(shape.span(), data.span()); + /// + /// + /// let mut shape = ArrayTrait::::new(); + /// shape.append(4); + /// let mut data = ArrayTrait::new(); + /// data.append(FP16x16 { mag: 54864, sign: true }); + /// data.append(FP16x16 { mag: 50952, sign: false }); + /// data.append(FP16x16 { mag: 8870, sign: true }); + /// data.append(FP16x16 { mag: 23216, sign: true }); + /// let bias = TensorTrait::new(shape.span(), data.span()); + /// + /// return X.layer_normalization(@scale,Option::Some(@bias),Option::None,Option::None,Option::None); + /// } + /// >>> [[-0.48926553 1.0185822 -0.02138367 -0.39223218] + /// [-0.7945549 0.99696046 0.04332176 -0.412645 ] + /// [-0.5664707 0.7491956 -0.7896356 -0.5320859 ]] + /// + /// ``` + /// + fn layer_normalization( + self: @Tensor, + scale: @Tensor, + B: Option<@Tensor>, + axis: Option, + epsilon: Option, + stash_type: Option, + ) -> (Tensor, Tensor, Tensor); } /// Cf: TensorTrait::new docstring diff --git a/src/operators/tensor/implementations/tensor_bool.cairo b/src/operators/tensor/implementations/tensor_bool.cairo index 8ca90eef6..f3c55f305 100644 --- a/src/operators/tensor/implementations/tensor_bool.cairo +++ b/src/operators/tensor/implementations/tensor_bool.cairo @@ -444,9 +444,22 @@ impl BoolTensor of TensorTrait { math::gather_nd::gather_nd(self, indices, batch_dims) } - fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { + fn compress( + self: @Tensor, condition: Tensor, axis: Option + ) -> Tensor { math::compress::compress(self, condition, axis) } + + fn layer_normalization( + self: @Tensor, + scale: @Tensor, + B: Option<@Tensor>, + axis: Option, + epsilon: Option, + stash_type: Option, + ) -> (Tensor, Tensor, Tensor) { + panic(array!['not supported!']) + } } /// Implements partial equal for two `Tensor` using the `PartialEq` trait. diff --git a/src/operators/tensor/implementations/tensor_complex64.cairo b/src/operators/tensor/implementations/tensor_complex64.cairo index 53feb8980..0f854ec9e 100644 --- a/src/operators/tensor/implementations/tensor_complex64.cairo +++ b/src/operators/tensor/implementations/tensor_complex64.cairo @@ -472,9 +472,22 @@ impl Complex64Tensor of TensorTrait { panic(array!['not supported!']) } - fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { + fn compress( + self: @Tensor, condition: Tensor, axis: Option + ) -> Tensor { math::compress::compress(self, condition, axis) } + + fn layer_normalization( + self: @Tensor, + scale: @Tensor, + B: Option<@Tensor>, + axis: Option, + epsilon: Option, + stash_type: Option, + ) -> (Tensor, Tensor, Tensor) { + panic(array!['not supported!']) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp16x16.cairo b/src/operators/tensor/implementations/tensor_fp16x16.cairo index 2f13326c1..be7964b35 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16.cairo @@ -486,7 +486,9 @@ impl FP16x16Tensor of TensorTrait { math::is_nan::is_nan(self) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } @@ -504,9 +506,22 @@ impl FP16x16Tensor of TensorTrait { manipulation::unique::unique(self, axis, sorted) } - fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { + fn compress( + self: @Tensor, condition: Tensor, axis: Option + ) -> Tensor { math::compress::compress(self, condition, axis) } + + fn layer_normalization( + self: @Tensor, + scale: @Tensor, + B: Option<@Tensor>, + axis: Option, + epsilon: Option, + stash_type: Option, + ) -> (Tensor, Tensor, Tensor) { + math::layer_normalization::layer_normalization(self, scale, B, axis, epsilon, stash_type) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo index 4070d0154..6c873b8ea 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo @@ -452,7 +452,9 @@ impl FP16x16WTensor of TensorTrait { math::is_nan::is_nan(self) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } @@ -470,9 +472,22 @@ impl FP16x16WTensor of TensorTrait { manipulation::unique::unique(self, axis, sorted) } - fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { + fn compress( + self: @Tensor, condition: Tensor, axis: Option + ) -> Tensor { math::compress::compress(self, condition, axis) } + + fn layer_normalization( + self: @Tensor, + scale: @Tensor, + B: Option<@Tensor>, + axis: Option, + epsilon: Option, + stash_type: Option, + ) -> (Tensor, Tensor, Tensor) { + math::layer_normalization::layer_normalization(self, scale, B, axis, epsilon, stash_type) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp32x32.cairo b/src/operators/tensor/implementations/tensor_fp32x32.cairo index bc77c3e15..65c640b2e 100644 --- a/src/operators/tensor/implementations/tensor_fp32x32.cairo +++ b/src/operators/tensor/implementations/tensor_fp32x32.cairo @@ -487,7 +487,9 @@ impl FP32x32Tensor of TensorTrait { math::is_nan::is_nan(self) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } @@ -505,9 +507,22 @@ impl FP32x32Tensor of TensorTrait { manipulation::unique::unique(self, axis, sorted) } - fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { + fn compress( + self: @Tensor, condition: Tensor, axis: Option + ) -> Tensor { math::compress::compress(self, condition, axis) } + + fn layer_normalization( + self: @Tensor, + scale: @Tensor, + B: Option<@Tensor>, + axis: Option, + epsilon: Option, + stash_type: Option, + ) -> (Tensor, Tensor, Tensor) { + math::layer_normalization::layer_normalization(self, scale, B, axis, epsilon, stash_type) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp64x64.cairo b/src/operators/tensor/implementations/tensor_fp64x64.cairo index 7cac1e80f..a56062474 100644 --- a/src/operators/tensor/implementations/tensor_fp64x64.cairo +++ b/src/operators/tensor/implementations/tensor_fp64x64.cairo @@ -487,7 +487,9 @@ impl FP64x64Tensor of TensorTrait { math::is_nan::is_nan(self) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } @@ -505,9 +507,22 @@ impl FP64x64Tensor of TensorTrait { manipulation::unique::unique(self, axis, sorted) } - fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { + fn compress( + self: @Tensor, condition: Tensor, axis: Option + ) -> Tensor { math::compress::compress(self, condition, axis) } + + fn layer_normalization( + self: @Tensor, + scale: @Tensor, + B: Option<@Tensor>, + axis: Option, + epsilon: Option, + stash_type: Option, + ) -> (Tensor, Tensor, Tensor) { + math::layer_normalization::layer_normalization(self, scale, B, axis, epsilon, stash_type) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp8x23.cairo b/src/operators/tensor/implementations/tensor_fp8x23.cairo index 6b8a471f0..549675a2d 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23.cairo @@ -485,7 +485,9 @@ impl FP8x23Tensor of TensorTrait { math::is_nan::is_nan(self) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } @@ -503,9 +505,22 @@ impl FP8x23Tensor of TensorTrait { manipulation::unique::unique(self, axis, sorted) } - fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { + fn compress( + self: @Tensor, condition: Tensor, axis: Option + ) -> Tensor { math::compress::compress(self, condition, axis) } + + fn layer_normalization( + self: @Tensor, + scale: @Tensor, + B: Option<@Tensor>, + axis: Option, + epsilon: Option, + stash_type: Option, + ) -> (Tensor, Tensor, Tensor) { + math::layer_normalization::layer_normalization(self, scale, B, axis, epsilon, stash_type) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo index 54118f17b..58f985e5f 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo @@ -438,7 +438,9 @@ impl FP8x23WTensor of TensorTrait { math::is_nan::is_nan(self) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } @@ -456,9 +458,22 @@ impl FP8x23WTensor of TensorTrait { manipulation::unique::unique(self, axis, sorted) } - fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { + fn compress( + self: @Tensor, condition: Tensor, axis: Option + ) -> Tensor { math::compress::compress(self, condition, axis) } + + fn layer_normalization( + self: @Tensor, + scale: @Tensor, + B: Option<@Tensor>, + axis: Option, + epsilon: Option, + stash_type: Option, + ) -> (Tensor, Tensor, Tensor) { + math::layer_normalization::layer_normalization(self, scale, B, axis, epsilon, stash_type) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_i32.cairo b/src/operators/tensor/implementations/tensor_i32.cairo index 67401fcb2..689cbd893 100644 --- a/src/operators/tensor/implementations/tensor_i32.cairo +++ b/src/operators/tensor/implementations/tensor_i32.cairo @@ -482,7 +482,9 @@ impl I32Tensor of TensorTrait { panic(array!['not supported!']) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } @@ -503,6 +505,17 @@ impl I32Tensor of TensorTrait { fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { math::compress::compress(self, condition, axis) } + + fn layer_normalization( + self: @Tensor, + scale: @Tensor, + B: Option<@Tensor>, + axis: Option, + epsilon: Option, + stash_type: Option, + ) -> (Tensor, Tensor, Tensor) { + panic(array!['not supported!']) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_i8.cairo b/src/operators/tensor/implementations/tensor_i8.cairo index 4077b9bd3..8b3cebb8a 100644 --- a/src/operators/tensor/implementations/tensor_i8.cairo +++ b/src/operators/tensor/implementations/tensor_i8.cairo @@ -480,7 +480,9 @@ impl I8Tensor of TensorTrait { panic(array!['not supported!']) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } @@ -501,6 +503,17 @@ impl I8Tensor of TensorTrait { fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { math::compress::compress(self, condition, axis) } + + fn layer_normalization( + self: @Tensor, + scale: @Tensor, + B: Option<@Tensor>, + axis: Option, + epsilon: Option, + stash_type: Option, + ) -> (Tensor, Tensor, Tensor) { + panic(array!['not supported!']) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_u32.cairo b/src/operators/tensor/implementations/tensor_u32.cairo index b69c19b04..17d5facf1 100644 --- a/src/operators/tensor/implementations/tensor_u32.cairo +++ b/src/operators/tensor/implementations/tensor_u32.cairo @@ -424,7 +424,9 @@ impl U32Tensor of TensorTrait { panic(array!['not supported!']) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } @@ -445,6 +447,17 @@ impl U32Tensor of TensorTrait { fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { math::compress::compress(self, condition, axis) } + + fn layer_normalization( + self: @Tensor, + scale: @Tensor, + B: Option<@Tensor>, + axis: Option, + epsilon: Option, + stash_type: Option, + ) -> (Tensor, Tensor, Tensor) { + panic(array!['not supported!']) + } } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/math.cairo b/src/operators/tensor/math.cairo index 4cb97feda..10c03fb90 100644 --- a/src/operators/tensor/math.cairo +++ b/src/operators/tensor/math.cairo @@ -58,4 +58,5 @@ mod is_inf; mod gather_nd; mod reduce_log_sum; mod erf; -mod compress; \ No newline at end of file +mod compress; +mod layer_normalization; diff --git a/src/operators/tensor/math/compress.cairo b/src/operators/tensor/math/compress.cairo index 6380d5d15..d22eb1d82 100644 --- a/src/operators/tensor/math/compress.cairo +++ b/src/operators/tensor/math/compress.cairo @@ -14,12 +14,7 @@ use orion::operators::tensor::U32TensorPartialEq; use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor}; /// Cf: TensorTrait::compare docstring -fn compress< - T, - impl TTensorTrait: TensorTrait, - impl TCopy: Copy, - impl TDrop: Drop, ->( +fn compress, impl TCopy: Copy, impl TDrop: Drop,>( self: @Tensor, condition: Tensor, axis: Option ) -> Tensor { let axis = match axis { @@ -29,7 +24,7 @@ fn compress< let data_rank = (*self.shape).len(); let condition_rank = (condition.shape).len(); - assert((data_rank >= 1 ), 'data rank must > 1'); + assert((data_rank >= 1), 'data rank must > 1'); assert((condition_rank == 1), 'condition rank must be 1'); let mut data_shape = *self.shape; @@ -67,9 +62,7 @@ fn compress< let mut total_shape = 1; loop { match data_shape.pop_front() { - Option::Some(val) => { - total_shape *= *val; - }, + Option::Some(val) => { total_shape *= *val; }, Option::None(_) => { break; } }; }; @@ -78,8 +71,10 @@ fn compress< loop { match condition_data.pop_front() { Option::Some(val) => { - if (ind == total_shape) {break; } - if (*val != 0){ + if (ind == total_shape) { + break; + } + if (*val != 0) { output_data.append(*self.data[ind]); } ind += 1; @@ -99,8 +94,7 @@ fn compress< Option::Some(val) => { if (ind == axis) { output_shape.append(output); - } - else { + } else { output_shape.append(*val); if (ind > axis) { loop_breaker *= *val; @@ -120,31 +114,34 @@ fn compress< let mut ind = 0; let mut ind_loop = 0; - + let mut inner_index: usize = 0; let mut condition_data_clone = condition_data.clone(); loop { - if (ind == other_loop_breaker) {break;} + if (ind == other_loop_breaker) { + break; + } let mut condition_data_clone = condition_data.clone(); - inner_index = *data_shape.at(axis) * ind; + inner_index = *data_shape.at(axis) * ind; loop { - match condition_data_clone.pop_front() { - Option::Some(val) => { - if (*val != 0){ - let result = inner_index * loop_breaker ; - - let mut data_ind:usize = result ; - loop { - if data_ind == result + loop_breaker { break; } - index_data.append(data_ind); - data_ind+=1; - }; - } - inner_index += 1; - }, - Option::None(_) => { break; } + Option::Some(val) => { + if (*val != 0) { + let result = inner_index * loop_breaker; + + let mut data_ind: usize = result; + loop { + if data_ind == result + loop_breaker { + break; + } + index_data.append(data_ind); + data_ind += 1; + }; + } + inner_index += 1; + }, + Option::None(_) => { break; } }; }; @@ -153,14 +150,12 @@ fn compress< loop { match index_data.pop_front() { - Option::Some(val) => { - output_data.append(*self.data[val]); - }, + Option::Some(val) => { output_data.append(*self.data[val]); }, Option::None(_) => { break; } }; - }; + }; } let mut output_tensor = TensorTrait::::new(output_shape.span(), output_data.span()); return output_tensor; -} \ No newline at end of file +} diff --git a/src/operators/tensor/math/layer_normalization.cairo b/src/operators/tensor/math/layer_normalization.cairo new file mode 100644 index 000000000..6a20fd92b --- /dev/null +++ b/src/operators/tensor/math/layer_normalization.cairo @@ -0,0 +1,191 @@ +use core::traits::TryInto; +use core::array::ArrayTrait; +use core::array::SpanTrait; +use core::option::OptionTrait; +use core::traits::Into; +use orion::numbers::NumberTrait; +use orion::operators::tensor::{ + TensorTrait, Tensor, I8Tensor, I32Tensor, U32Tensor, FP16x16Tensor, BoolTensor +}; +use orion::numbers::{FP16x16, FP16x16Impl, FP32x32, FP32x32Impl, FixedTrait}; +use core::debug::PrintTrait; +use orion::numbers::{i8, i32, IntegerTrait}; +use orion::operators::vec::{VecTrait, NullableVec, NullableVecImpl}; + + +/// Cf: TensorTrait::layer_normalization docstring +fn layer_normalization< + T, + MAG, + +TensorTrait, + +NumberTrait, + +PartialEq, + +Copy, + +Drop, + +Div>, + +Sub>, + +Add>, + +Mul>, + +Into, +>( + self: @Tensor, + scale: @Tensor, + B: Option<@Tensor>, + axis: Option, + epsilon: Option, + stash_type: Option, +) -> (Tensor, Tensor, Tensor) { + let X_rank = (*self).shape.len(); + let X_shape = (*self).shape; + + let mut axis = match axis { + Option::Some(axis) => axis, + Option::None => IntegerTrait::::new(1, true), + }; + let epsilon = match epsilon { + Option::Some(epsilon) => epsilon, + Option::None => NumberTrait::zero(), // default of onnx is 1e-05 + }; + + let stash_type = match stash_type { + Option::Some(stash_type) => stash_type, + Option::None => 1, + }; + + let axis = if axis < IntegerTrait::::new(0, false) { + X_rank - axis.mag + } else { + axis.mag + }; + + let unsqueezed_rank = X_rank - axis; + let mut reduction_shape = ArrayTrait::new(); + let mut i = 0; + loop { + if i == axis { + break; + } + reduction_shape.append(*(*self).shape.at(i)); + i += 1; + }; + let mut i = 0; + loop { + if i == unsqueezed_rank { + break; + } + reduction_shape.append(1); + i += 1; + }; + + let mut row_number = 1; + let mut col_number = 1; + let mut i = 0; + loop { + if i == X_rank { + break; + } + if i < axis { + row_number *= *(*self).shape.at(i); + } else { + col_number *= *(*self).shape.at(i); + } + i += 1; + }; + + let mut shape_matrix = ArrayTrait::new(); + shape_matrix.append(row_number); + shape_matrix.append(col_number); + + // Shape [1, 1] to mutiply one element tensors with 2D matrices + let mut shape_one = ArrayTrait::new(); + shape_one.append(1); + shape_one.append(1); + + let mut col_number_tensor = ArrayTrait::new(); + col_number_tensor.append(NumberTrait::new_unscaled(col_number.into(), false)); + + let mut epsilon_tensor = ArrayTrait::new(); + epsilon_tensor.append(epsilon); + + let mut one_tensor = ArrayTrait::new(); + one_tensor.append(NumberTrait::one()); + + let x_mat = self.reshape(shape_matrix.span()); + let x_mean = x_mat.reduce_sum(1, true) + / TensorTrait::new(shape_one.span(), col_number_tensor.span()); + + let x_diff = x_mat - x_mean; + let x_squared_diff = x_diff * x_diff; + + let variance = x_squared_diff.reduce_sum(1, true) + / TensorTrait::new(shape_one.span(), col_number_tensor.span()); + let variance_eps = variance + TensorTrait::new(shape_one.span(), epsilon_tensor.span()); + + let std_dev = variance_eps.sqrt(); + + let inv_std_dev = TensorTrait::new(shape_one.span(), one_tensor.span()) / std_dev; + + let y_mat = x_diff * inv_std_dev; + + let scale = if (*scale).shape.len() < (*self).shape.len() { + // Append 1 in scale shape to make sure scale has a dimension compatible with Y for multiplication + let mut shape = ArrayTrait::new(); + let mut i = 0; + loop { + if i == (*self).shape.len() - (*scale).shape.len() { + break; + } + shape.append(1); + i += 1; + }; + let mut i = 0; + loop { + if i == (*scale).shape.len() { + break; + } + shape.append(*(*scale).shape.at(i)); + i += 1; + }; + TensorTrait::new(shape.span(), (*scale).data) + } else { + *scale + }; + + let Y = y_mat.reshape((*self).shape) * scale; + + let Y = match B { + Option::Some(B) => { + let B = if (*B).shape.len() < (*self).shape.len() { + // Append 1 in B shape to make sure scale has a dimension compatible with Y for multiplication + let mut shape = ArrayTrait::new(); + let mut i = 0; + loop { + if i == (*self).shape.len() - (*B).shape.len() { + break; + } + shape.append(1); + i += 1; + }; + let mut i = 0; + loop { + if i == (*B).shape.len() { + break; + } + shape.append(*(*B).shape.at(i)); + i += 1; + }; + TensorTrait::new(shape.span(), (*B).data) + } else { + *B + }; + Y + B + }, + Option::None => Y, + }; + + let X_mean = TensorTrait::new(reduction_shape.span(), x_mean.data); + let X_inv_std_dev = TensorTrait::new(reduction_shape.span(), inv_std_dev.data); + + return (Y, X_mean, X_inv_std_dev); +} + diff --git a/tests/lib.cairo b/tests/lib.cairo index c408347ef..f5cecb77d 100644 --- a/tests/lib.cairo +++ b/tests/lib.cairo @@ -5,4 +5,3 @@ mod nodes; mod ml; mod operators; - diff --git a/tests/ml/tree_ensemble_classifier.cairo b/tests/ml/tree_ensemble_classifier.cairo index 6ee2afc11..441aabb34 100644 --- a/tests/ml/tree_ensemble_classifier.cairo +++ b/tests/ml/tree_ensemble_classifier.cairo @@ -241,8 +241,9 @@ fn test_tree_ensemble_classifier_binary_none() { #[test] #[available_gas(200000000000)] fn test_tree_ensemble_classifier_binary_logistic() { - - let (mut classifier, X) = tree_ensemble_classifier_binary_class_helper(POST_TRANSFORM::LOGISTIC); + let (mut classifier, X) = tree_ensemble_classifier_binary_class_helper( + POST_TRANSFORM::LOGISTIC + ); let (labels, mut scores) = TreeEnsembleClassifierTrait::predict(ref classifier, X); @@ -282,11 +283,13 @@ fn test_tree_ensemble_classifier_binary_softmax() { 'score[0, 1]' ); } - + #[test] #[available_gas(200000000000)] fn test_tree_ensemble_classifier_binary_softmax_zero() { - let (mut classifier, X) = tree_ensemble_classifier_binary_class_helper(POST_TRANSFORM::SOFTMAXZERO); + let (mut classifier, X) = tree_ensemble_classifier_binary_class_helper( + POST_TRANSFORM::SOFTMAXZERO + ); let (labels, mut scores) = TreeEnsembleClassifierTrait::predict(ref classifier, X); @@ -485,180 +488,1911 @@ fn tree_ensemble_classifier_helper( fn tree_ensemble_classifier_binary_class_helper( post_transform: POST_TRANSFORM ) -> (TreeEnsembleClassifier, Tensor) { - let class_ids: Span = array![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0].span(); - let class_nodeids: Span = array![4, 5, 7, 10, 12, 13, 15, 17, 19, 20, 24, 26, 29, 31, 32, 33, 37, 38, 39, 40, 46, 49, 50, 52, 56, 57, 58, 59, 62, 64, 66, 67, 68, 73, 74, 75, 76, 81, 82, 83, 84, 88, 89, 91, 93, 94, 95, 98, 99, 101, 104, 106, 107, 108, 112, 113, 114, 115, 119, 121, 124, 125, 127, 128, 130, 131, 138, 140, 141, 142, 143, 148, 149, 150, 151, 152, 153, 154].span(); - let class_treeids: Span = array![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0].span(); - let class_weights: Span = array![FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 43690, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 65536, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 65536, sign: false }].span(); + let class_ids: Span = array![ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ] + .span(); + let class_nodeids: Span = array![ + 4, + 5, + 7, + 10, + 12, + 13, + 15, + 17, + 19, + 20, + 24, + 26, + 29, + 31, + 32, + 33, + 37, + 38, + 39, + 40, + 46, + 49, + 50, + 52, + 56, + 57, + 58, + 59, + 62, + 64, + 66, + 67, + 68, + 73, + 74, + 75, + 76, + 81, + 82, + 83, + 84, + 88, + 89, + 91, + 93, + 94, + 95, + 98, + 99, + 101, + 104, + 106, + 107, + 108, + 112, + 113, + 114, + 115, + 119, + 121, + 124, + 125, + 127, + 128, + 130, + 131, + 138, + 140, + 141, + 142, + 143, + 148, + 149, + 150, + 151, + 152, + 153, + 154 + ] + .span(); + let class_treeids: Span = array![ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ] + .span(); + let class_weights: Span = array![ + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 65536, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 32768, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 65536, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 32768, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 65536, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 65536, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 65536, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 65536, sign: false }, + FP16x16 { mag: 65536, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 65536, sign: false }, + FP16x16 { mag: 32768, sign: false }, + FP16x16 { mag: 65536, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 32768, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 65536, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 65536, sign: false }, + FP16x16 { mag: 43690, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 32768, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 65536, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 32768, sign: false }, + FP16x16 { mag: 65536, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 65536, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 65536, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 65536, sign: false }, + FP16x16 { mag: 65536, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 32768, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 65536, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 65536, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 65536, sign: false }, + FP16x16 { mag: 65536, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 32768, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 65536, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 65536, sign: false }, + FP16x16 { mag: 65536, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 65536, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 65536, sign: false }, + FP16x16 { mag: 65536, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 65536, sign: false } + ] + .span(); let classlabels: Span = array![0, 1].span(); - let nodes_falsenodeids: Span = array![116, 21, 6, 5, 0, 0, 8, 0, 14, 11, 0, 13, 0, 0, 16, 0, 18, 0, 20, 0, 0, 41, 34, 25, 0, 27, 0, 33, 30, 0, 32, 0, 0, 0, 40, 39, 38, 0, 0, 0, 0, 109, 96, 69, 60, 47, 0, 51, 50, 0, 0, 53, 0, 59, 58, 57, 0, 0, 0, 0, 68, 63, 0, 65, 0, 67, 0, 0, 0, 77, 76, 75, 74, 0, 0, 0, 0, 85, 84, 83, 82, 0, 0, 0, 0, 95, 90, 89, 0, 0, 92, 0, 94, 0, 0, 0, 100, 99, 0, 0, 102, 0, 108, 105, 0, 107, 0, 0, 0, 115, 114, 113, 0, 0, 0, 0, 132, 129, 120, 0, 122, 0, 126, 125, 0, 0, 128, 0, 0, 131, 0, 0, 154, 153, 144, 143, 142, 139, 0, 141, 0, 0, 0, 0, 152, 151, 150, 149, 0, 0, 0, 0, 0, 0, 0].span(); - let nodes_featureids: Span = array![3, 2, 4, 8, 0, 0, 1, 0, 2, 7, 0, 0, 0, 0, 7, 0, 0, 0, 6, 0, 0, 8, 0, 2, 0, 7, 0, 7, 2, 0, 2, 0, 0, 0, 2, 6, 7, 0, 0, 0, 0, 7, 7, 0, 7, 1, 0, 0, 2, 0, 0, 2, 0, 2, 2, 6, 0, 0, 0, 0, 2, 0, 0, 1, 0, 6, 0, 0, 0, 0, 2, 6, 7, 0, 0, 0, 0, 6, 7, 2, 0, 0, 0, 0, 0, 2, 2, 7, 0, 0, 2, 0, 0, 0, 0, 0, 6, 1, 0, 0, 4, 0, 2, 2, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 6, 0, 7, 0, 0, 0, 1, 3, 0, 0, 2, 0, 0, 8, 0, 0, 2, 2, 2, 4, 7, 3, 0, 1, 0, 0, 0, 0, 4, 3, 7, 8, 0, 0, 0, 0, 0, 0, 0].span(); - let nodes_missing_value_tracks_true: Span = array![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0].span(); - let nodes_modes: Span = array![NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::BRANCH_LEQ, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF, NODE_MODES::LEAF].span(); - let nodes_nodeids: Span = array![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154].span(); - let nodes_treeids: Span = array![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0].span(); - let nodes_truenodeids: Span = array![1, 2, 3, 4, 0, 0, 7, 0, 9, 10, 0, 12, 0, 0, 15, 0, 17, 0, 19, 0, 0, 22, 23, 24, 0, 26, 0, 28, 29, 0, 31, 0, 0, 0, 35, 36, 37, 0, 0, 0, 0, 42, 43, 44, 45, 46, 0, 48, 49, 0, 0, 52, 0, 54, 55, 56, 0, 0, 0, 0, 61, 62, 0, 64, 0, 66, 0, 0, 0, 70, 71, 72, 73, 0, 0, 0, 0, 78, 79, 80, 81, 0, 0, 0, 0, 86, 87, 88, 0, 0, 91, 0, 93, 0, 0, 0, 97, 98, 0, 0, 101, 0, 103, 104, 0, 106, 0, 0, 0, 110, 111, 112, 0, 0, 0, 0, 117, 118, 119, 0, 121, 0, 123, 124, 0, 0, 127, 0, 0, 130, 0, 0, 133, 134, 135, 136, 137, 138, 0, 140, 0, 0, 0, 0, 145, 146, 147, 148, 0, 0, 0, 0, 0, 0, 0].span(); - let nodes_values: Span = array![FP16x16 { mag: 4096, sign: false }, FP16x16 { mag: 22937, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 49152, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 16384, sign: false }, FP16x16 { mag: 57344, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 19660, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 8192, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 29491, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 8192, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 24576, sign: false }, FP16x16 { mag: 42598, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 62259, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 62259, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 40960, sign: false }, FP16x16 { mag: 24576, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 8192, sign: false }, FP16x16 { mag: 49152, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 19660, sign: false }, FP16x16 { mag: 45875, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 29491, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 49152, sign: false }, FP16x16 { mag: 42598, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 36044, sign: false }, FP16x16 { mag: 19660, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 49152, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 45875, sign: false }, FP16x16 { mag: 29491, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 8192, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 8192, sign: false }, FP16x16 { mag: 36044, sign: false }, FP16x16 { mag: 58982, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 58982, sign: false }, FP16x16 { mag: 29491, sign: false }, FP16x16 { mag: 8192, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 45875, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 58982, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 49152, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 42598, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 45875, sign: false }, FP16x16 { mag: 49152, sign: false }, FP16x16 { mag: 29491, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 45875, sign: false }, FP16x16 { mag: 8192, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 49152, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 36044, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 58982, sign: false }, FP16x16 { mag: 49152, sign: false }, FP16x16 { mag: 36044, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 16384, sign: false }, FP16x16 { mag: 20480, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 49152, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 8192, sign: false }, FP16x16 { mag: 32768, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 0, sign: false }].span(); + let nodes_falsenodeids: Span = array![ + 116, + 21, + 6, + 5, + 0, + 0, + 8, + 0, + 14, + 11, + 0, + 13, + 0, + 0, + 16, + 0, + 18, + 0, + 20, + 0, + 0, + 41, + 34, + 25, + 0, + 27, + 0, + 33, + 30, + 0, + 32, + 0, + 0, + 0, + 40, + 39, + 38, + 0, + 0, + 0, + 0, + 109, + 96, + 69, + 60, + 47, + 0, + 51, + 50, + 0, + 0, + 53, + 0, + 59, + 58, + 57, + 0, + 0, + 0, + 0, + 68, + 63, + 0, + 65, + 0, + 67, + 0, + 0, + 0, + 77, + 76, + 75, + 74, + 0, + 0, + 0, + 0, + 85, + 84, + 83, + 82, + 0, + 0, + 0, + 0, + 95, + 90, + 89, + 0, + 0, + 92, + 0, + 94, + 0, + 0, + 0, + 100, + 99, + 0, + 0, + 102, + 0, + 108, + 105, + 0, + 107, + 0, + 0, + 0, + 115, + 114, + 113, + 0, + 0, + 0, + 0, + 132, + 129, + 120, + 0, + 122, + 0, + 126, + 125, + 0, + 0, + 128, + 0, + 0, + 131, + 0, + 0, + 154, + 153, + 144, + 143, + 142, + 139, + 0, + 141, + 0, + 0, + 0, + 0, + 152, + 151, + 150, + 149, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ] + .span(); + let nodes_featureids: Span = array![ + 3, + 2, + 4, + 8, + 0, + 0, + 1, + 0, + 2, + 7, + 0, + 0, + 0, + 0, + 7, + 0, + 0, + 0, + 6, + 0, + 0, + 8, + 0, + 2, + 0, + 7, + 0, + 7, + 2, + 0, + 2, + 0, + 0, + 0, + 2, + 6, + 7, + 0, + 0, + 0, + 0, + 7, + 7, + 0, + 7, + 1, + 0, + 0, + 2, + 0, + 0, + 2, + 0, + 2, + 2, + 6, + 0, + 0, + 0, + 0, + 2, + 0, + 0, + 1, + 0, + 6, + 0, + 0, + 0, + 0, + 2, + 6, + 7, + 0, + 0, + 0, + 0, + 6, + 7, + 2, + 0, + 0, + 0, + 0, + 0, + 2, + 2, + 7, + 0, + 0, + 2, + 0, + 0, + 0, + 0, + 0, + 6, + 1, + 0, + 0, + 4, + 0, + 2, + 2, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 2, + 0, + 0, + 0, + 0, + 6, + 0, + 7, + 0, + 0, + 0, + 1, + 3, + 0, + 0, + 2, + 0, + 0, + 8, + 0, + 0, + 2, + 2, + 2, + 4, + 7, + 3, + 0, + 1, + 0, + 0, + 0, + 0, + 4, + 3, + 7, + 8, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ] + .span(); + let nodes_missing_value_tracks_true: Span = array![ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ] + .span(); + let nodes_modes: Span = array![ + NODE_MODES::BRANCH_LEQ, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::LEAF, + NODE_MODES::LEAF, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::LEAF, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::LEAF, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::LEAF, + NODE_MODES::LEAF, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::LEAF, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::LEAF, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::LEAF, + NODE_MODES::LEAF, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::LEAF, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::LEAF, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::LEAF, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::LEAF, + NODE_MODES::LEAF, + NODE_MODES::LEAF, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::LEAF, + NODE_MODES::LEAF, + NODE_MODES::LEAF, + NODE_MODES::LEAF, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::LEAF, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::LEAF, + NODE_MODES::LEAF, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::LEAF, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::LEAF, + NODE_MODES::LEAF, + NODE_MODES::LEAF, + NODE_MODES::LEAF, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::LEAF, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::LEAF, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::LEAF, + NODE_MODES::LEAF, + NODE_MODES::LEAF, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::LEAF, + NODE_MODES::LEAF, + NODE_MODES::LEAF, + NODE_MODES::LEAF, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::LEAF, + NODE_MODES::LEAF, + NODE_MODES::LEAF, + NODE_MODES::LEAF, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::LEAF, + NODE_MODES::LEAF, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::LEAF, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::LEAF, + NODE_MODES::LEAF, + NODE_MODES::LEAF, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::LEAF, + NODE_MODES::LEAF, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::LEAF, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::LEAF, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::LEAF, + NODE_MODES::LEAF, + NODE_MODES::LEAF, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::LEAF, + NODE_MODES::LEAF, + NODE_MODES::LEAF, + NODE_MODES::LEAF, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::LEAF, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::LEAF, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::LEAF, + NODE_MODES::LEAF, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::LEAF, + NODE_MODES::LEAF, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::LEAF, + NODE_MODES::LEAF, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::LEAF, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::LEAF, + NODE_MODES::LEAF, + NODE_MODES::LEAF, + NODE_MODES::LEAF, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::BRANCH_LEQ, + NODE_MODES::LEAF, + NODE_MODES::LEAF, + NODE_MODES::LEAF, + NODE_MODES::LEAF, + NODE_MODES::LEAF, + NODE_MODES::LEAF, + NODE_MODES::LEAF + ] + .span(); + let nodes_nodeids: Span = array![ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + 58, + 59, + 60, + 61, + 62, + 63, + 64, + 65, + 66, + 67, + 68, + 69, + 70, + 71, + 72, + 73, + 74, + 75, + 76, + 77, + 78, + 79, + 80, + 81, + 82, + 83, + 84, + 85, + 86, + 87, + 88, + 89, + 90, + 91, + 92, + 93, + 94, + 95, + 96, + 97, + 98, + 99, + 100, + 101, + 102, + 103, + 104, + 105, + 106, + 107, + 108, + 109, + 110, + 111, + 112, + 113, + 114, + 115, + 116, + 117, + 118, + 119, + 120, + 121, + 122, + 123, + 124, + 125, + 126, + 127, + 128, + 129, + 130, + 131, + 132, + 133, + 134, + 135, + 136, + 137, + 138, + 139, + 140, + 141, + 142, + 143, + 144, + 145, + 146, + 147, + 148, + 149, + 150, + 151, + 152, + 153, + 154 + ] + .span(); + let nodes_treeids: Span = array![ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ] + .span(); + let nodes_truenodeids: Span = array![ + 1, + 2, + 3, + 4, + 0, + 0, + 7, + 0, + 9, + 10, + 0, + 12, + 0, + 0, + 15, + 0, + 17, + 0, + 19, + 0, + 0, + 22, + 23, + 24, + 0, + 26, + 0, + 28, + 29, + 0, + 31, + 0, + 0, + 0, + 35, + 36, + 37, + 0, + 0, + 0, + 0, + 42, + 43, + 44, + 45, + 46, + 0, + 48, + 49, + 0, + 0, + 52, + 0, + 54, + 55, + 56, + 0, + 0, + 0, + 0, + 61, + 62, + 0, + 64, + 0, + 66, + 0, + 0, + 0, + 70, + 71, + 72, + 73, + 0, + 0, + 0, + 0, + 78, + 79, + 80, + 81, + 0, + 0, + 0, + 0, + 86, + 87, + 88, + 0, + 0, + 91, + 0, + 93, + 0, + 0, + 0, + 97, + 98, + 0, + 0, + 101, + 0, + 103, + 104, + 0, + 106, + 0, + 0, + 0, + 110, + 111, + 112, + 0, + 0, + 0, + 0, + 117, + 118, + 119, + 0, + 121, + 0, + 123, + 124, + 0, + 0, + 127, + 0, + 0, + 130, + 0, + 0, + 133, + 134, + 135, + 136, + 137, + 138, + 0, + 140, + 0, + 0, + 0, + 0, + 145, + 146, + 147, + 148, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ] + .span(); + let nodes_values: Span = array![ + FP16x16 { mag: 4096, sign: false }, + FP16x16 { mag: 22937, sign: false }, + FP16x16 { mag: 32768, sign: false }, + FP16x16 { mag: 32768, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 49152, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 16384, sign: false }, + FP16x16 { mag: 57344, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 19660, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 8192, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 32768, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 32768, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 32768, sign: false }, + FP16x16 { mag: 32768, sign: false }, + FP16x16 { mag: 29491, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 8192, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 24576, sign: false }, + FP16x16 { mag: 42598, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 62259, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 62259, sign: false }, + FP16x16 { mag: 32768, sign: false }, + FP16x16 { mag: 32768, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 40960, sign: false }, + FP16x16 { mag: 24576, sign: false }, + FP16x16 { mag: 32768, sign: false }, + FP16x16 { mag: 8192, sign: false }, + FP16x16 { mag: 49152, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 19660, sign: false }, + FP16x16 { mag: 45875, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 29491, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 49152, sign: false }, + FP16x16 { mag: 42598, sign: false }, + FP16x16 { mag: 32768, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 36044, sign: false }, + FP16x16 { mag: 19660, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 49152, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 32768, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 45875, sign: false }, + FP16x16 { mag: 29491, sign: false }, + FP16x16 { mag: 32768, sign: false }, + FP16x16 { mag: 8192, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 32768, sign: false }, + FP16x16 { mag: 8192, sign: false }, + FP16x16 { mag: 36044, sign: false }, + FP16x16 { mag: 58982, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 58982, sign: false }, + FP16x16 { mag: 29491, sign: false }, + FP16x16 { mag: 8192, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 45875, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 58982, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 32768, sign: false }, + FP16x16 { mag: 49152, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 32768, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 42598, sign: false }, + FP16x16 { mag: 32768, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 32768, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 45875, sign: false }, + FP16x16 { mag: 49152, sign: false }, + FP16x16 { mag: 29491, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 32768, sign: false }, + FP16x16 { mag: 45875, sign: false }, + FP16x16 { mag: 8192, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 32768, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 49152, sign: false }, + FP16x16 { mag: 32768, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 36044, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 32768, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 58982, sign: false }, + FP16x16 { mag: 49152, sign: false }, + FP16x16 { mag: 36044, sign: false }, + FP16x16 { mag: 32768, sign: false }, + FP16x16 { mag: 16384, sign: false }, + FP16x16 { mag: 20480, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 49152, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 32768, sign: false }, + FP16x16 { mag: 32768, sign: false }, + FP16x16 { mag: 8192, sign: false }, + FP16x16 { mag: 32768, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 0, sign: false } + ] + .span(); let base_values: Option> = Option::None; let tree_ids: Span = array![0].span(); let mut root_index: Felt252Dict = Default::default(); - root_index.insert(0, 0); + root_index.insert(0, 0); let mut node_index: Felt252Dict = Default::default(); - node_index.insert(2089986280348253421170679821480865132823066470938446095505822317253594081284, 0); - node_index.insert(2001140082530619239661729809084578298299223810202097622761632384561112390979, 1); - node_index.insert(2592670241084192212354027440049085852792506518781954896144296316131790403900, 2); - node_index.insert(2960591271376829378356567803618548672034867345123727178628869426548453833420, 3); - node_index.insert(458933264452572171106695256465341160654132084710250671055261382009315664425, 4); - node_index.insert(3344223123784052057366048933846905716067140384361791026153972616805110454637, 5); - node_index.insert(658476905110174425295568215706634733332002869979287079110965040248935650599, 6); - node_index.insert(2836212335642438363012490794290757623813171043187182819737087983331902926990, 7); - node_index.insert(3496601277869056110810900082189273917786762659443522403285387602989271154262, 8); - node_index.insert(1249294489531540970169611621067106471309281870082955806338234725206665112557, 9); - node_index.insert(2161697998033672097816961828039488190903838124365465380011173778905747857792, 10); - node_index.insert(1129815197211541481934112806673325772687763881719835256646064516195041515616, 11); - node_index.insert(2592593088135949192377729543480191336537305484235681164569491942155715064163, 12); - node_index.insert(578223957014284909949571568465953382377214912750427143720957054706073492593, 13); - node_index.insert(1645617302026197421098102802983206579163506957138012501615708926120228167528, 14); - node_index.insert(2809438816810155970395166036110536928593305127049404137239671320081144123490, 15); - node_index.insert(2496308528011391755709310159103918074725328650411689040761791240500618770096, 16); - node_index.insert(2003594778587446957576114348312422277631766150749194167061999666337236425714, 17); - node_index.insert(2215681478480673835576618830034726157921200517935329010004363713426342305479, 18); - node_index.insert(3185925835074464079989752015681272863271067691852543168049845807561733691707, 19); - node_index.insert(1207265836470221457484062512091666004839070622130697586496866096347024057755, 20); - node_index.insert(1870230949202979679764944800468118671928852128047695497376875566624821494262, 21); - node_index.insert(618060852536781954395603948693216564334274573299243914053414488061601327758, 22); - node_index.insert(232760707548494477255512699093366059519467428168757247456690480397246371463, 23); - node_index.insert(1617386247965480308136742715422077429967341022950306068917456849194882895900, 24); - node_index.insert(654822874782506608656472905579051041410086644071534146326024101025575400153, 25); - node_index.insert(525638101901638132526332140778087078272370083489998903571807698910013602668, 26); - node_index.insert(3091640181556387972179279087539287892670640556085669903494551919685982442095, 27); - node_index.insert(1425411460578159050163131982087304445715005458700346341117759372943452688022, 28); - node_index.insert(1722933265299553894839124723076027659619615015638971980461286818493531809034, 29); - node_index.insert(3325117385742592388671007840076299062858228097051060057749225651290693960897, 30); - node_index.insert(1869273998012404873272699831805499731567895666937555882116307079956228100456, 31); - node_index.insert(257262395234910825879033951801423835835630270967846664413154594520703929530, 32); - node_index.insert(2891500475385583315757684141371327604925143655360011721762142660942782195029, 33); - node_index.insert(1257459981124043271342269816753070228024611695909553991758648317372015085782, 34); - node_index.insert(3573101724490615587655146760489247477770015274618159524231872921394794809579, 35); - node_index.insert(2951401777594449283985541406642940553317465718696638438535370997641527993378, 36); - node_index.insert(2436860863451320452900512817385686838091627966322316039332239784330434600829, 37); - node_index.insert(3257977356974702770994741663931928753019715185508521958836925918758890988390, 38); - node_index.insert(2741853283805093821434776875305720302351684616683152528499335618682018880592, 39); - node_index.insert(514567459251558911686762246500770717674979116530125263461114578537254680672, 40); - node_index.insert(2119374930171040799805795099091470687208894498354655018353474015395489390434, 41); - node_index.insert(3338470191188327918255138125570464269857839379813971679216902484398948556964, 42); - node_index.insert(2892272281879752543368066497063301979597320550780387266511926397533716561161, 43); - node_index.insert(2855312300216814846973137837923466865382642814675378398541743368270404441020, 44); - node_index.insert(3483159989811162048659069774034779954374540681397531094699912464364012442948, 45); - node_index.insert(2987290998320166766043911843685118029159841654368226419198314196237253901671, 46); - node_index.insert(2925128850088180758852255336587985612621894021863350117875677692518888637440, 47); - node_index.insert(2816470536741550741568042622139415760794090671576940833850781679568928363263, 48); - node_index.insert(117504025904364990582663097556885493352655695615775952177872159762046032741, 49); - node_index.insert(2143228410294149239354901612797540167003066966910132278060626241695943498248, 50); - node_index.insert(419311759585766455354017006957403420381614228026953716552023555428752798694, 51); - node_index.insert(3050064038480880151202753004776919876287903442365303272956696507808448797287, 52); - node_index.insert(1385347512411195789080079656286641766866442255046855963092069449745407366357, 53); - node_index.insert(3070310993421490198115289431281422702215620142859327949152517372324361472619, 54); - node_index.insert(2913742884576958969164113782587195202828846527657900496424141449477472273564, 55); - node_index.insert(2093568472535973986606438755824580633177115509557931302974988564932601955239, 56); - node_index.insert(3560543329106347446823281318204312198881533222464682017397248462954529220234, 57); - node_index.insert(2258329791422139736262782239641765930569031761627249090322755566443202104242, 58); - node_index.insert(780147230530856456622774510057100334628735431063744145772648079601317149643, 59); - node_index.insert(2316329094783634722527635915976455864728431870713378530935487247638854220445, 60); - node_index.insert(595942459003356191117553450912822964169058193996898486073017533717706655996, 61); - node_index.insert(468061318535033931711585815055033307297228787991312757359512916260570188285, 62); - node_index.insert(2052204235688624923559873131063770183910134013049526186717275231865702195614, 63); - node_index.insert(1699955311620840869165542755053722387608345658646185648087789689690825797785, 64); - node_index.insert(3374282522812564185678772854203408947562394461702303390331208821006329361123, 65); - node_index.insert(2973169188135795465401576355486514117723575153845438471619715618155257254587, 66); - node_index.insert(1933845760462748501896196912926633344425020928596291295340561855718789280752, 67); - node_index.insert(1400206374308839959676708676217334569580738052049798766556848516900888958934, 68); - node_index.insert(1440488595273849761788031183901254714714513692476890759699232177835922420051, 69); - node_index.insert(1765607197782429306903827944694032984087223086461400721152786273443512274576, 70); - node_index.insert(1081728107764482028110815183657783965582618309560569428049406599883158895762, 71); - node_index.insert(2062101824085365476835789898002802715794623271831111740147610520210138854237, 72); - node_index.insert(2074740322618091900768870458741540994849904300182495465356314088191301853065, 73); - node_index.insert(3258451235037745323160669027918885172565773098482160366154412360890640013860, 74); - node_index.insert(525053653813541387331907730505904505067816165493211829943994988775279102044, 75); - node_index.insert(1899573658331441767985549642643113663505618738939032010935036740376062596854, 76); - node_index.insert(350484224543766923071449868701665032398970313961410080649918872017849315812, 77); - node_index.insert(1950842492180490337143378914485176805944281696420768035114335939818602766139, 78); - node_index.insert(1404824782481446239312837894341789608778585592445990662138109764117920511709, 79); - node_index.insert(362836422984951199752185473435750713386745407518736982952373985921347236081, 80); - node_index.insert(946623025367211063265176586824604502073515634531788667777364911179858705558, 81); - node_index.insert(2633163324000277496191816132521100721217797223993064604664039067710591734562, 82); - node_index.insert(1801986104078933931671502775029170829560335045042499367678597186639133610708, 83); - node_index.insert(1420697278439090953165809531316265389371075037014378922361911811337560296928, 84); - node_index.insert(2818913779862691152404893285048164649343019708946413114150419613972391643833, 85); - node_index.insert(2117995436013652728497840885480545729833030913486848118093758726746902541269, 86); - node_index.insert(127751852951361188238686395231851222850913859197429858579312845246901369178, 87); - node_index.insert(2698811633001158191033663638617437313508153976714307643233173949778419312517, 88); - node_index.insert(658388282521842455588914251287531837029259203197178137902217792556456503561, 89); - node_index.insert(1181527093320872098458354979612125149419384756607076935731557552577945926179, 90); - node_index.insert(749436134732178646256740138670151907037714564259781780243747781475007506978, 91); - node_index.insert(139527053159256821789882596124320673637475746672994443968014105962305658551, 92); - node_index.insert(2256264752321707533173578319742847366660740117899562657584919346001438808295, 93); - node_index.insert(1471349294215639651865069312281269029496180149092207674923855978537861742949, 94); - node_index.insert(1599527610774916650758786135513735847459194869088601099692148267264507139422, 95); - node_index.insert(1348925567371118538973078195838174941892601233016661969987842843098656775084, 96); - node_index.insert(3255130909854220350850821724488067913492420563978595271106701962634473840914, 97); - node_index.insert(1098499015810170842401428216621470177488952811780672364884710297364076372943, 98); - node_index.insert(2666902303639302012507119689908308317608522901613536135678723310999647515155, 99); - node_index.insert(907997515879651052705985194221621380802961721264372722705825219340461809200, 100); - node_index.insert(2124360554325144308113106422635485756539471211141315552843423768396084888273, 101); - node_index.insert(3598736440043009208771817410113758019876931018927260161846683440123219507147, 102); - node_index.insert(1237113034722832488580561245188430373504295256910735188987019984096012001931, 103); - node_index.insert(884558344049768836371555446021588200903052780339208951904957349404044037185, 104); - node_index.insert(784280321344489256066716285882203121428790637989919760379274813665427427262, 105); - node_index.insert(3472551952588748711709398308465335743810517871695257916614928877311914574241, 106); - node_index.insert(1579363348100943961344032004617708767155021524242506190674861550786419896732, 107); - node_index.insert(653576968777651719072715499492112313607520878545254037043893560183879857489, 108); - node_index.insert(2633327961579170199842757290989312779085828750765842327985383652720803061926, 109); - node_index.insert(3101204920253220343970782457572784926765600523633379722044614528209389590915, 110); - node_index.insert(2537565394330405662800880050062241097694806466900452037378113841155978555645, 111); - node_index.insert(306955559655552244989220345789093187601563118591829582730637833945761653350, 112); - node_index.insert(1144065212212058748489308207801098564095305699242880891977316839573431241916, 113); - node_index.insert(3478181491851418723342103101321490659650934149094649769124337426850038155270, 114); - node_index.insert(3419621624676637660673415219086314486713019053519954317586073983685881930356, 115); - node_index.insert(2426908011370291613447136873176769136554489197972200481728552402228021778402, 116); - node_index.insert(1916122042123370178944690083048900704842269230325086549679099089416174875473, 117); - node_index.insert(2057207652658215393591191155928140567561900227203223756539551876829334137660, 118); - node_index.insert(2722034389703601317070746005702467061064354401688341549606678773616189196490, 119); - node_index.insert(1171026027377763359814377926117880688616494219551682642535759838199732407496, 120); - node_index.insert(3507234282031533800397666430789917374211847440333243952151005899337152633413, 121); - node_index.insert(591003147462937848375161803108517142253138969543815135207326321181858185919, 122); - node_index.insert(182069734527202013451813026473135702900640769187641767871411473365447302169, 123); - node_index.insert(1195243682249232878341146428166676460720423167409013083888435705219134747702, 124); - node_index.insert(1793425644853312386902998134061844248823841892125424765064687913085130719534, 125); - node_index.insert(1983622665815164792580256365519803214027269990384198703315493315153573288434, 126); - node_index.insert(3615973154491344159350153395208055142342062736505558158666764642048838175685, 127); - node_index.insert(2751715913626909804252433699602081411293721754810298670422380863932998088133, 128); - node_index.insert(186918881712189523740089713555196200069231794627360499557319265374750577226, 129); - node_index.insert(696585542544434929491503209053317581175146475161262066468664234437983008675, 130); - node_index.insert(4359830495913805154545225899592517767672472055784183911796827820518038513, 131); - node_index.insert(2954335207058000607751727656601539819316106074875304820535376873121805433820, 132); - node_index.insert(2510390039949230255082316953804013731253145558531652907601250263563528226672, 133); - node_index.insert(3226995230854300551967642178527450300960499043510855212238369890580256668532, 134); - node_index.insert(1620924075233065517364532267959798304439946408626316544761884056227131075831, 135); - node_index.insert(1610900122192929153657761847202689179268074338802437933866337242354758101660, 136); - node_index.insert(2565949095169598991903537465065584077778440646580025930326495506484329892725, 137); - node_index.insert(1012362975819634411571869839734809106575285344002573666983595104659295812607, 138); - node_index.insert(242312010918799555845832460483650516749990744287009628468613253461264531026, 139); - node_index.insert(1104776796569046483584574115975216172161469015460244982207905888870418040487, 140); - node_index.insert(3289555912992777681578950209252840071327866822704829766247386311885634446673, 141); - node_index.insert(3133389957643610781371406448279843175887428913359743769920083259111437722268, 142); - node_index.insert(1169918710119352022244140656086831769713178729571654411898266328562003734517, 143); - node_index.insert(3592039235252149652556167686570045881877115549259769455422056097903987237819, 144); - node_index.insert(2048175709145840597887667330964815895803568760936075562647625937161113445908, 145); - node_index.insert(602222645962845554276438041138511866776339653340605661136009451417275008940, 146); - node_index.insert(3318742320906017551291978242369663702298606650330380959683585594592748661010, 147); - node_index.insert(564160996724923690963741657975239836484028160385417016805513722318839327322, 148); - node_index.insert(656294390376267384135628810815504467149264887388377312825033341338166573620, 149); - node_index.insert(1201592236750942207412694706123654466634588634474700675083122904145559965915, 150); - node_index.insert(2141408926815137181004274624388915700231991905288681935478972043994347966006, 151); - node_index.insert(1440847977042239464860406726605567303568767649154338464116083965986084755262, 152); - node_index.insert(950585553138591375958592507876257987416844837045084288783892644487908218679, 153); - node_index.insert(257643451533833048856069434258149588745628261389615631070776723485957908127, 154); + node_index + .insert(2089986280348253421170679821480865132823066470938446095505822317253594081284, 0); + node_index + .insert(2001140082530619239661729809084578298299223810202097622761632384561112390979, 1); + node_index + .insert(2592670241084192212354027440049085852792506518781954896144296316131790403900, 2); + node_index + .insert(2960591271376829378356567803618548672034867345123727178628869426548453833420, 3); + node_index + .insert(458933264452572171106695256465341160654132084710250671055261382009315664425, 4); + node_index + .insert(3344223123784052057366048933846905716067140384361791026153972616805110454637, 5); + node_index + .insert(658476905110174425295568215706634733332002869979287079110965040248935650599, 6); + node_index + .insert(2836212335642438363012490794290757623813171043187182819737087983331902926990, 7); + node_index + .insert(3496601277869056110810900082189273917786762659443522403285387602989271154262, 8); + node_index + .insert(1249294489531540970169611621067106471309281870082955806338234725206665112557, 9); + node_index + .insert(2161697998033672097816961828039488190903838124365465380011173778905747857792, 10); + node_index + .insert(1129815197211541481934112806673325772687763881719835256646064516195041515616, 11); + node_index + .insert(2592593088135949192377729543480191336537305484235681164569491942155715064163, 12); + node_index + .insert(578223957014284909949571568465953382377214912750427143720957054706073492593, 13); + node_index + .insert(1645617302026197421098102802983206579163506957138012501615708926120228167528, 14); + node_index + .insert(2809438816810155970395166036110536928593305127049404137239671320081144123490, 15); + node_index + .insert(2496308528011391755709310159103918074725328650411689040761791240500618770096, 16); + node_index + .insert(2003594778587446957576114348312422277631766150749194167061999666337236425714, 17); + node_index + .insert(2215681478480673835576618830034726157921200517935329010004363713426342305479, 18); + node_index + .insert(3185925835074464079989752015681272863271067691852543168049845807561733691707, 19); + node_index + .insert(1207265836470221457484062512091666004839070622130697586496866096347024057755, 20); + node_index + .insert(1870230949202979679764944800468118671928852128047695497376875566624821494262, 21); + node_index + .insert(618060852536781954395603948693216564334274573299243914053414488061601327758, 22); + node_index + .insert(232760707548494477255512699093366059519467428168757247456690480397246371463, 23); + node_index + .insert(1617386247965480308136742715422077429967341022950306068917456849194882895900, 24); + node_index + .insert(654822874782506608656472905579051041410086644071534146326024101025575400153, 25); + node_index + .insert(525638101901638132526332140778087078272370083489998903571807698910013602668, 26); + node_index + .insert(3091640181556387972179279087539287892670640556085669903494551919685982442095, 27); + node_index + .insert(1425411460578159050163131982087304445715005458700346341117759372943452688022, 28); + node_index + .insert(1722933265299553894839124723076027659619615015638971980461286818493531809034, 29); + node_index + .insert(3325117385742592388671007840076299062858228097051060057749225651290693960897, 30); + node_index + .insert(1869273998012404873272699831805499731567895666937555882116307079956228100456, 31); + node_index + .insert(257262395234910825879033951801423835835630270967846664413154594520703929530, 32); + node_index + .insert(2891500475385583315757684141371327604925143655360011721762142660942782195029, 33); + node_index + .insert(1257459981124043271342269816753070228024611695909553991758648317372015085782, 34); + node_index + .insert(3573101724490615587655146760489247477770015274618159524231872921394794809579, 35); + node_index + .insert(2951401777594449283985541406642940553317465718696638438535370997641527993378, 36); + node_index + .insert(2436860863451320452900512817385686838091627966322316039332239784330434600829, 37); + node_index + .insert(3257977356974702770994741663931928753019715185508521958836925918758890988390, 38); + node_index + .insert(2741853283805093821434776875305720302351684616683152528499335618682018880592, 39); + node_index + .insert(514567459251558911686762246500770717674979116530125263461114578537254680672, 40); + node_index + .insert(2119374930171040799805795099091470687208894498354655018353474015395489390434, 41); + node_index + .insert(3338470191188327918255138125570464269857839379813971679216902484398948556964, 42); + node_index + .insert(2892272281879752543368066497063301979597320550780387266511926397533716561161, 43); + node_index + .insert(2855312300216814846973137837923466865382642814675378398541743368270404441020, 44); + node_index + .insert(3483159989811162048659069774034779954374540681397531094699912464364012442948, 45); + node_index + .insert(2987290998320166766043911843685118029159841654368226419198314196237253901671, 46); + node_index + .insert(2925128850088180758852255336587985612621894021863350117875677692518888637440, 47); + node_index + .insert(2816470536741550741568042622139415760794090671576940833850781679568928363263, 48); + node_index + .insert(117504025904364990582663097556885493352655695615775952177872159762046032741, 49); + node_index + .insert(2143228410294149239354901612797540167003066966910132278060626241695943498248, 50); + node_index + .insert(419311759585766455354017006957403420381614228026953716552023555428752798694, 51); + node_index + .insert(3050064038480880151202753004776919876287903442365303272956696507808448797287, 52); + node_index + .insert(1385347512411195789080079656286641766866442255046855963092069449745407366357, 53); + node_index + .insert(3070310993421490198115289431281422702215620142859327949152517372324361472619, 54); + node_index + .insert(2913742884576958969164113782587195202828846527657900496424141449477472273564, 55); + node_index + .insert(2093568472535973986606438755824580633177115509557931302974988564932601955239, 56); + node_index + .insert(3560543329106347446823281318204312198881533222464682017397248462954529220234, 57); + node_index + .insert(2258329791422139736262782239641765930569031761627249090322755566443202104242, 58); + node_index + .insert(780147230530856456622774510057100334628735431063744145772648079601317149643, 59); + node_index + .insert(2316329094783634722527635915976455864728431870713378530935487247638854220445, 60); + node_index + .insert(595942459003356191117553450912822964169058193996898486073017533717706655996, 61); + node_index + .insert(468061318535033931711585815055033307297228787991312757359512916260570188285, 62); + node_index + .insert(2052204235688624923559873131063770183910134013049526186717275231865702195614, 63); + node_index + .insert(1699955311620840869165542755053722387608345658646185648087789689690825797785, 64); + node_index + .insert(3374282522812564185678772854203408947562394461702303390331208821006329361123, 65); + node_index + .insert(2973169188135795465401576355486514117723575153845438471619715618155257254587, 66); + node_index + .insert(1933845760462748501896196912926633344425020928596291295340561855718789280752, 67); + node_index + .insert(1400206374308839959676708676217334569580738052049798766556848516900888958934, 68); + node_index + .insert(1440488595273849761788031183901254714714513692476890759699232177835922420051, 69); + node_index + .insert(1765607197782429306903827944694032984087223086461400721152786273443512274576, 70); + node_index + .insert(1081728107764482028110815183657783965582618309560569428049406599883158895762, 71); + node_index + .insert(2062101824085365476835789898002802715794623271831111740147610520210138854237, 72); + node_index + .insert(2074740322618091900768870458741540994849904300182495465356314088191301853065, 73); + node_index + .insert(3258451235037745323160669027918885172565773098482160366154412360890640013860, 74); + node_index + .insert(525053653813541387331907730505904505067816165493211829943994988775279102044, 75); + node_index + .insert(1899573658331441767985549642643113663505618738939032010935036740376062596854, 76); + node_index + .insert(350484224543766923071449868701665032398970313961410080649918872017849315812, 77); + node_index + .insert(1950842492180490337143378914485176805944281696420768035114335939818602766139, 78); + node_index + .insert(1404824782481446239312837894341789608778585592445990662138109764117920511709, 79); + node_index + .insert(362836422984951199752185473435750713386745407518736982952373985921347236081, 80); + node_index + .insert(946623025367211063265176586824604502073515634531788667777364911179858705558, 81); + node_index + .insert(2633163324000277496191816132521100721217797223993064604664039067710591734562, 82); + node_index + .insert(1801986104078933931671502775029170829560335045042499367678597186639133610708, 83); + node_index + .insert(1420697278439090953165809531316265389371075037014378922361911811337560296928, 84); + node_index + .insert(2818913779862691152404893285048164649343019708946413114150419613972391643833, 85); + node_index + .insert(2117995436013652728497840885480545729833030913486848118093758726746902541269, 86); + node_index + .insert(127751852951361188238686395231851222850913859197429858579312845246901369178, 87); + node_index + .insert(2698811633001158191033663638617437313508153976714307643233173949778419312517, 88); + node_index + .insert(658388282521842455588914251287531837029259203197178137902217792556456503561, 89); + node_index + .insert(1181527093320872098458354979612125149419384756607076935731557552577945926179, 90); + node_index + .insert(749436134732178646256740138670151907037714564259781780243747781475007506978, 91); + node_index + .insert(139527053159256821789882596124320673637475746672994443968014105962305658551, 92); + node_index + .insert(2256264752321707533173578319742847366660740117899562657584919346001438808295, 93); + node_index + .insert(1471349294215639651865069312281269029496180149092207674923855978537861742949, 94); + node_index + .insert(1599527610774916650758786135513735847459194869088601099692148267264507139422, 95); + node_index + .insert(1348925567371118538973078195838174941892601233016661969987842843098656775084, 96); + node_index + .insert(3255130909854220350850821724488067913492420563978595271106701962634473840914, 97); + node_index + .insert(1098499015810170842401428216621470177488952811780672364884710297364076372943, 98); + node_index + .insert(2666902303639302012507119689908308317608522901613536135678723310999647515155, 99); + node_index + .insert(907997515879651052705985194221621380802961721264372722705825219340461809200, 100); + node_index + .insert(2124360554325144308113106422635485756539471211141315552843423768396084888273, 101); + node_index + .insert(3598736440043009208771817410113758019876931018927260161846683440123219507147, 102); + node_index + .insert(1237113034722832488580561245188430373504295256910735188987019984096012001931, 103); + node_index + .insert(884558344049768836371555446021588200903052780339208951904957349404044037185, 104); + node_index + .insert(784280321344489256066716285882203121428790637989919760379274813665427427262, 105); + node_index + .insert(3472551952588748711709398308465335743810517871695257916614928877311914574241, 106); + node_index + .insert(1579363348100943961344032004617708767155021524242506190674861550786419896732, 107); + node_index + .insert(653576968777651719072715499492112313607520878545254037043893560183879857489, 108); + node_index + .insert(2633327961579170199842757290989312779085828750765842327985383652720803061926, 109); + node_index + .insert(3101204920253220343970782457572784926765600523633379722044614528209389590915, 110); + node_index + .insert(2537565394330405662800880050062241097694806466900452037378113841155978555645, 111); + node_index + .insert(306955559655552244989220345789093187601563118591829582730637833945761653350, 112); + node_index + .insert(1144065212212058748489308207801098564095305699242880891977316839573431241916, 113); + node_index + .insert(3478181491851418723342103101321490659650934149094649769124337426850038155270, 114); + node_index + .insert(3419621624676637660673415219086314486713019053519954317586073983685881930356, 115); + node_index + .insert(2426908011370291613447136873176769136554489197972200481728552402228021778402, 116); + node_index + .insert(1916122042123370178944690083048900704842269230325086549679099089416174875473, 117); + node_index + .insert(2057207652658215393591191155928140567561900227203223756539551876829334137660, 118); + node_index + .insert(2722034389703601317070746005702467061064354401688341549606678773616189196490, 119); + node_index + .insert(1171026027377763359814377926117880688616494219551682642535759838199732407496, 120); + node_index + .insert(3507234282031533800397666430789917374211847440333243952151005899337152633413, 121); + node_index + .insert(591003147462937848375161803108517142253138969543815135207326321181858185919, 122); + node_index + .insert(182069734527202013451813026473135702900640769187641767871411473365447302169, 123); + node_index + .insert(1195243682249232878341146428166676460720423167409013083888435705219134747702, 124); + node_index + .insert(1793425644853312386902998134061844248823841892125424765064687913085130719534, 125); + node_index + .insert(1983622665815164792580256365519803214027269990384198703315493315153573288434, 126); + node_index + .insert(3615973154491344159350153395208055142342062736505558158666764642048838175685, 127); + node_index + .insert(2751715913626909804252433699602081411293721754810298670422380863932998088133, 128); + node_index + .insert(186918881712189523740089713555196200069231794627360499557319265374750577226, 129); + node_index + .insert(696585542544434929491503209053317581175146475161262066468664234437983008675, 130); + node_index + .insert(4359830495913805154545225899592517767672472055784183911796827820518038513, 131); + node_index + .insert(2954335207058000607751727656601539819316106074875304820535376873121805433820, 132); + node_index + .insert(2510390039949230255082316953804013731253145558531652907601250263563528226672, 133); + node_index + .insert(3226995230854300551967642178527450300960499043510855212238369890580256668532, 134); + node_index + .insert(1620924075233065517364532267959798304439946408626316544761884056227131075831, 135); + node_index + .insert(1610900122192929153657761847202689179268074338802437933866337242354758101660, 136); + node_index + .insert(2565949095169598991903537465065584077778440646580025930326495506484329892725, 137); + node_index + .insert(1012362975819634411571869839734809106575285344002573666983595104659295812607, 138); + node_index + .insert(242312010918799555845832460483650516749990744287009628468613253461264531026, 139); + node_index + .insert(1104776796569046483584574115975216172161469015460244982207905888870418040487, 140); + node_index + .insert(3289555912992777681578950209252840071327866822704829766247386311885634446673, 141); + node_index + .insert(3133389957643610781371406448279843175887428913359743769920083259111437722268, 142); + node_index + .insert(1169918710119352022244140656086831769713178729571654411898266328562003734517, 143); + node_index + .insert(3592039235252149652556167686570045881877115549259769455422056097903987237819, 144); + node_index + .insert(2048175709145840597887667330964815895803568760936075562647625937161113445908, 145); + node_index + .insert(602222645962845554276438041138511866776339653340605661136009451417275008940, 146); + node_index + .insert(3318742320906017551291978242369663702298606650330380959683585594592748661010, 147); + node_index + .insert(564160996724923690963741657975239836484028160385417016805513722318839327322, 148); + node_index + .insert(656294390376267384135628810815504467149264887388377312825033341338166573620, 149); + node_index + .insert(1201592236750942207412694706123654466634588634474700675083122904145559965915, 150); + node_index + .insert(2141408926815137181004274624388915700231991905288681935478972043994347966006, 151); + node_index + .insert(1440847977042239464860406726605567303568767649154338464116083965986084755262, 152); + node_index + .insert(950585553138591375958592507876257987416844837045084288783892644487908218679, 153); + node_index + .insert(257643451533833048856069434258149588745628261389615631070776723485957908127, 154); let atts = TreeEnsembleAttributes { nodes_falsenodeids, @@ -687,19 +2421,20 @@ fn tree_ensemble_classifier_binary_class_helper( }; let mut X = TensorTrait::new( - array![1,9].span(), - array![ - FP16x16 { mag: 39321, sign: false }, - FP16x16 { mag: 32768, sign: false }, - FP16x16 { mag: 52428, sign: false }, - FP16x16 { mag: 16384, sign: false }, - FP16x16 { mag: 0, sign: false }, - FP16x16 { mag: 65536, sign: false }, - FP16x16 { mag: 0, sign: false }, - FP16x16 { mag: 16384, sign: false }, - FP16x16 { mag: 0, sign: false }, - ].span() - ); + array![1, 9].span(), + array![ + FP16x16 { mag: 39321, sign: false }, + FP16x16 { mag: 32768, sign: false }, + FP16x16 { mag: 52428, sign: false }, + FP16x16 { mag: 16384, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 65536, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 16384, sign: false }, + FP16x16 { mag: 0, sign: false }, + ] + .span() + ); (classifier, X) -} \ No newline at end of file +} diff --git a/tests/nodes.cairo b/tests/nodes.cairo index e3eccf547..92eb6ba20 100644 --- a/tests/nodes.cairo +++ b/tests/nodes.cairo @@ -869,3 +869,19 @@ mod compress_u32_3d_axis1; mod compress_u32_3d_axis2; mod compress_u32_3d_axis2_2; mod compress_u32_3d_axis3; +mod layer_normalization_default_axis; +mod layer_normalization_4d_axis0; +mod layer_normalization_4d_axis1; +mod layer_normalization_4d_axis2; +mod layer_normalization_4d_axis3; +mod layer_normalization_3d_axis0_epsilon; +mod layer_normalization_3d_axis1_epsilon; +mod layer_normalization_3d_axis2_epsilon; +mod layer_normalization_4d_axis_negative_4; +mod layer_normalization_4d_axis_negative_3; +mod layer_normalization_4d_axis_negative_2; +mod layer_normalization_4d_axis_negative_1; +mod layer_normalization_3d_axis_negative_3_epsilon; +mod layer_normalization_3d_axis_negative_2_epsilon; +mod layer_normalization_3d_axis_negative_1_epsilon; +mod layer_normalization_test; diff --git a/tests/nodes/compress_fp16x16_3d_axis1.cairo b/tests/nodes/compress_fp16x16_3d_axis1.cairo index de0c173ed..f110fd66d 100644 --- a/tests/nodes/compress_fp16x16_3d_axis1.cairo +++ b/tests/nodes/compress_fp16x16_3d_axis1.cairo @@ -18,7 +18,7 @@ fn test_compress_fp16x16_3d_axis1() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.compress(condition:input_1, axis:Option::Some(1)); + let y_0 = input_0.compress(condition: input_1, axis: Option::Some(1)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/compress_fp16x16_3d_axis2.cairo b/tests/nodes/compress_fp16x16_3d_axis2.cairo index 765bcb5ea..1115fb557 100644 --- a/tests/nodes/compress_fp16x16_3d_axis2.cairo +++ b/tests/nodes/compress_fp16x16_3d_axis2.cairo @@ -18,7 +18,7 @@ fn test_compress_fp16x16_3d_axis2() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.compress(condition:input_1, axis:Option::Some(2)); + let y_0 = input_0.compress(condition: input_1, axis: Option::Some(2)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/compress_fp16x16_3d_axis3.cairo b/tests/nodes/compress_fp16x16_3d_axis3.cairo index ffa9c8321..76ef5f641 100644 --- a/tests/nodes/compress_fp16x16_3d_axis3.cairo +++ b/tests/nodes/compress_fp16x16_3d_axis3.cairo @@ -18,7 +18,7 @@ fn test_compress_fp16x16_3d_axis3() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.compress(condition:input_1, axis:Option::Some(3)); + let y_0 = input_0.compress(condition: input_1, axis: Option::Some(3)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/compress_fp16x16_3d_default.cairo b/tests/nodes/compress_fp16x16_3d_default.cairo index d9b837a19..aff1849e2 100644 --- a/tests/nodes/compress_fp16x16_3d_default.cairo +++ b/tests/nodes/compress_fp16x16_3d_default.cairo @@ -18,7 +18,7 @@ fn test_compress_fp16x16_3d_default() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.compress(condition:input_1, axis:Option::Some(0)); + let y_0 = input_0.compress(condition: input_1, axis: Option::Some(0)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/compress_fp16x16_3d_noaxis.cairo b/tests/nodes/compress_fp16x16_3d_noaxis.cairo index 2bd536e08..3c9645b1d 100644 --- a/tests/nodes/compress_fp16x16_3d_noaxis.cairo +++ b/tests/nodes/compress_fp16x16_3d_noaxis.cairo @@ -18,7 +18,7 @@ fn test_compress_fp16x16_3d_noaxis() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.compress(condition:input_1, axis:Option::None(())); + let y_0 = input_0.compress(condition: input_1, axis: Option::None(())); assert_eq(y_0, z_0); } diff --git a/tests/nodes/compress_fp8x23_3d_axis1.cairo b/tests/nodes/compress_fp8x23_3d_axis1.cairo index edd013f54..f7edfd13a 100644 --- a/tests/nodes/compress_fp8x23_3d_axis1.cairo +++ b/tests/nodes/compress_fp8x23_3d_axis1.cairo @@ -18,7 +18,7 @@ fn test_compress_fp8x23_3d_axis1() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.compress(condition:input_1, axis:Option::Some(1)); + let y_0 = input_0.compress(condition: input_1, axis: Option::Some(1)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/compress_fp8x23_3d_axis2.cairo b/tests/nodes/compress_fp8x23_3d_axis2.cairo index 580a6272a..369ffb8bf 100644 --- a/tests/nodes/compress_fp8x23_3d_axis2.cairo +++ b/tests/nodes/compress_fp8x23_3d_axis2.cairo @@ -18,7 +18,7 @@ fn test_compress_fp8x23_3d_axis2() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.compress(condition:input_1, axis:Option::Some(2)); + let y_0 = input_0.compress(condition: input_1, axis: Option::Some(2)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/compress_fp8x23_3d_default.cairo b/tests/nodes/compress_fp8x23_3d_default.cairo index a927f7fe8..eab9aa1ac 100644 --- a/tests/nodes/compress_fp8x23_3d_default.cairo +++ b/tests/nodes/compress_fp8x23_3d_default.cairo @@ -18,7 +18,7 @@ fn test_compress_fp8x23_3d_default() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.compress(condition:input_1, axis:Option::Some(0)); + let y_0 = input_0.compress(condition: input_1, axis: Option::Some(0)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/compress_i32_3d_axis1.cairo b/tests/nodes/compress_i32_3d_axis1.cairo index f69cf2e2a..571e5beb5 100644 --- a/tests/nodes/compress_i32_3d_axis1.cairo +++ b/tests/nodes/compress_i32_3d_axis1.cairo @@ -18,7 +18,7 @@ fn test_compress_i32_3d_axis1() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.compress(condition:input_1, axis:Option::Some(1)); + let y_0 = input_0.compress(condition: input_1, axis: Option::Some(1)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/compress_i32_3d_axis2.cairo b/tests/nodes/compress_i32_3d_axis2.cairo index bfe01e5a0..be674ffba 100644 --- a/tests/nodes/compress_i32_3d_axis2.cairo +++ b/tests/nodes/compress_i32_3d_axis2.cairo @@ -18,7 +18,7 @@ fn test_compress_i32_3d_axis2() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.compress(condition:input_1, axis:Option::Some(2)); + let y_0 = input_0.compress(condition: input_1, axis: Option::Some(2)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/compress_i32_3d_default.cairo b/tests/nodes/compress_i32_3d_default.cairo index b07d95010..4bd05fce1 100644 --- a/tests/nodes/compress_i32_3d_default.cairo +++ b/tests/nodes/compress_i32_3d_default.cairo @@ -18,7 +18,7 @@ fn test_compress_i32_3d_default() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.compress(condition:input_1, axis:Option::Some(0)); + let y_0 = input_0.compress(condition: input_1, axis: Option::Some(0)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/compress_i8_3d_axis1.cairo b/tests/nodes/compress_i8_3d_axis1.cairo index 6a4197ce1..fae6c2356 100644 --- a/tests/nodes/compress_i8_3d_axis1.cairo +++ b/tests/nodes/compress_i8_3d_axis1.cairo @@ -18,7 +18,7 @@ fn test_compress_i8_3d_axis1() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.compress(condition:input_1, axis:Option::Some(1)); + let y_0 = input_0.compress(condition: input_1, axis: Option::Some(1)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/compress_i8_3d_axis2.cairo b/tests/nodes/compress_i8_3d_axis2.cairo index 4dd7b5a8f..f8e90c133 100644 --- a/tests/nodes/compress_i8_3d_axis2.cairo +++ b/tests/nodes/compress_i8_3d_axis2.cairo @@ -18,7 +18,7 @@ fn test_compress_i8_3d_axis2() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.compress(condition:input_1, axis:Option::Some(2)); + let y_0 = input_0.compress(condition: input_1, axis: Option::Some(2)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/compress_i8_3d_default.cairo b/tests/nodes/compress_i8_3d_default.cairo index 14b684377..1b4052d0e 100644 --- a/tests/nodes/compress_i8_3d_default.cairo +++ b/tests/nodes/compress_i8_3d_default.cairo @@ -18,7 +18,7 @@ fn test_compress_i8_3d_default() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.compress(condition:input_1, axis:Option::Some(0)); + let y_0 = input_0.compress(condition: input_1, axis: Option::Some(0)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/compress_u32_3d_axis1.cairo b/tests/nodes/compress_u32_3d_axis1.cairo index dda59bead..7cfadc989 100644 --- a/tests/nodes/compress_u32_3d_axis1.cairo +++ b/tests/nodes/compress_u32_3d_axis1.cairo @@ -16,7 +16,7 @@ fn test_compress_u32_3d_axis1() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.compress(condition:input_1, axis:Option::Some(1)); + let y_0 = input_0.compress(condition: input_1, axis: Option::Some(1)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/compress_u32_3d_axis2.cairo b/tests/nodes/compress_u32_3d_axis2.cairo index ba8fa77ef..9c70291c5 100644 --- a/tests/nodes/compress_u32_3d_axis2.cairo +++ b/tests/nodes/compress_u32_3d_axis2.cairo @@ -16,7 +16,7 @@ fn test_compress_u32_3d_axis2() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.compress(condition:input_1, axis:Option::Some(2)); + let y_0 = input_0.compress(condition: input_1, axis: Option::Some(2)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/compress_u32_3d_axis2_2.cairo b/tests/nodes/compress_u32_3d_axis2_2.cairo index aa283b2cc..850c10296 100644 --- a/tests/nodes/compress_u32_3d_axis2_2.cairo +++ b/tests/nodes/compress_u32_3d_axis2_2.cairo @@ -16,7 +16,7 @@ fn test_compress_u32_3d_axis2_2() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.compress(condition:input_1, axis:Option::Some(2)); + let y_0 = input_0.compress(condition: input_1, axis: Option::Some(2)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/compress_u32_3d_axis3.cairo b/tests/nodes/compress_u32_3d_axis3.cairo index 62684b39f..c53e3e1b1 100644 --- a/tests/nodes/compress_u32_3d_axis3.cairo +++ b/tests/nodes/compress_u32_3d_axis3.cairo @@ -16,7 +16,7 @@ fn test_compress_u32_3d_axis3() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.compress(condition:input_1, axis:Option::Some(3)); + let y_0 = input_0.compress(condition: input_1, axis: Option::Some(3)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/compress_u32_3d_default.cairo b/tests/nodes/compress_u32_3d_default.cairo index 058750c53..a7d987eb6 100644 --- a/tests/nodes/compress_u32_3d_default.cairo +++ b/tests/nodes/compress_u32_3d_default.cairo @@ -16,7 +16,7 @@ fn test_compress_u32_3d_default() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.compress(condition:input_1, axis:Option::Some(0)); + let y_0 = input_0.compress(condition: input_1, axis: Option::Some(0)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/layer_normalization_3d_axis0_epsilon.cairo b/tests/nodes/layer_normalization_3d_axis0_epsilon.cairo new file mode 100644 index 000000000..9f0a27f1a --- /dev/null +++ b/tests/nodes/layer_normalization_3d_axis0_epsilon.cairo @@ -0,0 +1,33 @@ +mod input_0; +mod input_1; +mod input_2; +mod output_0; + + +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::utils::{assert_eq, assert_seq_eq}; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; + +use orion::numbers::{IntegerTrait, i32, FixedTrait}; + +#[test] +#[available_gas(2000000000)] +fn test_layer_normalization_3d_axis0_epsilon() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let input_2 = input_2::input_2(); + let z_0 = output_0::output_0(); + + let (y_0, _, _) = input_0 + .layer_normalization( + @input_1, + Option::Some(@input_2), + Option::Some(IntegerTrait::::new(0, false)), + Option::Some(FixedTrait::new(6554, false)), + Option::None + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/layer_normalization_3d_axis0_epsilon/input_0.cairo b/tests/nodes/layer_normalization_3d_axis0_epsilon/input_0.cairo new file mode 100644 index 000000000..2f4da77ea --- /dev/null +++ b/tests/nodes/layer_normalization_3d_axis0_epsilon/input_0.cairo @@ -0,0 +1,44 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 73220, sign: false }); + data.append(FP16x16 { mag: 15621, sign: false }); + data.append(FP16x16 { mag: 26862, sign: true }); + data.append(FP16x16 { mag: 63971, sign: false }); + data.append(FP16x16 { mag: 64826, sign: false }); + data.append(FP16x16 { mag: 18837, sign: false }); + data.append(FP16x16 { mag: 66021, sign: true }); + data.append(FP16x16 { mag: 42181, sign: true }); + data.append(FP16x16 { mag: 69342, sign: true }); + data.append(FP16x16 { mag: 72001, sign: true }); + data.append(FP16x16 { mag: 99818, sign: true }); + data.append(FP16x16 { mag: 63088, sign: false }); + data.append(FP16x16 { mag: 17845, sign: true }); + data.append(FP16x16 { mag: 37020, sign: true }); + data.append(FP16x16 { mag: 20567, sign: false }); + data.append(FP16x16 { mag: 1924, sign: true }); + data.append(FP16x16 { mag: 13154, sign: true }); + data.append(FP16x16 { mag: 88735, sign: false }); + data.append(FP16x16 { mag: 40464, sign: false }); + data.append(FP16x16 { mag: 96907, sign: false }); + data.append(FP16x16 { mag: 79699, sign: false }); + data.append(FP16x16 { mag: 91862, sign: true }); + data.append(FP16x16 { mag: 97396, sign: false }); + data.append(FP16x16 { mag: 23929, sign: false }); + data.append(FP16x16 { mag: 11785, sign: true }); + data.append(FP16x16 { mag: 7747, sign: false }); + data.append(FP16x16 { mag: 91889, sign: true }); + data.append(FP16x16 { mag: 16735, sign: true }); + data.append(FP16x16 { mag: 120303, sign: true }); + data.append(FP16x16 { mag: 116144, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_3d_axis0_epsilon/input_1.cairo b/tests/nodes/layer_normalization_3d_axis0_epsilon/input_1.cairo new file mode 100644 index 000000000..f169c9adb --- /dev/null +++ b/tests/nodes/layer_normalization_3d_axis0_epsilon/input_1.cairo @@ -0,0 +1,44 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 53884, sign: false }); + data.append(FP16x16 { mag: 80807, sign: true }); + data.append(FP16x16 { mag: 17881, sign: true }); + data.append(FP16x16 { mag: 18476, sign: true }); + data.append(FP16x16 { mag: 36283, sign: true }); + data.append(FP16x16 { mag: 61330, sign: true }); + data.append(FP16x16 { mag: 41039, sign: false }); + data.append(FP16x16 { mag: 82065, sign: false }); + data.append(FP16x16 { mag: 45401, sign: false }); + data.append(FP16x16 { mag: 128118, sign: true }); + data.append(FP16x16 { mag: 214898, sign: true }); + data.append(FP16x16 { mag: 16418, sign: false }); + data.append(FP16x16 { mag: 82143, sign: false }); + data.append(FP16x16 { mag: 573, sign: true }); + data.append(FP16x16 { mag: 48898, sign: false }); + data.append(FP16x16 { mag: 14511, sign: false }); + data.append(FP16x16 { mag: 11366, sign: false }); + data.append(FP16x16 { mag: 53881, sign: false }); + data.append(FP16x16 { mag: 27317, sign: true }); + data.append(FP16x16 { mag: 88557, sign: false }); + data.append(FP16x16 { mag: 14203, sign: false }); + data.append(FP16x16 { mag: 1404, sign: true }); + data.append(FP16x16 { mag: 30266, sign: false }); + data.append(FP16x16 { mag: 83574, sign: true }); + data.append(FP16x16 { mag: 82692, sign: false }); + data.append(FP16x16 { mag: 86496, sign: false }); + data.append(FP16x16 { mag: 101363, sign: true }); + data.append(FP16x16 { mag: 30107, sign: true }); + data.append(FP16x16 { mag: 40283, sign: true }); + data.append(FP16x16 { mag: 54260, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_3d_axis0_epsilon/input_2.cairo b/tests/nodes/layer_normalization_3d_axis0_epsilon/input_2.cairo new file mode 100644 index 000000000..68ddc02f9 --- /dev/null +++ b/tests/nodes/layer_normalization_3d_axis0_epsilon/input_2.cairo @@ -0,0 +1,44 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_2() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 12782, sign: true }); + data.append(FP16x16 { mag: 34829, sign: true }); + data.append(FP16x16 { mag: 85769, sign: true }); + data.append(FP16x16 { mag: 76891, sign: true }); + data.append(FP16x16 { mag: 52049, sign: true }); + data.append(FP16x16 { mag: 129289, sign: true }); + data.append(FP16x16 { mag: 92309, sign: true }); + data.append(FP16x16 { mag: 48090, sign: true }); + data.append(FP16x16 { mag: 1390, sign: false }); + data.append(FP16x16 { mag: 10093, sign: true }); + data.append(FP16x16 { mag: 6373, sign: true }); + data.append(FP16x16 { mag: 91002, sign: true }); + data.append(FP16x16 { mag: 9698, sign: false }); + data.append(FP16x16 { mag: 103992, sign: true }); + data.append(FP16x16 { mag: 26897, sign: true }); + data.append(FP16x16 { mag: 67478, sign: true }); + data.append(FP16x16 { mag: 5546, sign: false }); + data.append(FP16x16 { mag: 55870, sign: true }); + data.append(FP16x16 { mag: 35113, sign: true }); + data.append(FP16x16 { mag: 267167, sign: true }); + data.append(FP16x16 { mag: 51438, sign: true }); + data.append(FP16x16 { mag: 13667, sign: false }); + data.append(FP16x16 { mag: 17845, sign: false }); + data.append(FP16x16 { mag: 92263, sign: false }); + data.append(FP16x16 { mag: 114550, sign: true }); + data.append(FP16x16 { mag: 31510, sign: false }); + data.append(FP16x16 { mag: 24263, sign: true }); + data.append(FP16x16 { mag: 68737, sign: true }); + data.append(FP16x16 { mag: 61297, sign: true }); + data.append(FP16x16 { mag: 33386, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_3d_axis0_epsilon/output_0.cairo b/tests/nodes/layer_normalization_3d_axis0_epsilon/output_0.cairo new file mode 100644 index 000000000..8ff9e98aa --- /dev/null +++ b/tests/nodes/layer_normalization_3d_axis0_epsilon/output_0.cairo @@ -0,0 +1,44 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 42429, sign: false }); + data.append(FP16x16 { mag: 49634, sign: true }); + data.append(FP16x16 { mag: 77948, sign: true }); + data.append(FP16x16 { mag: 93325, sign: true }); + data.append(FP16x16 { mag: 84777, sign: true }); + data.append(FP16x16 { mag: 143407, sign: true }); + data.append(FP16x16 { mag: 133737, sign: true }); + data.append(FP16x16 { mag: 102353, sign: true }); + data.append(FP16x16 { mag: 46642, sign: true }); + data.append(FP16x16 { mag: 130431, sign: false }); + data.append(FP16x16 { mag: 316659, sign: false }); + data.append(FP16x16 { mag: 76609, sign: true }); + data.append(FP16x16 { mag: 15411, sign: true }); + data.append(FP16x16 { mag: 103656, sign: true }); + data.append(FP16x16 { mag: 14405, sign: true }); + data.append(FP16x16 { mag: 68539, sign: true }); + data.append(FP16x16 { mag: 2850, sign: false }); + data.append(FP16x16 { mag: 11550, sign: false }); + data.append(FP16x16 { mag: 50032, sign: true }); + data.append(FP16x16 { mag: 145784, sign: true }); + data.append(FP16x16 { mag: 35540, sign: true }); + data.append(FP16x16 { mag: 15615, sign: false }); + data.append(FP16x16 { mag: 59548, sign: false }); + data.append(FP16x16 { mag: 66808, sign: false }); + data.append(FP16x16 { mag: 132508, sign: true }); + data.append(FP16x16 { mag: 37408, sign: false }); + data.append(FP16x16 { mag: 116363, sign: false }); + data.append(FP16x16 { mag: 60021, sign: true }); + data.append(FP16x16 { mag: 11310, sign: false }); + data.append(FP16x16 { mag: 56233, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_3d_axis1_epsilon.cairo b/tests/nodes/layer_normalization_3d_axis1_epsilon.cairo new file mode 100644 index 000000000..d2c6400b6 --- /dev/null +++ b/tests/nodes/layer_normalization_3d_axis1_epsilon.cairo @@ -0,0 +1,32 @@ +mod input_0; +mod input_1; +mod input_2; +mod output_0; + + +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::utils::{assert_eq, assert_seq_eq}; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{IntegerTrait, i32, FixedTrait}; + +#[test] +#[available_gas(2000000000)] +fn test_layer_normalization_3d_axis1_epsilon() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let input_2 = input_2::input_2(); + let z_0 = output_0::output_0(); + + let (y_0, _, _) = input_0 + .layer_normalization( + @input_1, + Option::Some(@input_2), + Option::Some(IntegerTrait::::new(1, false)), + Option::Some(FixedTrait::new(6554, false)), + Option::None + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/layer_normalization_3d_axis1_epsilon/input_0.cairo b/tests/nodes/layer_normalization_3d_axis1_epsilon/input_0.cairo new file mode 100644 index 000000000..2f4da77ea --- /dev/null +++ b/tests/nodes/layer_normalization_3d_axis1_epsilon/input_0.cairo @@ -0,0 +1,44 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 73220, sign: false }); + data.append(FP16x16 { mag: 15621, sign: false }); + data.append(FP16x16 { mag: 26862, sign: true }); + data.append(FP16x16 { mag: 63971, sign: false }); + data.append(FP16x16 { mag: 64826, sign: false }); + data.append(FP16x16 { mag: 18837, sign: false }); + data.append(FP16x16 { mag: 66021, sign: true }); + data.append(FP16x16 { mag: 42181, sign: true }); + data.append(FP16x16 { mag: 69342, sign: true }); + data.append(FP16x16 { mag: 72001, sign: true }); + data.append(FP16x16 { mag: 99818, sign: true }); + data.append(FP16x16 { mag: 63088, sign: false }); + data.append(FP16x16 { mag: 17845, sign: true }); + data.append(FP16x16 { mag: 37020, sign: true }); + data.append(FP16x16 { mag: 20567, sign: false }); + data.append(FP16x16 { mag: 1924, sign: true }); + data.append(FP16x16 { mag: 13154, sign: true }); + data.append(FP16x16 { mag: 88735, sign: false }); + data.append(FP16x16 { mag: 40464, sign: false }); + data.append(FP16x16 { mag: 96907, sign: false }); + data.append(FP16x16 { mag: 79699, sign: false }); + data.append(FP16x16 { mag: 91862, sign: true }); + data.append(FP16x16 { mag: 97396, sign: false }); + data.append(FP16x16 { mag: 23929, sign: false }); + data.append(FP16x16 { mag: 11785, sign: true }); + data.append(FP16x16 { mag: 7747, sign: false }); + data.append(FP16x16 { mag: 91889, sign: true }); + data.append(FP16x16 { mag: 16735, sign: true }); + data.append(FP16x16 { mag: 120303, sign: true }); + data.append(FP16x16 { mag: 116144, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_3d_axis1_epsilon/input_1.cairo b/tests/nodes/layer_normalization_3d_axis1_epsilon/input_1.cairo new file mode 100644 index 000000000..6db5b252a --- /dev/null +++ b/tests/nodes/layer_normalization_3d_axis1_epsilon/input_1.cairo @@ -0,0 +1,28 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 71268, sign: false }); + data.append(FP16x16 { mag: 7868, sign: false }); + data.append(FP16x16 { mag: 95401, sign: true }); + data.append(FP16x16 { mag: 1087, sign: true }); + data.append(FP16x16 { mag: 1166, sign: false }); + data.append(FP16x16 { mag: 10185, sign: false }); + data.append(FP16x16 { mag: 52837, sign: true }); + data.append(FP16x16 { mag: 5760, sign: true }); + data.append(FP16x16 { mag: 21502, sign: true }); + data.append(FP16x16 { mag: 44185, sign: true }); + data.append(FP16x16 { mag: 39539, sign: false }); + data.append(FP16x16 { mag: 113293, sign: false }); + data.append(FP16x16 { mag: 24873, sign: false }); + data.append(FP16x16 { mag: 124246, sign: false }); + data.append(FP16x16 { mag: 20310, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_3d_axis1_epsilon/input_2.cairo b/tests/nodes/layer_normalization_3d_axis1_epsilon/input_2.cairo new file mode 100644 index 000000000..86501b55a --- /dev/null +++ b/tests/nodes/layer_normalization_3d_axis1_epsilon/input_2.cairo @@ -0,0 +1,28 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_2() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 37055, sign: true }); + data.append(FP16x16 { mag: 26922, sign: false }); + data.append(FP16x16 { mag: 155904, sign: false }); + data.append(FP16x16 { mag: 33841, sign: true }); + data.append(FP16x16 { mag: 53256, sign: false }); + data.append(FP16x16 { mag: 22490, sign: false }); + data.append(FP16x16 { mag: 110070, sign: false }); + data.append(FP16x16 { mag: 90061, sign: true }); + data.append(FP16x16 { mag: 44130, sign: true }); + data.append(FP16x16 { mag: 8720, sign: true }); + data.append(FP16x16 { mag: 61513, sign: true }); + data.append(FP16x16 { mag: 42238, sign: true }); + data.append(FP16x16 { mag: 18154, sign: false }); + data.append(FP16x16 { mag: 88282, sign: false }); + data.append(FP16x16 { mag: 29231, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_3d_axis1_epsilon/output_0.cairo b/tests/nodes/layer_normalization_3d_axis1_epsilon/output_0.cairo new file mode 100644 index 000000000..ac2abb924 --- /dev/null +++ b/tests/nodes/layer_normalization_3d_axis1_epsilon/output_0.cairo @@ -0,0 +1,44 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 59799, sign: false }); + data.append(FP16x16 { mag: 29975, sign: false }); + data.append(FP16x16 { mag: 187208, sign: false }); + data.append(FP16x16 { mag: 35149, sign: true }); + data.append(FP16x16 { mag: 54676, sign: false }); + data.append(FP16x16 { mag: 26994, sign: false }); + data.append(FP16x16 { mag: 162286, sign: false }); + data.append(FP16x16 { mag: 86683, sign: true }); + data.append(FP16x16 { mag: 21676, sign: true }); + data.append(FP16x16 { mag: 39399, sign: false }); + data.append(FP16x16 { mag: 123115, sign: true }); + data.append(FP16x16 { mag: 92377, sign: false }); + data.append(FP16x16 { mag: 13773, sign: false }); + data.append(FP16x16 { mag: 26238, sign: false }); + data.append(FP16x16 { mag: 19656, sign: true }); + data.append(FP16x16 { mag: 51756, sign: true }); + data.append(FP16x16 { mag: 24121, sign: false }); + data.append(FP16x16 { mag: 60347, sign: false }); + data.append(FP16x16 { mag: 34230, sign: true }); + data.append(FP16x16 { mag: 54552, sign: false }); + data.append(FP16x16 { mag: 31465, sign: false }); + data.append(FP16x16 { mag: 184284, sign: false }); + data.append(FP16x16 { mag: 96496, sign: true }); + data.append(FP16x16 { mag: 47101, sign: true }); + data.append(FP16x16 { mag: 6198, sign: false }); + data.append(FP16x16 { mag: 64574, sign: true }); + data.append(FP16x16 { mag: 201408, sign: true }); + data.append(FP16x16 { mag: 8115, sign: false }); + data.append(FP16x16 { mag: 133313, sign: true }); + data.append(FP16x16 { mag: 1471, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_3d_axis2_epsilon.cairo b/tests/nodes/layer_normalization_3d_axis2_epsilon.cairo new file mode 100644 index 000000000..2d3fe7a76 --- /dev/null +++ b/tests/nodes/layer_normalization_3d_axis2_epsilon.cairo @@ -0,0 +1,32 @@ +mod input_0; +mod input_1; +mod input_2; +mod output_0; + + +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::utils::{assert_eq, assert_seq_eq}; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{IntegerTrait, i32, FixedTrait}; + +#[test] +#[available_gas(2000000000)] +fn test_layer_normalization_3d_axis2_epsilon() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let input_2 = input_2::input_2(); + let z_0 = output_0::output_0(); + + let (y_0, _, _) = input_0 + .layer_normalization( + @input_1, + Option::Some(@input_2), + Option::Some(IntegerTrait::::new(2, false)), + Option::Some(FixedTrait::new(6554, false)), + Option::None + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/layer_normalization_3d_axis2_epsilon/input_0.cairo b/tests/nodes/layer_normalization_3d_axis2_epsilon/input_0.cairo new file mode 100644 index 000000000..2f4da77ea --- /dev/null +++ b/tests/nodes/layer_normalization_3d_axis2_epsilon/input_0.cairo @@ -0,0 +1,44 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 73220, sign: false }); + data.append(FP16x16 { mag: 15621, sign: false }); + data.append(FP16x16 { mag: 26862, sign: true }); + data.append(FP16x16 { mag: 63971, sign: false }); + data.append(FP16x16 { mag: 64826, sign: false }); + data.append(FP16x16 { mag: 18837, sign: false }); + data.append(FP16x16 { mag: 66021, sign: true }); + data.append(FP16x16 { mag: 42181, sign: true }); + data.append(FP16x16 { mag: 69342, sign: true }); + data.append(FP16x16 { mag: 72001, sign: true }); + data.append(FP16x16 { mag: 99818, sign: true }); + data.append(FP16x16 { mag: 63088, sign: false }); + data.append(FP16x16 { mag: 17845, sign: true }); + data.append(FP16x16 { mag: 37020, sign: true }); + data.append(FP16x16 { mag: 20567, sign: false }); + data.append(FP16x16 { mag: 1924, sign: true }); + data.append(FP16x16 { mag: 13154, sign: true }); + data.append(FP16x16 { mag: 88735, sign: false }); + data.append(FP16x16 { mag: 40464, sign: false }); + data.append(FP16x16 { mag: 96907, sign: false }); + data.append(FP16x16 { mag: 79699, sign: false }); + data.append(FP16x16 { mag: 91862, sign: true }); + data.append(FP16x16 { mag: 97396, sign: false }); + data.append(FP16x16 { mag: 23929, sign: false }); + data.append(FP16x16 { mag: 11785, sign: true }); + data.append(FP16x16 { mag: 7747, sign: false }); + data.append(FP16x16 { mag: 91889, sign: true }); + data.append(FP16x16 { mag: 16735, sign: true }); + data.append(FP16x16 { mag: 120303, sign: true }); + data.append(FP16x16 { mag: 116144, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_3d_axis2_epsilon/input_1.cairo b/tests/nodes/layer_normalization_3d_axis2_epsilon/input_1.cairo new file mode 100644 index 000000000..4bb7a1456 --- /dev/null +++ b/tests/nodes/layer_normalization_3d_axis2_epsilon/input_1.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 49614, sign: false }); + data.append(FP16x16 { mag: 39079, sign: true }); + data.append(FP16x16 { mag: 59684, sign: true }); + data.append(FP16x16 { mag: 44628, sign: false }); + data.append(FP16x16 { mag: 45415, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_3d_axis2_epsilon/input_2.cairo b/tests/nodes/layer_normalization_3d_axis2_epsilon/input_2.cairo new file mode 100644 index 000000000..797fd7460 --- /dev/null +++ b/tests/nodes/layer_normalization_3d_axis2_epsilon/input_2.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_2() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 93216, sign: true }); + data.append(FP16x16 { mag: 37735, sign: false }); + data.append(FP16x16 { mag: 93039, sign: false }); + data.append(FP16x16 { mag: 65618, sign: false }); + data.append(FP16x16 { mag: 52063, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_3d_axis2_epsilon/output_0.cairo b/tests/nodes/layer_normalization_3d_axis2_epsilon/output_0.cairo new file mode 100644 index 000000000..aa884adae --- /dev/null +++ b/tests/nodes/layer_normalization_3d_axis2_epsilon/output_0.cairo @@ -0,0 +1,44 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 53283, sign: true }); + data.append(FP16x16 { mag: 57948, sign: false }); + data.append(FP16x16 { mag: 182112, sign: false }); + data.append(FP16x16 { mag: 92064, sign: false }); + data.append(FP16x16 { mag: 79866, sign: false }); + data.append(FP16x16 { mag: 12556, sign: true }); + data.append(FP16x16 { mag: 57172, sign: false }); + data.append(FP16x16 { mag: 87126, sign: false }); + data.append(FP16x16 { mag: 39713, sign: false }); + data.append(FP16x16 { mag: 22679, sign: false }); + data.append(FP16x16 { mag: 165649, sign: true }); + data.append(FP16x16 { mag: 13773, sign: true }); + data.append(FP16x16 { mag: 96744, sign: false }); + data.append(FP16x16 { mag: 48255, sign: false }); + data.append(FP16x16 { mag: 78994, sign: false }); + data.append(FP16x16 { mag: 137324, sign: true }); + data.append(FP16x16 { mag: 81318, sign: false }); + data.append(FP16x16 { mag: 37094, sign: false }); + data.append(FP16x16 { mag: 64052, sign: false }); + data.append(FP16x16 { mag: 102111, sign: false }); + data.append(FP16x16 { mag: 51160, sign: true }); + data.append(FP16x16 { mag: 98975, sign: false }); + data.append(FP16x16 { mag: 27582, sign: false }); + data.append(FP16x16 { mag: 68415, sign: false }); + data.append(FP16x16 { mag: 32080, sign: false }); + data.append(FP16x16 { mag: 76560, sign: true }); + data.append(FP16x16 { mag: 70073, sign: false }); + data.append(FP16x16 { mag: 90063, sign: false }); + data.append(FP16x16 { mag: 13883, sign: false }); + data.append(FP16x16 { mag: 124780, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_3d_axis_negative_1_epsilon.cairo b/tests/nodes/layer_normalization_3d_axis_negative_1_epsilon.cairo new file mode 100644 index 000000000..f197b1609 --- /dev/null +++ b/tests/nodes/layer_normalization_3d_axis_negative_1_epsilon.cairo @@ -0,0 +1,32 @@ +mod input_0; +mod input_1; +mod input_2; +mod output_0; + + +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::utils::{assert_eq, assert_seq_eq}; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{IntegerTrait, i32, FixedTrait}; + +#[test] +#[available_gas(2000000000)] +fn test_layer_normalization_3d_axis_negative_1_epsilon() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let input_2 = input_2::input_2(); + let z_0 = output_0::output_0(); + + let (y_0, _, _) = input_0 + .layer_normalization( + @input_1, + Option::Some(@input_2), + Option::Some(IntegerTrait::::new(1, true)), + Option::Some(FixedTrait::new(6554, false)), + Option::None + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/layer_normalization_3d_axis_negative_1_epsilon/input_0.cairo b/tests/nodes/layer_normalization_3d_axis_negative_1_epsilon/input_0.cairo new file mode 100644 index 000000000..2f4da77ea --- /dev/null +++ b/tests/nodes/layer_normalization_3d_axis_negative_1_epsilon/input_0.cairo @@ -0,0 +1,44 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 73220, sign: false }); + data.append(FP16x16 { mag: 15621, sign: false }); + data.append(FP16x16 { mag: 26862, sign: true }); + data.append(FP16x16 { mag: 63971, sign: false }); + data.append(FP16x16 { mag: 64826, sign: false }); + data.append(FP16x16 { mag: 18837, sign: false }); + data.append(FP16x16 { mag: 66021, sign: true }); + data.append(FP16x16 { mag: 42181, sign: true }); + data.append(FP16x16 { mag: 69342, sign: true }); + data.append(FP16x16 { mag: 72001, sign: true }); + data.append(FP16x16 { mag: 99818, sign: true }); + data.append(FP16x16 { mag: 63088, sign: false }); + data.append(FP16x16 { mag: 17845, sign: true }); + data.append(FP16x16 { mag: 37020, sign: true }); + data.append(FP16x16 { mag: 20567, sign: false }); + data.append(FP16x16 { mag: 1924, sign: true }); + data.append(FP16x16 { mag: 13154, sign: true }); + data.append(FP16x16 { mag: 88735, sign: false }); + data.append(FP16x16 { mag: 40464, sign: false }); + data.append(FP16x16 { mag: 96907, sign: false }); + data.append(FP16x16 { mag: 79699, sign: false }); + data.append(FP16x16 { mag: 91862, sign: true }); + data.append(FP16x16 { mag: 97396, sign: false }); + data.append(FP16x16 { mag: 23929, sign: false }); + data.append(FP16x16 { mag: 11785, sign: true }); + data.append(FP16x16 { mag: 7747, sign: false }); + data.append(FP16x16 { mag: 91889, sign: true }); + data.append(FP16x16 { mag: 16735, sign: true }); + data.append(FP16x16 { mag: 120303, sign: true }); + data.append(FP16x16 { mag: 116144, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_3d_axis_negative_1_epsilon/input_1.cairo b/tests/nodes/layer_normalization_3d_axis_negative_1_epsilon/input_1.cairo new file mode 100644 index 000000000..908e17fa1 --- /dev/null +++ b/tests/nodes/layer_normalization_3d_axis_negative_1_epsilon/input_1.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 9463, sign: false }); + data.append(FP16x16 { mag: 34110, sign: true }); + data.append(FP16x16 { mag: 50067, sign: false }); + data.append(FP16x16 { mag: 4048, sign: false }); + data.append(FP16x16 { mag: 19840, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_3d_axis_negative_1_epsilon/input_2.cairo b/tests/nodes/layer_normalization_3d_axis_negative_1_epsilon/input_2.cairo new file mode 100644 index 000000000..cf65126ba --- /dev/null +++ b/tests/nodes/layer_normalization_3d_axis_negative_1_epsilon/input_2.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_2() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 41329, sign: true }); + data.append(FP16x16 { mag: 41872, sign: true }); + data.append(FP16x16 { mag: 18851, sign: true }); + data.append(FP16x16 { mag: 38714, sign: true }); + data.append(FP16x16 { mag: 67617, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_3d_axis_negative_1_epsilon/output_0.cairo b/tests/nodes/layer_normalization_3d_axis_negative_1_epsilon/output_0.cairo new file mode 100644 index 000000000..ed4491dc1 --- /dev/null +++ b/tests/nodes/layer_normalization_3d_axis_negative_1_epsilon/output_0.cairo @@ -0,0 +1,44 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 33712, sign: true }); + data.append(FP16x16 { mag: 24229, sign: true }); + data.append(FP16x16 { mag: 93572, sign: true }); + data.append(FP16x16 { mag: 36316, sign: true }); + data.append(FP16x16 { mag: 55471, sign: true }); + data.append(FP16x16 { mag: 25943, sign: true }); + data.append(FP16x16 { mag: 24906, sign: true }); + data.append(FP16x16 { mag: 13890, sign: true }); + data.append(FP16x16 { mag: 41064, sign: true }); + data.append(FP16x16 { mag: 80454, sign: true }); + data.append(FP16x16 { mag: 55145, sign: true }); + data.append(FP16x16 { mag: 86832, sign: true }); + data.append(FP16x16 { mag: 21959, sign: true }); + data.append(FP16x16 { mag: 40289, sign: true }); + data.append(FP16x16 { mag: 55852, sign: true }); + data.append(FP16x16 { mag: 49742, sign: true }); + data.append(FP16x16 { mag: 3830, sign: true }); + data.append(FP16x16 { mag: 28080, sign: false }); + data.append(FP16x16 { mag: 38856, sign: true }); + data.append(FP16x16 { mag: 45753, sign: true }); + data.append(FP16x16 { mag: 33307, sign: true }); + data.append(FP16x16 { mag: 11580, sign: false }); + data.append(FP16x16 { mag: 36059, sign: false }); + data.append(FP16x16 { mag: 38461, sign: true }); + data.append(FP16x16 { mag: 76347, sign: true }); + data.append(FP16x16 { mag: 38152, sign: true }); + data.append(FP16x16 { mag: 13646, sign: true }); + data.append(FP16x16 { mag: 16354, sign: true }); + data.append(FP16x16 { mag: 43407, sign: true }); + data.append(FP16x16 { mag: 35849, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_3d_axis_negative_2_epsilon.cairo b/tests/nodes/layer_normalization_3d_axis_negative_2_epsilon.cairo new file mode 100644 index 000000000..0abf1233d --- /dev/null +++ b/tests/nodes/layer_normalization_3d_axis_negative_2_epsilon.cairo @@ -0,0 +1,33 @@ +mod input_0; +mod input_1; +mod input_2; +mod output_0; + + +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::utils::{assert_eq, assert_seq_eq}; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; + +use orion::numbers::{IntegerTrait, i32, FixedTrait}; + +#[test] +#[available_gas(2000000000)] +fn test_layer_normalization_3d_axis_negative_2_epsilon() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let input_2 = input_2::input_2(); + let z_0 = output_0::output_0(); + + let (y_0, _, _) = input_0 + .layer_normalization( + @input_1, + Option::Some(@input_2), + Option::Some(IntegerTrait::::new(2, true)), + Option::Some(FixedTrait::new(6554, false)), + Option::None + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/layer_normalization_3d_axis_negative_2_epsilon/input_0.cairo b/tests/nodes/layer_normalization_3d_axis_negative_2_epsilon/input_0.cairo new file mode 100644 index 000000000..2f4da77ea --- /dev/null +++ b/tests/nodes/layer_normalization_3d_axis_negative_2_epsilon/input_0.cairo @@ -0,0 +1,44 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 73220, sign: false }); + data.append(FP16x16 { mag: 15621, sign: false }); + data.append(FP16x16 { mag: 26862, sign: true }); + data.append(FP16x16 { mag: 63971, sign: false }); + data.append(FP16x16 { mag: 64826, sign: false }); + data.append(FP16x16 { mag: 18837, sign: false }); + data.append(FP16x16 { mag: 66021, sign: true }); + data.append(FP16x16 { mag: 42181, sign: true }); + data.append(FP16x16 { mag: 69342, sign: true }); + data.append(FP16x16 { mag: 72001, sign: true }); + data.append(FP16x16 { mag: 99818, sign: true }); + data.append(FP16x16 { mag: 63088, sign: false }); + data.append(FP16x16 { mag: 17845, sign: true }); + data.append(FP16x16 { mag: 37020, sign: true }); + data.append(FP16x16 { mag: 20567, sign: false }); + data.append(FP16x16 { mag: 1924, sign: true }); + data.append(FP16x16 { mag: 13154, sign: true }); + data.append(FP16x16 { mag: 88735, sign: false }); + data.append(FP16x16 { mag: 40464, sign: false }); + data.append(FP16x16 { mag: 96907, sign: false }); + data.append(FP16x16 { mag: 79699, sign: false }); + data.append(FP16x16 { mag: 91862, sign: true }); + data.append(FP16x16 { mag: 97396, sign: false }); + data.append(FP16x16 { mag: 23929, sign: false }); + data.append(FP16x16 { mag: 11785, sign: true }); + data.append(FP16x16 { mag: 7747, sign: false }); + data.append(FP16x16 { mag: 91889, sign: true }); + data.append(FP16x16 { mag: 16735, sign: true }); + data.append(FP16x16 { mag: 120303, sign: true }); + data.append(FP16x16 { mag: 116144, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_3d_axis_negative_2_epsilon/input_1.cairo b/tests/nodes/layer_normalization_3d_axis_negative_2_epsilon/input_1.cairo new file mode 100644 index 000000000..ed3eccc4c --- /dev/null +++ b/tests/nodes/layer_normalization_3d_axis_negative_2_epsilon/input_1.cairo @@ -0,0 +1,28 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 76875, sign: true }); + data.append(FP16x16 { mag: 70595, sign: true }); + data.append(FP16x16 { mag: 48362, sign: true }); + data.append(FP16x16 { mag: 114023, sign: false }); + data.append(FP16x16 { mag: 68398, sign: true }); + data.append(FP16x16 { mag: 90609, sign: false }); + data.append(FP16x16 { mag: 11920, sign: false }); + data.append(FP16x16 { mag: 83372, sign: true }); + data.append(FP16x16 { mag: 131126, sign: false }); + data.append(FP16x16 { mag: 36226, sign: false }); + data.append(FP16x16 { mag: 15255, sign: true }); + data.append(FP16x16 { mag: 97966, sign: false }); + data.append(FP16x16 { mag: 156224, sign: false }); + data.append(FP16x16 { mag: 92550, sign: true }); + data.append(FP16x16 { mag: 120464, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_3d_axis_negative_2_epsilon/input_2.cairo b/tests/nodes/layer_normalization_3d_axis_negative_2_epsilon/input_2.cairo new file mode 100644 index 000000000..27b6c770b --- /dev/null +++ b/tests/nodes/layer_normalization_3d_axis_negative_2_epsilon/input_2.cairo @@ -0,0 +1,28 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_2() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 93757, sign: false }); + data.append(FP16x16 { mag: 4959, sign: true }); + data.append(FP16x16 { mag: 26505, sign: false }); + data.append(FP16x16 { mag: 1530, sign: false }); + data.append(FP16x16 { mag: 75165, sign: true }); + data.append(FP16x16 { mag: 97790, sign: true }); + data.append(FP16x16 { mag: 16969, sign: true }); + data.append(FP16x16 { mag: 61662, sign: false }); + data.append(FP16x16 { mag: 110907, sign: true }); + data.append(FP16x16 { mag: 17227, sign: true }); + data.append(FP16x16 { mag: 14091, sign: false }); + data.append(FP16x16 { mag: 52957, sign: true }); + data.append(FP16x16 { mag: 41342, sign: false }); + data.append(FP16x16 { mag: 34186, sign: false }); + data.append(FP16x16 { mag: 17811, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_3d_axis_negative_2_epsilon/output_0.cairo b/tests/nodes/layer_normalization_3d_axis_negative_2_epsilon/output_0.cairo new file mode 100644 index 000000000..5da7eb6ac --- /dev/null +++ b/tests/nodes/layer_normalization_3d_axis_negative_2_epsilon/output_0.cairo @@ -0,0 +1,44 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 10716, sign: true }); + data.append(FP16x16 { mag: 32353, sign: true }); + data.append(FP16x16 { mag: 42375, sign: false }); + data.append(FP16x16 { mag: 138711, sign: false }); + data.append(FP16x16 { mag: 158441, sign: true }); + data.append(FP16x16 { mag: 57718, sign: true }); + data.append(FP16x16 { mag: 28750, sign: true }); + data.append(FP16x16 { mag: 110550, sign: false }); + data.append(FP16x16 { mag: 247834, sign: true }); + data.append(FP16x16 { mag: 56680, sign: true }); + data.append(FP16x16 { mag: 37859, sign: false }); + data.append(FP16x16 { mag: 63446, sign: false }); + data.append(FP16x16 { mag: 13827, sign: false }); + data.append(FP16x16 { mag: 80402, sign: false }); + data.append(FP16x16 { mag: 74600, sign: false }); + data.append(FP16x16 { mag: 109615, sign: false }); + data.append(FP16x16 { mag: 20165, sign: false }); + data.append(FP16x16 { mag: 21935, sign: true }); + data.append(FP16x16 { mag: 42407, sign: false }); + data.append(FP16x16 { mag: 151122, sign: true }); + data.append(FP16x16 { mag: 17941, sign: true }); + data.append(FP16x16 { mag: 33712, sign: true }); + data.append(FP16x16 { mag: 31466, sign: true }); + data.append(FP16x16 { mag: 92787, sign: true }); + data.append(FP16x16 { mag: 29459, sign: true }); + data.append(FP16x16 { mag: 15272, sign: false }); + data.append(FP16x16 { mag: 190594, sign: true }); + data.append(FP16x16 { mag: 21712, sign: true }); + data.append(FP16x16 { mag: 199251, sign: false }); + data.append(FP16x16 { mag: 182464, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_3d_axis_negative_3_epsilon.cairo b/tests/nodes/layer_normalization_3d_axis_negative_3_epsilon.cairo new file mode 100644 index 000000000..2c56a29de --- /dev/null +++ b/tests/nodes/layer_normalization_3d_axis_negative_3_epsilon.cairo @@ -0,0 +1,33 @@ +mod input_0; +mod input_1; +mod input_2; +mod output_0; + + +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::utils::{assert_eq, assert_seq_eq}; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; + +use orion::numbers::{IntegerTrait, i32, FixedTrait}; + +#[test] +#[available_gas(2000000000)] +fn test_layer_normalization_3d_axis_negative_3_epsilon() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let input_2 = input_2::input_2(); + let z_0 = output_0::output_0(); + + let (y_0, _, _) = input_0 + .layer_normalization( + @input_1, + Option::Some(@input_2), + Option::Some(IntegerTrait::::new(3, true)), + Option::Some(FixedTrait::new(6554, false)), + Option::None + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/layer_normalization_3d_axis_negative_3_epsilon/input_0.cairo b/tests/nodes/layer_normalization_3d_axis_negative_3_epsilon/input_0.cairo new file mode 100644 index 000000000..2f4da77ea --- /dev/null +++ b/tests/nodes/layer_normalization_3d_axis_negative_3_epsilon/input_0.cairo @@ -0,0 +1,44 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 73220, sign: false }); + data.append(FP16x16 { mag: 15621, sign: false }); + data.append(FP16x16 { mag: 26862, sign: true }); + data.append(FP16x16 { mag: 63971, sign: false }); + data.append(FP16x16 { mag: 64826, sign: false }); + data.append(FP16x16 { mag: 18837, sign: false }); + data.append(FP16x16 { mag: 66021, sign: true }); + data.append(FP16x16 { mag: 42181, sign: true }); + data.append(FP16x16 { mag: 69342, sign: true }); + data.append(FP16x16 { mag: 72001, sign: true }); + data.append(FP16x16 { mag: 99818, sign: true }); + data.append(FP16x16 { mag: 63088, sign: false }); + data.append(FP16x16 { mag: 17845, sign: true }); + data.append(FP16x16 { mag: 37020, sign: true }); + data.append(FP16x16 { mag: 20567, sign: false }); + data.append(FP16x16 { mag: 1924, sign: true }); + data.append(FP16x16 { mag: 13154, sign: true }); + data.append(FP16x16 { mag: 88735, sign: false }); + data.append(FP16x16 { mag: 40464, sign: false }); + data.append(FP16x16 { mag: 96907, sign: false }); + data.append(FP16x16 { mag: 79699, sign: false }); + data.append(FP16x16 { mag: 91862, sign: true }); + data.append(FP16x16 { mag: 97396, sign: false }); + data.append(FP16x16 { mag: 23929, sign: false }); + data.append(FP16x16 { mag: 11785, sign: true }); + data.append(FP16x16 { mag: 7747, sign: false }); + data.append(FP16x16 { mag: 91889, sign: true }); + data.append(FP16x16 { mag: 16735, sign: true }); + data.append(FP16x16 { mag: 120303, sign: true }); + data.append(FP16x16 { mag: 116144, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_3d_axis_negative_3_epsilon/input_1.cairo b/tests/nodes/layer_normalization_3d_axis_negative_3_epsilon/input_1.cairo new file mode 100644 index 000000000..07b6687de --- /dev/null +++ b/tests/nodes/layer_normalization_3d_axis_negative_3_epsilon/input_1.cairo @@ -0,0 +1,44 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 51329, sign: true }); + data.append(FP16x16 { mag: 47264, sign: true }); + data.append(FP16x16 { mag: 78049, sign: false }); + data.append(FP16x16 { mag: 31332, sign: true }); + data.append(FP16x16 { mag: 64228, sign: false }); + data.append(FP16x16 { mag: 50183, sign: false }); + data.append(FP16x16 { mag: 111933, sign: true }); + data.append(FP16x16 { mag: 37549, sign: true }); + data.append(FP16x16 { mag: 48542, sign: true }); + data.append(FP16x16 { mag: 13252, sign: true }); + data.append(FP16x16 { mag: 63185, sign: true }); + data.append(FP16x16 { mag: 2871, sign: false }); + data.append(FP16x16 { mag: 57251, sign: false }); + data.append(FP16x16 { mag: 15125, sign: true }); + data.append(FP16x16 { mag: 75974, sign: false }); + data.append(FP16x16 { mag: 29448, sign: true }); + data.append(FP16x16 { mag: 118787, sign: false }); + data.append(FP16x16 { mag: 85238, sign: true }); + data.append(FP16x16 { mag: 6392, sign: true }); + data.append(FP16x16 { mag: 32667, sign: true }); + data.append(FP16x16 { mag: 306, sign: false }); + data.append(FP16x16 { mag: 53902, sign: true }); + data.append(FP16x16 { mag: 25940, sign: true }); + data.append(FP16x16 { mag: 38753, sign: true }); + data.append(FP16x16 { mag: 73289, sign: true }); + data.append(FP16x16 { mag: 47552, sign: false }); + data.append(FP16x16 { mag: 27826, sign: false }); + data.append(FP16x16 { mag: 47550, sign: false }); + data.append(FP16x16 { mag: 36199, sign: true }); + data.append(FP16x16 { mag: 43172, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_3d_axis_negative_3_epsilon/input_2.cairo b/tests/nodes/layer_normalization_3d_axis_negative_3_epsilon/input_2.cairo new file mode 100644 index 000000000..85700c760 --- /dev/null +++ b/tests/nodes/layer_normalization_3d_axis_negative_3_epsilon/input_2.cairo @@ -0,0 +1,44 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_2() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 6143, sign: false }); + data.append(FP16x16 { mag: 4674, sign: false }); + data.append(FP16x16 { mag: 48051, sign: true }); + data.append(FP16x16 { mag: 18813, sign: false }); + data.append(FP16x16 { mag: 46995, sign: false }); + data.append(FP16x16 { mag: 20870, sign: true }); + data.append(FP16x16 { mag: 56843, sign: false }); + data.append(FP16x16 { mag: 81615, sign: false }); + data.append(FP16x16 { mag: 92340, sign: false }); + data.append(FP16x16 { mag: 84516, sign: false }); + data.append(FP16x16 { mag: 82019, sign: false }); + data.append(FP16x16 { mag: 51674, sign: false }); + data.append(FP16x16 { mag: 52303, sign: false }); + data.append(FP16x16 { mag: 7441, sign: true }); + data.append(FP16x16 { mag: 35138, sign: false }); + data.append(FP16x16 { mag: 78581, sign: false }); + data.append(FP16x16 { mag: 6660, sign: false }); + data.append(FP16x16 { mag: 137669, sign: true }); + data.append(FP16x16 { mag: 12790, sign: true }); + data.append(FP16x16 { mag: 144767, sign: false }); + data.append(FP16x16 { mag: 10893, sign: true }); + data.append(FP16x16 { mag: 26226, sign: true }); + data.append(FP16x16 { mag: 64470, sign: false }); + data.append(FP16x16 { mag: 22466, sign: false }); + data.append(FP16x16 { mag: 101996, sign: true }); + data.append(FP16x16 { mag: 46134, sign: true }); + data.append(FP16x16 { mag: 81851, sign: true }); + data.append(FP16x16 { mag: 176946, sign: false }); + data.append(FP16x16 { mag: 6446, sign: true }); + data.append(FP16x16 { mag: 77193, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_3d_axis_negative_3_epsilon/output_0.cairo b/tests/nodes/layer_normalization_3d_axis_negative_3_epsilon/output_0.cairo new file mode 100644 index 000000000..7d9e964c0 --- /dev/null +++ b/tests/nodes/layer_normalization_3d_axis_negative_3_epsilon/output_0.cairo @@ -0,0 +1,44 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 46451, sign: true }); + data.append(FP16x16 { mag: 3984, sign: true }); + data.append(FP16x16 { mag: 82191, sign: true }); + data.append(FP16x16 { mag: 9056, sign: true }); + data.append(FP16x16 { mag: 104929, sign: false }); + data.append(FP16x16 { mag: 9318, sign: true }); + data.append(FP16x16 { mag: 169836, sign: false }); + data.append(FP16x16 { mag: 106443, sign: false }); + data.append(FP16x16 { mag: 143696, sign: false }); + data.append(FP16x16 { mag: 99052, sign: false }); + data.append(FP16x16 { mag: 176999, sign: false }); + data.append(FP16x16 { mag: 54192, sign: false }); + data.append(FP16x16 { mag: 34801, sign: false }); + data.append(FP16x16 { mag: 1418, sign: false }); + data.append(FP16x16 { mag: 54547, sign: false }); + data.append(FP16x16 { mag: 80734, sign: false }); + data.append(FP16x16 { mag: 21511, sign: true }); + data.append(FP16x16 { mag: 244326, sign: true }); + data.append(FP16x16 { mag: 16281, sign: true }); + data.append(FP16x16 { mag: 99990, sign: false }); + data.append(FP16x16 { mag: 10550, sign: true }); + data.append(FP16x16 { mag: 48534, sign: false }); + data.append(FP16x16 { mag: 28730, sign: false }); + data.append(FP16x16 { mag: 10662, sign: false }); + data.append(FP16x16 { mag: 86081, sign: true }); + data.append(FP16x16 { mag: 42891, sign: true }); + data.append(FP16x16 { mag: 120456, sign: true }); + data.append(FP16x16 { mag: 163181, sign: false }); + data.append(FP16x16 { mag: 58800, sign: false }); + data.append(FP16x16 { mag: 5886, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis-1/input_0.cairo b/tests/nodes/layer_normalization_4d_axis-1/input_0.cairo new file mode 100644 index 000000000..796ab8161 --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis-1/input_0.cairo @@ -0,0 +1,135 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 50439, sign: true }); + data.append(FP16x16 { mag: 80584, sign: false }); + data.append(FP16x16 { mag: 108804, sign: false }); + data.append(FP16x16 { mag: 64736, sign: true }); + data.append(FP16x16 { mag: 35989, sign: false }); + data.append(FP16x16 { mag: 83918, sign: false }); + data.append(FP16x16 { mag: 80462, sign: true }); + data.append(FP16x16 { mag: 43318, sign: false }); + data.append(FP16x16 { mag: 65651, sign: true }); + data.append(FP16x16 { mag: 11331, sign: true }); + data.append(FP16x16 { mag: 25254, sign: true }); + data.append(FP16x16 { mag: 62677, sign: false }); + data.append(FP16x16 { mag: 8412, sign: false }); + data.append(FP16x16 { mag: 64067, sign: true }); + data.append(FP16x16 { mag: 11592, sign: true }); + data.append(FP16x16 { mag: 10557, sign: false }); + data.append(FP16x16 { mag: 33540, sign: true }); + data.append(FP16x16 { mag: 50007, sign: true }); + data.append(FP16x16 { mag: 66942, sign: true }); + data.append(FP16x16 { mag: 24781, sign: true }); + data.append(FP16x16 { mag: 70618, sign: true }); + data.append(FP16x16 { mag: 91944, sign: false }); + data.append(FP16x16 { mag: 25130, sign: false }); + data.append(FP16x16 { mag: 70069, sign: true }); + data.append(FP16x16 { mag: 19429, sign: false }); + data.append(FP16x16 { mag: 17607, sign: true }); + data.append(FP16x16 { mag: 20059, sign: true }); + data.append(FP16x16 { mag: 30637, sign: false }); + data.append(FP16x16 { mag: 110305, sign: true }); + data.append(FP16x16 { mag: 47227, sign: true }); + data.append(FP16x16 { mag: 52701, sign: false }); + data.append(FP16x16 { mag: 113586, sign: true }); + data.append(FP16x16 { mag: 28893, sign: false }); + data.append(FP16x16 { mag: 19309, sign: true }); + data.append(FP16x16 { mag: 9704, sign: true }); + data.append(FP16x16 { mag: 56793, sign: false }); + data.append(FP16x16 { mag: 44072, sign: false }); + data.append(FP16x16 { mag: 2660, sign: true }); + data.append(FP16x16 { mag: 29401, sign: true }); + data.append(FP16x16 { mag: 36736, sign: true }); + data.append(FP16x16 { mag: 86250, sign: false }); + data.append(FP16x16 { mag: 38971, sign: true }); + data.append(FP16x16 { mag: 138249, sign: true }); + data.append(FP16x16 { mag: 36048, sign: false }); + data.append(FP16x16 { mag: 37452, sign: false }); + data.append(FP16x16 { mag: 23181, sign: true }); + data.append(FP16x16 { mag: 35955, sign: true }); + data.append(FP16x16 { mag: 51748, sign: true }); + data.append(FP16x16 { mag: 57097, sign: true }); + data.append(FP16x16 { mag: 91645, sign: true }); + data.append(FP16x16 { mag: 727, sign: true }); + data.append(FP16x16 { mag: 26384, sign: true }); + data.append(FP16x16 { mag: 1331, sign: false }); + data.append(FP16x16 { mag: 98672, sign: false }); + data.append(FP16x16 { mag: 82596, sign: false }); + data.append(FP16x16 { mag: 100984, sign: true }); + data.append(FP16x16 { mag: 88934, sign: true }); + data.append(FP16x16 { mag: 110736, sign: false }); + data.append(FP16x16 { mag: 106050, sign: true }); + data.append(FP16x16 { mag: 43286, sign: true }); + data.append(FP16x16 { mag: 114427, sign: false }); + data.append(FP16x16 { mag: 26160, sign: false }); + data.append(FP16x16 { mag: 19914, sign: true }); + data.append(FP16x16 { mag: 168031, sign: false }); + data.append(FP16x16 { mag: 70834, sign: true }); + data.append(FP16x16 { mag: 62785, sign: false }); + data.append(FP16x16 { mag: 20711, sign: true }); + data.append(FP16x16 { mag: 150814, sign: true }); + data.append(FP16x16 { mag: 19909, sign: true }); + data.append(FP16x16 { mag: 4360, sign: true }); + data.append(FP16x16 { mag: 38550, sign: false }); + data.append(FP16x16 { mag: 155210, sign: false }); + data.append(FP16x16 { mag: 49387, sign: true }); + data.append(FP16x16 { mag: 48606, sign: false }); + data.append(FP16x16 { mag: 26297, sign: false }); + data.append(FP16x16 { mag: 5832, sign: true }); + data.append(FP16x16 { mag: 67019, sign: false }); + data.append(FP16x16 { mag: 108552, sign: false }); + data.append(FP16x16 { mag: 38302, sign: true }); + data.append(FP16x16 { mag: 67467, sign: false }); + data.append(FP16x16 { mag: 123524, sign: false }); + data.append(FP16x16 { mag: 38110, sign: true }); + data.append(FP16x16 { mag: 49459, sign: true }); + data.append(FP16x16 { mag: 89977, sign: false }); + data.append(FP16x16 { mag: 3116, sign: false }); + data.append(FP16x16 { mag: 139868, sign: true }); + data.append(FP16x16 { mag: 3777, sign: false }); + data.append(FP16x16 { mag: 93508, sign: false }); + data.append(FP16x16 { mag: 18473, sign: true }); + data.append(FP16x16 { mag: 9749, sign: false }); + data.append(FP16x16 { mag: 69280, sign: false }); + data.append(FP16x16 { mag: 31261, sign: true }); + data.append(FP16x16 { mag: 70810, sign: true }); + data.append(FP16x16 { mag: 53719, sign: true }); + data.append(FP16x16 { mag: 25490, sign: false }); + data.append(FP16x16 { mag: 76561, sign: true }); + data.append(FP16x16 { mag: 87810, sign: true }); + data.append(FP16x16 { mag: 54546, sign: true }); + data.append(FP16x16 { mag: 11956, sign: false }); + data.append(FP16x16 { mag: 53981, sign: true }); + data.append(FP16x16 { mag: 48734, sign: false }); + data.append(FP16x16 { mag: 81861, sign: true }); + data.append(FP16x16 { mag: 91276, sign: false }); + data.append(FP16x16 { mag: 104233, sign: false }); + data.append(FP16x16 { mag: 52396, sign: false }); + data.append(FP16x16 { mag: 37016, sign: false }); + data.append(FP16x16 { mag: 39784, sign: false }); + data.append(FP16x16 { mag: 16087, sign: true }); + data.append(FP16x16 { mag: 22464, sign: true }); + data.append(FP16x16 { mag: 98432, sign: true }); + data.append(FP16x16 { mag: 120832, sign: true }); + data.append(FP16x16 { mag: 25665, sign: true }); + data.append(FP16x16 { mag: 23338, sign: true }); + data.append(FP16x16 { mag: 18801, sign: false }); + data.append(FP16x16 { mag: 22675, sign: false }); + data.append(FP16x16 { mag: 79634, sign: true }); + data.append(FP16x16 { mag: 95608, sign: true }); + data.append(FP16x16 { mag: 73767, sign: false }); + data.append(FP16x16 { mag: 20235, sign: false }); + data.append(FP16x16 { mag: 19535, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis-1/input_1.cairo b/tests/nodes/layer_normalization_4d_axis-1/input_1.cairo new file mode 100644 index 000000000..5796909a7 --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis-1/input_1.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 4256, sign: true }); + data.append(FP16x16 { mag: 12347, sign: true }); + data.append(FP16x16 { mag: 12223, sign: true }); + data.append(FP16x16 { mag: 35664, sign: false }); + data.append(FP16x16 { mag: 47729, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis-1/input_2.cairo b/tests/nodes/layer_normalization_4d_axis-1/input_2.cairo new file mode 100644 index 000000000..5b36613cf --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis-1/input_2.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_2() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 31681, sign: true }); + data.append(FP16x16 { mag: 39712, sign: false }); + data.append(FP16x16 { mag: 111813, sign: true }); + data.append(FP16x16 { mag: 73292, sign: false }); + data.append(FP16x16 { mag: 69974, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis-1/output_0.cairo b/tests/nodes/layer_normalization_4d_axis-1/output_0.cairo new file mode 100644 index 000000000..b16b3e34c --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis-1/output_0.cairo @@ -0,0 +1,135 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 27222, sign: true }); + data.append(FP16x16 { mag: 29264, sign: false }); + data.append(FP16x16 { mag: 127142, sign: true }); + data.append(FP16x16 { mag: 28560, sign: false }); + data.append(FP16x16 { mag: 60351, sign: true }); + data.append(FP16x16 { mag: 37788, sign: true }); + data.append(FP16x16 { mag: 54370, sign: false }); + data.append(FP16x16 { mag: 121438, sign: true }); + data.append(FP16x16 { mag: 39380, sign: false }); + data.append(FP16x16 { mag: 74002, sign: true }); + data.append(FP16x16 { mag: 29712, sign: true }); + data.append(FP16x16 { mag: 19392, sign: false }); + data.append(FP16x16 { mag: 116026, sign: true }); + data.append(FP16x16 { mag: 23611, sign: false }); + data.append(FP16x16 { mag: 76413, sign: true }); + data.append(FP16x16 { mag: 38774, sign: true }); + data.append(FP16x16 { mag: 39995, sign: false }); + data.append(FP16x16 { mag: 103821, sign: true }); + data.append(FP16x16 { mag: 26838, sign: false }); + data.append(FP16x16 { mag: 55050, sign: true }); + data.append(FP16x16 { mag: 26907, sign: true }); + data.append(FP16x16 { mag: 21300, sign: false }); + data.append(FP16x16 { mag: 116914, sign: true }); + data.append(FP16x16 { mag: 33608, sign: false }); + data.append(FP16x16 { mag: 54427, sign: true }); + data.append(FP16x16 { mag: 33093, sign: true }); + data.append(FP16x16 { mag: 36271, sign: false }); + data.append(FP16x16 { mag: 128658, sign: true }); + data.append(FP16x16 { mag: 13438, sign: false }); + data.append(FP16x16 { mag: 84790, sign: true }); + data.append(FP16x16 { mag: 36528, sign: true }); + data.append(FP16x16 { mag: 61679, sign: false }); + data.append(FP16x16 { mag: 120628, sign: true }); + data.append(FP16x16 { mag: 68843, sign: false }); + data.append(FP16x16 { mag: 67883, sign: true }); + data.append(FP16x16 { mag: 37338, sign: true }); + data.append(FP16x16 { mag: 27445, sign: false }); + data.append(FP16x16 { mag: 108887, sign: true }); + data.append(FP16x16 { mag: 39593, sign: false }); + data.append(FP16x16 { mag: 124310, sign: true }); + data.append(FP16x16 { mag: 36554, sign: true }); + data.append(FP16x16 { mag: 45301, sign: false }); + data.append(FP16x16 { mag: 90795, sign: true }); + data.append(FP16x16 { mag: 91286, sign: false }); + data.append(FP16x16 { mag: 45037, sign: true }); + data.append(FP16x16 { mag: 36959, sign: true }); + data.append(FP16x16 { mag: 31204, sign: false }); + data.append(FP16x16 { mag: 111906, sign: true }); + data.append(FP16x16 { mag: 65334, sign: false }); + data.append(FP16x16 { mag: 151771, sign: true }); + data.append(FP16x16 { mag: 28963, sign: true }); + data.append(FP16x16 { mag: 53952, sign: false }); + data.append(FP16x16 { mag: 104513, sign: true }); + data.append(FP16x16 { mag: 121643, sign: false }); + data.append(FP16x16 { mag: 20658, sign: true }); + data.append(FP16x16 { mag: 28787, sign: true }); + data.append(FP16x16 { mag: 46278, sign: false }); + data.append(FP16x16 { mag: 135335, sign: true }); + data.append(FP16x16 { mag: 46819, sign: false }); + data.append(FP16x16 { mag: 68555, sign: true }); + data.append(FP16x16 { mag: 35143, sign: true }); + data.append(FP16x16 { mag: 42180, sign: false }); + data.append(FP16x16 { mag: 102903, sign: true }); + data.append(FP16x16 { mag: 124252, sign: false }); + data.append(FP16x16 { mag: 132667, sign: true }); + data.append(FP16x16 { mag: 37173, sign: true }); + data.append(FP16x16 { mag: 38662, sign: false }); + data.append(FP16x16 { mag: 89894, sign: true }); + data.append(FP16x16 { mag: 76738, sign: false }); + data.append(FP16x16 { mag: 54649, sign: true }); + data.append(FP16x16 { mag: 31336, sign: true }); + data.append(FP16x16 { mag: 18722, sign: false }); + data.append(FP16x16 { mag: 94412, sign: true }); + data.append(FP16x16 { mag: 75879, sign: false }); + data.append(FP16x16 { mag: 82768, sign: true }); + data.append(FP16x16 { mag: 28065, sign: true }); + data.append(FP16x16 { mag: 33448, sign: false }); + data.append(FP16x16 { mag: 127469, sign: true }); + data.append(FP16x16 { mag: 21429, sign: false }); + data.append(FP16x16 { mag: 45363, sign: true }); + data.append(FP16x16 { mag: 37692, sign: true }); + data.append(FP16x16 { mag: 51120, sign: false }); + data.append(FP16x16 { mag: 98514, sign: true }); + data.append(FP16x16 { mag: 106370, sign: false }); + data.append(FP16x16 { mag: 85630, sign: true }); + data.append(FP16x16 { mag: 24340, sign: true }); + data.append(FP16x16 { mag: 37405, sign: false }); + data.append(FP16x16 { mag: 128694, sign: true }); + data.append(FP16x16 { mag: 69394, sign: false }); + data.append(FP16x16 { mag: 57263, sign: true }); + data.append(FP16x16 { mag: 38335, sign: true }); + data.append(FP16x16 { mag: 44227, sign: false }); + data.append(FP16x16 { mag: 98068, sign: true }); + data.append(FP16x16 { mag: 44884, sign: false }); + data.append(FP16x16 { mag: 35454, sign: true }); + data.append(FP16x16 { mag: 28683, sign: true }); + data.append(FP16x16 { mag: 52423, sign: false }); + data.append(FP16x16 { mag: 110980, sign: true }); + data.append(FP16x16 { mag: 139406, sign: false }); + data.append(FP16x16 { mag: 72447, sign: true }); + data.append(FP16x16 { mag: 32053, sign: true }); + data.append(FP16x16 { mag: 63056, sign: false }); + data.append(FP16x16 { mag: 120760, sign: true }); + data.append(FP16x16 { mag: 106395, sign: false }); + data.append(FP16x16 { mag: 63149, sign: true }); + data.append(FP16x16 { mag: 35826, sign: true }); + data.append(FP16x16 { mag: 27007, sign: false }); + data.append(FP16x16 { mag: 110829, sign: true }); + data.append(FP16x16 { mag: 65907, sign: false }); + data.append(FP16x16 { mag: 151855, sign: true }); + data.append(FP16x16 { mag: 23850, sign: true }); + data.append(FP16x16 { mag: 39710, sign: false }); + data.append(FP16x16 { mag: 112364, sign: true }); + data.append(FP16x16 { mag: 103957, sign: false }); + data.append(FP16x16 { mag: 25360, sign: true }); + data.append(FP16x16 { mag: 27646, sign: true }); + data.append(FP16x16 { mag: 54560, sign: false }); + data.append(FP16x16 { mag: 130110, sign: true }); + data.append(FP16x16 { mag: 96250, sign: false }); + data.append(FP16x16 { mag: 69503, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis-2/input_0.cairo b/tests/nodes/layer_normalization_4d_axis-2/input_0.cairo new file mode 100644 index 000000000..796ab8161 --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis-2/input_0.cairo @@ -0,0 +1,135 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 50439, sign: true }); + data.append(FP16x16 { mag: 80584, sign: false }); + data.append(FP16x16 { mag: 108804, sign: false }); + data.append(FP16x16 { mag: 64736, sign: true }); + data.append(FP16x16 { mag: 35989, sign: false }); + data.append(FP16x16 { mag: 83918, sign: false }); + data.append(FP16x16 { mag: 80462, sign: true }); + data.append(FP16x16 { mag: 43318, sign: false }); + data.append(FP16x16 { mag: 65651, sign: true }); + data.append(FP16x16 { mag: 11331, sign: true }); + data.append(FP16x16 { mag: 25254, sign: true }); + data.append(FP16x16 { mag: 62677, sign: false }); + data.append(FP16x16 { mag: 8412, sign: false }); + data.append(FP16x16 { mag: 64067, sign: true }); + data.append(FP16x16 { mag: 11592, sign: true }); + data.append(FP16x16 { mag: 10557, sign: false }); + data.append(FP16x16 { mag: 33540, sign: true }); + data.append(FP16x16 { mag: 50007, sign: true }); + data.append(FP16x16 { mag: 66942, sign: true }); + data.append(FP16x16 { mag: 24781, sign: true }); + data.append(FP16x16 { mag: 70618, sign: true }); + data.append(FP16x16 { mag: 91944, sign: false }); + data.append(FP16x16 { mag: 25130, sign: false }); + data.append(FP16x16 { mag: 70069, sign: true }); + data.append(FP16x16 { mag: 19429, sign: false }); + data.append(FP16x16 { mag: 17607, sign: true }); + data.append(FP16x16 { mag: 20059, sign: true }); + data.append(FP16x16 { mag: 30637, sign: false }); + data.append(FP16x16 { mag: 110305, sign: true }); + data.append(FP16x16 { mag: 47227, sign: true }); + data.append(FP16x16 { mag: 52701, sign: false }); + data.append(FP16x16 { mag: 113586, sign: true }); + data.append(FP16x16 { mag: 28893, sign: false }); + data.append(FP16x16 { mag: 19309, sign: true }); + data.append(FP16x16 { mag: 9704, sign: true }); + data.append(FP16x16 { mag: 56793, sign: false }); + data.append(FP16x16 { mag: 44072, sign: false }); + data.append(FP16x16 { mag: 2660, sign: true }); + data.append(FP16x16 { mag: 29401, sign: true }); + data.append(FP16x16 { mag: 36736, sign: true }); + data.append(FP16x16 { mag: 86250, sign: false }); + data.append(FP16x16 { mag: 38971, sign: true }); + data.append(FP16x16 { mag: 138249, sign: true }); + data.append(FP16x16 { mag: 36048, sign: false }); + data.append(FP16x16 { mag: 37452, sign: false }); + data.append(FP16x16 { mag: 23181, sign: true }); + data.append(FP16x16 { mag: 35955, sign: true }); + data.append(FP16x16 { mag: 51748, sign: true }); + data.append(FP16x16 { mag: 57097, sign: true }); + data.append(FP16x16 { mag: 91645, sign: true }); + data.append(FP16x16 { mag: 727, sign: true }); + data.append(FP16x16 { mag: 26384, sign: true }); + data.append(FP16x16 { mag: 1331, sign: false }); + data.append(FP16x16 { mag: 98672, sign: false }); + data.append(FP16x16 { mag: 82596, sign: false }); + data.append(FP16x16 { mag: 100984, sign: true }); + data.append(FP16x16 { mag: 88934, sign: true }); + data.append(FP16x16 { mag: 110736, sign: false }); + data.append(FP16x16 { mag: 106050, sign: true }); + data.append(FP16x16 { mag: 43286, sign: true }); + data.append(FP16x16 { mag: 114427, sign: false }); + data.append(FP16x16 { mag: 26160, sign: false }); + data.append(FP16x16 { mag: 19914, sign: true }); + data.append(FP16x16 { mag: 168031, sign: false }); + data.append(FP16x16 { mag: 70834, sign: true }); + data.append(FP16x16 { mag: 62785, sign: false }); + data.append(FP16x16 { mag: 20711, sign: true }); + data.append(FP16x16 { mag: 150814, sign: true }); + data.append(FP16x16 { mag: 19909, sign: true }); + data.append(FP16x16 { mag: 4360, sign: true }); + data.append(FP16x16 { mag: 38550, sign: false }); + data.append(FP16x16 { mag: 155210, sign: false }); + data.append(FP16x16 { mag: 49387, sign: true }); + data.append(FP16x16 { mag: 48606, sign: false }); + data.append(FP16x16 { mag: 26297, sign: false }); + data.append(FP16x16 { mag: 5832, sign: true }); + data.append(FP16x16 { mag: 67019, sign: false }); + data.append(FP16x16 { mag: 108552, sign: false }); + data.append(FP16x16 { mag: 38302, sign: true }); + data.append(FP16x16 { mag: 67467, sign: false }); + data.append(FP16x16 { mag: 123524, sign: false }); + data.append(FP16x16 { mag: 38110, sign: true }); + data.append(FP16x16 { mag: 49459, sign: true }); + data.append(FP16x16 { mag: 89977, sign: false }); + data.append(FP16x16 { mag: 3116, sign: false }); + data.append(FP16x16 { mag: 139868, sign: true }); + data.append(FP16x16 { mag: 3777, sign: false }); + data.append(FP16x16 { mag: 93508, sign: false }); + data.append(FP16x16 { mag: 18473, sign: true }); + data.append(FP16x16 { mag: 9749, sign: false }); + data.append(FP16x16 { mag: 69280, sign: false }); + data.append(FP16x16 { mag: 31261, sign: true }); + data.append(FP16x16 { mag: 70810, sign: true }); + data.append(FP16x16 { mag: 53719, sign: true }); + data.append(FP16x16 { mag: 25490, sign: false }); + data.append(FP16x16 { mag: 76561, sign: true }); + data.append(FP16x16 { mag: 87810, sign: true }); + data.append(FP16x16 { mag: 54546, sign: true }); + data.append(FP16x16 { mag: 11956, sign: false }); + data.append(FP16x16 { mag: 53981, sign: true }); + data.append(FP16x16 { mag: 48734, sign: false }); + data.append(FP16x16 { mag: 81861, sign: true }); + data.append(FP16x16 { mag: 91276, sign: false }); + data.append(FP16x16 { mag: 104233, sign: false }); + data.append(FP16x16 { mag: 52396, sign: false }); + data.append(FP16x16 { mag: 37016, sign: false }); + data.append(FP16x16 { mag: 39784, sign: false }); + data.append(FP16x16 { mag: 16087, sign: true }); + data.append(FP16x16 { mag: 22464, sign: true }); + data.append(FP16x16 { mag: 98432, sign: true }); + data.append(FP16x16 { mag: 120832, sign: true }); + data.append(FP16x16 { mag: 25665, sign: true }); + data.append(FP16x16 { mag: 23338, sign: true }); + data.append(FP16x16 { mag: 18801, sign: false }); + data.append(FP16x16 { mag: 22675, sign: false }); + data.append(FP16x16 { mag: 79634, sign: true }); + data.append(FP16x16 { mag: 95608, sign: true }); + data.append(FP16x16 { mag: 73767, sign: false }); + data.append(FP16x16 { mag: 20235, sign: false }); + data.append(FP16x16 { mag: 19535, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis-2/input_1.cairo b/tests/nodes/layer_normalization_4d_axis-2/input_1.cairo new file mode 100644 index 000000000..88404328d --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis-2/input_1.cairo @@ -0,0 +1,33 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 12195, sign: false }); + data.append(FP16x16 { mag: 25593, sign: false }); + data.append(FP16x16 { mag: 14094, sign: true }); + data.append(FP16x16 { mag: 99626, sign: true }); + data.append(FP16x16 { mag: 48676, sign: true }); + data.append(FP16x16 { mag: 58459, sign: false }); + data.append(FP16x16 { mag: 96699, sign: true }); + data.append(FP16x16 { mag: 14935, sign: true }); + data.append(FP16x16 { mag: 2362, sign: true }); + data.append(FP16x16 { mag: 150235, sign: false }); + data.append(FP16x16 { mag: 65730, sign: true }); + data.append(FP16x16 { mag: 56267, sign: false }); + data.append(FP16x16 { mag: 83617, sign: true }); + data.append(FP16x16 { mag: 34940, sign: false }); + data.append(FP16x16 { mag: 14826, sign: false }); + data.append(FP16x16 { mag: 67759, sign: true }); + data.append(FP16x16 { mag: 88099, sign: true }); + data.append(FP16x16 { mag: 103290, sign: true }); + data.append(FP16x16 { mag: 50684, sign: true }); + data.append(FP16x16 { mag: 29161, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis-2/input_2.cairo b/tests/nodes/layer_normalization_4d_axis-2/input_2.cairo new file mode 100644 index 000000000..d742874f7 --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis-2/input_2.cairo @@ -0,0 +1,33 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_2() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 44676, sign: false }); + data.append(FP16x16 { mag: 53361, sign: true }); + data.append(FP16x16 { mag: 33378, sign: true }); + data.append(FP16x16 { mag: 43061, sign: false }); + data.append(FP16x16 { mag: 24801, sign: false }); + data.append(FP16x16 { mag: 33406, sign: false }); + data.append(FP16x16 { mag: 54529, sign: true }); + data.append(FP16x16 { mag: 133687, sign: false }); + data.append(FP16x16 { mag: 44032, sign: true }); + data.append(FP16x16 { mag: 38747, sign: true }); + data.append(FP16x16 { mag: 34054, sign: true }); + data.append(FP16x16 { mag: 45436, sign: false }); + data.append(FP16x16 { mag: 80815, sign: false }); + data.append(FP16x16 { mag: 79372, sign: true }); + data.append(FP16x16 { mag: 17958, sign: true }); + data.append(FP16x16 { mag: 49483, sign: false }); + data.append(FP16x16 { mag: 46695, sign: true }); + data.append(FP16x16 { mag: 1816, sign: true }); + data.append(FP16x16 { mag: 43264, sign: true }); + data.append(FP16x16 { mag: 59187, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis-2/output_0.cairo b/tests/nodes/layer_normalization_4d_axis-2/output_0.cairo new file mode 100644 index 000000000..330008c16 --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis-2/output_0.cairo @@ -0,0 +1,135 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 34980, sign: false }); + data.append(FP16x16 { mag: 14079, sign: true }); + data.append(FP16x16 { mag: 62083, sign: true }); + data.append(FP16x16 { mag: 147603, sign: false }); + data.append(FP16x16 { mag: 11308, sign: true }); + data.append(FP16x16 { mag: 126598, sign: false }); + data.append(FP16x16 { mag: 73983, sign: false }); + data.append(FP16x16 { mag: 120660, sign: false }); + data.append(FP16x16 { mag: 41515, sign: true }); + data.append(FP16x16 { mag: 53720, sign: true }); + data.append(FP16x16 { mag: 11229, sign: true }); + data.append(FP16x16 { mag: 113880, sign: false }); + data.append(FP16x16 { mag: 59791, sign: false }); + data.append(FP16x16 { mag: 115620, sign: true }); + data.append(FP16x16 { mag: 19505, sign: true }); + data.append(FP16x16 { mag: 29861, sign: false }); + data.append(FP16x16 { mag: 3122, sign: true }); + data.append(FP16x16 { mag: 79516, sign: false }); + data.append(FP16x16 { mag: 11908, sign: false }); + data.append(FP16x16 { mag: 49306, sign: true }); + data.append(FP16x16 { mag: 30901, sign: false }); + data.append(FP16x16 { mag: 4892, sign: true }); + data.append(FP16x16 { mag: 42556, sign: true }); + data.append(FP16x16 { mag: 154572, sign: false }); + data.append(FP16x16 { mag: 1734, sign: true }); + data.append(FP16x16 { mag: 25009, sign: false }); + data.append(FP16x16 { mag: 36229, sign: true }); + data.append(FP16x16 { mag: 122431, sign: false }); + data.append(FP16x16 { mag: 39621, sign: true }); + data.append(FP16x16 { mag: 143085, sign: true }); + data.append(FP16x16 { mag: 110561, sign: true }); + data.append(FP16x16 { mag: 63080, sign: true }); + data.append(FP16x16 { mag: 20513, sign: false }); + data.append(FP16x16 { mag: 85496, sign: true }); + data.append(FP16x16 { mag: 17909, sign: true }); + data.append(FP16x16 { mag: 34541, sign: true }); + data.append(FP16x16 { mag: 135099, sign: true }); + data.append(FP16x16 { mag: 15693, sign: true }); + data.append(FP16x16 { mag: 24867, sign: true }); + data.append(FP16x16 { mag: 44624, sign: true }); + data.append(FP16x16 { mag: 62432, sign: false }); + data.append(FP16x16 { mag: 61070, sign: true }); + data.append(FP16x16 { mag: 9497, sign: true }); + data.append(FP16x16 { mag: 31812, sign: true }); + data.append(FP16x16 { mag: 12740, sign: true }); + data.append(FP16x16 { mag: 28750, sign: false }); + data.append(FP16x16 { mag: 29494, sign: true }); + data.append(FP16x16 { mag: 140863, sign: false }); + data.append(FP16x16 { mag: 42720, sign: true }); + data.append(FP16x16 { mag: 195054, sign: true }); + data.append(FP16x16 { mag: 49531, sign: true }); + data.append(FP16x16 { mag: 38426, sign: false }); + data.append(FP16x16 { mag: 58711, sign: false }); + data.append(FP16x16 { mag: 22407, sign: true }); + data.append(FP16x16 { mag: 2868, sign: false }); + data.append(FP16x16 { mag: 128861, sign: false }); + data.append(FP16x16 { mag: 41612, sign: false }); + data.append(FP16x16 { mag: 187704, sign: true }); + data.append(FP16x16 { mag: 19714, sign: false }); + data.append(FP16x16 { mag: 48637, sign: true }); + data.append(FP16x16 { mag: 58990, sign: false }); + data.append(FP16x16 { mag: 53022, sign: true }); + data.append(FP16x16 { mag: 25027, sign: true }); + data.append(FP16x16 { mag: 144085, sign: true }); + data.append(FP16x16 { mag: 86230, sign: false }); + data.append(FP16x16 { mag: 62331, sign: false }); + data.append(FP16x16 { mag: 3779, sign: false }); + data.append(FP16x16 { mag: 168240, sign: false }); + data.append(FP16x16 { mag: 42633, sign: true }); + data.append(FP16x16 { mag: 97041, sign: true }); + data.append(FP16x16 { mag: 45633, sign: true }); + data.append(FP16x16 { mag: 141649, sign: false }); + data.append(FP16x16 { mag: 162760, sign: false }); + data.append(FP16x16 { mag: 68597, sign: true }); + data.append(FP16x16 { mag: 17735, sign: true }); + data.append(FP16x16 { mag: 77086, sign: false }); + data.append(FP16x16 { mag: 95189, sign: true }); + data.append(FP16x16 { mag: 115074, sign: true }); + data.append(FP16x16 { mag: 980, sign: true }); + data.append(FP16x16 { mag: 75411, sign: true }); + data.append(FP16x16 { mag: 69778, sign: false }); + data.append(FP16x16 { mag: 63413, sign: true }); + data.append(FP16x16 { mag: 25417, sign: true }); + data.append(FP16x16 { mag: 111323, sign: true }); + data.append(FP16x16 { mag: 13487, sign: false }); + data.append(FP16x16 { mag: 79762, sign: true }); + data.append(FP16x16 { mag: 77975, sign: true }); + data.append(FP16x16 { mag: 109742, sign: false }); + data.append(FP16x16 { mag: 43808, sign: true }); + data.append(FP16x16 { mag: 11284, sign: false }); + data.append(FP16x16 { mag: 115283, sign: true }); + data.append(FP16x16 { mag: 29181, sign: false }); + data.append(FP16x16 { mag: 155120, sign: false }); + data.append(FP16x16 { mag: 101365, sign: true }); + data.append(FP16x16 { mag: 9482, sign: true }); + data.append(FP16x16 { mag: 115606, sign: false }); + data.append(FP16x16 { mag: 54304, sign: false }); + data.append(FP16x16 { mag: 64495, sign: false }); + data.append(FP16x16 { mag: 61839, sign: true }); + data.append(FP16x16 { mag: 40715, sign: true }); + data.append(FP16x16 { mag: 54635, sign: false }); + data.append(FP16x16 { mag: 84491, sign: true }); + data.append(FP16x16 { mag: 54222, sign: true }); + data.append(FP16x16 { mag: 124370, sign: true }); + data.append(FP16x16 { mag: 17725, sign: true }); + data.append(FP16x16 { mag: 70484, sign: false }); + data.append(FP16x16 { mag: 120027, sign: true }); + data.append(FP16x16 { mag: 136560, sign: false }); + data.append(FP16x16 { mag: 43343, sign: true }); + data.append(FP16x16 { mag: 260233, sign: true }); + data.append(FP16x16 { mag: 85770, sign: false }); + data.append(FP16x16 { mag: 26220, sign: false }); + data.append(FP16x16 { mag: 106343, sign: false }); + data.append(FP16x16 { mag: 67118, sign: true }); + data.append(FP16x16 { mag: 11865, sign: true }); + data.append(FP16x16 { mag: 129550, sign: false }); + data.append(FP16x16 { mag: 79314, sign: false }); + data.append(FP16x16 { mag: 126419, sign: true }); + data.append(FP16x16 { mag: 62170, sign: true }); + data.append(FP16x16 { mag: 52010, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis-3/input_0.cairo b/tests/nodes/layer_normalization_4d_axis-3/input_0.cairo new file mode 100644 index 000000000..796ab8161 --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis-3/input_0.cairo @@ -0,0 +1,135 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 50439, sign: true }); + data.append(FP16x16 { mag: 80584, sign: false }); + data.append(FP16x16 { mag: 108804, sign: false }); + data.append(FP16x16 { mag: 64736, sign: true }); + data.append(FP16x16 { mag: 35989, sign: false }); + data.append(FP16x16 { mag: 83918, sign: false }); + data.append(FP16x16 { mag: 80462, sign: true }); + data.append(FP16x16 { mag: 43318, sign: false }); + data.append(FP16x16 { mag: 65651, sign: true }); + data.append(FP16x16 { mag: 11331, sign: true }); + data.append(FP16x16 { mag: 25254, sign: true }); + data.append(FP16x16 { mag: 62677, sign: false }); + data.append(FP16x16 { mag: 8412, sign: false }); + data.append(FP16x16 { mag: 64067, sign: true }); + data.append(FP16x16 { mag: 11592, sign: true }); + data.append(FP16x16 { mag: 10557, sign: false }); + data.append(FP16x16 { mag: 33540, sign: true }); + data.append(FP16x16 { mag: 50007, sign: true }); + data.append(FP16x16 { mag: 66942, sign: true }); + data.append(FP16x16 { mag: 24781, sign: true }); + data.append(FP16x16 { mag: 70618, sign: true }); + data.append(FP16x16 { mag: 91944, sign: false }); + data.append(FP16x16 { mag: 25130, sign: false }); + data.append(FP16x16 { mag: 70069, sign: true }); + data.append(FP16x16 { mag: 19429, sign: false }); + data.append(FP16x16 { mag: 17607, sign: true }); + data.append(FP16x16 { mag: 20059, sign: true }); + data.append(FP16x16 { mag: 30637, sign: false }); + data.append(FP16x16 { mag: 110305, sign: true }); + data.append(FP16x16 { mag: 47227, sign: true }); + data.append(FP16x16 { mag: 52701, sign: false }); + data.append(FP16x16 { mag: 113586, sign: true }); + data.append(FP16x16 { mag: 28893, sign: false }); + data.append(FP16x16 { mag: 19309, sign: true }); + data.append(FP16x16 { mag: 9704, sign: true }); + data.append(FP16x16 { mag: 56793, sign: false }); + data.append(FP16x16 { mag: 44072, sign: false }); + data.append(FP16x16 { mag: 2660, sign: true }); + data.append(FP16x16 { mag: 29401, sign: true }); + data.append(FP16x16 { mag: 36736, sign: true }); + data.append(FP16x16 { mag: 86250, sign: false }); + data.append(FP16x16 { mag: 38971, sign: true }); + data.append(FP16x16 { mag: 138249, sign: true }); + data.append(FP16x16 { mag: 36048, sign: false }); + data.append(FP16x16 { mag: 37452, sign: false }); + data.append(FP16x16 { mag: 23181, sign: true }); + data.append(FP16x16 { mag: 35955, sign: true }); + data.append(FP16x16 { mag: 51748, sign: true }); + data.append(FP16x16 { mag: 57097, sign: true }); + data.append(FP16x16 { mag: 91645, sign: true }); + data.append(FP16x16 { mag: 727, sign: true }); + data.append(FP16x16 { mag: 26384, sign: true }); + data.append(FP16x16 { mag: 1331, sign: false }); + data.append(FP16x16 { mag: 98672, sign: false }); + data.append(FP16x16 { mag: 82596, sign: false }); + data.append(FP16x16 { mag: 100984, sign: true }); + data.append(FP16x16 { mag: 88934, sign: true }); + data.append(FP16x16 { mag: 110736, sign: false }); + data.append(FP16x16 { mag: 106050, sign: true }); + data.append(FP16x16 { mag: 43286, sign: true }); + data.append(FP16x16 { mag: 114427, sign: false }); + data.append(FP16x16 { mag: 26160, sign: false }); + data.append(FP16x16 { mag: 19914, sign: true }); + data.append(FP16x16 { mag: 168031, sign: false }); + data.append(FP16x16 { mag: 70834, sign: true }); + data.append(FP16x16 { mag: 62785, sign: false }); + data.append(FP16x16 { mag: 20711, sign: true }); + data.append(FP16x16 { mag: 150814, sign: true }); + data.append(FP16x16 { mag: 19909, sign: true }); + data.append(FP16x16 { mag: 4360, sign: true }); + data.append(FP16x16 { mag: 38550, sign: false }); + data.append(FP16x16 { mag: 155210, sign: false }); + data.append(FP16x16 { mag: 49387, sign: true }); + data.append(FP16x16 { mag: 48606, sign: false }); + data.append(FP16x16 { mag: 26297, sign: false }); + data.append(FP16x16 { mag: 5832, sign: true }); + data.append(FP16x16 { mag: 67019, sign: false }); + data.append(FP16x16 { mag: 108552, sign: false }); + data.append(FP16x16 { mag: 38302, sign: true }); + data.append(FP16x16 { mag: 67467, sign: false }); + data.append(FP16x16 { mag: 123524, sign: false }); + data.append(FP16x16 { mag: 38110, sign: true }); + data.append(FP16x16 { mag: 49459, sign: true }); + data.append(FP16x16 { mag: 89977, sign: false }); + data.append(FP16x16 { mag: 3116, sign: false }); + data.append(FP16x16 { mag: 139868, sign: true }); + data.append(FP16x16 { mag: 3777, sign: false }); + data.append(FP16x16 { mag: 93508, sign: false }); + data.append(FP16x16 { mag: 18473, sign: true }); + data.append(FP16x16 { mag: 9749, sign: false }); + data.append(FP16x16 { mag: 69280, sign: false }); + data.append(FP16x16 { mag: 31261, sign: true }); + data.append(FP16x16 { mag: 70810, sign: true }); + data.append(FP16x16 { mag: 53719, sign: true }); + data.append(FP16x16 { mag: 25490, sign: false }); + data.append(FP16x16 { mag: 76561, sign: true }); + data.append(FP16x16 { mag: 87810, sign: true }); + data.append(FP16x16 { mag: 54546, sign: true }); + data.append(FP16x16 { mag: 11956, sign: false }); + data.append(FP16x16 { mag: 53981, sign: true }); + data.append(FP16x16 { mag: 48734, sign: false }); + data.append(FP16x16 { mag: 81861, sign: true }); + data.append(FP16x16 { mag: 91276, sign: false }); + data.append(FP16x16 { mag: 104233, sign: false }); + data.append(FP16x16 { mag: 52396, sign: false }); + data.append(FP16x16 { mag: 37016, sign: false }); + data.append(FP16x16 { mag: 39784, sign: false }); + data.append(FP16x16 { mag: 16087, sign: true }); + data.append(FP16x16 { mag: 22464, sign: true }); + data.append(FP16x16 { mag: 98432, sign: true }); + data.append(FP16x16 { mag: 120832, sign: true }); + data.append(FP16x16 { mag: 25665, sign: true }); + data.append(FP16x16 { mag: 23338, sign: true }); + data.append(FP16x16 { mag: 18801, sign: false }); + data.append(FP16x16 { mag: 22675, sign: false }); + data.append(FP16x16 { mag: 79634, sign: true }); + data.append(FP16x16 { mag: 95608, sign: true }); + data.append(FP16x16 { mag: 73767, sign: false }); + data.append(FP16x16 { mag: 20235, sign: false }); + data.append(FP16x16 { mag: 19535, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis-3/input_1.cairo b/tests/nodes/layer_normalization_4d_axis-3/input_1.cairo new file mode 100644 index 000000000..f0db6e787 --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis-3/input_1.cairo @@ -0,0 +1,74 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 72546, sign: true }); + data.append(FP16x16 { mag: 1934, sign: true }); + data.append(FP16x16 { mag: 26602, sign: false }); + data.append(FP16x16 { mag: 64666, sign: false }); + data.append(FP16x16 { mag: 113155, sign: false }); + data.append(FP16x16 { mag: 140657, sign: true }); + data.append(FP16x16 { mag: 8029, sign: false }); + data.append(FP16x16 { mag: 1972, sign: true }); + data.append(FP16x16 { mag: 13859, sign: true }); + data.append(FP16x16 { mag: 30242, sign: true }); + data.append(FP16x16 { mag: 5299, sign: false }); + data.append(FP16x16 { mag: 164784, sign: false }); + data.append(FP16x16 { mag: 46610, sign: false }); + data.append(FP16x16 { mag: 8040, sign: true }); + data.append(FP16x16 { mag: 2005, sign: false }); + data.append(FP16x16 { mag: 43621, sign: true }); + data.append(FP16x16 { mag: 23042, sign: false }); + data.append(FP16x16 { mag: 12823, sign: true }); + data.append(FP16x16 { mag: 93996, sign: true }); + data.append(FP16x16 { mag: 32415, sign: true }); + data.append(FP16x16 { mag: 78956, sign: true }); + data.append(FP16x16 { mag: 74322, sign: true }); + data.append(FP16x16 { mag: 3724, sign: true }); + data.append(FP16x16 { mag: 5105, sign: true }); + data.append(FP16x16 { mag: 54473, sign: false }); + data.append(FP16x16 { mag: 38101, sign: false }); + data.append(FP16x16 { mag: 12378, sign: false }); + data.append(FP16x16 { mag: 36960, sign: true }); + data.append(FP16x16 { mag: 87265, sign: true }); + data.append(FP16x16 { mag: 34339, sign: true }); + data.append(FP16x16 { mag: 109979, sign: false }); + data.append(FP16x16 { mag: 138917, sign: true }); + data.append(FP16x16 { mag: 93741, sign: false }); + data.append(FP16x16 { mag: 139243, sign: false }); + data.append(FP16x16 { mag: 42853, sign: true }); + data.append(FP16x16 { mag: 33084, sign: false }); + data.append(FP16x16 { mag: 23897, sign: true }); + data.append(FP16x16 { mag: 52766, sign: true }); + data.append(FP16x16 { mag: 38574, sign: true }); + data.append(FP16x16 { mag: 10485, sign: true }); + data.append(FP16x16 { mag: 60984, sign: true }); + data.append(FP16x16 { mag: 20096, sign: false }); + data.append(FP16x16 { mag: 24418, sign: false }); + data.append(FP16x16 { mag: 29015, sign: false }); + data.append(FP16x16 { mag: 49773, sign: false }); + data.append(FP16x16 { mag: 104302, sign: false }); + data.append(FP16x16 { mag: 39141, sign: false }); + data.append(FP16x16 { mag: 23473, sign: true }); + data.append(FP16x16 { mag: 33895, sign: false }); + data.append(FP16x16 { mag: 12032, sign: true }); + data.append(FP16x16 { mag: 60207, sign: false }); + data.append(FP16x16 { mag: 64204, sign: false }); + data.append(FP16x16 { mag: 34732, sign: true }); + data.append(FP16x16 { mag: 31629, sign: false }); + data.append(FP16x16 { mag: 3538, sign: true }); + data.append(FP16x16 { mag: 13382, sign: false }); + data.append(FP16x16 { mag: 29195, sign: true }); + data.append(FP16x16 { mag: 5944, sign: true }); + data.append(FP16x16 { mag: 103840, sign: true }); + data.append(FP16x16 { mag: 38426, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis-3/input_2.cairo b/tests/nodes/layer_normalization_4d_axis-3/input_2.cairo new file mode 100644 index 000000000..bc434b717 --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis-3/input_2.cairo @@ -0,0 +1,74 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_2() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 140567, sign: false }); + data.append(FP16x16 { mag: 40459, sign: true }); + data.append(FP16x16 { mag: 24644, sign: false }); + data.append(FP16x16 { mag: 111555, sign: false }); + data.append(FP16x16 { mag: 8570, sign: false }); + data.append(FP16x16 { mag: 10992, sign: true }); + data.append(FP16x16 { mag: 212035, sign: false }); + data.append(FP16x16 { mag: 11465, sign: false }); + data.append(FP16x16 { mag: 73223, sign: false }); + data.append(FP16x16 { mag: 105723, sign: true }); + data.append(FP16x16 { mag: 75766, sign: true }); + data.append(FP16x16 { mag: 102598, sign: true }); + data.append(FP16x16 { mag: 94019, sign: true }); + data.append(FP16x16 { mag: 35770, sign: false }); + data.append(FP16x16 { mag: 78502, sign: false }); + data.append(FP16x16 { mag: 41664, sign: true }); + data.append(FP16x16 { mag: 29280, sign: true }); + data.append(FP16x16 { mag: 124649, sign: false }); + data.append(FP16x16 { mag: 99019, sign: true }); + data.append(FP16x16 { mag: 66970, sign: true }); + data.append(FP16x16 { mag: 8747, sign: true }); + data.append(FP16x16 { mag: 79443, sign: true }); + data.append(FP16x16 { mag: 68471, sign: false }); + data.append(FP16x16 { mag: 11126, sign: true }); + data.append(FP16x16 { mag: 24010, sign: true }); + data.append(FP16x16 { mag: 4088, sign: false }); + data.append(FP16x16 { mag: 49743, sign: false }); + data.append(FP16x16 { mag: 8246, sign: false }); + data.append(FP16x16 { mag: 84164, sign: false }); + data.append(FP16x16 { mag: 86259, sign: false }); + data.append(FP16x16 { mag: 42086, sign: false }); + data.append(FP16x16 { mag: 5197, sign: false }); + data.append(FP16x16 { mag: 1592, sign: false }); + data.append(FP16x16 { mag: 41261, sign: true }); + data.append(FP16x16 { mag: 52531, sign: false }); + data.append(FP16x16 { mag: 9495, sign: true }); + data.append(FP16x16 { mag: 18817, sign: true }); + data.append(FP16x16 { mag: 2626, sign: false }); + data.append(FP16x16 { mag: 7622, sign: true }); + data.append(FP16x16 { mag: 34140, sign: true }); + data.append(FP16x16 { mag: 29535, sign: true }); + data.append(FP16x16 { mag: 43504, sign: true }); + data.append(FP16x16 { mag: 71292, sign: true }); + data.append(FP16x16 { mag: 30928, sign: false }); + data.append(FP16x16 { mag: 113967, sign: false }); + data.append(FP16x16 { mag: 38791, sign: false }); + data.append(FP16x16 { mag: 89439, sign: true }); + data.append(FP16x16 { mag: 73665, sign: true }); + data.append(FP16x16 { mag: 103501, sign: false }); + data.append(FP16x16 { mag: 27579, sign: false }); + data.append(FP16x16 { mag: 81207, sign: false }); + data.append(FP16x16 { mag: 8128, sign: true }); + data.append(FP16x16 { mag: 7160, sign: true }); + data.append(FP16x16 { mag: 159276, sign: true }); + data.append(FP16x16 { mag: 93395, sign: false }); + data.append(FP16x16 { mag: 6090, sign: false }); + data.append(FP16x16 { mag: 29949, sign: false }); + data.append(FP16x16 { mag: 24359, sign: true }); + data.append(FP16x16 { mag: 44435, sign: false }); + data.append(FP16x16 { mag: 57014, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis-3/output_0.cairo b/tests/nodes/layer_normalization_4d_axis-3/output_0.cairo new file mode 100644 index 000000000..896cfb65f --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis-3/output_0.cairo @@ -0,0 +1,135 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 187342, sign: false }); + data.append(FP16x16 { mag: 43360, sign: true }); + data.append(FP16x16 { mag: 76815, sign: false }); + data.append(FP16x16 { mag: 54732, sign: false }); + data.append(FP16x16 { mag: 95653, sign: false }); + data.append(FP16x16 { mag: 229562, sign: true }); + data.append(FP16x16 { mag: 202912, sign: false }); + data.append(FP16x16 { mag: 9710, sign: false }); + data.append(FP16x16 { mag: 85609, sign: false }); + data.append(FP16x16 { mag: 105578, sign: true }); + data.append(FP16x16 { mag: 76999, sign: true }); + data.append(FP16x16 { mag: 96182, sign: false }); + data.append(FP16x16 { mag: 79183, sign: true }); + data.append(FP16x16 { mag: 42747, sign: false }); + data.append(FP16x16 { mag: 78484, sign: false }); + data.append(FP16x16 { mag: 57080, sign: true }); + data.append(FP16x16 { mag: 37764, sign: true }); + data.append(FP16x16 { mag: 132826, sign: false }); + data.append(FP16x16 { mag: 13030, sign: true }); + data.append(FP16x16 { mag: 59680, sign: true }); + data.append(FP16x16 { mag: 68231, sign: false }); + data.append(FP16x16 { mag: 204694, sign: true }); + data.append(FP16x16 { mag: 66266, sign: false }); + data.append(FP16x16 { mag: 6194, sign: true }); + data.append(FP16x16 { mag: 3149, sign: false }); + data.append(FP16x16 { mag: 6, sign: true }); + data.append(FP16x16 { mag: 47916, sign: false }); + data.append(FP16x16 { mag: 16960, sign: true }); + data.append(FP16x16 { mag: 225919, sign: false }); + data.append(FP16x16 { mag: 106594, sign: false }); + data.append(FP16x16 { mag: 156803, sign: false }); + data.append(FP16x16 { mag: 238316, sign: false }); + data.append(FP16x16 { mag: 62849, sign: false }); + data.append(FP16x16 { mag: 60107, sign: true }); + data.append(FP16x16 { mag: 51595, sign: false }); + data.append(FP16x16 { mag: 27228, sign: false }); + data.append(FP16x16 { mag: 40369, sign: true }); + data.append(FP16x16 { mag: 4608, sign: true }); + data.append(FP16x16 { mag: 3968, sign: false }); + data.append(FP16x16 { mag: 29731, sign: true }); + data.append(FP16x16 { mag: 126627, sign: true }); + data.append(FP16x16 { mag: 52689, sign: true }); + data.append(FP16x16 { mag: 122123, sign: true }); + data.append(FP16x16 { mag: 53286, sign: false }); + data.append(FP16x16 { mag: 153464, sign: false }); + data.append(FP16x16 { mag: 18067, sign: false }); + data.append(FP16x16 { mag: 105399, sign: true }); + data.append(FP16x16 { mag: 58027, sign: true }); + data.append(FP16x16 { mag: 77954, sign: false }); + data.append(FP16x16 { mag: 43451, sign: false }); + data.append(FP16x16 { mag: 91367, sign: false }); + data.append(FP16x16 { mag: 24251, sign: true }); + data.append(FP16x16 { mag: 14192, sign: true }); + data.append(FP16x16 { mag: 102491, sign: true }); + data.append(FP16x16 { mag: 87973, sign: false }); + data.append(FP16x16 { mag: 13606, sign: true }); + data.append(FP16x16 { mag: 67163, sign: false }); + data.append(FP16x16 { mag: 36206, sign: true }); + data.append(FP16x16 { mag: 205884, sign: false }); + data.append(FP16x16 { mag: 36735, sign: true }); + data.append(FP16x16 { mag: 26425, sign: false }); + data.append(FP16x16 { mag: 41090, sign: true }); + data.append(FP16x16 { mag: 16004, sign: false }); + data.append(FP16x16 { mag: 262275, sign: false }); + data.append(FP16x16 { mag: 109588, sign: true }); + data.append(FP16x16 { mag: 129666, sign: true }); + data.append(FP16x16 { mag: 209336, sign: false }); + data.append(FP16x16 { mag: 15754, sign: false }); + data.append(FP16x16 { mag: 77723, sign: false }); + data.append(FP16x16 { mag: 102547, sign: true }); + data.append(FP16x16 { mag: 73109, sign: true }); + data.append(FP16x16 { mag: 251621, sign: false }); + data.append(FP16x16 { mag: 128567, sign: true }); + data.append(FP16x16 { mag: 30597, sign: false }); + data.append(FP16x16 { mag: 79160, sign: false }); + data.append(FP16x16 { mag: 36176, sign: true }); + data.append(FP16x16 { mag: 8461, sign: true }); + data.append(FP16x16 { mag: 105536, sign: false }); + data.append(FP16x16 { mag: 44072, sign: true }); + data.append(FP16x16 { mag: 96463, sign: true }); + data.append(FP16x16 { mag: 143124, sign: true }); + data.append(FP16x16 { mag: 36197, sign: true }); + data.append(FP16x16 { mag: 71235, sign: false }); + data.append(FP16x16 { mag: 17395, sign: true }); + data.append(FP16x16 { mag: 23976, sign: true }); + data.append(FP16x16 { mag: 72861, sign: true }); + data.append(FP16x16 { mag: 49866, sign: false }); + data.append(FP16x16 { mag: 38981, sign: true }); + data.append(FP16x16 { mag: 110729, sign: false }); + data.append(FP16x16 { mag: 83018, sign: false }); + data.append(FP16x16 { mag: 144969, sign: false }); + data.append(FP16x16 { mag: 72585, sign: false }); + data.append(FP16x16 { mag: 96262, sign: true }); + data.append(FP16x16 { mag: 152991, sign: true }); + data.append(FP16x16 { mag: 38957, sign: false }); + data.append(FP16x16 { mag: 46719, sign: true }); + data.append(FP16x16 { mag: 11869, sign: false }); + data.append(FP16x16 { mag: 45582, sign: false }); + data.append(FP16x16 { mag: 12464, sign: true }); + data.append(FP16x16 { mag: 25688, sign: true }); + data.append(FP16x16 { mag: 68881, sign: true }); + data.append(FP16x16 { mag: 67620, sign: true }); + data.append(FP16x16 { mag: 40860, sign: true }); + data.append(FP16x16 { mag: 72401, sign: false }); + data.append(FP16x16 { mag: 148655, sign: false }); + data.append(FP16x16 { mag: 88817, sign: false }); + data.append(FP16x16 { mag: 69136, sign: true }); + data.append(FP16x16 { mag: 67311, sign: true }); + data.append(FP16x16 { mag: 91271, sign: false }); + data.append(FP16x16 { mag: 44836, sign: false }); + data.append(FP16x16 { mag: 24194, sign: true }); + data.append(FP16x16 { mag: 34197, sign: true }); + data.append(FP16x16 { mag: 5799, sign: false }); + data.append(FP16x16 { mag: 152246, sign: true }); + data.append(FP16x16 { mag: 92415, sign: false }); + data.append(FP16x16 { mag: 9547, sign: true }); + data.append(FP16x16 { mag: 70654, sign: false }); + data.append(FP16x16 { mag: 30297, sign: true }); + data.append(FP16x16 { mag: 19252, sign: false }); + data.append(FP16x16 { mag: 44739, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis-4/input_0.cairo b/tests/nodes/layer_normalization_4d_axis-4/input_0.cairo new file mode 100644 index 000000000..796ab8161 --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis-4/input_0.cairo @@ -0,0 +1,135 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 50439, sign: true }); + data.append(FP16x16 { mag: 80584, sign: false }); + data.append(FP16x16 { mag: 108804, sign: false }); + data.append(FP16x16 { mag: 64736, sign: true }); + data.append(FP16x16 { mag: 35989, sign: false }); + data.append(FP16x16 { mag: 83918, sign: false }); + data.append(FP16x16 { mag: 80462, sign: true }); + data.append(FP16x16 { mag: 43318, sign: false }); + data.append(FP16x16 { mag: 65651, sign: true }); + data.append(FP16x16 { mag: 11331, sign: true }); + data.append(FP16x16 { mag: 25254, sign: true }); + data.append(FP16x16 { mag: 62677, sign: false }); + data.append(FP16x16 { mag: 8412, sign: false }); + data.append(FP16x16 { mag: 64067, sign: true }); + data.append(FP16x16 { mag: 11592, sign: true }); + data.append(FP16x16 { mag: 10557, sign: false }); + data.append(FP16x16 { mag: 33540, sign: true }); + data.append(FP16x16 { mag: 50007, sign: true }); + data.append(FP16x16 { mag: 66942, sign: true }); + data.append(FP16x16 { mag: 24781, sign: true }); + data.append(FP16x16 { mag: 70618, sign: true }); + data.append(FP16x16 { mag: 91944, sign: false }); + data.append(FP16x16 { mag: 25130, sign: false }); + data.append(FP16x16 { mag: 70069, sign: true }); + data.append(FP16x16 { mag: 19429, sign: false }); + data.append(FP16x16 { mag: 17607, sign: true }); + data.append(FP16x16 { mag: 20059, sign: true }); + data.append(FP16x16 { mag: 30637, sign: false }); + data.append(FP16x16 { mag: 110305, sign: true }); + data.append(FP16x16 { mag: 47227, sign: true }); + data.append(FP16x16 { mag: 52701, sign: false }); + data.append(FP16x16 { mag: 113586, sign: true }); + data.append(FP16x16 { mag: 28893, sign: false }); + data.append(FP16x16 { mag: 19309, sign: true }); + data.append(FP16x16 { mag: 9704, sign: true }); + data.append(FP16x16 { mag: 56793, sign: false }); + data.append(FP16x16 { mag: 44072, sign: false }); + data.append(FP16x16 { mag: 2660, sign: true }); + data.append(FP16x16 { mag: 29401, sign: true }); + data.append(FP16x16 { mag: 36736, sign: true }); + data.append(FP16x16 { mag: 86250, sign: false }); + data.append(FP16x16 { mag: 38971, sign: true }); + data.append(FP16x16 { mag: 138249, sign: true }); + data.append(FP16x16 { mag: 36048, sign: false }); + data.append(FP16x16 { mag: 37452, sign: false }); + data.append(FP16x16 { mag: 23181, sign: true }); + data.append(FP16x16 { mag: 35955, sign: true }); + data.append(FP16x16 { mag: 51748, sign: true }); + data.append(FP16x16 { mag: 57097, sign: true }); + data.append(FP16x16 { mag: 91645, sign: true }); + data.append(FP16x16 { mag: 727, sign: true }); + data.append(FP16x16 { mag: 26384, sign: true }); + data.append(FP16x16 { mag: 1331, sign: false }); + data.append(FP16x16 { mag: 98672, sign: false }); + data.append(FP16x16 { mag: 82596, sign: false }); + data.append(FP16x16 { mag: 100984, sign: true }); + data.append(FP16x16 { mag: 88934, sign: true }); + data.append(FP16x16 { mag: 110736, sign: false }); + data.append(FP16x16 { mag: 106050, sign: true }); + data.append(FP16x16 { mag: 43286, sign: true }); + data.append(FP16x16 { mag: 114427, sign: false }); + data.append(FP16x16 { mag: 26160, sign: false }); + data.append(FP16x16 { mag: 19914, sign: true }); + data.append(FP16x16 { mag: 168031, sign: false }); + data.append(FP16x16 { mag: 70834, sign: true }); + data.append(FP16x16 { mag: 62785, sign: false }); + data.append(FP16x16 { mag: 20711, sign: true }); + data.append(FP16x16 { mag: 150814, sign: true }); + data.append(FP16x16 { mag: 19909, sign: true }); + data.append(FP16x16 { mag: 4360, sign: true }); + data.append(FP16x16 { mag: 38550, sign: false }); + data.append(FP16x16 { mag: 155210, sign: false }); + data.append(FP16x16 { mag: 49387, sign: true }); + data.append(FP16x16 { mag: 48606, sign: false }); + data.append(FP16x16 { mag: 26297, sign: false }); + data.append(FP16x16 { mag: 5832, sign: true }); + data.append(FP16x16 { mag: 67019, sign: false }); + data.append(FP16x16 { mag: 108552, sign: false }); + data.append(FP16x16 { mag: 38302, sign: true }); + data.append(FP16x16 { mag: 67467, sign: false }); + data.append(FP16x16 { mag: 123524, sign: false }); + data.append(FP16x16 { mag: 38110, sign: true }); + data.append(FP16x16 { mag: 49459, sign: true }); + data.append(FP16x16 { mag: 89977, sign: false }); + data.append(FP16x16 { mag: 3116, sign: false }); + data.append(FP16x16 { mag: 139868, sign: true }); + data.append(FP16x16 { mag: 3777, sign: false }); + data.append(FP16x16 { mag: 93508, sign: false }); + data.append(FP16x16 { mag: 18473, sign: true }); + data.append(FP16x16 { mag: 9749, sign: false }); + data.append(FP16x16 { mag: 69280, sign: false }); + data.append(FP16x16 { mag: 31261, sign: true }); + data.append(FP16x16 { mag: 70810, sign: true }); + data.append(FP16x16 { mag: 53719, sign: true }); + data.append(FP16x16 { mag: 25490, sign: false }); + data.append(FP16x16 { mag: 76561, sign: true }); + data.append(FP16x16 { mag: 87810, sign: true }); + data.append(FP16x16 { mag: 54546, sign: true }); + data.append(FP16x16 { mag: 11956, sign: false }); + data.append(FP16x16 { mag: 53981, sign: true }); + data.append(FP16x16 { mag: 48734, sign: false }); + data.append(FP16x16 { mag: 81861, sign: true }); + data.append(FP16x16 { mag: 91276, sign: false }); + data.append(FP16x16 { mag: 104233, sign: false }); + data.append(FP16x16 { mag: 52396, sign: false }); + data.append(FP16x16 { mag: 37016, sign: false }); + data.append(FP16x16 { mag: 39784, sign: false }); + data.append(FP16x16 { mag: 16087, sign: true }); + data.append(FP16x16 { mag: 22464, sign: true }); + data.append(FP16x16 { mag: 98432, sign: true }); + data.append(FP16x16 { mag: 120832, sign: true }); + data.append(FP16x16 { mag: 25665, sign: true }); + data.append(FP16x16 { mag: 23338, sign: true }); + data.append(FP16x16 { mag: 18801, sign: false }); + data.append(FP16x16 { mag: 22675, sign: false }); + data.append(FP16x16 { mag: 79634, sign: true }); + data.append(FP16x16 { mag: 95608, sign: true }); + data.append(FP16x16 { mag: 73767, sign: false }); + data.append(FP16x16 { mag: 20235, sign: false }); + data.append(FP16x16 { mag: 19535, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis-4/input_1.cairo b/tests/nodes/layer_normalization_4d_axis-4/input_1.cairo new file mode 100644 index 000000000..7505f0217 --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis-4/input_1.cairo @@ -0,0 +1,135 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 9950, sign: true }); + data.append(FP16x16 { mag: 28981, sign: true }); + data.append(FP16x16 { mag: 12325, sign: false }); + data.append(FP16x16 { mag: 111567, sign: true }); + data.append(FP16x16 { mag: 87492, sign: true }); + data.append(FP16x16 { mag: 2266, sign: false }); + data.append(FP16x16 { mag: 91808, sign: false }); + data.append(FP16x16 { mag: 99181, sign: false }); + data.append(FP16x16 { mag: 54619, sign: false }); + data.append(FP16x16 { mag: 56530, sign: true }); + data.append(FP16x16 { mag: 58746, sign: false }); + data.append(FP16x16 { mag: 11025, sign: true }); + data.append(FP16x16 { mag: 63919, sign: false }); + data.append(FP16x16 { mag: 16987, sign: true }); + data.append(FP16x16 { mag: 14843, sign: true }); + data.append(FP16x16 { mag: 9629, sign: false }); + data.append(FP16x16 { mag: 27461, sign: false }); + data.append(FP16x16 { mag: 34324, sign: false }); + data.append(FP16x16 { mag: 102809, sign: true }); + data.append(FP16x16 { mag: 75888, sign: false }); + data.append(FP16x16 { mag: 65510, sign: true }); + data.append(FP16x16 { mag: 54723, sign: true }); + data.append(FP16x16 { mag: 52244, sign: true }); + data.append(FP16x16 { mag: 54119, sign: false }); + data.append(FP16x16 { mag: 28309, sign: false }); + data.append(FP16x16 { mag: 30690, sign: false }); + data.append(FP16x16 { mag: 5155, sign: false }); + data.append(FP16x16 { mag: 1463, sign: false }); + data.append(FP16x16 { mag: 72974, sign: true }); + data.append(FP16x16 { mag: 62563, sign: true }); + data.append(FP16x16 { mag: 115530, sign: true }); + data.append(FP16x16 { mag: 3885, sign: true }); + data.append(FP16x16 { mag: 1274, sign: false }); + data.append(FP16x16 { mag: 20319, sign: true }); + data.append(FP16x16 { mag: 13396, sign: true }); + data.append(FP16x16 { mag: 12923, sign: false }); + data.append(FP16x16 { mag: 13623, sign: false }); + data.append(FP16x16 { mag: 95497, sign: true }); + data.append(FP16x16 { mag: 78541, sign: false }); + data.append(FP16x16 { mag: 61124, sign: false }); + data.append(FP16x16 { mag: 12527, sign: true }); + data.append(FP16x16 { mag: 128370, sign: true }); + data.append(FP16x16 { mag: 81782, sign: false }); + data.append(FP16x16 { mag: 47494, sign: true }); + data.append(FP16x16 { mag: 15027, sign: false }); + data.append(FP16x16 { mag: 76936, sign: false }); + data.append(FP16x16 { mag: 95112, sign: true }); + data.append(FP16x16 { mag: 37987, sign: false }); + data.append(FP16x16 { mag: 11759, sign: false }); + data.append(FP16x16 { mag: 128485, sign: true }); + data.append(FP16x16 { mag: 12506, sign: false }); + data.append(FP16x16 { mag: 16778, sign: false }); + data.append(FP16x16 { mag: 49483, sign: false }); + data.append(FP16x16 { mag: 11346, sign: false }); + data.append(FP16x16 { mag: 58647, sign: true }); + data.append(FP16x16 { mag: 71599, sign: false }); + data.append(FP16x16 { mag: 22777, sign: false }); + data.append(FP16x16 { mag: 17125, sign: true }); + data.append(FP16x16 { mag: 122340, sign: true }); + data.append(FP16x16 { mag: 14001, sign: true }); + data.append(FP16x16 { mag: 43279, sign: false }); + data.append(FP16x16 { mag: 74703, sign: false }); + data.append(FP16x16 { mag: 95648, sign: false }); + data.append(FP16x16 { mag: 9808, sign: true }); + data.append(FP16x16 { mag: 41586, sign: false }); + data.append(FP16x16 { mag: 55055, sign: false }); + data.append(FP16x16 { mag: 29114, sign: false }); + data.append(FP16x16 { mag: 80337, sign: true }); + data.append(FP16x16 { mag: 45090, sign: false }); + data.append(FP16x16 { mag: 21497, sign: true }); + data.append(FP16x16 { mag: 66453, sign: true }); + data.append(FP16x16 { mag: 41129, sign: true }); + data.append(FP16x16 { mag: 28771, sign: false }); + data.append(FP16x16 { mag: 38619, sign: true }); + data.append(FP16x16 { mag: 12052, sign: true }); + data.append(FP16x16 { mag: 71742, sign: true }); + data.append(FP16x16 { mag: 3122, sign: true }); + data.append(FP16x16 { mag: 7658, sign: false }); + data.append(FP16x16 { mag: 72650, sign: false }); + data.append(FP16x16 { mag: 125101, sign: true }); + data.append(FP16x16 { mag: 123350, sign: true }); + data.append(FP16x16 { mag: 17299, sign: true }); + data.append(FP16x16 { mag: 133664, sign: false }); + data.append(FP16x16 { mag: 95871, sign: false }); + data.append(FP16x16 { mag: 68174, sign: true }); + data.append(FP16x16 { mag: 11742, sign: true }); + data.append(FP16x16 { mag: 39757, sign: false }); + data.append(FP16x16 { mag: 107969, sign: true }); + data.append(FP16x16 { mag: 47625, sign: false }); + data.append(FP16x16 { mag: 25381, sign: false }); + data.append(FP16x16 { mag: 86657, sign: false }); + data.append(FP16x16 { mag: 124296, sign: true }); + data.append(FP16x16 { mag: 84481, sign: true }); + data.append(FP16x16 { mag: 39857, sign: false }); + data.append(FP16x16 { mag: 62123, sign: false }); + data.append(FP16x16 { mag: 51178, sign: true }); + data.append(FP16x16 { mag: 1008, sign: true }); + data.append(FP16x16 { mag: 58152, sign: false }); + data.append(FP16x16 { mag: 23649, sign: false }); + data.append(FP16x16 { mag: 21472, sign: true }); + data.append(FP16x16 { mag: 42397, sign: true }); + data.append(FP16x16 { mag: 9088, sign: true }); + data.append(FP16x16 { mag: 174345, sign: false }); + data.append(FP16x16 { mag: 36378, sign: true }); + data.append(FP16x16 { mag: 57507, sign: false }); + data.append(FP16x16 { mag: 9867, sign: false }); + data.append(FP16x16 { mag: 103912, sign: false }); + data.append(FP16x16 { mag: 14015, sign: false }); + data.append(FP16x16 { mag: 18604, sign: true }); + data.append(FP16x16 { mag: 5497, sign: false }); + data.append(FP16x16 { mag: 47155, sign: false }); + data.append(FP16x16 { mag: 46396, sign: true }); + data.append(FP16x16 { mag: 57120, sign: false }); + data.append(FP16x16 { mag: 74433, sign: true }); + data.append(FP16x16 { mag: 8714, sign: false }); + data.append(FP16x16 { mag: 52414, sign: true }); + data.append(FP16x16 { mag: 29603, sign: false }); + data.append(FP16x16 { mag: 34762, sign: false }); + data.append(FP16x16 { mag: 30832, sign: false }); + data.append(FP16x16 { mag: 71610, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis-4/input_2.cairo b/tests/nodes/layer_normalization_4d_axis-4/input_2.cairo new file mode 100644 index 000000000..ebcaa1b95 --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis-4/input_2.cairo @@ -0,0 +1,135 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_2() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 74987, sign: false }); + data.append(FP16x16 { mag: 102257, sign: true }); + data.append(FP16x16 { mag: 78717, sign: false }); + data.append(FP16x16 { mag: 59379, sign: true }); + data.append(FP16x16 { mag: 11211, sign: true }); + data.append(FP16x16 { mag: 26255, sign: false }); + data.append(FP16x16 { mag: 30689, sign: false }); + data.append(FP16x16 { mag: 80491, sign: true }); + data.append(FP16x16 { mag: 14461, sign: true }); + data.append(FP16x16 { mag: 42696, sign: false }); + data.append(FP16x16 { mag: 56317, sign: false }); + data.append(FP16x16 { mag: 88135, sign: false }); + data.append(FP16x16 { mag: 45393, sign: false }); + data.append(FP16x16 { mag: 18363, sign: true }); + data.append(FP16x16 { mag: 20773, sign: false }); + data.append(FP16x16 { mag: 66375, sign: false }); + data.append(FP16x16 { mag: 35314, sign: true }); + data.append(FP16x16 { mag: 52902, sign: false }); + data.append(FP16x16 { mag: 86324, sign: true }); + data.append(FP16x16 { mag: 40151, sign: true }); + data.append(FP16x16 { mag: 22446, sign: false }); + data.append(FP16x16 { mag: 33948, sign: true }); + data.append(FP16x16 { mag: 30607, sign: false }); + data.append(FP16x16 { mag: 8106, sign: false }); + data.append(FP16x16 { mag: 42198, sign: false }); + data.append(FP16x16 { mag: 128853, sign: true }); + data.append(FP16x16 { mag: 17078, sign: true }); + data.append(FP16x16 { mag: 200766, sign: false }); + data.append(FP16x16 { mag: 8367, sign: false }); + data.append(FP16x16 { mag: 91621, sign: true }); + data.append(FP16x16 { mag: 34264, sign: false }); + data.append(FP16x16 { mag: 46301, sign: false }); + data.append(FP16x16 { mag: 106288, sign: false }); + data.append(FP16x16 { mag: 130170, sign: false }); + data.append(FP16x16 { mag: 5525, sign: false }); + data.append(FP16x16 { mag: 111843, sign: true }); + data.append(FP16x16 { mag: 1844, sign: true }); + data.append(FP16x16 { mag: 73523, sign: true }); + data.append(FP16x16 { mag: 33663, sign: false }); + data.append(FP16x16 { mag: 9757, sign: false }); + data.append(FP16x16 { mag: 16424, sign: true }); + data.append(FP16x16 { mag: 67890, sign: false }); + data.append(FP16x16 { mag: 103692, sign: true }); + data.append(FP16x16 { mag: 37978, sign: false }); + data.append(FP16x16 { mag: 52354, sign: true }); + data.append(FP16x16 { mag: 62808, sign: false }); + data.append(FP16x16 { mag: 73374, sign: true }); + data.append(FP16x16 { mag: 42169, sign: false }); + data.append(FP16x16 { mag: 42576, sign: false }); + data.append(FP16x16 { mag: 76141, sign: false }); + data.append(FP16x16 { mag: 12231, sign: false }); + data.append(FP16x16 { mag: 42136, sign: false }); + data.append(FP16x16 { mag: 47496, sign: true }); + data.append(FP16x16 { mag: 81264, sign: false }); + data.append(FP16x16 { mag: 15196, sign: true }); + data.append(FP16x16 { mag: 70791, sign: true }); + data.append(FP16x16 { mag: 7825, sign: true }); + data.append(FP16x16 { mag: 40333, sign: true }); + data.append(FP16x16 { mag: 30221, sign: false }); + data.append(FP16x16 { mag: 98040, sign: true }); + data.append(FP16x16 { mag: 39486, sign: true }); + data.append(FP16x16 { mag: 6407, sign: false }); + data.append(FP16x16 { mag: 24512, sign: true }); + data.append(FP16x16 { mag: 7948, sign: true }); + data.append(FP16x16 { mag: 11193, sign: true }); + data.append(FP16x16 { mag: 3711, sign: true }); + data.append(FP16x16 { mag: 35698, sign: true }); + data.append(FP16x16 { mag: 3213, sign: false }); + data.append(FP16x16 { mag: 38235, sign: false }); + data.append(FP16x16 { mag: 82210, sign: true }); + data.append(FP16x16 { mag: 6567, sign: true }); + data.append(FP16x16 { mag: 5350, sign: false }); + data.append(FP16x16 { mag: 5068, sign: true }); + data.append(FP16x16 { mag: 27441, sign: true }); + data.append(FP16x16 { mag: 17130, sign: false }); + data.append(FP16x16 { mag: 89900, sign: true }); + data.append(FP16x16 { mag: 77188, sign: false }); + data.append(FP16x16 { mag: 34983, sign: true }); + data.append(FP16x16 { mag: 5966, sign: false }); + data.append(FP16x16 { mag: 16995, sign: true }); + data.append(FP16x16 { mag: 34228, sign: true }); + data.append(FP16x16 { mag: 29003, sign: false }); + data.append(FP16x16 { mag: 9617, sign: false }); + data.append(FP16x16 { mag: 70272, sign: true }); + data.append(FP16x16 { mag: 82673, sign: true }); + data.append(FP16x16 { mag: 12155, sign: false }); + data.append(FP16x16 { mag: 25123, sign: true }); + data.append(FP16x16 { mag: 137966, sign: true }); + data.append(FP16x16 { mag: 69397, sign: true }); + data.append(FP16x16 { mag: 61271, sign: false }); + data.append(FP16x16 { mag: 12363, sign: false }); + data.append(FP16x16 { mag: 13484, sign: true }); + data.append(FP16x16 { mag: 85721, sign: false }); + data.append(FP16x16 { mag: 81931, sign: false }); + data.append(FP16x16 { mag: 15988, sign: false }); + data.append(FP16x16 { mag: 82208, sign: false }); + data.append(FP16x16 { mag: 94093, sign: true }); + data.append(FP16x16 { mag: 111240, sign: true }); + data.append(FP16x16 { mag: 95369, sign: false }); + data.append(FP16x16 { mag: 29800, sign: false }); + data.append(FP16x16 { mag: 18974, sign: false }); + data.append(FP16x16 { mag: 41527, sign: true }); + data.append(FP16x16 { mag: 12256, sign: false }); + data.append(FP16x16 { mag: 28697, sign: false }); + data.append(FP16x16 { mag: 92475, sign: true }); + data.append(FP16x16 { mag: 90746, sign: true }); + data.append(FP16x16 { mag: 4496, sign: true }); + data.append(FP16x16 { mag: 6464, sign: false }); + data.append(FP16x16 { mag: 45570, sign: true }); + data.append(FP16x16 { mag: 28209, sign: true }); + data.append(FP16x16 { mag: 33084, sign: false }); + data.append(FP16x16 { mag: 131340, sign: false }); + data.append(FP16x16 { mag: 15928, sign: true }); + data.append(FP16x16 { mag: 13786, sign: true }); + data.append(FP16x16 { mag: 89832, sign: false }); + data.append(FP16x16 { mag: 59414, sign: true }); + data.append(FP16x16 { mag: 27186, sign: true }); + data.append(FP16x16 { mag: 57062, sign: false }); + data.append(FP16x16 { mag: 8290, sign: false }); + data.append(FP16x16 { mag: 23747, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis-4/output_0.cairo b/tests/nodes/layer_normalization_4d_axis-4/output_0.cairo new file mode 100644 index 000000000..3dd5f2734 --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis-4/output_0.cairo @@ -0,0 +1,135 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 81938, sign: false }); + data.append(FP16x16 { mag: 139115, sign: true }); + data.append(FP16x16 { mag: 99625, sign: false }); + data.append(FP16x16 { mag: 42553, sign: false }); + data.append(FP16x16 { mag: 63807, sign: true }); + data.append(FP16x16 { mag: 29251, sign: false }); + data.append(FP16x16 { mag: 74903, sign: true }); + data.append(FP16x16 { mag: 9937, sign: true }); + data.append(FP16x16 { mag: 65116, sign: true }); + data.append(FP16x16 { mag: 48943, sign: false }); + data.append(FP16x16 { mag: 37524, sign: false }); + data.append(FP16x16 { mag: 77082, sign: false }); + data.append(FP16x16 { mag: 57309, sign: false }); + data.append(FP16x16 { mag: 3013, sign: true }); + data.append(FP16x16 { mag: 22472, sign: false }); + data.append(FP16x16 { mag: 68481, sign: false }); + data.append(FP16x16 { mag: 47521, sign: true }); + data.append(FP16x16 { mag: 29144, sign: false }); + data.append(FP16x16 { mag: 11016, sign: false }); + data.append(FP16x16 { mag: 63886, sign: true }); + data.append(FP16x16 { mag: 88094, sign: false }); + data.append(FP16x16 { mag: 112894, sign: true }); + data.append(FP16x16 { mag: 7732, sign: false }); + data.append(FP16x16 { mag: 45680, sign: true }); + data.append(FP16x16 { mag: 52166, sign: false }); + data.append(FP16x16 { mag: 135141, sign: true }); + data.append(FP16x16 { mag: 18324, sign: true }); + data.append(FP16x16 { mag: 201528, sign: false }); + data.append(FP16x16 { mag: 125050, sign: false }); + data.append(FP16x16 { mag: 50933, sign: true }); + data.append(FP16x16 { mag: 64223, sign: true }); + data.append(FP16x16 { mag: 52706, sign: false }); + data.append(FP16x16 { mag: 106918, sign: false }); + data.append(FP16x16 { mag: 134854, sign: false }); + data.append(FP16x16 { mag: 6678, sign: false }); + data.append(FP16x16 { mag: 100030, sign: true }); + data.append(FP16x16 { mag: 8001, sign: false }); + data.append(FP16x16 { mag: 75425, sign: true }); + data.append(FP16x16 { mag: 3641, sign: false }); + data.append(FP16x16 { mag: 20349, sign: true }); + data.append(FP16x16 { mag: 33423, sign: true }); + data.append(FP16x16 { mag: 135434, sign: false }); + data.append(FP16x16 { mag: 268826, sign: true }); + data.append(FP16x16 { mag: 9384, sign: false }); + data.append(FP16x16 { mag: 42989, sign: true }); + data.append(FP16x16 { mag: 40596, sign: false }); + data.append(FP16x16 { mag: 27643, sign: true }); + data.append(FP16x16 { mag: 14881, sign: false }); + data.append(FP16x16 { mag: 33184, sign: false }); + data.append(FP16x16 { mag: 245527, sign: false }); + data.append(FP16x16 { mag: 12844, sign: false }); + data.append(FP16x16 { mag: 36484, sign: false }); + data.append(FP16x16 { mag: 43540, sign: true }); + data.append(FP16x16 { mag: 98782, sign: false }); + data.append(FP16x16 { mag: 91559, sign: true }); + data.append(FP16x16 { mag: 175239, sign: true }); + data.append(FP16x16 { mag: 36924, sign: true }); + data.append(FP16x16 { mag: 69880, sign: true }); + data.append(FP16x16 { mag: 218008, sign: false }); + data.append(FP16x16 { mag: 89765, sign: true }); + data.append(FP16x16 { mag: 37583, sign: false }); + data.append(FP16x16 { mag: 40273, sign: false }); + data.append(FP16x16 { mag: 47428, sign: true }); + data.append(FP16x16 { mag: 33322, sign: true }); + data.append(FP16x16 { mag: 53002, sign: true }); + data.append(FP16x16 { mag: 51571, sign: false }); + data.append(FP16x16 { mag: 43023, sign: true }); + data.append(FP16x16 { mag: 180611, sign: false }); + data.append(FP16x16 { mag: 27436, sign: false }); + data.append(FP16x16 { mag: 82088, sign: true }); + data.append(FP16x16 { mag: 49075, sign: true }); + data.append(FP16x16 { mag: 93116, sign: true }); + data.append(FP16x16 { mag: 24713, sign: true }); + data.append(FP16x16 { mag: 57986, sign: true }); + data.append(FP16x16 { mag: 11641, sign: false }); + data.append(FP16x16 { mag: 87906, sign: true }); + data.append(FP16x16 { mag: 73854, sign: false }); + data.append(FP16x16 { mag: 22022, sign: true }); + data.append(FP16x16 { mag: 31528, sign: true }); + data.append(FP16x16 { mag: 151422, sign: true }); + data.append(FP16x16 { mag: 270761, sign: true }); + data.append(FP16x16 { mag: 37881, sign: false }); + data.append(FP16x16 { mag: 81795, sign: true }); + data.append(FP16x16 { mag: 65199, sign: false }); + data.append(FP16x16 { mag: 89952, sign: true }); + data.append(FP16x16 { mag: 36152, sign: false }); + data.append(FP16x16 { mag: 20482, sign: true }); + data.append(FP16x16 { mag: 296268, sign: true }); + data.append(FP16x16 { mag: 79775, sign: true }); + data.append(FP16x16 { mag: 66513, sign: false }); + data.append(FP16x16 { mag: 107842, sign: false }); + data.append(FP16x16 { mag: 37504, sign: false }); + data.append(FP16x16 { mag: 170624, sign: false }); + data.append(FP16x16 { mag: 52118, sign: false }); + data.append(FP16x16 { mag: 43525, sign: false }); + data.append(FP16x16 { mag: 138068, sign: false }); + data.append(FP16x16 { mag: 92822, sign: true }); + data.append(FP16x16 { mag: 155459, sign: true }); + data.append(FP16x16 { mag: 101038, sign: false }); + data.append(FP16x16 { mag: 45944, sign: false }); + data.append(FP16x16 { mag: 14639, sign: true }); + data.append(FP16x16 { mag: 30883, sign: true }); + data.append(FP16x16 { mag: 262024, sign: false }); + data.append(FP16x16 { mag: 30507, sign: true }); + data.append(FP16x16 { mag: 43715, sign: true }); + data.append(FP16x16 { mag: 84662, sign: true }); + data.append(FP16x16 { mag: 63901, sign: false }); + data.append(FP16x16 { mag: 3913, sign: false }); + data.append(FP16x16 { mag: 40400, sign: true }); + data.append(FP16x16 { mag: 36018, sign: true }); + data.append(FP16x16 { mag: 49780, sign: true }); + data.append(FP16x16 { mag: 146468, sign: false }); + data.append(FP16x16 { mag: 32554, sign: true }); + data.append(FP16x16 { mag: 39291, sign: true }); + data.append(FP16x16 { mag: 93325, sign: false }); + data.append(FP16x16 { mag: 216, sign: false }); + data.append(FP16x16 { mag: 67978, sign: true }); + data.append(FP16x16 { mag: 97710, sign: false }); + data.append(FP16x16 { mag: 19520, sign: false }); + data.append(FP16x16 { mag: 6998, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis0.cairo b/tests/nodes/layer_normalization_4d_axis0.cairo new file mode 100644 index 000000000..2a3b930ec --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis0.cairo @@ -0,0 +1,32 @@ +mod input_0; +mod input_1; +mod input_2; +mod output_0; + + +use orion::operators::tensor::FP8x23Tensor; +use orion::utils::{assert_eq, assert_seq_eq}; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23TensorPartialEq; +use orion::numbers::{IntegerTrait, i32, FixedTrait}; + +#[test] +#[available_gas(2000000000)] +fn test_layer_normalization_4d_axis0() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let input_2 = input_2::input_2(); + let z_0 = output_0::output_0(); + + let (y_0, _, _) = input_0 + .layer_normalization( + @input_1, + Option::Some(@input_2), + Option::Some(IntegerTrait::::new(0, false)), + Option::None, + Option::None + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/layer_normalization_4d_axis0/input_0.cairo b/tests/nodes/layer_normalization_4d_axis0/input_0.cairo new file mode 100644 index 000000000..d7913bb11 --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis0/input_0.cairo @@ -0,0 +1,135 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 9853496, sign: true }); + data.append(FP8x23 { mag: 12258403, sign: false }); + data.append(FP8x23 { mag: 872656, sign: false }); + data.append(FP8x23 { mag: 7388622, sign: true }); + data.append(FP8x23 { mag: 15454325, sign: false }); + data.append(FP8x23 { mag: 16251435, sign: false }); + data.append(FP8x23 { mag: 903277, sign: true }); + data.append(FP8x23 { mag: 3289794, sign: true }); + data.append(FP8x23 { mag: 8057933, sign: false }); + data.append(FP8x23 { mag: 6386388, sign: true }); + data.append(FP8x23 { mag: 5708410, sign: true }); + data.append(FP8x23 { mag: 4543373, sign: true }); + data.append(FP8x23 { mag: 23644376, sign: false }); + data.append(FP8x23 { mag: 7843321, sign: true }); + data.append(FP8x23 { mag: 5801261, sign: false }); + data.append(FP8x23 { mag: 5404517, sign: true }); + data.append(FP8x23 { mag: 3421350, sign: false }); + data.append(FP8x23 { mag: 2806284, sign: true }); + data.append(FP8x23 { mag: 5479745, sign: false }); + data.append(FP8x23 { mag: 9044852, sign: true }); + data.append(FP8x23 { mag: 2879371, sign: true }); + data.append(FP8x23 { mag: 7506722, sign: false }); + data.append(FP8x23 { mag: 374708, sign: false }); + data.append(FP8x23 { mag: 16088456, sign: false }); + data.append(FP8x23 { mag: 7446071, sign: false }); + data.append(FP8x23 { mag: 12333473, sign: true }); + data.append(FP8x23 { mag: 2694855, sign: false }); + data.append(FP8x23 { mag: 29333, sign: false }); + data.append(FP8x23 { mag: 3210230, sign: true }); + data.append(FP8x23 { mag: 246487, sign: false }); + data.append(FP8x23 { mag: 11307238, sign: true }); + data.append(FP8x23 { mag: 808074, sign: false }); + data.append(FP8x23 { mag: 2244426, sign: false }); + data.append(FP8x23 { mag: 4494036, sign: false }); + data.append(FP8x23 { mag: 9168918, sign: false }); + data.append(FP8x23 { mag: 11831318, sign: true }); + data.append(FP8x23 { mag: 11514568, sign: true }); + data.append(FP8x23 { mag: 3381120, sign: true }); + data.append(FP8x23 { mag: 6827926, sign: true }); + data.append(FP8x23 { mag: 2579494, sign: false }); + data.append(FP8x23 { mag: 4493030, sign: true }); + data.append(FP8x23 { mag: 4570125, sign: false }); + data.append(FP8x23 { mag: 8010665, sign: false }); + data.append(FP8x23 { mag: 5794037, sign: true }); + data.append(FP8x23 { mag: 9849078, sign: false }); + data.append(FP8x23 { mag: 11691798, sign: true }); + data.append(FP8x23 { mag: 3652747, sign: false }); + data.append(FP8x23 { mag: 1032666, sign: true }); + data.append(FP8x23 { mag: 9259310, sign: true }); + data.append(FP8x23 { mag: 7405492, sign: false }); + data.append(FP8x23 { mag: 4687488, sign: true }); + data.append(FP8x23 { mag: 1097650, sign: true }); + data.append(FP8x23 { mag: 2213858, sign: true }); + data.append(FP8x23 { mag: 1436205, sign: true }); + data.append(FP8x23 { mag: 10226423, sign: false }); + data.append(FP8x23 { mag: 6130226, sign: false }); + data.append(FP8x23 { mag: 1214058, sign: false }); + data.append(FP8x23 { mag: 12299984, sign: true }); + data.append(FP8x23 { mag: 829240, sign: false }); + data.append(FP8x23 { mag: 1612388, sign: false }); + data.append(FP8x23 { mag: 6632529, sign: true }); + data.append(FP8x23 { mag: 4410829, sign: true }); + data.append(FP8x23 { mag: 20654302, sign: false }); + data.append(FP8x23 { mag: 462475, sign: true }); + data.append(FP8x23 { mag: 10502841, sign: true }); + data.append(FP8x23 { mag: 7171902, sign: true }); + data.append(FP8x23 { mag: 4204962, sign: false }); + data.append(FP8x23 { mag: 17427142, sign: true }); + data.append(FP8x23 { mag: 12555224, sign: true }); + data.append(FP8x23 { mag: 8307885, sign: true }); + data.append(FP8x23 { mag: 455536, sign: false }); + data.append(FP8x23 { mag: 577191, sign: false }); + data.append(FP8x23 { mag: 4412268, sign: true }); + data.append(FP8x23 { mag: 15812229, sign: false }); + data.append(FP8x23 { mag: 7098764, sign: true }); + data.append(FP8x23 { mag: 9127468, sign: false }); + data.append(FP8x23 { mag: 4753858, sign: false }); + data.append(FP8x23 { mag: 2074029, sign: true }); + data.append(FP8x23 { mag: 1651256, sign: false }); + data.append(FP8x23 { mag: 9617324, sign: true }); + data.append(FP8x23 { mag: 11400835, sign: false }); + data.append(FP8x23 { mag: 4263073, sign: true }); + data.append(FP8x23 { mag: 22170402, sign: false }); + data.append(FP8x23 { mag: 7715608, sign: false }); + data.append(FP8x23 { mag: 7511781, sign: true }); + data.append(FP8x23 { mag: 8686402, sign: false }); + data.append(FP8x23 { mag: 2710329, sign: true }); + data.append(FP8x23 { mag: 5540998, sign: true }); + data.append(FP8x23 { mag: 11608300, sign: true }); + data.append(FP8x23 { mag: 3020404, sign: true }); + data.append(FP8x23 { mag: 6342478, sign: true }); + data.append(FP8x23 { mag: 9399735, sign: true }); + data.append(FP8x23 { mag: 446463, sign: false }); + data.append(FP8x23 { mag: 13691013, sign: true }); + data.append(FP8x23 { mag: 11552903, sign: true }); + data.append(FP8x23 { mag: 1204731, sign: false }); + data.append(FP8x23 { mag: 1741592, sign: true }); + data.append(FP8x23 { mag: 13103082, sign: false }); + data.append(FP8x23 { mag: 3181444, sign: true }); + data.append(FP8x23 { mag: 256975, sign: true }); + data.append(FP8x23 { mag: 9440785, sign: true }); + data.append(FP8x23 { mag: 2112590, sign: false }); + data.append(FP8x23 { mag: 13404752, sign: false }); + data.append(FP8x23 { mag: 760699, sign: false }); + data.append(FP8x23 { mag: 1588793, sign: true }); + data.append(FP8x23 { mag: 13026604, sign: true }); + data.append(FP8x23 { mag: 452707, sign: false }); + data.append(FP8x23 { mag: 7267348, sign: true }); + data.append(FP8x23 { mag: 14737007, sign: false }); + data.append(FP8x23 { mag: 8457998, sign: false }); + data.append(FP8x23 { mag: 2233703, sign: false }); + data.append(FP8x23 { mag: 3434673, sign: true }); + data.append(FP8x23 { mag: 4280157, sign: true }); + data.append(FP8x23 { mag: 2950181, sign: true }); + data.append(FP8x23 { mag: 1385553, sign: false }); + data.append(FP8x23 { mag: 17250056, sign: false }); + data.append(FP8x23 { mag: 12716927, sign: true }); + data.append(FP8x23 { mag: 2980452, sign: false }); + data.append(FP8x23 { mag: 13031106, sign: true }); + data.append(FP8x23 { mag: 4118717, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis0/input_1.cairo b/tests/nodes/layer_normalization_4d_axis0/input_1.cairo new file mode 100644 index 000000000..d5d19269f --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis0/input_1.cairo @@ -0,0 +1,135 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 7856434, sign: false }); + data.append(FP8x23 { mag: 6559353, sign: false }); + data.append(FP8x23 { mag: 2374243, sign: true }); + data.append(FP8x23 { mag: 9038396, sign: true }); + data.append(FP8x23 { mag: 4891209, sign: false }); + data.append(FP8x23 { mag: 10531353, sign: false }); + data.append(FP8x23 { mag: 2374580, sign: false }); + data.append(FP8x23 { mag: 15815884, sign: false }); + data.append(FP8x23 { mag: 8516677, sign: true }); + data.append(FP8x23 { mag: 812409, sign: false }); + data.append(FP8x23 { mag: 17686618, sign: false }); + data.append(FP8x23 { mag: 7197725, sign: false }); + data.append(FP8x23 { mag: 10954258, sign: false }); + data.append(FP8x23 { mag: 8350298, sign: false }); + data.append(FP8x23 { mag: 4666106, sign: true }); + data.append(FP8x23 { mag: 7256064, sign: false }); + data.append(FP8x23 { mag: 8635308, sign: false }); + data.append(FP8x23 { mag: 5024989, sign: false }); + data.append(FP8x23 { mag: 1263377, sign: true }); + data.append(FP8x23 { mag: 10669786, sign: false }); + data.append(FP8x23 { mag: 12365659, sign: true }); + data.append(FP8x23 { mag: 9240214, sign: true }); + data.append(FP8x23 { mag: 6033435, sign: true }); + data.append(FP8x23 { mag: 2813974, sign: false }); + data.append(FP8x23 { mag: 2864389, sign: true }); + data.append(FP8x23 { mag: 7531822, sign: true }); + data.append(FP8x23 { mag: 14534796, sign: false }); + data.append(FP8x23 { mag: 12879313, sign: false }); + data.append(FP8x23 { mag: 23143788, sign: true }); + data.append(FP8x23 { mag: 956451, sign: true }); + data.append(FP8x23 { mag: 140224, sign: false }); + data.append(FP8x23 { mag: 3524534, sign: false }); + data.append(FP8x23 { mag: 9520969, sign: true }); + data.append(FP8x23 { mag: 13151858, sign: true }); + data.append(FP8x23 { mag: 7994295, sign: false }); + data.append(FP8x23 { mag: 5842236, sign: false }); + data.append(FP8x23 { mag: 6780944, sign: false }); + data.append(FP8x23 { mag: 2066721, sign: false }); + data.append(FP8x23 { mag: 2105799, sign: true }); + data.append(FP8x23 { mag: 1298593, sign: true }); + data.append(FP8x23 { mag: 4229235, sign: true }); + data.append(FP8x23 { mag: 7262550, sign: true }); + data.append(FP8x23 { mag: 6216754, sign: false }); + data.append(FP8x23 { mag: 10747679, sign: true }); + data.append(FP8x23 { mag: 6150248, sign: true }); + data.append(FP8x23 { mag: 11662506, sign: true }); + data.append(FP8x23 { mag: 2114116, sign: false }); + data.append(FP8x23 { mag: 3345752, sign: false }); + data.append(FP8x23 { mag: 10971260, sign: false }); + data.append(FP8x23 { mag: 1397404, sign: true }); + data.append(FP8x23 { mag: 7777503, sign: false }); + data.append(FP8x23 { mag: 614354, sign: true }); + data.append(FP8x23 { mag: 2547461, sign: true }); + data.append(FP8x23 { mag: 9306342, sign: false }); + data.append(FP8x23 { mag: 8136902, sign: true }); + data.append(FP8x23 { mag: 2523700, sign: false }); + data.append(FP8x23 { mag: 5851470, sign: true }); + data.append(FP8x23 { mag: 2397709, sign: true }); + data.append(FP8x23 { mag: 2791362, sign: true }); + data.append(FP8x23 { mag: 11438024, sign: true }); + data.append(FP8x23 { mag: 2376167, sign: false }); + data.append(FP8x23 { mag: 11929178, sign: true }); + data.append(FP8x23 { mag: 2334309, sign: false }); + data.append(FP8x23 { mag: 11396586, sign: false }); + data.append(FP8x23 { mag: 9273483, sign: true }); + data.append(FP8x23 { mag: 5663225, sign: true }); + data.append(FP8x23 { mag: 7022748, sign: false }); + data.append(FP8x23 { mag: 3415334, sign: true }); + data.append(FP8x23 { mag: 5583578, sign: false }); + data.append(FP8x23 { mag: 3376007, sign: false }); + data.append(FP8x23 { mag: 9800954, sign: true }); + data.append(FP8x23 { mag: 557269, sign: true }); + data.append(FP8x23 { mag: 1332713, sign: false }); + data.append(FP8x23 { mag: 16394277, sign: false }); + data.append(FP8x23 { mag: 2404128, sign: false }); + data.append(FP8x23 { mag: 7264511, sign: true }); + data.append(FP8x23 { mag: 2959914, sign: false }); + data.append(FP8x23 { mag: 8465123, sign: false }); + data.append(FP8x23 { mag: 2179246, sign: false }); + data.append(FP8x23 { mag: 1403755, sign: true }); + data.append(FP8x23 { mag: 8841557, sign: true }); + data.append(FP8x23 { mag: 714874, sign: true }); + data.append(FP8x23 { mag: 3087264, sign: true }); + data.append(FP8x23 { mag: 3259074, sign: true }); + data.append(FP8x23 { mag: 10403024, sign: true }); + data.append(FP8x23 { mag: 9207515, sign: false }); + data.append(FP8x23 { mag: 10106154, sign: true }); + data.append(FP8x23 { mag: 22441932, sign: false }); + data.append(FP8x23 { mag: 2506462, sign: false }); + data.append(FP8x23 { mag: 7772606, sign: true }); + data.append(FP8x23 { mag: 10066387, sign: false }); + data.append(FP8x23 { mag: 13018190, sign: false }); + data.append(FP8x23 { mag: 551564, sign: true }); + data.append(FP8x23 { mag: 15173981, sign: false }); + data.append(FP8x23 { mag: 537916, sign: true }); + data.append(FP8x23 { mag: 1144261, sign: true }); + data.append(FP8x23 { mag: 10201860, sign: false }); + data.append(FP8x23 { mag: 14629321, sign: false }); + data.append(FP8x23 { mag: 10980985, sign: true }); + data.append(FP8x23 { mag: 8271637, sign: false }); + data.append(FP8x23 { mag: 12637380, sign: false }); + data.append(FP8x23 { mag: 7736607, sign: true }); + data.append(FP8x23 { mag: 1168929, sign: false }); + data.append(FP8x23 { mag: 7719282, sign: false }); + data.append(FP8x23 { mag: 8013562, sign: true }); + data.append(FP8x23 { mag: 1909398, sign: false }); + data.append(FP8x23 { mag: 7909437, sign: false }); + data.append(FP8x23 { mag: 9358508, sign: true }); + data.append(FP8x23 { mag: 2679626, sign: true }); + data.append(FP8x23 { mag: 10816482, sign: false }); + data.append(FP8x23 { mag: 5938359, sign: true }); + data.append(FP8x23 { mag: 3868879, sign: true }); + data.append(FP8x23 { mag: 17720398, sign: false }); + data.append(FP8x23 { mag: 8780306, sign: false }); + data.append(FP8x23 { mag: 8182772, sign: true }); + data.append(FP8x23 { mag: 8474158, sign: true }); + data.append(FP8x23 { mag: 10484711, sign: false }); + data.append(FP8x23 { mag: 6278095, sign: false }); + data.append(FP8x23 { mag: 13616979, sign: false }); + data.append(FP8x23 { mag: 683891, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis0/input_2.cairo b/tests/nodes/layer_normalization_4d_axis0/input_2.cairo new file mode 100644 index 000000000..576e62784 --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis0/input_2.cairo @@ -0,0 +1,135 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn input_2() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 2658357, sign: true }); + data.append(FP8x23 { mag: 1102321, sign: true }); + data.append(FP8x23 { mag: 9403754, sign: false }); + data.append(FP8x23 { mag: 6280169, sign: true }); + data.append(FP8x23 { mag: 322990, sign: true }); + data.append(FP8x23 { mag: 1855990, sign: false }); + data.append(FP8x23 { mag: 7959357, sign: false }); + data.append(FP8x23 { mag: 3813579, sign: true }); + data.append(FP8x23 { mag: 17190620, sign: false }); + data.append(FP8x23 { mag: 7625388, sign: false }); + data.append(FP8x23 { mag: 6326665, sign: true }); + data.append(FP8x23 { mag: 2144242, sign: false }); + data.append(FP8x23 { mag: 1652040, sign: true }); + data.append(FP8x23 { mag: 1708413, sign: true }); + data.append(FP8x23 { mag: 5577434, sign: false }); + data.append(FP8x23 { mag: 12958978, sign: false }); + data.append(FP8x23 { mag: 14472776, sign: true }); + data.append(FP8x23 { mag: 12142315, sign: true }); + data.append(FP8x23 { mag: 5004137, sign: true }); + data.append(FP8x23 { mag: 10687620, sign: false }); + data.append(FP8x23 { mag: 11911291, sign: true }); + data.append(FP8x23 { mag: 10598514, sign: true }); + data.append(FP8x23 { mag: 2129772, sign: true }); + data.append(FP8x23 { mag: 6883996, sign: false }); + data.append(FP8x23 { mag: 4546950, sign: true }); + data.append(FP8x23 { mag: 3450860, sign: true }); + data.append(FP8x23 { mag: 1775511, sign: true }); + data.append(FP8x23 { mag: 512918, sign: false }); + data.append(FP8x23 { mag: 3693486, sign: false }); + data.append(FP8x23 { mag: 4656135, sign: true }); + data.append(FP8x23 { mag: 5018284, sign: false }); + data.append(FP8x23 { mag: 17579972, sign: false }); + data.append(FP8x23 { mag: 5429984, sign: true }); + data.append(FP8x23 { mag: 3455574, sign: false }); + data.append(FP8x23 { mag: 2776885, sign: false }); + data.append(FP8x23 { mag: 12394373, sign: false }); + data.append(FP8x23 { mag: 12252363, sign: true }); + data.append(FP8x23 { mag: 9655402, sign: false }); + data.append(FP8x23 { mag: 11694762, sign: false }); + data.append(FP8x23 { mag: 2352758, sign: false }); + data.append(FP8x23 { mag: 14214209, sign: false }); + data.append(FP8x23 { mag: 4904045, sign: true }); + data.append(FP8x23 { mag: 750954, sign: false }); + data.append(FP8x23 { mag: 5521106, sign: true }); + data.append(FP8x23 { mag: 8124032, sign: true }); + data.append(FP8x23 { mag: 3155244, sign: true }); + data.append(FP8x23 { mag: 2794749, sign: true }); + data.append(FP8x23 { mag: 11913728, sign: true }); + data.append(FP8x23 { mag: 3406536, sign: false }); + data.append(FP8x23 { mag: 17572360, sign: false }); + data.append(FP8x23 { mag: 3846474, sign: false }); + data.append(FP8x23 { mag: 4679753, sign: true }); + data.append(FP8x23 { mag: 3526965, sign: false }); + data.append(FP8x23 { mag: 2361175, sign: true }); + data.append(FP8x23 { mag: 6800555, sign: false }); + data.append(FP8x23 { mag: 8765096, sign: false }); + data.append(FP8x23 { mag: 2069919, sign: true }); + data.append(FP8x23 { mag: 10541365, sign: false }); + data.append(FP8x23 { mag: 5702066, sign: false }); + data.append(FP8x23 { mag: 1901244, sign: true }); + data.append(FP8x23 { mag: 4795815, sign: true }); + data.append(FP8x23 { mag: 5440175, sign: true }); + data.append(FP8x23 { mag: 6145038, sign: true }); + data.append(FP8x23 { mag: 8155758, sign: true }); + data.append(FP8x23 { mag: 12882916, sign: false }); + data.append(FP8x23 { mag: 8512317, sign: true }); + data.append(FP8x23 { mag: 8827800, sign: false }); + data.append(FP8x23 { mag: 4780680, sign: true }); + data.append(FP8x23 { mag: 10233073, sign: false }); + data.append(FP8x23 { mag: 3987405, sign: false }); + data.append(FP8x23 { mag: 9797702, sign: true }); + data.append(FP8x23 { mag: 11709163, sign: true }); + data.append(FP8x23 { mag: 5010523, sign: false }); + data.append(FP8x23 { mag: 5898579, sign: false }); + data.append(FP8x23 { mag: 3610233, sign: false }); + data.append(FP8x23 { mag: 4366597, sign: true }); + data.append(FP8x23 { mag: 4121505, sign: false }); + data.append(FP8x23 { mag: 1386837, sign: true }); + data.append(FP8x23 { mag: 2197332, sign: false }); + data.append(FP8x23 { mag: 1583611, sign: false }); + data.append(FP8x23 { mag: 1204922, sign: false }); + data.append(FP8x23 { mag: 1731995, sign: true }); + data.append(FP8x23 { mag: 1196413, sign: true }); + data.append(FP8x23 { mag: 11399971, sign: true }); + data.append(FP8x23 { mag: 1303591, sign: true }); + data.append(FP8x23 { mag: 1306507, sign: true }); + data.append(FP8x23 { mag: 8296480, sign: false }); + data.append(FP8x23 { mag: 6644877, sign: true }); + data.append(FP8x23 { mag: 7522441, sign: false }); + data.append(FP8x23 { mag: 13195509, sign: false }); + data.append(FP8x23 { mag: 5912413, sign: false }); + data.append(FP8x23 { mag: 6827008, sign: true }); + data.append(FP8x23 { mag: 6046144, sign: true }); + data.append(FP8x23 { mag: 8475255, sign: true }); + data.append(FP8x23 { mag: 15398520, sign: false }); + data.append(FP8x23 { mag: 2626721, sign: false }); + data.append(FP8x23 { mag: 4142304, sign: true }); + data.append(FP8x23 { mag: 6868178, sign: false }); + data.append(FP8x23 { mag: 1537408, sign: true }); + data.append(FP8x23 { mag: 4760507, sign: true }); + data.append(FP8x23 { mag: 16526955, sign: false }); + data.append(FP8x23 { mag: 4766056, sign: false }); + data.append(FP8x23 { mag: 4099142, sign: true }); + data.append(FP8x23 { mag: 2694894, sign: false }); + data.append(FP8x23 { mag: 6008844, sign: true }); + data.append(FP8x23 { mag: 15054491, sign: false }); + data.append(FP8x23 { mag: 2926959, sign: true }); + data.append(FP8x23 { mag: 4325199, sign: false }); + data.append(FP8x23 { mag: 1262017, sign: false }); + data.append(FP8x23 { mag: 8352226, sign: false }); + data.append(FP8x23 { mag: 5432226, sign: true }); + data.append(FP8x23 { mag: 7565847, sign: true }); + data.append(FP8x23 { mag: 1398099, sign: true }); + data.append(FP8x23 { mag: 10744126, sign: false }); + data.append(FP8x23 { mag: 10154080, sign: false }); + data.append(FP8x23 { mag: 543398, sign: true }); + data.append(FP8x23 { mag: 4199815, sign: false }); + data.append(FP8x23 { mag: 3981547, sign: false }); + data.append(FP8x23 { mag: 3843559, sign: true }); + data.append(FP8x23 { mag: 2002995, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis0/output_0.cairo b/tests/nodes/layer_normalization_4d_axis0/output_0.cairo new file mode 100644 index 000000000..ecfa40d38 --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis0/output_0.cairo @@ -0,0 +1,135 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 11485245, sign: true }); + data.append(FP8x23 { mag: 8639736, sign: false }); + data.append(FP8x23 { mag: 9066761, sign: false }); + data.append(FP8x23 { mag: 1246285, sign: false }); + data.append(FP8x23 { mag: 8785751, sign: false }); + data.append(FP8x23 { mag: 22458580, sign: false }); + data.append(FP8x23 { mag: 7798869, sign: false }); + data.append(FP8x23 { mag: 9335609, sign: true }); + data.append(FP8x23 { mag: 8762101, sign: false }); + data.append(FP8x23 { mag: 7044940, sign: false }); + data.append(FP8x23 { mag: 17548656, sign: true }); + data.append(FP8x23 { mag: 1433318, sign: true }); + data.append(FP8x23 { mag: 29332314, sign: false }); + data.append(FP8x23 { mag: 9109826, sign: true }); + data.append(FP8x23 { mag: 2201931, sign: false }); + data.append(FP8x23 { mag: 8615226, sign: false }); + data.append(FP8x23 { mag: 10650534, sign: true }); + data.append(FP8x23 { mag: 13610117, sign: true }); + data.append(FP8x23 { mag: 5870152, sign: true }); + data.append(FP8x23 { mag: 282205, sign: true }); + data.append(FP8x23 { mag: 8192649, sign: true }); + data.append(FP8x23 { mag: 19142176, sign: true }); + data.append(FP8x23 { mag: 2631691, sign: true }); + data.append(FP8x23 { mag: 12334894, sign: false }); + data.append(FP8x23 { mag: 7174919, sign: true }); + data.append(FP8x23 { mag: 7215017, sign: false }); + data.append(FP8x23 { mag: 3412223, sign: false }); + data.append(FP8x23 { mag: 1059551, sign: false }); + data.append(FP8x23 { mag: 11556768, sign: false }); + data.append(FP8x23 { mag: 4721233, sign: true }); + data.append(FP8x23 { mag: 4836688, sign: false }); + data.append(FP8x23 { mag: 18053380, sign: false }); + data.append(FP8x23 { mag: 8322237, sign: true }); + data.append(FP8x23 { mag: 4030251, sign: true }); + data.append(FP8x23 { mag: 11736265, sign: false }); + data.append(FP8x23 { mag: 4467248, sign: false }); + data.append(FP8x23 { mag: 21199784, sign: true }); + data.append(FP8x23 { mag: 8911549, sign: false }); + data.append(FP8x23 { mag: 13309004, sign: false }); + data.append(FP8x23 { mag: 1906941, sign: false }); + data.append(FP8x23 { mag: 16291191, sign: false }); + data.append(FP8x23 { mag: 9102967, sign: true }); + data.append(FP8x23 { mag: 6868689, sign: false }); + data.append(FP8x23 { mag: 1406769, sign: false }); + data.append(FP8x23 { mag: 15510274, sign: true }); + data.append(FP8x23 { mag: 12477229, sign: false }); + data.append(FP8x23 { mag: 1801263, sign: true }); + data.append(FP8x23 { mag: 12190927, sign: true }); + data.append(FP8x23 { mag: 8150830, sign: true }); + data.append(FP8x23 { mag: 16296984, sign: false }); + data.append(FP8x23 { mag: 151497, sign: true }); + data.append(FP8x23 { mag: 4624143, sign: true }); + data.append(FP8x23 { mag: 4093028, sign: false }); + data.append(FP8x23 { mag: 3575281, sign: true }); + data.append(FP8x23 { mag: 3333834, sign: true }); + data.append(FP8x23 { mag: 10688711, sign: false }); + data.append(FP8x23 { mag: 3136147, sign: true }); + data.append(FP8x23 { mag: 13927308, sign: false }); + data.append(FP8x23 { mag: 5320165, sign: false }); + data.append(FP8x23 { mag: 4522953, sign: true }); + data.append(FP8x23 { mag: 6562536, sign: true }); + data.append(FP8x23 { mag: 302568, sign: false }); + data.append(FP8x23 { mag: 365858, sign: true }); + data.append(FP8x23 { mag: 8333322, sign: true }); + data.append(FP8x23 { mag: 24012324, sign: false }); + data.append(FP8x23 { mag: 3941235, sign: true }); + data.append(FP8x23 { mag: 12585527, sign: false }); + data.append(FP8x23 { mag: 2108236, sign: false }); + data.append(FP8x23 { mag: 2180042, sign: false }); + data.append(FP8x23 { mag: 809995, sign: false }); + data.append(FP8x23 { mag: 10706503, sign: true }); + data.append(FP8x23 { mag: 11768835, sign: true }); + data.append(FP8x23 { mag: 4368724, sign: false }); + data.append(FP8x23 { mag: 37121360, sign: false }); + data.append(FP8x23 { mag: 1690481, sign: false }); + data.append(FP8x23 { mag: 12472568, sign: true }); + data.append(FP8x23 { mag: 5896973, sign: false }); + data.append(FP8x23 { mag: 3128197, sign: true }); + data.append(FP8x23 { mag: 2706830, sign: false }); + data.append(FP8x23 { mag: 3121650, sign: false }); + data.append(FP8x23 { mag: 11032154, sign: true }); + data.append(FP8x23 { mag: 1400313, sign: true }); + data.append(FP8x23 { mag: 9391938, sign: true }); + data.append(FP8x23 { mag: 14493685, sign: true }); + data.append(FP8x23 { mag: 7510374, sign: false }); + data.append(FP8x23 { mag: 8488401, sign: false }); + data.append(FP8x23 { mag: 11134086, sign: false }); + data.append(FP8x23 { mag: 20440816, sign: true }); + data.append(FP8x23 { mag: 4187458, sign: false }); + data.append(FP8x23 { mag: 15662241, sign: false }); + data.append(FP8x23 { mag: 1227645, sign: true }); + data.append(FP8x23 { mag: 20756332, sign: true }); + data.append(FP8x23 { mag: 6096698, sign: true }); + data.append(FP8x23 { mag: 32393544, sign: true }); + data.append(FP8x23 { mag: 16110731, sign: false }); + data.append(FP8x23 { mag: 2419478, sign: false }); + data.append(FP8x23 { mag: 5840805, sign: true }); + data.append(FP8x23 { mag: 30053754, sign: false }); + data.append(FP8x23 { mag: 2156172, sign: false }); + data.append(FP8x23 { mag: 4688839, sign: true }); + data.append(FP8x23 { mag: 2943892, sign: false }); + data.append(FP8x23 { mag: 2536186, sign: false }); + data.append(FP8x23 { mag: 2204937, sign: true }); + data.append(FP8x23 { mag: 3688586, sign: false }); + data.append(FP8x23 { mag: 4819133, sign: true }); + data.append(FP8x23 { mag: 12194435, sign: false }); + data.append(FP8x23 { mag: 2196191, sign: true }); + data.append(FP8x23 { mag: 11984318, sign: false }); + data.append(FP8x23 { mag: 3501392, sign: true }); + data.append(FP8x23 { mag: 19567274, sign: false }); + data.append(FP8x23 { mag: 7228652, sign: true }); + data.append(FP8x23 { mag: 6148920, sign: true }); + data.append(FP8x23 { mag: 9655568, sign: true }); + data.append(FP8x23 { mag: 8030332, sign: false }); + data.append(FP8x23 { mag: 8497492, sign: false }); + data.append(FP8x23 { mag: 18119870, sign: true }); + data.append(FP8x23 { mag: 11121997, sign: true }); + data.append(FP8x23 { mag: 6433851, sign: false }); + data.append(FP8x23 { mag: 24247440, sign: true }); + data.append(FP8x23 { mag: 1697336, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis1.cairo b/tests/nodes/layer_normalization_4d_axis1.cairo new file mode 100644 index 000000000..d3d8808ce --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis1.cairo @@ -0,0 +1,32 @@ +mod input_0; +mod input_1; +mod input_2; +mod output_0; + + +use orion::operators::tensor::FP8x23Tensor; +use orion::utils::{assert_eq, assert_seq_eq}; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23TensorPartialEq; +use orion::numbers::{IntegerTrait, i32, FixedTrait}; + +#[test] +#[available_gas(2000000000)] +fn test_layer_normalization_4d_axis1() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let input_2 = input_2::input_2(); + let z_0 = output_0::output_0(); + + let (y_0, _, _) = input_0 + .layer_normalization( + @input_1, + Option::Some(@input_2), + Option::Some(IntegerTrait::::new(1, false)), + Option::None, + Option::None + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/layer_normalization_4d_axis1/input_0.cairo b/tests/nodes/layer_normalization_4d_axis1/input_0.cairo new file mode 100644 index 000000000..d7913bb11 --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis1/input_0.cairo @@ -0,0 +1,135 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 9853496, sign: true }); + data.append(FP8x23 { mag: 12258403, sign: false }); + data.append(FP8x23 { mag: 872656, sign: false }); + data.append(FP8x23 { mag: 7388622, sign: true }); + data.append(FP8x23 { mag: 15454325, sign: false }); + data.append(FP8x23 { mag: 16251435, sign: false }); + data.append(FP8x23 { mag: 903277, sign: true }); + data.append(FP8x23 { mag: 3289794, sign: true }); + data.append(FP8x23 { mag: 8057933, sign: false }); + data.append(FP8x23 { mag: 6386388, sign: true }); + data.append(FP8x23 { mag: 5708410, sign: true }); + data.append(FP8x23 { mag: 4543373, sign: true }); + data.append(FP8x23 { mag: 23644376, sign: false }); + data.append(FP8x23 { mag: 7843321, sign: true }); + data.append(FP8x23 { mag: 5801261, sign: false }); + data.append(FP8x23 { mag: 5404517, sign: true }); + data.append(FP8x23 { mag: 3421350, sign: false }); + data.append(FP8x23 { mag: 2806284, sign: true }); + data.append(FP8x23 { mag: 5479745, sign: false }); + data.append(FP8x23 { mag: 9044852, sign: true }); + data.append(FP8x23 { mag: 2879371, sign: true }); + data.append(FP8x23 { mag: 7506722, sign: false }); + data.append(FP8x23 { mag: 374708, sign: false }); + data.append(FP8x23 { mag: 16088456, sign: false }); + data.append(FP8x23 { mag: 7446071, sign: false }); + data.append(FP8x23 { mag: 12333473, sign: true }); + data.append(FP8x23 { mag: 2694855, sign: false }); + data.append(FP8x23 { mag: 29333, sign: false }); + data.append(FP8x23 { mag: 3210230, sign: true }); + data.append(FP8x23 { mag: 246487, sign: false }); + data.append(FP8x23 { mag: 11307238, sign: true }); + data.append(FP8x23 { mag: 808074, sign: false }); + data.append(FP8x23 { mag: 2244426, sign: false }); + data.append(FP8x23 { mag: 4494036, sign: false }); + data.append(FP8x23 { mag: 9168918, sign: false }); + data.append(FP8x23 { mag: 11831318, sign: true }); + data.append(FP8x23 { mag: 11514568, sign: true }); + data.append(FP8x23 { mag: 3381120, sign: true }); + data.append(FP8x23 { mag: 6827926, sign: true }); + data.append(FP8x23 { mag: 2579494, sign: false }); + data.append(FP8x23 { mag: 4493030, sign: true }); + data.append(FP8x23 { mag: 4570125, sign: false }); + data.append(FP8x23 { mag: 8010665, sign: false }); + data.append(FP8x23 { mag: 5794037, sign: true }); + data.append(FP8x23 { mag: 9849078, sign: false }); + data.append(FP8x23 { mag: 11691798, sign: true }); + data.append(FP8x23 { mag: 3652747, sign: false }); + data.append(FP8x23 { mag: 1032666, sign: true }); + data.append(FP8x23 { mag: 9259310, sign: true }); + data.append(FP8x23 { mag: 7405492, sign: false }); + data.append(FP8x23 { mag: 4687488, sign: true }); + data.append(FP8x23 { mag: 1097650, sign: true }); + data.append(FP8x23 { mag: 2213858, sign: true }); + data.append(FP8x23 { mag: 1436205, sign: true }); + data.append(FP8x23 { mag: 10226423, sign: false }); + data.append(FP8x23 { mag: 6130226, sign: false }); + data.append(FP8x23 { mag: 1214058, sign: false }); + data.append(FP8x23 { mag: 12299984, sign: true }); + data.append(FP8x23 { mag: 829240, sign: false }); + data.append(FP8x23 { mag: 1612388, sign: false }); + data.append(FP8x23 { mag: 6632529, sign: true }); + data.append(FP8x23 { mag: 4410829, sign: true }); + data.append(FP8x23 { mag: 20654302, sign: false }); + data.append(FP8x23 { mag: 462475, sign: true }); + data.append(FP8x23 { mag: 10502841, sign: true }); + data.append(FP8x23 { mag: 7171902, sign: true }); + data.append(FP8x23 { mag: 4204962, sign: false }); + data.append(FP8x23 { mag: 17427142, sign: true }); + data.append(FP8x23 { mag: 12555224, sign: true }); + data.append(FP8x23 { mag: 8307885, sign: true }); + data.append(FP8x23 { mag: 455536, sign: false }); + data.append(FP8x23 { mag: 577191, sign: false }); + data.append(FP8x23 { mag: 4412268, sign: true }); + data.append(FP8x23 { mag: 15812229, sign: false }); + data.append(FP8x23 { mag: 7098764, sign: true }); + data.append(FP8x23 { mag: 9127468, sign: false }); + data.append(FP8x23 { mag: 4753858, sign: false }); + data.append(FP8x23 { mag: 2074029, sign: true }); + data.append(FP8x23 { mag: 1651256, sign: false }); + data.append(FP8x23 { mag: 9617324, sign: true }); + data.append(FP8x23 { mag: 11400835, sign: false }); + data.append(FP8x23 { mag: 4263073, sign: true }); + data.append(FP8x23 { mag: 22170402, sign: false }); + data.append(FP8x23 { mag: 7715608, sign: false }); + data.append(FP8x23 { mag: 7511781, sign: true }); + data.append(FP8x23 { mag: 8686402, sign: false }); + data.append(FP8x23 { mag: 2710329, sign: true }); + data.append(FP8x23 { mag: 5540998, sign: true }); + data.append(FP8x23 { mag: 11608300, sign: true }); + data.append(FP8x23 { mag: 3020404, sign: true }); + data.append(FP8x23 { mag: 6342478, sign: true }); + data.append(FP8x23 { mag: 9399735, sign: true }); + data.append(FP8x23 { mag: 446463, sign: false }); + data.append(FP8x23 { mag: 13691013, sign: true }); + data.append(FP8x23 { mag: 11552903, sign: true }); + data.append(FP8x23 { mag: 1204731, sign: false }); + data.append(FP8x23 { mag: 1741592, sign: true }); + data.append(FP8x23 { mag: 13103082, sign: false }); + data.append(FP8x23 { mag: 3181444, sign: true }); + data.append(FP8x23 { mag: 256975, sign: true }); + data.append(FP8x23 { mag: 9440785, sign: true }); + data.append(FP8x23 { mag: 2112590, sign: false }); + data.append(FP8x23 { mag: 13404752, sign: false }); + data.append(FP8x23 { mag: 760699, sign: false }); + data.append(FP8x23 { mag: 1588793, sign: true }); + data.append(FP8x23 { mag: 13026604, sign: true }); + data.append(FP8x23 { mag: 452707, sign: false }); + data.append(FP8x23 { mag: 7267348, sign: true }); + data.append(FP8x23 { mag: 14737007, sign: false }); + data.append(FP8x23 { mag: 8457998, sign: false }); + data.append(FP8x23 { mag: 2233703, sign: false }); + data.append(FP8x23 { mag: 3434673, sign: true }); + data.append(FP8x23 { mag: 4280157, sign: true }); + data.append(FP8x23 { mag: 2950181, sign: true }); + data.append(FP8x23 { mag: 1385553, sign: false }); + data.append(FP8x23 { mag: 17250056, sign: false }); + data.append(FP8x23 { mag: 12716927, sign: true }); + data.append(FP8x23 { mag: 2980452, sign: false }); + data.append(FP8x23 { mag: 13031106, sign: true }); + data.append(FP8x23 { mag: 4118717, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis1/input_1.cairo b/tests/nodes/layer_normalization_4d_axis1/input_1.cairo new file mode 100644 index 000000000..9097a814d --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis1/input_1.cairo @@ -0,0 +1,74 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 848181, sign: true }); + data.append(FP8x23 { mag: 81518, sign: true }); + data.append(FP8x23 { mag: 945519, sign: true }); + data.append(FP8x23 { mag: 6487248, sign: false }); + data.append(FP8x23 { mag: 9672423, sign: false }); + data.append(FP8x23 { mag: 1904835, sign: false }); + data.append(FP8x23 { mag: 2382261, sign: false }); + data.append(FP8x23 { mag: 7115567, sign: false }); + data.append(FP8x23 { mag: 2500882, sign: true }); + data.append(FP8x23 { mag: 7175046, sign: false }); + data.append(FP8x23 { mag: 9187167, sign: true }); + data.append(FP8x23 { mag: 8462170, sign: true }); + data.append(FP8x23 { mag: 696867, sign: false }); + data.append(FP8x23 { mag: 21124896, sign: true }); + data.append(FP8x23 { mag: 5267251, sign: true }); + data.append(FP8x23 { mag: 2079757, sign: true }); + data.append(FP8x23 { mag: 1167805, sign: false }); + data.append(FP8x23 { mag: 1152176, sign: false }); + data.append(FP8x23 { mag: 10891270, sign: false }); + data.append(FP8x23 { mag: 14322228, sign: false }); + data.append(FP8x23 { mag: 8522531, sign: false }); + data.append(FP8x23 { mag: 10739640, sign: false }); + data.append(FP8x23 { mag: 2595621, sign: false }); + data.append(FP8x23 { mag: 8005426, sign: true }); + data.append(FP8x23 { mag: 17252540, sign: false }); + data.append(FP8x23 { mag: 240957, sign: false }); + data.append(FP8x23 { mag: 11486831, sign: true }); + data.append(FP8x23 { mag: 586966, sign: false }); + data.append(FP8x23 { mag: 9136961, sign: true }); + data.append(FP8x23 { mag: 6509073, sign: false }); + data.append(FP8x23 { mag: 5880972, sign: true }); + data.append(FP8x23 { mag: 13107216, sign: true }); + data.append(FP8x23 { mag: 3496295, sign: true }); + data.append(FP8x23 { mag: 8424994, sign: true }); + data.append(FP8x23 { mag: 6071613, sign: false }); + data.append(FP8x23 { mag: 2613672, sign: false }); + data.append(FP8x23 { mag: 201175, sign: true }); + data.append(FP8x23 { mag: 9019674, sign: true }); + data.append(FP8x23 { mag: 8834173, sign: true }); + data.append(FP8x23 { mag: 3624311, sign: false }); + data.append(FP8x23 { mag: 11146336, sign: true }); + data.append(FP8x23 { mag: 2676569, sign: false }); + data.append(FP8x23 { mag: 2174599, sign: true }); + data.append(FP8x23 { mag: 7756654, sign: false }); + data.append(FP8x23 { mag: 4488299, sign: false }); + data.append(FP8x23 { mag: 13736548, sign: false }); + data.append(FP8x23 { mag: 8146999, sign: true }); + data.append(FP8x23 { mag: 14514717, sign: false }); + data.append(FP8x23 { mag: 194204, sign: true }); + data.append(FP8x23 { mag: 1722858, sign: true }); + data.append(FP8x23 { mag: 16956510, sign: true }); + data.append(FP8x23 { mag: 938525, sign: true }); + data.append(FP8x23 { mag: 12229289, sign: true }); + data.append(FP8x23 { mag: 10339694, sign: false }); + data.append(FP8x23 { mag: 5410863, sign: true }); + data.append(FP8x23 { mag: 2290298, sign: false }); + data.append(FP8x23 { mag: 523371, sign: true }); + data.append(FP8x23 { mag: 3917128, sign: true }); + data.append(FP8x23 { mag: 8846368, sign: false }); + data.append(FP8x23 { mag: 878873, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis1/input_2.cairo b/tests/nodes/layer_normalization_4d_axis1/input_2.cairo new file mode 100644 index 000000000..7f46fa541 --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis1/input_2.cairo @@ -0,0 +1,74 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn input_2() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 17573086, sign: false }); + data.append(FP8x23 { mag: 8662834, sign: false }); + data.append(FP8x23 { mag: 184643, sign: true }); + data.append(FP8x23 { mag: 1609098, sign: false }); + data.append(FP8x23 { mag: 7906768, sign: false }); + data.append(FP8x23 { mag: 2763381, sign: false }); + data.append(FP8x23 { mag: 6890297, sign: false }); + data.append(FP8x23 { mag: 10693457, sign: true }); + data.append(FP8x23 { mag: 12330995, sign: false }); + data.append(FP8x23 { mag: 4256058, sign: false }); + data.append(FP8x23 { mag: 9437667, sign: true }); + data.append(FP8x23 { mag: 163731, sign: false }); + data.append(FP8x23 { mag: 7124038, sign: false }); + data.append(FP8x23 { mag: 9805754, sign: false }); + data.append(FP8x23 { mag: 28536, sign: false }); + data.append(FP8x23 { mag: 2693033, sign: true }); + data.append(FP8x23 { mag: 2150848, sign: false }); + data.append(FP8x23 { mag: 10084242, sign: true }); + data.append(FP8x23 { mag: 8917886, sign: true }); + data.append(FP8x23 { mag: 1425777, sign: false }); + data.append(FP8x23 { mag: 2606228, sign: true }); + data.append(FP8x23 { mag: 12616894, sign: false }); + data.append(FP8x23 { mag: 2222915, sign: true }); + data.append(FP8x23 { mag: 7726104, sign: true }); + data.append(FP8x23 { mag: 2657657, sign: true }); + data.append(FP8x23 { mag: 6077975, sign: false }); + data.append(FP8x23 { mag: 13813849, sign: true }); + data.append(FP8x23 { mag: 6292283, sign: false }); + data.append(FP8x23 { mag: 10348040, sign: true }); + data.append(FP8x23 { mag: 185170, sign: false }); + data.append(FP8x23 { mag: 5951156, sign: true }); + data.append(FP8x23 { mag: 17316840, sign: false }); + data.append(FP8x23 { mag: 10838919, sign: false }); + data.append(FP8x23 { mag: 1531418, sign: true }); + data.append(FP8x23 { mag: 2328055, sign: true }); + data.append(FP8x23 { mag: 2260306, sign: true }); + data.append(FP8x23 { mag: 6598532, sign: true }); + data.append(FP8x23 { mag: 13454879, sign: false }); + data.append(FP8x23 { mag: 3992621, sign: true }); + data.append(FP8x23 { mag: 2841648, sign: false }); + data.append(FP8x23 { mag: 2219350, sign: false }); + data.append(FP8x23 { mag: 1366422, sign: false }); + data.append(FP8x23 { mag: 18715506, sign: false }); + data.append(FP8x23 { mag: 818771, sign: true }); + data.append(FP8x23 { mag: 16038561, sign: true }); + data.append(FP8x23 { mag: 605075, sign: true }); + data.append(FP8x23 { mag: 9018968, sign: true }); + data.append(FP8x23 { mag: 12957025, sign: true }); + data.append(FP8x23 { mag: 7949557, sign: true }); + data.append(FP8x23 { mag: 4287990, sign: true }); + data.append(FP8x23 { mag: 6013668, sign: true }); + data.append(FP8x23 { mag: 5667271, sign: false }); + data.append(FP8x23 { mag: 3089234, sign: true }); + data.append(FP8x23 { mag: 8146239, sign: true }); + data.append(FP8x23 { mag: 3678288, sign: true }); + data.append(FP8x23 { mag: 4910526, sign: true }); + data.append(FP8x23 { mag: 7866112, sign: false }); + data.append(FP8x23 { mag: 7090234, sign: true }); + data.append(FP8x23 { mag: 7707693, sign: true }); + data.append(FP8x23 { mag: 3506913, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis1/output_0.cairo b/tests/nodes/layer_normalization_4d_axis1/output_0.cairo new file mode 100644 index 000000000..25fb0235f --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis1/output_0.cairo @@ -0,0 +1,135 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 18660112, sign: false }); + data.append(FP8x23 { mag: 8539773, sign: false }); + data.append(FP8x23 { mag: 253071, sign: true }); + data.append(FP8x23 { mag: 4686473, sign: true }); + data.append(FP8x23 { mag: 26410314, sign: false }); + data.append(FP8x23 { mag: 6599035, sign: false }); + data.append(FP8x23 { mag: 6528656, sign: false }); + data.append(FP8x23 { mag: 13917213, sign: true }); + data.append(FP8x23 { mag: 9881702, sign: false }); + data.append(FP8x23 { mag: 1799257, sign: true }); + data.append(FP8x23 { mag: 2470490, sign: true }); + data.append(FP8x23 { mag: 5336625, sign: false }); + data.append(FP8x23 { mag: 9177605, sign: false }); + data.append(FP8x23 { mag: 31518982, sign: false }); + data.append(FP8x23 { mag: 3629625, sign: true }); + data.append(FP8x23 { mag: 1195610, sign: true }); + data.append(FP8x23 { mag: 2611073, sign: false }); + data.append(FP8x23 { mag: 10535922, sign: true }); + data.append(FP8x23 { mag: 1795805, sign: true }); + data.append(FP8x23 { mag: 15467576, sign: true }); + data.append(FP8x23 { mag: 6025885, sign: true }); + data.append(FP8x23 { mag: 22387728, sign: false }); + data.append(FP8x23 { mag: 2198218, sign: true }); + data.append(FP8x23 { mag: 23681456, sign: true }); + data.append(FP8x23 { mag: 12906475, sign: false }); + data.append(FP8x23 { mag: 5693734, sign: false }); + data.append(FP8x23 { mag: 17287328, sign: true }); + data.append(FP8x23 { mag: 6272278, sign: false }); + data.append(FP8x23 { mag: 6300243, sign: true }); + data.append(FP8x23 { mag: 141751, sign: false }); + data.append(FP8x23 { mag: 2665067, sign: false }); + data.append(FP8x23 { mag: 16475111, sign: false }); + data.append(FP8x23 { mag: 9980473, sign: false }); + data.append(FP8x23 { mag: 5992447, sign: true }); + data.append(FP8x23 { mag: 4469792, sign: false }); + data.append(FP8x23 { mag: 6262509, sign: true }); + data.append(FP8x23 { mag: 6298524, sign: true }); + data.append(FP8x23 { mag: 17645284, sign: false }); + data.append(FP8x23 { mag: 3955277, sign: false }); + data.append(FP8x23 { mag: 3884818, sign: false }); + data.append(FP8x23 { mag: 8962234, sign: false }); + data.append(FP8x23 { mag: 2809372, sign: false }); + data.append(FP8x23 { mag: 16598741, sign: false }); + data.append(FP8x23 { mag: 6784943, sign: true }); + data.append(FP8x23 { mag: 10628059, sign: true }); + data.append(FP8x23 { mag: 21397334, sign: true }); + data.append(FP8x23 { mag: 12467617, sign: true }); + data.append(FP8x23 { mag: 15397510, sign: true }); + data.append(FP8x23 { mag: 7715232, sign: true }); + data.append(FP8x23 { mag: 5833417, sign: true }); + data.append(FP8x23 { mag: 4660253, sign: false }); + data.append(FP8x23 { mag: 5832772, sign: false }); + data.append(FP8x23 { mag: 790398, sign: false }); + data.append(FP8x23 { mag: 10411434, sign: true }); + data.append(FP8x23 { mag: 10458644, sign: true }); + data.append(FP8x23 { mag: 3224784, sign: true }); + data.append(FP8x23 { mag: 7805680, sign: false }); + data.append(FP8x23 { mag: 860368, sign: true }); + data.append(FP8x23 { mag: 7115954, sign: true }); + data.append(FP8x23 { mag: 3652584, sign: false }); + data.append(FP8x23 { mag: 18110536, sign: false }); + data.append(FP8x23 { mag: 8694257, sign: false }); + data.append(FP8x23 { mag: 2467611, sign: true }); + data.append(FP8x23 { mag: 1969761, sign: false }); + data.append(FP8x23 { mag: 2403990, sign: true }); + data.append(FP8x23 { mag: 1441612, sign: false }); + data.append(FP8x23 { mag: 8264830, sign: false }); + data.append(FP8x23 { mag: 23782516, sign: true }); + data.append(FP8x23 { mag: 15570296, sign: false }); + data.append(FP8x23 { mag: 1633225, sign: true }); + data.append(FP8x23 { mag: 10890572, sign: true }); + data.append(FP8x23 { mag: 1289518, sign: true }); + data.append(FP8x23 { mag: 6855306, sign: false }); + data.append(FP8x23 { mag: 29774126, sign: true }); + data.append(FP8x23 { mag: 3640459, sign: false }); + data.append(FP8x23 { mag: 5036650, sign: true }); + data.append(FP8x23 { mag: 2896261, sign: false }); + data.append(FP8x23 { mag: 10227605, sign: true }); + data.append(FP8x23 { mag: 5740720, sign: true }); + data.append(FP8x23 { mag: 12424902, sign: true }); + data.append(FP8x23 { mag: 9161889, sign: false }); + data.append(FP8x23 { mag: 8654373, sign: false }); + data.append(FP8x23 { mag: 4483841, sign: false }); + data.append(FP8x23 { mag: 15484597, sign: true }); + data.append(FP8x23 { mag: 15284265, sign: true }); + data.append(FP8x23 { mag: 6337630, sign: false }); + data.append(FP8x23 { mag: 11568081, sign: true }); + data.append(FP8x23 { mag: 5991922, sign: false }); + data.append(FP8x23 { mag: 520230, sign: false }); + data.append(FP8x23 { mag: 1312866, sign: true }); + data.append(FP8x23 { mag: 2415229, sign: true }); + data.append(FP8x23 { mag: 29673920, sign: false }); + data.append(FP8x23 { mag: 10289541, sign: false }); + data.append(FP8x23 { mag: 10450103, sign: false }); + data.append(FP8x23 { mag: 9512570, sign: true }); + data.append(FP8x23 { mag: 1628225, sign: true }); + data.append(FP8x23 { mag: 6580971, sign: true }); + data.append(FP8x23 { mag: 714841, sign: true }); + data.append(FP8x23 { mag: 1800550, sign: true }); + data.append(FP8x23 { mag: 3126343, sign: false }); + data.append(FP8x23 { mag: 12778888, sign: false }); + data.append(FP8x23 { mag: 2285159, sign: false }); + data.append(FP8x23 { mag: 15225974, sign: false }); + data.append(FP8x23 { mag: 672323, sign: false }); + data.append(FP8x23 { mag: 16353744, sign: true }); + data.append(FP8x23 { mag: 19120848, sign: true }); + data.append(FP8x23 { mag: 10304800, sign: true }); + data.append(FP8x23 { mag: 23183576, sign: true }); + data.append(FP8x23 { mag: 8290095, sign: true }); + data.append(FP8x23 { mag: 6100584, sign: true }); + data.append(FP8x23 { mag: 12063426, sign: true }); + data.append(FP8x23 { mag: 5926701, sign: false }); + data.append(FP8x23 { mag: 1446250, sign: false }); + data.append(FP8x23 { mag: 10444769, sign: true }); + data.append(FP8x23 { mag: 5096127, sign: true }); + data.append(FP8x23 { mag: 251532, sign: true }); + data.append(FP8x23 { mag: 8553469, sign: false }); + data.append(FP8x23 { mag: 8814550, sign: true }); + data.append(FP8x23 { mag: 19636340, sign: true }); + data.append(FP8x23 { mag: 3196814, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis2.cairo b/tests/nodes/layer_normalization_4d_axis2.cairo new file mode 100644 index 000000000..27748871e --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis2.cairo @@ -0,0 +1,32 @@ +mod input_0; +mod input_1; +mod input_2; +mod output_0; + + +use orion::operators::tensor::FP8x23Tensor; +use orion::utils::{assert_eq, assert_seq_eq}; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23TensorPartialEq; +use orion::numbers::{IntegerTrait, i32, FixedTrait}; + +#[test] +#[available_gas(2000000000)] +fn test_layer_normalization_4d_axis2() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let input_2 = input_2::input_2(); + let z_0 = output_0::output_0(); + + let (y_0, _, _) = input_0 + .layer_normalization( + @input_1, + Option::Some(@input_2), + Option::Some(IntegerTrait::::new(2, false)), + Option::None, + Option::None + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/layer_normalization_4d_axis2/input_0.cairo b/tests/nodes/layer_normalization_4d_axis2/input_0.cairo new file mode 100644 index 000000000..d7913bb11 --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis2/input_0.cairo @@ -0,0 +1,135 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 9853496, sign: true }); + data.append(FP8x23 { mag: 12258403, sign: false }); + data.append(FP8x23 { mag: 872656, sign: false }); + data.append(FP8x23 { mag: 7388622, sign: true }); + data.append(FP8x23 { mag: 15454325, sign: false }); + data.append(FP8x23 { mag: 16251435, sign: false }); + data.append(FP8x23 { mag: 903277, sign: true }); + data.append(FP8x23 { mag: 3289794, sign: true }); + data.append(FP8x23 { mag: 8057933, sign: false }); + data.append(FP8x23 { mag: 6386388, sign: true }); + data.append(FP8x23 { mag: 5708410, sign: true }); + data.append(FP8x23 { mag: 4543373, sign: true }); + data.append(FP8x23 { mag: 23644376, sign: false }); + data.append(FP8x23 { mag: 7843321, sign: true }); + data.append(FP8x23 { mag: 5801261, sign: false }); + data.append(FP8x23 { mag: 5404517, sign: true }); + data.append(FP8x23 { mag: 3421350, sign: false }); + data.append(FP8x23 { mag: 2806284, sign: true }); + data.append(FP8x23 { mag: 5479745, sign: false }); + data.append(FP8x23 { mag: 9044852, sign: true }); + data.append(FP8x23 { mag: 2879371, sign: true }); + data.append(FP8x23 { mag: 7506722, sign: false }); + data.append(FP8x23 { mag: 374708, sign: false }); + data.append(FP8x23 { mag: 16088456, sign: false }); + data.append(FP8x23 { mag: 7446071, sign: false }); + data.append(FP8x23 { mag: 12333473, sign: true }); + data.append(FP8x23 { mag: 2694855, sign: false }); + data.append(FP8x23 { mag: 29333, sign: false }); + data.append(FP8x23 { mag: 3210230, sign: true }); + data.append(FP8x23 { mag: 246487, sign: false }); + data.append(FP8x23 { mag: 11307238, sign: true }); + data.append(FP8x23 { mag: 808074, sign: false }); + data.append(FP8x23 { mag: 2244426, sign: false }); + data.append(FP8x23 { mag: 4494036, sign: false }); + data.append(FP8x23 { mag: 9168918, sign: false }); + data.append(FP8x23 { mag: 11831318, sign: true }); + data.append(FP8x23 { mag: 11514568, sign: true }); + data.append(FP8x23 { mag: 3381120, sign: true }); + data.append(FP8x23 { mag: 6827926, sign: true }); + data.append(FP8x23 { mag: 2579494, sign: false }); + data.append(FP8x23 { mag: 4493030, sign: true }); + data.append(FP8x23 { mag: 4570125, sign: false }); + data.append(FP8x23 { mag: 8010665, sign: false }); + data.append(FP8x23 { mag: 5794037, sign: true }); + data.append(FP8x23 { mag: 9849078, sign: false }); + data.append(FP8x23 { mag: 11691798, sign: true }); + data.append(FP8x23 { mag: 3652747, sign: false }); + data.append(FP8x23 { mag: 1032666, sign: true }); + data.append(FP8x23 { mag: 9259310, sign: true }); + data.append(FP8x23 { mag: 7405492, sign: false }); + data.append(FP8x23 { mag: 4687488, sign: true }); + data.append(FP8x23 { mag: 1097650, sign: true }); + data.append(FP8x23 { mag: 2213858, sign: true }); + data.append(FP8x23 { mag: 1436205, sign: true }); + data.append(FP8x23 { mag: 10226423, sign: false }); + data.append(FP8x23 { mag: 6130226, sign: false }); + data.append(FP8x23 { mag: 1214058, sign: false }); + data.append(FP8x23 { mag: 12299984, sign: true }); + data.append(FP8x23 { mag: 829240, sign: false }); + data.append(FP8x23 { mag: 1612388, sign: false }); + data.append(FP8x23 { mag: 6632529, sign: true }); + data.append(FP8x23 { mag: 4410829, sign: true }); + data.append(FP8x23 { mag: 20654302, sign: false }); + data.append(FP8x23 { mag: 462475, sign: true }); + data.append(FP8x23 { mag: 10502841, sign: true }); + data.append(FP8x23 { mag: 7171902, sign: true }); + data.append(FP8x23 { mag: 4204962, sign: false }); + data.append(FP8x23 { mag: 17427142, sign: true }); + data.append(FP8x23 { mag: 12555224, sign: true }); + data.append(FP8x23 { mag: 8307885, sign: true }); + data.append(FP8x23 { mag: 455536, sign: false }); + data.append(FP8x23 { mag: 577191, sign: false }); + data.append(FP8x23 { mag: 4412268, sign: true }); + data.append(FP8x23 { mag: 15812229, sign: false }); + data.append(FP8x23 { mag: 7098764, sign: true }); + data.append(FP8x23 { mag: 9127468, sign: false }); + data.append(FP8x23 { mag: 4753858, sign: false }); + data.append(FP8x23 { mag: 2074029, sign: true }); + data.append(FP8x23 { mag: 1651256, sign: false }); + data.append(FP8x23 { mag: 9617324, sign: true }); + data.append(FP8x23 { mag: 11400835, sign: false }); + data.append(FP8x23 { mag: 4263073, sign: true }); + data.append(FP8x23 { mag: 22170402, sign: false }); + data.append(FP8x23 { mag: 7715608, sign: false }); + data.append(FP8x23 { mag: 7511781, sign: true }); + data.append(FP8x23 { mag: 8686402, sign: false }); + data.append(FP8x23 { mag: 2710329, sign: true }); + data.append(FP8x23 { mag: 5540998, sign: true }); + data.append(FP8x23 { mag: 11608300, sign: true }); + data.append(FP8x23 { mag: 3020404, sign: true }); + data.append(FP8x23 { mag: 6342478, sign: true }); + data.append(FP8x23 { mag: 9399735, sign: true }); + data.append(FP8x23 { mag: 446463, sign: false }); + data.append(FP8x23 { mag: 13691013, sign: true }); + data.append(FP8x23 { mag: 11552903, sign: true }); + data.append(FP8x23 { mag: 1204731, sign: false }); + data.append(FP8x23 { mag: 1741592, sign: true }); + data.append(FP8x23 { mag: 13103082, sign: false }); + data.append(FP8x23 { mag: 3181444, sign: true }); + data.append(FP8x23 { mag: 256975, sign: true }); + data.append(FP8x23 { mag: 9440785, sign: true }); + data.append(FP8x23 { mag: 2112590, sign: false }); + data.append(FP8x23 { mag: 13404752, sign: false }); + data.append(FP8x23 { mag: 760699, sign: false }); + data.append(FP8x23 { mag: 1588793, sign: true }); + data.append(FP8x23 { mag: 13026604, sign: true }); + data.append(FP8x23 { mag: 452707, sign: false }); + data.append(FP8x23 { mag: 7267348, sign: true }); + data.append(FP8x23 { mag: 14737007, sign: false }); + data.append(FP8x23 { mag: 8457998, sign: false }); + data.append(FP8x23 { mag: 2233703, sign: false }); + data.append(FP8x23 { mag: 3434673, sign: true }); + data.append(FP8x23 { mag: 4280157, sign: true }); + data.append(FP8x23 { mag: 2950181, sign: true }); + data.append(FP8x23 { mag: 1385553, sign: false }); + data.append(FP8x23 { mag: 17250056, sign: false }); + data.append(FP8x23 { mag: 12716927, sign: true }); + data.append(FP8x23 { mag: 2980452, sign: false }); + data.append(FP8x23 { mag: 13031106, sign: true }); + data.append(FP8x23 { mag: 4118717, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis2/input_1.cairo b/tests/nodes/layer_normalization_4d_axis2/input_1.cairo new file mode 100644 index 000000000..73de56051 --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis2/input_1.cairo @@ -0,0 +1,33 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 16045555, sign: false }); + data.append(FP8x23 { mag: 6797189, sign: true }); + data.append(FP8x23 { mag: 338571, sign: true }); + data.append(FP8x23 { mag: 14826208, sign: true }); + data.append(FP8x23 { mag: 6612261, sign: false }); + data.append(FP8x23 { mag: 2255963, sign: false }); + data.append(FP8x23 { mag: 7694826, sign: false }); + data.append(FP8x23 { mag: 8157877, sign: true }); + data.append(FP8x23 { mag: 10027904, sign: true }); + data.append(FP8x23 { mag: 4144258, sign: false }); + data.append(FP8x23 { mag: 12368555, sign: true }); + data.append(FP8x23 { mag: 1431810, sign: false }); + data.append(FP8x23 { mag: 993247, sign: true }); + data.append(FP8x23 { mag: 10015980, sign: true }); + data.append(FP8x23 { mag: 11250731, sign: false }); + data.append(FP8x23 { mag: 12224184, sign: true }); + data.append(FP8x23 { mag: 14407597, sign: true }); + data.append(FP8x23 { mag: 1255469, sign: true }); + data.append(FP8x23 { mag: 48578, sign: true }); + data.append(FP8x23 { mag: 14580561, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis2/input_2.cairo b/tests/nodes/layer_normalization_4d_axis2/input_2.cairo new file mode 100644 index 000000000..5dd984a06 --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis2/input_2.cairo @@ -0,0 +1,33 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn input_2() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 10596986, sign: false }); + data.append(FP8x23 { mag: 13797276, sign: true }); + data.append(FP8x23 { mag: 726032, sign: true }); + data.append(FP8x23 { mag: 2944650, sign: true }); + data.append(FP8x23 { mag: 5288885, sign: true }); + data.append(FP8x23 { mag: 12046768, sign: true }); + data.append(FP8x23 { mag: 3375686, sign: false }); + data.append(FP8x23 { mag: 8744354, sign: true }); + data.append(FP8x23 { mag: 8940485, sign: false }); + data.append(FP8x23 { mag: 6541405, sign: true }); + data.append(FP8x23 { mag: 3256492, sign: false }); + data.append(FP8x23 { mag: 6889087, sign: false }); + data.append(FP8x23 { mag: 2560312, sign: true }); + data.append(FP8x23 { mag: 9717397, sign: true }); + data.append(FP8x23 { mag: 8774793, sign: true }); + data.append(FP8x23 { mag: 893052, sign: true }); + data.append(FP8x23 { mag: 7995400, sign: false }); + data.append(FP8x23 { mag: 9505615, sign: true }); + data.append(FP8x23 { mag: 541572, sign: false }); + data.append(FP8x23 { mag: 13005167, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis2/output_0.cairo b/tests/nodes/layer_normalization_4d_axis2/output_0.cairo new file mode 100644 index 000000000..b3de0fa1c --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis2/output_0.cairo @@ -0,0 +1,135 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 8703072, sign: true }); + data.append(FP8x23 { mag: 21681156, sign: true }); + data.append(FP8x23 { mag: 706829, sign: true }); + data.append(FP8x23 { mag: 10983862, sign: false }); + data.append(FP8x23 { mag: 4638530, sign: false }); + data.append(FP8x23 { mag: 8467598, sign: true }); + data.append(FP8x23 { mag: 1479070, sign: false }); + data.append(FP8x23 { mag: 4653312, sign: true }); + data.append(FP8x23 { mag: 1810198, sign: false }); + data.append(FP8x23 { mag: 9990925, sign: true }); + data.append(FP8x23 { mag: 12655580, sign: false }); + data.append(FP8x23 { mag: 5979269, sign: false }); + data.append(FP8x23 { mag: 4920753, sign: true }); + data.append(FP8x23 { mag: 178767, sign: false }); + data.append(FP8x23 { mag: 3487911, sign: true }); + data.append(FP8x23 { mag: 7999385, sign: false }); + data.append(FP8x23 { mag: 4888891, sign: false }); + data.append(FP8x23 { mag: 8940880, sign: true }); + data.append(FP8x23 { mag: 520413, sign: false }); + data.append(FP8x23 { mag: 29283250, sign: true }); + data.append(FP8x23 { mag: 5451075, sign: false }); + data.append(FP8x23 { mag: 21054176, sign: true }); + data.append(FP8x23 { mag: 764722, sign: true }); + data.append(FP8x23 { mag: 35781404, sign: true }); + data.append(FP8x23 { mag: 1716971, sign: false }); + data.append(FP8x23 { mag: 15621259, sign: true }); + data.append(FP8x23 { mag: 6641498, sign: false }); + data.append(FP8x23 { mag: 9299975, sign: true }); + data.append(FP8x23 { mag: 12600001, sign: false }); + data.append(FP8x23 { mag: 6138847, sign: true }); + data.append(FP8x23 { mag: 21157298, sign: false }); + data.append(FP8x23 { mag: 7135652, sign: false }); + data.append(FP8x23 { mag: 2922059, sign: true }); + data.append(FP8x23 { mag: 16377204, sign: true }); + data.append(FP8x23 { mag: 5736659, sign: false }); + data.append(FP8x23 { mag: 17655176, sign: false }); + data.append(FP8x23 { mag: 29246576, sign: false }); + data.append(FP8x23 { mag: 9018773, sign: true }); + data.append(FP8x23 { mag: 582792, sign: false }); + data.append(FP8x23 { mag: 7041778, sign: true }); + data.append(FP8x23 { mag: 361699, sign: true }); + data.append(FP8x23 { mag: 18572216, sign: true }); + data.append(FP8x23 { mag: 1141944, sign: true }); + data.append(FP8x23 { mag: 10129914, sign: false }); + data.append(FP8x23 { mag: 4692104, sign: false }); + data.append(FP8x23 { mag: 16070120, sign: true }); + data.append(FP8x23 { mag: 7702104, sign: false }); + data.append(FP8x23 { mag: 7488068, sign: true }); + data.append(FP8x23 { mag: 23095692, sign: false }); + data.append(FP8x23 { mag: 1833849, sign: true }); + data.append(FP8x23 { mag: 12071556, sign: false }); + data.append(FP8x23 { mag: 6654369, sign: false }); + data.append(FP8x23 { mag: 2228009, sign: true }); + data.append(FP8x23 { mag: 7557105, sign: true }); + data.append(FP8x23 { mag: 8856797, sign: false }); + data.append(FP8x23 { mag: 12395714, sign: true }); + data.append(FP8x23 { mag: 5265819, sign: false }); + data.append(FP8x23 { mag: 7149849, sign: true }); + data.append(FP8x23 { mag: 535227, sign: false }); + data.append(FP8x23 { mag: 9354983, sign: true }); + data.append(FP8x23 { mag: 1886297, sign: false }); + data.append(FP8x23 { mag: 11759874, sign: true }); + data.append(FP8x23 { mag: 1553244, sign: true }); + data.append(FP8x23 { mag: 4906809, sign: true }); + data.append(FP8x23 { mag: 11679092, sign: true }); + data.append(FP8x23 { mag: 13404629, sign: true }); + data.append(FP8x23 { mag: 8324403, sign: false }); + data.append(FP8x23 { mag: 5321254, sign: false }); + data.append(FP8x23 { mag: 20883912, sign: false }); + data.append(FP8x23 { mag: 9551024, sign: true }); + data.append(FP8x23 { mag: 377017, sign: false }); + data.append(FP8x23 { mag: 7241483, sign: false }); + data.append(FP8x23 { mag: 2262438, sign: true }); + data.append(FP8x23 { mag: 28881548, sign: true }); + data.append(FP8x23 { mag: 15456542, sign: true }); + data.append(FP8x23 { mag: 15339766, sign: true }); + data.append(FP8x23 { mag: 2135888, sign: true }); + data.append(FP8x23 { mag: 9450355, sign: true }); + data.append(FP8x23 { mag: 523906, sign: false }); + data.append(FP8x23 { mag: 25683136, sign: true }); + data.append(FP8x23 { mag: 32249876, sign: false }); + data.append(FP8x23 { mag: 11198256, sign: true }); + data.append(FP8x23 { mag: 1586058, sign: true }); + data.append(FP8x23 { mag: 16911206, sign: true }); + data.append(FP8x23 { mag: 10192207, sign: true }); + data.append(FP8x23 { mag: 9679471, sign: true }); + data.append(FP8x23 { mag: 1754443, sign: false }); + data.append(FP8x23 { mag: 4472428, sign: true }); + data.append(FP8x23 { mag: 20918500, sign: false }); + data.append(FP8x23 { mag: 7556645, sign: true }); + data.append(FP8x23 { mag: 10829383, sign: false }); + data.append(FP8x23 { mag: 5528458, sign: false }); + data.append(FP8x23 { mag: 2697706, sign: true }); + data.append(FP8x23 { mag: 4552741, sign: false }); + data.append(FP8x23 { mag: 22144528, sign: true }); + data.append(FP8x23 { mag: 3608822, sign: true }); + data.append(FP8x23 { mag: 9487844, sign: false }); + data.append(FP8x23 { mag: 11436110, sign: true }); + data.append(FP8x23 { mag: 554338, sign: false }); + data.append(FP8x23 { mag: 12122246, sign: true }); + data.append(FP8x23 { mag: 6229258, sign: true }); + data.append(FP8x23 { mag: 15782264, sign: true }); + data.append(FP8x23 { mag: 1268559, sign: true }); + data.append(FP8x23 { mag: 4948465, sign: true }); + data.append(FP8x23 { mag: 6197991, sign: true }); + data.append(FP8x23 { mag: 15351219, sign: true }); + data.append(FP8x23 { mag: 4140654, sign: false }); + data.append(FP8x23 { mag: 2247070, sign: true }); + data.append(FP8x23 { mag: 8678569, sign: true }); + data.append(FP8x23 { mag: 2279577, sign: true }); + data.append(FP8x23 { mag: 529337, sign: true }); + data.append(FP8x23 { mag: 6385535, sign: false }); + data.append(FP8x23 { mag: 2113548, sign: true }); + data.append(FP8x23 { mag: 6758004, sign: true }); + data.append(FP8x23 { mag: 6438428, sign: true }); + data.append(FP8x23 { mag: 25935804, sign: true }); + data.append(FP8x23 { mag: 28581364, sign: false }); + data.append(FP8x23 { mag: 9998688, sign: true }); + data.append(FP8x23 { mag: 612754, sign: false }); + data.append(FP8x23 { mag: 19290370, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis3.cairo b/tests/nodes/layer_normalization_4d_axis3.cairo new file mode 100644 index 000000000..2de09a8d6 --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis3.cairo @@ -0,0 +1,32 @@ +mod input_0; +mod input_1; +mod input_2; +mod output_0; + + +use orion::operators::tensor::FP8x23Tensor; +use orion::utils::{assert_eq, assert_seq_eq}; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23TensorPartialEq; +use orion::numbers::{IntegerTrait, i32, FixedTrait}; + +#[test] +#[available_gas(2000000000)] +fn test_layer_normalization_4d_axis3() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let input_2 = input_2::input_2(); + let z_0 = output_0::output_0(); + + let (y_0, _, _) = input_0 + .layer_normalization( + @input_1, + Option::Some(@input_2), + Option::Some(IntegerTrait::::new(3, false)), + Option::None, + Option::None + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/layer_normalization_4d_axis3/input_0.cairo b/tests/nodes/layer_normalization_4d_axis3/input_0.cairo new file mode 100644 index 000000000..d7913bb11 --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis3/input_0.cairo @@ -0,0 +1,135 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 9853496, sign: true }); + data.append(FP8x23 { mag: 12258403, sign: false }); + data.append(FP8x23 { mag: 872656, sign: false }); + data.append(FP8x23 { mag: 7388622, sign: true }); + data.append(FP8x23 { mag: 15454325, sign: false }); + data.append(FP8x23 { mag: 16251435, sign: false }); + data.append(FP8x23 { mag: 903277, sign: true }); + data.append(FP8x23 { mag: 3289794, sign: true }); + data.append(FP8x23 { mag: 8057933, sign: false }); + data.append(FP8x23 { mag: 6386388, sign: true }); + data.append(FP8x23 { mag: 5708410, sign: true }); + data.append(FP8x23 { mag: 4543373, sign: true }); + data.append(FP8x23 { mag: 23644376, sign: false }); + data.append(FP8x23 { mag: 7843321, sign: true }); + data.append(FP8x23 { mag: 5801261, sign: false }); + data.append(FP8x23 { mag: 5404517, sign: true }); + data.append(FP8x23 { mag: 3421350, sign: false }); + data.append(FP8x23 { mag: 2806284, sign: true }); + data.append(FP8x23 { mag: 5479745, sign: false }); + data.append(FP8x23 { mag: 9044852, sign: true }); + data.append(FP8x23 { mag: 2879371, sign: true }); + data.append(FP8x23 { mag: 7506722, sign: false }); + data.append(FP8x23 { mag: 374708, sign: false }); + data.append(FP8x23 { mag: 16088456, sign: false }); + data.append(FP8x23 { mag: 7446071, sign: false }); + data.append(FP8x23 { mag: 12333473, sign: true }); + data.append(FP8x23 { mag: 2694855, sign: false }); + data.append(FP8x23 { mag: 29333, sign: false }); + data.append(FP8x23 { mag: 3210230, sign: true }); + data.append(FP8x23 { mag: 246487, sign: false }); + data.append(FP8x23 { mag: 11307238, sign: true }); + data.append(FP8x23 { mag: 808074, sign: false }); + data.append(FP8x23 { mag: 2244426, sign: false }); + data.append(FP8x23 { mag: 4494036, sign: false }); + data.append(FP8x23 { mag: 9168918, sign: false }); + data.append(FP8x23 { mag: 11831318, sign: true }); + data.append(FP8x23 { mag: 11514568, sign: true }); + data.append(FP8x23 { mag: 3381120, sign: true }); + data.append(FP8x23 { mag: 6827926, sign: true }); + data.append(FP8x23 { mag: 2579494, sign: false }); + data.append(FP8x23 { mag: 4493030, sign: true }); + data.append(FP8x23 { mag: 4570125, sign: false }); + data.append(FP8x23 { mag: 8010665, sign: false }); + data.append(FP8x23 { mag: 5794037, sign: true }); + data.append(FP8x23 { mag: 9849078, sign: false }); + data.append(FP8x23 { mag: 11691798, sign: true }); + data.append(FP8x23 { mag: 3652747, sign: false }); + data.append(FP8x23 { mag: 1032666, sign: true }); + data.append(FP8x23 { mag: 9259310, sign: true }); + data.append(FP8x23 { mag: 7405492, sign: false }); + data.append(FP8x23 { mag: 4687488, sign: true }); + data.append(FP8x23 { mag: 1097650, sign: true }); + data.append(FP8x23 { mag: 2213858, sign: true }); + data.append(FP8x23 { mag: 1436205, sign: true }); + data.append(FP8x23 { mag: 10226423, sign: false }); + data.append(FP8x23 { mag: 6130226, sign: false }); + data.append(FP8x23 { mag: 1214058, sign: false }); + data.append(FP8x23 { mag: 12299984, sign: true }); + data.append(FP8x23 { mag: 829240, sign: false }); + data.append(FP8x23 { mag: 1612388, sign: false }); + data.append(FP8x23 { mag: 6632529, sign: true }); + data.append(FP8x23 { mag: 4410829, sign: true }); + data.append(FP8x23 { mag: 20654302, sign: false }); + data.append(FP8x23 { mag: 462475, sign: true }); + data.append(FP8x23 { mag: 10502841, sign: true }); + data.append(FP8x23 { mag: 7171902, sign: true }); + data.append(FP8x23 { mag: 4204962, sign: false }); + data.append(FP8x23 { mag: 17427142, sign: true }); + data.append(FP8x23 { mag: 12555224, sign: true }); + data.append(FP8x23 { mag: 8307885, sign: true }); + data.append(FP8x23 { mag: 455536, sign: false }); + data.append(FP8x23 { mag: 577191, sign: false }); + data.append(FP8x23 { mag: 4412268, sign: true }); + data.append(FP8x23 { mag: 15812229, sign: false }); + data.append(FP8x23 { mag: 7098764, sign: true }); + data.append(FP8x23 { mag: 9127468, sign: false }); + data.append(FP8x23 { mag: 4753858, sign: false }); + data.append(FP8x23 { mag: 2074029, sign: true }); + data.append(FP8x23 { mag: 1651256, sign: false }); + data.append(FP8x23 { mag: 9617324, sign: true }); + data.append(FP8x23 { mag: 11400835, sign: false }); + data.append(FP8x23 { mag: 4263073, sign: true }); + data.append(FP8x23 { mag: 22170402, sign: false }); + data.append(FP8x23 { mag: 7715608, sign: false }); + data.append(FP8x23 { mag: 7511781, sign: true }); + data.append(FP8x23 { mag: 8686402, sign: false }); + data.append(FP8x23 { mag: 2710329, sign: true }); + data.append(FP8x23 { mag: 5540998, sign: true }); + data.append(FP8x23 { mag: 11608300, sign: true }); + data.append(FP8x23 { mag: 3020404, sign: true }); + data.append(FP8x23 { mag: 6342478, sign: true }); + data.append(FP8x23 { mag: 9399735, sign: true }); + data.append(FP8x23 { mag: 446463, sign: false }); + data.append(FP8x23 { mag: 13691013, sign: true }); + data.append(FP8x23 { mag: 11552903, sign: true }); + data.append(FP8x23 { mag: 1204731, sign: false }); + data.append(FP8x23 { mag: 1741592, sign: true }); + data.append(FP8x23 { mag: 13103082, sign: false }); + data.append(FP8x23 { mag: 3181444, sign: true }); + data.append(FP8x23 { mag: 256975, sign: true }); + data.append(FP8x23 { mag: 9440785, sign: true }); + data.append(FP8x23 { mag: 2112590, sign: false }); + data.append(FP8x23 { mag: 13404752, sign: false }); + data.append(FP8x23 { mag: 760699, sign: false }); + data.append(FP8x23 { mag: 1588793, sign: true }); + data.append(FP8x23 { mag: 13026604, sign: true }); + data.append(FP8x23 { mag: 452707, sign: false }); + data.append(FP8x23 { mag: 7267348, sign: true }); + data.append(FP8x23 { mag: 14737007, sign: false }); + data.append(FP8x23 { mag: 8457998, sign: false }); + data.append(FP8x23 { mag: 2233703, sign: false }); + data.append(FP8x23 { mag: 3434673, sign: true }); + data.append(FP8x23 { mag: 4280157, sign: true }); + data.append(FP8x23 { mag: 2950181, sign: true }); + data.append(FP8x23 { mag: 1385553, sign: false }); + data.append(FP8x23 { mag: 17250056, sign: false }); + data.append(FP8x23 { mag: 12716927, sign: true }); + data.append(FP8x23 { mag: 2980452, sign: false }); + data.append(FP8x23 { mag: 13031106, sign: true }); + data.append(FP8x23 { mag: 4118717, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis3/input_1.cairo b/tests/nodes/layer_normalization_4d_axis3/input_1.cairo new file mode 100644 index 000000000..c4f5b5a1b --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis3/input_1.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 6200252, sign: false }); + data.append(FP8x23 { mag: 4645727, sign: true }); + data.append(FP8x23 { mag: 18131524, sign: false }); + data.append(FP8x23 { mag: 9670945, sign: true }); + data.append(FP8x23 { mag: 4399430, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis3/input_2.cairo b/tests/nodes/layer_normalization_4d_axis3/input_2.cairo new file mode 100644 index 000000000..6dac76016 --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis3/input_2.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn input_2() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 538440, sign: false }); + data.append(FP8x23 { mag: 3468236, sign: false }); + data.append(FP8x23 { mag: 8790979, sign: false }); + data.append(FP8x23 { mag: 32869270, sign: false }); + data.append(FP8x23 { mag: 630553, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis3/output_0.cairo b/tests/nodes/layer_normalization_4d_axis3/output_0.cairo new file mode 100644 index 000000000..9be521f05 --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis3/output_0.cairo @@ -0,0 +1,135 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 6861451, sign: true }); + data.append(FP8x23 { mag: 1101011, sign: true }); + data.append(FP8x23 { mag: 6298939, sign: false }); + data.append(FP8x23 { mag: 42064444, sign: false }); + data.append(FP8x23 { mag: 6341855, sign: false }); + data.append(FP8x23 { mag: 10637301, sign: false }); + data.append(FP8x23 { mag: 5512855, sign: false }); + data.append(FP8x23 { mag: 4407421, sign: true }); + data.append(FP8x23 { mag: 26673762, sign: false }); + data.append(FP8x23 { mag: 4214892, sign: true }); + data.append(FP8x23 { mag: 3698203, sign: true }); + data.append(FP8x23 { mag: 6179134, sign: false }); + data.append(FP8x23 { mag: 41981648, sign: false }); + data.append(FP8x23 { mag: 41245676, sign: false }); + data.append(FP8x23 { mag: 1961019, sign: false }); + data.append(FP8x23 { mag: 3735727, sign: true }); + data.append(FP8x23 { mag: 899718, sign: true }); + data.append(FP8x23 { mag: 4990078, sign: false }); + data.append(FP8x23 { mag: 20101102, sign: false }); + data.append(FP8x23 { mag: 5359211, sign: true }); + data.append(FP8x23 { mag: 7557954, sign: true }); + data.append(FP8x23 { mag: 2196963, sign: false }); + data.append(FP8x23 { mag: 5912856, sign: true }); + data.append(FP8x23 { mag: 17601684, sign: false }); + data.append(FP8x23 { mag: 1793852, sign: false }); + data.append(FP8x23 { mag: 11044257, sign: true }); + data.append(FP8x23 { mag: 1136295, sign: true }); + data.append(FP8x23 { mag: 17566640, sign: false }); + data.append(FP8x23 { mag: 34149192, sign: false }); + data.append(FP8x23 { mag: 2941641, sign: false }); + data.append(FP8x23 { mag: 10739165, sign: true }); + data.append(FP8x23 { mag: 3654829, sign: false }); + data.append(FP8x23 { mag: 11886320, sign: false }); + data.append(FP8x23 { mag: 28024164, sign: false }); + data.append(FP8x23 { mag: 5854203, sign: false }); + data.append(FP8x23 { mag: 5942399, sign: true }); + data.append(FP8x23 { mag: 8051302, sign: false }); + data.append(FP8x23 { mag: 18253056, sign: false }); + data.append(FP8x23 { mag: 34004268, sign: false }); + data.append(FP8x23 { mag: 7789614, sign: false }); + data.append(FP8x23 { mag: 6142160, sign: true }); + data.append(FP8x23 { mag: 1919472, sign: false }); + data.append(FP8x23 { mag: 24546494, sign: false }); + data.append(FP8x23 { mag: 45248064, sign: false }); + data.append(FP8x23 { mag: 5712510, sign: false }); + data.append(FP8x23 { mag: 7515486, sign: true }); + data.append(FP8x23 { mag: 237511, sign: true }); + data.append(FP8x23 { mag: 11646082, sign: false }); + data.append(FP8x23 { mag: 42217200, sign: false }); + data.append(FP8x23 { mag: 6395711, sign: false }); + data.append(FP8x23 { mag: 5252277, sign: true }); + data.append(FP8x23 { mag: 4592765, sign: false }); + data.append(FP8x23 { mag: 501418, sign: false }); + data.append(FP8x23 { mag: 35841232, sign: false }); + data.append(FP8x23 { mag: 9167666, sign: false }); + data.append(FP8x23 { mag: 7168818, sign: false }); + data.append(FP8x23 { mag: 2182334, sign: false }); + data.append(FP8x23 { mag: 25693876, sign: true }); + data.append(FP8x23 { mag: 30792410, sign: false }); + data.append(FP8x23 { mag: 2130807, sign: false }); + data.append(FP8x23 { mag: 3061656, sign: true }); + data.append(FP8x23 { mag: 5223668, sign: false }); + data.append(FP8x23 { mag: 43419840, sign: false }); + data.append(FP8x23 { mag: 33038392, sign: false }); + data.append(FP8x23 { mag: 3478012, sign: true }); + data.append(FP8x23 { mag: 1468348, sign: false }); + data.append(FP8x23 { mag: 4571455, sign: true }); + data.append(FP8x23 { mag: 14322581, sign: true }); + data.append(FP8x23 { mag: 38651728, sign: false }); + data.append(FP8x23 { mag: 596053, sign: false }); + data.append(FP8x23 { mag: 60762, sign: false }); + data.append(FP8x23 { mag: 3754916, sign: false }); + data.append(FP8x23 { mag: 3730281, sign: true }); + data.append(FP8x23 { mag: 14895674, sign: false }); + data.append(FP8x23 { mag: 3897275, sign: true }); + data.append(FP8x23 { mag: 8682959, sign: false }); + data.append(FP8x23 { mag: 558598, sign: false }); + data.append(FP8x23 { mag: 692740, sign: false }); + data.append(FP8x23 { mag: 31527354, sign: false }); + data.append(FP8x23 { mag: 6549321, sign: true }); + data.append(FP8x23 { mag: 3699828, sign: false }); + data.append(FP8x23 { mag: 7847593, sign: false }); + data.append(FP8x23 { mag: 36143524, sign: false }); + data.append(FP8x23 { mag: 31243174, sign: false }); + data.append(FP8x23 { mag: 4841996, sign: true }); + data.append(FP8x23 { mag: 11383225, sign: false }); + data.append(FP8x23 { mag: 3377710, sign: false }); + data.append(FP8x23 { mag: 1355152, sign: false }); + data.append(FP8x23 { mag: 45740296, sign: false }); + data.append(FP8x23 { mag: 509252, sign: false }); + data.append(FP8x23 { mag: 2763536, sign: false }); + data.append(FP8x23 { mag: 4688158, sign: false }); + data.append(FP8x23 { mag: 40319700, sign: false }); + data.append(FP8x23 { mag: 43844796, sign: false }); + data.append(FP8x23 { mag: 2450252, sign: true }); + data.append(FP8x23 { mag: 122285, sign: true }); + data.append(FP8x23 { mag: 6312799, sign: false }); + data.append(FP8x23 { mag: 43889444, sign: false }); + data.append(FP8x23 { mag: 41180916, sign: false }); + data.append(FP8x23 { mag: 942084, sign: true }); + data.append(FP8x23 { mag: 8298549, sign: true }); + data.append(FP8x23 { mag: 2797356, sign: false }); + data.append(FP8x23 { mag: 39226368, sign: false }); + data.append(FP8x23 { mag: 33248984, sign: false }); + data.append(FP8x23 { mag: 946515, sign: true }); + data.append(FP8x23 { mag: 7877337, sign: true }); + data.append(FP8x23 { mag: 3568616, sign: false }); + data.append(FP8x23 { mag: 5471631, sign: true }); + data.append(FP8x23 { mag: 19389094, sign: false }); + data.append(FP8x23 { mag: 4025465, sign: false }); + data.append(FP8x23 { mag: 8978928, sign: false }); + data.append(FP8x23 { mag: 6984708, sign: false }); + data.append(FP8x23 { mag: 10661943, sign: true }); + data.append(FP8x23 { mag: 38438520, sign: false }); + data.append(FP8x23 { mag: 5225164, sign: false }); + data.append(FP8x23 { mag: 11078083, sign: false }); + data.append(FP8x23 { mag: 7911388, sign: false }); + data.append(FP8x23 { mag: 16678516, sign: false }); + data.append(FP8x23 { mag: 42387840, sign: false }); + data.append(FP8x23 { mag: 224042, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis_negative_1.cairo b/tests/nodes/layer_normalization_4d_axis_negative_1.cairo new file mode 100644 index 000000000..693dc4311 --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis_negative_1.cairo @@ -0,0 +1,32 @@ +mod input_0; +mod input_1; +mod input_2; +mod output_0; + + +use orion::operators::tensor::FP8x23Tensor; +use orion::utils::{assert_eq, assert_seq_eq}; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23TensorPartialEq; +use orion::numbers::{IntegerTrait, i32, FixedTrait}; + +#[test] +#[available_gas(2000000000)] +fn test_layer_normalization_4d_axis_negative_1() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let input_2 = input_2::input_2(); + let z_0 = output_0::output_0(); + + let (y_0, _, _) = input_0 + .layer_normalization( + @input_1, + Option::Some(@input_2), + Option::Some(IntegerTrait::::new(1, true)), + Option::None, + Option::None + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/layer_normalization_4d_axis_negative_1/input_0.cairo b/tests/nodes/layer_normalization_4d_axis_negative_1/input_0.cairo new file mode 100644 index 000000000..d7913bb11 --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis_negative_1/input_0.cairo @@ -0,0 +1,135 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 9853496, sign: true }); + data.append(FP8x23 { mag: 12258403, sign: false }); + data.append(FP8x23 { mag: 872656, sign: false }); + data.append(FP8x23 { mag: 7388622, sign: true }); + data.append(FP8x23 { mag: 15454325, sign: false }); + data.append(FP8x23 { mag: 16251435, sign: false }); + data.append(FP8x23 { mag: 903277, sign: true }); + data.append(FP8x23 { mag: 3289794, sign: true }); + data.append(FP8x23 { mag: 8057933, sign: false }); + data.append(FP8x23 { mag: 6386388, sign: true }); + data.append(FP8x23 { mag: 5708410, sign: true }); + data.append(FP8x23 { mag: 4543373, sign: true }); + data.append(FP8x23 { mag: 23644376, sign: false }); + data.append(FP8x23 { mag: 7843321, sign: true }); + data.append(FP8x23 { mag: 5801261, sign: false }); + data.append(FP8x23 { mag: 5404517, sign: true }); + data.append(FP8x23 { mag: 3421350, sign: false }); + data.append(FP8x23 { mag: 2806284, sign: true }); + data.append(FP8x23 { mag: 5479745, sign: false }); + data.append(FP8x23 { mag: 9044852, sign: true }); + data.append(FP8x23 { mag: 2879371, sign: true }); + data.append(FP8x23 { mag: 7506722, sign: false }); + data.append(FP8x23 { mag: 374708, sign: false }); + data.append(FP8x23 { mag: 16088456, sign: false }); + data.append(FP8x23 { mag: 7446071, sign: false }); + data.append(FP8x23 { mag: 12333473, sign: true }); + data.append(FP8x23 { mag: 2694855, sign: false }); + data.append(FP8x23 { mag: 29333, sign: false }); + data.append(FP8x23 { mag: 3210230, sign: true }); + data.append(FP8x23 { mag: 246487, sign: false }); + data.append(FP8x23 { mag: 11307238, sign: true }); + data.append(FP8x23 { mag: 808074, sign: false }); + data.append(FP8x23 { mag: 2244426, sign: false }); + data.append(FP8x23 { mag: 4494036, sign: false }); + data.append(FP8x23 { mag: 9168918, sign: false }); + data.append(FP8x23 { mag: 11831318, sign: true }); + data.append(FP8x23 { mag: 11514568, sign: true }); + data.append(FP8x23 { mag: 3381120, sign: true }); + data.append(FP8x23 { mag: 6827926, sign: true }); + data.append(FP8x23 { mag: 2579494, sign: false }); + data.append(FP8x23 { mag: 4493030, sign: true }); + data.append(FP8x23 { mag: 4570125, sign: false }); + data.append(FP8x23 { mag: 8010665, sign: false }); + data.append(FP8x23 { mag: 5794037, sign: true }); + data.append(FP8x23 { mag: 9849078, sign: false }); + data.append(FP8x23 { mag: 11691798, sign: true }); + data.append(FP8x23 { mag: 3652747, sign: false }); + data.append(FP8x23 { mag: 1032666, sign: true }); + data.append(FP8x23 { mag: 9259310, sign: true }); + data.append(FP8x23 { mag: 7405492, sign: false }); + data.append(FP8x23 { mag: 4687488, sign: true }); + data.append(FP8x23 { mag: 1097650, sign: true }); + data.append(FP8x23 { mag: 2213858, sign: true }); + data.append(FP8x23 { mag: 1436205, sign: true }); + data.append(FP8x23 { mag: 10226423, sign: false }); + data.append(FP8x23 { mag: 6130226, sign: false }); + data.append(FP8x23 { mag: 1214058, sign: false }); + data.append(FP8x23 { mag: 12299984, sign: true }); + data.append(FP8x23 { mag: 829240, sign: false }); + data.append(FP8x23 { mag: 1612388, sign: false }); + data.append(FP8x23 { mag: 6632529, sign: true }); + data.append(FP8x23 { mag: 4410829, sign: true }); + data.append(FP8x23 { mag: 20654302, sign: false }); + data.append(FP8x23 { mag: 462475, sign: true }); + data.append(FP8x23 { mag: 10502841, sign: true }); + data.append(FP8x23 { mag: 7171902, sign: true }); + data.append(FP8x23 { mag: 4204962, sign: false }); + data.append(FP8x23 { mag: 17427142, sign: true }); + data.append(FP8x23 { mag: 12555224, sign: true }); + data.append(FP8x23 { mag: 8307885, sign: true }); + data.append(FP8x23 { mag: 455536, sign: false }); + data.append(FP8x23 { mag: 577191, sign: false }); + data.append(FP8x23 { mag: 4412268, sign: true }); + data.append(FP8x23 { mag: 15812229, sign: false }); + data.append(FP8x23 { mag: 7098764, sign: true }); + data.append(FP8x23 { mag: 9127468, sign: false }); + data.append(FP8x23 { mag: 4753858, sign: false }); + data.append(FP8x23 { mag: 2074029, sign: true }); + data.append(FP8x23 { mag: 1651256, sign: false }); + data.append(FP8x23 { mag: 9617324, sign: true }); + data.append(FP8x23 { mag: 11400835, sign: false }); + data.append(FP8x23 { mag: 4263073, sign: true }); + data.append(FP8x23 { mag: 22170402, sign: false }); + data.append(FP8x23 { mag: 7715608, sign: false }); + data.append(FP8x23 { mag: 7511781, sign: true }); + data.append(FP8x23 { mag: 8686402, sign: false }); + data.append(FP8x23 { mag: 2710329, sign: true }); + data.append(FP8x23 { mag: 5540998, sign: true }); + data.append(FP8x23 { mag: 11608300, sign: true }); + data.append(FP8x23 { mag: 3020404, sign: true }); + data.append(FP8x23 { mag: 6342478, sign: true }); + data.append(FP8x23 { mag: 9399735, sign: true }); + data.append(FP8x23 { mag: 446463, sign: false }); + data.append(FP8x23 { mag: 13691013, sign: true }); + data.append(FP8x23 { mag: 11552903, sign: true }); + data.append(FP8x23 { mag: 1204731, sign: false }); + data.append(FP8x23 { mag: 1741592, sign: true }); + data.append(FP8x23 { mag: 13103082, sign: false }); + data.append(FP8x23 { mag: 3181444, sign: true }); + data.append(FP8x23 { mag: 256975, sign: true }); + data.append(FP8x23 { mag: 9440785, sign: true }); + data.append(FP8x23 { mag: 2112590, sign: false }); + data.append(FP8x23 { mag: 13404752, sign: false }); + data.append(FP8x23 { mag: 760699, sign: false }); + data.append(FP8x23 { mag: 1588793, sign: true }); + data.append(FP8x23 { mag: 13026604, sign: true }); + data.append(FP8x23 { mag: 452707, sign: false }); + data.append(FP8x23 { mag: 7267348, sign: true }); + data.append(FP8x23 { mag: 14737007, sign: false }); + data.append(FP8x23 { mag: 8457998, sign: false }); + data.append(FP8x23 { mag: 2233703, sign: false }); + data.append(FP8x23 { mag: 3434673, sign: true }); + data.append(FP8x23 { mag: 4280157, sign: true }); + data.append(FP8x23 { mag: 2950181, sign: true }); + data.append(FP8x23 { mag: 1385553, sign: false }); + data.append(FP8x23 { mag: 17250056, sign: false }); + data.append(FP8x23 { mag: 12716927, sign: true }); + data.append(FP8x23 { mag: 2980452, sign: false }); + data.append(FP8x23 { mag: 13031106, sign: true }); + data.append(FP8x23 { mag: 4118717, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis_negative_1/input_1.cairo b/tests/nodes/layer_normalization_4d_axis_negative_1/input_1.cairo new file mode 100644 index 000000000..51b37b20c --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis_negative_1/input_1.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 7892951, sign: true }); + data.append(FP8x23 { mag: 7153170, sign: false }); + data.append(FP8x23 { mag: 6305733, sign: false }); + data.append(FP8x23 { mag: 6298263, sign: true }); + data.append(FP8x23 { mag: 924383, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis_negative_1/input_2.cairo b/tests/nodes/layer_normalization_4d_axis_negative_1/input_2.cairo new file mode 100644 index 000000000..3fb4ef013 --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis_negative_1/input_2.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn input_2() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 9596129, sign: false }); + data.append(FP8x23 { mag: 8524695, sign: false }); + data.append(FP8x23 { mag: 8030491, sign: true }); + data.append(FP8x23 { mag: 8640310, sign: false }); + data.append(FP8x23 { mag: 12854812, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis_negative_1/output_0.cairo b/tests/nodes/layer_normalization_4d_axis_negative_1/output_0.cairo new file mode 100644 index 000000000..5b69011e6 --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis_negative_1/output_0.cairo @@ -0,0 +1,135 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 19016228, sign: false }); + data.append(FP8x23 { mag: 15560107, sign: false }); + data.append(FP8x23 { mag: 8897166, sign: true }); + data.append(FP8x23 { mag: 14628724, sign: false }); + data.append(FP8x23 { mag: 14054839, sign: false }); + data.append(FP8x23 { mag: 3259770, sign: true }); + data.append(FP8x23 { mag: 5376533, sign: false }); + data.append(FP8x23 { mag: 12620596, sign: true }); + data.append(FP8x23 { mag: 4605447, sign: false }); + data.append(FP8x23 { mag: 11836714, sign: false }); + data.append(FP8x23 { mag: 14989397, sign: false }); + data.append(FP8x23 { mag: 4350642, sign: false }); + data.append(FP8x23 { mag: 3512472, sign: false }); + data.append(FP8x23 { mag: 14095496, sign: false }); + data.append(FP8x23 { mag: 13134362, sign: false }); + data.append(FP8x23 { mag: 15037165, sign: false }); + data.append(FP8x23 { mag: 15250170, sign: false }); + data.append(FP8x23 { mag: 9352358, sign: true }); + data.append(FP8x23 { mag: 324961, sign: false }); + data.append(FP8x23 { mag: 11596276, sign: false }); + data.append(FP8x23 { mag: 19902878, sign: false }); + data.append(FP8x23 { mag: 10482114, sign: false }); + data.append(FP8x23 { mag: 13144152, sign: true }); + data.append(FP8x23 { mag: 1302801, sign: true }); + data.append(FP8x23 { mag: 13099238, sign: false }); + data.append(FP8x23 { mag: 24340960, sign: false }); + data.append(FP8x23 { mag: 15614434, sign: false }); + data.append(FP8x23 { mag: 4978514, sign: true }); + data.append(FP8x23 { mag: 9473866, sign: false }); + data.append(FP8x23 { mag: 13340405, sign: false }); + data.append(FP8x23 { mag: 23952576, sign: false }); + data.append(FP8x23 { mag: 8237391, sign: false }); + data.append(FP8x23 { mag: 6954001, sign: true }); + data.append(FP8x23 { mag: 5484904, sign: false }); + data.append(FP8x23 { mag: 13952376, sign: false }); + data.append(FP8x23 { mag: 17846270, sign: false }); + data.append(FP8x23 { mag: 1468006, sign: false }); + data.append(FP8x23 { mag: 4739795, sign: true }); + data.append(FP8x23 { mag: 9379485, sign: false }); + data.append(FP8x23 { mag: 14359034, sign: false }); + data.append(FP8x23 { mag: 18100568, sign: false }); + data.append(FP8x23 { mag: 10909375, sign: false }); + data.append(FP8x23 { mag: 2551079, sign: true }); + data.append(FP8x23 { mag: 16702075, sign: false }); + data.append(FP8x23 { mag: 13922604, sign: false }); + data.append(FP8x23 { mag: 19848816, sign: false }); + data.append(FP8x23 { mag: 14230550, sign: false }); + data.append(FP8x23 { mag: 7037550, sign: true }); + data.append(FP8x23 { mag: 14728206, sign: false }); + data.append(FP8x23 { mag: 14066155, sign: false }); + data.append(FP8x23 { mag: 16967740, sign: false }); + data.append(FP8x23 { mag: 6793223, sign: false }); + data.append(FP8x23 { mag: 10913412, sign: true }); + data.append(FP8x23 { mag: 10575819, sign: false }); + data.append(FP8x23 { mag: 14648582, sign: false }); + data.append(FP8x23 { mag: 1155626, sign: false }); + data.append(FP8x23 { mag: 10504639, sign: false }); + data.append(FP8x23 { mag: 20023544, sign: true }); + data.append(FP8x23 { mag: 7287742, sign: false }); + data.append(FP8x23 { mag: 13170037, sign: false }); + data.append(FP8x23 { mag: 14179070, sign: false }); + data.append(FP8x23 { mag: 5821802, sign: false }); + data.append(FP8x23 { mag: 4012642, sign: false }); + data.append(FP8x23 { mag: 8750452, sign: false }); + data.append(FP8x23 { mag: 11991543, sign: false }); + data.append(FP8x23 { mag: 8412352, sign: false }); + data.append(FP8x23 { mag: 20903656, sign: false }); + data.append(FP8x23 { mag: 16068864, sign: true }); + data.append(FP8x23 { mag: 12406173, sign: false }); + data.append(FP8x23 { mag: 12847563, sign: false }); + data.append(FP8x23 { mag: 10204216, sign: false }); + data.append(FP8x23 { mag: 8083285, sign: false }); + data.append(FP8x23 { mag: 12385102, sign: true }); + data.append(FP8x23 { mag: 3065106, sign: true }); + data.append(FP8x23 { mag: 11903450, sign: false }); + data.append(FP8x23 { mag: 771882, sign: true }); + data.append(FP8x23 { mag: 13004755, sign: false }); + data.append(FP8x23 { mag: 10846875, sign: true }); + data.append(FP8x23 { mag: 7766379, sign: false }); + data.append(FP8x23 { mag: 11346216, sign: false }); + data.append(FP8x23 { mag: 5571666, sign: false }); + data.append(FP8x23 { mag: 1781663, sign: false }); + data.append(FP8x23 { mag: 1482105, sign: false }); + data.append(FP8x23 { mag: 7581305, sign: false }); + data.append(FP8x23 { mag: 11704950, sign: false }); + data.append(FP8x23 { mag: 4209335, sign: true }); + data.append(FP8x23 { mag: 8664081, sign: false }); + data.append(FP8x23 { mag: 10616503, sign: true }); + data.append(FP8x23 { mag: 17022648, sign: false }); + data.append(FP8x23 { mag: 12829325, sign: false }); + data.append(FP8x23 { mag: 6763571, sign: false }); + data.append(FP8x23 { mag: 6646343, sign: false }); + data.append(FP8x23 { mag: 2934484, sign: false }); + data.append(FP8x23 { mag: 15788190, sign: false }); + data.append(FP8x23 { mag: 12207490, sign: false }); + data.append(FP8x23 { mag: 10437236, sign: false }); + data.append(FP8x23 { mag: 4144833, sign: false }); + data.append(FP8x23 { mag: 4175960, sign: false }); + data.append(FP8x23 { mag: 14053320, sign: false }); + data.append(FP8x23 { mag: 12524378, sign: false }); + data.append(FP8x23 { mag: 20845660, sign: false }); + data.append(FP8x23 { mag: 9557670, sign: false }); + data.append(FP8x23 { mag: 2554248, sign: false }); + data.append(FP8x23 { mag: 8887601, sign: false }); + data.append(FP8x23 { mag: 12523447, sign: false }); + data.append(FP8x23 { mag: 20309456, sign: false }); + data.append(FP8x23 { mag: 8370136, sign: false }); + data.append(FP8x23 { mag: 12990704, sign: true }); + data.append(FP8x23 { mag: 138738, sign: true }); + data.append(FP8x23 { mag: 13568132, sign: false }); + data.append(FP8x23 { mag: 1148653, sign: true }); + data.append(FP8x23 { mag: 3110274, sign: false }); + data.append(FP8x23 { mag: 14795777, sign: true }); + data.append(FP8x23 { mag: 12267317, sign: false }); + data.append(FP8x23 { mag: 13820206, sign: false }); + data.append(FP8x23 { mag: 3820888, sign: true }); + data.append(FP8x23 { mag: 1683436, sign: false }); + data.append(FP8x23 { mag: 5287384, sign: true }); + data.append(FP8x23 { mag: 14839339, sign: false }); + data.append(FP8x23 { mag: 12675249, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis_negative_2.cairo b/tests/nodes/layer_normalization_4d_axis_negative_2.cairo new file mode 100644 index 000000000..52785fddf --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis_negative_2.cairo @@ -0,0 +1,32 @@ +mod input_0; +mod input_1; +mod input_2; +mod output_0; + + +use orion::operators::tensor::FP8x23Tensor; +use orion::utils::{assert_eq, assert_seq_eq}; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23TensorPartialEq; +use orion::numbers::{IntegerTrait, i32, FixedTrait}; + +#[test] +#[available_gas(2000000000)] +fn test_layer_normalization_4d_axis_negative_2() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let input_2 = input_2::input_2(); + let z_0 = output_0::output_0(); + + let (y_0, _, _) = input_0 + .layer_normalization( + @input_1, + Option::Some(@input_2), + Option::Some(IntegerTrait::::new(2, true)), + Option::None, + Option::None + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/layer_normalization_4d_axis_negative_2/input_0.cairo b/tests/nodes/layer_normalization_4d_axis_negative_2/input_0.cairo new file mode 100644 index 000000000..d7913bb11 --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis_negative_2/input_0.cairo @@ -0,0 +1,135 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 9853496, sign: true }); + data.append(FP8x23 { mag: 12258403, sign: false }); + data.append(FP8x23 { mag: 872656, sign: false }); + data.append(FP8x23 { mag: 7388622, sign: true }); + data.append(FP8x23 { mag: 15454325, sign: false }); + data.append(FP8x23 { mag: 16251435, sign: false }); + data.append(FP8x23 { mag: 903277, sign: true }); + data.append(FP8x23 { mag: 3289794, sign: true }); + data.append(FP8x23 { mag: 8057933, sign: false }); + data.append(FP8x23 { mag: 6386388, sign: true }); + data.append(FP8x23 { mag: 5708410, sign: true }); + data.append(FP8x23 { mag: 4543373, sign: true }); + data.append(FP8x23 { mag: 23644376, sign: false }); + data.append(FP8x23 { mag: 7843321, sign: true }); + data.append(FP8x23 { mag: 5801261, sign: false }); + data.append(FP8x23 { mag: 5404517, sign: true }); + data.append(FP8x23 { mag: 3421350, sign: false }); + data.append(FP8x23 { mag: 2806284, sign: true }); + data.append(FP8x23 { mag: 5479745, sign: false }); + data.append(FP8x23 { mag: 9044852, sign: true }); + data.append(FP8x23 { mag: 2879371, sign: true }); + data.append(FP8x23 { mag: 7506722, sign: false }); + data.append(FP8x23 { mag: 374708, sign: false }); + data.append(FP8x23 { mag: 16088456, sign: false }); + data.append(FP8x23 { mag: 7446071, sign: false }); + data.append(FP8x23 { mag: 12333473, sign: true }); + data.append(FP8x23 { mag: 2694855, sign: false }); + data.append(FP8x23 { mag: 29333, sign: false }); + data.append(FP8x23 { mag: 3210230, sign: true }); + data.append(FP8x23 { mag: 246487, sign: false }); + data.append(FP8x23 { mag: 11307238, sign: true }); + data.append(FP8x23 { mag: 808074, sign: false }); + data.append(FP8x23 { mag: 2244426, sign: false }); + data.append(FP8x23 { mag: 4494036, sign: false }); + data.append(FP8x23 { mag: 9168918, sign: false }); + data.append(FP8x23 { mag: 11831318, sign: true }); + data.append(FP8x23 { mag: 11514568, sign: true }); + data.append(FP8x23 { mag: 3381120, sign: true }); + data.append(FP8x23 { mag: 6827926, sign: true }); + data.append(FP8x23 { mag: 2579494, sign: false }); + data.append(FP8x23 { mag: 4493030, sign: true }); + data.append(FP8x23 { mag: 4570125, sign: false }); + data.append(FP8x23 { mag: 8010665, sign: false }); + data.append(FP8x23 { mag: 5794037, sign: true }); + data.append(FP8x23 { mag: 9849078, sign: false }); + data.append(FP8x23 { mag: 11691798, sign: true }); + data.append(FP8x23 { mag: 3652747, sign: false }); + data.append(FP8x23 { mag: 1032666, sign: true }); + data.append(FP8x23 { mag: 9259310, sign: true }); + data.append(FP8x23 { mag: 7405492, sign: false }); + data.append(FP8x23 { mag: 4687488, sign: true }); + data.append(FP8x23 { mag: 1097650, sign: true }); + data.append(FP8x23 { mag: 2213858, sign: true }); + data.append(FP8x23 { mag: 1436205, sign: true }); + data.append(FP8x23 { mag: 10226423, sign: false }); + data.append(FP8x23 { mag: 6130226, sign: false }); + data.append(FP8x23 { mag: 1214058, sign: false }); + data.append(FP8x23 { mag: 12299984, sign: true }); + data.append(FP8x23 { mag: 829240, sign: false }); + data.append(FP8x23 { mag: 1612388, sign: false }); + data.append(FP8x23 { mag: 6632529, sign: true }); + data.append(FP8x23 { mag: 4410829, sign: true }); + data.append(FP8x23 { mag: 20654302, sign: false }); + data.append(FP8x23 { mag: 462475, sign: true }); + data.append(FP8x23 { mag: 10502841, sign: true }); + data.append(FP8x23 { mag: 7171902, sign: true }); + data.append(FP8x23 { mag: 4204962, sign: false }); + data.append(FP8x23 { mag: 17427142, sign: true }); + data.append(FP8x23 { mag: 12555224, sign: true }); + data.append(FP8x23 { mag: 8307885, sign: true }); + data.append(FP8x23 { mag: 455536, sign: false }); + data.append(FP8x23 { mag: 577191, sign: false }); + data.append(FP8x23 { mag: 4412268, sign: true }); + data.append(FP8x23 { mag: 15812229, sign: false }); + data.append(FP8x23 { mag: 7098764, sign: true }); + data.append(FP8x23 { mag: 9127468, sign: false }); + data.append(FP8x23 { mag: 4753858, sign: false }); + data.append(FP8x23 { mag: 2074029, sign: true }); + data.append(FP8x23 { mag: 1651256, sign: false }); + data.append(FP8x23 { mag: 9617324, sign: true }); + data.append(FP8x23 { mag: 11400835, sign: false }); + data.append(FP8x23 { mag: 4263073, sign: true }); + data.append(FP8x23 { mag: 22170402, sign: false }); + data.append(FP8x23 { mag: 7715608, sign: false }); + data.append(FP8x23 { mag: 7511781, sign: true }); + data.append(FP8x23 { mag: 8686402, sign: false }); + data.append(FP8x23 { mag: 2710329, sign: true }); + data.append(FP8x23 { mag: 5540998, sign: true }); + data.append(FP8x23 { mag: 11608300, sign: true }); + data.append(FP8x23 { mag: 3020404, sign: true }); + data.append(FP8x23 { mag: 6342478, sign: true }); + data.append(FP8x23 { mag: 9399735, sign: true }); + data.append(FP8x23 { mag: 446463, sign: false }); + data.append(FP8x23 { mag: 13691013, sign: true }); + data.append(FP8x23 { mag: 11552903, sign: true }); + data.append(FP8x23 { mag: 1204731, sign: false }); + data.append(FP8x23 { mag: 1741592, sign: true }); + data.append(FP8x23 { mag: 13103082, sign: false }); + data.append(FP8x23 { mag: 3181444, sign: true }); + data.append(FP8x23 { mag: 256975, sign: true }); + data.append(FP8x23 { mag: 9440785, sign: true }); + data.append(FP8x23 { mag: 2112590, sign: false }); + data.append(FP8x23 { mag: 13404752, sign: false }); + data.append(FP8x23 { mag: 760699, sign: false }); + data.append(FP8x23 { mag: 1588793, sign: true }); + data.append(FP8x23 { mag: 13026604, sign: true }); + data.append(FP8x23 { mag: 452707, sign: false }); + data.append(FP8x23 { mag: 7267348, sign: true }); + data.append(FP8x23 { mag: 14737007, sign: false }); + data.append(FP8x23 { mag: 8457998, sign: false }); + data.append(FP8x23 { mag: 2233703, sign: false }); + data.append(FP8x23 { mag: 3434673, sign: true }); + data.append(FP8x23 { mag: 4280157, sign: true }); + data.append(FP8x23 { mag: 2950181, sign: true }); + data.append(FP8x23 { mag: 1385553, sign: false }); + data.append(FP8x23 { mag: 17250056, sign: false }); + data.append(FP8x23 { mag: 12716927, sign: true }); + data.append(FP8x23 { mag: 2980452, sign: false }); + data.append(FP8x23 { mag: 13031106, sign: true }); + data.append(FP8x23 { mag: 4118717, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis_negative_2/input_1.cairo b/tests/nodes/layer_normalization_4d_axis_negative_2/input_1.cairo new file mode 100644 index 000000000..3356d848b --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis_negative_2/input_1.cairo @@ -0,0 +1,33 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 1860902, sign: true }); + data.append(FP8x23 { mag: 7189990, sign: false }); + data.append(FP8x23 { mag: 5594953, sign: false }); + data.append(FP8x23 { mag: 14949612, sign: false }); + data.append(FP8x23 { mag: 1598676, sign: true }); + data.append(FP8x23 { mag: 19332304, sign: true }); + data.append(FP8x23 { mag: 13237330, sign: true }); + data.append(FP8x23 { mag: 13876161, sign: true }); + data.append(FP8x23 { mag: 2710915, sign: false }); + data.append(FP8x23 { mag: 1998193, sign: false }); + data.append(FP8x23 { mag: 10029104, sign: true }); + data.append(FP8x23 { mag: 5128877, sign: true }); + data.append(FP8x23 { mag: 12692706, sign: false }); + data.append(FP8x23 { mag: 7217481, sign: false }); + data.append(FP8x23 { mag: 2729123, sign: true }); + data.append(FP8x23 { mag: 12888666, sign: false }); + data.append(FP8x23 { mag: 4258854, sign: false }); + data.append(FP8x23 { mag: 1006706, sign: true }); + data.append(FP8x23 { mag: 3116978, sign: false }); + data.append(FP8x23 { mag: 10767356, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis_negative_2/input_2.cairo b/tests/nodes/layer_normalization_4d_axis_negative_2/input_2.cairo new file mode 100644 index 000000000..f973cd66d --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis_negative_2/input_2.cairo @@ -0,0 +1,33 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn input_2() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 14435005, sign: false }); + data.append(FP8x23 { mag: 275345, sign: false }); + data.append(FP8x23 { mag: 7948101, sign: true }); + data.append(FP8x23 { mag: 124471, sign: true }); + data.append(FP8x23 { mag: 11083371, sign: true }); + data.append(FP8x23 { mag: 2513924, sign: true }); + data.append(FP8x23 { mag: 6387124, sign: true }); + data.append(FP8x23 { mag: 5452904, sign: false }); + data.append(FP8x23 { mag: 12271809, sign: true }); + data.append(FP8x23 { mag: 15327354, sign: true }); + data.append(FP8x23 { mag: 3795402, sign: true }); + data.append(FP8x23 { mag: 2307268, sign: false }); + data.append(FP8x23 { mag: 5731544, sign: false }); + data.append(FP8x23 { mag: 4011370, sign: true }); + data.append(FP8x23 { mag: 3178152, sign: false }); + data.append(FP8x23 { mag: 14982171, sign: false }); + data.append(FP8x23 { mag: 2850000, sign: true }); + data.append(FP8x23 { mag: 9445099, sign: false }); + data.append(FP8x23 { mag: 8149556, sign: false }); + data.append(FP8x23 { mag: 8935026, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis_negative_2/output_0.cairo b/tests/nodes/layer_normalization_4d_axis_negative_2/output_0.cairo new file mode 100644 index 000000000..d1025c8b9 --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis_negative_2/output_0.cairo @@ -0,0 +1,135 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 16673352, sign: false }); + data.append(FP8x23 { mag: 8614825, sign: false }); + data.append(FP8x23 { mag: 8265431, sign: true }); + data.append(FP8x23 { mag: 14168915, sign: true }); + data.append(FP8x23 { mag: 13483567, sign: true }); + data.append(FP8x23 { mag: 33185346, sign: true }); + data.append(FP8x23 { mag: 3124395, sign: true }); + data.append(FP8x23 { mag: 12411572, sign: false }); + data.append(FP8x23 { mag: 10344227, sign: true }); + data.append(FP8x23 { mag: 16990572, sign: true }); + data.append(FP8x23 { mag: 3825893, sign: false }); + data.append(FP8x23 { mag: 5566316, sign: false }); + data.append(FP8x23 { mag: 35895616, sign: false }); + data.append(FP8x23 { mag: 11142513, sign: true }); + data.append(FP8x23 { mag: 1895698, sign: false }); + data.append(FP8x23 { mag: 5606359, sign: false }); + data.append(FP8x23 { mag: 1931723, sign: true }); + data.append(FP8x23 { mag: 9897935, sign: false }); + data.append(FP8x23 { mag: 9507187, sign: false }); + data.append(FP8x23 { mag: 20955956, sign: false }); + data.append(FP8x23 { mag: 15031808, sign: false }); + data.append(FP8x23 { mag: 7951613, sign: false }); + data.append(FP8x23 { mag: 7308734, sign: true }); + data.append(FP8x23 { mag: 32985592, sign: false }); + data.append(FP8x23 { mag: 12777209, sign: true }); + data.append(FP8x23 { mag: 28117402, sign: false }); + data.append(FP8x23 { mag: 12005266, sign: true }); + data.append(FP8x23 { mag: 4507819, sign: false }); + data.append(FP8x23 { mag: 13261112, sign: true }); + data.append(FP8x23 { mag: 15133257, sign: true }); + data.append(FP8x23 { mag: 10719554, sign: false }); + data.append(FP8x23 { mag: 1424049, sign: false }); + data.append(FP8x23 { mag: 10354309, sign: false }); + data.append(FP8x23 { mag: 787664, sign: false }); + data.append(FP8x23 { mag: 341933, sign: true }); + data.append(FP8x23 { mag: 4574301, sign: true }); + data.append(FP8x23 { mag: 9131801, sign: true }); + data.append(FP8x23 { mag: 9835476, sign: false }); + data.append(FP8x23 { mag: 5504734, sign: false }); + data.append(FP8x23 { mag: 4531222, sign: false }); + data.append(FP8x23 { mag: 15705951, sign: false }); + data.append(FP8x23 { mag: 5326222, sign: false }); + data.append(FP8x23 { mag: 1075073, sign: true }); + data.append(FP8x23 { mag: 13307859, sign: true }); + data.append(FP8x23 { mag: 13496520, sign: true }); + data.append(FP8x23 { mag: 31963888, sign: false }); + data.append(FP8x23 { mag: 13829816, sign: true }); + data.append(FP8x23 { mag: 7589786, sign: false }); + data.append(FP8x23 { mag: 16098488, sign: true }); + data.append(FP8x23 { mag: 13057562, sign: true }); + data.append(FP8x23 { mag: 3352336, sign: false }); + data.append(FP8x23 { mag: 3148048, sign: false }); + data.append(FP8x23 { mag: 1485042, sign: false }); + data.append(FP8x23 { mag: 5568070, sign: true }); + data.append(FP8x23 { mag: 1098794, sign: true }); + data.append(FP8x23 { mag: 27110094, sign: false }); + data.append(FP8x23 { mag: 2043142, sign: true }); + data.append(FP8x23 { mag: 11334084, sign: false }); + data.append(FP8x23 { mag: 8556721, sign: false }); + data.append(FP8x23 { mag: 6239462, sign: false }); + data.append(FP8x23 { mag: 15445237, sign: false }); + data.append(FP8x23 { mag: 1879795, sign: true }); + data.append(FP8x23 { mag: 5721750, sign: false }); + data.append(FP8x23 { mag: 1854019, sign: false }); + data.append(FP8x23 { mag: 9538382, sign: true }); + data.append(FP8x23 { mag: 9122158, sign: false }); + data.append(FP8x23 { mag: 14900349, sign: true }); + data.append(FP8x23 { mag: 29377836, sign: false }); + data.append(FP8x23 { mag: 15500562, sign: true }); + data.append(FP8x23 { mag: 16778470, sign: true }); + data.append(FP8x23 { mag: 6130239, sign: true }); + data.append(FP8x23 { mag: 1044951, sign: false }); + data.append(FP8x23 { mag: 1925009, sign: false }); + data.append(FP8x23 { mag: 9798252, sign: false }); + data.append(FP8x23 { mag: 4798964, sign: false }); + data.append(FP8x23 { mag: 30214180, sign: false }); + data.append(FP8x23 { mag: 144786, sign: false }); + data.append(FP8x23 { mag: 9489410, sign: false }); + data.append(FP8x23 { mag: 9283074, sign: false }); + data.append(FP8x23 { mag: 18297370, sign: false }); + data.append(FP8x23 { mag: 11923786, sign: false }); + data.append(FP8x23 { mag: 2473868, sign: true }); + data.append(FP8x23 { mag: 6264012, sign: false }); + data.append(FP8x23 { mag: 13958334, sign: false }); + data.append(FP8x23 { mag: 9897873, sign: true }); + data.append(FP8x23 { mag: 22800294, sign: true }); + data.append(FP8x23 { mag: 3598116, sign: true }); + data.append(FP8x23 { mag: 12719248, sign: false }); + data.append(FP8x23 { mag: 15509912, sign: true }); + data.append(FP8x23 { mag: 15816862, sign: true }); + data.append(FP8x23 { mag: 2345113, sign: false }); + data.append(FP8x23 { mag: 7181165, sign: false }); + data.append(FP8x23 { mag: 7487303, sign: false }); + data.append(FP8x23 { mag: 14294384, sign: true }); + data.append(FP8x23 { mag: 6421289, sign: false }); + data.append(FP8x23 { mag: 17845564, sign: false }); + data.append(FP8x23 { mag: 3291163, sign: true }); + data.append(FP8x23 { mag: 7897120, sign: false }); + data.append(FP8x23 { mag: 7330476, sign: false }); + data.append(FP8x23 { mag: 8283012, sign: false }); + data.append(FP8x23 { mag: 16386448, sign: false }); + data.append(FP8x23 { mag: 2375043, sign: false }); + data.append(FP8x23 { mag: 1017275, sign: false }); + data.append(FP8x23 { mag: 1896022, sign: false }); + data.append(FP8x23 { mag: 10863572, sign: true }); + data.append(FP8x23 { mag: 25803316, sign: false }); + data.append(FP8x23 { mag: 7703091, sign: true }); + data.append(FP8x23 { mag: 16504475, sign: false }); + data.append(FP8x23 { mag: 7508723, sign: true }); + data.append(FP8x23 { mag: 13272474, sign: true }); + data.append(FP8x23 { mag: 6865161, sign: true }); + data.append(FP8x23 { mag: 4111037, sign: false }); + data.append(FP8x23 { mag: 22349, sign: false }); + data.append(FP8x23 { mag: 6143899, sign: true }); + data.append(FP8x23 { mag: 2611413, sign: false }); + data.append(FP8x23 { mag: 41386196, sign: false }); + data.append(FP8x23 { mag: 8935166, sign: true }); + data.append(FP8x23 { mag: 9049725, sign: false }); + data.append(FP8x23 { mag: 3582293, sign: false }); + data.append(FP8x23 { mag: 13576482, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis_negative_3.cairo b/tests/nodes/layer_normalization_4d_axis_negative_3.cairo new file mode 100644 index 000000000..3b01a1d4a --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis_negative_3.cairo @@ -0,0 +1,32 @@ +mod input_0; +mod input_1; +mod input_2; +mod output_0; + + +use orion::operators::tensor::FP8x23Tensor; +use orion::utils::{assert_eq, assert_seq_eq}; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23TensorPartialEq; +use orion::numbers::{IntegerTrait, i32, FixedTrait}; + +#[test] +#[available_gas(2000000000)] +fn test_layer_normalization_4d_axis_negative_3() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let input_2 = input_2::input_2(); + let z_0 = output_0::output_0(); + + let (y_0, _, _) = input_0 + .layer_normalization( + @input_1, + Option::Some(@input_2), + Option::Some(IntegerTrait::::new(3, true)), + Option::None, + Option::None + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/layer_normalization_4d_axis_negative_3/input_0.cairo b/tests/nodes/layer_normalization_4d_axis_negative_3/input_0.cairo new file mode 100644 index 000000000..d7913bb11 --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis_negative_3/input_0.cairo @@ -0,0 +1,135 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 9853496, sign: true }); + data.append(FP8x23 { mag: 12258403, sign: false }); + data.append(FP8x23 { mag: 872656, sign: false }); + data.append(FP8x23 { mag: 7388622, sign: true }); + data.append(FP8x23 { mag: 15454325, sign: false }); + data.append(FP8x23 { mag: 16251435, sign: false }); + data.append(FP8x23 { mag: 903277, sign: true }); + data.append(FP8x23 { mag: 3289794, sign: true }); + data.append(FP8x23 { mag: 8057933, sign: false }); + data.append(FP8x23 { mag: 6386388, sign: true }); + data.append(FP8x23 { mag: 5708410, sign: true }); + data.append(FP8x23 { mag: 4543373, sign: true }); + data.append(FP8x23 { mag: 23644376, sign: false }); + data.append(FP8x23 { mag: 7843321, sign: true }); + data.append(FP8x23 { mag: 5801261, sign: false }); + data.append(FP8x23 { mag: 5404517, sign: true }); + data.append(FP8x23 { mag: 3421350, sign: false }); + data.append(FP8x23 { mag: 2806284, sign: true }); + data.append(FP8x23 { mag: 5479745, sign: false }); + data.append(FP8x23 { mag: 9044852, sign: true }); + data.append(FP8x23 { mag: 2879371, sign: true }); + data.append(FP8x23 { mag: 7506722, sign: false }); + data.append(FP8x23 { mag: 374708, sign: false }); + data.append(FP8x23 { mag: 16088456, sign: false }); + data.append(FP8x23 { mag: 7446071, sign: false }); + data.append(FP8x23 { mag: 12333473, sign: true }); + data.append(FP8x23 { mag: 2694855, sign: false }); + data.append(FP8x23 { mag: 29333, sign: false }); + data.append(FP8x23 { mag: 3210230, sign: true }); + data.append(FP8x23 { mag: 246487, sign: false }); + data.append(FP8x23 { mag: 11307238, sign: true }); + data.append(FP8x23 { mag: 808074, sign: false }); + data.append(FP8x23 { mag: 2244426, sign: false }); + data.append(FP8x23 { mag: 4494036, sign: false }); + data.append(FP8x23 { mag: 9168918, sign: false }); + data.append(FP8x23 { mag: 11831318, sign: true }); + data.append(FP8x23 { mag: 11514568, sign: true }); + data.append(FP8x23 { mag: 3381120, sign: true }); + data.append(FP8x23 { mag: 6827926, sign: true }); + data.append(FP8x23 { mag: 2579494, sign: false }); + data.append(FP8x23 { mag: 4493030, sign: true }); + data.append(FP8x23 { mag: 4570125, sign: false }); + data.append(FP8x23 { mag: 8010665, sign: false }); + data.append(FP8x23 { mag: 5794037, sign: true }); + data.append(FP8x23 { mag: 9849078, sign: false }); + data.append(FP8x23 { mag: 11691798, sign: true }); + data.append(FP8x23 { mag: 3652747, sign: false }); + data.append(FP8x23 { mag: 1032666, sign: true }); + data.append(FP8x23 { mag: 9259310, sign: true }); + data.append(FP8x23 { mag: 7405492, sign: false }); + data.append(FP8x23 { mag: 4687488, sign: true }); + data.append(FP8x23 { mag: 1097650, sign: true }); + data.append(FP8x23 { mag: 2213858, sign: true }); + data.append(FP8x23 { mag: 1436205, sign: true }); + data.append(FP8x23 { mag: 10226423, sign: false }); + data.append(FP8x23 { mag: 6130226, sign: false }); + data.append(FP8x23 { mag: 1214058, sign: false }); + data.append(FP8x23 { mag: 12299984, sign: true }); + data.append(FP8x23 { mag: 829240, sign: false }); + data.append(FP8x23 { mag: 1612388, sign: false }); + data.append(FP8x23 { mag: 6632529, sign: true }); + data.append(FP8x23 { mag: 4410829, sign: true }); + data.append(FP8x23 { mag: 20654302, sign: false }); + data.append(FP8x23 { mag: 462475, sign: true }); + data.append(FP8x23 { mag: 10502841, sign: true }); + data.append(FP8x23 { mag: 7171902, sign: true }); + data.append(FP8x23 { mag: 4204962, sign: false }); + data.append(FP8x23 { mag: 17427142, sign: true }); + data.append(FP8x23 { mag: 12555224, sign: true }); + data.append(FP8x23 { mag: 8307885, sign: true }); + data.append(FP8x23 { mag: 455536, sign: false }); + data.append(FP8x23 { mag: 577191, sign: false }); + data.append(FP8x23 { mag: 4412268, sign: true }); + data.append(FP8x23 { mag: 15812229, sign: false }); + data.append(FP8x23 { mag: 7098764, sign: true }); + data.append(FP8x23 { mag: 9127468, sign: false }); + data.append(FP8x23 { mag: 4753858, sign: false }); + data.append(FP8x23 { mag: 2074029, sign: true }); + data.append(FP8x23 { mag: 1651256, sign: false }); + data.append(FP8x23 { mag: 9617324, sign: true }); + data.append(FP8x23 { mag: 11400835, sign: false }); + data.append(FP8x23 { mag: 4263073, sign: true }); + data.append(FP8x23 { mag: 22170402, sign: false }); + data.append(FP8x23 { mag: 7715608, sign: false }); + data.append(FP8x23 { mag: 7511781, sign: true }); + data.append(FP8x23 { mag: 8686402, sign: false }); + data.append(FP8x23 { mag: 2710329, sign: true }); + data.append(FP8x23 { mag: 5540998, sign: true }); + data.append(FP8x23 { mag: 11608300, sign: true }); + data.append(FP8x23 { mag: 3020404, sign: true }); + data.append(FP8x23 { mag: 6342478, sign: true }); + data.append(FP8x23 { mag: 9399735, sign: true }); + data.append(FP8x23 { mag: 446463, sign: false }); + data.append(FP8x23 { mag: 13691013, sign: true }); + data.append(FP8x23 { mag: 11552903, sign: true }); + data.append(FP8x23 { mag: 1204731, sign: false }); + data.append(FP8x23 { mag: 1741592, sign: true }); + data.append(FP8x23 { mag: 13103082, sign: false }); + data.append(FP8x23 { mag: 3181444, sign: true }); + data.append(FP8x23 { mag: 256975, sign: true }); + data.append(FP8x23 { mag: 9440785, sign: true }); + data.append(FP8x23 { mag: 2112590, sign: false }); + data.append(FP8x23 { mag: 13404752, sign: false }); + data.append(FP8x23 { mag: 760699, sign: false }); + data.append(FP8x23 { mag: 1588793, sign: true }); + data.append(FP8x23 { mag: 13026604, sign: true }); + data.append(FP8x23 { mag: 452707, sign: false }); + data.append(FP8x23 { mag: 7267348, sign: true }); + data.append(FP8x23 { mag: 14737007, sign: false }); + data.append(FP8x23 { mag: 8457998, sign: false }); + data.append(FP8x23 { mag: 2233703, sign: false }); + data.append(FP8x23 { mag: 3434673, sign: true }); + data.append(FP8x23 { mag: 4280157, sign: true }); + data.append(FP8x23 { mag: 2950181, sign: true }); + data.append(FP8x23 { mag: 1385553, sign: false }); + data.append(FP8x23 { mag: 17250056, sign: false }); + data.append(FP8x23 { mag: 12716927, sign: true }); + data.append(FP8x23 { mag: 2980452, sign: false }); + data.append(FP8x23 { mag: 13031106, sign: true }); + data.append(FP8x23 { mag: 4118717, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis_negative_3/input_1.cairo b/tests/nodes/layer_normalization_4d_axis_negative_3/input_1.cairo new file mode 100644 index 000000000..99425c1b7 --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis_negative_3/input_1.cairo @@ -0,0 +1,74 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 16380343, sign: false }); + data.append(FP8x23 { mag: 4257343, sign: false }); + data.append(FP8x23 { mag: 600530, sign: false }); + data.append(FP8x23 { mag: 2445196, sign: false }); + data.append(FP8x23 { mag: 9479607, sign: true }); + data.append(FP8x23 { mag: 978145, sign: true }); + data.append(FP8x23 { mag: 7810708, sign: false }); + data.append(FP8x23 { mag: 3402716, sign: true }); + data.append(FP8x23 { mag: 23762706, sign: true }); + data.append(FP8x23 { mag: 3969772, sign: true }); + data.append(FP8x23 { mag: 287406, sign: false }); + data.append(FP8x23 { mag: 17010742, sign: true }); + data.append(FP8x23 { mag: 13856942, sign: true }); + data.append(FP8x23 { mag: 8047151, sign: true }); + data.append(FP8x23 { mag: 7959897, sign: false }); + data.append(FP8x23 { mag: 13270756, sign: true }); + data.append(FP8x23 { mag: 2099848, sign: true }); + data.append(FP8x23 { mag: 10634067, sign: false }); + data.append(FP8x23 { mag: 12591738, sign: false }); + data.append(FP8x23 { mag: 3020733, sign: true }); + data.append(FP8x23 { mag: 6136434, sign: false }); + data.append(FP8x23 { mag: 3475208, sign: false }); + data.append(FP8x23 { mag: 5530273, sign: true }); + data.append(FP8x23 { mag: 3210239, sign: true }); + data.append(FP8x23 { mag: 1963138, sign: false }); + data.append(FP8x23 { mag: 12035630, sign: false }); + data.append(FP8x23 { mag: 1743576, sign: false }); + data.append(FP8x23 { mag: 1123297, sign: true }); + data.append(FP8x23 { mag: 3857322, sign: false }); + data.append(FP8x23 { mag: 10535923, sign: false }); + data.append(FP8x23 { mag: 22905288, sign: false }); + data.append(FP8x23 { mag: 6554809, sign: false }); + data.append(FP8x23 { mag: 4404682, sign: false }); + data.append(FP8x23 { mag: 2731185, sign: true }); + data.append(FP8x23 { mag: 1388899, sign: true }); + data.append(FP8x23 { mag: 6849233, sign: false }); + data.append(FP8x23 { mag: 2696271, sign: true }); + data.append(FP8x23 { mag: 4022892, sign: true }); + data.append(FP8x23 { mag: 10206850, sign: true }); + data.append(FP8x23 { mag: 1922098, sign: false }); + data.append(FP8x23 { mag: 5108807, sign: false }); + data.append(FP8x23 { mag: 6197766, sign: false }); + data.append(FP8x23 { mag: 3598686, sign: true }); + data.append(FP8x23 { mag: 12654023, sign: true }); + data.append(FP8x23 { mag: 5122151, sign: false }); + data.append(FP8x23 { mag: 7636231, sign: false }); + data.append(FP8x23 { mag: 3971308, sign: false }); + data.append(FP8x23 { mag: 798761, sign: false }); + data.append(FP8x23 { mag: 6869233, sign: false }); + data.append(FP8x23 { mag: 8740030, sign: true }); + data.append(FP8x23 { mag: 4664523, sign: false }); + data.append(FP8x23 { mag: 3437420, sign: false }); + data.append(FP8x23 { mag: 3958479, sign: false }); + data.append(FP8x23 { mag: 6298427, sign: true }); + data.append(FP8x23 { mag: 3669742, sign: false }); + data.append(FP8x23 { mag: 1711541, sign: false }); + data.append(FP8x23 { mag: 22958802, sign: false }); + data.append(FP8x23 { mag: 684171, sign: true }); + data.append(FP8x23 { mag: 1802568, sign: true }); + data.append(FP8x23 { mag: 1621301, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis_negative_3/input_2.cairo b/tests/nodes/layer_normalization_4d_axis_negative_3/input_2.cairo new file mode 100644 index 000000000..8a5d49e10 --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis_negative_3/input_2.cairo @@ -0,0 +1,74 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn input_2() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 9802310, sign: true }); + data.append(FP8x23 { mag: 18025368, sign: false }); + data.append(FP8x23 { mag: 17682696, sign: true }); + data.append(FP8x23 { mag: 7029352, sign: true }); + data.append(FP8x23 { mag: 2764347, sign: true }); + data.append(FP8x23 { mag: 7626198, sign: true }); + data.append(FP8x23 { mag: 7957030, sign: false }); + data.append(FP8x23 { mag: 7354525, sign: false }); + data.append(FP8x23 { mag: 2783866, sign: false }); + data.append(FP8x23 { mag: 3333849, sign: true }); + data.append(FP8x23 { mag: 906374, sign: true }); + data.append(FP8x23 { mag: 6164127, sign: true }); + data.append(FP8x23 { mag: 4297991, sign: true }); + data.append(FP8x23 { mag: 5427655, sign: false }); + data.append(FP8x23 { mag: 18113416, sign: true }); + data.append(FP8x23 { mag: 2255928, sign: false }); + data.append(FP8x23 { mag: 1389584, sign: false }); + data.append(FP8x23 { mag: 4815558, sign: false }); + data.append(FP8x23 { mag: 2817187, sign: true }); + data.append(FP8x23 { mag: 440217, sign: false }); + data.append(FP8x23 { mag: 190737, sign: false }); + data.append(FP8x23 { mag: 8139391, sign: false }); + data.append(FP8x23 { mag: 11498979, sign: true }); + data.append(FP8x23 { mag: 4351939, sign: true }); + data.append(FP8x23 { mag: 5194729, sign: false }); + data.append(FP8x23 { mag: 1153310, sign: true }); + data.append(FP8x23 { mag: 2229956, sign: true }); + data.append(FP8x23 { mag: 4621022, sign: true }); + data.append(FP8x23 { mag: 1791765, sign: false }); + data.append(FP8x23 { mag: 6755727, sign: false }); + data.append(FP8x23 { mag: 3423500, sign: true }); + data.append(FP8x23 { mag: 4643245, sign: true }); + data.append(FP8x23 { mag: 343398, sign: true }); + data.append(FP8x23 { mag: 18973418, sign: true }); + data.append(FP8x23 { mag: 22705418, sign: false }); + data.append(FP8x23 { mag: 4003045, sign: false }); + data.append(FP8x23 { mag: 14684607, sign: true }); + data.append(FP8x23 { mag: 9470858, sign: false }); + data.append(FP8x23 { mag: 4521178, sign: true }); + data.append(FP8x23 { mag: 8692518, sign: true }); + data.append(FP8x23 { mag: 5888156, sign: false }); + data.append(FP8x23 { mag: 7325508, sign: true }); + data.append(FP8x23 { mag: 10220334, sign: true }); + data.append(FP8x23 { mag: 408916, sign: true }); + data.append(FP8x23 { mag: 2895708, sign: true }); + data.append(FP8x23 { mag: 8610215, sign: false }); + data.append(FP8x23 { mag: 11150253, sign: true }); + data.append(FP8x23 { mag: 5360206, sign: false }); + data.append(FP8x23 { mag: 674939, sign: false }); + data.append(FP8x23 { mag: 11169360, sign: true }); + data.append(FP8x23 { mag: 4025671, sign: false }); + data.append(FP8x23 { mag: 9509782, sign: false }); + data.append(FP8x23 { mag: 5718525, sign: true }); + data.append(FP8x23 { mag: 8584055, sign: false }); + data.append(FP8x23 { mag: 5100969, sign: true }); + data.append(FP8x23 { mag: 13275665, sign: true }); + data.append(FP8x23 { mag: 14069922, sign: true }); + data.append(FP8x23 { mag: 9485429, sign: true }); + data.append(FP8x23 { mag: 3613096, sign: false }); + data.append(FP8x23 { mag: 11158159, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis_negative_3/output_0.cairo b/tests/nodes/layer_normalization_4d_axis_negative_3/output_0.cairo new file mode 100644 index 000000000..125162520 --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis_negative_3/output_0.cairo @@ -0,0 +1,135 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 30795284, sign: true }); + data.append(FP8x23 { mag: 24452244, sign: false }); + data.append(FP8x23 { mag: 17639236, sign: true }); + data.append(FP8x23 { mag: 9402301, sign: true }); + data.append(FP8x23 { mag: 20899032, sign: true }); + data.append(FP8x23 { mag: 9595832, sign: true }); + data.append(FP8x23 { mag: 6771320, sign: false }); + data.append(FP8x23 { mag: 8896149, sign: false }); + data.append(FP8x23 { mag: 20488648, sign: true }); + data.append(FP8x23 { mag: 16405, sign: false }); + data.append(FP8x23 { mag: 1124332, sign: true }); + data.append(FP8x23 { mag: 4234479, sign: false }); + data.append(FP8x23 { mag: 45132380, sign: true }); + data.append(FP8x23 { mag: 13698921, sign: false }); + data.append(FP8x23 { mag: 12585182, sign: true }); + data.append(FP8x23 { mag: 11810859, sign: false }); + data.append(FP8x23 { mag: 562048, sign: false }); + data.append(FP8x23 { mag: 646758, sign: false }); + data.append(FP8x23 { mag: 5416873, sign: false }); + data.append(FP8x23 { mag: 4003232, sign: false }); + data.append(FP8x23 { mag: 2271500, sign: true }); + data.append(FP8x23 { mag: 11301107, sign: false }); + data.append(FP8x23 { mag: 11551598, sign: true }); + data.append(FP8x23 { mag: 10750163, sign: true }); + data.append(FP8x23 { mag: 6965746, sign: false }); + data.append(FP8x23 { mag: 20345854, sign: true }); + data.append(FP8x23 { mag: 1702719, sign: true }); + data.append(FP8x23 { mag: 4582737, sign: true }); + data.append(FP8x23 { mag: 82919, sign: false }); + data.append(FP8x23 { mag: 6685446, sign: false }); + data.append(FP8x23 { mag: 36982080, sign: true }); + data.append(FP8x23 { mag: 4222303, sign: true }); + data.append(FP8x23 { mag: 738083, sign: false }); + data.append(FP8x23 { mag: 20419578, sign: true }); + data.append(FP8x23 { mag: 21150390, sign: false }); + data.append(FP8x23 { mag: 6484889, sign: true }); + data.append(FP8x23 { mag: 10663730, sign: true }); + data.append(FP8x23 { mag: 11339833, sign: false }); + data.append(FP8x23 { mag: 4661686, sign: false }); + data.append(FP8x23 { mag: 8139288, sign: true }); + data.append(FP8x23 { mag: 2797626, sign: false }); + data.append(FP8x23 { mag: 3984268, sign: true }); + data.append(FP8x23 { mag: 13723312, sign: true }); + data.append(FP8x23 { mag: 9324158, sign: false }); + data.append(FP8x23 { mag: 3278882, sign: false }); + data.append(FP8x23 { mag: 2948328, sign: true }); + data.append(FP8x23 { mag: 9469186, sign: true }); + data.append(FP8x23 { mag: 5225903, sign: false }); + data.append(FP8x23 { mag: 7613417, sign: true }); + data.append(FP8x23 { mag: 19009284, sign: true }); + data.append(FP8x23 { mag: 1089409, sign: false }); + data.append(FP8x23 { mag: 8903622, sign: false }); + data.append(FP8x23 { mag: 6974317, sign: true }); + data.append(FP8x23 { mag: 9963899, sign: false }); + data.append(FP8x23 { mag: 502413, sign: true }); + data.append(FP8x23 { mag: 12015909, sign: true }); + data.append(FP8x23 { mag: 11418954, sign: true }); + data.append(FP8x23 { mag: 8397312, sign: true }); + data.append(FP8x23 { mag: 3492521, sign: false }); + data.append(FP8x23 { mag: 11426886, sign: true }); + data.append(FP8x23 { mag: 20181712, sign: true }); + data.append(FP8x23 { mag: 16384300, sign: false }); + data.append(FP8x23 { mag: 16232709, sign: true }); + data.append(FP8x23 { mag: 6893410, sign: true }); + data.append(FP8x23 { mag: 7340870, sign: false }); + data.append(FP8x23 { mag: 6947460, sign: true }); + data.append(FP8x23 { mag: 12463704, sign: false }); + data.append(FP8x23 { mag: 13613810, sign: false }); + data.append(FP8x23 { mag: 33562828, sign: false }); + data.append(FP8x23 { mag: 75456, sign: true }); + data.append(FP8x23 { mag: 860922, sign: true }); + data.append(FP8x23 { mag: 9085466, sign: true }); + data.append(FP8x23 { mag: 1045640, sign: false }); + data.append(FP8x23 { mag: 9649592, sign: true }); + data.append(FP8x23 { mag: 23571772, sign: true }); + data.append(FP8x23 { mag: 12698496, sign: true }); + data.append(FP8x23 { mag: 49246, sign: false }); + data.append(FP8x23 { mag: 3492380, sign: false }); + data.append(FP8x23 { mag: 856033, sign: false }); + data.append(FP8x23 { mag: 3361495, sign: false }); + data.append(FP8x23 { mag: 8664075, sign: false }); + data.append(FP8x23 { mag: 6857170, sign: false }); + data.append(FP8x23 { mag: 25788504, sign: true }); + data.append(FP8x23 { mag: 7463156, sign: true }); + data.append(FP8x23 { mag: 3757968, sign: false }); + data.append(FP8x23 { mag: 11816282, sign: false }); + data.append(FP8x23 { mag: 2570839, sign: true }); + data.append(FP8x23 { mag: 4046212, sign: true }); + data.append(FP8x23 { mag: 2796458, sign: true }); + data.append(FP8x23 { mag: 4330926, sign: false }); + data.append(FP8x23 { mag: 17195276, sign: true }); + data.append(FP8x23 { mag: 10822918, sign: true }); + data.append(FP8x23 { mag: 348715, sign: false }); + data.append(FP8x23 { mag: 15089289, sign: true }); + data.append(FP8x23 { mag: 24348896, sign: false }); + data.append(FP8x23 { mag: 5659438, sign: false }); + data.append(FP8x23 { mag: 14449243, sign: true }); + data.append(FP8x23 { mag: 3150978, sign: false }); + data.append(FP8x23 { mag: 1988497, sign: true }); + data.append(FP8x23 { mag: 8541534, sign: true }); + data.append(FP8x23 { mag: 1048302, sign: false }); + data.append(FP8x23 { mag: 5198116, sign: true }); + data.append(FP8x23 { mag: 15995067, sign: true }); + data.append(FP8x23 { mag: 2841452, sign: true }); + data.append(FP8x23 { mag: 3255403, sign: true }); + data.append(FP8x23 { mag: 1682816, sign: true }); + data.append(FP8x23 { mag: 10523466, sign: true }); + data.append(FP8x23 { mag: 4797427, sign: false }); + data.append(FP8x23 { mag: 12720153, sign: false }); + data.append(FP8x23 { mag: 20364616, sign: true }); + data.append(FP8x23 { mag: 5689884, sign: false }); + data.append(FP8x23 { mag: 8559599, sign: false }); + data.append(FP8x23 { mag: 7186609, sign: true }); + data.append(FP8x23 { mag: 9984205, sign: false }); + data.append(FP8x23 { mag: 4139366, sign: true }); + data.append(FP8x23 { mag: 9793996, sign: true }); + data.append(FP8x23 { mag: 44222332, sign: true }); + data.append(FP8x23 { mag: 9786600, sign: true }); + data.append(FP8x23 { mag: 6043721, sign: false }); + data.append(FP8x23 { mag: 10586105, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis_negative_4.cairo b/tests/nodes/layer_normalization_4d_axis_negative_4.cairo new file mode 100644 index 000000000..7acaf24b3 --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis_negative_4.cairo @@ -0,0 +1,32 @@ +mod input_0; +mod input_1; +mod input_2; +mod output_0; + + +use orion::operators::tensor::FP8x23Tensor; +use orion::utils::{assert_eq, assert_seq_eq}; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23TensorPartialEq; +use orion::numbers::{IntegerTrait, i32, FixedTrait}; + +#[test] +#[available_gas(2000000000)] +fn test_layer_normalization_4d_axis_negative_4() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let input_2 = input_2::input_2(); + let z_0 = output_0::output_0(); + + let (y_0, _, _) = input_0 + .layer_normalization( + @input_1, + Option::Some(@input_2), + Option::Some(IntegerTrait::::new(4, true)), + Option::None, + Option::None + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/layer_normalization_4d_axis_negative_4/input_0.cairo b/tests/nodes/layer_normalization_4d_axis_negative_4/input_0.cairo new file mode 100644 index 000000000..d7913bb11 --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis_negative_4/input_0.cairo @@ -0,0 +1,135 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 9853496, sign: true }); + data.append(FP8x23 { mag: 12258403, sign: false }); + data.append(FP8x23 { mag: 872656, sign: false }); + data.append(FP8x23 { mag: 7388622, sign: true }); + data.append(FP8x23 { mag: 15454325, sign: false }); + data.append(FP8x23 { mag: 16251435, sign: false }); + data.append(FP8x23 { mag: 903277, sign: true }); + data.append(FP8x23 { mag: 3289794, sign: true }); + data.append(FP8x23 { mag: 8057933, sign: false }); + data.append(FP8x23 { mag: 6386388, sign: true }); + data.append(FP8x23 { mag: 5708410, sign: true }); + data.append(FP8x23 { mag: 4543373, sign: true }); + data.append(FP8x23 { mag: 23644376, sign: false }); + data.append(FP8x23 { mag: 7843321, sign: true }); + data.append(FP8x23 { mag: 5801261, sign: false }); + data.append(FP8x23 { mag: 5404517, sign: true }); + data.append(FP8x23 { mag: 3421350, sign: false }); + data.append(FP8x23 { mag: 2806284, sign: true }); + data.append(FP8x23 { mag: 5479745, sign: false }); + data.append(FP8x23 { mag: 9044852, sign: true }); + data.append(FP8x23 { mag: 2879371, sign: true }); + data.append(FP8x23 { mag: 7506722, sign: false }); + data.append(FP8x23 { mag: 374708, sign: false }); + data.append(FP8x23 { mag: 16088456, sign: false }); + data.append(FP8x23 { mag: 7446071, sign: false }); + data.append(FP8x23 { mag: 12333473, sign: true }); + data.append(FP8x23 { mag: 2694855, sign: false }); + data.append(FP8x23 { mag: 29333, sign: false }); + data.append(FP8x23 { mag: 3210230, sign: true }); + data.append(FP8x23 { mag: 246487, sign: false }); + data.append(FP8x23 { mag: 11307238, sign: true }); + data.append(FP8x23 { mag: 808074, sign: false }); + data.append(FP8x23 { mag: 2244426, sign: false }); + data.append(FP8x23 { mag: 4494036, sign: false }); + data.append(FP8x23 { mag: 9168918, sign: false }); + data.append(FP8x23 { mag: 11831318, sign: true }); + data.append(FP8x23 { mag: 11514568, sign: true }); + data.append(FP8x23 { mag: 3381120, sign: true }); + data.append(FP8x23 { mag: 6827926, sign: true }); + data.append(FP8x23 { mag: 2579494, sign: false }); + data.append(FP8x23 { mag: 4493030, sign: true }); + data.append(FP8x23 { mag: 4570125, sign: false }); + data.append(FP8x23 { mag: 8010665, sign: false }); + data.append(FP8x23 { mag: 5794037, sign: true }); + data.append(FP8x23 { mag: 9849078, sign: false }); + data.append(FP8x23 { mag: 11691798, sign: true }); + data.append(FP8x23 { mag: 3652747, sign: false }); + data.append(FP8x23 { mag: 1032666, sign: true }); + data.append(FP8x23 { mag: 9259310, sign: true }); + data.append(FP8x23 { mag: 7405492, sign: false }); + data.append(FP8x23 { mag: 4687488, sign: true }); + data.append(FP8x23 { mag: 1097650, sign: true }); + data.append(FP8x23 { mag: 2213858, sign: true }); + data.append(FP8x23 { mag: 1436205, sign: true }); + data.append(FP8x23 { mag: 10226423, sign: false }); + data.append(FP8x23 { mag: 6130226, sign: false }); + data.append(FP8x23 { mag: 1214058, sign: false }); + data.append(FP8x23 { mag: 12299984, sign: true }); + data.append(FP8x23 { mag: 829240, sign: false }); + data.append(FP8x23 { mag: 1612388, sign: false }); + data.append(FP8x23 { mag: 6632529, sign: true }); + data.append(FP8x23 { mag: 4410829, sign: true }); + data.append(FP8x23 { mag: 20654302, sign: false }); + data.append(FP8x23 { mag: 462475, sign: true }); + data.append(FP8x23 { mag: 10502841, sign: true }); + data.append(FP8x23 { mag: 7171902, sign: true }); + data.append(FP8x23 { mag: 4204962, sign: false }); + data.append(FP8x23 { mag: 17427142, sign: true }); + data.append(FP8x23 { mag: 12555224, sign: true }); + data.append(FP8x23 { mag: 8307885, sign: true }); + data.append(FP8x23 { mag: 455536, sign: false }); + data.append(FP8x23 { mag: 577191, sign: false }); + data.append(FP8x23 { mag: 4412268, sign: true }); + data.append(FP8x23 { mag: 15812229, sign: false }); + data.append(FP8x23 { mag: 7098764, sign: true }); + data.append(FP8x23 { mag: 9127468, sign: false }); + data.append(FP8x23 { mag: 4753858, sign: false }); + data.append(FP8x23 { mag: 2074029, sign: true }); + data.append(FP8x23 { mag: 1651256, sign: false }); + data.append(FP8x23 { mag: 9617324, sign: true }); + data.append(FP8x23 { mag: 11400835, sign: false }); + data.append(FP8x23 { mag: 4263073, sign: true }); + data.append(FP8x23 { mag: 22170402, sign: false }); + data.append(FP8x23 { mag: 7715608, sign: false }); + data.append(FP8x23 { mag: 7511781, sign: true }); + data.append(FP8x23 { mag: 8686402, sign: false }); + data.append(FP8x23 { mag: 2710329, sign: true }); + data.append(FP8x23 { mag: 5540998, sign: true }); + data.append(FP8x23 { mag: 11608300, sign: true }); + data.append(FP8x23 { mag: 3020404, sign: true }); + data.append(FP8x23 { mag: 6342478, sign: true }); + data.append(FP8x23 { mag: 9399735, sign: true }); + data.append(FP8x23 { mag: 446463, sign: false }); + data.append(FP8x23 { mag: 13691013, sign: true }); + data.append(FP8x23 { mag: 11552903, sign: true }); + data.append(FP8x23 { mag: 1204731, sign: false }); + data.append(FP8x23 { mag: 1741592, sign: true }); + data.append(FP8x23 { mag: 13103082, sign: false }); + data.append(FP8x23 { mag: 3181444, sign: true }); + data.append(FP8x23 { mag: 256975, sign: true }); + data.append(FP8x23 { mag: 9440785, sign: true }); + data.append(FP8x23 { mag: 2112590, sign: false }); + data.append(FP8x23 { mag: 13404752, sign: false }); + data.append(FP8x23 { mag: 760699, sign: false }); + data.append(FP8x23 { mag: 1588793, sign: true }); + data.append(FP8x23 { mag: 13026604, sign: true }); + data.append(FP8x23 { mag: 452707, sign: false }); + data.append(FP8x23 { mag: 7267348, sign: true }); + data.append(FP8x23 { mag: 14737007, sign: false }); + data.append(FP8x23 { mag: 8457998, sign: false }); + data.append(FP8x23 { mag: 2233703, sign: false }); + data.append(FP8x23 { mag: 3434673, sign: true }); + data.append(FP8x23 { mag: 4280157, sign: true }); + data.append(FP8x23 { mag: 2950181, sign: true }); + data.append(FP8x23 { mag: 1385553, sign: false }); + data.append(FP8x23 { mag: 17250056, sign: false }); + data.append(FP8x23 { mag: 12716927, sign: true }); + data.append(FP8x23 { mag: 2980452, sign: false }); + data.append(FP8x23 { mag: 13031106, sign: true }); + data.append(FP8x23 { mag: 4118717, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis_negative_4/input_1.cairo b/tests/nodes/layer_normalization_4d_axis_negative_4/input_1.cairo new file mode 100644 index 000000000..cefc1df2c --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis_negative_4/input_1.cairo @@ -0,0 +1,135 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 765036, sign: false }); + data.append(FP8x23 { mag: 1487908, sign: true }); + data.append(FP8x23 { mag: 8979452, sign: true }); + data.append(FP8x23 { mag: 2457758, sign: true }); + data.append(FP8x23 { mag: 6713378, sign: true }); + data.append(FP8x23 { mag: 9826923, sign: true }); + data.append(FP8x23 { mag: 2590671, sign: false }); + data.append(FP8x23 { mag: 3000698, sign: true }); + data.append(FP8x23 { mag: 1027684, sign: true }); + data.append(FP8x23 { mag: 6534988, sign: true }); + data.append(FP8x23 { mag: 8946545, sign: false }); + data.append(FP8x23 { mag: 4391800, sign: true }); + data.append(FP8x23 { mag: 7641860, sign: true }); + data.append(FP8x23 { mag: 11675021, sign: false }); + data.append(FP8x23 { mag: 7466702, sign: false }); + data.append(FP8x23 { mag: 395181, sign: true }); + data.append(FP8x23 { mag: 365436, sign: true }); + data.append(FP8x23 { mag: 15999312, sign: true }); + data.append(FP8x23 { mag: 11565880, sign: false }); + data.append(FP8x23 { mag: 9282480, sign: false }); + data.append(FP8x23 { mag: 3785428, sign: true }); + data.append(FP8x23 { mag: 18869726, sign: true }); + data.append(FP8x23 { mag: 6981482, sign: true }); + data.append(FP8x23 { mag: 7335699, sign: false }); + data.append(FP8x23 { mag: 3393176, sign: true }); + data.append(FP8x23 { mag: 17918872, sign: true }); + data.append(FP8x23 { mag: 4778393, sign: true }); + data.append(FP8x23 { mag: 15076621, sign: true }); + data.append(FP8x23 { mag: 2537182, sign: false }); + data.append(FP8x23 { mag: 13904034, sign: false }); + data.append(FP8x23 { mag: 4360076, sign: true }); + data.append(FP8x23 { mag: 15481889, sign: false }); + data.append(FP8x23 { mag: 2741193, sign: false }); + data.append(FP8x23 { mag: 465479, sign: true }); + data.append(FP8x23 { mag: 2151935, sign: true }); + data.append(FP8x23 { mag: 3630430, sign: true }); + data.append(FP8x23 { mag: 8001601, sign: false }); + data.append(FP8x23 { mag: 10861003, sign: true }); + data.append(FP8x23 { mag: 15152404, sign: false }); + data.append(FP8x23 { mag: 12626251, sign: true }); + data.append(FP8x23 { mag: 4662175, sign: true }); + data.append(FP8x23 { mag: 1467433, sign: false }); + data.append(FP8x23 { mag: 18234838, sign: false }); + data.append(FP8x23 { mag: 15142594, sign: false }); + data.append(FP8x23 { mag: 797950, sign: false }); + data.append(FP8x23 { mag: 6158520, sign: true }); + data.append(FP8x23 { mag: 1464793, sign: false }); + data.append(FP8x23 { mag: 9229589, sign: false }); + data.append(FP8x23 { mag: 2331004, sign: false }); + data.append(FP8x23 { mag: 4402549, sign: true }); + data.append(FP8x23 { mag: 5288030, sign: false }); + data.append(FP8x23 { mag: 8947151, sign: true }); + data.append(FP8x23 { mag: 13011209, sign: false }); + data.append(FP8x23 { mag: 14550682, sign: true }); + data.append(FP8x23 { mag: 19891878, sign: false }); + data.append(FP8x23 { mag: 8668230, sign: false }); + data.append(FP8x23 { mag: 6457287, sign: false }); + data.append(FP8x23 { mag: 2427503, sign: false }); + data.append(FP8x23 { mag: 12205684, sign: false }); + data.append(FP8x23 { mag: 2003377, sign: false }); + data.append(FP8x23 { mag: 11948757, sign: false }); + data.append(FP8x23 { mag: 1542336, sign: true }); + data.append(FP8x23 { mag: 762426, sign: false }); + data.append(FP8x23 { mag: 7311585, sign: true }); + data.append(FP8x23 { mag: 4824334, sign: true }); + data.append(FP8x23 { mag: 5502968, sign: true }); + data.append(FP8x23 { mag: 12092839, sign: true }); + data.append(FP8x23 { mag: 3938484, sign: false }); + data.append(FP8x23 { mag: 13236991, sign: false }); + data.append(FP8x23 { mag: 13323203, sign: false }); + data.append(FP8x23 { mag: 3383393, sign: true }); + data.append(FP8x23 { mag: 2710544, sign: true }); + data.append(FP8x23 { mag: 1101084, sign: true }); + data.append(FP8x23 { mag: 8167031, sign: true }); + data.append(FP8x23 { mag: 10908585, sign: false }); + data.append(FP8x23 { mag: 5727771, sign: true }); + data.append(FP8x23 { mag: 7391430, sign: true }); + data.append(FP8x23 { mag: 9532868, sign: true }); + data.append(FP8x23 { mag: 570993, sign: false }); + data.append(FP8x23 { mag: 1661402, sign: true }); + data.append(FP8x23 { mag: 8927581, sign: false }); + data.append(FP8x23 { mag: 69087, sign: true }); + data.append(FP8x23 { mag: 13202806, sign: true }); + data.append(FP8x23 { mag: 4596499, sign: false }); + data.append(FP8x23 { mag: 6458124, sign: false }); + data.append(FP8x23 { mag: 16955478, sign: true }); + data.append(FP8x23 { mag: 12176098, sign: false }); + data.append(FP8x23 { mag: 9405263, sign: false }); + data.append(FP8x23 { mag: 7171010, sign: true }); + data.append(FP8x23 { mag: 2977961, sign: false }); + data.append(FP8x23 { mag: 5183702, sign: false }); + data.append(FP8x23 { mag: 15341216, sign: false }); + data.append(FP8x23 { mag: 18373916, sign: false }); + data.append(FP8x23 { mag: 3529331, sign: false }); + data.append(FP8x23 { mag: 21097650, sign: false }); + data.append(FP8x23 { mag: 3775469, sign: false }); + data.append(FP8x23 { mag: 6304868, sign: false }); + data.append(FP8x23 { mag: 3531356, sign: false }); + data.append(FP8x23 { mag: 2877400, sign: true }); + data.append(FP8x23 { mag: 12725005, sign: false }); + data.append(FP8x23 { mag: 2672641, sign: true }); + data.append(FP8x23 { mag: 2460268, sign: false }); + data.append(FP8x23 { mag: 9700178, sign: false }); + data.append(FP8x23 { mag: 1683799, sign: true }); + data.append(FP8x23 { mag: 2685739, sign: true }); + data.append(FP8x23 { mag: 2567529, sign: false }); + data.append(FP8x23 { mag: 12283587, sign: false }); + data.append(FP8x23 { mag: 903086, sign: true }); + data.append(FP8x23 { mag: 20664720, sign: true }); + data.append(FP8x23 { mag: 4231587, sign: false }); + data.append(FP8x23 { mag: 3279288, sign: false }); + data.append(FP8x23 { mag: 2866349, sign: true }); + data.append(FP8x23 { mag: 11227992, sign: true }); + data.append(FP8x23 { mag: 13129429, sign: true }); + data.append(FP8x23 { mag: 3240746, sign: false }); + data.append(FP8x23 { mag: 9516856, sign: true }); + data.append(FP8x23 { mag: 6088257, sign: true }); + data.append(FP8x23 { mag: 502779, sign: false }); + data.append(FP8x23 { mag: 4585135, sign: true }); + data.append(FP8x23 { mag: 9007322, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis_negative_4/input_2.cairo b/tests/nodes/layer_normalization_4d_axis_negative_4/input_2.cairo new file mode 100644 index 000000000..993d1b2b5 --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis_negative_4/input_2.cairo @@ -0,0 +1,135 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn input_2() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 1565206, sign: true }); + data.append(FP8x23 { mag: 8440746, sign: false }); + data.append(FP8x23 { mag: 73220, sign: true }); + data.append(FP8x23 { mag: 9789189, sign: false }); + data.append(FP8x23 { mag: 3941934, sign: false }); + data.append(FP8x23 { mag: 870617, sign: false }); + data.append(FP8x23 { mag: 9469178, sign: true }); + data.append(FP8x23 { mag: 2579358, sign: true }); + data.append(FP8x23 { mag: 11437672, sign: true }); + data.append(FP8x23 { mag: 4638952, sign: false }); + data.append(FP8x23 { mag: 6273082, sign: true }); + data.append(FP8x23 { mag: 6222525, sign: true }); + data.append(FP8x23 { mag: 10101133, sign: true }); + data.append(FP8x23 { mag: 5359626, sign: true }); + data.append(FP8x23 { mag: 3085420, sign: false }); + data.append(FP8x23 { mag: 9194004, sign: false }); + data.append(FP8x23 { mag: 6966236, sign: false }); + data.append(FP8x23 { mag: 6880860, sign: true }); + data.append(FP8x23 { mag: 5546158, sign: false }); + data.append(FP8x23 { mag: 6934687, sign: true }); + data.append(FP8x23 { mag: 987627, sign: true }); + data.append(FP8x23 { mag: 4104909, sign: false }); + data.append(FP8x23 { mag: 10332664, sign: true }); + data.append(FP8x23 { mag: 6589758, sign: true }); + data.append(FP8x23 { mag: 13722424, sign: true }); + data.append(FP8x23 { mag: 9462928, sign: true }); + data.append(FP8x23 { mag: 13947351, sign: true }); + data.append(FP8x23 { mag: 5509455, sign: true }); + data.append(FP8x23 { mag: 644274, sign: true }); + data.append(FP8x23 { mag: 2030384, sign: true }); + data.append(FP8x23 { mag: 1089765, sign: true }); + data.append(FP8x23 { mag: 175051, sign: false }); + data.append(FP8x23 { mag: 10990995, sign: false }); + data.append(FP8x23 { mag: 950685, sign: true }); + data.append(FP8x23 { mag: 6512231, sign: false }); + data.append(FP8x23 { mag: 980559, sign: true }); + data.append(FP8x23 { mag: 2639622, sign: true }); + data.append(FP8x23 { mag: 6566764, sign: true }); + data.append(FP8x23 { mag: 5196593, sign: true }); + data.append(FP8x23 { mag: 12299425, sign: true }); + data.append(FP8x23 { mag: 98068, sign: true }); + data.append(FP8x23 { mag: 5477706, sign: false }); + data.append(FP8x23 { mag: 540028, sign: false }); + data.append(FP8x23 { mag: 6323924, sign: false }); + data.append(FP8x23 { mag: 4402990, sign: false }); + data.append(FP8x23 { mag: 4425289, sign: true }); + data.append(FP8x23 { mag: 1864532, sign: true }); + data.append(FP8x23 { mag: 3821369, sign: true }); + data.append(FP8x23 { mag: 4152119, sign: true }); + data.append(FP8x23 { mag: 13493616, sign: false }); + data.append(FP8x23 { mag: 16485032, sign: false }); + data.append(FP8x23 { mag: 47558, sign: true }); + data.append(FP8x23 { mag: 4175409, sign: false }); + data.append(FP8x23 { mag: 1769218, sign: true }); + data.append(FP8x23 { mag: 1835677, sign: true }); + data.append(FP8x23 { mag: 6190130, sign: false }); + data.append(FP8x23 { mag: 9290798, sign: true }); + data.append(FP8x23 { mag: 14518570, sign: false }); + data.append(FP8x23 { mag: 5531764, sign: false }); + data.append(FP8x23 { mag: 3781891, sign: true }); + data.append(FP8x23 { mag: 8213439, sign: true }); + data.append(FP8x23 { mag: 5323399, sign: false }); + data.append(FP8x23 { mag: 15125746, sign: true }); + data.append(FP8x23 { mag: 3776717, sign: false }); + data.append(FP8x23 { mag: 5903805, sign: false }); + data.append(FP8x23 { mag: 735915, sign: true }); + data.append(FP8x23 { mag: 2326781, sign: false }); + data.append(FP8x23 { mag: 10726848, sign: true }); + data.append(FP8x23 { mag: 8423845, sign: false }); + data.append(FP8x23 { mag: 5989648, sign: true }); + data.append(FP8x23 { mag: 9314061, sign: false }); + data.append(FP8x23 { mag: 11670752, sign: false }); + data.append(FP8x23 { mag: 364021, sign: false }); + data.append(FP8x23 { mag: 1428671, sign: true }); + data.append(FP8x23 { mag: 2859786, sign: true }); + data.append(FP8x23 { mag: 367369, sign: false }); + data.append(FP8x23 { mag: 2107786, sign: true }); + data.append(FP8x23 { mag: 5419940, sign: true }); + data.append(FP8x23 { mag: 8748640, sign: false }); + data.append(FP8x23 { mag: 6783550, sign: true }); + data.append(FP8x23 { mag: 1174084, sign: false }); + data.append(FP8x23 { mag: 1244741, sign: false }); + data.append(FP8x23 { mag: 628290, sign: false }); + data.append(FP8x23 { mag: 7245872, sign: false }); + data.append(FP8x23 { mag: 10358153, sign: true }); + data.append(FP8x23 { mag: 11037845, sign: true }); + data.append(FP8x23 { mag: 7181495, sign: false }); + data.append(FP8x23 { mag: 9622438, sign: true }); + data.append(FP8x23 { mag: 433991, sign: true }); + data.append(FP8x23 { mag: 1950004, sign: true }); + data.append(FP8x23 { mag: 172993, sign: false }); + data.append(FP8x23 { mag: 5089314, sign: true }); + data.append(FP8x23 { mag: 2494218, sign: true }); + data.append(FP8x23 { mag: 6786513, sign: true }); + data.append(FP8x23 { mag: 757032, sign: false }); + data.append(FP8x23 { mag: 265511, sign: false }); + data.append(FP8x23 { mag: 3655001, sign: true }); + data.append(FP8x23 { mag: 1933308, sign: false }); + data.append(FP8x23 { mag: 13213099, sign: true }); + data.append(FP8x23 { mag: 4142122, sign: false }); + data.append(FP8x23 { mag: 1604209, sign: false }); + data.append(FP8x23 { mag: 11013449, sign: false }); + data.append(FP8x23 { mag: 9016796, sign: true }); + data.append(FP8x23 { mag: 12822414, sign: true }); + data.append(FP8x23 { mag: 5674089, sign: true }); + data.append(FP8x23 { mag: 7284811, sign: false }); + data.append(FP8x23 { mag: 5380566, sign: false }); + data.append(FP8x23 { mag: 1200376, sign: true }); + data.append(FP8x23 { mag: 4449674, sign: false }); + data.append(FP8x23 { mag: 2509265, sign: true }); + data.append(FP8x23 { mag: 11423904, sign: false }); + data.append(FP8x23 { mag: 182801, sign: true }); + data.append(FP8x23 { mag: 61327, sign: true }); + data.append(FP8x23 { mag: 20514282, sign: false }); + data.append(FP8x23 { mag: 2488749, sign: true }); + data.append(FP8x23 { mag: 4664432, sign: false }); + data.append(FP8x23 { mag: 12670565, sign: false }); + data.append(FP8x23 { mag: 4800221, sign: true }); + data.append(FP8x23 { mag: 5916248, sign: false }); + data.append(FP8x23 { mag: 2870212, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_4d_axis_negative_4/output_0.cairo b/tests/nodes/layer_normalization_4d_axis_negative_4/output_0.cairo new file mode 100644 index 000000000..915c0e6b9 --- /dev/null +++ b/tests/nodes/layer_normalization_4d_axis_negative_4/output_0.cairo @@ -0,0 +1,135 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP8x23Tensor; +use orion::numbers::{FixedTrait, FP8x23}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP8x23 { mag: 2424743, sign: true }); + data.append(FP8x23 { mag: 6230880, sign: false }); + data.append(FP8x23 { mag: 1347738, sign: true }); + data.append(FP8x23 { mag: 11835814, sign: false }); + data.append(FP8x23 { mag: 8560172, sign: true }); + data.append(FP8x23 { mag: 18353888, sign: true }); + data.append(FP8x23 { mag: 9644270, sign: true }); + data.append(FP8x23 { mag: 1531681, sign: true }); + data.append(FP8x23 { mag: 12454719, sign: true }); + data.append(FP8x23 { mag: 9308055, sign: false }); + data.append(FP8x23 { mag: 11949580, sign: true }); + data.append(FP8x23 { mag: 4039623, sign: true }); + data.append(FP8x23 { mag: 31716300, sign: true }); + data.append(FP8x23 { mag: 15707956, sign: true }); + data.append(FP8x23 { mag: 8486900, sign: false }); + data.append(FP8x23 { mag: 9430574, sign: false }); + data.append(FP8x23 { mag: 6804483, sign: false }); + data.append(FP8x23 { mag: 2207451, sign: true }); + data.append(FP8x23 { mag: 13474301, sign: false }); + data.append(FP8x23 { mag: 16478194, sign: true }); + data.append(FP8x23 { mag: 150738, sign: false }); + data.append(FP8x23 { mag: 13342368, sign: true }); + data.append(FP8x23 { mag: 10913451, sign: true }); + data.append(FP8x23 { mag: 7620089, sign: false }); + data.append(FP8x23 { mag: 16835534, sign: true }); + data.append(FP8x23 { mag: 15912140, sign: false }); + data.append(FP8x23 { mag: 15652847, sign: true }); + data.append(FP8x23 { mag: 6149347, sign: true }); + data.append(FP8x23 { mag: 1506302, sign: true }); + data.append(FP8x23 { mag: 1084043, sign: true }); + data.append(FP8x23 { mag: 4556684, sign: false }); + data.append(FP8x23 { mag: 2254544, sign: false }); + data.append(FP8x23 { mag: 11823707, sign: false }); + data.append(FP8x23 { mag: 1215629, sign: true }); + data.append(FP8x23 { mag: 4100511, sign: false }); + data.append(FP8x23 { mag: 3945443, sign: false }); + data.append(FP8x23 { mag: 13197695, sign: true }); + data.append(FP8x23 { mag: 2657681, sign: true }); + data.append(FP8x23 { mag: 16811968, sign: true }); + data.append(FP8x23 { mag: 16634120, sign: true }); + data.append(FP8x23 { mag: 2191530, sign: false }); + data.append(FP8x23 { mag: 6326118, sign: false }); + data.append(FP8x23 { mag: 18484424, sign: false }); + data.append(FP8x23 { mag: 3436880, sign: true }); + data.append(FP8x23 { mag: 5361301, sign: false }); + data.append(FP8x23 { mag: 3829618, sign: false }); + data.append(FP8x23 { mag: 1176182, sign: true }); + data.append(FP8x23 { mag: 4586050, sign: true }); + data.append(FP8x23 { mag: 6607650, sign: true }); + data.append(FP8x23 { mag: 9475522, sign: false }); + data.append(FP8x23 { mag: 13766757, sign: false }); + data.append(FP8x23 { mag: 762316, sign: false }); + data.append(FP8x23 { mag: 1284229, sign: false }); + data.append(FP8x23 { mag: 129063, sign: false }); + data.append(FP8x23 { mag: 22939360, sign: false }); + data.append(FP8x23 { mag: 12797228, sign: false }); + data.append(FP8x23 { mag: 8114181, sign: true }); + data.append(FP8x23 { mag: 11090552, sign: false }); + data.append(FP8x23 { mag: 7201685, sign: false }); + data.append(FP8x23 { mag: 3322697, sign: true }); + data.append(FP8x23 { mag: 17097544, sign: true }); + data.append(FP8x23 { mag: 6065884, sign: false }); + data.append(FP8x23 { mag: 13238165, sign: true }); + data.append(FP8x23 { mag: 3890634, sign: false }); + data.append(FP8x23 { mag: 11693645, sign: false }); + data.append(FP8x23 { mag: 3705815, sign: false }); + data.append(FP8x23 { mag: 4143846, sign: true }); + data.append(FP8x23 { mag: 18670984, sign: true }); + data.append(FP8x23 { mag: 10667477, sign: true }); + data.append(FP8x23 { mag: 18529098, sign: true }); + data.append(FP8x23 { mag: 9000333, sign: false }); + data.append(FP8x23 { mag: 11380511, sign: false }); + data.append(FP8x23 { mag: 894273, sign: false }); + data.append(FP8x23 { mag: 16982722, sign: true }); + data.append(FP8x23 { mag: 11570541, sign: true }); + data.append(FP8x23 { mag: 6023858, sign: true }); + data.append(FP8x23 { mag: 6541442, sign: true }); + data.append(FP8x23 { mag: 3458935, sign: true }); + data.append(FP8x23 { mag: 8882136, sign: false }); + data.append(FP8x23 { mag: 4963218, sign: true }); + data.append(FP8x23 { mag: 13530222, sign: false }); + data.append(FP8x23 { mag: 1276796, sign: false }); + data.append(FP8x23 { mag: 34420188, sign: true }); + data.append(FP8x23 { mag: 11609153, sign: false }); + data.append(FP8x23 { mag: 15829801, sign: true }); + data.append(FP8x23 { mag: 29074996, sign: true }); + data.append(FP8x23 { mag: 3762691, sign: false }); + data.append(FP8x23 { mag: 15404224, sign: true }); + data.append(FP8x23 { mag: 9107420, sign: false }); + data.append(FP8x23 { mag: 2895097, sign: true }); + data.append(FP8x23 { mag: 3503790, sign: true }); + data.append(FP8x23 { mag: 21504250, sign: true }); + data.append(FP8x23 { mag: 810151, sign: true }); + data.append(FP8x23 { mag: 12349691, sign: true }); + data.append(FP8x23 { mag: 27176632, sign: true }); + data.append(FP8x23 { mag: 949305, sign: false }); + data.append(FP8x23 { mag: 4704694, sign: true }); + data.append(FP8x23 { mag: 7530050, sign: false }); + data.append(FP8x23 { mag: 12245252, sign: true }); + data.append(FP8x23 { mag: 4252376, sign: false }); + data.append(FP8x23 { mag: 4476849, sign: false }); + data.append(FP8x23 { mag: 11722556, sign: false }); + data.append(FP8x23 { mag: 6701965, sign: false }); + data.append(FP8x23 { mag: 13039167, sign: true }); + data.append(FP8x23 { mag: 5275358, sign: true }); + data.append(FP8x23 { mag: 3438951, sign: false }); + data.append(FP8x23 { mag: 6515470, sign: false }); + data.append(FP8x23 { mag: 461279, sign: true }); + data.append(FP8x23 { mag: 32284742, sign: true }); + data.append(FP8x23 { mag: 1878248, sign: false }); + data.append(FP8x23 { mag: 12415928, sign: false }); + data.append(FP8x23 { mag: 866962, sign: false }); + data.append(FP8x23 { mag: 5170767, sign: false }); + data.append(FP8x23 { mag: 24572292, sign: false }); + data.append(FP8x23 { mag: 1832665, sign: true }); + data.append(FP8x23 { mag: 15074726, sign: true }); + data.append(FP8x23 { mag: 21567628, sign: false }); + data.append(FP8x23 { mag: 4603829, sign: true }); + data.append(FP8x23 { mag: 12786682, sign: false }); + data.append(FP8x23 { mag: 1155521, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_default_axis.cairo b/tests/nodes/layer_normalization_default_axis.cairo new file mode 100644 index 000000000..1d1e1db90 --- /dev/null +++ b/tests/nodes/layer_normalization_default_axis.cairo @@ -0,0 +1,28 @@ +mod input_0; +mod input_1; +mod input_2; +mod output_0; + + +use orion::operators::tensor::FP16x16TensorPartialEq; +use orion::utils::{assert_eq, assert_seq_eq}; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{IntegerTrait, i32, FixedTrait}; + +#[test] +#[available_gas(2000000000)] +fn test_layer_normalization_default_axis() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let input_2 = input_2::input_2(); + let z_0 = output_0::output_0(); + + let (y_0, _, _) = input_0 + .layer_normalization( + @input_1, Option::Some(@input_2), Option::None, Option::None, Option::None + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/layer_normalization_default_axis/input_0.cairo b/tests/nodes/layer_normalization_default_axis/input_0.cairo new file mode 100644 index 000000000..d8da4e64e --- /dev/null +++ b/tests/nodes/layer_normalization_default_axis/input_0.cairo @@ -0,0 +1,135 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 162162, sign: true }); + data.append(FP16x16 { mag: 29273, sign: true }); + data.append(FP16x16 { mag: 155457, sign: true }); + data.append(FP16x16 { mag: 50795, sign: false }); + data.append(FP16x16 { mag: 17629, sign: true }); + data.append(FP16x16 { mag: 26642, sign: false }); + data.append(FP16x16 { mag: 29753, sign: true }); + data.append(FP16x16 { mag: 76017, sign: true }); + data.append(FP16x16 { mag: 58145, sign: true }); + data.append(FP16x16 { mag: 81397, sign: false }); + data.append(FP16x16 { mag: 56926, sign: true }); + data.append(FP16x16 { mag: 92380, sign: true }); + data.append(FP16x16 { mag: 5200, sign: true }); + data.append(FP16x16 { mag: 24905, sign: false }); + data.append(FP16x16 { mag: 14892, sign: true }); + data.append(FP16x16 { mag: 37085, sign: false }); + data.append(FP16x16 { mag: 120688, sign: false }); + data.append(FP16x16 { mag: 24518, sign: true }); + data.append(FP16x16 { mag: 56437, sign: true }); + data.append(FP16x16 { mag: 68199, sign: true }); + data.append(FP16x16 { mag: 49008, sign: true }); + data.append(FP16x16 { mag: 6080, sign: false }); + data.append(FP16x16 { mag: 72783, sign: true }); + data.append(FP16x16 { mag: 100706, sign: true }); + data.append(FP16x16 { mag: 86624, sign: false }); + data.append(FP16x16 { mag: 72942, sign: false }); + data.append(FP16x16 { mag: 416, sign: false }); + data.append(FP16x16 { mag: 60084, sign: false }); + data.append(FP16x16 { mag: 13763, sign: false }); + data.append(FP16x16 { mag: 29895, sign: false }); + data.append(FP16x16 { mag: 70876, sign: false }); + data.append(FP16x16 { mag: 29792, sign: false }); + data.append(FP16x16 { mag: 50085, sign: false }); + data.append(FP16x16 { mag: 14512, sign: false }); + data.append(FP16x16 { mag: 69956, sign: true }); + data.append(FP16x16 { mag: 23179, sign: true }); + data.append(FP16x16 { mag: 34593, sign: false }); + data.append(FP16x16 { mag: 94827, sign: true }); + data.append(FP16x16 { mag: 36976, sign: true }); + data.append(FP16x16 { mag: 66090, sign: true }); + data.append(FP16x16 { mag: 46310, sign: true }); + data.append(FP16x16 { mag: 1270, sign: false }); + data.append(FP16x16 { mag: 10179, sign: false }); + data.append(FP16x16 { mag: 2421, sign: false }); + data.append(FP16x16 { mag: 41063, sign: false }); + data.append(FP16x16 { mag: 10014, sign: true }); + data.append(FP16x16 { mag: 23971, sign: true }); + data.append(FP16x16 { mag: 38734, sign: false }); + data.append(FP16x16 { mag: 252987, sign: false }); + data.append(FP16x16 { mag: 15276, sign: false }); + data.append(FP16x16 { mag: 94190, sign: true }); + data.append(FP16x16 { mag: 152835, sign: true }); + data.append(FP16x16 { mag: 120387, sign: true }); + data.append(FP16x16 { mag: 11624, sign: true }); + data.append(FP16x16 { mag: 83095, sign: false }); + data.append(FP16x16 { mag: 120401, sign: true }); + data.append(FP16x16 { mag: 32613, sign: false }); + data.append(FP16x16 { mag: 125957, sign: true }); + data.append(FP16x16 { mag: 44002, sign: true }); + data.append(FP16x16 { mag: 98578, sign: false }); + data.append(FP16x16 { mag: 40668, sign: true }); + data.append(FP16x16 { mag: 15027, sign: true }); + data.append(FP16x16 { mag: 80889, sign: true }); + data.append(FP16x16 { mag: 46530, sign: true }); + data.append(FP16x16 { mag: 14759, sign: false }); + data.append(FP16x16 { mag: 67967, sign: false }); + data.append(FP16x16 { mag: 83003, sign: false }); + data.append(FP16x16 { mag: 59413, sign: true }); + data.append(FP16x16 { mag: 24514, sign: false }); + data.append(FP16x16 { mag: 33820, sign: true }); + data.append(FP16x16 { mag: 71782, sign: false }); + data.append(FP16x16 { mag: 102396, sign: false }); + data.append(FP16x16 { mag: 118, sign: true }); + data.append(FP16x16 { mag: 1158, sign: true }); + data.append(FP16x16 { mag: 60291, sign: true }); + data.append(FP16x16 { mag: 41730, sign: true }); + data.append(FP16x16 { mag: 17552, sign: true }); + data.append(FP16x16 { mag: 27965, sign: true }); + data.append(FP16x16 { mag: 12189, sign: true }); + data.append(FP16x16 { mag: 1007, sign: true }); + data.append(FP16x16 { mag: 75390, sign: false }); + data.append(FP16x16 { mag: 29536, sign: true }); + data.append(FP16x16 { mag: 39461, sign: true }); + data.append(FP16x16 { mag: 36937, sign: false }); + data.append(FP16x16 { mag: 74171, sign: false }); + data.append(FP16x16 { mag: 18374, sign: false }); + data.append(FP16x16 { mag: 25540, sign: false }); + data.append(FP16x16 { mag: 72170, sign: true }); + data.append(FP16x16 { mag: 13960, sign: true }); + data.append(FP16x16 { mag: 25485, sign: false }); + data.append(FP16x16 { mag: 36529, sign: true }); + data.append(FP16x16 { mag: 157843, sign: false }); + data.append(FP16x16 { mag: 18027, sign: false }); + data.append(FP16x16 { mag: 49602, sign: false }); + data.append(FP16x16 { mag: 78001, sign: false }); + data.append(FP16x16 { mag: 12776, sign: false }); + data.append(FP16x16 { mag: 2507, sign: false }); + data.append(FP16x16 { mag: 85706, sign: false }); + data.append(FP16x16 { mag: 76268, sign: true }); + data.append(FP16x16 { mag: 20694, sign: false }); + data.append(FP16x16 { mag: 111079, sign: true }); + data.append(FP16x16 { mag: 118166, sign: false }); + data.append(FP16x16 { mag: 84064, sign: false }); + data.append(FP16x16 { mag: 146152, sign: false }); + data.append(FP16x16 { mag: 99879, sign: false }); + data.append(FP16x16 { mag: 101373, sign: false }); + data.append(FP16x16 { mag: 138415, sign: false }); + data.append(FP16x16 { mag: 34200, sign: true }); + data.append(FP16x16 { mag: 61986, sign: true }); + data.append(FP16x16 { mag: 99859, sign: false }); + data.append(FP16x16 { mag: 35109, sign: true }); + data.append(FP16x16 { mag: 28029, sign: false }); + data.append(FP16x16 { mag: 56997, sign: false }); + data.append(FP16x16 { mag: 19277, sign: false }); + data.append(FP16x16 { mag: 38195, sign: true }); + data.append(FP16x16 { mag: 56295, sign: true }); + data.append(FP16x16 { mag: 25408, sign: false }); + data.append(FP16x16 { mag: 72993, sign: true }); + data.append(FP16x16 { mag: 28011, sign: false }); + data.append(FP16x16 { mag: 63069, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_default_axis/input_1.cairo b/tests/nodes/layer_normalization_default_axis/input_1.cairo new file mode 100644 index 000000000..f180a8bd1 --- /dev/null +++ b/tests/nodes/layer_normalization_default_axis/input_1.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 8602, sign: true }); + data.append(FP16x16 { mag: 134121, sign: true }); + data.append(FP16x16 { mag: 39230, sign: false }); + data.append(FP16x16 { mag: 17052, sign: true }); + data.append(FP16x16 { mag: 24886, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_default_axis/input_2.cairo b/tests/nodes/layer_normalization_default_axis/input_2.cairo new file mode 100644 index 000000000..540f8c89b --- /dev/null +++ b/tests/nodes/layer_normalization_default_axis/input_2.cairo @@ -0,0 +1,17 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_2() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 35619, sign: true }); + data.append(FP16x16 { mag: 31160, sign: false }); + data.append(FP16x16 { mag: 51197, sign: true }); + data.append(FP16x16 { mag: 150071, sign: true }); + data.append(FP16x16 { mag: 57770, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_default_axis/output_0.cairo b/tests/nodes/layer_normalization_default_axis/output_0.cairo new file mode 100644 index 000000000..78cc8f3c2 --- /dev/null +++ b/tests/nodes/layer_normalization_default_axis/output_0.cairo @@ -0,0 +1,135 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(2); + shape.append(3); + shape.append(4); + shape.append(5); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 25327, sign: true }); + data.append(FP16x16 { mag: 22863, sign: true }); + data.append(FP16x16 { mag: 94966, sign: true }); + data.append(FP16x16 { mag: 173370, sign: true }); + data.append(FP16x16 { mag: 71281, sign: false }); + data.append(FP16x16 { mag: 41238, sign: true }); + data.append(FP16x16 { mag: 74203, sign: false }); + data.append(FP16x16 { mag: 95139, sign: true }); + data.append(FP16x16 { mag: 136235, sign: true }); + data.append(FP16x16 { mag: 97567, sign: false }); + data.append(FP16x16 { mag: 29760, sign: true }); + data.append(FP16x16 { mag: 238045, sign: false }); + data.append(FP16x16 { mag: 28606, sign: true }); + data.append(FP16x16 { mag: 172365, sign: true }); + data.append(FP16x16 { mag: 66239, sign: false }); + data.append(FP16x16 { mag: 39976, sign: true }); + data.append(FP16x16 { mag: 197413, sign: true }); + data.append(FP16x16 { mag: 65945, sign: true }); + data.append(FP16x16 { mag: 135863, sign: true }); + data.append(FP16x16 { mag: 32841, sign: false }); + data.append(FP16x16 { mag: 32630, sign: true }); + data.append(FP16x16 { mag: 33602, sign: true }); + data.append(FP16x16 { mag: 78882, sign: true }); + data.append(FP16x16 { mag: 130860, sign: true }); + data.append(FP16x16 { mag: 99996, sign: false }); + data.append(FP16x16 { mag: 47418, sign: true }); + data.append(FP16x16 { mag: 202789, sign: false }); + data.append(FP16x16 { mag: 15824, sign: true }); + data.append(FP16x16 { mag: 136570, sign: true }); + data.append(FP16x16 { mag: 52743, sign: false }); + data.append(FP16x16 { mag: 44830, sign: true }); + data.append(FP16x16 { mag: 1418, sign: false }); + data.append(FP16x16 { mag: 26045, sign: true }); + data.append(FP16x16 { mag: 148467, sign: true }); + data.append(FP16x16 { mag: 11987, sign: false }); + data.append(FP16x16 { mag: 38404, sign: true }); + data.append(FP16x16 { mag: 189961, sign: true }); + data.append(FP16x16 { mag: 102957, sign: true }); + data.append(FP16x16 { mag: 150196, sign: true }); + data.append(FP16x16 { mag: 41335, sign: false }); + data.append(FP16x16 { mag: 20875, sign: true }); + data.append(FP16x16 { mag: 33337, sign: false }); + data.append(FP16x16 { mag: 39363, sign: true }); + data.append(FP16x16 { mag: 150495, sign: true }); + data.append(FP16x16 { mag: 92702, sign: false }); + data.append(FP16x16 { mag: 30142, sign: true }); + data.append(FP16x16 { mag: 134994, sign: false }); + data.append(FP16x16 { mag: 57330, sign: true }); + data.append(FP16x16 { mag: 183402, sign: true }); + data.append(FP16x16 { mag: 48127, sign: false }); + data.append(FP16x16 { mag: 32082, sign: true }); + data.append(FP16x16 { mag: 178676, sign: false }); + data.append(FP16x16 { mag: 79394, sign: true }); + data.append(FP16x16 { mag: 159596, sign: true }); + data.append(FP16x16 { mag: 99358, sign: false }); + data.append(FP16x16 { mag: 26881, sign: true }); + data.append(FP16x16 { mag: 67973, sign: true }); + data.append(FP16x16 { mag: 93546, sign: true }); + data.append(FP16x16 { mag: 147691, sign: true }); + data.append(FP16x16 { mag: 94992, sign: false }); + data.append(FP16x16 { mag: 33741, sign: true }); + data.append(FP16x16 { mag: 46849, sign: true }); + data.append(FP16x16 { mag: 108987, sign: true }); + data.append(FP16x16 { mag: 143230, sign: true }); + data.append(FP16x16 { mag: 95371, sign: false }); + data.append(FP16x16 { mag: 43596, sign: true }); + data.append(FP16x16 { mag: 129515, sign: true }); + data.append(FP16x16 { mag: 104770, sign: true }); + data.append(FP16x16 { mag: 152546, sign: true }); + data.append(FP16x16 { mag: 35250, sign: false }); + data.append(FP16x16 { mag: 42942, sign: true }); + data.append(FP16x16 { mag: 153994, sign: true }); + data.append(FP16x16 { mag: 66548, sign: true }); + data.append(FP16x16 { mag: 143091, sign: true }); + data.append(FP16x16 { mag: 22149, sign: false }); + data.append(FP16x16 { mag: 22209, sign: true }); + data.append(FP16x16 { mag: 6651, sign: false }); + data.append(FP16x16 { mag: 73456, sign: true }); + data.append(FP16x16 { mag: 159774, sign: true }); + data.append(FP16x16 { mag: 91978, sign: false }); + data.append(FP16x16 { mag: 44648, sign: true }); + data.append(FP16x16 { mag: 175059, sign: false }); + data.append(FP16x16 { mag: 101164, sign: true }); + data.append(FP16x16 { mag: 154706, sign: true }); + data.append(FP16x16 { mag: 83280, sign: false }); + data.append(FP16x16 { mag: 40617, sign: true }); + data.append(FP16x16 { mag: 72477, sign: true }); + data.append(FP16x16 { mag: 123424, sign: true }); + data.append(FP16x16 { mag: 145228, sign: true }); + data.append(FP16x16 { mag: 76964, sign: false }); + data.append(FP16x16 { mag: 23646, sign: true }); + data.append(FP16x16 { mag: 185699, sign: true }); + data.append(FP16x16 { mag: 72671, sign: true }); + data.append(FP16x16 { mag: 149071, sign: true }); + data.append(FP16x16 { mag: 67251, sign: false }); + data.append(FP16x16 { mag: 36234, sign: true }); + data.append(FP16x16 { mag: 48229, sign: false }); + data.append(FP16x16 { mag: 6980, sign: false }); + data.append(FP16x16 { mag: 121901, sign: true }); + data.append(FP16x16 { mag: 63362, sign: false }); + data.append(FP16x16 { mag: 18857, sign: true }); + data.append(FP16x16 { mag: 43105, sign: true }); + data.append(FP16x16 { mag: 44077, sign: true }); + data.append(FP16x16 { mag: 164722, sign: true }); + data.append(FP16x16 { mag: 66582, sign: false }); + data.append(FP16x16 { mag: 41234, sign: true }); + data.append(FP16x16 { mag: 117951, sign: true }); + data.append(FP16x16 { mag: 91492, sign: true }); + data.append(FP16x16 { mag: 126684, sign: true }); + data.append(FP16x16 { mag: 73548, sign: false }); + data.append(FP16x16 { mag: 26057, sign: true }); + data.append(FP16x16 { mag: 47620, sign: true }); + data.append(FP16x16 { mag: 2425, sign: false }); + data.append(FP16x16 { mag: 156071, sign: true }); + data.append(FP16x16 { mag: 28040, sign: false }); + data.append(FP16x16 { mag: 30148, sign: true }); + data.append(FP16x16 { mag: 128023, sign: true }); + data.append(FP16x16 { mag: 90764, sign: true }); + data.append(FP16x16 { mag: 171299, sign: true }); + data.append(FP16x16 { mag: 38179, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_test.cairo b/tests/nodes/layer_normalization_test.cairo new file mode 100644 index 000000000..02235bd8b --- /dev/null +++ b/tests/nodes/layer_normalization_test.cairo @@ -0,0 +1,28 @@ +mod input_0; +mod input_1; +mod input_2; +mod output_0; + + +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::FP16x16TensorPartialEq; +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +#[test] +#[available_gas(2000000000)] +fn test_layer_normalization_test() { + let input_0 = input_0::input_0(); + let input_1 = input_1::input_1(); + let input_2 = input_2::input_2(); + let z_0 = output_0::output_0(); + + let (y_0, _, _) = input_0 + .layer_normalization( + @input_1, Option::Some(@input_2), Option::None, Option::None, Option::None + ); + + assert_eq(y_0, z_0); +} diff --git a/tests/nodes/layer_normalization_test/input_0.cairo b/tests/nodes/layer_normalization_test/input_0.cairo new file mode 100644 index 000000000..6a1e2573e --- /dev/null +++ b/tests/nodes/layer_normalization_test/input_0.cairo @@ -0,0 +1,25 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 41143, sign: true }); + data.append(FP16x16 { mag: 51803, sign: false }); + data.append(FP16x16 { mag: 113556, sign: false }); + data.append(FP16x16 { mag: 64774, sign: false }); + data.append(FP16x16 { mag: 866, sign: false }); + data.append(FP16x16 { mag: 698, sign: true }); + data.append(FP16x16 { mag: 106500, sign: false }); + data.append(FP16x16 { mag: 98929, sign: false }); + data.append(FP16x16 { mag: 7551, sign: false }); + data.append(FP16x16 { mag: 30689, sign: true }); + data.append(FP16x16 { mag: 38325, sign: false }); + data.append(FP16x16 { mag: 48164, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_test/input_1.cairo b/tests/nodes/layer_normalization_test/input_1.cairo new file mode 100644 index 000000000..c56fcc649 --- /dev/null +++ b/tests/nodes/layer_normalization_test/input_1.cairo @@ -0,0 +1,16 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_1() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 49855, sign: false }); + data.append(FP16x16 { mag: 150787, sign: false }); + data.append(FP16x16 { mag: 83498, sign: true }); + data.append(FP16x16 { mag: 30346, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_test/input_2.cairo b/tests/nodes/layer_normalization_test/input_2.cairo new file mode 100644 index 000000000..9c662c873 --- /dev/null +++ b/tests/nodes/layer_normalization_test/input_2.cairo @@ -0,0 +1,16 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn input_2() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 54864, sign: true }); + data.append(FP16x16 { mag: 50952, sign: false }); + data.append(FP16x16 { mag: 8870, sign: true }); + data.append(FP16x16 { mag: 23216, sign: true }); + TensorTrait::new(shape.span(), data.span()) +} diff --git a/tests/nodes/layer_normalization_test/output_0.cairo b/tests/nodes/layer_normalization_test/output_0.cairo new file mode 100644 index 000000000..c7ddbe92d --- /dev/null +++ b/tests/nodes/layer_normalization_test/output_0.cairo @@ -0,0 +1,25 @@ +use core::array::{ArrayTrait, SpanTrait}; +use orion::operators::tensor::{TensorTrait, Tensor}; +use orion::operators::tensor::FP16x16Tensor; +use orion::numbers::{FixedTrait, FP16x16}; + +fn output_0() -> Tensor { + let mut shape = ArrayTrait::::new(); + shape.append(3); + shape.append(4); + + let mut data = ArrayTrait::new(); + data.append(FP16x16 { mag: 133576, sign: true }); + data.append(FP16x16 { mag: 63223, sign: false }); + data.append(FP16x16 { mag: 107763, sign: true }); + data.append(FP16x16 { mag: 13716, sign: true }); + data.append(FP16x16 { mag: 103890, sign: true }); + data.append(FP16x16 { mag: 101915, sign: true }); + data.append(FP16x16 { mag: 98400, sign: true }); + data.append(FP16x16 { mag: 4851, sign: false }); + data.append(FP16x16 { mag: 68295, sign: true }); + data.append(FP16x16 { mag: 177139, sign: true }); + data.append(FP16x16 { mag: 69916, sign: true }); + data.append(FP16x16 { mag: 8677, sign: false }); + TensorTrait::new(shape.span(), data.span()) +} From 498762261d6b2aeb52334d8219cc859b133f3ea2 Mon Sep 17 00:00:00 2001 From: chachaleo Date: Thu, 11 Jan 2024 16:13:24 +0100 Subject: [PATCH 31/38] feat: SVM Regressor --- docgen/src/main.rs | 10 +- docs/SUMMARY.md | 2 + .../machine-learning/svm-regressor/README.md | 23 ++ .../svm-regressor/svm_regressor.predict.md | 111 ++++++ src/operators/ml.cairo | 1 + src/operators/ml/svm.cairo | 2 + src/operators/ml/svm/core.cairo | 100 +++++ src/operators/ml/svm/svm_regressor.cairo | 252 +++++++++++++ .../tree_ensemble_classifier.cairo | 68 ++-- .../tensor/implementations/tensor_bool.cairo | 4 +- .../implementations/tensor_complex64.cairo | 4 +- .../implementations/tensor_fp16x16.cairo | 8 +- .../implementations/tensor_fp16x16wide.cairo | 8 +- .../implementations/tensor_fp32x32.cairo | 8 +- .../implementations/tensor_fp64x64.cairo | 8 +- .../implementations/tensor_fp8x23.cairo | 8 +- .../implementations/tensor_fp8x23wide.cairo | 8 +- .../tensor/implementations/tensor_i32.cairo | 4 +- .../tensor/implementations/tensor_i8.cairo | 4 +- .../tensor/implementations/tensor_u32.cairo | 4 +- src/operators/tensor/math.cairo | 2 +- src/operators/tensor/math/compress.cairo | 69 ++-- tests/ml.cairo | 2 +- tests/ml/svm_regressor_test.cairo | 284 +++++++++++++++ tests/ml/tree_ensemble_classifier.cairo | 342 +++++++++--------- tests/nodes/compress_fp16x16_3d_axis1.cairo | 2 +- tests/nodes/compress_fp16x16_3d_axis2.cairo | 2 +- tests/nodes/compress_fp16x16_3d_axis3.cairo | 2 +- tests/nodes/compress_fp16x16_3d_default.cairo | 2 +- tests/nodes/compress_fp16x16_3d_noaxis.cairo | 2 +- tests/nodes/compress_fp8x23_3d_axis1.cairo | 2 +- tests/nodes/compress_fp8x23_3d_axis2.cairo | 2 +- tests/nodes/compress_fp8x23_3d_default.cairo | 2 +- tests/nodes/compress_i32_3d_axis1.cairo | 2 +- tests/nodes/compress_i32_3d_axis2.cairo | 2 +- tests/nodes/compress_i32_3d_default.cairo | 2 +- tests/nodes/compress_i8_3d_axis1.cairo | 2 +- tests/nodes/compress_i8_3d_axis2.cairo | 2 +- tests/nodes/compress_i8_3d_default.cairo | 2 +- tests/nodes/compress_u32_3d_axis1.cairo | 2 +- tests/nodes/compress_u32_3d_axis2.cairo | 2 +- tests/nodes/compress_u32_3d_axis2_2.cairo | 2 +- tests/nodes/compress_u32_3d_axis3.cairo | 2 +- tests/nodes/compress_u32_3d_default.cairo | 2 +- 44 files changed, 1080 insertions(+), 294 deletions(-) create mode 100644 docs/framework/operators/machine-learning/svm-regressor/README.md create mode 100644 docs/framework/operators/machine-learning/svm-regressor/svm_regressor.predict.md create mode 100644 src/operators/ml/svm.cairo create mode 100644 src/operators/ml/svm/core.cairo create mode 100644 src/operators/ml/svm/svm_regressor.cairo create mode 100644 tests/ml/svm_regressor_test.cairo diff --git a/docgen/src/main.rs b/docgen/src/main.rs index 274c9d855..ed2a69460 100644 --- a/docgen/src/main.rs +++ b/docgen/src/main.rs @@ -75,13 +75,21 @@ fn main() { doc_trait(trait_path, doc_path, label); doc_functions(trait_path, doc_path, trait_name, label); - // LINEAR REGRESSOR DOC + // LINEAR CLASSIFIER DOC let trait_path = "src/operators/ml/linear/linear_classifier.cairo"; let doc_path = "docs/framework/operators/machine-learning/linear-classifier"; let label = "linear_classifier"; let trait_name: &str = "LinearClassifierTrait"; doc_trait(trait_path, doc_path, label); doc_functions(trait_path, doc_path, trait_name, label); + + // SVM REGRESSOR DOC + let trait_path = "src/operators/ml/svm/svm_regressor.cairo"; + let doc_path = "docs/framework/operators/machine-learning/svm-regressor"; + let label = "svm_regressor"; + let trait_name: &str = "SVMRegressorTrait"; + doc_trait(trait_path, doc_path, label); + doc_functions(trait_path, doc_path, trait_name, label); } fn doc_trait(trait_path: &str, doc_path: &str, label: &str) { diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md index fa9998f2e..5cd38c5d0 100644 --- a/docs/SUMMARY.md +++ b/docs/SUMMARY.md @@ -143,6 +143,8 @@ * [linear\_classifier.predict](framework/operators/machine-learning/linear-classifier/linear\_classifier.predict.md) * [Linear Regressor](framework/operators/machine-learning/linear-regressor/README.md) * [linear\_regressor.predict](framework/operators/machine-learning/linear-regressor/linear\_regressor.predict.md) + * [SVM Regressor](framework/operators/machine-learning/svm-regressor/README.md) + * [svm\_regressor.predict](framework/operators/machine-learning/svm-regressor/svm\_regressor.predict.md) * [Sequence](framework/operators/sequence/README.md) * [sequence.sequence\_construct](framework/operators/sequence/sequence.sequence\_construct.md) * [sequence.sequence\_empty](framework/operators/sequence/sequence.sequence\_empty.md) diff --git a/docs/framework/operators/machine-learning/svm-regressor/README.md b/docs/framework/operators/machine-learning/svm-regressor/README.md new file mode 100644 index 000000000..f659cbbd1 --- /dev/null +++ b/docs/framework/operators/machine-learning/svm-regressor/README.md @@ -0,0 +1,23 @@ +# SVM Regressor + +`SVMRegressorTrait` provides a trait definition for svm regression problem. + +```rust +use orion::operators::ml::SVMRegressorTrait; +``` + +### Data types + +Orion supports currently only fixed point data types for `SVMRegressorTrait`. + +| Data type | dtype | +| -------------------- | ------------------------------------------------------------- | +| Fixed point (signed) | `SVMRegressorTrait` | + + +*** + +| function | description | +| --- | --- | +| [`svm_regressor.predict`](svm_regressor.predict.md) | Returns the regressed values for each input in N. | + diff --git a/docs/framework/operators/machine-learning/svm-regressor/svm_regressor.predict.md b/docs/framework/operators/machine-learning/svm-regressor/svm_regressor.predict.md new file mode 100644 index 000000000..68be4c922 --- /dev/null +++ b/docs/framework/operators/machine-learning/svm-regressor/svm_regressor.predict.md @@ -0,0 +1,111 @@ +# SVMRegressorTrait::predict + +```rust + fn predict(ref self: SVMRegressor, X: Tensor) -> Tensor; +``` + +Support Vector Machine regression prediction and one-class SVM anomaly detection. + +## Args + +* `self`: SVMRegressor - A SVMRegressor object. +* `X`: Input 2D tensor. + +## Returns + +* Tensor containing the Support Vector Machine regression prediction and one-class SVM anomaly detection of the input X. + +## Type Constraints + +`SVMRegressor` and `X` must be fixed points + +## Examples + +```rust +use orion::numbers::FP16x16; +use orion::operators::tensor::{Tensor, TensorTrait, FP16x16Tensor, U32Tensor}; +use orion::operators::tensor::FP16x16TensorPartialEq; + +use orion::operators::ml::svm::svm_regressor::{SVMRegressorTrait, POST_TRANSFORM, SVMRegressor}; +use orion::operators::ml::svm::core::{KERNEL_TYPE}; + +fn example_svm_regressor_linear() -> Tensor { + let coefficients: Span = array![ + FP16x16 { mag: 65536, sign: false }, + FP16x16 { mag: 65536, sign: true }, + FP16x16 { mag: 54959, sign: false }, + FP16x16 { mag: 54959, sign: true }, + FP16x16 { mag: 29299, sign: false }, + FP16x16 { mag: 65536, sign: true }, + FP16x16 { mag: 36236, sign: false } + ] + .span(); + let n_supports: usize = 7; + let one_class: usize = 0; + let rho: Span = array![FP16x16 { mag: 35788, sign: false }].span(); + let support_vectors: Span = array![ + FP16x16 { mag: 8421, sign: true }, + FP16x16 { mag: 5842, sign: false }, + FP16x16 { mag: 4510, sign: false }, + FP16x16 { mag: 5202, sign: true }, + FP16x16 { mag: 14783, sign: true }, + FP16x16 { mag: 17380, sign: true }, + FP16x16 { mag: 60595, sign: false }, + FP16x16 { mag: 1674, sign: true }, + FP16x16 { mag: 38669, sign: true }, + FP16x16 { mag: 63803, sign: false }, + FP16x16 { mag: 87720, sign: true }, + FP16x16 { mag: 22236, sign: false }, + FP16x16 { mag: 61816, sign: false }, + FP16x16 { mag: 34267, sign: true }, + FP16x16 { mag: 36418, sign: false }, + FP16x16 { mag: 27471, sign: false }, + FP16x16 { mag: 28421, sign: false }, + FP16x16 { mag: 69270, sign: true }, + FP16x16 { mag: 152819, sign: false }, + FP16x16 { mag: 4065, sign: false }, + FP16x16 { mag: 62274, sign: true } + ] + .span(); + let post_transform = POST_TRANSFORM::NONE; + let kernel_params: Span = array![ + FP16x16 { mag: 27812, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 196608, sign: false } + ] + .span(); + let kernel_type = KERNEL_TYPE::LINEAR; + + let mut regressor: SVMRegressor = SVMRegressor { + coefficients, + kernel_params, + kernel_type, + n_supports, + one_class, + post_transform, + rho, + support_vectors, + }; + + let mut X: Tensor = TensorTrait::new( + array![3, 3].span(), + array![ + FP16x16 { mag: 32768, sign: true }, + FP16x16 { mag: 26214, sign: true }, + FP16x16 { mag: 19660, sign: true }, + FP16x16 { mag: 13107, sign: true }, + FP16x16 { mag: 6553, sign: true }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 6553, sign: false }, + FP16x16 { mag: 13107, sign: false }, + FP16x16 { mag: 19660, sign: false }, + ] + .span() + ); + + return SVMRegressorTrait::predict(ref regressor, X); +} + +>>> [[-0.468206], [0.227487], [0.92318]] +``` + diff --git a/src/operators/ml.cairo b/src/operators/ml.cairo index 1098e42f3..724664216 100644 --- a/src/operators/ml.cairo +++ b/src/operators/ml.cairo @@ -1,5 +1,6 @@ mod tree_ensemble; mod linear; +mod svm; use orion::operators::ml::tree_ensemble::core::{ TreeEnsemble, TreeEnsembleAttributes, TreeEnsembleImpl, NODE_MODES diff --git a/src/operators/ml/svm.cairo b/src/operators/ml/svm.cairo new file mode 100644 index 000000000..93ab4515e --- /dev/null +++ b/src/operators/ml/svm.cairo @@ -0,0 +1,2 @@ +mod core; +mod svm_regressor; diff --git a/src/operators/ml/svm/core.cairo b/src/operators/ml/svm/core.cairo new file mode 100644 index 000000000..156fea8ee --- /dev/null +++ b/src/operators/ml/svm/core.cairo @@ -0,0 +1,100 @@ +use core::traits::TryInto; +use core::array::ArrayTrait; +use core::array::SpanTrait; +use core::traits::Into; +use orion::numbers::NumberTrait; +use orion::operators::tensor::{ + TensorTrait, Tensor, I8Tensor, I32Tensor, U32Tensor, FP16x16Tensor, BoolTensor +}; +use orion::numbers::{FP16x16, FP16x16Impl, FP32x32, FP32x32Impl, FixedTrait}; +use core::debug::PrintTrait; +use orion::utils::get_row; + +#[derive(Copy, Drop)] +enum KERNEL_TYPE { + LINEAR, + POLY, + RBF, + SIGMOID, +} + + +fn kernel_dot< + T, + MAG, + +Drop, + +Copy, + +NumberTrait, + +Add, + +TensorTrait, + +AddEq, + +Mul, + +Neg, + +Sub, +>( + kernel_params: Span, pA: Span, pB: Span, kernel: KERNEL_TYPE +) -> T { + let s = match kernel { + KERNEL_TYPE::LINEAR => sv_dot(pA, pB), + KERNEL_TYPE::POLY => { + let mut s = sv_dot(pA, pB); + s = s * *kernel_params.at(0) + *kernel_params.at(1); + s.pow(*kernel_params.at(2)) + }, + KERNEL_TYPE::RBF => { + let mut s = squared_diff(pA, pB); + NumberTrait::exp(-*kernel_params.at(0) * s) + }, + KERNEL_TYPE::SIGMOID => { + let mut s = sv_dot(pA, pB); + s = s * *kernel_params.at(0) + *kernel_params.at(1); + NumberTrait::tanh(s) + }, + }; + return s; +} + + +fn sv_dot< + T, MAG, +Drop, +Copy, +NumberTrait, +Add, +TensorTrait, +AddEq, +Mul, +>( + pA: Span, pB: Span +) -> T { + let mut i = 0; + let mut sum = NumberTrait::zero(); + loop { + if i == pA.len() { + break; + } + sum = sum + *pA.at(i) * *pB.at(i); + i += 1; + }; + + return sum; +} + +fn squared_diff< + T, + MAG, + +Drop, + +Copy, + +NumberTrait, + +Add, + +TensorTrait, + +AddEq, + +Mul, + +Sub, +>( + pA: Span, pB: Span +) -> T { + let mut i = 0; + let mut sum = NumberTrait::zero(); + loop { + if i == pA.len() { + break; + } + sum = sum + (*pA.at(i) - *pB.at(i)).pow(NumberTrait::one() + NumberTrait::one()); + i += 1; + }; + return sum; +} diff --git a/src/operators/ml/svm/svm_regressor.cairo b/src/operators/ml/svm/svm_regressor.cairo new file mode 100644 index 000000000..be76931e9 --- /dev/null +++ b/src/operators/ml/svm/svm_regressor.cairo @@ -0,0 +1,252 @@ +use core::traits::TryInto; +use core::array::ArrayTrait; +use core::array::SpanTrait; +use core::traits::Into; +use orion::numbers::NumberTrait; +use orion::operators::tensor::{ + TensorTrait, Tensor, I8Tensor, I32Tensor, U32Tensor, FP16x16Tensor, BoolTensor +}; +use orion::numbers::{FP16x16, FP16x16Impl, FP32x32, FP32x32Impl, FixedTrait}; +use core::debug::PrintTrait; +use orion::operators::nn::{NNTrait, FP16x16NN}; +use orion::utils::get_row; + +use orion::operators::ml::svm::core::{kernel_dot, KERNEL_TYPE}; + +#[derive(Copy, Drop, Destruct)] +struct SVMRegressor { + coefficients: Span, + kernel_params: Span, + kernel_type: KERNEL_TYPE, + n_supports: usize, + one_class: usize, + post_transform: POST_TRANSFORM, + rho: Span, + support_vectors: Span, +} + +#[derive(Copy, Drop)] +enum POST_TRANSFORM { + NONE, + SOFTMAX, + LOGISTIC, + SOFTMAXZERO, + PROBIT, +} + +#[derive(Copy, Drop)] +enum MODE { + SVM_LINEAR, + SVM_SVC, +} + +/// Trait +/// +/// predict - Returns the regressed values for each input in N. +trait SVMRegressorTrait { + /// # SVMRegressorTrait::predict + /// + /// ```rust + /// fn predict(ref self: SVMRegressor, X: Tensor) -> Tensor; + /// ``` + /// + /// Support Vector Machine regression prediction and one-class SVM anomaly detection. + /// + /// ## Args + /// + /// * `self`: SVMRegressor - A SVMRegressor object. + /// * `X`: Input 2D tensor. + /// + /// ## Returns + /// + /// * Tensor containing the Support Vector Machine regression prediction and one-class SVM anomaly detection of the input X. + /// + /// ## Type Constraints + /// + /// `SVMRegressor` and `X` must be fixed points + /// + /// ## Examples + /// + /// ```rust + /// use orion::numbers::FP16x16; + /// use orion::operators::tensor::{Tensor, TensorTrait, FP16x16Tensor, U32Tensor}; + /// use orion::operators::tensor::FP16x16TensorPartialEq; + /// + /// use orion::operators::ml::svm::svm_regressor::{SVMRegressorTrait, POST_TRANSFORM, SVMRegressor}; + /// use orion::operators::ml::svm::core::{KERNEL_TYPE}; + /// + /// fn example_svm_regressor_linear() -> Tensor { + /// let coefficients: Span = array![ + /// FP16x16 { mag: 65536, sign: false }, + /// FP16x16 { mag: 65536, sign: true }, + /// FP16x16 { mag: 54959, sign: false }, + /// FP16x16 { mag: 54959, sign: true }, + /// FP16x16 { mag: 29299, sign: false }, + /// FP16x16 { mag: 65536, sign: true }, + /// FP16x16 { mag: 36236, sign: false } + /// ] + /// .span(); + /// let n_supports: usize = 7; + /// let one_class: usize = 0; + /// let rho: Span = array![FP16x16 { mag: 35788, sign: false }].span(); + /// let support_vectors: Span = array![ + /// FP16x16 { mag: 8421, sign: true }, + /// FP16x16 { mag: 5842, sign: false }, + /// FP16x16 { mag: 4510, sign: false }, + /// FP16x16 { mag: 5202, sign: true }, + /// FP16x16 { mag: 14783, sign: true }, + /// FP16x16 { mag: 17380, sign: true }, + /// FP16x16 { mag: 60595, sign: false }, + /// FP16x16 { mag: 1674, sign: true }, + /// FP16x16 { mag: 38669, sign: true }, + /// FP16x16 { mag: 63803, sign: false }, + /// FP16x16 { mag: 87720, sign: true }, + /// FP16x16 { mag: 22236, sign: false }, + /// FP16x16 { mag: 61816, sign: false }, + /// FP16x16 { mag: 34267, sign: true }, + /// FP16x16 { mag: 36418, sign: false }, + /// FP16x16 { mag: 27471, sign: false }, + /// FP16x16 { mag: 28421, sign: false }, + /// FP16x16 { mag: 69270, sign: true }, + /// FP16x16 { mag: 152819, sign: false }, + /// FP16x16 { mag: 4065, sign: false }, + /// FP16x16 { mag: 62274, sign: true } + /// ] + /// .span(); + /// let post_transform = POST_TRANSFORM::NONE; + /// let kernel_params: Span = array![ + /// FP16x16 { mag: 27812, sign: false }, + /// FP16x16 { mag: 0, sign: false }, + /// FP16x16 { mag: 196608, sign: false } + /// ] + /// .span(); + /// let kernel_type = KERNEL_TYPE::LINEAR; + /// + /// let mut regressor: SVMRegressor = SVMRegressor { + /// coefficients, + /// kernel_params, + /// kernel_type, + /// n_supports, + /// one_class, + /// post_transform, + /// rho, + /// support_vectors, + /// }; + /// + /// let mut X: Tensor = TensorTrait::new( + /// array![3, 3].span(), + /// array![ + /// FP16x16 { mag: 32768, sign: true }, + /// FP16x16 { mag: 26214, sign: true }, + /// FP16x16 { mag: 19660, sign: true }, + /// FP16x16 { mag: 13107, sign: true }, + /// FP16x16 { mag: 6553, sign: true }, + /// FP16x16 { mag: 0, sign: false }, + /// FP16x16 { mag: 6553, sign: false }, + /// FP16x16 { mag: 13107, sign: false }, + /// FP16x16 { mag: 19660, sign: false }, + /// ] + /// .span() + /// ); + /// + /// return SVMRegressorTrait::predict(ref regressor, X); + /// } + /// + /// >>> [[-0.468206], [0.227487], [0.92318]] + /// ``` + /// + /// + fn predict(ref self: SVMRegressor, X: Tensor) -> Tensor; +} + +impl SVMRegressorImpl< + T, + MAG, + +Drop, + +Copy, + +NumberTrait, + +PartialOrd, + +PartialEq, + +Add, + +TensorTrait, + +PrintTrait, + +AddEq, + +Div, + +Mul, + +Neg, + +Sub, + +NNTrait, +> of SVMRegressorTrait { + fn predict(ref self: SVMRegressor, X: Tensor) -> Tensor { + let (mode_, kernel_type_, sv) = if self.n_supports > 0 { + let mode_ = MODE::SVM_SVC; + let kernel_type_ = self.kernel_type; + let sv = TensorTrait::new( + array![self.n_supports, self.support_vectors.len() / self.n_supports].span(), + self.support_vectors + ); //self.atts.support_vectors.reshape((self.atts.n_supports, -1)) + (mode_, kernel_type_, sv) + } else { + let mode_ = MODE::SVM_LINEAR; + let kernel_type_ = KERNEL_TYPE::LINEAR; + let sv = TensorTrait::new( + array![self.support_vectors.len()].span(), self.support_vectors + ); + (mode_, kernel_type_, sv) + }; + + let mut z = ArrayTrait::new(); + let mut n = 0; + loop { + if n == *X.shape.at(0) { + break; + } + let mut s = NumberTrait::zero(); + match mode_ { + MODE::SVM_LINEAR => { + let mut x_n = get_row(@X, n); + s = kernel_dot(self.kernel_params, x_n, self.coefficients, kernel_type_); + s += *self.rho.at(0); + }, + MODE::SVM_SVC => { + let mut x_n = get_row(@X, n); + let mut j = 0; + loop { + if j == self.n_supports { + break; + } + let mut sv_j = get_row(@sv, j); + let d = kernel_dot(self.kernel_params, x_n, sv_j, kernel_type_); + s += *self.coefficients.at(j) * d; + j += 1; + }; + s += *self.rho.at(0); + }, + } + if self.one_class == 1 { + let elem = if s > NumberTrait::zero() { + NumberTrait::one() + } else { + -NumberTrait::one() + }; + z.append(elem); + } else { + z.append(s); + }; + n += 1; + }; + + // Post Transform + let mut score = TensorTrait::new(array![*X.shape.at(0)].span(), z.span()); + score = match self.post_transform { + POST_TRANSFORM::NONE => score, + POST_TRANSFORM::SOFTMAX => NNTrait::softmax(@score, 1), + POST_TRANSFORM::LOGISTIC => NNTrait::sigmoid(@score), + POST_TRANSFORM::SOFTMAXZERO => NNTrait::softmax_zero(@score, 1), + POST_TRANSFORM::PROBIT => core::panic_with_felt252('Probit not supported yet'), + }; + + return score; + } +} + diff --git a/src/operators/ml/tree_ensemble/tree_ensemble_classifier.cairo b/src/operators/ml/tree_ensemble/tree_ensemble_classifier.cairo index eb50a2e14..051965260 100644 --- a/src/operators/ml/tree_ensemble/tree_ensemble_classifier.cairo +++ b/src/operators/ml/tree_ensemble/tree_ensemble_classifier.cairo @@ -408,12 +408,8 @@ impl TreeEnsembleClassifierImpl< let mut class_id: usize = 0; // Get first class_id in class_ids match class_ids.pop_front() { - Option::Some(c_id) => { - let mut class_id = *c_id; - }, - Option::None(_) => { - let mut class_id: usize = 0; - } + Option::Some(c_id) => { let mut class_id = *c_id; }, + Option::None(_) => { let mut class_id: usize = 0; } }; loop { if i == self.class_ids.len() { @@ -424,19 +420,17 @@ impl TreeEnsembleClassifierImpl< if *c_id == class_id { binary = true; continue; - }else{ + } else { binary = false; break; } - }, Option::None(_) => { break; } }; - }; // Clone res - if binary{ + if binary { let mut new_res: MutMatrix = MutMatrixImpl::new(res.rows, res.cols); let mut i: usize = 0; loop { @@ -445,14 +439,10 @@ impl TreeEnsembleClassifierImpl< } // Exchange let res_ele_1 = match res.get(i, 0) { - Option::Some(res_0) => { - new_res.set(i, 1, res_0); - }, - Option::None(_) => { - new_res.set(i, 1, NumberTrait::zero()); - }, + Option::Some(res_0) => { new_res.set(i, 1, res_0); }, + Option::None(_) => { new_res.set(i, 1, NumberTrait::zero()); }, }; - i+=1; + i += 1; }; match self.post_transform { POST_TRANSFORM::NONE => { @@ -467,11 +457,9 @@ impl TreeEnsembleClassifierImpl< let value = NumberTrait::sub(NumberTrait::one(), res_1); new_res.set(i, 0, value); }, - Option::None(_) => { - new_res.set(i, 0, NumberTrait::zero()); - }, + Option::None(_) => { new_res.set(i, 0, NumberTrait::zero()); }, }; - i+=1; + i += 1; }; }, POST_TRANSFORM::SOFTMAX => { @@ -482,14 +470,10 @@ impl TreeEnsembleClassifierImpl< } // Exchange let res_ele_0 = match new_res.get(i, 1) { - Option::Some(res_1) => { - new_res.set(i, 0, res_1.neg()); - }, - Option::None(_) => { - new_res.set(i, 0, NumberTrait::zero()); - }, + Option::Some(res_1) => { new_res.set(i, 0, res_1.neg()); }, + Option::None(_) => { new_res.set(i, 0, NumberTrait::zero()); }, }; - i+=1; + i += 1; }; }, POST_TRANSFORM::LOGISTIC => { @@ -500,14 +484,10 @@ impl TreeEnsembleClassifierImpl< } // Exchange let res_ele_0 = match new_res.get(i, 1) { - Option::Some(res_1) => { - new_res.set(i, 0, res_1.neg()); - }, - Option::None(_) => { - new_res.set(i, 0, NumberTrait::zero()); - }, + Option::Some(res_1) => { new_res.set(i, 0, res_1.neg()); }, + Option::None(_) => { new_res.set(i, 0, NumberTrait::zero()); }, }; - i+=1; + i += 1; }; }, POST_TRANSFORM::SOFTMAXZERO => { @@ -518,14 +498,10 @@ impl TreeEnsembleClassifierImpl< } // Exchange let res_ele_0 = match new_res.get(i, 1) { - Option::Some(res_1) => { - new_res.set(i, 0, res_1.neg()); - }, - Option::None(_) => { - new_res.set(i, 0, NumberTrait::zero()); - }, + Option::Some(res_1) => { new_res.set(i, 0, res_1.neg()); }, + Option::None(_) => { new_res.set(i, 0, NumberTrait::zero()); }, }; - i+=1; + i += 1; }; }, POST_TRANSFORM::PROBIT => { @@ -540,17 +516,15 @@ impl TreeEnsembleClassifierImpl< let value = NumberTrait::sub(NumberTrait::one(), res_1); new_res.set(i, 0, value); }, - Option::None(_) => { - new_res.set(i, 0, NumberTrait::zero()); - }, + Option::None(_) => { new_res.set(i, 0, NumberTrait::zero()); }, }; - i+=1; + i += 1; }; }, }; res = new_res; } - + // Post Transform let mut new_scores = match self.post_transform { POST_TRANSFORM::NONE => res, // No action required diff --git a/src/operators/tensor/implementations/tensor_bool.cairo b/src/operators/tensor/implementations/tensor_bool.cairo index 8ca90eef6..bdfc5aa41 100644 --- a/src/operators/tensor/implementations/tensor_bool.cairo +++ b/src/operators/tensor/implementations/tensor_bool.cairo @@ -444,7 +444,9 @@ impl BoolTensor of TensorTrait { math::gather_nd::gather_nd(self, indices, batch_dims) } - fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { + fn compress( + self: @Tensor, condition: Tensor, axis: Option + ) -> Tensor { math::compress::compress(self, condition, axis) } } diff --git a/src/operators/tensor/implementations/tensor_complex64.cairo b/src/operators/tensor/implementations/tensor_complex64.cairo index 53feb8980..36b7520ee 100644 --- a/src/operators/tensor/implementations/tensor_complex64.cairo +++ b/src/operators/tensor/implementations/tensor_complex64.cairo @@ -472,7 +472,9 @@ impl Complex64Tensor of TensorTrait { panic(array!['not supported!']) } - fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { + fn compress( + self: @Tensor, condition: Tensor, axis: Option + ) -> Tensor { math::compress::compress(self, condition, axis) } } diff --git a/src/operators/tensor/implementations/tensor_fp16x16.cairo b/src/operators/tensor/implementations/tensor_fp16x16.cairo index 2f13326c1..3d2e59684 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16.cairo @@ -486,7 +486,9 @@ impl FP16x16Tensor of TensorTrait { math::is_nan::is_nan(self) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } @@ -504,7 +506,9 @@ impl FP16x16Tensor of TensorTrait { manipulation::unique::unique(self, axis, sorted) } - fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { + fn compress( + self: @Tensor, condition: Tensor, axis: Option + ) -> Tensor { math::compress::compress(self, condition, axis) } } diff --git a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo index 4070d0154..339fc9762 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo @@ -452,7 +452,9 @@ impl FP16x16WTensor of TensorTrait { math::is_nan::is_nan(self) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } @@ -470,7 +472,9 @@ impl FP16x16WTensor of TensorTrait { manipulation::unique::unique(self, axis, sorted) } - fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { + fn compress( + self: @Tensor, condition: Tensor, axis: Option + ) -> Tensor { math::compress::compress(self, condition, axis) } } diff --git a/src/operators/tensor/implementations/tensor_fp32x32.cairo b/src/operators/tensor/implementations/tensor_fp32x32.cairo index bc77c3e15..8d7522848 100644 --- a/src/operators/tensor/implementations/tensor_fp32x32.cairo +++ b/src/operators/tensor/implementations/tensor_fp32x32.cairo @@ -487,7 +487,9 @@ impl FP32x32Tensor of TensorTrait { math::is_nan::is_nan(self) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } @@ -505,7 +507,9 @@ impl FP32x32Tensor of TensorTrait { manipulation::unique::unique(self, axis, sorted) } - fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { + fn compress( + self: @Tensor, condition: Tensor, axis: Option + ) -> Tensor { math::compress::compress(self, condition, axis) } } diff --git a/src/operators/tensor/implementations/tensor_fp64x64.cairo b/src/operators/tensor/implementations/tensor_fp64x64.cairo index 7cac1e80f..ac144a06a 100644 --- a/src/operators/tensor/implementations/tensor_fp64x64.cairo +++ b/src/operators/tensor/implementations/tensor_fp64x64.cairo @@ -487,7 +487,9 @@ impl FP64x64Tensor of TensorTrait { math::is_nan::is_nan(self) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } @@ -505,7 +507,9 @@ impl FP64x64Tensor of TensorTrait { manipulation::unique::unique(self, axis, sorted) } - fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { + fn compress( + self: @Tensor, condition: Tensor, axis: Option + ) -> Tensor { math::compress::compress(self, condition, axis) } } diff --git a/src/operators/tensor/implementations/tensor_fp8x23.cairo b/src/operators/tensor/implementations/tensor_fp8x23.cairo index 6b8a471f0..e18d668bf 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23.cairo @@ -485,7 +485,9 @@ impl FP8x23Tensor of TensorTrait { math::is_nan::is_nan(self) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } @@ -503,7 +505,9 @@ impl FP8x23Tensor of TensorTrait { manipulation::unique::unique(self, axis, sorted) } - fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { + fn compress( + self: @Tensor, condition: Tensor, axis: Option + ) -> Tensor { math::compress::compress(self, condition, axis) } } diff --git a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo index 54118f17b..920e57e05 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo @@ -438,7 +438,9 @@ impl FP8x23WTensor of TensorTrait { math::is_nan::is_nan(self) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } @@ -456,7 +458,9 @@ impl FP8x23WTensor of TensorTrait { manipulation::unique::unique(self, axis, sorted) } - fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { + fn compress( + self: @Tensor, condition: Tensor, axis: Option + ) -> Tensor { math::compress::compress(self, condition, axis) } } diff --git a/src/operators/tensor/implementations/tensor_i32.cairo b/src/operators/tensor/implementations/tensor_i32.cairo index 67401fcb2..0875d7393 100644 --- a/src/operators/tensor/implementations/tensor_i32.cairo +++ b/src/operators/tensor/implementations/tensor_i32.cairo @@ -482,7 +482,9 @@ impl I32Tensor of TensorTrait { panic(array!['not supported!']) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } diff --git a/src/operators/tensor/implementations/tensor_i8.cairo b/src/operators/tensor/implementations/tensor_i8.cairo index 4077b9bd3..843f236ef 100644 --- a/src/operators/tensor/implementations/tensor_i8.cairo +++ b/src/operators/tensor/implementations/tensor_i8.cairo @@ -480,7 +480,9 @@ impl I8Tensor of TensorTrait { panic(array!['not supported!']) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } diff --git a/src/operators/tensor/implementations/tensor_u32.cairo b/src/operators/tensor/implementations/tensor_u32.cairo index b69c19b04..add0c6d12 100644 --- a/src/operators/tensor/implementations/tensor_u32.cairo +++ b/src/operators/tensor/implementations/tensor_u32.cairo @@ -424,7 +424,9 @@ impl U32Tensor of TensorTrait { panic(array!['not supported!']) } - fn gather_nd(self: @Tensor, indices: Tensor, batch_dims: Option) -> Tensor { + fn gather_nd( + self: @Tensor, indices: Tensor, batch_dims: Option + ) -> Tensor { math::gather_nd::gather_nd(self, indices, batch_dims) } diff --git a/src/operators/tensor/math.cairo b/src/operators/tensor/math.cairo index 4cb97feda..0f675ee4d 100644 --- a/src/operators/tensor/math.cairo +++ b/src/operators/tensor/math.cairo @@ -58,4 +58,4 @@ mod is_inf; mod gather_nd; mod reduce_log_sum; mod erf; -mod compress; \ No newline at end of file +mod compress; diff --git a/src/operators/tensor/math/compress.cairo b/src/operators/tensor/math/compress.cairo index 6380d5d15..d22eb1d82 100644 --- a/src/operators/tensor/math/compress.cairo +++ b/src/operators/tensor/math/compress.cairo @@ -14,12 +14,7 @@ use orion::operators::tensor::U32TensorPartialEq; use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor}; /// Cf: TensorTrait::compare docstring -fn compress< - T, - impl TTensorTrait: TensorTrait, - impl TCopy: Copy, - impl TDrop: Drop, ->( +fn compress, impl TCopy: Copy, impl TDrop: Drop,>( self: @Tensor, condition: Tensor, axis: Option ) -> Tensor { let axis = match axis { @@ -29,7 +24,7 @@ fn compress< let data_rank = (*self.shape).len(); let condition_rank = (condition.shape).len(); - assert((data_rank >= 1 ), 'data rank must > 1'); + assert((data_rank >= 1), 'data rank must > 1'); assert((condition_rank == 1), 'condition rank must be 1'); let mut data_shape = *self.shape; @@ -67,9 +62,7 @@ fn compress< let mut total_shape = 1; loop { match data_shape.pop_front() { - Option::Some(val) => { - total_shape *= *val; - }, + Option::Some(val) => { total_shape *= *val; }, Option::None(_) => { break; } }; }; @@ -78,8 +71,10 @@ fn compress< loop { match condition_data.pop_front() { Option::Some(val) => { - if (ind == total_shape) {break; } - if (*val != 0){ + if (ind == total_shape) { + break; + } + if (*val != 0) { output_data.append(*self.data[ind]); } ind += 1; @@ -99,8 +94,7 @@ fn compress< Option::Some(val) => { if (ind == axis) { output_shape.append(output); - } - else { + } else { output_shape.append(*val); if (ind > axis) { loop_breaker *= *val; @@ -120,31 +114,34 @@ fn compress< let mut ind = 0; let mut ind_loop = 0; - + let mut inner_index: usize = 0; let mut condition_data_clone = condition_data.clone(); loop { - if (ind == other_loop_breaker) {break;} + if (ind == other_loop_breaker) { + break; + } let mut condition_data_clone = condition_data.clone(); - inner_index = *data_shape.at(axis) * ind; + inner_index = *data_shape.at(axis) * ind; loop { - match condition_data_clone.pop_front() { - Option::Some(val) => { - if (*val != 0){ - let result = inner_index * loop_breaker ; - - let mut data_ind:usize = result ; - loop { - if data_ind == result + loop_breaker { break; } - index_data.append(data_ind); - data_ind+=1; - }; - } - inner_index += 1; - }, - Option::None(_) => { break; } + Option::Some(val) => { + if (*val != 0) { + let result = inner_index * loop_breaker; + + let mut data_ind: usize = result; + loop { + if data_ind == result + loop_breaker { + break; + } + index_data.append(data_ind); + data_ind += 1; + }; + } + inner_index += 1; + }, + Option::None(_) => { break; } }; }; @@ -153,14 +150,12 @@ fn compress< loop { match index_data.pop_front() { - Option::Some(val) => { - output_data.append(*self.data[val]); - }, + Option::Some(val) => { output_data.append(*self.data[val]); }, Option::None(_) => { break; } }; - }; + }; } let mut output_tensor = TensorTrait::::new(output_shape.span(), output_data.span()); return output_tensor; -} \ No newline at end of file +} diff --git a/tests/ml.cairo b/tests/ml.cairo index 78f6b370b..4a5abf9a8 100644 --- a/tests/ml.cairo +++ b/tests/ml.cairo @@ -2,4 +2,4 @@ mod tree_ensemble_classifier; mod tree_ensemble_regressor; mod linear_regressor_test; mod linear_classifier_test; - +mod svm_regressor_test; diff --git a/tests/ml/svm_regressor_test.cairo b/tests/ml/svm_regressor_test.cairo new file mode 100644 index 000000000..afa924e56 --- /dev/null +++ b/tests/ml/svm_regressor_test.cairo @@ -0,0 +1,284 @@ +use orion::numbers::FP16x16; +use orion::operators::tensor::{Tensor, TensorTrait, FP16x16Tensor, U32Tensor}; +use orion::utils::{assert_eq, assert_seq_eq}; +use orion::operators::tensor::FP16x16TensorPartialEq; + +use orion::operators::ml::svm::svm_regressor::{SVMRegressorTrait, POST_TRANSFORM, SVMRegressor}; +use orion::operators::ml::svm::core::{KERNEL_TYPE}; + + +#[test] +#[available_gas(200000000000)] +fn test_svm_regressor_linear() { + let kernel_params: Span = array![ + FP16x16 { mag: 27812, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 196608, sign: false } + ] + .span(); + let kernel_type = KERNEL_TYPE::LINEAR; + let (mut regressor, X) = svm_regressor_helper(kernel_type, kernel_params); + + let scores = SVMRegressorTrait::predict(ref regressor, X); + + let mut expected_scores: Tensor = TensorTrait::new( + array![3, 1].span(), + array![ + FP16x16 { mag: 30684, sign: true }, + FP16x16 { mag: 14908, sign: false }, + FP16x16 { mag: 60501, sign: false }, + ] + .span() + ); + assert_eq(scores, expected_scores); +} + +#[test] +#[available_gas(200000000000)] +fn test_svm_regressor_poly() { + let kernel_params: Span = array![ + FP16x16 { mag: 22456, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 196608, sign: false } + ] + .span(); + + let kernel_type = KERNEL_TYPE::POLY; + let (mut regressor, X) = svm_regressor_helper(kernel_type, kernel_params); + + let scores = SVMRegressorTrait::predict(ref regressor, X); + + let mut expected_scores: Tensor = TensorTrait::new( + array![3, 1].span(), + array![ + FP16x16 { mag: 34542, sign: false }, + FP16x16 { mag: 35623, sign: false }, + FP16x16 { mag: 35815, sign: false }, + ] + .span() + ); + assert_eq(scores, expected_scores); +} + + +#[test] +#[available_gas(200000000000)] +fn test_svm_regressor_rbf() { + let kernel_params: Span = array![ + FP16x16 { mag: 19848, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 196608, sign: false } + ] + .span(); + let kernel_type = KERNEL_TYPE::RBF; + let (mut regressor, X) = svm_regressor_helper(kernel_type, kernel_params); + + let scores = SVMRegressorTrait::predict(ref regressor, X); + + let mut expected_scores: Tensor = TensorTrait::new( + array![3, 1].span(), + array![ + FP16x16 { mag: 19376, sign: false }, + FP16x16 { mag: 31318, sign: false }, + FP16x16 { mag: 45566, sign: false }, + ] + .span() + ); + assert_eq(scores, expected_scores); +} + +#[test] +#[available_gas(200000000000)] +fn test_svm_regressor_sigmoid() { + let kernel_params: Span = array![ + FP16x16 { mag: 20108, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 196608, sign: false } + ] + .span(); + let kernel_type = KERNEL_TYPE::SIGMOID; + let (mut regressor, X) = svm_regressor_helper(kernel_type, kernel_params); + + let scores = SVMRegressorTrait::predict(ref regressor, X); + + let mut expected_scores: Tensor = TensorTrait::new( + array![3, 1].span(), + array![ + FP16x16 { mag: 15683, sign: false }, + FP16x16 { mag: 29421, sign: false }, + FP16x16 { mag: 43364, sign: false }, + ] + .span() + ); + assert_eq(scores, expected_scores); +} + +#[test] +#[available_gas(200000000000)] +fn test_svm_regressor_linear_one_class_0() { + let post_transform = POST_TRANSFORM::NONE; + let one_class = 0; + let (mut regressor, X) = svm_regressor_linear_helper(post_transform, one_class); + + let scores = SVMRegressorTrait::predict(ref regressor, X); + + let mut expected_scores: Tensor = TensorTrait::new( + array![3, 1].span(), + array![ + FP16x16 { mag: 63484, sign: false }, + FP16x16 { mag: 74218, sign: false }, + FP16x16 { mag: 84953, sign: false }, + ] + .span() + ); + assert_eq(scores, expected_scores); +} + +#[test] +#[available_gas(200000000000)] +fn test_svm_regressor_linear_one_class_1() { + let post_transform = POST_TRANSFORM::NONE; + let one_class = 1; + let (mut regressor, X) = svm_regressor_linear_helper(post_transform, one_class); + + let scores = SVMRegressorTrait::predict(ref regressor, X); + + let mut expected_scores: Tensor = TensorTrait::new( + array![3, 1].span(), + array![ + FP16x16 { mag: 65536, sign: false }, + FP16x16 { mag: 65536, sign: false }, + FP16x16 { mag: 65536, sign: false }, + ] + .span() + ); + assert_eq(scores, expected_scores); +} + + +// ============ HELPER ============ // + +fn svm_regressor_helper( + kernel_type: KERNEL_TYPE, kernel_params: Span +) -> (SVMRegressor, Tensor) { + let coefficients: Span = array![ + FP16x16 { mag: 65536, sign: false }, + FP16x16 { mag: 65536, sign: true }, + FP16x16 { mag: 54959, sign: false }, + FP16x16 { mag: 54959, sign: true }, + FP16x16 { mag: 29299, sign: false }, + FP16x16 { mag: 65536, sign: true }, + FP16x16 { mag: 36236, sign: false } + ] + .span(); + + let n_supports: usize = 7; + let one_class: usize = 0; + let rho: Span = array![FP16x16 { mag: 35788, sign: false }].span(); + + let support_vectors: Span = array![ + FP16x16 { mag: 8421, sign: true }, + FP16x16 { mag: 5842, sign: false }, + FP16x16 { mag: 4510, sign: false }, + FP16x16 { mag: 5202, sign: true }, + FP16x16 { mag: 14783, sign: true }, + FP16x16 { mag: 17380, sign: true }, + FP16x16 { mag: 60595, sign: false }, + FP16x16 { mag: 1674, sign: true }, + FP16x16 { mag: 38669, sign: true }, + FP16x16 { mag: 63803, sign: false }, + FP16x16 { mag: 87720, sign: true }, + FP16x16 { mag: 22236, sign: false }, + FP16x16 { mag: 61816, sign: false }, + FP16x16 { mag: 34267, sign: true }, + FP16x16 { mag: 36418, sign: false }, + FP16x16 { mag: 27471, sign: false }, + FP16x16 { mag: 28421, sign: false }, + FP16x16 { mag: 69270, sign: true }, + FP16x16 { mag: 152819, sign: false }, + FP16x16 { mag: 4065, sign: false }, + FP16x16 { mag: 62274, sign: true } + ] + .span(); + + let post_transform = POST_TRANSFORM::NONE; + + let mut regressor: SVMRegressor = SVMRegressor { + coefficients, + kernel_params, + kernel_type, + n_supports, + one_class, + post_transform, + rho, + support_vectors, + }; + + let mut X: Tensor = TensorTrait::new( + array![3, 3].span(), + array![ + FP16x16 { mag: 32768, sign: true }, + FP16x16 { mag: 26214, sign: true }, + FP16x16 { mag: 19660, sign: true }, + FP16x16 { mag: 13107, sign: true }, + FP16x16 { mag: 6553, sign: true }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 6553, sign: false }, + FP16x16 { mag: 13107, sign: false }, + FP16x16 { mag: 19660, sign: false }, + ] + .span() + ); + + (regressor, X) +} + +fn svm_regressor_linear_helper( + post_transform: POST_TRANSFORM, one_class: usize +) -> (SVMRegressor, Tensor) { + let coefficients: Span = array![ + FP16x16 { mag: 18540, sign: false }, + FP16x16 { mag: 1746, sign: true }, + FP16x16 { mag: 1097, sign: false } + ] + .span(); + let kernel_params: Span = array![ + FP16x16 { mag: 65, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 196608, sign: false } + ] + .span(); + let kernel_type = KERNEL_TYPE::LINEAR; + let n_supports: usize = 0; + let rho: Span = array![FP16x16 { mag: 81285, sign: false }].span(); + + let support_vectors: Span = array![].span(); + + let mut regressor: SVMRegressor = SVMRegressor { + coefficients, + kernel_params, + kernel_type, + n_supports, + one_class, + post_transform, + rho, + support_vectors, + }; + let mut X: Tensor = TensorTrait::new( + array![3, 3].span(), + array![ + FP16x16 { mag: 65536, sign: true }, + FP16x16 { mag: 52428, sign: true }, + FP16x16 { mag: 39321, sign: true }, + FP16x16 { mag: 26214, sign: true }, + FP16x16 { mag: 13107, sign: true }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 13107, sign: false }, + FP16x16 { mag: 26214, sign: false }, + FP16x16 { mag: 39321, sign: false }, + ] + .span() + ); + + (regressor, X) +} diff --git a/tests/ml/tree_ensemble_classifier.cairo b/tests/ml/tree_ensemble_classifier.cairo index 6ee2afc11..2d325a334 100644 --- a/tests/ml/tree_ensemble_classifier.cairo +++ b/tests/ml/tree_ensemble_classifier.cairo @@ -241,7 +241,7 @@ fn test_tree_ensemble_classifier_binary_none() { #[test] #[available_gas(200000000000)] fn test_tree_ensemble_classifier_binary_logistic() { - + let (mut classifier, X) = tree_ensemble_classifier_binary_class_helper(POST_TRANSFORM::LOGISTIC); let (labels, mut scores) = TreeEnsembleClassifierTrait::predict(ref classifier, X); @@ -282,7 +282,7 @@ fn test_tree_ensemble_classifier_binary_softmax() { 'score[0, 1]' ); } - + #[test] #[available_gas(200000000000)] fn test_tree_ensemble_classifier_binary_softmax_zero() { @@ -502,163 +502,163 @@ fn tree_ensemble_classifier_binary_class_helper( let tree_ids: Span = array![0].span(); let mut root_index: Felt252Dict = Default::default(); - root_index.insert(0, 0); + root_index.insert(0, 0); let mut node_index: Felt252Dict = Default::default(); - node_index.insert(2089986280348253421170679821480865132823066470938446095505822317253594081284, 0); - node_index.insert(2001140082530619239661729809084578298299223810202097622761632384561112390979, 1); - node_index.insert(2592670241084192212354027440049085852792506518781954896144296316131790403900, 2); - node_index.insert(2960591271376829378356567803618548672034867345123727178628869426548453833420, 3); - node_index.insert(458933264452572171106695256465341160654132084710250671055261382009315664425, 4); - node_index.insert(3344223123784052057366048933846905716067140384361791026153972616805110454637, 5); - node_index.insert(658476905110174425295568215706634733332002869979287079110965040248935650599, 6); - node_index.insert(2836212335642438363012490794290757623813171043187182819737087983331902926990, 7); - node_index.insert(3496601277869056110810900082189273917786762659443522403285387602989271154262, 8); - node_index.insert(1249294489531540970169611621067106471309281870082955806338234725206665112557, 9); - node_index.insert(2161697998033672097816961828039488190903838124365465380011173778905747857792, 10); - node_index.insert(1129815197211541481934112806673325772687763881719835256646064516195041515616, 11); - node_index.insert(2592593088135949192377729543480191336537305484235681164569491942155715064163, 12); - node_index.insert(578223957014284909949571568465953382377214912750427143720957054706073492593, 13); - node_index.insert(1645617302026197421098102802983206579163506957138012501615708926120228167528, 14); - node_index.insert(2809438816810155970395166036110536928593305127049404137239671320081144123490, 15); - node_index.insert(2496308528011391755709310159103918074725328650411689040761791240500618770096, 16); - node_index.insert(2003594778587446957576114348312422277631766150749194167061999666337236425714, 17); - node_index.insert(2215681478480673835576618830034726157921200517935329010004363713426342305479, 18); - node_index.insert(3185925835074464079989752015681272863271067691852543168049845807561733691707, 19); - node_index.insert(1207265836470221457484062512091666004839070622130697586496866096347024057755, 20); - node_index.insert(1870230949202979679764944800468118671928852128047695497376875566624821494262, 21); - node_index.insert(618060852536781954395603948693216564334274573299243914053414488061601327758, 22); - node_index.insert(232760707548494477255512699093366059519467428168757247456690480397246371463, 23); - node_index.insert(1617386247965480308136742715422077429967341022950306068917456849194882895900, 24); - node_index.insert(654822874782506608656472905579051041410086644071534146326024101025575400153, 25); - node_index.insert(525638101901638132526332140778087078272370083489998903571807698910013602668, 26); - node_index.insert(3091640181556387972179279087539287892670640556085669903494551919685982442095, 27); - node_index.insert(1425411460578159050163131982087304445715005458700346341117759372943452688022, 28); - node_index.insert(1722933265299553894839124723076027659619615015638971980461286818493531809034, 29); - node_index.insert(3325117385742592388671007840076299062858228097051060057749225651290693960897, 30); - node_index.insert(1869273998012404873272699831805499731567895666937555882116307079956228100456, 31); - node_index.insert(257262395234910825879033951801423835835630270967846664413154594520703929530, 32); - node_index.insert(2891500475385583315757684141371327604925143655360011721762142660942782195029, 33); - node_index.insert(1257459981124043271342269816753070228024611695909553991758648317372015085782, 34); - node_index.insert(3573101724490615587655146760489247477770015274618159524231872921394794809579, 35); - node_index.insert(2951401777594449283985541406642940553317465718696638438535370997641527993378, 36); - node_index.insert(2436860863451320452900512817385686838091627966322316039332239784330434600829, 37); - node_index.insert(3257977356974702770994741663931928753019715185508521958836925918758890988390, 38); - node_index.insert(2741853283805093821434776875305720302351684616683152528499335618682018880592, 39); - node_index.insert(514567459251558911686762246500770717674979116530125263461114578537254680672, 40); - node_index.insert(2119374930171040799805795099091470687208894498354655018353474015395489390434, 41); - node_index.insert(3338470191188327918255138125570464269857839379813971679216902484398948556964, 42); - node_index.insert(2892272281879752543368066497063301979597320550780387266511926397533716561161, 43); - node_index.insert(2855312300216814846973137837923466865382642814675378398541743368270404441020, 44); - node_index.insert(3483159989811162048659069774034779954374540681397531094699912464364012442948, 45); - node_index.insert(2987290998320166766043911843685118029159841654368226419198314196237253901671, 46); - node_index.insert(2925128850088180758852255336587985612621894021863350117875677692518888637440, 47); - node_index.insert(2816470536741550741568042622139415760794090671576940833850781679568928363263, 48); - node_index.insert(117504025904364990582663097556885493352655695615775952177872159762046032741, 49); - node_index.insert(2143228410294149239354901612797540167003066966910132278060626241695943498248, 50); - node_index.insert(419311759585766455354017006957403420381614228026953716552023555428752798694, 51); - node_index.insert(3050064038480880151202753004776919876287903442365303272956696507808448797287, 52); - node_index.insert(1385347512411195789080079656286641766866442255046855963092069449745407366357, 53); - node_index.insert(3070310993421490198115289431281422702215620142859327949152517372324361472619, 54); - node_index.insert(2913742884576958969164113782587195202828846527657900496424141449477472273564, 55); - node_index.insert(2093568472535973986606438755824580633177115509557931302974988564932601955239, 56); - node_index.insert(3560543329106347446823281318204312198881533222464682017397248462954529220234, 57); - node_index.insert(2258329791422139736262782239641765930569031761627249090322755566443202104242, 58); - node_index.insert(780147230530856456622774510057100334628735431063744145772648079601317149643, 59); - node_index.insert(2316329094783634722527635915976455864728431870713378530935487247638854220445, 60); - node_index.insert(595942459003356191117553450912822964169058193996898486073017533717706655996, 61); - node_index.insert(468061318535033931711585815055033307297228787991312757359512916260570188285, 62); - node_index.insert(2052204235688624923559873131063770183910134013049526186717275231865702195614, 63); - node_index.insert(1699955311620840869165542755053722387608345658646185648087789689690825797785, 64); - node_index.insert(3374282522812564185678772854203408947562394461702303390331208821006329361123, 65); - node_index.insert(2973169188135795465401576355486514117723575153845438471619715618155257254587, 66); - node_index.insert(1933845760462748501896196912926633344425020928596291295340561855718789280752, 67); - node_index.insert(1400206374308839959676708676217334569580738052049798766556848516900888958934, 68); - node_index.insert(1440488595273849761788031183901254714714513692476890759699232177835922420051, 69); - node_index.insert(1765607197782429306903827944694032984087223086461400721152786273443512274576, 70); - node_index.insert(1081728107764482028110815183657783965582618309560569428049406599883158895762, 71); - node_index.insert(2062101824085365476835789898002802715794623271831111740147610520210138854237, 72); - node_index.insert(2074740322618091900768870458741540994849904300182495465356314088191301853065, 73); - node_index.insert(3258451235037745323160669027918885172565773098482160366154412360890640013860, 74); - node_index.insert(525053653813541387331907730505904505067816165493211829943994988775279102044, 75); - node_index.insert(1899573658331441767985549642643113663505618738939032010935036740376062596854, 76); - node_index.insert(350484224543766923071449868701665032398970313961410080649918872017849315812, 77); - node_index.insert(1950842492180490337143378914485176805944281696420768035114335939818602766139, 78); - node_index.insert(1404824782481446239312837894341789608778585592445990662138109764117920511709, 79); - node_index.insert(362836422984951199752185473435750713386745407518736982952373985921347236081, 80); - node_index.insert(946623025367211063265176586824604502073515634531788667777364911179858705558, 81); - node_index.insert(2633163324000277496191816132521100721217797223993064604664039067710591734562, 82); - node_index.insert(1801986104078933931671502775029170829560335045042499367678597186639133610708, 83); - node_index.insert(1420697278439090953165809531316265389371075037014378922361911811337560296928, 84); - node_index.insert(2818913779862691152404893285048164649343019708946413114150419613972391643833, 85); - node_index.insert(2117995436013652728497840885480545729833030913486848118093758726746902541269, 86); - node_index.insert(127751852951361188238686395231851222850913859197429858579312845246901369178, 87); - node_index.insert(2698811633001158191033663638617437313508153976714307643233173949778419312517, 88); - node_index.insert(658388282521842455588914251287531837029259203197178137902217792556456503561, 89); - node_index.insert(1181527093320872098458354979612125149419384756607076935731557552577945926179, 90); - node_index.insert(749436134732178646256740138670151907037714564259781780243747781475007506978, 91); - node_index.insert(139527053159256821789882596124320673637475746672994443968014105962305658551, 92); - node_index.insert(2256264752321707533173578319742847366660740117899562657584919346001438808295, 93); - node_index.insert(1471349294215639651865069312281269029496180149092207674923855978537861742949, 94); - node_index.insert(1599527610774916650758786135513735847459194869088601099692148267264507139422, 95); - node_index.insert(1348925567371118538973078195838174941892601233016661969987842843098656775084, 96); - node_index.insert(3255130909854220350850821724488067913492420563978595271106701962634473840914, 97); - node_index.insert(1098499015810170842401428216621470177488952811780672364884710297364076372943, 98); - node_index.insert(2666902303639302012507119689908308317608522901613536135678723310999647515155, 99); - node_index.insert(907997515879651052705985194221621380802961721264372722705825219340461809200, 100); - node_index.insert(2124360554325144308113106422635485756539471211141315552843423768396084888273, 101); - node_index.insert(3598736440043009208771817410113758019876931018927260161846683440123219507147, 102); - node_index.insert(1237113034722832488580561245188430373504295256910735188987019984096012001931, 103); - node_index.insert(884558344049768836371555446021588200903052780339208951904957349404044037185, 104); - node_index.insert(784280321344489256066716285882203121428790637989919760379274813665427427262, 105); - node_index.insert(3472551952588748711709398308465335743810517871695257916614928877311914574241, 106); - node_index.insert(1579363348100943961344032004617708767155021524242506190674861550786419896732, 107); - node_index.insert(653576968777651719072715499492112313607520878545254037043893560183879857489, 108); - node_index.insert(2633327961579170199842757290989312779085828750765842327985383652720803061926, 109); - node_index.insert(3101204920253220343970782457572784926765600523633379722044614528209389590915, 110); - node_index.insert(2537565394330405662800880050062241097694806466900452037378113841155978555645, 111); - node_index.insert(306955559655552244989220345789093187601563118591829582730637833945761653350, 112); - node_index.insert(1144065212212058748489308207801098564095305699242880891977316839573431241916, 113); - node_index.insert(3478181491851418723342103101321490659650934149094649769124337426850038155270, 114); - node_index.insert(3419621624676637660673415219086314486713019053519954317586073983685881930356, 115); - node_index.insert(2426908011370291613447136873176769136554489197972200481728552402228021778402, 116); - node_index.insert(1916122042123370178944690083048900704842269230325086549679099089416174875473, 117); - node_index.insert(2057207652658215393591191155928140567561900227203223756539551876829334137660, 118); - node_index.insert(2722034389703601317070746005702467061064354401688341549606678773616189196490, 119); - node_index.insert(1171026027377763359814377926117880688616494219551682642535759838199732407496, 120); - node_index.insert(3507234282031533800397666430789917374211847440333243952151005899337152633413, 121); - node_index.insert(591003147462937848375161803108517142253138969543815135207326321181858185919, 122); - node_index.insert(182069734527202013451813026473135702900640769187641767871411473365447302169, 123); - node_index.insert(1195243682249232878341146428166676460720423167409013083888435705219134747702, 124); - node_index.insert(1793425644853312386902998134061844248823841892125424765064687913085130719534, 125); - node_index.insert(1983622665815164792580256365519803214027269990384198703315493315153573288434, 126); - node_index.insert(3615973154491344159350153395208055142342062736505558158666764642048838175685, 127); - node_index.insert(2751715913626909804252433699602081411293721754810298670422380863932998088133, 128); - node_index.insert(186918881712189523740089713555196200069231794627360499557319265374750577226, 129); - node_index.insert(696585542544434929491503209053317581175146475161262066468664234437983008675, 130); - node_index.insert(4359830495913805154545225899592517767672472055784183911796827820518038513, 131); - node_index.insert(2954335207058000607751727656601539819316106074875304820535376873121805433820, 132); - node_index.insert(2510390039949230255082316953804013731253145558531652907601250263563528226672, 133); - node_index.insert(3226995230854300551967642178527450300960499043510855212238369890580256668532, 134); - node_index.insert(1620924075233065517364532267959798304439946408626316544761884056227131075831, 135); - node_index.insert(1610900122192929153657761847202689179268074338802437933866337242354758101660, 136); - node_index.insert(2565949095169598991903537465065584077778440646580025930326495506484329892725, 137); - node_index.insert(1012362975819634411571869839734809106575285344002573666983595104659295812607, 138); - node_index.insert(242312010918799555845832460483650516749990744287009628468613253461264531026, 139); - node_index.insert(1104776796569046483584574115975216172161469015460244982207905888870418040487, 140); - node_index.insert(3289555912992777681578950209252840071327866822704829766247386311885634446673, 141); - node_index.insert(3133389957643610781371406448279843175887428913359743769920083259111437722268, 142); - node_index.insert(1169918710119352022244140656086831769713178729571654411898266328562003734517, 143); - node_index.insert(3592039235252149652556167686570045881877115549259769455422056097903987237819, 144); - node_index.insert(2048175709145840597887667330964815895803568760936075562647625937161113445908, 145); - node_index.insert(602222645962845554276438041138511866776339653340605661136009451417275008940, 146); - node_index.insert(3318742320906017551291978242369663702298606650330380959683585594592748661010, 147); - node_index.insert(564160996724923690963741657975239836484028160385417016805513722318839327322, 148); - node_index.insert(656294390376267384135628810815504467149264887388377312825033341338166573620, 149); - node_index.insert(1201592236750942207412694706123654466634588634474700675083122904145559965915, 150); - node_index.insert(2141408926815137181004274624388915700231991905288681935478972043994347966006, 151); - node_index.insert(1440847977042239464860406726605567303568767649154338464116083965986084755262, 152); - node_index.insert(950585553138591375958592507876257987416844837045084288783892644487908218679, 153); - node_index.insert(257643451533833048856069434258149588745628261389615631070776723485957908127, 154); + node_index.insert(2089986280348253421170679821480865132823066470938446095505822317253594081284, 0); + node_index.insert(2001140082530619239661729809084578298299223810202097622761632384561112390979, 1); + node_index.insert(2592670241084192212354027440049085852792506518781954896144296316131790403900, 2); + node_index.insert(2960591271376829378356567803618548672034867345123727178628869426548453833420, 3); + node_index.insert(458933264452572171106695256465341160654132084710250671055261382009315664425, 4); + node_index.insert(3344223123784052057366048933846905716067140384361791026153972616805110454637, 5); + node_index.insert(658476905110174425295568215706634733332002869979287079110965040248935650599, 6); + node_index.insert(2836212335642438363012490794290757623813171043187182819737087983331902926990, 7); + node_index.insert(3496601277869056110810900082189273917786762659443522403285387602989271154262, 8); + node_index.insert(1249294489531540970169611621067106471309281870082955806338234725206665112557, 9); + node_index.insert(2161697998033672097816961828039488190903838124365465380011173778905747857792, 10); + node_index.insert(1129815197211541481934112806673325772687763881719835256646064516195041515616, 11); + node_index.insert(2592593088135949192377729543480191336537305484235681164569491942155715064163, 12); + node_index.insert(578223957014284909949571568465953382377214912750427143720957054706073492593, 13); + node_index.insert(1645617302026197421098102802983206579163506957138012501615708926120228167528, 14); + node_index.insert(2809438816810155970395166036110536928593305127049404137239671320081144123490, 15); + node_index.insert(2496308528011391755709310159103918074725328650411689040761791240500618770096, 16); + node_index.insert(2003594778587446957576114348312422277631766150749194167061999666337236425714, 17); + node_index.insert(2215681478480673835576618830034726157921200517935329010004363713426342305479, 18); + node_index.insert(3185925835074464079989752015681272863271067691852543168049845807561733691707, 19); + node_index.insert(1207265836470221457484062512091666004839070622130697586496866096347024057755, 20); + node_index.insert(1870230949202979679764944800468118671928852128047695497376875566624821494262, 21); + node_index.insert(618060852536781954395603948693216564334274573299243914053414488061601327758, 22); + node_index.insert(232760707548494477255512699093366059519467428168757247456690480397246371463, 23); + node_index.insert(1617386247965480308136742715422077429967341022950306068917456849194882895900, 24); + node_index.insert(654822874782506608656472905579051041410086644071534146326024101025575400153, 25); + node_index.insert(525638101901638132526332140778087078272370083489998903571807698910013602668, 26); + node_index.insert(3091640181556387972179279087539287892670640556085669903494551919685982442095, 27); + node_index.insert(1425411460578159050163131982087304445715005458700346341117759372943452688022, 28); + node_index.insert(1722933265299553894839124723076027659619615015638971980461286818493531809034, 29); + node_index.insert(3325117385742592388671007840076299062858228097051060057749225651290693960897, 30); + node_index.insert(1869273998012404873272699831805499731567895666937555882116307079956228100456, 31); + node_index.insert(257262395234910825879033951801423835835630270967846664413154594520703929530, 32); + node_index.insert(2891500475385583315757684141371327604925143655360011721762142660942782195029, 33); + node_index.insert(1257459981124043271342269816753070228024611695909553991758648317372015085782, 34); + node_index.insert(3573101724490615587655146760489247477770015274618159524231872921394794809579, 35); + node_index.insert(2951401777594449283985541406642940553317465718696638438535370997641527993378, 36); + node_index.insert(2436860863451320452900512817385686838091627966322316039332239784330434600829, 37); + node_index.insert(3257977356974702770994741663931928753019715185508521958836925918758890988390, 38); + node_index.insert(2741853283805093821434776875305720302351684616683152528499335618682018880592, 39); + node_index.insert(514567459251558911686762246500770717674979116530125263461114578537254680672, 40); + node_index.insert(2119374930171040799805795099091470687208894498354655018353474015395489390434, 41); + node_index.insert(3338470191188327918255138125570464269857839379813971679216902484398948556964, 42); + node_index.insert(2892272281879752543368066497063301979597320550780387266511926397533716561161, 43); + node_index.insert(2855312300216814846973137837923466865382642814675378398541743368270404441020, 44); + node_index.insert(3483159989811162048659069774034779954374540681397531094699912464364012442948, 45); + node_index.insert(2987290998320166766043911843685118029159841654368226419198314196237253901671, 46); + node_index.insert(2925128850088180758852255336587985612621894021863350117875677692518888637440, 47); + node_index.insert(2816470536741550741568042622139415760794090671576940833850781679568928363263, 48); + node_index.insert(117504025904364990582663097556885493352655695615775952177872159762046032741, 49); + node_index.insert(2143228410294149239354901612797540167003066966910132278060626241695943498248, 50); + node_index.insert(419311759585766455354017006957403420381614228026953716552023555428752798694, 51); + node_index.insert(3050064038480880151202753004776919876287903442365303272956696507808448797287, 52); + node_index.insert(1385347512411195789080079656286641766866442255046855963092069449745407366357, 53); + node_index.insert(3070310993421490198115289431281422702215620142859327949152517372324361472619, 54); + node_index.insert(2913742884576958969164113782587195202828846527657900496424141449477472273564, 55); + node_index.insert(2093568472535973986606438755824580633177115509557931302974988564932601955239, 56); + node_index.insert(3560543329106347446823281318204312198881533222464682017397248462954529220234, 57); + node_index.insert(2258329791422139736262782239641765930569031761627249090322755566443202104242, 58); + node_index.insert(780147230530856456622774510057100334628735431063744145772648079601317149643, 59); + node_index.insert(2316329094783634722527635915976455864728431870713378530935487247638854220445, 60); + node_index.insert(595942459003356191117553450912822964169058193996898486073017533717706655996, 61); + node_index.insert(468061318535033931711585815055033307297228787991312757359512916260570188285, 62); + node_index.insert(2052204235688624923559873131063770183910134013049526186717275231865702195614, 63); + node_index.insert(1699955311620840869165542755053722387608345658646185648087789689690825797785, 64); + node_index.insert(3374282522812564185678772854203408947562394461702303390331208821006329361123, 65); + node_index.insert(2973169188135795465401576355486514117723575153845438471619715618155257254587, 66); + node_index.insert(1933845760462748501896196912926633344425020928596291295340561855718789280752, 67); + node_index.insert(1400206374308839959676708676217334569580738052049798766556848516900888958934, 68); + node_index.insert(1440488595273849761788031183901254714714513692476890759699232177835922420051, 69); + node_index.insert(1765607197782429306903827944694032984087223086461400721152786273443512274576, 70); + node_index.insert(1081728107764482028110815183657783965582618309560569428049406599883158895762, 71); + node_index.insert(2062101824085365476835789898002802715794623271831111740147610520210138854237, 72); + node_index.insert(2074740322618091900768870458741540994849904300182495465356314088191301853065, 73); + node_index.insert(3258451235037745323160669027918885172565773098482160366154412360890640013860, 74); + node_index.insert(525053653813541387331907730505904505067816165493211829943994988775279102044, 75); + node_index.insert(1899573658331441767985549642643113663505618738939032010935036740376062596854, 76); + node_index.insert(350484224543766923071449868701665032398970313961410080649918872017849315812, 77); + node_index.insert(1950842492180490337143378914485176805944281696420768035114335939818602766139, 78); + node_index.insert(1404824782481446239312837894341789608778585592445990662138109764117920511709, 79); + node_index.insert(362836422984951199752185473435750713386745407518736982952373985921347236081, 80); + node_index.insert(946623025367211063265176586824604502073515634531788667777364911179858705558, 81); + node_index.insert(2633163324000277496191816132521100721217797223993064604664039067710591734562, 82); + node_index.insert(1801986104078933931671502775029170829560335045042499367678597186639133610708, 83); + node_index.insert(1420697278439090953165809531316265389371075037014378922361911811337560296928, 84); + node_index.insert(2818913779862691152404893285048164649343019708946413114150419613972391643833, 85); + node_index.insert(2117995436013652728497840885480545729833030913486848118093758726746902541269, 86); + node_index.insert(127751852951361188238686395231851222850913859197429858579312845246901369178, 87); + node_index.insert(2698811633001158191033663638617437313508153976714307643233173949778419312517, 88); + node_index.insert(658388282521842455588914251287531837029259203197178137902217792556456503561, 89); + node_index.insert(1181527093320872098458354979612125149419384756607076935731557552577945926179, 90); + node_index.insert(749436134732178646256740138670151907037714564259781780243747781475007506978, 91); + node_index.insert(139527053159256821789882596124320673637475746672994443968014105962305658551, 92); + node_index.insert(2256264752321707533173578319742847366660740117899562657584919346001438808295, 93); + node_index.insert(1471349294215639651865069312281269029496180149092207674923855978537861742949, 94); + node_index.insert(1599527610774916650758786135513735847459194869088601099692148267264507139422, 95); + node_index.insert(1348925567371118538973078195838174941892601233016661969987842843098656775084, 96); + node_index.insert(3255130909854220350850821724488067913492420563978595271106701962634473840914, 97); + node_index.insert(1098499015810170842401428216621470177488952811780672364884710297364076372943, 98); + node_index.insert(2666902303639302012507119689908308317608522901613536135678723310999647515155, 99); + node_index.insert(907997515879651052705985194221621380802961721264372722705825219340461809200, 100); + node_index.insert(2124360554325144308113106422635485756539471211141315552843423768396084888273, 101); + node_index.insert(3598736440043009208771817410113758019876931018927260161846683440123219507147, 102); + node_index.insert(1237113034722832488580561245188430373504295256910735188987019984096012001931, 103); + node_index.insert(884558344049768836371555446021588200903052780339208951904957349404044037185, 104); + node_index.insert(784280321344489256066716285882203121428790637989919760379274813665427427262, 105); + node_index.insert(3472551952588748711709398308465335743810517871695257916614928877311914574241, 106); + node_index.insert(1579363348100943961344032004617708767155021524242506190674861550786419896732, 107); + node_index.insert(653576968777651719072715499492112313607520878545254037043893560183879857489, 108); + node_index.insert(2633327961579170199842757290989312779085828750765842327985383652720803061926, 109); + node_index.insert(3101204920253220343970782457572784926765600523633379722044614528209389590915, 110); + node_index.insert(2537565394330405662800880050062241097694806466900452037378113841155978555645, 111); + node_index.insert(306955559655552244989220345789093187601563118591829582730637833945761653350, 112); + node_index.insert(1144065212212058748489308207801098564095305699242880891977316839573431241916, 113); + node_index.insert(3478181491851418723342103101321490659650934149094649769124337426850038155270, 114); + node_index.insert(3419621624676637660673415219086314486713019053519954317586073983685881930356, 115); + node_index.insert(2426908011370291613447136873176769136554489197972200481728552402228021778402, 116); + node_index.insert(1916122042123370178944690083048900704842269230325086549679099089416174875473, 117); + node_index.insert(2057207652658215393591191155928140567561900227203223756539551876829334137660, 118); + node_index.insert(2722034389703601317070746005702467061064354401688341549606678773616189196490, 119); + node_index.insert(1171026027377763359814377926117880688616494219551682642535759838199732407496, 120); + node_index.insert(3507234282031533800397666430789917374211847440333243952151005899337152633413, 121); + node_index.insert(591003147462937848375161803108517142253138969543815135207326321181858185919, 122); + node_index.insert(182069734527202013451813026473135702900640769187641767871411473365447302169, 123); + node_index.insert(1195243682249232878341146428166676460720423167409013083888435705219134747702, 124); + node_index.insert(1793425644853312386902998134061844248823841892125424765064687913085130719534, 125); + node_index.insert(1983622665815164792580256365519803214027269990384198703315493315153573288434, 126); + node_index.insert(3615973154491344159350153395208055142342062736505558158666764642048838175685, 127); + node_index.insert(2751715913626909804252433699602081411293721754810298670422380863932998088133, 128); + node_index.insert(186918881712189523740089713555196200069231794627360499557319265374750577226, 129); + node_index.insert(696585542544434929491503209053317581175146475161262066468664234437983008675, 130); + node_index.insert(4359830495913805154545225899592517767672472055784183911796827820518038513, 131); + node_index.insert(2954335207058000607751727656601539819316106074875304820535376873121805433820, 132); + node_index.insert(2510390039949230255082316953804013731253145558531652907601250263563528226672, 133); + node_index.insert(3226995230854300551967642178527450300960499043510855212238369890580256668532, 134); + node_index.insert(1620924075233065517364532267959798304439946408626316544761884056227131075831, 135); + node_index.insert(1610900122192929153657761847202689179268074338802437933866337242354758101660, 136); + node_index.insert(2565949095169598991903537465065584077778440646580025930326495506484329892725, 137); + node_index.insert(1012362975819634411571869839734809106575285344002573666983595104659295812607, 138); + node_index.insert(242312010918799555845832460483650516749990744287009628468613253461264531026, 139); + node_index.insert(1104776796569046483584574115975216172161469015460244982207905888870418040487, 140); + node_index.insert(3289555912992777681578950209252840071327866822704829766247386311885634446673, 141); + node_index.insert(3133389957643610781371406448279843175887428913359743769920083259111437722268, 142); + node_index.insert(1169918710119352022244140656086831769713178729571654411898266328562003734517, 143); + node_index.insert(3592039235252149652556167686570045881877115549259769455422056097903987237819, 144); + node_index.insert(2048175709145840597887667330964815895803568760936075562647625937161113445908, 145); + node_index.insert(602222645962845554276438041138511866776339653340605661136009451417275008940, 146); + node_index.insert(3318742320906017551291978242369663702298606650330380959683585594592748661010, 147); + node_index.insert(564160996724923690963741657975239836484028160385417016805513722318839327322, 148); + node_index.insert(656294390376267384135628810815504467149264887388377312825033341338166573620, 149); + node_index.insert(1201592236750942207412694706123654466634588634474700675083122904145559965915, 150); + node_index.insert(2141408926815137181004274624388915700231991905288681935478972043994347966006, 151); + node_index.insert(1440847977042239464860406726605567303568767649154338464116083965986084755262, 152); + node_index.insert(950585553138591375958592507876257987416844837045084288783892644487908218679, 153); + node_index.insert(257643451533833048856069434258149588745628261389615631070776723485957908127, 154); let atts = TreeEnsembleAttributes { nodes_falsenodeids, @@ -687,19 +687,19 @@ fn tree_ensemble_classifier_binary_class_helper( }; let mut X = TensorTrait::new( - array![1,9].span(), - array![ - FP16x16 { mag: 39321, sign: false }, - FP16x16 { mag: 32768, sign: false }, - FP16x16 { mag: 52428, sign: false }, - FP16x16 { mag: 16384, sign: false }, - FP16x16 { mag: 0, sign: false }, - FP16x16 { mag: 65536, sign: false }, - FP16x16 { mag: 0, sign: false }, - FP16x16 { mag: 16384, sign: false }, - FP16x16 { mag: 0, sign: false }, - ].span() - ); + array![1,9].span(), + array![ + FP16x16 { mag: 39321, sign: false }, + FP16x16 { mag: 32768, sign: false }, + FP16x16 { mag: 52428, sign: false }, + FP16x16 { mag: 16384, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 65536, sign: false }, + FP16x16 { mag: 0, sign: false }, + FP16x16 { mag: 16384, sign: false }, + FP16x16 { mag: 0, sign: false }, + ].span() + ); (classifier, X) } \ No newline at end of file diff --git a/tests/nodes/compress_fp16x16_3d_axis1.cairo b/tests/nodes/compress_fp16x16_3d_axis1.cairo index de0c173ed..f110fd66d 100644 --- a/tests/nodes/compress_fp16x16_3d_axis1.cairo +++ b/tests/nodes/compress_fp16x16_3d_axis1.cairo @@ -18,7 +18,7 @@ fn test_compress_fp16x16_3d_axis1() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.compress(condition:input_1, axis:Option::Some(1)); + let y_0 = input_0.compress(condition: input_1, axis: Option::Some(1)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/compress_fp16x16_3d_axis2.cairo b/tests/nodes/compress_fp16x16_3d_axis2.cairo index 765bcb5ea..1115fb557 100644 --- a/tests/nodes/compress_fp16x16_3d_axis2.cairo +++ b/tests/nodes/compress_fp16x16_3d_axis2.cairo @@ -18,7 +18,7 @@ fn test_compress_fp16x16_3d_axis2() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.compress(condition:input_1, axis:Option::Some(2)); + let y_0 = input_0.compress(condition: input_1, axis: Option::Some(2)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/compress_fp16x16_3d_axis3.cairo b/tests/nodes/compress_fp16x16_3d_axis3.cairo index ffa9c8321..76ef5f641 100644 --- a/tests/nodes/compress_fp16x16_3d_axis3.cairo +++ b/tests/nodes/compress_fp16x16_3d_axis3.cairo @@ -18,7 +18,7 @@ fn test_compress_fp16x16_3d_axis3() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.compress(condition:input_1, axis:Option::Some(3)); + let y_0 = input_0.compress(condition: input_1, axis: Option::Some(3)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/compress_fp16x16_3d_default.cairo b/tests/nodes/compress_fp16x16_3d_default.cairo index d9b837a19..aff1849e2 100644 --- a/tests/nodes/compress_fp16x16_3d_default.cairo +++ b/tests/nodes/compress_fp16x16_3d_default.cairo @@ -18,7 +18,7 @@ fn test_compress_fp16x16_3d_default() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.compress(condition:input_1, axis:Option::Some(0)); + let y_0 = input_0.compress(condition: input_1, axis: Option::Some(0)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/compress_fp16x16_3d_noaxis.cairo b/tests/nodes/compress_fp16x16_3d_noaxis.cairo index 2bd536e08..3c9645b1d 100644 --- a/tests/nodes/compress_fp16x16_3d_noaxis.cairo +++ b/tests/nodes/compress_fp16x16_3d_noaxis.cairo @@ -18,7 +18,7 @@ fn test_compress_fp16x16_3d_noaxis() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.compress(condition:input_1, axis:Option::None(())); + let y_0 = input_0.compress(condition: input_1, axis: Option::None(())); assert_eq(y_0, z_0); } diff --git a/tests/nodes/compress_fp8x23_3d_axis1.cairo b/tests/nodes/compress_fp8x23_3d_axis1.cairo index edd013f54..f7edfd13a 100644 --- a/tests/nodes/compress_fp8x23_3d_axis1.cairo +++ b/tests/nodes/compress_fp8x23_3d_axis1.cairo @@ -18,7 +18,7 @@ fn test_compress_fp8x23_3d_axis1() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.compress(condition:input_1, axis:Option::Some(1)); + let y_0 = input_0.compress(condition: input_1, axis: Option::Some(1)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/compress_fp8x23_3d_axis2.cairo b/tests/nodes/compress_fp8x23_3d_axis2.cairo index 580a6272a..369ffb8bf 100644 --- a/tests/nodes/compress_fp8x23_3d_axis2.cairo +++ b/tests/nodes/compress_fp8x23_3d_axis2.cairo @@ -18,7 +18,7 @@ fn test_compress_fp8x23_3d_axis2() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.compress(condition:input_1, axis:Option::Some(2)); + let y_0 = input_0.compress(condition: input_1, axis: Option::Some(2)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/compress_fp8x23_3d_default.cairo b/tests/nodes/compress_fp8x23_3d_default.cairo index a927f7fe8..eab9aa1ac 100644 --- a/tests/nodes/compress_fp8x23_3d_default.cairo +++ b/tests/nodes/compress_fp8x23_3d_default.cairo @@ -18,7 +18,7 @@ fn test_compress_fp8x23_3d_default() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.compress(condition:input_1, axis:Option::Some(0)); + let y_0 = input_0.compress(condition: input_1, axis: Option::Some(0)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/compress_i32_3d_axis1.cairo b/tests/nodes/compress_i32_3d_axis1.cairo index f69cf2e2a..571e5beb5 100644 --- a/tests/nodes/compress_i32_3d_axis1.cairo +++ b/tests/nodes/compress_i32_3d_axis1.cairo @@ -18,7 +18,7 @@ fn test_compress_i32_3d_axis1() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.compress(condition:input_1, axis:Option::Some(1)); + let y_0 = input_0.compress(condition: input_1, axis: Option::Some(1)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/compress_i32_3d_axis2.cairo b/tests/nodes/compress_i32_3d_axis2.cairo index bfe01e5a0..be674ffba 100644 --- a/tests/nodes/compress_i32_3d_axis2.cairo +++ b/tests/nodes/compress_i32_3d_axis2.cairo @@ -18,7 +18,7 @@ fn test_compress_i32_3d_axis2() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.compress(condition:input_1, axis:Option::Some(2)); + let y_0 = input_0.compress(condition: input_1, axis: Option::Some(2)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/compress_i32_3d_default.cairo b/tests/nodes/compress_i32_3d_default.cairo index b07d95010..4bd05fce1 100644 --- a/tests/nodes/compress_i32_3d_default.cairo +++ b/tests/nodes/compress_i32_3d_default.cairo @@ -18,7 +18,7 @@ fn test_compress_i32_3d_default() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.compress(condition:input_1, axis:Option::Some(0)); + let y_0 = input_0.compress(condition: input_1, axis: Option::Some(0)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/compress_i8_3d_axis1.cairo b/tests/nodes/compress_i8_3d_axis1.cairo index 6a4197ce1..fae6c2356 100644 --- a/tests/nodes/compress_i8_3d_axis1.cairo +++ b/tests/nodes/compress_i8_3d_axis1.cairo @@ -18,7 +18,7 @@ fn test_compress_i8_3d_axis1() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.compress(condition:input_1, axis:Option::Some(1)); + let y_0 = input_0.compress(condition: input_1, axis: Option::Some(1)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/compress_i8_3d_axis2.cairo b/tests/nodes/compress_i8_3d_axis2.cairo index 4dd7b5a8f..f8e90c133 100644 --- a/tests/nodes/compress_i8_3d_axis2.cairo +++ b/tests/nodes/compress_i8_3d_axis2.cairo @@ -18,7 +18,7 @@ fn test_compress_i8_3d_axis2() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.compress(condition:input_1, axis:Option::Some(2)); + let y_0 = input_0.compress(condition: input_1, axis: Option::Some(2)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/compress_i8_3d_default.cairo b/tests/nodes/compress_i8_3d_default.cairo index 14b684377..1b4052d0e 100644 --- a/tests/nodes/compress_i8_3d_default.cairo +++ b/tests/nodes/compress_i8_3d_default.cairo @@ -18,7 +18,7 @@ fn test_compress_i8_3d_default() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.compress(condition:input_1, axis:Option::Some(0)); + let y_0 = input_0.compress(condition: input_1, axis: Option::Some(0)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/compress_u32_3d_axis1.cairo b/tests/nodes/compress_u32_3d_axis1.cairo index dda59bead..7cfadc989 100644 --- a/tests/nodes/compress_u32_3d_axis1.cairo +++ b/tests/nodes/compress_u32_3d_axis1.cairo @@ -16,7 +16,7 @@ fn test_compress_u32_3d_axis1() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.compress(condition:input_1, axis:Option::Some(1)); + let y_0 = input_0.compress(condition: input_1, axis: Option::Some(1)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/compress_u32_3d_axis2.cairo b/tests/nodes/compress_u32_3d_axis2.cairo index ba8fa77ef..9c70291c5 100644 --- a/tests/nodes/compress_u32_3d_axis2.cairo +++ b/tests/nodes/compress_u32_3d_axis2.cairo @@ -16,7 +16,7 @@ fn test_compress_u32_3d_axis2() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.compress(condition:input_1, axis:Option::Some(2)); + let y_0 = input_0.compress(condition: input_1, axis: Option::Some(2)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/compress_u32_3d_axis2_2.cairo b/tests/nodes/compress_u32_3d_axis2_2.cairo index aa283b2cc..850c10296 100644 --- a/tests/nodes/compress_u32_3d_axis2_2.cairo +++ b/tests/nodes/compress_u32_3d_axis2_2.cairo @@ -16,7 +16,7 @@ fn test_compress_u32_3d_axis2_2() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.compress(condition:input_1, axis:Option::Some(2)); + let y_0 = input_0.compress(condition: input_1, axis: Option::Some(2)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/compress_u32_3d_axis3.cairo b/tests/nodes/compress_u32_3d_axis3.cairo index 62684b39f..c53e3e1b1 100644 --- a/tests/nodes/compress_u32_3d_axis3.cairo +++ b/tests/nodes/compress_u32_3d_axis3.cairo @@ -16,7 +16,7 @@ fn test_compress_u32_3d_axis3() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.compress(condition:input_1, axis:Option::Some(3)); + let y_0 = input_0.compress(condition: input_1, axis: Option::Some(3)); assert_eq(y_0, z_0); } diff --git a/tests/nodes/compress_u32_3d_default.cairo b/tests/nodes/compress_u32_3d_default.cairo index 058750c53..a7d987eb6 100644 --- a/tests/nodes/compress_u32_3d_default.cairo +++ b/tests/nodes/compress_u32_3d_default.cairo @@ -16,7 +16,7 @@ fn test_compress_u32_3d_default() { let input_1 = input_1::input_1(); let z_0 = output_0::output_0(); - let y_0 = input_0.compress(condition:input_1, axis:Option::Some(0)); + let y_0 = input_0.compress(condition: input_1, axis: Option::Some(0)); assert_eq(y_0, z_0); } From 7eaddb76cf0152e095f9e89773a73124d839219a Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Thu, 11 Jan 2024 19:41:07 -0500 Subject: [PATCH 32/38] make it compile --- .../implementations/tensor_complex64.cairo | 20 +++++++++++++++++++ .../implementations/tensor_fp16x16.cairo | 6 ------ .../implementations/tensor_fp16x16wide.cairo | 6 ------ .../implementations/tensor_fp32x32.cairo | 6 ------ .../implementations/tensor_fp64x64.cairo | 6 ------ .../implementations/tensor_fp8x23.cairo | 2 ++ .../implementations/tensor_fp8x23wide.cairo | 2 ++ .../tensor/implementations/tensor_i32.cairo | 2 ++ .../tensor/implementations/tensor_i8.cairo | 2 ++ .../tensor/implementations/tensor_u32.cairo | 2 ++ 10 files changed, 30 insertions(+), 24 deletions(-) diff --git a/src/operators/tensor/implementations/tensor_complex64.cairo b/src/operators/tensor/implementations/tensor_complex64.cairo index 53feb8980..2f2882db0 100644 --- a/src/operators/tensor/implementations/tensor_complex64.cairo +++ b/src/operators/tensor/implementations/tensor_complex64.cairo @@ -475,6 +475,26 @@ impl Complex64Tensor of TensorTrait { fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { math::compress::compress(self, condition, axis) } + + + fn resize( + self: @Tensor, + roi: Option>, + scales: Option>, + sizes: Option>, + antialias: Option, + axes: Option>, + coordinate_transformation_mode: Option, + cubic_coeff_a: Option, + exclude_outside: Option, + extrapolation_value: Option, + keep_aspect_ratio_policy: Option, + mode: Option, + nearest_mode: Option, + ) -> Tensor { + panic(array!['not supported!']) + } + } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp16x16.cairo b/src/operators/tensor/implementations/tensor_fp16x16.cairo index f4cbf0957..a8d07aa6c 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16.cairo @@ -486,12 +486,6 @@ impl FP16x16Tensor of TensorTrait { math::is_nan::is_nan(self) } - fn concat_from_sequence( - sequence: Array>, axis: i32, new_axis: Option - ) -> Tensor { - math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) - } - fn gather_nd( self: @Tensor, indices: Tensor, batch_dims: Option ) -> Tensor { diff --git a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo index 8e1603795..c371ab08f 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16wide.cairo @@ -452,12 +452,6 @@ impl FP16x16WTensor of TensorTrait { math::is_nan::is_nan(self) } - fn concat_from_sequence( - sequence: Array>, axis: i32, new_axis: Option - ) -> Tensor { - math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) - } - fn gather_nd( self: @Tensor, indices: Tensor, batch_dims: Option ) -> Tensor { diff --git a/src/operators/tensor/implementations/tensor_fp32x32.cairo b/src/operators/tensor/implementations/tensor_fp32x32.cairo index cba05da0f..fcf6bbc62 100644 --- a/src/operators/tensor/implementations/tensor_fp32x32.cairo +++ b/src/operators/tensor/implementations/tensor_fp32x32.cairo @@ -487,12 +487,6 @@ impl FP32x32Tensor of TensorTrait { math::is_nan::is_nan(self) } - fn concat_from_sequence( - sequence: Array>, axis: i32, new_axis: Option - ) -> Tensor { - math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) - } - fn gather_nd( self: @Tensor, indices: Tensor, batch_dims: Option ) -> Tensor { diff --git a/src/operators/tensor/implementations/tensor_fp64x64.cairo b/src/operators/tensor/implementations/tensor_fp64x64.cairo index 8c912eda3..22cb9885f 100644 --- a/src/operators/tensor/implementations/tensor_fp64x64.cairo +++ b/src/operators/tensor/implementations/tensor_fp64x64.cairo @@ -487,12 +487,6 @@ impl FP64x64Tensor of TensorTrait { math::is_nan::is_nan(self) } - fn concat_from_sequence( - sequence: Array>, axis: i32, new_axis: Option - ) -> Tensor { - math::concat_from_sequence::concat_from_sequence(sequence, axis, new_axis) - } - fn gather_nd( self: @Tensor, indices: Tensor, batch_dims: Option ) -> Tensor { diff --git a/src/operators/tensor/implementations/tensor_fp8x23.cairo b/src/operators/tensor/implementations/tensor_fp8x23.cairo index b0c27d6df..3157df272 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23.cairo @@ -533,6 +533,8 @@ impl FP8x23Tensor of TensorTrait { mode, nearest_mode ) + } + fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { math::compress::compress(self, condition, axis) } diff --git a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo index 744b3f8bf..89b91f695 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo @@ -472,6 +472,8 @@ impl FP8x23WTensor of TensorTrait { nearest_mode: Option, ) -> Tensor { panic(array!['not supported!']) + } + fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { math::compress::compress(self, condition, axis) } diff --git a/src/operators/tensor/implementations/tensor_i32.cairo b/src/operators/tensor/implementations/tensor_i32.cairo index 49c251f8a..b3a560c87 100644 --- a/src/operators/tensor/implementations/tensor_i32.cairo +++ b/src/operators/tensor/implementations/tensor_i32.cairo @@ -516,6 +516,8 @@ impl I32Tensor of TensorTrait { nearest_mode: Option, ) -> Tensor { panic(array!['not supported!']) + } + fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { math::compress::compress(self, condition, axis) } diff --git a/src/operators/tensor/implementations/tensor_i8.cairo b/src/operators/tensor/implementations/tensor_i8.cairo index d1024baf7..b67a57957 100644 --- a/src/operators/tensor/implementations/tensor_i8.cairo +++ b/src/operators/tensor/implementations/tensor_i8.cairo @@ -514,6 +514,8 @@ impl I8Tensor of TensorTrait { nearest_mode: Option, ) -> Tensor { panic(array!['not supported!']) + } + fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { math::compress::compress(self, condition, axis) } diff --git a/src/operators/tensor/implementations/tensor_u32.cairo b/src/operators/tensor/implementations/tensor_u32.cairo index cd796d554..ed740fe90 100644 --- a/src/operators/tensor/implementations/tensor_u32.cairo +++ b/src/operators/tensor/implementations/tensor_u32.cairo @@ -458,6 +458,8 @@ impl U32Tensor of TensorTrait { nearest_mode: Option, ) -> Tensor { panic(array!['not supported!']) + } + fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { math::compress::compress(self, condition, axis) } From 546651b4b0b7fb2bcfa65157af9937255fd7f126 Mon Sep 17 00:00:00 2001 From: "allcontributors[bot]" <46447321+allcontributors[bot]@users.noreply.github.com> Date: Fri, 12 Jan 2024 00:45:41 +0000 Subject: [PATCH 33/38] docs: update README.md [skip ci] --- README.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index d0b3055e4..1f48201d3 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ # Orion: An Open-source Framework for Validity and ZK ML ✨ -[![All Contributors](https://img.shields.io/badge/all_contributors-28-orange.svg?style=flat-square)](#contributors-) +[![All Contributors](https://img.shields.io/badge/all_contributors-29-orange.svg?style=flat-square)](#contributors-) Orion is an open-source, community-driven framework dedicated to Provable Machine Learning. It provides essential components and a new ONNX runtime for building verifiable Machine Learning models using [STARKs](https://starkware.co/stark/). @@ -104,6 +104,9 @@ Thanks goes to these wonderful people: Bilgin KoΓ§ak
Bilgin Koçak

πŸ’» akhercha
akhercha

πŸ’» + + Vid Kersic
Vid Kersic

πŸ’» + From 4102e1e7db3407812427ada6152c43aeed1e0425 Mon Sep 17 00:00:00 2001 From: "allcontributors[bot]" <46447321+allcontributors[bot]@users.noreply.github.com> Date: Fri, 12 Jan 2024 00:45:42 +0000 Subject: [PATCH 34/38] docs: update .all-contributorsrc [skip ci] --- .all-contributorsrc | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/.all-contributorsrc b/.all-contributorsrc index 8a41f9cbe..7f624a7d6 100644 --- a/.all-contributorsrc +++ b/.all-contributorsrc @@ -260,6 +260,15 @@ "contributions": [ "code" ] + }, + { + "login": "Vid201", + "name": "Vid Kersic", + "avatar_url": "https://avatars.githubusercontent.com/u/38610409?v=4", + "profile": "https://github.com/Vid201", + "contributions": [ + "code" + ] } ], "contributorsPerLine": 7, From aee4d18e0bc35bba0bce3d5013eff076acaf3a16 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Thu, 11 Jan 2024 20:18:09 -0500 Subject: [PATCH 35/38] remove duplicate compress --- src/operators/tensor/implementations/tensor_bool.cairo | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/operators/tensor/implementations/tensor_bool.cairo b/src/operators/tensor/implementations/tensor_bool.cairo index f7f7a82ce..fab0b8f3f 100644 --- a/src/operators/tensor/implementations/tensor_bool.cairo +++ b/src/operators/tensor/implementations/tensor_bool.cairo @@ -479,10 +479,6 @@ impl BoolTensor of TensorTrait { panic(array!['not supported!']) } - fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { - math::compress::compress(self, condition, axis) - } - fn split(self: @Tensor, axis: usize, num_outputs: Option, spl: Option>) -> Array> { panic(array!['not supported!']) } From 6188e1603cf9376abbd701d29028ffc08d49f6e7 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Thu, 11 Jan 2024 20:52:12 -0500 Subject: [PATCH 36/38] format code --- docs/framework/operators/tensor/README.md | 1 + .../ml/linear/linear_classifier.cairo | 4 +- src/operators/tensor/core.cairo | 3 +- .../tensor/implementations/tensor_bool.cairo | 6 +- .../implementations/tensor_complex64.cairo | 12 +- .../implementations/tensor_fp16x16.cairo | 8 +- .../implementations/tensor_fp32x32.cairo | 6 +- .../implementations/tensor_fp64x64.cairo | 6 +- .../implementations/tensor_fp8x23.cairo | 6 +- .../implementations/tensor_fp8x23wide.cairo | 6 +- .../tensor/implementations/tensor_i32.cairo | 4 +- .../tensor/implementations/tensor_i8.cairo | 4 +- .../tensor/implementations/tensor_u32.cairo | 4 +- src/operators/tensor/manipulation/split.cairo | 109 ++-- tests/ml/linear_classifier_test.cairo | 1 - tests/ml/tree_ensemble_classifier.cairo | 475 ++++++++++++------ .../split_fp16x16_1d_variable_parts.cairo | 9 +- .../split_fp16x16_2d_variable_parts.cairo | 9 +- tests/nodes/split_fp16x16_zero_size.cairo | 9 +- tests/nodes/split_u32_1d_variable_parts.cairo | 9 +- tests/nodes/split_u32_2d_variable_parts.cairo | 9 +- tests/nodes/split_u32_zero_size.cairo | 9 +- tests/operators/transpose_test.cairo | 4 +- 23 files changed, 457 insertions(+), 256 deletions(-) diff --git a/docs/framework/operators/tensor/README.md b/docs/framework/operators/tensor/README.md index 49241c9e0..281135f63 100644 --- a/docs/framework/operators/tensor/README.md +++ b/docs/framework/operators/tensor/README.md @@ -119,6 +119,7 @@ use orion::operators::tensor::TensorTrait; | [`tensor.reduce_log_sum`](tensor.reduce\_log\_sum.md) | Computes the log sum of the input tensor's elements along the provided axes. | | [`tensor.erf`](tensor.erf.md) | Computes the error function of the given input tensor element-wise. | | [`tensor.layer_normalization`](tensor.layer\_normalization.md) | computes the layer normalization of the input tensor. | +| [`tensor.split`](tensor.split.md) | Split a tensor into a list of tensors, along the specified β€˜axis’. | ## Arithmetic Operations diff --git a/src/operators/ml/linear/linear_classifier.cairo b/src/operators/ml/linear/linear_classifier.cairo index 2230500bc..fad7ea2d4 100644 --- a/src/operators/ml/linear/linear_classifier.cairo +++ b/src/operators/ml/linear/linear_classifier.cairo @@ -196,7 +196,7 @@ impl LinearClassifierImpl< POST_TRANSFORM::NONE => { scores }, POST_TRANSFORM::SOFTMAX => { NNTrait::softmax(@scores, 1) }, POST_TRANSFORM::LOGISTIC => { NNTrait::sigmoid(@scores) }, - POST_TRANSFORM::SOFTMAXZERO => { NNTrait::softmax_zero(@scores, 1)}, + POST_TRANSFORM::SOFTMAXZERO => { NNTrait::softmax_zero(@scores, 1) }, POST_TRANSFORM::PROBIT => core::panic_with_felt252('Probit not supported yet'), }; @@ -264,7 +264,7 @@ impl LinearClassifierImpl< } i += 1; }; - }, + }, POST_TRANSFORM::PROBIT => core::panic_with_felt252('Probit not supported yet'), }; } diff --git a/src/operators/tensor/core.cairo b/src/operators/tensor/core.cairo index 978ae7c91..796224fb1 100644 --- a/src/operators/tensor/core.cairo +++ b/src/operators/tensor/core.cairo @@ -5175,7 +5175,8 @@ trait TensorTrait { /// [[2,3],[6,7]] /// ``` /// - fn split(self: @Tensor, axis: usize, num_outputs: Option, spl: Option> + fn split( + self: @Tensor, axis: usize, num_outputs: Option, spl: Option> ) -> Array>; } diff --git a/src/operators/tensor/implementations/tensor_bool.cairo b/src/operators/tensor/implementations/tensor_bool.cairo index fab0b8f3f..eaeda00d5 100644 --- a/src/operators/tensor/implementations/tensor_bool.cairo +++ b/src/operators/tensor/implementations/tensor_bool.cairo @@ -460,7 +460,7 @@ impl BoolTensor of TensorTrait { ) -> (Tensor, Tensor, Tensor) { panic(array!['not supported!']) } - + fn resize( self: @Tensor, roi: Option>, @@ -479,7 +479,9 @@ impl BoolTensor of TensorTrait { panic(array!['not supported!']) } - fn split(self: @Tensor, axis: usize, num_outputs: Option, spl: Option>) -> Array> { + fn split( + self: @Tensor, axis: usize, num_outputs: Option, spl: Option> + ) -> Array> { panic(array!['not supported!']) } } diff --git a/src/operators/tensor/implementations/tensor_complex64.cairo b/src/operators/tensor/implementations/tensor_complex64.cairo index c68ebe636..52d916b14 100644 --- a/src/operators/tensor/implementations/tensor_complex64.cairo +++ b/src/operators/tensor/implementations/tensor_complex64.cairo @@ -488,8 +488,13 @@ impl Complex64Tensor of TensorTrait { ) -> (Tensor, Tensor, Tensor) { panic(array!['not supported!']) } - - fn split(self: @Tensor, axis: usize, num_outputs: Option, spl: Option>) -> Array> { + + fn split( + self: @Tensor, + axis: usize, + num_outputs: Option, + spl: Option> + ) -> Array> { panic(array!['not supported!']) } @@ -508,9 +513,8 @@ impl Complex64Tensor of TensorTrait { mode: Option, nearest_mode: Option, ) -> Tensor { - panic(array!['not supported!']) + panic(array!['not supported!']) } - } /// Implements addition for `Tensor` using the `Add` trait. diff --git a/src/operators/tensor/implementations/tensor_fp16x16.cairo b/src/operators/tensor/implementations/tensor_fp16x16.cairo index 548b1f9b0..cbe7fdfcc 100644 --- a/src/operators/tensor/implementations/tensor_fp16x16.cairo +++ b/src/operators/tensor/implementations/tensor_fp16x16.cairo @@ -516,7 +516,7 @@ impl FP16x16Tensor of TensorTrait { ) -> (Tensor, Tensor, Tensor) { math::layer_normalization::layer_normalization(self, scale, B, axis, epsilon, stash_type) } - + fn resize( self: @Tensor, roi: Option>, @@ -548,8 +548,10 @@ impl FP16x16Tensor of TensorTrait { nearest_mode ) } - - fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { + + fn compress( + self: @Tensor, condition: Tensor, axis: Option + ) -> Tensor { math::compress::compress(self, condition, axis) } diff --git a/src/operators/tensor/implementations/tensor_fp32x32.cairo b/src/operators/tensor/implementations/tensor_fp32x32.cairo index 644e5ef5e..e7b517eb4 100644 --- a/src/operators/tensor/implementations/tensor_fp32x32.cairo +++ b/src/operators/tensor/implementations/tensor_fp32x32.cairo @@ -549,8 +549,10 @@ impl FP32x32Tensor of TensorTrait { nearest_mode ) } - - fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { + + fn compress( + self: @Tensor, condition: Tensor, axis: Option + ) -> Tensor { math::compress::compress(self, condition, axis) } diff --git a/src/operators/tensor/implementations/tensor_fp64x64.cairo b/src/operators/tensor/implementations/tensor_fp64x64.cairo index 1694343e6..a8121fc31 100644 --- a/src/operators/tensor/implementations/tensor_fp64x64.cairo +++ b/src/operators/tensor/implementations/tensor_fp64x64.cairo @@ -549,8 +549,10 @@ impl FP64x64Tensor of TensorTrait { nearest_mode ) } - - fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { + + fn compress( + self: @Tensor, condition: Tensor, axis: Option + ) -> Tensor { math::compress::compress(self, condition, axis) } diff --git a/src/operators/tensor/implementations/tensor_fp8x23.cairo b/src/operators/tensor/implementations/tensor_fp8x23.cairo index 5ddcb1333..9f3e78573 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23.cairo @@ -547,8 +547,10 @@ impl FP8x23Tensor of TensorTrait { nearest_mode ) } - - fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { + + fn compress( + self: @Tensor, condition: Tensor, axis: Option + ) -> Tensor { math::compress::compress(self, condition, axis) } diff --git a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo index 7546f0072..6a7ca5ba4 100644 --- a/src/operators/tensor/implementations/tensor_fp8x23wide.cairo +++ b/src/operators/tensor/implementations/tensor_fp8x23wide.cairo @@ -486,8 +486,10 @@ impl FP8x23WTensor of TensorTrait { ) -> Tensor { panic(array!['not supported!']) } - - fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { + + fn compress( + self: @Tensor, condition: Tensor, axis: Option + ) -> Tensor { math::compress::compress(self, condition, axis) } diff --git a/src/operators/tensor/implementations/tensor_i32.cairo b/src/operators/tensor/implementations/tensor_i32.cairo index 0706ad712..2e2ae267d 100644 --- a/src/operators/tensor/implementations/tensor_i32.cairo +++ b/src/operators/tensor/implementations/tensor_i32.cairo @@ -519,7 +519,7 @@ impl I32Tensor of TensorTrait { ) -> Tensor { panic(array!['not supported!']) } - + fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { math::compress::compress(self, condition, axis) } @@ -534,7 +534,7 @@ impl I32Tensor of TensorTrait { ) -> (Tensor, Tensor, Tensor) { panic(array!['not supported!']) } - + fn split( self: @Tensor, axis: usize, num_outputs: Option, spl: Option> ) -> Array> { diff --git a/src/operators/tensor/implementations/tensor_i8.cairo b/src/operators/tensor/implementations/tensor_i8.cairo index 997a59f2c..769368166 100644 --- a/src/operators/tensor/implementations/tensor_i8.cairo +++ b/src/operators/tensor/implementations/tensor_i8.cairo @@ -517,7 +517,7 @@ impl I8Tensor of TensorTrait { ) -> Tensor { panic(array!['not supported!']) } - + fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { math::compress::compress(self, condition, axis) } @@ -532,7 +532,7 @@ impl I8Tensor of TensorTrait { ) -> (Tensor, Tensor, Tensor) { panic(array!['not supported!']) } - + fn split( self: @Tensor, axis: usize, num_outputs: Option, spl: Option> ) -> Array> { diff --git a/src/operators/tensor/implementations/tensor_u32.cairo b/src/operators/tensor/implementations/tensor_u32.cairo index 0edb0da92..d8e02d490 100644 --- a/src/operators/tensor/implementations/tensor_u32.cairo +++ b/src/operators/tensor/implementations/tensor_u32.cairo @@ -461,7 +461,7 @@ impl U32Tensor of TensorTrait { ) -> Tensor { panic(array!['not supported!']) } - + fn compress(self: @Tensor, condition: Tensor, axis: Option) -> Tensor { math::compress::compress(self, condition, axis) } @@ -476,7 +476,7 @@ impl U32Tensor of TensorTrait { ) -> (Tensor, Tensor, Tensor) { panic(array!['not supported!']) } - + fn split( self: @Tensor, axis: usize, num_outputs: Option, spl: Option> ) -> Array> { diff --git a/src/operators/tensor/manipulation/split.cairo b/src/operators/tensor/manipulation/split.cairo index 436de61ca..bf0274aec 100644 --- a/src/operators/tensor/manipulation/split.cairo +++ b/src/operators/tensor/manipulation/split.cairo @@ -17,29 +17,25 @@ fn split< self: @Tensor, axis: usize, num_outputs: Option, split: Option> ) -> Array> { let has_num_outputs = match num_outputs { - Option::Some(value) => { - true - }, + Option::Some(value) => { true }, Option::None => false, }; let has_split = match split { - Option::Some(value) => { - true - }, + Option::Some(value) => { true }, Option::None => false, }; assert(!(has_num_outputs && has_split), 'split or num_outputs not both.'); assert(has_num_outputs || has_split, 'split or num_outputs not both.'); - + let mut splited_t: Array> = array![]; let rank = (*self).shape.len(); // assert(axis < rank && axis > -rank, 'axis out of dimensions'); assert(axis < rank, 'axis out of dimensions'); - if (has_num_outputs){ + if (has_num_outputs) { splited_t = split_num_outputs(self, axis, num_outputs.unwrap()); - }else{ + } else { splited_t = split_has_split(self, axis, split.unwrap()); } splited_t @@ -57,30 +53,30 @@ fn split_num_outputs, +Drop, +TensorTrait, +PartialOrd, +Pa // if axis==0 { // axis = 1; // } - if (*(*t).shape.at(axis) % num_outputs == 0){ - div = *(*t).shape.at(axis) / num_outputs; + if (*(*t).shape.at(axis) % num_outputs == 0) { + div = *(*t).shape.at(axis) / num_outputs; let mut i = 0; loop { - if (i>=num_outputs) { + if (i >= num_outputs) { break; } split.append(div); i += 1; }; } else { - div = *(*t).shape.at(axis) / num_outputs+1; + div = *(*t).shape.at(axis) / num_outputs + 1; let mut i = 0; loop { - if (i>=num_outputs) { + if (i >= num_outputs) { break; } split.append(div); i += 1; }; - match split.pop_front(){ + match split.pop_front() { Option::Some(split_last_one) => { - split.append(split_last_one + *(*t).shape.at(axis) - div*(num_outputs-1)); - }, + split.append(split_last_one + *(*t).shape.at(axis) - div * (num_outputs - 1)); + }, Option::None(_) => { assert(false, 'split is none array'); } } } @@ -89,44 +85,40 @@ fn split_num_outputs, +Drop, +TensorTrait, +PartialOrd, +Pa let mut pos: usize = 0; let mut i = 0; loop { - if (i>=(*t).shape.len()) { + if (i >= (*t).shape.len()) { break; } let s: usize = *(*t).shape.at(i); - sli.set(i,0,0); - sli.set(i,1,s); + sli.set(i, 0, 0); + sli.set(i, 1, s); i += 1; }; let mut i: usize = 0; loop { - if (i>=split.len()) { + if (i >= split.len()) { break; } let spl = *split.at(i); sli.set(axis, 0, pos); - pos += spl; + pos += spl; sli.set(axis, 1, pos); let end_ele_0 = match sli.get(axis, 0) { - Option::Some(res) => { - res - }, - Option::None(_) => { - assert(false, 'Get end_ele_0 is failed'); - 0 - }, + Option::Some(res) => { res }, + Option::None(_) => { + assert(false, 'Get end_ele_0 is failed'); + 0 + }, }; let end_ele_1 = match sli.get(axis, 1) { - Option::Some(res) => { - res - }, - Option::None(_) => { - assert(false, 'Get end_ele_0 is failed'); - 0 - }, + Option::Some(res) => { res }, + Option::None(_) => { + assert(false, 'Get end_ele_0 is failed'); + 0 + }, }; - let starts: Span = array![sli.get(0,0).unwrap(),end_ele_0].span(); - let ends: Span = array![ sli.get(0,1).unwrap(), end_ele_1].span(); + let starts: Span = array![sli.get(0, 0).unwrap(), end_ele_0].span(); + let ends: Span = array![sli.get(0, 1).unwrap(), end_ele_1].span(); let axes: Option> = Option::None(()); let steps: Option> = Option::None(()); let sub_t: Tensor = t.slice(starts, ends, axes, steps); @@ -134,7 +126,6 @@ fn split_num_outputs, +Drop, +TensorTrait, +PartialOrd, +Pa i += 1; }; splited_t - } /// Subfunction split for tensors (wth split). @@ -147,44 +138,40 @@ fn split_has_split, +Drop, +TensorTrait, +PartialOrd, +Part let mut pos: usize = 0; let mut i = 0; loop { - if (i>=(*t).shape.len()) { + if (i >= (*t).shape.len()) { break; } let s: usize = *(*t).shape.at(i); - sli.set(i,0,0); - sli.set(i,1,s); + sli.set(i, 0, 0); + sli.set(i, 1, s); i += 1; }; let mut i: usize = 0; loop { - if (i>=split.data.len()) { + if (i >= split.data.len()) { break; } let spl: usize = split.at(indices: array![i].span()); sli.set(axis, 0, pos); - pos += spl; + pos += spl; sli.set(axis, 1, pos); let end_ele_0 = match sli.get(axis, 0) { - Option::Some(res) => { - res - }, - Option::None(_) => { - assert(false, 'Get end_ele_0 is failed'); - 0 - }, + Option::Some(res) => { res }, + Option::None(_) => { + assert(false, 'Get end_ele_0 is failed'); + 0 + }, }; let end_ele_1 = match sli.get(axis, 1) { - Option::Some(res) => { - res - }, - Option::None(_) => { - assert(false, 'Get end_ele_0 is failed'); - 0 - }, + Option::Some(res) => { res }, + Option::None(_) => { + assert(false, 'Get end_ele_0 is failed'); + 0 + }, }; - let starts: Span = array![sli.get(0,0).unwrap(),end_ele_0].span(); - let ends: Span = array![ sli.get(0,1).unwrap(), end_ele_1].span(); + let starts: Span = array![sli.get(0, 0).unwrap(), end_ele_0].span(); + let ends: Span = array![sli.get(0, 1).unwrap(), end_ele_1].span(); let axes: Option> = Option::None(()); let steps: Option> = Option::None(()); let sub_t: Tensor = t.slice(starts, ends, axes, steps); @@ -192,4 +179,4 @@ fn split_has_split, +Drop, +TensorTrait, +PartialOrd, +Part i += 1; }; splited_t -} \ No newline at end of file +} diff --git a/tests/ml/linear_classifier_test.cairo b/tests/ml/linear_classifier_test.cairo index 1a1c90e7d..e0c892328 100644 --- a/tests/ml/linear_classifier_test.cairo +++ b/tests/ml/linear_classifier_test.cairo @@ -180,7 +180,6 @@ fn test_linear_classifier_binary_softmax_zero() { assert(*scores.data[1] == FP16x16 { mag: 65535, sign: false }, '*scores[1] == 1.000000'); assert(*scores.data[2] == FP16x16 { mag: 0, sign: false }, '*scores[2] == 1.674492e-06'); assert(*scores.data[3] == FP16x16 { mag: 65535, sign: false }, '*scores[3] == 9.999983e-01'); - } #[test] diff --git a/tests/ml/tree_ensemble_classifier.cairo b/tests/ml/tree_ensemble_classifier.cairo index 2603625b4..441aabb34 100644 --- a/tests/ml/tree_ensemble_classifier.cairo +++ b/tests/ml/tree_ensemble_classifier.cairo @@ -241,8 +241,9 @@ fn test_tree_ensemble_classifier_binary_none() { #[test] #[available_gas(200000000000)] fn test_tree_ensemble_classifier_binary_logistic() { - - let (mut classifier, X) = tree_ensemble_classifier_binary_class_helper(POST_TRANSFORM::LOGISTIC); + let (mut classifier, X) = tree_ensemble_classifier_binary_class_helper( + POST_TRANSFORM::LOGISTIC + ); let (labels, mut scores) = TreeEnsembleClassifierTrait::predict(ref classifier, X); @@ -2082,161 +2083,316 @@ fn tree_ensemble_classifier_binary_class_helper( let mut root_index: Felt252Dict = Default::default(); root_index.insert(0, 0); let mut node_index: Felt252Dict = Default::default(); - node_index.insert(2089986280348253421170679821480865132823066470938446095505822317253594081284, 0); - node_index.insert(2001140082530619239661729809084578298299223810202097622761632384561112390979, 1); - node_index.insert(2592670241084192212354027440049085852792506518781954896144296316131790403900, 2); - node_index.insert(2960591271376829378356567803618548672034867345123727178628869426548453833420, 3); - node_index.insert(458933264452572171106695256465341160654132084710250671055261382009315664425, 4); - node_index.insert(3344223123784052057366048933846905716067140384361791026153972616805110454637, 5); - node_index.insert(658476905110174425295568215706634733332002869979287079110965040248935650599, 6); - node_index.insert(2836212335642438363012490794290757623813171043187182819737087983331902926990, 7); - node_index.insert(3496601277869056110810900082189273917786762659443522403285387602989271154262, 8); - node_index.insert(1249294489531540970169611621067106471309281870082955806338234725206665112557, 9); - node_index.insert(2161697998033672097816961828039488190903838124365465380011173778905747857792, 10); - node_index.insert(1129815197211541481934112806673325772687763881719835256646064516195041515616, 11); - node_index.insert(2592593088135949192377729543480191336537305484235681164569491942155715064163, 12); - node_index.insert(578223957014284909949571568465953382377214912750427143720957054706073492593, 13); - node_index.insert(1645617302026197421098102802983206579163506957138012501615708926120228167528, 14); - node_index.insert(2809438816810155970395166036110536928593305127049404137239671320081144123490, 15); - node_index.insert(2496308528011391755709310159103918074725328650411689040761791240500618770096, 16); - node_index.insert(2003594778587446957576114348312422277631766150749194167061999666337236425714, 17); - node_index.insert(2215681478480673835576618830034726157921200517935329010004363713426342305479, 18); - node_index.insert(3185925835074464079989752015681272863271067691852543168049845807561733691707, 19); - node_index.insert(1207265836470221457484062512091666004839070622130697586496866096347024057755, 20); - node_index.insert(1870230949202979679764944800468118671928852128047695497376875566624821494262, 21); - node_index.insert(618060852536781954395603948693216564334274573299243914053414488061601327758, 22); - node_index.insert(232760707548494477255512699093366059519467428168757247456690480397246371463, 23); - node_index.insert(1617386247965480308136742715422077429967341022950306068917456849194882895900, 24); - node_index.insert(654822874782506608656472905579051041410086644071534146326024101025575400153, 25); - node_index.insert(525638101901638132526332140778087078272370083489998903571807698910013602668, 26); - node_index.insert(3091640181556387972179279087539287892670640556085669903494551919685982442095, 27); - node_index.insert(1425411460578159050163131982087304445715005458700346341117759372943452688022, 28); - node_index.insert(1722933265299553894839124723076027659619615015638971980461286818493531809034, 29); - node_index.insert(3325117385742592388671007840076299062858228097051060057749225651290693960897, 30); - node_index.insert(1869273998012404873272699831805499731567895666937555882116307079956228100456, 31); - node_index.insert(257262395234910825879033951801423835835630270967846664413154594520703929530, 32); - node_index.insert(2891500475385583315757684141371327604925143655360011721762142660942782195029, 33); - node_index.insert(1257459981124043271342269816753070228024611695909553991758648317372015085782, 34); - node_index.insert(3573101724490615587655146760489247477770015274618159524231872921394794809579, 35); - node_index.insert(2951401777594449283985541406642940553317465718696638438535370997641527993378, 36); - node_index.insert(2436860863451320452900512817385686838091627966322316039332239784330434600829, 37); - node_index.insert(3257977356974702770994741663931928753019715185508521958836925918758890988390, 38); - node_index.insert(2741853283805093821434776875305720302351684616683152528499335618682018880592, 39); - node_index.insert(514567459251558911686762246500770717674979116530125263461114578537254680672, 40); - node_index.insert(2119374930171040799805795099091470687208894498354655018353474015395489390434, 41); - node_index.insert(3338470191188327918255138125570464269857839379813971679216902484398948556964, 42); - node_index.insert(2892272281879752543368066497063301979597320550780387266511926397533716561161, 43); - node_index.insert(2855312300216814846973137837923466865382642814675378398541743368270404441020, 44); - node_index.insert(3483159989811162048659069774034779954374540681397531094699912464364012442948, 45); - node_index.insert(2987290998320166766043911843685118029159841654368226419198314196237253901671, 46); - node_index.insert(2925128850088180758852255336587985612621894021863350117875677692518888637440, 47); - node_index.insert(2816470536741550741568042622139415760794090671576940833850781679568928363263, 48); - node_index.insert(117504025904364990582663097556885493352655695615775952177872159762046032741, 49); - node_index.insert(2143228410294149239354901612797540167003066966910132278060626241695943498248, 50); - node_index.insert(419311759585766455354017006957403420381614228026953716552023555428752798694, 51); - node_index.insert(3050064038480880151202753004776919876287903442365303272956696507808448797287, 52); - node_index.insert(1385347512411195789080079656286641766866442255046855963092069449745407366357, 53); - node_index.insert(3070310993421490198115289431281422702215620142859327949152517372324361472619, 54); - node_index.insert(2913742884576958969164113782587195202828846527657900496424141449477472273564, 55); - node_index.insert(2093568472535973986606438755824580633177115509557931302974988564932601955239, 56); - node_index.insert(3560543329106347446823281318204312198881533222464682017397248462954529220234, 57); - node_index.insert(2258329791422139736262782239641765930569031761627249090322755566443202104242, 58); - node_index.insert(780147230530856456622774510057100334628735431063744145772648079601317149643, 59); - node_index.insert(2316329094783634722527635915976455864728431870713378530935487247638854220445, 60); - node_index.insert(595942459003356191117553450912822964169058193996898486073017533717706655996, 61); - node_index.insert(468061318535033931711585815055033307297228787991312757359512916260570188285, 62); - node_index.insert(2052204235688624923559873131063770183910134013049526186717275231865702195614, 63); - node_index.insert(1699955311620840869165542755053722387608345658646185648087789689690825797785, 64); - node_index.insert(3374282522812564185678772854203408947562394461702303390331208821006329361123, 65); - node_index.insert(2973169188135795465401576355486514117723575153845438471619715618155257254587, 66); - node_index.insert(1933845760462748501896196912926633344425020928596291295340561855718789280752, 67); - node_index.insert(1400206374308839959676708676217334569580738052049798766556848516900888958934, 68); - node_index.insert(1440488595273849761788031183901254714714513692476890759699232177835922420051, 69); - node_index.insert(1765607197782429306903827944694032984087223086461400721152786273443512274576, 70); - node_index.insert(1081728107764482028110815183657783965582618309560569428049406599883158895762, 71); - node_index.insert(2062101824085365476835789898002802715794623271831111740147610520210138854237, 72); - node_index.insert(2074740322618091900768870458741540994849904300182495465356314088191301853065, 73); - node_index.insert(3258451235037745323160669027918885172565773098482160366154412360890640013860, 74); - node_index.insert(525053653813541387331907730505904505067816165493211829943994988775279102044, 75); - node_index.insert(1899573658331441767985549642643113663505618738939032010935036740376062596854, 76); - node_index.insert(350484224543766923071449868701665032398970313961410080649918872017849315812, 77); - node_index.insert(1950842492180490337143378914485176805944281696420768035114335939818602766139, 78); - node_index.insert(1404824782481446239312837894341789608778585592445990662138109764117920511709, 79); - node_index.insert(362836422984951199752185473435750713386745407518736982952373985921347236081, 80); - node_index.insert(946623025367211063265176586824604502073515634531788667777364911179858705558, 81); - node_index.insert(2633163324000277496191816132521100721217797223993064604664039067710591734562, 82); - node_index.insert(1801986104078933931671502775029170829560335045042499367678597186639133610708, 83); - node_index.insert(1420697278439090953165809531316265389371075037014378922361911811337560296928, 84); - node_index.insert(2818913779862691152404893285048164649343019708946413114150419613972391643833, 85); - node_index.insert(2117995436013652728497840885480545729833030913486848118093758726746902541269, 86); - node_index.insert(127751852951361188238686395231851222850913859197429858579312845246901369178, 87); - node_index.insert(2698811633001158191033663638617437313508153976714307643233173949778419312517, 88); - node_index.insert(658388282521842455588914251287531837029259203197178137902217792556456503561, 89); - node_index.insert(1181527093320872098458354979612125149419384756607076935731557552577945926179, 90); - node_index.insert(749436134732178646256740138670151907037714564259781780243747781475007506978, 91); - node_index.insert(139527053159256821789882596124320673637475746672994443968014105962305658551, 92); - node_index.insert(2256264752321707533173578319742847366660740117899562657584919346001438808295, 93); - node_index.insert(1471349294215639651865069312281269029496180149092207674923855978537861742949, 94); - node_index.insert(1599527610774916650758786135513735847459194869088601099692148267264507139422, 95); - node_index.insert(1348925567371118538973078195838174941892601233016661969987842843098656775084, 96); - node_index.insert(3255130909854220350850821724488067913492420563978595271106701962634473840914, 97); - node_index.insert(1098499015810170842401428216621470177488952811780672364884710297364076372943, 98); - node_index.insert(2666902303639302012507119689908308317608522901613536135678723310999647515155, 99); - node_index.insert(907997515879651052705985194221621380802961721264372722705825219340461809200, 100); - node_index.insert(2124360554325144308113106422635485756539471211141315552843423768396084888273, 101); - node_index.insert(3598736440043009208771817410113758019876931018927260161846683440123219507147, 102); - node_index.insert(1237113034722832488580561245188430373504295256910735188987019984096012001931, 103); - node_index.insert(884558344049768836371555446021588200903052780339208951904957349404044037185, 104); - node_index.insert(784280321344489256066716285882203121428790637989919760379274813665427427262, 105); - node_index.insert(3472551952588748711709398308465335743810517871695257916614928877311914574241, 106); - node_index.insert(1579363348100943961344032004617708767155021524242506190674861550786419896732, 107); - node_index.insert(653576968777651719072715499492112313607520878545254037043893560183879857489, 108); - node_index.insert(2633327961579170199842757290989312779085828750765842327985383652720803061926, 109); - node_index.insert(3101204920253220343970782457572784926765600523633379722044614528209389590915, 110); - node_index.insert(2537565394330405662800880050062241097694806466900452037378113841155978555645, 111); - node_index.insert(306955559655552244989220345789093187601563118591829582730637833945761653350, 112); - node_index.insert(1144065212212058748489308207801098564095305699242880891977316839573431241916, 113); - node_index.insert(3478181491851418723342103101321490659650934149094649769124337426850038155270, 114); - node_index.insert(3419621624676637660673415219086314486713019053519954317586073983685881930356, 115); - node_index.insert(2426908011370291613447136873176769136554489197972200481728552402228021778402, 116); - node_index.insert(1916122042123370178944690083048900704842269230325086549679099089416174875473, 117); - node_index.insert(2057207652658215393591191155928140567561900227203223756539551876829334137660, 118); - node_index.insert(2722034389703601317070746005702467061064354401688341549606678773616189196490, 119); - node_index.insert(1171026027377763359814377926117880688616494219551682642535759838199732407496, 120); - node_index.insert(3507234282031533800397666430789917374211847440333243952151005899337152633413, 121); - node_index.insert(591003147462937848375161803108517142253138969543815135207326321181858185919, 122); - node_index.insert(182069734527202013451813026473135702900640769187641767871411473365447302169, 123); - node_index.insert(1195243682249232878341146428166676460720423167409013083888435705219134747702, 124); - node_index.insert(1793425644853312386902998134061844248823841892125424765064687913085130719534, 125); - node_index.insert(1983622665815164792580256365519803214027269990384198703315493315153573288434, 126); - node_index.insert(3615973154491344159350153395208055142342062736505558158666764642048838175685, 127); - node_index.insert(2751715913626909804252433699602081411293721754810298670422380863932998088133, 128); - node_index.insert(186918881712189523740089713555196200069231794627360499557319265374750577226, 129); - node_index.insert(696585542544434929491503209053317581175146475161262066468664234437983008675, 130); - node_index.insert(4359830495913805154545225899592517767672472055784183911796827820518038513, 131); - node_index.insert(2954335207058000607751727656601539819316106074875304820535376873121805433820, 132); - node_index.insert(2510390039949230255082316953804013731253145558531652907601250263563528226672, 133); - node_index.insert(3226995230854300551967642178527450300960499043510855212238369890580256668532, 134); - node_index.insert(1620924075233065517364532267959798304439946408626316544761884056227131075831, 135); - node_index.insert(1610900122192929153657761847202689179268074338802437933866337242354758101660, 136); - node_index.insert(2565949095169598991903537465065584077778440646580025930326495506484329892725, 137); - node_index.insert(1012362975819634411571869839734809106575285344002573666983595104659295812607, 138); - node_index.insert(242312010918799555845832460483650516749990744287009628468613253461264531026, 139); - node_index.insert(1104776796569046483584574115975216172161469015460244982207905888870418040487, 140); - node_index.insert(3289555912992777681578950209252840071327866822704829766247386311885634446673, 141); - node_index.insert(3133389957643610781371406448279843175887428913359743769920083259111437722268, 142); - node_index.insert(1169918710119352022244140656086831769713178729571654411898266328562003734517, 143); - node_index.insert(3592039235252149652556167686570045881877115549259769455422056097903987237819, 144); - node_index.insert(2048175709145840597887667330964815895803568760936075562647625937161113445908, 145); - node_index.insert(602222645962845554276438041138511866776339653340605661136009451417275008940, 146); - node_index.insert(3318742320906017551291978242369663702298606650330380959683585594592748661010, 147); - node_index.insert(564160996724923690963741657975239836484028160385417016805513722318839327322, 148); - node_index.insert(656294390376267384135628810815504467149264887388377312825033341338166573620, 149); - node_index.insert(1201592236750942207412694706123654466634588634474700675083122904145559965915, 150); - node_index.insert(2141408926815137181004274624388915700231991905288681935478972043994347966006, 151); - node_index.insert(1440847977042239464860406726605567303568767649154338464116083965986084755262, 152); - node_index.insert(950585553138591375958592507876257987416844837045084288783892644487908218679, 153); - node_index.insert(257643451533833048856069434258149588745628261389615631070776723485957908127, 154); + node_index + .insert(2089986280348253421170679821480865132823066470938446095505822317253594081284, 0); + node_index + .insert(2001140082530619239661729809084578298299223810202097622761632384561112390979, 1); + node_index + .insert(2592670241084192212354027440049085852792506518781954896144296316131790403900, 2); + node_index + .insert(2960591271376829378356567803618548672034867345123727178628869426548453833420, 3); + node_index + .insert(458933264452572171106695256465341160654132084710250671055261382009315664425, 4); + node_index + .insert(3344223123784052057366048933846905716067140384361791026153972616805110454637, 5); + node_index + .insert(658476905110174425295568215706634733332002869979287079110965040248935650599, 6); + node_index + .insert(2836212335642438363012490794290757623813171043187182819737087983331902926990, 7); + node_index + .insert(3496601277869056110810900082189273917786762659443522403285387602989271154262, 8); + node_index + .insert(1249294489531540970169611621067106471309281870082955806338234725206665112557, 9); + node_index + .insert(2161697998033672097816961828039488190903838124365465380011173778905747857792, 10); + node_index + .insert(1129815197211541481934112806673325772687763881719835256646064516195041515616, 11); + node_index + .insert(2592593088135949192377729543480191336537305484235681164569491942155715064163, 12); + node_index + .insert(578223957014284909949571568465953382377214912750427143720957054706073492593, 13); + node_index + .insert(1645617302026197421098102802983206579163506957138012501615708926120228167528, 14); + node_index + .insert(2809438816810155970395166036110536928593305127049404137239671320081144123490, 15); + node_index + .insert(2496308528011391755709310159103918074725328650411689040761791240500618770096, 16); + node_index + .insert(2003594778587446957576114348312422277631766150749194167061999666337236425714, 17); + node_index + .insert(2215681478480673835576618830034726157921200517935329010004363713426342305479, 18); + node_index + .insert(3185925835074464079989752015681272863271067691852543168049845807561733691707, 19); + node_index + .insert(1207265836470221457484062512091666004839070622130697586496866096347024057755, 20); + node_index + .insert(1870230949202979679764944800468118671928852128047695497376875566624821494262, 21); + node_index + .insert(618060852536781954395603948693216564334274573299243914053414488061601327758, 22); + node_index + .insert(232760707548494477255512699093366059519467428168757247456690480397246371463, 23); + node_index + .insert(1617386247965480308136742715422077429967341022950306068917456849194882895900, 24); + node_index + .insert(654822874782506608656472905579051041410086644071534146326024101025575400153, 25); + node_index + .insert(525638101901638132526332140778087078272370083489998903571807698910013602668, 26); + node_index + .insert(3091640181556387972179279087539287892670640556085669903494551919685982442095, 27); + node_index + .insert(1425411460578159050163131982087304445715005458700346341117759372943452688022, 28); + node_index + .insert(1722933265299553894839124723076027659619615015638971980461286818493531809034, 29); + node_index + .insert(3325117385742592388671007840076299062858228097051060057749225651290693960897, 30); + node_index + .insert(1869273998012404873272699831805499731567895666937555882116307079956228100456, 31); + node_index + .insert(257262395234910825879033951801423835835630270967846664413154594520703929530, 32); + node_index + .insert(2891500475385583315757684141371327604925143655360011721762142660942782195029, 33); + node_index + .insert(1257459981124043271342269816753070228024611695909553991758648317372015085782, 34); + node_index + .insert(3573101724490615587655146760489247477770015274618159524231872921394794809579, 35); + node_index + .insert(2951401777594449283985541406642940553317465718696638438535370997641527993378, 36); + node_index + .insert(2436860863451320452900512817385686838091627966322316039332239784330434600829, 37); + node_index + .insert(3257977356974702770994741663931928753019715185508521958836925918758890988390, 38); + node_index + .insert(2741853283805093821434776875305720302351684616683152528499335618682018880592, 39); + node_index + .insert(514567459251558911686762246500770717674979116530125263461114578537254680672, 40); + node_index + .insert(2119374930171040799805795099091470687208894498354655018353474015395489390434, 41); + node_index + .insert(3338470191188327918255138125570464269857839379813971679216902484398948556964, 42); + node_index + .insert(2892272281879752543368066497063301979597320550780387266511926397533716561161, 43); + node_index + .insert(2855312300216814846973137837923466865382642814675378398541743368270404441020, 44); + node_index + .insert(3483159989811162048659069774034779954374540681397531094699912464364012442948, 45); + node_index + .insert(2987290998320166766043911843685118029159841654368226419198314196237253901671, 46); + node_index + .insert(2925128850088180758852255336587985612621894021863350117875677692518888637440, 47); + node_index + .insert(2816470536741550741568042622139415760794090671576940833850781679568928363263, 48); + node_index + .insert(117504025904364990582663097556885493352655695615775952177872159762046032741, 49); + node_index + .insert(2143228410294149239354901612797540167003066966910132278060626241695943498248, 50); + node_index + .insert(419311759585766455354017006957403420381614228026953716552023555428752798694, 51); + node_index + .insert(3050064038480880151202753004776919876287903442365303272956696507808448797287, 52); + node_index + .insert(1385347512411195789080079656286641766866442255046855963092069449745407366357, 53); + node_index + .insert(3070310993421490198115289431281422702215620142859327949152517372324361472619, 54); + node_index + .insert(2913742884576958969164113782587195202828846527657900496424141449477472273564, 55); + node_index + .insert(2093568472535973986606438755824580633177115509557931302974988564932601955239, 56); + node_index + .insert(3560543329106347446823281318204312198881533222464682017397248462954529220234, 57); + node_index + .insert(2258329791422139736262782239641765930569031761627249090322755566443202104242, 58); + node_index + .insert(780147230530856456622774510057100334628735431063744145772648079601317149643, 59); + node_index + .insert(2316329094783634722527635915976455864728431870713378530935487247638854220445, 60); + node_index + .insert(595942459003356191117553450912822964169058193996898486073017533717706655996, 61); + node_index + .insert(468061318535033931711585815055033307297228787991312757359512916260570188285, 62); + node_index + .insert(2052204235688624923559873131063770183910134013049526186717275231865702195614, 63); + node_index + .insert(1699955311620840869165542755053722387608345658646185648087789689690825797785, 64); + node_index + .insert(3374282522812564185678772854203408947562394461702303390331208821006329361123, 65); + node_index + .insert(2973169188135795465401576355486514117723575153845438471619715618155257254587, 66); + node_index + .insert(1933845760462748501896196912926633344425020928596291295340561855718789280752, 67); + node_index + .insert(1400206374308839959676708676217334569580738052049798766556848516900888958934, 68); + node_index + .insert(1440488595273849761788031183901254714714513692476890759699232177835922420051, 69); + node_index + .insert(1765607197782429306903827944694032984087223086461400721152786273443512274576, 70); + node_index + .insert(1081728107764482028110815183657783965582618309560569428049406599883158895762, 71); + node_index + .insert(2062101824085365476835789898002802715794623271831111740147610520210138854237, 72); + node_index + .insert(2074740322618091900768870458741540994849904300182495465356314088191301853065, 73); + node_index + .insert(3258451235037745323160669027918885172565773098482160366154412360890640013860, 74); + node_index + .insert(525053653813541387331907730505904505067816165493211829943994988775279102044, 75); + node_index + .insert(1899573658331441767985549642643113663505618738939032010935036740376062596854, 76); + node_index + .insert(350484224543766923071449868701665032398970313961410080649918872017849315812, 77); + node_index + .insert(1950842492180490337143378914485176805944281696420768035114335939818602766139, 78); + node_index + .insert(1404824782481446239312837894341789608778585592445990662138109764117920511709, 79); + node_index + .insert(362836422984951199752185473435750713386745407518736982952373985921347236081, 80); + node_index + .insert(946623025367211063265176586824604502073515634531788667777364911179858705558, 81); + node_index + .insert(2633163324000277496191816132521100721217797223993064604664039067710591734562, 82); + node_index + .insert(1801986104078933931671502775029170829560335045042499367678597186639133610708, 83); + node_index + .insert(1420697278439090953165809531316265389371075037014378922361911811337560296928, 84); + node_index + .insert(2818913779862691152404893285048164649343019708946413114150419613972391643833, 85); + node_index + .insert(2117995436013652728497840885480545729833030913486848118093758726746902541269, 86); + node_index + .insert(127751852951361188238686395231851222850913859197429858579312845246901369178, 87); + node_index + .insert(2698811633001158191033663638617437313508153976714307643233173949778419312517, 88); + node_index + .insert(658388282521842455588914251287531837029259203197178137902217792556456503561, 89); + node_index + .insert(1181527093320872098458354979612125149419384756607076935731557552577945926179, 90); + node_index + .insert(749436134732178646256740138670151907037714564259781780243747781475007506978, 91); + node_index + .insert(139527053159256821789882596124320673637475746672994443968014105962305658551, 92); + node_index + .insert(2256264752321707533173578319742847366660740117899562657584919346001438808295, 93); + node_index + .insert(1471349294215639651865069312281269029496180149092207674923855978537861742949, 94); + node_index + .insert(1599527610774916650758786135513735847459194869088601099692148267264507139422, 95); + node_index + .insert(1348925567371118538973078195838174941892601233016661969987842843098656775084, 96); + node_index + .insert(3255130909854220350850821724488067913492420563978595271106701962634473840914, 97); + node_index + .insert(1098499015810170842401428216621470177488952811780672364884710297364076372943, 98); + node_index + .insert(2666902303639302012507119689908308317608522901613536135678723310999647515155, 99); + node_index + .insert(907997515879651052705985194221621380802961721264372722705825219340461809200, 100); + node_index + .insert(2124360554325144308113106422635485756539471211141315552843423768396084888273, 101); + node_index + .insert(3598736440043009208771817410113758019876931018927260161846683440123219507147, 102); + node_index + .insert(1237113034722832488580561245188430373504295256910735188987019984096012001931, 103); + node_index + .insert(884558344049768836371555446021588200903052780339208951904957349404044037185, 104); + node_index + .insert(784280321344489256066716285882203121428790637989919760379274813665427427262, 105); + node_index + .insert(3472551952588748711709398308465335743810517871695257916614928877311914574241, 106); + node_index + .insert(1579363348100943961344032004617708767155021524242506190674861550786419896732, 107); + node_index + .insert(653576968777651719072715499492112313607520878545254037043893560183879857489, 108); + node_index + .insert(2633327961579170199842757290989312779085828750765842327985383652720803061926, 109); + node_index + .insert(3101204920253220343970782457572784926765600523633379722044614528209389590915, 110); + node_index + .insert(2537565394330405662800880050062241097694806466900452037378113841155978555645, 111); + node_index + .insert(306955559655552244989220345789093187601563118591829582730637833945761653350, 112); + node_index + .insert(1144065212212058748489308207801098564095305699242880891977316839573431241916, 113); + node_index + .insert(3478181491851418723342103101321490659650934149094649769124337426850038155270, 114); + node_index + .insert(3419621624676637660673415219086314486713019053519954317586073983685881930356, 115); + node_index + .insert(2426908011370291613447136873176769136554489197972200481728552402228021778402, 116); + node_index + .insert(1916122042123370178944690083048900704842269230325086549679099089416174875473, 117); + node_index + .insert(2057207652658215393591191155928140567561900227203223756539551876829334137660, 118); + node_index + .insert(2722034389703601317070746005702467061064354401688341549606678773616189196490, 119); + node_index + .insert(1171026027377763359814377926117880688616494219551682642535759838199732407496, 120); + node_index + .insert(3507234282031533800397666430789917374211847440333243952151005899337152633413, 121); + node_index + .insert(591003147462937848375161803108517142253138969543815135207326321181858185919, 122); + node_index + .insert(182069734527202013451813026473135702900640769187641767871411473365447302169, 123); + node_index + .insert(1195243682249232878341146428166676460720423167409013083888435705219134747702, 124); + node_index + .insert(1793425644853312386902998134061844248823841892125424765064687913085130719534, 125); + node_index + .insert(1983622665815164792580256365519803214027269990384198703315493315153573288434, 126); + node_index + .insert(3615973154491344159350153395208055142342062736505558158666764642048838175685, 127); + node_index + .insert(2751715913626909804252433699602081411293721754810298670422380863932998088133, 128); + node_index + .insert(186918881712189523740089713555196200069231794627360499557319265374750577226, 129); + node_index + .insert(696585542544434929491503209053317581175146475161262066468664234437983008675, 130); + node_index + .insert(4359830495913805154545225899592517767672472055784183911796827820518038513, 131); + node_index + .insert(2954335207058000607751727656601539819316106074875304820535376873121805433820, 132); + node_index + .insert(2510390039949230255082316953804013731253145558531652907601250263563528226672, 133); + node_index + .insert(3226995230854300551967642178527450300960499043510855212238369890580256668532, 134); + node_index + .insert(1620924075233065517364532267959798304439946408626316544761884056227131075831, 135); + node_index + .insert(1610900122192929153657761847202689179268074338802437933866337242354758101660, 136); + node_index + .insert(2565949095169598991903537465065584077778440646580025930326495506484329892725, 137); + node_index + .insert(1012362975819634411571869839734809106575285344002573666983595104659295812607, 138); + node_index + .insert(242312010918799555845832460483650516749990744287009628468613253461264531026, 139); + node_index + .insert(1104776796569046483584574115975216172161469015460244982207905888870418040487, 140); + node_index + .insert(3289555912992777681578950209252840071327866822704829766247386311885634446673, 141); + node_index + .insert(3133389957643610781371406448279843175887428913359743769920083259111437722268, 142); + node_index + .insert(1169918710119352022244140656086831769713178729571654411898266328562003734517, 143); + node_index + .insert(3592039235252149652556167686570045881877115549259769455422056097903987237819, 144); + node_index + .insert(2048175709145840597887667330964815895803568760936075562647625937161113445908, 145); + node_index + .insert(602222645962845554276438041138511866776339653340605661136009451417275008940, 146); + node_index + .insert(3318742320906017551291978242369663702298606650330380959683585594592748661010, 147); + node_index + .insert(564160996724923690963741657975239836484028160385417016805513722318839327322, 148); + node_index + .insert(656294390376267384135628810815504467149264887388377312825033341338166573620, 149); + node_index + .insert(1201592236750942207412694706123654466634588634474700675083122904145559965915, 150); + node_index + .insert(2141408926815137181004274624388915700231991905288681935478972043994347966006, 151); + node_index + .insert(1440847977042239464860406726605567303568767649154338464116083965986084755262, 152); + node_index + .insert(950585553138591375958592507876257987416844837045084288783892644487908218679, 153); + node_index + .insert(257643451533833048856069434258149588745628261389615631070776723485957908127, 154); let atts = TreeEnsembleAttributes { nodes_falsenodeids, @@ -2265,7 +2421,7 @@ fn tree_ensemble_classifier_binary_class_helper( }; let mut X = TensorTrait::new( - array![1,9].span(), + array![1, 9].span(), array![ FP16x16 { mag: 39321, sign: false }, FP16x16 { mag: 32768, sign: false }, @@ -2276,7 +2432,8 @@ fn tree_ensemble_classifier_binary_class_helper( FP16x16 { mag: 0, sign: false }, FP16x16 { mag: 16384, sign: false }, FP16x16 { mag: 0, sign: false }, - ].span() + ] + .span() ); (classifier, X) diff --git a/tests/nodes/split_fp16x16_1d_variable_parts.cairo b/tests/nodes/split_fp16x16_1d_variable_parts.cairo index f5f46e75d..888e8ef2e 100644 --- a/tests/nodes/split_fp16x16_1d_variable_parts.cairo +++ b/tests/nodes/split_fp16x16_1d_variable_parts.cairo @@ -15,7 +15,14 @@ fn test_split_fp16x16_1d_variable_parts() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = input_0.split(0, Option::None(()), Option::Some(TensorTrait::::new(shape: array![2].span(), data: array![2, 4].span(),))); + let y = input_0 + .split( + 0, + Option::None(()), + Option::Some( + TensorTrait::::new(shape: array![2].span(), data: array![2, 4].span(),) + ) + ); assert_seq_eq(y, z); } diff --git a/tests/nodes/split_fp16x16_2d_variable_parts.cairo b/tests/nodes/split_fp16x16_2d_variable_parts.cairo index d627014e2..ab6090f5c 100644 --- a/tests/nodes/split_fp16x16_2d_variable_parts.cairo +++ b/tests/nodes/split_fp16x16_2d_variable_parts.cairo @@ -15,7 +15,14 @@ fn test_split_fp16x16_2d_variable_parts() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = input_0.split(1, Option::None(()), Option::Some(TensorTrait::::new(shape: array![2].span(), data: array![2, 4].span(),))); + let y = input_0 + .split( + 1, + Option::None(()), + Option::Some( + TensorTrait::::new(shape: array![2].span(), data: array![2, 4].span(),) + ) + ); assert_seq_eq(y, z); } diff --git a/tests/nodes/split_fp16x16_zero_size.cairo b/tests/nodes/split_fp16x16_zero_size.cairo index c9056376b..a0749c925 100644 --- a/tests/nodes/split_fp16x16_zero_size.cairo +++ b/tests/nodes/split_fp16x16_zero_size.cairo @@ -15,7 +15,14 @@ fn test_split_fp16x16_zero_size() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = input_0.split(0, Option::None(()), Option::Some(TensorTrait::::new(shape: array![3].span(), data: array![0, 0, 0].span(),))); + let y = input_0 + .split( + 0, + Option::None(()), + Option::Some( + TensorTrait::::new(shape: array![3].span(), data: array![0, 0, 0].span(),) + ) + ); assert_seq_eq(y, z); } diff --git a/tests/nodes/split_u32_1d_variable_parts.cairo b/tests/nodes/split_u32_1d_variable_parts.cairo index 2680a6f77..3e4c7facd 100644 --- a/tests/nodes/split_u32_1d_variable_parts.cairo +++ b/tests/nodes/split_u32_1d_variable_parts.cairo @@ -14,7 +14,14 @@ fn test_split_u32_1d_variable_parts() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = input_0.split(0, Option::None(()), Option::Some(TensorTrait::::new(shape: array![2].span(), data: array![2, 4].span(),))); + let y = input_0 + .split( + 0, + Option::None(()), + Option::Some( + TensorTrait::::new(shape: array![2].span(), data: array![2, 4].span(),) + ) + ); assert_seq_eq(y, z); } diff --git a/tests/nodes/split_u32_2d_variable_parts.cairo b/tests/nodes/split_u32_2d_variable_parts.cairo index b38f87122..f591103a2 100644 --- a/tests/nodes/split_u32_2d_variable_parts.cairo +++ b/tests/nodes/split_u32_2d_variable_parts.cairo @@ -14,7 +14,14 @@ fn test_split_u32_2d_variable_parts() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = input_0.split(1, Option::None(()), Option::Some(TensorTrait::::new(shape: array![2].span(), data: array![2, 4].span(),))); + let y = input_0 + .split( + 1, + Option::None(()), + Option::Some( + TensorTrait::::new(shape: array![2].span(), data: array![2, 4].span(),) + ) + ); assert_seq_eq(y, z); } diff --git a/tests/nodes/split_u32_zero_size.cairo b/tests/nodes/split_u32_zero_size.cairo index 39eeb9d67..53c432c19 100644 --- a/tests/nodes/split_u32_zero_size.cairo +++ b/tests/nodes/split_u32_zero_size.cairo @@ -14,7 +14,14 @@ fn test_split_u32_zero_size() { let input_0 = input_0::input_0(); let z = output_0::output_0(); - let y = input_0.split(0, Option::None(()), Option::Some(TensorTrait::::new(shape: array![3].span(), data: array![0, 0, 0].span(),))); + let y = input_0 + .split( + 0, + Option::None(()), + Option::Some( + TensorTrait::::new(shape: array![3].span(), data: array![0, 0, 0].span(),) + ) + ); assert_seq_eq(y, z); } diff --git a/tests/operators/transpose_test.cairo b/tests/operators/transpose_test.cairo index ee720be28..b4683162e 100644 --- a/tests/operators/transpose_test.cairo +++ b/tests/operators/transpose_test.cairo @@ -28,9 +28,7 @@ fn transpose_test_values() { #[test] #[available_gas(200000000000)] fn transpose_test_1D() { - let tensor = TensorTrait::< - u32 - >::new(shape: array![4].span(), data: array![0, 1, 2, 3].span(),); + let tensor = TensorTrait::::new(shape: array![4].span(), data: array![0, 1, 2, 3].span(),); let result = tensor.transpose(axes: array![0].span()); From f37b57af338d79e0a2c850a09aef5ae00239758c Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Thu, 11 Jan 2024 20:52:57 -0500 Subject: [PATCH 37/38] Update Scarb.toml --- Scarb.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Scarb.toml b/Scarb.toml index 6ebf1fd1f..cf6169b1f 100644 --- a/Scarb.toml +++ b/Scarb.toml @@ -1,6 +1,6 @@ [package] name = "orion" -version = "0.1.9" +version = "0.2.0" cairo-version = "2.4.0" edition = "2023_10" description = "ONNX Runtime in Cairo for verifiable ML inference using STARK" From 4d34849c342ca74d3113527f666df55994809eb3 Mon Sep 17 00:00:00 2001 From: raphaelDkhn Date: Thu, 11 Jan 2024 20:54:51 -0500 Subject: [PATCH 38/38] Create .tool-versions --- .tool-versions | 1 + 1 file changed, 1 insertion(+) create mode 100644 .tool-versions diff --git a/.tool-versions b/.tool-versions new file mode 100644 index 000000000..21cfc8077 --- /dev/null +++ b/.tool-versions @@ -0,0 +1 @@ +scarb 2.4.0