Skip to content

Commit

Permalink
unary apply methods for TableFactor
Browse files Browse the repository at this point in the history
  • Loading branch information
varunagrawal committed Jul 17, 2023
1 parent cf12927 commit aafc33d
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 5 deletions.
41 changes: 40 additions & 1 deletion gtsam/discrete/TableFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ TableFactor::TableFactor(const DiscreteConditional& c)
Eigen::SparseVector<double> TableFactor::Convert(
const std::vector<double>& table) {
Eigen::SparseVector<double> sparse_table(table.size());
// Count number of nonzero elements in table and reserving the space.
// Count number of nonzero elements in table and reserve the space.
const uint64_t nnz = std::count_if(table.begin(), table.end(),
[](uint64_t i) { return i != 0; });
sparse_table.reserve(nnz);
Expand Down Expand Up @@ -218,6 +218,45 @@ void TableFactor::print(const string& s, const KeyFormatter& formatter) const {
cout << "number of nnzs: " << sparse_table_.nonZeros() << endl;
}

/* ************************************************************************ */
TableFactor TableFactor::apply(Unary op) const {
// Initialize new factor.
uint64_t cardi = 1;
for (auto [key, c] : cardinalities_) cardi *= c;
Eigen::SparseVector<double> sparse_table(cardi);
sparse_table.reserve(sparse_table_.nonZeros());

// Populate
for (SparseIt it(sparse_table_); it; ++it) {
sparse_table.coeffRef(it.index()) = op(it.value());
}

// Free unused memory and return.
sparse_table.pruned();
sparse_table.data().squeeze();
return TableFactor(discreteKeys(), sparse_table);
}

/* ************************************************************************ */
TableFactor TableFactor::apply(UnaryAssignment op) const {
// Initialize new factor.
uint64_t cardi = 1;
for (auto [key, c] : cardinalities_) cardi *= c;
Eigen::SparseVector<double> sparse_table(cardi);
sparse_table.reserve(sparse_table_.nonZeros());

// Populate
for (SparseIt it(sparse_table_); it; ++it) {
DiscreteValues assignment = findAssignments(it.index());
sparse_table.coeffRef(it.index()) = op(assignment, it.value());
}

// Free unused memory and return.
sparse_table.pruned();
sparse_table.data().squeeze();
return TableFactor(discreteKeys(), sparse_table);
}

/* ************************************************************************ */
TableFactor TableFactor::apply(const TableFactor& f, Binary op) const {
if (keys_.empty() && sparse_table_.nonZeros() == 0)
Expand Down
28 changes: 26 additions & 2 deletions gtsam/discrete/TableFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
typedef std::shared_ptr<TableFactor> shared_ptr;
typedef Eigen::SparseVector<double>::InnerIterator SparseIt;
typedef std::vector<std::pair<DiscreteValues, double>> AssignValList;
using Unary = std::function<double(const double&)>;
using UnaryAssignment =
std::function<double(const Assignment<Key>&, const double&)>;
using Binary = std::function<double(const double, const double)>;

public:
Expand Down Expand Up @@ -218,17 +221,38 @@ class GTSAM_EXPORT TableFactor : public DiscreteFactor {
/// @name Advanced Interface
/// @{

/**
* Apply unary operator `op(*this)` where `op` accepts the discrete value.
* @param op a unary operator that operates on TableFactor
*/
TableFactor apply(Unary op) const;
/**
* Apply unary operator `op(*this)` where `op` accepts the discrete assignment
* and the value at that assignment.
* @param op a unary operator that operates on TableFactor
*/
TableFactor apply(UnaryAssignment op) const;

/**
* Apply binary operator (*this) "op" f
* @param f the second argument for op
* @param op a binary operator that operates on TableFactor
*/
TableFactor apply(const TableFactor& f, Binary op) const;

/// Return keys in contract mode.
/**
* Return keys in contract mode.
*
* Modes are each of the dimensions of a sparse tensor,
* and the contract modes represent which dimensions will
* be involved in contraction (aka tensor multiplication).
*/
DiscreteKeys contractDkeys(const TableFactor& f) const;

/// Return keys in free mode.
/**
* @brief Return keys in free mode which are the dimensions
* not involved in the contraction operation.
*/
DiscreteKeys freeDkeys(const TableFactor& f) const;

/// Return union of DiscreteKeys in two factors.
Expand Down
36 changes: 34 additions & 2 deletions gtsam/discrete/tests/testTableFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,7 @@ void printTime(map<double, pair<chrono::microseconds, chrono::microseconds>>
for (auto&& kv : measured_time) {
cout << "dropout: " << kv.first
<< " | TableFactor time: " << kv.second.first.count()
<< " | DecisionTreeFactor time: " << kv.second.second.count() <<
endl;
<< " | DecisionTreeFactor time: " << kv.second.second.count() << endl;
}
}

Expand Down Expand Up @@ -361,6 +360,39 @@ TEST(TableFactor, htmlWithValueFormatter) {
EXPECT(actual == expected);
}

/* ************************************************************************* */
TEST(TableFactor, Unary) {
// Declare a bunch of keys
DiscreteKey X(0, 2), Y(1, 3);

// Create factors
TableFactor f(X & Y, "2 5 3 6 2 7");
auto op = [](const double x) { return 2 * x; };
auto g = f.apply(op);

TableFactor expected(X & Y, "4 10 6 12 4 14");
EXPECT(assert_equal(g, expected));

auto sq_op = [](const double x) { return x * x; };
auto g_sq = f.apply(sq_op);
TableFactor expected_sq(X & Y, "4 25 9 36 4 49");
EXPECT(assert_equal(g_sq, expected_sq));
}

/* ************************************************************************* */
TEST(TableFactor, UnaryAssignment) {
// Declare a bunch of keys
DiscreteKey X(0, 2), Y(1, 3);

// Create factors
TableFactor f(X & Y, "2 5 3 6 2 7");
auto op = [](const Assignment<Key>& key, const double x) { return 2 * x; };
auto g = f.apply(op);

TableFactor expected(X & Y, "4 10 6 12 4 14");
EXPECT(assert_equal(g, expected));
}

/* ************************************************************************* */
int main() {
TestResult tr;
Expand Down

0 comments on commit aafc33d

Please sign in to comment.