Skip to content

Commit

Permalink
feat(go/adbc/driver/snowflake): support PEM decoding JWT private keys (
Browse files Browse the repository at this point in the history
…#1199)

Resolves #1198.
  • Loading branch information
superhawk610 authored Oct 20, 2023
1 parent abf67e9 commit 126235b
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 1 deletion.
91 changes: 91 additions & 0 deletions go/adbc/driver/snowflake/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@ package snowflake_test

import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"database/sql"
"encoding/base64"
"encoding/pem"
"fmt"
"os"
"strconv"
Expand Down Expand Up @@ -672,3 +677,89 @@ func (suite *SnowflakeTests) TestUseHighPrecision() {
suite.Equal(1234567.89, rec.Column(1).(*array.Float64).Value(0))
suite.Equal(9876543210.99, rec.Column(1).(*array.Float64).Value(1))
}

func (suite *SnowflakeTests) TestJwtPrivateKey() {
// grab the username from the DSN
cfg, err := gosnowflake.ParseDSN(suite.Quirks.dsn)
suite.NoError(err)
username := cfg.User

// write the generated RSA key out to a file
writeKey := func(filename string, key []byte) string {
f, err := os.CreateTemp("", filename)
suite.NoError(err)
_, err = f.Write(key)
suite.NoError(err)
return f.Name()
}

// set the Snowflake user's RSA public key
setKey := func(privKey *rsa.PrivateKey) {
suite.NoError(suite.stmt.SetSqlQuery("USE ROLE ACCOUNTADMIN"))
_, err := suite.stmt.ExecuteUpdate(suite.ctx)
suite.NoError(err)

if privKey != nil {
pubKeyBytes, err := x509.MarshalPKIXPublicKey(privKey.Public())
suite.NoError(err)
encodedKey := base64.StdEncoding.EncodeToString(pubKeyBytes)
suite.NoError(suite.stmt.SetSqlQuery(fmt.Sprintf("ALTER USER %s SET RSA_PUBLIC_KEY='%s'", username, encodedKey)))
} else {
suite.NoError(suite.stmt.SetSqlQuery(fmt.Sprintf("ALTER USER %s SET RSA_PUBLIC_KEY=''", username)))
}
_, err = suite.stmt.ExecuteUpdate(suite.ctx)
suite.NoError(err)
}

// open a new connection using JWT authentication and verify that a simple query runs
verifyKey := func(keyFile string) {
opts := suite.Quirks.DatabaseOptions()
opts[driver.OptionAuthType] = driver.OptionValueAuthJwt
opts[driver.OptionJwtPrivateKey] = keyFile
db, err := suite.driver.NewDatabase(opts)
suite.NoError(err)
cnxn, err := db.Open(suite.ctx)
suite.NoError(err)
defer cnxn.Close()
stmt, err := cnxn.NewStatement()
suite.NoError(err)
defer stmt.Close()

suite.NoError(stmt.SetSqlQuery("SELECT 1"))
rdr, _, err := stmt.ExecuteQuery(suite.ctx)
defer rdr.Release()
suite.NoError(err)
}

// generate a key and set it the Snowflake user
rsaKey, _ := rsa.GenerateKey(rand.Reader, 2048)
setKey(rsaKey)

// when the test concludes, reset the user's key
defer setKey(nil)

// PKCS1 key
rsaKeyPem := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(rsaKey),
})
pkcs1Key := writeKey("key.pem", rsaKeyPem)
defer os.Remove(pkcs1Key)
verifyKey(pkcs1Key)

// PKCS8 key
rsaKeyP8Bytes, _ := x509.MarshalPKCS8PrivateKey(rsaKey)
rsaKeyP8 := pem.EncodeToMemory(&pem.Block{
Type: "PRIVATE KEY",
Bytes: rsaKeyP8Bytes,
})
pkcs8Key := writeKey("key.p8", rsaKeyP8)
defer os.Remove(pkcs8Key)
verifyKey(pkcs8Key)

// binary key
block, _ := pem.Decode([]byte(rsaKeyPem))
binKey := writeKey("key.bin", block.Bytes)
defer os.Remove(binKey)
verifyKey(binKey)
}
26 changes: 25 additions & 1 deletion go/adbc/driver/snowflake/snowflake_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@ package snowflake

import (
"context"
"crypto/rsa"
"crypto/x509"
"database/sql"
"encoding/pem"
"errors"
"fmt"
"net/url"
"os"
"strconv"
"strings"
"time"

"github.com/apache/arrow-adbc/go/adbc"
Expand Down Expand Up @@ -328,13 +332,33 @@ func (d *databaseImpl) SetOptions(cnOptions map[string]string) error {
}
}

d.cfg.PrivateKey, err = x509.ParsePKCS1PrivateKey(data)
var block []byte
if strings.Contains(string(data), "PRIVATE KEY") {
b, _ := pem.Decode(data)
block = b.Bytes
} else {
block = data
}

var key *rsa.PrivateKey
key, err = x509.ParsePKCS1PrivateKey(block)
if err != nil && strings.Contains(err.Error(), "use ParsePKCS8PrivateKey instead") {
var pkcs8Key any
pkcs8Key, err = x509.ParsePKCS8PrivateKey(block)
key, ok = pkcs8Key.(*rsa.PrivateKey)
if !ok {
err = errors.New("file does not contain an RSA private key")
}
}

if err != nil {
return adbc.Error{
Msg: "failed parsing private key file '" + v + "': " + err.Error(),
Code: adbc.StatusInvalidArgument,
}
}

d.cfg.PrivateKey = key
case OptionClientRequestMFAToken:
switch v {
case adbc.OptionValueEnabled:
Expand Down

0 comments on commit 126235b

Please sign in to comment.