From 844d5a995190f9bbd0cb0c8abbd39a542e16edfe Mon Sep 17 00:00:00 2001 From: Moritz Sanft <58110325+msanft@users.noreply.github.com> Date: Fri, 1 Dec 2023 10:21:20 +0100 Subject: [PATCH] constellation-lib: add init test Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> --- internal/constellation/BUILD.bazel | 27 ++- internal/constellation/apply.go | 6 +- internal/constellation/apply_test.go | 74 ++++++++ internal/constellation/applyinit_test.go | 228 +++++++++++++++++++++++ 4 files changed, 333 insertions(+), 2 deletions(-) create mode 100644 internal/constellation/apply_test.go create mode 100644 internal/constellation/applyinit_test.go diff --git a/internal/constellation/BUILD.bazel b/internal/constellation/BUILD.bazel index 50df0883c62..0a682728052 100644 --- a/internal/constellation/BUILD.bazel +++ b/internal/constellation/BUILD.bazel @@ -1,4 +1,5 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//bazel/go:go_test.bzl", "go_test") go_library( name = "constellation", @@ -12,7 +13,6 @@ go_library( deps = [ "//bootstrapper/initproto", "//internal/cloud/cloudprovider", - "//internal/config", "//internal/constants", "//internal/crypto", "//internal/grpc/grpclog", @@ -25,3 +25,28 @@ go_library( "@org_golang_google_grpc//:go_default_library", ], ) + +go_test( + name = "constellation_test", + srcs = [ + "apply_test.go", + "applyinit_test.go", + ], + embed = [":constellation"], + deps = [ + "//bootstrapper/initproto", + "//internal/cloud/cloudprovider", + "//internal/constants", + "//internal/crypto", + "//internal/grpc/atlscredentials", + "//internal/grpc/dialer", + "//internal/grpc/testdialer", + "//internal/kms/uri", + "//internal/license", + "//internal/logger", + "//internal/state", + "@com_github_stretchr_testify//assert", + "@com_github_stretchr_testify//require", + "@org_golang_google_grpc//:go_default_library", + ], +) diff --git a/internal/constellation/apply.go b/internal/constellation/apply.go index 8c4e89fbdb1..c74b6aeacee 100644 --- a/internal/constellation/apply.go +++ b/internal/constellation/apply.go @@ -21,10 +21,14 @@ import ( // In Particular, this involves Initialization and Upgrading of the cluster. type Applier struct { log debugLog - licenseChecker *license.Checker + licenseChecker licenseChecker spinner spinnerInterf } +type licenseChecker interface { + CheckLicense(context.Context, cloudprovider.Provider, string) (license.QuotaCheckResponse, error) +} + type debugLog interface { Debugf(format string, args ...any) } diff --git a/internal/constellation/apply_test.go b/internal/constellation/apply_test.go new file mode 100644 index 00000000000..ce2466fdfd4 --- /dev/null +++ b/internal/constellation/apply_test.go @@ -0,0 +1,74 @@ +/* +Copyright (c) Edgeless Systems GmbH + +SPDX-License-Identifier: AGPL-3.0-only +*/ + +package constellation + +import ( + "context" + "testing" + + "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" + "github.com/edgelesssys/constellation/v2/internal/crypto" + "github.com/edgelesssys/constellation/v2/internal/license" + "github.com/edgelesssys/constellation/v2/internal/logger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCheckLicense(t *testing.T) { + testCases := map[string]struct { + licenseChecker *stubLicenseChecker + wantErr bool + }{ + "success": { + licenseChecker: &stubLicenseChecker{}, + wantErr: false, + }, + "check license error": { + licenseChecker: &stubLicenseChecker{checkLicenseErr: assert.AnError}, + wantErr: true, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + require := require.New(t) + + a := &Applier{licenseChecker: tc.licenseChecker, log: logger.NewTest(t)} + _, err := a.CheckLicense(context.Background(), cloudprovider.Unknown, license.CommunityLicense) + if tc.wantErr { + require.Error(err) + } else { + require.NoError(err) + } + }) + } +} + +type stubLicenseChecker struct { + checkLicenseErr error +} + +func (c *stubLicenseChecker) CheckLicense(context.Context, cloudprovider.Provider, string) (license.QuotaCheckResponse, error) { + return license.QuotaCheckResponse{}, c.checkLicenseErr +} + +func TestGenerateMasterSecret(t *testing.T) { + assert := assert.New(t) + a := &Applier{log: logger.NewTest(t)} + sec, err := a.GenerateMasterSecret() + assert.NoError(err) + assert.Len(sec.Key, crypto.MasterSecretLengthDefault) + assert.Len(sec.Key, crypto.RNGLengthDefault) +} + +func TestGenerateMeasurementSalt(t *testing.T) { + assert := assert.New(t) + a := &Applier{log: logger.NewTest(t)} + salt, err := a.GenerateMeasurementSalt() + assert.NoError(err) + assert.Len(salt, crypto.RNGLengthDefault) +} diff --git a/internal/constellation/applyinit_test.go b/internal/constellation/applyinit_test.go new file mode 100644 index 00000000000..5a08abac5b2 --- /dev/null +++ b/internal/constellation/applyinit_test.go @@ -0,0 +1,228 @@ +/* +Copyright (c) Edgeless Systems GmbH + +SPDX-License-Identifier: AGPL-3.0-only +*/ + +package constellation + +import ( + "bytes" + "context" + "io" + "net" + "strconv" + "testing" + "time" + + "github.com/edgelesssys/constellation/v2/bootstrapper/initproto" + "github.com/edgelesssys/constellation/v2/internal/constants" + "github.com/edgelesssys/constellation/v2/internal/grpc/atlscredentials" + "github.com/edgelesssys/constellation/v2/internal/grpc/dialer" + "github.com/edgelesssys/constellation/v2/internal/grpc/testdialer" + "github.com/edgelesssys/constellation/v2/internal/kms/uri" + "github.com/edgelesssys/constellation/v2/internal/logger" + "github.com/edgelesssys/constellation/v2/internal/state" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" +) + +func TestInit(t *testing.T) { + clusterEndpoint := "192.0.2.1" + newState := func(endpoint string) *state.State { + return &state.State{ + Infrastructure: state.Infrastructure{ + ClusterEndpoint: endpoint, + }, + } + } + newInitServer := func(initErr error, responses ...*initproto.InitResponse) *stubInitServer { + return &stubInitServer{ + res: responses, + initErr: initErr, + } + } + + testCases := map[string]struct { + server initproto.APIServer + state *state.State + initServerEndpoint string + wantClusterLogs []byte + wantErr bool + }{ + "success": { + server: newInitServer(nil, + &initproto.InitResponse{ + Kind: &initproto.InitResponse_InitSuccess{ + InitSuccess: &initproto.InitSuccessResponse{ + Kubeconfig: []byte{}, + OwnerId: []byte{}, + ClusterId: []byte{}, + }, + }, + }), + state: newState(clusterEndpoint), + initServerEndpoint: clusterEndpoint, + }, + "no response": { + server: newInitServer(nil), + state: newState(clusterEndpoint), + initServerEndpoint: clusterEndpoint, + wantErr: true, + }, + "nil response": { + server: newInitServer(nil, &initproto.InitResponse{Kind: nil}), + state: newState(clusterEndpoint), + initServerEndpoint: clusterEndpoint, + wantErr: true, + }, + "failure response": { + server: newInitServer(nil, + &initproto.InitResponse{ + Kind: &initproto.InitResponse_InitFailure{ + InitFailure: &initproto.InitFailureResponse{ + Error: assert.AnError.Error(), + }, + }, + }), + state: newState(clusterEndpoint), + initServerEndpoint: clusterEndpoint, + wantErr: true, + }, + "setup server error": { + server: newInitServer(assert.AnError), + state: newState(clusterEndpoint), + initServerEndpoint: clusterEndpoint, + wantErr: true, + }, + "expected log response, got failure": { + server: newInitServer(nil, + &initproto.InitResponse{ + Kind: &initproto.InitResponse_InitFailure{ + InitFailure: &initproto.InitFailureResponse{ + Error: assert.AnError.Error(), + }, + }, + }, + &initproto.InitResponse{ + Kind: &initproto.InitResponse_InitFailure{ + InitFailure: &initproto.InitFailureResponse{ + Error: assert.AnError.Error(), + }, + }, + }, + ), + state: newState(clusterEndpoint), + initServerEndpoint: clusterEndpoint, + wantErr: true, + }, + "expected log response, got success": { + server: newInitServer(nil, + &initproto.InitResponse{ + Kind: &initproto.InitResponse_InitFailure{ + InitFailure: &initproto.InitFailureResponse{ + Error: assert.AnError.Error(), + }, + }, + }, + &initproto.InitResponse{ + Kind: &initproto.InitResponse_InitSuccess{ + InitSuccess: &initproto.InitSuccessResponse{ + Kubeconfig: []byte{}, + OwnerId: []byte{}, + ClusterId: []byte{}, + }, + }, + }, + ), + state: newState(clusterEndpoint), + initServerEndpoint: clusterEndpoint, + wantErr: true, + }, + "collect logs": { + server: newInitServer(nil, + &initproto.InitResponse{ + Kind: &initproto.InitResponse_InitFailure{ + InitFailure: &initproto.InitFailureResponse{ + Error: assert.AnError.Error(), + }, + }, + }, + &initproto.InitResponse{ + Kind: &initproto.InitResponse_Log{ + Log: &initproto.LogResponseType{ + Log: []byte("some log"), + }, + }, + }, + ), + wantClusterLogs: []byte("some log"), + state: newState(clusterEndpoint), + initServerEndpoint: clusterEndpoint, + wantErr: true, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := require.New(t) + + netDialer := testdialer.NewBufconnDialer() + dialer := dialer.New(nil, nil, netDialer) + stop := setupTestInitServer(netDialer, tc.server, tc.initServerEndpoint) + defer stop() + + a := &Applier{log: logger.NewTest(t), spinner: &nopSpinner{}} + + clusterLogs := &bytes.Buffer{} + ctx, cancel := context.WithTimeout(context.Background(), time.Second*4) + defer cancel() + _, err := a.Init(ctx, dialer, tc.state, clusterLogs, InitPayload{ + MasterSecret: uri.MasterSecret{}, + MeasurementSalt: []byte{}, + K8sVersion: "v1.26.5", + ConformanceMode: false, + }) + if tc.wantErr { + assert.Error(err) + assert.Equal(tc.wantClusterLogs, clusterLogs.Bytes()) + } else { + assert.NoError(err) + } + }) + } +} + +type nopSpinner struct { + io.Writer +} + +func (s *nopSpinner) Start(string, bool) {} +func (s *nopSpinner) Stop() {} +func (s *nopSpinner) Write(p []byte) (n int, err error) { + return s.Writer.Write(p) +} + +func setupTestInitServer(dialer *testdialer.BufconnDialer, server initproto.APIServer, host string) func() { + serverCreds := atlscredentials.New(nil, nil) + initServer := grpc.NewServer(grpc.Creds(serverCreds)) + initproto.RegisterAPIServer(initServer, server) + listener := dialer.GetListener(net.JoinHostPort(host, strconv.Itoa(constants.BootstrapperPort))) + go initServer.Serve(listener) + return initServer.GracefulStop +} + +type stubInitServer struct { + res []*initproto.InitResponse + initErr error + + initproto.UnimplementedAPIServer +} + +func (s *stubInitServer) Init(_ *initproto.InitRequest, stream initproto.API_InitServer) error { + for _, r := range s.res { + _ = stream.Send(r) + } + return s.initErr +}