Skip to content

Commit

Permalink
Improve A64 lowering for vector operations by using vector instructio…
Browse files Browse the repository at this point in the history
…ns (#1164)

This change replaces scalar versions of vector opcodes for A64 with
actual vector instructions.

We take the approach similar to X64: patch last component with zero,
perform the math, patch last component with type tag. I'm hoping that in
the future the type tag will be placed separately (separate IR opcode?)
because right now chains of math operations result in excessive type tag
operations.

To patch the type tag without always keeping a mask in a register,
ins.4s instructions can be used; unfortunately it's only capable of
patching a register in-place, so we need an extra register copy in case
it's not last-use. Usually it's last-use so the patch is free; probably
with IR rework mentioned above all of this can be improved (e.g.
load-with-patch will never need to copy).

~It's not 100% clear if we *have* to patch type tag: Apple does preserve
denormals but we'd need to benchmark this to see if there's an actual
performance impact. But for now we're playing it safe.~

This was tested by running the conformance tests, and new opcode
implementations were checked by comparing the result with
https://armconverter.com/.

Performance testing is complicated by the fact that OSS Luau doesn't
support vector constructor out of the box, and other limitations of
codegen. I've hacked vector constructor/type into REPL and confirmed
that on a test that calls this function in a loop (not inlined):

```
function fma(a: vector, b: vector, c: vector)
        return a * b + c
end
```

... this PR improves performance by ~6% (note that probably most of the
overhead here is the call dispatch; I didn't want to brave testing a
more complex expression). The assembly for an individual operation
changes as follows:

Before:

```
#   %14 = MUL_VEC %12, %13                                    ; useCount: 2, lastUse: %22
 dup         s29,v31.s[0]
 dup         s28,v30.s[0]
 fmul        s29,s29,s28
 ins         v31.s[0],v29.s[0]
 dup         s29,v31.s[1]
 dup         s28,v30.s[1]
 fmul        s29,s29,s28
 ins         v31.s[1],v29.s[0]
 dup         s29,v31.s[2]
 dup         s28,v30.s[2]
 fmul        s29,s29,s28
 ins         v31.s[2],v29.s[0]
```

After:

```
#   %14 = MUL_VEC %12, %13                                    ; useCount: 2, lastUse: %22
 ins         v31.s[3],w31
 ins         v30.s[3],w31
 fmul        v31.4s,v31.4s,v30.4s
 movz        w17,#4
 ins         v31.s[3],w17
```

**edit** final form (see comments):

```
#   %14 = MUL_VEC %12, %13                                    ; useCount: 2, lastUse: %22
 fmul        v31.4s,v31.4s,v30.4s
 movz        w17,#4
 ins         v31.s[3],w17
```
  • Loading branch information
zeux authored Feb 16, 2024
1 parent ea14e65 commit c5f4d97
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 61 deletions.
2 changes: 1 addition & 1 deletion CodeGen/include/Luau/AssemblyBuilderA64.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,6 @@ class AssemblyBuilderA64
void placeSR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, int shift = 0, int N = 0);
void placeSR2(const char* name, RegisterA64 dst, RegisterA64 src, uint8_t op, uint8_t op2 = 0);
void placeR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, uint8_t op2);
void placeR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t sizes, uint8_t op, uint8_t op2);
void placeR1(const char* name, RegisterA64 dst, RegisterA64 src, uint32_t op);
void placeI12(const char* name, RegisterA64 dst, RegisterA64 src1, int src2, uint8_t op);
void placeI16(const char* name, RegisterA64 dst, int src, uint8_t op, int shift = 0);
Expand All @@ -230,6 +229,7 @@ class AssemblyBuilderA64
void placeBM(const char* name, RegisterA64 dst, RegisterA64 src1, uint32_t src2, uint8_t op);
void placeBFM(const char* name, RegisterA64 dst, RegisterA64 src1, int src2, uint8_t op, int immr, int imms);
void placeER(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t op, int shift);
void placeVR(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint16_t op, uint8_t op2);

void place(uint32_t word);

Expand Down
92 changes: 65 additions & 27 deletions CodeGen/src/AssemblyBuilderA64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,22 @@ AssemblyBuilderA64::~AssemblyBuilderA64()

void AssemblyBuilderA64::mov(RegisterA64 dst, RegisterA64 src)
{
CODEGEN_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x || dst == sp);
CODEGEN_ASSERT(dst.kind == src.kind || (dst.kind == KindA64::x && src == sp) || (dst == sp && src.kind == KindA64::x));
if (dst.kind != KindA64::q)
{
CODEGEN_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x || dst == sp);
CODEGEN_ASSERT(dst.kind == src.kind || (dst.kind == KindA64::x && src == sp) || (dst == sp && src.kind == KindA64::x));

if (dst == sp || src == sp)
placeR1("mov", dst, src, 0b00'100010'0'000000000000);
if (dst == sp || src == sp)
placeR1("mov", dst, src, 0b00'100010'0'000000000000);
else
placeSR2("mov", dst, src, 0b01'01010);
}
else
placeSR2("mov", dst, src, 0b01'01010);
{
CODEGEN_ASSERT(dst.kind == src.kind);

placeR1("mov", dst, src, 0b10'01110'10'1'00000'00011'1 | (src.index << 6));
}
}

void AssemblyBuilderA64::mov(RegisterA64 dst, int src)
Expand Down Expand Up @@ -575,12 +584,18 @@ void AssemblyBuilderA64::fadd(RegisterA64 dst, RegisterA64 src1, RegisterA64 src

placeR3("fadd", dst, src1, src2, 0b11110'01'1, 0b0010'10);
}
else
else if (dst.kind == KindA64::s)
{
CODEGEN_ASSERT(dst.kind == KindA64::s && src1.kind == KindA64::s && src2.kind == KindA64::s);
CODEGEN_ASSERT(src1.kind == KindA64::s && src2.kind == KindA64::s);

placeR3("fadd", dst, src1, src2, 0b11110'00'1, 0b0010'10);
}
else
{
CODEGEN_ASSERT(dst.kind == KindA64::q && src1.kind == KindA64::q && src2.kind == KindA64::q);

placeVR("fadd", dst, src1, src2, 0b0'01110'0'0'1, 0b11010'1);
}
}

void AssemblyBuilderA64::fdiv(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2)
Expand All @@ -591,12 +606,18 @@ void AssemblyBuilderA64::fdiv(RegisterA64 dst, RegisterA64 src1, RegisterA64 src

placeR3("fdiv", dst, src1, src2, 0b11110'01'1, 0b0001'10);
}
else
else if (dst.kind == KindA64::s)
{
CODEGEN_ASSERT(dst.kind == KindA64::s && src1.kind == KindA64::s && src2.kind == KindA64::s);
CODEGEN_ASSERT(src1.kind == KindA64::s && src2.kind == KindA64::s);

placeR3("fdiv", dst, src1, src2, 0b11110'00'1, 0b0001'10);
}
else
{
CODEGEN_ASSERT(dst.kind == KindA64::q && src1.kind == KindA64::q && src2.kind == KindA64::q);

placeVR("fdiv", dst, src1, src2, 0b1'01110'00'1, 0b11111'1);
}
}

void AssemblyBuilderA64::fmul(RegisterA64 dst, RegisterA64 src1, RegisterA64 src2)
Expand All @@ -607,12 +628,18 @@ void AssemblyBuilderA64::fmul(RegisterA64 dst, RegisterA64 src1, RegisterA64 src

placeR3("fmul", dst, src1, src2, 0b11110'01'1, 0b0000'10);
}
else
else if (dst.kind == KindA64::s)
{
CODEGEN_ASSERT(dst.kind == KindA64::s && src1.kind == KindA64::s && src2.kind == KindA64::s);
CODEGEN_ASSERT(src1.kind == KindA64::s && src2.kind == KindA64::s);

placeR3("fmul", dst, src1, src2, 0b11110'00'1, 0b0000'10);
}
else
{
CODEGEN_ASSERT(dst.kind == KindA64::q && src1.kind == KindA64::q && src2.kind == KindA64::q);

placeVR("fmul", dst, src1, src2, 0b1'01110'00'1, 0b11011'1);
}
}

void AssemblyBuilderA64::fneg(RegisterA64 dst, RegisterA64 src)
Expand All @@ -623,12 +650,18 @@ void AssemblyBuilderA64::fneg(RegisterA64 dst, RegisterA64 src)

placeR1("fneg", dst, src, 0b000'11110'01'1'0000'10'10000);
}
else
else if (dst.kind == KindA64::s)
{
CODEGEN_ASSERT(dst.kind == KindA64::s && src.kind == KindA64::s);
CODEGEN_ASSERT(src.kind == KindA64::s);

placeR1("fneg", dst, src, 0b000'11110'00'1'0000'10'10000);
}
else
{
CODEGEN_ASSERT(dst.kind == KindA64::q && src.kind == KindA64::q);

placeR1("fneg", dst, src, 0b011'01110'1'0'10000'01111'10);
}
}

void AssemblyBuilderA64::fsqrt(RegisterA64 dst, RegisterA64 src)
Expand All @@ -646,12 +679,18 @@ void AssemblyBuilderA64::fsub(RegisterA64 dst, RegisterA64 src1, RegisterA64 src

placeR3("fsub", dst, src1, src2, 0b11110'01'1, 0b0011'10);
}
else
else if (dst.kind == KindA64::s)
{
CODEGEN_ASSERT(dst.kind == KindA64::s && src1.kind == KindA64::s && src2.kind == KindA64::s);
CODEGEN_ASSERT(src1.kind == KindA64::s && src2.kind == KindA64::s);

placeR3("fsub", dst, src1, src2, 0b11110'00'1, 0b0011'10);
}
else
{
CODEGEN_ASSERT(dst.kind == KindA64::q && src1.kind == KindA64::q && src2.kind == KindA64::q);

placeVR("fsub", dst, src1, src2, 0b0'01110'10'1, 0b11010'1);
}
}

void AssemblyBuilderA64::ins_4s(RegisterA64 dst, RegisterA64 src, uint8_t index)
Expand Down Expand Up @@ -952,18 +991,6 @@ void AssemblyBuilderA64::placeR3(const char* name, RegisterA64 dst, RegisterA64
commit();
}

void AssemblyBuilderA64::placeR3(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint8_t sizes, uint8_t op, uint8_t op2)
{
if (logText)
log(name, dst, src1, src2);

CODEGEN_ASSERT(dst.kind == KindA64::w || dst.kind == KindA64::x || dst.kind == KindA64::d || dst.kind == KindA64::q);
CODEGEN_ASSERT(dst.kind == src1.kind && dst.kind == src2.kind);

place(dst.index | (src1.index << 5) | (op2 << 10) | (src2.index << 16) | (op << 21) | (sizes << 29));
commit();
}

void AssemblyBuilderA64::placeR1(const char* name, RegisterA64 dst, RegisterA64 src, uint32_t op)
{
if (logText)
Expand Down Expand Up @@ -1226,6 +1253,17 @@ void AssemblyBuilderA64::placeER(const char* name, RegisterA64 dst, RegisterA64
commit();
}

void AssemblyBuilderA64::placeVR(const char* name, RegisterA64 dst, RegisterA64 src1, RegisterA64 src2, uint16_t op, uint8_t op2)
{
if (logText)
logAppend(" %-12sv%d.4s,v%d.4s,v%d.4s\n", name, dst.index, src1.index, src2.index);

CODEGEN_ASSERT(dst.kind == KindA64::q && dst.kind == src1.kind && dst.kind == src2.kind);

place(dst.index | (src1.index << 5) | (op2 << 10) | (src2.index << 16) | (op << 21) | (1 << 30));
commit();
}

void AssemblyBuilderA64::place(uint32_t word)
{
CODEGEN_ASSERT(codePos < codeEnd);
Expand Down
122 changes: 89 additions & 33 deletions CodeGen/src/IrLoweringA64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "lgc.h"

LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauCodeGenFixBufferLenCheckA64, false)
LUAU_FASTFLAGVARIABLE(LuauCodeGenVectorA64, false)

namespace Luau
{
Expand Down Expand Up @@ -673,77 +674,132 @@ void IrLoweringA64::lowerInst(IrInst& inst, uint32_t index, const IrBlock& next)
{
inst.regA64 = regs.allocReuse(KindA64::q, index, {inst.a, inst.b});

RegisterA64 tempa = regs.allocTemp(KindA64::s);
RegisterA64 tempb = regs.allocTemp(KindA64::s);
if (FFlag::LuauCodeGenVectorA64)
{
build.fadd(inst.regA64, regOp(inst.a), regOp(inst.b));

for (uint8_t i = 0; i < 3; i++)
RegisterA64 tempw = regs.allocTemp(KindA64::w);
build.mov(tempw, LUA_TVECTOR);
build.ins_4s(inst.regA64, tempw, 3);
}
else
{
build.dup_4s(tempa, regOp(inst.a), i);
build.dup_4s(tempb, regOp(inst.b), i);
build.fadd(tempa, tempa, tempb);
build.ins_4s(inst.regA64, i, castReg(KindA64::q, tempa), 0);
RegisterA64 tempa = regs.allocTemp(KindA64::s);
RegisterA64 tempb = regs.allocTemp(KindA64::s);

for (uint8_t i = 0; i < 3; i++)
{
build.dup_4s(tempa, regOp(inst.a), i);
build.dup_4s(tempb, regOp(inst.b), i);
build.fadd(tempa, tempa, tempb);
build.ins_4s(inst.regA64, i, castReg(KindA64::q, tempa), 0);
}
}
break;
}
case IrCmd::SUB_VEC:
{
inst.regA64 = regs.allocReuse(KindA64::q, index, {inst.a, inst.b});

RegisterA64 tempa = regs.allocTemp(KindA64::s);
RegisterA64 tempb = regs.allocTemp(KindA64::s);
if (FFlag::LuauCodeGenVectorA64)
{
build.fsub(inst.regA64, regOp(inst.a), regOp(inst.b));

for (uint8_t i = 0; i < 3; i++)
RegisterA64 tempw = regs.allocTemp(KindA64::w);
build.mov(tempw, LUA_TVECTOR);
build.ins_4s(inst.regA64, tempw, 3);
}
else
{
build.dup_4s(tempa, regOp(inst.a), i);
build.dup_4s(tempb, regOp(inst.b), i);
build.fsub(tempa, tempa, tempb);
build.ins_4s(inst.regA64, i, castReg(KindA64::q, tempa), 0);
RegisterA64 tempa = regs.allocTemp(KindA64::s);
RegisterA64 tempb = regs.allocTemp(KindA64::s);

for (uint8_t i = 0; i < 3; i++)
{
build.dup_4s(tempa, regOp(inst.a), i);
build.dup_4s(tempb, regOp(inst.b), i);
build.fsub(tempa, tempa, tempb);
build.ins_4s(inst.regA64, i, castReg(KindA64::q, tempa), 0);
}
}
break;
}
case IrCmd::MUL_VEC:
{
inst.regA64 = regs.allocReuse(KindA64::q, index, {inst.a, inst.b});

RegisterA64 tempa = regs.allocTemp(KindA64::s);
RegisterA64 tempb = regs.allocTemp(KindA64::s);
if (FFlag::LuauCodeGenVectorA64)
{
build.fmul(inst.regA64, regOp(inst.a), regOp(inst.b));

for (uint8_t i = 0; i < 3; i++)
RegisterA64 tempw = regs.allocTemp(KindA64::w);
build.mov(tempw, LUA_TVECTOR);
build.ins_4s(inst.regA64, tempw, 3);
}
else
{
build.dup_4s(tempa, regOp(inst.a), i);
build.dup_4s(tempb, regOp(inst.b), i);
build.fmul(tempa, tempa, tempb);
build.ins_4s(inst.regA64, i, castReg(KindA64::q, tempa), 0);
RegisterA64 tempa = regs.allocTemp(KindA64::s);
RegisterA64 tempb = regs.allocTemp(KindA64::s);

for (uint8_t i = 0; i < 3; i++)
{
build.dup_4s(tempa, regOp(inst.a), i);
build.dup_4s(tempb, regOp(inst.b), i);
build.fmul(tempa, tempa, tempb);
build.ins_4s(inst.regA64, i, castReg(KindA64::q, tempa), 0);
}
}
break;
}
case IrCmd::DIV_VEC:
{
inst.regA64 = regs.allocReuse(KindA64::q, index, {inst.a, inst.b});

RegisterA64 tempa = regs.allocTemp(KindA64::s);
RegisterA64 tempb = regs.allocTemp(KindA64::s);
if (FFlag::LuauCodeGenVectorA64)
{
build.fdiv(inst.regA64, regOp(inst.a), regOp(inst.b));

for (uint8_t i = 0; i < 3; i++)
RegisterA64 tempw = regs.allocTemp(KindA64::w);
build.mov(tempw, LUA_TVECTOR);
build.ins_4s(inst.regA64, tempw, 3);
}
else
{
build.dup_4s(tempa, regOp(inst.a), i);
build.dup_4s(tempb, regOp(inst.b), i);
build.fdiv(tempa, tempa, tempb);
build.ins_4s(inst.regA64, i, castReg(KindA64::q, tempa), 0);
RegisterA64 tempa = regs.allocTemp(KindA64::s);
RegisterA64 tempb = regs.allocTemp(KindA64::s);

for (uint8_t i = 0; i < 3; i++)
{
build.dup_4s(tempa, regOp(inst.a), i);
build.dup_4s(tempb, regOp(inst.b), i);
build.fdiv(tempa, tempa, tempb);
build.ins_4s(inst.regA64, i, castReg(KindA64::q, tempa), 0);
}
}
break;
}
case IrCmd::UNM_VEC:
{
inst.regA64 = regs.allocReuse(KindA64::q, index, {inst.a});

RegisterA64 tempa = regs.allocTemp(KindA64::s);
if (FFlag::LuauCodeGenVectorA64)
{
build.fneg(inst.regA64, regOp(inst.a));

for (uint8_t i = 0; i < 3; i++)
RegisterA64 tempw = regs.allocTemp(KindA64::w);
build.mov(tempw, LUA_TVECTOR);
build.ins_4s(inst.regA64, tempw, 3);
}
else
{
build.dup_4s(tempa, regOp(inst.a), i);
build.fneg(tempa, tempa);
build.ins_4s(inst.regA64, i, castReg(KindA64::q, tempa), 0);
RegisterA64 tempa = regs.allocTemp(KindA64::s);

for (uint8_t i = 0; i < 3; i++)
{
build.dup_4s(tempa, regOp(inst.a), i);
build.fneg(tempa, tempa);
build.ins_4s(inst.regA64, i, castReg(KindA64::q, tempa), 0);
}
}
break;
}
Expand Down
12 changes: 12 additions & 0 deletions tests/AssemblyBuilderA64.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "Moves")
{
SINGLE_COMPARE(mov(x0, x1), 0xAA0103E0);
SINGLE_COMPARE(mov(w0, w1), 0x2A0103E0);
SINGLE_COMPARE(mov(q0, q1), 0x4EA11C20);

SINGLE_COMPARE(movz(x0, 42), 0xD2800540);
SINGLE_COMPARE(movz(w0, 42), 0x52800540);
Expand Down Expand Up @@ -501,6 +502,15 @@ TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "PrePostIndexing")
SINGLE_COMPARE(str(q0, mem(x1, 1, AddressKindA64::post)), 0x3C801420);
}

TEST_CASE_FIXTURE(AssemblyBuilderA64Fixture, "SIMDMath")
{
SINGLE_COMPARE(fadd(q0, q1, q2), 0x4E22D420);
SINGLE_COMPARE(fsub(q0, q1, q2), 0x4EA2D420);
SINGLE_COMPARE(fmul(q0, q1, q2), 0x6E22DC20);
SINGLE_COMPARE(fdiv(q0, q1, q2), 0x6E22FC20);
SINGLE_COMPARE(fneg(q0, q1), 0x6EA0F820);
}

TEST_CASE("LogTest")
{
AssemblyBuilderA64 build(/* logText= */ true);
Expand Down Expand Up @@ -552,6 +562,7 @@ TEST_CASE("LogTest")
build.ins_4s(q31, 1, q29, 2);
build.dup_4s(s29, q31, 2);
build.dup_4s(q29, q30, 0);
build.fmul(q0, q1, q2);

build.setLabel(l);
build.ret();
Expand Down Expand Up @@ -594,6 +605,7 @@ TEST_CASE("LogTest")
ins v31.s[1],v29.s[2]
dup s29,v31.s[2]
dup v29.4s,v30.s[0]
fmul v0.4s,v1.4s,v2.4s
.L1:
ret
)";
Expand Down

0 comments on commit c5f4d97

Please sign in to comment.