Skip to content

Commit

Permalink
Merge pull request #21 from vejed/feature-transit-sign-verify
Browse files Browse the repository at this point in the history
Feature transit sign verify
  • Loading branch information
Lucaber authored Oct 19, 2022
2 parents 9813995 + 6bfbd29 commit ddeb4b7
Show file tree
Hide file tree
Showing 2 changed files with 231 additions and 0 deletions.
167 changes: 167 additions & 0 deletions transit.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ package vault
import (
"encoding/base64"
"errors"
"fmt"
"net/http"
"net/url"
"regexp"
"strconv"

"github.com/hashicorp/vault/api"
)
Expand Down Expand Up @@ -308,6 +311,170 @@ func (t *Transit) DecryptBatch(key string, opts TransitDecryptOptionsBatch) (*Tr
return res, nil
}

type TransitSignOptions struct {
Input string `json:"input"`
KeyVersion *int `json:"key_version,omitempty"`
HashAlgorithm string `json:"hash_algorithm,omitempty"`
Context string `json:"context,omitempty"`
Prehashed bool `json:"prehashed,omitempty"`
SignatureAlgorithm string `json:"signature_algorithm,omitempty"`
MarshalingAlgorithm string `json:"marshaling_algorithm,omitempty"`
SaltLength string `json:"salt_length,omitempty"`
}

type TransitSignResponse struct {
Data struct {
Signature string `json:"signature"`
KeyVersion int `json:"key_version,omitempty"`
} `json:"data"`
}

func (t *Transit) Sign(key string, opts *TransitSignOptions) (*TransitSignResponse, error) {
res := &TransitSignResponse{}

opts.Input = base64.StdEncoding.EncodeToString([]byte(opts.Input))

err := t.client.Write([]string{"v1", t.MountPoint, "sign", url.PathEscape(key)}, opts, res, nil)
if err != nil {
return nil, err
}

return res, nil
}

type TransitBatchSignInput struct {
Input string `json:"input"`
Context string `json:"context,omitempty"`
}

type TransitBatchSignature struct {
Signature string `json:"signature"`
KeyVersion int `json:"key_version,omitempty"`
}

type TransitSignOptionsBatch struct {
BatchInput []TransitBatchSignInput `json:"batch_input"`
KeyVersion *int `json:"key_version,omitempty"`
HashAlgorithm string `json:"hash_algorithm,omitempty"`
Prehashed bool `json:"prehashed,omitempty"`
SignatureAlgorithm string `json:"signature_algorithm,omitempty"`
MarshalingAlgorithm string `json:"marshaling_algorithm,omitempty"`
SaltLength string `json:"salt_length,omitempty"`
}

type TransitSignResponseBatch struct {
Data struct {
BatchResults []TransitBatchSignature `json:"batch_results"`
} `json:"data"`
}

func (t *Transit) SignBatch(key string, opts *TransitSignOptionsBatch) (*TransitSignResponseBatch, error) {
res := &TransitSignResponseBatch{}

for i := range opts.BatchInput {
opts.BatchInput[i].Input = base64.StdEncoding.EncodeToString([]byte(opts.BatchInput[i].Input))
}

err := t.client.Write([]string{"v1", t.MountPoint, "sign", url.PathEscape(key)}, opts, res, nil)
if err != nil {
return nil, err
}

return res, nil
}

type TransitVerifyOptions struct {
Input string `json:"input"`
Signature string `json:"signature"`
HashAlgorithm string `json:"hash_algorithm,omitempty"`
Context string `json:"context,omitempty"`
Prehashed bool `json:"prehashed,omitempty"`
SignatureAlgorithm string `json:"signature_algorithm,omitempty"`
MarshalingAlgorithm string `json:"marshaling_algorithm,omitempty"`
SaltLength string `json:"salt_length,omitempty"`
}

type TransitVerifyResponse struct {
Data struct {
Valid bool `json:"valid"`
} `json:"data"`
}

func (t *Transit) Verify(key string, opts *TransitVerifyOptions) (*TransitVerifyResponse, error) {
res := &TransitVerifyResponse{}

opts.Input = base64.StdEncoding.EncodeToString([]byte(opts.Input))

err := t.client.Write([]string{"v1", t.MountPoint, "verify", url.PathEscape(key)}, opts, res, nil)
if err != nil {
return nil, err
}

return res, nil
}

type TransitBatchVerifyInput struct {
Input string `json:"input"`
Signature string `json:"signature"`
Context string `json:"context,omitempty"`
}

type TransitBatchVerifyData struct {
Valid bool `json:"valid"`
}

type TransitVerifyOptionsBatch struct {
BatchInput []TransitBatchVerifyInput `json:"batch_input"`
HashAlgorithm string `json:"hash_algorithm,omitempty"`
Context string `json:"context,omitempty"`
Prehashed bool `json:"prehashed,omitempty"`
SignatureAlgorithm string `json:"signature_algorithm,omitempty"`
MarshalingAlgorithm string `json:"marshaling_algorithm,omitempty"`
SaltLength string `json:"salt_length,omitempty"`
}

type TransitVerifyResponseBatch struct {
Data struct {
BatchResults []TransitBatchVerifyData `json:"batch_results"`
} `json:"data"`
}

func (t *Transit) VerifyBatch(key string, opts *TransitVerifyOptionsBatch) (*TransitVerifyResponseBatch, error) {
res := &TransitVerifyResponseBatch{}

for i := range opts.BatchInput {
opts.BatchInput[i].Input = base64.StdEncoding.EncodeToString([]byte(opts.BatchInput[i].Input))
}

err := t.client.Write([]string{"v1", t.MountPoint, "verify", url.PathEscape(key)}, opts, res, nil)
if err != nil {
return nil, err
}

return res, nil
}

// DecodeCipherText gets payload from vault ciphertext format (removes "vault:v<ver>:" prefix)
func DecodeCipherText(vaultCipherText string) (string, int, error) {
regex := regexp.MustCompile(`^vault:v(\d+):(.+)$`)
matches := regex.FindStringSubmatch(vaultCipherText)
if len(matches) != 3 {
return "", 0, errors.New("invalid vault ciphertext format")
}

keyVersion, err := strconv.Atoi(matches[1])
if err != nil {
return "", 0, errors.New("can't parse key version")
}

return matches[2], keyVersion, nil
}

// EncodeCipherText encodes payload to vault ciphertext format (adda "vault:v<ver>:" prefix)
func EncodeCipherText(cipherText string, keyVersion int) string {
return fmt.Sprintf("vault:v%d:%s", keyVersion, cipherText)
}

func (t *Transit) mapError(err error) error {
resErr := &api.ResponseError{}
if errors.As(err, &resErr) {
Expand Down
64 changes: 64 additions & 0 deletions transit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,67 @@ func (s *TransitTestSuite) TestCreateKeyThatDoesAlreadyExist() {
err = s.client.Create("testCeateKeyThatDoesAlreadyExist", &TransitCreateOptions{})
require.NoError(s.T(), err)
}

func (s *TransitTestSuite) TestSignVerify() {
err := s.client.Create("testSignVerify", &TransitCreateOptions{Type: "rsa-2048"})
require.NoError(s.T(), err)

text := "test"

signRes, err := s.client.Sign("testSignVerify", &TransitSignOptions{
Input: text,
})
require.NoError(s.T(), err)

verifyRes, err := s.client.Verify("testSignVerify", &TransitVerifyOptions{
Input: text,
Signature: signRes.Data.Signature,
})
require.NoError(s.T(), err)

s.True(verifyRes.Data.Valid)
}

func (s *TransitTestSuite) TestSignVerifyBatch() {
err := s.client.Create("testSignVerify", &TransitCreateOptions{Type: "rsa-2048"})
require.NoError(s.T(), err)

text1 := "test1"
text2 := "test2"

signRes, err := s.client.SignBatch("testSignVerify", &TransitSignOptionsBatch{
BatchInput: []TransitBatchSignInput{
{Input: text1},
{Input: text2},
},
})
require.NoError(s.T(), err)

verifyRes, err := s.client.VerifyBatch("testSignVerify", &TransitVerifyOptionsBatch{
BatchInput: []TransitBatchVerifyInput{
{Input: text1, Signature: signRes.Data.BatchResults[0].Signature},
{Input: text2, Signature: signRes.Data.BatchResults[1].Signature},
},
})
require.NoError(s.T(), err)

s.True(verifyRes.Data.BatchResults[0].Valid)
s.True(verifyRes.Data.BatchResults[1].Valid)
}

func (s *TransitTestSuite) TestDecodeCipherText() {
dec, ver, err := DecodeCipherText("vault:v123:SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c")
require.NoError(s.T(), err)
s.Equal("SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c", dec)
s.Equal(123, ver)
}

func (s *TransitTestSuite) TestDecodeCipherTextError() {
_, _, err := DecodeCipherText("vault:SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c")
s.NotNil(err)
}

func (s *TransitTestSuite) TestEncodeCipherText() {
enc := EncodeCipherText("SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c", 123)
s.Equal("vault:v123:SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c", enc)
}

0 comments on commit ddeb4b7

Please sign in to comment.