From b70f69454d6da27cae19068d89368eeed62fc829 Mon Sep 17 00:00:00 2001 From: Vladyslav Len Date: Thu, 14 Nov 2024 12:26:15 +0100 Subject: [PATCH] chore(): sanitized queries && fixed tests --- internal/impl/postgresql/integration_test.go | 10 ++-- .../pglogicalstream/logical_stream.go | 4 ++ .../postgresql/pglogicalstream/monitor.go | 11 +++-- .../postgresql/pglogicalstream/pglogrepl.go | 48 +++++++++++++++---- .../pglogicalstream/sanitize/sanitize.go | 30 ++++++++++++ .../postgresql/pglogicalstream/snapshotter.go | 32 ++++++++++--- 6 files changed, 112 insertions(+), 23 deletions(-) diff --git a/internal/impl/postgresql/integration_test.go b/internal/impl/postgresql/integration_test.go index 283a281c2..704f7bab4 100644 --- a/internal/impl/postgresql/integration_test.go +++ b/internal/impl/postgresql/integration_test.go @@ -352,7 +352,7 @@ file: `, tmpDir) streamOutBuilder := service.NewStreamBuilder() - require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: OFF`)) + require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: INFO`)) require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf)) require.NoError(t, streamOutBuilder.AddInputYAML(template)) @@ -455,7 +455,7 @@ pg_stream: snapshot_batch_size: 100000 stream_snapshot: true decoding_plugin: pgoutput - stream_uncommitted: false + batch_transactions: false temporary_slot: true schema: public tables: @@ -550,7 +550,7 @@ pg_stream: snapshot_batch_size: 100 stream_snapshot: true decoding_plugin: pgoutput - stream_uncommitted: true + batch_transactions: true schema: public tables: - flights @@ -689,7 +689,7 @@ pg_stream: slot_name: test_slot_native_decoder stream_snapshot: true decoding_plugin: pgoutput - stream_uncommitted: true + batch_transactions: true schema: public tables: - flights @@ -826,7 +826,7 @@ pg_stream: slot_name: test_slot_native_decoder stream_snapshot: true decoding_plugin: pgoutput - stream_uncommitted: false + batch_transactions: false schema: public tables: - flights diff --git a/internal/impl/postgresql/pglogicalstream/logical_stream.go b/internal/impl/postgresql/pglogicalstream/logical_stream.go index 2c6c6bb11..9f2ea48f9 100644 --- a/internal/impl/postgresql/pglogicalstream/logical_stream.go +++ b/internal/impl/postgresql/pglogicalstream/logical_stream.go @@ -88,6 +88,10 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) { tableNames := slices.Clone(config.DBTables) for i, table := range tableNames { + if err := sanitize.ValidatePostgresIdentifier(table); err != nil { + return nil, fmt.Errorf("invalid table name %q: %w", table, err) + } + tableNames[i] = fmt.Sprintf("%s.%s", config.DBSchema, table) } stream := &Stream{ diff --git a/internal/impl/postgresql/pglogicalstream/monitor.go b/internal/impl/postgresql/pglogicalstream/monitor.go index b8e8e8f75..f800bb159 100644 --- a/internal/impl/postgresql/pglogicalstream/monitor.go +++ b/internal/impl/postgresql/pglogicalstream/monitor.go @@ -20,6 +20,7 @@ import ( "github.com/redpanda-data/benthos/v4/public/service" + "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/sanitize" "github.com/redpanda-data/connect/v4/internal/periodic" ) @@ -99,11 +100,15 @@ func (m *Monitor) readTablesStat(ctx context.Context, tables []string) error { for _, table := range tables { tableWithoutSchema := strings.Split(table, ".")[1] - // TODO(le-vlad): Implement proper SQL injection protection - query := "SELECT COUNT(*) FROM " + tableWithoutSchema + err := sanitize.ValidatePostgresIdentifier(tableWithoutSchema) + + if err != nil { + return fmt.Errorf("error sanitizing query: %w", err) + } var count int64 - err := m.dbConn.QueryRowContext(ctx, query).Scan(&count) + // tableWithoutSchema has been validated so its safe to use in the query + err = m.dbConn.QueryRowContext(ctx, "SELECT COUNT(*) FROM %s"+tableWithoutSchema).Scan(&count) if err != nil { // If the error is because the table doesn't exist, we'll set the count to 0 diff --git a/internal/impl/postgresql/pglogicalstream/pglogrepl.go b/internal/impl/postgresql/pglogicalstream/pglogrepl.go index 5e9647a09..833b96893 100644 --- a/internal/impl/postgresql/pglogicalstream/pglogrepl.go +++ b/internal/impl/postgresql/pglogicalstream/pglogrepl.go @@ -21,6 +21,7 @@ import ( "context" "database/sql/driver" "encoding/binary" + "errors" "fmt" "slices" "strconv" @@ -253,7 +254,7 @@ func CreateReplicationSlot( // NOTE: All strings passed into here have been validated and are not prone to SQL injection. newPgCreateSlotCommand := fmt.Sprintf("CREATE_REPLICATION_SLOT %s %s %s %s %s", slotName, temporaryString, options.Mode, outputPlugin, snapshotString) - oldPgCreateSlotCommand := fmt.Sprintf("SELECT * FROM pg_create_logical_replication_slot('%s', '%s', %v);", slotName, outputPlugin, temporaryString == "TEMPORARY") + oldPgCreateSlotCommand := fmt.Sprintf("SELECT * FROM pg_create_logical_replication_slot('%s', '%s', %v);", slotName, outputPlugin, temporaryString == "TEMPORARY") var snapshotName string if version > 14 { @@ -353,6 +354,17 @@ func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName return fmt.Errorf("failed to sanitize publication query: %w", err) } + // Since we need to pass table names without quoting, we need to validate it + for _, table := range tables { + if err := sanitize.ValidatePostgresIdentifier(table); err != nil { + return errors.New("invalid table name") + } + } + // the same for publication name + if err := sanitize.ValidatePostgresIdentifier(publicationName); err != nil { + return errors.New("invalid publication name") + } + result := conn.Exec(ctx, pubQuery) rows, err := result.ReadAll() @@ -362,13 +374,27 @@ func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName tablesClause := "FOR ALL TABLES" if len(tables) > 0 { - // TODO(le-vlad): Implement proper SQL injection protection - tablesClause = "FOR TABLE " + strings.Join(tables, ",") + // quotedTables := make([]string, len(tables)) + // for i, table := range tables { + // // Use sanitize.SQLIdentifier to properly quote and escape table names + // quoted, err := sanitize.SQLIdentifier(table) + // if err != nil { + // return fmt.Errorf("invalid table name %q: %w", table, err) + // } + // quotedTables[i] = quoted + // } + tablesClause = "FOR TABLE " + strings.Join(tables, ", ") } if len(rows) == 0 || len(rows[0].Rows) == 0 { + // tablesClause is sanitized, so we can safely interpolate it into the query + sq, err := sanitize.SQLQuery(fmt.Sprintf("CREATE PUBLICATION %s %s;", publicationName, tablesClause)) + fmt.Print(sq) + if err != nil { + return fmt.Errorf("failed to sanitize publication creation query: %w", err) + } // Publication doesn't exist, create new one - result = conn.Exec(ctx, fmt.Sprintf("CREATE PUBLICATION %s %s;", publicationName, tablesClause)) + result = conn.Exec(ctx, sq) if _, err := result.ReadAll(); err != nil { return fmt.Errorf("failed to create publication: %w", err) } @@ -405,8 +431,11 @@ func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName // remove tables from publication for _, dropTable := range tablesToRemoveFromPublication { - // TODO(le-vlad): Implement proper SQL injection protection - result = conn.Exec(ctx, fmt.Sprintf("ALTER PUBLICATION %s DROP TABLE %s;", publicationName, dropTable)) + sq, err := sanitize.SQLQuery(fmt.Sprintf("ALTER PUBLICATION %s DROP TABLE %s;", publicationName, dropTable)) + if err != nil { + return fmt.Errorf("failed to sanitize drop table query: %w", err) + } + result = conn.Exec(ctx, sq) if _, err := result.ReadAll(); err != nil { return fmt.Errorf("failed to remove table from publication: %w", err) } @@ -414,8 +443,11 @@ func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName // add tables to publication for _, addTable := range tablesToAddToPublication { - // TODO(le-vlad): Implement proper SQL injection protection - result = conn.Exec(ctx, fmt.Sprintf("ALTER PUBLICATION %s ADD TABLE %s;", publicationName, addTable)) + sq, err := sanitize.SQLQuery(fmt.Sprintf("ALTER PUBLICATION %s ADD TABLE %s;", publicationName, addTable)) + if err != nil { + return fmt.Errorf("failed to sanitize add table query: %w", err) + } + result = conn.Exec(ctx, sq) if _, err := result.ReadAll(); err != nil { return fmt.Errorf("failed to add table to publication: %w", err) } diff --git a/internal/impl/postgresql/pglogicalstream/sanitize/sanitize.go b/internal/impl/postgresql/pglogicalstream/sanitize/sanitize.go index 95f35bd8c..5bba854bb 100644 --- a/internal/impl/postgresql/pglogicalstream/sanitize/sanitize.go +++ b/internal/impl/postgresql/pglogicalstream/sanitize/sanitize.go @@ -33,9 +33,13 @@ import ( "strconv" "strings" "time" + "unicode" "unicode/utf8" ) +// MaxIdentifierLength is PostgreSQL's maximum identifier length +const MaxIdentifierLength = 63 + // Part is either a string or an int. A string is raw SQL. An int is a // argument placeholder. type Part any @@ -358,3 +362,29 @@ func SQLQuery(sql string, args ...any) (string, error) { } return query.Sanitize(args...) } + +// ValidatePostgresIdentifier checks if a string is a valid PostgreSQL identifier +// This follows PostgreSQL's standard naming rules +func ValidatePostgresIdentifier(name string) error { + if len(name) == 0 { + return errors.New("empty identifier is not allowed") + } + + if len(name) > MaxIdentifierLength { + return fmt.Errorf("identifier length exceeds maximum of %d characters", MaxIdentifierLength) + } + + // First character must be a letter or underscore + if !unicode.IsLetter(rune(name[0])) && name[0] != '_' { + return errors.New("identifier must start with a letter or underscore") + } + + // Subsequent characters must be letters, numbers, underscores, or dots + for i, char := range name { + if !unicode.IsLetter(char) && !unicode.IsDigit(char) && char != '_' && char != '.' { + return fmt.Errorf("invalid character '%c' at position %d in identifier '%s'", char, i, name) + } + } + + return nil +} diff --git a/internal/impl/postgresql/pglogicalstream/snapshotter.go b/internal/impl/postgresql/pglogicalstream/snapshotter.go index 4129f19cd..ce581c169 100644 --- a/internal/impl/postgresql/pglogicalstream/snapshotter.go +++ b/internal/impl/postgresql/pglogicalstream/snapshotter.go @@ -16,6 +16,7 @@ import ( _ "github.com/lib/pq" "github.com/redpanda-data/benthos/v4/public/service" + "github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/sanitize" ) // SnapshotCreationResponse is a structure that contains the name of the snapshot that was created @@ -97,8 +98,13 @@ func (s *Snapshotter) prepare() error { if _, err := s.pgConnection.Exec("BEGIN TRANSACTION ISOLATION LEVEL REPEATABLE READ;"); err != nil { return err } - // TODO(le-vlad): Implement proper SQL injection protection - if _, err := s.pgConnection.Exec(fmt.Sprintf("SET TRANSACTION SNAPSHOT '%s';", s.snapshotName)); err != nil { + + sq, err := sanitize.SQLQuery("SET TRANSACTION SNAPSHOT $1;", s.snapshotName) + if err != nil { + return err + } + + if _, err := s.pgConnection.Exec(sq); err != nil { return err } @@ -111,7 +117,8 @@ func (s *Snapshotter) findAvgRowSize(table string) (sql.NullInt64, error) { rows *sql.Rows err error ) - // TODO(le-vlad): Implement proper SQL injection protection + + // table is validated to be correct pg identifier, so we can use it directly if rows, err = s.pgConnection.Query(fmt.Sprintf(`SELECT SUM(pg_column_size('%s.*')) / COUNT(*) FROM %s;`, table, table)); err != nil { return avgRowSize, fmt.Errorf("can get avg row size due to query failure: %w", err) } @@ -170,11 +177,22 @@ func (s *Snapshotter) calculateBatchSize(availableMemory uint64, estimatedRowSiz func (s *Snapshotter) querySnapshotData(table string, lastSeenPk any, pk string, limit int) (rows *sql.Rows, err error) { s.logger.Infof("Query snapshot table: %v, limit: %v, lastSeenPkVal: %v, pk: %v", table, limit, lastSeenPk, pk) if lastSeenPk == nil { - // TODO(le-vlad): Implement proper SQL injection protection - return s.pgConnection.Query(fmt.Sprintf("SELECT * FROM %s ORDER BY %s LIMIT $1;", table, pk), limit) + // NOTE: All strings passed into here have been validated or derived from the code/database, therefore not prone to SQL injection. + sq, err := sanitize.SQLQuery(fmt.Sprintf("SELECT * FROM %s ORDER BY %s LIMIT %d;", table, pk, limit)) + if err != nil { + return nil, err + } + + return s.pgConnection.Query(sq) + } + + // NOTE: All strings passed into here have been validated or derived from the code/database, therefore not prone to SQL injection. + sq, err := sanitize.SQLQuery(fmt.Sprintf("SELECT * FROM %s WHERE %s > %s ORDER BY %s LIMIT %d;", table, pk, lastSeenPk, pk, limit)) + if err != nil { + return nil, err } - // TODO(le-vlad): Implement proper SQL injection protection - return s.pgConnection.Query(fmt.Sprintf("SELECT * FROM %s WHERE %s > $1 ORDER BY %s LIMIT $2;", table, pk, pk), lastSeenPk, limit) + + return s.pgConnection.Query(sq) } func (s *Snapshotter) releaseSnapshot() error {