Skip to content

Commit

Permalink
Merge pull request #554 from gizatechxyz/develop
Browse files Browse the repository at this point in the history
Merge Develop into Main
  • Loading branch information
raphaelDkhn authored Jan 27, 2024
2 parents 11e5799 + d4ecfdf commit 8ff7389
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 3 deletions.
3 changes: 1 addition & 2 deletions Scarb.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ homepage = "https://github.com/gizatechxyz/orion"
alexandria_merkle_tree = { git = "https://github.com/keep-starknet-strange/alexandria.git", rev = "01a7690" }
alexandria_data_structures = { git = "https://github.com/keep-starknet-strange/alexandria.git", rev = "01a7690" }
alexandria_sorting = { git = "https://github.com/keep-starknet-strange/alexandria.git", rev = "01a7690" }
# TODO: update to https://github.com/influenceth/cubit & change rev
cubit = { git = "https://github.com/akhercha/cubit.git", rev = "d3869a3" }
cubit = { git = "https://github.com/influenceth/cubit.git", rev = "6275608" }

[scripts]
sierra = "cairo-compile . -r"
Expand Down
9 changes: 9 additions & 0 deletions src/operators/nn/functional/gemm.cairo
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use alexandria_data_structures::array_ext::SpanTraitExt;
use core::array::SpanTrait;

use orion::numbers::NumberTrait;
Expand Down Expand Up @@ -48,6 +49,14 @@ fn gemm<

match C {
Option::Some(c) => {
let broadcast_c_shape = if c.shape.len() == 1 {
array![1].span().concat(c.shape)
} else {
c.shape
};

let c = Tensor { shape: broadcast_c_shape, data: c.data };

return mul_by_scalar(@A.matmul(@B), alpha) + mul_by_scalar(@c, beta);
},
Option::None(_) => { return mul_by_scalar(@A.matmul(@B), alpha); }
Expand Down
2 changes: 1 addition & 1 deletion src/operators/tensor/helpers.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -496,4 +496,4 @@ impl SpanPartialOrd<T, +Drop<T>, +Copy<T>, +PartialEq<T>, +PartialOrd<T>> of Par
fn lt(lhs: Span<T>, rhs: Span<T>) -> bool {
span_cmp(lhs, rhs) < 0
}
}
}

0 comments on commit 8ff7389

Please sign in to comment.