diff --git a/gtsam/discrete/DecisionTree-inl.h b/gtsam/discrete/DecisionTree-inl.h index 9d618dea02..f998a60658 100644 --- a/gtsam/discrete/DecisionTree-inl.h +++ b/gtsam/discrete/DecisionTree-inl.h @@ -93,7 +93,8 @@ namespace gtsam { /// print void print(const std::string& s, const LabelFormatter& labelFormatter, const ValueFormatter& valueFormatter) const override { - std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl; + std::cout << s << " Leaf [" << nrAssignments() << "] " + << valueFormatter(constant_) << std::endl; } /** Write graphviz format to stream `os`. */ @@ -827,6 +828,16 @@ namespace gtsam { return total; } + /****************************************************************************/ + template + size_t DecisionTree::nrAssignments() const { + size_t n = 0; + this->visitLeaf([&n](const DecisionTree::Leaf& leaf) { + n += leaf.nrAssignments(); + }); + return n; + } + /****************************************************************************/ // fold is just done with a visit template diff --git a/gtsam/discrete/DecisionTree.h b/gtsam/discrete/DecisionTree.h index ed19084859..bee0ce5c70 100644 --- a/gtsam/discrete/DecisionTree.h +++ b/gtsam/discrete/DecisionTree.h @@ -299,6 +299,42 @@ namespace gtsam { /// Return the number of leaves in the tree. size_t nrLeaves() const; + /** + * @brief This is a convenience function which returns the total number of + * leaf assignments in the decision tree. + * This function is not used for anymajor operations within the discrete + * factor graph framework. + * + * Leaf assignments represent the cardinality of each leaf node, e.g. in a + * binary tree each leaf has 2 assignments. This includes counts removed + * from implicit pruning hence, it will always be >= nrLeaves(). + * + * E.g. we have a decision tree as below, where each node has 2 branches: + * + * Choice(m1) + * 0 Choice(m0) + * 0 0 Leaf 0.0 + * 0 1 Leaf 0.0 + * 1 Choice(m0) + * 1 0 Leaf 1.0 + * 1 1 Leaf 2.0 + * + * In the unpruned form, the tree will have 4 assignments, 2 for each key, + * and 4 leaves. + * + * In the pruned form, the number of assignments is still 4 but the number + * of leaves is now 3, as below: + * + * Choice(m1) + * 0 Leaf 0.0 + * 1 Choice(m0) + * 1 0 Leaf 1.0 + * 1 1 Leaf 2.0 + * + * @return size_t + */ + size_t nrAssignments() const; + /** * @brief Fold a binary function over the tree, returning accumulator. * diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index ff18268b14..56f1659dc4 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -101,6 +101,14 @@ namespace gtsam { return DecisionTreeFactor(keys, result); } + /* ************************************************************************ */ + DecisionTreeFactor DecisionTreeFactor::apply(ADT::UnaryAssignment op) const { + // apply operand + ADT result = ADT::apply(op); + // Make a new factor + return DecisionTreeFactor(discreteKeys(), result); + } + /* ************************************************************************ */ DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine( size_t nrFrontals, ADT::Binary op) const { diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 6cce6e5d4d..e92c82b77b 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -147,6 +147,12 @@ namespace gtsam { /// @name Advanced Interface /// @{ + /** + * Apply unary operator (*this) "op" f + * @param op a unary operator that operates on AlgebraicDecisionTree + */ + DecisionTreeFactor apply(ADT::UnaryAssignment op) const; + /** * Apply binary operator (*this) "op" f * @param f the second argument for op diff --git a/gtsam/discrete/tests/testDecisionTree.cpp b/gtsam/discrete/tests/testDecisionTree.cpp index d2a94ddc33..efa7a1c445 100644 --- a/gtsam/discrete/tests/testDecisionTree.cpp +++ b/gtsam/discrete/tests/testDecisionTree.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #include diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index 266e02b0dd..b4bf612208 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -37,24 +37,6 @@ bool HybridBayesNet::equals(const This &bn, double tol) const { return Base::equals(bn, tol); } -/* ************************************************************************* */ -DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const { - AlgebraicDecisionTree discreteProbs; - - // The canonical decision tree factor which will get - // the discrete conditionals added to it. - DecisionTreeFactor discreteProbsFactor; - - for (auto &&conditional : *this) { - if (conditional->isDiscrete()) { - // Convert to a DecisionTreeFactor and add it to the main factor. - DecisionTreeFactor f(*conditional->asDiscrete()); - discreteProbsFactor = discreteProbsFactor * f; - } - } - return std::make_shared(discreteProbsFactor); -} - /* ************************************************************************* */ /** * @brief Helper function to get the pruner functional. @@ -144,53 +126,52 @@ std::function &, double)> prunerFunc( } /* ************************************************************************* */ -void HybridBayesNet::updateDiscreteConditionals( - const DecisionTreeFactor &prunedDiscreteProbs) { - KeyVector prunedTreeKeys = prunedDiscreteProbs.keys(); +DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals( + size_t maxNrLeaves) { + // Get the joint distribution of only the discrete keys + gttic_(HybridBayesNet_PruneDiscreteConditionals); + // The joint discrete probability. + DiscreteConditional discreteProbs; + + std::vector discrete_factor_idxs; + // Record frontal keys so we can maintain ordering + Ordering discrete_frontals; - // Loop with index since we need it later. for (size_t i = 0; i < this->size(); i++) { - HybridConditional::shared_ptr conditional = this->at(i); + auto conditional = this->at(i); if (conditional->isDiscrete()) { - auto discrete = conditional->asDiscrete(); - - // Convert pointer from conditional to factor - auto discreteTree = - std::dynamic_pointer_cast(discrete); - // Apply prunerFunc to the underlying AlgebraicDecisionTree - DecisionTreeFactor::ADT prunedDiscreteTree = - discreteTree->apply(prunerFunc(prunedDiscreteProbs, *conditional)); - - gttic_(HybridBayesNet_MakeConditional); - // Create the new (hybrid) conditional - KeyVector frontals(discrete->frontals().begin(), - discrete->frontals().end()); - auto prunedDiscrete = std::make_shared( - frontals.size(), conditional->discreteKeys(), prunedDiscreteTree); - conditional = std::make_shared(prunedDiscrete); - gttoc_(HybridBayesNet_MakeConditional); - - // Add it back to the BayesNet - this->at(i) = conditional; + discreteProbs = discreteProbs * (*conditional->asDiscrete()); + + Ordering conditional_keys(conditional->frontals()); + discrete_frontals += conditional_keys; + discrete_factor_idxs.push_back(i); } } -} - -/* ************************************************************************* */ -HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) { - // Get the decision tree of only the discrete keys - gttic_(HybridBayesNet_PruneDiscreteConditionals); - DecisionTreeFactor::shared_ptr discreteConditionals = - this->discreteConditionals(); const DecisionTreeFactor prunedDiscreteProbs = - discreteConditionals->prune(maxNrLeaves); + discreteProbs.prune(maxNrLeaves); gttoc_(HybridBayesNet_PruneDiscreteConditionals); + // Eliminate joint probability back into conditionals gttic_(HybridBayesNet_UpdateDiscreteConditionals); - this->updateDiscreteConditionals(prunedDiscreteProbs); + DiscreteFactorGraph dfg{prunedDiscreteProbs}; + DiscreteBayesNet::shared_ptr dbn = dfg.eliminateSequential(discrete_frontals); + + // Assign pruned discrete conditionals back at the correct indices. + for (size_t i = 0; i < discrete_factor_idxs.size(); i++) { + size_t idx = discrete_factor_idxs.at(i); + this->at(idx) = std::make_shared(dbn->at(i)); + } gttoc_(HybridBayesNet_UpdateDiscreteConditionals); - /* To Prune, we visitWith every leaf in the GaussianMixture. + return prunedDiscreteProbs; +} + +/* ************************************************************************* */ +HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) { + DecisionTreeFactor prunedDiscreteProbs = + this->pruneDiscreteConditionals(maxNrLeaves); + + /* To prune, we visitWith every leaf in the GaussianMixture. * For each leaf, using the assignment we can check the discrete decision tree * for 0.0 probability, then just set the leaf to a nullptr. * diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 23fc4d5d30..e71cfe9b43 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -136,13 +136,6 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { */ VectorValues optimize(const DiscreteValues &assignment) const; - /** - * @brief Get all the discrete conditionals as a decision tree factor. - * - * @return DecisionTreeFactor::shared_ptr - */ - DecisionTreeFactor::shared_ptr discreteConditionals() const; - /** * @brief Sample from an incomplete BayesNet, given missing variables. * @@ -222,11 +215,11 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { private: /** - * @brief Update the discrete conditionals with the pruned versions. + * @brief Prune all the discrete conditionals. * - * @param prunedDiscreteProbs + * @param maxNrLeaves */ - void updateDiscreteConditionals(const DecisionTreeFactor &prunedDiscreteProbs); + DecisionTreeFactor pruneDiscreteConditionals(size_t maxNrLeaves); #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION /** Serialization function */ diff --git a/gtsam/hybrid/HybridFactor.h b/gtsam/hybrid/HybridFactor.h index 13d5c2cba3..afd1c80328 100644 --- a/gtsam/hybrid/HybridFactor.h +++ b/gtsam/hybrid/HybridFactor.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include diff --git a/gtsam/hybrid/HybridFactorGraph.cpp b/gtsam/hybrid/HybridFactorGraph.cpp index d96a890f49..f7b96f6944 100644 --- a/gtsam/hybrid/HybridFactorGraph.cpp +++ b/gtsam/hybrid/HybridFactorGraph.cpp @@ -17,7 +17,6 @@ * @date January, 2023 */ -#include #include namespace gtsam { @@ -26,7 +25,7 @@ namespace gtsam { std::set HybridFactorGraph::discreteKeys() const { std::set keys; for (auto& factor : factors_) { - if (auto p = std::dynamic_pointer_cast(factor)) { + if (auto p = std::dynamic_pointer_cast(factor)) { for (const DiscreteKey& key : p->discreteKeys()) { keys.insert(key); } @@ -67,6 +66,8 @@ const KeySet HybridFactorGraph::continuousKeySet() const { for (const Key& key : p->continuousKeys()) { keys.insert(key); } + } else if (auto p = std::dynamic_pointer_cast(factor)) { + keys.insert(p->keys().begin(), p->keys().end()); } } return keys; diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.cpp b/gtsam/hybrid/HybridGaussianFactorGraph.cpp index 2b23ed4dbf..2d4ac83f61 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/HybridGaussianFactorGraph.cpp @@ -48,8 +48,6 @@ #include #include -// #define HYBRID_TIMING - namespace gtsam { /// Specialize EliminateableFactorGraph for HybridGaussianFactorGraph: @@ -120,7 +118,7 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const { // TODO(dellaert): in C++20, we can use std::visit. continue; } - } else if (dynamic_pointer_cast(f)) { + } else if (dynamic_pointer_cast(f)) { // Don't do anything for discrete-only factors // since we want to eliminate continuous values only. continue; @@ -167,8 +165,8 @@ discreteElimination(const HybridGaussianFactorGraph &factors, DiscreteFactorGraph dfg; for (auto &f : factors) { - if (auto dtf = dynamic_pointer_cast(f)) { - dfg.push_back(dtf); + if (auto df = dynamic_pointer_cast(f)) { + dfg.push_back(df); } else if (auto orphan = dynamic_pointer_cast(f)) { // Ignore orphaned clique. // TODO(dellaert): is this correct? If so explain here. @@ -262,6 +260,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors, }; DecisionTree probabilities(eliminationResults, probability); + return { std::make_shared(gaussianMixture), std::make_shared(discreteSeparator, probabilities)}; @@ -348,64 +347,68 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors, // When the number of assignments is large we may encounter stack overflows. // However this is also the case with iSAM2, so no pressure :) - // PREPROCESS: Identify the nature of the current elimination - - // TODO(dellaert): just check the factors: + // Check the factors: // 1. if all factors are discrete, then we can do discrete elimination: // 2. if all factors are continuous, then we can do continuous elimination: // 3. if not, we do hybrid elimination: - // First, identify the separator keys, i.e. all keys that are not frontal. - KeySet separatorKeys; + bool only_discrete = true, only_continuous = true; for (auto &&factor : factors) { - separatorKeys.insert(factor->begin(), factor->end()); - } - // remove frontals from separator - for (auto &k : frontalKeys) { - separatorKeys.erase(k); - } - - // Build a map from keys to DiscreteKeys - auto mapFromKeyToDiscreteKey = factors.discreteKeyMap(); - - // Fill in discrete frontals and continuous frontals. - std::set discreteFrontals; - KeySet continuousFrontals; - for (auto &k : frontalKeys) { - if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) { - discreteFrontals.insert(mapFromKeyToDiscreteKey.at(k)); - } else { - continuousFrontals.insert(k); - } - } - - // Fill in discrete discrete separator keys and continuous separator keys. - std::set discreteSeparatorSet; - KeyVector continuousSeparator; - for (auto &k : separatorKeys) { - if (mapFromKeyToDiscreteKey.find(k) != mapFromKeyToDiscreteKey.end()) { - discreteSeparatorSet.insert(mapFromKeyToDiscreteKey.at(k)); - } else { - continuousSeparator.push_back(k); + if (auto hybrid_factor = std::dynamic_pointer_cast(factor)) { + if (hybrid_factor->isDiscrete()) { + only_continuous = false; + } else if (hybrid_factor->isContinuous()) { + only_discrete = false; + } else if (hybrid_factor->isHybrid()) { + only_continuous = false; + only_discrete = false; + } + } else if (auto cont_factor = + std::dynamic_pointer_cast(factor)) { + only_discrete = false; + } else if (auto discrete_factor = + std::dynamic_pointer_cast(factor)) { + only_continuous = false; } } - // Check if we have any continuous keys: - const bool discrete_only = - continuousFrontals.empty() && continuousSeparator.empty(); - // NOTE: We should really defer the product here because of pruning - if (discrete_only) { + if (only_discrete) { // Case 1: we are only dealing with discrete return discreteElimination(factors, frontalKeys); - } else if (mapFromKeyToDiscreteKey.empty()) { + } else if (only_continuous) { // Case 2: we are only dealing with continuous return continuousElimination(factors, frontalKeys); } else { // Case 3: We are now in the hybrid land! + KeySet frontalKeysSet(frontalKeys.begin(), frontalKeys.end()); + + // Find all the keys in the set of continuous keys + // which are not in the frontal keys. This is our continuous separator. + KeyVector continuousSeparator; + auto continuousKeySet = factors.continuousKeySet(); + std::set_difference( + continuousKeySet.begin(), continuousKeySet.end(), + frontalKeysSet.begin(), frontalKeysSet.end(), + std::inserter(continuousSeparator, continuousSeparator.begin())); + + // Similarly for the discrete separator. + KeySet discreteSeparatorSet; + std::set discreteSeparator; + auto discreteKeySet = factors.discreteKeySet(); + std::set_difference( + discreteKeySet.begin(), discreteKeySet.end(), frontalKeysSet.begin(), + frontalKeysSet.end(), + std::inserter(discreteSeparatorSet, discreteSeparatorSet.begin())); + // Convert from set of keys to set of DiscreteKeys + auto discreteKeyMap = factors.discreteKeyMap(); + for (auto key : discreteSeparatorSet) { + discreteSeparator.insert(discreteKeyMap.at(key)); + } + return hybridElimination(factors, frontalKeys, continuousSeparator, - discreteSeparatorSet); + discreteSeparator); } } @@ -429,7 +432,7 @@ AlgebraicDecisionTree HybridGaussianFactorGraph::error( // Add the gaussian factor error to every leaf of the error tree. error_tree = error_tree.apply( [error](double leaf_value) { return leaf_value + error; }); - } else if (dynamic_pointer_cast(f)) { + } else if (dynamic_pointer_cast(f)) { // If factor at `idx` is discrete-only, we skip. continue; } else { diff --git a/gtsam/hybrid/HybridGaussianFactorGraph.h b/gtsam/hybrid/HybridGaussianFactorGraph.h index 421e69aa05..b3f1591507 100644 --- a/gtsam/hybrid/HybridGaussianFactorGraph.h +++ b/gtsam/hybrid/HybridGaussianFactorGraph.h @@ -40,6 +40,7 @@ class HybridEliminationTree; class HybridBayesTree; class HybridJunctionTree; class DecisionTreeFactor; +class TableFactor; class JacobianFactor; class HybridValues; diff --git a/gtsam/hybrid/HybridJunctionTree.cpp b/gtsam/hybrid/HybridJunctionTree.cpp index 6f2898bf19..22d3c7dd25 100644 --- a/gtsam/hybrid/HybridJunctionTree.cpp +++ b/gtsam/hybrid/HybridJunctionTree.cpp @@ -66,7 +66,7 @@ struct HybridConstructorTraversalData { for (auto& k : hf->discreteKeys()) { data.discreteKeys.insert(k.first); } - } else if (auto hf = std::dynamic_pointer_cast(f)) { + } else if (auto hf = std::dynamic_pointer_cast(f)) { for (auto& k : hf->discreteKeys()) { data.discreteKeys.insert(k.first); } @@ -161,7 +161,7 @@ HybridJunctionTree::HybridJunctionTree( Data rootData(0); rootData.junctionTreeNode = std::make_shared(); // Make a dummy node to gather - // the junction tree roots + // the junction tree roots treeTraversal::DepthFirstForest(eliminationTree, rootData, Data::ConstructorTraversalVisitorPre, Data::ConstructorTraversalVisitorPost); diff --git a/gtsam/hybrid/HybridNonlinearFactorGraph.cpp b/gtsam/hybrid/HybridNonlinearFactorGraph.cpp index 260f534e3c..2459e4ec9e 100644 --- a/gtsam/hybrid/HybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/HybridNonlinearFactorGraph.cpp @@ -17,6 +17,7 @@ */ #include +#include #include #include #include @@ -67,7 +68,7 @@ HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize( } else if (auto nlf = dynamic_pointer_cast(f)) { const GaussianFactor::shared_ptr& gf = nlf->linearize(continuousValues); linearFG->push_back(gf); - } else if (dynamic_pointer_cast(f)) { + } else if (dynamic_pointer_cast(f)) { // If discrete-only: doesn't need linearization. linearFG->push_back(f); } else if (auto gmf = dynamic_pointer_cast(f)) { diff --git a/gtsam/hybrid/HybridSmoother.cpp b/gtsam/hybrid/HybridSmoother.cpp index 56c62cf191..afa8340d2b 100644 --- a/gtsam/hybrid/HybridSmoother.cpp +++ b/gtsam/hybrid/HybridSmoother.cpp @@ -72,7 +72,8 @@ void HybridSmoother::update(HybridGaussianFactorGraph graph, addConditionals(graph, hybridBayesNet_, ordering); // Eliminate. - auto bayesNetFragment = graph.eliminateSequential(ordering); + HybridBayesNet::shared_ptr bayesNetFragment = + graph.eliminateSequential(ordering); /// Prune if (maxNrLeaves) { @@ -96,7 +97,8 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph, HybridGaussianFactorGraph graph(originalGraph); HybridBayesNet hybridBayesNet(originalHybridBayesNet); - // If we are not at the first iteration, means we have conditionals to add. + // If hybridBayesNet is not empty, + // it means we have conditionals to add to the factor graph. if (!hybridBayesNet.empty()) { // We add all relevant conditional mixtures on the last continuous variable // in the previous `hybridBayesNet` to the graph diff --git a/gtsam/hybrid/tests/Switching.h b/gtsam/hybrid/tests/Switching.h index 5842e1f1ac..4b2d3f11b6 100644 --- a/gtsam/hybrid/tests/Switching.h +++ b/gtsam/hybrid/tests/Switching.h @@ -202,31 +202,16 @@ struct Switching { * @brief Add "mode chain" to HybridNonlinearFactorGraph from M(0) to M(K-2). * E.g. if K=4, we want M0, M1 and M2. * - * @param fg The nonlinear factor graph to which the mode chain is added. + * @param fg The factor graph to which the mode chain is added. */ - void addModeChain(HybridNonlinearFactorGraph *fg, + template + void addModeChain(FACTORGRAPH *fg, std::string discrete_transition_prob = "1/2 3/2") { - fg->emplace_shared(modes[0], "1/1"); + fg->template emplace_shared(modes[0], "1/1"); for (size_t k = 0; k < K - 2; k++) { auto parents = {modes[k]}; - fg->emplace_shared(modes[k + 1], parents, - discrete_transition_prob); - } - } - - /** - * @brief Add "mode chain" to HybridGaussianFactorGraph from M(0) to M(K-2). - * E.g. if K=4, we want M0, M1 and M2. - * - * @param fg The gaussian factor graph to which the mode chain is added. - */ - void addModeChain(HybridGaussianFactorGraph *fg, - std::string discrete_transition_prob = "1/2 3/2") { - fg->emplace_shared(modes[0], "1/1"); - for (size_t k = 0; k < K - 2; k++) { - auto parents = {modes[k]}; - fg->emplace_shared(modes[k + 1], parents, - discrete_transition_prob); + fg->template emplace_shared( + modes[k + 1], parents, discrete_transition_prob); } } }; diff --git a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp index 75ba5a0594..5492234979 100644 --- a/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp +++ b/gtsam/hybrid/tests/testGaussianMixtureFactor.cpp @@ -108,7 +108,7 @@ TEST(GaussianMixtureFactor, Printing) { std::string expected = R"(Hybrid [x1 x2; 1]{ Choice(1) - 0 Leaf : + 0 Leaf [1] : A[x1] = [ 0; 0 @@ -120,7 +120,7 @@ TEST(GaussianMixtureFactor, Printing) { b = [ 0 0 ] No noise model - 1 Leaf : + 1 Leaf [1] : A[x1] = [ 0; 0 diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index f25675a552..1dfcbd6b72 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -231,7 +231,7 @@ TEST(HybridBayesNet, Pruning) { auto prunedTree = prunedBayesNet.evaluate(delta.continuous()); // Regression test on pruned logProbability tree - std::vector pruned_leaves = {0.0, 20.346113, 0.0, 19.738098}; + std::vector pruned_leaves = {0.0, 32.713418, 0.0, 31.735823}; AlgebraicDecisionTree expected_pruned(discrete_keys, pruned_leaves); EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6)); @@ -248,8 +248,10 @@ TEST(HybridBayesNet, Pruning) { logProbability += posterior->at(4)->asDiscrete()->logProbability(hybridValues); + // Regression double density = exp(logProbability); - EXPECT_DOUBLES_EQUAL(density, actualTree(discrete_values), 1e-9); + EXPECT_DOUBLES_EQUAL(density, + 1.6078460548731697 * actualTree(discrete_values), 1e-6); EXPECT_DOUBLES_EQUAL(density, prunedTree(discrete_values), 1e-9); EXPECT_DOUBLES_EQUAL(logProbability, posterior->logProbability(hybridValues), 1e-9); @@ -283,20 +285,30 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) { EXPECT_LONGS_EQUAL(7, posterior->size()); size_t maxNrLeaves = 3; - auto discreteConditionals = posterior->discreteConditionals(); + DiscreteConditional discreteConditionals; + for (auto&& conditional : *posterior) { + if (conditional->isDiscrete()) { + discreteConditionals = + discreteConditionals * (*conditional->asDiscrete()); + } + } const DecisionTreeFactor::shared_ptr prunedDecisionTree = std::make_shared( - discreteConditionals->prune(maxNrLeaves)); + discreteConditionals.prune(maxNrLeaves)); EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/, prunedDecisionTree->nrLeaves()); - auto original_discrete_conditionals = *(posterior->at(4)->asDiscrete()); + // regression + DiscreteKeys dkeys{{M(0), 2}, {M(1), 2}, {M(2), 2}}; + DecisionTreeFactor::ADT potentials( + dkeys, std::vector{0, 0, 0, 0.505145423, 0, 1, 0, 0.494854577}); + DiscreteConditional expected_discrete_conditionals(1, dkeys, potentials); // Prune! posterior->prune(maxNrLeaves); - // Functor to verify values against the original_discrete_conditionals + // Functor to verify values against the expected_discrete_conditionals auto checker = [&](const Assignment& assignment, double probability) -> double { // typecast so we can use this to get probability value @@ -304,7 +316,7 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) { if (prunedDecisionTree->operator()(choices) == 0) { EXPECT_DOUBLES_EQUAL(0.0, probability, 1e-9); } else { - EXPECT_DOUBLES_EQUAL(original_discrete_conditionals(choices), probability, + EXPECT_DOUBLES_EQUAL(expected_discrete_conditionals(choices), probability, 1e-9); } return 0.0; diff --git a/gtsam/hybrid/tests/testHybridBayesTree.cpp b/gtsam/hybrid/tests/testHybridBayesTree.cpp index 578f5d605c..81b257c32e 100644 --- a/gtsam/hybrid/tests/testHybridBayesTree.cpp +++ b/gtsam/hybrid/tests/testHybridBayesTree.cpp @@ -146,7 +146,7 @@ TEST(HybridBayesTree, Optimize) { DiscreteFactorGraph dfg; for (auto&& f : *remainingFactorGraph) { - auto discreteFactor = dynamic_pointer_cast(f); + auto discreteFactor = dynamic_pointer_cast(f); assert(discreteFactor); dfg.push_back(discreteFactor); } diff --git a/gtsam/hybrid/tests/testHybridEstimation.cpp b/gtsam/hybrid/tests/testHybridEstimation.cpp index b5f5244fa4..b8edc39d88 100644 --- a/gtsam/hybrid/tests/testHybridEstimation.cpp +++ b/gtsam/hybrid/tests/testHybridEstimation.cpp @@ -140,6 +140,61 @@ TEST(HybridEstimation, IncrementalSmoother) { EXPECT(assert_equal(expected_continuous, result)); } +/****************************************************************************/ +// Test approximate inference with an additional pruning step. +TEST(HybridEstimation, ISAM) { + size_t K = 15; + std::vector measurements = {0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 6, + 7, 8, 9, 9, 9, 10, 11, 11, 11, 11}; + // Ground truth discrete seq + std::vector discrete_seq = {1, 1, 0, 0, 0, 1, 1, 1, 1, 0, + 1, 1, 1, 0, 0, 1, 1, 0, 0, 0}; + // Switching example of robot moving in 1D + // with given measurements and equal mode priors. + Switching switching(K, 1.0, 0.1, measurements, "1/1 1/1"); + HybridNonlinearISAM isam; + HybridNonlinearFactorGraph graph; + Values initial; + + // gttic_(Estimation); + + // Add the X(0) prior + graph.push_back(switching.nonlinearFactorGraph.at(0)); + initial.insert(X(0), switching.linearizationPoint.at(X(0))); + + HybridGaussianFactorGraph linearized; + + for (size_t k = 1; k < K; k++) { + // Motion Model + graph.push_back(switching.nonlinearFactorGraph.at(k)); + // Measurement + graph.push_back(switching.nonlinearFactorGraph.at(k + K - 1)); + + initial.insert(X(k), switching.linearizationPoint.at(X(k))); + + isam.update(graph, initial, 3); + // isam.bayesTree().print("\n\n"); + + graph.resize(0); + initial.clear(); + } + + Values result = isam.estimate(); + DiscreteValues assignment = isam.assignment(); + + DiscreteValues expected_discrete; + for (size_t k = 0; k < K - 1; k++) { + expected_discrete[M(k)] = discrete_seq[k]; + } + EXPECT(assert_equal(expected_discrete, assignment)); + + Values expected_continuous; + for (size_t k = 0; k < K; k++) { + expected_continuous.insert(X(k), measurements[k]); + } + EXPECT(assert_equal(expected_continuous, result)); +} + /** * @brief A function to get a specific 1D robot motion problem as a linearized * factor graph. This is the problem P(X|Z, M), i.e. estimating the continuous diff --git a/gtsam/hybrid/tests/testHybridFactorGraph.cpp b/gtsam/hybrid/tests/testHybridFactorGraph.cpp index f5b4ec0b19..33c0761eb4 100644 --- a/gtsam/hybrid/tests/testHybridFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridFactorGraph.cpp @@ -18,7 +18,9 @@ #include #include #include +#include #include +#include #include using namespace std; @@ -37,6 +39,32 @@ TEST(HybridFactorGraph, Constructor) { HybridFactorGraph fg; } +/* ************************************************************************* */ +// Test if methods to get keys work as expected. +TEST(HybridFactorGraph, Keys) { + HybridGaussianFactorGraph hfg; + + // Add prior on x0 + hfg.add(JacobianFactor(X(0), I_3x3, Z_3x1)); + + // Add factor between x0 and x1 + hfg.add(JacobianFactor(X(0), I_3x3, X(1), -I_3x3, Z_3x1)); + + // Add a gaussian mixture factor ϕ(x1, c1) + DiscreteKey m1(M(1), 2); + DecisionTree dt( + M(1), std::make_shared(X(1), I_3x3, Z_3x1), + std::make_shared(X(1), I_3x3, Vector3::Ones())); + hfg.add(GaussianMixtureFactor({X(1)}, {m1}, dt)); + + KeySet expected_continuous{X(0), X(1)}; + EXPECT( + assert_container_equality(expected_continuous, hfg.continuousKeySet())); + + KeySet expected_discrete{M(1)}; + EXPECT(assert_container_equality(expected_discrete, hfg.discreteKeySet())); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp index 8276264ae0..1da897103e 100644 --- a/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridGaussianFactorGraph.cpp @@ -902,7 +902,7 @@ TEST(HybridGaussianFactorGraph, EliminateSwitchingNetwork) { // Test resulting posterior Bayes net has correct size: EXPECT_LONGS_EQUAL(8, posterior->size()); - // TODO(dellaert): this test fails - no idea why. + // Ratio test EXPECT(ratioTest(bn, measurements, *posterior)); } diff --git a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp index af3a23b947..12506b8af1 100644 --- a/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp +++ b/gtsam/hybrid/tests/testHybridNonlinearFactorGraph.cpp @@ -492,7 +492,7 @@ factor 0: factor 1: Hybrid [x0 x1; m0]{ Choice(m0) - 0 Leaf : + 0 Leaf [1] : A[x0] = [ -1 ] @@ -502,7 +502,7 @@ Hybrid [x0 x1; m0]{ b = [ -1 ] No noise model - 1 Leaf : + 1 Leaf [1] : A[x0] = [ -1 ] @@ -516,7 +516,7 @@ Hybrid [x0 x1; m0]{ factor 2: Hybrid [x1 x2; m1]{ Choice(m1) - 0 Leaf : + 0 Leaf [1] : A[x1] = [ -1 ] @@ -526,7 +526,7 @@ Hybrid [x1 x2; m1]{ b = [ -1 ] No noise model - 1 Leaf : + 1 Leaf [1] : A[x1] = [ -1 ] @@ -550,16 +550,16 @@ factor 4: b = [ -10 ] No noise model factor 5: P( m0 ): - Leaf 0.5 + Leaf [2] 0.5 factor 6: P( m1 | m0 ): Choice(m1) 0 Choice(m0) - 0 0 Leaf 0.33333333 - 0 1 Leaf 0.6 + 0 0 Leaf [1] 0.33333333 + 0 1 Leaf [1] 0.6 1 Choice(m0) - 1 0 Leaf 0.66666667 - 1 1 Leaf 0.4 + 1 0 Leaf [1] 0.66666667 + 1 1 Leaf [1] 0.4 )"; EXPECT(assert_print_equal(expected_hybridFactorGraph, linearizedFactorGraph)); @@ -570,13 +570,13 @@ size: 3 conditional 0: Hybrid P( x0 | x1 m0) Discrete Keys = (m0, 2), Choice(m0) - 0 Leaf p(x0 | x1) + 0 Leaf [1] p(x0 | x1) R = [ 10.0499 ] S[x1] = [ -0.0995037 ] d = [ -9.85087 ] No noise model - 1 Leaf p(x0 | x1) + 1 Leaf [1] p(x0 | x1) R = [ 10.0499 ] S[x1] = [ -0.0995037 ] d = [ -9.95037 ] @@ -586,26 +586,26 @@ conditional 1: Hybrid P( x1 | x2 m0 m1) Discrete Keys = (m0, 2), (m1, 2), Choice(m1) 0 Choice(m0) - 0 0 Leaf p(x1 | x2) + 0 0 Leaf [1] p(x1 | x2) R = [ 10.099 ] S[x2] = [ -0.0990196 ] d = [ -9.99901 ] No noise model - 0 1 Leaf p(x1 | x2) + 0 1 Leaf [1] p(x1 | x2) R = [ 10.099 ] S[x2] = [ -0.0990196 ] d = [ -9.90098 ] No noise model 1 Choice(m0) - 1 0 Leaf p(x1 | x2) + 1 0 Leaf [1] p(x1 | x2) R = [ 10.099 ] S[x2] = [ -0.0990196 ] d = [ -10.098 ] No noise model - 1 1 Leaf p(x1 | x2) + 1 1 Leaf [1] p(x1 | x2) R = [ 10.099 ] S[x2] = [ -0.0990196 ] d = [ -10 ] @@ -615,14 +615,14 @@ conditional 2: Hybrid P( x2 | m0 m1) Discrete Keys = (m0, 2), (m1, 2), Choice(m1) 0 Choice(m0) - 0 0 Leaf p(x2) + 0 0 Leaf [1] p(x2) R = [ 10.0494 ] d = [ -10.1489 ] mean: 1 elements x2: -1.0099 No noise model - 0 1 Leaf p(x2) + 0 1 Leaf [1] p(x2) R = [ 10.0494 ] d = [ -10.1479 ] mean: 1 elements @@ -630,14 +630,14 @@ conditional 2: Hybrid P( x2 | m0 m1) No noise model 1 Choice(m0) - 1 0 Leaf p(x2) + 1 0 Leaf [1] p(x2) R = [ 10.0494 ] d = [ -10.0504 ] mean: 1 elements x2: -1.0001 No noise model - 1 1 Leaf p(x2) + 1 1 Leaf [1] p(x2) R = [ 10.0494 ] d = [ -10.0494 ] mean: 1 elements diff --git a/gtsam/hybrid/tests/testMixtureFactor.cpp b/gtsam/hybrid/tests/testMixtureFactor.cpp index 67a7fd8ae1..03fdccff26 100644 --- a/gtsam/hybrid/tests/testMixtureFactor.cpp +++ b/gtsam/hybrid/tests/testMixtureFactor.cpp @@ -63,8 +63,8 @@ TEST(MixtureFactor, Printing) { R"(Hybrid [x1 x2; 1] MixtureFactor Choice(1) - 0 Leaf Nonlinear factor on 2 keys - 1 Leaf Nonlinear factor on 2 keys + 0 Leaf [1] Nonlinear factor on 2 keys + 1 Leaf [1] Nonlinear factor on 2 keys )"; EXPECT(assert_print_equal(expected, mixtureFactor)); } diff --git a/gtsam/linear/GaussianConditional.cpp b/gtsam/linear/GaussianConditional.cpp index 188c31abec..0112835aa4 100644 --- a/gtsam/linear/GaussianConditional.cpp +++ b/gtsam/linear/GaussianConditional.cpp @@ -99,7 +99,7 @@ namespace gtsam { /* ************************************************************************ */ void GaussianConditional::print(const string &s, const KeyFormatter& formatter) const { - cout << s << " p("; + cout << (s.empty() ? "" : s + " ") << "p("; for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) { cout << formatter(*it) << (nrFrontals() > 1 ? " " : ""); }