Skip to content

Commit

Permalink
feat(c/driver/postgresql): Support for writing DECIMAL types (#1288)
Browse files Browse the repository at this point in the history
  • Loading branch information
WillAyd authored Jan 3, 2024
1 parent 650994d commit 2116cff
Show file tree
Hide file tree
Showing 6 changed files with 629 additions and 0 deletions.
160 changes: 160 additions & 0 deletions c/driver/postgresql/postgres_copy_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -1224,6 +1224,152 @@ class PostgresCopyIntervalFieldWriter : public PostgresCopyFieldWriter {
}
};

// Inspiration for this taken from get_str_from_var in the pg source
// src/backend/utils/adt/numeric.c
template<enum ArrowType T>
class PostgresCopyNumericFieldWriter : public PostgresCopyFieldWriter {
public:
PostgresCopyNumericFieldWriter<T>(int32_t precision, int32_t scale) :
precision_{precision}, scale_{scale} {}

ArrowErrorCode Write(ArrowBuffer* buffer, int64_t index, ArrowError* error) override {
struct ArrowDecimal decimal;
ArrowDecimalInit(&decimal, bitwidth_, precision_, scale_);
ArrowArrayViewGetDecimalUnsafe(array_view_, index, &decimal);

const int16_t sign = ArrowDecimalSign(&decimal) > 0 ? kNumericPos : kNumericNeg;

// Number of decimal digits per Postgres digit
constexpr int kDecDigits = 4;
std::vector<int16_t> pg_digits;
int16_t weight = -(scale_ / kDecDigits);
int16_t dscale = scale_;
bool seen_decimal = scale_ == 0;
bool truncating_trailing_zeros = true;

char decimal_string[max_decimal_digits_ + 1];
int digits_remaining = DecimalToString<bitwidth_>(&decimal, decimal_string);
do {
const int start_pos = digits_remaining < kDecDigits ?
0 : digits_remaining - kDecDigits;
const size_t len = digits_remaining < 4 ? digits_remaining : kDecDigits;
char substr[kDecDigits + 1];
std::memcpy(substr, decimal_string + start_pos, len);
substr[len] = '\0';
int16_t val = static_cast<int16_t>(std::atoi(substr));

if (val == 0) {
if (!seen_decimal && truncating_trailing_zeros) {
dscale -= kDecDigits;
}
} else {
pg_digits.insert(pg_digits.begin(), val);
if (!seen_decimal && truncating_trailing_zeros) {
if (val % 1000 == 0) {
dscale -= 3;
} else if (val % 100 == 0) {
dscale -= 2;
} else if (val % 10 == 0) {
dscale -= 1;
}
}
truncating_trailing_zeros = false;
}
digits_remaining -= kDecDigits;
if (digits_remaining <= 0) {
break;
}
weight++;

if (start_pos <= static_cast<int>(std::strlen(decimal_string)) - scale_) {
seen_decimal = true;
}
} while (true);

int16_t ndigits = pg_digits.size();
int32_t field_size_bytes = sizeof(ndigits)
+ sizeof(weight)
+ sizeof(sign)
+ sizeof(dscale)
+ ndigits * sizeof(int16_t);

NANOARROW_RETURN_NOT_OK(WriteChecked<int32_t>(buffer, field_size_bytes, error));
NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, ndigits, error));
NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, weight, error));
NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, sign, error));
NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, dscale, error));

const size_t pg_digit_bytes = sizeof(int16_t) * pg_digits.size();
NANOARROW_RETURN_NOT_OK(ArrowBufferReserve(buffer, pg_digit_bytes));
for (auto pg_digit : pg_digits) {
WriteUnsafe<int16_t>(buffer, pg_digit);
}

return ADBC_STATUS_OK;
}

private:
// returns the length of the string
template <int32_t DEC_WIDTH>
int DecimalToString(struct ArrowDecimal* decimal, char* out) {
constexpr size_t nwords = (DEC_WIDTH == 128) ? 2 : 4;
uint8_t tmp[DEC_WIDTH / 8];
ArrowDecimalGetBytes(decimal, tmp);
uint64_t buf[DEC_WIDTH / 64];
std::memcpy(buf, tmp, sizeof(buf));
const int16_t sign = ArrowDecimalSign(decimal) > 0 ? kNumericPos : kNumericNeg;
const bool is_negative = sign == kNumericNeg ? true : false;
if (is_negative) {
buf[0] = ~buf[0] + 1;
for (size_t i = 1; i < nwords; i++) {
buf[i] = ~buf[i];
}
}

// Basic approach adopted from https://stackoverflow.com/a/8023862/621736
char s[max_decimal_digits_ + 1];
std::memset(s, '0', sizeof(s) - 1);
s[sizeof(s) - 1] = '\0';

for (size_t i = 0; i < DEC_WIDTH; i++) {
int carry;

carry = (buf[nwords - 1] >= 0x7FFFFFFFFFFFFFFF);
for (size_t j = nwords - 1; j > 0; j--) {
buf[j] = ((buf[j] << 1) & 0xFFFFFFFFFFFFFFFF) + (buf[j-1] >= 0x7FFFFFFFFFFFFFFF);
}
buf[0] = ((buf[0] << 1) & 0xFFFFFFFFFFFFFFFF);

for (int j = sizeof(s) - 2; j>= 0; j--) {
s[j] += s[j] - '0' + carry;
carry = (s[j] > '9');
if (carry) {
s[j] -= 10;
}
}
}

char* p = s;
while ((p[0] == '0') && (p < &s[sizeof(s) - 2])) {
p++;
}

const size_t ndigits = sizeof(s) - 1 - (p - s);
std::memcpy(out, p, ndigits);
out[ndigits] = '\0';

return ndigits;
}

static constexpr uint16_t kNumericPos = 0x0000;
static constexpr uint16_t kNumericNeg = 0x4000;
static constexpr int32_t bitwidth_ = (T == NANOARROW_TYPE_DECIMAL128) ? 128 : 256;
static constexpr size_t max_decimal_digits_ =
(T == NANOARROW_TYPE_DECIMAL128) ? 39 : 78;
const int32_t precision_;
const int32_t scale_;
};

template <enum ArrowTimeUnit TU>
class PostgresCopyDurationFieldWriter : public PostgresCopyFieldWriter {
public:
Expand Down Expand Up @@ -1392,6 +1538,20 @@ static inline ArrowErrorCode MakeCopyFieldWriter(struct ArrowSchema* schema,
case NANOARROW_TYPE_DOUBLE:
*out = new PostgresCopyDoubleFieldWriter();
return NANOARROW_OK;
case NANOARROW_TYPE_DECIMAL128: {
const auto precision = schema_view.decimal_precision;
const auto scale = schema_view.decimal_scale;
*out = new PostgresCopyNumericFieldWriter<
NANOARROW_TYPE_DECIMAL128>(precision, scale);
return NANOARROW_OK;
}
case NANOARROW_TYPE_DECIMAL256: {
const auto precision = schema_view.decimal_precision;
const auto scale = schema_view.decimal_scale;
*out = new PostgresCopyNumericFieldWriter<
NANOARROW_TYPE_DECIMAL256>(precision, scale);
return NANOARROW_OK;
}
case NANOARROW_TYPE_BINARY:
case NANOARROW_TYPE_STRING:
case NANOARROW_TYPE_LARGE_STRING:
Expand Down
66 changes: 66 additions & 0 deletions c/driver/postgresql/postgres_copy_reader_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,72 @@ TEST(PostgresCopyUtilsTest, PostgresCopyReadNumeric) {
EXPECT_EQ(std::string(item.data, item.size_bytes), "inf");
}

// This buffer is similar to the read variant above but removes special values
// nan, ±inf as they are not supported via the Arrow Decimal types
// COPY (SELECT CAST(col AS NUMERIC) AS col FROM ( VALUES (NULL), (-123.456),
// ('0.00001234'), (1.0000), (123.456), (1000000)) AS drvd(col))
// TO STDOUT WITH (FORMAT binary);
static uint8_t kTestPgCopyNumericWrite[] = {
0x50, 0x47, 0x43, 0x4f, 0x50, 0x59, 0x0a, 0xff, 0x0d, 0x0a, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0xff, 0xff, 0xff, 0xff, 0x00, 0x01, 0x00,
0x00, 0x00, 0x0c, 0x00, 0x02, 0x00, 0x00, 0x40, 0x00, 0x00, 0x03, 0x00, 0x7b, 0x11,
0xd0, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x01, 0xff, 0xfe, 0x00, 0x00, 0x00,
0x08, 0x04, 0xd2, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x01, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x02, 0x00,
0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x7b, 0x11, 0xd0, 0x00, 0x01, 0x00, 0x00, 0x00,
0x0a, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x64, 0xff, 0xff};

TEST(PostgresCopyUtilsTest, PostgresCopyWriteNumeric) {
adbc_validation::Handle<struct ArrowSchema> schema;
adbc_validation::Handle<struct ArrowArray> array;
struct ArrowError na_error;
constexpr enum ArrowType type = NANOARROW_TYPE_DECIMAL128;
constexpr int32_t size = 128;
constexpr int32_t precision = 38;
constexpr int32_t scale = 8;

struct ArrowDecimal decimal1;
struct ArrowDecimal decimal2;
struct ArrowDecimal decimal3;
struct ArrowDecimal decimal4;
struct ArrowDecimal decimal5;

ArrowDecimalInit(&decimal1, size, 19, 8);
ArrowDecimalSetInt(&decimal1, -12345600000);
ArrowDecimalInit(&decimal2, size, 19, 8);
ArrowDecimalSetInt(&decimal2, 1234);
ArrowDecimalInit(&decimal3, size, 19, 8);
ArrowDecimalSetInt(&decimal3, 100000000);
ArrowDecimalInit(&decimal4, size, 19, 8);
ArrowDecimalSetInt(&decimal4, 12345600000);
ArrowDecimalInit(&decimal5, size, 19, 8);
ArrowDecimalSetInt(&decimal5, 100000000000000);

const std::vector<std::optional<ArrowDecimal*>> values = {
std::nullopt, &decimal1, &decimal2, &decimal3, &decimal4, &decimal5};

ArrowSchemaInit(&schema.value);
ASSERT_EQ(ArrowSchemaSetTypeStruct(&schema.value, 1), 0);
ASSERT_EQ(AdbcNsArrowSchemaSetTypeDecimal(schema.value.children[0],
type, precision, scale), 0);
ASSERT_EQ(ArrowSchemaSetName(schema.value.children[0], "col"), 0);
ASSERT_EQ(adbc_validation::MakeBatch<ArrowDecimal*>(&schema.value, &array.value,
&na_error, values), ADBC_STATUS_OK);

PostgresCopyStreamWriteTester tester;
ASSERT_EQ(tester.Init(&schema.value, &array.value), NANOARROW_OK);
ASSERT_EQ(tester.WriteAll(nullptr), ENODATA);

const struct ArrowBuffer buf = tester.WriteBuffer();
// The last 2 bytes of a message can be transmitted via PQputCopyData
// so no need to test those bytes from the Writer
constexpr size_t buf_size = sizeof(kTestPgCopyNumericWrite) - 2;
ASSERT_EQ(buf.size_bytes, buf_size);
for (size_t i = 0; i < buf_size; i++) {
ASSERT_EQ(buf.data[i], kTestPgCopyNumericWrite[i]) << " at position " << i;
}
}

// COPY (SELECT CAST(col AS TIMESTAMP) FROM ( VALUES ('1900-01-01 12:34:56'),
// ('2100-01-01 12:34:56'), (NULL)) AS drvd("col")) TO STDOUT WITH (FORMAT BINARY);
static uint8_t kTestPgCopyTimestamp[] = {
Expand Down
Loading

0 comments on commit 2116cff

Please sign in to comment.