diff --git a/gatewayd_plugin.yaml b/gatewayd_plugin.yaml index cae7662..59215c3 100644 --- a/gatewayd_plugin.yaml +++ b/gatewayd_plugin.yaml @@ -30,4 +30,5 @@ plugins: - API_ADDRESS=localhost:18080 - EXIT_ON_STARTUP_ERROR=False - SENTRY_DSN=https://70eb1abcd32e41acbdfc17bc3407a543@o4504550475038720.ingest.sentry.io/4505342961123328 + - CACHE_CHANNEL_BUFFER_SIZE=100 checksum: 3988e10aefce2cd9b30888eddd2ec93a431c9018a695aea1cea0dac46ba91cae diff --git a/main.go b/main.go index 1f972b3..9c1c261 100644 --- a/main.go +++ b/main.go @@ -11,6 +11,7 @@ import ( "github.com/gatewayd-io/gatewayd-plugin-sdk/logging" "github.com/gatewayd-io/gatewayd-plugin-sdk/metrics" p "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin" + v1 "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin/v1" "github.com/getsentry/sentry-go" "github.com/go-redis/redis/v8" "github.com/hashicorp/go-hclog" @@ -52,6 +53,14 @@ func main() { go metrics.ExposeMetrics(metricsConfig, logger) } + cacheBufferSize := cast.ToUint(cfg["cacheBufferSize"]) + if cacheBufferSize <= 0 { + cacheBufferSize = 100 // default value + } + + pluginInstance.Impl.UpdateCacheChannel = make(chan *v1.Struct, cacheBufferSize) + go pluginInstance.Impl.UpdateCache(context.Background()) + pluginInstance.Impl.RedisURL = cast.ToString(cfg["redisURL"]) pluginInstance.Impl.Expiry = cast.ToDuration(cfg["expiry"]) pluginInstance.Impl.DefaultDBName = cast.ToString(cfg["defaultDBName"]) @@ -93,6 +102,8 @@ func main() { } } + defer close(pluginInstance.Impl.UpdateCacheChannel) + goplugin.Serve(&goplugin.ServeConfig{ HandshakeConfig: goplugin.HandshakeConfig{ ProtocolVersion: 1, diff --git a/plugin/module.go b/plugin/module.go index c51c362..4d727a9 100644 --- a/plugin/module.go +++ b/plugin/module.go @@ -45,6 +45,7 @@ var ( "PERIODIC_INVALIDATOR_INTERVAL", "1m"), "apiAddress": sdkConfig.GetEnv("API_ADDRESS", "localhost:8080"), "exitOnStartupError": sdkConfig.GetEnv("EXIT_ON_STARTUP_ERROR", "false"), + "cacheBufferSize": sdkConfig.GetEnv("CACHE_CHANNEL_BUFFER_SIZE", "100"), }, "hooks": []interface{}{ int32(v1.HookName_HOOK_NAME_ON_CLOSED), diff --git a/plugin/plugin.go b/plugin/plugin.go index 04d7701..3ecab15 100644 --- a/plugin/plugin.go +++ b/plugin/plugin.go @@ -30,6 +30,8 @@ type Plugin struct { ScanCount int64 ExitOnStartupError bool + UpdateCacheChannel chan *v1.Struct + // Periodic invalidator configuration. PeriodicInvalidatorEnabled bool PeriodicInvalidatorStartDelay time.Duration @@ -144,87 +146,103 @@ func (p *Plugin) OnTrafficFromClient( return req, nil } -// OnTrafficFromServer is called when a response is received by GatewayD from the server. -func (p *Plugin) OnTrafficFromServer( - ctx context.Context, resp *v1.Struct, -) (*v1.Struct, error) { - OnTrafficFromServerCounter.Inc() - resp, err := postgres.HandleServerMessage(resp, p.Logger) - if err != nil { - p.Logger.Info("Failed to handle server message", "error", err) - } - - rowDescription := cast.ToString(sdkPlugin.GetAttr(resp, "rowDescription", "")) - dataRow := cast.ToStringSlice(sdkPlugin.GetAttr(resp, "dataRow", []interface{}{})) - errorResponse := cast.ToString(sdkPlugin.GetAttr(resp, "errorResponse", "")) - request, ok := sdkPlugin.GetAttr(resp, "request", nil).([]byte) - if !ok { - request = []byte{} - } - response, ok := sdkPlugin.GetAttr(resp, "response", nil).([]byte) - if !ok { - response = []byte{} - } - server := cast.ToStringMapString(sdkPlugin.GetAttr(resp, "server", "")) +func (p *Plugin) UpdateCache(ctx context.Context) { + for { + serverResponse, ok := <-p.UpdateCacheChannel + if !ok { + p.Logger.Info("Channel closed, returning from function") + return + } - // This is used as a fallback if the database is not found in the startup message. - database := p.DefaultDBName - if database == "" { - client := cast.ToStringMapString(sdkPlugin.GetAttr(resp, "client", "")) - if client != nil && client["remote"] != "" { - database, err = p.RedisClient.Get(ctx, client["remote"]).Result() - if err != nil { - CacheMissesCounter.Inc() - p.Logger.Debug("Failed to get cached response", "error", err) - } - CacheGetsCounter.Inc() + OnTrafficFromServerCounter.Inc() + resp, err := postgres.HandleServerMessage(serverResponse, p.Logger) + if err != nil { + p.Logger.Info("Failed to handle server message", "error", err) } - } - // If the database is still not found, return the response as is without caching. - // This might also happen if the cache is cleared while the client is still connected. - // In this case, the client should reconnect and the error will go away. - if database == "" { - p.Logger.Debug("Database name not found or set in cache, startup message or plugin config. Skipping cache") - p.Logger.Debug("Consider setting the database name in the plugin config or disabling the plugin if you don't need it") - return resp, nil - } + rowDescription := cast.ToString(sdkPlugin.GetAttr(resp, "rowDescription", "")) + dataRow := cast.ToStringSlice(sdkPlugin.GetAttr(resp, "dataRow", []interface{}{})) + errorResponse := cast.ToString(sdkPlugin.GetAttr(resp, "errorResponse", "")) + request, isOk := sdkPlugin.GetAttr(resp, "request", nil).([]byte) + if !isOk { + request = []byte{} + } - cacheKey := strings.Join([]string{server["remote"], database, string(request)}, ":") - if errorResponse == "" && rowDescription != "" && dataRow != nil && len(dataRow) > 0 { - // The request was successful and the response contains data. Cache the response. - if err := p.RedisClient.Set(ctx, cacheKey, response, p.Expiry).Err(); err != nil { - CacheMissesCounter.Inc() - p.Logger.Debug("Failed to set cache", "error", err) + response, isOk := sdkPlugin.GetAttr(resp, "response", nil).([]byte) + if !isOk { + response = []byte{} } - CacheSetsCounter.Inc() + server := cast.ToStringMapString(sdkPlugin.GetAttr(resp, "server", "")) - // Cache the query as well. - query, err := postgres.GetQueryFromRequest(request) - if err != nil { - p.Logger.Debug("Failed to get query from request", "error", err) - return resp, nil + // This is used as a fallback if the database is not found in the startup message. + + database := p.DefaultDBName + if database == "" { + client := cast.ToStringMapString(sdkPlugin.GetAttr(resp, "client", "")) + if client != nil && client["remote"] != "" { + database, err = p.RedisClient.Get(ctx, client["remote"]).Result() + if err != nil { + CacheMissesCounter.Inc() + p.Logger.Debug("Failed to get cached response", "error", err) + } + CacheGetsCounter.Inc() + } } - tables, err := postgres.GetTablesFromQuery(query) - if err != nil { - p.Logger.Debug("Failed to get tables from query", "error", err) - return resp, nil + // If the database is still not found, return the response as is without caching. + // This might also happen if the cache is cleared while the client is still connected. + // In this case, the client should reconnect and the error will go away. + if database == "" { + p.Logger.Debug("Database name not found or set in cache, startup message or plugin config. " + + "Skipping cache") + p.Logger.Debug("Consider setting the database name in the " + + "plugin config or disabling the plugin if you don't need it") + return } - // Cache the table(s) used in each cached request. This is used to invalidate - // the cache when a rows is inserted, updated or deleted into that table. - for _, table := range tables { - requestQueryCacheKey := strings.Join([]string{table, cacheKey}, ":") - if err := p.RedisClient.Set( - ctx, requestQueryCacheKey, "", p.Expiry).Err(); err != nil { + cacheKey := strings.Join([]string{server["remote"], database, string(request)}, ":") + if errorResponse == "" && rowDescription != "" && dataRow != nil && len(dataRow) > 0 { + // The request was successful and the response contains data. Cache the response. + if err := p.RedisClient.Set(ctx, cacheKey, response, p.Expiry).Err(); err != nil { CacheMissesCounter.Inc() p.Logger.Debug("Failed to set cache", "error", err) } CacheSetsCounter.Inc() + + // Cache the query as well. + query, err := postgres.GetQueryFromRequest(request) + if err != nil { + p.Logger.Debug("Failed to get query from request", "error", err) + return + } + + tables, err := postgres.GetTablesFromQuery(query) + if err != nil { + p.Logger.Debug("Failed to get tables from query", "error", err) + return + } + + // Cache the table(s) used in each cached request. This is used to invalidate + // the cache when a rows is inserted, updated or deleted into that table. + for _, table := range tables { + requestQueryCacheKey := strings.Join([]string{table, cacheKey}, ":") + if err := p.RedisClient.Set( + ctx, requestQueryCacheKey, "", p.Expiry).Err(); err != nil { + CacheMissesCounter.Inc() + p.Logger.Debug("Failed to set cache", "error", err) + } + CacheSetsCounter.Inc() + } } } +} +// OnTrafficFromServer is called when a response is received by GatewayD from the server. +func (p *Plugin) OnTrafficFromServer( + _ context.Context, resp *v1.Struct, +) (*v1.Struct, error) { + p.Logger.Debug("Traffic is coming from the server side") + p.UpdateCacheChannel <- resp return resp, nil } diff --git a/plugin/plugin_test.go b/plugin/plugin_test.go index 07b99e8..4cefabf 100644 --- a/plugin/plugin_test.go +++ b/plugin/plugin_test.go @@ -3,9 +3,6 @@ package plugin import ( "context" "encoding/base64" - "os" - "testing" - miniredis "github.com/alicebob/miniredis/v2" "github.com/gatewayd-io/gatewayd-plugin-sdk/logging" v1 "github.com/gatewayd-io/gatewayd-plugin-sdk/plugin/v1" @@ -13,6 +10,9 @@ import ( "github.com/hashicorp/go-hclog" pgproto3 "github.com/jackc/pgx/v5/pgproto3" "github.com/stretchr/testify/assert" + "os" + "sync" + "testing" ) func testQueryRequest() (string, []byte) { @@ -44,16 +44,28 @@ func Test_Plugin(t *testing.T) { redisClient := redis.NewClient(redisConfig) assert.NotNil(t, redisClient) + updateCacheChannel := make(chan *v1.Struct, 10) + // Create and initialize a new plugin. logger := hclog.New(&hclog.LoggerOptions{ Level: logging.GetLogLevel("error"), Output: os.Stdout, }) p := NewCachePlugin(Plugin{ - Logger: logger, - RedisURL: redisURL, - RedisClient: redisClient, + Logger: logger, + RedisURL: redisURL, + RedisClient: redisClient, + UpdateCacheChannel: updateCacheChannel, }) + + // Use a WaitGroup to wait for the goroutine to finish + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + p.Impl.UpdateCache(context.Background()) + }() + assert.NotNil(t, p) // Test the plugin's GetPluginConfig method. @@ -146,6 +158,10 @@ func Test_Plugin(t *testing.T) { assert.NotNil(t, result) assert.Equal(t, result, resp) + // Close the channel and wait for the cache updater to return gracefully + close(updateCacheChannel) + wg.Wait() + // Check that the query and response was cached. cachedResponse, err := redisClient.Get( context.Background(), "localhost:5432:postgres:"+string(request)).Bytes()