Skip to content

Commit

Permalink
Implement remaining functions in 1.0 spec.
Browse files Browse the repository at this point in the history
Closes #1221
  • Loading branch information
CurtHagenlocher committed Apr 25, 2024
1 parent 96e05a0 commit 4d2167b
Show file tree
Hide file tree
Showing 5 changed files with 332 additions and 45 deletions.
18 changes: 16 additions & 2 deletions csharp/src/Apache.Arrow.Adbc/AdbcStatement.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,29 @@ public AdbcStatement()
/// </summary>
public virtual byte[] SubstraitPlan
{
get { throw new NotImplementedException(); }
set { throw new NotImplementedException(); }
get { throw AdbcException.NotImplemented("Statement does not support SubstraitPlan"); }
set { throw AdbcException.NotImplemented("Statement does not support SubstraitPlan"); }
}

/// <summary>
/// Binds this statement to a <see cref="RecordBatch"/> to provide parameter values or bulk data ingestion.
/// </summary>
/// <param name="batch">the RecordBatch to bind</param>
/// <param name="schema">the schema of the RecordBatch</param>
public virtual void Bind(RecordBatch batch, Schema schema)
{
throw AdbcException.NotImplemented("Statement does not support Bind");
}

/// <summary>
/// Binds this statement to an <see cref="IArrowArrayStream"/> to provide parameter values or bulk data ingestion.
/// </summary>
/// <param name="stream"></param>
public virtual void BindStream(IArrowArrayStream stream)
{
throw AdbcException.NotImplemented("Statement does not support BindStream");
}

/// <summary>
/// Executes the statement and returns a tuple containing the number
/// of records and the <see cref="IArrowArrayStream"/>.
Expand Down
179 changes: 155 additions & 24 deletions csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverExporter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,10 @@
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using Apache.Arrow.Adbc.Extensions;
using Apache.Arrow.C;
using Apache.Arrow.Ipc;

#if NETSTANDARD
using Apache.Arrow.Adbc.Extensions;
#endif

namespace Apache.Arrow.Adbc.C
{
public class CAdbcDriverExporter
Expand All @@ -38,6 +35,7 @@ public class CAdbcDriverExporter
#if NET5_0_OR_GREATER
private static unsafe delegate* unmanaged<CAdbcError*, void> ReleaseErrorPtr => (delegate* unmanaged<CAdbcError*, void>)s_releaseError.Pointer;
private static unsafe delegate* unmanaged<CAdbcDriver*, CAdbcError*, AdbcStatusCode> ReleaseDriverPtr => &ReleaseDriver;
private static unsafe delegate* unmanaged<CAdbcPartitions*, void> ReleasePartitionsPtr => &ReleasePartitions;

private static unsafe delegate* unmanaged<CAdbcDatabase*, CAdbcError*, AdbcStatusCode> DatabaseInitPtr => &InitDatabase;
private static unsafe delegate* unmanaged<CAdbcDatabase*, CAdbcError*, AdbcStatusCode> DatabaseReleasePtr => &ReleaseDatabase;
Expand All @@ -55,16 +53,23 @@ public class CAdbcDriverExporter
private static unsafe delegate* unmanaged<CAdbcConnection*, byte*, byte*, CAdbcError*, AdbcStatusCode> ConnectionSetOptionPtr => &SetConnectionOption;

private static unsafe delegate* unmanaged<CAdbcStatement*, CArrowArray*, CArrowSchema*, CAdbcError*, AdbcStatusCode> StatementBindPtr => &BindStatement;
private static unsafe delegate* unmanaged<CAdbcStatement*, CArrowArrayStream*, CAdbcError*, AdbcStatusCode> StatementBindStreamPtr => &BindStreamStatement;
private static unsafe delegate* unmanaged<CAdbcStatement*, CArrowArrayStream*, long*, CAdbcError*, AdbcStatusCode> StatementExecuteQueryPtr => &ExecuteStatementQuery;
private static unsafe delegate* unmanaged<CAdbcStatement*, CArrowSchema*, CAdbcPartitions*, long*, CAdbcError*, AdbcStatusCode> StatementExecutePartitionsPtr => &ExecuteStatementPartitions;
private static unsafe delegate* unmanaged<CAdbcConnection*, CAdbcStatement*, CAdbcError*, AdbcStatusCode> StatementNewPtr => &NewStatement;
private static unsafe delegate* unmanaged<CAdbcStatement*, CAdbcError*, AdbcStatusCode> StatementReleasePtr => &ReleaseStatement;
private static unsafe delegate* unmanaged<CAdbcStatement*, CAdbcError*, AdbcStatusCode> StatementPreparePtr => &PrepareStatement;
private static unsafe delegate* unmanaged<CAdbcStatement*, byte*, CAdbcError*, AdbcStatusCode> StatementSetSqlQueryPtr => &SetStatementSqlQuery;
private static unsafe delegate* unmanaged<CAdbcStatement*, byte*, int, CAdbcError*, AdbcStatusCode> StatementSetSubstraitPlanPtr => &SetStatementSubstraitPlan;
private static unsafe delegate* unmanaged<CAdbcStatement*, CArrowSchema*, CAdbcError*, AdbcStatusCode> StatementGetParameterSchemaPtr => &GetStatementParameterSchema;
#else
private static IntPtr ReleaseErrorPtr => s_releaseError.Pointer;
internal unsafe delegate AdbcStatusCode DriverRelease(CAdbcDriver* driver, CAdbcError* error);
private static unsafe readonly NativeDelegate<DriverRelease> s_releaseDriver = new NativeDelegate<DriverRelease>(ReleaseDriver);
private static IntPtr ReleaseDriverPtr => s_releaseDriver.Pointer;
internal unsafe delegate void PartitionsRelease(CAdbcPartitions* partitions);
private static unsafe readonly NativeDelegate<PartitionsRelease> s_releasePartitions = new NativeDelegate<PartitionsRelease>(ReleasePartitions);
private static IntPtr ReleasePartitionsPtr => s_releasePartitions.Pointer;

private static unsafe readonly NativeDelegate<DatabaseFn> s_databaseInit = new NativeDelegate<DatabaseFn>(InitDatabase);
private static IntPtr DatabaseInitPtr => s_databaseInit.Pointer;
Expand Down Expand Up @@ -102,12 +107,18 @@ public class CAdbcDriverExporter
private static unsafe readonly NativeDelegate<ConnectionSetOption> s_connectionSetOption = new NativeDelegate<ConnectionSetOption>(SetConnectionOption);
private static IntPtr ConnectionSetOptionPtr => s_connectionSetOption.Pointer;

private unsafe delegate AdbcStatusCode StatementBind(CAdbcStatement* statement, CArrowArray* array, CArrowSchema* schema, CAdbcError* error);
internal unsafe delegate AdbcStatusCode StatementBind(CAdbcStatement* statement, CArrowArray* array, CArrowSchema* schema, CAdbcError* error);
private static unsafe readonly NativeDelegate<StatementBind> s_statementBind = new NativeDelegate<StatementBind>(BindStatement);
private static IntPtr StatementBindPtr => s_statementBind.Pointer;
internal unsafe delegate AdbcStatusCode StatementBindStream(CAdbcStatement* statement, CArrowArrayStream* stream, CAdbcError* error);
private static unsafe readonly NativeDelegate<StatementBindStream> s_statementBindStream = new NativeDelegate<StatementBindStream>(BindStreamStatement);
private static IntPtr StatementBindStreamPtr => s_statementBindStream.Pointer;
internal unsafe delegate AdbcStatusCode StatementExecuteQuery(CAdbcStatement* statement, CArrowArrayStream* stream, long* rows, CAdbcError* error);
private static unsafe readonly NativeDelegate<StatementExecuteQuery> s_statementExecuteQuery = new NativeDelegate<StatementExecuteQuery>(ExecuteStatementQuery);
private static IntPtr StatementExecuteQueryPtr = s_statementExecuteQuery.Pointer;
internal unsafe delegate AdbcStatusCode StatementExecutePartitions(CAdbcStatement* statement, CArrowSchema* schema, CAdbcPartitions* partitions, long* rows, CAdbcError* error);
private static unsafe readonly NativeDelegate<StatementExecutePartitions> s_statementExecutePartitions = new NativeDelegate<StatementExecutePartitions>(ExecuteStatementPartitions);
private static IntPtr StatementExecutePartitionsPtr = s_statementExecutePartitions.Pointer;
internal unsafe delegate AdbcStatusCode StatementNew(CAdbcConnection* connection, CAdbcStatement* statement, CAdbcError* error);
private static unsafe readonly NativeDelegate<StatementNew> s_statementNew = new NativeDelegate<StatementNew>(NewStatement);
private static IntPtr StatementNewPtr => s_statementNew.Pointer;
Expand All @@ -119,17 +130,14 @@ public class CAdbcDriverExporter
internal unsafe delegate AdbcStatusCode StatementSetSqlQuery(CAdbcStatement* statement, byte* text, CAdbcError* error);
private static unsafe readonly NativeDelegate<StatementSetSqlQuery> s_statementSetSqlQuery = new NativeDelegate<StatementSetSqlQuery>(SetStatementSqlQuery);
private static IntPtr StatementSetSqlQueryPtr = s_statementSetSqlQuery.Pointer;
internal unsafe delegate AdbcStatusCode StatementSetSubstraitPlan(CAdbcStatement* statement, byte* plan, int length, CAdbcError* error);
private static unsafe readonly NativeDelegate<StatementSetSubstraitPlan> s_statementSetSubstraitPlan = new NativeDelegate<StatementSetSubstraitPlan>(SetStatementSubstraitPlan);
private static IntPtr StatementSetSubstraitPlanPtr = s_statementSetSubstraitPlan.Pointer;
internal unsafe delegate AdbcStatusCode StatementGetParameterSchema(CAdbcStatement* statement, CArrowSchema* schema, CAdbcError* error);
private static unsafe readonly NativeDelegate<StatementGetParameterSchema> s_statementGetParameterSchema = new NativeDelegate<StatementGetParameterSchema>(GetStatementParameterSchema);
private static IntPtr StatementGetParameterSchemaPtr = s_statementGetParameterSchema.Pointer;
#endif

/*
* Not yet implemented
unsafe delegate AdbcStatusCode StatementBindStream(CAdbcStatement* statement, CArrowArrayStream* stream, CAdbcError* error);
unsafe delegate AdbcStatusCode StatementExecutePartitions(CAdbcStatement* statement, CArrowSchema* schema, CAdbcPartitions* partitions, long* rows_affected, CAdbcError* error);
unsafe delegate AdbcStatusCode StatementGetParameterSchema(CAdbcStatement* statement, CArrowSchema* schema, CAdbcError* error);
unsafe delegate AdbcStatusCode StatementSetSubstraitPlan(CAdbcStatement statement, byte* plan, int length, CAdbcError error);
*/

public unsafe static AdbcStatusCode AdbcDriverInit(int version, CAdbcDriver* nativeDriver, CAdbcError* error, AdbcDriver driver)
{
DriverStub stub = new DriverStub(driver);
Expand All @@ -142,7 +150,6 @@ public unsafe static AdbcStatusCode AdbcDriverInit(int version, CAdbcDriver* nat
nativeDriver->DatabaseSetOption = DatabaseSetOptionPtr;
nativeDriver->DatabaseRelease = DatabaseReleasePtr;

// TODO: This should probably only set the pointers for the functionality actually supported by this particular driver
nativeDriver->ConnectionCommit = ConnectionCommitPtr;
nativeDriver->ConnectionGetInfo = ConnectionGetInfoPtr;
nativeDriver->ConnectionGetObjects = ConnectionGetObjectsPtr;
Expand All @@ -156,15 +163,15 @@ public unsafe static AdbcStatusCode AdbcDriverInit(int version, CAdbcDriver* nat
nativeDriver->ConnectionRollback = ConnectionRollbackPtr;

nativeDriver->StatementBind = StatementBindPtr;
// nativeDriver->StatementBindStream = StatementBindStreamPtr;
nativeDriver->StatementBindStream = StatementBindStreamPtr;
nativeDriver->StatementExecuteQuery = StatementExecuteQueryPtr;
// nativeDriver->StatementExecutePartitions = StatementExecutePartitionsPtr;
// nativeDriver->StatementGetParameterSchema = StatementGetParameterSchemaPtr;
nativeDriver->StatementExecutePartitions = StatementExecutePartitionsPtr;
nativeDriver->StatementGetParameterSchema = StatementGetParameterSchemaPtr;
nativeDriver->StatementNew = StatementNewPtr;
nativeDriver->StatementPrepare = StatementPreparePtr;
nativeDriver->StatementRelease = StatementReleasePtr;
nativeDriver->StatementSetSqlQuery = StatementSetSqlQueryPtr;
// nativeDriver->StatementSetSubstraitPlan = StatementSetSubstraitPlanPtr;
nativeDriver->StatementSetSubstraitPlan = StatementSetSubstraitPlanPtr;

return 0;
}
Expand All @@ -181,12 +188,7 @@ private unsafe static AdbcStatusCode SetError(CAdbcError* error, Exception excep
{
ReleaseError(error);

#if NETSTANDARD
error->message = (byte*)MarshalExtensions.StringToCoTaskMemUTF8(exception.Message);
#else
error->message = (byte*)Marshal.StringToCoTaskMemUTF8(exception.Message);
#endif

error->sqlstate0 = (byte)0;
error->sqlstate1 = (byte)0;
error->sqlstate2 = (byte)0;
Expand Down Expand Up @@ -249,6 +251,37 @@ private unsafe static AdbcStatusCode ReleaseDriver(CAdbcDriver* nativeDriver, CA
}
}

#if NET5_0_OR_GREATER
[UnmanagedCallersOnly]
#endif
private unsafe static void ReleasePartitions(CAdbcPartitions* partitions)
{
if (partitions != null)
{
if (partitions->partitions != null)
{
for (int i = 0; i < partitions->num_partitions; i++)
{
byte* partition = partitions->partitions[i];
if (partition != null)
{
Marshal.FreeHGlobal((IntPtr)partition);
partitions->partitions[i] = null;
}
}
Marshal.FreeHGlobal((IntPtr)partitions->partitions);
partitions->partitions = null;
}
if (partitions->partition_lengths != null)
{
Marshal.FreeHGlobal((IntPtr)partitions->partition_lengths);
partitions->partition_lengths = null;
}

partitions->release = default;
}
}

#if NET5_0_OR_GREATER
[UnmanagedCallersOnly]
#endif
Expand Down Expand Up @@ -512,6 +545,46 @@ private unsafe static AdbcStatusCode SetStatementSqlQuery(CAdbcStatement* native
}
}

#if NET5_0_OR_GREATER
[UnmanagedCallersOnly]
#endif
private unsafe static AdbcStatusCode SetStatementSubstraitPlan(CAdbcStatement* nativeStatement, byte* plan, int length, CAdbcError* error)
{
try
{
GCHandle gch = GCHandle.FromIntPtr((IntPtr)nativeStatement->private_data);
AdbcStatement stub = (AdbcStatement)gch.Target;

stub.SubstraitPlan = MarshalExtensions.MarshalBuffer(plan, length);

return AdbcStatusCode.Success;
}
catch (Exception e)
{
return SetError(error, e);
}
}

#if NET5_0_OR_GREATER
[UnmanagedCallersOnly]
#endif
private unsafe static AdbcStatusCode GetStatementParameterSchema(CAdbcStatement* nativeStatement, CArrowSchema* schema, CAdbcError* error)
{
try
{
GCHandle gch = GCHandle.FromIntPtr((IntPtr)nativeStatement->private_data);
AdbcStatement stub = (AdbcStatement)gch.Target;

CArrowSchemaExporter.ExportSchema(stub.GetParameterSchema(), schema);

return AdbcStatusCode.Success;
}
catch (Exception e)
{
return SetError(error, e);
}
}

#if NET5_0_OR_GREATER
[UnmanagedCallersOnly]
#endif
Expand All @@ -533,6 +606,26 @@ private unsafe static AdbcStatusCode BindStatement(CAdbcStatement* nativeStateme
}
}

#if NET5_0_OR_GREATER
[UnmanagedCallersOnly]
#endif
private unsafe static AdbcStatusCode BindStreamStatement(CAdbcStatement* nativeStatement, CArrowArrayStream* stream, CAdbcError* error)
{
try
{
GCHandle gch = GCHandle.FromIntPtr((IntPtr)nativeStatement->private_data);
AdbcStatement stub = (AdbcStatement)gch.Target;

IArrowArrayStream arrayStream = CArrowArrayStreamImporter.ImportArrayStream(stream);
stub.BindStream(arrayStream);
return AdbcStatusCode.Success;
}
catch (Exception e)
{
return SetError(error, e);
}
}

#if NET5_0_OR_GREATER
[UnmanagedCallersOnly]
#endif
Expand All @@ -557,6 +650,44 @@ private unsafe static AdbcStatusCode ExecuteStatementQuery(CAdbcStatement* nativ
}
}

#if NET5_0_OR_GREATER
[UnmanagedCallersOnly]
#endif
private unsafe static AdbcStatusCode ExecuteStatementPartitions(CAdbcStatement* nativeStatement, CArrowSchema* schema, CAdbcPartitions* partitions, long* rows, CAdbcError* error)
{
try
{
GCHandle gch = GCHandle.FromIntPtr((IntPtr)nativeStatement->private_data);
AdbcStatement stub = (AdbcStatement)gch.Target;
var result = stub.ExecutePartitioned();
if (rows != null)
{
*rows = result.AffectedRows;
}

partitions->release = ReleasePartitionsPtr;
partitions->num_partitions = result.PartitionDescriptors.Count;
partitions->partitions = (byte**)Marshal.AllocHGlobal(IntPtr.Size * result.PartitionDescriptors.Count);
partitions->partition_lengths = (nuint*)Marshal.AllocHGlobal(IntPtr.Size * result.PartitionDescriptors.Count);
for (int i = 0; i < partitions->num_partitions; i++)
{
ReadOnlySpan<byte> partition = result.PartitionDescriptors[i].Descriptor;
partitions->partition_lengths[i] = (nuint)partition.Length;
partitions->partitions[i] = (byte*)Marshal.AllocHGlobal(partition.Length);
fixed (void* descriptor = partition)
{
Buffer.MemoryCopy(descriptor, partitions->partitions[i], partition.Length, partition.Length);
}
}

return AdbcStatusCode.Success;
}
catch (Exception e)
{
return SetError(error, e);
}
}

#if NET5_0_OR_GREATER
[UnmanagedCallersOnly]
#endif
Expand Down
Loading

0 comments on commit 4d2167b

Please sign in to comment.