Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(c/driver/postgresql): Duration support #907

Merged
merged 16 commits into from
Sep 7, 2023
35 changes: 35 additions & 0 deletions c/driver/postgresql/postgres_copy_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,32 @@ class PostgresCopyNetworkEndianFieldReader : public PostgresCopyFieldReader {
}
};

// Reader for Durations; Similar to PostgresCopyNetworkEndianFieldReader but
// discards an extra 64 bits per read (representing the postgres day / month)
class PostgresIntervalReader : public PostgresCopyFieldReader {
public:
ArrowErrorCode Read(ArrowBufferView* data, int32_t field_size_bytes, ArrowArray* array,
ArrowError* error) override {
if (field_size_bytes <= 0) {
return ArrowArrayAppendNull(array, 1);
}

if (field_size_bytes != static_cast<int32_t>(sizeof(int64_t) * 2)) {
ArrowErrorSet(error, "Expected field with %d bytes but found field with %d bytes",
static_cast<int>(sizeof(int64_t)),
static_cast<int>(field_size_bytes)); // NOLINT(runtime/int)
return EINVAL;
}

int64_t value = ReadUnsafe<int64_t>(data);
NANOARROW_RETURN_NOT_OK(ArrowBufferAppend(data_, &value, sizeof(int64_t)));

// discard unnecessary bits
Copy link
Contributor Author

@WillAyd WillAyd Jul 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thought probably want to raise if these are non-zero

ReadUnsafe<int64_t>(data);
return AppendValid(array);
}
};

// Converts COPY resulting from the Postgres NUMERIC type into a string.
// Rewritten based on the Postgres implementation of NUMERIC cast to string in
// src/backend/utils/adt/numeric.c : get_str_from_var() (Note that in the initial source,
Expand Down Expand Up @@ -836,6 +862,15 @@ static inline ArrowErrorCode MakeCopyFieldReader(const PostgresType& pg_type,
default:
return ErrorCantConvert(error, pg_type, schema_view);
}
case NANOARROW_TYPE_DURATION:
switch (pg_type.type_id()) {
case PostgresTypeId::kInterval: {
*out = new PostgresIntervalReader();
return NANOARROW_OK;
}
default:
return ErrorCantConvert(error, pg_type, schema_view);
}
default:
return ErrorCantConvert(error, pg_type, schema_view);
}
Expand Down
6 changes: 6 additions & 0 deletions c/driver/postgresql/postgres_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,12 @@ class PostgresType {
NANOARROW_TIME_UNIT_MICRO, /*timezone=*/"UTC"));
break;

case PostgresTypeId::kInterval:
NANOARROW_RETURN_NOT_OK(
ArrowSchemaSetTypeDateTime(schema, NANOARROW_TYPE_DURATION,
NANOARROW_TIME_UNIT_MICRO, /*timezone=*/nullptr));
break;

// ---- Nested --------------------
case PostgresTypeId::kRecord:
NANOARROW_RETURN_NOT_OK(ArrowSchemaSetTypeStruct(schema, n_children()));
Expand Down
34 changes: 25 additions & 9 deletions c/driver/postgresql/statement.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,10 @@ struct BindStream {
type_id = PostgresTypeId::kTimestamp;
param_lengths[i] = 8;
break;
case ArrowType::NANOARROW_TYPE_DURATION:
type_id = PostgresTypeId::kInterval;
param_lengths[i] = 16;
break;
default:
SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #",
static_cast<uint64_t>(i + 1), " ('", bind_schema->children[i]->name,
Expand Down Expand Up @@ -385,11 +389,10 @@ struct BindStream {
param_values[col] = const_cast<char*>(view.data.as_char);
break;
}
case ArrowType::NANOARROW_TYPE_TIMESTAMP: {
case ArrowType::NANOARROW_TYPE_TIMESTAMP:
case ArrowType::NANOARROW_TYPE_DURATION: {
int64_t val = array_view->children[col]->buffer_views[1].data.as_int64[row];

// 2000-01-01 00:00:00.000000 in microseconds
constexpr int64_t kPostgresTimestampEpoch = 946684800000000;
constexpr int64_t kSecOverflowLimit = 9223372036854;
constexpr int64_t kmSecOverflowLimit = 9223372036854775;

Expand All @@ -399,8 +402,7 @@ struct BindStream {
if (abs(val) > kSecOverflowLimit) {
SetError(error, "[libpq] Field #%" PRId64 "%s%s%s%" PRId64 "%s",
col + 1, "('", bind_schema->children[col]->name, "') Row #",
row + 1,
"has value which exceeds postgres timestamp limits");
row + 1, "has value which exceeds postgres temporal limits");
return ADBC_STATUS_INVALID_ARGUMENT;
}
val *= 1000000;
Expand All @@ -409,8 +411,7 @@ struct BindStream {
if (abs(val) > kmSecOverflowLimit) {
SetError(error, "[libpq] Field #%" PRId64 "%s%s%s%" PRId64 "%s",
col + 1, "('", bind_schema->children[col]->name, "') Row #",
row + 1,
"has value which exceeds postgres timestamp limits");
row + 1, "has value which exceeds postgres temporal limits");
return ADBC_STATUS_INVALID_ARGUMENT;
}
val *= 1000;
Expand All @@ -422,8 +423,20 @@ struct BindStream {
break;
}

const uint64_t value = ToNetworkInt64(val - kPostgresTimestampEpoch);
std::memcpy(param_values[col], &value, sizeof(int64_t));
if (bind_schema_fields[col].type == ArrowType::NANOARROW_TYPE_TIMESTAMP) {
// 2000-01-01 00:00:00.000000 in microseconds
constexpr int64_t kPostgresTimestampEpoch = 946684800000000;
const uint64_t value = ToNetworkInt64(val - kPostgresTimestampEpoch);
std::memcpy(param_values[col], &value, sizeof(int64_t));
} else if (bind_schema_fields[col].type ==
ArrowType::NANOARROW_TYPE_DURATION) {
// postgres stores an interval as a 64 bit offset in microsecond
// resolution alongside a 32 bit day and 32 bit month
// for now we just send 0 for the day / month values
const uint64_t value = ToNetworkInt64(val);
std::memcpy(param_values[col], &value, sizeof(int64_t));
std::memset(param_values[col] + sizeof(int64_t), 0, sizeof(int64_t));
}
break;
}
default:
Expand Down Expand Up @@ -787,6 +800,9 @@ AdbcStatusCode PostgresStatement::CreateBulkTable(
create += " TIMESTAMP";
}
break;
case ArrowType::NANOARROW_TYPE_DURATION:
create += " INTERVAL";
break;
default:
SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #",
static_cast<uint64_t>(i + 1), " ('", source_schema.children[i]->name,
Expand Down
65 changes: 42 additions & 23 deletions c/validation/adbc_validation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1095,8 +1095,9 @@ void StatementTest::TestSqlIngestBinary() {
NANOARROW_TYPE_BINARY, {std::nullopt, "", "\x00\x01\x02\x04", "\xFE\xFF"}));
}

template <enum ArrowTimeUnit TU>
void StatementTest::TestSqlIngestTemporalType(const char* timezone) {
template <enum ArrowType T>
void StatementTest::TestSqlIngestTemporalType(enum ArrowTimeUnit unit,
const char* timezone) {
if (!quirks()->supports_bulk_ingest()) {
GTEST_SKIP();
}
Expand All @@ -1108,13 +1109,12 @@ void StatementTest::TestSqlIngestTemporalType(const char* timezone) {
Handle<struct ArrowArray> array;
struct ArrowError na_error;
const std::vector<std::optional<int64_t>> values = {std::nullopt, -42, 0, 42};
const ArrowType type = NANOARROW_TYPE_TIMESTAMP;

// much of this code is shared with TestSqlIngestType with minor
// changes to allow for various time units to be tested
ArrowSchemaInit(&schema.value);
ArrowSchemaSetTypeStruct(&schema.value, 1);
ArrowSchemaSetTypeDateTime(schema->children[0], type, TU, timezone);
ArrowSchemaSetTypeDateTime(schema->children[0], T, unit, timezone);
ArrowSchemaSetName(schema->children[0], "col");
ASSERT_THAT(MakeBatch<int64_t>(&schema.value, &array.value, &na_error, values),
IsOkErrno());
Expand Down Expand Up @@ -1146,7 +1146,7 @@ void StatementTest::TestSqlIngestTemporalType(const char* timezone) {

ASSERT_NO_FATAL_FAILURE(reader.GetSchema());

ArrowType round_trip_type = quirks()->IngestSelectRoundTripType(type);
ArrowType round_trip_type = quirks()->IngestSelectRoundTripType(T);
ASSERT_NO_FATAL_FAILURE(
CompareSchema(&reader.schema.value, {{"col", round_trip_type, NULLABLE}}));

Expand All @@ -1155,7 +1155,7 @@ void StatementTest::TestSqlIngestTemporalType(const char* timezone) {
ASSERT_EQ(values.size(), reader.array->length);
ASSERT_EQ(1, reader.array->n_children);

ValidateIngestedTemporalData(reader.array_view->children[0], TU, timezone);
ValidateIngestedTemporalData(reader.array_view->children[0], unit, timezone);

ASSERT_NO_FATAL_FAILURE(reader.Next());
ASSERT_EQ(nullptr, reader.array->release);
Expand All @@ -1170,27 +1170,46 @@ void StatementTest::ValidateIngestedTemporalData(struct ArrowArrayView* values,
FAIL() << "ValidateIngestedTemporalData is not implemented in the base class";
}

void StatementTest::TestSqlIngestDuration() {
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TYPE_DURATION>(
NANOARROW_TIME_UNIT_SECOND, nullptr));
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TYPE_DURATION>(
NANOARROW_TIME_UNIT_MICRO, nullptr));
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TYPE_DURATION>(
NANOARROW_TIME_UNIT_MILLI, nullptr));
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TYPE_DURATION>(
NANOARROW_TIME_UNIT_NANO, nullptr));
}

void StatementTest::TestSqlIngestTimestamp() {
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_SECOND>(nullptr));
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_MICRO>(nullptr));
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_MILLI>(nullptr));
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_NANO>(nullptr));
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TYPE_TIMESTAMP>(
NANOARROW_TIME_UNIT_SECOND, nullptr));
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TYPE_TIMESTAMP>(
NANOARROW_TIME_UNIT_MICRO, nullptr));
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TYPE_TIMESTAMP>(
NANOARROW_TIME_UNIT_MILLI, nullptr));
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TYPE_TIMESTAMP>(
NANOARROW_TIME_UNIT_NANO, nullptr));
}

void StatementTest::TestSqlIngestTimestampTz() {
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_SECOND>("UTC"));
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_MICRO>("UTC"));
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_MILLI>("UTC"));
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_NANO>("UTC"));

ASSERT_NO_FATAL_FAILURE(
TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_SECOND>("America/Los_Angeles"));
ASSERT_NO_FATAL_FAILURE(
TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_MICRO>("America/Los_Angeles"));
ASSERT_NO_FATAL_FAILURE(
TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_MILLI>("America/Los_Angeles"));
ASSERT_NO_FATAL_FAILURE(
TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_NANO>("America/Los_Angeles"));
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TYPE_TIMESTAMP>(
NANOARROW_TIME_UNIT_SECOND, nullptr));
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TYPE_TIMESTAMP>(
NANOARROW_TIME_UNIT_MICRO, nullptr));
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TYPE_TIMESTAMP>(
NANOARROW_TIME_UNIT_MILLI, nullptr));
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TYPE_TIMESTAMP>(
NANOARROW_TIME_UNIT_NANO, nullptr));

ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TYPE_TIMESTAMP>(
NANOARROW_TIME_UNIT_SECOND, "America/Los_Angeles"));
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TYPE_TIMESTAMP>(
NANOARROW_TIME_UNIT_MICRO, "America/Los_Angeles"));
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TYPE_TIMESTAMP>(
NANOARROW_TIME_UNIT_MILLI, "America/Los_Angeles"));
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TYPE_TIMESTAMP>(
NANOARROW_TIME_UNIT_NANO, "America/Los_Angeles"));
}

void StatementTest::TestSqlIngestAppend() {
Expand Down
6 changes: 4 additions & 2 deletions c/validation/adbc_validation.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ class StatementTest {
void TestSqlIngestBinary();

// Temporal
void TestSqlIngestDuration();
void TestSqlIngestTimestamp();
void TestSqlIngestTimestampTz();

Expand Down Expand Up @@ -274,8 +275,8 @@ class StatementTest {
template <typename CType>
void TestSqlIngestNumericType(ArrowType type);

template <enum ArrowTimeUnit TU>
void TestSqlIngestTemporalType(const char* timezone);
template <enum ArrowType T>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Originally I tried to provide ArrowType and ArrowTimeUnit as template parameters, but that kept yielding errors like

/home/willayd/clones/arrow-adbc/c/validation/adbc_validation.cc:1173:114: error: macro "ASSERT_NO_FATAL_FAILURE" passed 2 arguments, but takes just 1
 1173 |   ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TYPE_DURATION, NANOARROW_TIME_UNIT_SECOND>(nullptr));
      |                                                                                                                  ^
In file included from /home/willayd/clones/arrow-adbc/c/validation/adbc_validation.h:26,
                 from /home/willayd/clones/arrow-adbc/c/validation/adbc_validation.cc:18:
/usr/local/include/gtest/gtest.h:2216: note: macro "ASSERT_NO_FATAL_FAILURE" defined here
 2216 | #define ASSERT_NO_FATAL_FAILURE(statement) \

I'm not sure if that is a bug with gtest or my lack of C++ template knowledge, but figured changing the template around like this wasn't a huge deal

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh for future reference, I think you just need MACRO((Template<A, B>), ...)

void TestSqlIngestTemporalType(enum ArrowTimeUnit, const char* timezone);

virtual void ValidateIngestedTemporalData(struct ArrowArrayView* values,
enum ArrowTimeUnit unit,
Expand All @@ -299,6 +300,7 @@ class StatementTest {
TEST_F(FIXTURE, SqlIngestFloat64) { TestSqlIngestFloat64(); } \
TEST_F(FIXTURE, SqlIngestString) { TestSqlIngestString(); } \
TEST_F(FIXTURE, SqlIngestBinary) { TestSqlIngestBinary(); } \
TEST_F(FIXTURE, SqlIngestDuration) { TestSqlIngestDuration(); } \
TEST_F(FIXTURE, SqlIngestTimestamp) { TestSqlIngestTimestamp(); } \
TEST_F(FIXTURE, SqlIngestTimestampTz) { TestSqlIngestTimestampTz(); } \
TEST_F(FIXTURE, SqlIngestAppend) { TestSqlIngestAppend(); } \
Expand Down