Skip to content

Commit

Permalink
Skip derivative creation of const vars and params
Browse files Browse the repository at this point in the history
  • Loading branch information
kchristin22 committed Nov 3, 2024
1 parent cddc21d commit 583e80a
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 0 deletions.
3 changes: 3 additions & 0 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,9 @@ namespace clad {
}

bool DiffRequest::shouldHaveAdjoint(const VarDecl* VD) const {
if (VD->getType().isConstQualified())
return false;

if (!EnableVariedAnalysis)
return true;

Expand Down
2 changes: 2 additions & 0 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
i == m_DiffReq->getNumParams() - 1)
continue;
auto VDDerivedType = param->getType();
if (VDDerivedType.isConstQualified())
continue;
// We cannot initialize derived variable for pointer types because
// we do not know the correct size.
if (utils::isArrayOrPointerType(VDDerivedType))
Expand Down
38 changes: 38 additions & 0 deletions test/Gradient/Assignments.C
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,34 @@ double f23(double x, double y) {
//CHECK-NEXT: }
//CHECK-NEXT: }

double constVal(double y, const double x) {
const double z = y;
y *= z;
return y * x;
}

//CHECK: void constVal_grad_0(double y, const double x, double *_d_y) {
//CHECK-NEXT: const double z = y;
//CHECK-NEXT: double _t0 = y;
//CHECK-NEXT: y *= z;
//CHECK-NEXT: *_d_y += 1 * x;
//CHECK-NEXT: {
//CHECK-NEXT: y = _t0;
//CHECK-NEXT: double _r_d0 = *_d_y;
//CHECK-NEXT: *_d_y = 0.;
//CHECK-NEXT: *_d_y += _r_d0 * z;
//CHECK-NEXT: }
//CHECK-NEXT:}

double constValInput(const double x) {
return x;
}

//CHECK: void constValInput_grad(const double x, double *_d_x) {
//CHECK-NEXT: *_d_x += 1;
//CHECK-NEXT:}


#define TEST(F, x, y) \
{ \
result[0] = 0; \
Expand Down Expand Up @@ -884,4 +912,14 @@ int main() {
TEST(f21, 6, 4); // CHECK-EXEC: {1.00, 0.00}
TEST(f22, 6, 4); // CHECK-EXEC: {0.00, 0.00}
TEST(f23, 7, 5); // CHECK-EXEC: {1.00, 1.00}

auto const_test = clad::gradient(constVal, "y");
double const_test_result = 0;
const_test.execute(3, 4, &const_test_result);
printf("%.2f\n", const_test_result); // CHECK-EXEC: 12.00

auto const_test_input = clad::gradient(constValInput);
double const_test_input_result = 0;
const_test_input.execute(3, &const_test_input_result);
printf("%.2f\n", const_test_input_result); // CHECK-EXEC: 1.00
}

0 comments on commit 583e80a

Please sign in to comment.