Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hybrid Elimination Improvements #1575

Merged
merged 17 commits into from
Jul 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion gtsam/discrete/DecisionTree-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ namespace gtsam {
/// print
void print(const std::string& s, const LabelFormatter& labelFormatter,
const ValueFormatter& valueFormatter) const override {
std::cout << s << " Leaf " << valueFormatter(constant_) << std::endl;
std::cout << s << " Leaf [" << nrAssignments() << "] "
<< valueFormatter(constant_) << std::endl;
}

/** Write graphviz format to stream `os`. */
Expand Down Expand Up @@ -827,6 +828,16 @@ namespace gtsam {
return total;
}

/****************************************************************************/
template <typename L, typename Y>
size_t DecisionTree<L, Y>::nrAssignments() const {
size_t n = 0;
this->visitLeaf([&n](const DecisionTree<L, Y>::Leaf& leaf) {
n += leaf.nrAssignments();
});
return n;
}

/****************************************************************************/
// fold is just done with a visit
template <typename L, typename Y>
Expand Down
36 changes: 36 additions & 0 deletions gtsam/discrete/DecisionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,42 @@ namespace gtsam {
/// Return the number of leaves in the tree.
size_t nrLeaves() const;

/**
* @brief This is a convenience function which returns the total number of
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spelling. And why are we adding it ? And why is the implementation recursive.
I would just as well delete it unless it has a purpose.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit lost on which word is misspelled. The purpose is to help with testing and ensure correctness as a convenience method.

* leaf assignments in the decision tree.
* This function is not used for anymajor operations within the discrete
* factor graph framework.
*
* Leaf assignments represent the cardinality of each leaf node, e.g. in a
* binary tree each leaf has 2 assignments. This includes counts removed
* from implicit pruning hence, it will always be >= nrLeaves().
*
* E.g. we have a decision tree as below, where each node has 2 branches:
*
* Choice(m1)
* 0 Choice(m0)
* 0 0 Leaf 0.0
* 0 1 Leaf 0.0
* 1 Choice(m0)
* 1 0 Leaf 1.0
* 1 1 Leaf 2.0
*
* In the unpruned form, the tree will have 4 assignments, 2 for each key,
* and 4 leaves.
*
* In the pruned form, the number of assignments is still 4 but the number
* of leaves is now 3, as below:
*
* Choice(m1)
* 0 Leaf 0.0
* 1 Choice(m0)
* 1 0 Leaf 1.0
* 1 1 Leaf 2.0
*
* @return size_t
*/
size_t nrAssignments() const;

/**
* @brief Fold a binary function over the tree, returning accumulator.
*
Expand Down
8 changes: 8 additions & 0 deletions gtsam/discrete/DecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@ namespace gtsam {
return DecisionTreeFactor(keys, result);
}

/* ************************************************************************ */
DecisionTreeFactor DecisionTreeFactor::apply(ADT::UnaryAssignment op) const {
// apply operand
ADT result = ADT::apply(op);
// Make a new factor
return DecisionTreeFactor(discreteKeys(), result);
}

/* ************************************************************************ */
DecisionTreeFactor::shared_ptr DecisionTreeFactor::combine(
size_t nrFrontals, ADT::Binary op) const {
Expand Down
6 changes: 6 additions & 0 deletions gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,12 @@ namespace gtsam {
/// @name Advanced Interface
/// @{

/**
* Apply unary operator (*this) "op" f
* @param op a unary operator that operates on AlgebraicDecisionTree
*/
DecisionTreeFactor apply(ADT::UnaryAssignment op) const;

/**
* Apply binary operator (*this) "op" f
* @param f the second argument for op
Expand Down
1 change: 1 addition & 0 deletions gtsam/discrete/tests/testDecisionTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/discrete/Signature.h>
#include <gtsam/inference/Symbol.h>

#include <iomanip>

Expand Down
89 changes: 35 additions & 54 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,24 +37,6 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
return Base::equals(bn, tol);
}

/* ************************************************************************* */
DecisionTreeFactor::shared_ptr HybridBayesNet::discreteConditionals() const {
AlgebraicDecisionTree<Key> discreteProbs;

// The canonical decision tree factor which will get
// the discrete conditionals added to it.
DecisionTreeFactor discreteProbsFactor;

for (auto &&conditional : *this) {
if (conditional->isDiscrete()) {
// Convert to a DecisionTreeFactor and add it to the main factor.
DecisionTreeFactor f(*conditional->asDiscrete());
discreteProbsFactor = discreteProbsFactor * f;
}
}
return std::make_shared<DecisionTreeFactor>(discreteProbsFactor);
}

/* ************************************************************************* */
/**
* @brief Helper function to get the pruner functional.
Expand Down Expand Up @@ -144,53 +126,52 @@ std::function<double(const Assignment<Key> &, double)> prunerFunc(
}

/* ************************************************************************* */
void HybridBayesNet::updateDiscreteConditionals(
const DecisionTreeFactor &prunedDiscreteProbs) {
KeyVector prunedTreeKeys = prunedDiscreteProbs.keys();
DecisionTreeFactor HybridBayesNet::pruneDiscreteConditionals(
size_t maxNrLeaves) {
// Get the joint distribution of only the discrete keys
gttic_(HybridBayesNet_PruneDiscreteConditionals);
// The joint discrete probability.
DiscreteConditional discreteProbs;

std::vector<size_t> discrete_factor_idxs;
// Record frontal keys so we can maintain ordering
Ordering discrete_frontals;

// Loop with index since we need it later.
for (size_t i = 0; i < this->size(); i++) {
HybridConditional::shared_ptr conditional = this->at(i);
auto conditional = this->at(i);
if (conditional->isDiscrete()) {
auto discrete = conditional->asDiscrete();

// Convert pointer from conditional to factor
auto discreteTree =
std::dynamic_pointer_cast<DecisionTreeFactor::ADT>(discrete);
// Apply prunerFunc to the underlying AlgebraicDecisionTree
DecisionTreeFactor::ADT prunedDiscreteTree =
discreteTree->apply(prunerFunc(prunedDiscreteProbs, *conditional));

gttic_(HybridBayesNet_MakeConditional);
// Create the new (hybrid) conditional
KeyVector frontals(discrete->frontals().begin(),
discrete->frontals().end());
auto prunedDiscrete = std::make_shared<DiscreteLookupTable>(
frontals.size(), conditional->discreteKeys(), prunedDiscreteTree);
conditional = std::make_shared<HybridConditional>(prunedDiscrete);
gttoc_(HybridBayesNet_MakeConditional);

// Add it back to the BayesNet
this->at(i) = conditional;
discreteProbs = discreteProbs * (*conditional->asDiscrete());

Ordering conditional_keys(conditional->frontals());
discrete_frontals += conditional_keys;
discrete_factor_idxs.push_back(i);
}
}
}

/* ************************************************************************* */
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
// Get the decision tree of only the discrete keys
gttic_(HybridBayesNet_PruneDiscreteConditionals);
DecisionTreeFactor::shared_ptr discreteConditionals =
this->discreteConditionals();
const DecisionTreeFactor prunedDiscreteProbs =
discreteConditionals->prune(maxNrLeaves);
discreteProbs.prune(maxNrLeaves);
gttoc_(HybridBayesNet_PruneDiscreteConditionals);

// Eliminate joint probability back into conditionals
gttic_(HybridBayesNet_UpdateDiscreteConditionals);
this->updateDiscreteConditionals(prunedDiscreteProbs);
DiscreteFactorGraph dfg{prunedDiscreteProbs};
DiscreteBayesNet::shared_ptr dbn = dfg.eliminateSequential(discrete_frontals);

// Assign pruned discrete conditionals back at the correct indices.
for (size_t i = 0; i < discrete_factor_idxs.size(); i++) {
size_t idx = discrete_factor_idxs.at(i);
this->at(idx) = std::make_shared<HybridConditional>(dbn->at(i));
}
gttoc_(HybridBayesNet_UpdateDiscreteConditionals);

/* To Prune, we visitWith every leaf in the GaussianMixture.
return prunedDiscreteProbs;
}

/* ************************************************************************* */
HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
DecisionTreeFactor prunedDiscreteProbs =
this->pruneDiscreteConditionals(maxNrLeaves);

/* To prune, we visitWith every leaf in the GaussianMixture.
* For each leaf, using the assignment we can check the discrete decision tree
* for 0.0 probability, then just set the leaf to a nullptr.
*
Expand Down
13 changes: 3 additions & 10 deletions gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,6 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
*/
VectorValues optimize(const DiscreteValues &assignment) const;

/**
* @brief Get all the discrete conditionals as a decision tree factor.
*
* @return DecisionTreeFactor::shared_ptr
*/
DecisionTreeFactor::shared_ptr discreteConditionals() const;

/**
* @brief Sample from an incomplete BayesNet, given missing variables.
*
Expand Down Expand Up @@ -222,11 +215,11 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {

private:
/**
* @brief Update the discrete conditionals with the pruned versions.
* @brief Prune all the discrete conditionals.
*
* @param prunedDiscreteProbs
* @param maxNrLeaves
*/
void updateDiscreteConditionals(const DecisionTreeFactor &prunedDiscreteProbs);
DecisionTreeFactor pruneDiscreteConditionals(size_t maxNrLeaves);

#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
/** Serialization function */
Expand Down
1 change: 1 addition & 0 deletions gtsam/hybrid/HybridFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <gtsam/base/Testable.h>
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/TableFactor.h>
#include <gtsam/inference/Factor.h>
#include <gtsam/linear/GaussianFactorGraph.h>
#include <gtsam/nonlinear/Values.h>
Expand Down
5 changes: 3 additions & 2 deletions gtsam/hybrid/HybridFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
* @date January, 2023
*/

#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/hybrid/HybridFactorGraph.h>

namespace gtsam {
Expand All @@ -26,7 +25,7 @@ namespace gtsam {
std::set<DiscreteKey> HybridFactorGraph::discreteKeys() const {
std::set<DiscreteKey> keys;
for (auto& factor : factors_) {
if (auto p = std::dynamic_pointer_cast<DecisionTreeFactor>(factor)) {
if (auto p = std::dynamic_pointer_cast<DiscreteFactor>(factor)) {
for (const DiscreteKey& key : p->discreteKeys()) {
keys.insert(key);
}
Expand Down Expand Up @@ -67,6 +66,8 @@ const KeySet HybridFactorGraph::continuousKeySet() const {
for (const Key& key : p->continuousKeys()) {
keys.insert(key);
}
} else if (auto p = std::dynamic_pointer_cast<GaussianFactor>(factor)) {
keys.insert(p->keys().begin(), p->keys().end());
}
}
return keys;
Expand Down
Loading
Loading