Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decimal Type determined at Compile Time (#26) #126

Merged
merged 30 commits into from
Dec 2, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
e170ee7
feat: add compile time decimal type
mookums Nov 24, 2023
3cdcd31
feat: add -D when passed LOST_DATABASE_DOUBLE
mookums Nov 24, 2023
7fadedf
feat: switch from float to decimal
mookums Nov 24, 2023
c1c0fd4
feat: add db flagging & checking
mookums Nov 24, 2023
4a70b3a
fix: cpplint
mookums Nov 24, 2023
9ae2b85
refactor: change all float endianness to decimal endianness
mookums Nov 24, 2023
dff0155
feat: make double mode default
mookums Nov 28, 2023
83e66a6
feat: eigen3 with float/double accordingly
mookums Nov 28, 2023
92b143e
feat: optimize dbFlags
mookums Nov 28, 2023
09ff6d9
fix: prevent double promotion
mookums Nov 28, 2023
6195ddd
fix: cpplint again
mookums Nov 28, 2023
3378aa4
fix: tests with decimal type
mookums Nov 28, 2023
3c4f908
test: add test with float mode
mookums Nov 28, 2023
2128c9e
fix: remove debug logging
mookums Nov 28, 2023
c42076e
feat: add decimal support to kvector test
mookums Nov 28, 2023
46392c8
fix: rebuild with LOST_FLOAT_MODE on test
mookums Nov 28, 2023
08a5bd5
fix: gitlab actions syntax
mookums Nov 28, 2023
8044f6e
fix: split into independent double and float builds
mookums Nov 28, 2023
1328d9f
fix: switch all math functions to decimal wrapped ones
mookums Nov 28, 2023
4511cc3
fix: use math macros instead of redefining them
mookums Nov 28, 2023
bf183b1
fix: cpplint again again
mookums Nov 28, 2023
380cdd6
fix: findpairsexact instead of findpairsliberal
mookums Nov 29, 2023
53259c9
refactor: apply conditional only to format string
mookums Nov 29, 2023
a55c216
feat: error on double promotion when in float mode
mookums Nov 29, 2023
f9c1d80
fix: double promotion in test
mookums Nov 29, 2023
321a4ab
fix: use DECIMAL macro instead of typecast
mookums Nov 29, 2023
1e49367
refactor: return to uint32_t for dbFlags
mookums Nov 29, 2023
e36e1ec
fix: double promotion in kvector test
mookums Nov 29, 2023
debaa0f
fix: cpplint wins again
mookums Nov 29, 2023
aace162
fix: wrap all decimal macros in parenthesis
mookums Dec 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 26 additions & 15 deletions src/databases.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ namespace lost {

const int32_t PairDistanceKVectorDatabase::kMagicValue = 0x2536f009;

inline bool isFlagSet(uint8_t dbFlags, uint8_t flag) {
return (dbFlags & flag) != 0;
}

struct KVectorPair {
int16_t index1;
int16_t index2;
Expand Down Expand Up @@ -109,7 +113,7 @@ KVectorIndex::KVectorIndex(DeserializeContext *des) {
max = DeserializePrimitive<decimal>(des);
numBins = DeserializePrimitive<int32_t>(des);

assert(min >= 0.0f);
assert(min >= DECIMAL(0.0));
assert(max > min);
binWidth = (max - min) / numBins;

Expand All @@ -124,10 +128,10 @@ KVectorIndex::KVectorIndex(DeserializeContext *des) {
long KVectorIndex::QueryLiberal(decimal minQueryDistance, decimal maxQueryDistance, long *upperIndex) const {
assert(maxQueryDistance > minQueryDistance);
if (maxQueryDistance >= max) {
maxQueryDistance = max - 0.00001; // TODO: better way to avoid hitting the bottom bin
maxQueryDistance = max - DECIMAL(0.00001); // TODO: better way to avoid hitting the bottom bin
}
if (minQueryDistance <= min) {
minQueryDistance = min + 0.00001;
minQueryDistance = min + DECIMAL(0.00001);
}
if (minQueryDistance > max || maxQueryDistance < min) {
*upperIndex = 0;
Expand Down Expand Up @@ -175,7 +179,7 @@ std::vector<KVectorPair> CatalogToPairDistances(const Catalog &catalog, decimal
KVectorPair pair = { i, k, AngleUnit(catalog[i].spatial, catalog[k].spatial) };
assert(isfinite(pair.distance));
assert(pair.distance >= 0);
assert(pair.distance <= M_PI);
assert(pair.distance <= DECIMAL_M_PI);

if (pair.distance >= minDistance && pair.distance <= maxDistance) {
// we'll sort later
Expand Down Expand Up @@ -232,7 +236,7 @@ decimal Clamp(decimal num, decimal low, decimal high) {
const int16_t *PairDistanceKVectorDatabase::FindPairsLiberal(
decimal minQueryDistance, decimal maxQueryDistance, const int16_t **end) const {

assert(maxQueryDistance <= M_PI);
assert(maxQueryDistance <= DECIMAL_M_PI);

long upperIndex = -1;
long lowerIndex = index.QueryLiberal(minQueryDistance, maxQueryDistance, &upperIndex);
Expand All @@ -245,9 +249,9 @@ const int16_t *PairDistanceKVectorDatabase::FindPairsExact(const Catalog &catalo

// Instead of computing the angle for every pair in the database, we pre-compute the /cosines/
// of the min and max query distances so that we can compare against dot products directly! As
// angle increases, cosine decreases, up to M_PI (and queries larger than that don't really make
// angle increases, cosine decreases, up to DECIMAL_M_PI (and queries larger than that don't really make
// sense anyway)
assert(maxQueryDistance <= M_PI);
assert(maxQueryDistance <= DECIMAL_M_PI);

decimal maxQueryCos = cos(minQueryDistance);
decimal minQueryCos = cos(maxQueryDistance);
Expand Down Expand Up @@ -318,13 +322,20 @@ const unsigned char *MultiDatabase::SubDatabasePointer(int32_t magicValue) const
if (curMagicValue == 0) {
return nullptr;
}
uint32_t dbFlags = DeserializePrimitive<uint32_t>(des);
uint8_t dbFlags = DeserializePrimitive<uint8_t>(des);

// Ensure that our database is using the same type as the runtime.
if (dbFlags & MULTI_DB_IS_DOUBLE) {
assert(typeid(decimal) == typeid(double));
} else {
assert(typeid(decimal) == typeid(float));
}
#ifdef LOST_FLOAT_MODE
if(!isFlagSet(dbFlags, MULTI_DB_FLOAT_FLAG)) {
std::cerr << "LOST was compiled in float mode. This database was serialized in double mode and is incompatible." << std::endl;
exit(1);
}
#else
if(isFlagSet(dbFlags, MULTI_DB_FLOAT_FLAG)) {
std::cerr << "LOST was compiled in double mode. This database was serialized in float mode and is incompatible." << std::endl;
exit(1);
}
#endif

uint32_t dbLength = DeserializePrimitive<uint32_t>(des);
assert(dbLength > 0);
Expand All @@ -340,10 +351,10 @@ const unsigned char *MultiDatabase::SubDatabasePointer(int32_t magicValue) const

void SerializeMultiDatabase(SerializeContext *ser,
const MultiDatabaseDescriptor &dbs,
uint32_t flags) {
uint8_t flags) {
mookums marked this conversation as resolved.
Show resolved Hide resolved
for (const MultiDatabaseEntry &multiDbEntry : dbs) {
SerializePrimitive<int32_t>(ser, multiDbEntry.magicValue);
SerializePrimitive<uint32_t>(ser, flags);
SerializePrimitive<uint8_t>(ser, flags);
SerializePrimitive<uint32_t>(ser, multiDbEntry.bytes.size());
SerializePadding<uint64_t>(ser);
std::copy(multiDbEntry.bytes.cbegin(), multiDbEntry.bytes.cend(), std::back_inserter(ser->buffer));
Expand Down
9 changes: 5 additions & 4 deletions src/databases.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ namespace lost {

const int32_t kCatalogMagicValue = 0xF9A283BC;

inline bool isFlagSet(uint8_t dbFlags, uint8_t flag);

/**
* A data structure enabling constant-time range queries into fixed numerical data.
*
Expand Down Expand Up @@ -97,8 +99,7 @@ class PairDistanceKVectorDatabase {
* Multi-databases are essentially a map from "magic values" to database buffers.
*/

#define MULTI_DB_IS_DOUBLE 0x0001
#define MULTI_DB_IS_FLOAT 0x0000
#define MULTI_DB_FLOAT_FLAG 0x0001 // By default, our DB is in double mode.

class MultiDatabase {
public:
Expand All @@ -115,13 +116,13 @@ class MultiDatabaseEntry {
: magicValue(magicValue), bytes(bytes) { }

int32_t magicValue;
uint32_t flags;
uint8_t flags;
std::vector<unsigned char> bytes;
};

typedef std::vector<MultiDatabaseEntry> MultiDatabaseDescriptor;

void SerializeMultiDatabase(SerializeContext *, const MultiDatabaseDescriptor &dbs, uint32_t flags);
void SerializeMultiDatabase(SerializeContext *, const MultiDatabaseDescriptor &dbs, uint8_t flags);

}

Expand Down