From 7561a0d3b8a846ebd0d3be8550a8e9886438820b Mon Sep 17 00:00:00 2001 From: Ashley Jeffs Date: Mon, 30 Sep 2024 10:05:18 +0100 Subject: [PATCH 001/118] Move repos into connect --- go.mod | 11 +- go.sum | 8 + internal/impl/postgresql/pg_stream/README.md | 48 + .../pg_stream/pg_stream/integration_test.go | 215 +++++ .../pg_stream/pg_stream/pg_stream.go | 258 ++++++ .../pg_stream_schemaless.go | 74 ++ .../pg_stream_schemaless/wal_message.go | 25 + .../impl/postgresql/pglogicalstream/README.MD | 91 ++ .../impl/postgresql/pglogicalstream/config.go | 30 + .../pglogicalstream/docker-compose.yaml | 11 + .../pglogicalstream/example/simple/main.go | 44 + .../pglogicalstream/example/ws/main.go | 54 ++ .../impl/postgresql/pglogicalstream/filter.go | 73 ++ .../internal/helpers/arrow_schema_builder.go | 44 + .../internal/helpers/availablememory.go | 20 + .../internal/schemas/schemas.go | 16 + .../pglogicalstream/logical_stream.go | 511 +++++++++++ .../postgresql/pglogicalstream/message.go | 728 +++++++++++++++ .../pglogicalstream/message_test.go | 856 ++++++++++++++++++ .../postgresql/pglogicalstream/pglogrepl.go | 773 ++++++++++++++++ .../pglogicalstream/pglogrepl_test.go | 414 +++++++++ .../postgresql/pglogicalstream/snapshotter.go | 101 +++ .../impl/postgresql/pglogicalstream/types.go | 24 + .../pglogicalstream/wal_changes_message.go | 25 + 24 files changed, 4451 insertions(+), 3 deletions(-) create mode 100644 internal/impl/postgresql/pg_stream/README.md create mode 100644 internal/impl/postgresql/pg_stream/pg_stream/integration_test.go create mode 100644 internal/impl/postgresql/pg_stream/pg_stream/pg_stream.go create mode 100644 internal/impl/postgresql/pg_stream/pg_stream_schemaless/pg_stream_schemaless.go create mode 100644 internal/impl/postgresql/pg_stream/pg_stream_schemaless/wal_message.go create mode 100644 internal/impl/postgresql/pglogicalstream/README.MD create mode 100644 internal/impl/postgresql/pglogicalstream/config.go create mode 100644 internal/impl/postgresql/pglogicalstream/docker-compose.yaml create mode 100644 internal/impl/postgresql/pglogicalstream/example/simple/main.go create mode 100644 internal/impl/postgresql/pglogicalstream/example/ws/main.go create mode 100644 internal/impl/postgresql/pglogicalstream/filter.go create mode 100644 internal/impl/postgresql/pglogicalstream/internal/helpers/arrow_schema_builder.go create mode 100644 internal/impl/postgresql/pglogicalstream/internal/helpers/availablememory.go create mode 100644 internal/impl/postgresql/pglogicalstream/internal/schemas/schemas.go create mode 100644 internal/impl/postgresql/pglogicalstream/logical_stream.go create mode 100644 internal/impl/postgresql/pglogicalstream/message.go create mode 100644 internal/impl/postgresql/pglogicalstream/message_test.go create mode 100644 internal/impl/postgresql/pglogicalstream/pglogrepl.go create mode 100644 internal/impl/postgresql/pglogicalstream/pglogrepl_test.go create mode 100644 internal/impl/postgresql/pglogicalstream/snapshotter.go create mode 100644 internal/impl/postgresql/pglogicalstream/types.go create mode 100644 internal/impl/postgresql/pglogicalstream/wal_changes_message.go diff --git a/go.mod b/go.mod index 7eab79eb4b..370692cafb 100644 --- a/go.mod +++ b/go.mod @@ -24,6 +24,7 @@ require ( github.com/Masterminds/squirrel v1.5.4 github.com/PaesslerAG/gval v1.2.2 github.com/PaesslerAG/jsonpath v0.1.1 + github.com/apache/arrow/go/v14 v14.0.2 github.com/apache/pulsar-client-go v0.13.1 github.com/aws/aws-lambda-go v1.47.0 github.com/aws/aws-sdk-go-v2 v1.30.4 @@ -65,10 +66,14 @@ require ( github.com/golang-jwt/jwt/v5 v5.2.1 github.com/gosimple/slug v1.14.0 github.com/influxdata/influxdb1-client v0.0.0-20220302092344-a9ab5670611c + github.com/jackc/pglogrepl v0.0.0-20240307033717-828fbfe908e9 github.com/jackc/pgx/v4 v4.18.3 + github.com/jackc/pgx/v5 v5.6.0 + github.com/jaswdr/faker v1.19.1 github.com/jhump/protoreflect v1.16.0 github.com/lib/pq v1.10.9 github.com/linkedin/goavro/v2 v2.13.0 + github.com/lucasepe/codename v0.2.0 github.com/matoous/go-nanoid/v2 v2.1.0 github.com/microcosm-cc/bluemonday v1.0.27 github.com/microsoft/gocosmos v1.1.1 @@ -262,7 +267,7 @@ require ( github.com/gorilla/css v1.0.1 // indirect github.com/gorilla/handlers v1.5.2 // indirect github.com/gorilla/mux v1.8.1 // indirect - github.com/gorilla/websocket v1.5.3 // indirect + github.com/gorilla/websocket v1.5.3 github.com/gosimple/unidecode v1.0.1 // indirect github.com/govalues/decimal v0.1.29 // indirect github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 // indirect @@ -281,7 +286,7 @@ require ( github.com/itchyny/timefmt-go v0.1.6 // indirect github.com/jackc/chunkreader/v2 v2.0.1 // indirect github.com/jackc/pgconn v1.14.3 - github.com/jackc/pgio v1.0.0 // indirect + github.com/jackc/pgio v1.0.0 github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgproto3/v2 v2.3.3 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect @@ -389,7 +394,7 @@ require ( gopkg.in/jcmturner/rpc.v1 v1.1.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect + gopkg.in/yaml.v3 v3.0.1 modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 // indirect modernc.org/libc v1.55.3 // indirect modernc.org/mathutil v1.6.0 // indirect diff --git a/go.sum b/go.sum index 8108de7522..1335d065bb 100644 --- a/go.sum +++ b/go.sum @@ -165,6 +165,8 @@ github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kd github.com/apache/arrow/go/arrow v0.0.0-20200730104253-651201b0f516/go.mod h1:QNYViu/X0HXDHw7m3KXzWSVXIbfUvJqBFe6Gj8/pYA0= github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 h1:q4dksr6ICHXqG5hm0ZW5IHyeEJXoIJSOZeBLmWPNeIQ= github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40/go.mod h1:Q7yQnSMnLvcXlZ8RV+jwz/6y1rQTqbX6C82SndT52Zs= +github.com/apache/arrow/go/v14 v14.0.2 h1:N8OkaJEOfI3mEZt07BIkvo4sC6XDbL+48MBPWO5IONw= +github.com/apache/arrow/go/v14 v14.0.2/go.mod h1:u3fgh3EdgN/YQ8cVQRguVW3R+seMybFg8QBQ5LU+eBY= github.com/apache/arrow/go/v15 v15.0.2 h1:60IliRbiyTWCWjERBCkO1W4Qun9svcYoZrSLcyOsMLE= github.com/apache/arrow/go/v15 v15.0.2/go.mod h1:DGXsR3ajT524njufqf95822i+KTh+yea1jass9YXgjA= github.com/apache/pulsar-client-go v0.13.1 h1:XAAKXjF99du7LP6qu/nBII1HC2nS483/vQoQIWmm5Yg= @@ -690,6 +692,8 @@ github.com/jackc/pgconn v1.14.3 h1:bVoTr12EGANZz66nZPkMInAV/KHD2TxH9npjXXgiB3w= github.com/jackc/pgconn v1.14.3/go.mod h1:RZbme4uasqzybK2RK5c65VsHxoyaml09lx3tXOcO/VM= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= +github.com/jackc/pglogrepl v0.0.0-20240307033717-828fbfe908e9 h1:86CQbMauoZdLS0HDLcEHYo6rErjiCBjVvcxGsioIn7s= +github.com/jackc/pglogrepl v0.0.0-20240307033717-828fbfe908e9/go.mod h1:SO15KF4QqfUM5UhsG9roXre5qeAQLC1rm8a8Gjpgg5k= github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c= github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 h1:DadwsjnMwFjfWc9y5Wi/+Zz7xoE5ALHsRQlOctkOiHc= @@ -732,6 +736,8 @@ github.com/jackc/puddle v1.3.0 h1:eHK/5clGOatcjX3oWGBO/MpxpbHzSwud5EWTSCI+MX0= github.com/jackc/puddle v1.3.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jaswdr/faker v1.19.1 h1:xBoz8/O6r0QAR8eEvKJZMdofxiRH+F0M/7MU9eNKhsM= +github.com/jaswdr/faker v1.19.1/go.mod h1:x7ZlyB1AZqwqKZgyQlnqEG8FDptmHlncA5u2zY/yi6w= github.com/jawher/mow.cli v1.0.4/go.mod h1:5hQj2V8g+qYmLUVWqu4Wuja1pI57M83EChYLVZ0sMKk= github.com/jawher/mow.cli v1.2.0/go.mod h1:y+pcA3jBAdo/GIZx/0rFjw/K2bVEODP9rfZOfaiq8Ko= github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8= @@ -816,6 +822,8 @@ github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/linkedin/goavro/v2 v2.13.0 h1:L8eI8GcuciwUkt41Ej62joSZS4kKaYIUdze+6for9NU= github.com/linkedin/goavro/v2 v2.13.0/go.mod h1:KXx+erlq+RPlGSPmLF7xGo6SAbh8sCQ53x064+ioxhk= +github.com/lucasepe/codename v0.2.0 h1:zkW9mKWSO8jjVIYFyZWE9FPvBtFVJxgMpQcMkf4Vv20= +github.com/lucasepe/codename v0.2.0/go.mod h1:RDcExRuZPWp5Uz+BosvpROFTrxpt5r1vSzBObHdBdDM= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= github.com/lufia/plan9stats v0.0.0-20240226150601-1dcf7310316a h1:3Bm7EwfUQUvhNeKIkUct/gl9eod1TcXuj8stxvi/GoI= github.com/lufia/plan9stats v0.0.0-20240226150601-1dcf7310316a/go.mod h1:ilwx/Dta8jXAgpFYFvSWEMwxmbWXyiUHkd5FwyKhb5k= diff --git a/internal/impl/postgresql/pg_stream/README.md b/internal/impl/postgresql/pg_stream/README.md new file mode 100644 index 0000000000..758da9d8e4 --- /dev/null +++ b/internal/impl/postgresql/pg_stream/README.md @@ -0,0 +1,48 @@ +# PostgreSQL Logical Replication Streaming Plugin for Benthos + +Welcome to the PostgreSQL Logical Replication Streaming Plugin for Benthos! This plugin allows you to seamlessly stream data changes from your PostgreSQL database using Benthos, a versatile stream processor. + +## Features + +- **Real-time Data Streaming:** Capture data changes in real-time as they happen in your PostgreSQL database. + +- **Flexible Configuration:** Easily configure the plugin to specify the database connection details, replication slot, and table filtering rules. + +- **Checkpoints:** Store your replication consuming progress in Redis + +## Prerequisites + +Before you begin, make sure you have the following prerequisites: + +- [PostgreSQL](https://www.postgresql.org/): Ensure you have a PostgreSQL database instance that supports logical replication. + +### Create benthos configuration with plugin + +```yaml +input: + label: postgres_cdc_input + # register new plugin + pg_stream: + host: datbase hoat + slot_name: reqplication slot name + user: postgres username with replication permissions + password: password + port: 5432 + schema: schema you want to replicate tables from + stream_snapshot: set true if you want to stream existing data. If set to false only a new data will be streamed + database: name of the database + checkpoint_storage: redis uri if you want to store checkpoints + tables: ## list of tables you want to replicate + - table_name +``` + +### Register processor to pretty format your data +By default, plugins exports raw `wal2json` message. If you want to receive your data as json structure +without metadata to transform it with benthos - you can register `pg_stream_schemaless` plugin to transform it + +```yaml +pipeline: + processors: + - label: pretty_changes_processor + pg_stream_schemaless: { } +``` diff --git a/internal/impl/postgresql/pg_stream/pg_stream/integration_test.go b/internal/impl/postgresql/pg_stream/pg_stream/integration_test.go new file mode 100644 index 0000000000..02f53a7bbb --- /dev/null +++ b/internal/impl/postgresql/pg_stream/pg_stream/integration_test.go @@ -0,0 +1,215 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package pg_stream + +import ( + "context" + "database/sql" + "fmt" + "log" + "strings" + "sync" + "testing" + "time" + + "github.com/jaswdr/faker" + _ "github.com/lib/pq" + _ "github.com/redpanda-data/benthos/v4/public/components/io" + _ "github.com/redpanda-data/benthos/v4/public/components/pure" + "github.com/redpanda-data/benthos/v4/public/service" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ory/dockertest/v3" + "github.com/ory/dockertest/v3/docker" +) + +func TestIntegrationPgCDC(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + pool, err := dockertest.NewPool("") + require.NoError(t, err) + + // Use custom PostgreSQL image with wal2json plugin compiled in + resource, err := pool.RunWithOptions(&dockertest.RunOptions{ + Repository: "usedatabrew/pgwal2json", + Tag: "16", + Env: []string{ + "POSTGRES_PASSWORD=secret", + "POSTGRES_USER=user_name", + "POSTGRES_DB=dbname", + }, + ExposedPorts: []string{"5432"}, + Cmd: []string{ + "postgres", + "-c", "wal_level=logical", + }, + }, func(config *docker.HostConfig) { + config.AutoRemove = true + config.RestartPolicy = docker.RestartPolicy{Name: "no"} + }) + + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, pool.Purge(resource)) + }) + + require.NoError(t, resource.Expire(120)) + + hostAndPort := resource.GetHostPort("5432/tcp") + hostAndPortSplited := strings.Split(hostAndPort, ":") + databaseUrl := fmt.Sprintf("user=user_name password=secret dbname=dbname sslmode=disable host=%s port=%s", hostAndPortSplited[0], hostAndPortSplited[1]) + + var db *sql.DB + + pool.MaxWait = 120 * time.Second + if err = pool.Retry(func() error { + if db, err = sql.Open("postgres", databaseUrl); err != nil { + return err + } + + if err = db.Ping(); err != nil { + return err + } + + var walLevel string + if err = db.QueryRow("SHOW wal_level").Scan(&walLevel); err != nil { + return err + } + + var pgConfig string + if err = db.QueryRow("SHOW config_file").Scan(&pgConfig); err != nil { + return err + } + + if walLevel != "logical" { + return fmt.Errorf("wal_level is not logical") + } + + _, err = db.Exec("CREATE TABLE IF NOT EXISTS flights (id serial PRIMARY KEY, name VARCHAR(50), created_at TIMESTAMP);") + + return err + }); err != nil { + log.Fatalf("Could not connect to docker: %s", err) + } + + fake := faker.New() + for i := 0; i < 1000; i++ { + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + require.NoError(t, err) + } + + template := fmt.Sprintf(` +pg_stream: + host: %s + slot_name: test_slot + user: user_name + password: secret + port: %s + schema: public + tls: none + stream_snapshot: true + database: dbname + tables: + - flights +`, hostAndPortSplited[0], hostAndPortSplited[1]) + + cacheConf := fmt.Sprintf(` +label: pg_stream_cache +file: + directory: %v +`, tmpDir) + + streamOutBuilder := service.NewStreamBuilder() + require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: OFF`)) + require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) + require.NoError(t, streamOutBuilder.AddInputYAML(template)) + + var outMessages []string + var outMessagesMut sync.Mutex + + require.NoError(t, streamOutBuilder.AddConsumerFunc(func(c context.Context, m *service.Message) error { + msgBytes, err := m.AsBytes() + require.NoError(t, err) + outMessagesMut.Lock() + outMessages = append(outMessages, string(msgBytes)) + outMessagesMut.Unlock() + return nil + })) + + streamOut, err := streamOutBuilder.Build() + require.NoError(t, err) + + go func() { + _ = streamOut.Run(context.Background()) + }() + + assert.Eventually(t, func() bool { + outMessagesMut.Lock() + defer outMessagesMut.Unlock() + return len(outMessages) == 1000 + }, time.Second*25, time.Millisecond*100) + + for i := 0; i < 1000; i++ { + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + require.NoError(t, err) + } + + assert.Eventually(t, func() bool { + outMessagesMut.Lock() + defer outMessagesMut.Unlock() + return len(outMessages) == 2000 + }, time.Second*25, time.Millisecond*100) + + require.NoError(t, streamOut.StopWithin(time.Second*10)) + + // Starting stream for the same replication slot should continue from the last LSN + // Meaning we must not receive any old messages again + + streamOutBuilder = service.NewStreamBuilder() + require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: OFF`)) + require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) + require.NoError(t, streamOutBuilder.AddInputYAML(template)) + + outMessages = []string{} + require.NoError(t, streamOutBuilder.AddConsumerFunc(func(c context.Context, m *service.Message) error { + msgBytes, err := m.AsBytes() + require.NoError(t, err) + outMessagesMut.Lock() + outMessages = append(outMessages, string(msgBytes)) + outMessagesMut.Unlock() + return nil + })) + + streamOut, err = streamOutBuilder.Build() + require.NoError(t, err) + + go func() { + assert.NoError(t, streamOut.Run(context.Background())) + }() + + time.Sleep(time.Second * 5) + for i := 0; i < 50; i++ { + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + require.NoError(t, err) + } + + assert.Eventually(t, func() bool { + outMessagesMut.Lock() + defer outMessagesMut.Unlock() + return len(outMessages) == 50 + }, time.Second*20, time.Millisecond*100) + + require.NoError(t, streamOut.StopWithin(time.Second*10)) + t.Log("All the conditions are met 🎉") + + t.Cleanup(func() { + db.Close() + }) +} diff --git a/internal/impl/postgresql/pg_stream/pg_stream/pg_stream.go b/internal/impl/postgresql/pg_stream/pg_stream/pg_stream.go new file mode 100644 index 0000000000..fac92c400c --- /dev/null +++ b/internal/impl/postgresql/pg_stream/pg_stream/pg_stream.go @@ -0,0 +1,258 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package pg_stream + +import ( + "context" + "crypto/tls" + "encoding/json" + "fmt" + "strings" + + "github.com/jackc/pgx/v5/pgconn" + "github.com/lucasepe/codename" + "github.com/redpanda-data/benthos/v4/public/service" + "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream" +) + +var randomSlotName string + +var pgStreamConfigSpec = service.NewConfigSpec(). + Summary("Creates Postgres replication slot for CDC"). + Field(service.NewStringField("host"). + Description("PostgreSQL instance host"). + Example("123.0.0.1")). + Field(service.NewIntField("port"). + Description("PostgreSQL instance port"). + Example(5432). + Default(5432)). + Field(service.NewStringField("user"). + Description("Username with permissions to start replication (RDS superuser)"). + Example("postgres"), + ). + Field(service.NewStringField("password"). + Description("PostgreSQL database password")). + Field(service.NewStringField("schema"). + Description("Schema that will be used to create replication")). + Field(service.NewStringField("database"). + Description("PostgreSQL database name")). + Field(service.NewStringEnumField("tls", "require", "none"). + Description("Defines whether benthos need to verify (skipinsecure) TLS configuration"). + Example("none"). + Default("none")). + Field(service.NewBoolField("stream_snapshot"). + Description("Set `true` if you want to receive all the data that currently exist in database"). + Example(true). + Default(false)). + Field(service.NewFloatField("snapshot_memory_safety_factor"). + Description("Sets amout of memory that can be used to stream snapshot. If affects batch sizes. If we want to use only 25% of the memory available - put 0.25 factor. It will make initial streaming slower, but it will prevent your worker from OOM Kill"). + Example(0.2). + Default(0.5)). + Field(service.NewStringListField("tables"). + Example(` + - my_table + - my_table_2 + `). + Description("List of tables we have to create logical replication for")). + Field(service.NewStringField("slot_name"). + Description("PostgeSQL logical replication slot name. You can create it manually before starting the sync. If not provided will be replaced with a random one"). + Example("my_test_slot"). + Default(randomSlotName)) + +func newPgStreamInput(conf *service.ParsedConfig) (s service.Input, err error) { + var ( + dbName string + dbPort int + dbHost string + dbSchema string + dbUser string + dbPassword string + dbSlotName string + tlsSetting string + tables []string + streamSnapshot bool + snapshotMemSafetyFactor float64 + ) + + dbSchema, err = conf.FieldString("schema") + if err != nil { + return nil, err + } + + dbSlotName, err = conf.FieldString("slot_name") + if err != nil { + return nil, err + } + + if dbSlotName == "" { + dbSlotName = randomSlotName + } + + dbPassword, err = conf.FieldString("password") + if err != nil { + return nil, err + } + + dbUser, err = conf.FieldString("user") + if err != nil { + return nil, err + } + + tlsSetting, err = conf.FieldString("tls") + if err != nil { + return nil, err + } + + dbName, err = conf.FieldString("database") + if err != nil { + return nil, err + } + + dbHost, err = conf.FieldString("host") + if err != nil { + return nil, err + } + + dbPort, err = conf.FieldInt("port") + if err != nil { + return nil, err + } + + tables, err = conf.FieldStringList("tables") + if err != nil { + return nil, err + } + + streamSnapshot, err = conf.FieldBool("stream_snapshot") + if err != nil { + return nil, err + } + + snapshotMemSafetyFactor, err = conf.FieldFloat("snapshot_memory_safety_factor") + if err != nil { + return nil, err + } + + pgconnConfig := pgconn.Config{ + Host: dbHost, + Port: uint16(dbPort), + Database: dbName, + User: dbUser, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + Password: dbPassword, + } + + if tlsSetting == "none" { + pgconnConfig.TLSConfig = nil + } + + return service.AutoRetryNacks(&pgStreamInput{ + dbConfig: pgconnConfig, + streamSnapshot: streamSnapshot, + snapshotMemSafetyFactor: snapshotMemSafetyFactor, + slotName: dbSlotName, + schema: dbSchema, + tls: pglogicalstream.TlsVerify(tlsSetting), + tables: tables, + }), err +} + +func init() { + rng, _ := codename.DefaultRNG() + randomSlotName = fmt.Sprintf("%s", strings.ReplaceAll(codename.Generate(rng, 5), "-", "_")) + + err := service.RegisterInput( + "pg_stream", pgStreamConfigSpec, + func(conf *service.ParsedConfig, mgr *service.Resources) (service.Input, error) { + return newPgStreamInput(conf) + }) + if err != nil { + panic(err) + } +} + +type pgStreamInput struct { + dbConfig pgconn.Config + pglogicalStream *pglogicalstream.Stream + redisUri string + slotName string + schema string + tables []string + streamSnapshot bool + tls pglogicalstream.TlsVerify // none, require + snapshotMemSafetyFactor float64 + logger *service.Logger +} + +func (p *pgStreamInput) Connect(ctx context.Context) error { + pgStream, err := pglogicalstream.NewPgStream(pglogicalstream.Config{ + DbHost: p.dbConfig.Host, + DbPassword: p.dbConfig.Password, + DbUser: p.dbConfig.User, + DbPort: int(p.dbConfig.Port), + DbTables: p.tables, + DbName: p.dbConfig.Database, + DbSchema: p.schema, + ReplicationSlotName: fmt.Sprintf("rs_%s", p.slotName), + TlsVerify: p.tls, + StreamOldData: p.streamSnapshot, + SnapshotMemorySafetyFactor: p.snapshotMemSafetyFactor, + SeparateChanges: true, + }) + if err != nil { + panic(err) + } + p.pglogicalStream = pgStream + return err +} + +func (p *pgStreamInput) Read(ctx context.Context) (*service.Message, service.AckFunc, error) { + select { + case snapshotMessage := <-p.pglogicalStream.SnapshotMessageC(): + var ( + mb []byte + err error + ) + if mb, err = json.Marshal(snapshotMessage); err != nil { + return nil, nil, err + } + return service.NewMessage(mb), func(ctx context.Context, err error) error { + // Nacks are retried automatically when we use service.AutoRetryNacks + return nil + }, nil + case message := <-p.pglogicalStream.LrMessageC(): + var ( + mb []byte + err error + ) + if mb, err = json.Marshal(message); err != nil { + return nil, nil, err + } + return service.NewMessage(mb), func(ctx context.Context, err error) error { + // Nacks are retried automatically when we use service.AutoRetryNacks + //message.ServerHeartbeat. + + if message.Lsn != nil { + p.pglogicalStream.AckLSN(*message.Lsn) + } + return nil + }, nil + case <-ctx.Done(): + return nil, nil, p.pglogicalStream.Stop() + } +} + +func (p *pgStreamInput) Close(ctx context.Context) error { + if p.pglogicalStream != nil { + return p.pglogicalStream.Stop() + } + return nil +} diff --git a/internal/impl/postgresql/pg_stream/pg_stream_schemaless/pg_stream_schemaless.go b/internal/impl/postgresql/pg_stream/pg_stream_schemaless/pg_stream_schemaless.go new file mode 100644 index 0000000000..c03a94fc74 --- /dev/null +++ b/internal/impl/postgresql/pg_stream/pg_stream_schemaless/pg_stream_schemaless.go @@ -0,0 +1,74 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package pg_stream_schemaless + +import ( + "bytes" + "context" + + "encoding/json" + + "github.com/redpanda-data/benthos/v4/public/service" +) + +func init() { + // Config spec is empty for now as we don't have any dynamic fields. + configSpec := service.NewConfigSpec() + + constructor := func(conf *service.ParsedConfig, mgr *service.Resources) (service.Processor, error) { + return newPgSchematicProcessor(mgr.Logger(), mgr.Metrics()), nil + } + + err := service.RegisterProcessor("pg_stream_schemaless", configSpec, constructor) + if err != nil { + panic(err) + } +} + +type pgSchematicProcessor struct { +} + +func newPgSchematicProcessor(logger *service.Logger, metrics *service.Metrics) *pgSchematicProcessor { + // The logger and metrics components will already be labelled with the + // identifier of this component within a config. + return &pgSchematicProcessor{} +} + +func (r *pgSchematicProcessor) Process(ctx context.Context, m *service.Message) (service.MessageBatch, error) { + bytesContent, err := m.AsBytes() + if err != nil { + return nil, err + } + var message WalMessage + if err = json.NewDecoder(bytes.NewReader(bytesContent)).Decode(&message); err != nil { + return nil, err + } + + var messageAsSchema = map[string]interface{}{} + if len(message.Change) == 0 { + return nil, nil + } + + for _, change := range message.Change { + for i, k := range change.Columnnames { + messageAsSchema[k] = change.Columnvalues[i] + } + } + var newBytes []byte + if newBytes, err = json.Marshal(&messageAsSchema); err != nil { + return nil, err + } + + m.SetBytes(newBytes) + return []*service.Message{m}, nil +} + +func (r *pgSchematicProcessor) Close(ctx context.Context) error { + return nil +} diff --git a/internal/impl/postgresql/pg_stream/pg_stream_schemaless/wal_message.go b/internal/impl/postgresql/pg_stream/pg_stream_schemaless/wal_message.go new file mode 100644 index 0000000000..2477ce4538 --- /dev/null +++ b/internal/impl/postgresql/pg_stream/pg_stream_schemaless/wal_message.go @@ -0,0 +1,25 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package pg_stream_schemaless + +type WalMessage struct { + Change []struct { + Kind string `json:"kind"` + Schema string `json:"schema"` + Table string `json:"table"` + Columnnames []string `json:"columnnames"` + Columntypes []string `json:"columntypes"` + Columnvalues []interface{} `json:"columnvalues"` + Oldkeys struct { + Keynames []string `json:"keynames"` + Keytypes []string `json:"keytypes"` + Keyvalues []interface{} `json:"keyvalues"` + } `json:"oldkeys"` + } `json:"change"` +} diff --git a/internal/impl/postgresql/pglogicalstream/README.MD b/internal/impl/postgresql/pglogicalstream/README.MD new file mode 100644 index 0000000000..89fa6fc314 --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/README.MD @@ -0,0 +1,91 @@ +This Go module builds upon [github.com/jackc/pglogrepl](https://github.com/jackc/pglogrepl) to provide an advanced +logical replication solution for PostgreSQL. It extends the capabilities of jackc/pglogrep for logical replication by +introducing several key features, making it easier to implement Change Data Capture (CDC) in your Go-based applications. + +## Features + +- **Checkpoints Storing:** Efficiently manage and store replication checkpoints, facilitating better tracking and + management of data changes. + +- **Snapshot Streaming:** Seamlessly capture and replicate snapshots of your PostgreSQL database, ensuring all data is + streamed through the pipeline. + +- **Table Filtering:** Tailor your CDC process by selectively filtering and replicating specific tables, optimizing + resource usage. + +## Getting Started + +Follow these steps to get started with our PostgreSQL Logical Replication CDC Module for Go: + +### Configure your replication stream + +Create `config.yaml` file + +```yaml +db_host: database host +db_password: password12345 +db_user: postgres +db_port: 5432 +db_name: mocks +db_schema: public +db_tables: + - rides +replication_slot_name: morning_elephant +tls_verify: require +stream_old_data: true +``` + +### Basic usage example + +By default `pglogicalstream` will create replication slot and publication for the tables you provide in Yaml config +It immediately starts streaming updates and you can receive them in the `OnMessage` function + +```go +package main + +import ( + "fmt" + "github.com/usedatabrew/pglogicalstream" + "gopkg.in/yaml.v3" + "io/ioutil" + "log" +) + +func main() { + var config pglogicalstream.Config + yamlFile, err := ioutil.ReadFile("./example/simple/config.yaml") + if err != nil { + log.Printf("yamlFile.Get err #%v ", err) + } + + err = yaml.Unmarshal(yamlFile, &config) + if err != nil { + log.Fatalf("Unmarshal: %v", err) + } + + pgStream, err := pglogicalstream.NewPgStream(config, log.WithPrefix("pg-cdc")) + if err != nil { + panic(err) + } + + pgStream.OnMessage(func(message messages.Wal2JsonChanges) { + fmt.Println(message.Changes) + }) +} + +``` + +### Example with checkpointer + +In order to recover after the failure, etc you have to store LSN somewhere to continue streaming the data +You can implement `CheckPointer` interface and pass it's instance to `NewPgStreamCheckPointer` and your LSN +will be stored automatically + +```go +checkPointer, err := NewPgStreamCheckPointer("redis.com:port", "user", "password") +if err != nil { + log.Fatalf("Checkpointer error") +} +pgStream, err := pglogicalstream.NewPgStream(config, checkPointer) +``` + diff --git a/internal/impl/postgresql/pglogicalstream/config.go b/internal/impl/postgresql/pglogicalstream/config.go new file mode 100644 index 0000000000..d9569427e0 --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/config.go @@ -0,0 +1,30 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package pglogicalstream + +type TlsVerify string + +const TlsNoVerify TlsVerify = "none" +const TlsRequireVerify TlsVerify = "require" + +type Config struct { + DbHost string `yaml:"db_host"` + DbPassword string `yaml:"db_password"` + DbUser string `yaml:"db_user"` + DbPort int `yaml:"db_port"` + DbName string `yaml:"db_name"` + DbSchema string `yaml:"db_schema"` + DbTables []string `yaml:"db_tables"` + ReplicationSlotName string `yaml:"replication_slot_name"` + TlsVerify TlsVerify `yaml:"tls_verify"` + StreamOldData bool `yaml:"stream_old_data"` + SeparateChanges bool `yaml:"separate_changes"` + SnapshotMemorySafetyFactor float64 `yaml:"snapshot_memory_safety_factor"` + BatchSize int `yaml:"batch_size"` +} diff --git a/internal/impl/postgresql/pglogicalstream/docker-compose.yaml b/internal/impl/postgresql/pglogicalstream/docker-compose.yaml new file mode 100644 index 0000000000..3ff4981ee8 --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/docker-compose.yaml @@ -0,0 +1,11 @@ +services: + postgres: + image: postgres:${POSTGRES_VERSION:-15} + restart: always + command: ["-c", "wal_level=logical", "-c", "max_wal_senders=10", "-c", "max_replication_slots=10"] + environment: + POSTGRES_USER: ${POSTGRES_USER:-pglogrepl} + POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-secret} + POSTGRES_DB: ${POSTGRES_DB:-pglogrepl} + POSTGRES_HOST_AUTH_METHOD: trust + network_mode: "host" \ No newline at end of file diff --git a/internal/impl/postgresql/pglogicalstream/example/simple/main.go b/internal/impl/postgresql/pglogicalstream/example/simple/main.go new file mode 100644 index 0000000000..b0c4ad01a9 --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/example/simple/main.go @@ -0,0 +1,44 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package main + +import ( + "fmt" + "io/ioutil" + "log" + + "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream" + "gopkg.in/yaml.v3" +) + +func main() { + var config pglogicalstream.Config + yamlFile, err := ioutil.ReadFile("./config.yaml") + if err != nil { + log.Printf("yamlFile.Get err #%v ", err) + } + + err = yaml.Unmarshal(yamlFile, &config) + if err != nil { + log.Fatalf("Unmarshal: %v", err) + } + + pgStream, err := pglogicalstream.NewPgStream(config) + if err != nil { + panic(err) + } + + pgStream.OnMessage(func(message pglogicalstream.Wal2JsonChanges) { + fmt.Println(message.Changes) + if message.Lsn != nil { + // Snapshots dont have LSN + pgStream.AckLSN(*message.Lsn) + } + }) +} diff --git a/internal/impl/postgresql/pglogicalstream/example/ws/main.go b/internal/impl/postgresql/pglogicalstream/example/ws/main.go new file mode 100644 index 0000000000..4ce2d9d067 --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/example/ws/main.go @@ -0,0 +1,54 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package main + +import ( + "io/ioutil" + "log" + + "github.com/gorilla/websocket" + "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream" + "gopkg.in/yaml.v3" +) + +func main() { + var config pglogicalstream.Config + yamlFile, err := ioutil.ReadFile("./example/simple/config.yaml") + if err != nil { + log.Printf("yamlFile.Get err #%v ", err) + } + + err = yaml.Unmarshal(yamlFile, &config) + if err != nil { + log.Fatalf("Unmarshal: %v", err) + } + + pgStream, err := pglogicalstream.NewPgStream(config) + if err != nil { + panic(err) + } + + wsClient, _, err := websocket.DefaultDialer.Dial("ws://localhost:10000/ws", nil) + if err != nil { + panic(err) + } + defer wsClient.Close() + + pgStream.OnMessage(func(message pglogicalstream.Wal2JsonChanges) { + marshaledChanges, err := message.Changes[0].Row.MarshalJSON() + if err != nil { + panic(err) + } + + err = wsClient.WriteMessage(websocket.TextMessage, marshaledChanges) + if err != nil { + log.Fatalf("write: %v", err) + } + }) +} diff --git a/internal/impl/postgresql/pglogicalstream/filter.go b/internal/impl/postgresql/pglogicalstream/filter.go new file mode 100644 index 0000000000..6841b71554 --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/filter.go @@ -0,0 +1,73 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package pglogicalstream + +type ChangeFilter struct { + tablesWhiteList map[string]bool + schemaWhiteList string +} + +type Filtered func(change Wal2JsonChanges) + +func NewChangeFilter(tableSchemas []string, schema string) ChangeFilter { + tablesMap := map[string]bool{} + for _, table := range tableSchemas { + tablesMap[table] = true + } + + return ChangeFilter{ + tablesWhiteList: tablesMap, + schemaWhiteList: schema, + } +} + +func (c ChangeFilter) FilterChange(lsn string, changes WallMessage, OnFiltered Filtered) { + if len(changes.Change) == 0 { + return + } + + for _, ch := range changes.Change { + var filteredChanges = Wal2JsonChanges{ + Lsn: &lsn, + Changes: []Wal2JsonChange{}, + } + if ch.Schema != c.schemaWhiteList { + continue + } + + var ( + tableExist bool + ) + + if _, tableExist = c.tablesWhiteList[ch.Table]; !tableExist { + continue + } + + if ch.Kind == "delete" { + ch.Columnvalues = make([]interface{}, len(ch.Oldkeys.Keyvalues)) + for i, changedValue := range ch.Oldkeys.Keyvalues { + if len(ch.Columnvalues) == 0 { + break + } + ch.Columnvalues[i] = changedValue + } + } + + filteredChanges.Changes = append(filteredChanges.Changes, Wal2JsonChange{ + Kind: ch.Kind, + Schema: ch.Schema, + Table: ch.Table, + ColumnNames: ch.Columnnames, + ColumnTypes: ch.Columntypes, + ColumnValues: ch.Columnvalues, + }) + + OnFiltered(filteredChanges) + } +} diff --git a/internal/impl/postgresql/pglogicalstream/internal/helpers/arrow_schema_builder.go b/internal/impl/postgresql/pglogicalstream/internal/helpers/arrow_schema_builder.go new file mode 100644 index 0000000000..c8b5912311 --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/internal/helpers/arrow_schema_builder.go @@ -0,0 +1,44 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package helpers + +import "github.com/apache/arrow/go/v14/arrow" + +func MapPlainTypeToArrow(fieldType string) arrow.DataType { + switch fieldType { + case "Boolean": + return arrow.FixedWidthTypes.Boolean + case "Int16": + return arrow.PrimitiveTypes.Int16 + case "Int32": + return arrow.PrimitiveTypes.Int32 + case "Int64": + return arrow.PrimitiveTypes.Int64 + case "Uint64": + return arrow.PrimitiveTypes.Uint64 + case "Float64": + return arrow.PrimitiveTypes.Float64 + case "Float32": + return arrow.PrimitiveTypes.Float32 + case "UUID": + return arrow.BinaryTypes.String + case "bytea": + return arrow.BinaryTypes.Binary + case "JSON": + return arrow.BinaryTypes.String + case "Inet": + return arrow.BinaryTypes.String + case "MAC": + return arrow.BinaryTypes.String + case "Date32": + return arrow.FixedWidthTypes.Date32 + default: + return arrow.BinaryTypes.String + } +} diff --git a/internal/impl/postgresql/pglogicalstream/internal/helpers/availablememory.go b/internal/impl/postgresql/pglogicalstream/internal/helpers/availablememory.go new file mode 100644 index 0000000000..c586ffecc6 --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/internal/helpers/availablememory.go @@ -0,0 +1,20 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package helpers + +import "runtime" + +func GetAvailableMemory() uint64 { + var memStats runtime.MemStats + runtime.ReadMemStats(&memStats) + // You can use memStats.Sys or another appropriate memory metric. + // Consider leaving some memory unused for other processes. + availableMemory := memStats.Sys - memStats.HeapInuse + return availableMemory +} diff --git a/internal/impl/postgresql/pglogicalstream/internal/schemas/schemas.go b/internal/impl/postgresql/pglogicalstream/internal/schemas/schemas.go new file mode 100644 index 0000000000..d1611c3fb2 --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/internal/schemas/schemas.go @@ -0,0 +1,16 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package schemas + +import "github.com/apache/arrow/go/v14/arrow" + +type DataTableSchema struct { + TableName string + Schema *arrow.Schema +} diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go new file mode 100644 index 0000000000..b15e9be282 --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -0,0 +1,511 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package pglogicalstream + +import ( + "bytes" + "context" + "crypto/tls" + "database/sql" + "encoding/json" + "errors" + "fmt" + "log" + "os" + "strings" + "sync" + "time" + + "github.com/jackc/pglogrepl" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgproto3" + "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/internal/helpers" +) + +var pluginArguments = []string{"\"pretty-print\" 'true'"} + +type Stream struct { + pgConn *pgconn.PgConn + // extra copy of db config is required to establish a new db connection + // which is required to take snapshot data + dbConfig pgconn.Config + streamCtx context.Context + streamCancel context.CancelFunc + + standbyCtxCancel context.CancelFunc + clientXLogPos pglogrepl.LSN + standbyMessageTimeout time.Duration + nextStandbyMessageDeadline time.Time + messages chan Wal2JsonChanges + snapshotMessages chan Wal2JsonChanges + snapshotName string + changeFilter ChangeFilter + lsnrestart pglogrepl.LSN + slotName string + schema string + tableNames []string + separateChanges bool + snapshotBatchSize int + snapshotMemorySafetyFactor float64 + logger *log.Logger + + m sync.Mutex + stopped bool +} + +func NewPgStream(config Config) (*Stream, error) { + var ( + cfg *pgconn.Config + err error + ) + + sslVerifyFull := "" + if config.TlsVerify == TlsRequireVerify { + sslVerifyFull = "&sslmode=verify-full" + } + + if cfg, err = pgconn.ParseConfig(fmt.Sprintf("postgres://%s:%s@%s:%d/%s?replication=database%s", + config.DbUser, + config.DbPassword, + config.DbHost, + config.DbPort, + config.DbName, + sslVerifyFull, + )); err != nil { + return nil, err + } + + if config.TlsVerify == TlsRequireVerify { + cfg.TLSConfig = &tls.Config{ + InsecureSkipVerify: true, + ServerName: config.DbHost, + } + } else { + cfg.TLSConfig = nil + } + + dbConn, err := pgconn.ConnectConfig(context.Background(), cfg) + if err != nil { + return nil, err + } + + var tableNames []string + for _, table := range config.DbTables { + tableNames = append(tableNames, table) + } + + stream := &Stream{ + pgConn: dbConn, + dbConfig: *cfg, + messages: make(chan Wal2JsonChanges), + snapshotMessages: make(chan Wal2JsonChanges, 100), + slotName: config.ReplicationSlotName, + schema: config.DbSchema, + snapshotMemorySafetyFactor: config.SnapshotMemorySafetyFactor, + separateChanges: config.SeparateChanges, + snapshotBatchSize: config.BatchSize, + tableNames: tableNames, + changeFilter: NewChangeFilter(tableNames, config.DbSchema), + logger: log.WithPrefix("[pg-stream]"), + m: sync.Mutex{}, + stopped: false, + } + + result := stream.pgConn.Exec(context.Background(), fmt.Sprintf("DROP PUBLICATION IF EXISTS pglog_stream_%s;", config.ReplicationSlotName)) + _, err = result.ReadAll() + if err != nil { + stream.logger.Errorf("drop publication if exists error %s", err.Error()) + } + + for i, table := range tableNames { + tableNames[i] = fmt.Sprintf("%s.%s", config.DbSchema, table) + } + + tablesSchemaFilter := fmt.Sprintf("FOR TABLE %s", strings.Join(tableNames, ",")) + stream.logger.Infof("Create publication for table schemas with query %s", fmt.Sprintf("CREATE PUBLICATION pglog_stream_%s %s;", config.ReplicationSlotName, tablesSchemaFilter)) + result = stream.pgConn.Exec(context.Background(), fmt.Sprintf("CREATE PUBLICATION pglog_stream_%s %s;", config.ReplicationSlotName, tablesSchemaFilter)) + _, err = result.ReadAll() + if err != nil { + stream.logger.Fatalf("create publication error %s", err.Error()) + } + stream.logger.Info("Created Postgresql publication", "publication_name", config.ReplicationSlotName) + + sysident, err := pglogrepl.IdentifySystem(context.Background(), stream.pgConn) + if err != nil { + stream.logger.Fatalf("Failed to identify the system %s", err.Error()) + } + + stream.logger.Info("System identification result", "SystemID:", sysident.SystemID, "Timeline:", sysident.Timeline, "XLogPos:", sysident.XLogPos, "DBName:", sysident.DBName) + + var freshlyCreatedSlot = false + var confirmedLSNFromDB string + // check is replication slot exist to get last restart SLN + connExecResult := stream.pgConn.Exec(context.TODO(), fmt.Sprintf("SELECT confirmed_flush_lsn FROM pg_replication_slots WHERE slot_name = '%s'", config.ReplicationSlotName)) + if slotCheckResults, err := connExecResult.ReadAll(); err != nil { + stream.logger.Fatal(err) + } else { + if len(slotCheckResults) == 0 || len(slotCheckResults[0].Rows) == 0 { + // here we create a new replication slot because there is no slot found + var createSlotResult CreateReplicationSlotResult + createSlotResult, err = CreateReplicationSlot(context.Background(), stream.pgConn, stream.slotName, "wal2json", + CreateReplicationSlotOptions{Temporary: false, + SnapshotAction: "export", + }) + if err != nil { + stream.logger.Fatalf("Failed to create replication slot for the database: %s", err.Error()) + } + stream.snapshotName = createSlotResult.SnapshotName + freshlyCreatedSlot = true + } else { + slotCheckRow := slotCheckResults[0].Rows[0] + confirmedLSNFromDB = string(slotCheckRow[0]) + stream.logger.Info("Replication slot restart LSN extracted from DB", "LSN", confirmedLSNFromDB) + } + } + + var lsnrestart pglogrepl.LSN + if freshlyCreatedSlot { + lsnrestart = sysident.XLogPos + } else { + lsnrestart, _ = pglogrepl.ParseLSN(confirmedLSNFromDB) + } + + stream.lsnrestart = lsnrestart + + if freshlyCreatedSlot { + stream.clientXLogPos = sysident.XLogPos + } else { + stream.clientXLogPos = lsnrestart + } + + stream.standbyMessageTimeout = time.Second * 10 + stream.nextStandbyMessageDeadline = time.Now().Add(stream.standbyMessageTimeout) + stream.streamCtx, stream.streamCancel = context.WithCancel(context.Background()) + + if !freshlyCreatedSlot || config.StreamOldData == false { + stream.startLr() + go stream.streamMessagesAsync() + } else { + // New messages will be streamed after the snapshot has been processed. + go stream.processSnapshot() + } + + return stream, err +} + +func (s *Stream) startLr() { + var err error + err = pglogrepl.StartReplication(context.Background(), s.pgConn, s.slotName, s.lsnrestart, pglogrepl.StartReplicationOptions{PluginArgs: pluginArguments}) + if err != nil { + s.logger.Fatalf("Starting replication slot failed: %s", err.Error()) + } + s.logger.Info("Started logical replication on slot", "slot-name", s.slotName) +} + +func (s *Stream) AckLSN(lsn string) { + var err error + s.clientXLogPos, err = pglogrepl.ParseLSN(lsn) + if err != nil { + s.logger.Fatalf("Failed to parse LSN for Acknowledge %s", err.Error()) + } + + err = pglogrepl.SendStandbyStatusUpdate(context.Background(), s.pgConn, pglogrepl.StandbyStatusUpdate{ + WALApplyPosition: s.clientXLogPos, + WALWritePosition: s.clientXLogPos, + ReplyRequested: true, + }) + + if err != nil { + s.logger.Fatalf("SendStandbyStatusUpdate failed: %s", err.Error()) + } + s.logger.Debugf("Sent Standby status message at LSN#%s", s.clientXLogPos.String()) + s.nextStandbyMessageDeadline = time.Now().Add(s.standbyMessageTimeout) +} + +func (s *Stream) streamMessagesAsync() { + for { + select { + case <-s.streamCtx.Done(): + s.logger.Warn("Stream was cancelled...exiting...") + return + default: + if time.Now().After(s.nextStandbyMessageDeadline) { + var err error + err = pglogrepl.SendStandbyStatusUpdate(context.Background(), s.pgConn, pglogrepl.StandbyStatusUpdate{ + WALWritePosition: s.clientXLogPos, + }) + + if err != nil { + s.logger.Fatalf("SendStandbyStatusUpdate failed: %s", err.Error()) + } + s.logger.Debugf("Sent Standby status message at LSN#%s", s.clientXLogPos.String()) + s.nextStandbyMessageDeadline = time.Now().Add(s.standbyMessageTimeout) + } + + ctx, cancel := context.WithDeadline(context.Background(), s.nextStandbyMessageDeadline) + rawMsg, err := s.pgConn.ReceiveMessage(ctx) + s.standbyCtxCancel = cancel + + if err != nil && (errors.Is(err, context.Canceled) || s.stopped) { + s.logger.Warn("Service was interrpupted....stop reading from replication slot") + return + } + + if err != nil { + if pgconn.Timeout(err) { + continue + } + + s.logger.Fatalf("Failed to receive messages from PostgreSQL %s", err.Error()) + } + + if errMsg, ok := rawMsg.(*pgproto3.ErrorResponse); ok { + s.logger.Fatalf("Received broken Postgres WAL. Error: %+v", errMsg) + } + + msg, ok := rawMsg.(*pgproto3.CopyData) + if !ok { + s.logger.Warnf("Received unexpected message: %T\n", rawMsg) + continue + } + + switch msg.Data[0] { + case pglogrepl.PrimaryKeepaliveMessageByteID: + pkm, err := pglogrepl.ParsePrimaryKeepaliveMessage(msg.Data[1:]) + if err != nil { + s.logger.Fatalf("ParsePrimaryKeepaliveMessage failed: %s", err.Error()) + } + + if pkm.ReplyRequested { + s.nextStandbyMessageDeadline = time.Time{} + } + + case pglogrepl.XLogDataByteID: + xld, err := pglogrepl.ParseXLogData(msg.Data[1:]) + if err != nil { + s.logger.Fatalf("ParseXLogData failed: %s", err.Error()) + } + clientXLogPos := xld.WALStart + pglogrepl.LSN(len(xld.WALData)) + var changes WallMessage + if err := json.NewDecoder(bytes.NewReader(xld.WALData)).Decode(&changes); err != nil { + panic(fmt.Errorf("cant parse change from database to filter it %v", err)) + } + + if len(changes.Change) == 0 { + s.AckLSN(clientXLogPos.String()) + } else { + s.changeFilter.FilterChange(clientXLogPos.String(), changes, func(change Wal2JsonChanges) { + s.messages <- change + }) + } + } + } + } +} +func (s *Stream) processSnapshot() { + snapshotter, err := NewSnapshotter(s.dbConfig, s.snapshotName) + if err != nil { + s.logger.Errorf("Failed to create database snapshot: %v", err.Error()) + s.cleanUpOnFailure() + os.Exit(1) + } + if err = snapshotter.Prepare(); err != nil { + s.logger.Errorf("Failed to prepare database snapshot: %v", err.Error()) + s.cleanUpOnFailure() + os.Exit(1) + } + defer func() { + snapshotter.ReleaseSnapshot() + snapshotter.CloseConn() + }() + + for _, table := range s.tableNames { + s.logger.Info("Processing snapshot for table", "table", table) + + var ( + avgRowSizeBytes sql.NullInt64 + offset = int(0) + ) + avgRowSizeBytes = snapshotter.FindAvgRowSize(table) + + batchSize := snapshotter.CalculateBatchSize(helpers.GetAvailableMemory(), uint64(avgRowSizeBytes.Int64)) + s.logger.Info("Querying snapshot", "batch_side", batchSize, "available_memory", helpers.GetAvailableMemory(), "avg_row_size", avgRowSizeBytes.Int64) + + tablePk, err := s.getPrimaryKeyColumn(table) + if err != nil { + panic(err) + } + + for { + var snapshotRows *sql.Rows + if snapshotRows, err = snapshotter.QuerySnapshotData(table, tablePk, batchSize, offset); err != nil { + log.Fatalf("Can't query snapshot data %v", err) + } + + columnTypes, err := snapshotRows.ColumnTypes() + var columnTypesString = make([]string, len(columnTypes)) + columnNames, err := snapshotRows.Columns() + for i, _ := range columnNames { + columnTypesString[i] = columnTypes[i].DatabaseTypeName() + } + + if err != nil { + panic(err) + } + + count := len(columnTypes) + var rowsCount = 0 + for snapshotRows.Next() { + rowsCount += 1 + scanArgs := make([]interface{}, count) + for i, v := range columnTypes { + switch v.DatabaseTypeName() { + case "VARCHAR", "TEXT", "UUID", "TIMESTAMP": + scanArgs[i] = new(sql.NullString) + break + case "BOOL": + scanArgs[i] = new(sql.NullBool) + break + case "INT4": + scanArgs[i] = new(sql.NullInt64) + break + default: + scanArgs[i] = new(sql.NullString) + } + } + + err := snapshotRows.Scan(scanArgs...) + + if err != nil { + panic(err) + } + + var columnValues = make([]interface{}, len(columnTypes)) + for i, _ := range columnTypes { + if z, ok := (scanArgs[i]).(*sql.NullBool); ok { + columnValues[i] = z.Bool + continue + } + if z, ok := (scanArgs[i]).(*sql.NullString); ok { + columnValues[i] = z.String + continue + } + if z, ok := (scanArgs[i]).(*sql.NullInt64); ok { + columnValues[i] = z.Int64 + continue + } + if z, ok := (scanArgs[i]).(*sql.NullFloat64); ok { + columnValues[i] = z.Float64 + continue + } + if z, ok := (scanArgs[i]).(*sql.NullInt32); ok { + columnValues[i] = z.Int32 + continue + } + + columnValues[i] = scanArgs[i] + } + + var snapshotChanges []Wal2JsonChange + snapshotChanges = append(snapshotChanges, Wal2JsonChange{ + Kind: "insert", + Schema: s.schema, + Table: table, + ColumnNames: columnNames, + ColumnValues: columnValues, + }) + var lsn *string + snapshotChangePacket := Wal2JsonChanges{ + Lsn: lsn, + Changes: snapshotChanges, + } + + s.snapshotMessages <- snapshotChangePacket + } + + offset += batchSize + + if batchSize != rowsCount { + break + } + } + + } + + s.startLr() + go s.streamMessagesAsync() +} + +func (s *Stream) OnMessage(callback OnMessage) { + for { + select { + case snapshotMessage := <-s.snapshotMessages: + callback(snapshotMessage) + case message := <-s.messages: + callback(message) + case <-s.streamCtx.Done(): + return + } + } +} + +func (s *Stream) SnapshotMessageC() chan Wal2JsonChanges { + return s.snapshotMessages +} + +func (s *Stream) LrMessageC() chan Wal2JsonChanges { + return s.messages +} + +// cleanUpOnFailure drops replication slot and publication if database snapshotting was failed for any reason +func (s *Stream) cleanUpOnFailure() { + s.logger.Warn("Cleaning up resources on accident.", "replication-slot", s.slotName) + err := DropReplicationSlot(context.Background(), s.pgConn, s.slotName, DropReplicationSlotOptions{Wait: true}) + if err != nil { + s.logger.Errorf("Failed to drop replication slot: %s", err.Error()) + } + s.pgConn.Close(context.TODO()) +} + +func (s *Stream) getPrimaryKeyColumn(tableName string) (string, error) { + q := fmt.Sprintf(` + SELECT a.attname + FROM pg_index i + JOIN pg_attribute a ON a.attrelid = i.indrelid + AND a.attnum = ANY(i.indkey) + WHERE i.indrelid = '%s'::regclass + AND i.indisprimary; + `, tableName) + + reader := s.pgConn.Exec(context.Background(), q) + data, err := reader.ReadAll() + if err != nil { + return "", err + } + + pkResultRow := data[0].Rows[0] + pkColName := string(pkResultRow[0]) + return pkColName, nil +} + +func (s *Stream) Stop() error { + s.m.Lock() + s.stopped = true + s.m.Unlock() + + if s.pgConn != nil { + if s.streamCtx != nil { + s.streamCancel() + s.standbyCtxCancel() + } + return s.pgConn.Close(context.Background()) + } + + return nil +} diff --git a/internal/impl/postgresql/pglogicalstream/message.go b/internal/impl/postgresql/pglogicalstream/message.go new file mode 100644 index 0000000000..914d0e9c96 --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/message.go @@ -0,0 +1,728 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package pglogicalstream + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "strconv" + "time" +) + +var ( + errMsgNotSupported = errors.New("replication message not supported") +) + +// MessageType indicates the type of logical replication message. +type MessageType uint8 + +func (t MessageType) String() string { + switch t { + case MessageTypeBegin: + return "Begin" + case MessageTypeCommit: + return "Commit" + case MessageTypeOrigin: + return "Origin" + case MessageTypeRelation: + return "Relation" + case MessageTypeType: + return "Type" + case MessageTypeInsert: + return "Insert" + case MessageTypeUpdate: + return "Update" + case MessageTypeDelete: + return "Delete" + case MessageTypeTruncate: + return "Truncate" + case MessageTypeMessage: + return "Message" + case MessageTypeStreamStart: + return "StreamStart" + case MessageTypeStreamStop: + return "StreamStop" + case MessageTypeStreamCommit: + return "StreamCommit" + case MessageTypeStreamAbort: + return "StreamAbort" + default: + return "Unknown" + } +} + +// List of types of logical replication messages. +const ( + MessageTypeBegin MessageType = 'B' + MessageTypeMessage MessageType = 'M' + MessageTypeCommit MessageType = 'C' + MessageTypeOrigin MessageType = 'O' + MessageTypeRelation MessageType = 'R' + MessageTypeType MessageType = 'Y' + MessageTypeInsert MessageType = 'I' + MessageTypeUpdate MessageType = 'U' + MessageTypeDelete MessageType = 'D' + MessageTypeTruncate MessageType = 'T' + MessageTypeStreamStart MessageType = 'S' + MessageTypeStreamStop MessageType = 'E' + MessageTypeStreamCommit MessageType = 'c' + MessageTypeStreamAbort MessageType = 'A' +) + +// Message is a message received from server. +type Message interface { + Type() MessageType +} + +// MessageDecoder decodes message into struct. +type MessageDecoder interface { + Decode([]byte) error +} + +type baseMessage struct { + msgType MessageType +} + +// Type returns message type. +func (m *baseMessage) Type() MessageType { + return m.msgType +} + +// SetType sets message type. +// This method is added to help writing test code in application. +// The message type is still defined by message data. +func (m *baseMessage) SetType(t MessageType) { + m.msgType = t +} + +// Decode parse src into message struct. The src must contain the complete message starts after +// the first message type byte. +func (m *baseMessage) Decode(_ []byte) error { + return fmt.Errorf("message decode not implemented") +} + +func (m *baseMessage) lengthError(name string, expectedLen, actualLen int) error { + return fmt.Errorf("%s must have %d bytes, got %d bytes", name, expectedLen, actualLen) +} + +func (m *baseMessage) decodeStringError(name, field string) error { + return fmt.Errorf("%s.%s decode string error", name, field) +} + +func (m *baseMessage) decodeTupleDataError(name, field string, e error) error { + return fmt.Errorf("%s.%s decode tuple error: %s", name, field, e.Error()) +} + +func (m *baseMessage) invalidTupleTypeError(name, field string, e string, a byte) error { + return fmt.Errorf("%s.%s invalid tuple type value, expect %s, actual %c", name, field, e, a) +} + +// decodeString decode a string from src and returns the length of bytes being parsed. +// +// String type definition: https://www.postgresql.org/docs/current/protocol-message-types.html +// String(s) +// +// A null-terminated string (C-style string). There is no specific length limitation on strings. +// If s is specified it is the exact value that will appear, otherwise the value is variable. +// Eg. String, String("user"). +// +// If there is no null byte in src, return -1. +func (m *baseMessage) decodeString(src []byte) (string, int) { + end := bytes.IndexByte(src, byte(0)) + if end == -1 { + return "", -1 + } + // Trim the last null byte before converting it to a Golang string, then we can + // compare the result string with a Golang string literal. + return string(src[:end]), end + 1 +} + +func (m *baseMessage) decodeLSN(src []byte) (LSN, int) { + return LSN(binary.BigEndian.Uint64(src)), 8 +} + +func (m *baseMessage) decodeTime(src []byte) (time.Time, int) { + return pgTimeToTime(int64(binary.BigEndian.Uint64(src))), 8 +} + +func (m *baseMessage) decodeUint16(src []byte) (uint16, int) { + return binary.BigEndian.Uint16(src), 2 +} + +func (m *baseMessage) decodeUint32(src []byte) (uint32, int) { + return binary.BigEndian.Uint32(src), 4 +} + +func (m *baseMessage) decodeInt32(src []byte) (int32, int) { + asUint32, size := m.decodeUint32(src) + return int32(asUint32), size +} + +// BeginMessage is a begin message. +type BeginMessage struct { + baseMessage + //FinalLSN is the final LSN of the transaction. + FinalLSN LSN + // CommitTime is the commit timestamp of the transaction. + CommitTime time.Time + // Xid of the transaction. + Xid uint32 +} + +// Decode decodes the message from src. +func (m *BeginMessage) Decode(src []byte) error { + if len(src) < 20 { + return m.lengthError("BeginMessage", 20, len(src)) + } + var low, used int + m.FinalLSN, used = m.decodeLSN(src) + low += used + m.CommitTime, used = m.decodeTime(src[low:]) + low += used + m.Xid = binary.BigEndian.Uint32(src[low:]) + + m.SetType(MessageTypeBegin) + + return nil +} + +// CommitMessage is a commit message. +type CommitMessage struct { + baseMessage + // Flags currently unused (must be 0). + Flags uint8 + // CommitLSN is the LSN of the commit. + CommitLSN LSN + // TransactionEndLSN is the end LSN of the transaction. + TransactionEndLSN LSN + // CommitTime is the commit timestamp of the transaction + CommitTime time.Time +} + +// Decode decodes the message from src. +func (m *CommitMessage) Decode(src []byte) error { + if len(src) < 25 { + return m.lengthError("CommitMessage", 25, len(src)) + } + var low, used int + m.Flags = src[0] + low += 1 + m.CommitLSN, used = m.decodeLSN(src[low:]) + low += used + m.TransactionEndLSN, used = m.decodeLSN(src[low:]) + low += used + m.CommitTime, _ = m.decodeTime(src[low:]) + + m.SetType(MessageTypeCommit) + + return nil +} + +// OriginMessage is an origin message. +type OriginMessage struct { + baseMessage + // CommitLSN is the LSN of the commit on the origin server. + CommitLSN LSN + Name string +} + +// Decode decodes to message from src. +func (m *OriginMessage) Decode(src []byte) error { + if len(src) < 8 { + return m.lengthError("OriginMessage", 9, len(src)) + } + + var low, used int + m.CommitLSN, used = m.decodeLSN(src) + low += used + m.Name, used = m.decodeString(src[low:]) + if used < 0 { + return m.decodeStringError("OriginMessage", "Name") + } + + m.SetType(MessageTypeOrigin) + + return nil +} + +// RelationMessageColumn is one column in a RelationMessage. +type RelationMessageColumn struct { + // Flags for the column. Currently, it can be either 0 for no flags or 1 which marks the column as part of the key. + Flags uint8 + + Name string + + // DataType is the ID of the column's data type. + DataType uint32 + + // TypeModifier is type modifier of the column (atttypmod). + TypeModifier int32 +} + +// RelationMessage is a relation message. +type RelationMessage struct { + baseMessage + RelationID uint32 + Namespace string + RelationName string + ReplicaIdentity uint8 + ColumnNum uint16 + Columns []*RelationMessageColumn +} + +// Decode decodes to message from src. +func (m *RelationMessage) Decode(src []byte) error { + if len(src) < 7 { + return m.lengthError("RelationMessage", 7, len(src)) + } + + var low, used int + m.RelationID, used = m.decodeUint32(src) + low += used + + m.Namespace, used = m.decodeString(src[low:]) + if used < 0 { + return m.decodeStringError("RelationMessage", "Namespace") + } + low += used + + m.RelationName, used = m.decodeString(src[low:]) + if used < 0 { + return m.decodeStringError("RelationMessage", "RelationName") + } + low += used + + m.ReplicaIdentity = src[low] + low++ + + m.ColumnNum, used = m.decodeUint16(src[low:]) + low += used + + for i := 0; i < int(m.ColumnNum); i++ { + column := new(RelationMessageColumn) + column.Flags = src[low] + low++ + column.Name, used = m.decodeString(src[low:]) + if used < 0 { + return m.decodeStringError("RelationMessage", fmt.Sprintf("Column[%d].Name", i)) + } + low += used + + column.DataType, used = m.decodeUint32(src[low:]) + low += used + + column.TypeModifier, used = m.decodeInt32(src[low:]) + low += used + + m.Columns = append(m.Columns, column) + } + + m.SetType(MessageTypeRelation) + + return nil +} + +// TypeMessage is a type message. +type TypeMessage struct { + baseMessage + DataType uint32 + Namespace string + Name string +} + +// Decode decodes to message from src. +func (m *TypeMessage) Decode(src []byte) error { + if len(src) < 6 { + return m.lengthError("TypeMessage", 6, len(src)) + } + + var low, used int + m.DataType, used = m.decodeUint32(src) + low += used + + m.Namespace, used = m.decodeString(src[low:]) + if used < 0 { + return m.decodeStringError("TypeMessage", "Namespace") + } + low += used + + m.Name, used = m.decodeString(src[low:]) + if used < 0 { + return m.decodeStringError("TypeMessage", "Name") + } + + m.SetType(MessageTypeType) + + return nil +} + +// List of types of data in a tuple. +const ( + TupleDataTypeNull = uint8('n') + TupleDataTypeToast = uint8('u') + TupleDataTypeText = uint8('t') + TupleDataTypeBinary = uint8('b') +) + +// TupleDataColumn is a column in a TupleData. +type TupleDataColumn struct { + // DataType indicates how the data is stored. + // Byte1('n') Identifies the data as NULL value. + // Or + // Byte1('u') Identifies unchanged TOASTed value (the actual value is not sent). + // Or + // Byte1('t') Identifies the data as text formatted value. + // Or + // Byte1('b') Identifies the data as binary value. + DataType uint8 + Length uint32 + // Data is th value of the column, in text format. (A future release might support additional formats.) n is the above length. + Data []byte +} + +// Int64 parse column data as an int64 integer. +func (c *TupleDataColumn) Int64() (int64, error) { + if c.DataType != TupleDataTypeText { + return 0, fmt.Errorf("invalid column's data type, expect %c, actual %c", + TupleDataTypeText, c.DataType) + } + + return strconv.ParseInt(string(c.Data), 10, 64) +} + +// TupleData contains row change information. +type TupleData struct { + baseMessage + ColumnNum uint16 + Columns []*TupleDataColumn +} + +// Decode decodes to message from src. +func (m *TupleData) Decode(src []byte) (int, error) { + var low, used int + + m.ColumnNum, used = m.decodeUint16(src) + low += used + + for i := 0; i < int(m.ColumnNum); i++ { + column := new(TupleDataColumn) + column.DataType = src[low] + low += 1 + + switch column.DataType { + case TupleDataTypeText, TupleDataTypeBinary: + column.Length, used = m.decodeUint32(src[low:]) + low += used + + column.Data = make([]byte, int(column.Length)) + for j := 0; j < int(column.Length); j++ { + column.Data[j] = src[low+j] + } + low += int(column.Length) + case TupleDataTypeNull, TupleDataTypeToast: + } + + m.Columns = append(m.Columns, column) + } + + return low, nil +} + +// InsertMessage is a insert message +type InsertMessage struct { + baseMessage + // RelationID is the ID of the relation corresponding to the ID in the relation message. + RelationID uint32 + Tuple *TupleData +} + +// Decode decodes to message from src. +func (m *InsertMessage) Decode(src []byte) error { + if len(src) < 8 { + return m.lengthError("InsertMessage", 8, len(src)) + } + + var low, used int + + m.RelationID, used = m.decodeUint32(src) + low += used + + tupleType := src[low] + low += 1 + if tupleType != 'N' { + return m.invalidTupleTypeError("InsertMessage", "TupleType", "N", tupleType) + } + + m.Tuple = new(TupleData) + _, err := m.Tuple.Decode(src[low:]) + if err != nil { + return m.decodeTupleDataError("InsertMessage", "TupleData", err) + } + + m.SetType(MessageTypeInsert) + + return nil +} + +// List of types of UpdateMessage tuples. +const ( + UpdateMessageTupleTypeNone = uint8(0) + UpdateMessageTupleTypeKey = uint8('K') + UpdateMessageTupleTypeOld = uint8('O') + UpdateMessageTupleTypeNew = uint8('N') +) + +// UpdateMessage is a update message. +type UpdateMessage struct { + baseMessage + RelationID uint32 + + // OldTupleType + // Byte1('K'): + // Identifies the following TupleData submessage as a key. + // This field is optional and is only present if the update changed data + // in any of the column(s) that are part of the REPLICA IDENTITY index. + // + // Byte1('O'): + // Identifies the following TupleData submessage as an old tuple. + // This field is optional and is only present if table in which the update happened + // has REPLICA IDENTITY set to FULL. + // + // The Update message may contain either a 'K' message part or an 'O' message part + // or neither of them, but never both of them. + OldTupleType uint8 + OldTuple *TupleData + + // NewTuple is the contents of a new tuple. + // Byte1('N'): Identifies the following TupleData message as a new tuple. + NewTuple *TupleData +} + +// Decode decodes to message from src. +func (m *UpdateMessage) Decode(src []byte) (err error) { + if len(src) < 6 { + return m.lengthError("UpdateMessage", 6, len(src)) + } + + var low, used int + + m.RelationID, used = m.decodeUint32(src) + low += used + + tupleType := src[low] + low++ + + switch tupleType { + case UpdateMessageTupleTypeKey, UpdateMessageTupleTypeOld: + m.OldTupleType = tupleType + m.OldTuple = new(TupleData) + used, err = m.OldTuple.Decode(src[low:]) + if err != nil { + return m.decodeTupleDataError("UpdateMessage", "OldTuple", err) + } + low += used + low++ + fallthrough + case UpdateMessageTupleTypeNew: + m.NewTuple = new(TupleData) + _, err = m.NewTuple.Decode(src[low:]) + if err != nil { + return m.decodeTupleDataError("UpdateMessage", "NewTuple", err) + } + default: + return m.invalidTupleTypeError("UpdateMessage", "Tuple", "K/O/N", tupleType) + } + + m.SetType(MessageTypeUpdate) + + return nil +} + +// List of types of DeleteMessage tuples. +const ( + DeleteMessageTupleTypeKey = uint8('K') + DeleteMessageTupleTypeOld = uint8('O') +) + +// DeleteMessage is a delete message. +type DeleteMessage struct { + baseMessage + RelationID uint32 + // OldTupleType + // Byte1('K'): + // Identifies the following TupleData submessage as a key. + // This field is present if the table in which the delete has happened uses an index + // as REPLICA IDENTITY. + // + // Byte1('O') + // Identifies the following TupleData message as an old tuple. + // This field is present if the table in which the delete has happened has + // REPLICA IDENTITY set to FULL. + // + // The Delete message may contain either a 'K' message part or an 'O' message part, + // but never both of them. + OldTupleType uint8 + OldTuple *TupleData +} + +// Decode decodes a message from src. +func (m *DeleteMessage) Decode(src []byte) (err error) { + if len(src) < 4 { + return m.lengthError("DeleteMessage", 4, len(src)) + } + + var low, used int + + m.RelationID, used = m.decodeUint32(src) + low += used + + m.OldTupleType = src[low] + low++ + + switch m.OldTupleType { + case DeleteMessageTupleTypeKey, DeleteMessageTupleTypeOld: + m.OldTuple = new(TupleData) + _, err = m.OldTuple.Decode(src[low:]) + if err != nil { + return m.decodeTupleDataError("DeleteMessage", "OldTuple", err) + } + default: + return m.invalidTupleTypeError("DeleteMessage", "OldTupleType", "K/O", m.OldTupleType) + } + + m.SetType(MessageTypeDelete) + + return nil +} + +// List of truncate options. +const ( + TruncateOptionCascade = uint8(1) << iota + TruncateOptionRestartIdentity +) + +// TruncateMessage is a truncate message. +type TruncateMessage struct { + baseMessage + RelationNum uint32 + Option uint8 + RelationIDs []uint32 +} + +// Decode decodes to message from src. +func (m *TruncateMessage) Decode(src []byte) (err error) { + if len(src) < 9 { + return m.lengthError("TruncateMessage", 9, len(src)) + } + + var low, used int + m.RelationNum, used = m.decodeUint32(src) + low += used + + m.Option = src[low] + low++ + + m.RelationIDs = make([]uint32, m.RelationNum) + for i := 0; i < int(m.RelationNum); i++ { + m.RelationIDs[i], used = m.decodeUint32(src[low:]) + low += used + } + + m.SetType(MessageTypeTruncate) + + return nil +} + +// LogicalDecodingMessage is a logical decoding message. +type LogicalDecodingMessage struct { + baseMessage + + LSN LSN + Transactional bool + Prefix string + Content []byte +} + +// Decode decodes a message from src. +func (m *LogicalDecodingMessage) Decode(src []byte) (err error) { + if len(src) < 14 { + return m.lengthError("LogicalDecodingMessage", 14, len(src)) + } + + var low, used int + + flags := src[low] + m.Transactional = flags == 1 + low++ + + m.LSN, used = m.decodeLSN(src[low:]) + low += used + + m.Prefix, used = m.decodeString(src[low:]) + low += used + + contentLength, used := m.decodeUint32(src[low:]) + low += used + + m.Content = src[low : low+int(contentLength)] + + m.SetType(MessageTypeMessage) + + return nil +} + +// Parse parse a logical replication message. +func Parse(data []byte) (m Message, err error) { + var decoder MessageDecoder + msgType := MessageType(data[0]) + switch msgType { + case MessageTypeRelation: + decoder = new(RelationMessage) + case MessageTypeType: + decoder = new(TypeMessage) + case MessageTypeInsert: + decoder = new(InsertMessage) + case MessageTypeUpdate: + decoder = new(UpdateMessage) + case MessageTypeDelete: + decoder = new(DeleteMessage) + case MessageTypeTruncate: + decoder = new(TruncateMessage) + case MessageTypeMessage: + decoder = new(LogicalDecodingMessage) + default: + decoder = getCommonDecoder(msgType) + } + + if decoder == nil { + return nil, errMsgNotSupported + } + + if err = decoder.Decode(data[1:]); err != nil { + return nil, err + } + + return decoder.(Message), nil +} + +func getCommonDecoder(msgType MessageType) MessageDecoder { + var decoder MessageDecoder + switch msgType { + case MessageTypeBegin: + decoder = new(BeginMessage) + case MessageTypeCommit: + decoder = new(CommitMessage) + case MessageTypeOrigin: + decoder = new(OriginMessage) + } + + return decoder +} diff --git a/internal/impl/postgresql/pglogicalstream/message_test.go b/internal/impl/postgresql/pglogicalstream/message_test.go new file mode 100644 index 0000000000..761875ad1c --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/message_test.go @@ -0,0 +1,856 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package pglogicalstream + +import ( + "encoding/binary" + "errors" + "math/rand" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +var bigEndian = binary.BigEndian + +type messageSuite struct { + suite.Suite +} + +func (s *messageSuite) R() *require.Assertions { + return s.Require() +} + +func (s *messageSuite) Equal(e, a interface{}, args ...interface{}) { + s.R().Equal(e, a, args...) +} + +func (s *messageSuite) NoError(err error) { + s.R().NoError(err) +} + +func (s *messageSuite) True(value bool) { + s.R().True(value) +} + +func (s *messageSuite) newLSN() LSN { + return LSN(rand.Int63()) +} + +func (s *messageSuite) newXid() uint32 { + return uint32(rand.Int31()) +} + +func (s *messageSuite) newTime() (time.Time, uint64) { + // Postgres time format only support millisecond accuracy. + now := time.Now().Truncate(time.Millisecond) + return now, uint64(timeToPgTime(now)) +} + +func (s *messageSuite) newRelationID() uint32 { + return uint32(rand.Int31()) +} + +func (s *messageSuite) putString(dst []byte, value string) int { + copy(dst, value) + dst[len(value)] = byte(0) + return len(value) + 1 +} + +func (s *messageSuite) tupleColumnLength(dataType uint8, data []byte) int { + switch dataType { + case uint8('n'), uint8('u'): + return 1 + case uint8('t'): + return 1 + 4 + len(data) + default: + s.FailNow("invalid data type of a tuple: %c", dataType) + return 0 + } +} + +func (s *messageSuite) putTupleColumn(dst []byte, dataType uint8, data []byte) int { + dst[0] = dataType + + switch dataType { + case uint8('n'), uint8('u'): + return 1 + case uint8('t'): + bigEndian.PutUint32(dst[1:], uint32(len(data))) + copy(dst[5:], data) + return 5 + len(data) + default: + s.FailNow("invalid data type of a tuple: %c", dataType) + return 0 + } +} + +func (s *messageSuite) putMessageTestData(msg []byte) *LogicalDecodingMessage { + // transaction flag + msg[0] = 1 + off := 1 + + lsn := s.newLSN() + bigEndian.PutUint64(msg[off:], uint64(lsn)) + off += 8 + + off += s.putString(msg[off:], "test") + + content := "hello" + + bigEndian.PutUint32(msg[off:], uint32(len(content))) + off += 4 + + for i := 0; i < len(content); i++ { + msg[off] = content[i] + off++ + } + return &LogicalDecodingMessage{ + Transactional: true, + LSN: lsn, + Prefix: "test", + Content: []byte("hello"), + } +} + +func (s *messageSuite) assertV1NotSupported(msg []byte) { + _, err := Parse(msg) + s.Error(err) + s.True(errors.Is(err, errMsgNotSupported)) +} + +func (s *messageSuite) createRelationTestData() ([]byte, *RelationMessage) { + relationID := uint32(rand.Int31()) + namespace := "public" + relationName := "table1" + noAtttypmod := int32(-1) + col1 := "id" // int8 + col2 := "name" // text + col3 := "created_at" // timestamptz + + col1Length := 1 + len(col1) + 1 + 4 + 4 + col2Length := 1 + len(col2) + 1 + 4 + 4 + col3Length := 1 + len(col3) + 1 + 4 + 4 + + msg := make([]byte, 1+4+len(namespace)+1+len(relationName)+1+1+ + 2+col1Length+col2Length+col3Length) + msg[0] = 'R' + off := 1 + bigEndian.PutUint32(msg[off:], relationID) + off += 4 + off += s.putString(msg[off:], namespace) + off += s.putString(msg[off:], relationName) + msg[off] = 1 + off++ + bigEndian.PutUint16(msg[off:], 3) + off += 2 + + msg[off] = 1 // column id is key + off++ + off += s.putString(msg[off:], col1) + bigEndian.PutUint32(msg[off:], 20) // int8 + off += 4 + bigEndian.PutUint32(msg[off:], uint32(noAtttypmod)) + off += 4 + + msg[off] = 0 + off++ + off += s.putString(msg[off:], col2) + bigEndian.PutUint32(msg[off:], 25) // text + off += 4 + bigEndian.PutUint32(msg[off:], uint32(noAtttypmod)) + off += 4 + + msg[off] = 0 + off++ + off += s.putString(msg[off:], col3) + bigEndian.PutUint32(msg[off:], 1184) // timestamptz + off += 4 + bigEndian.PutUint32(msg[off:], uint32(noAtttypmod)) + off += 4 + + expected := &RelationMessage{ + RelationID: relationID, + Namespace: namespace, + RelationName: relationName, + ReplicaIdentity: 1, + ColumnNum: 3, + Columns: []*RelationMessageColumn{ + { + Flags: 1, + Name: col1, + DataType: 20, + TypeModifier: -1, + }, + { + Flags: 0, + Name: col2, + DataType: 25, + TypeModifier: -1, + }, + { + Flags: 0, + Name: col3, + DataType: 1184, + TypeModifier: -1, + }, + }, + } + expected.msgType = 'R' + + return msg, expected +} + +func (s *messageSuite) createTypeTestData() ([]byte, *TypeMessage) { + dataType := uint32(1184) // timestamptz + namespace := "public" + name := "created_at" + + msg := make([]byte, 1+4+len(namespace)+1+len(name)+1) + msg[0] = 'Y' + off := 1 + bigEndian.PutUint32(msg[off:], dataType) + off += 4 + off += s.putString(msg[off:], namespace) + s.putString(msg[off:], name) + + expected := &TypeMessage{ + DataType: dataType, + Namespace: namespace, + Name: name, + } + expected.msgType = 'Y' + + return msg, expected +} + +func (s *messageSuite) createInsertTestData() ([]byte, *InsertMessage) { + relationID := s.newRelationID() + + col1Data := []byte("1") + col2Data := []byte("myname") + col3Data := []byte("123456789") + col1Length := s.tupleColumnLength('t', col1Data) + col2Length := s.tupleColumnLength('t', col2Data) + col3Length := s.tupleColumnLength('t', col3Data) + col4Length := s.tupleColumnLength('n', nil) + col5Length := s.tupleColumnLength('u', nil) + + msg := make([]byte, 1+4+1+2+col1Length+col2Length+col3Length+col4Length+col5Length) + msg[0] = 'I' + off := 1 + bigEndian.PutUint32(msg[off:], relationID) + off += 4 + msg[off] = 'N' + off++ + bigEndian.PutUint16(msg[off:], 5) + off += 2 + off += s.putTupleColumn(msg[off:], 't', col1Data) + off += s.putTupleColumn(msg[off:], 't', col2Data) + off += s.putTupleColumn(msg[off:], 't', col3Data) + off += s.putTupleColumn(msg[off:], 'n', nil) + s.putTupleColumn(msg[off:], 'u', nil) + + expected := &InsertMessage{ + RelationID: relationID, + Tuple: &TupleData{ + ColumnNum: 5, + Columns: []*TupleDataColumn{ + { + DataType: TupleDataTypeText, + Length: uint32(len(col1Data)), + Data: col1Data, + }, + { + DataType: TupleDataTypeText, + Length: uint32(len(col2Data)), + Data: col2Data, + }, + { + DataType: TupleDataTypeText, + Length: uint32(len(col3Data)), + Data: col3Data, + }, + { + DataType: TupleDataTypeNull, + }, + { + DataType: TupleDataTypeToast, + }, + }, + }, + } + expected.msgType = 'I' + + return msg, expected +} + +func (s *messageSuite) createUpdateTestDataTypeK() ([]byte, *UpdateMessage) { + relationID := s.newRelationID() + + oldCol1Data := []byte("123") // like an id + oldCol1Length := s.tupleColumnLength('t', oldCol1Data) + + newCol1Data := []byte("1124") + newCol2Data := []byte("myname") + newCol1Length := s.tupleColumnLength('t', newCol1Data) + newCol2Length := s.tupleColumnLength('t', newCol2Data) + + msg := make([]byte, 1+4+ + 1+2+oldCol1Length+ + 1+2+newCol1Length+newCol2Length) + msg[0] = 'U' + off := 1 + bigEndian.PutUint32(msg[off:], relationID) + off += 4 + msg[off] = 'K' + off += 1 + bigEndian.PutUint16(msg[off:], 1) + off += 2 + off += s.putTupleColumn(msg[off:], 't', oldCol1Data) + msg[off] = 'N' + off++ + bigEndian.PutUint16(msg[off:], 2) + off += 2 + off += s.putTupleColumn(msg[off:], 't', newCol1Data) + s.putTupleColumn(msg[off:], 't', newCol2Data) + expected := &UpdateMessage{ + RelationID: relationID, + OldTupleType: UpdateMessageTupleTypeKey, + OldTuple: &TupleData{ + ColumnNum: 1, + Columns: []*TupleDataColumn{ + { + DataType: TupleDataTypeText, + Length: uint32(len(oldCol1Data)), + Data: oldCol1Data, + }, + }, + }, + NewTuple: &TupleData{ + ColumnNum: 2, + Columns: []*TupleDataColumn{ + { + DataType: TupleDataTypeText, + Length: uint32(len(newCol1Data)), + Data: newCol1Data, + }, + { + DataType: TupleDataTypeText, + Length: uint32(len(newCol2Data)), + Data: newCol2Data, + }, + }, + }, + } + expected.msgType = 'U' + + return msg, expected +} + +func (s *messageSuite) createUpdateTestDataTypeO() ([]byte, *UpdateMessage) { + relationID := s.newRelationID() + + oldCol1Data := []byte("123") // like an id + oldCol1Length := s.tupleColumnLength('t', oldCol1Data) + oldCol2Data := []byte("myoldname") + oldCol2Length := s.tupleColumnLength('t', oldCol2Data) + + newCol1Data := []byte("1124") + newCol2Data := []byte("myname") + newCol1Length := s.tupleColumnLength('t', newCol1Data) + newCol2Length := s.tupleColumnLength('t', newCol2Data) + + msg := make([]byte, 1+4+ + 1+2+oldCol1Length+oldCol2Length+ + 1+2+newCol1Length+newCol2Length) + msg[0] = 'U' + off := 1 + bigEndian.PutUint32(msg[off:], relationID) + off += 4 + msg[off] = 'O' + off += 1 + bigEndian.PutUint16(msg[off:], 2) + off += 2 + off += s.putTupleColumn(msg[off:], 't', oldCol1Data) + off += s.putTupleColumn(msg[off:], 't', oldCol2Data) + msg[off] = 'N' + off++ + bigEndian.PutUint16(msg[off:], 2) + off += 2 + off += s.putTupleColumn(msg[off:], 't', newCol1Data) + s.putTupleColumn(msg[off:], 't', newCol2Data) + expected := &UpdateMessage{ + RelationID: relationID, + OldTupleType: UpdateMessageTupleTypeOld, + OldTuple: &TupleData{ + ColumnNum: 2, + Columns: []*TupleDataColumn{ + { + DataType: TupleDataTypeText, + Length: uint32(len(oldCol1Data)), + Data: oldCol1Data, + }, + { + DataType: TupleDataTypeText, + Length: uint32(len(oldCol2Data)), + Data: oldCol2Data, + }, + }, + }, + NewTuple: &TupleData{ + ColumnNum: 2, + Columns: []*TupleDataColumn{ + { + DataType: TupleDataTypeText, + Length: uint32(len(newCol1Data)), + Data: newCol1Data, + }, + { + DataType: TupleDataTypeText, + Length: uint32(len(newCol2Data)), + Data: newCol2Data, + }, + }, + }, + } + expected.msgType = 'U' + + return msg, expected +} + +func (s *messageSuite) createUpdateTestDataWithoutOldTuple() ([]byte, *UpdateMessage) { + relationID := s.newRelationID() + + newCol1Data := []byte("1124") + newCol2Data := []byte("myname") + newCol1Length := s.tupleColumnLength('t', newCol1Data) + newCol2Length := s.tupleColumnLength('t', newCol2Data) + + msg := make([]byte, 1+4+ + 1+2+newCol1Length+newCol2Length) + msg[0] = 'U' + off := 1 + bigEndian.PutUint32(msg[off:], relationID) + off += 4 + msg[off] = 'N' + off++ + bigEndian.PutUint16(msg[off:], 2) + off += 2 + off += s.putTupleColumn(msg[off:], 't', newCol1Data) + s.putTupleColumn(msg[off:], 't', newCol2Data) + expected := &UpdateMessage{ + RelationID: relationID, + OldTupleType: UpdateMessageTupleTypeNone, + NewTuple: &TupleData{ + ColumnNum: 2, + Columns: []*TupleDataColumn{ + { + DataType: TupleDataTypeText, + Length: uint32(len(newCol1Data)), + Data: newCol1Data, + }, + { + DataType: TupleDataTypeText, + Length: uint32(len(newCol2Data)), + Data: newCol2Data, + }, + }, + }, + } + expected.msgType = 'U' + + return msg, expected +} + +func (s *messageSuite) createDeleteTestDataTypeK() ([]byte, *DeleteMessage) { + relationID := s.newRelationID() + + oldCol1Data := []byte("123") // like an id + oldCol1Length := s.tupleColumnLength('t', oldCol1Data) + + msg := make([]byte, 1+4+ + 1+2+oldCol1Length) + msg[0] = 'D' + off := 1 + bigEndian.PutUint32(msg[off:], relationID) + off += 4 + msg[off] = 'K' + off++ + bigEndian.PutUint16(msg[off:], 1) + off += 2 + off += s.putTupleColumn(msg[off:], 't', oldCol1Data) + expected := &DeleteMessage{ + RelationID: relationID, + OldTupleType: DeleteMessageTupleTypeKey, + OldTuple: &TupleData{ + ColumnNum: 1, + Columns: []*TupleDataColumn{ + { + DataType: TupleDataTypeText, + Length: uint32(len(oldCol1Data)), + Data: oldCol1Data, + }, + }, + }, + } + expected.msgType = 'D' + return msg, expected +} + +func (s *messageSuite) createDeleteTestDataTypeO() ([]byte, *DeleteMessage) { + relationID := s.newRelationID() + + oldCol1Data := []byte("123") // like an id + oldCol1Length := s.tupleColumnLength('t', oldCol1Data) + oldCol2Data := []byte("myoldname") + oldCol2Length := s.tupleColumnLength('t', oldCol2Data) + + msg := make([]byte, 1+4+ + 1+2+oldCol1Length+oldCol2Length) + msg[0] = 'D' + off := 1 + bigEndian.PutUint32(msg[off:], relationID) + off += 4 + msg[off] = 'O' + off += 1 + bigEndian.PutUint16(msg[off:], 2) + off += 2 + off += s.putTupleColumn(msg[off:], 't', oldCol1Data) + off += s.putTupleColumn(msg[off:], 't', oldCol2Data) + expected := &DeleteMessage{ + RelationID: relationID, + OldTupleType: DeleteMessageTupleTypeOld, + OldTuple: &TupleData{ + ColumnNum: 2, + Columns: []*TupleDataColumn{ + { + DataType: TupleDataTypeText, + Length: uint32(len(oldCol1Data)), + Data: oldCol1Data, + }, + { + DataType: TupleDataTypeText, + Length: uint32(len(oldCol2Data)), + Data: oldCol2Data, + }, + }, + }, + } + expected.msgType = 'D' + return msg, expected +} + +func (s *messageSuite) createTruncateTestData() ([]byte, *TruncateMessage) { + relationID1 := s.newRelationID() + relationID2 := s.newRelationID() + option := uint8(0x01 | 0x02) + + msg := make([]byte, 1+4+1+4*2) + msg[0] = 'T' + off := 1 + bigEndian.PutUint32(msg[off:], 2) + off += 4 + msg[off] = option + off++ + bigEndian.PutUint32(msg[off:], relationID1) + off += 4 + bigEndian.PutUint32(msg[off:], relationID2) + expected := &TruncateMessage{ + RelationNum: 2, + Option: TruncateOptionCascade | TruncateOptionRestartIdentity, + RelationIDs: []uint32{ + relationID1, + relationID2, + }, + } + expected.msgType = 'T' + return msg, expected +} + +func (s *messageSuite) insertXid(msg []byte) ([]byte, uint32) { + msgV2 := make([]byte, 4+len(msg)) + msgV2[0] = msg[0] + xid := s.newXid() + bigEndian.PutUint32(msgV2[1:], xid) + copy(msgV2[5:], msg[1:]) + + return msgV2, xid +} + +func TestBeginMessageSuite(t *testing.T) { + suite.Run(t, new(beginMessageSuite)) +} + +type beginMessageSuite struct { + messageSuite +} + +func (s *beginMessageSuite) Test() { + finalLSN := s.newLSN() + commitTime, pgCommitTime := s.newTime() + xid := s.newXid() + + msg := make([]byte, 1+8+8+4) + msg[0] = 'B' + bigEndian.PutUint64(msg[1:], uint64(finalLSN)) + bigEndian.PutUint64(msg[9:], pgCommitTime) + bigEndian.PutUint32(msg[17:], xid) + + m, err := Parse(msg) + s.NoError(err) + beginMsg, ok := m.(*BeginMessage) + s.True(ok) + + expected := &BeginMessage{ + FinalLSN: finalLSN, + CommitTime: commitTime, + Xid: xid, + } + expected.msgType = 'B' + s.Equal(expected, beginMsg) +} + +func TestCommitMessage(t *testing.T) { + suite.Run(t, new(commitMessageSuite)) +} + +type commitMessageSuite struct { + messageSuite +} + +func (s *commitMessageSuite) Test() { + flags := uint8(0) + commitLSN := s.newLSN() + transactionEndLSN := s.newLSN() + commitTime, pgCommitTime := s.newTime() + + msg := make([]byte, 1+1+8+8+8) + msg[0] = 'C' + msg[1] = flags + bigEndian.PutUint64(msg[2:], uint64(commitLSN)) + bigEndian.PutUint64(msg[10:], uint64(transactionEndLSN)) + bigEndian.PutUint64(msg[18:], pgCommitTime) + + m, err := Parse(msg) + s.NoError(err) + commitMsg, ok := m.(*CommitMessage) + s.True(ok) + + expected := &CommitMessage{ + Flags: 0, + CommitLSN: commitLSN, + TransactionEndLSN: transactionEndLSN, + CommitTime: commitTime, + } + expected.msgType = 'C' + s.Equal(expected, commitMsg) +} + +func TestOriginMessage(t *testing.T) { + suite.Run(t, new(originMessageSuite)) +} + +type originMessageSuite struct { + messageSuite +} + +func (s *originMessageSuite) Test() { + commitLSN := s.newLSN() + name := "someorigin" + + msg := make([]byte, 1+8+len(name)+1) // 1 byte for \0 + msg[0] = 'O' + bigEndian.PutUint64(msg[1:], uint64(commitLSN)) + s.putString(msg[9:], name) + + m, err := Parse(msg) + s.NoError(err) + originMsg, ok := m.(*OriginMessage) + s.True(ok) + + expected := &OriginMessage{ + CommitLSN: commitLSN, + Name: name, + } + expected.msgType = 'O' + s.Equal(expected, originMsg) +} + +func TestRelationMessageSuite(t *testing.T) { + suite.Run(t, new(relationMessageSuite)) +} + +type relationMessageSuite struct { + messageSuite +} + +func (s *relationMessageSuite) Test() { + + msg, expected := s.createRelationTestData() + + m, err := Parse(msg) + s.NoError(err) + relationMsg, ok := m.(*RelationMessage) + s.True(ok) + + s.Equal(expected, relationMsg) +} + +func TestTypeMessageSuite(t *testing.T) { + suite.Run(t, new(typeMessageSuite)) +} + +type typeMessageSuite struct { + messageSuite +} + +func (s *typeMessageSuite) Test() { + msg, expected := s.createTypeTestData() + + m, err := Parse(msg) + s.NoError(err) + typeMsg, ok := m.(*TypeMessage) + s.True(ok) + + s.Equal(expected, typeMsg) +} + +func TestInsertMessageSuite(t *testing.T) { + suite.Run(t, new(insertMessageSuite)) +} + +type insertMessageSuite struct { + messageSuite +} + +func (s *insertMessageSuite) Test() { + + msg, expected := s.createInsertTestData() + + m, err := Parse(msg) + s.NoError(err) + insertMsg, ok := m.(*InsertMessage) + s.True(ok) + + s.Equal(expected, insertMsg) +} + +func TestUpdateMessageSuite(t *testing.T) { + suite.Run(t, new(updateMessageSuite)) +} + +type updateMessageSuite struct { + messageSuite +} + +func (s *updateMessageSuite) TestWithOldTupleTypeK() { + msg, expected := s.createUpdateTestDataTypeK() + m, err := Parse(msg) + s.NoError(err) + updateMsg, ok := m.(*UpdateMessage) + s.True(ok) + + s.Equal(expected, updateMsg) +} + +func (s *updateMessageSuite) TestWithOldTupleTypeO() { + msg, expected := s.createUpdateTestDataTypeO() + m, err := Parse(msg) + s.NoError(err) + updateMsg, ok := m.(*UpdateMessage) + s.True(ok) + + s.Equal(expected, updateMsg) +} + +func (s *updateMessageSuite) TestWithoutOldTuple() { + msg, expected := s.createUpdateTestDataWithoutOldTuple() + m, err := Parse(msg) + s.NoError(err) + updateMsg, ok := m.(*UpdateMessage) + s.True(ok) + + s.Equal(expected, updateMsg) +} + +func TestDeleteMessageSuite(t *testing.T) { + suite.Run(t, new(deleteMessageSuite)) +} + +type deleteMessageSuite struct { + messageSuite +} + +func (s *deleteMessageSuite) TestWithOldTupleTypeK() { + msg, expected := s.createDeleteTestDataTypeK() + + m, err := Parse(msg) + s.NoError(err) + deleteMsg, ok := m.(*DeleteMessage) + s.True(ok) + + s.Equal(expected, deleteMsg) +} + +func (s *deleteMessageSuite) TestWithOldTupleTypeO() { + msg, expected := s.createDeleteTestDataTypeO() + + m, err := Parse(msg) + s.NoError(err) + deleteMsg, ok := m.(*DeleteMessage) + s.True(ok) + + s.Equal(expected, deleteMsg) +} + +func TestTruncateMessageSuite(t *testing.T) { + suite.Run(t, new(truncateMessageSuite)) +} + +type truncateMessageSuite struct { + messageSuite +} + +func (s *truncateMessageSuite) Test() { + msg, expected := s.createTruncateTestData() + + m, err := Parse(msg) + s.NoError(err) + truncateMsg, ok := m.(*TruncateMessage) + s.True(ok) + + s.Equal(expected, truncateMsg) +} + +func TestLogicalDecodingMessageSuite(t *testing.T) { + suite.Run(t, new(logicalDecodingMessageSuite)) +} + +type logicalDecodingMessageSuite struct { + messageSuite +} + +func (s *logicalDecodingMessageSuite) Test() { + msg := make([]byte, 1+1+8+5+4+5) + msg[0] = 'M' + + expected := s.putMessageTestData(msg[1:]) + + expected.msgType = MessageTypeMessage + + m, err := Parse(msg) + s.NoError(err) + logicalDecodingMsg, ok := m.(*LogicalDecodingMessage) + s.True(ok) + + s.Equal(expected, logicalDecodingMsg) +} diff --git a/internal/impl/postgresql/pglogicalstream/pglogrepl.go b/internal/impl/postgresql/pglogicalstream/pglogrepl.go new file mode 100644 index 0000000000..9124fc0de7 --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/pglogrepl.go @@ -0,0 +1,773 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package pglogicalstream + +// Package pglogrepl implements PostgreSQL logical replication client functionality. +// +// pglogrepl uses package github.com/jackc/pgconn as its underlying PostgreSQL connection. +// Use pgconn to establish a connection to PostgreSQL and then use the pglogrepl functions +// on that connection. +// +// Proper use of this package requires understanding the underlying PostgreSQL concepts. +// See https://www.postgresql.org/docs/current/protocol-replication.html. + +import ( + "context" + "database/sql/driver" + "encoding/binary" + "fmt" + "strconv" + "strings" + "time" + + "github.com/jackc/pgio" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgproto3" +) + +const ( + XLogDataByteID = 'w' + PrimaryKeepaliveMessageByteID = 'k' + StandbyStatusUpdateByteID = 'r' +) + +type ReplicationMode int + +const ( + LogicalReplication ReplicationMode = iota + PhysicalReplication +) + +// String formats the mode into a postgres valid string +func (mode ReplicationMode) String() string { + if mode == LogicalReplication { + return "LOGICAL" + } else { + return "PHYSICAL" + } +} + +// LSN is a PostgreSQL Log Sequence Number. See https://www.postgresql.org/docs/current/datatype-pg-lsn.html. +type LSN uint64 + +// String formats the LSN value into the XXX/XXX format which is the text format used by PostgreSQL. +func (lsn LSN) String() string { + return fmt.Sprintf("%X/%X", uint32(lsn>>32), uint32(lsn)) +} + +func (lsn *LSN) decodeText(src string) error { + lsnValue, err := ParseLSN(src) + if err != nil { + return err + } + *lsn = lsnValue + + return nil +} + +// Scan implements the Scanner interface. +func (lsn *LSN) Scan(src interface{}) error { + if lsn == nil { + return nil + } + + switch v := src.(type) { + case uint64: + *lsn = LSN(v) + case string: + if err := lsn.decodeText(v); err != nil { + return err + } + case []byte: + if err := lsn.decodeText(string(v)); err != nil { + return err + } + default: + return fmt.Errorf("can not scan %T to LSN", src) + } + + return nil +} + +// Value implements the Valuer interface. +func (lsn LSN) Value() (driver.Value, error) { + return driver.Value(lsn.String()), nil +} + +// ParseLSN parses the given XXX/XXX text format LSN used by PostgreSQL. +func ParseLSN(s string) (LSN, error) { + var upperHalf uint64 + var lowerHalf uint64 + var nparsed int + nparsed, err := fmt.Sscanf(s, "%X/%X", &upperHalf, &lowerHalf) + if err != nil { + return 0, fmt.Errorf("failed to parse LSN: %w", err) + } + + if nparsed != 2 { + return 0, fmt.Errorf("failed to parsed LSN: %s", s) + } + + return LSN((upperHalf << 32) + lowerHalf), nil +} + +// IdentifySystemResult is the parsed result of the IDENTIFY_SYSTEM command. +type IdentifySystemResult struct { + SystemID string + Timeline int32 + XLogPos LSN + DBName string +} + +// IdentifySystem executes the IDENTIFY_SYSTEM command. +func IdentifySystem(ctx context.Context, conn *pgconn.PgConn) (IdentifySystemResult, error) { + return ParseIdentifySystem(conn.Exec(ctx, "IDENTIFY_SYSTEM")) +} + +// ParseIdentifySystem parses the result of the IDENTIFY_SYSTEM command. +func ParseIdentifySystem(mrr *pgconn.MultiResultReader) (IdentifySystemResult, error) { + var isr IdentifySystemResult + results, err := mrr.ReadAll() + if err != nil { + return isr, err + } + + if len(results) != 1 { + return isr, fmt.Errorf("expected 1 result set, got %d", len(results)) + } + + result := results[0] + if len(result.Rows) != 1 { + return isr, fmt.Errorf("expected 1 result row, got %d", len(result.Rows)) + } + + row := result.Rows[0] + if len(row) != 4 { + return isr, fmt.Errorf("expected 4 result columns, got %d", len(row)) + } + + isr.SystemID = string(row[0]) + timeline, err := strconv.ParseInt(string(row[1]), 10, 32) + if err != nil { + return isr, fmt.Errorf("failed to parse timeline: %w", err) + } + isr.Timeline = int32(timeline) + + isr.XLogPos, err = ParseLSN(string(row[2])) + if err != nil { + return isr, fmt.Errorf("failed to parse xlogpos as LSN: %w", err) + } + + isr.DBName = string(row[3]) + + return isr, nil +} + +// TimelineHistoryResult is the parsed result of the TIMELINE_HISTORY command. +type TimelineHistoryResult struct { + FileName string + Content []byte +} + +// TimelineHistory executes the TIMELINE_HISTORY command. +func TimelineHistory(ctx context.Context, conn *pgconn.PgConn, timeline int32) (TimelineHistoryResult, error) { + sql := fmt.Sprintf("TIMELINE_HISTORY %d", timeline) + return ParseTimelineHistory(conn.Exec(ctx, sql)) +} + +// ParseTimelineHistory parses the result of the TIMELINE_HISTORY command. +func ParseTimelineHistory(mrr *pgconn.MultiResultReader) (TimelineHistoryResult, error) { + var thr TimelineHistoryResult + results, err := mrr.ReadAll() + if err != nil { + return thr, err + } + + if len(results) != 1 { + return thr, fmt.Errorf("expected 1 result set, got %d", len(results)) + } + + result := results[0] + if len(result.Rows) != 1 { + return thr, fmt.Errorf("expected 1 result row, got %d", len(result.Rows)) + } + + row := result.Rows[0] + if len(row) != 2 { + return thr, fmt.Errorf("expected 2 result columns, got %d", len(row)) + } + + thr.FileName = string(row[0]) + thr.Content = row[1] + return thr, nil +} + +type CreateReplicationSlotOptions struct { + Temporary bool + SnapshotAction string + Mode ReplicationMode +} + +// CreateReplicationSlotResult is the parsed results the CREATE_REPLICATION_SLOT command. +type CreateReplicationSlotResult struct { + SlotName string + ConsistentPoint string + SnapshotName string + OutputPlugin string +} + +// CreateReplicationSlot creates a logical replication slot. +func CreateReplicationSlot( + ctx context.Context, + conn *pgconn.PgConn, + slotName string, + outputPlugin string, + options CreateReplicationSlotOptions, +) (CreateReplicationSlotResult, error) { + var temporaryString string + if options.Temporary { + temporaryString = "TEMPORARY" + } + var snapshotString string + if options.SnapshotAction == "export" { + snapshotString = "(SNAPSHOT export)" + } else { + snapshotString = options.SnapshotAction + } + sql := fmt.Sprintf("CREATE_REPLICATION_SLOT %s %s %s %s %s", slotName, temporaryString, options.Mode, outputPlugin, snapshotString) + return ParseCreateReplicationSlot(conn.Exec(ctx, sql)) +} + +// ParseCreateReplicationSlot parses the result of the CREATE_REPLICATION_SLOT command. +func ParseCreateReplicationSlot(mrr *pgconn.MultiResultReader) (CreateReplicationSlotResult, error) { + var crsr CreateReplicationSlotResult + results, err := mrr.ReadAll() + if err != nil { + return crsr, err + } + + if len(results) != 1 { + return crsr, fmt.Errorf("expected 1 result set, got %d", len(results)) + } + + result := results[0] + if len(result.Rows) != 1 { + return crsr, fmt.Errorf("expected 1 result row, got %d", len(result.Rows)) + } + + row := result.Rows[0] + if len(row) != 4 { + return crsr, fmt.Errorf("expected 4 result columns, got %d", len(row)) + } + + crsr.SlotName = string(row[0]) + crsr.ConsistentPoint = string(row[1]) + crsr.SnapshotName = string(row[2]) + crsr.OutputPlugin = string(row[3]) + + return crsr, nil +} + +type DropReplicationSlotOptions struct { + Wait bool +} + +// DropReplicationSlot drops a logical replication slot. +func DropReplicationSlot(ctx context.Context, conn *pgconn.PgConn, slotName string, options DropReplicationSlotOptions) error { + var waitString string + if options.Wait { + waitString = "WAIT" + } + sql := fmt.Sprintf("DROP_REPLICATION_SLOT %s %s", slotName, waitString) + _, err := conn.Exec(ctx, sql).ReadAll() + return err +} + +type StartReplicationOptions struct { + Timeline int32 // 0 means current server timeline + Mode ReplicationMode + PluginArgs []string +} + +// StartReplication begins the replication process by executing the START_REPLICATION command. +func StartReplication(ctx context.Context, conn *pgconn.PgConn, slotName string, startLSN LSN, options StartReplicationOptions) error { + var timelineString string + if options.Timeline > 0 { + timelineString = fmt.Sprintf("TIMELINE %d", options.Timeline) + options.PluginArgs = append(options.PluginArgs, timelineString) + } + + sql := fmt.Sprintf("START_REPLICATION SLOT %s %s %s ", slotName, options.Mode, startLSN) + if options.Mode == LogicalReplication { + if len(options.PluginArgs) > 0 { + sql += fmt.Sprintf("(%s)", strings.Join(options.PluginArgs, ", ")) + } + } else { + sql += timelineString + } + + conn.Frontend().SendQuery(&pgproto3.Query{String: sql}) + err := conn.Frontend().Flush() + if err != nil { + return fmt.Errorf("failed to send START_REPLICATION: %w", err) + } + + for { + msg, err := conn.ReceiveMessage(ctx) + if err != nil { + return fmt.Errorf("failed to receive message: %w", err) + } + + switch msg := msg.(type) { + case *pgproto3.NoticeResponse: + case *pgproto3.ErrorResponse: + return pgconn.ErrorResponseToPgError(msg) + case *pgproto3.CopyBothResponse: + // This signals the start of the replication stream. + return nil + default: + return fmt.Errorf("unexpected response type: %T", msg) + } + } +} + +type BaseBackupOptions struct { + // Request information required to generate a progress report, but might as such have a negative impact on the performance. + Progress bool + // Sets the label of the backup. If none is specified, a backup label of 'wal-g' will be used. + Label string + // Request a fast checkpoint. + Fast bool + // Include the necessary WAL segments in the backup. This will include all the files between start and stop backup in the pg_wal directory of the base directory tar file. + WAL bool + // By default, the backup will wait until the last required WAL segment has been archived, or emit a warning if log archiving is not enabled. + // Specifying NOWAIT disables both the waiting and the warning, leaving the client responsible for ensuring the required log is available. + NoWait bool + // Limit (throttle) the maximum amount of data transferred from server to client per unit of time (kb/s). + MaxRate int32 + // Include information about symbolic links present in the directory pg_tblspc in a file named tablespace_map. + TablespaceMap bool + // Disable checksums being verified during a base backup. + // Note that NoVerifyChecksums=true is only supported since PG11 + NoVerifyChecksums bool +} + +func (bbo BaseBackupOptions) sql(serverVersion int) string { + var parts []string + if bbo.Label != "" { + parts = append(parts, "LABEL '"+strings.ReplaceAll(bbo.Label, "'", "''")+"'") + } + if bbo.Progress { + parts = append(parts, "PROGRESS") + } + if bbo.Fast { + if serverVersion >= 15 { + parts = append(parts, "CHECKPOINT 'fast'") + } else { + parts = append(parts, "FAST") + } + } + if bbo.WAL { + parts = append(parts, "WAL") + } + if bbo.NoWait { + if serverVersion >= 15 { + parts = append(parts, "WAIT false") + } else { + parts = append(parts, "NOWAIT") + } + } + if bbo.MaxRate >= 32 { + parts = append(parts, fmt.Sprintf("MAX_RATE %d", bbo.MaxRate)) + } + if bbo.TablespaceMap { + parts = append(parts, "TABLESPACE_MAP") + } + if bbo.NoVerifyChecksums { + if serverVersion >= 15 { + parts = append(parts, "VERIFY_CHECKSUMS false") + } else if serverVersion >= 11 { + parts = append(parts, "NOVERIFY_CHECKSUMS") + } + } + if serverVersion >= 15 { + return "BASE_BACKUP(" + strings.Join(parts, ", ") + ")" + } + return "BASE_BACKUP " + strings.Join(parts, " ") +} + +// BaseBackupTablespace represents a tablespace in the backup +type BaseBackupTablespace struct { + OID int32 + Location string + Size int8 +} + +// BaseBackupResult will hold the return values of the BaseBackup command +type BaseBackupResult struct { + LSN LSN + TimelineID int32 + Tablespaces []BaseBackupTablespace +} + +func serverMajorVersion(conn *pgconn.PgConn) (int, error) { + verString := conn.ParameterStatus("server_version") + dot := strings.IndexByte(verString, '.') + if dot == -1 { + return 0, fmt.Errorf("bad server version string: '%s'", verString) + } + return strconv.Atoi(verString[:dot]) +} + +// StartBaseBackup begins the process for copying a basebackup by executing the BASE_BACKUP command. +func StartBaseBackup(ctx context.Context, conn *pgconn.PgConn, options BaseBackupOptions) (result BaseBackupResult, err error) { + serverVersion, err := serverMajorVersion(conn) + if err != nil { + return result, err + } + sql := options.sql(serverVersion) + + conn.Frontend().SendQuery(&pgproto3.Query{String: sql}) + err = conn.Frontend().Flush() + if err != nil { + return result, fmt.Errorf("failed to send BASE_BACKUP: %w", err) + } + // From here Postgres returns result sets, but pgconn has no infrastructure to properly capture them. + // So we capture data low level with sub functions, before we return from this function when we get to the CopyData part. + result.LSN, result.TimelineID, err = getBaseBackupInfo(ctx, conn) + if err != nil { + return result, err + } + result.Tablespaces, err = getTableSpaceInfo(ctx, conn) + return result, err +} + +// getBaseBackupInfo returns the start or end position of the backup as returned by Postgres +func getBaseBackupInfo(ctx context.Context, conn *pgconn.PgConn) (start LSN, timelineID int32, err error) { + for { + msg, err := conn.ReceiveMessage(ctx) + if err != nil { + return start, timelineID, fmt.Errorf("failed to receive message: %w", err) + } + switch msg := msg.(type) { + case *pgproto3.RowDescription: + if len(msg.Fields) != 2 { + return start, timelineID, fmt.Errorf("expected 2 column headers, received: %d", len(msg.Fields)) + } + colName := string(msg.Fields[0].Name) + if colName != "recptr" { + return start, timelineID, fmt.Errorf("unexpected col name for recptr col: %s", colName) + } + colName = string(msg.Fields[1].Name) + if colName != "tli" { + return start, timelineID, fmt.Errorf("unexpected col name for tli col: %s", colName) + } + case *pgproto3.DataRow: + if len(msg.Values) != 2 { + return start, timelineID, fmt.Errorf("expected 2 columns, received: %d", len(msg.Values)) + } + colData := string(msg.Values[0]) + start, err = ParseLSN(colData) + if err != nil { + return start, timelineID, fmt.Errorf("cannot convert result to LSN: %s", colData) + } + colData = string(msg.Values[1]) + tli, err := strconv.Atoi(colData) + if err != nil { + return start, timelineID, fmt.Errorf("cannot convert timelineID to int: %s", colData) + } + timelineID = int32(tli) + case *pgproto3.NoticeResponse: + case *pgproto3.CommandComplete: + return start, timelineID, nil + case *pgproto3.ErrorResponse: + return start, timelineID, fmt.Errorf("error response sev=%q code=%q message=%q detail=%q position=%d", msg.Severity, msg.Code, msg.Message, msg.Detail, msg.Position) + default: + return start, timelineID, fmt.Errorf("unexpected response type: %T", msg) + } + } +} + +// getBaseBackupInfo returns the start or end position of the backup as returned by Postgres +func getTableSpaceInfo(ctx context.Context, conn *pgconn.PgConn) (tbss []BaseBackupTablespace, err error) { + for { + msg, err := conn.ReceiveMessage(ctx) + if err != nil { + return tbss, fmt.Errorf("failed to receive message: %w", err) + } + switch msg := msg.(type) { + case *pgproto3.RowDescription: + if len(msg.Fields) != 3 { + return tbss, fmt.Errorf("expected 3 column headers, received: %d", len(msg.Fields)) + } + colName := string(msg.Fields[0].Name) + if colName != "spcoid" { + return tbss, fmt.Errorf("unexpected col name for spcoid col: %s", colName) + } + colName = string(msg.Fields[1].Name) + if colName != "spclocation" { + return tbss, fmt.Errorf("unexpected col name for spclocation col: %s", colName) + } + colName = string(msg.Fields[2].Name) + if colName != "size" { + return tbss, fmt.Errorf("unexpected col name for size col: %s", colName) + } + case *pgproto3.DataRow: + if len(msg.Values) != 3 { + return tbss, fmt.Errorf("expected 3 columns, received: %d", len(msg.Values)) + } + if msg.Values[0] == nil { + continue + } + tbs := BaseBackupTablespace{} + colData := string(msg.Values[0]) + OID, err := strconv.Atoi(colData) + if err != nil { + return tbss, fmt.Errorf("cannot convert spcoid to int: %s", colData) + } + tbs.OID = int32(OID) + tbs.Location = string(msg.Values[1]) + if msg.Values[2] != nil { + colData := string(msg.Values[2]) + size, err := strconv.Atoi(colData) + if err != nil { + return tbss, fmt.Errorf("cannot convert size to int: %s", colData) + } + tbs.Size = int8(size) + } + tbss = append(tbss, tbs) + case *pgproto3.CommandComplete: + return tbss, nil + default: + return tbss, fmt.Errorf("unexpected response type: %T", msg) + } + } +} + +// NextTableSpace consumes some msgs so we are at start of CopyData +func NextTableSpace(ctx context.Context, conn *pgconn.PgConn) (err error) { + + for { + msg, err := conn.ReceiveMessage(ctx) + if err != nil { + return fmt.Errorf("failed to receive message: %w", err) + } + + switch msg := msg.(type) { + case *pgproto3.CopyOutResponse: + return nil + case *pgproto3.CopyData: + return nil + case *pgproto3.ErrorResponse: + return pgconn.ErrorResponseToPgError(msg) + case *pgproto3.NoticeResponse: + case *pgproto3.RowDescription: + + default: + return fmt.Errorf("unexpected response type: %T", msg) + } + } +} + +// FinishBaseBackup wraps up a backup after copying all results from the BASE_BACKUP command. +func FinishBaseBackup(ctx context.Context, conn *pgconn.PgConn) (result BaseBackupResult, err error) { + + // From here Postgres returns result sets, but pgconn has no infrastructure to properly capture them. + // So we capture data low level with sub functions, before we return from this function when we get to the CopyData part. + result.LSN, result.TimelineID, err = getBaseBackupInfo(ctx, conn) + if err != nil { + return result, err + } + + // Base_Backup done, server send a command complete response from pg13 + vmaj, err := serverMajorVersion(conn) + if err != nil { + return + } + var ( + pack pgproto3.BackendMessage + ok bool + ) + if vmaj > 12 { + pack, err = conn.ReceiveMessage(ctx) + if err != nil { + return + } + _, ok = pack.(*pgproto3.CommandComplete) + if !ok { + err = fmt.Errorf("expect command_complete, got %T", pack) + return + } + } + + // simple query done, server send a ready for query response + pack, err = conn.ReceiveMessage(ctx) + if err != nil { + return + } + _, ok = pack.(*pgproto3.ReadyForQuery) + if !ok { + err = fmt.Errorf("expect ready_for_query, got %T", pack) + return + } + return +} + +type PrimaryKeepaliveMessage struct { + ServerWALEnd LSN + ServerTime time.Time + ReplyRequested bool +} + +// ParsePrimaryKeepaliveMessage parses a Primary keepalive message from the server. +func ParsePrimaryKeepaliveMessage(buf []byte) (PrimaryKeepaliveMessage, error) { + var pkm PrimaryKeepaliveMessage + if len(buf) != 17 { + return pkm, fmt.Errorf("PrimaryKeepaliveMessage must be 17 bytes, got %d", len(buf)) + } + + pkm.ServerWALEnd = LSN(binary.BigEndian.Uint64(buf)) + pkm.ServerTime = pgTimeToTime(int64(binary.BigEndian.Uint64(buf[8:]))) + pkm.ReplyRequested = buf[16] != 0 + + return pkm, nil +} + +type XLogData struct { + WALStart LSN + ServerWALEnd LSN + ServerTime time.Time + WALData []byte +} + +// ParseXLogData parses a XLogData message from the server. +func ParseXLogData(buf []byte) (XLogData, error) { + var xld XLogData + if len(buf) < 24 { + return xld, fmt.Errorf("XLogData must be at least 24 bytes, got %d", len(buf)) + } + + xld.WALStart = LSN(binary.BigEndian.Uint64(buf)) + xld.ServerWALEnd = LSN(binary.BigEndian.Uint64(buf[8:])) + xld.ServerTime = pgTimeToTime(int64(binary.BigEndian.Uint64(buf[16:]))) + xld.WALData = buf[24:] + + return xld, nil +} + +// StandbyStatusUpdate is a message sent from the client that acknowledges receipt of WAL records. +type StandbyStatusUpdate struct { + WALWritePosition LSN // The WAL position that's been locally written + WALFlushPosition LSN // The WAL position that's been locally flushed + WALApplyPosition LSN // The WAL position that's been locally applied + ClientTime time.Time // Client system clock time + ReplyRequested bool // Request server to reply immediately. +} + +// SendStandbyStatusUpdate sends a StandbyStatusUpdate to the PostgreSQL server. +// +// The only required field in ssu is WALWritePosition. If WALFlushPosition is 0 then WALWritePosition will be assigned +// to it. If WALApplyPosition is 0 then WALWritePosition will be assigned to it. If ClientTime is the zero value then +// the current time will be assigned to it. +func SendStandbyStatusUpdate(_ context.Context, conn *pgconn.PgConn, ssu StandbyStatusUpdate) error { + if ssu.WALFlushPosition == 0 { + ssu.WALFlushPosition = ssu.WALWritePosition + } + if ssu.WALApplyPosition == 0 { + ssu.WALApplyPosition = ssu.WALWritePosition + } + if ssu.ClientTime == (time.Time{}) { + ssu.ClientTime = time.Now() + } + + data := make([]byte, 0, 34) + data = append(data, StandbyStatusUpdateByteID) + data = pgio.AppendUint64(data, uint64(ssu.WALWritePosition)) + data = pgio.AppendUint64(data, uint64(ssu.WALFlushPosition)) + data = pgio.AppendUint64(data, uint64(ssu.WALApplyPosition)) + data = pgio.AppendInt64(data, timeToPgTime(ssu.ClientTime)) + if ssu.ReplyRequested { + data = append(data, 1) + } else { + data = append(data, 0) + } + + cd := &pgproto3.CopyData{Data: data} + buf, err := cd.Encode(nil) + if err != nil { + return err + } + + return conn.Frontend().SendUnbufferedEncodedCopyData(buf) +} + +// CopyDoneResult is the parsed result as returned by the server after the client +// sends a CopyDone to the server to confirm ending the copy-both mode. +type CopyDoneResult struct { + Timeline int32 + LSN LSN +} + +// SendStandbyCopyDone sends a StandbyCopyDone to the PostgreSQL server +// to confirm ending the copy-both mode. +func SendStandbyCopyDone(_ context.Context, conn *pgconn.PgConn) (cdr *CopyDoneResult, err error) { + // I am suspicious that this is wildly wrong, but I'm pretty sure the previous + // code was wildly wrong too -- wttw + conn.Frontend().Send(&pgproto3.CopyDone{}) + err = conn.Frontend().Flush() + if err != nil { + return + } + + for { + var msg pgproto3.BackendMessage + msg, err = conn.Frontend().Receive() + if err != nil { + return cdr, err + } + + switch m := msg.(type) { + case *pgproto3.CopyDone: + case *pgproto3.ParameterStatus, *pgproto3.NoticeResponse: + case *pgproto3.CommandComplete: + case *pgproto3.RowDescription: + case *pgproto3.DataRow: + // We are expecting just one row returned, with two columns timeline and LSN + // We should pay attention to RowDescription, but we'll take it on trust. + if len(m.Values) == 2 { + timeline, lerr := strconv.Atoi(string(m.Values[0])) + if lerr == nil { + lsn, lerr := ParseLSN(string(m.Values[1])) + if lerr == nil { + cdr.Timeline = int32(timeline) + cdr.LSN = lsn + } + } + } + case *pgproto3.EmptyQueryResponse: + case *pgproto3.ErrorResponse: + return cdr, pgconn.ErrorResponseToPgError(m) + case *pgproto3.ReadyForQuery: + // Should we eat the ReadyForQuery here, or not? + return cdr, err + } + } +} + +const microsecFromUnixEpochToY2K = 946684800 * 1000000 + +func pgTimeToTime(microsecSinceY2K int64) time.Time { + microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K + return time.Unix(0, microsecSinceUnixEpoch*1000) +} + +func timeToPgTime(t time.Time) int64 { + microsecSinceUnixEpoch := t.Unix()*1000000 + int64(t.Nanosecond())/1000 + return microsecSinceUnixEpoch - microsecFromUnixEpochToY2K +} diff --git a/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go b/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go new file mode 100644 index 0000000000..0e81f4c00e --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go @@ -0,0 +1,414 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package pglogicalstream + +import ( + "context" + "fmt" + "os" + "strconv" + "testing" + "time" + + "github.com/jackc/pglogrepl" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgproto3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +func TestLSNSuite(t *testing.T) { + suite.Run(t, new(lsnSuite)) +} + +type lsnSuite struct { + suite.Suite +} + +func (s *lsnSuite) R() *require.Assertions { + return s.Require() +} + +func (s *lsnSuite) Equal(e, a interface{}, args ...interface{}) { + s.R().Equal(e, a, args...) +} + +func (s *lsnSuite) NoError(err error) { + s.R().NoError(err) +} + +func (s *lsnSuite) TestScannerInterface() { + var lsn pglogrepl.LSN + lsnText := "16/B374D848" + lsnUint64 := uint64(97500059720) + var err error + + err = lsn.Scan(lsnText) + s.NoError(err) + s.Equal(lsnText, lsn.String()) + + err = lsn.Scan([]byte(lsnText)) + s.NoError(err) + s.Equal(lsnText, lsn.String()) + + lsn = 0 + err = lsn.Scan(lsnUint64) + s.NoError(err) + s.Equal(lsnText, lsn.String()) + + err = lsn.Scan(int64(lsnUint64)) + s.Error(err) + s.T().Log(err) +} + +func (s *lsnSuite) TestScanToNil() { + var lsnPtr *pglogrepl.LSN + err := lsnPtr.Scan("16/B374D848") + s.NoError(err) +} + +func (s *lsnSuite) TestValueInterface() { + lsn := pglogrepl.LSN(97500059720) + driverValue, err := lsn.Value() + s.NoError(err) + lsnStr, ok := driverValue.(string) + s.R().True(ok) + s.Equal("16/B374D848", lsnStr) +} + +const slotName = "pglogrepl_test" +const outputPlugin = "test_decoding" + +func closeConn(t testing.TB, conn *pgconn.PgConn) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + require.NoError(t, conn.Close(ctx)) +} + +func TestIdentifySystem(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + conn, err := pgconn.Connect(ctx, os.Getenv("PGLOGREPL_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, conn) + + sysident, err := pglogrepl.IdentifySystem(ctx, conn) + require.NoError(t, err) + + assert.Greater(t, len(sysident.SystemID), 0) + assert.True(t, sysident.Timeline > 0) + assert.True(t, sysident.XLogPos > 0) + assert.Greater(t, len(sysident.DBName), 0) +} + +func TestGetHistoryFile(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + config, err := pgconn.ParseConfig(os.Getenv("PGLOGREPL_TEST_CONN_STRING")) + require.NoError(t, err) + config.RuntimeParams["replication"] = "on" + + conn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, conn) + + sysident, err := pglogrepl.IdentifySystem(ctx, conn) + require.NoError(t, err) + + _, err = pglogrepl.TimelineHistory(ctx, conn, 0) + require.Error(t, err) + + _, err = pglogrepl.TimelineHistory(ctx, conn, 1) + require.Error(t, err) + + if sysident.Timeline > 1 { + // This test requires a Postgres with at least 1 timeline increase (promote, or recover)... + tlh, err := pglogrepl.TimelineHistory(ctx, conn, sysident.Timeline) + require.NoError(t, err) + + expectedFileName := fmt.Sprintf("%08X.history", sysident.Timeline) + assert.Equal(t, expectedFileName, tlh.FileName) + assert.Greater(t, len(tlh.Content), 0) + } +} + +func TestCreateReplicationSlot(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + conn, err := pgconn.Connect(ctx, os.Getenv("PGLOGREPL_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, conn) + + result, err := pglogrepl.CreateReplicationSlot(ctx, conn, slotName, outputPlugin, pglogrepl.CreateReplicationSlotOptions{Temporary: true}) + require.NoError(t, err) + + assert.Equal(t, slotName, result.SlotName) + assert.Equal(t, outputPlugin, result.OutputPlugin) +} + +func TestDropReplicationSlot(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + conn, err := pgconn.Connect(ctx, os.Getenv("PGLOGREPL_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, conn) + + _, err = pglogrepl.CreateReplicationSlot(ctx, conn, slotName, outputPlugin, pglogrepl.CreateReplicationSlotOptions{Temporary: true}) + require.NoError(t, err) + + err = pglogrepl.DropReplicationSlot(ctx, conn, slotName, pglogrepl.DropReplicationSlotOptions{}) + require.NoError(t, err) + + _, err = pglogrepl.CreateReplicationSlot(ctx, conn, slotName, outputPlugin, pglogrepl.CreateReplicationSlotOptions{Temporary: true}) + require.NoError(t, err) +} + +func TestStartReplication(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + conn, err := pgconn.Connect(ctx, os.Getenv("PGLOGREPL_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, conn) + + sysident, err := pglogrepl.IdentifySystem(ctx, conn) + require.NoError(t, err) + + _, err = pglogrepl.CreateReplicationSlot(ctx, conn, slotName, outputPlugin, pglogrepl.CreateReplicationSlotOptions{Temporary: true}) + require.NoError(t, err) + + err = pglogrepl.StartReplication(ctx, conn, slotName, sysident.XLogPos, pglogrepl.StartReplicationOptions{}) + require.NoError(t, err) + + go func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + config, err := pgconn.ParseConfig(os.Getenv("PGLOGREPL_TEST_CONN_STRING")) + require.NoError(t, err) + delete(config.RuntimeParams, "replication") + + conn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, conn) + + _, err = conn.Exec(ctx, ` +create table t(id int primary key, name text); + +insert into t values (1, 'foo'); +insert into t values (2, 'bar'); +insert into t values (3, 'baz'); + +update t set name='quz' where id=3; + +delete from t where id=2; + +drop table t; +`).ReadAll() + require.NoError(t, err) + }() + + rxKeepAlive := func() pglogrepl.PrimaryKeepaliveMessage { + msg, err := conn.ReceiveMessage(ctx) + require.NoError(t, err) + cdMsg, ok := msg.(*pgproto3.CopyData) + require.True(t, ok) + + require.Equal(t, byte(pglogrepl.PrimaryKeepaliveMessageByteID), cdMsg.Data[0]) + pkm, err := pglogrepl.ParsePrimaryKeepaliveMessage(cdMsg.Data[1:]) + require.NoError(t, err) + return pkm + } + + rxXLogData := func() pglogrepl.XLogData { + var cdMsg *pgproto3.CopyData + // Discard keepalive messages + for { + msg, err := conn.ReceiveMessage(ctx) + require.NoError(t, err) + var ok bool + cdMsg, ok = msg.(*pgproto3.CopyData) + require.True(t, ok) + if cdMsg.Data[0] != pglogrepl.PrimaryKeepaliveMessageByteID { + break + } + } + require.Equal(t, byte(pglogrepl.XLogDataByteID), cdMsg.Data[0]) + xld, err := pglogrepl.ParseXLogData(cdMsg.Data[1:]) + require.NoError(t, err) + return xld + } + + rxKeepAlive() + xld := rxXLogData() + assert.Equal(t, "BEGIN", string(xld.WALData[:5])) + xld = rxXLogData() + assert.Equal(t, "table public.t: INSERT: id[integer]:1 name[text]:'foo'", string(xld.WALData)) + xld = rxXLogData() + assert.Equal(t, "table public.t: INSERT: id[integer]:2 name[text]:'bar'", string(xld.WALData)) + xld = rxXLogData() + assert.Equal(t, "table public.t: INSERT: id[integer]:3 name[text]:'baz'", string(xld.WALData)) + xld = rxXLogData() + assert.Equal(t, "table public.t: UPDATE: id[integer]:3 name[text]:'quz'", string(xld.WALData)) + xld = rxXLogData() + assert.Equal(t, "table public.t: DELETE: id[integer]:2", string(xld.WALData)) + xld = rxXLogData() + assert.Equal(t, "COMMIT", string(xld.WALData[:6])) +} + +func TestStartReplicationPhysical(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*50) + defer cancel() + + conn, err := pgconn.Connect(ctx, os.Getenv("PGLOGREPL_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, conn) + + sysident, err := pglogrepl.IdentifySystem(ctx, conn) + require.NoError(t, err) + + _, err = pglogrepl.CreateReplicationSlot(ctx, conn, slotName, "", pglogrepl.CreateReplicationSlotOptions{Temporary: true, Mode: pglogrepl.PhysicalReplication}) + require.NoError(t, err) + + err = pglogrepl.StartReplication(ctx, conn, slotName, sysident.XLogPos, pglogrepl.StartReplicationOptions{Mode: pglogrepl.PhysicalReplication}) + require.NoError(t, err) + + go func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + config, err := pgconn.ParseConfig(os.Getenv("PGLOGREPL_TEST_CONN_STRING")) + require.NoError(t, err) + delete(config.RuntimeParams, "replication") + + conn, err := pgconn.ConnectConfig(ctx, config) + require.NoError(t, err) + defer closeConn(t, conn) + + _, err = conn.Exec(ctx, ` +create table mytable(id int primary key, name text); +drop table mytable; +`).ReadAll() + require.NoError(t, err) + }() + + _ = func() pglogrepl.PrimaryKeepaliveMessage { + msg, err := conn.ReceiveMessage(ctx) + require.NoError(t, err) + cdMsg, ok := msg.(*pgproto3.CopyData) + require.True(t, ok) + + require.Equal(t, byte(pglogrepl.PrimaryKeepaliveMessageByteID), cdMsg.Data[0]) + pkm, err := pglogrepl.ParsePrimaryKeepaliveMessage(cdMsg.Data[1:]) + require.NoError(t, err) + return pkm + } + + rxXLogData := func() pglogrepl.XLogData { + msg, err := conn.ReceiveMessage(ctx) + require.NoError(t, err) + cdMsg, ok := msg.(*pgproto3.CopyData) + require.True(t, ok) + + require.Equal(t, byte(pglogrepl.XLogDataByteID), cdMsg.Data[0]) + xld, err := pglogrepl.ParseXLogData(cdMsg.Data[1:]) + require.NoError(t, err) + return xld + } + + xld := rxXLogData() + assert.Contains(t, string(xld.WALData), "mytable") + + copyDoneResult, err := pglogrepl.SendStandbyCopyDone(ctx, conn) + require.NoError(t, err) + assert.Nil(t, copyDoneResult) +} + +func TestBaseBackup(t *testing.T) { + // base backup test could take a long time. Therefore it can be disabled. + envSkipTest := os.Getenv("PGLOGREPL_SKIP_BASE_BACKUP") + if envSkipTest != "" { + skipTest, err := strconv.ParseBool(envSkipTest) + if err != nil { + t.Error(err) + } else if skipTest { + return + } + } + + conn, err := pgconn.Connect(context.Background(), os.Getenv("PGLOGREPL_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, conn) + + options := pglogrepl.BaseBackupOptions{ + NoVerifyChecksums: true, + Progress: true, + Label: "pglogrepltest", + Fast: true, + WAL: true, + NoWait: true, + MaxRate: 1024, + TablespaceMap: true, + } + startRes, err := pglogrepl.StartBaseBackup(context.Background(), conn, options) + require.NoError(t, err) + require.GreaterOrEqual(t, startRes.TimelineID, int32(1)) + + //Write the tablespaces + for i := 0; i < len(startRes.Tablespaces)+1; i++ { + f, err := os.CreateTemp("", fmt.Sprintf("pglogrepl_test_tbs_%d.tar", i)) + require.NoError(t, err) + err = pglogrepl.NextTableSpace(context.Background(), conn) + var message pgproto3.BackendMessage + L: + for { + message, err = conn.ReceiveMessage(context.Background()) + require.NoError(t, err) + switch msg := message.(type) { + case *pgproto3.CopyData: + _, err := f.Write(msg.Data) + require.NoError(t, err) + case *pgproto3.CopyDone: + break L + default: + t.Errorf("Received unexpected message: %#v\n", msg) + } + } + err = f.Close() + require.NoError(t, err) + } + + stopRes, err := pglogrepl.FinishBaseBackup(context.Background(), conn) + require.NoError(t, err) + require.Equal(t, startRes.TimelineID, stopRes.TimelineID) + require.Equal(t, len(stopRes.Tablespaces), 0) + require.Less(t, uint64(startRes.LSN), uint64(stopRes.LSN)) + _, err = pglogrepl.StartBaseBackup(context.Background(), conn, options) + require.NoError(t, err) +} + +func TestSendStandbyStatusUpdate(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + conn, err := pgconn.Connect(ctx, os.Getenv("PGLOGREPL_TEST_CONN_STRING")) + require.NoError(t, err) + defer closeConn(t, conn) + + sysident, err := pglogrepl.IdentifySystem(ctx, conn) + require.NoError(t, err) + + err = pglogrepl.SendStandbyStatusUpdate(ctx, conn, pglogrepl.StandbyStatusUpdate{WALWritePosition: sysident.XLogPos}) + require.NoError(t, err) +} diff --git a/internal/impl/postgresql/pglogicalstream/snapshotter.go b/internal/impl/postgresql/pglogicalstream/snapshotter.go new file mode 100644 index 0000000000..d98c9f305b --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/snapshotter.go @@ -0,0 +1,101 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package pglogicalstream + +import ( + "database/sql" + "fmt" + "log" + + "github.com/jackc/pgx/v5/pgconn" + _ "github.com/lib/pq" +) + +type Snapshotter struct { + pgConnection *sql.DB + snapshotName string +} + +func NewSnapshotter(dbConf pgconn.Config, snapshotName string) (*Snapshotter, error) { + var sslMode = "none" + if dbConf.TLSConfig != nil { + sslMode = "require" + } else { + sslMode = "disable" + } + connStr := fmt.Sprintf("user=%s password=%s host=%s port=%d dbname=%s sslmode=%s", dbConf.User, + dbConf.Password, dbConf.Host, dbConf.Port, dbConf.Database, sslMode, + ) + + pgConn, err := sql.Open("postgres", connStr) + + return &Snapshotter{ + pgConnection: pgConn, + snapshotName: snapshotName, + }, err +} + +func (s *Snapshotter) Prepare() error { + if _, err := s.pgConnection.Exec("BEGIN TRANSACTION ISOLATION LEVEL REPEATABLE READ;"); err != nil { + return err + } + if _, err := s.pgConnection.Exec(fmt.Sprintf("SET TRANSACTION SNAPSHOT '%s';", s.snapshotName)); err != nil { + return err + } + + return nil +} + +func (s *Snapshotter) FindAvgRowSize(table string) sql.NullInt64 { + var avgRowSize sql.NullInt64 + + if rows, err := s.pgConnection.Query(fmt.Sprintf(`SELECT SUM(pg_column_size('%s.*')) / COUNT(*) FROM %s;`, table, table)); err != nil { + log.Fatal("Can get avg row size", err) + } else { + if rows.Next() { + if err = rows.Scan(&avgRowSize); err != nil { + log.Fatal("Can get avg row size", err) + } + } else { + log.Fatal("Can get avg row size; 0 rows returned") + } + } + + return avgRowSize +} + +func (s *Snapshotter) CalculateBatchSize(availableMemory uint64, estimatedRowSize uint64) int { + // Adjust this factor based on your system's memory constraints. + // This example uses a safety factor of 0.8 to leave some memory headroom. + safetyFactor := 0.6 + batchSize := int(float64(availableMemory) * safetyFactor / float64(estimatedRowSize)) + if batchSize < 1 { + batchSize = 1 + } + return batchSize +} + +func (s *Snapshotter) QuerySnapshotData(table string, pk string, limit, offset int) (rows *sql.Rows, err error) { + // fmt.Sprintf("SELECT * FROM %s ORDER BY %s LIMIT %d OFFSET %d;", table, pk, limit, offset) + log.WithPrefix("[pg-stream/snapshotter]").Info("Query snapshot", "table", table, "limit", limit, "offset", offset, "pk", pk) + return s.pgConnection.Query(fmt.Sprintf("SELECT * FROM %s ORDER BY %s LIMIT %d OFFSET %d;", table, pk, limit, offset)) +} + +func (s *Snapshotter) ReleaseSnapshot() error { + _, err := s.pgConnection.Exec("COMMIT;") + return err +} + +func (s *Snapshotter) CloseConn() error { + if s.pgConnection != nil { + return s.pgConnection.Close() + } + + return nil +} diff --git a/internal/impl/postgresql/pglogicalstream/types.go b/internal/impl/postgresql/pglogicalstream/types.go new file mode 100644 index 0000000000..7dd4670cbc --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/types.go @@ -0,0 +1,24 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package pglogicalstream + +type Wal2JsonChanges struct { + Lsn *string `json:"lsn"` + Changes []Wal2JsonChange `json:"change"` +} + +type Wal2JsonChange struct { + Kind string `json:"kind"` + Schema string `json:"schema"` + Table string `json:"table"` + ColumnNames []string `json:"columnnames"` + ColumnTypes []string `json:"columntypes"` + ColumnValues []interface{} `json:"columnvalues"` +} +type OnMessage = func(message Wal2JsonChanges) diff --git a/internal/impl/postgresql/pglogicalstream/wal_changes_message.go b/internal/impl/postgresql/pglogicalstream/wal_changes_message.go new file mode 100644 index 0000000000..de5c1255e2 --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/wal_changes_message.go @@ -0,0 +1,25 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package pglogicalstream + +type WallMessage struct { + Change []struct { + Kind string `json:"kind"` + Schema string `json:"schema"` + Table string `json:"table"` + Columnnames []string `json:"columnnames"` + Columntypes []string `json:"columntypes"` + Columnvalues []interface{} `json:"columnvalues"` + Oldkeys struct { + Keynames []string `json:"keynames"` + Keytypes []string `json:"keytypes"` + Keyvalues []interface{} `json:"keyvalues"` + } `json:"oldkeys"` + } `json:"change"` +} From 99164a2e55e8376f709f0d3aa50807143c8e1882 Mon Sep 17 00:00:00 2001 From: Ashley Jeffs Date: Mon, 30 Sep 2024 10:12:24 +0100 Subject: [PATCH 002/118] Add placeholders for logging and TODOs on panics --- .../pg_stream/pg_stream/integration_test.go | 3 +- .../pg_stream/pg_stream/pg_stream.go | 1 + .../pglogicalstream/example/simple/main.go | 3 +- .../pglogicalstream/example/ws/main.go | 3 +- .../pglogicalstream/logical_stream.go | 53 ++++++++++--------- .../postgresql/pglogicalstream/snapshotter.go | 12 +++-- 6 files changed, 40 insertions(+), 35 deletions(-) diff --git a/internal/impl/postgresql/pg_stream/pg_stream/integration_test.go b/internal/impl/postgresql/pg_stream/pg_stream/integration_test.go index 02f53a7bbb..4678af70fe 100644 --- a/internal/impl/postgresql/pg_stream/pg_stream/integration_test.go +++ b/internal/impl/postgresql/pg_stream/pg_stream/integration_test.go @@ -12,7 +12,6 @@ import ( "context" "database/sql" "fmt" - "log" "strings" "sync" "testing" @@ -96,7 +95,7 @@ func TestIntegrationPgCDC(t *testing.T) { return err }); err != nil { - log.Fatalf("Could not connect to docker: %s", err) + panic(fmt.Errorf("could not connect to docker: %w", err)) } fake := faker.New() diff --git a/internal/impl/postgresql/pg_stream/pg_stream/pg_stream.go b/internal/impl/postgresql/pg_stream/pg_stream/pg_stream.go index fac92c400c..c321e0aecb 100644 --- a/internal/impl/postgresql/pg_stream/pg_stream/pg_stream.go +++ b/internal/impl/postgresql/pg_stream/pg_stream/pg_stream.go @@ -18,6 +18,7 @@ import ( "github.com/jackc/pgx/v5/pgconn" "github.com/lucasepe/codename" "github.com/redpanda-data/benthos/v4/public/service" + "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream" ) diff --git a/internal/impl/postgresql/pglogicalstream/example/simple/main.go b/internal/impl/postgresql/pglogicalstream/example/simple/main.go index b0c4ad01a9..1008f4a4a3 100644 --- a/internal/impl/postgresql/pglogicalstream/example/simple/main.go +++ b/internal/impl/postgresql/pglogicalstream/example/simple/main.go @@ -13,8 +13,9 @@ import ( "io/ioutil" "log" - "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream" "gopkg.in/yaml.v3" + + "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream" ) func main() { diff --git a/internal/impl/postgresql/pglogicalstream/example/ws/main.go b/internal/impl/postgresql/pglogicalstream/example/ws/main.go index 4ce2d9d067..53e0821162 100644 --- a/internal/impl/postgresql/pglogicalstream/example/ws/main.go +++ b/internal/impl/postgresql/pglogicalstream/example/ws/main.go @@ -13,8 +13,9 @@ import ( "log" "github.com/gorilla/websocket" - "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream" "gopkg.in/yaml.v3" + + "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream" ) func main() { diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index b15e9be282..1a35a521d1 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -16,7 +16,6 @@ import ( "encoding/json" "errors" "fmt" - "log" "os" "strings" "sync" @@ -25,6 +24,8 @@ import ( "github.com/jackc/pglogrepl" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgproto3" + "github.com/redpanda-data/benthos/v4/public/service" + "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/internal/helpers" ) @@ -53,7 +54,7 @@ type Stream struct { separateChanges bool snapshotBatchSize int snapshotMemorySafetyFactor float64 - logger *log.Logger + logger *service.Logger m sync.Mutex stopped bool @@ -112,7 +113,7 @@ func NewPgStream(config Config) (*Stream, error) { snapshotBatchSize: config.BatchSize, tableNames: tableNames, changeFilter: NewChangeFilter(tableNames, config.DbSchema), - logger: log.WithPrefix("[pg-stream]"), + logger: nil, // TODO m: sync.Mutex{}, stopped: false, } @@ -132,23 +133,23 @@ func NewPgStream(config Config) (*Stream, error) { result = stream.pgConn.Exec(context.Background(), fmt.Sprintf("CREATE PUBLICATION pglog_stream_%s %s;", config.ReplicationSlotName, tablesSchemaFilter)) _, err = result.ReadAll() if err != nil { - stream.logger.Fatalf("create publication error %s", err.Error()) + panic(fmt.Errorf("create publication error %w", err)) // TODO } - stream.logger.Info("Created Postgresql publication", "publication_name", config.ReplicationSlotName) + stream.logger.Infof("Created Postgresql publication %v %v", "publication_name", config.ReplicationSlotName) sysident, err := pglogrepl.IdentifySystem(context.Background(), stream.pgConn) if err != nil { - stream.logger.Fatalf("Failed to identify the system %s", err.Error()) + panic(fmt.Errorf("failed to identify the system %w", err)) // TODO } - stream.logger.Info("System identification result", "SystemID:", sysident.SystemID, "Timeline:", sysident.Timeline, "XLogPos:", sysident.XLogPos, "DBName:", sysident.DBName) + stream.logger.Infof("System identification result SystemID: %v Timeline: %v XLogPos: %v DBName: %v", sysident.SystemID, sysident.Timeline, sysident.XLogPos, sysident.DBName) var freshlyCreatedSlot = false var confirmedLSNFromDB string // check is replication slot exist to get last restart SLN connExecResult := stream.pgConn.Exec(context.TODO(), fmt.Sprintf("SELECT confirmed_flush_lsn FROM pg_replication_slots WHERE slot_name = '%s'", config.ReplicationSlotName)) if slotCheckResults, err := connExecResult.ReadAll(); err != nil { - stream.logger.Fatal(err) + panic(err) // TODO } else { if len(slotCheckResults) == 0 || len(slotCheckResults[0].Rows) == 0 { // here we create a new replication slot because there is no slot found @@ -158,14 +159,14 @@ func NewPgStream(config Config) (*Stream, error) { SnapshotAction: "export", }) if err != nil { - stream.logger.Fatalf("Failed to create replication slot for the database: %s", err.Error()) + panic(fmt.Errorf("failed to create replication slot for the database: %w", err)) // TODO } stream.snapshotName = createSlotResult.SnapshotName freshlyCreatedSlot = true } else { slotCheckRow := slotCheckResults[0].Rows[0] confirmedLSNFromDB = string(slotCheckRow[0]) - stream.logger.Info("Replication slot restart LSN extracted from DB", "LSN", confirmedLSNFromDB) + stream.logger.Infof("Replication slot restart LSN extracted from DB: LSN %v", confirmedLSNFromDB) } } @@ -203,16 +204,16 @@ func (s *Stream) startLr() { var err error err = pglogrepl.StartReplication(context.Background(), s.pgConn, s.slotName, s.lsnrestart, pglogrepl.StartReplicationOptions{PluginArgs: pluginArguments}) if err != nil { - s.logger.Fatalf("Starting replication slot failed: %s", err.Error()) + panic(fmt.Errorf("starting replication slot failed: %w", err)) // TODO } - s.logger.Info("Started logical replication on slot", "slot-name", s.slotName) + s.logger.Infof("Started logical replication on slot slot-name: %v", s.slotName) } func (s *Stream) AckLSN(lsn string) { var err error s.clientXLogPos, err = pglogrepl.ParseLSN(lsn) if err != nil { - s.logger.Fatalf("Failed to parse LSN for Acknowledge %s", err.Error()) + panic(fmt.Errorf("failed to parse LSN for Acknowledge %w", err)) // TODO } err = pglogrepl.SendStandbyStatusUpdate(context.Background(), s.pgConn, pglogrepl.StandbyStatusUpdate{ @@ -222,7 +223,7 @@ func (s *Stream) AckLSN(lsn string) { }) if err != nil { - s.logger.Fatalf("SendStandbyStatusUpdate failed: %s", err.Error()) + panic(fmt.Errorf("sendStandbyStatusUpdate failed: %w", err)) // TODO } s.logger.Debugf("Sent Standby status message at LSN#%s", s.clientXLogPos.String()) s.nextStandbyMessageDeadline = time.Now().Add(s.standbyMessageTimeout) @@ -242,7 +243,7 @@ func (s *Stream) streamMessagesAsync() { }) if err != nil { - s.logger.Fatalf("SendStandbyStatusUpdate failed: %s", err.Error()) + panic(fmt.Errorf("sendStandbyStatusUpdate failed: %w", err)) // TODO } s.logger.Debugf("Sent Standby status message at LSN#%s", s.clientXLogPos.String()) s.nextStandbyMessageDeadline = time.Now().Add(s.standbyMessageTimeout) @@ -262,11 +263,11 @@ func (s *Stream) streamMessagesAsync() { continue } - s.logger.Fatalf("Failed to receive messages from PostgreSQL %s", err.Error()) + panic(fmt.Errorf("failed to receive messages from PostgreSQL %w", err)) // TODO } if errMsg, ok := rawMsg.(*pgproto3.ErrorResponse); ok { - s.logger.Fatalf("Received broken Postgres WAL. Error: %+v", errMsg) + panic(fmt.Errorf("received broken Postgres WAL. Error: %+v", errMsg)) // TODO } msg, ok := rawMsg.(*pgproto3.CopyData) @@ -279,7 +280,7 @@ func (s *Stream) streamMessagesAsync() { case pglogrepl.PrimaryKeepaliveMessageByteID: pkm, err := pglogrepl.ParsePrimaryKeepaliveMessage(msg.Data[1:]) if err != nil { - s.logger.Fatalf("ParsePrimaryKeepaliveMessage failed: %s", err.Error()) + panic(fmt.Errorf("parsePrimaryKeepaliveMessage failed: %w", err)) // TODO } if pkm.ReplyRequested { @@ -289,7 +290,7 @@ func (s *Stream) streamMessagesAsync() { case pglogrepl.XLogDataByteID: xld, err := pglogrepl.ParseXLogData(msg.Data[1:]) if err != nil { - s.logger.Fatalf("ParseXLogData failed: %s", err.Error()) + panic(fmt.Errorf("parseXLogData failed: %w", err)) // TODO } clientXLogPos := xld.WALStart + pglogrepl.LSN(len(xld.WALData)) var changes WallMessage @@ -326,7 +327,7 @@ func (s *Stream) processSnapshot() { }() for _, table := range s.tableNames { - s.logger.Info("Processing snapshot for table", "table", table) + s.logger.Infof("Processing snapshot for table: %v", table) var ( avgRowSizeBytes sql.NullInt64 @@ -335,7 +336,7 @@ func (s *Stream) processSnapshot() { avgRowSizeBytes = snapshotter.FindAvgRowSize(table) batchSize := snapshotter.CalculateBatchSize(helpers.GetAvailableMemory(), uint64(avgRowSizeBytes.Int64)) - s.logger.Info("Querying snapshot", "batch_side", batchSize, "available_memory", helpers.GetAvailableMemory(), "avg_row_size", avgRowSizeBytes.Int64) + s.logger.Infof("Querying snapshot batch_side: %v, available_memory: %v, avg_row_size: %v", batchSize, helpers.GetAvailableMemory(), avgRowSizeBytes.Int64) tablePk, err := s.getPrimaryKeyColumn(table) if err != nil { @@ -345,13 +346,13 @@ func (s *Stream) processSnapshot() { for { var snapshotRows *sql.Rows if snapshotRows, err = snapshotter.QuerySnapshotData(table, tablePk, batchSize, offset); err != nil { - log.Fatalf("Can't query snapshot data %v", err) + panic(fmt.Errorf("can't query snapshot data: %w", err)) // TODO } columnTypes, err := snapshotRows.ColumnTypes() var columnTypesString = make([]string, len(columnTypes)) columnNames, err := snapshotRows.Columns() - for i, _ := range columnNames { + for i := range columnNames { columnTypesString[i] = columnTypes[i].DatabaseTypeName() } @@ -387,7 +388,7 @@ func (s *Stream) processSnapshot() { } var columnValues = make([]interface{}, len(columnTypes)) - for i, _ := range columnTypes { + for i := range columnTypes { if z, ok := (scanArgs[i]).(*sql.NullBool); ok { columnValues[i] = z.Bool continue @@ -465,7 +466,7 @@ func (s *Stream) LrMessageC() chan Wal2JsonChanges { // cleanUpOnFailure drops replication slot and publication if database snapshotting was failed for any reason func (s *Stream) cleanUpOnFailure() { - s.logger.Warn("Cleaning up resources on accident.", "replication-slot", s.slotName) + s.logger.Warnf("Cleaning up resources on accident: %v", s.slotName) err := DropReplicationSlot(context.Background(), s.pgConn, s.slotName, DropReplicationSlotOptions{Wait: true}) if err != nil { s.logger.Errorf("Failed to drop replication slot: %s", err.Error()) @@ -478,7 +479,7 @@ func (s *Stream) getPrimaryKeyColumn(tableName string) (string, error) { SELECT a.attname FROM pg_index i JOIN pg_attribute a ON a.attrelid = i.indrelid - AND a.attnum = ANY(i.indkey) + AND a.attnum = ANY(i.indkey) WHERE i.indrelid = '%s'::regclass AND i.indisprimary; `, tableName) diff --git a/internal/impl/postgresql/pglogicalstream/snapshotter.go b/internal/impl/postgresql/pglogicalstream/snapshotter.go index d98c9f305b..a4c73479f9 100644 --- a/internal/impl/postgresql/pglogicalstream/snapshotter.go +++ b/internal/impl/postgresql/pglogicalstream/snapshotter.go @@ -11,15 +11,16 @@ package pglogicalstream import ( "database/sql" "fmt" - "log" "github.com/jackc/pgx/v5/pgconn" _ "github.com/lib/pq" + "github.com/redpanda-data/benthos/v4/public/service" ) type Snapshotter struct { pgConnection *sql.DB snapshotName string + logger *service.Logger } func NewSnapshotter(dbConf pgconn.Config, snapshotName string) (*Snapshotter, error) { @@ -38,6 +39,7 @@ func NewSnapshotter(dbConf pgconn.Config, snapshotName string) (*Snapshotter, er return &Snapshotter{ pgConnection: pgConn, snapshotName: snapshotName, + logger: nil, // TODO }, err } @@ -56,14 +58,14 @@ func (s *Snapshotter) FindAvgRowSize(table string) sql.NullInt64 { var avgRowSize sql.NullInt64 if rows, err := s.pgConnection.Query(fmt.Sprintf(`SELECT SUM(pg_column_size('%s.*')) / COUNT(*) FROM %s;`, table, table)); err != nil { - log.Fatal("Can get avg row size", err) + panic(fmt.Errorf("can get avg row size: %w", err)) // TODO } else { if rows.Next() { if err = rows.Scan(&avgRowSize); err != nil { - log.Fatal("Can get avg row size", err) + panic(fmt.Errorf("can get avg row size: %w", err)) // TODO } } else { - log.Fatal("Can get avg row size; 0 rows returned") + panic("can get avg row size; 0 rows returned") // TODO } } @@ -83,7 +85,7 @@ func (s *Snapshotter) CalculateBatchSize(availableMemory uint64, estimatedRowSiz func (s *Snapshotter) QuerySnapshotData(table string, pk string, limit, offset int) (rows *sql.Rows, err error) { // fmt.Sprintf("SELECT * FROM %s ORDER BY %s LIMIT %d OFFSET %d;", table, pk, limit, offset) - log.WithPrefix("[pg-stream/snapshotter]").Info("Query snapshot", "table", table, "limit", limit, "offset", offset, "pk", pk) + s.logger.Infof("Query snapshot table: %v, limit: %v, offset: %v, pk: %v", table, limit, offset, pk) return s.pgConnection.Query(fmt.Sprintf("SELECT * FROM %s ORDER BY %s LIMIT %d OFFSET %d;", table, pk, limit, offset)) } From ffd356bcb789c726001bb6b12920ad47aa3dc5c2 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Tue, 1 Oct 2024 21:17:32 +0200 Subject: [PATCH 003/118] feat(pgstream): added support for pgoutput native plugin --- .../pg_stream/pg_stream/integration_test.go | 194 +++++++++++++++++- .../pg_stream/pg_stream/pg_stream.go | 31 ++- .../pg_stream_schemaless.go | 74 ------- .../pg_stream_schemaless/wal_message.go | 25 --- .../pglogicalstream/docker-compose.yaml | 11 - .../pglogicalstream/example/simple/main.go | 45 ---- .../pglogicalstream/example/ws/main.go | 55 ----- .../impl/postgresql/pglogicalstream/filter.go | 73 ------- .../internal/helpers/arrow_schema_builder.go | 44 ---- .../internal/schemas/schemas.go | 16 -- .../{message.go => replication_message.go} | 0 ...ge_test.go => replication_message_test.go} | 0 .../pglogicalstream/wal_changes_message.go | 25 --- 13 files changed, 213 insertions(+), 380 deletions(-) delete mode 100644 internal/impl/postgresql/pg_stream/pg_stream_schemaless/pg_stream_schemaless.go delete mode 100644 internal/impl/postgresql/pg_stream/pg_stream_schemaless/wal_message.go delete mode 100644 internal/impl/postgresql/pglogicalstream/docker-compose.yaml delete mode 100644 internal/impl/postgresql/pglogicalstream/example/simple/main.go delete mode 100644 internal/impl/postgresql/pglogicalstream/example/ws/main.go delete mode 100644 internal/impl/postgresql/pglogicalstream/filter.go delete mode 100644 internal/impl/postgresql/pglogicalstream/internal/helpers/arrow_schema_builder.go delete mode 100644 internal/impl/postgresql/pglogicalstream/internal/schemas/schemas.go rename internal/impl/postgresql/pglogicalstream/{message.go => replication_message.go} (100%) rename internal/impl/postgresql/pglogicalstream/{message_test.go => replication_message_test.go} (100%) delete mode 100644 internal/impl/postgresql/pglogicalstream/wal_changes_message.go diff --git a/internal/impl/postgresql/pg_stream/pg_stream/integration_test.go b/internal/impl/postgresql/pg_stream/pg_stream/integration_test.go index 4678af70fe..bec6d7e0f9 100644 --- a/internal/impl/postgresql/pg_stream/pg_stream/integration_test.go +++ b/internal/impl/postgresql/pg_stream/pg_stream/integration_test.go @@ -6,7 +6,7 @@ // // https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md -package pg_stream +package pgstream import ( "context" @@ -30,7 +30,6 @@ import ( ) func TestIntegrationPgCDC(t *testing.T) { - t.Parallel() tmpDir := t.TempDir() pool, err := dockertest.NewPool("") require.NoError(t, err) @@ -63,13 +62,13 @@ func TestIntegrationPgCDC(t *testing.T) { hostAndPort := resource.GetHostPort("5432/tcp") hostAndPortSplited := strings.Split(hostAndPort, ":") - databaseUrl := fmt.Sprintf("user=user_name password=secret dbname=dbname sslmode=disable host=%s port=%s", hostAndPortSplited[0], hostAndPortSplited[1]) + databaseURL := fmt.Sprintf("user=user_name password=secret dbname=dbname sslmode=disable host=%s port=%s", hostAndPortSplited[0], hostAndPortSplited[1]) var db *sql.DB pool.MaxWait = 120 * time.Second if err = pool.Retry(func() error { - if db, err = sql.Open("postgres", databaseUrl); err != nil { + if db, err = sql.Open("postgres", databaseURL); err != nil { return err } @@ -92,6 +91,12 @@ func TestIntegrationPgCDC(t *testing.T) { } _, err = db.Exec("CREATE TABLE IF NOT EXISTS flights (id serial PRIMARY KEY, name VARCHAR(50), created_at TIMESTAMP);") + if err != nil { + return err + } + + // flights_non_streamed is a control table with data that should not be streamed or queried by snapshot streaming + _, err = db.Exec("CREATE TABLE IF NOT EXISTS flights_non_streamed (id serial PRIMARY KEY, name VARCHAR(50), created_at TIMESTAMP);") return err }); err != nil { @@ -101,6 +106,8 @@ func TestIntegrationPgCDC(t *testing.T) { fake := faker.New() for i := 0; i < 1000; i++ { _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + + _, err = db.Exec("INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) require.NoError(t, err) } @@ -157,6 +164,7 @@ file: for i := 0; i < 1000; i++ { _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + _, err = db.Exec("INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) require.NoError(t, err) } @@ -205,6 +213,184 @@ file: return len(outMessages) == 50 }, time.Second*20, time.Millisecond*100) + require.NoError(t, streamOut.StopWithin(time.Second*10)) + t.Log("All the conditions are met 🎉", len(outMessages)) + + t.Cleanup(func() { + db.Close() + }) +} + +func TestIntegrationPgCDCForPgOutputPlugin(t *testing.T) { + tmpDir := t.TempDir() + pool, err := dockertest.NewPool("") + require.NoError(t, err) + + resource, err := pool.RunWithOptions(&dockertest.RunOptions{ + Repository: "postgres", + Tag: "16", + Env: []string{ + "POSTGRES_PASSWORD=secret", + "POSTGRES_USER=user_name", + "POSTGRES_DB=dbname", + }, + ExposedPorts: []string{"5432"}, + Cmd: []string{ + "postgres", + "-c", "wal_level=logical", + }, + }, func(config *docker.HostConfig) { + config.AutoRemove = true + config.RestartPolicy = docker.RestartPolicy{Name: "no"} + }) + + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, pool.Purge(resource)) + }) + + require.NoError(t, resource.Expire(120)) + + hostAndPort := resource.GetHostPort("5432/tcp") + hostAndPortSplited := strings.Split(hostAndPort, ":") + databaseURL := fmt.Sprintf("user=user_name password=secret dbname=dbname sslmode=disable host=%s port=%s", hostAndPortSplited[0], hostAndPortSplited[1]) + + var db *sql.DB + + pool.MaxWait = 120 * time.Second + if err = pool.Retry(func() error { + if db, err = sql.Open("postgres", databaseURL); err != nil { + return err + } + + if err = db.Ping(); err != nil { + return err + } + + var walLevel string + if err = db.QueryRow("SHOW wal_level").Scan(&walLevel); err != nil { + return err + } + + if walLevel != "logical" { + return fmt.Errorf("wal_level is not logical") + } + + _, err = db.Exec("CREATE TABLE IF NOT EXISTS flights (id serial PRIMARY KEY, name VARCHAR(50), created_at TIMESTAMP);") + + return err + }); err != nil { + panic(fmt.Errorf("could not connect to docker: %w", err)) + } + + fake := faker.New() + for i := 0; i < 10; i++ { + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + require.NoError(t, err) + } + + template := fmt.Sprintf(` +pg_stream: + host: %s + slot_name: test_slot_native_decoder + user: user_name + password: secret + port: %s + schema: public + tls: none + stream_snapshot: true + decoding_plugin: pgoutput + database: dbname + tables: + - flights +`, hostAndPortSplited[0], hostAndPortSplited[1]) + + cacheConf := fmt.Sprintf(` +label: pg_stream_cache +file: + directory: %v +`, tmpDir) + + streamOutBuilder := service.NewStreamBuilder() + require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: OFF`)) + require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) + require.NoError(t, streamOutBuilder.AddInputYAML(template)) + + var outMessages []string + var outMessagesMut sync.Mutex + + require.NoError(t, streamOutBuilder.AddConsumerFunc(func(c context.Context, m *service.Message) error { + msgBytes, err := m.AsBytes() + require.NoError(t, err) + outMessagesMut.Lock() + outMessages = append(outMessages, string(msgBytes)) + outMessagesMut.Unlock() + return nil + })) + + streamOut, err := streamOutBuilder.Build() + require.NoError(t, err) + + go func() { + _ = streamOut.Run(context.Background()) + }() + + assert.Eventually(t, func() bool { + outMessagesMut.Lock() + defer outMessagesMut.Unlock() + return len(outMessages) == 10 + }, time.Second*25, time.Millisecond*100) + + for i := 0; i < 10; i++ { + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + require.NoError(t, err) + } + + assert.Eventually(t, func() bool { + outMessagesMut.Lock() + defer outMessagesMut.Unlock() + return len(outMessages) == 20 + }, time.Second*25, time.Millisecond*100) + + require.NoError(t, streamOut.StopWithin(time.Second*10)) + + // Starting stream for the same replication slot should continue from the last LSN + // Meaning we must not receive any old messages again + + streamOutBuilder = service.NewStreamBuilder() + require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: OFF`)) + require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) + require.NoError(t, streamOutBuilder.AddInputYAML(template)) + + outMessages = []string{} + require.NoError(t, streamOutBuilder.AddConsumerFunc(func(c context.Context, m *service.Message) error { + msgBytes, err := m.AsBytes() + require.NoError(t, err) + outMessagesMut.Lock() + outMessages = append(outMessages, string(msgBytes)) + outMessagesMut.Unlock() + return nil + })) + + streamOut, err = streamOutBuilder.Build() + require.NoError(t, err) + + go func() { + assert.NoError(t, streamOut.Run(context.Background())) + }() + + time.Sleep(time.Second * 5) + for i := 0; i < 10; i++ { + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + require.NoError(t, err) + } + + assert.Eventually(t, func() bool { + outMessagesMut.Lock() + defer outMessagesMut.Unlock() + return len(outMessages) == 10 + }, time.Second*20, time.Millisecond*100) + require.NoError(t, streamOut.StopWithin(time.Second*10)) t.Log("All the conditions are met 🎉") diff --git a/internal/impl/postgresql/pg_stream/pg_stream/pg_stream.go b/internal/impl/postgresql/pg_stream/pg_stream/pg_stream.go index c321e0aecb..8a3dd20408 100644 --- a/internal/impl/postgresql/pg_stream/pg_stream/pg_stream.go +++ b/internal/impl/postgresql/pg_stream/pg_stream/pg_stream.go @@ -6,7 +6,7 @@ // // https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md -package pg_stream +package pgstream import ( "context" @@ -55,6 +55,9 @@ var pgStreamConfigSpec = service.NewConfigSpec(). Description("Sets amout of memory that can be used to stream snapshot. If affects batch sizes. If we want to use only 25% of the memory available - put 0.25 factor. It will make initial streaming slower, but it will prevent your worker from OOM Kill"). Example(0.2). Default(0.5)). + Field(service.NewStringEnumField("decoding_plugin", "pgoutput", "wal2json").Description("Specifies which decoding plugin to use when streaming data from PostgreSQL"). + Example("pgoutput"). + Default("pgoutput")). Field(service.NewStringListField("tables"). Example(` - my_table @@ -66,7 +69,7 @@ var pgStreamConfigSpec = service.NewConfigSpec(). Example("my_test_slot"). Default(randomSlotName)) -func newPgStreamInput(conf *service.ParsedConfig) (s service.Input, err error) { +func newPgStreamInput(conf *service.ParsedConfig, logger *service.Logger) (s service.Input, err error) { var ( dbName string dbPort int @@ -79,6 +82,7 @@ func newPgStreamInput(conf *service.ParsedConfig) (s service.Input, err error) { tables []string streamSnapshot bool snapshotMemSafetyFactor float64 + decodingPlugin string ) dbSchema, err = conf.FieldString("schema") @@ -135,6 +139,11 @@ func newPgStreamInput(conf *service.ParsedConfig) (s service.Input, err error) { return nil, err } + decodingPlugin, err = conf.FieldString("decoding_plugin") + if err != nil { + return nil, err + } + snapshotMemSafetyFactor, err = conf.FieldFloat("snapshot_memory_safety_factor") if err != nil { return nil, err @@ -163,17 +172,19 @@ func newPgStreamInput(conf *service.ParsedConfig) (s service.Input, err error) { schema: dbSchema, tls: pglogicalstream.TlsVerify(tlsSetting), tables: tables, + decodingPlugin: decodingPlugin, + logger: logger, }), err } func init() { rng, _ := codename.DefaultRNG() - randomSlotName = fmt.Sprintf("%s", strings.ReplaceAll(codename.Generate(rng, 5), "-", "_")) + randomSlotName = strings.ReplaceAll(codename.Generate(rng, 5), "-", "_") err := service.RegisterInput( "pg_stream", pgStreamConfigSpec, func(conf *service.ParsedConfig, mgr *service.Resources) (service.Input, error) { - return newPgStreamInput(conf) + return newPgStreamInput(conf, mgr.Logger()) }) if err != nil { panic(err) @@ -183,10 +194,10 @@ func init() { type pgStreamInput struct { dbConfig pgconn.Config pglogicalStream *pglogicalstream.Stream - redisUri string slotName string schema string tables []string + decodingPlugin string streamSnapshot bool tls pglogicalstream.TlsVerify // none, require snapshotMemSafetyFactor float64 @@ -202,14 +213,15 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { DbTables: p.tables, DbName: p.dbConfig.Database, DbSchema: p.schema, - ReplicationSlotName: fmt.Sprintf("rs_%s", p.slotName), + ReplicationSlotName: "rs_" + p.slotName, TlsVerify: p.tls, StreamOldData: p.streamSnapshot, + DecodingPlugin: p.decodingPlugin, SnapshotMemorySafetyFactor: p.snapshotMemSafetyFactor, SeparateChanges: true, }) if err != nil { - panic(err) + return err } p.pglogicalStream = pgStream return err @@ -242,7 +254,10 @@ func (p *pgStreamInput) Read(ctx context.Context) (*service.Message, service.Ack //message.ServerHeartbeat. if message.Lsn != nil { - p.pglogicalStream.AckLSN(*message.Lsn) + if err := p.pglogicalStream.AckLSN(*message.Lsn); err != nil { + fmt.Println("Error while acking LSN", err) + return err + } } return nil }, nil diff --git a/internal/impl/postgresql/pg_stream/pg_stream_schemaless/pg_stream_schemaless.go b/internal/impl/postgresql/pg_stream/pg_stream_schemaless/pg_stream_schemaless.go deleted file mode 100644 index c03a94fc74..0000000000 --- a/internal/impl/postgresql/pg_stream/pg_stream_schemaless/pg_stream_schemaless.go +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright 2024 Redpanda Data, Inc. -// -// Licensed as a Redpanda Enterprise file under the Redpanda Community -// License (the "License"); you may not use this file except in compliance with -// the License. You may obtain a copy of the License at -// -// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md - -package pg_stream_schemaless - -import ( - "bytes" - "context" - - "encoding/json" - - "github.com/redpanda-data/benthos/v4/public/service" -) - -func init() { - // Config spec is empty for now as we don't have any dynamic fields. - configSpec := service.NewConfigSpec() - - constructor := func(conf *service.ParsedConfig, mgr *service.Resources) (service.Processor, error) { - return newPgSchematicProcessor(mgr.Logger(), mgr.Metrics()), nil - } - - err := service.RegisterProcessor("pg_stream_schemaless", configSpec, constructor) - if err != nil { - panic(err) - } -} - -type pgSchematicProcessor struct { -} - -func newPgSchematicProcessor(logger *service.Logger, metrics *service.Metrics) *pgSchematicProcessor { - // The logger and metrics components will already be labelled with the - // identifier of this component within a config. - return &pgSchematicProcessor{} -} - -func (r *pgSchematicProcessor) Process(ctx context.Context, m *service.Message) (service.MessageBatch, error) { - bytesContent, err := m.AsBytes() - if err != nil { - return nil, err - } - var message WalMessage - if err = json.NewDecoder(bytes.NewReader(bytesContent)).Decode(&message); err != nil { - return nil, err - } - - var messageAsSchema = map[string]interface{}{} - if len(message.Change) == 0 { - return nil, nil - } - - for _, change := range message.Change { - for i, k := range change.Columnnames { - messageAsSchema[k] = change.Columnvalues[i] - } - } - var newBytes []byte - if newBytes, err = json.Marshal(&messageAsSchema); err != nil { - return nil, err - } - - m.SetBytes(newBytes) - return []*service.Message{m}, nil -} - -func (r *pgSchematicProcessor) Close(ctx context.Context) error { - return nil -} diff --git a/internal/impl/postgresql/pg_stream/pg_stream_schemaless/wal_message.go b/internal/impl/postgresql/pg_stream/pg_stream_schemaless/wal_message.go deleted file mode 100644 index 2477ce4538..0000000000 --- a/internal/impl/postgresql/pg_stream/pg_stream_schemaless/wal_message.go +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright 2024 Redpanda Data, Inc. -// -// Licensed as a Redpanda Enterprise file under the Redpanda Community -// License (the "License"); you may not use this file except in compliance with -// the License. You may obtain a copy of the License at -// -// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md - -package pg_stream_schemaless - -type WalMessage struct { - Change []struct { - Kind string `json:"kind"` - Schema string `json:"schema"` - Table string `json:"table"` - Columnnames []string `json:"columnnames"` - Columntypes []string `json:"columntypes"` - Columnvalues []interface{} `json:"columnvalues"` - Oldkeys struct { - Keynames []string `json:"keynames"` - Keytypes []string `json:"keytypes"` - Keyvalues []interface{} `json:"keyvalues"` - } `json:"oldkeys"` - } `json:"change"` -} diff --git a/internal/impl/postgresql/pglogicalstream/docker-compose.yaml b/internal/impl/postgresql/pglogicalstream/docker-compose.yaml deleted file mode 100644 index 3ff4981ee8..0000000000 --- a/internal/impl/postgresql/pglogicalstream/docker-compose.yaml +++ /dev/null @@ -1,11 +0,0 @@ -services: - postgres: - image: postgres:${POSTGRES_VERSION:-15} - restart: always - command: ["-c", "wal_level=logical", "-c", "max_wal_senders=10", "-c", "max_replication_slots=10"] - environment: - POSTGRES_USER: ${POSTGRES_USER:-pglogrepl} - POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-secret} - POSTGRES_DB: ${POSTGRES_DB:-pglogrepl} - POSTGRES_HOST_AUTH_METHOD: trust - network_mode: "host" \ No newline at end of file diff --git a/internal/impl/postgresql/pglogicalstream/example/simple/main.go b/internal/impl/postgresql/pglogicalstream/example/simple/main.go deleted file mode 100644 index 1008f4a4a3..0000000000 --- a/internal/impl/postgresql/pglogicalstream/example/simple/main.go +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright 2024 Redpanda Data, Inc. -// -// Licensed as a Redpanda Enterprise file under the Redpanda Community -// License (the "License"); you may not use this file except in compliance with -// the License. You may obtain a copy of the License at -// -// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md - -package main - -import ( - "fmt" - "io/ioutil" - "log" - - "gopkg.in/yaml.v3" - - "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream" -) - -func main() { - var config pglogicalstream.Config - yamlFile, err := ioutil.ReadFile("./config.yaml") - if err != nil { - log.Printf("yamlFile.Get err #%v ", err) - } - - err = yaml.Unmarshal(yamlFile, &config) - if err != nil { - log.Fatalf("Unmarshal: %v", err) - } - - pgStream, err := pglogicalstream.NewPgStream(config) - if err != nil { - panic(err) - } - - pgStream.OnMessage(func(message pglogicalstream.Wal2JsonChanges) { - fmt.Println(message.Changes) - if message.Lsn != nil { - // Snapshots dont have LSN - pgStream.AckLSN(*message.Lsn) - } - }) -} diff --git a/internal/impl/postgresql/pglogicalstream/example/ws/main.go b/internal/impl/postgresql/pglogicalstream/example/ws/main.go deleted file mode 100644 index 53e0821162..0000000000 --- a/internal/impl/postgresql/pglogicalstream/example/ws/main.go +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright 2024 Redpanda Data, Inc. -// -// Licensed as a Redpanda Enterprise file under the Redpanda Community -// License (the "License"); you may not use this file except in compliance with -// the License. You may obtain a copy of the License at -// -// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md - -package main - -import ( - "io/ioutil" - "log" - - "github.com/gorilla/websocket" - "gopkg.in/yaml.v3" - - "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream" -) - -func main() { - var config pglogicalstream.Config - yamlFile, err := ioutil.ReadFile("./example/simple/config.yaml") - if err != nil { - log.Printf("yamlFile.Get err #%v ", err) - } - - err = yaml.Unmarshal(yamlFile, &config) - if err != nil { - log.Fatalf("Unmarshal: %v", err) - } - - pgStream, err := pglogicalstream.NewPgStream(config) - if err != nil { - panic(err) - } - - wsClient, _, err := websocket.DefaultDialer.Dial("ws://localhost:10000/ws", nil) - if err != nil { - panic(err) - } - defer wsClient.Close() - - pgStream.OnMessage(func(message pglogicalstream.Wal2JsonChanges) { - marshaledChanges, err := message.Changes[0].Row.MarshalJSON() - if err != nil { - panic(err) - } - - err = wsClient.WriteMessage(websocket.TextMessage, marshaledChanges) - if err != nil { - log.Fatalf("write: %v", err) - } - }) -} diff --git a/internal/impl/postgresql/pglogicalstream/filter.go b/internal/impl/postgresql/pglogicalstream/filter.go deleted file mode 100644 index 6841b71554..0000000000 --- a/internal/impl/postgresql/pglogicalstream/filter.go +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright 2024 Redpanda Data, Inc. -// -// Licensed as a Redpanda Enterprise file under the Redpanda Community -// License (the "License"); you may not use this file except in compliance with -// the License. You may obtain a copy of the License at -// -// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md - -package pglogicalstream - -type ChangeFilter struct { - tablesWhiteList map[string]bool - schemaWhiteList string -} - -type Filtered func(change Wal2JsonChanges) - -func NewChangeFilter(tableSchemas []string, schema string) ChangeFilter { - tablesMap := map[string]bool{} - for _, table := range tableSchemas { - tablesMap[table] = true - } - - return ChangeFilter{ - tablesWhiteList: tablesMap, - schemaWhiteList: schema, - } -} - -func (c ChangeFilter) FilterChange(lsn string, changes WallMessage, OnFiltered Filtered) { - if len(changes.Change) == 0 { - return - } - - for _, ch := range changes.Change { - var filteredChanges = Wal2JsonChanges{ - Lsn: &lsn, - Changes: []Wal2JsonChange{}, - } - if ch.Schema != c.schemaWhiteList { - continue - } - - var ( - tableExist bool - ) - - if _, tableExist = c.tablesWhiteList[ch.Table]; !tableExist { - continue - } - - if ch.Kind == "delete" { - ch.Columnvalues = make([]interface{}, len(ch.Oldkeys.Keyvalues)) - for i, changedValue := range ch.Oldkeys.Keyvalues { - if len(ch.Columnvalues) == 0 { - break - } - ch.Columnvalues[i] = changedValue - } - } - - filteredChanges.Changes = append(filteredChanges.Changes, Wal2JsonChange{ - Kind: ch.Kind, - Schema: ch.Schema, - Table: ch.Table, - ColumnNames: ch.Columnnames, - ColumnTypes: ch.Columntypes, - ColumnValues: ch.Columnvalues, - }) - - OnFiltered(filteredChanges) - } -} diff --git a/internal/impl/postgresql/pglogicalstream/internal/helpers/arrow_schema_builder.go b/internal/impl/postgresql/pglogicalstream/internal/helpers/arrow_schema_builder.go deleted file mode 100644 index c8b5912311..0000000000 --- a/internal/impl/postgresql/pglogicalstream/internal/helpers/arrow_schema_builder.go +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2024 Redpanda Data, Inc. -// -// Licensed as a Redpanda Enterprise file under the Redpanda Community -// License (the "License"); you may not use this file except in compliance with -// the License. You may obtain a copy of the License at -// -// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md - -package helpers - -import "github.com/apache/arrow/go/v14/arrow" - -func MapPlainTypeToArrow(fieldType string) arrow.DataType { - switch fieldType { - case "Boolean": - return arrow.FixedWidthTypes.Boolean - case "Int16": - return arrow.PrimitiveTypes.Int16 - case "Int32": - return arrow.PrimitiveTypes.Int32 - case "Int64": - return arrow.PrimitiveTypes.Int64 - case "Uint64": - return arrow.PrimitiveTypes.Uint64 - case "Float64": - return arrow.PrimitiveTypes.Float64 - case "Float32": - return arrow.PrimitiveTypes.Float32 - case "UUID": - return arrow.BinaryTypes.String - case "bytea": - return arrow.BinaryTypes.Binary - case "JSON": - return arrow.BinaryTypes.String - case "Inet": - return arrow.BinaryTypes.String - case "MAC": - return arrow.BinaryTypes.String - case "Date32": - return arrow.FixedWidthTypes.Date32 - default: - return arrow.BinaryTypes.String - } -} diff --git a/internal/impl/postgresql/pglogicalstream/internal/schemas/schemas.go b/internal/impl/postgresql/pglogicalstream/internal/schemas/schemas.go deleted file mode 100644 index d1611c3fb2..0000000000 --- a/internal/impl/postgresql/pglogicalstream/internal/schemas/schemas.go +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright 2024 Redpanda Data, Inc. -// -// Licensed as a Redpanda Enterprise file under the Redpanda Community -// License (the "License"); you may not use this file except in compliance with -// the License. You may obtain a copy of the License at -// -// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md - -package schemas - -import "github.com/apache/arrow/go/v14/arrow" - -type DataTableSchema struct { - TableName string - Schema *arrow.Schema -} diff --git a/internal/impl/postgresql/pglogicalstream/message.go b/internal/impl/postgresql/pglogicalstream/replication_message.go similarity index 100% rename from internal/impl/postgresql/pglogicalstream/message.go rename to internal/impl/postgresql/pglogicalstream/replication_message.go diff --git a/internal/impl/postgresql/pglogicalstream/message_test.go b/internal/impl/postgresql/pglogicalstream/replication_message_test.go similarity index 100% rename from internal/impl/postgresql/pglogicalstream/message_test.go rename to internal/impl/postgresql/pglogicalstream/replication_message_test.go diff --git a/internal/impl/postgresql/pglogicalstream/wal_changes_message.go b/internal/impl/postgresql/pglogicalstream/wal_changes_message.go deleted file mode 100644 index de5c1255e2..0000000000 --- a/internal/impl/postgresql/pglogicalstream/wal_changes_message.go +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright 2024 Redpanda Data, Inc. -// -// Licensed as a Redpanda Enterprise file under the Redpanda Community -// License (the "License"); you may not use this file except in compliance with -// the License. You may obtain a copy of the License at -// -// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md - -package pglogicalstream - -type WallMessage struct { - Change []struct { - Kind string `json:"kind"` - Schema string `json:"schema"` - Table string `json:"table"` - Columnnames []string `json:"columnnames"` - Columntypes []string `json:"columntypes"` - Columnvalues []interface{} `json:"columnvalues"` - Oldkeys struct { - Keynames []string `json:"keynames"` - Keytypes []string `json:"keytypes"` - Keyvalues []interface{} `json:"keyvalues"` - } `json:"oldkeys"` - } `json:"change"` -} From 2a1b515e031f081ba12b0943ce1084bede0ed48e Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Tue, 1 Oct 2024 21:19:13 +0200 Subject: [PATCH 004/118] feat(pgstream): added support for pgoutput native plugin --- go.mod | 1 - go.sum | 2 - .../impl/postgresql/pglogicalstream/config.go | 35 +- .../impl/postgresql/pglogicalstream/consts.go | 28 ++ .../pglogicalstream/logical_stream.go | 327 ++++++++++++------ .../postgresql/pglogicalstream/pglogrepl.go | 283 --------------- .../pglogicalstream/pglogrepl_test.go | 181 ++-------- .../replication_message_decoders.go | 208 +++++++++++ .../postgresql/pglogicalstream/snapshotter.go | 24 +- .../pglogicalstream/stream_message.go | 14 + 10 files changed, 537 insertions(+), 566 deletions(-) create mode 100644 internal/impl/postgresql/pglogicalstream/consts.go create mode 100644 internal/impl/postgresql/pglogicalstream/replication_message_decoders.go create mode 100644 internal/impl/postgresql/pglogicalstream/stream_message.go diff --git a/go.mod b/go.mod index 370692cafb..6dab1d5b5a 100644 --- a/go.mod +++ b/go.mod @@ -24,7 +24,6 @@ require ( github.com/Masterminds/squirrel v1.5.4 github.com/PaesslerAG/gval v1.2.2 github.com/PaesslerAG/jsonpath v0.1.1 - github.com/apache/arrow/go/v14 v14.0.2 github.com/apache/pulsar-client-go v0.13.1 github.com/aws/aws-lambda-go v1.47.0 github.com/aws/aws-sdk-go-v2 v1.30.4 diff --git a/go.sum b/go.sum index 1335d065bb..f193a3aa89 100644 --- a/go.sum +++ b/go.sum @@ -165,8 +165,6 @@ github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kd github.com/apache/arrow/go/arrow v0.0.0-20200730104253-651201b0f516/go.mod h1:QNYViu/X0HXDHw7m3KXzWSVXIbfUvJqBFe6Gj8/pYA0= github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 h1:q4dksr6ICHXqG5hm0ZW5IHyeEJXoIJSOZeBLmWPNeIQ= github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40/go.mod h1:Q7yQnSMnLvcXlZ8RV+jwz/6y1rQTqbX6C82SndT52Zs= -github.com/apache/arrow/go/v14 v14.0.2 h1:N8OkaJEOfI3mEZt07BIkvo4sC6XDbL+48MBPWO5IONw= -github.com/apache/arrow/go/v14 v14.0.2/go.mod h1:u3fgh3EdgN/YQ8cVQRguVW3R+seMybFg8QBQ5LU+eBY= github.com/apache/arrow/go/v15 v15.0.2 h1:60IliRbiyTWCWjERBCkO1W4Qun9svcYoZrSLcyOsMLE= github.com/apache/arrow/go/v15 v15.0.2/go.mod h1:DGXsR3ajT524njufqf95822i+KTh+yea1jass9YXgjA= github.com/apache/pulsar-client-go v0.13.1 h1:XAAKXjF99du7LP6qu/nBII1HC2nS483/vQoQIWmm5Yg= diff --git a/internal/impl/postgresql/pglogicalstream/config.go b/internal/impl/postgresql/pglogicalstream/config.go index d9569427e0..446c276dc6 100644 --- a/internal/impl/postgresql/pglogicalstream/config.go +++ b/internal/impl/postgresql/pglogicalstream/config.go @@ -8,23 +8,24 @@ package pglogicalstream -type TlsVerify string - -const TlsNoVerify TlsVerify = "none" -const TlsRequireVerify TlsVerify = "require" +import "github.com/redpanda-data/benthos/v4/public/service" type Config struct { - DbHost string `yaml:"db_host"` - DbPassword string `yaml:"db_password"` - DbUser string `yaml:"db_user"` - DbPort int `yaml:"db_port"` - DbName string `yaml:"db_name"` - DbSchema string `yaml:"db_schema"` - DbTables []string `yaml:"db_tables"` - ReplicationSlotName string `yaml:"replication_slot_name"` - TlsVerify TlsVerify `yaml:"tls_verify"` - StreamOldData bool `yaml:"stream_old_data"` - SeparateChanges bool `yaml:"separate_changes"` - SnapshotMemorySafetyFactor float64 `yaml:"snapshot_memory_safety_factor"` - BatchSize int `yaml:"batch_size"` + DbHost string `yaml:"db_host"` + DbPassword string `yaml:"db_password"` + DbUser string `yaml:"db_user"` + DbPort int `yaml:"db_port"` + DbName string `yaml:"db_name"` + DbSchema string `yaml:"db_schema"` + DbTables []string `yaml:"db_tables"` + TlsVerify TlsVerify `yaml:"tls_verify"` + + ReplicationSlotName string `yaml:"replication_slot_name"` + StreamOldData bool `yaml:"stream_old_data"` + SeparateChanges bool `yaml:"separate_changes"` + SnapshotMemorySafetyFactor float64 `yaml:"snapshot_memory_safety_factor"` + DecodingPlugin string `yaml:"decoding_plugin"` + BatchSize int `yaml:"batch_size"` + + logger *service.Logger } diff --git a/internal/impl/postgresql/pglogicalstream/consts.go b/internal/impl/postgresql/pglogicalstream/consts.go new file mode 100644 index 0000000000..544c152d51 --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/consts.go @@ -0,0 +1,28 @@ +package pglogicalstream + +type DecodingPlugin string + +const ( + Wal2JSON DecodingPlugin = "wal2json" + PgOutput DecodingPlugin = "pgoutput" +) + +func DecodingPluginFromString(plugin string) DecodingPlugin { + switch plugin { + case "wal2json": + return Wal2JSON + case "pgoutput": + return PgOutput + default: + return PgOutput + } +} + +func (d DecodingPlugin) String() string { + return string(d) +} + +type TlsVerify string + +const TlsNoVerify TlsVerify = "none" +const TlsRequireVerify TlsVerify = "require" diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 1a35a521d1..8636e209ae 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -9,11 +9,9 @@ package pglogicalstream import ( - "bytes" "context" "crypto/tls" "database/sql" - "encoding/json" "errors" "fmt" "os" @@ -21,9 +19,9 @@ import ( "sync" "time" - "github.com/jackc/pglogrepl" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgproto3" + "github.com/jackc/pgx/v5/pgtype" "github.com/redpanda-data/benthos/v4/public/service" "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/internal/helpers" @@ -39,20 +37,22 @@ type Stream struct { streamCtx context.Context streamCancel context.CancelFunc - standbyCtxCancel context.CancelFunc - clientXLogPos pglogrepl.LSN + standbyCtxCancel context.CancelFunc + + clientXLogPos LSN + lsnrestart LSN + standbyMessageTimeout time.Duration nextStandbyMessageDeadline time.Time - messages chan Wal2JsonChanges - snapshotMessages chan Wal2JsonChanges + messages chan StreamMessage + snapshotMessages chan StreamMessage snapshotName string - changeFilter ChangeFilter - lsnrestart pglogrepl.LSN slotName string schema string tableNames []string separateChanges bool snapshotBatchSize int + decodingPlugin DecodingPlugin snapshotMemorySafetyFactor float64 logger *service.Logger @@ -104,24 +104,31 @@ func NewPgStream(config Config) (*Stream, error) { stream := &Stream{ pgConn: dbConn, dbConfig: *cfg, - messages: make(chan Wal2JsonChanges), - snapshotMessages: make(chan Wal2JsonChanges, 100), + messages: make(chan StreamMessage), + snapshotMessages: make(chan StreamMessage, 100), slotName: config.ReplicationSlotName, schema: config.DbSchema, snapshotMemorySafetyFactor: config.SnapshotMemorySafetyFactor, separateChanges: config.SeparateChanges, snapshotBatchSize: config.BatchSize, tableNames: tableNames, - changeFilter: NewChangeFilter(tableNames, config.DbSchema), - logger: nil, // TODO + logger: config.logger, m: sync.Mutex{}, - stopped: false, + decodingPlugin: DecodingPluginFromString(config.DecodingPlugin), + } + + if stream.decodingPlugin == "pgoutput" { + pluginArguments = []string{ + "proto_version '1'", + fmt.Sprintf("publication_names 'pglog_stream_%s'", config.ReplicationSlotName), + "messages 'true'", + } } result := stream.pgConn.Exec(context.Background(), fmt.Sprintf("DROP PUBLICATION IF EXISTS pglog_stream_%s;", config.ReplicationSlotName)) _, err = result.ReadAll() if err != nil { - stream.logger.Errorf("drop publication if exists error %s", err.Error()) + return nil, err } for i, table := range tableNames { @@ -133,13 +140,14 @@ func NewPgStream(config Config) (*Stream, error) { result = stream.pgConn.Exec(context.Background(), fmt.Sprintf("CREATE PUBLICATION pglog_stream_%s %s;", config.ReplicationSlotName, tablesSchemaFilter)) _, err = result.ReadAll() if err != nil { - panic(fmt.Errorf("create publication error %w", err)) // TODO + return nil, err } + stream.logger.Infof("Created Postgresql publication %v %v", "publication_name", config.ReplicationSlotName) - sysident, err := pglogrepl.IdentifySystem(context.Background(), stream.pgConn) + sysident, err := IdentifySystem(context.Background(), stream.pgConn) if err != nil { - panic(fmt.Errorf("failed to identify the system %w", err)) // TODO + return nil, err } stream.logger.Infof("System identification result SystemID: %v Timeline: %v XLogPos: %v DBName: %v", sysident.SystemID, sysident.Timeline, sysident.XLogPos, sysident.DBName) @@ -149,17 +157,17 @@ func NewPgStream(config Config) (*Stream, error) { // check is replication slot exist to get last restart SLN connExecResult := stream.pgConn.Exec(context.TODO(), fmt.Sprintf("SELECT confirmed_flush_lsn FROM pg_replication_slots WHERE slot_name = '%s'", config.ReplicationSlotName)) if slotCheckResults, err := connExecResult.ReadAll(); err != nil { - panic(err) // TODO + return nil, err } else { if len(slotCheckResults) == 0 || len(slotCheckResults[0].Rows) == 0 { // here we create a new replication slot because there is no slot found var createSlotResult CreateReplicationSlotResult - createSlotResult, err = CreateReplicationSlot(context.Background(), stream.pgConn, stream.slotName, "wal2json", + createSlotResult, err = CreateReplicationSlot(context.Background(), stream.pgConn, stream.slotName, stream.decodingPlugin.String(), CreateReplicationSlotOptions{Temporary: false, SnapshotAction: "export", }) if err != nil { - panic(fmt.Errorf("failed to create replication slot for the database: %w", err)) // TODO + return nil, err } stream.snapshotName = createSlotResult.SnapshotName freshlyCreatedSlot = true @@ -170,11 +178,13 @@ func NewPgStream(config Config) (*Stream, error) { } } - var lsnrestart pglogrepl.LSN + // TODO:: check decoding plugin and replication slot plugin should match + + var lsnrestart LSN if freshlyCreatedSlot { lsnrestart = sysident.XLogPos } else { - lsnrestart, _ = pglogrepl.ParseLSN(confirmedLSNFromDB) + lsnrestart, _ = ParseLSN(confirmedLSNFromDB) } stream.lsnrestart = lsnrestart @@ -190,46 +200,60 @@ func NewPgStream(config Config) (*Stream, error) { stream.streamCtx, stream.streamCancel = context.WithCancel(context.Background()) if !freshlyCreatedSlot || config.StreamOldData == false { - stream.startLr() + if err = stream.startLr(); err != nil { + return nil, err + } + go stream.streamMessagesAsync() } else { // New messages will be streamed after the snapshot has been processed. + // stream.startLr() and stream.streamMessagesAsync() will be called inside stream.processSnapshot() go stream.processSnapshot() } return stream, err } -func (s *Stream) startLr() { +func (s *Stream) startLr() error { var err error - err = pglogrepl.StartReplication(context.Background(), s.pgConn, s.slotName, s.lsnrestart, pglogrepl.StartReplicationOptions{PluginArgs: pluginArguments}) + err = StartReplication(context.Background(), s.pgConn, s.slotName, s.lsnrestart, StartReplicationOptions{PluginArgs: pluginArguments}) if err != nil { - panic(fmt.Errorf("starting replication slot failed: %w", err)) // TODO + return err } s.logger.Infof("Started logical replication on slot slot-name: %v", s.slotName) + + return nil } -func (s *Stream) AckLSN(lsn string) { +func (s *Stream) AckLSN(lsn string) error { var err error - s.clientXLogPos, err = pglogrepl.ParseLSN(lsn) + s.clientXLogPos, err = ParseLSN(lsn) if err != nil { panic(fmt.Errorf("failed to parse LSN for Acknowledge %w", err)) // TODO } - err = pglogrepl.SendStandbyStatusUpdate(context.Background(), s.pgConn, pglogrepl.StandbyStatusUpdate{ + err = SendStandbyStatusUpdate(context.Background(), s.pgConn, StandbyStatusUpdate{ WALApplyPosition: s.clientXLogPos, WALWritePosition: s.clientXLogPos, ReplyRequested: true, }) if err != nil { - panic(fmt.Errorf("sendStandbyStatusUpdate failed: %w", err)) // TODO + s.logger.Errorf("Failed to send Standby status message at LSN#%s: %v", s.clientXLogPos.String(), err) + return err } + s.logger.Debugf("Sent Standby status message at LSN#%s", s.clientXLogPos.String()) s.nextStandbyMessageDeadline = time.Now().Add(s.standbyMessageTimeout) + + return nil } func (s *Stream) streamMessagesAsync() { + relations := map[uint32]*RelationMessage{} + typeMap := pgtype.NewMap() + pgoutputChanges := []StreamMessageChanges{} + for { select { case <-s.streamCtx.Done(): @@ -237,13 +261,19 @@ func (s *Stream) streamMessagesAsync() { return default: if time.Now().After(s.nextStandbyMessageDeadline) { - var err error - err = pglogrepl.SendStandbyStatusUpdate(context.Background(), s.pgConn, pglogrepl.StandbyStatusUpdate{ + if s.pgConn.IsClosed() { + s.logger.Warn("Postgres connection is closed...stop reading from replication slot") + return + } + + err := SendStandbyStatusUpdate(context.Background(), s.pgConn, StandbyStatusUpdate{ WALWritePosition: s.clientXLogPos, }) if err != nil { - panic(fmt.Errorf("sendStandbyStatusUpdate failed: %w", err)) // TODO + s.logger.Errorf("Failed to send Standby status message at LSN#%s: %v", s.clientXLogPos.String(), err) + s.Stop() + return } s.logger.Debugf("Sent Standby status message at LSN#%s", s.clientXLogPos.String()) s.nextStandbyMessageDeadline = time.Now().Add(s.standbyMessageTimeout) @@ -254,7 +284,7 @@ func (s *Stream) streamMessagesAsync() { s.standbyCtxCancel = cancel if err != nil && (errors.Is(err, context.Canceled) || s.stopped) { - s.logger.Warn("Service was interrpupted....stop reading from replication slot") + s.logger.Warn("Service was interrupted....stop reading from replication slot") return } @@ -263,11 +293,15 @@ func (s *Stream) streamMessagesAsync() { continue } - panic(fmt.Errorf("failed to receive messages from PostgreSQL %w", err)) // TODO + s.logger.Errorf("Failed to receive messages from PostgreSQL: %v", err) + s.Stop() + return } if errMsg, ok := rawMsg.(*pgproto3.ErrorResponse); ok { - panic(fmt.Errorf("received broken Postgres WAL. Error: %+v", errMsg)) // TODO + s.logger.Errorf("Received error message from Postgres: %v", errMsg) + s.Stop() + return } msg, ok := rawMsg.(*pgproto3.CopyData) @@ -277,8 +311,8 @@ func (s *Stream) streamMessagesAsync() { } switch msg.Data[0] { - case pglogrepl.PrimaryKeepaliveMessageByteID: - pkm, err := pglogrepl.ParsePrimaryKeepaliveMessage(msg.Data[1:]) + case PrimaryKeepaliveMessageByteID: + pkm, err := ParsePrimaryKeepaliveMessage(msg.Data[1:]) if err != nil { panic(fmt.Errorf("parsePrimaryKeepaliveMessage failed: %w", err)) // TODO } @@ -287,43 +321,117 @@ func (s *Stream) streamMessagesAsync() { s.nextStandbyMessageDeadline = time.Time{} } - case pglogrepl.XLogDataByteID: - xld, err := pglogrepl.ParseXLogData(msg.Data[1:]) + case XLogDataByteID: + xld, err := ParseXLogData(msg.Data[1:]) if err != nil { - panic(fmt.Errorf("parseXLogData failed: %w", err)) // TODO + panic(fmt.Errorf("parseXLogData failed: %w", err)) } - clientXLogPos := xld.WALStart + pglogrepl.LSN(len(xld.WALData)) - var changes WallMessage - if err := json.NewDecoder(bytes.NewReader(xld.WALData)).Decode(&changes); err != nil { - panic(fmt.Errorf("cant parse change from database to filter it %v", err)) + clientXLogPos := xld.WALStart + LSN(len(xld.WALData)) + + if s.decodingPlugin == "wal2json" { + message, err := DecodeWal2JsonChanges(clientXLogPos.String(), xld.WALData) + if err != nil { + s.logger.Errorf("decodeWal2JsonChanges failed: %w", err) + s.Stop() + return + } + + if message == nil || len(message.Changes) == 0 { + // automatic ack for empty changes + // basically mean that the client is up-to-date, + // but we still need to acknowledge the LSN for standby + if err = s.AckLSN(clientXLogPos.String()); err != nil { + // stop reading from replication slot + // if we can't acknowledge the LSN + s.Stop() + return + } + } else { + s.messages <- *message + } } - if len(changes.Change) == 0 { - s.AckLSN(clientXLogPos.String()) - } else { - s.changeFilter.FilterChange(clientXLogPos.String(), changes, func(change Wal2JsonChanges) { - s.messages <- change - }) + if s.decodingPlugin == "pgoutput" { + // message changes must be collected in the buffer in the context of the same transaction + // as single transaction can contain multiple changes + // and LSN ack will cause potential loss of changes + isBegin, err := IsBeginMessage(xld.WALData) + if err != nil { + s.logger.Errorf("Failed to parse WAL data: %w", err) + s.Stop() + return + } + + if isBegin { + pgoutputChanges = []StreamMessageChanges{} + } + + // parse changes inside the transaction + message, err := DecodePgOutput(xld.WALData, relations, typeMap) + if err != nil { + s.logger.Errorf("decodePgOutput failed: %w", err) + s.Stop() + return + } + + if message != nil { + pgoutputChanges = append(pgoutputChanges, *message) + } + + isCommit, err := IsCommitMessage(xld.WALData) + if err != nil { + s.logger.Errorf("Failed to parse WAL data: %w", err) + s.Stop() + return + } + + if isCommit { + if len(pgoutputChanges) == 0 { + // 0 changes happened in the transaction + // or we received a change that are not supported/needed by the replication stream + if err = s.AckLSN(clientXLogPos.String()); err != nil { + // stop reading from replication slot + // if we can't acknowledge the LSN + s.Stop() + return + } + } else { + // send all collected changes + lsn := clientXLogPos.String() + s.messages <- StreamMessage{Lsn: &lsn, Changes: pgoutputChanges} + } + } } } } } } + func (s *Stream) processSnapshot() { - snapshotter, err := NewSnapshotter(s.dbConfig, s.snapshotName) + snapshotter, err := NewSnapshotter(s.dbConfig, s.snapshotName, s.logger) if err != nil { - s.logger.Errorf("Failed to create database snapshot: %v", err.Error()) - s.cleanUpOnFailure() + s.logger.Errorf("Failed to open SQL connection to prepare snapshot: %v", err.Error()) + if err = s.cleanUpOnFailure(); err != nil { + s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) + } + os.Exit(1) } if err = snapshotter.Prepare(); err != nil { - s.logger.Errorf("Failed to prepare database snapshot: %v", err.Error()) - s.cleanUpOnFailure() + s.logger.Errorf("Failed to prepare database snapshot. Probably snapshot is expired...: %v", err.Error()) + if err = s.cleanUpOnFailure(); err != nil { + s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) + } + os.Exit(1) } defer func() { - snapshotter.ReleaseSnapshot() - snapshotter.CloseConn() + if err = snapshotter.ReleaseSnapshot(); err != nil { + s.logger.Errorf("Failed to release database snapshot: %v", err.Error()) + } + if err = snapshotter.CloseConn(); err != nil { + s.logger.Errorf("Failed to close database connection: %v", err.Error()) + } }() for _, table := range s.tableNames { @@ -331,16 +439,30 @@ func (s *Stream) processSnapshot() { var ( avgRowSizeBytes sql.NullInt64 - offset = int(0) + offset = 0 ) - avgRowSizeBytes = snapshotter.FindAvgRowSize(table) + + avgRowSizeBytes, err = snapshotter.FindAvgRowSize(table) + if err != nil { + s.logger.Errorf("Failed to calculate average row size for table %v: %v", table, err.Error()) + if err = s.cleanUpOnFailure(); err != nil { + s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) + } + + os.Exit(1) + } batchSize := snapshotter.CalculateBatchSize(helpers.GetAvailableMemory(), uint64(avgRowSizeBytes.Int64)) s.logger.Infof("Querying snapshot batch_side: %v, available_memory: %v, avg_row_size: %v", batchSize, helpers.GetAvailableMemory(), avgRowSizeBytes.Int64) tablePk, err := s.getPrimaryKeyColumn(table) if err != nil { - panic(err) + s.logger.Errorf("Failed to get primary key column for table %v: %v", table, err.Error()) + if err = s.cleanUpOnFailure(); err != nil { + s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) + } + + os.Exit(1) } for { @@ -350,14 +472,26 @@ func (s *Stream) processSnapshot() { } columnTypes, err := snapshotRows.ColumnTypes() + if err != nil { + s.logger.Errorf("Failed to get column types for table %v: %v", table, err.Error()) + if err = s.cleanUpOnFailure(); err != nil { + s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) + } + os.Exit(1) + } + var columnTypesString = make([]string, len(columnTypes)) columnNames, err := snapshotRows.Columns() - for i := range columnNames { - columnTypesString[i] = columnTypes[i].DatabaseTypeName() + if err != nil { + s.logger.Errorf("Failed to get column names for table %v: %v", table, err.Error()) + if err = s.cleanUpOnFailure(); err != nil { + s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) + } + os.Exit(1) } - if err != nil { - panic(err) + for i := range columnNames { + columnTypesString[i] = columnTypes[i].DatabaseTypeName() } count := len(columnTypes) @@ -384,7 +518,11 @@ func (s *Stream) processSnapshot() { err := snapshotRows.Scan(scanArgs...) if err != nil { - panic(err) + s.logger.Errorf("Failed to scan row for table %v: %v", table, err.Error()) + if err = s.cleanUpOnFailure(); err != nil { + s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) + } + os.Exit(1) } var columnValues = make([]interface{}, len(columnTypes)) @@ -413,18 +551,23 @@ func (s *Stream) processSnapshot() { columnValues[i] = scanArgs[i] } - var snapshotChanges []Wal2JsonChange - snapshotChanges = append(snapshotChanges, Wal2JsonChange{ - Kind: "insert", - Schema: s.schema, - Table: table, - ColumnNames: columnNames, - ColumnValues: columnValues, - }) - var lsn *string - snapshotChangePacket := Wal2JsonChanges{ - Lsn: lsn, - Changes: snapshotChanges, + snapshotChangePacket := StreamMessage{ + Lsn: nil, + Changes: []StreamMessageChanges{ + { + Table: table, + Operation: "insert", + Schema: s.schema, + Data: func() map[string]any { + var data = make(map[string]any) + for i, cn := range columnNames { + data[cn] = columnValues[i] + } + + return data + }(), + }, + }, } s.snapshotMessages <- snapshotChangePacket @@ -439,39 +582,29 @@ func (s *Stream) processSnapshot() { } - s.startLr() - go s.streamMessagesAsync() -} - -func (s *Stream) OnMessage(callback OnMessage) { - for { - select { - case snapshotMessage := <-s.snapshotMessages: - callback(snapshotMessage) - case message := <-s.messages: - callback(message) - case <-s.streamCtx.Done(): - return - } + if err = s.startLr(); err != nil { + s.logger.Errorf("Failed to start logical replication after snapshot: %v", err.Error()) + os.Exit(1) } + go s.streamMessagesAsync() } -func (s *Stream) SnapshotMessageC() chan Wal2JsonChanges { +func (s *Stream) SnapshotMessageC() chan StreamMessage { return s.snapshotMessages } -func (s *Stream) LrMessageC() chan Wal2JsonChanges { +func (s *Stream) LrMessageC() chan StreamMessage { return s.messages } // cleanUpOnFailure drops replication slot and publication if database snapshotting was failed for any reason -func (s *Stream) cleanUpOnFailure() { +func (s *Stream) cleanUpOnFailure() error { s.logger.Warnf("Cleaning up resources on accident: %v", s.slotName) err := DropReplicationSlot(context.Background(), s.pgConn, s.slotName, DropReplicationSlotOptions{Wait: true}) if err != nil { s.logger.Errorf("Failed to drop replication slot: %s", err.Error()) } - s.pgConn.Close(context.TODO()) + return s.pgConn.Close(context.TODO()) } func (s *Stream) getPrimaryKeyColumn(tableName string) (string, error) { diff --git a/internal/impl/postgresql/pglogicalstream/pglogrepl.go b/internal/impl/postgresql/pglogicalstream/pglogrepl.go index 9124fc0de7..1b625d6f76 100644 --- a/internal/impl/postgresql/pglogicalstream/pglogrepl.go +++ b/internal/impl/postgresql/pglogicalstream/pglogrepl.go @@ -41,7 +41,6 @@ type ReplicationMode int const ( LogicalReplication ReplicationMode = iota - PhysicalReplication ) // String formats the mode into a postgres valid string @@ -337,288 +336,6 @@ func StartReplication(ctx context.Context, conn *pgconn.PgConn, slotName string, } } -type BaseBackupOptions struct { - // Request information required to generate a progress report, but might as such have a negative impact on the performance. - Progress bool - // Sets the label of the backup. If none is specified, a backup label of 'wal-g' will be used. - Label string - // Request a fast checkpoint. - Fast bool - // Include the necessary WAL segments in the backup. This will include all the files between start and stop backup in the pg_wal directory of the base directory tar file. - WAL bool - // By default, the backup will wait until the last required WAL segment has been archived, or emit a warning if log archiving is not enabled. - // Specifying NOWAIT disables both the waiting and the warning, leaving the client responsible for ensuring the required log is available. - NoWait bool - // Limit (throttle) the maximum amount of data transferred from server to client per unit of time (kb/s). - MaxRate int32 - // Include information about symbolic links present in the directory pg_tblspc in a file named tablespace_map. - TablespaceMap bool - // Disable checksums being verified during a base backup. - // Note that NoVerifyChecksums=true is only supported since PG11 - NoVerifyChecksums bool -} - -func (bbo BaseBackupOptions) sql(serverVersion int) string { - var parts []string - if bbo.Label != "" { - parts = append(parts, "LABEL '"+strings.ReplaceAll(bbo.Label, "'", "''")+"'") - } - if bbo.Progress { - parts = append(parts, "PROGRESS") - } - if bbo.Fast { - if serverVersion >= 15 { - parts = append(parts, "CHECKPOINT 'fast'") - } else { - parts = append(parts, "FAST") - } - } - if bbo.WAL { - parts = append(parts, "WAL") - } - if bbo.NoWait { - if serverVersion >= 15 { - parts = append(parts, "WAIT false") - } else { - parts = append(parts, "NOWAIT") - } - } - if bbo.MaxRate >= 32 { - parts = append(parts, fmt.Sprintf("MAX_RATE %d", bbo.MaxRate)) - } - if bbo.TablespaceMap { - parts = append(parts, "TABLESPACE_MAP") - } - if bbo.NoVerifyChecksums { - if serverVersion >= 15 { - parts = append(parts, "VERIFY_CHECKSUMS false") - } else if serverVersion >= 11 { - parts = append(parts, "NOVERIFY_CHECKSUMS") - } - } - if serverVersion >= 15 { - return "BASE_BACKUP(" + strings.Join(parts, ", ") + ")" - } - return "BASE_BACKUP " + strings.Join(parts, " ") -} - -// BaseBackupTablespace represents a tablespace in the backup -type BaseBackupTablespace struct { - OID int32 - Location string - Size int8 -} - -// BaseBackupResult will hold the return values of the BaseBackup command -type BaseBackupResult struct { - LSN LSN - TimelineID int32 - Tablespaces []BaseBackupTablespace -} - -func serverMajorVersion(conn *pgconn.PgConn) (int, error) { - verString := conn.ParameterStatus("server_version") - dot := strings.IndexByte(verString, '.') - if dot == -1 { - return 0, fmt.Errorf("bad server version string: '%s'", verString) - } - return strconv.Atoi(verString[:dot]) -} - -// StartBaseBackup begins the process for copying a basebackup by executing the BASE_BACKUP command. -func StartBaseBackup(ctx context.Context, conn *pgconn.PgConn, options BaseBackupOptions) (result BaseBackupResult, err error) { - serverVersion, err := serverMajorVersion(conn) - if err != nil { - return result, err - } - sql := options.sql(serverVersion) - - conn.Frontend().SendQuery(&pgproto3.Query{String: sql}) - err = conn.Frontend().Flush() - if err != nil { - return result, fmt.Errorf("failed to send BASE_BACKUP: %w", err) - } - // From here Postgres returns result sets, but pgconn has no infrastructure to properly capture them. - // So we capture data low level with sub functions, before we return from this function when we get to the CopyData part. - result.LSN, result.TimelineID, err = getBaseBackupInfo(ctx, conn) - if err != nil { - return result, err - } - result.Tablespaces, err = getTableSpaceInfo(ctx, conn) - return result, err -} - -// getBaseBackupInfo returns the start or end position of the backup as returned by Postgres -func getBaseBackupInfo(ctx context.Context, conn *pgconn.PgConn) (start LSN, timelineID int32, err error) { - for { - msg, err := conn.ReceiveMessage(ctx) - if err != nil { - return start, timelineID, fmt.Errorf("failed to receive message: %w", err) - } - switch msg := msg.(type) { - case *pgproto3.RowDescription: - if len(msg.Fields) != 2 { - return start, timelineID, fmt.Errorf("expected 2 column headers, received: %d", len(msg.Fields)) - } - colName := string(msg.Fields[0].Name) - if colName != "recptr" { - return start, timelineID, fmt.Errorf("unexpected col name for recptr col: %s", colName) - } - colName = string(msg.Fields[1].Name) - if colName != "tli" { - return start, timelineID, fmt.Errorf("unexpected col name for tli col: %s", colName) - } - case *pgproto3.DataRow: - if len(msg.Values) != 2 { - return start, timelineID, fmt.Errorf("expected 2 columns, received: %d", len(msg.Values)) - } - colData := string(msg.Values[0]) - start, err = ParseLSN(colData) - if err != nil { - return start, timelineID, fmt.Errorf("cannot convert result to LSN: %s", colData) - } - colData = string(msg.Values[1]) - tli, err := strconv.Atoi(colData) - if err != nil { - return start, timelineID, fmt.Errorf("cannot convert timelineID to int: %s", colData) - } - timelineID = int32(tli) - case *pgproto3.NoticeResponse: - case *pgproto3.CommandComplete: - return start, timelineID, nil - case *pgproto3.ErrorResponse: - return start, timelineID, fmt.Errorf("error response sev=%q code=%q message=%q detail=%q position=%d", msg.Severity, msg.Code, msg.Message, msg.Detail, msg.Position) - default: - return start, timelineID, fmt.Errorf("unexpected response type: %T", msg) - } - } -} - -// getBaseBackupInfo returns the start or end position of the backup as returned by Postgres -func getTableSpaceInfo(ctx context.Context, conn *pgconn.PgConn) (tbss []BaseBackupTablespace, err error) { - for { - msg, err := conn.ReceiveMessage(ctx) - if err != nil { - return tbss, fmt.Errorf("failed to receive message: %w", err) - } - switch msg := msg.(type) { - case *pgproto3.RowDescription: - if len(msg.Fields) != 3 { - return tbss, fmt.Errorf("expected 3 column headers, received: %d", len(msg.Fields)) - } - colName := string(msg.Fields[0].Name) - if colName != "spcoid" { - return tbss, fmt.Errorf("unexpected col name for spcoid col: %s", colName) - } - colName = string(msg.Fields[1].Name) - if colName != "spclocation" { - return tbss, fmt.Errorf("unexpected col name for spclocation col: %s", colName) - } - colName = string(msg.Fields[2].Name) - if colName != "size" { - return tbss, fmt.Errorf("unexpected col name for size col: %s", colName) - } - case *pgproto3.DataRow: - if len(msg.Values) != 3 { - return tbss, fmt.Errorf("expected 3 columns, received: %d", len(msg.Values)) - } - if msg.Values[0] == nil { - continue - } - tbs := BaseBackupTablespace{} - colData := string(msg.Values[0]) - OID, err := strconv.Atoi(colData) - if err != nil { - return tbss, fmt.Errorf("cannot convert spcoid to int: %s", colData) - } - tbs.OID = int32(OID) - tbs.Location = string(msg.Values[1]) - if msg.Values[2] != nil { - colData := string(msg.Values[2]) - size, err := strconv.Atoi(colData) - if err != nil { - return tbss, fmt.Errorf("cannot convert size to int: %s", colData) - } - tbs.Size = int8(size) - } - tbss = append(tbss, tbs) - case *pgproto3.CommandComplete: - return tbss, nil - default: - return tbss, fmt.Errorf("unexpected response type: %T", msg) - } - } -} - -// NextTableSpace consumes some msgs so we are at start of CopyData -func NextTableSpace(ctx context.Context, conn *pgconn.PgConn) (err error) { - - for { - msg, err := conn.ReceiveMessage(ctx) - if err != nil { - return fmt.Errorf("failed to receive message: %w", err) - } - - switch msg := msg.(type) { - case *pgproto3.CopyOutResponse: - return nil - case *pgproto3.CopyData: - return nil - case *pgproto3.ErrorResponse: - return pgconn.ErrorResponseToPgError(msg) - case *pgproto3.NoticeResponse: - case *pgproto3.RowDescription: - - default: - return fmt.Errorf("unexpected response type: %T", msg) - } - } -} - -// FinishBaseBackup wraps up a backup after copying all results from the BASE_BACKUP command. -func FinishBaseBackup(ctx context.Context, conn *pgconn.PgConn) (result BaseBackupResult, err error) { - - // From here Postgres returns result sets, but pgconn has no infrastructure to properly capture them. - // So we capture data low level with sub functions, before we return from this function when we get to the CopyData part. - result.LSN, result.TimelineID, err = getBaseBackupInfo(ctx, conn) - if err != nil { - return result, err - } - - // Base_Backup done, server send a command complete response from pg13 - vmaj, err := serverMajorVersion(conn) - if err != nil { - return - } - var ( - pack pgproto3.BackendMessage - ok bool - ) - if vmaj > 12 { - pack, err = conn.ReceiveMessage(ctx) - if err != nil { - return - } - _, ok = pack.(*pgproto3.CommandComplete) - if !ok { - err = fmt.Errorf("expect command_complete, got %T", pack) - return - } - } - - // simple query done, server send a ready for query response - pack, err = conn.ReceiveMessage(ctx) - if err != nil { - return - } - _, ok = pack.(*pgproto3.ReadyForQuery) - if !ok { - err = fmt.Errorf("expect ready_for_query, got %T", pack) - return - } - return -} - type PrimaryKeepaliveMessage struct { ServerWALEnd LSN ServerTime time.Time diff --git a/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go b/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go index 0e81f4c00e..f9b501e8bc 100644 --- a/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go +++ b/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go @@ -12,11 +12,9 @@ import ( "context" "fmt" "os" - "strconv" "testing" "time" - "github.com/jackc/pglogrepl" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgproto3" "github.com/stretchr/testify/assert" @@ -45,7 +43,7 @@ func (s *lsnSuite) NoError(err error) { } func (s *lsnSuite) TestScannerInterface() { - var lsn pglogrepl.LSN + var lsn LSN lsnText := "16/B374D848" lsnUint64 := uint64(97500059720) var err error @@ -69,13 +67,13 @@ func (s *lsnSuite) TestScannerInterface() { } func (s *lsnSuite) TestScanToNil() { - var lsnPtr *pglogrepl.LSN + var lsnPtr *LSN err := lsnPtr.Scan("16/B374D848") s.NoError(err) } func (s *lsnSuite) TestValueInterface() { - lsn := pglogrepl.LSN(97500059720) + lsn := LSN(97500059720) driverValue, err := lsn.Value() s.NoError(err) lsnStr, ok := driverValue.(string) @@ -100,7 +98,7 @@ func TestIdentifySystem(t *testing.T) { require.NoError(t, err) defer closeConn(t, conn) - sysident, err := pglogrepl.IdentifySystem(ctx, conn) + sysident, err := IdentifySystem(ctx, conn) require.NoError(t, err) assert.Greater(t, len(sysident.SystemID), 0) @@ -121,18 +119,18 @@ func TestGetHistoryFile(t *testing.T) { require.NoError(t, err) defer closeConn(t, conn) - sysident, err := pglogrepl.IdentifySystem(ctx, conn) + sysident, err := IdentifySystem(ctx, conn) require.NoError(t, err) - _, err = pglogrepl.TimelineHistory(ctx, conn, 0) + _, err = TimelineHistory(ctx, conn, 0) require.Error(t, err) - _, err = pglogrepl.TimelineHistory(ctx, conn, 1) + _, err = TimelineHistory(ctx, conn, 1) require.Error(t, err) if sysident.Timeline > 1 { // This test requires a Postgres with at least 1 timeline increase (promote, or recover)... - tlh, err := pglogrepl.TimelineHistory(ctx, conn, sysident.Timeline) + tlh, err := TimelineHistory(ctx, conn, sysident.Timeline) require.NoError(t, err) expectedFileName := fmt.Sprintf("%08X.history", sysident.Timeline) @@ -149,7 +147,7 @@ func TestCreateReplicationSlot(t *testing.T) { require.NoError(t, err) defer closeConn(t, conn) - result, err := pglogrepl.CreateReplicationSlot(ctx, conn, slotName, outputPlugin, pglogrepl.CreateReplicationSlotOptions{Temporary: true}) + result, err := CreateReplicationSlot(ctx, conn, slotName, outputPlugin, CreateReplicationSlotOptions{Temporary: true}) require.NoError(t, err) assert.Equal(t, slotName, result.SlotName) @@ -164,13 +162,13 @@ func TestDropReplicationSlot(t *testing.T) { require.NoError(t, err) defer closeConn(t, conn) - _, err = pglogrepl.CreateReplicationSlot(ctx, conn, slotName, outputPlugin, pglogrepl.CreateReplicationSlotOptions{Temporary: true}) + _, err = CreateReplicationSlot(ctx, conn, slotName, outputPlugin, CreateReplicationSlotOptions{Temporary: true}) require.NoError(t, err) - err = pglogrepl.DropReplicationSlot(ctx, conn, slotName, pglogrepl.DropReplicationSlotOptions{}) + err = DropReplicationSlot(ctx, conn, slotName, DropReplicationSlotOptions{}) require.NoError(t, err) - _, err = pglogrepl.CreateReplicationSlot(ctx, conn, slotName, outputPlugin, pglogrepl.CreateReplicationSlotOptions{Temporary: true}) + _, err = CreateReplicationSlot(ctx, conn, slotName, outputPlugin, CreateReplicationSlotOptions{Temporary: true}) require.NoError(t, err) } @@ -182,13 +180,13 @@ func TestStartReplication(t *testing.T) { require.NoError(t, err) defer closeConn(t, conn) - sysident, err := pglogrepl.IdentifySystem(ctx, conn) + sysident, err := IdentifySystem(ctx, conn) require.NoError(t, err) - _, err = pglogrepl.CreateReplicationSlot(ctx, conn, slotName, outputPlugin, pglogrepl.CreateReplicationSlotOptions{Temporary: true}) + _, err = CreateReplicationSlot(ctx, conn, slotName, outputPlugin, CreateReplicationSlotOptions{Temporary: true}) require.NoError(t, err) - err = pglogrepl.StartReplication(ctx, conn, slotName, sysident.XLogPos, pglogrepl.StartReplicationOptions{}) + err = StartReplication(ctx, conn, slotName, sysident.XLogPos, StartReplicationOptions{}) require.NoError(t, err) go func() { @@ -219,19 +217,19 @@ drop table t; require.NoError(t, err) }() - rxKeepAlive := func() pglogrepl.PrimaryKeepaliveMessage { + rxKeepAlive := func() PrimaryKeepaliveMessage { msg, err := conn.ReceiveMessage(ctx) require.NoError(t, err) cdMsg, ok := msg.(*pgproto3.CopyData) require.True(t, ok) - require.Equal(t, byte(pglogrepl.PrimaryKeepaliveMessageByteID), cdMsg.Data[0]) - pkm, err := pglogrepl.ParsePrimaryKeepaliveMessage(cdMsg.Data[1:]) + require.Equal(t, byte(PrimaryKeepaliveMessageByteID), cdMsg.Data[0]) + pkm, err := ParsePrimaryKeepaliveMessage(cdMsg.Data[1:]) require.NoError(t, err) return pkm } - rxXLogData := func() pglogrepl.XLogData { + rxXLogData := func() XLogData { var cdMsg *pgproto3.CopyData // Discard keepalive messages for { @@ -240,12 +238,12 @@ drop table t; var ok bool cdMsg, ok = msg.(*pgproto3.CopyData) require.True(t, ok) - if cdMsg.Data[0] != pglogrepl.PrimaryKeepaliveMessageByteID { + if cdMsg.Data[0] != PrimaryKeepaliveMessageByteID { break } } - require.Equal(t, byte(pglogrepl.XLogDataByteID), cdMsg.Data[0]) - xld, err := pglogrepl.ParseXLogData(cdMsg.Data[1:]) + require.Equal(t, byte(XLogDataByteID), cdMsg.Data[0]) + xld, err := ParseXLogData(cdMsg.Data[1:]) require.NoError(t, err) return xld } @@ -267,137 +265,6 @@ drop table t; assert.Equal(t, "COMMIT", string(xld.WALData[:6])) } -func TestStartReplicationPhysical(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*50) - defer cancel() - - conn, err := pgconn.Connect(ctx, os.Getenv("PGLOGREPL_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, conn) - - sysident, err := pglogrepl.IdentifySystem(ctx, conn) - require.NoError(t, err) - - _, err = pglogrepl.CreateReplicationSlot(ctx, conn, slotName, "", pglogrepl.CreateReplicationSlotOptions{Temporary: true, Mode: pglogrepl.PhysicalReplication}) - require.NoError(t, err) - - err = pglogrepl.StartReplication(ctx, conn, slotName, sysident.XLogPos, pglogrepl.StartReplicationOptions{Mode: pglogrepl.PhysicalReplication}) - require.NoError(t, err) - - go func() { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - - config, err := pgconn.ParseConfig(os.Getenv("PGLOGREPL_TEST_CONN_STRING")) - require.NoError(t, err) - delete(config.RuntimeParams, "replication") - - conn, err := pgconn.ConnectConfig(ctx, config) - require.NoError(t, err) - defer closeConn(t, conn) - - _, err = conn.Exec(ctx, ` -create table mytable(id int primary key, name text); -drop table mytable; -`).ReadAll() - require.NoError(t, err) - }() - - _ = func() pglogrepl.PrimaryKeepaliveMessage { - msg, err := conn.ReceiveMessage(ctx) - require.NoError(t, err) - cdMsg, ok := msg.(*pgproto3.CopyData) - require.True(t, ok) - - require.Equal(t, byte(pglogrepl.PrimaryKeepaliveMessageByteID), cdMsg.Data[0]) - pkm, err := pglogrepl.ParsePrimaryKeepaliveMessage(cdMsg.Data[1:]) - require.NoError(t, err) - return pkm - } - - rxXLogData := func() pglogrepl.XLogData { - msg, err := conn.ReceiveMessage(ctx) - require.NoError(t, err) - cdMsg, ok := msg.(*pgproto3.CopyData) - require.True(t, ok) - - require.Equal(t, byte(pglogrepl.XLogDataByteID), cdMsg.Data[0]) - xld, err := pglogrepl.ParseXLogData(cdMsg.Data[1:]) - require.NoError(t, err) - return xld - } - - xld := rxXLogData() - assert.Contains(t, string(xld.WALData), "mytable") - - copyDoneResult, err := pglogrepl.SendStandbyCopyDone(ctx, conn) - require.NoError(t, err) - assert.Nil(t, copyDoneResult) -} - -func TestBaseBackup(t *testing.T) { - // base backup test could take a long time. Therefore it can be disabled. - envSkipTest := os.Getenv("PGLOGREPL_SKIP_BASE_BACKUP") - if envSkipTest != "" { - skipTest, err := strconv.ParseBool(envSkipTest) - if err != nil { - t.Error(err) - } else if skipTest { - return - } - } - - conn, err := pgconn.Connect(context.Background(), os.Getenv("PGLOGREPL_TEST_CONN_STRING")) - require.NoError(t, err) - defer closeConn(t, conn) - - options := pglogrepl.BaseBackupOptions{ - NoVerifyChecksums: true, - Progress: true, - Label: "pglogrepltest", - Fast: true, - WAL: true, - NoWait: true, - MaxRate: 1024, - TablespaceMap: true, - } - startRes, err := pglogrepl.StartBaseBackup(context.Background(), conn, options) - require.NoError(t, err) - require.GreaterOrEqual(t, startRes.TimelineID, int32(1)) - - //Write the tablespaces - for i := 0; i < len(startRes.Tablespaces)+1; i++ { - f, err := os.CreateTemp("", fmt.Sprintf("pglogrepl_test_tbs_%d.tar", i)) - require.NoError(t, err) - err = pglogrepl.NextTableSpace(context.Background(), conn) - var message pgproto3.BackendMessage - L: - for { - message, err = conn.ReceiveMessage(context.Background()) - require.NoError(t, err) - switch msg := message.(type) { - case *pgproto3.CopyData: - _, err := f.Write(msg.Data) - require.NoError(t, err) - case *pgproto3.CopyDone: - break L - default: - t.Errorf("Received unexpected message: %#v\n", msg) - } - } - err = f.Close() - require.NoError(t, err) - } - - stopRes, err := pglogrepl.FinishBaseBackup(context.Background(), conn) - require.NoError(t, err) - require.Equal(t, startRes.TimelineID, stopRes.TimelineID) - require.Equal(t, len(stopRes.Tablespaces), 0) - require.Less(t, uint64(startRes.LSN), uint64(stopRes.LSN)) - _, err = pglogrepl.StartBaseBackup(context.Background(), conn, options) - require.NoError(t, err) -} - func TestSendStandbyStatusUpdate(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() @@ -406,9 +273,9 @@ func TestSendStandbyStatusUpdate(t *testing.T) { require.NoError(t, err) defer closeConn(t, conn) - sysident, err := pglogrepl.IdentifySystem(ctx, conn) + sysident, err := IdentifySystem(ctx, conn) require.NoError(t, err) - err = pglogrepl.SendStandbyStatusUpdate(ctx, conn, pglogrepl.StandbyStatusUpdate{WALWritePosition: sysident.XLogPos}) + err = SendStandbyStatusUpdate(ctx, conn, StandbyStatusUpdate{WALWritePosition: sysident.XLogPos}) require.NoError(t, err) } diff --git a/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go b/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go new file mode 100644 index 0000000000..2070f88e86 --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go @@ -0,0 +1,208 @@ +package pglogicalstream + +import ( + "bytes" + "encoding/json" + "fmt" + "github.com/jackc/pgx/v5/pgtype" + "log" +) + +// ---------------------------------------------------------------------------- +// PgOutput section + +func IsBeginMessage(WALData []byte) (bool, error) { + logicalMsg, err := Parse(WALData) + if err != nil { + return false, err + } + + _, ok := logicalMsg.(*BeginMessage) + return ok, nil +} + +func IsCommitMessage(WALData []byte) (bool, error) { + logicalMsg, err := Parse(WALData) + if err != nil { + return false, err + } + + _, ok := logicalMsg.(*CommitMessage) + return ok, nil +} + +// DecodePgOutput decodes a logical replication message in pgoutput format. +// It uses the provided relations map to look up the relation metadata for the +// as a side effect it updates the relations map with any new relation metadata +// When the relation is changes in the database, the relation message is sent +// before the change message. +func DecodePgOutput(WALData []byte, relations map[uint32]*RelationMessage, typeMap *pgtype.Map) (message *StreamMessageChanges, err error) { + logicalMsg, err := Parse(WALData) + message = &StreamMessageChanges{} + + if err != nil { + return nil, err + } + switch logicalMsg := logicalMsg.(type) { + case *RelationMessage: + relations[logicalMsg.RelationID] = logicalMsg + return nil, nil + case *BeginMessage: + return nil, nil + case *CommitMessage: + return nil, nil + case *InsertMessage: + rel, ok := relations[logicalMsg.RelationID] + if !ok { + return nil, fmt.Errorf("unknown relation ID %d", logicalMsg.RelationID) + } + message.Operation = "insert" + message.Schema = rel.Namespace + message.Table = rel.RelationName + values := map[string]interface{}{} + for idx, col := range logicalMsg.Tuple.Columns { + colName := rel.Columns[idx].Name + switch col.DataType { + case 'n': // null + values[colName] = nil + case 'u': // unchanged toast + // This TOAST value was not changed. TOAST values are not stored in the tuple, and logical replication doesn't want to spend a disk read to fetch its value for you. + case 't': //text + val, err := decodeTextColumnData(typeMap, col.Data, rel.Columns[idx].DataType) + if err != nil { + return nil, err + } + values[colName] = val + } + } + + message.Data = values + case *UpdateMessage: + rel, ok := relations[logicalMsg.RelationID] + if !ok { + return nil, fmt.Errorf("unknown relation ID %d", logicalMsg.RelationID) + } + values := map[string]interface{}{} + for idx, col := range logicalMsg.NewTuple.Columns { + colName := rel.Columns[idx].Name + switch col.DataType { + case 'n': // null + values[colName] = nil + case 'u': // unchanged toast + // This TOAST value was not changed. TOAST values are not stored in the tuple, and logical replication doesn't want to spend a disk read to fetch its value for you. + case 't': //text + val, err := decodeTextColumnData(typeMap, col.Data, rel.Columns[idx].DataType) + if err != nil { + log.Fatalln("error decoding column data:", err) + } + values[colName] = val + } + } + message.Data = values + //log.Printf("UPDATE %s.%s: SET %v", rel.Namespace, rel.RelationName, values) + case *DeleteMessage: + rel, ok := relations[logicalMsg.RelationID] + if !ok { + return nil, fmt.Errorf("unknown relation ID %d", logicalMsg.RelationID) + } + + values := map[string]interface{}{} + for idx, col := range logicalMsg.OldTuple.Columns { + colName := rel.Columns[idx].Name + switch col.DataType { + case 'n': // null + values[colName] = nil + case 'u': // unchanged toast + // This TOAST value was not changed. TOAST values are not stored in the tuple, and logical replication doesn't want to spend a disk read to fetch its value for you. + case 't': //text + val, err := decodeTextColumnData(typeMap, col.Data, rel.Columns[idx].DataType) + if err != nil { + log.Fatalln("error decoding column data:", err) + } + values[colName] = val + } + } + message.Data = values + case *TruncateMessage: + + case *TypeMessage: + case *OriginMessage: + + case *LogicalDecodingMessage: + log.Printf("Logical decoding message: %q, %q", logicalMsg.Prefix, logicalMsg.Content) + return nil, nil + default: + log.Printf("Unknown message type in pgoutput stream: %T", logicalMsg) + return nil, nil + } + + return message, nil +} + +func decodeTextColumnData(mi *pgtype.Map, data []byte, dataType uint32) (interface{}, error) { + if dt, ok := mi.TypeForOID(dataType); ok { + return dt.Codec.DecodeValue(mi, dataType, pgtype.TextFormatCode, data) + } + return string(data), nil +} + +// ---------------------------------------------------------------------------- +// Wal2Json section + +type WallMessageWal2JSON struct { + Change []struct { + Kind string `json:"kind"` + Schema string `json:"schema"` + Table string `json:"table"` + Columnnames []string `json:"columnnames"` + Columntypes []string `json:"columntypes"` + Columnvalues []interface{} `json:"columnvalues"` + Oldkeys struct { + Keynames []string `json:"keynames"` + Keytypes []string `json:"keytypes"` + Keyvalues []interface{} `json:"keyvalues"` + } `json:"oldkeys"` + } `json:"change"` +} + +func DecodeWal2JsonChanges(clientXLogPosition string, WALData []byte) (*StreamMessage, error) { + var changes WallMessageWal2JSON + if err := json.NewDecoder(bytes.NewReader(WALData)).Decode(&changes); err != nil { + return nil, err + } + + if len(changes.Change) == 0 { + return nil, nil + } + message := &StreamMessage{ + Lsn: &clientXLogPosition, + Changes: make([]StreamMessageChanges, len(changes.Change)), + } + + for _, change := range changes.Change { + messageChange := StreamMessageChanges{ + Operation: change.Kind, + Schema: change.Schema, + Table: change.Table, + Data: make(map[string]any), + } + + if change.Kind == "delete" { + for i, keyName := range change.Oldkeys.Keynames { + if len(change.Columnvalues) == 0 { + break + } + + messageChange.Data[keyName] = change.Oldkeys.Keyvalues[i] + } + } else { + for i, columnName := range change.Columnnames { + messageChange.Data[columnName] = change.Columnvalues[i] + } + } + + message.Changes = append(message.Changes, messageChange) + } + + return message, nil +} diff --git a/internal/impl/postgresql/pglogicalstream/snapshotter.go b/internal/impl/postgresql/pglogicalstream/snapshotter.go index a4c73479f9..7bbe171661 100644 --- a/internal/impl/postgresql/pglogicalstream/snapshotter.go +++ b/internal/impl/postgresql/pglogicalstream/snapshotter.go @@ -17,13 +17,19 @@ import ( "github.com/redpanda-data/benthos/v4/public/service" ) +// Snapshotter is a structure that allows the creation of a snapshot of a database at a given point in time +// At the time we initialize logical replication - we specify what we want to export the snapshot. +// This snapshot exists until the connection that created the replication slot remains open. +// Therefore Snapshotter opens another connection to the database and sets the transaction to the snapshot. +// This allows you to read the data that was in the database at the time of the snapshot creation. type Snapshotter struct { pgConnection *sql.DB - snapshotName string logger *service.Logger + + snapshotName string } -func NewSnapshotter(dbConf pgconn.Config, snapshotName string) (*Snapshotter, error) { +func NewSnapshotter(dbConf pgconn.Config, snapshotName string, logger *service.Logger) (*Snapshotter, error) { var sslMode = "none" if dbConf.TLSConfig != nil { sslMode = "require" @@ -39,7 +45,7 @@ func NewSnapshotter(dbConf pgconn.Config, snapshotName string) (*Snapshotter, er return &Snapshotter{ pgConnection: pgConn, snapshotName: snapshotName, - logger: nil, // TODO + logger: logger, }, err } @@ -54,22 +60,22 @@ func (s *Snapshotter) Prepare() error { return nil } -func (s *Snapshotter) FindAvgRowSize(table string) sql.NullInt64 { +func (s *Snapshotter) FindAvgRowSize(table string) (sql.NullInt64, error) { var avgRowSize sql.NullInt64 if rows, err := s.pgConnection.Query(fmt.Sprintf(`SELECT SUM(pg_column_size('%s.*')) / COUNT(*) FROM %s;`, table, table)); err != nil { - panic(fmt.Errorf("can get avg row size: %w", err)) // TODO + return avgRowSize, fmt.Errorf("can get avg row size due to query failure: %w", err) } else { if rows.Next() { if err = rows.Scan(&avgRowSize); err != nil { - panic(fmt.Errorf("can get avg row size: %w", err)) // TODO + return avgRowSize, fmt.Errorf("can get avg row size: %w", err) } } else { - panic("can get avg row size; 0 rows returned") // TODO + return avgRowSize, fmt.Errorf("can get avg row size; 0 rows returned") } } - return avgRowSize + return avgRowSize, nil } func (s *Snapshotter) CalculateBatchSize(availableMemory uint64, estimatedRowSize uint64) int { @@ -80,11 +86,11 @@ func (s *Snapshotter) CalculateBatchSize(availableMemory uint64, estimatedRowSiz if batchSize < 1 { batchSize = 1 } + return batchSize } func (s *Snapshotter) QuerySnapshotData(table string, pk string, limit, offset int) (rows *sql.Rows, err error) { - // fmt.Sprintf("SELECT * FROM %s ORDER BY %s LIMIT %d OFFSET %d;", table, pk, limit, offset) s.logger.Infof("Query snapshot table: %v, limit: %v, offset: %v, pk: %v", table, limit, offset, pk) return s.pgConnection.Query(fmt.Sprintf("SELECT * FROM %s ORDER BY %s LIMIT %d OFFSET %d;", table, pk, limit, offset)) } diff --git a/internal/impl/postgresql/pglogicalstream/stream_message.go b/internal/impl/postgresql/pglogicalstream/stream_message.go new file mode 100644 index 0000000000..520539500a --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/stream_message.go @@ -0,0 +1,14 @@ +package pglogicalstream + +type StreamMessageChanges struct { + Operation string `json:"operation"` + Schema string `json:"schema"` + Table string `json:"table"` + // For deleted messages - there will be old changes if replica identity set to full or empty changes + Data map[string]any `json:"data"` +} + +type StreamMessage struct { + Lsn *string `json:"lsn"` + Changes []StreamMessageChanges `json:"changes"` +} From 2c66b77f5686134b002be6948f358a1ca1ff21e2 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Wed, 2 Oct 2024 18:17:15 +0200 Subject: [PATCH 005/118] chore(pg_stream): updated table filtering --- go.mod | 5 ++-- go.sum | 2 -- .../pg_stream/pg_stream/integration_test.go | 12 ++++++---- .../pglogicalstream/logical_stream.go | 24 +++++++++++++------ 4 files changed, 27 insertions(+), 16 deletions(-) diff --git a/go.mod b/go.mod index 6dab1d5b5a..e5bbf4939f 100644 --- a/go.mod +++ b/go.mod @@ -65,7 +65,6 @@ require ( github.com/golang-jwt/jwt/v5 v5.2.1 github.com/gosimple/slug v1.14.0 github.com/influxdata/influxdb1-client v0.0.0-20220302092344-a9ab5670611c - github.com/jackc/pglogrepl v0.0.0-20240307033717-828fbfe908e9 github.com/jackc/pgx/v4 v4.18.3 github.com/jackc/pgx/v5 v5.6.0 github.com/jaswdr/faker v1.19.1 @@ -266,7 +265,7 @@ require ( github.com/gorilla/css v1.0.1 // indirect github.com/gorilla/handlers v1.5.2 // indirect github.com/gorilla/mux v1.8.1 // indirect - github.com/gorilla/websocket v1.5.3 + github.com/gorilla/websocket v1.5.3 // indirect github.com/gosimple/unidecode v1.0.1 // indirect github.com/govalues/decimal v0.1.29 // indirect github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 // indirect @@ -393,7 +392,7 @@ require ( gopkg.in/jcmturner/rpc.v1 v1.1.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect - gopkg.in/yaml.v3 v3.0.1 + gopkg.in/yaml.v3 v3.0.1 // indirect modernc.org/gc/v3 v3.0.0-20240107210532-573471604cb6 // indirect modernc.org/libc v1.55.3 // indirect modernc.org/mathutil v1.6.0 // indirect diff --git a/go.sum b/go.sum index f193a3aa89..66681fbcfb 100644 --- a/go.sum +++ b/go.sum @@ -690,8 +690,6 @@ github.com/jackc/pgconn v1.14.3 h1:bVoTr12EGANZz66nZPkMInAV/KHD2TxH9npjXXgiB3w= github.com/jackc/pgconn v1.14.3/go.mod h1:RZbme4uasqzybK2RK5c65VsHxoyaml09lx3tXOcO/VM= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= -github.com/jackc/pglogrepl v0.0.0-20240307033717-828fbfe908e9 h1:86CQbMauoZdLS0HDLcEHYo6rErjiCBjVvcxGsioIn7s= -github.com/jackc/pglogrepl v0.0.0-20240307033717-828fbfe908e9/go.mod h1:SO15KF4QqfUM5UhsG9roXre5qeAQLC1rm8a8Gjpgg5k= github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c= github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 h1:DadwsjnMwFjfWc9y5Wi/+Zz7xoE5ALHsRQlOctkOiHc= diff --git a/internal/impl/postgresql/pg_stream/pg_stream/integration_test.go b/internal/impl/postgresql/pg_stream/pg_stream/integration_test.go index bec6d7e0f9..696df06f5f 100644 --- a/internal/impl/postgresql/pg_stream/pg_stream/integration_test.go +++ b/internal/impl/postgresql/pg_stream/pg_stream/integration_test.go @@ -43,7 +43,6 @@ func TestIntegrationPgCDC(t *testing.T) { "POSTGRES_USER=user_name", "POSTGRES_DB=dbname", }, - ExposedPorts: []string{"5432"}, Cmd: []string{ "postgres", "-c", "wal_level=logical", @@ -106,7 +105,6 @@ func TestIntegrationPgCDC(t *testing.T) { fake := faker.New() for i := 0; i < 1000; i++ { _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) - _, err = db.Exec("INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) require.NoError(t, err) } @@ -119,6 +117,7 @@ pg_stream: password: secret port: %s schema: public + decoding_plugin: wal2json tls: none stream_snapshot: true database: dbname @@ -214,7 +213,7 @@ file: }, time.Second*20, time.Millisecond*100) require.NoError(t, streamOut.StopWithin(time.Second*10)) - t.Log("All the conditions are met 🎉", len(outMessages)) + t.Log("All the conditions are met 🎉") t.Cleanup(func() { db.Close() @@ -234,7 +233,6 @@ func TestIntegrationPgCDCForPgOutputPlugin(t *testing.T) { "POSTGRES_USER=user_name", "POSTGRES_DB=dbname", }, - ExposedPorts: []string{"5432"}, Cmd: []string{ "postgres", "-c", "wal_level=logical", @@ -277,7 +275,12 @@ func TestIntegrationPgCDCForPgOutputPlugin(t *testing.T) { } _, err = db.Exec("CREATE TABLE IF NOT EXISTS flights (id serial PRIMARY KEY, name VARCHAR(50), created_at TIMESTAMP);") + if err != nil { + return err + } + // flights_non_streamed is a control table with data that should not be streamed or queried by snapshot streaming + _, err = db.Exec("CREATE TABLE IF NOT EXISTS flights_non_streamed (id serial PRIMARY KEY, name VARCHAR(50), created_at TIMESTAMP);") return err }); err != nil { panic(fmt.Errorf("could not connect to docker: %w", err)) @@ -343,6 +346,7 @@ file: for i := 0; i < 10; i++ { _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + _, err = db.Exec("INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) require.NoError(t, err) } diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 8636e209ae..c1cc322966 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -27,8 +27,6 @@ import ( "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/internal/helpers" ) -var pluginArguments = []string{"\"pretty-print\" 'true'"} - type Stream struct { pgConn *pgconn.PgConn // extra copy of db config is required to establish a new db connection @@ -53,6 +51,7 @@ type Stream struct { separateChanges bool snapshotBatchSize int decodingPlugin DecodingPlugin + decodingPluginArguments []string snapshotMemorySafetyFactor float64 logger *service.Logger @@ -117,6 +116,11 @@ func NewPgStream(config Config) (*Stream, error) { decodingPlugin: DecodingPluginFromString(config.DecodingPlugin), } + for i, table := range tableNames { + tableNames[i] = fmt.Sprintf("%s.%s", config.DbSchema, table) + } + + var pluginArguments = []string{} if stream.decodingPlugin == "pgoutput" { pluginArguments = []string{ "proto_version '1'", @@ -125,16 +129,22 @@ func NewPgStream(config Config) (*Stream, error) { } } + if stream.decodingPlugin == "wal2json" { + tablesFilterRule := strings.Join(tableNames, ", ") + pluginArguments = []string{ + "\"pretty-print\" 'true'", + "\"add-tables\"" + " " + fmt.Sprintf("'%s'", tablesFilterRule), + } + } + + stream.decodingPluginArguments = pluginArguments + result := stream.pgConn.Exec(context.Background(), fmt.Sprintf("DROP PUBLICATION IF EXISTS pglog_stream_%s;", config.ReplicationSlotName)) _, err = result.ReadAll() if err != nil { return nil, err } - for i, table := range tableNames { - tableNames[i] = fmt.Sprintf("%s.%s", config.DbSchema, table) - } - tablesSchemaFilter := fmt.Sprintf("FOR TABLE %s", strings.Join(tableNames, ",")) stream.logger.Infof("Create publication for table schemas with query %s", fmt.Sprintf("CREATE PUBLICATION pglog_stream_%s %s;", config.ReplicationSlotName, tablesSchemaFilter)) result = stream.pgConn.Exec(context.Background(), fmt.Sprintf("CREATE PUBLICATION pglog_stream_%s %s;", config.ReplicationSlotName, tablesSchemaFilter)) @@ -216,7 +226,7 @@ func NewPgStream(config Config) (*Stream, error) { func (s *Stream) startLr() error { var err error - err = StartReplication(context.Background(), s.pgConn, s.slotName, s.lsnrestart, StartReplicationOptions{PluginArgs: pluginArguments}) + err = StartReplication(context.Background(), s.pgConn, s.slotName, s.lsnrestart, StartReplicationOptions{PluginArgs: s.decodingPluginArguments}) if err != nil { return err } From 73fb9f79787002c738ca2e08406533c30009051d Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Fri, 4 Oct 2024 11:36:51 +0200 Subject: [PATCH 006/118] chore(): updated tests for pglogical stream --- .../pg_stream/pg_stream/integration_test.go | 163 +++++++++++++++- .../pg_stream/pg_stream/pg_stream.go | 15 +- .../impl/postgresql/pglogicalstream/config.go | 17 +- .../pglogicalstream/logical_stream.go | 30 +-- .../postgresql/pglogicalstream/pglogrepl.go | 17 ++ .../pglogicalstream/pglogrepl_test.go | 178 +++++++++++++----- .../replication_message_decoders.go | 20 +- 7 files changed, 352 insertions(+), 88 deletions(-) diff --git a/internal/impl/postgresql/pg_stream/pg_stream/integration_test.go b/internal/impl/postgresql/pg_stream/pg_stream/integration_test.go index 696df06f5f..af46dca815 100644 --- a/internal/impl/postgresql/pg_stream/pg_stream/integration_test.go +++ b/internal/impl/postgresql/pg_stream/pg_stream/integration_test.go @@ -256,7 +256,7 @@ func TestIntegrationPgCDCForPgOutputPlugin(t *testing.T) { var db *sql.DB pool.MaxWait = 120 * time.Second - if err = pool.Retry(func() error { + err = pool.Retry(func() error { if db, err = sql.Open("postgres", databaseURL); err != nil { return err } @@ -282,9 +282,8 @@ func TestIntegrationPgCDCForPgOutputPlugin(t *testing.T) { // flights_non_streamed is a control table with data that should not be streamed or queried by snapshot streaming _, err = db.Exec("CREATE TABLE IF NOT EXISTS flights_non_streamed (id serial PRIMARY KEY, name VARCHAR(50), created_at TIMESTAMP);") return err - }); err != nil { - panic(fmt.Errorf("could not connect to docker: %w", err)) - } + }) + require.NoError(t, err) fake := faker.New() for i := 0; i < 10; i++ { @@ -402,3 +401,159 @@ file: db.Close() }) } + +func TestNeonPostgresCDCReplication(t *testing.T) { + t.Skip() + tmpDir := t.TempDir() + + dbhost := "" // neon db host + dbport := "" // neondb port + dbPassword := "" // neondb password + dbUser := "" // neondb user + dbName := "" // neondb name + + databaseURL := fmt.Sprintf("user=%s password=%s dbname=%s sslmode=require host=%s port=%s", dbUser, dbPassword, dbName, dbhost, dbport) + + db, err := sql.Open("postgres", databaseURL) + require.NoError(t, err) + + err = db.Ping() + require.NoError(t, err) + + var walLevel string + err = db.QueryRow("SHOW wal_level").Scan(&walLevel) + require.NoError(t, err) + + assert.Equal(t, "logical", walLevel) + + _, err = db.Exec("DROP TABLE IF EXISTS flights;") + require.NoError(t, err) + _, err = db.Exec("DROP TABLE IF EXISTS flights_non_streamed;") + require.NoError(t, err) + + _, err = db.Exec("CREATE TABLE IF NOT EXISTS flights (id serial PRIMARY KEY, name VARCHAR(50), created_at TIMESTAMP);") + require.NoError(t, err) + + // flights_non_streamed is a control table with data that should not be streamed or queried by snapshot streaming + _, err = db.Exec("CREATE TABLE IF NOT EXISTS flights_non_streamed (id serial PRIMARY KEY, name VARCHAR(50), created_at TIMESTAMP);") + require.NoError(t, err) + + fake := faker.New() + for i := 0; i < 10; i++ { + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + require.NoError(t, err) + } + + dbConnOptions := "" // neondb connection options for the endpoint + + template := fmt.Sprintf(` +pg_stream: + host: %s + slot_name: test_slot_native_decoder + user: %s + password: %s + port: %s + schema: public + tls: require + stream_snapshot: true + pg_conn_options: endpoint=%s + decoding_plugin: wal2json + database: %s + tables: + - flights +`, dbhost, dbUser, dbPassword, dbport, dbConnOptions, dbName) + + cacheConf := fmt.Sprintf(` +label: pg_stream_cache +file: + directory: %v +`, tmpDir) + + streamOutBuilder := service.NewStreamBuilder() + require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: OFF`)) + require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) + require.NoError(t, streamOutBuilder.AddInputYAML(template)) + + var outMessages []string + var outMessagesMut sync.Mutex + + require.NoError(t, streamOutBuilder.AddConsumerFunc(func(c context.Context, m *service.Message) error { + msgBytes, err := m.AsBytes() + require.NoError(t, err) + outMessagesMut.Lock() + outMessages = append(outMessages, string(msgBytes)) + outMessagesMut.Unlock() + return nil + })) + + streamOut, err := streamOutBuilder.Build() + require.NoError(t, err) + + go func() { + _ = streamOut.Run(context.Background()) + }() + + assert.Eventually(t, func() bool { + outMessagesMut.Lock() + defer outMessagesMut.Unlock() + return len(outMessages) == 10 + }, time.Minute, time.Millisecond*100) + + for i := 0; i < 10; i++ { + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + _, err = db.Exec("INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + require.NoError(t, err) + } + + assert.Eventually(t, func() bool { + outMessagesMut.Lock() + defer outMessagesMut.Unlock() + return len(outMessages) == 20 + }, time.Minute, time.Millisecond*100) + + require.NoError(t, streamOut.StopWithin(time.Second*10)) + + // Starting stream for the same replication slot should continue from the last LSN + // Meaning we must not receive any old messages again + + streamOutBuilder = service.NewStreamBuilder() + require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: OFF`)) + require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) + require.NoError(t, streamOutBuilder.AddInputYAML(template)) + + outMessages = []string{} + require.NoError(t, streamOutBuilder.AddConsumerFunc(func(c context.Context, m *service.Message) error { + msgBytes, err := m.AsBytes() + require.NoError(t, err) + outMessagesMut.Lock() + outMessages = append(outMessages, string(msgBytes)) + outMessagesMut.Unlock() + return nil + })) + + streamOut, err = streamOutBuilder.Build() + require.NoError(t, err) + + go func() { + assert.NoError(t, streamOut.Run(context.Background())) + }() + + time.Sleep(time.Second * 5) + for i := 0; i < 10; i++ { + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + require.NoError(t, err) + } + + assert.Eventually(t, func() bool { + outMessagesMut.Lock() + defer outMessagesMut.Unlock() + return len(outMessages) == 10 + }, time.Minute, time.Millisecond*100) + + require.NoError(t, streamOut.StopWithin(time.Second*10)) + t.Log("All the conditions are met 🎉") + + t.Cleanup(func() { + db.Close() + }) +} diff --git a/internal/impl/postgresql/pg_stream/pg_stream/pg_stream.go b/internal/impl/postgresql/pg_stream/pg_stream/pg_stream.go index 8a3dd20408..a7393bf880 100644 --- a/internal/impl/postgresql/pg_stream/pg_stream/pg_stream.go +++ b/internal/impl/postgresql/pg_stream/pg_stream/pg_stream.go @@ -47,6 +47,7 @@ var pgStreamConfigSpec = service.NewConfigSpec(). Description("Defines whether benthos need to verify (skipinsecure) TLS configuration"). Example("none"). Default("none")). + Field(service.NewStringField("pg_conn_options").Default("")). Field(service.NewBoolField("stream_snapshot"). Description("Set `true` if you want to receive all the data that currently exist in database"). Example(true). @@ -83,6 +84,7 @@ func newPgStreamInput(conf *service.ParsedConfig, logger *service.Logger) (s ser streamSnapshot bool snapshotMemSafetyFactor float64 decodingPlugin string + pgConnOptions string ) dbSchema, err = conf.FieldString("schema") @@ -149,6 +151,14 @@ func newPgStreamInput(conf *service.ParsedConfig, logger *service.Logger) (s ser return nil, err } + if pgConnOptions, err = conf.FieldString("pg_conn_options"); err != nil { + return nil, err + } + + if pgConnOptions != "" { + pgConnOptions = fmt.Sprintf("options=%s", pgConnOptions) + } + pgconnConfig := pgconn.Config{ Host: dbHost, Port: uint16(dbPort), @@ -170,6 +180,7 @@ func newPgStreamInput(conf *service.ParsedConfig, logger *service.Logger) (s ser snapshotMemSafetyFactor: snapshotMemSafetyFactor, slotName: dbSlotName, schema: dbSchema, + pgConnRuntimeParam: pgConnOptions, tls: pglogicalstream.TlsVerify(tlsSetting), tables: tables, decodingPlugin: decodingPlugin, @@ -193,19 +204,21 @@ func init() { type pgStreamInput struct { dbConfig pgconn.Config + tls pglogicalstream.TlsVerify pglogicalStream *pglogicalstream.Stream + pgConnRuntimeParam string slotName string schema string tables []string decodingPlugin string streamSnapshot bool - tls pglogicalstream.TlsVerify // none, require snapshotMemSafetyFactor float64 logger *service.Logger } func (p *pgStreamInput) Connect(ctx context.Context) error { pgStream, err := pglogicalstream.NewPgStream(pglogicalstream.Config{ + PgConnRuntimeParam: p.pgConnRuntimeParam, DbHost: p.dbConfig.Host, DbPassword: p.dbConfig.Password, DbUser: p.dbConfig.User, diff --git a/internal/impl/postgresql/pglogicalstream/config.go b/internal/impl/postgresql/pglogicalstream/config.go index 446c276dc6..f31cbb6f0c 100644 --- a/internal/impl/postgresql/pglogicalstream/config.go +++ b/internal/impl/postgresql/pglogicalstream/config.go @@ -11,14 +11,15 @@ package pglogicalstream import "github.com/redpanda-data/benthos/v4/public/service" type Config struct { - DbHost string `yaml:"db_host"` - DbPassword string `yaml:"db_password"` - DbUser string `yaml:"db_user"` - DbPort int `yaml:"db_port"` - DbName string `yaml:"db_name"` - DbSchema string `yaml:"db_schema"` - DbTables []string `yaml:"db_tables"` - TlsVerify TlsVerify `yaml:"tls_verify"` + DbHost string `yaml:"db_host"` + DbPassword string `yaml:"db_password"` + DbUser string `yaml:"db_user"` + DbPort int `yaml:"db_port"` + DbName string `yaml:"db_name"` + DbSchema string `yaml:"db_schema"` + DbTables []string `yaml:"db_tables"` + TlsVerify TlsVerify `yaml:"tls_verify"` + PgConnRuntimeParam string `yaml:"pg_conn_options"` ReplicationSlotName string `yaml:"replication_slot_name"` StreamOldData bool `yaml:"stream_old_data"` diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index c1cc322966..c17edee921 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -70,21 +70,28 @@ func NewPgStream(config Config) (*Stream, error) { sslVerifyFull = "&sslmode=verify-full" } - if cfg, err = pgconn.ParseConfig(fmt.Sprintf("postgres://%s:%s@%s:%d/%s?replication=database%s", + connectionParams := "" + if config.PgConnRuntimeParam != "" { + connectionParams = fmt.Sprintf("&%s", config.PgConnRuntimeParam) + } + + q := fmt.Sprintf("postgres://%s:%s@%s:%d/%s?replication=database%s%s", config.DbUser, config.DbPassword, config.DbHost, config.DbPort, config.DbName, sslVerifyFull, - )); err != nil { + connectionParams, + ) + + if cfg, err = pgconn.ParseConfig(q); err != nil { return nil, err } if config.TlsVerify == TlsRequireVerify { cfg.TLSConfig = &tls.Config{ InsecureSkipVerify: true, - ServerName: config.DbHost, } } else { cfg.TLSConfig = nil @@ -95,6 +102,10 @@ func NewPgStream(config Config) (*Stream, error) { return nil, err } + if err = dbConn.Ping(context.Background()); err != nil { + return nil, err + } + var tableNames []string for _, table := range config.DbTables { tableNames = append(tableNames, table) @@ -139,17 +150,8 @@ func NewPgStream(config Config) (*Stream, error) { stream.decodingPluginArguments = pluginArguments - result := stream.pgConn.Exec(context.Background(), fmt.Sprintf("DROP PUBLICATION IF EXISTS pglog_stream_%s;", config.ReplicationSlotName)) - _, err = result.ReadAll() - if err != nil { - return nil, err - } - - tablesSchemaFilter := fmt.Sprintf("FOR TABLE %s", strings.Join(tableNames, ",")) - stream.logger.Infof("Create publication for table schemas with query %s", fmt.Sprintf("CREATE PUBLICATION pglog_stream_%s %s;", config.ReplicationSlotName, tablesSchemaFilter)) - result = stream.pgConn.Exec(context.Background(), fmt.Sprintf("CREATE PUBLICATION pglog_stream_%s %s;", config.ReplicationSlotName, tablesSchemaFilter)) - _, err = result.ReadAll() - if err != nil { + pubName := fmt.Sprintf("pglog_stream_%s", config.ReplicationSlotName) + if err = CreatePublication(context.Background(), stream.pgConn, pubName, tableNames, true); err != nil { return nil, err } diff --git a/internal/impl/postgresql/pglogicalstream/pglogrepl.go b/internal/impl/postgresql/pglogicalstream/pglogrepl.go index 1b625d6f76..0b04b0ea58 100644 --- a/internal/impl/postgresql/pglogicalstream/pglogrepl.go +++ b/internal/impl/postgresql/pglogicalstream/pglogrepl.go @@ -288,6 +288,23 @@ func DropReplicationSlot(ctx context.Context, conn *pgconn.PgConn, slotName stri return err } +func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName string, tables []string, dropIfExist bool) error { + result := conn.Exec(context.Background(), fmt.Sprintf("DROP PUBLICATION IF EXISTS %s;", publicationName)) + if _, err := result.ReadAll(); err != nil { + return nil + } + + tablesSchemaFilter := fmt.Sprintf("FOR TABLE %s", strings.Join(tables, ",")) + if len(tables) == 0 { + tablesSchemaFilter = "FOR ALL TABLES" + } + result = conn.Exec(context.Background(), fmt.Sprintf("CREATE PUBLICATION %s %s;", publicationName, tablesSchemaFilter)) + if _, err := result.ReadAll(); err != nil { + return err + } + return nil +} + type StartReplicationOptions struct { Timeline int32 // 0 means current server timeline Mode ReplicationMode diff --git a/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go b/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go index f9b501e8bc..e726b2f08a 100644 --- a/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go +++ b/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go @@ -10,13 +10,18 @@ package pglogicalstream import ( "context" + "database/sql" + "encoding/json" "fmt" - "os" + "strings" "testing" "time" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgproto3" + "github.com/jackc/pgx/v5/pgtype" + "github.com/ory/dockertest/v3" + "github.com/ory/dockertest/v3/docker" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" @@ -82,7 +87,7 @@ func (s *lsnSuite) TestValueInterface() { } const slotName = "pglogrepl_test" -const outputPlugin = "test_decoding" +const outputPlugin = "pgoutput" func closeConn(t testing.TB, conn *pgconn.PgConn) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) @@ -90,64 +95,82 @@ func closeConn(t testing.TB, conn *pgconn.PgConn) { require.NoError(t, conn.Close(ctx)) } -func TestIdentifySystem(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() +func createDockerInstance(t *testing.T) (*dockertest.Pool, *dockertest.Resource, string) { + pool, err := dockertest.NewPool("") + require.NoError(t, err) + + resource, err := pool.RunWithOptions(&dockertest.RunOptions{ + Repository: "postgres", + Tag: "16", + Env: []string{ + "POSTGRES_PASSWORD=secret", + "POSTGRES_USER=user_name", + "POSTGRES_DB=dbname", + }, + Cmd: []string{ + "postgres", + "-c", "wal_level=logical", + }, + }, func(config *docker.HostConfig) { + config.AutoRemove = true + config.RestartPolicy = docker.RestartPolicy{Name: "no"} + }) - conn, err := pgconn.Connect(ctx, os.Getenv("PGLOGREPL_TEST_CONN_STRING")) require.NoError(t, err) - defer closeConn(t, conn) + require.NoError(t, resource.Expire(120)) - sysident, err := IdentifySystem(ctx, conn) + hostAndPort := resource.GetHostPort("5432/tcp") + hostAndPortSplited := strings.Split(hostAndPort, ":") + databaseURL := fmt.Sprintf("user=user_name password=secret dbname=dbname sslmode=disable host=%s port=%s replication=database", hostAndPortSplited[0], hostAndPortSplited[1]) + + var db *sql.DB + pool.MaxWait = 120 * time.Second + err = pool.Retry(func() error { + if db, err = sql.Open("postgres", databaseURL); err != nil { + return err + } + + if err = db.Ping(); err != nil { + return err + } + + return err + }) require.NoError(t, err) - assert.Greater(t, len(sysident.SystemID), 0) - assert.True(t, sysident.Timeline > 0) - assert.True(t, sysident.XLogPos > 0) - assert.Greater(t, len(sysident.DBName), 0) + return pool, resource, databaseURL } -func TestGetHistoryFile(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) +func TestIdentifySystem(t *testing.T) { + pool, resource, dbUrl := createDockerInstance(t) + defer pool.Purge(resource) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*100) defer cancel() - config, err := pgconn.ParseConfig(os.Getenv("PGLOGREPL_TEST_CONN_STRING")) - require.NoError(t, err) - config.RuntimeParams["replication"] = "on" - - conn, err := pgconn.ConnectConfig(ctx, config) + conn, err := pgconn.Connect(ctx, dbUrl) require.NoError(t, err) defer closeConn(t, conn) sysident, err := IdentifySystem(ctx, conn) require.NoError(t, err) - _, err = TimelineHistory(ctx, conn, 0) - require.Error(t, err) - - _, err = TimelineHistory(ctx, conn, 1) - require.Error(t, err) - - if sysident.Timeline > 1 { - // This test requires a Postgres with at least 1 timeline increase (promote, or recover)... - tlh, err := TimelineHistory(ctx, conn, sysident.Timeline) - require.NoError(t, err) - - expectedFileName := fmt.Sprintf("%08X.history", sysident.Timeline) - assert.Equal(t, expectedFileName, tlh.FileName) - assert.Greater(t, len(tlh.Content), 0) - } + assert.Greater(t, len(sysident.SystemID), 0) + assert.True(t, sysident.Timeline > 0) + assert.True(t, sysident.XLogPos > 0) + assert.Greater(t, len(sysident.DBName), 0) } func TestCreateReplicationSlot(t *testing.T) { + pool, resource, dbUrl := createDockerInstance(t) + defer pool.Purge(resource) ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - conn, err := pgconn.Connect(ctx, os.Getenv("PGLOGREPL_TEST_CONN_STRING")) + conn, err := pgconn.Connect(ctx, dbUrl) require.NoError(t, err) defer closeConn(t, conn) - result, err := CreateReplicationSlot(ctx, conn, slotName, outputPlugin, CreateReplicationSlotOptions{Temporary: true}) + result, err := CreateReplicationSlot(ctx, conn, slotName, outputPlugin, CreateReplicationSlotOptions{Temporary: false, SnapshotAction: "export"}) require.NoError(t, err) assert.Equal(t, slotName, result.SlotName) @@ -155,45 +178,62 @@ func TestCreateReplicationSlot(t *testing.T) { } func TestDropReplicationSlot(t *testing.T) { + pool, resource, dbUrl := createDockerInstance(t) + defer pool.Purge(resource) ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - conn, err := pgconn.Connect(ctx, os.Getenv("PGLOGREPL_TEST_CONN_STRING")) + conn, err := pgconn.Connect(ctx, dbUrl) require.NoError(t, err) defer closeConn(t, conn) - _, err = CreateReplicationSlot(ctx, conn, slotName, outputPlugin, CreateReplicationSlotOptions{Temporary: true}) + _, err = CreateReplicationSlot(ctx, conn, slotName, outputPlugin, CreateReplicationSlotOptions{Temporary: false}) require.NoError(t, err) err = DropReplicationSlot(ctx, conn, slotName, DropReplicationSlotOptions{}) require.NoError(t, err) - _, err = CreateReplicationSlot(ctx, conn, slotName, outputPlugin, CreateReplicationSlotOptions{Temporary: true}) + _, err = CreateReplicationSlot(ctx, conn, slotName, outputPlugin, CreateReplicationSlotOptions{Temporary: false}) require.NoError(t, err) } func TestStartReplication(t *testing.T) { + pool, resource, dbUrl := createDockerInstance(t) + defer pool.Purge(resource) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - conn, err := pgconn.Connect(ctx, os.Getenv("PGLOGREPL_TEST_CONN_STRING")) + conn, err := pgconn.Connect(ctx, dbUrl) require.NoError(t, err) defer closeConn(t, conn) sysident, err := IdentifySystem(ctx, conn) require.NoError(t, err) - _, err = CreateReplicationSlot(ctx, conn, slotName, outputPlugin, CreateReplicationSlotOptions{Temporary: true}) + // create publication + publicationName := "test_publication" + err = CreatePublication(context.Background(), conn, publicationName, []string{}, true) + require.NoError(t, err) + + _, err = CreateReplicationSlot(ctx, conn, slotName, outputPlugin, CreateReplicationSlotOptions{Temporary: false, SnapshotAction: "export"}) require.NoError(t, err) - err = StartReplication(ctx, conn, slotName, sysident.XLogPos, StartReplicationOptions{}) + err = StartReplication(ctx, conn, slotName, sysident.XLogPos, StartReplicationOptions{ + PluginArgs: []string{ + "proto_version '1'", + "publication_names 'test_publication'", + "messages 'true'", + }, + Mode: LogicalReplication, + }) require.NoError(t, err) go func() { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - config, err := pgconn.ParseConfig(os.Getenv("PGLOGREPL_TEST_CONN_STRING")) + config, err := pgconn.ParseConfig(dbUrl) require.NoError(t, err) delete(config.RuntimeParams, "replication") @@ -229,6 +269,9 @@ drop table t; return pkm } + relations := map[uint32]*RelationMessage{} + typeMap := pgtype.NewMap() + rxXLogData := func() XLogData { var cdMsg *pgproto3.CopyData // Discard keepalive messages @@ -250,26 +293,59 @@ drop table t; rxKeepAlive() xld := rxXLogData() - assert.Equal(t, "BEGIN", string(xld.WALData[:5])) + begin, err := IsBeginMessage(xld.WALData) + require.NoError(t, err) + assert.Equal(t, true, begin) + xld = rxXLogData() - assert.Equal(t, "table public.t: INSERT: id[integer]:1 name[text]:'foo'", string(xld.WALData)) + relationStreamMessage, err := DecodePgOutput(xld.WALData, relations, typeMap) + require.NoError(t, err) + assert.Nil(t, relationStreamMessage) + xld = rxXLogData() - assert.Equal(t, "table public.t: INSERT: id[integer]:2 name[text]:'bar'", string(xld.WALData)) + streamMessage, err := DecodePgOutput(xld.WALData, relations, typeMap) + jsonData, err := json.Marshal(&streamMessage) + require.NoError(t, err) + assert.Equal(t, "{\"operation\":\"insert\",\"schema\":\"public\",\"table\":\"t\",\"data\":{\"id\":1,\"name\":\"foo\"}}", string(jsonData)) + + xld = rxXLogData() + streamMessage, err = DecodePgOutput(xld.WALData, relations, typeMap) + jsonData, err = json.Marshal(&streamMessage) + require.NoError(t, err) + assert.Equal(t, "{\"operation\":\"insert\",\"schema\":\"public\",\"table\":\"t\",\"data\":{\"id\":2,\"name\":\"bar\"}}", string(jsonData)) + xld = rxXLogData() - assert.Equal(t, "table public.t: INSERT: id[integer]:3 name[text]:'baz'", string(xld.WALData)) + streamMessage, err = DecodePgOutput(xld.WALData, relations, typeMap) + jsonData, err = json.Marshal(&streamMessage) + require.NoError(t, err) + assert.Equal(t, "{\"operation\":\"insert\",\"schema\":\"public\",\"table\":\"t\",\"data\":{\"id\":3,\"name\":\"baz\"}}", string(jsonData)) + xld = rxXLogData() - assert.Equal(t, "table public.t: UPDATE: id[integer]:3 name[text]:'quz'", string(xld.WALData)) + streamMessage, err = DecodePgOutput(xld.WALData, relations, typeMap) + jsonData, err = json.Marshal(&streamMessage) + require.NoError(t, err) + assert.Equal(t, "{\"operation\":\"update\",\"schema\":\"public\",\"table\":\"t\",\"data\":{\"id\":3,\"name\":\"quz\"}}", string(jsonData)) + xld = rxXLogData() - assert.Equal(t, "table public.t: DELETE: id[integer]:2", string(xld.WALData)) + streamMessage, err = DecodePgOutput(xld.WALData, relations, typeMap) + jsonData, err = json.Marshal(&streamMessage) + require.NoError(t, err) + assert.Equal(t, "{\"operation\":\"delete\",\"schema\":\"public\",\"table\":\"t\",\"data\":{\"id\":2,\"name\":null}}", string(jsonData)) xld = rxXLogData() - assert.Equal(t, "COMMIT", string(xld.WALData[:6])) + + commit, err := IsCommitMessage(xld.WALData) + require.NoError(t, err) + assert.Equal(t, true, commit) } func TestSendStandbyStatusUpdate(t *testing.T) { + pool, resource, dbUrl := createDockerInstance(t) + defer pool.Purge(resource) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - conn, err := pgconn.Connect(ctx, os.Getenv("PGLOGREPL_TEST_CONN_STRING")) + conn, err := pgconn.Connect(ctx, dbUrl) require.NoError(t, err) defer closeConn(t, conn) diff --git a/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go b/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go index 2070f88e86..551deec921 100644 --- a/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go +++ b/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go @@ -4,8 +4,9 @@ import ( "bytes" "encoding/json" "fmt" - "github.com/jackc/pgx/v5/pgtype" "log" + + "github.com/jackc/pgx/v5/pgtype" ) // ---------------------------------------------------------------------------- @@ -36,9 +37,9 @@ func IsCommitMessage(WALData []byte) (bool, error) { // as a side effect it updates the relations map with any new relation metadata // When the relation is changes in the database, the relation message is sent // before the change message. -func DecodePgOutput(WALData []byte, relations map[uint32]*RelationMessage, typeMap *pgtype.Map) (message *StreamMessageChanges, err error) { +func DecodePgOutput(WALData []byte, relations map[uint32]*RelationMessage, typeMap *pgtype.Map) (*StreamMessageChanges, error) { logicalMsg, err := Parse(WALData) - message = &StreamMessageChanges{} + message := &StreamMessageChanges{} if err != nil { return nil, err @@ -75,13 +76,15 @@ func DecodePgOutput(WALData []byte, relations map[uint32]*RelationMessage, typeM values[colName] = val } } - message.Data = values case *UpdateMessage: rel, ok := relations[logicalMsg.RelationID] if !ok { return nil, fmt.Errorf("unknown relation ID %d", logicalMsg.RelationID) } + message.Operation = "update" + message.Schema = rel.Namespace + message.Table = rel.RelationName values := map[string]interface{}{} for idx, col := range logicalMsg.NewTuple.Columns { colName := rel.Columns[idx].Name @@ -99,13 +102,14 @@ func DecodePgOutput(WALData []byte, relations map[uint32]*RelationMessage, typeM } } message.Data = values - //log.Printf("UPDATE %s.%s: SET %v", rel.Namespace, rel.RelationName, values) case *DeleteMessage: rel, ok := relations[logicalMsg.RelationID] if !ok { return nil, fmt.Errorf("unknown relation ID %d", logicalMsg.RelationID) } - + message.Operation = "delete" + message.Schema = rel.Namespace + message.Table = rel.RelationName values := map[string]interface{}{} for idx, col := range logicalMsg.OldTuple.Columns { colName := rel.Columns[idx].Name @@ -124,15 +128,11 @@ func DecodePgOutput(WALData []byte, relations map[uint32]*RelationMessage, typeM } message.Data = values case *TruncateMessage: - case *TypeMessage: case *OriginMessage: - case *LogicalDecodingMessage: - log.Printf("Logical decoding message: %q, %q", logicalMsg.Prefix, logicalMsg.Content) return nil, nil default: - log.Printf("Unknown message type in pgoutput stream: %T", logicalMsg) return nil, nil } From 5306dcc3a451eeda21ce9611903f1bdf8918c11d Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Fri, 4 Oct 2024 11:38:58 +0200 Subject: [PATCH 007/118] chore(): fmt applied --- internal/impl/kafka/enterprise/schema_registry_input.go | 1 + internal/impl/kafka/enterprise/schema_registry_output.go | 1 + 2 files changed, 2 insertions(+) diff --git a/internal/impl/kafka/enterprise/schema_registry_input.go b/internal/impl/kafka/enterprise/schema_registry_input.go index 5b4af24507..772b802052 100644 --- a/internal/impl/kafka/enterprise/schema_registry_input.go +++ b/internal/impl/kafka/enterprise/schema_registry_input.go @@ -20,6 +20,7 @@ import ( "sync" "github.com/redpanda-data/benthos/v4/public/service" + "github.com/redpanda-data/connect/v4/internal/impl/confluent/sr" ) diff --git a/internal/impl/kafka/enterprise/schema_registry_output.go b/internal/impl/kafka/enterprise/schema_registry_output.go index be3cb1c95a..1f9a1f6c41 100644 --- a/internal/impl/kafka/enterprise/schema_registry_output.go +++ b/internal/impl/kafka/enterprise/schema_registry_output.go @@ -18,6 +18,7 @@ import ( "sync/atomic" "github.com/redpanda-data/benthos/v4/public/service" + "github.com/redpanda-data/connect/v4/internal/impl/confluent/sr" ) From 7031109f0abc1339cf2716a891edda9d9605487d Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Fri, 4 Oct 2024 11:44:35 +0200 Subject: [PATCH 008/118] chore(): code re-org --- .../pg_stream.go => input_postgrecdc.go} | 0 .../pg_stream => }/integration_test.go | 0 internal/impl/postgresql/pg_stream/README.md | 48 ---------- .../impl/postgresql/pglogicalstream/README.MD | 91 ------------------- .../{internal/helpers => }/availablememory.go | 2 +- .../pglogicalstream/logical_stream.go | 7 +- 6 files changed, 4 insertions(+), 144 deletions(-) rename internal/impl/postgresql/{pg_stream/pg_stream/pg_stream.go => input_postgrecdc.go} (100%) rename internal/impl/postgresql/{pg_stream/pg_stream => }/integration_test.go (100%) delete mode 100644 internal/impl/postgresql/pg_stream/README.md delete mode 100644 internal/impl/postgresql/pglogicalstream/README.MD rename internal/impl/postgresql/pglogicalstream/{internal/helpers => }/availablememory.go (96%) diff --git a/internal/impl/postgresql/pg_stream/pg_stream/pg_stream.go b/internal/impl/postgresql/input_postgrecdc.go similarity index 100% rename from internal/impl/postgresql/pg_stream/pg_stream/pg_stream.go rename to internal/impl/postgresql/input_postgrecdc.go diff --git a/internal/impl/postgresql/pg_stream/pg_stream/integration_test.go b/internal/impl/postgresql/integration_test.go similarity index 100% rename from internal/impl/postgresql/pg_stream/pg_stream/integration_test.go rename to internal/impl/postgresql/integration_test.go diff --git a/internal/impl/postgresql/pg_stream/README.md b/internal/impl/postgresql/pg_stream/README.md deleted file mode 100644 index 758da9d8e4..0000000000 --- a/internal/impl/postgresql/pg_stream/README.md +++ /dev/null @@ -1,48 +0,0 @@ -# PostgreSQL Logical Replication Streaming Plugin for Benthos - -Welcome to the PostgreSQL Logical Replication Streaming Plugin for Benthos! This plugin allows you to seamlessly stream data changes from your PostgreSQL database using Benthos, a versatile stream processor. - -## Features - -- **Real-time Data Streaming:** Capture data changes in real-time as they happen in your PostgreSQL database. - -- **Flexible Configuration:** Easily configure the plugin to specify the database connection details, replication slot, and table filtering rules. - -- **Checkpoints:** Store your replication consuming progress in Redis - -## Prerequisites - -Before you begin, make sure you have the following prerequisites: - -- [PostgreSQL](https://www.postgresql.org/): Ensure you have a PostgreSQL database instance that supports logical replication. - -### Create benthos configuration with plugin - -```yaml -input: - label: postgres_cdc_input - # register new plugin - pg_stream: - host: datbase hoat - slot_name: reqplication slot name - user: postgres username with replication permissions - password: password - port: 5432 - schema: schema you want to replicate tables from - stream_snapshot: set true if you want to stream existing data. If set to false only a new data will be streamed - database: name of the database - checkpoint_storage: redis uri if you want to store checkpoints - tables: ## list of tables you want to replicate - - table_name -``` - -### Register processor to pretty format your data -By default, plugins exports raw `wal2json` message. If you want to receive your data as json structure -without metadata to transform it with benthos - you can register `pg_stream_schemaless` plugin to transform it - -```yaml -pipeline: - processors: - - label: pretty_changes_processor - pg_stream_schemaless: { } -``` diff --git a/internal/impl/postgresql/pglogicalstream/README.MD b/internal/impl/postgresql/pglogicalstream/README.MD deleted file mode 100644 index 89fa6fc314..0000000000 --- a/internal/impl/postgresql/pglogicalstream/README.MD +++ /dev/null @@ -1,91 +0,0 @@ -This Go module builds upon [github.com/jackc/pglogrepl](https://github.com/jackc/pglogrepl) to provide an advanced -logical replication solution for PostgreSQL. It extends the capabilities of jackc/pglogrep for logical replication by -introducing several key features, making it easier to implement Change Data Capture (CDC) in your Go-based applications. - -## Features - -- **Checkpoints Storing:** Efficiently manage and store replication checkpoints, facilitating better tracking and - management of data changes. - -- **Snapshot Streaming:** Seamlessly capture and replicate snapshots of your PostgreSQL database, ensuring all data is - streamed through the pipeline. - -- **Table Filtering:** Tailor your CDC process by selectively filtering and replicating specific tables, optimizing - resource usage. - -## Getting Started - -Follow these steps to get started with our PostgreSQL Logical Replication CDC Module for Go: - -### Configure your replication stream - -Create `config.yaml` file - -```yaml -db_host: database host -db_password: password12345 -db_user: postgres -db_port: 5432 -db_name: mocks -db_schema: public -db_tables: - - rides -replication_slot_name: morning_elephant -tls_verify: require -stream_old_data: true -``` - -### Basic usage example - -By default `pglogicalstream` will create replication slot and publication for the tables you provide in Yaml config -It immediately starts streaming updates and you can receive them in the `OnMessage` function - -```go -package main - -import ( - "fmt" - "github.com/usedatabrew/pglogicalstream" - "gopkg.in/yaml.v3" - "io/ioutil" - "log" -) - -func main() { - var config pglogicalstream.Config - yamlFile, err := ioutil.ReadFile("./example/simple/config.yaml") - if err != nil { - log.Printf("yamlFile.Get err #%v ", err) - } - - err = yaml.Unmarshal(yamlFile, &config) - if err != nil { - log.Fatalf("Unmarshal: %v", err) - } - - pgStream, err := pglogicalstream.NewPgStream(config, log.WithPrefix("pg-cdc")) - if err != nil { - panic(err) - } - - pgStream.OnMessage(func(message messages.Wal2JsonChanges) { - fmt.Println(message.Changes) - }) -} - -``` - -### Example with checkpointer - -In order to recover after the failure, etc you have to store LSN somewhere to continue streaming the data -You can implement `CheckPointer` interface and pass it's instance to `NewPgStreamCheckPointer` and your LSN -will be stored automatically - -```go -checkPointer, err := NewPgStreamCheckPointer("redis.com:port", "user", "password") -if err != nil { - log.Fatalf("Checkpointer error") -} -pgStream, err := pglogicalstream.NewPgStream(config, checkPointer) -``` - diff --git a/internal/impl/postgresql/pglogicalstream/internal/helpers/availablememory.go b/internal/impl/postgresql/pglogicalstream/availablememory.go similarity index 96% rename from internal/impl/postgresql/pglogicalstream/internal/helpers/availablememory.go rename to internal/impl/postgresql/pglogicalstream/availablememory.go index c586ffecc6..55df055faf 100644 --- a/internal/impl/postgresql/pglogicalstream/internal/helpers/availablememory.go +++ b/internal/impl/postgresql/pglogicalstream/availablememory.go @@ -6,7 +6,7 @@ // // https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md -package helpers +package pglogicalstream import "runtime" diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index c17edee921..d5358ecbba 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -23,8 +23,6 @@ import ( "github.com/jackc/pgx/v5/pgproto3" "github.com/jackc/pgx/v5/pgtype" "github.com/redpanda-data/benthos/v4/public/service" - - "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/internal/helpers" ) type Stream struct { @@ -464,8 +462,9 @@ func (s *Stream) processSnapshot() { os.Exit(1) } - batchSize := snapshotter.CalculateBatchSize(helpers.GetAvailableMemory(), uint64(avgRowSizeBytes.Int64)) - s.logger.Infof("Querying snapshot batch_side: %v, available_memory: %v, avg_row_size: %v", batchSize, helpers.GetAvailableMemory(), avgRowSizeBytes.Int64) + availableMemory := GetAvailableMemory() + batchSize := snapshotter.CalculateBatchSize(availableMemory, uint64(avgRowSizeBytes.Int64)) + s.logger.Infof("Querying snapshot batch_side: %v, available_memory: %v, avg_row_size: %v", batchSize, availableMemory, avgRowSizeBytes.Int64) tablePk, err := s.getPrimaryKeyColumn(table) if err != nil { From 184cf6492ea0e730738c28ada2f4c72f729a6d60 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Fri, 4 Oct 2024 15:59:25 +0200 Subject: [PATCH 009/118] chore(): added temp. replication slot and removed outdated code --- internal/impl/postgresql/input_postgrecdc.go | 11 +- internal/impl/postgresql/integration_test.go | 156 ------------------ .../impl/postgresql/pglogicalstream/config.go | 1 + .../pglogicalstream/logical_stream.go | 4 +- .../replication_message_decoders.go | 6 +- 5 files changed, 17 insertions(+), 161 deletions(-) diff --git a/internal/impl/postgresql/input_postgrecdc.go b/internal/impl/postgresql/input_postgrecdc.go index a7393bf880..4663b3286d 100644 --- a/internal/impl/postgresql/input_postgrecdc.go +++ b/internal/impl/postgresql/input_postgrecdc.go @@ -65,6 +65,7 @@ var pgStreamConfigSpec = service.NewConfigSpec(). - my_table_2 `). Description("List of tables we have to create logical replication for")). + Field(service.NewBoolField("temporary_slot").Default(false)). Field(service.NewStringField("slot_name"). Description("PostgeSQL logical replication slot name. You can create it manually before starting the sync. If not provided will be replaced with a random one"). Example("my_test_slot"). @@ -79,6 +80,7 @@ func newPgStreamInput(conf *service.ParsedConfig, logger *service.Logger) (s ser dbUser string dbPassword string dbSlotName string + temporarySlot bool tlsSetting string tables []string streamSnapshot bool @@ -97,6 +99,11 @@ func newPgStreamInput(conf *service.ParsedConfig, logger *service.Logger) (s ser return nil, err } + temporarySlot, err = conf.FieldBool("temporary_slot") + if err != nil { + return nil, err + } + if dbSlotName == "" { dbSlotName = randomSlotName } @@ -185,6 +192,7 @@ func newPgStreamInput(conf *service.ParsedConfig, logger *service.Logger) (s ser tables: tables, decodingPlugin: decodingPlugin, logger: logger, + temporarySlot: temporarySlot, }), err } @@ -208,6 +216,7 @@ type pgStreamInput struct { pglogicalStream *pglogicalstream.Stream pgConnRuntimeParam string slotName string + temporarySlot bool schema string tables []string decodingPlugin string @@ -229,9 +238,9 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { ReplicationSlotName: "rs_" + p.slotName, TlsVerify: p.tls, StreamOldData: p.streamSnapshot, + TemporaryReplicationSlot: p.temporarySlot, DecodingPlugin: p.decodingPlugin, SnapshotMemorySafetyFactor: p.snapshotMemSafetyFactor, - SeparateChanges: true, }) if err != nil { return err diff --git a/internal/impl/postgresql/integration_test.go b/internal/impl/postgresql/integration_test.go index af46dca815..eac1b77b8d 100644 --- a/internal/impl/postgresql/integration_test.go +++ b/internal/impl/postgresql/integration_test.go @@ -401,159 +401,3 @@ file: db.Close() }) } - -func TestNeonPostgresCDCReplication(t *testing.T) { - t.Skip() - tmpDir := t.TempDir() - - dbhost := "" // neon db host - dbport := "" // neondb port - dbPassword := "" // neondb password - dbUser := "" // neondb user - dbName := "" // neondb name - - databaseURL := fmt.Sprintf("user=%s password=%s dbname=%s sslmode=require host=%s port=%s", dbUser, dbPassword, dbName, dbhost, dbport) - - db, err := sql.Open("postgres", databaseURL) - require.NoError(t, err) - - err = db.Ping() - require.NoError(t, err) - - var walLevel string - err = db.QueryRow("SHOW wal_level").Scan(&walLevel) - require.NoError(t, err) - - assert.Equal(t, "logical", walLevel) - - _, err = db.Exec("DROP TABLE IF EXISTS flights;") - require.NoError(t, err) - _, err = db.Exec("DROP TABLE IF EXISTS flights_non_streamed;") - require.NoError(t, err) - - _, err = db.Exec("CREATE TABLE IF NOT EXISTS flights (id serial PRIMARY KEY, name VARCHAR(50), created_at TIMESTAMP);") - require.NoError(t, err) - - // flights_non_streamed is a control table with data that should not be streamed or queried by snapshot streaming - _, err = db.Exec("CREATE TABLE IF NOT EXISTS flights_non_streamed (id serial PRIMARY KEY, name VARCHAR(50), created_at TIMESTAMP);") - require.NoError(t, err) - - fake := faker.New() - for i := 0; i < 10; i++ { - _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) - require.NoError(t, err) - } - - dbConnOptions := "" // neondb connection options for the endpoint - - template := fmt.Sprintf(` -pg_stream: - host: %s - slot_name: test_slot_native_decoder - user: %s - password: %s - port: %s - schema: public - tls: require - stream_snapshot: true - pg_conn_options: endpoint=%s - decoding_plugin: wal2json - database: %s - tables: - - flights -`, dbhost, dbUser, dbPassword, dbport, dbConnOptions, dbName) - - cacheConf := fmt.Sprintf(` -label: pg_stream_cache -file: - directory: %v -`, tmpDir) - - streamOutBuilder := service.NewStreamBuilder() - require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: OFF`)) - require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) - require.NoError(t, streamOutBuilder.AddInputYAML(template)) - - var outMessages []string - var outMessagesMut sync.Mutex - - require.NoError(t, streamOutBuilder.AddConsumerFunc(func(c context.Context, m *service.Message) error { - msgBytes, err := m.AsBytes() - require.NoError(t, err) - outMessagesMut.Lock() - outMessages = append(outMessages, string(msgBytes)) - outMessagesMut.Unlock() - return nil - })) - - streamOut, err := streamOutBuilder.Build() - require.NoError(t, err) - - go func() { - _ = streamOut.Run(context.Background()) - }() - - assert.Eventually(t, func() bool { - outMessagesMut.Lock() - defer outMessagesMut.Unlock() - return len(outMessages) == 10 - }, time.Minute, time.Millisecond*100) - - for i := 0; i < 10; i++ { - _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) - _, err = db.Exec("INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) - require.NoError(t, err) - } - - assert.Eventually(t, func() bool { - outMessagesMut.Lock() - defer outMessagesMut.Unlock() - return len(outMessages) == 20 - }, time.Minute, time.Millisecond*100) - - require.NoError(t, streamOut.StopWithin(time.Second*10)) - - // Starting stream for the same replication slot should continue from the last LSN - // Meaning we must not receive any old messages again - - streamOutBuilder = service.NewStreamBuilder() - require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: OFF`)) - require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) - require.NoError(t, streamOutBuilder.AddInputYAML(template)) - - outMessages = []string{} - require.NoError(t, streamOutBuilder.AddConsumerFunc(func(c context.Context, m *service.Message) error { - msgBytes, err := m.AsBytes() - require.NoError(t, err) - outMessagesMut.Lock() - outMessages = append(outMessages, string(msgBytes)) - outMessagesMut.Unlock() - return nil - })) - - streamOut, err = streamOutBuilder.Build() - require.NoError(t, err) - - go func() { - assert.NoError(t, streamOut.Run(context.Background())) - }() - - time.Sleep(time.Second * 5) - for i := 0; i < 10; i++ { - _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) - require.NoError(t, err) - } - - assert.Eventually(t, func() bool { - outMessagesMut.Lock() - defer outMessagesMut.Unlock() - return len(outMessages) == 10 - }, time.Minute, time.Millisecond*100) - - require.NoError(t, streamOut.StopWithin(time.Second*10)) - t.Log("All the conditions are met 🎉") - - t.Cleanup(func() { - db.Close() - }) -} diff --git a/internal/impl/postgresql/pglogicalstream/config.go b/internal/impl/postgresql/pglogicalstream/config.go index f31cbb6f0c..c49f4ebdb0 100644 --- a/internal/impl/postgresql/pglogicalstream/config.go +++ b/internal/impl/postgresql/pglogicalstream/config.go @@ -22,6 +22,7 @@ type Config struct { PgConnRuntimeParam string `yaml:"pg_conn_options"` ReplicationSlotName string `yaml:"replication_slot_name"` + TemporaryReplicationSlot bool `yaml:"temporary_replication_slot"` StreamOldData bool `yaml:"stream_old_data"` SeparateChanges bool `yaml:"separate_changes"` SnapshotMemorySafetyFactor float64 `yaml:"snapshot_memory_safety_factor"` diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index d5358ecbba..a873e5ee5b 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -46,7 +46,6 @@ type Stream struct { slotName string schema string tableNames []string - separateChanges bool snapshotBatchSize int decodingPlugin DecodingPlugin decodingPluginArguments []string @@ -117,7 +116,6 @@ func NewPgStream(config Config) (*Stream, error) { slotName: config.ReplicationSlotName, schema: config.DbSchema, snapshotMemorySafetyFactor: config.SnapshotMemorySafetyFactor, - separateChanges: config.SeparateChanges, snapshotBatchSize: config.BatchSize, tableNames: tableNames, logger: config.logger, @@ -173,7 +171,7 @@ func NewPgStream(config Config) (*Stream, error) { // here we create a new replication slot because there is no slot found var createSlotResult CreateReplicationSlotResult createSlotResult, err = CreateReplicationSlot(context.Background(), stream.pgConn, stream.slotName, stream.decodingPlugin.String(), - CreateReplicationSlotOptions{Temporary: false, + CreateReplicationSlotOptions{Temporary: config.TemporaryReplicationSlot, SnapshotAction: "export", }) if err != nil { diff --git a/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go b/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go index 551deec921..259c7ca13b 100644 --- a/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go +++ b/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go @@ -176,10 +176,14 @@ func DecodeWal2JsonChanges(clientXLogPosition string, WALData []byte) (*StreamMe } message := &StreamMessage{ Lsn: &clientXLogPosition, - Changes: make([]StreamMessageChanges, len(changes.Change)), + Changes: []StreamMessageChanges{}, } for _, change := range changes.Change { + if change.Kind == "" { + continue + } + messageChange := StreamMessageChanges{ Operation: change.Kind, Schema: change.Schema, From c388aab5c8bf0f41c1d39298b8b55cbe53fa27a1 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Mon, 7 Oct 2024 14:53:37 +0200 Subject: [PATCH 010/118] chore(): fixed eslint errors && tests --- internal/impl/postgresql/input_postgrecdc.go | 22 ++-- .../pglogicalstream/availablememory.go | 2 +- .../impl/postgresql/pglogicalstream/config.go | 45 ++++--- .../impl/postgresql/pglogicalstream/consts.go | 15 ++- .../pglogicalstream/logical_stream.go | 123 +++++++++++------- .../postgresql/pglogicalstream/pglogrepl.go | 21 ++- .../pglogicalstream/pglogrepl_test.go | 86 +++++++----- .../pglogicalstream/replication_message.go | 2 +- .../replication_message_decoders.go | 14 +- .../replication_message_test.go | 22 +--- .../postgresql/pglogicalstream/snapshotter.go | 44 ++++--- .../pglogicalstream/stream_message.go | 3 + .../impl/postgresql/pglogicalstream/types.go | 15 --- 13 files changed, 245 insertions(+), 169 deletions(-) diff --git a/internal/impl/postgresql/input_postgrecdc.go b/internal/impl/postgresql/input_postgrecdc.go index 4663b3286d..8aad13df2a 100644 --- a/internal/impl/postgresql/input_postgrecdc.go +++ b/internal/impl/postgresql/input_postgrecdc.go @@ -163,7 +163,7 @@ func newPgStreamInput(conf *service.ParsedConfig, logger *service.Logger) (s ser } if pgConnOptions != "" { - pgConnOptions = fmt.Sprintf("options=%s", pgConnOptions) + pgConnOptions = "options=" + pgConnOptions } pgconnConfig := pgconn.Config{ @@ -188,7 +188,7 @@ func newPgStreamInput(conf *service.ParsedConfig, logger *service.Logger) (s ser slotName: dbSlotName, schema: dbSchema, pgConnRuntimeParam: pgConnOptions, - tls: pglogicalstream.TlsVerify(tlsSetting), + tls: pglogicalstream.TLSVerify(tlsSetting), tables: tables, decodingPlugin: decodingPlugin, logger: logger, @@ -212,7 +212,7 @@ func init() { type pgStreamInput struct { dbConfig pgconn.Config - tls pglogicalstream.TlsVerify + tls pglogicalstream.TLSVerify pglogicalStream *pglogicalstream.Stream pgConnRuntimeParam string slotName string @@ -228,15 +228,15 @@ type pgStreamInput struct { func (p *pgStreamInput) Connect(ctx context.Context) error { pgStream, err := pglogicalstream.NewPgStream(pglogicalstream.Config{ PgConnRuntimeParam: p.pgConnRuntimeParam, - DbHost: p.dbConfig.Host, - DbPassword: p.dbConfig.Password, - DbUser: p.dbConfig.User, - DbPort: int(p.dbConfig.Port), - DbTables: p.tables, - DbName: p.dbConfig.Database, - DbSchema: p.schema, + DBHost: p.dbConfig.Host, + DBPassword: p.dbConfig.Password, + DBUser: p.dbConfig.User, + DBPort: int(p.dbConfig.Port), + DBTables: p.tables, + DBName: p.dbConfig.Database, + DBSchema: p.schema, ReplicationSlotName: "rs_" + p.slotName, - TlsVerify: p.tls, + TLSVerify: p.tls, StreamOldData: p.streamSnapshot, TemporaryReplicationSlot: p.temporarySlot, DecodingPlugin: p.decodingPlugin, diff --git a/internal/impl/postgresql/pglogicalstream/availablememory.go b/internal/impl/postgresql/pglogicalstream/availablememory.go index 55df055faf..ae0ae7e42b 100644 --- a/internal/impl/postgresql/pglogicalstream/availablememory.go +++ b/internal/impl/postgresql/pglogicalstream/availablememory.go @@ -10,7 +10,7 @@ package pglogicalstream import "runtime" -func GetAvailableMemory() uint64 { +func getAvailableMemory() uint64 { var memStats runtime.MemStats runtime.ReadMemStats(&memStats) // You can use memStats.Sys or another appropriate memory metric. diff --git a/internal/impl/postgresql/pglogicalstream/config.go b/internal/impl/postgresql/pglogicalstream/config.go index c49f4ebdb0..b9ee6845d5 100644 --- a/internal/impl/postgresql/pglogicalstream/config.go +++ b/internal/impl/postgresql/pglogicalstream/config.go @@ -10,24 +10,39 @@ package pglogicalstream import "github.com/redpanda-data/benthos/v4/public/service" +// Config is the configuration for the pglogicalstream plugin type Config struct { - DbHost string `yaml:"db_host"` - DbPassword string `yaml:"db_password"` - DbUser string `yaml:"db_user"` - DbPort int `yaml:"db_port"` - DbName string `yaml:"db_name"` - DbSchema string `yaml:"db_schema"` - DbTables []string `yaml:"db_tables"` - TlsVerify TlsVerify `yaml:"tls_verify"` - PgConnRuntimeParam string `yaml:"pg_conn_options"` + // DbHost is the host of the PostgreSQL instance + DBHost string `yaml:"db_host"` + // DbPassword is the password for the PostgreSQL instance + DBPassword string `yaml:"db_password"` + // DbUser is the user for the PostgreSQL instance + DBUser string `yaml:"db_user"` + // DbPort is the port of the PostgreSQL instance + DBPort int `yaml:"db_port"` + // DbName is the name of the database to connect to + DBName string `yaml:"db_name"` + // DbSchema is the schema to stream changes from + DBSchema string `yaml:"db_schema"` + // DbTables is the tables to stream changes from + DBTables []string `yaml:"db_tables"` + // TlsVerify is the TLS verification configuration + TLSVerify TLSVerify `yaml:"tls_verify"` + // PgConnRuntimeParam is the runtime parameter for the PostgreSQL connection + PgConnRuntimeParam string `yaml:"pg_conn_options"` - ReplicationSlotName string `yaml:"replication_slot_name"` - TemporaryReplicationSlot bool `yaml:"temporary_replication_slot"` - StreamOldData bool `yaml:"stream_old_data"` - SeparateChanges bool `yaml:"separate_changes"` + // ReplicationSlotName is the name of the replication slot to use + ReplicationSlotName string `yaml:"replication_slot_name"` + // TemporaryReplicationSlot is whether to use a temporary replication slot + TemporaryReplicationSlot bool `yaml:"temporary_replication_slot"` + // StreamOldData is whether to stream all existing data + StreamOldData bool `yaml:"stream_old_data"` + // SnapshotMemorySafetyFactor is the memory safety factor for streaming snapshot SnapshotMemorySafetyFactor float64 `yaml:"snapshot_memory_safety_factor"` - DecodingPlugin string `yaml:"decoding_plugin"` - BatchSize int `yaml:"batch_size"` + // DecodingPlugin is the decoding plugin to use + DecodingPlugin string `yaml:"decoding_plugin"` + // BatchSize is the batch size for streaming + BatchSize int `yaml:"batch_size"` logger *service.Logger } diff --git a/internal/impl/postgresql/pglogicalstream/consts.go b/internal/impl/postgresql/pglogicalstream/consts.go index 544c152d51..0728c3a538 100644 --- a/internal/impl/postgresql/pglogicalstream/consts.go +++ b/internal/impl/postgresql/pglogicalstream/consts.go @@ -1,13 +1,16 @@ package pglogicalstream +// DecodingPlugin is a type for the decoding plugin type DecodingPlugin string const ( + // Wal2JSON is the value for the wal2json decoding plugin. It requires wal2json extension to be installed on the PostgreSQL instance Wal2JSON DecodingPlugin = "wal2json" + // PgOutput is the value for the pgoutput decoding plugin. It requires pgoutput extension to be installed on the PostgreSQL instance PgOutput DecodingPlugin = "pgoutput" ) -func DecodingPluginFromString(plugin string) DecodingPlugin { +func decodingPluginFromString(plugin string) DecodingPlugin { switch plugin { case "wal2json": return Wal2JSON @@ -22,7 +25,11 @@ func (d DecodingPlugin) String() string { return string(d) } -type TlsVerify string +// TLSVerify is a type for the TLS verification mode +type TLSVerify string -const TlsNoVerify TlsVerify = "none" -const TlsRequireVerify TlsVerify = "require" +// TLSNoVerify is the value for no TLS verification +const TLSNoVerify TLSVerify = "none" + +// TLSRequireVerify is the value for TLS verification with a CA +const TLSRequireVerify TLSVerify = "require" diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index a873e5ee5b..e0aae04293 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -25,6 +25,8 @@ import ( "github.com/redpanda-data/benthos/v4/public/service" ) +// Stream is a structure that represents a logical replication stream +// It includes the connection to the database, the context for the stream, and snapshotting functionality type Stream struct { pgConn *pgconn.PgConn // extra copy of db config is required to establish a new db connection @@ -56,6 +58,7 @@ type Stream struct { stopped bool } +// NewPgStream creates a new instance of the Stream struct func NewPgStream(config Config) (*Stream, error) { var ( cfg *pgconn.Config @@ -63,21 +66,21 @@ func NewPgStream(config Config) (*Stream, error) { ) sslVerifyFull := "" - if config.TlsVerify == TlsRequireVerify { + if config.TLSVerify == TLSRequireVerify { sslVerifyFull = "&sslmode=verify-full" } connectionParams := "" if config.PgConnRuntimeParam != "" { - connectionParams = fmt.Sprintf("&%s", config.PgConnRuntimeParam) + connectionParams = "&" + config.PgConnRuntimeParam } q := fmt.Sprintf("postgres://%s:%s@%s:%d/%s?replication=database%s%s", - config.DbUser, - config.DbPassword, - config.DbHost, - config.DbPort, - config.DbName, + config.DBUser, + config.DBPassword, + config.DBHost, + config.DBPort, + config.DBName, sslVerifyFull, connectionParams, ) @@ -86,7 +89,7 @@ func NewPgStream(config Config) (*Stream, error) { return nil, err } - if config.TlsVerify == TlsRequireVerify { + if config.TLSVerify == TLSRequireVerify { cfg.TLSConfig = &tls.Config{ InsecureSkipVerify: true, } @@ -104,9 +107,7 @@ func NewPgStream(config Config) (*Stream, error) { } var tableNames []string - for _, table := range config.DbTables { - tableNames = append(tableNames, table) - } + tableNames = append(tableNames, config.DBTables...) stream := &Stream{ pgConn: dbConn, @@ -114,17 +115,17 @@ func NewPgStream(config Config) (*Stream, error) { messages: make(chan StreamMessage), snapshotMessages: make(chan StreamMessage, 100), slotName: config.ReplicationSlotName, - schema: config.DbSchema, + schema: config.DBSchema, snapshotMemorySafetyFactor: config.SnapshotMemorySafetyFactor, snapshotBatchSize: config.BatchSize, tableNames: tableNames, logger: config.logger, m: sync.Mutex{}, - decodingPlugin: DecodingPluginFromString(config.DecodingPlugin), + decodingPlugin: decodingPluginFromString(config.DecodingPlugin), } for i, table := range tableNames { - tableNames[i] = fmt.Sprintf("%s.%s", config.DbSchema, table) + tableNames[i] = fmt.Sprintf("%s.%s", config.DBSchema, table) } var pluginArguments = []string{} @@ -146,7 +147,7 @@ func NewPgStream(config Config) (*Stream, error) { stream.decodingPluginArguments = pluginArguments - pubName := fmt.Sprintf("pglog_stream_%s", config.ReplicationSlotName) + pubName := "pglog_stream_" + config.ReplicationSlotName if err = CreatePublication(context.Background(), stream.pgConn, pubName, tableNames, true); err != nil { return nil, err } @@ -207,7 +208,7 @@ func NewPgStream(config Config) (*Stream, error) { stream.nextStandbyMessageDeadline = time.Now().Add(stream.standbyMessageTimeout) stream.streamCtx, stream.streamCancel = context.WithCancel(context.Background()) - if !freshlyCreatedSlot || config.StreamOldData == false { + if !freshlyCreatedSlot || !config.StreamOldData { if err = stream.startLr(); err != nil { return nil, err } @@ -223,16 +224,16 @@ func NewPgStream(config Config) (*Stream, error) { } func (s *Stream) startLr() error { - var err error - err = StartReplication(context.Background(), s.pgConn, s.slotName, s.lsnrestart, StartReplicationOptions{PluginArgs: s.decodingPluginArguments}) - if err != nil { + if err := StartReplication(context.Background(), s.pgConn, s.slotName, s.lsnrestart, StartReplicationOptions{PluginArgs: s.decodingPluginArguments}); err != nil { return err } - s.logger.Infof("Started logical replication on slot slot-name: %v", s.slotName) + s.logger.Infof("Started logical replication on slot slot-name: %v", s.slotName) return nil } +// AckLSN acknowledges the LSN up to which the stream has processed the messages. +// This makes Postgres to remove the WAL files that are no longer needed. func (s *Stream) AckLSN(lsn string) error { var err error s.clientXLogPos, err = ParseLSN(lsn) @@ -280,7 +281,9 @@ func (s *Stream) streamMessagesAsync() { if err != nil { s.logger.Errorf("Failed to send Standby status message at LSN#%s: %v", s.clientXLogPos.String(), err) - s.Stop() + if err = s.Stop(); err != nil { + s.logger.Errorf("Failed to stop the stream: %v", err) + } return } s.logger.Debugf("Sent Standby status message at LSN#%s", s.clientXLogPos.String()) @@ -302,13 +305,17 @@ func (s *Stream) streamMessagesAsync() { } s.logger.Errorf("Failed to receive messages from PostgreSQL: %v", err) - s.Stop() + if err = s.Stop(); err != nil { + s.logger.Errorf("Failed to stop the stream: %v", err) + } return } if errMsg, ok := rawMsg.(*pgproto3.ErrorResponse); ok { s.logger.Errorf("Received error message from Postgres: %v", errMsg) - s.Stop() + if err = s.Stop(); err != nil { + s.logger.Errorf("Failed to stop the stream: %v", err) + } return } @@ -337,10 +344,12 @@ func (s *Stream) streamMessagesAsync() { clientXLogPos := xld.WALStart + LSN(len(xld.WALData)) if s.decodingPlugin == "wal2json" { - message, err := DecodeWal2JsonChanges(clientXLogPos.String(), xld.WALData) + message, err := decodeWal2JsonChanges(clientXLogPos.String(), xld.WALData) if err != nil { s.logger.Errorf("decodeWal2JsonChanges failed: %w", err) - s.Stop() + if err = s.Stop(); err != nil { + s.logger.Errorf("Failed to stop the stream: %v", err) + } return } @@ -351,7 +360,9 @@ func (s *Stream) streamMessagesAsync() { if err = s.AckLSN(clientXLogPos.String()); err != nil { // stop reading from replication slot // if we can't acknowledge the LSN - s.Stop() + if err = s.Stop(); err != nil { + s.logger.Errorf("Failed to stop the stream: %v", err) + } return } } else { @@ -363,10 +374,12 @@ func (s *Stream) streamMessagesAsync() { // message changes must be collected in the buffer in the context of the same transaction // as single transaction can contain multiple changes // and LSN ack will cause potential loss of changes - isBegin, err := IsBeginMessage(xld.WALData) + isBegin, err := isBeginMessage(xld.WALData) if err != nil { s.logger.Errorf("Failed to parse WAL data: %w", err) - s.Stop() + if err = s.Stop(); err != nil { + s.logger.Errorf("Failed to stop the stream: %v", err) + } return } @@ -375,10 +388,12 @@ func (s *Stream) streamMessagesAsync() { } // parse changes inside the transaction - message, err := DecodePgOutput(xld.WALData, relations, typeMap) + message, err := decodePgOutput(xld.WALData, relations, typeMap) if err != nil { s.logger.Errorf("decodePgOutput failed: %w", err) - s.Stop() + if err = s.Stop(); err != nil { + s.logger.Errorf("Failed to stop the stream: %v", err) + } return } @@ -386,10 +401,12 @@ func (s *Stream) streamMessagesAsync() { pgoutputChanges = append(pgoutputChanges, *message) } - isCommit, err := IsCommitMessage(xld.WALData) + isCommit, err := isCommitMessage(xld.WALData) if err != nil { s.logger.Errorf("Failed to parse WAL data: %w", err) - s.Stop() + if err = s.Stop(); err != nil { + s.logger.Errorf("Failed to stop the stream: %v", err) + } return } @@ -400,7 +417,9 @@ func (s *Stream) streamMessagesAsync() { if err = s.AckLSN(clientXLogPos.String()); err != nil { // stop reading from replication slot // if we can't acknowledge the LSN - s.Stop() + if err = s.Stop(); err != nil { + s.logger.Errorf("Failed to stop the stream: %v", err) + } return } } else { @@ -425,7 +444,7 @@ func (s *Stream) processSnapshot() { os.Exit(1) } - if err = snapshotter.Prepare(); err != nil { + if err = snapshotter.prepare(); err != nil { s.logger.Errorf("Failed to prepare database snapshot. Probably snapshot is expired...: %v", err.Error()) if err = s.cleanUpOnFailure(); err != nil { s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) @@ -434,10 +453,10 @@ func (s *Stream) processSnapshot() { os.Exit(1) } defer func() { - if err = snapshotter.ReleaseSnapshot(); err != nil { + if err = snapshotter.releaseSnapshot(); err != nil { s.logger.Errorf("Failed to release database snapshot: %v", err.Error()) } - if err = snapshotter.CloseConn(); err != nil { + if err = snapshotter.closeConn(); err != nil { s.logger.Errorf("Failed to close database connection: %v", err.Error()) } }() @@ -450,7 +469,7 @@ func (s *Stream) processSnapshot() { offset = 0 ) - avgRowSizeBytes, err = snapshotter.FindAvgRowSize(table) + avgRowSizeBytes, err = snapshotter.findAvgRowSize(table) if err != nil { s.logger.Errorf("Failed to calculate average row size for table %v: %v", table, err.Error()) if err = s.cleanUpOnFailure(); err != nil { @@ -460,8 +479,8 @@ func (s *Stream) processSnapshot() { os.Exit(1) } - availableMemory := GetAvailableMemory() - batchSize := snapshotter.CalculateBatchSize(availableMemory, uint64(avgRowSizeBytes.Int64)) + availableMemory := getAvailableMemory() + batchSize := snapshotter.calculateBatchSize(availableMemory, uint64(avgRowSizeBytes.Int64)) s.logger.Infof("Querying snapshot batch_side: %v, available_memory: %v, avg_row_size: %v", batchSize, availableMemory, avgRowSizeBytes.Int64) tablePk, err := s.getPrimaryKeyColumn(table) @@ -476,8 +495,22 @@ func (s *Stream) processSnapshot() { for { var snapshotRows *sql.Rows - if snapshotRows, err = snapshotter.QuerySnapshotData(table, tablePk, batchSize, offset); err != nil { - panic(fmt.Errorf("can't query snapshot data: %w", err)) // TODO + if snapshotRows, err = snapshotter.querySnapshotData(table, tablePk, batchSize, offset); err != nil { + s.logger.Errorf("Failed to query snapshot for table %v: %v", table, err.Error()) + if err = s.cleanUpOnFailure(); err != nil { + s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) + } + + os.Exit(1) + } + + if snapshotRows.Err() != nil { + s.logger.Errorf("Failed to query snapshot for table %v: %v", table, err.Error()) + if err = s.cleanUpOnFailure(); err != nil { + s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) + } + + os.Exit(1) } columnTypes, err := snapshotRows.ColumnTypes() @@ -512,13 +545,10 @@ func (s *Stream) processSnapshot() { switch v.DatabaseTypeName() { case "VARCHAR", "TEXT", "UUID", "TIMESTAMP": scanArgs[i] = new(sql.NullString) - break case "BOOL": scanArgs[i] = new(sql.NullBool) - break case "INT4": scanArgs[i] = new(sql.NullInt64) - break default: scanArgs[i] = new(sql.NullString) } @@ -598,10 +628,14 @@ func (s *Stream) processSnapshot() { go s.streamMessagesAsync() } +// SnapshotMessageC represents a message from the stream that are sent to the consumer on the snapshot processing stage +// meaning these messages will have nil LSN field func (s *Stream) SnapshotMessageC() chan StreamMessage { return s.snapshotMessages } +// LrMessageC represents a message from the stream that are sent to the consumer on the logical replication stage +// meaning these messages will have non-nil LSN field func (s *Stream) LrMessageC() chan StreamMessage { return s.messages } @@ -637,6 +671,7 @@ func (s *Stream) getPrimaryKeyColumn(tableName string) (string, error) { return pkColName, nil } +// Stop closes the stream conect and prevents from replication slot read func (s *Stream) Stop() error { s.m.Lock() s.stopped = true diff --git a/internal/impl/postgresql/pglogicalstream/pglogrepl.go b/internal/impl/postgresql/pglogicalstream/pglogrepl.go index 0b04b0ea58..0213b008ba 100644 --- a/internal/impl/postgresql/pglogicalstream/pglogrepl.go +++ b/internal/impl/postgresql/pglogicalstream/pglogrepl.go @@ -32,14 +32,19 @@ import ( ) const ( - XLogDataByteID = 'w' + // XLogDataByteID is the byte ID for XLogData messages. + XLogDataByteID = 'w' + // PrimaryKeepaliveMessageByteID is the byte ID for PrimaryKeepaliveMessage messages. PrimaryKeepaliveMessageByteID = 'k' - StandbyStatusUpdateByteID = 'r' + // StandbyStatusUpdateByteID is the byte ID for StandbyStatusUpdate messages. + StandbyStatusUpdateByteID = 'r' ) +// ReplicationMode is the mode of replication to use. type ReplicationMode int const ( + // LogicalReplication is the only replication mode supported by this plugin LogicalReplication ReplicationMode = iota ) @@ -207,6 +212,7 @@ func ParseTimelineHistory(mrr *pgconn.MultiResultReader) (TimelineHistoryResult, return thr, nil } +// CreateReplicationSlotOptions are the options for the CREATE_REPLICATION_SLOT command. Including Mode, Temporary, and SnapshotAction. type CreateReplicationSlotOptions struct { Temporary bool SnapshotAction string @@ -273,6 +279,7 @@ func ParseCreateReplicationSlot(mrr *pgconn.MultiResultReader) (CreateReplicatio return crsr, nil } +// DropReplicationSlotOptions are options for the DROP_REPLICATION_SLOT command. type DropReplicationSlotOptions struct { Wait bool } @@ -288,13 +295,14 @@ func DropReplicationSlot(ctx context.Context, conn *pgconn.PgConn, slotName stri return err } +// CreatePublication creates a new PostgreSQL publication with the given name for a list of tables and drop if exists flag func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName string, tables []string, dropIfExist bool) error { result := conn.Exec(context.Background(), fmt.Sprintf("DROP PUBLICATION IF EXISTS %s;", publicationName)) if _, err := result.ReadAll(); err != nil { return nil } - tablesSchemaFilter := fmt.Sprintf("FOR TABLE %s", strings.Join(tables, ",")) + tablesSchemaFilter := "FOR TABLE " + strings.Join(tables, ",") if len(tables) == 0 { tablesSchemaFilter = "FOR ALL TABLES" } @@ -305,6 +313,10 @@ func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName return nil } +// StartReplicationOptions are the options for the START_REPLICATION command. +// The Timeline field is optional and defaults to 0, which means the current server timeline. +// The Mode field is required and must be either PhysicalReplication or LogicalReplication. ## PhysicalReplication is not supporter by this plugin, but still can be implemented +// The PluginArgs field is optional and only used for LogicalReplication. type StartReplicationOptions struct { Timeline int32 // 0 means current server timeline Mode ReplicationMode @@ -353,6 +365,7 @@ func StartReplication(ctx context.Context, conn *pgconn.PgConn, slotName string, } } +// PrimaryKeepaliveMessage is a message sent by the primary server to the replica server to keep the connection alive. type PrimaryKeepaliveMessage struct { ServerWALEnd LSN ServerTime time.Time @@ -373,6 +386,7 @@ func ParsePrimaryKeepaliveMessage(buf []byte) (PrimaryKeepaliveMessage, error) { return pkm, nil } +// XLogData is a message sent by the primary server to the replica server containing WAL data. type XLogData struct { WALStart LSN ServerWALEnd LSN @@ -479,6 +493,7 @@ func SendStandbyCopyDone(_ context.Context, conn *pgconn.PgConn) (cdr *CopyDoneR if lerr == nil { lsn, lerr := ParseLSN(string(m.Values[1])) if lerr == nil { + cdr = new(CopyDoneResult) cdr.Timeline = int32(timeline) cdr.LSN = lsn } diff --git a/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go b/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go index e726b2f08a..fbcec320a3 100644 --- a/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go +++ b/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go @@ -142,31 +142,39 @@ func createDockerInstance(t *testing.T) (*dockertest.Pool, *dockertest.Resource, } func TestIdentifySystem(t *testing.T) { - pool, resource, dbUrl := createDockerInstance(t) - defer pool.Purge(resource) + pool, resource, dbURL := createDockerInstance(t) + defer func() { + err := pool.Purge(resource) + require.NoError(t, err) + }() ctx, cancel := context.WithTimeout(context.Background(), time.Second*100) defer cancel() - conn, err := pgconn.Connect(ctx, dbUrl) + conn, err := pgconn.Connect(ctx, dbURL) require.NoError(t, err) defer closeConn(t, conn) sysident, err := IdentifySystem(ctx, conn) require.NoError(t, err) - assert.Greater(t, len(sysident.SystemID), 0) - assert.True(t, sysident.Timeline > 0) - assert.True(t, sysident.XLogPos > 0) - assert.Greater(t, len(sysident.DBName), 0) + assert.NotEmpty(t, sysident.SystemID, 0) + assert.Greater(t, sysident.Timeline, int32(0)) + + xlogPositionIsPositive := sysident.XLogPos > 0 + assert.True(t, xlogPositionIsPositive) + assert.NotEmpty(t, sysident.DBName, 0) } func TestCreateReplicationSlot(t *testing.T) { - pool, resource, dbUrl := createDockerInstance(t) - defer pool.Purge(resource) + pool, resource, dbURL := createDockerInstance(t) + defer func() { + err := pool.Purge(resource) + require.NoError(t, err) + }() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - conn, err := pgconn.Connect(ctx, dbUrl) + conn, err := pgconn.Connect(ctx, dbURL) require.NoError(t, err) defer closeConn(t, conn) @@ -178,12 +186,15 @@ func TestCreateReplicationSlot(t *testing.T) { } func TestDropReplicationSlot(t *testing.T) { - pool, resource, dbUrl := createDockerInstance(t) - defer pool.Purge(resource) + pool, resource, dbURL := createDockerInstance(t) + defer func() { + err := pool.Purge(resource) + require.NoError(t, err) + }() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - conn, err := pgconn.Connect(ctx, dbUrl) + conn, err := pgconn.Connect(ctx, dbURL) require.NoError(t, err) defer closeConn(t, conn) @@ -198,13 +209,16 @@ func TestDropReplicationSlot(t *testing.T) { } func TestStartReplication(t *testing.T) { - pool, resource, dbUrl := createDockerInstance(t) - defer pool.Purge(resource) + pool, resource, dbURL := createDockerInstance(t) + defer func() { + err := pool.Purge(resource) + require.NoError(t, err) + }() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - conn, err := pgconn.Connect(ctx, dbUrl) + conn, err := pgconn.Connect(ctx, dbURL) require.NoError(t, err) defer closeConn(t, conn) @@ -233,7 +247,7 @@ func TestStartReplication(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - config, err := pgconn.ParseConfig(dbUrl) + config, err := pgconn.ParseConfig(dbURL) require.NoError(t, err) delete(config.RuntimeParams, "replication") @@ -293,59 +307,69 @@ drop table t; rxKeepAlive() xld := rxXLogData() - begin, err := IsBeginMessage(xld.WALData) + begin, err := isBeginMessage(xld.WALData) require.NoError(t, err) - assert.Equal(t, true, begin) + assert.True(t, begin) xld = rxXLogData() - relationStreamMessage, err := DecodePgOutput(xld.WALData, relations, typeMap) + var streamMessage *StreamMessageChanges + streamMessage, err = decodePgOutput(xld.WALData, relations, typeMap) require.NoError(t, err) - assert.Nil(t, relationStreamMessage) + assert.Nil(t, streamMessage) xld = rxXLogData() - streamMessage, err := DecodePgOutput(xld.WALData, relations, typeMap) + streamMessage, err = decodePgOutput(xld.WALData, relations, typeMap) + require.NoError(t, err) jsonData, err := json.Marshal(&streamMessage) require.NoError(t, err) assert.Equal(t, "{\"operation\":\"insert\",\"schema\":\"public\",\"table\":\"t\",\"data\":{\"id\":1,\"name\":\"foo\"}}", string(jsonData)) xld = rxXLogData() - streamMessage, err = DecodePgOutput(xld.WALData, relations, typeMap) + streamMessage, err = decodePgOutput(xld.WALData, relations, typeMap) + require.NoError(t, err) jsonData, err = json.Marshal(&streamMessage) require.NoError(t, err) assert.Equal(t, "{\"operation\":\"insert\",\"schema\":\"public\",\"table\":\"t\",\"data\":{\"id\":2,\"name\":\"bar\"}}", string(jsonData)) xld = rxXLogData() - streamMessage, err = DecodePgOutput(xld.WALData, relations, typeMap) + streamMessage, err = decodePgOutput(xld.WALData, relations, typeMap) + require.NoError(t, err) jsonData, err = json.Marshal(&streamMessage) require.NoError(t, err) assert.Equal(t, "{\"operation\":\"insert\",\"schema\":\"public\",\"table\":\"t\",\"data\":{\"id\":3,\"name\":\"baz\"}}", string(jsonData)) xld = rxXLogData() - streamMessage, err = DecodePgOutput(xld.WALData, relations, typeMap) + streamMessage, err = decodePgOutput(xld.WALData, relations, typeMap) + require.NoError(t, err) jsonData, err = json.Marshal(&streamMessage) require.NoError(t, err) assert.Equal(t, "{\"operation\":\"update\",\"schema\":\"public\",\"table\":\"t\",\"data\":{\"id\":3,\"name\":\"quz\"}}", string(jsonData)) xld = rxXLogData() - streamMessage, err = DecodePgOutput(xld.WALData, relations, typeMap) + streamMessage, err = decodePgOutput(xld.WALData, relations, typeMap) + require.NoError(t, err) jsonData, err = json.Marshal(&streamMessage) require.NoError(t, err) assert.Equal(t, "{\"operation\":\"delete\",\"schema\":\"public\",\"table\":\"t\",\"data\":{\"id\":2,\"name\":null}}", string(jsonData)) xld = rxXLogData() - commit, err := IsCommitMessage(xld.WALData) + var commit bool + commit, err = isCommitMessage(xld.WALData) require.NoError(t, err) - assert.Equal(t, true, commit) + assert.True(t, commit) } func TestSendStandbyStatusUpdate(t *testing.T) { - pool, resource, dbUrl := createDockerInstance(t) - defer pool.Purge(resource) + pool, resource, dbURL := createDockerInstance(t) + defer func() { + err := pool.Purge(resource) + require.NoError(t, err) + }() ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - conn, err := pgconn.Connect(ctx, dbUrl) + conn, err := pgconn.Connect(ctx, dbURL) require.NoError(t, err) defer closeConn(t, conn) diff --git a/internal/impl/postgresql/pglogicalstream/replication_message.go b/internal/impl/postgresql/pglogicalstream/replication_message.go index 914d0e9c96..f68abd7d3b 100644 --- a/internal/impl/postgresql/pglogicalstream/replication_message.go +++ b/internal/impl/postgresql/pglogicalstream/replication_message.go @@ -106,7 +106,7 @@ func (m *baseMessage) SetType(t MessageType) { // Decode parse src into message struct. The src must contain the complete message starts after // the first message type byte. func (m *baseMessage) Decode(_ []byte) error { - return fmt.Errorf("message decode not implemented") + return errors.New("message decode not implemented") } func (m *baseMessage) lengthError(name string, expectedLen, actualLen int) error { diff --git a/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go b/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go index 259c7ca13b..461cb90378 100644 --- a/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go +++ b/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go @@ -12,7 +12,7 @@ import ( // ---------------------------------------------------------------------------- // PgOutput section -func IsBeginMessage(WALData []byte) (bool, error) { +func isBeginMessage(WALData []byte) (bool, error) { logicalMsg, err := Parse(WALData) if err != nil { return false, err @@ -22,7 +22,7 @@ func IsBeginMessage(WALData []byte) (bool, error) { return ok, nil } -func IsCommitMessage(WALData []byte) (bool, error) { +func isCommitMessage(WALData []byte) (bool, error) { logicalMsg, err := Parse(WALData) if err != nil { return false, err @@ -32,12 +32,12 @@ func IsCommitMessage(WALData []byte) (bool, error) { return ok, nil } -// DecodePgOutput decodes a logical replication message in pgoutput format. +// decodePgOutput decodes a logical replication message in pgoutput format. // It uses the provided relations map to look up the relation metadata for the // as a side effect it updates the relations map with any new relation metadata // When the relation is changes in the database, the relation message is sent // before the change message. -func DecodePgOutput(WALData []byte, relations map[uint32]*RelationMessage, typeMap *pgtype.Map) (*StreamMessageChanges, error) { +func decodePgOutput(WALData []byte, relations map[uint32]*RelationMessage, typeMap *pgtype.Map) (*StreamMessageChanges, error) { logicalMsg, err := Parse(WALData) message := &StreamMessageChanges{} @@ -149,7 +149,7 @@ func decodeTextColumnData(mi *pgtype.Map, data []byte, dataType uint32) (interfa // ---------------------------------------------------------------------------- // Wal2Json section -type WallMessageWal2JSON struct { +type walMessageWal2JSON struct { Change []struct { Kind string `json:"kind"` Schema string `json:"schema"` @@ -165,8 +165,8 @@ type WallMessageWal2JSON struct { } `json:"change"` } -func DecodeWal2JsonChanges(clientXLogPosition string, WALData []byte) (*StreamMessage, error) { - var changes WallMessageWal2JSON +func decodeWal2JsonChanges(clientXLogPosition string, WALData []byte) (*StreamMessage, error) { + var changes walMessageWal2JSON if err := json.NewDecoder(bytes.NewReader(WALData)).Decode(&changes); err != nil { return nil, err } diff --git a/internal/impl/postgresql/pglogicalstream/replication_message_test.go b/internal/impl/postgresql/pglogicalstream/replication_message_test.go index 761875ad1c..d0c438900e 100644 --- a/internal/impl/postgresql/pglogicalstream/replication_message_test.go +++ b/internal/impl/postgresql/pglogicalstream/replication_message_test.go @@ -10,7 +10,6 @@ package pglogicalstream import ( "encoding/binary" - "errors" "math/rand" "testing" "time" @@ -121,12 +120,6 @@ func (s *messageSuite) putMessageTestData(msg []byte) *LogicalDecodingMessage { } } -func (s *messageSuite) assertV1NotSupported(msg []byte) { - _, err := Parse(msg) - s.Error(err) - s.True(errors.Is(err, errMsgNotSupported)) -} - func (s *messageSuite) createRelationTestData() ([]byte, *RelationMessage) { relationID := uint32(rand.Int31()) namespace := "public" @@ -175,7 +168,6 @@ func (s *messageSuite) createRelationTestData() ([]byte, *RelationMessage) { bigEndian.PutUint32(msg[off:], 1184) // timestamptz off += 4 bigEndian.PutUint32(msg[off:], uint32(noAtttypmod)) - off += 4 expected := &RelationMessage{ RelationID: relationID, @@ -487,7 +479,7 @@ func (s *messageSuite) createDeleteTestDataTypeK() ([]byte, *DeleteMessage) { off++ bigEndian.PutUint16(msg[off:], 1) off += 2 - off += s.putTupleColumn(msg[off:], 't', oldCol1Data) + s.putTupleColumn(msg[off:], 't', oldCol1Data) expected := &DeleteMessage{ RelationID: relationID, OldTupleType: DeleteMessageTupleTypeKey, @@ -525,7 +517,7 @@ func (s *messageSuite) createDeleteTestDataTypeO() ([]byte, *DeleteMessage) { bigEndian.PutUint16(msg[off:], 2) off += 2 off += s.putTupleColumn(msg[off:], 't', oldCol1Data) - off += s.putTupleColumn(msg[off:], 't', oldCol2Data) + s.putTupleColumn(msg[off:], 't', oldCol2Data) expected := &DeleteMessage{ RelationID: relationID, OldTupleType: DeleteMessageTupleTypeOld, @@ -576,16 +568,6 @@ func (s *messageSuite) createTruncateTestData() ([]byte, *TruncateMessage) { return msg, expected } -func (s *messageSuite) insertXid(msg []byte) ([]byte, uint32) { - msgV2 := make([]byte, 4+len(msg)) - msgV2[0] = msg[0] - xid := s.newXid() - bigEndian.PutUint32(msgV2[1:], xid) - copy(msgV2[5:], msg[1:]) - - return msgV2, xid -} - func TestBeginMessageSuite(t *testing.T) { suite.Run(t, new(beginMessageSuite)) } diff --git a/internal/impl/postgresql/pglogicalstream/snapshotter.go b/internal/impl/postgresql/pglogicalstream/snapshotter.go index 7bbe171661..1ecbc50ac1 100644 --- a/internal/impl/postgresql/pglogicalstream/snapshotter.go +++ b/internal/impl/postgresql/pglogicalstream/snapshotter.go @@ -12,6 +12,8 @@ import ( "database/sql" "fmt" + "errors" + "github.com/jackc/pgx/v5/pgconn" _ "github.com/lib/pq" "github.com/redpanda-data/benthos/v4/public/service" @@ -29,8 +31,9 @@ type Snapshotter struct { snapshotName string } +// NewSnapshotter creates a new Snapshotter instance func NewSnapshotter(dbConf pgconn.Config, snapshotName string, logger *service.Logger) (*Snapshotter, error) { - var sslMode = "none" + var sslMode string if dbConf.TLSConfig != nil { sslMode = "require" } else { @@ -49,7 +52,7 @@ func NewSnapshotter(dbConf pgconn.Config, snapshotName string, logger *service.L }, err } -func (s *Snapshotter) Prepare() error { +func (s *Snapshotter) prepare() error { if _, err := s.pgConnection.Exec("BEGIN TRANSACTION ISOLATION LEVEL REPEATABLE READ;"); err != nil { return err } @@ -60,25 +63,32 @@ func (s *Snapshotter) Prepare() error { return nil } -func (s *Snapshotter) FindAvgRowSize(table string) (sql.NullInt64, error) { - var avgRowSize sql.NullInt64 - - if rows, err := s.pgConnection.Query(fmt.Sprintf(`SELECT SUM(pg_column_size('%s.*')) / COUNT(*) FROM %s;`, table, table)); err != nil { +func (s *Snapshotter) findAvgRowSize(table string) (sql.NullInt64, error) { + var ( + avgRowSize sql.NullInt64 + rows *sql.Rows + err error + ) + if rows, err = s.pgConnection.Query(fmt.Sprintf(`SELECT SUM(pg_column_size('%s.*')) / COUNT(*) FROM %s;`, table, table)); err != nil { return avgRowSize, fmt.Errorf("can get avg row size due to query failure: %w", err) - } else { - if rows.Next() { - if err = rows.Scan(&avgRowSize); err != nil { - return avgRowSize, fmt.Errorf("can get avg row size: %w", err) - } - } else { - return avgRowSize, fmt.Errorf("can get avg row size; 0 rows returned") + } + + if rows.Err() != nil { + return avgRowSize, fmt.Errorf("can get avg row size due to query failure: %w", rows.Err()) + } + + if rows.Next() { + if err = rows.Scan(&avgRowSize); err != nil { + return avgRowSize, fmt.Errorf("can get avg row size: %w", err) } + } else { + return avgRowSize, errors.New("can get avg row size; 0 rows returned") } return avgRowSize, nil } -func (s *Snapshotter) CalculateBatchSize(availableMemory uint64, estimatedRowSize uint64) int { +func (s *Snapshotter) calculateBatchSize(availableMemory uint64, estimatedRowSize uint64) int { // Adjust this factor based on your system's memory constraints. // This example uses a safety factor of 0.8 to leave some memory headroom. safetyFactor := 0.6 @@ -90,17 +100,17 @@ func (s *Snapshotter) CalculateBatchSize(availableMemory uint64, estimatedRowSiz return batchSize } -func (s *Snapshotter) QuerySnapshotData(table string, pk string, limit, offset int) (rows *sql.Rows, err error) { +func (s *Snapshotter) querySnapshotData(table string, pk string, limit, offset int) (rows *sql.Rows, err error) { s.logger.Infof("Query snapshot table: %v, limit: %v, offset: %v, pk: %v", table, limit, offset, pk) return s.pgConnection.Query(fmt.Sprintf("SELECT * FROM %s ORDER BY %s LIMIT %d OFFSET %d;", table, pk, limit, offset)) } -func (s *Snapshotter) ReleaseSnapshot() error { +func (s *Snapshotter) releaseSnapshot() error { _, err := s.pgConnection.Exec("COMMIT;") return err } -func (s *Snapshotter) CloseConn() error { +func (s *Snapshotter) closeConn() error { if s.pgConnection != nil { return s.pgConnection.Close() } diff --git a/internal/impl/postgresql/pglogicalstream/stream_message.go b/internal/impl/postgresql/pglogicalstream/stream_message.go index 520539500a..446422c16c 100644 --- a/internal/impl/postgresql/pglogicalstream/stream_message.go +++ b/internal/impl/postgresql/pglogicalstream/stream_message.go @@ -1,5 +1,7 @@ package pglogicalstream +// StreamMessageChanges represents the changes in a single message +// Single message can have multiple changes type StreamMessageChanges struct { Operation string `json:"operation"` Schema string `json:"schema"` @@ -8,6 +10,7 @@ type StreamMessageChanges struct { Data map[string]any `json:"data"` } +// StreamMessage represents a single message after it has been decoded by the plugin type StreamMessage struct { Lsn *string `json:"lsn"` Changes []StreamMessageChanges `json:"changes"` diff --git a/internal/impl/postgresql/pglogicalstream/types.go b/internal/impl/postgresql/pglogicalstream/types.go index 7dd4670cbc..2d1d0ff3ad 100644 --- a/internal/impl/postgresql/pglogicalstream/types.go +++ b/internal/impl/postgresql/pglogicalstream/types.go @@ -7,18 +7,3 @@ // https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md package pglogicalstream - -type Wal2JsonChanges struct { - Lsn *string `json:"lsn"` - Changes []Wal2JsonChange `json:"change"` -} - -type Wal2JsonChange struct { - Kind string `json:"kind"` - Schema string `json:"schema"` - Table string `json:"table"` - ColumnNames []string `json:"columnnames"` - ColumnTypes []string `json:"columntypes"` - ColumnValues []interface{} `json:"columnvalues"` -} -type OnMessage = func(message Wal2JsonChanges) From dd82b0f82595548f969ed21bb87eae99749e5e19 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Mon, 7 Oct 2024 14:55:26 +0200 Subject: [PATCH 011/118] chore(): removed panics --- .../postgresql/pglogicalstream/logical_stream.go | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index e0aae04293..aaad6f65c7 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -238,7 +238,10 @@ func (s *Stream) AckLSN(lsn string) error { var err error s.clientXLogPos, err = ParseLSN(lsn) if err != nil { - panic(fmt.Errorf("failed to parse LSN for Acknowledge %w", err)) // TODO + s.logger.Errorf("Failed to parse LSN for Acknowledge: %v", err) + if err = s.Stop(); err != nil { + s.logger.Errorf("Failed to stop the stream: %v", err) + } } err = SendStandbyStatusUpdate(context.Background(), s.pgConn, StandbyStatusUpdate{ @@ -329,7 +332,10 @@ func (s *Stream) streamMessagesAsync() { case PrimaryKeepaliveMessageByteID: pkm, err := ParsePrimaryKeepaliveMessage(msg.Data[1:]) if err != nil { - panic(fmt.Errorf("parsePrimaryKeepaliveMessage failed: %w", err)) // TODO + s.logger.Errorf("Failed to parse PrimaryKeepaliveMessage: %v", err) + if err = s.Stop(); err != nil { + s.logger.Errorf("Failed to stop the stream: %v", err) + } } if pkm.ReplyRequested { @@ -339,7 +345,10 @@ func (s *Stream) streamMessagesAsync() { case XLogDataByteID: xld, err := ParseXLogData(msg.Data[1:]) if err != nil { - panic(fmt.Errorf("parseXLogData failed: %w", err)) + s.logger.Errorf("Failed to parse XLogData: %v", err) + if err = s.Stop(); err != nil { + s.logger.Errorf("Failed to stop the stream: %v", err) + } } clientXLogPos := xld.WALStart + LSN(len(xld.WALData)) From aa230d51e6e124450b50c1dcab255c2cf5fe72f1 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Tue, 8 Oct 2024 14:07:53 +0200 Subject: [PATCH 012/118] fix(): table name in snapshotter --- internal/impl/postgresql/pglogicalstream/logical_stream.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index aaad6f65c7..e4c0c658c2 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -599,11 +599,12 @@ func (s *Stream) processSnapshot() { columnValues[i] = scanArgs[i] } + tableWithoutSchema := strings.Split(table, ".")[1] snapshotChangePacket := StreamMessage{ Lsn: nil, Changes: []StreamMessageChanges{ { - Table: table, + Table: tableWithoutSchema, Operation: "insert", Schema: s.schema, Data: func() map[string]any { From 3fb399686d1cee8f8cd9a6eca81b8df24ff9b331 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Wed, 9 Oct 2024 09:52:51 +0200 Subject: [PATCH 013/118] chore(): working on stream uncomited changes --- internal/impl/postgresql/input_postgrecdc.go | 11 ++ internal/impl/postgresql/integration_test.go | 185 ++++++++++++++++++ .../impl/postgresql/pglogicalstream/config.go | 3 + .../pglogicalstream/logical_stream.go | 137 +++++++++---- .../pglogicalstream/pglogrepl_test.go | 2 +- .../replication_message_decoders.go | 8 +- 6 files changed, 298 insertions(+), 48 deletions(-) diff --git a/internal/impl/postgresql/input_postgrecdc.go b/internal/impl/postgresql/input_postgrecdc.go index 8aad13df2a..fa5357dfae 100644 --- a/internal/impl/postgresql/input_postgrecdc.go +++ b/internal/impl/postgresql/input_postgrecdc.go @@ -47,6 +47,7 @@ var pgStreamConfigSpec = service.NewConfigSpec(). Description("Defines whether benthos need to verify (skipinsecure) TLS configuration"). Example("none"). Default("none")). + Field(service.NewBoolField("stream_uncomited").Default(false).Description("Defines whether you want to stream uncomitted messages before receiving commit message from postgres. This may lead to duplicated records after the the connector has been restarted")). Field(service.NewStringField("pg_conn_options").Default("")). Field(service.NewBoolField("stream_snapshot"). Description("Set `true` if you want to receive all the data that currently exist in database"). @@ -87,6 +88,7 @@ func newPgStreamInput(conf *service.ParsedConfig, logger *service.Logger) (s ser snapshotMemSafetyFactor float64 decodingPlugin string pgConnOptions string + streamUncomited bool ) dbSchema, err = conf.FieldString("schema") @@ -148,6 +150,11 @@ func newPgStreamInput(conf *service.ParsedConfig, logger *service.Logger) (s ser return nil, err } + streamUncomited, err = conf.FieldBool("stream_uncomited") + if err != nil { + return nil, err + } + decodingPlugin, err = conf.FieldString("decoding_plugin") if err != nil { return nil, err @@ -192,6 +199,7 @@ func newPgStreamInput(conf *service.ParsedConfig, logger *service.Logger) (s ser tables: tables, decodingPlugin: decodingPlugin, logger: logger, + streamUncomited: streamUncomited, temporarySlot: temporarySlot, }), err } @@ -222,6 +230,7 @@ type pgStreamInput struct { decodingPlugin string streamSnapshot bool snapshotMemSafetyFactor float64 + streamUncomited bool logger *service.Logger } @@ -239,6 +248,7 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { TLSVerify: p.tls, StreamOldData: p.streamSnapshot, TemporaryReplicationSlot: p.temporarySlot, + StreamUncomited: p.streamUncomited, DecodingPlugin: p.decodingPlugin, SnapshotMemorySafetyFactor: p.snapshotMemSafetyFactor, }) @@ -280,6 +290,7 @@ func (p *pgStreamInput) Read(ctx context.Context) (*service.Message, service.Ack fmt.Println("Error while acking LSN", err) return err } + fmt.Println("Ack LSN", *message.Lsn) } return nil }, nil diff --git a/internal/impl/postgresql/integration_test.go b/internal/impl/postgresql/integration_test.go index eac1b77b8d..ac019dd10c 100644 --- a/internal/impl/postgresql/integration_test.go +++ b/internal/impl/postgresql/integration_test.go @@ -401,3 +401,188 @@ file: db.Close() }) } + +func TestIntegrationPgCDCForPgOutputStreamUncomitedPlugin(t *testing.T) { + tmpDir := t.TempDir() + pool, err := dockertest.NewPool("") + require.NoError(t, err) + + resource, err := pool.RunWithOptions(&dockertest.RunOptions{ + Repository: "postgres", + Tag: "16", + Env: []string{ + "POSTGRES_PASSWORD=secret", + "POSTGRES_USER=user_name", + "POSTGRES_DB=dbname", + }, + Cmd: []string{ + "postgres", + "-c", "wal_level=logical", + }, + }, func(config *docker.HostConfig) { + config.AutoRemove = true + config.RestartPolicy = docker.RestartPolicy{Name: "no"} + }) + + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, pool.Purge(resource)) + }) + + require.NoError(t, resource.Expire(120)) + + hostAndPort := resource.GetHostPort("5432/tcp") + hostAndPortSplited := strings.Split(hostAndPort, ":") + databaseURL := fmt.Sprintf("user=user_name password=secret dbname=dbname sslmode=disable host=%s port=%s", hostAndPortSplited[0], hostAndPortSplited[1]) + + var db *sql.DB + + pool.MaxWait = 120 * time.Second + err = pool.Retry(func() error { + if db, err = sql.Open("postgres", databaseURL); err != nil { + return err + } + + if err = db.Ping(); err != nil { + return err + } + + var walLevel string + if err = db.QueryRow("SHOW wal_level").Scan(&walLevel); err != nil { + return err + } + + if walLevel != "logical" { + return fmt.Errorf("wal_level is not logical") + } + + _, err = db.Exec("CREATE TABLE IF NOT EXISTS flights (id serial PRIMARY KEY, name VARCHAR(50), created_at TIMESTAMP);") + if err != nil { + return err + } + + // flights_non_streamed is a control table with data that should not be streamed or queried by snapshot streaming + _, err = db.Exec("CREATE TABLE IF NOT EXISTS flights_non_streamed (id serial PRIMARY KEY, name VARCHAR(50), created_at TIMESTAMP);") + return err + }) + require.NoError(t, err) + + fake := faker.New() + for i := 0; i < 10; i++ { + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + require.NoError(t, err) + } + + template := fmt.Sprintf(` +pg_stream: + host: %s + slot_name: test_slot_native_decoder + user: user_name + password: secret + port: %s + schema: public + tls: none + stream_snapshot: true + decoding_plugin: pgoutput + stream_uncomited: true + database: dbname + tables: + - flights +`, hostAndPortSplited[0], hostAndPortSplited[1]) + + cacheConf := fmt.Sprintf(` +label: pg_stream_cache +file: + directory: %v +`, tmpDir) + + streamOutBuilder := service.NewStreamBuilder() + require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: OFF`)) + require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) + require.NoError(t, streamOutBuilder.AddInputYAML(template)) + + var outMessages []string + var outMessagesMut sync.Mutex + + require.NoError(t, streamOutBuilder.AddConsumerFunc(func(c context.Context, m *service.Message) error { + msgBytes, err := m.AsBytes() + require.NoError(t, err) + outMessagesMut.Lock() + outMessages = append(outMessages, string(msgBytes)) + fmt.Println("Msg received", string(msgBytes)) + outMessagesMut.Unlock() + return nil + })) + + streamOut, err := streamOutBuilder.Build() + require.NoError(t, err) + + go func() { + _ = streamOut.Run(context.Background()) + }() + + assert.Eventually(t, func() bool { + outMessagesMut.Lock() + defer outMessagesMut.Unlock() + return len(outMessages) == 10 + }, time.Second*25, time.Millisecond*100) + + for i := 0; i < 10; i++ { + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + _, err = db.Exec("INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + require.NoError(t, err) + } + + assert.Eventually(t, func() bool { + outMessagesMut.Lock() + defer outMessagesMut.Unlock() + return len(outMessages) == 20 + }, time.Second*25, time.Millisecond*100) + + require.NoError(t, streamOut.StopWithin(time.Second*10)) + + // Starting stream for the same replication slot should continue from the last LSN + // Meaning we must not receive any old messages again + + streamOutBuilder = service.NewStreamBuilder() + require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: OFF`)) + require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) + require.NoError(t, streamOutBuilder.AddInputYAML(template)) + + outMessages = []string{} + require.NoError(t, streamOutBuilder.AddConsumerFunc(func(c context.Context, m *service.Message) error { + msgBytes, err := m.AsBytes() + require.NoError(t, err) + outMessagesMut.Lock() + outMessages = append(outMessages, string(msgBytes)) + fmt.Println("Msg received", string(msgBytes)) + outMessagesMut.Unlock() + return nil + })) + + streamOut, err = streamOutBuilder.Build() + require.NoError(t, err) + + go func() { + assert.NoError(t, streamOut.Run(context.Background())) + }() + + time.Sleep(time.Second * 5) + for i := 0; i < 10; i++ { + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + require.NoError(t, err) + } + + assert.Eventually(t, func() bool { + outMessagesMut.Lock() + defer outMessagesMut.Unlock() + return len(outMessages) == 10 + }, time.Second*20, time.Millisecond*100) + + require.NoError(t, streamOut.StopWithin(time.Second*10)) + t.Log("All the conditions are met 🎉") + + t.Cleanup(func() { + db.Close() + }) +} diff --git a/internal/impl/postgresql/pglogicalstream/config.go b/internal/impl/postgresql/pglogicalstream/config.go index b9ee6845d5..1c506d743b 100644 --- a/internal/impl/postgresql/pglogicalstream/config.go +++ b/internal/impl/postgresql/pglogicalstream/config.go @@ -44,5 +44,8 @@ type Config struct { // BatchSize is the batch size for streaming BatchSize int `yaml:"batch_size"` + // StreamUncommitted is whether to stream uncommitted messages before receiving commit message + StreamUncomited bool `yaml:"stream_uncommitted"` + logger *service.Logger } diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index e4c0c658c2..669063cb7d 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -53,6 +53,9 @@ type Stream struct { decodingPluginArguments []string snapshotMemorySafetyFactor float64 logger *service.Logger + streamUncomited bool + + lsnAckBuffer []string m sync.Mutex stopped bool @@ -117,8 +120,10 @@ func NewPgStream(config Config) (*Stream, error) { slotName: config.ReplicationSlotName, schema: config.DBSchema, snapshotMemorySafetyFactor: config.SnapshotMemorySafetyFactor, + streamUncomited: config.StreamUncomited, snapshotBatchSize: config.BatchSize, tableNames: tableNames, + lsnAckBuffer: []string{}, logger: config.logger, m: sync.Mutex{}, decodingPlugin: decodingPluginFromString(config.DecodingPlugin), @@ -235,18 +240,20 @@ func (s *Stream) startLr() error { // AckLSN acknowledges the LSN up to which the stream has processed the messages. // This makes Postgres to remove the WAL files that are no longer needed. func (s *Stream) AckLSN(lsn string) error { - var err error - s.clientXLogPos, err = ParseLSN(lsn) + clientXLogPos, err := ParseLSN(lsn) if err != nil { s.logger.Errorf("Failed to parse LSN for Acknowledge: %v", err) if err = s.Stop(); err != nil { s.logger.Errorf("Failed to stop the stream: %v", err) } + + return err } err = SendStandbyStatusUpdate(context.Background(), s.pgConn, StandbyStatusUpdate{ - WALApplyPosition: s.clientXLogPos, - WALWritePosition: s.clientXLogPos, + WALApplyPosition: clientXLogPos, + WALWritePosition: clientXLogPos, + WALFlushPosition: clientXLogPos, ReplyRequested: true, }) @@ -255,6 +262,8 @@ func (s *Stream) AckLSN(lsn string) error { return err } + // Update client XLogPos after we ack the message + s.clientXLogPos = clientXLogPos s.logger.Debugf("Sent Standby status message at LSN#%s", s.clientXLogPos.String()) s.nextStandbyMessageDeadline = time.Now().Add(s.standbyMessageTimeout) @@ -380,47 +389,29 @@ func (s *Stream) streamMessagesAsync() { } if s.decodingPlugin == "pgoutput" { - // message changes must be collected in the buffer in the context of the same transaction - // as single transaction can contain multiple changes - // and LSN ack will cause potential loss of changes - isBegin, err := isBeginMessage(xld.WALData) - if err != nil { - s.logger.Errorf("Failed to parse WAL data: %w", err) - if err = s.Stop(); err != nil { - s.logger.Errorf("Failed to stop the stream: %v", err) - } - return - } - - if isBegin { - pgoutputChanges = []StreamMessageChanges{} - } - - // parse changes inside the transaction - message, err := decodePgOutput(xld.WALData, relations, typeMap) - if err != nil { - s.logger.Errorf("decodePgOutput failed: %w", err) - if err = s.Stop(); err != nil { - s.logger.Errorf("Failed to stop the stream: %v", err) + fmt.Println(string(xld.WALData)) + if s.streamUncomited { + // parse changes inside the transaction + message, err := decodePgOutput(xld.WALData, relations, typeMap) + if err != nil { + s.logger.Errorf("decodePgOutput failed: %w", err) + if err = s.Stop(); err != nil { + s.logger.Errorf("Failed to stop the stream: %v", err) + } + return } - return - } - if message != nil { - pgoutputChanges = append(pgoutputChanges, *message) - } + if message == nil { + if ok, _ := isBeginMessage(xld.WALData); ok { + fmt.Println("Begin message on ", clientXLogPos.String()) + } - isCommit, err := isCommitMessage(xld.WALData) - if err != nil { - s.logger.Errorf("Failed to parse WAL data: %w", err) - if err = s.Stop(); err != nil { - s.logger.Errorf("Failed to stop the stream: %v", err) - } - return - } + if ok, commit, _ := isCommitMessage(xld.WALData); ok { + fmt.Println("Commit message on ", clientXLogPos.String()) + fmt.Println("Commit transaction end", commit.TransactionEndLSN.String()) + clientXLogPos = commit.TransactionEndLSN + } - if isCommit { - if len(pgoutputChanges) == 0 { // 0 changes happened in the transaction // or we received a change that are not supported/needed by the replication stream if err = s.AckLSN(clientXLogPos.String()); err != nil { @@ -432,9 +423,69 @@ func (s *Stream) streamMessagesAsync() { return } } else { - // send all collected changes lsn := clientXLogPos.String() - s.messages <- StreamMessage{Lsn: &lsn, Changes: pgoutputChanges} + fmt.Println("Message on LSN", lsn) + s.messages <- StreamMessage{Lsn: &lsn, Changes: []StreamMessageChanges{ + *message, + }} + } + } else { + // message changes must be collected in the buffer in the context of the same transaction + // as single transaction can contain multiple changes + // and LSN ack will cause potential loss of changes + isBegin, err := isBeginMessage(xld.WALData) + if err != nil { + s.logger.Errorf("Failed to parse WAL data: %w", err) + if err = s.Stop(); err != nil { + s.logger.Errorf("Failed to stop the stream: %v", err) + } + return + } + + if isBegin { + pgoutputChanges = []StreamMessageChanges{} + } + + // parse changes inside the transaction + message, err := decodePgOutput(xld.WALData, relations, typeMap) + if err != nil { + s.logger.Errorf("decodePgOutput failed: %w", err) + if err = s.Stop(); err != nil { + s.logger.Errorf("Failed to stop the stream: %v", err) + } + return + } + + if message != nil { + pgoutputChanges = append(pgoutputChanges, *message) + } + + isCommit, _, err := isCommitMessage(xld.WALData) + if err != nil { + s.logger.Errorf("Failed to parse WAL data: %w", err) + if err = s.Stop(); err != nil { + s.logger.Errorf("Failed to stop the stream: %v", err) + } + return + } + + if isCommit { + if len(pgoutputChanges) == 0 { + // 0 changes happened in the transaction + // or we received a change that are not supported/needed by the replication stream + if err = s.AckLSN(clientXLogPos.String()); err != nil { + // stop reading from replication slot + // if we can't acknowledge the LSN + if err = s.Stop(); err != nil { + s.logger.Errorf("Failed to stop the stream: %v", err) + } + return + } + } else { + // send all collected changes + lsn := clientXLogPos.String() + s.messages <- StreamMessage{Lsn: &lsn, Changes: pgoutputChanges} + } } } } diff --git a/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go b/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go index fbcec320a3..d38aa11a7e 100644 --- a/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go +++ b/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go @@ -354,7 +354,7 @@ drop table t; xld = rxXLogData() var commit bool - commit, err = isCommitMessage(xld.WALData) + commit, _, err = isCommitMessage(xld.WALData) require.NoError(t, err) assert.True(t, commit) } diff --git a/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go b/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go index 461cb90378..829ba906fe 100644 --- a/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go +++ b/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go @@ -22,14 +22,14 @@ func isBeginMessage(WALData []byte) (bool, error) { return ok, nil } -func isCommitMessage(WALData []byte) (bool, error) { +func isCommitMessage(WALData []byte) (bool, *CommitMessage, error) { logicalMsg, err := Parse(WALData) if err != nil { - return false, err + return false, nil, err } - _, ok := logicalMsg.(*CommitMessage) - return ok, nil + m, ok := logicalMsg.(*CommitMessage) + return ok, m, nil } // decodePgOutput decodes a logical replication message in pgoutput format. From 878fdc54528f37d12ba957c3d887d8f5397883ce Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Wed, 9 Oct 2024 10:41:20 +0200 Subject: [PATCH 014/118] fix(postgres): correct order for message LSN ack --- internal/impl/postgresql/input_postgrecdc.go | 15 ++++++++------ internal/impl/postgresql/integration_test.go | 1 - .../pglogicalstream/logical_stream.go | 20 ++++++++----------- 3 files changed, 17 insertions(+), 19 deletions(-) diff --git a/internal/impl/postgresql/input_postgrecdc.go b/internal/impl/postgresql/input_postgrecdc.go index fa5357dfae..274529bf6a 100644 --- a/internal/impl/postgresql/input_postgrecdc.go +++ b/internal/impl/postgresql/input_postgrecdc.go @@ -12,7 +12,6 @@ import ( "context" "crypto/tls" "encoding/json" - "fmt" "strings" "github.com/jackc/pgx/v5/pgconn" @@ -72,7 +71,7 @@ var pgStreamConfigSpec = service.NewConfigSpec(). Example("my_test_slot"). Default(randomSlotName)) -func newPgStreamInput(conf *service.ParsedConfig, logger *service.Logger) (s service.Input, err error) { +func newPgStreamInput(conf *service.ParsedConfig, logger *service.Logger, metrics *service.Metrics) (s service.Input, err error) { var ( dbName string dbPort int @@ -198,9 +197,11 @@ func newPgStreamInput(conf *service.ParsedConfig, logger *service.Logger) (s ser tls: pglogicalstream.TLSVerify(tlsSetting), tables: tables, decodingPlugin: decodingPlugin, - logger: logger, streamUncomited: streamUncomited, temporarySlot: temporarySlot, + + logger: logger, + metrics: metrics, }), err } @@ -211,7 +212,7 @@ func init() { err := service.RegisterInput( "pg_stream", pgStreamConfigSpec, func(conf *service.ParsedConfig, mgr *service.Resources) (service.Input, error) { - return newPgStreamInput(conf, mgr.Logger()) + return newPgStreamInput(conf, mgr.Logger(), mgr.Metrics()) }) if err != nil { panic(err) @@ -232,6 +233,7 @@ type pgStreamInput struct { snapshotMemSafetyFactor float64 streamUncomited bool logger *service.Logger + metrics *service.Metrics } func (p *pgStreamInput) Connect(ctx context.Context) error { @@ -287,10 +289,11 @@ func (p *pgStreamInput) Read(ctx context.Context) (*service.Message, service.Ack if message.Lsn != nil { if err := p.pglogicalStream.AckLSN(*message.Lsn); err != nil { - fmt.Println("Error while acking LSN", err) return err } - fmt.Println("Ack LSN", *message.Lsn) + if p.streamUncomited { + p.pglogicalStream.ConsumedCallback() <- true + } } return nil }, nil diff --git a/internal/impl/postgresql/integration_test.go b/internal/impl/postgresql/integration_test.go index ac019dd10c..0363faec00 100644 --- a/internal/impl/postgresql/integration_test.go +++ b/internal/impl/postgresql/integration_test.go @@ -555,7 +555,6 @@ file: require.NoError(t, err) outMessagesMut.Lock() outMessages = append(outMessages, string(msgBytes)) - fmt.Println("Msg received", string(msgBytes)) outMessagesMut.Unlock() return nil })) diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 669063cb7d..87fff10450 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -59,6 +59,8 @@ type Stream struct { m sync.Mutex stopped bool + + consumedCallback chan bool } // NewPgStream creates a new instance of the Stream struct @@ -123,6 +125,7 @@ func NewPgStream(config Config) (*Stream, error) { streamUncomited: config.StreamUncomited, snapshotBatchSize: config.BatchSize, tableNames: tableNames, + consumedCallback: make(chan bool), lsnAckBuffer: []string{}, logger: config.logger, m: sync.Mutex{}, @@ -228,6 +231,10 @@ func NewPgStream(config Config) (*Stream, error) { return stream, err } +func (s *Stream) ConsumedCallback() chan bool { + return s.consumedCallback +} + func (s *Stream) startLr() error { if err := StartReplication(context.Background(), s.pgConn, s.slotName, s.lsnrestart, StartReplicationOptions{PluginArgs: s.decodingPluginArguments}); err != nil { return err @@ -389,7 +396,6 @@ func (s *Stream) streamMessagesAsync() { } if s.decodingPlugin == "pgoutput" { - fmt.Println(string(xld.WALData)) if s.streamUncomited { // parse changes inside the transaction message, err := decodePgOutput(xld.WALData, relations, typeMap) @@ -402,16 +408,6 @@ func (s *Stream) streamMessagesAsync() { } if message == nil { - if ok, _ := isBeginMessage(xld.WALData); ok { - fmt.Println("Begin message on ", clientXLogPos.String()) - } - - if ok, commit, _ := isCommitMessage(xld.WALData); ok { - fmt.Println("Commit message on ", clientXLogPos.String()) - fmt.Println("Commit transaction end", commit.TransactionEndLSN.String()) - clientXLogPos = commit.TransactionEndLSN - } - // 0 changes happened in the transaction // or we received a change that are not supported/needed by the replication stream if err = s.AckLSN(clientXLogPos.String()); err != nil { @@ -424,10 +420,10 @@ func (s *Stream) streamMessagesAsync() { } } else { lsn := clientXLogPos.String() - fmt.Println("Message on LSN", lsn) s.messages <- StreamMessage{Lsn: &lsn, Changes: []StreamMessageChanges{ *message, }} + <-s.consumedCallback } } else { // message changes must be collected in the buffer in the context of the same transaction From 5b4a8352ff105a12d9e02ec427f9316f434077b8 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Thu, 10 Oct 2024 17:31:46 +0200 Subject: [PATCH 015/118] chore(): removed log line --- internal/impl/postgresql/integration_test.go | 177 ++++++++++++++++++- 1 file changed, 176 insertions(+), 1 deletion(-) diff --git a/internal/impl/postgresql/integration_test.go b/internal/impl/postgresql/integration_test.go index 0363faec00..677b038e2b 100644 --- a/internal/impl/postgresql/integration_test.go +++ b/internal/impl/postgresql/integration_test.go @@ -509,7 +509,6 @@ file: require.NoError(t, err) outMessagesMut.Lock() outMessages = append(outMessages, string(msgBytes)) - fmt.Println("Msg received", string(msgBytes)) outMessagesMut.Unlock() return nil })) @@ -585,3 +584,179 @@ file: db.Close() }) } + +func bulkInsert(db *sql.DB, generateData func() (string, time.Time), totalInserts int) error { + const batchSize = 10000 + + for i := 0; i < totalInserts; i += batchSize { + end := i + batchSize + if end > totalInserts { + end = totalInserts + } + + valueStrings := make([]string, 0, batchSize) + valueArgs := make([]interface{}, 0, batchSize*2) + + for j := 0; j < end-i; j++ { + valueStrings = append(valueStrings, fmt.Sprintf("($%d, $%d)", j*2+1, j*2+2)) + name, createdAt := generateData() + valueArgs = append(valueArgs, name, createdAt) + } + + stmt := fmt.Sprintf("INSERT INTO flights (name, created_at) VALUES %s", + strings.Join(valueStrings, ",")) + + _, err := db.Exec(stmt, valueArgs...) + if err != nil { + return fmt.Errorf("bulk insert failed: %w", err) + } + } + + return nil +} + +func TestIntegrationPgCDCForPgOutputStreamUncomitedPluginForNeonTech(t *testing.T) { + tmpDir := t.TempDir() + + fake := faker.New() + generateData := func() (string, time.Time) { + return fake.Address().City(), fake.Time().Time(time.Now()) + } + + databaseURL := "user=redpanda_owner password=MwDzur6AWUZ4 dbname=redpanda sslmode=require host=ep-holy-hill-a5zyhish.us-east-2.aws.neon.tech port=5432" + + var ( + db *sql.DB + err error + ) + + db, err = sql.Open("postgres", databaseURL) + require.NoError(t, err) + + err = db.Ping() + require.NoError(t, err) + + var walLevel string + err = db.QueryRow("SHOW wal_level").Scan(&walLevel) + require.NoError(t, err) + require.Equal(t, "logical", walLevel) + + _, err = db.Exec("DROP TABLE IF EXISTS flights") + require.NoError(t, err) + + _, err = db.Exec("CREATE TABLE IF NOT EXISTS flights (id serial PRIMARY KEY, name VARCHAR(50), created_at TIMESTAMP);") + require.NoError(t, err) + + err = bulkInsert(db, generateData, 100000) + require.NoError(t, err) + + template := fmt.Sprintf(` +pg_stream: + host: %s + slot_name: my_pg_slot_to_check_wal + user: redpanda_owner + password: MwDzur6AWUZ4 + port: %s + schema: public + tls: require + stream_snapshot: true + stream_uncomited: true + database: redpanda + temporary_slot: false + pg_conn_options: "endpoint=ep-holy-hill-a5zyhish" + tables: + - flights +`, "ep-holy-hill-a5zyhish.us-east-2.aws.neon.tech", "5432") + + cacheConf := fmt.Sprintf(` +label: pg_stream_cache +file: + directory: %v +`, tmpDir) + + streamOutBuilder := service.NewStreamBuilder() + require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: OFF`)) + require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) + require.NoError(t, streamOutBuilder.AddInputYAML(template)) + + var outMessages []string + var outMessagesMut sync.Mutex + + require.NoError(t, streamOutBuilder.AddConsumerFunc(func(c context.Context, m *service.Message) error { + msgBytes, err := m.AsBytes() + require.NoError(t, err) + outMessagesMut.Lock() + outMessages = append(outMessages, string(msgBytes)) + outMessagesMut.Unlock() + return nil + })) + + streamOut, err := streamOutBuilder.Build() + require.NoError(t, err) + + go func() { + _ = streamOut.Run(context.Background()) + }() + + assert.Eventually(t, func() bool { + outMessagesMut.Lock() + defer outMessagesMut.Unlock() + fmt.Println("Messages count", len(outMessages)) + return len(outMessages) == 100000 + }, time.Minute, time.Second) + + err = bulkInsert(db, generateData, 100000) + require.NoError(t, err) + + assert.Eventually(t, func() bool { + outMessagesMut.Lock() + defer outMessagesMut.Unlock() + fmt.Println("Messages count", len(outMessages)) + return len(outMessages) == 200000 + }, time.Minute, time.Second) + + require.NoError(t, streamOut.StopWithin(time.Second*10)) + + // Starting stream for the same replication slot should continue from the last LSN + // Meaning we must not receive any old messages again + + streamOutBuilder = service.NewStreamBuilder() + require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: OFF`)) + require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) + require.NoError(t, streamOutBuilder.AddInputYAML(template)) + + outMessages = []string{} + require.NoError(t, streamOutBuilder.AddConsumerFunc(func(c context.Context, m *service.Message) error { + msgBytes, err := m.AsBytes() + require.NoError(t, err) + outMessagesMut.Lock() + outMessages = append(outMessages, string(msgBytes)) + outMessagesMut.Unlock() + return nil + })) + + streamOut, err = streamOutBuilder.Build() + require.NoError(t, err) + + go func() { + assert.NoError(t, streamOut.Run(context.Background())) + }() + + time.Sleep(time.Second * 5) + err = bulkInsert(db, generateData, 10000) + require.NoError(t, err) + + assert.Eventually(t, func() bool { + outMessagesMut.Lock() + defer outMessagesMut.Unlock() + fmt.Println("Messages", len(outMessages)) + return len(outMessages) == 10000 + }, time.Second*20, time.Second) + + require.NoError(t, streamOut.StopWithin(time.Second*10)) + t.Log("All the conditions are met 🎉") + + t.Cleanup(func() { + db.Close() + }) +} From 8cf7eddabe5c74fee58f88a62a7a36cb5b0f9801 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Fri, 11 Oct 2024 13:02:00 +0200 Subject: [PATCH 016/118] chore(): working on metrics --- internal/impl/postgresql/input_postgrecdc.go | 24 ++- .../postgresql/pglogicalstream/debouncer.go | 35 ++++ .../pglogicalstream/logical_stream.go | 24 +++ .../postgresql/pglogicalstream/monitor.go | 164 ++++++++++++++++++ .../postgresql/pglogicalstream/snapshotter.go | 31 ++-- .../impl/postgresql/pglogicalstream/util.go | 22 +++ 6 files changed, 287 insertions(+), 13 deletions(-) create mode 100644 internal/impl/postgresql/pglogicalstream/debouncer.go create mode 100644 internal/impl/postgresql/pglogicalstream/monitor.go create mode 100644 internal/impl/postgresql/pglogicalstream/util.go diff --git a/internal/impl/postgresql/input_postgrecdc.go b/internal/impl/postgresql/input_postgrecdc.go index 274529bf6a..fe09e7ef7f 100644 --- a/internal/impl/postgresql/input_postgrecdc.go +++ b/internal/impl/postgresql/input_postgrecdc.go @@ -13,6 +13,7 @@ import ( "crypto/tls" "encoding/json" "strings" + "time" "github.com/jackc/pgx/v5/pgconn" "github.com/lucasepe/codename" @@ -187,6 +188,9 @@ func newPgStreamInput(conf *service.ParsedConfig, logger *service.Logger, metric pgconnConfig.TLSConfig = nil } + snapsotMetrics := metrics.NewGauge("snapshot_progress") + replicationLag := metrics.NewGauge("replication_lag") + return service.AutoRetryNacks(&pgStreamInput{ dbConfig: pgconnConfig, streamSnapshot: streamSnapshot, @@ -200,8 +204,10 @@ func newPgStreamInput(conf *service.ParsedConfig, logger *service.Logger, metric streamUncomited: streamUncomited, temporarySlot: temporarySlot, - logger: logger, - metrics: metrics, + logger: logger, + metrics: metrics, + snapshotMetrics: snapsotMetrics, + replicationLag: replicationLag, }), err } @@ -234,6 +240,10 @@ type pgStreamInput struct { streamUncomited bool logger *service.Logger metrics *service.Metrics + metricsTicker *time.Ticker + + snapshotMetrics *service.MetricGauge + replicationLag *service.MetricGauge } func (p *pgStreamInput) Connect(ctx context.Context) error { @@ -257,12 +267,22 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { if err != nil { return err } + + p.metricsTicker = time.NewTicker(5 * time.Second) p.pglogicalStream = pgStream + return err } func (p *pgStreamInput) Read(ctx context.Context) (*service.Message, service.AckFunc, error) { + select { + case <-p.metricsTicker.C: + progress := p.pglogicalStream.GetProgress() + for table, progress := range progress.TableProgress { + p.snapshotMetrics.Set(int64(progress), table) + } + p.replicationLag.Set(progress.WalLagInBytes) case snapshotMessage := <-p.pglogicalStream.SnapshotMessageC(): var ( mb []byte diff --git a/internal/impl/postgresql/pglogicalstream/debouncer.go b/internal/impl/postgresql/pglogicalstream/debouncer.go new file mode 100644 index 0000000000..1ddc8b3eea --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/debouncer.go @@ -0,0 +1,35 @@ +package pglogicalstream + +import ( + "sync" + "time" +) + +// New returns a debounced function that takes another functions as its argument. +// This function will be called when the debounced function stops being called +// for the given duration. +// The debounced function can be invoked with different functions, if needed, +// the last one will win. +func NewDebouncer(after time.Duration) func(f func()) { + d := &debouncer{after: after} + + return func(f func()) { + d.add(f) + } +} + +type debouncer struct { + mu sync.Mutex + after time.Duration + timer *time.Timer +} + +func (d *debouncer) add(f func()) { + d.mu.Lock() + defer d.mu.Unlock() + + if d.timer != nil { + d.timer.Stop() + } + d.timer = time.AfterFunc(d.after, f) +} diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 87fff10450..2cacc537f9 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -53,6 +53,7 @@ type Stream struct { decodingPluginArguments []string snapshotMemorySafetyFactor float64 logger *service.Logger + monitor *Monitor streamUncomited bool lsnAckBuffer []string @@ -216,6 +217,12 @@ func NewPgStream(config Config) (*Stream, error) { stream.nextStandbyMessageDeadline = time.Now().Add(stream.standbyMessageTimeout) stream.streamCtx, stream.streamCancel = context.WithCancel(context.Background()) + monitor, err := NewMonitor(cfg, stream.logger, tableNames, stream.slotName) + if err != nil { + return nil, err + } + stream.monitor = monitor + if !freshlyCreatedSlot || !config.StreamOldData { if err = stream.startLr(); err != nil { return nil, err @@ -231,6 +238,12 @@ func NewPgStream(config Config) (*Stream, error) { return stream, err } +// GetProgress returns the progress of the stream. +// including the % of snapsho messages processed and the WAL lag in bytes. +func (s *Stream) GetProgress() *Report { + return s.monitor.Report() +} + func (s *Stream) ConsumedCallback() chan bool { return s.consumedCallback } @@ -517,6 +530,16 @@ func (s *Stream) processSnapshot() { } }() + tableStats, err := snapshotter.GetRowsCountPerTable(s.tableNames) + if err != nil { + s.logger.Errorf("Failed to get table stats: %v", err.Error()) + if err = s.cleanUpOnFailure(); err != nil { + s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) + } + + os.Exit(1) + } + for _, table := range s.tableNames { s.logger.Infof("Processing snapshot for table: %v", table) @@ -733,6 +756,7 @@ func (s *Stream) Stop() error { s.m.Lock() s.stopped = true s.m.Unlock() + s.monitor.Stop() if s.pgConn != nil { if s.streamCtx != nil { diff --git a/internal/impl/postgresql/pglogicalstream/monitor.go b/internal/impl/postgresql/pglogicalstream/monitor.go new file mode 100644 index 0000000000..b5e4abdc91 --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/monitor.go @@ -0,0 +1,164 @@ +package pglogicalstream + +import ( + "context" + "database/sql" + "fmt" + "strings" + "sync" + "time" + + "github.com/jackc/pgx/v5/pgconn" + "github.com/redpanda-data/benthos/v4/public/service" +) + +type Report struct { + WalLagInBytes int64 + TableProgress map[string]float64 +} + +type Monitor struct { + // tableStat contains numbers of rows for each table determined at the moment of the snapshot creation + // this is used to calculate snapshot ingestion progress + tableStat map[string]int64 + lock sync.Mutex + // snapshotProgress is a map of table names to the percentage of rows ingested from the snapshot + snapshotProgress map[string]float64 + // replicationLagInBytes is the replication lag in bytes measured by + // finding the difference between the latest LSN and the last confirmed LSN for the replication slot + replicationLagInBytes int64 + + debounced func(f func()) + + dbConn *sql.DB + slotName string + logger *service.Logger + ticker *time.Ticker + cancelTicker context.CancelFunc + ctx context.Context +} + +func NewMonitor(conf *pgconn.Config, logger *service.Logger, tables []string, slotName string) (*Monitor, error) { + // debounces is user to throttle locks on the monitor to prevent unnecessary updates that would affect the performance + debounced := NewDebouncer(100 * time.Millisecond) + dbConn, err := openPgConnectionFromConfig(*conf) + if err != nil { + return nil, err + } + + m := &Monitor{ + snapshotProgress: map[string]float64{}, + replicationLagInBytes: 0, + debounced: debounced, + dbConn: dbConn, + slotName: slotName, + logger: logger, + } + + if err = m.readTablesStat(tables); err != nil { + return nil, err + } + + ctx, cancel := context.WithCancel(context.Background()) + m.ctx = ctx + m.cancelTicker = cancel + m.ticker = time.NewTicker(5 * time.Second) + + return m, nil +} + +// UpdateSnapshotProgressForTable updates the snapshot ingestion progress for a given table +func (m *Monitor) UpdateSnapshotProgressForTable(table string, position int) { + storeSnapshotProgress := func() { + m.lock.Lock() + defer m.lock.Unlock() + m.snapshotProgress[table] = float64(position) / float64(m.tableStat[table]) * 100 + } + + m.debounced(storeSnapshotProgress) +} + +// we need to read the tables stat to calculate the snapshot ingestion progress +func (m *Monitor) readTablesStat(tables []string) error { + results := make(map[string]int64) + + // Construct the query + queryParts := make([]string, len(tables)) + for i, table := range tables { + queryParts[i] = fmt.Sprintf("SELECT '%s' AS table_name, COUNT(*) FROM %s", table, table) + } + query := strings.Join(queryParts, " UNION ALL ") + + // Execute the query + rows, err := m.dbConn.Query(query) + if err != nil { + return err + } + defer rows.Close() + + // Process the results + for rows.Next() { + var tableName string + var count int64 + if err := rows.Scan(&tableName, &count); err != nil { + return err + } + results[tableName] = count + } + + if err := rows.Err(); err != nil { + return err + } + + m.tableStat = results + return nil +} + +func (m *Monitor) readReplicationLag() { + result, err := m.dbConn.Query(`SELECT slot_name, + pg_wal_lsn_diff(pg_current_wal_lsn(), restart_lsn) AS lag_bytes + FROM pg_replication_slots WHERE slot_name = ?;`, m.slotName) + // calculate the replication lag in bytes + // replicationLagInBytes = latestLsn - confirmedLsn + if result.Err() != nil || err != nil { + m.logger.Errorf("Error reading replication lag: %v", err) + return + } + + var slotName string + var lagbytes int64 + if err = result.Scan(&slotName, &lagbytes); err != nil { + m.logger.Errorf("Error reading replication lag: %v", err) + return + } + + m.replicationLagInBytes = lagbytes +} + +func (m *Monitor) Report() *Report { + m.lock.Lock() + defer m.lock.Unlock() + // report the snapshot ingestion progress + // report the replication lag + return &Report{ + WalLagInBytes: m.replicationLagInBytes, + TableProgress: m.snapshotProgress, + } +} + +func (m *Monitor) Stop() { + m.cancelTicker() + m.ticker.Stop() + m.dbConn.Close() +} + +func (m *Monitor) startSync() { + for { + select { + case <-m.ctx.Done(): + return + case <-m.ticker.C: + m.readReplicationLag() + } + } +} diff --git a/internal/impl/postgresql/pglogicalstream/snapshotter.go b/internal/impl/postgresql/pglogicalstream/snapshotter.go index 1ecbc50ac1..f20580a7c3 100644 --- a/internal/impl/postgresql/pglogicalstream/snapshotter.go +++ b/internal/impl/postgresql/pglogicalstream/snapshotter.go @@ -33,17 +33,7 @@ type Snapshotter struct { // NewSnapshotter creates a new Snapshotter instance func NewSnapshotter(dbConf pgconn.Config, snapshotName string, logger *service.Logger) (*Snapshotter, error) { - var sslMode string - if dbConf.TLSConfig != nil { - sslMode = "require" - } else { - sslMode = "disable" - } - connStr := fmt.Sprintf("user=%s password=%s host=%s port=%d dbname=%s sslmode=%s", dbConf.User, - dbConf.Password, dbConf.Host, dbConf.Port, dbConf.Database, sslMode, - ) - - pgConn, err := sql.Open("postgres", connStr) + pgConn, err := openPgConnectionFromConfig(dbConf) return &Snapshotter{ pgConnection: pgConn, @@ -63,6 +53,25 @@ func (s *Snapshotter) prepare() error { return nil } +func (s *Snapshotter) GetRowsCountPerTable(tableNames []string) (map[string]int, error) { + tables := make(map[string]int) + rows, err := s.pgConnection.Query("SELECT table_name, count(*) FROM information_schema.tables WHERE table_name in (?) GROUP BY table_name;", tableNames) + if err != nil { + return tables, err + } + + for rows.Next() { + var tableName string + var count int + if err := rows.Scan(&tableName, &count); err != nil { + return tables, err + } + tables[tableName] = count + } + + return tables, nil +} + func (s *Snapshotter) findAvgRowSize(table string) (sql.NullInt64, error) { var ( avgRowSize sql.NullInt64 diff --git a/internal/impl/postgresql/pglogicalstream/util.go b/internal/impl/postgresql/pglogicalstream/util.go new file mode 100644 index 0000000000..041427c2cd --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/util.go @@ -0,0 +1,22 @@ +package pglogicalstream + +import ( + "database/sql" + "fmt" + + "github.com/jackc/pgx/v5/pgconn" +) + +func openPgConnectionFromConfig(dbConf pgconn.Config) (*sql.DB, error) { + var sslMode string + if dbConf.TLSConfig != nil { + sslMode = "require" + } else { + sslMode = "disable" + } + connStr := fmt.Sprintf("user=%s password=%s host=%s port=%d dbname=%s sslmode=%s", dbConf.User, + dbConf.Password, dbConf.Host, dbConf.Port, dbConf.Database, sslMode, + ) + + return sql.Open("postgres", connStr) +} From 2149189a99cac7fbe88db3d92826d704ac06fa6b Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Sat, 12 Oct 2024 11:51:01 +0200 Subject: [PATCH 017/118] chore(): removed test case && working on monitor testing --- internal/impl/postgresql/integration_test.go | 146 ------------------ .../pglogicalstream/monitor_test.go | 71 +++++++++ 2 files changed, 71 insertions(+), 146 deletions(-) create mode 100644 internal/impl/postgresql/pglogicalstream/monitor_test.go diff --git a/internal/impl/postgresql/integration_test.go b/internal/impl/postgresql/integration_test.go index 677b038e2b..be8afb5fab 100644 --- a/internal/impl/postgresql/integration_test.go +++ b/internal/impl/postgresql/integration_test.go @@ -614,149 +614,3 @@ func bulkInsert(db *sql.DB, generateData func() (string, time.Time), totalInsert return nil } - -func TestIntegrationPgCDCForPgOutputStreamUncomitedPluginForNeonTech(t *testing.T) { - tmpDir := t.TempDir() - - fake := faker.New() - generateData := func() (string, time.Time) { - return fake.Address().City(), fake.Time().Time(time.Now()) - } - - databaseURL := "user=redpanda_owner password=MwDzur6AWUZ4 dbname=redpanda sslmode=require host=ep-holy-hill-a5zyhish.us-east-2.aws.neon.tech port=5432" - - var ( - db *sql.DB - err error - ) - - db, err = sql.Open("postgres", databaseURL) - require.NoError(t, err) - - err = db.Ping() - require.NoError(t, err) - - var walLevel string - err = db.QueryRow("SHOW wal_level").Scan(&walLevel) - require.NoError(t, err) - require.Equal(t, "logical", walLevel) - - _, err = db.Exec("DROP TABLE IF EXISTS flights") - require.NoError(t, err) - - _, err = db.Exec("CREATE TABLE IF NOT EXISTS flights (id serial PRIMARY KEY, name VARCHAR(50), created_at TIMESTAMP);") - require.NoError(t, err) - - err = bulkInsert(db, generateData, 100000) - require.NoError(t, err) - - template := fmt.Sprintf(` -pg_stream: - host: %s - slot_name: my_pg_slot_to_check_wal - user: redpanda_owner - password: MwDzur6AWUZ4 - port: %s - schema: public - tls: require - stream_snapshot: true - stream_uncomited: true - database: redpanda - temporary_slot: false - pg_conn_options: "endpoint=ep-holy-hill-a5zyhish" - tables: - - flights -`, "ep-holy-hill-a5zyhish.us-east-2.aws.neon.tech", "5432") - - cacheConf := fmt.Sprintf(` -label: pg_stream_cache -file: - directory: %v -`, tmpDir) - - streamOutBuilder := service.NewStreamBuilder() - require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: OFF`)) - require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) - require.NoError(t, streamOutBuilder.AddInputYAML(template)) - - var outMessages []string - var outMessagesMut sync.Mutex - - require.NoError(t, streamOutBuilder.AddConsumerFunc(func(c context.Context, m *service.Message) error { - msgBytes, err := m.AsBytes() - require.NoError(t, err) - outMessagesMut.Lock() - outMessages = append(outMessages, string(msgBytes)) - outMessagesMut.Unlock() - return nil - })) - - streamOut, err := streamOutBuilder.Build() - require.NoError(t, err) - - go func() { - _ = streamOut.Run(context.Background()) - }() - - assert.Eventually(t, func() bool { - outMessagesMut.Lock() - defer outMessagesMut.Unlock() - fmt.Println("Messages count", len(outMessages)) - return len(outMessages) == 100000 - }, time.Minute, time.Second) - - err = bulkInsert(db, generateData, 100000) - require.NoError(t, err) - - assert.Eventually(t, func() bool { - outMessagesMut.Lock() - defer outMessagesMut.Unlock() - fmt.Println("Messages count", len(outMessages)) - return len(outMessages) == 200000 - }, time.Minute, time.Second) - - require.NoError(t, streamOut.StopWithin(time.Second*10)) - - // Starting stream for the same replication slot should continue from the last LSN - // Meaning we must not receive any old messages again - - streamOutBuilder = service.NewStreamBuilder() - require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: OFF`)) - require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) - require.NoError(t, streamOutBuilder.AddInputYAML(template)) - - outMessages = []string{} - require.NoError(t, streamOutBuilder.AddConsumerFunc(func(c context.Context, m *service.Message) error { - msgBytes, err := m.AsBytes() - require.NoError(t, err) - outMessagesMut.Lock() - outMessages = append(outMessages, string(msgBytes)) - outMessagesMut.Unlock() - return nil - })) - - streamOut, err = streamOutBuilder.Build() - require.NoError(t, err) - - go func() { - assert.NoError(t, streamOut.Run(context.Background())) - }() - - time.Sleep(time.Second * 5) - err = bulkInsert(db, generateData, 10000) - require.NoError(t, err) - - assert.Eventually(t, func() bool { - outMessagesMut.Lock() - defer outMessagesMut.Unlock() - fmt.Println("Messages", len(outMessages)) - return len(outMessages) == 10000 - }, time.Second*20, time.Second) - - require.NoError(t, streamOut.StopWithin(time.Second*10)) - t.Log("All the conditions are met 🎉") - - t.Cleanup(func() { - db.Close() - }) -} diff --git a/internal/impl/postgresql/pglogicalstream/monitor_test.go b/internal/impl/postgresql/pglogicalstream/monitor_test.go new file mode 100644 index 0000000000..211f89e212 --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/monitor_test.go @@ -0,0 +1,71 @@ +package pglogicalstream + +import ( + "database/sql" + "fmt" + "strings" + "testing" + "time" + + "github.com/ory/dockertest/v3" + "github.com/ory/dockertest/v3/docker" + "github.com/stretchr/testify/require" +) + +func Test_MonitorReplorting(t *testing.T) { + pool, err := dockertest.NewPool("") + require.NoError(t, err) + + resource, err := pool.RunWithOptions(&dockertest.RunOptions{ + Repository: "postgres", + Tag: "16", + Env: []string{ + "POSTGRES_PASSWORD=secret", + "POSTGRES_USER=user_name", + "POSTGRES_DB=dbname", + }, + Cmd: []string{ + "postgres", + "-c", "wal_level=logical", + }, + }, func(config *docker.HostConfig) { + config.AutoRemove = true + config.RestartPolicy = docker.RestartPolicy{Name: "no"} + }) + + require.NoError(t, err) + require.NoError(t, resource.Expire(120)) + + hostAndPort := resource.GetHostPort("5432/tcp") + hostAndPortSplited := strings.Split(hostAndPort, ":") + databaseURL := fmt.Sprintf("user=user_name password=secret dbname=dbname sslmode=disable host=%s port=%s replication=database", hostAndPortSplited[0], hostAndPortSplited[1]) + + var db *sql.DB + pool.MaxWait = 120 * time.Second + err = pool.Retry(func() error { + if db, err = sql.Open("postgres", databaseURL); err != nil { + return err + } + + if err = db.Ping(); err != nil { + return err + } + + return err + }) + require.NoError(t, err) + + _, err = db.Exec("CREATE TABLE IF NOT EXISTS flights (id serial PRIMARY KEY, name VARCHAR(50), created_at TIMESTAMP);") + if err != nil { + return err + } + + fake := faker.New() + for i := 0; i < 1000; i++ { + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + _, err = db.Exec("INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + require.NoError(t, err) + } + + mon := NewMonitor(conf *pgconn.Config, logger *service.Logger, tables []string, slotName string) +} From 7bf08989c38e7205bf566c5e770b249450bfaa86 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Sun, 13 Oct 2024 21:34:45 +0200 Subject: [PATCH 018/118] chore(): monitor testing --- .../impl/postgresql/pglogicalstream/monitor_test.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/internal/impl/postgresql/pglogicalstream/monitor_test.go b/internal/impl/postgresql/pglogicalstream/monitor_test.go index 211f89e212..ca1dca7c40 100644 --- a/internal/impl/postgresql/pglogicalstream/monitor_test.go +++ b/internal/impl/postgresql/pglogicalstream/monitor_test.go @@ -7,8 +7,10 @@ import ( "testing" "time" + "github.com/jaswdr/faker" "github.com/ory/dockertest/v3" "github.com/ory/dockertest/v3/docker" + "github.com/redpanda-data/benthos/v4/public/service" "github.com/stretchr/testify/require" ) @@ -56,16 +58,14 @@ func Test_MonitorReplorting(t *testing.T) { require.NoError(t, err) _, err = db.Exec("CREATE TABLE IF NOT EXISTS flights (id serial PRIMARY KEY, name VARCHAR(50), created_at TIMESTAMP);") - if err != nil { - return err - } + require.NoError(t, err) fake := faker.New() for i := 0; i < 1000; i++ { _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) - _, err = db.Exec("INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) require.NoError(t, err) } - mon := NewMonitor(conf *pgconn.Config, logger *service.Logger, tables []string, slotName string) + slotName := "test_slot" + mon := NewMonitor(db, logger*service.Logger, []string{"flights"}, slotName) } From f18c8e44829f9ce9b6557b21ec1cb544c2b18a81 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Tue, 15 Oct 2024 13:25:23 +0200 Subject: [PATCH 019/118] chore(): added backward compatibility for postgresql --- internal/impl/postgresql/input_postgrecdc.go | 20 +- internal/impl/postgresql/integration_test.go | 331 +++++++++++------- .../pglogicalstream/logical_stream.go | 70 ++-- .../pglogicalstream/monitor_test.go | 17 +- .../postgresql/pglogicalstream/pglogrepl.go | 54 ++- .../pglogicalstream/pglogrepl_test.go | 8 +- .../postgresql/pglogicalstream/snapshotter.go | 95 +++-- .../impl/postgresql/pglogicalstream/util.go | 29 ++ 8 files changed, 418 insertions(+), 206 deletions(-) diff --git a/internal/impl/postgresql/input_postgrecdc.go b/internal/impl/postgresql/input_postgrecdc.go index fe09e7ef7f..4836d2cdb1 100644 --- a/internal/impl/postgresql/input_postgrecdc.go +++ b/internal/impl/postgresql/input_postgrecdc.go @@ -277,12 +277,6 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { func (p *pgStreamInput) Read(ctx context.Context) (*service.Message, service.AckFunc, error) { select { - case <-p.metricsTicker.C: - progress := p.pglogicalStream.GetProgress() - for table, progress := range progress.TableProgress { - p.snapshotMetrics.Set(int64(progress), table) - } - p.replicationLag.Set(progress.WalLagInBytes) case snapshotMessage := <-p.pglogicalStream.SnapshotMessageC(): var ( mb []byte @@ -291,7 +285,11 @@ func (p *pgStreamInput) Read(ctx context.Context) (*service.Message, service.Ack if mb, err = json.Marshal(snapshotMessage); err != nil { return nil, nil, err } - return service.NewMessage(mb), func(ctx context.Context, err error) error { + + connectMessage := service.NewMessage(mb) + connectMessage.MetaSet("table", snapshotMessage.Changes[0].Table) + connectMessage.MetaSet("operation", snapshotMessage.Changes[0].Operation) + return connectMessage, func(ctx context.Context, err error) error { // Nacks are retried automatically when we use service.AutoRetryNacks return nil }, nil @@ -303,10 +301,10 @@ func (p *pgStreamInput) Read(ctx context.Context) (*service.Message, service.Ack if mb, err = json.Marshal(message); err != nil { return nil, nil, err } - return service.NewMessage(mb), func(ctx context.Context, err error) error { - // Nacks are retried automatically when we use service.AutoRetryNacks - //message.ServerHeartbeat. - + connectMessage := service.NewMessage(mb) + connectMessage.MetaSet("table", message.Changes[0].Table) + connectMessage.MetaSet("operation", message.Changes[0].Operation) + return connectMessage, func(ctx context.Context, err error) error { if message.Lsn != nil { if err := p.pglogicalStream.AckLSN(*message.Lsn); err != nil { return err diff --git a/internal/impl/postgresql/integration_test.go b/internal/impl/postgresql/integration_test.go index be8afb5fab..b97357c4f5 100644 --- a/internal/impl/postgresql/integration_test.go +++ b/internal/impl/postgresql/integration_test.go @@ -29,6 +29,76 @@ import ( "github.com/ory/dockertest/v3/docker" ) +func ResourceWithPostgreSQLVersion(t *testing.T, pool *dockertest.Pool, version string) (*dockertest.Resource, *sql.DB, error) { + resource, err := pool.RunWithOptions(&dockertest.RunOptions{ + Repository: "postgres", + Tag: version, + Env: []string{ + "POSTGRES_PASSWORD=secret", + "POSTGRES_USER=user_name", + "POSTGRES_DB=dbname", + }, + Cmd: []string{ + "postgres", + "-c", "wal_level=logical", + }, + }, func(config *docker.HostConfig) { + config.AutoRemove = true + config.RestartPolicy = docker.RestartPolicy{Name: "no"} + }) + + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, pool.Purge(resource)) + }) + + require.NoError(t, resource.Expire(120)) + + hostAndPort := resource.GetHostPort("5432/tcp") + hostAndPortSplited := strings.Split(hostAndPort, ":") + databaseURL := fmt.Sprintf("user=user_name password=secret dbname=dbname sslmode=disable host=%s port=%s", hostAndPortSplited[0], hostAndPortSplited[1]) + + var db *sql.DB + pool.MaxWait = 120 * time.Second + if err = pool.Retry(func() error { + if db, err = sql.Open("postgres", databaseURL); err != nil { + return err + } + + if err = db.Ping(); err != nil { + return err + } + + var walLevel string + if err = db.QueryRow("SHOW wal_level").Scan(&walLevel); err != nil { + return err + } + + var pgConfig string + if err = db.QueryRow("SHOW config_file").Scan(&pgConfig); err != nil { + return err + } + + if walLevel != "logical" { + return fmt.Errorf("wal_level is not logical") + } + + _, err = db.Exec("CREATE TABLE IF NOT EXISTS flights (id serial PRIMARY KEY, name VARCHAR(50), created_at TIMESTAMP);") + if err != nil { + return err + } + + // flights_non_streamed is a control table with data that should not be streamed or queried by snapshot streaming + _, err = db.Exec("CREATE TABLE IF NOT EXISTS flights_non_streamed (id serial PRIMARY KEY, name VARCHAR(50), created_at TIMESTAMP);") + + return err + }); err != nil { + panic(fmt.Errorf("could not connect to docker: %w", err)) + } + + return resource, db, nil +} + func TestIntegrationPgCDC(t *testing.T) { tmpDir := t.TempDir() pool, err := dockertest.NewPool("") @@ -225,64 +295,18 @@ func TestIntegrationPgCDCForPgOutputPlugin(t *testing.T) { pool, err := dockertest.NewPool("") require.NoError(t, err) - resource, err := pool.RunWithOptions(&dockertest.RunOptions{ - Repository: "postgres", - Tag: "16", - Env: []string{ - "POSTGRES_PASSWORD=secret", - "POSTGRES_USER=user_name", - "POSTGRES_DB=dbname", - }, - Cmd: []string{ - "postgres", - "-c", "wal_level=logical", - }, - }, func(config *docker.HostConfig) { - config.AutoRemove = true - config.RestartPolicy = docker.RestartPolicy{Name: "no"} - }) + var ( + resource *dockertest.Resource + db *sql.DB + ) + resource, db, err = ResourceWithPostgreSQLVersion(t, pool, "16") require.NoError(t, err) - t.Cleanup(func() { - assert.NoError(t, pool.Purge(resource)) - }) - require.NoError(t, resource.Expire(120)) hostAndPort := resource.GetHostPort("5432/tcp") hostAndPortSplited := strings.Split(hostAndPort, ":") - databaseURL := fmt.Sprintf("user=user_name password=secret dbname=dbname sslmode=disable host=%s port=%s", hostAndPortSplited[0], hostAndPortSplited[1]) - - var db *sql.DB - - pool.MaxWait = 120 * time.Second - err = pool.Retry(func() error { - if db, err = sql.Open("postgres", databaseURL); err != nil { - return err - } - - if err = db.Ping(); err != nil { - return err - } - - var walLevel string - if err = db.QueryRow("SHOW wal_level").Scan(&walLevel); err != nil { - return err - } - if walLevel != "logical" { - return fmt.Errorf("wal_level is not logical") - } - - _, err = db.Exec("CREATE TABLE IF NOT EXISTS flights (id serial PRIMARY KEY, name VARCHAR(50), created_at TIMESTAMP);") - if err != nil { - return err - } - - // flights_non_streamed is a control table with data that should not be streamed or queried by snapshot streaming - _, err = db.Exec("CREATE TABLE IF NOT EXISTS flights_non_streamed (id serial PRIMARY KEY, name VARCHAR(50), created_at TIMESTAMP);") - return err - }) require.NoError(t, err) fake := faker.New() @@ -407,65 +431,17 @@ func TestIntegrationPgCDCForPgOutputStreamUncomitedPlugin(t *testing.T) { pool, err := dockertest.NewPool("") require.NoError(t, err) - resource, err := pool.RunWithOptions(&dockertest.RunOptions{ - Repository: "postgres", - Tag: "16", - Env: []string{ - "POSTGRES_PASSWORD=secret", - "POSTGRES_USER=user_name", - "POSTGRES_DB=dbname", - }, - Cmd: []string{ - "postgres", - "-c", "wal_level=logical", - }, - }, func(config *docker.HostConfig) { - config.AutoRemove = true - config.RestartPolicy = docker.RestartPolicy{Name: "no"} - }) + var ( + resource *dockertest.Resource + db *sql.DB + ) + resource, db, err = ResourceWithPostgreSQLVersion(t, pool, "16") require.NoError(t, err) - t.Cleanup(func() { - assert.NoError(t, pool.Purge(resource)) - }) - require.NoError(t, resource.Expire(120)) hostAndPort := resource.GetHostPort("5432/tcp") hostAndPortSplited := strings.Split(hostAndPort, ":") - databaseURL := fmt.Sprintf("user=user_name password=secret dbname=dbname sslmode=disable host=%s port=%s", hostAndPortSplited[0], hostAndPortSplited[1]) - - var db *sql.DB - - pool.MaxWait = 120 * time.Second - err = pool.Retry(func() error { - if db, err = sql.Open("postgres", databaseURL); err != nil { - return err - } - - if err = db.Ping(); err != nil { - return err - } - - var walLevel string - if err = db.QueryRow("SHOW wal_level").Scan(&walLevel); err != nil { - return err - } - - if walLevel != "logical" { - return fmt.Errorf("wal_level is not logical") - } - - _, err = db.Exec("CREATE TABLE IF NOT EXISTS flights (id serial PRIMARY KEY, name VARCHAR(50), created_at TIMESTAMP);") - if err != nil { - return err - } - - // flights_non_streamed is a control table with data that should not be streamed or queried by snapshot streaming - _, err = db.Exec("CREATE TABLE IF NOT EXISTS flights_non_streamed (id serial PRIMARY KEY, name VARCHAR(50), created_at TIMESTAMP);") - return err - }) - require.NoError(t, err) fake := faker.New() for i := 0; i < 10; i++ { @@ -585,32 +561,141 @@ file: }) } -func bulkInsert(db *sql.DB, generateData func() (string, time.Time), totalInserts int) error { - const batchSize = 10000 +func TestIntegrationPgMultiVersionsCDCForPgOutputStreamUncomitedPlugin(t *testing.T) { + // running tests in the look to test different PostgreSQL versions + t.Parallel() + for _, v := range []string{"13", "12", "11", "10", "9.6", "9.4"} { + tmpDir := t.TempDir() + pool, err := dockertest.NewPool("") + require.NoError(t, err) + + var ( + resource *dockertest.Resource + db *sql.DB + ) + + resource, db, err = ResourceWithPostgreSQLVersion(t, pool, v) + require.NoError(t, err) + require.NoError(t, resource.Expire(120)) - for i := 0; i < totalInserts; i += batchSize { - end := i + batchSize - if end > totalInserts { - end = totalInserts + hostAndPort := resource.GetHostPort("5432/tcp") + hostAndPortSplited := strings.Split(hostAndPort, ":") + + fake := faker.New() + for i := 0; i < 1000; i++ { + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + require.NoError(t, err) } - valueStrings := make([]string, 0, batchSize) - valueArgs := make([]interface{}, 0, batchSize*2) + template := fmt.Sprintf(` +pg_stream: + host: %s + slot_name: test_slot_native_decoder + user: user_name + password: secret + port: %s + schema: public + tls: none + stream_snapshot: true + decoding_plugin: pgoutput + stream_uncomited: true + database: dbname + tables: + - flights +`, hostAndPortSplited[0], hostAndPortSplited[1]) + + cacheConf := fmt.Sprintf(` +label: pg_stream_cache +file: + directory: %v +`, tmpDir) + + streamOutBuilder := service.NewStreamBuilder() + require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: INFO`)) + require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) + require.NoError(t, streamOutBuilder.AddInputYAML(template)) - for j := 0; j < end-i; j++ { - valueStrings = append(valueStrings, fmt.Sprintf("($%d, $%d)", j*2+1, j*2+2)) - name, createdAt := generateData() - valueArgs = append(valueArgs, name, createdAt) + var outMessages []string + var outMessagesMut sync.Mutex + + require.NoError(t, streamOutBuilder.AddConsumerFunc(func(c context.Context, m *service.Message) error { + msgBytes, err := m.AsBytes() + require.NoError(t, err) + outMessagesMut.Lock() + outMessages = append(outMessages, string(msgBytes)) + outMessagesMut.Unlock() + return nil + })) + + streamOut, err := streamOutBuilder.Build() + require.NoError(t, err) + + go func() { + _ = streamOut.Run(context.Background()) + }() + + assert.Eventually(t, func() bool { + outMessagesMut.Lock() + defer outMessagesMut.Unlock() + return len(outMessages) == 1000 + }, time.Second*25, time.Millisecond*100) + + for i := 0; i < 1000; i++ { + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + _, err = db.Exec("INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + require.NoError(t, err) } - stmt := fmt.Sprintf("INSERT INTO flights (name, created_at) VALUES %s", - strings.Join(valueStrings, ",")) + assert.Eventually(t, func() bool { + outMessagesMut.Lock() + defer outMessagesMut.Unlock() + return len(outMessages) == 2000 + }, time.Second*25, time.Millisecond*100) + + require.NoError(t, streamOut.StopWithin(time.Second*10)) + + // Starting stream for the same replication slot should continue from the last LSN + // Meaning we must not receive any old messages again + + streamOutBuilder = service.NewStreamBuilder() + require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: INFO`)) + require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) + require.NoError(t, streamOutBuilder.AddInputYAML(template)) + + outMessages = []string{} + require.NoError(t, streamOutBuilder.AddConsumerFunc(func(c context.Context, m *service.Message) error { + msgBytes, err := m.AsBytes() + require.NoError(t, err) + outMessagesMut.Lock() + outMessages = append(outMessages, string(msgBytes)) + outMessagesMut.Unlock() + return nil + })) + + streamOut, err = streamOutBuilder.Build() + require.NoError(t, err) - _, err := db.Exec(stmt, valueArgs...) - if err != nil { - return fmt.Errorf("bulk insert failed: %w", err) + go func() { + assert.NoError(t, streamOut.Run(context.Background())) + }() + + time.Sleep(time.Second * 5) + for i := 0; i < 1000; i++ { + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + require.NoError(t, err) } - } - return nil + assert.Eventually(t, func() bool { + outMessagesMut.Lock() + defer outMessagesMut.Unlock() + return len(outMessages) == 1000 + }, time.Second*20, time.Millisecond*100) + + require.NoError(t, streamOut.StopWithin(time.Second*10)) + t.Log("All the conditions are met 🎉") + + t.Cleanup(func() { + db.Close() + }) + } } diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 2cacc537f9..7455685b06 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -55,6 +55,7 @@ type Stream struct { logger *service.Logger monitor *Monitor streamUncomited bool + snapshotter *Snapshotter lsnAckBuffer []string @@ -137,12 +138,34 @@ func NewPgStream(config Config) (*Stream, error) { tableNames[i] = fmt.Sprintf("%s.%s", config.DBSchema, table) } + var version int + version, err = getPostgresVersion(*cfg) + if err != nil { + return nil, err + } + + snapshotter, err := NewSnapshotter(stream.dbConfig, stream.logger, version) + if err != nil { + if err != nil { + stream.logger.Errorf("Failed to open SQL connection to prepare snapshot: %v", err.Error()) + if err = stream.cleanUpOnFailure(); err != nil { + stream.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) + } + + os.Exit(1) + } + } + stream.snapshotter = snapshotter + var pluginArguments = []string{} if stream.decodingPlugin == "pgoutput" { pluginArguments = []string{ "proto_version '1'", fmt.Sprintf("publication_names 'pglog_stream_%s'", config.ReplicationSlotName), - "messages 'true'", + } + + if version > 14 { + pluginArguments = append(pluginArguments, "messages 'true'") } } @@ -156,20 +179,18 @@ func NewPgStream(config Config) (*Stream, error) { stream.decodingPluginArguments = pluginArguments + // create snapshot transaction before creating a slot for older PostgreSQL versions to ensure consistency + pubName := "pglog_stream_" + config.ReplicationSlotName if err = CreatePublication(context.Background(), stream.pgConn, pubName, tableNames, true); err != nil { return nil, err } - stream.logger.Infof("Created Postgresql publication %v %v", "publication_name", config.ReplicationSlotName) - sysident, err := IdentifySystem(context.Background(), stream.pgConn) if err != nil { return nil, err } - stream.logger.Infof("System identification result SystemID: %v Timeline: %v XLogPos: %v DBName: %v", sysident.SystemID, sysident.Timeline, sysident.XLogPos, sysident.DBName) - var freshlyCreatedSlot = false var confirmedLSNFromDB string // check is replication slot exist to get last restart SLN @@ -183,8 +204,10 @@ func NewPgStream(config Config) (*Stream, error) { createSlotResult, err = CreateReplicationSlot(context.Background(), stream.pgConn, stream.slotName, stream.decodingPlugin.String(), CreateReplicationSlotOptions{Temporary: config.TemporaryReplicationSlot, SnapshotAction: "export", - }) + }, version, stream.snapshotter) if err != nil { + fmt.Println(err) + fmt.Println("Failed to create replication slot", err.Error()) return nil, err } stream.snapshotName = createSlotResult.SnapshotName @@ -192,7 +215,6 @@ func NewPgStream(config Config) (*Stream, error) { } else { slotCheckRow := slotCheckResults[0].Rows[0] confirmedLSNFromDB = string(slotCheckRow[0]) - stream.logger.Infof("Replication slot restart LSN extracted from DB: LSN %v", confirmedLSNFromDB) } } @@ -504,16 +526,7 @@ func (s *Stream) streamMessagesAsync() { } func (s *Stream) processSnapshot() { - snapshotter, err := NewSnapshotter(s.dbConfig, s.snapshotName, s.logger) - if err != nil { - s.logger.Errorf("Failed to open SQL connection to prepare snapshot: %v", err.Error()) - if err = s.cleanUpOnFailure(); err != nil { - s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) - } - - os.Exit(1) - } - if err = snapshotter.prepare(); err != nil { + if err := s.snapshotter.prepare(); err != nil { s.logger.Errorf("Failed to prepare database snapshot. Probably snapshot is expired...: %v", err.Error()) if err = s.cleanUpOnFailure(); err != nil { s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) @@ -522,33 +535,24 @@ func (s *Stream) processSnapshot() { os.Exit(1) } defer func() { - if err = snapshotter.releaseSnapshot(); err != nil { + if err := s.snapshotter.releaseSnapshot(); err != nil { s.logger.Errorf("Failed to release database snapshot: %v", err.Error()) } - if err = snapshotter.closeConn(); err != nil { + if err := s.snapshotter.closeConn(); err != nil { s.logger.Errorf("Failed to close database connection: %v", err.Error()) } }() - tableStats, err := snapshotter.GetRowsCountPerTable(s.tableNames) - if err != nil { - s.logger.Errorf("Failed to get table stats: %v", err.Error()) - if err = s.cleanUpOnFailure(); err != nil { - s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) - } - - os.Exit(1) - } - for _, table := range s.tableNames { s.logger.Infof("Processing snapshot for table: %v", table) var ( avgRowSizeBytes sql.NullInt64 offset = 0 + err error ) - avgRowSizeBytes, err = snapshotter.findAvgRowSize(table) + avgRowSizeBytes, err = s.snapshotter.findAvgRowSize(table) if err != nil { s.logger.Errorf("Failed to calculate average row size for table %v: %v", table, err.Error()) if err = s.cleanUpOnFailure(); err != nil { @@ -559,7 +563,7 @@ func (s *Stream) processSnapshot() { } availableMemory := getAvailableMemory() - batchSize := snapshotter.calculateBatchSize(availableMemory, uint64(avgRowSizeBytes.Int64)) + batchSize := s.snapshotter.calculateBatchSize(availableMemory, uint64(avgRowSizeBytes.Int64)) s.logger.Infof("Querying snapshot batch_side: %v, available_memory: %v, avg_row_size: %v", batchSize, availableMemory, avgRowSizeBytes.Int64) tablePk, err := s.getPrimaryKeyColumn(table) @@ -574,7 +578,7 @@ func (s *Stream) processSnapshot() { for { var snapshotRows *sql.Rows - if snapshotRows, err = snapshotter.querySnapshotData(table, tablePk, batchSize, offset); err != nil { + if snapshotRows, err = s.snapshotter.querySnapshotData(table, tablePk, batchSize, offset); err != nil { s.logger.Errorf("Failed to query snapshot for table %v: %v", table, err.Error()) if err = s.cleanUpOnFailure(); err != nil { s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) @@ -701,7 +705,7 @@ func (s *Stream) processSnapshot() { } - if err = s.startLr(); err != nil { + if err := s.startLr(); err != nil { s.logger.Errorf("Failed to start logical replication after snapshot: %v", err.Error()) os.Exit(1) } diff --git a/internal/impl/postgresql/pglogicalstream/monitor_test.go b/internal/impl/postgresql/pglogicalstream/monitor_test.go index ca1dca7c40..c78e4a0f8d 100644 --- a/internal/impl/postgresql/pglogicalstream/monitor_test.go +++ b/internal/impl/postgresql/pglogicalstream/monitor_test.go @@ -3,10 +3,12 @@ package pglogicalstream import ( "database/sql" "fmt" + "strconv" "strings" "testing" "time" + "github.com/jackc/pgx/v5/pgconn" "github.com/jaswdr/faker" "github.com/ory/dockertest/v3" "github.com/ory/dockertest/v3/docker" @@ -15,6 +17,7 @@ import ( ) func Test_MonitorReplorting(t *testing.T) { + t.Skip("Skipping for now") pool, err := dockertest.NewPool("") require.NoError(t, err) @@ -40,7 +43,7 @@ func Test_MonitorReplorting(t *testing.T) { hostAndPort := resource.GetHostPort("5432/tcp") hostAndPortSplited := strings.Split(hostAndPort, ":") - databaseURL := fmt.Sprintf("user=user_name password=secret dbname=dbname sslmode=disable host=%s port=%s replication=database", hostAndPortSplited[0], hostAndPortSplited[1]) + databaseURL := fmt.Sprintf("user=user_name password=secret dbname=dbname sslmode=disable host=%s port=%s", hostAndPortSplited[0], hostAndPortSplited[1]) var db *sql.DB pool.MaxWait = 120 * time.Second @@ -66,6 +69,16 @@ func Test_MonitorReplorting(t *testing.T) { require.NoError(t, err) } + portUint64, err := strconv.ParseUint(hostAndPortSplited[1], 10, 10) + require.NoError(t, err) slotName := "test_slot" - mon := NewMonitor(db, logger*service.Logger, []string{"flights"}, slotName) + mon, err := NewMonitor(&pgconn.Config{ + Host: hostAndPortSplited[0], + Port: uint16(portUint64), + User: "user_name", + Password: "secret", + Database: "dbname", + }, &service.Logger{}, []string{"flights"}, slotName) + require.NoError(t, err) + require.NotNil(t, mon) } diff --git a/internal/impl/postgresql/pglogicalstream/pglogrepl.go b/internal/impl/postgresql/pglogicalstream/pglogrepl.go index 0213b008ba..89aba6479b 100644 --- a/internal/impl/postgresql/pglogicalstream/pglogrepl.go +++ b/internal/impl/postgresql/pglogicalstream/pglogrepl.go @@ -234,6 +234,8 @@ func CreateReplicationSlot( slotName string, outputPlugin string, options CreateReplicationSlotOptions, + version int, + snapshotter *Snapshotter, ) (CreateReplicationSlotResult, error) { var temporaryString string if options.Temporary { @@ -245,12 +247,42 @@ func CreateReplicationSlot( } else { snapshotString = options.SnapshotAction } - sql := fmt.Sprintf("CREATE_REPLICATION_SLOT %s %s %s %s %s", slotName, temporaryString, options.Mode, outputPlugin, snapshotString) - return ParseCreateReplicationSlot(conn.Exec(ctx, sql)) + + newPgCreateSlotCommand := fmt.Sprintf("CREATE_REPLICATION_SLOT %s %s %s %s %s", slotName, temporaryString, options.Mode, outputPlugin, snapshotString) + oldPgCreateSlotCommand := fmt.Sprintf("SELECT * FROM pg_create_logical_replication_slot('%s', '%s', %v);", slotName, outputPlugin, temporaryString == "TEMPORARY") + + var snapshotName string + if version > 14 { + result, err := ParseCreateReplicationSlot(conn.Exec(ctx, newPgCreateSlotCommand), version, snapshotName) + if err != nil { + return CreateReplicationSlotResult{}, err + } + snapshotter.setTransactionSnapshotName(result.SnapshotName) + } + + var snapshotResponse SnapshotCreationResponse + if options.SnapshotAction == "export" { + var err error + snapshotResponse, err = snapshotter.initSnapshotTransaction() + if err != nil { + return CreateReplicationSlotResult{}, err + } + snapshotter.setTransactionSnapshotName(snapshotResponse.ExportedSnapshotName) + } + + replicationSlotCreationResponse := conn.Exec(ctx, oldPgCreateSlotCommand) + _, err := replicationSlotCreationResponse.ReadAll() + if err != nil { + return CreateReplicationSlotResult{}, err + } + + return CreateReplicationSlotResult{ + SnapshotName: snapshotResponse.ExportedSnapshotName, + }, nil } // ParseCreateReplicationSlot parses the result of the CREATE_REPLICATION_SLOT command. -func ParseCreateReplicationSlot(mrr *pgconn.MultiResultReader) (CreateReplicationSlotResult, error) { +func ParseCreateReplicationSlot(mrr *pgconn.MultiResultReader, version int, snapshotName string) (CreateReplicationSlotResult, error) { var crsr CreateReplicationSlotResult results, err := mrr.ReadAll() if err != nil { @@ -267,14 +299,22 @@ func ParseCreateReplicationSlot(mrr *pgconn.MultiResultReader) (CreateReplicatio } row := result.Rows[0] - if len(row) != 4 { - return crsr, fmt.Errorf("expected 4 result columns, got %d", len(row)) + if version > 14 { + if len(row) != 4 { + return crsr, fmt.Errorf("expected 4 result columns, got %d", len(row)) + } } crsr.SlotName = string(row[0]) crsr.ConsistentPoint = string(row[1]) - crsr.SnapshotName = string(row[2]) - crsr.OutputPlugin = string(row[3]) + + if version > 14 { + crsr.SnapshotName = string(row[2]) + } else { + crsr.SnapshotName = snapshotName + } + + fmt.Println("Snapshot name", crsr.SnapshotName) return crsr, nil } diff --git a/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go b/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go index d38aa11a7e..7859652afd 100644 --- a/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go +++ b/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go @@ -178,7 +178,7 @@ func TestCreateReplicationSlot(t *testing.T) { require.NoError(t, err) defer closeConn(t, conn) - result, err := CreateReplicationSlot(ctx, conn, slotName, outputPlugin, CreateReplicationSlotOptions{Temporary: false, SnapshotAction: "export"}) + result, err := CreateReplicationSlot(ctx, conn, slotName, outputPlugin, CreateReplicationSlotOptions{Temporary: false, SnapshotAction: "export"}, 16, nil) require.NoError(t, err) assert.Equal(t, slotName, result.SlotName) @@ -198,13 +198,13 @@ func TestDropReplicationSlot(t *testing.T) { require.NoError(t, err) defer closeConn(t, conn) - _, err = CreateReplicationSlot(ctx, conn, slotName, outputPlugin, CreateReplicationSlotOptions{Temporary: false}) + _, err = CreateReplicationSlot(ctx, conn, slotName, outputPlugin, CreateReplicationSlotOptions{Temporary: false}, 16, nil) require.NoError(t, err) err = DropReplicationSlot(ctx, conn, slotName, DropReplicationSlotOptions{}) require.NoError(t, err) - _, err = CreateReplicationSlot(ctx, conn, slotName, outputPlugin, CreateReplicationSlotOptions{Temporary: false}) + _, err = CreateReplicationSlot(ctx, conn, slotName, outputPlugin, CreateReplicationSlotOptions{Temporary: false}, 16, nil) require.NoError(t, err) } @@ -230,7 +230,7 @@ func TestStartReplication(t *testing.T) { err = CreatePublication(context.Background(), conn, publicationName, []string{}, true) require.NoError(t, err) - _, err = CreateReplicationSlot(ctx, conn, slotName, outputPlugin, CreateReplicationSlotOptions{Temporary: false, SnapshotAction: "export"}) + _, err = CreateReplicationSlot(ctx, conn, slotName, outputPlugin, CreateReplicationSlotOptions{Temporary: false, SnapshotAction: "export"}, 16, nil) require.NoError(t, err) err = StartReplication(ctx, conn, slotName, sysident.XLogPos, StartReplicationOptions{ diff --git a/internal/impl/postgresql/pglogicalstream/snapshotter.go b/internal/impl/postgresql/pglogicalstream/snapshotter.go index f20580a7c3..00e3fd460e 100644 --- a/internal/impl/postgresql/pglogicalstream/snapshotter.go +++ b/internal/impl/postgresql/pglogicalstream/snapshotter.go @@ -19,57 +19,90 @@ import ( "github.com/redpanda-data/benthos/v4/public/service" ) +type SnapshotCreationResponse struct { + ExportedSnapshotName string +} + // Snapshotter is a structure that allows the creation of a snapshot of a database at a given point in time // At the time we initialize logical replication - we specify what we want to export the snapshot. // This snapshot exists until the connection that created the replication slot remains open. // Therefore Snapshotter opens another connection to the database and sets the transaction to the snapshot. // This allows you to read the data that was in the database at the time of the snapshot creation. type Snapshotter struct { - pgConnection *sql.DB - logger *service.Logger + pgConnection *sql.DB + snapshotCreateConnection *sql.DB + logger *service.Logger snapshotName string + + version int } // NewSnapshotter creates a new Snapshotter instance -func NewSnapshotter(dbConf pgconn.Config, snapshotName string, logger *service.Logger) (*Snapshotter, error) { +func NewSnapshotter(dbConf pgconn.Config, logger *service.Logger, version int) (*Snapshotter, error) { pgConn, err := openPgConnectionFromConfig(dbConf) + if err != nil { + return nil, err + } + + snapshotCreateConnection, err := openPgConnectionFromConfig(dbConf) + if err != nil { + return nil, err + } return &Snapshotter{ - pgConnection: pgConn, - snapshotName: snapshotName, - logger: logger, - }, err + pgConnection: pgConn, + snapshotCreateConnection: snapshotCreateConnection, + logger: logger, + version: version, + }, nil } -func (s *Snapshotter) prepare() error { - if _, err := s.pgConnection.Exec("BEGIN TRANSACTION ISOLATION LEVEL REPEATABLE READ;"); err != nil { - return err - } - if _, err := s.pgConnection.Exec(fmt.Sprintf("SET TRANSACTION SNAPSHOT '%s';", s.snapshotName)); err != nil { - return err +func (s *Snapshotter) initSnapshotTransaction() (SnapshotCreationResponse, error) { + if s.version >= 14 { + return SnapshotCreationResponse{}, errors.New("Snapshot is exported by default for versions above PG14") } - return nil -} + var snapshotName sql.NullString -func (s *Snapshotter) GetRowsCountPerTable(tableNames []string) (map[string]int, error) { - tables := make(map[string]int) - rows, err := s.pgConnection.Query("SELECT table_name, count(*) FROM information_schema.tables WHERE table_name in (?) GROUP BY table_name;", tableNames) + snapshotRow, err := s.pgConnection.Query(`BEGIN; SELECT pg_export_snapshot();`) if err != nil { - return tables, err + return SnapshotCreationResponse{}, fmt.Errorf("Cant get exported snapshot for initial streaming", err) + } + + if snapshotRow.Err() != nil { + return SnapshotCreationResponse{}, fmt.Errorf("can get avg row size due to query failure: %w", snapshotRow.Err()) } - for rows.Next() { - var tableName string - var count int - if err := rows.Scan(&tableName, &count); err != nil { - return tables, err + if snapshotRow.Next() { + if err = snapshotRow.Scan(&snapshotName); err != nil { + return SnapshotCreationResponse{}, fmt.Errorf("Cant scan snapshot name into string: %w", err) } - tables[tableName] = count + } else { + return SnapshotCreationResponse{}, errors.New("can get avg row size; 0 rows returned") } - return tables, nil + return SnapshotCreationResponse{ExportedSnapshotName: snapshotName.String}, nil +} + +func (s *Snapshotter) setTransactionSnapshotName(snapshotName string) { + s.snapshotName = snapshotName +} + +func (s *Snapshotter) prepare() error { + if s.snapshotName == "" { + return errors.New("Snapshot name is not set") + } + + if _, err := s.pgConnection.Exec("BEGIN TRANSACTION ISOLATION LEVEL REPEATABLE READ;"); err != nil { + return err + } + if _, err := s.pgConnection.Exec(fmt.Sprintf("SET TRANSACTION SNAPSHOT '%s';", s.snapshotName)); err != nil { + fmt.Println("Failed to prepare snapshot", err) + return err + } + + return nil } func (s *Snapshotter) findAvgRowSize(table string) (sql.NullInt64, error) { @@ -115,6 +148,12 @@ func (s *Snapshotter) querySnapshotData(table string, pk string, limit, offset i } func (s *Snapshotter) releaseSnapshot() error { + if s.version < 14 && s.snapshotCreateConnection != nil { + if _, err := s.snapshotCreateConnection.Exec("COMMIT;"); err != nil { + return err + } + } + _, err := s.pgConnection.Exec("COMMIT;") return err } @@ -124,5 +163,9 @@ func (s *Snapshotter) closeConn() error { return s.pgConnection.Close() } + if s.snapshotCreateConnection != nil { + return s.snapshotCreateConnection.Close() + } + return nil } diff --git a/internal/impl/postgresql/pglogicalstream/util.go b/internal/impl/postgresql/pglogicalstream/util.go index 041427c2cd..65e44fe610 100644 --- a/internal/impl/postgresql/pglogicalstream/util.go +++ b/internal/impl/postgresql/pglogicalstream/util.go @@ -3,6 +3,8 @@ package pglogicalstream import ( "database/sql" "fmt" + "regexp" + "strconv" "github.com/jackc/pgx/v5/pgconn" ) @@ -20,3 +22,30 @@ func openPgConnectionFromConfig(dbConf pgconn.Config) (*sql.DB, error) { return sql.Open("postgres", connStr) } + +func getPostgresVersion(connConfig pgconn.Config) (int, error) { + conn, err := openPgConnectionFromConfig(connConfig) + if err != nil { + return 0, fmt.Errorf("failed to connect to the database: %w", err) + } + + var versionString string + err = conn.QueryRow("SHOW server_version").Scan(&versionString) + if err != nil { + return 0, fmt.Errorf("failed to execute query: %w", err) + } + + // Extract the major version number + re := regexp.MustCompile(`^(\d+)`) + match := re.FindStringSubmatch(versionString) + if len(match) < 2 { + return 0, fmt.Errorf("failed to parse version string: %s", versionString) + } + + majorVersion, err := strconv.Atoi(match[1]) + if err != nil { + return 0, fmt.Errorf("failed to convert version to integer: %w", err) + } + + return majorVersion, nil +} From 39fdace9086d245ae155a28f1a14984c1e7e5b48 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Wed, 16 Oct 2024 13:19:52 +0200 Subject: [PATCH 020/118] chore(): updated tests for different pg versions && working on metrics --- internal/impl/postgresql/input_postgrecdc.go | 11 ++ internal/impl/postgresql/integration_test.go | 141 +++++++++++++++++- .../impl/postgresql/pglogicalstream/consts.go | 8 + .../postgresql/pglogicalstream/debouncer.go | 8 + .../pglogicalstream/logical_stream.go | 17 ++- .../postgresql/pglogicalstream/monitor.go | 62 ++++---- .../pglogicalstream/monitor_test.go | 8 + .../postgresql/pglogicalstream/pglogrepl.go | 8 +- .../pglogicalstream/pglogrepl_test.go | 1 - .../replication_message_decoders.go | 8 + .../postgresql/pglogicalstream/snapshotter.go | 4 +- .../pglogicalstream/stream_message.go | 26 +++- .../impl/postgresql/pglogicalstream/util.go | 8 + 13 files changed, 261 insertions(+), 49 deletions(-) diff --git a/internal/impl/postgresql/input_postgrecdc.go b/internal/impl/postgresql/input_postgrecdc.go index 4836d2cdb1..a338687122 100644 --- a/internal/impl/postgresql/input_postgrecdc.go +++ b/internal/impl/postgresql/input_postgrecdc.go @@ -12,6 +12,7 @@ import ( "context" "crypto/tls" "encoding/json" + "fmt" "strings" "time" @@ -289,6 +290,11 @@ func (p *pgStreamInput) Read(ctx context.Context) (*service.Message, service.Ack connectMessage := service.NewMessage(mb) connectMessage.MetaSet("table", snapshotMessage.Changes[0].Table) connectMessage.MetaSet("operation", snapshotMessage.Changes[0].Operation) + if snapshotMessage.Changes[0].TableSnapshotProgress != nil { + fmt.Println("Table snapshot progress", *snapshotMessage.Changes[0].TableSnapshotProgress, snapshotMessage.Changes[0].Table) + p.snapshotMetrics.SetFloat64(*snapshotMessage.Changes[0].TableSnapshotProgress, snapshotMessage.Changes[0].Table) + } + return connectMessage, func(ctx context.Context, err error) error { // Nacks are retried automatically when we use service.AutoRetryNacks return nil @@ -304,6 +310,11 @@ func (p *pgStreamInput) Read(ctx context.Context) (*service.Message, service.Ack connectMessage := service.NewMessage(mb) connectMessage.MetaSet("table", message.Changes[0].Table) connectMessage.MetaSet("operation", message.Changes[0].Operation) + if message.WALLagBytes != nil { + fmt.Println("Wal lag", *message.WALLagBytes) + p.replicationLag.Set(*message.WALLagBytes) + } + return connectMessage, func(ctx context.Context, err error) error { if message.Lsn != nil { if err := p.pglogicalStream.AckLSN(*message.Lsn); err != nil { diff --git a/internal/impl/postgresql/integration_test.go b/internal/impl/postgresql/integration_test.go index b97357c4f5..d8147a0feb 100644 --- a/internal/impl/postgresql/integration_test.go +++ b/internal/impl/postgresql/integration_test.go @@ -564,7 +564,7 @@ file: func TestIntegrationPgMultiVersionsCDCForPgOutputStreamUncomitedPlugin(t *testing.T) { // running tests in the look to test different PostgreSQL versions t.Parallel() - for _, v := range []string{"13", "12", "11", "10", "9.6", "9.4"} { + for _, v := range []string{"17", "16", "15", "14", "13", "12", "11", "10"} { tmpDir := t.TempDir() pool, err := dockertest.NewPool("") require.NoError(t, err) @@ -699,3 +699,142 @@ file: }) } } + +func TestIntegrationPgMultiVersionsCDCForPgOutputStreamComittedPlugin(t *testing.T) { + // running tests in the look to test different PostgreSQL versions + t.Parallel() + for _, v := range []string{"17", "16", "15", "14", "13", "12", "11", "10"} { + tmpDir := t.TempDir() + pool, err := dockertest.NewPool("") + require.NoError(t, err) + + var ( + resource *dockertest.Resource + db *sql.DB + ) + + resource, db, err = ResourceWithPostgreSQLVersion(t, pool, v) + require.NoError(t, err) + require.NoError(t, resource.Expire(120)) + + hostAndPort := resource.GetHostPort("5432/tcp") + hostAndPortSplited := strings.Split(hostAndPort, ":") + + fake := faker.New() + for i := 0; i < 1000; i++ { + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + require.NoError(t, err) + } + + template := fmt.Sprintf(` +pg_stream: + host: %s + slot_name: test_slot_native_decoder + user: user_name + password: secret + port: %s + schema: public + tls: none + stream_snapshot: true + decoding_plugin: pgoutput + stream_uncomited: false + database: dbname + tables: + - flights +`, hostAndPortSplited[0], hostAndPortSplited[1]) + + cacheConf := fmt.Sprintf(` +label: pg_stream_cache +file: + directory: %v +`, tmpDir) + + streamOutBuilder := service.NewStreamBuilder() + require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: INFO`)) + require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) + require.NoError(t, streamOutBuilder.AddInputYAML(template)) + + var outMessages []string + var outMessagesMut sync.Mutex + + require.NoError(t, streamOutBuilder.AddConsumerFunc(func(c context.Context, m *service.Message) error { + msgBytes, err := m.AsBytes() + require.NoError(t, err) + outMessagesMut.Lock() + outMessages = append(outMessages, string(msgBytes)) + outMessagesMut.Unlock() + return nil + })) + + streamOut, err := streamOutBuilder.Build() + require.NoError(t, err) + + go func() { + _ = streamOut.Run(context.Background()) + }() + + assert.Eventually(t, func() bool { + outMessagesMut.Lock() + defer outMessagesMut.Unlock() + return len(outMessages) == 1000 + }, time.Second*25, time.Millisecond*100) + + for i := 0; i < 1000; i++ { + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + _, err = db.Exec("INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + require.NoError(t, err) + } + + assert.Eventually(t, func() bool { + outMessagesMut.Lock() + defer outMessagesMut.Unlock() + return len(outMessages) == 2000 + }, time.Second*25, time.Millisecond*100) + + require.NoError(t, streamOut.StopWithin(time.Second*10)) + + // Starting stream for the same replication slot should continue from the last LSN + // Meaning we must not receive any old messages again + + streamOutBuilder = service.NewStreamBuilder() + require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: INFO`)) + require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) + require.NoError(t, streamOutBuilder.AddInputYAML(template)) + + outMessages = []string{} + require.NoError(t, streamOutBuilder.AddConsumerFunc(func(c context.Context, m *service.Message) error { + msgBytes, err := m.AsBytes() + require.NoError(t, err) + outMessagesMut.Lock() + outMessages = append(outMessages, string(msgBytes)) + outMessagesMut.Unlock() + return nil + })) + + streamOut, err = streamOutBuilder.Build() + require.NoError(t, err) + + go func() { + assert.NoError(t, streamOut.Run(context.Background())) + }() + + time.Sleep(time.Second * 5) + for i := 0; i < 1000; i++ { + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + require.NoError(t, err) + } + + assert.Eventually(t, func() bool { + outMessagesMut.Lock() + defer outMessagesMut.Unlock() + return len(outMessages) == 1000 + }, time.Second*20, time.Millisecond*100) + + require.NoError(t, streamOut.StopWithin(time.Second*10)) + t.Log("All the conditions are met 🎉") + + t.Cleanup(func() { + db.Close() + }) + } +} diff --git a/internal/impl/postgresql/pglogicalstream/consts.go b/internal/impl/postgresql/pglogicalstream/consts.go index 0728c3a538..968d0e60ed 100644 --- a/internal/impl/postgresql/pglogicalstream/consts.go +++ b/internal/impl/postgresql/pglogicalstream/consts.go @@ -1,3 +1,11 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + package pglogicalstream // DecodingPlugin is a type for the decoding plugin diff --git a/internal/impl/postgresql/pglogicalstream/debouncer.go b/internal/impl/postgresql/pglogicalstream/debouncer.go index 1ddc8b3eea..e33837279e 100644 --- a/internal/impl/postgresql/pglogicalstream/debouncer.go +++ b/internal/impl/postgresql/pglogicalstream/debouncer.go @@ -1,3 +1,11 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + package pglogicalstream import ( diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 7455685b06..a1639c5ae1 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -402,6 +402,9 @@ func (s *Stream) streamMessagesAsync() { } } clientXLogPos := xld.WALStart + LSN(len(xld.WALData)) + fmt.Println("Accessing meterics") + metrics := s.monitor.Report() + fmt.Println("Accessing meterics", metrics) if s.decodingPlugin == "wal2json" { message, err := decodeWal2JsonChanges(clientXLogPos.String(), xld.WALData) @@ -426,6 +429,7 @@ func (s *Stream) streamMessagesAsync() { return } } else { + s.messages <- *message } } @@ -455,9 +459,12 @@ func (s *Stream) streamMessagesAsync() { } } else { lsn := clientXLogPos.String() - s.messages <- StreamMessage{Lsn: &lsn, Changes: []StreamMessageChanges{ - *message, - }} + s.messages <- StreamMessage{ + Lsn: &lsn, + Changes: []StreamMessageChanges{ + *message, + }, + } <-s.consumedCallback } } else { @@ -526,7 +533,9 @@ func (s *Stream) streamMessagesAsync() { } func (s *Stream) processSnapshot() { + // metricsCtx, cancel := context.WithCancel(context.Background()) if err := s.snapshotter.prepare(); err != nil { + // cancel() s.logger.Errorf("Failed to prepare database snapshot. Probably snapshot is expired...: %v", err.Error()) if err = s.cleanUpOnFailure(); err != nil { s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) @@ -535,6 +544,7 @@ func (s *Stream) processSnapshot() { os.Exit(1) } defer func() { + // cancel() if err := s.snapshotter.releaseSnapshot(); err != nil { s.logger.Errorf("Failed to release database snapshot: %v", err.Error()) } @@ -620,6 +630,7 @@ func (s *Stream) processSnapshot() { } count := len(columnTypes) + var rowsCount = 0 for snapshotRows.Next() { rowsCount += 1 diff --git a/internal/impl/postgresql/pglogicalstream/monitor.go b/internal/impl/postgresql/pglogicalstream/monitor.go index b5e4abdc91..b4c2c9a1a8 100644 --- a/internal/impl/postgresql/pglogicalstream/monitor.go +++ b/internal/impl/postgresql/pglogicalstream/monitor.go @@ -1,9 +1,18 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + package pglogicalstream import ( "context" "database/sql" "fmt" + "math" "strings" "sync" "time" @@ -28,8 +37,6 @@ type Monitor struct { // finding the difference between the latest LSN and the last confirmed LSN for the replication slot replicationLagInBytes int64 - debounced func(f func()) - dbConn *sql.DB slotName string logger *service.Logger @@ -39,8 +46,6 @@ type Monitor struct { } func NewMonitor(conf *pgconn.Config, logger *service.Logger, tables []string, slotName string) (*Monitor, error) { - // debounces is user to throttle locks on the monitor to prevent unnecessary updates that would affect the performance - debounced := NewDebouncer(100 * time.Millisecond) dbConn, err := openPgConnectionFromConfig(*conf) if err != nil { return nil, err @@ -49,7 +54,6 @@ func NewMonitor(conf *pgconn.Config, logger *service.Logger, tables []string, sl m := &Monitor{ snapshotProgress: map[string]float64{}, replicationLagInBytes: 0, - debounced: debounced, dbConn: dbConn, slotName: slotName, logger: logger, @@ -69,47 +73,37 @@ func NewMonitor(conf *pgconn.Config, logger *service.Logger, tables []string, sl // UpdateSnapshotProgressForTable updates the snapshot ingestion progress for a given table func (m *Monitor) UpdateSnapshotProgressForTable(table string, position int) { - storeSnapshotProgress := func() { - m.lock.Lock() - defer m.lock.Unlock() - m.snapshotProgress[table] = float64(position) / float64(m.tableStat[table]) * 100 - } - - m.debounced(storeSnapshotProgress) + m.lock.Lock() + defer m.lock.Unlock() + m.snapshotProgress[table] = math.Round(float64(position) / float64(m.tableStat[table]) * 100) } // we need to read the tables stat to calculate the snapshot ingestion progress func (m *Monitor) readTablesStat(tables []string) error { results := make(map[string]int64) - // Construct the query - queryParts := make([]string, len(tables)) - for i, table := range tables { - queryParts[i] = fmt.Sprintf("SELECT '%s' AS table_name, COUNT(*) FROM %s", table, table) - } - query := strings.Join(queryParts, " UNION ALL ") + for _, table := range tables { + tableWithoutSchema := strings.Split(table, ".")[1] + query := fmt.Sprintf("SELECT COUNT(*) FROM %s", tableWithoutSchema) - // Execute the query - rows, err := m.dbConn.Query(query) - if err != nil { - return err - } - defer rows.Close() - - // Process the results - for rows.Next() { - var tableName string var count int64 - if err := rows.Scan(&tableName, &count); err != nil { - return err + err := m.dbConn.QueryRow(query).Scan(&count) + + if err != nil { + // If the error is because the table doesn't exist, we'll set the count to 0 + // and continue. You might want to log this situation. + if strings.Contains(err.Error(), "does not exist") { + results[tableWithoutSchema] = 0 + continue + } + // For any other error, we'll return it + return fmt.Errorf("error counting rows in table %s: %w", tableWithoutSchema, err) } - results[tableName] = count - } - if err := rows.Err(); err != nil { - return err + results[tableWithoutSchema] = count } + fmt.Println("Table stat", results) m.tableStat = results return nil } diff --git a/internal/impl/postgresql/pglogicalstream/monitor_test.go b/internal/impl/postgresql/pglogicalstream/monitor_test.go index c78e4a0f8d..3a2c800555 100644 --- a/internal/impl/postgresql/pglogicalstream/monitor_test.go +++ b/internal/impl/postgresql/pglogicalstream/monitor_test.go @@ -1,3 +1,11 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + package pglogicalstream import ( diff --git a/internal/impl/postgresql/pglogicalstream/pglogrepl.go b/internal/impl/postgresql/pglogicalstream/pglogrepl.go index 89aba6479b..5975c6d97f 100644 --- a/internal/impl/postgresql/pglogicalstream/pglogrepl.go +++ b/internal/impl/postgresql/pglogicalstream/pglogrepl.go @@ -257,7 +257,11 @@ func CreateReplicationSlot( if err != nil { return CreateReplicationSlotResult{}, err } - snapshotter.setTransactionSnapshotName(result.SnapshotName) + if snapshotter != nil { + snapshotter.setTransactionSnapshotName(result.SnapshotName) + } + + return result, nil } var snapshotResponse SnapshotCreationResponse @@ -314,8 +318,6 @@ func ParseCreateReplicationSlot(mrr *pgconn.MultiResultReader, version int, snap crsr.SnapshotName = snapshotName } - fmt.Println("Snapshot name", crsr.SnapshotName) - return crsr, nil } diff --git a/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go b/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go index 7859652afd..26689674cb 100644 --- a/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go +++ b/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go @@ -182,7 +182,6 @@ func TestCreateReplicationSlot(t *testing.T) { require.NoError(t, err) assert.Equal(t, slotName, result.SlotName) - assert.Equal(t, outputPlugin, result.OutputPlugin) } func TestDropReplicationSlot(t *testing.T) { diff --git a/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go b/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go index 829ba906fe..5ad6117593 100644 --- a/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go +++ b/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go @@ -1,3 +1,11 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + package pglogicalstream import ( diff --git a/internal/impl/postgresql/pglogicalstream/snapshotter.go b/internal/impl/postgresql/pglogicalstream/snapshotter.go index 00e3fd460e..a1e44e9dd0 100644 --- a/internal/impl/postgresql/pglogicalstream/snapshotter.go +++ b/internal/impl/postgresql/pglogicalstream/snapshotter.go @@ -59,7 +59,7 @@ func NewSnapshotter(dbConf pgconn.Config, logger *service.Logger, version int) ( } func (s *Snapshotter) initSnapshotTransaction() (SnapshotCreationResponse, error) { - if s.version >= 14 { + if s.version > 14 { return SnapshotCreationResponse{}, errors.New("Snapshot is exported by default for versions above PG14") } @@ -67,7 +67,7 @@ func (s *Snapshotter) initSnapshotTransaction() (SnapshotCreationResponse, error snapshotRow, err := s.pgConnection.Query(`BEGIN; SELECT pg_export_snapshot();`) if err != nil { - return SnapshotCreationResponse{}, fmt.Errorf("Cant get exported snapshot for initial streaming", err) + return SnapshotCreationResponse{}, fmt.Errorf("Cant get exported snapshot for initial streaming %w", err) } if snapshotRow.Err() != nil { diff --git a/internal/impl/postgresql/pglogicalstream/stream_message.go b/internal/impl/postgresql/pglogicalstream/stream_message.go index 446422c16c..d2f6e72c24 100644 --- a/internal/impl/postgresql/pglogicalstream/stream_message.go +++ b/internal/impl/postgresql/pglogicalstream/stream_message.go @@ -1,17 +1,33 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + package pglogicalstream // StreamMessageChanges represents the changes in a single message // Single message can have multiple changes type StreamMessageChanges struct { - Operation string `json:"operation"` - Schema string `json:"schema"` - Table string `json:"table"` + Operation string `json:"operation"` + Schema string `json:"schema"` + Table string `json:"table"` + TableSnapshotProgress *float64 `json:"table_snapshot_progress"` // For deleted messages - there will be old changes if replica identity set to full or empty changes Data map[string]any `json:"data"` } +type StreamMessageMetrics struct { + WALLagBytes *int64 `json:"wal_lag_bytes"` + IsStreaming bool `json:"is_streaming"` +} + // StreamMessage represents a single message after it has been decoded by the plugin type StreamMessage struct { - Lsn *string `json:"lsn"` - Changes []StreamMessageChanges `json:"changes"` + Lsn *string `json:"lsn"` + Changes []StreamMessageChanges `json:"changes"` + IsStreaming bool `json:"is_streaming"` + WALLagBytes *int64 `json:"wal_lag_bytes"` } diff --git a/internal/impl/postgresql/pglogicalstream/util.go b/internal/impl/postgresql/pglogicalstream/util.go index 65e44fe610..6765f868de 100644 --- a/internal/impl/postgresql/pglogicalstream/util.go +++ b/internal/impl/postgresql/pglogicalstream/util.go @@ -1,3 +1,11 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + package pglogicalstream import ( From 4f360c9b0fa5da547ab3c2c4a5df7bc6067756e8 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Wed, 16 Oct 2024 13:29:53 +0200 Subject: [PATCH 021/118] chore(): added WAL lag streaming --- internal/impl/postgresql/input_postgrecdc.go | 1 - .../pglogicalstream/logical_stream.go | 14 ++++++---- .../postgresql/pglogicalstream/monitor.go | 28 +++++++++++++++---- 3 files changed, 31 insertions(+), 12 deletions(-) diff --git a/internal/impl/postgresql/input_postgrecdc.go b/internal/impl/postgresql/input_postgrecdc.go index a338687122..c4ed24db0f 100644 --- a/internal/impl/postgresql/input_postgrecdc.go +++ b/internal/impl/postgresql/input_postgrecdc.go @@ -311,7 +311,6 @@ func (p *pgStreamInput) Read(ctx context.Context) (*service.Message, service.Ack connectMessage.MetaSet("table", message.Changes[0].Table) connectMessage.MetaSet("operation", message.Changes[0].Operation) if message.WALLagBytes != nil { - fmt.Println("Wal lag", *message.WALLagBytes) p.replicationLag.Set(*message.WALLagBytes) } diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index a1639c5ae1..b844a99195 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -402,10 +402,7 @@ func (s *Stream) streamMessagesAsync() { } } clientXLogPos := xld.WALStart + LSN(len(xld.WALData)) - fmt.Println("Accessing meterics") metrics := s.monitor.Report() - fmt.Println("Accessing meterics", metrics) - if s.decodingPlugin == "wal2json" { message, err := decodeWal2JsonChanges(clientXLogPos.String(), xld.WALData) if err != nil { @@ -429,7 +426,7 @@ func (s *Stream) streamMessagesAsync() { return } } else { - + message.WALLagBytes = &metrics.WalLagInBytes s.messages <- *message } } @@ -464,6 +461,8 @@ func (s *Stream) streamMessagesAsync() { Changes: []StreamMessageChanges{ *message, }, + IsStreaming: true, + WALLagBytes: &metrics.WalLagInBytes, } <-s.consumedCallback } @@ -522,7 +521,12 @@ func (s *Stream) streamMessagesAsync() { } else { // send all collected changes lsn := clientXLogPos.String() - s.messages <- StreamMessage{Lsn: &lsn, Changes: pgoutputChanges} + s.messages <- StreamMessage{ + Lsn: &lsn, + Changes: pgoutputChanges, + IsStreaming: true, + WALLagBytes: &metrics.WalLagInBytes, + } } } } diff --git a/internal/impl/postgresql/pglogicalstream/monitor.go b/internal/impl/postgresql/pglogicalstream/monitor.go index b4c2c9a1a8..b6b403f517 100644 --- a/internal/impl/postgresql/pglogicalstream/monitor.go +++ b/internal/impl/postgresql/pglogicalstream/monitor.go @@ -66,7 +66,21 @@ func NewMonitor(conf *pgconn.Config, logger *service.Logger, tables []string, sl ctx, cancel := context.WithCancel(context.Background()) m.ctx = ctx m.cancelTicker = cancel - m.ticker = time.NewTicker(5 * time.Second) + // hardocded duration to monitor slot lag + m.ticker = time.NewTicker(time.Second * 3) + + go func() { + for { + select { + case <-m.ticker.C: + m.readReplicationLag() + break + case <-m.ctx.Done(): + m.ticker.Stop() + return + } + } + }() return m, nil } @@ -111,19 +125,21 @@ func (m *Monitor) readTablesStat(tables []string) error { func (m *Monitor) readReplicationLag() { result, err := m.dbConn.Query(`SELECT slot_name, pg_wal_lsn_diff(pg_current_wal_lsn(), restart_lsn) AS lag_bytes - FROM pg_replication_slots WHERE slot_name = ?;`, m.slotName) + FROM pg_replication_slots WHERE slot_name = $1;`, m.slotName) // calculate the replication lag in bytes // replicationLagInBytes = latestLsn - confirmedLsn - if result.Err() != nil || err != nil { + if err != nil || result.Err() != nil { m.logger.Errorf("Error reading replication lag: %v", err) return } var slotName string var lagbytes int64 - if err = result.Scan(&slotName, &lagbytes); err != nil { - m.logger.Errorf("Error reading replication lag: %v", err) - return + for result.Next() { + if err = result.Scan(&slotName, &lagbytes); err != nil { + m.logger.Errorf("Error reading replication lag: %v", err) + return + } } m.replicationLagInBytes = lagbytes From fac3216419bd97fc08bd9dcf36cf7be32224914b Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Wed, 16 Oct 2024 13:36:25 +0200 Subject: [PATCH 022/118] chore(): added snapshot metrics --- internal/impl/postgresql/input_postgrecdc.go | 2 -- .../impl/postgresql/pglogicalstream/logical_stream.go | 9 +++++---- internal/impl/postgresql/pglogicalstream/monitor.go | 6 ++++++ 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/internal/impl/postgresql/input_postgrecdc.go b/internal/impl/postgresql/input_postgrecdc.go index c4ed24db0f..36aa64d2f8 100644 --- a/internal/impl/postgresql/input_postgrecdc.go +++ b/internal/impl/postgresql/input_postgrecdc.go @@ -12,7 +12,6 @@ import ( "context" "crypto/tls" "encoding/json" - "fmt" "strings" "time" @@ -291,7 +290,6 @@ func (p *pgStreamInput) Read(ctx context.Context) (*service.Message, service.Ack connectMessage.MetaSet("table", snapshotMessage.Changes[0].Table) connectMessage.MetaSet("operation", snapshotMessage.Changes[0].Operation) if snapshotMessage.Changes[0].TableSnapshotProgress != nil { - fmt.Println("Table snapshot progress", *snapshotMessage.Changes[0].TableSnapshotProgress, snapshotMessage.Changes[0].Table) p.snapshotMetrics.SetFloat64(*snapshotMessage.Changes[0].TableSnapshotProgress, snapshotMessage.Changes[0].Table) } diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index b844a99195..5300e35946 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -537,9 +537,7 @@ func (s *Stream) streamMessagesAsync() { } func (s *Stream) processSnapshot() { - // metricsCtx, cancel := context.WithCancel(context.Background()) if err := s.snapshotter.prepare(); err != nil { - // cancel() s.logger.Errorf("Failed to prepare database snapshot. Probably snapshot is expired...: %v", err.Error()) if err = s.cleanUpOnFailure(); err != nil { s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) @@ -548,7 +546,6 @@ func (s *Stream) processSnapshot() { os.Exit(1) } defer func() { - // cancel() if err := s.snapshotter.releaseSnapshot(); err != nil { s.logger.Errorf("Failed to release database snapshot: %v", err.Error()) } @@ -557,6 +554,8 @@ func (s *Stream) processSnapshot() { } }() + s.logger.Infof("Starting snapshot processing") + for _, table := range s.tableNames { s.logger.Infof("Processing snapshot for table: %v", table) @@ -707,7 +706,9 @@ func (s *Stream) processSnapshot() { }, }, } - + s.monitor.UpdateSnapshotProgressForTable(tableWithoutSchema, rowsCount+offset) + tableProgress := s.monitor.GetSnapshotProgressForTable(tableWithoutSchema) + snapshotChangePacket.Changes[0].TableSnapshotProgress = &tableProgress s.snapshotMessages <- snapshotChangePacket } diff --git a/internal/impl/postgresql/pglogicalstream/monitor.go b/internal/impl/postgresql/pglogicalstream/monitor.go index b6b403f517..f9f6852d83 100644 --- a/internal/impl/postgresql/pglogicalstream/monitor.go +++ b/internal/impl/postgresql/pglogicalstream/monitor.go @@ -85,6 +85,12 @@ func NewMonitor(conf *pgconn.Config, logger *service.Logger, tables []string, sl return m, nil } +func (m *Monitor) GetSnapshotProgressForTable(table string) float64 { + m.lock.Lock() + defer m.lock.Unlock() + return m.snapshotProgress[table] +} + // UpdateSnapshotProgressForTable updates the snapshot ingestion progress for a given table func (m *Monitor) UpdateSnapshotProgressForTable(table string, position int) { m.lock.Lock() From 0418bcf9fe02fedd8c4cd71ff04fcca38c70b91b Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Wed, 16 Oct 2024 13:50:01 +0200 Subject: [PATCH 023/118] chore(): added snapshot metrics streaming --- internal/impl/postgresql/integration_test.go | 2 -- internal/impl/postgresql/pglogicalstream/logical_stream.go | 2 -- internal/impl/postgresql/pglogicalstream/monitor.go | 1 - internal/impl/postgresql/pglogicalstream/snapshotter.go | 1 - internal/impl/postgresql/pglogicalstream/stream_message.go | 2 +- 5 files changed, 1 insertion(+), 7 deletions(-) diff --git a/internal/impl/postgresql/integration_test.go b/internal/impl/postgresql/integration_test.go index d8147a0feb..efb8c6aafd 100644 --- a/internal/impl/postgresql/integration_test.go +++ b/internal/impl/postgresql/integration_test.go @@ -701,8 +701,6 @@ file: } func TestIntegrationPgMultiVersionsCDCForPgOutputStreamComittedPlugin(t *testing.T) { - // running tests in the look to test different PostgreSQL versions - t.Parallel() for _, v := range []string{"17", "16", "15", "14", "13", "12", "11", "10"} { tmpDir := t.TempDir() pool, err := dockertest.NewPool("") diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 5300e35946..9234e40e89 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -206,8 +206,6 @@ func NewPgStream(config Config) (*Stream, error) { SnapshotAction: "export", }, version, stream.snapshotter) if err != nil { - fmt.Println(err) - fmt.Println("Failed to create replication slot", err.Error()) return nil, err } stream.snapshotName = createSlotResult.SnapshotName diff --git a/internal/impl/postgresql/pglogicalstream/monitor.go b/internal/impl/postgresql/pglogicalstream/monitor.go index f9f6852d83..662fc93b2c 100644 --- a/internal/impl/postgresql/pglogicalstream/monitor.go +++ b/internal/impl/postgresql/pglogicalstream/monitor.go @@ -123,7 +123,6 @@ func (m *Monitor) readTablesStat(tables []string) error { results[tableWithoutSchema] = count } - fmt.Println("Table stat", results) m.tableStat = results return nil } diff --git a/internal/impl/postgresql/pglogicalstream/snapshotter.go b/internal/impl/postgresql/pglogicalstream/snapshotter.go index a1e44e9dd0..54d9c800c6 100644 --- a/internal/impl/postgresql/pglogicalstream/snapshotter.go +++ b/internal/impl/postgresql/pglogicalstream/snapshotter.go @@ -98,7 +98,6 @@ func (s *Snapshotter) prepare() error { return err } if _, err := s.pgConnection.Exec(fmt.Sprintf("SET TRANSACTION SNAPSHOT '%s';", s.snapshotName)); err != nil { - fmt.Println("Failed to prepare snapshot", err) return err } diff --git a/internal/impl/postgresql/pglogicalstream/stream_message.go b/internal/impl/postgresql/pglogicalstream/stream_message.go index d2f6e72c24..e4abd06e9e 100644 --- a/internal/impl/postgresql/pglogicalstream/stream_message.go +++ b/internal/impl/postgresql/pglogicalstream/stream_message.go @@ -14,7 +14,7 @@ type StreamMessageChanges struct { Operation string `json:"operation"` Schema string `json:"schema"` Table string `json:"table"` - TableSnapshotProgress *float64 `json:"table_snapshot_progress"` + TableSnapshotProgress *float64 `json:"table_snapshot_progress,omitempty"` // For deleted messages - there will be old changes if replica identity set to full or empty changes Data map[string]any `json:"data"` } From 5579b16f0e3c8d36e343f9695ae67e17cd2b24c6 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Thu, 17 Oct 2024 13:34:17 +0200 Subject: [PATCH 024/118] chore(): added explicit value for snapshot batch size --- internal/impl/postgresql/input_postgrecdc.go | 15 +- internal/impl/postgresql/integration_test.go | 101 +++++++ .../pglogicalstream/logical_stream.go | 267 ++++++++++-------- internal/impl/postgresql/test_utils.go | 39 +++ 4 files changed, 299 insertions(+), 123 deletions(-) create mode 100644 internal/impl/postgresql/test_utils.go diff --git a/internal/impl/postgresql/input_postgrecdc.go b/internal/impl/postgresql/input_postgrecdc.go index 36aa64d2f8..c5ef62d5cb 100644 --- a/internal/impl/postgresql/input_postgrecdc.go +++ b/internal/impl/postgresql/input_postgrecdc.go @@ -56,7 +56,11 @@ var pgStreamConfigSpec = service.NewConfigSpec(). Field(service.NewFloatField("snapshot_memory_safety_factor"). Description("Sets amout of memory that can be used to stream snapshot. If affects batch sizes. If we want to use only 25% of the memory available - put 0.25 factor. It will make initial streaming slower, but it will prevent your worker from OOM Kill"). Example(0.2). - Default(0.5)). + Default(1)). + Field(service.NewIntField("snapshot_batch_size"). + Description("Batch side for querying the snapshot"). + Example(10000). + Default(0)). Field(service.NewStringEnumField("decoding_plugin", "pgoutput", "wal2json").Description("Specifies which decoding plugin to use when streaming data from PostgreSQL"). Example("pgoutput"). Default("pgoutput")). @@ -89,6 +93,7 @@ func newPgStreamInput(conf *service.ParsedConfig, logger *service.Logger, metric decodingPlugin string pgConnOptions string streamUncomited bool + snapshotBatchSize int ) dbSchema, err = conf.FieldString("schema") @@ -165,6 +170,11 @@ func newPgStreamInput(conf *service.ParsedConfig, logger *service.Logger, metric return nil, err } + snapshotBatchSize, err = conf.FieldInt("snapshot_batch_size") + if err != nil { + return nil, err + } + if pgConnOptions, err = conf.FieldString("pg_conn_options"); err != nil { return nil, err } @@ -203,6 +213,7 @@ func newPgStreamInput(conf *service.ParsedConfig, logger *service.Logger, metric decodingPlugin: decodingPlugin, streamUncomited: streamUncomited, temporarySlot: temporarySlot, + snapshotBatchSize: snapshotBatchSize, logger: logger, metrics: metrics, @@ -237,6 +248,7 @@ type pgStreamInput struct { decodingPlugin string streamSnapshot bool snapshotMemSafetyFactor float64 + snapshotBatchSize int streamUncomited bool logger *service.Logger metrics *service.Metrics @@ -258,6 +270,7 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { DBSchema: p.schema, ReplicationSlotName: "rs_" + p.slotName, TLSVerify: p.tls, + BatchSize: p.snapshotBatchSize, StreamOldData: p.streamSnapshot, TemporaryReplicationSlot: p.temporarySlot, StreamUncomited: p.streamUncomited, diff --git a/internal/impl/postgresql/integration_test.go b/internal/impl/postgresql/integration_test.go index efb8c6aafd..6da75dd4ee 100644 --- a/internal/impl/postgresql/integration_test.go +++ b/internal/impl/postgresql/integration_test.go @@ -426,6 +426,107 @@ file: }) } +func TestIntegrationPgStreamingFromRemoteDB(t *testing.T) { + t.Skip("This test requires a remote database to run. Aimed to test AWS") + tmpDir := t.TempDir() + + // tables: users, products, orders, order_items + + host := "" + user := "" + password := "" + dbname := "" + port := "" + sslmode := "" + + template := fmt.Sprintf(` +pg_stream: + host: %s + slot_name: test_slot_native_decoder + user: %s + password: %s + port: %s + schema: public + tls: %s + snapshot_batch_size: 100000 + stream_snapshot: true + decoding_plugin: pgoutput + stream_uncomited: false + database: %s + tables: + - users + - products + - orders + - order_items +`, host, user, password, port, sslmode, dbname) + + cacheConf := fmt.Sprintf(` +label: pg_stream_cache +file: + directory: %v +`, tmpDir) + + streamOutBuilder := service.NewStreamBuilder() + require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: INFO`)) + require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) + require.NoError(t, streamOutBuilder.AddInputYAML(template)) + + var outMessages int64 + var outMessagesMut sync.Mutex + + rc := NewRateCounter() + + go func() { + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + + for range ticker.C { + fmt.Printf("Current rate: %.2f messages per second\n", rc.Rate()) + fmt.Printf("Total messages: %d\n", outMessages) + } + }() + + require.NoError(t, streamOutBuilder.AddConsumerFunc(func(c context.Context, m *service.Message) error { + mb, err := m.AsBytes() + fmt.Println(string(mb)) + require.NoError(t, err) + outMessagesMut.Lock() + outMessages += 1 + outMessagesMut.Unlock() + rc.Increment() + return nil + })) + + streamOut, err := streamOutBuilder.Build() + require.NoError(t, err) + + go func() { + fmt.Println("Starting stream") + _ = streamOut.Run(context.Background()) + }() + + assert.Eventually(t, func() bool { + outMessagesMut.Lock() + defer outMessagesMut.Unlock() + return outMessages == 28528761 + }, time.Minute*15, time.Millisecond*100) + + t.Log("Backfill conditioins are met 🎉") + + // you need to start inserting the data somewhere in another place + time.Sleep(time.Second * 30) + outMessages = 0 + assert.Eventually(t, func() bool { + outMessagesMut.Lock() + defer outMessagesMut.Unlock() + return outMessages == 1000000 + }, time.Minute*15, time.Millisecond*100) + + require.NoError(t, streamOut.StopWithin(time.Second*10)) + + require.NoError(t, streamOut.StopWithin(time.Second*10)) +} + func TestIntegrationPgCDCForPgOutputStreamUncomitedPlugin(t *testing.T) { tmpDir := t.TempDir() pool, err := dockertest.NewPool("") diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 9234e40e89..371abde455 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -104,10 +104,12 @@ func NewPgStream(config Config) (*Stream, error) { cfg.TLSConfig = nil } + fmt.Println("Connecting to database") dbConn, err := pgconn.ConnectConfig(context.Background(), cfg) if err != nil { return nil, err } + fmt.Println("Connected to database") if err = dbConn.Ping(context.Background()); err != nil { return nil, err @@ -182,6 +184,7 @@ func NewPgStream(config Config) (*Stream, error) { // create snapshot transaction before creating a slot for older PostgreSQL versions to ensure consistency pubName := "pglog_stream_" + config.ReplicationSlotName + fmt.Println("Creating publication", pubName, "for tables", tableNames) if err = CreatePublication(context.Background(), stream.pgConn, pubName, tableNames, true); err != nil { return nil, err } @@ -243,6 +246,7 @@ func NewPgStream(config Config) (*Stream, error) { } stream.monitor = monitor + fmt.Println("Starting stream from LSN", stream.lsnrestart, "with clientXLogPos", stream.clientXLogPos, "and snapshot name", stream.snapshotName) if !freshlyCreatedSlot || !config.StreamOldData { if err = stream.startLr(); err != nil { return nil, err @@ -554,43 +558,22 @@ func (s *Stream) processSnapshot() { s.logger.Infof("Starting snapshot processing") - for _, table := range s.tableNames { - s.logger.Infof("Processing snapshot for table: %v", table) - - var ( - avgRowSizeBytes sql.NullInt64 - offset = 0 - err error - ) - - avgRowSizeBytes, err = s.snapshotter.findAvgRowSize(table) - if err != nil { - s.logger.Errorf("Failed to calculate average row size for table %v: %v", table, err.Error()) - if err = s.cleanUpOnFailure(); err != nil { - s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) - } - - os.Exit(1) - } - - availableMemory := getAvailableMemory() - batchSize := s.snapshotter.calculateBatchSize(availableMemory, uint64(avgRowSizeBytes.Int64)) - s.logger.Infof("Querying snapshot batch_side: %v, available_memory: %v, avg_row_size: %v", batchSize, availableMemory, avgRowSizeBytes.Int64) + var wg sync.WaitGroup - tablePk, err := s.getPrimaryKeyColumn(table) - if err != nil { - s.logger.Errorf("Failed to get primary key column for table %v: %v", table, err.Error()) - if err = s.cleanUpOnFailure(); err != nil { - s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) - } + for _, table := range s.tableNames { + wg.Add(1) + go func(tableName string) { + s.logger.Infof("Processing snapshot for table: %v", table) - os.Exit(1) - } + var ( + avgRowSizeBytes sql.NullInt64 + offset = 0 + err error + ) - for { - var snapshotRows *sql.Rows - if snapshotRows, err = s.snapshotter.querySnapshotData(table, tablePk, batchSize, offset); err != nil { - s.logger.Errorf("Failed to query snapshot for table %v: %v", table, err.Error()) + avgRowSizeBytes, err = s.snapshotter.findAvgRowSize(table) + if err != nil { + s.logger.Errorf("Failed to calculate average row size for table %v: %v", table, err.Error()) if err = s.cleanUpOnFailure(); err != nil { s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) } @@ -598,127 +581,159 @@ func (s *Stream) processSnapshot() { os.Exit(1) } - if snapshotRows.Err() != nil { - s.logger.Errorf("Failed to query snapshot for table %v: %v", table, err.Error()) - if err = s.cleanUpOnFailure(); err != nil { - s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) - } - - os.Exit(1) + availableMemory := getAvailableMemory() + batchSize := s.snapshotter.calculateBatchSize(availableMemory, uint64(avgRowSizeBytes.Int64)) + if s.snapshotBatchSize > 0 { + batchSize = s.snapshotBatchSize } - columnTypes, err := snapshotRows.ColumnTypes() - if err != nil { - s.logger.Errorf("Failed to get column types for table %v: %v", table, err.Error()) - if err = s.cleanUpOnFailure(); err != nil { - s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) - } - os.Exit(1) - } + s.logger.Infof("Querying snapshot batch_side: %v, available_memory: %v, avg_row_size: %v", batchSize, availableMemory, avgRowSizeBytes.Int64) - var columnTypesString = make([]string, len(columnTypes)) - columnNames, err := snapshotRows.Columns() + tablePk, err := s.getPrimaryKeyColumn(table) if err != nil { - s.logger.Errorf("Failed to get column names for table %v: %v", table, err.Error()) + s.logger.Errorf("Failed to get primary key column for table %v: %v", table, err.Error()) if err = s.cleanUpOnFailure(); err != nil { s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) } - os.Exit(1) - } - for i := range columnNames { - columnTypesString[i] = columnTypes[i].DatabaseTypeName() + os.Exit(1) } - count := len(columnTypes) - - var rowsCount = 0 - for snapshotRows.Next() { - rowsCount += 1 - scanArgs := make([]interface{}, count) - for i, v := range columnTypes { - switch v.DatabaseTypeName() { - case "VARCHAR", "TEXT", "UUID", "TIMESTAMP": - scanArgs[i] = new(sql.NullString) - case "BOOL": - scanArgs[i] = new(sql.NullBool) - case "INT4": - scanArgs[i] = new(sql.NullInt64) - default: - scanArgs[i] = new(sql.NullString) + for { + var snapshotRows *sql.Rows + if snapshotRows, err = s.snapshotter.querySnapshotData(table, tablePk, batchSize, offset); err != nil { + s.logger.Errorf("Failed to query snapshot for table %v: %v", table, err.Error()) + if err = s.cleanUpOnFailure(); err != nil { + s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) } + + os.Exit(1) } - err := snapshotRows.Scan(scanArgs...) + if snapshotRows.Err() != nil { + s.logger.Errorf("Failed to query snapshot for table %v: %v", table, err.Error()) + if err = s.cleanUpOnFailure(); err != nil { + s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) + } + + os.Exit(1) + } + columnTypes, err := snapshotRows.ColumnTypes() if err != nil { - s.logger.Errorf("Failed to scan row for table %v: %v", table, err.Error()) + s.logger.Errorf("Failed to get column types for table %v: %v", table, err.Error()) if err = s.cleanUpOnFailure(); err != nil { s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) } os.Exit(1) } - var columnValues = make([]interface{}, len(columnTypes)) - for i := range columnTypes { - if z, ok := (scanArgs[i]).(*sql.NullBool); ok { - columnValues[i] = z.Bool - continue - } - if z, ok := (scanArgs[i]).(*sql.NullString); ok { - columnValues[i] = z.String - continue - } - if z, ok := (scanArgs[i]).(*sql.NullInt64); ok { - columnValues[i] = z.Int64 - continue + var columnTypesString = make([]string, len(columnTypes)) + columnNames, err := snapshotRows.Columns() + if err != nil { + s.logger.Errorf("Failed to get column names for table %v: %v", table, err.Error()) + if err = s.cleanUpOnFailure(); err != nil { + s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) } - if z, ok := (scanArgs[i]).(*sql.NullFloat64); ok { - columnValues[i] = z.Float64 - continue + os.Exit(1) + } + + for i := range columnNames { + columnTypesString[i] = columnTypes[i].DatabaseTypeName() + } + + count := len(columnTypes) + + var rowsCount = 0 + for snapshotRows.Next() { + rowsCount += 1 + scanArgs := make([]interface{}, count) + for i, v := range columnTypes { + switch v.DatabaseTypeName() { + case "VARCHAR", "TEXT", "UUID", "TIMESTAMP": + scanArgs[i] = new(sql.NullString) + case "BOOL": + scanArgs[i] = new(sql.NullBool) + case "INT4": + scanArgs[i] = new(sql.NullInt64) + default: + scanArgs[i] = new(sql.NullString) + } } - if z, ok := (scanArgs[i]).(*sql.NullInt32); ok { - columnValues[i] = z.Int32 - continue + + err := snapshotRows.Scan(scanArgs...) + + if err != nil { + s.logger.Errorf("Failed to scan row for table %v: %v", table, err.Error()) + if err = s.cleanUpOnFailure(); err != nil { + s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) + } + os.Exit(1) } - columnValues[i] = scanArgs[i] - } + var columnValues = make([]interface{}, len(columnTypes)) + for i := range columnTypes { + if z, ok := (scanArgs[i]).(*sql.NullBool); ok { + columnValues[i] = z.Bool + continue + } + if z, ok := (scanArgs[i]).(*sql.NullString); ok { + columnValues[i] = z.String + continue + } + if z, ok := (scanArgs[i]).(*sql.NullInt64); ok { + columnValues[i] = z.Int64 + continue + } + if z, ok := (scanArgs[i]).(*sql.NullFloat64); ok { + columnValues[i] = z.Float64 + continue + } + if z, ok := (scanArgs[i]).(*sql.NullInt32); ok { + columnValues[i] = z.Int32 + continue + } - tableWithoutSchema := strings.Split(table, ".")[1] - snapshotChangePacket := StreamMessage{ - Lsn: nil, - Changes: []StreamMessageChanges{ - { - Table: tableWithoutSchema, - Operation: "insert", - Schema: s.schema, - Data: func() map[string]any { - var data = make(map[string]any) - for i, cn := range columnNames { - data[cn] = columnValues[i] - } + columnValues[i] = scanArgs[i] + } + + tableWithoutSchema := strings.Split(table, ".")[1] + snapshotChangePacket := StreamMessage{ + Lsn: nil, + Changes: []StreamMessageChanges{ + { + Table: tableWithoutSchema, + Operation: "insert", + Schema: s.schema, + Data: func() map[string]any { + var data = make(map[string]any) + for i, cn := range columnNames { + data[cn] = columnValues[i] + } - return data - }(), + return data + }(), + }, }, - }, + } + s.monitor.UpdateSnapshotProgressForTable(tableWithoutSchema, rowsCount+offset) + tableProgress := s.monitor.GetSnapshotProgressForTable(tableWithoutSchema) + snapshotChangePacket.Changes[0].TableSnapshotProgress = &tableProgress + s.snapshotMessages <- snapshotChangePacket } - s.monitor.UpdateSnapshotProgressForTable(tableWithoutSchema, rowsCount+offset) - tableProgress := s.monitor.GetSnapshotProgressForTable(tableWithoutSchema) - snapshotChangePacket.Changes[0].TableSnapshotProgress = &tableProgress - s.snapshotMessages <- snapshotChangePacket - } - offset += batchSize + offset += batchSize - if batchSize != rowsCount { - break + if batchSize != rowsCount { + break + } } - } - + wg.Done() + }(table) } + wg.Wait() + if err := s.startLr(); err != nil { s.logger.Errorf("Failed to start logical replication after snapshot: %v", err.Error()) os.Exit(1) @@ -771,6 +786,9 @@ func (s *Stream) getPrimaryKeyColumn(tableName string) (string, error) { // Stop closes the stream conect and prevents from replication slot read func (s *Stream) Stop() error { + if s == nil { + return nil + } s.m.Lock() s.stopped = true s.m.Unlock() @@ -779,7 +797,12 @@ func (s *Stream) Stop() error { if s.pgConn != nil { if s.streamCtx != nil { s.streamCancel() - s.standbyCtxCancel() + // s.standbyCtxCancel is initialized later when starting reading from the replication slot. + // In case we failed to start replication of the process was shut down before starting the replication slot + // we need to check if the context is not nil before calling cancel + if s.standbyCtxCancel != nil { + s.standbyCtxCancel() + } } return s.pgConn.Close(context.Background()) } diff --git a/internal/impl/postgresql/test_utils.go b/internal/impl/postgresql/test_utils.go new file mode 100644 index 0000000000..700bcd7ed2 --- /dev/null +++ b/internal/impl/postgresql/test_utils.go @@ -0,0 +1,39 @@ +package pgstream + +import ( + "sync" + "sync/atomic" + "time" +) + +// RateCounter is used to measure the rate of invocations +type RateCounter struct { + count int64 + lastChecked time.Time + mutex sync.Mutex +} + +// NewRateCounter creates a new RateCounter +func NewRateCounter() *RateCounter { + return &RateCounter{ + lastChecked: time.Now(), + } +} + +// Increment increases the counter by 1 +func (rc *RateCounter) Increment() { + atomic.AddInt64(&rc.count, 1) +} + +// Rate calculates the current rate of invocations per second +func (rc *RateCounter) Rate() float64 { + rc.mutex.Lock() + defer rc.mutex.Unlock() + + now := time.Now() + duration := now.Sub(rc.lastChecked).Seconds() + count := atomic.SwapInt64(&rc.count, 0) + rc.lastChecked = now + + return float64(count) / duration +} From 64770ad9cca15c286bb5e13eda8cc07984ea491f Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Thu, 17 Oct 2024 13:58:54 +0200 Subject: [PATCH 025/118] chore(): updated docs --- internal/impl/postgresql/input_postgrecdc.go | 44 ++++++++++++-------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/internal/impl/postgresql/input_postgrecdc.go b/internal/impl/postgresql/input_postgrecdc.go index c5ef62d5cb..c4ac4229d8 100644 --- a/internal/impl/postgresql/input_postgrecdc.go +++ b/internal/impl/postgresql/input_postgrecdc.go @@ -25,54 +25,62 @@ import ( var randomSlotName string var pgStreamConfigSpec = service.NewConfigSpec(). - Summary("Creates Postgres replication slot for CDC"). + Summary("Creates a PostgreSQL replication slot for Change Data Capture (CDC)"). Field(service.NewStringField("host"). - Description("PostgreSQL instance host"). + Description("The hostname or IP address of the PostgreSQL instance."). Example("123.0.0.1")). Field(service.NewIntField("port"). - Description("PostgreSQL instance port"). + Description("The port number on which the PostgreSQL instance is listening."). Example(5432). Default(5432)). Field(service.NewStringField("user"). - Description("Username with permissions to start replication (RDS superuser)"). + Description("Username of a user with replication permissions. For AWS RDS, this typically requires superuser privileges."). Example("postgres"), ). Field(service.NewStringField("password"). - Description("PostgreSQL database password")). + Description("Password for the specified PostgreSQL user.")). Field(service.NewStringField("schema"). - Description("Schema that will be used to create replication")). + Description("The PostgreSQL schema from which to replicate data.")). Field(service.NewStringField("database"). - Description("PostgreSQL database name")). + Description("The name of the PostgreSQL database to connect to.")). Field(service.NewStringEnumField("tls", "require", "none"). - Description("Defines whether benthos need to verify (skipinsecure) TLS configuration"). + Description("Specifies whether to use TLS for the database connection. Set to 'require' to enforce TLS, or 'none' to disable it."). Example("none"). Default("none")). - Field(service.NewBoolField("stream_uncomited").Default(false).Description("Defines whether you want to stream uncomitted messages before receiving commit message from postgres. This may lead to duplicated records after the the connector has been restarted")). - Field(service.NewStringField("pg_conn_options").Default("")). + Field(service.NewBoolField("stream_uncomited"). + Description("If set to true, the plugin will stream uncommitted transactions before receiving a commit message from PostgreSQL. This may result in duplicate records if the connector is restarted."). + Default(false)). + Field(service.NewStringField("pg_conn_options"). + Description("Additional PostgreSQL connection options as a string. Refer to PostgreSQL documentation for available options."). + Default(""), + ). Field(service.NewBoolField("stream_snapshot"). - Description("Set `true` if you want to receive all the data that currently exist in database"). + Description("When set to true, the plugin will first stream a snapshot of all existing data in the database before streaming changes."). Example(true). Default(false)). Field(service.NewFloatField("snapshot_memory_safety_factor"). - Description("Sets amout of memory that can be used to stream snapshot. If affects batch sizes. If we want to use only 25% of the memory available - put 0.25 factor. It will make initial streaming slower, but it will prevent your worker from OOM Kill"). + Description("Determines the fraction of available memory that can be used for streaming the snapshot. Values between 0 and 1 represent the percentage of memory to use. Lower values make initial streaming slower but help prevent out-of-memory errors."). Example(0.2). Default(1)). Field(service.NewIntField("snapshot_batch_size"). - Description("Batch side for querying the snapshot"). + Description("The number of rows to fetch in each batch when querying the snapshot. A value of 0 lets the plugin determine the batch size based on `snapshot_memory_safety_factor` property."). Example(10000). Default(0)). - Field(service.NewStringEnumField("decoding_plugin", "pgoutput", "wal2json").Description("Specifies which decoding plugin to use when streaming data from PostgreSQL"). + Field(service.NewStringEnumField("decoding_plugin", "pgoutput", "wal2json"). + Description("Specifies the logical decoding plugin to use for streaming changes from PostgreSQL. 'pgoutput' is the native logical replication protocol, while 'wal2json' provides change data as JSON."). Example("pgoutput"). Default("pgoutput")). Field(service.NewStringListField("tables"). + Description("A list of table names to include in the logical replication. Each table should be specified as a separate item."). Example(` - my_table - my_table_2 - `). - Description("List of tables we have to create logical replication for")). - Field(service.NewBoolField("temporary_slot").Default(false)). + `)). + Field(service.NewBoolField("temporary_slot"). + Description("If set to true, creates a temporary replication slot that is automatically dropped when the connection is closed."). + Default(false)). Field(service.NewStringField("slot_name"). - Description("PostgeSQL logical replication slot name. You can create it manually before starting the sync. If not provided will be replaced with a random one"). + Description("The name of the PostgreSQL logical replication slot to use. If not provided, a random name will be generated. You can create this slot manually before starting replication if desired."). Example("my_test_slot"). Default(randomSlotName)) From 803b6822e2b0af3d56308b7d82bd33162d5310c6 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Thu, 17 Oct 2024 14:05:10 +0200 Subject: [PATCH 026/118] chore(): updated docs --- internal/impl/postgresql/input_postgrecdc.go | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/internal/impl/postgresql/input_postgrecdc.go b/internal/impl/postgresql/input_postgrecdc.go index c4ac4229d8..eb8aef5de3 100644 --- a/internal/impl/postgresql/input_postgrecdc.go +++ b/internal/impl/postgresql/input_postgrecdc.go @@ -25,7 +25,17 @@ import ( var randomSlotName string var pgStreamConfigSpec = service.NewConfigSpec(). - Summary("Creates a PostgreSQL replication slot for Change Data Capture (CDC)"). + Beta(). + Categories("Services"). + Version("0.0.1"). + Summary(`Creates a PostgreSQL replication slot for Change Data Capture (CDC) + == Metadata + +This input adds the following metadata fields to each message: +- streaming (Indicates whether the message is part of a streaming operation or snapshot processing) +- table (Name of the table that the message originated from) +- operation (Type of operation that generated the message, such as INSERT, UPDATE, or DELETE) + `). Field(service.NewStringField("host"). Description("The hostname or IP address of the PostgreSQL instance."). Example("123.0.0.1")). @@ -308,6 +318,7 @@ func (p *pgStreamInput) Read(ctx context.Context) (*service.Message, service.Ack } connectMessage := service.NewMessage(mb) + connectMessage.MetaSet("streaming", "false") connectMessage.MetaSet("table", snapshotMessage.Changes[0].Table) connectMessage.MetaSet("operation", snapshotMessage.Changes[0].Operation) if snapshotMessage.Changes[0].TableSnapshotProgress != nil { @@ -327,6 +338,7 @@ func (p *pgStreamInput) Read(ctx context.Context) (*service.Message, service.Ack return nil, nil, err } connectMessage := service.NewMessage(mb) + connectMessage.MetaSet("streaming", "true") connectMessage.MetaSet("table", message.Changes[0].Table) connectMessage.MetaSet("operation", message.Changes[0].Operation) if message.WALLagBytes != nil { From 34f5995f810f0197336c96d90290b91376b05267 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Thu, 17 Oct 2024 14:09:36 +0200 Subject: [PATCH 027/118] chore(): applieds golangci-lint notes --- internal/impl/postgresql/integration_test.go | 4 ++++ .../postgresql/pglogicalstream/debouncer.go | 2 +- .../pglogicalstream/logical_stream.go | 13 ++++++------- .../postgresql/pglogicalstream/monitor.go | 19 +++++++------------ .../postgresql/pglogicalstream/snapshotter.go | 11 ++++++----- .../pglogicalstream/stream_message.go | 1 + 6 files changed, 25 insertions(+), 25 deletions(-) diff --git a/internal/impl/postgresql/integration_test.go b/internal/impl/postgresql/integration_test.go index 6da75dd4ee..4a1d35e2fd 100644 --- a/internal/impl/postgresql/integration_test.go +++ b/internal/impl/postgresql/integration_test.go @@ -369,6 +369,7 @@ file: for i := 0; i < 10; i++ { _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + require.NoError(t, err) _, err = db.Exec("INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) require.NoError(t, err) } @@ -605,6 +606,7 @@ file: for i := 0; i < 10; i++ { _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + require.NoError(t, err) _, err = db.Exec("INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) require.NoError(t, err) } @@ -743,6 +745,7 @@ file: for i := 0; i < 1000; i++ { _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + require.NoError(t, err) _, err = db.Exec("INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) require.NoError(t, err) } @@ -880,6 +883,7 @@ file: for i := 0; i < 1000; i++ { _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + require.NoError(t, err) _, err = db.Exec("INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) require.NoError(t, err) } diff --git a/internal/impl/postgresql/pglogicalstream/debouncer.go b/internal/impl/postgresql/pglogicalstream/debouncer.go index e33837279e..9fbd9ae4f0 100644 --- a/internal/impl/postgresql/pglogicalstream/debouncer.go +++ b/internal/impl/postgresql/pglogicalstream/debouncer.go @@ -13,7 +13,7 @@ import ( "time" ) -// New returns a debounced function that takes another functions as its argument. +// NewDebouncer New returns a debounced function that takes another functions as its argument. // This function will be called when the debounced function stops being called // for the given duration. // The debounced function can be invoked with different functions, if needed, diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 371abde455..38813d1756 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -148,14 +148,12 @@ func NewPgStream(config Config) (*Stream, error) { snapshotter, err := NewSnapshotter(stream.dbConfig, stream.logger, version) if err != nil { - if err != nil { - stream.logger.Errorf("Failed to open SQL connection to prepare snapshot: %v", err.Error()) - if err = stream.cleanUpOnFailure(); err != nil { - stream.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) - } - - os.Exit(1) + stream.logger.Errorf("Failed to open SQL connection to prepare snapshot: %v", err.Error()) + if err = stream.cleanUpOnFailure(); err != nil { + stream.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) } + + os.Exit(1) } stream.snapshotter = snapshotter @@ -268,6 +266,7 @@ func (s *Stream) GetProgress() *Report { return s.monitor.Report() } +// ConsumedCallback returns a channel that is used to tell the plugin to commit consumed offset func (s *Stream) ConsumedCallback() chan bool { return s.consumedCallback } diff --git a/internal/impl/postgresql/pglogicalstream/monitor.go b/internal/impl/postgresql/pglogicalstream/monitor.go index 662fc93b2c..cf4add3aa1 100644 --- a/internal/impl/postgresql/pglogicalstream/monitor.go +++ b/internal/impl/postgresql/pglogicalstream/monitor.go @@ -21,11 +21,13 @@ import ( "github.com/redpanda-data/benthos/v4/public/service" ) +// Report is a structure that contains the current state of the Monitor type Report struct { WalLagInBytes int64 TableProgress map[string]float64 } +// Monitor is a structure that allows monitoring the progress of snapshot ingestion and replication lag type Monitor struct { // tableStat contains numbers of rows for each table determined at the moment of the snapshot creation // this is used to calculate snapshot ingestion progress @@ -45,6 +47,7 @@ type Monitor struct { ctx context.Context } +// NewMonitor creates a new Monitor instance func NewMonitor(conf *pgconn.Config, logger *service.Logger, tables []string, slotName string) (*Monitor, error) { dbConn, err := openPgConnectionFromConfig(*conf) if err != nil { @@ -85,6 +88,7 @@ func NewMonitor(conf *pgconn.Config, logger *service.Logger, tables []string, sl return m, nil } +// GetSnapshotProgressForTable returns the snapshot ingestion progress for a given table func (m *Monitor) GetSnapshotProgressForTable(table string) float64 { m.lock.Lock() defer m.lock.Unlock() @@ -104,7 +108,7 @@ func (m *Monitor) readTablesStat(tables []string) error { for _, table := range tables { tableWithoutSchema := strings.Split(table, ".")[1] - query := fmt.Sprintf("SELECT COUNT(*) FROM %s", tableWithoutSchema) + query := "SELECT COUNT(*) FROM %s" + tableWithoutSchema var count int64 err := m.dbConn.QueryRow(query).Scan(&count) @@ -150,6 +154,7 @@ func (m *Monitor) readReplicationLag() { m.replicationLagInBytes = lagbytes } +// Report returns a snapshot of the monitor's state func (m *Monitor) Report() *Report { m.lock.Lock() defer m.lock.Unlock() @@ -161,19 +166,9 @@ func (m *Monitor) Report() *Report { } } +// Stop stops the monitor func (m *Monitor) Stop() { m.cancelTicker() m.ticker.Stop() m.dbConn.Close() } - -func (m *Monitor) startSync() { - for { - select { - case <-m.ctx.Done(): - return - case <-m.ticker.C: - m.readReplicationLag() - } - } -} diff --git a/internal/impl/postgresql/pglogicalstream/snapshotter.go b/internal/impl/postgresql/pglogicalstream/snapshotter.go index 54d9c800c6..4015c21e1e 100644 --- a/internal/impl/postgresql/pglogicalstream/snapshotter.go +++ b/internal/impl/postgresql/pglogicalstream/snapshotter.go @@ -19,6 +19,7 @@ import ( "github.com/redpanda-data/benthos/v4/public/service" ) +// SnapshotCreationResponse is a structure that contains the name of the snapshot that was created type SnapshotCreationResponse struct { ExportedSnapshotName string } @@ -60,14 +61,14 @@ func NewSnapshotter(dbConf pgconn.Config, logger *service.Logger, version int) ( func (s *Snapshotter) initSnapshotTransaction() (SnapshotCreationResponse, error) { if s.version > 14 { - return SnapshotCreationResponse{}, errors.New("Snapshot is exported by default for versions above PG14") + return SnapshotCreationResponse{}, errors.New("snapshot is exported by default for versions above PG14") } var snapshotName sql.NullString snapshotRow, err := s.pgConnection.Query(`BEGIN; SELECT pg_export_snapshot();`) if err != nil { - return SnapshotCreationResponse{}, fmt.Errorf("Cant get exported snapshot for initial streaming %w", err) + return SnapshotCreationResponse{}, fmt.Errorf("cant get exported snapshot for initial streaming %w", err) } if snapshotRow.Err() != nil { @@ -76,10 +77,10 @@ func (s *Snapshotter) initSnapshotTransaction() (SnapshotCreationResponse, error if snapshotRow.Next() { if err = snapshotRow.Scan(&snapshotName); err != nil { - return SnapshotCreationResponse{}, fmt.Errorf("Cant scan snapshot name into string: %w", err) + return SnapshotCreationResponse{}, fmt.Errorf("cant scan snapshot name into string: %w", err) } } else { - return SnapshotCreationResponse{}, errors.New("can get avg row size; 0 rows returned") + return SnapshotCreationResponse{}, errors.New("cant get avg row size; 0 rows returned") } return SnapshotCreationResponse{ExportedSnapshotName: snapshotName.String}, nil @@ -91,7 +92,7 @@ func (s *Snapshotter) setTransactionSnapshotName(snapshotName string) { func (s *Snapshotter) prepare() error { if s.snapshotName == "" { - return errors.New("Snapshot name is not set") + return errors.New("snapshot name is not set") } if _, err := s.pgConnection.Exec("BEGIN TRANSACTION ISOLATION LEVEL REPEATABLE READ;"); err != nil { diff --git a/internal/impl/postgresql/pglogicalstream/stream_message.go b/internal/impl/postgresql/pglogicalstream/stream_message.go index e4abd06e9e..3a139fcc29 100644 --- a/internal/impl/postgresql/pglogicalstream/stream_message.go +++ b/internal/impl/postgresql/pglogicalstream/stream_message.go @@ -19,6 +19,7 @@ type StreamMessageChanges struct { Data map[string]any `json:"data"` } +// StreamMessageMetrics represents the metrics of a stream. Passed to each message type StreamMessageMetrics struct { WALLagBytes *int64 `json:"wal_lag_bytes"` IsStreaming bool `json:"is_streaming"` From 622303e618c852365c72ffdcc9961b6ff387f1fb Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Wed, 23 Oct 2024 11:32:29 +0200 Subject: [PATCH 028/118] chore(): working on faster snapshot processing --- go.mod | 1 + go.sum | 3 + internal/impl/postgresql/input_postgrecdc.go | 3 - internal/impl/postgresql/integration_test.go | 24 +-- .../pglogicalstream/logical_stream.go | 139 +++++++++--------- .../postgresql/pglogicalstream/monitor.go | 2 +- .../postgresql/pglogicalstream/snapshotter.go | 24 +++ public/components/community/package.go | 1 + public/components/postgresql/package.go | 20 +++ 9 files changed, 133 insertions(+), 84 deletions(-) create mode 100644 public/components/postgresql/package.go diff --git a/go.mod b/go.mod index e5bbf4939f..a6dabfc2e7 100644 --- a/go.mod +++ b/go.mod @@ -88,6 +88,7 @@ require ( github.com/opensearch-project/opensearch-go/v3 v3.1.0 github.com/ory/dockertest/v3 v3.11.0 github.com/oschwald/geoip2-golang v1.11.0 + github.com/panjf2000/ants/v2 v2.10.0 github.com/parquet-go/parquet-go v0.23.0 github.com/pebbe/zmq4 v1.2.11 github.com/pinecone-io/go-pinecone v1.0.0 diff --git a/go.sum b/go.sum index 66681fbcfb..a732a1e76a 100644 --- a/go.sum +++ b/go.sum @@ -948,6 +948,8 @@ github.com/oschwald/geoip2-golang v1.11.0 h1:hNENhCn1Uyzhf9PTmquXENiWS6AlxAEnBII github.com/oschwald/geoip2-golang v1.11.0/go.mod h1:P9zG+54KPEFOliZ29i7SeYZ/GM6tfEL+rgSn03hYuUo= github.com/oschwald/maxminddb-golang v1.13.0 h1:R8xBorY71s84yO06NgTmQvqvTvlS/bnYZrrWX1MElnU= github.com/oschwald/maxminddb-golang v1.13.0/go.mod h1:BU0z8BfFVhi1LQaonTwwGQlsHUEu9pWNdMfmq4ztm0o= +github.com/panjf2000/ants/v2 v2.10.0 h1:zhRg1pQUtkyRiOFo2Sbqwjp0GfBNo9cUY2/Grpx1p+8= +github.com/panjf2000/ants/v2 v2.10.0/go.mod h1:7ZxyxsqE4vvW0M7LSD8aI3cKwgFhBHbxnlN8mDqHa1I= github.com/parquet-go/parquet-go v0.23.0 h1:dyEU5oiHCtbASyItMCD2tXtT2nPmoPbKpqf0+nnGrmk= github.com/parquet-go/parquet-go v0.23.0/go.mod h1:MnwbUcFHU6uBYMymKAlPPAw9yh3kE1wWl6Gl1uLdkNk= github.com/pascaldekloe/goe v0.1.0/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= @@ -1404,6 +1406,7 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= diff --git a/internal/impl/postgresql/input_postgrecdc.go b/internal/impl/postgresql/input_postgrecdc.go index eb8aef5de3..d886990b15 100644 --- a/internal/impl/postgresql/input_postgrecdc.go +++ b/internal/impl/postgresql/input_postgrecdc.go @@ -13,7 +13,6 @@ import ( "crypto/tls" "encoding/json" "strings" - "time" "github.com/jackc/pgx/v5/pgconn" "github.com/lucasepe/codename" @@ -270,7 +269,6 @@ type pgStreamInput struct { streamUncomited bool logger *service.Logger metrics *service.Metrics - metricsTicker *time.Ticker snapshotMetrics *service.MetricGauge replicationLag *service.MetricGauge @@ -299,7 +297,6 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { return err } - p.metricsTicker = time.NewTicker(5 * time.Second) p.pglogicalStream = pgStream return err diff --git a/internal/impl/postgresql/integration_test.go b/internal/impl/postgresql/integration_test.go index 4a1d35e2fd..3afadffb48 100644 --- a/internal/impl/postgresql/integration_test.go +++ b/internal/impl/postgresql/integration_test.go @@ -428,17 +428,17 @@ file: } func TestIntegrationPgStreamingFromRemoteDB(t *testing.T) { - t.Skip("This test requires a remote database to run. Aimed to test AWS") + // t.Skip("This test requires a remote database to run. Aimed to test remote databases") tmpDir := t.TempDir() // tables: users, products, orders, order_items - host := "" - user := "" - password := "" - dbname := "" - port := "" - sslmode := "" + host := "localhost" + user := "postgres" + password := "postgres" + dbname := "postgres" + port := "5432" + sslmode := "none" template := fmt.Sprintf(` pg_stream: @@ -453,6 +453,7 @@ pg_stream: stream_snapshot: true decoding_plugin: pgoutput stream_uncomited: false + temporary_slot: true database: %s tables: - users @@ -488,8 +489,7 @@ file: }() require.NoError(t, streamOutBuilder.AddConsumerFunc(func(c context.Context, m *service.Message) error { - mb, err := m.AsBytes() - fmt.Println(string(mb)) + _, err := m.AsBytes() require.NoError(t, err) outMessagesMut.Lock() outMessages += 1 @@ -509,13 +509,13 @@ file: assert.Eventually(t, func() bool { outMessagesMut.Lock() defer outMessagesMut.Unlock() - return outMessages == 28528761 + return outMessages == 200000 }, time.Minute*15, time.Millisecond*100) t.Log("Backfill conditioins are met 🎉") // you need to start inserting the data somewhere in another place - time.Sleep(time.Second * 30) + time.Sleep(time.Minute * 30) outMessages = 0 assert.Eventually(t, func() bool { outMessagesMut.Lock() @@ -575,7 +575,7 @@ file: `, tmpDir) streamOutBuilder := service.NewStreamBuilder() - require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: OFF`)) + require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: INFO`)) require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) require.NoError(t, streamOutBuilder.AddInputYAML(template)) diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 38813d1756..20bb1f40e0 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -19,6 +19,8 @@ import ( "sync" "time" + "github.com/panjf2000/ants/v2" + "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgproto3" "github.com/jackc/pgx/v5/pgtype" @@ -95,6 +97,7 @@ func NewPgStream(config Config) (*Stream, error) { if cfg, err = pgconn.ParseConfig(q); err != nil { return nil, err } + cfg.Password = config.DBPassword if config.TLSVerify == TLSRequireVerify { cfg.TLSConfig = &tls.Config{ @@ -104,12 +107,10 @@ func NewPgStream(config Config) (*Stream, error) { cfg.TLSConfig = nil } - fmt.Println("Connecting to database") dbConn, err := pgconn.ConnectConfig(context.Background(), cfg) if err != nil { return nil, err } - fmt.Println("Connected to database") if err = dbConn.Ping(context.Background()); err != nil { return nil, err @@ -598,9 +599,48 @@ func (s *Stream) processSnapshot() { os.Exit(1) } + type RawMessage struct { + ColumnNames []string + ColumnValues []interface{} + TableName string + } + + var pwg sync.WaitGroup + p, _ := ants.NewPoolWithFunc(batchSize/4, func(i interface{}) { + m := i.(RawMessage) + + snapshotChangePacket := StreamMessage{ + Lsn: nil, + Changes: []StreamMessageChanges{ + { + Table: m.TableName, + Operation: "insert", + Schema: s.schema, + Data: func() map[string]any { + var data = make(map[string]any) + for i, cn := range m.ColumnNames { + data[cn] = m.ColumnValues[i] + } + return data + }(), + }, + }, + } + + tableProgress := s.monitor.GetSnapshotProgressForTable(m.TableName) + snapshotChangePacket.Changes[0].TableSnapshotProgress = &tableProgress + + s.snapshotMessages <- snapshotChangePacket + + pwg.Done() + }, ants.WithPreAlloc(true)) + defer p.Release() + for { var snapshotRows *sql.Rows + queryStart := time.Now() if snapshotRows, err = s.snapshotter.querySnapshotData(table, tablePk, batchSize, offset); err != nil { + s.logger.Errorf("Failed to query snapshot data for table %v: %v", table, err.Error()) s.logger.Errorf("Failed to query snapshot for table %v: %v", table, err.Error()) if err = s.cleanUpOnFailure(); err != nil { s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) @@ -609,7 +649,11 @@ func (s *Stream) processSnapshot() { os.Exit(1) } + queryDuration := time.Since(queryStart) + fmt.Printf("Query duration: %v %s \n", queryDuration, tableName) + if snapshotRows.Err() != nil { + s.logger.Errorf("Failed to get snapshot data for table %v: %v", table, snapshotRows.Err().Error()) s.logger.Errorf("Failed to query snapshot for table %v: %v", table, err.Error()) if err = s.cleanUpOnFailure(); err != nil { s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) @@ -620,6 +664,7 @@ func (s *Stream) processSnapshot() { columnTypes, err := snapshotRows.ColumnTypes() if err != nil { + fmt.Println("Failed to get column types") s.logger.Errorf("Failed to get column types for table %v: %v", table, err.Error()) if err = s.cleanUpOnFailure(); err != nil { s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) @@ -627,9 +672,9 @@ func (s *Stream) processSnapshot() { os.Exit(1) } - var columnTypesString = make([]string, len(columnTypes)) columnNames, err := snapshotRows.Columns() if err != nil { + fmt.Println("Failed to get column names") s.logger.Errorf("Failed to get column names for table %v: %v", table, err.Error()) if err = s.cleanUpOnFailure(); err != nil { s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) @@ -637,32 +682,18 @@ func (s *Stream) processSnapshot() { os.Exit(1) } - for i := range columnNames { - columnTypesString[i] = columnTypes[i].DatabaseTypeName() - } - - count := len(columnTypes) - var rowsCount = 0 + rowsStart := time.Now() + + tableWithoutSchema := strings.Split(table, ".")[1] for snapshotRows.Next() { rowsCount += 1 - scanArgs := make([]interface{}, count) - for i, v := range columnTypes { - switch v.DatabaseTypeName() { - case "VARCHAR", "TEXT", "UUID", "TIMESTAMP": - scanArgs[i] = new(sql.NullString) - case "BOOL": - scanArgs[i] = new(sql.NullBool) - case "INT4": - scanArgs[i] = new(sql.NullInt64) - default: - scanArgs[i] = new(sql.NullString) - } - } + scanArgs, valueGetters := s.snapshotter.prepareScannersAndGetters(columnTypes) err := snapshotRows.Scan(scanArgs...) if err != nil { + fmt.Println("Failed to scan row") s.logger.Errorf("Failed to scan row for table %v: %v", table, err.Error()) if err = s.cleanUpOnFailure(); err != nil { s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) @@ -670,57 +701,29 @@ func (s *Stream) processSnapshot() { os.Exit(1) } - var columnValues = make([]interface{}, len(columnTypes)) - for i := range columnTypes { - if z, ok := (scanArgs[i]).(*sql.NullBool); ok { - columnValues[i] = z.Bool - continue - } - if z, ok := (scanArgs[i]).(*sql.NullString); ok { - columnValues[i] = z.String - continue - } - if z, ok := (scanArgs[i]).(*sql.NullInt64); ok { - columnValues[i] = z.Int64 - continue - } - if z, ok := (scanArgs[i]).(*sql.NullFloat64); ok { - columnValues[i] = z.Float64 - continue - } - if z, ok := (scanArgs[i]).(*sql.NullInt32); ok { - columnValues[i] = z.Int32 - continue - } - - columnValues[i] = scanArgs[i] + columnValues := make([]interface{}, len(columnTypes)) + for i, getter := range valueGetters { + columnValues[i] = getter(scanArgs[i]) } - tableWithoutSchema := strings.Split(table, ".")[1] - snapshotChangePacket := StreamMessage{ - Lsn: nil, - Changes: []StreamMessageChanges{ - { - Table: tableWithoutSchema, - Operation: "insert", - Schema: s.schema, - Data: func() map[string]any { - var data = make(map[string]any) - for i, cn := range columnNames { - data[cn] = columnValues[i] - } - - return data - }(), - }, - }, + if rowsCount%100 == 0 { + s.monitor.UpdateSnapshotProgressForTable(tableWithoutSchema, rowsCount+offset) } - s.monitor.UpdateSnapshotProgressForTable(tableWithoutSchema, rowsCount+offset) - tableProgress := s.monitor.GetSnapshotProgressForTable(tableWithoutSchema) - snapshotChangePacket.Changes[0].TableSnapshotProgress = &tableProgress - s.snapshotMessages <- snapshotChangePacket + + pwg.Add(1) + _ = p.Invoke(RawMessage{ + TableName: tableWithoutSchema, + ColumnNames: columnNames, + ColumnValues: columnValues, + }) } + // waiting for batch to be processed + pwg.Wait() + + batchEnd := time.Since(rowsStart) + fmt.Printf("Batch duration: %v %s \n", batchEnd, tableName) + offset += batchSize if batchSize != rowsCount { diff --git a/internal/impl/postgresql/pglogicalstream/monitor.go b/internal/impl/postgresql/pglogicalstream/monitor.go index cf4add3aa1..3ae98834e8 100644 --- a/internal/impl/postgresql/pglogicalstream/monitor.go +++ b/internal/impl/postgresql/pglogicalstream/monitor.go @@ -108,7 +108,7 @@ func (m *Monitor) readTablesStat(tables []string) error { for _, table := range tables { tableWithoutSchema := strings.Split(table, ".")[1] - query := "SELECT COUNT(*) FROM %s" + tableWithoutSchema + query := "SELECT COUNT(*) FROM " + tableWithoutSchema var count int64 err := m.dbConn.QueryRow(query).Scan(&count) diff --git a/internal/impl/postgresql/pglogicalstream/snapshotter.go b/internal/impl/postgresql/pglogicalstream/snapshotter.go index 4015c21e1e..525e6d8d4c 100644 --- a/internal/impl/postgresql/pglogicalstream/snapshotter.go +++ b/internal/impl/postgresql/pglogicalstream/snapshotter.go @@ -130,6 +130,30 @@ func (s *Snapshotter) findAvgRowSize(table string) (sql.NullInt64, error) { return avgRowSize, nil } +func (s *Snapshotter) prepareScannersAndGetters(columnTypes []*sql.ColumnType) ([]interface{}, []func(interface{}) interface{}) { + scanArgs := make([]interface{}, len(columnTypes)) + valueGetters := make([]func(interface{}) interface{}, len(columnTypes)) + + for i, v := range columnTypes { + switch v.DatabaseTypeName() { + case "VARCHAR", "TEXT", "UUID", "TIMESTAMP": + scanArgs[i] = new(sql.NullString) + valueGetters[i] = func(v interface{}) interface{} { return v.(*sql.NullString).String } + case "BOOL": + scanArgs[i] = new(sql.NullBool) + valueGetters[i] = func(v interface{}) interface{} { return v.(*sql.NullBool).Bool } + case "INT4": + scanArgs[i] = new(sql.NullInt64) + valueGetters[i] = func(v interface{}) interface{} { return v.(*sql.NullInt64).Int64 } + default: + scanArgs[i] = new(sql.NullString) + valueGetters[i] = func(v interface{}) interface{} { return v.(*sql.NullString).String } + } + } + + return scanArgs, valueGetters +} + func (s *Snapshotter) calculateBatchSize(availableMemory uint64, estimatedRowSize uint64) int { // Adjust this factor based on your system's memory constraints. // This example uses a safety factor of 0.8 to leave some memory headroom. diff --git a/public/components/community/package.go b/public/components/community/package.go index 630be9754f..902eb34967 100644 --- a/public/components/community/package.go +++ b/public/components/community/package.go @@ -70,4 +70,5 @@ import ( _ "github.com/redpanda-data/connect/v4/public/components/twitter" _ "github.com/redpanda-data/connect/v4/public/components/wasm" _ "github.com/redpanda-data/connect/v4/public/components/zeromq" + _ "github.com/redpanda-data/connect/v4/public/components/postgresql" ) diff --git a/public/components/postgresql/package.go b/public/components/postgresql/package.go new file mode 100644 index 0000000000..fa5d81b263 --- /dev/null +++ b/public/components/postgresql/package.go @@ -0,0 +1,20 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgresql + +import ( + // Bring in the internal plugin definitions. + _ "github.com/redpanda-data/connect/v4/internal/impl/postgresql" +) From 041a55ca2a3cffd802c8cef2b91202623a0d9efe Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Wed, 23 Oct 2024 12:32:05 +0200 Subject: [PATCH 029/118] chore(): experimenting with object pool --- internal/impl/postgresql/integration_test.go | 1 + .../pglogicalstream/logical_stream.go | 44 +++++++++++-------- 2 files changed, 27 insertions(+), 18 deletions(-) diff --git a/internal/impl/postgresql/integration_test.go b/internal/impl/postgresql/integration_test.go index 3afadffb48..66cb74b34d 100644 --- a/internal/impl/postgresql/integration_test.go +++ b/internal/impl/postgresql/integration_test.go @@ -470,6 +470,7 @@ file: streamOutBuilder := service.NewStreamBuilder() require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: INFO`)) + streamOutBuilder.SetThreads(4) require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) require.NoError(t, streamOutBuilder.AddInputYAML(template)) diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 20bb1f40e0..54927d1f29 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -560,6 +560,17 @@ func (s *Stream) processSnapshot() { var wg sync.WaitGroup + _ = make([]byte, 100<<20) + + var objPool = sync.Pool{ + New: func() interface{} { + return &StreamMessage{ + Lsn: nil, + Changes: make([]StreamMessageChanges, 1), + } + }, + } + for _, table := range s.tableNames { wg.Add(1) go func(tableName string) { @@ -608,29 +619,26 @@ func (s *Stream) processSnapshot() { var pwg sync.WaitGroup p, _ := ants.NewPoolWithFunc(batchSize/4, func(i interface{}) { m := i.(RawMessage) - - snapshotChangePacket := StreamMessage{ - Lsn: nil, - Changes: []StreamMessageChanges{ - { - Table: m.TableName, - Operation: "insert", - Schema: s.schema, - Data: func() map[string]any { - var data = make(map[string]any) - for i, cn := range m.ColumnNames { - data[cn] = m.ColumnValues[i] - } - return data - }(), - }, - }, + snapshotChangePacket := objPool.Get().(*StreamMessage) + defer objPool.Put(snapshotChangePacket) + + snapshotChangePacket.Changes[0] = StreamMessageChanges{ + Table: m.TableName, + Operation: "insert", + Schema: s.schema, + Data: func() map[string]any { + var data = make(map[string]any) + for i, cn := range m.ColumnNames { + data[cn] = m.ColumnValues[i] + } + return data + }(), } tableProgress := s.monitor.GetSnapshotProgressForTable(m.TableName) snapshotChangePacket.Changes[0].TableSnapshotProgress = &tableProgress - s.snapshotMessages <- snapshotChangePacket + s.snapshotMessages <- *snapshotChangePacket pwg.Done() }, ants.WithPreAlloc(true)) From c7c198ca9074698779660d4694258d9ae4337e1b Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Wed, 23 Oct 2024 12:36:35 +0200 Subject: [PATCH 030/118] Revert "chore(): experimenting with object pool" This reverts commit 041a55ca2a3cffd802c8cef2b91202623a0d9efe. --- internal/impl/postgresql/integration_test.go | 1 - .../pglogicalstream/logical_stream.go | 44 ++++++++----------- 2 files changed, 18 insertions(+), 27 deletions(-) diff --git a/internal/impl/postgresql/integration_test.go b/internal/impl/postgresql/integration_test.go index 66cb74b34d..3afadffb48 100644 --- a/internal/impl/postgresql/integration_test.go +++ b/internal/impl/postgresql/integration_test.go @@ -470,7 +470,6 @@ file: streamOutBuilder := service.NewStreamBuilder() require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: INFO`)) - streamOutBuilder.SetThreads(4) require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) require.NoError(t, streamOutBuilder.AddInputYAML(template)) diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 54927d1f29..20bb1f40e0 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -560,17 +560,6 @@ func (s *Stream) processSnapshot() { var wg sync.WaitGroup - _ = make([]byte, 100<<20) - - var objPool = sync.Pool{ - New: func() interface{} { - return &StreamMessage{ - Lsn: nil, - Changes: make([]StreamMessageChanges, 1), - } - }, - } - for _, table := range s.tableNames { wg.Add(1) go func(tableName string) { @@ -619,26 +608,29 @@ func (s *Stream) processSnapshot() { var pwg sync.WaitGroup p, _ := ants.NewPoolWithFunc(batchSize/4, func(i interface{}) { m := i.(RawMessage) - snapshotChangePacket := objPool.Get().(*StreamMessage) - defer objPool.Put(snapshotChangePacket) - - snapshotChangePacket.Changes[0] = StreamMessageChanges{ - Table: m.TableName, - Operation: "insert", - Schema: s.schema, - Data: func() map[string]any { - var data = make(map[string]any) - for i, cn := range m.ColumnNames { - data[cn] = m.ColumnValues[i] - } - return data - }(), + + snapshotChangePacket := StreamMessage{ + Lsn: nil, + Changes: []StreamMessageChanges{ + { + Table: m.TableName, + Operation: "insert", + Schema: s.schema, + Data: func() map[string]any { + var data = make(map[string]any) + for i, cn := range m.ColumnNames { + data[cn] = m.ColumnValues[i] + } + return data + }(), + }, + }, } tableProgress := s.monitor.GetSnapshotProgressForTable(m.TableName) snapshotChangePacket.Changes[0].TableSnapshotProgress = &tableProgress - s.snapshotMessages <- *snapshotChangePacket + s.snapshotMessages <- snapshotChangePacket pwg.Done() }, ants.WithPreAlloc(true)) From b16738d54d3f472b52be82aae8908e8c486556a8 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Wed, 23 Oct 2024 12:54:01 +0200 Subject: [PATCH 031/118] chore(): use common pool to process snapshot --- .../pglogicalstream/logical_stream.go | 113 ++++++++++-------- 1 file changed, 63 insertions(+), 50 deletions(-) diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 20bb1f40e0..354d4e9417 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -558,6 +558,56 @@ func (s *Stream) processSnapshot() { s.logger.Infof("Starting snapshot processing") + type RawMessage struct { + RowsCount int + Offset int + ColumnTypes []*sql.ColumnType + ColumnNames []string + ScanArgs []interface{} + ValueGetters []func(interface{}) interface{} + TableName string + } + + var pwg sync.WaitGroup + p, _ := ants.NewPoolWithFunc(2000, func(i interface{}) { + m := i.(RawMessage) + + columnValues := make([]interface{}, len(m.ColumnTypes)) + for i, getter := range m.ValueGetters { + columnValues[i] = getter(m.ScanArgs[i]) + } + + snapshotChangePacket := StreamMessage{ + Lsn: nil, + Changes: []StreamMessageChanges{ + { + Table: m.TableName, + Operation: "insert", + Schema: s.schema, + Data: func() map[string]any { + var data = make(map[string]any) + for i, cn := range m.ColumnNames { + data[cn] = columnValues[i] + } + return data + }(), + }, + }, + } + + if m.RowsCount%100 == 0 { + s.monitor.UpdateSnapshotProgressForTable(m.TableName, m.RowsCount+m.Offset) + } + + tableProgress := s.monitor.GetSnapshotProgressForTable(m.TableName) + snapshotChangePacket.Changes[0].TableSnapshotProgress = &tableProgress + + s.snapshotMessages <- snapshotChangePacket + + pwg.Done() + }, ants.WithPreAlloc(true)) + defer p.Release() + var wg sync.WaitGroup for _, table := range s.tableNames { @@ -599,43 +649,6 @@ func (s *Stream) processSnapshot() { os.Exit(1) } - type RawMessage struct { - ColumnNames []string - ColumnValues []interface{} - TableName string - } - - var pwg sync.WaitGroup - p, _ := ants.NewPoolWithFunc(batchSize/4, func(i interface{}) { - m := i.(RawMessage) - - snapshotChangePacket := StreamMessage{ - Lsn: nil, - Changes: []StreamMessageChanges{ - { - Table: m.TableName, - Operation: "insert", - Schema: s.schema, - Data: func() map[string]any { - var data = make(map[string]any) - for i, cn := range m.ColumnNames { - data[cn] = m.ColumnValues[i] - } - return data - }(), - }, - }, - } - - tableProgress := s.monitor.GetSnapshotProgressForTable(m.TableName) - snapshotChangePacket.Changes[0].TableSnapshotProgress = &tableProgress - - s.snapshotMessages <- snapshotChangePacket - - pwg.Done() - }, ants.WithPreAlloc(true)) - defer p.Release() - for { var snapshotRows *sql.Rows queryStart := time.Now() @@ -684,13 +697,17 @@ func (s *Stream) processSnapshot() { var rowsCount = 0 rowsStart := time.Now() + totalScanDuration := time.Duration(0) tableWithoutSchema := strings.Split(table, ".")[1] for snapshotRows.Next() { rowsCount += 1 + scanStart := time.Now() scanArgs, valueGetters := s.snapshotter.prepareScannersAndGetters(columnTypes) err := snapshotRows.Scan(scanArgs...) + scanEnd := time.Since(scanStart) + totalScanDuration += scanEnd if err != nil { fmt.Println("Failed to scan row") @@ -701,28 +718,21 @@ func (s *Stream) processSnapshot() { os.Exit(1) } - columnValues := make([]interface{}, len(columnTypes)) - for i, getter := range valueGetters { - columnValues[i] = getter(scanArgs[i]) - } - - if rowsCount%100 == 0 { - s.monitor.UpdateSnapshotProgressForTable(tableWithoutSchema, rowsCount+offset) - } - pwg.Add(1) _ = p.Invoke(RawMessage{ + ColumnTypes: columnTypes, + RowsCount: rowsCount, TableName: tableWithoutSchema, ColumnNames: columnNames, - ColumnValues: columnValues, + ScanArgs: scanArgs, + ValueGetters: valueGetters, + Offset: offset, }) } - // waiting for batch to be processed - pwg.Wait() - batchEnd := time.Since(rowsStart) fmt.Printf("Batch duration: %v %s \n", batchEnd, tableName) + fmt.Println("Scan duration", totalScanDuration, tableName) offset += batchSize @@ -734,6 +744,9 @@ func (s *Stream) processSnapshot() { }(table) } + // waiting for batch to be processed + pwg.Wait() + wg.Wait() if err := s.startLr(); err != nil { From 8800444be9e0c7bc900df461e89ca628b1d3bfd2 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Wed, 23 Oct 2024 13:06:56 +0200 Subject: [PATCH 032/118] chore(): added snapshot message rate counter --- internal/impl/postgresql/input_postgrecdc.go | 13 ++++++++++--- .../impl/postgresql/{test_utils.go => utils.go} | 0 2 files changed, 10 insertions(+), 3 deletions(-) rename internal/impl/postgresql/{test_utils.go => utils.go} (100%) diff --git a/internal/impl/postgresql/input_postgrecdc.go b/internal/impl/postgresql/input_postgrecdc.go index d886990b15..a458de8f92 100644 --- a/internal/impl/postgresql/input_postgrecdc.go +++ b/internal/impl/postgresql/input_postgrecdc.go @@ -217,6 +217,8 @@ func newPgStreamInput(conf *service.ParsedConfig, logger *service.Logger, metric snapsotMetrics := metrics.NewGauge("snapshot_progress") replicationLag := metrics.NewGauge("replication_lag") + snapshotMessageRate := metrics.NewGauge("snapshot_message_rate") + snapshotRateCounter := NewRateCounter() return service.AutoRetryNacks(&pgStreamInput{ dbConfig: pgconnConfig, @@ -231,6 +233,8 @@ func newPgStreamInput(conf *service.ParsedConfig, logger *service.Logger, metric streamUncomited: streamUncomited, temporarySlot: temporarySlot, snapshotBatchSize: snapshotBatchSize, + snapshotMessageRate: snapshotMessageRate, + snapshotRateCounter: snapshotRateCounter, logger: logger, metrics: metrics, @@ -270,8 +274,10 @@ type pgStreamInput struct { logger *service.Logger metrics *service.Metrics - snapshotMetrics *service.MetricGauge - replicationLag *service.MetricGauge + snapshotRateCounter *RateCounter + snapshotMessageRate *service.MetricGauge + snapshotMetrics *service.MetricGauge + replicationLag *service.MetricGauge } func (p *pgStreamInput) Connect(ctx context.Context) error { @@ -303,7 +309,6 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { } func (p *pgStreamInput) Read(ctx context.Context) (*service.Message, service.AckFunc, error) { - select { case snapshotMessage := <-p.pglogicalStream.SnapshotMessageC(): var ( @@ -322,6 +327,8 @@ func (p *pgStreamInput) Read(ctx context.Context) (*service.Message, service.Ack p.snapshotMetrics.SetFloat64(*snapshotMessage.Changes[0].TableSnapshotProgress, snapshotMessage.Changes[0].Table) } + p.snapshotMessageRate.SetFloat64(p.snapshotRateCounter.Rate()) + return connectMessage, func(ctx context.Context, err error) error { // Nacks are retried automatically when we use service.AutoRetryNacks return nil diff --git a/internal/impl/postgresql/test_utils.go b/internal/impl/postgresql/utils.go similarity index 100% rename from internal/impl/postgresql/test_utils.go rename to internal/impl/postgresql/utils.go From 79428da57fc2c4b63a8617dbf9827f88d3b905b9 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Fri, 25 Oct 2024 12:04:42 +0200 Subject: [PATCH 033/118] chore(): working on batches --- internal/impl/postgresql/input_postgrecdc.go | 323 ++++++++++++++---- internal/impl/postgresql/integration_test.go | 4 +- .../pglogicalstream/logical_stream.go | 166 +++++---- internal/impl/postgresql/utils.go | 42 +++ 4 files changed, 390 insertions(+), 145 deletions(-) diff --git a/internal/impl/postgresql/input_postgrecdc.go b/internal/impl/postgresql/input_postgrecdc.go index a458de8f92..b8328d8f39 100644 --- a/internal/impl/postgresql/input_postgrecdc.go +++ b/internal/impl/postgresql/input_postgrecdc.go @@ -12,8 +12,13 @@ import ( "context" "crypto/tls" "encoding/json" + "fmt" + "strconv" "strings" + "sync" + "time" + "github.com/Jeffail/checkpoint" "github.com/jackc/pgx/v5/pgconn" "github.com/lucasepe/codename" "github.com/redpanda-data/benthos/v4/public/service" @@ -23,6 +28,11 @@ import ( var randomSlotName string +type asyncMessage struct { + msg service.MessageBatch + ackFn service.AckFunc +} + var pgStreamConfigSpec = service.NewConfigSpec(). Beta(). Categories("Services"). @@ -85,15 +95,19 @@ This input adds the following metadata fields to each message: - my_table - my_table_2 `)). + Field(service.NewIntField("checkpoint_limit"). + Description("The maximum number of messages of the same topic and partition that can be processed at a given time. Increasing this limit enables parallel processing and batching at the output level to work on individual partitions. Any given offset will not be committed unless all messages under that offset are delivered in order to preserve at least once delivery guarantees."). + Version("3.33.0").Default(1024)). Field(service.NewBoolField("temporary_slot"). Description("If set to true, creates a temporary replication slot that is automatically dropped when the connection is closed."). Default(false)). Field(service.NewStringField("slot_name"). Description("The name of the PostgreSQL logical replication slot to use. If not provided, a random name will be generated. You can create this slot manually before starting replication if desired."). Example("my_test_slot"). - Default(randomSlotName)) + Default(randomSlotName)). + Field(service.NewBatchPolicyField("batching").Advanced()) -func newPgStreamInput(conf *service.ParsedConfig, logger *service.Logger, metrics *service.Metrics) (s service.Input, err error) { +func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s service.BatchInput, err error) { var ( dbName string dbPort int @@ -111,6 +125,8 @@ func newPgStreamInput(conf *service.ParsedConfig, logger *service.Logger, metric pgConnOptions string streamUncomited bool snapshotBatchSize int + checkpointLimit int + batching service.BatchPolicy ) dbSchema, err = conf.FieldString("schema") @@ -167,6 +183,10 @@ func newPgStreamInput(conf *service.ParsedConfig, logger *service.Logger, metric return nil, err } + if checkpointLimit, err = conf.FieldInt("checkpoint_limit"); err != nil { + return nil, err + } + streamSnapshot, err = conf.FieldBool("stream_snapshot") if err != nil { return nil, err @@ -200,6 +220,12 @@ func newPgStreamInput(conf *service.ParsedConfig, logger *service.Logger, metric pgConnOptions = "options=" + pgConnOptions } + if batching, err = conf.FieldBatchPolicy("batching"); err != nil { + return nil, err + } else if batching.IsNoop() { + batching.Count = 1 + } + pgconnConfig := pgconn.Config{ Host: dbHost, Port: uint16(dbPort), @@ -215,12 +241,12 @@ func newPgStreamInput(conf *service.ParsedConfig, logger *service.Logger, metric pgconnConfig.TLSConfig = nil } - snapsotMetrics := metrics.NewGauge("snapshot_progress") - replicationLag := metrics.NewGauge("replication_lag") - snapshotMessageRate := metrics.NewGauge("snapshot_message_rate") + snapsotMetrics := mgr.Metrics().NewGauge("snapshot_progress", "table") + replicationLag := mgr.Metrics().NewGauge("replication_lag_bytes") + snapshotMessageRate := mgr.Metrics().NewGauge("snapshot_message_rate") snapshotRateCounter := NewRateCounter() - return service.AutoRetryNacks(&pgStreamInput{ + i := &pgStreamInput{ dbConfig: pgconnConfig, streamSnapshot: streamSnapshot, snapshotMemSafetyFactor: snapshotMemSafetyFactor, @@ -235,22 +261,34 @@ func newPgStreamInput(conf *service.ParsedConfig, logger *service.Logger, metric snapshotBatchSize: snapshotBatchSize, snapshotMessageRate: snapshotMessageRate, snapshotRateCounter: snapshotRateCounter, - - logger: logger, - metrics: metrics, + batching: batching, + checkpointLimit: checkpointLimit, + cMut: sync.Mutex{}, + msgChan: make(chan asyncMessage), + + mgr: mgr, + logger: mgr.Logger(), + metrics: mgr.Metrics(), snapshotMetrics: snapsotMetrics, replicationLag: replicationLag, - }), err + } + + r, err := service.AutoRetryNacksBatchedToggled(conf, i) + if err != nil { + return nil, err + } + + return conf.WrapBatchInputExtractTracingSpanMapping("pg_stream", r) } func init() { rng, _ := codename.DefaultRNG() randomSlotName = strings.ReplaceAll(codename.Generate(rng, 5), "-", "_") - err := service.RegisterInput( + err := service.RegisterBatchInput( "pg_stream", pgStreamConfigSpec, - func(conf *service.ParsedConfig, mgr *service.Resources) (service.Input, error) { - return newPgStreamInput(conf, mgr.Logger(), mgr.Metrics()) + func(conf *service.ParsedConfig, mgr *service.Resources) (service.BatchInput, error) { + return newPgStreamInput(conf, mgr) }) if err != nil { panic(err) @@ -272,7 +310,12 @@ type pgStreamInput struct { snapshotBatchSize int streamUncomited bool logger *service.Logger + mgr *service.Resources metrics *service.Metrics + cMut sync.Mutex + msgChan chan asyncMessage + batching service.BatchPolicy + checkpointLimit int snapshotRateCounter *RateCounter snapshotMessageRate *service.MetricGauge @@ -305,64 +348,228 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { p.pglogicalStream = pgStream - return err -} + go func() { + batchPolicy, err := p.batching.NewBatcher(p.mgr) + if err != nil { + p.logger.Errorf("Failed to initialise batch policy: %v, falling back to no policy.\n", err) + conf := service.BatchPolicy{Count: 1} + if batchPolicy, err = conf.NewBatcher(p.mgr); err != nil { + panic(err) + } + } -func (p *pgStreamInput) Read(ctx context.Context) (*service.Message, service.AckFunc, error) { - select { - case snapshotMessage := <-p.pglogicalStream.SnapshotMessageC(): - var ( - mb []byte - err error - ) - if mb, err = json.Marshal(snapshotMessage); err != nil { - return nil, nil, err + defer func() { + batchPolicy.Close(context.Background()) + }() + + var nextTimedBatchChan <-chan time.Time + var flushBatch func(context.Context, chan<- asyncMessage, service.MessageBatch, *int64) bool + if p.checkpointLimit > 1 { + flushBatch = p.asyncCheckpointer() + } else { + flushBatch = p.syncCheckpointer() } - connectMessage := service.NewMessage(mb) - connectMessage.MetaSet("streaming", "false") - connectMessage.MetaSet("table", snapshotMessage.Changes[0].Table) - connectMessage.MetaSet("operation", snapshotMessage.Changes[0].Operation) - if snapshotMessage.Changes[0].TableSnapshotProgress != nil { - p.snapshotMetrics.SetFloat64(*snapshotMessage.Changes[0].TableSnapshotProgress, snapshotMessage.Changes[0].Table) + // offsets are nilable since we don't provide offset tracking during the snapshot phase + var latestOffset *int64 + + for { + select { + case <-nextTimedBatchChan: + nextTimedBatchChan = nil + flushedBatch, err := batchPolicy.Flush(ctx) + if err != nil { + p.mgr.Logger().Debugf("Timed flush batch error: %w", err) + break + } + + if !flushBatch(ctx, p.msgChan, flushedBatch, latestOffset) { + break + } + // TrxCommit LSN must be acked when all the bessages in the batch are processed + case trxCommitLsn, open := <-p.pglogicalStream.AckTxChan(): + if !open { + break + } + + fmt.Println("trxCommitLsn", trxCommitLsn) + fmt.Println("Force flushing the batch") + flushedBatch, err := batchPolicy.Flush(ctx) + if err != nil { + p.mgr.Logger().Debugf("Flush batch error: %w", err) + break + } + + if !flushBatch(ctx, p.msgChan, flushedBatch, latestOffset) { + break + } + + time.Sleep(time.Second) + + if err = p.pglogicalStream.AckLSN(trxCommitLsn); err != nil { + p.mgr.Logger().Errorf("Failed to ack LSN: %v", err) + break + } + + p.pglogicalStream.ConsumedCallback() <- true + case message, open := <-p.pglogicalStream.Messages(): + if !open { + break + } + var ( + mb []byte + err error + ) + if message.Lsn != nil { + parsedLSN, err := LSNToInt64(*message.Lsn) + if err != nil { + p.logger.Errorf("Failed to parse LSN: %v", err) + break + } + latestOffset = &parsedLSN + } + + if mb, err = json.Marshal(message); err != nil { + break + } + + if message.IsStreaming { + fmt.Println("Message received", string(mb)) + } + + batchMsg := service.NewMessage(mb) + + streaming := strconv.FormatBool(message.IsStreaming) + batchMsg.MetaSet("streaming", streaming) + batchMsg.MetaSet("table", message.Changes[0].Table) + batchMsg.MetaSet("operation", message.Changes[0].Operation) + if message.Changes[0].TableSnapshotProgress != nil { + p.snapshotMetrics.SetFloat64(*message.Changes[0].TableSnapshotProgress, message.Changes[0].Table) + } + if message.WALLagBytes != nil { + p.replicationLag.Set(*message.WALLagBytes) + } + + p.snapshotRateCounter.Increment() + p.snapshotMessageRate.SetFloat64(p.snapshotRateCounter.Rate()) + + if batchPolicy.Add(batchMsg) { + nextTimedBatchChan = nil + flushedBatch, err := batchPolicy.Flush(ctx) + if err != nil { + p.mgr.Logger().Debugf("Flush batch error: %w", err) + break + } + if !flushBatch(ctx, p.msgChan, flushedBatch, latestOffset) { + break + } + } + case <-ctx.Done(): + p.pglogicalStream.Stop() + } } + }() + + return err +} - p.snapshotMessageRate.SetFloat64(p.snapshotRateCounter.Rate()) - - return connectMessage, func(ctx context.Context, err error) error { - // Nacks are retried automatically when we use service.AutoRetryNacks - return nil - }, nil - case message := <-p.pglogicalStream.LrMessageC(): - var ( - mb []byte - err error - ) - if mb, err = json.Marshal(message); err != nil { - return nil, nil, err +func (p *pgStreamInput) asyncCheckpointer() func(context.Context, chan<- asyncMessage, service.MessageBatch, *int64) bool { + cp := checkpoint.NewCapped[*int64](int64(p.checkpointLimit)) + return func(ctx context.Context, c chan<- asyncMessage, msg service.MessageBatch, lsn *int64) bool { + if msg == nil { + return true } - connectMessage := service.NewMessage(mb) - connectMessage.MetaSet("streaming", "true") - connectMessage.MetaSet("table", message.Changes[0].Table) - connectMessage.MetaSet("operation", message.Changes[0].Operation) - if message.WALLagBytes != nil { - p.replicationLag.Set(*message.WALLagBytes) + resolveFn, err := cp.Track(ctx, lsn, int64(len(msg))) + if err != nil { + if ctx.Err() == nil { + p.mgr.Logger().Errorf("Failed to checkpoint offset: %v\n", err) + } + return false } + select { + case c <- asyncMessage{ + msg: msg, + ackFn: func(ctx context.Context, res error) error { + maxOffset := resolveFn() + if maxOffset == nil { + return nil + } + p.cMut.Lock() + if lsn != nil { + fmt.Println("Acking LSN from chackpointer", *lsn) + p.pglogicalStream.AckLSN(Int64ToLSN(*lsn)) + } + p.cMut.Unlock() + return nil + }, + }: + case <-ctx.Done(): + return false + } + return true + } +} - return connectMessage, func(ctx context.Context, err error) error { - if message.Lsn != nil { - if err := p.pglogicalStream.AckLSN(*message.Lsn); err != nil { - return err +func (p *pgStreamInput) syncCheckpointer() func(context.Context, chan<- asyncMessage, service.MessageBatch, *int64) bool { + ackedChan := make(chan error) + return func(ctx context.Context, c chan<- asyncMessage, msg service.MessageBatch, lsn *int64) bool { + if msg == nil { + return true + } + select { + case c <- asyncMessage{ + msg: msg, + ackFn: func(ctx context.Context, res error) error { + resErr := res + if resErr == nil { + p.cMut.Lock() + if lsn != nil { + p.pglogicalStream.AckLSN(Int64ToLSN(*lsn)) + } + p.cMut.Unlock() + } + select { + case ackedChan <- resErr: + case <-ctx.Done(): } - if p.streamUncomited { - p.pglogicalStream.ConsumedCallback() <- true + return nil + }, + }: + select { + case resErr := <-ackedChan: + if resErr != nil { + p.mgr.Logger().Errorf("Received error from message batch: %v, shutting down consumer.\n", resErr) + return false } + case <-ctx.Done(): + return false } - return nil - }, nil + case <-ctx.Done(): + return false + } + return true + } +} + +func (p *pgStreamInput) ReadBatch(ctx context.Context) (service.MessageBatch, service.AckFunc, error) { + p.cMut.Lock() + msgChan := p.msgChan + p.cMut.Unlock() + if msgChan == nil { + return nil, nil, service.ErrNotConnected + } + + select { + case m, open := <-msgChan: + if !open { + return nil, nil, service.ErrNotConnected + } + return m.msg, m.ackFn, nil case <-ctx.Done(): - return nil, nil, p.pglogicalStream.Stop() + } + + return nil, nil, ctx.Err() } func (p *pgStreamInput) Close(ctx context.Context) error { diff --git a/internal/impl/postgresql/integration_test.go b/internal/impl/postgresql/integration_test.go index 3afadffb48..e99643937e 100644 --- a/internal/impl/postgresql/integration_test.go +++ b/internal/impl/postgresql/integration_test.go @@ -488,8 +488,8 @@ file: } }() - require.NoError(t, streamOutBuilder.AddConsumerFunc(func(c context.Context, m *service.Message) error { - _, err := m.AsBytes() + require.NoError(t, streamOutBuilder.AddBatchConsumerFunc(func(c context.Context, mb service.MessageBatch) error { + _, err := mb[0].AsBytes() require.NoError(t, err) outMessagesMut.Lock() outMessages += 1 diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 354d4e9417..643ae1d583 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -19,8 +19,6 @@ import ( "sync" "time" - "github.com/panjf2000/ants/v2" - "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgproto3" "github.com/jackc/pgx/v5/pgtype" @@ -45,7 +43,6 @@ type Stream struct { standbyMessageTimeout time.Duration nextStandbyMessageDeadline time.Time messages chan StreamMessage - snapshotMessages chan StreamMessage snapshotName string slotName string schema string @@ -58,6 +55,7 @@ type Stream struct { monitor *Monitor streamUncomited bool snapshotter *Snapshotter + transactionAckChan chan string lsnAckBuffer []string @@ -123,7 +121,6 @@ func NewPgStream(config Config) (*Stream, error) { pgConn: dbConn, dbConfig: *cfg, messages: make(chan StreamMessage), - snapshotMessages: make(chan StreamMessage, 100), slotName: config.ReplicationSlotName, schema: config.DBSchema, snapshotMemorySafetyFactor: config.SnapshotMemorySafetyFactor, @@ -131,6 +128,7 @@ func NewPgStream(config Config) (*Stream, error) { snapshotBatchSize: config.BatchSize, tableNames: tableNames, consumedCallback: make(chan bool), + transactionAckChan: make(chan string), lsnAckBuffer: []string{}, logger: config.logger, m: sync.Mutex{}, @@ -294,6 +292,8 @@ func (s *Stream) AckLSN(lsn string) error { return err } + fmt.Println("Ack LSN", lsn, "clientXLogPos") + err = SendStandbyStatusUpdate(context.Background(), s.pgConn, StandbyStatusUpdate{ WALApplyPosition: clientXLogPos, WALWritePosition: clientXLogPos, @@ -445,28 +445,46 @@ func (s *Stream) streamMessagesAsync() { return } - if message == nil { - // 0 changes happened in the transaction - // or we received a change that are not supported/needed by the replication stream - if err = s.AckLSN(clientXLogPos.String()); err != nil { - // stop reading from replication slot - // if we can't acknowledge the LSN - if err = s.Stop(); err != nil { - s.logger.Errorf("Failed to stop the stream: %v", err) - } - return + fmt.Println("Receive pg message", message) + + isCommit, _, err := isCommitMessage(xld.WALData) + if err != nil { + s.logger.Errorf("Failed to parse WAL data: %w", err) + if err = s.Stop(); err != nil { + s.logger.Errorf("Failed to stop the stream: %v", err) } + return + } + + // when receiving a commit message, we need to acknowledge the LSN + // but we must wait for benthos to flush the messages before we can do that + if isCommit { + s.transactionAckChan <- clientXLogPos.String() + <-s.consumedCallback } else { - lsn := clientXLogPos.String() - s.messages <- StreamMessage{ - Lsn: &lsn, - Changes: []StreamMessageChanges{ - *message, - }, - IsStreaming: true, - WALLagBytes: &metrics.WalLagInBytes, + if message == nil { + // 0 changes happened in the transaction + // or we received a change that are not supported/needed by the replication stream + if err = s.AckLSN(clientXLogPos.String()); err != nil { + // stop reading from replication slot + // if we can't acknowledge the LSN + if err = s.Stop(); err != nil { + s.logger.Errorf("Failed to stop the stream: %v", err) + } + return + } + } else { + lsn := clientXLogPos.String() + fmt.Println("Pushed uncomited message to stream", lsn, message) + s.messages <- StreamMessage{ + Lsn: &lsn, + Changes: []StreamMessageChanges{ + *message, + }, + IsStreaming: true, + WALLagBytes: &metrics.WalLagInBytes, + } } - <-s.consumedCallback } } else { // message changes must be collected in the buffer in the context of the same transaction @@ -538,6 +556,14 @@ func (s *Stream) streamMessagesAsync() { } } +func (s *Stream) AckTxChan() chan string { + return s.transactionAckChan +} + +func (s *Stream) ConfigrmAckTxChan() chan bool { + return s.consumedCallback +} + func (s *Stream) processSnapshot() { if err := s.snapshotter.prepare(); err != nil { s.logger.Errorf("Failed to prepare database snapshot. Probably snapshot is expired...: %v", err.Error()) @@ -568,46 +594,6 @@ func (s *Stream) processSnapshot() { TableName string } - var pwg sync.WaitGroup - p, _ := ants.NewPoolWithFunc(2000, func(i interface{}) { - m := i.(RawMessage) - - columnValues := make([]interface{}, len(m.ColumnTypes)) - for i, getter := range m.ValueGetters { - columnValues[i] = getter(m.ScanArgs[i]) - } - - snapshotChangePacket := StreamMessage{ - Lsn: nil, - Changes: []StreamMessageChanges{ - { - Table: m.TableName, - Operation: "insert", - Schema: s.schema, - Data: func() map[string]any { - var data = make(map[string]any) - for i, cn := range m.ColumnNames { - data[cn] = columnValues[i] - } - return data - }(), - }, - }, - } - - if m.RowsCount%100 == 0 { - s.monitor.UpdateSnapshotProgressForTable(m.TableName, m.RowsCount+m.Offset) - } - - tableProgress := s.monitor.GetSnapshotProgressForTable(m.TableName) - snapshotChangePacket.Changes[0].TableSnapshotProgress = &tableProgress - - s.snapshotMessages <- snapshotChangePacket - - pwg.Done() - }, ants.WithPreAlloc(true)) - defer p.Release() - var wg sync.WaitGroup for _, table := range s.tableNames { @@ -677,7 +663,6 @@ func (s *Stream) processSnapshot() { columnTypes, err := snapshotRows.ColumnTypes() if err != nil { - fmt.Println("Failed to get column types") s.logger.Errorf("Failed to get column types for table %v: %v", table, err.Error()) if err = s.cleanUpOnFailure(); err != nil { s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) @@ -687,7 +672,6 @@ func (s *Stream) processSnapshot() { columnNames, err := snapshotRows.Columns() if err != nil { - fmt.Println("Failed to get column names") s.logger.Errorf("Failed to get column names for table %v: %v", table, err.Error()) if err = s.cleanUpOnFailure(); err != nil { s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) @@ -698,6 +682,7 @@ func (s *Stream) processSnapshot() { var rowsCount = 0 rowsStart := time.Now() totalScanDuration := time.Duration(0) + totalWaitingFromBenthos := time.Duration(0) tableWithoutSchema := strings.Split(table, ".")[1] for snapshotRows.Next() { @@ -718,21 +703,41 @@ func (s *Stream) processSnapshot() { os.Exit(1) } - pwg.Add(1) - _ = p.Invoke(RawMessage{ - ColumnTypes: columnTypes, - RowsCount: rowsCount, - TableName: tableWithoutSchema, - ColumnNames: columnNames, - ScanArgs: scanArgs, - ValueGetters: valueGetters, - Offset: offset, - }) + var data = make(map[string]any) + for i, getter := range valueGetters { + data[columnNames[i]] = getter(scanArgs[i]) + } + + snapshotChangePacket := StreamMessage{ + Lsn: nil, + Changes: []StreamMessageChanges{ + { + Table: tableWithoutSchema, + Operation: "insert", + Schema: s.schema, + Data: data, + }, + }, + } + + if rowsCount%100 == 0 { + s.monitor.UpdateSnapshotProgressForTable(tableWithoutSchema, rowsCount+offset) + } + + tableProgress := s.monitor.GetSnapshotProgressForTable(tableWithoutSchema) + snapshotChangePacket.Changes[0].TableSnapshotProgress = &tableProgress + snapshotChangePacket.IsStreaming = false + + waitingFromBenthos := time.Now() + s.messages <- snapshotChangePacket + totalWaitingFromBenthos += time.Since(waitingFromBenthos) + } batchEnd := time.Since(rowsStart) fmt.Printf("Batch duration: %v %s \n", batchEnd, tableName) fmt.Println("Scan duration", totalScanDuration, tableName) + fmt.Println("Waiting from benthos duration", totalWaitingFromBenthos, tableName) offset += batchSize @@ -744,9 +749,6 @@ func (s *Stream) processSnapshot() { }(table) } - // waiting for batch to be processed - pwg.Wait() - wg.Wait() if err := s.startLr(); err != nil { @@ -756,15 +758,9 @@ func (s *Stream) processSnapshot() { go s.streamMessagesAsync() } -// SnapshotMessageC represents a message from the stream that are sent to the consumer on the snapshot processing stage -// meaning these messages will have nil LSN field -func (s *Stream) SnapshotMessageC() chan StreamMessage { - return s.snapshotMessages -} - // LrMessageC represents a message from the stream that are sent to the consumer on the logical replication stage // meaning these messages will have non-nil LSN field -func (s *Stream) LrMessageC() chan StreamMessage { +func (s *Stream) Messages() chan StreamMessage { return s.messages } diff --git a/internal/impl/postgresql/utils.go b/internal/impl/postgresql/utils.go index 700bcd7ed2..3c3e3bd642 100644 --- a/internal/impl/postgresql/utils.go +++ b/internal/impl/postgresql/utils.go @@ -1,6 +1,9 @@ package pgstream import ( + "fmt" + "strconv" + "strings" "sync" "sync/atomic" "time" @@ -37,3 +40,42 @@ func (rc *RateCounter) Rate() float64 { return float64(count) / duration } + +// LSNToInt64 converts a PostgreSQL LSN string to int64 +func LSNToInt64(lsn string) (int64, error) { + // Split the LSN into segments + parts := strings.Split(lsn, "/") + if len(parts) != 2 { + return 0, fmt.Errorf("invalid LSN format: %s", lsn) + } + + // Parse both segments as hex with uint64 first + upper, err := strconv.ParseUint(parts[0], 16, 64) + if err != nil { + return 0, fmt.Errorf("failed to parse upper part: %w", err) + } + + lower, err := strconv.ParseUint(parts[1], 16, 64) + if err != nil { + return 0, fmt.Errorf("failed to parse lower part: %w", err) + } + + // Combine the segments into a single int64 + // Upper part is shifted left by 32 bits + result := int64((upper << 32) | lower) + + return result, nil +} + +// Int64ToLSN converts an int64 to a PostgreSQL LSN string +func Int64ToLSN(value int64) string { + // Convert to uint64 to handle the bitwise operations properly + uvalue := uint64(value) + + // Extract upper and lower parts + upper := uvalue >> 32 + lower := uvalue & 0xFFFFFFFF + + // Format as hexadecimal with proper padding + return fmt.Sprintf("%X/%X", upper, lower) +} From 85abfc83811f7ec2848973bc30a52135ef0d63ff Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Fri, 25 Oct 2024 13:01:31 +0200 Subject: [PATCH 034/118] fixed(): test --- internal/impl/postgresql/integration_test.go | 2 +- .../pglogicalstream/logical_stream.go | 101 ++++++------------ 2 files changed, 32 insertions(+), 71 deletions(-) diff --git a/internal/impl/postgresql/integration_test.go b/internal/impl/postgresql/integration_test.go index 3afadffb48..6413e83ff3 100644 --- a/internal/impl/postgresql/integration_test.go +++ b/internal/impl/postgresql/integration_test.go @@ -428,7 +428,7 @@ file: } func TestIntegrationPgStreamingFromRemoteDB(t *testing.T) { - // t.Skip("This test requires a remote database to run. Aimed to test remote databases") + t.Skip("This test requires a remote database to run. Aimed to test remote databases") tmpDir := t.TempDir() // tables: users, products, orders, order_items diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 354d4e9417..7d02d9b95b 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -19,8 +19,6 @@ import ( "sync" "time" - "github.com/panjf2000/ants/v2" - "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgproto3" "github.com/jackc/pgx/v5/pgtype" @@ -568,48 +566,7 @@ func (s *Stream) processSnapshot() { TableName string } - var pwg sync.WaitGroup - p, _ := ants.NewPoolWithFunc(2000, func(i interface{}) { - m := i.(RawMessage) - - columnValues := make([]interface{}, len(m.ColumnTypes)) - for i, getter := range m.ValueGetters { - columnValues[i] = getter(m.ScanArgs[i]) - } - - snapshotChangePacket := StreamMessage{ - Lsn: nil, - Changes: []StreamMessageChanges{ - { - Table: m.TableName, - Operation: "insert", - Schema: s.schema, - Data: func() map[string]any { - var data = make(map[string]any) - for i, cn := range m.ColumnNames { - data[cn] = columnValues[i] - } - return data - }(), - }, - }, - } - - if m.RowsCount%100 == 0 { - s.monitor.UpdateSnapshotProgressForTable(m.TableName, m.RowsCount+m.Offset) - } - - tableProgress := s.monitor.GetSnapshotProgressForTable(m.TableName) - snapshotChangePacket.Changes[0].TableSnapshotProgress = &tableProgress - - s.snapshotMessages <- snapshotChangePacket - - pwg.Done() - }, ants.WithPreAlloc(true)) - defer p.Release() - var wg sync.WaitGroup - for _, table := range s.tableNames { wg.Add(1) go func(tableName string) { @@ -651,7 +608,6 @@ func (s *Stream) processSnapshot() { for { var snapshotRows *sql.Rows - queryStart := time.Now() if snapshotRows, err = s.snapshotter.querySnapshotData(table, tablePk, batchSize, offset); err != nil { s.logger.Errorf("Failed to query snapshot data for table %v: %v", table, err.Error()) s.logger.Errorf("Failed to query snapshot for table %v: %v", table, err.Error()) @@ -662,9 +618,6 @@ func (s *Stream) processSnapshot() { os.Exit(1) } - queryDuration := time.Since(queryStart) - fmt.Printf("Query duration: %v %s \n", queryDuration, tableName) - if snapshotRows.Err() != nil { s.logger.Errorf("Failed to get snapshot data for table %v: %v", table, snapshotRows.Err().Error()) s.logger.Errorf("Failed to query snapshot for table %v: %v", table, err.Error()) @@ -696,21 +649,15 @@ func (s *Stream) processSnapshot() { } var rowsCount = 0 - rowsStart := time.Now() - totalScanDuration := time.Duration(0) tableWithoutSchema := strings.Split(table, ".")[1] for snapshotRows.Next() { rowsCount += 1 - scanStart := time.Now() scanArgs, valueGetters := s.snapshotter.prepareScannersAndGetters(columnTypes) err := snapshotRows.Scan(scanArgs...) - scanEnd := time.Since(scanStart) - totalScanDuration += scanEnd if err != nil { - fmt.Println("Failed to scan row") s.logger.Errorf("Failed to scan row for table %v: %v", table, err.Error()) if err = s.cleanUpOnFailure(); err != nil { s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) @@ -718,21 +665,38 @@ func (s *Stream) processSnapshot() { os.Exit(1) } - pwg.Add(1) - _ = p.Invoke(RawMessage{ - ColumnTypes: columnTypes, - RowsCount: rowsCount, - TableName: tableWithoutSchema, - ColumnNames: columnNames, - ScanArgs: scanArgs, - ValueGetters: valueGetters, - Offset: offset, - }) - } + columnValues := make([]interface{}, len(columnTypes)) + for i, getter := range valueGetters { + columnValues[i] = getter(scanArgs[i]) + } + + snapshotChangePacket := StreamMessage{ + Lsn: nil, + Changes: []StreamMessageChanges{ + { + Table: tableWithoutSchema, + Operation: "insert", + Schema: s.schema, + Data: func() map[string]any { + var data = make(map[string]any) + for i, cn := range columnNames { + data[cn] = columnValues[i] + } + return data + }(), + }, + }, + } + + if rowsCount%100 == 0 { + s.monitor.UpdateSnapshotProgressForTable(tableName, rowsCount+offset) + } - batchEnd := time.Since(rowsStart) - fmt.Printf("Batch duration: %v %s \n", batchEnd, tableName) - fmt.Println("Scan duration", totalScanDuration, tableName) + tableProgress := s.monitor.GetSnapshotProgressForTable(tableWithoutSchema) + snapshotChangePacket.Changes[0].TableSnapshotProgress = &tableProgress + + s.snapshotMessages <- snapshotChangePacket + } offset += batchSize @@ -744,9 +708,6 @@ func (s *Stream) processSnapshot() { }(table) } - // waiting for batch to be processed - pwg.Wait() - wg.Wait() if err := s.startLr(); err != nil { From d4a4960fb8b2a6e9f741e1ca5e86cf098ec491e8 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Fri, 25 Oct 2024 13:02:20 +0200 Subject: [PATCH 035/118] fix(): metrics --- internal/impl/postgresql/input_postgrecdc.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/impl/postgresql/input_postgrecdc.go b/internal/impl/postgresql/input_postgrecdc.go index a458de8f92..049b35a84e 100644 --- a/internal/impl/postgresql/input_postgrecdc.go +++ b/internal/impl/postgresql/input_postgrecdc.go @@ -215,7 +215,7 @@ func newPgStreamInput(conf *service.ParsedConfig, logger *service.Logger, metric pgconnConfig.TLSConfig = nil } - snapsotMetrics := metrics.NewGauge("snapshot_progress") + snapsotMetrics := metrics.NewGauge("snapshot_progress", "table") replicationLag := metrics.NewGauge("replication_lag") snapshotMessageRate := metrics.NewGauge("snapshot_message_rate") snapshotRateCounter := NewRateCounter() From 0187c7a5b24cd29b0d0db281a601b94290df883e Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Fri, 25 Oct 2024 13:03:13 +0200 Subject: [PATCH 036/118] chore(): removed unused struct --- .../impl/postgresql/pglogicalstream/logical_stream.go | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 7d02d9b95b..318840ef11 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -556,16 +556,6 @@ func (s *Stream) processSnapshot() { s.logger.Infof("Starting snapshot processing") - type RawMessage struct { - RowsCount int - Offset int - ColumnTypes []*sql.ColumnType - ColumnNames []string - ScanArgs []interface{} - ValueGetters []func(interface{}) interface{} - TableName string - } - var wg sync.WaitGroup for _, table := range s.tableNames { wg.Add(1) From fd87cbe067a2b26d472cfd2eb3b980d43d2f2e65 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Mon, 28 Oct 2024 20:19:41 +0100 Subject: [PATCH 037/118] chore(): stabilised batches --- internal/impl/postgresql/input_postgrecdc.go | 150 +++++++-------- internal/impl/postgresql/integration_test.go | 178 +++++++++--------- .../pglogicalstream/logical_stream.go | 30 ++- 3 files changed, 176 insertions(+), 182 deletions(-) diff --git a/internal/impl/postgresql/input_postgrecdc.go b/internal/impl/postgresql/input_postgrecdc.go index b8328d8f39..cbad96c0b3 100644 --- a/internal/impl/postgresql/input_postgrecdc.go +++ b/internal/impl/postgresql/input_postgrecdc.go @@ -12,10 +12,10 @@ import ( "context" "crypto/tls" "encoding/json" - "fmt" "strconv" "strings" "sync" + "sync/atomic" "time" "github.com/Jeffail/checkpoint" @@ -105,6 +105,7 @@ This input adds the following metadata fields to each message: Description("The name of the PostgreSQL logical replication slot to use. If not provided, a random name will be generated. You can create this slot manually before starting replication if desired."). Example("my_test_slot"). Default(randomSlotName)). + Field(service.NewAutoRetryNacksToggleField()). Field(service.NewBatchPolicyField("batching").Advanced()) func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s service.BatchInput, err error) { @@ -271,6 +272,8 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser metrics: mgr.Metrics(), snapshotMetrics: snapsotMetrics, replicationLag: replicationLag, + inTxState: atomic.Bool{}, + releaseTrxChan: make(chan bool), } r, err := service.AutoRetryNacksBatchedToggled(conf, i) @@ -321,25 +324,30 @@ type pgStreamInput struct { snapshotMessageRate *service.MetricGauge snapshotMetrics *service.MetricGauge replicationLag *service.MetricGauge + + pendingTrx *string + releaseTrxChan chan bool + inTxState atomic.Bool } func (p *pgStreamInput) Connect(ctx context.Context) error { pgStream, err := pglogicalstream.NewPgStream(pglogicalstream.Config{ - PgConnRuntimeParam: p.pgConnRuntimeParam, - DBHost: p.dbConfig.Host, - DBPassword: p.dbConfig.Password, - DBUser: p.dbConfig.User, - DBPort: int(p.dbConfig.Port), - DBTables: p.tables, - DBName: p.dbConfig.Database, - DBSchema: p.schema, - ReplicationSlotName: "rs_" + p.slotName, - TLSVerify: p.tls, - BatchSize: p.snapshotBatchSize, - StreamOldData: p.streamSnapshot, - TemporaryReplicationSlot: p.temporarySlot, - StreamUncomited: p.streamUncomited, - DecodingPlugin: p.decodingPlugin, + PgConnRuntimeParam: p.pgConnRuntimeParam, + DBHost: p.dbConfig.Host, + DBPassword: p.dbConfig.Password, + DBUser: p.dbConfig.User, + DBPort: int(p.dbConfig.Port), + DBTables: p.tables, + DBName: p.dbConfig.Database, + DBSchema: p.schema, + ReplicationSlotName: "rs_" + p.slotName, + TLSVerify: p.tls, + BatchSize: p.snapshotBatchSize, + StreamOldData: p.streamSnapshot, + TemporaryReplicationSlot: p.temporarySlot, + StreamUncomited: p.streamUncomited, + DecodingPlugin: p.decodingPlugin, + SnapshotMemorySafetyFactor: p.snapshotMemSafetyFactor, }) if err != nil { @@ -363,12 +371,8 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { }() var nextTimedBatchChan <-chan time.Time - var flushBatch func(context.Context, chan<- asyncMessage, service.MessageBatch, *int64) bool - if p.checkpointLimit > 1 { - flushBatch = p.asyncCheckpointer() - } else { - flushBatch = p.syncCheckpointer() - } + var flushBatch func(context.Context, chan<- asyncMessage, service.MessageBatch, *int64, *chan bool) bool + flushBatch = p.asyncCheckpointer() // offsets are nilable since we don't provide offset tracking during the snapshot phase var latestOffset *int64 @@ -383,35 +387,42 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { break } - if !flushBatch(ctx, p.msgChan, flushedBatch, latestOffset) { + if !flushBatch(ctx, p.msgChan, flushedBatch, latestOffset, nil) { + break + } + case _, open := <-p.pglogicalStream.TxBeginChan(): + if !open { + p.logger.Debugf("TxBeginChan closed, exiting...") break } + + p.logger.Debugf("Entering transaction state. Stop messages from ack until we receive commit message...") + p.inTxState.Store(true) + // TrxCommit LSN must be acked when all the bessages in the batch are processed case trxCommitLsn, open := <-p.pglogicalStream.AckTxChan(): if !open { break } - fmt.Println("trxCommitLsn", trxCommitLsn) - fmt.Println("Force flushing the batch") + p.cMut.Lock() + p.cMut.Unlock() + flushedBatch, err := batchPolicy.Flush(ctx) if err != nil { p.mgr.Logger().Debugf("Flush batch error: %w", err) break } - if !flushBatch(ctx, p.msgChan, flushedBatch, latestOffset) { - break - } - - time.Sleep(time.Second) - - if err = p.pglogicalStream.AckLSN(trxCommitLsn); err != nil { - p.mgr.Logger().Errorf("Failed to ack LSN: %v", err) + callbackChan := make(chan bool) + if !flushBatch(ctx, p.msgChan, flushedBatch, latestOffset, &callbackChan) { break } + <-callbackChan + p.pglogicalStream.AckLSN(trxCommitLsn) p.pglogicalStream.ConsumedCallback() <- true + case message, open := <-p.pglogicalStream.Messages(): if !open { break @@ -428,15 +439,10 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { } latestOffset = &parsedLSN } - if mb, err = json.Marshal(message); err != nil { break } - if message.IsStreaming { - fmt.Println("Message received", string(mb)) - } - batchMsg := service.NewMessage(mb) streaming := strconv.FormatBool(message.IsStreaming) @@ -460,9 +466,18 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { p.mgr.Logger().Debugf("Flush batch error: %w", err) break } - if !flushBatch(ctx, p.msgChan, flushedBatch, latestOffset) { - break + if message.IsStreaming { + callbackChan := make(chan bool) + if !flushBatch(ctx, p.msgChan, flushedBatch, latestOffset, &callbackChan) { + break + } + <-callbackChan + } else { + if !flushBatch(ctx, p.msgChan, flushedBatch, latestOffset, nil) { + break + } } + } case <-ctx.Done(): p.pglogicalStream.Stop() @@ -473,12 +488,18 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { return err } -func (p *pgStreamInput) asyncCheckpointer() func(context.Context, chan<- asyncMessage, service.MessageBatch, *int64) bool { +func (p *pgStreamInput) asyncCheckpointer() func(context.Context, chan<- asyncMessage, service.MessageBatch, *int64, *chan bool) bool { cp := checkpoint.NewCapped[*int64](int64(p.checkpointLimit)) - return func(ctx context.Context, c chan<- asyncMessage, msg service.MessageBatch, lsn *int64) bool { + return func(ctx context.Context, c chan<- asyncMessage, msg service.MessageBatch, lsn *int64, txCommitConfirmChan *chan bool) bool { if msg == nil { + if txCommitConfirmChan != nil { + go func() { + *txCommitConfirmChan <- true + }() + } return true } + resolveFn, err := cp.Track(ctx, lsn, int64(len(msg))) if err != nil { if ctx.Err() == nil { @@ -486,6 +507,7 @@ func (p *pgStreamInput) asyncCheckpointer() func(context.Context, chan<- asyncMe } return false } + select { case c <- asyncMessage{ msg: msg, @@ -496,8 +518,10 @@ func (p *pgStreamInput) asyncCheckpointer() func(context.Context, chan<- asyncMe } p.cMut.Lock() if lsn != nil { - fmt.Println("Acking LSN from chackpointer", *lsn) p.pglogicalStream.AckLSN(Int64ToLSN(*lsn)) + if txCommitConfirmChan != nil { + *txCommitConfirmChan <- true + } } p.cMut.Unlock() return nil @@ -506,47 +530,7 @@ func (p *pgStreamInput) asyncCheckpointer() func(context.Context, chan<- asyncMe case <-ctx.Done(): return false } - return true - } -} -func (p *pgStreamInput) syncCheckpointer() func(context.Context, chan<- asyncMessage, service.MessageBatch, *int64) bool { - ackedChan := make(chan error) - return func(ctx context.Context, c chan<- asyncMessage, msg service.MessageBatch, lsn *int64) bool { - if msg == nil { - return true - } - select { - case c <- asyncMessage{ - msg: msg, - ackFn: func(ctx context.Context, res error) error { - resErr := res - if resErr == nil { - p.cMut.Lock() - if lsn != nil { - p.pglogicalStream.AckLSN(Int64ToLSN(*lsn)) - } - p.cMut.Unlock() - } - select { - case ackedChan <- resErr: - case <-ctx.Done(): - } - return nil - }, - }: - select { - case resErr := <-ackedChan: - if resErr != nil { - p.mgr.Logger().Errorf("Received error from message batch: %v, shutting down consumer.\n", resErr) - return false - } - case <-ctx.Done(): - return false - } - case <-ctx.Done(): - return false - } return true } } diff --git a/internal/impl/postgresql/integration_test.go b/internal/impl/postgresql/integration_test.go index e99643937e..5b9bc8d3bb 100644 --- a/internal/impl/postgresql/integration_test.go +++ b/internal/impl/postgresql/integration_test.go @@ -202,19 +202,18 @@ file: `, tmpDir) streamOutBuilder := service.NewStreamBuilder() - require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: OFF`)) + require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: INFO`)) require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) require.NoError(t, streamOutBuilder.AddInputYAML(template)) - var outMessages []string - var outMessagesMut sync.Mutex - - require.NoError(t, streamOutBuilder.AddConsumerFunc(func(c context.Context, m *service.Message) error { - msgBytes, err := m.AsBytes() + var outBatches []string + var outBatchMut sync.Mutex + require.NoError(t, streamOutBuilder.AddBatchConsumerFunc(func(c context.Context, mb service.MessageBatch) error { + msgBytes, err := mb[0].AsBytes() require.NoError(t, err) - outMessagesMut.Lock() - outMessages = append(outMessages, string(msgBytes)) - outMessagesMut.Unlock() + outBatchMut.Lock() + outBatches = append(outBatches, string(msgBytes)) + outBatchMut.Unlock() return nil })) @@ -226,9 +225,9 @@ file: }() assert.Eventually(t, func() bool { - outMessagesMut.Lock() - defer outMessagesMut.Unlock() - return len(outMessages) == 1000 + outBatchMut.Lock() + defer outBatchMut.Unlock() + return len(outBatches) == 1000 }, time.Second*25, time.Millisecond*100) for i := 0; i < 1000; i++ { @@ -238,10 +237,10 @@ file: } assert.Eventually(t, func() bool { - outMessagesMut.Lock() - defer outMessagesMut.Unlock() - return len(outMessages) == 2000 - }, time.Second*25, time.Millisecond*100) + outBatchMut.Lock() + defer outBatchMut.Unlock() + return len(outBatches) == 2000 + }, time.Second, time.Millisecond*100) require.NoError(t, streamOut.StopWithin(time.Second*10)) @@ -253,13 +252,13 @@ file: require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) require.NoError(t, streamOutBuilder.AddInputYAML(template)) - outMessages = []string{} - require.NoError(t, streamOutBuilder.AddConsumerFunc(func(c context.Context, m *service.Message) error { - msgBytes, err := m.AsBytes() + outBatches = []string{} + require.NoError(t, streamOutBuilder.AddBatchConsumerFunc(func(c context.Context, mb service.MessageBatch) error { + msgBytes, err := mb[0].AsBytes() require.NoError(t, err) - outMessagesMut.Lock() - outMessages = append(outMessages, string(msgBytes)) - outMessagesMut.Unlock() + outBatchMut.Lock() + outBatches = append(outBatches, string(msgBytes)) + outBatchMut.Unlock() return nil })) @@ -277,9 +276,9 @@ file: } assert.Eventually(t, func() bool { - outMessagesMut.Lock() - defer outMessagesMut.Unlock() - return len(outMessages) == 50 + outBatchMut.Lock() + defer outBatchMut.Unlock() + return len(outBatches) == 50 }, time.Second*20, time.Millisecond*100) require.NoError(t, streamOut.StopWithin(time.Second*10)) @@ -342,15 +341,14 @@ file: require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) require.NoError(t, streamOutBuilder.AddInputYAML(template)) - var outMessages []string - var outMessagesMut sync.Mutex - - require.NoError(t, streamOutBuilder.AddConsumerFunc(func(c context.Context, m *service.Message) error { - msgBytes, err := m.AsBytes() + var outBatches []string + var outBatchMut sync.Mutex + require.NoError(t, streamOutBuilder.AddBatchConsumerFunc(func(c context.Context, mb service.MessageBatch) error { + msgBytes, err := mb[0].AsBytes() require.NoError(t, err) - outMessagesMut.Lock() - outMessages = append(outMessages, string(msgBytes)) - outMessagesMut.Unlock() + outBatchMut.Lock() + outBatches = append(outBatches, string(msgBytes)) + outBatchMut.Unlock() return nil })) @@ -362,9 +360,9 @@ file: }() assert.Eventually(t, func() bool { - outMessagesMut.Lock() - defer outMessagesMut.Unlock() - return len(outMessages) == 10 + outBatchMut.Lock() + defer outBatchMut.Unlock() + return len(outBatches) == 10 }, time.Second*25, time.Millisecond*100) for i := 0; i < 10; i++ { @@ -375,9 +373,9 @@ file: } assert.Eventually(t, func() bool { - outMessagesMut.Lock() - defer outMessagesMut.Unlock() - return len(outMessages) == 20 + outBatchMut.Lock() + defer outBatchMut.Unlock() + return len(outBatches) == 20 }, time.Second*25, time.Millisecond*100) require.NoError(t, streamOut.StopWithin(time.Second*10)) @@ -390,13 +388,13 @@ file: require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) require.NoError(t, streamOutBuilder.AddInputYAML(template)) - outMessages = []string{} + outBatches = []string{} require.NoError(t, streamOutBuilder.AddConsumerFunc(func(c context.Context, m *service.Message) error { msgBytes, err := m.AsBytes() require.NoError(t, err) - outMessagesMut.Lock() - outMessages = append(outMessages, string(msgBytes)) - outMessagesMut.Unlock() + outBatchMut.Lock() + outBatches = append(outBatches, string(msgBytes)) + outBatchMut.Unlock() return nil })) @@ -414,9 +412,9 @@ file: } assert.Eventually(t, func() bool { - outMessagesMut.Lock() - defer outMessagesMut.Unlock() - return len(outMessages) == 10 + outBatchMut.Lock() + defer outBatchMut.Unlock() + return len(outBatches) == 10 }, time.Second*20, time.Millisecond*100) require.NoError(t, streamOut.StopWithin(time.Second*10)) @@ -428,7 +426,7 @@ file: } func TestIntegrationPgStreamingFromRemoteDB(t *testing.T) { - // t.Skip("This test requires a remote database to run. Aimed to test remote databases") + t.Skip("This test requires a remote database to run. Aimed to test remote databases") tmpDir := t.TempDir() // tables: users, products, orders, order_items @@ -502,7 +500,6 @@ file: require.NoError(t, err) go func() { - fmt.Println("Starting stream") _ = streamOut.Run(context.Background()) }() @@ -546,7 +543,7 @@ func TestIntegrationPgCDCForPgOutputStreamUncomitedPlugin(t *testing.T) { hostAndPortSplited := strings.Split(hostAndPort, ":") fake := faker.New() - for i := 0; i < 10; i++ { + for i := 0; i < 10000; i++ { _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) require.NoError(t, err) } @@ -560,6 +557,7 @@ pg_stream: port: %s schema: public tls: none + snapshot_batch_size: 100 stream_snapshot: true decoding_plugin: pgoutput stream_uncomited: true @@ -579,15 +577,14 @@ file: require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) require.NoError(t, streamOutBuilder.AddInputYAML(template)) - var outMessages []string - var outMessagesMut sync.Mutex - - require.NoError(t, streamOutBuilder.AddConsumerFunc(func(c context.Context, m *service.Message) error { - msgBytes, err := m.AsBytes() + var outBatches []string + var outBatchMut sync.Mutex + require.NoError(t, streamOutBuilder.AddBatchConsumerFunc(func(c context.Context, mb service.MessageBatch) error { + msgBytes, err := mb[0].AsBytes() require.NoError(t, err) - outMessagesMut.Lock() - outMessages = append(outMessages, string(msgBytes)) - outMessagesMut.Unlock() + outBatchMut.Lock() + outBatches = append(outBatches, string(msgBytes)) + outBatchMut.Unlock() return nil })) @@ -599,9 +596,9 @@ file: }() assert.Eventually(t, func() bool { - outMessagesMut.Lock() - defer outMessagesMut.Unlock() - return len(outMessages) == 10 + outBatchMut.Lock() + defer outBatchMut.Unlock() + return len(outBatches) == 10000 }, time.Second*25, time.Millisecond*100) for i := 0; i < 10; i++ { @@ -612,9 +609,9 @@ file: } assert.Eventually(t, func() bool { - outMessagesMut.Lock() - defer outMessagesMut.Unlock() - return len(outMessages) == 20 + outBatchMut.Lock() + defer outBatchMut.Unlock() + return len(outBatches) == 10010 }, time.Second*25, time.Millisecond*100) require.NoError(t, streamOut.StopWithin(time.Second*10)) @@ -627,13 +624,13 @@ file: require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) require.NoError(t, streamOutBuilder.AddInputYAML(template)) - outMessages = []string{} + outBatches = []string{} require.NoError(t, streamOutBuilder.AddConsumerFunc(func(c context.Context, m *service.Message) error { msgBytes, err := m.AsBytes() require.NoError(t, err) - outMessagesMut.Lock() - outMessages = append(outMessages, string(msgBytes)) - outMessagesMut.Unlock() + outBatchMut.Lock() + outBatches = append(outBatches, string(msgBytes)) + outBatchMut.Unlock() return nil })) @@ -651,9 +648,9 @@ file: } assert.Eventually(t, func() bool { - outMessagesMut.Lock() - defer outMessagesMut.Unlock() - return len(outMessages) == 10 + outBatchMut.Lock() + defer outBatchMut.Unlock() + return len(outBatches) == 10 }, time.Second*20, time.Millisecond*100) require.NoError(t, streamOut.StopWithin(time.Second*10)) @@ -718,15 +715,14 @@ file: require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) require.NoError(t, streamOutBuilder.AddInputYAML(template)) - var outMessages []string - var outMessagesMut sync.Mutex - - require.NoError(t, streamOutBuilder.AddConsumerFunc(func(c context.Context, m *service.Message) error { - msgBytes, err := m.AsBytes() + var outBatches []string + var outBatchMut sync.Mutex + require.NoError(t, streamOutBuilder.AddBatchConsumerFunc(func(c context.Context, mb service.MessageBatch) error { + msgBytes, err := mb[0].AsBytes() require.NoError(t, err) - outMessagesMut.Lock() - outMessages = append(outMessages, string(msgBytes)) - outMessagesMut.Unlock() + outBatchMut.Lock() + outBatches = append(outBatches, string(msgBytes)) + outBatchMut.Unlock() return nil })) @@ -738,9 +734,9 @@ file: }() assert.Eventually(t, func() bool { - outMessagesMut.Lock() - defer outMessagesMut.Unlock() - return len(outMessages) == 1000 + outBatchMut.Lock() + defer outBatchMut.Unlock() + return len(outBatches) == 1000 }, time.Second*25, time.Millisecond*100) for i := 0; i < 1000; i++ { @@ -751,9 +747,9 @@ file: } assert.Eventually(t, func() bool { - outMessagesMut.Lock() - defer outMessagesMut.Unlock() - return len(outMessages) == 2000 + outBatchMut.Lock() + defer outBatchMut.Unlock() + return len(outBatches) == 2000 }, time.Second*25, time.Millisecond*100) require.NoError(t, streamOut.StopWithin(time.Second*10)) @@ -766,13 +762,13 @@ file: require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) require.NoError(t, streamOutBuilder.AddInputYAML(template)) - outMessages = []string{} + outBatches = []string{} require.NoError(t, streamOutBuilder.AddConsumerFunc(func(c context.Context, m *service.Message) error { msgBytes, err := m.AsBytes() require.NoError(t, err) - outMessagesMut.Lock() - outMessages = append(outMessages, string(msgBytes)) - outMessagesMut.Unlock() + outBatchMut.Lock() + outBatches = append(outBatches, string(msgBytes)) + outBatchMut.Unlock() return nil })) @@ -790,9 +786,9 @@ file: } assert.Eventually(t, func() bool { - outMessagesMut.Lock() - defer outMessagesMut.Unlock() - return len(outMessages) == 1000 + outBatchMut.Lock() + defer outBatchMut.Unlock() + return len(outBatches) == 1000 }, time.Second*20, time.Millisecond*100) require.NoError(t, streamOut.StopWithin(time.Second*10)) diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 643ae1d583..04a7fd9ee7 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -56,6 +56,7 @@ type Stream struct { streamUncomited bool snapshotter *Snapshotter transactionAckChan chan string + transactionBeginChan chan bool lsnAckBuffer []string @@ -129,6 +130,7 @@ func NewPgStream(config Config) (*Stream, error) { tableNames: tableNames, consumedCallback: make(chan bool), transactionAckChan: make(chan string), + transactionBeginChan: make(chan bool), lsnAckBuffer: []string{}, logger: config.logger, m: sync.Mutex{}, @@ -292,8 +294,6 @@ func (s *Stream) AckLSN(lsn string) error { return err } - fmt.Println("Ack LSN", lsn, "clientXLogPos") - err = SendStandbyStatusUpdate(context.Background(), s.pgConn, StandbyStatusUpdate{ WALApplyPosition: clientXLogPos, WALWritePosition: clientXLogPos, @@ -445,8 +445,6 @@ func (s *Stream) streamMessagesAsync() { return } - fmt.Println("Receive pg message", message) - isCommit, _, err := isCommitMessage(xld.WALData) if err != nil { s.logger.Errorf("Failed to parse WAL data: %w", err) @@ -456,13 +454,26 @@ func (s *Stream) streamMessagesAsync() { return } + isBegin, err := isBeginMessage(xld.WALData) + if err != nil { + s.logger.Errorf("Failed to parse WAL data: %w", err) + if err = s.Stop(); err != nil { + s.logger.Errorf("Failed to stop the stream: %v", err) + } + return + } + + if isBegin { + s.transactionBeginChan <- true + } + // when receiving a commit message, we need to acknowledge the LSN // but we must wait for benthos to flush the messages before we can do that if isCommit { s.transactionAckChan <- clientXLogPos.String() <-s.consumedCallback } else { - if message == nil { + if message == nil && (!isBegin && !isCommit) { // 0 changes happened in the transaction // or we received a change that are not supported/needed by the replication stream if err = s.AckLSN(clientXLogPos.String()); err != nil { @@ -473,9 +484,8 @@ func (s *Stream) streamMessagesAsync() { } return } - } else { + } else if message != nil { lsn := clientXLogPos.String() - fmt.Println("Pushed uncomited message to stream", lsn, message) s.messages <- StreamMessage{ Lsn: &lsn, Changes: []StreamMessageChanges{ @@ -556,6 +566,10 @@ func (s *Stream) streamMessagesAsync() { } } +func (s *Stream) TxBeginChan() chan bool { + return s.transactionBeginChan +} + func (s *Stream) AckTxChan() chan string { return s.transactionAckChan } @@ -741,7 +755,7 @@ func (s *Stream) processSnapshot() { offset += batchSize - if batchSize != rowsCount { + if rowsCount < batchSize { break } } From e328d5f3c8515a52b46691b05b2b292639a3e2d2 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Mon, 28 Oct 2024 21:28:50 +0100 Subject: [PATCH 038/118] chore(): removed debug lines; fixed linter --- internal/impl/postgresql/input_postgrecdc.go | 28 +++++++------- .../pglogicalstream/logical_stream.go | 37 ++----------------- 2 files changed, 16 insertions(+), 49 deletions(-) diff --git a/internal/impl/postgresql/input_postgrecdc.go b/internal/impl/postgresql/input_postgrecdc.go index cbad96c0b3..d514974a3d 100644 --- a/internal/impl/postgresql/input_postgrecdc.go +++ b/internal/impl/postgresql/input_postgrecdc.go @@ -325,7 +325,6 @@ type pgStreamInput struct { snapshotMetrics *service.MetricGauge replicationLag *service.MetricGauge - pendingTrx *string releaseTrxChan chan bool inTxState atomic.Bool } @@ -371,8 +370,7 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { }() var nextTimedBatchChan <-chan time.Time - var flushBatch func(context.Context, chan<- asyncMessage, service.MessageBatch, *int64, *chan bool) bool - flushBatch = p.asyncCheckpointer() + flushBatch := p.asyncCheckpointer() // offsets are nilable since we don't provide offset tracking during the snapshot phase var latestOffset *int64 @@ -390,14 +388,6 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { if !flushBatch(ctx, p.msgChan, flushedBatch, latestOffset, nil) { break } - case _, open := <-p.pglogicalStream.TxBeginChan(): - if !open { - p.logger.Debugf("TxBeginChan closed, exiting...") - break - } - - p.logger.Debugf("Entering transaction state. Stop messages from ack until we receive commit message...") - p.inTxState.Store(true) // TrxCommit LSN must be acked when all the bessages in the batch are processed case trxCommitLsn, open := <-p.pglogicalStream.AckTxChan(): @@ -420,7 +410,11 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { } <-callbackChan - p.pglogicalStream.AckLSN(trxCommitLsn) + if err = p.pglogicalStream.AckLSN(trxCommitLsn); err != nil { + p.mgr.Logger().Errorf("Failed to ack LSN: %v", err) + break + } + p.pglogicalStream.ConsumedCallback() <- true case message, open := <-p.pglogicalStream.Messages(): @@ -480,7 +474,9 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { } case <-ctx.Done(): - p.pglogicalStream.Stop() + if err = p.pglogicalStream.Stop(); err != nil { + p.logger.Errorf("Failed to stop pglogical stream: %v", err) + } } } }() @@ -517,13 +513,15 @@ func (p *pgStreamInput) asyncCheckpointer() func(context.Context, chan<- asyncMe return nil } p.cMut.Lock() + defer p.cMut.Unlock() if lsn != nil { - p.pglogicalStream.AckLSN(Int64ToLSN(*lsn)) + if err = p.pglogicalStream.AckLSN(Int64ToLSN(*lsn)); err != nil { + return err + } if txCommitConfirmChan != nil { *txCommitConfirmChan <- true } } - p.cMut.Unlock() return nil }, }: diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 04a7fd9ee7..10e43984ba 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -454,26 +454,13 @@ func (s *Stream) streamMessagesAsync() { return } - isBegin, err := isBeginMessage(xld.WALData) - if err != nil { - s.logger.Errorf("Failed to parse WAL data: %w", err) - if err = s.Stop(); err != nil { - s.logger.Errorf("Failed to stop the stream: %v", err) - } - return - } - - if isBegin { - s.transactionBeginChan <- true - } - // when receiving a commit message, we need to acknowledge the LSN // but we must wait for benthos to flush the messages before we can do that if isCommit { s.transactionAckChan <- clientXLogPos.String() <-s.consumedCallback } else { - if message == nil && (!isBegin && !isCommit) { + if message == nil && !isCommit { // 0 changes happened in the transaction // or we received a change that are not supported/needed by the replication stream if err = s.AckLSN(clientXLogPos.String()); err != nil { @@ -566,18 +553,11 @@ func (s *Stream) streamMessagesAsync() { } } -func (s *Stream) TxBeginChan() chan bool { - return s.transactionBeginChan -} - +// AckTxChan returns the transaction ack channel func (s *Stream) AckTxChan() chan string { return s.transactionAckChan } -func (s *Stream) ConfigrmAckTxChan() chan bool { - return s.consumedCallback -} - func (s *Stream) processSnapshot() { if err := s.snapshotter.prepare(); err != nil { s.logger.Errorf("Failed to prepare database snapshot. Probably snapshot is expired...: %v", err.Error()) @@ -598,16 +578,6 @@ func (s *Stream) processSnapshot() { s.logger.Infof("Starting snapshot processing") - type RawMessage struct { - RowsCount int - Offset int - ColumnTypes []*sql.ColumnType - ColumnNames []string - ScanArgs []interface{} - ValueGetters []func(interface{}) interface{} - TableName string - } - var wg sync.WaitGroup for _, table := range s.tableNames { @@ -772,8 +742,7 @@ func (s *Stream) processSnapshot() { go s.streamMessagesAsync() } -// LrMessageC represents a message from the stream that are sent to the consumer on the logical replication stage -// meaning these messages will have non-nil LSN field +// Messages is a channel that can be used to consume messages from the plugin. It will contain LSN nil for snapshot messages func (s *Stream) Messages() chan StreamMessage { return s.messages } From 6fd8e4f3269cad2a0f6381d9b82c292ee2a250f6 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Fri, 1 Nov 2024 16:17:28 +0100 Subject: [PATCH 039/118] chore(): updated tls config field && small refactoring --- go.mod | 2 +- go.sum | 3 - internal/impl/postgresql/input_postgrecdc.go | 146 +++++++++--------- internal/impl/postgresql/integration_test.go | 49 +++--- .../impl/postgresql/pglogicalstream/config.go | 10 +- .../pglogicalstream/logical_stream.go | 67 ++++---- .../postgresql/pglogicalstream/monitor.go | 2 +- .../postgresql/pglogicalstream/snapshotter.go | 2 +- .../impl/postgresql/pglogicalstream/util.go | 21 ++- 9 files changed, 152 insertions(+), 150 deletions(-) diff --git a/go.mod b/go.mod index 8dbcb2807e..eaa968b99b 100644 --- a/go.mod +++ b/go.mod @@ -92,7 +92,6 @@ require ( github.com/opensearch-project/opensearch-go/v3 v3.1.0 github.com/ory/dockertest/v3 v3.11.0 github.com/oschwald/geoip2-golang v1.11.0 - github.com/panjf2000/ants/v2 v2.10.0 github.com/parquet-go/parquet-go v0.23.0 github.com/pebbe/zmq4 v1.2.11 github.com/pinecone-io/go-pinecone v1.0.0 @@ -153,6 +152,7 @@ require ( cloud.google.com/go/longrunning v0.5.9 // indirect github.com/containerd/platforms v0.2.1 // indirect github.com/hamba/avro/v2 v2.22.2-0.20240625062549-66aad10411d9 // indirect + github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect diff --git a/go.sum b/go.sum index 8259b9d54f..ac8c4e9bb5 100644 --- a/go.sum +++ b/go.sum @@ -945,8 +945,6 @@ github.com/oschwald/geoip2-golang v1.11.0 h1:hNENhCn1Uyzhf9PTmquXENiWS6AlxAEnBII github.com/oschwald/geoip2-golang v1.11.0/go.mod h1:P9zG+54KPEFOliZ29i7SeYZ/GM6tfEL+rgSn03hYuUo= github.com/oschwald/maxminddb-golang v1.13.0 h1:R8xBorY71s84yO06NgTmQvqvTvlS/bnYZrrWX1MElnU= github.com/oschwald/maxminddb-golang v1.13.0/go.mod h1:BU0z8BfFVhi1LQaonTwwGQlsHUEu9pWNdMfmq4ztm0o= -github.com/panjf2000/ants/v2 v2.10.0 h1:zhRg1pQUtkyRiOFo2Sbqwjp0GfBNo9cUY2/Grpx1p+8= -github.com/panjf2000/ants/v2 v2.10.0/go.mod h1:7ZxyxsqE4vvW0M7LSD8aI3cKwgFhBHbxnlN8mDqHa1I= github.com/parquet-go/parquet-go v0.23.0 h1:dyEU5oiHCtbASyItMCD2tXtT2nPmoPbKpqf0+nnGrmk= github.com/parquet-go/parquet-go v0.23.0/go.mod h1:MnwbUcFHU6uBYMymKAlPPAw9yh3kE1wWl6Gl1uLdkNk= github.com/pascaldekloe/goe v0.1.0/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= @@ -1406,7 +1404,6 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= diff --git a/internal/impl/postgresql/input_postgrecdc.go b/internal/impl/postgresql/input_postgrecdc.go index d514974a3d..8bd559e47b 100644 --- a/internal/impl/postgresql/input_postgrecdc.go +++ b/internal/impl/postgresql/input_postgrecdc.go @@ -13,19 +13,38 @@ import ( "crypto/tls" "encoding/json" "strconv" - "strings" "sync" "sync/atomic" "time" "github.com/Jeffail/checkpoint" "github.com/jackc/pgx/v5/pgconn" - "github.com/lucasepe/codename" "github.com/redpanda-data/benthos/v4/public/service" "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream" ) +const ( + fieldHost = "host" + fieldPort = "port" + fieldUser = "user" + fieldPass = "password" + fieldSchema = "schema" + fieldDatabase = "database" + fieldTls = "tls" + fieldStreamUncomitted = "stream_uncomitted" + fieldPgConnOptions = "pg_conn_options" + fieldStreamSnapshot = "stream_snapshot" + fieldSnapshotMemSafetyFactor = "snapshot_memory_safety_factor" + fieldSnapshotBatchSize = "snapshot_batch_size" + fieldDecodingPlugin = "decoding_plugin" + fieldTables = "tables" + fieldCheckpointLimit = "checkpoint_limit" + fieldTemporarySlot = "temporary_slot" + fieldSlotName = "slot_name" + fieldBatching = "batching" +) + var randomSlotName string type asyncMessage struct { @@ -45,68 +64,69 @@ This input adds the following metadata fields to each message: - table (Name of the table that the message originated from) - operation (Type of operation that generated the message, such as INSERT, UPDATE, or DELETE) `). - Field(service.NewStringField("host"). + Field(service.NewStringField(fieldHost). Description("The hostname or IP address of the PostgreSQL instance."). Example("123.0.0.1")). - Field(service.NewIntField("port"). + Field(service.NewIntField(fieldPort). Description("The port number on which the PostgreSQL instance is listening."). Example(5432). Default(5432)). - Field(service.NewStringField("user"). + Field(service.NewStringField(fieldUser). Description("Username of a user with replication permissions. For AWS RDS, this typically requires superuser privileges."). Example("postgres"), ). - Field(service.NewStringField("password"). + Field(service.NewStringField(fieldPass). Description("Password for the specified PostgreSQL user.")). - Field(service.NewStringField("schema"). + Field(service.NewStringField(fieldSchema). Description("The PostgreSQL schema from which to replicate data.")). - Field(service.NewStringField("database"). + Field(service.NewStringField(fieldDatabase). Description("The name of the PostgreSQL database to connect to.")). - Field(service.NewStringEnumField("tls", "require", "none"). + Field(service.NewTLSToggledField(fieldTls). Description("Specifies whether to use TLS for the database connection. Set to 'require' to enforce TLS, or 'none' to disable it."). - Example("none"). - Default("none")). - Field(service.NewBoolField("stream_uncomited"). + Default(nil)). + Field(service.NewBoolField(fieldStreamUncomitted). Description("If set to true, the plugin will stream uncommitted transactions before receiving a commit message from PostgreSQL. This may result in duplicate records if the connector is restarted."). Default(false)). - Field(service.NewStringField("pg_conn_options"). + Field(service.NewStringField(fieldPgConnOptions). Description("Additional PostgreSQL connection options as a string. Refer to PostgreSQL documentation for available options."). Default(""), ). - Field(service.NewBoolField("stream_snapshot"). + Field(service.NewBoolField(fieldStreamSnapshot). Description("When set to true, the plugin will first stream a snapshot of all existing data in the database before streaming changes."). Example(true). Default(false)). - Field(service.NewFloatField("snapshot_memory_safety_factor"). + Field(service.NewFloatField(fieldSnapshotMemSafetyFactor). Description("Determines the fraction of available memory that can be used for streaming the snapshot. Values between 0 and 1 represent the percentage of memory to use. Lower values make initial streaming slower but help prevent out-of-memory errors."). Example(0.2). Default(1)). - Field(service.NewIntField("snapshot_batch_size"). + Field(service.NewIntField(fieldSnapshotBatchSize). Description("The number of rows to fetch in each batch when querying the snapshot. A value of 0 lets the plugin determine the batch size based on `snapshot_memory_safety_factor` property."). Example(10000). Default(0)). - Field(service.NewStringEnumField("decoding_plugin", "pgoutput", "wal2json"). - Description("Specifies the logical decoding plugin to use for streaming changes from PostgreSQL. 'pgoutput' is the native logical replication protocol, while 'wal2json' provides change data as JSON."). + Field(service.NewStringEnumField(fieldDecodingPlugin, "pgoutput", "wal2json"). + Description(`Specifies the logical decoding plugin to use for streaming changes from PostgreSQL. 'pgoutput' is the native logical replication protocol, while 'wal2json' provides change data as JSON. + Important: No matter which plugin you choose, the data will be converted to JSON before sending it to Benthos. + `). Example("pgoutput"). Default("pgoutput")). - Field(service.NewStringListField("tables"). + Field(service.NewStringListField(fieldTables). Description("A list of table names to include in the logical replication. Each table should be specified as a separate item."). Example(` - my_table - my_table_2 `)). - Field(service.NewIntField("checkpoint_limit"). - Description("The maximum number of messages of the same topic and partition that can be processed at a given time. Increasing this limit enables parallel processing and batching at the output level to work on individual partitions. Any given offset will not be committed unless all messages under that offset are delivered in order to preserve at least once delivery guarantees."). + Field(service.NewIntField(fieldCheckpointLimit). + Description("The maximum number of messages that can be processed at a given time. Increasing this limit enables parallel processing and batching at the output level. Any given LSN will not be acknowledged unless all messages under that offset are delivered in order to preserve at least once delivery guarantees."). Version("3.33.0").Default(1024)). - Field(service.NewBoolField("temporary_slot"). + Field(service.NewBoolField(fieldTemporarySlot). Description("If set to true, creates a temporary replication slot that is automatically dropped when the connection is closed."). Default(false)). - Field(service.NewStringField("slot_name"). + Field(service.NewStringField(fieldSlotName). Description("The name of the PostgreSQL logical replication slot to use. If not provided, a random name will be generated. You can create this slot manually before starting replication if desired."). Example("my_test_slot"). Default(randomSlotName)). Field(service.NewAutoRetryNacksToggleField()). - Field(service.NewBatchPolicyField("batching").Advanced()) + Field(service.NewBatchPolicyField(fieldBatching)) func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s service.BatchInput, err error) { var ( @@ -118,7 +138,6 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser dbPassword string dbSlotName string temporarySlot bool - tlsSetting string tables []string streamSnapshot bool snapshotMemSafetyFactor float64 @@ -130,90 +149,86 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser batching service.BatchPolicy ) - dbSchema, err = conf.FieldString("schema") + dbSchema, err = conf.FieldString(fieldSchema) if err != nil { return nil, err } - dbSlotName, err = conf.FieldString("slot_name") + dbSlotName, err = conf.FieldString(fieldSlotName) if err != nil { return nil, err } - temporarySlot, err = conf.FieldBool("temporary_slot") + temporarySlot, err = conf.FieldBool(fieldTemporarySlot) if err != nil { return nil, err } - if dbSlotName == "" { - dbSlotName = randomSlotName - } - - dbPassword, err = conf.FieldString("password") + dbPassword, err = conf.FieldString(fieldPass) if err != nil { return nil, err } - dbUser, err = conf.FieldString("user") + dbUser, err = conf.FieldString(fieldUser) if err != nil { return nil, err } - tlsSetting, err = conf.FieldString("tls") + tlsConf, tlsEnabled, err := conf.FieldTLSToggled(fieldTls) if err != nil { return nil, err } - dbName, err = conf.FieldString("database") + dbName, err = conf.FieldString(fieldDatabase) if err != nil { return nil, err } - dbHost, err = conf.FieldString("host") + dbHost, err = conf.FieldString(fieldHost) if err != nil { return nil, err } - dbPort, err = conf.FieldInt("port") + dbPort, err = conf.FieldInt(fieldPort) if err != nil { return nil, err } - tables, err = conf.FieldStringList("tables") + tables, err = conf.FieldStringList(fieldTables) if err != nil { return nil, err } - if checkpointLimit, err = conf.FieldInt("checkpoint_limit"); err != nil { + if checkpointLimit, err = conf.FieldInt(fieldCheckpointLimit); err != nil { return nil, err } - streamSnapshot, err = conf.FieldBool("stream_snapshot") + streamSnapshot, err = conf.FieldBool(fieldStreamSnapshot) if err != nil { return nil, err } - streamUncomited, err = conf.FieldBool("stream_uncomited") + streamUncomited, err = conf.FieldBool(fieldStreamUncomitted) if err != nil { return nil, err } - decodingPlugin, err = conf.FieldString("decoding_plugin") + decodingPlugin, err = conf.FieldString(fieldDecodingPlugin) if err != nil { return nil, err } - snapshotMemSafetyFactor, err = conf.FieldFloat("snapshot_memory_safety_factor") + snapshotMemSafetyFactor, err = conf.FieldFloat(fieldSnapshotMemSafetyFactor) if err != nil { return nil, err } - snapshotBatchSize, err = conf.FieldInt("snapshot_batch_size") + snapshotBatchSize, err = conf.FieldInt(fieldSnapshotBatchSize) if err != nil { return nil, err } - if pgConnOptions, err = conf.FieldString("pg_conn_options"); err != nil { + if pgConnOptions, err = conf.FieldString(fieldPgConnOptions); err != nil { return nil, err } @@ -221,31 +236,28 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser pgConnOptions = "options=" + pgConnOptions } - if batching, err = conf.FieldBatchPolicy("batching"); err != nil { + if batching, err = conf.FieldBatchPolicy(fieldBatching); err != nil { return nil, err } else if batching.IsNoop() { batching.Count = 1 } pgconnConfig := pgconn.Config{ - Host: dbHost, - Port: uint16(dbPort), - Database: dbName, - User: dbUser, - TLSConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - Password: dbPassword, + Host: dbHost, + Port: uint16(dbPort), + Database: dbName, + User: dbUser, + TLSConfig: tlsConf, + Password: dbPassword, } - if tlsSetting == "none" { + if !tlsEnabled { pgconnConfig.TLSConfig = nil + tlsConf = nil } snapsotMetrics := mgr.Metrics().NewGauge("snapshot_progress", "table") replicationLag := mgr.Metrics().NewGauge("replication_lag_bytes") - snapshotMessageRate := mgr.Metrics().NewGauge("snapshot_message_rate") - snapshotRateCounter := NewRateCounter() i := &pgStreamInput{ dbConfig: pgconnConfig, @@ -254,14 +266,12 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser slotName: dbSlotName, schema: dbSchema, pgConnRuntimeParam: pgConnOptions, - tls: pglogicalstream.TLSVerify(tlsSetting), + tls: tlsConf, tables: tables, decodingPlugin: decodingPlugin, streamUncomited: streamUncomited, temporarySlot: temporarySlot, snapshotBatchSize: snapshotBatchSize, - snapshotMessageRate: snapshotMessageRate, - snapshotRateCounter: snapshotRateCounter, batching: batching, checkpointLimit: checkpointLimit, cMut: sync.Mutex{}, @@ -285,9 +295,6 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser } func init() { - rng, _ := codename.DefaultRNG() - randomSlotName = strings.ReplaceAll(codename.Generate(rng, 5), "-", "_") - err := service.RegisterBatchInput( "pg_stream", pgStreamConfigSpec, func(conf *service.ParsedConfig, mgr *service.Resources) (service.BatchInput, error) { @@ -300,7 +307,7 @@ func init() { type pgStreamInput struct { dbConfig pgconn.Config - tls pglogicalstream.TLSVerify + tls *tls.Config pglogicalStream *pglogicalstream.Stream pgConnRuntimeParam string slotName string @@ -330,7 +337,7 @@ type pgStreamInput struct { } func (p *pgStreamInput) Connect(ctx context.Context) error { - pgStream, err := pglogicalstream.NewPgStream(pglogicalstream.Config{ + pgStream, err := pglogicalstream.NewPgStream(ctx, &pglogicalstream.Config{ PgConnRuntimeParam: p.pgConnRuntimeParam, DBHost: p.dbConfig.Host, DBPassword: p.dbConfig.Password, @@ -340,7 +347,7 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { DBName: p.dbConfig.Database, DBSchema: p.schema, ReplicationSlotName: "rs_" + p.slotName, - TLSVerify: p.tls, + TLSConfig: p.tls, BatchSize: p.snapshotBatchSize, StreamOldData: p.streamSnapshot, TemporaryReplicationSlot: p.temporarySlot, @@ -389,7 +396,7 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { break } - // TrxCommit LSN must be acked when all the bessages in the batch are processed + // TrxCommit LSN must be acked when all the messages in the batch are processed case trxCommitLsn, open := <-p.pglogicalStream.AckTxChan(): if !open { break @@ -450,9 +457,6 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { p.replicationLag.Set(*message.WALLagBytes) } - p.snapshotRateCounter.Increment() - p.snapshotMessageRate.SetFloat64(p.snapshotRateCounter.Rate()) - if batchPolicy.Add(batchMsg) { nextTimedBatchChan = nil flushedBatch, err := batchPolicy.Flush(ctx) diff --git a/internal/impl/postgresql/integration_test.go b/internal/impl/postgresql/integration_test.go index 5b9bc8d3bb..c7bfd2c9fe 100644 --- a/internal/impl/postgresql/integration_test.go +++ b/internal/impl/postgresql/integration_test.go @@ -34,7 +34,7 @@ func ResourceWithPostgreSQLVersion(t *testing.T, pool *dockertest.Pool, version Repository: "postgres", Tag: version, Env: []string{ - "POSTGRES_PASSWORD=secret", + "POSTGRES_PASSWORD=l]YLSc|4[i56%{gY", "POSTGRES_USER=user_name", "POSTGRES_DB=dbname", }, @@ -56,7 +56,8 @@ func ResourceWithPostgreSQLVersion(t *testing.T, pool *dockertest.Pool, version hostAndPort := resource.GetHostPort("5432/tcp") hostAndPortSplited := strings.Split(hostAndPort, ":") - databaseURL := fmt.Sprintf("user=user_name password=secret dbname=dbname sslmode=disable host=%s port=%s", hostAndPortSplited[0], hostAndPortSplited[1]) + password := "l]YLSc|4[i56%{gY" + databaseURL := fmt.Sprintf("user=user_name password=%s dbname=dbname sslmode=disable host=%s port=%s", password, hostAndPortSplited[0], hostAndPortSplited[1]) var db *sql.DB pool.MaxWait = 120 * time.Second @@ -109,7 +110,7 @@ func TestIntegrationPgCDC(t *testing.T) { Repository: "usedatabrew/pgwal2json", Tag: "16", Env: []string{ - "POSTGRES_PASSWORD=secret", + "POSTGRES_PASSWORD=l]YLSc|4[i56%{gY", "POSTGRES_USER=user_name", "POSTGRES_DB=dbname", }, @@ -131,7 +132,8 @@ func TestIntegrationPgCDC(t *testing.T) { hostAndPort := resource.GetHostPort("5432/tcp") hostAndPortSplited := strings.Split(hostAndPort, ":") - databaseURL := fmt.Sprintf("user=user_name password=secret dbname=dbname sslmode=disable host=%s port=%s", hostAndPortSplited[0], hostAndPortSplited[1]) + password := "l]YLSc|4[i56%{gY" + databaseURL := fmt.Sprintf("user=user_name password=%s dbname=dbname sslmode=disable host=%s port=%s", password, hostAndPortSplited[0], hostAndPortSplited[1]) var db *sql.DB @@ -184,16 +186,15 @@ pg_stream: host: %s slot_name: test_slot user: user_name - password: secret + password: %s port: %s schema: public decoding_plugin: wal2json - tls: none stream_snapshot: true database: dbname tables: - flights -`, hostAndPortSplited[0], hostAndPortSplited[1]) +`, hostAndPortSplited[0], password, hostAndPortSplited[1]) cacheConf := fmt.Sprintf(` label: pg_stream_cache @@ -305,6 +306,7 @@ func TestIntegrationPgCDCForPgOutputPlugin(t *testing.T) { hostAndPort := resource.GetHostPort("5432/tcp") hostAndPortSplited := strings.Split(hostAndPort, ":") + password := "l]YLSc|4[i56%{gY" require.NoError(t, err) @@ -319,7 +321,7 @@ pg_stream: host: %s slot_name: test_slot_native_decoder user: user_name - password: secret + password: %s port: %s schema: public tls: none @@ -328,7 +330,7 @@ pg_stream: database: dbname tables: - flights -`, hostAndPortSplited[0], hostAndPortSplited[1]) +`, hostAndPortSplited[0], password, hostAndPortSplited[1]) cacheConf := fmt.Sprintf(` label: pg_stream_cache @@ -450,7 +452,7 @@ pg_stream: snapshot_batch_size: 100000 stream_snapshot: true decoding_plugin: pgoutput - stream_uncomited: false + stream_uncomitted: false temporary_slot: true database: %s tables: @@ -541,6 +543,7 @@ func TestIntegrationPgCDCForPgOutputStreamUncomitedPlugin(t *testing.T) { hostAndPort := resource.GetHostPort("5432/tcp") hostAndPortSplited := strings.Split(hostAndPort, ":") + password := "l]YLSc|4[i56%{gY" fake := faker.New() for i := 0; i < 10000; i++ { @@ -553,18 +556,17 @@ pg_stream: host: %s slot_name: test_slot_native_decoder user: user_name - password: secret + password: %s port: %s schema: public - tls: none snapshot_batch_size: 100 stream_snapshot: true decoding_plugin: pgoutput - stream_uncomited: true + stream_uncomitted: true database: dbname tables: - flights -`, hostAndPortSplited[0], hostAndPortSplited[1]) +`, hostAndPortSplited[0], password, hostAndPortSplited[1]) cacheConf := fmt.Sprintf(` label: pg_stream_cache @@ -592,7 +594,8 @@ file: require.NoError(t, err) go func() { - _ = streamOut.Run(context.Background()) + err = streamOut.Run(context.Background()) + require.NoError(t, err) }() assert.Eventually(t, func() bool { @@ -680,6 +683,7 @@ func TestIntegrationPgMultiVersionsCDCForPgOutputStreamUncomitedPlugin(t *testin hostAndPort := resource.GetHostPort("5432/tcp") hostAndPortSplited := strings.Split(hostAndPort, ":") + password := "l]YLSc|4[i56%{gY" fake := faker.New() for i := 0; i < 1000; i++ { @@ -692,17 +696,16 @@ pg_stream: host: %s slot_name: test_slot_native_decoder user: user_name - password: secret + password: %s port: %s schema: public - tls: none stream_snapshot: true decoding_plugin: pgoutput - stream_uncomited: true + stream_uncomitted: true database: dbname tables: - flights -`, hostAndPortSplited[0], hostAndPortSplited[1]) +`, hostAndPortSplited[0], password, hostAndPortSplited[1]) cacheConf := fmt.Sprintf(` label: pg_stream_cache @@ -817,6 +820,7 @@ func TestIntegrationPgMultiVersionsCDCForPgOutputStreamComittedPlugin(t *testing hostAndPort := resource.GetHostPort("5432/tcp") hostAndPortSplited := strings.Split(hostAndPort, ":") + password := "l]YLSc|4[i56%{gY" fake := faker.New() for i := 0; i < 1000; i++ { @@ -829,17 +833,16 @@ pg_stream: host: %s slot_name: test_slot_native_decoder user: user_name - password: secret + password: %s port: %s schema: public - tls: none stream_snapshot: true decoding_plugin: pgoutput - stream_uncomited: false + stream_uncomitted: false database: dbname tables: - flights -`, hostAndPortSplited[0], hostAndPortSplited[1]) +`, hostAndPortSplited[0], password, hostAndPortSplited[1]) cacheConf := fmt.Sprintf(` label: pg_stream_cache diff --git a/internal/impl/postgresql/pglogicalstream/config.go b/internal/impl/postgresql/pglogicalstream/config.go index 1c506d743b..3590a6ddf0 100644 --- a/internal/impl/postgresql/pglogicalstream/config.go +++ b/internal/impl/postgresql/pglogicalstream/config.go @@ -8,7 +8,11 @@ package pglogicalstream -import "github.com/redpanda-data/benthos/v4/public/service" +import ( + "crypto/tls" + + "github.com/redpanda-data/benthos/v4/public/service" +) // Config is the configuration for the pglogicalstream plugin type Config struct { @@ -26,8 +30,8 @@ type Config struct { DBSchema string `yaml:"db_schema"` // DbTables is the tables to stream changes from DBTables []string `yaml:"db_tables"` - // TlsVerify is the TLS verification configuration - TLSVerify TLSVerify `yaml:"tls_verify"` + // TLSConfig is the TLS verification configuration + TLSConfig *tls.Config `yaml:"tls"` // PgConnRuntimeParam is the runtime parameter for the PostgreSQL connection PgConnRuntimeParam string `yaml:"pg_conn_options"` diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 10e43984ba..de9b9e4d85 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -10,7 +10,6 @@ package pglogicalstream import ( "context" - "crypto/tls" "database/sql" "errors" "fmt" @@ -22,6 +21,7 @@ import ( "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgproto3" "github.com/jackc/pgx/v5/pgtype" + "github.com/lucasepe/codename" "github.com/redpanda-data/benthos/v4/public/service" ) @@ -31,7 +31,7 @@ type Stream struct { pgConn *pgconn.PgConn // extra copy of db config is required to establish a new db connection // which is required to take snapshot data - dbConfig pgconn.Config + dbConfig *pgconn.Config streamCtx context.Context streamCancel context.CancelFunc @@ -67,44 +67,35 @@ type Stream struct { } // NewPgStream creates a new instance of the Stream struct -func NewPgStream(config Config) (*Stream, error) { +func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { var ( cfg *pgconn.Config err error ) - sslVerifyFull := "" - if config.TLSVerify == TLSRequireVerify { - sslVerifyFull = "&sslmode=verify-full" - } - connectionParams := "" if config.PgConnRuntimeParam != "" { connectionParams = "&" + config.PgConnRuntimeParam } - q := fmt.Sprintf("postgres://%s:%s@%s:%d/%s?replication=database%s%s", + // intentiolly omit password to support password with special characters + // a new lines below cfg.Password = config.DBPassword + // we also need to use pgconn.ParseConfig since we are going to use this pased config to connect to the database for snapshot + // pgx panics when connection is not created via PaseConfig method + q := fmt.Sprintf("postgres://%s@%s:%d/%s?replication=database%s", config.DBUser, - config.DBPassword, config.DBHost, config.DBPort, config.DBName, - sslVerifyFull, connectionParams, ) if cfg, err = pgconn.ParseConfig(q); err != nil { return nil, err } - cfg.Password = config.DBPassword - if config.TLSVerify == TLSRequireVerify { - cfg.TLSConfig = &tls.Config{ - InsecureSkipVerify: true, - } - } else { - cfg.TLSConfig = nil - } + cfg.Password = config.DBPassword + cfg.TLSConfig = config.TLSConfig dbConn, err := pgconn.ConnectConfig(context.Background(), cfg) if err != nil { @@ -118,9 +109,14 @@ func NewPgStream(config Config) (*Stream, error) { var tableNames []string tableNames = append(tableNames, config.DBTables...) + if config.ReplicationSlotName == "" { + rng, _ := codename.DefaultRNG() + config.ReplicationSlotName = strings.ReplaceAll(codename.Generate(rng, 5), "-", "_") + } + stream := &Stream{ pgConn: dbConn, - dbConfig: *cfg, + dbConfig: cfg, messages: make(chan StreamMessage), slotName: config.ReplicationSlotName, schema: config.DBSchema, @@ -142,7 +138,7 @@ func NewPgStream(config Config) (*Stream, error) { } var version int - version, err = getPostgresVersion(*cfg) + version, err = getPostgresVersion(cfg) if err != nil { return nil, err } @@ -180,30 +176,29 @@ func NewPgStream(config Config) (*Stream, error) { stream.decodingPluginArguments = pluginArguments - // create snapshot transaction before creating a slot for older PostgreSQL versions to ensure consistency - pubName := "pglog_stream_" + config.ReplicationSlotName - fmt.Println("Creating publication", pubName, "for tables", tableNames) - if err = CreatePublication(context.Background(), stream.pgConn, pubName, tableNames, true); err != nil { + stream.logger.Debugf("Creating publication %s for tables: %s", pubName, tableNames) + if err = CreatePublication(ctx, stream.pgConn, pubName, tableNames, true); err != nil { return nil, err } - sysident, err := IdentifySystem(context.Background(), stream.pgConn) + sysident, err := IdentifySystem(ctx, stream.pgConn) if err != nil { return nil, err } var freshlyCreatedSlot = false var confirmedLSNFromDB string + var outputPlugin string // check is replication slot exist to get last restart SLN - connExecResult := stream.pgConn.Exec(context.TODO(), fmt.Sprintf("SELECT confirmed_flush_lsn FROM pg_replication_slots WHERE slot_name = '%s'", config.ReplicationSlotName)) + connExecResult := stream.pgConn.Exec(ctx, fmt.Sprintf("SELECT confirmed_flush_lsn, plugin FROM pg_replication_slots WHERE slot_name = '%s'", config.ReplicationSlotName)) if slotCheckResults, err := connExecResult.ReadAll(); err != nil { return nil, err } else { if len(slotCheckResults) == 0 || len(slotCheckResults[0].Rows) == 0 { // here we create a new replication slot because there is no slot found var createSlotResult CreateReplicationSlotResult - createSlotResult, err = CreateReplicationSlot(context.Background(), stream.pgConn, stream.slotName, stream.decodingPlugin.String(), + createSlotResult, err = CreateReplicationSlot(ctx, stream.pgConn, stream.slotName, stream.decodingPlugin.String(), CreateReplicationSlotOptions{Temporary: config.TemporaryReplicationSlot, SnapshotAction: "export", }, version, stream.snapshotter) @@ -215,10 +210,13 @@ func NewPgStream(config Config) (*Stream, error) { } else { slotCheckRow := slotCheckResults[0].Rows[0] confirmedLSNFromDB = string(slotCheckRow[0]) + outputPlugin = string(slotCheckRow[1]) } } - // TODO:: check decoding plugin and replication slot plugin should match + if !freshlyCreatedSlot && outputPlugin != stream.decodingPlugin.String() { + return nil, fmt.Errorf("Replication slot %s already exists with different output plugin: %s", config.ReplicationSlotName, outputPlugin) + } var lsnrestart LSN if freshlyCreatedSlot { @@ -245,7 +243,7 @@ func NewPgStream(config Config) (*Stream, error) { } stream.monitor = monitor - fmt.Println("Starting stream from LSN", stream.lsnrestart, "with clientXLogPos", stream.clientXLogPos, "and snapshot name", stream.snapshotName) + stream.logger.Debugf("Starting stream from LSN %s with clientXLogPos %s and snapshot name %s", stream.lsnrestart.String(), stream.clientXLogPos.String(), stream.snapshotName) if !freshlyCreatedSlot || !config.StreamOldData { if err = stream.startLr(); err != nil { return nil, err @@ -633,7 +631,7 @@ func (s *Stream) processSnapshot() { } queryDuration := time.Since(queryStart) - fmt.Printf("Query duration: %v %s \n", queryDuration, tableName) + s.logger.Debugf("Query duration: %v %s \n", queryDuration, tableName) if snapshotRows.Err() != nil { s.logger.Errorf("Failed to get snapshot data for table %v: %v", table, snapshotRows.Err().Error()) @@ -679,7 +677,6 @@ func (s *Stream) processSnapshot() { totalScanDuration += scanEnd if err != nil { - fmt.Println("Failed to scan row") s.logger.Errorf("Failed to scan row for table %v: %v", table, err.Error()) if err = s.cleanUpOnFailure(); err != nil { s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) @@ -719,9 +716,9 @@ func (s *Stream) processSnapshot() { } batchEnd := time.Since(rowsStart) - fmt.Printf("Batch duration: %v %s \n", batchEnd, tableName) - fmt.Println("Scan duration", totalScanDuration, tableName) - fmt.Println("Waiting from benthos duration", totalWaitingFromBenthos, tableName) + s.logger.Debugf("Batch duration: %v %s \n", batchEnd, tableName) + s.logger.Debugf("Scan duration %v %s\n", totalScanDuration, tableName) + s.logger.Debugf("Waiting from benthos duration %v %s\n", totalWaitingFromBenthos, tableName) offset += batchSize diff --git a/internal/impl/postgresql/pglogicalstream/monitor.go b/internal/impl/postgresql/pglogicalstream/monitor.go index 3ae98834e8..2285a09629 100644 --- a/internal/impl/postgresql/pglogicalstream/monitor.go +++ b/internal/impl/postgresql/pglogicalstream/monitor.go @@ -49,7 +49,7 @@ type Monitor struct { // NewMonitor creates a new Monitor instance func NewMonitor(conf *pgconn.Config, logger *service.Logger, tables []string, slotName string) (*Monitor, error) { - dbConn, err := openPgConnectionFromConfig(*conf) + dbConn, err := openPgConnectionFromConfig(conf) if err != nil { return nil, err } diff --git a/internal/impl/postgresql/pglogicalstream/snapshotter.go b/internal/impl/postgresql/pglogicalstream/snapshotter.go index 525e6d8d4c..0511e9e4f8 100644 --- a/internal/impl/postgresql/pglogicalstream/snapshotter.go +++ b/internal/impl/postgresql/pglogicalstream/snapshotter.go @@ -40,7 +40,7 @@ type Snapshotter struct { } // NewSnapshotter creates a new Snapshotter instance -func NewSnapshotter(dbConf pgconn.Config, logger *service.Logger, version int) (*Snapshotter, error) { +func NewSnapshotter(dbConf *pgconn.Config, logger *service.Logger, version int) (*Snapshotter, error) { pgConn, err := openPgConnectionFromConfig(dbConf) if err != nil { return nil, err diff --git a/internal/impl/postgresql/pglogicalstream/util.go b/internal/impl/postgresql/pglogicalstream/util.go index 6765f868de..d24995fc6b 100644 --- a/internal/impl/postgresql/pglogicalstream/util.go +++ b/internal/impl/postgresql/pglogicalstream/util.go @@ -14,24 +14,21 @@ import ( "regexp" "strconv" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/stdlib" ) -func openPgConnectionFromConfig(dbConf pgconn.Config) (*sql.DB, error) { - var sslMode string - if dbConf.TLSConfig != nil { - sslMode = "require" - } else { - sslMode = "disable" - } - connStr := fmt.Sprintf("user=%s password=%s host=%s port=%d dbname=%s sslmode=%s", dbConf.User, - dbConf.Password, dbConf.Host, dbConf.Port, dbConf.Database, sslMode, - ) +func openPgConnectionFromConfig(dbConf *pgconn.Config) (*sql.DB, error) { + conf, _ := pgx.ParseConfig("") + delete(dbConf.RuntimeParams, "replication") + conf.Config = *dbConf + connStr := stdlib.RegisterConnConfig(conf) - return sql.Open("postgres", connStr) + return sql.Open("pgx", connStr) } -func getPostgresVersion(connConfig pgconn.Config) (int, error) { +func getPostgresVersion(connConfig *pgconn.Config) (int, error) { conn, err := openPgConnectionFromConfig(connConfig) if err != nil { return 0, fmt.Errorf("failed to connect to the database: %w", err) From f4e14bd2922d856eed97645f27ce45fe95df0be6 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Fri, 1 Nov 2024 16:21:03 +0100 Subject: [PATCH 040/118] ref(): use context when create publication --- internal/impl/postgresql/pglogicalstream/pglogrepl.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/impl/postgresql/pglogicalstream/pglogrepl.go b/internal/impl/postgresql/pglogicalstream/pglogrepl.go index 5975c6d97f..f7a23c1c83 100644 --- a/internal/impl/postgresql/pglogicalstream/pglogrepl.go +++ b/internal/impl/postgresql/pglogicalstream/pglogrepl.go @@ -339,7 +339,7 @@ func DropReplicationSlot(ctx context.Context, conn *pgconn.PgConn, slotName stri // CreatePublication creates a new PostgreSQL publication with the given name for a list of tables and drop if exists flag func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName string, tables []string, dropIfExist bool) error { - result := conn.Exec(context.Background(), fmt.Sprintf("DROP PUBLICATION IF EXISTS %s;", publicationName)) + result := conn.Exec(ctx, fmt.Sprintf("DROP PUBLICATION IF EXISTS %s;", publicationName)) if _, err := result.ReadAll(); err != nil { return nil } @@ -348,7 +348,7 @@ func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName if len(tables) == 0 { tablesSchemaFilter = "FOR ALL TABLES" } - result = conn.Exec(context.Background(), fmt.Sprintf("CREATE PUBLICATION %s %s;", publicationName, tablesSchemaFilter)) + result = conn.Exec(ctx, fmt.Sprintf("CREATE PUBLICATION %s %s;", publicationName, tablesSchemaFilter)) if _, err := result.ReadAll(); err != nil { return err } From 9ff7079b29089b216868e2b536c38c8464c60c63 Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Mon, 4 Nov 2024 03:22:30 +0000 Subject: [PATCH 041/118] pgcdc: cleanup configuration * By default we were just using a replication slot name of `rs_`. * Cleanup description * Fix typos --- internal/impl/postgresql/input_postgrecdc.go | 32 +++++++++++-------- .../pglogicalstream/logical_stream.go | 3 +- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/internal/impl/postgresql/input_postgrecdc.go b/internal/impl/postgresql/input_postgrecdc.go index 8bd559e47b..0d30f46422 100644 --- a/internal/impl/postgresql/input_postgrecdc.go +++ b/internal/impl/postgresql/input_postgrecdc.go @@ -18,6 +18,7 @@ import ( "time" "github.com/Jeffail/checkpoint" + "github.com/google/uuid" "github.com/jackc/pgx/v5/pgconn" "github.com/redpanda-data/benthos/v4/public/service" @@ -31,8 +32,8 @@ const ( fieldPass = "password" fieldSchema = "schema" fieldDatabase = "database" - fieldTls = "tls" - fieldStreamUncomitted = "stream_uncomitted" + fieldTLS = "tls" + fieldStreamUncommitted = "stream_uncomitted" fieldPgConnOptions = "pg_conn_options" fieldStreamSnapshot = "stream_snapshot" fieldSnapshotMemSafetyFactor = "snapshot_memory_safety_factor" @@ -45,8 +46,6 @@ const ( fieldBatching = "batching" ) -var randomSlotName string - type asyncMessage struct { msg service.MessageBatch ackFn service.AckFunc @@ -55,12 +54,15 @@ type asyncMessage struct { var pgStreamConfigSpec = service.NewConfigSpec(). Beta(). Categories("Services"). - Version("0.0.1"). - Summary(`Creates a PostgreSQL replication slot for Change Data Capture (CDC) - == Metadata + Version("4.39.0"). + Summary(`Streams changes from a PostgreSQL database using logical replication.`). + Description(`Streams changes from a PostgreSQL database for Change Data Capture (CDC). +Additionally, if ` + "`" + fieldStreamSnapshot + "`" + ` is set to true, then the existing data in the database is also streamed too. + +== Metadata This input adds the following metadata fields to each message: -- streaming (Indicates whether the message is part of a streaming operation or snapshot processing) +- streaming (Boolean indicating whether the message is part of a streaming operation or snapshot processing) - table (Name of the table that the message originated from) - operation (Type of operation that generated the message, such as INSERT, UPDATE, or DELETE) `). @@ -81,10 +83,10 @@ This input adds the following metadata fields to each message: Description("The PostgreSQL schema from which to replicate data.")). Field(service.NewStringField(fieldDatabase). Description("The name of the PostgreSQL database to connect to.")). - Field(service.NewTLSToggledField(fieldTls). + Field(service.NewTLSToggledField(fieldTLS). Description("Specifies whether to use TLS for the database connection. Set to 'require' to enforce TLS, or 'none' to disable it."). Default(nil)). - Field(service.NewBoolField(fieldStreamUncomitted). + Field(service.NewBoolField(fieldStreamUncommitted). Description("If set to true, the plugin will stream uncommitted transactions before receiving a commit message from PostgreSQL. This may result in duplicate records if the connector is restarted."). Default(false)). Field(service.NewStringField(fieldPgConnOptions). @@ -124,7 +126,7 @@ This input adds the following metadata fields to each message: Field(service.NewStringField(fieldSlotName). Description("The name of the PostgreSQL logical replication slot to use. If not provided, a random name will be generated. You can create this slot manually before starting replication if desired."). Example("my_test_slot"). - Default(randomSlotName)). + Default("")). Field(service.NewAutoRetryNacksToggleField()). Field(service.NewBatchPolicyField(fieldBatching)) @@ -158,6 +160,10 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser if err != nil { return nil, err } + // Set the default to be a random string + if dbSlotName == "" { + dbSlotName = uuid.NewString() + } temporarySlot, err = conf.FieldBool(fieldTemporarySlot) if err != nil { @@ -174,7 +180,7 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser return nil, err } - tlsConf, tlsEnabled, err := conf.FieldTLSToggled(fieldTls) + tlsConf, tlsEnabled, err := conf.FieldTLSToggled(fieldTLS) if err != nil { return nil, err } @@ -208,7 +214,7 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser return nil, err } - streamUncomited, err = conf.FieldBool(fieldStreamUncomitted) + streamUncomited, err = conf.FieldBool(fieldStreamUncommitted) if err != nil { return nil, err } diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index de9b9e4d85..d73cf7520a 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -110,8 +110,7 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { tableNames = append(tableNames, config.DBTables...) if config.ReplicationSlotName == "" { - rng, _ := codename.DefaultRNG() - config.ReplicationSlotName = strings.ReplaceAll(codename.Generate(rng, 5), "-", "_") + return nil, errors.New("missing replication slot name") } stream := &Stream{ From 2517627c355d716564488139bb0e4744fa9915fb Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Mon, 4 Nov 2024 03:59:35 +0000 Subject: [PATCH 042/118] pgcdc: simplify stream setup Just have the user give us a DSN that is standard and our SQL* plugins already expect this format. That fixes bugs we have with special characters that need escaping, and generally simplfies setup. Also fixes: - Don't os.Exit, but bubble an error up - Use provided context instead of context.Background - Prevent SQL injection attacks in slot names --- go.mod | 1 - go.sum | 2 - internal/impl/postgresql/input_postgrecdc.go | 132 ++++-------------- .../impl/postgresql/pglogicalstream/config.go | 41 ++---- .../pglogicalstream/logical_stream.go | 110 ++++++--------- 5 files changed, 81 insertions(+), 205 deletions(-) diff --git a/go.mod b/go.mod index eaa968b99b..a1e6a1e1ce 100644 --- a/go.mod +++ b/go.mod @@ -75,7 +75,6 @@ require ( github.com/jhump/protoreflect v1.16.0 github.com/lib/pq v1.10.9 github.com/linkedin/goavro/v2 v2.13.0 - github.com/lucasepe/codename v0.2.0 github.com/matoous/go-nanoid/v2 v2.1.0 github.com/microcosm-cc/bluemonday v1.0.27 github.com/microsoft/gocosmos v1.1.1 diff --git a/go.sum b/go.sum index ac8c4e9bb5..a2c41fe1b0 100644 --- a/go.sum +++ b/go.sum @@ -815,8 +815,6 @@ github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/linkedin/goavro/v2 v2.13.0 h1:L8eI8GcuciwUkt41Ej62joSZS4kKaYIUdze+6for9NU= github.com/linkedin/goavro/v2 v2.13.0/go.mod h1:KXx+erlq+RPlGSPmLF7xGo6SAbh8sCQ53x064+ioxhk= -github.com/lucasepe/codename v0.2.0 h1:zkW9mKWSO8jjVIYFyZWE9FPvBtFVJxgMpQcMkf4Vv20= -github.com/lucasepe/codename v0.2.0/go.mod h1:RDcExRuZPWp5Uz+BosvpROFTrxpt5r1vSzBObHdBdDM= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= github.com/lufia/plan9stats v0.0.0-20240226150601-1dcf7310316a h1:3Bm7EwfUQUvhNeKIkUct/gl9eod1TcXuj8stxvi/GoI= github.com/lufia/plan9stats v0.0.0-20240226150601-1dcf7310316a/go.mod h1:ilwx/Dta8jXAgpFYFvSWEMwxmbWXyiUHkd5FwyKhb5k= diff --git a/internal/impl/postgresql/input_postgrecdc.go b/internal/impl/postgresql/input_postgrecdc.go index 0d30f46422..26db1c381a 100644 --- a/internal/impl/postgresql/input_postgrecdc.go +++ b/internal/impl/postgresql/input_postgrecdc.go @@ -10,7 +10,6 @@ package pgstream import ( "context" - "crypto/tls" "encoding/json" "strconv" "sync" @@ -26,19 +25,13 @@ import ( ) const ( - fieldHost = "host" - fieldPort = "port" - fieldUser = "user" - fieldPass = "password" - fieldSchema = "schema" - fieldDatabase = "database" - fieldTLS = "tls" + fieldDSN = "dsn" fieldStreamUncommitted = "stream_uncomitted" - fieldPgConnOptions = "pg_conn_options" fieldStreamSnapshot = "stream_snapshot" fieldSnapshotMemSafetyFactor = "snapshot_memory_safety_factor" fieldSnapshotBatchSize = "snapshot_batch_size" fieldDecodingPlugin = "decoding_plugin" + fieldSchema = "schema" fieldTables = "tables" fieldCheckpointLimit = "checkpoint_limit" fieldTemporarySlot = "temporary_slot" @@ -66,33 +59,12 @@ This input adds the following metadata fields to each message: - table (Name of the table that the message originated from) - operation (Type of operation that generated the message, such as INSERT, UPDATE, or DELETE) `). - Field(service.NewStringField(fieldHost). - Description("The hostname or IP address of the PostgreSQL instance."). - Example("123.0.0.1")). - Field(service.NewIntField(fieldPort). - Description("The port number on which the PostgreSQL instance is listening."). - Example(5432). - Default(5432)). - Field(service.NewStringField(fieldUser). - Description("Username of a user with replication permissions. For AWS RDS, this typically requires superuser privileges."). - Example("postgres"), - ). - Field(service.NewStringField(fieldPass). - Description("Password for the specified PostgreSQL user.")). - Field(service.NewStringField(fieldSchema). - Description("The PostgreSQL schema from which to replicate data.")). - Field(service.NewStringField(fieldDatabase). - Description("The name of the PostgreSQL database to connect to.")). - Field(service.NewTLSToggledField(fieldTLS). - Description("Specifies whether to use TLS for the database connection. Set to 'require' to enforce TLS, or 'none' to disable it."). - Default(nil)). + Field(service.NewStringField(fieldDSN). + Description("The Data Source Name for the PostgreSQL database in the form of `postgres://[user[:password]@][netloc][:port][/dbname][?param1=value1&...]`. Please note that Postgres enforces SSL by default, you can override this with the parameter `sslmode=disable` if required."). + Example("postgres://foouser:foopass@localhost:5432/foodb?sslmode=disable")). Field(service.NewBoolField(fieldStreamUncommitted). Description("If set to true, the plugin will stream uncommitted transactions before receiving a commit message from PostgreSQL. This may result in duplicate records if the connector is restarted."). Default(false)). - Field(service.NewStringField(fieldPgConnOptions). - Description("Additional PostgreSQL connection options as a string. Refer to PostgreSQL documentation for available options."). - Default(""), - ). Field(service.NewBoolField(fieldStreamSnapshot). Description("When set to true, the plugin will first stream a snapshot of all existing data in the database before streaming changes."). Example(true). @@ -111,6 +83,9 @@ This input adds the following metadata fields to each message: `). Example("pgoutput"). Default("pgoutput")). + Field(service.NewStringField(fieldSchema). + Description("The PostgreSQL schema from which to replicate data."). + Example("public")). Field(service.NewStringListField(fieldTables). Description("A list of table names to include in the logical replication. Each table should be specified as a separate item."). Example(` @@ -132,26 +107,21 @@ This input adds the following metadata fields to each message: func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s service.BatchInput, err error) { var ( - dbName string - dbPort int - dbHost string - dbSchema string - dbUser string - dbPassword string + dsn string dbSlotName string temporarySlot bool + schema string tables []string streamSnapshot bool snapshotMemSafetyFactor float64 decodingPlugin string - pgConnOptions string - streamUncomited bool + streamUncommitted bool snapshotBatchSize int checkpointLimit int batching service.BatchPolicy ) - dbSchema, err = conf.FieldString(fieldSchema) + dsn, err = conf.FieldString(fieldDSN) if err != nil { return nil, err } @@ -170,32 +140,7 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser return nil, err } - dbPassword, err = conf.FieldString(fieldPass) - if err != nil { - return nil, err - } - - dbUser, err = conf.FieldString(fieldUser) - if err != nil { - return nil, err - } - - tlsConf, tlsEnabled, err := conf.FieldTLSToggled(fieldTLS) - if err != nil { - return nil, err - } - - dbName, err = conf.FieldString(fieldDatabase) - if err != nil { - return nil, err - } - - dbHost, err = conf.FieldString(fieldHost) - if err != nil { - return nil, err - } - - dbPort, err = conf.FieldInt(fieldPort) + schema, err = conf.FieldString(fieldSchema) if err != nil { return nil, err } @@ -214,7 +159,7 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser return nil, err } - streamUncomited, err = conf.FieldBool(fieldStreamUncommitted) + streamUncommitted, err = conf.FieldBool(fieldStreamUncommitted) if err != nil { return nil, err } @@ -234,33 +179,16 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser return nil, err } - if pgConnOptions, err = conf.FieldString(fieldPgConnOptions); err != nil { - return nil, err - } - - if pgConnOptions != "" { - pgConnOptions = "options=" + pgConnOptions - } - if batching, err = conf.FieldBatchPolicy(fieldBatching); err != nil { return nil, err } else if batching.IsNoop() { batching.Count = 1 } - pgconnConfig := pgconn.Config{ - Host: dbHost, - Port: uint16(dbPort), - Database: dbName, - User: dbUser, - TLSConfig: tlsConf, - Password: dbPassword, - } - - if !tlsEnabled { - pgconnConfig.TLSConfig = nil - tlsConf = nil - } + pgconnConfig, err := pgconn.ParseConfigWithOptions(dsn, pgconn.ParseConfigOptions{ + // Don't support dynamic reading of password + GetSSLPassword: func(context.Context) string { return "" }, + }) snapsotMetrics := mgr.Metrics().NewGauge("snapshot_progress", "table") replicationLag := mgr.Metrics().NewGauge("replication_lag_bytes") @@ -270,12 +198,10 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser streamSnapshot: streamSnapshot, snapshotMemSafetyFactor: snapshotMemSafetyFactor, slotName: dbSlotName, - schema: dbSchema, - pgConnRuntimeParam: pgConnOptions, - tls: tlsConf, + schema: schema, tables: tables, decodingPlugin: decodingPlugin, - streamUncomited: streamUncomited, + streamUncommitted: streamUncommitted, temporarySlot: temporarySlot, snapshotBatchSize: snapshotBatchSize, batching: batching, @@ -312,10 +238,8 @@ func init() { } type pgStreamInput struct { - dbConfig pgconn.Config - tls *tls.Config + dbConfig *pgconn.Config pglogicalStream *pglogicalstream.Stream - pgConnRuntimeParam string slotName string temporarySlot bool schema string @@ -324,7 +248,7 @@ type pgStreamInput struct { streamSnapshot bool snapshotMemSafetyFactor float64 snapshotBatchSize int - streamUncomited bool + streamUncommitted bool logger *service.Logger mgr *service.Resources metrics *service.Metrics @@ -344,20 +268,14 @@ type pgStreamInput struct { func (p *pgStreamInput) Connect(ctx context.Context) error { pgStream, err := pglogicalstream.NewPgStream(ctx, &pglogicalstream.Config{ - PgConnRuntimeParam: p.pgConnRuntimeParam, - DBHost: p.dbConfig.Host, - DBPassword: p.dbConfig.Password, - DBUser: p.dbConfig.User, - DBPort: int(p.dbConfig.Port), - DBTables: p.tables, - DBName: p.dbConfig.Database, + DBConfig: p.dbConfig, DBSchema: p.schema, + DBTables: p.tables, ReplicationSlotName: "rs_" + p.slotName, - TLSConfig: p.tls, BatchSize: p.snapshotBatchSize, StreamOldData: p.streamSnapshot, TemporaryReplicationSlot: p.temporarySlot, - StreamUncomited: p.streamUncomited, + StreamUncommitted: p.streamUncommitted, DecodingPlugin: p.decodingPlugin, SnapshotMemorySafetyFactor: p.snapshotMemSafetyFactor, diff --git a/internal/impl/postgresql/pglogicalstream/config.go b/internal/impl/postgresql/pglogicalstream/config.go index 3590a6ddf0..f17322db3f 100644 --- a/internal/impl/postgresql/pglogicalstream/config.go +++ b/internal/impl/postgresql/pglogicalstream/config.go @@ -9,47 +9,32 @@ package pglogicalstream import ( - "crypto/tls" - + "github.com/jackc/pgx/v5/pgconn" "github.com/redpanda-data/benthos/v4/public/service" ) // Config is the configuration for the pglogicalstream plugin type Config struct { - // DbHost is the host of the PostgreSQL instance - DBHost string `yaml:"db_host"` - // DbPassword is the password for the PostgreSQL instance - DBPassword string `yaml:"db_password"` - // DbUser is the user for the PostgreSQL instance - DBUser string `yaml:"db_user"` - // DbPort is the port of the PostgreSQL instance - DBPort int `yaml:"db_port"` - // DbName is the name of the database to connect to - DBName string `yaml:"db_name"` - // DbSchema is the schema to stream changes from - DBSchema string `yaml:"db_schema"` + // DBConfig is the configuration to connect to the database with + DBConfig *pgconn.Config + // The DB schema to lookup tables in + DBSchema string // DbTables is the tables to stream changes from - DBTables []string `yaml:"db_tables"` - // TLSConfig is the TLS verification configuration - TLSConfig *tls.Config `yaml:"tls"` - // PgConnRuntimeParam is the runtime parameter for the PostgreSQL connection - PgConnRuntimeParam string `yaml:"pg_conn_options"` - + DBTables []string // ReplicationSlotName is the name of the replication slot to use - ReplicationSlotName string `yaml:"replication_slot_name"` + ReplicationSlotName string // TemporaryReplicationSlot is whether to use a temporary replication slot - TemporaryReplicationSlot bool `yaml:"temporary_replication_slot"` + TemporaryReplicationSlot bool // StreamOldData is whether to stream all existing data - StreamOldData bool `yaml:"stream_old_data"` + StreamOldData bool // SnapshotMemorySafetyFactor is the memory safety factor for streaming snapshot - SnapshotMemorySafetyFactor float64 `yaml:"snapshot_memory_safety_factor"` + SnapshotMemorySafetyFactor float64 // DecodingPlugin is the decoding plugin to use - DecodingPlugin string `yaml:"decoding_plugin"` + DecodingPlugin string // BatchSize is the batch size for streaming - BatchSize int `yaml:"batch_size"` - + BatchSize int // StreamUncommitted is whether to stream uncommitted messages before receiving commit message - StreamUncomited bool `yaml:"stream_uncommitted"` + StreamUncommitted bool logger *service.Logger } diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index d73cf7520a..1506bea785 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -14,6 +14,7 @@ import ( "errors" "fmt" "os" + "slices" "strings" "sync" "time" @@ -21,7 +22,6 @@ import ( "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgproto3" "github.com/jackc/pgx/v5/pgtype" - "github.com/lucasepe/codename" "github.com/redpanda-data/benthos/v4/public/service" ) @@ -68,60 +68,29 @@ type Stream struct { // NewPgStream creates a new instance of the Stream struct func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { - var ( - cfg *pgconn.Config - err error - ) - - connectionParams := "" - if config.PgConnRuntimeParam != "" { - connectionParams = "&" + config.PgConnRuntimeParam - } - - // intentiolly omit password to support password with special characters - // a new lines below cfg.Password = config.DBPassword - // we also need to use pgconn.ParseConfig since we are going to use this pased config to connect to the database for snapshot - // pgx panics when connection is not created via PaseConfig method - q := fmt.Sprintf("postgres://%s@%s:%d/%s?replication=database%s", - config.DBUser, - config.DBHost, - config.DBPort, - config.DBName, - connectionParams, - ) - - if cfg, err = pgconn.ParseConfig(q); err != nil { - return nil, err + if config.ReplicationSlotName == "" { + return nil, errors.New("missing replication slot name") } - cfg.Password = config.DBPassword - cfg.TLSConfig = config.TLSConfig - - dbConn, err := pgconn.ConnectConfig(context.Background(), cfg) + dbConn, err := pgconn.ConnectConfig(ctx, config.DBConfig) if err != nil { return nil, err } - if err = dbConn.Ping(context.Background()); err != nil { + if err = dbConn.Ping(ctx); err != nil { return nil, err } - var tableNames []string - tableNames = append(tableNames, config.DBTables...) - - if config.ReplicationSlotName == "" { - return nil, errors.New("missing replication slot name") - } - + tableNames := slices.Clone(config.DBTables) stream := &Stream{ pgConn: dbConn, - dbConfig: cfg, + dbConfig: config.DBConfig, messages: make(chan StreamMessage), slotName: config.ReplicationSlotName, - schema: config.DBSchema, snapshotMemorySafetyFactor: config.SnapshotMemorySafetyFactor, - streamUncomited: config.StreamUncomited, + streamUncomited: config.StreamUncommitted, snapshotBatchSize: config.BatchSize, + schema: config.DBSchema, tableNames: tableNames, consumedCallback: make(chan bool), transactionAckChan: make(chan string), @@ -137,7 +106,7 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { } var version int - version, err = getPostgresVersion(cfg) + version, err = getPostgresVersion(config.DBConfig) if err != nil { return nil, err } @@ -148,8 +117,7 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { if err = stream.cleanUpOnFailure(); err != nil { stream.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) } - - os.Exit(1) + return nil, err } stream.snapshotter = snapshotter @@ -163,20 +131,22 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { if version > 14 { pluginArguments = append(pluginArguments, "messages 'true'") } - } - - if stream.decodingPlugin == "wal2json" { + } else if stream.decodingPlugin == "wal2json" { tablesFilterRule := strings.Join(tableNames, ", ") pluginArguments = []string{ "\"pretty-print\" 'true'", "\"add-tables\"" + " " + fmt.Sprintf("'%s'", tablesFilterRule), } + } else { + return nil, fmt.Errorf("unknown decoding plugin: %q", stream.decodingPlugin) } stream.decodingPluginArguments = pluginArguments pubName := "pglog_stream_" + config.ReplicationSlotName stream.logger.Debugf("Creating publication %s for tables: %s", pubName, tableNames) + // QUESTION: Do we always want to drop existing publications? Does that stop old connect streams that + // are using the same replication slot name? if err = CreatePublication(ctx, stream.pgConn, pubName, tableNames, true); err != nil { return nil, err } @@ -190,27 +160,33 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { var confirmedLSNFromDB string var outputPlugin string // check is replication slot exist to get last restart SLN - connExecResult := stream.pgConn.Exec(ctx, fmt.Sprintf("SELECT confirmed_flush_lsn, plugin FROM pg_replication_slots WHERE slot_name = '%s'", config.ReplicationSlotName)) - if slotCheckResults, err := connExecResult.ReadAll(); err != nil { - return nil, err - } else { - if len(slotCheckResults) == 0 || len(slotCheckResults[0].Rows) == 0 { - // here we create a new replication slot because there is no slot found - var createSlotResult CreateReplicationSlotResult - createSlotResult, err = CreateReplicationSlot(ctx, stream.pgConn, stream.slotName, stream.decodingPlugin.String(), - CreateReplicationSlotOptions{Temporary: config.TemporaryReplicationSlot, - SnapshotAction: "export", - }, version, stream.snapshotter) - if err != nil { - return nil, err - } - stream.snapshotName = createSlotResult.SnapshotName - freshlyCreatedSlot = true - } else { - slotCheckRow := slotCheckResults[0].Rows[0] - confirmedLSNFromDB = string(slotCheckRow[0]) - outputPlugin = string(slotCheckRow[1]) + connExecResult := stream.pgConn.ExecParams( + ctx, + "SELECT confirmed_flush_lsn, plugin FROM pg_replication_slots WHERE slot_name = $1", + [][]byte{[]byte(config.ReplicationSlotName)}, + nil, + nil, + nil, + ).Read() + if connExecResult.Err != nil { + return nil, connExecResult.Err + } + if len(connExecResult.Rows) == 0 { + // here we create a new replication slot because there is no slot found + var createSlotResult CreateReplicationSlotResult + createSlotResult, err = CreateReplicationSlot(ctx, stream.pgConn, stream.slotName, stream.decodingPlugin.String(), + CreateReplicationSlotOptions{Temporary: config.TemporaryReplicationSlot, + SnapshotAction: "export", + }, version, stream.snapshotter) + if err != nil { + return nil, err } + stream.snapshotName = createSlotResult.SnapshotName + freshlyCreatedSlot = true + } else { + slotCheckRow := connExecResult.Rows[0] + confirmedLSNFromDB = string(slotCheckRow[0]) + outputPlugin = string(slotCheckRow[1]) } if !freshlyCreatedSlot && outputPlugin != stream.decodingPlugin.String() { @@ -236,7 +212,7 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { stream.nextStandbyMessageDeadline = time.Now().Add(stream.standbyMessageTimeout) stream.streamCtx, stream.streamCancel = context.WithCancel(context.Background()) - monitor, err := NewMonitor(cfg, stream.logger, tableNames, stream.slotName) + monitor, err := NewMonitor(config.DBConfig, stream.logger, tableNames, stream.slotName) if err != nil { return nil, err } From 79d9f1faaac6c0c733b35206b159bd7cb757fb11 Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Mon, 4 Nov 2024 04:48:43 +0000 Subject: [PATCH 043/118] more review feedback. This got to be a lot so just checkpointing so Vlad can see where I am going. --- internal/impl/postgresql/input_postgrecdc.go | 29 ++-- internal/impl/postgresql/integration_test.go | 75 ++++------- .../impl/postgresql/pglogicalstream/config.go | 2 +- .../pglogicalstream/logical_stream.go | 126 ++++++++---------- 4 files changed, 94 insertions(+), 138 deletions(-) diff --git a/internal/impl/postgresql/input_postgrecdc.go b/internal/impl/postgresql/input_postgrecdc.go index 26db1c381a..c40bd42321 100644 --- a/internal/impl/postgresql/input_postgrecdc.go +++ b/internal/impl/postgresql/input_postgrecdc.go @@ -11,6 +11,7 @@ package pgstream import ( "context" "encoding/json" + "fmt" "strconv" "sync" "sync/atomic" @@ -189,6 +190,12 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser // Don't support dynamic reading of password GetSSLPassword: func(context.Context) string { return "" }, }) + if err != nil { + return nil, err + } + // This is required for postgres to understand we're interested in replication. + // https://github.com/jackc/pglogrepl/issues/6 + pgconnConfig.RuntimeParams["replication"] = "database" snapsotMetrics := mgr.Metrics().NewGauge("snapshot_progress", "table") replicationLag := mgr.Metrics().NewGauge("replication_lag_bytes") @@ -268,20 +275,20 @@ type pgStreamInput struct { func (p *pgStreamInput) Connect(ctx context.Context) error { pgStream, err := pglogicalstream.NewPgStream(ctx, &pglogicalstream.Config{ - DBConfig: p.dbConfig, - DBSchema: p.schema, - DBTables: p.tables, - ReplicationSlotName: "rs_" + p.slotName, - BatchSize: p.snapshotBatchSize, - StreamOldData: p.streamSnapshot, - TemporaryReplicationSlot: p.temporarySlot, - StreamUncommitted: p.streamUncommitted, - DecodingPlugin: p.decodingPlugin, - + DBConfig: p.dbConfig, + DBSchema: p.schema, + DBTables: p.tables, + ReplicationSlotName: "rs_" + p.slotName, + BatchSize: p.snapshotBatchSize, + StreamOldData: p.streamSnapshot, + TemporaryReplicationSlot: p.temporarySlot, + StreamUncommitted: p.streamUncommitted, + DecodingPlugin: p.decodingPlugin, SnapshotMemorySafetyFactor: p.snapshotMemSafetyFactor, + Logger: p.logger, }) if err != nil { - return err + return fmt.Errorf("unable to create replication stream: %w", err) } p.pglogicalStream = pgStream diff --git a/internal/impl/postgresql/integration_test.go b/internal/impl/postgresql/integration_test.go index c7bfd2c9fe..5eda5825c8 100644 --- a/internal/impl/postgresql/integration_test.go +++ b/internal/impl/postgresql/integration_test.go @@ -183,18 +183,14 @@ func TestIntegrationPgCDC(t *testing.T) { template := fmt.Sprintf(` pg_stream: - host: %s + dsn: %s slot_name: test_slot - user: user_name - password: %s - port: %s - schema: public decoding_plugin: wal2json stream_snapshot: true - database: dbname + schema: public tables: - flights -`, hostAndPortSplited[0], password, hostAndPortSplited[1]) +`, databaseURL) cacheConf := fmt.Sprintf(` label: pg_stream_cache @@ -316,21 +312,17 @@ func TestIntegrationPgCDCForPgOutputPlugin(t *testing.T) { require.NoError(t, err) } + databaseURL := fmt.Sprintf("user=user_name password=%s dbname=dbname sslmode=disable host=%s port=%s", password, hostAndPortSplited[0], hostAndPortSplited[1]) template := fmt.Sprintf(` pg_stream: - host: %s + dsn: %s slot_name: test_slot_native_decoder - user: user_name - password: %s - port: %s - schema: public - tls: none stream_snapshot: true decoding_plugin: pgoutput - database: dbname + schema: public tables: - flights -`, hostAndPortSplited[0], password, hostAndPortSplited[1]) +`, databaseURL) cacheConf := fmt.Sprintf(` label: pg_stream_cache @@ -433,34 +425,22 @@ func TestIntegrationPgStreamingFromRemoteDB(t *testing.T) { // tables: users, products, orders, order_items - host := "localhost" - user := "postgres" - password := "postgres" - dbname := "postgres" - port := "5432" - sslmode := "none" - template := fmt.Sprintf(` pg_stream: - host: %s + dsn: postgres://postgres:postgres@localhost:5432/postgres?sslmode=disable slot_name: test_slot_native_decoder - user: %s - password: %s - port: %s - schema: public - tls: %s snapshot_batch_size: 100000 stream_snapshot: true decoding_plugin: pgoutput stream_uncomitted: false temporary_slot: true - database: %s + schema: public tables: - users - products - orders - order_items -`, host, user, password, port, sslmode, dbname) +`) cacheConf := fmt.Sprintf(` label: pg_stream_cache @@ -527,7 +507,7 @@ file: require.NoError(t, streamOut.StopWithin(time.Second*10)) } -func TestIntegrationPgCDCForPgOutputStreamUncomitedPlugin(t *testing.T) { +func TestIntegrationPgCDCForPgOutputStreamUncommittedPlugin(t *testing.T) { tmpDir := t.TempDir() pool, err := dockertest.NewPool("") require.NoError(t, err) @@ -551,22 +531,19 @@ func TestIntegrationPgCDCForPgOutputStreamUncomitedPlugin(t *testing.T) { require.NoError(t, err) } + databaseURL := fmt.Sprintf("user=user_name password=%s dbname=dbname sslmode=disable host=%s port=%s", password, hostAndPortSplited[0], hostAndPortSplited[1]) template := fmt.Sprintf(` pg_stream: - host: %s + dsn: %s slot_name: test_slot_native_decoder - user: user_name - password: %s - port: %s - schema: public snapshot_batch_size: 100 stream_snapshot: true decoding_plugin: pgoutput stream_uncomitted: true - database: dbname + schema: public tables: - flights -`, hostAndPortSplited[0], password, hostAndPortSplited[1]) +`, databaseURL) cacheConf := fmt.Sprintf(` label: pg_stream_cache @@ -691,21 +668,18 @@ func TestIntegrationPgMultiVersionsCDCForPgOutputStreamUncomitedPlugin(t *testin require.NoError(t, err) } + databaseURL := fmt.Sprintf("user=user_name password=%s dbname=dbname sslmode=disable host=%s port=%s", password, hostAndPortSplited[0], hostAndPortSplited[1]) template := fmt.Sprintf(` pg_stream: - host: %s + dsn: %s slot_name: test_slot_native_decoder - user: user_name - password: %s - port: %s - schema: public stream_snapshot: true decoding_plugin: pgoutput stream_uncomitted: true - database: dbname + schema: public tables: - flights -`, hostAndPortSplited[0], password, hostAndPortSplited[1]) +`, databaseURL) cacheConf := fmt.Sprintf(` label: pg_stream_cache @@ -828,21 +802,18 @@ func TestIntegrationPgMultiVersionsCDCForPgOutputStreamComittedPlugin(t *testing require.NoError(t, err) } + databaseURL := fmt.Sprintf("user=user_name password=%s dbname=dbname sslmode=disable host=%s port=%s", password, hostAndPortSplited[0], hostAndPortSplited[1]) template := fmt.Sprintf(` pg_stream: - host: %s + dsn: %s slot_name: test_slot_native_decoder - user: user_name - password: %s - port: %s - schema: public stream_snapshot: true decoding_plugin: pgoutput stream_uncomitted: false - database: dbname + schema: public tables: - flights -`, hostAndPortSplited[0], password, hostAndPortSplited[1]) +`, databaseURL) cacheConf := fmt.Sprintf(` label: pg_stream_cache diff --git a/internal/impl/postgresql/pglogicalstream/config.go b/internal/impl/postgresql/pglogicalstream/config.go index f17322db3f..63e026ab74 100644 --- a/internal/impl/postgresql/pglogicalstream/config.go +++ b/internal/impl/postgresql/pglogicalstream/config.go @@ -36,5 +36,5 @@ type Config struct { // StreamUncommitted is whether to stream uncommitted messages before receiving commit message StreamUncommitted bool - logger *service.Logger + Logger *service.Logger } diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 1506bea785..df4d07fbac 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -13,7 +13,6 @@ import ( "database/sql" "errors" "fmt" - "os" "slices" "strings" "sync" @@ -23,15 +22,13 @@ import ( "github.com/jackc/pgx/v5/pgproto3" "github.com/jackc/pgx/v5/pgtype" "github.com/redpanda-data/benthos/v4/public/service" + "golang.org/x/sync/errgroup" ) // Stream is a structure that represents a logical replication stream // It includes the connection to the database, the context for the stream, and snapshotting functionality type Stream struct { - pgConn *pgconn.PgConn - // extra copy of db config is required to establish a new db connection - // which is required to take snapshot data - dbConfig *pgconn.Config + pgConn *pgconn.PgConn streamCtx context.Context streamCancel context.CancelFunc @@ -72,7 +69,7 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { return nil, errors.New("missing replication slot name") } - dbConn, err := pgconn.ConnectConfig(ctx, config.DBConfig) + dbConn, err := pgconn.ConnectConfig(ctx, config.DBConfig.Copy()) if err != nil { return nil, err } @@ -84,7 +81,6 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { tableNames := slices.Clone(config.DBTables) stream := &Stream{ pgConn: dbConn, - dbConfig: config.DBConfig, messages: make(chan StreamMessage), slotName: config.ReplicationSlotName, snapshotMemorySafetyFactor: config.SnapshotMemorySafetyFactor, @@ -96,7 +92,7 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { transactionAckChan: make(chan string), transactionBeginChan: make(chan bool), lsnAckBuffer: []string{}, - logger: config.logger, + logger: config.Logger, m: sync.Mutex{}, decodingPlugin: decodingPluginFromString(config.DecodingPlugin), } @@ -106,15 +102,15 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { } var version int - version, err = getPostgresVersion(config.DBConfig) + version, err = getPostgresVersion(config.DBConfig.Copy()) if err != nil { return nil, err } - snapshotter, err := NewSnapshotter(stream.dbConfig, stream.logger, version) + snapshotter, err := NewSnapshotter(config.DBConfig.Copy(), stream.logger, version) if err != nil { stream.logger.Errorf("Failed to open SQL connection to prepare snapshot: %v", err.Error()) - if err = stream.cleanUpOnFailure(); err != nil { + if err = stream.cleanUpOnFailure(ctx); err != nil { stream.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) } return nil, err @@ -144,13 +140,10 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { stream.decodingPluginArguments = pluginArguments pubName := "pglog_stream_" + config.ReplicationSlotName - stream.logger.Debugf("Creating publication %s for tables: %s", pubName, tableNames) - // QUESTION: Do we always want to drop existing publications? Does that stop old connect streams that - // are using the same replication slot name? + stream.logger.Infof("Creating publication %s for tables: %s", pubName, tableNames) if err = CreatePublication(ctx, stream.pgConn, pubName, tableNames, true); err != nil { return nil, err } - sysident, err := IdentifySystem(ctx, stream.pgConn) if err != nil { return nil, err @@ -160,18 +153,18 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { var confirmedLSNFromDB string var outputPlugin string // check is replication slot exist to get last restart SLN - connExecResult := stream.pgConn.ExecParams( + + // TODO: There should be a helper method for this that also validates the parameters to fmt here are not possible to cause SQL injection. + // this means we either escape or we validate it's only alphanumeric and `-_` + connExecResult, err := stream.pgConn.Exec( ctx, - "SELECT confirmed_flush_lsn, plugin FROM pg_replication_slots WHERE slot_name = $1", - [][]byte{[]byte(config.ReplicationSlotName)}, - nil, - nil, - nil, - ).Read() - if connExecResult.Err != nil { - return nil, connExecResult.Err + fmt.Sprintf("SELECT confirmed_flush_lsn, plugin FROM pg_replication_slots WHERE slot_name = '%s'", + config.ReplicationSlotName), + ).ReadAll() + if err != nil { + return nil, err } - if len(connExecResult.Rows) == 0 { + if len(connExecResult) == 0 || len(connExecResult[0].Rows) == 0 { // here we create a new replication slot because there is no slot found var createSlotResult CreateReplicationSlotResult createSlotResult, err = CreateReplicationSlot(ctx, stream.pgConn, stream.slotName, stream.decodingPlugin.String(), @@ -184,7 +177,7 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { stream.snapshotName = createSlotResult.SnapshotName freshlyCreatedSlot = true } else { - slotCheckRow := connExecResult.Rows[0] + slotCheckRow := connExecResult[0].Rows[0] confirmedLSNFromDB = string(slotCheckRow[0]) outputPlugin = string(slotCheckRow[1]) } @@ -212,7 +205,7 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { stream.nextStandbyMessageDeadline = time.Now().Add(stream.standbyMessageTimeout) stream.streamCtx, stream.streamCancel = context.WithCancel(context.Background()) - monitor, err := NewMonitor(config.DBConfig, stream.logger, tableNames, stream.slotName) + monitor, err := NewMonitor(config.DBConfig.Copy(), stream.logger, tableNames, stream.slotName) if err != nil { return nil, err } @@ -228,7 +221,7 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { } else { // New messages will be streamed after the snapshot has been processed. // stream.startLr() and stream.streamMessagesAsync() will be called inside stream.processSnapshot() - go stream.processSnapshot() + go stream.processSnapshot(context.Background()) } return stream, err @@ -531,14 +524,13 @@ func (s *Stream) AckTxChan() chan string { return s.transactionAckChan } -func (s *Stream) processSnapshot() { +func (s *Stream) processSnapshot(ctx context.Context) error { if err := s.snapshotter.prepare(); err != nil { s.logger.Errorf("Failed to prepare database snapshot. Probably snapshot is expired...: %v", err.Error()) - if err = s.cleanUpOnFailure(); err != nil { + if err = s.cleanUpOnFailure(ctx); err != nil { s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) } - - os.Exit(1) + return err } defer func() { if err := s.snapshotter.releaseSnapshot(); err != nil { @@ -551,27 +543,31 @@ func (s *Stream) processSnapshot() { s.logger.Infof("Starting snapshot processing") - var wg sync.WaitGroup + var wg errgroup.Group for _, table := range s.tableNames { - wg.Add(1) - go func(tableName string) { + tableName := table + wg.Go(func() (err error) { s.logger.Infof("Processing snapshot for table: %v", table) + defer func() { + if err != nil { + if cleanupErr := s.cleanUpOnFailure(ctx); cleanupErr != nil { + s.logger.Errorf("Failed to clean up resources on accident: %v", cleanupErr.Error()) + } + } + }() + var ( avgRowSizeBytes sql.NullInt64 offset = 0 - err error ) avgRowSizeBytes, err = s.snapshotter.findAvgRowSize(table) if err != nil { s.logger.Errorf("Failed to calculate average row size for table %v: %v", table, err.Error()) - if err = s.cleanUpOnFailure(); err != nil { - s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) - } - os.Exit(1) + return err } availableMemory := getAvailableMemory() @@ -585,11 +581,7 @@ func (s *Stream) processSnapshot() { tablePk, err := s.getPrimaryKeyColumn(table) if err != nil { s.logger.Errorf("Failed to get primary key column for table %v: %v", table, err.Error()) - if err = s.cleanUpOnFailure(); err != nil { - s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) - } - - os.Exit(1) + return err } for { @@ -598,11 +590,7 @@ func (s *Stream) processSnapshot() { if snapshotRows, err = s.snapshotter.querySnapshotData(table, tablePk, batchSize, offset); err != nil { s.logger.Errorf("Failed to query snapshot data for table %v: %v", table, err.Error()) s.logger.Errorf("Failed to query snapshot for table %v: %v", table, err.Error()) - if err = s.cleanUpOnFailure(); err != nil { - s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) - } - - os.Exit(1) + return err } queryDuration := time.Since(queryStart) @@ -611,29 +599,19 @@ func (s *Stream) processSnapshot() { if snapshotRows.Err() != nil { s.logger.Errorf("Failed to get snapshot data for table %v: %v", table, snapshotRows.Err().Error()) s.logger.Errorf("Failed to query snapshot for table %v: %v", table, err.Error()) - if err = s.cleanUpOnFailure(); err != nil { - s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) - } - - os.Exit(1) + return err } columnTypes, err := snapshotRows.ColumnTypes() if err != nil { s.logger.Errorf("Failed to get column types for table %v: %v", table, err.Error()) - if err = s.cleanUpOnFailure(); err != nil { - s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) - } - os.Exit(1) + return err } columnNames, err := snapshotRows.Columns() if err != nil { s.logger.Errorf("Failed to get column names for table %v: %v", table, err.Error()) - if err = s.cleanUpOnFailure(); err != nil { - s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) - } - os.Exit(1) + return err } var rowsCount = 0 @@ -653,10 +631,7 @@ func (s *Stream) processSnapshot() { if err != nil { s.logger.Errorf("Failed to scan row for table %v: %v", table, err.Error()) - if err = s.cleanUpOnFailure(); err != nil { - s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) - } - os.Exit(1) + return err } var data = make(map[string]any) @@ -701,17 +676,20 @@ func (s *Stream) processSnapshot() { break } } - wg.Done() - }(table) + return nil + }) } - wg.Wait() + if err := wg.Wait(); err != nil { + return err + } if err := s.startLr(); err != nil { s.logger.Errorf("Failed to start logical replication after snapshot: %v", err.Error()) - os.Exit(1) + return err } go s.streamMessagesAsync() + return nil } // Messages is a channel that can be used to consume messages from the plugin. It will contain LSN nil for snapshot messages @@ -720,13 +698,13 @@ func (s *Stream) Messages() chan StreamMessage { } // cleanUpOnFailure drops replication slot and publication if database snapshotting was failed for any reason -func (s *Stream) cleanUpOnFailure() error { +func (s *Stream) cleanUpOnFailure(ctx context.Context) error { s.logger.Warnf("Cleaning up resources on accident: %v", s.slotName) - err := DropReplicationSlot(context.Background(), s.pgConn, s.slotName, DropReplicationSlotOptions{Wait: true}) + err := DropReplicationSlot(ctx, s.pgConn, s.slotName, DropReplicationSlotOptions{Wait: true}) if err != nil { s.logger.Errorf("Failed to drop replication slot: %s", err.Error()) } - return s.pgConn.Close(context.TODO()) + return s.pgConn.Close(ctx) } func (s *Stream) getPrimaryKeyColumn(tableName string) (string, error) { From 4e05248557f30c5df6b2b0ffd7cfc37a91612d41 Mon Sep 17 00:00:00 2001 From: Mihai Todor Date: Tue, 5 Nov 2024 02:20:02 +0000 Subject: [PATCH 044/118] Chan cleanup WIP Signed-off-by: Mihai Todor --- internal/impl/postgresql/input_postgrecdc.go | 95 +++++++++----------- 1 file changed, 44 insertions(+), 51 deletions(-) diff --git a/internal/impl/postgresql/input_postgrecdc.go b/internal/impl/postgresql/input_postgrecdc.go index 8bd559e47b..6ca60282e7 100644 --- a/internal/impl/postgresql/input_postgrecdc.go +++ b/internal/impl/postgresql/input_postgrecdc.go @@ -377,11 +377,10 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { }() var nextTimedBatchChan <-chan time.Time - flushBatch := p.asyncCheckpointer() // offsets are nilable since we don't provide offset tracking during the snapshot phase var latestOffset *int64 - + cp := checkpoint.NewCapped[*int64](int64(p.checkpointLimit)) for { select { case <-nextTimedBatchChan: @@ -392,7 +391,7 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { break } - if !flushBatch(ctx, p.msgChan, flushedBatch, latestOffset, nil) { + if !p.flushBatch(ctx, cp, flushedBatch, latestOffset, false) { break } @@ -411,12 +410,10 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { break } - callbackChan := make(chan bool) - if !flushBatch(ctx, p.msgChan, flushedBatch, latestOffset, &callbackChan) { + if !p.flushBatch(ctx, cp, flushedBatch, latestOffset, true) { break } - <-callbackChan if err = p.pglogicalStream.AckLSN(trxCommitLsn); err != nil { p.mgr.Logger().Errorf("Failed to ack LSN: %v", err) break @@ -465,13 +462,11 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { break } if message.IsStreaming { - callbackChan := make(chan bool) - if !flushBatch(ctx, p.msgChan, flushedBatch, latestOffset, &callbackChan) { + if !p.flushBatch(ctx, cp, flushedBatch, latestOffset, true) { break } - <-callbackChan } else { - if !flushBatch(ctx, p.msgChan, flushedBatch, latestOffset, nil) { + if !p.flushBatch(ctx, cp, flushedBatch, latestOffset, false) { break } } @@ -488,53 +483,51 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { return err } -func (p *pgStreamInput) asyncCheckpointer() func(context.Context, chan<- asyncMessage, service.MessageBatch, *int64, *chan bool) bool { - cp := checkpoint.NewCapped[*int64](int64(p.checkpointLimit)) - return func(ctx context.Context, c chan<- asyncMessage, msg service.MessageBatch, lsn *int64, txCommitConfirmChan *chan bool) bool { - if msg == nil { - if txCommitConfirmChan != nil { - go func() { - *txCommitConfirmChan <- true - }() - } - return true - } +func (p *pgStreamInput) flushBatch(ctx context.Context, checkpointer *checkpoint.Capped[*int64], msg service.MessageBatch, lsn *int64, waitForCommit bool) bool { + if msg == nil { + return true + } - resolveFn, err := cp.Track(ctx, lsn, int64(len(msg))) - if err != nil { - if ctx.Err() == nil { - p.mgr.Logger().Errorf("Failed to checkpoint offset: %v\n", err) - } - return false + resolveFn, err := checkpointer.Track(ctx, lsn, int64(len(msg))) + if err != nil { + if ctx.Err() == nil { + p.mgr.Logger().Errorf("Failed to checkpoint offset: %v\n", err) } + return false + } - select { - case c <- asyncMessage{ - msg: msg, - ackFn: func(ctx context.Context, res error) error { - maxOffset := resolveFn() - if maxOffset == nil { - return nil + commitChan := make(chan bool) + if !waitForCommit { + close(commitChan) + } + select { + case p.msgChan <- asyncMessage{ + msg: msg, + ackFn: func(ctx context.Context, res error) error { + maxOffset := resolveFn() + if maxOffset == nil { + return nil + } + p.cMut.Lock() + defer p.cMut.Unlock() + if lsn != nil { + if err = p.pglogicalStream.AckLSN(Int64ToLSN(*lsn)); err != nil { + return err } - p.cMut.Lock() - defer p.cMut.Unlock() - if lsn != nil { - if err = p.pglogicalStream.AckLSN(Int64ToLSN(*lsn)); err != nil { - return err - } - if txCommitConfirmChan != nil { - *txCommitConfirmChan <- true - } + if waitForCommit { + close(commitChan) } - return nil - }, - }: - case <-ctx.Done(): - return false - } - - return true + } + return nil + }, + }: + case <-ctx.Done(): + return false } + + <-commitChan + + return true } func (p *pgStreamInput) ReadBatch(ctx context.Context) (service.MessageBatch, service.AckFunc, error) { From 4a4b7ba79f72ceee47f8e03d6c8332b91862cfda Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Tue, 5 Nov 2024 11:06:12 +0100 Subject: [PATCH 045/118] chore(): addressed pull requests changes --- ...input_postgrecdc.go => input_pg_stream.go} | 126 ++++++++---------- internal/impl/postgresql/integration_test.go | 18 +-- .../impl/postgresql/pglogicalstream/config.go | 2 +- .../{util.go => connection.go} | 0 .../pglogicalstream/logical_stream.go | 6 +- internal/impl/postgresql/utils.go | 43 ++---- 6 files changed, 71 insertions(+), 124 deletions(-) rename internal/impl/postgresql/{input_postgrecdc.go => input_pg_stream.go} (84%) rename internal/impl/postgresql/pglogicalstream/{util.go => connection.go} (100%) diff --git a/internal/impl/postgresql/input_postgrecdc.go b/internal/impl/postgresql/input_pg_stream.go similarity index 84% rename from internal/impl/postgresql/input_postgrecdc.go rename to internal/impl/postgresql/input_pg_stream.go index 8bd559e47b..d028b7fa9a 100644 --- a/internal/impl/postgresql/input_postgrecdc.go +++ b/internal/impl/postgresql/input_pg_stream.go @@ -10,7 +10,6 @@ package pgstream import ( "context" - "crypto/tls" "encoding/json" "strconv" "sync" @@ -32,7 +31,7 @@ const ( fieldSchema = "schema" fieldDatabase = "database" fieldTls = "tls" - fieldStreamUncomitted = "stream_uncomitted" + fieldStreamUncommitted = "stream_uncommitted" fieldPgConnOptions = "pg_conn_options" fieldStreamSnapshot = "stream_snapshot" fieldSnapshotMemSafetyFactor = "snapshot_memory_safety_factor" @@ -45,8 +44,6 @@ const ( fieldBatching = "batching" ) -var randomSlotName string - type asyncMessage struct { msg service.MessageBatch ackFn service.AckFunc @@ -55,18 +52,18 @@ type asyncMessage struct { var pgStreamConfigSpec = service.NewConfigSpec(). Beta(). Categories("Services"). - Version("0.0.1"). + Version("v4.39.0"). Summary(`Creates a PostgreSQL replication slot for Change Data Capture (CDC) - == Metadata +== Metadata This input adds the following metadata fields to each message: -- streaming (Indicates whether the message is part of a streaming operation or snapshot processing) +- is_streaming (Indicates whether the message is part of a streaming operation or snapshot processing) - table (Name of the table that the message originated from) - operation (Type of operation that generated the message, such as INSERT, UPDATE, or DELETE) `). Field(service.NewStringField(fieldHost). Description("The hostname or IP address of the PostgreSQL instance."). - Example("123.0.0.1")). + Example("127.0.0.1")). Field(service.NewIntField(fieldPort). Description("The port number on which the PostgreSQL instance is listening."). Example(5432). @@ -84,7 +81,7 @@ This input adds the following metadata fields to each message: Field(service.NewTLSToggledField(fieldTls). Description("Specifies whether to use TLS for the database connection. Set to 'require' to enforce TLS, or 'none' to disable it."). Default(nil)). - Field(service.NewBoolField(fieldStreamUncomitted). + Field(service.NewBoolField(fieldStreamUncommitted). Description("If set to true, the plugin will stream uncommitted transactions before receiving a commit message from PostgreSQL. This may result in duplicate records if the connector is restarted."). Default(false)). Field(service.NewStringField(fieldPgConnOptions). @@ -105,7 +102,7 @@ This input adds the following metadata fields to each message: Default(0)). Field(service.NewStringEnumField(fieldDecodingPlugin, "pgoutput", "wal2json"). Description(`Specifies the logical decoding plugin to use for streaming changes from PostgreSQL. 'pgoutput' is the native logical replication protocol, while 'wal2json' provides change data as JSON. - Important: No matter which plugin you choose, the data will be converted to JSON before sending it to Benthos. + Important: No matter which plugin you choose, the data will be converted to JSON before sending it to Connect. `). Example("pgoutput"). Default("pgoutput")). @@ -117,14 +114,14 @@ This input adds the following metadata fields to each message: `)). Field(service.NewIntField(fieldCheckpointLimit). Description("The maximum number of messages that can be processed at a given time. Increasing this limit enables parallel processing and batching at the output level. Any given LSN will not be acknowledged unless all messages under that offset are delivered in order to preserve at least once delivery guarantees."). - Version("3.33.0").Default(1024)). + Default(1024)). Field(service.NewBoolField(fieldTemporarySlot). Description("If set to true, creates a temporary replication slot that is automatically dropped when the connection is closed."). Default(false)). Field(service.NewStringField(fieldSlotName). Description("The name of the PostgreSQL logical replication slot to use. If not provided, a random name will be generated. You can create this slot manually before starting replication if desired."). Example("my_test_slot"). - Default(randomSlotName)). + Default(nil)). Field(service.NewAutoRetryNacksToggleField()). Field(service.NewBatchPolicyField(fieldBatching)) @@ -143,7 +140,7 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser snapshotMemSafetyFactor float64 decodingPlugin string pgConnOptions string - streamUncomited bool + streamUncommitted bool snapshotBatchSize int checkpointLimit int batching service.BatchPolicy @@ -174,7 +171,7 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser return nil, err } - tlsConf, tlsEnabled, err := conf.FieldTLSToggled(fieldTls) + tlsConf, _, err := conf.FieldTLSToggled(fieldTls) if err != nil { return nil, err } @@ -208,7 +205,7 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser return nil, err } - streamUncomited, err = conf.FieldBool(fieldStreamUncomitted) + streamUncommitted, err = conf.FieldBool(fieldStreamUncommitted) if err != nil { return nil, err } @@ -242,7 +239,7 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser batching.Count = 1 } - pgconnConfig := pgconn.Config{ + pgConnConfig := pgconn.Config{ Host: dbHost, Port: uint16(dbPort), Database: dbName, @@ -251,25 +248,19 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser Password: dbPassword, } - if !tlsEnabled { - pgconnConfig.TLSConfig = nil - tlsConf = nil - } - - snapsotMetrics := mgr.Metrics().NewGauge("snapshot_progress", "table") + snapshotMetrics := mgr.Metrics().NewGauge("snapshot_progress", "table") replicationLag := mgr.Metrics().NewGauge("replication_lag_bytes") i := &pgStreamInput{ - dbConfig: pgconnConfig, + dbConfig: pgConnConfig, streamSnapshot: streamSnapshot, snapshotMemSafetyFactor: snapshotMemSafetyFactor, slotName: dbSlotName, schema: dbSchema, pgConnRuntimeParam: pgConnOptions, - tls: tlsConf, tables: tables, decodingPlugin: decodingPlugin, - streamUncomited: streamUncomited, + streamUncommitted: streamUncommitted, temporarySlot: temporarySlot, snapshotBatchSize: snapshotBatchSize, batching: batching, @@ -279,8 +270,7 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser mgr: mgr, logger: mgr.Logger(), - metrics: mgr.Metrics(), - snapshotMetrics: snapsotMetrics, + snapshotMetrics: snapshotMetrics, replicationLag: replicationLag, inTxState: atomic.Bool{}, releaseTrxChan: make(chan bool), @@ -307,8 +297,7 @@ func init() { type pgStreamInput struct { dbConfig pgconn.Config - tls *tls.Config - pglogicalStream *pglogicalstream.Stream + pgLogicalStream *pglogicalstream.Stream pgConnRuntimeParam string slotName string temporarySlot bool @@ -318,7 +307,7 @@ type pgStreamInput struct { streamSnapshot bool snapshotMemSafetyFactor float64 snapshotBatchSize int - streamUncomited bool + streamUncommitted bool logger *service.Logger mgr *service.Resources metrics *service.Metrics @@ -327,10 +316,8 @@ type pgStreamInput struct { batching service.BatchPolicy checkpointLimit int - snapshotRateCounter *RateCounter - snapshotMessageRate *service.MetricGauge - snapshotMetrics *service.MetricGauge - replicationLag *service.MetricGauge + snapshotMetrics *service.MetricGauge + replicationLag *service.MetricGauge releaseTrxChan chan bool inTxState atomic.Bool @@ -338,34 +325,32 @@ type pgStreamInput struct { func (p *pgStreamInput) Connect(ctx context.Context) error { pgStream, err := pglogicalstream.NewPgStream(ctx, &pglogicalstream.Config{ - PgConnRuntimeParam: p.pgConnRuntimeParam, - DBHost: p.dbConfig.Host, - DBPassword: p.dbConfig.Password, - DBUser: p.dbConfig.User, - DBPort: int(p.dbConfig.Port), - DBTables: p.tables, - DBName: p.dbConfig.Database, - DBSchema: p.schema, - ReplicationSlotName: "rs_" + p.slotName, - TLSConfig: p.tls, - BatchSize: p.snapshotBatchSize, - StreamOldData: p.streamSnapshot, - TemporaryReplicationSlot: p.temporarySlot, - StreamUncomited: p.streamUncomited, - DecodingPlugin: p.decodingPlugin, - + PgConnRuntimeParam: p.pgConnRuntimeParam, + DBHost: p.dbConfig.Host, + DBPassword: p.dbConfig.Password, + DBUser: p.dbConfig.User, + DBPort: int(p.dbConfig.Port), + DBTables: p.tables, + DBName: p.dbConfig.Database, + DBSchema: p.schema, + ReplicationSlotName: "rs_" + p.slotName, + BatchSize: p.snapshotBatchSize, + StreamOldData: p.streamSnapshot, + TemporaryReplicationSlot: p.temporarySlot, + StreamUncommitted: p.streamUncommitted, + DecodingPlugin: p.decodingPlugin, SnapshotMemorySafetyFactor: p.snapshotMemSafetyFactor, }) if err != nil { return err } - p.pglogicalStream = pgStream + p.pgLogicalStream = pgStream go func() { batchPolicy, err := p.batching.NewBatcher(p.mgr) if err != nil { - p.logger.Errorf("Failed to initialise batch policy: %v, falling back to no policy.\n", err) + p.logger.Errorf("Failed to initialise batch policy: %v, falling back to no policy.", err) conf := service.BatchPolicy{Count: 1} if batchPolicy, err = conf.NewBatcher(p.mgr); err != nil { panic(err) @@ -384,11 +369,13 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { for { select { + case <-ctx.Done(): + return case <-nextTimedBatchChan: nextTimedBatchChan = nil flushedBatch, err := batchPolicy.Flush(ctx) if err != nil { - p.mgr.Logger().Debugf("Timed flush batch error: %w", err) + p.logger.Debugf("Timed flush batch error: %w", err) break } @@ -397,17 +384,14 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { } // TrxCommit LSN must be acked when all the messages in the batch are processed - case trxCommitLsn, open := <-p.pglogicalStream.AckTxChan(): + case trxCommitLsn, open := <-p.pgLogicalStream.AckTxChan(): if !open { break } - p.cMut.Lock() - p.cMut.Unlock() - flushedBatch, err := batchPolicy.Flush(ctx) if err != nil { - p.mgr.Logger().Debugf("Flush batch error: %w", err) + p.logger.Debugf("Flush batch error: %w", err) break } @@ -417,14 +401,14 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { } <-callbackChan - if err = p.pglogicalStream.AckLSN(trxCommitLsn); err != nil { - p.mgr.Logger().Errorf("Failed to ack LSN: %v", err) + if err = p.pgLogicalStream.AckLSN(trxCommitLsn); err != nil { + p.logger.Errorf("Failed to ack LSN: %v", err) break } - p.pglogicalStream.ConsumedCallback() <- true + p.pgLogicalStream.ConsumedCallback() <- true - case message, open := <-p.pglogicalStream.Messages(): + case message, open := <-p.pgLogicalStream.Messages(): if !open { break } @@ -440,6 +424,12 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { } latestOffset = &parsedLSN } + + if len(message.Changes) == 0 { + p.logger.Debugf("Received empty message on LSN: %v", message.Lsn) + continue + } + if mb, err = json.Marshal(message); err != nil { break } @@ -461,7 +451,7 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { nextTimedBatchChan = nil flushedBatch, err := batchPolicy.Flush(ctx) if err != nil { - p.mgr.Logger().Debugf("Flush batch error: %w", err) + p.logger.Debugf("Flush batch error: %w", err) break } if message.IsStreaming { @@ -478,7 +468,7 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { } case <-ctx.Done(): - if err = p.pglogicalStream.Stop(); err != nil { + if err = p.pgLogicalStream.Stop(); err != nil { p.logger.Errorf("Failed to stop pglogical stream: %v", err) } } @@ -503,7 +493,7 @@ func (p *pgStreamInput) asyncCheckpointer() func(context.Context, chan<- asyncMe resolveFn, err := cp.Track(ctx, lsn, int64(len(msg))) if err != nil { if ctx.Err() == nil { - p.mgr.Logger().Errorf("Failed to checkpoint offset: %v\n", err) + p.logger.Errorf("Failed to checkpoint offset: %v", err) } return false } @@ -519,7 +509,7 @@ func (p *pgStreamInput) asyncCheckpointer() func(context.Context, chan<- asyncMe p.cMut.Lock() defer p.cMut.Unlock() if lsn != nil { - if err = p.pglogicalStream.AckLSN(Int64ToLSN(*lsn)); err != nil { + if err = p.pgLogicalStream.AckLSN(Int64ToLSN(*lsn)); err != nil { return err } if txCommitConfirmChan != nil { @@ -559,8 +549,8 @@ func (p *pgStreamInput) ReadBatch(ctx context.Context) (service.MessageBatch, se } func (p *pgStreamInput) Close(ctx context.Context) error { - if p.pglogicalStream != nil { - return p.pglogicalStream.Stop() + if p.pgLogicalStream != nil { + return p.pgLogicalStream.Stop() } return nil } diff --git a/internal/impl/postgresql/integration_test.go b/internal/impl/postgresql/integration_test.go index c7bfd2c9fe..f65b9633b2 100644 --- a/internal/impl/postgresql/integration_test.go +++ b/internal/impl/postgresql/integration_test.go @@ -324,7 +324,6 @@ pg_stream: password: %s port: %s schema: public - tls: none stream_snapshot: true decoding_plugin: pgoutput database: dbname @@ -438,7 +437,6 @@ func TestIntegrationPgStreamingFromRemoteDB(t *testing.T) { password := "postgres" dbname := "postgres" port := "5432" - sslmode := "none" template := fmt.Sprintf(` pg_stream: @@ -448,7 +446,6 @@ pg_stream: password: %s port: %s schema: public - tls: %s snapshot_batch_size: 100000 stream_snapshot: true decoding_plugin: pgoutput @@ -460,7 +457,7 @@ pg_stream: - products - orders - order_items -`, host, user, password, port, sslmode, dbname) +`, host, user, password, port, dbname) cacheConf := fmt.Sprintf(` label: pg_stream_cache @@ -476,25 +473,12 @@ file: var outMessages int64 var outMessagesMut sync.Mutex - rc := NewRateCounter() - - go func() { - ticker := time.NewTicker(time.Second) - defer ticker.Stop() - - for range ticker.C { - fmt.Printf("Current rate: %.2f messages per second\n", rc.Rate()) - fmt.Printf("Total messages: %d\n", outMessages) - } - }() - require.NoError(t, streamOutBuilder.AddBatchConsumerFunc(func(c context.Context, mb service.MessageBatch) error { _, err := mb[0].AsBytes() require.NoError(t, err) outMessagesMut.Lock() outMessages += 1 outMessagesMut.Unlock() - rc.Increment() return nil })) diff --git a/internal/impl/postgresql/pglogicalstream/config.go b/internal/impl/postgresql/pglogicalstream/config.go index 3590a6ddf0..74aaddef4f 100644 --- a/internal/impl/postgresql/pglogicalstream/config.go +++ b/internal/impl/postgresql/pglogicalstream/config.go @@ -49,7 +49,7 @@ type Config struct { BatchSize int `yaml:"batch_size"` // StreamUncommitted is whether to stream uncommitted messages before receiving commit message - StreamUncomited bool `yaml:"stream_uncommitted"` + StreamUncommitted bool `yaml:"stream_uncommitted"` logger *service.Logger } diff --git a/internal/impl/postgresql/pglogicalstream/util.go b/internal/impl/postgresql/pglogicalstream/connection.go similarity index 100% rename from internal/impl/postgresql/pglogicalstream/util.go rename to internal/impl/postgresql/pglogicalstream/connection.go diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index de9b9e4d85..6b5e58291c 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -53,7 +53,7 @@ type Stream struct { snapshotMemorySafetyFactor float64 logger *service.Logger monitor *Monitor - streamUncomited bool + streamUncommitted bool snapshotter *Snapshotter transactionAckChan chan string transactionBeginChan chan bool @@ -121,7 +121,7 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { slotName: config.ReplicationSlotName, schema: config.DBSchema, snapshotMemorySafetyFactor: config.SnapshotMemorySafetyFactor, - streamUncomited: config.StreamUncomited, + streamUncommitted: config.StreamUncommitted, snapshotBatchSize: config.BatchSize, tableNames: tableNames, consumedCallback: make(chan bool), @@ -432,7 +432,7 @@ func (s *Stream) streamMessagesAsync() { } if s.decodingPlugin == "pgoutput" { - if s.streamUncomited { + if s.streamUncommitted { // parse changes inside the transaction message, err := decodePgOutput(xld.WALData, relations, typeMap) if err != nil { diff --git a/internal/impl/postgresql/utils.go b/internal/impl/postgresql/utils.go index 3c3e3bd642..e787849a0d 100644 --- a/internal/impl/postgresql/utils.go +++ b/internal/impl/postgresql/utils.go @@ -1,46 +1,19 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + package pgstream import ( "fmt" "strconv" "strings" - "sync" - "sync/atomic" - "time" ) -// RateCounter is used to measure the rate of invocations -type RateCounter struct { - count int64 - lastChecked time.Time - mutex sync.Mutex -} - -// NewRateCounter creates a new RateCounter -func NewRateCounter() *RateCounter { - return &RateCounter{ - lastChecked: time.Now(), - } -} - -// Increment increases the counter by 1 -func (rc *RateCounter) Increment() { - atomic.AddInt64(&rc.count, 1) -} - -// Rate calculates the current rate of invocations per second -func (rc *RateCounter) Rate() float64 { - rc.mutex.Lock() - defer rc.mutex.Unlock() - - now := time.Now() - duration := now.Sub(rc.lastChecked).Seconds() - count := atomic.SwapInt64(&rc.count, 0) - rc.lastChecked = now - - return float64(count) / duration -} - // LSNToInt64 converts a PostgreSQL LSN string to int64 func LSNToInt64(lsn string) (int64, error) { // Split the LSN into segments From ea64c1cde3e76ab49448240d7ff8edb800b5b639 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Tue, 5 Nov 2024 11:17:47 +0100 Subject: [PATCH 046/118] chore(): updated tests --- internal/impl/postgresql/integration_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/internal/impl/postgresql/integration_test.go b/internal/impl/postgresql/integration_test.go index f65b9633b2..0c70d44b70 100644 --- a/internal/impl/postgresql/integration_test.go +++ b/internal/impl/postgresql/integration_test.go @@ -449,7 +449,7 @@ pg_stream: snapshot_batch_size: 100000 stream_snapshot: true decoding_plugin: pgoutput - stream_uncomitted: false + stream_uncommitted: false temporary_slot: true database: %s tables: @@ -546,7 +546,7 @@ pg_stream: snapshot_batch_size: 100 stream_snapshot: true decoding_plugin: pgoutput - stream_uncomitted: true + stream_uncommitted: true database: dbname tables: - flights @@ -685,7 +685,7 @@ pg_stream: schema: public stream_snapshot: true decoding_plugin: pgoutput - stream_uncomitted: true + stream_uncommitted: true database: dbname tables: - flights @@ -822,7 +822,7 @@ pg_stream: schema: public stream_snapshot: true decoding_plugin: pgoutput - stream_uncomitted: false + stream_uncommitted: false database: dbname tables: - flights From 22ff49a14706ac807c53a719271cc81e02679df3 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Tue, 5 Nov 2024 11:52:05 +0100 Subject: [PATCH 047/118] chore(): removed unused vars --- internal/impl/postgresql/pglogicalstream/consts.go | 9 --------- 1 file changed, 9 deletions(-) diff --git a/internal/impl/postgresql/pglogicalstream/consts.go b/internal/impl/postgresql/pglogicalstream/consts.go index 968d0e60ed..944fd82b39 100644 --- a/internal/impl/postgresql/pglogicalstream/consts.go +++ b/internal/impl/postgresql/pglogicalstream/consts.go @@ -32,12 +32,3 @@ func decodingPluginFromString(plugin string) DecodingPlugin { func (d DecodingPlugin) String() string { return string(d) } - -// TLSVerify is a type for the TLS verification mode -type TLSVerify string - -// TLSNoVerify is the value for no TLS verification -const TLSNoVerify TLSVerify = "none" - -// TLSRequireVerify is the value for TLS verification with a CA -const TLSRequireVerify TLSVerify = "require" From 2f159777c463296975f5c68e1dae4d4534523f4e Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Tue, 5 Nov 2024 17:00:45 +0100 Subject: [PATCH 048/118] chore(): run make deps to fix ci pipeline --- go.mod | 1 - 1 file changed, 1 deletion(-) diff --git a/go.mod b/go.mod index a1e6a1e1ce..e41a494963 100644 --- a/go.mod +++ b/go.mod @@ -151,7 +151,6 @@ require ( cloud.google.com/go/longrunning v0.5.9 // indirect github.com/containerd/platforms v0.2.1 // indirect github.com/hamba/avro/v2 v2.22.2-0.20240625062549-66aad10411d9 // indirect - github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect From a492556c1fa4c1e7d30a5649bb1b85717ded9abf Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Tue, 5 Nov 2024 17:03:43 +0100 Subject: [PATCH 049/118] fix(postgres_cdc): monitor tests --- .../impl/postgresql/pglogicalstream/monitor_test.go | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/internal/impl/postgresql/pglogicalstream/monitor_test.go b/internal/impl/postgresql/pglogicalstream/monitor_test.go index 3a2c800555..b9a626b714 100644 --- a/internal/impl/postgresql/pglogicalstream/monitor_test.go +++ b/internal/impl/postgresql/pglogicalstream/monitor_test.go @@ -11,12 +11,10 @@ package pglogicalstream import ( "database/sql" "fmt" - "strconv" "strings" "testing" "time" - "github.com/jackc/pgx/v5/pgconn" "github.com/jaswdr/faker" "github.com/ory/dockertest/v3" "github.com/ory/dockertest/v3/docker" @@ -77,16 +75,9 @@ func Test_MonitorReplorting(t *testing.T) { require.NoError(t, err) } - portUint64, err := strconv.ParseUint(hostAndPortSplited[1], 10, 10) require.NoError(t, err) slotName := "test_slot" - mon, err := NewMonitor(&pgconn.Config{ - Host: hostAndPortSplited[0], - Port: uint16(portUint64), - User: "user_name", - Password: "secret", - Database: "dbname", - }, &service.Logger{}, []string{"flights"}, slotName) + mon, err := NewMonitor(databaseURL, &service.Logger{}, []string{"flights"}, slotName) require.NoError(t, err) require.NotNil(t, mon) } From 0f71f7c96c153a5ff6d1b2ea6f4540a0da86c2c9 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Tue, 5 Nov 2024 17:08:23 +0100 Subject: [PATCH 050/118] chore(postgres_cdc): added integration test skip check --- internal/impl/postgresql/integration_test.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/internal/impl/postgresql/integration_test.go b/internal/impl/postgresql/integration_test.go index 596cb1b657..30f04bf7ae 100644 --- a/internal/impl/postgresql/integration_test.go +++ b/internal/impl/postgresql/integration_test.go @@ -12,6 +12,7 @@ import ( "context" "database/sql" "fmt" + "github.com/redpanda-data/benthos/v4/public/service/integration" "strings" "sync" "testing" @@ -101,6 +102,8 @@ func ResourceWithPostgreSQLVersion(t *testing.T, pool *dockertest.Pool, version } func TestIntegrationPgCDC(t *testing.T) { + integration.CheckSkip(t) + tmpDir := t.TempDir() pool, err := dockertest.NewPool("") require.NoError(t, err) @@ -287,6 +290,7 @@ file: } func TestIntegrationPgCDCForPgOutputPlugin(t *testing.T) { + integration.CheckSkip(t) tmpDir := t.TempDir() pool, err := dockertest.NewPool("") require.NoError(t, err) @@ -495,6 +499,7 @@ file: } func TestIntegrationPgCDCForPgOutputStreamUncommittedPlugin(t *testing.T) { + integration.CheckSkip(t) tmpDir := t.TempDir() pool, err := dockertest.NewPool("") require.NoError(t, err) @@ -629,6 +634,7 @@ file: } func TestIntegrationPgMultiVersionsCDCForPgOutputStreamUncomitedPlugin(t *testing.T) { + integration.CheckSkip(t) // running tests in the look to test different PostgreSQL versions t.Parallel() for _, v := range []string{"17", "16", "15", "14", "13", "12", "11", "10"} { @@ -765,6 +771,7 @@ file: } func TestIntegrationPgMultiVersionsCDCForPgOutputStreamComittedPlugin(t *testing.T) { + integration.CheckSkip(t) for _, v := range []string{"17", "16", "15", "14", "13", "12", "11", "10"} { tmpDir := t.TempDir() pool, err := dockertest.NewPool("") From 078ffd9eec2df92b00302897e0935bf7accae4f7 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Tue, 5 Nov 2024 17:21:12 +0100 Subject: [PATCH 051/118] fix(postgres_cdc): lint warnings --- internal/impl/postgresql/input_pg_stream.go | 1 - internal/impl/postgresql/integration_test.go | 6 +++--- .../postgresql/pglogicalstream/logical_stream.go | 13 +++++++------ public/components/community/package.go | 2 +- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/internal/impl/postgresql/input_pg_stream.go b/internal/impl/postgresql/input_pg_stream.go index e19b10ebce..7a6f7ba426 100644 --- a/internal/impl/postgresql/input_pg_stream.go +++ b/internal/impl/postgresql/input_pg_stream.go @@ -261,7 +261,6 @@ type pgStreamInput struct { streamUncommitted bool logger *service.Logger mgr *service.Resources - metrics *service.Metrics cMut sync.Mutex msgChan chan asyncMessage batching service.BatchPolicy diff --git a/internal/impl/postgresql/integration_test.go b/internal/impl/postgresql/integration_test.go index 30f04bf7ae..a5a203c146 100644 --- a/internal/impl/postgresql/integration_test.go +++ b/internal/impl/postgresql/integration_test.go @@ -12,7 +12,6 @@ import ( "context" "database/sql" "fmt" - "github.com/redpanda-data/benthos/v4/public/service/integration" "strings" "sync" "testing" @@ -23,6 +22,7 @@ import ( _ "github.com/redpanda-data/benthos/v4/public/components/io" _ "github.com/redpanda-data/benthos/v4/public/components/pure" "github.com/redpanda-data/benthos/v4/public/service" + "github.com/redpanda-data/benthos/v4/public/service/integration" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -429,7 +429,7 @@ func TestIntegrationPgStreamingFromRemoteDB(t *testing.T) { // tables: users, products, orders, order_items - template := fmt.Sprintf(` + template := ` pg_stream: dsn: postgres://postgres:postgres@localhost:5432/postgres?sslmode=disable slot_name: test_slot_native_decoder @@ -444,7 +444,7 @@ pg_stream: - products - orders - order_items -`) +` cacheConf := fmt.Sprintf(` label: pg_stream_cache diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 4c0ee527d4..7201a52f52 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -28,10 +28,7 @@ import ( // Stream is a structure that represents a logical replication stream // It includes the connection to the database, the context for the stream, and snapshotting functionality type Stream struct { - pgConn *pgconn.PgConn - // pgDsn is used for creating golang PG Connection - // as using pgconn.Config for golang doesn't support multiple queries in the prepared statement for Postgres Version <= 14 - pgDsn string + pgConn *pgconn.PgConn streamCtx context.Context streamCancel context.CancelFunc @@ -120,7 +117,7 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { } stream.snapshotter = snapshotter - var pluginArguments = []string{} + var pluginArguments []string if stream.decodingPlugin == "pgoutput" { pluginArguments = []string{ "proto_version '1'", @@ -224,7 +221,11 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { } else { // New messages will be streamed after the snapshot has been processed. // stream.startLr() and stream.streamMessagesAsync() will be called inside stream.processSnapshot() - go stream.processSnapshot(context.Background()) + go func() { + if err := stream.processSnapshot(ctx); err != nil { + stream.logger.Errorf("Failed to process snapshot: %v", err.Error()) + } + }() } return stream, err diff --git a/public/components/community/package.go b/public/components/community/package.go index 55c2ba3c58..55ce2c7e55 100644 --- a/public/components/community/package.go +++ b/public/components/community/package.go @@ -54,6 +54,7 @@ import ( _ "github.com/redpanda-data/connect/v4/public/components/opensearch" _ "github.com/redpanda-data/connect/v4/public/components/otlp" _ "github.com/redpanda-data/connect/v4/public/components/pinecone" + _ "github.com/redpanda-data/connect/v4/public/components/postgresql" _ "github.com/redpanda-data/connect/v4/public/components/prometheus" _ "github.com/redpanda-data/connect/v4/public/components/pulsar" _ "github.com/redpanda-data/connect/v4/public/components/pure" @@ -71,5 +72,4 @@ import ( _ "github.com/redpanda-data/connect/v4/public/components/twitter" _ "github.com/redpanda-data/connect/v4/public/components/wasm" _ "github.com/redpanda-data/connect/v4/public/components/zeromq" - _ "github.com/redpanda-data/connect/v4/public/components/postgresql" ) From cf65fdcac4d81fe744812e921b50ab512e09fc64 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Wed, 6 Nov 2024 10:07:45 +0100 Subject: [PATCH 052/118] chore(): specify monitoring && standby intervals via config --- internal/impl/postgresql/input_pg_stream.go | 30 ++++++++++++++++++- internal/impl/postgresql/integration_test.go | 6 ++-- .../impl/postgresql/pglogicalstream/config.go | 3 ++ .../postgresql/pglogicalstream/connection.go | 4 +-- .../pglogicalstream/logical_stream.go | 4 +-- .../postgresql/pglogicalstream/monitor.go | 4 +-- .../pglogicalstream/monitor_test.go | 2 +- 7 files changed, 42 insertions(+), 11 deletions(-) diff --git a/internal/impl/postgresql/input_pg_stream.go b/internal/impl/postgresql/input_pg_stream.go index 7a6f7ba426..d31050d270 100644 --- a/internal/impl/postgresql/input_pg_stream.go +++ b/internal/impl/postgresql/input_pg_stream.go @@ -27,7 +27,7 @@ import ( const ( fieldDSN = "dsn" - fieldStreamUncommitted = "stream_uncomitted" + fieldStreamUncommitted = "stream_uncommitted" fieldStreamSnapshot = "stream_snapshot" fieldSnapshotMemSafetyFactor = "snapshot_memory_safety_factor" fieldSnapshotBatchSize = "snapshot_batch_size" @@ -36,6 +36,8 @@ const ( fieldTables = "tables" fieldCheckpointLimit = "checkpoint_limit" fieldTemporarySlot = "temporary_slot" + fieldPgStandbyTimeout = "pg_standby_timeout_sec" + fieldWalMonitorIntervalSec = "pg_wal_monitor_interval_sec" fieldSlotName = "slot_name" fieldBatching = "batching" ) @@ -103,6 +105,14 @@ This input adds the following metadata fields to each message: Description("The name of the PostgreSQL logical replication slot to use. If not provided, a random name will be generated. You can create this slot manually before starting replication if desired."). Example("my_test_slot"). Default("")). + Field(service.NewIntField(fieldPgStandbyTimeout). + Description("Int field that specifies default standby timeout for PostgreSQL replication connection"). + Example(10). + Default(10)). + Field(service.NewIntField(fieldWalMonitorIntervalSec). + Description("Int field stat specifies ticker interval for WAL monitoring. Used to fetch replication slot lag"). + Example(3). + Default(3)). Field(service.NewAutoRetryNacksToggleField()). Field(service.NewBatchPolicyField(fieldBatching)) @@ -119,6 +129,8 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser streamUncommitted bool snapshotBatchSize int checkpointLimit int + walMonitorIntervalSec int + pgStandbyTimeoutSec int batching service.BatchPolicy ) @@ -186,6 +198,16 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser batching.Count = 1 } + pgStandbyTimeoutSec, err = conf.FieldInt(fieldPgStandbyTimeout) + if err != nil { + return nil, err + } + + walMonitorIntervalSec, err = conf.FieldInt(fieldWalMonitorIntervalSec) + if err != nil { + return nil, err + } + pgConnConfig, err := pgconn.ParseConfigWithOptions(dsn, pgconn.ParseConfigOptions{ // Don't support dynamic reading of password GetSSLPassword: func(context.Context) string { return "" }, @@ -216,6 +238,8 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser snapshotBatchSize: snapshotBatchSize, batching: batching, checkpointLimit: checkpointLimit, + pgStandbyTimeoutSec: pgStandbyTimeoutSec, + walMonitorIntervalSec: walMonitorIntervalSec, cMut: sync.Mutex{}, msgChan: make(chan asyncMessage), @@ -251,6 +275,8 @@ type pgStreamInput struct { dbRawDSN string pgLogicalStream *pglogicalstream.Stream slotName string + pgStandbyTimeoutSec int + walMonitorIntervalSec int temporarySlot bool schema string tables []string @@ -287,6 +313,8 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { StreamUncommitted: p.streamUncommitted, DecodingPlugin: p.decodingPlugin, SnapshotMemorySafetyFactor: p.snapshotMemSafetyFactor, + PgStandbyTimeoutSec: p.pgStandbyTimeoutSec, + WalMonitorIntervalSec: p.walMonitorIntervalSec, Logger: p.logger, }) if err != nil { diff --git a/internal/impl/postgresql/integration_test.go b/internal/impl/postgresql/integration_test.go index a5a203c146..64915e7d20 100644 --- a/internal/impl/postgresql/integration_test.go +++ b/internal/impl/postgresql/integration_test.go @@ -531,7 +531,7 @@ pg_stream: snapshot_batch_size: 100 stream_snapshot: true decoding_plugin: pgoutput - stream_uncomitted: true + stream_uncommitted: true schema: public tables: - flights @@ -668,7 +668,7 @@ pg_stream: slot_name: test_slot_native_decoder stream_snapshot: true decoding_plugin: pgoutput - stream_uncomitted: true + stream_uncommitted: true schema: public tables: - flights @@ -803,7 +803,7 @@ pg_stream: slot_name: test_slot_native_decoder stream_snapshot: true decoding_plugin: pgoutput - stream_uncomitted: false + stream_uncommitted: false schema: public tables: - flights diff --git a/internal/impl/postgresql/pglogicalstream/config.go b/internal/impl/postgresql/pglogicalstream/config.go index 92c01b6353..8303694cc2 100644 --- a/internal/impl/postgresql/pglogicalstream/config.go +++ b/internal/impl/postgresql/pglogicalstream/config.go @@ -38,4 +38,7 @@ type Config struct { StreamUncommitted bool Logger *service.Logger + + PgStandbyTimeoutSec int + WalMonitorIntervalSec int } diff --git a/internal/impl/postgresql/pglogicalstream/connection.go b/internal/impl/postgresql/pglogicalstream/connection.go index e36625604a..7b801236ef 100644 --- a/internal/impl/postgresql/pglogicalstream/connection.go +++ b/internal/impl/postgresql/pglogicalstream/connection.go @@ -15,6 +15,8 @@ import ( "strconv" ) +var re = regexp.MustCompile(`^(\d+)`) + func openPgConnectionFromConfig(dbDSN string) (*sql.DB, error) { return sql.Open("postgres", dbDSN) } @@ -31,8 +33,6 @@ func getPostgresVersion(dbDSN string) (int, error) { return 0, fmt.Errorf("failed to execute query: %w", err) } - // Extract the major version number - re := regexp.MustCompile(`^(\d+)`) match := re.FindStringSubmatch(versionString) if len(match) < 2 { return 0, fmt.Errorf("failed to parse version string: %s", versionString) diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 7201a52f52..f3f0bd40bb 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -201,11 +201,11 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { stream.clientXLogPos = lsnrestart } - stream.standbyMessageTimeout = time.Second * 10 + stream.standbyMessageTimeout = time.Duration(config.PgStandbyTimeoutSec) * time.Second stream.nextStandbyMessageDeadline = time.Now().Add(stream.standbyMessageTimeout) stream.streamCtx, stream.streamCancel = context.WithCancel(context.Background()) - monitor, err := NewMonitor(config.DBRawDSN, stream.logger, tableNames, stream.slotName) + monitor, err := NewMonitor(config.DBRawDSN, stream.logger, tableNames, stream.slotName, config.WalMonitorIntervalSec) if err != nil { return nil, err } diff --git a/internal/impl/postgresql/pglogicalstream/monitor.go b/internal/impl/postgresql/pglogicalstream/monitor.go index 8c808bcda5..df14b9f4a4 100644 --- a/internal/impl/postgresql/pglogicalstream/monitor.go +++ b/internal/impl/postgresql/pglogicalstream/monitor.go @@ -47,7 +47,7 @@ type Monitor struct { } // NewMonitor creates a new Monitor instance -func NewMonitor(dbDSN string, logger *service.Logger, tables []string, slotName string) (*Monitor, error) { +func NewMonitor(dbDSN string, logger *service.Logger, tables []string, slotName string, intervalSec int) (*Monitor, error) { dbConn, err := openPgConnectionFromConfig(dbDSN) if err != nil { return nil, err @@ -69,7 +69,7 @@ func NewMonitor(dbDSN string, logger *service.Logger, tables []string, slotName m.ctx = ctx m.cancelTicker = cancel // hardocded duration to monitor slot lag - m.ticker = time.NewTicker(time.Second * 3) + m.ticker = time.NewTicker(time.Second * time.Duration(intervalSec)) go func() { for { diff --git a/internal/impl/postgresql/pglogicalstream/monitor_test.go b/internal/impl/postgresql/pglogicalstream/monitor_test.go index b9a626b714..fd5c304f81 100644 --- a/internal/impl/postgresql/pglogicalstream/monitor_test.go +++ b/internal/impl/postgresql/pglogicalstream/monitor_test.go @@ -77,7 +77,7 @@ func Test_MonitorReplorting(t *testing.T) { require.NoError(t, err) slotName := "test_slot" - mon, err := NewMonitor(databaseURL, &service.Logger{}, []string{"flights"}, slotName) + mon, err := NewMonitor(databaseURL, &service.Logger{}, []string{"flights"}, slotName, 1) require.NoError(t, err) require.NotNil(t, mon) } From cdd1a012cd946174d59fa49e9a89b37664e9dbe9 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Wed, 6 Nov 2024 10:30:57 +0100 Subject: [PATCH 053/118] chore(): removed redundant tests + deps --- go.mod | 1 - go.sum | 2 - internal/impl/postgresql/integration_test.go | 79 ++++++++++++------ .../pglogicalstream/monitor_test.go | 83 ------------------- 4 files changed, 52 insertions(+), 113 deletions(-) delete mode 100644 internal/impl/postgresql/pglogicalstream/monitor_test.go diff --git a/go.mod b/go.mod index e41a494963..b4f7007d6a 100644 --- a/go.mod +++ b/go.mod @@ -71,7 +71,6 @@ require ( github.com/influxdata/influxdb1-client v0.0.0-20220302092344-a9ab5670611c github.com/jackc/pgx/v4 v4.18.3 github.com/jackc/pgx/v5 v5.6.0 - github.com/jaswdr/faker v1.19.1 github.com/jhump/protoreflect v1.16.0 github.com/lib/pq v1.10.9 github.com/linkedin/goavro/v2 v2.13.0 diff --git a/go.sum b/go.sum index a2c41fe1b0..30321a032b 100644 --- a/go.sum +++ b/go.sum @@ -730,8 +730,6 @@ github.com/jackc/puddle v1.3.0 h1:eHK/5clGOatcjX3oWGBO/MpxpbHzSwud5EWTSCI+MX0= github.com/jackc/puddle v1.3.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= -github.com/jaswdr/faker v1.19.1 h1:xBoz8/O6r0QAR8eEvKJZMdofxiRH+F0M/7MU9eNKhsM= -github.com/jaswdr/faker v1.19.1/go.mod h1:x7ZlyB1AZqwqKZgyQlnqEG8FDptmHlncA5u2zY/yi6w= github.com/jawher/mow.cli v1.0.4/go.mod h1:5hQj2V8g+qYmLUVWqu4Wuja1pI57M83EChYLVZ0sMKk= github.com/jawher/mow.cli v1.2.0/go.mod h1:y+pcA3jBAdo/GIZx/0rFjw/K2bVEODP9rfZOfaiq8Ko= github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8= diff --git a/internal/impl/postgresql/integration_test.go b/internal/impl/postgresql/integration_test.go index 64915e7d20..283a281c26 100644 --- a/internal/impl/postgresql/integration_test.go +++ b/internal/impl/postgresql/integration_test.go @@ -17,7 +17,7 @@ import ( "testing" "time" - "github.com/jaswdr/faker" + "github.com/go-faker/faker/v4" _ "github.com/lib/pq" _ "github.com/redpanda-data/benthos/v4/public/components/io" _ "github.com/redpanda-data/benthos/v4/public/components/pure" @@ -30,6 +30,21 @@ import ( "github.com/ory/dockertest/v3/docker" ) +type FakeFlightRecord struct { + RealAddress faker.RealAddress `faker:"real_address"` + CreatedAt int64 `fake:"unix_time"` +} + +func GetFakeFlightRecord() FakeFlightRecord { + flightRecord := FakeFlightRecord{} + err := faker.FakeData(&flightRecord) + if err != nil { + panic(err) + } + + return flightRecord +} + func ResourceWithPostgreSQLVersion(t *testing.T, pool *dockertest.Pool, version string) (*dockertest.Resource, *sql.DB, error) { resource, err := pool.RunWithOptions(&dockertest.RunOptions{ Repository: "postgres", @@ -177,10 +192,10 @@ func TestIntegrationPgCDC(t *testing.T) { panic(fmt.Errorf("could not connect to docker: %w", err)) } - fake := faker.New() for i := 0; i < 1000; i++ { - _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) - _, err = db.Exec("INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + f := GetFakeFlightRecord() + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) + _, err = db.Exec("INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) require.NoError(t, err) } @@ -231,8 +246,9 @@ file: }, time.Second*25, time.Millisecond*100) for i := 0; i < 1000; i++ { - _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) - _, err = db.Exec("INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + f := GetFakeFlightRecord() + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) + _, err = db.Exec("INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) require.NoError(t, err) } @@ -271,7 +287,8 @@ file: time.Sleep(time.Second * 5) for i := 0; i < 50; i++ { - _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + f := GetFakeFlightRecord() + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) require.NoError(t, err) } @@ -310,9 +327,9 @@ func TestIntegrationPgCDCForPgOutputPlugin(t *testing.T) { require.NoError(t, err) - fake := faker.New() for i := 0; i < 10; i++ { - _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + f := GetFakeFlightRecord() + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) require.NoError(t, err) } @@ -364,9 +381,10 @@ file: }, time.Second*25, time.Millisecond*100) for i := 0; i < 10; i++ { - _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + f := GetFakeFlightRecord() + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) require.NoError(t, err) - _, err = db.Exec("INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + _, err = db.Exec("INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) require.NoError(t, err) } @@ -405,7 +423,8 @@ file: time.Sleep(time.Second * 5) for i := 0; i < 10; i++ { - _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + f := GetFakeFlightRecord() + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) require.NoError(t, err) } @@ -517,9 +536,9 @@ func TestIntegrationPgCDCForPgOutputStreamUncommittedPlugin(t *testing.T) { hostAndPortSplited := strings.Split(hostAndPort, ":") password := "l]YLSc|4[i56%{gY" - fake := faker.New() for i := 0; i < 10000; i++ { - _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + f := GetFakeFlightRecord() + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) require.NoError(t, err) } @@ -574,9 +593,10 @@ file: }, time.Second*25, time.Millisecond*100) for i := 0; i < 10; i++ { - _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + f := GetFakeFlightRecord() + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) require.NoError(t, err) - _, err = db.Exec("INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + _, err = db.Exec("INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) require.NoError(t, err) } @@ -615,7 +635,8 @@ file: time.Sleep(time.Second * 5) for i := 0; i < 10; i++ { - _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + f := GetFakeFlightRecord() + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) require.NoError(t, err) } @@ -655,9 +676,9 @@ func TestIntegrationPgMultiVersionsCDCForPgOutputStreamUncomitedPlugin(t *testin hostAndPortSplited := strings.Split(hostAndPort, ":") password := "l]YLSc|4[i56%{gY" - fake := faker.New() for i := 0; i < 1000; i++ { - _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + f := GetFakeFlightRecord() + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) require.NoError(t, err) } @@ -710,9 +731,10 @@ file: }, time.Second*25, time.Millisecond*100) for i := 0; i < 1000; i++ { - _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + f := GetFakeFlightRecord() + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) require.NoError(t, err) - _, err = db.Exec("INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + _, err = db.Exec("INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) require.NoError(t, err) } @@ -751,7 +773,8 @@ file: time.Sleep(time.Second * 5) for i := 0; i < 1000; i++ { - _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + f := GetFakeFlightRecord() + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) require.NoError(t, err) } @@ -790,9 +813,9 @@ func TestIntegrationPgMultiVersionsCDCForPgOutputStreamComittedPlugin(t *testing hostAndPortSplited := strings.Split(hostAndPort, ":") password := "l]YLSc|4[i56%{gY" - fake := faker.New() for i := 0; i < 1000; i++ { - _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + f := GetFakeFlightRecord() + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) require.NoError(t, err) } @@ -846,9 +869,10 @@ file: }, time.Second*25, time.Millisecond*100) for i := 0; i < 1000; i++ { - _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + f := GetFakeFlightRecord() + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) require.NoError(t, err) - _, err = db.Exec("INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + _, err = db.Exec("INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) require.NoError(t, err) } @@ -887,7 +911,8 @@ file: time.Sleep(time.Second * 5) for i := 0; i < 1000; i++ { - _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) + f := GetFakeFlightRecord() + _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) require.NoError(t, err) } diff --git a/internal/impl/postgresql/pglogicalstream/monitor_test.go b/internal/impl/postgresql/pglogicalstream/monitor_test.go deleted file mode 100644 index fd5c304f81..0000000000 --- a/internal/impl/postgresql/pglogicalstream/monitor_test.go +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright 2024 Redpanda Data, Inc. -// -// Licensed as a Redpanda Enterprise file under the Redpanda Community -// License (the "License"); you may not use this file except in compliance with -// the License. You may obtain a copy of the License at -// -// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md - -package pglogicalstream - -import ( - "database/sql" - "fmt" - "strings" - "testing" - "time" - - "github.com/jaswdr/faker" - "github.com/ory/dockertest/v3" - "github.com/ory/dockertest/v3/docker" - "github.com/redpanda-data/benthos/v4/public/service" - "github.com/stretchr/testify/require" -) - -func Test_MonitorReplorting(t *testing.T) { - t.Skip("Skipping for now") - pool, err := dockertest.NewPool("") - require.NoError(t, err) - - resource, err := pool.RunWithOptions(&dockertest.RunOptions{ - Repository: "postgres", - Tag: "16", - Env: []string{ - "POSTGRES_PASSWORD=secret", - "POSTGRES_USER=user_name", - "POSTGRES_DB=dbname", - }, - Cmd: []string{ - "postgres", - "-c", "wal_level=logical", - }, - }, func(config *docker.HostConfig) { - config.AutoRemove = true - config.RestartPolicy = docker.RestartPolicy{Name: "no"} - }) - - require.NoError(t, err) - require.NoError(t, resource.Expire(120)) - - hostAndPort := resource.GetHostPort("5432/tcp") - hostAndPortSplited := strings.Split(hostAndPort, ":") - databaseURL := fmt.Sprintf("user=user_name password=secret dbname=dbname sslmode=disable host=%s port=%s", hostAndPortSplited[0], hostAndPortSplited[1]) - - var db *sql.DB - pool.MaxWait = 120 * time.Second - err = pool.Retry(func() error { - if db, err = sql.Open("postgres", databaseURL); err != nil { - return err - } - - if err = db.Ping(); err != nil { - return err - } - - return err - }) - require.NoError(t, err) - - _, err = db.Exec("CREATE TABLE IF NOT EXISTS flights (id serial PRIMARY KEY, name VARCHAR(50), created_at TIMESTAMP);") - require.NoError(t, err) - - fake := faker.New() - for i := 0; i < 1000; i++ { - _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", fake.Address().City(), fake.Time().RFC1123(time.Now())) - require.NoError(t, err) - } - - require.NoError(t, err) - slotName := "test_slot" - mon, err := NewMonitor(databaseURL, &service.Logger{}, []string{"flights"}, slotName, 1) - require.NoError(t, err) - require.NotNil(t, mon) -} From 0f0eb6fb05bb639f28a6b0ce418b490afc6fbe18 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Wed, 6 Nov 2024 11:14:27 +0100 Subject: [PATCH 054/118] chore(): updated docs --- .../components/pages/inputs/pg_stream.adoc | 405 ++++++++++++++++++ internal/plugins/info.csv | 1 + 2 files changed, 406 insertions(+) create mode 100644 docs/modules/components/pages/inputs/pg_stream.adoc diff --git a/docs/modules/components/pages/inputs/pg_stream.adoc b/docs/modules/components/pages/inputs/pg_stream.adoc new file mode 100644 index 0000000000..fb67d91e3a --- /dev/null +++ b/docs/modules/components/pages/inputs/pg_stream.adoc @@ -0,0 +1,405 @@ += pg_stream +:type: input +:status: beta +:categories: ["Services"] + + + +//// + THIS FILE IS AUTOGENERATED! + + To make changes, edit the corresponding source file under: + + https://github.com/redpanda-data/connect/tree/main/internal/impl/. + + And: + + https://github.com/redpanda-data/connect/tree/main/cmd/tools/docs_gen/templates/plugin.adoc.tmpl +//// + +// © 2024 Redpanda Data Inc. + + +component_type_dropdown::[] + + +Streams changes from a PostgreSQL database using logical replication. + +Introduced in version 4.39.0. + + +[tabs] +====== +Common:: ++ +-- + +```yml +# Common config fields, showing default values +input: + label: "" + pg_stream: + dsn: postgres://foouser:foopass@localhost:5432/foodb?sslmode=disable # No default (required) + stream_uncommitted: false + stream_snapshot: false + snapshot_memory_safety_factor: 1 + snapshot_batch_size: 0 + decoding_plugin: pgoutput + schema: public # No default (required) + tables: [] # No default (required) + checkpoint_limit: 1024 + temporary_slot: false + slot_name: "" + pg_standby_timeout_sec: 10 + pg_wal_monitor_interval_sec: 3 + auto_replay_nacks: true + batching: + count: 0 + byte_size: 0 + period: "" + check: "" +``` + +-- +Advanced:: ++ +-- + +```yml +# All config fields, showing default values +input: + label: "" + pg_stream: + dsn: postgres://foouser:foopass@localhost:5432/foodb?sslmode=disable # No default (required) + stream_uncommitted: false + stream_snapshot: false + snapshot_memory_safety_factor: 1 + snapshot_batch_size: 0 + decoding_plugin: pgoutput + schema: public # No default (required) + tables: [] # No default (required) + checkpoint_limit: 1024 + temporary_slot: false + slot_name: "" + pg_standby_timeout_sec: 10 + pg_wal_monitor_interval_sec: 3 + auto_replay_nacks: true + batching: + count: 0 + byte_size: 0 + period: "" + check: "" + processors: [] # No default (optional) +``` + +-- +====== + +Streams changes from a PostgreSQL database for Change Data Capture (CDC). +Additionally, if `stream_snapshot` is set to true, then the existing data in the database is also streamed too. + +== Metadata + +This input adds the following metadata fields to each message: +- is_streaming (Boolean indicating whether the message is part of a streaming operation or snapshot processing) +- table (Name of the table that the message originated from) +- operation (Type of operation that generated the message, such as INSERT, UPDATE, or DELETE) + + +== Fields + +=== `dsn` + +The Data Source Name for the PostgreSQL database in the form of `postgres://[user[:password]@][netloc][:port][/dbname][?param1=value1&...]`. Please note that Postgres enforces SSL by default, you can override this with the parameter `sslmode=disable` if required. + + +*Type*: `string` + + +```yml +# Examples + +dsn: postgres://foouser:foopass@localhost:5432/foodb?sslmode=disable +``` + +=== `stream_uncommitted` + +If set to true, the plugin will stream uncommitted transactions before receiving a commit message from PostgreSQL. This may result in duplicate records if the connector is restarted. + + +*Type*: `bool` + +*Default*: `false` + +=== `stream_snapshot` + +When set to true, the plugin will first stream a snapshot of all existing data in the database before streaming changes. + + +*Type*: `bool` + +*Default*: `false` + +```yml +# Examples + +stream_snapshot: true +``` + +=== `snapshot_memory_safety_factor` + +Determines the fraction of available memory that can be used for streaming the snapshot. Values between 0 and 1 represent the percentage of memory to use. Lower values make initial streaming slower but help prevent out-of-memory errors. + + +*Type*: `float` + +*Default*: `1` + +```yml +# Examples + +snapshot_memory_safety_factor: 0.2 +``` + +=== `snapshot_batch_size` + +The number of rows to fetch in each batch when querying the snapshot. A value of 0 lets the plugin determine the batch size based on `snapshot_memory_safety_factor` property. + + +*Type*: `int` + +*Default*: `0` + +```yml +# Examples + +snapshot_batch_size: 10000 +``` + +=== `decoding_plugin` + +Specifies the logical decoding plugin to use for streaming changes from PostgreSQL. 'pgoutput' is the native logical replication protocol, while 'wal2json' provides change data as JSON. + Important: No matter which plugin you choose, the data will be converted to JSON before sending it to Connect. + + +*Type*: `string` + +*Default*: `"pgoutput"` + +Options: +`pgoutput` +, `wal2json` +. + +```yml +# Examples + +decoding_plugin: pgoutput +``` + +=== `schema` + +The PostgreSQL schema from which to replicate data. + + +*Type*: `string` + + +```yml +# Examples + +schema: public +``` + +=== `tables` + +A list of table names to include in the logical replication. Each table should be specified as a separate item. + + +*Type*: `array` + + +```yml +# Examples + +tables: |2- + - my_table + - my_table_2 + +``` + +=== `checkpoint_limit` + +The maximum number of messages that can be processed at a given time. Increasing this limit enables parallel processing and batching at the output level. Any given LSN will not be acknowledged unless all messages under that offset are delivered in order to preserve at least once delivery guarantees. + + +*Type*: `int` + +*Default*: `1024` + +=== `temporary_slot` + +If set to true, creates a temporary replication slot that is automatically dropped when the connection is closed. + + +*Type*: `bool` + +*Default*: `false` + +=== `slot_name` + +The name of the PostgreSQL logical replication slot to use. If not provided, a random name will be generated. You can create this slot manually before starting replication if desired. + + +*Type*: `string` + +*Default*: `""` + +```yml +# Examples + +slot_name: my_test_slot +``` + +=== `pg_standby_timeout_sec` + +Int field that specifies default standby timeout for PostgreSQL replication connection + + +*Type*: `int` + +*Default*: `10` + +```yml +# Examples + +pg_standby_timeout_sec: 10 +``` + +=== `pg_wal_monitor_interval_sec` + +Int field stat specifies ticker interval for WAL monitoring. Used to fetch replication slot lag + + +*Type*: `int` + +*Default*: `3` + +```yml +# Examples + +pg_wal_monitor_interval_sec: 3 +``` + +=== `auto_replay_nacks` + +Whether messages that are rejected (nacked) at the output level should be automatically replayed indefinitely, eventually resulting in back pressure if the cause of the rejections is persistent. If set to `false` these messages will instead be deleted. Disabling auto replays can greatly improve memory efficiency of high throughput streams as the original shape of the data can be discarded immediately upon consumption and mutation. + + +*Type*: `bool` + +*Default*: `true` + +=== `batching` + +Allows you to configure a xref:configuration:batching.adoc[batching policy]. + + +*Type*: `object` + + +```yml +# Examples + +batching: + byte_size: 5000 + count: 0 + period: 1s + +batching: + count: 10 + period: 1s + +batching: + check: this.contains("END BATCH") + count: 0 + period: 1m +``` + +=== `batching.count` + +A number of messages at which the batch should be flushed. If `0` disables count based batching. + + +*Type*: `int` + +*Default*: `0` + +=== `batching.byte_size` + +An amount of bytes at which the batch should be flushed. If `0` disables size based batching. + + +*Type*: `int` + +*Default*: `0` + +=== `batching.period` + +A period in which an incomplete batch should be flushed regardless of its size. + + +*Type*: `string` + +*Default*: `""` + +```yml +# Examples + +period: 1s + +period: 1m + +period: 500ms +``` + +=== `batching.check` + +A xref:guides:bloblang/about.adoc[Bloblang query] that should return a boolean value indicating whether a message should end a batch. + + +*Type*: `string` + +*Default*: `""` + +```yml +# Examples + +check: this.type == "end_of_transaction" +``` + +=== `batching.processors` + +A list of xref:components:processors/about.adoc[processors] to apply to a batch as it is flushed. This allows you to aggregate and archive the batch however you see fit. Please note that all resulting messages are flushed as a single batch, therefore splitting the batch into smaller batches using these processors is a no-op. + + +*Type*: `array` + + +```yml +# Examples + +processors: + - archive: + format: concatenate + +processors: + - archive: + format: lines + +processors: + - archive: + format: json_array +``` + + diff --git a/internal/plugins/info.csv b/internal/plugins/info.csv index 8b57251cc0..2c122588b5 100644 --- a/internal/plugins/info.csv +++ b/internal/plugins/info.csv @@ -168,6 +168,7 @@ parquet ,processor ,parquet ,3.62.0 ,commun parquet_decode ,processor ,parquet_decode ,4.4.0 ,certified ,n ,y ,y parquet_encode ,processor ,parquet_encode ,4.4.0 ,certified ,n ,y ,y parse_log ,processor ,parse_log ,0.0.0 ,community ,n ,y ,y +pg_stream ,input ,pg_stream ,0.0.0 ,community ,n ,n ,n pinecone ,output ,pinecone ,4.31.0 ,certified ,n ,y ,y processors ,processor ,processors ,0.0.0 ,certified ,n ,y ,y prometheus ,metric ,prometheus ,0.0.0 ,certified ,n ,y ,y From 131c0a04ef699c788af124ac2932183e81d6849c Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Thu, 7 Nov 2024 15:13:51 +0000 Subject: [PATCH 055/118] pgstream: create batcher in foreground --- internal/impl/postgresql/input_pg_stream.go | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/internal/impl/postgresql/input_pg_stream.go b/internal/impl/postgresql/input_pg_stream.go index c82ce73c7a..41c858cbdb 100644 --- a/internal/impl/postgresql/input_pg_stream.go +++ b/internal/impl/postgresql/input_pg_stream.go @@ -322,17 +322,11 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { } p.pgLogicalStream = pgStream - + batchPolicy, err := p.batching.NewBatcher(p.mgr) + if err != nil { + return err + } go func() { - batchPolicy, err := p.batching.NewBatcher(p.mgr) - if err != nil { - p.logger.Errorf("Failed to initialise batch policy: %v, falling back to no policy.", err) - conf := service.BatchPolicy{Count: 1} - if batchPolicy, err = conf.NewBatcher(p.mgr); err != nil { - panic(err) - } - } - defer func() { batchPolicy.Close(context.Background()) }() From e47ca50a763dd906aaa2e123d986b5cb57c0ce04 Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Thu, 7 Nov 2024 15:14:59 +0000 Subject: [PATCH 056/118] pgstream: only check for done once --- internal/impl/postgresql/input_pg_stream.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/internal/impl/postgresql/input_pg_stream.go b/internal/impl/postgresql/input_pg_stream.go index 41c858cbdb..780253dfce 100644 --- a/internal/impl/postgresql/input_pg_stream.go +++ b/internal/impl/postgresql/input_pg_stream.go @@ -338,8 +338,6 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { cp := checkpoint.NewCapped[*int64](int64(p.checkpointLimit)) for { select { - case <-ctx.Done(): - return case <-nextTimedBatchChan: nextTimedBatchChan = nil flushedBatch, err := batchPolicy.Flush(ctx) @@ -436,6 +434,7 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { if err = p.pgLogicalStream.Stop(); err != nil { p.logger.Errorf("Failed to stop pglogical stream: %v", err) } + return } } }() From 6eae2329cd00dfaf3de1508e4c734407a4b54f9d Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Fri, 8 Nov 2024 19:50:18 +0000 Subject: [PATCH 057/118] pgcdc: remove bool for operation --- internal/impl/postgresql/input_pg_stream.go | 6 ++---- .../impl/postgresql/pglogicalstream/logical_stream.go | 6 +++--- .../impl/postgresql/pglogicalstream/stream_message.go | 9 ++++++++- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/internal/impl/postgresql/input_pg_stream.go b/internal/impl/postgresql/input_pg_stream.go index 780253dfce..466ad18612 100644 --- a/internal/impl/postgresql/input_pg_stream.go +++ b/internal/impl/postgresql/input_pg_stream.go @@ -12,7 +12,6 @@ import ( "context" "encoding/json" "fmt" - "strconv" "sync" "sync/atomic" "time" @@ -401,8 +400,7 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { batchMsg := service.NewMessage(mb) - streaming := strconv.FormatBool(message.IsStreaming) - batchMsg.MetaSet("streaming", streaming) + batchMsg.MetaSet("mode", string(message.Mode)) batchMsg.MetaSet("table", message.Changes[0].Table) batchMsg.MetaSet("operation", message.Changes[0].Operation) if message.Changes[0].TableSnapshotProgress != nil { @@ -419,7 +417,7 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { p.logger.Debugf("Flush batch error: %w", err) break } - if message.IsStreaming { + if message.Mode == pglogicalstream.StreamModeStreaming { if !p.flushBatch(ctx, cp, flushedBatch, latestOffset, true) { break } diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index f3f0bd40bb..d745f6f3a2 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -448,7 +448,7 @@ func (s *Stream) streamMessagesAsync() { Changes: []StreamMessageChanges{ *message, }, - IsStreaming: true, + Mode: StreamModeStreaming, WALLagBytes: &metrics.WalLagInBytes, } } @@ -511,7 +511,7 @@ func (s *Stream) streamMessagesAsync() { s.messages <- StreamMessage{ Lsn: &lsn, Changes: pgoutputChanges, - IsStreaming: true, + Mode: StreamModeStreaming, WALLagBytes: &metrics.WalLagInBytes, } } @@ -661,7 +661,7 @@ func (s *Stream) processSnapshot(ctx context.Context) error { tableProgress := s.monitor.GetSnapshotProgressForTable(tableWithoutSchema) snapshotChangePacket.Changes[0].TableSnapshotProgress = &tableProgress - snapshotChangePacket.IsStreaming = false + snapshotChangePacket.Mode = StreamModeSnapshot waitingFromBenthos := time.Now() s.messages <- snapshotChangePacket diff --git a/internal/impl/postgresql/pglogicalstream/stream_message.go b/internal/impl/postgresql/pglogicalstream/stream_message.go index 3a139fcc29..7c6b8531c0 100644 --- a/internal/impl/postgresql/pglogicalstream/stream_message.go +++ b/internal/impl/postgresql/pglogicalstream/stream_message.go @@ -25,10 +25,17 @@ type StreamMessageMetrics struct { IsStreaming bool `json:"is_streaming"` } +type StreamMode string + +const ( + StreamModeStreaming StreamMode = "streaming" + StreamModeSnapshot StreamMode = "snapshot" +) + // StreamMessage represents a single message after it has been decoded by the plugin type StreamMessage struct { Lsn *string `json:"lsn"` Changes []StreamMessageChanges `json:"changes"` - IsStreaming bool `json:"is_streaming"` + Mode StreamMode `json:"mode"` WALLagBytes *int64 `json:"wal_lag_bytes"` } From 35b7d98852757c271acb1df455da44f29f2273fd Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Fri, 8 Nov 2024 19:53:59 +0000 Subject: [PATCH 058/118] pgcdc: update docs for mode --- docs/modules/components/pages/inputs/pg_stream.adoc | 2 +- internal/impl/postgresql/input_pg_stream.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/modules/components/pages/inputs/pg_stream.adoc b/docs/modules/components/pages/inputs/pg_stream.adoc index fb67d91e3a..9b8ca14930 100644 --- a/docs/modules/components/pages/inputs/pg_stream.adoc +++ b/docs/modules/components/pages/inputs/pg_stream.adoc @@ -101,7 +101,7 @@ Additionally, if `stream_snapshot` is set to true, then the existing data in the == Metadata This input adds the following metadata fields to each message: -- is_streaming (Boolean indicating whether the message is part of a streaming operation or snapshot processing) +- mode (Either "streaming" or "snapshot" indicating whether the message is part of a streaming operation or snapshot processing) - table (Name of the table that the message originated from) - operation (Type of operation that generated the message, such as INSERT, UPDATE, or DELETE) diff --git a/internal/impl/postgresql/input_pg_stream.go b/internal/impl/postgresql/input_pg_stream.go index 466ad18612..d1c83c653e 100644 --- a/internal/impl/postgresql/input_pg_stream.go +++ b/internal/impl/postgresql/input_pg_stream.go @@ -57,7 +57,7 @@ Additionally, if ` + "`" + fieldStreamSnapshot + "`" + ` is set to true, then th == Metadata This input adds the following metadata fields to each message: -- is_streaming (Boolean indicating whether the message is part of a streaming operation or snapshot processing) +- mode (Either "streaming" or "snapshot" indicating whether the message is part of a streaming operation or snapshot processing) - table (Name of the table that the message originated from) - operation (Type of operation that generated the message, such as INSERT, UPDATE, or DELETE) `). From 4f13dcc1d10840db42829cac28a9e4d444d4013f Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Fri, 8 Nov 2024 20:19:31 +0000 Subject: [PATCH 059/118] pgcdc: validate slot names can't cause SQL injection --- internal/impl/postgresql/input_pg_stream.go | 23 +++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/internal/impl/postgresql/input_pg_stream.go b/internal/impl/postgresql/input_pg_stream.go index d1c83c653e..e7411425f7 100644 --- a/internal/impl/postgresql/input_pg_stream.go +++ b/internal/impl/postgresql/input_pg_stream.go @@ -12,6 +12,7 @@ import ( "context" "encoding/json" "fmt" + "strings" "sync" "sync/atomic" "time" @@ -147,6 +148,10 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser dbSlotName = uuid.NewString() } + if err := validateSimpleString(dbSlotName); err != nil { + return nil, fmt.Errorf("invalid slot_name: %w", err) + } + temporarySlot, err = conf.FieldBool(fieldTemporarySlot) if err != nil { return nil, err @@ -258,6 +263,24 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser return conf.WrapBatchInputExtractTracingSpanMapping("pg_stream", r) } +// validateSimpleString ensures we aren't vuln to SQL injection +func validateSimpleString(s string) error { + for _, b := range []byte(s) { + isDigit := b >= '0' && b <= '9' + isLower := b >= 'a' && b <= 'z' + isUpper := b >= 'A' && b <= 'Z' + isDelimiter := b == '_' || b == '-' + if !isDigit && !isLower && !isUpper && !isDelimiter { + return fmt.Errorf("invalid postgres identifier %q", s) + } + } + // See: https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p + if strings.Contains(s, "--") { + return fmt.Errorf("invalid postgres identifier %q", s) + } + return nil +} + func init() { err := service.RegisterBatchInput( "pg_stream", pgStreamConfigSpec, From 83d1db51347f759a545b63c6f7d782eea8eb41d0 Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Fri, 8 Nov 2024 20:20:00 +0000 Subject: [PATCH 060/118] pgcdc: use error type for error handling, not bool --- internal/impl/postgresql/input_pg_stream.go | 75 ++++++++----------- .../pglogicalstream/logical_stream.go | 2 - 2 files changed, 33 insertions(+), 44 deletions(-) diff --git a/internal/impl/postgresql/input_pg_stream.go b/internal/impl/postgresql/input_pg_stream.go index e7411425f7..b08ebd27c0 100644 --- a/internal/impl/postgresql/input_pg_stream.go +++ b/internal/impl/postgresql/input_pg_stream.go @@ -368,7 +368,7 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { break } - if !p.flushBatch(ctx, cp, flushedBatch, latestOffset, false) { + if err := p.flushBatch(ctx, cp, flushedBatch, latestOffset, false); err != nil { break } @@ -384,7 +384,7 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { break } - if !p.flushBatch(ctx, cp, flushedBatch, latestOffset, true) { + if err = p.flushBatch(ctx, cp, flushedBatch, latestOffset, true); err != nil { break } @@ -440,16 +440,10 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { p.logger.Debugf("Flush batch error: %w", err) break } - if message.Mode == pglogicalstream.StreamModeStreaming { - if !p.flushBatch(ctx, cp, flushedBatch, latestOffset, true) { - break - } - } else { - if !p.flushBatch(ctx, cp, flushedBatch, latestOffset, false) { - break - } + waitForCommit := message.Mode == pglogicalstream.StreamModeStreaming + if err := p.flushBatch(ctx, cp, flushedBatch, latestOffset, waitForCommit); err != nil { + break } - } case <-ctx.Done(): if err = p.pgLogicalStream.Stop(); err != nil { @@ -463,9 +457,9 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { return err } -func (p *pgStreamInput) flushBatch(ctx context.Context, checkpointer *checkpoint.Capped[*int64], msg service.MessageBatch, lsn *int64, waitForCommit bool) bool { +func (p *pgStreamInput) flushBatch(ctx context.Context, checkpointer *checkpoint.Capped[*int64], msg service.MessageBatch, lsn *int64, waitForCommit bool) error { if msg == nil { - return true + return nil } resolveFn, err := checkpointer.Track(ctx, lsn, int64(len(msg))) @@ -473,41 +467,38 @@ func (p *pgStreamInput) flushBatch(ctx context.Context, checkpointer *checkpoint if ctx.Err() == nil { p.mgr.Logger().Errorf("Failed to checkpoint offset: %v\n", err) } - return false + return err } - commitChan := make(chan bool) - if !waitForCommit { - close(commitChan) + var wg sync.WaitGroup + if waitForCommit { + wg.Add(1) } - select { - case p.msgChan <- asyncMessage{ - msg: msg, - ackFn: func(ctx context.Context, res error) error { - maxOffset := resolveFn() - if maxOffset == nil { - return nil - } - p.cMut.Lock() - defer p.cMut.Unlock() - if lsn != nil { - if err = p.pgLogicalStream.AckLSN(Int64ToLSN(*lsn)); err != nil { - return err - } - if waitForCommit { - close(commitChan) - } - } + ackFn := func(ctx context.Context, res error) error { + if waitForCommit { + defer wg.Done() + } + maxOffset := resolveFn() + if maxOffset == nil { return nil - }, - }: + } + p.cMut.Lock() + defer p.cMut.Unlock() + if lsn == nil { + return nil + } + if err = p.pgLogicalStream.AckLSN(Int64ToLSN(*lsn)); err != nil { + return err + } + return nil + } + select { + case p.msgChan <- asyncMessage{msg: msg, ackFn: ackFn}: case <-ctx.Done(): - return false + return ctx.Err() } - - <-commitChan - - return true + wg.Wait() // Noop if !waitForCommit + return nil } func (p *pgStreamInput) ReadBatch(ctx context.Context) (service.MessageBatch, service.AckFunc, error) { diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index d745f6f3a2..1498ccbfbf 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -154,8 +154,6 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { var outputPlugin string // check is replication slot exist to get last restart SLN - // TODO: There should be a helper method for this that also validates the parameters to fmt here are not possible to cause SQL injection. - // this means we either escape or we validate it's only alphanumeric and `-_` connExecResult, err := stream.pgConn.Exec( ctx, fmt.Sprintf("SELECT confirmed_flush_lsn, plugin FROM pg_replication_slots WHERE slot_name = '%s'", From 3de85d02737f8fd6c95b4ae777aab91fdce89164 Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Fri, 8 Nov 2024 20:25:15 +0000 Subject: [PATCH 061/118] pgcdc: import sanitization code from pgx We are forced to use the simple query protocol for pg in replication mode, which means we need to sanitize stuff. Import some internal code from pgx for that. --- .../pglogicalstream/sanitize/sanitize.go | 356 ++++++++++++++++++ .../pglogicalstream/sanitize/sanitize_test.go | 252 +++++++++++++ 2 files changed, 608 insertions(+) create mode 100644 internal/impl/postgresql/pglogicalstream/sanitize/sanitize.go create mode 100644 internal/impl/postgresql/pglogicalstream/sanitize/sanitize_test.go diff --git a/internal/impl/postgresql/pglogicalstream/sanitize/sanitize.go b/internal/impl/postgresql/pglogicalstream/sanitize/sanitize.go new file mode 100644 index 0000000000..02098e4def --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/sanitize/sanitize.go @@ -0,0 +1,356 @@ +// Copyright (c) 2013-2021 Jack Christensen +// +// MIT License +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to +// permit persons to whom the Software is furnished to do so, subject to +// the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +// An import of sanitization code from pgx/internal/sanitize so that we +// can sanitize +package sanitize + +import ( + "bytes" + "encoding/hex" + "fmt" + "strconv" + "strings" + "time" + "unicode/utf8" +) + +// Part is either a string or an int. A string is raw SQL. An int is a +// argument placeholder. +type Part any + +type Query struct { + Parts []Part +} + +// utf.DecodeRune returns the utf8.RuneError for errors. But that is actually rune U+FFFD -- the unicode replacement +// character. utf8.RuneError is not an error if it is also width 3. +// +// https://github.com/jackc/pgx/issues/1380 +const replacementcharacterwidth = 3 + +func (q *Query) Sanitize(args ...any) (string, error) { + argUse := make([]bool, len(args)) + buf := &bytes.Buffer{} + + for _, part := range q.Parts { + var str string + switch part := part.(type) { + case string: + str = part + case int: + argIdx := part - 1 + + if argIdx < 0 { + return "", fmt.Errorf("first sql argument must be > 0") + } + + if argIdx >= len(args) { + return "", fmt.Errorf("insufficient arguments") + } + arg := args[argIdx] + switch arg := arg.(type) { + case nil: + str = "null" + case int64: + str = strconv.FormatInt(arg, 10) + case float64: + str = strconv.FormatFloat(arg, 'f', -1, 64) + case bool: + str = strconv.FormatBool(arg) + case []byte: + str = QuoteBytes(arg) + case string: + str = QuoteString(arg) + case time.Time: + str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'") + default: + return "", fmt.Errorf("invalid arg type: %T", arg) + } + argUse[argIdx] = true + + // Prevent SQL injection via Line Comment Creation + // https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p + str = " " + str + " " + default: + return "", fmt.Errorf("invalid Part type: %T", part) + } + buf.WriteString(str) + } + + for i, used := range argUse { + if !used { + return "", fmt.Errorf("unused argument: %d", i) + } + } + return buf.String(), nil +} + +func NewQuery(sql string) (*Query, error) { + l := &sqlLexer{ + src: sql, + stateFn: rawState, + } + + for l.stateFn != nil { + l.stateFn = l.stateFn(l) + } + + query := &Query{Parts: l.parts} + + return query, nil +} + +func QuoteString(str string) string { + return "'" + strings.ReplaceAll(str, "'", "''") + "'" +} + +func QuoteBytes(buf []byte) string { + return `'\x` + hex.EncodeToString(buf) + "'" +} + +type sqlLexer struct { + src string + start int + pos int + nested int // multiline comment nesting level. + stateFn stateFn + parts []Part +} + +type stateFn func(*sqlLexer) stateFn + +func rawState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case 'e', 'E': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '\'' { + l.pos += width + return escapeStringState + } + case '\'': + return singleQuoteState + case '"': + return doubleQuoteState + case '$': + nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:]) + if '0' <= nextRune && nextRune <= '9' { + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos-width]) + } + l.start = l.pos + return placeholderState + } + case '-': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '-' { + l.pos += width + return oneLineCommentState + } + case '/': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '*' { + l.pos += width + return multilineCommentState + } + case utf8.RuneError: + if width != replacementcharacterwidth { + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } + } +} + +func singleQuoteState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '\'': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '\'' { + return rawState + } + l.pos += width + case utf8.RuneError: + if width != replacementcharacterwidth { + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } + } +} + +func doubleQuoteState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '"': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '"' { + return rawState + } + l.pos += width + case utf8.RuneError: + if width != replacementcharacterwidth { + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } + } +} + +// placeholderState consumes a placeholder value. The $ must have already has +// already been consumed. The first rune must be a digit. +func placeholderState(l *sqlLexer) stateFn { + num := 0 + + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + if '0' <= r && r <= '9' { + num *= 10 + num += int(r - '0') + } else { + l.parts = append(l.parts, num) + l.pos -= width + l.start = l.pos + return rawState + } + } +} + +func escapeStringState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '\\': + _, width = utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + case '\'': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '\'' { + return rawState + } + l.pos += width + case utf8.RuneError: + if width != replacementcharacterwidth { + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } + } +} + +func oneLineCommentState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '\\': + _, width = utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + case '\n', '\r': + return rawState + case utf8.RuneError: + if width != replacementcharacterwidth { + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } + } +} + +func multilineCommentState(l *sqlLexer) stateFn { + for { + r, width := utf8.DecodeRuneInString(l.src[l.pos:]) + l.pos += width + + switch r { + case '/': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune == '*' { + l.pos += width + l.nested++ + } + case '*': + nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) + if nextRune != '/' { + continue + } + + l.pos += width + if l.nested == 0 { + return rawState + } + l.nested-- + + case utf8.RuneError: + if width != replacementcharacterwidth { + if l.pos-l.start > 0 { + l.parts = append(l.parts, l.src[l.start:l.pos]) + l.start = l.pos + } + return nil + } + } + } +} + +// SanitizeSQL replaces placeholder values with args. It quotes and escapes args +// as necessary. This function is only safe when standard_conforming_strings is +// on. +func SanitizeSQL(sql string, args ...any) (string, error) { + query, err := NewQuery(sql) + if err != nil { + return "", err + } + return query.Sanitize(args...) +} diff --git a/internal/impl/postgresql/pglogicalstream/sanitize/sanitize_test.go b/internal/impl/postgresql/pglogicalstream/sanitize/sanitize_test.go new file mode 100644 index 0000000000..ba87ba5eaa --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/sanitize/sanitize_test.go @@ -0,0 +1,252 @@ +// Copyright (c) 2013-2021 Jack Christensen +// +// MIT License +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to +// permit persons to whom the Software is furnished to do so, subject to +// the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +package sanitize_test + +import ( + "testing" + "time" + + "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/sanitize" +) + +func TestNewQuery(t *testing.T) { + successTests := []struct { + sql string + expected sanitize.Query + }{ + { + sql: "select 42", + expected: sanitize.Query{Parts: []sanitize.Part{"select 42"}}, + }, + { + sql: "select $1", + expected: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, + }, + { + sql: "select 'quoted $42', $1", + expected: sanitize.Query{Parts: []sanitize.Part{"select 'quoted $42', ", 1}}, + }, + { + sql: `select "doubled quoted $42", $1`, + expected: sanitize.Query{Parts: []sanitize.Part{`select "doubled quoted $42", `, 1}}, + }, + { + sql: "select 'foo''bar', $1", + expected: sanitize.Query{Parts: []sanitize.Part{"select 'foo''bar', ", 1}}, + }, + { + sql: `select "foo""bar", $1`, + expected: sanitize.Query{Parts: []sanitize.Part{`select "foo""bar", `, 1}}, + }, + { + sql: "select '''', $1", + expected: sanitize.Query{Parts: []sanitize.Part{"select '''', ", 1}}, + }, + { + sql: `select """", $1`, + expected: sanitize.Query{Parts: []sanitize.Part{`select """", `, 1}}, + }, + { + sql: "select $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11", + expected: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2, ", ", 3, ", ", 4, ", ", 5, ", ", 6, ", ", 7, ", ", 8, ", ", 9, ", ", 10, ", ", 11}}, + }, + { + sql: `select "adsf""$1""adsf", $1, 'foo''$$12bar', $2, '$3'`, + expected: sanitize.Query{Parts: []sanitize.Part{`select "adsf""$1""adsf", `, 1, `, 'foo''$$12bar', `, 2, `, '$3'`}}, + }, + { + sql: `select E'escape string\' $42', $1`, + expected: sanitize.Query{Parts: []sanitize.Part{`select E'escape string\' $42', `, 1}}, + }, + { + sql: `select e'escape string\' $42', $1`, + expected: sanitize.Query{Parts: []sanitize.Part{`select e'escape string\' $42', `, 1}}, + }, + { + sql: `select /* a baby's toy */ 'barbie', $1`, + expected: sanitize.Query{Parts: []sanitize.Part{`select /* a baby's toy */ 'barbie', `, 1}}, + }, + { + sql: `select /* *_* */ $1`, + expected: sanitize.Query{Parts: []sanitize.Part{`select /* *_* */ `, 1}}, + }, + { + sql: `select 42 /* /* /* 42 */ */ */, $1`, + expected: sanitize.Query{Parts: []sanitize.Part{`select 42 /* /* /* 42 */ */ */, `, 1}}, + }, + { + sql: "select -- a baby's toy\n'barbie', $1", + expected: sanitize.Query{Parts: []sanitize.Part{"select -- a baby's toy\n'barbie', ", 1}}, + }, + { + sql: "select 42 -- is a Deep Thought's favorite number", + expected: sanitize.Query{Parts: []sanitize.Part{"select 42 -- is a Deep Thought's favorite number"}}, + }, + { + sql: "select 42, -- \\nis a Deep Thought's favorite number\n$1", + expected: sanitize.Query{Parts: []sanitize.Part{"select 42, -- \\nis a Deep Thought's favorite number\n", 1}}, + }, + { + sql: "select 42, -- \\nis a Deep Thought's favorite number\r$1", + expected: sanitize.Query{Parts: []sanitize.Part{"select 42, -- \\nis a Deep Thought's favorite number\r", 1}}, + }, + { + // https://github.com/jackc/pgx/issues/1380 + sql: "select 'hello w�rld'", + expected: sanitize.Query{Parts: []sanitize.Part{"select 'hello w�rld'"}}, + }, + { + // Unterminated quoted string + sql: "select 'hello world", + expected: sanitize.Query{Parts: []sanitize.Part{"select 'hello world"}}, + }, + } + + for i, tt := range successTests { + query, err := sanitize.NewQuery(tt.sql) + if err != nil { + t.Errorf("%d. %v", i, err) + } + + if len(query.Parts) == len(tt.expected.Parts) { + for j := range query.Parts { + if query.Parts[j] != tt.expected.Parts[j] { + t.Errorf("%d. expected part %d to be %v but it was %v", i, j, tt.expected.Parts[j], query.Parts[j]) + } + } + } else { + t.Errorf("%d. expected query parts to be %v but it was %v", i, tt.expected.Parts, query.Parts) + } + } +} + +func TestQuerySanitize(t *testing.T) { + successfulTests := []struct { + query sanitize.Query + args []any + expected string + }{ + { + query: sanitize.Query{Parts: []sanitize.Part{"select 42"}}, + args: []any{}, + expected: `select 42`, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, + args: []any{int64(42)}, + expected: `select 42 `, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, + args: []any{float64(1.23)}, + expected: `select 1.23 `, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, + args: []any{true}, + expected: `select true `, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, + args: []any{[]byte{0, 1, 2, 3, 255}}, + expected: `select '\x00010203ff' `, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, + args: []any{nil}, + expected: `select null `, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, + args: []any{"foobar"}, + expected: `select 'foobar' `, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, + args: []any{"foo'bar"}, + expected: `select 'foo''bar' `, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, + args: []any{`foo\'bar`}, + expected: `select 'foo\''bar' `, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"insert ", 1}}, + args: []any{time.Date(2020, time.March, 1, 23, 59, 59, 999999999, time.UTC)}, + expected: `insert '2020-03-01 23:59:59.999999Z' `, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select 1-", 1}}, + args: []any{int64(-1)}, + expected: `select 1- -1 `, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select 1-", 1}}, + args: []any{float64(-1)}, + expected: `select 1- -1 `, + }, + } + + for i, tt := range successfulTests { + actual, err := tt.query.Sanitize(tt.args...) + if err != nil { + t.Errorf("%d. %v", i, err) + continue + } + + if tt.expected != actual { + t.Errorf("%d. expected %s, but got %s", i, tt.expected, actual) + } + } + + errorTests := []struct { + query sanitize.Query + args []any + expected string + }{ + { + query: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2}}, + args: []any{int64(42)}, + expected: `insufficient arguments`, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select 'foo'"}}, + args: []any{int64(42)}, + expected: `unused argument: 0`, + }, + { + query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, + args: []any{42}, + expected: `invalid arg type: int`, + }, + } + + for i, tt := range errorTests { + _, err := tt.query.Sanitize(tt.args...) + if err == nil || err.Error() != tt.expected { + t.Errorf("%d. expected error %v, got %v", i, tt.expected, err) + } + } +} From a67ca7c5705ea850539f8967cde90fccb3e7dac2 Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Fri, 8 Nov 2024 20:28:34 +0000 Subject: [PATCH 062/118] pgcdc: add note about pk in snapshot reading --- internal/impl/postgresql/input_pg_stream.go | 2 +- internal/impl/postgresql/pglogicalstream/config.go | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/internal/impl/postgresql/input_pg_stream.go b/internal/impl/postgresql/input_pg_stream.go index b08ebd27c0..fb0f15b3cb 100644 --- a/internal/impl/postgresql/input_pg_stream.go +++ b/internal/impl/postgresql/input_pg_stream.go @@ -69,7 +69,7 @@ This input adds the following metadata fields to each message: Description("If set to true, the plugin will stream uncommitted transactions before receiving a commit message from PostgreSQL. This may result in duplicate records if the connector is restarted."). Default(false)). Field(service.NewBoolField(fieldStreamSnapshot). - Description("When set to true, the plugin will first stream a snapshot of all existing data in the database before streaming changes."). + Description("When set to true, the plugin will first stream a snapshot of all existing data in the database before streaming changes. In order to use this the tables that are being snapshot MUST have a primary key set so that reading from the table can be parallelized."). Example(true). Default(false)). Field(service.NewFloatField(fieldSnapshotMemSafetyFactor). diff --git a/internal/impl/postgresql/pglogicalstream/config.go b/internal/impl/postgresql/pglogicalstream/config.go index 8303694cc2..cdcc3b0d1d 100644 --- a/internal/impl/postgresql/pglogicalstream/config.go +++ b/internal/impl/postgresql/pglogicalstream/config.go @@ -23,6 +23,8 @@ type Config struct { // DbTables is the tables to stream changes from DBTables []string // ReplicationSlotName is the name of the replication slot to use + // + // MUST BE SQL INJECTION FREE ReplicationSlotName string // TemporaryReplicationSlot is whether to use a temporary replication slot TemporaryReplicationSlot bool From ca4cfdbaf2707e45c43df7494e705eddd9205fbd Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Fri, 8 Nov 2024 21:07:20 +0000 Subject: [PATCH 063/118] pgcdc: properly sanitize query --- .../pglogicalstream/logical_stream.go | 29 ++++++++++++------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 1498ccbfbf..ef4704de1e 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -22,6 +22,7 @@ import ( "github.com/jackc/pgx/v5/pgproto3" "github.com/jackc/pgx/v5/pgtype" "github.com/redpanda-data/benthos/v4/public/service" + "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/sanitize" "golang.org/x/sync/errgroup" ) @@ -43,7 +44,8 @@ type Stream struct { snapshotName string slotName string schema string - tableNames []string + // includes schema + tableQualifiedName []string snapshotBatchSize int decodingPlugin DecodingPlugin decodingPluginArguments []string @@ -87,7 +89,7 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { streamUncommitted: config.StreamUncommitted, snapshotBatchSize: config.BatchSize, schema: config.DBSchema, - tableNames: tableNames, + tableQualifiedName: tableNames, consumedCallback: make(chan bool), transactionAckChan: make(chan string), transactionBeginChan: make(chan bool), @@ -121,6 +123,7 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { if stream.decodingPlugin == "pgoutput" { pluginArguments = []string{ "proto_version '1'", + // Sprintf is safe because we validate ReplicationSlotName is alphanumeric in the config fmt.Sprintf("publication_names 'pglog_stream_%s'", config.ReplicationSlotName), } @@ -131,7 +134,8 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { tablesFilterRule := strings.Join(tableNames, ", ") pluginArguments = []string{ "\"pretty-print\" 'true'", - "\"add-tables\"" + " " + fmt.Sprintf("'%s'", tablesFilterRule), + // TODO: Validate this is escaped properly + fmt.Sprintf(`"add-tables" '%s'`, tablesFilterRule), } } else { return nil, fmt.Errorf("unknown decoding plugin: %q", stream.decodingPlugin) @@ -154,11 +158,11 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { var outputPlugin string // check is replication slot exist to get last restart SLN - connExecResult, err := stream.pgConn.Exec( - ctx, - fmt.Sprintf("SELECT confirmed_flush_lsn, plugin FROM pg_replication_slots WHERE slot_name = '%s'", - config.ReplicationSlotName), - ).ReadAll() + s, err := sanitize.SanitizeSQL("SELECT confirmed_flush_lsn, plugin FROM pg_replication_slots WHERE slot_name = $1", config.ReplicationSlotName) + if err != nil { + return nil, err + } + connExecResult, err := stream.pgConn.Exec(ctx, s).ReadAll() if err != nil { return nil, err } @@ -547,7 +551,7 @@ func (s *Stream) processSnapshot(ctx context.Context) error { var wg errgroup.Group - for _, table := range s.tableNames { + for _, table := range s.tableQualifiedName { tableName := table wg.Go(func() (err error) { s.logger.Infof("Processing snapshot for table: %v", table) @@ -710,14 +714,17 @@ func (s *Stream) cleanUpOnFailure(ctx context.Context) error { } func (s *Stream) getPrimaryKeyColumn(tableName string) (string, error) { - q := fmt.Sprintf(` + q, err := sanitize.SanitizeSQL(` SELECT a.attname FROM pg_index i JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey) - WHERE i.indrelid = '%s'::regclass + WHERE i.indrelid = $1::regclass AND i.indisprimary; `, tableName) + if err != nil { + return "", err + } reader := s.pgConn.Exec(context.Background(), q) data, err := reader.ReadAll() From a6854300bfef0f18f2c681e3bb8a027521c4e2b6 Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Fri, 8 Nov 2024 21:14:45 +0000 Subject: [PATCH 064/118] pgcdc add note about how waiting for commit is buggy --- internal/impl/postgresql/input_pg_stream.go | 7 +++++++ internal/impl/postgresql/pglogicalstream/logical_stream.go | 7 +++---- internal/impl/postgresql/pglogicalstream/pglogrepl.go | 1 + 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/internal/impl/postgresql/input_pg_stream.go b/internal/impl/postgresql/input_pg_stream.go index fb0f15b3cb..c515175098 100644 --- a/internal/impl/postgresql/input_pg_stream.go +++ b/internal/impl/postgresql/input_pg_stream.go @@ -475,6 +475,13 @@ func (p *pgStreamInput) flushBatch(ctx context.Context, checkpointer *checkpoint wg.Add(1) } ackFn := func(ctx context.Context, res error) error { + // This waits for *THIS MESSAGE* to get acked, which is + // not when we actually ack this LSN because of out of order + // processing might cause another message to actually resolve + // the proper checkpointer to commit. + // + // This waitForCommit business probably needs to happen inside + // the ack stream not here. if waitForCommit { defer wg.Done() } diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index ef4704de1e..25a879d813 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -81,6 +81,9 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { } tableNames := slices.Clone(config.DBTables) + for i, table := range tableNames { + tableNames[i] = fmt.Sprintf("%s.%s", config.DBSchema, table) + } stream := &Stream{ pgConn: dbConn, messages: make(chan StreamMessage), @@ -99,10 +102,6 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { decodingPlugin: decodingPluginFromString(config.DecodingPlugin), } - for i, table := range tableNames { - tableNames[i] = fmt.Sprintf("%s.%s", config.DBSchema, table) - } - var version int version, err = getPostgresVersion(config.DBRawDSN) if err != nil { diff --git a/internal/impl/postgresql/pglogicalstream/pglogrepl.go b/internal/impl/postgresql/pglogicalstream/pglogrepl.go index f7a23c1c83..c879247521 100644 --- a/internal/impl/postgresql/pglogicalstream/pglogrepl.go +++ b/internal/impl/postgresql/pglogicalstream/pglogrepl.go @@ -344,6 +344,7 @@ func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName return nil } + // TODO(rockwood): We need to validate the tables don't contain a SQL injection attack tablesSchemaFilter := "FOR TABLE " + strings.Join(tables, ",") if len(tables) == 0 { tablesSchemaFilter = "FOR ALL TABLES" From 2d3b322565b5e443444d72b24237525c019ec29f Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Fri, 8 Nov 2024 21:32:41 +0000 Subject: [PATCH 065/118] pgcdc: drop unused param --- internal/impl/postgresql/pglogicalstream/logical_stream.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 25a879d813..4bf74cf314 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -144,7 +144,7 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { pubName := "pglog_stream_" + config.ReplicationSlotName stream.logger.Infof("Creating publication %s for tables: %s", pubName, tableNames) - if err = CreatePublication(ctx, stream.pgConn, pubName, tableNames, true); err != nil { + if err = CreatePublication(ctx, stream.pgConn, pubName, tableNames); err != nil { return nil, err } sysident, err := IdentifySystem(ctx, stream.pgConn) From 743ee339c114990611e589f046946c87abf65b46 Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Fri, 8 Nov 2024 21:34:52 +0000 Subject: [PATCH 066/118] pgcdc: actually remove unused param --- internal/impl/postgresql/pglogicalstream/pglogrepl.go | 2 +- internal/impl/postgresql/pglogicalstream/pglogrepl_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/impl/postgresql/pglogicalstream/pglogrepl.go b/internal/impl/postgresql/pglogicalstream/pglogrepl.go index c879247521..9d4f1837d3 100644 --- a/internal/impl/postgresql/pglogicalstream/pglogrepl.go +++ b/internal/impl/postgresql/pglogicalstream/pglogrepl.go @@ -338,7 +338,7 @@ func DropReplicationSlot(ctx context.Context, conn *pgconn.PgConn, slotName stri } // CreatePublication creates a new PostgreSQL publication with the given name for a list of tables and drop if exists flag -func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName string, tables []string, dropIfExist bool) error { +func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName string, tables []string) error { result := conn.Exec(ctx, fmt.Sprintf("DROP PUBLICATION IF EXISTS %s;", publicationName)) if _, err := result.ReadAll(); err != nil { return nil diff --git a/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go b/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go index 26689674cb..1e54b045d8 100644 --- a/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go +++ b/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go @@ -226,7 +226,7 @@ func TestStartReplication(t *testing.T) { // create publication publicationName := "test_publication" - err = CreatePublication(context.Background(), conn, publicationName, []string{}, true) + err = CreatePublication(context.Background(), conn, publicationName, []string{}) require.NoError(t, err) _, err = CreateReplicationSlot(ctx, conn, slotName, outputPlugin, CreateReplicationSlotOptions{Temporary: false, SnapshotAction: "export"}, 16, nil) From b6789ff86d5ddbf33244561d1e11d003791e7257 Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Fri, 8 Nov 2024 21:35:55 +0000 Subject: [PATCH 067/118] pgcdc: update docs --- docs/modules/components/pages/inputs/pg_stream.adoc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/modules/components/pages/inputs/pg_stream.adoc b/docs/modules/components/pages/inputs/pg_stream.adoc index 9b8ca14930..22dee435ef 100644 --- a/docs/modules/components/pages/inputs/pg_stream.adoc +++ b/docs/modules/components/pages/inputs/pg_stream.adoc @@ -133,7 +133,7 @@ If set to true, the plugin will stream uncommitted transactions before receiving === `stream_snapshot` -When set to true, the plugin will first stream a snapshot of all existing data in the database before streaming changes. +When set to true, the plugin will first stream a snapshot of all existing data in the database before streaming changes. In order to use this the tables that are being snapshot MUST have a primary key set so that reading from the table can be parallelized. *Type*: `bool` From fe18543338d19eaf5db1e3d3451893ecc4973ab5 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Mon, 11 Nov 2024 12:48:49 +0100 Subject: [PATCH 068/118] ref(): small code refactoring --- internal/impl/amqp09/integration_test.go | 2 +- .../amqp1/integration_service_bus_test.go | 2 +- internal/impl/amqp1/integration_test.go | 2 +- .../aws/cache_dynamodb_integration_test.go | 2 +- internal/impl/aws/integration_test.go | 2 +- .../aws/output_kinesis_integration_test.go | 3 +- internal/impl/azure/integration_test.go | 4 +- internal/impl/beanstalkd/integration_test.go | 8 +- internal/impl/cassandra/integration_test.go | 2 +- internal/impl/cockroachdb/exploration_test.go | 2 +- internal/impl/cockroachdb/integration_test.go | 2 +- internal/impl/couchbase/cache_test.go | 2 +- internal/impl/couchbase/output_test.go | 3 +- internal/impl/couchbase/processor_test.go | 3 +- internal/impl/cypher/output_test.go | 2 +- .../elasticsearch/aws/integration_test.go | 2 +- .../impl/elasticsearch/integration_test.go | 4 +- .../elasticsearch/writer_integration_test.go | 2 +- internal/impl/gcp/integration_pubsub_test.go | 2 +- internal/impl/gcp/integration_test.go | 2 +- internal/impl/hdfs/integration_test.go | 2 +- .../metrics_influxdb_integration_test.go | 2 +- .../impl/kafka/enterprise/integration_test.go | 4 +- .../impl/kafka/integration_sarama_test.go | 6 +- internal/impl/kafka/integration_test.go | 6 +- .../impl/memcached/cache_integration_test.go | 2 +- internal/impl/mongodb/input_test.go | 2 +- internal/impl/mongodb/integration_test.go | 2 +- internal/impl/mongodb/processor_test.go | 3 +- internal/impl/mqtt/integration_test.go | 2 +- internal/impl/nanomsg/integration_test.go | 2 +- .../impl/nats/integration_jetstream_test.go | 4 +- internal/impl/nats/integration_kv_test.go | 2 +- internal/impl/nats/integration_nats_test.go | 2 +- internal/impl/nats/integration_req_test.go | 3 +- internal/impl/nats/integration_stream_test.go | 2 +- internal/impl/ollama/chat_processor_test.go | 3 +- .../impl/ollama/embeddings_processor_test.go | 2 +- internal/impl/opensearch/integration_test.go | 2 +- internal/impl/postgresql/integration_test.go | 11 +- .../pglogicalstream/logical_stream.go | 174 +++++------------- .../pglogicalstream/pluginhandlers.go | 162 ++++++++++++++++ .../pglogicalstream/sanitize/sanitize.go | 20 +- .../pglogicalstream/stream_message.go | 5 +- internal/impl/pulsar/integration_test.go | 2 +- internal/impl/qdrant/integration_test.go | 2 +- internal/impl/questdb/integration_test.go | 2 +- internal/impl/redis/cache_integration_test.go | 6 +- internal/impl/redis/integration_test.go | 2 +- .../impl/redis/processor_integration_test.go | 3 +- .../impl/redis/rate_limit_integration_test.go | 4 +- internal/impl/sftp/integration_test.go | 2 +- internal/impl/splunk/integration_test.go | 2 +- internal/impl/sql/cache_integration_test.go | 2 +- internal/impl/sql/integration_test.go | 21 +-- internal/impl/zeromq/integration_test.go | 2 +- internal/secrets/redis_test.go | 2 +- 57 files changed, 299 insertions(+), 231 deletions(-) create mode 100644 internal/impl/postgresql/pglogicalstream/pluginhandlers.go diff --git a/internal/impl/amqp09/integration_test.go b/internal/impl/amqp09/integration_test.go index 93ef12ef13..5ceddf3a11 100644 --- a/internal/impl/amqp09/integration_test.go +++ b/internal/impl/amqp09/integration_test.go @@ -107,7 +107,7 @@ input: ) } - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/amqp1/integration_service_bus_test.go b/internal/impl/amqp1/integration_service_bus_test.go index 1259d4300a..b1b545f1ca 100644 --- a/internal/impl/amqp1/integration_service_bus_test.go +++ b/internal/impl/amqp1/integration_service_bus_test.go @@ -31,7 +31,7 @@ import ( ) func TestIntegrationAzureServiceBus(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) if testing.Short() { t.Skip("Skipping integration test in short mode") diff --git a/internal/impl/amqp1/integration_test.go b/internal/impl/amqp1/integration_test.go index 6cd5729a54..a5ae3420f1 100644 --- a/internal/impl/amqp1/integration_test.go +++ b/internal/impl/amqp1/integration_test.go @@ -29,7 +29,7 @@ import ( ) func TestIntegrationAMQP1(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/aws/cache_dynamodb_integration_test.go b/internal/impl/aws/cache_dynamodb_integration_test.go index 62dd533227..7ef496f7cf 100644 --- a/internal/impl/aws/cache_dynamodb_integration_test.go +++ b/internal/impl/aws/cache_dynamodb_integration_test.go @@ -93,7 +93,7 @@ func createTable(ctx context.Context, t testing.TB, dynamoPort, id string) error } func TestIntegrationDynamoDBCache(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/aws/integration_test.go b/internal/impl/aws/integration_test.go index 14dd2ce57d..e78b2eba68 100644 --- a/internal/impl/aws/integration_test.go +++ b/internal/impl/aws/integration_test.go @@ -73,7 +73,7 @@ func getLocalStack(t testing.TB) (port string) { } func TestIntegration(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() servicePort := getLocalStack(t) diff --git a/internal/impl/aws/output_kinesis_integration_test.go b/internal/impl/aws/output_kinesis_integration_test.go index e584deee60..f27f85e8ac 100644 --- a/internal/impl/aws/output_kinesis_integration_test.go +++ b/internal/impl/aws/output_kinesis_integration_test.go @@ -32,12 +32,11 @@ import ( "github.com/stretchr/testify/require" "github.com/redpanda-data/benthos/v4/public/service" - "github.com/redpanda-data/benthos/v4/public/service/integration" ) func TestKinesisIntegration(t *testing.T) { t.Skip("The docker image we're using here is old and deprecated") - integration.CheckSkip(t) + // integration.CheckSkip(t) if testing.Short() { t.Skip("Skipping integration test in short mode") diff --git a/internal/impl/azure/integration_test.go b/internal/impl/azure/integration_test.go index 4448fcb1b6..20501bd7c5 100644 --- a/internal/impl/azure/integration_test.go +++ b/internal/impl/azure/integration_test.go @@ -44,7 +44,7 @@ import ( ) func TestIntegrationAzure(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") @@ -277,7 +277,7 @@ input: } func TestIntegrationCosmosDB(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/beanstalkd/integration_test.go b/internal/impl/beanstalkd/integration_test.go index 2964edcd2f..e97238fb8c 100644 --- a/internal/impl/beanstalkd/integration_test.go +++ b/internal/impl/beanstalkd/integration_test.go @@ -37,7 +37,7 @@ input: ` func TestIntegrationBeanstalkdOpenClose(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") @@ -65,7 +65,7 @@ func TestIntegrationBeanstalkdOpenClose(t *testing.T) { } func TestIntegrationBeanstalkdSendBatch(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") @@ -93,7 +93,7 @@ func TestIntegrationBeanstalkdSendBatch(t *testing.T) { } func TestIntegrationBeanstalkdStreamSequential(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") @@ -121,7 +121,7 @@ func TestIntegrationBeanstalkdStreamSequential(t *testing.T) { } func TestIntegrationBeanstalkdStreamParallel(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/cassandra/integration_test.go b/internal/impl/cassandra/integration_test.go index 725f498e03..f284289093 100644 --- a/internal/impl/cassandra/integration_test.go +++ b/internal/impl/cassandra/integration_test.go @@ -30,7 +30,7 @@ import ( ) func TestIntegrationCassandra(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() diff --git a/internal/impl/cockroachdb/exploration_test.go b/internal/impl/cockroachdb/exploration_test.go index 437b620b4f..1989d23333 100644 --- a/internal/impl/cockroachdb/exploration_test.go +++ b/internal/impl/cockroachdb/exploration_test.go @@ -35,7 +35,7 @@ import ( ) func TestIntegrationExploration(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/cockroachdb/integration_test.go b/internal/impl/cockroachdb/integration_test.go index ad6d1a7f86..043aae159e 100644 --- a/internal/impl/cockroachdb/integration_test.go +++ b/internal/impl/cockroachdb/integration_test.go @@ -33,7 +33,7 @@ import ( ) func TestIntegrationCRDB(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() tmpDir := t.TempDir() diff --git a/internal/impl/couchbase/cache_test.go b/internal/impl/couchbase/cache_test.go index 8a8d6cbeff..710ddbc998 100644 --- a/internal/impl/couchbase/cache_test.go +++ b/internal/impl/couchbase/cache_test.go @@ -27,7 +27,7 @@ import ( ) func TestIntegrationCouchbaseCache(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) servicePort := requireCouchbase(t) diff --git a/internal/impl/couchbase/output_test.go b/internal/impl/couchbase/output_test.go index f63a502afd..bb70a0b337 100644 --- a/internal/impl/couchbase/output_test.go +++ b/internal/impl/couchbase/output_test.go @@ -25,7 +25,6 @@ import ( "github.com/stretchr/testify/require" "github.com/redpanda-data/benthos/v4/public/service" - "github.com/redpanda-data/benthos/v4/public/service/integration" "github.com/redpanda-data/connect/v4/internal/impl/couchbase" ) @@ -108,7 +107,7 @@ couchbase: } func TestIntegrationCouchbaseOutput(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) servicePort := requireCouchbase(t) diff --git a/internal/impl/couchbase/processor_test.go b/internal/impl/couchbase/processor_test.go index 988bd14890..27320f7026 100644 --- a/internal/impl/couchbase/processor_test.go +++ b/internal/impl/couchbase/processor_test.go @@ -25,7 +25,6 @@ import ( "github.com/stretchr/testify/require" "github.com/redpanda-data/benthos/v4/public/service" - "github.com/redpanda-data/benthos/v4/public/service/integration" "github.com/redpanda-data/connect/v4/internal/impl/couchbase" ) @@ -118,7 +117,7 @@ couchbase: } func TestIntegrationCouchbaseProcessor(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) servicePort := requireCouchbase(t) diff --git a/internal/impl/cypher/output_test.go b/internal/impl/cypher/output_test.go index fd9780ed58..1b9407ae34 100644 --- a/internal/impl/cypher/output_test.go +++ b/internal/impl/cypher/output_test.go @@ -49,7 +49,7 @@ func makeBatch(args ...string) service.MessageBatch { } func TestIntegration(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/elasticsearch/aws/integration_test.go b/internal/impl/elasticsearch/aws/integration_test.go index 640fb696e7..c0f0f1a996 100644 --- a/internal/impl/elasticsearch/aws/integration_test.go +++ b/internal/impl/elasticsearch/aws/integration_test.go @@ -55,7 +55,7 @@ var elasticIndex = `{ func TestIntegrationElasticsearchAWS(t *testing.T) { t.Skip("Struggling to get localstack es to work, maybe one day") - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/elasticsearch/integration_test.go b/internal/impl/elasticsearch/integration_test.go index 25a3200a42..9b2aabf541 100644 --- a/internal/impl/elasticsearch/integration_test.go +++ b/internal/impl/elasticsearch/integration_test.go @@ -50,7 +50,7 @@ var elasticIndex = `{ }` func TestIntegrationElasticsearchV8(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") @@ -133,7 +133,7 @@ output: } func TestIntegrationElasticsearchV7(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/elasticsearch/writer_integration_test.go b/internal/impl/elasticsearch/writer_integration_test.go index 21a0e3ab01..37bbb2306a 100644 --- a/internal/impl/elasticsearch/writer_integration_test.go +++ b/internal/impl/elasticsearch/writer_integration_test.go @@ -47,7 +47,7 @@ func outputFromConf(t testing.TB, confStr string, args ...any) *elasticsearch.Ou } func TestIntegrationWriter(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/gcp/integration_pubsub_test.go b/internal/impl/gcp/integration_pubsub_test.go index 2404d3a792..87fc9777e9 100644 --- a/internal/impl/gcp/integration_pubsub_test.go +++ b/internal/impl/gcp/integration_pubsub_test.go @@ -32,7 +32,7 @@ import ( ) func TestIntegrationGCPPubSub(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) pool, err := dockertest.NewPool("") require.NoError(t, err) diff --git a/internal/impl/gcp/integration_test.go b/internal/impl/gcp/integration_test.go index f5c73a3f3c..e496ef2930 100644 --- a/internal/impl/gcp/integration_test.go +++ b/internal/impl/gcp/integration_test.go @@ -45,7 +45,7 @@ func createGCPCloudStorageBucket(var1, id string) error { } func TestIntegrationGCP(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/hdfs/integration_test.go b/internal/impl/hdfs/integration_test.go index 0516bcb1bf..f8177bc905 100644 --- a/internal/impl/hdfs/integration_test.go +++ b/internal/impl/hdfs/integration_test.go @@ -28,7 +28,7 @@ import ( ) func TestIntegrationHDFS(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) // t.Skip() // Skip until we fix the static port bindings t.Parallel() diff --git a/internal/impl/influxdb/metrics_influxdb_integration_test.go b/internal/impl/influxdb/metrics_influxdb_integration_test.go index 8b4046489f..7fb650aa03 100644 --- a/internal/impl/influxdb/metrics_influxdb_integration_test.go +++ b/internal/impl/influxdb/metrics_influxdb_integration_test.go @@ -34,7 +34,7 @@ func TestInfluxIntegration(t *testing.T) { t.Skip("skipping test on macos") } - integration.CheckSkip(t) + // integration.CheckSkip(t) if testing.Short() { t.Skip("Skipping integration test in short mode") diff --git a/internal/impl/kafka/enterprise/integration_test.go b/internal/impl/kafka/enterprise/integration_test.go index 48549e8d76..659725b2b2 100644 --- a/internal/impl/kafka/enterprise/integration_test.go +++ b/internal/impl/kafka/enterprise/integration_test.go @@ -92,7 +92,7 @@ func readNKafkaMessages(ctx context.Context, t testing.TB, address, topic string } func TestIntegration(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") @@ -288,7 +288,7 @@ max_message_bytes: 1MB } func TestSchemaRegistryIntegration(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/kafka/integration_sarama_test.go b/internal/impl/kafka/integration_sarama_test.go index c58d8a4342..a4952df2ca 100644 --- a/internal/impl/kafka/integration_sarama_test.go +++ b/internal/impl/kafka/integration_sarama_test.go @@ -200,7 +200,7 @@ kafka: } func TestIntegrationSaramaRedpanda(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") @@ -452,7 +452,7 @@ input: } func TestIntegrationSaramaOld(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) if runtime.GOOS == "darwin" { t.Skip("skipping test on macos") } @@ -668,7 +668,7 @@ input: } func TestIntegrationSaramaOutputFixedTimestamp(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/kafka/integration_test.go b/internal/impl/kafka/integration_test.go index 6ee08fd1e2..a360a3cca5 100644 --- a/internal/impl/kafka/integration_test.go +++ b/internal/impl/kafka/integration_test.go @@ -62,7 +62,7 @@ func createKafkaTopic(ctx context.Context, address, id string, partitions int32) } func TestIntegrationKafka(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") @@ -269,7 +269,7 @@ func createKafkaTopicSasl(address, id string, partitions int32) error { } func TestIntegrationKafkaSasl(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") @@ -380,7 +380,7 @@ input: } func TestIntegrationKafkaOutputFixedTimestamp(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/memcached/cache_integration_test.go b/internal/impl/memcached/cache_integration_test.go index 53fa137a72..6dccb58198 100644 --- a/internal/impl/memcached/cache_integration_test.go +++ b/internal/impl/memcached/cache_integration_test.go @@ -28,7 +28,7 @@ import ( ) func TestIntegrationMemcachedCache(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/mongodb/input_test.go b/internal/impl/mongodb/input_test.go index 2612ee431e..97e627352a 100644 --- a/internal/impl/mongodb/input_test.go +++ b/internal/impl/mongodb/input_test.go @@ -57,7 +57,7 @@ query: | } func TestInputIntegration(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) pool, err := dockertest.NewPool("") if err != nil { diff --git a/internal/impl/mongodb/integration_test.go b/internal/impl/mongodb/integration_test.go index d501386b40..9987f83016 100644 --- a/internal/impl/mongodb/integration_test.go +++ b/internal/impl/mongodb/integration_test.go @@ -37,7 +37,7 @@ func generateCollectionName(testID string) string { } func TestIntegrationMongoDB(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/mongodb/processor_test.go b/internal/impl/mongodb/processor_test.go index 00f41f4c07..1331f7e5dc 100644 --- a/internal/impl/mongodb/processor_test.go +++ b/internal/impl/mongodb/processor_test.go @@ -29,13 +29,12 @@ import ( "go.mongodb.org/mongo-driver/mongo/options" "github.com/redpanda-data/benthos/v4/public/service" - "github.com/redpanda-data/benthos/v4/public/service/integration" "github.com/redpanda-data/connect/v4/internal/impl/mongodb" ) func TestProcessorIntegration(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) pool, err := dockertest.NewPool("") if err != nil { diff --git a/internal/impl/mqtt/integration_test.go b/internal/impl/mqtt/integration_test.go index f5ddcb7945..7a661780b4 100644 --- a/internal/impl/mqtt/integration_test.go +++ b/internal/impl/mqtt/integration_test.go @@ -28,7 +28,7 @@ import ( ) func TestIntegrationMQTT(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/nanomsg/integration_test.go b/internal/impl/nanomsg/integration_test.go index e1d1fb49bc..23f06ec3f8 100644 --- a/internal/impl/nanomsg/integration_test.go +++ b/internal/impl/nanomsg/integration_test.go @@ -22,7 +22,7 @@ import ( ) func TestIntegrationNanomsg(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() template := ` diff --git a/internal/impl/nats/integration_jetstream_test.go b/internal/impl/nats/integration_jetstream_test.go index ec31a2590c..9b7348a542 100644 --- a/internal/impl/nats/integration_jetstream_test.go +++ b/internal/impl/nats/integration_jetstream_test.go @@ -29,7 +29,7 @@ import ( ) func TestIntegrationNatsJetstream(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") @@ -99,7 +99,7 @@ input: } func TestIntegrationNatsPullConsumer(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/nats/integration_kv_test.go b/internal/impl/nats/integration_kv_test.go index 5732be3457..72384c9c8e 100644 --- a/internal/impl/nats/integration_kv_test.go +++ b/internal/impl/nats/integration_kv_test.go @@ -33,7 +33,7 @@ import ( ) func TestIntegrationNatsKV(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/nats/integration_nats_test.go b/internal/impl/nats/integration_nats_test.go index 1023637208..032aa2ee69 100644 --- a/internal/impl/nats/integration_nats_test.go +++ b/internal/impl/nats/integration_nats_test.go @@ -28,7 +28,7 @@ import ( ) func TestIntegrationNats(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/nats/integration_req_test.go b/internal/impl/nats/integration_req_test.go index ea1664321a..7dcf4fb1fe 100644 --- a/internal/impl/nats/integration_req_test.go +++ b/internal/impl/nats/integration_req_test.go @@ -26,11 +26,10 @@ import ( "github.com/stretchr/testify/require" "github.com/redpanda-data/benthos/v4/public/service" - "github.com/redpanda-data/benthos/v4/public/service/integration" ) func TestIntegrationNatsReq(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/nats/integration_stream_test.go b/internal/impl/nats/integration_stream_test.go index 9b66ff3fa3..92ed441101 100644 --- a/internal/impl/nats/integration_stream_test.go +++ b/internal/impl/nats/integration_stream_test.go @@ -28,7 +28,7 @@ import ( ) func TestIntegrationNatsStream(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/ollama/chat_processor_test.go b/internal/impl/ollama/chat_processor_test.go index b564f8e209..eb1b7d8f3d 100644 --- a/internal/impl/ollama/chat_processor_test.go +++ b/internal/impl/ollama/chat_processor_test.go @@ -17,7 +17,6 @@ import ( "github.com/ollama/ollama/api" "github.com/redpanda-data/benthos/v4/public/service" - "github.com/redpanda-data/benthos/v4/public/service/integration" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/testcontainers/testcontainers-go/modules/ollama" @@ -39,7 +38,7 @@ func createCompletionProcessorForTest(t *testing.T, addr string) *ollamaCompleti } func TestOllamaCompletionIntegration(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) ctx := context.Background() ollamaContainer, err := ollama.Run(ctx, "ollama/ollama:0.2.5") diff --git a/internal/impl/ollama/embeddings_processor_test.go b/internal/impl/ollama/embeddings_processor_test.go index c0bbb14753..44eb6c7f78 100644 --- a/internal/impl/ollama/embeddings_processor_test.go +++ b/internal/impl/ollama/embeddings_processor_test.go @@ -37,7 +37,7 @@ func createEmbeddingsProcessorForTest(t *testing.T, addr string) *ollamaEmbeddin } func TestOllamaEmbeddingsIntegration(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) ctx := context.Background() ollamaContainer, err := ollama.Run(ctx, "ollama/ollama:0.2.5") diff --git a/internal/impl/opensearch/integration_test.go b/internal/impl/opensearch/integration_test.go index c09fc91312..348c0b7cef 100644 --- a/internal/impl/opensearch/integration_test.go +++ b/internal/impl/opensearch/integration_test.go @@ -51,7 +51,7 @@ func outputFromConf(t testing.TB, confStr string, args ...any) *opensearch.Outpu } func TestIntegration(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/postgresql/integration_test.go b/internal/impl/postgresql/integration_test.go index 283a281c26..21267cdc3c 100644 --- a/internal/impl/postgresql/integration_test.go +++ b/internal/impl/postgresql/integration_test.go @@ -22,7 +22,6 @@ import ( _ "github.com/redpanda-data/benthos/v4/public/components/io" _ "github.com/redpanda-data/benthos/v4/public/components/pure" "github.com/redpanda-data/benthos/v4/public/service" - "github.com/redpanda-data/benthos/v4/public/service/integration" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -117,7 +116,7 @@ func ResourceWithPostgreSQLVersion(t *testing.T, pool *dockertest.Pool, version } func TestIntegrationPgCDC(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) tmpDir := t.TempDir() pool, err := dockertest.NewPool("") @@ -307,7 +306,7 @@ file: } func TestIntegrationPgCDCForPgOutputPlugin(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) tmpDir := t.TempDir() pool, err := dockertest.NewPool("") require.NoError(t, err) @@ -518,7 +517,7 @@ file: } func TestIntegrationPgCDCForPgOutputStreamUncommittedPlugin(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) tmpDir := t.TempDir() pool, err := dockertest.NewPool("") require.NoError(t, err) @@ -655,7 +654,7 @@ file: } func TestIntegrationPgMultiVersionsCDCForPgOutputStreamUncomitedPlugin(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) // running tests in the look to test different PostgreSQL versions t.Parallel() for _, v := range []string{"17", "16", "15", "14", "13", "12", "11", "10"} { @@ -794,7 +793,7 @@ file: } func TestIntegrationPgMultiVersionsCDCForPgOutputStreamComittedPlugin(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) for _, v := range []string{"17", "16", "15", "14", "13", "12", "11", "10"} { tmpDir := t.TempDir() pool, err := dockertest.NewPool("") diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 4bf74cf314..34ecf57608 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -20,7 +20,6 @@ import ( "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgproto3" - "github.com/jackc/pgx/v5/pgtype" "github.com/redpanda-data/benthos/v4/public/service" "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/sanitize" "golang.org/x/sync/errgroup" @@ -157,7 +156,7 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { var outputPlugin string // check is replication slot exist to get last restart SLN - s, err := sanitize.SanitizeSQL("SELECT confirmed_flush_lsn, plugin FROM pg_replication_slots WHERE slot_name = $1", config.ReplicationSlotName) + s, err := sanitize.SQLQuery("SELECT confirmed_flush_lsn, plugin FROM pg_replication_slots WHERE slot_name = $1", config.ReplicationSlotName) if err != nil { return nil, err } @@ -286,9 +285,20 @@ func (s *Stream) AckLSN(lsn string) error { } func (s *Stream) streamMessagesAsync() { - relations := map[uint32]*RelationMessage{} - typeMap := pgtype.NewMap() - pgoutputChanges := []StreamMessageChanges{} + var handler PluginHandler + switch s.decodingPlugin { + case "wal2json": + handler = NewWal2JsonPluginHandler(s.messages, s.monitor) + case "pgoutput": + handler = NewPgOutputPluginHandler(s.messages, s.streamUncommitted, s.monitor, s.consumedCallback, s.transactionAckChan) + default: + s.logger.Error("Invalid decoding plugin. Cant find needed handler implementation") + if err := s.Stop(); err != nil { + s.logger.Errorf("Failed to stop the stream: %v", err) + } + + return + } for { select { @@ -366,6 +376,8 @@ func (s *Stream) streamMessagesAsync() { s.nextStandbyMessageDeadline = time.Time{} } + // XLogDataByteID is the message type for the actual WAL data + // It will cause the stream to process WAL changes and create the corresponding messages case XLogDataByteID: xld, err := ParseXLogData(msg.Data[1:]) if err != nil { @@ -375,10 +387,8 @@ func (s *Stream) streamMessagesAsync() { } } clientXLogPos := xld.WALStart + LSN(len(xld.WALData)) - metrics := s.monitor.Report() if s.decodingPlugin == "wal2json" { - message, err := decodeWal2JsonChanges(clientXLogPos.String(), xld.WALData) - if err != nil { + if err = handler.Handle(clientXLogPos, xld); err != nil { s.logger.Errorf("decodeWal2JsonChanges failed: %w", err) if err = s.Stop(); err != nil { s.logger.Errorf("Failed to stop the stream: %v", err) @@ -386,137 +396,37 @@ func (s *Stream) streamMessagesAsync() { return } - if message == nil || len(message.Changes) == 0 { - // automatic ack for empty changes - // basically mean that the client is up-to-date, - // but we still need to acknowledge the LSN for standby - if err = s.AckLSN(clientXLogPos.String()); err != nil { - // stop reading from replication slot - // if we can't acknowledge the LSN - if err = s.Stop(); err != nil { - s.logger.Errorf("Failed to stop the stream: %v", err) - } - return + // automatic ack for empty changes + // basically mean that the client is up-to-date, + // but we still need to acknowledge the LSN for standby + if err = s.AckLSN(clientXLogPos.String()); err != nil { + // stop reading from replication slot + // if we can't acknowledge the LSN + if err = s.Stop(); err != nil { + s.logger.Errorf("Failed to stop the stream: %v", err) } - } else { - message.WALLagBytes = &metrics.WalLagInBytes - s.messages <- *message + return } } if s.decodingPlugin == "pgoutput" { - if s.streamUncommitted { - // parse changes inside the transaction - message, err := decodePgOutput(xld.WALData, relations, typeMap) - if err != nil { - s.logger.Errorf("decodePgOutput failed: %w", err) - if err = s.Stop(); err != nil { - s.logger.Errorf("Failed to stop the stream: %v", err) - } - return - } - - isCommit, _, err := isCommitMessage(xld.WALData) - if err != nil { - s.logger.Errorf("Failed to parse WAL data: %w", err) - if err = s.Stop(); err != nil { - s.logger.Errorf("Failed to stop the stream: %v", err) - } - return - } - - // when receiving a commit message, we need to acknowledge the LSN - // but we must wait for benthos to flush the messages before we can do that - if isCommit { - s.transactionAckChan <- clientXLogPos.String() - <-s.consumedCallback - } else { - if message == nil && !isCommit { - // 0 changes happened in the transaction - // or we received a change that are not supported/needed by the replication stream - if err = s.AckLSN(clientXLogPos.String()); err != nil { - // stop reading from replication slot - // if we can't acknowledge the LSN - if err = s.Stop(); err != nil { - s.logger.Errorf("Failed to stop the stream: %v", err) - } - return - } - } else if message != nil { - lsn := clientXLogPos.String() - s.messages <- StreamMessage{ - Lsn: &lsn, - Changes: []StreamMessageChanges{ - *message, - }, - Mode: StreamModeStreaming, - WALLagBytes: &metrics.WalLagInBytes, - } - } - } - } else { - // message changes must be collected in the buffer in the context of the same transaction - // as single transaction can contain multiple changes - // and LSN ack will cause potential loss of changes - isBegin, err := isBeginMessage(xld.WALData) - if err != nil { - s.logger.Errorf("Failed to parse WAL data: %w", err) - if err = s.Stop(); err != nil { - s.logger.Errorf("Failed to stop the stream: %v", err) - } - return - } - - if isBegin { - pgoutputChanges = []StreamMessageChanges{} - } - - // parse changes inside the transaction - message, err := decodePgOutput(xld.WALData, relations, typeMap) - if err != nil { - s.logger.Errorf("decodePgOutput failed: %w", err) - if err = s.Stop(); err != nil { - s.logger.Errorf("Failed to stop the stream: %v", err) - } - return - } - - if message != nil { - pgoutputChanges = append(pgoutputChanges, *message) - } - - isCommit, _, err := isCommitMessage(xld.WALData) - if err != nil { - s.logger.Errorf("Failed to parse WAL data: %w", err) - if err = s.Stop(); err != nil { - s.logger.Errorf("Failed to stop the stream: %v", err) - } - return + if err = handler.Handle(clientXLogPos, xld); err != nil { + s.logger.Errorf("decodePgOutputChanges failed: %w", err) + if err = s.Stop(); err != nil { + s.logger.Errorf("Failed to stop the stream: %v", err) } + } - if isCommit { - if len(pgoutputChanges) == 0 { - // 0 changes happened in the transaction - // or we received a change that are not supported/needed by the replication stream - if err = s.AckLSN(clientXLogPos.String()); err != nil { - // stop reading from replication slot - // if we can't acknowledge the LSN - if err = s.Stop(); err != nil { - s.logger.Errorf("Failed to stop the stream: %v", err) - } - return - } - } else { - // send all collected changes - lsn := clientXLogPos.String() - s.messages <- StreamMessage{ - Lsn: &lsn, - Changes: pgoutputChanges, - Mode: StreamModeStreaming, - WALLagBytes: &metrics.WalLagInBytes, - } - } + // automatic ack for empty changes + // basically mean that the client is up-to-date, + // but we still need to acknowledge the LSN for standby + if err = s.AckLSN(clientXLogPos.String()); err != nil { + // stop reading from replication slot + // if we can't acknowledge the LSN + if err = s.Stop(); err != nil { + s.logger.Errorf("Failed to stop the stream: %v", err) } + return } } } @@ -713,7 +623,7 @@ func (s *Stream) cleanUpOnFailure(ctx context.Context) error { } func (s *Stream) getPrimaryKeyColumn(tableName string) (string, error) { - q, err := sanitize.SanitizeSQL(` + q, err := sanitize.SQLQuery(` SELECT a.attname FROM pg_index i JOIN pg_attribute a ON a.attrelid = i.indrelid diff --git a/internal/impl/postgresql/pglogicalstream/pluginhandlers.go b/internal/impl/postgresql/pglogicalstream/pluginhandlers.go new file mode 100644 index 0000000000..b0902bc818 --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/pluginhandlers.go @@ -0,0 +1,162 @@ +// Copyright 2024 Redpanda Data, Inc. +// +// Licensed as a Redpanda Enterprise file under the Redpanda Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md + +package pglogicalstream + +import "github.com/jackc/pgx/v5/pgtype" + +// PluginHandler is an interface that must be implemented by all plugin handlers +type PluginHandler interface { + Handle(clientXLogPos LSN, xld XLogData) error +} + +// Wal2JsonPluginHandler is a handler for wal2json output plugin +type Wal2JsonPluginHandler struct { + messages chan StreamMessage + monitor *Monitor +} + +// NewWal2JsonPluginHandler creates a new Wal2JsonPluginHandler +func NewWal2JsonPluginHandler(messages chan StreamMessage, monitor *Monitor) *Wal2JsonPluginHandler { + return &Wal2JsonPluginHandler{ + messages: messages, + monitor: monitor, + } +} + +// Handle handles the wal2json output +func (w *Wal2JsonPluginHandler) Handle(clientXLogPos LSN, xld XLogData) error { + // get current stream metrics + metrics := w.monitor.Report() + message, err := decodeWal2JsonChanges(clientXLogPos.String(), xld.WALData) + if err != nil { + return err + } + + if message != nil && len(message.Changes) > 0 { + message.WALLagBytes = &metrics.WalLagInBytes + w.messages <- *message + } + + return nil +} + +// PgOutputPluginHandler is a handler for pgoutput output plugin +type PgOutputPluginHandler struct { + messages chan StreamMessage + monitor *Monitor + + streamUncommitted bool + relations map[uint32]*RelationMessage + typeMap *pgtype.Map + pgoutputChanges []StreamMessageChanges + + consumedCallback chan bool + transactionAckChan chan string +} + +// NewPgOutputPluginHandler creates a new PgOutputPluginHandler +func NewPgOutputPluginHandler( + messages chan StreamMessage, + streamUncommitted bool, + monitor *Monitor, + consumedCallback chan bool, + transactionAckChan chan string, +) *PgOutputPluginHandler { + return &PgOutputPluginHandler{ + messages: messages, + monitor: monitor, + streamUncommitted: streamUncommitted, + relations: map[uint32]*RelationMessage{}, + typeMap: pgtype.NewMap(), + pgoutputChanges: []StreamMessageChanges{}, + consumedCallback: consumedCallback, + transactionAckChan: transactionAckChan, + } +} + +// Handle handles the pgoutput output +func (p *PgOutputPluginHandler) Handle(clientXLogPos LSN, xld XLogData) error { + if p.streamUncommitted { + // parse changes inside the transaction + message, err := decodePgOutput(xld.WALData, p.relations, p.typeMap) + if err != nil { + return err + } + + isCommit, _, err := isCommitMessage(xld.WALData) + if err != nil { + return err + } + + // when receiving a commit message, we need to acknowledge the LSN + // but we must wait for benthos to flush the messages before we can do that + if isCommit { + p.transactionAckChan <- clientXLogPos.String() + <-p.consumedCallback + } else { + if message == nil && !isCommit { + return nil + } else if message != nil { + lsn := clientXLogPos.String() + p.messages <- StreamMessage{ + Lsn: &lsn, + Changes: []StreamMessageChanges{ + *message, + }, + Mode: StreamModeStreaming, + WALLagBytes: &p.monitor.Report().WalLagInBytes, + } + } + } + } else { + // message changes must be collected in the buffer in the context of the same transaction + // as single transaction can contain multiple changes + // and LSN ack will cause potential loss of changes + isBegin, err := isBeginMessage(xld.WALData) + if err != nil { + return err + } + + if isBegin { + p.pgoutputChanges = []StreamMessageChanges{} + } + + // parse changes inside the transaction + message, err := decodePgOutput(xld.WALData, p.relations, p.typeMap) + if err != nil { + return err + } + + if message != nil { + p.pgoutputChanges = append(p.pgoutputChanges, *message) + } + + isCommit, _, err := isCommitMessage(xld.WALData) + if err != nil { + return err + } + + if isCommit { + if len(p.pgoutputChanges) == 0 { + return nil + } else { + // send all collected changes + lsn := clientXLogPos.String() + p.messages <- StreamMessage{ + Lsn: &lsn, + Changes: p.pgoutputChanges, + Mode: StreamModeStreaming, + WALLagBytes: &p.monitor.Report().WalLagInBytes, + } + } + } + } + + return nil +} diff --git a/internal/impl/postgresql/pglogicalstream/sanitize/sanitize.go b/internal/impl/postgresql/pglogicalstream/sanitize/sanitize.go index 02098e4def..95f35bd8c0 100644 --- a/internal/impl/postgresql/pglogicalstream/sanitize/sanitize.go +++ b/internal/impl/postgresql/pglogicalstream/sanitize/sanitize.go @@ -28,6 +28,7 @@ package sanitize import ( "bytes" "encoding/hex" + "errors" "fmt" "strconv" "strings" @@ -39,6 +40,7 @@ import ( // argument placeholder. type Part any +// Query represents a SQL query that consists of []Part type Query struct { Parts []Part } @@ -49,6 +51,7 @@ type Query struct { // https://github.com/jackc/pgx/issues/1380 const replacementcharacterwidth = 3 +// Sanitize sanitizes a SQL query func (q *Query) Sanitize(args ...any) (string, error) { argUse := make([]bool, len(args)) buf := &bytes.Buffer{} @@ -62,11 +65,11 @@ func (q *Query) Sanitize(args ...any) (string, error) { argIdx := part - 1 if argIdx < 0 { - return "", fmt.Errorf("first sql argument must be > 0") + return "", errors.New("first sql argument must be > 0") } if argIdx >= len(args) { - return "", fmt.Errorf("insufficient arguments") + return "", errors.New("insufficient arguments") } arg := args[argIdx] switch arg := arg.(type) { @@ -79,9 +82,9 @@ func (q *Query) Sanitize(args ...any) (string, error) { case bool: str = strconv.FormatBool(arg) case []byte: - str = QuoteBytes(arg) + str = quoteBytes(arg) case string: - str = QuoteString(arg) + str = quoteString(arg) case time.Time: str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'") default: @@ -106,6 +109,7 @@ func (q *Query) Sanitize(args ...any) (string, error) { return buf.String(), nil } +// NewQuery parses a SQL query string and returns a Query object. func NewQuery(sql string) (*Query, error) { l := &sqlLexer{ src: sql, @@ -121,11 +125,11 @@ func NewQuery(sql string) (*Query, error) { return query, nil } -func QuoteString(str string) string { +func quoteString(str string) string { return "'" + strings.ReplaceAll(str, "'", "''") + "'" } -func QuoteBytes(buf []byte) string { +func quoteBytes(buf []byte) string { return `'\x` + hex.EncodeToString(buf) + "'" } @@ -344,10 +348,10 @@ func multilineCommentState(l *sqlLexer) stateFn { } } -// SanitizeSQL replaces placeholder values with args. It quotes and escapes args +// SQLQuery replaces placeholder values with args. It quotes and escapes args // as necessary. This function is only safe when standard_conforming_strings is // on. -func SanitizeSQL(sql string, args ...any) (string, error) { +func SQLQuery(sql string, args ...any) (string, error) { query, err := NewQuery(sql) if err != nil { return "", err diff --git a/internal/impl/postgresql/pglogicalstream/stream_message.go b/internal/impl/postgresql/pglogicalstream/stream_message.go index 7c6b8531c0..6d0dbdf087 100644 --- a/internal/impl/postgresql/pglogicalstream/stream_message.go +++ b/internal/impl/postgresql/pglogicalstream/stream_message.go @@ -25,11 +25,14 @@ type StreamMessageMetrics struct { IsStreaming bool `json:"is_streaming"` } +// StreamMode represents the mode of the stream at the time of the message type StreamMode string const ( + // StreamModeStreaming indicates that the stream is in streaming mode StreamModeStreaming StreamMode = "streaming" - StreamModeSnapshot StreamMode = "snapshot" + // StreamModeSnapshot indicates that the stream is in snapshot mode + StreamModeSnapshot StreamMode = "snapshot" ) // StreamMessage represents a single message after it has been decoded by the plugin diff --git a/internal/impl/pulsar/integration_test.go b/internal/impl/pulsar/integration_test.go index 221c26cefe..b835d9a48c 100644 --- a/internal/impl/pulsar/integration_test.go +++ b/internal/impl/pulsar/integration_test.go @@ -28,7 +28,7 @@ import ( ) func TestIntegrationPulsar(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/qdrant/integration_test.go b/internal/impl/qdrant/integration_test.go index 7f4e396c02..c5700dab2d 100644 --- a/internal/impl/qdrant/integration_test.go +++ b/internal/impl/qdrant/integration_test.go @@ -45,7 +45,7 @@ output: ) func TestIntegrationQdrant(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() diff --git a/internal/impl/questdb/integration_test.go b/internal/impl/questdb/integration_test.go index e204cf9863..06e8329c71 100644 --- a/internal/impl/questdb/integration_test.go +++ b/internal/impl/questdb/integration_test.go @@ -35,7 +35,7 @@ import ( func TestIntegrationQuestDB(t *testing.T) { ctx := context.Background() - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/redis/cache_integration_test.go b/internal/impl/redis/cache_integration_test.go index 1fbd4c3b36..ae1f4e7442 100644 --- a/internal/impl/redis/cache_integration_test.go +++ b/internal/impl/redis/cache_integration_test.go @@ -31,7 +31,7 @@ import ( ) func TestIntegrationRedisCache(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") @@ -85,7 +85,7 @@ cache_resources: func TestIntegrationRedisClusterCache(t *testing.T) { t.Skip("Skipping as networking often fails for this test") - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") @@ -178,7 +178,7 @@ cache_resources: func TestIntegrationRedisFailoverCache(t *testing.T) { t.Skip("Skipping as networking often fails for this test") - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/redis/integration_test.go b/internal/impl/redis/integration_test.go index 5ab6991d14..044215af54 100644 --- a/internal/impl/redis/integration_test.go +++ b/internal/impl/redis/integration_test.go @@ -31,7 +31,7 @@ import ( ) func TestIntegrationRedis(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/redis/processor_integration_test.go b/internal/impl/redis/processor_integration_test.go index 25cc1a3594..16566e3daa 100644 --- a/internal/impl/redis/processor_integration_test.go +++ b/internal/impl/redis/processor_integration_test.go @@ -28,11 +28,10 @@ import ( "github.com/stretchr/testify/require" "github.com/redpanda-data/benthos/v4/public/service" - "github.com/redpanda-data/benthos/v4/public/service/integration" ) func TestIntegrationRedisProcessor(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) pool, err := dockertest.NewPool("") if err != nil { diff --git a/internal/impl/redis/rate_limit_integration_test.go b/internal/impl/redis/rate_limit_integration_test.go index 19dee16685..5174019443 100644 --- a/internal/impl/redis/rate_limit_integration_test.go +++ b/internal/impl/redis/rate_limit_integration_test.go @@ -26,12 +26,10 @@ import ( "github.com/redis/go-redis/v9" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - "github.com/redpanda-data/benthos/v4/public/service/integration" ) func TestIntegrationRedisRateLimit(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) pool, err := dockertest.NewPool("") if err != nil { diff --git a/internal/impl/sftp/integration_test.go b/internal/impl/sftp/integration_test.go index 0e236f3c38..c5166eb36e 100644 --- a/internal/impl/sftp/integration_test.go +++ b/internal/impl/sftp/integration_test.go @@ -36,7 +36,7 @@ var ( ) func TestIntegration(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/splunk/integration_test.go b/internal/impl/splunk/integration_test.go index 85ec377a2f..415a8ef6aa 100644 --- a/internal/impl/splunk/integration_test.go +++ b/internal/impl/splunk/integration_test.go @@ -24,7 +24,7 @@ import ( ) func TestIntegrationSplunk(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/sql/cache_integration_test.go b/internal/impl/sql/cache_integration_test.go index ac76c2cc1a..8d62ad806a 100644 --- a/internal/impl/sql/cache_integration_test.go +++ b/internal/impl/sql/cache_integration_test.go @@ -29,7 +29,7 @@ import ( ) func TestIntegrationCache(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/sql/integration_test.go b/internal/impl/sql/integration_test.go index 1a4b9103fb..8b7ea78539 100644 --- a/internal/impl/sql/integration_test.go +++ b/internal/impl/sql/integration_test.go @@ -29,7 +29,6 @@ import ( "github.com/stretchr/testify/require" "github.com/redpanda-data/benthos/v4/public/service" - "github.com/redpanda-data/benthos/v4/public/service/integration" isql "github.com/redpanda-data/connect/v4/internal/impl/sql" @@ -564,7 +563,7 @@ func testSuite(t *testing.T, driver, dsn string, createTableFn func(string) (str } func TestIntegrationClickhouse(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") @@ -619,7 +618,7 @@ func TestIntegrationClickhouse(t *testing.T) { } func TestIntegrationOldClickhouse(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") @@ -674,7 +673,7 @@ func TestIntegrationOldClickhouse(t *testing.T) { } func TestIntegrationPostgres(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") @@ -735,7 +734,7 @@ func TestIntegrationPostgres(t *testing.T) { } func TestIntegrationPostgresVector(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") @@ -846,7 +845,7 @@ suffix: ORDER BY embedding <-> '[3,1,2]' LIMIT 1 } func TestIntegrationMySQL(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") @@ -910,7 +909,7 @@ func TestIntegrationMySQL(t *testing.T) { } func TestIntegrationMSSQL(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") @@ -971,7 +970,7 @@ func TestIntegrationMSSQL(t *testing.T) { } func TestIntegrationSQLite(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() var db *sql.DB @@ -1014,7 +1013,7 @@ func TestIntegrationSQLite(t *testing.T) { } func TestIntegrationOracle(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") @@ -1076,7 +1075,7 @@ func TestIntegrationOracle(t *testing.T) { } func TestIntegrationTrino(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") @@ -1137,7 +1136,7 @@ create table %s ( } func TestIntegrationCosmosDB(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/zeromq/integration_test.go b/internal/impl/zeromq/integration_test.go index 54802da11b..add75809b2 100644 --- a/internal/impl/zeromq/integration_test.go +++ b/internal/impl/zeromq/integration_test.go @@ -25,7 +25,7 @@ import ( ) func TestIntegrationZMQ(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() template := ` diff --git a/internal/secrets/redis_test.go b/internal/secrets/redis_test.go index 0d2e28b839..8495fbafe5 100644 --- a/internal/secrets/redis_test.go +++ b/internal/secrets/redis_test.go @@ -26,7 +26,7 @@ import ( ) func TestIntegrationRedis(t *testing.T) { - integration.CheckSkip(t) + // integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") From 3622b1705070198364dea86e678b091fa834140d Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Mon, 11 Nov 2024 12:58:30 +0100 Subject: [PATCH 069/118] feat(): added max_parallel_snapshot_tables config field --- internal/impl/postgresql/input_pg_stream.go | 176 +++++++++--------- .../impl/postgresql/pglogicalstream/config.go | 5 +- .../pglogicalstream/logical_stream.go | 6 +- 3 files changed, 96 insertions(+), 91 deletions(-) diff --git a/internal/impl/postgresql/input_pg_stream.go b/internal/impl/postgresql/input_pg_stream.go index c515175098..80744cf6c3 100644 --- a/internal/impl/postgresql/input_pg_stream.go +++ b/internal/impl/postgresql/input_pg_stream.go @@ -26,20 +26,21 @@ import ( ) const ( - fieldDSN = "dsn" - fieldStreamUncommitted = "stream_uncommitted" - fieldStreamSnapshot = "stream_snapshot" - fieldSnapshotMemSafetyFactor = "snapshot_memory_safety_factor" - fieldSnapshotBatchSize = "snapshot_batch_size" - fieldDecodingPlugin = "decoding_plugin" - fieldSchema = "schema" - fieldTables = "tables" - fieldCheckpointLimit = "checkpoint_limit" - fieldTemporarySlot = "temporary_slot" - fieldPgStandbyTimeout = "pg_standby_timeout_sec" - fieldWalMonitorIntervalSec = "pg_wal_monitor_interval_sec" - fieldSlotName = "slot_name" - fieldBatching = "batching" + fieldDSN = "dsn" + fieldStreamUncommitted = "stream_uncommitted" + fieldStreamSnapshot = "stream_snapshot" + fieldSnapshotMemSafetyFactor = "snapshot_memory_safety_factor" + fieldSnapshotBatchSize = "snapshot_batch_size" + fieldDecodingPlugin = "decoding_plugin" + fieldSchema = "schema" + fieldTables = "tables" + fieldCheckpointLimit = "checkpoint_limit" + fieldTemporarySlot = "temporary_slot" + fieldPgStandbyTimeout = "pg_standby_timeout_sec" + fieldWalMonitorIntervalSec = "pg_wal_monitor_interval_sec" + fieldSlotName = "slot_name" + fieldBatching = "batching" + fieldMaxParallelSnapshotTables = "max_parallel_snapshot_tables" ) type asyncMessage struct { @@ -113,34 +114,36 @@ This input adds the following metadata fields to each message: Description("Int field stat specifies ticker interval for WAL monitoring. Used to fetch replication slot lag"). Example(3). Default(3)). + Field(service.NewIntField(fieldMaxParallelSnapshotTables). + Description("Int specifies a number of tables that will be processed in parallel during the snapshot processing stage"). + Default(1)). Field(service.NewAutoRetryNacksToggleField()). Field(service.NewBatchPolicyField(fieldBatching)) func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s service.BatchInput, err error) { var ( - dsn string - dbSlotName string - temporarySlot bool - schema string - tables []string - streamSnapshot bool - snapshotMemSafetyFactor float64 - decodingPlugin string - streamUncommitted bool - snapshotBatchSize int - checkpointLimit int - walMonitorIntervalSec int - pgStandbyTimeoutSec int - batching service.BatchPolicy + dsn string + dbSlotName string + temporarySlot bool + schema string + tables []string + streamSnapshot bool + snapshotMemSafetyFactor float64 + decodingPlugin string + streamUncommitted bool + snapshotBatchSize int + checkpointLimit int + walMonitorIntervalSec int + maxParallelSnapshotTables int + pgStandbyTimeoutSec int + batching service.BatchPolicy ) - dsn, err = conf.FieldString(fieldDSN) - if err != nil { + if dsn, err = conf.FieldString(fieldDSN); err != nil { return nil, err } - dbSlotName, err = conf.FieldString(fieldSlotName) - if err != nil { + if dbSlotName, err = conf.FieldString(fieldSlotName); err != nil { return nil, err } // Set the default to be a random string @@ -152,18 +155,15 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser return nil, fmt.Errorf("invalid slot_name: %w", err) } - temporarySlot, err = conf.FieldBool(fieldTemporarySlot) - if err != nil { + if temporarySlot, err = conf.FieldBool(fieldTemporarySlot); err != nil { return nil, err } - schema, err = conf.FieldString(fieldSchema) - if err != nil { + if schema, err = conf.FieldString(fieldSchema); err != nil { return nil, err } - tables, err = conf.FieldStringList(fieldTables) - if err != nil { + if tables, err = conf.FieldStringList(fieldTables); err != nil { return nil, err } @@ -171,28 +171,23 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser return nil, err } - streamSnapshot, err = conf.FieldBool(fieldStreamSnapshot) - if err != nil { + if streamSnapshot, err = conf.FieldBool(fieldStreamSnapshot); err != nil { return nil, err } - streamUncommitted, err = conf.FieldBool(fieldStreamUncommitted) - if err != nil { + if streamUncommitted, err = conf.FieldBool(fieldStreamUncommitted); err != nil { return nil, err } - decodingPlugin, err = conf.FieldString(fieldDecodingPlugin) - if err != nil { + if decodingPlugin, err = conf.FieldString(fieldDecodingPlugin); err != nil { return nil, err } - snapshotMemSafetyFactor, err = conf.FieldFloat(fieldSnapshotMemSafetyFactor) - if err != nil { + if snapshotMemSafetyFactor, err = conf.FieldFloat(fieldSnapshotMemSafetyFactor); err != nil { return nil, err } - snapshotBatchSize, err = conf.FieldInt(fieldSnapshotBatchSize) - if err != nil { + if snapshotBatchSize, err = conf.FieldInt(fieldSnapshotBatchSize); err != nil { return nil, err } @@ -202,13 +197,15 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser batching.Count = 1 } - pgStandbyTimeoutSec, err = conf.FieldInt(fieldPgStandbyTimeout) - if err != nil { + if pgStandbyTimeoutSec, err = conf.FieldInt(fieldPgStandbyTimeout); err != nil { return nil, err } - walMonitorIntervalSec, err = conf.FieldInt(fieldWalMonitorIntervalSec) - if err != nil { + if walMonitorIntervalSec, err = conf.FieldInt(fieldWalMonitorIntervalSec); err != nil { + return nil, err + } + + if maxParallelSnapshotTables, err = conf.FieldInt(fieldMaxParallelSnapshotTables); err != nil { return nil, err } @@ -230,22 +227,23 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser dbConfig: pgConnConfig, // dbRawDSN is used for creating golang PG Connection // as using pgconn.Config for golang doesn't support multiple queries in the prepared statement for Postgres Version <= 14 - dbRawDSN: dsn, - streamSnapshot: streamSnapshot, - snapshotMemSafetyFactor: snapshotMemSafetyFactor, - slotName: dbSlotName, - schema: schema, - tables: tables, - decodingPlugin: decodingPlugin, - streamUncommitted: streamUncommitted, - temporarySlot: temporarySlot, - snapshotBatchSize: snapshotBatchSize, - batching: batching, - checkpointLimit: checkpointLimit, - pgStandbyTimeoutSec: pgStandbyTimeoutSec, - walMonitorIntervalSec: walMonitorIntervalSec, - cMut: sync.Mutex{}, - msgChan: make(chan asyncMessage), + dbRawDSN: dsn, + streamSnapshot: streamSnapshot, + snapshotMemSafetyFactor: snapshotMemSafetyFactor, + slotName: dbSlotName, + schema: schema, + tables: tables, + decodingPlugin: decodingPlugin, + streamUncommitted: streamUncommitted, + temporarySlot: temporarySlot, + snapshotBatchSize: snapshotBatchSize, + batching: batching, + checkpointLimit: checkpointLimit, + pgStandbyTimeoutSec: pgStandbyTimeoutSec, + walMonitorIntervalSec: walMonitorIntervalSec, + maxParallelSnapshotTables: maxParallelSnapshotTables, + cMut: sync.Mutex{}, + msgChan: make(chan asyncMessage), mgr: mgr, logger: mgr.Logger(), @@ -293,26 +291,27 @@ func init() { } type pgStreamInput struct { - dbConfig *pgconn.Config - dbRawDSN string - pgLogicalStream *pglogicalstream.Stream - slotName string - pgStandbyTimeoutSec int - walMonitorIntervalSec int - temporarySlot bool - schema string - tables []string - decodingPlugin string - streamSnapshot bool - snapshotMemSafetyFactor float64 - snapshotBatchSize int - streamUncommitted bool - logger *service.Logger - mgr *service.Resources - cMut sync.Mutex - msgChan chan asyncMessage - batching service.BatchPolicy - checkpointLimit int + dbConfig *pgconn.Config + dbRawDSN string + pgLogicalStream *pglogicalstream.Stream + slotName string + pgStandbyTimeoutSec int + walMonitorIntervalSec int + temporarySlot bool + schema string + tables []string + decodingPlugin string + streamSnapshot bool + snapshotMemSafetyFactor float64 + snapshotBatchSize int + streamUncommitted bool + maxParallelSnapshotTables int + logger *service.Logger + mgr *service.Resources + cMut sync.Mutex + msgChan chan asyncMessage + batching service.BatchPolicy + checkpointLimit int snapshotMetrics *service.MetricGauge replicationLag *service.MetricGauge @@ -337,6 +336,7 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { SnapshotMemorySafetyFactor: p.snapshotMemSafetyFactor, PgStandbyTimeoutSec: p.pgStandbyTimeoutSec, WalMonitorIntervalSec: p.walMonitorIntervalSec, + MaxParallelSnapshotTables: p.maxParallelSnapshotTables, Logger: p.logger, }) if err != nil { diff --git a/internal/impl/postgresql/pglogicalstream/config.go b/internal/impl/postgresql/pglogicalstream/config.go index cdcc3b0d1d..7af9b3f0fa 100644 --- a/internal/impl/postgresql/pglogicalstream/config.go +++ b/internal/impl/postgresql/pglogicalstream/config.go @@ -41,6 +41,7 @@ type Config struct { Logger *service.Logger - PgStandbyTimeoutSec int - WalMonitorIntervalSec int + PgStandbyTimeoutSec int + WalMonitorIntervalSec int + MaxParallelSnapshotTables int } diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 34ecf57608..c4c675ede7 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -55,6 +55,7 @@ type Stream struct { snapshotter *Snapshotter transactionAckChan chan string transactionBeginChan chan bool + maxParallelSnapshotTables int lsnAckBuffer []string @@ -96,6 +97,7 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { transactionAckChan: make(chan string), transactionBeginChan: make(chan bool), lsnAckBuffer: []string{}, + maxParallelSnapshotTables: config.MaxParallelSnapshotTables, logger: config.Logger, m: sync.Mutex{}, decodingPlugin: decodingPluginFromString(config.DecodingPlugin), @@ -457,15 +459,17 @@ func (s *Stream) processSnapshot(ctx context.Context) error { }() s.logger.Infof("Starting snapshot processing") - + sem := make(chan struct{}, s.maxParallelSnapshotTables) var wg errgroup.Group for _, table := range s.tableQualifiedName { tableName := table + sem <- struct{}{} wg.Go(func() (err error) { s.logger.Infof("Processing snapshot for table: %v", table) defer func() { + defer func() { <-sem }() if err != nil { if cleanupErr := s.cleanUpOnFailure(ctx); cleanupErr != nil { s.logger.Errorf("Failed to clean up resources on accident: %v", cleanupErr.Error()) From 9f521ba05ee3adcd08926f2b015edfdba98baa40 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Mon, 11 Nov 2024 13:27:59 +0100 Subject: [PATCH 070/118] chore(): added pk ordering to consume snapshot --- .../impl/postgresql/pglogicalstream/logical_stream.go | 7 ++++++- internal/impl/postgresql/pglogicalstream/snapshotter.go | 9 ++++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index c4c675ede7..4e2f19dc74 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -503,10 +503,12 @@ func (s *Stream) processSnapshot(ctx context.Context) error { return err } + var lastPkVal interface{} + for { var snapshotRows *sql.Rows queryStart := time.Now() - if snapshotRows, err = s.snapshotter.querySnapshotData(table, tablePk, batchSize, offset); err != nil { + if snapshotRows, err = s.snapshotter.querySnapshotData(table, lastPkVal, tablePk, batchSize); err != nil { s.logger.Errorf("Failed to query snapshot data for table %v: %v", table, err.Error()) s.logger.Errorf("Failed to query snapshot for table %v: %v", table, err.Error()) return err @@ -556,6 +558,9 @@ func (s *Stream) processSnapshot(ctx context.Context) error { var data = make(map[string]any) for i, getter := range valueGetters { data[columnNames[i]] = getter(scanArgs[i]) + if columnNames[i] == tablePk { + lastPkVal = getter(scanArgs[i]) + } } snapshotChangePacket := StreamMessage{ diff --git a/internal/impl/postgresql/pglogicalstream/snapshotter.go b/internal/impl/postgresql/pglogicalstream/snapshotter.go index 2df221fcc5..875dead7af 100644 --- a/internal/impl/postgresql/pglogicalstream/snapshotter.go +++ b/internal/impl/postgresql/pglogicalstream/snapshotter.go @@ -165,9 +165,12 @@ func (s *Snapshotter) calculateBatchSize(availableMemory uint64, estimatedRowSiz return batchSize } -func (s *Snapshotter) querySnapshotData(table string, pk string, limit, offset int) (rows *sql.Rows, err error) { - s.logger.Infof("Query snapshot table: %v, limit: %v, offset: %v, pk: %v", table, limit, offset, pk) - return s.pgConnection.Query(fmt.Sprintf("SELECT * FROM %s ORDER BY %s LIMIT %d OFFSET %d;", table, pk, limit, offset)) +func (s *Snapshotter) querySnapshotData(table string, lastSeenPk interface{}, pk string, limit int) (rows *sql.Rows, err error) { + s.logger.Infof("Query snapshot table: %v, limit: %v, lastSeenPkVal: %v, pk: %v", table, limit, lastSeenPk, pk) + if lastSeenPk == nil { + return s.pgConnection.Query(fmt.Sprintf("SELECT * FROM %s ORDER BY %s LIMIT $1;", table, pk), limit) + } + return s.pgConnection.Query(fmt.Sprintf("SELECT * FROM %s WHERE %s > $1 ORDER BY %s LIMIT $2;", table, pk, pk), lastSeenPk, limit) } func (s *Snapshotter) releaseSnapshot() error { From 7235fd7c9c44e8e9f9b34720e831a8251dba13d1 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Mon, 11 Nov 2024 20:55:12 +0100 Subject: [PATCH 071/118] fix(): enabled integration tests --- internal/impl/amqp09/integration_test.go | 2 +- .../amqp1/integration_service_bus_test.go | 2 +- internal/impl/amqp1/integration_test.go | 2 +- .../aws/cache_dynamodb_integration_test.go | 2 +- .../aws/output_kinesis_integration_test.go | 2 +- internal/impl/azure/integration_test.go | 4 ++-- internal/impl/beanstalkd/integration_test.go | 8 ++++---- internal/impl/cassandra/integration_test.go | 2 +- internal/impl/cockroachdb/exploration_test.go | 2 +- internal/impl/cockroachdb/integration_test.go | 2 +- internal/impl/couchbase/cache_test.go | 2 +- internal/impl/couchbase/output_test.go | 2 +- internal/impl/couchbase/processor_test.go | 2 +- .../elasticsearch/aws/integration_test.go | 2 +- .../impl/elasticsearch/integration_test.go | 4 ++-- .../elasticsearch/writer_integration_test.go | 2 +- internal/impl/gcp/integration_pubsub_test.go | 2 +- internal/impl/gcp/integration_test.go | 2 +- .../metrics_influxdb_integration_test.go | 2 +- .../impl/memcached/cache_integration_test.go | 2 +- internal/impl/mongodb/input_test.go | 2 +- internal/impl/mongodb/integration_test.go | 2 +- internal/impl/mongodb/processor_test.go | 2 +- internal/impl/mqtt/integration_test.go | 2 +- internal/impl/nanomsg/integration_test.go | 2 +- .../impl/nats/integration_jetstream_test.go | 4 ++-- internal/impl/nats/integration_kv_test.go | 2 +- internal/impl/nats/integration_nats_test.go | 2 +- internal/impl/nats/integration_req_test.go | 2 +- internal/impl/nats/integration_stream_test.go | 2 +- internal/impl/ollama/chat_processor_test.go | 2 +- .../impl/ollama/embeddings_processor_test.go | 2 +- internal/impl/postgresql/integration_test.go | 10 +++++----- internal/impl/pulsar/integration_test.go | 2 +- internal/impl/qdrant/integration_test.go | 2 +- internal/impl/questdb/integration_test.go | 2 +- internal/impl/redis/cache_integration_test.go | 6 +++--- internal/impl/redis/integration_test.go | 2 +- .../impl/redis/processor_integration_test.go | 2 +- .../impl/redis/rate_limit_integration_test.go | 2 +- internal/impl/splunk/integration_test.go | 2 +- internal/impl/sql/cache_integration_test.go | 2 +- internal/impl/sql/integration_test.go | 20 +++++++++---------- internal/impl/zeromq/integration_test.go | 2 +- internal/secrets/redis_test.go | 2 +- 45 files changed, 66 insertions(+), 66 deletions(-) diff --git a/internal/impl/amqp09/integration_test.go b/internal/impl/amqp09/integration_test.go index 5ceddf3a11..93ef12ef13 100644 --- a/internal/impl/amqp09/integration_test.go +++ b/internal/impl/amqp09/integration_test.go @@ -107,7 +107,7 @@ input: ) } - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/amqp1/integration_service_bus_test.go b/internal/impl/amqp1/integration_service_bus_test.go index b1b545f1ca..1259d4300a 100644 --- a/internal/impl/amqp1/integration_service_bus_test.go +++ b/internal/impl/amqp1/integration_service_bus_test.go @@ -31,7 +31,7 @@ import ( ) func TestIntegrationAzureServiceBus(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) if testing.Short() { t.Skip("Skipping integration test in short mode") diff --git a/internal/impl/amqp1/integration_test.go b/internal/impl/amqp1/integration_test.go index a5ae3420f1..6cd5729a54 100644 --- a/internal/impl/amqp1/integration_test.go +++ b/internal/impl/amqp1/integration_test.go @@ -29,7 +29,7 @@ import ( ) func TestIntegrationAMQP1(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/aws/cache_dynamodb_integration_test.go b/internal/impl/aws/cache_dynamodb_integration_test.go index 7ef496f7cf..62dd533227 100644 --- a/internal/impl/aws/cache_dynamodb_integration_test.go +++ b/internal/impl/aws/cache_dynamodb_integration_test.go @@ -93,7 +93,7 @@ func createTable(ctx context.Context, t testing.TB, dynamoPort, id string) error } func TestIntegrationDynamoDBCache(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/aws/output_kinesis_integration_test.go b/internal/impl/aws/output_kinesis_integration_test.go index f27f85e8ac..5e7830f33d 100644 --- a/internal/impl/aws/output_kinesis_integration_test.go +++ b/internal/impl/aws/output_kinesis_integration_test.go @@ -36,7 +36,7 @@ import ( func TestKinesisIntegration(t *testing.T) { t.Skip("The docker image we're using here is old and deprecated") - // integration.CheckSkip(t) + integration.CheckSkip(t) if testing.Short() { t.Skip("Skipping integration test in short mode") diff --git a/internal/impl/azure/integration_test.go b/internal/impl/azure/integration_test.go index 20501bd7c5..4448fcb1b6 100644 --- a/internal/impl/azure/integration_test.go +++ b/internal/impl/azure/integration_test.go @@ -44,7 +44,7 @@ import ( ) func TestIntegrationAzure(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") @@ -277,7 +277,7 @@ input: } func TestIntegrationCosmosDB(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/beanstalkd/integration_test.go b/internal/impl/beanstalkd/integration_test.go index e97238fb8c..2964edcd2f 100644 --- a/internal/impl/beanstalkd/integration_test.go +++ b/internal/impl/beanstalkd/integration_test.go @@ -37,7 +37,7 @@ input: ` func TestIntegrationBeanstalkdOpenClose(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") @@ -65,7 +65,7 @@ func TestIntegrationBeanstalkdOpenClose(t *testing.T) { } func TestIntegrationBeanstalkdSendBatch(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") @@ -93,7 +93,7 @@ func TestIntegrationBeanstalkdSendBatch(t *testing.T) { } func TestIntegrationBeanstalkdStreamSequential(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") @@ -121,7 +121,7 @@ func TestIntegrationBeanstalkdStreamSequential(t *testing.T) { } func TestIntegrationBeanstalkdStreamParallel(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/cassandra/integration_test.go b/internal/impl/cassandra/integration_test.go index f284289093..725f498e03 100644 --- a/internal/impl/cassandra/integration_test.go +++ b/internal/impl/cassandra/integration_test.go @@ -30,7 +30,7 @@ import ( ) func TestIntegrationCassandra(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() diff --git a/internal/impl/cockroachdb/exploration_test.go b/internal/impl/cockroachdb/exploration_test.go index 2e59650938..f7dfb63767 100644 --- a/internal/impl/cockroachdb/exploration_test.go +++ b/internal/impl/cockroachdb/exploration_test.go @@ -35,7 +35,7 @@ import ( ) func TestIntegrationExploration(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/cockroachdb/integration_test.go b/internal/impl/cockroachdb/integration_test.go index b6196c204e..c79cb73b6b 100644 --- a/internal/impl/cockroachdb/integration_test.go +++ b/internal/impl/cockroachdb/integration_test.go @@ -33,7 +33,7 @@ import ( ) func TestIntegrationCRDB(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() tmpDir := t.TempDir() diff --git a/internal/impl/couchbase/cache_test.go b/internal/impl/couchbase/cache_test.go index 710ddbc998..8a8d6cbeff 100644 --- a/internal/impl/couchbase/cache_test.go +++ b/internal/impl/couchbase/cache_test.go @@ -27,7 +27,7 @@ import ( ) func TestIntegrationCouchbaseCache(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) servicePort := requireCouchbase(t) diff --git a/internal/impl/couchbase/output_test.go b/internal/impl/couchbase/output_test.go index bb70a0b337..5912d36171 100644 --- a/internal/impl/couchbase/output_test.go +++ b/internal/impl/couchbase/output_test.go @@ -107,7 +107,7 @@ couchbase: } func TestIntegrationCouchbaseOutput(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) servicePort := requireCouchbase(t) diff --git a/internal/impl/couchbase/processor_test.go b/internal/impl/couchbase/processor_test.go index 27320f7026..30d5701bde 100644 --- a/internal/impl/couchbase/processor_test.go +++ b/internal/impl/couchbase/processor_test.go @@ -117,7 +117,7 @@ couchbase: } func TestIntegrationCouchbaseProcessor(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) servicePort := requireCouchbase(t) diff --git a/internal/impl/elasticsearch/aws/integration_test.go b/internal/impl/elasticsearch/aws/integration_test.go index a72981628d..211addea83 100644 --- a/internal/impl/elasticsearch/aws/integration_test.go +++ b/internal/impl/elasticsearch/aws/integration_test.go @@ -56,7 +56,7 @@ func TestIntegrationElasticsearchAWS(t *testing.T) { // TODO: Fix this test after migrating to the new Elasticsearch client libs. t.Skip("Struggling to get localstack es to work, maybe one day") - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/elasticsearch/integration_test.go b/internal/impl/elasticsearch/integration_test.go index 9b2aabf541..25a3200a42 100644 --- a/internal/impl/elasticsearch/integration_test.go +++ b/internal/impl/elasticsearch/integration_test.go @@ -50,7 +50,7 @@ var elasticIndex = `{ }` func TestIntegrationElasticsearchV8(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") @@ -133,7 +133,7 @@ output: } func TestIntegrationElasticsearchV7(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/elasticsearch/writer_integration_test.go b/internal/impl/elasticsearch/writer_integration_test.go index 37bbb2306a..21a0e3ab01 100644 --- a/internal/impl/elasticsearch/writer_integration_test.go +++ b/internal/impl/elasticsearch/writer_integration_test.go @@ -47,7 +47,7 @@ func outputFromConf(t testing.TB, confStr string, args ...any) *elasticsearch.Ou } func TestIntegrationWriter(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/gcp/integration_pubsub_test.go b/internal/impl/gcp/integration_pubsub_test.go index 87fc9777e9..2404d3a792 100644 --- a/internal/impl/gcp/integration_pubsub_test.go +++ b/internal/impl/gcp/integration_pubsub_test.go @@ -32,7 +32,7 @@ import ( ) func TestIntegrationGCPPubSub(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) pool, err := dockertest.NewPool("") require.NoError(t, err) diff --git a/internal/impl/gcp/integration_test.go b/internal/impl/gcp/integration_test.go index e496ef2930..f5c73a3f3c 100644 --- a/internal/impl/gcp/integration_test.go +++ b/internal/impl/gcp/integration_test.go @@ -45,7 +45,7 @@ func createGCPCloudStorageBucket(var1, id string) error { } func TestIntegrationGCP(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/influxdb/metrics_influxdb_integration_test.go b/internal/impl/influxdb/metrics_influxdb_integration_test.go index 7fb650aa03..8b4046489f 100644 --- a/internal/impl/influxdb/metrics_influxdb_integration_test.go +++ b/internal/impl/influxdb/metrics_influxdb_integration_test.go @@ -34,7 +34,7 @@ func TestInfluxIntegration(t *testing.T) { t.Skip("skipping test on macos") } - // integration.CheckSkip(t) + integration.CheckSkip(t) if testing.Short() { t.Skip("Skipping integration test in short mode") diff --git a/internal/impl/memcached/cache_integration_test.go b/internal/impl/memcached/cache_integration_test.go index 6dccb58198..53fa137a72 100644 --- a/internal/impl/memcached/cache_integration_test.go +++ b/internal/impl/memcached/cache_integration_test.go @@ -28,7 +28,7 @@ import ( ) func TestIntegrationMemcachedCache(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/mongodb/input_test.go b/internal/impl/mongodb/input_test.go index b7f2735699..37c9e7095f 100644 --- a/internal/impl/mongodb/input_test.go +++ b/internal/impl/mongodb/input_test.go @@ -57,7 +57,7 @@ query: | } func TestInputIntegration(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) pool, err := dockertest.NewPool("") if err != nil { diff --git a/internal/impl/mongodb/integration_test.go b/internal/impl/mongodb/integration_test.go index a03b65e9cc..361150ffbc 100644 --- a/internal/impl/mongodb/integration_test.go +++ b/internal/impl/mongodb/integration_test.go @@ -37,7 +37,7 @@ func generateCollectionName(testID string) string { } func TestIntegrationMongoDB(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/mongodb/processor_test.go b/internal/impl/mongodb/processor_test.go index fed8292cdb..7d28eef7c1 100644 --- a/internal/impl/mongodb/processor_test.go +++ b/internal/impl/mongodb/processor_test.go @@ -34,7 +34,7 @@ import ( ) func TestProcessorIntegration(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) pool, err := dockertest.NewPool("") if err != nil { diff --git a/internal/impl/mqtt/integration_test.go b/internal/impl/mqtt/integration_test.go index 7a661780b4..f5ddcb7945 100644 --- a/internal/impl/mqtt/integration_test.go +++ b/internal/impl/mqtt/integration_test.go @@ -28,7 +28,7 @@ import ( ) func TestIntegrationMQTT(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/nanomsg/integration_test.go b/internal/impl/nanomsg/integration_test.go index 23f06ec3f8..e1d1fb49bc 100644 --- a/internal/impl/nanomsg/integration_test.go +++ b/internal/impl/nanomsg/integration_test.go @@ -22,7 +22,7 @@ import ( ) func TestIntegrationNanomsg(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() template := ` diff --git a/internal/impl/nats/integration_jetstream_test.go b/internal/impl/nats/integration_jetstream_test.go index 9b7348a542..ec31a2590c 100644 --- a/internal/impl/nats/integration_jetstream_test.go +++ b/internal/impl/nats/integration_jetstream_test.go @@ -29,7 +29,7 @@ import ( ) func TestIntegrationNatsJetstream(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") @@ -99,7 +99,7 @@ input: } func TestIntegrationNatsPullConsumer(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/nats/integration_kv_test.go b/internal/impl/nats/integration_kv_test.go index 72384c9c8e..5732be3457 100644 --- a/internal/impl/nats/integration_kv_test.go +++ b/internal/impl/nats/integration_kv_test.go @@ -33,7 +33,7 @@ import ( ) func TestIntegrationNatsKV(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/nats/integration_nats_test.go b/internal/impl/nats/integration_nats_test.go index 032aa2ee69..1023637208 100644 --- a/internal/impl/nats/integration_nats_test.go +++ b/internal/impl/nats/integration_nats_test.go @@ -28,7 +28,7 @@ import ( ) func TestIntegrationNats(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/nats/integration_req_test.go b/internal/impl/nats/integration_req_test.go index 7dcf4fb1fe..bdad897f70 100644 --- a/internal/impl/nats/integration_req_test.go +++ b/internal/impl/nats/integration_req_test.go @@ -29,7 +29,7 @@ import ( ) func TestIntegrationNatsReq(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/nats/integration_stream_test.go b/internal/impl/nats/integration_stream_test.go index 92ed441101..9b66ff3fa3 100644 --- a/internal/impl/nats/integration_stream_test.go +++ b/internal/impl/nats/integration_stream_test.go @@ -28,7 +28,7 @@ import ( ) func TestIntegrationNatsStream(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/ollama/chat_processor_test.go b/internal/impl/ollama/chat_processor_test.go index eb1b7d8f3d..fbb38066bc 100644 --- a/internal/impl/ollama/chat_processor_test.go +++ b/internal/impl/ollama/chat_processor_test.go @@ -38,7 +38,7 @@ func createCompletionProcessorForTest(t *testing.T, addr string) *ollamaCompleti } func TestOllamaCompletionIntegration(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) ctx := context.Background() ollamaContainer, err := ollama.Run(ctx, "ollama/ollama:0.2.5") diff --git a/internal/impl/ollama/embeddings_processor_test.go b/internal/impl/ollama/embeddings_processor_test.go index 44eb6c7f78..c0bbb14753 100644 --- a/internal/impl/ollama/embeddings_processor_test.go +++ b/internal/impl/ollama/embeddings_processor_test.go @@ -37,7 +37,7 @@ func createEmbeddingsProcessorForTest(t *testing.T, addr string) *ollamaEmbeddin } func TestOllamaEmbeddingsIntegration(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) ctx := context.Background() ollamaContainer, err := ollama.Run(ctx, "ollama/ollama:0.2.5") diff --git a/internal/impl/postgresql/integration_test.go b/internal/impl/postgresql/integration_test.go index 21267cdc3c..08c933e610 100644 --- a/internal/impl/postgresql/integration_test.go +++ b/internal/impl/postgresql/integration_test.go @@ -116,7 +116,7 @@ func ResourceWithPostgreSQLVersion(t *testing.T, pool *dockertest.Pool, version } func TestIntegrationPgCDC(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) tmpDir := t.TempDir() pool, err := dockertest.NewPool("") @@ -306,7 +306,7 @@ file: } func TestIntegrationPgCDCForPgOutputPlugin(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) tmpDir := t.TempDir() pool, err := dockertest.NewPool("") require.NoError(t, err) @@ -517,7 +517,7 @@ file: } func TestIntegrationPgCDCForPgOutputStreamUncommittedPlugin(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) tmpDir := t.TempDir() pool, err := dockertest.NewPool("") require.NoError(t, err) @@ -654,7 +654,7 @@ file: } func TestIntegrationPgMultiVersionsCDCForPgOutputStreamUncomitedPlugin(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) // running tests in the look to test different PostgreSQL versions t.Parallel() for _, v := range []string{"17", "16", "15", "14", "13", "12", "11", "10"} { @@ -793,7 +793,7 @@ file: } func TestIntegrationPgMultiVersionsCDCForPgOutputStreamComittedPlugin(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) for _, v := range []string{"17", "16", "15", "14", "13", "12", "11", "10"} { tmpDir := t.TempDir() pool, err := dockertest.NewPool("") diff --git a/internal/impl/pulsar/integration_test.go b/internal/impl/pulsar/integration_test.go index b835d9a48c..221c26cefe 100644 --- a/internal/impl/pulsar/integration_test.go +++ b/internal/impl/pulsar/integration_test.go @@ -28,7 +28,7 @@ import ( ) func TestIntegrationPulsar(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/qdrant/integration_test.go b/internal/impl/qdrant/integration_test.go index c5700dab2d..7f4e396c02 100644 --- a/internal/impl/qdrant/integration_test.go +++ b/internal/impl/qdrant/integration_test.go @@ -45,7 +45,7 @@ output: ) func TestIntegrationQdrant(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() diff --git a/internal/impl/questdb/integration_test.go b/internal/impl/questdb/integration_test.go index 06e8329c71..e204cf9863 100644 --- a/internal/impl/questdb/integration_test.go +++ b/internal/impl/questdb/integration_test.go @@ -35,7 +35,7 @@ import ( func TestIntegrationQuestDB(t *testing.T) { ctx := context.Background() - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/redis/cache_integration_test.go b/internal/impl/redis/cache_integration_test.go index ae1f4e7442..1fbd4c3b36 100644 --- a/internal/impl/redis/cache_integration_test.go +++ b/internal/impl/redis/cache_integration_test.go @@ -31,7 +31,7 @@ import ( ) func TestIntegrationRedisCache(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") @@ -85,7 +85,7 @@ cache_resources: func TestIntegrationRedisClusterCache(t *testing.T) { t.Skip("Skipping as networking often fails for this test") - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") @@ -178,7 +178,7 @@ cache_resources: func TestIntegrationRedisFailoverCache(t *testing.T) { t.Skip("Skipping as networking often fails for this test") - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/redis/integration_test.go b/internal/impl/redis/integration_test.go index 044215af54..5ab6991d14 100644 --- a/internal/impl/redis/integration_test.go +++ b/internal/impl/redis/integration_test.go @@ -31,7 +31,7 @@ import ( ) func TestIntegrationRedis(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/redis/processor_integration_test.go b/internal/impl/redis/processor_integration_test.go index 16566e3daa..c4ffde7c48 100644 --- a/internal/impl/redis/processor_integration_test.go +++ b/internal/impl/redis/processor_integration_test.go @@ -31,7 +31,7 @@ import ( ) func TestIntegrationRedisProcessor(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) pool, err := dockertest.NewPool("") if err != nil { diff --git a/internal/impl/redis/rate_limit_integration_test.go b/internal/impl/redis/rate_limit_integration_test.go index 5174019443..729b8ed338 100644 --- a/internal/impl/redis/rate_limit_integration_test.go +++ b/internal/impl/redis/rate_limit_integration_test.go @@ -29,7 +29,7 @@ import ( ) func TestIntegrationRedisRateLimit(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) pool, err := dockertest.NewPool("") if err != nil { diff --git a/internal/impl/splunk/integration_test.go b/internal/impl/splunk/integration_test.go index 415a8ef6aa..85ec377a2f 100644 --- a/internal/impl/splunk/integration_test.go +++ b/internal/impl/splunk/integration_test.go @@ -24,7 +24,7 @@ import ( ) func TestIntegrationSplunk(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/sql/cache_integration_test.go b/internal/impl/sql/cache_integration_test.go index 8d62ad806a..ac76c2cc1a 100644 --- a/internal/impl/sql/cache_integration_test.go +++ b/internal/impl/sql/cache_integration_test.go @@ -29,7 +29,7 @@ import ( ) func TestIntegrationCache(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/sql/integration_test.go b/internal/impl/sql/integration_test.go index 8b7ea78539..6da722fcf6 100644 --- a/internal/impl/sql/integration_test.go +++ b/internal/impl/sql/integration_test.go @@ -563,7 +563,7 @@ func testSuite(t *testing.T, driver, dsn string, createTableFn func(string) (str } func TestIntegrationClickhouse(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") @@ -618,7 +618,7 @@ func TestIntegrationClickhouse(t *testing.T) { } func TestIntegrationOldClickhouse(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") @@ -673,7 +673,7 @@ func TestIntegrationOldClickhouse(t *testing.T) { } func TestIntegrationPostgres(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") @@ -734,7 +734,7 @@ func TestIntegrationPostgres(t *testing.T) { } func TestIntegrationPostgresVector(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") @@ -845,7 +845,7 @@ suffix: ORDER BY embedding <-> '[3,1,2]' LIMIT 1 } func TestIntegrationMySQL(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") @@ -909,7 +909,7 @@ func TestIntegrationMySQL(t *testing.T) { } func TestIntegrationMSSQL(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") @@ -970,7 +970,7 @@ func TestIntegrationMSSQL(t *testing.T) { } func TestIntegrationSQLite(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() var db *sql.DB @@ -1013,7 +1013,7 @@ func TestIntegrationSQLite(t *testing.T) { } func TestIntegrationOracle(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") @@ -1075,7 +1075,7 @@ func TestIntegrationOracle(t *testing.T) { } func TestIntegrationTrino(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") @@ -1136,7 +1136,7 @@ create table %s ( } func TestIntegrationCosmosDB(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") diff --git a/internal/impl/zeromq/integration_test.go b/internal/impl/zeromq/integration_test.go index add75809b2..54802da11b 100644 --- a/internal/impl/zeromq/integration_test.go +++ b/internal/impl/zeromq/integration_test.go @@ -25,7 +25,7 @@ import ( ) func TestIntegrationZMQ(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() template := ` diff --git a/internal/secrets/redis_test.go b/internal/secrets/redis_test.go index 8495fbafe5..0d2e28b839 100644 --- a/internal/secrets/redis_test.go +++ b/internal/secrets/redis_test.go @@ -26,7 +26,7 @@ import ( ) func TestIntegrationRedis(t *testing.T) { - // integration.CheckSkip(t) + integration.CheckSkip(t) t.Parallel() pool, err := dockertest.NewPool("") From 99dbe63d545dc2b6c46466104c9e7f2a64ff86b9 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Mon, 11 Nov 2024 20:58:08 +0100 Subject: [PATCH 072/118] chore(): small fixes && pr notes --- internal/impl/postgresql/integration_test.go | 1 + internal/impl/postgresql/pglogicalstream/logical_stream.go | 2 +- internal/impl/postgresql/pglogicalstream/snapshotter.go | 2 +- internal/impl/postgresql/utils.go | 4 ++-- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/internal/impl/postgresql/integration_test.go b/internal/impl/postgresql/integration_test.go index 08c933e610..dc750fac87 100644 --- a/internal/impl/postgresql/integration_test.go +++ b/internal/impl/postgresql/integration_test.go @@ -12,6 +12,7 @@ import ( "context" "database/sql" "fmt" + "github.com/redpanda-data/benthos/v4/public/service/integration" "strings" "sync" "testing" diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 4e2f19dc74..1d91cf1326 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -503,7 +503,7 @@ func (s *Stream) processSnapshot(ctx context.Context) error { return err } - var lastPkVal interface{} + var lastPkVal any for { var snapshotRows *sql.Rows diff --git a/internal/impl/postgresql/pglogicalstream/snapshotter.go b/internal/impl/postgresql/pglogicalstream/snapshotter.go index 875dead7af..a5f8931c3d 100644 --- a/internal/impl/postgresql/pglogicalstream/snapshotter.go +++ b/internal/impl/postgresql/pglogicalstream/snapshotter.go @@ -165,7 +165,7 @@ func (s *Snapshotter) calculateBatchSize(availableMemory uint64, estimatedRowSiz return batchSize } -func (s *Snapshotter) querySnapshotData(table string, lastSeenPk interface{}, pk string, limit int) (rows *sql.Rows, err error) { +func (s *Snapshotter) querySnapshotData(table string, lastSeenPk any, pk string, limit int) (rows *sql.Rows, err error) { s.logger.Infof("Query snapshot table: %v, limit: %v, lastSeenPkVal: %v, pk: %v", table, limit, lastSeenPk, pk) if lastSeenPk == nil { return s.pgConnection.Query(fmt.Sprintf("SELECT * FROM %s ORDER BY %s LIMIT $1;", table, pk), limit) diff --git a/internal/impl/postgresql/utils.go b/internal/impl/postgresql/utils.go index e787849a0d..d01bf441cd 100644 --- a/internal/impl/postgresql/utils.go +++ b/internal/impl/postgresql/utils.go @@ -23,12 +23,12 @@ func LSNToInt64(lsn string) (int64, error) { } // Parse both segments as hex with uint64 first - upper, err := strconv.ParseUint(parts[0], 16, 64) + upper, err := strconv.ParseUint(parts[0], 16, 32) if err != nil { return 0, fmt.Errorf("failed to parse upper part: %w", err) } - lower, err := strconv.ParseUint(parts[1], 16, 64) + lower, err := strconv.ParseUint(parts[1], 16, 32) if err != nil { return 0, fmt.Errorf("failed to parse lower part: %w", err) } From 61ea84b3fd5c74bce7e2d5c9dbd19a2e50f3f74d Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Mon, 11 Nov 2024 22:07:37 +0100 Subject: [PATCH 073/118] chore(): updated docs && fixed lint --- docs/modules/components/pages/inputs/pg_stream.adoc | 11 +++++++++++ internal/impl/postgresql/integration_test.go | 2 +- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/docs/modules/components/pages/inputs/pg_stream.adoc b/docs/modules/components/pages/inputs/pg_stream.adoc index 22dee435ef..d11a472808 100644 --- a/docs/modules/components/pages/inputs/pg_stream.adoc +++ b/docs/modules/components/pages/inputs/pg_stream.adoc @@ -52,6 +52,7 @@ input: slot_name: "" pg_standby_timeout_sec: 10 pg_wal_monitor_interval_sec: 3 + max_parallel_snapshot_tables: 1 auto_replay_nacks: true batching: count: 0 @@ -83,6 +84,7 @@ input: slot_name: "" pg_standby_timeout_sec: 10 pg_wal_monitor_interval_sec: 3 + max_parallel_snapshot_tables: 1 auto_replay_nacks: true batching: count: 0 @@ -291,6 +293,15 @@ Int field stat specifies ticker interval for WAL monitoring. Used to fetch repli pg_wal_monitor_interval_sec: 3 ``` +=== `max_parallel_snapshot_tables` + +Int specifies a number of tables that will be processed in parallel during the snapshot processing stage + + +*Type*: `int` + +*Default*: `1` + === `auto_replay_nacks` Whether messages that are rejected (nacked) at the output level should be automatically replayed indefinitely, eventually resulting in back pressure if the cause of the rejections is persistent. If set to `false` these messages will instead be deleted. Disabling auto replays can greatly improve memory efficiency of high throughput streams as the original shape of the data can be discarded immediately upon consumption and mutation. diff --git a/internal/impl/postgresql/integration_test.go b/internal/impl/postgresql/integration_test.go index dc750fac87..283a281c26 100644 --- a/internal/impl/postgresql/integration_test.go +++ b/internal/impl/postgresql/integration_test.go @@ -12,7 +12,6 @@ import ( "context" "database/sql" "fmt" - "github.com/redpanda-data/benthos/v4/public/service/integration" "strings" "sync" "testing" @@ -23,6 +22,7 @@ import ( _ "github.com/redpanda-data/benthos/v4/public/components/io" _ "github.com/redpanda-data/benthos/v4/public/components/pure" "github.com/redpanda-data/benthos/v4/public/service" + "github.com/redpanda-data/benthos/v4/public/service/integration" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" From 141832fb8ce423f7366d3364c72d7f676d8f35d4 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Tue, 12 Nov 2024 09:52:11 +0100 Subject: [PATCH 074/118] chore(): revert integration tests --- internal/impl/aws/output_kinesis_integration_test.go | 1 + internal/impl/couchbase/output_test.go | 1 + internal/impl/couchbase/processor_test.go | 1 + internal/impl/mongodb/processor_test.go | 1 + internal/impl/nats/integration_req_test.go | 1 + internal/impl/ollama/chat_processor_test.go | 1 + internal/impl/redis/processor_integration_test.go | 1 + internal/impl/redis/rate_limit_integration_test.go | 1 + internal/impl/sql/integration_test.go | 1 + 9 files changed, 9 insertions(+) diff --git a/internal/impl/aws/output_kinesis_integration_test.go b/internal/impl/aws/output_kinesis_integration_test.go index 5e7830f33d..e584deee60 100644 --- a/internal/impl/aws/output_kinesis_integration_test.go +++ b/internal/impl/aws/output_kinesis_integration_test.go @@ -32,6 +32,7 @@ import ( "github.com/stretchr/testify/require" "github.com/redpanda-data/benthos/v4/public/service" + "github.com/redpanda-data/benthos/v4/public/service/integration" ) func TestKinesisIntegration(t *testing.T) { diff --git a/internal/impl/couchbase/output_test.go b/internal/impl/couchbase/output_test.go index 5912d36171..f63a502afd 100644 --- a/internal/impl/couchbase/output_test.go +++ b/internal/impl/couchbase/output_test.go @@ -25,6 +25,7 @@ import ( "github.com/stretchr/testify/require" "github.com/redpanda-data/benthos/v4/public/service" + "github.com/redpanda-data/benthos/v4/public/service/integration" "github.com/redpanda-data/connect/v4/internal/impl/couchbase" ) diff --git a/internal/impl/couchbase/processor_test.go b/internal/impl/couchbase/processor_test.go index 30d5701bde..988bd14890 100644 --- a/internal/impl/couchbase/processor_test.go +++ b/internal/impl/couchbase/processor_test.go @@ -25,6 +25,7 @@ import ( "github.com/stretchr/testify/require" "github.com/redpanda-data/benthos/v4/public/service" + "github.com/redpanda-data/benthos/v4/public/service/integration" "github.com/redpanda-data/connect/v4/internal/impl/couchbase" ) diff --git a/internal/impl/mongodb/processor_test.go b/internal/impl/mongodb/processor_test.go index 7d28eef7c1..f586db8761 100644 --- a/internal/impl/mongodb/processor_test.go +++ b/internal/impl/mongodb/processor_test.go @@ -29,6 +29,7 @@ import ( "go.mongodb.org/mongo-driver/mongo/options" "github.com/redpanda-data/benthos/v4/public/service" + "github.com/redpanda-data/benthos/v4/public/service/integration" "github.com/redpanda-data/connect/v4/internal/impl/mongodb" ) diff --git a/internal/impl/nats/integration_req_test.go b/internal/impl/nats/integration_req_test.go index bdad897f70..ea1664321a 100644 --- a/internal/impl/nats/integration_req_test.go +++ b/internal/impl/nats/integration_req_test.go @@ -26,6 +26,7 @@ import ( "github.com/stretchr/testify/require" "github.com/redpanda-data/benthos/v4/public/service" + "github.com/redpanda-data/benthos/v4/public/service/integration" ) func TestIntegrationNatsReq(t *testing.T) { diff --git a/internal/impl/ollama/chat_processor_test.go b/internal/impl/ollama/chat_processor_test.go index fbb38066bc..b564f8e209 100644 --- a/internal/impl/ollama/chat_processor_test.go +++ b/internal/impl/ollama/chat_processor_test.go @@ -17,6 +17,7 @@ import ( "github.com/ollama/ollama/api" "github.com/redpanda-data/benthos/v4/public/service" + "github.com/redpanda-data/benthos/v4/public/service/integration" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/testcontainers/testcontainers-go/modules/ollama" diff --git a/internal/impl/redis/processor_integration_test.go b/internal/impl/redis/processor_integration_test.go index c4ffde7c48..25cc1a3594 100644 --- a/internal/impl/redis/processor_integration_test.go +++ b/internal/impl/redis/processor_integration_test.go @@ -28,6 +28,7 @@ import ( "github.com/stretchr/testify/require" "github.com/redpanda-data/benthos/v4/public/service" + "github.com/redpanda-data/benthos/v4/public/service/integration" ) func TestIntegrationRedisProcessor(t *testing.T) { diff --git a/internal/impl/redis/rate_limit_integration_test.go b/internal/impl/redis/rate_limit_integration_test.go index 729b8ed338..92098b85ff 100644 --- a/internal/impl/redis/rate_limit_integration_test.go +++ b/internal/impl/redis/rate_limit_integration_test.go @@ -24,6 +24,7 @@ import ( "github.com/ory/dockertest/v3" "github.com/redis/go-redis/v9" + "github.com/redpanda-data/benthos/v4/public/service/integration" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/internal/impl/sql/integration_test.go b/internal/impl/sql/integration_test.go index 6da722fcf6..1a4b9103fb 100644 --- a/internal/impl/sql/integration_test.go +++ b/internal/impl/sql/integration_test.go @@ -29,6 +29,7 @@ import ( "github.com/stretchr/testify/require" "github.com/redpanda-data/benthos/v4/public/service" + "github.com/redpanda-data/benthos/v4/public/service/integration" isql "github.com/redpanda-data/connect/v4/internal/impl/sql" From 51a940eeb19c84ab6dadff5cd8b4475f429cafdd Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Tue, 12 Nov 2024 11:21:48 +0100 Subject: [PATCH 075/118] chore(): added publication updates instead of re-creation --- .../postgresql/pglogicalstream/pglogrepl.go | 117 ++++++++++++++++-- .../pglogicalstream/pglogrepl_test.go | 82 ++++++++++++ 2 files changed, 190 insertions(+), 9 deletions(-) diff --git a/internal/impl/postgresql/pglogicalstream/pglogrepl.go b/internal/impl/postgresql/pglogicalstream/pglogrepl.go index 9d4f1837d3..0324fe81b3 100644 --- a/internal/impl/postgresql/pglogicalstream/pglogrepl.go +++ b/internal/impl/postgresql/pglogicalstream/pglogrepl.go @@ -22,6 +22,7 @@ import ( "database/sql/driver" "encoding/binary" "fmt" + "slices" "strconv" "strings" "time" @@ -29,6 +30,7 @@ import ( "github.com/jackc/pgio" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgproto3" + "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/sanitize" ) const ( @@ -339,23 +341,120 @@ func DropReplicationSlot(ctx context.Context, conn *pgconn.PgConn, slotName stri // CreatePublication creates a new PostgreSQL publication with the given name for a list of tables and drop if exists flag func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName string, tables []string) error { - result := conn.Exec(ctx, fmt.Sprintf("DROP PUBLICATION IF EXISTS %s;", publicationName)) - if _, err := result.ReadAll(); err != nil { + // Check if publication exists + pubQuery, err := sanitize.SQLQuery(` + SELECT pubname, puballtables + FROM pg_publication + WHERE pubname = $1; + `, publicationName) + if err != nil { + return fmt.Errorf("failed to sanitize publication query: %w", err) + } + + result := conn.Exec(ctx, pubQuery) + + rows, err := result.ReadAll() + if err != nil { + return fmt.Errorf("failed to check publication existence: %w", err) + } + + tablesClause := "FOR ALL TABLES" + if len(tables) > 0 { + // TODO: Implement proper SQL injection protection, potentially using parameterized queries + // or a SQL query builder that handles proper escaping + tablesClause = "FOR TABLE " + strings.Join(tables, ",") + } + + if len(rows) == 0 || len(rows[0].Rows) == 0 { + // Publication doesn't exist, create new one + result = conn.Exec(ctx, fmt.Sprintf("CREATE PUBLICATION %s %s;", publicationName, tablesClause)) + if _, err := result.ReadAll(); err != nil { + return fmt.Errorf("failed to create publication: %w", err) + } + return nil } - // TODO(rockwood): We need to validate the tables don't contain a SQL injection attack - tablesSchemaFilter := "FOR TABLE " + strings.Join(tables, ",") - if len(tables) == 0 { - tablesSchemaFilter = "FOR ALL TABLES" + // assuming publication already exists + // get a list of tables in the publication + pubTables, forAllTables, err := GetPublicationTables(ctx, conn, publicationName) + if err != nil { + return fmt.Errorf("failed to get publication tables: %w", err) } - result = conn.Exec(ctx, fmt.Sprintf("CREATE PUBLICATION %s %s;", publicationName, tablesSchemaFilter)) - if _, err := result.ReadAll(); err != nil { - return err + + // list of tables to publish is empty and publication is for all tables + // no update is needed + if forAllTables && len(pubTables) == 0 { + return nil + } + + var tablesToRemoveFromPublication = []string{} + var tablesToAddToPublication = []string{} + for _, table := range tables { + if !slices.Contains[[]string, string](pubTables, table) { + tablesToAddToPublication = append(tablesToAddToPublication, table) + } + } + + for _, table := range pubTables { + if !slices.Contains[[]string, string](tables, table) { + tablesToRemoveFromPublication = append(tablesToRemoveFromPublication, table) + } } + + // remove tables from publication + for _, dropTable := range tablesToRemoveFromPublication { + result = conn.Exec(ctx, fmt.Sprintf("ALTER PUBLICATION %s DROP TABLE %s;", publicationName, dropTable)) + if _, err := result.ReadAll(); err != nil { + return fmt.Errorf("failed to remove table from publication: %w", err) + } + } + + // add tables to publication + for _, addTable := range tablesToAddToPublication { + result = conn.Exec(ctx, fmt.Sprintf("ALTER PUBLICATION %s ADD TABLE %s;", publicationName, addTable)) + if _, err := result.ReadAll(); err != nil { + return fmt.Errorf("failed to add table to publication: %w", err) + } + } + return nil } +// GetPublicationTables returns a list of tables currently in the publication +// Arguments, in order: list of the tables, exist for all tables, errror +func GetPublicationTables(ctx context.Context, conn *pgconn.PgConn, publicationName string) ([]string, bool, error) { + query, err := sanitize.SQLQuery(` + SELECT DISTINCT + tablename as table_name + FROM pg_publication_tables + WHERE pubname = $1 + ORDER BY table_name; + `, publicationName) + if err != nil { + return nil, false, fmt.Errorf("failed to get publication tables: %w", err) + } + + // Get specific tables in the publication + result := conn.Exec(ctx, query) + + rows, err := result.ReadAll() + if err != nil { + return nil, false, fmt.Errorf("failed to get publication tables: %w", err) + } + + if len(rows) == 0 || len(rows[0].Rows) == 0 { + return nil, true, nil // Publication exists and is for all tables + } + + tables := make([]string, 0, len(rows)) + for _, row := range rows[0].Rows { + tables = append(tables, string(row[0])) + } + + return tables, false, nil +} + // StartReplicationOptions are the options for the START_REPLICATION command. // The Timeline field is optional and defaults to 0, which means the current server timeline. // The Mode field is required and must be either PhysicalReplication or LogicalReplication. ## PhysicalReplication is not supporter by this plugin, but still can be implemented diff --git a/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go b/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go index 1e54b045d8..8a50b34bc1 100644 --- a/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go +++ b/internal/impl/postgresql/pglogicalstream/pglogrepl_test.go @@ -207,6 +207,88 @@ func TestDropReplicationSlot(t *testing.T) { require.NoError(t, err) } +func TestCreatePublication(t *testing.T) { + pool, resource, dbURL := createDockerInstance(t) + defer func() { + err := pool.Purge(resource) + require.NoError(t, err) + }() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + conn, err := pgconn.Connect(ctx, dbURL) + require.NoError(t, err) + defer closeConn(t, conn) + + publicationName := "test_publication" + err = CreatePublication(context.Background(), conn, publicationName, []string{}) + require.NoError(t, err) + + tables, forAllTables, err := GetPublicationTables(context.Background(), conn, publicationName) + require.NoError(t, err) + assert.Empty(t, tables) + assert.True(t, forAllTables) + + multiReader := conn.Exec(context.Background(), "CREATE TABLE test_table (id serial PRIMARY KEY, name text);") + _, err = multiReader.ReadAll() + require.NoError(t, err) + + publicationWithTables := "test_pub_with_tables" + err = CreatePublication(context.Background(), conn, publicationWithTables, []string{"test_table"}) + require.NoError(t, err) + + tables, forAllTables, err = GetPublicationTables(context.Background(), conn, publicationName) + require.NoError(t, err) + assert.NotEmpty(t, tables) + assert.Contains(t, tables, "test_table") + assert.False(t, forAllTables) + + // add more tables to publication + multiReader = conn.Exec(context.Background(), "CREATE TABLE test_table2 (id serial PRIMARY KEY, name text);") + _, err = multiReader.ReadAll() + require.NoError(t, err) + + // Pass more tables to the publication + err = CreatePublication(context.Background(), conn, publicationWithTables, []string{ + "test_table2", + "test_table", + }) + require.NoError(t, err) + + tables, forAllTables, err = GetPublicationTables(context.Background(), conn, publicationWithTables) + require.NoError(t, err) + assert.NotEmpty(t, tables) + assert.Contains(t, tables, "test_table") + assert.Contains(t, tables, "test_table2") + assert.False(t, forAllTables) + + // Removing one table from the publication + err = CreatePublication(context.Background(), conn, publicationWithTables, []string{ + "test_table", + }) + require.NoError(t, err) + + tables, forAllTables, err = GetPublicationTables(context.Background(), conn, publicationWithTables) + require.NoError(t, err) + assert.NotEmpty(t, tables) + assert.Contains(t, tables, "test_table") + assert.False(t, forAllTables) + + // Add one table and remove one at the same time + err = CreatePublication(context.Background(), conn, publicationWithTables, []string{ + "test_table2", + }) + require.NoError(t, err) + + tables, forAllTables, err = GetPublicationTables(context.Background(), conn, publicationWithTables) + require.NoError(t, err) + assert.NotEmpty(t, tables) + assert.Contains(t, tables, "test_table2") + assert.False(t, forAllTables) + +} + func TestStartReplication(t *testing.T) { pool, resource, dbURL := createDockerInstance(t) defer func() { From ba87494f6d8dfdf389092c00a147c823fe529cff Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Tue, 12 Nov 2024 20:11:16 +0000 Subject: [PATCH 076/118] pgcdc: prefix stat names --- internal/impl/postgresql/input_pg_stream.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/impl/postgresql/input_pg_stream.go b/internal/impl/postgresql/input_pg_stream.go index 80744cf6c3..b64947954a 100644 --- a/internal/impl/postgresql/input_pg_stream.go +++ b/internal/impl/postgresql/input_pg_stream.go @@ -83,7 +83,7 @@ This input adds the following metadata fields to each message: Default(0)). Field(service.NewStringEnumField(fieldDecodingPlugin, "pgoutput", "wal2json"). Description(`Specifies the logical decoding plugin to use for streaming changes from PostgreSQL. 'pgoutput' is the native logical replication protocol, while 'wal2json' provides change data as JSON. - Important: No matter which plugin you choose, the data will be converted to JSON before sending it to Connect. +Important: No matter which plugin you choose, the data will be converted to JSON before sending it to Connect. `). Example("pgoutput"). Default("pgoutput")). @@ -220,8 +220,8 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser // https://github.com/jackc/pglogrepl/issues/6 pgConnConfig.RuntimeParams["replication"] = "database" - snapshotMetrics := mgr.Metrics().NewGauge("snapshot_progress", "table") - replicationLag := mgr.Metrics().NewGauge("replication_lag_bytes") + snapshotMetrics := mgr.Metrics().NewGauge("postgres_snapshot_progress", "table") + replicationLag := mgr.Metrics().NewGauge("postgres_replication_lag_bytes") i := &pgStreamInput{ dbConfig: pgConnConfig, From 19eda8aa95d2089dc4c967d6f45b29d3231d6463 Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Tue, 12 Nov 2024 20:17:58 +0000 Subject: [PATCH 077/118] pgcdc: remove lsnrestart field --- .../pglogicalstream/logical_stream.go | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 1d91cf1326..148b6d0040 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -35,7 +35,6 @@ type Stream struct { standbyCtxCancel context.CancelFunc clientXLogPos LSN - lsnrestart LSN standbyMessageTimeout time.Duration nextStandbyMessageDeadline time.Time @@ -195,8 +194,6 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { lsnrestart, _ = ParseLSN(confirmedLSNFromDB) } - stream.lsnrestart = lsnrestart - if freshlyCreatedSlot { stream.clientXLogPos = sysident.XLogPos } else { @@ -213,9 +210,9 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { } stream.monitor = monitor - stream.logger.Debugf("Starting stream from LSN %s with clientXLogPos %s and snapshot name %s", stream.lsnrestart.String(), stream.clientXLogPos.String(), stream.snapshotName) + stream.logger.Debugf("Starting stream from LSN %s with clientXLogPos %s and snapshot name %s", lsnrestart.String(), stream.clientXLogPos.String(), stream.snapshotName) if !freshlyCreatedSlot || !config.StreamOldData { - if err = stream.startLr(); err != nil { + if err = stream.startLr(lsnrestart); err != nil { return nil, err } @@ -224,7 +221,7 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { // New messages will be streamed after the snapshot has been processed. // stream.startLr() and stream.streamMessagesAsync() will be called inside stream.processSnapshot() go func() { - if err := stream.processSnapshot(ctx); err != nil { + if err := stream.processSnapshot(ctx, lsnrestart); err != nil { stream.logger.Errorf("Failed to process snapshot: %v", err.Error()) } }() @@ -244,8 +241,8 @@ func (s *Stream) ConsumedCallback() chan bool { return s.consumedCallback } -func (s *Stream) startLr() error { - if err := StartReplication(context.Background(), s.pgConn, s.slotName, s.lsnrestart, StartReplicationOptions{PluginArgs: s.decodingPluginArguments}); err != nil { +func (s *Stream) startLr(lsnStart LSN) error { + if err := StartReplication(context.Background(), s.pgConn, s.slotName, lsnStart, StartReplicationOptions{PluginArgs: s.decodingPluginArguments}); err != nil { return err } @@ -441,7 +438,7 @@ func (s *Stream) AckTxChan() chan string { return s.transactionAckChan } -func (s *Stream) processSnapshot(ctx context.Context) error { +func (s *Stream) processSnapshot(ctx context.Context, lsnStart LSN) error { if err := s.snapshotter.prepare(); err != nil { s.logger.Errorf("Failed to prepare database snapshot. Probably snapshot is expired...: %v", err.Error()) if err = s.cleanUpOnFailure(ctx); err != nil { @@ -608,7 +605,7 @@ func (s *Stream) processSnapshot(ctx context.Context) error { return err } - if err := s.startLr(); err != nil { + if err := s.startLr(lsnStart); err != nil { s.logger.Errorf("Failed to start logical replication after snapshot: %v", err.Error()) return err } From af512ce3e468334cbd45371746e2df24833d777b Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Tue, 12 Nov 2024 20:31:41 +0000 Subject: [PATCH 078/118] pgcdc: add a high watermark utility --- .../postgresql/pglogicalstream/watermark.go | 59 +++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 internal/impl/postgresql/pglogicalstream/watermark.go diff --git a/internal/impl/postgresql/pglogicalstream/watermark.go b/internal/impl/postgresql/pglogicalstream/watermark.go new file mode 100644 index 0000000000..bd24967994 --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/watermark.go @@ -0,0 +1,59 @@ +/* + * Copyright 2024 Redpanda Data, Inc. + * + * Licensed as a Redpanda Enterprise file under the Redpanda Community + * License (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://github.com/redpanda-data/redpanda/blob/master/licenses/rcl.md + */ + +package pglogicalstream + +import ( + "cmp" + "sync" +) + +// watermark is a utility that allows you to store the highest value and subscribe to when +// a specific offset is reached +type watermark[T cmp.Ordered] struct { + val T + mu sync.Mutex + cond sync.Cond +} + +// create a new watermark at the initial value +func newWatermark[T cmp.Ordered](initial T) *watermark[T] { + w := &watermark[T]{val: initial} + w.cond = *sync.NewCond(&w.mu) + return w +} + +// Set the watermark value if it's newer +func (w *watermark[T]) Set(v T) { + w.mu.Lock() + defer w.mu.Unlock() + if v <= w.val { + return + } + w.val = v + w.cond.Broadcast() +} + +// Get the current watermark value +func (w *watermark[T]) Get() T { + w.mu.Lock() + cpy := w.val + w.mu.Unlock() + return cpy +} + +// WaitFor waits until the watermark satifies some predicate. +func (w *watermark[T]) WaitFor(pred func(T) bool) { + w.mu.Lock() + defer w.mu.Unlock() + for !pred(w.val) { + w.cond.Wait() + } +} From aa4bc89da7028216bfd5670e3c8fb10a923768f9 Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Tue, 12 Nov 2024 20:47:53 +0000 Subject: [PATCH 079/118] pgcdc: use watermark for log position --- .../pglogicalstream/logical_stream.go | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 148b6d0040..e2323175ac 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -34,7 +34,7 @@ type Stream struct { standbyCtxCancel context.CancelFunc - clientXLogPos LSN + clientXLogPos *watermark[LSN] standbyMessageTimeout time.Duration nextStandbyMessageDeadline time.Time @@ -195,9 +195,9 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { } if freshlyCreatedSlot { - stream.clientXLogPos = sysident.XLogPos + stream.clientXLogPos = newWatermark(sysident.XLogPos) } else { - stream.clientXLogPos = lsnrestart + stream.clientXLogPos = newWatermark(lsnrestart) } stream.standbyMessageTimeout = time.Duration(config.PgStandbyTimeoutSec) * time.Second @@ -210,7 +210,7 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { } stream.monitor = monitor - stream.logger.Debugf("Starting stream from LSN %s with clientXLogPos %s and snapshot name %s", lsnrestart.String(), stream.clientXLogPos.String(), stream.snapshotName) + stream.logger.Debugf("Starting stream from LSN %s with clientXLogPos %s and snapshot name %s", lsnrestart.String(), stream.clientXLogPos.Get().String(), stream.snapshotName) if !freshlyCreatedSlot || !config.StreamOldData { if err = stream.startLr(lsnrestart); err != nil { return nil, err @@ -271,13 +271,13 @@ func (s *Stream) AckLSN(lsn string) error { }) if err != nil { - s.logger.Errorf("Failed to send Standby status message at LSN#%s: %v", s.clientXLogPos.String(), err) + s.logger.Errorf("Failed to send Standby status message at LSN#%s: %v", clientXLogPos.String(), err) return err } // Update client XLogPos after we ack the message - s.clientXLogPos = clientXLogPos - s.logger.Debugf("Sent Standby status message at LSN#%s", s.clientXLogPos.String()) + s.clientXLogPos.Set(clientXLogPos) + s.logger.Debugf("Sent Standby status message at LSN#%s", clientXLogPos.String()) s.nextStandbyMessageDeadline = time.Now().Add(s.standbyMessageTimeout) return nil @@ -311,18 +311,19 @@ func (s *Stream) streamMessagesAsync() { return } + pos := s.clientXLogPos.Get() err := SendStandbyStatusUpdate(context.Background(), s.pgConn, StandbyStatusUpdate{ - WALWritePosition: s.clientXLogPos, + WALWritePosition: pos, }) if err != nil { - s.logger.Errorf("Failed to send Standby status message at LSN#%s: %v", s.clientXLogPos.String(), err) + s.logger.Errorf("Failed to send Standby status message at LSN#%s: %v", pos.String(), err) if err = s.Stop(); err != nil { s.logger.Errorf("Failed to stop the stream: %v", err) } return } - s.logger.Debugf("Sent Standby status message at LSN#%s", s.clientXLogPos.String()) + s.logger.Debugf("Sent Standby status message at LSN#%s", pos.String()) s.nextStandbyMessageDeadline = time.Now().Add(s.standbyMessageTimeout) } From 2524a3fc17834842e8e681b40a6e2208ba903556 Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Tue, 12 Nov 2024 20:51:22 +0000 Subject: [PATCH 080/118] pgcdc: remove layer of nesting from switch --- .../pglogicalstream/logical_stream.go | 188 +++++++++--------- 1 file changed, 93 insertions(+), 95 deletions(-) diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index e2323175ac..12bd1cbd1a 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -300,134 +300,132 @@ func (s *Stream) streamMessagesAsync() { } for { - select { - case <-s.streamCtx.Done(): - s.logger.Warn("Stream was cancelled...exiting...") + if s.streamCtx.Err() != nil { + s.logger.Debug("Stream was cancelled... exiting...") return - default: - if time.Now().After(s.nextStandbyMessageDeadline) { - if s.pgConn.IsClosed() { - s.logger.Warn("Postgres connection is closed...stop reading from replication slot") - return - } + } + if time.Now().After(s.nextStandbyMessageDeadline) { + if s.pgConn.IsClosed() { + s.logger.Warn("Postgres connection is closed...stop reading from replication slot") + return + } - pos := s.clientXLogPos.Get() - err := SendStandbyStatusUpdate(context.Background(), s.pgConn, StandbyStatusUpdate{ - WALWritePosition: pos, - }) + pos := s.clientXLogPos.Get() + err := SendStandbyStatusUpdate(context.Background(), s.pgConn, StandbyStatusUpdate{ + WALWritePosition: pos, + }) - if err != nil { - s.logger.Errorf("Failed to send Standby status message at LSN#%s: %v", pos.String(), err) - if err = s.Stop(); err != nil { - s.logger.Errorf("Failed to stop the stream: %v", err) - } - return + if err != nil { + s.logger.Errorf("Failed to send Standby status message at LSN#%s: %v", pos.String(), err) + if err = s.Stop(); err != nil { + s.logger.Errorf("Failed to stop the stream: %v", err) } - s.logger.Debugf("Sent Standby status message at LSN#%s", pos.String()) - s.nextStandbyMessageDeadline = time.Now().Add(s.standbyMessageTimeout) + return } + s.logger.Debugf("Sent Standby status message at LSN#%s", pos.String()) + s.nextStandbyMessageDeadline = time.Now().Add(s.standbyMessageTimeout) + } - ctx, cancel := context.WithDeadline(context.Background(), s.nextStandbyMessageDeadline) - rawMsg, err := s.pgConn.ReceiveMessage(ctx) - s.standbyCtxCancel = cancel + ctx, cancel := context.WithDeadline(context.Background(), s.nextStandbyMessageDeadline) + rawMsg, err := s.pgConn.ReceiveMessage(ctx) + s.standbyCtxCancel = cancel - if err != nil && (errors.Is(err, context.Canceled) || s.stopped) { - s.logger.Warn("Service was interrupted....stop reading from replication slot") - return + if err != nil && (errors.Is(err, context.Canceled) || s.stopped) { + s.logger.Warn("Service was interrupted....stop reading from replication slot") + return + } + + if err != nil { + if pgconn.Timeout(err) { + continue } - if err != nil { - if pgconn.Timeout(err) { - continue - } + s.logger.Errorf("Failed to receive messages from PostgreSQL: %v", err) + if err = s.Stop(); err != nil { + s.logger.Errorf("Failed to stop the stream: %v", err) + } + return + } - s.logger.Errorf("Failed to receive messages from PostgreSQL: %v", err) - if err = s.Stop(); err != nil { - s.logger.Errorf("Failed to stop the stream: %v", err) - } - return + if errMsg, ok := rawMsg.(*pgproto3.ErrorResponse); ok { + s.logger.Errorf("Received error message from Postgres: %v", errMsg) + if err = s.Stop(); err != nil { + s.logger.Errorf("Failed to stop the stream: %v", err) } + return + } + + msg, ok := rawMsg.(*pgproto3.CopyData) + if !ok { + s.logger.Warnf("Received unexpected message: %T\n", rawMsg) + continue + } - if errMsg, ok := rawMsg.(*pgproto3.ErrorResponse); ok { - s.logger.Errorf("Received error message from Postgres: %v", errMsg) + switch msg.Data[0] { + case PrimaryKeepaliveMessageByteID: + pkm, err := ParsePrimaryKeepaliveMessage(msg.Data[1:]) + if err != nil { + s.logger.Errorf("Failed to parse PrimaryKeepaliveMessage: %v", err) if err = s.Stop(); err != nil { s.logger.Errorf("Failed to stop the stream: %v", err) } - return } - msg, ok := rawMsg.(*pgproto3.CopyData) - if !ok { - s.logger.Warnf("Received unexpected message: %T\n", rawMsg) - continue + if pkm.ReplyRequested { + s.nextStandbyMessageDeadline = time.Time{} } - switch msg.Data[0] { - case PrimaryKeepaliveMessageByteID: - pkm, err := ParsePrimaryKeepaliveMessage(msg.Data[1:]) - if err != nil { - s.logger.Errorf("Failed to parse PrimaryKeepaliveMessage: %v", err) + // XLogDataByteID is the message type for the actual WAL data + // It will cause the stream to process WAL changes and create the corresponding messages + case XLogDataByteID: + xld, err := ParseXLogData(msg.Data[1:]) + if err != nil { + s.logger.Errorf("Failed to parse XLogData: %v", err) + if err = s.Stop(); err != nil { + s.logger.Errorf("Failed to stop the stream: %v", err) + } + } + clientXLogPos := xld.WALStart + LSN(len(xld.WALData)) + if s.decodingPlugin == "wal2json" { + if err = handler.Handle(clientXLogPos, xld); err != nil { + s.logger.Errorf("decodeWal2JsonChanges failed: %w", err) if err = s.Stop(); err != nil { s.logger.Errorf("Failed to stop the stream: %v", err) } + return } - if pkm.ReplyRequested { - s.nextStandbyMessageDeadline = time.Time{} - } - - // XLogDataByteID is the message type for the actual WAL data - // It will cause the stream to process WAL changes and create the corresponding messages - case XLogDataByteID: - xld, err := ParseXLogData(msg.Data[1:]) - if err != nil { - s.logger.Errorf("Failed to parse XLogData: %v", err) + // automatic ack for empty changes + // basically mean that the client is up-to-date, + // but we still need to acknowledge the LSN for standby + if err = s.AckLSN(clientXLogPos.String()); err != nil { + // stop reading from replication slot + // if we can't acknowledge the LSN if err = s.Stop(); err != nil { s.logger.Errorf("Failed to stop the stream: %v", err) } + return } - clientXLogPos := xld.WALStart + LSN(len(xld.WALData)) - if s.decodingPlugin == "wal2json" { - if err = handler.Handle(clientXLogPos, xld); err != nil { - s.logger.Errorf("decodeWal2JsonChanges failed: %w", err) - if err = s.Stop(); err != nil { - s.logger.Errorf("Failed to stop the stream: %v", err) - } - return - } + } - // automatic ack for empty changes - // basically mean that the client is up-to-date, - // but we still need to acknowledge the LSN for standby - if err = s.AckLSN(clientXLogPos.String()); err != nil { - // stop reading from replication slot - // if we can't acknowledge the LSN - if err = s.Stop(); err != nil { - s.logger.Errorf("Failed to stop the stream: %v", err) - } - return + if s.decodingPlugin == "pgoutput" { + if err = handler.Handle(clientXLogPos, xld); err != nil { + s.logger.Errorf("decodePgOutputChanges failed: %w", err) + if err = s.Stop(); err != nil { + s.logger.Errorf("Failed to stop the stream: %v", err) } } - if s.decodingPlugin == "pgoutput" { - if err = handler.Handle(clientXLogPos, xld); err != nil { - s.logger.Errorf("decodePgOutputChanges failed: %w", err) - if err = s.Stop(); err != nil { - s.logger.Errorf("Failed to stop the stream: %v", err) - } - } - - // automatic ack for empty changes - // basically mean that the client is up-to-date, - // but we still need to acknowledge the LSN for standby - if err = s.AckLSN(clientXLogPos.String()); err != nil { - // stop reading from replication slot - // if we can't acknowledge the LSN - if err = s.Stop(); err != nil { - s.logger.Errorf("Failed to stop the stream: %v", err) - } - return + // automatic ack for empty changes + // basically mean that the client is up-to-date, + // but we still need to acknowledge the LSN for standby + if err = s.AckLSN(clientXLogPos.String()); err != nil { + // stop reading from replication slot + // if we can't acknowledge the LSN + if err = s.Stop(); err != nil { + s.logger.Errorf("Failed to stop the stream: %v", err) } + return } } } From f31c71bb2d40d8e76071b7cc094110423ad80157 Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Tue, 12 Nov 2024 20:58:46 +0000 Subject: [PATCH 081/118] pgcdc: use typed duration fields --- internal/impl/postgresql/input_pg_stream.go | 40 +++++++++---------- .../impl/postgresql/pglogicalstream/config.go | 6 ++- .../pglogicalstream/logical_stream.go | 4 +- .../postgresql/pglogicalstream/monitor.go | 5 +-- 4 files changed, 28 insertions(+), 27 deletions(-) diff --git a/internal/impl/postgresql/input_pg_stream.go b/internal/impl/postgresql/input_pg_stream.go index b64947954a..a770552390 100644 --- a/internal/impl/postgresql/input_pg_stream.go +++ b/internal/impl/postgresql/input_pg_stream.go @@ -36,8 +36,8 @@ const ( fieldTables = "tables" fieldCheckpointLimit = "checkpoint_limit" fieldTemporarySlot = "temporary_slot" - fieldPgStandbyTimeout = "pg_standby_timeout_sec" - fieldWalMonitorIntervalSec = "pg_wal_monitor_interval_sec" + fieldPgStandbyTimeout = "pg_standby_timeout" + fieldWalMonitorInterval = "pg_wal_monitor_interval" fieldSlotName = "slot_name" fieldBatching = "batching" fieldMaxParallelSnapshotTables = "max_parallel_snapshot_tables" @@ -106,14 +106,14 @@ Important: No matter which plugin you choose, the data will be converted to JSON Description("The name of the PostgreSQL logical replication slot to use. If not provided, a random name will be generated. You can create this slot manually before starting replication if desired."). Example("my_test_slot"). Default("")). - Field(service.NewIntField(fieldPgStandbyTimeout). - Description("Int field that specifies default standby timeout for PostgreSQL replication connection"). - Example(10). - Default(10)). - Field(service.NewIntField(fieldWalMonitorIntervalSec). - Description("Int field stat specifies ticker interval for WAL monitoring. Used to fetch replication slot lag"). - Example(3). - Default(3)). + Field(service.NewDurationField(fieldPgStandbyTimeout). + Description("Specify the standby timeout before refreshing an idle connection."). + Example(30 * time.Second). + Default(10 * time.Second)). + Field(service.NewDurationField(fieldWalMonitorInterval). + Description("How often to report changes to the replication lag."). + Example(6 * time.Second). + Default(3 * time.Second)). Field(service.NewIntField(fieldMaxParallelSnapshotTables). Description("Int specifies a number of tables that will be processed in parallel during the snapshot processing stage"). Default(1)). @@ -133,9 +133,9 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser streamUncommitted bool snapshotBatchSize int checkpointLimit int - walMonitorIntervalSec int + walMonitorInterval time.Duration maxParallelSnapshotTables int - pgStandbyTimeoutSec int + pgStandbyTimeout time.Duration batching service.BatchPolicy ) @@ -197,11 +197,11 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser batching.Count = 1 } - if pgStandbyTimeoutSec, err = conf.FieldInt(fieldPgStandbyTimeout); err != nil { + if pgStandbyTimeout, err = conf.FieldDuration(fieldPgStandbyTimeout); err != nil { return nil, err } - if walMonitorIntervalSec, err = conf.FieldInt(fieldWalMonitorIntervalSec); err != nil { + if walMonitorInterval, err = conf.FieldDuration(fieldWalMonitorInterval); err != nil { return nil, err } @@ -239,8 +239,8 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser snapshotBatchSize: snapshotBatchSize, batching: batching, checkpointLimit: checkpointLimit, - pgStandbyTimeoutSec: pgStandbyTimeoutSec, - walMonitorIntervalSec: walMonitorIntervalSec, + pgStandbyTimeout: pgStandbyTimeout, + walMonitorInterval: walMonitorInterval, maxParallelSnapshotTables: maxParallelSnapshotTables, cMut: sync.Mutex{}, msgChan: make(chan asyncMessage), @@ -295,8 +295,8 @@ type pgStreamInput struct { dbRawDSN string pgLogicalStream *pglogicalstream.Stream slotName string - pgStandbyTimeoutSec int - walMonitorIntervalSec int + pgStandbyTimeout time.Duration + walMonitorInterval time.Duration temporarySlot bool schema string tables []string @@ -334,8 +334,8 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { StreamUncommitted: p.streamUncommitted, DecodingPlugin: p.decodingPlugin, SnapshotMemorySafetyFactor: p.snapshotMemSafetyFactor, - PgStandbyTimeoutSec: p.pgStandbyTimeoutSec, - WalMonitorIntervalSec: p.walMonitorIntervalSec, + PgStandbyTimeout: p.pgStandbyTimeout, + WalMonitorInterval: p.walMonitorInterval, MaxParallelSnapshotTables: p.maxParallelSnapshotTables, Logger: p.logger, }) diff --git a/internal/impl/postgresql/pglogicalstream/config.go b/internal/impl/postgresql/pglogicalstream/config.go index 7af9b3f0fa..99415d65f8 100644 --- a/internal/impl/postgresql/pglogicalstream/config.go +++ b/internal/impl/postgresql/pglogicalstream/config.go @@ -9,6 +9,8 @@ package pglogicalstream import ( + "time" + "github.com/jackc/pgx/v5/pgconn" "github.com/redpanda-data/benthos/v4/public/service" ) @@ -41,7 +43,7 @@ type Config struct { Logger *service.Logger - PgStandbyTimeoutSec int - WalMonitorIntervalSec int + PgStandbyTimeout time.Duration + WalMonitorInterval time.Duration MaxParallelSnapshotTables int } diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 12bd1cbd1a..b5f33b66cd 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -200,11 +200,11 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { stream.clientXLogPos = newWatermark(lsnrestart) } - stream.standbyMessageTimeout = time.Duration(config.PgStandbyTimeoutSec) * time.Second + stream.standbyMessageTimeout = config.PgStandbyTimeout stream.nextStandbyMessageDeadline = time.Now().Add(stream.standbyMessageTimeout) stream.streamCtx, stream.streamCancel = context.WithCancel(context.Background()) - monitor, err := NewMonitor(config.DBRawDSN, stream.logger, tableNames, stream.slotName, config.WalMonitorIntervalSec) + monitor, err := NewMonitor(config.DBRawDSN, stream.logger, tableNames, stream.slotName, config.WalMonitorInterval) if err != nil { return nil, err } diff --git a/internal/impl/postgresql/pglogicalstream/monitor.go b/internal/impl/postgresql/pglogicalstream/monitor.go index df14b9f4a4..0408c5a1c6 100644 --- a/internal/impl/postgresql/pglogicalstream/monitor.go +++ b/internal/impl/postgresql/pglogicalstream/monitor.go @@ -47,7 +47,7 @@ type Monitor struct { } // NewMonitor creates a new Monitor instance -func NewMonitor(dbDSN string, logger *service.Logger, tables []string, slotName string, intervalSec int) (*Monitor, error) { +func NewMonitor(dbDSN string, logger *service.Logger, tables []string, slotName string, interval time.Duration) (*Monitor, error) { dbConn, err := openPgConnectionFromConfig(dbDSN) if err != nil { return nil, err @@ -68,8 +68,7 @@ func NewMonitor(dbDSN string, logger *service.Logger, tables []string, slotName ctx, cancel := context.WithCancel(context.Background()) m.ctx = ctx m.cancelTicker = cancel - // hardocded duration to monitor slot lag - m.ticker = time.NewTicker(time.Second * time.Duration(intervalSec)) + m.ticker = time.NewTicker(interval) go func() { for { From 5f2bce055302251eca125dbf7b3cb46ce3491a7d Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Tue, 12 Nov 2024 21:03:51 +0000 Subject: [PATCH 082/118] pgcdc: fix waiting for txn ack --- internal/impl/postgresql/input_pg_stream.go | 45 ++----------------- .../pglogicalstream/logical_stream.go | 22 +-------- .../pglogicalstream/pluginhandlers.go | 28 ++++++------ 3 files changed, 17 insertions(+), 78 deletions(-) diff --git a/internal/impl/postgresql/input_pg_stream.go b/internal/impl/postgresql/input_pg_stream.go index a770552390..f913355f43 100644 --- a/internal/impl/postgresql/input_pg_stream.go +++ b/internal/impl/postgresql/input_pg_stream.go @@ -368,33 +368,10 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { break } - if err := p.flushBatch(ctx, cp, flushedBatch, latestOffset, false); err != nil { + if err := p.flushBatch(ctx, cp, flushedBatch, latestOffset); err != nil { break } - // TrxCommit LSN must be acked when all the messages in the batch are processed - case trxCommitLsn, open := <-p.pgLogicalStream.AckTxChan(): - if !open { - break - } - - flushedBatch, err := batchPolicy.Flush(ctx) - if err != nil { - p.logger.Debugf("Flush batch error: %w", err) - break - } - - if err = p.flushBatch(ctx, cp, flushedBatch, latestOffset, true); err != nil { - break - } - - if err = p.pgLogicalStream.AckLSN(trxCommitLsn); err != nil { - p.logger.Errorf("Failed to ack LSN: %v", err) - break - } - - p.pgLogicalStream.ConsumedCallback() <- true - case message, open := <-p.pgLogicalStream.Messages(): if !open { break @@ -440,8 +417,7 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { p.logger.Debugf("Flush batch error: %w", err) break } - waitForCommit := message.Mode == pglogicalstream.StreamModeStreaming - if err := p.flushBatch(ctx, cp, flushedBatch, latestOffset, waitForCommit); err != nil { + if err := p.flushBatch(ctx, cp, flushedBatch, latestOffset); err != nil { break } } @@ -457,7 +433,7 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { return err } -func (p *pgStreamInput) flushBatch(ctx context.Context, checkpointer *checkpoint.Capped[*int64], msg service.MessageBatch, lsn *int64, waitForCommit bool) error { +func (p *pgStreamInput) flushBatch(ctx context.Context, checkpointer *checkpoint.Capped[*int64], msg service.MessageBatch, lsn *int64) error { if msg == nil { return nil } @@ -470,21 +446,7 @@ func (p *pgStreamInput) flushBatch(ctx context.Context, checkpointer *checkpoint return err } - var wg sync.WaitGroup - if waitForCommit { - wg.Add(1) - } ackFn := func(ctx context.Context, res error) error { - // This waits for *THIS MESSAGE* to get acked, which is - // not when we actually ack this LSN because of out of order - // processing might cause another message to actually resolve - // the proper checkpointer to commit. - // - // This waitForCommit business probably needs to happen inside - // the ack stream not here. - if waitForCommit { - defer wg.Done() - } maxOffset := resolveFn() if maxOffset == nil { return nil @@ -504,7 +466,6 @@ func (p *pgStreamInput) flushBatch(ctx context.Context, checkpointer *checkpoint case <-ctx.Done(): return ctx.Err() } - wg.Wait() // Noop if !waitForCommit return nil } diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index b5f33b66cd..79f27887ea 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -52,16 +52,10 @@ type Stream struct { monitor *Monitor streamUncommitted bool snapshotter *Snapshotter - transactionAckChan chan string - transactionBeginChan chan bool maxParallelSnapshotTables int - lsnAckBuffer []string - m sync.Mutex stopped bool - - consumedCallback chan bool } // NewPgStream creates a new instance of the Stream struct @@ -92,10 +86,6 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { snapshotBatchSize: config.BatchSize, schema: config.DBSchema, tableQualifiedName: tableNames, - consumedCallback: make(chan bool), - transactionAckChan: make(chan string), - transactionBeginChan: make(chan bool), - lsnAckBuffer: []string{}, maxParallelSnapshotTables: config.MaxParallelSnapshotTables, logger: config.Logger, m: sync.Mutex{}, @@ -236,11 +226,6 @@ func (s *Stream) GetProgress() *Report { return s.monitor.Report() } -// ConsumedCallback returns a channel that is used to tell the plugin to commit consumed offset -func (s *Stream) ConsumedCallback() chan bool { - return s.consumedCallback -} - func (s *Stream) startLr(lsnStart LSN) error { if err := StartReplication(context.Background(), s.pgConn, s.slotName, lsnStart, StartReplicationOptions{PluginArgs: s.decodingPluginArguments}); err != nil { return err @@ -289,7 +274,7 @@ func (s *Stream) streamMessagesAsync() { case "wal2json": handler = NewWal2JsonPluginHandler(s.messages, s.monitor) case "pgoutput": - handler = NewPgOutputPluginHandler(s.messages, s.streamUncommitted, s.monitor, s.consumedCallback, s.transactionAckChan) + handler = NewPgOutputPluginHandler(s.messages, s.streamUncommitted, s.monitor, s.clientXLogPos) default: s.logger.Error("Invalid decoding plugin. Cant find needed handler implementation") if err := s.Stop(); err != nil { @@ -432,11 +417,6 @@ func (s *Stream) streamMessagesAsync() { } } -// AckTxChan returns the transaction ack channel -func (s *Stream) AckTxChan() chan string { - return s.transactionAckChan -} - func (s *Stream) processSnapshot(ctx context.Context, lsnStart LSN) error { if err := s.snapshotter.prepare(); err != nil { s.logger.Errorf("Failed to prepare database snapshot. Probably snapshot is expired...: %v", err.Error()) diff --git a/internal/impl/postgresql/pglogicalstream/pluginhandlers.go b/internal/impl/postgresql/pglogicalstream/pluginhandlers.go index b0902bc818..45b304ea64 100644 --- a/internal/impl/postgresql/pglogicalstream/pluginhandlers.go +++ b/internal/impl/postgresql/pglogicalstream/pluginhandlers.go @@ -56,8 +56,7 @@ type PgOutputPluginHandler struct { typeMap *pgtype.Map pgoutputChanges []StreamMessageChanges - consumedCallback chan bool - transactionAckChan chan string + lsnWatermark *watermark[LSN] } // NewPgOutputPluginHandler creates a new PgOutputPluginHandler @@ -65,18 +64,16 @@ func NewPgOutputPluginHandler( messages chan StreamMessage, streamUncommitted bool, monitor *Monitor, - consumedCallback chan bool, - transactionAckChan chan string, + lsnWatermark *watermark[LSN], ) *PgOutputPluginHandler { return &PgOutputPluginHandler{ - messages: messages, - monitor: monitor, - streamUncommitted: streamUncommitted, - relations: map[uint32]*RelationMessage{}, - typeMap: pgtype.NewMap(), - pgoutputChanges: []StreamMessageChanges{}, - consumedCallback: consumedCallback, - transactionAckChan: transactionAckChan, + messages: messages, + monitor: monitor, + streamUncommitted: streamUncommitted, + relations: map[uint32]*RelationMessage{}, + typeMap: pgtype.NewMap(), + pgoutputChanges: []StreamMessageChanges{}, + lsnWatermark: lsnWatermark, } } @@ -95,10 +92,11 @@ func (p *PgOutputPluginHandler) Handle(clientXLogPos LSN, xld XLogData) error { } // when receiving a commit message, we need to acknowledge the LSN - // but we must wait for benthos to flush the messages before we can do that + // but we must wait for connect to flush the messages before we can do that if isCommit { - p.transactionAckChan <- clientXLogPos.String() - <-p.consumedCallback + p.lsnWatermark.WaitFor(func(lsn LSN) bool { + return lsn >= clientXLogPos + }) } else { if message == nil && !isCommit { return nil From ec5783fbf751b26548eb392f8f87cb5919069d05 Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Tue, 12 Nov 2024 21:11:03 +0000 Subject: [PATCH 083/118] pgcdc: dedup config fields --- internal/impl/postgresql/input_pg_stream.go | 96 +++++++-------------- 1 file changed, 31 insertions(+), 65 deletions(-) diff --git a/internal/impl/postgresql/input_pg_stream.go b/internal/impl/postgresql/input_pg_stream.go index f913355f43..a4be41caca 100644 --- a/internal/impl/postgresql/input_pg_stream.go +++ b/internal/impl/postgresql/input_pg_stream.go @@ -14,7 +14,6 @@ import ( "fmt" "strings" "sync" - "sync/atomic" "time" "github.com/Jeffail/checkpoint" @@ -224,33 +223,33 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser replicationLag := mgr.Metrics().NewGauge("postgres_replication_lag_bytes") i := &pgStreamInput{ - dbConfig: pgConnConfig, - // dbRawDSN is used for creating golang PG Connection - // as using pgconn.Config for golang doesn't support multiple queries in the prepared statement for Postgres Version <= 14 - dbRawDSN: dsn, - streamSnapshot: streamSnapshot, - snapshotMemSafetyFactor: snapshotMemSafetyFactor, - slotName: dbSlotName, - schema: schema, - tables: tables, - decodingPlugin: decodingPlugin, - streamUncommitted: streamUncommitted, - temporarySlot: temporarySlot, - snapshotBatchSize: snapshotBatchSize, - batching: batching, - checkpointLimit: checkpointLimit, - pgStandbyTimeout: pgStandbyTimeout, - walMonitorInterval: walMonitorInterval, - maxParallelSnapshotTables: maxParallelSnapshotTables, - cMut: sync.Mutex{}, - msgChan: make(chan asyncMessage), + streamConfig: &pglogicalstream.Config{ + DBConfig: pgConnConfig, + DBRawDSN: dsn, + DBSchema: schema, + DBTables: tables, + + ReplicationSlotName: "rs_" + dbSlotName, + BatchSize: snapshotBatchSize, + StreamOldData: streamSnapshot, + TemporaryReplicationSlot: temporarySlot, + StreamUncommitted: streamUncommitted, + DecodingPlugin: decodingPlugin, + SnapshotMemorySafetyFactor: snapshotMemSafetyFactor, + PgStandbyTimeout: pgStandbyTimeout, + WalMonitorInterval: walMonitorInterval, + MaxParallelSnapshotTables: maxParallelSnapshotTables, + Logger: mgr.Logger(), + }, + batching: batching, + checkpointLimit: checkpointLimit, + cMut: sync.Mutex{}, + msgChan: make(chan asyncMessage), mgr: mgr, logger: mgr.Logger(), snapshotMetrics: snapshotMetrics, replicationLag: replicationLag, - inTxState: atomic.Bool{}, - releaseTrxChan: make(chan bool), } r, err := service.AutoRetryNacksBatchedToggled(conf, i) @@ -291,54 +290,21 @@ func init() { } type pgStreamInput struct { - dbConfig *pgconn.Config - dbRawDSN string - pgLogicalStream *pglogicalstream.Stream - slotName string - pgStandbyTimeout time.Duration - walMonitorInterval time.Duration - temporarySlot bool - schema string - tables []string - decodingPlugin string - streamSnapshot bool - snapshotMemSafetyFactor float64 - snapshotBatchSize int - streamUncommitted bool - maxParallelSnapshotTables int - logger *service.Logger - mgr *service.Resources - cMut sync.Mutex - msgChan chan asyncMessage - batching service.BatchPolicy - checkpointLimit int + streamConfig *pglogicalstream.Config + pgLogicalStream *pglogicalstream.Stream + logger *service.Logger + mgr *service.Resources + cMut sync.Mutex + msgChan chan asyncMessage + batching service.BatchPolicy + checkpointLimit int snapshotMetrics *service.MetricGauge replicationLag *service.MetricGauge - - releaseTrxChan chan bool - inTxState atomic.Bool } func (p *pgStreamInput) Connect(ctx context.Context) error { - pgStream, err := pglogicalstream.NewPgStream(ctx, &pglogicalstream.Config{ - DBConfig: p.dbConfig, - DBRawDSN: p.dbRawDSN, - DBSchema: p.schema, - DBTables: p.tables, - - ReplicationSlotName: "rs_" + p.slotName, - BatchSize: p.snapshotBatchSize, - StreamOldData: p.streamSnapshot, - TemporaryReplicationSlot: p.temporarySlot, - StreamUncommitted: p.streamUncommitted, - DecodingPlugin: p.decodingPlugin, - SnapshotMemorySafetyFactor: p.snapshotMemSafetyFactor, - PgStandbyTimeout: p.pgStandbyTimeout, - WalMonitorInterval: p.walMonitorInterval, - MaxParallelSnapshotTables: p.maxParallelSnapshotTables, - Logger: p.logger, - }) + pgStream, err := pglogicalstream.NewPgStream(ctx, p.streamConfig) if err != nil { return fmt.Errorf("unable to create replication stream: %w", err) } From 089ed64185be1f1260b679f482a5cc97f9b7c370 Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Tue, 12 Nov 2024 21:22:48 +0000 Subject: [PATCH 084/118] pgcdc: fix config field defaults --- internal/impl/postgresql/input_pg_stream.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/internal/impl/postgresql/input_pg_stream.go b/internal/impl/postgresql/input_pg_stream.go index a4be41caca..a698037181 100644 --- a/internal/impl/postgresql/input_pg_stream.go +++ b/internal/impl/postgresql/input_pg_stream.go @@ -107,12 +107,12 @@ Important: No matter which plugin you choose, the data will be converted to JSON Default("")). Field(service.NewDurationField(fieldPgStandbyTimeout). Description("Specify the standby timeout before refreshing an idle connection."). - Example(30 * time.Second). - Default(10 * time.Second)). + Example("30s"). + Default("10s")). Field(service.NewDurationField(fieldWalMonitorInterval). Description("How often to report changes to the replication lag."). - Example(6 * time.Second). - Default(3 * time.Second)). + Example("6s"). + Default("3s")). Field(service.NewIntField(fieldMaxParallelSnapshotTables). Description("Int specifies a number of tables that will be processed in parallel during the snapshot processing stage"). Default(1)). From f8cbc95879500b1b256ca75ebf3ce0f2eaace961 Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Wed, 13 Nov 2024 02:12:27 +0000 Subject: [PATCH 085/118] pgcdc: properly implement watermark We need to be able to be cancelled if we never reach the watermark --- .../pglogicalstream/logical_stream.go | 11 +-- .../pglogicalstream/pluginhandlers.go | 25 ++++--- .../postgresql/pglogicalstream/watermark.go | 59 ---------------- .../pglogicalstream/watermark/watermark.go | 70 +++++++++++++++++++ .../watermark/watermark_test.go | 51 ++++++++++++++ 5 files changed, 143 insertions(+), 73 deletions(-) delete mode 100644 internal/impl/postgresql/pglogicalstream/watermark.go create mode 100644 internal/impl/postgresql/pglogicalstream/watermark/watermark.go create mode 100644 internal/impl/postgresql/pglogicalstream/watermark/watermark_test.go diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 79f27887ea..d097ae9bfb 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -22,6 +22,7 @@ import ( "github.com/jackc/pgx/v5/pgproto3" "github.com/redpanda-data/benthos/v4/public/service" "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/sanitize" + "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/watermark" "golang.org/x/sync/errgroup" ) @@ -34,7 +35,7 @@ type Stream struct { standbyCtxCancel context.CancelFunc - clientXLogPos *watermark[LSN] + clientXLogPos *watermark.Value[LSN] standbyMessageTimeout time.Duration nextStandbyMessageDeadline time.Time @@ -185,9 +186,9 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { } if freshlyCreatedSlot { - stream.clientXLogPos = newWatermark(sysident.XLogPos) + stream.clientXLogPos = watermark.New(sysident.XLogPos) } else { - stream.clientXLogPos = newWatermark(lsnrestart) + stream.clientXLogPos = watermark.New(lsnrestart) } stream.standbyMessageTimeout = config.PgStandbyTimeout @@ -372,7 +373,7 @@ func (s *Stream) streamMessagesAsync() { } clientXLogPos := xld.WALStart + LSN(len(xld.WALData)) if s.decodingPlugin == "wal2json" { - if err = handler.Handle(clientXLogPos, xld); err != nil { + if err = handler.Handle(s.streamCtx, clientXLogPos, xld); err != nil { s.logger.Errorf("decodeWal2JsonChanges failed: %w", err) if err = s.Stop(); err != nil { s.logger.Errorf("Failed to stop the stream: %v", err) @@ -394,7 +395,7 @@ func (s *Stream) streamMessagesAsync() { } if s.decodingPlugin == "pgoutput" { - if err = handler.Handle(clientXLogPos, xld); err != nil { + if err = handler.Handle(s.streamCtx, clientXLogPos, xld); err != nil { s.logger.Errorf("decodePgOutputChanges failed: %w", err) if err = s.Stop(); err != nil { s.logger.Errorf("Failed to stop the stream: %v", err) diff --git a/internal/impl/postgresql/pglogicalstream/pluginhandlers.go b/internal/impl/postgresql/pglogicalstream/pluginhandlers.go index 45b304ea64..500b77a31e 100644 --- a/internal/impl/postgresql/pglogicalstream/pluginhandlers.go +++ b/internal/impl/postgresql/pglogicalstream/pluginhandlers.go @@ -8,11 +8,16 @@ package pglogicalstream -import "github.com/jackc/pgx/v5/pgtype" +import ( + "context" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/watermark" +) // PluginHandler is an interface that must be implemented by all plugin handlers type PluginHandler interface { - Handle(clientXLogPos LSN, xld XLogData) error + Handle(ctx context.Context, clientXLogPos LSN, xld XLogData) error } // Wal2JsonPluginHandler is a handler for wal2json output plugin @@ -30,7 +35,7 @@ func NewWal2JsonPluginHandler(messages chan StreamMessage, monitor *Monitor) *Wa } // Handle handles the wal2json output -func (w *Wal2JsonPluginHandler) Handle(clientXLogPos LSN, xld XLogData) error { +func (w *Wal2JsonPluginHandler) Handle(_ context.Context, clientXLogPos LSN, xld XLogData) error { // get current stream metrics metrics := w.monitor.Report() message, err := decodeWal2JsonChanges(clientXLogPos.String(), xld.WALData) @@ -56,7 +61,7 @@ type PgOutputPluginHandler struct { typeMap *pgtype.Map pgoutputChanges []StreamMessageChanges - lsnWatermark *watermark[LSN] + lsnWatermark *watermark.Value[LSN] } // NewPgOutputPluginHandler creates a new PgOutputPluginHandler @@ -64,7 +69,7 @@ func NewPgOutputPluginHandler( messages chan StreamMessage, streamUncommitted bool, monitor *Monitor, - lsnWatermark *watermark[LSN], + lsnWatermark *watermark.Value[LSN], ) *PgOutputPluginHandler { return &PgOutputPluginHandler{ messages: messages, @@ -78,7 +83,7 @@ func NewPgOutputPluginHandler( } // Handle handles the pgoutput output -func (p *PgOutputPluginHandler) Handle(clientXLogPos LSN, xld XLogData) error { +func (p *PgOutputPluginHandler) Handle(ctx context.Context, clientXLogPos LSN, xld XLogData) error { if p.streamUncommitted { // parse changes inside the transaction message, err := decodePgOutput(xld.WALData, p.relations, p.typeMap) @@ -94,9 +99,11 @@ func (p *PgOutputPluginHandler) Handle(clientXLogPos LSN, xld XLogData) error { // when receiving a commit message, we need to acknowledge the LSN // but we must wait for connect to flush the messages before we can do that if isCommit { - p.lsnWatermark.WaitFor(func(lsn LSN) bool { - return lsn >= clientXLogPos - }) + select { + case <-p.lsnWatermark.WaitFor(clientXLogPos): + case <-ctx.Done(): + return ctx.Err() + } } else { if message == nil && !isCommit { return nil diff --git a/internal/impl/postgresql/pglogicalstream/watermark.go b/internal/impl/postgresql/pglogicalstream/watermark.go deleted file mode 100644 index bd24967994..0000000000 --- a/internal/impl/postgresql/pglogicalstream/watermark.go +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Copyright 2024 Redpanda Data, Inc. - * - * Licensed as a Redpanda Enterprise file under the Redpanda Community - * License (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * https://github.com/redpanda-data/redpanda/blob/master/licenses/rcl.md - */ - -package pglogicalstream - -import ( - "cmp" - "sync" -) - -// watermark is a utility that allows you to store the highest value and subscribe to when -// a specific offset is reached -type watermark[T cmp.Ordered] struct { - val T - mu sync.Mutex - cond sync.Cond -} - -// create a new watermark at the initial value -func newWatermark[T cmp.Ordered](initial T) *watermark[T] { - w := &watermark[T]{val: initial} - w.cond = *sync.NewCond(&w.mu) - return w -} - -// Set the watermark value if it's newer -func (w *watermark[T]) Set(v T) { - w.mu.Lock() - defer w.mu.Unlock() - if v <= w.val { - return - } - w.val = v - w.cond.Broadcast() -} - -// Get the current watermark value -func (w *watermark[T]) Get() T { - w.mu.Lock() - cpy := w.val - w.mu.Unlock() - return cpy -} - -// WaitFor waits until the watermark satifies some predicate. -func (w *watermark[T]) WaitFor(pred func(T) bool) { - w.mu.Lock() - defer w.mu.Unlock() - for !pred(w.val) { - w.cond.Wait() - } -} diff --git a/internal/impl/postgresql/pglogicalstream/watermark/watermark.go b/internal/impl/postgresql/pglogicalstream/watermark/watermark.go new file mode 100644 index 0000000000..56dc30783b --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/watermark/watermark.go @@ -0,0 +1,70 @@ +/* + * Copyright 2024 Redpanda Data, Inc. + * + * Licensed as a Redpanda Enterprise file under the Redpanda Community + * License (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://github.com/redpanda-data/redpanda/blob/master/licenses/rcl.md + */ + +package watermark + +import ( + "cmp" + "sync" +) + +// Value is a utility that allows you to store the highest value and subscribe to when +// a specific offset is reached +type ( + Value[T cmp.Ordered] struct { + val T + mu sync.Mutex + waiters map[chan<- any]T + } +) + +// New makes a new Value holding `initial` +func New[T cmp.Ordered](initial T) *Value[T] { + w := &Value[T]{val: initial} + w.waiters = map[chan<- any]T{} + return w +} + +// Set the watermark value if it's newer +func (w *Value[T]) Set(v T) { + w.mu.Lock() + defer w.mu.Unlock() + if v <= w.val { + return + } + w.val = v + for notify, val := range w.waiters { + if val <= w.val { + notify <- nil + delete(w.waiters, notify) + } + } +} + +// Get the current watermark value +func (w *Value[T]) Get() T { + w.mu.Lock() + cpy := w.val + w.mu.Unlock() + return cpy +} + +// WaitFor returns a channel that recieves a value when the watermark reaches `val`. +func (w *Value[T]) WaitFor(val T) <-chan any { + w.mu.Lock() + defer w.mu.Unlock() + ch := make(chan any, 1) + if w.val >= val { + ch <- nil + return ch + } + w.waiters[ch] = val + return ch +} diff --git a/internal/impl/postgresql/pglogicalstream/watermark/watermark_test.go b/internal/impl/postgresql/pglogicalstream/watermark/watermark_test.go new file mode 100644 index 0000000000..c24aee5851 --- /dev/null +++ b/internal/impl/postgresql/pglogicalstream/watermark/watermark_test.go @@ -0,0 +1,51 @@ +/* + * Copyright 2024 Redpanda Data, Inc. + * + * Licensed as a Redpanda Enterprise file under the Redpanda Community + * License (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://github.com/redpanda-data/redpanda/blob/master/licenses/rcl.md + */ + +package watermark_test + +import ( + "testing" + + "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/watermark" + "github.com/stretchr/testify/require" +) + +func TestWatermark(t *testing.T) { + w := watermark.New(5) + require.Equal(t, 5, w.Get()) + w.Set(3) + require.Equal(t, 5, w.Get()) + ch1 := w.WaitFor(9) + ch2 := w.WaitFor(10) + ch3 := w.WaitFor(10) + ch4 := w.WaitFor(100) + require.Len(t, ch1, 0) + require.Len(t, ch2, 0) + require.Len(t, ch3, 0) + require.Len(t, ch4, 0) + w.Set(8) + require.Equal(t, 8, w.Get()) + require.Len(t, ch1, 0) + require.Len(t, ch2, 0) + require.Len(t, ch3, 0) + require.Len(t, ch4, 0) + w.Set(9) + require.Equal(t, 9, w.Get()) + require.Len(t, ch1, 1) + require.Len(t, ch2, 0) + require.Len(t, ch3, 0) + require.Len(t, ch4, 0) + w.Set(10) + require.Equal(t, 10, w.Get()) + require.Len(t, ch1, 1) + require.Len(t, ch2, 1) + require.Len(t, ch3, 1) + require.Len(t, ch4, 0) +} From a74571ae7ef5dcbdd4174d100f8feef6a2bd75ef Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Wed, 13 Nov 2024 03:05:13 +0000 Subject: [PATCH 086/118] pgcdc: properly ack only on commit messages, once everything is processed --- .../pglogicalstream/logical_stream.go | 46 ++++--------------- .../pglogicalstream/pluginhandlers.go | 36 +++++++++------ 2 files changed, 29 insertions(+), 53 deletions(-) diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index d097ae9bfb..c1b907c4dc 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -372,46 +372,16 @@ func (s *Stream) streamMessagesAsync() { } } clientXLogPos := xld.WALStart + LSN(len(xld.WALData)) - if s.decodingPlugin == "wal2json" { - if err = handler.Handle(s.streamCtx, clientXLogPos, xld); err != nil { - s.logger.Errorf("decodeWal2JsonChanges failed: %w", err) - if err = s.Stop(); err != nil { - s.logger.Errorf("Failed to stop the stream: %v", err) - } - return - } - - // automatic ack for empty changes - // basically mean that the client is up-to-date, - // but we still need to acknowledge the LSN for standby - if err = s.AckLSN(clientXLogPos.String()); err != nil { - // stop reading from replication slot - // if we can't acknowledge the LSN - if err = s.Stop(); err != nil { - s.logger.Errorf("Failed to stop the stream: %v", err) - } - return - } - } - - if s.decodingPlugin == "pgoutput" { - if err = handler.Handle(s.streamCtx, clientXLogPos, xld); err != nil { - s.logger.Errorf("decodePgOutputChanges failed: %w", err) - if err = s.Stop(); err != nil { - s.logger.Errorf("Failed to stop the stream: %v", err) - } + commit, err := handler.Handle(s.streamCtx, clientXLogPos, xld) + if err != nil { + s.logger.Errorf("decodePgOutputChanges failed: %w", err) + if err = s.Stop(); err != nil { + s.logger.Errorf("Failed to stop the stream: %v", err) } - - // automatic ack for empty changes - // basically mean that the client is up-to-date, - // but we still need to acknowledge the LSN for standby + } else if commit { + // This is a hack and we probably should not do it if err = s.AckLSN(clientXLogPos.String()); err != nil { - // stop reading from replication slot - // if we can't acknowledge the LSN - if err = s.Stop(); err != nil { - s.logger.Errorf("Failed to stop the stream: %v", err) - } - return + s.logger.Errorf("Failed to ack commit message: %v", err) } } } diff --git a/internal/impl/postgresql/pglogicalstream/pluginhandlers.go b/internal/impl/postgresql/pglogicalstream/pluginhandlers.go index 500b77a31e..b07f57b0b9 100644 --- a/internal/impl/postgresql/pglogicalstream/pluginhandlers.go +++ b/internal/impl/postgresql/pglogicalstream/pluginhandlers.go @@ -10,6 +10,7 @@ package pglogicalstream import ( "context" + "fmt" "github.com/jackc/pgx/v5/pgtype" "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/watermark" @@ -17,7 +18,8 @@ import ( // PluginHandler is an interface that must be implemented by all plugin handlers type PluginHandler interface { - Handle(ctx context.Context, clientXLogPos LSN, xld XLogData) error + // returns true if we need to ack the clientXLogPos + Handle(ctx context.Context, clientXLogPos LSN, xld XLogData) (bool, error) } // Wal2JsonPluginHandler is a handler for wal2json output plugin @@ -35,12 +37,12 @@ func NewWal2JsonPluginHandler(messages chan StreamMessage, monitor *Monitor) *Wa } // Handle handles the wal2json output -func (w *Wal2JsonPluginHandler) Handle(_ context.Context, clientXLogPos LSN, xld XLogData) error { +func (w *Wal2JsonPluginHandler) Handle(_ context.Context, clientXLogPos LSN, xld XLogData) (bool, error) { // get current stream metrics metrics := w.monitor.Report() message, err := decodeWal2JsonChanges(clientXLogPos.String(), xld.WALData) if err != nil { - return err + return false, err } if message != nil && len(message.Changes) > 0 { @@ -48,7 +50,7 @@ func (w *Wal2JsonPluginHandler) Handle(_ context.Context, clientXLogPos LSN, xld w.messages <- *message } - return nil + return false, nil } // PgOutputPluginHandler is a handler for pgoutput output plugin @@ -61,6 +63,7 @@ type PgOutputPluginHandler struct { typeMap *pgtype.Map pgoutputChanges []StreamMessageChanges + lastEmitted LSN lsnWatermark *watermark.Value[LSN] } @@ -78,37 +81,40 @@ func NewPgOutputPluginHandler( relations: map[uint32]*RelationMessage{}, typeMap: pgtype.NewMap(), pgoutputChanges: []StreamMessageChanges{}, + lastEmitted: lsnWatermark.Get(), lsnWatermark: lsnWatermark, } } // Handle handles the pgoutput output -func (p *PgOutputPluginHandler) Handle(ctx context.Context, clientXLogPos LSN, xld XLogData) error { +func (p *PgOutputPluginHandler) Handle(ctx context.Context, clientXLogPos LSN, xld XLogData) (bool, error) { if p.streamUncommitted { // parse changes inside the transaction message, err := decodePgOutput(xld.WALData, p.relations, p.typeMap) if err != nil { - return err + return false, err } isCommit, _, err := isCommitMessage(xld.WALData) if err != nil { - return err + return false, err } // when receiving a commit message, we need to acknowledge the LSN // but we must wait for connect to flush the messages before we can do that if isCommit { select { - case <-p.lsnWatermark.WaitFor(clientXLogPos): + case <-p.lsnWatermark.WaitFor(p.lastEmitted): + return true, nil case <-ctx.Done(): - return ctx.Err() + return false, ctx.Err() } } else { if message == nil && !isCommit { - return nil + return false, nil } else if message != nil { lsn := clientXLogPos.String() + p.lastEmitted = clientXLogPos p.messages <- StreamMessage{ Lsn: &lsn, Changes: []StreamMessageChanges{ @@ -125,7 +131,7 @@ func (p *PgOutputPluginHandler) Handle(ctx context.Context, clientXLogPos LSN, x // and LSN ack will cause potential loss of changes isBegin, err := isBeginMessage(xld.WALData) if err != nil { - return err + return false, err } if isBegin { @@ -135,7 +141,7 @@ func (p *PgOutputPluginHandler) Handle(ctx context.Context, clientXLogPos LSN, x // parse changes inside the transaction message, err := decodePgOutput(xld.WALData, p.relations, p.typeMap) if err != nil { - return err + return false, err } if message != nil { @@ -144,12 +150,12 @@ func (p *PgOutputPluginHandler) Handle(ctx context.Context, clientXLogPos LSN, x isCommit, _, err := isCommitMessage(xld.WALData) if err != nil { - return err + return false, err } if isCommit { if len(p.pgoutputChanges) == 0 { - return nil + return false, nil } else { // send all collected changes lsn := clientXLogPos.String() @@ -163,5 +169,5 @@ func (p *PgOutputPluginHandler) Handle(ctx context.Context, clientXLogPos LSN, x } } - return nil + return false, nil } From 7d738647a08a305476b578f565a43ae49be9f13e Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Wed, 13 Nov 2024 03:12:50 +0000 Subject: [PATCH 087/118] pgcdc: there are actually 3 handlers --- .../pglogicalstream/pluginhandlers.go | 187 ++++++++++-------- 1 file changed, 102 insertions(+), 85 deletions(-) diff --git a/internal/impl/postgresql/pglogicalstream/pluginhandlers.go b/internal/impl/postgresql/pglogicalstream/pluginhandlers.go index b07f57b0b9..807d081b40 100644 --- a/internal/impl/postgresql/pglogicalstream/pluginhandlers.go +++ b/internal/impl/postgresql/pglogicalstream/pluginhandlers.go @@ -10,7 +10,6 @@ package pglogicalstream import ( "context" - "fmt" "github.com/jackc/pgx/v5/pgtype" "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/watermark" @@ -53,118 +52,136 @@ func (w *Wal2JsonPluginHandler) Handle(_ context.Context, clientXLogPos LSN, xld return false, nil } -// PgOutputPluginHandler is a handler for pgoutput output plugin -type PgOutputPluginHandler struct { +// PgOutputUnbufferedPluginHandler is a native output handler that emits each message as it's received. +type PgOutputUnbufferedPluginHandler struct { messages chan StreamMessage monitor *Monitor - streamUncommitted bool - relations map[uint32]*RelationMessage - typeMap *pgtype.Map - pgoutputChanges []StreamMessageChanges + relations map[uint32]*RelationMessage + typeMap *pgtype.Map lastEmitted LSN lsnWatermark *watermark.Value[LSN] } +// PgOutputBufferedPluginHandler is a native output handler that buffers and emits each transaction together +type PgOutputBufferedPluginHandler struct { + messages chan StreamMessage + monitor *Monitor + + relations map[uint32]*RelationMessage + typeMap *pgtype.Map + pgoutputChanges []StreamMessageChanges +} + // NewPgOutputPluginHandler creates a new PgOutputPluginHandler func NewPgOutputPluginHandler( messages chan StreamMessage, streamUncommitted bool, monitor *Monitor, lsnWatermark *watermark.Value[LSN], -) *PgOutputPluginHandler { - return &PgOutputPluginHandler{ - messages: messages, - monitor: monitor, - streamUncommitted: streamUncommitted, - relations: map[uint32]*RelationMessage{}, - typeMap: pgtype.NewMap(), - pgoutputChanges: []StreamMessageChanges{}, - lastEmitted: lsnWatermark.Get(), - lsnWatermark: lsnWatermark, +) PluginHandler { + if streamUncommitted { + return &PgOutputUnbufferedPluginHandler{ + messages: messages, + monitor: monitor, + relations: map[uint32]*RelationMessage{}, + typeMap: pgtype.NewMap(), + lastEmitted: lsnWatermark.Get(), + lsnWatermark: lsnWatermark, + } + } + return &PgOutputBufferedPluginHandler{ + messages: messages, + monitor: monitor, + relations: map[uint32]*RelationMessage{}, + typeMap: pgtype.NewMap(), + pgoutputChanges: []StreamMessageChanges{}, } } // Handle handles the pgoutput output -func (p *PgOutputPluginHandler) Handle(ctx context.Context, clientXLogPos LSN, xld XLogData) (bool, error) { - if p.streamUncommitted { - // parse changes inside the transaction - message, err := decodePgOutput(xld.WALData, p.relations, p.typeMap) - if err != nil { - return false, err - } +func (p *PgOutputUnbufferedPluginHandler) Handle(ctx context.Context, clientXLogPos LSN, xld XLogData) (bool, error) { + // parse changes inside the transaction + message, err := decodePgOutput(xld.WALData, p.relations, p.typeMap) + if err != nil { + return false, err + } - isCommit, _, err := isCommitMessage(xld.WALData) - if err != nil { - return false, err - } + isCommit, _, err := isCommitMessage(xld.WALData) + if err != nil { + return false, err + } - // when receiving a commit message, we need to acknowledge the LSN - // but we must wait for connect to flush the messages before we can do that - if isCommit { - select { - case <-p.lsnWatermark.WaitFor(p.lastEmitted): - return true, nil - case <-ctx.Done(): - return false, ctx.Err() - } - } else { - if message == nil && !isCommit { - return false, nil - } else if message != nil { - lsn := clientXLogPos.String() - p.lastEmitted = clientXLogPos - p.messages <- StreamMessage{ - Lsn: &lsn, - Changes: []StreamMessageChanges{ - *message, - }, - Mode: StreamModeStreaming, - WALLagBytes: &p.monitor.Report().WalLagInBytes, - } - } + // when receiving a commit message, we need to acknowledge the LSN + // but we must wait for connect to flush the messages before we can do that + if isCommit { + select { + case <-p.lsnWatermark.WaitFor(p.lastEmitted): + return true, nil + case <-ctx.Done(): + return false, ctx.Err() } } else { - // message changes must be collected in the buffer in the context of the same transaction - // as single transaction can contain multiple changes - // and LSN ack will cause potential loss of changes - isBegin, err := isBeginMessage(xld.WALData) - if err != nil { - return false, err + if message == nil && !isCommit { + return false, nil + } else if message != nil { + lsn := clientXLogPos.String() + p.lastEmitted = clientXLogPos + p.messages <- StreamMessage{ + Lsn: &lsn, + Changes: []StreamMessageChanges{ + *message, + }, + Mode: StreamModeStreaming, + WALLagBytes: &p.monitor.Report().WalLagInBytes, + } } + } - if isBegin { - p.pgoutputChanges = []StreamMessageChanges{} - } + return false, nil +} - // parse changes inside the transaction - message, err := decodePgOutput(xld.WALData, p.relations, p.typeMap) - if err != nil { - return false, err - } +// Handle handles the pgoutput output +func (p *PgOutputBufferedPluginHandler) Handle(ctx context.Context, clientXLogPos LSN, xld XLogData) (bool, error) { + // message changes must be collected in the buffer in the context of the same transaction + // as single transaction can contain multiple changes + // and LSN ack will cause potential loss of changes + isBegin, err := isBeginMessage(xld.WALData) + if err != nil { + return false, err + } - if message != nil { - p.pgoutputChanges = append(p.pgoutputChanges, *message) - } + if isBegin { + p.pgoutputChanges = []StreamMessageChanges{} + } - isCommit, _, err := isCommitMessage(xld.WALData) - if err != nil { - return false, err - } + // parse changes inside the transaction + message, err := decodePgOutput(xld.WALData, p.relations, p.typeMap) + if err != nil { + return false, err + } + + if message != nil { + p.pgoutputChanges = append(p.pgoutputChanges, *message) + } - if isCommit { - if len(p.pgoutputChanges) == 0 { - return false, nil - } else { - // send all collected changes - lsn := clientXLogPos.String() - p.messages <- StreamMessage{ - Lsn: &lsn, - Changes: p.pgoutputChanges, - Mode: StreamModeStreaming, - WALLagBytes: &p.monitor.Report().WalLagInBytes, - } + isCommit, _, err := isCommitMessage(xld.WALData) + if err != nil { + return false, err + } + + if isCommit { + if len(p.pgoutputChanges) == 0 { + return false, nil + } else { + // send all collected changes + lsn := clientXLogPos.String() + p.messages <- StreamMessage{ + Lsn: &lsn, + Changes: p.pgoutputChanges, + Mode: StreamModeStreaming, + WALLagBytes: &p.monitor.Report().WalLagInBytes, } } } From 8d0aaed6a9d90bc3bcab3e3973f73b50a5a7d0d5 Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Wed, 13 Nov 2024 03:20:55 +0000 Subject: [PATCH 088/118] pgcdc: simplify plugin handling code --- .../pglogicalstream/pluginhandlers.go | 50 +++++++++---------- 1 file changed, 24 insertions(+), 26 deletions(-) diff --git a/internal/impl/postgresql/pglogicalstream/pluginhandlers.go b/internal/impl/postgresql/pglogicalstream/pluginhandlers.go index 807d081b40..3da396a925 100644 --- a/internal/impl/postgresql/pglogicalstream/pluginhandlers.go +++ b/internal/impl/postgresql/pglogicalstream/pluginhandlers.go @@ -122,20 +122,18 @@ func (p *PgOutputUnbufferedPluginHandler) Handle(ctx context.Context, clientXLog case <-ctx.Done(): return false, ctx.Err() } - } else { - if message == nil && !isCommit { - return false, nil - } else if message != nil { - lsn := clientXLogPos.String() - p.lastEmitted = clientXLogPos - p.messages <- StreamMessage{ - Lsn: &lsn, - Changes: []StreamMessageChanges{ - *message, - }, - Mode: StreamModeStreaming, - WALLagBytes: &p.monitor.Report().WalLagInBytes, - } + } + + if message != nil { + lsn := clientXLogPos.String() + p.lastEmitted = clientXLogPos + p.messages <- StreamMessage{ + Lsn: &lsn, + Changes: []StreamMessageChanges{ + *message, + }, + Mode: StreamModeStreaming, + WALLagBytes: &p.monitor.Report().WalLagInBytes, } } @@ -171,18 +169,18 @@ func (p *PgOutputBufferedPluginHandler) Handle(ctx context.Context, clientXLogPo return false, err } - if isCommit { - if len(p.pgoutputChanges) == 0 { - return false, nil - } else { - // send all collected changes - lsn := clientXLogPos.String() - p.messages <- StreamMessage{ - Lsn: &lsn, - Changes: p.pgoutputChanges, - Mode: StreamModeStreaming, - WALLagBytes: &p.monitor.Report().WalLagInBytes, - } + if !isCommit { + return false, nil + } + + if len(p.pgoutputChanges) >= 0 { + // send all collected changes + lsn := clientXLogPos.String() + p.messages <- StreamMessage{ + Lsn: &lsn, + Changes: p.pgoutputChanges, + Mode: StreamModeStreaming, + WALLagBytes: &p.monitor.Report().WalLagInBytes, } } From f75e003eaeddf9dd3831ceb3c01dc4a518ce582a Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Wed, 13 Nov 2024 03:56:10 +0000 Subject: [PATCH 089/118] pgcdc: fix randomized ID uuid is invalid because we can't use dashes --- internal/impl/postgresql/input_pg_stream.go | 13 ++++++------- .../impl/postgresql/pglogicalstream/pglogrepl.go | 4 ++-- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/internal/impl/postgresql/input_pg_stream.go b/internal/impl/postgresql/input_pg_stream.go index a698037181..5f0590bd8f 100644 --- a/internal/impl/postgresql/input_pg_stream.go +++ b/internal/impl/postgresql/input_pg_stream.go @@ -17,8 +17,8 @@ import ( "time" "github.com/Jeffail/checkpoint" - "github.com/google/uuid" "github.com/jackc/pgx/v5/pgconn" + "github.com/matoous/go-nanoid/v2" "github.com/redpanda-data/benthos/v4/public/service" "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream" @@ -147,7 +147,10 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser } // Set the default to be a random string if dbSlotName == "" { - dbSlotName = uuid.NewString() + dbSlotName, err = gonanoid.Generate("0123456789ABCDEFGHJKMNPQRSTVWXYZ", 32) + if err != nil { + return nil, err + } } if err := validateSimpleString(dbSlotName); err != nil { @@ -266,15 +269,11 @@ func validateSimpleString(s string) error { isDigit := b >= '0' && b <= '9' isLower := b >= 'a' && b <= 'z' isUpper := b >= 'A' && b <= 'Z' - isDelimiter := b == '_' || b == '-' + isDelimiter := b == '_' if !isDigit && !isLower && !isUpper && !isDelimiter { return fmt.Errorf("invalid postgres identifier %q", s) } } - // See: https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p - if strings.Contains(s, "--") { - return fmt.Errorf("invalid postgres identifier %q", s) - } return nil } diff --git a/internal/impl/postgresql/pglogicalstream/pglogrepl.go b/internal/impl/postgresql/pglogicalstream/pglogrepl.go index 0324fe81b3..ea272fff80 100644 --- a/internal/impl/postgresql/pglogicalstream/pglogrepl.go +++ b/internal/impl/postgresql/pglogicalstream/pglogrepl.go @@ -391,13 +391,13 @@ func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName var tablesToRemoveFromPublication = []string{} var tablesToAddToPublication = []string{} for _, table := range tables { - if !slices.Contains[[]string, string](pubTables, table) { + if !slices.Contains(pubTables, table) { tablesToAddToPublication = append(tablesToAddToPublication, table) } } for _, table := range pubTables { - if !slices.Contains[[]string, string](tables, table) { + if !slices.Contains(tables, table) { tablesToRemoveFromPublication = append(tablesToRemoveFromPublication, table) } } From fc29d429ca963ae7c3e4302ded0d3119c588d640 Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Wed, 13 Nov 2024 03:59:38 +0000 Subject: [PATCH 090/118] pgcdc: remove unused import --- internal/impl/postgresql/input_pg_stream.go | 1 - 1 file changed, 1 deletion(-) diff --git a/internal/impl/postgresql/input_pg_stream.go b/internal/impl/postgresql/input_pg_stream.go index 5f0590bd8f..e757b47b50 100644 --- a/internal/impl/postgresql/input_pg_stream.go +++ b/internal/impl/postgresql/input_pg_stream.go @@ -12,7 +12,6 @@ import ( "context" "encoding/json" "fmt" - "strings" "sync" "time" From 82cc4e10513c1f9946352b8615d237d3cccaa8c4 Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Wed, 13 Nov 2024 04:11:22 +0000 Subject: [PATCH 091/118] pgcdc: always include mode --- .../postgresql/pglogicalstream/replication_message_decoders.go | 1 + 1 file changed, 1 insertion(+) diff --git a/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go b/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go index 5ad6117593..47bf3a6d26 100644 --- a/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go +++ b/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go @@ -185,6 +185,7 @@ func decodeWal2JsonChanges(clientXLogPosition string, WALData []byte) (*StreamMe message := &StreamMessage{ Lsn: &clientXLogPosition, Changes: []StreamMessageChanges{}, + Mode: StreamModeStreaming, } for _, change := range changes.Change { From 1bce9977be59b3e2e8081f82c7254db87e5aafd2 Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Wed, 13 Nov 2024 04:26:31 +0000 Subject: [PATCH 092/118] pgcdc: fix period batching and cleanup logic --- internal/impl/postgresql/input_pg_stream.go | 20 ++++++++++++------- .../pglogicalstream/stream_message.go | 2 +- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/internal/impl/postgresql/input_pg_stream.go b/internal/impl/postgresql/input_pg_stream.go index e757b47b50..53f7ab509e 100644 --- a/internal/impl/postgresql/input_pg_stream.go +++ b/internal/impl/postgresql/input_pg_stream.go @@ -322,8 +322,13 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { // offsets are nilable since we don't provide offset tracking during the snapshot phase var latestOffset *int64 cp := checkpoint.NewCapped[*int64](int64(p.checkpointLimit)) - for { + for ctx.Err() != nil { select { + case <-ctx.Done(): + if err = p.pgLogicalStream.Stop(); err != nil { + p.logger.Errorf("Failed to stop pglogical stream: %v", err) + } + return case <-nextTimedBatchChan: nextTimedBatchChan = nil flushedBatch, err := batchPolicy.Flush(ctx) @@ -358,7 +363,8 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { continue } - if mb, err = json.Marshal(message); err != nil { + // TODO this should only be the message + if mb, err = json.Marshal(message.Changes); err != nil { break } @@ -384,12 +390,12 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { if err := p.flushBatch(ctx, cp, flushedBatch, latestOffset); err != nil { break } + } else { + d, ok := batchPolicy.UntilNext() + if ok { + nextTimedBatchChan = time.After(d) + } } - case <-ctx.Done(): - if err = p.pgLogicalStream.Stop(); err != nil { - p.logger.Errorf("Failed to stop pglogical stream: %v", err) - } - return } } }() diff --git a/internal/impl/postgresql/pglogicalstream/stream_message.go b/internal/impl/postgresql/pglogicalstream/stream_message.go index 6d0dbdf087..99ebf1acd0 100644 --- a/internal/impl/postgresql/pglogicalstream/stream_message.go +++ b/internal/impl/postgresql/pglogicalstream/stream_message.go @@ -40,5 +40,5 @@ type StreamMessage struct { Lsn *string `json:"lsn"` Changes []StreamMessageChanges `json:"changes"` Mode StreamMode `json:"mode"` - WALLagBytes *int64 `json:"wal_lag_bytes"` + WALLagBytes *int64 `json:"wal_lag_bytes,omitempty"` } From b85c8bc00273d09db8d20a849fc3aa79fa634a93 Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Wed, 13 Nov 2024 04:40:17 +0000 Subject: [PATCH 093/118] pgcdc: fix lint error --- .../watermark/watermark_test.go | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/internal/impl/postgresql/pglogicalstream/watermark/watermark_test.go b/internal/impl/postgresql/pglogicalstream/watermark/watermark_test.go index c24aee5851..4d176bab43 100644 --- a/internal/impl/postgresql/pglogicalstream/watermark/watermark_test.go +++ b/internal/impl/postgresql/pglogicalstream/watermark/watermark_test.go @@ -22,30 +22,31 @@ func TestWatermark(t *testing.T) { require.Equal(t, 5, w.Get()) w.Set(3) require.Equal(t, 5, w.Get()) + require.Len(t, w.WaitFor(1), 1) ch1 := w.WaitFor(9) ch2 := w.WaitFor(10) ch3 := w.WaitFor(10) ch4 := w.WaitFor(100) - require.Len(t, ch1, 0) - require.Len(t, ch2, 0) - require.Len(t, ch3, 0) - require.Len(t, ch4, 0) + require.Empty(t, ch1) + require.Empty(t, ch2) + require.Empty(t, ch3) + require.Empty(t, ch4) w.Set(8) require.Equal(t, 8, w.Get()) - require.Len(t, ch1, 0) - require.Len(t, ch2, 0) - require.Len(t, ch3, 0) - require.Len(t, ch4, 0) + require.Empty(t, ch1) + require.Empty(t, ch2) + require.Empty(t, ch3) + require.Empty(t, ch4) w.Set(9) require.Equal(t, 9, w.Get()) require.Len(t, ch1, 1) - require.Len(t, ch2, 0) - require.Len(t, ch3, 0) - require.Len(t, ch4, 0) + require.Empty(t, ch2) + require.Empty(t, ch3) + require.Empty(t, ch4) w.Set(10) require.Equal(t, 10, w.Get()) require.Len(t, ch1, 1) require.Len(t, ch2, 1) require.Len(t, ch3, 1) - require.Len(t, ch4, 0) + require.Empty(t, ch4) } From 30678edf60209c135540c859a312dcd326d2ff43 Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Wed, 13 Nov 2024 04:41:23 +0000 Subject: [PATCH 094/118] pgcdc: regen docs --- .../components/pages/inputs/pg_stream.adoc | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/docs/modules/components/pages/inputs/pg_stream.adoc b/docs/modules/components/pages/inputs/pg_stream.adoc index d11a472808..3b2ddb59b2 100644 --- a/docs/modules/components/pages/inputs/pg_stream.adoc +++ b/docs/modules/components/pages/inputs/pg_stream.adoc @@ -50,8 +50,8 @@ input: checkpoint_limit: 1024 temporary_slot: false slot_name: "" - pg_standby_timeout_sec: 10 - pg_wal_monitor_interval_sec: 3 + pg_standby_timeout: 10s + pg_wal_monitor_interval: 3s max_parallel_snapshot_tables: 1 auto_replay_nacks: true batching: @@ -82,8 +82,8 @@ input: checkpoint_limit: 1024 temporary_slot: false slot_name: "" - pg_standby_timeout_sec: 10 - pg_wal_monitor_interval_sec: 3 + pg_standby_timeout: 10s + pg_wal_monitor_interval: 3s max_parallel_snapshot_tables: 1 auto_replay_nacks: true batching: @@ -181,7 +181,7 @@ snapshot_batch_size: 10000 === `decoding_plugin` Specifies the logical decoding plugin to use for streaming changes from PostgreSQL. 'pgoutput' is the native logical replication protocol, while 'wal2json' provides change data as JSON. - Important: No matter which plugin you choose, the data will be converted to JSON before sending it to Connect. +Important: No matter which plugin you choose, the data will be converted to JSON before sending it to Connect. *Type*: `string` @@ -263,34 +263,34 @@ The name of the PostgreSQL logical replication slot to use. If not provided, a r slot_name: my_test_slot ``` -=== `pg_standby_timeout_sec` +=== `pg_standby_timeout` -Int field that specifies default standby timeout for PostgreSQL replication connection +Specify the standby timeout before refreshing an idle connection. -*Type*: `int` +*Type*: `string` -*Default*: `10` +*Default*: `"10s"` ```yml # Examples -pg_standby_timeout_sec: 10 +pg_standby_timeout: 30s ``` -=== `pg_wal_monitor_interval_sec` +=== `pg_wal_monitor_interval` -Int field stat specifies ticker interval for WAL monitoring. Used to fetch replication slot lag +How often to report changes to the replication lag. -*Type*: `int` +*Type*: `string` -*Default*: `3` +*Default*: `"3s"` ```yml # Examples -pg_wal_monitor_interval_sec: 3 +pg_wal_monitor_interval: 6s ``` === `max_parallel_snapshot_tables` From 408394a50e602388aeb0f90e2cdc6ede33e39eb2 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Wed, 13 Nov 2024 17:44:42 +0100 Subject: [PATCH 095/118] chore(): added +1 to standby update to follow postgresql requirements --- internal/impl/postgresql/pglogicalstream/logical_stream.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 1d91cf1326..485a07d0de 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -267,9 +267,9 @@ func (s *Stream) AckLSN(lsn string) error { } err = SendStandbyStatusUpdate(context.Background(), s.pgConn, StandbyStatusUpdate{ - WALApplyPosition: clientXLogPos, - WALWritePosition: clientXLogPos, - WALFlushPosition: clientXLogPos, + WALApplyPosition: clientXLogPos + 1, + WALWritePosition: clientXLogPos + 1, + WALFlushPosition: clientXLogPos + 1, ReplyRequested: true, }) From fe0526806d6383f556e48e9c64e3f7bcdff80fce Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Wed, 13 Nov 2024 18:58:25 +0000 Subject: [PATCH 096/118] chore: goimports --- internal/impl/postgresql/input_pg_stream.go | 2 +- internal/impl/postgresql/pglogicalstream/logical_stream.go | 3 ++- internal/impl/postgresql/pglogicalstream/pglogrepl.go | 1 + internal/impl/postgresql/pglogicalstream/pluginhandlers.go | 1 + .../postgresql/pglogicalstream/watermark/watermark_test.go | 3 ++- 5 files changed, 7 insertions(+), 3 deletions(-) diff --git a/internal/impl/postgresql/input_pg_stream.go b/internal/impl/postgresql/input_pg_stream.go index 53f7ab509e..ca650231bf 100644 --- a/internal/impl/postgresql/input_pg_stream.go +++ b/internal/impl/postgresql/input_pg_stream.go @@ -17,7 +17,7 @@ import ( "github.com/Jeffail/checkpoint" "github.com/jackc/pgx/v5/pgconn" - "github.com/matoous/go-nanoid/v2" + gonanoid "github.com/matoous/go-nanoid/v2" "github.com/redpanda-data/benthos/v4/public/service" "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream" diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index a7870e029b..c2e5d87a0b 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -21,9 +21,10 @@ import ( "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgproto3" "github.com/redpanda-data/benthos/v4/public/service" + "golang.org/x/sync/errgroup" + "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/sanitize" "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/watermark" - "golang.org/x/sync/errgroup" ) // Stream is a structure that represents a logical replication stream diff --git a/internal/impl/postgresql/pglogicalstream/pglogrepl.go b/internal/impl/postgresql/pglogicalstream/pglogrepl.go index ea272fff80..f9cbd8d1ef 100644 --- a/internal/impl/postgresql/pglogicalstream/pglogrepl.go +++ b/internal/impl/postgresql/pglogicalstream/pglogrepl.go @@ -30,6 +30,7 @@ import ( "github.com/jackc/pgio" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgproto3" + "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/sanitize" ) diff --git a/internal/impl/postgresql/pglogicalstream/pluginhandlers.go b/internal/impl/postgresql/pglogicalstream/pluginhandlers.go index 3da396a925..8b2e4ecf85 100644 --- a/internal/impl/postgresql/pglogicalstream/pluginhandlers.go +++ b/internal/impl/postgresql/pglogicalstream/pluginhandlers.go @@ -12,6 +12,7 @@ import ( "context" "github.com/jackc/pgx/v5/pgtype" + "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/watermark" ) diff --git a/internal/impl/postgresql/pglogicalstream/watermark/watermark_test.go b/internal/impl/postgresql/pglogicalstream/watermark/watermark_test.go index 4d176bab43..637deff653 100644 --- a/internal/impl/postgresql/pglogicalstream/watermark/watermark_test.go +++ b/internal/impl/postgresql/pglogicalstream/watermark/watermark_test.go @@ -13,8 +13,9 @@ package watermark_test import ( "testing" - "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/watermark" "github.com/stretchr/testify/require" + + "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/watermark" ) func TestWatermark(t *testing.T) { From 581c7d4602624dd853e07ea202ba8ac5ecfe3a8b Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Wed, 13 Nov 2024 19:55:17 +0000 Subject: [PATCH 097/118] pgcdc: simplify shutdown in the input Still need to simplify this in the internal logical_stream package, but this is a first step --- internal/impl/postgresql/input_pg_stream.go | 229 +++++++++--------- .../pglogicalstream/logical_stream.go | 22 +- 2 files changed, 131 insertions(+), 120 deletions(-) diff --git a/internal/impl/postgresql/input_pg_stream.go b/internal/impl/postgresql/input_pg_stream.go index ca650231bf..2fc76ec744 100644 --- a/internal/impl/postgresql/input_pg_stream.go +++ b/internal/impl/postgresql/input_pg_stream.go @@ -12,10 +12,10 @@ import ( "context" "encoding/json" "fmt" - "sync" "time" "github.com/Jeffail/checkpoint" + "github.com/Jeffail/shutdown" "github.com/jackc/pgx/v5/pgconn" gonanoid "github.com/matoous/go-nanoid/v2" "github.com/redpanda-data/benthos/v4/public/service" @@ -245,15 +245,18 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser }, batching: batching, checkpointLimit: checkpointLimit, - cMut: sync.Mutex{}, msgChan: make(chan asyncMessage), mgr: mgr, logger: mgr.Logger(), snapshotMetrics: snapshotMetrics, replicationLag: replicationLag, + stopSig: shutdown.NewSignaller(), } + // Has stopped is how we notify that we're not connected. This will get reset at connection time. + i.stopSig.TriggerHasStopped() + r, err := service.AutoRetryNacksBatchedToggled(conf, i) if err != nil { return nil, err @@ -292,13 +295,13 @@ type pgStreamInput struct { pgLogicalStream *pglogicalstream.Stream logger *service.Logger mgr *service.Resources - cMut sync.Mutex msgChan chan asyncMessage batching service.BatchPolicy checkpointLimit int snapshotMetrics *service.MetricGauge replicationLag *service.MetricGauge + stopSig *shutdown.Signaller } func (p *pgStreamInput) Connect(ctx context.Context) error { @@ -308,112 +311,118 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { } p.pgLogicalStream = pgStream - batchPolicy, err := p.batching.NewBatcher(p.mgr) + batcher, err := p.batching.NewBatcher(p.mgr) if err != nil { return err } - go func() { - defer func() { - batchPolicy.Close(context.Background()) - }() - - var nextTimedBatchChan <-chan time.Time - - // offsets are nilable since we don't provide offset tracking during the snapshot phase - var latestOffset *int64 - cp := checkpoint.NewCapped[*int64](int64(p.checkpointLimit)) - for ctx.Err() != nil { - select { - case <-ctx.Done(): - if err = p.pgLogicalStream.Stop(); err != nil { - p.logger.Errorf("Failed to stop pglogical stream: %v", err) - } - return - case <-nextTimedBatchChan: - nextTimedBatchChan = nil - flushedBatch, err := batchPolicy.Flush(ctx) - if err != nil { - p.logger.Debugf("Timed flush batch error: %w", err) - break - } + // Reset our stop signal + p.stopSig = shutdown.NewSignaller() + go p.processStream(batcher) + return err +} - if err := p.flushBatch(ctx, cp, flushedBatch, latestOffset); err != nil { - break - } +func (p *pgStreamInput) processStream(batcher *service.Batcher) { + ctx, _ := p.stopSig.SoftStopCtx(context.Background()) + defer func() { + ctx, _ := p.stopSig.HardStopCtx(context.Background()) + if err := batcher.Close(ctx); err != nil { + p.logger.Errorf("uneable to close batcher: %v", err) + } + p.stopSig.TriggerHasStopped() + }() - case message, open := <-p.pgLogicalStream.Messages(): - if !open { - break - } - var ( - mb []byte - err error - ) - if message.Lsn != nil { - parsedLSN, err := LSNToInt64(*message.Lsn) - if err != nil { - p.logger.Errorf("Failed to parse LSN: %v", err) - break - } - latestOffset = &parsedLSN - } + var nextTimedBatchChan <-chan time.Time + + // offsets are nilable since we don't provide offset tracking during the snapshot phase + cp := checkpoint.NewCapped[*int64](int64(p.checkpointLimit)) + for !p.stopSig.IsSoftStopSignalled() { + select { + case <-nextTimedBatchChan: + nextTimedBatchChan = nil + flushedBatch, err := batcher.Flush(ctx) + if err != nil { + p.logger.Debugf("timed flush batch error: %v", err) + break + } + if err := p.flushBatch(ctx, cp, flushedBatch); err != nil { + p.logger.Debugf("failed to flush batch: %v", err) + break + } - if len(message.Changes) == 0 { - p.logger.Debugf("Received empty message on LSN: %v", message.Lsn) - continue - } + case message := <-p.pgLogicalStream.Messages(): + var ( + mb []byte + err error + ) - // TODO this should only be the message - if mb, err = json.Marshal(message.Changes); err != nil { - break - } + if len(message.Changes) == 0 { + p.logger.Errorf("received empty message (LSN=%v)", message.Lsn) + break + } + + // TODO this should only be the message + if mb, err = json.Marshal(message.Changes); err != nil { + break + } - batchMsg := service.NewMessage(mb) + batchMsg := service.NewMessage(mb) - batchMsg.MetaSet("mode", string(message.Mode)) - batchMsg.MetaSet("table", message.Changes[0].Table) - batchMsg.MetaSet("operation", message.Changes[0].Operation) - if message.Changes[0].TableSnapshotProgress != nil { - p.snapshotMetrics.SetFloat64(*message.Changes[0].TableSnapshotProgress, message.Changes[0].Table) + batchMsg.MetaSet("mode", string(message.Mode)) + batchMsg.MetaSet("table", message.Changes[0].Table) + batchMsg.MetaSet("operation", message.Changes[0].Operation) + if message.Lsn != nil { + batchMsg.MetaSet("lsn", *message.Lsn) + } + if message.Changes[0].TableSnapshotProgress != nil { + p.snapshotMetrics.SetFloat64(*message.Changes[0].TableSnapshotProgress, message.Changes[0].Table) + } + if message.WALLagBytes != nil { + p.replicationLag.Set(*message.WALLagBytes) + } + + if batcher.Add(batchMsg) { + nextTimedBatchChan = nil + flushedBatch, err := batcher.Flush(ctx) + if err != nil { + p.logger.Debugf("error flushing batch: %v", err) + break } - if message.WALLagBytes != nil { - p.replicationLag.Set(*message.WALLagBytes) + if err := p.flushBatch(ctx, cp, flushedBatch); err != nil { + p.logger.Debugf("failed to flush batch: %v", err) + break } - - if batchPolicy.Add(batchMsg) { - nextTimedBatchChan = nil - flushedBatch, err := batchPolicy.Flush(ctx) - if err != nil { - p.logger.Debugf("Flush batch error: %w", err) - break - } - if err := p.flushBatch(ctx, cp, flushedBatch, latestOffset); err != nil { - break - } - } else { - d, ok := batchPolicy.UntilNext() - if ok { - nextTimedBatchChan = time.After(d) - } + } else { + d, ok := batcher.UntilNext() + if ok { + nextTimedBatchChan = time.After(d) } } } - }() - - return err + } } -func (p *pgStreamInput) flushBatch(ctx context.Context, checkpointer *checkpoint.Capped[*int64], msg service.MessageBatch, lsn *int64) error { - if msg == nil { +func (p *pgStreamInput) flushBatch( + ctx context.Context, + checkpointer *checkpoint.Capped[*int64], + batch service.MessageBatch, +) error { + if batch == nil { return nil } - resolveFn, err := checkpointer.Track(ctx, lsn, int64(len(msg))) - if err != nil { - if ctx.Err() == nil { - p.mgr.Logger().Errorf("Failed to checkpoint offset: %v\n", err) + var lsn *int64 + lastMsg := batch[len(batch)-1] + lsnStr, ok := lastMsg.MetaGet("lsn") + if ok { + parsed, err := LSNToInt64(lsnStr) + if err != nil { + return fmt.Errorf("unable to extract LSN from last message in batch: %w", err) } - return err + lsn = &parsed + } + resolveFn, err := checkpointer.Track(ctx, lsn, int64(len(batch))) + if err != nil { + return fmt.Errorf("unable to checkpoint: %w", err) } ackFn := func(ctx context.Context, res error) error { @@ -421,18 +430,17 @@ func (p *pgStreamInput) flushBatch(ctx context.Context, checkpointer *checkpoint if maxOffset == nil { return nil } - p.cMut.Lock() - defer p.cMut.Unlock() + lsn := *maxOffset if lsn == nil { return nil } - if err = p.pgLogicalStream.AckLSN(Int64ToLSN(*lsn)); err != nil { - return err + if err = p.pgLogicalStream.AckLSN(ctx, Int64ToLSN(*lsn)); err != nil { + return fmt.Errorf("unable to ack LSN to postgres: %w", err) } return nil } select { - case p.msgChan <- asyncMessage{msg: msg, ackFn: ackFn}: + case p.msgChan <- asyncMessage{msg: batch, ackFn: ackFn}: case <-ctx.Done(): return ctx.Err() } @@ -440,29 +448,32 @@ func (p *pgStreamInput) flushBatch(ctx context.Context, checkpointer *checkpoint } func (p *pgStreamInput) ReadBatch(ctx context.Context) (service.MessageBatch, service.AckFunc, error) { - p.cMut.Lock() - msgChan := p.msgChan - p.cMut.Unlock() - if msgChan == nil { - return nil, nil, service.ErrNotConnected - } - select { - case m, open := <-msgChan: - if !open { - return nil, nil, service.ErrNotConnected - } + case m := <-p.msgChan: return m.msg, m.ackFn, nil + case <-p.stopSig.HasStoppedChan(): + return nil, nil, service.ErrNotConnected case <-ctx.Done(): - + return nil, nil, ctx.Err() } - - return nil, nil, ctx.Err() } func (p *pgStreamInput) Close(ctx context.Context) error { + p.stopSig.TriggerSoftStop() + select { + case <-ctx.Done(): + case <-p.stopSig.HasStoppedChan(): + } if p.pgLogicalStream != nil { - return p.pgLogicalStream.Stop() + if err := p.pgLogicalStream.Stop(ctx); err != nil { + return err + } + } + p.stopSig.TriggerHardStop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-p.stopSig.HasStoppedChan(): } return nil } diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index c2e5d87a0b..5b221f369f 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -239,11 +239,11 @@ func (s *Stream) startLr(lsnStart LSN) error { // AckLSN acknowledges the LSN up to which the stream has processed the messages. // This makes Postgres to remove the WAL files that are no longer needed. -func (s *Stream) AckLSN(lsn string) error { +func (s *Stream) AckLSN(ctx context.Context, lsn string) error { clientXLogPos, err := ParseLSN(lsn) if err != nil { s.logger.Errorf("Failed to parse LSN for Acknowledge: %v", err) - if err = s.Stop(); err != nil { + if err = s.Stop(ctx); err != nil { s.logger.Errorf("Failed to stop the stream: %v", err) } @@ -279,7 +279,7 @@ func (s *Stream) streamMessagesAsync() { handler = NewPgOutputPluginHandler(s.messages, s.streamUncommitted, s.monitor, s.clientXLogPos) default: s.logger.Error("Invalid decoding plugin. Cant find needed handler implementation") - if err := s.Stop(); err != nil { + if err := s.Stop(context.TODO()); err != nil { s.logger.Errorf("Failed to stop the stream: %v", err) } @@ -304,7 +304,7 @@ func (s *Stream) streamMessagesAsync() { if err != nil { s.logger.Errorf("Failed to send Standby status message at LSN#%s: %v", pos.String(), err) - if err = s.Stop(); err != nil { + if err = s.Stop(context.TODO()); err != nil { s.logger.Errorf("Failed to stop the stream: %v", err) } return @@ -328,7 +328,7 @@ func (s *Stream) streamMessagesAsync() { } s.logger.Errorf("Failed to receive messages from PostgreSQL: %v", err) - if err = s.Stop(); err != nil { + if err = s.Stop(context.TODO()); err != nil { s.logger.Errorf("Failed to stop the stream: %v", err) } return @@ -336,7 +336,7 @@ func (s *Stream) streamMessagesAsync() { if errMsg, ok := rawMsg.(*pgproto3.ErrorResponse); ok { s.logger.Errorf("Received error message from Postgres: %v", errMsg) - if err = s.Stop(); err != nil { + if err = s.Stop(context.TODO()); err != nil { s.logger.Errorf("Failed to stop the stream: %v", err) } return @@ -353,7 +353,7 @@ func (s *Stream) streamMessagesAsync() { pkm, err := ParsePrimaryKeepaliveMessage(msg.Data[1:]) if err != nil { s.logger.Errorf("Failed to parse PrimaryKeepaliveMessage: %v", err) - if err = s.Stop(); err != nil { + if err = s.Stop(context.TODO()); err != nil { s.logger.Errorf("Failed to stop the stream: %v", err) } } @@ -368,7 +368,7 @@ func (s *Stream) streamMessagesAsync() { xld, err := ParseXLogData(msg.Data[1:]) if err != nil { s.logger.Errorf("Failed to parse XLogData: %v", err) - if err = s.Stop(); err != nil { + if err = s.Stop(context.TODO()); err != nil { s.logger.Errorf("Failed to stop the stream: %v", err) } } @@ -376,12 +376,12 @@ func (s *Stream) streamMessagesAsync() { commit, err := handler.Handle(s.streamCtx, clientXLogPos, xld) if err != nil { s.logger.Errorf("decodePgOutputChanges failed: %w", err) - if err = s.Stop(); err != nil { + if err = s.Stop(context.TODO()); err != nil { s.logger.Errorf("Failed to stop the stream: %v", err) } } else if commit { // This is a hack and we probably should not do it - if err = s.AckLSN(clientXLogPos.String()); err != nil { + if err = s.AckLSN(context.TODO(), clientXLogPos.String()); err != nil { s.logger.Errorf("Failed to ack commit message: %v", err) } } @@ -604,7 +604,7 @@ func (s *Stream) getPrimaryKeyColumn(tableName string) (string, error) { } // Stop closes the stream conect and prevents from replication slot read -func (s *Stream) Stop() error { +func (s *Stream) Stop(ctx context.Context) error { if s == nil { return nil } From 998b4c5bbc1d1ccf63232a1353e717109e9c6290 Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Wed, 13 Nov 2024 20:05:22 +0000 Subject: [PATCH 098/118] pgcdc: localize the pg stream To make lifetime semantics and handling ErrNotConnected better --- internal/impl/postgresql/input_pg_stream.go | 30 ++++++++++----------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/internal/impl/postgresql/input_pg_stream.go b/internal/impl/postgresql/input_pg_stream.go index 2fc76ec744..0facf0f6fe 100644 --- a/internal/impl/postgresql/input_pg_stream.go +++ b/internal/impl/postgresql/input_pg_stream.go @@ -39,6 +39,8 @@ const ( fieldSlotName = "slot_name" fieldBatching = "batching" fieldMaxParallelSnapshotTables = "max_parallel_snapshot_tables" + + shutdownTimeout = 5 * time.Second ) type asyncMessage struct { @@ -292,7 +294,6 @@ func init() { type pgStreamInput struct { streamConfig *pglogicalstream.Config - pgLogicalStream *pglogicalstream.Stream logger *service.Logger mgr *service.Resources msgChan chan asyncMessage @@ -309,24 +310,25 @@ func (p *pgStreamInput) Connect(ctx context.Context) error { if err != nil { return fmt.Errorf("unable to create replication stream: %w", err) } - - p.pgLogicalStream = pgStream batcher, err := p.batching.NewBatcher(p.mgr) if err != nil { return err } // Reset our stop signal p.stopSig = shutdown.NewSignaller() - go p.processStream(batcher) + go p.processStream(pgStream, batcher) return err } -func (p *pgStreamInput) processStream(batcher *service.Batcher) { +func (p *pgStreamInput) processStream(pgStream *pglogicalstream.Stream, batcher *service.Batcher) { ctx, _ := p.stopSig.SoftStopCtx(context.Background()) defer func() { ctx, _ := p.stopSig.HardStopCtx(context.Background()) if err := batcher.Close(ctx); err != nil { - p.logger.Errorf("uneable to close batcher: %v", err) + p.logger.Errorf("unable to close batcher: %v", err) + } + if err := pgStream.Stop(ctx); err != nil { + p.logger.Errorf("unable to stop replication stream: %v", err) } p.stopSig.TriggerHasStopped() }() @@ -344,12 +346,12 @@ func (p *pgStreamInput) processStream(batcher *service.Batcher) { p.logger.Debugf("timed flush batch error: %v", err) break } - if err := p.flushBatch(ctx, cp, flushedBatch); err != nil { + if err := p.flushBatch(ctx, pgStream, cp, flushedBatch); err != nil { p.logger.Debugf("failed to flush batch: %v", err) break } - case message := <-p.pgLogicalStream.Messages(): + case message := <-pgStream.Messages(): var ( mb []byte err error @@ -387,7 +389,7 @@ func (p *pgStreamInput) processStream(batcher *service.Batcher) { p.logger.Debugf("error flushing batch: %v", err) break } - if err := p.flushBatch(ctx, cp, flushedBatch); err != nil { + if err := p.flushBatch(ctx, pgStream, cp, flushedBatch); err != nil { p.logger.Debugf("failed to flush batch: %v", err) break } @@ -403,6 +405,7 @@ func (p *pgStreamInput) processStream(batcher *service.Batcher) { func (p *pgStreamInput) flushBatch( ctx context.Context, + pgStream *pglogicalstream.Stream, checkpointer *checkpoint.Capped[*int64], batch service.MessageBatch, ) error { @@ -434,7 +437,7 @@ func (p *pgStreamInput) flushBatch( if lsn == nil { return nil } - if err = p.pgLogicalStream.AckLSN(ctx, Int64ToLSN(*lsn)); err != nil { + if err = pgStream.AckLSN(ctx, Int64ToLSN(*lsn)); err != nil { return fmt.Errorf("unable to ack LSN to postgres: %w", err) } return nil @@ -462,17 +465,14 @@ func (p *pgStreamInput) Close(ctx context.Context) error { p.stopSig.TriggerSoftStop() select { case <-ctx.Done(): + case <-time.After(shutdownTimeout): case <-p.stopSig.HasStoppedChan(): } - if p.pgLogicalStream != nil { - if err := p.pgLogicalStream.Stop(ctx); err != nil { - return err - } - } p.stopSig.TriggerHardStop() select { case <-ctx.Done(): return ctx.Err() + case <-time.After(shutdownTimeout): case <-p.stopSig.HasStoppedChan(): } return nil From 080f3f7c9c9e116fa37a715fb7b65abaed3fef91 Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Wed, 13 Nov 2024 21:30:01 +0000 Subject: [PATCH 099/118] pgcdc: simplify internal flow control Simplify the internal flow control of the logical stream by just returning and handling errors at the top level. --- internal/impl/postgresql/input_pg_stream.go | 22 +- .../pglogicalstream/logical_stream.go | 379 +++++++++--------- .../postgresql/pglogicalstream/monitor.go | 4 +- .../pglogicalstream/pluginhandlers.go | 24 +- 4 files changed, 223 insertions(+), 206 deletions(-) diff --git a/internal/impl/postgresql/input_pg_stream.go b/internal/impl/postgresql/input_pg_stream.go index 0facf0f6fe..f546eb8f20 100644 --- a/internal/impl/postgresql/input_pg_stream.go +++ b/internal/impl/postgresql/input_pg_stream.go @@ -325,10 +325,11 @@ func (p *pgStreamInput) processStream(pgStream *pglogicalstream.Stream, batcher defer func() { ctx, _ := p.stopSig.HardStopCtx(context.Background()) if err := batcher.Close(ctx); err != nil { - p.logger.Errorf("unable to close batcher: %v", err) + p.logger.Errorf("unable to close batcher: %s", err) } + // TODO(rockwood): We should wait for outstanding acks to be completed (best effort) if err := pgStream.Stop(ctx); err != nil { - p.logger.Errorf("unable to stop replication stream: %v", err) + p.logger.Errorf("unable to stop replication stream: %s", err) } p.stopSig.TriggerHasStopped() }() @@ -343,14 +344,13 @@ func (p *pgStreamInput) processStream(pgStream *pglogicalstream.Stream, batcher nextTimedBatchChan = nil flushedBatch, err := batcher.Flush(ctx) if err != nil { - p.logger.Debugf("timed flush batch error: %v", err) + p.logger.Debugf("timed flush batch error: %s", err) break } if err := p.flushBatch(ctx, pgStream, cp, flushedBatch); err != nil { - p.logger.Debugf("failed to flush batch: %v", err) + p.logger.Debugf("failed to flush batch: %s", err) break } - case message := <-pgStream.Messages(): var ( mb []byte @@ -358,7 +358,7 @@ func (p *pgStreamInput) processStream(pgStream *pglogicalstream.Stream, batcher ) if len(message.Changes) == 0 { - p.logger.Errorf("received empty message (LSN=%v)", message.Lsn) + p.logger.Errorf("received empty message (LSN=%s)", message.Lsn) break } @@ -386,11 +386,11 @@ func (p *pgStreamInput) processStream(pgStream *pglogicalstream.Stream, batcher nextTimedBatchChan = nil flushedBatch, err := batcher.Flush(ctx) if err != nil { - p.logger.Debugf("error flushing batch: %v", err) + p.logger.Debugf("error flushing batch: %s", err) break } if err := p.flushBatch(ctx, pgStream, cp, flushedBatch); err != nil { - p.logger.Debugf("failed to flush batch: %v", err) + p.logger.Debugf("failed to flush batch: %s", err) break } } else { @@ -399,6 +399,12 @@ func (p *pgStreamInput) processStream(pgStream *pglogicalstream.Stream, batcher nextTimedBatchChan = time.After(d) } } + case err := <-pgStream.Errors(): + p.logger.Warnf("logical replication stream error: %s", err) + // If the stream has internally errored then we should stop and restart processing + p.stopSig.TriggerSoftStop() + case <-p.stopSig.SoftStopChan(): + p.logger.Debug("soft stop triggered, stopping logical replication stream") } } } diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 5b221f369f..15b9534378 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -15,9 +15,9 @@ import ( "fmt" "slices" "strings" - "sync" "time" + "github.com/Jeffail/shutdown" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgproto3" "github.com/redpanda-data/benthos/v4/public/service" @@ -30,20 +30,20 @@ import ( // Stream is a structure that represents a logical replication stream // It includes the connection to the database, the context for the stream, and snapshotting functionality type Stream struct { - pgConn *pgconn.PgConn - streamCtx context.Context - streamCancel context.CancelFunc + pgConn *pgconn.PgConn - standbyCtxCancel context.CancelFunc + shutSig *shutdown.Signaller clientXLogPos *watermark.Value[LSN] standbyMessageTimeout time.Duration nextStandbyMessageDeadline time.Time messages chan StreamMessage - snapshotName string - slotName string - schema string + errors chan error + + snapshotName string + slotName string + schema string // includes schema tableQualifiedName []string snapshotBatchSize int @@ -55,9 +55,6 @@ type Stream struct { streamUncommitted bool snapshotter *Snapshotter maxParallelSnapshotTables int - - m sync.Mutex - stopped bool } // NewPgStream creates a new instance of the Stream struct @@ -66,10 +63,24 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { return nil, errors.New("missing replication slot name") } + // Cleanup state - this will be accumulated as the function progresses and cleared + // if we successfully create a stream. + var cleanups []func() + defer func() { + for i := len(cleanups) - 1; i >= 0; i-- { + cleanups[i]() + } + }() + dbConn, err := pgconn.ConnectConfig(ctx, config.DBConfig.Copy()) if err != nil { return nil, err } + cleanups = append(cleanups, func() { + if err := dbConn.Close(ctx); err != nil { + config.Logger.Warnf("unable to properly cleanup db connection on stream creation failure: %s", err) + } + }) if err = dbConn.Ping(ctx); err != nil { return nil, err @@ -82,6 +93,7 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { stream := &Stream{ pgConn: dbConn, messages: make(chan StreamMessage), + errors: make(chan error, 1), slotName: config.ReplicationSlotName, snapshotMemorySafetyFactor: config.SnapshotMemorySafetyFactor, streamUncommitted: config.StreamUncommitted, @@ -90,8 +102,8 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { tableQualifiedName: tableNames, maxParallelSnapshotTables: config.MaxParallelSnapshotTables, logger: config.Logger, - m: sync.Mutex{}, decodingPlugin: decodingPluginFromString(config.DecodingPlugin), + shutSig: shutdown.NewSignaller(), } var version int @@ -102,13 +114,14 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { snapshotter, err := NewSnapshotter(config.DBRawDSN, stream.logger, version) if err != nil { - stream.logger.Errorf("Failed to open SQL connection to prepare snapshot: %v", err.Error()) - if err = stream.cleanUpOnFailure(ctx); err != nil { - stream.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) - } return nil, err } stream.snapshotter = snapshotter + cleanups = append(cleanups, func() { + if err := snapshotter.closeConn(); err != nil { + config.Logger.Warnf("unable to properly cleanup snapshotter connection on stream creation failure: %s", err) + } + }) var pluginArguments []string if stream.decodingPlugin == "pgoutput" { @@ -139,6 +152,10 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { if err = CreatePublication(ctx, stream.pgConn, pubName, tableNames); err != nil { return nil, err } + cleanups = append(cleanups, func() { + // TODO: Drop publication if it was created (meaning it's not existing state we might want to keep). + }) + sysident, err := IdentifySystem(ctx, stream.pgConn) if err != nil { return nil, err @@ -160,15 +177,29 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { if len(connExecResult) == 0 || len(connExecResult[0].Rows) == 0 { // here we create a new replication slot because there is no slot found var createSlotResult CreateReplicationSlotResult - createSlotResult, err = CreateReplicationSlot(ctx, stream.pgConn, stream.slotName, stream.decodingPlugin.String(), - CreateReplicationSlotOptions{Temporary: config.TemporaryReplicationSlot, + createSlotResult, err = CreateReplicationSlot( + ctx, + stream.pgConn, + stream.slotName, + stream.decodingPlugin.String(), + CreateReplicationSlotOptions{ + Temporary: config.TemporaryReplicationSlot, SnapshotAction: "export", - }, version, stream.snapshotter) + }, + version, + stream.snapshotter, + ) if err != nil { return nil, err } stream.snapshotName = createSlotResult.SnapshotName freshlyCreatedSlot = true + cleanups = append(cleanups, func() { + err := DropReplicationSlot(ctx, stream.pgConn, stream.slotName, DropReplicationSlotOptions{Wait: true}) + if err != nil { + config.Logger.Warnf("unable to properly cleanup replication slot on stream creation failure: %s", err) + } + }) } else { slotCheckRow := connExecResult[0].Rows[0] confirmedLSNFromDB = string(slotCheckRow[0]) @@ -194,72 +225,101 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { stream.standbyMessageTimeout = config.PgStandbyTimeout stream.nextStandbyMessageDeadline = time.Now().Add(stream.standbyMessageTimeout) - stream.streamCtx, stream.streamCancel = context.WithCancel(context.Background()) monitor, err := NewMonitor(config.DBRawDSN, stream.logger, tableNames, stream.slotName, config.WalMonitorInterval) if err != nil { return nil, err } stream.monitor = monitor + cleanups = append(cleanups, func() { + if err := monitor.Stop(); err != nil { + config.Logger.Warnf("unable to properly cleanup monitor on stream creation failure: %s", err) + } + }) - stream.logger.Debugf("Starting stream from LSN %s with clientXLogPos %s and snapshot name %s", lsnrestart.String(), stream.clientXLogPos.Get().String(), stream.snapshotName) + stream.logger.Debugf("starting stream from LSN %s with clientXLogPos %s and snapshot name %s", lsnrestart.String(), stream.clientXLogPos.Get().String(), stream.snapshotName) + // TODO(le-vlad): if snapshot processing is restarted we will just skip right to streaming... if !freshlyCreatedSlot || !config.StreamOldData { - if err = stream.startLr(lsnrestart); err != nil { + if err = stream.startLr(ctx, lsnrestart); err != nil { return nil, err } - go stream.streamMessagesAsync() + go func() { + defer stream.shutSig.TriggerHasStopped() + if err := stream.streamMessages(); err != nil { + stream.errors <- fmt.Errorf("logical replication stream error: %w", err) + } + }() } else { - // New messages will be streamed after the snapshot has been processed. - // stream.startLr() and stream.streamMessagesAsync() will be called inside stream.processSnapshot() go func() { - if err := stream.processSnapshot(ctx, lsnrestart); err != nil { - stream.logger.Errorf("Failed to process snapshot: %v", err.Error()) + defer stream.shutSig.TriggerHasStopped() + if err := stream.processSnapshot(); err != nil { + stream.errors <- fmt.Errorf("failed to process snapshot: %w", err) + return + } + ctx, _ := stream.shutSig.SoftStopCtx(context.Background()) + if err := stream.startLr(ctx, lsnrestart); err != nil { + stream.errors <- fmt.Errorf("failed to start logical replication: %w", err) + return + } + if err := stream.streamMessages(); err != nil { + stream.errors <- fmt.Errorf("logical replication stream error: %w", err) } }() } - return stream, err + // Success! No need to cleanup + cleanups = nil + return stream, nil } // GetProgress returns the progress of the stream. -// including the % of snapsho messages processed and the WAL lag in bytes. +// including the % of snapshot messages processed and the WAL lag in bytes. func (s *Stream) GetProgress() *Report { return s.monitor.Report() } -func (s *Stream) startLr(lsnStart LSN) error { - if err := StartReplication(context.Background(), s.pgConn, s.slotName, lsnStart, StartReplicationOptions{PluginArgs: s.decodingPluginArguments}); err != nil { +func (s *Stream) startLr(ctx context.Context, lsnStart LSN) error { + err := StartReplication( + ctx, + s.pgConn, + s.slotName, + lsnStart, + StartReplicationOptions{ + PluginArgs: s.decodingPluginArguments, + }, + ) + if err != nil { return err } - - s.logger.Infof("Started logical replication on slot slot-name: %v", s.slotName) + s.logger.Debugf("Started logical replication on slot slot-name: %v", s.slotName) return nil } // AckLSN acknowledges the LSN up to which the stream has processed the messages. // This makes Postgres to remove the WAL files that are no longer needed. func (s *Stream) AckLSN(ctx context.Context, lsn string) error { + if s.shutSig.IsHardStopSignalled() { + return fmt.Errorf("unable to ack LSN %s stream shutting down", lsn) + } clientXLogPos, err := ParseLSN(lsn) if err != nil { - s.logger.Errorf("Failed to parse LSN for Acknowledge: %v", err) - if err = s.Stop(ctx); err != nil { - s.logger.Errorf("Failed to stop the stream: %v", err) - } - return err } - err = SendStandbyStatusUpdate(context.Background(), s.pgConn, StandbyStatusUpdate{ - WALApplyPosition: clientXLogPos + 1, - WALWritePosition: clientXLogPos + 1, - WALFlushPosition: clientXLogPos + 1, - ReplyRequested: true, - }) + err = SendStandbyStatusUpdate( + ctx, + s.pgConn, + StandbyStatusUpdate{ + WALApplyPosition: clientXLogPos + 1, + WALWritePosition: clientXLogPos + 1, + WALFlushPosition: clientXLogPos + 1, + ReplyRequested: true, + }, + ) if err != nil { - s.logger.Errorf("Failed to send Standby status message at LSN#%s: %v", clientXLogPos.String(), err) - return err + return fmt.Errorf("Failed to send Standby status message at LSN %s: %w", clientXLogPos.String(), err) } // Update client XLogPos after we ack the message @@ -270,7 +330,7 @@ func (s *Stream) AckLSN(ctx context.Context, lsn string) error { return nil } -func (s *Stream) streamMessagesAsync() { +func (s *Stream) streamMessages() error { var handler PluginHandler switch s.decodingPlugin { case "wal2json": @@ -278,73 +338,50 @@ func (s *Stream) streamMessagesAsync() { case "pgoutput": handler = NewPgOutputPluginHandler(s.messages, s.streamUncommitted, s.monitor, s.clientXLogPos) default: - s.logger.Error("Invalid decoding plugin. Cant find needed handler implementation") - if err := s.Stop(context.TODO()); err != nil { - s.logger.Errorf("Failed to stop the stream: %v", err) - } - - return + return fmt.Errorf("invalid decoding plugin: %q", s.decodingPlugin) } - for { - if s.streamCtx.Err() != nil { - s.logger.Debug("Stream was cancelled... exiting...") - return - } + ctx, _ := s.shutSig.SoftStopCtx(context.Background()) + for !s.shutSig.IsSoftStopSignalled() { if time.Now().After(s.nextStandbyMessageDeadline) { - if s.pgConn.IsClosed() { - s.logger.Warn("Postgres connection is closed...stop reading from replication slot") - return - } - pos := s.clientXLogPos.Get() - err := SendStandbyStatusUpdate(context.Background(), s.pgConn, StandbyStatusUpdate{ - WALWritePosition: pos, - }) - + err := SendStandbyStatusUpdate( + ctx, + s.pgConn, + StandbyStatusUpdate{ + WALWritePosition: pos, + }, + ) if err != nil { - s.logger.Errorf("Failed to send Standby status message at LSN#%s: %v", pos.String(), err) - if err = s.Stop(context.TODO()); err != nil { - s.logger.Errorf("Failed to stop the stream: %v", err) - } - return + return fmt.Errorf("unable to send standby status message at LSN %s: %w", pos, err) } s.logger.Debugf("Sent Standby status message at LSN#%s", pos.String()) s.nextStandbyMessageDeadline = time.Now().Add(s.standbyMessageTimeout) } - - ctx, cancel := context.WithDeadline(context.Background(), s.nextStandbyMessageDeadline) - rawMsg, err := s.pgConn.ReceiveMessage(ctx) - s.standbyCtxCancel = cancel - - if err != nil && (errors.Is(err, context.Canceled) || s.stopped) { - s.logger.Warn("Service was interrupted....stop reading from replication slot") - return - } - + recvCtx, cancel := context.WithDeadline(ctx, s.nextStandbyMessageDeadline) + rawMsg, err := s.pgConn.ReceiveMessage(recvCtx) + cancel() // don't leak goroutine + hitStandbyTimeout := errors.Is(err, context.DeadlineExceeded) && ctx.Err() == nil if err != nil { - if pgconn.Timeout(err) { + if hitStandbyTimeout || pgconn.Timeout(err) { + s.logger.Info("continue") continue } - - s.logger.Errorf("Failed to receive messages from PostgreSQL: %v", err) - if err = s.Stop(context.TODO()); err != nil { - s.logger.Errorf("Failed to stop the stream: %v", err) - } - return + return fmt.Errorf("failed to receive messages from Postgres: %w", err) } if errMsg, ok := rawMsg.(*pgproto3.ErrorResponse); ok { - s.logger.Errorf("Received error message from Postgres: %v", errMsg) - if err = s.Stop(context.TODO()); err != nil { - s.logger.Errorf("Failed to stop the stream: %v", err) - } - return + return fmt.Errorf("received error message from Postgres: %v", errMsg) } msg, ok := rawMsg.(*pgproto3.CopyData) if !ok { - s.logger.Warnf("Received unexpected message: %T\n", rawMsg) + s.logger.Warnf("received unexpected message: %T", rawMsg) + continue + } + + if len(msg.Data) == 0 { + s.logger.Warn("received malformatted with no data") continue } @@ -352,12 +389,8 @@ func (s *Stream) streamMessagesAsync() { case PrimaryKeepaliveMessageByteID: pkm, err := ParsePrimaryKeepaliveMessage(msg.Data[1:]) if err != nil { - s.logger.Errorf("Failed to parse PrimaryKeepaliveMessage: %v", err) - if err = s.Stop(context.TODO()); err != nil { - s.logger.Errorf("Failed to stop the stream: %v", err) - } + return fmt.Errorf("failed to parse PrimaryKeepaliveMessage: %w", err) } - if pkm.ReplyRequested { s.nextStandbyMessageDeadline = time.Time{} } @@ -367,63 +400,45 @@ func (s *Stream) streamMessagesAsync() { case XLogDataByteID: xld, err := ParseXLogData(msg.Data[1:]) if err != nil { - s.logger.Errorf("Failed to parse XLogData: %v", err) - if err = s.Stop(context.TODO()); err != nil { - s.logger.Errorf("Failed to stop the stream: %v", err) - } + return fmt.Errorf("failed to parse XLogData: %w", err) } clientXLogPos := xld.WALStart + LSN(len(xld.WALData)) - commit, err := handler.Handle(s.streamCtx, clientXLogPos, xld) + commit, err := handler.Handle(ctx, clientXLogPos, xld) if err != nil { - s.logger.Errorf("decodePgOutputChanges failed: %w", err) - if err = s.Stop(context.TODO()); err != nil { - s.logger.Errorf("Failed to stop the stream: %v", err) - } + return fmt.Errorf("decoding postgres changes failed: %w", err) } else if commit { // This is a hack and we probably should not do it - if err = s.AckLSN(context.TODO(), clientXLogPos.String()); err != nil { - s.logger.Errorf("Failed to ack commit message: %v", err) + if err = s.AckLSN(ctx, clientXLogPos.String()); err != nil { + s.logger.Warnf("Failed to ack commit message LSN: %v", err) } } } } + // clean shutdown, return nil + return nil } -func (s *Stream) processSnapshot(ctx context.Context, lsnStart LSN) error { +func (s *Stream) processSnapshot() error { if err := s.snapshotter.prepare(); err != nil { - s.logger.Errorf("Failed to prepare database snapshot. Probably snapshot is expired...: %v", err.Error()) - if err = s.cleanUpOnFailure(ctx); err != nil { - s.logger.Errorf("Failed to clean up resources on accident: %v", err.Error()) - } - return err + return fmt.Errorf("failed to prepare database snapshot - snapshot may be expired: %w", err) } defer func() { if err := s.snapshotter.releaseSnapshot(); err != nil { - s.logger.Errorf("Failed to release database snapshot: %v", err.Error()) + s.logger.Warnf("Failed to release database snapshot: %v", err.Error()) } if err := s.snapshotter.closeConn(); err != nil { - s.logger.Errorf("Failed to close database connection: %v", err.Error()) + s.logger.Warnf("Failed to close database connection: %v", err.Error()) } }() - s.logger.Infof("Starting snapshot processing") - sem := make(chan struct{}, s.maxParallelSnapshotTables) + s.logger.Debugf("Starting snapshot processing") var wg errgroup.Group + wg.SetLimit(s.maxParallelSnapshotTables) for _, table := range s.tableQualifiedName { tableName := table - sem <- struct{}{} wg.Go(func() (err error) { - s.logger.Infof("Processing snapshot for table: %v", table) - - defer func() { - defer func() { <-sem }() - if err != nil { - if cleanupErr := s.cleanUpOnFailure(ctx); cleanupErr != nil { - s.logger.Errorf("Failed to clean up resources on accident: %v", cleanupErr.Error()) - } - } - }() + s.logger.Debugf("Processing snapshot for table: %v", table) var ( avgRowSizeBytes sql.NullInt64 @@ -432,9 +447,7 @@ func (s *Stream) processSnapshot(ctx context.Context, lsnStart LSN) error { avgRowSizeBytes, err = s.snapshotter.findAvgRowSize(table) if err != nil { - s.logger.Errorf("Failed to calculate average row size for table %v: %v", table, err.Error()) - - return err + return fmt.Errorf("Failed to calculate average row size for table %v: %w", table, err) } availableMemory := getAvailableMemory() @@ -443,12 +456,11 @@ func (s *Stream) processSnapshot(ctx context.Context, lsnStart LSN) error { batchSize = s.snapshotBatchSize } - s.logger.Infof("Querying snapshot batch_side: %v, available_memory: %v, avg_row_size: %v", batchSize, availableMemory, avgRowSizeBytes.Int64) + s.logger.Debugf("Querying snapshot batch_side: %v, available_memory: %v, avg_row_size: %v", batchSize, availableMemory, avgRowSizeBytes.Int64) tablePk, err := s.getPrimaryKeyColumn(table) if err != nil { - s.logger.Errorf("Failed to get primary key column for table %v: %v", table, err.Error()) - return err + return fmt.Errorf("failed to get primary key column for table %v: %w", table, err) } var lastPkVal any @@ -457,30 +469,24 @@ func (s *Stream) processSnapshot(ctx context.Context, lsnStart LSN) error { var snapshotRows *sql.Rows queryStart := time.Now() if snapshotRows, err = s.snapshotter.querySnapshotData(table, lastPkVal, tablePk, batchSize); err != nil { - s.logger.Errorf("Failed to query snapshot data for table %v: %v", table, err.Error()) - s.logger.Errorf("Failed to query snapshot for table %v: %v", table, err.Error()) - return err + return fmt.Errorf("failed to query snapshot data for table %v: %w", table, err) } queryDuration := time.Since(queryStart) s.logger.Debugf("Query duration: %v %s \n", queryDuration, tableName) if snapshotRows.Err() != nil { - s.logger.Errorf("Failed to get snapshot data for table %v: %v", table, snapshotRows.Err().Error()) - s.logger.Errorf("Failed to query snapshot for table %v: %v", table, err.Error()) - return err + return fmt.Errorf("failed to get snapshot data for table %v: %w", table, snapshotRows.Err()) } columnTypes, err := snapshotRows.ColumnTypes() if err != nil { - s.logger.Errorf("Failed to get column types for table %v: %v", table, err.Error()) - return err + return fmt.Errorf("failed to get column types for table %v: %w", table, err) } columnNames, err := snapshotRows.Columns() if err != nil { - s.logger.Errorf("Failed to get column names for table %v: %v", table, err.Error()) - return err + return fmt.Errorf("failed to get column names for table %v: %w", table, err) } var rowsCount = 0 @@ -499,8 +505,7 @@ func (s *Stream) processSnapshot(ctx context.Context, lsnStart LSN) error { totalScanDuration += scanEnd if err != nil { - s.logger.Errorf("Failed to scan row for table %v: %v", table, err.Error()) - return err + return fmt.Errorf("failed to scan row for table %v: %v", table, err.Error()) } var data = make(map[string]any) @@ -532,7 +537,11 @@ func (s *Stream) processSnapshot(ctx context.Context, lsnStart LSN) error { snapshotChangePacket.Mode = StreamModeSnapshot waitingFromBenthos := time.Now() - s.messages <- snapshotChangePacket + select { + case s.messages <- snapshotChangePacket: + case <-s.shutSig.SoftStopChan(): + return nil + } totalWaitingFromBenthos += time.Since(waitingFromBenthos) } @@ -551,17 +560,7 @@ func (s *Stream) processSnapshot(ctx context.Context, lsnStart LSN) error { return nil }) } - - if err := wg.Wait(); err != nil { - return err - } - - if err := s.startLr(lsnStart); err != nil { - s.logger.Errorf("Failed to start logical replication after snapshot: %v", err.Error()) - return err - } - go s.streamMessagesAsync() - return nil + return wg.Wait() } // Messages is a channel that can be used to consume messages from the plugin. It will contain LSN nil for snapshot messages @@ -569,17 +568,13 @@ func (s *Stream) Messages() chan StreamMessage { return s.messages } -// cleanUpOnFailure drops replication slot and publication if database snapshotting was failed for any reason -func (s *Stream) cleanUpOnFailure(ctx context.Context) error { - s.logger.Warnf("Cleaning up resources on accident: %v", s.slotName) - err := DropReplicationSlot(ctx, s.pgConn, s.slotName, DropReplicationSlotOptions{Wait: true}) - if err != nil { - s.logger.Errorf("Failed to drop replication slot: %s", err.Error()) - } - return s.pgConn.Close(ctx) +// Errors is a channel that can be used to see if and error has occured internally and the stream should be restarted +func (s *Stream) Errors() chan error { + return s.errors } func (s *Stream) getPrimaryKeyColumn(tableName string) (string, error) { + // TODO(le-vlad): support composite primary keys q, err := sanitize.SQLQuery(` SELECT a.attname FROM pg_index i @@ -603,28 +598,30 @@ func (s *Stream) getPrimaryKeyColumn(tableName string) (string, error) { return pkColName, nil } -// Stop closes the stream conect and prevents from replication slot read +// Stop closes the stream (hopefully gracefully) func (s *Stream) Stop(ctx context.Context) error { - if s == nil { - return nil - } - s.m.Lock() - s.stopped = true - s.m.Unlock() - s.monitor.Stop() - - if s.pgConn != nil { - if s.streamCtx != nil { - s.streamCancel() - // s.standbyCtxCancel is initialized later when starting reading from the replication slot. - // In case we failed to start replication of the process was shut down before starting the replication slot - // we need to check if the context is not nil before calling cancel - if s.standbyCtxCancel != nil { - s.standbyCtxCancel() - } + s.shutSig.TriggerSoftStop() + var wg errgroup.Group + stopNowCtx, _ := s.shutSig.HardStopCtx(ctx) + wg.Go(func() error { + return s.pgConn.Close(stopNowCtx) + }) + wg.Go(func() error { + return s.monitor.Stop() + }) + select { + case <-ctx.Done(): + case <-s.shutSig.HasStoppedChan(): + return wg.Wait() + } + s.shutSig.TriggerHardStop() + err := wg.Wait() + select { + case <-time.After(time.Second): + if err == nil { + return errors.New("unable to cleanly shutdown postgres logical replication stream") } - return s.pgConn.Close(context.Background()) + case <-s.shutSig.HasStoppedChan(): } - - return nil + return err } diff --git a/internal/impl/postgresql/pglogicalstream/monitor.go b/internal/impl/postgresql/pglogicalstream/monitor.go index 0408c5a1c6..a22c56bf19 100644 --- a/internal/impl/postgresql/pglogicalstream/monitor.go +++ b/internal/impl/postgresql/pglogicalstream/monitor.go @@ -165,8 +165,8 @@ func (m *Monitor) Report() *Report { } // Stop stops the monitor -func (m *Monitor) Stop() { +func (m *Monitor) Stop() error { m.cancelTicker() m.ticker.Stop() - m.dbConn.Close() + return m.dbConn.Close() } diff --git a/internal/impl/postgresql/pglogicalstream/pluginhandlers.go b/internal/impl/postgresql/pglogicalstream/pluginhandlers.go index 8b2e4ecf85..1d61204228 100644 --- a/internal/impl/postgresql/pglogicalstream/pluginhandlers.go +++ b/internal/impl/postgresql/pglogicalstream/pluginhandlers.go @@ -37,7 +37,7 @@ func NewWal2JsonPluginHandler(messages chan StreamMessage, monitor *Monitor) *Wa } // Handle handles the wal2json output -func (w *Wal2JsonPluginHandler) Handle(_ context.Context, clientXLogPos LSN, xld XLogData) (bool, error) { +func (w *Wal2JsonPluginHandler) Handle(ctx context.Context, clientXLogPos LSN, xld XLogData) (bool, error) { // get current stream metrics metrics := w.monitor.Report() message, err := decodeWal2JsonChanges(clientXLogPos.String(), xld.WALData) @@ -47,7 +47,11 @@ func (w *Wal2JsonPluginHandler) Handle(_ context.Context, clientXLogPos LSN, xld if message != nil && len(message.Changes) > 0 { message.WALLagBytes = &metrics.WalLagInBytes - w.messages <- *message + select { + case w.messages <- *message: + case <-ctx.Done(): + return false, ctx.Err() + } } return false, nil @@ -127,8 +131,7 @@ func (p *PgOutputUnbufferedPluginHandler) Handle(ctx context.Context, clientXLog if message != nil { lsn := clientXLogPos.String() - p.lastEmitted = clientXLogPos - p.messages <- StreamMessage{ + msg := StreamMessage{ Lsn: &lsn, Changes: []StreamMessageChanges{ *message, @@ -136,6 +139,12 @@ func (p *PgOutputUnbufferedPluginHandler) Handle(ctx context.Context, clientXLog Mode: StreamModeStreaming, WALLagBytes: &p.monitor.Report().WalLagInBytes, } + select { + case p.messages <- msg: + p.lastEmitted = clientXLogPos + case <-ctx.Done(): + return false, ctx.Err() + } } return false, nil @@ -177,12 +186,17 @@ func (p *PgOutputBufferedPluginHandler) Handle(ctx context.Context, clientXLogPo if len(p.pgoutputChanges) >= 0 { // send all collected changes lsn := clientXLogPos.String() - p.messages <- StreamMessage{ + msg := StreamMessage{ Lsn: &lsn, Changes: p.pgoutputChanges, Mode: StreamModeStreaming, WALLagBytes: &p.monitor.Report().WalLagInBytes, } + select { + case p.messages <- msg: + case <-ctx.Done(): + return false, ctx.Err() + } } return false, nil From bde31e27b5d1e35e94138ad544b789195a71e895 Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Wed, 13 Nov 2024 22:04:23 +0000 Subject: [PATCH 100/118] pgcdc: don't produce 0 messages --- internal/impl/postgresql/input_pg_stream.go | 2 +- .../impl/postgresql/pglogicalstream/pluginhandlers.go | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/internal/impl/postgresql/input_pg_stream.go b/internal/impl/postgresql/input_pg_stream.go index f546eb8f20..07779ddaf4 100644 --- a/internal/impl/postgresql/input_pg_stream.go +++ b/internal/impl/postgresql/input_pg_stream.go @@ -358,7 +358,7 @@ func (p *pgStreamInput) processStream(pgStream *pglogicalstream.Stream, batcher ) if len(message.Changes) == 0 { - p.logger.Errorf("received empty message (LSN=%s)", message.Lsn) + p.logger.Errorf("received empty message (LSN=%v)", message.Lsn) break } diff --git a/internal/impl/postgresql/pglogicalstream/pluginhandlers.go b/internal/impl/postgresql/pglogicalstream/pluginhandlers.go index 1d61204228..0f881c201a 100644 --- a/internal/impl/postgresql/pglogicalstream/pluginhandlers.go +++ b/internal/impl/postgresql/pglogicalstream/pluginhandlers.go @@ -132,10 +132,8 @@ func (p *PgOutputUnbufferedPluginHandler) Handle(ctx context.Context, clientXLog if message != nil { lsn := clientXLogPos.String() msg := StreamMessage{ - Lsn: &lsn, - Changes: []StreamMessageChanges{ - *message, - }, + Lsn: &lsn, + Changes: []StreamMessageChanges{*message}, Mode: StreamModeStreaming, WALLagBytes: &p.monitor.Report().WalLagInBytes, } @@ -183,7 +181,7 @@ func (p *PgOutputBufferedPluginHandler) Handle(ctx context.Context, clientXLogPo return false, nil } - if len(p.pgoutputChanges) >= 0 { + if len(p.pgoutputChanges) > 0 { // send all collected changes lsn := clientXLogPos.String() msg := StreamMessage{ From d1ea32525f2c2b7f0f0549fe2ca7fffe51cea6de Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Wed, 13 Nov 2024 22:08:29 +0000 Subject: [PATCH 101/118] pgcdc: rename stream uncommitted to batch transactions --- internal/impl/postgresql/input_pg_stream.go | 14 +++++++------- .../postgresql/pglogicalstream/logical_stream.go | 10 +++++----- .../postgresql/pglogicalstream/pluginhandlers.go | 4 ++-- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/internal/impl/postgresql/input_pg_stream.go b/internal/impl/postgresql/input_pg_stream.go index 07779ddaf4..f1a3ce4e68 100644 --- a/internal/impl/postgresql/input_pg_stream.go +++ b/internal/impl/postgresql/input_pg_stream.go @@ -25,7 +25,7 @@ import ( const ( fieldDSN = "dsn" - fieldStreamUncommitted = "stream_uncommitted" + fieldBatchTransactions = "batch_transactions" fieldStreamSnapshot = "stream_snapshot" fieldSnapshotMemSafetyFactor = "snapshot_memory_safety_factor" fieldSnapshotBatchSize = "snapshot_batch_size" @@ -66,9 +66,9 @@ This input adds the following metadata fields to each message: Field(service.NewStringField(fieldDSN). Description("The Data Source Name for the PostgreSQL database in the form of `postgres://[user[:password]@][netloc][:port][/dbname][?param1=value1&...]`. Please note that Postgres enforces SSL by default, you can override this with the parameter `sslmode=disable` if required."). Example("postgres://foouser:foopass@localhost:5432/foodb?sslmode=disable")). - Field(service.NewBoolField(fieldStreamUncommitted). - Description("If set to true, the plugin will stream uncommitted transactions before receiving a commit message from PostgreSQL. This may result in duplicate records if the connector is restarted."). - Default(false)). + Field(service.NewBoolField(fieldBatchTransactions). + Description("When set to true, transactions are batched into a single message. Note that this setting has no effect when using wal2json"). + Default(true)). Field(service.NewBoolField(fieldStreamSnapshot). Description("When set to true, the plugin will first stream a snapshot of all existing data in the database before streaming changes. In order to use this the tables that are being snapshot MUST have a primary key set so that reading from the table can be parallelized."). Example(true). @@ -130,7 +130,7 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser streamSnapshot bool snapshotMemSafetyFactor float64 decodingPlugin string - streamUncommitted bool + batchTransactions bool snapshotBatchSize int checkpointLimit int walMonitorInterval time.Duration @@ -178,7 +178,7 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser return nil, err } - if streamUncommitted, err = conf.FieldBool(fieldStreamUncommitted); err != nil { + if batchTransactions, err = conf.FieldBool(fieldBatchTransactions); err != nil { return nil, err } @@ -237,7 +237,7 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser BatchSize: snapshotBatchSize, StreamOldData: streamSnapshot, TemporaryReplicationSlot: temporarySlot, - StreamUncommitted: streamUncommitted, + BatchTransactions: batchTransactions, DecodingPlugin: decodingPlugin, SnapshotMemorySafetyFactor: snapshotMemSafetyFactor, PgStandbyTimeout: pgStandbyTimeout, diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 15b9534378..905f76a732 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -52,7 +52,7 @@ type Stream struct { snapshotMemorySafetyFactor float64 logger *service.Logger monitor *Monitor - streamUncommitted bool + batchTransactions bool snapshotter *Snapshotter maxParallelSnapshotTables int } @@ -96,7 +96,7 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { errors: make(chan error, 1), slotName: config.ReplicationSlotName, snapshotMemorySafetyFactor: config.SnapshotMemorySafetyFactor, - streamUncommitted: config.StreamUncommitted, + batchTransactions: config.BatchTransactions, snapshotBatchSize: config.BatchSize, schema: config.DBSchema, tableQualifiedName: tableNames, @@ -319,7 +319,7 @@ func (s *Stream) AckLSN(ctx context.Context, lsn string) error { ) if err != nil { - return fmt.Errorf("Failed to send Standby status message at LSN %s: %w", clientXLogPos.String(), err) + return fmt.Errorf("failed to send Standby status message at LSN %s: %w", clientXLogPos.String(), err) } // Update client XLogPos after we ack the message @@ -336,7 +336,7 @@ func (s *Stream) streamMessages() error { case "wal2json": handler = NewWal2JsonPluginHandler(s.messages, s.monitor) case "pgoutput": - handler = NewPgOutputPluginHandler(s.messages, s.streamUncommitted, s.monitor, s.clientXLogPos) + handler = NewPgOutputPluginHandler(s.messages, s.batchTransactions, s.monitor, s.clientXLogPos) default: return fmt.Errorf("invalid decoding plugin: %q", s.decodingPlugin) } @@ -447,7 +447,7 @@ func (s *Stream) processSnapshot() error { avgRowSizeBytes, err = s.snapshotter.findAvgRowSize(table) if err != nil { - return fmt.Errorf("Failed to calculate average row size for table %v: %w", table, err) + return fmt.Errorf("failed to calculate average row size for table %v: %w", table, err) } availableMemory := getAvailableMemory() diff --git a/internal/impl/postgresql/pglogicalstream/pluginhandlers.go b/internal/impl/postgresql/pglogicalstream/pluginhandlers.go index 0f881c201a..e579a388a2 100644 --- a/internal/impl/postgresql/pglogicalstream/pluginhandlers.go +++ b/internal/impl/postgresql/pglogicalstream/pluginhandlers.go @@ -82,11 +82,11 @@ type PgOutputBufferedPluginHandler struct { // NewPgOutputPluginHandler creates a new PgOutputPluginHandler func NewPgOutputPluginHandler( messages chan StreamMessage, - streamUncommitted bool, + batchTransactions bool, monitor *Monitor, lsnWatermark *watermark.Value[LSN], ) PluginHandler { - if streamUncommitted { + if batchTransactions { return &PgOutputUnbufferedPluginHandler{ messages: messages, monitor: monitor, From 093c0ac6456d92f50dcf7010cae05a730f4e5930 Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Wed, 13 Nov 2024 22:10:40 +0000 Subject: [PATCH 102/118] pgcdc: fix config name --- internal/impl/postgresql/pglogicalstream/config.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/internal/impl/postgresql/pglogicalstream/config.go b/internal/impl/postgresql/pglogicalstream/config.go index 99415d65f8..bcb31f4181 100644 --- a/internal/impl/postgresql/pglogicalstream/config.go +++ b/internal/impl/postgresql/pglogicalstream/config.go @@ -38,8 +38,9 @@ type Config struct { DecodingPlugin string // BatchSize is the batch size for streaming BatchSize int - // StreamUncommitted is whether to stream uncommitted messages before receiving commit message - StreamUncommitted bool + // BatchTransactions is whether to buffer transactions as an entire single message or to send + // each row in a transaction as a message. This has no effect for wal2json. + BatchTransactions bool Logger *service.Logger From a2a80daac41454bdeefabd72fa4ccf6c5e10c718 Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Wed, 13 Nov 2024 22:16:55 +0000 Subject: [PATCH 103/118] pgcdc: add some TODOs --- internal/impl/postgresql/pglogicalstream/monitor.go | 1 + internal/impl/postgresql/pglogicalstream/pglogrepl.go | 6 ++++-- internal/impl/postgresql/pglogicalstream/snapshotter.go | 4 ++++ 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/internal/impl/postgresql/pglogicalstream/monitor.go b/internal/impl/postgresql/pglogicalstream/monitor.go index a22c56bf19..e55ae39766 100644 --- a/internal/impl/postgresql/pglogicalstream/monitor.go +++ b/internal/impl/postgresql/pglogicalstream/monitor.go @@ -106,6 +106,7 @@ func (m *Monitor) readTablesStat(tables []string) error { for _, table := range tables { tableWithoutSchema := strings.Split(table, ".")[1] + // TODO(le-vlad): Implement proper SQL injection protection query := "SELECT COUNT(*) FROM " + tableWithoutSchema var count int64 diff --git a/internal/impl/postgresql/pglogicalstream/pglogrepl.go b/internal/impl/postgresql/pglogicalstream/pglogrepl.go index f9cbd8d1ef..5e9647a09a 100644 --- a/internal/impl/postgresql/pglogicalstream/pglogrepl.go +++ b/internal/impl/postgresql/pglogicalstream/pglogrepl.go @@ -251,6 +251,7 @@ func CreateReplicationSlot( snapshotString = options.SnapshotAction } + // NOTE: All strings passed into here have been validated and are not prone to SQL injection. newPgCreateSlotCommand := fmt.Sprintf("CREATE_REPLICATION_SLOT %s %s %s %s %s", slotName, temporaryString, options.Mode, outputPlugin, snapshotString) oldPgCreateSlotCommand := fmt.Sprintf("SELECT * FROM pg_create_logical_replication_slot('%s', '%s', %v);", slotName, outputPlugin, temporaryString == "TEMPORARY") @@ -361,8 +362,7 @@ func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName tablesClause := "FOR ALL TABLES" if len(tables) > 0 { - // TODO: Implement proper SQL injection protection, potentially using parameterized queries - // or a SQL query builder that handles proper escaping + // TODO(le-vlad): Implement proper SQL injection protection tablesClause = "FOR TABLE " + strings.Join(tables, ",") } @@ -405,6 +405,7 @@ func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName // remove tables from publication for _, dropTable := range tablesToRemoveFromPublication { + // TODO(le-vlad): Implement proper SQL injection protection result = conn.Exec(ctx, fmt.Sprintf("ALTER PUBLICATION %s DROP TABLE %s;", publicationName, dropTable)) if _, err := result.ReadAll(); err != nil { return fmt.Errorf("failed to remove table from publication: %w", err) @@ -413,6 +414,7 @@ func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName // add tables to publication for _, addTable := range tablesToAddToPublication { + // TODO(le-vlad): Implement proper SQL injection protection result = conn.Exec(ctx, fmt.Sprintf("ALTER PUBLICATION %s ADD TABLE %s;", publicationName, addTable)) if _, err := result.ReadAll(); err != nil { return fmt.Errorf("failed to add table to publication: %w", err) diff --git a/internal/impl/postgresql/pglogicalstream/snapshotter.go b/internal/impl/postgresql/pglogicalstream/snapshotter.go index a5f8931c3d..4129f19cd0 100644 --- a/internal/impl/postgresql/pglogicalstream/snapshotter.go +++ b/internal/impl/postgresql/pglogicalstream/snapshotter.go @@ -97,6 +97,7 @@ func (s *Snapshotter) prepare() error { if _, err := s.pgConnection.Exec("BEGIN TRANSACTION ISOLATION LEVEL REPEATABLE READ;"); err != nil { return err } + // TODO(le-vlad): Implement proper SQL injection protection if _, err := s.pgConnection.Exec(fmt.Sprintf("SET TRANSACTION SNAPSHOT '%s';", s.snapshotName)); err != nil { return err } @@ -110,6 +111,7 @@ func (s *Snapshotter) findAvgRowSize(table string) (sql.NullInt64, error) { rows *sql.Rows err error ) + // TODO(le-vlad): Implement proper SQL injection protection if rows, err = s.pgConnection.Query(fmt.Sprintf(`SELECT SUM(pg_column_size('%s.*')) / COUNT(*) FROM %s;`, table, table)); err != nil { return avgRowSize, fmt.Errorf("can get avg row size due to query failure: %w", err) } @@ -168,8 +170,10 @@ func (s *Snapshotter) calculateBatchSize(availableMemory uint64, estimatedRowSiz func (s *Snapshotter) querySnapshotData(table string, lastSeenPk any, pk string, limit int) (rows *sql.Rows, err error) { s.logger.Infof("Query snapshot table: %v, limit: %v, lastSeenPkVal: %v, pk: %v", table, limit, lastSeenPk, pk) if lastSeenPk == nil { + // TODO(le-vlad): Implement proper SQL injection protection return s.pgConnection.Query(fmt.Sprintf("SELECT * FROM %s ORDER BY %s LIMIT $1;", table, pk), limit) } + // TODO(le-vlad): Implement proper SQL injection protection return s.pgConnection.Query(fmt.Sprintf("SELECT * FROM %s WHERE %s > $1 ORDER BY %s LIMIT $2;", table, pk, pk), lastSeenPk, limit) } From 8084bdf8b03dd057dff3103487a5fccd6b7b832a Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Wed, 13 Nov 2024 22:26:54 +0000 Subject: [PATCH 104/118] pgcdc: update docs --- docs/modules/components/pages/inputs/pg_stream.adoc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/modules/components/pages/inputs/pg_stream.adoc b/docs/modules/components/pages/inputs/pg_stream.adoc index 3b2ddb59b2..1851cf1790 100644 --- a/docs/modules/components/pages/inputs/pg_stream.adoc +++ b/docs/modules/components/pages/inputs/pg_stream.adoc @@ -40,7 +40,7 @@ input: label: "" pg_stream: dsn: postgres://foouser:foopass@localhost:5432/foodb?sslmode=disable # No default (required) - stream_uncommitted: false + batch_transactions: true stream_snapshot: false snapshot_memory_safety_factor: 1 snapshot_batch_size: 0 @@ -72,7 +72,7 @@ input: label: "" pg_stream: dsn: postgres://foouser:foopass@localhost:5432/foodb?sslmode=disable # No default (required) - stream_uncommitted: false + batch_transactions: true stream_snapshot: false snapshot_memory_safety_factor: 1 snapshot_batch_size: 0 @@ -124,14 +124,14 @@ The Data Source Name for the PostgreSQL database in the form of `postgres://[use dsn: postgres://foouser:foopass@localhost:5432/foodb?sslmode=disable ``` -=== `stream_uncommitted` +=== `batch_transactions` -If set to true, the plugin will stream uncommitted transactions before receiving a commit message from PostgreSQL. This may result in duplicate records if the connector is restarted. +When set to true, transactions are batched into a single message. Note that this setting has no effect when using wal2json *Type*: `bool` -*Default*: `false` +*Default*: `true` === `stream_snapshot` From 1f8a650d6591c2ff084cc3033f0d8f95d1fccd31 Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Thu, 14 Nov 2024 02:33:59 +0000 Subject: [PATCH 105/118] pgcdc: review feedback --- internal/impl/postgresql/input_pg_stream.go | 6 +-- .../postgresql/pglogicalstream/connection.go | 1 + .../postgresql/pglogicalstream/debouncer.go | 43 ------------------- .../pglogicalstream/logical_stream.go | 7 +-- 4 files changed, 3 insertions(+), 54 deletions(-) delete mode 100644 internal/impl/postgresql/pglogicalstream/debouncer.go diff --git a/internal/impl/postgresql/input_pg_stream.go b/internal/impl/postgresql/input_pg_stream.go index f1a3ce4e68..d0fc05e8c4 100644 --- a/internal/impl/postgresql/input_pg_stream.go +++ b/internal/impl/postgresql/input_pg_stream.go @@ -282,11 +282,7 @@ func validateSimpleString(s string) error { } func init() { - err := service.RegisterBatchInput( - "pg_stream", pgStreamConfigSpec, - func(conf *service.ParsedConfig, mgr *service.Resources) (service.BatchInput, error) { - return newPgStreamInput(conf, mgr) - }) + err := service.RegisterBatchInput("pg_stream", pgStreamConfigSpec, newPgStreamInput) if err != nil { panic(err) } diff --git a/internal/impl/postgresql/pglogicalstream/connection.go b/internal/impl/postgresql/pglogicalstream/connection.go index 7b801236ef..9f81c3be2b 100644 --- a/internal/impl/postgresql/pglogicalstream/connection.go +++ b/internal/impl/postgresql/pglogicalstream/connection.go @@ -26,6 +26,7 @@ func getPostgresVersion(dbDSN string) (int, error) { if err != nil { return 0, fmt.Errorf("failed to connect to the database: %w", err) } + defer conn.Close() var versionString string err = conn.QueryRow("SHOW server_version").Scan(&versionString) diff --git a/internal/impl/postgresql/pglogicalstream/debouncer.go b/internal/impl/postgresql/pglogicalstream/debouncer.go deleted file mode 100644 index 9fbd9ae4f0..0000000000 --- a/internal/impl/postgresql/pglogicalstream/debouncer.go +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright 2024 Redpanda Data, Inc. -// -// Licensed as a Redpanda Enterprise file under the Redpanda Community -// License (the "License"); you may not use this file except in compliance with -// the License. You may obtain a copy of the License at -// -// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md - -package pglogicalstream - -import ( - "sync" - "time" -) - -// NewDebouncer New returns a debounced function that takes another functions as its argument. -// This function will be called when the debounced function stops being called -// for the given duration. -// The debounced function can be invoked with different functions, if needed, -// the last one will win. -func NewDebouncer(after time.Duration) func(f func()) { - d := &debouncer{after: after} - - return func(f func()) { - d.add(f) - } -} - -type debouncer struct { - mu sync.Mutex - after time.Duration - timer *time.Timer -} - -func (d *debouncer) add(f func()) { - d.mu.Lock() - defer d.mu.Unlock() - - if d.timer != nil { - d.timer.Stop() - } - d.timer = time.AfterFunc(d.after, f) -} diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 905f76a732..19b237ef8f 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -216,12 +216,7 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { } else { lsnrestart, _ = ParseLSN(confirmedLSNFromDB) } - - if freshlyCreatedSlot { - stream.clientXLogPos = watermark.New(sysident.XLogPos) - } else { - stream.clientXLogPos = watermark.New(lsnrestart) - } + stream.clientXLogPos = watermark.New(lsnrestart) stream.standbyMessageTimeout = config.PgStandbyTimeout stream.nextStandbyMessageDeadline = time.Now().Add(stream.standbyMessageTimeout) From cd352e004ef9ba174953ba9e79ba1a9e7d8ceb6c Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Thu, 14 Nov 2024 02:43:13 +0000 Subject: [PATCH 106/118] pgcdc: cleanup monitor with periodic utility --- .../pglogicalstream/logical_stream.go | 2 +- .../postgresql/pglogicalstream/monitor.go | 64 ++++++++----------- 2 files changed, 28 insertions(+), 38 deletions(-) diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 19b237ef8f..2c6c6bb11e 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -221,7 +221,7 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { stream.standbyMessageTimeout = config.PgStandbyTimeout stream.nextStandbyMessageDeadline = time.Now().Add(stream.standbyMessageTimeout) - monitor, err := NewMonitor(config.DBRawDSN, stream.logger, tableNames, stream.slotName, config.WalMonitorInterval) + monitor, err := NewMonitor(ctx, config.DBRawDSN, stream.logger, tableNames, stream.slotName, config.WalMonitorInterval) if err != nil { return nil, err } diff --git a/internal/impl/postgresql/pglogicalstream/monitor.go b/internal/impl/postgresql/pglogicalstream/monitor.go index e55ae39766..c1bd1ccd90 100644 --- a/internal/impl/postgresql/pglogicalstream/monitor.go +++ b/internal/impl/postgresql/pglogicalstream/monitor.go @@ -12,12 +12,14 @@ import ( "context" "database/sql" "fmt" + "maps" "math" "strings" "sync" "time" "github.com/redpanda-data/benthos/v4/public/service" + "github.com/redpanda-data/connect/v4/internal/periodic" ) // Report is a structure that contains the current state of the Monitor @@ -38,16 +40,21 @@ type Monitor struct { // finding the difference between the latest LSN and the last confirmed LSN for the replication slot replicationLagInBytes int64 - dbConn *sql.DB - slotName string - logger *service.Logger - ticker *time.Ticker - cancelTicker context.CancelFunc - ctx context.Context + dbConn *sql.DB + slotName string + logger *service.Logger + loop *periodic.Periodic } // NewMonitor creates a new Monitor instance -func NewMonitor(dbDSN string, logger *service.Logger, tables []string, slotName string, interval time.Duration) (*Monitor, error) { +func NewMonitor( + ctx context.Context, + dbDSN string, + logger *service.Logger, + tables []string, + slotName string, + interval time.Duration, +) (*Monitor, error) { dbConn, err := openPgConnectionFromConfig(dbDSN) if err != nil { return nil, err @@ -60,29 +67,11 @@ func NewMonitor(dbDSN string, logger *service.Logger, tables []string, slotName slotName: slotName, logger: logger, } - - if err = m.readTablesStat(tables); err != nil { + m.loop = periodic.NewWithContext(interval, m.readReplicationLag) + if err = m.readTablesStat(ctx, tables); err != nil { return nil, err } - - ctx, cancel := context.WithCancel(context.Background()) - m.ctx = ctx - m.cancelTicker = cancel - m.ticker = time.NewTicker(interval) - - go func() { - for { - select { - case <-m.ticker.C: - m.readReplicationLag() - break - case <-m.ctx.Done(): - m.ticker.Stop() - return - } - } - }() - + m.loop.Start() return m, nil } @@ -101,7 +90,7 @@ func (m *Monitor) UpdateSnapshotProgressForTable(table string, position int) { } // we need to read the tables stat to calculate the snapshot ingestion progress -func (m *Monitor) readTablesStat(tables []string) error { +func (m *Monitor) readTablesStat(ctx context.Context, tables []string) error { results := make(map[string]int64) for _, table := range tables { @@ -110,7 +99,7 @@ func (m *Monitor) readTablesStat(tables []string) error { query := "SELECT COUNT(*) FROM " + tableWithoutSchema var count int64 - err := m.dbConn.QueryRow(query).Scan(&count) + err := m.dbConn.QueryRowContext(ctx, query).Scan(&count) if err != nil { // If the error is because the table doesn't exist, we'll set the count to 0 @@ -130,14 +119,14 @@ func (m *Monitor) readTablesStat(tables []string) error { return nil } -func (m *Monitor) readReplicationLag() { - result, err := m.dbConn.Query(`SELECT slot_name, +func (m *Monitor) readReplicationLag(ctx context.Context) { + result, err := m.dbConn.QueryContext(ctx, `SELECT slot_name, pg_wal_lsn_diff(pg_current_wal_lsn(), restart_lsn) AS lag_bytes FROM pg_replication_slots WHERE slot_name = $1;`, m.slotName) // calculate the replication lag in bytes // replicationLagInBytes = latestLsn - confirmedLsn if err != nil || result.Err() != nil { - m.logger.Errorf("Error reading replication lag: %v", err) + m.logger.Warnf("Error reading replication lag: %v", err) return } @@ -145,12 +134,14 @@ func (m *Monitor) readReplicationLag() { var lagbytes int64 for result.Next() { if err = result.Scan(&slotName, &lagbytes); err != nil { - m.logger.Errorf("Error reading replication lag: %v", err) + m.logger.Warnf("Error reading replication lag: %v", err) return } } + m.lock.Lock() m.replicationLagInBytes = lagbytes + m.lock.Unlock() } // Report returns a snapshot of the monitor's state @@ -161,13 +152,12 @@ func (m *Monitor) Report() *Report { // report the replication lag return &Report{ WalLagInBytes: m.replicationLagInBytes, - TableProgress: m.snapshotProgress, + TableProgress: maps.Clone(m.snapshotProgress), } } // Stop stops the monitor func (m *Monitor) Stop() error { - m.cancelTicker() - m.ticker.Stop() + m.loop.Stop() return m.dbConn.Close() } From 5594cbd54b2ad3ea4b13fec61545d94886576df0 Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Thu, 14 Nov 2024 02:50:11 +0000 Subject: [PATCH 107/118] pgcdc: fmt --- internal/impl/postgresql/pglogicalstream/monitor.go | 1 + 1 file changed, 1 insertion(+) diff --git a/internal/impl/postgresql/pglogicalstream/monitor.go b/internal/impl/postgresql/pglogicalstream/monitor.go index c1bd1ccd90..cf3f3853b7 100644 --- a/internal/impl/postgresql/pglogicalstream/monitor.go +++ b/internal/impl/postgresql/pglogicalstream/monitor.go @@ -19,6 +19,7 @@ import ( "time" "github.com/redpanda-data/benthos/v4/public/service" + "github.com/redpanda-data/connect/v4/internal/periodic" ) From 272eef0e01db7b9af38e69b94431f91407f8c4b7 Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Thu, 14 Nov 2024 04:16:18 +0000 Subject: [PATCH 108/118] pgcdc: check for non-zero duration --- internal/impl/postgresql/pglogicalstream/monitor.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/internal/impl/postgresql/pglogicalstream/monitor.go b/internal/impl/postgresql/pglogicalstream/monitor.go index cf3f3853b7..b8e8e8f75e 100644 --- a/internal/impl/postgresql/pglogicalstream/monitor.go +++ b/internal/impl/postgresql/pglogicalstream/monitor.go @@ -60,6 +60,9 @@ func NewMonitor( if err != nil { return nil, err } + if interval <= 0 { + return nil, fmt.Errorf("invalid monitoring interval: %s", interval.String()) + } m := &Monitor{ snapshotProgress: map[string]float64{}, From b70f69454d6da27cae19068d89368eeed62fc829 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Thu, 14 Nov 2024 12:26:15 +0100 Subject: [PATCH 109/118] chore(): sanitized queries && fixed tests --- internal/impl/postgresql/integration_test.go | 10 ++-- .../pglogicalstream/logical_stream.go | 4 ++ .../postgresql/pglogicalstream/monitor.go | 11 +++-- .../postgresql/pglogicalstream/pglogrepl.go | 48 +++++++++++++++---- .../pglogicalstream/sanitize/sanitize.go | 30 ++++++++++++ .../postgresql/pglogicalstream/snapshotter.go | 32 ++++++++++--- 6 files changed, 112 insertions(+), 23 deletions(-) diff --git a/internal/impl/postgresql/integration_test.go b/internal/impl/postgresql/integration_test.go index 283a281c26..704f7bab49 100644 --- a/internal/impl/postgresql/integration_test.go +++ b/internal/impl/postgresql/integration_test.go @@ -352,7 +352,7 @@ file: `, tmpDir) streamOutBuilder := service.NewStreamBuilder() - require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: OFF`)) + require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: INFO`)) require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) require.NoError(t, streamOutBuilder.AddInputYAML(template)) @@ -455,7 +455,7 @@ pg_stream: snapshot_batch_size: 100000 stream_snapshot: true decoding_plugin: pgoutput - stream_uncommitted: false + batch_transactions: false temporary_slot: true schema: public tables: @@ -550,7 +550,7 @@ pg_stream: snapshot_batch_size: 100 stream_snapshot: true decoding_plugin: pgoutput - stream_uncommitted: true + batch_transactions: true schema: public tables: - flights @@ -689,7 +689,7 @@ pg_stream: slot_name: test_slot_native_decoder stream_snapshot: true decoding_plugin: pgoutput - stream_uncommitted: true + batch_transactions: true schema: public tables: - flights @@ -826,7 +826,7 @@ pg_stream: slot_name: test_slot_native_decoder stream_snapshot: true decoding_plugin: pgoutput - stream_uncommitted: false + batch_transactions: false schema: public tables: - flights diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 2c6c6bb11e..9f2ea48f91 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -88,6 +88,10 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { tableNames := slices.Clone(config.DBTables) for i, table := range tableNames { + if err := sanitize.ValidatePostgresIdentifier(table); err != nil { + return nil, fmt.Errorf("invalid table name %q: %w", table, err) + } + tableNames[i] = fmt.Sprintf("%s.%s", config.DBSchema, table) } stream := &Stream{ diff --git a/internal/impl/postgresql/pglogicalstream/monitor.go b/internal/impl/postgresql/pglogicalstream/monitor.go index b8e8e8f75e..f800bb1591 100644 --- a/internal/impl/postgresql/pglogicalstream/monitor.go +++ b/internal/impl/postgresql/pglogicalstream/monitor.go @@ -20,6 +20,7 @@ import ( "github.com/redpanda-data/benthos/v4/public/service" + "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/sanitize" "github.com/redpanda-data/connect/v4/internal/periodic" ) @@ -99,11 +100,15 @@ func (m *Monitor) readTablesStat(ctx context.Context, tables []string) error { for _, table := range tables { tableWithoutSchema := strings.Split(table, ".")[1] - // TODO(le-vlad): Implement proper SQL injection protection - query := "SELECT COUNT(*) FROM " + tableWithoutSchema + err := sanitize.ValidatePostgresIdentifier(tableWithoutSchema) + + if err != nil { + return fmt.Errorf("error sanitizing query: %w", err) + } var count int64 - err := m.dbConn.QueryRowContext(ctx, query).Scan(&count) + // tableWithoutSchema has been validated so its safe to use in the query + err = m.dbConn.QueryRowContext(ctx, "SELECT COUNT(*) FROM %s"+tableWithoutSchema).Scan(&count) if err != nil { // If the error is because the table doesn't exist, we'll set the count to 0 diff --git a/internal/impl/postgresql/pglogicalstream/pglogrepl.go b/internal/impl/postgresql/pglogicalstream/pglogrepl.go index 5e9647a09a..833b968938 100644 --- a/internal/impl/postgresql/pglogicalstream/pglogrepl.go +++ b/internal/impl/postgresql/pglogicalstream/pglogrepl.go @@ -21,6 +21,7 @@ import ( "context" "database/sql/driver" "encoding/binary" + "errors" "fmt" "slices" "strconv" @@ -253,7 +254,7 @@ func CreateReplicationSlot( // NOTE: All strings passed into here have been validated and are not prone to SQL injection. newPgCreateSlotCommand := fmt.Sprintf("CREATE_REPLICATION_SLOT %s %s %s %s %s", slotName, temporaryString, options.Mode, outputPlugin, snapshotString) - oldPgCreateSlotCommand := fmt.Sprintf("SELECT * FROM pg_create_logical_replication_slot('%s', '%s', %v);", slotName, outputPlugin, temporaryString == "TEMPORARY") + oldPgCreateSlotCommand := fmt.Sprintf("SELECT * FROM pg_create_logical_replication_slot('%s', '%s', %v);", slotName, outputPlugin, temporaryString == "TEMPORARY") var snapshotName string if version > 14 { @@ -353,6 +354,17 @@ func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName return fmt.Errorf("failed to sanitize publication query: %w", err) } + // Since we need to pass table names without quoting, we need to validate it + for _, table := range tables { + if err := sanitize.ValidatePostgresIdentifier(table); err != nil { + return errors.New("invalid table name") + } + } + // the same for publication name + if err := sanitize.ValidatePostgresIdentifier(publicationName); err != nil { + return errors.New("invalid publication name") + } + result := conn.Exec(ctx, pubQuery) rows, err := result.ReadAll() @@ -362,13 +374,27 @@ func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName tablesClause := "FOR ALL TABLES" if len(tables) > 0 { - // TODO(le-vlad): Implement proper SQL injection protection - tablesClause = "FOR TABLE " + strings.Join(tables, ",") + // quotedTables := make([]string, len(tables)) + // for i, table := range tables { + // // Use sanitize.SQLIdentifier to properly quote and escape table names + // quoted, err := sanitize.SQLIdentifier(table) + // if err != nil { + // return fmt.Errorf("invalid table name %q: %w", table, err) + // } + // quotedTables[i] = quoted + // } + tablesClause = "FOR TABLE " + strings.Join(tables, ", ") } if len(rows) == 0 || len(rows[0].Rows) == 0 { + // tablesClause is sanitized, so we can safely interpolate it into the query + sq, err := sanitize.SQLQuery(fmt.Sprintf("CREATE PUBLICATION %s %s;", publicationName, tablesClause)) + fmt.Print(sq) + if err != nil { + return fmt.Errorf("failed to sanitize publication creation query: %w", err) + } // Publication doesn't exist, create new one - result = conn.Exec(ctx, fmt.Sprintf("CREATE PUBLICATION %s %s;", publicationName, tablesClause)) + result = conn.Exec(ctx, sq) if _, err := result.ReadAll(); err != nil { return fmt.Errorf("failed to create publication: %w", err) } @@ -405,8 +431,11 @@ func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName // remove tables from publication for _, dropTable := range tablesToRemoveFromPublication { - // TODO(le-vlad): Implement proper SQL injection protection - result = conn.Exec(ctx, fmt.Sprintf("ALTER PUBLICATION %s DROP TABLE %s;", publicationName, dropTable)) + sq, err := sanitize.SQLQuery(fmt.Sprintf("ALTER PUBLICATION %s DROP TABLE %s;", publicationName, dropTable)) + if err != nil { + return fmt.Errorf("failed to sanitize drop table query: %w", err) + } + result = conn.Exec(ctx, sq) if _, err := result.ReadAll(); err != nil { return fmt.Errorf("failed to remove table from publication: %w", err) } @@ -414,8 +443,11 @@ func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName // add tables to publication for _, addTable := range tablesToAddToPublication { - // TODO(le-vlad): Implement proper SQL injection protection - result = conn.Exec(ctx, fmt.Sprintf("ALTER PUBLICATION %s ADD TABLE %s;", publicationName, addTable)) + sq, err := sanitize.SQLQuery(fmt.Sprintf("ALTER PUBLICATION %s ADD TABLE %s;", publicationName, addTable)) + if err != nil { + return fmt.Errorf("failed to sanitize add table query: %w", err) + } + result = conn.Exec(ctx, sq) if _, err := result.ReadAll(); err != nil { return fmt.Errorf("failed to add table to publication: %w", err) } diff --git a/internal/impl/postgresql/pglogicalstream/sanitize/sanitize.go b/internal/impl/postgresql/pglogicalstream/sanitize/sanitize.go index 95f35bd8c0..5bba854bbb 100644 --- a/internal/impl/postgresql/pglogicalstream/sanitize/sanitize.go +++ b/internal/impl/postgresql/pglogicalstream/sanitize/sanitize.go @@ -33,9 +33,13 @@ import ( "strconv" "strings" "time" + "unicode" "unicode/utf8" ) +// MaxIdentifierLength is PostgreSQL's maximum identifier length +const MaxIdentifierLength = 63 + // Part is either a string or an int. A string is raw SQL. An int is a // argument placeholder. type Part any @@ -358,3 +362,29 @@ func SQLQuery(sql string, args ...any) (string, error) { } return query.Sanitize(args...) } + +// ValidatePostgresIdentifier checks if a string is a valid PostgreSQL identifier +// This follows PostgreSQL's standard naming rules +func ValidatePostgresIdentifier(name string) error { + if len(name) == 0 { + return errors.New("empty identifier is not allowed") + } + + if len(name) > MaxIdentifierLength { + return fmt.Errorf("identifier length exceeds maximum of %d characters", MaxIdentifierLength) + } + + // First character must be a letter or underscore + if !unicode.IsLetter(rune(name[0])) && name[0] != '_' { + return errors.New("identifier must start with a letter or underscore") + } + + // Subsequent characters must be letters, numbers, underscores, or dots + for i, char := range name { + if !unicode.IsLetter(char) && !unicode.IsDigit(char) && char != '_' && char != '.' { + return fmt.Errorf("invalid character '%c' at position %d in identifier '%s'", char, i, name) + } + } + + return nil +} diff --git a/internal/impl/postgresql/pglogicalstream/snapshotter.go b/internal/impl/postgresql/pglogicalstream/snapshotter.go index 4129f19cd0..ce581c169c 100644 --- a/internal/impl/postgresql/pglogicalstream/snapshotter.go +++ b/internal/impl/postgresql/pglogicalstream/snapshotter.go @@ -16,6 +16,7 @@ import ( _ "github.com/lib/pq" "github.com/redpanda-data/benthos/v4/public/service" + "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/sanitize" ) // SnapshotCreationResponse is a structure that contains the name of the snapshot that was created @@ -97,8 +98,13 @@ func (s *Snapshotter) prepare() error { if _, err := s.pgConnection.Exec("BEGIN TRANSACTION ISOLATION LEVEL REPEATABLE READ;"); err != nil { return err } - // TODO(le-vlad): Implement proper SQL injection protection - if _, err := s.pgConnection.Exec(fmt.Sprintf("SET TRANSACTION SNAPSHOT '%s';", s.snapshotName)); err != nil { + + sq, err := sanitize.SQLQuery("SET TRANSACTION SNAPSHOT $1;", s.snapshotName) + if err != nil { + return err + } + + if _, err := s.pgConnection.Exec(sq); err != nil { return err } @@ -111,7 +117,8 @@ func (s *Snapshotter) findAvgRowSize(table string) (sql.NullInt64, error) { rows *sql.Rows err error ) - // TODO(le-vlad): Implement proper SQL injection protection + + // table is validated to be correct pg identifier, so we can use it directly if rows, err = s.pgConnection.Query(fmt.Sprintf(`SELECT SUM(pg_column_size('%s.*')) / COUNT(*) FROM %s;`, table, table)); err != nil { return avgRowSize, fmt.Errorf("can get avg row size due to query failure: %w", err) } @@ -170,11 +177,22 @@ func (s *Snapshotter) calculateBatchSize(availableMemory uint64, estimatedRowSiz func (s *Snapshotter) querySnapshotData(table string, lastSeenPk any, pk string, limit int) (rows *sql.Rows, err error) { s.logger.Infof("Query snapshot table: %v, limit: %v, lastSeenPkVal: %v, pk: %v", table, limit, lastSeenPk, pk) if lastSeenPk == nil { - // TODO(le-vlad): Implement proper SQL injection protection - return s.pgConnection.Query(fmt.Sprintf("SELECT * FROM %s ORDER BY %s LIMIT $1;", table, pk), limit) + // NOTE: All strings passed into here have been validated or derived from the code/database, therefore not prone to SQL injection. + sq, err := sanitize.SQLQuery(fmt.Sprintf("SELECT * FROM %s ORDER BY %s LIMIT %d;", table, pk, limit)) + if err != nil { + return nil, err + } + + return s.pgConnection.Query(sq) + } + + // NOTE: All strings passed into here have been validated or derived from the code/database, therefore not prone to SQL injection. + sq, err := sanitize.SQLQuery(fmt.Sprintf("SELECT * FROM %s WHERE %s > %s ORDER BY %s LIMIT %d;", table, pk, lastSeenPk, pk, limit)) + if err != nil { + return nil, err } - // TODO(le-vlad): Implement proper SQL injection protection - return s.pgConnection.Query(fmt.Sprintf("SELECT * FROM %s WHERE %s > $1 ORDER BY %s LIMIT $2;", table, pk, pk), lastSeenPk, limit) + + return s.pgConnection.Query(sq) } func (s *Snapshotter) releaseSnapshot() error { From 2a4e42b2a40b280dfb38c076409aab7784184ea0 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Fri, 15 Nov 2024 11:51:22 +0100 Subject: [PATCH 110/118] chore(): removed wal2json support --- internal/impl/postgresql/input_pg_stream.go | 15 +- internal/impl/postgresql/integration_test.go | 195 ------------------ .../impl/postgresql/pglogicalstream/config.go | 4 +- .../impl/postgresql/pglogicalstream/consts.go | 34 --- .../pglogicalstream/logical_stream.go | 45 ++-- .../postgresql/pglogicalstream/monitor.go | 2 +- .../pglogicalstream/pluginhandlers.go | 35 ---- .../replication_message_decoders.go | 68 ------ 8 files changed, 16 insertions(+), 382 deletions(-) delete mode 100644 internal/impl/postgresql/pglogicalstream/consts.go diff --git a/internal/impl/postgresql/input_pg_stream.go b/internal/impl/postgresql/input_pg_stream.go index d0fc05e8c4..7adaa20543 100644 --- a/internal/impl/postgresql/input_pg_stream.go +++ b/internal/impl/postgresql/input_pg_stream.go @@ -29,7 +29,6 @@ const ( fieldStreamSnapshot = "stream_snapshot" fieldSnapshotMemSafetyFactor = "snapshot_memory_safety_factor" fieldSnapshotBatchSize = "snapshot_batch_size" - fieldDecodingPlugin = "decoding_plugin" fieldSchema = "schema" fieldTables = "tables" fieldCheckpointLimit = "checkpoint_limit" @@ -67,7 +66,7 @@ This input adds the following metadata fields to each message: Description("The Data Source Name for the PostgreSQL database in the form of `postgres://[user[:password]@][netloc][:port][/dbname][?param1=value1&...]`. Please note that Postgres enforces SSL by default, you can override this with the parameter `sslmode=disable` if required."). Example("postgres://foouser:foopass@localhost:5432/foodb?sslmode=disable")). Field(service.NewBoolField(fieldBatchTransactions). - Description("When set to true, transactions are batched into a single message. Note that this setting has no effect when using wal2json"). + Description("When set to true, transactions are batched into a single message."). Default(true)). Field(service.NewBoolField(fieldStreamSnapshot). Description("When set to true, the plugin will first stream a snapshot of all existing data in the database before streaming changes. In order to use this the tables that are being snapshot MUST have a primary key set so that reading from the table can be parallelized."). @@ -81,12 +80,6 @@ This input adds the following metadata fields to each message: Description("The number of rows to fetch in each batch when querying the snapshot. A value of 0 lets the plugin determine the batch size based on `snapshot_memory_safety_factor` property."). Example(10000). Default(0)). - Field(service.NewStringEnumField(fieldDecodingPlugin, "pgoutput", "wal2json"). - Description(`Specifies the logical decoding plugin to use for streaming changes from PostgreSQL. 'pgoutput' is the native logical replication protocol, while 'wal2json' provides change data as JSON. -Important: No matter which plugin you choose, the data will be converted to JSON before sending it to Connect. - `). - Example("pgoutput"). - Default("pgoutput")). Field(service.NewStringField(fieldSchema). Description("The PostgreSQL schema from which to replicate data."). Example("public")). @@ -129,7 +122,6 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser tables []string streamSnapshot bool snapshotMemSafetyFactor float64 - decodingPlugin string batchTransactions bool snapshotBatchSize int checkpointLimit int @@ -182,10 +174,6 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser return nil, err } - if decodingPlugin, err = conf.FieldString(fieldDecodingPlugin); err != nil { - return nil, err - } - if snapshotMemSafetyFactor, err = conf.FieldFloat(fieldSnapshotMemSafetyFactor); err != nil { return nil, err } @@ -238,7 +226,6 @@ func newPgStreamInput(conf *service.ParsedConfig, mgr *service.Resources) (s ser StreamOldData: streamSnapshot, TemporaryReplicationSlot: temporarySlot, BatchTransactions: batchTransactions, - DecodingPlugin: decodingPlugin, SnapshotMemorySafetyFactor: snapshotMemSafetyFactor, PgStandbyTimeout: pgStandbyTimeout, WalMonitorInterval: walMonitorInterval, diff --git a/internal/impl/postgresql/integration_test.go b/internal/impl/postgresql/integration_test.go index 704f7bab49..616c578e8e 100644 --- a/internal/impl/postgresql/integration_test.go +++ b/internal/impl/postgresql/integration_test.go @@ -116,196 +116,6 @@ func ResourceWithPostgreSQLVersion(t *testing.T, pool *dockertest.Pool, version return resource, db, nil } -func TestIntegrationPgCDC(t *testing.T) { - integration.CheckSkip(t) - - tmpDir := t.TempDir() - pool, err := dockertest.NewPool("") - require.NoError(t, err) - - // Use custom PostgreSQL image with wal2json plugin compiled in - resource, err := pool.RunWithOptions(&dockertest.RunOptions{ - Repository: "usedatabrew/pgwal2json", - Tag: "16", - Env: []string{ - "POSTGRES_PASSWORD=l]YLSc|4[i56%{gY", - "POSTGRES_USER=user_name", - "POSTGRES_DB=dbname", - }, - Cmd: []string{ - "postgres", - "-c", "wal_level=logical", - }, - }, func(config *docker.HostConfig) { - config.AutoRemove = true - config.RestartPolicy = docker.RestartPolicy{Name: "no"} - }) - - require.NoError(t, err) - t.Cleanup(func() { - assert.NoError(t, pool.Purge(resource)) - }) - - require.NoError(t, resource.Expire(120)) - - hostAndPort := resource.GetHostPort("5432/tcp") - hostAndPortSplited := strings.Split(hostAndPort, ":") - password := "l]YLSc|4[i56%{gY" - databaseURL := fmt.Sprintf("user=user_name password=%s dbname=dbname sslmode=disable host=%s port=%s", password, hostAndPortSplited[0], hostAndPortSplited[1]) - - var db *sql.DB - - pool.MaxWait = 120 * time.Second - if err = pool.Retry(func() error { - if db, err = sql.Open("postgres", databaseURL); err != nil { - return err - } - - if err = db.Ping(); err != nil { - return err - } - - var walLevel string - if err = db.QueryRow("SHOW wal_level").Scan(&walLevel); err != nil { - return err - } - - var pgConfig string - if err = db.QueryRow("SHOW config_file").Scan(&pgConfig); err != nil { - return err - } - - if walLevel != "logical" { - return fmt.Errorf("wal_level is not logical") - } - - _, err = db.Exec("CREATE TABLE IF NOT EXISTS flights (id serial PRIMARY KEY, name VARCHAR(50), created_at TIMESTAMP);") - if err != nil { - return err - } - - // flights_non_streamed is a control table with data that should not be streamed or queried by snapshot streaming - _, err = db.Exec("CREATE TABLE IF NOT EXISTS flights_non_streamed (id serial PRIMARY KEY, name VARCHAR(50), created_at TIMESTAMP);") - - return err - }); err != nil { - panic(fmt.Errorf("could not connect to docker: %w", err)) - } - - for i := 0; i < 1000; i++ { - f := GetFakeFlightRecord() - _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) - _, err = db.Exec("INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) - require.NoError(t, err) - } - - template := fmt.Sprintf(` -pg_stream: - dsn: %s - slot_name: test_slot - decoding_plugin: wal2json - stream_snapshot: true - schema: public - tables: - - flights -`, databaseURL) - - cacheConf := fmt.Sprintf(` -label: pg_stream_cache -file: - directory: %v -`, tmpDir) - - streamOutBuilder := service.NewStreamBuilder() - require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: INFO`)) - require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) - require.NoError(t, streamOutBuilder.AddInputYAML(template)) - - var outBatches []string - var outBatchMut sync.Mutex - require.NoError(t, streamOutBuilder.AddBatchConsumerFunc(func(c context.Context, mb service.MessageBatch) error { - msgBytes, err := mb[0].AsBytes() - require.NoError(t, err) - outBatchMut.Lock() - outBatches = append(outBatches, string(msgBytes)) - outBatchMut.Unlock() - return nil - })) - - streamOut, err := streamOutBuilder.Build() - require.NoError(t, err) - - go func() { - _ = streamOut.Run(context.Background()) - }() - - assert.Eventually(t, func() bool { - outBatchMut.Lock() - defer outBatchMut.Unlock() - return len(outBatches) == 1000 - }, time.Second*25, time.Millisecond*100) - - for i := 0; i < 1000; i++ { - f := GetFakeFlightRecord() - _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) - _, err = db.Exec("INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) - require.NoError(t, err) - } - - assert.Eventually(t, func() bool { - outBatchMut.Lock() - defer outBatchMut.Unlock() - return len(outBatches) == 2000 - }, time.Second, time.Millisecond*100) - - require.NoError(t, streamOut.StopWithin(time.Second*10)) - - // Starting stream for the same replication slot should continue from the last LSN - // Meaning we must not receive any old messages again - - streamOutBuilder = service.NewStreamBuilder() - require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: OFF`)) - require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) - require.NoError(t, streamOutBuilder.AddInputYAML(template)) - - outBatches = []string{} - require.NoError(t, streamOutBuilder.AddBatchConsumerFunc(func(c context.Context, mb service.MessageBatch) error { - msgBytes, err := mb[0].AsBytes() - require.NoError(t, err) - outBatchMut.Lock() - outBatches = append(outBatches, string(msgBytes)) - outBatchMut.Unlock() - return nil - })) - - streamOut, err = streamOutBuilder.Build() - require.NoError(t, err) - - go func() { - assert.NoError(t, streamOut.Run(context.Background())) - }() - - time.Sleep(time.Second * 5) - for i := 0; i < 50; i++ { - f := GetFakeFlightRecord() - _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) - require.NoError(t, err) - } - - assert.Eventually(t, func() bool { - outBatchMut.Lock() - defer outBatchMut.Unlock() - return len(outBatches) == 50 - }, time.Second*20, time.Millisecond*100) - - require.NoError(t, streamOut.StopWithin(time.Second*10)) - t.Log("All the conditions are met 🎉") - - t.Cleanup(func() { - db.Close() - }) -} - func TestIntegrationPgCDCForPgOutputPlugin(t *testing.T) { integration.CheckSkip(t) tmpDir := t.TempDir() @@ -339,7 +149,6 @@ pg_stream: dsn: %s slot_name: test_slot_native_decoder stream_snapshot: true - decoding_plugin: pgoutput schema: public tables: - flights @@ -454,7 +263,6 @@ pg_stream: slot_name: test_slot_native_decoder snapshot_batch_size: 100000 stream_snapshot: true - decoding_plugin: pgoutput batch_transactions: false temporary_slot: true schema: public @@ -549,7 +357,6 @@ pg_stream: slot_name: test_slot_native_decoder snapshot_batch_size: 100 stream_snapshot: true - decoding_plugin: pgoutput batch_transactions: true schema: public tables: @@ -688,7 +495,6 @@ pg_stream: dsn: %s slot_name: test_slot_native_decoder stream_snapshot: true - decoding_plugin: pgoutput batch_transactions: true schema: public tables: @@ -825,7 +631,6 @@ pg_stream: dsn: %s slot_name: test_slot_native_decoder stream_snapshot: true - decoding_plugin: pgoutput batch_transactions: false schema: public tables: diff --git a/internal/impl/postgresql/pglogicalstream/config.go b/internal/impl/postgresql/pglogicalstream/config.go index bcb31f4181..d937813a36 100644 --- a/internal/impl/postgresql/pglogicalstream/config.go +++ b/internal/impl/postgresql/pglogicalstream/config.go @@ -34,12 +34,10 @@ type Config struct { StreamOldData bool // SnapshotMemorySafetyFactor is the memory safety factor for streaming snapshot SnapshotMemorySafetyFactor float64 - // DecodingPlugin is the decoding plugin to use - DecodingPlugin string // BatchSize is the batch size for streaming BatchSize int // BatchTransactions is whether to buffer transactions as an entire single message or to send - // each row in a transaction as a message. This has no effect for wal2json. + // each row in a transaction as a message. BatchTransactions bool Logger *service.Logger diff --git a/internal/impl/postgresql/pglogicalstream/consts.go b/internal/impl/postgresql/pglogicalstream/consts.go deleted file mode 100644 index 944fd82b39..0000000000 --- a/internal/impl/postgresql/pglogicalstream/consts.go +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2024 Redpanda Data, Inc. -// -// Licensed as a Redpanda Enterprise file under the Redpanda Community -// License (the "License"); you may not use this file except in compliance with -// the License. You may obtain a copy of the License at -// -// https://github.com/redpanda-data/connect/v4/blob/main/licenses/rcl.md - -package pglogicalstream - -// DecodingPlugin is a type for the decoding plugin -type DecodingPlugin string - -const ( - // Wal2JSON is the value for the wal2json decoding plugin. It requires wal2json extension to be installed on the PostgreSQL instance - Wal2JSON DecodingPlugin = "wal2json" - // PgOutput is the value for the pgoutput decoding plugin. It requires pgoutput extension to be installed on the PostgreSQL instance - PgOutput DecodingPlugin = "pgoutput" -) - -func decodingPluginFromString(plugin string) DecodingPlugin { - switch plugin { - case "wal2json": - return Wal2JSON - case "pgoutput": - return PgOutput - default: - return PgOutput - } -} - -func (d DecodingPlugin) String() string { - return string(d) -} diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 9f2ea48f91..2505a7b358 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -27,6 +27,8 @@ import ( "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/watermark" ) +const decodingPlugin = "pgoutput" + // Stream is a structure that represents a logical replication stream // It includes the connection to the database, the context for the stream, and snapshotting functionality type Stream struct { @@ -47,7 +49,6 @@ type Stream struct { // includes schema tableQualifiedName []string snapshotBatchSize int - decodingPlugin DecodingPlugin decodingPluginArguments []string snapshotMemorySafetyFactor float64 logger *service.Logger @@ -106,7 +107,6 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { tableQualifiedName: tableNames, maxParallelSnapshotTables: config.MaxParallelSnapshotTables, logger: config.Logger, - decodingPlugin: decodingPluginFromString(config.DecodingPlugin), shutSig: shutdown.NewSignaller(), } @@ -127,26 +127,14 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { } }) - var pluginArguments []string - if stream.decodingPlugin == "pgoutput" { - pluginArguments = []string{ - "proto_version '1'", - // Sprintf is safe because we validate ReplicationSlotName is alphanumeric in the config - fmt.Sprintf("publication_names 'pglog_stream_%s'", config.ReplicationSlotName), - } + pluginArguments := []string{ + "proto_version '1'", + // Sprintf is safe because we validate ReplicationSlotName is alphanumeric in the config + fmt.Sprintf("publication_names 'pglog_stream_%s'", config.ReplicationSlotName), + } - if version > 14 { - pluginArguments = append(pluginArguments, "messages 'true'") - } - } else if stream.decodingPlugin == "wal2json" { - tablesFilterRule := strings.Join(tableNames, ", ") - pluginArguments = []string{ - "\"pretty-print\" 'true'", - // TODO: Validate this is escaped properly - fmt.Sprintf(`"add-tables" '%s'`, tablesFilterRule), - } - } else { - return nil, fmt.Errorf("unknown decoding plugin: %q", stream.decodingPlugin) + if version > 14 { + pluginArguments = append(pluginArguments, "messages 'true'") } stream.decodingPluginArguments = pluginArguments @@ -185,7 +173,7 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { ctx, stream.pgConn, stream.slotName, - stream.decodingPlugin.String(), + decodingPlugin, CreateReplicationSlotOptions{ Temporary: config.TemporaryReplicationSlot, SnapshotAction: "export", @@ -210,7 +198,8 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { outputPlugin = string(slotCheckRow[1]) } - if !freshlyCreatedSlot && outputPlugin != stream.decodingPlugin.String() { + // handling a case when replication slot already exists but with different output plugin created manually + if !freshlyCreatedSlot && outputPlugin != decodingPlugin { return nil, fmt.Errorf("replication slot %s already exists with different output plugin: %s", config.ReplicationSlotName, outputPlugin) } @@ -330,15 +319,7 @@ func (s *Stream) AckLSN(ctx context.Context, lsn string) error { } func (s *Stream) streamMessages() error { - var handler PluginHandler - switch s.decodingPlugin { - case "wal2json": - handler = NewWal2JsonPluginHandler(s.messages, s.monitor) - case "pgoutput": - handler = NewPgOutputPluginHandler(s.messages, s.batchTransactions, s.monitor, s.clientXLogPos) - default: - return fmt.Errorf("invalid decoding plugin: %q", s.decodingPlugin) - } + handler := NewPgOutputPluginHandler(s.messages, s.batchTransactions, s.monitor, s.clientXLogPos) ctx, _ := s.shutSig.SoftStopCtx(context.Background()) for !s.shutSig.IsSoftStopSignalled() { diff --git a/internal/impl/postgresql/pglogicalstream/monitor.go b/internal/impl/postgresql/pglogicalstream/monitor.go index f800bb1591..d9ed0ba4db 100644 --- a/internal/impl/postgresql/pglogicalstream/monitor.go +++ b/internal/impl/postgresql/pglogicalstream/monitor.go @@ -108,7 +108,7 @@ func (m *Monitor) readTablesStat(ctx context.Context, tables []string) error { var count int64 // tableWithoutSchema has been validated so its safe to use in the query - err = m.dbConn.QueryRowContext(ctx, "SELECT COUNT(*) FROM %s"+tableWithoutSchema).Scan(&count) + err = m.dbConn.QueryRowContext(ctx, "SELECT COUNT(*) FROM "+tableWithoutSchema).Scan(&count) if err != nil { // If the error is because the table doesn't exist, we'll set the count to 0 diff --git a/internal/impl/postgresql/pglogicalstream/pluginhandlers.go b/internal/impl/postgresql/pglogicalstream/pluginhandlers.go index e579a388a2..6e228dc21b 100644 --- a/internal/impl/postgresql/pglogicalstream/pluginhandlers.go +++ b/internal/impl/postgresql/pglogicalstream/pluginhandlers.go @@ -22,41 +22,6 @@ type PluginHandler interface { Handle(ctx context.Context, clientXLogPos LSN, xld XLogData) (bool, error) } -// Wal2JsonPluginHandler is a handler for wal2json output plugin -type Wal2JsonPluginHandler struct { - messages chan StreamMessage - monitor *Monitor -} - -// NewWal2JsonPluginHandler creates a new Wal2JsonPluginHandler -func NewWal2JsonPluginHandler(messages chan StreamMessage, monitor *Monitor) *Wal2JsonPluginHandler { - return &Wal2JsonPluginHandler{ - messages: messages, - monitor: monitor, - } -} - -// Handle handles the wal2json output -func (w *Wal2JsonPluginHandler) Handle(ctx context.Context, clientXLogPos LSN, xld XLogData) (bool, error) { - // get current stream metrics - metrics := w.monitor.Report() - message, err := decodeWal2JsonChanges(clientXLogPos.String(), xld.WALData) - if err != nil { - return false, err - } - - if message != nil && len(message.Changes) > 0 { - message.WALLagBytes = &metrics.WalLagInBytes - select { - case w.messages <- *message: - case <-ctx.Done(): - return false, ctx.Err() - } - } - - return false, nil -} - // PgOutputUnbufferedPluginHandler is a native output handler that emits each message as it's received. type PgOutputUnbufferedPluginHandler struct { messages chan StreamMessage diff --git a/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go b/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go index 47bf3a6d26..0c1ed3b236 100644 --- a/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go +++ b/internal/impl/postgresql/pglogicalstream/replication_message_decoders.go @@ -9,8 +9,6 @@ package pglogicalstream import ( - "bytes" - "encoding/json" "fmt" "log" @@ -153,69 +151,3 @@ func decodeTextColumnData(mi *pgtype.Map, data []byte, dataType uint32) (interfa } return string(data), nil } - -// ---------------------------------------------------------------------------- -// Wal2Json section - -type walMessageWal2JSON struct { - Change []struct { - Kind string `json:"kind"` - Schema string `json:"schema"` - Table string `json:"table"` - Columnnames []string `json:"columnnames"` - Columntypes []string `json:"columntypes"` - Columnvalues []interface{} `json:"columnvalues"` - Oldkeys struct { - Keynames []string `json:"keynames"` - Keytypes []string `json:"keytypes"` - Keyvalues []interface{} `json:"keyvalues"` - } `json:"oldkeys"` - } `json:"change"` -} - -func decodeWal2JsonChanges(clientXLogPosition string, WALData []byte) (*StreamMessage, error) { - var changes walMessageWal2JSON - if err := json.NewDecoder(bytes.NewReader(WALData)).Decode(&changes); err != nil { - return nil, err - } - - if len(changes.Change) == 0 { - return nil, nil - } - message := &StreamMessage{ - Lsn: &clientXLogPosition, - Changes: []StreamMessageChanges{}, - Mode: StreamModeStreaming, - } - - for _, change := range changes.Change { - if change.Kind == "" { - continue - } - - messageChange := StreamMessageChanges{ - Operation: change.Kind, - Schema: change.Schema, - Table: change.Table, - Data: make(map[string]any), - } - - if change.Kind == "delete" { - for i, keyName := range change.Oldkeys.Keynames { - if len(change.Columnvalues) == 0 { - break - } - - messageChange.Data[keyName] = change.Oldkeys.Keyvalues[i] - } - } else { - for i, columnName := range change.Columnnames { - messageChange.Data[columnName] = change.Columnvalues[i] - } - } - - message.Changes = append(message.Changes, messageChange) - } - - return message, nil -} From 04697518bcf02cb09f37e450fd96192f96528055 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Fri, 15 Nov 2024 12:01:00 +0100 Subject: [PATCH 111/118] chore(): updated pgstream docs --- .../components/pages/inputs/pg_stream.adoc | 25 +------------------ 1 file changed, 1 insertion(+), 24 deletions(-) diff --git a/docs/modules/components/pages/inputs/pg_stream.adoc b/docs/modules/components/pages/inputs/pg_stream.adoc index 1851cf1790..5410431626 100644 --- a/docs/modules/components/pages/inputs/pg_stream.adoc +++ b/docs/modules/components/pages/inputs/pg_stream.adoc @@ -44,7 +44,6 @@ input: stream_snapshot: false snapshot_memory_safety_factor: 1 snapshot_batch_size: 0 - decoding_plugin: pgoutput schema: public # No default (required) tables: [] # No default (required) checkpoint_limit: 1024 @@ -76,7 +75,6 @@ input: stream_snapshot: false snapshot_memory_safety_factor: 1 snapshot_batch_size: 0 - decoding_plugin: pgoutput schema: public # No default (required) tables: [] # No default (required) checkpoint_limit: 1024 @@ -126,7 +124,7 @@ dsn: postgres://foouser:foopass@localhost:5432/foodb?sslmode=disable === `batch_transactions` -When set to true, transactions are batched into a single message. Note that this setting has no effect when using wal2json +When set to true, transactions are batched into a single message. *Type*: `bool` @@ -178,27 +176,6 @@ The number of rows to fetch in each batch when querying the snapshot. A value of snapshot_batch_size: 10000 ``` -=== `decoding_plugin` - -Specifies the logical decoding plugin to use for streaming changes from PostgreSQL. 'pgoutput' is the native logical replication protocol, while 'wal2json' provides change data as JSON. -Important: No matter which plugin you choose, the data will be converted to JSON before sending it to Connect. - - -*Type*: `string` - -*Default*: `"pgoutput"` - -Options: -`pgoutput` -, `wal2json` -. - -```yml -# Examples - -decoding_plugin: pgoutput -``` - === `schema` The PostgreSQL schema from which to replicate data. From b550df99437cf4d579367d653911ff06553a3c22 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Fri, 15 Nov 2024 13:42:59 +0100 Subject: [PATCH 112/118] feat(): added support for composite primary keys --- internal/impl/postgresql/integration_test.go | 22 ++++-- .../pglogicalstream/logical_stream.go | 71 +++++++++++++------ .../postgresql/pglogicalstream/pglogrepl.go | 1 - .../postgresql/pglogicalstream/snapshotter.go | 25 +++++-- 4 files changed, 88 insertions(+), 31 deletions(-) diff --git a/internal/impl/postgresql/integration_test.go b/internal/impl/postgresql/integration_test.go index 616c578e8e..f22878d2cb 100644 --- a/internal/impl/postgresql/integration_test.go +++ b/internal/impl/postgresql/integration_test.go @@ -105,6 +105,15 @@ func ResourceWithPostgreSQLVersion(t *testing.T, pool *dockertest.Pool, version return err } + _, err = db.Exec(` + CREATE TABLE IF NOT EXISTS flights_composite_pks ( + id serial, seq integer, name VARCHAR(50), created_at TIMESTAMP, + PRIMARY KEY (id, seq) + );`) + if err != nil { + return err + } + // flights_non_streamed is a control table with data that should not be streamed or queried by snapshot streaming _, err = db.Exec("CREATE TABLE IF NOT EXISTS flights_non_streamed (id serial PRIMARY KEY, name VARCHAR(50), created_at TIMESTAMP);") @@ -139,7 +148,7 @@ func TestIntegrationPgCDCForPgOutputPlugin(t *testing.T) { for i := 0; i < 10; i++ { f := GetFakeFlightRecord() - _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) + _, err = db.Exec("INSERT INTO flights_composite_pks (seq, name, created_at) VALUES ($1, $2, $3);", i, f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) require.NoError(t, err) } @@ -149,9 +158,10 @@ pg_stream: dsn: %s slot_name: test_slot_native_decoder stream_snapshot: true + snapshot_batch_size: 5 schema: public tables: - - flights + - flights_composite_pks `, databaseURL) cacheConf := fmt.Sprintf(` @@ -189,9 +199,9 @@ file: return len(outBatches) == 10 }, time.Second*25, time.Millisecond*100) - for i := 0; i < 10; i++ { + for i := 10; i < 20; i++ { f := GetFakeFlightRecord() - _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) + _, err = db.Exec("INSERT INTO flights_composite_pks (seq, name, created_at) VALUES ($1, $2, $3);", i, f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) require.NoError(t, err) _, err = db.Exec("INSERT INTO flights_non_streamed (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) require.NoError(t, err) @@ -231,9 +241,9 @@ file: }() time.Sleep(time.Second * 5) - for i := 0; i < 10; i++ { + for i := 20; i < 30; i++ { f := GetFakeFlightRecord() - _, err = db.Exec("INSERT INTO flights (name, created_at) VALUES ($1, $2);", f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) + _, err = db.Exec("INSERT INTO flights_composite_pks (seq, name, created_at) VALUES ($1, $2, $3);", i, f.RealAddress.City, time.Unix(f.CreatedAt, 0).Format(time.RFC3339)) require.NoError(t, err) } diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 2505a7b358..f6287d03c9 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -438,18 +438,34 @@ func (s *Stream) processSnapshot() error { s.logger.Debugf("Querying snapshot batch_side: %v, available_memory: %v, avg_row_size: %v", batchSize, availableMemory, avgRowSizeBytes.Int64) - tablePk, err := s.getPrimaryKeyColumn(table) + tablePks, pksSlice, err := s.getPrimaryKeyColumn(table) if err != nil { return fmt.Errorf("failed to get primary key column for table %v: %w", table, err) } - var lastPkVal any + if len(tablePks) == 0 { + return fmt.Errorf("failed to get primary key column for table %s", table) + } + + tablePksOrderByStatement := strings.Join(pksSlice, ", ") + if len(tablePks) > 1 { + tablePksOrderByStatement = "(" + tablePksOrderByStatement + ")" + } + + var lastPkVals = map[string]any{} for { var snapshotRows *sql.Rows queryStart := time.Now() - if snapshotRows, err = s.snapshotter.querySnapshotData(table, lastPkVal, tablePk, batchSize); err != nil { - return fmt.Errorf("failed to query snapshot data for table %v: %w", table, err) + if offset == 0 { + lastPkVals = make(map[string]any) + if snapshotRows, err = s.snapshotter.querySnapshotData(table, nil, tablePksOrderByStatement, batchSize); err != nil { + return fmt.Errorf("failed to query snapshot data for table %v: %w", table, err) + } + } else { + if snapshotRows, err = s.snapshotter.querySnapshotData(table, &lastPkVals, tablePksOrderByStatement, batchSize); err != nil { + return fmt.Errorf("failed to query snapshot data for table %v: %w", table, err) + } } queryDuration := time.Since(queryStart) @@ -491,8 +507,8 @@ func (s *Stream) processSnapshot() error { var data = make(map[string]any) for i, getter := range valueGetters { data[columnNames[i]] = getter(scanArgs[i]) - if columnNames[i] == tablePk { - lastPkVal = getter(scanArgs[i]) + if _, ok := tablePks[columnNames[i]]; ok { + lastPkVals[columnNames[i]] = getter(scanArgs[i]) } } @@ -553,29 +569,44 @@ func (s *Stream) Errors() chan error { return s.errors } -func (s *Stream) getPrimaryKeyColumn(tableName string) (string, error) { - // TODO(le-vlad): support composite primary keys +func (s *Stream) getPrimaryKeyColumn(tableName string) (map[string]any, []string, error) { + /// Query to get all primary key columns in their correct order q, err := sanitize.SQLQuery(` - SELECT a.attname - FROM pg_index i - JOIN pg_attribute a ON a.attrelid = i.indrelid - AND a.attnum = ANY(i.indkey) - WHERE i.indrelid = $1::regclass - AND i.indisprimary; - `, tableName) + SELECT a.attname + FROM pg_index i + JOIN pg_attribute a ON a.attrelid = i.indrelid + AND a.attnum = ANY(i.indkey) + WHERE i.indrelid = $1::regclass + AND i.indisprimary + ORDER BY array_position(i.indkey, a.attnum); + `, tableName) + if err != nil { - return "", err + return nil, nil, fmt.Errorf("failed to sanitize query: %w", err) } reader := s.pgConn.Exec(context.Background(), q) data, err := reader.ReadAll() if err != nil { - return "", err + return nil, nil, fmt.Errorf("failed to read query results: %w", err) + } + + if len(data) == 0 || len(data[0].Rows) == 0 { + return nil, nil, fmt.Errorf("no primary key found for table %s", tableName) + } + + // Extract all primary key column names + pkColumns := make([]string, len(data[0].Rows)) + for i, row := range data[0].Rows { + pkColumns[i] = string(row[0]) + } + + var pksMap = make(map[string]any) + for _, pk := range pkColumns { + pksMap[pk] = nil } - pkResultRow := data[0].Rows[0] - pkColName := string(pkResultRow[0]) - return pkColName, nil + return pksMap, pkColumns, nil } // Stop closes the stream (hopefully gracefully) diff --git a/internal/impl/postgresql/pglogicalstream/pglogrepl.go b/internal/impl/postgresql/pglogicalstream/pglogrepl.go index 833b968938..99debf5190 100644 --- a/internal/impl/postgresql/pglogicalstream/pglogrepl.go +++ b/internal/impl/postgresql/pglogicalstream/pglogrepl.go @@ -389,7 +389,6 @@ func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName if len(rows) == 0 || len(rows[0].Rows) == 0 { // tablesClause is sanitized, so we can safely interpolate it into the query sq, err := sanitize.SQLQuery(fmt.Sprintf("CREATE PUBLICATION %s %s;", publicationName, tablesClause)) - fmt.Print(sq) if err != nil { return fmt.Errorf("failed to sanitize publication creation query: %w", err) } diff --git a/internal/impl/postgresql/pglogicalstream/snapshotter.go b/internal/impl/postgresql/pglogicalstream/snapshotter.go index ce581c169c..8d3a6ac70e 100644 --- a/internal/impl/postgresql/pglogicalstream/snapshotter.go +++ b/internal/impl/postgresql/pglogicalstream/snapshotter.go @@ -11,6 +11,7 @@ package pglogicalstream import ( "database/sql" "fmt" + "strings" "errors" @@ -174,11 +175,13 @@ func (s *Snapshotter) calculateBatchSize(availableMemory uint64, estimatedRowSiz return batchSize } -func (s *Snapshotter) querySnapshotData(table string, lastSeenPk any, pk string, limit int) (rows *sql.Rows, err error) { - s.logger.Infof("Query snapshot table: %v, limit: %v, lastSeenPkVal: %v, pk: %v", table, limit, lastSeenPk, pk) +func (s *Snapshotter) querySnapshotData(table string, lastSeenPk *map[string]any, pksOrderBy string, limit int) (rows *sql.Rows, err error) { + + s.logger.Infof("Query snapshot table: %v, limit: %v, lastSeenPkVal: %v, pk: %v", table, limit, lastSeenPk, pksOrderBy) + if lastSeenPk == nil { // NOTE: All strings passed into here have been validated or derived from the code/database, therefore not prone to SQL injection. - sq, err := sanitize.SQLQuery(fmt.Sprintf("SELECT * FROM %s ORDER BY %s LIMIT %d;", table, pk, limit)) + sq, err := sanitize.SQLQuery(fmt.Sprintf("SELECT * FROM %s ORDER BY %s LIMIT %d;", table, pksOrderBy, limit)) if err != nil { return nil, err } @@ -186,8 +189,22 @@ func (s *Snapshotter) querySnapshotData(table string, lastSeenPk any, pk string, return s.pgConnection.Query(sq) } + var ( + placeholders []string + lastSeenPksValues []any + i = 1 + ) + + for _, v := range *lastSeenPk { + placeholders = append(placeholders, fmt.Sprintf("$%d", i)) + i++ + lastSeenPksValues = append(lastSeenPksValues, v) + } + + lastSeenPlaceHolders := "(" + strings.Join(placeholders, ", ") + ")" + // NOTE: All strings passed into here have been validated or derived from the code/database, therefore not prone to SQL injection. - sq, err := sanitize.SQLQuery(fmt.Sprintf("SELECT * FROM %s WHERE %s > %s ORDER BY %s LIMIT %d;", table, pk, lastSeenPk, pk, limit)) + sq, err := sanitize.SQLQuery(fmt.Sprintf("SELECT * FROM %s WHERE %s > %s ORDER BY %s LIMIT %d;", table, pksOrderBy, lastSeenPlaceHolders, pksOrderBy, limit), lastSeenPksValues...) if err != nil { return nil, err } From 813788338a2b6455114914b6fa6afe4517eafdac Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Fri, 15 Nov 2024 15:21:55 +0000 Subject: [PATCH 113/118] pgcdc: mark as enterprise licensed --- public/components/all/package.go | 1 + public/components/community/package.go | 1 - public/components/postgresql/package.go | 22 +++++++++------------- 3 files changed, 10 insertions(+), 14 deletions(-) diff --git a/public/components/all/package.go b/public/components/all/package.go index d950cc3e5a..0b7b3a6c3e 100644 --- a/public/components/all/package.go +++ b/public/components/all/package.go @@ -23,6 +23,7 @@ import ( _ "github.com/redpanda-data/connect/v4/public/components/kafka/enterprise" _ "github.com/redpanda-data/connect/v4/public/components/ollama" _ "github.com/redpanda-data/connect/v4/public/components/openai" + _ "github.com/redpanda-data/connect/v4/public/components/postgresql" _ "github.com/redpanda-data/connect/v4/public/components/snowflake" _ "github.com/redpanda-data/connect/v4/public/components/splunk" ) diff --git a/public/components/community/package.go b/public/components/community/package.go index e66b1ca70f..324cae7040 100644 --- a/public/components/community/package.go +++ b/public/components/community/package.go @@ -54,7 +54,6 @@ import ( _ "github.com/redpanda-data/connect/v4/public/components/opensearch" _ "github.com/redpanda-data/connect/v4/public/components/otlp" _ "github.com/redpanda-data/connect/v4/public/components/pinecone" - _ "github.com/redpanda-data/connect/v4/public/components/postgresql" _ "github.com/redpanda-data/connect/v4/public/components/prometheus" _ "github.com/redpanda-data/connect/v4/public/components/pulsar" _ "github.com/redpanda-data/connect/v4/public/components/pure" diff --git a/public/components/postgresql/package.go b/public/components/postgresql/package.go index fa5d81b263..275ee41b31 100644 --- a/public/components/postgresql/package.go +++ b/public/components/postgresql/package.go @@ -1,16 +1,12 @@ -// Copyright 2024 Redpanda Data, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* + * Copyright 2024 Redpanda Data, Inc. + * + * Licensed as a Redpanda Enterprise file under the Redpanda Community + * License (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://github.com/redpanda-data/redpanda/blob/master/licenses/rcl.md + */ package postgresql From 93f701e6500e69b4a434e0cbbe188352b4b708db Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Fri, 15 Nov 2024 16:35:29 +0100 Subject: [PATCH 114/118] chore(): applied make fmt --- internal/impl/postgresql/pglogicalstream/snapshotter.go | 1 + 1 file changed, 1 insertion(+) diff --git a/internal/impl/postgresql/pglogicalstream/snapshotter.go b/internal/impl/postgresql/pglogicalstream/snapshotter.go index 8d3a6ac70e..891447ea7d 100644 --- a/internal/impl/postgresql/pglogicalstream/snapshotter.go +++ b/internal/impl/postgresql/pglogicalstream/snapshotter.go @@ -17,6 +17,7 @@ import ( _ "github.com/lib/pq" "github.com/redpanda-data/benthos/v4/public/service" + "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/sanitize" ) From 3cfc43687965456804456ada9b1a9f99fd50755c Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Fri, 15 Nov 2024 16:07:10 +0000 Subject: [PATCH 115/118] pgcdc/snapshot: use context for cancellation --- .../pglogicalstream/logical_stream.go | 33 ++++++++----------- .../postgresql/pglogicalstream/snapshotter.go | 22 +++++++------ 2 files changed, 26 insertions(+), 29 deletions(-) diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index f6287d03c9..dbcb2c3f7f 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -425,7 +425,9 @@ func (s *Stream) processSnapshot() error { offset = 0 ) - avgRowSizeBytes, err = s.snapshotter.findAvgRowSize(table) + ctx, _ := s.shutSig.SoftStopCtx(context.Background()) + + avgRowSizeBytes, err = s.snapshotter.findAvgRowSize(ctx, table) if err != nil { return fmt.Errorf("failed to calculate average row size for table %v: %w", table, err) } @@ -438,38 +440,31 @@ func (s *Stream) processSnapshot() error { s.logger.Debugf("Querying snapshot batch_side: %v, available_memory: %v, avg_row_size: %v", batchSize, availableMemory, avgRowSizeBytes.Int64) - tablePks, pksSlice, err := s.getPrimaryKeyColumn(table) + lastPrimaryKey, primaryKeyColumns, err := s.getPrimaryKeyColumn(ctx, table) if err != nil { return fmt.Errorf("failed to get primary key column for table %v: %w", table, err) } - if len(tablePks) == 0 { + if len(lastPrimaryKey) == 0 { return fmt.Errorf("failed to get primary key column for table %s", table) } - tablePksOrderByStatement := strings.Join(pksSlice, ", ") - if len(tablePks) > 1 { - tablePksOrderByStatement = "(" + tablePksOrderByStatement + ")" - } - var lastPkVals = map[string]any{} for { var snapshotRows *sql.Rows queryStart := time.Now() if offset == 0 { - lastPkVals = make(map[string]any) - if snapshotRows, err = s.snapshotter.querySnapshotData(table, nil, tablePksOrderByStatement, batchSize); err != nil { - return fmt.Errorf("failed to query snapshot data for table %v: %w", table, err) - } + snapshotRows, err = s.snapshotter.querySnapshotData(ctx, table, nil, primaryKeyColumns, batchSize) } else { - if snapshotRows, err = s.snapshotter.querySnapshotData(table, &lastPkVals, tablePksOrderByStatement, batchSize); err != nil { - return fmt.Errorf("failed to query snapshot data for table %v: %w", table, err) - } + snapshotRows, err = s.snapshotter.querySnapshotData(ctx, table, lastPkVals, primaryKeyColumns, batchSize) + } + if err != nil { + return fmt.Errorf("failed to query snapshot data for table %v: %w", table, err) } queryDuration := time.Since(queryStart) - s.logger.Debugf("Query duration: %v %s \n", queryDuration, tableName) + s.logger.Tracef("Query duration: %v %s \n", queryDuration, tableName) if snapshotRows.Err() != nil { return fmt.Errorf("failed to get snapshot data for table %v: %w", table, snapshotRows.Err()) @@ -507,7 +502,7 @@ func (s *Stream) processSnapshot() error { var data = make(map[string]any) for i, getter := range valueGetters { data[columnNames[i]] = getter(scanArgs[i]) - if _, ok := tablePks[columnNames[i]]; ok { + if _, ok := lastPrimaryKey[columnNames[i]]; ok { lastPkVals[columnNames[i]] = getter(scanArgs[i]) } } @@ -569,7 +564,7 @@ func (s *Stream) Errors() chan error { return s.errors } -func (s *Stream) getPrimaryKeyColumn(tableName string) (map[string]any, []string, error) { +func (s *Stream) getPrimaryKeyColumn(ctx context.Context, tableName string) (map[string]any, []string, error) { /// Query to get all primary key columns in their correct order q, err := sanitize.SQLQuery(` SELECT a.attname @@ -585,7 +580,7 @@ func (s *Stream) getPrimaryKeyColumn(tableName string) (map[string]any, []string return nil, nil, fmt.Errorf("failed to sanitize query: %w", err) } - reader := s.pgConn.Exec(context.Background(), q) + reader := s.pgConn.Exec(ctx, q) data, err := reader.ReadAll() if err != nil { return nil, nil, fmt.Errorf("failed to read query results: %w", err) diff --git a/internal/impl/postgresql/pglogicalstream/snapshotter.go b/internal/impl/postgresql/pglogicalstream/snapshotter.go index 8d3a6ac70e..0277870b8f 100644 --- a/internal/impl/postgresql/pglogicalstream/snapshotter.go +++ b/internal/impl/postgresql/pglogicalstream/snapshotter.go @@ -9,6 +9,7 @@ package pglogicalstream import ( + "context" "database/sql" "fmt" "strings" @@ -112,7 +113,7 @@ func (s *Snapshotter) prepare() error { return nil } -func (s *Snapshotter) findAvgRowSize(table string) (sql.NullInt64, error) { +func (s *Snapshotter) findAvgRowSize(ctx context.Context, table string) (sql.NullInt64, error) { var ( avgRowSize sql.NullInt64 rows *sql.Rows @@ -120,7 +121,7 @@ func (s *Snapshotter) findAvgRowSize(table string) (sql.NullInt64, error) { ) // table is validated to be correct pg identifier, so we can use it directly - if rows, err = s.pgConnection.Query(fmt.Sprintf(`SELECT SUM(pg_column_size('%s.*')) / COUNT(*) FROM %s;`, table, table)); err != nil { + if rows, err = s.pgConnection.QueryContext(ctx, fmt.Sprintf(`SELECT SUM(pg_column_size('%s.*')) / COUNT(*) FROM %s;`, table, table)); err != nil { return avgRowSize, fmt.Errorf("can get avg row size due to query failure: %w", err) } @@ -175,18 +176,18 @@ func (s *Snapshotter) calculateBatchSize(availableMemory uint64, estimatedRowSiz return batchSize } -func (s *Snapshotter) querySnapshotData(table string, lastSeenPk *map[string]any, pksOrderBy string, limit int) (rows *sql.Rows, err error) { +func (s *Snapshotter) querySnapshotData(ctx context.Context, table string, lastSeenPk map[string]any, pkColumns []string, limit int) (rows *sql.Rows, err error) { - s.logger.Infof("Query snapshot table: %v, limit: %v, lastSeenPkVal: %v, pk: %v", table, limit, lastSeenPk, pksOrderBy) + s.logger.Infof("Query snapshot table: %v, limit: %v, lastSeenPkVal: %v, pk: %v", table, limit, lastSeenPk, pkColumns) if lastSeenPk == nil { // NOTE: All strings passed into here have been validated or derived from the code/database, therefore not prone to SQL injection. - sq, err := sanitize.SQLQuery(fmt.Sprintf("SELECT * FROM %s ORDER BY %s LIMIT %d;", table, pksOrderBy, limit)) + sq, err := sanitize.SQLQuery(fmt.Sprintf("SELECT * FROM %s ORDER BY %s LIMIT %d;", table, pkColumns, limit)) if err != nil { return nil, err } - return s.pgConnection.Query(sq) + return s.pgConnection.QueryContext(ctx, sq) } var ( @@ -195,21 +196,22 @@ func (s *Snapshotter) querySnapshotData(table string, lastSeenPk *map[string]any i = 1 ) - for _, v := range *lastSeenPk { + for _, col := range pkColumns { placeholders = append(placeholders, fmt.Sprintf("$%d", i)) i++ - lastSeenPksValues = append(lastSeenPksValues, v) + lastSeenPksValues = append(lastSeenPksValues, lastSeenPk[col]) } lastSeenPlaceHolders := "(" + strings.Join(placeholders, ", ") + ")" + pkAsTuple := "(" + strings.Join(pkColumns, ", ") + ")" // NOTE: All strings passed into here have been validated or derived from the code/database, therefore not prone to SQL injection. - sq, err := sanitize.SQLQuery(fmt.Sprintf("SELECT * FROM %s WHERE %s > %s ORDER BY %s LIMIT %d;", table, pksOrderBy, lastSeenPlaceHolders, pksOrderBy, limit), lastSeenPksValues...) + sq, err := sanitize.SQLQuery(fmt.Sprintf("SELECT * FROM %s WHERE %s > %s ORDER BY %s LIMIT %d;", table, pkAsTuple, lastSeenPlaceHolders, pkColumns, limit), lastSeenPksValues...) if err != nil { return nil, err } - return s.pgConnection.Query(sq) + return s.pgConnection.QueryContext(ctx, sq) } func (s *Snapshotter) releaseSnapshot() error { From b64dc1b358acdef6b0e96bb99098794bff722064 Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Fri, 15 Nov 2024 16:11:58 +0000 Subject: [PATCH 116/118] pgcdc: fix primary key order by clause --- internal/impl/postgresql/pglogicalstream/snapshotter.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/impl/postgresql/pglogicalstream/snapshotter.go b/internal/impl/postgresql/pglogicalstream/snapshotter.go index 9df4e025f3..8f6c737da6 100644 --- a/internal/impl/postgresql/pglogicalstream/snapshotter.go +++ b/internal/impl/postgresql/pglogicalstream/snapshotter.go @@ -179,11 +179,11 @@ func (s *Snapshotter) calculateBatchSize(availableMemory uint64, estimatedRowSiz func (s *Snapshotter) querySnapshotData(ctx context.Context, table string, lastSeenPk map[string]any, pkColumns []string, limit int) (rows *sql.Rows, err error) { - s.logger.Infof("Query snapshot table: %v, limit: %v, lastSeenPkVal: %v, pk: %v", table, limit, lastSeenPk, pkColumns) + s.logger.Debugf("Query snapshot table: %v, limit: %v, lastSeenPkVal: %v, pk: %v", table, limit, lastSeenPk, pkColumns) if lastSeenPk == nil { // NOTE: All strings passed into here have been validated or derived from the code/database, therefore not prone to SQL injection. - sq, err := sanitize.SQLQuery(fmt.Sprintf("SELECT * FROM %s ORDER BY %s LIMIT %d;", table, pkColumns, limit)) + sq, err := sanitize.SQLQuery(fmt.Sprintf("SELECT * FROM %s ORDER BY %s LIMIT %d;", table, strings.Join(pkColumns, ", "), limit)) if err != nil { return nil, err } @@ -207,7 +207,7 @@ func (s *Snapshotter) querySnapshotData(ctx context.Context, table string, lastS pkAsTuple := "(" + strings.Join(pkColumns, ", ") + ")" // NOTE: All strings passed into here have been validated or derived from the code/database, therefore not prone to SQL injection. - sq, err := sanitize.SQLQuery(fmt.Sprintf("SELECT * FROM %s WHERE %s > %s ORDER BY %s LIMIT %d;", table, pkAsTuple, lastSeenPlaceHolders, pkColumns, limit), lastSeenPksValues...) + sq, err := sanitize.SQLQuery(fmt.Sprintf("SELECT * FROM %s WHERE %s > %s ORDER BY %s LIMIT %d;", table, pkAsTuple, lastSeenPlaceHolders, strings.Join(pkColumns, ", "), limit), lastSeenPksValues...) if err != nil { return nil, err } From d05aae0539a2393dd739feeb926f36b706b01de2 Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Fri, 15 Nov 2024 16:14:55 +0000 Subject: [PATCH 117/118] pgcdc: fix zero batch check --- internal/impl/postgresql/input_pg_stream.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/impl/postgresql/input_pg_stream.go b/internal/impl/postgresql/input_pg_stream.go index 7adaa20543..fcdbfb217f 100644 --- a/internal/impl/postgresql/input_pg_stream.go +++ b/internal/impl/postgresql/input_pg_stream.go @@ -345,7 +345,7 @@ func (p *pgStreamInput) processStream(pgStream *pglogicalstream.Stream, batcher break } - // TODO this should only be the message + // TODO(rockwood): this should only be the message if mb, err = json.Marshal(message.Changes); err != nil { break } @@ -398,7 +398,7 @@ func (p *pgStreamInput) flushBatch( checkpointer *checkpoint.Capped[*int64], batch service.MessageBatch, ) error { - if batch == nil { + if len(batch) == 0 { return nil } From 438719ab3e6d0898e3d4a51ef78ecaf0d726b422 Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Fri, 15 Nov 2024 18:34:28 +0000 Subject: [PATCH 118/118] update changelog --- CHANGELOG.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b96f9e0b8e..70fdb569e2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,16 @@ Changelog All notable changes to this project will be documented in this file. +## 4.40.0 - TBD + +### Added + +- New `pg_stream` input supporting change data capture (CDC) from PostgreSQL (@le-vlad) + +### Changed + +- `snowflake_streaming` with `schema_evolution.enabled` set to true can now autocreate tables. + ## 4.39.0 - 2024-11-07 ### Added