Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(c/driver/postgresql): Interval support #908

Merged
merged 11 commits into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions c/driver/flightsql/sqlite_flightsql_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,9 @@ class SqliteFlightSqlStatementTest : public ::testing::Test,
void TearDown() override { ASSERT_NO_FATAL_FAILURE(TearDownTest()); }

void TestSqlIngestTableEscaping() { GTEST_SKIP() << "Table escaping not implemented"; }
void TestSqlIngestInterval() {
GTEST_SKIP() << "Cannot ingest Interval (not implemented)";
}

protected:
SqliteFlightSqlQuirks quirks_;
Expand Down
49 changes: 48 additions & 1 deletion c/driver/postgresql/postgres_copy_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include <algorithm>
#include <cerrno>
#include <cinttypes>
#include <cstdint>
#include <memory>
#include <string>
Expand Down Expand Up @@ -212,7 +213,43 @@ class PostgresCopyNetworkEndianFieldReader : public PostgresCopyFieldReader {
}
};

// Converts COPY resulting from the Postgres NUMERIC type into a string.
// Reader for Intervals
class PostgresCopyIntervalFieldReader : public PostgresCopyFieldReader {
public:
ArrowErrorCode Read(ArrowBufferView* data, int32_t field_size_bytes, ArrowArray* array,
ArrowError* error) override {
if (field_size_bytes <= 0) {
return ArrowArrayAppendNull(array, 1);
}

if (field_size_bytes != 16) {
ArrowErrorSet(error, "Expected field with %d bytes but found field with %d bytes",
16,
static_cast<int>(field_size_bytes)); // NOLINT(runtime/int)
return EINVAL;
}

// postgres stores time as usec, arrow stores as ns
const int64_t time_usec = ReadUnsafe<int64_t>(data);

if ((time_usec > INT64_MAX / 1000) | (time_usec < INT64_MIN / 1000)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if ((time_usec > INT64_MAX / 1000) | (time_usec < INT64_MIN / 1000)) {
if ((time_usec > INT64_MAX / 1000) || (time_usec < INT64_MIN / 1000)) {

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, Arrow vendors https://github.com/nemequ/portable-snippets/tree/master/safe-math for overflow-safe helpers (we can do that as a followup though)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is your idea to also vendor that in this repo? Or something we should do in nanoarrow?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, in this repo.

ArrowErrorSet(error, "[libpq] Interval with time value %" PRId64
" usec would overflow when converting to nanoseconds");
return EINVAL;
}

const int64_t time = time_usec * 1000;
const int32_t days = ReadUnsafe<int32_t>(data);
const int32_t months = ReadUnsafe<int32_t>(data);

NANOARROW_RETURN_NOT_OK(ArrowBufferAppend(data_, &months, sizeof(int32_t)));
NANOARROW_RETURN_NOT_OK(ArrowBufferAppend(data_, &days, sizeof(int32_t)));
NANOARROW_RETURN_NOT_OK(ArrowBufferAppend(data_, &time, sizeof(int64_t)));
return AppendValid(array);
}
};

// // Converts COPY resulting from the Postgres NUMERIC type into a string.
// Rewritten based on the Postgres implementation of NUMERIC cast to string in
// src/backend/utils/adt/numeric.c : get_str_from_var() (Note that in the initial source,
// DEC_DIGITS is always 4 and DBASE is always 10000).
Expand Down Expand Up @@ -836,6 +873,16 @@ static inline ArrowErrorCode MakeCopyFieldReader(const PostgresType& pg_type,
default:
return ErrorCantConvert(error, pg_type, schema_view);
}
case NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO:
switch (pg_type.type_id()) {
case PostgresTypeId::kInterval: {
*out = new PostgresCopyIntervalFieldReader();
return NANOARROW_OK;
}
default:
return ErrorCantConvert(error, pg_type, schema_view);
}

default:
return ErrorCantConvert(error, pg_type, schema_view);
}
Expand Down
5 changes: 5 additions & 0 deletions c/driver/postgresql/postgres_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,11 @@ class PostgresType {
NANOARROW_TIME_UNIT_MICRO, /*timezone=*/"UTC"));
break;

case PostgresTypeId::kInterval:
NANOARROW_RETURN_NOT_OK(
ArrowSchemaSetType(schema, NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO));
break;

// ---- Nested --------------------
case PostgresTypeId::kRecord:
NANOARROW_RETURN_NOT_OK(ArrowSchemaSetTypeStruct(schema, n_children()));
Expand Down
24 changes: 24 additions & 0 deletions c/driver/postgresql/statement.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,10 @@ struct BindStream {
type_id = PostgresTypeId::kTimestamp;
param_lengths[i] = 8;
break;
case ArrowType::NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO:
type_id = PostgresTypeId::kInterval;
param_lengths[i] = 16;
break;
default:
SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #",
static_cast<uint64_t>(i + 1), " ('", bind_schema->children[i]->name,
Expand Down Expand Up @@ -426,6 +430,23 @@ struct BindStream {
std::memcpy(param_values[col], &value, sizeof(int64_t));
break;
}
case ArrowType::NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: {
const auto buf =
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing that is missing from https://github.com/apache/arrow-nanoarrow/pull/258/files is a ArrowArrayViewGetIntervalUnsafe function, which would help both here and in the test

array_view->children[col]->buffer_views[1].data.as_uint8 + row * 16;
const int32_t raw_months = *(int32_t*)buf;
const int32_t raw_days = *(int32_t*)(buf + 4);
const int64_t raw_ns = *(int64_t*)(buf + 8);

const uint32_t months = ToNetworkInt32(raw_months);
const uint32_t days = ToNetworkInt32(raw_days);
const uint64_t ms = ToNetworkInt64(raw_ns / 1000);

std::memcpy(param_values[col], &ms, sizeof(uint64_t));
std::memcpy(param_values[col] + sizeof(uint64_t), &days, sizeof(uint32_t));
std::memcpy(param_values[col] + sizeof(uint64_t) + sizeof(uint32_t),
&months, sizeof(uint32_t));
break;
}
default:
SetError(error, "%s%" PRId64 "%s%s%s%s", "[libpq] Field #", col + 1, " ('",
bind_schema->children[col]->name,
Expand Down Expand Up @@ -787,6 +808,9 @@ AdbcStatusCode PostgresStatement::CreateBulkTable(
create += " TIMESTAMP";
}
break;
case ArrowType::NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO:
create += " INTERVAL";
break;
default:
SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #",
static_cast<uint64_t>(i + 1), " ('", source_schema.children[i]->name,
Expand Down
3 changes: 3 additions & 0 deletions c/driver/sqlite/sqlite_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,9 @@ class SqliteStatementTest : public ::testing::Test,
}

void TestSqlIngestBinary() { GTEST_SKIP() << "Cannot ingest BINARY (not implemented)"; }
void TestSqlIngestInterval() {
GTEST_SKIP() << "Cannot ingest Interval (not implemented)";
}

protected:
void ValidateIngestedTemporalData(struct ArrowArrayView* values,
Expand Down
3 changes: 3 additions & 0 deletions c/driver_manager/adbc_driver_manager_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,9 @@ class SqliteStatementTest : public ::testing::Test,
void TestSqlIngestTimestampTz() {
GTEST_SKIP() << "Cannot ingest TIMESTAMP WITH TIMEZONE (not implemented)";
}
void TestSqlIngestInterval() {
GTEST_SKIP() << "Cannot ingest Interval (not implemented)";
}

protected:
SqliteQuirks quirks_;
Expand Down
83 changes: 83 additions & 0 deletions c/validation/adbc_validation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1193,6 +1193,89 @@ void StatementTest::TestSqlIngestTimestampTz() {
TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_NANO>("America/Los_Angeles"));
}

void StatementTest::TestSqlIngestInterval() {
if (!quirks()->supports_bulk_ingest()) {
GTEST_SKIP();
}

ASSERT_THAT(quirks()->DropTable(&connection, "bulk_ingest", &error),
IsOkStatus(&error));

Handle<struct ArrowSchema> schema;
Handle<struct ArrowArray> array;
struct ArrowError na_error;
const enum ArrowType type = NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO;
// values are days, months, ns
struct ArrowInterval neg_interval;
struct ArrowInterval zero_interval;
struct ArrowInterval pos_interval;

ArrowIntervalInit(&neg_interval, type);
ArrowIntervalInit(&zero_interval, type);
ArrowIntervalInit(&pos_interval, type);

neg_interval.months = -5;
neg_interval.days = -5;
neg_interval.ns = -42000;

pos_interval.months = 5;
pos_interval.days = 5;
pos_interval.ns = 42000;

const std::vector<std::optional<ArrowInterval*>> values = {
std::nullopt, &neg_interval, &zero_interval, &pos_interval};

ASSERT_THAT(MakeSchema(&schema.value, {{"col", type}}), IsOkErrno());

ASSERT_THAT(MakeBatch<ArrowInterval*>(&schema.value, &array.value, &na_error, values),
IsOkErrno());

ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_TABLE,
"bulk_ingest", &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error),
IsOkStatus(&error));

int64_t rows_affected = 0;
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, &rows_affected, &error),
IsOkStatus(&error));
ASSERT_THAT(rows_affected,
::testing::AnyOf(::testing::Eq(values.size()), ::testing::Eq(-1)));

ASSERT_THAT(AdbcStatementSetSqlQuery(
&statement,
"SELECT * FROM bulk_ingest ORDER BY \"col\" ASC NULLS FIRST", &error),
IsOkStatus(&error));
{
StreamReader reader;
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
&reader.rows_affected, &error),
IsOkStatus(&error));
ASSERT_THAT(reader.rows_affected,
::testing::AnyOf(::testing::Eq(values.size()), ::testing::Eq(-1)));

ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
ArrowType round_trip_type = quirks()->IngestSelectRoundTripType(type);
ASSERT_NO_FATAL_FAILURE(
CompareSchema(&reader.schema.value, {{"col", round_trip_type, NULLABLE}}));

ASSERT_NO_FATAL_FAILURE(reader.Next());
ASSERT_NE(nullptr, reader.array->release);
ASSERT_EQ(values.size(), reader.array->length);
ASSERT_EQ(1, reader.array->n_children);

if (round_trip_type == type) {
ASSERT_NO_FATAL_FAILURE(
CompareArray<ArrowInterval*>(reader.array_view->children[0], values));
}

ASSERT_NO_FATAL_FAILURE(reader.Next());
ASSERT_EQ(nullptr, reader.array->release);
}
ASSERT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error));
}

void StatementTest::TestSqlIngestTableEscaping() {
std::string name = "create_table_escaping";

Expand Down
2 changes: 2 additions & 0 deletions c/validation/adbc_validation.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ class StatementTest {
// Temporal
void TestSqlIngestTimestamp();
void TestSqlIngestTimestampTz();
void TestSqlIngestInterval();

// ---- End Type-specific tests ----------------

Expand Down Expand Up @@ -302,6 +303,7 @@ class StatementTest {
TEST_F(FIXTURE, SqlIngestBinary) { TestSqlIngestBinary(); } \
TEST_F(FIXTURE, SqlIngestTimestamp) { TestSqlIngestTimestamp(); } \
TEST_F(FIXTURE, SqlIngestTimestampTz) { TestSqlIngestTimestampTz(); } \
TEST_F(FIXTURE, SqlIngestInterval) { TestSqlIngestInterval(); } \
TEST_F(FIXTURE, SqlIngestTableEscaping) { TestSqlIngestTableEscaping(); } \
TEST_F(FIXTURE, SqlIngestAppend) { TestSqlIngestAppend(); } \
TEST_F(FIXTURE, SqlIngestErrors) { TestSqlIngestErrors(); } \
Expand Down
19 changes: 13 additions & 6 deletions c/validation/adbc_validation_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,8 @@ struct GetObjectsReader {
}
~GetObjectsReader() { AdbcGetObjectsDataDelete(get_objects_data_); }

struct AdbcGetObjectsData* operator*() {
return get_objects_data_;
}
struct AdbcGetObjectsData* operator->() {
return get_objects_data_;
}
struct AdbcGetObjectsData* operator*() { return get_objects_data_; }
struct AdbcGetObjectsData* operator->() { return get_objects_data_; }

private:
struct AdbcGetObjectsData* get_objects_data_;
Expand Down Expand Up @@ -264,6 +260,10 @@ int MakeArray(struct ArrowArray* parent, struct ArrowArray* array,
if (int errno_res = ArrowArrayAppendBytes(array, view); errno_res != 0) {
return errno_res;
}
} else if constexpr (std::is_same<T, ArrowInterval*>::value) {
if (int errno_res = ArrowArrayAppendInterval(array, *v); errno_res != 0) {
return errno_res;
}
} else {
static_assert(!sizeof(T), "Not yet implemented");
return ENOTSUP;
Expand Down Expand Up @@ -375,6 +375,13 @@ void CompareArray(struct ArrowArrayView* array,
struct ArrowStringView view = ArrowArrayViewGetStringUnsafe(array, i);
std::string str(view.data, view.size_bytes);
ASSERT_EQ(*v, str);
} else if constexpr (std::is_same<T, ArrowInterval*>::value) {
ASSERT_NE(array->buffer_views[1].data.data, nullptr);
const auto buf = array->buffer_views[1].data.as_uint8;
const auto record = buf + i * 16;
ASSERT_EQ(memcmp(record, &(*v)->months, 4), 0);
ASSERT_EQ(memcmp(record + 4, &(*v)->days, 4), 0);
ASSERT_EQ(memcmp(record + 8, &(*v)->ns, 8), 0);
} else {
static_assert(!sizeof(T), "Not yet implemented");
}
Expand Down
73 changes: 73 additions & 0 deletions c/vendor/nanoarrow/nanoarrow.h
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,29 @@ struct ArrowArrayPrivateData {
int8_t union_type_id_is_child_index;
};

/// \brief A representation of an interval.
/// \ingroup nanoarrow-utils
struct ArrowInterval {
/// \brief The type of interval being used
enum ArrowType type;
/// \brief The number of months represented by the interval
int32_t months;
/// \brief The number of days represented by the interval
int32_t days;
/// \brief The number of ms represented by the interval
int32_t ms;
/// \brief The number of ns represented by the interval
int64_t ns;
};

/// \brief Zero initialize an Interval with a given unit
/// \ingroup nanoarrow-utils
static inline void ArrowIntervalInit(struct ArrowInterval* interval,
enum ArrowType type) {
memset(interval, 0, sizeof(struct ArrowInterval));
interval->type = type;
}

/// \brief A representation of a fixed-precision decimal number
/// \ingroup nanoarrow-utils
///
Expand Down Expand Up @@ -1649,6 +1672,13 @@ static inline ArrowErrorCode ArrowArrayAppendBytes(struct ArrowArray* array,
static inline ArrowErrorCode ArrowArrayAppendString(struct ArrowArray* array,
struct ArrowStringView value);

/// \brief Append a Interval to an array
///
/// Returns NANOARROW_OK if value can be exactly represented by
/// the underlying storage type or EINVAL otherwise.
static inline ArrowErrorCode ArrowArrayAppendInterval(struct ArrowArray* array,
struct ArrowInterval* value);

/// \brief Append a decimal value to an array
///
/// Returns NANOARROW_OK if array is a decimal array with the appropriate
Expand Down Expand Up @@ -2891,6 +2921,49 @@ static inline ArrowErrorCode ArrowArrayAppendString(struct ArrowArray* array,
}
}

static inline ArrowErrorCode ArrowArrayAppendInterval(struct ArrowArray* array,
struct ArrowInterval* value) {
struct ArrowArrayPrivateData* private_data =
(struct ArrowArrayPrivateData*)array->private_data;

struct ArrowBuffer* data_buffer = ArrowArrayBuffer(array, 1);

switch (private_data->storage_type) {
case NANOARROW_TYPE_INTERVAL_MONTHS: {
if (value->type != NANOARROW_TYPE_INTERVAL_MONTHS) {
return EINVAL;
}

NANOARROW_RETURN_NOT_OK(ArrowBufferAppendInt32(data_buffer, value->months));
break;
}
case NANOARROW_TYPE_INTERVAL_DAY_TIME: {
if (value->type != NANOARROW_TYPE_INTERVAL_DAY_TIME) {
return EINVAL;
}

NANOARROW_RETURN_NOT_OK(ArrowBufferAppendInt32(data_buffer, value->days));
NANOARROW_RETURN_NOT_OK(ArrowBufferAppendInt32(data_buffer, value->ms));
break;
}
case NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: {
if (value->type != NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO) {
return EINVAL;
}

NANOARROW_RETURN_NOT_OK(ArrowBufferAppendInt32(data_buffer, value->months));
NANOARROW_RETURN_NOT_OK(ArrowBufferAppendInt32(data_buffer, value->days));
NANOARROW_RETURN_NOT_OK(ArrowBufferAppendInt64(data_buffer, value->ns));
break;
}
default:
return EINVAL;
}

array->length++;
return NANOARROW_OK;
}

static inline ArrowErrorCode ArrowArrayAppendDecimal(struct ArrowArray* array,
struct ArrowDecimal* value) {
struct ArrowArrayPrivateData* private_data =
Expand Down
Loading