From 0f06843e498d5f7afed45c5d50ab523b05933f67 Mon Sep 17 00:00:00 2001
From: David Li
Date: Fri, 22 Dec 2023 08:53:27 -0500
Subject: [PATCH] fix(c/driver/postgresql): fix ingest with multiple batches
(#1393)
The COPY writer was ending the COPY command after each batch, so any
dataset with more than one batch would fail. Instead, write the header
once and don't end the command until we've written all batches.
Fixes #1310.
---
c/driver/postgresql/postgres_copy_reader.h | 13 +-
.../postgresql/postgres_copy_reader_test.cc | 55 +++++-
c/driver/postgresql/postgres_util.h | 4 +-
c/driver/postgresql/statement.cc | 51 +++--
docs/source/python/recipe/postgresql.rst | 5 +
.../recipe/postgresql_create_dataset_table.py | 184 ++++++++++++++++++
.../adbc_driver_manager/dbapi.py | 16 ++
.../tests/test_dbapi.py | 56 ++++++
8 files changed, 350 insertions(+), 34 deletions(-)
create mode 100644 docs/source/python/recipe/postgresql_create_dataset_table.py
diff --git a/c/driver/postgresql/postgres_copy_reader.h b/c/driver/postgresql/postgres_copy_reader.h
index 7ba29ea3a3..686d54b81b 100644
--- a/c/driver/postgresql/postgres_copy_reader.h
+++ b/c/driver/postgresql/postgres_copy_reader.h
@@ -1460,16 +1460,20 @@ static inline ArrowErrorCode MakeCopyFieldWriter(struct ArrowSchema* schema,
class PostgresCopyStreamWriter {
public:
- ArrowErrorCode Init(struct ArrowSchema* schema, struct ArrowArray* array) {
+ ArrowErrorCode Init(struct ArrowSchema* schema) {
schema_ = schema;
NANOARROW_RETURN_NOT_OK(
ArrowArrayViewInitFromSchema(&array_view_.value, schema, nullptr));
- NANOARROW_RETURN_NOT_OK(ArrowArrayViewSetArray(&array_view_.value, array, nullptr));
root_writer_.Init(&array_view_.value);
ArrowBufferInit(&buffer_.value);
return NANOARROW_OK;
}
+ ArrowErrorCode SetArray(struct ArrowArray* array) {
+ NANOARROW_RETURN_NOT_OK(ArrowArrayViewSetArray(&array_view_.value, array, nullptr));
+ return NANOARROW_OK;
+ }
+
ArrowErrorCode WriteHeader(ArrowError* error) {
NANOARROW_RETURN_NOT_OK(ArrowBufferAppend(&buffer_.value, kPgCopyBinarySignature,
sizeof(kPgCopyBinarySignature)));
@@ -1508,6 +1512,11 @@ class PostgresCopyStreamWriter {
const struct ArrowBuffer& WriteBuffer() const { return buffer_.value; }
+ void Rewind() {
+ records_written_ = 0;
+ buffer_->size_bytes = 0;
+ }
+
private:
PostgresCopyFieldTupleWriter root_writer_;
struct ArrowSchema* schema_;
diff --git a/c/driver/postgresql/postgres_copy_reader_test.cc b/c/driver/postgresql/postgres_copy_reader_test.cc
index 00ba120480..7882d602af 100644
--- a/c/driver/postgresql/postgres_copy_reader_test.cc
+++ b/c/driver/postgresql/postgres_copy_reader_test.cc
@@ -60,13 +60,14 @@ class PostgresCopyStreamTester {
class PostgresCopyStreamWriteTester {
public:
ArrowErrorCode Init(struct ArrowSchema* schema, struct ArrowArray* array,
- ArrowError* error = nullptr) {
- NANOARROW_RETURN_NOT_OK(writer_.Init(schema, array));
+ struct ArrowError* error = nullptr) {
+ NANOARROW_RETURN_NOT_OK(writer_.Init(schema));
NANOARROW_RETURN_NOT_OK(writer_.InitFieldWriters(error));
+ NANOARROW_RETURN_NOT_OK(writer_.SetArray(array));
return NANOARROW_OK;
}
- ArrowErrorCode WriteAll(ArrowError* error = nullptr) {
+ ArrowErrorCode WriteAll(struct ArrowError* error) {
NANOARROW_RETURN_NOT_OK(writer_.WriteHeader(error));
int result;
@@ -77,8 +78,20 @@ class PostgresCopyStreamWriteTester {
return result;
}
+ ArrowErrorCode WriteArray(struct ArrowArray* array, struct ArrowError* error) {
+ writer_.SetArray(array);
+ int result;
+ do {
+ result = writer_.WriteRecord(error);
+ } while (result == NANOARROW_OK);
+
+ return result;
+ }
+
const struct ArrowBuffer& WriteBuffer() const { return writer_.WriteBuffer(); }
+ void Rewind() { writer_.Rewind(); }
+
private:
PostgresCopyStreamWriter writer_;
};
@@ -1261,4 +1274,40 @@ TEST(PostgresCopyUtilsTest, PostgresCopyReadCustomRecord) {
ASSERT_DOUBLE_EQ(data_buffer2[2], 0);
}
+TEST(PostgresCopyUtilsTest, PostgresCopyWriteMultiBatch) {
+ // Regression test for https://github.com/apache/arrow-adbc/issues/1310
+ adbc_validation::Handle schema;
+ adbc_validation::Handle array;
+ struct ArrowError na_error;
+ ASSERT_EQ(adbc_validation::MakeSchema(&schema.value, {{"col", NANOARROW_TYPE_INT32}}),
+ NANOARROW_OK);
+ ASSERT_EQ(adbc_validation::MakeBatch(&schema.value, &array.value, &na_error,
+ {-123, -1, 1, 123, std::nullopt}),
+ NANOARROW_OK);
+
+ PostgresCopyStreamWriteTester tester;
+ ASSERT_EQ(tester.Init(&schema.value, &array.value), NANOARROW_OK);
+ ASSERT_EQ(tester.WriteAll(nullptr), ENODATA);
+
+ struct ArrowBuffer buf = tester.WriteBuffer();
+ // The last 2 bytes of a message can be transmitted via PQputCopyData
+ // so no need to test those bytes from the Writer
+ size_t buf_size = sizeof(kTestPgCopyInteger) - 2;
+ ASSERT_EQ(buf.size_bytes, buf_size);
+ for (size_t i = 0; i < buf_size; i++) {
+ ASSERT_EQ(buf.data[i], kTestPgCopyInteger[i]);
+ }
+
+ tester.Rewind();
+ ASSERT_EQ(tester.WriteArray(&array.value, nullptr), ENODATA);
+
+ buf = tester.WriteBuffer();
+ // Ignore the header and footer
+ buf_size = sizeof(kTestPgCopyInteger) - 21;
+ ASSERT_EQ(buf.size_bytes, buf_size);
+ for (size_t i = 0; i < buf_size; i++) {
+ ASSERT_EQ(buf.data[i], kTestPgCopyInteger[i + 19]);
+ }
+}
+
} // namespace adbcpq
diff --git a/c/driver/postgresql/postgres_util.h b/c/driver/postgresql/postgres_util.h
index 1009d70b55..95e2619f10 100644
--- a/c/driver/postgresql/postgres_util.h
+++ b/c/driver/postgresql/postgres_util.h
@@ -166,9 +166,11 @@ struct Handle {
Handle() { std::memset(&value, 0, sizeof(value)); }
- ~Handle() { Releaser::Release(&value); }
+ ~Handle() { reset(); }
Resource* operator->() { return &value; }
+
+ void reset() { Releaser::Release(&value); }
};
} // namespace adbcpq
diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc
index 6c0541a6c5..811a086a4f 100644
--- a/c/driver/postgresql/statement.cc
+++ b/c/driver/postgresql/statement.cc
@@ -565,7 +565,12 @@ struct BindStream {
AdbcStatusCode ExecuteCopy(PGconn* conn, int64_t* rows_affected,
struct AdbcError* error) {
if (rows_affected) *rows_affected = 0;
- PGresult* result = nullptr;
+
+ PostgresCopyStreamWriter writer;
+ CHECK_NA(INTERNAL, writer.Init(&bind_schema.value), error);
+ CHECK_NA(INTERNAL, writer.InitFieldWriters(nullptr), error);
+
+ CHECK_NA(INTERNAL, writer.WriteHeader(nullptr), error);
while (true) {
Handle array;
@@ -579,20 +584,9 @@ struct BindStream {
}
if (!array->release) break;
- Handle array_view;
- CHECK_NA(
- INTERNAL,
- ArrowArrayViewInitFromSchema(&array_view.value, &bind_schema.value, nullptr),
- error);
- CHECK_NA(INTERNAL, ArrowArrayViewSetArray(&array_view.value, &array.value, nullptr),
- error);
-
- PostgresCopyStreamWriter writer;
- CHECK_NA(INTERNAL, writer.Init(&bind_schema.value, &array.value), error);
- CHECK_NA(INTERNAL, writer.InitFieldWriters(nullptr), error);
+ CHECK_NA(INTERNAL, writer.SetArray(&array.value), error);
// build writer buffer
- CHECK_NA(INTERNAL, writer.WriteHeader(nullptr), error);
int write_result;
do {
write_result = writer.WriteRecord(nullptr);
@@ -611,25 +605,26 @@ struct BindStream {
return ADBC_STATUS_IO;
}
- if (PQputCopyEnd(conn, NULL) <= 0) {
- SetError(error, "Error message returned by PQputCopyEnd: %s",
- PQerrorMessage(conn));
- return ADBC_STATUS_IO;
- }
+ if (rows_affected) *rows_affected += array->length;
+ writer.Rewind();
+ }
- result = PQgetResult(conn);
- ExecStatusType pg_status = PQresultStatus(result);
- if (pg_status != PGRES_COMMAND_OK) {
- AdbcStatusCode code =
- SetError(error, result, "[libpq] Failed to execute COPY statement: %s %s",
- PQresStatus(pg_status), PQerrorMessage(conn));
- PQclear(result);
- return code;
- }
+ if (PQputCopyEnd(conn, NULL) <= 0) {
+ SetError(error, "Error message returned by PQputCopyEnd: %s", PQerrorMessage(conn));
+ return ADBC_STATUS_IO;
+ }
+ PGresult* result = PQgetResult(conn);
+ ExecStatusType pg_status = PQresultStatus(result);
+ if (pg_status != PGRES_COMMAND_OK) {
+ AdbcStatusCode code =
+ SetError(error, result, "[libpq] Failed to execute COPY statement: %s %s",
+ PQresStatus(pg_status), PQerrorMessage(conn));
PQclear(result);
- if (rows_affected) *rows_affected += array->length;
+ return code;
}
+
+ PQclear(result);
return ADBC_STATUS_OK;
}
};
diff --git a/docs/source/python/recipe/postgresql.rst b/docs/source/python/recipe/postgresql.rst
index 7e93a47912..7d578b3633 100644
--- a/docs/source/python/recipe/postgresql.rst
+++ b/docs/source/python/recipe/postgresql.rst
@@ -26,6 +26,11 @@ Authenticate with a username and password
.. _recipe-postgresql-create-append:
+Create/append to a table from an Arrow dataset
+==============================================
+
+.. recipe:: postgresql_create_dataset_table.py
+
Create/append to a table from an Arrow table
============================================
diff --git a/docs/source/python/recipe/postgresql_create_dataset_table.py b/docs/source/python/recipe/postgresql_create_dataset_table.py
new file mode 100644
index 0000000000..e26093a308
--- /dev/null
+++ b/docs/source/python/recipe/postgresql_create_dataset_table.py
@@ -0,0 +1,184 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# RECIPE STARTS HERE
+
+#: ADBC makes it easy to load PyArrow datasets into your datastore.
+
+import os
+import tempfile
+from pathlib import Path
+
+import pyarrow
+import pyarrow.csv
+import pyarrow.dataset
+import pyarrow.feather
+import pyarrow.parquet
+
+import adbc_driver_postgresql.dbapi
+
+uri = os.environ["ADBC_POSTGRESQL_TEST_URI"]
+conn = adbc_driver_postgresql.dbapi.connect(uri)
+
+#: For the purposes of testing, we'll first make sure the tables we're about
+#: to use don't exist.
+with conn.cursor() as cur:
+ cur.execute("DROP TABLE IF EXISTS csvtable")
+ cur.execute("DROP TABLE IF EXISTS ipctable")
+ cur.execute("DROP TABLE IF EXISTS pqtable")
+ cur.execute("DROP TABLE IF EXISTS csvdataset")
+ cur.execute("DROP TABLE IF EXISTS ipcdataset")
+ cur.execute("DROP TABLE IF EXISTS pqdataset")
+
+conn.commit()
+
+#: Generating sample data
+#: ~~~~~~~~~~~~~~~~~~~~~~
+
+tempdir = tempfile.TemporaryDirectory(
+ prefix="adbc-docs-",
+ ignore_cleanup_errors=True,
+)
+root = Path(tempdir.name)
+table = pyarrow.table(
+ [
+ [1, 1, 2],
+ ["foo", "bar", "baz"],
+ ],
+ names=["ints", "strs"],
+)
+
+#: First we'll write single files.
+
+csv_file = root / "example.csv"
+pyarrow.csv.write_csv(table, csv_file)
+
+ipc_file = root / "example.arrow"
+pyarrow.feather.write_feather(table, ipc_file)
+
+parquet_file = root / "example.parquet"
+pyarrow.parquet.write_table(table, parquet_file)
+
+#: We'll also generate some partitioned datasets.
+
+csv_dataset = root / "csv_dataset"
+pyarrow.dataset.write_dataset(
+ table,
+ csv_dataset,
+ format="csv",
+ partitioning=["ints"],
+)
+
+ipc_dataset = root / "ipc_dataset"
+pyarrow.dataset.write_dataset(
+ table,
+ ipc_dataset,
+ format="feather",
+ partitioning=["ints"],
+)
+
+parquet_dataset = root / "parquet_dataset"
+pyarrow.dataset.write_dataset(
+ table,
+ parquet_dataset,
+ format="parquet",
+ partitioning=["ints"],
+)
+
+#: Loading CSV Files into PostgreSQL
+#: ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+#: We can directly pass a :py:class:`pyarrow.RecordBatchReader` (from
+#: ``open_csv``) to ``adbc_ingest``. We can also pass a
+#: :py:class:`pyarrow.dataset.Dataset`, or a
+#: :py:class:`pyarrow.dataset.Scanner`.
+
+with conn.cursor() as cur:
+ reader = pyarrow.csv.open_csv(csv_file)
+ cur.adbc_ingest("csvtable", reader, mode="create")
+
+ reader = pyarrow.dataset.dataset(
+ csv_dataset,
+ format="csv",
+ partitioning=["ints"],
+ )
+ cur.adbc_ingest("csvdataset", reader, mode="create")
+
+conn.commit()
+
+with conn.cursor() as cur:
+ cur.execute("SELECT ints, strs FROM csvtable ORDER BY ints, strs ASC")
+ assert cur.fetchall() == [(1, "bar"), (1, "foo"), (2, "baz")]
+
+ cur.execute("SELECT ints, strs FROM csvdataset ORDER BY ints, strs ASC")
+ assert cur.fetchall() == [(1, "bar"), (1, "foo"), (2, "baz")]
+
+#: Loading Arrow IPC (Feather) Files into PostgreSQL
+#: ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+with conn.cursor() as cur:
+ reader = pyarrow.ipc.RecordBatchFileReader(ipc_file)
+ #: Because of quirks in the PyArrow API, we have to read the file into
+ #: memory.
+ cur.adbc_ingest("ipctable", reader.read_all(), mode="create")
+
+ #: The Dataset API will stream the data into memory and then into
+ #: PostgreSQL, though.
+ reader = pyarrow.dataset.dataset(
+ ipc_dataset,
+ format="feather",
+ partitioning=["ints"],
+ )
+ cur.adbc_ingest("ipcdataset", reader, mode="create")
+
+conn.commit()
+
+with conn.cursor() as cur:
+ cur.execute("SELECT ints, strs FROM ipctable ORDER BY ints, strs ASC")
+ assert cur.fetchall() == [(1, "bar"), (1, "foo"), (2, "baz")]
+
+ cur.execute("SELECT ints, strs FROM ipcdataset ORDER BY ints, strs ASC")
+ assert cur.fetchall() == [(1, "bar"), (1, "foo"), (2, "baz")]
+
+#: Loading Parquet Files into PostgreSQL
+#: ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+with conn.cursor() as cur:
+ reader = pyarrow.parquet.ParquetFile(parquet_file)
+ cur.adbc_ingest("pqtable", reader.iter_batches(), mode="create")
+
+ reader = pyarrow.dataset.dataset(
+ parquet_dataset,
+ format="parquet",
+ partitioning=["ints"],
+ )
+ cur.adbc_ingest("pqdataset", reader, mode="create")
+
+conn.commit()
+
+with conn.cursor() as cur:
+ cur.execute("SELECT ints, strs FROM pqtable ORDER BY ints, strs ASC")
+ assert cur.fetchall() == [(1, "bar"), (1, "foo"), (2, "baz")]
+
+ cur.execute("SELECT ints, strs FROM pqdataset ORDER BY ints, strs ASC")
+ assert cur.fetchall() == [(1, "bar"), (1, "foo"), (2, "baz")]
+
+#: Cleanup
+#: ~~~~~~~
+
+conn.close()
+tempdir.cleanup()
diff --git a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
index 8f06ed0396..1e86144c12 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
+++ b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
@@ -43,6 +43,15 @@
except ImportError as e:
raise ImportError("PyArrow is required for the DBAPI-compatible interface") from e
+try:
+ import pyarrow.dataset
+except ImportError:
+ _pya_dataset = ()
+ _pya_scanner = ()
+else:
+ _pya_dataset = (pyarrow.dataset.Dataset,)
+ _pya_scanner = (pyarrow.dataset.Scanner,)
+
import adbc_driver_manager
from . import _lib, _reader
@@ -891,6 +900,13 @@ def adbc_ingest(
else:
if isinstance(data, pyarrow.Table):
data = data.to_reader()
+ elif isinstance(data, pyarrow.dataset.Dataset):
+ data = data.scanner().to_reader()
+ elif isinstance(data, pyarrow.dataset.Scanner):
+ data = data.to_reader()
+ elif not hasattr(data, "_export_to_c"):
+ data = pyarrow.Table.from_batches(data)
+ data = data.to_reader()
handle = _lib.ArrowArrayStreamHandle()
data._export_to_c(handle.address)
self._stmt.bind_stream(handle)
diff --git a/python/adbc_driver_postgresql/tests/test_dbapi.py b/python/adbc_driver_postgresql/tests/test_dbapi.py
index 2a132bd4a7..283e3fe687 100644
--- a/python/adbc_driver_postgresql/tests/test_dbapi.py
+++ b/python/adbc_driver_postgresql/tests/test_dbapi.py
@@ -15,9 +15,11 @@
# specific language governing permissions and limitations
# under the License.
+from pathlib import Path
from typing import Generator
import pyarrow
+import pyarrow.dataset
import pytest
from adbc_driver_postgresql import StatementOptions, dbapi
@@ -213,6 +215,60 @@ def test_stmt_ingest(postgres: dbapi.Connection) -> None:
assert cur.fetch_arrow_table() == table
+def test_stmt_ingest_dataset(postgres: dbapi.Connection, tmp_path: Path) -> None:
+ # Regression test for https://github.com/apache/arrow-adbc/issues/1310
+ table = pyarrow.table(
+ [
+ [1, 1, 2, 2, 3, 3],
+ ["a", "a", None, None, "b", "b"],
+ ],
+ schema=pyarrow.schema([("ints", "int32"), ("strs", "string")]),
+ )
+ pyarrow.dataset.write_dataset(
+ table, tmp_path, format="parquet", partitioning=["ints"]
+ )
+ ds = pyarrow.dataset.dataset(tmp_path, format="parquet", partitioning=["ints"])
+
+ with postgres.cursor() as cur:
+ for item in (
+ lambda: ds,
+ lambda: ds.scanner(),
+ lambda: ds.scanner().to_reader(),
+ lambda: ds.scanner().to_table(),
+ ):
+ cur.execute("DROP TABLE IF EXISTS test_ingest")
+
+ cur.adbc_ingest(
+ "test_ingest",
+ item(),
+ mode="create_append",
+ )
+ cur.execute("SELECT ints, strs FROM test_ingest ORDER BY ints")
+ assert cur.fetch_arrow_table() == table
+
+
+def test_stmt_ingest_multi(postgres: dbapi.Connection) -> None:
+ # Regression test for https://github.com/apache/arrow-adbc/issues/1310
+ table = pyarrow.table(
+ [
+ [1, 1, 2, 2, 3, 3],
+ ["a", "a", None, None, "b", "b"],
+ ],
+ names=["ints", "strs"],
+ )
+
+ with postgres.cursor() as cur:
+ cur.execute("DROP TABLE IF EXISTS test_ingest")
+
+ cur.adbc_ingest(
+ "test_ingest",
+ table.to_batches(max_chunksize=2),
+ mode="create_append",
+ )
+ cur.execute("SELECT * FROM test_ingest ORDER BY ints")
+ assert cur.fetch_arrow_table() == table
+
+
def test_ddl(postgres: dbapi.Connection):
with postgres.cursor() as cur:
cur.execute("DROP TABLE IF EXISTS test_ddl")