From e47531e79431a4e3e38e699f7fc0f7d96b1cb5ce Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 14 Jul 2023 17:52:22 -0400 Subject: [PATCH] prune joint discrete probability which is faster --- gtsam/hybrid/HybridBayesNet.cpp | 85 ++++++++++------------- gtsam/hybrid/HybridBayesNet.h | 13 +--- gtsam/hybrid/tests/testHybridBayesNet.cpp | 26 +++++-- 3 files changed, 58 insertions(+), 66 deletions(-) diff --git a/gtsam/hybrid/HybridBayesNet.cpp b/gtsam/hybrid/HybridBayesNet.cpp index ff2752bcbe..b4bf612208 100644 --- a/gtsam/hybrid/HybridBayesNet.cpp +++ b/gtsam/hybrid/HybridBayesNet.cpp @@ -37,19 +37,6 @@ bool HybridBayesNet::equals(const This &bn, double tol) const { return Base::equals(bn, tol); } -/* ************************************************************************* */ -DiscreteConditional::shared_ptr HybridBayesNet::discreteConditionals() const { - // The joint discrete probability. - DiscreteConditional discreteProbs; - - for (auto &&conditional : *this) { - if (conditional->isDiscrete()) { - discreteProbs = discreteProbs * (*conditional->asDiscrete()); - } - } - return std::make_shared(discreteProbs); -} - /* ************************************************************************* */ /** * @brief Helper function to get the pruner functional. @@ -139,52 +126,52 @@ std::function &, double)> prunerFunc( } /* ************************************************************************* */ -void HybridBayesNet::updateDiscreteConditionals( - const DecisionTreeFactor &prunedDiscreteProbs) { - // TODO(Varun) Should prune the joint conditional, maybe during elimination? - // Loop with index since we need it later. +DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals( + size_t maxNrLeaves) { + // Get the joint distribution of only the discrete keys + gttic_(HybridBayesNet_PruneDiscreteConditionals); + // The joint discrete probability. + DiscreteConditional discreteProbs; + + std::vector discrete_factor_idxs; + // Record frontal keys so we can maintain ordering + Ordering discrete_frontals; + for (size_t i = 0; i < this->size(); i++) { - HybridConditional::shared_ptr conditional = this->at(i); + auto conditional = this->at(i); if (conditional->isDiscrete()) { - auto discrete = conditional->asDiscrete(); - - // Convert pointer from conditional to factor - auto discreteFactor = - std::dynamic_pointer_cast(discrete); - // Apply prunerFunc to the underlying conditional - DecisionTreeFactor::ADT prunedDiscreteFactor = - discreteFactor->apply(prunerFunc(prunedDiscreteProbs, *conditional)); - - gttic_(HybridBayesNet_MakeConditional); - // Create the new (hybrid) conditional - KeyVector frontals(discrete->frontals().begin(), - discrete->frontals().end()); - auto prunedDiscrete = std::make_shared( - frontals.size(), conditional->discreteKeys(), prunedDiscreteFactor); - conditional = std::make_shared(prunedDiscrete); - gttoc_(HybridBayesNet_MakeConditional); - - // Add it back to the BayesNet - this->at(i) = conditional; + discreteProbs = discreteProbs * (*conditional->asDiscrete()); + + Ordering conditional_keys(conditional->frontals()); + discrete_frontals += conditional_keys; + discrete_factor_idxs.push_back(i); } } -} - -/* ************************************************************************* */ -HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) { - // Get the joint distribution of only the discrete keys - gttic_(HybridBayesNet_PruneDiscreteConditionals); - DiscreteConditional::shared_ptr discreteConditionals = - this->discreteConditionals(); const DecisionTreeFactor prunedDiscreteProbs = - discreteConditionals->prune(maxNrLeaves); + discreteProbs.prune(maxNrLeaves); gttoc_(HybridBayesNet_PruneDiscreteConditionals); + // Eliminate joint probability back into conditionals gttic_(HybridBayesNet_UpdateDiscreteConditionals); - this->updateDiscreteConditionals(prunedDiscreteProbs); + DiscreteFactorGraph dfg{prunedDiscreteProbs}; + DiscreteBayesNet::shared_ptr dbn = dfg.eliminateSequential(discrete_frontals); + + // Assign pruned discrete conditionals back at the correct indices. + for (size_t i = 0; i < discrete_factor_idxs.size(); i++) { + size_t idx = discrete_factor_idxs.at(i); + this->at(idx) = std::make_shared(dbn->at(i)); + } gttoc_(HybridBayesNet_UpdateDiscreteConditionals); - /* To Prune, we visitWith every leaf in the GaussianMixture. + return prunedDiscreteProbs; +} + +/* ************************************************************************* */ +HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) { + DecisionTreeFactor prunedDiscreteProbs = + this->pruneDiscreteConditionals(maxNrLeaves); + + /* To prune, we visitWith every leaf in the GaussianMixture. * For each leaf, using the assignment we can check the discrete decision tree * for 0.0 probability, then just set the leaf to a nullptr. * diff --git a/gtsam/hybrid/HybridBayesNet.h b/gtsam/hybrid/HybridBayesNet.h index 19e88d754d..e71cfe9b43 100644 --- a/gtsam/hybrid/HybridBayesNet.h +++ b/gtsam/hybrid/HybridBayesNet.h @@ -136,13 +136,6 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { */ VectorValues optimize(const DiscreteValues &assignment) const; - /** - * @brief Get all the discrete conditionals as a decision tree factor. - * - * @return DiscreteConditional::shared_ptr - */ - DiscreteConditional::shared_ptr discreteConditionals() const; - /** * @brief Sample from an incomplete BayesNet, given missing variables. * @@ -222,11 +215,11 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet { private: /** - * @brief Update the discrete conditionals with the pruned versions. + * @brief Prune all the discrete conditionals. * - * @param prunedDiscreteProbs + * @param maxNrLeaves */ - void updateDiscreteConditionals(const DecisionTreeFactor &prunedDiscreteProbs); + DecisionTreeFactor pruneDiscreteConditionals(size_t maxNrLeaves); #ifdef GTSAM_ENABLE_BOOST_SERIALIZATION /** Serialization function */ diff --git a/gtsam/hybrid/tests/testHybridBayesNet.cpp b/gtsam/hybrid/tests/testHybridBayesNet.cpp index d2f39c6edd..5248fce015 100644 --- a/gtsam/hybrid/tests/testHybridBayesNet.cpp +++ b/gtsam/hybrid/tests/testHybridBayesNet.cpp @@ -231,7 +231,7 @@ TEST(HybridBayesNet, Pruning) { auto prunedTree = prunedBayesNet.evaluate(delta.continuous()); // Regression test on pruned logProbability tree - std::vector pruned_leaves = {0.0, 20.346113, 0.0, 19.738098}; + std::vector pruned_leaves = {0.0, 32.713418, 0.0, 31.735823}; AlgebraicDecisionTree expected_pruned(discrete_keys, pruned_leaves); EXPECT(assert_equal(expected_pruned, prunedTree, 1e-6)); @@ -248,8 +248,10 @@ TEST(HybridBayesNet, Pruning) { logProbability += posterior->at(4)->asDiscrete()->logProbability(hybridValues); + // Regression double density = exp(logProbability); - EXPECT_DOUBLES_EQUAL(density, actualTree(discrete_values), 1e-9); + EXPECT_DOUBLES_EQUAL(density, + 1.6078460548731697 * actualTree(discrete_values), 1e-6); EXPECT_DOUBLES_EQUAL(density, prunedTree(discrete_values), 1e-9); EXPECT_DOUBLES_EQUAL(logProbability, posterior->logProbability(hybridValues), 1e-9); @@ -283,10 +285,16 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) { EXPECT_LONGS_EQUAL(7, posterior->size()); size_t maxNrLeaves = 3; - auto discreteConditionals = posterior->discreteConditionals(); + DiscreteConditional discreteConditionals; + for (auto&& conditional : *posterior) { + if (conditional->isDiscrete()) { + discreteConditionals = + discreteConditionals * (*conditional->asDiscrete()); + } + } const DecisionTreeFactor::shared_ptr prunedDecisionTree = std::make_shared( - discreteConditionals->prune(maxNrLeaves)); + discreteConditionals.prune(maxNrLeaves)); #ifdef GTSAM_DT_MERGING EXPECT_LONGS_EQUAL(maxNrLeaves + 2 /*2 zero leaves*/, @@ -295,12 +303,16 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) { EXPECT_LONGS_EQUAL(8 /*full tree*/, prunedDecisionTree->nrLeaves()); #endif - auto original_discrete_conditionals = *(posterior->at(4)->asDiscrete()); + // regression + DiscreteKeys dkeys{{M(0), 2}, {M(1), 2}, {M(2), 2}}; + DecisionTreeFactor::ADT potentials( + dkeys, std::vector{0, 0, 0, 0.505145423, 0, 1, 0, 0.494854577}); + DiscreteConditional expected_discrete_conditionals(1, dkeys, potentials); // Prune! posterior->prune(maxNrLeaves); - // Functor to verify values against the original_discrete_conditionals + // Functor to verify values against the expected_discrete_conditionals auto checker = [&](const Assignment& assignment, double probability) -> double { // typecast so we can use this to get probability value @@ -308,7 +320,7 @@ TEST(HybridBayesNet, UpdateDiscreteConditionals) { if (prunedDecisionTree->operator()(choices) == 0) { EXPECT_DOUBLES_EQUAL(0.0, probability, 1e-9); } else { - EXPECT_DOUBLES_EQUAL(original_discrete_conditionals(choices), probability, + EXPECT_DOUBLES_EQUAL(expected_discrete_conditionals(choices), probability, 1e-9); } return 0.0;