Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(c/driver): be explicit about columns in ingestion #1238

Merged
merged 1 commit into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions c/driver/postgresql/postgresql_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,14 @@ class PostgresQuirks : public adbc_validation::DriverQuirks {
return ddl;
}

std::optional<std::string> PrimaryKeyIngestTableDdl(
std::string_view name) const override {
std::string ddl = "CREATE TABLE ";
ddl += name;
ddl += " (id BIGSERIAL PRIMARY KEY, value BIGINT)";
return ddl;
}

std::optional<std::string> CompositePrimaryKeyTableDdl(
std::string_view name) const override {
std::string ddl = "CREATE TABLE ";
Expand Down
26 changes: 20 additions & 6 deletions c/driver/postgresql/statement.cc
Original file line number Diff line number Diff line change
Expand Up @@ -887,7 +887,8 @@ AdbcStatusCode PostgresStatement::Cancel(struct AdbcError* error) {
AdbcStatusCode PostgresStatement::CreateBulkTable(
const std::string& current_schema, const struct ArrowSchema& source_schema,
const std::vector<struct ArrowSchemaView>& source_schema_fields,
std::string* escaped_table, struct AdbcError* error) {
std::string* escaped_table, std::string* escaped_field_list,
struct AdbcError* error) {
PGconn* conn = connection_->conn();

if (!ingest_.db_schema.empty() && ingest_.temporary) {
Expand Down Expand Up @@ -944,10 +945,9 @@ AdbcStatusCode PostgresStatement::CreateBulkTable(

switch (ingest_.mode) {
case IngestMode::kCreate:
case IngestMode::kAppend:
// Nothing to do
break;
case IngestMode::kAppend:
return ADBC_STATUS_OK;
case IngestMode::kReplace: {
std::string drop = "DROP TABLE IF EXISTS " + *escaped_table;
PGresult* result = PQexecParams(conn, drop.c_str(), /*nParams=*/0,
Expand All @@ -972,7 +972,10 @@ AdbcStatusCode PostgresStatement::CreateBulkTable(
create += " (";

for (size_t i = 0; i < source_schema_fields.size(); i++) {
if (i > 0) create += ", ";
if (i > 0) {
create += ", ";
*escaped_field_list += ", ";
}

const char* unescaped = source_schema.children[i]->name;
char* escaped = PQescapeIdentifier(conn, unescaped, std::strlen(unescaped));
Expand All @@ -982,6 +985,7 @@ AdbcStatusCode PostgresStatement::CreateBulkTable(
return ADBC_STATUS_INTERNAL;
}
create += escaped;
*escaped_field_list += escaped;
PQfreemem(escaped);

switch (source_schema_fields[i].type) {
Expand Down Expand Up @@ -1034,6 +1038,10 @@ AdbcStatusCode PostgresStatement::CreateBulkTable(
}
}

if (ingest_.mode == IngestMode::kAppend) {
return ADBC_STATUS_OK;
}

create += ")";
SetError(error, "%s%s", "[libpq] ", create.c_str());
PGresult* result = PQexecParams(conn, create.c_str(), /*nParams=*/0,
Expand Down Expand Up @@ -1203,15 +1211,21 @@ AdbcStatusCode PostgresStatement::ExecuteUpdateBulk(int64_t* rows_affected,
BindStream bind_stream(std::move(bind_));
std::memset(&bind_, 0, sizeof(bind_));
std::string escaped_table;
std::string escaped_field_list;
RAISE_ADBC(bind_stream.Begin(
[&]() -> AdbcStatusCode {
return CreateBulkTable(current_schema, bind_stream.bind_schema.value,
bind_stream.bind_schema_fields, &escaped_table, error);
bind_stream.bind_schema_fields, &escaped_table,
&escaped_field_list, error);
},
error));
RAISE_ADBC(bind_stream.SetParamTypes(*type_resolver_, error));

std::string query = "COPY " + escaped_table + " FROM STDIN WITH (FORMAT binary)";
std::string query = "COPY ";
query += escaped_table;
query += " (";
query += escaped_field_list;
query += ") FROM STDIN WITH (FORMAT binary)";
PGresult* result = PQexec(connection_->conn(), query.c_str());
if (PQresultStatus(result) != PGRES_COPY_IN) {
AdbcStatusCode code =
Expand Down
3 changes: 2 additions & 1 deletion c/driver/postgresql/statement.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ class PostgresStatement {
AdbcStatusCode CreateBulkTable(
const std::string& current_schema, const struct ArrowSchema& source_schema,
const std::vector<struct ArrowSchemaView>& source_schema_fields,
std::string* escaped_table, struct AdbcError* error);
std::string* escaped_table, std::string* escaped_field_list,
struct AdbcError* error);
AdbcStatusCode ExecuteUpdateBulk(int64_t* rows_affected, struct AdbcError* error);
AdbcStatusCode ExecuteUpdateQuery(int64_t* rows_affected, struct AdbcError* error);
AdbcStatusCode ExecutePreparedStatement(struct ArrowArrayStream* stream,
Expand Down
40 changes: 32 additions & 8 deletions c/driver/sqlite/sqlite.c
Original file line number Diff line number Diff line change
Expand Up @@ -1136,7 +1136,7 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt,
goto cleanup;
}

sqlite3_str_appendf(insert_query, "INSERT INTO %s VALUES (", table);
sqlite3_str_appendf(insert_query, "INSERT INTO %s (", table);
if (sqlite3_str_errcode(insert_query)) {
SetError(error, "[SQLite] Failed to build INSERT: %s", sqlite3_errmsg(stmt->conn));
code = ADBC_STATUS_INTERNAL;
Expand All @@ -1154,6 +1154,14 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt,
code = ADBC_STATUS_INTERNAL;
goto cleanup;
}

sqlite3_str_appendf(insert_query, "%s", ", ");
if (sqlite3_str_errcode(insert_query)) {
SetError(error, "[SQLite] Failed to build INSERT: %s",
sqlite3_errmsg(stmt->conn));
code = ADBC_STATUS_INTERNAL;
goto cleanup;
}
}

sqlite3_str_appendf(create_query, "\"%w\"", stmt->binder.schema.children[i]->name);
Expand All @@ -1163,6 +1171,13 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt,
goto cleanup;
}

sqlite3_str_appendf(insert_query, "\"%w\"", stmt->binder.schema.children[i]->name);
if (sqlite3_str_errcode(insert_query)) {
SetError(error, "[SQLite] Failed to build INSERT: %s", sqlite3_errmsg(stmt->conn));
code = ADBC_STATUS_INTERNAL;
goto cleanup;
}

int status =
ArrowSchemaViewInit(&view, stmt->binder.schema.children[i], &arrow_error);
if (status != 0) {
Expand Down Expand Up @@ -1199,13 +1214,6 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt,
default:
break;
}

sqlite3_str_appendf(insert_query, "%s?", (i > 0 ? ", " : ""));
if (sqlite3_str_errcode(insert_query)) {
SetError(error, "[SQLite] Failed to build INSERT: %s", sqlite3_errmsg(stmt->conn));
code = ADBC_STATUS_INTERNAL;
goto cleanup;
}
}

sqlite3_str_appendchar(create_query, 1, ')');
Expand All @@ -1215,6 +1223,22 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt,
goto cleanup;
}

sqlite3_str_appendall(insert_query, ") VALUES (");
if (sqlite3_str_errcode(insert_query)) {
SetError(error, "[SQLite] Failed to build INSERT: %s", sqlite3_errmsg(stmt->conn));
code = ADBC_STATUS_INTERNAL;
goto cleanup;
}

for (int i = 0; i < stmt->binder.schema.n_children; i++) {
sqlite3_str_appendf(insert_query, "%s?", (i > 0 ? ", " : ""));
if (sqlite3_str_errcode(insert_query)) {
SetError(error, "[SQLite] Failed to build INSERT: %s", sqlite3_errmsg(stmt->conn));
code = ADBC_STATUS_INTERNAL;
goto cleanup;
}
}

sqlite3_str_appendchar(insert_query, 1, ')');
if (sqlite3_str_errcode(insert_query)) {
SetError(error, "[SQLite] Failed to build INSERT: %s", sqlite3_errmsg(stmt->conn));
Expand Down
8 changes: 8 additions & 0 deletions c/driver/sqlite/sqlite_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@ class SqliteQuirks : public adbc_validation::DriverQuirks {
return ddl;
}

std::optional<std::string> PrimaryKeyIngestTableDdl(
std::string_view name) const override {
std::string ddl = "CREATE TABLE ";
ddl += name;
ddl += " (id INTEGER PRIMARY KEY, value BIGINT)";
return ddl;
}

std::optional<std::string> CompositePrimaryKeyTableDdl(
std::string_view name) const override {
std::string ddl = "CREATE TABLE ";
Expand Down
109 changes: 109 additions & 0 deletions c/validation/adbc_validation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2803,6 +2803,115 @@ void StatementTest::TestSqlIngestTemporaryExclusive() {
}
}

void StatementTest::TestSqlIngestPrimaryKey() {
std::string name = "pkeytest";
auto ddl = quirks()->PrimaryKeyIngestTableDdl(name);
if (!ddl) {
GTEST_SKIP();
}
ASSERT_THAT(quirks()->DropTable(&connection, name, &error), IsOkStatus(&error));

// Create table
{
Handle<struct AdbcStatement> statement;
StreamReader reader;
ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementSetSqlQuery(&statement.value, ddl->c_str(), &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementRelease(&statement.value, &error), IsOkStatus(&error));
}

// Ingest without the primary key
{
Handle<struct ArrowSchema> schema;
Handle<struct ArrowArray> array;
struct ArrowError na_error;
ASSERT_THAT(MakeSchema(&schema.value, {{"value", NANOARROW_TYPE_INT64}}),
IsOkErrno());
ASSERT_THAT((MakeBatch<int64_t>(&schema.value, &array.value, &na_error,
{42, -42, std::nullopt})),
IsOkErrno());

Handle<struct AdbcStatement> statement;
ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TARGET_TABLE,
name.c_str(), &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_MODE,
ADBC_INGEST_OPTION_MODE_APPEND, &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementBind(&statement.value, &array.value, &schema.value, &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementRelease(&statement.value, &error), IsOkStatus(&error));
}

// Ingest with the primary key
{
Handle<struct ArrowSchema> schema;
Handle<struct ArrowArray> array;
struct ArrowError na_error;
ASSERT_THAT(MakeSchema(&schema.value,
{
{"id", NANOARROW_TYPE_INT64},
{"value", NANOARROW_TYPE_INT64},
}),
IsOkErrno());
ASSERT_THAT((MakeBatch<int64_t, int64_t>(&schema.value, &array.value, &na_error,
{4, 5, 6}, {1, 0, -1})),
IsOkErrno());

Handle<struct AdbcStatement> statement;
ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TARGET_TABLE,
name.c_str(), &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_MODE,
ADBC_INGEST_OPTION_MODE_APPEND, &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementBind(&statement.value, &array.value, &schema.value, &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementRelease(&statement.value, &error), IsOkStatus(&error));
}

// Get the data
{
Handle<struct AdbcStatement> statement;
StreamReader reader;
ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementSetSqlQuery(
&statement.value, "SELECT * FROM pkeytest ORDER BY id ASC", &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, &reader.stream.value, nullptr,
&error),
IsOkStatus(&error));

ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
ASSERT_EQ(2, reader.schema->n_children);
ASSERT_NO_FATAL_FAILURE(reader.Next());
ASSERT_NE(nullptr, reader.array->release);
ASSERT_EQ(6, reader.array->length);
ASSERT_EQ(2, reader.array->n_children);

// Different databases start numbering at 0 or 1 for the primary key
// column, so can't compare it
// TODO(https://github.com/apache/arrow-adbc/issues/938): if the test
// helpers converted data to plain C++ values we could do a more
// sophisticated assertion
ASSERT_NO_FATAL_FAILURE(CompareArray<int64_t>(reader.array_view->children[1],
{42, -42, std::nullopt, 1, 0, -1}));
}
}

void StatementTest::TestSqlPartitionedInts() {
ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 42", &error),
Expand Down
15 changes: 15 additions & 0 deletions c/validation/adbc_validation.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,19 @@ class DriverQuirks {
return std::nullopt;
}

/// \brief Get the statement to create a table with a primary key, or
/// nullopt if not supported. This is used to test ingestion into a table
/// with an auto-incrementing primary key (which should not require the
/// data to contain the primary key).
///
/// The table should have two columns:
/// - "id" which should be an auto-incrementing primary key compatible with int64
/// - "value" with Arrow type int64
virtual std::optional<std::string> PrimaryKeyIngestTableDdl(
std::string_view name) const {
return std::nullopt;
}

/// \brief Get the statement to create a table with a composite primary key,
/// or nullopt if not supported.
///
Expand Down Expand Up @@ -347,6 +360,7 @@ class StatementTest {
void TestSqlIngestTemporaryAppend();
void TestSqlIngestTemporaryReplace();
void TestSqlIngestTemporaryExclusive();
void TestSqlIngestPrimaryKey();

void TestSqlPartitionedInts();

Expand Down Expand Up @@ -444,6 +458,7 @@ class StatementTest {
TEST_F(FIXTURE, SqlIngestTemporaryAppend) { TestSqlIngestTemporaryAppend(); } \
TEST_F(FIXTURE, SqlIngestTemporaryReplace) { TestSqlIngestTemporaryReplace(); } \
TEST_F(FIXTURE, SqlIngestTemporaryExclusive) { TestSqlIngestTemporaryExclusive(); } \
TEST_F(FIXTURE, SqlIngestPrimaryKey) { TestSqlIngestPrimaryKey(); } \
TEST_F(FIXTURE, SqlPartitionedInts) { TestSqlPartitionedInts(); } \
TEST_F(FIXTURE, SqlPrepareGetParameterSchema) { TestSqlPrepareGetParameterSchema(); } \
TEST_F(FIXTURE, SqlPrepareSelectNoParams) { TestSqlPrepareSelectNoParams(); } \
Expand Down
Loading