Skip to content

Commit

Permalink
feat(csharp): Add support for SqlDecimal (#1241)
Browse files Browse the repository at this point in the history
This PR:

- Adds support for SqlDecimal for Decimal128 (follow up to
apache/arrow#38481)
- Treats Decimal256 values as string (follow up to
apache/arrow#38508)
- Adds a new DecimalBehavior to the Client to allow the caller to
determine how to treat decimal values
- Standardizes the test frameworks to Xunit

Addresses #1230

---------

Co-authored-by: David Coe <[email protected]>
  • Loading branch information
davidhcoe and David Coe authored Oct 31, 2023
1 parent b4dbd9d commit 2a74e8e
Show file tree
Hide file tree
Showing 29 changed files with 938 additions and 960 deletions.
69 changes: 68 additions & 1 deletion csharp/src/Apache.Arrow.Adbc/AdbcStatement.cs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,73 @@ public virtual void Dispose()
/// <param name="index">
/// The index in the array to get the value from.
/// </param>
public abstract object GetValue(IArrowArray arrowArray, Field field, int index);
public virtual object GetValue(IArrowArray arrowArray, Field field, int index)
{
if (arrowArray == null) throw new ArgumentNullException(nameof(arrowArray));
if (field == null) throw new ArgumentNullException(nameof(field));
if (index < 0) throw new ArgumentOutOfRangeException(nameof(index));

switch (arrowArray)
{
case BooleanArray booleanArray:
return booleanArray.GetValue(index);
case Date32Array date32Array:
return date32Array.GetDateTime(index);
case Date64Array date64Array:
return date64Array.GetDateTime(index);
case Decimal128Array decimal128Array:
return decimal128Array.GetSqlDecimal(index);
case Decimal256Array decimal256Array:
return decimal256Array.GetString(index);
case DoubleArray doubleArray:
return doubleArray.GetValue(index);
case FloatArray floatArray:
return floatArray.GetValue(index);
#if NET5_0_OR_GREATER
case PrimitiveArray<Half> halfFloatArray:
return halfFloatArray.GetValue(index);
#endif
case Int8Array int8Array:
return int8Array.GetValue(index);
case Int16Array int16Array:
return int16Array.GetValue(index);
case Int32Array int32Array:
return int32Array.GetValue(index);
case Int64Array int64Array:
return int64Array.GetValue(index);
case StringArray stringArray:
return stringArray.GetString(index);
case Time32Array time32Array:
return time32Array.GetValue(index);
case Time64Array time64Array:
return time64Array.GetValue(index);
case TimestampArray timestampArray:
DateTimeOffset dateTimeOffset = timestampArray.GetTimestamp(index).Value;
return dateTimeOffset;
case UInt8Array uInt8Array:
return uInt8Array.GetValue(index);
case UInt16Array uInt16Array:
return uInt16Array.GetValue(index);
case UInt32Array uInt32Array:
return uInt32Array.GetValue(index);
case UInt64Array uInt64Array:
return uInt64Array.GetValue(index);

case BinaryArray binaryArray:
if (!binaryArray.IsNull(index))
return binaryArray.GetBytes(index).ToArray();

return null;

// not covered:
// -- struct array
// -- dictionary array
// -- fixed size binary
// -- list array
// -- union array
}

return null;
}
}
}
154 changes: 1 addition & 153 deletions csharp/src/Apache.Arrow.Adbc/C/CAdbcDriverImporter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

using System;
using System.Collections.Generic;
using System.Data;
using System.IO;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text.RegularExpressions;
using Apache.Arrow.C;
using Apache.Arrow.Ipc;

Expand Down Expand Up @@ -328,158 +328,6 @@ public override unsafe UpdateResult ExecuteUpdate()
return new UpdateResult(rows);
}
}

public override object GetValue(IArrowArray arrowArray, Field field, int index)
{
if (arrowArray is BooleanArray)
{
return ((BooleanArray)arrowArray).GetValue(index);
}
else if (arrowArray is Date32Array)
{
Date32Array date32Array = (Date32Array)arrowArray;

return date32Array.GetDateTime(index);
}
else if (arrowArray is Date64Array)
{
Date64Array date64Array = (Date64Array)arrowArray;

return date64Array.GetDateTime(index);
}
else if (arrowArray is Decimal128Array)
{
try
{
// the value may be <decimal.min or >decimal.max
// then Arrow throws an exception
// no good way to check prior to
return ((Decimal128Array)arrowArray).GetValue(index);
}
catch (OverflowException oex)
{
return ParseDecimalValueFromOverflowException(oex);
}
}
else if (arrowArray is Decimal256Array)
{
try
{
return ((Decimal256Array)arrowArray).GetValue(index);
}
catch (OverflowException oex)
{
return ParseDecimalValueFromOverflowException(oex);
}
}
else if (arrowArray is DoubleArray)
{
return ((DoubleArray)arrowArray).Values[index];
}
else if (arrowArray is FloatArray)
{
return ((FloatArray)arrowArray).GetValue(index);
}
#if NET5_0_OR_GREATER
else if (arrowArray is PrimitiveArray<Half>)
{
// TODO: HalfFloatArray not present in current library

return ((PrimitiveArray<Half>)arrowArray).GetValue(index);
}
#endif
else if (arrowArray is Int8Array)
{
Int8Array array = (Int8Array)arrowArray;
return array.GetValue(index);
}
else if (arrowArray is Int16Array)
{
return ((Int16Array)arrowArray).Values[index];
}
else if (arrowArray is Int32Array)
{
return ((Int32Array)arrowArray).Values[index];
}
else if (arrowArray is Int64Array)
{
Int64Array array = (Int64Array)arrowArray;

return array.GetValue(index);
}
else if (arrowArray is StringArray)
{
return ((StringArray)arrowArray).GetString(index);
}
else if (arrowArray is Time32Array)
{
return ((Time32Array)arrowArray).GetValue(index);
}
else if (arrowArray is Time64Array)
{
return ((Time64Array)arrowArray).GetValue(index);
}
else if (arrowArray is TimestampArray)
{
TimestampArray timestampArray = (TimestampArray)arrowArray;
DateTimeOffset dateTimeOffset = timestampArray.GetTimestamp(index).Value;
return dateTimeOffset;
}
else if (arrowArray is UInt8Array)
{
return ((UInt8Array)arrowArray).GetValue(index);
}
else if (arrowArray is UInt16Array)
{
return ((UInt16Array)arrowArray).GetValue(index);
}
else if (arrowArray is UInt32Array)
{
return ((UInt32Array)arrowArray).GetValue(index);
}
else if (arrowArray is UInt64Array)
{
return ((UInt64Array)arrowArray).GetValue(index);
}
else if (arrowArray is BinaryArray)
{
ReadOnlySpan<byte> bytes = ((BinaryArray)arrowArray).GetBytes(index);

if (bytes != null)
return bytes.ToArray();
}

// not covered:
// -- struct array
// -- dictionary array
// -- fixed size binary
// -- list array
// -- union array

return null;
}

private string ParseDecimalValueFromOverflowException(OverflowException oex)
{
if (oex == null)
throw new ArgumentNullException(nameof(oex));

// any decimal value, positive or negative, with or without a decimal in place
Regex regex = new Regex(" -?\\d*\\.?\\d* ");

var matches = regex.Matches(oex.Message);

foreach (Match match in matches)
{
string value = match.Value;

if (!string.IsNullOrEmpty(value))
return value;
}

throw oex;
}

}

/// <summary>
Expand Down
6 changes: 5 additions & 1 deletion csharp/src/Client/AdbcCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ public AdbcCommand(AdbcStatement adbcStatement, AdbcConnection adbcConnection) :

this.adbcStatement = adbcStatement;
this.DbConnection = adbcConnection;
this.DecimalBehavior = adbcConnection.DecimalBehavior;
}

/// <summary>
Expand All @@ -69,6 +70,7 @@ public AdbcCommand(string query, AdbcConnection adbcConnection) : base()
this.CommandText = query;

this.DbConnection = adbcConnection;
this.DecimalBehavior = adbcConnection.DecimalBehavior;
}

/// <summary>
Expand All @@ -77,6 +79,8 @@ public AdbcCommand(string query, AdbcConnection adbcConnection) : base()
/// </summary>
public AdbcStatement AdbcStatement => this.adbcStatement;

public DecimalBehavior DecimalBehavior { get; set; }

public override string CommandText
{
get => this.adbcStatement.SqlQuery;
Expand Down Expand Up @@ -170,7 +174,7 @@ protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior)
case CommandBehavior.SchemaOnly: // The schema is not known until a read happens
case CommandBehavior.Default:
QueryResult result = this.ExecuteQuery();
return new AdbcDataReader(this, result);
return new AdbcDataReader(this, result, this.DecimalBehavior);

default:
throw new InvalidOperationException($"{behavior} is not supported with this provider");
Expand Down
15 changes: 9 additions & 6 deletions csharp/src/Client/AdbcConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ public sealed class AdbcConnection : DbConnection
public AdbcConnection()
{
this.AdbcDriver = null;
this.DecimalBehavior = DecimalBehavior.UseSqlDecimal;
this.adbcConnectionParameters = new Dictionary<string, string>();
this.adbcConnectionOptions = new Dictionary<string, string>();
}
Expand All @@ -55,7 +56,7 @@ public AdbcConnection(string connectionString) : this()
}

/// <summary>
/// Overloaded. Intializes an <see cref="AdbcConnection"/>.
/// Overloaded. Initializes an <see cref="AdbcConnection"/>.
/// </summary>
/// <param name="adbcDriver">
/// The <see cref="AdbcDriver"/> to use for connecting. This value
Expand All @@ -67,7 +68,7 @@ public AdbcConnection(AdbcDriver adbcDriver) : this()
}

/// <summary>
/// Overloaded. Intializes an <see cref="AdbcConnection"/>.
/// Overloaded. Initializes an <see cref="AdbcConnection"/>.
/// </summary>
/// <param name="adbcDriver">
/// The <see cref="AdbcDriver"/> to use for connecting. This value
Expand Down Expand Up @@ -120,6 +121,11 @@ internal AdbcStatement AdbcStatement

public override string ConnectionString { get => GetConnectionString(); set => SetConnectionProperties(value); }

/// <summary>
/// Gets or sets the behavior of decimals.
/// </summary>
public DecimalBehavior DecimalBehavior { get; set; }

protected override DbCommand CreateDbCommand()
{
EnsureConnectionOpen();
Expand Down Expand Up @@ -220,9 +226,6 @@ public override DataTable GetSchema()
return GetSchema(null);
}

//GetSchema("TABLES")
//GetSchema("VIEWS")

public override DataTable GetSchema(string collectionName)
{
return GetSchema(collectionName, null);
Expand All @@ -231,7 +234,7 @@ public override DataTable GetSchema(string collectionName)
public override DataTable GetSchema(string collectionName, string[] restrictionValues)
{
Schema arrowSchema = this.adbcConnectionInternal.GetTableSchema("", "", "");
return SchemaConverter.ConvertArrowSchema(arrowSchema, this.AdbcStatement);
return SchemaConverter.ConvertArrowSchema(arrowSchema, this.AdbcStatement, this.DecimalBehavior);
}

#region NOT_IMPLEMENTED
Expand Down
Loading

0 comments on commit 2a74e8e

Please sign in to comment.