diff --git a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverExporter.cs b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverExporter.cs index ed6057131d..eb0c5e2520 100644 --- a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverExporter.cs +++ b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverExporter.cs @@ -75,7 +75,7 @@ public class CAdbcDriverExporter private static unsafe readonly NativeDelegate s_databaseSetOption = new NativeDelegate(SetDatabaseOption); private static IntPtr DatabaseSetOptionPtr => s_databaseSetOption.Pointer; - private unsafe delegate AdbcStatusCode ConnectionGetObjects(CAdbcConnection* connection, int depth, byte* catalog, byte* db_schema, byte* table_name, byte** table_type, byte* column_name, CArrowArrayStream* stream, CAdbcError* error); + internal unsafe delegate AdbcStatusCode ConnectionGetObjects(CAdbcConnection* connection, int depth, byte* catalog, byte* db_schema, byte* table_name, byte** table_type, byte* column_name, CArrowArrayStream* stream, CAdbcError* error); private static unsafe readonly NativeDelegate s_connectionGetObjects = new NativeDelegate(GetConnectionObjects); private static IntPtr ConnectionGetObjectsPtr => s_connectionGetObjects.Pointer; private unsafe delegate AdbcStatusCode ConnectionGetTableSchema(CAdbcConnection* connection, byte* catalog, byte* db_schema, byte* table_name, CArrowSchema* schema, CAdbcError* error); @@ -93,7 +93,7 @@ public class CAdbcDriverExporter private static IntPtr ConnectionCommitPtr => s_connectionCommit.Pointer; private static unsafe readonly NativeDelegate s_connectionRelease = new NativeDelegate(ReleaseConnection); private static IntPtr ConnectionReleasePtr => s_connectionRelease.Pointer; - private unsafe delegate AdbcStatusCode ConnectionGetInfo(CAdbcConnection* connection, byte* info_codes, int info_codes_length, CArrowArrayStream* stream, CAdbcError* error); + internal unsafe delegate AdbcStatusCode ConnectionGetInfo(CAdbcConnection* connection, byte* info_codes, int info_codes_length, CArrowArrayStream* stream, CAdbcError* error); private static unsafe readonly NativeDelegate s_connectionGetInfo = new NativeDelegate(GetConnectionInfo); private static IntPtr ConnectionGetInfoPtr => s_connectionGetInfo.Pointer; private unsafe delegate AdbcStatusCode ConnectionReadPartition(CAdbcConnection* connection, byte* serialized_partition, int serialized_length, CArrowArrayStream* stream, CAdbcError* error); diff --git a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs index fabd1700da..108087645b 100644 --- a/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs +++ b/csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs @@ -18,8 +18,10 @@ using System; using System.Collections.Generic; using System.IO; +using System.Linq; using System.Runtime.InteropServices; using Apache.Arrow.C; +using Apache.Arrow.Ipc; #if NETSTANDARD using Apache.Arrow.Adbc.Extensions; @@ -220,6 +222,38 @@ public unsafe override AdbcStatement CreateStatement() return new AdbcStatementNative(_nativeDriver, nativeStatement); } + public override IArrowArrayStream GetInfo(List codes) + { + return GetInfo(codes.Select(x => (int)x).ToList()); + } + + public override unsafe IArrowArrayStream GetInfo(List codes) + { + CArrowArrayStream* nativeArrayStream = CArrowArrayStream.Create(); + + using (CallHelper caller = new CallHelper()) + { + caller.Call(_nativeDriver.ConnectionGetInfo, ref _nativeConnection, codes, nativeArrayStream); + } + + IArrowArrayStream arrowArrayStream = CArrowArrayStreamImporter.ImportArrayStream(nativeArrayStream); + + return arrowArrayStream; + } + + public override unsafe IArrowArrayStream GetObjects(GetObjectsDepth depth, string catalogPattern, string dbSchemaPattern, string tableNamePattern, List tableTypes, string columnNamePattern) + { + CArrowArrayStream* nativeArrayStream = CArrowArrayStream.Create(); + + using (CallHelper caller = new CallHelper()) + { + caller.Call(_nativeDriver.ConnectionGetObjects, ref _nativeConnection, (int)depth, catalogPattern, dbSchemaPattern, tableNamePattern, tableTypes, columnNamePattern, nativeArrayStream); + } + + IArrowArrayStream arrowArrayStream = CArrowArrayStreamImporter.ImportArrayStream(nativeArrayStream); + + return arrowArrayStream; + } } /// @@ -567,6 +601,92 @@ public unsafe void Dispose() } } +#if NET5_0_OR_GREATER + public unsafe void Call(delegate* unmanaged fn, ref CAdbcConnection connection, List infoCodes, CArrowArrayStream* stream) +#else + public unsafe void Call(IntPtr ptr, ref CAdbcConnection connection, List infoCodes, CArrowArrayStream* stream) +#endif + { + int numInts = infoCodes.Count; + + // Calculate the total number of bytes needed + int totalBytes = numInts * sizeof(int); + + IntPtr bytePtr = Marshal.AllocHGlobal(totalBytes); + + int[] intArray = infoCodes.ToArray(); + Marshal.Copy(intArray, 0, bytePtr, numInts); + + fixed (CAdbcConnection* cn = &connection) + fixed (CAdbcError* e = &_error) + { +#if NET5_0_OR_GREATER + TranslateCode(fn(cn, (byte*)bytePtr, infoCodes.Count, stream, e)); +#else + TranslateCode(Marshal.GetDelegateForFunctionPointer(ptr)(cn, (byte*)bytePtr, infoCodes.Count, stream, e)); +#endif + } + } + +#if NET5_0_OR_GREATER + public unsafe void Call(delegate* unmanaged fn, ref CAdbcConnection connection, int depth, string catalog, string db_schema, string table_name, List table_types, string column_name, CArrowArrayStream* stream) +#else + public unsafe void Call(IntPtr fn, ref CAdbcConnection connection, int depth, string catalog, string db_schema, string table_name, List table_types, string column_name, CArrowArrayStream* stream) +#endif + { + byte* bcatalog, bDb_schema, bTable_name, bColumn_Name; + + if(table_types == null) + { + table_types = new List(); + } + + // need to terminate with a null entry per https://github.com/apache/arrow-adbc/blob/b97e22c4d6524b60bf261e1970155500645be510/adbc.h#L909-L911 + table_types.Add(null); + + byte** bTable_type = (byte**)Marshal.AllocHGlobal(IntPtr.Size * table_types.Count); + + for (int i = 0; i < table_types.Count; i++) + { + string tableType = table_types[i]; +#if NETSTANDARD + bTable_type[i] = (byte*)MarshalExtensions.StringToCoTaskMemUTF8(tableType); +#else + bTable_type[i] = (byte*)Marshal.StringToCoTaskMemUTF8(tableType); +#endif + } + + using (Utf8Helper helper = new Utf8Helper(catalog)) + { + bcatalog = (byte*)(IntPtr)(helper); + } + + using (Utf8Helper helper = new Utf8Helper(db_schema)) + { + bDb_schema = (byte*)(IntPtr)(helper); + } + + using (Utf8Helper helper = new Utf8Helper(table_name)) + { + bTable_name = (byte*)(IntPtr)(helper); + } + + using (Utf8Helper helper = new Utf8Helper(column_name)) + { + bColumn_Name = (byte*)(IntPtr)(helper); + } + + fixed (CAdbcConnection* cn = &connection) + fixed (CAdbcError* e = &_error) + { +#if NET5_0_OR_GREATER + TranslateCode(fn(cn, depth, bcatalog, bDb_schema, bTable_name, bTable_type, bColumn_Name, stream, e)); +#else + TranslateCode(Marshal.GetDelegateForFunctionPointer(fn)(cn, depth, bcatalog, bDb_schema, bTable_name, bTable_type, bColumn_Name, stream, e)); +#endif + } + } + private unsafe void TranslateCode(AdbcStatusCode statusCode) { if (statusCode != AdbcStatusCode.Success) @@ -580,7 +700,9 @@ private unsafe void TranslateCode(AdbcStatusCode statusCode) message = Marshal.PtrToStringUTF8((IntPtr)_error.message); #endif } + Dispose(); + throw new AdbcException(message); } } diff --git a/csharp/src/Apache.Arrow.Adbc/StandardSchemas.cs b/csharp/src/Apache.Arrow.Adbc/StandardSchemas.cs index 10fdd04b83..3eab6ed270 100644 --- a/csharp/src/Apache.Arrow.Adbc/StandardSchemas.cs +++ b/csharp/src/Apache.Arrow.Adbc/StandardSchemas.cs @@ -67,7 +67,7 @@ public static class StandardSchemas ) }, // TBD if this line is the best approach but its a good one-liner - new int[] {0, 1, 2, 3, 4, 5}.SelectMany(BitConverter.GetBytes).ToArray(), + new int[] {0, 1, 2, 3, 4, 5}.ToArray(), UnionMode.Dense), true) },