diff --git a/backend/pkg/integration-test/integration-test-util.go b/backend/pkg/integration-test/integration-test-util.go
new file mode 100644
index 0000000000..3523116728
--- /dev/null
+++ b/backend/pkg/integration-test/integration-test-util.go
@@ -0,0 +1,73 @@
+package integrationtests_test
+
+import (
+ "context"
+ "testing"
+
+ "connectrpc.com/connect"
+ mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1"
+ "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1/mgmtv1alpha1connect"
+ "github.com/stretchr/testify/require"
+)
+
+func CreatePersonalAccount(
+ ctx context.Context,
+ t *testing.T,
+ userclient mgmtv1alpha1connect.UserAccountServiceClient,
+) string {
+ resp, err := userclient.SetPersonalAccount(ctx, connect.NewRequest(&mgmtv1alpha1.SetPersonalAccountRequest{}))
+ RequireNoErrResp(t, resp, err)
+ return resp.Msg.AccountId
+}
+
+func CreatePostgresConnection(
+ ctx context.Context,
+ t *testing.T,
+ connclient mgmtv1alpha1connect.ConnectionServiceClient,
+ accountId string,
+ name string,
+ pgurl string,
+) *mgmtv1alpha1.Connection {
+ resp, err := connclient.CreateConnection(
+ ctx,
+ connect.NewRequest(&mgmtv1alpha1.CreateConnectionRequest{
+ AccountId: accountId,
+ Name: name,
+ ConnectionConfig: &mgmtv1alpha1.ConnectionConfig{
+ Config: &mgmtv1alpha1.ConnectionConfig_PgConfig{
+ PgConfig: &mgmtv1alpha1.PostgresConnectionConfig{
+ ConnectionConfig: &mgmtv1alpha1.PostgresConnectionConfig_Url{
+ Url: pgurl,
+ },
+ },
+ },
+ },
+ }),
+ )
+ RequireNoErrResp(t, resp, err)
+ return resp.Msg.GetConnection()
+}
+
+func SetUser(ctx context.Context, t *testing.T, client mgmtv1alpha1connect.UserAccountServiceClient) string {
+ resp, err := client.SetUser(ctx, connect.NewRequest(&mgmtv1alpha1.SetUserRequest{}))
+ RequireNoErrResp(t, resp, err)
+ return resp.Msg.GetUserId()
+}
+
+func CreateTeamAccount(ctx context.Context, t *testing.T, client mgmtv1alpha1connect.UserAccountServiceClient, name string) string {
+ resp, err := client.CreateTeamAccount(ctx, connect.NewRequest(&mgmtv1alpha1.CreateTeamAccountRequest{Name: name}))
+ RequireNoErrResp(t, resp, err)
+ return resp.Msg.AccountId
+}
+
+func RequireNoErrResp[T any](t testing.TB, resp *connect.Response[T], err error) {
+ t.Helper()
+ require.NoError(t, err)
+ require.NotNil(t, resp)
+}
+
+func RequireErrResp[T any](t testing.TB, resp *connect.Response[T], err error) {
+ t.Helper()
+ require.Error(t, err)
+ require.Nil(t, resp)
+}
diff --git a/backend/pkg/integration-test/integration-test.go b/backend/pkg/integration-test/integration-test.go
new file mode 100644
index 0000000000..45d6ff1f57
--- /dev/null
+++ b/backend/pkg/integration-test/integration-test.go
@@ -0,0 +1,450 @@
+package integrationtests_test
+
+import (
+ "context"
+ "io"
+ "log/slog"
+ "net/http"
+ "net/http/httptest"
+ "sync"
+ "testing"
+
+ "connectrpc.com/connect"
+ db_queries "github.com/nucleuscloud/neosync/backend/gen/go/db"
+ mysql_queries "github.com/nucleuscloud/neosync/backend/gen/go/db/dbschemas/mysql"
+ pg_queries "github.com/nucleuscloud/neosync/backend/gen/go/db/dbschemas/postgresql"
+ "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1/mgmtv1alpha1connect"
+ "github.com/nucleuscloud/neosync/backend/internal/apikey"
+ auth_apikey "github.com/nucleuscloud/neosync/backend/internal/auth/apikey"
+ auth_client "github.com/nucleuscloud/neosync/backend/internal/auth/client"
+ auth_jwt "github.com/nucleuscloud/neosync/backend/internal/auth/jwt"
+ "github.com/nucleuscloud/neosync/backend/internal/authmgmt"
+ auth_interceptor "github.com/nucleuscloud/neosync/backend/internal/connect/interceptors/auth"
+ neosync_gcp "github.com/nucleuscloud/neosync/backend/internal/gcp"
+ "github.com/nucleuscloud/neosync/backend/internal/neosyncdb"
+ clientmanager "github.com/nucleuscloud/neosync/backend/internal/temporal/client-manager"
+ "github.com/nucleuscloud/neosync/backend/internal/utils"
+ "github.com/nucleuscloud/neosync/backend/pkg/mongoconnect"
+ mssql_queries "github.com/nucleuscloud/neosync/backend/pkg/mssql-querier"
+ "github.com/nucleuscloud/neosync/backend/pkg/sqlconnect"
+ "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager"
+ v1alpha_anonymizationservice "github.com/nucleuscloud/neosync/backend/services/mgmt/v1alpha1/anonymization-service"
+ v1alpha1_connectiondataservice "github.com/nucleuscloud/neosync/backend/services/mgmt/v1alpha1/connection-data-service"
+ v1alpha1_connectionservice "github.com/nucleuscloud/neosync/backend/services/mgmt/v1alpha1/connection-service"
+ v1alpha1_jobservice "github.com/nucleuscloud/neosync/backend/services/mgmt/v1alpha1/job-service"
+ v1alpha1_transformersservice "github.com/nucleuscloud/neosync/backend/services/mgmt/v1alpha1/transformers-service"
+ v1alpha1_useraccountservice "github.com/nucleuscloud/neosync/backend/services/mgmt/v1alpha1/user-account-service"
+ awsmanager "github.com/nucleuscloud/neosync/internal/aws"
+ "github.com/nucleuscloud/neosync/internal/billing"
+ presidioapi "github.com/nucleuscloud/neosync/internal/ee/presidio"
+ neomigrate "github.com/nucleuscloud/neosync/internal/migrate"
+ promapiv1mock "github.com/nucleuscloud/neosync/internal/mocks/github.com/prometheus/client_golang/api/prometheus/v1"
+ tcpostgres "github.com/nucleuscloud/neosync/internal/testutil/testcontainers/postgres"
+ http_client "github.com/nucleuscloud/neosync/worker/pkg/http/client"
+)
+
+var (
+ validAuthUser = &authmgmt.User{Name: "foo", Email: "bar", Picture: "baz"}
+)
+
+type UnauthdClients struct {
+ Users mgmtv1alpha1connect.UserAccountServiceClient
+ Transformers mgmtv1alpha1connect.TransformersServiceClient
+ Connections mgmtv1alpha1connect.ConnectionServiceClient
+ ConnectionData mgmtv1alpha1connect.ConnectionDataServiceClient
+ Jobs mgmtv1alpha1connect.JobServiceClient
+ Anonymize mgmtv1alpha1connect.AnonymizationServiceClient
+}
+
+type Mocks struct {
+ TemporalClientManager *clientmanager.MockTemporalClientManagerClient
+ Authclient *auth_client.MockInterface
+ Authmanagerclient *authmgmt.MockInterface
+ Prometheusclient *promapiv1mock.MockAPI
+ Billingclient *billing.MockInterface
+ Presidio Presidiomocks
+}
+
+type Presidiomocks struct {
+ Analyzer *presidioapi.MockAnalyzeInterface
+ Anonymizer *presidioapi.MockAnonymizeInterface
+ Entities *presidioapi.MockEntityInterface
+}
+
+type NeosyncApiTestClient struct {
+ NeosyncQuerier db_queries.Querier
+ systemQuerier pg_queries.Querier
+
+ Pgcontainer *tcpostgres.PostgresTestContainer
+ migrationsDir string
+
+ httpsrv *httptest.Server
+
+ UnauthdClients *UnauthdClients
+ NeosyncCloudClients *NeosyncCloudClients
+ AuthdClients *AuthdClients
+
+ Mocks *Mocks
+}
+
+// Option is a functional option for configuring Neosync Api Test Client
+type Option func(*NeosyncApiTestClient)
+
+func NewNeosyncApiTestClient(ctx context.Context, t *testing.T, opts ...Option) (*NeosyncApiTestClient, error) {
+ neoApi := &NeosyncApiTestClient{
+ migrationsDir: "../../../../sql/postgresql/schema",
+ }
+ for _, opt := range opts {
+ opt(neoApi)
+ }
+ err := neoApi.Setup(ctx, t)
+ if err != nil {
+ return nil, err
+ }
+ return neoApi, nil
+}
+
+// Sets neosync database migrations directory path
+func WithMigrationsDirectory(directoryPath string) Option {
+ return func(a *NeosyncApiTestClient) {
+ a.migrationsDir = directoryPath
+ }
+}
+
+type NeosyncCloudClients struct {
+ httpsrv *httptest.Server
+ basepath string
+}
+
+func (s *NeosyncCloudClients) GetUserClient(authUserId string) mgmtv1alpha1connect.UserAccountServiceClient {
+ return mgmtv1alpha1connect.NewUserAccountServiceClient(http_client.WithBearerAuth(&http.Client{}, &authUserId), s.httpsrv.URL+s.basepath)
+}
+
+func (s *NeosyncCloudClients) GetConnectionClient(authUserId string) mgmtv1alpha1connect.ConnectionServiceClient {
+ return mgmtv1alpha1connect.NewConnectionServiceClient(http_client.WithBearerAuth(&http.Client{}, &authUserId), s.httpsrv.URL+s.basepath)
+}
+
+func (s *NeosyncCloudClients) GetAnonymizeClient(authUserId string) mgmtv1alpha1connect.AnonymizationServiceClient {
+ return mgmtv1alpha1connect.NewAnonymizationServiceClient(http_client.WithBearerAuth(&http.Client{}, &authUserId), s.httpsrv.URL+s.basepath)
+}
+
+type AuthdClients struct {
+ httpsrv *httptest.Server
+}
+
+func (s *AuthdClients) GetUserClient(authUserId string) mgmtv1alpha1connect.UserAccountServiceClient {
+ return mgmtv1alpha1connect.NewUserAccountServiceClient(http_client.WithBearerAuth(&http.Client{}, &authUserId), s.httpsrv.URL+"/auth")
+}
+
+func (s *AuthdClients) GetConnectionClient(authUserId string) mgmtv1alpha1connect.ConnectionServiceClient {
+ return mgmtv1alpha1connect.NewConnectionServiceClient(http_client.WithBearerAuth(&http.Client{}, &authUserId), s.httpsrv.URL+"/auth")
+}
+
+func (s *NeosyncApiTestClient) Setup(ctx context.Context, t *testing.T) error {
+ pgcontainer, err := tcpostgres.NewPostgresTestContainer(ctx)
+ if err != nil {
+ return err
+ }
+ s.Pgcontainer = pgcontainer
+ s.NeosyncQuerier = db_queries.New()
+ s.systemQuerier = pg_queries.New()
+
+ s.Mocks = &Mocks{
+ TemporalClientManager: clientmanager.NewMockTemporalClientManagerClient(t),
+ Authclient: auth_client.NewMockInterface(t),
+ Authmanagerclient: authmgmt.NewMockInterface(t),
+ Prometheusclient: promapiv1mock.NewMockAPI(t),
+ Billingclient: billing.NewMockInterface(t),
+ Presidio: Presidiomocks{
+ Analyzer: presidioapi.NewMockAnalyzeInterface(t),
+ Anonymizer: presidioapi.NewMockAnonymizeInterface(t),
+ Entities: presidioapi.NewMockEntityInterface(t),
+ },
+ }
+
+ maxAllowed := int64(10000)
+ unauthdUserService := v1alpha1_useraccountservice.New(
+ &v1alpha1_useraccountservice.Config{IsAuthEnabled: false, IsNeosyncCloud: false, DefaultMaxAllowedRecords: &maxAllowed},
+ neosyncdb.New(pgcontainer.DB, db_queries.New()),
+ s.Mocks.TemporalClientManager,
+ s.Mocks.Authclient,
+ s.Mocks.Authmanagerclient,
+ nil,
+ )
+
+ authdUserService := v1alpha1_useraccountservice.New(
+ &v1alpha1_useraccountservice.Config{IsAuthEnabled: true, IsNeosyncCloud: false},
+ neosyncdb.New(pgcontainer.DB, db_queries.New()),
+ s.Mocks.TemporalClientManager,
+ s.Mocks.Authclient,
+ s.Mocks.Authmanagerclient,
+ nil,
+ )
+
+ authdConnectionService := v1alpha1_connectionservice.New(
+ &v1alpha1_connectionservice.Config{},
+ neosyncdb.New(pgcontainer.DB, db_queries.New()),
+ authdUserService,
+ &sqlconnect.SqlOpenConnector{},
+ pg_queries.New(),
+ mysql_queries.New(),
+ mssql_queries.New(),
+ mongoconnect.NewConnector(),
+ awsmanager.New(),
+ )
+
+ neoCloudAuthdUserService := v1alpha1_useraccountservice.New(
+ &v1alpha1_useraccountservice.Config{IsAuthEnabled: true, IsNeosyncCloud: true},
+ neosyncdb.New(pgcontainer.DB, db_queries.New()),
+ s.Mocks.TemporalClientManager,
+ s.Mocks.Authclient,
+ s.Mocks.Authmanagerclient,
+ s.Mocks.Billingclient,
+ )
+ neoCloudAuthdAnonymizeService := v1alpha_anonymizationservice.New(
+ &v1alpha_anonymizationservice.Config{IsAuthEnabled: true, IsNeosyncCloud: true, IsPresidioEnabled: false},
+ nil,
+ neoCloudAuthdUserService,
+ s.Mocks.Presidio.Analyzer,
+ s.Mocks.Presidio.Anonymizer,
+ neosyncdb.New(pgcontainer.DB, db_queries.New()),
+ )
+
+ neoCloudConnectionService := v1alpha1_connectionservice.New(
+ &v1alpha1_connectionservice.Config{},
+ neosyncdb.New(pgcontainer.DB, db_queries.New()),
+ neoCloudAuthdUserService,
+ &sqlconnect.SqlOpenConnector{},
+ pg_queries.New(),
+ mysql_queries.New(),
+ mssql_queries.New(),
+ mongoconnect.NewConnector(),
+ awsmanager.New(),
+ )
+
+ unauthdTransformersService := v1alpha1_transformersservice.New(
+ &v1alpha1_transformersservice.Config{
+ IsPresidioEnabled: true,
+ IsNeosyncCloud: false,
+ },
+ neosyncdb.New(pgcontainer.DB, db_queries.New()),
+ unauthdUserService,
+ s.Mocks.Presidio.Entities,
+ )
+
+ unauthdConnectionsService := v1alpha1_connectionservice.New(
+ &v1alpha1_connectionservice.Config{},
+ neosyncdb.New(pgcontainer.DB, db_queries.New()),
+ unauthdUserService,
+ &sqlconnect.SqlOpenConnector{},
+ pg_queries.New(),
+ mysql_queries.New(),
+ mssql_queries.New(),
+ mongoconnect.NewConnector(),
+ awsmanager.New(),
+ )
+
+ unauthdJobsService := v1alpha1_jobservice.New(
+ &v1alpha1_jobservice.Config{},
+ neosyncdb.New(pgcontainer.DB, db_queries.New()),
+ s.Mocks.TemporalClientManager,
+ unauthdConnectionsService,
+ unauthdUserService,
+ sqlmanager.NewSqlManager(
+ &sync.Map{}, pg_queries.New(),
+ &sync.Map{}, mysql_queries.New(),
+ &sync.Map{}, mssql_queries.New(),
+ &sqlconnect.SqlOpenConnector{},
+ ),
+ )
+
+ unauthdConnectionDataService := v1alpha1_connectiondataservice.New(
+ &v1alpha1_connectiondataservice.Config{},
+ unauthdUserService,
+ unauthdConnectionsService,
+ unauthdJobsService,
+ awsmanager.New(),
+ &sqlconnect.SqlOpenConnector{},
+ pg_queries.New(),
+ mysql_queries.New(),
+ mongoconnect.NewConnector(),
+ sqlmanager.NewSqlManager(
+ &sync.Map{}, pg_queries.New(),
+ &sync.Map{}, mysql_queries.New(),
+ &sync.Map{}, mssql_queries.New(),
+ &sqlconnect.SqlOpenConnector{},
+ ),
+ neosync_gcp.NewManager(),
+ )
+
+ var presAnalyzeClient presidioapi.AnalyzeInterface
+ var presAnonClient presidioapi.AnonymizeInterface
+
+ unauthdAnonymizationService := v1alpha_anonymizationservice.New(
+ &v1alpha_anonymizationservice.Config{IsPresidioEnabled: false},
+ nil,
+ unauthdUserService,
+ presAnalyzeClient, presAnonClient,
+ neosyncdb.New(pgcontainer.DB, db_queries.New()),
+ )
+
+ rootmux := http.NewServeMux()
+
+ unauthmux := http.NewServeMux()
+ unauthmux.Handle(mgmtv1alpha1connect.NewUserAccountServiceHandler(
+ unauthdUserService,
+ ))
+ unauthmux.Handle(mgmtv1alpha1connect.NewTransformersServiceHandler(
+ unauthdTransformersService,
+ ))
+ unauthmux.Handle(mgmtv1alpha1connect.NewConnectionServiceHandler(
+ unauthdConnectionsService,
+ ))
+ unauthmux.Handle(mgmtv1alpha1connect.NewJobServiceHandler(
+ unauthdJobsService,
+ ))
+ unauthmux.Handle(mgmtv1alpha1connect.NewAnonymizationServiceHandler(
+ unauthdAnonymizationService,
+ ))
+ unauthmux.Handle(mgmtv1alpha1connect.NewConnectionDataServiceHandler(
+ unauthdConnectionDataService,
+ ))
+ rootmux.Handle("/unauth/", http.StripPrefix("/unauth", unauthmux))
+
+ authinterceptors := connect.WithInterceptors(
+ auth_interceptor.NewInterceptor(func(ctx context.Context, header http.Header, spec connect.Spec) (context.Context, error) {
+ // will need to further fill this out as the tests grow
+ authuserid, err := utils.GetBearerTokenFromHeader(header, "Authorization")
+ if err != nil {
+ return nil, err
+ }
+ if apikey.IsValidV1WorkerKey(authuserid) {
+ return auth_apikey.SetTokenData(ctx, &auth_apikey.TokenContextData{
+ RawToken: authuserid,
+ ApiKey: nil,
+ ApiKeyType: apikey.WorkerApiKey,
+ }), nil
+ }
+ return auth_jwt.SetTokenData(ctx, &auth_jwt.TokenContextData{
+ AuthUserId: authuserid,
+ Claims: &auth_jwt.CustomClaims{Email: &validAuthUser.Email},
+ }), nil
+ }),
+ )
+
+ authmux := http.NewServeMux()
+ authmux.Handle(mgmtv1alpha1connect.NewUserAccountServiceHandler(
+ authdUserService,
+ authinterceptors,
+ ))
+ authmux.Handle(mgmtv1alpha1connect.NewConnectionServiceHandler(
+ authdConnectionService,
+ authinterceptors,
+ ))
+ rootmux.Handle("/auth/", http.StripPrefix("/auth", authmux))
+
+ ncauthmux := http.NewServeMux()
+ ncauthmux.Handle(mgmtv1alpha1connect.NewUserAccountServiceHandler(
+ neoCloudAuthdUserService,
+ authinterceptors,
+ ))
+ ncauthmux.Handle(mgmtv1alpha1connect.NewAnonymizationServiceHandler(
+ neoCloudAuthdAnonymizeService,
+ authinterceptors,
+ ))
+ ncauthmux.Handle(mgmtv1alpha1connect.NewConnectionServiceHandler(
+ neoCloudConnectionService,
+ authinterceptors,
+ ))
+ rootmux.Handle("/ncauth/", http.StripPrefix("/ncauth", ncauthmux))
+
+ s.httpsrv = startHTTPServer(t, rootmux)
+
+ s.UnauthdClients = &UnauthdClients{
+ Users: mgmtv1alpha1connect.NewUserAccountServiceClient(s.httpsrv.Client(), s.httpsrv.URL+"/unauth"),
+ Transformers: mgmtv1alpha1connect.NewTransformersServiceClient(s.httpsrv.Client(), s.httpsrv.URL+"/unauth"),
+ Connections: mgmtv1alpha1connect.NewConnectionServiceClient(s.httpsrv.Client(), s.httpsrv.URL+"/unauth"),
+ ConnectionData: mgmtv1alpha1connect.NewConnectionDataServiceClient(s.httpsrv.Client(), s.httpsrv.URL+"/unauth"),
+ Jobs: mgmtv1alpha1connect.NewJobServiceClient(s.httpsrv.Client(), s.httpsrv.URL+"/unauth"),
+ Anonymize: mgmtv1alpha1connect.NewAnonymizationServiceClient(s.httpsrv.Client(), s.httpsrv.URL+"/unauth"),
+ }
+
+ s.AuthdClients = &AuthdClients{
+ httpsrv: s.httpsrv,
+ }
+ s.NeosyncCloudClients = &NeosyncCloudClients{
+ httpsrv: s.httpsrv,
+ basepath: "/ncauth",
+ }
+
+ err = s.InitializeTest(ctx)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func (s *NeosyncApiTestClient) InitializeTest(ctx context.Context) error {
+ discardLogger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{}))
+ err := neomigrate.Up(ctx, s.Pgcontainer.URL, s.migrationsDir, discardLogger)
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func (s *NeosyncApiTestClient) CleanupTest(ctx context.Context) error {
+ // Dropping here because 1) more efficient and 2) we have a bad down migration
+ // _jobs-connection-id-null.down that breaks due to having a null connection_id column.
+ // we should do something about that at some point. Running this single drop is easier though
+ _, err := s.Pgcontainer.DB.Exec(ctx, "DROP SCHEMA IF EXISTS neosync_api CASCADE")
+ if err != nil {
+ return err
+ }
+ _, err = s.Pgcontainer.DB.Exec(ctx, "DROP TABLE IF EXISTS public.schema_migrations")
+ if err != nil {
+ return err
+ }
+ return nil
+}
+
+func (s *NeosyncApiTestClient) TearDown(ctx context.Context) error {
+ if s.Pgcontainer != nil {
+ _, err := s.Pgcontainer.DB.Exec(ctx, "DROP SCHEMA IF EXISTS neosync_api CASCADE")
+ if err != nil {
+ return err
+ }
+ _, err = s.Pgcontainer.DB.Exec(ctx, "DROP TABLE IF EXISTS public.schema_migrations")
+ if err != nil {
+ return err
+ }
+ if s.Pgcontainer.DB != nil {
+ s.Pgcontainer.DB.Close()
+ }
+ if s.Pgcontainer.TestContainer != nil {
+ err := s.Pgcontainer.TestContainer.Terminate(ctx)
+ if err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+func startHTTPServer(tb testing.TB, h http.Handler) *httptest.Server {
+ tb.Helper()
+ srv := httptest.NewUnstartedServer(h)
+ srv.EnableHTTP2 = true
+ srv.Start()
+ tb.Cleanup(srv.Close)
+ return srv
+}
+
+func NewTestSqlManagerClient() *sqlmanager.SqlManager {
+ return sqlmanager.NewSqlManager(
+ &sync.Map{}, pg_queries.New(),
+ &sync.Map{}, mysql_queries.New(),
+ &sync.Map{}, mssql_queries.New(),
+ &sqlconnect.SqlOpenConnector{},
+ )
+}
diff --git a/backend/pkg/sqlmanager/mysql/integration_test.go b/backend/pkg/sqlmanager/mysql/integration_test.go
index 216596ea40..867ef3b375 100644
--- a/backend/pkg/sqlmanager/mysql/integration_test.go
+++ b/backend/pkg/sqlmanager/mysql/integration_test.go
@@ -2,201 +2,74 @@ package sqlmanager_mysql
import (
"context"
- "database/sql"
"fmt"
"log/slog"
"os"
"testing"
- "time"
_ "github.com/go-sql-driver/mysql"
- "golang.org/x/sync/errgroup"
mysql_queries "github.com/nucleuscloud/neosync/backend/gen/go/db/dbschemas/mysql"
- sqlmanager_shared "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/shared"
+ tcmysql "github.com/nucleuscloud/neosync/internal/testutil/testcontainers/mysql"
"github.com/stretchr/testify/suite"
- "github.com/testcontainers/testcontainers-go"
- testmysql "github.com/testcontainers/testcontainers-go/modules/mysql"
- "github.com/testcontainers/testcontainers-go/wait"
)
type IntegrationTestSuite struct {
suite.Suite
- initSql string
- setupSql string
- teardownSql string
-
ctx context.Context
- source *mysqlTestContainer
- target *mysqlTestContainer
-}
-
-type mysqlTestContainer struct {
- pool *sql.DB
- querier mysql_queries.Querier
- container *testmysql.MySQLContainer
- url string
- close func()
-}
-
-type mysqlTest struct {
- source *mysqlTestContainer
- target *mysqlTestContainer
-}
-
-func (s *IntegrationTestSuite) SetupMysql() (*mysqlTest, error) {
- var source *mysqlTestContainer
- var target *mysqlTestContainer
-
- errgrp := errgroup.Group{}
- errgrp.Go(func() error {
- sourcecontainer, err := createMysqlTestContainer(s.ctx, "datasync", "root", "pass-source")
- if err != nil {
- return err
- }
- source = sourcecontainer
- return nil
- })
-
- errgrp.Go(func() error {
- targetcontainer, err := createMysqlTestContainer(s.ctx, "datasync", "root", "pass-target")
- if err != nil {
- return err
- }
- target = targetcontainer
- return nil
- })
-
- err := errgrp.Wait()
- if err != nil {
- return nil, err
- }
-
- return &mysqlTest{
- source: source,
- target: target,
- }, nil
-}
-
-func createMysqlTestContainer(
- ctx context.Context,
- database, username, password string,
-) (*mysqlTestContainer, error) {
- container, err := testmysql.Run(ctx,
- "mysql:8.0.36",
- testmysql.WithDatabase(database),
- testmysql.WithUsername(username),
- testmysql.WithPassword(password),
- testcontainers.WithWaitStrategy(
- wait.ForLog("port: 3306 MySQL Community Server").
- WithOccurrence(1).WithStartupTimeout(20*time.Second),
- ),
- )
- if err != nil {
- return nil, err
- }
- connstr, err := container.ConnectionString(ctx, "multiStatements=true&parseTime=true")
- if err != nil {
- panic(err)
- }
- pool, err := sql.Open(sqlmanager_shared.MysqlDriver, connstr)
- if err != nil {
- panic(err)
- }
- containerPort, err := container.MappedPort(ctx, "3306/tcp")
- if err != nil {
- return nil, err
- }
- containerHost, err := container.Host(ctx)
- if err != nil {
- return nil, err
- }
-
- connUrl := fmt.Sprintf("mysql://%s:%s@%s:%s/%s?multiStatements=true&parseTime=true", username, password, containerHost, containerPort.Port(), database)
- return &mysqlTestContainer{
- pool: pool,
- querier: mysql_queries.New(),
- url: connUrl,
- container: container,
- close: func() {
- if pool != nil {
- pool.Close()
- }
- },
- }, nil
+ querier mysql_queries.Querier
+ containers *tcmysql.MysqlTestSyncContainer
}
func (s *IntegrationTestSuite) SetupSuite() {
s.ctx = context.Background()
- m, err := s.SetupMysql()
- if err != nil {
- panic(err)
- }
- s.source = m.source
- s.target = m.target
-
- initSql, err := os.ReadFile("./testdata/init.sql")
- if err != nil {
- panic(err)
- }
- s.initSql = string(initSql)
-
- setupSql, err := os.ReadFile("./testdata/setup.sql")
- if err != nil {
- panic(err)
- }
- s.setupSql = string(setupSql)
-
- teardownSql, err := os.ReadFile("./testdata/teardown.sql")
+ container, err := tcmysql.NewMysqlTestSyncContainer(s.ctx, []tcmysql.Option{}, []tcmysql.Option{})
if err != nil {
panic(err)
}
- s.teardownSql = string(teardownSql)
+ s.containers = container
+ s.querier = mysql_queries.New()
}
// Runs before each test
func (s *IntegrationTestSuite) SetupTest() {
- _, err := s.target.pool.ExecContext(s.ctx, s.initSql)
+ err := s.containers.Target.RunSqlFiles(s.ctx, nil, []string{"testdata/init.sql"})
if err != nil {
panic(err)
}
- _, err = s.source.pool.ExecContext(s.ctx, s.setupSql)
+ err = s.containers.Source.RunSqlFiles(s.ctx, nil, []string{"testdata/setup.sql"})
if err != nil {
panic(err)
}
}
func (s *IntegrationTestSuite) TearDownTest() {
- _, err := s.target.pool.ExecContext(s.ctx, s.teardownSql)
+ err := s.containers.Source.RunSqlFiles(s.ctx, nil, []string{"testdata/teardown.sql"})
if err != nil {
panic(err)
}
- _, err = s.source.pool.ExecContext(s.ctx, s.teardownSql)
+ err = s.containers.Target.RunSqlFiles(s.ctx, nil, []string{"testdata/teardown.sql"})
if err != nil {
panic(err)
}
}
func (s *IntegrationTestSuite) TearDownSuite() {
- if s.source.pool != nil {
- s.source.close()
- }
- if s.target.pool != nil {
- s.target.close()
- }
- if s.source != nil {
- err := s.source.container.Terminate(s.ctx)
- if err != nil {
- panic(err)
+ if s.containers != nil {
+ if s.containers.Source != nil {
+ err := s.containers.Source.TearDown(s.ctx)
+ if err != nil {
+ panic(err)
+ }
}
- }
- if s.target != nil {
- err := s.target.container.Terminate(s.ctx)
- if err != nil {
- panic(err)
+ if s.containers.Target != nil {
+ err := s.containers.Target.TearDown(s.ctx)
+ if err != nil {
+ panic(err)
+ }
}
}
}
diff --git a/backend/pkg/sqlmanager/mysql/mysql-manager_integration_test.go b/backend/pkg/sqlmanager/mysql/mysql-manager_integration_test.go
index 754f1e0504..341d30db19 100644
--- a/backend/pkg/sqlmanager/mysql/mysql-manager_integration_test.go
+++ b/backend/pkg/sqlmanager/mysql/mysql-manager_integration_test.go
@@ -11,7 +11,7 @@ import (
)
func (s *IntegrationTestSuite) Test_GetTableConstraintsBySchema() {
- manager := MysqlManager{querier: s.source.querier, pool: s.source.pool}
+ manager := MysqlManager{querier: s.querier, pool: s.containers.Source.DB}
expected := &sqlmanager_shared.TableConstraints{
ForeignKeyConstraints: map[string][]*sqlmanager_shared.ForeignConstraint{
@@ -43,7 +43,7 @@ func (s *IntegrationTestSuite) Test_GetTableConstraintsBySchema() {
}
func (s *IntegrationTestSuite) Test_GetSchemaColumnMap() {
- manager := NewManager(s.source.querier, s.source.pool, func() {})
+ manager := NewManager(s.querier, s.containers.Source.DB, func() {})
schema := "sqlmanagermysql3"
actual, err := manager.GetSchemaColumnMap(context.Background())
require.NoError(s.T(), err)
@@ -58,7 +58,7 @@ func (s *IntegrationTestSuite) Test_GetSchemaColumnMap() {
}
func (s *IntegrationTestSuite) Test_GetForeignKeyConstraintsMap() {
- manager := NewManager(s.source.querier, s.source.pool, func() {})
+ manager := NewManager(s.querier, s.containers.Source.DB, func() {})
schema := "sqlmanagermysql3"
actual, err := manager.GetTableConstraintsBySchema(s.ctx, []string{schema})
require.NoError(s.T(), err)
@@ -114,7 +114,7 @@ func (s *IntegrationTestSuite) Test_GetForeignKeyConstraintsMap() {
}
func (s *IntegrationTestSuite) Test_GetForeignKeyConstraintsMap_BasicCircular() {
- manager := NewManager(s.source.querier, s.source.pool, func() {})
+ manager := NewManager(s.querier, s.containers.Source.DB, func() {})
schema := "sqlmanagermysql3"
actual, err := manager.GetTableConstraintsBySchema(s.ctx, []string{schema})
require.NoError(s.T(), err)
@@ -155,7 +155,7 @@ func (s *IntegrationTestSuite) Test_GetForeignKeyConstraintsMap_BasicCircular()
}
func (s *IntegrationTestSuite) Test_GetForeignKeyConstraintsMap_Composite() {
- manager := NewManager(s.source.querier, s.source.pool, func() {})
+ manager := NewManager(s.querier, s.containers.Source.DB, func() {})
schema := "sqlmanagermysql3"
actual, err := manager.GetTableConstraintsBySchema(s.ctx, []string{schema})
@@ -175,7 +175,7 @@ func (s *IntegrationTestSuite) Test_GetForeignKeyConstraintsMap_Composite() {
}
func (s *IntegrationTestSuite) Test_GetPrimaryKeyConstraintsMap() {
- manager := NewManager(s.source.querier, s.source.pool, func() {})
+ manager := NewManager(s.querier, s.containers.Source.DB, func() {})
schema := "sqlmanagermysql3"
actual, err := manager.GetTableConstraintsBySchema(context.Background(), []string{schema})
@@ -193,7 +193,7 @@ func (s *IntegrationTestSuite) Test_GetPrimaryKeyConstraintsMap() {
}
func (s *IntegrationTestSuite) Test_GetUniqueConstraintsMap() {
- manager := NewManager(s.source.querier, s.source.pool, func() {})
+ manager := NewManager(s.querier, s.containers.Source.DB, func() {})
schema := "sqlmanagermysql3"
actual, err := manager.GetTableConstraintsBySchema(context.Background(), []string{schema})
@@ -207,7 +207,7 @@ func (s *IntegrationTestSuite) Test_GetUniqueConstraintsMap() {
}
func (s *IntegrationTestSuite) Test_GetUniqueConstraintsMap_Composite() {
- manager := NewManager(s.source.querier, s.source.pool, func() {})
+ manager := NewManager(s.querier, s.containers.Source.DB, func() {})
schema := "sqlmanagermysql3"
actual, err := manager.GetTableConstraintsBySchema(context.Background(), []string{schema})
@@ -223,7 +223,7 @@ func (s *IntegrationTestSuite) Test_GetUniqueConstraintsMap_Composite() {
}
func (s *IntegrationTestSuite) Test_GetRolePermissionsMap() {
- manager := NewManager(s.source.querier, s.source.pool, func() {})
+ manager := NewManager(s.querier, s.containers.Source.DB, func() {})
schema := "sqlmanagermysql3"
actual, err := manager.GetRolePermissionsMap(context.Background())
@@ -242,18 +242,18 @@ func (s *IntegrationTestSuite) Test_GetRolePermissionsMap() {
}
func (s *IntegrationTestSuite) Test_GetCreateTableStatement() {
- manager := NewManager(s.source.querier, s.source.pool, func() {})
+ manager := NewManager(s.querier, s.containers.Source.DB, func() {})
schema := "sqlmanagermysql3"
actual, err := manager.GetCreateTableStatement(context.Background(), schema, "users")
require.NoError(s.T(), err)
require.NotEmpty(s.T(), actual)
- _, err = s.target.pool.ExecContext(context.Background(), actual)
+ _, err = s.containers.Target.DB.ExecContext(context.Background(), actual)
require.NoError(s.T(), err)
}
func (s *IntegrationTestSuite) Test_GetTableInitStatements() {
- manager := NewManager(s.source.querier, s.source.pool, func() {})
+ manager := NewManager(s.querier, s.containers.Source.DB, func() {})
schema := "sqlmanagermysql3"
actual, err := manager.GetTableInitStatements(
@@ -268,25 +268,25 @@ func (s *IntegrationTestSuite) Test_GetTableInitStatements() {
require.NoError(s.T(), err)
require.NotEmpty(s.T(), actual)
for _, stmt := range actual {
- _, err = s.target.pool.ExecContext(context.Background(), stmt.CreateTableStatement)
+ _, err = s.containers.Target.DB.ExecContext(context.Background(), stmt.CreateTableStatement)
require.NoError(s.T(), err)
}
for _, stmt := range actual {
for _, index := range stmt.IndexStatements {
- _, err = s.target.pool.ExecContext(context.Background(), index)
+ _, err = s.containers.Target.DB.ExecContext(context.Background(), index)
require.NoError(s.T(), err)
}
}
for _, stmt := range actual {
for _, alter := range stmt.AlterTableStatements {
- _, err = s.target.pool.ExecContext(context.Background(), alter.Statement)
+ _, err = s.containers.Target.DB.ExecContext(context.Background(), alter.Statement)
require.NoError(s.T(), err)
}
}
}
func (s *IntegrationTestSuite) Test_Exec() {
- manager := NewManager(s.source.querier, s.source.pool, func() {})
+ manager := NewManager(s.querier, s.containers.Source.DB, func() {})
schema := "sqlmanagermysql3"
err := manager.Exec(context.Background(), fmt.Sprintf("SELECT 1 FROM %s.%s", schema, "users"))
@@ -294,7 +294,7 @@ func (s *IntegrationTestSuite) Test_Exec() {
}
func (s *IntegrationTestSuite) Test_BatchExec() {
- manager := NewManager(s.source.querier, s.source.pool, func() {})
+ manager := NewManager(s.querier, s.containers.Source.DB, func() {})
schema := "sqlmanagermysql3"
stmt := fmt.Sprintf("SELECT 1 FROM %s.%s;", schema, "users")
@@ -303,7 +303,7 @@ func (s *IntegrationTestSuite) Test_BatchExec() {
}
func (s *IntegrationTestSuite) Test_BatchExec_With_Prefix() {
- manager := NewManager(s.source.querier, s.source.pool, func() {})
+ manager := NewManager(s.querier, s.containers.Source.DB, func() {})
schema := "sqlmanagermysql3"
stmt := fmt.Sprintf("SELECT 1 FROM %s.%s;", schema, "users")
@@ -314,7 +314,7 @@ func (s *IntegrationTestSuite) Test_BatchExec_With_Prefix() {
}
func (s *IntegrationTestSuite) Test_GetSchemaInitStatements() {
- manager := NewManager(s.source.querier, s.source.pool, func() {})
+ manager := NewManager(s.querier, s.containers.Source.DB, func() {})
schema := "sqlmanagermysql3"
statements, err := manager.GetSchemaInitStatements(context.Background(), []*sqlmanager_shared.SchemaTable{
@@ -337,29 +337,29 @@ func (s *IntegrationTestSuite) Test_GetSchemaInitStatements() {
lableStmtMap[st.Label] = append(lableStmtMap[st.Label], st.Statements...)
}
for _, stmt := range lableStmtMap["create table"] {
- _, err = s.target.pool.ExecContext(context.Background(), stmt)
+ _, err = s.containers.Target.DB.ExecContext(context.Background(), stmt)
require.NoError(s.T(), err)
}
for _, stmt := range lableStmtMap["table triggers"] {
- _, err = s.target.pool.ExecContext(context.Background(), stmt)
+ _, err = s.containers.Target.DB.ExecContext(context.Background(), stmt)
require.NoError(s.T(), err)
}
for _, stmt := range lableStmtMap["table index"] {
- _, err = s.target.pool.ExecContext(context.Background(), stmt)
+ _, err = s.containers.Target.DB.ExecContext(context.Background(), stmt)
require.NoError(s.T(), err)
}
for _, stmt := range lableStmtMap["non-fk alter table"] {
- _, err = s.target.pool.ExecContext(context.Background(), stmt)
+ _, err = s.containers.Target.DB.ExecContext(context.Background(), stmt)
require.NoError(s.T(), err)
}
for _, stmt := range lableStmtMap["fk alter table"] {
- _, err = s.target.pool.ExecContext(context.Background(), stmt)
+ _, err = s.containers.Target.DB.ExecContext(context.Background(), stmt)
require.NoError(s.T(), err)
}
}
func (s *IntegrationTestSuite) Test_GetSchemaInitStatements_customtable() {
- manager := NewManager(s.source.querier, s.source.pool, func() {})
+ manager := NewManager(s.querier, s.containers.Source.DB, func() {})
schema := "sqlmanagermysql3"
statements, err := manager.GetSchemaInitStatements(context.Background(), []*sqlmanager_shared.SchemaTable{{Schema: schema, Table: "custom_table"}})
@@ -368,7 +368,7 @@ func (s *IntegrationTestSuite) Test_GetSchemaInitStatements_customtable() {
}
func (s *IntegrationTestSuite) Test_GetSchemaTableTriggers_customtable() {
- manager := NewManager(s.source.querier, s.source.pool, func() {})
+ manager := NewManager(s.querier, s.containers.Source.DB, func() {})
schema := "sqlmanagermysql3"
triggers, err := manager.GetSchemaTableTriggers(context.Background(), []*sqlmanager_shared.SchemaTable{{Schema: schema, Table: "employee_log"}})
@@ -377,7 +377,7 @@ func (s *IntegrationTestSuite) Test_GetSchemaTableTriggers_customtable() {
}
func (s *IntegrationTestSuite) Test_GetSchemaTableDataTypes_customtable() {
- manager := NewManager(s.source.querier, s.source.pool, func() {})
+ manager := NewManager(s.querier, s.containers.Source.DB, func() {})
schema := "sqlmanagermysql3"
resp, err := manager.GetSchemaTableDataTypes(context.Background(), []*sqlmanager_shared.SchemaTable{{Schema: schema, Table: "custom_table"}})
diff --git a/backend/pkg/sqlmanager/mysql_sql-manager_integration_test.go b/backend/pkg/sqlmanager/mysql_sql-manager_integration_test.go
index 5cc081f01d..3f8dda54e7 100644
--- a/backend/pkg/sqlmanager/mysql_sql-manager_integration_test.go
+++ b/backend/pkg/sqlmanager/mysql_sql-manager_integration_test.go
@@ -7,7 +7,6 @@ import (
"os"
"sync"
"testing"
- "time"
"github.com/google/uuid"
mysql_queries "github.com/nucleuscloud/neosync/backend/gen/go/db/dbschemas/mysql"
@@ -16,15 +15,13 @@ import (
"github.com/stretchr/testify/suite"
_ "github.com/go-sql-driver/mysql"
- "github.com/testcontainers/testcontainers-go"
- testmysql "github.com/testcontainers/testcontainers-go/modules/mysql"
- "github.com/testcontainers/testcontainers-go/wait"
+ tcmysql "github.com/nucleuscloud/neosync/internal/testutil/testcontainers/mysql"
)
type MysqlIntegrationTestSuite struct {
suite.Suite
- mysqlcontainer *testmysql.MySQLContainer
+ mysqlcontainer *tcmysql.MysqlTestContainer
ctx context.Context
@@ -34,54 +31,20 @@ type MysqlIntegrationTestSuite struct {
conncfg *mgmtv1alpha1.MysqlConnectionConfig
// mgmt connection
mgmtconn *mgmtv1alpha1.Connection
-
- // dsn format of connection url
- dsn string
}
func (s *MysqlIntegrationTestSuite) SetupSuite() {
s.ctx = context.Background()
- dbname := "testdb"
- user := "root"
- pass := "test-password"
-
- container, err := testmysql.Run(s.ctx,
- "mysql:8.0.36",
- testmysql.WithDatabase(dbname),
- testmysql.WithUsername(user),
- testmysql.WithPassword(pass),
- testcontainers.WithWaitStrategy(
- wait.ForLog("port: 3306 MySQL Community Server").
- WithOccurrence(1).WithStartupTimeout(20*time.Second),
- ),
- )
+ container, err := tcmysql.NewMysqlTestContainer(s.ctx)
if err != nil {
panic(err)
}
-
- connstr, err := container.ConnectionString(s.ctx, "multiStatements=true&&parseTime=true")
- if err != nil {
- panic(err)
- }
- s.dsn = connstr
-
s.mysqlcontainer = container
- containerPort, err := container.MappedPort(s.ctx, "3306/tcp")
- if err != nil {
- panic(err)
- }
- containerHost, err := container.Host(s.ctx)
- if err != nil {
- panic(err)
- }
-
- connUrl := fmt.Sprintf("mysql://%s:%s@%s:%s/%s?multiStatements=true&parseTime=true", user, pass, containerHost, containerPort.Port(), dbname)
-
s.conncfg = &mgmtv1alpha1.MysqlConnectionConfig{
ConnectionConfig: &mgmtv1alpha1.MysqlConnectionConfig_Url{
- Url: connUrl,
+ Url: container.URL,
},
}
s.mgmtconn = &mgmtv1alpha1.Connection{
@@ -106,7 +69,7 @@ func (s *MysqlIntegrationTestSuite) TearDownTest() {
func (s *MysqlIntegrationTestSuite) TearDownSuite() {
if s.mysqlcontainer != nil {
- err := s.mysqlcontainer.Terminate(s.ctx)
+ err := s.mysqlcontainer.TearDown(s.ctx)
if err != nil {
panic(err)
}
@@ -145,7 +108,7 @@ func (s *MysqlIntegrationTestSuite) Test_NewSqlDb() {
func (s *MysqlIntegrationTestSuite) Test_NewSqlDbFromUrl() {
t := s.T()
- conn, err := s.sqlmanager.NewSqlDbFromUrl(s.ctx, "mysql", s.dsn) // NewSqlDbFromUrl requires dsn format
+ conn, err := s.sqlmanager.NewSqlDbFromUrl(s.ctx, "mysql", s.mysqlcontainer.URL)
requireNoConnErr(t, conn, err)
requireValidDatabase(t, s.ctx, conn, "mysql", "SELECT 1")
diff --git a/backend/pkg/sqlmanager/postgres/integration_test.go b/backend/pkg/sqlmanager/postgres/integration_test.go
index 2f082cb4c7..e09f0c8215 100644
--- a/backend/pkg/sqlmanager/postgres/integration_test.go
+++ b/backend/pkg/sqlmanager/postgres/integration_test.go
@@ -7,13 +7,15 @@ import (
"log/slog"
"os"
"testing"
- "time"
pg_queries "github.com/nucleuscloud/neosync/backend/gen/go/db/dbschemas/postgresql"
+ sqlmanager_shared "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/shared"
+ tcpostgres "github.com/nucleuscloud/neosync/internal/testutil/testcontainers/postgres"
"github.com/stretchr/testify/suite"
- "github.com/testcontainers/testcontainers-go"
- testpg "github.com/testcontainers/testcontainers-go/modules/postgres"
- "github.com/testcontainers/testcontainers-go/wait"
+)
+
+var (
+ testdataFolder = "testdata"
)
type IntegrationTestSuite struct {
@@ -22,12 +24,9 @@ type IntegrationTestSuite struct {
db *sql.DB
querier pg_queries.Querier
- setupSql string
- teardownSql string
-
ctx context.Context
- pgcontainer *testpg.PostgresContainer
+ pgcontainer *tcpostgres.PostgresTestContainer
schema string
}
@@ -40,36 +39,13 @@ func (s *IntegrationTestSuite) SetupSuite() {
s.ctx = context.Background()
s.schema = "sqlmanagerpostgres@special"
- pgcontainer, err := testpg.Run(
- s.ctx,
- "postgres:15",
- testcontainers.WithWaitStrategy(
- wait.ForLog("database system is ready to accept connections").
- WithOccurrence(2).WithStartupTimeout(5*time.Second),
- ),
- )
+ pgcontainer, err := tcpostgres.NewPostgresTestContainer(s.ctx)
if err != nil {
panic(err)
}
s.pgcontainer = pgcontainer
- connstr, err := pgcontainer.ConnectionString(s.ctx)
- if err != nil {
- panic(err)
- }
-
- setupSql, err := os.ReadFile("./testdata/setup.sql")
- if err != nil {
- panic(err)
- }
- s.setupSql = string(setupSql)
-
- teardownSql, err := os.ReadFile("./testdata/teardown.sql")
- if err != nil {
- panic(err)
- }
- s.teardownSql = string(teardownSql)
- db, err := sql.Open("pgx", connstr)
+ db, err := sql.Open(sqlmanager_shared.PostgresDriver, s.pgcontainer.URL)
if err != nil {
panic(err)
}
@@ -79,14 +55,14 @@ func (s *IntegrationTestSuite) SetupSuite() {
// Runs before each test
func (s *IntegrationTestSuite) SetupTest() {
- _, err := s.db.ExecContext(s.ctx, s.setupSql)
+ err := s.pgcontainer.RunSqlFiles(s.ctx, &testdataFolder, []string{"setup.sql"})
if err != nil {
panic(err)
}
}
func (s *IntegrationTestSuite) TearDownTest() {
- _, err := s.db.ExecContext(s.ctx, s.teardownSql)
+ err := s.pgcontainer.RunSqlFiles(s.ctx, &testdataFolder, []string{"teardown.sql"})
if err != nil {
panic(err)
}
@@ -96,11 +72,9 @@ func (s *IntegrationTestSuite) TearDownSuite() {
if s.db != nil {
s.db.Close()
}
- if s.pgcontainer != nil {
- err := s.pgcontainer.Terminate(s.ctx)
- if err != nil {
- panic(err)
- }
+ err := s.pgcontainer.TearDown(s.ctx)
+ if err != nil {
+ panic(err)
}
}
diff --git a/backend/pkg/sqlmanager/postgres_sql-manager_integration_test.go b/backend/pkg/sqlmanager/postgres_sql-manager_integration_test.go
index 04354c32f9..2aee422bbd 100644
--- a/backend/pkg/sqlmanager/postgres_sql-manager_integration_test.go
+++ b/backend/pkg/sqlmanager/postgres_sql-manager_integration_test.go
@@ -107,7 +107,7 @@ func (s *PostgresIntegrationTestSuite) Test_NewPooledSqlDb() {
conn, err := s.sqlmanager.NewPooledSqlDb(s.ctx, slog.Default(), s.mgmtconn)
requireNoConnErr(t, conn, err)
- requireValidDatabase(t, s.ctx, conn, "postgres", "SELECT 1")
+ requireValidDatabase(t, s.ctx, conn, "pgx", "SELECT 1")
conn.Db.Close()
}
@@ -118,7 +118,7 @@ func (s *PostgresIntegrationTestSuite) Test_NewSqlDb() {
conn, err := s.sqlmanager.NewSqlDb(s.ctx, slog.Default(), s.mgmtconn, &connTimeout)
requireNoConnErr(t, conn, err)
- requireValidDatabase(t, s.ctx, conn, "postgres", "SELECT 1")
+ requireValidDatabase(t, s.ctx, conn, "pgx", "SELECT 1")
conn.Db.Close()
}
@@ -127,7 +127,7 @@ func (s *PostgresIntegrationTestSuite) Test_NewSqlDbFromUrl() {
conn, err := s.sqlmanager.NewSqlDbFromUrl(s.ctx, "postgres", s.pgcfg.GetUrl())
requireNoConnErr(t, conn, err)
- requireValidDatabase(t, s.ctx, conn, "postgres", "SELECT 1")
+ requireValidDatabase(t, s.ctx, conn, "pgx", "SELECT 1")
conn.Db.Close()
}
@@ -137,7 +137,7 @@ func (s *PostgresIntegrationTestSuite) Test_NewSqlDbFromConnectionConfig() {
conn, err := s.sqlmanager.NewSqlDbFromConnectionConfig(s.ctx, slog.Default(), s.mgmtconn.GetConnectionConfig(), &connTimeout)
requireNoConnErr(t, conn, err)
- requireValidDatabase(t, s.ctx, conn, "postgres", "SELECT 1")
+ requireValidDatabase(t, s.ctx, conn, "pgx", "SELECT 1")
conn.Db.Close()
}
diff --git a/backend/services/mgmt/v1alpha1/integration_tests/anonymization-service_integration_test.go b/backend/services/mgmt/v1alpha1/integration_tests/anonymization-service_integration_test.go
index 37255bcccd..cd841063d5 100644
--- a/backend/services/mgmt/v1alpha1/integration_tests/anonymization-service_integration_test.go
+++ b/backend/services/mgmt/v1alpha1/integration_tests/anonymization-service_integration_test.go
@@ -17,10 +17,10 @@ func (s *IntegrationTestSuite) Test_AnonymizeService_AnonymizeMany() {
t := s.T()
t.Run("OSS-fail", func(t *testing.T) {
- userclient := s.unauthdClients.users
+ userclient := s.UnauthdClients.Users
s.setUser(s.ctx, userclient)
accountId := s.createPersonalAccount(s.ctx, userclient)
- resp, err := s.unauthdClients.anonymize.AnonymizeMany(
+ resp, err := s.UnauthdClients.Anonymize.AnonymizeMany(
s.ctx,
connect.NewRequest(&mgmtv1alpha1.AnonymizeManyRequest{
AccountId: accountId,
@@ -35,8 +35,8 @@ func (s *IntegrationTestSuite) Test_AnonymizeService_AnonymizeMany() {
})
t.Run("cloud-personal-fail", func(t *testing.T) {
- userclient := s.neosyncCloudClients.getUserClient(testAuthUserId)
- anonclient := s.neosyncCloudClients.getAnonymizeClient(testAuthUserId)
+ userclient := s.NeosyncCloudClients.GetUserClient(testAuthUserId)
+ anonclient := s.NeosyncCloudClients.GetAnonymizeClient(testAuthUserId)
s.setUser(s.ctx, userclient)
accountId := s.createPersonalAccount(s.ctx, userclient)
resp, err := anonclient.AnonymizeMany(
@@ -87,12 +87,12 @@ func (s *IntegrationTestSuite) Test_AnonymizeService_AnonymizeMany() {
}`,
}
- userclient := s.neosyncCloudClients.getUserClient(testAuthUserId)
- anonclient := s.neosyncCloudClients.getAnonymizeClient(testAuthUserId)
+ userclient := s.NeosyncCloudClients.GetUserClient(testAuthUserId)
+ anonclient := s.NeosyncCloudClients.GetAnonymizeClient(testAuthUserId)
s.setUser(s.ctx, userclient)
accountId := s.createBilledTeamAccount(s.ctx, userclient, "team1", "foo")
- s.mocks.billingclient.On("GetSubscriptions", "foo").Once().Return(&testSubscriptionIter{subscriptions: []*stripe.Subscription{
+ s.Mocks.Billingclient.On("GetSubscriptions", "foo").Once().Return(&testSubscriptionIter{subscriptions: []*stripe.Subscription{
{Status: stripe.SubscriptionStatusIncompleteExpired},
{Status: stripe.SubscriptionStatusActive},
}}, nil)
@@ -170,8 +170,8 @@ func (s *IntegrationTestSuite) Test_AnonymizeService_AnonymizeSingle() {
"sports": ["basketball", "golf", "swimming"]
}`
- accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users)
- resp, err := s.unauthdClients.anonymize.AnonymizeSingle(
+ accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users)
+ resp, err := s.UnauthdClients.Anonymize.AnonymizeSingle(
s.ctx,
connect.NewRequest(&mgmtv1alpha1.AnonymizeSingleRequest{
AccountId: accountId,
@@ -228,11 +228,11 @@ func (s *IntegrationTestSuite) Test_AnonymizeService_AnonymizeSingle_ForbiddenTr
t := s.T()
t.Run("OSS", func(t *testing.T) {
- accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users)
+ accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users)
t.Run("transformpiitext", func(t *testing.T) {
t.Run("mappings", func(t *testing.T) {
- resp, err := s.unauthdClients.anonymize.AnonymizeSingle(
+ resp, err := s.UnauthdClients.Anonymize.AnonymizeSingle(
s.ctx,
connect.NewRequest(&mgmtv1alpha1.AnonymizeSingleRequest{
AccountId: accountId,
@@ -252,7 +252,7 @@ func (s *IntegrationTestSuite) Test_AnonymizeService_AnonymizeSingle_ForbiddenTr
t.Run("defaults", func(t *testing.T) {
t.Run("Bool", func(t *testing.T) {
- resp, err := s.unauthdClients.anonymize.AnonymizeSingle(
+ resp, err := s.UnauthdClients.Anonymize.AnonymizeSingle(
s.ctx,
connect.NewRequest(&mgmtv1alpha1.AnonymizeSingleRequest{
AccountId: accountId,
@@ -268,7 +268,7 @@ func (s *IntegrationTestSuite) Test_AnonymizeService_AnonymizeSingle_ForbiddenTr
requireConnectError(t, err, connect.CodePermissionDenied)
})
t.Run("S", func(t *testing.T) {
- resp, err := s.unauthdClients.anonymize.AnonymizeSingle(
+ resp, err := s.UnauthdClients.Anonymize.AnonymizeSingle(
s.ctx,
connect.NewRequest(&mgmtv1alpha1.AnonymizeSingleRequest{
AccountId: accountId,
@@ -284,7 +284,7 @@ func (s *IntegrationTestSuite) Test_AnonymizeService_AnonymizeSingle_ForbiddenTr
requireConnectError(t, err, connect.CodePermissionDenied)
})
t.Run("N", func(t *testing.T) {
- resp, err := s.unauthdClients.anonymize.AnonymizeSingle(
+ resp, err := s.UnauthdClients.Anonymize.AnonymizeSingle(
s.ctx,
connect.NewRequest(&mgmtv1alpha1.AnonymizeSingleRequest{
AccountId: accountId,
@@ -304,8 +304,8 @@ func (s *IntegrationTestSuite) Test_AnonymizeService_AnonymizeSingle_ForbiddenTr
})
t.Run("cloud-personal", func(t *testing.T) {
- userclient := s.neosyncCloudClients.getUserClient(testAuthUserId)
- anonclient := s.neosyncCloudClients.getAnonymizeClient(testAuthUserId)
+ userclient := s.NeosyncCloudClients.GetUserClient(testAuthUserId)
+ anonclient := s.NeosyncCloudClients.GetAnonymizeClient(testAuthUserId)
s.setUser(s.ctx, userclient)
accountId := s.createPersonalAccount(s.ctx, userclient)
diff --git a/backend/services/mgmt/v1alpha1/integration_tests/connection-service_integration_test.go b/backend/services/mgmt/v1alpha1/integration_tests/connection-service_integration_test.go
index 081339d92f..fee2e210da 100644
--- a/backend/services/mgmt/v1alpha1/integration_tests/connection-service_integration_test.go
+++ b/backend/services/mgmt/v1alpha1/integration_tests/connection-service_integration_test.go
@@ -9,9 +9,9 @@ import (
)
func (s *IntegrationTestSuite) Test_ConnectionService_IsConnectionNameAvailable_Available() {
- accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users)
+ accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users)
- resp, err := s.unauthdClients.connections.IsConnectionNameAvailable(
+ resp, err := s.UnauthdClients.Connections.IsConnectionNameAvailable(
s.ctx,
connect.NewRequest(&mgmtv1alpha1.IsConnectionNameAvailableRequest{
AccountId: accountId,
@@ -23,10 +23,10 @@ func (s *IntegrationTestSuite) Test_ConnectionService_IsConnectionNameAvailable_
}
func (s *IntegrationTestSuite) Test_ConnectionService_IsConnectionNameAvailable_NotAvailable() {
- accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users)
- s.createPostgresConnection(s.unauthdClients.connections, accountId, "foo", "test-url")
+ accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users)
+ s.createPostgresConnection(s.UnauthdClients.Connections, accountId, "foo", "test-url")
- resp, err := s.unauthdClients.connections.IsConnectionNameAvailable(
+ resp, err := s.UnauthdClients.Connections.IsConnectionNameAvailable(
s.ctx,
connect.NewRequest(&mgmtv1alpha1.IsConnectionNameAvailableRequest{
AccountId: accountId,
@@ -39,16 +39,14 @@ func (s *IntegrationTestSuite) Test_ConnectionService_IsConnectionNameAvailable_
func (s *IntegrationTestSuite) Test_ConnectionService_CheckConnectionConfig() {
t := s.T()
- accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users)
- pgconnstr, err := s.pgcontainer.ConnectionString(s.ctx, "sslmode=disable")
- require.NoError(t, err)
+ accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users)
- conn := s.createPostgresConnection(s.unauthdClients.connections, accountId, "foo", pgconnstr)
+ conn := s.createPostgresConnection(s.UnauthdClients.Connections, accountId, "foo", s.Pgcontainer.URL)
t.Run("valid-pg-connstr", func(t *testing.T) {
t.Parallel()
- resp, err := s.unauthdClients.connections.CheckConnectionConfig(
+ resp, err := s.UnauthdClients.Connections.CheckConnectionConfig(
s.ctx,
connect.NewRequest(&mgmtv1alpha1.CheckConnectionConfigRequest{
ConnectionConfig: conn.GetConnectionConfig(),
@@ -62,25 +60,21 @@ func (s *IntegrationTestSuite) Test_ConnectionService_CheckConnectionConfig() {
func (s *IntegrationTestSuite) Test_ConnectionService_CreateConnection() {
t := s.T()
- accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users)
+ accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users)
t.Run("postgres-success", func(t *testing.T) {
- pgconnstr, err := s.pgcontainer.ConnectionString(s.ctx, "sslmode=disable")
- require.NoError(t, err)
- s.createPostgresConnection(s.unauthdClients.connections, accountId, "foo", pgconnstr)
+ s.createPostgresConnection(s.UnauthdClients.Connections, accountId, "foo", s.Pgcontainer.URL)
})
}
func (s *IntegrationTestSuite) Test_ConnectionService_UpdateConnection() {
t := s.T()
- accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users)
+ accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users)
t.Run("postgres-success", func(t *testing.T) {
- pgconnstr, err := s.pgcontainer.ConnectionString(s.ctx, "sslmode=disable")
- require.NoError(t, err)
- conn := s.createPostgresConnection(s.unauthdClients.connections, accountId, "foo", pgconnstr)
+ conn := s.createPostgresConnection(s.UnauthdClients.Connections, accountId, "foo", s.Pgcontainer.URL)
- resp, err := s.unauthdClients.connections.UpdateConnection(
+ resp, err := s.UnauthdClients.Connections.UpdateConnection(
s.ctx,
connect.NewRequest(&mgmtv1alpha1.UpdateConnectionRequest{
Id: conn.GetId(),
@@ -95,13 +89,11 @@ func (s *IntegrationTestSuite) Test_ConnectionService_UpdateConnection() {
func (s *IntegrationTestSuite) Test_ConnectionService_GetConnection() {
t := s.T()
- accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users)
- pgconnstr, err := s.pgcontainer.ConnectionString(s.ctx, "sslmode=disable")
- require.NoError(t, err)
+ accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users)
- conn := s.createPostgresConnection(s.unauthdClients.connections, accountId, "foo", pgconnstr)
+ conn := s.createPostgresConnection(s.UnauthdClients.Connections, accountId, "foo", s.Pgcontainer.URL)
- resp, err := s.unauthdClients.connections.GetConnection(
+ resp, err := s.UnauthdClients.Connections.GetConnection(
s.ctx,
connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{
Id: conn.GetId(),
@@ -113,13 +105,11 @@ func (s *IntegrationTestSuite) Test_ConnectionService_GetConnection() {
func (s *IntegrationTestSuite) Test_ConnectionService_GetConnections() {
t := s.T()
- accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users)
- pgconnstr, err := s.pgcontainer.ConnectionString(s.ctx, "sslmode=disable")
- require.NoError(t, err)
+ accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users)
- s.createPostgresConnection(s.unauthdClients.connections, accountId, "foo", pgconnstr)
+ s.createPostgresConnection(s.UnauthdClients.Connections, accountId, "foo", s.Pgcontainer.URL)
- resp, err := s.unauthdClients.connections.GetConnections(
+ resp, err := s.UnauthdClients.Connections.GetConnections(
s.ctx,
connect.NewRequest(&mgmtv1alpha1.GetConnectionsRequest{
AccountId: accountId,
@@ -131,13 +121,11 @@ func (s *IntegrationTestSuite) Test_ConnectionService_GetConnections() {
func (s *IntegrationTestSuite) Test_ConnectionService_DeleteConnection() {
t := s.T()
- accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users)
- pgconnstr, err := s.pgcontainer.ConnectionString(s.ctx, "sslmode=disable")
- require.NoError(t, err)
+ accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users)
- conn := s.createPostgresConnection(s.unauthdClients.connections, accountId, "foo", pgconnstr)
+ conn := s.createPostgresConnection(s.UnauthdClients.Connections, accountId, "foo", s.Pgcontainer.URL)
- resp, err := s.unauthdClients.connections.GetConnections(
+ resp, err := s.UnauthdClients.Connections.GetConnections(
s.ctx,
connect.NewRequest(&mgmtv1alpha1.GetConnectionsRequest{
AccountId: accountId,
@@ -146,7 +134,7 @@ func (s *IntegrationTestSuite) Test_ConnectionService_DeleteConnection() {
requireNoErrResp(t, resp, err)
require.NotEmpty(t, resp.Msg.GetConnections())
- resp2, err := s.unauthdClients.connections.DeleteConnection(
+ resp2, err := s.UnauthdClients.Connections.DeleteConnection(
s.ctx,
connect.NewRequest(&mgmtv1alpha1.DeleteConnectionRequest{
Id: conn.GetId(),
@@ -155,7 +143,7 @@ func (s *IntegrationTestSuite) Test_ConnectionService_DeleteConnection() {
requireNoErrResp(t, resp2, err)
// again to test idempotency
- resp2, err = s.unauthdClients.connections.DeleteConnection(
+ resp2, err = s.UnauthdClients.Connections.DeleteConnection(
s.ctx,
connect.NewRequest(&mgmtv1alpha1.DeleteConnectionRequest{
Id: conn.GetId(),
@@ -166,13 +154,11 @@ func (s *IntegrationTestSuite) Test_ConnectionService_DeleteConnection() {
func (s *IntegrationTestSuite) Test_ConnectionService_CheckSqlQuery() {
t := s.T()
- accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users)
- pgconnstr, err := s.pgcontainer.ConnectionString(s.ctx, "sslmode=disable")
- require.NoError(t, err)
+ accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users)
- conn := s.createPostgresConnection(s.unauthdClients.connections, accountId, "foo", pgconnstr)
+ conn := s.createPostgresConnection(s.UnauthdClients.Connections, accountId, "foo", s.Pgcontainer.URL)
- resp, err := s.unauthdClients.connections.CheckSqlQuery(
+ resp, err := s.UnauthdClients.Connections.CheckSqlQuery(
s.ctx,
connect.NewRequest(&mgmtv1alpha1.CheckSqlQueryRequest{
Id: conn.GetId(),
diff --git a/backend/services/mgmt/v1alpha1/integration_tests/integration_test.go b/backend/services/mgmt/v1alpha1/integration_tests/integration_test.go
index 3b40117809..dc4c160563 100644
--- a/backend/services/mgmt/v1alpha1/integration_tests/integration_test.go
+++ b/backend/services/mgmt/v1alpha1/integration_tests/integration_test.go
@@ -3,349 +3,49 @@ package integrationtests_test
import (
"context"
"fmt"
- "io"
"log/slog"
- "net/http"
- "net/http/httptest"
"os"
- "sync"
"testing"
- "time"
- "connectrpc.com/connect"
- "github.com/jackc/pgx/v5/pgxpool"
- db_queries "github.com/nucleuscloud/neosync/backend/gen/go/db"
- mysql_queries "github.com/nucleuscloud/neosync/backend/gen/go/db/dbschemas/mysql"
- pg_queries "github.com/nucleuscloud/neosync/backend/gen/go/db/dbschemas/postgresql"
- "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1/mgmtv1alpha1connect"
- "github.com/nucleuscloud/neosync/backend/internal/apikey"
- auth_apikey "github.com/nucleuscloud/neosync/backend/internal/auth/apikey"
- auth_client "github.com/nucleuscloud/neosync/backend/internal/auth/client"
- auth_jwt "github.com/nucleuscloud/neosync/backend/internal/auth/jwt"
- "github.com/nucleuscloud/neosync/backend/internal/authmgmt"
- auth_interceptor "github.com/nucleuscloud/neosync/backend/internal/connect/interceptors/auth"
- "github.com/nucleuscloud/neosync/backend/internal/neosyncdb"
- clientmanager "github.com/nucleuscloud/neosync/backend/internal/temporal/client-manager"
- "github.com/nucleuscloud/neosync/backend/internal/utils"
- "github.com/nucleuscloud/neosync/backend/pkg/mongoconnect"
- mssql_queries "github.com/nucleuscloud/neosync/backend/pkg/mssql-querier"
- "github.com/nucleuscloud/neosync/backend/pkg/sqlconnect"
- "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager"
- v1alpha_anonymizationservice "github.com/nucleuscloud/neosync/backend/services/mgmt/v1alpha1/anonymization-service"
- v1alpha1_connectionservice "github.com/nucleuscloud/neosync/backend/services/mgmt/v1alpha1/connection-service"
- v1alpha1_jobservice "github.com/nucleuscloud/neosync/backend/services/mgmt/v1alpha1/job-service"
- v1alpha1_transformersservice "github.com/nucleuscloud/neosync/backend/services/mgmt/v1alpha1/transformers-service"
- v1alpha1_useraccountservice "github.com/nucleuscloud/neosync/backend/services/mgmt/v1alpha1/user-account-service"
- awsmanager "github.com/nucleuscloud/neosync/internal/aws"
- "github.com/nucleuscloud/neosync/internal/billing"
- presidioapi "github.com/nucleuscloud/neosync/internal/ee/presidio"
- neomigrate "github.com/nucleuscloud/neosync/internal/migrate"
- promapiv1mock "github.com/nucleuscloud/neosync/internal/mocks/github.com/prometheus/client_golang/api/prometheus/v1"
- http_client "github.com/nucleuscloud/neosync/worker/pkg/http/client"
+ tcneosyncapi "github.com/nucleuscloud/neosync/backend/pkg/integration-test"
"github.com/stretchr/testify/suite"
- "github.com/testcontainers/testcontainers-go"
- testpg "github.com/testcontainers/testcontainers-go/modules/postgres"
- "github.com/testcontainers/testcontainers-go/wait"
)
-type unauthdClients struct {
- users mgmtv1alpha1connect.UserAccountServiceClient
- transformers mgmtv1alpha1connect.TransformersServiceClient
- connections mgmtv1alpha1connect.ConnectionServiceClient
- jobs mgmtv1alpha1connect.JobServiceClient
- anonymize mgmtv1alpha1connect.AnonymizationServiceClient
-}
-
-type neosyncCloudClients struct {
- httpsrv *httptest.Server
- basepath string
-}
-
-func (s *neosyncCloudClients) getUserClient(authUserId string) mgmtv1alpha1connect.UserAccountServiceClient {
- return mgmtv1alpha1connect.NewUserAccountServiceClient(http_client.WithBearerAuth(&http.Client{}, &authUserId), s.httpsrv.URL+s.basepath)
-}
-func (s *neosyncCloudClients) getAnonymizeClient(authUserId string) mgmtv1alpha1connect.AnonymizationServiceClient {
- return mgmtv1alpha1connect.NewAnonymizationServiceClient(http_client.WithBearerAuth(&http.Client{}, &authUserId), s.httpsrv.URL+s.basepath)
-}
-
-type authdClients struct {
- httpsrv *httptest.Server
-}
-
-func (s *authdClients) getUserClient(authUserId string) mgmtv1alpha1connect.UserAccountServiceClient {
- return mgmtv1alpha1connect.NewUserAccountServiceClient(http_client.WithBearerAuth(&http.Client{}, &authUserId), s.httpsrv.URL+"/auth")
-}
-
-type mocks struct {
- temporalClientManager *clientmanager.MockTemporalClientManagerClient
- authclient *auth_client.MockInterface
- authmanagerclient *authmgmt.MockInterface
- prometheusclient *promapiv1mock.MockAPI
- billingclient *billing.MockInterface
- presidio presidiomocks
-}
-
-type presidiomocks struct {
- analyzer *presidioapi.MockAnalyzeInterface
- anonymizer *presidioapi.MockAnonymizeInterface
- entities *presidioapi.MockEntityInterface
-}
-
type IntegrationTestSuite struct {
suite.Suite
-
- pgpool *pgxpool.Pool
- neosyncQuerier db_queries.Querier
- systemQuerier pg_queries.Querier
-
+ tcneosyncapi.NeosyncApiTestClient
ctx context.Context
-
- pgcontainer *testpg.PostgresContainer
- connstr string
- migrationsDir string
-
- httpsrv *httptest.Server
-
- unauthdClients *unauthdClients
- neosyncCloudClients *neosyncCloudClients
- authdClients *authdClients
-
- mocks *mocks
}
+// TODO update service integration tests to not use testify suite
func (s *IntegrationTestSuite) SetupSuite() {
s.ctx = context.Background()
-
- pgcontainer, err := testpg.Run(
- s.ctx,
- "postgres:15",
- testcontainers.WithWaitStrategy(
- wait.ForLog("database system is ready to accept connections").
- WithOccurrence(2).WithStartupTimeout(5*time.Second),
- ),
- )
- if err != nil {
- panic(err)
- }
- s.pgcontainer = pgcontainer
- connstr, err := pgcontainer.ConnectionString(s.ctx, "sslmode=disable")
- if err != nil {
- panic(err)
- }
- s.connstr = connstr
-
- pool, err := pgxpool.New(s.ctx, connstr)
+ api, err := tcneosyncapi.NewNeosyncApiTestClient(s.ctx, s.T())
if err != nil {
panic(err)
}
- s.pgpool = pool
- s.neosyncQuerier = db_queries.New()
- s.systemQuerier = pg_queries.New()
- s.migrationsDir = "../../../../sql/postgresql/schema"
-
- s.mocks = &mocks{
- temporalClientManager: clientmanager.NewMockTemporalClientManagerClient(s.T()),
- authclient: auth_client.NewMockInterface(s.T()),
- authmanagerclient: authmgmt.NewMockInterface(s.T()),
- prometheusclient: promapiv1mock.NewMockAPI(s.T()),
- billingclient: billing.NewMockInterface(s.T()),
- presidio: presidiomocks{
- analyzer: presidioapi.NewMockAnalyzeInterface(s.T()),
- anonymizer: presidioapi.NewMockAnonymizeInterface(s.T()),
- entities: presidioapi.NewMockEntityInterface(s.T()),
- },
- }
-
- maxAllowed := int64(10000)
- unauthdUserService := v1alpha1_useraccountservice.New(
- &v1alpha1_useraccountservice.Config{IsAuthEnabled: false, IsNeosyncCloud: false, DefaultMaxAllowedRecords: &maxAllowed},
- neosyncdb.New(pool, db_queries.New()),
- s.mocks.temporalClientManager,
- s.mocks.authclient,
- s.mocks.authmanagerclient,
- nil,
- )
-
- authdUserService := v1alpha1_useraccountservice.New(
- &v1alpha1_useraccountservice.Config{IsAuthEnabled: true, IsNeosyncCloud: false},
- neosyncdb.New(pool, db_queries.New()),
- s.mocks.temporalClientManager,
- s.mocks.authclient,
- s.mocks.authmanagerclient,
- nil,
- )
-
- neoCloudAuthdUserService := v1alpha1_useraccountservice.New(
- &v1alpha1_useraccountservice.Config{IsAuthEnabled: true, IsNeosyncCloud: true},
- neosyncdb.New(pool, db_queries.New()),
- s.mocks.temporalClientManager,
- s.mocks.authclient,
- s.mocks.authmanagerclient,
- s.mocks.billingclient,
- )
- neoCloudAuthdAnonymizeService := v1alpha_anonymizationservice.New(
- &v1alpha_anonymizationservice.Config{IsAuthEnabled: true, IsNeosyncCloud: true, IsPresidioEnabled: false},
- nil,
- neoCloudAuthdUserService,
- s.mocks.presidio.analyzer,
- s.mocks.presidio.anonymizer,
- neosyncdb.New(pool, db_queries.New()),
- )
-
- unauthdTransformersService := v1alpha1_transformersservice.New(
- &v1alpha1_transformersservice.Config{IsPresidioEnabled: true, IsNeosyncCloud: false},
- neosyncdb.New(pool, db_queries.New()),
- unauthdUserService,
- s.mocks.presidio.entities,
- )
-
- unauthdConnectionsService := v1alpha1_connectionservice.New(
- &v1alpha1_connectionservice.Config{},
- neosyncdb.New(pool, db_queries.New()),
- unauthdUserService,
- &sqlconnect.SqlOpenConnector{},
- pg_queries.New(),
- mysql_queries.New(),
- mssql_queries.New(),
- mongoconnect.NewConnector(),
- awsmanager.New(),
- )
- unauthdJobsService := v1alpha1_jobservice.New(
- &v1alpha1_jobservice.Config{},
- neosyncdb.New(pool, db_queries.New()),
- s.mocks.temporalClientManager,
- unauthdConnectionsService,
- unauthdUserService,
- sqlmanager.NewSqlManager(
- &sync.Map{}, pg_queries.New(),
- &sync.Map{}, mysql_queries.New(),
- &sync.Map{}, mssql_queries.New(),
- &sqlconnect.SqlOpenConnector{},
- ),
- )
-
- var presAnalyzeClient presidioapi.AnalyzeInterface
- var presAnonClient presidioapi.AnonymizeInterface
-
- unauthdAnonymizationService := v1alpha_anonymizationservice.New(
- &v1alpha_anonymizationservice.Config{IsPresidioEnabled: false},
- nil,
- unauthdUserService,
- presAnalyzeClient, presAnonClient,
- neosyncdb.New(pool, db_queries.New()),
- )
-
- rootmux := http.NewServeMux()
-
- unauthmux := http.NewServeMux()
- unauthmux.Handle(mgmtv1alpha1connect.NewUserAccountServiceHandler(
- unauthdUserService,
- ))
- unauthmux.Handle(mgmtv1alpha1connect.NewTransformersServiceHandler(
- unauthdTransformersService,
- ))
- unauthmux.Handle(mgmtv1alpha1connect.NewConnectionServiceHandler(
- unauthdConnectionsService,
- ))
- unauthmux.Handle(mgmtv1alpha1connect.NewJobServiceHandler(
- unauthdJobsService,
- ))
-
- unauthmux.Handle(mgmtv1alpha1connect.NewAnonymizationServiceHandler(
- unauthdAnonymizationService,
- ))
- rootmux.Handle("/unauth/", http.StripPrefix("/unauth", unauthmux))
-
- authinterceptors := connect.WithInterceptors(
- auth_interceptor.NewInterceptor(func(ctx context.Context, header http.Header, spec connect.Spec) (context.Context, error) {
- // will need to further fill this out as the tests grow
- authuserid, err := utils.GetBearerTokenFromHeader(header, "Authorization")
- if err != nil {
- return nil, err
- }
- if apikey.IsValidV1WorkerKey(authuserid) {
- return auth_apikey.SetTokenData(ctx, &auth_apikey.TokenContextData{
- RawToken: authuserid,
- ApiKey: nil,
- ApiKeyType: apikey.WorkerApiKey,
- }), nil
- }
- return auth_jwt.SetTokenData(ctx, &auth_jwt.TokenContextData{
- AuthUserId: authuserid,
- Claims: &auth_jwt.CustomClaims{Email: &validAuthUser.Email},
- }), nil
- }),
- )
-
- authmux := http.NewServeMux()
- authmux.Handle(mgmtv1alpha1connect.NewUserAccountServiceHandler(
- authdUserService,
- authinterceptors,
- ))
- rootmux.Handle("/auth/", http.StripPrefix("/auth", authmux))
-
- ncauthmux := http.NewServeMux()
- ncauthmux.Handle(mgmtv1alpha1connect.NewUserAccountServiceHandler(
- neoCloudAuthdUserService,
- authinterceptors,
- ))
- ncauthmux.Handle(mgmtv1alpha1connect.NewAnonymizationServiceHandler(
- neoCloudAuthdAnonymizeService,
- authinterceptors,
- ))
- rootmux.Handle("/ncauth/", http.StripPrefix("/ncauth", ncauthmux))
-
- s.httpsrv = startHTTPServer(s.T(), rootmux)
-
- s.unauthdClients = &unauthdClients{
- users: mgmtv1alpha1connect.NewUserAccountServiceClient(s.httpsrv.Client(), s.httpsrv.URL+"/unauth"),
- transformers: mgmtv1alpha1connect.NewTransformersServiceClient(s.httpsrv.Client(), s.httpsrv.URL+"/unauth"),
- connections: mgmtv1alpha1connect.NewConnectionServiceClient(s.httpsrv.Client(), s.httpsrv.URL+"/unauth"),
- jobs: mgmtv1alpha1connect.NewJobServiceClient(s.httpsrv.Client(), s.httpsrv.URL+"/unauth"),
- anonymize: mgmtv1alpha1connect.NewAnonymizationServiceClient(s.httpsrv.Client(), s.httpsrv.URL+"/unauth"),
- }
-
- s.authdClients = &authdClients{
- httpsrv: s.httpsrv,
- }
- s.neosyncCloudClients = &neosyncCloudClients{
- httpsrv: s.httpsrv,
- basepath: "/ncauth",
- }
+ s.NeosyncApiTestClient = *api
}
// Runs before each test
func (s *IntegrationTestSuite) SetupTest() {
- discardLogger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{}))
- err := neomigrate.Up(s.ctx, s.connstr, s.migrationsDir, discardLogger)
+ err := s.InitializeTest(s.ctx)
if err != nil {
panic(err)
}
}
func (s *IntegrationTestSuite) TearDownTest() {
- // Dropping here because 1) more efficient and 2) we have a bad down migration
- // _jobs-connection-id-null.down that breaks due to having a null connection_id column.
- // we should do something about that at some point. Running this single drop is easier though
- _, err := s.pgpool.Exec(s.ctx, "DROP SCHEMA IF EXISTS neosync_api CASCADE")
- if err != nil {
- panic(err)
- }
- _, err = s.pgpool.Exec(s.ctx, "DROP TABLE IF EXISTS public.schema_migrations")
+ err := s.CleanupTest(s.ctx)
if err != nil {
panic(err)
}
}
func (s *IntegrationTestSuite) TearDownSuite() {
- if s.pgpool != nil {
- s.pgpool.Close()
- }
- if s.pgcontainer != nil {
- err := s.pgcontainer.Terminate(s.ctx)
- if err != nil {
- panic(err)
- }
+ err := s.TearDown(s.ctx)
+ if err != nil {
+ panic(err)
}
}
@@ -358,12 +58,3 @@ func TestIntegrationTestSuite(t *testing.T) {
}
suite.Run(t, new(IntegrationTestSuite))
}
-
-func startHTTPServer(tb testing.TB, h http.Handler) *httptest.Server {
- tb.Helper()
- srv := httptest.NewUnstartedServer(h)
- srv.EnableHTTP2 = true
- srv.Start()
- tb.Cleanup(srv.Close)
- return srv
-}
diff --git a/backend/services/mgmt/v1alpha1/integration_tests/integration_util_test.go b/backend/services/mgmt/v1alpha1/integration_tests/integration_util_test.go
index 7fd0720b36..30a04fe718 100644
--- a/backend/services/mgmt/v1alpha1/integration_tests/integration_util_test.go
+++ b/backend/services/mgmt/v1alpha1/integration_tests/integration_util_test.go
@@ -12,6 +12,7 @@ import (
mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1"
"github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1/mgmtv1alpha1connect"
"github.com/nucleuscloud/neosync/backend/internal/neosyncdb"
+ tcneosyncapi "github.com/nucleuscloud/neosync/backend/pkg/integration-test"
"github.com/stretchr/testify/require"
)
@@ -20,9 +21,7 @@ func (s *IntegrationTestSuite) createPersonalAccount(
userclient mgmtv1alpha1connect.UserAccountServiceClient,
) string {
s.T().Helper()
- resp, err := userclient.SetPersonalAccount(ctx, connect.NewRequest(&mgmtv1alpha1.SetPersonalAccountRequest{}))
- requireNoErrResp(s.T(), resp, err)
- return resp.Msg.AccountId
+ return tcneosyncapi.CreatePersonalAccount(ctx, s.T(), userclient)
}
func requireNoErrResp[T any](t testing.TB, resp *connect.Response[T], err error) {
@@ -53,7 +52,7 @@ func (s *IntegrationTestSuite) setAccountCreatedAt(
if err != nil {
return err
}
- _, err = s.neosyncQuerier.SetAccountCreatedAt(ctx, s.pgpool, db_queries.SetAccountCreatedAtParams{
+ _, err = s.NeosyncQuerier.SetAccountCreatedAt(ctx, s.Pgcontainer.DB, db_queries.SetAccountCreatedAtParams{
CreatedAt: pgtype.Timestamp{Time: createdAt, Valid: true},
AccountId: accountUuid,
})
diff --git a/backend/services/mgmt/v1alpha1/integration_tests/jobs-service_integration_test.go b/backend/services/mgmt/v1alpha1/integration_tests/jobs-service_integration_test.go
index 339dc52ada..53c41e1320 100644
--- a/backend/services/mgmt/v1alpha1/integration_tests/jobs-service_integration_test.go
+++ b/backend/services/mgmt/v1alpha1/integration_tests/jobs-service_integration_test.go
@@ -11,9 +11,9 @@ import (
)
func (s *IntegrationTestSuite) Test_GetJobs_Empty() {
- accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users)
+ accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users)
- resp, err := s.unauthdClients.jobs.GetJobs(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetJobsRequest{
+ resp, err := s.UnauthdClients.Jobs.GetJobs(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetJobsRequest{
AccountId: accountId,
}))
requireNoErrResp(s.T(), resp, err)
@@ -22,23 +22,23 @@ func (s *IntegrationTestSuite) Test_GetJobs_Empty() {
}
func (s *IntegrationTestSuite) Test_CreateJob_Ok() {
- accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users)
- srcconn := s.createPostgresConnection(s.unauthdClients.connections, accountId, "source", "test")
- destconn := s.createPostgresConnection(s.unauthdClients.connections, accountId, "dest", "test2")
+ accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users)
+ srcconn := s.createPostgresConnection(s.UnauthdClients.Connections, accountId, "source", "test")
+ destconn := s.createPostgresConnection(s.UnauthdClients.Connections, accountId, "dest", "test2")
mockScheduleClient := temporalmocks.NewScheduleClient(s.T())
mockScheduleHandle := temporalmocks.NewScheduleHandle(s.T())
- s.mocks.temporalClientManager.
+ s.Mocks.TemporalClientManager.
On(
"DoesAccountHaveTemporalWorkspace", mock.Anything, mock.Anything, mock.Anything,
).
Return(true, nil).
Once()
- s.mocks.temporalClientManager.
+ s.Mocks.TemporalClientManager.
On("GetScheduleClientByAccount", mock.Anything, mock.Anything, mock.Anything).
Return(mockScheduleClient, nil).
Once()
- s.mocks.temporalClientManager.
+ s.Mocks.TemporalClientManager.
On("GetTemporalConfigByAccount", mock.Anything, mock.Anything).
Return(&pg_models.TemporalConfig{}, nil).
Once()
@@ -51,7 +51,7 @@ func (s *IntegrationTestSuite) Test_CreateJob_Ok() {
Return("test-id").
Once()
- resp, err := s.unauthdClients.jobs.CreateJob(s.ctx, connect.NewRequest(&mgmtv1alpha1.CreateJobRequest{
+ resp, err := s.UnauthdClients.Jobs.CreateJob(s.ctx, connect.NewRequest(&mgmtv1alpha1.CreateJobRequest{
AccountId: accountId,
JobName: "test",
Mappings: []*mgmtv1alpha1.JobMapping{},
diff --git a/backend/services/mgmt/v1alpha1/integration_tests/transformers-service_integration_test.go b/backend/services/mgmt/v1alpha1/integration_tests/transformers-service_integration_test.go
index 7e55874a44..65fdd83285 100644
--- a/backend/services/mgmt/v1alpha1/integration_tests/transformers-service_integration_test.go
+++ b/backend/services/mgmt/v1alpha1/integration_tests/transformers-service_integration_test.go
@@ -11,7 +11,7 @@ import (
)
func (s *IntegrationTestSuite) Test_TransformersService_GetSystemTransformers() {
- resp, err := s.unauthdClients.transformers.GetSystemTransformers(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetSystemTransformersRequest{}))
+ resp, err := s.UnauthdClients.Transformers.GetSystemTransformers(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetSystemTransformersRequest{}))
requireNoErrResp(s.T(), resp, err)
require.NotEmpty(s.T(), resp.Msg.GetTransformers())
}
@@ -19,7 +19,7 @@ func (s *IntegrationTestSuite) Test_TransformersService_GetSystemTransformers()
func (s *IntegrationTestSuite) Test_TransformersService_GetSystemTransformersBySource() {
t := s.T()
t.Run("ok", func(t *testing.T) {
- resp, err := s.unauthdClients.transformers.GetSystemTransformerBySource(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetSystemTransformerBySourceRequest{
+ resp, err := s.UnauthdClients.Transformers.GetSystemTransformerBySource(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetSystemTransformerBySourceRequest{
Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_BOOL,
}))
requireNoErrResp(t, resp, err)
@@ -28,7 +28,7 @@ func (s *IntegrationTestSuite) Test_TransformersService_GetSystemTransformersByS
require.Equal(t, transformer.GetSource(), mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_BOOL)
})
t.Run("not_found", func(t *testing.T) {
- resp, err := s.unauthdClients.transformers.GetSystemTransformerBySource(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetSystemTransformerBySourceRequest{
+ resp, err := s.UnauthdClients.Transformers.GetSystemTransformerBySource(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetSystemTransformerBySourceRequest{
Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_UNSPECIFIED,
}))
requireErrResp(t, resp, err)
@@ -41,14 +41,14 @@ func (s *IntegrationTestSuite) Test_TransformersService_GetTransformPiiRecognize
t.Run("ok", func(t *testing.T) {
allowed := []string{"foo", "bar"}
- s.mocks.presidio.entities.On("GetSupportedentitiesWithResponse", mock.Anything, mock.Anything).
+ s.Mocks.Presidio.Entities.On("GetSupportedentitiesWithResponse", mock.Anything, mock.Anything).
Once().
Return(&presidioapi.GetSupportedentitiesResponse{
JSON200: &allowed,
}, nil)
- accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users)
- resp, err := s.unauthdClients.transformers.GetTransformPiiEntities(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetTransformPiiEntitiesRequest{
+ accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users)
+ resp, err := s.UnauthdClients.Transformers.GetTransformPiiEntities(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetTransformPiiEntitiesRequest{
AccountId: accountId,
}))
requireNoErrResp(t, resp, err)
diff --git a/backend/services/mgmt/v1alpha1/integration_tests/user-account-service_integration_test.go b/backend/services/mgmt/v1alpha1/integration_tests/user-account-service_integration_test.go
index bfc52a801e..3d50cf4e1d 100644
--- a/backend/services/mgmt/v1alpha1/integration_tests/user-account-service_integration_test.go
+++ b/backend/services/mgmt/v1alpha1/integration_tests/user-account-service_integration_test.go
@@ -27,9 +27,9 @@ var (
)
func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountOnboardingConfig() {
- accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users)
+ accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users)
- resp, err := s.unauthdClients.users.GetAccountOnboardingConfig(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetAccountOnboardingConfigRequest{AccountId: accountId}))
+ resp, err := s.UnauthdClients.Users.GetAccountOnboardingConfig(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetAccountOnboardingConfigRequest{AccountId: accountId}))
requireNoErrResp(s.T(), resp, err)
onboardingConfig := resp.Msg.GetConfig()
require.NotNil(s.T(), onboardingConfig)
@@ -38,21 +38,21 @@ func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountOnboardingConfi
}
func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountOnboardingConfig_NoAccount() {
- resp, err := s.unauthdClients.users.GetAccountOnboardingConfig(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetAccountOnboardingConfigRequest{AccountId: uuid.NewString()}))
+ resp, err := s.UnauthdClients.Users.GetAccountOnboardingConfig(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetAccountOnboardingConfigRequest{AccountId: uuid.NewString()}))
requireErrResp(s.T(), resp, err)
requireConnectError(s.T(), err, connect.CodePermissionDenied)
}
func (s *IntegrationTestSuite) Test_UserAccountService_SetAccountOnboardingConfig_NoAccount() {
- resp, err := s.unauthdClients.users.SetAccountOnboardingConfig(s.ctx, connect.NewRequest(&mgmtv1alpha1.SetAccountOnboardingConfigRequest{AccountId: uuid.NewString(), Config: &mgmtv1alpha1.AccountOnboardingConfig{}}))
+ resp, err := s.UnauthdClients.Users.SetAccountOnboardingConfig(s.ctx, connect.NewRequest(&mgmtv1alpha1.SetAccountOnboardingConfigRequest{AccountId: uuid.NewString(), Config: &mgmtv1alpha1.AccountOnboardingConfig{}}))
requireErrResp(s.T(), resp, err)
requireConnectError(s.T(), err, connect.CodePermissionDenied)
}
func (s *IntegrationTestSuite) Test_UserAccountService_SetAccountOnboardingConfig_NoConfig() {
- accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users)
+ accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users)
- resp, err := s.unauthdClients.users.SetAccountOnboardingConfig(s.ctx, connect.NewRequest(&mgmtv1alpha1.SetAccountOnboardingConfigRequest{AccountId: accountId, Config: nil}))
+ resp, err := s.UnauthdClients.Users.SetAccountOnboardingConfig(s.ctx, connect.NewRequest(&mgmtv1alpha1.SetAccountOnboardingConfigRequest{AccountId: accountId, Config: nil}))
requireNoErrResp(s.T(), resp, err)
onboardingConfig := resp.Msg.GetConfig()
require.NotNil(s.T(), onboardingConfig)
@@ -61,9 +61,9 @@ func (s *IntegrationTestSuite) Test_UserAccountService_SetAccountOnboardingConfi
}
func (s *IntegrationTestSuite) Test_UserAccountService_SetAccountOnboardingConfig() {
- accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users)
+ accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users)
- resp, err := s.unauthdClients.users.SetAccountOnboardingConfig(s.ctx, connect.NewRequest(&mgmtv1alpha1.SetAccountOnboardingConfigRequest{
+ resp, err := s.UnauthdClients.Users.SetAccountOnboardingConfig(s.ctx, connect.NewRequest(&mgmtv1alpha1.SetAccountOnboardingConfigRequest{
AccountId: accountId, Config: &mgmtv1alpha1.AccountOnboardingConfig{
HasCompletedOnboarding: true,
}},
@@ -85,12 +85,12 @@ var (
)
func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountTemporalConfig() {
- accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users)
+ accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users)
- s.mocks.temporalClientManager.On("GetTemporalConfigByAccount", mock.Anything, mock.Anything).
+ s.Mocks.TemporalClientManager.On("GetTemporalConfigByAccount", mock.Anything, mock.Anything).
Return(validTemporalConfigModel, nil)
- resp, err := s.unauthdClients.users.GetAccountTemporalConfig(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetAccountTemporalConfigRequest{AccountId: accountId}))
+ resp, err := s.UnauthdClients.Users.GetAccountTemporalConfig(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetAccountTemporalConfigRequest{AccountId: accountId}))
requireNoErrResp(s.T(), resp, err)
tc := resp.Msg.GetConfig()
@@ -102,13 +102,13 @@ func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountTemporalConfig(
}
func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountTemporalConfig_NoAccount() {
- resp, err := s.unauthdClients.users.GetAccountTemporalConfig(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetAccountTemporalConfigRequest{AccountId: uuid.NewString()}))
+ resp, err := s.UnauthdClients.Users.GetAccountTemporalConfig(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetAccountTemporalConfigRequest{AccountId: uuid.NewString()}))
requireErrResp(s.T(), resp, err)
requireConnectError(s.T(), err, connect.CodePermissionDenied)
}
func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountTemporalConfig_NeosyncCloud() {
- userclient := s.neosyncCloudClients.getUserClient(testAuthUserId)
+ userclient := s.NeosyncCloudClients.GetUserClient(testAuthUserId)
s.setUser(s.ctx, userclient)
accountId := s.createPersonalAccount(s.ctx, userclient)
@@ -118,13 +118,13 @@ func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountTemporalConfig_
}
func (s *IntegrationTestSuite) Test_UserAccountService_SetAccountTemporalConfig_NoAccount() {
- resp, err := s.unauthdClients.users.SetAccountTemporalConfig(s.ctx, connect.NewRequest(&mgmtv1alpha1.SetAccountTemporalConfigRequest{AccountId: uuid.NewString()}))
+ resp, err := s.UnauthdClients.Users.SetAccountTemporalConfig(s.ctx, connect.NewRequest(&mgmtv1alpha1.SetAccountTemporalConfigRequest{AccountId: uuid.NewString()}))
requireErrResp(s.T(), resp, err)
requireConnectError(s.T(), err, connect.CodePermissionDenied)
}
func (s *IntegrationTestSuite) Test_UserAccountService_SetAccountTemporalConfig_NeosyncCloud() {
- userclient := s.neosyncCloudClients.getUserClient(testAuthUserId)
+ userclient := s.NeosyncCloudClients.GetUserClient(testAuthUserId)
s.setUser(s.ctx, userclient)
accountId := s.createPersonalAccount(s.ctx, userclient)
@@ -134,12 +134,12 @@ func (s *IntegrationTestSuite) Test_UserAccountService_SetAccountTemporalConfig_
}
func (s *IntegrationTestSuite) Test_UserAccountService_SetAccountTemporalConfig_NoConfig() {
- accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users)
+ accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users)
- s.mocks.temporalClientManager.On("GetTemporalConfigByAccount", mock.Anything, mock.Anything).
+ s.Mocks.TemporalClientManager.On("GetTemporalConfigByAccount", mock.Anything, mock.Anything).
Return(validTemporalConfigModel, nil)
- resp, err := s.unauthdClients.users.SetAccountTemporalConfig(s.ctx, connect.NewRequest(&mgmtv1alpha1.SetAccountTemporalConfigRequest{AccountId: accountId, Config: nil}))
+ resp, err := s.UnauthdClients.Users.SetAccountTemporalConfig(s.ctx, connect.NewRequest(&mgmtv1alpha1.SetAccountTemporalConfigRequest{AccountId: accountId, Config: nil}))
requireNoErrResp(s.T(), resp, err)
tc := resp.Msg.GetConfig()
@@ -151,13 +151,13 @@ func (s *IntegrationTestSuite) Test_UserAccountService_SetAccountTemporalConfig_
}
func (s *IntegrationTestSuite) Test_UserAccountService_SetAccountTemporalConfig() {
- accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users)
+ accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users)
// kind of a bad test since we are mocking this client wholesale, but it at least verifies we can write the config
- s.mocks.temporalClientManager.On("GetTemporalConfigByAccount", mock.Anything, mock.Anything).
+ s.Mocks.TemporalClientManager.On("GetTemporalConfigByAccount", mock.Anything, mock.Anything).
Return(validTemporalConfigModel, nil)
- resp, err := s.unauthdClients.users.SetAccountTemporalConfig(s.ctx, connect.NewRequest(&mgmtv1alpha1.SetAccountTemporalConfigRequest{
+ resp, err := s.UnauthdClients.Users.SetAccountTemporalConfig(s.ctx, connect.NewRequest(&mgmtv1alpha1.SetAccountTemporalConfigRequest{
AccountId: accountId, Config: &mgmtv1alpha1.AccountTemporalConfig{
Url: "test",
Namespace: "test",
@@ -174,7 +174,7 @@ func (s *IntegrationTestSuite) Test_UserAccountService_SetAccountTemporalConfig(
}
func (s *IntegrationTestSuite) Test_UserAccountService_GetUser_Auth() {
- client := s.authdClients.getUserClient("test-user1")
+ client := s.AuthdClients.GetUserClient("test-user1")
userId := s.setUser(s.ctx, client)
resp, err := client.GetUser(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetUserRequest{}))
@@ -183,21 +183,21 @@ func (s *IntegrationTestSuite) Test_UserAccountService_GetUser_Auth() {
}
func (s *IntegrationTestSuite) Test_UserAccountService_GetUser_Auth_NotFound() {
- client := s.authdClients.getUserClient(testAuthUserId)
+ client := s.AuthdClients.GetUserClient(testAuthUserId)
resp, err := client.GetUser(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetUserRequest{}))
requireErrResp(s.T(), resp, err)
requireConnectError(s.T(), err, connect.CodeNotFound)
}
func (s *IntegrationTestSuite) Test_UserAccountService_SetUser_Auth() {
- client := s.authdClients.getUserClient(testAuthUserId)
+ client := s.AuthdClients.GetUserClient(testAuthUserId)
userId := s.setUser(s.ctx, client)
require.NotEmpty(s.T(), userId)
require.NotEqual(s.T(), "00000000-0000-0000-0000-000000000000", userId)
}
func (s *IntegrationTestSuite) Test_UserAccountService_CreateTeamAccount_Auth() {
- client := s.authdClients.getUserClient(testAuthUserId)
+ client := s.AuthdClients.GetUserClient(testAuthUserId)
s.setUser(s.ctx, client)
resp, err := client.CreateTeamAccount(s.ctx, connect.NewRequest(&mgmtv1alpha1.CreateTeamAccountRequest{Name: "test-name"}))
@@ -206,12 +206,12 @@ func (s *IntegrationTestSuite) Test_UserAccountService_CreateTeamAccount_Auth()
}
func (s *IntegrationTestSuite) Test_UserAccountService_CreateTeamAccount_NeosyncCloud() {
- client := s.neosyncCloudClients.getUserClient(testAuthUserId)
+ client := s.NeosyncCloudClients.GetUserClient(testAuthUserId)
s.setUser(s.ctx, client)
- s.mocks.billingclient.On("NewCustomer", mock.Anything).Once().
+ s.Mocks.Billingclient.On("NewCustomer", mock.Anything).Once().
Return(&stripe.Customer{ID: "test-stripe-id"}, nil)
- s.mocks.billingclient.On("NewCheckoutSession", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Once().
+ s.Mocks.Billingclient.On("NewCheckoutSession", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Once().
Return(&stripe.CheckoutSession{URL: "test-url"}, nil)
resp, err := client.CreateTeamAccount(s.ctx, connect.NewRequest(&mgmtv1alpha1.CreateTeamAccountRequest{Name: "test-name"}))
@@ -221,11 +221,11 @@ func (s *IntegrationTestSuite) Test_UserAccountService_CreateTeamAccount_Neosync
}
func (s *IntegrationTestSuite) Test_UserAccountService_GetTeamAccountMembers_Auth() {
- client := s.authdClients.getUserClient(testAuthUserId)
+ client := s.AuthdClients.GetUserClient(testAuthUserId)
userId := s.setUser(s.ctx, client)
accountId := s.createTeamAccount(s.ctx, client, "test-team")
- s.mocks.authmanagerclient.On("GetUserBySub", mock.Anything, testAuthUserId).
+ s.Mocks.Authmanagerclient.On("GetUserBySub", mock.Anything, testAuthUserId).
Return(validAuthUser, nil)
resp, err := client.GetTeamAccountMembers(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetTeamAccountMembersRequest{AccountId: accountId}))
@@ -255,7 +255,7 @@ func (s *IntegrationTestSuite) createTeamAccount(ctx context.Context, client mgm
}
func (s *IntegrationTestSuite) Test_UserAccountService_GetUser() {
- resp, err := s.unauthdClients.users.GetUser(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetUserRequest{}))
+ resp, err := s.UnauthdClients.Users.GetUser(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetUserRequest{}))
requireNoErrResp(s.T(), resp, err)
userId := resp.Msg.GetUserId()
@@ -263,7 +263,7 @@ func (s *IntegrationTestSuite) Test_UserAccountService_GetUser() {
}
func (s *IntegrationTestSuite) Test_UserAccountService_SetUser() {
- resp, err := s.unauthdClients.users.SetUser(s.ctx, connect.NewRequest(&mgmtv1alpha1.SetUserRequest{}))
+ resp, err := s.UnauthdClients.Users.SetUser(s.ctx, connect.NewRequest(&mgmtv1alpha1.SetUserRequest{}))
requireNoErrResp(s.T(), resp, err)
userId := resp.Msg.UserId
@@ -272,7 +272,7 @@ func (s *IntegrationTestSuite) Test_UserAccountService_SetUser() {
}
func (s *IntegrationTestSuite) Test_UserAccountService_GetAccounts_Empty() {
- resp, err := s.unauthdClients.users.GetUserAccounts(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetUserAccountsRequest{}))
+ resp, err := s.UnauthdClients.Users.GetUserAccounts(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetUserAccountsRequest{}))
requireNoErrResp(s.T(), resp, err)
accounts := resp.Msg.GetAccounts()
@@ -280,7 +280,7 @@ func (s *IntegrationTestSuite) Test_UserAccountService_GetAccounts_Empty() {
}
func (s *IntegrationTestSuite) Test_UserAccountService_SetPersonalAccount() {
- resp, err := s.unauthdClients.users.SetPersonalAccount(s.ctx, connect.NewRequest(&mgmtv1alpha1.SetPersonalAccountRequest{}))
+ resp, err := s.UnauthdClients.Users.SetPersonalAccount(s.ctx, connect.NewRequest(&mgmtv1alpha1.SetPersonalAccountRequest{}))
requireNoErrResp(s.T(), resp, err)
accountId := resp.Msg.GetAccountId()
@@ -288,9 +288,9 @@ func (s *IntegrationTestSuite) Test_UserAccountService_SetPersonalAccount() {
}
func (s *IntegrationTestSuite) Test_UserAccountService_GetAccounts_NotEmpty() {
- s.createPersonalAccount(s.ctx, s.unauthdClients.users)
+ s.createPersonalAccount(s.ctx, s.UnauthdClients.Users)
- accResp, err := s.unauthdClients.users.GetUserAccounts(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetUserAccountsRequest{}))
+ accResp, err := s.UnauthdClients.Users.GetUserAccounts(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetUserAccountsRequest{}))
requireNoErrResp(s.T(), accResp, err)
accounts := accResp.Msg.GetAccounts()
@@ -299,15 +299,15 @@ func (s *IntegrationTestSuite) Test_UserAccountService_GetAccounts_NotEmpty() {
}
func (s *IntegrationTestSuite) Test_UserAccountService_IsUserInAccount() {
- accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users)
+ accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users)
- resp, err := s.unauthdClients.users.IsUserInAccount(s.ctx, connect.NewRequest(&mgmtv1alpha1.IsUserInAccountRequest{
+ resp, err := s.UnauthdClients.Users.IsUserInAccount(s.ctx, connect.NewRequest(&mgmtv1alpha1.IsUserInAccountRequest{
AccountId: accountId,
}))
requireNoErrResp(s.T(), resp, err)
require.True(s.T(), resp.Msg.GetOk())
- resp, err = s.unauthdClients.users.IsUserInAccount(s.ctx, connect.NewRequest(&mgmtv1alpha1.IsUserInAccountRequest{
+ resp, err = s.UnauthdClients.Users.IsUserInAccount(s.ctx, connect.NewRequest(&mgmtv1alpha1.IsUserInAccountRequest{
AccountId: uuid.NewString(),
}))
requireNoErrResp(s.T(), resp, err)
@@ -315,63 +315,63 @@ func (s *IntegrationTestSuite) Test_UserAccountService_IsUserInAccount() {
}
func (s *IntegrationTestSuite) Test_UserAccountService_CreateTeamAccount_NoAuth() {
- s.createPersonalAccount(s.ctx, s.unauthdClients.users)
+ s.createPersonalAccount(s.ctx, s.UnauthdClients.Users)
- resp, err := s.unauthdClients.users.CreateTeamAccount(s.ctx, connect.NewRequest(&mgmtv1alpha1.CreateTeamAccountRequest{Name: "test-name"}))
+ resp, err := s.UnauthdClients.Users.CreateTeamAccount(s.ctx, connect.NewRequest(&mgmtv1alpha1.CreateTeamAccountRequest{Name: "test-name"}))
requireErrResp(s.T(), resp, err)
requireConnectError(s.T(), err, connect.CodePermissionDenied)
}
func (s *IntegrationTestSuite) Test_UserAccountService_GetTeamAccountMembers_NoAuth_Personal() {
- accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users)
+ accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users)
- resp, err := s.unauthdClients.users.GetTeamAccountMembers(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetTeamAccountMembersRequest{AccountId: accountId}))
+ resp, err := s.UnauthdClients.Users.GetTeamAccountMembers(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetTeamAccountMembersRequest{AccountId: accountId}))
requireErrResp(s.T(), resp, err)
requireConnectError(s.T(), err, connect.CodePermissionDenied)
}
func (s *IntegrationTestSuite) Test_UserAccountService_RemoveTeamAccountMember_NoAuth_Personal() {
- accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users)
+ accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users)
- resp, err := s.unauthdClients.users.RemoveTeamAccountMember(s.ctx, connect.NewRequest(&mgmtv1alpha1.RemoveTeamAccountMemberRequest{AccountId: accountId, UserId: uuid.NewString()}))
+ resp, err := s.UnauthdClients.Users.RemoveTeamAccountMember(s.ctx, connect.NewRequest(&mgmtv1alpha1.RemoveTeamAccountMemberRequest{AccountId: accountId, UserId: uuid.NewString()}))
requireErrResp(s.T(), resp, err)
requireConnectError(s.T(), err, connect.CodePermissionDenied)
}
func (s *IntegrationTestSuite) Test_UserAccountService_InviteUserToTeamAccount_NoAuth_Personal() {
- accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users)
+ accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users)
- resp, err := s.unauthdClients.users.InviteUserToTeamAccount(s.ctx, connect.NewRequest(&mgmtv1alpha1.InviteUserToTeamAccountRequest{AccountId: accountId, Email: "test@example.com"}))
+ resp, err := s.UnauthdClients.Users.InviteUserToTeamAccount(s.ctx, connect.NewRequest(&mgmtv1alpha1.InviteUserToTeamAccountRequest{AccountId: accountId, Email: "test@example.com"}))
requireErrResp(s.T(), resp, err)
requireConnectError(s.T(), err, connect.CodePermissionDenied)
}
func (s *IntegrationTestSuite) Test_UserAccountService_GetTeamAccountInvites_NoAuth_Personal() {
- accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users)
+ accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users)
- resp, err := s.unauthdClients.users.GetTeamAccountInvites(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetTeamAccountInvitesRequest{AccountId: accountId}))
+ resp, err := s.UnauthdClients.Users.GetTeamAccountInvites(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetTeamAccountInvitesRequest{AccountId: accountId}))
requireErrResp(s.T(), resp, err)
requireConnectError(s.T(), err, connect.CodePermissionDenied)
}
func (s *IntegrationTestSuite) Test_UserAccountService_RemoveTeamAccountInvite_NoAuth_Personal() {
- resp, err := s.unauthdClients.users.RemoveTeamAccountInvite(s.ctx, connect.NewRequest(&mgmtv1alpha1.RemoveTeamAccountInviteRequest{Id: uuid.NewString()}))
+ resp, err := s.UnauthdClients.Users.RemoveTeamAccountInvite(s.ctx, connect.NewRequest(&mgmtv1alpha1.RemoveTeamAccountInviteRequest{Id: uuid.NewString()}))
requireNoErrResp(s.T(), resp, err)
}
func (s *IntegrationTestSuite) Test_UserAccountService_AcceptTeamAccountInvite_NoAuth_Personal() {
- resp, err := s.unauthdClients.users.AcceptTeamAccountInvite(s.ctx, connect.NewRequest(&mgmtv1alpha1.AcceptTeamAccountInviteRequest{Token: uuid.NewString()}))
+ resp, err := s.UnauthdClients.Users.AcceptTeamAccountInvite(s.ctx, connect.NewRequest(&mgmtv1alpha1.AcceptTeamAccountInviteRequest{Token: uuid.NewString()}))
requireErrResp(s.T(), resp, err)
requireConnectError(s.T(), err, connect.CodeUnauthenticated)
}
func (s *IntegrationTestSuite) Test_UserAccountService_GetSystemInformation() {
- resp, err := s.unauthdClients.users.GetSystemInformation(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetSystemInformationRequest{}))
+ resp, err := s.UnauthdClients.Users.GetSystemInformation(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetSystemInformationRequest{}))
requireNoErrResp(s.T(), resp, err)
}
func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountStatus_NeosyncCloud_Personal() {
- userclient := s.neosyncCloudClients.getUserClient(testAuthUserId)
+ userclient := s.NeosyncCloudClients.GetUserClient(testAuthUserId)
s.setUser(s.ctx, userclient)
accountId := s.createPersonalAccount(s.ctx, userclient)
@@ -401,7 +401,7 @@ func (t *testSubscriptionIter) Err() error {
}
func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountStatus_NeosyncCloud_Billed() {
- userclient := s.neosyncCloudClients.getUserClient(testAuthUserId)
+ userclient := s.NeosyncCloudClients.GetUserClient(testAuthUserId)
s.setUser(s.ctx, userclient)
t := s.T()
@@ -409,7 +409,7 @@ func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountStatus_NeosyncC
t.Run("active_sub", func(t *testing.T) {
custId := "cust_id1"
accountId := s.createBilledTeamAccount(s.ctx, userclient, "test-team", custId)
- s.mocks.billingclient.On("GetSubscriptions", custId).Once().Return(&testSubscriptionIter{subscriptions: []*stripe.Subscription{
+ s.Mocks.Billingclient.On("GetSubscriptions", custId).Once().Return(&testSubscriptionIter{subscriptions: []*stripe.Subscription{
{Status: stripe.SubscriptionStatusIncompleteExpired},
{Status: stripe.SubscriptionStatusActive},
}}, nil)
@@ -428,7 +428,7 @@ func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountStatus_NeosyncC
err := s.setAccountCreatedAt(s.ctx, accountId, time.Now().UTC().Add(-30*24*time.Hour))
assert.NoError(s.T(), err)
- s.mocks.billingclient.On("GetSubscriptions", custId).Once().Return(&testSubscriptionIter{subscriptions: []*stripe.Subscription{
+ s.Mocks.Billingclient.On("GetSubscriptions", custId).Once().Return(&testSubscriptionIter{subscriptions: []*stripe.Subscription{
{Status: stripe.SubscriptionStatusIncompleteExpired},
{Status: stripe.SubscriptionStatusIncompleteExpired},
}}, nil)
@@ -443,9 +443,9 @@ func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountStatus_NeosyncC
}
func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountStatus_OSS_Personal() {
- accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users)
+ accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users)
- resp, err := s.unauthdClients.users.GetAccountStatus(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetAccountStatusRequest{
+ resp, err := s.UnauthdClients.Users.GetAccountStatus(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetAccountStatusRequest{
AccountId: accountId,
}))
requireNoErrResp(s.T(), resp, err)
@@ -454,7 +454,7 @@ func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountStatus_OSS_Pers
}
func (s *IntegrationTestSuite) Test_UserAccountService_IsAccountStatusValid_NeosyncCloud_Personal() {
- userclient := s.neosyncCloudClients.getUserClient(testAuthUserId)
+ userclient := s.NeosyncCloudClients.GetUserClient(testAuthUserId)
s.setUser(s.ctx, userclient)
accountId := s.createPersonalAccount(s.ctx, userclient)
@@ -469,14 +469,14 @@ func (s *IntegrationTestSuite) Test_UserAccountService_IsAccountStatusValid_Neos
}
func (s *IntegrationTestSuite) Test_UserAccountService_IsAccountStatusValid_NeosyncCloud_Billed() {
- userclient := s.neosyncCloudClients.getUserClient(testAuthUserId)
+ userclient := s.NeosyncCloudClients.GetUserClient(testAuthUserId)
s.setUser(s.ctx, userclient)
t := s.T()
t.Run("active", func(t *testing.T) {
custId := "cust_id1"
accountId := s.createBilledTeamAccount(s.ctx, userclient, "test1", custId)
- s.mocks.billingclient.On("GetSubscriptions", custId).Once().Return(&testSubscriptionIter{subscriptions: []*stripe.Subscription{
+ s.Mocks.Billingclient.On("GetSubscriptions", custId).Once().Return(&testSubscriptionIter{subscriptions: []*stripe.Subscription{
{Status: stripe.SubscriptionStatusActive},
}}, nil)
resp, err := userclient.IsAccountStatusValid(s.ctx, connect.NewRequest(&mgmtv1alpha1.IsAccountStatusValidRequest{
@@ -494,7 +494,7 @@ func (s *IntegrationTestSuite) Test_UserAccountService_IsAccountStatusValid_Neos
accountId := s.createBilledTeamAccount(s.ctx, userclient, "test2", custId)
err := s.setAccountCreatedAt(s.ctx, accountId, time.Now().UTC().Add(-30*24*time.Hour))
assert.NoError(t, err)
- s.mocks.billingclient.On("GetSubscriptions", custId).Once().Return(&testSubscriptionIter{subscriptions: []*stripe.Subscription{
+ s.Mocks.Billingclient.On("GetSubscriptions", custId).Once().Return(&testSubscriptionIter{subscriptions: []*stripe.Subscription{
{Status: stripe.SubscriptionStatusIncompleteExpired},
}}, nil)
@@ -511,7 +511,7 @@ func (s *IntegrationTestSuite) Test_UserAccountService_IsAccountStatusValid_Neos
t.Run("no_subs_active_trial", func(t *testing.T) {
custId := "cust_id3"
accountId := s.createBilledTeamAccount(s.ctx, userclient, "test3", custId)
- s.mocks.billingclient.On("GetSubscriptions", custId).Once().Return(&testSubscriptionIter{subscriptions: []*stripe.Subscription{}}, nil)
+ s.Mocks.Billingclient.On("GetSubscriptions", custId).Once().Return(&testSubscriptionIter{subscriptions: []*stripe.Subscription{}}, nil)
resp, err := userclient.IsAccountStatusValid(s.ctx, connect.NewRequest(&mgmtv1alpha1.IsAccountStatusValidRequest{
AccountId: accountId,
@@ -527,7 +527,7 @@ func (s *IntegrationTestSuite) Test_UserAccountService_IsAccountStatusValid_Neos
accountId := s.createBilledTeamAccount(s.ctx, userclient, "test4", custId)
err := s.setAccountCreatedAt(s.ctx, accountId, time.Now().UTC().Add(-30*24*time.Hour))
assert.NoError(t, err)
- s.mocks.billingclient.On("GetSubscriptions", custId).Once().Return(&testSubscriptionIter{subscriptions: []*stripe.Subscription{}}, nil)
+ s.Mocks.Billingclient.On("GetSubscriptions", custId).Once().Return(&testSubscriptionIter{subscriptions: []*stripe.Subscription{}}, nil)
resp, err := userclient.IsAccountStatusValid(s.ctx, connect.NewRequest(&mgmtv1alpha1.IsAccountStatusValidRequest{
AccountId: accountId,
@@ -542,7 +542,7 @@ func (s *IntegrationTestSuite) Test_UserAccountService_IsAccountStatusValid_Neos
t.Run("no_active_subs_active_trial", func(t *testing.T) {
custId := "cust_id5"
accountId := s.createBilledTeamAccount(s.ctx, userclient, "test5", custId)
- s.mocks.billingclient.On("GetSubscriptions", custId).Once().Return(&testSubscriptionIter{subscriptions: []*stripe.Subscription{
+ s.Mocks.Billingclient.On("GetSubscriptions", custId).Once().Return(&testSubscriptionIter{subscriptions: []*stripe.Subscription{
{Status: stripe.SubscriptionStatusIncompleteExpired},
}}, nil)
@@ -558,7 +558,7 @@ func (s *IntegrationTestSuite) Test_UserAccountService_IsAccountStatusValid_Neos
}
func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountBillingCheckoutSession() {
- userclient := s.neosyncCloudClients.getUserClient(testAuthUserId)
+ userclient := s.NeosyncCloudClients.GetUserClient(testAuthUserId)
s.setUser(s.ctx, userclient)
t := s.T()
@@ -566,7 +566,7 @@ func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountBillingCheckout
t.Run("billed account - allowed", func(t *testing.T) {
teamAccountId := s.createBilledTeamAccount(s.ctx, userclient, "test-team", "test-stripe-id")
- s.mocks.billingclient.On("NewCheckoutSession", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Once().
+ s.Mocks.Billingclient.On("NewCheckoutSession", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Once().
Return(&stripe.CheckoutSession{URL: "new-test-url"}, nil)
resp, err := userclient.GetAccountBillingCheckoutSession(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetAccountBillingCheckoutSessionRequest{
@@ -585,7 +585,7 @@ func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountBillingCheckout
})
t.Run("non-neosynccloud - disallowed", func(t *testing.T) {
- personalAccountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users)
+ personalAccountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users)
resp, err := userclient.GetAccountBillingCheckoutSession(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetAccountBillingCheckoutSessionRequest{
AccountId: personalAccountId,
}))
@@ -594,7 +594,7 @@ func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountBillingCheckout
}
func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountBillingPortalSession() {
- userclient := s.neosyncCloudClients.getUserClient(testAuthUserId)
+ userclient := s.NeosyncCloudClients.GetUserClient(testAuthUserId)
s.setUser(s.ctx, userclient)
t := s.T()
@@ -602,7 +602,7 @@ func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountBillingPortalSe
t.Run("billed account - allowed", func(t *testing.T) {
teamAccountId := s.createBilledTeamAccount(s.ctx, userclient, "test-team", "test-stripe-id")
- s.mocks.billingclient.On("NewBillingPortalSession", mock.Anything, mock.Anything).Once().
+ s.Mocks.Billingclient.On("NewBillingPortalSession", mock.Anything, mock.Anything).Once().
Return(&stripe.BillingPortalSession{URL: "new-test-url"}, nil)
resp, err := userclient.GetAccountBillingPortalSession(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetAccountBillingPortalSessionRequest{
@@ -621,8 +621,8 @@ func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountBillingPortalSe
})
t.Run("non-neosynccloud - disallowed", func(t *testing.T) {
- personalAccountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users)
- resp, err := s.unauthdClients.users.GetAccountBillingPortalSession(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetAccountBillingPortalSessionRequest{
+ personalAccountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users)
+ resp, err := s.UnauthdClients.Users.GetAccountBillingPortalSession(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetAccountBillingPortalSessionRequest{
AccountId: personalAccountId,
}))
requireErrResp(s.T(), resp, err)
@@ -630,22 +630,22 @@ func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountBillingPortalSe
}
func (s *IntegrationTestSuite) createBilledTeamAccount(ctx context.Context, client mgmtv1alpha1connect.UserAccountServiceClient, name, stripeCustomerId string) string {
- s.mocks.billingclient.On("NewCustomer", mock.Anything).Once().
+ s.Mocks.Billingclient.On("NewCustomer", mock.Anything).Once().
Return(&stripe.Customer{ID: stripeCustomerId}, nil)
- s.mocks.billingclient.On("NewCheckoutSession", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Once().
+ s.Mocks.Billingclient.On("NewCheckoutSession", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Once().
Return(&stripe.CheckoutSession{URL: "test-url"}, nil)
return s.createTeamAccount(ctx, client, name)
}
func (s *IntegrationTestSuite) Test_GetBillingAccounts() {
- userclient1 := s.neosyncCloudClients.getUserClient(testAuthUserId)
+ userclient1 := s.NeosyncCloudClients.GetUserClient(testAuthUserId)
s.setUser(s.ctx, userclient1)
- userclient2 := s.neosyncCloudClients.getUserClient(testAuthUserId2)
+ userclient2 := s.NeosyncCloudClients.GetUserClient(testAuthUserId2)
s.setUser(s.ctx, userclient2)
workerapikey := apikey.NewV1WorkerKey()
- workeruserclient := s.neosyncCloudClients.getUserClient(workerapikey)
+ workeruserclient := s.NeosyncCloudClients.GetUserClient(workerapikey)
t := s.T()
@@ -691,8 +691,8 @@ func (s *IntegrationTestSuite) Test_ConvertPersonalToTeamAccount() {
t := s.T()
t.Run("OSS unauth", func(t *testing.T) {
- s.setUser(s.ctx, s.unauthdClients.users)
- resp, err := s.unauthdClients.users.ConvertPersonalToTeamAccount(s.ctx, connect.NewRequest(&mgmtv1alpha1.ConvertPersonalToTeamAccountRequest{
+ s.setUser(s.ctx, s.UnauthdClients.Users)
+ resp, err := s.UnauthdClients.Users.ConvertPersonalToTeamAccount(s.ctx, connect.NewRequest(&mgmtv1alpha1.ConvertPersonalToTeamAccountRequest{
Name: "unauthteamname",
}))
requireErrResp(t, resp, err)
@@ -700,7 +700,7 @@ func (s *IntegrationTestSuite) Test_ConvertPersonalToTeamAccount() {
})
t.Run("OSS auth success", func(t *testing.T) {
- userclient := s.authdClients.getUserClient(testAuthUserId)
+ userclient := s.AuthdClients.GetUserClient(testAuthUserId)
s.setUser(s.ctx, userclient)
accountId := s.createPersonalAccount(s.ctx, userclient)
@@ -713,14 +713,14 @@ func (s *IntegrationTestSuite) Test_ConvertPersonalToTeamAccount() {
})
t.Run("cloud billing success", func(t *testing.T) {
- userclient := s.neosyncCloudClients.getUserClient(testAuthUserId)
+ userclient := s.NeosyncCloudClients.GetUserClient(testAuthUserId)
s.setUser(s.ctx, userclient)
accountId := s.createPersonalAccount(s.ctx, userclient)
stripeCustomerId := "foo"
- s.mocks.billingclient.On("NewCustomer", mock.Anything).Once().
+ s.Mocks.Billingclient.On("NewCustomer", mock.Anything).Once().
Return(&stripe.Customer{ID: stripeCustomerId}, nil)
- s.mocks.billingclient.On("NewCheckoutSession", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Once().
+ s.Mocks.Billingclient.On("NewCheckoutSession", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Once().
Return(&stripe.CheckoutSession{URL: "test-url"}, nil)
resp, err := userclient.ConvertPersonalToTeamAccount(s.ctx, connect.NewRequest(&mgmtv1alpha1.ConvertPersonalToTeamAccountRequest{
Name: "newname2",
@@ -731,13 +731,13 @@ func (s *IntegrationTestSuite) Test_ConvertPersonalToTeamAccount() {
})
t.Run("cloud success unspecified account", func(t *testing.T) {
- userclient := s.neosyncCloudClients.getUserClient(testAuthUserId)
+ userclient := s.NeosyncCloudClients.GetUserClient(testAuthUserId)
s.setUser(s.ctx, userclient)
stripeCustomerId := "foo"
- s.mocks.billingclient.On("NewCustomer", mock.Anything).Once().
+ s.Mocks.Billingclient.On("NewCustomer", mock.Anything).Once().
Return(&stripe.Customer{ID: stripeCustomerId}, nil)
- s.mocks.billingclient.On("NewCheckoutSession", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Once().
+ s.Mocks.Billingclient.On("NewCheckoutSession", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Once().
Return(&stripe.CheckoutSession{URL: "test-url"}, nil)
resp, err := userclient.ConvertPersonalToTeamAccount(s.ctx, connect.NewRequest(&mgmtv1alpha1.ConvertPersonalToTeamAccountRequest{
Name: "newname3",
@@ -747,11 +747,11 @@ func (s *IntegrationTestSuite) Test_ConvertPersonalToTeamAccount() {
}
func (s *IntegrationTestSuite) Test_SetBillingMeterEvent() {
- userclient1 := s.neosyncCloudClients.getUserClient(testAuthUserId)
+ userclient1 := s.NeosyncCloudClients.GetUserClient(testAuthUserId)
s.setUser(s.ctx, userclient1)
workerapikey := apikey.NewV1WorkerKey()
- workeruserclient := s.neosyncCloudClients.getUserClient(workerapikey)
+ workeruserclient := s.NeosyncCloudClients.GetUserClient(workerapikey)
t := s.T()
@@ -759,7 +759,7 @@ func (s *IntegrationTestSuite) Test_SetBillingMeterEvent() {
au1TeamAccountId1 := s.createBilledTeamAccount(s.ctx, userclient1, "test-team", "test-stripe-id")
t.Run("new event", func(t *testing.T) {
- s.mocks.billingclient.On("NewMeterEvent", mock.Anything).Once().Return(&stripe.BillingMeterEvent{}, nil)
+ s.Mocks.Billingclient.On("NewMeterEvent", mock.Anything).Once().Return(&stripe.BillingMeterEvent{}, nil)
ts := uint64(1)
resp, err := workeruserclient.SetBillingMeterEvent(s.ctx, connect.NewRequest(&mgmtv1alpha1.SetBillingMeterEventRequest{
AccountId: au1TeamAccountId1,
@@ -795,7 +795,7 @@ func (s *IntegrationTestSuite) Test_SetBillingMeterEvent() {
t.Run("squashes meter already existing", func(t *testing.T) {
eventId := "test-event-id"
stripeerr := &stripe.Error{Type: stripe.ErrorTypeInvalidRequest, Msg: fmt.Sprintf("An event already exists with identifier %s", eventId)}
- s.mocks.billingclient.On("NewMeterEvent", mock.Anything).Once().Return(nil, stripeerr)
+ s.Mocks.Billingclient.On("NewMeterEvent", mock.Anything).Once().Return(nil, stripeerr)
ts := uint64(1)
resp, err := workeruserclient.SetBillingMeterEvent(s.ctx, connect.NewRequest(&mgmtv1alpha1.SetBillingMeterEventRequest{
AccountId: au1TeamAccountId1,
diff --git a/cli/internal/auth/tokens.go b/cli/internal/auth/tokens.go
index 53015ddd90..199f80bced 100644
--- a/cli/internal/auth/tokens.go
+++ b/cli/internal/auth/tokens.go
@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"log/slog"
+ "net/http"
"connectrpc.com/connect"
mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1"
@@ -91,3 +92,41 @@ func IsAuthEnabled(ctx context.Context) (bool, error) {
}
return isEnabledResp.Msg.IsEnabled, nil
}
+
+// Returns the neosync url found in the environment, otherwise defaults to localhost
+func GetNeosyncUrl() string {
+ return serverconfig.GetApiBaseUrl()
+}
+
+// Returns an instance of *http.Client that includes the Neosync API Token if one was found in the environment
+func GetNeosyncHttpClient(ctx context.Context, apiKey *string, logger *slog.Logger) (*http.Client, error) {
+ token, err := GetToken(ctx, apiKey, logger)
+ if err != nil {
+ return nil, err
+ }
+ return http_client.NewWithBearerAuth(token), nil
+}
+
+func GetToken(ctx context.Context, apiKey *string, logger *slog.Logger) (*string, error) {
+ isAuthEnabled, err := IsAuthEnabled(ctx)
+ if err != nil {
+ return nil, err
+ }
+ var token *string
+ if isAuthEnabled {
+ logger.Debug("Auth Enabled")
+ if apiKey != nil && *apiKey != "" {
+ logger.Debug("found API Key")
+ token = apiKey
+ } else {
+ logger.Debug("Retrieving Access Token")
+ accessToken, err := userconfig.GetAccessToken()
+ if err != nil {
+ logger.Error("Unable to retrieve access token. Please use neosync login command and try again.")
+ return nil, err
+ }
+ token = &accessToken
+ }
+ }
+ return token, nil
+}
diff --git a/cli/internal/benthos/config.go b/cli/internal/benthos/config.go
deleted file mode 100644
index b5c9a88cf1..0000000000
--- a/cli/internal/benthos/config.go
+++ /dev/null
@@ -1,164 +0,0 @@
-package cli_neosync_benthos
-
-type BenthosConfig struct {
- // HTTP HTTPConfig `json:"http" yaml:"http"`
- StreamConfig `json:",inline" yaml:",inline"`
-}
-
-type HTTPConfig struct {
- Address string `json:"address" yaml:"address"`
- Enabled bool `json:"enabled" yaml:"enabled"`
- // RootPath string `json:"root_path" yaml:"root_path"`
- // DebugEndpoints bool `json:"debug_endpoints" yaml:"debug_endpoints"`
- // CertFile string `json:"cert_file" yaml:"cert_file"`
- // KeyFile string `json:"key_file" yaml:"key_file"`
- // CORS httpserver.CORSConfig `json:"cors" yaml:"cors"`
- // BasicAuth httpserver.BasicAuthConfig `json:"basic_auth" yaml:"basic_auth"`
-}
-
-type StreamConfig struct {
- Logger *LoggerConfig `json:"logger" yaml:"logger,omitempty"`
- Input *InputConfig `json:"input" yaml:"input"`
- Buffer *BufferConfig `json:"buffer,omitempty" yaml:"buffer,omitempty"`
- Pipeline *PipelineConfig `json:"pipeline" yaml:"pipeline"`
- Output *OutputConfig `json:"output" yaml:"output"`
-}
-
-type LoggerConfig struct {
- Level string `json:"level" yaml:"level"`
- AddTimestamp bool `json:"add_timestamp" yaml:"add_timestamp"`
-}
-type InputConfig struct {
- Label string `json:"label" yaml:"label"`
- Inputs `json:",inline" yaml:",inline"`
-}
-
-type Inputs struct {
- NeosyncConnectionData *NeosyncConnectionData `json:"neosync_connection_data,omitempty" yaml:"neosync_connection_data,omitempty"`
-}
-
-type NeosyncConnectionData struct {
- ApiKey *string `json:"api_key,omitempty" yaml:"api_key,omitempty"`
- ApiUrl string `json:"api_url" yaml:"api_url"`
- ConnectionId string `json:"connection_id" yaml:"connection_id"`
- ConnectionType string `json:"connection_type" yaml:"connection_type"`
- JobId *string `json:"job_id,omitempty" yaml:"job_id,omitempty"`
- JobRunId *string `json:"job_run_id,omitempty" yaml:"job_run_id,omitempty"`
- Schema string `json:"schema" yaml:"schema"`
- Table string `json:"table" yaml:"table"`
-}
-
-type BufferConfig struct{}
-
-type PipelineConfig struct {
- Threads int `json:"threads" yaml:"threads"`
- Processors []ProcessorConfig `json:"processors" yaml:"processors"`
-}
-
-type ProcessorConfig struct {
-}
-
-type BranchConfig struct {
- Processors []ProcessorConfig `json:"processors" yaml:"processors"`
- RequestMap *string `json:"request_map,omitempty" yaml:"request_map,omitempty"`
- ResultMap *string `json:"result_map,omitempty" yaml:"result_map,omitempty"`
-}
-
-type OutputConfig struct {
- Label string `json:"label" yaml:"label"`
- Outputs `json:",inline" yaml:",inline"`
- Processors []ProcessorConfig `json:"processors,omitempty" yaml:"processors,omitempty"`
- // Broker *OutputBrokerConfig `json:"broker,omitempty" yaml:"broker,omitempty"`
-}
-
-type Outputs struct {
- PooledSqlInsert *PooledSqlInsert `json:"pooled_sql_insert,omitempty" yaml:"pooled_sql_insert,omitempty"`
- PooledSqlUpdate *PooledSqlUpdate `json:"pooled_sql_update,omitempty" yaml:"pooled_sql_update,omitempty"`
- AwsS3 *AwsS3Insert `json:"aws_s3,omitempty" yaml:"aws_s3,omitempty"`
- AwsDynamoDB *OutputAwsDynamoDB `json:"aws_dynamodb,omitempty" yaml:"aws_dynamodb,omitempty"`
-}
-
-type OutputAwsDynamoDB struct {
- Table string `json:"table" yaml:"table"`
- JsonMapColumns map[string]string `json:"json_map_columns,omitempty" yaml:"json_map_columns,omitempty"`
-
- Region string `json:"region,omitempty" yaml:"region,omitempty"`
- Endpoint string `json:"endpoint,omitempty" yaml:"endpoint,omitempty"`
-
- Credentials *AwsCredentials `json:"credentials,omitempty" yaml:"credentials,omitempty"`
-
- MaxInFlight *int `json:"max_in_flight,omitempty" yaml:"max_in_flight,omitempty"`
- Batching *Batching `json:"batching,omitempty" yaml:"batching,omitempty"`
-}
-
-type PooledSqlUpdate struct {
- Driver string `json:"driver" yaml:"driver"`
- Dsn string `json:"dsn" yaml:"dsn"`
- Schema string `json:"schema" yaml:"schema"`
- Table string `json:"table" yaml:"table"`
- Columns []string `json:"columns" yaml:"columns"`
- WhereColumns []string `json:"where_columns" yaml:"where_columns"`
- ArgsMapping string `json:"args_mapping" yaml:"args_mapping"`
- Batching *Batching `json:"batching,omitempty" yaml:"batching,omitempty"`
-}
-
-type PooledSqlInsert struct {
- Driver string `json:"driver" yaml:"driver"`
- Dsn string `json:"dsn" yaml:"dsn"`
- Schema string `json:"schema" yaml:"schema"`
- Table string `json:"table" yaml:"table"`
- Columns []string `json:"columns" yaml:"columns"`
- OnConflictDoNothing bool `json:"on_conflict_do_nothing" yaml:"on_conflict_do_nothing"`
- TruncateOnRetry bool `json:"truncate_on_retry" yaml:"truncate_on_retry"`
- ArgsMapping string `json:"args_mapping" yaml:"args_mapping"`
- Batching *Batching `json:"batching,omitempty" yaml:"batching,omitempty"`
-}
-
-type AwsS3Insert struct {
- Bucket string `json:"bucket" yaml:"bucket"`
- MaxInFlight int `json:"max_in_flight" yaml:"max_in_flight"`
- Path string `json:"path" yaml:"path"`
- Batching *Batching `json:"batching,omitempty" yaml:"batching,omitempty"`
-
- Region string `json:"region,omitempty" yaml:"region,omitempty"`
- Endpoint string `json:"endpoint,omitempty" yaml:"endpoint,omitempty"`
-
- Credentials *AwsCredentials `json:"credentials,omitempty" yaml:"credentials,omitempty"`
-}
-
-type AwsCredentials struct {
- Profile string `json:"profile,omitempty" yaml:"profile,omitempty"`
- Id string `json:"id,omitempty" yaml:"id,omitempty"`
- Secret string `json:"secret,omitempty" yaml:"secret,omitempty"`
- Token string `json:"token,omitempty" yaml:"token,omitempty"`
- FromEc2Role bool `json:"from_ec2_role,omitempty" yaml:"from_ec2_role,omitempty"`
- Role string `json:"role,omitempty" yaml:"role,omitempty"`
- RoleExternalId string `json:"role_external_id,omitempty" yaml:"role_external_id,omitempty"`
-}
-
-type Batching struct {
- Count int `json:"count" yaml:"count"`
- ByteSize int `json:"byte_size" yaml:"byte_size"`
- Period string `json:"period" yaml:"period"`
- Check string `json:"check" yaml:"check"`
- Processors []*BatchProcessor `json:"processors" yaml:"processors"`
-}
-
-type BatchProcessor struct {
- Archive *ArchiveProcessor `json:"archive,omitempty" yaml:"archive,omitempty"`
- Compress *CompressProcessor `json:"compress,omitempty" yaml:"compress,omitempty"`
-}
-
-type ArchiveProcessor struct {
- Format string `json:"format" yaml:"format"`
- Path *string `json:"path,omitempty" yaml:"path,omitempty"`
-}
-
-type CompressProcessor struct {
- Algorithm string `json:"algorithm" yaml:"algorithm"`
-}
-
-type OutputBrokerConfig struct {
- Pattern string `json:"pattern" yaml:"pattern"`
- Outputs []Outputs `json:"outputs" yaml:"outputs"`
-}
diff --git a/cli/internal/cmds/neosync/connections/connections_integration_test.go b/cli/internal/cmds/neosync/connections/connections_integration_test.go
new file mode 100644
index 0000000000..a194c071cf
--- /dev/null
+++ b/cli/internal/cmds/neosync/connections/connections_integration_test.go
@@ -0,0 +1,71 @@
+package connections_cmd
+
+import (
+ "context"
+ "testing"
+
+ mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1"
+ tcneosyncapi "github.com/nucleuscloud/neosync/backend/pkg/integration-test"
+ "github.com/nucleuscloud/neosync/internal/testutil"
+ "github.com/stretchr/testify/require"
+)
+
+const neosyncDbMigrationsPath = "../../../../../backend/sql/postgresql/schema"
+
+func Test_Connections(t *testing.T) {
+ t.Parallel()
+ ok := testutil.ShouldRunIntegrationTest()
+ if !ok {
+ return
+ }
+ ctx := context.Background()
+
+ neosyncApi, err := tcneosyncapi.NewNeosyncApiTestClient(ctx, t, tcneosyncapi.WithMigrationsDirectory(neosyncDbMigrationsPath))
+ if err != nil {
+ panic(err)
+ }
+ postgresUrl := "postgresql://postgres:foofar@localhost:5434/neosync"
+
+ t.Run("list_unauthed", func(t *testing.T) {
+ accountId := tcneosyncapi.CreatePersonalAccount(ctx, t, neosyncApi.UnauthdClients.Users)
+ conn1 := tcneosyncapi.CreatePostgresConnection(ctx, t, neosyncApi.UnauthdClients.Connections, accountId, "conn1", postgresUrl)
+ conn2 := tcneosyncapi.CreatePostgresConnection(ctx, t, neosyncApi.UnauthdClients.Connections, accountId, "conn2", postgresUrl)
+ conns := []*mgmtv1alpha1.Connection{conn1, conn2}
+ connections, err := getConnections(ctx, neosyncApi.UnauthdClients.Connections, accountId)
+ require.NoError(t, err)
+ require.Len(t, connections, len(conns))
+ })
+
+ t.Run("list_auth", func(t *testing.T) {
+ testAuthUserId := "c3b32842-9b70-4f4e-ad45-9cab26c6f2f1"
+ userclient := neosyncApi.AuthdClients.GetUserClient(testAuthUserId)
+ connclient := neosyncApi.AuthdClients.GetConnectionClient(testAuthUserId)
+ tcneosyncapi.SetUser(ctx, t, userclient)
+ accountId := tcneosyncapi.CreatePersonalAccount(ctx, t, userclient)
+ conn1 := tcneosyncapi.CreatePostgresConnection(ctx, t, connclient, accountId, "conn1", postgresUrl)
+ conn2 := tcneosyncapi.CreatePostgresConnection(ctx, t, connclient, accountId, "conn2", postgresUrl)
+ conns := []*mgmtv1alpha1.Connection{conn1, conn2}
+ connections, err := getConnections(ctx, connclient, accountId)
+ require.NoError(t, err)
+ require.Len(t, connections, len(conns))
+ })
+
+ t.Run("list_cloud", func(t *testing.T) {
+ testAuthUserId := "34f3e404-c995-452b-89e4-9c486b491dab"
+ userclient := neosyncApi.NeosyncCloudClients.GetUserClient(testAuthUserId)
+ connclient := neosyncApi.NeosyncCloudClients.GetConnectionClient(testAuthUserId)
+ tcneosyncapi.SetUser(ctx, t, userclient)
+ accountId := tcneosyncapi.CreatePersonalAccount(ctx, t, userclient)
+ conn1 := tcneosyncapi.CreatePostgresConnection(ctx, t, connclient, accountId, "conn1", postgresUrl)
+ conn2 := tcneosyncapi.CreatePostgresConnection(ctx, t, connclient, accountId, "conn2", postgresUrl)
+ conns := []*mgmtv1alpha1.Connection{conn1, conn2}
+ connections, err := getConnections(ctx, connclient, accountId)
+ require.NoError(t, err)
+ require.Len(t, connections, len(conns))
+ })
+
+ err = neosyncApi.TearDown(ctx)
+ if err != nil {
+ panic(err)
+ }
+}
diff --git a/cli/internal/cmds/neosync/connections/list.go b/cli/internal/cmds/neosync/connections/list.go
index e7f041ba73..c067c994de 100644
--- a/cli/internal/cmds/neosync/connections/list.go
+++ b/cli/internal/cmds/neosync/connections/list.go
@@ -4,18 +4,17 @@ import (
"context"
"errors"
"fmt"
+ "log/slog"
+ "os"
"time"
"connectrpc.com/connect"
+ charmlog "github.com/charmbracelet/log"
"github.com/fatih/color"
mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1"
"github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1/mgmtv1alpha1connect"
"github.com/nucleuscloud/neosync/cli/internal/auth"
- auth_interceptor "github.com/nucleuscloud/neosync/cli/internal/connect/interceptors/auth"
- "github.com/nucleuscloud/neosync/cli/internal/serverconfig"
"github.com/nucleuscloud/neosync/cli/internal/userconfig"
- "github.com/nucleuscloud/neosync/cli/internal/version"
- http_client "github.com/nucleuscloud/neosync/worker/pkg/http/client"
"github.com/rodaine/table"
"github.com/spf13/cobra"
)
@@ -35,8 +34,13 @@ func newListCmd() *cobra.Command {
if err != nil {
return err
}
+
+ debugMode, err := cmd.Flags().GetBool("debug")
+ if err != nil {
+ return err
+ }
cmd.SilenceUsage = true
- return listConnections(cmd.Context(), &apiKey, &accountId)
+ return listConnections(cmd.Context(), debugMode, &apiKey, &accountId)
},
}
cmd.Flags().String("account-id", "", "Account to list connections for. Defaults to account id in cli context")
@@ -45,18 +49,24 @@ func newListCmd() *cobra.Command {
func listConnections(
ctx context.Context,
+ debugMode bool,
apiKey, accountIdFlag *string,
) error {
- isAuthEnabled, err := auth.IsAuthEnabled(ctx)
- if err != nil {
- return err
+ logLevel := charmlog.InfoLevel
+ if debugMode {
+ logLevel = charmlog.DebugLevel
}
+ charmlogger := charmlog.NewWithOptions(os.Stderr, charmlog.Options{
+ ReportTimestamp: true,
+ Level: logLevel,
+ })
+ logger := slog.New(charmlogger)
var accountId = accountIdFlag
if accountId == nil || *accountId == "" {
aId, err := userconfig.GetAccountId()
if err != nil {
- fmt.Println("Unable to retrieve account id. Please use account switch command to set account.") //nolint:forbidigo
+ logger.Error("Unable to retrieve account id. Please use account switch command to set account.")
return err
}
accountId = &aId
@@ -66,26 +76,40 @@ func listConnections(
return errors.New("Account Id not found. Please use account switch command to set account.")
}
- connectionclient := mgmtv1alpha1connect.NewConnectionServiceClient(
- http_client.NewWithHeaders(version.Get().Headers()),
- serverconfig.GetApiBaseUrl(),
- connect.WithInterceptors(
- auth_interceptor.NewInterceptor(isAuthEnabled, auth.AuthHeader, auth.GetAuthHeaderTokenFn(apiKey)),
- ),
- )
- res, err := connectionclient.GetConnections(ctx, connect.NewRequest[mgmtv1alpha1.GetConnectionsRequest](&mgmtv1alpha1.GetConnectionsRequest{
- AccountId: *accountId,
- }))
+ connectInterceptors := []connect.Interceptor{}
+ neosyncurl := auth.GetNeosyncUrl()
+ httpclient, err := auth.GetNeosyncHttpClient(ctx, apiKey, logger)
+ if err != nil {
+ return err
+ }
+ connectInterceptorOption := connect.WithInterceptors(connectInterceptors...)
+ connectionclient := mgmtv1alpha1connect.NewConnectionServiceClient(httpclient, neosyncurl, connectInterceptorOption)
+
+ connections, err := getConnections(ctx, connectionclient, *accountId)
if err != nil {
return err
}
fmt.Println() //nolint:forbidigo
- printConnectionsTable(res.Msg.Connections)
+ printConnectionsTable(connections)
fmt.Println() //nolint:forbidigo
return nil
}
+func getConnections(
+ ctx context.Context,
+ connectionclient mgmtv1alpha1connect.ConnectionServiceClient,
+ accountId string,
+) ([]*mgmtv1alpha1.Connection, error) {
+ res, err := connectionclient.GetConnections(ctx, connect.NewRequest[mgmtv1alpha1.GetConnectionsRequest](&mgmtv1alpha1.GetConnectionsRequest{
+ AccountId: accountId,
+ }))
+ if err != nil {
+ return nil, err
+ }
+ return res.Msg.GetConnections(), nil
+}
+
func printConnectionsTable(
connections []*mgmtv1alpha1.Connection,
) {
diff --git a/cli/internal/cmds/neosync/neosync.go b/cli/internal/cmds/neosync/neosync.go
index 1d7ae8a872..ef7b874f12 100644
--- a/cli/internal/cmds/neosync/neosync.go
+++ b/cli/internal/cmds/neosync/neosync.go
@@ -66,6 +66,8 @@ func Execute() {
)
rootCmd.PersistentFlags().String(apiKeyFlag, "", fmt.Sprintf("Neosync API Key. Takes precedence over $%s", apiKeyEnvVarName))
+ rootCmd.PersistentFlags().Bool("debug", false, "Run in debug mode")
+
rootCmd.AddCommand(jobs_cmd.NewCmd())
rootCmd.AddCommand(version_cmd.NewCmd())
rootCmd.AddCommand(whoami_cmd.NewCmd())
diff --git a/cli/internal/cmds/neosync/sync/config.go b/cli/internal/cmds/neosync/sync/config.go
new file mode 100644
index 0000000000..8791ef6670
--- /dev/null
+++ b/cli/internal/cmds/neosync/sync/config.go
@@ -0,0 +1,199 @@
+package sync_cmd
+
+import (
+ "errors"
+ "fmt"
+ "log/slog"
+ "os"
+
+ mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1"
+ "github.com/nucleuscloud/neosync/cli/internal/output"
+ "github.com/nucleuscloud/neosync/cli/internal/userconfig"
+ "github.com/spf13/cobra"
+ "gopkg.in/yaml.v2"
+)
+
+func buildCmdConfig(cmd *cobra.Command) (*cmdConfig, error) {
+ config := &cmdConfig{
+ Source: &sourceConfig{
+ ConnectionOpts: &connectionOpts{},
+ },
+ Destination: &sqlDestinationConfig{},
+ AwsDynamoDbDestination: &dynamoDbDestinationConfig{},
+ }
+ configPath, err := cmd.Flags().GetString("config")
+ if err != nil {
+ return nil, err
+ }
+
+ if configPath != "" {
+ fileBytes, err := os.ReadFile(configPath)
+ if err != nil {
+ return nil, fmt.Errorf("error reading config file: %w", err)
+ }
+ err = yaml.Unmarshal(fileBytes, &config)
+ if err != nil {
+ return nil, fmt.Errorf("error parsing config file: %w", err)
+ }
+ }
+
+ connectionId, err := cmd.Flags().GetString("connection-id")
+ if err != nil {
+ return nil, err
+ }
+ if connectionId != "" {
+ config.Source.ConnectionId = connectionId
+ }
+
+ destConnUrl, err := cmd.Flags().GetString("destination-connection-url")
+ if err != nil {
+ return nil, err
+ }
+ if destConnUrl != "" {
+ config.Destination.ConnectionUrl = destConnUrl
+ }
+
+ driver, err := cmd.Flags().GetString("destination-driver")
+ if err != nil {
+ return nil, err
+ }
+ pDriver, ok := parseDriverString(driver)
+ if ok {
+ config.Destination.Driver = pDriver
+ }
+
+ initSchema, err := cmd.Flags().GetBool("init-schema")
+ if err != nil {
+ return nil, err
+ }
+ if initSchema {
+ config.Destination.InitSchema = initSchema
+ }
+
+ truncateBeforeInsert, err := cmd.Flags().GetBool("truncate-before-insert")
+ if err != nil {
+ return nil, err
+ }
+ if truncateBeforeInsert {
+ config.Destination.TruncateBeforeInsert = truncateBeforeInsert
+ }
+
+ truncateCascade, err := cmd.Flags().GetBool("truncate-cascade")
+ if err != nil {
+ return nil, err
+ }
+ if truncateCascade {
+ config.Destination.TruncateCascade = truncateCascade
+ }
+
+ onConflictDoNothing, err := cmd.Flags().GetBool("on-conflict-do-nothing")
+ if err != nil {
+ return nil, err
+ }
+ if onConflictDoNothing {
+ config.Destination.OnConflict.DoNothing = onConflictDoNothing
+ }
+
+ jobId, err := cmd.Flags().GetString("job-id")
+ if err != nil {
+ return nil, err
+ }
+ if jobId != "" {
+ config.Source.ConnectionOpts.JobId = &jobId
+ }
+
+ jobRunId, err := cmd.Flags().GetString("job-run-id")
+ if err != nil {
+ return nil, err
+ }
+ if jobRunId != "" {
+ config.Source.ConnectionOpts.JobRunId = &jobRunId
+ }
+
+ config, err = buildAwsCredConfig(cmd, config)
+ if err != nil {
+ return nil, err
+ }
+
+ if config.Source.ConnectionId == "" {
+ return nil, fmt.Errorf("must provide connection-id")
+ }
+
+ accountIdFlag, err := cmd.Flags().GetString("account-id")
+ if err != nil {
+ return nil, err
+ }
+ accountId := accountIdFlag
+ if accountId == "" {
+ aId, err := userconfig.GetAccountId()
+ if err != nil {
+ return nil, errors.New("Unable to retrieve account id. Please use account switch command to set account.")
+ }
+ accountId = aId
+ }
+ config.AccountId = &accountId
+
+ if accountId == "" {
+ return nil, errors.New("Account Id not found. Please use account switch command to set account.")
+ }
+
+ outputType, err := output.ValidateAndRetrieveOutputFlag(cmd)
+ if err != nil {
+ return nil, err
+ }
+ config.OutputType = &outputType
+
+ debug, err := cmd.Flags().GetBool("debug")
+ if err != nil {
+ return nil, err
+ }
+ config.Debug = debug
+ return config, nil
+}
+
+func isConfigValid(cmd *cmdConfig, logger *slog.Logger, sourceConnection *mgmtv1alpha1.Connection, sourceConnectionType ConnectionType) error {
+ if sourceConnectionType == awsS3Connection && (cmd.Source.ConnectionOpts.JobId == nil || *cmd.Source.ConnectionOpts.JobId == "") && (cmd.Source.ConnectionOpts.JobRunId == nil || *cmd.Source.ConnectionOpts.JobRunId == "") {
+ return errors.New("S3 source connection type requires job-id or job-run-id.")
+ }
+ if sourceConnectionType == gcpCloudStorageConnection && (cmd.Source.ConnectionOpts.JobId == nil || *cmd.Source.ConnectionOpts.JobId == "") && (cmd.Source.ConnectionOpts.JobRunId == nil || *cmd.Source.ConnectionOpts.JobRunId == "") {
+ return errors.New("GCP Cloud Storage source connection type requires job-id or job-run-id")
+ }
+
+ if cmd.Destination.TruncateCascade && cmd.Destination.Driver == mysqlDriver {
+ return fmt.Errorf("truncate cascade is only supported in postgres")
+ }
+
+ if sourceConnectionType == mysqlConnection || sourceConnectionType == postgresConnection {
+ if cmd.Destination.Driver == "" {
+ return fmt.Errorf("must provide destination-driver")
+ }
+ if cmd.Destination.ConnectionUrl == "" {
+ return fmt.Errorf("must provide destination-connection-url")
+ }
+
+ if cmd.Destination.Driver != mysqlDriver && cmd.Destination.Driver != postgresDriver {
+ return errors.New("unsupported destination driver. only pgx (postgres) and mysql are currently supported")
+ }
+ }
+
+ if sourceConnectionType == awsDynamoDBConnection {
+ if cmd.AwsDynamoDbDestination == nil {
+ return fmt.Errorf("must provide destination aws credentials")
+ }
+
+ if cmd.AwsDynamoDbDestination.AwsCredConfig.Region == "" {
+ return fmt.Errorf("must provide destination aws region")
+ }
+ }
+
+ if sourceConnection.AccountId != *cmd.AccountId {
+ return fmt.Errorf("Connection not found. AccountId: %s", *cmd.AccountId)
+ }
+
+ logger.Debug("Checking if source and destination are compatible")
+ err := areSourceAndDestCompatible(sourceConnection, cmd.Destination.Driver)
+ if err != nil {
+ return err
+ }
+ return nil
+}
diff --git a/cli/internal/cmds/neosync/sync/dynamodb.go b/cli/internal/cmds/neosync/sync/dynamodb.go
index c8e6f9e4df..81a7923c96 100644
--- a/cli/internal/cmds/neosync/sync/dynamodb.go
+++ b/cli/internal/cmds/neosync/sync/dynamodb.go
@@ -2,7 +2,7 @@ package sync_cmd
import (
tabledependency "github.com/nucleuscloud/neosync/backend/pkg/table-dependency"
- cli_neosync_benthos "github.com/nucleuscloud/neosync/cli/internal/benthos"
+ neosync_benthos "github.com/nucleuscloud/neosync/worker/pkg/benthos"
"github.com/spf13/cobra"
)
@@ -62,21 +62,19 @@ func buildAwsCredConfig(cmd *cobra.Command, config *cmdConfig) (*cmdConfig, erro
func generateDynamoDbBenthosConfig(
cmd *cmdConfig,
- apiUrl string,
- authToken *string,
table string,
) *benthosConfigResponse {
- bc := &cli_neosync_benthos.BenthosConfig{
- StreamConfig: cli_neosync_benthos.StreamConfig{
- Logger: &cli_neosync_benthos.LoggerConfig{
+ bc := &neosync_benthos.BenthosConfig{
+ StreamConfig: neosync_benthos.StreamConfig{
+ Logger: &neosync_benthos.LoggerConfig{
Level: "ERROR",
AddTimestamp: true,
},
- Input: &cli_neosync_benthos.InputConfig{
- Inputs: cli_neosync_benthos.Inputs{
- NeosyncConnectionData: &cli_neosync_benthos.NeosyncConnectionData{
- ApiKey: authToken,
- ApiUrl: apiUrl,
+ Input: &neosync_benthos.InputConfig{
+ Inputs: neosync_benthos.Inputs{
+ NeosyncConnectionData: &neosync_benthos.NeosyncConnectionData{
+ // ApiKey: authToken,
+ // ApiUrl: apiUrl,
ConnectionId: cmd.Source.ConnectionId,
ConnectionType: string(awsDynamoDBConnection),
Schema: "dynamodb",
@@ -84,16 +82,16 @@ func generateDynamoDbBenthosConfig(
},
},
},
- Pipeline: &cli_neosync_benthos.PipelineConfig{},
- Output: &cli_neosync_benthos.OutputConfig{
- Outputs: cli_neosync_benthos.Outputs{
- AwsDynamoDB: &cli_neosync_benthos.OutputAwsDynamoDB{
+ Pipeline: &neosync_benthos.PipelineConfig{},
+ Output: &neosync_benthos.OutputConfig{
+ Outputs: neosync_benthos.Outputs{
+ AwsDynamoDB: &neosync_benthos.OutputAwsDynamoDB{
Table: table,
JsonMapColumns: map[string]string{
"": ".",
},
- Batching: &cli_neosync_benthos.Batching{
+ Batching: &neosync_benthos.Batching{
// https://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_BatchWriteItem.html
// A single call to BatchWriteItem can transmit up to 16MB of data over the network, consisting of up to 25 item put or delete operations
// Specifying the count here may not be enough if the overall data is above 16MB.
@@ -119,12 +117,12 @@ func generateDynamoDbBenthosConfig(
}
}
-func buildBenthosAwsCredentials(cmd *cmdConfig) *cli_neosync_benthos.AwsCredentials {
+func buildBenthosAwsCredentials(cmd *cmdConfig) *neosync_benthos.AwsCredentials {
if cmd.AwsDynamoDbDestination == nil || cmd.AwsDynamoDbDestination.AwsCredConfig == nil {
return nil
}
cc := cmd.AwsDynamoDbDestination.AwsCredConfig
- creds := &cli_neosync_benthos.AwsCredentials{}
+ creds := &neosync_benthos.AwsCredentials{}
if cc.Profile != nil {
creds.Profile = *cc.Profile
}
diff --git a/cli/internal/cmds/neosync/sync/sync.go b/cli/internal/cmds/neosync/sync/sync.go
index ba3a474518..9c8c561ef2 100644
--- a/cli/internal/cmds/neosync/sync/sync.go
+++ b/cli/internal/cmds/neosync/sync/sync.go
@@ -6,7 +6,6 @@ import (
"fmt"
"log/slog"
"os"
- "slices"
"strings"
syncmap "sync"
"time"
@@ -27,22 +26,17 @@ import (
sqlmanager_shared "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/shared"
tabledependency "github.com/nucleuscloud/neosync/backend/pkg/table-dependency"
"github.com/nucleuscloud/neosync/cli/internal/auth"
- cli_neosync_benthos "github.com/nucleuscloud/neosync/cli/internal/benthos"
- auth_interceptor "github.com/nucleuscloud/neosync/cli/internal/connect/interceptors/auth"
"github.com/nucleuscloud/neosync/cli/internal/output"
- "github.com/nucleuscloud/neosync/cli/internal/serverconfig"
- "github.com/nucleuscloud/neosync/cli/internal/userconfig"
- "github.com/nucleuscloud/neosync/cli/internal/version"
connectiontunnelmanager "github.com/nucleuscloud/neosync/internal/connection-tunnel-manager"
pool_sql_provider "github.com/nucleuscloud/neosync/internal/connection-tunnel-manager/pool/providers/sql"
"github.com/nucleuscloud/neosync/internal/connection-tunnel-manager/providers"
"github.com/nucleuscloud/neosync/internal/connection-tunnel-manager/providers/mongoprovider"
"github.com/nucleuscloud/neosync/internal/connection-tunnel-manager/providers/sqlprovider"
+ neosync_benthos "github.com/nucleuscloud/neosync/worker/pkg/benthos"
"github.com/spf13/cobra"
"golang.org/x/sync/errgroup"
"gopkg.in/yaml.v2"
- _ "github.com/nucleuscloud/neosync/cli/internal/benthos/inputs"
benthos_environment "github.com/nucleuscloud/neosync/worker/pkg/benthos/environment"
_ "github.com/nucleuscloud/neosync/worker/pkg/benthos/sql"
"github.com/nucleuscloud/neosync/worker/pkg/workflows/datasync/activities/shared"
@@ -52,8 +46,6 @@ import (
_ "github.com/warpstreamlabs/bento/public/components/pure"
_ "github.com/warpstreamlabs/bento/public/components/pure/extended"
- http_client "github.com/nucleuscloud/neosync/worker/pkg/http/client"
-
"github.com/warpstreamlabs/bento/public/service"
)
@@ -86,6 +78,8 @@ type cmdConfig struct {
Destination *sqlDestinationConfig `yaml:"destination"`
AwsDynamoDbDestination *dynamoDbDestinationConfig `yaml:"aws-dynamodb-destination,omitempty"`
Debug bool
+ OutputType *output.OutputType `yaml:"output-type,omitempty"`
+ AccountId *string `yaml:"account-id,omitempty"`
}
type sourceConfig struct {
@@ -141,133 +135,15 @@ func NewCmd() *cobra.Command {
apiKey = &apiKeyStr
}
- config := &cmdConfig{
- Source: &sourceConfig{
- ConnectionOpts: &connectionOpts{},
- },
- Destination: &sqlDestinationConfig{},
- AwsDynamoDbDestination: &dynamoDbDestinationConfig{},
- }
- configPath, err := cmd.Flags().GetString("config")
- if err != nil {
- return err
- }
-
- if configPath != "" {
- fileBytes, err := os.ReadFile(configPath)
- if err != nil {
- return fmt.Errorf("error reading config file: %w", err)
- }
- err = yaml.Unmarshal(fileBytes, &config)
- if err != nil {
- return fmt.Errorf("error parsing config file: %w", err)
- }
- }
-
- connectionId, err := cmd.Flags().GetString("connection-id")
- if err != nil {
- return err
- }
- if connectionId != "" {
- config.Source.ConnectionId = connectionId
- }
-
- destConnUrl, err := cmd.Flags().GetString("destination-connection-url")
- if err != nil {
- return err
- }
- if destConnUrl != "" {
- config.Destination.ConnectionUrl = destConnUrl
- }
-
- driver, err := cmd.Flags().GetString("destination-driver")
- if err != nil {
- return err
- }
- pDriver, ok := parseDriverString(driver)
- if ok {
- config.Destination.Driver = pDriver
- }
-
- initSchema, err := cmd.Flags().GetBool("init-schema")
- if err != nil {
- return err
- }
- if initSchema {
- config.Destination.InitSchema = initSchema
- }
-
- truncateBeforeInsert, err := cmd.Flags().GetBool("truncate-before-insert")
- if err != nil {
- return err
- }
- if truncateBeforeInsert {
- config.Destination.TruncateBeforeInsert = truncateBeforeInsert
- }
-
- truncateCascade, err := cmd.Flags().GetBool("truncate-cascade")
- if err != nil {
- return err
- }
- if truncateCascade {
- config.Destination.TruncateCascade = truncateCascade
- }
-
- onConflictDoNothing, err := cmd.Flags().GetBool("on-conflict-do-nothing")
- if err != nil {
- return err
- }
- if onConflictDoNothing {
- config.Destination.OnConflict.DoNothing = onConflictDoNothing
- }
-
- jobId, err := cmd.Flags().GetString("job-id")
- if err != nil {
- return err
- }
- if jobId != "" {
- config.Source.ConnectionOpts.JobId = &jobId
- }
-
- jobRunId, err := cmd.Flags().GetString("job-run-id")
- if err != nil {
- return err
- }
- if jobRunId != "" {
- config.Source.ConnectionOpts.JobRunId = &jobRunId
- }
-
- config, err = buildAwsCredConfig(cmd, config)
+ config, err := buildCmdConfig(cmd)
if err != nil {
return err
}
- if config.Source.ConnectionId == "" {
- return fmt.Errorf("must provide connection-id")
- }
-
- accountId, err := cmd.Flags().GetString("account-id")
- if err != nil {
- return err
- }
-
- outputType, err := output.ValidateAndRetrieveOutputFlag(cmd)
- if err != nil {
- return err
- }
-
- debug, err := cmd.Flags().GetBool("debug")
- if err != nil {
- return err
- }
- config.Debug = debug
-
- return sync(cmd.Context(), outputType, apiKey, &accountId, config)
+ return sync(cmd.Context(), apiKey, config)
},
}
- cmd.Flags().Bool("debug", false, "Run in debug mode")
-
cmd.Flags().String("connection-id", "", "Connection id for sync source")
cmd.Flags().String("job-id", "", "Id of Job to sync data from. Only used with [AWS S3, GCP Cloud Storage] connections. Can use job-run-id instead.")
cmd.Flags().String("job-run-id", "", "Id of Job run to sync data from. Only used with [AWS S3, GCP Cloud Storage] connections. Can use job-id instead.")
@@ -294,10 +170,20 @@ func NewCmd() *cobra.Command {
return cmd
}
+type clisync struct {
+ connectiondataclient mgmtv1alpha1connect.ConnectionDataServiceClient
+ connectionclient mgmtv1alpha1connect.ConnectionServiceClient
+ sqlmanagerclient *sqlmanager.SqlManager
+ sqlconnector *sqlconnect.SqlOpenConnector
+ benv *service.Environment
+ cmd *cmdConfig
+ logger *slog.Logger
+ ctx context.Context
+}
+
func sync(
ctx context.Context,
- outputType output.OutputType,
- apiKey, accountIdFlag *string,
+ apiKey *string,
cmd *cmdConfig,
) error {
logLevel := charmlog.InfoLevel
@@ -311,27 +197,16 @@ func sync(
logger := slog.New(charmlogger)
logger.Info("Starting sync")
- isAuthEnabled, err := auth.IsAuthEnabled(ctx)
+
+ connectInterceptors := []connect.Interceptor{}
+ neosyncurl := auth.GetNeosyncUrl()
+ httpclient, err := auth.GetNeosyncHttpClient(ctx, apiKey, logger)
if err != nil {
return err
}
-
- httpclient := http_client.NewWithHeaders(version.Get().Headers())
- connectionclient := mgmtv1alpha1connect.NewConnectionServiceClient(
- httpclient,
- serverconfig.GetApiBaseUrl(),
- connect.WithInterceptors(
- auth_interceptor.NewInterceptor(isAuthEnabled, auth.AuthHeader, auth.GetAuthHeaderTokenFn(apiKey)),
- ),
- )
-
- connectiondataclient := mgmtv1alpha1connect.NewConnectionDataServiceClient(
- httpclient,
- serverconfig.GetApiBaseUrl(),
- connect.WithInterceptors(
- auth_interceptor.NewInterceptor(isAuthEnabled, auth.AuthHeader, auth.GetAuthHeaderTokenFn(apiKey)),
- ),
- )
+ connectInterceptorOption := connect.WithInterceptors(connectInterceptors...)
+ connectionclient := mgmtv1alpha1connect.NewConnectionServiceClient(httpclient, neosyncurl, connectInterceptorOption)
+ connectiondataclient := mgmtv1alpha1connect.NewConnectionDataServiceClient(httpclient, neosyncurl, connectInterceptorOption)
pgpoolmap := &syncmap.Map{}
mysqlpoolmap := &syncmap.Map{}
@@ -342,99 +217,23 @@ func sync(
sqlConnector := &sqlconnect.SqlOpenConnector{}
sqlmanagerclient := sqlmanager.NewSqlManager(pgpoolmap, pgquerier, mysqlpoolmap, mysqlquerier, mssqlpoolmap, mssqlquerier, sqlConnector)
- logger.Debug("Retrieving neosync source connection")
- connResp, err := connectionclient.GetConnection(ctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{
- Id: cmd.Source.ConnectionId,
- }))
- if err != nil {
- return err
- }
- sourceConnection := connResp.Msg.GetConnection()
- sourceConnectionType, err := getConnectionType(sourceConnection)
- if err != nil {
- return err
- }
- logger.Debug(fmt.Sprintf("Source connection type: %s", sourceConnectionType))
-
- if sourceConnectionType == awsS3Connection && (cmd.Source.ConnectionOpts.JobId == nil || *cmd.Source.ConnectionOpts.JobId == "") && (cmd.Source.ConnectionOpts.JobRunId == nil || *cmd.Source.ConnectionOpts.JobRunId == "") {
- return errors.New("S3 source connection type requires job-id or job-run-id.")
- }
- if sourceConnectionType == gcpCloudStorageConnection && (cmd.Source.ConnectionOpts.JobId == nil || *cmd.Source.ConnectionOpts.JobId == "") && (cmd.Source.ConnectionOpts.JobRunId == nil || *cmd.Source.ConnectionOpts.JobRunId == "") {
- return errors.New("GCP Cloud Storage source connection type requires job-id or job-run-id")
- }
-
- if cmd.Destination.TruncateCascade && cmd.Destination.Driver == mysqlDriver {
- return fmt.Errorf("truncate cascade is only supported in postgres")
- }
-
- if sourceConnectionType == mysqlConnection || sourceConnectionType == postgresConnection {
- if cmd.Destination.Driver == "" {
- return fmt.Errorf("must provide destination-driver")
- }
- if cmd.Destination.ConnectionUrl == "" {
- return fmt.Errorf("must provide destination-connection-url")
- }
-
- if cmd.Destination.Driver != mysqlDriver && cmd.Destination.Driver != postgresDriver {
- return errors.New("unsupported destination driver. only postgres and mysql are currently supported")
- }
- }
-
- if sourceConnectionType == awsDynamoDBConnection {
- if cmd.AwsDynamoDbDestination == nil {
- return fmt.Errorf("must provide destination aws credentials")
- }
-
- if cmd.AwsDynamoDbDestination.AwsCredConfig.Region == "" {
- return fmt.Errorf("must provide destination aws region")
- }
- }
- logger.Debug("Validated config")
-
- var token *string
- if isAuthEnabled {
- logger.Debug("Auth Enabled")
- if apiKey != nil && *apiKey != "" {
- logger.Debug("found API Key")
- token = apiKey
- } else {
- logger.Debug("Retrieving Access Token")
- accessToken, err := userconfig.GetAccessToken()
- if err != nil {
- logger.Error("Unable to retrieve access token. Please use neosync login command and try again.")
- return err
- }
- token = &accessToken
- logger.Debug("Setting account id")
- var accountId = accountIdFlag
- if accountId == nil || *accountId == "" {
- aId, err := userconfig.GetAccountId()
- if err != nil {
- logger.Error("Unable to retrieve account id. Please use account switch command to set account.")
- return err
- }
- accountId = &aId
- }
-
- if accountId == nil || *accountId == "" {
- return errors.New("Account Id not found. Please use account switch command to set account.")
- }
-
- if sourceConnection.AccountId != *accountId {
- return fmt.Errorf("Connection not found. AccountId: %s", *accountId)
- }
- }
+ sync := &clisync{
+ connectiondataclient: connectiondataclient,
+ connectionclient: connectionclient,
+ sqlmanagerclient: sqlmanagerclient,
+ sqlconnector: sqlConnector,
+ cmd: cmd,
+ logger: logger,
+ ctx: ctx,
}
- logger.Debug("Checking if source and destination are compatible")
- err = areSourceAndDestCompatible(sourceConnection, cmd.Destination.Driver)
- if err != nil {
- return err
- }
+ return sync.configureAndRunSync()
+}
+func (c *clisync) configureAndRunSync() error {
connectionprovider := providers.NewProvider(
mongoprovider.NewProvider(),
- sqlprovider.NewProvider(sqlConnector),
+ sqlprovider.NewProvider(c.sqlconnector),
)
tunnelmanager := connectiontunnelmanager.NewConnectionTunnelManager(connectionprovider)
session := uuid.NewString()
@@ -443,15 +242,15 @@ func sync(
tunnelmanager.ReleaseSession(session)
}()
- destConnection := cmdConfigToDestinationConnection(cmd)
+ destConnection := cmdConfigToDestinationConnection(c.cmd)
dsnToConnIdMap := &syncmap.Map{}
var sqlDsn string
- if cmd.Destination != nil {
- sqlDsn = cmd.Destination.ConnectionUrl
+ if c.cmd.Destination != nil {
+ sqlDsn = c.cmd.Destination.ConnectionUrl
}
dsnToConnIdMap.Store(sqlDsn, destConnection.Id)
stopChan := make(chan error, 3)
- ctx, cancel := context.WithCancel(ctx)
+ ctx, cancel := context.WithCancel(c.ctx)
defer cancel()
go func() {
for {
@@ -464,8 +263,8 @@ func sync(
}
}
}()
- benv, err := benthos_environment.NewEnvironment(
- logger,
+ benthosEnv, err := benthos_environment.NewEnvironment(
+ c.logger,
benthos_environment.WithSqlConfig(&benthos_environment.SqlConfig{
Provider: pool_sql_provider.NewProvider(pool_sql_provider.GetSqlPoolProviderGetter(
tunnelmanager,
@@ -474,26 +273,63 @@ func sync(
destConnection.Id: destConnection,
},
session,
- logger,
+ c.logger,
)),
IsRetry: false,
}),
+ benthos_environment.WithConnectionDataConfig(&benthos_environment.ConnectionDataConfig{
+ NeosyncConnectionDataApi: c.connectiondataclient,
+ }),
benthos_environment.WithStopChannel(stopChan),
benthos_environment.WithBlobEnv(bloblang.NewEnvironment()),
)
if err != nil {
return err
}
+ c.benv = benthosEnv
+
+ groupedConfigs, err := c.configureSync()
+ if err != nil {
+ return err
+ }
+ if groupedConfigs == nil {
+ return nil
+ }
+
+ return runSync(c.ctx, *c.cmd.OutputType, c.benv, groupedConfigs, c.logger)
+}
+
+func (c *clisync) configureSync() ([][]*benthosConfigResponse, error) {
+ c.logger.Debug("Retrieving neosync connection")
+ connResp, err := c.connectionclient.GetConnection(c.ctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{
+ Id: c.cmd.Source.ConnectionId,
+ }))
+ if err != nil {
+ return nil, err
+ }
+ sourceConnection := connResp.Msg.GetConnection()
+ sourceConnectionType, err := getConnectionType(sourceConnection)
+ if err != nil {
+ return nil, err
+ }
+ c.logger.Debug(fmt.Sprintf("Source connection type: %s", sourceConnectionType))
+
+ err = isConfigValid(c.cmd, c.logger, sourceConnection, sourceConnectionType)
+ if err != nil {
+ return nil, err
+ }
+ c.logger.Debug("Validated config")
- logger.Info("Retrieving connection schema...")
+ c.logger.Info("Retrieving connection schema...")
var schemaConfig *schemaConfig
switch sourceConnectionType {
case awsS3Connection:
+ c.logger.Info("Building schema and table constraints...")
var cfg *mgmtv1alpha1.AwsS3SchemaConfig
- if cmd.Source.ConnectionOpts.JobRunId != nil && *cmd.Source.ConnectionOpts.JobRunId != "" {
- cfg = &mgmtv1alpha1.AwsS3SchemaConfig{Id: &mgmtv1alpha1.AwsS3SchemaConfig_JobRunId{JobRunId: *cmd.Source.ConnectionOpts.JobRunId}}
- } else if cmd.Source.ConnectionOpts.JobId != nil && *cmd.Source.ConnectionOpts.JobId != "" {
- cfg = &mgmtv1alpha1.AwsS3SchemaConfig{Id: &mgmtv1alpha1.AwsS3SchemaConfig_JobId{JobId: *cmd.Source.ConnectionOpts.JobId}}
+ if c.cmd.Source.ConnectionOpts.JobRunId != nil && *c.cmd.Source.ConnectionOpts.JobRunId != "" {
+ cfg = &mgmtv1alpha1.AwsS3SchemaConfig{Id: &mgmtv1alpha1.AwsS3SchemaConfig_JobRunId{JobRunId: *c.cmd.Source.ConnectionOpts.JobRunId}}
+ } else if c.cmd.Source.ConnectionOpts.JobId != nil && *c.cmd.Source.ConnectionOpts.JobId != "" {
+ cfg = &mgmtv1alpha1.AwsS3SchemaConfig{Id: &mgmtv1alpha1.AwsS3SchemaConfig_JobId{JobId: *c.cmd.Source.ConnectionOpts.JobId}}
}
s3Config := &mgmtv1alpha1.ConnectionSchemaConfig{
Config: &mgmtv1alpha1.ConnectionSchemaConfig_AwsS3Config{
@@ -501,21 +337,21 @@ func sync(
},
}
- schemaCfg, err := getDestinationSchemaConfig(ctx, connectiondataclient, sqlmanagerclient, sourceConnection, cmd, s3Config, logger)
+ schemaCfg, err := c.getDestinationSchemaConfig(sourceConnection, s3Config)
if err != nil {
- return err
+ return nil, err
}
if len(schemaCfg.Schemas) == 0 {
- logger.Warn("No tables found.")
- return nil
+ c.logger.Warn("No tables found.")
+ return nil, nil
}
schemaConfig = schemaCfg
case gcpCloudStorageConnection:
var cfg *mgmtv1alpha1.GcpCloudStorageSchemaConfig
- if cmd.Source.ConnectionOpts.JobRunId != nil && *cmd.Source.ConnectionOpts.JobRunId != "" {
- cfg = &mgmtv1alpha1.GcpCloudStorageSchemaConfig{Id: &mgmtv1alpha1.GcpCloudStorageSchemaConfig_JobRunId{JobRunId: *cmd.Source.ConnectionOpts.JobRunId}}
- } else if cmd.Source.ConnectionOpts.JobId != nil && *cmd.Source.ConnectionOpts.JobId != "" {
- cfg = &mgmtv1alpha1.GcpCloudStorageSchemaConfig{Id: &mgmtv1alpha1.GcpCloudStorageSchemaConfig_JobId{JobId: *cmd.Source.ConnectionOpts.JobId}}
+ if c.cmd.Source.ConnectionOpts.JobRunId != nil && *c.cmd.Source.ConnectionOpts.JobRunId != "" {
+ cfg = &mgmtv1alpha1.GcpCloudStorageSchemaConfig{Id: &mgmtv1alpha1.GcpCloudStorageSchemaConfig_JobRunId{JobRunId: *c.cmd.Source.ConnectionOpts.JobRunId}}
+ } else if c.cmd.Source.ConnectionOpts.JobId != nil && *c.cmd.Source.ConnectionOpts.JobId != "" {
+ cfg = &mgmtv1alpha1.GcpCloudStorageSchemaConfig{Id: &mgmtv1alpha1.GcpCloudStorageSchemaConfig_JobId{JobId: *c.cmd.Source.ConnectionOpts.JobId}}
}
gcpConfig := &mgmtv1alpha1.ConnectionSchemaConfig{
@@ -524,45 +360,45 @@ func sync(
},
}
- schemaCfg, err := getDestinationSchemaConfig(ctx, connectiondataclient, sqlmanagerclient, sourceConnection, cmd, gcpConfig, logger)
+ schemaCfg, err := c.getDestinationSchemaConfig(sourceConnection, gcpConfig)
if err != nil {
- return err
+ return nil, err
}
if len(schemaCfg.Schemas) == 0 {
- logger.Warn("No tables found.")
- return nil
+ c.logger.Warn("No tables found.")
+ return nil, nil
}
schemaConfig = schemaCfg
case mysqlConnection:
- logger.Info("Building schema and table constraints...")
+ c.logger.Info("Building schema and table constraints...")
mysqlCfg := &mgmtv1alpha1.ConnectionSchemaConfig{
Config: &mgmtv1alpha1.ConnectionSchemaConfig_MysqlConfig{
MysqlConfig: &mgmtv1alpha1.MysqlSchemaConfig{},
},
}
- schemaCfg, err := getConnectionSchemaConfig(ctx, logger, connectiondataclient, sourceConnection, cmd, mysqlCfg)
+ schemaCfg, err := c.getConnectionSchemaConfig(sourceConnection, mysqlCfg)
if err != nil {
- return err
+ return nil, err
}
if len(schemaCfg.Schemas) == 0 {
- logger.Warn("No tables found.")
- return nil
+ c.logger.Warn("No tables found.")
+ return nil, nil
}
schemaConfig = schemaCfg
case postgresConnection:
- logger.Info("Building schema and table constraints...")
+ c.logger.Info("Building schema and table constraints...")
postgresConfig := &mgmtv1alpha1.ConnectionSchemaConfig{
Config: &mgmtv1alpha1.ConnectionSchemaConfig_PgConfig{
PgConfig: &mgmtv1alpha1.PostgresSchemaConfig{},
},
}
- schemaCfg, err := getConnectionSchemaConfig(ctx, logger, connectiondataclient, sourceConnection, cmd, postgresConfig)
+ schemaCfg, err := c.getConnectionSchemaConfig(sourceConnection, postgresConfig)
if err != nil {
- return err
+ return nil, err
}
if len(schemaCfg.Schemas) == 0 {
- logger.Warn("No tables found.")
- return nil
+ c.logger.Warn("No tables found.")
+ return nil, nil
}
schemaConfig = schemaCfg
case awsDynamoDBConnection:
@@ -571,13 +407,13 @@ func sync(
DynamodbConfig: &mgmtv1alpha1.DynamoDBSchemaConfig{},
},
}
- schemaCfg, err := getConnectionSchemaConfig(ctx, logger, connectiondataclient, sourceConnection, cmd, dynamoConfig)
+ schemaCfg, err := c.getConnectionSchemaConfig(sourceConnection, dynamoConfig)
if err != nil {
- return err
+ return nil, err
}
if len(schemaCfg.Schemas) == 0 {
- logger.Warn("No tables found.")
- return nil
+ c.logger.Warn("No tables found.")
+ return nil, nil
}
tableMap := map[string]struct{}{}
for _, s := range schemaCfg.Schemas {
@@ -585,58 +421,38 @@ func sync(
}
configs := []*benthosConfigResponse{}
for t := range tableMap {
- benthosConfig := generateDynamoDbBenthosConfig(cmd, serverconfig.GetApiBaseUrl(), token, t)
+ benthosConfig := generateDynamoDbBenthosConfig(c.cmd, t)
configs = append(configs, benthosConfig)
}
-
- return runSync(ctx, outputType, benv, [][]*benthosConfigResponse{configs}, logger)
+ return [][]*benthosConfigResponse{configs}, nil
default:
- return fmt.Errorf("this connection type is not currently supported")
+ return nil, fmt.Errorf("this connection type is not currently supported")
}
- logger.Debug("Building sync configs")
- syncConfigs := buildSyncConfigs(schemaConfig, logger)
+ c.logger.Debug("Building sync configs")
+ syncConfigs := buildSyncConfigs(schemaConfig, c.logger)
if syncConfigs == nil {
- return nil
+ return nil, nil
}
- logger.Info("Running table init statements...")
- err = runDestinationInitStatements(ctx, logger, sqlmanagerclient, cmd, syncConfigs, schemaConfig)
+ c.logger.Info("Running table init statements...")
+ err = c.runDestinationInitStatements(syncConfigs, schemaConfig)
if err != nil {
- return err
+ return nil, err
}
syncConfigCount := len(syncConfigs)
- logger.Info(fmt.Sprintf("Generating %d sync configs...", syncConfigCount))
+ c.logger.Info(fmt.Sprintf("Generating %d sync configs...", syncConfigCount))
configs := []*benthosConfigResponse{}
for _, cfg := range syncConfigs {
- benthosConfig := generateBenthosConfig(cmd, sourceConnectionType, serverconfig.GetApiBaseUrl(), cfg, token)
+ benthosConfig := generateBenthosConfig(c.cmd, sourceConnectionType, cfg)
configs = append(configs, benthosConfig)
}
// order configs in run order by dependency
- groupedConfigs := groupConfigsByDependency(configs, logger)
- if groupedConfigs == nil {
- return nil
- }
+ c.logger.Debug("Ordering configs by dependency")
+ groupedConfigs := groupConfigsByDependency(configs, c.logger)
- return runSync(ctx, outputType, benv, groupedConfigs, logger)
-}
-
-func areSourceAndDestCompatible(connection *mgmtv1alpha1.Connection, destinationDriver DriverType) error {
- switch connection.ConnectionConfig.Config.(type) {
- case *mgmtv1alpha1.ConnectionConfig_PgConfig:
- if destinationDriver != postgresDriver {
- return fmt.Errorf("Connection and destination types are incompatible [postgres, %s]", destinationDriver)
- }
- case *mgmtv1alpha1.ConnectionConfig_MysqlConfig:
- if destinationDriver != mysqlDriver {
- return fmt.Errorf("Connection and destination types are incompatible [mysql, %s]", destinationDriver)
- }
- case *mgmtv1alpha1.ConnectionConfig_AwsS3Config, *mgmtv1alpha1.ConnectionConfig_GcpCloudstorageConfig, *mgmtv1alpha1.ConnectionConfig_DynamodbConfig:
- default:
- return errors.New("unsupported destination driver. only postgres and mysql are currently supported")
- }
- return nil
+ return groupedConfigs, nil
}
var (
@@ -678,9 +494,15 @@ func syncData(ctx context.Context, benv *service.Environment, cfg *benthosConfig
}
streamBuilderMu.Lock()
streambldr := benv.NewStreamBuilder()
+ if streambldr == nil {
+ return fmt.Errorf("failed to create StreamBuilder")
+ }
if outputType == output.PlainOutput {
streambldr.SetLogger(logger.With("benthos", "true", "table", cfg.Table, "runType", runType))
}
+ if benv == nil {
+ return fmt.Errorf("benthos env is nil")
+ }
err = streambldr.SetYAML(string(configbits))
if err != nil {
@@ -788,30 +610,26 @@ func cmdConfigToDestinationConnection(cmd *cmdConfig) *mgmtv1alpha1.Connection {
return &mgmtv1alpha1.Connection{}
}
-func runDestinationInitStatements(
- ctx context.Context,
- logger *slog.Logger,
- sqlmanagerclient sqlmanager.SqlManagerClient,
- cmd *cmdConfig,
+func (c *clisync) runDestinationInitStatements(
syncConfigs []*tabledependency.RunConfig,
schemaConfig *schemaConfig,
) error {
dependencyMap := buildDependencyMap(syncConfigs)
- db, err := sqlmanagerclient.NewSqlDbFromUrl(ctx, string(cmd.Destination.Driver), cmd.Destination.ConnectionUrl)
+ db, err := c.sqlmanagerclient.NewSqlDbFromUrl(c.ctx, string(c.cmd.Destination.Driver), c.cmd.Destination.ConnectionUrl)
if err != nil {
return err
}
defer db.Db.Close()
- if cmd.Destination.InitSchema {
+ if c.cmd.Destination.InitSchema {
if len(schemaConfig.InitSchemaStatements) != 0 {
for _, block := range schemaConfig.InitSchemaStatements {
- logger.Info(fmt.Sprintf("[%s] found %d statements to execute during schema initialization", block.Label, len(block.Statements)))
+ c.logger.Info(fmt.Sprintf("[%s] found %d statements to execute during schema initialization", block.Label, len(block.Statements)))
if len(block.Statements) == 0 {
continue
}
- err = db.Db.BatchExec(ctx, batchSize, block.Statements, &sql_manager.BatchExecOpts{})
+ err = db.Db.BatchExec(c.ctx, batchSize, block.Statements, &sql_manager.BatchExecOpts{})
if err != nil {
- logger.Error(fmt.Sprintf("Error creating tables: %v", err))
+ c.logger.Error(fmt.Sprintf("Error creating tables: %v", err))
return fmt.Errorf("unable to exec pg %s statements: %w", block.Label, err)
}
}
@@ -829,15 +647,15 @@ func runDestinationInitStatements(
orderedInitStatements = append(orderedInitStatements, schemaConfig.InitTableStatementsMap[t.String()])
}
- err = db.Db.BatchExec(ctx, batchSize, orderedInitStatements, &sql_manager.BatchExecOpts{})
+ err = db.Db.BatchExec(c.ctx, batchSize, orderedInitStatements, &sql_manager.BatchExecOpts{})
if err != nil {
- logger.Error(fmt.Sprintf("Error creating tables: %v", err))
+ c.logger.Error(fmt.Sprintf("Error creating tables: %v", err))
return err
}
}
}
- if cmd.Destination.Driver == postgresDriver {
- if cmd.Destination.TruncateCascade {
+ if c.cmd.Destination.Driver == postgresDriver {
+ if c.cmd.Destination.TruncateCascade {
truncateCascadeStmts := []string{}
for _, syncCfg := range syncConfigs {
stmt, ok := schemaConfig.TruncateTableStatementsMap[syncCfg.Table()]
@@ -845,12 +663,12 @@ func runDestinationInitStatements(
truncateCascadeStmts = append(truncateCascadeStmts, stmt)
}
}
- err = db.Db.BatchExec(ctx, batchSize, truncateCascadeStmts, &sql_manager.BatchExecOpts{})
+ err = db.Db.BatchExec(c.ctx, batchSize, truncateCascadeStmts, &sql_manager.BatchExecOpts{})
if err != nil {
- logger.Error(fmt.Sprintf("Error truncate cascade tables: %v", err))
+ c.logger.Error(fmt.Sprintf("Error truncate cascade tables: %v", err))
return err
}
- } else if cmd.Destination.TruncateBeforeInsert {
+ } else if c.cmd.Destination.TruncateBeforeInsert {
orderedTablesResp, err := tabledependency.GetTablesOrderedByDependency(dependencyMap)
if err != nil {
return err
@@ -859,13 +677,13 @@ func runDestinationInitStatements(
if err != nil {
return err
}
- err = db.Db.Exec(ctx, orderedTruncateStatement)
+ err = db.Db.Exec(c.ctx, orderedTruncateStatement)
if err != nil {
- logger.Error(fmt.Sprintf("Error truncating tables: %v", err))
+ c.logger.Error(fmt.Sprintf("Error truncating tables: %v", err))
return err
}
}
- } else if cmd.Destination.Driver == mysqlDriver {
+ } else if c.cmd.Destination.Driver == mysqlDriver {
orderedTablesResp, err := tabledependency.GetTablesOrderedByDependency(dependencyMap)
if err != nil {
return err
@@ -875,9 +693,9 @@ func runDestinationInitStatements(
orderedTableTruncateStatements = append(orderedTableTruncateStatements, schemaConfig.TruncateTableStatementsMap[t.String()])
}
disableFkChecks := sql_manager.DisableForeignKeyChecks
- err = db.Db.BatchExec(ctx, batchSize, orderedTableTruncateStatements, &sql_manager.BatchExecOpts{Prefix: &disableFkChecks})
+ err = db.Db.BatchExec(c.ctx, batchSize, orderedTableTruncateStatements, &sql_manager.BatchExecOpts{Prefix: &disableFkChecks})
if err != nil {
- logger.Error(fmt.Sprintf("Error truncating tables: %v", err))
+ c.logger.Error(fmt.Sprintf("Error truncating tables: %v", err))
return err
}
}
@@ -906,21 +724,6 @@ func buildSyncConfigs(
return runConfigs
}
-func buildDependencyMap(syncConfigs []*tabledependency.RunConfig) map[string][]string {
- dependencyMap := map[string][]string{}
- for _, cfg := range syncConfigs {
- _, dpOk := dependencyMap[cfg.Table()]
- if !dpOk {
- dependencyMap[cfg.Table()] = []string{}
- }
-
- for _, dep := range cfg.DependsOn() {
- dependencyMap[cfg.Table()] = append(dependencyMap[cfg.Table()], dep.Table)
- }
- }
- return dependencyMap
-}
-
func getTableInitStatementMap(
ctx context.Context,
logger *slog.Logger,
@@ -954,31 +757,10 @@ func getTableInitStatementMap(
return nil, nil
}
-type SqlTable struct {
- Schema string
- Table string
- Columns []string
-}
-
-func getTableColMap(schemas []*mgmtv1alpha1.DatabaseColumn) map[string][]string {
- tableColMap := map[string][]string{}
- for _, record := range schemas {
- table := sql_manager.BuildTable(record.Schema, record.Table)
- _, ok := tableColMap[table]
- if ok {
- tableColMap[table] = append(tableColMap[table], record.Column)
- } else {
- tableColMap[table] = []string{record.Column}
- }
- }
-
- return tableColMap
-}
-
type benthosConfigResponse struct {
Name string
DependsOn []*tabledependency.DependsOn
- Config *cli_neosync_benthos.BenthosConfig
+ Config *neosync_benthos.BenthosConfig
Table string
Columns []string
}
@@ -986,9 +768,7 @@ type benthosConfigResponse struct {
func generateBenthosConfig(
cmd *cmdConfig,
connectionType ConnectionType,
- apiUrl string,
syncConfig *tabledependency.RunConfig,
- authToken *string,
) *benthosConfigResponse {
schema, table := sqlmanager_shared.SplitTableKey(syncConfig.Table())
@@ -998,17 +778,15 @@ func generateBenthosConfig(
jobId = cmd.Source.ConnectionOpts.JobId
}
- bc := &cli_neosync_benthos.BenthosConfig{
- StreamConfig: cli_neosync_benthos.StreamConfig{
- Logger: &cli_neosync_benthos.LoggerConfig{
+ bc := &neosync_benthos.BenthosConfig{
+ StreamConfig: neosync_benthos.StreamConfig{
+ Logger: &neosync_benthos.LoggerConfig{
Level: "ERROR",
AddTimestamp: true,
},
- Input: &cli_neosync_benthos.InputConfig{
- Inputs: cli_neosync_benthos.Inputs{
- NeosyncConnectionData: &cli_neosync_benthos.NeosyncConnectionData{
- ApiKey: authToken,
- ApiUrl: apiUrl,
+ Input: &neosync_benthos.InputConfig{
+ Inputs: neosync_benthos.Inputs{
+ NeosyncConnectionData: &neosync_benthos.NeosyncConnectionData{
ConnectionId: cmd.Source.ConnectionId,
ConnectionType: string(connectionType),
JobId: jobId,
@@ -1018,17 +796,17 @@ func generateBenthosConfig(
},
},
},
- Pipeline: &cli_neosync_benthos.PipelineConfig{},
- Output: &cli_neosync_benthos.OutputConfig{},
+ Pipeline: &neosync_benthos.PipelineConfig{},
+ Output: &neosync_benthos.OutputConfig{},
},
}
if syncConfig.RunType() == tabledependency.RunTypeUpdate {
args := syncConfig.InsertColumns()
args = append(args, syncConfig.PrimaryKeys()...)
- bc.Output = &cli_neosync_benthos.OutputConfig{
- Outputs: cli_neosync_benthos.Outputs{
- PooledSqlUpdate: &cli_neosync_benthos.PooledSqlUpdate{
+ bc.Output = &neosync_benthos.OutputConfig{
+ Outputs: neosync_benthos.Outputs{
+ PooledSqlUpdate: &neosync_benthos.PooledSqlUpdate{
Driver: string(cmd.Destination.Driver),
Dsn: cmd.Destination.ConnectionUrl,
@@ -1038,7 +816,7 @@ func generateBenthosConfig(
WhereColumns: syncConfig.PrimaryKeys(),
ArgsMapping: buildPlainInsertArgs(args),
- Batching: &cli_neosync_benthos.Batching{
+ Batching: &neosync_benthos.Batching{
Period: "5s",
Count: 100,
},
@@ -1046,9 +824,9 @@ func generateBenthosConfig(
},
}
} else {
- bc.Output = &cli_neosync_benthos.OutputConfig{
- Outputs: cli_neosync_benthos.Outputs{
- PooledSqlInsert: &cli_neosync_benthos.PooledSqlInsert{
+ bc.Output = &neosync_benthos.OutputConfig{
+ Outputs: neosync_benthos.Outputs{
+ PooledSqlInsert: &neosync_benthos.PooledSqlInsert{
Driver: string(cmd.Destination.Driver),
Dsn: cmd.Destination.ConnectionUrl,
@@ -1058,7 +836,7 @@ func generateBenthosConfig(
OnConflictDoNothing: cmd.Destination.OnConflict.DoNothing,
ArgsMapping: buildPlainInsertArgs(syncConfig.SelectColumns()),
- Batching: &cli_neosync_benthos.Batching{
+ Batching: &neosync_benthos.Batching{
Period: "5s",
Count: 100,
},
@@ -1075,66 +853,6 @@ func generateBenthosConfig(
Columns: syncConfig.InsertColumns(),
}
}
-func groupConfigsByDependency(configs []*benthosConfigResponse, logger *slog.Logger) [][]*benthosConfigResponse {
- groupedConfigs := [][]*benthosConfigResponse{}
- configMap := map[string]*benthosConfigResponse{}
- queuedMap := map[string][]string{} // map -> table to cols
-
- // get root configs
- rootConfigs := []*benthosConfigResponse{}
- for _, c := range configs {
- if len(c.DependsOn) == 0 {
- rootConfigs = append(rootConfigs, c)
- queuedMap[c.Table] = c.Columns
- } else {
- configMap[c.Name] = c
- }
- }
- if len(rootConfigs) == 0 {
- logger.Info("No root configs found. There must be one config with no dependencies.")
- return nil
- }
- groupedConfigs = append(groupedConfigs, rootConfigs)
-
- prevTableLen := 0
- for len(configMap) > 0 {
- // prevents looping forever
- if prevTableLen == len(configMap) {
- logger.Info("Unable to order configs by dependency. No path found.")
- return nil
- }
- prevTableLen = len(configMap)
- dependentConfigs := []*benthosConfigResponse{}
- for _, c := range configMap {
- if isConfigReady(c, queuedMap) {
- dependentConfigs = append(dependentConfigs, c)
- delete(configMap, c.Name)
- }
- }
- if len(dependentConfigs) > 0 {
- groupedConfigs = append(groupedConfigs, dependentConfigs)
- for _, c := range dependentConfigs {
- queuedMap[c.Table] = append(queuedMap[c.Table], c.Columns...)
- }
- }
- }
-
- return groupedConfigs
-}
-func isConfigReady(config *benthosConfigResponse, queuedMap map[string][]string) bool {
- for _, dep := range config.DependsOn {
- if cols, ok := queuedMap[dep.Table]; ok {
- for _, dc := range dep.Columns {
- if !slices.Contains(cols, dc) {
- return false
- }
- }
- } else {
- return false
- }
- }
- return true
-}
type schemaConfig struct {
Schemas []*mgmtv1alpha1.DatabaseColumn
@@ -1145,12 +863,8 @@ type schemaConfig struct {
InitSchemaStatements []*mgmtv1alpha1.SchemaInitStatements
}
-func getConnectionSchemaConfig(
- ctx context.Context,
- logger *slog.Logger,
- connectiondataclient mgmtv1alpha1connect.ConnectionDataServiceClient,
+func (c *clisync) getConnectionSchemaConfig(
connection *mgmtv1alpha1.Connection,
- cmd *cmdConfig,
sc *mgmtv1alpha1.ConnectionSchemaConfig,
) (*schemaConfig, error) {
var schemas []*mgmtv1alpha1.DatabaseColumn
@@ -1159,9 +873,9 @@ func getConnectionSchemaConfig(
var initTableStatementsMap map[string]string
var truncateTableStatementsMap map[string]string
var initSchemaStatements []*mgmtv1alpha1.SchemaInitStatements
- errgrp, errctx := errgroup.WithContext(ctx)
+ errgrp, errctx := errgroup.WithContext(c.ctx)
errgrp.Go(func() error {
- schemaResp, err := connectiondataclient.GetConnectionSchema(errctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionSchemaRequest{
+ schemaResp, err := c.connectiondataclient.GetConnectionSchema(errctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionSchemaRequest{
ConnectionId: connection.Id,
SchemaConfig: sc,
}))
@@ -1173,7 +887,7 @@ func getConnectionSchemaConfig(
})
errgrp.Go(func() error {
- constraintConnectionResp, err := connectiondataclient.GetConnectionTableConstraints(errctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionTableConstraintsRequest{ConnectionId: cmd.Source.ConnectionId}))
+ constraintConnectionResp, err := c.connectiondataclient.GetConnectionTableConstraints(errctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionTableConstraintsRequest{ConnectionId: c.cmd.Source.ConnectionId}))
if err != nil {
return err
}
@@ -1183,7 +897,7 @@ func getConnectionSchemaConfig(
})
errgrp.Go(func() error {
- initStatementsResp, err := getTableInitStatementMap(errctx, logger, connectiondataclient, cmd.Source.ConnectionId, cmd.Destination)
+ initStatementsResp, err := getTableInitStatementMap(errctx, c.logger, c.connectiondataclient, c.cmd.Source.ConnectionId, c.cmd.Destination)
if err != nil {
return err
}
@@ -1225,16 +939,11 @@ func getConnectionSchemaConfig(
}, nil
}
-func getDestinationSchemaConfig(
- ctx context.Context,
- connectiondataclient mgmtv1alpha1connect.ConnectionDataServiceClient,
- sqlmanagerclient sqlmanager.SqlManagerClient,
+func (c *clisync) getDestinationSchemaConfig(
connection *mgmtv1alpha1.Connection,
- cmd *cmdConfig,
sc *mgmtv1alpha1.ConnectionSchemaConfig,
- logger *slog.Logger,
) (*schemaConfig, error) {
- schemaResp, err := connectiondataclient.GetConnectionSchema(ctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionSchemaRequest{
+ schemaResp, err := c.connectiondataclient.GetConnectionSchema(c.ctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionSchemaRequest{
ConnectionId: connection.Id,
SchemaConfig: sc,
}))
@@ -1244,7 +953,7 @@ func getDestinationSchemaConfig(
tableColMap := getTableColMap(schemaResp.Msg.GetSchemas())
if len(tableColMap) == 0 {
- logger.Info("No tables found.")
+ c.logger.Info("No tables found.")
return nil, nil
}
@@ -1257,8 +966,8 @@ func getDestinationSchemaConfig(
schemas = append(schemas, s)
}
- logger.Info("Building table constraints...")
- tableConstraints, err := getDestinationTableConstraints(ctx, sqlmanagerclient, cmd.Destination.Driver, cmd.Destination.ConnectionUrl, schemas)
+ c.logger.Info("Building table constraints...")
+ tableConstraints, err := c.getDestinationTableConstraints(schemas)
if err != nil {
return nil, err
}
@@ -1271,8 +980,8 @@ func getDestinationSchemaConfig(
}
truncateTableStatementsMap := map[string]string{}
- if cmd.Destination.Driver == postgresDriver {
- if cmd.Destination.TruncateCascade {
+ if c.cmd.Destination.Driver == postgresDriver {
+ if c.cmd.Destination.TruncateCascade {
for t := range tableColMap {
schema, table := sqlmanager_shared.SplitTableKey(t)
stmt, err := sqlmanager_postgres.BuildPgTruncateCascadeStatement(schema, table)
@@ -1284,7 +993,7 @@ func getDestinationSchemaConfig(
}
// truncate before insert handled in runDestinationInitStatements
} else {
- if cmd.Destination.TruncateBeforeInsert {
+ if c.cmd.Destination.TruncateBeforeInsert {
for t := range tableColMap {
schema, table := sqlmanager_shared.SplitTableKey(t)
stmt, err := sqlmanager_mysql.BuildMysqlTruncateStatement(schema, table)
@@ -1304,10 +1013,10 @@ func getDestinationSchemaConfig(
}, nil
}
-func getDestinationTableConstraints(ctx context.Context, sqlmanagerclient sqlmanager.SqlManagerClient, connectionDriver DriverType, connectionUrl string, schemas []string) (*sql_manager.TableConstraints, error) {
- cctx, cancel := context.WithDeadline(ctx, time.Now().Add(5*time.Second))
+func (c *clisync) getDestinationTableConstraints(schemas []string) (*sql_manager.TableConstraints, error) {
+ cctx, cancel := context.WithDeadline(c.ctx, time.Now().Add(5*time.Second))
defer cancel()
- db, err := sqlmanagerclient.NewSqlDbFromUrl(cctx, string(connectionDriver), connectionUrl)
+ db, err := c.sqlmanagerclient.NewSqlDbFromUrl(cctx, string(c.cmd.Destination.Driver), c.cmd.Destination.ConnectionUrl)
if err != nil {
return nil, err
}
@@ -1320,45 +1029,3 @@ func getDestinationTableConstraints(ctx context.Context, sqlmanagerclient sqlman
return constraints, nil
}
-
-func buildPlainInsertArgs(cols []string) string {
- if len(cols) == 0 {
- return ""
- }
- pieces := make([]string, len(cols))
- for idx := range cols {
- pieces[idx] = fmt.Sprintf("this.%q", cols[idx])
- }
- return fmt.Sprintf("root = [%s]", strings.Join(pieces, ", "))
-}
-
-func maxInt(a, b int) int {
- if a > b {
- return a
- }
- return b
-}
-
-func parseDriverString(str string) (DriverType, bool) {
- p, ok := driverMap[strings.ToLower(str)]
- return p, ok
-}
-
-func getConnectionType(connection *mgmtv1alpha1.Connection) (ConnectionType, error) {
- if connection.ConnectionConfig.GetAwsS3Config() != nil {
- return awsS3Connection, nil
- }
- if connection.GetConnectionConfig().GetGcpCloudstorageConfig() != nil {
- return gcpCloudStorageConnection, nil
- }
- if connection.ConnectionConfig.GetMysqlConfig() != nil {
- return mysqlConnection, nil
- }
- if connection.ConnectionConfig.GetPgConfig() != nil {
- return postgresConnection, nil
- }
- if connection.ConnectionConfig.GetDynamodbConfig() != nil {
- return awsDynamoDBConnection, nil
- }
- return "", errors.New("unsupported connection type")
-}
diff --git a/cli/internal/cmds/neosync/sync/sync_integration_test.go b/cli/internal/cmds/neosync/sync/sync_integration_test.go
new file mode 100644
index 0000000000..1fda80b520
--- /dev/null
+++ b/cli/internal/cmds/neosync/sync/sync_integration_test.go
@@ -0,0 +1,107 @@
+package sync_cmd
+
+import (
+ "context"
+ "testing"
+
+ tcneosyncapi "github.com/nucleuscloud/neosync/backend/pkg/integration-test"
+ "github.com/nucleuscloud/neosync/cli/internal/output"
+ "github.com/nucleuscloud/neosync/internal/testutil"
+ tcpostgres "github.com/nucleuscloud/neosync/internal/testutil/testcontainers/postgres"
+ "github.com/stretchr/testify/require"
+ "golang.org/x/sync/errgroup"
+)
+
+const neosyncDbMigrationsPath = "../../../../../backend/sql/postgresql/schema"
+
+func Test_Sync_Postgres(t *testing.T) {
+ t.Parallel()
+ ok := testutil.ShouldRunIntegrationTest()
+ if !ok {
+ return
+ }
+ ctx := context.Background()
+
+ var neosyncApi *tcneosyncapi.NeosyncApiTestClient
+ var postgres *tcpostgres.PostgresTestSyncContainer
+
+ errgrp := errgroup.Group{}
+ errgrp.Go(func() error {
+ p, err := tcpostgres.NewPostgresTestSyncContainer(ctx, []tcpostgres.Option{}, []tcpostgres.Option{})
+ if err != nil {
+ return err
+ }
+ postgres = p
+ return nil
+ })
+
+ errgrp.Go(func() error {
+ api, err := tcneosyncapi.NewNeosyncApiTestClient(ctx, t, tcneosyncapi.WithMigrationsDirectory(neosyncDbMigrationsPath))
+ if err != nil {
+ return err
+ }
+ neosyncApi = api
+ return nil
+ })
+
+ err := errgrp.Wait()
+ if err != nil {
+ panic(err)
+ }
+
+ testdataFolder := "../../../../../internal/testutil/testdata/postgres/humanresources"
+ err = postgres.Source.RunSqlFiles(ctx, &testdataFolder, []string{"create-tables.sql"})
+ if err != nil {
+ panic(err)
+ }
+ err = postgres.Target.RunSqlFiles(ctx, &testdataFolder, []string{"create-schema.sql"})
+ if err != nil {
+ panic(err)
+ }
+
+ connclient := neosyncApi.UnauthdClients.Connections
+ conndataclient := neosyncApi.UnauthdClients.ConnectionData
+
+ sqlmanagerclient := tcneosyncapi.NewTestSqlManagerClient()
+
+ discardLogger := testutil.GetTestCharmSlogger()
+
+ accountId := tcneosyncapi.CreatePersonalAccount(ctx, t, neosyncApi.UnauthdClients.Users)
+ sourceConn := tcneosyncapi.CreatePostgresConnection(ctx, t, neosyncApi.UnauthdClients.Connections, accountId, "source", postgres.Source.URL)
+ t.Run("sync_postgres", func(t *testing.T) {
+ outputType := output.PlainOutput
+ cmdConfig := &cmdConfig{
+ Source: &sourceConfig{
+ ConnectionId: sourceConn.Id,
+ },
+ Destination: &sqlDestinationConfig{
+ ConnectionUrl: postgres.Target.URL,
+ Driver: postgresDriver,
+ InitSchema: true,
+ TruncateBeforeInsert: true,
+ TruncateCascade: true,
+ },
+ OutputType: &outputType,
+ AccountId: &accountId,
+ }
+ sync := &clisync{
+ connectiondataclient: conndataclient,
+ connectionclient: connclient,
+ sqlmanagerclient: sqlmanagerclient,
+ ctx: ctx,
+ logger: discardLogger,
+ cmd: cmdConfig,
+ }
+ err := sync.configureAndRunSync()
+ require.NoError(t, err)
+ })
+
+ err = postgres.TearDown(ctx)
+ if err != nil {
+ panic(err)
+ }
+ err = neosyncApi.TearDown(ctx)
+ if err != nil {
+ panic(err)
+ }
+}
diff --git a/cli/internal/cmds/neosync/sync/ui.go b/cli/internal/cmds/neosync/sync/ui.go
index 7535f9ed47..2bbe28772b 100644
--- a/cli/internal/cmds/neosync/sync/ui.go
+++ b/cli/internal/cmds/neosync/sync/ui.go
@@ -11,7 +11,6 @@ import (
"golang.org/x/sync/errgroup"
- _ "github.com/nucleuscloud/neosync/cli/internal/benthos/inputs"
"github.com/nucleuscloud/neosync/cli/internal/output"
_ "github.com/nucleuscloud/neosync/worker/pkg/benthos/sql"
_ "github.com/warpstreamlabs/bento/public/components/aws"
@@ -67,7 +66,7 @@ func newModel(ctx context.Context, benv *service.Environment, groupedConfigs [][
}
func (m *model) Init() tea.Cmd {
- return tea.Batch(m.syncConfigs(m.ctx, m.benv, m.groupedConfigs[m.index]), m.spinner.Tick)
+ return tea.Batch(m.syncConfigs(m.ctx, m.groupedConfigs[m.index]), m.spinner.Tick)
}
func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
@@ -97,7 +96,7 @@ func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.index++
return m, tea.Batch(
tea.Println(strings.Join(successStrs, " \n")),
- m.syncConfigs(m.ctx, m.benv, m.groupedConfigs[m.index]),
+ m.syncConfigs(m.ctx, m.groupedConfigs[m.index]),
)
case spinner.TickMsg:
var cmd tea.Cmd
@@ -137,7 +136,7 @@ func (m *model) View() string {
type syncedDataMsg map[string]string
-func (m *model) syncConfigs(ctx context.Context, benv *service.Environment, configs []*benthosConfigResponse) tea.Cmd {
+func (m *model) syncConfigs(ctx context.Context, configs []*benthosConfigResponse) tea.Cmd {
return func() tea.Msg {
messageMap := syncmap.Map{}
errgrp, errctx := errgroup.WithContext(ctx)
@@ -147,7 +146,7 @@ func (m *model) syncConfigs(ctx context.Context, benv *service.Environment, conf
errgrp.Go(func() error {
start := time.Now()
m.logger.Info(fmt.Sprintf("Syncing table %s", cfg.Name))
- err := syncData(errctx, benv, cfg, m.logger, m.outputType)
+ err := syncData(errctx, m.benv, cfg, m.logger, m.outputType)
if err != nil {
fmt.Printf("Error syncing table: %s", err.Error()) //nolint:forbidigo
return err
diff --git a/cli/internal/cmds/neosync/sync/util.go b/cli/internal/cmds/neosync/sync/util.go
new file mode 100644
index 0000000000..ba8da54a29
--- /dev/null
+++ b/cli/internal/cmds/neosync/sync/util.go
@@ -0,0 +1,164 @@
+package sync_cmd
+
+import (
+ "errors"
+ "fmt"
+ "log/slog"
+ "slices"
+ "strings"
+
+ mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1"
+ sql_manager "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/shared"
+ tabledependency "github.com/nucleuscloud/neosync/backend/pkg/table-dependency"
+)
+
+func buildPlainInsertArgs(cols []string) string {
+ if len(cols) == 0 {
+ return ""
+ }
+ pieces := make([]string, len(cols))
+ for idx := range cols {
+ pieces[idx] = fmt.Sprintf("this.%q", cols[idx])
+ }
+ return fmt.Sprintf("root = [%s]", strings.Join(pieces, ", "))
+}
+
+func maxInt(a, b int) int {
+ if a > b {
+ return a
+ }
+ return b
+}
+
+func parseDriverString(str string) (DriverType, bool) {
+ p, ok := driverMap[strings.ToLower(str)]
+ return p, ok
+}
+
+func getConnectionType(connection *mgmtv1alpha1.Connection) (ConnectionType, error) {
+ if connection.ConnectionConfig.GetAwsS3Config() != nil {
+ return awsS3Connection, nil
+ }
+ if connection.GetConnectionConfig().GetGcpCloudstorageConfig() != nil {
+ return gcpCloudStorageConnection, nil
+ }
+ if connection.ConnectionConfig.GetMysqlConfig() != nil {
+ return mysqlConnection, nil
+ }
+ if connection.ConnectionConfig.GetPgConfig() != nil {
+ return postgresConnection, nil
+ }
+ if connection.ConnectionConfig.GetDynamodbConfig() != nil {
+ return awsDynamoDBConnection, nil
+ }
+ return "", errors.New("unsupported connection type")
+}
+
+func isConfigReady(config *benthosConfigResponse, queuedMap map[string][]string) bool {
+ for _, dep := range config.DependsOn {
+ if cols, ok := queuedMap[dep.Table]; ok {
+ for _, dc := range dep.Columns {
+ if !slices.Contains(cols, dc) {
+ return false
+ }
+ }
+ } else {
+ return false
+ }
+ }
+ return true
+}
+
+func groupConfigsByDependency(configs []*benthosConfigResponse, logger *slog.Logger) [][]*benthosConfigResponse {
+ groupedConfigs := [][]*benthosConfigResponse{}
+ configMap := map[string]*benthosConfigResponse{}
+ queuedMap := map[string][]string{} // map -> table to cols
+
+ // get root configs
+ rootConfigs := []*benthosConfigResponse{}
+ for _, c := range configs {
+ if len(c.DependsOn) == 0 {
+ rootConfigs = append(rootConfigs, c)
+ queuedMap[c.Table] = c.Columns
+ } else {
+ configMap[c.Name] = c
+ }
+ }
+ if len(rootConfigs) == 0 {
+ logger.Info("No root configs found. There must be one config with no dependencies.")
+ return nil
+ }
+ groupedConfigs = append(groupedConfigs, rootConfigs)
+
+ prevTableLen := 0
+ for len(configMap) > 0 {
+ // prevents looping forever
+ if prevTableLen == len(configMap) {
+ logger.Info("Unable to order configs by dependency. No path found.")
+ return nil
+ }
+ prevTableLen = len(configMap)
+ dependentConfigs := []*benthosConfigResponse{}
+ for _, c := range configMap {
+ if isConfigReady(c, queuedMap) {
+ dependentConfigs = append(dependentConfigs, c)
+ delete(configMap, c.Name)
+ }
+ }
+ if len(dependentConfigs) > 0 {
+ groupedConfigs = append(groupedConfigs, dependentConfigs)
+ for _, c := range dependentConfigs {
+ queuedMap[c.Table] = append(queuedMap[c.Table], c.Columns...)
+ }
+ }
+ }
+
+ return groupedConfigs
+}
+
+func getTableColMap(schemas []*mgmtv1alpha1.DatabaseColumn) map[string][]string {
+ tableColMap := map[string][]string{}
+ for _, record := range schemas {
+ table := sql_manager.BuildTable(record.Schema, record.Table)
+ _, ok := tableColMap[table]
+ if ok {
+ tableColMap[table] = append(tableColMap[table], record.Column)
+ } else {
+ tableColMap[table] = []string{record.Column}
+ }
+ }
+
+ return tableColMap
+}
+
+func buildDependencyMap(syncConfigs []*tabledependency.RunConfig) map[string][]string {
+ dependencyMap := map[string][]string{}
+ for _, cfg := range syncConfigs {
+ _, dpOk := dependencyMap[cfg.Table()]
+ if !dpOk {
+ dependencyMap[cfg.Table()] = []string{}
+ }
+
+ for _, dep := range cfg.DependsOn() {
+ dependencyMap[cfg.Table()] = append(dependencyMap[cfg.Table()], dep.Table)
+ }
+ }
+ return dependencyMap
+}
+
+func areSourceAndDestCompatible(connection *mgmtv1alpha1.Connection, destinationDriver DriverType) error {
+ switch connection.ConnectionConfig.Config.(type) {
+ case *mgmtv1alpha1.ConnectionConfig_PgConfig:
+ if destinationDriver != postgresDriver {
+ return fmt.Errorf("Connection and destination types are incompatible [postgres, %s]", destinationDriver)
+ }
+ case *mgmtv1alpha1.ConnectionConfig_MysqlConfig:
+ if destinationDriver != mysqlDriver {
+ return fmt.Errorf("Connection and destination types are incompatible [mysql, %s]", destinationDriver)
+ }
+ case *mgmtv1alpha1.ConnectionConfig_AwsS3Config, *mgmtv1alpha1.ConnectionConfig_GcpCloudstorageConfig, *mgmtv1alpha1.ConnectionConfig_DynamodbConfig:
+ default:
+ return errors.New("unsupported destination driver. only postgres and mysql are currently supported")
+ }
+ return nil
+}
diff --git a/internal/testutil/testcontainers/mysql/mysql.go b/internal/testutil/testcontainers/mysql/mysql.go
new file mode 100644
index 0000000000..a02081f3bf
--- /dev/null
+++ b/internal/testutil/testcontainers/mysql/mysql.go
@@ -0,0 +1,181 @@
+package testcontainers_mysql
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "os"
+ "time"
+
+ sqlmanager_shared "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/shared"
+ "github.com/testcontainers/testcontainers-go"
+ "github.com/testcontainers/testcontainers-go/modules/mysql"
+ testmysql "github.com/testcontainers/testcontainers-go/modules/mysql"
+ "github.com/testcontainers/testcontainers-go/wait"
+ "golang.org/x/sync/errgroup"
+)
+
+type MysqlTestSyncContainer struct {
+ Source *MysqlTestContainer
+ Target *MysqlTestContainer
+}
+
+func NewMysqlTestSyncContainer(ctx context.Context, sourceOpts, destOpts []Option) (*MysqlTestSyncContainer, error) {
+ tc := &MysqlTestSyncContainer{}
+ errgrp := errgroup.Group{}
+ errgrp.Go(func() error {
+ m, err := NewMysqlTestContainer(ctx, sourceOpts...)
+ if err != nil {
+ return err
+ }
+ tc.Source = m
+ return nil
+ })
+
+ errgrp.Go(func() error {
+ m, err := NewMysqlTestContainer(ctx, destOpts...)
+ if err != nil {
+ return err
+ }
+ tc.Target = m
+ return nil
+ })
+
+ err := errgrp.Wait()
+ if err != nil {
+ return nil, err
+ }
+
+ return tc, nil
+}
+
+func (m *MysqlTestSyncContainer) TearDown(ctx context.Context) error {
+ if m.Source != nil {
+ err := m.Source.TearDown(ctx)
+ if err != nil {
+ return err
+ }
+ }
+ if m.Target != nil {
+ err := m.Target.TearDown(ctx)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// Holds the MySQL test container and connection pool.
+type MysqlTestContainer struct {
+ DB *sql.DB
+ URL string
+ TestContainer *testmysql.MySQLContainer
+ database string
+ password string
+ username string
+}
+
+// Option is a functional option for configuring the Mysql Test Container
+type Option func(*MysqlTestContainer)
+
+// NewMysqlTestContainer initializes a new MySQL Test Container with functional options
+func NewMysqlTestContainer(ctx context.Context, opts ...Option) (*MysqlTestContainer, error) {
+ m := &MysqlTestContainer{
+ database: "testdb",
+ username: "root",
+ password: "pass",
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m.setup(ctx)
+}
+
+// Sets test container database
+func WithDatabase(database string) Option {
+ return func(a *MysqlTestContainer) {
+ a.database = database
+ }
+}
+
+// Sets test container database
+func WithUsername(username string) Option {
+ return func(a *MysqlTestContainer) {
+ a.username = username
+ }
+}
+
+// Sets test container database
+func WithPassword(password string) Option {
+ return func(a *MysqlTestContainer) {
+ a.password = password
+ }
+}
+
+// Creates and starts a MySQL test container and sets up the connection.
+func (m *MysqlTestContainer) setup(ctx context.Context) (*MysqlTestContainer, error) {
+ mysqlContainer, err := mysql.Run(
+ ctx,
+ "mysql:8.0.36",
+ mysql.WithDatabase(m.database),
+ mysql.WithUsername(m.username),
+ mysql.WithPassword(m.password),
+ testcontainers.WithWaitStrategy(
+ wait.ForLog("port: 3306 MySQL Community Server").WithOccurrence(1).WithStartupTimeout(20*time.Second),
+ ),
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ connStr, err := mysqlContainer.ConnectionString(ctx, "multiStatements=true&parseTime=true")
+ if err != nil {
+ return nil, err
+ }
+
+ db, err := sql.Open(sqlmanager_shared.MysqlDriver, connStr)
+ if err != nil {
+ return nil, err
+ }
+
+ return &MysqlTestContainer{
+ DB: db,
+ URL: connStr,
+ TestContainer: mysqlContainer,
+ }, nil
+}
+
+// Closes the connection pool and terminates the container.
+func (m *MysqlTestContainer) TearDown(ctx context.Context) error {
+ if m.DB != nil {
+ m.DB.Close()
+ }
+
+ if m.TestContainer != nil {
+ err := m.TestContainer.Terminate(ctx)
+ if err != nil {
+ return fmt.Errorf("failed to terminate MySQL container: %w", err)
+ }
+ }
+
+ return nil
+}
+
+// Executes SQL files within the test container
+func (m *MysqlTestContainer) RunSqlFiles(ctx context.Context, folder *string, files []string) error {
+ for _, file := range files {
+ filePath := file
+ if folder != nil && *folder != "" {
+ filePath = fmt.Sprintf("./%s/%s", *folder, file)
+ }
+ sqlStr, err := os.ReadFile(filePath)
+ if err != nil {
+ return err
+ }
+ _, err = m.DB.ExecContext(ctx, string(sqlStr))
+ if err != nil {
+ return fmt.Errorf("unable to exec SQL when running MySQL SQL files: %w", err)
+ }
+ }
+ return nil
+}
diff --git a/internal/testutil/testcontainers/postgres/postgres.go b/internal/testutil/testcontainers/postgres/postgres.go
new file mode 100644
index 0000000000..6073c53e94
--- /dev/null
+++ b/internal/testutil/testcontainers/postgres/postgres.go
@@ -0,0 +1,181 @@
+package testcontainers_postgres
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "time"
+
+ "github.com/jackc/pgx/v5/pgxpool"
+ "github.com/testcontainers/testcontainers-go"
+ "github.com/testcontainers/testcontainers-go/modules/postgres"
+ testpg "github.com/testcontainers/testcontainers-go/modules/postgres"
+ "github.com/testcontainers/testcontainers-go/wait"
+ "golang.org/x/sync/errgroup"
+)
+
+type PostgresTestSyncContainer struct {
+ Source *PostgresTestContainer
+ Target *PostgresTestContainer
+}
+
+func NewPostgresTestSyncContainer(ctx context.Context, sourceOpts, destOpts []Option) (*PostgresTestSyncContainer, error) {
+ tc := &PostgresTestSyncContainer{}
+ errgrp := errgroup.Group{}
+ errgrp.Go(func() error {
+ p, err := NewPostgresTestContainer(ctx, sourceOpts...)
+ if err != nil {
+ return err
+ }
+ tc.Source = p
+ return nil
+ })
+
+ errgrp.Go(func() error {
+ p, err := NewPostgresTestContainer(ctx, destOpts...)
+ if err != nil {
+ return err
+ }
+ tc.Target = p
+ return nil
+ })
+
+ err := errgrp.Wait()
+ if err != nil {
+ return nil, err
+ }
+
+ return tc, nil
+}
+
+func (p *PostgresTestSyncContainer) TearDown(ctx context.Context) error {
+ if p.Source != nil {
+ err := p.Source.TearDown(ctx)
+ if err != nil {
+ return err
+ }
+ }
+ if p.Target != nil {
+ err := p.Target.TearDown(ctx)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// Holds the PostgreSQL test container and connection pool.
+type PostgresTestContainer struct {
+ DB *pgxpool.Pool
+ URL string
+ TestContainer *testpg.PostgresContainer
+ database string
+ username string
+ password string
+}
+
+// Option is a functional option for configuring the Postgres Test Container
+type Option func(*PostgresTestContainer)
+
+// NewPostgresTestContainer initializes a new Postgres Test Container with functional options
+func NewPostgresTestContainer(ctx context.Context, opts ...Option) (*PostgresTestContainer, error) {
+ p := &PostgresTestContainer{
+ database: "testdb",
+ username: "postgres",
+ password: "pass",
+ }
+ for _, opt := range opts {
+ opt(p)
+ }
+ return p.Setup(ctx)
+}
+
+// Sets test container database
+func WithDatabase(database string) Option {
+ return func(a *PostgresTestContainer) {
+ a.database = database
+ }
+}
+
+// Sets test container database
+func WithUsername(username string) Option {
+ return func(a *PostgresTestContainer) {
+ a.username = username
+ }
+}
+
+// Sets test container database
+func WithPassword(password string) Option {
+ return func(a *PostgresTestContainer) {
+ a.password = password
+ }
+}
+
+// Creates and starts a PostgreSQL test container and sets up the connection.
+func (p *PostgresTestContainer) Setup(ctx context.Context) (*PostgresTestContainer, error) {
+ pgContainer, err := postgres.Run(
+ ctx,
+ "postgres:15",
+ postgres.WithDatabase(p.database),
+ postgres.WithUsername(p.username),
+ postgres.WithPassword(p.password),
+ testcontainers.WithWaitStrategy(
+ wait.ForLog("database system is ready to accept connections").
+ WithOccurrence(2).WithStartupTimeout(20*time.Second),
+ ),
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ connStr, err := pgContainer.ConnectionString(ctx, "sslmode=disable")
+ if err != nil {
+ return nil, err
+ }
+
+ pool, err := pgxpool.New(ctx, connStr)
+ if err != nil {
+ return nil, err
+ }
+
+ return &PostgresTestContainer{
+ DB: pool,
+ URL: connStr,
+ TestContainer: pgContainer,
+ }, nil
+}
+
+// Closes the connection pool and terminates the container.
+func (p *PostgresTestContainer) TearDown(ctx context.Context) error {
+ if p.DB != nil {
+ p.DB.Close()
+ }
+
+ if p.TestContainer != nil {
+ err := p.TestContainer.Terminate(ctx)
+ if err != nil {
+ return fmt.Errorf("failed to terminate postgres container: %w", err)
+ }
+ }
+
+ return nil
+}
+
+// Executes SQL files within the test container
+func (p *PostgresTestContainer) RunSqlFiles(ctx context.Context, folder *string, files []string) error {
+ for _, file := range files {
+ filePath := file
+ if folder != nil && *folder != "" {
+ filePath = fmt.Sprintf("./%s/%s", *folder, file)
+ }
+ sqlStr, err := os.ReadFile(filePath)
+ if err != nil {
+ return err
+ }
+ _, err = p.DB.Exec(ctx, string(sqlStr))
+ if err != nil {
+ return fmt.Errorf("unable to exec sql when running postgres sql files: %w", err)
+ }
+ }
+ return nil
+}
diff --git a/internal/testutil/testdata/postgres/alltypes/create-schema.sql b/internal/testutil/testdata/postgres/alltypes/create-schema.sql
new file mode 100644
index 0000000000..5ac3f363a9
--- /dev/null
+++ b/internal/testutil/testdata/postgres/alltypes/create-schema.sql
@@ -0,0 +1 @@
+CREATE SCHEMA IF NOT EXISTS alltypes;
diff --git a/internal/testutil/testdata/postgres/alltypes/create-tables.sql b/internal/testutil/testdata/postgres/alltypes/create-tables.sql
new file mode 100644
index 0000000000..9403da5d35
--- /dev/null
+++ b/internal/testutil/testdata/postgres/alltypes/create-tables.sql
@@ -0,0 +1,348 @@
+CREATE SCHEMA IF NOT EXISTS alltypes;
+CREATE TABLE IF NOT EXISTS alltypes.all_postgres_types (
+ id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY,
+ -- Numeric Types
+ smallint_col SMALLINT,
+ integer_col INTEGER,
+ bigint_col BIGINT,
+ decimal_col DECIMAL(10, 2),
+ numeric_col NUMERIC(10, 2),
+ real_col REAL,
+ double_precision_col DOUBLE PRECISION,
+ serial_col SERIAL,
+ bigserial_col BIGSERIAL,
+
+ -- Monetary Types
+ money_col MONEY,
+
+ -- Character Types
+ char_col CHAR(10),
+ varchar_col VARCHAR(50),
+ text_col TEXT,
+
+ -- Binary Types
+ bytea_col BYTEA,
+
+ -- Date/Time Types
+ timestamp_col TIMESTAMP,
+ timestamptz_col TIMESTAMPTZ,
+ date_col DATE,
+ time_col TIME,
+ timetz_col TIMETZ,
+ interval_col INTERVAL,
+
+ -- Boolean Type
+ boolean_col BOOLEAN,
+
+ -- UUID Type
+ uuid_col UUID,
+
+ -- Network Address Types
+ inet_col INET,
+ cidr_col CIDR,
+ macaddr_col MACADDR,
+
+ -- Bit String Types
+ bit_col BIT(8),
+ varbit_col VARBIT(8),
+
+ -- Geometric Types
+ point_col POINT,
+ line_col LINE,
+ lseg_col LSEG,
+ box_col BOX,
+ path_col PATH,
+ polygon_col POLYGON,
+ circle_col CIRCLE,
+
+ -- JSON Types
+ json_col JSON,
+ jsonb_col JSONB,
+
+ -- Range Types
+ int4range_col INT4RANGE,
+ int8range_col INT8RANGE,
+ numrange_col NUMRANGE,
+ tsrange_col TSRANGE,
+ tstzrange_col TSTZRANGE,
+ daterange_col DATERANGE,
+
+ -- Array Types
+ integer_array_col INTEGER[],
+ text_array_col TEXT[],
+
+ -- XML Type
+ xml_col XML,
+
+ -- TSVECTOR Type (Full-Text Search)
+ tsvector_col TSVECTOR,
+
+ -- OID Type
+ oid_col OID
+);
+
+
+INSERT INTO alltypes.all_postgres_types (
+ Id,
+ smallint_col,
+ integer_col,
+ bigint_col,
+ decimal_col,
+ numeric_col,
+ real_col,
+ double_precision_col,
+ serial_col,
+ bigserial_col,
+ money_col,
+ char_col,
+ varchar_col,
+ text_col,
+ bytea_col,
+ timestamp_col,
+ timestamptz_col,
+ date_col,
+ time_col,
+ timetz_col,
+ interval_col,
+ boolean_col,
+ uuid_col,
+ inet_col,
+ cidr_col,
+ macaddr_col,
+ bit_col,
+ varbit_col,
+ point_col,
+ line_col,
+ lseg_col,
+ box_col,
+ path_col,
+ polygon_col,
+ circle_col,
+ json_col,
+ jsonb_col,
+ int4range_col,
+ int8range_col,
+ numrange_col,
+ tsrange_col,
+ tstzrange_col,
+ daterange_col,
+ integer_array_col,
+ text_array_col,
+ xml_col,
+ tsvector_col,
+ oid_col
+) VALUES (
+ DEFAULT,
+ 32767, -- smallint_col
+ 2147483647, -- integer_col
+ 9223372036854775807, -- bigint_col
+ 1234.56, -- decimal_col
+ 99999999.99, -- numeric_col
+ 12345.67, -- real_col
+ 123456789.123456789, -- double_precision_col
+ 1, -- serial_col (auto-incremented, will be generated)
+ 1, -- bigserial_col (auto-incremented, will be generated)
+ '$100.00', -- money_col
+ 'A', -- char_col
+ 'DEFAULT', -- varchar_col
+ 'default', -- text_col
+ decode('DEADBEEF', 'hex'), -- bytea_col
+ '2024-01-01 12:34:56', -- timestamp_col
+ '2024-01-01 12:34:56+00', -- timestamptz_col
+ '2024-01-01', -- date_col
+ '12:34:56', -- time_col
+ '12:34:56+00', -- timetz_col
+ '1 day', -- interval_col
+ TRUE, -- boolean_col
+ '123e4567-e89b-12d3-a456-426614174000', -- uuid_col
+ '192.168.1.1', -- inet_col
+ '192.168.1.0/24', -- cidr_col
+ '08:00:2b:01:02:03', -- macaddr_col
+ B'10101010', -- bit_col
+ B'1010', -- varbit_col
+ '(1, 2)', -- point_col
+ '{1, 1, 0}', -- line_col
+ '[(0,0), (1,1)]', -- lseg_col
+ '(0,0),(1,1)', -- box_col
+ '((0,0), (1,1), (2,2))', -- path_col
+ '((0,0), (1,1), (1,0))', -- polygon_col
+ '<(1,1),1>', -- circle_col
+ '{"name": "John", "age": 30}', -- json_col
+ '{"name": "John", "age": 30}', -- jsonb_col
+ '[1,10]', -- int4range_col
+ '[1,1000]', -- int8range_col
+ '[1.0,10.0]', -- numrange_col
+ '[2024-01-01 12:00:00, 2024-01-01 13:00:00]', -- tsrange_col
+ '[2024-01-01 12:00:00+00, 2024-01-01 13:00:00+00]', -- tstzrange_col
+ '[2024-01-01, 2024-01-02]', -- daterange_col
+ '{1, 2, 3}', -- integer_array_col
+ '{"one", "two", "three"}', -- text_array_col
+ 'bar', -- xml_col
+ 'example tsvector', -- tsvector_col
+ 123456 -- oid_col
+);
+
+INSERT INTO alltypes.all_postgres_types (
+ Id
+) VALUES (
+ DEFAULT
+);
+
+
+CREATE TABLE IF NOT EXISTS alltypes.time_time (
+ id SERIAL PRIMARY KEY,
+ timestamp_col TIMESTAMP,
+ timestamptz_col TIMESTAMPTZ,
+ date_col DATE
+);
+
+INSERT INTO alltypes.time_time (
+ timestamp_col,
+ timestamptz_col,
+ date_col
+)
+VALUES (
+ '2024-03-18 10:30:00',
+ '2024-03-18 10:30:00+00',
+ '2024-03-18'
+);
+
+INSERT INTO alltypes.time_time (
+ timestamp_col,
+ timestamptz_col,
+ date_col
+)
+VALUES (
+ '0001-01-01 00:00:00 BC',
+ '0001-01-01 00:00:00+00 BC',
+ '0001-01-01 BC'
+);
+
+
+-- CREATE TABLE IF NOT EXISTS alltypes.array_types (
+-- "id" BIGINT NOT NULL PRIMARY KEY,
+-- "int_array" _int4,
+-- "smallint_array" _int2,
+-- "bigint_array" _int8,
+-- "real_array" _float4,
+-- "double_array" _float8,
+-- "text_array" _text,
+-- "varchar_array" _varchar,
+-- "char_array" _bpchar,
+-- "boolean_array" _bool,
+-- "date_array" _date,
+-- "time_array" _time,
+-- "timestamp_array" _timestamp,
+-- "timestamptz_array" _timestamptz,
+-- "interval_array" _interval,
+-- -- "inet_array" _inet, // broken
+-- -- "cidr_array" _cidr,
+-- "point_array" _point,
+-- "line_array" _line,
+-- "lseg_array" _lseg,
+-- -- "box_array" _box, // broken
+-- "path_array" _path,
+-- "polygon_array" _polygon,
+-- "circle_array" _circle,
+-- "uuid_array" _uuid,
+-- "json_array" _json,
+-- "jsonb_array" _jsonb,
+-- "bit_array" _bit,
+-- "varbit_array" _varbit,
+-- "numeric_array" _numeric,
+-- "money_array" _money,
+-- "xml_array" _xml,
+-- "int_double_array" _int4
+-- );
+
+
+-- INSERT INTO alltypes.array_types (
+-- id, int_array, smallint_array, bigint_array, real_array, double_array,
+-- text_array, varchar_array, char_array, boolean_array, date_array,
+-- time_array, timestamp_array, timestamptz_array, interval_array,
+-- -- inet_array, cidr_array,
+-- point_array, line_array, lseg_array,
+-- -- box_array,
+-- path_array, polygon_array, circle_array,
+-- uuid_array,
+-- json_array, jsonb_array,
+-- bit_array, varbit_array, numeric_array,
+-- money_array, xml_array, int_double_array
+-- ) VALUES (
+-- 1,
+-- ARRAY[1, 2, 3],
+-- ARRAY[10::smallint, 20::smallint],
+-- ARRAY[100::bigint, 200::bigint],
+-- ARRAY[1.1::real, 2.2::real],
+-- ARRAY[1.11::double precision, 2.22::double precision],
+-- ARRAY['text1', 'text2'],
+-- ARRAY['varchar1'::varchar, 'varchar2'::varchar],
+-- ARRAY['a'::char, 'b'::char],
+-- ARRAY[true, false],
+-- ARRAY['2023-01-01'::date, '2023-01-02'::date],
+-- ARRAY['12:00:00'::time, '13:00:00'::time],
+-- ARRAY['2023-01-01 12:00:00'::timestamp, '2023-01-02 13:00:00'::timestamp],
+-- ARRAY['2023-01-01 12:00:00+00'::timestamptz, '2023-01-02 13:00:00+00'::timestamptz],
+-- ARRAY['1 day'::interval, '2 hours'::interval],
+-- -- ARRAY['192.168.0.1'::inet, '10.0.0.1'::inet],
+-- -- ARRAY['192.168.0.0/24'::cidr, '10.0.0.0/8'::cidr],
+-- ARRAY['(1,1)'::point, '(2,2)'::point],
+-- ARRAY['{1,2,2}'::line, '{3,4,4}'::line],
+-- ARRAY['(1,1,2,2)'::lseg, '(3,3,4,4)'::lseg],
+-- -- ARRAY['(1,1,2,2)'::box, '(3,3,4,4)'::box],
+-- ARRAY['((1,1),(2,2),(3,3))'::path, '((4,4),(5,5),(6,6))'::path],
+-- ARRAY['((1,1),(2,2),(3,3))'::polygon, '((4,4),(5,5),(6,6))'::polygon],
+-- ARRAY['<(1,1),1>'::circle, '<(2,2),2>'::circle],
+-- ARRAY['a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11'::uuid, 'b0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11'::uuid],
+-- ARRAY['{"key": "value1"}'::json, '{"key": "value2"}'::json],
+-- ARRAY['{"key": "value1"}'::jsonb, '{"key": "value2"}'::jsonb],
+-- ARRAY['101'::bit(3), '110'::bit(3)],
+-- ARRAY['10101'::bit varying(5), '01010'::bit varying(5)],
+-- ARRAY[1.23::numeric, 4.56::numeric],
+-- ARRAY[10.00::money, 20.00::money],
+-- ARRAY['value1'::xml, 'value2'::xml],
+-- ARRAY[[1, 2], [3, 4]]
+-- );
+
+
+CREATE TABLE alltypes.json_data (
+ id SERIAL PRIMARY KEY,
+ data JSONB
+);
+
+
+INSERT INTO alltypes.json_data (data) VALUES ('"Hello, world!"');
+INSERT INTO alltypes.json_data (data) VALUES ('42');
+INSERT INTO alltypes.json_data (data) VALUES ('3.14');
+INSERT INTO alltypes.json_data (data) VALUES ('true');
+INSERT INTO alltypes.json_data (data) VALUES ('false');
+INSERT INTO alltypes.json_data (data) VALUES ('null');
+
+INSERT INTO alltypes.json_data (data) VALUES ('{"name": "John", "age": 30}');
+INSERT INTO alltypes.json_data (data) VALUES ('{"coords": {"x": 10, "y": 20}}');
+
+INSERT INTO alltypes.json_data (data) VALUES ('[1, 2, 3, 4]');
+INSERT INTO alltypes.json_data (data) VALUES ('["apple", "banana", "cherry"]');
+
+INSERT INTO alltypes.json_data (data) VALUES ('{"items": ["book", "pen"], "count": 2, "in_stock": true}');
+
+INSERT INTO alltypes.json_data (data) VALUES (
+ '{
+ "user": {
+ "name": "Alice",
+ "age": 28,
+ "contacts": [
+ {"type": "email", "value": "alice@example.com"},
+ {"type": "phone", "value": "123-456-7890"}
+ ]
+ },
+ "orders": [
+ {"id": 1001, "total": 59.99},
+ {"id": 1002, "total": 24.50}
+ ],
+ "preferences": {
+ "notifications": true,
+ "theme": "dark"
+ }
+ }'
+);
diff --git a/internal/testutil/testdata/postgres/alltypes/teardown.sql b/internal/testutil/testdata/postgres/alltypes/teardown.sql
new file mode 100644
index 0000000000..7313dd27c3
--- /dev/null
+++ b/internal/testutil/testdata/postgres/alltypes/teardown.sql
@@ -0,0 +1 @@
+DROP SCHEMA IF EXISTS alltypes CASCADE;
diff --git a/internal/testutil/testdata/postgres/humanresources/create-schema.sql b/internal/testutil/testdata/postgres/humanresources/create-schema.sql
new file mode 100644
index 0000000000..9112cc24aa
--- /dev/null
+++ b/internal/testutil/testdata/postgres/humanresources/create-schema.sql
@@ -0,0 +1 @@
+CREATE SCHEMA IF NOT EXISTS humanresources;
diff --git a/internal/testutil/testdata/postgres/humanresources/create-tables.sql b/internal/testutil/testdata/postgres/humanresources/create-tables.sql
new file mode 100644
index 0000000000..5e9d8b3d58
--- /dev/null
+++ b/internal/testutil/testdata/postgres/humanresources/create-tables.sql
@@ -0,0 +1,225 @@
+CREATE SCHEMA IF NOT EXISTS humanresources;
+SET search_path TO humanresources;
+
+CREATE TABLE regions (
+ region_id SERIAL PRIMARY KEY,
+ region_name CHARACTER VARYING (25)
+);
+
+CREATE TABLE countries (
+ country_id CHARACTER (2) PRIMARY KEY,
+ country_name CHARACTER VARYING (40),
+ region_id INTEGER NOT NULL,
+ FOREIGN KEY (region_id) REFERENCES regions (region_id) ON UPDATE CASCADE ON DELETE CASCADE
+);
+
+CREATE TABLE locations (
+ location_id SERIAL PRIMARY KEY,
+ street_address CHARACTER VARYING (40),
+ postal_code CHARACTER VARYING (12),
+ city CHARACTER VARYING (30) NOT NULL,
+ state_province CHARACTER VARYING (25),
+ country_id CHARACTER (2) NOT NULL,
+ FOREIGN KEY (country_id) REFERENCES countries (country_id) ON UPDATE CASCADE ON DELETE CASCADE
+);
+
+CREATE TABLE departments (
+ department_id SERIAL PRIMARY KEY,
+ department_name CHARACTER VARYING (30) NOT NULL,
+ location_id INTEGER,
+ FOREIGN KEY (location_id) REFERENCES locations (location_id) ON UPDATE CASCADE ON DELETE CASCADE
+);
+
+CREATE TABLE jobs (
+ job_id SERIAL PRIMARY KEY,
+ job_title CHARACTER VARYING (35) NOT NULL,
+ min_salary NUMERIC (8, 2),
+ max_salary NUMERIC (8, 2)
+);
+
+CREATE TABLE employees (
+ employee_id SERIAL PRIMARY KEY,
+ first_name CHARACTER VARYING (20),
+ last_name CHARACTER VARYING (25) NOT NULL,
+ email CHARACTER VARYING (100) NOT NULL,
+ phone_number CHARACTER VARYING (20),
+ hire_date DATE NOT NULL,
+ job_id INTEGER NOT NULL,
+ salary NUMERIC (8, 2) NOT NULL,
+ manager_id INTEGER,
+ department_id INTEGER,
+ FOREIGN KEY (job_id) REFERENCES jobs (job_id) ON UPDATE CASCADE ON DELETE CASCADE,
+ FOREIGN KEY (department_id) REFERENCES departments (department_id) ON UPDATE CASCADE ON DELETE CASCADE,
+ FOREIGN KEY (manager_id) REFERENCES employees (employee_id) ON UPDATE CASCADE ON DELETE CASCADE
+);
+
+CREATE TABLE dependents (
+ dependent_id SERIAL PRIMARY KEY,
+ first_name CHARACTER VARYING (50) NOT NULL,
+ last_name CHARACTER VARYING (50) NOT NULL,
+ relationship CHARACTER VARYING (25) NOT NULL,
+ employee_id INTEGER NOT NULL,
+ FOREIGN KEY (employee_id) REFERENCES employees (employee_id) ON DELETE CASCADE ON UPDATE CASCADE
+);
+
+
+/*Data for the table regions */
+
+INSERT INTO regions(region_id,region_name) VALUES (1,'Europe');
+INSERT INTO regions(region_id,region_name) VALUES (2,'Americas');
+INSERT INTO regions(region_id,region_name) VALUES (3,'Asia');
+INSERT INTO regions(region_id,region_name) VALUES (4,'Middle East and Africa');
+
+
+/*Data for the table countries */
+INSERT INTO countries(country_id,country_name,region_id) VALUES ('AR','Argentina',2);
+INSERT INTO countries(country_id,country_name,region_id) VALUES ('AU','Australia',3);
+INSERT INTO countries(country_id,country_name,region_id) VALUES ('BE','Belgium',1);
+INSERT INTO countries(country_id,country_name,region_id) VALUES ('BR','Brazil',2);
+INSERT INTO countries(country_id,country_name,region_id) VALUES ('CA','Canada',2);
+INSERT INTO countries(country_id,country_name,region_id) VALUES ('CH','Switzerland',1);
+INSERT INTO countries(country_id,country_name,region_id) VALUES ('CN','China',3);
+INSERT INTO countries(country_id,country_name,region_id) VALUES ('DE','Germany',1);
+INSERT INTO countries(country_id,country_name,region_id) VALUES ('DK','Denmark',1);
+INSERT INTO countries(country_id,country_name,region_id) VALUES ('EG','Egypt',4);
+INSERT INTO countries(country_id,country_name,region_id) VALUES ('FR','France',1);
+INSERT INTO countries(country_id,country_name,region_id) VALUES ('HK','HongKong',3);
+INSERT INTO countries(country_id,country_name,region_id) VALUES ('IL','Israel',4);
+INSERT INTO countries(country_id,country_name,region_id) VALUES ('IN','India',3);
+INSERT INTO countries(country_id,country_name,region_id) VALUES ('IT','Italy',1);
+INSERT INTO countries(country_id,country_name,region_id) VALUES ('JP','Japan',3);
+INSERT INTO countries(country_id,country_name,region_id) VALUES ('KW','Kuwait',4);
+INSERT INTO countries(country_id,country_name,region_id) VALUES ('MX','Mexico',2);
+INSERT INTO countries(country_id,country_name,region_id) VALUES ('NG','Nigeria',4);
+INSERT INTO countries(country_id,country_name,region_id) VALUES ('NL','Netherlands',1);
+INSERT INTO countries(country_id,country_name,region_id) VALUES ('SG','Singapore',3);
+INSERT INTO countries(country_id,country_name,region_id) VALUES ('UK','United Kingdom',1);
+INSERT INTO countries(country_id,country_name,region_id) VALUES ('US','United States of America',2);
+INSERT INTO countries(country_id,country_name,region_id) VALUES ('ZM','Zambia',4);
+INSERT INTO countries(country_id,country_name,region_id) VALUES ('ZW','Zimbabwe',4);
+
+/*Data for the table locations */
+INSERT INTO locations(location_id,street_address,postal_code,city,state_province,country_id) VALUES (1400,'2014 Jabberwocky Rd','26192','Southlake','Texas','US');
+INSERT INTO locations(location_id,street_address,postal_code,city,state_province,country_id) VALUES (1500,'2011 Interiors Blvd','99236','South San Francisco','California','US');
+INSERT INTO locations(location_id,street_address,postal_code,city,state_province,country_id) VALUES (1700,'2004 Charade Rd','98199','Seattle','Washington','US');
+INSERT INTO locations(location_id,street_address,postal_code,city,state_province,country_id) VALUES (1800,'147 Spadina Ave','M5V 2L7','Toronto','Ontario','CA');
+INSERT INTO locations(location_id,street_address,postal_code,city,state_province,country_id) VALUES (2400,'8204 Arthur St',NULL,'London',NULL,'UK');
+INSERT INTO locations(location_id,street_address,postal_code,city,state_province,country_id) VALUES (2500,'Magdalen Centre, The Oxford Science Park','OX9 9ZB','Oxford','Oxford','UK');
+INSERT INTO locations(location_id,street_address,postal_code,city,state_province,country_id) VALUES (2700,'Schwanthalerstr. 7031','80925','Munich','Bavaria','DE');
+
+
+/*Data for the table jobs */
+
+INSERT INTO jobs(job_id,job_title,min_salary,max_salary) VALUES (1,'Public Accountant',4200.00,9000.00);
+INSERT INTO jobs(job_id,job_title,min_salary,max_salary) VALUES (2,'Accounting Manager',8200.00,16000.00);
+INSERT INTO jobs(job_id,job_title,min_salary,max_salary) VALUES (3,'Administration Assistant',3000.00,6000.00);
+INSERT INTO jobs(job_id,job_title,min_salary,max_salary) VALUES (4,'President',20000.00,40000.00);
+INSERT INTO jobs(job_id,job_title,min_salary,max_salary) VALUES (5,'Administration Vice President',15000.00,30000.00);
+INSERT INTO jobs(job_id,job_title,min_salary,max_salary) VALUES (6,'Accountant',4200.00,9000.00);
+INSERT INTO jobs(job_id,job_title,min_salary,max_salary) VALUES (7,'Finance Manager',8200.00,16000.00);
+INSERT INTO jobs(job_id,job_title,min_salary,max_salary) VALUES (8,'Human Resources Representative',4000.00,9000.00);
+INSERT INTO jobs(job_id,job_title,min_salary,max_salary) VALUES (9,'Programmer',4000.00,10000.00);
+INSERT INTO jobs(job_id,job_title,min_salary,max_salary) VALUES (10,'Marketing Manager',9000.00,15000.00);
+INSERT INTO jobs(job_id,job_title,min_salary,max_salary) VALUES (11,'Marketing Representative',4000.00,9000.00);
+INSERT INTO jobs(job_id,job_title,min_salary,max_salary) VALUES (12,'Public Relations Representative',4500.00,10500.00);
+INSERT INTO jobs(job_id,job_title,min_salary,max_salary) VALUES (13,'Purchasing Clerk',2500.00,5500.00);
+INSERT INTO jobs(job_id,job_title,min_salary,max_salary) VALUES (14,'Purchasing Manager',8000.00,15000.00);
+INSERT INTO jobs(job_id,job_title,min_salary,max_salary) VALUES (15,'Sales Manager',10000.00,20000.00);
+INSERT INTO jobs(job_id,job_title,min_salary,max_salary) VALUES (16,'Sales Representative',6000.00,12000.00);
+INSERT INTO jobs(job_id,job_title,min_salary,max_salary) VALUES (17,'Shipping Clerk',2500.00,5500.00);
+INSERT INTO jobs(job_id,job_title,min_salary,max_salary) VALUES (18,'Stock Clerk',2000.00,5000.00);
+INSERT INTO jobs(job_id,job_title,min_salary,max_salary) VALUES (19,'Stock Manager',5500.00,8500.00);
+
+
+/*Data for the table departments */
+
+INSERT INTO departments(department_id,department_name,location_id) VALUES (1,'Administration',1700);
+INSERT INTO departments(department_id,department_name,location_id) VALUES (2,'Marketing',1800);
+INSERT INTO departments(department_id,department_name,location_id) VALUES (3,'Purchasing',1700);
+INSERT INTO departments(department_id,department_name,location_id) VALUES (4,'Human Resources',2400);
+INSERT INTO departments(department_id,department_name,location_id) VALUES (5,'Shipping',1500);
+INSERT INTO departments(department_id,department_name,location_id) VALUES (6,'IT',1400);
+INSERT INTO departments(department_id,department_name,location_id) VALUES (7,'Public Relations',2700);
+INSERT INTO departments(department_id,department_name,location_id) VALUES (8,'Sales',2500);
+INSERT INTO departments(department_id,department_name,location_id) VALUES (9,'Executive',1700);
+INSERT INTO departments(department_id,department_name,location_id) VALUES (10,'Finance',1700);
+INSERT INTO departments(department_id,department_name,location_id) VALUES (11,'Accounting',1700);
+
+
+
+/*Data for the table employees */
+
+INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (100,'Steven','King','steven.king@sqltutorial.org','515.123.4567','1987-06-17',4,24000.00,NULL,9);
+INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (101,'Neena','Kochhar','neena.kochhar@sqltutorial.org','515.123.4568','1989-09-21',5,17000.00,100,9);
+INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (102,'Lex','De Haan','lex.de haan@sqltutorial.org','515.123.4569','1993-01-13',5,17000.00,100,9);
+INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (103,'Alexander','Hunold','alexander.hunold@sqltutorial.org','590.423.4567','1990-01-03',9,9000.00,102,6);
+INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (104,'Bruce','Ernst','bruce.ernst@sqltutorial.org','590.423.4568','1991-05-21',9,6000.00,103,6);
+INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (105,'David','Austin','david.austin@sqltutorial.org','590.423.4569','1997-06-25',9,4800.00,103,6);
+INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (106,'Valli','Pataballa','valli.pataballa@sqltutorial.org','590.423.4560','1998-02-05',9,4800.00,103,6);
+INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (107,'Diana','Lorentz','diana.lorentz@sqltutorial.org','590.423.5567','1999-02-07',9,4200.00,103,6);
+INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (108,'Nancy','Greenberg','nancy.greenberg@sqltutorial.org','515.124.4569','1994-08-17',7,12000.00,101,10);
+INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (109,'Daniel','Faviet','daniel.faviet@sqltutorial.org','515.124.4169','1994-08-16',6,9000.00,108,10);
+INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (110,'John','Chen','john.chen@sqltutorial.org','515.124.4269','1997-09-28',6,8200.00,108,10);
+INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (111,'Ismael','Sciarra','ismael.sciarra@sqltutorial.org','515.124.4369','1997-09-30',6,7700.00,108,10);
+INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (112,'Jose Manuel','Urman','jose manuel.urman@sqltutorial.org','515.124.4469','1998-03-07',6,7800.00,108,10);
+INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (113,'Luis','Popp','luis.popp@sqltutorial.org','515.124.4567','1999-12-07',6,6900.00,108,10);
+INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (114,'Den','Raphaely','den.raphaely@sqltutorial.org','515.127.4561','1994-12-07',14,11000.00,100,3);
+INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (115,'Alexander','Khoo','alexander.khoo@sqltutorial.org','515.127.4562','1995-05-18',13,3100.00,114,3);
+INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (116,'Shelli','Baida','shelli.baida@sqltutorial.org','515.127.4563','1997-12-24',13,2900.00,114,3);
+INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (117,'Sigal','Tobias','sigal.tobias@sqltutorial.org','515.127.4564','1997-07-24',13,2800.00,114,3);
+INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (118,'Guy','Himuro','guy.himuro@sqltutorial.org','515.127.4565','1998-11-15',13,2600.00,114,3);
+INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (119,'Karen','Colmenares','karen.colmenares@sqltutorial.org','515.127.4566','1999-08-10',13,2500.00,114,3);
+INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (120,'Matthew','Weiss','matthew.weiss@sqltutorial.org','650.123.1234','1996-07-18',19,8000.00,100,5);
+INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (121,'Adam','Fripp','adam.fripp@sqltutorial.org','650.123.2234','1997-04-10',19,8200.00,100,5);
+INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (122,'Payam','Kaufling','payam.kaufling@sqltutorial.org','650.123.3234','1995-05-01',19,7900.00,100,5);
+INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (123,'Shanta','Vollman','shanta.vollman@sqltutorial.org','650.123.4234','1997-10-10',19,6500.00,100,5);
+INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (126,'Irene','Mikkilineni','irene.mikkilineni@sqltutorial.org','650.124.1224','1998-09-28',18,2700.00,120,5);
+INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (145,'John','Russell','john.russell@sqltutorial.org',NULL,'1996-10-01',15,14000.00,100,8);
+INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (146,'Karen','Partners','karen.partners@sqltutorial.org',NULL,'1997-01-05',15,13500.00,100,8);
+INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (176,'Jonathon','Taylor','jonathon.taylor@sqltutorial.org',NULL,'1998-03-24',16,8600.00,100,8);
+INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (177,'Jack','Livingston','jack.livingston@sqltutorial.org',NULL,'1998-04-23',16,8400.00,100,8);
+INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (178,'Kimberely','Grant','kimberely.grant@sqltutorial.org',NULL,'1999-05-24',16,7000.00,100,8);
+INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (179,'Charles','Johnson','charles.johnson@sqltutorial.org',NULL,'2000-01-04',16,6200.00,100,8);
+INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (192,'Sarah','Bell','sarah.bell@sqltutorial.org','650.501.1876','1996-02-04',17,4000.00,123,5);
+INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (193,'Britney','Everett','britney.everett@sqltutorial.org','650.501.2876','1997-03-03',17,3900.00,123,5);
+INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (200,'Jennifer','Whalen','jennifer.whalen@sqltutorial.org','515.123.4444','1987-09-17',3,4400.00,101,1);
+INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (201,'Michael','Hartstein','michael.hartstein@sqltutorial.org','515.123.5555','1996-02-17',10,13000.00,100,2);
+INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (202,'Pat','Fay','pat.fay@sqltutorial.org','603.123.6666','1997-08-17',11,6000.00,201,2);
+INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (203,'Susan','Mavris','susan.mavris@sqltutorial.org','515.123.7777','1994-06-07',8,6500.00,101,4);
+INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (204,'Hermann','Baer','hermann.baer@sqltutorial.org','515.123.8888','1994-06-07',12,10000.00,101,7);
+INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (205,'Shelley','Higgins','shelley.higgins@sqltutorial.org','515.123.8080','1994-06-07',2,12000.00,101,11);
+INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (206,'William','Gietz','william.gietz@sqltutorial.org','515.123.8181','1994-06-07',1,8300.00,205,11);
+
+
+/*Data for the table dependents */
+
+INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (1,'Penelope','Gietz','Child',206);
+INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (2,'Nick','Higgins','Child',205);
+INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (3,'Ed','Whalen','Child',200);
+INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (4,'Jennifer','King','Child',100);
+INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (5,'Johnny','Kochhar','Child',101);
+INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (6,'Bette','De Haan','Child',102);
+INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (7,'Grace','Faviet','Child',109);
+INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (8,'Matthew','Chen','Child',110);
+INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (9,'Joe','Sciarra','Child',111);
+INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (10,'Christian','Urman','Child',112);
+INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (11,'Zero','Popp','Child',113);
+INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (12,'Karl','Greenberg','Child',108);
+INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (13,'Uma','Mavris','Child',203);
+INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (14,'Vivien','Hunold','Child',103);
+INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (15,'Cuba','Ernst','Child',104);
+INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (16,'Fred','Austin','Child',105);
+INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (17,'Helen','Pataballa','Child',106);
+INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (18,'Dan','Lorentz','Child',107);
+INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (19,'Bob','Hartstein','Child',201);
+INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (20,'Lucille','Fay','Child',202);
+INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (21,'Kirsten','Baer','Child',204);
+INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (22,'Elvis','Khoo','Child',115);
+INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (23,'Sandra','Baida','Child',116);
+INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (24,'Cameron','Tobias','Child',117);
+INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (25,'Kevin','Himuro','Child',118);
+INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (26,'Rip','Colmenares','Child',119);
+INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (27,'Julia','Raphaely','Child',114);
+INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (28,'Woody','Russell','Child',145);
+INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (29,'Alec','Partners','Child',146);
+INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (30,'Sandra','Taylor','Child',176);
diff --git a/internal/testutil/testdata/postgres/humanresources/teardown.sql b/internal/testutil/testdata/postgres/humanresources/teardown.sql
new file mode 100644
index 0000000000..dc2f4c304e
--- /dev/null
+++ b/internal/testutil/testdata/postgres/humanresources/teardown.sql
@@ -0,0 +1 @@
+DROP SCHEMA IF EXISTS humanresources CASCADE;
diff --git a/internal/testutil/utils.go b/internal/testutil/utils.go
new file mode 100644
index 0000000000..10001a4cf0
--- /dev/null
+++ b/internal/testutil/utils.go
@@ -0,0 +1,31 @@
+package testutil
+
+import (
+ "fmt"
+ "io"
+ "log/slog"
+ "os"
+
+ charmlog "github.com/charmbracelet/log"
+)
+
+func ShouldRunIntegrationTest() bool {
+ evkey := "INTEGRATION_TESTS_ENABLED"
+ shouldRun := os.Getenv(evkey)
+ if shouldRun != "1" {
+ slog.Warn(fmt.Sprintf("skipping integration tests, set %s=1 to enable", evkey))
+ return false
+ }
+ return true
+}
+
+func GetTestSlogger() *slog.Logger {
+ return slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{}))
+}
+
+func GetTestCharmSlogger() *slog.Logger {
+ charmlogger := charmlog.NewWithOptions(io.Discard, charmlog.Options{
+ Level: charmlog.DebugLevel,
+ })
+ return slog.New(charmlogger)
+}
diff --git a/worker/pkg/benthos/config.go b/worker/pkg/benthos/config.go
index 4e6a042c98..99388eeaf4 100644
--- a/worker/pkg/benthos/config.go
+++ b/worker/pkg/benthos/config.go
@@ -17,6 +17,7 @@ type HTTPConfig struct {
}
type StreamConfig struct {
+ Logger *LoggerConfig `json:"logger" yaml:"logger,omitempty"`
Input *InputConfig `json:"input" yaml:"input"`
Buffer *BufferConfig `json:"buffer,omitempty" yaml:"buffer,omitempty"`
Pipeline *PipelineConfig `json:"pipeline" yaml:"pipeline"`
@@ -25,6 +26,11 @@ type StreamConfig struct {
Metrics *Metrics `json:"metrics,omitempty" yaml:"metrics,omitempty"`
}
+type LoggerConfig struct {
+ Level string `json:"level" yaml:"level"`
+ AddTimestamp bool `json:"add_timestamp" yaml:"add_timestamp"`
+}
+
type Metrics struct {
OtelCollector *MetricsOtelCollector `json:"otel_collector,omitempty" yaml:"otel_collector,omitempty"`
Mapping string `json:"mapping,omitempty" yaml:"mapping,omitempty"`
@@ -54,13 +60,25 @@ type InputConfig struct {
}
type Inputs struct {
- SqlSelect *SqlSelect `json:"sql_select,omitempty" yaml:"sql_select,omitempty"`
- PooledSqlRaw *InputPooledSqlRaw `json:"pooled_sql_raw,omitempty" yaml:"pooled_sql_raw,omitempty"`
- Generate *Generate `json:"generate,omitempty" yaml:"generate,omitempty"`
- OpenAiGenerate *OpenAiGenerate `json:"openai_generate,omitempty" yaml:"openai_generate,omitempty"`
- MongoDB *InputMongoDb `json:"mongodb,omitempty" yaml:"mongodb,omitempty"`
- PooledMongoDB *InputMongoDb `json:"pooled_mongodb,omitempty" yaml:"pooled_mongodb,omitempty"`
- AwsDynamoDB *InputAwsDynamoDB `json:"aws_dynamodb,omitempty" yaml:"aws_dynamodb,omitempty"`
+ SqlSelect *SqlSelect `json:"sql_select,omitempty" yaml:"sql_select,omitempty"`
+ PooledSqlRaw *InputPooledSqlRaw `json:"pooled_sql_raw,omitempty" yaml:"pooled_sql_raw,omitempty"`
+ Generate *Generate `json:"generate,omitempty" yaml:"generate,omitempty"`
+ OpenAiGenerate *OpenAiGenerate `json:"openai_generate,omitempty" yaml:"openai_generate,omitempty"`
+ MongoDB *InputMongoDb `json:"mongodb,omitempty" yaml:"mongodb,omitempty"`
+ PooledMongoDB *InputMongoDb `json:"pooled_mongodb,omitempty" yaml:"pooled_mongodb,omitempty"`
+ AwsDynamoDB *InputAwsDynamoDB `json:"aws_dynamodb,omitempty" yaml:"aws_dynamodb,omitempty"`
+ NeosyncConnectionData *NeosyncConnectionData `json:"neosync_connection_data,omitempty" yaml:"neosync_connection_data,omitempty"`
+}
+
+type NeosyncConnectionData struct {
+ ApiKey *string `json:"api_key,omitempty" yaml:"api_key,omitempty"`
+ ApiUrl string `json:"api_url" yaml:"api_url"`
+ ConnectionId string `json:"connection_id" yaml:"connection_id"`
+ ConnectionType string `json:"connection_type" yaml:"connection_type"`
+ JobId *string `json:"job_id,omitempty" yaml:"job_id,omitempty"`
+ JobRunId *string `json:"job_run_id,omitempty" yaml:"job_run_id,omitempty"`
+ Schema string `json:"schema" yaml:"schema"`
+ Table string `json:"table" yaml:"table"`
}
type InputAwsDynamoDB struct {
diff --git a/worker/pkg/benthos/environment/environment.go b/worker/pkg/benthos/environment/environment.go
index fff3b1c769..a9547af794 100644
--- a/worker/pkg/benthos/environment/environment.go
+++ b/worker/pkg/benthos/environment/environment.go
@@ -5,11 +5,13 @@ import (
"fmt"
"log/slog"
+ "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1/mgmtv1alpha1connect"
neosync_benthos_defaulttransform "github.com/nucleuscloud/neosync/worker/pkg/benthos/default_transform"
neosync_benthos_dynamodb "github.com/nucleuscloud/neosync/worker/pkg/benthos/dynamodb"
neosync_benthos_error "github.com/nucleuscloud/neosync/worker/pkg/benthos/error"
benthos_metrics "github.com/nucleuscloud/neosync/worker/pkg/benthos/metrics"
neosync_benthos_mongodb "github.com/nucleuscloud/neosync/worker/pkg/benthos/mongodb"
+ neosync_benthos_connectiondata "github.com/nucleuscloud/neosync/worker/pkg/benthos/neosync_connection_data"
openaigenerate "github.com/nucleuscloud/neosync/worker/pkg/benthos/openai_generate"
neosync_benthos_sql "github.com/nucleuscloud/neosync/worker/pkg/benthos/sql"
"github.com/warpstreamlabs/bento/public/bloblang"
@@ -24,6 +26,8 @@ type RegisterConfig struct {
mongoConfig *MongoConfig // nil to disable
+ connectionDataConfig *ConnectionDataConfig // nil to diable
+
stopChannel chan<- error
blobEnv *bloblang.Environment
@@ -52,6 +56,11 @@ func WithMongoConfig(mongocfg *MongoConfig) Option {
cfg.mongoConfig = mongocfg
}
}
+func WithConnectionDataConfig(connectionDataCfg *ConnectionDataConfig) Option {
+ return func(cfg *RegisterConfig) {
+ cfg.connectionDataConfig = connectionDataCfg
+ }
+}
func WithBlobEnv(b *bloblang.Environment) Option {
return func(cfg *RegisterConfig) {
cfg.blobEnv = b
@@ -67,6 +76,10 @@ type MongoConfig struct {
Provider neosync_benthos_mongodb.MongoPoolProvider
}
+type ConnectionDataConfig struct {
+ NeosyncConnectionDataApi mgmtv1alpha1connect.ConnectionDataServiceClient
+}
+
func NewEnvironment(logger *slog.Logger, opts ...Option) (*service.Environment, error) {
return NewWithEnvironment(service.NewEnvironment(), logger, opts...)
}
@@ -118,6 +131,13 @@ func NewWithEnvironment(env *service.Environment, logger *slog.Logger, opts ...O
}
}
+ if config.connectionDataConfig != nil {
+ err := neosync_benthos_connectiondata.RegisterNeosyncConnectionDataInput(env, config.connectionDataConfig.NeosyncConnectionDataApi)
+ if err != nil {
+ return nil, fmt.Errorf("unable to register neosync_connection_data input: %w", err)
+ }
+ }
+
err := openaigenerate.RegisterOpenaiGenerate(env)
if err != nil {
return nil, fmt.Errorf("unable to register openai_generate input to benthos instance: %w", err)
diff --git a/cli/internal/benthos/inputs/neosync-connection-data.go b/worker/pkg/benthos/neosync_connection_data/neosync_connection_data_input.go
similarity index 84%
rename from cli/internal/benthos/inputs/neosync-connection-data.go
rename to worker/pkg/benthos/neosync_connection_data/neosync_connection_data_input.go
index fe77f1c029..207337c84a 100644
--- a/cli/internal/benthos/inputs/neosync-connection-data.go
+++ b/worker/pkg/benthos/neosync_connection_data/neosync_connection_data_input.go
@@ -1,4 +1,4 @@
-package input
+package neosync_benthos_connectiondata
import (
"context"
@@ -8,12 +8,8 @@ import (
"connectrpc.com/connect"
mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1"
"github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1/mgmtv1alpha1connect"
- "github.com/nucleuscloud/neosync/cli/internal/auth"
- auth_interceptor "github.com/nucleuscloud/neosync/cli/internal/connect/interceptors/auth"
- "github.com/nucleuscloud/neosync/cli/internal/version"
neosync_dynamodb "github.com/nucleuscloud/neosync/internal/dynamodb"
neosync_metadata "github.com/nucleuscloud/neosync/worker/pkg/benthos/metadata"
- http_client "github.com/nucleuscloud/neosync/worker/pkg/http/client"
"github.com/warpstreamlabs/bento/public/service"
)
@@ -28,21 +24,7 @@ var neosyncConnectionDataConfigSpec = service.NewConfigSpec().
Field(service.NewStringField("job_id").Optional()).
Field(service.NewStringField("job_run_id").Optional())
-func newNeosyncConnectionDataInput(conf *service.ParsedConfig) (service.Input, error) {
- var apiKey *string
- if conf.Contains("api_key") {
- apiKeyStr, err := conf.FieldString("api_key")
- if err != nil {
- return nil, err
- }
- apiKey = &apiKeyStr
- }
-
- apiUrl, err := conf.FieldString("api_url")
- if err != nil {
- return nil, err
- }
-
+func newNeosyncConnectionDataInput(conf *service.ParsedConfig, neosyncConnectApi mgmtv1alpha1connect.ConnectionDataServiceClient) (service.Input, error) {
connectionId, err := conf.FieldString("connection_id")
if err != nil {
return nil, err
@@ -80,8 +62,6 @@ func newNeosyncConnectionDataInput(conf *service.ParsedConfig) (service.Input, e
}
return service.AutoRetryNacks(&neosyncInput{
- apiKey: apiKey,
- apiUrl: apiUrl,
connectionId: connectionId,
connectionType: connectionType,
schema: schema,
@@ -90,18 +70,17 @@ func newNeosyncConnectionDataInput(conf *service.ParsedConfig) (service.Input, e
jobId: jobId,
jobRunId: jobRunId,
},
+ neosyncConnectApi: neosyncConnectApi,
}), nil
}
-func init() {
- err := service.RegisterInput(
+func RegisterNeosyncConnectionDataInput(env *service.Environment, neosyncConnectApi mgmtv1alpha1connect.ConnectionDataServiceClient) error {
+ return env.RegisterInput(
"neosync_connection_data", neosyncConnectionDataConfigSpec,
func(conf *service.ParsedConfig, mgr *service.Resources) (service.Input, error) {
- return newNeosyncConnectionDataInput(conf)
- })
- if err != nil {
- panic(err)
- }
+ return newNeosyncConnectionDataInput(conf, neosyncConnectApi)
+ },
+ )
}
//------------------------------------------------------------------------------
@@ -112,9 +91,6 @@ type connOpts struct {
}
type neosyncInput struct {
- apiKey *string
- apiUrl string
-
connectionId string
connectionType string
connectionOpts *connOpts
@@ -129,12 +105,6 @@ type neosyncInput struct {
}
func (g *neosyncInput) Connect(ctx context.Context) error {
- g.neosyncConnectApi = mgmtv1alpha1connect.NewConnectionDataServiceClient(
- http_client.NewWithHeaders(version.Get().Headers()),
- g.apiUrl,
- connect.WithInterceptors(auth_interceptor.NewInterceptor(g.apiKey != nil, auth.AuthHeader, auth.GetAuthHeaderTokenFn(g.apiKey))),
- )
-
var streamCfg *mgmtv1alpha1.ConnectionStreamConfig
if g.connectionType == "awsS3" {
diff --git a/worker/pkg/workflows/datasync/workflow/integration_test.go b/worker/pkg/workflows/datasync/workflow/integration_test.go
index c24f62ca74..14f98d5105 100644
--- a/worker/pkg/workflows/datasync/workflow/integration_test.go
+++ b/worker/pkg/workflows/datasync/workflow/integration_test.go
@@ -3,29 +3,24 @@ package datasync_workflow
import (
"context"
"database/sql"
- "errors"
"fmt"
"log/slog"
- "net/url"
"os"
"testing"
- "time"
"github.com/aws/aws-sdk-go-v2/service/dynamodb"
dyntypes "github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
"github.com/docker/go-connections/nat"
_ "github.com/go-sql-driver/mysql"
- "github.com/jackc/pgx/v5/pgxpool"
mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1"
sqlmanager_shared "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/shared"
awsmanager "github.com/nucleuscloud/neosync/internal/aws"
+ tcmysql "github.com/nucleuscloud/neosync/internal/testutil/testcontainers/mysql"
+ tcpostgres "github.com/nucleuscloud/neosync/internal/testutil/testcontainers/postgres"
"github.com/stretchr/testify/suite"
"github.com/testcontainers/testcontainers-go"
testmongodb "github.com/testcontainers/testcontainers-go/modules/mongodb"
testmssql "github.com/testcontainers/testcontainers-go/modules/mssql"
- testmysql "github.com/testcontainers/testcontainers-go/modules/mysql"
- "github.com/testcontainers/testcontainers-go/modules/postgres"
- testpg "github.com/testcontainers/testcontainers-go/modules/postgres"
"github.com/testcontainers/testcontainers-go/modules/redis"
"github.com/testcontainers/testcontainers-go/wait"
"go.mongodb.org/mongo-driver/mongo"
@@ -33,20 +28,6 @@ import (
"golang.org/x/sync/errgroup"
)
-type postgresTestContainer struct {
- pool *pgxpool.Pool
- url string
-}
-type postgresTest struct {
- pool *pgxpool.Pool
- testcontainer *testpg.PostgresContainer
-
- source *postgresTestContainer
- target *postgresTestContainer
-
- databases []string
-}
-
type mssqlTest struct {
pool *sql.DB
testcontainer *testmssql.MSSQLServerContainer
@@ -59,18 +40,6 @@ type mssqlTestContainer struct {
url string
}
-type mysqlTestContainer struct {
- pool *sql.DB
- container *testmysql.MySQLContainer
- url string
- close func()
-}
-
-type mysqlTest struct {
- source *mysqlTestContainer
- target *mysqlTestContainer
-}
-
type redisTest struct {
url string
testcontainer *redis.RedisContainer
@@ -91,8 +60,8 @@ type IntegrationTestSuite struct {
ctx context.Context
- mysql *mysqlTest
- postgres *postgresTest
+ mysql *tcmysql.MysqlTestSyncContainer
+ postgres *tcpostgres.PostgresTestSyncContainer
mssql *mssqlTest
redis *redisTest
dynamo *dynamodbTest
@@ -213,149 +182,20 @@ func createMssqlTest(ctx context.Context, mssqlcontainer *testmssql.MSSQLServerC
}, nil
}
-func (s *IntegrationTestSuite) SetupPostgres() (*postgresTest, error) {
- pgcontainer, err := testpg.Run(
- s.ctx,
- "postgres:15",
- postgres.WithDatabase("postgres"),
- testcontainers.WithWaitStrategy(
- wait.ForLog("database system is ready to accept connections").
- WithOccurrence(2).WithStartupTimeout(20*time.Second),
- ),
- )
- if err != nil {
- return nil, err
- }
- postgresTest := &postgresTest{
- testcontainer: pgcontainer,
- }
- connstr, err := pgcontainer.ConnectionString(s.ctx, "sslmode=disable")
- if err != nil {
- return nil, err
- }
-
- postgresTest.databases = []string{"datasync_source", "datasync_target"}
- pool, err := pgxpool.New(s.ctx, connstr)
- if err != nil {
- return nil, err
- }
- postgresTest.pool = pool
-
- s.T().Logf("creating databases. %+v \n", postgresTest.databases)
- for _, db := range postgresTest.databases {
- _, err = postgresTest.pool.Exec(s.ctx, fmt.Sprintf("CREATE DATABASE %s;", db))
- if err != nil {
- return nil, err
- }
- }
-
- srcUrl, err := getDbPgUrl(connstr, "datasync_source", "disable")
- if err != nil {
- return nil, err
- }
- postgresTest.source = &postgresTestContainer{
- url: srcUrl,
- }
- sourceConn, err := pgxpool.New(s.ctx, postgresTest.source.url)
- if err != nil {
- return nil, err
- }
- postgresTest.source.pool = sourceConn
-
- targetUrl, err := getDbPgUrl(connstr, "datasync_target", "disable")
- if err != nil {
- return nil, err
- }
- postgresTest.target = &postgresTestContainer{
- url: targetUrl,
- }
- targetConn, err := pgxpool.New(s.ctx, postgresTest.target.url)
+func (s *IntegrationTestSuite) SetupPostgres() (*tcpostgres.PostgresTestSyncContainer, error) {
+ container, err := tcpostgres.NewPostgresTestSyncContainer(s.ctx, []tcpostgres.Option{}, []tcpostgres.Option{})
if err != nil {
return nil, err
}
- postgresTest.target.pool = targetConn
- return postgresTest, nil
+ return container, nil
}
-func (s *IntegrationTestSuite) SetupMysql() (*mysqlTest, error) {
- var source *mysqlTestContainer
- var target *mysqlTestContainer
-
- errgrp := errgroup.Group{}
- errgrp.Go(func() error {
- sourcecontainer, err := createMysqlTestContainer(s.ctx, "datasync", "root", "pass-source")
- if err != nil {
- return err
- }
- source = sourcecontainer
- return nil
- })
-
- errgrp.Go(func() error {
- targetcontainer, err := createMysqlTestContainer(s.ctx, "datasync", "root", "pass-target")
- if err != nil {
- return err
- }
- target = targetcontainer
- return nil
- })
-
- err := errgrp.Wait()
+func (s *IntegrationTestSuite) SetupMysql() (*tcmysql.MysqlTestSyncContainer, error) {
+ container, err := tcmysql.NewMysqlTestSyncContainer(s.ctx, []tcmysql.Option{}, []tcmysql.Option{})
if err != nil {
return nil, err
}
-
- return &mysqlTest{
- source: source,
- target: target,
- }, nil
-}
-
-func createMysqlTestContainer(
- ctx context.Context,
- database, username, password string,
-) (*mysqlTestContainer, error) {
- container, err := testmysql.Run(ctx,
- "mysql:8.0.36",
- testmysql.WithDatabase(database),
- testmysql.WithUsername(username),
- testmysql.WithPassword(password),
- testcontainers.WithWaitStrategy(
- wait.ForLog("port: 3306 MySQL Community Server").
- WithOccurrence(1).WithStartupTimeout(20*time.Second),
- ),
- )
- if err != nil {
- return nil, err
- }
- connstr, err := container.ConnectionString(ctx, "multiStatements=true&parseTime=true")
- if err != nil {
- panic(err)
- }
- pool, err := sql.Open(sqlmanager_shared.MysqlDriver, connstr)
- if err != nil {
- panic(err)
- }
- containerPort, err := container.MappedPort(ctx, "3306/tcp")
- if err != nil {
- return nil, err
- }
- containerHost, err := container.Host(ctx)
- if err != nil {
- return nil, err
- }
-
- connUrl := fmt.Sprintf("mysql://%s:%s@%s:%s/%s?multiStatements=true&parseTime=true", username, password, containerHost, containerPort.Port(), database)
- return &mysqlTestContainer{
- pool: pool,
- url: connUrl,
- container: container,
- close: func() {
- if pool != nil {
- pool.Close()
- }
- },
- }, nil
+ return container, nil
}
func (s *IntegrationTestSuite) SetupRedis() (*redisTest, error) {
@@ -445,8 +285,8 @@ func (s *IntegrationTestSuite) SetupDynamoDB() (*dynamodbTest, error) {
func (s *IntegrationTestSuite) SetupSuite() {
s.ctx = context.Background()
- var postgresTest *postgresTest
- var mysqlTest *mysqlTest
+ var postgresTest *tcpostgres.PostgresTestSyncContainer
+ var mysqlTest *tcmysql.MysqlTestSyncContainer
var mssqlTest *mssqlTest
var redisTest *redisTest
var dynamoTest *dynamodbTest
@@ -520,20 +360,6 @@ func (s *IntegrationTestSuite) SetupSuite() {
s.mongodb = mongodbTest
}
-func (s *IntegrationTestSuite) RunPostgresSqlFiles(pool *pgxpool.Pool, testFolder string, files []string) {
- s.T().Logf("running postgres sql file. folder: %s \n", testFolder)
- for _, file := range files {
- sqlStr, err := os.ReadFile(fmt.Sprintf("./testdata/%s/%s", testFolder, file))
- if err != nil {
- panic(err)
- }
- _, err = pool.Exec(s.ctx, string(sqlStr))
- if err != nil {
- panic(fmt.Errorf("unable to exec sql when running postgres sql files: %w", err))
- }
- }
-}
-
func (s *IntegrationTestSuite) RunMysqlSqlFiles(pool *sql.DB, testFolder string, files []string) {
s.T().Logf("running mysql sql file. folder: %s \n", testFolder)
for _, file := range files {
@@ -662,26 +488,9 @@ func (s *IntegrationTestSuite) TearDownSuite() {
s.T().Log("tearing down test suite")
// postgres
if s.postgres != nil {
- for _, db := range s.postgres.databases {
- _, err := s.postgres.pool.Exec(s.ctx, fmt.Sprintf("DROP DATABASE IF EXISTS %s WITH (FORCE);", db))
- if err != nil {
- panic(err)
- }
- }
- if s.postgres.source.pool != nil {
- s.postgres.source.pool.Close()
- }
- if s.postgres.target.pool != nil {
- s.postgres.target.pool.Close()
- }
- if s.postgres.pool != nil {
- s.postgres.pool.Close()
- }
- if s.postgres.testcontainer != nil {
- err := s.postgres.testcontainer.Terminate(s.ctx)
- if err != nil {
- panic(err)
- }
+ err := s.postgres.TearDown(s.ctx)
+ if err != nil {
+ panic(err)
}
}
@@ -706,19 +515,9 @@ func (s *IntegrationTestSuite) TearDownSuite() {
// mysql
if s.mysql != nil {
- s.mysql.source.close()
- s.mysql.target.close()
- if s.mysql.source.container != nil {
- err := s.mysql.source.container.Terminate(s.ctx)
- if err != nil {
- panic(err)
- }
- }
- if s.mysql.target.container != nil {
- err := s.mysql.target.container.Terminate(s.ctx)
- if err != nil {
- panic(err)
- }
+ err := s.mysql.TearDown(s.ctx)
+ if err != nil {
+ panic(err)
}
}
@@ -774,19 +573,3 @@ func TestIntegrationTestSuite(t *testing.T) {
}
suite.Run(t, new(IntegrationTestSuite))
}
-
-func getDbPgUrl(dburl, database, sslmode string) (string, error) {
- u, err := url.Parse(dburl)
- if err != nil {
- var urlErr *url.Error
- if errors.As(err, &urlErr) {
- return "", fmt.Errorf("unable to parse postgres url [%s]: %w", urlErr.Op, urlErr.Err)
- }
- return "", fmt.Errorf("unable to parse postgres url: %w", err)
- }
-
- u.Path = database
- query := u.Query()
- query.Add("sslmode", sslmode)
- return u.String(), nil
-}
diff --git a/worker/pkg/workflows/datasync/workflow/testdata/javascript-transformers/tests.go b/worker/pkg/workflows/datasync/workflow/testdata/javascript-transformers/tests.go
index 572b63a343..d6b6c9c32d 100644
--- a/worker/pkg/workflows/datasync/workflow/testdata/javascript-transformers/tests.go
+++ b/worker/pkg/workflows/datasync/workflow/testdata/javascript-transformers/tests.go
@@ -9,7 +9,7 @@ func GetSyncTests() []*workflow_testdata.IntegrationTest {
return []*workflow_testdata.IntegrationTest{
{
Name: "Javascript transformer sync",
- Folder: "javascript-transformers",
+ Folder: "testdata/javascript-transformers",
SourceFilePaths: []string{"create.sql", "insert.sql"},
TargetFilePaths: []string{"create.sql"},
JobMappings: getJsTransformerJobmappings(),
@@ -19,7 +19,7 @@ func GetSyncTests() []*workflow_testdata.IntegrationTest {
},
{
Name: "Javascript generator sync",
- Folder: "javascript-transformers",
+ Folder: "testdata/javascript-transformers",
SourceFilePaths: []string{"create.sql", "insert.sql"},
TargetFilePaths: []string{"create.sql"},
JobMappings: getJsGeneratorJobmappings(),
diff --git a/worker/pkg/workflows/datasync/workflow/testdata/mysql/all-types/tests.go b/worker/pkg/workflows/datasync/workflow/testdata/mysql/all-types/tests.go
index fbe7bc3ffd..b656fdff9f 100644
--- a/worker/pkg/workflows/datasync/workflow/testdata/mysql/all-types/tests.go
+++ b/worker/pkg/workflows/datasync/workflow/testdata/mysql/all-types/tests.go
@@ -8,7 +8,7 @@ func GetSyncTests() []*workflow_testdata.IntegrationTest {
return []*workflow_testdata.IntegrationTest{
{
Name: "All datatypes passthrough",
- Folder: "mysql/all-types",
+ Folder: "testdata/mysql/all-types",
SourceFilePaths: []string{"create.sql", "insert.sql"},
TargetFilePaths: []string{"create-dbs.sql"},
JobMappings: GetDefaultSyncJobMappings(),
diff --git a/worker/pkg/workflows/datasync/workflow/testdata/mysql/composite-keys/tests.go b/worker/pkg/workflows/datasync/workflow/testdata/mysql/composite-keys/tests.go
index ba260078da..d525b15c09 100644
--- a/worker/pkg/workflows/datasync/workflow/testdata/mysql/composite-keys/tests.go
+++ b/worker/pkg/workflows/datasync/workflow/testdata/mysql/composite-keys/tests.go
@@ -10,7 +10,7 @@ func GetSyncTests() []*workflow_testdata.IntegrationTest {
return []*workflow_testdata.IntegrationTest{
{
Name: "Composite key transformation + truncate",
- Folder: "mysql/composite-keys",
+ Folder: "testdata/mysql/composite-keys",
SourceFilePaths: []string{"create.sql", "insert.sql"},
TargetFilePaths: []string{"create.sql", "insert.sql"},
JobMappings: getPkTransformerJobmappings(),
diff --git a/worker/pkg/workflows/datasync/workflow/testdata/mysql/init-schema/tests.go b/worker/pkg/workflows/datasync/workflow/testdata/mysql/init-schema/tests.go
index e571b84f74..c5709d0ace 100644
--- a/worker/pkg/workflows/datasync/workflow/testdata/mysql/init-schema/tests.go
+++ b/worker/pkg/workflows/datasync/workflow/testdata/mysql/init-schema/tests.go
@@ -9,7 +9,7 @@ func GetSyncTests() []*workflow_testdata.IntegrationTest {
return []*workflow_testdata.IntegrationTest{
{
Name: "Init Schema",
- Folder: "mysql/init-schema",
+ Folder: "testdata/mysql/init-schema",
SourceFilePaths: []string{"create.sql", "insert.sql"},
TargetFilePaths: []string{"create-dbs.sql"},
JobMappings: getJobmappings(),
diff --git a/worker/pkg/workflows/datasync/workflow/testdata/mysql/multiple-dbs/tests.go b/worker/pkg/workflows/datasync/workflow/testdata/mysql/multiple-dbs/tests.go
index eb0698e241..ee2913649e 100644
--- a/worker/pkg/workflows/datasync/workflow/testdata/mysql/multiple-dbs/tests.go
+++ b/worker/pkg/workflows/datasync/workflow/testdata/mysql/multiple-dbs/tests.go
@@ -6,7 +6,7 @@ func GetSyncTests() []*workflow_testdata.IntegrationTest {
return []*workflow_testdata.IntegrationTest{
{
Name: "multiple databases sync + init schema",
- Folder: "mysql/multiple-dbs",
+ Folder: "testdata/mysql/multiple-dbs",
SourceFilePaths: []string{"create-dbs.sql", "create.sql", "insert.sql"},
TargetFilePaths: []string{"create-dbs.sql"},
JobMappings: GetDefaultSyncJobMappings(),
diff --git a/worker/pkg/workflows/datasync/workflow/testdata/postgres/all-types/tests.go b/worker/pkg/workflows/datasync/workflow/testdata/postgres/all-types/tests.go
index e037a36bde..88f590328a 100644
--- a/worker/pkg/workflows/datasync/workflow/testdata/postgres/all-types/tests.go
+++ b/worker/pkg/workflows/datasync/workflow/testdata/postgres/all-types/tests.go
@@ -6,7 +6,7 @@ func GetSyncTests() []*workflow_testdata.IntegrationTest {
return []*workflow_testdata.IntegrationTest{
{
Name: "All Postgres types",
- Folder: "postgres/all-types",
+ Folder: "testdata/postgres/all-types",
SourceFilePaths: []string{"setup.sql"},
TargetFilePaths: []string{"schema-create.sql", "setup.sql"},
JobMappings: GetDefaultSyncJobMappings(),
@@ -23,7 +23,7 @@ func GetSyncTests() []*workflow_testdata.IntegrationTest {
},
{
Name: "All Postgres types + init schema",
- Folder: "postgres/all-types",
+ Folder: "testdata/postgres/all-types",
SourceFilePaths: []string{"setup.sql"},
TargetFilePaths: []string{"schema-create.sql"},
JobMappings: GetDefaultSyncJobMappings(),
diff --git a/worker/pkg/workflows/datasync/workflow/testdata/postgres/circular-dependencies/tests.go b/worker/pkg/workflows/datasync/workflow/testdata/postgres/circular-dependencies/tests.go
index eb8d6cb847..dcead5c2ab 100644
--- a/worker/pkg/workflows/datasync/workflow/testdata/postgres/circular-dependencies/tests.go
+++ b/worker/pkg/workflows/datasync/workflow/testdata/postgres/circular-dependencies/tests.go
@@ -6,7 +6,7 @@ func GetSyncTests() []*workflow_testdata.IntegrationTest {
return []*workflow_testdata.IntegrationTest{
{
Name: "Circular Dependency sync + init schema",
- Folder: "postgres/circular-dependencies",
+ Folder: "testdata/postgres/circular-dependencies",
SourceFilePaths: []string{"setup.sql"},
TargetFilePaths: []string{"schema-create.sql"},
JobMappings: GetDefaultSyncJobMappings(),
diff --git a/worker/pkg/workflows/datasync/workflow/testdata/postgres/double-reference/tests.go b/worker/pkg/workflows/datasync/workflow/testdata/postgres/double-reference/tests.go
index a3385bb792..b12d75798b 100644
--- a/worker/pkg/workflows/datasync/workflow/testdata/postgres/double-reference/tests.go
+++ b/worker/pkg/workflows/datasync/workflow/testdata/postgres/double-reference/tests.go
@@ -6,7 +6,7 @@ func GetSyncTests() []*workflow_testdata.IntegrationTest {
return []*workflow_testdata.IntegrationTest{
{
Name: "Double reference sync",
- Folder: "postgres/double-reference",
+ Folder: "testdata/postgres/double-reference",
SourceFilePaths: []string{"source-create.sql", "insert.sql"},
TargetFilePaths: []string{"source-create.sql"},
JobMappings: GetDefaultSyncJobMappings(),
@@ -19,7 +19,7 @@ func GetSyncTests() []*workflow_testdata.IntegrationTest {
},
{
Name: "Double reference subset",
- Folder: "postgres/double-reference",
+ Folder: "testdata/postgres/double-reference",
SourceFilePaths: []string{"source-create.sql", "insert.sql"},
TargetFilePaths: []string{"source-create.sql"},
SubsetMap: map[string]string{
diff --git a/worker/pkg/workflows/datasync/workflow/testdata/postgres/subsetting/tests.go b/worker/pkg/workflows/datasync/workflow/testdata/postgres/subsetting/tests.go
index 7e4b02989e..2039f46e0d 100644
--- a/worker/pkg/workflows/datasync/workflow/testdata/postgres/subsetting/tests.go
+++ b/worker/pkg/workflows/datasync/workflow/testdata/postgres/subsetting/tests.go
@@ -6,7 +6,7 @@ func GetSyncTests() []*workflow_testdata.IntegrationTest {
return []*workflow_testdata.IntegrationTest{
{
Name: "Complex subsetting",
- Folder: "postgres/subsetting",
+ Folder: "testdata/postgres/subsetting",
SourceFilePaths: []string{"setup.sql"},
TargetFilePaths: []string{"schema-create.sql"},
JobMappings: GetDefaultSyncJobMappings(),
diff --git a/worker/pkg/workflows/datasync/workflow/testdata/postgres/virtual-foreign-keys/tests.go b/worker/pkg/workflows/datasync/workflow/testdata/postgres/virtual-foreign-keys/tests.go
index bb71642c20..4d80b78a33 100644
--- a/worker/pkg/workflows/datasync/workflow/testdata/postgres/virtual-foreign-keys/tests.go
+++ b/worker/pkg/workflows/datasync/workflow/testdata/postgres/virtual-foreign-keys/tests.go
@@ -9,7 +9,7 @@ func GetSyncTests() []*workflow_testdata.IntegrationTest {
return []*workflow_testdata.IntegrationTest{
{
Name: "Virtual Foreign Keys sync",
- Folder: "postgres/virtual-foreign-keys",
+ Folder: "testdata/postgres/virtual-foreign-keys",
SourceFilePaths: []string{"source-setup.sql"},
TargetFilePaths: []string{"target-setup.sql"},
JobMappings: GetDefaultSyncJobMappings(),
@@ -26,7 +26,7 @@ func GetSyncTests() []*workflow_testdata.IntegrationTest {
},
{
Name: "Virtual Foreign Keys subset",
- Folder: "postgres/virtual-foreign-keys",
+ Folder: "testdata/postgres/virtual-foreign-keys",
SourceFilePaths: []string{"source-setup.sql"},
TargetFilePaths: []string{"target-setup.sql"},
SubsetMap: map[string]string{
diff --git a/worker/pkg/workflows/datasync/workflow/testdata/primary-key-transformer/tests.go b/worker/pkg/workflows/datasync/workflow/testdata/primary-key-transformer/tests.go
index 098067bfd2..628dc0e17a 100644
--- a/worker/pkg/workflows/datasync/workflow/testdata/primary-key-transformer/tests.go
+++ b/worker/pkg/workflows/datasync/workflow/testdata/primary-key-transformer/tests.go
@@ -10,7 +10,7 @@ func GetSyncTests() []*workflow_testdata.IntegrationTest {
return []*workflow_testdata.IntegrationTest{
{
Name: "Circular Dependency primary key transformation",
- Folder: "primary-key-transformer",
+ Folder: "testdata/primary-key-transformer",
SourceFilePaths: []string{"create.sql", "insert.sql"},
TargetFilePaths: []string{"create.sql"},
JobMappings: getPkTransformerJobmappings(),
diff --git a/worker/pkg/workflows/datasync/workflow/testdata/skip-fk-violations/tests.go b/worker/pkg/workflows/datasync/workflow/testdata/skip-fk-violations/tests.go
index 73cc68a622..c2259efe38 100644
--- a/worker/pkg/workflows/datasync/workflow/testdata/skip-fk-violations/tests.go
+++ b/worker/pkg/workflows/datasync/workflow/testdata/skip-fk-violations/tests.go
@@ -8,7 +8,7 @@ func GetSyncTests() []*workflow_testdata.IntegrationTest {
return []*workflow_testdata.IntegrationTest{
{
Name: "Skip Foreign Key Violations",
- Folder: "skip-fk-violations",
+ Folder: "testdata/skip-fk-violations",
SourceFilePaths: []string{"create.sql", "insert.sql"},
TargetFilePaths: []string{"create.sql"},
JobMappings: GetDefaultSyncJobMappings(),
@@ -27,7 +27,7 @@ func GetSyncTests() []*workflow_testdata.IntegrationTest {
},
{
Name: "Foreign Key Violations Error",
- Folder: "skip-fk-violations",
+ Folder: "testdata/skip-fk-violations",
SourceFilePaths: []string{"create.sql", "insert.sql"},
TargetFilePaths: []string{"create.sql"},
JobMappings: GetDefaultSyncJobMappings(),
diff --git a/worker/pkg/workflows/datasync/workflow/workflow_integration_test.go b/worker/pkg/workflows/datasync/workflow/workflow_integration_test.go
index 77153757e5..5cf2cf2c73 100644
--- a/worker/pkg/workflows/datasync/workflow/workflow_integration_test.go
+++ b/worker/pkg/workflows/datasync/workflow/workflow_integration_test.go
@@ -87,8 +87,10 @@ func (s *IntegrationTestSuite) Test_Workflow_Sync_Postgres() {
t.Run(tt.Name, func(t *testing.T) {
t.Logf("running integration test: %s \n", tt.Name)
// setup
- s.RunPostgresSqlFiles(s.postgres.source.pool, tt.Folder, tt.SourceFilePaths)
- s.RunPostgresSqlFiles(s.postgres.target.pool, tt.Folder, tt.TargetFilePaths)
+ err := s.postgres.Source.RunSqlFiles(s.ctx, &tt.Folder, tt.SourceFilePaths)
+ require.NoError(t, err)
+ err = s.postgres.Target.RunSqlFiles(s.ctx, &tt.Folder, tt.TargetFilePaths)
+ require.NoError(t, err)
schemas := []*mgmtv1alpha1.PostgresSourceSchemaOption{}
subsetMap := map[string]*mgmtv1alpha1.PostgresSourceSchemaOption{}
@@ -179,7 +181,7 @@ func (s *IntegrationTestSuite) Test_Workflow_Sync_Postgres() {
Config: &mgmtv1alpha1.ConnectionConfig_PgConfig{
PgConfig: &mgmtv1alpha1.PostgresConnectionConfig{
ConnectionConfig: &mgmtv1alpha1.PostgresConnectionConfig_Url{
- Url: s.postgres.source.url,
+ Url: s.postgres.Source.URL,
},
},
},
@@ -196,7 +198,7 @@ func (s *IntegrationTestSuite) Test_Workflow_Sync_Postgres() {
Config: &mgmtv1alpha1.ConnectionConfig_PgConfig{
PgConfig: &mgmtv1alpha1.PostgresConnectionConfig{
ConnectionConfig: &mgmtv1alpha1.PostgresConnectionConfig_Url{
- Url: s.postgres.target.url,
+ Url: s.postgres.Target.URL,
},
},
},
@@ -212,7 +214,7 @@ func (s *IntegrationTestSuite) Test_Workflow_Sync_Postgres() {
srv := startHTTPServer(t, mux)
env := executeWorkflow(t, srv, s.redis.url, "115aaf2c-776e-4847-8268-d914e3c15968")
require.Truef(t, env.IsWorkflowCompleted(), fmt.Sprintf("Workflow did not complete. Test: %s", tt.Name))
- err := env.GetWorkflowError()
+ err = env.GetWorkflowError()
if tt.ExpectError {
require.Error(t, err, "Did not received Temporal Workflow Error", "testName", tt.Name)
return
@@ -220,7 +222,7 @@ func (s *IntegrationTestSuite) Test_Workflow_Sync_Postgres() {
require.NoError(t, err, "Received Temporal Workflow Error", "testName", tt.Name)
for table, expected := range tt.Expected {
- rows, err := s.postgres.target.pool.Query(s.ctx, fmt.Sprintf("select * from %s;", table))
+ rows, err := s.postgres.Target.DB.Query(s.ctx, fmt.Sprintf("select * from %s;", table))
require.NoError(t, err)
count := 0
for rows.Next() {
@@ -230,8 +232,10 @@ func (s *IntegrationTestSuite) Test_Workflow_Sync_Postgres() {
}
// tear down
- s.RunPostgresSqlFiles(s.postgres.source.pool, tt.Folder, []string{"teardown.sql"})
- s.RunPostgresSqlFiles(s.postgres.target.pool, tt.Folder, []string{"teardown.sql"})
+ err = s.postgres.Source.RunSqlFiles(s.ctx, &tt.Folder, []string{"teardown.sql"})
+ require.NoError(t, err)
+ err = s.postgres.Target.RunSqlFiles(s.ctx, &tt.Folder, []string{"teardown.sql"})
+ require.NoError(t, err)
})
}
})
@@ -447,10 +451,12 @@ func toRunContextKeyString(id *mgmtv1alpha1.RunContextKey) string {
}
func (s *IntegrationTestSuite) Test_Workflow_VirtualForeignKeys_Transform() {
- testFolder := "postgres/virtual-foreign-keys"
+ testFolder := "testdata/postgres/virtual-foreign-keys"
// setup
- s.RunPostgresSqlFiles(s.postgres.source.pool, testFolder, []string{"source-setup.sql"})
- s.RunPostgresSqlFiles(s.postgres.target.pool, testFolder, []string{"target-setup.sql"})
+ err := s.postgres.Source.RunSqlFiles(s.ctx, &testFolder, []string{"source-setup.sql"})
+ require.NoError(s.T(), err)
+ err = s.postgres.Target.RunSqlFiles(s.ctx, &testFolder, []string{"target-setup.sql"})
+ require.NoError(s.T(), err)
virtualForeignKeys := testdata_virtualforeignkeys.GetVirtualForeignKeys()
jobmappings := testdata_virtualforeignkeys.GetDefaultSyncJobMappings()
@@ -515,7 +521,7 @@ func (s *IntegrationTestSuite) Test_Workflow_VirtualForeignKeys_Transform() {
Config: &mgmtv1alpha1.ConnectionConfig_PgConfig{
PgConfig: &mgmtv1alpha1.PostgresConnectionConfig{
ConnectionConfig: &mgmtv1alpha1.PostgresConnectionConfig_Url{
- Url: s.postgres.source.url,
+ Url: s.postgres.Source.URL,
},
},
},
@@ -532,7 +538,7 @@ func (s *IntegrationTestSuite) Test_Workflow_VirtualForeignKeys_Transform() {
Config: &mgmtv1alpha1.ConnectionConfig_PgConfig{
PgConfig: &mgmtv1alpha1.PostgresConnectionConfig{
ConnectionConfig: &mgmtv1alpha1.PostgresConnectionConfig_Url{
- Url: s.postgres.target.url,
+ Url: s.postgres.Target.URL,
},
},
},
@@ -549,12 +555,12 @@ func (s *IntegrationTestSuite) Test_Workflow_VirtualForeignKeys_Transform() {
testName := "Virtual Foreign Key primary key transform"
env := executeWorkflow(s.T(), srv, s.redis.url, "fd4d8660-31a0-48b2-9adf-10f11b94898f")
require.Truef(s.T(), env.IsWorkflowCompleted(), fmt.Sprintf("Workflow did not complete. Test: %s", testName))
- err := env.GetWorkflowError()
+ err = env.GetWorkflowError()
require.NoError(s.T(), err, "Received Temporal Workflow Error", "testName", testName)
tables := []string{"regions", "countries", "locations", "departments", "dependents", "jobs", "employees"}
for _, t := range tables {
- rows, err := s.postgres.target.pool.Query(s.ctx, fmt.Sprintf("select * from vfk_hr.%s;", t))
+ rows, err := s.postgres.Target.DB.Query(s.ctx, fmt.Sprintf("select * from vfk_hr.%s;", t))
require.NoError(s.T(), err)
count := 0
for rows.Next() {
@@ -564,35 +570,37 @@ func (s *IntegrationTestSuite) Test_Workflow_VirtualForeignKeys_Transform() {
require.NoError(s.T(), err)
}
- rows := s.postgres.source.pool.QueryRow(s.ctx, "select count(*) from vfk_hr.countries where country_id = 'US';")
+ rows := s.postgres.Source.DB.QueryRow(s.ctx, "select count(*) from vfk_hr.countries where country_id = 'US';")
var rowCount int
err = rows.Scan(&rowCount)
require.NoError(s.T(), err)
require.Equal(s.T(), 1, rowCount)
- rows = s.postgres.source.pool.QueryRow(s.ctx, "select count(*) from vfk_hr.locations where country_id = 'US';")
+ rows = s.postgres.Source.DB.QueryRow(s.ctx, "select count(*) from vfk_hr.locations where country_id = 'US';")
err = rows.Scan(&rowCount)
require.NoError(s.T(), err)
require.Equal(s.T(), 3, rowCount)
- rows = s.postgres.target.pool.QueryRow(s.ctx, "select count(*) from vfk_hr.countries where country_id = 'US';")
+ rows = s.postgres.Target.DB.QueryRow(s.ctx, "select count(*) from vfk_hr.countries where country_id = 'US';")
err = rows.Scan(&rowCount)
require.NoError(s.T(), err)
require.Equal(s.T(), 0, rowCount)
- rows = s.postgres.target.pool.QueryRow(s.ctx, "select count(*) from vfk_hr.countries where country_id = 'SU';")
+ rows = s.postgres.Target.DB.QueryRow(s.ctx, "select count(*) from vfk_hr.countries where country_id = 'SU';")
err = rows.Scan(&rowCount)
require.NoError(s.T(), err)
require.Equal(s.T(), 1, rowCount)
- rows = s.postgres.target.pool.QueryRow(s.ctx, "select count(*) from vfk_hr.locations where country_id = 'SU';")
+ rows = s.postgres.Target.DB.QueryRow(s.ctx, "select count(*) from vfk_hr.locations where country_id = 'SU';")
err = rows.Scan(&rowCount)
require.NoError(s.T(), err)
require.Equal(s.T(), 3, rowCount)
// tear down
- s.RunPostgresSqlFiles(s.postgres.source.pool, testFolder, []string{"teardown.sql"})
- s.RunPostgresSqlFiles(s.postgres.target.pool, testFolder, []string{"teardown.sql"})
+ err = s.postgres.Source.RunSqlFiles(s.ctx, &testFolder, []string{"teardown.sql"})
+ require.NoError(s.T(), err)
+ err = s.postgres.Target.RunSqlFiles(s.ctx, &testFolder, []string{"teardown.sql"})
+ require.NoError(s.T(), err)
}
func getAllMysqlSyncTests() map[string][]*workflow_testdata.IntegrationTest {
@@ -618,8 +626,10 @@ func (s *IntegrationTestSuite) Test_Workflow_Sync_Mysql() {
t.Run(tt.Name, func(t *testing.T) {
t.Logf("running integration test: %s \n", tt.Name)
// setup
- s.RunMysqlSqlFiles(s.mysql.source.pool, tt.Folder, tt.SourceFilePaths)
- s.RunMysqlSqlFiles(s.mysql.target.pool, tt.Folder, tt.TargetFilePaths)
+ err := s.mysql.Source.RunSqlFiles(s.ctx, &tt.Folder, tt.SourceFilePaths)
+ require.NoError(t, err)
+ err = s.mysql.Target.RunSqlFiles(s.ctx, &tt.Folder, tt.TargetFilePaths)
+ require.NoError(t, err)
schemas := []*mgmtv1alpha1.MysqlSourceSchemaOption{}
subsetMap := map[string]*mgmtv1alpha1.MysqlSourceSchemaOption{}
@@ -710,7 +720,7 @@ func (s *IntegrationTestSuite) Test_Workflow_Sync_Mysql() {
Config: &mgmtv1alpha1.ConnectionConfig_MysqlConfig{
MysqlConfig: &mgmtv1alpha1.MysqlConnectionConfig{
ConnectionConfig: &mgmtv1alpha1.MysqlConnectionConfig_Url{
- Url: s.mysql.source.url,
+ Url: s.mysql.Source.URL,
},
},
},
@@ -727,7 +737,7 @@ func (s *IntegrationTestSuite) Test_Workflow_Sync_Mysql() {
Config: &mgmtv1alpha1.ConnectionConfig_MysqlConfig{
MysqlConfig: &mgmtv1alpha1.MysqlConnectionConfig{
ConnectionConfig: &mgmtv1alpha1.MysqlConnectionConfig_Url{
- Url: s.mysql.target.url,
+ Url: s.mysql.Target.URL,
},
},
},
@@ -742,7 +752,7 @@ func (s *IntegrationTestSuite) Test_Workflow_Sync_Mysql() {
srv := startHTTPServer(t, mux)
env := executeWorkflow(t, srv, s.redis.url, "115aaf2c-776e-4847-8268-d914e3c15968")
require.Truef(t, env.IsWorkflowCompleted(), fmt.Sprintf("Workflow did not complete. Test: %s", tt.Name))
- err := env.GetWorkflowError()
+ err = env.GetWorkflowError()
if tt.ExpectError {
require.Error(t, err, "Did not received Temporal Workflow Error", "testName", tt.Name)
return
@@ -750,7 +760,7 @@ func (s *IntegrationTestSuite) Test_Workflow_Sync_Mysql() {
require.NoError(t, err, "Received Temporal Workflow Error", "testName", tt.Name)
for table, expected := range tt.Expected {
- rows, err := s.mysql.target.pool.QueryContext(s.ctx, fmt.Sprintf("select * from %s;", table))
+ rows, err := s.mysql.Target.DB.QueryContext(s.ctx, fmt.Sprintf("select * from %s;", table))
require.NoError(t, err)
count := 0
for rows.Next() {
@@ -760,8 +770,10 @@ func (s *IntegrationTestSuite) Test_Workflow_Sync_Mysql() {
}
// tear down
- s.RunMysqlSqlFiles(s.mysql.source.pool, tt.Folder, []string{"teardown.sql"})
- s.RunMysqlSqlFiles(s.mysql.target.pool, tt.Folder, []string{"teardown.sql"})
+ err = s.mysql.Source.RunSqlFiles(s.ctx, &tt.Folder, []string{"teardown.sql"})
+ require.NoError(t, err)
+ err = s.mysql.Target.RunSqlFiles(s.ctx, &tt.Folder, []string{"teardown.sql"})
+ require.NoError(t, err)
})
}
})
@@ -1387,7 +1399,9 @@ func getAllMongoDBSyncTests() map[string][]*workflow_testdata.IntegrationTest {
func (s *IntegrationTestSuite) Test_Workflow_Generate() {
// setup
testName := "Generate Job"
- s.RunPostgresSqlFiles(s.postgres.target.pool, "generate-job", []string{"setup.sql"})
+ folder := "testdata/generate-job"
+ err := s.postgres.Target.RunSqlFiles(s.ctx, &folder, []string{"setup.sql"})
+ require.NoError(s.T(), err)
connectionId := "226add85-5751-4232-b085-a0ae93afc7ce"
schema := "generate_job"
@@ -1467,7 +1481,7 @@ func (s *IntegrationTestSuite) Test_Workflow_Generate() {
Config: &mgmtv1alpha1.ConnectionConfig_PgConfig{
PgConfig: &mgmtv1alpha1.PostgresConnectionConfig{
ConnectionConfig: &mgmtv1alpha1.PostgresConnectionConfig_Url{
- Url: s.postgres.target.url,
+ Url: s.postgres.Target.URL,
},
},
},
@@ -1483,10 +1497,10 @@ func (s *IntegrationTestSuite) Test_Workflow_Generate() {
srv := startHTTPServer(s.T(), mux)
env := executeWorkflow(s.T(), srv, s.redis.url, "115aaf2c-776e-4847-8268-d914e3c15968")
require.Truef(s.T(), env.IsWorkflowCompleted(), fmt.Sprintf("Workflow did not complete. Test: %s", testName))
- err := env.GetWorkflowError()
+ err = env.GetWorkflowError()
require.NoError(s.T(), err, "Received Temporal Workflow Error", "testName", testName)
- rows, err := s.postgres.target.pool.Query(s.ctx, fmt.Sprintf("select * from %s.%s;", schema, table))
+ rows, err := s.postgres.Target.DB.Query(s.ctx, fmt.Sprintf("select * from %s.%s;", schema, table))
require.NoError(s.T(), err)
count := 0
for rows.Next() {
@@ -1495,7 +1509,8 @@ func (s *IntegrationTestSuite) Test_Workflow_Generate() {
require.Equalf(s.T(), 10, count, fmt.Sprintf("Test: %s Table: %s", testName, table))
// tear down
- s.RunPostgresSqlFiles(s.postgres.target.pool, "generate-job", []string{"teardown.sql"})
+ err = s.postgres.Target.RunSqlFiles(s.ctx, &folder, []string{"teardown.sql"})
+ require.NoError(s.T(), err)
}
func executeWorkflow(