From b78b5412d6ec3c6dd3c393e3712f863e9271f5a5 Mon Sep 17 00:00:00 2001 From: Mostafa Moradian Date: Thu, 26 Dec 2024 01:40:27 +0100 Subject: [PATCH] Refactor run command into a separate file --- cmd/gatewayd_app.go | 1071 +++++++++++++++++++++++++++++++++++++++++++ cmd/run.go | 1037 +++-------------------------------------- errors/errors.go | 14 +- 3 files changed, 1148 insertions(+), 974 deletions(-) create mode 100644 cmd/gatewayd_app.go diff --git a/cmd/gatewayd_app.go b/cmd/gatewayd_app.go new file mode 100644 index 00000000..933bfc09 --- /dev/null +++ b/cmd/gatewayd_app.go @@ -0,0 +1,1071 @@ +package cmd + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "net/http" + "net/url" + "os" + "os/signal" + "runtime" + "strconv" + "time" + + "github.com/NYTimes/gziphandler" + sdkAct "github.com/gatewayd-io/gatewayd-plugin-sdk/act" + sdkPlugin "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin" + v1 "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin/v1" + "github.com/gatewayd-io/gatewayd/act" + "github.com/gatewayd-io/gatewayd/api" + "github.com/gatewayd-io/gatewayd/config" + gdErr "github.com/gatewayd-io/gatewayd/errors" + "github.com/gatewayd-io/gatewayd/logging" + "github.com/gatewayd-io/gatewayd/metrics" + "github.com/gatewayd-io/gatewayd/network" + "github.com/gatewayd-io/gatewayd/plugin" + "github.com/gatewayd-io/gatewayd/pool" + "github.com/gatewayd-io/gatewayd/raft" + usage "github.com/gatewayd-io/gatewayd/usagereport/v1" + "github.com/getsentry/sentry-go" + "github.com/go-co-op/gocron" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" + "github.com/redis/go-redis/v9" + "github.com/rs/zerolog" + "github.com/spf13/cobra" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" +) + +type GatewayDApp struct { + EnableTracing bool + EnableSentry bool + EnableLinting bool + EnableUsageReport bool + DevMode bool + CollectorURL string + PluginConfigFile string + GlobalConfigFile string + + conf *config.Config + pluginRegistry *plugin.Registry + actRegistry *act.Registry + metricsServer *http.Server + metricsMerger *metrics.Merger + httpServer *api.HTTPServer + grpcServer *api.GRPCServer + + loggers map[string]zerolog.Logger + pools map[string]map[string]*pool.Pool + clients map[string]map[string]*config.Client + proxies map[string]map[string]*network.Proxy + servers map[string]*network.Server + healthCheckScheduler *gocron.Scheduler + stopChan chan struct{} +} + +// loadConfig loads global and plugin configuration. +func (app *GatewayDApp) loadConfig(runCtx context.Context) error { + app.conf = config.NewConfig(runCtx, + config.Config{ + GlobalConfigFile: app.GlobalConfigFile, + PluginConfigFile: app.PluginConfigFile, + }, + ) + if err := app.conf.InitConfig(runCtx); err != nil { + return err + } + return nil +} + +// createLoggers creates loggers from the config. +func (app *GatewayDApp) createLoggers( + runCtx context.Context, cmd *cobra.Command, +) zerolog.Logger { + // Use cobra command cmd instead of os.Stdout for the console output. + cmdLogger := &cobraCmdWriter{cmd} + + // Create a logger for each tenant. + for name, cfg := range app.conf.Global.Loggers { + app.loggers[name] = logging.NewLogger(runCtx, logging.LoggerConfig{ + Output: cfg.GetOutput(), + ConsoleOut: cmdLogger, + Level: config.If( + config.Exists(config.LogLevels, cfg.Level), + config.LogLevels[cfg.Level], + config.LogLevels[config.DefaultLogLevel], + ), + TimeFormat: config.If( + config.Exists(config.TimeFormats, cfg.TimeFormat), + config.TimeFormats[cfg.TimeFormat], + config.TimeFormats[config.DefaultTimeFormat], + ), + ConsoleTimeFormat: config.If( + config.Exists( + config.ConsoleTimeFormats, cfg.ConsoleTimeFormat), + config.ConsoleTimeFormats[cfg.ConsoleTimeFormat], + config.ConsoleTimeFormats[config.DefaultConsoleTimeFormat], + ), + NoColor: cfg.NoColor, + FileName: cfg.FileName, + MaxSize: cfg.MaxSize, + MaxBackups: cfg.MaxBackups, + MaxAge: cfg.MaxAge, + Compress: cfg.Compress, + LocalTime: cfg.LocalTime, + SyslogPriority: cfg.GetSyslogPriority(), + RSyslogNetwork: cfg.RSyslogNetwork, + RSyslogAddress: cfg.RSyslogAddress, + Name: name, + }) + } + return app.loggers[config.Default] +} + +// createActRegistry creates a new act registry given +// the built-in signals, policies, and actions. +func (app *GatewayDApp) createActRegistry(logger zerolog.Logger) error { + // Create a new act registry given the built-in signals, policies, and actions. + var publisher *act.Publisher + if app.conf.Plugin.ActionRedis.Enabled { + rdb := redis.NewClient(&redis.Options{ + Addr: app.conf.Plugin.ActionRedis.Address, + }) + var err error + publisher, err = act.NewPublisher(act.Publisher{ + Logger: logger, + RedisDB: rdb, + ChannelName: app.conf.Plugin.ActionRedis.Channel, + }) + if err != nil { + logger.Error().Err(err).Msg("Failed to create publisher for act registry") + return err + } + logger.Info().Msg("Created Redis publisher for Act registry") + } + + app.actRegistry = act.NewActRegistry( + act.Registry{ + Signals: act.BuiltinSignals(), + Policies: act.BuiltinPolicies(), + Actions: act.BuiltinActions(), + DefaultPolicyName: app.conf.Plugin.DefaultPolicy, + PolicyTimeout: app.conf.Plugin.PolicyTimeout, + DefaultActionTimeout: app.conf.Plugin.ActionTimeout, + TaskPublisher: publisher, + Logger: logger, + }, + ) + + return nil +} + +// loadPolicies loads policies from the configuration file and +// adds them to the registry. +func (app *GatewayDApp) loadPolicies(logger zerolog.Logger) error { + // Load policies from the configuration file and add them to the registry. + for _, plc := range app.conf.Plugin.Policies { + if policy, err := sdkAct.NewPolicy( + plc.Name, plc.Policy, plc.Metadata, + ); err != nil || policy == nil { + logger.Error().Err(err).Str("name", plc.Name).Msg("Failed to create policy") + return err + } else { + app.actRegistry.Add(policy) + } + } + + return nil +} + +// createPluginRegistry creates a new plugin registry. +func (app *GatewayDApp) createPluginRegistry(runCtx context.Context, logger zerolog.Logger) { + // Create a new plugin registry. + // The plugins are loaded and hooks registered before the configuration is loaded. + app.pluginRegistry = plugin.NewRegistry( + runCtx, + plugin.Registry{ + ActRegistry: app.actRegistry, + Compatibility: config.If( + config.Exists( + config.CompatibilityPolicies, app.conf.Plugin.CompatibilityPolicy, + ), + config.CompatibilityPolicies[app.conf.Plugin.CompatibilityPolicy], + config.DefaultCompatibilityPolicy), + Logger: logger, + DevMode: app.DevMode, + }, + ) +} + +// startMetricsMerger starts the metrics merger if enabled. +func (app *GatewayDApp) startMetricsMerger(runCtx context.Context, logger zerolog.Logger) { + // Start the metrics merger if enabled. + if app.conf.Plugin.EnableMetricsMerger { + app.metricsMerger = metrics.NewMerger(runCtx, metrics.Merger{ + MetricsMergerPeriod: app.conf.Plugin.MetricsMergerPeriod, + Logger: logger, + }) + app.pluginRegistry.ForEach(func(_ sdkPlugin.Identifier, plugin *plugin.Plugin) { + if metricsEnabled, err := strconv.ParseBool(plugin.Config["metricsEnabled"]); err == nil && metricsEnabled { + app.metricsMerger.Add(plugin.ID.Name, plugin.Config["metricsUnixDomainSocket"]) + logger.Debug().Str("plugin", plugin.ID.Name).Msg( + "Added plugin to metrics merger") + } + }) + app.metricsMerger.Start() + } +} + +// startHealthCheckScheduler starts the health check scheduler if enabled. +func (app *GatewayDApp) startHealthCheckScheduler( + runCtx, ctx context.Context, span trace.Span, logger zerolog.Logger, +) { + // Ping the plugins to check if they are alive, and remove them if they are not. + startDelay := time.Now().Add(app.conf.Plugin.HealthCheckPeriod) + if _, err := app.healthCheckScheduler.Every( + app.conf.Plugin.HealthCheckPeriod).SingletonMode().StartAt(startDelay).Do( + func() { + _, span := otel.Tracer(config.TracerName).Start(ctx, "Run plugin health check") + defer span.End() + + var plugins []string + app.pluginRegistry.ForEach( + func(pluginId sdkPlugin.Identifier, plugin *plugin.Plugin) { + if err := plugin.Ping(); err != nil { + span.RecordError(err) + logger.Error().Err(err).Msg("Failed to ping plugin") + if app.conf.Plugin.EnableMetricsMerger && app.metricsMerger != nil { + app.metricsMerger.Remove(pluginId.Name) + } + app.pluginRegistry.Remove(pluginId) + + if !app.conf.Plugin.ReloadOnCrash { + return // Do not reload the plugins. + } + + // Reload the plugins and register their hooks upon crash. + logger.Info().Str("name", pluginId.Name).Msg("Reloading crashed plugin") + pluginConfig := app.conf.Plugin.GetPlugins(pluginId.Name) + if pluginConfig != nil { + app.pluginRegistry.LoadPlugins(runCtx, pluginConfig, app.conf.Plugin.StartTimeout) + } + } else { + logger.Trace().Str("name", pluginId.Name).Msg("Successfully pinged plugin") + plugins = append(plugins, pluginId.Name) + } + }) + span.SetAttributes(attribute.StringSlice("plugins", plugins)) + }); err != nil { + logger.Error().Err(err).Msg("Failed to start plugin health check scheduler") + span.RecordError(err) + } + + // Start the health check scheduler only if there are plugins. + if app.pluginRegistry.Size() > 0 { + logger.Info().Str( + "healthCheckPeriod", app.conf.Plugin.HealthCheckPeriod.String(), + ).Msg("Starting plugin health check scheduler") + app.healthCheckScheduler.StartAsync() + } +} + +// onConfigLoaded runs the OnConfigLoaded hook and +// merges the global config with the one from the plugins. +func (app *GatewayDApp) onConfigLoaded( + runCtx context.Context, span trace.Span, logger zerolog.Logger, +) error { + // Set the plugin timeout context. + pluginTimeoutCtx, cancel := context.WithTimeout( + context.Background(), app.conf.Plugin.Timeout) + defer cancel() + + // The config will be passed to the plugins that register to the "OnConfigLoaded" plugin. + // The plugins can modify the config and return it. + updatedGlobalConfig, err := app.pluginRegistry.Run( + pluginTimeoutCtx, app.conf.GlobalKoanf.All(), v1.HookName_HOOK_NAME_ON_CONFIG_LOADED) + if err != nil { + logger.Error().Err(err).Msg("Failed to run OnConfigLoaded hooks") + span.RecordError(err) + } + if updatedGlobalConfig != nil { + updatedGlobalConfig = app.pluginRegistry.ActRegistry.RunAll(updatedGlobalConfig) + } + + // If the config was modified by the plugins, merge it with the one loaded from the file. + // Only global configuration is merged, which means that plugins cannot modify the plugin + // configurations. + if updatedGlobalConfig != nil { + // Merge the config with the one loaded from the file (in memory). + // The changes won't be persisted to disk. + if err := app.conf.MergeGlobalConfig(runCtx, updatedGlobalConfig); err != nil { + logger.Error().Err(err).Msg("Failed to merge global config") + span.RecordError(err) + return err + } + } + + return nil +} + +// startMetricsServer starts the metrics server if enabled. +func (app *GatewayDApp) startMetricsServer( + runCtx context.Context, logger zerolog.Logger, +) error { + // Start the metrics server if enabled. + // TODO: Start multiple metrics servers. For now, only one default is supported. + // I should first find a use case for those multiple metrics servers. + _, span := otel.Tracer(config.TracerName).Start(runCtx, "Start metrics server") + defer span.End() + + metricsConfig := app.conf.Global.Metrics[config.Default] + + // TODO: refactor this to a separate function. + if !metricsConfig.Enabled { + logger.Info().Msg("Metrics server is disabled") + return nil + } + + scheme := "http://" + if metricsConfig.KeyFile != "" && metricsConfig.CertFile != "" { + scheme = "https://" + } + + fqdn, err := url.Parse(scheme + metricsConfig.Address) + if err != nil { + logger.Error().Err(err).Msg("Failed to parse metrics address") + span.RecordError(err) + return err + } + + address, err := url.JoinPath(fqdn.String(), metricsConfig.Path) + if err != nil { + logger.Error().Err(err).Msg("Failed to parse metrics path") + span.RecordError(err) + return err + } + + // Merge the metrics from the plugins with the ones from GatewayD. + mergedMetricsHandler := func(next http.Handler) http.Handler { + handler := func(responseWriter http.ResponseWriter, request *http.Request) { + if _, err := responseWriter.Write(app.metricsMerger.OutputMetrics); err != nil { + logger.Error().Err(err).Msg("Failed to write metrics") + span.RecordError(err) + sentry.CaptureException(err) + } + // The WriteHeader method intentionally does nothing, to prevent a bug + // in the merging metrics that causes the headers to be written twice, + // which results in an error: "http: superfluous response.WriteHeader call". + next.ServeHTTP( + &metrics.HeaderBypassResponseWriter{ + ResponseWriter: responseWriter, + }, + request) + } + return http.HandlerFunc(handler) + } + + handler := func() http.Handler { + return promhttp.InstrumentMetricHandler( + prometheus.DefaultRegisterer, + promhttp.HandlerFor(prometheus.DefaultGatherer, promhttp.HandlerOpts{ + DisableCompression: true, + }), + ) + }() + + mux := http.NewServeMux() + mux.HandleFunc("/", func(responseWriter http.ResponseWriter, _ *http.Request) { + // Serve a static page with a link to the metrics endpoint. + if _, err := responseWriter.Write([]byte(fmt.Sprintf( + `GatewayD Prometheus Metrics ServerMetrics`, + address, + ))); err != nil { + logger.Error().Err(err).Msg("Failed to write metrics") + span.RecordError(err) + sentry.CaptureException(err) + } + }) + + if app.conf.Plugin.EnableMetricsMerger && app.metricsMerger != nil { + handler = mergedMetricsHandler(handler) + } + + readHeaderTimeout := config.If( + metricsConfig.ReadHeaderTimeout > 0, + metricsConfig.ReadHeaderTimeout, + config.DefaultReadHeaderTimeout, + ) + + // Check if the metrics server is already running before registering the handler. + if _, err = http.Get(address); err != nil { //nolint:gosec + // The timeout handler limits the nested handlers from running for too long. + mux.Handle( + metricsConfig.Path, + http.TimeoutHandler( + gziphandler.GzipHandler(handler), + readHeaderTimeout, + "The request timed out while fetching the metrics", + ), + ) + } else { + logger.Warn().Msg("Metrics server is already running, consider changing the port") + span.RecordError(err) + } + + // Create a new metrics server. + timeout := config.If( + metricsConfig.Timeout > 0, + metricsConfig.Timeout, + config.DefaultMetricsServerTimeout, + ) + app.metricsServer = &http.Server{ + Addr: metricsConfig.Address, + Handler: mux, + ReadHeaderTimeout: readHeaderTimeout, + ReadTimeout: timeout, + WriteTimeout: timeout, + IdleTimeout: timeout, + } + + logger.Info().Fields(map[string]any{ + "address": address, + "timeout": timeout.String(), + "readHeaderTimeout": readHeaderTimeout.String(), + }).Msg("Metrics are exposed") + + if metricsConfig.CertFile != "" && metricsConfig.KeyFile != "" { + // Set up TLS. + app.metricsServer.TLSConfig = &tls.Config{ + MinVersion: tls.VersionTLS13, + CurvePreferences: []tls.CurveID{ + tls.CurveP521, + tls.CurveP384, + tls.CurveP256, + }, + CipherSuites: []uint16{ + tls.TLS_AES_128_GCM_SHA256, + tls.TLS_AES_256_GCM_SHA384, + tls.TLS_CHACHA20_POLY1305_SHA256, + }, + } + app.metricsServer.TLSNextProto = make( + map[string]func(*http.Server, *tls.Conn, http.Handler)) + logger.Debug().Msg("Metrics server is running with TLS") + + // Start the metrics server with TLS. + if err = app.metricsServer.ListenAndServeTLS( + metricsConfig.CertFile, metricsConfig.KeyFile); !errors.Is(err, http.ErrServerClosed) { + logger.Error().Err(err).Msg("Failed to start metrics server") + span.RecordError(err) + } + } else { + // Start the metrics server without TLS. + if err = app.metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { + logger.Error().Err(err).Msg("Failed to start metrics server") + span.RecordError(err) + } + } + + return nil +} + +// onNewLogger runs the OnNewLogger hook. +func (app *GatewayDApp) onNewLogger( + span trace.Span, logger zerolog.Logger, +) { + // This is a notification hook, so we don't care about the result. + pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), app.conf.Plugin.Timeout) + defer cancel() + + if data, ok := app.conf.GlobalKoanf.Get("loggers").(map[string]any); ok { + result, err := app.pluginRegistry.Run( + pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_NEW_LOGGER) + if err != nil { + logger.Error().Err(err).Msg("Failed to run OnNewLogger hooks") + span.RecordError(err) + } + if result != nil { + _ = app.pluginRegistry.ActRegistry.RunAll(result) + } + } else { + logger.Error().Msg("Failed to get loggers from config") + } +} + +// createPoolAndClients creates pools of connections and clients. +func (app *GatewayDApp) createPoolAndClients( + runCtx context.Context, span trace.Span, +) error { + // Create and initialize pools of connections. + for configGroupName, configGroup := range app.conf.Global.Pools { + for configBlockName, cfg := range configGroup { + logger := app.loggers[configGroupName] + // Check if the pool size is greater than zero. + currentPoolSize := config.If( + cfg.Size > 0, + // Check if the pool size is greater than the minimum pool size. + config.If( + cfg.Size > config.MinimumPoolSize, + cfg.Size, + config.MinimumPoolSize, + ), + config.DefaultPoolSize, + ) + + if _, ok := app.pools[configGroupName]; !ok { + app.pools[configGroupName] = make(map[string]*pool.Pool) + } + app.pools[configGroupName][configBlockName] = pool.NewPool(runCtx, currentPoolSize) + + span.AddEvent("Create pool", trace.WithAttributes( + attribute.String("name", configBlockName), + attribute.Int("size", currentPoolSize), + )) + + if _, ok := app.clients[configGroupName]; !ok { + app.clients[configGroupName] = make(map[string]*config.Client) + } + + // Get client config from the config file. + if clientConfig, ok := app.conf.Global.Clients[configGroupName][configBlockName]; !ok { + // This ensures that the default client config is used if the pool name is not + // found in the clients section. + app.clients[configGroupName][configBlockName] = app.conf.Global.Clients[config.Default][config.DefaultConfigurationBlock] //nolint:lll + } else { + // Merge the default client config with the one from the pool. + app.clients[configGroupName][configBlockName] = clientConfig + } + + // Fill the missing and zero values with the default ones. + app.clients[configGroupName][configBlockName].TCPKeepAlivePeriod = config.If( + app.clients[configGroupName][configBlockName].TCPKeepAlivePeriod > 0, + app.clients[configGroupName][configBlockName].TCPKeepAlivePeriod, + config.DefaultTCPKeepAlivePeriod, + ) + app.clients[configGroupName][configBlockName].ReceiveDeadline = config.If( + app.clients[configGroupName][configBlockName].ReceiveDeadline > 0, + app.clients[configGroupName][configBlockName].ReceiveDeadline, + config.DefaultReceiveDeadline, + ) + app.clients[configGroupName][configBlockName].ReceiveTimeout = config.If( + app.clients[configGroupName][configBlockName].ReceiveTimeout > 0, + app.clients[configGroupName][configBlockName].ReceiveTimeout, + config.DefaultReceiveTimeout, + ) + app.clients[configGroupName][configBlockName].SendDeadline = config.If( + app.clients[configGroupName][configBlockName].SendDeadline > 0, + app.clients[configGroupName][configBlockName].SendDeadline, + config.DefaultSendDeadline, + ) + app.clients[configGroupName][configBlockName].ReceiveChunkSize = config.If( + app.clients[configGroupName][configBlockName].ReceiveChunkSize > 0, + app.clients[configGroupName][configBlockName].ReceiveChunkSize, + config.DefaultChunkSize, + ) + app.clients[configGroupName][configBlockName].DialTimeout = config.If( + app.clients[configGroupName][configBlockName].DialTimeout > 0, + app.clients[configGroupName][configBlockName].DialTimeout, + config.DefaultDialTimeout, + ) + + // Add clients to the pool. + for range currentPoolSize { + clientConfig := app.clients[configGroupName][configBlockName] + clientConfig.GroupName = configGroupName + clientConfig.BlockName = configBlockName + client := network.NewClient( + runCtx, clientConfig, logger, + network.NewRetry( + network.Retry{ + Retries: clientConfig.Retries, + Backoff: config.If( + clientConfig.Backoff > 0, + clientConfig.Backoff, + config.DefaultBackoff, + ), + BackoffMultiplier: clientConfig.BackoffMultiplier, + DisableBackoffCaps: clientConfig.DisableBackoffCaps, + Logger: app.loggers[configBlockName], + }, + ), + ) + + if client == nil { + return errors.New("failed to create client, please check the configuration") + } + + eventOptions := trace.WithAttributes( + attribute.String("name", configBlockName), + attribute.String("group", configGroupName), + attribute.String("network", client.Network), + attribute.String("address", client.Address), + attribute.Int("receiveChunkSize", client.ReceiveChunkSize), + attribute.String("receiveDeadline", client.ReceiveDeadline.String()), + attribute.String("receiveTimeout", client.ReceiveTimeout.String()), + attribute.String("sendDeadline", client.SendDeadline.String()), + attribute.String("dialTimeout", client.DialTimeout.String()), + attribute.Bool("tcpKeepAlive", client.TCPKeepAlive), + attribute.String("tcpKeepAlivePeriod", client.TCPKeepAlivePeriod.String()), + attribute.String("localAddress", client.LocalAddr()), + attribute.String("remoteAddress", client.RemoteAddr()), + attribute.Int("retries", clientConfig.Retries), + attribute.String("backoff", client.Retry().Backoff.String()), + attribute.Float64("backoffMultiplier", clientConfig.BackoffMultiplier), + attribute.Bool("disableBackoffCaps", clientConfig.DisableBackoffCaps), + ) + if client.ID != "" { + eventOptions = trace.WithAttributes( + attribute.String("id", client.ID), + ) + } + + span.AddEvent("Create client", eventOptions) + + pluginTimeoutCtx, cancel := context.WithTimeout( + context.Background(), app.conf.Plugin.Timeout) + defer cancel() + + clientCfg := map[string]any{ + "id": client.ID, + "name": configBlockName, + "group": configGroupName, + "network": client.Network, + "address": client.Address, + "receiveChunkSize": client.ReceiveChunkSize, + "receiveDeadline": client.ReceiveDeadline.String(), + "receiveTimeout": client.ReceiveTimeout.String(), + "sendDeadline": client.SendDeadline.String(), + "dialTimeout": client.DialTimeout.String(), + "tcpKeepAlive": client.TCPKeepAlive, + "tcpKeepAlivePeriod": client.TCPKeepAlivePeriod.String(), + "localAddress": client.LocalAddr(), + "remoteAddress": client.RemoteAddr(), + "retries": clientConfig.Retries, + "backoff": client.Retry().Backoff.String(), + "backoffMultiplier": clientConfig.BackoffMultiplier, + "disableBackoffCaps": clientConfig.DisableBackoffCaps, + } + result, err := app.pluginRegistry.Run( + pluginTimeoutCtx, clientCfg, v1.HookName_HOOK_NAME_ON_NEW_CLIENT) + if err != nil { + logger.Error().Err(err).Msg("Failed to run OnNewClient hooks") + span.RecordError(err) + } + if result != nil { + _ = app.pluginRegistry.ActRegistry.RunAll(result) + } + + err = app.pools[configGroupName][configBlockName].Put(client.ID, client) + if err != nil { + logger.Error().Err(err).Msg("Failed to add client to the pool") + span.RecordError(err) + } + } + + // Verify that the pool is properly populated. + logger.Info().Fields(map[string]any{ + "name": configBlockName, + "count": strconv.Itoa(app.pools[configGroupName][configBlockName].Size()), + }).Msg("There are clients available in the pool") + + if app.pools[configGroupName][configBlockName].Size() != currentPoolSize { + logger.Error().Msg( + "The pool size is incorrect, either because " + + "the clients cannot connect due to no network connectivity " + + "or the server is not running. exiting...") + app.pluginRegistry.Shutdown() + return errors.New("failed to initialize pool, please check the configuration") + } + + // Run the OnNewPool hook. + pluginTimeoutCtx, cancel := context.WithTimeout( + context.Background(), app.conf.Plugin.Timeout) + defer cancel() + + result, err := app.pluginRegistry.Run( + pluginTimeoutCtx, + map[string]any{"name": configBlockName, "size": currentPoolSize}, + v1.HookName_HOOK_NAME_ON_NEW_POOL) + if err != nil { + logger.Error().Err(err).Msg("Failed to run OnNewPool hooks") + span.RecordError(err) + } + if result != nil { + _ = app.pluginRegistry.ActRegistry.RunAll(result) + } + } + } + + return nil +} + +// createProxies creates proxies. +func (app *GatewayDApp) createProxies(runCtx context.Context, span trace.Span) { + // Create and initialize prefork proxies with each pool of clients. + for configGroupName, configGroup := range app.conf.Global.Proxies { + for configBlockName, cfg := range configGroup { + logger := app.loggers[configGroupName] + clientConfig := app.clients[configGroupName][configBlockName] + + // Fill the missing and zero value with the default one. + cfg.HealthCheckPeriod = config.If( + cfg.HealthCheckPeriod > 0, + cfg.HealthCheckPeriod, + config.DefaultHealthCheckPeriod, + ) + + if _, ok := app.proxies[configGroupName]; !ok { + app.proxies[configGroupName] = make(map[string]*network.Proxy) + } + + app.proxies[configGroupName][configBlockName] = network.NewProxy( + runCtx, + network.Proxy{ + GroupName: configGroupName, + BlockName: configBlockName, + AvailableConnections: app.pools[configGroupName][configBlockName], + PluginRegistry: app.pluginRegistry, + HealthCheckPeriod: cfg.HealthCheckPeriod, + ClientConfig: clientConfig, + Logger: logger, + PluginTimeout: app.conf.Plugin.Timeout, + }, + ) + + span.AddEvent("Create proxy", trace.WithAttributes( + attribute.String("name", configBlockName), + attribute.String("healthCheckPeriod", cfg.HealthCheckPeriod.String()), + )) + + pluginTimeoutCtx, cancel := context.WithTimeout( + context.Background(), app.conf.Plugin.Timeout) + defer cancel() + + if data, ok := app.conf.GlobalKoanf.Get("proxies").(map[string]any); ok { + result, err := app.pluginRegistry.Run( + pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_NEW_PROXY) + if err != nil { + logger.Error().Err(err).Msg("Failed to run OnNewProxy hooks") + span.RecordError(err) + } + if result != nil { + _ = app.pluginRegistry.ActRegistry.RunAll(result) + } + } else { + logger.Error().Msg("Failed to get proxy from config") + } + } + } +} + +// createServers creates servers. +func (app *GatewayDApp) createServers( + runCtx context.Context, span trace.Span, raftNode *raft.Node, +) { + // Create and initialize servers. + for name, cfg := range app.conf.Global.Servers { + logger := app.loggers[name] + + var serverProxies []network.IProxy + for _, proxy := range app.proxies[name] { + serverProxies = append(serverProxies, proxy) + } + + app.servers[name] = network.NewServer( + runCtx, + network.Server{ + GroupName: name, + Network: cfg.Network, + Address: cfg.Address, + TickInterval: config.If( + cfg.TickInterval > 0, + cfg.TickInterval, + config.DefaultTickInterval, + ), + Options: network.Option{ + // Can be used to send keepalive messages to the client. + EnableTicker: cfg.EnableTicker, + }, + Proxies: serverProxies, + Logger: logger, + PluginRegistry: app.pluginRegistry, + PluginTimeout: app.conf.Plugin.Timeout, + EnableTLS: cfg.EnableTLS, + CertFile: cfg.CertFile, + KeyFile: cfg.KeyFile, + HandshakeTimeout: cfg.HandshakeTimeout, + LoadbalancerStrategyName: cfg.LoadBalancer.Strategy, + LoadbalancerRules: cfg.LoadBalancer.LoadBalancingRules, + LoadbalancerConsistentHash: cfg.LoadBalancer.ConsistentHash, + RaftNode: raftNode, + }, + ) + + span.AddEvent("Create server", trace.WithAttributes( + attribute.String("name", name), + attribute.String("network", cfg.Network), + attribute.String("address", cfg.Address), + attribute.String("tickInterval", cfg.TickInterval.String()), + attribute.String("pluginTimeout", app.conf.Plugin.Timeout.String()), + attribute.Bool("enableTLS", cfg.EnableTLS), + attribute.String("certFile", cfg.CertFile), + attribute.String("keyFile", cfg.KeyFile), + attribute.String("handshakeTimeout", cfg.HandshakeTimeout.String()), + )) + + pluginTimeoutCtx, cancel := context.WithTimeout( + context.Background(), app.conf.Plugin.Timeout) + defer cancel() + + if data, ok := app.conf.GlobalKoanf.Get("servers").(map[string]any); ok { + result, err := app.pluginRegistry.Run( + pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_NEW_SERVER) + if err != nil { + logger.Error().Err(err).Msg("Failed to run OnNewServer hooks") + span.RecordError(err) + } + if result != nil { + _ = app.pluginRegistry.ActRegistry.RunAll(result) + } + } else { + logger.Error().Msg("Failed to get the servers configuration") + } + } +} + +// startAPIServers starts the API servers. +func (app *GatewayDApp) startAPIServers( + runCtx context.Context, logger zerolog.Logger, raftNode *raft.Node, +) { + // Start the HTTP and gRPC APIs. + if !app.conf.Global.API.Enabled { + logger.Info().Msg("API is not enabled, skipping") + return + } + + apiOptions := api.Options{ + Logger: logger, + GRPCNetwork: app.conf.Global.API.GRPCNetwork, + GRPCAddress: app.conf.Global.API.GRPCAddress, + HTTPAddress: app.conf.Global.API.HTTPAddress, + Servers: app.servers, + RaftNode: raftNode, + } + + apiObj := &api.API{ + Options: &apiOptions, + Config: app.conf, + PluginRegistry: app.pluginRegistry, + Pools: app.pools, + Proxies: app.proxies, + Servers: app.servers, + } + app.grpcServer = api.NewGRPCServer( + runCtx, + api.GRPCServer{ + API: apiObj, + HealthChecker: &api.HealthChecker{Servers: app.servers}, + }, + ) + if app.grpcServer != nil { + go app.grpcServer.Start() + logger.Info().Str("address", apiOptions.HTTPAddress).Msg("Started the HTTP API") + + app.httpServer = api.NewHTTPServer(&apiOptions) + go app.httpServer.Start() + + logger.Info().Fields( + map[string]any{ + "network": apiOptions.GRPCNetwork, + "address": apiOptions.GRPCAddress, + }, + ).Msg("Started the gRPC Server") + } +} + +// reportUsage reports usage statistics. +func (app *GatewayDApp) reportUsage(logger zerolog.Logger) { + if !app.EnableUsageReport { + logger.Info().Msg("Usage reporting is not enabled, skipping") + return + } + + // Report usage statistics. + go func() { + conn, err := grpc.NewClient( + UsageReportURL, + grpc.WithTransportCredentials( + credentials.NewTLS( + &tls.Config{ + MinVersion: tls.VersionTLS12, + }, + ), + ), + ) + if err != nil { + logger.Trace().Err(err).Msg( + "Failed to dial to the gRPC server for usage reporting") + } + defer func(conn *grpc.ClientConn) { + err := conn.Close() + if err != nil { + logger.Trace().Err(err).Msg("Failed to close the connection to the usage report service") + } + }(conn) + + client := usage.NewUsageReportServiceClient(conn) + report := usage.UsageReportRequest{ + Version: config.Version, + RuntimeVersion: runtime.Version(), + Goos: runtime.GOOS, + Goarch: runtime.GOARCH, + Service: "gatewayd", + DevMode: app.DevMode, + Plugins: []*usage.Plugin{}, + } + app.pluginRegistry.ForEach( + func(identifier sdkPlugin.Identifier, _ *plugin.Plugin) { + report.Plugins = append(report.GetPlugins(), &usage.Plugin{ + Name: identifier.Name, + Version: identifier.Version, + Checksum: identifier.Checksum, + }) + }, + ) + _, err = client.Report(context.Background(), &report) + if err != nil { + logger.Trace().Err(err).Msg("Failed to report usage statistics") + } + }() +} + +// startServers starts the servers. +func (app *GatewayDApp) startServers( + runCtx context.Context, span trace.Span, +) { + // Start the server. + for name, server := range app.servers { + logger := app.loggers[name] + go func( + span trace.Span, + server *network.Server, + logger zerolog.Logger, + healthCheckScheduler *gocron.Scheduler, + metricsMerger *metrics.Merger, + pluginRegistry *plugin.Registry, + ) { + span.AddEvent("Start server") + if err := server.Run(); err != nil { + logger.Error().Err(err).Msg("Failed to start server") + span.RecordError(err) + app.stopGracefully(runCtx, nil) + os.Exit(gdErr.FailedToStartServer) + } + }(span, server, logger, app.healthCheckScheduler, app.metricsMerger, app.pluginRegistry) + } +} + +// stopGracefully stops the server gracefully. +func (app *GatewayDApp) stopGracefully(runCtx context.Context, sig os.Signal) { + _, span := otel.Tracer(config.TracerName).Start(runCtx, "Shutdown server") + currentSignal := "unknown" + if sig != nil { + currentSignal = sig.String() + } + + logger := app.loggers[config.Default] + + logger.Info().Msg("Notifying the plugins that the server is shutting down") + if app.pluginRegistry != nil { + pluginTimeoutCtx, cancel := context.WithTimeout( + context.Background(), app.conf.Plugin.Timeout) + defer cancel() + + //nolint:contextcheck + result, err := app.pluginRegistry.Run( + pluginTimeoutCtx, + map[string]any{"signal": currentSignal}, + v1.HookName_HOOK_NAME_ON_SIGNAL, + ) + if err != nil { + logger.Error().Err(err).Msg("Failed to run OnSignal hooks") + span.RecordError(err) + } + if result != nil { + _ = app.pluginRegistry.ActRegistry.RunAll(result) //nolint:contextcheck + } + } + + logger.Info().Msg("GatewayD is shutting down") + span.AddEvent("GatewayD is shutting down", trace.WithAttributes( + attribute.String("signal", currentSignal), + )) + if app.healthCheckScheduler != nil { + app.healthCheckScheduler.Stop() + app.healthCheckScheduler.Clear() + logger.Info().Msg("Stopped health check scheduler") + span.AddEvent("Stopped health check scheduler") + } + if app.metricsMerger != nil { + app.metricsMerger.Stop() + logger.Info().Msg("Stopped metrics merger") + span.AddEvent("Stopped metrics merger") + } + if app.metricsServer != nil { + //nolint:contextcheck + if err := app.metricsServer.Shutdown(context.Background()); err != nil { + logger.Error().Err(err).Msg("Failed to stop metrics server") + span.RecordError(err) + } else { + logger.Info().Msg("Stopped metrics server") + span.AddEvent("Stopped metrics server") + } + } + for name, server := range app.servers { + logger.Info().Str("name", name).Msg("Stopping server") + server.Shutdown() + span.AddEvent("Stopped server") + } + logger.Info().Msg("Stopped all servers") + if app.pluginRegistry != nil { + app.pluginRegistry.Shutdown() + logger.Info().Msg("Stopped plugin registry") + span.AddEvent("Stopped plugin registry") + } + span.End() + + if app.httpServer != nil { + app.httpServer.Shutdown(runCtx) + logger.Info().Msg("Stopped HTTP Server") + span.AddEvent("Stopped HTTP Server") + } + + if app.grpcServer != nil { + app.grpcServer.Shutdown(runCtx) + logger.Info().Msg("Stopped gRPC Server") + span.AddEvent("Stopped gRPC Server") + } + + // Close the stop channel to notify the other goroutines to stop. + app.stopChan <- struct{}{} + close(app.stopChan) +} + +// handleSignals handles the signals and stops the server gracefully. +func (app *GatewayDApp) handleSignals(runCtx context.Context, signals []os.Signal) { + signalsCh := make(chan os.Signal, 1) + signal.Notify(signalsCh, signals...) + + go func() { + for sig := range signalsCh { + app.stopGracefully(runCtx, sig) + os.Exit(0) + } + }() +} diff --git a/cmd/run.go b/cmd/run.go index ff2ace2e..4022d67a 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -2,49 +2,24 @@ package cmd import ( "context" - "crypto/tls" - "errors" - "fmt" "io" "log" - "net/http" - "net/url" "os" - "os/signal" - "runtime" - "strconv" "syscall" "time" - "github.com/NYTimes/gziphandler" - sdkAct "github.com/gatewayd-io/gatewayd-plugin-sdk/act" - sdkPlugin "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin" - v1 "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin/v1" - "github.com/gatewayd-io/gatewayd/act" - "github.com/gatewayd-io/gatewayd/api" "github.com/gatewayd-io/gatewayd/config" gerr "github.com/gatewayd-io/gatewayd/errors" - "github.com/gatewayd-io/gatewayd/logging" - "github.com/gatewayd-io/gatewayd/metrics" "github.com/gatewayd-io/gatewayd/network" - "github.com/gatewayd-io/gatewayd/plugin" "github.com/gatewayd-io/gatewayd/pool" "github.com/gatewayd-io/gatewayd/raft" "github.com/gatewayd-io/gatewayd/tracing" - usage "github.com/gatewayd-io/gatewayd/usagereport/v1" "github.com/getsentry/sentry-go" "github.com/go-co-op/gocron" - "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promhttp" - "github.com/redis/go-redis/v9" "github.com/rs/zerolog" "github.com/spf13/cobra" "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" "golang.org/x/exp/maps" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials" ) var _ io.Writer = &cobraCmdWriter{} @@ -58,37 +33,10 @@ func (c *cobraCmdWriter) Write(p []byte) (int, error) { return len(p), nil } -type GatewayDInstance struct { - EnableTracing bool - EnableSentry bool - EnableLinting bool - EnableUsageReport bool - DevMode bool - CollectorURL string - PluginConfigFile string - GlobalConfigFile string - - conf *config.Config - pluginRegistry *plugin.Registry - actRegistry *act.Registry - metricsServer *http.Server - metricsMerger *metrics.Merger - httpServer *api.HTTPServer - grpcServer *api.GRPCServer - - loggers map[string]zerolog.Logger - pools map[string]map[string]*pool.Pool - clients map[string]map[string]*config.Client - proxies map[string]map[string]*network.Proxy - servers map[string]*network.Server - healthCheckScheduler *gocron.Scheduler - stopChan chan struct{} -} - var ( UsageReportURL = "localhost:59091" testMode bool - testApp *GatewayDInstance + testApp *GatewayDApp ) // EnableTestMode enables test mode and returns the previous value. @@ -99,105 +47,6 @@ func EnableTestMode() bool { return previous } -// stopGracefully stops the server gracefully. -func (app *GatewayDInstance) stopGracefully(runCtx context.Context, sig os.Signal) { - _, span := otel.Tracer(config.TracerName).Start(runCtx, "Shutdown server") - currentSignal := "unknown" - if sig != nil { - currentSignal = sig.String() - } - - logger := app.loggers[config.Default] - - logger.Info().Msg("Notifying the plugins that the server is shutting down") - if app.pluginRegistry != nil { - pluginTimeoutCtx, cancel := context.WithTimeout( - context.Background(), app.conf.Plugin.Timeout) - defer cancel() - - //nolint:contextcheck - result, err := app.pluginRegistry.Run( - pluginTimeoutCtx, - map[string]any{"signal": currentSignal}, - v1.HookName_HOOK_NAME_ON_SIGNAL, - ) - if err != nil { - logger.Error().Err(err).Msg("Failed to run OnSignal hooks") - span.RecordError(err) - } - if result != nil { - _ = app.pluginRegistry.ActRegistry.RunAll(result) //nolint:contextcheck - } - } - - logger.Info().Msg("GatewayD is shutting down") - span.AddEvent("GatewayD is shutting down", trace.WithAttributes( - attribute.String("signal", currentSignal), - )) - if app.healthCheckScheduler != nil { - app.healthCheckScheduler.Stop() - app.healthCheckScheduler.Clear() - logger.Info().Msg("Stopped health check scheduler") - span.AddEvent("Stopped health check scheduler") - } - if app.metricsMerger != nil { - app.metricsMerger.Stop() - logger.Info().Msg("Stopped metrics merger") - span.AddEvent("Stopped metrics merger") - } - if app.metricsServer != nil { - //nolint:contextcheck - if err := app.metricsServer.Shutdown(context.Background()); err != nil { - logger.Error().Err(err).Msg("Failed to stop metrics server") - span.RecordError(err) - } else { - logger.Info().Msg("Stopped metrics server") - span.AddEvent("Stopped metrics server") - } - } - for name, server := range app.servers { - logger.Info().Str("name", name).Msg("Stopping server") - server.Shutdown() - span.AddEvent("Stopped server") - } - logger.Info().Msg("Stopped all servers") - if app.pluginRegistry != nil { - app.pluginRegistry.Shutdown() - logger.Info().Msg("Stopped plugin registry") - span.AddEvent("Stopped plugin registry") - } - span.End() - - if app.httpServer != nil { - app.httpServer.Shutdown(runCtx) - logger.Info().Msg("Stopped HTTP Server") - span.AddEvent("Stopped HTTP Server") - } - - if app.grpcServer != nil { - app.grpcServer.Shutdown(runCtx) - logger.Info().Msg("Stopped gRPC Server") - span.AddEvent("Stopped gRPC Server") - } - - // Close the stop channel to notify the other goroutines to stop. - app.stopChan <- struct{}{} - close(app.stopChan) -} - -// handleSignals handles the signals and stops the server gracefully. -func (app *GatewayDInstance) handleSignals(runCtx context.Context, signals []os.Signal) { - signalsCh := make(chan os.Signal, 1) - signal.Notify(signalsCh, signals...) - - go func() { - for sig := range signalsCh { - app.stopGracefully(runCtx, sig) - os.Exit(0) - } - }() -} - // runCmd represents the run command. var runCmd = &cobra.Command{ Use: "run", @@ -214,18 +63,22 @@ var runCmd = &cobra.Command{ runCtx, span := otel.Tracer(config.TracerName).Start(context.Background(), "GatewayD") span.End() - // Setup signal handling after context is created - signals := []os.Signal{ - os.Interrupt, - os.Kill, - syscall.SIGTERM, - syscall.SIGABRT, - syscall.SIGQUIT, - syscall.SIGHUP, - syscall.SIGINT, - } + // Handle signals from the user. + app.handleSignals( + runCtx, + []os.Signal{ + os.Interrupt, + os.Kill, + syscall.SIGTERM, + syscall.SIGABRT, + syscall.SIGQUIT, + syscall.SIGHUP, + syscall.SIGINT, + }, + ) - app.handleSignals(runCtx, signals) + // Stop the server gracefully when the program terminates cleanly. + defer app.stopGracefully(runCtx, nil) // Enable tracing with OpenTelemetry. if app.EnableTracing { @@ -234,6 +87,7 @@ var runCmd = &cobra.Command{ defer func() { if err := shutdown(context.Background()); err != nil { cmd.Println(err) + app.stopGracefully(runCtx, nil) os.Exit(gerr.FailedToStartTracer) } }() @@ -280,670 +134,91 @@ var runCmd = &cobra.Command{ } } - // Load global and plugin configuration. - app.conf = config.NewConfig(runCtx, - config.Config{ - GlobalConfigFile: app.GlobalConfigFile, - PluginConfigFile: app.PluginConfigFile, - }, - ) - if err := app.conf.InitConfig(runCtx); err != nil { + // Load the configuration files. + if err := app.loadConfig(runCtx); err != nil { log.Fatal(err) } - // Create and initialize App.loggers from the config. - // Use cobra command cmd instead of os.Stdout for the console output. - cmdLogger := &cobraCmdWriter{cmd} - for name, cfg := range app.conf.Global.Loggers { - app.loggers[name] = logging.NewLogger(runCtx, logging.LoggerConfig{ - Output: cfg.GetOutput(), - ConsoleOut: cmdLogger, - Level: config.If( - config.Exists(config.LogLevels, cfg.Level), - config.LogLevels[cfg.Level], - config.LogLevels[config.DefaultLogLevel], - ), - TimeFormat: config.If( - config.Exists(config.TimeFormats, cfg.TimeFormat), - config.TimeFormats[cfg.TimeFormat], - config.TimeFormats[config.DefaultTimeFormat], - ), - ConsoleTimeFormat: config.If( - config.Exists( - config.ConsoleTimeFormats, cfg.ConsoleTimeFormat), - config.ConsoleTimeFormats[cfg.ConsoleTimeFormat], - config.ConsoleTimeFormats[config.DefaultConsoleTimeFormat], - ), - NoColor: cfg.NoColor, - FileName: cfg.FileName, - MaxSize: cfg.MaxSize, - MaxBackups: cfg.MaxBackups, - MaxAge: cfg.MaxAge, - Compress: cfg.Compress, - LocalTime: cfg.LocalTime, - SyslogPriority: cfg.GetSyslogPriority(), - RSyslogNetwork: cfg.RSyslogNetwork, - RSyslogAddress: cfg.RSyslogAddress, - Name: name, - }) - } - - // Set the default logger. - logger := app.loggers[config.Default] + // Create and initialize loggers from the config. + // And then set the default logger. + logger := app.createLoggers(runCtx, cmd) if app.DevMode { logger.Warn().Msg( "Running GatewayD in development mode (not recommended for production)") } - // Create a new act registry given the built-in signals, policies, and actions. - var publisher *act.Publisher - if app.conf.Plugin.ActionRedis.Enabled { - rdb := redis.NewClient(&redis.Options{ - Addr: app.conf.Plugin.ActionRedis.Address, - }) - var err error - publisher, err = act.NewPublisher(act.Publisher{ - Logger: logger, - RedisDB: rdb, - ChannelName: app.conf.Plugin.ActionRedis.Channel, - }) - if err != nil { - logger.Error().Err(err).Msg("Failed to create publisher for act registry") - os.Exit(gerr.FailedToCreateActRegistry) - } - logger.Info().Msg("Created Redis publisher for Act registry") - } - - app.actRegistry = act.NewActRegistry( - act.Registry{ - Signals: act.BuiltinSignals(), - Policies: act.BuiltinPolicies(), - Actions: act.BuiltinActions(), - DefaultPolicyName: app.conf.Plugin.DefaultPolicy, - PolicyTimeout: app.conf.Plugin.PolicyTimeout, - DefaultActionTimeout: app.conf.Plugin.ActionTimeout, - TaskPublisher: publisher, - Logger: logger, - }) - - if app.actRegistry == nil { - logger.Error().Msg("Failed to create act registry") + // Create the Act registry. + if err := app.createActRegistry(logger); err != nil { + logger.Error().Err(err).Msg("Failed to create act registry") + app.stopGracefully(runCtx, nil) os.Exit(gerr.FailedToCreateActRegistry) } // Load policies from the configuration file and add them to the registry. - for _, plc := range app.conf.Plugin.Policies { - if policy, err := sdkAct.NewPolicy( - plc.Name, plc.Policy, plc.Metadata, - ); err != nil || policy == nil { - logger.Error().Err(err).Str("name", plc.Name).Msg("Failed to create policy") - } else { - app.actRegistry.Add(policy) - } + if err := app.loadPolicies(logger); err != nil { + logger.Error().Err(err).Msg("Failed to load policies") + app.stopGracefully(runCtx, nil) + os.Exit(gerr.FailedToLoadPolicies) } logger.Info().Fields(map[string]any{ "policies": maps.Keys(app.actRegistry.Policies), }).Msg("Policies are loaded") - // Create a new plugin registry. - // The plugins are loaded and hooks registered before the configuration is loaded. - app.pluginRegistry = plugin.NewRegistry( - runCtx, - plugin.Registry{ - ActRegistry: app.actRegistry, - Compatibility: config.If( - config.Exists( - config.CompatibilityPolicies, app.conf.Plugin.CompatibilityPolicy, - ), - config.CompatibilityPolicies[app.conf.Plugin.CompatibilityPolicy], - config.DefaultCompatibilityPolicy), - Logger: logger, - DevMode: app.DevMode, - }, - ) + // Create the plugin registry. + app.createPluginRegistry(runCtx, logger) // Load plugins and register their hooks. - app.pluginRegistry.LoadPlugins(runCtx, app.conf.Plugin.Plugins, app.conf.Plugin.StartTimeout) + app.pluginRegistry.LoadPlugins( + runCtx, + app.conf.Plugin.Plugins, + app.conf.Plugin.StartTimeout, + ) // Start the metrics merger if enabled. - if app.conf.Plugin.EnableMetricsMerger { - app.metricsMerger = metrics.NewMerger(runCtx, metrics.Merger{ - MetricsMergerPeriod: app.conf.Plugin.MetricsMergerPeriod, - Logger: logger, - }) - app.pluginRegistry.ForEach(func(_ sdkPlugin.Identifier, plugin *plugin.Plugin) { - if metricsEnabled, err := strconv.ParseBool(plugin.Config["metricsEnabled"]); err == nil && metricsEnabled { - app.metricsMerger.Add(plugin.ID.Name, plugin.Config["metricsUnixDomainSocket"]) - logger.Debug().Str("plugin", plugin.ID.Name).Msg( - "Added plugin to metrics merger") - } - }) - app.metricsMerger.Start() - } + app.startMetricsMerger(runCtx, logger) // TODO: Move this to the plugin registry. ctx, span := otel.Tracer(config.TracerName).Start(runCtx, "Plugin health check") - // Ping the plugins to check if they are alive, and remove them if they are not. - startDelay := time.Now().Add(app.conf.Plugin.HealthCheckPeriod) - if _, err := app.healthCheckScheduler.Every( - app.conf.Plugin.HealthCheckPeriod).SingletonMode().StartAt(startDelay).Do(func() { - _, span := otel.Tracer(config.TracerName).Start(ctx, "Run plugin health check") - defer span.End() - - var plugins []string - app.pluginRegistry.ForEach(func(pluginId sdkPlugin.Identifier, plugin *plugin.Plugin) { - if err := plugin.Ping(); err != nil { - span.RecordError(err) - logger.Error().Err(err).Msg("Failed to ping plugin") - if app.conf.Plugin.EnableMetricsMerger && app.metricsMerger != nil { - app.metricsMerger.Remove(pluginId.Name) - } - app.pluginRegistry.Remove(pluginId) - - if !app.conf.Plugin.ReloadOnCrash { - return // Do not reload the plugins. - } - - // Reload the plugins and register their hooks upon crash. - logger.Info().Str("name", pluginId.Name).Msg("Reloading crashed plugin") - pluginConfig := app.conf.Plugin.GetPlugins(pluginId.Name) - if pluginConfig != nil { - app.pluginRegistry.LoadPlugins(runCtx, pluginConfig, app.conf.Plugin.StartTimeout) - } - } else { - logger.Trace().Str("name", pluginId.Name).Msg("Successfully pinged plugin") - plugins = append(plugins, pluginId.Name) - } - }) - span.SetAttributes(attribute.StringSlice("plugins", plugins)) - }); err != nil { - logger.Error().Err(err).Msg("Failed to start plugin health check scheduler") - span.RecordError(err) - } - if app.pluginRegistry.Size() > 0 { - logger.Info().Str( - "healthCheckPeriod", app.conf.Plugin.HealthCheckPeriod.String(), - ).Msg("Starting plugin health check scheduler") - app.healthCheckScheduler.StartAsync() - } + // Start the health check scheduler only if there are plugins. + app.startHealthCheckScheduler(runCtx, ctx, span, logger) span.End() - // Set the plugin timeout context. - pluginTimeoutCtx, cancel := context.WithTimeout(context.Background(), app.conf.Plugin.Timeout) - defer cancel() - - // The config will be passed to the plugins that register to the "OnConfigLoaded" plugin. - // The plugins can modify the config and return it. - updatedGlobalConfig, err := app.pluginRegistry.Run( - pluginTimeoutCtx, app.conf.GlobalKoanf.All(), v1.HookName_HOOK_NAME_ON_CONFIG_LOADED) - if err != nil { - logger.Error().Err(err).Msg("Failed to run OnConfigLoaded hooks") - span.RecordError(err) - } - if updatedGlobalConfig != nil { - updatedGlobalConfig = app.pluginRegistry.ActRegistry.RunAll(updatedGlobalConfig) - } - - // If the config was modified by the plugins, merge it with the one loaded from the file. - // Only global configuration is merged, which means that plugins cannot modify the plugin - // configurations. - if updatedGlobalConfig != nil { - // Merge the config with the one loaded from the file (in memory). - // The changes won't be persisted to disk. - if err := app.conf.MergeGlobalConfig(runCtx, updatedGlobalConfig); err != nil { - log.Fatal(err) - } + // Merge the global config with the one from the plugins. + if err := app.onConfigLoaded(runCtx, span, logger); err != nil { + app.stopGracefully(runCtx, nil) + os.Exit(gerr.FailedToMergeGlobalConfig) } // Start the metrics server if enabled. - // TODO: Start multiple metrics servers. For now, only one default is supported. - // I should first find a use case for those multiple metrics servers. - go func(metricsConfig *config.Metrics, logger zerolog.Logger) { - _, span := otel.Tracer(config.TracerName).Start(runCtx, "Start metrics server") - defer span.End() - - // TODO: refactor this to a separate function. - if !metricsConfig.Enabled { - logger.Info().Msg("Metrics server is disabled") - return - } - - scheme := "http://" - if metricsConfig.KeyFile != "" && metricsConfig.CertFile != "" { - scheme = "https://" - } - - fqdn, err := url.Parse(scheme + metricsConfig.Address) - if err != nil { - logger.Error().Err(err).Msg("Failed to parse metrics address") + go func(app *GatewayDApp) { + if err := app.startMetricsServer(runCtx, logger); err != nil { + logger.Error().Err(err).Msg("Failed to start metrics server") span.RecordError(err) - return } + }(app) - address, err := url.JoinPath(fqdn.String(), metricsConfig.Path) - if err != nil { - logger.Error().Err(err).Msg("Failed to parse metrics path") - span.RecordError(err) - return - } - - // Merge the metrics from the plugins with the ones from GatewayD. - mergedMetricsHandler := func(next http.Handler) http.Handler { - handler := func(responseWriter http.ResponseWriter, request *http.Request) { - if _, err := responseWriter.Write(app.metricsMerger.OutputMetrics); err != nil { - logger.Error().Err(err).Msg("Failed to write metrics") - span.RecordError(err) - sentry.CaptureException(err) - } - // The WriteHeader method intentionally does nothing, to prevent a bug - // in the merging metrics that causes the headers to be written twice, - // which results in an error: "http: superfluous response.WriteHeader call". - next.ServeHTTP( - &metrics.HeaderBypassResponseWriter{ - ResponseWriter: responseWriter, - }, - request) - } - return http.HandlerFunc(handler) - } - - handler := func() http.Handler { - return promhttp.InstrumentMetricHandler( - prometheus.DefaultRegisterer, - promhttp.HandlerFor(prometheus.DefaultGatherer, promhttp.HandlerOpts{ - DisableCompression: true, - }), - ) - }() - - mux := http.NewServeMux() - mux.HandleFunc("/", func(responseWriter http.ResponseWriter, _ *http.Request) { - // Serve a static page with a link to the metrics endpoint. - if _, err := responseWriter.Write([]byte(fmt.Sprintf( - `GatewayD Prometheus Metrics ServerMetrics`, - address, - ))); err != nil { - logger.Error().Err(err).Msg("Failed to write metrics") - span.RecordError(err) - sentry.CaptureException(err) - } - }) - - if app.conf.Plugin.EnableMetricsMerger && app.metricsMerger != nil { - handler = mergedMetricsHandler(handler) - } - - readHeaderTimeout := config.If( - metricsConfig.ReadHeaderTimeout > 0, - metricsConfig.ReadHeaderTimeout, - config.DefaultReadHeaderTimeout, - ) - - // Check if the metrics server is already running before registering the handler. - if _, err = http.Get(address); err != nil { //nolint:gosec - // The timeout handler limits the nested handlers from running for too long. - mux.Handle( - metricsConfig.Path, - http.TimeoutHandler( - gziphandler.GzipHandler(handler), - readHeaderTimeout, - "The request timed out while fetching the metrics", - ), - ) - } else { - logger.Warn().Msg("Metrics server is already running, consider changing the port") - span.RecordError(err) - } - - // Create a new metrics server. - timeout := config.If( - metricsConfig.Timeout > 0, - metricsConfig.Timeout, - config.DefaultMetricsServerTimeout, - ) - app.metricsServer = &http.Server{ - Addr: metricsConfig.Address, - Handler: mux, - ReadHeaderTimeout: readHeaderTimeout, - ReadTimeout: timeout, - WriteTimeout: timeout, - IdleTimeout: timeout, - } - - logger.Info().Fields(map[string]any{ - "address": address, - "timeout": timeout.String(), - "readHeaderTimeout": readHeaderTimeout.String(), - }).Msg("Metrics are exposed") - - if metricsConfig.CertFile != "" && metricsConfig.KeyFile != "" { - // Set up TLS. - app.metricsServer.TLSConfig = &tls.Config{ - MinVersion: tls.VersionTLS13, - CurvePreferences: []tls.CurveID{ - tls.CurveP521, - tls.CurveP384, - tls.CurveP256, - }, - CipherSuites: []uint16{ - tls.TLS_AES_128_GCM_SHA256, - tls.TLS_AES_256_GCM_SHA384, - tls.TLS_CHACHA20_POLY1305_SHA256, - }, - } - app.metricsServer.TLSNextProto = make( - map[string]func(*http.Server, *tls.Conn, http.Handler)) - logger.Debug().Msg("Metrics server is running with TLS") - - // Start the metrics server with TLS. - if err = app.metricsServer.ListenAndServeTLS( - metricsConfig.CertFile, metricsConfig.KeyFile); !errors.Is(err, http.ErrServerClosed) { - logger.Error().Err(err).Msg("Failed to start metrics server") - span.RecordError(err) - } - } else { - // Start the metrics server without TLS. - if err = app.metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { - logger.Error().Err(err).Msg("Failed to start metrics server") - span.RecordError(err) - } - } - }(app.conf.Global.Metrics[config.Default], logger) - - // This is a notification hook, so we don't care about the result. - pluginTimeoutCtx, cancel = context.WithTimeout(context.Background(), app.conf.Plugin.Timeout) - defer cancel() - - if data, ok := app.conf.GlobalKoanf.Get("loggers").(map[string]any); ok { - result, err := app.pluginRegistry.Run( - pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_NEW_LOGGER) - if err != nil { - logger.Error().Err(err).Msg("Failed to run OnNewLogger hooks") - span.RecordError(err) - } - if result != nil { - _ = app.pluginRegistry.ActRegistry.RunAll(result) - } - } else { - logger.Error().Msg("Failed to get loggers from config") - } - - // Declare httpServer and grpcServer here as it is used in the StopGracefully function ahead of their definition. - var httpServer *api.HTTPServer - var grpcServer *api.GRPCServer + // Run the OnNewLogger hook. + app.onNewLogger(span, logger) _, span = otel.Tracer(config.TracerName).Start(runCtx, "Create pools and clients") - // Create and initialize pools of connections. - for configGroupName, configGroup := range app.conf.Global.Pools { - for configBlockName, cfg := range configGroup { - logger := app.loggers[configGroupName] - // Check if the pool size is greater than zero. - currentPoolSize := config.If( - cfg.Size > 0, - // Check if the pool size is greater than the minimum pool size. - config.If( - cfg.Size > config.MinimumPoolSize, - cfg.Size, - config.MinimumPoolSize, - ), - config.DefaultPoolSize, - ) - - if _, ok := app.pools[configGroupName]; !ok { - app.pools[configGroupName] = make(map[string]*pool.Pool) - } - app.pools[configGroupName][configBlockName] = pool.NewPool(runCtx, currentPoolSize) - - span.AddEvent("Create pool", trace.WithAttributes( - attribute.String("name", configBlockName), - attribute.Int("size", currentPoolSize), - )) - - if _, ok := app.clients[configGroupName]; !ok { - app.clients[configGroupName] = make(map[string]*config.Client) - } - - // Get client config from the config file. - if clientConfig, ok := app.conf.Global.Clients[configGroupName][configBlockName]; !ok { - // This ensures that the default client config is used if the pool name is not - // found in the clients section. - app.clients[configGroupName][configBlockName] = app.conf.Global.Clients[config.Default][config.DefaultConfigurationBlock] //nolint:lll - } else { - // Merge the default client config with the one from the pool. - app.clients[configGroupName][configBlockName] = clientConfig - } - - // Fill the missing and zero values with the default ones. - app.clients[configGroupName][configBlockName].TCPKeepAlivePeriod = config.If( - app.clients[configGroupName][configBlockName].TCPKeepAlivePeriod > 0, - app.clients[configGroupName][configBlockName].TCPKeepAlivePeriod, - config.DefaultTCPKeepAlivePeriod, - ) - app.clients[configGroupName][configBlockName].ReceiveDeadline = config.If( - app.clients[configGroupName][configBlockName].ReceiveDeadline > 0, - app.clients[configGroupName][configBlockName].ReceiveDeadline, - config.DefaultReceiveDeadline, - ) - app.clients[configGroupName][configBlockName].ReceiveTimeout = config.If( - app.clients[configGroupName][configBlockName].ReceiveTimeout > 0, - app.clients[configGroupName][configBlockName].ReceiveTimeout, - config.DefaultReceiveTimeout, - ) - app.clients[configGroupName][configBlockName].SendDeadline = config.If( - app.clients[configGroupName][configBlockName].SendDeadline > 0, - app.clients[configGroupName][configBlockName].SendDeadline, - config.DefaultSendDeadline, - ) - app.clients[configGroupName][configBlockName].ReceiveChunkSize = config.If( - app.clients[configGroupName][configBlockName].ReceiveChunkSize > 0, - app.clients[configGroupName][configBlockName].ReceiveChunkSize, - config.DefaultChunkSize, - ) - app.clients[configGroupName][configBlockName].DialTimeout = config.If( - app.clients[configGroupName][configBlockName].DialTimeout > 0, - app.clients[configGroupName][configBlockName].DialTimeout, - config.DefaultDialTimeout, - ) - - // Add clients to the pool. - for range currentPoolSize { - clientConfig := app.clients[configGroupName][configBlockName] - clientConfig.GroupName = configGroupName - clientConfig.BlockName = configBlockName - client := network.NewClient( - runCtx, clientConfig, logger, - network.NewRetry( - network.Retry{ - Retries: clientConfig.Retries, - Backoff: config.If( - clientConfig.Backoff > 0, - clientConfig.Backoff, - config.DefaultBackoff, - ), - BackoffMultiplier: clientConfig.BackoffMultiplier, - DisableBackoffCaps: clientConfig.DisableBackoffCaps, - Logger: app.loggers[configBlockName], - }, - ), - ) - - if client != nil { - eventOptions := trace.WithAttributes( - attribute.String("name", configBlockName), - attribute.String("group", configGroupName), - attribute.String("network", client.Network), - attribute.String("address", client.Address), - attribute.Int("receiveChunkSize", client.ReceiveChunkSize), - attribute.String("receiveDeadline", client.ReceiveDeadline.String()), - attribute.String("receiveTimeout", client.ReceiveTimeout.String()), - attribute.String("sendDeadline", client.SendDeadline.String()), - attribute.String("dialTimeout", client.DialTimeout.String()), - attribute.Bool("tcpKeepAlive", client.TCPKeepAlive), - attribute.String("tcpKeepAlivePeriod", client.TCPKeepAlivePeriod.String()), - attribute.String("localAddress", client.LocalAddr()), - attribute.String("remoteAddress", client.RemoteAddr()), - attribute.Int("retries", clientConfig.Retries), - attribute.String("backoff", client.Retry().Backoff.String()), - attribute.Float64("backoffMultiplier", clientConfig.BackoffMultiplier), - attribute.Bool("disableBackoffCaps", clientConfig.DisableBackoffCaps), - ) - if client.ID != "" { - eventOptions = trace.WithAttributes( - attribute.String("id", client.ID), - ) - } - - span.AddEvent("Create client", eventOptions) - - pluginTimeoutCtx, cancel = context.WithTimeout( - context.Background(), app.conf.Plugin.Timeout) - defer cancel() - - clientCfg := map[string]any{ - "id": client.ID, - "name": configBlockName, - "group": configGroupName, - "network": client.Network, - "address": client.Address, - "receiveChunkSize": client.ReceiveChunkSize, - "receiveDeadline": client.ReceiveDeadline.String(), - "receiveTimeout": client.ReceiveTimeout.String(), - "sendDeadline": client.SendDeadline.String(), - "dialTimeout": client.DialTimeout.String(), - "tcpKeepAlive": client.TCPKeepAlive, - "tcpKeepAlivePeriod": client.TCPKeepAlivePeriod.String(), - "localAddress": client.LocalAddr(), - "remoteAddress": client.RemoteAddr(), - "retries": clientConfig.Retries, - "backoff": client.Retry().Backoff.String(), - "backoffMultiplier": clientConfig.BackoffMultiplier, - "disableBackoffCaps": clientConfig.DisableBackoffCaps, - } - result, err := app.pluginRegistry.Run( - pluginTimeoutCtx, clientCfg, v1.HookName_HOOK_NAME_ON_NEW_CLIENT) - if err != nil { - logger.Error().Err(err).Msg("Failed to run OnNewClient hooks") - span.RecordError(err) - } - if result != nil { - _ = app.pluginRegistry.ActRegistry.RunAll(result) - } - - err = app.pools[configGroupName][configBlockName].Put(client.ID, client) - if err != nil { - logger.Error().Err(err).Msg("Failed to add client to the pool") - span.RecordError(err) - } - } else { - logger.Error().Msg("Failed to create client, please check the configuration") - go func() { - // Wait for the stop signal to exit gracefully. - // This prevents the program from waiting indefinitely - // after the StopGracefully function is called. - <-app.stopChan - os.Exit(gerr.FailedToCreateClient) - }() - app.stopGracefully(runCtx, nil) - os.Exit(gerr.FailedToCreateClient) - } - } - - // Verify that the pool is properly populated. - logger.Info().Fields(map[string]any{ - "name": configBlockName, - "count": strconv.Itoa(app.pools[configGroupName][configBlockName].Size()), - }).Msg("There are clients available in the pool") - - if app.pools[configGroupName][configBlockName].Size() != currentPoolSize { - logger.Error().Msg( - "The pool size is incorrect, either because " + - "the clients cannot connect due to no network connectivity " + - "or the server is not running. exiting...") - app.pluginRegistry.Shutdown() - os.Exit(gerr.FailedToInitializePool) - } - - pluginTimeoutCtx, cancel = context.WithTimeout( - context.Background(), app.conf.Plugin.Timeout) - defer cancel() - result, err := app.pluginRegistry.Run( - pluginTimeoutCtx, - map[string]any{"name": configBlockName, "size": currentPoolSize}, - v1.HookName_HOOK_NAME_ON_NEW_POOL) - if err != nil { - logger.Error().Err(err).Msg("Failed to run OnNewPool hooks") - span.RecordError(err) - } - if result != nil { - _ = app.pluginRegistry.ActRegistry.RunAll(result) - } - } + if err := app.createPoolAndClients(runCtx, span); err != nil { + logger.Error().Err(err).Msg("Failed to create pools and clients") + span.RecordError(err) + app.stopGracefully(runCtx, nil) + os.Exit(gerr.FailedToCreatePoolAndClients) } span.End() _, span = otel.Tracer(config.TracerName).Start(runCtx, "Create proxies") - // Create and initialize prefork proxies with each pool of clients. - for configGroupName, configGroup := range app.conf.Global.Proxies { - for configBlockName, cfg := range configGroup { - logger := app.loggers[configGroupName] - clientConfig := app.clients[configGroupName][configBlockName] - - // Fill the missing and zero value with the default one. - cfg.HealthCheckPeriod = config.If( - cfg.HealthCheckPeriod > 0, - cfg.HealthCheckPeriod, - config.DefaultHealthCheckPeriod, - ) - - if _, ok := app.proxies[configGroupName]; !ok { - app.proxies[configGroupName] = make(map[string]*network.Proxy) - } - - app.proxies[configGroupName][configBlockName] = network.NewProxy( - runCtx, - network.Proxy{ - GroupName: configGroupName, - BlockName: configBlockName, - AvailableConnections: app.pools[configGroupName][configBlockName], - PluginRegistry: app.pluginRegistry, - HealthCheckPeriod: cfg.HealthCheckPeriod, - ClientConfig: clientConfig, - Logger: logger, - PluginTimeout: app.conf.Plugin.Timeout, - }, - ) - - span.AddEvent("Create proxy", trace.WithAttributes( - attribute.String("name", configBlockName), - attribute.String("healthCheckPeriod", cfg.HealthCheckPeriod.String()), - )) - pluginTimeoutCtx, cancel = context.WithTimeout( - context.Background(), app.conf.Plugin.Timeout) - defer cancel() - - if data, ok := app.conf.GlobalKoanf.Get("proxies").(map[string]any); ok { - result, err := app.pluginRegistry.Run( - pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_NEW_PROXY) - if err != nil { - logger.Error().Err(err).Msg("Failed to run OnNewProxy hooks") - span.RecordError(err) - } - if result != nil { - _ = app.pluginRegistry.ActRegistry.RunAll(result) - } - } else { - logger.Error().Msg("Failed to get proxy from config") - } - - } - } + // Create proxies. + app.createProxies(runCtx, span) span.End() @@ -954,202 +229,28 @@ var runCmd = &cobra.Command{ if originalErr != nil { logger.Error().Err(originalErr).Msg("Failed to start raft node") span.RecordError(originalErr) - app.pluginRegistry.Shutdown() + app.stopGracefully(runCtx, nil) os.Exit(gerr.FailedToStartRaftNode) } _, span = otel.Tracer(config.TracerName).Start(runCtx, "Create servers") - // Create and initialize servers. - for name, cfg := range app.conf.Global.Servers { - logger := app.loggers[name] - - var serverProxies []network.IProxy - for _, proxy := range app.proxies[name] { - serverProxies = append(serverProxies, proxy) - } - - app.servers[name] = network.NewServer( - runCtx, - network.Server{ - GroupName: name, - Network: cfg.Network, - Address: cfg.Address, - TickInterval: config.If( - cfg.TickInterval > 0, - cfg.TickInterval, - config.DefaultTickInterval, - ), - Options: network.Option{ - // Can be used to send keepalive messages to the client. - EnableTicker: cfg.EnableTicker, - }, - Proxies: serverProxies, - Logger: logger, - PluginRegistry: app.pluginRegistry, - PluginTimeout: app.conf.Plugin.Timeout, - EnableTLS: cfg.EnableTLS, - CertFile: cfg.CertFile, - KeyFile: cfg.KeyFile, - HandshakeTimeout: cfg.HandshakeTimeout, - LoadbalancerStrategyName: cfg.LoadBalancer.Strategy, - LoadbalancerRules: cfg.LoadBalancer.LoadBalancingRules, - LoadbalancerConsistentHash: cfg.LoadBalancer.ConsistentHash, - RaftNode: raftNode, - }, - ) - span.AddEvent("Create server", trace.WithAttributes( - attribute.String("name", name), - attribute.String("network", cfg.Network), - attribute.String("address", cfg.Address), - attribute.String("tickInterval", cfg.TickInterval.String()), - attribute.String("pluginTimeout", app.conf.Plugin.Timeout.String()), - attribute.Bool("enableTLS", cfg.EnableTLS), - attribute.String("certFile", cfg.CertFile), - attribute.String("keyFile", cfg.KeyFile), - attribute.String("handshakeTimeout", cfg.HandshakeTimeout.String()), - )) - - pluginTimeoutCtx, cancel = context.WithTimeout( - context.Background(), app.conf.Plugin.Timeout) - defer cancel() - - if data, ok := app.conf.GlobalKoanf.Get("servers").(map[string]any); ok { - result, err := app.pluginRegistry.Run( - pluginTimeoutCtx, data, v1.HookName_HOOK_NAME_ON_NEW_SERVER) - if err != nil { - logger.Error().Err(err).Msg("Failed to run OnNewServer hooks") - span.RecordError(err) - } - if result != nil { - _ = app.pluginRegistry.ActRegistry.RunAll(result) - } - } else { - logger.Error().Msg("Failed to get the servers configuration") - } - } + // Create servers. + app.createServers(runCtx, span, raftNode) span.End() - // Start the HTTP and gRPC APIs. - if app.conf.Global.API.Enabled { - apiOptions := api.Options{ - Logger: logger, - GRPCNetwork: app.conf.Global.API.GRPCNetwork, - GRPCAddress: app.conf.Global.API.GRPCAddress, - HTTPAddress: app.conf.Global.API.HTTPAddress, - Servers: app.servers, - RaftNode: raftNode, - } - - apiObj := &api.API{ - Options: &apiOptions, - Config: app.conf, - PluginRegistry: app.pluginRegistry, - Pools: app.pools, - Proxies: app.proxies, - Servers: app.servers, - } - grpcServer = api.NewGRPCServer( - runCtx, - api.GRPCServer{ - API: apiObj, - HealthChecker: &api.HealthChecker{Servers: app.servers}, - }, - ) - if grpcServer != nil { - go grpcServer.Start() - logger.Info().Str("address", apiOptions.HTTPAddress).Msg("Started the HTTP API") - - httpServer = api.NewHTTPServer(&apiOptions) - go httpServer.Start() - - logger.Info().Fields( - map[string]any{ - "network": apiOptions.GRPCNetwork, - "address": apiOptions.GRPCAddress, - }, - ).Msg("Started the gRPC Server") - } - } + // Start the API servers. + app.startAPIServers(runCtx, logger, raftNode) - // Report usage statistics. - if app.EnableUsageReport { - go func() { - conn, err := grpc.NewClient( - UsageReportURL, - grpc.WithTransportCredentials( - credentials.NewTLS( - &tls.Config{ - MinVersion: tls.VersionTLS12, - }, - ), - ), - ) - if err != nil { - logger.Trace().Err(err).Msg( - "Failed to dial to the gRPC server for usage reporting") - } - defer func(conn *grpc.ClientConn) { - err := conn.Close() - if err != nil { - logger.Trace().Err(err).Msg("Failed to close the connection to the usage report service") - } - }(conn) - - client := usage.NewUsageReportServiceClient(conn) - report := usage.UsageReportRequest{ - Version: config.Version, - RuntimeVersion: runtime.Version(), - Goos: runtime.GOOS, - Goarch: runtime.GOARCH, - Service: "gatewayd", - DevMode: app.DevMode, - Plugins: []*usage.Plugin{}, - } - app.pluginRegistry.ForEach( - func(identifier sdkPlugin.Identifier, _ *plugin.Plugin) { - report.Plugins = append(report.GetPlugins(), &usage.Plugin{ - Name: identifier.Name, - Version: identifier.Version, - Checksum: identifier.Checksum, - }) - }, - ) - _, err = client.Report(context.Background(), &report) - if err != nil { - logger.Trace().Err(err).Msg("Failed to report usage statistics") - } - }() - } + // Report usage. + app.reportUsage(logger) _, span = otel.Tracer(config.TracerName).Start(runCtx, "Start servers") - // Start the server. - for name, server := range app.servers { - logger := app.loggers[name] - go func( - span trace.Span, - server *network.Server, - logger zerolog.Logger, - healthCheckScheduler *gocron.Scheduler, - metricsMerger *metrics.Merger, - pluginRegistry *plugin.Registry, - ) { - span.AddEvent("Start server") - if err := server.Run(); err != nil { - logger.Error().Err(err).Msg("Failed to start server") - span.RecordError(err) - healthCheckScheduler.Clear() - if metricsMerger != nil { - metricsMerger.Stop() - } - server.Shutdown() - pluginRegistry.Shutdown() - os.Exit(gerr.FailedToStartServer) - } - }(span, server, logger, app.healthCheckScheduler, app.metricsMerger, app.pluginRegistry) - } + // Start the servers. + app.startServers(runCtx, span) + span.End() // Wait for the server to shut down. @@ -1176,8 +277,8 @@ func init() { runCmd.Flags().Bool("metrics-merger", true, "Enable metrics merger") } -func NewGatewayDInstance(cmd *cobra.Command) *GatewayDInstance { - app := GatewayDInstance{ +func NewGatewayDInstance(cmd *cobra.Command) *GatewayDApp { + app := GatewayDApp{ loggers: make(map[string]zerolog.Logger), pools: make(map[string]map[string]*pool.Pool), clients: make(map[string]map[string]*config.Client), diff --git a/errors/errors.go b/errors/errors.go index 61bb0e32..6c0d791b 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -214,10 +214,12 @@ var ( ) const ( - FailedToCreateClient = 1 - FailedToInitializePool = 2 - FailedToStartServer = 3 - FailedToStartTracer = 4 - FailedToCreateActRegistry = 5 - FailedToStartRaftNode = 6 + FailedToCreateClient = 1 + FailedToStartServer = 2 + FailedToStartTracer = 3 + FailedToCreateActRegistry = 4 + FailedToLoadPolicies = 5 + FailedToStartRaftNode = 6 + FailedToMergeGlobalConfig = 7 + FailedToCreatePoolAndClients = 8 )