diff --git a/testing/mocks.go b/testing/mocks.go index 427fbcd..289543a 100644 --- a/testing/mocks.go +++ b/testing/mocks.go @@ -151,20 +151,6 @@ func (d *Device) Product() *spb.SevProduct { return d.SevProduct } -// Getter represents a static server for request/respond url -> body contents. -type Getter struct { - Responses map[string][]byte -} - -// Get returns a registered response for a given URL. -func (g *Getter) Get(url string) ([]byte, error) { - v, ok := g.Responses[url] - if !ok { - return nil, fmt.Errorf("404: %s", url) - } - return v, nil -} - // GetResponse controls how often (Occurances) a certain response should be // provided. type GetResponse struct { @@ -173,16 +159,34 @@ type GetResponse struct { Error error } -// VariableResponseGetter is a mock for HTTPSGetter interface that sequentially +// Getter is a mock for HTTPSGetter interface that sequentially // returns the configured responses for the provided URL. Responses are returned // as a queue, i.e., always serving from index 0. -type VariableResponseGetter struct { +type Getter struct { Responses map[string][]GetResponse } -// Get the next configured response body and error. The configured response -// is also removed if it has been requested the configured number of times. -func (g *VariableResponseGetter) Get(url string) ([]byte, error) { +// SimpleGetter constructs a static server from url -> body responses. +// For more elaborate tests, construct a custom Getter. +func SimpleGetter(responses map[string][]byte) *Getter { + getter := &Getter{ + Responses: make(map[string][]GetResponse), + } + for key, value := range responses { + getter.Responses[key] = []GetResponse{ + { + Occurances: ^uint(0), + Body: value, + Error: nil, + }, + } + } + return getter +} + +// Get the next response body and error. The response is also removed, +// if it has been requested the configured number of times. +func (g *Getter) Get(url string) ([]byte, error) { resp, ok := g.Responses[url] if !ok || len(resp) == 0 { return nil, fmt.Errorf("404: %s", url) @@ -198,7 +202,7 @@ func (g *VariableResponseGetter) Get(url string) ([]byte, error) { // Done checks that all configured responses have been consumed, and errors // otherwise. -func (g *VariableResponseGetter) Done(t testing.TB) { +func (g *Getter) Done(t testing.TB) { for key := range g.Responses { if len(g.Responses[key]) != 0 { t.Errorf("Prepared response for '%s' not retrieved.", key) diff --git a/validate/validate_test.go b/validate/validate_test.go index 41ec2be..30cdc59 100644 --- a/validate/validate_test.go +++ b/validate/validate_test.go @@ -235,12 +235,12 @@ func TestValidateSnpAttestation(t *testing.T) { if err != nil { t.Fatal(err) } - getter := &test.Getter{ - Responses: map[string][]byte{ + getter := test.SimpleGetter( + map[string][]byte{ "https://kdsintf.amd.com/vcek/v1/Milan/cert_chain": rootBytes, "https://kdsintf.amd.com/vcek/v1/Milan/0a0b0c0d0e0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000010203040506?blSPL=31&teeSPL=127&snpSPL=112&ucodeSPL=146": sign.Vcek.Raw, }, - } + ) attestationFn := func(nonce [64]byte) *spb.Attestation { report, err := sg.GetReport(device, nonce) if err != nil { @@ -271,12 +271,14 @@ func TestValidateSnpAttestation(t *testing.T) { if err != nil { t.Fatal(err) } - return &spb.Attestation{Report: report, + return &spb.Attestation{ + Report: report, CertificateChain: &spb.CertificateChain{ AskCert: sign0.Ask.Raw, ArkCert: sign0.Ark.Raw, VcekCert: sign0.Vcek.Raw, - }} + }, + } }(), opts: &Options{ReportData: nonce0s1[:], GuestPolicy: abi.SnpPolicy{Debug: true}}, }, diff --git a/verify/trust/trust_test.go b/verify/trust/trust_test.go index 289a06a..417cf29 100644 --- a/verify/trust/trust_test.go +++ b/verify/trust/trust_test.go @@ -26,12 +26,12 @@ import ( func TestRetryHTTPSGetter(t *testing.T) { testCases := map[string]struct { - getter *test.VariableResponseGetter + getter *test.Getter timeout time.Duration maxRetryDelay time.Duration }{ "immediate success": { - getter: &test.VariableResponseGetter{ + getter: &test.Getter{ Responses: map[string][]test.GetResponse{ "https://fetch.me": { { @@ -46,7 +46,7 @@ func TestRetryHTTPSGetter(t *testing.T) { maxRetryDelay: time.Millisecond, }, "second success": { - getter: &test.VariableResponseGetter{ + getter: &test.Getter{ Responses: map[string][]test.GetResponse{ "https://fetch.me": { { @@ -66,7 +66,7 @@ func TestRetryHTTPSGetter(t *testing.T) { maxRetryDelay: time.Millisecond, }, "third success": { - getter: &test.VariableResponseGetter{ + getter: &test.Getter{ Responses: map[string][]test.GetResponse{ "https://fetch.me": { { @@ -108,7 +108,7 @@ func TestRetryHTTPSGetter(t *testing.T) { } func TestRetryHTTPSGetterAllFail(t *testing.T) { - testGetter := &test.VariableResponseGetter{ + testGetter := &test.Getter{ Responses: map[string][]test.GetResponse{ "https://fetch.me": { { diff --git a/verify/verify_test.go b/verify/verify_test.go index 7008bc3..1821f47 100644 --- a/verify/verify_test.go +++ b/verify/verify_test.go @@ -41,8 +41,10 @@ import ( const product = "Milan" -var signMu sync.Once -var signer *test.AmdSigner +var ( + signMu sync.Once + signer *test.AmdSigner +) func initSigner() { newSigner, err := test.DefaultCertChain(product, time.Now()) @@ -296,15 +298,16 @@ func TestKdsMetadataLogic(t *testing.T) { } // Trust the test-generated root if the test should pass. Otherwise, other root logic // won't get tested. - options := &Options{TrustedRoots: map[string][]*trust.AMDRootCerts{ - "Milan": {&trust.AMDRootCerts{ - Product: "Milan", - ProductCerts: &trust.ProductCerts{ - Ark: newSigner.Ark, - Ask: newSigner.Ask, - }, - }}, - }, + options := &Options{ + TrustedRoots: map[string][]*trust.AMDRootCerts{ + "Milan": {&trust.AMDRootCerts{ + Product: "Milan", + ProductCerts: &trust.ProductCerts{ + Ark: newSigner.Ark, + Ask: newSigner.Ask, + }, + }}, + }, Now: time.Date(1, time.January, 5, 0, 0, 0, 0, time.UTC), } if tc.wantErr != "" { @@ -374,11 +377,11 @@ func TestCRLRootValidity(t *testing.T) { if err != nil { t.Fatal(err) } - g2 := &test.Getter{ - Responses: map[string][]byte{ + g2 := test.SimpleGetter( + map[string][]byte{ "https://kdsintf.amd.com/vcek/v1/Milan/crl": crl, }, - } + ) wantErr := "CRL is not signed by ARK" if err := VcekNotRevoked(root, signer2.Vcek, &Options{Getter: g2}); !test.Match(err, wantErr) { t.Errorf("Bad Root: VcekNotRevoked(%v) did not error as expected. Got %v, want %v", signer.Vcek, err, wantErr) @@ -452,13 +455,13 @@ func TestRealAttestationVerification(t *testing.T) { trust.ClearProductCertCache() var nonce [64]byte copy(nonce[:], []byte{1, 2, 3, 4, 5}) - getter := &test.Getter{ - Responses: map[string][]byte{ + getter := test.SimpleGetter( + map[string][]byte{ "https://kdsintf.amd.com/vcek/v1/Milan/cert_chain": testdata.MilanBytes, // Use the VCEK's hwID and known TCB values to specify the URL its VCEK cert would be fetched from. "https://kdsintf.amd.com/vcek/v1/Milan/3ac3fe21e13fb0990eb28a802e3fb6a29483a6b0753590c951bdd3b8e53786184ca39e359669a2b76a1936776b564ea464cdce40c05f63c9b610c5068b006b5d?blSPL=2&teeSPL=0&snpSPL=5&ucodeSPL=68": testdata.VcekBytes, }, - } + ) if err := RawSnpReport(testdata.AttestationBytes, &Options{Getter: getter}); err != nil { t.Error(err) }