Skip to content

Commit

Permalink
Set a timeout for image pulls (#7884)
Browse files Browse the repository at this point in the history
  • Loading branch information
bduffany authored Nov 14, 2024
1 parent de4208a commit c0c93ec
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 0 deletions.
20 changes: 20 additions & 0 deletions enterprise/server/remote_execution/container/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ var (
ErrRemoved = status.UnavailableError("container has been removed")

recordCPUTimelines = flag.Bool("executor.record_cpu_timelines", false, "Capture CPU timeseries data in UsageStats for each task.")
imagePullTimeout = flag.Duration("executor.image_pull_timeout", 5*time.Minute, "How long to wait for the container image to be pulled before returning an Unavailable (retryable) error for an action execution attempt. Applies to all isolation types (docker, firecracker, etc.)")
debugUseLocalImagesOnly = flag.Bool("debug_use_local_images_only", false, "Do not pull OCI images and only used locally cached images. This can be set to test local image builds during development without needing to push to a container registry. Not intended for production use.")
DebugEnableAnonymousRecycling = flag.Bool("debug_enable_anonymous_runner_recycling", false, "Whether to enable runner recycling for unauthenticated requests. For debugging purposes only - do not use in production.")

Expand Down Expand Up @@ -513,6 +514,25 @@ func PullImageIfNecessary(ctx context.Context, env environment.Env, ctr CommandC

ctx, span := tracing.StartSpan(ctx)
defer span.End()

if *imagePullTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, *imagePullTimeout)
defer cancel()
}

if err := pullImageIfNecessary(ctx, env, ctr, creds, imageRef); err != nil {
// make sure we always return Unavailable if the context deadline
// was exceeded
if err == context.DeadlineExceeded || ctx.Err() != nil {
return status.UnavailableErrorf("%s", status.Message(err))
}
return err
}
return nil
}

func pullImageIfNecessary(ctx context.Context, env environment.Env, ctr CommandContainer, creds oci.Credentials, imageRef string) error {
cacheAuth := env.GetImageCacheAuthenticator()
if cacheAuth == nil || env.GetAuthenticator() == nil {
// If we don't have an authenticator available, fall back to
Expand Down
49 changes: 49 additions & 0 deletions enterprise/server/remote_execution/runner/runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ type fakeContainer struct {
CreateError error
Removed chan struct{}
Result *interfaces.CommandResult
Isolation string // Fake isolation type name
ImageCached bool // Return value for IsImageCached
BlockPull bool // PullImage blocks forever if true.
}

func NewFakeContainer() *fakeContainer {
Expand All @@ -81,11 +84,26 @@ func NewFakeContainer() *fakeContainer {
}
}

func (c *fakeContainer) IsolationType() string {
if c.Isolation == "" {
return "bare"
}
return c.Isolation
}

func (c *fakeContainer) Run(ctx context.Context, cmd *repb.Command, workdir string, creds oci.Credentials) *interfaces.CommandResult {
return c.Result
}

func (c *fakeContainer) IsImageCached(ctx context.Context) (bool, error) {
return c.ImageCached, nil
}

func (c *fakeContainer) PullImage(ctx context.Context, creds oci.Credentials) error {
if c.BlockPull {
<-ctx.Done()
return ctx.Err()
}
return nil
}

Expand Down Expand Up @@ -867,3 +885,34 @@ func TestDoNotRecycleSpecialFile(t *testing.T) {
})
}
}

func TestImagePullTimeout(t *testing.T) {
// Enable OCI isolation so we can pull images.
flags.Set(t, "executor.enable_oci", true)
// Time out image pulls immediately
flags.Set(t, "executor.image_pull_timeout", 1*time.Nanosecond)

env := newTestEnv(t)
cfg := noLimitsCfg()
cfg.ContainerProvider = providerFunc(func(ctx context.Context, args *container.Init) (container.CommandContainer, error) {
ctr := NewFakeContainer()
ctr.BlockPull = true
ctr.Isolation = "oci"
return ctr, nil
})
pool := newRunnerPool(t, env, cfg)
ctx := withAuthenticatedUser(t, context.Background(), env, "US1")
task := newTask()
plat := task.ExecutionTask.Command.Platform
plat.Properties = append(plat.Properties, []*repb.Platform_Property{
{Name: "container-image", Value: "docker://busybox"},
{Name: "workload-isolation-type", Value: "oci"},
}...)
r, err := pool.Get(ctx, task)
require.NoError(t, err)

err = r.PrepareForTask(ctx)
require.Error(t, err)
assert.True(t, status.IsUnavailableError(err), "expected Unavailable, got %T", err)
assert.Contains(t, err.Error(), "deadline exceeded")
}

0 comments on commit c0c93ec

Please sign in to comment.