From 2116cffffe264ccdbaecffeee9fd06c88f22bce5 Mon Sep 17 00:00:00 2001 From: William Ayd Date: Wed, 3 Jan 2024 08:28:08 -0500 Subject: [PATCH] feat(c/driver/postgresql): Support for writing DECIMAL types (#1288) --- c/driver/postgresql/postgres_copy_reader.h | 160 ++++++++++++ .../postgresql/postgres_copy_reader_test.cc | 66 +++++ c/driver/postgresql/postgresql_benchmark.cc | 152 +++++++++++ c/driver/postgresql/postgresql_test.cc | 238 ++++++++++++++++++ c/driver/postgresql/statement.cc | 9 + c/validation/adbc_validation_util.h | 4 + 6 files changed, 629 insertions(+) diff --git a/c/driver/postgresql/postgres_copy_reader.h b/c/driver/postgresql/postgres_copy_reader.h index 686d54b81b..8a9192c329 100644 --- a/c/driver/postgresql/postgres_copy_reader.h +++ b/c/driver/postgresql/postgres_copy_reader.h @@ -1224,6 +1224,152 @@ class PostgresCopyIntervalFieldWriter : public PostgresCopyFieldWriter { } }; +// Inspiration for this taken from get_str_from_var in the pg source +// src/backend/utils/adt/numeric.c +template +class PostgresCopyNumericFieldWriter : public PostgresCopyFieldWriter { +public: + PostgresCopyNumericFieldWriter(int32_t precision, int32_t scale) : + precision_{precision}, scale_{scale} {} + + ArrowErrorCode Write(ArrowBuffer* buffer, int64_t index, ArrowError* error) override { + struct ArrowDecimal decimal; + ArrowDecimalInit(&decimal, bitwidth_, precision_, scale_); + ArrowArrayViewGetDecimalUnsafe(array_view_, index, &decimal); + + const int16_t sign = ArrowDecimalSign(&decimal) > 0 ? kNumericPos : kNumericNeg; + + // Number of decimal digits per Postgres digit + constexpr int kDecDigits = 4; + std::vector pg_digits; + int16_t weight = -(scale_ / kDecDigits); + int16_t dscale = scale_; + bool seen_decimal = scale_ == 0; + bool truncating_trailing_zeros = true; + + char decimal_string[max_decimal_digits_ + 1]; + int digits_remaining = DecimalToString(&decimal, decimal_string); + do { + const int start_pos = digits_remaining < kDecDigits ? + 0 : digits_remaining - kDecDigits; + const size_t len = digits_remaining < 4 ? digits_remaining : kDecDigits; + char substr[kDecDigits + 1]; + std::memcpy(substr, decimal_string + start_pos, len); + substr[len] = '\0'; + int16_t val = static_cast(std::atoi(substr)); + + if (val == 0) { + if (!seen_decimal && truncating_trailing_zeros) { + dscale -= kDecDigits; + } + } else { + pg_digits.insert(pg_digits.begin(), val); + if (!seen_decimal && truncating_trailing_zeros) { + if (val % 1000 == 0) { + dscale -= 3; + } else if (val % 100 == 0) { + dscale -= 2; + } else if (val % 10 == 0) { + dscale -= 1; + } + } + truncating_trailing_zeros = false; + } + digits_remaining -= kDecDigits; + if (digits_remaining <= 0) { + break; + } + weight++; + + if (start_pos <= static_cast(std::strlen(decimal_string)) - scale_) { + seen_decimal = true; + } + } while (true); + + int16_t ndigits = pg_digits.size(); + int32_t field_size_bytes = sizeof(ndigits) + + sizeof(weight) + + sizeof(sign) + + sizeof(dscale) + + ndigits * sizeof(int16_t); + + NANOARROW_RETURN_NOT_OK(WriteChecked(buffer, field_size_bytes, error)); + NANOARROW_RETURN_NOT_OK(WriteChecked(buffer, ndigits, error)); + NANOARROW_RETURN_NOT_OK(WriteChecked(buffer, weight, error)); + NANOARROW_RETURN_NOT_OK(WriteChecked(buffer, sign, error)); + NANOARROW_RETURN_NOT_OK(WriteChecked(buffer, dscale, error)); + + const size_t pg_digit_bytes = sizeof(int16_t) * pg_digits.size(); + NANOARROW_RETURN_NOT_OK(ArrowBufferReserve(buffer, pg_digit_bytes)); + for (auto pg_digit : pg_digits) { + WriteUnsafe(buffer, pg_digit); + } + + return ADBC_STATUS_OK; + } + +private: + // returns the length of the string + template + int DecimalToString(struct ArrowDecimal* decimal, char* out) { + constexpr size_t nwords = (DEC_WIDTH == 128) ? 2 : 4; + uint8_t tmp[DEC_WIDTH / 8]; + ArrowDecimalGetBytes(decimal, tmp); + uint64_t buf[DEC_WIDTH / 64]; + std::memcpy(buf, tmp, sizeof(buf)); + const int16_t sign = ArrowDecimalSign(decimal) > 0 ? kNumericPos : kNumericNeg; + const bool is_negative = sign == kNumericNeg ? true : false; + if (is_negative) { + buf[0] = ~buf[0] + 1; + for (size_t i = 1; i < nwords; i++) { + buf[i] = ~buf[i]; + } + } + + // Basic approach adopted from https://stackoverflow.com/a/8023862/621736 + char s[max_decimal_digits_ + 1]; + std::memset(s, '0', sizeof(s) - 1); + s[sizeof(s) - 1] = '\0'; + + for (size_t i = 0; i < DEC_WIDTH; i++) { + int carry; + + carry = (buf[nwords - 1] >= 0x7FFFFFFFFFFFFFFF); + for (size_t j = nwords - 1; j > 0; j--) { + buf[j] = ((buf[j] << 1) & 0xFFFFFFFFFFFFFFFF) + (buf[j-1] >= 0x7FFFFFFFFFFFFFFF); + } + buf[0] = ((buf[0] << 1) & 0xFFFFFFFFFFFFFFFF); + + for (int j = sizeof(s) - 2; j>= 0; j--) { + s[j] += s[j] - '0' + carry; + carry = (s[j] > '9'); + if (carry) { + s[j] -= 10; + } + } + } + + char* p = s; + while ((p[0] == '0') && (p < &s[sizeof(s) - 2])) { + p++; + } + + const size_t ndigits = sizeof(s) - 1 - (p - s); + std::memcpy(out, p, ndigits); + out[ndigits] = '\0'; + + return ndigits; + } + + static constexpr uint16_t kNumericPos = 0x0000; + static constexpr uint16_t kNumericNeg = 0x4000; + static constexpr int32_t bitwidth_ = (T == NANOARROW_TYPE_DECIMAL128) ? 128 : 256; + static constexpr size_t max_decimal_digits_ = + (T == NANOARROW_TYPE_DECIMAL128) ? 39 : 78; + const int32_t precision_; + const int32_t scale_; +}; + template class PostgresCopyDurationFieldWriter : public PostgresCopyFieldWriter { public: @@ -1392,6 +1538,20 @@ static inline ArrowErrorCode MakeCopyFieldWriter(struct ArrowSchema* schema, case NANOARROW_TYPE_DOUBLE: *out = new PostgresCopyDoubleFieldWriter(); return NANOARROW_OK; + case NANOARROW_TYPE_DECIMAL128: { + const auto precision = schema_view.decimal_precision; + const auto scale = schema_view.decimal_scale; + *out = new PostgresCopyNumericFieldWriter< + NANOARROW_TYPE_DECIMAL128>(precision, scale); + return NANOARROW_OK; + } + case NANOARROW_TYPE_DECIMAL256: { + const auto precision = schema_view.decimal_precision; + const auto scale = schema_view.decimal_scale; + *out = new PostgresCopyNumericFieldWriter< + NANOARROW_TYPE_DECIMAL256>(precision, scale); + return NANOARROW_OK; + } case NANOARROW_TYPE_BINARY: case NANOARROW_TYPE_STRING: case NANOARROW_TYPE_LARGE_STRING: diff --git a/c/driver/postgresql/postgres_copy_reader_test.cc b/c/driver/postgresql/postgres_copy_reader_test.cc index 7882d602af..201aa223a2 100644 --- a/c/driver/postgresql/postgres_copy_reader_test.cc +++ b/c/driver/postgresql/postgres_copy_reader_test.cc @@ -693,6 +693,72 @@ TEST(PostgresCopyUtilsTest, PostgresCopyReadNumeric) { EXPECT_EQ(std::string(item.data, item.size_bytes), "inf"); } +// This buffer is similar to the read variant above but removes special values +// nan, ±inf as they are not supported via the Arrow Decimal types +// COPY (SELECT CAST(col AS NUMERIC) AS col FROM ( VALUES (NULL), (-123.456), +// ('0.00001234'), (1.0000), (123.456), (1000000)) AS drvd(col)) +// TO STDOUT WITH (FORMAT binary); +static uint8_t kTestPgCopyNumericWrite[] = { + 0x50, 0x47, 0x43, 0x4f, 0x50, 0x59, 0x0a, 0xff, 0x0d, 0x0a, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0xff, 0xff, 0xff, 0xff, 0x00, 0x01, 0x00, + 0x00, 0x00, 0x0c, 0x00, 0x02, 0x00, 0x00, 0x40, 0x00, 0x00, 0x03, 0x00, 0x7b, 0x11, + 0xd0, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x01, 0xff, 0xfe, 0x00, 0x00, 0x00, + 0x08, 0x04, 0xd2, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x02, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x7b, 0x11, 0xd0, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x0a, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x64, 0xff, 0xff}; + +TEST(PostgresCopyUtilsTest, PostgresCopyWriteNumeric) { + adbc_validation::Handle schema; + adbc_validation::Handle array; + struct ArrowError na_error; + constexpr enum ArrowType type = NANOARROW_TYPE_DECIMAL128; + constexpr int32_t size = 128; + constexpr int32_t precision = 38; + constexpr int32_t scale = 8; + + struct ArrowDecimal decimal1; + struct ArrowDecimal decimal2; + struct ArrowDecimal decimal3; + struct ArrowDecimal decimal4; + struct ArrowDecimal decimal5; + + ArrowDecimalInit(&decimal1, size, 19, 8); + ArrowDecimalSetInt(&decimal1, -12345600000); + ArrowDecimalInit(&decimal2, size, 19, 8); + ArrowDecimalSetInt(&decimal2, 1234); + ArrowDecimalInit(&decimal3, size, 19, 8); + ArrowDecimalSetInt(&decimal3, 100000000); + ArrowDecimalInit(&decimal4, size, 19, 8); + ArrowDecimalSetInt(&decimal4, 12345600000); + ArrowDecimalInit(&decimal5, size, 19, 8); + ArrowDecimalSetInt(&decimal5, 100000000000000); + + const std::vector> values = { + std::nullopt, &decimal1, &decimal2, &decimal3, &decimal4, &decimal5}; + + ArrowSchemaInit(&schema.value); + ASSERT_EQ(ArrowSchemaSetTypeStruct(&schema.value, 1), 0); + ASSERT_EQ(AdbcNsArrowSchemaSetTypeDecimal(schema.value.children[0], + type, precision, scale), 0); + ASSERT_EQ(ArrowSchemaSetName(schema.value.children[0], "col"), 0); + ASSERT_EQ(adbc_validation::MakeBatch(&schema.value, &array.value, + &na_error, values), ADBC_STATUS_OK); + + PostgresCopyStreamWriteTester tester; + ASSERT_EQ(tester.Init(&schema.value, &array.value), NANOARROW_OK); + ASSERT_EQ(tester.WriteAll(nullptr), ENODATA); + + const struct ArrowBuffer buf = tester.WriteBuffer(); + // The last 2 bytes of a message can be transmitted via PQputCopyData + // so no need to test those bytes from the Writer + constexpr size_t buf_size = sizeof(kTestPgCopyNumericWrite) - 2; + ASSERT_EQ(buf.size_bytes, buf_size); + for (size_t i = 0; i < buf_size; i++) { + ASSERT_EQ(buf.data[i], kTestPgCopyNumericWrite[i]) << " at position " << i; + } +} + // COPY (SELECT CAST(col AS TIMESTAMP) FROM ( VALUES ('1900-01-01 12:34:56'), // ('2100-01-01 12:34:56'), (NULL)) AS drvd("col")) TO STDOUT WITH (FORMAT BINARY); static uint8_t kTestPgCopyTimestamp[] = { diff --git a/c/driver/postgresql/postgresql_benchmark.cc b/c/driver/postgresql/postgresql_benchmark.cc index 239575a7d7..908269966e 100644 --- a/c/driver/postgresql/postgresql_benchmark.cc +++ b/c/driver/postgresql/postgresql_benchmark.cc @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +#include #include #include @@ -175,5 +176,156 @@ static void BM_PostgresqlExecute(benchmark::State& state) { &error)); } +static void BM_PostgresqlDecimalWrite(benchmark::State& state) { + const char* uri = std::getenv("ADBC_POSTGRESQL_TEST_URI"); + if (!uri || !strcmp(uri, "")) { + state.SkipWithError("ADBC_POSTGRESQL_TEST_URI not set!"); + return; + } + adbc_validation::Handle database; + struct AdbcError error; + + ADBC_BENCHMARK_RETURN_NOT_OK(AdbcDatabaseNew(&database.value, &error)); + ADBC_BENCHMARK_RETURN_NOT_OK(AdbcDatabaseSetOption(&database.value, + "uri", + uri, + &error)); + ADBC_BENCHMARK_RETURN_NOT_OK(AdbcDatabaseInit(&database.value, &error)); + + adbc_validation::Handle connection; + ADBC_BENCHMARK_RETURN_NOT_OK(AdbcConnectionNew(&connection.value, &error)); + ADBC_BENCHMARK_RETURN_NOT_OK(AdbcConnectionInit(&connection.value, + &database.value, + &error)); + + adbc_validation::Handle statement; + ADBC_BENCHMARK_RETURN_NOT_OK(AdbcStatementNew(&connection.value, + &statement.value, + &error)); + + const char* drop_query = "DROP TABLE IF EXISTS adbc_postgresql_ingest_benchmark"; + ADBC_BENCHMARK_RETURN_NOT_OK(AdbcStatementSetSqlQuery(&statement.value, + drop_query, + &error)); + + ADBC_BENCHMARK_RETURN_NOT_OK(AdbcStatementExecuteQuery(&statement.value, + nullptr, + nullptr, + &error)); + + adbc_validation::Handle schema; + adbc_validation::Handle array; + struct ArrowError na_error; + + constexpr enum ArrowType type = NANOARROW_TYPE_DECIMAL128; + constexpr int32_t bitwidth = 128; + constexpr int32_t precision = 38; + constexpr int32_t scale = 8; + constexpr size_t ncols = 5; + ArrowSchemaInit(&schema.value); + if (ArrowSchemaSetTypeStruct(&schema.value, ncols) != NANOARROW_OK) { + state.SkipWithError("Call to ArrowSchemaSetTypeStruct failed!"); + error.release(&error); + return; + } + + for (size_t i = 0; i < ncols; i++) { + if (AdbcNsArrowSchemaSetTypeDecimal(schema.value.children[i], + type, precision, scale) != NANOARROW_OK) { + state.SkipWithError("Call to ArrowSchemaSetTypeDecimal failed!"); + error.release(&error); + return; + } + + std::string colname = "col" + std::to_string(i); + if (ArrowSchemaSetName(schema.value.children[i], colname.c_str()) != NANOARROW_OK) { + state.SkipWithError("Call to ArrowSchemaSetName failed!"); + error.release(&error); + return; + } + } + if (ArrowArrayInitFromSchema(&array.value, &schema.value, &na_error) != NANOARROW_OK) { + state.SkipWithError("Call to ArrowArrayInitFromSchema failed!"); + error.release(&error); + return; + } + + if (ArrowArrayStartAppending(&array.value) != NANOARROW_OK) { + state.SkipWithError("Call to ArrowArrayStartAppending failed!"); + error.release(&error); + return; + } + + constexpr size_t nrows = 1000; + struct ArrowDecimal decimal; + ArrowDecimalInit(&decimal, bitwidth, precision, scale); + for (size_t i = 0; i < nrows; i++) { + for (size_t j = 0; j < ncols; j++) { + ArrowDecimalSetInt(&decimal, i + j); + if (ArrowArrayAppendDecimal(array.value.children[j], &decimal) != NANOARROW_OK) { + state.SkipWithError("Call to ArrowArrayAppendDecimal failed"); + error.release(&error); + return; + } + } + } + + for (int64_t i = 0; i < array.value.n_children; i++) { + array.value.children[i]->length = nrows; + } + array.value.length = nrows; + + if (ArrowArrayFinishBuildingDefault(&array.value, &na_error) != NANOARROW_OK) { + state.SkipWithError("Call to ArrowArrayFinishBuildingDefault failed"); + error.release(&error); + return; + } + + const char* create_query = + "CREATE TABLE adbc_postgresql_ingest_benchmark (col0 DECIMAL(38, 8), " + "col1 DECIMAL(38, 8), col2 DECIMAL(38, 8), col3 DECIMAL(38, 8), col4 DECIMAL(38, 8))"; + + ADBC_BENCHMARK_RETURN_NOT_OK(AdbcStatementSetSqlQuery(&statement.value, + create_query, + &error)); + + ADBC_BENCHMARK_RETURN_NOT_OK(AdbcStatementExecuteQuery(&statement.value, + nullptr, + nullptr, + &error)); + + adbc_validation::Handle insert_stmt; + ADBC_BENCHMARK_RETURN_NOT_OK(AdbcStatementNew(&connection.value, + &insert_stmt.value, + &error)); + + ADBC_BENCHMARK_RETURN_NOT_OK(AdbcStatementSetOption(&insert_stmt.value, + ADBC_INGEST_OPTION_TARGET_TABLE, + "adbc_postgresql_ingest_benchmark", + &error)); + + ADBC_BENCHMARK_RETURN_NOT_OK(AdbcStatementSetOption(&insert_stmt.value, + ADBC_INGEST_OPTION_MODE, + ADBC_INGEST_OPTION_MODE_APPEND, + &error)); + + for (auto _ : state) { + AdbcStatementBind(&insert_stmt.value, &array.value, &schema.value, &error); + AdbcStatementExecuteQuery(&insert_stmt.value, nullptr, nullptr, &error); + } + + ADBC_BENCHMARK_RETURN_NOT_OK(AdbcStatementSetSqlQuery(&statement.value, + drop_query, + &error)); + + ADBC_BENCHMARK_RETURN_NOT_OK(AdbcStatementExecuteQuery(&statement.value, + nullptr, + nullptr, + &error)); +} + +// TODO: we are limited to only 1 iteration as AdbcStatementBind is part of +// the benchmark loop, but releases the array when it is done BENCHMARK(BM_PostgresqlExecute)->Iterations(1); +BENCHMARK(BM_PostgresqlDecimalWrite)->Iterations(1); BENCHMARK_MAIN(); diff --git a/c/driver/postgresql/postgresql_test.cc b/c/driver/postgresql/postgresql_test.cc index 2327767a4e..8ac841d8a0 100644 --- a/c/driver/postgresql/postgresql_test.cc +++ b/c/driver/postgresql/postgresql_test.cc @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +#include #include #include #include @@ -119,6 +120,9 @@ class PostgresQuirks : public adbc_validation::DriverQuirks { return NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO; case NANOARROW_TYPE_LARGE_STRING: return NANOARROW_TYPE_STRING; + case NANOARROW_TYPE_DECIMAL128: + case NANOARROW_TYPE_DECIMAL256: + return NANOARROW_TYPE_STRING; default: return ingest_type; } @@ -1645,3 +1649,237 @@ INSTANTIATE_TEST_SUITE_P(TimeTypes, PostgresTypeTest, testing::ValuesIn(kTimeTyp INSTANTIATE_TEST_SUITE_P(TimestampTypes, PostgresTypeTest, testing::ValuesIn(kTimestampTypeCases), TypeTestCase::FormatName); + +struct DecimalTestCase { + const enum ArrowType type; + const int32_t precision; + const int32_t scale; + const std::vector> data; + const std::vector> expected; +}; + +class PostgresDecimalTest : public ::testing::TestWithParam { +public: + void SetUp() override { + ASSERT_THAT(AdbcDatabaseNew(&database_, &error_), IsOkStatus(&error_)); + ASSERT_THAT(quirks_.SetupDatabase(&database_, &error_), IsOkStatus(&error_)); + ASSERT_THAT(AdbcDatabaseInit(&database_, &error_), IsOkStatus(&error_)); + + ASSERT_THAT(AdbcConnectionNew(&connection_, &error_), IsOkStatus(&error_)); + ASSERT_THAT(AdbcConnectionInit(&connection_, &database_, &error_), + IsOkStatus(&error_)); + + ASSERT_THAT(AdbcStatementNew(&connection_, &statement_, &error_), + IsOkStatus(&error_)); + + ASSERT_THAT(quirks_.DropTable(&connection_, "bulk_ingest", &error_), + IsOkStatus(&error_)); + } + + void TearDown() override { + if (statement_.private_data) { + ASSERT_THAT(AdbcStatementRelease(&statement_, &error_), IsOkStatus(&error_)); + } + if (connection_.private_data) { + ASSERT_THAT(AdbcConnectionRelease(&connection_, &error_), IsOkStatus(&error_)); + } + if (database_.private_data) { + ASSERT_THAT(AdbcDatabaseRelease(&database_, &error_), IsOkStatus(&error_)); + } + + if (error_.release) error_.release(&error_); + } + +protected: + PostgresQuirks quirks_; + struct AdbcError error_ = {}; + struct AdbcDatabase database_ = {}; + struct AdbcConnection connection_ = {}; + struct AdbcStatement statement_ = {}; +}; + +TEST_P(PostgresDecimalTest, SelectValue) { + adbc_validation::Handle schema; + adbc_validation::Handle array; + struct ArrowError na_error; + + const enum ArrowType type = GetParam().type; + const int32_t precision = GetParam().precision; + const int32_t scale = GetParam().scale; + const auto data = GetParam().data; + const auto expected = GetParam().expected; + const size_t nrecords = expected.size(); + + int32_t bitwidth; + switch (type) { + case NANOARROW_TYPE_DECIMAL128: + bitwidth = 128; + break; + case NANOARROW_TYPE_DECIMAL256: + bitwidth = 256; + break; + default: + FAIL(); + } + + // this is a bit of a hack to make std::vector play nicely with + // a dynamic number of stack-allocated ArrowDecimal objects + constexpr size_t max_decimals = 10; + struct ArrowDecimal decimals[max_decimals]; + if (nrecords > max_decimals) { + FAIL() << + " max_decimals exceeded for test case - please change parametrization"; + } + + std::vector> values; + for (size_t i = 0; i < nrecords; i++) { + ArrowDecimalInit(&decimals[i], bitwidth, precision, scale); + uint8_t buf[32]; + const auto record = data[i]; + memcpy(buf, record.data(), sizeof(buf)); + ArrowDecimalSetBytes(&decimals[i], buf); + values.push_back(&decimals[i]); + } + + auto expected_with_null{expected}; + expected_with_null.insert(expected_with_null.begin(), std::nullopt); + values.push_back(std::nullopt); + + ArrowSchemaInit(&schema.value); + ASSERT_EQ(ArrowSchemaSetTypeStruct(&schema.value, 1), 0); + ASSERT_EQ(AdbcNsArrowSchemaSetTypeDecimal(schema.value.children[0], + type, precision, scale), 0); + ASSERT_EQ(ArrowSchemaSetName(schema.value.children[0], "col"), 0); + + ASSERT_THAT(adbc_validation::MakeBatch(&schema.value, &array.value, + &na_error, values), + adbc_validation::IsOkErrno()); + + ASSERT_THAT(AdbcStatementSetOption(&statement_, + ADBC_INGEST_OPTION_TARGET_TABLE, + "bulk_ingest", &error_), + IsOkStatus(&error_)); + ASSERT_THAT(AdbcStatementBind(&statement_, &array.value, &schema.value, &error_), + IsOkStatus(&error_)); + + int64_t rows_affected = 0; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement_, nullptr, &rows_affected, &error_), + IsOkStatus(&error_)); + ASSERT_THAT(rows_affected, + ::testing::AnyOf(::testing::Eq(values.size()), ::testing::Eq(-1))); + + ASSERT_THAT(AdbcStatementSetSqlQuery( + &statement_, + "SELECT * FROM bulk_ingest ORDER BY \"col\" ASC NULLS FIRST", &error_), + IsOkStatus(&error_)); + + { + adbc_validation::StreamReader reader; + ASSERT_THAT(AdbcStatementExecuteQuery(&statement_, &reader.stream.value, + &reader.rows_affected, &error_), + IsOkStatus(&error_)); + ASSERT_THAT(reader.rows_affected, + ::testing::AnyOf(::testing::Eq(values.size()), ::testing::Eq(-1))); + + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + ArrowType round_trip_type = quirks_.IngestSelectRoundTripType(type); + ASSERT_NO_FATAL_FAILURE(adbc_validation::CompareSchema(&reader.schema.value, + {{"col", + round_trip_type, true}})); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_NE(nullptr, reader.array->release); + ASSERT_EQ(values.size(), reader.array->length); + ASSERT_EQ(1, reader.array->n_children); + + ASSERT_NO_FATAL_FAILURE(adbc_validation::CompareArray< + std::string>(reader.array_view->children[0], + expected_with_null)); + + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_EQ(nullptr, reader.array->release); + } +} + +static std::vector> kDecimalData = { + // -12345600000 + {18446744061363951616ULL, 18446744073709551615ULL, 0, 0}, + // 1234 + {1234ULL, 0, 0, 0}, + // 100000000 + {100000000ULL, 0, 0, 0}, + // 12345600000 + {12345600000ULL, 0, 0, 0}, + // 100000000000000 + {100000000000000ULL, 0, 0, 0}, + // 2342394230592232349023094 + {8221368519775271798ULL, 126981ULL, 0, 0}, +}; + +static std::vector> kDecimal256Data = { + // 1234567890123456789012345678901234567890123456789012345678901234567890123456 + {17877984925544397504ULL, 5352188884907840935ULL, 234631617561833724ULL, + 196678011949953713ULL}, + // -1234567890123456789012345678901234567890123456789012345678901234567890123456 + {568759148165154112ULL, 13094555188801710680ULL, 18212112456147717891ULL, + 18250066061759597902ULL}, +}; + +static std::initializer_list kDecimal128Cases = { + { + NANOARROW_TYPE_DECIMAL128, 38, 8, kDecimalData, + {"-123.456", "0.00001234", "1", "123.456", "1000000", + "23423942305922323.49023094"} + }}; + +static std::initializer_list kDecimal128NoScaleCases = { + { + NANOARROW_TYPE_DECIMAL128, 38, 0, kDecimalData, + {"-12345600000", "1234", "100000000", "12345600000", "100000000000000", + "2342394230592232349023094"} + }}; + +static std::initializer_list kDecimal256Cases = { + { + NANOARROW_TYPE_DECIMAL256, 38, 8, kDecimalData, + {"-123.456", "0.00001234", "1", "123.456", "1000000", + "23423942305922323.49023094"} + }}; + +static std::initializer_list kDecimal256NoScaleCases = { + { + NANOARROW_TYPE_DECIMAL256, 38, 0, kDecimalData, + {"-12345600000", "1234", "100000000", "12345600000", "100000000000000", + "2342394230592232349023094"} + }}; + +static std::initializer_list kDecimal256LargeCases = { + { + NANOARROW_TYPE_DECIMAL256, 76, 8, kDecimal256Data, + { + "-12345678901234567890123456789012345678901234567890123456789012345678.90123456", + "12345678901234567890123456789012345678901234567890123456789012345678.90123456", + } + }}; + +static std::initializer_list kDecimal256LargeNoScaleCases = { + { + NANOARROW_TYPE_DECIMAL256, 76, 0, kDecimal256Data, + { + "-1234567890123456789012345678901234567890123456789012345678901234567890123456", + "1234567890123456789012345678901234567890123456789012345678901234567890123456", + } + }}; + +INSTANTIATE_TEST_SUITE_P(Decimal128Tests, PostgresDecimalTest, + testing::ValuesIn(kDecimal128Cases)); +INSTANTIATE_TEST_SUITE_P(Decimal128NoScale, PostgresDecimalTest, + testing::ValuesIn(kDecimal128NoScaleCases)); +INSTANTIATE_TEST_SUITE_P(Decimal256Tests, PostgresDecimalTest, + testing::ValuesIn(kDecimal128Cases)); +INSTANTIATE_TEST_SUITE_P(Decimal256NoScale, PostgresDecimalTest, + testing::ValuesIn(kDecimal128NoScaleCases)); +INSTANTIATE_TEST_SUITE_P(Decimal256LargeTests, PostgresDecimalTest, + testing::ValuesIn(kDecimal256LargeCases)); +INSTANTIATE_TEST_SUITE_P(Decimal256LargeNoScale, PostgresDecimalTest, + testing::ValuesIn(kDecimal256LargeNoScaleCases)); diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc index 811a086a4f..68fd45a944 100644 --- a/c/driver/postgresql/statement.cc +++ b/c/driver/postgresql/statement.cc @@ -214,6 +214,11 @@ struct BindStream { type_id = PostgresTypeId::kInterval; param_lengths[i] = 16; break; + case ArrowType::NANOARROW_TYPE_DECIMAL128: + case ArrowType::NANOARROW_TYPE_DECIMAL256: + type_id = PostgresTypeId::kNumeric; + param_lengths[i] = 0; + break; case ArrowType::NANOARROW_TYPE_DICTIONARY: { struct ArrowSchemaView value_view; CHECK_NA(INTERNAL, @@ -1062,6 +1067,10 @@ AdbcStatusCode PostgresStatement::CreateBulkTable( case ArrowType::NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: create += " INTERVAL"; break; + case ArrowType::NANOARROW_TYPE_DECIMAL128: + case ArrowType::NANOARROW_TYPE_DECIMAL256: + create += " DECIMAL"; + break; case ArrowType::NANOARROW_TYPE_DICTIONARY: { struct ArrowSchemaView value_view; CHECK_NA(INTERNAL, diff --git a/c/validation/adbc_validation_util.h b/c/validation/adbc_validation_util.h index da71e0d9e8..321b10f9d8 100644 --- a/c/validation/adbc_validation_util.h +++ b/c/validation/adbc_validation_util.h @@ -283,6 +283,10 @@ int MakeArray(struct ArrowArray* parent, struct ArrowArray* array, if (int errno_res = ArrowArrayAppendInterval(array, *v); errno_res != 0) { return errno_res; } + } else if constexpr (std::is_same::value) { + if (int errno_res = ArrowArrayAppendDecimal(array, *v); errno_res != 0) { + return errno_res; + } } else { static_assert(!sizeof(T), "Not yet implemented"); return ENOTSUP;