Skip to content

Commit

Permalink
Merge pull request #35 from deeglaze/fwerr
Browse files Browse the repository at this point in the history
Fix SevFirmwareErr implementation and clean up error conditionals
  • Loading branch information
deeglaze authored Jan 26, 2023
2 parents 0d57edf + 6a17176 commit 0f7e438
Show file tree
Hide file tree
Showing 8 changed files with 50 additions and 33 deletions.
3 changes: 1 addition & 2 deletions abi/amdsp.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,10 @@ const GuestRequestInvalidLength SevFirmwareStatus = 0x100000000

// SevFirmwareErr is an error that interprets firmware status codes from the AMD secure processor.
type SevFirmwareErr struct {
error
Status SevFirmwareStatus
}

func (e SevFirmwareErr) Error() string {
func (e *SevFirmwareErr) Error() string {
if e.Status == Success {
return "success"
}
Expand Down
6 changes: 3 additions & 3 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@ func message(d Device, command uintptr, req *labi.SnpUserGuestRequest) error {
// indicates a problem certificate length. We need to
// communicate that specifically.
if req.FwErr != 0 {
return abi.SevFirmwareErr{Status: abi.SevFirmwareStatus(req.FwErr)}
return &abi.SevFirmwareErr{Status: abi.SevFirmwareStatus(req.FwErr)}
}
return err
}
if result != uintptr(labi.EsOk) {
return labi.SevEsErr{Result: labi.EsResult(result)}
return &labi.SevEsErr{Result: labi.EsResult(result)}
}
return nil
}
Expand Down Expand Up @@ -113,7 +113,7 @@ func getExtendedReportIn(d Device, reportData [64]byte, vmpl int, certs []byte)
}
// Query the length required for certs.
if err := message(d, labi.IocSnpGetExtendedReport, &userGuestReq); err != nil {
var fwErr abi.SevFirmwareErr
var fwErr *abi.SevFirmwareErr
if errors.As(err, &fwErr) && fwErr.Status == abi.GuestRequestInvalidLength {
return nil, snpExtReportReq.CertsLength, nil
}
Expand Down
14 changes: 7 additions & 7 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func initDevice() {
now := time.Date(2022, time.May, 3, 9, 0, 0, 0, time.UTC)
for _, tc := range test.TestCases() {
// Don't test faked errors when running real hardware tests.
if !UseDefaultSevGuest() && tc.WantErr != nil {
if !UseDefaultSevGuest() && tc.WantErr != "" {
continue
}
tests = append(tests, tc)
Expand Down Expand Up @@ -137,11 +137,11 @@ func TestOpenGetReportClose(t *testing.T) {

// Does the proto report match expectations?
got, err := GetReport(device, tc.Input)
if err != tc.WantErr {
if !test.Match(err, tc.WantErr) {
t.Fatalf("GetReport(device, %v) = %v, %v. Want err: %v", tc.Input, got, err, tc.WantErr)
}

if tc.WantErr == nil {
if tc.WantErr == "" {
cleanReport(got)
want := reportProto
want.Signature = got.Signature // Zeros were placeholders.
Expand All @@ -156,10 +156,10 @@ func TestOpenGetRawExtendedReportClose(t *testing.T) {
devMu.Do(initDevice)
for _, tc := range tests {
raw, certs, err := GetRawExtendedReport(device, tc.Input)
if err != tc.WantErr {
if !test.Match(err, tc.WantErr) {
t.Fatalf("%s: GetRawExtendedReport(device, %v) = %v, %v, %v. Want err: %v", tc.Name, tc.Input, raw, certs, err, tc.WantErr)
}
if tc.WantErr == nil {
if tc.WantErr == "" {
if err := cleanRawReport(raw); err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -189,10 +189,10 @@ func TestOpenGetExtendedReportClose(t *testing.T) {
devMu.Do(initDevice)
for _, tc := range tests {
ereport, err := GetExtendedReport(device, tc.Input)
if err != tc.WantErr {
if !test.Match(err, tc.WantErr) {
t.Fatalf("%s: GetExtendedReport(device, %v) = %v, %v. Want err: %v", tc.Name, tc.Input, ereport, err, tc.WantErr)
}
if tc.WantErr == nil {
if tc.WantErr == "" {
reportProto := &spb.Report{}
if err := prototext.Unmarshal([]byte(tc.OutputProto), reportProto); err != nil {
t.Fatalf("test failure: %v", err)
Expand Down
3 changes: 1 addition & 2 deletions client/linuxabi/linux_abi.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,10 @@ const (

// SevEsErr is an error that interprets SEV-ES guest-host communication results.
type SevEsErr struct {
error
Result EsResult
}

func (err SevEsErr) Error() string {
func (err *SevEsErr) Error() string {
if err.Result == EsUnsupported {
return "requested operation not supported"
}
Expand Down
5 changes: 3 additions & 2 deletions kds/kds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"encoding/hex"
"fmt"
"net/url"
"strings"
"testing"

"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -80,7 +81,7 @@ func TestParseProductBaseURL(t *testing.T) {
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
gotProduct, gotURL, err := parseBaseProductURL(tc.url)
if (err != nil && err.Error() != tc.wantErr) || (err == nil && tc.wantErr != "") {
if (err == nil && tc.wantErr != "") || (err != nil && !strings.Contains(err.Error(), tc.wantErr)) {
t.Fatalf("parseBaseProductURL(%q) = _, _, %v, want %q", tc.url, err, tc.wantErr)
}
if err == nil {
Expand Down Expand Up @@ -144,7 +145,7 @@ func TestParseVCEKCertURL(t *testing.T) {
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
got, err := ParseVCEKCertURL(tc.url)
if (err != nil && err.Error() != tc.wantErr) || (err == nil && tc.wantErr != "") {
if (err == nil && tc.wantErr != "") || (err != nil && !strings.Contains(err.Error(), tc.wantErr)) {
t.Fatalf("ParseVCEKCertURL(%q) = _, %v, want %q", tc.url, err, tc.wantErr)
}
if err == nil {
Expand Down
25 changes: 25 additions & 0 deletions testing/match.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package testing

import "strings"

// Match returns true iff both errors match expectations closely enough
func Match(got error, want string) bool {
if got == nil {
return want == ""
}
return strings.Contains(got.Error(), want)
}
4 changes: 2 additions & 2 deletions testing/test_cases.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ type TestCase struct {
OutputProto string
FwErr abi.SevFirmwareStatus
EsResult labi.EsResult
WantErr error
WantErr string
}

// TestCases returns common test cases for get_report.
Expand All @@ -158,7 +158,7 @@ func TestCases() []TestCase {
Name: "fw oom",
Input: userZeros11,
FwErr: abi.ResourceLimit,
WantErr: abi.SevFirmwareErr{Status: abi.ResourceLimit},
WantErr: (&abi.SevFirmwareErr{Status: abi.ResourceLimit}).Error(),
},
}
}
Expand Down
23 changes: 8 additions & 15 deletions verify/verify_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"math/big"
"math/rand"
"os"
"strings"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -151,10 +150,10 @@ func TestSnpReportSignature(t *testing.T) {
}
// Does the Raw report match expectations?
raw, err := sg.GetRawReport(d, tc.Input)
if err != tc.WantErr {
if !test.Match(err, tc.WantErr) {
t.Fatalf("GetRawReport(d, %v) = %v, %v. Want err: %v", tc.Input, raw, err, tc.WantErr)
}
if tc.WantErr == nil {
if tc.WantErr == "" {
got := abi.SignedComponent(raw)
want := abi.SignedComponent(tc.Output[:])
if !bytes.Equal(got, want) {
Expand Down Expand Up @@ -303,14 +302,8 @@ func TestKdsMetadataLogic(t *testing.T) {
options = &Options{}
}
vcek, _, err := VcekDER(newSigner.Vcek.Raw, newSigner.Ask.Raw, newSigner.Ark.Raw, options)
if err == nil && tc.wantErr != "" {
t.Errorf("%s: VcekDER(...) = %+v did not error as expected.", tc.name, vcek)
}
if err != nil && tc.wantErr == "" {
t.Errorf("%s: VcekDER(...) errored unexpectedly: %v", tc.name, err)
}
if err != nil && tc.wantErr != "" && !strings.Contains(err.Error(), tc.wantErr) {
t.Errorf("%s: VcekDER(...) did not error as expected. Got %v, want %s", tc.name, err, tc.wantErr)
if !test.Match(err, tc.wantErr) {
t.Errorf("%s: VcekDER(...) = %+v, %v did not error as expected. Want %q", tc.name, vcek, err, tc.wantErr)
}
}
}
Expand Down Expand Up @@ -374,7 +367,7 @@ func TestCRLRootValidity(t *testing.T) {
},
}
wantErr := "CRL is not signed by ARK"
if err := VcekNotRevoked(root, g2, signer2.Vcek); err == nil || !strings.Contains(err.Error(), wantErr) {
if err := VcekNotRevoked(root, g2, signer2.Vcek); !test.Match(err, wantErr) {
t.Errorf("Bad Root: VcekNotRevoked(%v) did not error as expected. Got %v, want %v", signer.Vcek, err, wantErr)
}

Expand All @@ -385,7 +378,7 @@ func TestCRLRootValidity(t *testing.T) {
AskX509: signer2.Ask,
}
wantErr2 := "ASK was revoked at 2022-06-14 12:01:00 +0000 UTC"
if err := VcekNotRevoked(root2, g2, signer2.Vcek); err == nil || !strings.Contains(err.Error(), wantErr2) {
if err := VcekNotRevoked(root2, g2, signer2.Vcek); !test.Match(err, wantErr2) {
t.Errorf("Bad ASK: VcekNotRevoked(%v) did not error as expected. Got %v, want %v", signer.Vcek, err, wantErr2)
}
}
Expand Down Expand Up @@ -422,10 +415,10 @@ func TestOpenGetExtendedReportVerifyClose(t *testing.T) {
}
for _, getReport := range reportGetters {
ereport, err := getReport.getter(d, tc.Input)
if err != tc.WantErr {
if !test.Match(err, tc.WantErr) {
t.Fatalf("%s: %s(d, %v) = %v, %v. Want err: %v", tc.Name, getReport.name, tc.Input, ereport, err, tc.WantErr)
}
if tc.WantErr == nil {
if tc.WantErr == "" {
if err := SnpAttestation(ereport, options); err != nil {
t.Errorf("SnpAttestation(%v) errored unexpectedly: %v", ereport, err)
}
Expand Down

0 comments on commit 0f7e438

Please sign in to comment.