diff --git a/c/driver/postgresql/postgresql_test.cc b/c/driver/postgresql/postgresql_test.cc index f6df18093f..343c729ddd 100644 --- a/c/driver/postgresql/postgresql_test.cc +++ b/c/driver/postgresql/postgresql_test.cc @@ -107,6 +107,14 @@ class PostgresQuirks : public adbc_validation::DriverQuirks { return ddl; } + std::optional PrimaryKeyIngestTableDdl( + std::string_view name) const override { + std::string ddl = "CREATE TABLE "; + ddl += name; + ddl += " (id BIGSERIAL PRIMARY KEY, value BIGINT)"; + return ddl; + } + std::optional CompositePrimaryKeyTableDdl( std::string_view name) const override { std::string ddl = "CREATE TABLE "; diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc index e5691b529c..eac7ededfe 100644 --- a/c/driver/postgresql/statement.cc +++ b/c/driver/postgresql/statement.cc @@ -887,7 +887,8 @@ AdbcStatusCode PostgresStatement::Cancel(struct AdbcError* error) { AdbcStatusCode PostgresStatement::CreateBulkTable( const std::string& current_schema, const struct ArrowSchema& source_schema, const std::vector& source_schema_fields, - std::string* escaped_table, struct AdbcError* error) { + std::string* escaped_table, std::string* escaped_field_list, + struct AdbcError* error) { PGconn* conn = connection_->conn(); if (!ingest_.db_schema.empty() && ingest_.temporary) { @@ -944,10 +945,9 @@ AdbcStatusCode PostgresStatement::CreateBulkTable( switch (ingest_.mode) { case IngestMode::kCreate: + case IngestMode::kAppend: // Nothing to do break; - case IngestMode::kAppend: - return ADBC_STATUS_OK; case IngestMode::kReplace: { std::string drop = "DROP TABLE IF EXISTS " + *escaped_table; PGresult* result = PQexecParams(conn, drop.c_str(), /*nParams=*/0, @@ -972,7 +972,10 @@ AdbcStatusCode PostgresStatement::CreateBulkTable( create += " ("; for (size_t i = 0; i < source_schema_fields.size(); i++) { - if (i > 0) create += ", "; + if (i > 0) { + create += ", "; + *escaped_field_list += ", "; + } const char* unescaped = source_schema.children[i]->name; char* escaped = PQescapeIdentifier(conn, unescaped, std::strlen(unescaped)); @@ -982,6 +985,7 @@ AdbcStatusCode PostgresStatement::CreateBulkTable( return ADBC_STATUS_INTERNAL; } create += escaped; + *escaped_field_list += escaped; PQfreemem(escaped); switch (source_schema_fields[i].type) { @@ -1034,6 +1038,10 @@ AdbcStatusCode PostgresStatement::CreateBulkTable( } } + if (ingest_.mode == IngestMode::kAppend) { + return ADBC_STATUS_OK; + } + create += ")"; SetError(error, "%s%s", "[libpq] ", create.c_str()); PGresult* result = PQexecParams(conn, create.c_str(), /*nParams=*/0, @@ -1203,15 +1211,21 @@ AdbcStatusCode PostgresStatement::ExecuteUpdateBulk(int64_t* rows_affected, BindStream bind_stream(std::move(bind_)); std::memset(&bind_, 0, sizeof(bind_)); std::string escaped_table; + std::string escaped_field_list; RAISE_ADBC(bind_stream.Begin( [&]() -> AdbcStatusCode { return CreateBulkTable(current_schema, bind_stream.bind_schema.value, - bind_stream.bind_schema_fields, &escaped_table, error); + bind_stream.bind_schema_fields, &escaped_table, + &escaped_field_list, error); }, error)); RAISE_ADBC(bind_stream.SetParamTypes(*type_resolver_, error)); - std::string query = "COPY " + escaped_table + " FROM STDIN WITH (FORMAT binary)"; + std::string query = "COPY "; + query += escaped_table; + query += " ("; + query += escaped_field_list; + query += ") FROM STDIN WITH (FORMAT binary)"; PGresult* result = PQexec(connection_->conn(), query.c_str()); if (PQresultStatus(result) != PGRES_COPY_IN) { AdbcStatusCode code = diff --git a/c/driver/postgresql/statement.h b/c/driver/postgresql/statement.h index 20bb3b7ace..c822390d8c 100644 --- a/c/driver/postgresql/statement.h +++ b/c/driver/postgresql/statement.h @@ -128,7 +128,8 @@ class PostgresStatement { AdbcStatusCode CreateBulkTable( const std::string& current_schema, const struct ArrowSchema& source_schema, const std::vector& source_schema_fields, - std::string* escaped_table, struct AdbcError* error); + std::string* escaped_table, std::string* escaped_field_list, + struct AdbcError* error); AdbcStatusCode ExecuteUpdateBulk(int64_t* rows_affected, struct AdbcError* error); AdbcStatusCode ExecuteUpdateQuery(int64_t* rows_affected, struct AdbcError* error); AdbcStatusCode ExecutePreparedStatement(struct ArrowArrayStream* stream, diff --git a/c/driver/sqlite/sqlite.c b/c/driver/sqlite/sqlite.c index e492801832..a94b83f750 100644 --- a/c/driver/sqlite/sqlite.c +++ b/c/driver/sqlite/sqlite.c @@ -1136,7 +1136,7 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt, goto cleanup; } - sqlite3_str_appendf(insert_query, "INSERT INTO %s VALUES (", table); + sqlite3_str_appendf(insert_query, "INSERT INTO %s (", table); if (sqlite3_str_errcode(insert_query)) { SetError(error, "[SQLite] Failed to build INSERT: %s", sqlite3_errmsg(stmt->conn)); code = ADBC_STATUS_INTERNAL; @@ -1154,6 +1154,14 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt, code = ADBC_STATUS_INTERNAL; goto cleanup; } + + sqlite3_str_appendf(insert_query, "%s", ", "); + if (sqlite3_str_errcode(insert_query)) { + SetError(error, "[SQLite] Failed to build INSERT: %s", + sqlite3_errmsg(stmt->conn)); + code = ADBC_STATUS_INTERNAL; + goto cleanup; + } } sqlite3_str_appendf(create_query, "\"%w\"", stmt->binder.schema.children[i]->name); @@ -1163,6 +1171,13 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt, goto cleanup; } + sqlite3_str_appendf(insert_query, "\"%w\"", stmt->binder.schema.children[i]->name); + if (sqlite3_str_errcode(insert_query)) { + SetError(error, "[SQLite] Failed to build INSERT: %s", sqlite3_errmsg(stmt->conn)); + code = ADBC_STATUS_INTERNAL; + goto cleanup; + } + int status = ArrowSchemaViewInit(&view, stmt->binder.schema.children[i], &arrow_error); if (status != 0) { @@ -1199,13 +1214,6 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt, default: break; } - - sqlite3_str_appendf(insert_query, "%s?", (i > 0 ? ", " : "")); - if (sqlite3_str_errcode(insert_query)) { - SetError(error, "[SQLite] Failed to build INSERT: %s", sqlite3_errmsg(stmt->conn)); - code = ADBC_STATUS_INTERNAL; - goto cleanup; - } } sqlite3_str_appendchar(create_query, 1, ')'); @@ -1215,6 +1223,22 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt, goto cleanup; } + sqlite3_str_appendall(insert_query, ") VALUES ("); + if (sqlite3_str_errcode(insert_query)) { + SetError(error, "[SQLite] Failed to build INSERT: %s", sqlite3_errmsg(stmt->conn)); + code = ADBC_STATUS_INTERNAL; + goto cleanup; + } + + for (int i = 0; i < stmt->binder.schema.n_children; i++) { + sqlite3_str_appendf(insert_query, "%s?", (i > 0 ? ", " : "")); + if (sqlite3_str_errcode(insert_query)) { + SetError(error, "[SQLite] Failed to build INSERT: %s", sqlite3_errmsg(stmt->conn)); + code = ADBC_STATUS_INTERNAL; + goto cleanup; + } + } + sqlite3_str_appendchar(insert_query, 1, ')'); if (sqlite3_str_errcode(insert_query)) { SetError(error, "[SQLite] Failed to build INSERT: %s", sqlite3_errmsg(stmt->conn)); diff --git a/c/driver/sqlite/sqlite_test.cc b/c/driver/sqlite/sqlite_test.cc index 13da21c10c..db31891774 100644 --- a/c/driver/sqlite/sqlite_test.cc +++ b/c/driver/sqlite/sqlite_test.cc @@ -98,6 +98,14 @@ class SqliteQuirks : public adbc_validation::DriverQuirks { return ddl; } + std::optional PrimaryKeyIngestTableDdl( + std::string_view name) const override { + std::string ddl = "CREATE TABLE "; + ddl += name; + ddl += " (id INTEGER PRIMARY KEY, value BIGINT)"; + return ddl; + } + std::optional CompositePrimaryKeyTableDdl( std::string_view name) const override { std::string ddl = "CREATE TABLE "; diff --git a/c/validation/adbc_validation.cc b/c/validation/adbc_validation.cc index f0f42937f8..d30aa0a979 100644 --- a/c/validation/adbc_validation.cc +++ b/c/validation/adbc_validation.cc @@ -2803,6 +2803,115 @@ void StatementTest::TestSqlIngestTemporaryExclusive() { } } +void StatementTest::TestSqlIngestPrimaryKey() { + std::string name = "pkeytest"; + auto ddl = quirks()->PrimaryKeyIngestTableDdl(name); + if (!ddl) { + GTEST_SKIP(); + } + ASSERT_THAT(quirks()->DropTable(&connection, name, &error), IsOkStatus(&error)); + + // Create table + { + Handle statement; + StreamReader reader; + ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement.value, ddl->c_str(), &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementRelease(&statement.value, &error), IsOkStatus(&error)); + } + + // Ingest without the primary key + { + Handle schema; + Handle array; + struct ArrowError na_error; + ASSERT_THAT(MakeSchema(&schema.value, {{"value", NANOARROW_TYPE_INT64}}), + IsOkErrno()); + ASSERT_THAT((MakeBatch(&schema.value, &array.value, &na_error, + {42, -42, std::nullopt})), + IsOkErrno()); + + Handle statement; + ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TARGET_TABLE, + name.c_str(), &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_MODE, + ADBC_INGEST_OPTION_MODE_APPEND, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement.value, &array.value, &schema.value, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementRelease(&statement.value, &error), IsOkStatus(&error)); + } + + // Ingest with the primary key + { + Handle schema; + Handle array; + struct ArrowError na_error; + ASSERT_THAT(MakeSchema(&schema.value, + { + {"id", NANOARROW_TYPE_INT64}, + {"value", NANOARROW_TYPE_INT64}, + }), + IsOkErrno()); + ASSERT_THAT((MakeBatch(&schema.value, &array.value, &na_error, + {4, 5, 6}, {1, 0, -1})), + IsOkErrno()); + + Handle statement; + ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_TARGET_TABLE, + name.c_str(), &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetOption(&statement.value, ADBC_INGEST_OPTION_MODE, + ADBC_INGEST_OPTION_MODE_APPEND, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementBind(&statement.value, &array.value, &schema.value, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, nullptr, nullptr, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementRelease(&statement.value, &error), IsOkStatus(&error)); + } + + // Get the data + { + Handle statement; + StreamReader reader; + ASSERT_THAT(AdbcStatementNew(&connection, &statement.value, &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementSetSqlQuery( + &statement.value, "SELECT * FROM pkeytest ORDER BY id ASC", &error), + IsOkStatus(&error)); + ASSERT_THAT(AdbcStatementExecuteQuery(&statement.value, &reader.stream.value, nullptr, + &error), + IsOkStatus(&error)); + + ASSERT_NO_FATAL_FAILURE(reader.GetSchema()); + ASSERT_EQ(2, reader.schema->n_children); + ASSERT_NO_FATAL_FAILURE(reader.Next()); + ASSERT_NE(nullptr, reader.array->release); + ASSERT_EQ(6, reader.array->length); + ASSERT_EQ(2, reader.array->n_children); + + // Different databases start numbering at 0 or 1 for the primary key + // column, so can't compare it + // TODO(https://github.com/apache/arrow-adbc/issues/938): if the test + // helpers converted data to plain C++ values we could do a more + // sophisticated assertion + ASSERT_NO_FATAL_FAILURE(CompareArray(reader.array_view->children[1], + {42, -42, std::nullopt, 1, 0, -1})); + } +} + void StatementTest::TestSqlPartitionedInts() { ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 42", &error), diff --git a/c/validation/adbc_validation.h b/c/validation/adbc_validation.h index e2b5d434d2..874d9a0584 100644 --- a/c/validation/adbc_validation.h +++ b/c/validation/adbc_validation.h @@ -86,6 +86,19 @@ class DriverQuirks { return std::nullopt; } + /// \brief Get the statement to create a table with a primary key, or + /// nullopt if not supported. This is used to test ingestion into a table + /// with an auto-incrementing primary key (which should not require the + /// data to contain the primary key). + /// + /// The table should have two columns: + /// - "id" which should be an auto-incrementing primary key compatible with int64 + /// - "value" with Arrow type int64 + virtual std::optional PrimaryKeyIngestTableDdl( + std::string_view name) const { + return std::nullopt; + } + /// \brief Get the statement to create a table with a composite primary key, /// or nullopt if not supported. /// @@ -347,6 +360,7 @@ class StatementTest { void TestSqlIngestTemporaryAppend(); void TestSqlIngestTemporaryReplace(); void TestSqlIngestTemporaryExclusive(); + void TestSqlIngestPrimaryKey(); void TestSqlPartitionedInts(); @@ -444,6 +458,7 @@ class StatementTest { TEST_F(FIXTURE, SqlIngestTemporaryAppend) { TestSqlIngestTemporaryAppend(); } \ TEST_F(FIXTURE, SqlIngestTemporaryReplace) { TestSqlIngestTemporaryReplace(); } \ TEST_F(FIXTURE, SqlIngestTemporaryExclusive) { TestSqlIngestTemporaryExclusive(); } \ + TEST_F(FIXTURE, SqlIngestPrimaryKey) { TestSqlIngestPrimaryKey(); } \ TEST_F(FIXTURE, SqlPartitionedInts) { TestSqlPartitionedInts(); } \ TEST_F(FIXTURE, SqlPrepareGetParameterSchema) { TestSqlPrepareGetParameterSchema(); } \ TEST_F(FIXTURE, SqlPrepareSelectNoParams) { TestSqlPrepareSelectNoParams(); } \