Skip to content

Commit

Permalink
Add cudaMemset call after cudaMalloc for derivative pointers (#1129)
Browse files Browse the repository at this point in the history
  • Loading branch information
kchristin22 authored Nov 3, 2024
1 parent cddc21d commit effbb7b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
12 changes: 12 additions & 0 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1766,6 +1766,18 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
.ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), Loc,
llvm::MutableArrayRef<Expr*>(DerivedCallArgs), Loc)
.get();
if (FD->getNameAsString() == "cudaMalloc") {
if (auto* addrOp = dyn_cast<UnaryOperator>(DerivedCallArgs[0]))
if (addrOp->getOpcode() == UO_AddrOf)
DerivedCallArgs[0] = addrOp->getSubExpr(); // get the pointer

llvm::SmallVector<Expr*, 3> args = {DerivedCallArgs[0],
getZeroInit(m_Context.IntTy),
DerivedCallArgs[1]};
addToCurrentBlock(call_dx, direction::forward);
addToCurrentBlock(GetFunctionCall("cudaMemset", "", args));
call_dx = nullptr;
}
return StmtDiff(call, call_dx);
}
// For calls to C-style memory deallocation functions, we do not need to
Expand Down
1 change: 1 addition & 0 deletions test/CUDA/GradientKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ double fn_memory(double *out, double *in) {
//CHECK-NEXT: double *_d_in_dev = nullptr;
//CHECK-NEXT: double *in_dev = nullptr;
//CHECK-NEXT: cudaMalloc(&_d_in_dev, 10 * sizeof(double));
//CHECK-NEXT: cudaMemset(_d_in_dev, 0, 10 * sizeof(double));
//CHECK-NEXT: cudaMalloc(&in_dev, 10 * sizeof(double));
//CHECK-NEXT: cudaMemcpy(in_dev, in, 10 * sizeof(double), cudaMemcpyHostToDevice);
//CHECK-NEXT: kernel_call<<<1, 10>>>(out, in_dev);
Expand Down

0 comments on commit effbb7b

Please sign in to comment.