Skip to content

Commit

Permalink
Export functions to shut down the process ungracefully
Browse files Browse the repository at this point in the history
Those are useful in more places than just the platform specific entry points.

PiperOrigin-RevId: 686485935
  • Loading branch information
torsm authored and copybara-github committed Oct 16, 2024
1 parent 19279ad commit 41cf2b9
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 23 deletions.
24 changes: 15 additions & 9 deletions fleetspeak/src/client/entry/entry_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import (
"syscall"
"time"

"github.com/google/fleetspeak/fleetspeak/src/common/fscontext"

log "github.com/golang/glog"
)

Expand All @@ -28,15 +30,10 @@ func RunMain(innerMain InnerMain, _ /* windowsServiceName */ string) {
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer cancel()

context.AfterFunc(ctx, func() {
log.Info("Main context stopped, shutting down...")
time.AfterFunc(shutdownTimeout, func() {
if err := dumpProfile(*profileDir, "goroutine", 2); err != nil {
log.Errorf("Failed to dump goroutine profile: %v", err)
}
log.Exitf("Fleetspeak failed to shut down in %v. Exiting ungracefully.", shutdownTimeout)
})
})
stop := fscontext.AfterFuncDelayed(ctx, func() {
ExitUngracefully(fmt.Errorf("process did not exit within %d", shutdownTimeout))
}, shutdownTimeout)
defer stop()

cancelSignal := notifyFunc(func(si os.Signal) {
runtime.GC()
Expand Down Expand Up @@ -83,6 +80,15 @@ func notifyFunc(callback func(os.Signal), signals ...os.Signal) func() {
}
}

// ExitUngracefully can be called to exit the process after a failed attempt to
// properly free all resources.
func ExitUngracefully(cause error) {
if err := dumpProfile(*profileDir, "goroutine", 2); err != nil {
log.Errorf("Failed to dump goroutine profile: %v", err)
}
log.Exitf("Exiting ungracefully due to %v", cause)
}

// dumpProfile writes the given pprof profile to disk with the given debug flag.
func dumpProfile(profileDir, profileName string, pprofDebugFlag int) error {
profileDumpPath := filepath.Join(profileDir, fmt.Sprintf("fleetspeakd-%s-pprof-%d-%v", profileName, os.Getpid(), time.Now().Format("2006-01-02-15-04-05.000")))
Expand Down
24 changes: 14 additions & 10 deletions fleetspeak/src/client/entry/entry_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@ package entry
import (
"context"
"errors"
"fmt"
"os"
"os/signal"
"sync"
"syscall"
"time"

log "github.com/golang/glog"
"github.com/google/fleetspeak/fleetspeak/src/common/fscontext"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/svc"
)
Expand All @@ -30,7 +31,10 @@ func (m *fleetspeakService) Execute(args []string, r <-chan svc.ChangeRequest, c
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

enforceShutdownTimeout(ctx)
stop := fscontext.AfterFuncDelayed(ctx, func() {
ExitUngracefully(fmt.Errorf("process did not exit within %d", shutdownTimeout))
}, shutdownTimeout)
defer stop()

sighupCh := make(chan os.Signal, 1)
defer close(sighupCh)
Expand Down Expand Up @@ -84,7 +88,10 @@ func (m *fleetspeakService) ExecuteAsRegularProcess() {
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer cancel()

enforceShutdownTimeout(ctx)
stop := fscontext.AfterFuncDelayed(ctx, func() {
ExitUngracefully(fmt.Errorf("process did not exit within %d", shutdownTimeout))
}, shutdownTimeout)
defer stop()

err := m.innerMain(ctx, nil)
if err != nil {
Expand All @@ -108,13 +115,10 @@ func RunMain(innerMain InnerMain, windowsServiceName string) {
}
}

func enforceShutdownTimeout(ctx context.Context) {
context.AfterFunc(ctx, func() {
log.Info("Main context stopped, shutting down...")
time.AfterFunc(shutdownTimeout, func() {
log.Exitf("Fleetspeak failed to shut down in %v. Exiting ungracefully.", shutdownTimeout)
})
})
// ExitUngracefully can be called to exit the process after a failed attempt to
// properly free all resources.
func ExitUngracefully(cause error) {
log.Exitf("Exiting ungracefully due to %v", cause)
}

// tryDisableStderr redirects [os.Stderr] to [os.DevNull]. When running as a
Expand Down
28 changes: 28 additions & 0 deletions fleetspeak/src/common/fscontext/fscontext.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"context"
"fmt"
"sync"
"time"
)

// ErrStopRequested is the cancelation cause to be used when callers
Expand Down Expand Up @@ -60,3 +61,30 @@ func WithDoneChan(ctx context.Context, cause error, done <-chan struct{}) (newCt
wg.Wait()
}
}

// AfterFuncDelayed behaves like context.AfterFunc, but with a delay.
// The returned stop function must be called eventually to free resources.
func AfterFuncDelayed(ctx context.Context, f func(), delay time.Duration) func() bool {
stopCh := make(chan func() bool)

cancelInitiation := context.AfterFunc(ctx, func() {
timer := time.AfterFunc(delay, f)
stopCh <- timer.Stop
close(stopCh)
})

return func() bool {
if cancelInitiation() {
// context.AfterFunc will not fire
close(stopCh)
return true
}

stop, ok := <-stopCh
if ok {
return stop()
}

return false
}
}
129 changes: 125 additions & 4 deletions fleetspeak/src/common/fscontext/fscontext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package fscontext_test
import (
"context"
"errors"
"sync/atomic"
"testing"
"time"

Expand All @@ -30,7 +31,7 @@ var (

var shortDuration = 100 * time.Millisecond

func TestWithDoneChanNotCanceled(t *testing.T) {
func TestWithDoneChan_NotCanceled(t *testing.T) {
// Given
doneCh := make(chan struct{})
defer close(doneCh)
Expand All @@ -49,7 +50,7 @@ func TestWithDoneChanNotCanceled(t *testing.T) {
}
}

func TestWithDoneChanCanceledThroughChannel(t *testing.T) {
func TestWithDoneChan_CanceledThroughChannel(t *testing.T) {
// Given
doneCh := make(chan struct{})

Expand All @@ -73,7 +74,7 @@ func TestWithDoneChanCanceledThroughChannel(t *testing.T) {
}
}

func TestWithDoneChanCanceledThroughOuterContext(t *testing.T) {
func TestWithDoneChan_CanceledThroughOuterContext(t *testing.T) {
// Given
doneCh := make(chan struct{})
defer close(doneCh)
Expand All @@ -100,7 +101,7 @@ func TestWithDoneChanCanceledThroughOuterContext(t *testing.T) {
}
}

func TestWithDoneChanCanceledThroughOwnContext(t *testing.T) {
func TestWithDoneChan_CanceledThroughOwnContext(t *testing.T) {
// Given
doneCh := make(chan struct{})
defer close(doneCh)
Expand All @@ -124,3 +125,123 @@ func TestWithDoneChanCanceledThroughOwnContext(t *testing.T) {
}
}
}

func TestAfterFuncDelayed_DelayReached(t *testing.T) {
// Given
const delay = 100 * time.Millisecond
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

flag := atomic.Bool{}
setFlag := func() {
flag.Store(true)
}

stop := fscontext.AfterFuncDelayed(ctx, setFlag, delay)
defer stop()

// When
cancel()
time.Sleep(2 * delay)

// Then
if !flag.Load() {
t.Errorf("flag not set after %v", delay)
}
if stop() {
t.Errorf("stop() returned true, but flag was set")
}
}

func TestAfterFuncDelayed_StopBeforeCancel(t *testing.T) {
// Given
const delay = 100 * time.Millisecond
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

flag := atomic.Bool{}
setFlag := func() {
flag.Store(true)
}

stop := fscontext.AfterFuncDelayed(ctx, setFlag, delay)
defer stop()

// When
stopped := stop()
cancel()
time.Sleep(2 * delay)

// Then
if !stopped {
t.Errorf("stop() returned false, but context was not canceled")
}
if flag.Load() {
t.Errorf("flag was set, but stop() was called before")
}
if stop() {
t.Errorf("stop() returned true, but stop() was called before")
}
}

func TestAfterFuncDelayed_StopAfterCancel(t *testing.T) {
// Given
const delay = 100 * time.Millisecond
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

flag := atomic.Bool{}
setFlag := func() {
flag.Store(true)
}

stop := fscontext.AfterFuncDelayed(ctx, setFlag, delay)
defer stop()

// When
cancel()
stopped := stop()
time.Sleep(2 * delay)

// Then
if !stopped {
t.Errorf("stop() returned false, but delay was not reached")
}
if flag.Load() {
t.Errorf("flag was set, but stop() was called before")
}
if stop() {
t.Errorf("stop() returned true, but stop() was called before")
}
}

func TestAfterFuncDelayed_StopAfterDelay(t *testing.T) {
// Given
const delay = 100 * time.Millisecond
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

flag := atomic.Bool{}
setFlag := func() {
flag.Store(true)
}

stop := fscontext.AfterFuncDelayed(ctx, setFlag, delay)
defer stop()

// When
cancel()
time.Sleep(2 * delay)
stopped := stop()

// Then
if stopped {
t.Errorf("stop() returned true, but delay was reached")
}
if !flag.Load() {
t.Errorf("flag was not set, but delay was reached")
}
if stop() {
t.Errorf("stop() returned true, but stop() was called before")
}
}

0 comments on commit 41cf2b9

Please sign in to comment.