Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CodeGen: Rewrite dot product lowering using a dedicated IR instruction #1512

Merged
merged 6 commits into from
Nov 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CodeGen/include/Luau/AssemblyBuilderA64.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ class AssemblyBuilderA64
void fneg(RegisterA64 dst, RegisterA64 src);
void fsqrt(RegisterA64 dst, RegisterA64 src);
void fsub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2);
void faddp(RegisterA64 dst, RegisterA64 src);

// Vector component manipulation
void ins_4s(RegisterA64 dst, RegisterA64 src, uint8_t index);
Expand Down
2 changes: 2 additions & 0 deletions CodeGen/include/Luau/AssemblyBuilderX64.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ class AssemblyBuilderX64
void vpshufps(RegisterX64 dst, RegisterX64 src1, OperandX64 src2, uint8_t shuffle);
void vpinsrd(RegisterX64 dst, RegisterX64 src1, OperandX64 src2, uint8_t offset);

void vdpps(OperandX64 dst, OperandX64 src1, OperandX64 src2, uint8_t mask);

// Run final checks
bool finalize();

Expand Down
4 changes: 4 additions & 0 deletions CodeGen/include/Luau/IrData.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,10 @@ enum class IrCmd : uint8_t
// A: TValue
UNM_VEC,

// Compute dot product between two vectors
// A, B: TValue
DOT_VEC,

// Compute Luau 'not' operation on destructured TValue
// A: tag
// B: int (value)
Expand Down
1 change: 1 addition & 0 deletions CodeGen/include/Luau/IrUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ inline bool hasResult(IrCmd cmd)
case IrCmd::SUB_VEC:
case IrCmd::MUL_VEC:
case IrCmd::DIV_VEC:
case IrCmd::DOT_VEC:
case IrCmd::UNM_VEC:
case IrCmd::NOT_ANY:
case IrCmd::CMP_ANY:
Expand Down
8 changes: 8 additions & 0 deletions CodeGen/src/AssemblyBuilderA64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,14 @@ void AssemblyBuilderA64::fabs(RegisterA64 dst, RegisterA64 src)
placeR1("fabs", dst, src, 0b000'11110'01'1'0000'01'10000);
}

void AssemblyBuilderA64::faddp(RegisterA64 dst, RegisterA64 src)
{
CODEGEN_ASSERT(dst.kind == KindA64::d || dst.kind == KindA64::s);
CODEGEN_ASSERT(dst.kind == src.kind);

placeR1("faddp", dst, src, 0b011'11110'0'0'11000'01101'10 | ((dst.kind == KindA64::d) << 12));
}

void AssemblyBuilderA64::fadd(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2)
{
if (dst.kind == KindA64::d)
Expand Down
5 changes: 5 additions & 0 deletions CodeGen/src/AssemblyBuilderX64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -946,6 +946,11 @@ void AssemblyBuilderX64::vpinsrd(RegisterX64 dst, RegisterX64 src1, OperandX64 s
placeAvx("vpinsrd", dst, src1, src2, offset, 0x22, false, AVX_0F3A, AVX_66);
}

void AssemblyBuilderX64::vdpps(OperandX64 dst, OperandX64 src1, OperandX64 src2, uint8_t mask)
{
placeAvx("vdpps", dst, src1, src2, mask, 0x40, false, AVX_0F3A, AVX_66);
}

bool AssemblyBuilderX64::finalize()
{
code.resize(codePos - code.data());
Expand Down
2 changes: 2 additions & 0 deletions CodeGen/src/IrDump.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ const char* getCmdName(IrCmd cmd)
return "DIV_VEC";
case IrCmd::UNM_VEC:
return "UNM_VEC";
case IrCmd::DOT_VEC:
return "DOT_VEC";
case IrCmd::NOT_ANY:
return "NOT_ANY";
case IrCmd::CMP_ANY:
Expand Down
15 changes: 15 additions & 0 deletions CodeGen/src/IrLoweringA64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,21 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
build.fneg(inst.regA64, regOp(inst.a));
break;
}
case IrCmd::DOT_VEC:
{
inst.regA64 = regs.allocReg(KindA64::d, index);

RegisterA64 temp = regs.allocTemp(KindA64::q);
RegisterA64 temps = castReg(KindA64::s, temp);
RegisterA64 regs = castReg(KindA64::s, inst.regA64);

build.fmul(temp, regOp(inst.a), regOp(inst.b));
build.faddp(regs, temps); // x+y
build.dup_4s(temp, temp, 2);
build.fadd(regs, regs, temps); // +z
build.fcvt(inst.regA64, regs);
break;
}
case IrCmd::NOT_ANY:
{
inst.regA64 = regs.allocReuse(KindA64::w, index, {inst.a, inst.b});
Expand Down
14 changes: 14 additions & 0 deletions CodeGen/src/IrLoweringX64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,20 @@ void IrLoweringX64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
build.vxorpd(inst.regX64, regOp(inst.a), build.f32x4(-0.0, -0.0, -0.0, -0.0));
break;
}
case IrCmd::DOT_VEC:
{
inst.regX64 = regs.allocRegOrReuse(SizeX64::xmmword, index, {inst.a, inst.b});

ScopedRegX64 tmp1{regs};
ScopedRegX64 tmp2{regs};

RegisterX64 tmpa = vecOp(inst.a, tmp1);
RegisterX64 tmpb = (inst.a == inst.b) ? tmpa : vecOp(inst.b, tmp2);

build.vdpps(inst.regX64, tmpa, tmpb, 0x71); // 7 = 0b0111, sum first 3 products into first float
build.vcvtss2sd(inst.regX64, inst.regX64, inst.regX64);
break;
}
case IrCmd::NOT_ANY:
{
// TODO: if we have a single user which is a STORE_INT, we are missing the opportunity to write directly to target
Expand Down
104 changes: 73 additions & 31 deletions CodeGen/src/IrTranslateBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ static const int kMinMaxUnrolledParams = 5;
static const int kBit32BinaryOpUnrolledParams = 5;

LUAU_FASTFLAGVARIABLE(LuauVectorLibNativeCodegen);
LUAU_FASTFLAGVARIABLE(LuauVectorLibNativeDot);

namespace Luau
{
Expand Down Expand Up @@ -907,15 +908,26 @@ static BuiltinImplResult translateBuiltinVectorMagnitude(

build.loadAndCheckTag(arg1, LUA_TVECTOR, build.vmExit(pcpos));

IrOp x = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0));
IrOp y = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4));
IrOp z = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8));
IrOp sum;

if (FFlag::LuauVectorLibNativeDot)
{
IrOp a = build.inst(IrCmd::LOAD_TVALUE, arg1, build.constInt(0));

IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x);
IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y);
IrOp z2 = build.inst(IrCmd::MUL_NUM, z, z);
sum = build.inst(IrCmd::DOT_VEC, a, a);
}
else
{
IrOp x = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0));
IrOp y = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4));
IrOp z = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8));

IrOp sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, x2, y2), z2);
IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x);
IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y);
IrOp z2 = build.inst(IrCmd::MUL_NUM, z, z);

sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, x2, y2), z2);
}

IrOp mag = build.inst(IrCmd::SQRT_NUM, sum);

Expand Down Expand Up @@ -945,25 +957,43 @@ static BuiltinImplResult translateBuiltinVectorNormalize(

build.loadAndCheckTag(arg1, LUA_TVECTOR, build.vmExit(pcpos));

IrOp x = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0));
IrOp y = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4));
IrOp z = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8));
if (FFlag::LuauVectorLibNativeDot)
{
IrOp a = build.inst(IrCmd::LOAD_TVALUE, arg1, build.constInt(0));
IrOp sum = build.inst(IrCmd::DOT_VEC, a, a);

IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x);
IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y);
IrOp z2 = build.inst(IrCmd::MUL_NUM, z, z);
IrOp mag = build.inst(IrCmd::SQRT_NUM, sum);
IrOp inv = build.inst(IrCmd::DIV_NUM, build.constDouble(1.0), mag);
IrOp invvec = build.inst(IrCmd::NUM_TO_VEC, inv);

IrOp sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, x2, y2), z2);
IrOp result = build.inst(IrCmd::MUL_VEC, a, invvec);

IrOp mag = build.inst(IrCmd::SQRT_NUM, sum);
IrOp inv = build.inst(IrCmd::DIV_NUM, build.constDouble(1.0), mag);
result = build.inst(IrCmd::TAG_VECTOR, result);

IrOp xr = build.inst(IrCmd::MUL_NUM, x, inv);
IrOp yr = build.inst(IrCmd::MUL_NUM, y, inv);
IrOp zr = build.inst(IrCmd::MUL_NUM, z, inv);
build.inst(IrCmd::STORE_TVALUE, build.vmReg(ra), result);
}
else
{
IrOp x = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0));
IrOp y = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4));
IrOp z = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8));

build.inst(IrCmd::STORE_VECTOR, build.vmReg(ra), xr, yr, zr);
build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TVECTOR));
IrOp x2 = build.inst(IrCmd::MUL_NUM, x, x);
IrOp y2 = build.inst(IrCmd::MUL_NUM, y, y);
IrOp z2 = build.inst(IrCmd::MUL_NUM, z, z);

IrOp sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, x2, y2), z2);

IrOp mag = build.inst(IrCmd::SQRT_NUM, sum);
IrOp inv = build.inst(IrCmd::DIV_NUM, build.constDouble(1.0), mag);

IrOp xr = build.inst(IrCmd::MUL_NUM, x, inv);
IrOp yr = build.inst(IrCmd::MUL_NUM, y, inv);
IrOp zr = build.inst(IrCmd::MUL_NUM, z, inv);

build.inst(IrCmd::STORE_VECTOR, build.vmReg(ra), xr, yr, zr);
build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TVECTOR));
}

return {BuiltinImplType::Full, 1};
}
Expand Down Expand Up @@ -1019,19 +1049,31 @@ static BuiltinImplResult translateBuiltinVectorDot(IrBuilder& build, int nparams
build.loadAndCheckTag(arg1, LUA_TVECTOR, build.vmExit(pcpos));
build.loadAndCheckTag(args, LUA_TVECTOR, build.vmExit(pcpos));

IrOp x1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0));
IrOp x2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(0));
IrOp xx = build.inst(IrCmd::MUL_NUM, x1, x2);
IrOp sum;

IrOp y1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4));
IrOp y2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(4));
IrOp yy = build.inst(IrCmd::MUL_NUM, y1, y2);
if (FFlag::LuauVectorLibNativeDot)
{
IrOp a = build.inst(IrCmd::LOAD_TVALUE, arg1, build.constInt(0));
IrOp b = build.inst(IrCmd::LOAD_TVALUE, args, build.constInt(0));

IrOp z1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8));
IrOp z2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(8));
IrOp zz = build.inst(IrCmd::MUL_NUM, z1, z2);
sum = build.inst(IrCmd::DOT_VEC, a, b);
}
else
{
IrOp x1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(0));
IrOp x2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(0));
IrOp xx = build.inst(IrCmd::MUL_NUM, x1, x2);

IrOp y1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(4));
IrOp y2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(4));
IrOp yy = build.inst(IrCmd::MUL_NUM, y1, y2);

IrOp sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, xx, yy), zz);
IrOp z1 = build.inst(IrCmd::LOAD_FLOAT, arg1, build.constInt(8));
IrOp z2 = build.inst(IrCmd::LOAD_FLOAT, args, build.constInt(8));
IrOp zz = build.inst(IrCmd::MUL_NUM, z1, z2);

sum = build.inst(IrCmd::ADD_NUM, build.inst(IrCmd::ADD_NUM, xx, yy), zz);
}

build.inst(IrCmd::STORE_DOUBLE, build.vmReg(ra), sum);
build.inst(IrCmd::STORE_TAG, build.vmReg(ra), build.constTag(LUA_TNUMBER));
Expand Down
2 changes: 2 additions & 0 deletions CodeGen/src/IrUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ IrValueKind getCmdValueKind(IrCmd cmd)
case IrCmd::DIV_VEC:
case IrCmd::UNM_VEC:
return IrValueKind::Tvalue;
case IrCmd::DOT_VEC:
return IrValueKind::Double;
case IrCmd::NOT_ANY:
case IrCmd::CMP_ANY:
return IrValueKind::Int;
Expand Down
4 changes: 3 additions & 1 deletion CodeGen/src/OptimizeConstProp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,8 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction&
if (tag == LUA_TBOOLEAN &&
(value.kind == IrOpKind::Inst || (value.kind == IrOpKind::Constant && function.constOp(value).kind == IrConstKind::Int)))
canSplitTvalueStore = true;
else if (tag == LUA_TNUMBER && (value.kind == IrOpKind::Inst || (value.kind == IrOpKind::Constant && function.constOp(value).kind == IrConstKind::Double)))
else if (tag == LUA_TNUMBER &&
(value.kind == IrOpKind::Inst || (value.kind == IrOpKind::Constant && function.constOp(value).kind == IrConstKind::Double)))
canSplitTvalueStore = true;
else if (tag != 0xff && isGCO(tag) && value.kind == IrOpKind::Inst)
canSplitTvalueStore = true;
Expand Down Expand Up @@ -1342,6 +1343,7 @@ static void constPropInInst(ConstPropState& state, IrBuilder& build, IrFunction&
case IrCmd::SUB_VEC:
case IrCmd::MUL_VEC:
case IrCmd::DIV_VEC:
case IrCmd::DOT_VEC:
if (IrInst* a = function.asInstOp(inst.a); a && a->cmd == IrCmd::TAG_VECTOR)
replace(function, inst.a, a->a);

Expand Down
3 changes: 3 additions & 0 deletions tests/AssemblyBuilderA64.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,9 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "FPMath")
SINGLE_COMPARE(fsub(d1, d2, d3), 0x1E633841);
SINGLE_COMPARE(fsub(s29, s29, s28), 0x1E3C3BBD);

SINGLE_COMPARE(faddp(s29, s28), 0x7E30DB9D);
SINGLE_COMPARE(faddp(d29, d28), 0x7E70DB9D);

SINGLE_COMPARE(frinta(d1, d2), 0x1E664041);
SINGLE_COMPARE(frintm(d1, d2), 0x1E654041);
SINGLE_COMPARE(frintp(d1, d2), 0x1E64C041);
Expand Down
2 changes: 2 additions & 0 deletions tests/AssemblyBuilderX64.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,8 @@ TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "AVXTernaryInstructionForms")

SINGLE_COMPARE(vpshufps(xmm7, xmm12, xmmword[rcx + r10], 0b11010100), 0xc4, 0xa1, 0x18, 0xc6, 0x3c, 0x11, 0xd4);
SINGLE_COMPARE(vpinsrd(xmm7, xmm12, xmmword[rcx + r10], 2), 0xc4, 0xa3, 0x19, 0x22, 0x3c, 0x11, 0x02);

SINGLE_COMPARE(vdpps(xmm7, xmm12, xmmword[rcx + r10], 2), 0xc4, 0xa3, 0x19, 0x40, 0x3c, 0x11, 0x02);
}

TEST_CASE_FIXTURE(AssemblyBuilderX64Fixture, "MiscInstructions")
Expand Down
Loading