Skip to content

Commit

Permalink
refactoring of annotators
Browse files Browse the repository at this point in the history
  • Loading branch information
karasikov committed Jul 21, 2023
1 parent 7430904 commit 9951913
Show file tree
Hide file tree
Showing 40 changed files with 235 additions and 780 deletions.
158 changes: 77 additions & 81 deletions metagraph/src/annotation/binary_matrix/base/binary_matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,55 @@
#include "common/serialization.hpp"
#include "common/utils/template_utils.hpp"
#include "common/vector_map.hpp"
#include "common/vector_set.hpp"
#include "common/hashers/hash.hpp"
#include "annotation/binary_matrix/row_diff/row_diff.hpp"


namespace mtg {
namespace annot {
namespace matrix {

const size_t kRowBatchSize = 100'000;


std::vector<BinaryMatrix::SetBitPositions>
BinaryMatrix::get_rows(const std::vector<Row> &row_ids) const {
std::vector<SetBitPositions> rows(row_ids.size());
BinaryMatrix::get_rows_dict(std::vector<Row> *rows, size_t num_threads) const {
VectorSet<SetBitPositions, utils::VectorHash> unique_rows;

std::vector<std::pair<Row, size_t>> row_to_index(rows->size());
for (size_t i = 0; i < rows->size(); ++i) {
row_to_index[i] = std::make_pair((*rows)[i], i);
}

auto slice = slice_rows(row_ids);
// don't break the topological order for row-diff annotation
if (!dynamic_cast<const IRowDiff *>(this)) {
ips4o::parallel::sort(row_to_index.begin(), row_to_index.end(),
utils::LessFirst(), num_threads);
}

assert(slice.size() >= row_ids.size());
#pragma omp parallel for num_threads(num_threads) schedule(dynamic)
for (uint64_t begin = 0; begin < row_to_index.size(); begin += kRowBatchSize) {
const uint64_t end = std::min(begin + kRowBatchSize,
static_cast<uint64_t>(row_to_index.size()));

std::vector<Row> ids(end - begin);
for (uint64_t i = begin; i < end; ++i) {
ids[i - begin] = row_to_index[i].first;
}

auto row_begin = slice.begin();
auto batch = get_rows(ids);

for (size_t i = 0; i < rows.size(); ++i) {
// every row in `slice` ends with `-1`
auto row_end = std::find(row_begin, slice.end(),
std::numeric_limits<Column>::max());
rows[i].assign(row_begin, row_end);
row_begin = row_end + 1;
#pragma omp critical
{
for (uint64_t i = begin; i < end; ++i) {
auto it = unique_rows.emplace(std::move(batch[i - begin])).first;
(*rows)[row_to_index[i].second] = it - unique_rows.begin();
}
}
}

return rows;
return const_cast<std::vector<SetBitPositions>&&>(unique_rows.values_container());
}

void BinaryMatrix::call_columns(const std::vector<Column> &column_ids,
Expand Down Expand Up @@ -60,14 +84,18 @@ BinaryMatrix::sum_rows(const std::vector<std::pair<Row, size_t>> &index_counts,
if (total_sum_count < min_count)
return {};

// TODO: call slice_rows
auto rows = get_rows(row_ids);
auto distinct_rows = get_rows_dict(&row_ids);

std::vector<size_t> counts(distinct_rows.size(), 0);
for (size_t i = 0; i < index_counts.size(); ++i) {
counts[row_ids[i]] += index_counts[i].second;
}

VectorMap<Column, size_t> col_counts;

for (size_t i = 0; i < index_counts.size(); ++i) {
for (size_t j : rows[i]) {
col_counts[j] += index_counts[i].second;
for (size_t i = 0; i < counts.size(); ++i) {
for (size_t j : distinct_rows[i]) {
col_counts[j] += counts[i];
}
}

Expand All @@ -81,6 +109,19 @@ BinaryMatrix::sum_rows(const std::vector<std::pair<Row, size_t>> &index_counts,
}


std::vector<RainbowMatrix::SetBitPositions>
RainbowMatrix::get_rows(const std::vector<Row> &rows) const {
std::vector<Row> pointers = rows;
auto distinct_rows = get_rows_dict(&pointers);

std::vector<SetBitPositions> result(rows.size());
for (size_t i = 0; i < pointers.size(); ++i) {
result[i] = distinct_rows[pointers[i]];
}

return result;
}

std::vector<RainbowMatrix::SetBitPositions>
RainbowMatrix::get_rows_dict(std::vector<Row> *rows, size_t num_threads) const {
assert(rows);
Expand Down Expand Up @@ -108,80 +149,35 @@ RainbowMatrix::get_rows_dict(std::vector<Row> *rows, size_t num_threads) const {
}
row_codes = {};

std::vector<SetBitPositions> unique_rows(codes.size());

#pragma omp parallel for num_threads(num_threads)
for (size_t i = 0; i < codes.size(); ++i) {
unique_rows[i] = code_to_row(codes[i]);
}
if (num_threads <= 1)
return codes_to_rows(codes);

return unique_rows;
}
std::vector<SetBitPositions> unique_rows(codes.size());

// TODO: improve
RainbowMatrix::SetBitPositions
RainbowMatrix::slice_rows(const std::vector<Row> &rows) const {
SetBitPositions slice;
slice.reserve(rows.size() * 2);
size_t batch_size = std::min(kRowBatchSize,
(codes.size() + num_threads - 1) / num_threads);

for (const auto &row : get_rows(rows)) {
for (uint64_t j : row) {
slice.push_back(j);
#pragma omp parallel for num_threads(num_threads) schedule(dynamic)
for (size_t i = 0; i < codes.size(); i += batch_size) {
std::vector<uint64_t> ids(codes.begin() + i,
codes.begin() + std::min(i + batch_size, codes.size()));
auto rows = codes_to_rows(ids);
for (size_t j = 0; j < rows.size(); ++j) {
unique_rows[i + j] = std::move(rows[j]);
}
slice.push_back(std::numeric_limits<Column>::max());
}

return slice;
}

std::vector<RainbowMatrix::SetBitPositions>
RainbowMatrix::get_rows(const std::vector<Row> &rows) const {
std::vector<Row> pointers = rows;
auto distinct_rows = get_rows_dict(&pointers);

std::vector<SetBitPositions> result(rows.size());
for (size_t i = 0; i < pointers.size(); ++i) {
result[i] = distinct_rows[pointers[i]];
}

return result;
return unique_rows;
}

// TODO: merge with BinaryMatrix::sum_rows
std::vector<std::pair<RainbowMatrix::Column, size_t /* count */>>
RainbowMatrix::sum_rows(const std::vector<std::pair<Row, size_t>> &index_counts,
size_t min_count) const {
min_count = std::max(min_count, size_t(1));

size_t total_sum_count = 0;
for (const auto &pair : index_counts) {
total_sum_count += pair.second;
}

if (total_sum_count < min_count)
return {};

tsl::hopscotch_map<size_t, size_t> code_count;
code_count.reserve(index_counts.size());
for (auto [i, count] : index_counts) {
code_count[get_code(i)] += count;
}

VectorMap<Column, size_t> col_counts;

for (auto [c, count] : code_count) {
for (size_t j : code_to_row(c)) {
col_counts[j] += count;
}
std::vector<BinaryMatrix::SetBitPositions>
RowMajor::get_rows(const std::vector<Row> &row_ids) const {
std::vector<SetBitPositions> rows(row_ids.size());
for (size_t i = 0; i < row_ids.size(); ++i) {
rows[i] = get_row(row_ids[i]);
}

auto &result = const_cast<std::vector<std::pair<Column, size_t>>&>(col_counts.values_container());

result.erase(std::remove_if(result.begin(), result.end(),
[&](const auto &p) { return p.second < min_count; }),
result.end());

return std::move(result);
return rows;
}


Expand Down
35 changes: 18 additions & 17 deletions metagraph/src/annotation/binary_matrix/base/binary_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@ class BinaryMatrix {
virtual uint64_t num_rows() const = 0;

// row is in [0, num_rows), column is in [0, num_columns)
virtual std::vector<SetBitPositions> get_rows(const std::vector<Row> &rows) const;
virtual std::vector<SetBitPositions> get_rows(const std::vector<Row> &rows) const = 0;
// Return unique rows (in arbitrary order) and update the row indexes
// in |rows| to point to their respective rows in the vector returned.
virtual std::vector<SetBitPositions> get_rows_dict(std::vector<Row> *rows,
size_t num_threads = 1) const;
virtual std::vector<Row> get_column(Column column) const = 0;

// get all selected rows appended with -1 and concatenated
virtual SetBitPositions slice_rows(const std::vector<Row> &rows) const = 0;

// For each column id in columns, run callback on its respective index in columns
// and a bitmap represnting the column
virtual void call_columns(const std::vector<Column> &columns,
Expand All @@ -59,30 +60,30 @@ class RainbowMatrix : public BinaryMatrix {
public:
virtual ~RainbowMatrix() {}

// row is in [0, num_rows), column is in [0, num_columns)
virtual std::vector<SetBitPositions> get_rows(const std::vector<Row> &rows) const final;

// Return unique rows (in arbitrary order) and update the row indexes
// in |rows| to point to their respective rows in the vector returned.
virtual std::vector<SetBitPositions> get_rows_dict(std::vector<Row> *rows, size_t num_threads = 1) const;
virtual std::vector<SetBitPositions> get_rows(const std::vector<Row> &rows) const;
virtual SetBitPositions slice_rows(const std::vector<Row> &rows) const;

// Return all columns for which counts are greater than or equal to |min_count|.
virtual std::vector<std::pair<Column, size_t /* count */>>
sum_rows(const std::vector<std::pair<Row, size_t>> &index_counts,
size_t min_count = 1) const;
virtual std::vector<SetBitPositions> get_rows_dict(std::vector<Row> *rows,
size_t num_threads = 1) const final;

virtual uint64_t num_distinct_rows() const = 0;

private:
virtual uint64_t get_code(Row row) const = 0;
virtual SetBitPositions code_to_row(Row row) const = 0;
virtual std::vector<SetBitPositions> codes_to_rows(const std::vector<Row> &rows) const = 0;
};


class GetRowSupport {
class RowMajor : public BinaryMatrix {
public:
virtual ~GetRowSupport() {}
virtual ~RowMajor() {}

virtual SetBitPositions get_row(Row row) const = 0;

virtual BinaryMatrix::SetBitPositions get_row(BinaryMatrix::Row row) const = 0;
// row is in [0, num_rows), column is in [0, num_columns)
virtual std::vector<SetBitPositions> get_rows(const std::vector<Row> &rows) const final;
};

class GetEntrySupport {
Expand All @@ -92,7 +93,7 @@ class GetEntrySupport {
virtual bool get(BinaryMatrix::Row row, BinaryMatrix::Column column) const = 0;
};

class BinaryMatrixRowDynamic : public BinaryMatrix, public GetRowSupport {
class BinaryMatrixRowDynamic : public RowMajor {
public:
virtual ~BinaryMatrixRowDynamic() {}

Expand Down
37 changes: 0 additions & 37 deletions metagraph/src/annotation/binary_matrix/bin_rel_wt/bin_rel_wt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,43 +145,6 @@ BinRelWT::SetBitPositions BinRelWT::get_row(Row row) const {
return relations_in_row;
}

BinRelWT::SetBitPositions
BinRelWT::slice_rows(const std::vector<Row> &rows) const {
SetBitPositions slice;
slice.reserve(rows.size() * 2);

for (Row row : rows) {
assert(row < num_objects);
if (!is_zero_row(row)) {
auto first_label = to_label_id(0);
auto last_label = to_label_id(max_used_label);
auto cur_object = to_object_id(row);

auto num_relations_in_row = static_cast<uint64_t>(
binary_relation_.count_distinct_labels(cur_object,
cur_object,
first_label,
last_label)
);

for (uint64_t relation_it = 1; relation_it <= num_relations_in_row; ++relation_it) {
auto element =
binary_relation_.nth_element(cur_object,
cur_object,
first_label,
relation_it,
brwt::lab_major);
assert(element);
slice.push_back(to_index(element->label));
}
}

slice.push_back(std::numeric_limits<Column>::max());
}

return slice;
}

std::vector<BinRelWT::Row> BinRelWT::get_column(Column column) const {
assert(column < num_labels);
if (is_zero_column(column)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace mtg {
namespace annot {
namespace matrix {

class BinRelWT : public BinaryMatrix, public GetRowSupport, public GetEntrySupport {
class BinRelWT : public RowMajor, public GetEntrySupport {
public:
BinRelWT() {}

Expand All @@ -26,7 +26,6 @@ class BinRelWT : public BinaryMatrix, public GetRowSupport, public GetEntrySuppo

bool get(Row row, Column column) const override;
SetBitPositions get_row(Row row) const override;
SetBitPositions slice_rows(const std::vector<Row> &row_ids) const override;
std::vector<Row> get_column(Column column) const override;

bool load(std::istream &in) override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,45 +86,6 @@ BinRelWT_sdsl::SetBitPositions BinRelWT_sdsl::get_row(Row row) const {
return SetBitPositions(label_indices.begin(), label_indices.end());
}

BinRelWT_sdsl::SetBitPositions BinRelWT_sdsl::slice_rows(const std::vector<Row> &row_ids) const {
SetBitPositions slice;
slice.reserve(row_ids.size() * 2);

// Get label indices from the base string stored in wt_.
typedef sdsl::int_vector<>::size_type size_type;
typedef typename decltype(wt_)::value_type value_type;
std::vector<value_type> label_indices;
std::vector<size_type> rank_c_i;
std::vector<size_type> rank_c_j;

for (uint64_t row : row_ids) {
uint64_t first_string_index = delimiters_.select1(row + 1) - row;
uint64_t last_string_index = delimiters_.select1(row + 2) - (row + 1);

size_type num_row_set_bits;

rank_c_i.resize(last_string_index - first_string_index);
rank_c_j.resize(last_string_index - first_string_index);

label_indices.resize(last_string_index - first_string_index);

wt_.interval_symbols(first_string_index,
last_string_index,
num_row_set_bits,
label_indices,
rank_c_i,
rank_c_j);

for (Column j : label_indices) {
slice.push_back(j);
}

slice.push_back(std::numeric_limits<Column>::max());
}

return slice;
}

std::vector<BinRelWT_sdsl::Row> BinRelWT_sdsl::get_column(Column column) const {
assert(column < num_columns_);

Expand Down
Loading

0 comments on commit 9951913

Please sign in to comment.