From 091c9c95575a88dfede3b9d963b60832aebc9966 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 8 Jan 2022 18:29:11 -0500 Subject: [PATCH 01/14] fail fast in linux matrix --- .github/workflows/build-linux.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build-linux.yml b/.github/workflows/build-linux.yml index f52e5eec36..7b13b66460 100644 --- a/.github/workflows/build-linux.yml +++ b/.github/workflows/build-linux.yml @@ -15,7 +15,7 @@ jobs: BOOST_VERSION: 1.67.0 strategy: - fail-fast: false + fail-fast: true matrix: # Github Actions requires a single row to be added to the build matrix. # See https://help.github.com/en/articles/workflow-syntax-for-github-actions. From fa5ead62465cb84709aea44b36a39da9b39c7dd4 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 9 Jan 2022 15:59:40 -0500 Subject: [PATCH 02/14] Fix failing test --- gtsam_unstable/discrete/tests/testLoopyBelief.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/gtsam_unstable/discrete/tests/testLoopyBelief.cpp b/gtsam_unstable/discrete/tests/testLoopyBelief.cpp index 6561949b14..eac0d834e6 100644 --- a/gtsam_unstable/discrete/tests/testLoopyBelief.cpp +++ b/gtsam_unstable/discrete/tests/testLoopyBelief.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include From 3851a985179b41000714b78523f45232018b2156 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 9 Jan 2022 17:00:41 -0500 Subject: [PATCH 03/14] Fix single quotes --- gtsam/discrete/DecisionTreeFactor.cpp | 2 +- gtsam/discrete/DiscreteConditional.cpp | 2 +- gtsam/discrete/DiscreteValues.cpp | 4 ++-- gtsam/discrete/tests/testDiscreteValues.cpp | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index c50811a506..ad4cbad434 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -226,7 +226,7 @@ namespace gtsam { stringstream ss; // Print out preamble. - ss << "
\n\n \n"; + ss << "
\n
\n \n"; // Print out header row. ss << " "; diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 82018abea6..0bdc7d7b5a 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -396,7 +396,7 @@ string DiscreteConditional::html(const KeyFormatter& keyFormatter, } // Print out preamble. - ss << "
\n \n"; + ss << "
\n \n"; // Print out header row. ss << " "; diff --git a/gtsam/discrete/DiscreteValues.cpp b/gtsam/discrete/DiscreteValues.cpp index 8a09de7425..5d0c8dd3d5 100644 --- a/gtsam/discrete/DiscreteValues.cpp +++ b/gtsam/discrete/DiscreteValues.cpp @@ -65,7 +65,7 @@ string DiscreteValues::html(const KeyFormatter& keyFormatter, stringstream ss; // Print out preamble. - ss << "
\n
\n \n"; + ss << "
\n
\n \n"; // Print out header row. ss << " \n"; @@ -76,7 +76,7 @@ string DiscreteValues::html(const KeyFormatter& keyFormatter, // Print out all rows. for (const auto& kv : *this) { ss << " "; - ss << ""; ss << "\n"; } diff --git a/gtsam/discrete/tests/testDiscreteValues.cpp b/gtsam/discrete/tests/testDiscreteValues.cpp index 5e7c0ac6f1..c8a1fa1680 100644 --- a/gtsam/discrete/tests/testDiscreteValues.cpp +++ b/gtsam/discrete/tests/testDiscreteValues.cpp @@ -57,8 +57,8 @@ TEST(DiscreteValues, htmlWithValueFormatter) { " \n" " \n" " \n" - " \n" - " \n" + " \n" + " \n" " \n" "
Variablevalue
" << keyFormatter(kv.first) << "\'" + ss << "" << keyFormatter(kv.first) << "" << Translate(names, kv.first, kv.second) << "
Variablevalue
B'-
A'One
B-
AOne
\n" "
"; From b79c59acd5c630d258dab2b75dd46209aa9088c8 Mon Sep 17 00:00:00 2001 From: Jose Luis Blanco-Claraco Date: Wed, 12 Jan 2022 11:00:24 +0100 Subject: [PATCH 04/14] FG print(): fix empty lines on nullptr; avoid endl --- gtsam/nonlinear/NonlinearFactorGraph.cpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/gtsam/nonlinear/NonlinearFactorGraph.cpp b/gtsam/nonlinear/NonlinearFactorGraph.cpp index 0d1ed31487..0e0d70268c 100644 --- a/gtsam/nonlinear/NonlinearFactorGraph.cpp +++ b/gtsam/nonlinear/NonlinearFactorGraph.cpp @@ -54,9 +54,14 @@ void NonlinearFactorGraph::print(const std::string& str, const KeyFormatter& key for (size_t i = 0; i < factors_.size(); i++) { stringstream ss; ss << "Factor " << i << ": "; - if (factors_[i] != nullptr) factors_[i]->print(ss.str(), keyFormatter); - cout << endl; + if (factors_[i] != nullptr) { + factors_[i]->print(ss.str(), keyFormatter); + cout << "\n"; + } else { + cout << ss.str() << "nullptr\n"; + } } + std::cout.flush(); } /* ************************************************************************* */ @@ -80,8 +85,9 @@ void NonlinearFactorGraph::printErrors(const Values& values, const std::string& factor->print(ss.str(), keyFormatter); cout << "error = " << errorValue << "\n"; } - cout << endl; // only one "endl" at end might be faster, \n for each factor + cout << "\n"; } + std::cout.flush(); } /* ************************************************************************* */ From 3c804d89b5a5da2113d0241a3f25728bbef4c34c Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Wed, 12 Jan 2022 16:50:10 -0500 Subject: [PATCH 05/14] add better tests for probPrime and add a fix --- gtsam/linear/GaussianFactorGraph.h | 3 ++- .../linear/tests/testGaussianFactorGraph.cpp | 23 +++++++++++++++++++ gtsam/nonlinear/NonlinearFactorGraph.cpp | 3 ++- gtsam/nonlinear/NonlinearFactorGraph.h | 2 +- tests/testNonlinearFactorGraph.cpp | 18 +++++++++++++++ 5 files changed, 46 insertions(+), 3 deletions(-) diff --git a/gtsam/linear/GaussianFactorGraph.h b/gtsam/linear/GaussianFactorGraph.h index 7bee4c9fb5..f392221222 100644 --- a/gtsam/linear/GaussianFactorGraph.h +++ b/gtsam/linear/GaussianFactorGraph.h @@ -154,7 +154,8 @@ namespace gtsam { /** Unnormalized probability. O(n) */ double probPrime(const VectorValues& c) const { - return exp(-0.5 * error(c)); + // NOTE the 0.5 constant is handled by the factor error. + return exp(-error(c)); } /** diff --git a/gtsam/linear/tests/testGaussianFactorGraph.cpp b/gtsam/linear/tests/testGaussianFactorGraph.cpp index bb07a36aae..41464a1109 100644 --- a/gtsam/linear/tests/testGaussianFactorGraph.cpp +++ b/gtsam/linear/tests/testGaussianFactorGraph.cpp @@ -426,6 +426,7 @@ TEST(GaussianFactorGraph, hessianDiagonal) { EXPECT(assert_equal(expected, actual)); } +/* ************************************************************************* */ TEST(GaussianFactorGraph, DenseSolve) { GaussianFactorGraph fg = createSimpleGaussianFactorGraph(); VectorValues expected = fg.optimize(); @@ -433,6 +434,28 @@ TEST(GaussianFactorGraph, DenseSolve) { EXPECT(assert_equal(expected, actual)); } +/* ************************************************************************* */ +TEST(GaussianFactorGraph, ProbPrime) { + GaussianFactorGraph gfg; + gfg.emplace_shared(1, I_1x1, Z_1x1, + noiseModel::Isotropic::Sigma(1, 1.0)); + + VectorValues values; + values.insert(1, I_1x1); + + // We are testing the normal distribution PDF where info matrix Σ = 1, + // mean mu = 0 and x = 1. + // Therefore factor squared error: y = 0.5 * (Σ*x - mu)^2 = + // 0.5 * (1.0 - 0)^2 = 0.5 + // NOTE the 0.5 constant is a part of the factor error. + EXPECT_DOUBLES_EQUAL(0.5, gfg.error(values), 1e-12); + + // The gaussian PDF value is: exp^(-0.5 * (Σ*x - mu)^2) / sqrt(2 * PI) + // Ignore the denominator and we get: exp^(-0.5 * (1.0)^2) = exp^(-0.5) + double expected = exp(-0.5); + EXPECT_DOUBLES_EQUAL(expected, gfg.probPrime(values), 1e-12); +} + /* ************************************************************************* */ int main() { TestResult tr; diff --git a/gtsam/nonlinear/NonlinearFactorGraph.cpp b/gtsam/nonlinear/NonlinearFactorGraph.cpp index 0e0d70268c..89236ea878 100644 --- a/gtsam/nonlinear/NonlinearFactorGraph.cpp +++ b/gtsam/nonlinear/NonlinearFactorGraph.cpp @@ -45,7 +45,8 @@ template class FactorGraph; /* ************************************************************************* */ double NonlinearFactorGraph::probPrime(const Values& values) const { - return exp(-0.5 * error(values)); + // NOTE the 0.5 constant is handled by the factor error. + return exp(-error(values)); } /* ************************************************************************* */ diff --git a/gtsam/nonlinear/NonlinearFactorGraph.h b/gtsam/nonlinear/NonlinearFactorGraph.h index 160e469241..ea8748f63b 100644 --- a/gtsam/nonlinear/NonlinearFactorGraph.h +++ b/gtsam/nonlinear/NonlinearFactorGraph.h @@ -90,7 +90,7 @@ namespace gtsam { /** Test equality */ bool equals(const NonlinearFactorGraph& other, double tol = 1e-9) const; - /** unnormalized error, \f$ 0.5 \sum_i (h_i(X_i)-z)^2/\sigma^2 \f$ in the most common case */ + /** unnormalized error, \f$ \sum_i 0.5 (h_i(X_i)-z)^2 / \sigma^2 \f$ in the most common case */ double error(const Values& values) const; /** Unnormalized probability. O(n) */ diff --git a/tests/testNonlinearFactorGraph.cpp b/tests/testNonlinearFactorGraph.cpp index 4dec08f45c..8a360e4542 100644 --- a/tests/testNonlinearFactorGraph.cpp +++ b/tests/testNonlinearFactorGraph.cpp @@ -107,6 +107,24 @@ TEST( NonlinearFactorGraph, probPrime ) DOUBLES_EQUAL(expected,actual,0); } +/* ************************************************************************* */ +TEST(NonlinearFactorGraph, ProbPrime2) { + NonlinearFactorGraph fg; + fg.emplace_shared>(1, 0.0, + noiseModel::Isotropic::Sigma(1, 1.0)); + + Values values; + values.insert(1, 1.0); + + // The prior factor squared error is: 0.5. + EXPECT_DOUBLES_EQUAL(0.5, fg.error(values), 1e-12); + + // The probability value is: exp^(-factor_error) / sqrt(2 * PI) + // Ignore the denominator and we get: exp^(-factor_error) = exp^(-0.5) + double expected = exp(-0.5); + EXPECT_DOUBLES_EQUAL(expected, fg.probPrime(values), 1e-12); +} + /* ************************************************************************* */ TEST( NonlinearFactorGraph, linearize ) { From be5aa56df72f654f338168d6e79c69e915186ebc Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 15 Jan 2022 08:15:46 -0500 Subject: [PATCH 06/14] Constructor from PMF --- gtsam/discrete/DiscretePrior.h | 14 +++++++------- gtsam/discrete/discrete.i | 1 + gtsam/discrete/tests/testDiscretePrior.cpp | 11 +++++++++-- python/gtsam/tests/test_DiscretePrior.py | 6 +++++- 4 files changed, 22 insertions(+), 10 deletions(-) diff --git a/gtsam/discrete/DiscretePrior.h b/gtsam/discrete/DiscretePrior.h index 9ac8acb17a..1da1882155 100644 --- a/gtsam/discrete/DiscretePrior.h +++ b/gtsam/discrete/DiscretePrior.h @@ -48,17 +48,17 @@ class GTSAM_EXPORT DiscretePrior : public DiscreteConditional { DiscretePrior(const Signature& s) : Base(s) {} /** - * Construct from key and a Signature::Table specifying the - * conditional probability table (CPT). + * Construct from key and a vector of floats specifying the probability mass + * function (PMF). * - * Example: DiscretePrior P(D, table); + * Example: DiscretePrior P(D, {0.4, 0.6}); */ - DiscretePrior(const DiscreteKey& key, const Signature::Table& table) - : Base(Signature(key, {}, table)) {} + DiscretePrior(const DiscreteKey& key, const std::vector& spec) + : DiscretePrior(Signature(key, {}, Signature::Table{spec})) {} /** - * Construct from key and a string specifying the conditional - * probability table (CPT). + * Construct from key and a string specifying the probability mass function + * (PMF). * * Example: DiscretePrior P(D, "9/1 2/8 3/7 1/9"); */ diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 218b790e88..12bd5be549 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -120,6 +120,7 @@ virtual class DiscretePrior : gtsam::DiscreteConditional { DiscretePrior(); DiscretePrior(const gtsam::DecisionTreeFactor& f); DiscretePrior(const gtsam::DiscreteKey& key, string spec); + DiscretePrior(const gtsam::DiscreteKey& key, std::vector spec); void print(string s = "Discrete Prior\n", const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; diff --git a/gtsam/discrete/tests/testDiscretePrior.cpp b/gtsam/discrete/tests/testDiscretePrior.cpp index 23f093b229..6225d227e0 100644 --- a/gtsam/discrete/tests/testDiscretePrior.cpp +++ b/gtsam/discrete/tests/testDiscretePrior.cpp @@ -27,12 +27,19 @@ static const DiscreteKey X(0, 2); /* ************************************************************************* */ TEST(DiscretePrior, constructors) { + DecisionTreeFactor f(X, "0.4 0.6"); + DiscretePrior expected(f); + DiscretePrior actual(X % "2/3"); EXPECT_LONGS_EQUAL(1, actual.nrFrontals()); EXPECT_LONGS_EQUAL(0, actual.nrParents()); - DecisionTreeFactor f(X, "0.4 0.6"); - DiscretePrior expected(f); EXPECT(assert_equal(expected, actual, 1e-9)); + + const vector pmf{0.4, 0.6}; + DiscretePrior actual2(X, pmf); + EXPECT_LONGS_EQUAL(1, actual2.nrFrontals()); + EXPECT_LONGS_EQUAL(0, actual2.nrParents()); + EXPECT(assert_equal(expected, actual2, 1e-9)); } /* ************************************************************************* */ diff --git a/python/gtsam/tests/test_DiscretePrior.py b/python/gtsam/tests/test_DiscretePrior.py index 2c923589ce..06bdc81ca7 100644 --- a/python/gtsam/tests/test_DiscretePrior.py +++ b/python/gtsam/tests/test_DiscretePrior.py @@ -25,12 +25,16 @@ class TestDiscretePrior(GtsamTestCase): def test_constructor(self): """Test various constructors.""" - actual = DiscretePrior(X, "2/3") keys = DiscreteKeys() keys.push_back(X) f = DecisionTreeFactor(keys, "0.4 0.6") expected = DiscretePrior(f) + + actual = DiscretePrior(X, "2/3") self.gtsamAssertEquals(actual, expected) + + actual2 = DiscretePrior(X, [0.4, 0.6]) + self.gtsamAssertEquals(actual2, expected) def test_operator(self): prior = DiscretePrior(X, "2/3") From c15bbed9dc044ffa159ec5a243dce6985e5203cd Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 15 Jan 2022 08:44:10 -0500 Subject: [PATCH 07/14] exposing more factor methods --- gtsam/discrete/discrete.i | 9 ++++ .../discrete/tests/testDecisionTreeFactor.cpp | 26 ++++++---- python/gtsam/tests/test_DecisionTreeFactor.py | 52 +++++++++++++++++-- 3 files changed, 73 insertions(+), 14 deletions(-) diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 12bd5be549..24a9410561 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -58,6 +58,15 @@ virtual class DecisionTreeFactor : gtsam::DiscreteFactor { const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DecisionTreeFactor& other, double tol = 1e-9) const; + + double operator()(const gtsam::DiscreteValues& values) const; + gtsam::DecisionTreeFactor operator*(const gtsam::DecisionTreeFactor& f) const; + size_t cardinality(gtsam::Key j) const; + gtsam::DecisionTreeFactor operator/(const gtsam::DecisionTreeFactor& f) const; + gtsam::DecisionTreeFactor* sum(size_t nrFrontals) const; + gtsam::DecisionTreeFactor* sum(const gtsam::Ordering& keys) const; + gtsam::DecisionTreeFactor* max(size_t nrFrontals) const; + string dot( const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter, bool showZero = true) const; diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index 594134edf7..f2ab5f6bc2 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -17,10 +17,12 @@ * @author Duy-Nguyen Ta */ -#include -#include -#include #include +#include +#include +#include +#include + #include using namespace boost::assign; @@ -51,17 +53,21 @@ TEST( DecisionTreeFactor, constructors) } /* ************************************************************************* */ -TEST_UNSAFE( DecisionTreeFactor, multiplication) -{ - DiscreteKey v0(0,2), v1(1,2), v2(2,2); +TEST(DecisionTreeFactor, multiplication) { + DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2); + // Multiply with a DiscretePrior, i.e., Bayes Law! + DiscretePrior prior(v1 % "1/3"); DecisionTreeFactor f1(v0 & v1, "1 2 3 4"); - DecisionTreeFactor f2(v1 & v2, "5 6 7 8"); - - DecisionTreeFactor expected(v0 & v1 & v2, "5 6 14 16 15 18 28 32"); + DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3"); + CHECK(assert_equal(expected, prior * f1)); + CHECK(assert_equal(expected, f1 * prior)); + // Multiply two factors + DecisionTreeFactor f2(v1 & v2, "5 6 7 8"); DecisionTreeFactor actual = f1 * f2; - CHECK(assert_equal(expected, actual)); + DecisionTreeFactor expected2(v0 & v1 & v2, "5 6 14 16 15 18 28 32"); + CHECK(assert_equal(expected2, actual)); } /* ************************************************************************* */ diff --git a/python/gtsam/tests/test_DecisionTreeFactor.py b/python/gtsam/tests/test_DecisionTreeFactor.py index 12a60d5cb1..03d9f82d7e 100644 --- a/python/gtsam/tests/test_DecisionTreeFactor.py +++ b/python/gtsam/tests/test_DecisionTreeFactor.py @@ -13,7 +13,7 @@ import unittest -from gtsam import DecisionTreeFactor, DecisionTreeFactor, DiscreteKeys +from gtsam import DecisionTreeFactor, DiscreteValues, DiscretePrior, Ordering from gtsam.utils.test_case import GtsamTestCase @@ -21,15 +21,59 @@ class TestDecisionTreeFactor(GtsamTestCase): """Tests for DecisionTreeFactors.""" def setUp(self): - A = (12, 3) - B = (5, 2) - self.factor = DecisionTreeFactor([A, B], "1 2 3 4 5 6") + self.A = (12, 3) + self.B = (5, 2) + self.factor = DecisionTreeFactor([self.A, self.B], "1 2 3 4 5 6") def test_enumerate(self): actual = self.factor.enumerate() _, values = zip(*actual) self.assertEqual(list(values), [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + def test_multiplication(self): + """Test whether multiplication works with overloading.""" + v0 = (0, 2) + v1 = (1, 2) + v2 = (2, 2) + + # Multiply with a DiscretePrior, i.e., Bayes Law! + prior = DiscretePrior(v1, [1, 3]) + f1 = DecisionTreeFactor([v0, v1], "1 2 3 4") + expected = DecisionTreeFactor([v0, v1], "0.25 1.5 0.75 3") + self.gtsamAssertEquals(prior * f1, expected) + self.gtsamAssertEquals(f1 * prior, expected) + + # Multiply two factors + f2 = DecisionTreeFactor([v1, v2], "5 6 7 8") + actual = f1 * f2 + expected2 = DecisionTreeFactor([v0, v1, v2], "5 6 14 16 15 18 28 32") + self.gtsamAssertEquals(actual, expected2) + + def test_methods(self): + """Test whether we can call methods in python.""" + # double operator()(const DiscreteValues& values) const; + values = DiscreteValues() + values[self.A[0]] = 0 + values[self.B[0]] = 0 + self.assertIsInstance(self.factor(values), float) + + # size_t cardinality(Key j) const; + self.assertIsInstance(self.factor.cardinality(self.A[0]), int) + + # DecisionTreeFactor operator/(const DecisionTreeFactor& f) const; + self.assertIsInstance(self.factor / self.factor, DecisionTreeFactor) + + # DecisionTreeFactor* sum(size_t nrFrontals) const; + self.assertIsInstance(self.factor.sum(1), DecisionTreeFactor) + + # DecisionTreeFactor* sum(const Ordering& keys) const; + ordering = Ordering() + ordering.push_back(self.A[0]) + self.assertIsInstance(self.factor.sum(ordering), DecisionTreeFactor) + + # DecisionTreeFactor* max(size_t nrFrontals) const; + self.assertIsInstance(self.factor.max(1), DecisionTreeFactor) + def test_markdown(self): """Test whether the _repr_markdown_ method.""" From 0909e9838915ab6b6332d27462d9dd58309b438a Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 15 Jan 2022 15:11:25 -0500 Subject: [PATCH 08/14] Comments only --- gtsam/discrete/DecisionTreeFactor.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index b5f6c0c4af..8beeb4c4a0 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -57,7 +57,7 @@ namespace gtsam { /** Default constructor for I/O */ DecisionTreeFactor(); - /** Constructor from Indices, Ordering, and AlgebraicDecisionDiagram */ + /** Constructor from DiscreteKeys and AlgebraicDecisionTree */ DecisionTreeFactor(const DiscreteKeys& keys, const ADT& potentials); /** Constructor from doubles */ @@ -139,14 +139,14 @@ namespace gtsam { /** * Apply binary operator (*this) "op" f * @param f the second argument for op - * @param op a binary operator that operates on AlgebraicDecisionDiagram potentials + * @param op a binary operator that operates on AlgebraicDecisionTree */ DecisionTreeFactor apply(const DecisionTreeFactor& f, ADT::Binary op) const; /** * Combine frontal variables using binary operator "op" * @param nrFrontals nr. of frontal to combine variables in this factor - * @param op a binary operator that operates on AlgebraicDecisionDiagram potentials + * @param op a binary operator that operates on AlgebraicDecisionTree * @return shared pointer to newly created DecisionTreeFactor */ shared_ptr combine(size_t nrFrontals, ADT::Binary op) const; @@ -154,7 +154,7 @@ namespace gtsam { /** * Combine frontal variables in an Ordering using binary operator "op" * @param nrFrontals nr. of frontal to combine variables in this factor - * @param op a binary operator that operates on AlgebraicDecisionDiagram potentials + * @param op a binary operator that operates on AlgebraicDecisionTree * @return shared pointer to newly created DecisionTreeFactor */ shared_ptr combine(const Ordering& keys, ADT::Binary op) const; From f9dd225ca5d4498bcd9b3f1aa75441c0a351e3f1 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 15 Jan 2022 15:12:55 -0500 Subject: [PATCH 09/14] Replace buggy/awkward Combine with principled operator*, remove toFactor --- gtsam/discrete/DiscreteConditional.cpp | 77 ++++++++--- gtsam/discrete/DiscreteConditional.h | 74 +++++------ gtsam/discrete/discrete.i | 1 - .../discrete/tests/testDecisionTreeFactor.cpp | 2 +- .../tests/testDiscreteConditional.cpp | 122 ++++++++++++++---- gtsam/discrete/tests/testDiscretePrior.cpp | 13 ++ 6 files changed, 202 insertions(+), 87 deletions(-) diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 0bdc7d7b5a..5acd7c0f65 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -30,6 +30,7 @@ #include #include #include +#include using namespace std; using std::stringstream; @@ -38,37 +39,77 @@ using std::pair; namespace gtsam { // Instantiate base class -template class GTSAM_EXPORT Conditional ; +template class GTSAM_EXPORT + Conditional; -/* ******************************************************************************** */ +/* ************************************************************************** */ DiscreteConditional::DiscreteConditional(const size_t nrFrontals, - const DecisionTreeFactor& f) : - BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) { -} + const DecisionTreeFactor& f) + : BaseFactor(f / (*f.sum(nrFrontals))), BaseConditional(nrFrontals) {} -/* ******************************************************************************** */ +/* ************************************************************************** */ +DiscreteConditional::DiscreteConditional(size_t nrFrontals, + const DiscreteKeys& keys, + const ADT& potentials) + : BaseFactor(keys, potentials), BaseConditional(nrFrontals) {} + +/* ************************************************************************** */ DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint, - const DecisionTreeFactor& marginal) : - BaseFactor( - ISDEBUG("DiscreteConditional::COUNT") ? joint : joint / marginal), BaseConditional( - joint.size()-marginal.size()) { - if (ISDEBUG("DiscreteConditional::DiscreteConditional")) - cout << (firstFrontalKey()) << endl; //TODO Print all keys -} + const DecisionTreeFactor& marginal) + : BaseFactor(joint / marginal), + BaseConditional(joint.size() - marginal.size()) {} -/* ******************************************************************************** */ +/* ************************************************************************** */ DiscreteConditional::DiscreteConditional(const DecisionTreeFactor& joint, - const DecisionTreeFactor& marginal, const Ordering& orderedKeys) : - DiscreteConditional(joint, marginal) { + const DecisionTreeFactor& marginal, + const Ordering& orderedKeys) + : DiscreteConditional(joint, marginal) { keys_.clear(); keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end()); } -/* ******************************************************************************** */ +/* ************************************************************************** */ DiscreteConditional::DiscreteConditional(const Signature& signature) : BaseFactor(signature.discreteKeys(), signature.cpt()), BaseConditional(1) {} +/* ************************************************************************** */ +DiscreteConditional DiscreteConditional::operator*( + const DiscreteConditional& other) const { + // Take union of frontal keys + std::set newFrontals; + for (auto&& key : this->frontals()) newFrontals.insert(key); + for (auto&& key : other.frontals()) newFrontals.insert(key); + + // Check if frontals overlapped + if (nrFrontals() + other.nrFrontals() > newFrontals.size()) + throw std::invalid_argument( + "DiscreteConditional::operator* called with overlapping frontal keys."); + + // Now, add cardinalities. + DiscreteKeys discreteKeys; + for (auto&& key : frontals()) + discreteKeys.emplace_back(key, cardinality(key)); + for (auto&& key : other.frontals()) + discreteKeys.emplace_back(key, other.cardinality(key)); + + // Sort + std::sort(discreteKeys.begin(), discreteKeys.end()); + + // Add parents to set, to make them unique + std::set parents; + for (auto&& key : this->parents()) + if (!newFrontals.count(key)) parents.emplace(key, cardinality(key)); + for (auto&& key : other.parents()) + if (!newFrontals.count(key)) parents.emplace(key, other.cardinality(key)); + + // Finally, add parents to keys, in order + for (auto&& dk : parents) discreteKeys.push_back(dk); + + ADT product = ADT::apply(other, ADT::Ring::mul); + return DiscreteConditional(newFrontals.size(), discreteKeys, product); +} + /* ******************************************************************************** */ void DiscreteConditional::print(const string& s, const KeyFormatter& formatter) const { @@ -82,7 +123,7 @@ void DiscreteConditional::print(const string& s, cout << formatter(*it) << " "; } } - cout << ")"; + cout << "):\n"; ADT::print(""); cout << endl; } diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 4a83ff83a0..450af57ab5 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -49,14 +49,21 @@ class GTSAM_EXPORT DiscreteConditional /// @name Standard Constructors /// @{ - /** default constructor needed for serialization */ + /// Default constructor needed for serialization. DiscreteConditional() {} - /** constructor from factor */ + /// Construct from factor, taking the first `nFrontals` keys as frontals. DiscreteConditional(size_t nFrontals, const DecisionTreeFactor& f); + /** + * Construct from DiscreteKeys and AlgebraicDecisionTree, taking the first + * `nFrontals` keys as frontals, in the order given. + */ + DiscreteConditional(size_t nFrontals, const DiscreteKeys& keys, + const ADT& potentials); + /** Construct from signature */ - DiscreteConditional(const Signature& signature); + explicit DiscreteConditional(const Signature& signature); /** * Construct from key, parents, and a Signature::Table specifying the @@ -86,27 +93,38 @@ class GTSAM_EXPORT DiscreteConditional DiscreteConditional(const DiscreteKey& key, const std::string& spec) : DiscreteConditional(Signature(key, {}, spec)) {} - /** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */ + /** + * @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y) + * Assumes but *does not check* that f(Y)=sum_X f(X,Y). + */ DiscreteConditional(const DecisionTreeFactor& joint, const DecisionTreeFactor& marginal); - /** construct P(X|Y)=P(X,Y)/P(Y) from P(X,Y) and P(Y) */ + /** + * @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y) + * Assumes but *does not check* that f(Y)=sum_X f(X,Y). + * Makes sure the keys are ordered as given. Does not check orderedKeys. + */ DiscreteConditional(const DecisionTreeFactor& joint, const DecisionTreeFactor& marginal, const Ordering& orderedKeys); /** - * Combine several conditional into a single one. - * The conditionals must be given in increasing order, meaning that the - * parents of any conditional may not include a conditional coming before it. - * @param firstConditional Iterator to the first conditional to combine, must - * dereference to a shared_ptr. - * @param lastConditional Iterator to after the last conditional to combine, - * must dereference to a shared_ptr. - * */ - template - static shared_ptr Combine(ITERATOR firstConditional, - ITERATOR lastConditional); + * @brief Combine two conditionals, yielding a new conditional with the union + * of the frontal keys, ordered by gtsam::Key. + * + * The two conditionals must make a valid Bayes net fragment, i.e., + * the frontal variables cannot overlap, and must be acyclic: + * Example of correct use: + * P(A,B) = P(A|B) * P(B) + * P(A,B|C) = P(A|B) * P(B|C) + * P(A,B,C) = P(A,B|C) * P(C) + * Example of incorrect use: + * P(A|B) * P(A|C) = ? + * P(A|B) * P(B|A) = ? + * We check for overlapping frontals, but do *not* check for cyclic. + */ + DiscreteConditional operator*(const DiscreteConditional& other) const; /// @} /// @name Testable @@ -136,11 +154,6 @@ class GTSAM_EXPORT DiscreteConditional return ADT::operator()(values); } - /** Convert to a factor */ - DecisionTreeFactor::shared_ptr toFactor() const { - return DecisionTreeFactor::shared_ptr(new DecisionTreeFactor(*this)); - } - /** Restrict to given parent values, returns DecisionTreeFactor */ DecisionTreeFactor::shared_ptr choose( const DiscreteValues& parentsValues) const; @@ -208,23 +221,4 @@ class GTSAM_EXPORT DiscreteConditional template <> struct traits : public Testable {}; -/* ************************************************************************* */ -template -DiscreteConditional::shared_ptr DiscreteConditional::Combine( - ITERATOR firstConditional, ITERATOR lastConditional) { - // TODO: check for being a clique - - // multiply all the potentials of the given conditionals - size_t nrFrontals = 0; - DecisionTreeFactor product; - for (ITERATOR it = firstConditional; it != lastConditional; - ++it, ++nrFrontals) { - DiscreteConditional::shared_ptr c = *it; - DecisionTreeFactor::shared_ptr factor = c->toFactor(); - product = (*factor) * product; - } - // and then create a new multi-frontal conditional - return boost::make_shared(nrFrontals, product); -} - } // namespace gtsam diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 24a9410561..5fce25cf5e 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -102,7 +102,6 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { void printSignature( string s = "Discrete Conditional: ", const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; - gtsam::DecisionTreeFactor* toFactor() const; gtsam::DecisionTreeFactor* choose( const gtsam::DiscreteValues& parentsValues) const; gtsam::DecisionTreeFactor* likelihood( diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index f2ab5f6bc2..7e89874a59 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -60,7 +60,7 @@ TEST(DecisionTreeFactor, multiplication) { DiscretePrior prior(v1 % "1/3"); DecisionTreeFactor f1(v0 & v1, "1 2 3 4"); DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3"); - CHECK(assert_equal(expected, prior * f1)); + CHECK(assert_equal(expected, static_cast(prior) * f1)); CHECK(assert_equal(expected, f1 * prior)); // Multiply two factors diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index 3fb67a615c..03766136c1 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -34,20 +34,21 @@ using namespace gtsam; TEST(DiscreteConditional, constructors) { DiscreteKey X(0, 2), Y(2, 3), Z(1, 2); // watch ordering ! - DiscreteConditional expected(X | Y = "1/1 2/3 1/4"); - EXPECT_LONGS_EQUAL(0, *(expected.beginFrontals())); - EXPECT_LONGS_EQUAL(2, *(expected.beginParents())); - EXPECT(expected.endParents() == expected.end()); - EXPECT(expected.endFrontals() == expected.beginParents()); + DiscreteConditional actual(X | Y = "1/1 2/3 1/4"); + EXPECT_LONGS_EQUAL(0, *(actual.beginFrontals())); + EXPECT_LONGS_EQUAL(2, *(actual.beginParents())); + EXPECT(actual.endParents() == actual.end()); + EXPECT(actual.endFrontals() == actual.beginParents()); DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8"); - DiscreteConditional actual1(1, f1); - EXPECT(assert_equal(expected, actual1, 1e-9)); + DiscreteConditional expected1(1, f1); + EXPECT(assert_equal(expected1, actual, 1e-9)); DecisionTreeFactor f2( X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); DiscreteConditional actual2(1, f2); - EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9)); + DecisionTreeFactor expected2 = f2 / *f2.sum(1); + EXPECT(assert_equal(expected2, static_cast(actual2))); } /* ************************************************************************* */ @@ -61,6 +62,7 @@ TEST(DiscreteConditional, constructors_alt_interface) { r3 += 1.0, 4.0; table += r1, r2, r3; DiscreteConditional actual1(X, {Y}, table); + DecisionTreeFactor f1(X & Y, "0.5 0.4 0.2 0.5 0.6 0.8"); DiscreteConditional expected1(1, f1); EXPECT(assert_equal(expected1, actual1, 1e-9)); @@ -68,43 +70,109 @@ TEST(DiscreteConditional, constructors_alt_interface) { DecisionTreeFactor f2( X & Y & Z, "0.2 0.5 0.3 0.6 0.4 0.7 0.25 0.55 0.35 0.65 0.45 0.75"); DiscreteConditional actual2(1, f2); - EXPECT(assert_equal(f2 / *f2.sum(1), *actual2.toFactor(), 1e-9)); + DecisionTreeFactor expected2 = f2 / *f2.sum(1); + EXPECT(assert_equal(expected2, static_cast(actual2))); } /* ************************************************************************* */ TEST(DiscreteConditional, constructors2) { - // Declare keys and ordering DiscreteKey C(0, 2), B(1, 2); - DecisionTreeFactor actual(C & B, "0.8 0.75 0.2 0.25"); Signature signature((C | B) = "4/1 3/1"); - DiscreteConditional expected(signature); - DecisionTreeFactor::shared_ptr expectedFactor = expected.toFactor(); - EXPECT(assert_equal(*expectedFactor, actual)); + DiscreteConditional actual(signature); + + DecisionTreeFactor expected(C & B, "0.8 0.75 0.2 0.25"); + EXPECT(assert_equal(expected, static_cast(actual))); } /* ************************************************************************* */ TEST(DiscreteConditional, constructors3) { - // Declare keys and ordering DiscreteKey C(0, 2), B(1, 2), A(2, 2); - DecisionTreeFactor actual(C & B & A, "0.8 0.5 0.5 0.2 0.2 0.5 0.5 0.8"); Signature signature((C | B, A) = "4/1 1/1 1/1 1/4"); - DiscreteConditional expected(signature); - DecisionTreeFactor::shared_ptr expectedFactor = expected.toFactor(); - EXPECT(assert_equal(*expectedFactor, actual)); + DiscreteConditional actual(signature); + + DecisionTreeFactor expected(C & B & A, "0.8 0.5 0.5 0.2 0.2 0.5 0.5 0.8"); + EXPECT(assert_equal(expected, static_cast(actual))); } /* ************************************************************************* */ -TEST(DiscreteConditional, Combine) { +// Check calculation of joint P(A,B) +TEST(DiscreteConditional, Multiply) { DiscreteKey A(0, 2), B(1, 2); - vector c; - c.push_back(boost::make_shared(A | B = "1/2 2/1")); - c.push_back(boost::make_shared(B % "1/2")); - DecisionTreeFactor factor(A & B, "0.111111 0.444444 0.222222 0.222222"); - DiscreteConditional expected(2, factor); - auto actual = DiscreteConditional::Combine(c.begin(), c.end()); - EXPECT(assert_equal(expected, *actual, 1e-5)); + DiscreteConditional conditional(A | B = "1/2 2/1"); + DiscreteConditional prior(B % "1/2"); + + // P(A,B) = P(A|B) * P(B) = P(B) * P(A|B) + for (auto&& actual : {prior * conditional, conditional * prior}) { + EXPECT_LONGS_EQUAL(2, actual.nrFrontals()); + KeyVector frontals(actual.beginFrontals(), actual.endFrontals()); + EXPECT((frontals == KeyVector{0, 1})); + for (auto&& it : actual.enumerate()) { + const DiscreteValues& v = it.first; + EXPECT_DOUBLES_EQUAL(actual(v), conditional(v) * prior(v), 1e-9); + } + } } +/* ************************************************************************* */ +// Check calculation of conditional joint P(A,B|C) +TEST(DiscreteConditional, Multiply2) { + DiscreteKey A(0, 2), B(1, 2), C(2, 2); + DiscreteConditional A_given_B(A | B = "1/3 3/1"); + DiscreteConditional B_given_C(B | C = "1/3 3/1"); + // P(A,B|C) = P(A|B)P(B|C) = P(B|C)P(A|B) + for (auto&& actual : {A_given_B * B_given_C, B_given_C * A_given_B}) { + EXPECT_LONGS_EQUAL(2, actual.nrFrontals()); + EXPECT_LONGS_EQUAL(1, actual.nrParents()); + KeyVector frontals(actual.beginFrontals(), actual.endFrontals()); + EXPECT((frontals == KeyVector{0, 1})); + for (auto&& it : actual.enumerate()) { + const DiscreteValues& v = it.first; + EXPECT_DOUBLES_EQUAL(actual(v), A_given_B(v) * B_given_C(v), 1e-9); + } + } +} +/* ************************************************************************* */ +// Check calculation of conditional joint P(A,B|C), double check keys +TEST(DiscreteConditional, Multiply3) { + DiscreteKey A(1, 2), B(2, 2), C(0, 2); // different keys!!! + DiscreteConditional A_given_B(A | B = "1/3 3/1"); + DiscreteConditional B_given_C(B | C = "1/3 3/1"); + + // P(A,B|C) = P(A|B)P(B|C) = P(B|C)P(A|B) + for (auto&& actual : {A_given_B * B_given_C, B_given_C * A_given_B}) { + EXPECT_LONGS_EQUAL(2, actual.nrFrontals()); + EXPECT_LONGS_EQUAL(1, actual.nrParents()); + KeyVector frontals(actual.beginFrontals(), actual.endFrontals()); + EXPECT((frontals == KeyVector{1, 2})); + for (auto&& it : actual.enumerate()) { + const DiscreteValues& v = it.first; + EXPECT_DOUBLES_EQUAL(actual(v), A_given_B(v) * B_given_C(v), 1e-9); + } + } +} +/* ************************************************************************* */ +// Check calculation of conditional joint P(A,B,C|D,E) = P(A,B|D) P(C|D,E) +TEST(DiscreteConditional, Multiply4) { + DiscreteKey A(0, 2), B(1, 2), C(2, 2), D(4, 2), E(3, 2); + DiscreteConditional A_given_B(A | B = "1/3 3/1"); + DiscreteConditional B_given_D(B | D = "1/3 3/1"); + DiscreteConditional AB_given_D = A_given_B * B_given_D; + DiscreteConditional C_given_DE((C | D, E) = "4/1 1/1 1/1 1/4"); + + // P(A,B,C|D,E) = P(A,B|D) P(C|D,E) = P(C|D,E) P(A,B|D) + for (auto&& actual : {AB_given_D * C_given_DE, C_given_DE * AB_given_D}) { + EXPECT_LONGS_EQUAL(3, actual.nrFrontals()); + EXPECT_LONGS_EQUAL(2, actual.nrParents()); + KeyVector frontals(actual.beginFrontals(), actual.endFrontals()); + EXPECT((frontals == KeyVector{0, 1, 2})); + KeyVector parents(actual.beginParents(), actual.endParents()); + EXPECT((parents == KeyVector{3, 4})); + for (auto&& it : actual.enumerate()) { + const DiscreteValues& v = it.first; + EXPECT_DOUBLES_EQUAL(actual(v), AB_given_D(v) * C_given_DE(v), 1e-9); + } + } +} /* ************************************************************************* */ TEST(DiscreteConditional, likelihood) { DiscreteKey X(0, 2), Y(1, 3); diff --git a/gtsam/discrete/tests/testDiscretePrior.cpp b/gtsam/discrete/tests/testDiscretePrior.cpp index 6225d227e0..6ef57c7ff4 100644 --- a/gtsam/discrete/tests/testDiscretePrior.cpp +++ b/gtsam/discrete/tests/testDiscretePrior.cpp @@ -42,6 +42,19 @@ TEST(DiscretePrior, constructors) { EXPECT(assert_equal(expected, actual2, 1e-9)); } +/* ************************************************************************* */ +TEST(DiscretePrior, Multiply) { + DiscreteKey A(0, 2), B(1, 2); + DiscreteConditional conditional(A | B = "1/2 2/1"); + DiscretePrior prior(B, "1/2"); + DiscreteConditional actual = prior * conditional; // P(A|B) * P(B) + + EXPECT_LONGS_EQUAL(2, actual.nrFrontals()); // = P(A,B) + DecisionTreeFactor factor(A & B, "1 4 2 2"); + DiscreteConditional expected(2, factor); + EXPECT(assert_equal(expected, actual, 1e-5)); +} + /* ************************************************************************* */ TEST(DiscretePrior, operator) { DiscretePrior prior(X % "2/3"); From 23a8dba7163f57988a495c898828d938b1a678dd Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 15 Jan 2022 15:33:01 -0500 Subject: [PATCH 10/14] Wrapped multiplication --- gtsam/discrete/discrete.i | 4 ++ python/gtsam/tests/test_DecisionTreeFactor.py | 2 +- .../gtsam/tests/test_DiscreteConditional.py | 48 ++++++++++++++++++- 3 files changed, 51 insertions(+), 3 deletions(-) diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 5fce25cf5e..8bcb8b4aa9 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -95,10 +95,14 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { DiscreteConditional(const gtsam::DecisionTreeFactor& joint, const gtsam::DecisionTreeFactor& marginal, const gtsam::Ordering& orderedKeys); + gtsam::DiscreteConditional operator*( + const gtsam::DiscreteConditional& other) const; void print(string s = "Discrete Conditional\n", const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; bool equals(const gtsam::DiscreteConditional& other, double tol = 1e-9) const; + size_t nrFrontals() const; + size_t nrParents() const; void printSignature( string s = "Discrete Conditional: ", const gtsam::KeyFormatter& formatter = gtsam::DefaultKeyFormatter) const; diff --git a/python/gtsam/tests/test_DecisionTreeFactor.py b/python/gtsam/tests/test_DecisionTreeFactor.py index 03d9f82d7e..a13a43e263 100644 --- a/python/gtsam/tests/test_DecisionTreeFactor.py +++ b/python/gtsam/tests/test_DecisionTreeFactor.py @@ -40,7 +40,7 @@ def test_multiplication(self): prior = DiscretePrior(v1, [1, 3]) f1 = DecisionTreeFactor([v0, v1], "1 2 3 4") expected = DecisionTreeFactor([v0, v1], "0.25 1.5 0.75 3") - self.gtsamAssertEquals(prior * f1, expected) + self.gtsamAssertEquals(DecisionTreeFactor(prior) * f1, expected) self.gtsamAssertEquals(f1 * prior, expected) # Multiply two factors diff --git a/python/gtsam/tests/test_DiscreteConditional.py b/python/gtsam/tests/test_DiscreteConditional.py index 0ae66c2d40..190c221819 100644 --- a/python/gtsam/tests/test_DiscreteConditional.py +++ b/python/gtsam/tests/test_DiscreteConditional.py @@ -16,6 +16,13 @@ from gtsam import DecisionTreeFactor, DiscreteConditional, DiscreteKeys from gtsam.utils.test_case import GtsamTestCase +# Some DiscreteKeys for binary variables: +A = 0, 2 +B = 1, 2 +C = 2, 2 +D = 4, 2 +E = 3, 2 + class TestDiscreteConditional(GtsamTestCase): """Tests for Discrete Conditionals.""" @@ -36,6 +43,44 @@ def test_single_value_versions(self): actual = conditional.sample(2) self.assertIsInstance(actual, int) + def test_multiply(self): + """Check calculation of joint P(A,B)""" + conditional = DiscreteConditional(A, [B], "1/2 2/1") + prior = DiscreteConditional(B, "1/2") + + # P(A,B) = P(A|B) * P(B) = P(B) * P(A|B) + for actual in [prior * conditional, conditional * prior]: + self.assertEqual(2, actual.nrFrontals()) + for v, value in actual.enumerate(): + self.assertAlmostEqual(actual(v), conditional(v) * prior(v)) + + def test_multiply2(self): + """Check calculation of conditional joint P(A,B|C)""" + A_given_B = DiscreteConditional(A, [B], "1/3 3/1") + B_given_C = DiscreteConditional(B, [C], "1/3 3/1") + + # P(A,B|C) = P(A|B)P(B|C) = P(B|C)P(A|B) + for actual in [A_given_B * B_given_C, B_given_C * A_given_B]: + self.assertEqual(2, actual.nrFrontals()) + self.assertEqual(1, actual.nrParents()) + for v, value in actual.enumerate(): + self.assertAlmostEqual(actual(v), A_given_B(v) * B_given_C(v)) + + def test_multiply4(self): + """Check calculation of joint P(A,B,C|D,E) = P(A,B|D) P(C|D,E)""" + A_given_B = DiscreteConditional(A, [B], "1/3 3/1") + B_given_D = DiscreteConditional(B, [D], "1/3 3/1") + AB_given_D = A_given_B * B_given_D + C_given_DE = DiscreteConditional(C, [D, E], "4/1 1/1 1/1 1/4") + + # P(A,B,C|D,E) = P(A,B|D) P(C|D,E) = P(C|D,E) P(A,B|D) + for actual in [AB_given_D * C_given_DE, C_given_DE * AB_given_D]: + self.assertEqual(3, actual.nrFrontals()) + self.assertEqual(2, actual.nrParents()) + for v, value in actual.enumerate(): + self.assertAlmostEqual( + actual(v), AB_given_D(v) * C_given_DE(v)) + def test_markdown(self): """Test whether the _repr_markdown_ method.""" @@ -48,8 +93,7 @@ def test_markdown(self): conditional = DiscreteConditional(A, parents, "0/1 1/3 1/1 3/1 0/1 1/0") - expected = \ - " *P(A|B,C):*\n\n" \ + expected = " *P(A|B,C):*\n\n" \ "|*B*|*C*|0|1|\n" \ "|:-:|:-:|:-:|:-:|\n" \ "|0|0|0|1|\n" \ From 64cd58843acf7664cf84169cd829507fff3050fa Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sat, 15 Jan 2022 16:28:34 -0500 Subject: [PATCH 11/14] marginals without parents --- gtsam/discrete/DiscreteConditional.cpp | 21 ++++++++++- gtsam/discrete/DiscreteConditional.h | 3 ++ gtsam/discrete/discrete.i | 1 + .../tests/testDiscreteConditional.cpp | 36 ++++++++++++++++++- .../gtsam/tests/test_DiscreteConditional.py | 9 +++++ 5 files changed, 68 insertions(+), 2 deletions(-) diff --git a/gtsam/discrete/DiscreteConditional.cpp b/gtsam/discrete/DiscreteConditional.cpp index 5acd7c0f65..e8aa4511d8 100644 --- a/gtsam/discrete/DiscreteConditional.cpp +++ b/gtsam/discrete/DiscreteConditional.cpp @@ -110,7 +110,26 @@ DiscreteConditional DiscreteConditional::operator*( return DiscreteConditional(newFrontals.size(), discreteKeys, product); } -/* ******************************************************************************** */ +/* ************************************************************************** */ +DiscreteConditional DiscreteConditional::marginal(Key key) const { + if (nrParents() > 0) + throw std::invalid_argument( + "DiscreteConditional::marginal: single argument version only valid for " + "fully specified joint distributions (i.e., no parents)."); + + // Calculate the keys as the frontal keys without the given key. + DiscreteKeys discreteKeys{{key, cardinality(key)}}; + + // Calculate sum + ADT adt(*this); + for (auto&& k : frontals()) + if (k != key) adt = adt.sum(k, cardinality(k)); + + // Return new factor + return DiscreteConditional(1, discreteKeys, adt); +} + +/* ************************************************************************** */ void DiscreteConditional::print(const string& s, const KeyFormatter& formatter) const { cout << s << " P( "; diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 450af57ab5..836aa39200 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -126,6 +126,9 @@ class GTSAM_EXPORT DiscreteConditional */ DiscreteConditional operator*(const DiscreteConditional& other) const; + /** Calculate marginal on given key, no parent case. */ + DiscreteConditional marginal(Key key) const; + /// @} /// @name Testable /// @{ diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index 8bcb8b4aa9..cd3e85598d 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -97,6 +97,7 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { const gtsam::Ordering& orderedKeys); gtsam::DiscreteConditional operator*( const gtsam::DiscreteConditional& other) const; + DiscreteConditional marginal(gtsam::Key key) const; void print(string s = "Discrete Conditional\n", const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; diff --git a/gtsam/discrete/tests/testDiscreteConditional.cpp b/gtsam/discrete/tests/testDiscreteConditional.cpp index 03766136c1..1256595170 100644 --- a/gtsam/discrete/tests/testDiscreteConditional.cpp +++ b/gtsam/discrete/tests/testDiscreteConditional.cpp @@ -97,10 +97,14 @@ TEST(DiscreteConditional, constructors3) { /* ************************************************************************* */ // Check calculation of joint P(A,B) TEST(DiscreteConditional, Multiply) { - DiscreteKey A(0, 2), B(1, 2); + DiscreteKey A(1, 2), B(0, 2); DiscreteConditional conditional(A | B = "1/2 2/1"); DiscreteConditional prior(B % "1/2"); + // The expected factor + DecisionTreeFactor f(A & B, "1 4 2 2"); + DiscreteConditional expected(2, f); + // P(A,B) = P(A|B) * P(B) = P(B) * P(A|B) for (auto&& actual : {prior * conditional, conditional * prior}) { EXPECT_LONGS_EQUAL(2, actual.nrFrontals()); @@ -110,8 +114,11 @@ TEST(DiscreteConditional, Multiply) { const DiscreteValues& v = it.first; EXPECT_DOUBLES_EQUAL(actual(v), conditional(v) * prior(v), 1e-9); } + // And for good measure: + EXPECT(assert_equal(expected, actual)); } } + /* ************************************************************************* */ // Check calculation of conditional joint P(A,B|C) TEST(DiscreteConditional, Multiply2) { @@ -131,6 +138,7 @@ TEST(DiscreteConditional, Multiply2) { } } } + /* ************************************************************************* */ // Check calculation of conditional joint P(A,B|C), double check keys TEST(DiscreteConditional, Multiply3) { @@ -150,6 +158,7 @@ TEST(DiscreteConditional, Multiply3) { } } } + /* ************************************************************************* */ // Check calculation of conditional joint P(A,B,C|D,E) = P(A,B|D) P(C|D,E) TEST(DiscreteConditional, Multiply4) { @@ -173,6 +182,31 @@ TEST(DiscreteConditional, Multiply4) { } } } + +/* ************************************************************************* */ +// Check calculation of marginals for joint P(A,B) +TEST(DiscreteConditional, marginals) { + DiscreteKey A(1, 2), B(0, 2); + DiscreteConditional conditional(A | B = "1/2 2/1"); + DiscreteConditional prior(B % "1/2"); + DiscreteConditional pAB = prior * conditional; + + DiscreteConditional actualA = pAB.marginal(A.first); + DiscreteConditional pA(A % "5/4"); + EXPECT(assert_equal(pA, actualA)); + EXPECT_LONGS_EQUAL(1, actualA.nrFrontals()); + EXPECT_LONGS_EQUAL(0, actualA.nrParents()); + KeyVector frontalsA(actualA.beginFrontals(), actualA.endFrontals()); + EXPECT((frontalsA == KeyVector{1})); + + DiscreteConditional actualB = pAB.marginal(B.first); + EXPECT(assert_equal(prior, actualB)); + EXPECT_LONGS_EQUAL(1, actualB.nrFrontals()); + EXPECT_LONGS_EQUAL(0, actualB.nrParents()); + KeyVector frontalsB(actualB.beginFrontals(), actualB.endFrontals()); + EXPECT((frontalsB == KeyVector{0})); +} + /* ************************************************************************* */ TEST(DiscreteConditional, likelihood) { DiscreteKey X(0, 2), Y(1, 3); diff --git a/python/gtsam/tests/test_DiscreteConditional.py b/python/gtsam/tests/test_DiscreteConditional.py index 190c221819..f46a0e8773 100644 --- a/python/gtsam/tests/test_DiscreteConditional.py +++ b/python/gtsam/tests/test_DiscreteConditional.py @@ -81,6 +81,15 @@ def test_multiply4(self): self.assertAlmostEqual( actual(v), AB_given_D(v) * C_given_DE(v)) + def test_marginals(self): + conditional = DiscreteConditional(A, [B], "1/2 2/1") + prior = DiscreteConditional(B, "1/2") + pAB = prior * conditional + self.gtsamAssertEquals(prior, pAB.marginal(B[0])) + + pA = DiscreteConditional(A % "5/4") + self.gtsamAssertEquals(pA, pAB.marginal(A[0])) + def test_markdown(self): """Test whether the _repr_markdown_ method.""" From 0b11b127609c3a0f7492050bf0613457f02fba22 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Sun, 16 Jan 2022 12:02:22 -0500 Subject: [PATCH 12/14] fix tests --- gtsam/slam/slam.i | 2 +- python/gtsam/tests/test_DiscreteConditional.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/gtsam/slam/slam.i b/gtsam/slam/slam.i index a0a7329dd8..602b2afe3c 100644 --- a/gtsam/slam/slam.i +++ b/gtsam/slam/slam.i @@ -11,7 +11,7 @@ namespace gtsam { // ###### #include -template virtual class BetweenFactor : gtsam::NoiseModelFactor { diff --git a/python/gtsam/tests/test_DiscreteConditional.py b/python/gtsam/tests/test_DiscreteConditional.py index f46a0e8773..241a5f0be9 100644 --- a/python/gtsam/tests/test_DiscreteConditional.py +++ b/python/gtsam/tests/test_DiscreteConditional.py @@ -87,7 +87,7 @@ def test_marginals(self): pAB = prior * conditional self.gtsamAssertEquals(prior, pAB.marginal(B[0])) - pA = DiscreteConditional(A % "5/4") + pA = DiscreteConditional(A, "5/4") self.gtsamAssertEquals(pA, pAB.marginal(A[0])) def test_markdown(self): From 4235334c83f1878a2446c621bf3b12588576028f Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 16 Jan 2022 09:47:36 -0500 Subject: [PATCH 13/14] Rename DiscretePrior -> DiscreteDistribution --- gtsam/discrete/DiscreteBayesNet.h | 6 ++-- gtsam/discrete/DiscreteConditional.h | 2 +- ...retePrior.cpp => DiscreteDistribution.cpp} | 16 +++++---- ...DiscretePrior.h => DiscreteDistribution.h} | 30 ++++++++-------- gtsam/discrete/discrete.i | 12 +++---- .../discrete/tests/testDecisionTreeFactor.cpp | 6 ++-- ...Prior.cpp => testDiscreteDistribution.cpp} | 35 +++++++++---------- python/gtsam/tests/test_DecisionTreeFactor.py | 6 ++-- python/gtsam/tests/test_DiscreteBayesNet.py | 4 +-- ...ePrior.py => test_DiscreteDistribution.py} | 20 +++++------ 10 files changed, 70 insertions(+), 67 deletions(-) rename gtsam/discrete/{DiscretePrior.cpp => DiscreteDistribution.cpp} (71%) rename gtsam/discrete/{DiscretePrior.h => DiscreteDistribution.h} (69%) rename gtsam/discrete/tests/{testDiscretePrior.cpp => testDiscreteDistribution.cpp} (74%) rename python/gtsam/tests/{test_DiscretePrior.py => test_DiscreteDistribution.py} (77%) diff --git a/gtsam/discrete/DiscreteBayesNet.h b/gtsam/discrete/DiscreteBayesNet.h index 17dfe2c5ff..db20e7223a 100644 --- a/gtsam/discrete/DiscreteBayesNet.h +++ b/gtsam/discrete/DiscreteBayesNet.h @@ -19,7 +19,7 @@ #pragma once #include -#include +#include #include #include @@ -79,9 +79,9 @@ namespace gtsam { // Add inherited versions of add. using Base::add; - /** Add a DiscretePrior using a table or a string */ + /** Add a DiscreteDistribution using a table or a string */ void add(const DiscreteKey& key, const std::string& spec) { - emplace_shared(key, spec); + emplace_shared(key, spec); } /** Add a DiscreteCondtional */ diff --git a/gtsam/discrete/DiscreteConditional.h b/gtsam/discrete/DiscreteConditional.h index 836aa39200..c3c8a66def 100644 --- a/gtsam/discrete/DiscreteConditional.h +++ b/gtsam/discrete/DiscreteConditional.h @@ -89,7 +89,7 @@ class GTSAM_EXPORT DiscreteConditional const std::string& spec) : DiscreteConditional(Signature(key, parents, spec)) {} - /// No-parent specialization; can also use DiscretePrior. + /// No-parent specialization; can also use DiscreteDistribution. DiscreteConditional(const DiscreteKey& key, const std::string& spec) : DiscreteConditional(Signature(key, {}, spec)) {} diff --git a/gtsam/discrete/DiscretePrior.cpp b/gtsam/discrete/DiscreteDistribution.cpp similarity index 71% rename from gtsam/discrete/DiscretePrior.cpp rename to gtsam/discrete/DiscreteDistribution.cpp index 3941e0199e..7397714709 100644 --- a/gtsam/discrete/DiscretePrior.cpp +++ b/gtsam/discrete/DiscreteDistribution.cpp @@ -10,21 +10,23 @@ * -------------------------------------------------------------------------- */ /** - * @file DiscretePrior.cpp + * @file DiscreteDistribution.cpp * @date December 2021 * @author Frank Dellaert */ -#include +#include + +#include namespace gtsam { -void DiscretePrior::print(const std::string& s, - const KeyFormatter& formatter) const { +void DiscreteDistribution::print(const std::string& s, + const KeyFormatter& formatter) const { Base::print(s, formatter); } -double DiscretePrior::operator()(size_t value) const { +double DiscreteDistribution::operator()(size_t value) const { if (nrFrontals() != 1) throw std::invalid_argument( "Single value operator can only be invoked on single-variable " @@ -34,10 +36,10 @@ double DiscretePrior::operator()(size_t value) const { return Base::operator()(values); } -std::vector DiscretePrior::pmf() const { +std::vector DiscreteDistribution::pmf() const { if (nrFrontals() != 1) throw std::invalid_argument( - "DiscretePrior::pmf only defined for single-variable priors"); + "DiscreteDistribution::pmf only defined for single-variable priors"); const size_t nrValues = cardinalities_.at(keys_[0]); std::vector array; array.reserve(nrValues); diff --git a/gtsam/discrete/DiscretePrior.h b/gtsam/discrete/DiscreteDistribution.h similarity index 69% rename from gtsam/discrete/DiscretePrior.h rename to gtsam/discrete/DiscreteDistribution.h index 1da1882155..fae6e355bd 100644 --- a/gtsam/discrete/DiscretePrior.h +++ b/gtsam/discrete/DiscreteDistribution.h @@ -10,7 +10,7 @@ * -------------------------------------------------------------------------- */ /** - * @file DiscretePrior.h + * @file DiscreteDistribution.h * @date December 2021 * @author Frank Dellaert */ @@ -20,6 +20,7 @@ #include #include +#include namespace gtsam { @@ -27,7 +28,7 @@ namespace gtsam { * A prior probability on a set of discrete variables. * Derives from DiscreteConditional */ -class GTSAM_EXPORT DiscretePrior : public DiscreteConditional { +class GTSAM_EXPORT DiscreteDistribution : public DiscreteConditional { public: using Base = DiscreteConditional; @@ -35,35 +36,36 @@ class GTSAM_EXPORT DiscretePrior : public DiscreteConditional { /// @{ /// Default constructor needed for serialization. - DiscretePrior() {} + DiscreteDistribution() {} /// Constructor from factor. - DiscretePrior(const DecisionTreeFactor& f) : Base(f.size(), f) {} + explicit DiscreteDistribution(const DecisionTreeFactor& f) + : Base(f.size(), f) {} /** * Construct from a Signature. * - * Example: DiscretePrior P(D % "3/2"); + * Example: DiscreteDistribution P(D % "3/2"); */ - DiscretePrior(const Signature& s) : Base(s) {} + explicit DiscreteDistribution(const Signature& s) : Base(s) {} /** * Construct from key and a vector of floats specifying the probability mass * function (PMF). * - * Example: DiscretePrior P(D, {0.4, 0.6}); + * Example: DiscreteDistribution P(D, {0.4, 0.6}); */ - DiscretePrior(const DiscreteKey& key, const std::vector& spec) - : DiscretePrior(Signature(key, {}, Signature::Table{spec})) {} + DiscreteDistribution(const DiscreteKey& key, const std::vector& spec) + : DiscreteDistribution(Signature(key, {}, Signature::Table{spec})) {} /** * Construct from key and a string specifying the probability mass function * (PMF). * - * Example: DiscretePrior P(D, "9/1 2/8 3/7 1/9"); + * Example: DiscreteDistribution P(D, "9/1 2/8 3/7 1/9"); */ - DiscretePrior(const DiscreteKey& key, const std::string& spec) - : DiscretePrior(Signature(key, {}, spec)) {} + DiscreteDistribution(const DiscreteKey& key, const std::string& spec) + : DiscreteDistribution(Signature(key, {}, spec)) {} /// @} /// @name Testable @@ -102,10 +104,10 @@ class GTSAM_EXPORT DiscretePrior : public DiscreteConditional { /// @} }; -// DiscretePrior +// DiscreteDistribution // traits template <> -struct traits : public Testable {}; +struct traits : public Testable {}; } // namespace gtsam diff --git a/gtsam/discrete/discrete.i b/gtsam/discrete/discrete.i index cd3e85598d..7ce4bd9021 100644 --- a/gtsam/discrete/discrete.i +++ b/gtsam/discrete/discrete.i @@ -128,12 +128,12 @@ virtual class DiscreteConditional : gtsam::DecisionTreeFactor { std::map> names) const; }; -#include -virtual class DiscretePrior : gtsam::DiscreteConditional { - DiscretePrior(); - DiscretePrior(const gtsam::DecisionTreeFactor& f); - DiscretePrior(const gtsam::DiscreteKey& key, string spec); - DiscretePrior(const gtsam::DiscreteKey& key, std::vector spec); +#include +virtual class DiscreteDistribution : gtsam::DiscreteConditional { + DiscreteDistribution(); + DiscreteDistribution(const gtsam::DecisionTreeFactor& f); + DiscreteDistribution(const gtsam::DiscreteKey& key, string spec); + DiscreteDistribution(const gtsam::DiscreteKey& key, std::vector spec); void print(string s = "Discrete Prior\n", const gtsam::KeyFormatter& keyFormatter = gtsam::DefaultKeyFormatter) const; diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index 7e89874a59..92145b8b76 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -20,7 +20,7 @@ #include #include #include -#include +#include #include #include @@ -56,8 +56,8 @@ TEST( DecisionTreeFactor, constructors) TEST(DecisionTreeFactor, multiplication) { DiscreteKey v0(0, 2), v1(1, 2), v2(2, 2); - // Multiply with a DiscretePrior, i.e., Bayes Law! - DiscretePrior prior(v1 % "1/3"); + // Multiply with a DiscreteDistribution, i.e., Bayes Law! + DiscreteDistribution prior(v1 % "1/3"); DecisionTreeFactor f1(v0 & v1, "1 2 3 4"); DecisionTreeFactor expected(v0 & v1, "0.25 1.5 0.75 3"); CHECK(assert_equal(expected, static_cast(prior) * f1)); diff --git a/gtsam/discrete/tests/testDiscretePrior.cpp b/gtsam/discrete/tests/testDiscreteDistribution.cpp similarity index 74% rename from gtsam/discrete/tests/testDiscretePrior.cpp rename to gtsam/discrete/tests/testDiscreteDistribution.cpp index 6ef57c7ff4..5c0c42e737 100644 --- a/gtsam/discrete/tests/testDiscretePrior.cpp +++ b/gtsam/discrete/tests/testDiscreteDistribution.cpp @@ -11,42 +11,41 @@ /* * @file testDiscretePrior.cpp - * @brief unit tests for DiscretePrior + * @brief unit tests for DiscreteDistribution * @author Frank dellaert * @date December 2021 */ #include -#include +#include #include -using namespace std; using namespace gtsam; static const DiscreteKey X(0, 2); /* ************************************************************************* */ -TEST(DiscretePrior, constructors) { +TEST(DiscreteDistribution, constructors) { DecisionTreeFactor f(X, "0.4 0.6"); - DiscretePrior expected(f); + DiscreteDistribution expected(f); - DiscretePrior actual(X % "2/3"); + DiscreteDistribution actual(X % "2/3"); EXPECT_LONGS_EQUAL(1, actual.nrFrontals()); EXPECT_LONGS_EQUAL(0, actual.nrParents()); EXPECT(assert_equal(expected, actual, 1e-9)); - const vector pmf{0.4, 0.6}; - DiscretePrior actual2(X, pmf); + const std::vector pmf{0.4, 0.6}; + DiscreteDistribution actual2(X, pmf); EXPECT_LONGS_EQUAL(1, actual2.nrFrontals()); EXPECT_LONGS_EQUAL(0, actual2.nrParents()); EXPECT(assert_equal(expected, actual2, 1e-9)); } /* ************************************************************************* */ -TEST(DiscretePrior, Multiply) { +TEST(DiscreteDistribution, Multiply) { DiscreteKey A(0, 2), B(1, 2); DiscreteConditional conditional(A | B = "1/2 2/1"); - DiscretePrior prior(B, "1/2"); + DiscreteDistribution prior(B, "1/2"); DiscreteConditional actual = prior * conditional; // P(A|B) * P(B) EXPECT_LONGS_EQUAL(2, actual.nrFrontals()); // = P(A,B) @@ -56,22 +55,22 @@ TEST(DiscretePrior, Multiply) { } /* ************************************************************************* */ -TEST(DiscretePrior, operator) { - DiscretePrior prior(X % "2/3"); +TEST(DiscreteDistribution, operator) { + DiscreteDistribution prior(X % "2/3"); EXPECT_DOUBLES_EQUAL(prior(0), 0.4, 1e-9); EXPECT_DOUBLES_EQUAL(prior(1), 0.6, 1e-9); } /* ************************************************************************* */ -TEST(DiscretePrior, pmf) { - DiscretePrior prior(X % "2/3"); - vector expected {0.4, 0.6}; - EXPECT(prior.pmf() == expected); +TEST(DiscreteDistribution, pmf) { + DiscreteDistribution prior(X % "2/3"); + std::vector expected{0.4, 0.6}; + EXPECT(prior.pmf() == expected); } /* ************************************************************************* */ -TEST(DiscretePrior, sample) { - DiscretePrior prior(X % "2/3"); +TEST(DiscreteDistribution, sample) { + DiscreteDistribution prior(X % "2/3"); prior.sample(); } diff --git a/python/gtsam/tests/test_DecisionTreeFactor.py b/python/gtsam/tests/test_DecisionTreeFactor.py index a13a43e263..0499e72154 100644 --- a/python/gtsam/tests/test_DecisionTreeFactor.py +++ b/python/gtsam/tests/test_DecisionTreeFactor.py @@ -13,7 +13,7 @@ import unittest -from gtsam import DecisionTreeFactor, DiscreteValues, DiscretePrior, Ordering +from gtsam import DecisionTreeFactor, DiscreteValues, DiscreteDistribution, Ordering from gtsam.utils.test_case import GtsamTestCase @@ -36,8 +36,8 @@ def test_multiplication(self): v1 = (1, 2) v2 = (2, 2) - # Multiply with a DiscretePrior, i.e., Bayes Law! - prior = DiscretePrior(v1, [1, 3]) + # Multiply with a DiscreteDistribution, i.e., Bayes Law! + prior = DiscreteDistribution(v1, [1, 3]) f1 = DecisionTreeFactor([v0, v1], "1 2 3 4") expected = DecisionTreeFactor([v0, v1], "0.25 1.5 0.75 3") self.gtsamAssertEquals(DecisionTreeFactor(prior) * f1, expected) diff --git a/python/gtsam/tests/test_DiscreteBayesNet.py b/python/gtsam/tests/test_DiscreteBayesNet.py index bdd5a05464..36f0d153d9 100644 --- a/python/gtsam/tests/test_DiscreteBayesNet.py +++ b/python/gtsam/tests/test_DiscreteBayesNet.py @@ -14,7 +14,7 @@ import unittest from gtsam import (DiscreteBayesNet, DiscreteConditional, DiscreteFactorGraph, - DiscreteKeys, DiscretePrior, DiscreteValues, Ordering) + DiscreteKeys, DiscreteDistribution, DiscreteValues, Ordering) from gtsam.utils.test_case import GtsamTestCase @@ -74,7 +74,7 @@ def test_Asia(self): for j in range(8): ordering.push_back(j) chordal = fg.eliminateSequential(ordering) - expected2 = DiscretePrior(Bronchitis, "11/9") + expected2 = DiscreteDistribution(Bronchitis, "11/9") self.gtsamAssertEquals(chordal.at(7), expected2) # solve diff --git a/python/gtsam/tests/test_DiscretePrior.py b/python/gtsam/tests/test_DiscreteDistribution.py similarity index 77% rename from python/gtsam/tests/test_DiscretePrior.py rename to python/gtsam/tests/test_DiscreteDistribution.py index 06bdc81ca7..fa999fd6b5 100644 --- a/python/gtsam/tests/test_DiscretePrior.py +++ b/python/gtsam/tests/test_DiscreteDistribution.py @@ -14,7 +14,7 @@ import unittest import numpy as np -from gtsam import DecisionTreeFactor, DiscreteKeys, DiscretePrior +from gtsam import DecisionTreeFactor, DiscreteKeys, DiscreteDistribution from gtsam.utils.test_case import GtsamTestCase X = 0, 2 @@ -28,33 +28,33 @@ def test_constructor(self): keys = DiscreteKeys() keys.push_back(X) f = DecisionTreeFactor(keys, "0.4 0.6") - expected = DiscretePrior(f) - - actual = DiscretePrior(X, "2/3") + expected = DiscreteDistribution(f) + + actual = DiscreteDistribution(X, "2/3") self.gtsamAssertEquals(actual, expected) - - actual2 = DiscretePrior(X, [0.4, 0.6]) + + actual2 = DiscreteDistribution(X, [0.4, 0.6]) self.gtsamAssertEquals(actual2, expected) def test_operator(self): - prior = DiscretePrior(X, "2/3") + prior = DiscreteDistribution(X, "2/3") self.assertAlmostEqual(prior(0), 0.4) self.assertAlmostEqual(prior(1), 0.6) def test_pmf(self): - prior = DiscretePrior(X, "2/3") + prior = DiscreteDistribution(X, "2/3") expected = np.array([0.4, 0.6]) np.testing.assert_allclose(expected, prior.pmf()) def test_sample(self): - prior = DiscretePrior(X, "2/3") + prior = DiscreteDistribution(X, "2/3") actual = prior.sample() self.assertIsInstance(actual, int) def test_markdown(self): """Test the _repr_markdown_ method.""" - prior = DiscretePrior(X, "2/3") + prior = DiscreteDistribution(X, "2/3") expected = " *P(0):*\n\n" \ "|0|value|\n" \ "|:-:|:-:|\n" \ From 91de3cb6ba333e9e874e3e3fcd92673e018ae0c3 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Sun, 16 Jan 2022 15:17:26 -0500 Subject: [PATCH 14/14] Bump version to 4.2a3 --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index d040f9e82a..7c37099a45 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,7 +11,7 @@ endif() set (GTSAM_VERSION_MAJOR 4) set (GTSAM_VERSION_MINOR 2) set (GTSAM_VERSION_PATCH 0) -set (GTSAM_PRERELEASE_VERSION "a2") +set (GTSAM_PRERELEASE_VERSION "a3") math (EXPR GTSAM_VERSION_NUMERIC "10000 * ${GTSAM_VERSION_MAJOR} + 100 * ${GTSAM_VERSION_MINOR} + ${GTSAM_VERSION_PATCH}") if (${GTSAM_VERSION_PATCH} EQUAL 0)