Skip to content

Commit

Permalink
Yayy! working
Browse files Browse the repository at this point in the history
  • Loading branch information
parth-07 committed Apr 22, 2022
1 parent f266be5 commit 8f6aa67
Showing 1 changed file with 73 additions and 35 deletions.
108 changes: 73 additions & 35 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1515,15 +1515,19 @@ namespace clad {
ArgDeclStmts.push_back(BuildDeclStmt(gradVarDecl));
idx++;
}
Expr* pullback = dfdx();
if (!pullback && FD->getReturnType()->isLValueReferenceType())
pullback = getZeroInit(FD->getReturnType().getNonReferenceType());

// FIXME: Remove this restriction.
if (!FD->getReturnType()->isVoidType()) {
assert((dfdx() && !FD->getReturnType()->isVoidType()) &&
assert((pullback && !FD->getReturnType()->isVoidType()) &&
"Call to function returning non-void type with no dfdx() is not "
"supported!");
}

if (FD->getReturnType()->isVoidType()) {
assert(dfdx() == nullptr && FD->getReturnType()->isVoidType() &&
assert(pullback == nullptr && FD->getReturnType()->isVoidType() &&
"Call to function returning void type should not have any "
"corresponding dfdx().");
}
Expand All @@ -1533,9 +1537,9 @@ namespace clad {
DerivedCallOutputArgs.end());
pullbackCallArgs = DerivedCallArgs;

if (dfdx())
if (pullback)
pullbackCallArgs.insert(pullbackCallArgs.begin() + CE->getNumArgs(),
dfdx());
pullback);

// Try to find it in builtin derivatives
OverloadedDerivedFn =
Expand Down Expand Up @@ -1684,6 +1688,18 @@ namespace clad {
m_ExternalSource->ActBeforeFinalizingVisitCallExpr(
CE, OverloadedDerivedFn, DerivedCallArgs, ArgResultDecls, asGrad);

// FIXME: Why are we cloning args here? We already created different
// expressions for call to original function and call to gradient.
// Re-clone function arguments again, since they are required at 2 places:
// call to gradient and call to original function. At this point, each arg
// is either a simple expression or a reference to a temporary variable.
// Therefore cloning it has constant complexity.
std::transform(std::begin(CallArgs),
std::end(CallArgs),
std::begin(CallArgs),
[this](Expr* E) { return Clone(E); });

Expr* call = nullptr;

if (FD->getReturnType()->isReferenceType()) {
DiffRequest transformReq;
Expand All @@ -1702,26 +1718,48 @@ namespace clad {
policy.Bool = true;
transformedSourcefn->print(llvm::outs(), policy);
}
// FIXME: Add derivative of `this`.
for (std::size_t i=0, e = CE->getNumArgs(); i != e; ++i) {
const Expr* arg = CE->getArg(i);
StmtDiff argDiff = Visit(arg);
if (argDiff.getExpr_dx()) {
Expr* derivedArg = argDiff.getExpr_dx();
if (isCladArrayType(derivedArg->getType()))
CallArgs.push_back(derivedArg);
else
CallArgs.push_back(
BuildOp(UnaryOperatorKind::UO_AddrOf, derivedArg, noLoc));
}
else
CallArgs.push_back(m_Sema.ActOnCXXNullPtrLiteral(noLoc).get());
}
llvm::errs()<<"Dumping fwd call args:\n";
for (auto arg : CallArgs) {
arg->dumpColor();
llvm::errs()<<"\n";
}
call = m_Sema
.ActOnCallExpr(getCurrentScope(),
BuildDeclRef(transformedSourcefn), noLoc,
CallArgs, noLoc)
.get();
auto callRes = StoreAndRef(call);
auto resValue =
utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "value");
auto resAdjoint = utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "adjoint");
return {resValue, resAdjoint};
} else {
// Recreate the original call expression.
call = m_Sema
.ActOnCallExpr(getCurrentScope(),
Clone(CE->getCallee()),
noLoc,
CallArgs,
noLoc)
.get();
return StmtDiff(call);
}
// FIXME: Why are we cloning args here? We already created different
// expressions for call to original function and call to gradient.
// Re-clone function arguments again, since they are required at 2 places:
// call to gradient and call to original function. At this point, each arg
// is either a simple expression or a reference to a temporary variable.
// Therefore cloning it has constant complexity.
std::transform(std::begin(CallArgs),
std::end(CallArgs),
std::begin(CallArgs),
[this](Expr* E) { return Clone(E); });
// Recreate the original call expression.
Expr* call = m_Sema
.ActOnCallExpr(getCurrentScope(),
Clone(CE->getCallee()),
noLoc,
CallArgs,
noLoc)
.get();
return StmtDiff(call);
return {};
}

StmtDiff ReverseModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) {
Expand Down Expand Up @@ -2183,14 +2221,16 @@ namespace clad {
// ```
// Computation of hessian requires this code to be correctly
// differentiated.
// if (isVDRefType || specialThisDiffCase) {
// VDDerivedType = getNonConstType(VDDerivedType, m_Context, m_Sema);
// initDiff = Visit(VD->getInit());
// if (initDiff.getExpr_dx())
// VDDerivedInit = initDiff.getExpr_dx();
// else
// VDDerivedType = VDDerivedType.getNonReferenceType();
// }
if (isVDRefType || specialThisDiffCase) {
VDDerivedType = getNonConstType(VDDerivedType, m_Context, m_Sema);
initDiff = Visit(VD->getInit());
// if (initDiff.getExpr_dx())
// VDDerivedInit = initDiff.getExpr_dx();
// else
// VDDerivedType = VDDerivedType.getNonReferenceType();
if (!initDiff.getExpr_dx())
VDDerivedType = VDDerivedType.getNonReferenceType();
}
// Here separate behaviour for record and non-record types is only
// necessary to preserve the old tests.
if (VDDerivedType->isRecordType())
Expand All @@ -2208,7 +2248,7 @@ namespace clad {
// differentiated and should not be differentiated again.
// If `VD` is a reference to a non-local variable then also there's no
// need to call `Visit` since non-local variables are not differentiated.
if (!isVDRefType || isa<CallExpr>(VD->getInit()->IgnoreParenImpCasts())) {
if (!isVDRefType) {
Expr* derivedE = BuildDeclRef(VDDerived);
if (isVDRefType)
derivedE = BuildOp(UnaryOperatorKind::UO_Deref, derivedE);
Expand All @@ -2233,8 +2273,6 @@ namespace clad {
getZeroInit(VDDerivedType));
addToCurrentBlock(assignToZero, direction::reverse);
}
} else {
initDiff = VD->getInit() ? Visit(VD->getInit()) : StmtDiff{};
}
VarDecl* VDClone = nullptr;
// Here separate behaviour for record and non-record types is only
Expand Down Expand Up @@ -2945,7 +2983,7 @@ namespace clad {
if (effectiveReturnType->isVoidType())
effectiveReturnType = m_Context.DoubleTy;
else
paramTypes.push_back(m_Function->getReturnType());
paramTypes.push_back(m_Function->getReturnType().getNonReferenceType());
}

if (auto MD = dyn_cast<CXXMethodDecl>(m_Function)) {
Expand Down

0 comments on commit 8f6aa67

Please sign in to comment.