Skip to content

Commit

Permalink
chore(): sanitized queries && fixed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
le-vlad committed Nov 14, 2024
1 parent 272eef0 commit b70f694
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 23 deletions.
10 changes: 5 additions & 5 deletions internal/impl/postgresql/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ file:
`, tmpDir)

streamOutBuilder := service.NewStreamBuilder()
require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: OFF`))
require.NoError(t, streamOutBuilder.SetLoggerYAML(`level: INFO`))
require.NoError(t, streamOutBuilder.AddCacheYAML(cacheConf))
require.NoError(t, streamOutBuilder.AddInputYAML(template))

Expand Down Expand Up @@ -455,7 +455,7 @@ pg_stream:
snapshot_batch_size: 100000
stream_snapshot: true
decoding_plugin: pgoutput
stream_uncommitted: false
batch_transactions: false
temporary_slot: true
schema: public
tables:
Expand Down Expand Up @@ -550,7 +550,7 @@ pg_stream:
snapshot_batch_size: 100
stream_snapshot: true
decoding_plugin: pgoutput
stream_uncommitted: true
batch_transactions: true
schema: public
tables:
- flights
Expand Down Expand Up @@ -689,7 +689,7 @@ pg_stream:
slot_name: test_slot_native_decoder
stream_snapshot: true
decoding_plugin: pgoutput
stream_uncommitted: true
batch_transactions: true
schema: public
tables:
- flights
Expand Down Expand Up @@ -826,7 +826,7 @@ pg_stream:
slot_name: test_slot_native_decoder
stream_snapshot: true
decoding_plugin: pgoutput
stream_uncommitted: false
batch_transactions: false
schema: public
tables:
- flights
Expand Down
4 changes: 4 additions & 0 deletions internal/impl/postgresql/pglogicalstream/logical_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ func NewPgStream(ctx context.Context, config *Config) (*Stream, error) {

tableNames := slices.Clone(config.DBTables)
for i, table := range tableNames {
if err := sanitize.ValidatePostgresIdentifier(table); err != nil {
return nil, fmt.Errorf("invalid table name %q: %w", table, err)
}

tableNames[i] = fmt.Sprintf("%s.%s", config.DBSchema, table)
}
stream := &Stream{
Expand Down
11 changes: 8 additions & 3 deletions internal/impl/postgresql/pglogicalstream/monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (

"github.com/redpanda-data/benthos/v4/public/service"

"github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/sanitize"
"github.com/redpanda-data/connect/v4/internal/periodic"
)

Expand Down Expand Up @@ -99,11 +100,15 @@ func (m *Monitor) readTablesStat(ctx context.Context, tables []string) error {

for _, table := range tables {
tableWithoutSchema := strings.Split(table, ".")[1]
// TODO(le-vlad): Implement proper SQL injection protection
query := "SELECT COUNT(*) FROM " + tableWithoutSchema
err := sanitize.ValidatePostgresIdentifier(tableWithoutSchema)

if err != nil {
return fmt.Errorf("error sanitizing query: %w", err)
}

var count int64
err := m.dbConn.QueryRowContext(ctx, query).Scan(&count)
// tableWithoutSchema has been validated so its safe to use in the query
err = m.dbConn.QueryRowContext(ctx, "SELECT COUNT(*) FROM %s"+tableWithoutSchema).Scan(&count)

if err != nil {
// If the error is because the table doesn't exist, we'll set the count to 0
Expand Down
48 changes: 40 additions & 8 deletions internal/impl/postgresql/pglogicalstream/pglogrepl.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"context"
"database/sql/driver"
"encoding/binary"
"errors"
"fmt"
"slices"
"strconv"
Expand Down Expand Up @@ -253,7 +254,7 @@ func CreateReplicationSlot(

// NOTE: All strings passed into here have been validated and are not prone to SQL injection.
newPgCreateSlotCommand := fmt.Sprintf("CREATE_REPLICATION_SLOT %s %s %s %s %s", slotName, temporaryString, options.Mode, outputPlugin, snapshotString)
oldPgCreateSlotCommand := fmt.Sprintf("SELECT * FROM pg_create_logical_replication_slot('%s', '%s', %v);", slotName, outputPlugin, temporaryString == "TEMPORARY")
oldPgCreateSlotCommand := fmt.Sprintf("SELECT * FROM pg_create_logical_replication_slot('%s', '%s', %v);", slotName, outputPlugin, temporaryString == "TEMPORARY")

var snapshotName string
if version > 14 {
Expand Down Expand Up @@ -353,6 +354,17 @@ func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName
return fmt.Errorf("failed to sanitize publication query: %w", err)
}

// Since we need to pass table names without quoting, we need to validate it
for _, table := range tables {
if err := sanitize.ValidatePostgresIdentifier(table); err != nil {
return errors.New("invalid table name")
}
}
// the same for publication name
if err := sanitize.ValidatePostgresIdentifier(publicationName); err != nil {
return errors.New("invalid publication name")
}

result := conn.Exec(ctx, pubQuery)

rows, err := result.ReadAll()
Expand All @@ -362,13 +374,27 @@ func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName

tablesClause := "FOR ALL TABLES"
if len(tables) > 0 {
// TODO(le-vlad): Implement proper SQL injection protection
tablesClause = "FOR TABLE " + strings.Join(tables, ",")
// quotedTables := make([]string, len(tables))
// for i, table := range tables {
// // Use sanitize.SQLIdentifier to properly quote and escape table names
// quoted, err := sanitize.SQLIdentifier(table)
// if err != nil {
// return fmt.Errorf("invalid table name %q: %w", table, err)
// }
// quotedTables[i] = quoted
// }
tablesClause = "FOR TABLE " + strings.Join(tables, ", ")
}

if len(rows) == 0 || len(rows[0].Rows) == 0 {
// tablesClause is sanitized, so we can safely interpolate it into the query
sq, err := sanitize.SQLQuery(fmt.Sprintf("CREATE PUBLICATION %s %s;", publicationName, tablesClause))
fmt.Print(sq)
if err != nil {
return fmt.Errorf("failed to sanitize publication creation query: %w", err)
}
// Publication doesn't exist, create new one
result = conn.Exec(ctx, fmt.Sprintf("CREATE PUBLICATION %s %s;", publicationName, tablesClause))
result = conn.Exec(ctx, sq)
if _, err := result.ReadAll(); err != nil {
return fmt.Errorf("failed to create publication: %w", err)
}
Expand Down Expand Up @@ -405,17 +431,23 @@ func CreatePublication(ctx context.Context, conn *pgconn.PgConn, publicationName

// remove tables from publication
for _, dropTable := range tablesToRemoveFromPublication {
// TODO(le-vlad): Implement proper SQL injection protection
result = conn.Exec(ctx, fmt.Sprintf("ALTER PUBLICATION %s DROP TABLE %s;", publicationName, dropTable))
sq, err := sanitize.SQLQuery(fmt.Sprintf("ALTER PUBLICATION %s DROP TABLE %s;", publicationName, dropTable))
if err != nil {
return fmt.Errorf("failed to sanitize drop table query: %w", err)
}
result = conn.Exec(ctx, sq)
if _, err := result.ReadAll(); err != nil {
return fmt.Errorf("failed to remove table from publication: %w", err)
}
}

// add tables to publication
for _, addTable := range tablesToAddToPublication {
// TODO(le-vlad): Implement proper SQL injection protection
result = conn.Exec(ctx, fmt.Sprintf("ALTER PUBLICATION %s ADD TABLE %s;", publicationName, addTable))
sq, err := sanitize.SQLQuery(fmt.Sprintf("ALTER PUBLICATION %s ADD TABLE %s;", publicationName, addTable))
if err != nil {
return fmt.Errorf("failed to sanitize add table query: %w", err)
}
result = conn.Exec(ctx, sq)
if _, err := result.ReadAll(); err != nil {
return fmt.Errorf("failed to add table to publication: %w", err)
}
Expand Down
30 changes: 30 additions & 0 deletions internal/impl/postgresql/pglogicalstream/sanitize/sanitize.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,13 @@ import (
"strconv"
"strings"
"time"
"unicode"
"unicode/utf8"
)

// MaxIdentifierLength is PostgreSQL's maximum identifier length
const MaxIdentifierLength = 63

// Part is either a string or an int. A string is raw SQL. An int is a
// argument placeholder.
type Part any
Expand Down Expand Up @@ -358,3 +362,29 @@ func SQLQuery(sql string, args ...any) (string, error) {
}
return query.Sanitize(args...)
}

// ValidatePostgresIdentifier checks if a string is a valid PostgreSQL identifier
// This follows PostgreSQL's standard naming rules
func ValidatePostgresIdentifier(name string) error {
if len(name) == 0 {
return errors.New("empty identifier is not allowed")
}

if len(name) > MaxIdentifierLength {
return fmt.Errorf("identifier length exceeds maximum of %d characters", MaxIdentifierLength)
}

// First character must be a letter or underscore
if !unicode.IsLetter(rune(name[0])) && name[0] != '_' {
return errors.New("identifier must start with a letter or underscore")
}

// Subsequent characters must be letters, numbers, underscores, or dots
for i, char := range name {
if !unicode.IsLetter(char) && !unicode.IsDigit(char) && char != '_' && char != '.' {
return fmt.Errorf("invalid character '%c' at position %d in identifier '%s'", char, i, name)
}
}

return nil
}
32 changes: 25 additions & 7 deletions internal/impl/postgresql/pglogicalstream/snapshotter.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (

_ "github.com/lib/pq"
"github.com/redpanda-data/benthos/v4/public/service"
"github.com/redpanda-data/connect/v4/internal/impl/postgresql/pglogicalstream/sanitize"
)

// SnapshotCreationResponse is a structure that contains the name of the snapshot that was created
Expand Down Expand Up @@ -97,8 +98,13 @@ func (s *Snapshotter) prepare() error {
if _, err := s.pgConnection.Exec("BEGIN TRANSACTION ISOLATION LEVEL REPEATABLE READ;"); err != nil {
return err
}
// TODO(le-vlad): Implement proper SQL injection protection
if _, err := s.pgConnection.Exec(fmt.Sprintf("SET TRANSACTION SNAPSHOT '%s';", s.snapshotName)); err != nil {

sq, err := sanitize.SQLQuery("SET TRANSACTION SNAPSHOT $1;", s.snapshotName)
if err != nil {
return err
}

if _, err := s.pgConnection.Exec(sq); err != nil {
return err
}

Expand All @@ -111,7 +117,8 @@ func (s *Snapshotter) findAvgRowSize(table string) (sql.NullInt64, error) {
rows *sql.Rows
err error
)
// TODO(le-vlad): Implement proper SQL injection protection

// table is validated to be correct pg identifier, so we can use it directly
if rows, err = s.pgConnection.Query(fmt.Sprintf(`SELECT SUM(pg_column_size('%s.*')) / COUNT(*) FROM %s;`, table, table)); err != nil {
return avgRowSize, fmt.Errorf("can get avg row size due to query failure: %w", err)
}
Expand Down Expand Up @@ -170,11 +177,22 @@ func (s *Snapshotter) calculateBatchSize(availableMemory uint64, estimatedRowSiz
func (s *Snapshotter) querySnapshotData(table string, lastSeenPk any, pk string, limit int) (rows *sql.Rows, err error) {
s.logger.Infof("Query snapshot table: %v, limit: %v, lastSeenPkVal: %v, pk: %v", table, limit, lastSeenPk, pk)
if lastSeenPk == nil {
// TODO(le-vlad): Implement proper SQL injection protection
return s.pgConnection.Query(fmt.Sprintf("SELECT * FROM %s ORDER BY %s LIMIT $1;", table, pk), limit)
// NOTE: All strings passed into here have been validated or derived from the code/database, therefore not prone to SQL injection.
sq, err := sanitize.SQLQuery(fmt.Sprintf("SELECT * FROM %s ORDER BY %s LIMIT %d;", table, pk, limit))
if err != nil {
return nil, err
}

return s.pgConnection.Query(sq)
}

// NOTE: All strings passed into here have been validated or derived from the code/database, therefore not prone to SQL injection.
sq, err := sanitize.SQLQuery(fmt.Sprintf("SELECT * FROM %s WHERE %s > %s ORDER BY %s LIMIT %d;", table, pk, lastSeenPk, pk, limit))
if err != nil {
return nil, err
}
// TODO(le-vlad): Implement proper SQL injection protection
return s.pgConnection.Query(fmt.Sprintf("SELECT * FROM %s WHERE %s > $1 ORDER BY %s LIMIT $2;", table, pk, pk), lastSeenPk, limit)

return s.pgConnection.Query(sq)
}

func (s *Snapshotter) releaseSnapshot() error {
Expand Down

0 comments on commit b70f694

Please sign in to comment.