Skip to content

Commit

Permalink
NEOS-1565: Refactors CLI to share DB Connection Pools (#2844)
Browse files Browse the repository at this point in the history
  • Loading branch information
nickzelei authored Oct 23, 2024
1 parent 8155544 commit 23637d2
Show file tree
Hide file tree
Showing 17 changed files with 261 additions and 205 deletions.
244 changes: 199 additions & 45 deletions cli/internal/cmds/neosync/sync/sync.go

Large diffs are not rendered by default.

13 changes: 4 additions & 9 deletions cli/internal/cmds/neosync/sync/sync_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package sync_cmd

import (
"os"
"io"
"log/slog"
"testing"

charmlog "github.com/charmbracelet/log"
tabledependency "github.com/nucleuscloud/neosync/backend/pkg/table-dependency"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -103,9 +103,7 @@ func Test_groupConfigsByDependency(t *testing.T) {
},
}

logger := charmlog.NewWithOptions(os.Stderr, charmlog.Options{
ReportTimestamp: true,
})
logger := slog.New(slog.NewTextHandler(io.Discard, nil))

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down Expand Up @@ -133,10 +131,7 @@ func Test_groupConfigsByDependency_Error(t *testing.T) {
{Name: "public.b", DependsOn: []*tabledependency.DependsOn{{Table: "public.c", Columns: []string{"id"}}}, Table: "public.b", Columns: []string{"id", "c_id"}},
{Name: "public.c", DependsOn: []*tabledependency.DependsOn{{Table: "public.a", Columns: []string{"id"}}}, Table: "public.c", Columns: []string{"id", "a_id"}},
}
logger := charmlog.NewWithOptions(os.Stderr, charmlog.Options{
ReportTimestamp: true,
})
groups := groupConfigsByDependency(configs, logger)
groups := groupConfigsByDependency(configs, slog.New(slog.NewTextHandler(io.Discard, nil)))
require.Nil(t, groups)
}

Expand Down
35 changes: 19 additions & 16 deletions cli/internal/cmds/neosync/sync/ui.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"context"
"fmt"
"io"
"os"
"log/slog"
"strings"
syncmap "sync"
"time"
Expand All @@ -18,16 +18,17 @@ import (
_ "github.com/warpstreamlabs/bento/public/components/io"
_ "github.com/warpstreamlabs/bento/public/components/pure"
_ "github.com/warpstreamlabs/bento/public/components/pure/extended"
"github.com/warpstreamlabs/bento/public/service"

"github.com/charmbracelet/bubbles/spinner"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
charmlog "github.com/charmbracelet/log"
)

type model struct {
ctx context.Context
logger *charmlog.Logger
logger *slog.Logger
benv *service.Environment
groupedConfigs [][]*benthosConfigResponse
tableSynced int
index int
Expand All @@ -50,7 +51,7 @@ var (
durationStyle = dotStyle
)

func newModel(ctx context.Context, groupedConfigs [][]*benthosConfigResponse, logger *charmlog.Logger, outputType output.OutputType) *model {
func newModel(ctx context.Context, benv *service.Environment, groupedConfigs [][]*benthosConfigResponse, logger *slog.Logger, outputType output.OutputType) *model {
s := spinner.New()
s.Style = lipgloss.NewStyle().Foreground(lipgloss.Color("63"))
return &model{
Expand All @@ -61,11 +62,12 @@ func newModel(ctx context.Context, groupedConfigs [][]*benthosConfigResponse, lo
totalConfigCount: getConfigCount(groupedConfigs),
logger: logger,
outputType: outputType,
benv: benv,
}
}

func (m *model) Init() tea.Cmd {
return tea.Batch(m.syncConfigs(m.ctx, m.groupedConfigs[m.index]), m.spinner.Tick)
return tea.Batch(m.syncConfigs(m.ctx, m.benv, m.groupedConfigs[m.index]), m.spinner.Tick)
}

func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
Expand All @@ -85,7 +87,7 @@ func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}
if m.totalConfigCount == m.tableSynced {
m.done = true
m.logger.Infof("Done! Completed %d tables.", m.tableSynced)
m.logger.Info(fmt.Sprintf("Done! Completed %d tables.", m.tableSynced))
return m, tea.Sequence(
tea.Println(strings.Join(successStrs, " \n")),
tea.Quit,
Expand All @@ -95,7 +97,7 @@ func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.index++
return m, tea.Batch(
tea.Println(strings.Join(successStrs, " \n")),
m.syncConfigs(m.ctx, m.groupedConfigs[m.index]),
m.syncConfigs(m.ctx, m.benv, m.groupedConfigs[m.index]),
)
case spinner.TickMsg:
var cmd tea.Cmd
Expand Down Expand Up @@ -135,7 +137,7 @@ func (m *model) View() string {

type syncedDataMsg map[string]string

func (m *model) syncConfigs(ctx context.Context, configs []*benthosConfigResponse) tea.Cmd {
func (m *model) syncConfigs(ctx context.Context, benv *service.Environment, configs []*benthosConfigResponse) tea.Cmd {
return func() tea.Msg {
messageMap := syncmap.Map{}
errgrp, errctx := errgroup.WithContext(ctx)
Expand All @@ -144,15 +146,15 @@ func (m *model) syncConfigs(ctx context.Context, configs []*benthosConfigRespons
cfg := cfg
errgrp.Go(func() error {
start := time.Now()
m.logger.Infof("Syncing table %s", cfg.Name)
err := syncData(errctx, cfg, m.logger, m.outputType)
m.logger.Info(fmt.Sprintf("Syncing table %s", cfg.Name))
err := syncData(errctx, benv, cfg, m.logger, m.outputType)
if err != nil {
fmt.Printf("Error syncing table: %s", err.Error()) //nolint:forbidigo
return err
}
duration := time.Since(start)
messageMap.Store(cfg.Name, duration)
m.logger.Infof("Finished syncing table %s %s", cfg.Name, duration.String())
m.logger.Info(fmt.Sprintf("Finished syncing table %s %s", cfg.Name, duration.String()))
return nil
})
}
Expand Down Expand Up @@ -190,19 +192,20 @@ func getConfigCount(groupedConfigs [][]*benthosConfigResponse) int {
return count
}

func runSync(ctx context.Context, outputType output.OutputType, groupedConfigs [][]*benthosConfigResponse, logger *charmlog.Logger) error {
func runSync(ctx context.Context, outputType output.OutputType, benv *service.Environment, groupedConfigs [][]*benthosConfigResponse, logger *slog.Logger) error {
var opts []tea.ProgramOption
var synclogger = logger
if outputType == output.PlainOutput {
// Plain mode don't render the TUI
opts = []tea.ProgramOption{tea.WithoutRenderer(), tea.WithInput(nil)}
} else {
fmt.Println(bold.Render(" \n Completed Tables")) //nolint:forbidigo
// TUI mode, discard log output
logger.SetOutput(io.Discard)
synclogger = slog.New(slog.NewJSONHandler(io.Discard, nil))
}
if _, err := tea.NewProgram(newModel(ctx, groupedConfigs, logger, outputType), opts...).Run(); err != nil {
logger.Error("Error syncing data:", err)
os.Exit(1)
if _, err := tea.NewProgram(newModel(ctx, benv, groupedConfigs, synclogger, outputType), opts...).Run(); err != nil {
logger.Error(fmt.Sprintf("Error syncing data: %v", err))
return fmt.Errorf("unable to finish syncing data: %w", err)
}
return nil
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package sync_activity
package pool_mongo_provider

import (
"errors"
Expand All @@ -7,36 +7,36 @@ import (
"sync"

mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1"
connectiontunnelmanager "github.com/nucleuscloud/neosync/worker/internal/connection-tunnel-manager"
connectiontunnelmanager "github.com/nucleuscloud/neosync/internal/connection-tunnel-manager"
neosync_benthos_mongodb "github.com/nucleuscloud/neosync/worker/pkg/benthos/mongodb"
)

type mongoConnectionGetter = func(url string) (neosync_benthos_mongodb.MongoClient, error)
type Getter = func(url string) (neosync_benthos_mongodb.MongoClient, error)

// wrapper used for benthos sql-based connections to retrieve the connection they need
type mongoPoolPovider struct {
getter mongoConnectionGetter
// wrapper used for benthos mongo-based connections to retrieve the connection they need
type Provider struct {
getter Getter
}

var _ neosync_benthos_mongodb.MongoPoolProvider = &mongoPoolPovider{}
var _ neosync_benthos_mongodb.MongoPoolProvider = (*Provider)(nil)

func newMongoPoolProvider(getter mongoConnectionGetter) *mongoPoolPovider {
return &mongoPoolPovider{getter: getter}
func NewProvider(getter Getter) *Provider {
return &Provider{getter: getter}
}

func (p *mongoPoolPovider) GetClient(url string) (neosync_benthos_mongodb.MongoClient, error) {
func (p *Provider) GetClient(url string) (neosync_benthos_mongodb.MongoClient, error) {
return p.getter(url)
}

// Returns a function that converts a raw DSN directly to the relevant pooled sql client.
// Allows sharing connections across activities for effective pooling and SSH tunnel management.
func getMongoPoolProviderGetter(
func GetMongoPoolProviderGetter(
tunnelmanager connectiontunnelmanager.Interface[any],
dsnToConnectionIdMap *sync.Map,
connectionMap map[string]*mgmtv1alpha1.Connection,
session string,
slogger *slog.Logger,
) mongoConnectionGetter {
) Getter {
return func(url string) (neosync_benthos_mongodb.MongoClient, error) {
connid, ok := dsnToConnectionIdMap.Load(url)
if !ok {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package sync_activity
package pool_sql_provider

import (
"errors"
Expand All @@ -7,34 +7,36 @@ import (
"sync"

mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1"
connectiontunnelmanager "github.com/nucleuscloud/neosync/worker/internal/connection-tunnel-manager"
connectiontunnelmanager "github.com/nucleuscloud/neosync/internal/connection-tunnel-manager"
neosync_benthos_sql "github.com/nucleuscloud/neosync/worker/pkg/benthos/sql"
)

type sqlConnectionGetter = func(dsn string) (neosync_benthos_sql.SqlDbtx, error)
type Getter = func(dsn string) (neosync_benthos_sql.SqlDbtx, error)

// wrapper used for benthos sql-based connections to retrieve the connection they need
type sqlPoolProvider struct {
getter sqlConnectionGetter
type Provider struct {
getter Getter
}

func newSqlPoolProvider(getter sqlConnectionGetter) *sqlPoolProvider {
return &sqlPoolProvider{getter: getter}
var _ neosync_benthos_sql.DbPoolProvider = (*Provider)(nil)

func NewProvider(getter Getter) *Provider {
return &Provider{getter: getter}
}

func (p *sqlPoolProvider) GetDb(driver, dsn string) (neosync_benthos_sql.SqlDbtx, error) {
func (p *Provider) GetDb(driver, dsn string) (neosync_benthos_sql.SqlDbtx, error) {
return p.getter(dsn)
}

// Returns a function that converts a raw DSN directly to the relevant pooled sql client.
// Allows sharing connections across activities for effective pooling and SSH tunnel management.
func getSqlPoolProviderGetter(
func GetSqlPoolProviderGetter(
tunnelmanager connectiontunnelmanager.Interface[any],
dsnToConnectionIdMap *sync.Map,
connectionMap map[string]*mgmtv1alpha1.Connection,
session string,
slogger *slog.Logger,
) sqlConnectionGetter {
) Getter {
return func(dsn string) (neosync_benthos_sql.SqlDbtx, error) {
connid, ok := dsnToConnectionIdMap.Load(dsn)
if !ok {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package sync_activity
package pool_sql_provider

import (
"testing"
Expand All @@ -8,11 +8,11 @@ import (
)

func Test_newPoolProvider(t *testing.T) {
assert.NotNil(t, newSqlPoolProvider(nil))
assert.NotNil(t, NewProvider(nil))
}

func Test_newPoolProvider_GetDb(t *testing.T) {
provider := newSqlPoolProvider(func(dsn string) (neosync_benthos_sql.SqlDbtx, error) {
provider := NewProvider(func(dsn string) (neosync_benthos_sql.SqlDbtx, error) {
return neosync_benthos_sql.NewMockSqlDbtx(t), nil
})
assert.NotNil(t, provider)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"errors"

mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1"
connectiontunnelmanager "github.com/nucleuscloud/neosync/worker/internal/connection-tunnel-manager"
connectiontunnelmanager "github.com/nucleuscloud/neosync/internal/connection-tunnel-manager"
neosync_benthos_mongodb "github.com/nucleuscloud/neosync/worker/pkg/benthos/mongodb"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"fmt"

mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1"
connectiontunnelmanager "github.com/nucleuscloud/neosync/worker/internal/connection-tunnel-manager"
connectiontunnelmanager "github.com/nucleuscloud/neosync/internal/connection-tunnel-manager"
neosync_benthos_mongodb "github.com/nucleuscloud/neosync/worker/pkg/benthos/mongodb"
neosync_benthos_sql "github.com/nucleuscloud/neosync/worker/pkg/benthos/sql"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"testing"

mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1"
connectiontunnelmanager "github.com/nucleuscloud/neosync/worker/internal/connection-tunnel-manager"
connectiontunnelmanager "github.com/nucleuscloud/neosync/internal/connection-tunnel-manager"
neosync_benthos_mongodb "github.com/nucleuscloud/neosync/worker/pkg/benthos/mongodb"
neosync_benthos_sql "github.com/nucleuscloud/neosync/worker/pkg/benthos/sql"
"github.com/stretchr/testify/mock"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (

mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1"
"github.com/nucleuscloud/neosync/backend/pkg/sqlconnect"
connectiontunnelmanager "github.com/nucleuscloud/neosync/worker/internal/connection-tunnel-manager"
connectiontunnelmanager "github.com/nucleuscloud/neosync/internal/connection-tunnel-manager"
neosync_benthos_sql "github.com/nucleuscloud/neosync/worker/pkg/benthos/sql"
)

Expand Down
20 changes: 0 additions & 20 deletions worker/pkg/benthos/mongodb/output.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,26 +77,6 @@ func outputSpec() *service.ConfigSpec {
return spec
}

// func init() {
// err := service.RegisterBatchOutput(
// "mongodb", outputSpec(),
// func(conf *service.ParsedConfig, mgr *service.Resources) (out service.BatchOutput, batchPol service.BatchPolicy, mif int, err error) {
// if batchPol, err = conf.FieldBatchPolicy(moFieldBatching); err != nil {
// return
// }
// if mif, err = conf.FieldMaxInFlight(); err != nil {
// return
// }
// if out, err = newOutputWriter(conf, mgr); err != nil {
// return
// }
// return
// })
// if err != nil {
// panic(err)
// }
// }

func RegisterPooledMongoDbOutput(env *service.Environment, clientProvider MongoPoolProvider) error {
return env.RegisterBatchOutput(
"pooled_mongodb", outputSpec(),
Expand Down
27 changes: 0 additions & 27 deletions worker/pkg/benthos/sql/output_sql_insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"encoding/json"
"fmt"
"log/slog"
"os"
"strings"
"sync"

Expand Down Expand Up @@ -64,32 +63,6 @@ func RegisterPooledSqlInsertOutput(env *service.Environment, dbprovider DbPoolPr
)
}

func init() {
dbprovider := NewDbPoolProvider()
err := service.RegisterBatchOutput(
"pooled_sql_insert", sqlInsertOutputSpec(),
func(conf *service.ParsedConfig, mgr *service.Resources) (service.BatchOutput, service.BatchPolicy, int, error) {
batchPolicy, err := conf.FieldBatchPolicy("batching")
if err != nil {
return nil, batchPolicy, -1, err
}

maxInFlight, err := conf.FieldInt("max_in_flight")
if err != nil {
return nil, service.BatchPolicy{}, -1, err
}
logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{}))
out, err := newInsertOutput(conf, mgr, dbprovider, false, logger)
if err != nil {
return nil, service.BatchPolicy{}, -1, err
}
return out, batchPolicy, maxInFlight, nil
})
if err != nil {
panic(err)
}
}

var _ service.BatchOutput = &pooledInsertOutput{}

type pooledInsertOutput struct {
Expand Down
Loading

0 comments on commit 23637d2

Please sign in to comment.