From 6a17176b40a36d5be56e6bd5009e2d0aaac50fa0 Mon Sep 17 00:00:00 2001 From: Dionna Glaze Date: Thu, 26 Jan 2023 19:14:32 +0000 Subject: [PATCH] Fix SevFirmwareErr, SevEsErr implementations and clean up error conditionals An error type should be a pointer to an object that implements Error. The indirection makes direct error comparison fail, so we switch to the existing pattern of wantErr as a string with a Contains test. The testing library implementation can't be used in kds_test.go due to a cyclic dependency that would induce, but it's a mild inconvenince. Signed-off-by: Dionna Glaze --- abi/amdsp.go | 3 +-- client/client.go | 6 +++--- client/client_test.go | 14 +++++++------- client/linuxabi/linux_abi.go | 3 +-- kds/kds_test.go | 5 +++-- testing/match.go | 25 +++++++++++++++++++++++++ testing/test_cases.go | 4 ++-- verify/verify_test.go | 23 ++++++++--------------- 8 files changed, 50 insertions(+), 33 deletions(-) create mode 100644 testing/match.go diff --git a/abi/amdsp.go b/abi/amdsp.go index f9b0f29..4572e85 100644 --- a/abi/amdsp.go +++ b/abi/amdsp.go @@ -106,11 +106,10 @@ const GuestRequestInvalidLength SevFirmwareStatus = 0x100000000 // SevFirmwareErr is an error that interprets firmware status codes from the AMD secure processor. type SevFirmwareErr struct { - error Status SevFirmwareStatus } -func (e SevFirmwareErr) Error() string { +func (e *SevFirmwareErr) Error() string { if e.Status == Success { return "success" } diff --git a/client/client.go b/client/client.go index 8027f0a..469e325 100644 --- a/client/client.go +++ b/client/client.go @@ -46,12 +46,12 @@ func message(d Device, command uintptr, req *labi.SnpUserGuestRequest) error { // indicates a problem certificate length. We need to // communicate that specifically. if req.FwErr != 0 { - return abi.SevFirmwareErr{Status: abi.SevFirmwareStatus(req.FwErr)} + return &abi.SevFirmwareErr{Status: abi.SevFirmwareStatus(req.FwErr)} } return err } if result != uintptr(labi.EsOk) { - return labi.SevEsErr{Result: labi.EsResult(result)} + return &labi.SevEsErr{Result: labi.EsResult(result)} } return nil } @@ -113,7 +113,7 @@ func getExtendedReportIn(d Device, reportData [64]byte, vmpl int, certs []byte) } // Query the length required for certs. if err := message(d, labi.IocSnpGetExtendedReport, &userGuestReq); err != nil { - var fwErr abi.SevFirmwareErr + var fwErr *abi.SevFirmwareErr if errors.As(err, &fwErr) && fwErr.Status == abi.GuestRequestInvalidLength { return nil, snpExtReportReq.CertsLength, nil } diff --git a/client/client_test.go b/client/client_test.go index b14f3ae..4aec4e7 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -44,7 +44,7 @@ func initDevice() { now := time.Date(2022, time.May, 3, 9, 0, 0, 0, time.UTC) for _, tc := range test.TestCases() { // Don't test faked errors when running real hardware tests. - if !UseDefaultSevGuest() && tc.WantErr != nil { + if !UseDefaultSevGuest() && tc.WantErr != "" { continue } tests = append(tests, tc) @@ -137,11 +137,11 @@ func TestOpenGetReportClose(t *testing.T) { // Does the proto report match expectations? got, err := GetReport(device, tc.Input) - if err != tc.WantErr { + if !test.Match(err, tc.WantErr) { t.Fatalf("GetReport(device, %v) = %v, %v. Want err: %v", tc.Input, got, err, tc.WantErr) } - if tc.WantErr == nil { + if tc.WantErr == "" { cleanReport(got) want := reportProto want.Signature = got.Signature // Zeros were placeholders. @@ -156,10 +156,10 @@ func TestOpenGetRawExtendedReportClose(t *testing.T) { devMu.Do(initDevice) for _, tc := range tests { raw, certs, err := GetRawExtendedReport(device, tc.Input) - if err != tc.WantErr { + if !test.Match(err, tc.WantErr) { t.Fatalf("%s: GetRawExtendedReport(device, %v) = %v, %v, %v. Want err: %v", tc.Name, tc.Input, raw, certs, err, tc.WantErr) } - if tc.WantErr == nil { + if tc.WantErr == "" { if err := cleanRawReport(raw); err != nil { t.Fatal(err) } @@ -189,10 +189,10 @@ func TestOpenGetExtendedReportClose(t *testing.T) { devMu.Do(initDevice) for _, tc := range tests { ereport, err := GetExtendedReport(device, tc.Input) - if err != tc.WantErr { + if !test.Match(err, tc.WantErr) { t.Fatalf("%s: GetExtendedReport(device, %v) = %v, %v. Want err: %v", tc.Name, tc.Input, ereport, err, tc.WantErr) } - if tc.WantErr == nil { + if tc.WantErr == "" { reportProto := &spb.Report{} if err := prototext.Unmarshal([]byte(tc.OutputProto), reportProto); err != nil { t.Fatalf("test failure: %v", err) diff --git a/client/linuxabi/linux_abi.go b/client/linuxabi/linux_abi.go index d81f380..86d9ce8 100644 --- a/client/linuxabi/linux_abi.go +++ b/client/linuxabi/linux_abi.go @@ -82,11 +82,10 @@ const ( // SevEsErr is an error that interprets SEV-ES guest-host communication results. type SevEsErr struct { - error Result EsResult } -func (err SevEsErr) Error() string { +func (err *SevEsErr) Error() string { if err.Result == EsUnsupported { return "requested operation not supported" } diff --git a/kds/kds_test.go b/kds/kds_test.go index 110bf2a..30ef83d 100644 --- a/kds/kds_test.go +++ b/kds/kds_test.go @@ -18,6 +18,7 @@ import ( "encoding/hex" "fmt" "net/url" + "strings" "testing" "github.com/google/go-cmp/cmp" @@ -80,7 +81,7 @@ func TestParseProductBaseURL(t *testing.T) { for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { gotProduct, gotURL, err := parseBaseProductURL(tc.url) - if (err != nil && err.Error() != tc.wantErr) || (err == nil && tc.wantErr != "") { + if (err == nil && tc.wantErr != "") || (err != nil && !strings.Contains(err.Error(), tc.wantErr)) { t.Fatalf("parseBaseProductURL(%q) = _, _, %v, want %q", tc.url, err, tc.wantErr) } if err == nil { @@ -144,7 +145,7 @@ func TestParseVCEKCertURL(t *testing.T) { for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { got, err := ParseVCEKCertURL(tc.url) - if (err != nil && err.Error() != tc.wantErr) || (err == nil && tc.wantErr != "") { + if (err == nil && tc.wantErr != "") || (err != nil && !strings.Contains(err.Error(), tc.wantErr)) { t.Fatalf("ParseVCEKCertURL(%q) = _, %v, want %q", tc.url, err, tc.wantErr) } if err == nil { diff --git a/testing/match.go b/testing/match.go new file mode 100644 index 0000000..d3b4cb1 --- /dev/null +++ b/testing/match.go @@ -0,0 +1,25 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testing + +import "strings" + +// Match returns true iff both errors match expectations closely enough +func Match(got error, want string) bool { + if got == nil { + return want == "" + } + return strings.Contains(got.Error(), want) +} diff --git a/testing/test_cases.go b/testing/test_cases.go index 5025fa8..81c40fe 100644 --- a/testing/test_cases.go +++ b/testing/test_cases.go @@ -134,7 +134,7 @@ type TestCase struct { OutputProto string FwErr abi.SevFirmwareStatus EsResult labi.EsResult - WantErr error + WantErr string } // TestCases returns common test cases for get_report. @@ -158,7 +158,7 @@ func TestCases() []TestCase { Name: "fw oom", Input: userZeros11, FwErr: abi.ResourceLimit, - WantErr: abi.SevFirmwareErr{Status: abi.ResourceLimit}, + WantErr: (&abi.SevFirmwareErr{Status: abi.ResourceLimit}).Error(), }, } } diff --git a/verify/verify_test.go b/verify/verify_test.go index 0ac5637..33dd6fc 100644 --- a/verify/verify_test.go +++ b/verify/verify_test.go @@ -23,7 +23,6 @@ import ( "math/big" "math/rand" "os" - "strings" "sync" "testing" "time" @@ -151,10 +150,10 @@ func TestSnpReportSignature(t *testing.T) { } // Does the Raw report match expectations? raw, err := sg.GetRawReport(d, tc.Input) - if err != tc.WantErr { + if !test.Match(err, tc.WantErr) { t.Fatalf("GetRawReport(d, %v) = %v, %v. Want err: %v", tc.Input, raw, err, tc.WantErr) } - if tc.WantErr == nil { + if tc.WantErr == "" { got := abi.SignedComponent(raw) want := abi.SignedComponent(tc.Output[:]) if !bytes.Equal(got, want) { @@ -303,14 +302,8 @@ func TestKdsMetadataLogic(t *testing.T) { options = &Options{} } vcek, _, err := VcekDER(newSigner.Vcek.Raw, newSigner.Ask.Raw, newSigner.Ark.Raw, options) - if err == nil && tc.wantErr != "" { - t.Errorf("%s: VcekDER(...) = %+v did not error as expected.", tc.name, vcek) - } - if err != nil && tc.wantErr == "" { - t.Errorf("%s: VcekDER(...) errored unexpectedly: %v", tc.name, err) - } - if err != nil && tc.wantErr != "" && !strings.Contains(err.Error(), tc.wantErr) { - t.Errorf("%s: VcekDER(...) did not error as expected. Got %v, want %s", tc.name, err, tc.wantErr) + if !test.Match(err, tc.wantErr) { + t.Errorf("%s: VcekDER(...) = %+v, %v did not error as expected. Want %q", tc.name, vcek, err, tc.wantErr) } } } @@ -374,7 +367,7 @@ func TestCRLRootValidity(t *testing.T) { }, } wantErr := "CRL is not signed by ARK" - if err := VcekNotRevoked(root, g2, signer2.Vcek); err == nil || !strings.Contains(err.Error(), wantErr) { + if err := VcekNotRevoked(root, g2, signer2.Vcek); !test.Match(err, wantErr) { t.Errorf("Bad Root: VcekNotRevoked(%v) did not error as expected. Got %v, want %v", signer.Vcek, err, wantErr) } @@ -385,7 +378,7 @@ func TestCRLRootValidity(t *testing.T) { AskX509: signer2.Ask, } wantErr2 := "ASK was revoked at 2022-06-14 12:01:00 +0000 UTC" - if err := VcekNotRevoked(root2, g2, signer2.Vcek); err == nil || !strings.Contains(err.Error(), wantErr2) { + if err := VcekNotRevoked(root2, g2, signer2.Vcek); !test.Match(err, wantErr2) { t.Errorf("Bad ASK: VcekNotRevoked(%v) did not error as expected. Got %v, want %v", signer.Vcek, err, wantErr2) } } @@ -422,10 +415,10 @@ func TestOpenGetExtendedReportVerifyClose(t *testing.T) { } for _, getReport := range reportGetters { ereport, err := getReport.getter(d, tc.Input) - if err != tc.WantErr { + if !test.Match(err, tc.WantErr) { t.Fatalf("%s: %s(d, %v) = %v, %v. Want err: %v", tc.Name, getReport.name, tc.Input, ereport, err, tc.WantErr) } - if tc.WantErr == nil { + if tc.WantErr == "" { if err := SnpAttestation(ereport, options); err != nil { t.Errorf("SnpAttestation(%v) errored unexpectedly: %v", ereport, err) }