Skip to content

Commit

Permalink
constellation-lib: add init test
Browse files Browse the repository at this point in the history
Signed-off-by: Moritz Sanft <[email protected]>
  • Loading branch information
msanft committed Dec 1, 2023
1 parent 7e105a8 commit 844d5a9
Show file tree
Hide file tree
Showing 4 changed files with 333 additions and 2 deletions.
27 changes: 26 additions & 1 deletion internal/constellation/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
load("//bazel/go:go_test.bzl", "go_test")

go_library(
name = "constellation",
Expand All @@ -12,7 +13,6 @@ go_library(
deps = [
"//bootstrapper/initproto",
"//internal/cloud/cloudprovider",
"//internal/config",
"//internal/constants",
"//internal/crypto",
"//internal/grpc/grpclog",
Expand All @@ -25,3 +25,28 @@ go_library(
"@org_golang_google_grpc//:go_default_library",
],
)

go_test(
name = "constellation_test",
srcs = [
"apply_test.go",
"applyinit_test.go",
],
embed = [":constellation"],
deps = [
"//bootstrapper/initproto",
"//internal/cloud/cloudprovider",
"//internal/constants",
"//internal/crypto",
"//internal/grpc/atlscredentials",
"//internal/grpc/dialer",
"//internal/grpc/testdialer",
"//internal/kms/uri",
"//internal/license",
"//internal/logger",
"//internal/state",
"@com_github_stretchr_testify//assert",
"@com_github_stretchr_testify//require",
"@org_golang_google_grpc//:go_default_library",
],
)
6 changes: 5 additions & 1 deletion internal/constellation/apply.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,14 @@ import (
// In Particular, this involves Initialization and Upgrading of the cluster.
type Applier struct {
log debugLog
licenseChecker *license.Checker
licenseChecker licenseChecker
spinner spinnerInterf
}

type licenseChecker interface {
CheckLicense(context.Context, cloudprovider.Provider, string) (license.QuotaCheckResponse, error)
}

type debugLog interface {
Debugf(format string, args ...any)
}
Expand Down
74 changes: 74 additions & 0 deletions internal/constellation/apply_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/

package constellation

import (
"context"
"testing"

"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/crypto"
"github.com/edgelesssys/constellation/v2/internal/license"
"github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestCheckLicense(t *testing.T) {
testCases := map[string]struct {
licenseChecker *stubLicenseChecker
wantErr bool
}{
"success": {
licenseChecker: &stubLicenseChecker{},
wantErr: false,
},
"check license error": {
licenseChecker: &stubLicenseChecker{checkLicenseErr: assert.AnError},
wantErr: true,
},
}

for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
require := require.New(t)

a := &Applier{licenseChecker: tc.licenseChecker, log: logger.NewTest(t)}
_, err := a.CheckLicense(context.Background(), cloudprovider.Unknown, license.CommunityLicense)
if tc.wantErr {
require.Error(err)
} else {
require.NoError(err)
}
})
}
}

type stubLicenseChecker struct {
checkLicenseErr error
}

func (c *stubLicenseChecker) CheckLicense(context.Context, cloudprovider.Provider, string) (license.QuotaCheckResponse, error) {
return license.QuotaCheckResponse{}, c.checkLicenseErr
}

func TestGenerateMasterSecret(t *testing.T) {
assert := assert.New(t)
a := &Applier{log: logger.NewTest(t)}
sec, err := a.GenerateMasterSecret()
assert.NoError(err)
assert.Len(sec.Key, crypto.MasterSecretLengthDefault)
assert.Len(sec.Key, crypto.RNGLengthDefault)
}

func TestGenerateMeasurementSalt(t *testing.T) {
assert := assert.New(t)
a := &Applier{log: logger.NewTest(t)}
salt, err := a.GenerateMeasurementSalt()
assert.NoError(err)
assert.Len(salt, crypto.RNGLengthDefault)
}
228 changes: 228 additions & 0 deletions internal/constellation/applyinit_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/

package constellation

import (
"bytes"
"context"
"io"
"net"
"strconv"
"testing"
"time"

"github.com/edgelesssys/constellation/v2/bootstrapper/initproto"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/grpc/atlscredentials"
"github.com/edgelesssys/constellation/v2/internal/grpc/dialer"
"github.com/edgelesssys/constellation/v2/internal/grpc/testdialer"
"github.com/edgelesssys/constellation/v2/internal/kms/uri"
"github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/edgelesssys/constellation/v2/internal/state"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
)

func TestInit(t *testing.T) {
clusterEndpoint := "192.0.2.1"
newState := func(endpoint string) *state.State {
return &state.State{
Infrastructure: state.Infrastructure{
ClusterEndpoint: endpoint,
},
}
}
newInitServer := func(initErr error, responses ...*initproto.InitResponse) *stubInitServer {
return &stubInitServer{
res: responses,
initErr: initErr,
}
}

testCases := map[string]struct {
server initproto.APIServer
state *state.State
initServerEndpoint string
wantClusterLogs []byte
wantErr bool
}{
"success": {
server: newInitServer(nil,
&initproto.InitResponse{
Kind: &initproto.InitResponse_InitSuccess{
InitSuccess: &initproto.InitSuccessResponse{
Kubeconfig: []byte{},
OwnerId: []byte{},
ClusterId: []byte{},
},
},
}),
state: newState(clusterEndpoint),
initServerEndpoint: clusterEndpoint,
},
"no response": {
server: newInitServer(nil),
state: newState(clusterEndpoint),
initServerEndpoint: clusterEndpoint,
wantErr: true,
},
"nil response": {
server: newInitServer(nil, &initproto.InitResponse{Kind: nil}),
state: newState(clusterEndpoint),
initServerEndpoint: clusterEndpoint,
wantErr: true,
},
"failure response": {
server: newInitServer(nil,
&initproto.InitResponse{
Kind: &initproto.InitResponse_InitFailure{
InitFailure: &initproto.InitFailureResponse{
Error: assert.AnError.Error(),
},
},
}),
state: newState(clusterEndpoint),
initServerEndpoint: clusterEndpoint,
wantErr: true,
},
"setup server error": {
server: newInitServer(assert.AnError),
state: newState(clusterEndpoint),
initServerEndpoint: clusterEndpoint,
wantErr: true,
},
"expected log response, got failure": {
server: newInitServer(nil,
&initproto.InitResponse{
Kind: &initproto.InitResponse_InitFailure{
InitFailure: &initproto.InitFailureResponse{
Error: assert.AnError.Error(),
},
},
},
&initproto.InitResponse{
Kind: &initproto.InitResponse_InitFailure{
InitFailure: &initproto.InitFailureResponse{
Error: assert.AnError.Error(),
},
},
},
),
state: newState(clusterEndpoint),
initServerEndpoint: clusterEndpoint,
wantErr: true,
},
"expected log response, got success": {
server: newInitServer(nil,
&initproto.InitResponse{
Kind: &initproto.InitResponse_InitFailure{
InitFailure: &initproto.InitFailureResponse{
Error: assert.AnError.Error(),
},
},
},
&initproto.InitResponse{
Kind: &initproto.InitResponse_InitSuccess{
InitSuccess: &initproto.InitSuccessResponse{
Kubeconfig: []byte{},
OwnerId: []byte{},
ClusterId: []byte{},
},
},
},
),
state: newState(clusterEndpoint),
initServerEndpoint: clusterEndpoint,
wantErr: true,
},
"collect logs": {
server: newInitServer(nil,
&initproto.InitResponse{
Kind: &initproto.InitResponse_InitFailure{
InitFailure: &initproto.InitFailureResponse{
Error: assert.AnError.Error(),
},
},
},
&initproto.InitResponse{
Kind: &initproto.InitResponse_Log{
Log: &initproto.LogResponseType{
Log: []byte("some log"),
},
},
},
),
wantClusterLogs: []byte("some log"),
state: newState(clusterEndpoint),
initServerEndpoint: clusterEndpoint,
wantErr: true,
},
}

for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := require.New(t)

netDialer := testdialer.NewBufconnDialer()
dialer := dialer.New(nil, nil, netDialer)
stop := setupTestInitServer(netDialer, tc.server, tc.initServerEndpoint)
defer stop()

a := &Applier{log: logger.NewTest(t), spinner: &nopSpinner{}}

clusterLogs := &bytes.Buffer{}
ctx, cancel := context.WithTimeout(context.Background(), time.Second*4)
defer cancel()
_, err := a.Init(ctx, dialer, tc.state, clusterLogs, InitPayload{
MasterSecret: uri.MasterSecret{},
MeasurementSalt: []byte{},
K8sVersion: "v1.26.5",
ConformanceMode: false,
})
if tc.wantErr {
assert.Error(err)
assert.Equal(tc.wantClusterLogs, clusterLogs.Bytes())
} else {
assert.NoError(err)
}
})
}
}

type nopSpinner struct {
io.Writer
}

func (s *nopSpinner) Start(string, bool) {}
func (s *nopSpinner) Stop() {}
func (s *nopSpinner) Write(p []byte) (n int, err error) {
return s.Writer.Write(p)
}

func setupTestInitServer(dialer *testdialer.BufconnDialer, server initproto.APIServer, host string) func() {
serverCreds := atlscredentials.New(nil, nil)
initServer := grpc.NewServer(grpc.Creds(serverCreds))
initproto.RegisterAPIServer(initServer, server)
listener := dialer.GetListener(net.JoinHostPort(host, strconv.Itoa(constants.BootstrapperPort)))
go initServer.Serve(listener)
return initServer.GracefulStop
}

type stubInitServer struct {
res []*initproto.InitResponse
initErr error

initproto.UnimplementedAPIServer
}

func (s *stubInitServer) Init(_ *initproto.InitRequest, stream initproto.API_InitServer) error {
for _, r := range s.res {
_ = stream.Send(r)
}
return s.initErr
}

0 comments on commit 844d5a9

Please sign in to comment.