Skip to content

Commit

Permalink
fix(csharp): include GetInfo and GetObjects call for .NET 4.7.2 (#945)
Browse files Browse the repository at this point in the history
Continuation of #930

---------

Co-authored-by: David Coe <[email protected]>
  • Loading branch information
davidhcoe and David Coe authored Jul 28, 2023
1 parent b97e22c commit 7933e16
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 3 deletions.
4 changes: 2 additions & 2 deletions csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverExporter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public class CAdbcDriverExporter
private static unsafe readonly NativeDelegate<DatabaseSetOption> s_databaseSetOption = new NativeDelegate<DatabaseSetOption>(SetDatabaseOption);
private static IntPtr DatabaseSetOptionPtr => s_databaseSetOption.Pointer;

private unsafe delegate AdbcStatusCode ConnectionGetObjects(CAdbcConnection* connection, int depth, byte* catalog, byte* db_schema, byte* table_name, byte** table_type, byte* column_name, CArrowArrayStream* stream, CAdbcError* error);
internal unsafe delegate AdbcStatusCode ConnectionGetObjects(CAdbcConnection* connection, int depth, byte* catalog, byte* db_schema, byte* table_name, byte** table_type, byte* column_name, CArrowArrayStream* stream, CAdbcError* error);
private static unsafe readonly NativeDelegate<ConnectionGetObjects> s_connectionGetObjects = new NativeDelegate<ConnectionGetObjects>(GetConnectionObjects);
private static IntPtr ConnectionGetObjectsPtr => s_connectionGetObjects.Pointer;
private unsafe delegate AdbcStatusCode ConnectionGetTableSchema(CAdbcConnection* connection, byte* catalog, byte* db_schema, byte* table_name, CArrowSchema* schema, CAdbcError* error);
Expand All @@ -93,7 +93,7 @@ public class CAdbcDriverExporter
private static IntPtr ConnectionCommitPtr => s_connectionCommit.Pointer;
private static unsafe readonly NativeDelegate<ConnectionFn> s_connectionRelease = new NativeDelegate<ConnectionFn>(ReleaseConnection);
private static IntPtr ConnectionReleasePtr => s_connectionRelease.Pointer;
private unsafe delegate AdbcStatusCode ConnectionGetInfo(CAdbcConnection* connection, byte* info_codes, int info_codes_length, CArrowArrayStream* stream, CAdbcError* error);
internal unsafe delegate AdbcStatusCode ConnectionGetInfo(CAdbcConnection* connection, byte* info_codes, int info_codes_length, CArrowArrayStream* stream, CAdbcError* error);
private static unsafe readonly NativeDelegate<ConnectionGetInfo> s_connectionGetInfo = new NativeDelegate<ConnectionGetInfo>(GetConnectionInfo);
private static IntPtr ConnectionGetInfoPtr => s_connectionGetInfo.Pointer;
private unsafe delegate AdbcStatusCode ConnectionReadPartition(CAdbcConnection* connection, byte* serialized_partition, int serialized_length, CArrowArrayStream* stream, CAdbcError* error);
Expand Down
122 changes: 122 additions & 0 deletions csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.InteropServices;
using Apache.Arrow.C;
using Apache.Arrow.Ipc;

#if NETSTANDARD
using Apache.Arrow.Adbc.Extensions;
Expand Down Expand Up @@ -220,6 +222,38 @@ public unsafe override AdbcStatement CreateStatement()
return new AdbcStatementNative(_nativeDriver, nativeStatement);
}

public override IArrowArrayStream GetInfo(List<AdbcInfoCode> codes)
{
return GetInfo(codes.Select(x => (int)x).ToList<int>());
}

public override unsafe IArrowArrayStream GetInfo(List<int> codes)
{
CArrowArrayStream* nativeArrayStream = CArrowArrayStream.Create();

using (CallHelper caller = new CallHelper())
{
caller.Call(_nativeDriver.ConnectionGetInfo, ref _nativeConnection, codes, nativeArrayStream);
}

IArrowArrayStream arrowArrayStream = CArrowArrayStreamImporter.ImportArrayStream(nativeArrayStream);

return arrowArrayStream;
}

public override unsafe IArrowArrayStream GetObjects(GetObjectsDepth depth, string catalogPattern, string dbSchemaPattern, string tableNamePattern, List<string> tableTypes, string columnNamePattern)
{
CArrowArrayStream* nativeArrayStream = CArrowArrayStream.Create();

using (CallHelper caller = new CallHelper())
{
caller.Call(_nativeDriver.ConnectionGetObjects, ref _nativeConnection, (int)depth, catalogPattern, dbSchemaPattern, tableNamePattern, tableTypes, columnNamePattern, nativeArrayStream);
}

IArrowArrayStream arrowArrayStream = CArrowArrayStreamImporter.ImportArrayStream(nativeArrayStream);

return arrowArrayStream;
}
}

/// <summary>
Expand Down Expand Up @@ -567,6 +601,92 @@ public unsafe void Dispose()
}
}

#if NET5_0_OR_GREATER
public unsafe void Call(delegate* unmanaged<CAdbcConnection*, byte*, int, CArrowArrayStream*, CAdbcError*, AdbcStatusCode> fn, ref CAdbcConnection connection, List<int> infoCodes, CArrowArrayStream* stream)
#else
public unsafe void Call(IntPtr ptr, ref CAdbcConnection connection, List<int> infoCodes, CArrowArrayStream* stream)
#endif
{
int numInts = infoCodes.Count;

// Calculate the total number of bytes needed
int totalBytes = numInts * sizeof(int);

IntPtr bytePtr = Marshal.AllocHGlobal(totalBytes);

int[] intArray = infoCodes.ToArray();
Marshal.Copy(intArray, 0, bytePtr, numInts);

fixed (CAdbcConnection* cn = &connection)
fixed (CAdbcError* e = &_error)
{
#if NET5_0_OR_GREATER
TranslateCode(fn(cn, (byte*)bytePtr, infoCodes.Count, stream, e));
#else
TranslateCode(Marshal.GetDelegateForFunctionPointer<CAdbcDriverExporter.ConnectionGetInfo>(ptr)(cn, (byte*)bytePtr, infoCodes.Count, stream, e));
#endif
}
}

#if NET5_0_OR_GREATER
public unsafe void Call(delegate* unmanaged<CAdbcConnection*, int, byte*, byte*, byte*, byte**, byte*, CArrowArrayStream*, CAdbcError*, AdbcStatusCode> fn, ref CAdbcConnection connection, int depth, string catalog, string db_schema, string table_name, List<string> table_types, string column_name, CArrowArrayStream* stream)
#else
public unsafe void Call(IntPtr fn, ref CAdbcConnection connection, int depth, string catalog, string db_schema, string table_name, List<string> table_types, string column_name, CArrowArrayStream* stream)
#endif
{
byte* bcatalog, bDb_schema, bTable_name, bColumn_Name;

if(table_types == null)
{
table_types = new List<string>();
}

// need to terminate with a null entry per https://github.com/apache/arrow-adbc/blob/b97e22c4d6524b60bf261e1970155500645be510/adbc.h#L909-L911
table_types.Add(null);

byte** bTable_type = (byte**)Marshal.AllocHGlobal(IntPtr.Size * table_types.Count);

for (int i = 0; i < table_types.Count; i++)
{
string tableType = table_types[i];
#if NETSTANDARD
bTable_type[i] = (byte*)MarshalExtensions.StringToCoTaskMemUTF8(tableType);
#else
bTable_type[i] = (byte*)Marshal.StringToCoTaskMemUTF8(tableType);
#endif
}

using (Utf8Helper helper = new Utf8Helper(catalog))
{
bcatalog = (byte*)(IntPtr)(helper);
}

using (Utf8Helper helper = new Utf8Helper(db_schema))
{
bDb_schema = (byte*)(IntPtr)(helper);
}

using (Utf8Helper helper = new Utf8Helper(table_name))
{
bTable_name = (byte*)(IntPtr)(helper);
}

using (Utf8Helper helper = new Utf8Helper(column_name))
{
bColumn_Name = (byte*)(IntPtr)(helper);
}

fixed (CAdbcConnection* cn = &connection)
fixed (CAdbcError* e = &_error)
{
#if NET5_0_OR_GREATER
TranslateCode(fn(cn, depth, bcatalog, bDb_schema, bTable_name, bTable_type, bColumn_Name, stream, e));
#else
TranslateCode(Marshal.GetDelegateForFunctionPointer<CAdbcDriverExporter.ConnectionGetObjects>(fn)(cn, depth, bcatalog, bDb_schema, bTable_name, bTable_type, bColumn_Name, stream, e));
#endif
}
}

private unsafe void TranslateCode(AdbcStatusCode statusCode)
{
if (statusCode != AdbcStatusCode.Success)
Expand All @@ -580,7 +700,9 @@ private unsafe void TranslateCode(AdbcStatusCode statusCode)
message = Marshal.PtrToStringUTF8((IntPtr)_error.message);
#endif
}

Dispose();

throw new AdbcException(message);
}
}
Expand Down
2 changes: 1 addition & 1 deletion csharp/src/Apache.Arrow.Adbc/StandardSchemas.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public static class StandardSchemas
)
},
// TBD if this line is the best approach but its a good one-liner
new int[] {0, 1, 2, 3, 4, 5}.SelectMany(BitConverter.GetBytes).ToArray(),
new int[] {0, 1, 2, 3, 4, 5}.ToArray(),
UnionMode.Dense),
true)
},
Expand Down

0 comments on commit 7933e16

Please sign in to comment.