Skip to content

Commit

Permalink
Fix synthesizing literals function for enums (vgvassilev#1113)
Browse files Browse the repository at this point in the history
  • Loading branch information
kchristin22 authored Oct 12, 2024
1 parent 04b353b commit f86eede
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 1 deletion.
16 changes: 15 additions & 1 deletion lib/Differentiator/ConstantFolder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//----------------------------------------------------------------------------//

#include "ConstantFolder.h"
#include "clad/Differentiator/Compatibility.h"

#include "clang/AST/ASTContext.h"

Expand Down Expand Up @@ -141,7 +142,20 @@ namespace clad {
// SourceLocation noLoc;
Expr* Result = 0;
QT = QT.getCanonicalType();
if (QT->isPointerType()) {
if (QT->isEnumeralType()) {
llvm::APInt APVal(C.getIntWidth(QT), val,
QT->isSignedIntegerOrEnumerationType());
Result = clad::synthesizeLiteral(
dyn_cast<EnumType>(QT)->getDecl()->getIntegerType(), C, APVal);
SourceLocation noLoc;
Expr* cast = CXXStaticCastExpr::Create(
C, QT, CLAD_COMPAT_ExprValueKind_R_or_PR_Value,
clang::CastKind::CK_IntegralCast, Result, /*CXXCastPath=*/nullptr,
C.getTrivialTypeSourceInfo(QT, noLoc)
CLAD_COMPAT_CLANG12_CastExpr_DefaultFPO,
noLoc, noLoc, SourceRange());
Result = cast;
} else if (QT->isPointerType()) {
Result = clad::synthesizeLiteral(QT, C);
} else if (QT->isBooleanType()) {
Result = clad::synthesizeLiteral(QT, C, (bool)val);
Expand Down
153 changes: 153 additions & 0 deletions test/Gradient/Switch.C
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,146 @@ double fn7(double u, double v) {
// CHECK-NEXT: }
// CHECK-NEXT: }

enum Op {
Add,
Sub,
Mul,
Div
};

double fn24(double x, double y, Op op) {
double res = 0;
switch (op) {
case Add:
res = x + y;
break;
case Sub:
res = x - y;
break;
case Mul:
res = x * y;
break;
case Div:
res = x / y;
break;
}
return res;
}

// CHECK: void fn24_grad_0_1(double x, double y, Op op, double *_d_x, double *_d_y) {
// CHECK-NEXT: Op _d_op = static_cast<Op>(0U);
// CHECK-NEXT: Op _cond0;
// CHECK-NEXT: double _t0;
// CHECK-NEXT: clad::tape<unsigned {{int|long}}> _t1 = {};
// CHECK-NEXT: double _t2;
// CHECK-NEXT: double _t3;
// CHECK-NEXT: double _t4;
// CHECK-NEXT: double _d_res = 0.;
// CHECK-NEXT: double res = 0;
// CHECK-NEXT: {
// CHECK-NEXT: _cond0 = op;
// CHECK-NEXT: switch (_cond0) {
// CHECK-NEXT: {
// CHECK-NEXT: case Add:
// CHECK-NEXT: res = x + y;
// CHECK-NEXT: _t0 = res;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: clad::push(_t1, {{1U|1UL}});
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: case Sub:
// CHECK-NEXT: res = x - y;
// CHECK-NEXT: _t2 = res;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: clad::push(_t1, {{2U|2UL}});
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: case Mul:
// CHECK-NEXT: res = x * y;
// CHECK-NEXT: _t3 = res;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: clad::push(_t1, {{3U|3UL}});
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: case Div:
// CHECK-NEXT: res = x / y;
// CHECK-NEXT: _t4 = res;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: clad::push(_t1, {{4U|4UL}});
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: clad::push(_t1, {{5U|5UL}});
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: _d_res += 1;
// CHECK-NEXT: {
// CHECK-NEXT: switch (clad::pop(_t1)) {
// CHECK-NEXT: case {{5U|5UL}}:
// CHECK-NEXT: ;
// CHECK-NEXT: case {{4U|4UL}}:
// CHECK-NEXT: ;
// CHECK-NEXT: {
// CHECK-NEXT: {
// CHECK-NEXT: res = _t4;
// CHECK-NEXT: double _r_d3 = _d_res;
// CHECK-NEXT: _d_res = 0.;
// CHECK-NEXT: *_d_x += _r_d3 / y;
// CHECK-NEXT: double _r0 = _r_d3 * -(x / (y * y));
// CHECK-NEXT: _d_y += _r0;
// CHECK-NEXT: }
// CHECK-NEXT: if (Div == _cond0)
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: case {{3U|3UL}}:
// CHECK-NEXT: ;
// CHECK-NEXT: {
// CHECK-NEXT: {
// CHECK-NEXT: res = _t3;
// CHECK-NEXT: double _r_d2 = _d_res;
// CHECK-NEXT: _d_res = 0.;
// CHECK-NEXT: *_d_x += _r_d2 * y;
// CHECK-NEXT: _d_y += x * _r_d2;
// CHECK-NEXT: }
// CHECK-NEXT: if (Mul == _cond0)
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: case {{2U|2UL}}:
// CHECK-NEXT: ;
// CHECK-NEXT: {
// CHECK-NEXT: {
// CHECK-NEXT: res = _t2;
// CHECK-NEXT: double _r_d1 = _d_res;
// CHECK-NEXT: _d_res = 0.;
// CHECK-NEXT: *_d_x += _r_d1;
// CHECK-NEXT: _d_y += -_r_d1;
// CHECK-NEXT: }
// CHECK-NEXT: if (Sub == _cond0)
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: case {{1U|1UL}}:
// CHECK-NEXT: ;
// CHECK-NEXT: {
// CHECK-NEXT: {
// CHECK-NEXT: res = _t0;
// CHECK-NEXT: double _r_d0 = _d_res;
// CHECK-NEXT: _d_res = 0.;
// CHECK-NEXT: *_d_x += _r_d0;
// CHECK-NEXT: _d_y += _r_d0;
// CHECK-NEXT: }
// CHECK-NEXT: if (Add == _cond0)
// CHECK-NEXT: break;
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT:}


#define TEST_2(F, x, y) \
{ \
Expand All @@ -691,6 +831,14 @@ double fn7(double u, double v) {
printf("{%.2f, %.2f}\n", result[0], result[1]); \
}

#define TEST_2_Op(F, x, y, op) \
{ \
result[0] = result[1] = 0; \
auto d_##F = clad::gradient(F, "x, y"); \
d_##F.execute(x, y, op, result, result + 1); \
printf("{%.2f, %.2f}\n", result[0], result[1]); \
}

int main() {
double result[2] = {};

Expand All @@ -705,4 +853,9 @@ int main() {

TEST_GRADIENT(fn6, 2, 3, 5, &result[0], &result[1]); // CHECK-EXEC: {5.00, 3.00}
TEST_GRADIENT(fn7, 2, 3, 5, &result[0], &result[1]); // CHECK-EXEC: {3.00, 2.00}

TEST_2_Op(fn24, 3, 5, Add); // CHECK-EXEC: {1.00, 1.00}
TEST_2_Op(fn24, 3, 5, Sub); // CHECK-EXEC: {1.00, -1.00}
TEST_2_Op(fn24, 3, 5, Mul); // CHECK-EXEC: {5.00, 3.00}
TEST_2_Op(fn24, 3, 5, Div); // CHECK-EXEC: {0.20, -0.12}
}

0 comments on commit f86eede

Please sign in to comment.