Skip to content

Commit

Permalink
Fix issue #604: Using the .item function on a Complex Tensor crashes (#…
Browse files Browse the repository at this point in the history
…605)

Also adds a test case for the issue.

Co-authored-by: Angel Ezquerra <[email protected]>
  • Loading branch information
AngelEzquerra and AngelEzquerraAtKeysight authored Oct 20, 2023
1 parent caa5223 commit 6fd4df3
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
14 changes: 13 additions & 1 deletion src/arraymancer/laser/tensor/initialization.nim
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,19 @@ func item*[T_IN, T_OUT](t: Tensor[T_IN], _: typedesc[T_OUT]): T_OUT =
## This only works for Tensors (of any rank) that contain one single element.
## If the tensor has more than one element IndexDefect is raised.
if likely(t.size == 1):
when T_OUT is Complex64:
when T_IN is Complex and T_OUT is Complex:
# When the input and the output types are Complex, we need to find
# the "base" type of the output type (e.g. float32 or float64),
# and then convert the real and imaginary parts of the input value
# into the output base type before creating the output complex type
type TT = typeof(
block:
var tmp: T_OUT
tmp.re
)
let val = t.squeeze[0]
result = complex(TT(val.re), TT(val.im))
elif T_OUT is Complex64:
result = complex(float64(t.squeeze[0]))
elif T_OUT is Complex32:
result = complex(float32(t.squeeze[0]))
Expand Down
4 changes: 4 additions & 0 deletions tests/tensor/test_shapeshifting.nim
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@ proc main() =
let a = [[[[1]]]].toTensor
let value = a.item(Complex64)
check value == complex(1.0, 0)
block:
let a = [[[[complex[float64](1.0, 1.1)]]]].toTensor
let value = a.item(Complex32)
check value == complex[float32](1.0, 1.1)

test "Unsqueeze":
block:
Expand Down

0 comments on commit 6fd4df3

Please sign in to comment.