Skip to content

Commit

Permalink
Fix merge error affecting i8 with wmma (#1743)
Browse files Browse the repository at this point in the history
Merging in i8 wmma fix into rocm-5.7 staging branch

Co-authored-by: Alex Brown <[email protected]>
  • Loading branch information
yoichiyoshida and AlexBrownAMD authored Jul 14, 2023
1 parent 5b36bb8 commit 97e0cfc
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 6 deletions.
6 changes: 3 additions & 3 deletions Tensile/Components/MFMA.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class WMMASelection(MFMA):
def __call__(self, writer, accOutStart, accOutEnd, in0, in1, accInStart, accInEnd, accStoreCIdx, firstIter):
kernel = writer.kernel
inType = kernel["ProblemType"]["DataType"].toNameAbbrev()
neg = " neg_lo:[1,1,1]" if (inType == "i8") else ""
inType = "iu8" if inType == "i8" else inType
outType = kernel["ProblemType"]["ComputeDataType"].toNameAbbrev()
if kernel["ProblemType"]["DataType"].isComplex():
Expand All @@ -38,12 +39,11 @@ def __call__(self, writer, accOutStart, accOutEnd, in0, in1, accInStart, accInEn
# miB = kernel["MatrixInstB"]
str0 = in1 if kernel["SourceSwap"] else in0
str1 = in0 if kernel["SourceSwap"] else in1

# use const 0 for src2 in firstIter case
src2 = "0" if firstIter else "v[%u:%u]"%(accOutStart, accOutEnd)

kStr = "v_wmma_%s_%ux%ux%u_%s v[%u+%u:%u+%u], %s, %s, %s%s" \
% (outType, miM, miN, miK, inType, accInStart, accStoreCIdx, accInEnd, accStoreCIdx, str0, str1, src2, writer.endLine)
kStr = "v_wmma_%s_%ux%ux%u_%s v[%u+%u:%u+%u], %s, %s, %s%s%s" \
% (outType, miM, miN, miK, inType, accInStart, accStoreCIdx, accInEnd, accStoreCIdx, str0, str1, src2, neg, writer.endLine)

return kStr

Expand Down
3 changes: 0 additions & 3 deletions Tensile/KernelWriterAssembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -6468,8 +6468,6 @@ def mfmaIter(self, kernel, u, innerUnroll, vregSetIdx, lastKinloop=False, tail=F

# calculate constant
is_mfma = globalParameters["AsmCaps"][self.version]["HasMFMA"]
# mfma_1k = "_1k" if (kernel["MFMA_BF16_1K"] or kernel["ProblemType"]["Fp16AltImpl"]) else ""
# accumRegType = "a" if not kernel["MIArchVgpr"] else "v"

numRegistersIn = miInputType.numRegisters()
numRegistersOut = kernel["MIRegPerOut"]
Expand All @@ -6479,7 +6477,6 @@ def mfmaIter(self, kernel, u, innerUnroll, vregSetIdx, lastKinloop=False, tail=F
dividerFortidInK = kernel["MatrixInstN"] * kernel["MatrixInstB"]
numMIInput = kernel["MIInputPerThread"]
miInTypeName = "bf16" if kernel["ProblemType"]["Fp16AltImpl"] else miInputType.toNameAbbrev() # v_mfma_[...xK]<InType>
# neg = " neg_lo:[1,1,1]" if ((not is_mfma) and (miInTypeName == "i8")) else ""
miInTypeName = "iu8" if ((not is_mfma) and miInTypeName == "i8") else miInTypeName
miOutTypeName = miInputType.MIOutputTypeNameAbbrev() # v_mfma_<OutType>..
miOutTypeName = miOutTypeName if is_mfma else kernel["ProblemType"]["ComputeDataType"].toNameAbbrev()
Expand Down

0 comments on commit 97e0cfc

Please sign in to comment.