diff --git a/src/BoundedTerminationPass.cpp b/src/BoundedTerminationPass.cpp index b55095c..9fc3475 100644 --- a/src/BoundedTerminationPass.cpp +++ b/src/BoundedTerminationPass.cpp @@ -1,7 +1,7 @@ +#include "llvm/ADT/SCCIterator.h" #include "llvm/ADT/StringRef.h" -#include "llvm/Analysis/CGSCCPassManager.h" -#include "llvm/Analysis/LazyCallGraph.h" +#include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/ScalarEvolution.h" @@ -373,15 +373,6 @@ FunctionTerminationPass::run(llvm::Function &F, llvm::ScalarEvolution &SE = FAM.getResult(F); llvm::LoopInfo &loop_info = FAM.getResult(F); - // const auto &outer_result = detect_cgscc_recursion(F, FAM); - // - // // If this function is part of a recursive call-graph group - // // (a non-trivial CGSCC), then we don't do any more analysis. - // // This way, we avoid recursing in getResult - // // for the callees of this function. - // if (outer_result.elt >= DoesThisTerminate::Unbounded) { - // return outer_result; - // } std::map blocks_to_results; // SetVector preserves insertion order - which is nice because it makes this @@ -450,9 +441,38 @@ ModuleTerminationPass::run(llvm::Module &IR, llvm::ModuleAnalysisManager &AM) { {&function, FAM.getResult(function)}); } - // Step 2 : CGSCC analysis; note recursion up-front. - // TODO: This is taking literally forever. No, figuratively forever - // auto &call_graph = AM.getResult(IR); + // Step 2 : CGSCC analysis. + // Take anything in a recursive group and force it Unknown. + // See also NoRecursionCheck in clang-tidy + llvm::CallGraph &CG = AM.getResult(IR); + for(llvm::scc_iterator SCCI = llvm::scc_begin(&CG); !SCCI.isAtEnd(); ++SCCI) { + if(!SCCI.hasCycle()) { + // SCC doesn't have a loop. We don't need to update anything. + continue; + } + // SCC has a loop. Update all functions to note they're mutually recursive. + const std::vector &nextSCC = *SCCI; + TerminationPassResult shared_result = { + .elt = DoesThisTerminate::Unknown, + .explanation = "part of a call graph that contains a loop: ", + }; + int count = 0; + for(llvm::CallGraphNode *node : nextSCC) { + const llvm::Function *f = node->getFunction(); + shared_result.explanation = (shared_result.explanation + llvm::demangle(f->getName())); + if(count < nextSCC.size() - 1) { + shared_result.explanation += ", "; + } else { + ++count; + } + } + for(llvm::CallGraphNode *node : nextSCC) { + llvm::Function *f = node->getFunction(); + const auto new_result = update(per_function_results[f], {shared_result}); + per_function_results[f] = new_result; + } + } + // TODO // Step 3 : worklist algorithm on the call graph. // TODO: