Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fetch_server: support digest function and blake3 #6382

Merged
merged 3 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion server/remote_asset/fetch_server/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")

go_library(
name = "fetch_server",
Expand All @@ -23,3 +23,23 @@ go_library(
"@org_golang_google_protobuf//types/known/durationpb",
],
)

go_test(
name = "fetch_server_test",
srcs = ["fetch_server_test.go"],
deps = [
":fetch_server",
"//proto:remote_asset_go_proto",
"//proto:remote_execution_go_proto",
"//proto:resource_go_proto",
"//server/remote_cache/byte_stream_server",
"//server/remote_cache/digest",
"//server/testutil/testenv",
"//server/util/prefix",
"//server/util/scratchspace",
"@com_github_stretchr_testify//assert",
"@com_github_stretchr_testify//require",
"@org_golang_google_genproto_googleapis_bytestream//:bytestream",
"@org_golang_google_grpc//:go_default_library",
],
)
187 changes: 130 additions & 57 deletions server/remote_asset/fetch_server/fetch_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ import (
)

const (
checksumQualifier = "checksum.sri"
ChecksumQualifier = "checksum.sri"
sha256Prefix = "sha256-"
blake3Prefix = "blake3-"
maxHTTPTimeout = 60 * time.Minute
)

Expand Down Expand Up @@ -64,9 +65,6 @@ func checkPreconditions(env environment.Env) error {
if env.GetCache() == nil {
return status.FailedPreconditionError("missing Cache")
}
if env.GetByteStreamClient() == nil {
return status.FailedPreconditionError("missing ByteStreamClient")
}
return nil
}

Expand Down Expand Up @@ -109,52 +107,43 @@ func (p *FetchServer) FetchBlob(ctx context.Context, req *rapb.FetchBlobRequest)
return nil, err
}

var expectedSHA256 string

storageFunc := req.GetDigestFunction()
if storageFunc == repb.DigestFunction_UNKNOWN {
storageFunc = repb.DigestFunction_SHA256
}
var checksumFunc repb.DigestFunction_Value
sluongng marked this conversation as resolved.
Show resolved Hide resolved
var expectedChecksum string
for _, qualifier := range req.GetQualifiers() {
if qualifier.GetName() == checksumQualifier && strings.HasPrefix(qualifier.GetValue(), sha256Prefix) {
b64sha256 := strings.TrimPrefix(qualifier.GetValue(), sha256Prefix)
sha256, err := base64.StdEncoding.DecodeString(b64sha256)
if err != nil {
return nil, status.FailedPreconditionErrorf("Error decoding qualifier %q: %s", qualifier.GetName(), err.Error())
}
blobDigest := &repb.Digest{
Hash: fmt.Sprintf("%x", sha256),
// The digest size is unknown since the client only sends up
// the hash. We can look up the size using the Metadata API,
// which looks up only using the hash, so the size we pass here
// doesn't matter.
SizeBytes: 1,
var prefix string
if qualifier.GetName() == ChecksumQualifier {
if strings.HasPrefix(qualifier.GetValue(), sha256Prefix) {
checksumFunc = repb.DigestFunction_SHA256
prefix = sha256Prefix
} else if strings.HasPrefix(qualifier.GetValue(), blake3Prefix) {
checksumFunc = repb.DigestFunction_BLAKE3
prefix = blake3Prefix
}
expectedSHA256 = blobDigest.Hash
cacheRN := digest.NewResourceName(blobDigest, req.GetInstanceName(), rspb.CacheType_CAS, repb.DigestFunction_SHA256)

log.CtxInfof(ctx, "Looking up %s in cache", blobDigest.Hash)

// Lookup metadata to get the correct digest size to be returned to
// the client.
cache := p.env.GetCache()
md, err := cache.Metadata(ctx, cacheRN.ToProto())
}
if prefix != "" {
b64hash := strings.TrimPrefix(qualifier.GetValue(), prefix)
decodedHash, err := base64.StdEncoding.DecodeString(b64hash)
if err != nil {
log.CtxInfof(ctx, "FetchServer failed to get metadata for %s: %s", expectedSHA256, err)
continue
return nil, status.FailedPreconditionErrorf("Error decoding qualifier %q: %s", qualifier.GetName(), err.Error())
}
blobDigest.SizeBytes = md.DigestSizeBytes
expectedChecksum = fmt.Sprintf("%x", decodedHash)
break
}
}
if len(expectedChecksum) != 0 {
blobDigest := p.findBlobInCache(ctx, req.GetInstanceName(), checksumFunc, expectedChecksum)
// If the digestFunc is supplied and differ from the checksum sri,
// after looking up the cached blob using checksum sri, re-upload
// that blob using the requested digestFunc.
if blobDigest != nil && checksumFunc != storageFunc {
blobDigest = p.rewriteToCache(ctx, blobDigest, req.GetInstanceName(), checksumFunc, storageFunc)
}

// Even though we successfully fetched metadata, we need to renew
// the cache entry (using Contains()) to ensure that it doesn't
// expire by the time the client requests it from cache.
cacheRN = digest.NewResourceName(blobDigest, req.GetInstanceName(), rspb.CacheType_CAS, repb.DigestFunction_SHA256)
exists, err := cache.Contains(ctx, cacheRN.ToProto())
if err != nil {
log.CtxErrorf(ctx, "Failed to renew %s: %s", digest.String(blobDigest), err)
continue
}
if !exists {
log.CtxInfof(ctx, "Blob %s expired before we could renew it", digest.String(blobDigest))
continue
}
log.CtxInfof(ctx, "FetchServer found %s in cache", digest.String(blobDigest))
if blobDigest != nil {
return &rapb.FetchBlobResponse{
Status: &statuspb.Status{Code: int32(gcodes.OK)},
BlobDigest: blobDigest,
Expand All @@ -172,7 +161,7 @@ func (p *FetchServer) FetchBlob(ctx context.Context, req *rapb.FetchBlobRequest)
if err != nil {
return nil, status.InvalidArgumentErrorf("unparsable URI: %q", uri)
}
blobDigest, err := mirrorToCache(ctx, p.env.GetByteStreamClient(), req.GetInstanceName(), httpClient, uri, expectedSHA256)
blobDigest, err := mirrorToCache(ctx, p.env.GetByteStreamClient(), req.GetInstanceName(), httpClient, uri, storageFunc, checksumFunc, expectedChecksum)
if err != nil {
lastFetchErr = err
log.CtxWarningf(ctx, "Failed to mirror %q to cache: %s", uri, err)
Expand Down Expand Up @@ -203,12 +192,81 @@ func (p *FetchServer) FetchDirectory(ctx context.Context, req *rapb.FetchDirecto
return nil, status.UnimplementedError("FetchDirectory is not yet implemented")
}

func (p *FetchServer) rewriteToCache(ctx context.Context, blobDigest *repb.Digest, instanceName string, fromFunc, toFunc repb.DigestFunction_Value) *repb.Digest {
cacheRN := digest.NewResourceName(blobDigest, instanceName, rspb.CacheType_CAS, fromFunc)
cache := p.env.GetCache()
reader, err := cache.Reader(ctx, cacheRN.ToProto(), 0, 0)
if err != nil {
log.CtxErrorf(ctx, "Failed to get cache reader for %s: %s", digest.String(blobDigest), err)
return nil
}

tmpFilePath, err := tempCopy(reader)
if err != nil {
log.CtxErrorf(ctx, "Failed to copy from reader to temp for %s: %s", digest.String(blobDigest), err)
return nil
}
defer func() {
if err := os.Remove(tmpFilePath); err != nil {
log.Errorf("Failed to remove temp file: %s", err)
}
}()

bsClient := p.env.GetByteStreamClient()
storageDigest, err := cachetools.UploadFile(ctx, bsClient, instanceName, toFunc, tmpFilePath)
if err != nil {
log.CtxErrorf(ctx, "Failed to re-upload blob with new digestFunc %s for %s: %s", toFunc, digest.String(blobDigest), err)
return nil
}
return storageDigest
}

func (p *FetchServer) findBlobInCache(ctx context.Context, instanceName string, checksumFunc repb.DigestFunction_Value, expectedChecksum string) *repb.Digest {
blobDigest := &repb.Digest{
Hash: expectedChecksum,
// The digest size is unknown since the client only sends up
// the hash. We can look up the size using the Metadata API,
// which looks up only using the hash, so the size we pass here
// doesn't matter.
SizeBytes: 1,
}
cacheRN := digest.NewResourceName(blobDigest, instanceName, rspb.CacheType_CAS, checksumFunc)
log.CtxDebugf(ctx, "Looking up %s in cache", blobDigest.Hash)

// Lookup metadata to get the correct digest size to be returned to
// the client.
cache := p.env.GetCache()
md, err := cache.Metadata(ctx, cacheRN.ToProto())
if err != nil {
log.CtxInfof(ctx, "FetchServer failed to get metadata for %s: %s", expectedChecksum, err)
return nil
}
blobDigest.SizeBytes = md.DigestSizeBytes

// Even though we successfully fetched metadata, we need to renew
// the cache entry (using Contains()) to ensure that it doesn't
// expire by the time the client requests it from cache.
cacheRN = digest.NewResourceName(blobDigest, instanceName, rspb.CacheType_CAS, checksumFunc)
exists, err := cache.Contains(ctx, cacheRN.ToProto())
if err != nil {
log.CtxErrorf(ctx, "Failed to renew %s: %s", digest.String(blobDigest), err)
return nil
}
if !exists {
log.CtxInfof(ctx, "Blob %s expired before we could renew it", digest.String(blobDigest))
return nil
}

log.CtxDebugf(ctx, "FetchServer found %s in cache", digest.String(blobDigest))
return blobDigest
}

// mirrorToCache uploads the contents at the given URI to the given cache,
// returning the digest. The fetched contents are checked against the given
// expectedSHA256 (if non-empty), and if there is a mismatch then an error is
// expectedChecksum (if non-empty), and if there is a mismatch then an error is
// returned.
func mirrorToCache(ctx context.Context, bsClient bspb.ByteStreamClient, remoteInstanceName string, httpClient *http.Client, uri, expectedSHA256 string) (*repb.Digest, error) {
log.CtxInfof(ctx, "Fetching %s", uri)
func mirrorToCache(ctx context.Context, bsClient bspb.ByteStreamClient, remoteInstanceName string, httpClient *http.Client, uri string, storageFunc repb.DigestFunction_Value, checksumFunc repb.DigestFunction_Value, expectedChecksum string) (*repb.Digest, error) {
log.CtxDebugf(ctx, "Fetching %s", uri)
rsp, err := httpClient.Get(uri)
if err != nil {
return nil, status.UnavailableErrorf("failed to fetch %q: HTTP GET failed: %s", uri, err)
Expand All @@ -218,12 +276,12 @@ func mirrorToCache(ctx context.Context, bsClient bspb.ByteStreamClient, remoteIn
return nil, status.UnavailableErrorf("failed to fetch %q: HTTP %s", uri, err)
}

// If we know what the SHA256 should be and the content length is known,
// If we know what the hash should be and the content length is known,
// then we know the full digest, and can pipe directly from the HTTP
// response to cache.
if expectedSHA256 != "" && rsp.ContentLength >= 0 {
d := &repb.Digest{Hash: expectedSHA256, SizeBytes: rsp.ContentLength}
rn := digest.NewResourceName(d, remoteInstanceName, rspb.CacheType_CAS, repb.DigestFunction_SHA256)
if checksumFunc == storageFunc && expectedChecksum != "" && rsp.ContentLength >= 0 {
d := &repb.Digest{Hash: expectedChecksum, SizeBytes: rsp.ContentLength}
rn := digest.NewResourceName(d, remoteInstanceName, rspb.CacheType_CAS, storageFunc)
if _, err := cachetools.UploadFromReader(ctx, bsClient, rn, rsp.Body); err != nil {
return nil, status.UnavailableErrorf("failed to upload %s to cache: %s", digest.String(d), err)
}
Expand All @@ -246,14 +304,29 @@ func mirrorToCache(ctx context.Context, bsClient bspb.ByteStreamClient, remoteIn
log.Errorf("Failed to remove temp file: %s", err)
}
}()
blobDigest, err := cachetools.UploadFile(ctx, bsClient, remoteInstanceName, repb.DigestFunction_SHA256, tmpFilePath)

// If the requested digestFunc is supplied and differ from the checksum sri,
// verify the downloaded file with the checksum sri before storing it to
// our cache.
if checksumFunc != storageFunc {
checksumDigestRN, err := cachetools.ComputeFileDigest(tmpFilePath, remoteInstanceName, checksumFunc)
if err != nil {
return nil, status.UnavailableErrorf("failed to compute checksum digest: %s", err)
}
if expectedChecksum != "" && checksumDigestRN.GetDigest().GetHash() != expectedChecksum {
return nil, status.InvalidArgumentErrorf("response body checksum for %q was %q but wanted %q", uri, checksumDigestRN.GetDigest().Hash, expectedChecksum)
}
}
blobDigest, err := cachetools.UploadFile(ctx, bsClient, remoteInstanceName, storageFunc, tmpFilePath)
if err != nil {
return nil, status.UnavailableErrorf("failed to add object to cache: %s", err)
}
if expectedSHA256 != "" && blobDigest.Hash != expectedSHA256 {
return nil, status.InvalidArgumentErrorf("response body checksum for %q was %q but wanted %q", uri, blobDigest.Hash, expectedSHA256)
// If the requested digestFunc is supplied is the same with the checksum sri,
// verify the expected checksum of the downloaded file after storing it in our cache.
if checksumFunc == storageFunc && expectedChecksum != "" && blobDigest.Hash != expectedChecksum {
return nil, status.InvalidArgumentErrorf("response body checksum for %q was %q but wanted %q", uri, blobDigest.Hash, expectedChecksum)
}
log.CtxInfof(ctx, "Mirrored %s to cache (digest: %s)", uri, digest.String(blobDigest))
log.CtxDebugf(ctx, "Mirrored %s to cache (digest: %s)", uri, digest.String(blobDigest))
return blobDigest, nil
}

Expand Down
Loading
Loading