Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GTSAM_DT_MERGING Flag #1501

Merged
merged 53 commits into from
Nov 11, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
e114e9f
add nrAssignments method for DecisionTree
varunagrawal Mar 26, 2023
6aa7d66
add unit test showing issue with nrAssignments
varunagrawal Mar 26, 2023
1818695
updated docs to better describe nrAssignments
varunagrawal Mar 29, 2023
0cd36db
Merge branch 'develop' into fix-1496
varunagrawal Jun 7, 2023
73b563a
WIP for debugging nrAssignments issue
varunagrawal Jun 8, 2023
8a8f146
update Unique to be recursive
varunagrawal Jun 8, 2023
ff1ea32
remove unnecessary code
varunagrawal Jun 8, 2023
dbd0a7d
re-enable DecisionTree tests
varunagrawal Jun 8, 2023
68cb724
add new build method to replace create, and let create call Unique
varunagrawal Jun 8, 2023
be70ffc
remove excessive Unique call to improve efficiency
varunagrawal Jun 8, 2023
c3090f0
cleanup
varunagrawal Jun 8, 2023
70ffbf3
mark nrAssignments as const
varunagrawal Jun 8, 2023
2352043
rename GTSAM_DT_NO_PRUNING to GTSAM_DT_NO_MERGING to help with disamb…
varunagrawal Jun 8, 2023
2998820
bottom-up Unique method that works much, much better
varunagrawal Jun 8, 2023
a66e270
print nrAssignments when printing decision trees
varunagrawal Jun 8, 2023
d74e41a
Merge branch 'develop' into decisiontree-improvements
varunagrawal Jun 9, 2023
39cf348
Merge branch 'develop' into fix-1496
varunagrawal Jun 9, 2023
0cb1316
Merge branch 'fix-1496' into decisiontree-improvements
varunagrawal Jun 9, 2023
76568f2
formatting
varunagrawal Jun 9, 2023
29c1816
change to GTSAM_DT_MERGING and expose via CMake
varunagrawal Jun 10, 2023
8959982
remove extra calls to Unique
varunagrawal Jun 14, 2023
88ab371
Merge branch 'develop' into decisiontree-improvements
varunagrawal Jun 22, 2023
7af8e66
Merge branch 'develop' into decisiontree-improvements
varunagrawal Jun 22, 2023
c605a5b
Merge branch 'develop' into fix-1496
varunagrawal Jun 26, 2023
b37fc3f
update DecisionTree::nrAssignments docstring
varunagrawal Jun 26, 2023
3d7163a
Merge branch 'fix-1496' into decisiontree-improvements
varunagrawal Jun 26, 2023
b24f20a
fix tests to work when GTSAM_DT_MERGING=OFF
varunagrawal Jun 26, 2023
8ffddc4
print GTSAM_DT_MERGING cmake config
varunagrawal Jun 26, 2023
e5fea0d
update docstring
varunagrawal Jun 26, 2023
9b7f4b3
fix test case
varunagrawal Jun 28, 2023
8c38e45
enumerate all assignments for computing probabilities to prune
varunagrawal Jun 28, 2023
b86696a
Merge pull request #1542 from borglab/decisiontree-improvements
varunagrawal Jun 28, 2023
647d3c0
remove nrAssignments from the DecisionTree
varunagrawal Jun 28, 2023
2db0828
Revert "remove nrAssignments from the DecisionTree"
varunagrawal Jul 10, 2023
b7deefd
Revert "enumerate all assignments for computing probabilities to prune"
varunagrawal Jul 10, 2023
e5a7bac
Merge pull request #1555 from borglab/remove-nrAssignments
varunagrawal Jul 10, 2023
3fe9f1a
Merge branch 'develop' into fix-1496
varunagrawal Jul 18, 2023
ff7c368
Merge branch 'hybrid-tablefactor-2' into fix-1496
varunagrawal Jul 19, 2023
cf6c1ca
fix tests
varunagrawal Jul 19, 2023
372e703
Merge branch 'develop' into fix-1496
varunagrawal Jul 19, 2023
1dfb388
fix odd behavior in nrAssignments
varunagrawal Jul 20, 2023
ea24a2c
park changes so I can come back to them later
varunagrawal Jul 20, 2023
369d08b
Merge branch 'develop' into fix-1496
varunagrawal Jul 28, 2023
b35fb0f
update tests
varunagrawal Jul 28, 2023
4e9d849
remove prints
varunagrawal Jul 28, 2023
4580c51
undo change
varunagrawal Jul 28, 2023
8cb33dd
remove make_unique flag
varunagrawal Jul 29, 2023
94d737e
remove printing
varunagrawal Jul 29, 2023
4386c51
remove nrAssignments from DecisionTree
varunagrawal Nov 6, 2023
ecd6450
Merge branch 'develop' into fix-1496
varunagrawal Nov 6, 2023
c4d11c4
fix unittest assertion deprecation
varunagrawal Nov 6, 2023
9b67c3a
Merge branch 'develop' into remove-nrAssignments
varunagrawal Nov 6, 2023
fe81362
Merge branch 'fix-1496' into remove-nrAssignments
varunagrawal Nov 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion gtsam/discrete/DecisionTree-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -649,8 +649,9 @@ namespace gtsam {
throw std::invalid_argument("DecisionTree::create invalid argument");
}
auto choice = std::make_shared<Choice>(begin->first, endY - beginY);
for (ValueIt y = beginY; y != endY; y++)
for (ValueIt y = beginY; y != endY; y++) {
choice->push_back(NodePtr(new Leaf(*y)));
}
return Choice::Unique(choice);
}

Expand Down Expand Up @@ -827,6 +828,16 @@ namespace gtsam {
return total;
}

/****************************************************************************/
template <typename L, typename Y>
size_t DecisionTree<L, Y>::nrAssignments() const {
size_t n = 0;
this->visitLeaf([&n](const DecisionTree<L, Y>::Leaf& leaf) {
n += leaf.nrAssignments();
});
return n;
}

/****************************************************************************/
// fold is just done with a visit
template <typename L, typename Y>
Expand Down
9 changes: 9 additions & 0 deletions gtsam/discrete/DecisionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,15 @@ namespace gtsam {
/// Return the number of leaves in the tree.
size_t nrLeaves() const;

/**
* @brief Return the number of total leaf assignments.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please give examples as to what this is, I think "number of assignments" needs to be explained.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO(Frank): Add comments to mention that this is a convenience function and not used for any major operation.

* This includes counts removed from implicit pruning hence,
* it will always be >= nrLeaves().
*
* @return size_t
*/
size_t nrAssignments() const;

/**
* @brief Fold a binary function over the tree, returning accumulator.
*
Expand Down
23 changes: 23 additions & 0 deletions gtsam/discrete/tests/testDecisionTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/discrete/Signature.h>
#include <gtsam/inference/Symbol.h>

#include <iomanip>

Expand Down Expand Up @@ -329,6 +330,9 @@ TEST(DecisionTree, Containers) {
TEST(DecisionTree, NrAssignments) {
const std::pair<string, size_t> A("A", 2), B("B", 2), C("C", 2);
DT tree({A, B, C}, "1 1 1 1 1 1 1 1");

EXPECT_LONGS_EQUAL(8, tree.nrAssignments());

EXPECT(tree.root_->isLeaf());
auto leaf = std::dynamic_pointer_cast<const DT::Leaf>(tree.root_);
EXPECT_LONGS_EQUAL(8, leaf->nrAssignments());
Expand All @@ -348,6 +352,8 @@ TEST(DecisionTree, NrAssignments) {
1 1 Leaf 5
*/

EXPECT_LONGS_EQUAL(8, tree2.nrAssignments());

auto root = std::dynamic_pointer_cast<const DT::Choice>(tree2.root_);
CHECK(root);
auto choice0 = std::dynamic_pointer_cast<const DT::Choice>(root->branches()[0]);
Expand Down Expand Up @@ -531,6 +537,23 @@ TEST(DecisionTree, ApplyWithAssignment) {
EXPECT_LONGS_EQUAL(5, count);
}

/* ************************************************************************** */
// Test number of assignments.
TEST(DecisionTree, NrAssignments2) {
using gtsam::symbol_shorthand::M;

DiscreteKeys keys{{M(1), 2}, {M(0), 2}};
std::vector<double> probs = {0, 0, 1, 2};
DecisionTree<Key, double> dt1(keys, probs);

EXPECT_LONGS_EQUAL(4, dt1.nrAssignments());

DiscreteKeys keys2{{M(0), 2}, {M(1), 2}};
DecisionTree<Key, double> dt2(keys2, probs);
//TODO(Varun) The below is failing, because the number of assignments aren't being set correctly.
EXPECT_LONGS_EQUAL(4, dt2.nrAssignments());
}

/* ************************************************************************* */
int main() {
TestResult tr;
Expand Down