diff --git a/pkg/function-proxy/proxy.go b/pkg/function-proxy/proxy.go index 7210b5ab1..d884eb6bc 100644 --- a/pkg/function-proxy/proxy.go +++ b/pkg/function-proxy/proxy.go @@ -82,8 +82,19 @@ func startNativeDaemon() { func main() { go startNativeDaemon() - http.HandleFunc("/", handler) - http.HandleFunc("/healthz", health) - http.Handle("/metrics", promhttp.Handler()) - utils.ListenAndServe() + + mux := http.NewServeMux() + mux.HandleFunc("/", handler) + mux.HandleFunc("/healthz", health) + mux.Handle("/metrics", promhttp.Handler()) + + server := utils.NewServer(mux) + + go func() { + if err := server.ListenAndServe(); err != http.ErrServerClosed { + panic(err) + } + }() + + utils.GracefulShutdown(server) } diff --git a/pkg/function-proxy/utils/proxy-utils.go b/pkg/function-proxy/utils/proxy-utils.go index bc53050d1..77bce1c36 100644 --- a/pkg/function-proxy/utils/proxy-utils.go +++ b/pkg/function-proxy/utils/proxy-utils.go @@ -22,17 +22,21 @@ import ( "log" "net/http" "os" + "os/signal" "strconv" + "syscall" "time" "github.com/prometheus/client_golang/prometheus" ) var ( - timeout = os.Getenv("FUNC_TIMEOUT") - funcPort = os.Getenv("FUNC_PORT") - intTimeout int - funcHistogram = prometheus.NewHistogramVec(prometheus.HistogramOpts{ + timeout = os.Getenv("FUNC_TIMEOUT") + funcPort = os.Getenv("FUNC_PORT") + shutdownTimeout = os.Getenv("SHUTDOWN_TIMEOUT") + intTimeout int + intShutdownTimeout int + funcHistogram = prometheus.NewHistogramVec(prometheus.HistogramOpts{ Name: "function_duration_seconds", Help: "Duration of user function in seconds", }, []string{"method"}) @@ -53,11 +57,18 @@ func init() { if funcPort == "" { funcPort = "8080" } + if shutdownTimeout == "" { + shutdownTimeout = "10" + } var err error intTimeout, err = strconv.Atoi(timeout) if err != nil { panic(err) } + intShutdownTimeout, err = strconv.Atoi(shutdownTimeout) + if err != nil { + panic(err) + } prometheus.MustRegister(funcHistogram, funcCalls, funcErrors) } @@ -139,9 +150,26 @@ func Handler(w http.ResponseWriter, r *http.Request, h Handle) { } } -// ListenAndServe starts an HTTP server in FUNC_PORT using custom logging -func ListenAndServe() { - if err := http.ListenAndServe(fmt.Sprintf(":%s", funcPort), logReq(http.DefaultServeMux)); err != nil { - panic(err) +// NewServer returns an HTTP server ready to listen on the configured port +// and with logReq mixed in for logging. +func NewServer(mux *http.ServeMux) *http.Server { + return &http.Server{Addr: fmt.Sprintf(":%s", funcPort), Handler: logReq(mux)} +} + +// GracefulShutdown accepts a server reference and triggers a graceful shutdown +// for it when either SIGINT or SIGTERM is received. +func GracefulShutdown(server *http.Server) { + stop := make(chan os.Signal, 1) + signal.Notify(stop, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + <-stop + timeoutDuration := time.Duration(intShutdownTimeout) * time.Second + ctx, cancel := context.WithTimeout(context.Background(), timeoutDuration) + defer cancel() + + log.Printf("Shuting down with timeout: %s\n", timeoutDuration) + if err := server.Shutdown(ctx); err != nil { + log.Printf("Error: %v\n", err) + } else { + log.Println("Server stopped") } }