Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Additional Nonlinear CG Direction Methods #1878

Merged
merged 7 commits into from
Oct 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions gtsam/nonlinear/NonlinearConjugateGradientOptimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ static VectorValues gradientInPlace(const NonlinearFactorGraph& nfg,

NonlinearConjugateGradientOptimizer::NonlinearConjugateGradientOptimizer(
const NonlinearFactorGraph& graph, const Values& initialValues,
const Parameters& params)
const Parameters& params, const DirectionMethod& directionMethod)
: Base(graph, std::unique_ptr<State>(
new State(initialValues, graph.error(initialValues)))),
params_(params) {}
Expand All @@ -70,7 +70,8 @@ NonlinearConjugateGradientOptimizer::System::advance(const State& current,

GaussianFactorGraph::shared_ptr NonlinearConjugateGradientOptimizer::iterate() {
const auto [newValues, dummy] = nonlinearConjugateGradient<System, Values>(
System(graph_), state_->values, params_, true /* single iteration */);
System(graph_), state_->values, params_, true /* single iteration */,
directionMethod_);
state_.reset(
new State(newValues, graph_.error(newValues), state_->iterations + 1));

Expand All @@ -81,8 +82,8 @@ GaussianFactorGraph::shared_ptr NonlinearConjugateGradientOptimizer::iterate() {
const Values& NonlinearConjugateGradientOptimizer::optimize() {
// Optimize until convergence
System system(graph_);
const auto [newValues, iterations] =
nonlinearConjugateGradient(system, state_->values, params_, false);
const auto [newValues, iterations] = nonlinearConjugateGradient(
system, state_->values, params_, false, directionMethod_);
state_.reset(
new State(std::move(newValues), graph_.error(newValues), iterations));
return state_->values;
Expand Down
84 changes: 76 additions & 8 deletions gtsam/nonlinear/NonlinearConjugateGradientOptimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,57 @@

namespace gtsam {

/// Fletcher-Reeves formula for computing β, the direction of steepest descent.
template <typename Gradient>
double FletcherReeves(const Gradient &currentGradient,
const Gradient &prevGradient) {
// Fletcher-Reeves: beta = g_n'*g_n/g_n-1'*g_n-1
const double beta =
currentGradient.dot(currentGradient) / prevGradient.dot(prevGradient);
return beta;
}

/// Polak-Ribiere formula for computing β, the direction of steepest descent.
template <typename Gradient>
double PolakRibiere(const Gradient &currentGradient,
const Gradient &prevGradient) {
// Polak-Ribiere: beta = g_n'*(g_n-g_n-1)/g_n-1'*g_n-1
const double beta =
std::max(0.0, currentGradient.dot(currentGradient - prevGradient) /
prevGradient.dot(prevGradient));
return beta;
}

/// The Hestenes-Stiefel formula for computing β,
/// the direction of steepest descent.
template <typename Gradient>
double HestenesStiefel(const Gradient &currentGradient,
const Gradient &prevGradient,
const Gradient &direction) {
// Hestenes-Stiefel: beta = g_n'*(g_n-g_n-1)/(-s_n-1')*(g_n-g_n-1)
Gradient d = currentGradient - prevGradient;
const double beta = std::max(0.0, currentGradient.dot(d) / -direction.dot(d));
return beta;
}

/// The Dai-Yuan formula for computing β, the direction of steepest descent.
template <typename Gradient>
double DaiYuan(const Gradient &currentGradient, const Gradient &prevGradient,
const Gradient &direction) {
// Dai-Yuan: beta = g_n'*g_n/(-s_n-1')*(g_n-g_n-1)
const double beta =
std::max(0.0, currentGradient.dot(currentGradient) /
-direction.dot(currentGradient - prevGradient));
return beta;
}

enum class DirectionMethod {
FletcherReeves,
PolakRibiere,
HestenesStiefel,
DaiYuan
};

/** An implementation of the nonlinear CG method using the template below */
class GTSAM_EXPORT NonlinearConjugateGradientOptimizer
: public NonlinearOptimizer {
Expand Down Expand Up @@ -51,14 +102,16 @@ class GTSAM_EXPORT NonlinearConjugateGradientOptimizer

protected:
Parameters params_;
DirectionMethod directionMethod_ = DirectionMethod::PolakRibiere;

const NonlinearOptimizerParams &_params() const override { return params_; }

public:
/// Constructor
NonlinearConjugateGradientOptimizer(const NonlinearFactorGraph &graph,
const Values &initialValues,
const Parameters &params = Parameters());
NonlinearConjugateGradientOptimizer(
const NonlinearFactorGraph &graph, const Values &initialValues,
const Parameters &params = Parameters(),
const DirectionMethod &directionMethod = DirectionMethod::PolakRibiere);

/// Destructor
~NonlinearConjugateGradientOptimizer() override {}
Expand Down Expand Up @@ -140,7 +193,9 @@ double lineSearch(const S &system, const V currentValues, const W &gradient) {
template <class S, class V>
std::tuple<V, int> nonlinearConjugateGradient(
const S &system, const V &initial, const NonlinearOptimizerParams &params,
const bool singleIteration, const bool gradientDescent = false) {
const bool singleIteration,
const DirectionMethod &directionMethod = DirectionMethod::PolakRibiere,
const bool gradientDescent = false) {
// GTSAM_CONCEPT_MANIFOLD_TYPE(V)

size_t iteration = 0;
Expand Down Expand Up @@ -177,10 +232,23 @@ std::tuple<V, int> nonlinearConjugateGradient(
} else {
prevGradient = currentGradient;
currentGradient = system.gradient(currentValues);
// Polak-Ribiere: beta = g'*(g_n-g_n-1)/g_n-1'*g_n-1
const double beta =
std::max(0.0, currentGradient.dot(currentGradient - prevGradient) /
prevGradient.dot(prevGradient));

double beta;
switch (directionMethod) {
case DirectionMethod::FletcherReeves:
beta = FletcherReeves(currentGradient, prevGradient);
break;
case DirectionMethod::PolakRibiere:
beta = PolakRibiere(currentGradient, prevGradient);
break;
case DirectionMethod::HestenesStiefel:
beta = HestenesStiefel(currentGradient, prevGradient, direction);
break;
case DirectionMethod::DaiYuan:
beta = DaiYuan(currentGradient, prevGradient, direction);
break;
}

direction = currentGradient + (beta * direction);
}

Expand Down
43 changes: 43 additions & 0 deletions gtsam/nonlinear/tests/testNonlinearConjugateGradientOptimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,49 @@ TEST(NonlinearConjugateGradientOptimizer, Optimize) {
EXPECT_DOUBLES_EQUAL(0.0, graph.error(result), 1e-4);
}

/* ************************************************************************* */
/// Test different direction methods
TEST(NonlinearConjugateGradientOptimizer, DirectionMethods) {
const auto [graph, initialEstimate] = generateProblem();

NonlinearOptimizerParams param;
param.maxIterations =
500; /* requires a larger number of iterations to converge */
param.verbosity = NonlinearOptimizerParams::SILENT;

// Fletcher-Reeves
{
NonlinearConjugateGradientOptimizer optimizer(
graph, initialEstimate, param, DirectionMethod::FletcherReeves);
Values result = optimizer.optimize();

EXPECT_DOUBLES_EQUAL(0.0, graph.error(result), 1e-4);
}
// Polak-Ribiere
{
NonlinearConjugateGradientOptimizer optimizer(
graph, initialEstimate, param, DirectionMethod::PolakRibiere);
Values result = optimizer.optimize();

EXPECT_DOUBLES_EQUAL(0.0, graph.error(result), 1e-4);
}
// Hestenes-Stiefel
{
NonlinearConjugateGradientOptimizer optimizer(
graph, initialEstimate, param, DirectionMethod::HestenesStiefel);
Values result = optimizer.optimize();

EXPECT_DOUBLES_EQUAL(0.0, graph.error(result), 1e-4);
}
// Dai-Yuan
{
NonlinearConjugateGradientOptimizer optimizer(graph, initialEstimate, param,
DirectionMethod::DaiYuan);
Values result = optimizer.optimize();

EXPECT_DOUBLES_EQUAL(0.0, graph.error(result), 1e-4);
}
}
/* ************************************************************************* */
int main() {
TestResult tr;
Expand Down
Loading