diff --git a/src/arraymancer/laser/tensor/initialization.nim b/src/arraymancer/laser/tensor/initialization.nim index 30477ed5..7b1b869d 100644 --- a/src/arraymancer/laser/tensor/initialization.nim +++ b/src/arraymancer/laser/tensor/initialization.nim @@ -285,7 +285,19 @@ func item*[T_IN, T_OUT](t: Tensor[T_IN], _: typedesc[T_OUT]): T_OUT = ## This only works for Tensors (of any rank) that contain one single element. ## If the tensor has more than one element IndexDefect is raised. if likely(t.size == 1): - when T_OUT is Complex64: + when T_IN is Complex and T_OUT is Complex: + # When the input and the output types are Complex, we need to find + # the "base" type of the output type (e.g. float32 or float64), + # and then convert the real and imaginary parts of the input value + # into the output base type before creating the output complex type + type TT = typeof( + block: + var tmp: T_OUT + tmp.re + ) + let val = t.squeeze[0] + result = complex(TT(val.re), TT(val.im)) + elif T_OUT is Complex64: result = complex(float64(t.squeeze[0])) elif T_OUT is Complex32: result = complex(float32(t.squeeze[0])) diff --git a/tests/tensor/test_shapeshifting.nim b/tests/tensor/test_shapeshifting.nim index f4f247bd..2525ab79 100644 --- a/tests/tensor/test_shapeshifting.nim +++ b/tests/tensor/test_shapeshifting.nim @@ -157,6 +157,10 @@ proc main() = let a = [[[[1]]]].toTensor let value = a.item(Complex64) check value == complex(1.0, 0) + block: + let a = [[[[complex[float64](1.0, 1.1)]]]].toTensor + let value = a.item(Complex32) + check value == complex[float32](1.0, 1.1) test "Unsqueeze": block: