Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More tests, some small bugfixes #1854

Merged
merged 11 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion gtsam/discrete/DiscreteBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,12 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
// sample each node in turn in topological sort order (parents first)
for (auto it = std::make_reverse_iterator(end());
it != std::make_reverse_iterator(begin()); ++it) {
(*it)->sampleInPlace(&result);
const DiscreteConditional::shared_ptr& conditional = *it;
// Sample the conditional only if value for j not already in result
const Key j = conditional->firstFrontalKey();
if (result.count(j) == 0) {
conditional->sampleInPlace(&result);
}
}
return result;
}
Expand Down
18 changes: 13 additions & 5 deletions gtsam/discrete/DiscreteConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,18 @@ size_t DiscreteConditional::argmax(const DiscreteValues& parentsValues) const {

/* ************************************************************************** */
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
assert(nrFrontals() == 1);
Key j = (firstFrontalKey());
// throw if more than one frontal:
if (nrFrontals() != 1) {
throw std::invalid_argument(
"DiscreteConditional::sampleInPlace can only be called on single "
"variable conditionals");
}
Key j = firstFrontalKey();
// throw if values already contains j:
if (values->count(j) > 0) {
throw std::invalid_argument(
"DiscreteConditional::sampleInPlace: values already contains j");
}
size_t sampled = sample(*values); // Sample variable given parents
(*values)[j] = sampled; // store result in partial solution
}
Expand Down Expand Up @@ -467,9 +477,7 @@ double DiscreteConditional::evaluate(const HybridValues& x) const {
}

/* ************************************************************************* */
double DiscreteConditional::negLogConstant() const {
return 0.0;
}
double DiscreteConditional::negLogConstant() const { return 0.0; }

/* ************************************************************************* */

Expand Down
2 changes: 1 addition & 1 deletion gtsam/discrete/DiscreteConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ class GTSAM_EXPORT DiscreteConditional
static_cast<const BaseConditional*>(this)->print(s, formatter);
}

/// Evaluate, just look up in AlgebraicDecisonTree
/// Evaluate, just look up in AlgebraicDecisionTree
double evaluate(const DiscreteValues& values) const {
return ADT::operator()(values);
}
Expand Down
2 changes: 1 addition & 1 deletion gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ GaussianBayesNet HybridBayesNet::choose(
for (auto &&conditional : *this) {
if (auto gm = conditional->asHybrid()) {
// If conditional is hybrid, select based on assignment.
gbn.push_back((*gm)(assignment));
gbn.push_back(gm->choose(assignment));
} else if (auto gc = conditional->asGaussian()) {
// If continuous only, add Gaussian conditional.
gbn.push_back(gc);
Expand Down
2 changes: 2 additions & 0 deletions gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* @brief Get the Gaussian Bayes Net which corresponds to a specific discrete
* value assignment.
*
* @note Any pure discrete factors are ignored.
*
* @param assignment The discrete value assignment for the discrete keys.
* @return GaussianBayesNet
*/
Expand Down
11 changes: 5 additions & 6 deletions gtsam/hybrid/HybridGaussianConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ size_t HybridGaussianConditional::nrComponents() const {
}

/* *******************************************************************************/
GaussianConditional::shared_ptr HybridGaussianConditional::operator()(
GaussianConditional::shared_ptr HybridGaussianConditional::choose(
const DiscreteValues &discreteValues) const {
auto &ptr = conditionals_(discreteValues);
if (!ptr) return nullptr;
Expand All @@ -192,11 +192,10 @@ bool HybridGaussianConditional::equals(const HybridFactor &lf,

// Check the base and the factors:
return BaseFactor::equals(*e, tol) &&
conditionals_.equals(e->conditionals_,
[tol](const GaussianConditional::shared_ptr &f1,
const GaussianConditional::shared_ptr &f2) {
return f1->equals(*(f2), tol);
});
conditionals_.equals(
dellaert marked this conversation as resolved.
Show resolved Hide resolved
e->conditionals_, [tol](const auto &f1, const auto &f2) {
return (!f1 && !f2) || (f1 && f2 && f1->equals(*f2, tol));
});
}

/* *******************************************************************************/
Expand Down
8 changes: 7 additions & 1 deletion gtsam/hybrid/HybridGaussianConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,15 @@ class GTSAM_EXPORT HybridGaussianConditional
/// @{

/// @brief Return the conditional Gaussian for the given discrete assignment.
GaussianConditional::shared_ptr operator()(
GaussianConditional::shared_ptr choose(
const DiscreteValues &discreteValues) const;

/// @brief Syntactic sugar for choose.
GaussianConditional::shared_ptr operator()(
const DiscreteValues &discreteValues) const {
return choose(discreteValues);
}

/// Returns the total number of continuous components
size_t nrComponents() const;

Expand Down
24 changes: 11 additions & 13 deletions gtsam/hybrid/HybridGaussianFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,9 @@ bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const {

// Check the base and the factors:
return Base::equals(*e, tol) &&
factors_.equals(e->factors_,
[tol](const sharedFactor &f1, const sharedFactor &f2) {
return f1->equals(*f2, tol);
});
factors_.equals(e->factors_, [tol](const auto &f1, const auto &f2) {
dellaert marked this conversation as resolved.
Show resolved Hide resolved
return (!f1 && !f2) || (f1 && f2 && f1->equals(*f2, tol));
});
}

/* *******************************************************************************/
Expand Down Expand Up @@ -213,16 +212,15 @@ GaussianFactorGraphTree HybridGaussianFactor::asGaussianFactorGraphTree()
}

/* *******************************************************************************/
double HybridGaussianFactor::potentiallyPrunedComponentError(
const sharedFactor &gf, const VectorValues &values) const {
/// Helper method to compute the error of a component.
static double PotentiallyPrunedComponentError(
const GaussianFactor::shared_ptr &gf, const VectorValues &values) {
// Check if valid pointer
if (gf) {
return gf->error(values);
} else {
// If not valid, pointer, it means this component was pruned,
// so we return maximum error.
// This way the negative exponential will give
// a probability value close to 0.0.
// If nullptr this component was pruned, so we return maximum error. This
// way the negative exponential will give a probability value close to 0.0.
return std::numeric_limits<double>::max();
}
}
Expand All @@ -231,8 +229,8 @@ double HybridGaussianFactor::potentiallyPrunedComponentError(
AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
const VectorValues &continuousValues) const {
// functor to convert from sharedFactor to double error value.
auto errorFunc = [this, &continuousValues](const sharedFactor &gf) {
return this->potentiallyPrunedComponentError(gf, continuousValues);
auto errorFunc = [&continuousValues](const sharedFactor &gf) {
return PotentiallyPrunedComponentError(gf, continuousValues);
};
DecisionTree<Key, double> error_tree(factors_, errorFunc);
return error_tree;
Expand All @@ -242,7 +240,7 @@ AlgebraicDecisionTree<Key> HybridGaussianFactor::errorTree(
double HybridGaussianFactor::error(const HybridValues &values) const {
// Directly index to get the component, no need to build the whole tree.
const sharedFactor gf = factors_(values.discrete());
return potentiallyPrunedComponentError(gf, values.continuous());
return PotentiallyPrunedComponentError(gf, values.continuous());
}

} // namespace gtsam
4 changes: 0 additions & 4 deletions gtsam/hybrid/HybridGaussianFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,6 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
*/
static Factors augment(const FactorValuePairs &factors);

/// Helper method to compute the error of a component.
double potentiallyPrunedComponentError(
const sharedFactor &gf, const VectorValues &continuousValues) const;

/// Helper struct to assist private constructor below.
struct ConstructorHelper;

Expand Down
92 changes: 34 additions & 58 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,32 @@ const Ordering HybridOrdering(const HybridGaussianFactorGraph &graph) {
index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true);
}

/* ************************************************************************ */
static void printFactor(const std::shared_ptr<Factor> &factor,
const DiscreteValues &assignment,
const KeyFormatter &keyFormatter) {
if (auto hgf = std::dynamic_pointer_cast<HybridGaussianFactor>(factor)) {
hgf->operator()(assignment)
->print("HybridGaussianFactor, component:", keyFormatter);
} else if (auto gf = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
factor->print("GaussianFactor:\n", keyFormatter);
} else if (auto df = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
factor->print("DiscreteFactor:\n", keyFormatter);
} else if (auto hc = std::dynamic_pointer_cast<HybridConditional>(factor)) {
if (hc->isContinuous()) {
factor->print("GaussianConditional:\n", keyFormatter);
} else if (hc->isDiscrete()) {
factor->print("DiscreteConditional:\n", keyFormatter);
} else {
hc->asHybrid()
->choose(assignment)
->print("HybridConditional, component:\n", keyFormatter);
}
} else {
factor->print("Unknown factor type\n", keyFormatter);
}
}

/* ************************************************************************ */
void HybridGaussianFactorGraph::printErrors(
const HybridValues &values, const std::string &str,
Expand All @@ -83,69 +109,19 @@ void HybridGaussianFactorGraph::printErrors(
&printCondition) const {
std::cout << str << "size: " << size() << std::endl << std::endl;

std::stringstream ss;

for (size_t i = 0; i < factors_.size(); i++) {
auto &&factor = factors_[i];
std::cout << "Factor " << i << ": ";

// Clear the stringstream
ss.str(std::string());

if (auto hgf = std::dynamic_pointer_cast<HybridGaussianFactor>(factor)) {
if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
hgf->operator()(values.discrete())->print(ss.str(), keyFormatter);
std::cout << "error = " << factor->error(values) << std::endl;
}
} else if (auto hc = std::dynamic_pointer_cast<HybridConditional>(factor)) {
if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
if (hc->isContinuous()) {
factor->print(ss.str(), keyFormatter);
std::cout << "error = " << hc->asGaussian()->error(values) << "\n";
} else if (hc->isDiscrete()) {
factor->print(ss.str(), keyFormatter);
std::cout << "error = " << hc->asDiscrete()->error(values.discrete())
<< "\n";
} else {
// Is hybrid
auto conditionalComponent =
hc->asHybrid()->operator()(values.discrete());
conditionalComponent->print(ss.str(), keyFormatter);
std::cout << "error = " << conditionalComponent->error(values)
<< "\n";
}
}
} else if (auto gf = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
const double errorValue = (factor != nullptr ? gf->error(values) : .0);
if (!printCondition(factor.get(), errorValue, i))
continue; // User-provided filter did not pass

if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = " << errorValue << "\n";
}
} else if (auto df = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
if (factor == nullptr) {
std::cout << "nullptr"
<< "\n";
} else {
factor->print(ss.str(), keyFormatter);
std::cout << "error = " << df->error(values.discrete()) << std::endl;
}

} else {
if (factor == nullptr) {
std::cout << "Factor " << i << ": nullptr\n";
continue;
}
const double errorValue = factor->error(values);
if (!printCondition(factor.get(), errorValue, i))
continue; // User-provided filter did not pass

// Print the factor
std::cout << "Factor " << i << ", error = " << errorValue << "\n";
printFactor(factor, values.discrete(), keyFormatter);
std::cout << "\n";
}
std::cout.flush();
Expand Down
5 changes: 5 additions & 0 deletions gtsam/hybrid/HybridGaussianFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,4 +231,9 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
GaussianFactorGraph operator()(const DiscreteValues& assignment) const;
};

// traits
template <>
struct traits<HybridGaussianFactorGraph>
: public Testable<HybridGaussianFactorGraph> {};

} // namespace gtsam
21 changes: 10 additions & 11 deletions gtsam/hybrid/tests/Switching.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,11 @@ inline std::pair<KeyVector, std::vector<int>> makeBinaryOrdering(
return {new_order, levels};
}

/* ***************************************************************************
*/
/* ****************************************************************************/
using MotionModel = BetweenFactor<double>;

// Test fixture with switching network.
/// ϕ(X(0)) .. ϕ(X(k),X(k+1)) .. ϕ(X(k);z_k) .. ϕ(M(0)) .. ϕ(M(k),M(k+1))
dellaert marked this conversation as resolved.
Show resolved Hide resolved
struct Switching {
size_t K;
DiscreteKeys modes;
Expand All @@ -140,8 +140,8 @@ struct Switching {
: K(K) {
using noiseModel::Isotropic;

// Create DiscreteKeys for binary K modes.
for (size_t k = 0; k < K; k++) {
// Create DiscreteKeys for K-1 binary modes.
for (size_t k = 0; k < K - 1; k++) {
dellaert marked this conversation as resolved.
Show resolved Hide resolved
modes.emplace_back(M(k), 2);
}

Expand All @@ -153,34 +153,33 @@ struct Switching {
}

// Create hybrid factor graph.
// Add a prior on X(0).

// Add a prior ϕ(X(0)) on X(0).
nonlinearFactorGraph.emplace_shared<PriorFactor<double>>(
X(0), measurements.at(0), Isotropic::Sigma(1, prior_sigma));

// Add "motion models".
// Add "motion models" ϕ(X(k),X(k+1)).
for (size_t k = 0; k < K - 1; k++) {
auto motion_models = motionModels(k, between_sigma);
nonlinearFactorGraph.emplace_shared<HybridNonlinearFactor>(modes[k],
motion_models);
}

// Add measurement factors
// Add measurement factors ϕ(X(k);z_k).
auto measurement_noise = Isotropic::Sigma(1, prior_sigma);
for (size_t k = 1; k < K; k++) {
nonlinearFactorGraph.emplace_shared<PriorFactor<double>>(
X(k), measurements.at(k), measurement_noise);
}

// Add "mode chain"
// Add "mode chain" ϕ(M(0)) ϕ(M(0),M(1)) ... ϕ(M(K-3),M(K-2))
addModeChain(&nonlinearFactorGraph, discrete_transition_prob);

// Create the linearization point.
for (size_t k = 0; k < K; k++) {
linearizationPoint.insert<double>(X(k), static_cast<double>(k + 1));
}

// The ground truth is robot moving forward
// and one less than the linearization point
linearizedFactorGraph = *nonlinearFactorGraph.linearize(linearizationPoint);
}

Expand All @@ -196,7 +195,7 @@ struct Switching {
}

/**
* @brief Add "mode chain" to HybridNonlinearFactorGraph from M(0) to M(K-2).
* @brief Add "mode chain" to HybridNonlinearFactorGraph from M(0) to M(K-1).
dellaert marked this conversation as resolved.
Show resolved Hide resolved
* E.g. if K=4, we want M0, M1 and M2.
*
* @param fg The factor graph to which the mode chain is added.
Expand Down
Loading
Loading