Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(c/driver/postgresql): Support for writing DECIMAL types #1288

Merged
merged 35 commits into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
a24b046
Initial hacks
WillAyd Nov 3, 2023
dbddd8b
Merge remote-tracking branch 'upstream/main' into copy-decimal
WillAyd Nov 13, 2023
c5dfd05
feat(c/driver/postgresql): Support for writing DECIMAL128
WillAyd Nov 13, 2023
a091bcd
removed TODO
WillAyd Nov 13, 2023
61eb3cc
trailing decimals
WillAyd Nov 14, 2023
a76ab65
Merge remote-tracking branch 'upstream/main' into copy-decimal
WillAyd Nov 26, 2023
078de30
more decimal hacks
WillAyd Nov 28, 2023
bf8ed7b
working for positive decimal values
WillAyd Nov 29, 2023
75cbd58
Merge branch 'main' into copy-decimal
WillAyd Nov 29, 2023
94bf657
negative value support
WillAyd Nov 29, 2023
4b49999
skip other drivers
WillAyd Nov 29, 2023
c046632
No std::string_view
WillAyd Nov 29, 2023
3957b6d
cleanups
WillAyd Nov 29, 2023
c5d19bb
more generic ToString
WillAyd Nov 29, 2023
ba44774
don't hardcode precision and scale
WillAyd Nov 29, 2023
06b6349
Decimal256 Support
WillAyd Nov 30, 2023
bc19709
remove dead code
WillAyd Nov 30, 2023
10e6e09
Merge remote-tracking branch 'upstream/main' into copy-decimal
WillAyd Dec 14, 2023
e9967a7
less string
WillAyd Dec 14, 2023
6a0d3c9
Allocate up front
WillAyd Dec 16, 2023
df7ba3e
compiling with lifecycle issues
WillAyd Dec 18, 2023
ac733bf
lifecycle workarounds
WillAyd Dec 18, 2023
0cf303f
Try parametrized postgres-test suite
WillAyd Dec 18, 2023
59cdb22
fix test precision / scale arguments
WillAyd Dec 18, 2023
759b0f1
add nullability testing
WillAyd Dec 18, 2023
9472ff5
decimal256 test cases (but failing)
WillAyd Dec 18, 2023
0eba157
passing DECIMAL256 tests
WillAyd Dec 18, 2023
97c2d5c
lint
WillAyd Dec 18, 2023
443efed
endian agnosticism
WillAyd Dec 18, 2023
5a93f9e
fixups
WillAyd Dec 18, 2023
b629aca
msvc compat?
WillAyd Dec 22, 2023
f5100d0
fix COPY test
WillAyd Dec 22, 2023
7e7351d
Simple benchmark
WillAyd Dec 22, 2023
dc1b735
return int instead of void
WillAyd Dec 22, 2023
cc252cd
Merge remote-tracking branch 'upstream/main' into copy-decimal
WillAyd Jan 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions c/driver/postgresql/postgres_copy_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -1217,6 +1217,77 @@ class PostgresCopyIntervalFieldWriter : public PostgresCopyFieldWriter {
}
};


// Inspiration for this taken from get_str_from_var in the pg source
// src/backend/utils/adt/numeric.c
class PostgresCopyNumericFieldWriter : public PostgresCopyFieldWriter {
public:
ArrowErrorCode Write(ArrowBuffer* buffer, int64_t index, ArrowError* error) override {
struct ArrowDecimal decimal;
// TODO: these need to be inferred from the schema not hard coded
constexpr int16_t precision = 19;
constexpr int16_t scale = 8;
ArrowDecimalInit(&decimal, 128, precision, scale);
ArrowArrayViewGetDecimalUnsafe(array_view_, index, &decimal);
constexpr uint16_t kNumericPos = 0x0000;
constexpr uint16_t kNumericNeg = 0x4000;
constexpr int64_t kNBase = 10000;
// Number of decimal digits per Postgres digit
constexpr int kDecDigits = 4;

// TODO: need some kind of bounds check on this
int64_t decimal_int = ArrowDecimalGetIntUnsafe(&decimal);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is ultimately correct. I think ideally we would just use the bytes backing the Decimal object, but I haven't yet figured out how that all gets managed when multiple words are required

// TODO: is -INT64_MIN possible? If so how do we handle?
if (decimal_int < 0) {
decimal_int = -decimal_int;
}
std::vector<int16_t> pg_digits;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably too small of an optimization to matter, but in principle you should be able to put an upper bound on the number of digits needed to represent an Arrow decimal, and then just stack-allocate?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea that's true and actually how postgres does it internally.

https://github.com/postgres/postgres/blob/8680bae8463a0b213893ca6a1c5bb2c2530e823c/src/backend/utils/adt/numeric.c#L8026

If we wanted to stack allocate I guess would just expand that out to whatever is required to store up to 4 decimal words?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, decimal128 would be 38 digits and decimal256 would be 76

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK cool. Shouldn't be too hard to switch to that - just need to figure out how to handle once I get multi-word decimals supported


int16_t weight = -1;
constexpr size_t nloops = (precision + kDecDigits - 1) / kDecDigits;
for (size_t i = 0; i < nloops; i++) {
const int64_t rem = decimal_int % kNBase;
// TODO: postgres seems to pack records to the left of a decimal place
// internally, so 1000000.0 would be sent as one digit of 100 with
// a weight of 1 (there are weight + 1 pg digits to the left of a decimal place)
// Here we still send two digits of 100 and 0000
pg_digits.insert(pg_digits.begin(), rem);

// TODO: how does pg deal with words when integer and decimal part are sent
// in same word?
decimal_int /= kNBase;
if (i >= scale / kDecDigits) {
weight++;
if (decimal_int == 0) {
break;
}
}
}

int16_t ndigits = pg_digits.size();
const int16_t sign = ArrowDecimalSign(&decimal) > 0 ? kNumericPos : kNumericNeg;
const int16_t dscale = scale;

int32_t field_size_bytes = sizeof(ndigits)
+ sizeof(weight)
+ sizeof(sign)
+ sizeof(dscale)
+ ndigits * sizeof(int16_t);

NANOARROW_RETURN_NOT_OK(WriteChecked<int32_t>(buffer, field_size_bytes, error));
NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, ndigits, error));
NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, weight, error));
NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, sign, error));
NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, dscale, error));

for (auto pg_digit : pg_digits) {
NANOARROW_RETURN_NOT_OK(WriteChecked<int16_t>(buffer, pg_digit, error));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

presumably you could check once then memcpy the digits over

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(again, possibly the compiler already does this)

}

return ADBC_STATUS_OK;
}
};

template <enum ArrowTimeUnit TU>
class PostgresCopyDurationFieldWriter : public PostgresCopyFieldWriter {
public:
Expand Down Expand Up @@ -1379,6 +1450,9 @@ static inline ArrowErrorCode MakeCopyFieldWriter(struct ArrowSchema* schema,
case NANOARROW_TYPE_DOUBLE:
*out = new PostgresCopyDoubleFieldWriter();
return NANOARROW_OK;
case NANOARROW_TYPE_DECIMAL128:
*out = new PostgresCopyNumericFieldWriter();
return NANOARROW_OK;
case NANOARROW_TYPE_BINARY:
case NANOARROW_TYPE_STRING:
case NANOARROW_TYPE_LARGE_STRING:
Expand Down
61 changes: 61 additions & 0 deletions c/driver/postgresql/postgres_copy_reader_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,67 @@ TEST(PostgresCopyUtilsTest, PostgresCopyReadNumeric) {
EXPECT_EQ(std::string(item.data, item.size_bytes), "inf");
}

// This buffer is similar to the read variant above but removes special values
// nan, ±inf as they are not supported via the Arrow Decimal types
// COPY (SELECT CAST(col AS NUMERIC) AS col FROM ( VALUES (NULL), (-123.456),
// ('0.00001234'), ('1.0000'), (123.456), (1000000)) AS drvd(col))
// TO STDOUT WITH (FORMAT binary);
static uint8_t kTestPgCopyNumericWrite[] = {
0x50, 0x47, 0x43, 0x4f, 0x50, 0x59, 0x0a, 0xff, 0x0d, 0x0a, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0xff, 0xff, 0xff, 0xff, 0x00, 0x01, 0x00,
0x00, 0x00, 0x0c, 0x00, 0x02, 0x00, 0x00, 0x40, 0x00, 0x00, 0x03, 0x00, 0x7b, 0x11,
0xd0, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0, 0x00, 0x01, 0xff, 0xfe, 0x00, 0x00, 0x00,
0x08, 0x04, 0xd2, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x01, 0x00, 0x00, 0x00,
0x00, 0x00, 0x04, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x02, 0x00,
0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x7b, 0x11, 0xd0, 0x00, 0x01, 0x00, 0x00, 0x00,
0x0a, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x64, 0xff, 0xff};

TEST(PostgresCopyUtilsTest, PostgresCopyWriteNumeric) {
adbc_validation::Handle<struct ArrowSchema> schema;
adbc_validation::Handle<struct ArrowArray> array;
struct ArrowError na_error;
constexpr int32_t size = 128;
constexpr enum ArrowType type = NANOARROW_TYPE_DECIMAL128;

struct ArrowDecimal decimal1;
struct ArrowDecimal decimal2;
struct ArrowDecimal decimal3;
struct ArrowDecimal decimal4;
struct ArrowDecimal decimal5;

ArrowDecimalInit(&decimal1, size, 19, 8);
ArrowDecimalSetInt(&decimal1, -12345600000);
ArrowDecimalInit(&decimal2, size, 19, 8);
ArrowDecimalSetInt(&decimal2, 1234);
ArrowDecimalInit(&decimal3, size, 19, 8);
ArrowDecimalSetInt(&decimal3, 100000000);
ArrowDecimalInit(&decimal4, size, 19, 8);
ArrowDecimalSetInt(&decimal4, 12345600000);
ArrowDecimalInit(&decimal5, size, 19, 8);
ArrowDecimalSetInt(&decimal5, 100000000000000);

const std::vector<std::optional<ArrowDecimal*>> values = {
std::nullopt, &decimal1, &decimal2, &decimal3, &decimal4, &decimal5};

ASSERT_EQ(adbc_validation::MakeSchema(&schema.value, {{"col", type}}),
ADBC_STATUS_OK);
ASSERT_EQ(adbc_validation::MakeBatch<ArrowDecimal*>(&schema.value, &array.value,
&na_error, values), ADBC_STATUS_OK);

PostgresCopyStreamWriteTester tester;
ASSERT_EQ(tester.Init(&schema.value, &array.value), NANOARROW_OK);
ASSERT_EQ(tester.WriteAll(nullptr), ENODATA);

const struct ArrowBuffer buf = tester.WriteBuffer();
// The last 2 bytes of a message can be transmitted via PQputCopyData
// so no need to test those bytes from the Writer
constexpr size_t buf_size = sizeof(kTestPgCopyNumericWrite) - 2;
ASSERT_EQ(buf.size_bytes, buf_size);
for (size_t i = 0; i < buf_size; i++) {
ASSERT_EQ(buf.data[i], kTestPgCopyNumericWrite[i]);
}
}

// COPY (SELECT CAST(col AS TIMESTAMP) FROM ( VALUES ('1900-01-01 12:34:56'),
// ('2100-01-01 12:34:56'), (NULL)) AS drvd("col")) TO STDOUT WITH (FORMAT BINARY);
static uint8_t kTestPgCopyTimestamp[] = {
Expand Down
2 changes: 2 additions & 0 deletions c/driver/postgresql/postgresql_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ class PostgresQuirks : public adbc_validation::DriverQuirks {
return NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO;
case NANOARROW_TYPE_LARGE_STRING:
return NANOARROW_TYPE_STRING;
case NANOARROW_TYPE_DECIMAL128:
return NANOARROW_TYPE_STRING;
default:
return ingest_type;
}
Expand Down
7 changes: 7 additions & 0 deletions c/driver/postgresql/statement.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,10 @@ struct BindStream {
type_id = PostgresTypeId::kInterval;
param_lengths[i] = 16;
break;
case ArrowType::NANOARROW_TYPE_DECIMAL128:
type_id = PostgresTypeId::kNumeric;
param_lengths[i] = 0;
break;
case ArrowType::NANOARROW_TYPE_DICTIONARY: {
struct ArrowSchemaView value_view;
CHECK_NA(INTERNAL,
Expand Down Expand Up @@ -1056,6 +1060,9 @@ AdbcStatusCode PostgresStatement::CreateBulkTable(
case ArrowType::NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO:
create += " INTERVAL";
break;
case ArrowType::NANOARROW_TYPE_DECIMAL128:
create += " DECIMAL";
break;
case ArrowType::NANOARROW_TYPE_DICTIONARY: {
struct ArrowSchemaView value_view;
CHECK_NA(INTERNAL,
Expand Down
100 changes: 100 additions & 0 deletions c/validation/adbc_validation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1528,6 +1528,106 @@ void StatementTest::TestSqlIngestFloat64() {
ASSERT_NO_FATAL_FAILURE(TestSqlIngestNumericType<double>(NANOARROW_TYPE_DOUBLE));
}

// For full coverage, ensure that this contains Decimal examples that:
// - Have >= four zeroes to the left of the decimal point
// - Have >= four zeroes to the right of the decimal point
// - Have >= four trailing zeroes to the right of the decimal point
// - Have >= four leading zeroes before the first digit to the right of the decimal point
// - Is < 0 (negative)
// - The arrow Decimal implementations do not support special values nan, ±inf
void StatementTest::TestSqlIngestDecimal128() {
if (!quirks()->supports_bulk_ingest(ADBC_INGEST_OPTION_MODE_CREATE)) {
GTEST_SKIP();
}

ASSERT_THAT(quirks()->DropTable(&connection, "bulk_ingest", &error),
IsOkStatus(&error));

Handle<struct ArrowSchema> schema;
Handle<struct ArrowArray> array;
struct ArrowError na_error;
constexpr int32_t size = 128;
constexpr enum ArrowType type = NANOARROW_TYPE_DECIMAL128;

struct ArrowDecimal decimal1;
struct ArrowDecimal decimal2;
struct ArrowDecimal decimal3;
struct ArrowDecimal decimal4;
struct ArrowDecimal decimal5;

ArrowDecimalInit(&decimal1, size, 19, 8);
ArrowDecimalSetInt(&decimal1, -12345600000);
ArrowDecimalInit(&decimal2, size, 19, 8);
ArrowDecimalSetInt(&decimal2, 1234);
ArrowDecimalInit(&decimal3, size, 19, 8);
ArrowDecimalSetInt(&decimal3, 100000000);
ArrowDecimalInit(&decimal4, size, 19, 8);
ArrowDecimalSetInt(&decimal4, 12345600000);
ArrowDecimalInit(&decimal5, size, 19, 8);
ArrowDecimalSetInt(&decimal5, 100000000000000);

const std::vector<std::optional<ArrowDecimal*>> values = {
std::nullopt, &decimal1, &decimal2, &decimal3, &decimal4, &decimal5};

ASSERT_THAT(MakeSchema(&schema.value, {{"col", type}}), IsOkErrno());
ASSERT_THAT(MakeBatch<ArrowDecimal*>(&schema.value, &array.value,
&na_error, values), IsOkErrno());

ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_TABLE,
"bulk_ingest", &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error),
IsOkStatus(&error));

int64_t rows_affected = 0;
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, &rows_affected, &error),
IsOkStatus(&error));
ASSERT_THAT(rows_affected,
::testing::AnyOf(::testing::Eq(values.size()), ::testing::Eq(-1)));

ASSERT_THAT(AdbcStatementSetSqlQuery(
&statement,
"SELECT * FROM bulk_ingest ORDER BY \"col\" ASC NULLS FIRST", &error),
IsOkStatus(&error));
{
StreamReader reader;
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
&reader.rows_affected, &error),
IsOkStatus(&error));
ASSERT_THAT(reader.rows_affected,
::testing::AnyOf(::testing::Eq(values.size()), ::testing::Eq(-1)));

ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
ArrowType round_trip_type = quirks()->IngestSelectRoundTripType(type);
ASSERT_NO_FATAL_FAILURE(
CompareSchema(&reader.schema.value, {{"col", round_trip_type, NULLABLE}}));

ASSERT_NO_FATAL_FAILURE(reader.Next());
ASSERT_NE(nullptr, reader.array->release);
ASSERT_EQ(values.size(), reader.array->length);
ASSERT_EQ(1, reader.array->n_children);

// Currently postgres roundtrips to string, but in the future we should
// roundtrip to a decimal
//
//if (round_trip_type == type) {
// ASSERT_NO_FATAL_FAILURE(
// CompareArray<ArrowDecimal*>(reader.array_view->children[0], values));
//}

const std::vector<std::optional<std::string>> str_values = {
std::nullopt, "-123.45600000", "0.00001234", "1.00000000", "123.45600000",
"1000000.00000000"};
ASSERT_NO_FATAL_FAILURE(
CompareArray<std::string>(reader.array_view->children[0], str_values));

ASSERT_NO_FATAL_FAILURE(reader.Next());
ASSERT_EQ(nullptr, reader.array->release);
}
ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error));
}

void StatementTest::TestSqlIngestString() {
ASSERT_NO_FATAL_FAILURE(TestSqlIngestType<std::string>(
NANOARROW_TYPE_STRING, {std::nullopt, "", "", "1234", "例"}, false));
Expand Down
4 changes: 4 additions & 0 deletions c/validation/adbc_validation.h
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,9 @@ class StatementTest {
void TestSqlIngestFloat32();
void TestSqlIngestFloat64();

// Decmial
void TestSqlIngestDecimal128();

// Strings
void TestSqlIngestString();
void TestSqlIngestLargeString();
Expand Down Expand Up @@ -434,6 +437,7 @@ class StatementTest {
TEST_F(FIXTURE, SqlIngestUInt64) { TestSqlIngestUInt64(); } \
TEST_F(FIXTURE, SqlIngestFloat32) { TestSqlIngestFloat32(); } \
TEST_F(FIXTURE, SqlIngestFloat64) { TestSqlIngestFloat64(); } \
TEST_F(FIXTURE, SqlIngestDecimal128) { TestSqlIngestDecimal128(); } \
TEST_F(FIXTURE, SqlIngestString) { TestSqlIngestString(); } \
TEST_F(FIXTURE, SqlIngestLargeString) { TestSqlIngestLargeString(); } \
TEST_F(FIXTURE, SqlIngestBinary) { TestSqlIngestBinary(); } \
Expand Down
13 changes: 12 additions & 1 deletion c/validation/adbc_validation_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,18 @@ int MakeSchema(struct ArrowSchema* schema, const std::vector<SchemaField>& field
CHECK_ERRNO(ArrowSchemaSetTypeStruct(schema, fields.size()));
size_t i = 0;
for (const SchemaField& field : fields) {
CHECK_ERRNO(ArrowSchemaSetType(schema->children[i], field.type));
switch (field.type) {
case NANOARROW_TYPE_DECIMAL128:
// TODO: don't hardcore 19, 8
CHECK_ERRNO(AdbcNsArrowSchemaSetTypeDecimal(schema->children[i],
field.type,
19,
8));
break;
default:
CHECK_ERRNO(ArrowSchemaSetType(schema->children[i], field.type));
}

CHECK_ERRNO(ArrowSchemaSetName(schema->children[i], field.name.c_str()));
if (!field.nullable) {
schema->children[i]->flags &= ~ARROW_FLAG_NULLABLE;
Expand Down
15 changes: 15 additions & 0 deletions c/validation/adbc_validation_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,10 @@ int MakeArray(struct ArrowArray* parent, struct ArrowArray* array,
if (int errno_res = ArrowArrayAppendInterval(array, *v); errno_res != 0) {
return errno_res;
}
} else if constexpr (std::is_same<T, ArrowDecimal*>::value) {
if (int errno_res = ArrowArrayAppendDecimal(array, *v); errno_res != 0) {
return errno_res;
}
} else {
static_assert(!sizeof(T), "Not yet implemented");
return ENOTSUP;
Expand Down Expand Up @@ -412,6 +416,17 @@ void CompareArray(struct ArrowArrayView* array,
ASSERT_EQ(interval.months, (*v)->months);
ASSERT_EQ(interval.days, (*v)->days);
ASSERT_EQ(interval.ns, (*v)->ns);
} else if constexpr (std::is_same<T, ArrowDecimal*>::value) {
ASSERT_NE(array->buffer_views[1].data.data, nullptr);
struct ArrowDecimal decimal;
// For now assuming Decimal128 so set as bitwidth
ArrowDecimalInit(&decimal, 128, (*v)->precision, (*v)->scale);
ArrowArrayViewGetDecimalUnsafe(array, i, &decimal);

ASSERT_EQ(decimal.n_words, (*v)->n_words);
// For now assuming Decimal128 so only need to check first two words
ASSERT_EQ(decimal.words[0], (*v)->words[0]);
ASSERT_EQ(decimal.words[1], (*v)->words[1]);
} else {
static_assert(!sizeof(T), "Not yet implemented");
}
Expand Down
Loading