Skip to content

Commit

Permalink
fix(c/driver/postgresql): fix ingest with multiple batches (#1393)
Browse files Browse the repository at this point in the history
The COPY writer was ending the COPY command after each batch, so any
dataset with more than one batch would fail. Instead, write the header
once and don't end the command until we've written all batches.

Fixes #1310.
  • Loading branch information
lidavidm authored Dec 22, 2023
1 parent 7531ac1 commit 0f06843
Show file tree
Hide file tree
Showing 8 changed files with 350 additions and 34 deletions.
13 changes: 11 additions & 2 deletions c/driver/postgresql/postgres_copy_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -1460,16 +1460,20 @@ static inline ArrowErrorCode MakeCopyFieldWriter(struct ArrowSchema* schema,

class PostgresCopyStreamWriter {
public:
ArrowErrorCode Init(struct ArrowSchema* schema, struct ArrowArray* array) {
ArrowErrorCode Init(struct ArrowSchema* schema) {
schema_ = schema;
NANOARROW_RETURN_NOT_OK(
ArrowArrayViewInitFromSchema(&array_view_.value, schema, nullptr));
NANOARROW_RETURN_NOT_OK(ArrowArrayViewSetArray(&array_view_.value, array, nullptr));
root_writer_.Init(&array_view_.value);
ArrowBufferInit(&buffer_.value);
return NANOARROW_OK;
}

ArrowErrorCode SetArray(struct ArrowArray* array) {
NANOARROW_RETURN_NOT_OK(ArrowArrayViewSetArray(&array_view_.value, array, nullptr));
return NANOARROW_OK;
}

ArrowErrorCode WriteHeader(ArrowError* error) {
NANOARROW_RETURN_NOT_OK(ArrowBufferAppend(&buffer_.value, kPgCopyBinarySignature,
sizeof(kPgCopyBinarySignature)));
Expand Down Expand Up @@ -1508,6 +1512,11 @@ class PostgresCopyStreamWriter {

const struct ArrowBuffer& WriteBuffer() const { return buffer_.value; }

void Rewind() {
records_written_ = 0;
buffer_->size_bytes = 0;
}

private:
PostgresCopyFieldTupleWriter root_writer_;
struct ArrowSchema* schema_;
Expand Down
55 changes: 52 additions & 3 deletions c/driver/postgresql/postgres_copy_reader_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,14 @@ class PostgresCopyStreamTester {
class PostgresCopyStreamWriteTester {
public:
ArrowErrorCode Init(struct ArrowSchema* schema, struct ArrowArray* array,
ArrowError* error = nullptr) {
NANOARROW_RETURN_NOT_OK(writer_.Init(schema, array));
struct ArrowError* error = nullptr) {
NANOARROW_RETURN_NOT_OK(writer_.Init(schema));
NANOARROW_RETURN_NOT_OK(writer_.InitFieldWriters(error));
NANOARROW_RETURN_NOT_OK(writer_.SetArray(array));
return NANOARROW_OK;
}

ArrowErrorCode WriteAll(ArrowError* error = nullptr) {
ArrowErrorCode WriteAll(struct ArrowError* error) {
NANOARROW_RETURN_NOT_OK(writer_.WriteHeader(error));

int result;
Expand All @@ -77,8 +78,20 @@ class PostgresCopyStreamWriteTester {
return result;
}

ArrowErrorCode WriteArray(struct ArrowArray* array, struct ArrowError* error) {
writer_.SetArray(array);
int result;
do {
result = writer_.WriteRecord(error);
} while (result == NANOARROW_OK);

return result;
}

const struct ArrowBuffer& WriteBuffer() const { return writer_.WriteBuffer(); }

void Rewind() { writer_.Rewind(); }

private:
PostgresCopyStreamWriter writer_;
};
Expand Down Expand Up @@ -1261,4 +1274,40 @@ TEST(PostgresCopyUtilsTest, PostgresCopyReadCustomRecord) {
ASSERT_DOUBLE_EQ(data_buffer2[2], 0);
}

TEST(PostgresCopyUtilsTest, PostgresCopyWriteMultiBatch) {
// Regression test for https://github.com/apache/arrow-adbc/issues/1310
adbc_validation::Handle<struct ArrowSchema> schema;
adbc_validation::Handle<struct ArrowArray> array;
struct ArrowError na_error;
ASSERT_EQ(adbc_validation::MakeSchema(&schema.value, {{"col", NANOARROW_TYPE_INT32}}),
NANOARROW_OK);
ASSERT_EQ(adbc_validation::MakeBatch<int32_t>(&schema.value, &array.value, &na_error,
{-123, -1, 1, 123, std::nullopt}),
NANOARROW_OK);

PostgresCopyStreamWriteTester tester;
ASSERT_EQ(tester.Init(&schema.value, &array.value), NANOARROW_OK);
ASSERT_EQ(tester.WriteAll(nullptr), ENODATA);

struct ArrowBuffer buf = tester.WriteBuffer();
// The last 2 bytes of a message can be transmitted via PQputCopyData
// so no need to test those bytes from the Writer
size_t buf_size = sizeof(kTestPgCopyInteger) - 2;
ASSERT_EQ(buf.size_bytes, buf_size);
for (size_t i = 0; i < buf_size; i++) {
ASSERT_EQ(buf.data[i], kTestPgCopyInteger[i]);
}

tester.Rewind();
ASSERT_EQ(tester.WriteArray(&array.value, nullptr), ENODATA);

buf = tester.WriteBuffer();
// Ignore the header and footer
buf_size = sizeof(kTestPgCopyInteger) - 21;
ASSERT_EQ(buf.size_bytes, buf_size);
for (size_t i = 0; i < buf_size; i++) {
ASSERT_EQ(buf.data[i], kTestPgCopyInteger[i + 19]);
}
}

} // namespace adbcpq
4 changes: 3 additions & 1 deletion c/driver/postgresql/postgres_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,11 @@ struct Handle {

Handle() { std::memset(&value, 0, sizeof(value)); }

~Handle() { Releaser<Resource>::Release(&value); }
~Handle() { reset(); }

Resource* operator->() { return &value; }

void reset() { Releaser<Resource>::Release(&value); }
};

} // namespace adbcpq
51 changes: 23 additions & 28 deletions c/driver/postgresql/statement.cc
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,12 @@ struct BindStream {
AdbcStatusCode ExecuteCopy(PGconn* conn, int64_t* rows_affected,
struct AdbcError* error) {
if (rows_affected) *rows_affected = 0;
PGresult* result = nullptr;

PostgresCopyStreamWriter writer;
CHECK_NA(INTERNAL, writer.Init(&bind_schema.value), error);
CHECK_NA(INTERNAL, writer.InitFieldWriters(nullptr), error);

CHECK_NA(INTERNAL, writer.WriteHeader(nullptr), error);

while (true) {
Handle<struct ArrowArray> array;
Expand All @@ -579,20 +584,9 @@ struct BindStream {
}
if (!array->release) break;

Handle<struct ArrowArrayView> array_view;
CHECK_NA(
INTERNAL,
ArrowArrayViewInitFromSchema(&array_view.value, &bind_schema.value, nullptr),
error);
CHECK_NA(INTERNAL, ArrowArrayViewSetArray(&array_view.value, &array.value, nullptr),
error);

PostgresCopyStreamWriter writer;
CHECK_NA(INTERNAL, writer.Init(&bind_schema.value, &array.value), error);
CHECK_NA(INTERNAL, writer.InitFieldWriters(nullptr), error);
CHECK_NA(INTERNAL, writer.SetArray(&array.value), error);

// build writer buffer
CHECK_NA(INTERNAL, writer.WriteHeader(nullptr), error);
int write_result;
do {
write_result = writer.WriteRecord(nullptr);
Expand All @@ -611,25 +605,26 @@ struct BindStream {
return ADBC_STATUS_IO;
}

if (PQputCopyEnd(conn, NULL) <= 0) {
SetError(error, "Error message returned by PQputCopyEnd: %s",
PQerrorMessage(conn));
return ADBC_STATUS_IO;
}
if (rows_affected) *rows_affected += array->length;
writer.Rewind();
}

result = PQgetResult(conn);
ExecStatusType pg_status = PQresultStatus(result);
if (pg_status != PGRES_COMMAND_OK) {
AdbcStatusCode code =
SetError(error, result, "[libpq] Failed to execute COPY statement: %s %s",
PQresStatus(pg_status), PQerrorMessage(conn));
PQclear(result);
return code;
}
if (PQputCopyEnd(conn, NULL) <= 0) {
SetError(error, "Error message returned by PQputCopyEnd: %s", PQerrorMessage(conn));
return ADBC_STATUS_IO;
}

PGresult* result = PQgetResult(conn);
ExecStatusType pg_status = PQresultStatus(result);
if (pg_status != PGRES_COMMAND_OK) {
AdbcStatusCode code =
SetError(error, result, "[libpq] Failed to execute COPY statement: %s %s",
PQresStatus(pg_status), PQerrorMessage(conn));
PQclear(result);
if (rows_affected) *rows_affected += array->length;
return code;
}

PQclear(result);
return ADBC_STATUS_OK;
}
};
Expand Down
5 changes: 5 additions & 0 deletions docs/source/python/recipe/postgresql.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ Authenticate with a username and password

.. _recipe-postgresql-create-append:

Create/append to a table from an Arrow dataset
==============================================

.. recipe:: postgresql_create_dataset_table.py

Create/append to a table from an Arrow table
============================================

Expand Down
Loading

0 comments on commit 0f06843

Please sign in to comment.