Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(go/adbc/driver/snowflake): Removing SQL injection to get table name with special character for getObjectsTables #1338

105 changes: 98 additions & 7 deletions csharp/test/Drivers/Interop/Snowflake/DriverTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/*
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
Expand All @@ -16,6 +16,7 @@
*/

using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using Apache.Arrow.Adbc.Tests.Metadata;
Expand Down Expand Up @@ -110,10 +111,7 @@ public void CanExecuteUpdate()
for (int i = 0; i < queries.Length; i++)
{
string query = queries[i];
using AdbcStatement statement = _connection.CreateStatement();
statement.SqlQuery = query;

UpdateResult updateResult = statement.ExecuteUpdate();
UpdateResult updateResult = ExecuteUpdateStatement(query);

Assert.Equal(expectedResults[i], updateResult.AffectedRows);
}
Expand Down Expand Up @@ -279,17 +277,66 @@ public void CanGetObjectsAll()
{
IEnumerable<AdbcColumn> highPrecisionColumns = columns.Where(c => c.XdbcTypeName == "NUMBER");

if(highPrecisionColumns.Count() > 0)
if (highPrecisionColumns.Count() > 0)
{
// ensure they all are coming back as XdbcDataType_XDBC_DECIMAL because they are Decimal128
short XdbcDataType_XDBC_DECIMAL = 3;
IEnumerable<AdbcColumn> invalidHighPrecisionColumns = highPrecisionColumns.Where(c => c.XdbcSqlDataType != XdbcDataType_XDBC_DECIMAL);
IEnumerable<AdbcColumn> invalidHighPrecisionColumns = highPrecisionColumns.Where(c => c.XdbcSqlDataType != XdbcDataType_XDBC_DECIMAL);
int count = invalidHighPrecisionColumns.Count();
Assert.True(count == 0, $"There are {count} columns that do not map to the correct XdbcSqlDataType when UseHighPrecision=true");
}
}
}

/// <summary>
/// Validates if the driver can call GetObjects with GetObjectsDepth as Tables with TableName as a Special Character.
/// </summary>
[SkippableTheory, Order(3)]
[InlineData(@"ADBCDEMO_DB",@"PUBLIC","MyIdentifier")]
[InlineData(@"ADBCDEMO'DB", @"PUBLIC'SCHEMA","my.identifier")]
[InlineData(@"ADBCDEM""DB", @"PUBLIC""SCHEMA", "my.identifier")]
[InlineData(@"ADBCDEMO_DB", @"PUBLIC", "my identifier")]
[InlineData(@"ADBCDEMO_DB", @"PUBLIC", "My 'Identifier'")]
[InlineData(@"ADBCDEMO_DB", @"PUBLIC", "3rd_identifier")]
[InlineData(@"ADBCDEMO_DB", @"PUBLIC", "$Identifier")]
[InlineData(@"ADBCDEMO_DB", @"PUBLIC", "My ^Identifier")]
[InlineData(@"ADBCDEMO_DB", @"PUBLIC", "My ^Ident~ifier")]
[InlineData(@"ADBCDEMO_DB", @"PUBLIC", @"My\^Ident~ifier")]
[InlineData(@"ADBCDEMO_DB", @"PUBLIC", "идентификатор")]
[InlineData(@"ADBCDEMO_DB", @"PUBLIC", @"ADBCTest_""ALL""TYPES")]
[InlineData(@"ADBCDEMO_DB", @"PUBLIC", @"ADBC\TEST""\TAB_""LE")]
[InlineData(@"ADBCDEMO_DB", @"PUBLIC", "ONE")]
public void CanGetObjectsTablesWithSpecialCharacter(string databaseName, string schemaName, string tableName)
{
CreateDatabaseAndTable(databaseName, schemaName, tableName);

using IArrowArrayStream stream = _connection.GetObjects(
depth: AdbcConnection.GetObjectsDepth.Tables,
catalogPattern: databaseName,
dbSchemaPattern: schemaName,
tableNamePattern: tableName,
tableTypes: new List<string> { "BASE TABLE", "VIEW" },
columnNamePattern: null);

using RecordBatch recordBatch = stream.ReadNextRecordBatchAsync().Result;

List<AdbcCatalog> catalogs = GetObjectsParser.ParseCatalog(recordBatch, databaseName, schemaName);

List<AdbcTable> tables = catalogs
.Where(c => string.Equals(c.Name, databaseName))
.Select(c => c.DbSchemas)
.FirstOrDefault()
.Where(s => string.Equals(s.Name, schemaName))
.Select(s => s.Tables)
.FirstOrDefault();

AdbcTable table = tables.FirstOrDefault();

Assert.True(table != null, "table should not be null");
Assert.Equal(tableName, table.Name, true);
DropDatabaseAndTable(databaseName, schemaName, tableName);
}

/// <summary>
/// Validates if the driver can call GetTableSchema.
/// </summary>
Expand Down Expand Up @@ -354,6 +401,50 @@ public void CanExecuteQuery()
Tests.DriverTests.CanExecuteQuery(queryResult, _testConfiguration.ExpectedResultsCount);
}

private void CreateDatabaseAndTable(string databaseName, string schemaName, string tableName)
{
databaseName = databaseName.Replace("\"", "\"\"");
schemaName = schemaName.Replace("\"", "\"\"");
tableName = tableName.Replace("\"", "\"\"");

string createDatabase = string.Format("CREATE DATABASE IF NOT EXISTS \"{0}\"", databaseName);
ExecuteUpdateStatement(createDatabase);

string createSchema = string.Format("CREATE SCHEMA IF NOT EXISTS \"{0}\".\"{1}\"", databaseName, schemaName);
ExecuteUpdateStatement(createSchema);

string fullyQualifiedTableName = string.Format("\"{0}\".\"{1}\".\"{2}\"", databaseName, schemaName, tableName);
string createTableStatement = string.Format("CREATE OR REPLACE TABLE {0} (INDEX INT)", fullyQualifiedTableName);
ExecuteUpdateStatement(createTableStatement);

}

private void DropDatabaseAndTable(string databaseName, string schemaName, string tableName)
{
tableName = tableName.Replace("\"", "\"\"");
schemaName = schemaName.Replace("\"", "\"\"");
databaseName = databaseName.Replace("\"", "\"\"");

string fullyQualifiedTableName = string.Format("\"{0}\".\"{1}\".\"{2}\"", databaseName, schemaName, tableName);
string createTableStatement = string.Format("DROP TABLE IF EXISTS {0} ", fullyQualifiedTableName);
ExecuteUpdateStatement(createTableStatement);

string createSchema = string.Format("DROP SCHEMA IF EXISTS \"{0}\".\"{1}\"", databaseName, schemaName);
ExecuteUpdateStatement(createSchema);

string createDatabase = string.Format("DROP DATABASE IF EXISTS \"{0}\"", databaseName);
ExecuteUpdateStatement(createDatabase);

}

private UpdateResult ExecuteUpdateStatement(string query)
{
using AdbcStatement statement = _connection.CreateStatement();
statement.SqlQuery = query;
UpdateResult updateResult = statement.ExecuteUpdate();
return updateResult;
}

private static string GetPartialNameForPatternMatch(string name)
{
if (string.IsNullOrEmpty(name) || name.Length == 1) return name;
Expand Down
4 changes: 2 additions & 2 deletions go/adbc/driver/flightsql/flightsql_connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ func (c *cnxn) readInfo(ctx context.Context, expectedSchema *arrow.Schema, info
}

// Helper function to build up a map of catalogs to DB schemas
func (c *cnxn) getObjectsDbSchemas(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string) (result map[string][]string, err error) {
func (c *cnxn) getObjectsDbSchemas(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, metadataRecords []internal.Metadata) (result map[string][]string, err error) {
if depth == adbc.ObjectDepthCatalogs {
return
}
Expand Down Expand Up @@ -588,7 +588,7 @@ func (c *cnxn) getObjectsDbSchemas(ctx context.Context, depth adbc.ObjectDepth,
return
}

func (c *cnxn) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string) (result internal.SchemaToTableInfo, err error) {
func (c *cnxn) getObjectsTables(ctx context.Context, depth adbc.ObjectDepth, catalog *string, dbSchema *string, tableName *string, columnName *string, tableType []string, metadataRecords []internal.Metadata) (result internal.SchemaToTableInfo, err error) {
if depth == adbc.ObjectDepthCatalogs || depth == adbc.ObjectDepthDBSchemas {
return
}
Expand Down
21 changes: 17 additions & 4 deletions go/adbc/driver/internal/shared_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ package internal

import (
"context"
"database/sql"
"regexp"
"strconv"
"strings"
"time"

"github.com/apache/arrow-adbc/go/adbc"
"github.com/apache/arrow/go/v14/arrow"
Expand All @@ -38,8 +40,18 @@ type TableInfo struct {
Schema *arrow.Schema
}

type GetObjDBSchemasFn func(ctx context.Context, depth adbc.ObjectDepth, catalog *string, schema *string) (map[string][]string, error)
type GetObjTablesFn func(ctx context.Context, depth adbc.ObjectDepth, catalog *string, schema *string, tableName *string, columnName *string, tableType []string) (map[CatalogAndSchema][]TableInfo, error)
type Metadata struct {
Created time.Time
ColName, DataType string
Dbname, Kind, Schema, TblName, TblType, IdentGen, IdentIncrement, Comment sql.NullString
OrdinalPos int
NumericPrec, NumericPrecRadix, NumericScale, DatetimePrec sql.NullInt16
IsNullable, IsIdent bool
CharMaxLength, CharOctetLength sql.NullInt32
}

type GetObjDBSchemasFn func(ctx context.Context, depth adbc.ObjectDepth, catalog *string, schema *string, metadataRecords []Metadata) (map[string][]string, error)
type GetObjTablesFn func(ctx context.Context, depth adbc.ObjectDepth, catalog *string, schema *string, tableName *string, columnName *string, tableType []string, metadataRecords []Metadata) (map[CatalogAndSchema][]TableInfo, error)
type SchemaToTableInfo = map[CatalogAndSchema][]TableInfo

// Helper function that compiles a SQL-style pattern (%, _) to a regex
Expand Down Expand Up @@ -87,6 +99,7 @@ type GetObjects struct {
builder *array.RecordBuilder
schemaLookup map[string][]string
tableLookup map[CatalogAndSchema][]TableInfo
MetadataRecords []Metadata
catalogPattern *regexp.Regexp
columnNamePattern *regexp.Regexp

Expand Down Expand Up @@ -123,13 +136,13 @@ type GetObjects struct {
}

func (g *GetObjects) Init(mem memory.Allocator, getObj GetObjDBSchemasFn, getTbls GetObjTablesFn) error {
if catalogToDbSchemas, err := getObj(g.Ctx, g.Depth, g.Catalog, g.DbSchema); err != nil {
if catalogToDbSchemas, err := getObj(g.Ctx, g.Depth, g.Catalog, g.DbSchema, g.MetadataRecords); err != nil {
return err
} else {
g.schemaLookup = catalogToDbSchemas
}

if tableLookup, err := getTbls(g.Ctx, g.Depth, g.Catalog, g.DbSchema, g.TableName, g.ColumnName, g.TableType); err != nil {
if tableLookup, err := getTbls(g.Ctx, g.Depth, g.Catalog, g.DbSchema, g.TableName, g.ColumnName, g.TableType, g.MetadataRecords); err != nil {
return err
} else {
g.tableLookup = tableLookup
Expand Down
Loading
Loading