diff --git a/Tensile/Components/MFMA.py b/Tensile/Components/MFMA.py index ec052fa74f..60d6cb4f8b 100644 --- a/Tensile/Components/MFMA.py +++ b/Tensile/Components/MFMA.py @@ -28,6 +28,7 @@ class WMMASelection(MFMA): def __call__(self, writer, accOutStart, accOutEnd, in0, in1, accInStart, accInEnd, accStoreCIdx, firstIter): kernel = writer.kernel inType = kernel["ProblemType"]["DataType"].toNameAbbrev() + neg = " neg_lo:[1,1,1]" if (inType == "i8") else "" inType = "iu8" if inType == "i8" else inType outType = kernel["ProblemType"]["ComputeDataType"].toNameAbbrev() if kernel["ProblemType"]["DataType"].isComplex(): @@ -38,12 +39,11 @@ def __call__(self, writer, accOutStart, accOutEnd, in0, in1, accInStart, accInEn # miB = kernel["MatrixInstB"] str0 = in1 if kernel["SourceSwap"] else in0 str1 = in0 if kernel["SourceSwap"] else in1 - # use const 0 for src2 in firstIter case src2 = "0" if firstIter else "v[%u:%u]"%(accOutStart, accOutEnd) - kStr = "v_wmma_%s_%ux%ux%u_%s v[%u+%u:%u+%u], %s, %s, %s%s" \ - % (outType, miM, miN, miK, inType, accInStart, accStoreCIdx, accInEnd, accStoreCIdx, str0, str1, src2, writer.endLine) + kStr = "v_wmma_%s_%ux%ux%u_%s v[%u+%u:%u+%u], %s, %s, %s%s%s" \ + % (outType, miM, miN, miK, inType, accInStart, accStoreCIdx, accInEnd, accStoreCIdx, str0, str1, src2, neg, writer.endLine) return kStr diff --git a/Tensile/KernelWriterAssembly.py b/Tensile/KernelWriterAssembly.py index df395b548a..5f9d34b0c5 100644 --- a/Tensile/KernelWriterAssembly.py +++ b/Tensile/KernelWriterAssembly.py @@ -6468,8 +6468,6 @@ def mfmaIter(self, kernel, u, innerUnroll, vregSetIdx, lastKinloop=False, tail=F # calculate constant is_mfma = globalParameters["AsmCaps"][self.version]["HasMFMA"] - # mfma_1k = "_1k" if (kernel["MFMA_BF16_1K"] or kernel["ProblemType"]["Fp16AltImpl"]) else "" - # accumRegType = "a" if not kernel["MIArchVgpr"] else "v" numRegistersIn = miInputType.numRegisters() numRegistersOut = kernel["MIRegPerOut"] @@ -6479,7 +6477,6 @@ def mfmaIter(self, kernel, u, innerUnroll, vregSetIdx, lastKinloop=False, tail=F dividerFortidInK = kernel["MatrixInstN"] * kernel["MatrixInstB"] numMIInput = kernel["MIInputPerThread"] miInTypeName = "bf16" if kernel["ProblemType"]["Fp16AltImpl"] else miInputType.toNameAbbrev() # v_mfma_[...xK] - # neg = " neg_lo:[1,1,1]" if ((not is_mfma) and (miInTypeName == "i8")) else "" miInTypeName = "iu8" if ((not is_mfma) and miInTypeName == "i8") else miInTypeName miOutTypeName = miInputType.MIOutputTypeNameAbbrev() # v_mfma_.. miOutTypeName = miOutTypeName if is_mfma else kernel["ProblemType"]["ComputeDataType"].toNameAbbrev()