diff --git a/Sources/WhisperKit/Core/TextDecoder.swift b/Sources/WhisperKit/Core/TextDecoder.swift index f0e8219..85fc33f 100644 --- a/Sources/WhisperKit/Core/TextDecoder.swift +++ b/Sources/WhisperKit/Core/TextDecoder.swift @@ -308,45 +308,49 @@ public extension TextDecoding { return kvCache } - static func updateKVCache(keyTensor: MLMultiArray, keySlice: MLMultiArray, - valueTensor: MLMultiArray, valueSlice: MLMultiArray, - insertAtIndex index: Int) - { - let tensorShape = keyTensor.shape.map { $0.intValue } - let sliceShape = keySlice.shape.map { $0.intValue } - let sliceStrides = keySlice.strides.map { $0.intValue } // same for val - let bytesPerSample = MemoryLayout.size - - keyTensor.withUnsafeMutableBytes { keyTensorPointer, keyTargetStrides in - keySlice.withUnsafeBytes { keySlicePointer in - valueTensor.withUnsafeMutableBytes { valueTensorPointer, valueTargetStrides in - valueSlice.withUnsafeBytes { valueSlicePointer in - // Assuming batch size is always 1 - DispatchQueue.concurrentPerform(iterations: tensorShape[1]) { j in - // Slice size is 3 for prefill and 1 for decode loops - for k in 0...size) + memcpy(&valueData, valueTensor.dataPointer, valueTensor.count * MemoryLayout.size) + + // Calculate dimensions for index mapping + let seqLength = tensorShape[3] + let hiddenDim = tensorShape[1] + + // Concurrent processing across hidden dimension + DispatchQueue.concurrentPerform(iterations: hiddenDim) { j in + for k in 0...size) + memcpy(valueTensor.dataPointer, &valueData, valueTensor.count * MemoryLayout.size) + } static func updateAlignmentWeights( alignmentTensor: MLMultiArray,