Skip to content

Commit

Permalink
Add test for functions in inline namespace
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak authored and vgvassilev committed Oct 29, 2023
1 parent 270777e commit 6a35905
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
4 changes: 4 additions & 0 deletions lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,10 @@ namespace clad {
continue;
}
if (DC2->isInlineNamespace()) {
// Inline namespace can be skipped from context, because its members
// are automatically searched from the parent namespace.
// This will also help us to deal with intermediate inline namespaces
// like std::__1::, as present in std functions for libc++.
DC2 = DC2->getParent();
continue;
}
Expand Down
41 changes: 41 additions & 0 deletions test/Gradient/FunctionCalls.C
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,45 @@ double fn10(double x, double y) {
// CHECK-NEXT: * _d_x += _d_out;
// CHECK-NEXT: }

namespace n1{
inline namespace n2{
double sum(const double& x, const double& y) {
return x + y;
}
}
}

namespace clad{
namespace custom_derivatives{
namespace n1{
inline namespace n2{
void sum_pullback(const double& x, const double& y, double _d_y0, clad::array_ref<double> _d_x, clad::array_ref<double> _d_y) {
* _d_x += _d_y0;
* _d_y += _d_y0;
}
}
}
}
}

double fn11(double x, double y) {
return n1::n2::sum(x, y);
}

// CHECK: void fn11_grad(double x, double y, clad::array_ref<double> _d_x, clad::array_ref<double> _d_y) {
// CHECK-NEXT: double _t0;
// CHECK-NEXT: double _t1;
// CHECK-NEXT: _t0 = x;
// CHECK-NEXT: _t1 = y;
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: clad::custom_derivatives::n1::sum_pullback(_t0, _t1, 1, &* _d_x, &* _d_y);
// CHECK-NEXT: double _r0 = * _d_x;
// CHECK-NEXT: double _r1 = * _d_y;
// CHECK-NEXT: }
// CHECK-NEXT: }

template<typename T>
void reset(T* arr, int n) {
for (int i=0; i<n; ++i)
Expand Down Expand Up @@ -649,6 +688,7 @@ int main() {
INIT(fn8);
INIT(fn9);
INIT(fn10);
INIT(fn11);

TEST1_float(fn1, 11); // CHECK-EXEC: {3.00}
TEST2(fn2, 3, 5); // CHECK-EXEC: {1.00, 3.00}
Expand All @@ -661,4 +701,5 @@ int main() {
TEST2(fn8, 3, 5); // CHECK-EXEC: {7.62, 4.57}
TEST2(fn9, 3, 5); // CHECK-EXEC: {5.00, 3.00}
TEST2(fn10, 8, 5); // CHECK-EXEC: {0.00, 7.00}
TEST2(fn11, 3, 5); // CHECK-EXEC: {1.00, 1.00}
}

0 comments on commit 6a35905

Please sign in to comment.