diff --git a/backend/pkg/integration-test/integration-test-util.go b/backend/pkg/integration-test/integration-test-util.go new file mode 100644 index 0000000000..3523116728 --- /dev/null +++ b/backend/pkg/integration-test/integration-test-util.go @@ -0,0 +1,73 @@ +package integrationtests_test + +import ( + "context" + "testing" + + "connectrpc.com/connect" + mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1" + "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1/mgmtv1alpha1connect" + "github.com/stretchr/testify/require" +) + +func CreatePersonalAccount( + ctx context.Context, + t *testing.T, + userclient mgmtv1alpha1connect.UserAccountServiceClient, +) string { + resp, err := userclient.SetPersonalAccount(ctx, connect.NewRequest(&mgmtv1alpha1.SetPersonalAccountRequest{})) + RequireNoErrResp(t, resp, err) + return resp.Msg.AccountId +} + +func CreatePostgresConnection( + ctx context.Context, + t *testing.T, + connclient mgmtv1alpha1connect.ConnectionServiceClient, + accountId string, + name string, + pgurl string, +) *mgmtv1alpha1.Connection { + resp, err := connclient.CreateConnection( + ctx, + connect.NewRequest(&mgmtv1alpha1.CreateConnectionRequest{ + AccountId: accountId, + Name: name, + ConnectionConfig: &mgmtv1alpha1.ConnectionConfig{ + Config: &mgmtv1alpha1.ConnectionConfig_PgConfig{ + PgConfig: &mgmtv1alpha1.PostgresConnectionConfig{ + ConnectionConfig: &mgmtv1alpha1.PostgresConnectionConfig_Url{ + Url: pgurl, + }, + }, + }, + }, + }), + ) + RequireNoErrResp(t, resp, err) + return resp.Msg.GetConnection() +} + +func SetUser(ctx context.Context, t *testing.T, client mgmtv1alpha1connect.UserAccountServiceClient) string { + resp, err := client.SetUser(ctx, connect.NewRequest(&mgmtv1alpha1.SetUserRequest{})) + RequireNoErrResp(t, resp, err) + return resp.Msg.GetUserId() +} + +func CreateTeamAccount(ctx context.Context, t *testing.T, client mgmtv1alpha1connect.UserAccountServiceClient, name string) string { + resp, err := client.CreateTeamAccount(ctx, connect.NewRequest(&mgmtv1alpha1.CreateTeamAccountRequest{Name: name})) + RequireNoErrResp(t, resp, err) + return resp.Msg.AccountId +} + +func RequireNoErrResp[T any](t testing.TB, resp *connect.Response[T], err error) { + t.Helper() + require.NoError(t, err) + require.NotNil(t, resp) +} + +func RequireErrResp[T any](t testing.TB, resp *connect.Response[T], err error) { + t.Helper() + require.Error(t, err) + require.Nil(t, resp) +} diff --git a/backend/pkg/integration-test/integration-test.go b/backend/pkg/integration-test/integration-test.go new file mode 100644 index 0000000000..45d6ff1f57 --- /dev/null +++ b/backend/pkg/integration-test/integration-test.go @@ -0,0 +1,450 @@ +package integrationtests_test + +import ( + "context" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "sync" + "testing" + + "connectrpc.com/connect" + db_queries "github.com/nucleuscloud/neosync/backend/gen/go/db" + mysql_queries "github.com/nucleuscloud/neosync/backend/gen/go/db/dbschemas/mysql" + pg_queries "github.com/nucleuscloud/neosync/backend/gen/go/db/dbschemas/postgresql" + "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1/mgmtv1alpha1connect" + "github.com/nucleuscloud/neosync/backend/internal/apikey" + auth_apikey "github.com/nucleuscloud/neosync/backend/internal/auth/apikey" + auth_client "github.com/nucleuscloud/neosync/backend/internal/auth/client" + auth_jwt "github.com/nucleuscloud/neosync/backend/internal/auth/jwt" + "github.com/nucleuscloud/neosync/backend/internal/authmgmt" + auth_interceptor "github.com/nucleuscloud/neosync/backend/internal/connect/interceptors/auth" + neosync_gcp "github.com/nucleuscloud/neosync/backend/internal/gcp" + "github.com/nucleuscloud/neosync/backend/internal/neosyncdb" + clientmanager "github.com/nucleuscloud/neosync/backend/internal/temporal/client-manager" + "github.com/nucleuscloud/neosync/backend/internal/utils" + "github.com/nucleuscloud/neosync/backend/pkg/mongoconnect" + mssql_queries "github.com/nucleuscloud/neosync/backend/pkg/mssql-querier" + "github.com/nucleuscloud/neosync/backend/pkg/sqlconnect" + "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager" + v1alpha_anonymizationservice "github.com/nucleuscloud/neosync/backend/services/mgmt/v1alpha1/anonymization-service" + v1alpha1_connectiondataservice "github.com/nucleuscloud/neosync/backend/services/mgmt/v1alpha1/connection-data-service" + v1alpha1_connectionservice "github.com/nucleuscloud/neosync/backend/services/mgmt/v1alpha1/connection-service" + v1alpha1_jobservice "github.com/nucleuscloud/neosync/backend/services/mgmt/v1alpha1/job-service" + v1alpha1_transformersservice "github.com/nucleuscloud/neosync/backend/services/mgmt/v1alpha1/transformers-service" + v1alpha1_useraccountservice "github.com/nucleuscloud/neosync/backend/services/mgmt/v1alpha1/user-account-service" + awsmanager "github.com/nucleuscloud/neosync/internal/aws" + "github.com/nucleuscloud/neosync/internal/billing" + presidioapi "github.com/nucleuscloud/neosync/internal/ee/presidio" + neomigrate "github.com/nucleuscloud/neosync/internal/migrate" + promapiv1mock "github.com/nucleuscloud/neosync/internal/mocks/github.com/prometheus/client_golang/api/prometheus/v1" + tcpostgres "github.com/nucleuscloud/neosync/internal/testutil/testcontainers/postgres" + http_client "github.com/nucleuscloud/neosync/worker/pkg/http/client" +) + +var ( + validAuthUser = &authmgmt.User{Name: "foo", Email: "bar", Picture: "baz"} +) + +type UnauthdClients struct { + Users mgmtv1alpha1connect.UserAccountServiceClient + Transformers mgmtv1alpha1connect.TransformersServiceClient + Connections mgmtv1alpha1connect.ConnectionServiceClient + ConnectionData mgmtv1alpha1connect.ConnectionDataServiceClient + Jobs mgmtv1alpha1connect.JobServiceClient + Anonymize mgmtv1alpha1connect.AnonymizationServiceClient +} + +type Mocks struct { + TemporalClientManager *clientmanager.MockTemporalClientManagerClient + Authclient *auth_client.MockInterface + Authmanagerclient *authmgmt.MockInterface + Prometheusclient *promapiv1mock.MockAPI + Billingclient *billing.MockInterface + Presidio Presidiomocks +} + +type Presidiomocks struct { + Analyzer *presidioapi.MockAnalyzeInterface + Anonymizer *presidioapi.MockAnonymizeInterface + Entities *presidioapi.MockEntityInterface +} + +type NeosyncApiTestClient struct { + NeosyncQuerier db_queries.Querier + systemQuerier pg_queries.Querier + + Pgcontainer *tcpostgres.PostgresTestContainer + migrationsDir string + + httpsrv *httptest.Server + + UnauthdClients *UnauthdClients + NeosyncCloudClients *NeosyncCloudClients + AuthdClients *AuthdClients + + Mocks *Mocks +} + +// Option is a functional option for configuring Neosync Api Test Client +type Option func(*NeosyncApiTestClient) + +func NewNeosyncApiTestClient(ctx context.Context, t *testing.T, opts ...Option) (*NeosyncApiTestClient, error) { + neoApi := &NeosyncApiTestClient{ + migrationsDir: "../../../../sql/postgresql/schema", + } + for _, opt := range opts { + opt(neoApi) + } + err := neoApi.Setup(ctx, t) + if err != nil { + return nil, err + } + return neoApi, nil +} + +// Sets neosync database migrations directory path +func WithMigrationsDirectory(directoryPath string) Option { + return func(a *NeosyncApiTestClient) { + a.migrationsDir = directoryPath + } +} + +type NeosyncCloudClients struct { + httpsrv *httptest.Server + basepath string +} + +func (s *NeosyncCloudClients) GetUserClient(authUserId string) mgmtv1alpha1connect.UserAccountServiceClient { + return mgmtv1alpha1connect.NewUserAccountServiceClient(http_client.WithBearerAuth(&http.Client{}, &authUserId), s.httpsrv.URL+s.basepath) +} + +func (s *NeosyncCloudClients) GetConnectionClient(authUserId string) mgmtv1alpha1connect.ConnectionServiceClient { + return mgmtv1alpha1connect.NewConnectionServiceClient(http_client.WithBearerAuth(&http.Client{}, &authUserId), s.httpsrv.URL+s.basepath) +} + +func (s *NeosyncCloudClients) GetAnonymizeClient(authUserId string) mgmtv1alpha1connect.AnonymizationServiceClient { + return mgmtv1alpha1connect.NewAnonymizationServiceClient(http_client.WithBearerAuth(&http.Client{}, &authUserId), s.httpsrv.URL+s.basepath) +} + +type AuthdClients struct { + httpsrv *httptest.Server +} + +func (s *AuthdClients) GetUserClient(authUserId string) mgmtv1alpha1connect.UserAccountServiceClient { + return mgmtv1alpha1connect.NewUserAccountServiceClient(http_client.WithBearerAuth(&http.Client{}, &authUserId), s.httpsrv.URL+"/auth") +} + +func (s *AuthdClients) GetConnectionClient(authUserId string) mgmtv1alpha1connect.ConnectionServiceClient { + return mgmtv1alpha1connect.NewConnectionServiceClient(http_client.WithBearerAuth(&http.Client{}, &authUserId), s.httpsrv.URL+"/auth") +} + +func (s *NeosyncApiTestClient) Setup(ctx context.Context, t *testing.T) error { + pgcontainer, err := tcpostgres.NewPostgresTestContainer(ctx) + if err != nil { + return err + } + s.Pgcontainer = pgcontainer + s.NeosyncQuerier = db_queries.New() + s.systemQuerier = pg_queries.New() + + s.Mocks = &Mocks{ + TemporalClientManager: clientmanager.NewMockTemporalClientManagerClient(t), + Authclient: auth_client.NewMockInterface(t), + Authmanagerclient: authmgmt.NewMockInterface(t), + Prometheusclient: promapiv1mock.NewMockAPI(t), + Billingclient: billing.NewMockInterface(t), + Presidio: Presidiomocks{ + Analyzer: presidioapi.NewMockAnalyzeInterface(t), + Anonymizer: presidioapi.NewMockAnonymizeInterface(t), + Entities: presidioapi.NewMockEntityInterface(t), + }, + } + + maxAllowed := int64(10000) + unauthdUserService := v1alpha1_useraccountservice.New( + &v1alpha1_useraccountservice.Config{IsAuthEnabled: false, IsNeosyncCloud: false, DefaultMaxAllowedRecords: &maxAllowed}, + neosyncdb.New(pgcontainer.DB, db_queries.New()), + s.Mocks.TemporalClientManager, + s.Mocks.Authclient, + s.Mocks.Authmanagerclient, + nil, + ) + + authdUserService := v1alpha1_useraccountservice.New( + &v1alpha1_useraccountservice.Config{IsAuthEnabled: true, IsNeosyncCloud: false}, + neosyncdb.New(pgcontainer.DB, db_queries.New()), + s.Mocks.TemporalClientManager, + s.Mocks.Authclient, + s.Mocks.Authmanagerclient, + nil, + ) + + authdConnectionService := v1alpha1_connectionservice.New( + &v1alpha1_connectionservice.Config{}, + neosyncdb.New(pgcontainer.DB, db_queries.New()), + authdUserService, + &sqlconnect.SqlOpenConnector{}, + pg_queries.New(), + mysql_queries.New(), + mssql_queries.New(), + mongoconnect.NewConnector(), + awsmanager.New(), + ) + + neoCloudAuthdUserService := v1alpha1_useraccountservice.New( + &v1alpha1_useraccountservice.Config{IsAuthEnabled: true, IsNeosyncCloud: true}, + neosyncdb.New(pgcontainer.DB, db_queries.New()), + s.Mocks.TemporalClientManager, + s.Mocks.Authclient, + s.Mocks.Authmanagerclient, + s.Mocks.Billingclient, + ) + neoCloudAuthdAnonymizeService := v1alpha_anonymizationservice.New( + &v1alpha_anonymizationservice.Config{IsAuthEnabled: true, IsNeosyncCloud: true, IsPresidioEnabled: false}, + nil, + neoCloudAuthdUserService, + s.Mocks.Presidio.Analyzer, + s.Mocks.Presidio.Anonymizer, + neosyncdb.New(pgcontainer.DB, db_queries.New()), + ) + + neoCloudConnectionService := v1alpha1_connectionservice.New( + &v1alpha1_connectionservice.Config{}, + neosyncdb.New(pgcontainer.DB, db_queries.New()), + neoCloudAuthdUserService, + &sqlconnect.SqlOpenConnector{}, + pg_queries.New(), + mysql_queries.New(), + mssql_queries.New(), + mongoconnect.NewConnector(), + awsmanager.New(), + ) + + unauthdTransformersService := v1alpha1_transformersservice.New( + &v1alpha1_transformersservice.Config{ + IsPresidioEnabled: true, + IsNeosyncCloud: false, + }, + neosyncdb.New(pgcontainer.DB, db_queries.New()), + unauthdUserService, + s.Mocks.Presidio.Entities, + ) + + unauthdConnectionsService := v1alpha1_connectionservice.New( + &v1alpha1_connectionservice.Config{}, + neosyncdb.New(pgcontainer.DB, db_queries.New()), + unauthdUserService, + &sqlconnect.SqlOpenConnector{}, + pg_queries.New(), + mysql_queries.New(), + mssql_queries.New(), + mongoconnect.NewConnector(), + awsmanager.New(), + ) + + unauthdJobsService := v1alpha1_jobservice.New( + &v1alpha1_jobservice.Config{}, + neosyncdb.New(pgcontainer.DB, db_queries.New()), + s.Mocks.TemporalClientManager, + unauthdConnectionsService, + unauthdUserService, + sqlmanager.NewSqlManager( + &sync.Map{}, pg_queries.New(), + &sync.Map{}, mysql_queries.New(), + &sync.Map{}, mssql_queries.New(), + &sqlconnect.SqlOpenConnector{}, + ), + ) + + unauthdConnectionDataService := v1alpha1_connectiondataservice.New( + &v1alpha1_connectiondataservice.Config{}, + unauthdUserService, + unauthdConnectionsService, + unauthdJobsService, + awsmanager.New(), + &sqlconnect.SqlOpenConnector{}, + pg_queries.New(), + mysql_queries.New(), + mongoconnect.NewConnector(), + sqlmanager.NewSqlManager( + &sync.Map{}, pg_queries.New(), + &sync.Map{}, mysql_queries.New(), + &sync.Map{}, mssql_queries.New(), + &sqlconnect.SqlOpenConnector{}, + ), + neosync_gcp.NewManager(), + ) + + var presAnalyzeClient presidioapi.AnalyzeInterface + var presAnonClient presidioapi.AnonymizeInterface + + unauthdAnonymizationService := v1alpha_anonymizationservice.New( + &v1alpha_anonymizationservice.Config{IsPresidioEnabled: false}, + nil, + unauthdUserService, + presAnalyzeClient, presAnonClient, + neosyncdb.New(pgcontainer.DB, db_queries.New()), + ) + + rootmux := http.NewServeMux() + + unauthmux := http.NewServeMux() + unauthmux.Handle(mgmtv1alpha1connect.NewUserAccountServiceHandler( + unauthdUserService, + )) + unauthmux.Handle(mgmtv1alpha1connect.NewTransformersServiceHandler( + unauthdTransformersService, + )) + unauthmux.Handle(mgmtv1alpha1connect.NewConnectionServiceHandler( + unauthdConnectionsService, + )) + unauthmux.Handle(mgmtv1alpha1connect.NewJobServiceHandler( + unauthdJobsService, + )) + unauthmux.Handle(mgmtv1alpha1connect.NewAnonymizationServiceHandler( + unauthdAnonymizationService, + )) + unauthmux.Handle(mgmtv1alpha1connect.NewConnectionDataServiceHandler( + unauthdConnectionDataService, + )) + rootmux.Handle("/unauth/", http.StripPrefix("/unauth", unauthmux)) + + authinterceptors := connect.WithInterceptors( + auth_interceptor.NewInterceptor(func(ctx context.Context, header http.Header, spec connect.Spec) (context.Context, error) { + // will need to further fill this out as the tests grow + authuserid, err := utils.GetBearerTokenFromHeader(header, "Authorization") + if err != nil { + return nil, err + } + if apikey.IsValidV1WorkerKey(authuserid) { + return auth_apikey.SetTokenData(ctx, &auth_apikey.TokenContextData{ + RawToken: authuserid, + ApiKey: nil, + ApiKeyType: apikey.WorkerApiKey, + }), nil + } + return auth_jwt.SetTokenData(ctx, &auth_jwt.TokenContextData{ + AuthUserId: authuserid, + Claims: &auth_jwt.CustomClaims{Email: &validAuthUser.Email}, + }), nil + }), + ) + + authmux := http.NewServeMux() + authmux.Handle(mgmtv1alpha1connect.NewUserAccountServiceHandler( + authdUserService, + authinterceptors, + )) + authmux.Handle(mgmtv1alpha1connect.NewConnectionServiceHandler( + authdConnectionService, + authinterceptors, + )) + rootmux.Handle("/auth/", http.StripPrefix("/auth", authmux)) + + ncauthmux := http.NewServeMux() + ncauthmux.Handle(mgmtv1alpha1connect.NewUserAccountServiceHandler( + neoCloudAuthdUserService, + authinterceptors, + )) + ncauthmux.Handle(mgmtv1alpha1connect.NewAnonymizationServiceHandler( + neoCloudAuthdAnonymizeService, + authinterceptors, + )) + ncauthmux.Handle(mgmtv1alpha1connect.NewConnectionServiceHandler( + neoCloudConnectionService, + authinterceptors, + )) + rootmux.Handle("/ncauth/", http.StripPrefix("/ncauth", ncauthmux)) + + s.httpsrv = startHTTPServer(t, rootmux) + + s.UnauthdClients = &UnauthdClients{ + Users: mgmtv1alpha1connect.NewUserAccountServiceClient(s.httpsrv.Client(), s.httpsrv.URL+"/unauth"), + Transformers: mgmtv1alpha1connect.NewTransformersServiceClient(s.httpsrv.Client(), s.httpsrv.URL+"/unauth"), + Connections: mgmtv1alpha1connect.NewConnectionServiceClient(s.httpsrv.Client(), s.httpsrv.URL+"/unauth"), + ConnectionData: mgmtv1alpha1connect.NewConnectionDataServiceClient(s.httpsrv.Client(), s.httpsrv.URL+"/unauth"), + Jobs: mgmtv1alpha1connect.NewJobServiceClient(s.httpsrv.Client(), s.httpsrv.URL+"/unauth"), + Anonymize: mgmtv1alpha1connect.NewAnonymizationServiceClient(s.httpsrv.Client(), s.httpsrv.URL+"/unauth"), + } + + s.AuthdClients = &AuthdClients{ + httpsrv: s.httpsrv, + } + s.NeosyncCloudClients = &NeosyncCloudClients{ + httpsrv: s.httpsrv, + basepath: "/ncauth", + } + + err = s.InitializeTest(ctx) + if err != nil { + return err + } + return nil +} + +func (s *NeosyncApiTestClient) InitializeTest(ctx context.Context) error { + discardLogger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})) + err := neomigrate.Up(ctx, s.Pgcontainer.URL, s.migrationsDir, discardLogger) + if err != nil { + return err + } + return nil +} + +func (s *NeosyncApiTestClient) CleanupTest(ctx context.Context) error { + // Dropping here because 1) more efficient and 2) we have a bad down migration + // _jobs-connection-id-null.down that breaks due to having a null connection_id column. + // we should do something about that at some point. Running this single drop is easier though + _, err := s.Pgcontainer.DB.Exec(ctx, "DROP SCHEMA IF EXISTS neosync_api CASCADE") + if err != nil { + return err + } + _, err = s.Pgcontainer.DB.Exec(ctx, "DROP TABLE IF EXISTS public.schema_migrations") + if err != nil { + return err + } + return nil +} + +func (s *NeosyncApiTestClient) TearDown(ctx context.Context) error { + if s.Pgcontainer != nil { + _, err := s.Pgcontainer.DB.Exec(ctx, "DROP SCHEMA IF EXISTS neosync_api CASCADE") + if err != nil { + return err + } + _, err = s.Pgcontainer.DB.Exec(ctx, "DROP TABLE IF EXISTS public.schema_migrations") + if err != nil { + return err + } + if s.Pgcontainer.DB != nil { + s.Pgcontainer.DB.Close() + } + if s.Pgcontainer.TestContainer != nil { + err := s.Pgcontainer.TestContainer.Terminate(ctx) + if err != nil { + return err + } + } + } + return nil +} + +func startHTTPServer(tb testing.TB, h http.Handler) *httptest.Server { + tb.Helper() + srv := httptest.NewUnstartedServer(h) + srv.EnableHTTP2 = true + srv.Start() + tb.Cleanup(srv.Close) + return srv +} + +func NewTestSqlManagerClient() *sqlmanager.SqlManager { + return sqlmanager.NewSqlManager( + &sync.Map{}, pg_queries.New(), + &sync.Map{}, mysql_queries.New(), + &sync.Map{}, mssql_queries.New(), + &sqlconnect.SqlOpenConnector{}, + ) +} diff --git a/backend/pkg/sqlmanager/mysql/integration_test.go b/backend/pkg/sqlmanager/mysql/integration_test.go index 216596ea40..867ef3b375 100644 --- a/backend/pkg/sqlmanager/mysql/integration_test.go +++ b/backend/pkg/sqlmanager/mysql/integration_test.go @@ -2,201 +2,74 @@ package sqlmanager_mysql import ( "context" - "database/sql" "fmt" "log/slog" "os" "testing" - "time" _ "github.com/go-sql-driver/mysql" - "golang.org/x/sync/errgroup" mysql_queries "github.com/nucleuscloud/neosync/backend/gen/go/db/dbschemas/mysql" - sqlmanager_shared "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/shared" + tcmysql "github.com/nucleuscloud/neosync/internal/testutil/testcontainers/mysql" "github.com/stretchr/testify/suite" - "github.com/testcontainers/testcontainers-go" - testmysql "github.com/testcontainers/testcontainers-go/modules/mysql" - "github.com/testcontainers/testcontainers-go/wait" ) type IntegrationTestSuite struct { suite.Suite - initSql string - setupSql string - teardownSql string - ctx context.Context - source *mysqlTestContainer - target *mysqlTestContainer -} - -type mysqlTestContainer struct { - pool *sql.DB - querier mysql_queries.Querier - container *testmysql.MySQLContainer - url string - close func() -} - -type mysqlTest struct { - source *mysqlTestContainer - target *mysqlTestContainer -} - -func (s *IntegrationTestSuite) SetupMysql() (*mysqlTest, error) { - var source *mysqlTestContainer - var target *mysqlTestContainer - - errgrp := errgroup.Group{} - errgrp.Go(func() error { - sourcecontainer, err := createMysqlTestContainer(s.ctx, "datasync", "root", "pass-source") - if err != nil { - return err - } - source = sourcecontainer - return nil - }) - - errgrp.Go(func() error { - targetcontainer, err := createMysqlTestContainer(s.ctx, "datasync", "root", "pass-target") - if err != nil { - return err - } - target = targetcontainer - return nil - }) - - err := errgrp.Wait() - if err != nil { - return nil, err - } - - return &mysqlTest{ - source: source, - target: target, - }, nil -} - -func createMysqlTestContainer( - ctx context.Context, - database, username, password string, -) (*mysqlTestContainer, error) { - container, err := testmysql.Run(ctx, - "mysql:8.0.36", - testmysql.WithDatabase(database), - testmysql.WithUsername(username), - testmysql.WithPassword(password), - testcontainers.WithWaitStrategy( - wait.ForLog("port: 3306 MySQL Community Server"). - WithOccurrence(1).WithStartupTimeout(20*time.Second), - ), - ) - if err != nil { - return nil, err - } - connstr, err := container.ConnectionString(ctx, "multiStatements=true&parseTime=true") - if err != nil { - panic(err) - } - pool, err := sql.Open(sqlmanager_shared.MysqlDriver, connstr) - if err != nil { - panic(err) - } - containerPort, err := container.MappedPort(ctx, "3306/tcp") - if err != nil { - return nil, err - } - containerHost, err := container.Host(ctx) - if err != nil { - return nil, err - } - - connUrl := fmt.Sprintf("mysql://%s:%s@%s:%s/%s?multiStatements=true&parseTime=true", username, password, containerHost, containerPort.Port(), database) - return &mysqlTestContainer{ - pool: pool, - querier: mysql_queries.New(), - url: connUrl, - container: container, - close: func() { - if pool != nil { - pool.Close() - } - }, - }, nil + querier mysql_queries.Querier + containers *tcmysql.MysqlTestSyncContainer } func (s *IntegrationTestSuite) SetupSuite() { s.ctx = context.Background() - m, err := s.SetupMysql() - if err != nil { - panic(err) - } - s.source = m.source - s.target = m.target - - initSql, err := os.ReadFile("./testdata/init.sql") - if err != nil { - panic(err) - } - s.initSql = string(initSql) - - setupSql, err := os.ReadFile("./testdata/setup.sql") - if err != nil { - panic(err) - } - s.setupSql = string(setupSql) - - teardownSql, err := os.ReadFile("./testdata/teardown.sql") + container, err := tcmysql.NewMysqlTestSyncContainer(s.ctx, []tcmysql.Option{}, []tcmysql.Option{}) if err != nil { panic(err) } - s.teardownSql = string(teardownSql) + s.containers = container + s.querier = mysql_queries.New() } // Runs before each test func (s *IntegrationTestSuite) SetupTest() { - _, err := s.target.pool.ExecContext(s.ctx, s.initSql) + err := s.containers.Target.RunSqlFiles(s.ctx, nil, []string{"testdata/init.sql"}) if err != nil { panic(err) } - _, err = s.source.pool.ExecContext(s.ctx, s.setupSql) + err = s.containers.Source.RunSqlFiles(s.ctx, nil, []string{"testdata/setup.sql"}) if err != nil { panic(err) } } func (s *IntegrationTestSuite) TearDownTest() { - _, err := s.target.pool.ExecContext(s.ctx, s.teardownSql) + err := s.containers.Source.RunSqlFiles(s.ctx, nil, []string{"testdata/teardown.sql"}) if err != nil { panic(err) } - _, err = s.source.pool.ExecContext(s.ctx, s.teardownSql) + err = s.containers.Target.RunSqlFiles(s.ctx, nil, []string{"testdata/teardown.sql"}) if err != nil { panic(err) } } func (s *IntegrationTestSuite) TearDownSuite() { - if s.source.pool != nil { - s.source.close() - } - if s.target.pool != nil { - s.target.close() - } - if s.source != nil { - err := s.source.container.Terminate(s.ctx) - if err != nil { - panic(err) + if s.containers != nil { + if s.containers.Source != nil { + err := s.containers.Source.TearDown(s.ctx) + if err != nil { + panic(err) + } } - } - if s.target != nil { - err := s.target.container.Terminate(s.ctx) - if err != nil { - panic(err) + if s.containers.Target != nil { + err := s.containers.Target.TearDown(s.ctx) + if err != nil { + panic(err) + } } } } diff --git a/backend/pkg/sqlmanager/mysql/mysql-manager_integration_test.go b/backend/pkg/sqlmanager/mysql/mysql-manager_integration_test.go index 754f1e0504..341d30db19 100644 --- a/backend/pkg/sqlmanager/mysql/mysql-manager_integration_test.go +++ b/backend/pkg/sqlmanager/mysql/mysql-manager_integration_test.go @@ -11,7 +11,7 @@ import ( ) func (s *IntegrationTestSuite) Test_GetTableConstraintsBySchema() { - manager := MysqlManager{querier: s.source.querier, pool: s.source.pool} + manager := MysqlManager{querier: s.querier, pool: s.containers.Source.DB} expected := &sqlmanager_shared.TableConstraints{ ForeignKeyConstraints: map[string][]*sqlmanager_shared.ForeignConstraint{ @@ -43,7 +43,7 @@ func (s *IntegrationTestSuite) Test_GetTableConstraintsBySchema() { } func (s *IntegrationTestSuite) Test_GetSchemaColumnMap() { - manager := NewManager(s.source.querier, s.source.pool, func() {}) + manager := NewManager(s.querier, s.containers.Source.DB, func() {}) schema := "sqlmanagermysql3" actual, err := manager.GetSchemaColumnMap(context.Background()) require.NoError(s.T(), err) @@ -58,7 +58,7 @@ func (s *IntegrationTestSuite) Test_GetSchemaColumnMap() { } func (s *IntegrationTestSuite) Test_GetForeignKeyConstraintsMap() { - manager := NewManager(s.source.querier, s.source.pool, func() {}) + manager := NewManager(s.querier, s.containers.Source.DB, func() {}) schema := "sqlmanagermysql3" actual, err := manager.GetTableConstraintsBySchema(s.ctx, []string{schema}) require.NoError(s.T(), err) @@ -114,7 +114,7 @@ func (s *IntegrationTestSuite) Test_GetForeignKeyConstraintsMap() { } func (s *IntegrationTestSuite) Test_GetForeignKeyConstraintsMap_BasicCircular() { - manager := NewManager(s.source.querier, s.source.pool, func() {}) + manager := NewManager(s.querier, s.containers.Source.DB, func() {}) schema := "sqlmanagermysql3" actual, err := manager.GetTableConstraintsBySchema(s.ctx, []string{schema}) require.NoError(s.T(), err) @@ -155,7 +155,7 @@ func (s *IntegrationTestSuite) Test_GetForeignKeyConstraintsMap_BasicCircular() } func (s *IntegrationTestSuite) Test_GetForeignKeyConstraintsMap_Composite() { - manager := NewManager(s.source.querier, s.source.pool, func() {}) + manager := NewManager(s.querier, s.containers.Source.DB, func() {}) schema := "sqlmanagermysql3" actual, err := manager.GetTableConstraintsBySchema(s.ctx, []string{schema}) @@ -175,7 +175,7 @@ func (s *IntegrationTestSuite) Test_GetForeignKeyConstraintsMap_Composite() { } func (s *IntegrationTestSuite) Test_GetPrimaryKeyConstraintsMap() { - manager := NewManager(s.source.querier, s.source.pool, func() {}) + manager := NewManager(s.querier, s.containers.Source.DB, func() {}) schema := "sqlmanagermysql3" actual, err := manager.GetTableConstraintsBySchema(context.Background(), []string{schema}) @@ -193,7 +193,7 @@ func (s *IntegrationTestSuite) Test_GetPrimaryKeyConstraintsMap() { } func (s *IntegrationTestSuite) Test_GetUniqueConstraintsMap() { - manager := NewManager(s.source.querier, s.source.pool, func() {}) + manager := NewManager(s.querier, s.containers.Source.DB, func() {}) schema := "sqlmanagermysql3" actual, err := manager.GetTableConstraintsBySchema(context.Background(), []string{schema}) @@ -207,7 +207,7 @@ func (s *IntegrationTestSuite) Test_GetUniqueConstraintsMap() { } func (s *IntegrationTestSuite) Test_GetUniqueConstraintsMap_Composite() { - manager := NewManager(s.source.querier, s.source.pool, func() {}) + manager := NewManager(s.querier, s.containers.Source.DB, func() {}) schema := "sqlmanagermysql3" actual, err := manager.GetTableConstraintsBySchema(context.Background(), []string{schema}) @@ -223,7 +223,7 @@ func (s *IntegrationTestSuite) Test_GetUniqueConstraintsMap_Composite() { } func (s *IntegrationTestSuite) Test_GetRolePermissionsMap() { - manager := NewManager(s.source.querier, s.source.pool, func() {}) + manager := NewManager(s.querier, s.containers.Source.DB, func() {}) schema := "sqlmanagermysql3" actual, err := manager.GetRolePermissionsMap(context.Background()) @@ -242,18 +242,18 @@ func (s *IntegrationTestSuite) Test_GetRolePermissionsMap() { } func (s *IntegrationTestSuite) Test_GetCreateTableStatement() { - manager := NewManager(s.source.querier, s.source.pool, func() {}) + manager := NewManager(s.querier, s.containers.Source.DB, func() {}) schema := "sqlmanagermysql3" actual, err := manager.GetCreateTableStatement(context.Background(), schema, "users") require.NoError(s.T(), err) require.NotEmpty(s.T(), actual) - _, err = s.target.pool.ExecContext(context.Background(), actual) + _, err = s.containers.Target.DB.ExecContext(context.Background(), actual) require.NoError(s.T(), err) } func (s *IntegrationTestSuite) Test_GetTableInitStatements() { - manager := NewManager(s.source.querier, s.source.pool, func() {}) + manager := NewManager(s.querier, s.containers.Source.DB, func() {}) schema := "sqlmanagermysql3" actual, err := manager.GetTableInitStatements( @@ -268,25 +268,25 @@ func (s *IntegrationTestSuite) Test_GetTableInitStatements() { require.NoError(s.T(), err) require.NotEmpty(s.T(), actual) for _, stmt := range actual { - _, err = s.target.pool.ExecContext(context.Background(), stmt.CreateTableStatement) + _, err = s.containers.Target.DB.ExecContext(context.Background(), stmt.CreateTableStatement) require.NoError(s.T(), err) } for _, stmt := range actual { for _, index := range stmt.IndexStatements { - _, err = s.target.pool.ExecContext(context.Background(), index) + _, err = s.containers.Target.DB.ExecContext(context.Background(), index) require.NoError(s.T(), err) } } for _, stmt := range actual { for _, alter := range stmt.AlterTableStatements { - _, err = s.target.pool.ExecContext(context.Background(), alter.Statement) + _, err = s.containers.Target.DB.ExecContext(context.Background(), alter.Statement) require.NoError(s.T(), err) } } } func (s *IntegrationTestSuite) Test_Exec() { - manager := NewManager(s.source.querier, s.source.pool, func() {}) + manager := NewManager(s.querier, s.containers.Source.DB, func() {}) schema := "sqlmanagermysql3" err := manager.Exec(context.Background(), fmt.Sprintf("SELECT 1 FROM %s.%s", schema, "users")) @@ -294,7 +294,7 @@ func (s *IntegrationTestSuite) Test_Exec() { } func (s *IntegrationTestSuite) Test_BatchExec() { - manager := NewManager(s.source.querier, s.source.pool, func() {}) + manager := NewManager(s.querier, s.containers.Source.DB, func() {}) schema := "sqlmanagermysql3" stmt := fmt.Sprintf("SELECT 1 FROM %s.%s;", schema, "users") @@ -303,7 +303,7 @@ func (s *IntegrationTestSuite) Test_BatchExec() { } func (s *IntegrationTestSuite) Test_BatchExec_With_Prefix() { - manager := NewManager(s.source.querier, s.source.pool, func() {}) + manager := NewManager(s.querier, s.containers.Source.DB, func() {}) schema := "sqlmanagermysql3" stmt := fmt.Sprintf("SELECT 1 FROM %s.%s;", schema, "users") @@ -314,7 +314,7 @@ func (s *IntegrationTestSuite) Test_BatchExec_With_Prefix() { } func (s *IntegrationTestSuite) Test_GetSchemaInitStatements() { - manager := NewManager(s.source.querier, s.source.pool, func() {}) + manager := NewManager(s.querier, s.containers.Source.DB, func() {}) schema := "sqlmanagermysql3" statements, err := manager.GetSchemaInitStatements(context.Background(), []*sqlmanager_shared.SchemaTable{ @@ -337,29 +337,29 @@ func (s *IntegrationTestSuite) Test_GetSchemaInitStatements() { lableStmtMap[st.Label] = append(lableStmtMap[st.Label], st.Statements...) } for _, stmt := range lableStmtMap["create table"] { - _, err = s.target.pool.ExecContext(context.Background(), stmt) + _, err = s.containers.Target.DB.ExecContext(context.Background(), stmt) require.NoError(s.T(), err) } for _, stmt := range lableStmtMap["table triggers"] { - _, err = s.target.pool.ExecContext(context.Background(), stmt) + _, err = s.containers.Target.DB.ExecContext(context.Background(), stmt) require.NoError(s.T(), err) } for _, stmt := range lableStmtMap["table index"] { - _, err = s.target.pool.ExecContext(context.Background(), stmt) + _, err = s.containers.Target.DB.ExecContext(context.Background(), stmt) require.NoError(s.T(), err) } for _, stmt := range lableStmtMap["non-fk alter table"] { - _, err = s.target.pool.ExecContext(context.Background(), stmt) + _, err = s.containers.Target.DB.ExecContext(context.Background(), stmt) require.NoError(s.T(), err) } for _, stmt := range lableStmtMap["fk alter table"] { - _, err = s.target.pool.ExecContext(context.Background(), stmt) + _, err = s.containers.Target.DB.ExecContext(context.Background(), stmt) require.NoError(s.T(), err) } } func (s *IntegrationTestSuite) Test_GetSchemaInitStatements_customtable() { - manager := NewManager(s.source.querier, s.source.pool, func() {}) + manager := NewManager(s.querier, s.containers.Source.DB, func() {}) schema := "sqlmanagermysql3" statements, err := manager.GetSchemaInitStatements(context.Background(), []*sqlmanager_shared.SchemaTable{{Schema: schema, Table: "custom_table"}}) @@ -368,7 +368,7 @@ func (s *IntegrationTestSuite) Test_GetSchemaInitStatements_customtable() { } func (s *IntegrationTestSuite) Test_GetSchemaTableTriggers_customtable() { - manager := NewManager(s.source.querier, s.source.pool, func() {}) + manager := NewManager(s.querier, s.containers.Source.DB, func() {}) schema := "sqlmanagermysql3" triggers, err := manager.GetSchemaTableTriggers(context.Background(), []*sqlmanager_shared.SchemaTable{{Schema: schema, Table: "employee_log"}}) @@ -377,7 +377,7 @@ func (s *IntegrationTestSuite) Test_GetSchemaTableTriggers_customtable() { } func (s *IntegrationTestSuite) Test_GetSchemaTableDataTypes_customtable() { - manager := NewManager(s.source.querier, s.source.pool, func() {}) + manager := NewManager(s.querier, s.containers.Source.DB, func() {}) schema := "sqlmanagermysql3" resp, err := manager.GetSchemaTableDataTypes(context.Background(), []*sqlmanager_shared.SchemaTable{{Schema: schema, Table: "custom_table"}}) diff --git a/backend/pkg/sqlmanager/mysql_sql-manager_integration_test.go b/backend/pkg/sqlmanager/mysql_sql-manager_integration_test.go index 5cc081f01d..3f8dda54e7 100644 --- a/backend/pkg/sqlmanager/mysql_sql-manager_integration_test.go +++ b/backend/pkg/sqlmanager/mysql_sql-manager_integration_test.go @@ -7,7 +7,6 @@ import ( "os" "sync" "testing" - "time" "github.com/google/uuid" mysql_queries "github.com/nucleuscloud/neosync/backend/gen/go/db/dbschemas/mysql" @@ -16,15 +15,13 @@ import ( "github.com/stretchr/testify/suite" _ "github.com/go-sql-driver/mysql" - "github.com/testcontainers/testcontainers-go" - testmysql "github.com/testcontainers/testcontainers-go/modules/mysql" - "github.com/testcontainers/testcontainers-go/wait" + tcmysql "github.com/nucleuscloud/neosync/internal/testutil/testcontainers/mysql" ) type MysqlIntegrationTestSuite struct { suite.Suite - mysqlcontainer *testmysql.MySQLContainer + mysqlcontainer *tcmysql.MysqlTestContainer ctx context.Context @@ -34,54 +31,20 @@ type MysqlIntegrationTestSuite struct { conncfg *mgmtv1alpha1.MysqlConnectionConfig // mgmt connection mgmtconn *mgmtv1alpha1.Connection - - // dsn format of connection url - dsn string } func (s *MysqlIntegrationTestSuite) SetupSuite() { s.ctx = context.Background() - dbname := "testdb" - user := "root" - pass := "test-password" - - container, err := testmysql.Run(s.ctx, - "mysql:8.0.36", - testmysql.WithDatabase(dbname), - testmysql.WithUsername(user), - testmysql.WithPassword(pass), - testcontainers.WithWaitStrategy( - wait.ForLog("port: 3306 MySQL Community Server"). - WithOccurrence(1).WithStartupTimeout(20*time.Second), - ), - ) + container, err := tcmysql.NewMysqlTestContainer(s.ctx) if err != nil { panic(err) } - - connstr, err := container.ConnectionString(s.ctx, "multiStatements=true&&parseTime=true") - if err != nil { - panic(err) - } - s.dsn = connstr - s.mysqlcontainer = container - containerPort, err := container.MappedPort(s.ctx, "3306/tcp") - if err != nil { - panic(err) - } - containerHost, err := container.Host(s.ctx) - if err != nil { - panic(err) - } - - connUrl := fmt.Sprintf("mysql://%s:%s@%s:%s/%s?multiStatements=true&parseTime=true", user, pass, containerHost, containerPort.Port(), dbname) - s.conncfg = &mgmtv1alpha1.MysqlConnectionConfig{ ConnectionConfig: &mgmtv1alpha1.MysqlConnectionConfig_Url{ - Url: connUrl, + Url: container.URL, }, } s.mgmtconn = &mgmtv1alpha1.Connection{ @@ -106,7 +69,7 @@ func (s *MysqlIntegrationTestSuite) TearDownTest() { func (s *MysqlIntegrationTestSuite) TearDownSuite() { if s.mysqlcontainer != nil { - err := s.mysqlcontainer.Terminate(s.ctx) + err := s.mysqlcontainer.TearDown(s.ctx) if err != nil { panic(err) } @@ -145,7 +108,7 @@ func (s *MysqlIntegrationTestSuite) Test_NewSqlDb() { func (s *MysqlIntegrationTestSuite) Test_NewSqlDbFromUrl() { t := s.T() - conn, err := s.sqlmanager.NewSqlDbFromUrl(s.ctx, "mysql", s.dsn) // NewSqlDbFromUrl requires dsn format + conn, err := s.sqlmanager.NewSqlDbFromUrl(s.ctx, "mysql", s.mysqlcontainer.URL) requireNoConnErr(t, conn, err) requireValidDatabase(t, s.ctx, conn, "mysql", "SELECT 1") diff --git a/backend/pkg/sqlmanager/postgres/integration_test.go b/backend/pkg/sqlmanager/postgres/integration_test.go index 2f082cb4c7..e09f0c8215 100644 --- a/backend/pkg/sqlmanager/postgres/integration_test.go +++ b/backend/pkg/sqlmanager/postgres/integration_test.go @@ -7,13 +7,15 @@ import ( "log/slog" "os" "testing" - "time" pg_queries "github.com/nucleuscloud/neosync/backend/gen/go/db/dbschemas/postgresql" + sqlmanager_shared "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/shared" + tcpostgres "github.com/nucleuscloud/neosync/internal/testutil/testcontainers/postgres" "github.com/stretchr/testify/suite" - "github.com/testcontainers/testcontainers-go" - testpg "github.com/testcontainers/testcontainers-go/modules/postgres" - "github.com/testcontainers/testcontainers-go/wait" +) + +var ( + testdataFolder = "testdata" ) type IntegrationTestSuite struct { @@ -22,12 +24,9 @@ type IntegrationTestSuite struct { db *sql.DB querier pg_queries.Querier - setupSql string - teardownSql string - ctx context.Context - pgcontainer *testpg.PostgresContainer + pgcontainer *tcpostgres.PostgresTestContainer schema string } @@ -40,36 +39,13 @@ func (s *IntegrationTestSuite) SetupSuite() { s.ctx = context.Background() s.schema = "sqlmanagerpostgres@special" - pgcontainer, err := testpg.Run( - s.ctx, - "postgres:15", - testcontainers.WithWaitStrategy( - wait.ForLog("database system is ready to accept connections"). - WithOccurrence(2).WithStartupTimeout(5*time.Second), - ), - ) + pgcontainer, err := tcpostgres.NewPostgresTestContainer(s.ctx) if err != nil { panic(err) } s.pgcontainer = pgcontainer - connstr, err := pgcontainer.ConnectionString(s.ctx) - if err != nil { - panic(err) - } - - setupSql, err := os.ReadFile("./testdata/setup.sql") - if err != nil { - panic(err) - } - s.setupSql = string(setupSql) - - teardownSql, err := os.ReadFile("./testdata/teardown.sql") - if err != nil { - panic(err) - } - s.teardownSql = string(teardownSql) - db, err := sql.Open("pgx", connstr) + db, err := sql.Open(sqlmanager_shared.PostgresDriver, s.pgcontainer.URL) if err != nil { panic(err) } @@ -79,14 +55,14 @@ func (s *IntegrationTestSuite) SetupSuite() { // Runs before each test func (s *IntegrationTestSuite) SetupTest() { - _, err := s.db.ExecContext(s.ctx, s.setupSql) + err := s.pgcontainer.RunSqlFiles(s.ctx, &testdataFolder, []string{"setup.sql"}) if err != nil { panic(err) } } func (s *IntegrationTestSuite) TearDownTest() { - _, err := s.db.ExecContext(s.ctx, s.teardownSql) + err := s.pgcontainer.RunSqlFiles(s.ctx, &testdataFolder, []string{"teardown.sql"}) if err != nil { panic(err) } @@ -96,11 +72,9 @@ func (s *IntegrationTestSuite) TearDownSuite() { if s.db != nil { s.db.Close() } - if s.pgcontainer != nil { - err := s.pgcontainer.Terminate(s.ctx) - if err != nil { - panic(err) - } + err := s.pgcontainer.TearDown(s.ctx) + if err != nil { + panic(err) } } diff --git a/backend/pkg/sqlmanager/postgres_sql-manager_integration_test.go b/backend/pkg/sqlmanager/postgres_sql-manager_integration_test.go index 04354c32f9..2aee422bbd 100644 --- a/backend/pkg/sqlmanager/postgres_sql-manager_integration_test.go +++ b/backend/pkg/sqlmanager/postgres_sql-manager_integration_test.go @@ -107,7 +107,7 @@ func (s *PostgresIntegrationTestSuite) Test_NewPooledSqlDb() { conn, err := s.sqlmanager.NewPooledSqlDb(s.ctx, slog.Default(), s.mgmtconn) requireNoConnErr(t, conn, err) - requireValidDatabase(t, s.ctx, conn, "postgres", "SELECT 1") + requireValidDatabase(t, s.ctx, conn, "pgx", "SELECT 1") conn.Db.Close() } @@ -118,7 +118,7 @@ func (s *PostgresIntegrationTestSuite) Test_NewSqlDb() { conn, err := s.sqlmanager.NewSqlDb(s.ctx, slog.Default(), s.mgmtconn, &connTimeout) requireNoConnErr(t, conn, err) - requireValidDatabase(t, s.ctx, conn, "postgres", "SELECT 1") + requireValidDatabase(t, s.ctx, conn, "pgx", "SELECT 1") conn.Db.Close() } @@ -127,7 +127,7 @@ func (s *PostgresIntegrationTestSuite) Test_NewSqlDbFromUrl() { conn, err := s.sqlmanager.NewSqlDbFromUrl(s.ctx, "postgres", s.pgcfg.GetUrl()) requireNoConnErr(t, conn, err) - requireValidDatabase(t, s.ctx, conn, "postgres", "SELECT 1") + requireValidDatabase(t, s.ctx, conn, "pgx", "SELECT 1") conn.Db.Close() } @@ -137,7 +137,7 @@ func (s *PostgresIntegrationTestSuite) Test_NewSqlDbFromConnectionConfig() { conn, err := s.sqlmanager.NewSqlDbFromConnectionConfig(s.ctx, slog.Default(), s.mgmtconn.GetConnectionConfig(), &connTimeout) requireNoConnErr(t, conn, err) - requireValidDatabase(t, s.ctx, conn, "postgres", "SELECT 1") + requireValidDatabase(t, s.ctx, conn, "pgx", "SELECT 1") conn.Db.Close() } diff --git a/backend/services/mgmt/v1alpha1/integration_tests/anonymization-service_integration_test.go b/backend/services/mgmt/v1alpha1/integration_tests/anonymization-service_integration_test.go index 37255bcccd..cd841063d5 100644 --- a/backend/services/mgmt/v1alpha1/integration_tests/anonymization-service_integration_test.go +++ b/backend/services/mgmt/v1alpha1/integration_tests/anonymization-service_integration_test.go @@ -17,10 +17,10 @@ func (s *IntegrationTestSuite) Test_AnonymizeService_AnonymizeMany() { t := s.T() t.Run("OSS-fail", func(t *testing.T) { - userclient := s.unauthdClients.users + userclient := s.UnauthdClients.Users s.setUser(s.ctx, userclient) accountId := s.createPersonalAccount(s.ctx, userclient) - resp, err := s.unauthdClients.anonymize.AnonymizeMany( + resp, err := s.UnauthdClients.Anonymize.AnonymizeMany( s.ctx, connect.NewRequest(&mgmtv1alpha1.AnonymizeManyRequest{ AccountId: accountId, @@ -35,8 +35,8 @@ func (s *IntegrationTestSuite) Test_AnonymizeService_AnonymizeMany() { }) t.Run("cloud-personal-fail", func(t *testing.T) { - userclient := s.neosyncCloudClients.getUserClient(testAuthUserId) - anonclient := s.neosyncCloudClients.getAnonymizeClient(testAuthUserId) + userclient := s.NeosyncCloudClients.GetUserClient(testAuthUserId) + anonclient := s.NeosyncCloudClients.GetAnonymizeClient(testAuthUserId) s.setUser(s.ctx, userclient) accountId := s.createPersonalAccount(s.ctx, userclient) resp, err := anonclient.AnonymizeMany( @@ -87,12 +87,12 @@ func (s *IntegrationTestSuite) Test_AnonymizeService_AnonymizeMany() { }`, } - userclient := s.neosyncCloudClients.getUserClient(testAuthUserId) - anonclient := s.neosyncCloudClients.getAnonymizeClient(testAuthUserId) + userclient := s.NeosyncCloudClients.GetUserClient(testAuthUserId) + anonclient := s.NeosyncCloudClients.GetAnonymizeClient(testAuthUserId) s.setUser(s.ctx, userclient) accountId := s.createBilledTeamAccount(s.ctx, userclient, "team1", "foo") - s.mocks.billingclient.On("GetSubscriptions", "foo").Once().Return(&testSubscriptionIter{subscriptions: []*stripe.Subscription{ + s.Mocks.Billingclient.On("GetSubscriptions", "foo").Once().Return(&testSubscriptionIter{subscriptions: []*stripe.Subscription{ {Status: stripe.SubscriptionStatusIncompleteExpired}, {Status: stripe.SubscriptionStatusActive}, }}, nil) @@ -170,8 +170,8 @@ func (s *IntegrationTestSuite) Test_AnonymizeService_AnonymizeSingle() { "sports": ["basketball", "golf", "swimming"] }` - accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users) - resp, err := s.unauthdClients.anonymize.AnonymizeSingle( + accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users) + resp, err := s.UnauthdClients.Anonymize.AnonymizeSingle( s.ctx, connect.NewRequest(&mgmtv1alpha1.AnonymizeSingleRequest{ AccountId: accountId, @@ -228,11 +228,11 @@ func (s *IntegrationTestSuite) Test_AnonymizeService_AnonymizeSingle_ForbiddenTr t := s.T() t.Run("OSS", func(t *testing.T) { - accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users) + accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users) t.Run("transformpiitext", func(t *testing.T) { t.Run("mappings", func(t *testing.T) { - resp, err := s.unauthdClients.anonymize.AnonymizeSingle( + resp, err := s.UnauthdClients.Anonymize.AnonymizeSingle( s.ctx, connect.NewRequest(&mgmtv1alpha1.AnonymizeSingleRequest{ AccountId: accountId, @@ -252,7 +252,7 @@ func (s *IntegrationTestSuite) Test_AnonymizeService_AnonymizeSingle_ForbiddenTr t.Run("defaults", func(t *testing.T) { t.Run("Bool", func(t *testing.T) { - resp, err := s.unauthdClients.anonymize.AnonymizeSingle( + resp, err := s.UnauthdClients.Anonymize.AnonymizeSingle( s.ctx, connect.NewRequest(&mgmtv1alpha1.AnonymizeSingleRequest{ AccountId: accountId, @@ -268,7 +268,7 @@ func (s *IntegrationTestSuite) Test_AnonymizeService_AnonymizeSingle_ForbiddenTr requireConnectError(t, err, connect.CodePermissionDenied) }) t.Run("S", func(t *testing.T) { - resp, err := s.unauthdClients.anonymize.AnonymizeSingle( + resp, err := s.UnauthdClients.Anonymize.AnonymizeSingle( s.ctx, connect.NewRequest(&mgmtv1alpha1.AnonymizeSingleRequest{ AccountId: accountId, @@ -284,7 +284,7 @@ func (s *IntegrationTestSuite) Test_AnonymizeService_AnonymizeSingle_ForbiddenTr requireConnectError(t, err, connect.CodePermissionDenied) }) t.Run("N", func(t *testing.T) { - resp, err := s.unauthdClients.anonymize.AnonymizeSingle( + resp, err := s.UnauthdClients.Anonymize.AnonymizeSingle( s.ctx, connect.NewRequest(&mgmtv1alpha1.AnonymizeSingleRequest{ AccountId: accountId, @@ -304,8 +304,8 @@ func (s *IntegrationTestSuite) Test_AnonymizeService_AnonymizeSingle_ForbiddenTr }) t.Run("cloud-personal", func(t *testing.T) { - userclient := s.neosyncCloudClients.getUserClient(testAuthUserId) - anonclient := s.neosyncCloudClients.getAnonymizeClient(testAuthUserId) + userclient := s.NeosyncCloudClients.GetUserClient(testAuthUserId) + anonclient := s.NeosyncCloudClients.GetAnonymizeClient(testAuthUserId) s.setUser(s.ctx, userclient) accountId := s.createPersonalAccount(s.ctx, userclient) diff --git a/backend/services/mgmt/v1alpha1/integration_tests/connection-service_integration_test.go b/backend/services/mgmt/v1alpha1/integration_tests/connection-service_integration_test.go index 081339d92f..fee2e210da 100644 --- a/backend/services/mgmt/v1alpha1/integration_tests/connection-service_integration_test.go +++ b/backend/services/mgmt/v1alpha1/integration_tests/connection-service_integration_test.go @@ -9,9 +9,9 @@ import ( ) func (s *IntegrationTestSuite) Test_ConnectionService_IsConnectionNameAvailable_Available() { - accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users) + accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users) - resp, err := s.unauthdClients.connections.IsConnectionNameAvailable( + resp, err := s.UnauthdClients.Connections.IsConnectionNameAvailable( s.ctx, connect.NewRequest(&mgmtv1alpha1.IsConnectionNameAvailableRequest{ AccountId: accountId, @@ -23,10 +23,10 @@ func (s *IntegrationTestSuite) Test_ConnectionService_IsConnectionNameAvailable_ } func (s *IntegrationTestSuite) Test_ConnectionService_IsConnectionNameAvailable_NotAvailable() { - accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users) - s.createPostgresConnection(s.unauthdClients.connections, accountId, "foo", "test-url") + accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users) + s.createPostgresConnection(s.UnauthdClients.Connections, accountId, "foo", "test-url") - resp, err := s.unauthdClients.connections.IsConnectionNameAvailable( + resp, err := s.UnauthdClients.Connections.IsConnectionNameAvailable( s.ctx, connect.NewRequest(&mgmtv1alpha1.IsConnectionNameAvailableRequest{ AccountId: accountId, @@ -39,16 +39,14 @@ func (s *IntegrationTestSuite) Test_ConnectionService_IsConnectionNameAvailable_ func (s *IntegrationTestSuite) Test_ConnectionService_CheckConnectionConfig() { t := s.T() - accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users) - pgconnstr, err := s.pgcontainer.ConnectionString(s.ctx, "sslmode=disable") - require.NoError(t, err) + accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users) - conn := s.createPostgresConnection(s.unauthdClients.connections, accountId, "foo", pgconnstr) + conn := s.createPostgresConnection(s.UnauthdClients.Connections, accountId, "foo", s.Pgcontainer.URL) t.Run("valid-pg-connstr", func(t *testing.T) { t.Parallel() - resp, err := s.unauthdClients.connections.CheckConnectionConfig( + resp, err := s.UnauthdClients.Connections.CheckConnectionConfig( s.ctx, connect.NewRequest(&mgmtv1alpha1.CheckConnectionConfigRequest{ ConnectionConfig: conn.GetConnectionConfig(), @@ -62,25 +60,21 @@ func (s *IntegrationTestSuite) Test_ConnectionService_CheckConnectionConfig() { func (s *IntegrationTestSuite) Test_ConnectionService_CreateConnection() { t := s.T() - accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users) + accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users) t.Run("postgres-success", func(t *testing.T) { - pgconnstr, err := s.pgcontainer.ConnectionString(s.ctx, "sslmode=disable") - require.NoError(t, err) - s.createPostgresConnection(s.unauthdClients.connections, accountId, "foo", pgconnstr) + s.createPostgresConnection(s.UnauthdClients.Connections, accountId, "foo", s.Pgcontainer.URL) }) } func (s *IntegrationTestSuite) Test_ConnectionService_UpdateConnection() { t := s.T() - accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users) + accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users) t.Run("postgres-success", func(t *testing.T) { - pgconnstr, err := s.pgcontainer.ConnectionString(s.ctx, "sslmode=disable") - require.NoError(t, err) - conn := s.createPostgresConnection(s.unauthdClients.connections, accountId, "foo", pgconnstr) + conn := s.createPostgresConnection(s.UnauthdClients.Connections, accountId, "foo", s.Pgcontainer.URL) - resp, err := s.unauthdClients.connections.UpdateConnection( + resp, err := s.UnauthdClients.Connections.UpdateConnection( s.ctx, connect.NewRequest(&mgmtv1alpha1.UpdateConnectionRequest{ Id: conn.GetId(), @@ -95,13 +89,11 @@ func (s *IntegrationTestSuite) Test_ConnectionService_UpdateConnection() { func (s *IntegrationTestSuite) Test_ConnectionService_GetConnection() { t := s.T() - accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users) - pgconnstr, err := s.pgcontainer.ConnectionString(s.ctx, "sslmode=disable") - require.NoError(t, err) + accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users) - conn := s.createPostgresConnection(s.unauthdClients.connections, accountId, "foo", pgconnstr) + conn := s.createPostgresConnection(s.UnauthdClients.Connections, accountId, "foo", s.Pgcontainer.URL) - resp, err := s.unauthdClients.connections.GetConnection( + resp, err := s.UnauthdClients.Connections.GetConnection( s.ctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{ Id: conn.GetId(), @@ -113,13 +105,11 @@ func (s *IntegrationTestSuite) Test_ConnectionService_GetConnection() { func (s *IntegrationTestSuite) Test_ConnectionService_GetConnections() { t := s.T() - accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users) - pgconnstr, err := s.pgcontainer.ConnectionString(s.ctx, "sslmode=disable") - require.NoError(t, err) + accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users) - s.createPostgresConnection(s.unauthdClients.connections, accountId, "foo", pgconnstr) + s.createPostgresConnection(s.UnauthdClients.Connections, accountId, "foo", s.Pgcontainer.URL) - resp, err := s.unauthdClients.connections.GetConnections( + resp, err := s.UnauthdClients.Connections.GetConnections( s.ctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionsRequest{ AccountId: accountId, @@ -131,13 +121,11 @@ func (s *IntegrationTestSuite) Test_ConnectionService_GetConnections() { func (s *IntegrationTestSuite) Test_ConnectionService_DeleteConnection() { t := s.T() - accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users) - pgconnstr, err := s.pgcontainer.ConnectionString(s.ctx, "sslmode=disable") - require.NoError(t, err) + accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users) - conn := s.createPostgresConnection(s.unauthdClients.connections, accountId, "foo", pgconnstr) + conn := s.createPostgresConnection(s.UnauthdClients.Connections, accountId, "foo", s.Pgcontainer.URL) - resp, err := s.unauthdClients.connections.GetConnections( + resp, err := s.UnauthdClients.Connections.GetConnections( s.ctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionsRequest{ AccountId: accountId, @@ -146,7 +134,7 @@ func (s *IntegrationTestSuite) Test_ConnectionService_DeleteConnection() { requireNoErrResp(t, resp, err) require.NotEmpty(t, resp.Msg.GetConnections()) - resp2, err := s.unauthdClients.connections.DeleteConnection( + resp2, err := s.UnauthdClients.Connections.DeleteConnection( s.ctx, connect.NewRequest(&mgmtv1alpha1.DeleteConnectionRequest{ Id: conn.GetId(), @@ -155,7 +143,7 @@ func (s *IntegrationTestSuite) Test_ConnectionService_DeleteConnection() { requireNoErrResp(t, resp2, err) // again to test idempotency - resp2, err = s.unauthdClients.connections.DeleteConnection( + resp2, err = s.UnauthdClients.Connections.DeleteConnection( s.ctx, connect.NewRequest(&mgmtv1alpha1.DeleteConnectionRequest{ Id: conn.GetId(), @@ -166,13 +154,11 @@ func (s *IntegrationTestSuite) Test_ConnectionService_DeleteConnection() { func (s *IntegrationTestSuite) Test_ConnectionService_CheckSqlQuery() { t := s.T() - accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users) - pgconnstr, err := s.pgcontainer.ConnectionString(s.ctx, "sslmode=disable") - require.NoError(t, err) + accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users) - conn := s.createPostgresConnection(s.unauthdClients.connections, accountId, "foo", pgconnstr) + conn := s.createPostgresConnection(s.UnauthdClients.Connections, accountId, "foo", s.Pgcontainer.URL) - resp, err := s.unauthdClients.connections.CheckSqlQuery( + resp, err := s.UnauthdClients.Connections.CheckSqlQuery( s.ctx, connect.NewRequest(&mgmtv1alpha1.CheckSqlQueryRequest{ Id: conn.GetId(), diff --git a/backend/services/mgmt/v1alpha1/integration_tests/integration_test.go b/backend/services/mgmt/v1alpha1/integration_tests/integration_test.go index 3b40117809..dc4c160563 100644 --- a/backend/services/mgmt/v1alpha1/integration_tests/integration_test.go +++ b/backend/services/mgmt/v1alpha1/integration_tests/integration_test.go @@ -3,349 +3,49 @@ package integrationtests_test import ( "context" "fmt" - "io" "log/slog" - "net/http" - "net/http/httptest" "os" - "sync" "testing" - "time" - "connectrpc.com/connect" - "github.com/jackc/pgx/v5/pgxpool" - db_queries "github.com/nucleuscloud/neosync/backend/gen/go/db" - mysql_queries "github.com/nucleuscloud/neosync/backend/gen/go/db/dbschemas/mysql" - pg_queries "github.com/nucleuscloud/neosync/backend/gen/go/db/dbschemas/postgresql" - "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1/mgmtv1alpha1connect" - "github.com/nucleuscloud/neosync/backend/internal/apikey" - auth_apikey "github.com/nucleuscloud/neosync/backend/internal/auth/apikey" - auth_client "github.com/nucleuscloud/neosync/backend/internal/auth/client" - auth_jwt "github.com/nucleuscloud/neosync/backend/internal/auth/jwt" - "github.com/nucleuscloud/neosync/backend/internal/authmgmt" - auth_interceptor "github.com/nucleuscloud/neosync/backend/internal/connect/interceptors/auth" - "github.com/nucleuscloud/neosync/backend/internal/neosyncdb" - clientmanager "github.com/nucleuscloud/neosync/backend/internal/temporal/client-manager" - "github.com/nucleuscloud/neosync/backend/internal/utils" - "github.com/nucleuscloud/neosync/backend/pkg/mongoconnect" - mssql_queries "github.com/nucleuscloud/neosync/backend/pkg/mssql-querier" - "github.com/nucleuscloud/neosync/backend/pkg/sqlconnect" - "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager" - v1alpha_anonymizationservice "github.com/nucleuscloud/neosync/backend/services/mgmt/v1alpha1/anonymization-service" - v1alpha1_connectionservice "github.com/nucleuscloud/neosync/backend/services/mgmt/v1alpha1/connection-service" - v1alpha1_jobservice "github.com/nucleuscloud/neosync/backend/services/mgmt/v1alpha1/job-service" - v1alpha1_transformersservice "github.com/nucleuscloud/neosync/backend/services/mgmt/v1alpha1/transformers-service" - v1alpha1_useraccountservice "github.com/nucleuscloud/neosync/backend/services/mgmt/v1alpha1/user-account-service" - awsmanager "github.com/nucleuscloud/neosync/internal/aws" - "github.com/nucleuscloud/neosync/internal/billing" - presidioapi "github.com/nucleuscloud/neosync/internal/ee/presidio" - neomigrate "github.com/nucleuscloud/neosync/internal/migrate" - promapiv1mock "github.com/nucleuscloud/neosync/internal/mocks/github.com/prometheus/client_golang/api/prometheus/v1" - http_client "github.com/nucleuscloud/neosync/worker/pkg/http/client" + tcneosyncapi "github.com/nucleuscloud/neosync/backend/pkg/integration-test" "github.com/stretchr/testify/suite" - "github.com/testcontainers/testcontainers-go" - testpg "github.com/testcontainers/testcontainers-go/modules/postgres" - "github.com/testcontainers/testcontainers-go/wait" ) -type unauthdClients struct { - users mgmtv1alpha1connect.UserAccountServiceClient - transformers mgmtv1alpha1connect.TransformersServiceClient - connections mgmtv1alpha1connect.ConnectionServiceClient - jobs mgmtv1alpha1connect.JobServiceClient - anonymize mgmtv1alpha1connect.AnonymizationServiceClient -} - -type neosyncCloudClients struct { - httpsrv *httptest.Server - basepath string -} - -func (s *neosyncCloudClients) getUserClient(authUserId string) mgmtv1alpha1connect.UserAccountServiceClient { - return mgmtv1alpha1connect.NewUserAccountServiceClient(http_client.WithBearerAuth(&http.Client{}, &authUserId), s.httpsrv.URL+s.basepath) -} -func (s *neosyncCloudClients) getAnonymizeClient(authUserId string) mgmtv1alpha1connect.AnonymizationServiceClient { - return mgmtv1alpha1connect.NewAnonymizationServiceClient(http_client.WithBearerAuth(&http.Client{}, &authUserId), s.httpsrv.URL+s.basepath) -} - -type authdClients struct { - httpsrv *httptest.Server -} - -func (s *authdClients) getUserClient(authUserId string) mgmtv1alpha1connect.UserAccountServiceClient { - return mgmtv1alpha1connect.NewUserAccountServiceClient(http_client.WithBearerAuth(&http.Client{}, &authUserId), s.httpsrv.URL+"/auth") -} - -type mocks struct { - temporalClientManager *clientmanager.MockTemporalClientManagerClient - authclient *auth_client.MockInterface - authmanagerclient *authmgmt.MockInterface - prometheusclient *promapiv1mock.MockAPI - billingclient *billing.MockInterface - presidio presidiomocks -} - -type presidiomocks struct { - analyzer *presidioapi.MockAnalyzeInterface - anonymizer *presidioapi.MockAnonymizeInterface - entities *presidioapi.MockEntityInterface -} - type IntegrationTestSuite struct { suite.Suite - - pgpool *pgxpool.Pool - neosyncQuerier db_queries.Querier - systemQuerier pg_queries.Querier - + tcneosyncapi.NeosyncApiTestClient ctx context.Context - - pgcontainer *testpg.PostgresContainer - connstr string - migrationsDir string - - httpsrv *httptest.Server - - unauthdClients *unauthdClients - neosyncCloudClients *neosyncCloudClients - authdClients *authdClients - - mocks *mocks } +// TODO update service integration tests to not use testify suite func (s *IntegrationTestSuite) SetupSuite() { s.ctx = context.Background() - - pgcontainer, err := testpg.Run( - s.ctx, - "postgres:15", - testcontainers.WithWaitStrategy( - wait.ForLog("database system is ready to accept connections"). - WithOccurrence(2).WithStartupTimeout(5*time.Second), - ), - ) - if err != nil { - panic(err) - } - s.pgcontainer = pgcontainer - connstr, err := pgcontainer.ConnectionString(s.ctx, "sslmode=disable") - if err != nil { - panic(err) - } - s.connstr = connstr - - pool, err := pgxpool.New(s.ctx, connstr) + api, err := tcneosyncapi.NewNeosyncApiTestClient(s.ctx, s.T()) if err != nil { panic(err) } - s.pgpool = pool - s.neosyncQuerier = db_queries.New() - s.systemQuerier = pg_queries.New() - s.migrationsDir = "../../../../sql/postgresql/schema" - - s.mocks = &mocks{ - temporalClientManager: clientmanager.NewMockTemporalClientManagerClient(s.T()), - authclient: auth_client.NewMockInterface(s.T()), - authmanagerclient: authmgmt.NewMockInterface(s.T()), - prometheusclient: promapiv1mock.NewMockAPI(s.T()), - billingclient: billing.NewMockInterface(s.T()), - presidio: presidiomocks{ - analyzer: presidioapi.NewMockAnalyzeInterface(s.T()), - anonymizer: presidioapi.NewMockAnonymizeInterface(s.T()), - entities: presidioapi.NewMockEntityInterface(s.T()), - }, - } - - maxAllowed := int64(10000) - unauthdUserService := v1alpha1_useraccountservice.New( - &v1alpha1_useraccountservice.Config{IsAuthEnabled: false, IsNeosyncCloud: false, DefaultMaxAllowedRecords: &maxAllowed}, - neosyncdb.New(pool, db_queries.New()), - s.mocks.temporalClientManager, - s.mocks.authclient, - s.mocks.authmanagerclient, - nil, - ) - - authdUserService := v1alpha1_useraccountservice.New( - &v1alpha1_useraccountservice.Config{IsAuthEnabled: true, IsNeosyncCloud: false}, - neosyncdb.New(pool, db_queries.New()), - s.mocks.temporalClientManager, - s.mocks.authclient, - s.mocks.authmanagerclient, - nil, - ) - - neoCloudAuthdUserService := v1alpha1_useraccountservice.New( - &v1alpha1_useraccountservice.Config{IsAuthEnabled: true, IsNeosyncCloud: true}, - neosyncdb.New(pool, db_queries.New()), - s.mocks.temporalClientManager, - s.mocks.authclient, - s.mocks.authmanagerclient, - s.mocks.billingclient, - ) - neoCloudAuthdAnonymizeService := v1alpha_anonymizationservice.New( - &v1alpha_anonymizationservice.Config{IsAuthEnabled: true, IsNeosyncCloud: true, IsPresidioEnabled: false}, - nil, - neoCloudAuthdUserService, - s.mocks.presidio.analyzer, - s.mocks.presidio.anonymizer, - neosyncdb.New(pool, db_queries.New()), - ) - - unauthdTransformersService := v1alpha1_transformersservice.New( - &v1alpha1_transformersservice.Config{IsPresidioEnabled: true, IsNeosyncCloud: false}, - neosyncdb.New(pool, db_queries.New()), - unauthdUserService, - s.mocks.presidio.entities, - ) - - unauthdConnectionsService := v1alpha1_connectionservice.New( - &v1alpha1_connectionservice.Config{}, - neosyncdb.New(pool, db_queries.New()), - unauthdUserService, - &sqlconnect.SqlOpenConnector{}, - pg_queries.New(), - mysql_queries.New(), - mssql_queries.New(), - mongoconnect.NewConnector(), - awsmanager.New(), - ) - unauthdJobsService := v1alpha1_jobservice.New( - &v1alpha1_jobservice.Config{}, - neosyncdb.New(pool, db_queries.New()), - s.mocks.temporalClientManager, - unauthdConnectionsService, - unauthdUserService, - sqlmanager.NewSqlManager( - &sync.Map{}, pg_queries.New(), - &sync.Map{}, mysql_queries.New(), - &sync.Map{}, mssql_queries.New(), - &sqlconnect.SqlOpenConnector{}, - ), - ) - - var presAnalyzeClient presidioapi.AnalyzeInterface - var presAnonClient presidioapi.AnonymizeInterface - - unauthdAnonymizationService := v1alpha_anonymizationservice.New( - &v1alpha_anonymizationservice.Config{IsPresidioEnabled: false}, - nil, - unauthdUserService, - presAnalyzeClient, presAnonClient, - neosyncdb.New(pool, db_queries.New()), - ) - - rootmux := http.NewServeMux() - - unauthmux := http.NewServeMux() - unauthmux.Handle(mgmtv1alpha1connect.NewUserAccountServiceHandler( - unauthdUserService, - )) - unauthmux.Handle(mgmtv1alpha1connect.NewTransformersServiceHandler( - unauthdTransformersService, - )) - unauthmux.Handle(mgmtv1alpha1connect.NewConnectionServiceHandler( - unauthdConnectionsService, - )) - unauthmux.Handle(mgmtv1alpha1connect.NewJobServiceHandler( - unauthdJobsService, - )) - - unauthmux.Handle(mgmtv1alpha1connect.NewAnonymizationServiceHandler( - unauthdAnonymizationService, - )) - rootmux.Handle("/unauth/", http.StripPrefix("/unauth", unauthmux)) - - authinterceptors := connect.WithInterceptors( - auth_interceptor.NewInterceptor(func(ctx context.Context, header http.Header, spec connect.Spec) (context.Context, error) { - // will need to further fill this out as the tests grow - authuserid, err := utils.GetBearerTokenFromHeader(header, "Authorization") - if err != nil { - return nil, err - } - if apikey.IsValidV1WorkerKey(authuserid) { - return auth_apikey.SetTokenData(ctx, &auth_apikey.TokenContextData{ - RawToken: authuserid, - ApiKey: nil, - ApiKeyType: apikey.WorkerApiKey, - }), nil - } - return auth_jwt.SetTokenData(ctx, &auth_jwt.TokenContextData{ - AuthUserId: authuserid, - Claims: &auth_jwt.CustomClaims{Email: &validAuthUser.Email}, - }), nil - }), - ) - - authmux := http.NewServeMux() - authmux.Handle(mgmtv1alpha1connect.NewUserAccountServiceHandler( - authdUserService, - authinterceptors, - )) - rootmux.Handle("/auth/", http.StripPrefix("/auth", authmux)) - - ncauthmux := http.NewServeMux() - ncauthmux.Handle(mgmtv1alpha1connect.NewUserAccountServiceHandler( - neoCloudAuthdUserService, - authinterceptors, - )) - ncauthmux.Handle(mgmtv1alpha1connect.NewAnonymizationServiceHandler( - neoCloudAuthdAnonymizeService, - authinterceptors, - )) - rootmux.Handle("/ncauth/", http.StripPrefix("/ncauth", ncauthmux)) - - s.httpsrv = startHTTPServer(s.T(), rootmux) - - s.unauthdClients = &unauthdClients{ - users: mgmtv1alpha1connect.NewUserAccountServiceClient(s.httpsrv.Client(), s.httpsrv.URL+"/unauth"), - transformers: mgmtv1alpha1connect.NewTransformersServiceClient(s.httpsrv.Client(), s.httpsrv.URL+"/unauth"), - connections: mgmtv1alpha1connect.NewConnectionServiceClient(s.httpsrv.Client(), s.httpsrv.URL+"/unauth"), - jobs: mgmtv1alpha1connect.NewJobServiceClient(s.httpsrv.Client(), s.httpsrv.URL+"/unauth"), - anonymize: mgmtv1alpha1connect.NewAnonymizationServiceClient(s.httpsrv.Client(), s.httpsrv.URL+"/unauth"), - } - - s.authdClients = &authdClients{ - httpsrv: s.httpsrv, - } - s.neosyncCloudClients = &neosyncCloudClients{ - httpsrv: s.httpsrv, - basepath: "/ncauth", - } + s.NeosyncApiTestClient = *api } // Runs before each test func (s *IntegrationTestSuite) SetupTest() { - discardLogger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})) - err := neomigrate.Up(s.ctx, s.connstr, s.migrationsDir, discardLogger) + err := s.InitializeTest(s.ctx) if err != nil { panic(err) } } func (s *IntegrationTestSuite) TearDownTest() { - // Dropping here because 1) more efficient and 2) we have a bad down migration - // _jobs-connection-id-null.down that breaks due to having a null connection_id column. - // we should do something about that at some point. Running this single drop is easier though - _, err := s.pgpool.Exec(s.ctx, "DROP SCHEMA IF EXISTS neosync_api CASCADE") - if err != nil { - panic(err) - } - _, err = s.pgpool.Exec(s.ctx, "DROP TABLE IF EXISTS public.schema_migrations") + err := s.CleanupTest(s.ctx) if err != nil { panic(err) } } func (s *IntegrationTestSuite) TearDownSuite() { - if s.pgpool != nil { - s.pgpool.Close() - } - if s.pgcontainer != nil { - err := s.pgcontainer.Terminate(s.ctx) - if err != nil { - panic(err) - } + err := s.TearDown(s.ctx) + if err != nil { + panic(err) } } @@ -358,12 +58,3 @@ func TestIntegrationTestSuite(t *testing.T) { } suite.Run(t, new(IntegrationTestSuite)) } - -func startHTTPServer(tb testing.TB, h http.Handler) *httptest.Server { - tb.Helper() - srv := httptest.NewUnstartedServer(h) - srv.EnableHTTP2 = true - srv.Start() - tb.Cleanup(srv.Close) - return srv -} diff --git a/backend/services/mgmt/v1alpha1/integration_tests/integration_util_test.go b/backend/services/mgmt/v1alpha1/integration_tests/integration_util_test.go index 7fd0720b36..30a04fe718 100644 --- a/backend/services/mgmt/v1alpha1/integration_tests/integration_util_test.go +++ b/backend/services/mgmt/v1alpha1/integration_tests/integration_util_test.go @@ -12,6 +12,7 @@ import ( mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1" "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1/mgmtv1alpha1connect" "github.com/nucleuscloud/neosync/backend/internal/neosyncdb" + tcneosyncapi "github.com/nucleuscloud/neosync/backend/pkg/integration-test" "github.com/stretchr/testify/require" ) @@ -20,9 +21,7 @@ func (s *IntegrationTestSuite) createPersonalAccount( userclient mgmtv1alpha1connect.UserAccountServiceClient, ) string { s.T().Helper() - resp, err := userclient.SetPersonalAccount(ctx, connect.NewRequest(&mgmtv1alpha1.SetPersonalAccountRequest{})) - requireNoErrResp(s.T(), resp, err) - return resp.Msg.AccountId + return tcneosyncapi.CreatePersonalAccount(ctx, s.T(), userclient) } func requireNoErrResp[T any](t testing.TB, resp *connect.Response[T], err error) { @@ -53,7 +52,7 @@ func (s *IntegrationTestSuite) setAccountCreatedAt( if err != nil { return err } - _, err = s.neosyncQuerier.SetAccountCreatedAt(ctx, s.pgpool, db_queries.SetAccountCreatedAtParams{ + _, err = s.NeosyncQuerier.SetAccountCreatedAt(ctx, s.Pgcontainer.DB, db_queries.SetAccountCreatedAtParams{ CreatedAt: pgtype.Timestamp{Time: createdAt, Valid: true}, AccountId: accountUuid, }) diff --git a/backend/services/mgmt/v1alpha1/integration_tests/jobs-service_integration_test.go b/backend/services/mgmt/v1alpha1/integration_tests/jobs-service_integration_test.go index 339dc52ada..53c41e1320 100644 --- a/backend/services/mgmt/v1alpha1/integration_tests/jobs-service_integration_test.go +++ b/backend/services/mgmt/v1alpha1/integration_tests/jobs-service_integration_test.go @@ -11,9 +11,9 @@ import ( ) func (s *IntegrationTestSuite) Test_GetJobs_Empty() { - accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users) + accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users) - resp, err := s.unauthdClients.jobs.GetJobs(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetJobsRequest{ + resp, err := s.UnauthdClients.Jobs.GetJobs(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetJobsRequest{ AccountId: accountId, })) requireNoErrResp(s.T(), resp, err) @@ -22,23 +22,23 @@ func (s *IntegrationTestSuite) Test_GetJobs_Empty() { } func (s *IntegrationTestSuite) Test_CreateJob_Ok() { - accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users) - srcconn := s.createPostgresConnection(s.unauthdClients.connections, accountId, "source", "test") - destconn := s.createPostgresConnection(s.unauthdClients.connections, accountId, "dest", "test2") + accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users) + srcconn := s.createPostgresConnection(s.UnauthdClients.Connections, accountId, "source", "test") + destconn := s.createPostgresConnection(s.UnauthdClients.Connections, accountId, "dest", "test2") mockScheduleClient := temporalmocks.NewScheduleClient(s.T()) mockScheduleHandle := temporalmocks.NewScheduleHandle(s.T()) - s.mocks.temporalClientManager. + s.Mocks.TemporalClientManager. On( "DoesAccountHaveTemporalWorkspace", mock.Anything, mock.Anything, mock.Anything, ). Return(true, nil). Once() - s.mocks.temporalClientManager. + s.Mocks.TemporalClientManager. On("GetScheduleClientByAccount", mock.Anything, mock.Anything, mock.Anything). Return(mockScheduleClient, nil). Once() - s.mocks.temporalClientManager. + s.Mocks.TemporalClientManager. On("GetTemporalConfigByAccount", mock.Anything, mock.Anything). Return(&pg_models.TemporalConfig{}, nil). Once() @@ -51,7 +51,7 @@ func (s *IntegrationTestSuite) Test_CreateJob_Ok() { Return("test-id"). Once() - resp, err := s.unauthdClients.jobs.CreateJob(s.ctx, connect.NewRequest(&mgmtv1alpha1.CreateJobRequest{ + resp, err := s.UnauthdClients.Jobs.CreateJob(s.ctx, connect.NewRequest(&mgmtv1alpha1.CreateJobRequest{ AccountId: accountId, JobName: "test", Mappings: []*mgmtv1alpha1.JobMapping{}, diff --git a/backend/services/mgmt/v1alpha1/integration_tests/transformers-service_integration_test.go b/backend/services/mgmt/v1alpha1/integration_tests/transformers-service_integration_test.go index 7e55874a44..65fdd83285 100644 --- a/backend/services/mgmt/v1alpha1/integration_tests/transformers-service_integration_test.go +++ b/backend/services/mgmt/v1alpha1/integration_tests/transformers-service_integration_test.go @@ -11,7 +11,7 @@ import ( ) func (s *IntegrationTestSuite) Test_TransformersService_GetSystemTransformers() { - resp, err := s.unauthdClients.transformers.GetSystemTransformers(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetSystemTransformersRequest{})) + resp, err := s.UnauthdClients.Transformers.GetSystemTransformers(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetSystemTransformersRequest{})) requireNoErrResp(s.T(), resp, err) require.NotEmpty(s.T(), resp.Msg.GetTransformers()) } @@ -19,7 +19,7 @@ func (s *IntegrationTestSuite) Test_TransformersService_GetSystemTransformers() func (s *IntegrationTestSuite) Test_TransformersService_GetSystemTransformersBySource() { t := s.T() t.Run("ok", func(t *testing.T) { - resp, err := s.unauthdClients.transformers.GetSystemTransformerBySource(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetSystemTransformerBySourceRequest{ + resp, err := s.UnauthdClients.Transformers.GetSystemTransformerBySource(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetSystemTransformerBySourceRequest{ Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_BOOL, })) requireNoErrResp(t, resp, err) @@ -28,7 +28,7 @@ func (s *IntegrationTestSuite) Test_TransformersService_GetSystemTransformersByS require.Equal(t, transformer.GetSource(), mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_BOOL) }) t.Run("not_found", func(t *testing.T) { - resp, err := s.unauthdClients.transformers.GetSystemTransformerBySource(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetSystemTransformerBySourceRequest{ + resp, err := s.UnauthdClients.Transformers.GetSystemTransformerBySource(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetSystemTransformerBySourceRequest{ Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_UNSPECIFIED, })) requireErrResp(t, resp, err) @@ -41,14 +41,14 @@ func (s *IntegrationTestSuite) Test_TransformersService_GetTransformPiiRecognize t.Run("ok", func(t *testing.T) { allowed := []string{"foo", "bar"} - s.mocks.presidio.entities.On("GetSupportedentitiesWithResponse", mock.Anything, mock.Anything). + s.Mocks.Presidio.Entities.On("GetSupportedentitiesWithResponse", mock.Anything, mock.Anything). Once(). Return(&presidioapi.GetSupportedentitiesResponse{ JSON200: &allowed, }, nil) - accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users) - resp, err := s.unauthdClients.transformers.GetTransformPiiEntities(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetTransformPiiEntitiesRequest{ + accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users) + resp, err := s.UnauthdClients.Transformers.GetTransformPiiEntities(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetTransformPiiEntitiesRequest{ AccountId: accountId, })) requireNoErrResp(t, resp, err) diff --git a/backend/services/mgmt/v1alpha1/integration_tests/user-account-service_integration_test.go b/backend/services/mgmt/v1alpha1/integration_tests/user-account-service_integration_test.go index bfc52a801e..3d50cf4e1d 100644 --- a/backend/services/mgmt/v1alpha1/integration_tests/user-account-service_integration_test.go +++ b/backend/services/mgmt/v1alpha1/integration_tests/user-account-service_integration_test.go @@ -27,9 +27,9 @@ var ( ) func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountOnboardingConfig() { - accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users) + accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users) - resp, err := s.unauthdClients.users.GetAccountOnboardingConfig(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetAccountOnboardingConfigRequest{AccountId: accountId})) + resp, err := s.UnauthdClients.Users.GetAccountOnboardingConfig(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetAccountOnboardingConfigRequest{AccountId: accountId})) requireNoErrResp(s.T(), resp, err) onboardingConfig := resp.Msg.GetConfig() require.NotNil(s.T(), onboardingConfig) @@ -38,21 +38,21 @@ func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountOnboardingConfi } func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountOnboardingConfig_NoAccount() { - resp, err := s.unauthdClients.users.GetAccountOnboardingConfig(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetAccountOnboardingConfigRequest{AccountId: uuid.NewString()})) + resp, err := s.UnauthdClients.Users.GetAccountOnboardingConfig(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetAccountOnboardingConfigRequest{AccountId: uuid.NewString()})) requireErrResp(s.T(), resp, err) requireConnectError(s.T(), err, connect.CodePermissionDenied) } func (s *IntegrationTestSuite) Test_UserAccountService_SetAccountOnboardingConfig_NoAccount() { - resp, err := s.unauthdClients.users.SetAccountOnboardingConfig(s.ctx, connect.NewRequest(&mgmtv1alpha1.SetAccountOnboardingConfigRequest{AccountId: uuid.NewString(), Config: &mgmtv1alpha1.AccountOnboardingConfig{}})) + resp, err := s.UnauthdClients.Users.SetAccountOnboardingConfig(s.ctx, connect.NewRequest(&mgmtv1alpha1.SetAccountOnboardingConfigRequest{AccountId: uuid.NewString(), Config: &mgmtv1alpha1.AccountOnboardingConfig{}})) requireErrResp(s.T(), resp, err) requireConnectError(s.T(), err, connect.CodePermissionDenied) } func (s *IntegrationTestSuite) Test_UserAccountService_SetAccountOnboardingConfig_NoConfig() { - accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users) + accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users) - resp, err := s.unauthdClients.users.SetAccountOnboardingConfig(s.ctx, connect.NewRequest(&mgmtv1alpha1.SetAccountOnboardingConfigRequest{AccountId: accountId, Config: nil})) + resp, err := s.UnauthdClients.Users.SetAccountOnboardingConfig(s.ctx, connect.NewRequest(&mgmtv1alpha1.SetAccountOnboardingConfigRequest{AccountId: accountId, Config: nil})) requireNoErrResp(s.T(), resp, err) onboardingConfig := resp.Msg.GetConfig() require.NotNil(s.T(), onboardingConfig) @@ -61,9 +61,9 @@ func (s *IntegrationTestSuite) Test_UserAccountService_SetAccountOnboardingConfi } func (s *IntegrationTestSuite) Test_UserAccountService_SetAccountOnboardingConfig() { - accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users) + accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users) - resp, err := s.unauthdClients.users.SetAccountOnboardingConfig(s.ctx, connect.NewRequest(&mgmtv1alpha1.SetAccountOnboardingConfigRequest{ + resp, err := s.UnauthdClients.Users.SetAccountOnboardingConfig(s.ctx, connect.NewRequest(&mgmtv1alpha1.SetAccountOnboardingConfigRequest{ AccountId: accountId, Config: &mgmtv1alpha1.AccountOnboardingConfig{ HasCompletedOnboarding: true, }}, @@ -85,12 +85,12 @@ var ( ) func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountTemporalConfig() { - accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users) + accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users) - s.mocks.temporalClientManager.On("GetTemporalConfigByAccount", mock.Anything, mock.Anything). + s.Mocks.TemporalClientManager.On("GetTemporalConfigByAccount", mock.Anything, mock.Anything). Return(validTemporalConfigModel, nil) - resp, err := s.unauthdClients.users.GetAccountTemporalConfig(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetAccountTemporalConfigRequest{AccountId: accountId})) + resp, err := s.UnauthdClients.Users.GetAccountTemporalConfig(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetAccountTemporalConfigRequest{AccountId: accountId})) requireNoErrResp(s.T(), resp, err) tc := resp.Msg.GetConfig() @@ -102,13 +102,13 @@ func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountTemporalConfig( } func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountTemporalConfig_NoAccount() { - resp, err := s.unauthdClients.users.GetAccountTemporalConfig(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetAccountTemporalConfigRequest{AccountId: uuid.NewString()})) + resp, err := s.UnauthdClients.Users.GetAccountTemporalConfig(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetAccountTemporalConfigRequest{AccountId: uuid.NewString()})) requireErrResp(s.T(), resp, err) requireConnectError(s.T(), err, connect.CodePermissionDenied) } func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountTemporalConfig_NeosyncCloud() { - userclient := s.neosyncCloudClients.getUserClient(testAuthUserId) + userclient := s.NeosyncCloudClients.GetUserClient(testAuthUserId) s.setUser(s.ctx, userclient) accountId := s.createPersonalAccount(s.ctx, userclient) @@ -118,13 +118,13 @@ func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountTemporalConfig_ } func (s *IntegrationTestSuite) Test_UserAccountService_SetAccountTemporalConfig_NoAccount() { - resp, err := s.unauthdClients.users.SetAccountTemporalConfig(s.ctx, connect.NewRequest(&mgmtv1alpha1.SetAccountTemporalConfigRequest{AccountId: uuid.NewString()})) + resp, err := s.UnauthdClients.Users.SetAccountTemporalConfig(s.ctx, connect.NewRequest(&mgmtv1alpha1.SetAccountTemporalConfigRequest{AccountId: uuid.NewString()})) requireErrResp(s.T(), resp, err) requireConnectError(s.T(), err, connect.CodePermissionDenied) } func (s *IntegrationTestSuite) Test_UserAccountService_SetAccountTemporalConfig_NeosyncCloud() { - userclient := s.neosyncCloudClients.getUserClient(testAuthUserId) + userclient := s.NeosyncCloudClients.GetUserClient(testAuthUserId) s.setUser(s.ctx, userclient) accountId := s.createPersonalAccount(s.ctx, userclient) @@ -134,12 +134,12 @@ func (s *IntegrationTestSuite) Test_UserAccountService_SetAccountTemporalConfig_ } func (s *IntegrationTestSuite) Test_UserAccountService_SetAccountTemporalConfig_NoConfig() { - accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users) + accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users) - s.mocks.temporalClientManager.On("GetTemporalConfigByAccount", mock.Anything, mock.Anything). + s.Mocks.TemporalClientManager.On("GetTemporalConfigByAccount", mock.Anything, mock.Anything). Return(validTemporalConfigModel, nil) - resp, err := s.unauthdClients.users.SetAccountTemporalConfig(s.ctx, connect.NewRequest(&mgmtv1alpha1.SetAccountTemporalConfigRequest{AccountId: accountId, Config: nil})) + resp, err := s.UnauthdClients.Users.SetAccountTemporalConfig(s.ctx, connect.NewRequest(&mgmtv1alpha1.SetAccountTemporalConfigRequest{AccountId: accountId, Config: nil})) requireNoErrResp(s.T(), resp, err) tc := resp.Msg.GetConfig() @@ -151,13 +151,13 @@ func (s *IntegrationTestSuite) Test_UserAccountService_SetAccountTemporalConfig_ } func (s *IntegrationTestSuite) Test_UserAccountService_SetAccountTemporalConfig() { - accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users) + accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users) // kind of a bad test since we are mocking this client wholesale, but it at least verifies we can write the config - s.mocks.temporalClientManager.On("GetTemporalConfigByAccount", mock.Anything, mock.Anything). + s.Mocks.TemporalClientManager.On("GetTemporalConfigByAccount", mock.Anything, mock.Anything). Return(validTemporalConfigModel, nil) - resp, err := s.unauthdClients.users.SetAccountTemporalConfig(s.ctx, connect.NewRequest(&mgmtv1alpha1.SetAccountTemporalConfigRequest{ + resp, err := s.UnauthdClients.Users.SetAccountTemporalConfig(s.ctx, connect.NewRequest(&mgmtv1alpha1.SetAccountTemporalConfigRequest{ AccountId: accountId, Config: &mgmtv1alpha1.AccountTemporalConfig{ Url: "test", Namespace: "test", @@ -174,7 +174,7 @@ func (s *IntegrationTestSuite) Test_UserAccountService_SetAccountTemporalConfig( } func (s *IntegrationTestSuite) Test_UserAccountService_GetUser_Auth() { - client := s.authdClients.getUserClient("test-user1") + client := s.AuthdClients.GetUserClient("test-user1") userId := s.setUser(s.ctx, client) resp, err := client.GetUser(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetUserRequest{})) @@ -183,21 +183,21 @@ func (s *IntegrationTestSuite) Test_UserAccountService_GetUser_Auth() { } func (s *IntegrationTestSuite) Test_UserAccountService_GetUser_Auth_NotFound() { - client := s.authdClients.getUserClient(testAuthUserId) + client := s.AuthdClients.GetUserClient(testAuthUserId) resp, err := client.GetUser(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetUserRequest{})) requireErrResp(s.T(), resp, err) requireConnectError(s.T(), err, connect.CodeNotFound) } func (s *IntegrationTestSuite) Test_UserAccountService_SetUser_Auth() { - client := s.authdClients.getUserClient(testAuthUserId) + client := s.AuthdClients.GetUserClient(testAuthUserId) userId := s.setUser(s.ctx, client) require.NotEmpty(s.T(), userId) require.NotEqual(s.T(), "00000000-0000-0000-0000-000000000000", userId) } func (s *IntegrationTestSuite) Test_UserAccountService_CreateTeamAccount_Auth() { - client := s.authdClients.getUserClient(testAuthUserId) + client := s.AuthdClients.GetUserClient(testAuthUserId) s.setUser(s.ctx, client) resp, err := client.CreateTeamAccount(s.ctx, connect.NewRequest(&mgmtv1alpha1.CreateTeamAccountRequest{Name: "test-name"})) @@ -206,12 +206,12 @@ func (s *IntegrationTestSuite) Test_UserAccountService_CreateTeamAccount_Auth() } func (s *IntegrationTestSuite) Test_UserAccountService_CreateTeamAccount_NeosyncCloud() { - client := s.neosyncCloudClients.getUserClient(testAuthUserId) + client := s.NeosyncCloudClients.GetUserClient(testAuthUserId) s.setUser(s.ctx, client) - s.mocks.billingclient.On("NewCustomer", mock.Anything).Once(). + s.Mocks.Billingclient.On("NewCustomer", mock.Anything).Once(). Return(&stripe.Customer{ID: "test-stripe-id"}, nil) - s.mocks.billingclient.On("NewCheckoutSession", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Once(). + s.Mocks.Billingclient.On("NewCheckoutSession", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Once(). Return(&stripe.CheckoutSession{URL: "test-url"}, nil) resp, err := client.CreateTeamAccount(s.ctx, connect.NewRequest(&mgmtv1alpha1.CreateTeamAccountRequest{Name: "test-name"})) @@ -221,11 +221,11 @@ func (s *IntegrationTestSuite) Test_UserAccountService_CreateTeamAccount_Neosync } func (s *IntegrationTestSuite) Test_UserAccountService_GetTeamAccountMembers_Auth() { - client := s.authdClients.getUserClient(testAuthUserId) + client := s.AuthdClients.GetUserClient(testAuthUserId) userId := s.setUser(s.ctx, client) accountId := s.createTeamAccount(s.ctx, client, "test-team") - s.mocks.authmanagerclient.On("GetUserBySub", mock.Anything, testAuthUserId). + s.Mocks.Authmanagerclient.On("GetUserBySub", mock.Anything, testAuthUserId). Return(validAuthUser, nil) resp, err := client.GetTeamAccountMembers(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetTeamAccountMembersRequest{AccountId: accountId})) @@ -255,7 +255,7 @@ func (s *IntegrationTestSuite) createTeamAccount(ctx context.Context, client mgm } func (s *IntegrationTestSuite) Test_UserAccountService_GetUser() { - resp, err := s.unauthdClients.users.GetUser(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetUserRequest{})) + resp, err := s.UnauthdClients.Users.GetUser(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetUserRequest{})) requireNoErrResp(s.T(), resp, err) userId := resp.Msg.GetUserId() @@ -263,7 +263,7 @@ func (s *IntegrationTestSuite) Test_UserAccountService_GetUser() { } func (s *IntegrationTestSuite) Test_UserAccountService_SetUser() { - resp, err := s.unauthdClients.users.SetUser(s.ctx, connect.NewRequest(&mgmtv1alpha1.SetUserRequest{})) + resp, err := s.UnauthdClients.Users.SetUser(s.ctx, connect.NewRequest(&mgmtv1alpha1.SetUserRequest{})) requireNoErrResp(s.T(), resp, err) userId := resp.Msg.UserId @@ -272,7 +272,7 @@ func (s *IntegrationTestSuite) Test_UserAccountService_SetUser() { } func (s *IntegrationTestSuite) Test_UserAccountService_GetAccounts_Empty() { - resp, err := s.unauthdClients.users.GetUserAccounts(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetUserAccountsRequest{})) + resp, err := s.UnauthdClients.Users.GetUserAccounts(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetUserAccountsRequest{})) requireNoErrResp(s.T(), resp, err) accounts := resp.Msg.GetAccounts() @@ -280,7 +280,7 @@ func (s *IntegrationTestSuite) Test_UserAccountService_GetAccounts_Empty() { } func (s *IntegrationTestSuite) Test_UserAccountService_SetPersonalAccount() { - resp, err := s.unauthdClients.users.SetPersonalAccount(s.ctx, connect.NewRequest(&mgmtv1alpha1.SetPersonalAccountRequest{})) + resp, err := s.UnauthdClients.Users.SetPersonalAccount(s.ctx, connect.NewRequest(&mgmtv1alpha1.SetPersonalAccountRequest{})) requireNoErrResp(s.T(), resp, err) accountId := resp.Msg.GetAccountId() @@ -288,9 +288,9 @@ func (s *IntegrationTestSuite) Test_UserAccountService_SetPersonalAccount() { } func (s *IntegrationTestSuite) Test_UserAccountService_GetAccounts_NotEmpty() { - s.createPersonalAccount(s.ctx, s.unauthdClients.users) + s.createPersonalAccount(s.ctx, s.UnauthdClients.Users) - accResp, err := s.unauthdClients.users.GetUserAccounts(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetUserAccountsRequest{})) + accResp, err := s.UnauthdClients.Users.GetUserAccounts(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetUserAccountsRequest{})) requireNoErrResp(s.T(), accResp, err) accounts := accResp.Msg.GetAccounts() @@ -299,15 +299,15 @@ func (s *IntegrationTestSuite) Test_UserAccountService_GetAccounts_NotEmpty() { } func (s *IntegrationTestSuite) Test_UserAccountService_IsUserInAccount() { - accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users) + accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users) - resp, err := s.unauthdClients.users.IsUserInAccount(s.ctx, connect.NewRequest(&mgmtv1alpha1.IsUserInAccountRequest{ + resp, err := s.UnauthdClients.Users.IsUserInAccount(s.ctx, connect.NewRequest(&mgmtv1alpha1.IsUserInAccountRequest{ AccountId: accountId, })) requireNoErrResp(s.T(), resp, err) require.True(s.T(), resp.Msg.GetOk()) - resp, err = s.unauthdClients.users.IsUserInAccount(s.ctx, connect.NewRequest(&mgmtv1alpha1.IsUserInAccountRequest{ + resp, err = s.UnauthdClients.Users.IsUserInAccount(s.ctx, connect.NewRequest(&mgmtv1alpha1.IsUserInAccountRequest{ AccountId: uuid.NewString(), })) requireNoErrResp(s.T(), resp, err) @@ -315,63 +315,63 @@ func (s *IntegrationTestSuite) Test_UserAccountService_IsUserInAccount() { } func (s *IntegrationTestSuite) Test_UserAccountService_CreateTeamAccount_NoAuth() { - s.createPersonalAccount(s.ctx, s.unauthdClients.users) + s.createPersonalAccount(s.ctx, s.UnauthdClients.Users) - resp, err := s.unauthdClients.users.CreateTeamAccount(s.ctx, connect.NewRequest(&mgmtv1alpha1.CreateTeamAccountRequest{Name: "test-name"})) + resp, err := s.UnauthdClients.Users.CreateTeamAccount(s.ctx, connect.NewRequest(&mgmtv1alpha1.CreateTeamAccountRequest{Name: "test-name"})) requireErrResp(s.T(), resp, err) requireConnectError(s.T(), err, connect.CodePermissionDenied) } func (s *IntegrationTestSuite) Test_UserAccountService_GetTeamAccountMembers_NoAuth_Personal() { - accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users) + accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users) - resp, err := s.unauthdClients.users.GetTeamAccountMembers(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetTeamAccountMembersRequest{AccountId: accountId})) + resp, err := s.UnauthdClients.Users.GetTeamAccountMembers(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetTeamAccountMembersRequest{AccountId: accountId})) requireErrResp(s.T(), resp, err) requireConnectError(s.T(), err, connect.CodePermissionDenied) } func (s *IntegrationTestSuite) Test_UserAccountService_RemoveTeamAccountMember_NoAuth_Personal() { - accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users) + accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users) - resp, err := s.unauthdClients.users.RemoveTeamAccountMember(s.ctx, connect.NewRequest(&mgmtv1alpha1.RemoveTeamAccountMemberRequest{AccountId: accountId, UserId: uuid.NewString()})) + resp, err := s.UnauthdClients.Users.RemoveTeamAccountMember(s.ctx, connect.NewRequest(&mgmtv1alpha1.RemoveTeamAccountMemberRequest{AccountId: accountId, UserId: uuid.NewString()})) requireErrResp(s.T(), resp, err) requireConnectError(s.T(), err, connect.CodePermissionDenied) } func (s *IntegrationTestSuite) Test_UserAccountService_InviteUserToTeamAccount_NoAuth_Personal() { - accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users) + accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users) - resp, err := s.unauthdClients.users.InviteUserToTeamAccount(s.ctx, connect.NewRequest(&mgmtv1alpha1.InviteUserToTeamAccountRequest{AccountId: accountId, Email: "test@example.com"})) + resp, err := s.UnauthdClients.Users.InviteUserToTeamAccount(s.ctx, connect.NewRequest(&mgmtv1alpha1.InviteUserToTeamAccountRequest{AccountId: accountId, Email: "test@example.com"})) requireErrResp(s.T(), resp, err) requireConnectError(s.T(), err, connect.CodePermissionDenied) } func (s *IntegrationTestSuite) Test_UserAccountService_GetTeamAccountInvites_NoAuth_Personal() { - accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users) + accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users) - resp, err := s.unauthdClients.users.GetTeamAccountInvites(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetTeamAccountInvitesRequest{AccountId: accountId})) + resp, err := s.UnauthdClients.Users.GetTeamAccountInvites(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetTeamAccountInvitesRequest{AccountId: accountId})) requireErrResp(s.T(), resp, err) requireConnectError(s.T(), err, connect.CodePermissionDenied) } func (s *IntegrationTestSuite) Test_UserAccountService_RemoveTeamAccountInvite_NoAuth_Personal() { - resp, err := s.unauthdClients.users.RemoveTeamAccountInvite(s.ctx, connect.NewRequest(&mgmtv1alpha1.RemoveTeamAccountInviteRequest{Id: uuid.NewString()})) + resp, err := s.UnauthdClients.Users.RemoveTeamAccountInvite(s.ctx, connect.NewRequest(&mgmtv1alpha1.RemoveTeamAccountInviteRequest{Id: uuid.NewString()})) requireNoErrResp(s.T(), resp, err) } func (s *IntegrationTestSuite) Test_UserAccountService_AcceptTeamAccountInvite_NoAuth_Personal() { - resp, err := s.unauthdClients.users.AcceptTeamAccountInvite(s.ctx, connect.NewRequest(&mgmtv1alpha1.AcceptTeamAccountInviteRequest{Token: uuid.NewString()})) + resp, err := s.UnauthdClients.Users.AcceptTeamAccountInvite(s.ctx, connect.NewRequest(&mgmtv1alpha1.AcceptTeamAccountInviteRequest{Token: uuid.NewString()})) requireErrResp(s.T(), resp, err) requireConnectError(s.T(), err, connect.CodeUnauthenticated) } func (s *IntegrationTestSuite) Test_UserAccountService_GetSystemInformation() { - resp, err := s.unauthdClients.users.GetSystemInformation(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetSystemInformationRequest{})) + resp, err := s.UnauthdClients.Users.GetSystemInformation(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetSystemInformationRequest{})) requireNoErrResp(s.T(), resp, err) } func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountStatus_NeosyncCloud_Personal() { - userclient := s.neosyncCloudClients.getUserClient(testAuthUserId) + userclient := s.NeosyncCloudClients.GetUserClient(testAuthUserId) s.setUser(s.ctx, userclient) accountId := s.createPersonalAccount(s.ctx, userclient) @@ -401,7 +401,7 @@ func (t *testSubscriptionIter) Err() error { } func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountStatus_NeosyncCloud_Billed() { - userclient := s.neosyncCloudClients.getUserClient(testAuthUserId) + userclient := s.NeosyncCloudClients.GetUserClient(testAuthUserId) s.setUser(s.ctx, userclient) t := s.T() @@ -409,7 +409,7 @@ func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountStatus_NeosyncC t.Run("active_sub", func(t *testing.T) { custId := "cust_id1" accountId := s.createBilledTeamAccount(s.ctx, userclient, "test-team", custId) - s.mocks.billingclient.On("GetSubscriptions", custId).Once().Return(&testSubscriptionIter{subscriptions: []*stripe.Subscription{ + s.Mocks.Billingclient.On("GetSubscriptions", custId).Once().Return(&testSubscriptionIter{subscriptions: []*stripe.Subscription{ {Status: stripe.SubscriptionStatusIncompleteExpired}, {Status: stripe.SubscriptionStatusActive}, }}, nil) @@ -428,7 +428,7 @@ func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountStatus_NeosyncC err := s.setAccountCreatedAt(s.ctx, accountId, time.Now().UTC().Add(-30*24*time.Hour)) assert.NoError(s.T(), err) - s.mocks.billingclient.On("GetSubscriptions", custId).Once().Return(&testSubscriptionIter{subscriptions: []*stripe.Subscription{ + s.Mocks.Billingclient.On("GetSubscriptions", custId).Once().Return(&testSubscriptionIter{subscriptions: []*stripe.Subscription{ {Status: stripe.SubscriptionStatusIncompleteExpired}, {Status: stripe.SubscriptionStatusIncompleteExpired}, }}, nil) @@ -443,9 +443,9 @@ func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountStatus_NeosyncC } func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountStatus_OSS_Personal() { - accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users) + accountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users) - resp, err := s.unauthdClients.users.GetAccountStatus(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetAccountStatusRequest{ + resp, err := s.UnauthdClients.Users.GetAccountStatus(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetAccountStatusRequest{ AccountId: accountId, })) requireNoErrResp(s.T(), resp, err) @@ -454,7 +454,7 @@ func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountStatus_OSS_Pers } func (s *IntegrationTestSuite) Test_UserAccountService_IsAccountStatusValid_NeosyncCloud_Personal() { - userclient := s.neosyncCloudClients.getUserClient(testAuthUserId) + userclient := s.NeosyncCloudClients.GetUserClient(testAuthUserId) s.setUser(s.ctx, userclient) accountId := s.createPersonalAccount(s.ctx, userclient) @@ -469,14 +469,14 @@ func (s *IntegrationTestSuite) Test_UserAccountService_IsAccountStatusValid_Neos } func (s *IntegrationTestSuite) Test_UserAccountService_IsAccountStatusValid_NeosyncCloud_Billed() { - userclient := s.neosyncCloudClients.getUserClient(testAuthUserId) + userclient := s.NeosyncCloudClients.GetUserClient(testAuthUserId) s.setUser(s.ctx, userclient) t := s.T() t.Run("active", func(t *testing.T) { custId := "cust_id1" accountId := s.createBilledTeamAccount(s.ctx, userclient, "test1", custId) - s.mocks.billingclient.On("GetSubscriptions", custId).Once().Return(&testSubscriptionIter{subscriptions: []*stripe.Subscription{ + s.Mocks.Billingclient.On("GetSubscriptions", custId).Once().Return(&testSubscriptionIter{subscriptions: []*stripe.Subscription{ {Status: stripe.SubscriptionStatusActive}, }}, nil) resp, err := userclient.IsAccountStatusValid(s.ctx, connect.NewRequest(&mgmtv1alpha1.IsAccountStatusValidRequest{ @@ -494,7 +494,7 @@ func (s *IntegrationTestSuite) Test_UserAccountService_IsAccountStatusValid_Neos accountId := s.createBilledTeamAccount(s.ctx, userclient, "test2", custId) err := s.setAccountCreatedAt(s.ctx, accountId, time.Now().UTC().Add(-30*24*time.Hour)) assert.NoError(t, err) - s.mocks.billingclient.On("GetSubscriptions", custId).Once().Return(&testSubscriptionIter{subscriptions: []*stripe.Subscription{ + s.Mocks.Billingclient.On("GetSubscriptions", custId).Once().Return(&testSubscriptionIter{subscriptions: []*stripe.Subscription{ {Status: stripe.SubscriptionStatusIncompleteExpired}, }}, nil) @@ -511,7 +511,7 @@ func (s *IntegrationTestSuite) Test_UserAccountService_IsAccountStatusValid_Neos t.Run("no_subs_active_trial", func(t *testing.T) { custId := "cust_id3" accountId := s.createBilledTeamAccount(s.ctx, userclient, "test3", custId) - s.mocks.billingclient.On("GetSubscriptions", custId).Once().Return(&testSubscriptionIter{subscriptions: []*stripe.Subscription{}}, nil) + s.Mocks.Billingclient.On("GetSubscriptions", custId).Once().Return(&testSubscriptionIter{subscriptions: []*stripe.Subscription{}}, nil) resp, err := userclient.IsAccountStatusValid(s.ctx, connect.NewRequest(&mgmtv1alpha1.IsAccountStatusValidRequest{ AccountId: accountId, @@ -527,7 +527,7 @@ func (s *IntegrationTestSuite) Test_UserAccountService_IsAccountStatusValid_Neos accountId := s.createBilledTeamAccount(s.ctx, userclient, "test4", custId) err := s.setAccountCreatedAt(s.ctx, accountId, time.Now().UTC().Add(-30*24*time.Hour)) assert.NoError(t, err) - s.mocks.billingclient.On("GetSubscriptions", custId).Once().Return(&testSubscriptionIter{subscriptions: []*stripe.Subscription{}}, nil) + s.Mocks.Billingclient.On("GetSubscriptions", custId).Once().Return(&testSubscriptionIter{subscriptions: []*stripe.Subscription{}}, nil) resp, err := userclient.IsAccountStatusValid(s.ctx, connect.NewRequest(&mgmtv1alpha1.IsAccountStatusValidRequest{ AccountId: accountId, @@ -542,7 +542,7 @@ func (s *IntegrationTestSuite) Test_UserAccountService_IsAccountStatusValid_Neos t.Run("no_active_subs_active_trial", func(t *testing.T) { custId := "cust_id5" accountId := s.createBilledTeamAccount(s.ctx, userclient, "test5", custId) - s.mocks.billingclient.On("GetSubscriptions", custId).Once().Return(&testSubscriptionIter{subscriptions: []*stripe.Subscription{ + s.Mocks.Billingclient.On("GetSubscriptions", custId).Once().Return(&testSubscriptionIter{subscriptions: []*stripe.Subscription{ {Status: stripe.SubscriptionStatusIncompleteExpired}, }}, nil) @@ -558,7 +558,7 @@ func (s *IntegrationTestSuite) Test_UserAccountService_IsAccountStatusValid_Neos } func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountBillingCheckoutSession() { - userclient := s.neosyncCloudClients.getUserClient(testAuthUserId) + userclient := s.NeosyncCloudClients.GetUserClient(testAuthUserId) s.setUser(s.ctx, userclient) t := s.T() @@ -566,7 +566,7 @@ func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountBillingCheckout t.Run("billed account - allowed", func(t *testing.T) { teamAccountId := s.createBilledTeamAccount(s.ctx, userclient, "test-team", "test-stripe-id") - s.mocks.billingclient.On("NewCheckoutSession", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Once(). + s.Mocks.Billingclient.On("NewCheckoutSession", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Once(). Return(&stripe.CheckoutSession{URL: "new-test-url"}, nil) resp, err := userclient.GetAccountBillingCheckoutSession(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetAccountBillingCheckoutSessionRequest{ @@ -585,7 +585,7 @@ func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountBillingCheckout }) t.Run("non-neosynccloud - disallowed", func(t *testing.T) { - personalAccountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users) + personalAccountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users) resp, err := userclient.GetAccountBillingCheckoutSession(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetAccountBillingCheckoutSessionRequest{ AccountId: personalAccountId, })) @@ -594,7 +594,7 @@ func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountBillingCheckout } func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountBillingPortalSession() { - userclient := s.neosyncCloudClients.getUserClient(testAuthUserId) + userclient := s.NeosyncCloudClients.GetUserClient(testAuthUserId) s.setUser(s.ctx, userclient) t := s.T() @@ -602,7 +602,7 @@ func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountBillingPortalSe t.Run("billed account - allowed", func(t *testing.T) { teamAccountId := s.createBilledTeamAccount(s.ctx, userclient, "test-team", "test-stripe-id") - s.mocks.billingclient.On("NewBillingPortalSession", mock.Anything, mock.Anything).Once(). + s.Mocks.Billingclient.On("NewBillingPortalSession", mock.Anything, mock.Anything).Once(). Return(&stripe.BillingPortalSession{URL: "new-test-url"}, nil) resp, err := userclient.GetAccountBillingPortalSession(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetAccountBillingPortalSessionRequest{ @@ -621,8 +621,8 @@ func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountBillingPortalSe }) t.Run("non-neosynccloud - disallowed", func(t *testing.T) { - personalAccountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users) - resp, err := s.unauthdClients.users.GetAccountBillingPortalSession(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetAccountBillingPortalSessionRequest{ + personalAccountId := s.createPersonalAccount(s.ctx, s.UnauthdClients.Users) + resp, err := s.UnauthdClients.Users.GetAccountBillingPortalSession(s.ctx, connect.NewRequest(&mgmtv1alpha1.GetAccountBillingPortalSessionRequest{ AccountId: personalAccountId, })) requireErrResp(s.T(), resp, err) @@ -630,22 +630,22 @@ func (s *IntegrationTestSuite) Test_UserAccountService_GetAccountBillingPortalSe } func (s *IntegrationTestSuite) createBilledTeamAccount(ctx context.Context, client mgmtv1alpha1connect.UserAccountServiceClient, name, stripeCustomerId string) string { - s.mocks.billingclient.On("NewCustomer", mock.Anything).Once(). + s.Mocks.Billingclient.On("NewCustomer", mock.Anything).Once(). Return(&stripe.Customer{ID: stripeCustomerId}, nil) - s.mocks.billingclient.On("NewCheckoutSession", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Once(). + s.Mocks.Billingclient.On("NewCheckoutSession", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Once(). Return(&stripe.CheckoutSession{URL: "test-url"}, nil) return s.createTeamAccount(ctx, client, name) } func (s *IntegrationTestSuite) Test_GetBillingAccounts() { - userclient1 := s.neosyncCloudClients.getUserClient(testAuthUserId) + userclient1 := s.NeosyncCloudClients.GetUserClient(testAuthUserId) s.setUser(s.ctx, userclient1) - userclient2 := s.neosyncCloudClients.getUserClient(testAuthUserId2) + userclient2 := s.NeosyncCloudClients.GetUserClient(testAuthUserId2) s.setUser(s.ctx, userclient2) workerapikey := apikey.NewV1WorkerKey() - workeruserclient := s.neosyncCloudClients.getUserClient(workerapikey) + workeruserclient := s.NeosyncCloudClients.GetUserClient(workerapikey) t := s.T() @@ -691,8 +691,8 @@ func (s *IntegrationTestSuite) Test_ConvertPersonalToTeamAccount() { t := s.T() t.Run("OSS unauth", func(t *testing.T) { - s.setUser(s.ctx, s.unauthdClients.users) - resp, err := s.unauthdClients.users.ConvertPersonalToTeamAccount(s.ctx, connect.NewRequest(&mgmtv1alpha1.ConvertPersonalToTeamAccountRequest{ + s.setUser(s.ctx, s.UnauthdClients.Users) + resp, err := s.UnauthdClients.Users.ConvertPersonalToTeamAccount(s.ctx, connect.NewRequest(&mgmtv1alpha1.ConvertPersonalToTeamAccountRequest{ Name: "unauthteamname", })) requireErrResp(t, resp, err) @@ -700,7 +700,7 @@ func (s *IntegrationTestSuite) Test_ConvertPersonalToTeamAccount() { }) t.Run("OSS auth success", func(t *testing.T) { - userclient := s.authdClients.getUserClient(testAuthUserId) + userclient := s.AuthdClients.GetUserClient(testAuthUserId) s.setUser(s.ctx, userclient) accountId := s.createPersonalAccount(s.ctx, userclient) @@ -713,14 +713,14 @@ func (s *IntegrationTestSuite) Test_ConvertPersonalToTeamAccount() { }) t.Run("cloud billing success", func(t *testing.T) { - userclient := s.neosyncCloudClients.getUserClient(testAuthUserId) + userclient := s.NeosyncCloudClients.GetUserClient(testAuthUserId) s.setUser(s.ctx, userclient) accountId := s.createPersonalAccount(s.ctx, userclient) stripeCustomerId := "foo" - s.mocks.billingclient.On("NewCustomer", mock.Anything).Once(). + s.Mocks.Billingclient.On("NewCustomer", mock.Anything).Once(). Return(&stripe.Customer{ID: stripeCustomerId}, nil) - s.mocks.billingclient.On("NewCheckoutSession", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Once(). + s.Mocks.Billingclient.On("NewCheckoutSession", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Once(). Return(&stripe.CheckoutSession{URL: "test-url"}, nil) resp, err := userclient.ConvertPersonalToTeamAccount(s.ctx, connect.NewRequest(&mgmtv1alpha1.ConvertPersonalToTeamAccountRequest{ Name: "newname2", @@ -731,13 +731,13 @@ func (s *IntegrationTestSuite) Test_ConvertPersonalToTeamAccount() { }) t.Run("cloud success unspecified account", func(t *testing.T) { - userclient := s.neosyncCloudClients.getUserClient(testAuthUserId) + userclient := s.NeosyncCloudClients.GetUserClient(testAuthUserId) s.setUser(s.ctx, userclient) stripeCustomerId := "foo" - s.mocks.billingclient.On("NewCustomer", mock.Anything).Once(). + s.Mocks.Billingclient.On("NewCustomer", mock.Anything).Once(). Return(&stripe.Customer{ID: stripeCustomerId}, nil) - s.mocks.billingclient.On("NewCheckoutSession", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Once(). + s.Mocks.Billingclient.On("NewCheckoutSession", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Once(). Return(&stripe.CheckoutSession{URL: "test-url"}, nil) resp, err := userclient.ConvertPersonalToTeamAccount(s.ctx, connect.NewRequest(&mgmtv1alpha1.ConvertPersonalToTeamAccountRequest{ Name: "newname3", @@ -747,11 +747,11 @@ func (s *IntegrationTestSuite) Test_ConvertPersonalToTeamAccount() { } func (s *IntegrationTestSuite) Test_SetBillingMeterEvent() { - userclient1 := s.neosyncCloudClients.getUserClient(testAuthUserId) + userclient1 := s.NeosyncCloudClients.GetUserClient(testAuthUserId) s.setUser(s.ctx, userclient1) workerapikey := apikey.NewV1WorkerKey() - workeruserclient := s.neosyncCloudClients.getUserClient(workerapikey) + workeruserclient := s.NeosyncCloudClients.GetUserClient(workerapikey) t := s.T() @@ -759,7 +759,7 @@ func (s *IntegrationTestSuite) Test_SetBillingMeterEvent() { au1TeamAccountId1 := s.createBilledTeamAccount(s.ctx, userclient1, "test-team", "test-stripe-id") t.Run("new event", func(t *testing.T) { - s.mocks.billingclient.On("NewMeterEvent", mock.Anything).Once().Return(&stripe.BillingMeterEvent{}, nil) + s.Mocks.Billingclient.On("NewMeterEvent", mock.Anything).Once().Return(&stripe.BillingMeterEvent{}, nil) ts := uint64(1) resp, err := workeruserclient.SetBillingMeterEvent(s.ctx, connect.NewRequest(&mgmtv1alpha1.SetBillingMeterEventRequest{ AccountId: au1TeamAccountId1, @@ -795,7 +795,7 @@ func (s *IntegrationTestSuite) Test_SetBillingMeterEvent() { t.Run("squashes meter already existing", func(t *testing.T) { eventId := "test-event-id" stripeerr := &stripe.Error{Type: stripe.ErrorTypeInvalidRequest, Msg: fmt.Sprintf("An event already exists with identifier %s", eventId)} - s.mocks.billingclient.On("NewMeterEvent", mock.Anything).Once().Return(nil, stripeerr) + s.Mocks.Billingclient.On("NewMeterEvent", mock.Anything).Once().Return(nil, stripeerr) ts := uint64(1) resp, err := workeruserclient.SetBillingMeterEvent(s.ctx, connect.NewRequest(&mgmtv1alpha1.SetBillingMeterEventRequest{ AccountId: au1TeamAccountId1, diff --git a/cli/internal/auth/tokens.go b/cli/internal/auth/tokens.go index 53015ddd90..199f80bced 100644 --- a/cli/internal/auth/tokens.go +++ b/cli/internal/auth/tokens.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log/slog" + "net/http" "connectrpc.com/connect" mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1" @@ -91,3 +92,41 @@ func IsAuthEnabled(ctx context.Context) (bool, error) { } return isEnabledResp.Msg.IsEnabled, nil } + +// Returns the neosync url found in the environment, otherwise defaults to localhost +func GetNeosyncUrl() string { + return serverconfig.GetApiBaseUrl() +} + +// Returns an instance of *http.Client that includes the Neosync API Token if one was found in the environment +func GetNeosyncHttpClient(ctx context.Context, apiKey *string, logger *slog.Logger) (*http.Client, error) { + token, err := GetToken(ctx, apiKey, logger) + if err != nil { + return nil, err + } + return http_client.NewWithBearerAuth(token), nil +} + +func GetToken(ctx context.Context, apiKey *string, logger *slog.Logger) (*string, error) { + isAuthEnabled, err := IsAuthEnabled(ctx) + if err != nil { + return nil, err + } + var token *string + if isAuthEnabled { + logger.Debug("Auth Enabled") + if apiKey != nil && *apiKey != "" { + logger.Debug("found API Key") + token = apiKey + } else { + logger.Debug("Retrieving Access Token") + accessToken, err := userconfig.GetAccessToken() + if err != nil { + logger.Error("Unable to retrieve access token. Please use neosync login command and try again.") + return nil, err + } + token = &accessToken + } + } + return token, nil +} diff --git a/cli/internal/benthos/config.go b/cli/internal/benthos/config.go deleted file mode 100644 index b5c9a88cf1..0000000000 --- a/cli/internal/benthos/config.go +++ /dev/null @@ -1,164 +0,0 @@ -package cli_neosync_benthos - -type BenthosConfig struct { - // HTTP HTTPConfig `json:"http" yaml:"http"` - StreamConfig `json:",inline" yaml:",inline"` -} - -type HTTPConfig struct { - Address string `json:"address" yaml:"address"` - Enabled bool `json:"enabled" yaml:"enabled"` - // RootPath string `json:"root_path" yaml:"root_path"` - // DebugEndpoints bool `json:"debug_endpoints" yaml:"debug_endpoints"` - // CertFile string `json:"cert_file" yaml:"cert_file"` - // KeyFile string `json:"key_file" yaml:"key_file"` - // CORS httpserver.CORSConfig `json:"cors" yaml:"cors"` - // BasicAuth httpserver.BasicAuthConfig `json:"basic_auth" yaml:"basic_auth"` -} - -type StreamConfig struct { - Logger *LoggerConfig `json:"logger" yaml:"logger,omitempty"` - Input *InputConfig `json:"input" yaml:"input"` - Buffer *BufferConfig `json:"buffer,omitempty" yaml:"buffer,omitempty"` - Pipeline *PipelineConfig `json:"pipeline" yaml:"pipeline"` - Output *OutputConfig `json:"output" yaml:"output"` -} - -type LoggerConfig struct { - Level string `json:"level" yaml:"level"` - AddTimestamp bool `json:"add_timestamp" yaml:"add_timestamp"` -} -type InputConfig struct { - Label string `json:"label" yaml:"label"` - Inputs `json:",inline" yaml:",inline"` -} - -type Inputs struct { - NeosyncConnectionData *NeosyncConnectionData `json:"neosync_connection_data,omitempty" yaml:"neosync_connection_data,omitempty"` -} - -type NeosyncConnectionData struct { - ApiKey *string `json:"api_key,omitempty" yaml:"api_key,omitempty"` - ApiUrl string `json:"api_url" yaml:"api_url"` - ConnectionId string `json:"connection_id" yaml:"connection_id"` - ConnectionType string `json:"connection_type" yaml:"connection_type"` - JobId *string `json:"job_id,omitempty" yaml:"job_id,omitempty"` - JobRunId *string `json:"job_run_id,omitempty" yaml:"job_run_id,omitempty"` - Schema string `json:"schema" yaml:"schema"` - Table string `json:"table" yaml:"table"` -} - -type BufferConfig struct{} - -type PipelineConfig struct { - Threads int `json:"threads" yaml:"threads"` - Processors []ProcessorConfig `json:"processors" yaml:"processors"` -} - -type ProcessorConfig struct { -} - -type BranchConfig struct { - Processors []ProcessorConfig `json:"processors" yaml:"processors"` - RequestMap *string `json:"request_map,omitempty" yaml:"request_map,omitempty"` - ResultMap *string `json:"result_map,omitempty" yaml:"result_map,omitempty"` -} - -type OutputConfig struct { - Label string `json:"label" yaml:"label"` - Outputs `json:",inline" yaml:",inline"` - Processors []ProcessorConfig `json:"processors,omitempty" yaml:"processors,omitempty"` - // Broker *OutputBrokerConfig `json:"broker,omitempty" yaml:"broker,omitempty"` -} - -type Outputs struct { - PooledSqlInsert *PooledSqlInsert `json:"pooled_sql_insert,omitempty" yaml:"pooled_sql_insert,omitempty"` - PooledSqlUpdate *PooledSqlUpdate `json:"pooled_sql_update,omitempty" yaml:"pooled_sql_update,omitempty"` - AwsS3 *AwsS3Insert `json:"aws_s3,omitempty" yaml:"aws_s3,omitempty"` - AwsDynamoDB *OutputAwsDynamoDB `json:"aws_dynamodb,omitempty" yaml:"aws_dynamodb,omitempty"` -} - -type OutputAwsDynamoDB struct { - Table string `json:"table" yaml:"table"` - JsonMapColumns map[string]string `json:"json_map_columns,omitempty" yaml:"json_map_columns,omitempty"` - - Region string `json:"region,omitempty" yaml:"region,omitempty"` - Endpoint string `json:"endpoint,omitempty" yaml:"endpoint,omitempty"` - - Credentials *AwsCredentials `json:"credentials,omitempty" yaml:"credentials,omitempty"` - - MaxInFlight *int `json:"max_in_flight,omitempty" yaml:"max_in_flight,omitempty"` - Batching *Batching `json:"batching,omitempty" yaml:"batching,omitempty"` -} - -type PooledSqlUpdate struct { - Driver string `json:"driver" yaml:"driver"` - Dsn string `json:"dsn" yaml:"dsn"` - Schema string `json:"schema" yaml:"schema"` - Table string `json:"table" yaml:"table"` - Columns []string `json:"columns" yaml:"columns"` - WhereColumns []string `json:"where_columns" yaml:"where_columns"` - ArgsMapping string `json:"args_mapping" yaml:"args_mapping"` - Batching *Batching `json:"batching,omitempty" yaml:"batching,omitempty"` -} - -type PooledSqlInsert struct { - Driver string `json:"driver" yaml:"driver"` - Dsn string `json:"dsn" yaml:"dsn"` - Schema string `json:"schema" yaml:"schema"` - Table string `json:"table" yaml:"table"` - Columns []string `json:"columns" yaml:"columns"` - OnConflictDoNothing bool `json:"on_conflict_do_nothing" yaml:"on_conflict_do_nothing"` - TruncateOnRetry bool `json:"truncate_on_retry" yaml:"truncate_on_retry"` - ArgsMapping string `json:"args_mapping" yaml:"args_mapping"` - Batching *Batching `json:"batching,omitempty" yaml:"batching,omitempty"` -} - -type AwsS3Insert struct { - Bucket string `json:"bucket" yaml:"bucket"` - MaxInFlight int `json:"max_in_flight" yaml:"max_in_flight"` - Path string `json:"path" yaml:"path"` - Batching *Batching `json:"batching,omitempty" yaml:"batching,omitempty"` - - Region string `json:"region,omitempty" yaml:"region,omitempty"` - Endpoint string `json:"endpoint,omitempty" yaml:"endpoint,omitempty"` - - Credentials *AwsCredentials `json:"credentials,omitempty" yaml:"credentials,omitempty"` -} - -type AwsCredentials struct { - Profile string `json:"profile,omitempty" yaml:"profile,omitempty"` - Id string `json:"id,omitempty" yaml:"id,omitempty"` - Secret string `json:"secret,omitempty" yaml:"secret,omitempty"` - Token string `json:"token,omitempty" yaml:"token,omitempty"` - FromEc2Role bool `json:"from_ec2_role,omitempty" yaml:"from_ec2_role,omitempty"` - Role string `json:"role,omitempty" yaml:"role,omitempty"` - RoleExternalId string `json:"role_external_id,omitempty" yaml:"role_external_id,omitempty"` -} - -type Batching struct { - Count int `json:"count" yaml:"count"` - ByteSize int `json:"byte_size" yaml:"byte_size"` - Period string `json:"period" yaml:"period"` - Check string `json:"check" yaml:"check"` - Processors []*BatchProcessor `json:"processors" yaml:"processors"` -} - -type BatchProcessor struct { - Archive *ArchiveProcessor `json:"archive,omitempty" yaml:"archive,omitempty"` - Compress *CompressProcessor `json:"compress,omitempty" yaml:"compress,omitempty"` -} - -type ArchiveProcessor struct { - Format string `json:"format" yaml:"format"` - Path *string `json:"path,omitempty" yaml:"path,omitempty"` -} - -type CompressProcessor struct { - Algorithm string `json:"algorithm" yaml:"algorithm"` -} - -type OutputBrokerConfig struct { - Pattern string `json:"pattern" yaml:"pattern"` - Outputs []Outputs `json:"outputs" yaml:"outputs"` -} diff --git a/cli/internal/cmds/neosync/connections/connections_integration_test.go b/cli/internal/cmds/neosync/connections/connections_integration_test.go new file mode 100644 index 0000000000..a194c071cf --- /dev/null +++ b/cli/internal/cmds/neosync/connections/connections_integration_test.go @@ -0,0 +1,71 @@ +package connections_cmd + +import ( + "context" + "testing" + + mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1" + tcneosyncapi "github.com/nucleuscloud/neosync/backend/pkg/integration-test" + "github.com/nucleuscloud/neosync/internal/testutil" + "github.com/stretchr/testify/require" +) + +const neosyncDbMigrationsPath = "../../../../../backend/sql/postgresql/schema" + +func Test_Connections(t *testing.T) { + t.Parallel() + ok := testutil.ShouldRunIntegrationTest() + if !ok { + return + } + ctx := context.Background() + + neosyncApi, err := tcneosyncapi.NewNeosyncApiTestClient(ctx, t, tcneosyncapi.WithMigrationsDirectory(neosyncDbMigrationsPath)) + if err != nil { + panic(err) + } + postgresUrl := "postgresql://postgres:foofar@localhost:5434/neosync" + + t.Run("list_unauthed", func(t *testing.T) { + accountId := tcneosyncapi.CreatePersonalAccount(ctx, t, neosyncApi.UnauthdClients.Users) + conn1 := tcneosyncapi.CreatePostgresConnection(ctx, t, neosyncApi.UnauthdClients.Connections, accountId, "conn1", postgresUrl) + conn2 := tcneosyncapi.CreatePostgresConnection(ctx, t, neosyncApi.UnauthdClients.Connections, accountId, "conn2", postgresUrl) + conns := []*mgmtv1alpha1.Connection{conn1, conn2} + connections, err := getConnections(ctx, neosyncApi.UnauthdClients.Connections, accountId) + require.NoError(t, err) + require.Len(t, connections, len(conns)) + }) + + t.Run("list_auth", func(t *testing.T) { + testAuthUserId := "c3b32842-9b70-4f4e-ad45-9cab26c6f2f1" + userclient := neosyncApi.AuthdClients.GetUserClient(testAuthUserId) + connclient := neosyncApi.AuthdClients.GetConnectionClient(testAuthUserId) + tcneosyncapi.SetUser(ctx, t, userclient) + accountId := tcneosyncapi.CreatePersonalAccount(ctx, t, userclient) + conn1 := tcneosyncapi.CreatePostgresConnection(ctx, t, connclient, accountId, "conn1", postgresUrl) + conn2 := tcneosyncapi.CreatePostgresConnection(ctx, t, connclient, accountId, "conn2", postgresUrl) + conns := []*mgmtv1alpha1.Connection{conn1, conn2} + connections, err := getConnections(ctx, connclient, accountId) + require.NoError(t, err) + require.Len(t, connections, len(conns)) + }) + + t.Run("list_cloud", func(t *testing.T) { + testAuthUserId := "34f3e404-c995-452b-89e4-9c486b491dab" + userclient := neosyncApi.NeosyncCloudClients.GetUserClient(testAuthUserId) + connclient := neosyncApi.NeosyncCloudClients.GetConnectionClient(testAuthUserId) + tcneosyncapi.SetUser(ctx, t, userclient) + accountId := tcneosyncapi.CreatePersonalAccount(ctx, t, userclient) + conn1 := tcneosyncapi.CreatePostgresConnection(ctx, t, connclient, accountId, "conn1", postgresUrl) + conn2 := tcneosyncapi.CreatePostgresConnection(ctx, t, connclient, accountId, "conn2", postgresUrl) + conns := []*mgmtv1alpha1.Connection{conn1, conn2} + connections, err := getConnections(ctx, connclient, accountId) + require.NoError(t, err) + require.Len(t, connections, len(conns)) + }) + + err = neosyncApi.TearDown(ctx) + if err != nil { + panic(err) + } +} diff --git a/cli/internal/cmds/neosync/connections/list.go b/cli/internal/cmds/neosync/connections/list.go index e7f041ba73..c067c994de 100644 --- a/cli/internal/cmds/neosync/connections/list.go +++ b/cli/internal/cmds/neosync/connections/list.go @@ -4,18 +4,17 @@ import ( "context" "errors" "fmt" + "log/slog" + "os" "time" "connectrpc.com/connect" + charmlog "github.com/charmbracelet/log" "github.com/fatih/color" mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1" "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1/mgmtv1alpha1connect" "github.com/nucleuscloud/neosync/cli/internal/auth" - auth_interceptor "github.com/nucleuscloud/neosync/cli/internal/connect/interceptors/auth" - "github.com/nucleuscloud/neosync/cli/internal/serverconfig" "github.com/nucleuscloud/neosync/cli/internal/userconfig" - "github.com/nucleuscloud/neosync/cli/internal/version" - http_client "github.com/nucleuscloud/neosync/worker/pkg/http/client" "github.com/rodaine/table" "github.com/spf13/cobra" ) @@ -35,8 +34,13 @@ func newListCmd() *cobra.Command { if err != nil { return err } + + debugMode, err := cmd.Flags().GetBool("debug") + if err != nil { + return err + } cmd.SilenceUsage = true - return listConnections(cmd.Context(), &apiKey, &accountId) + return listConnections(cmd.Context(), debugMode, &apiKey, &accountId) }, } cmd.Flags().String("account-id", "", "Account to list connections for. Defaults to account id in cli context") @@ -45,18 +49,24 @@ func newListCmd() *cobra.Command { func listConnections( ctx context.Context, + debugMode bool, apiKey, accountIdFlag *string, ) error { - isAuthEnabled, err := auth.IsAuthEnabled(ctx) - if err != nil { - return err + logLevel := charmlog.InfoLevel + if debugMode { + logLevel = charmlog.DebugLevel } + charmlogger := charmlog.NewWithOptions(os.Stderr, charmlog.Options{ + ReportTimestamp: true, + Level: logLevel, + }) + logger := slog.New(charmlogger) var accountId = accountIdFlag if accountId == nil || *accountId == "" { aId, err := userconfig.GetAccountId() if err != nil { - fmt.Println("Unable to retrieve account id. Please use account switch command to set account.") //nolint:forbidigo + logger.Error("Unable to retrieve account id. Please use account switch command to set account.") return err } accountId = &aId @@ -66,26 +76,40 @@ func listConnections( return errors.New("Account Id not found. Please use account switch command to set account.") } - connectionclient := mgmtv1alpha1connect.NewConnectionServiceClient( - http_client.NewWithHeaders(version.Get().Headers()), - serverconfig.GetApiBaseUrl(), - connect.WithInterceptors( - auth_interceptor.NewInterceptor(isAuthEnabled, auth.AuthHeader, auth.GetAuthHeaderTokenFn(apiKey)), - ), - ) - res, err := connectionclient.GetConnections(ctx, connect.NewRequest[mgmtv1alpha1.GetConnectionsRequest](&mgmtv1alpha1.GetConnectionsRequest{ - AccountId: *accountId, - })) + connectInterceptors := []connect.Interceptor{} + neosyncurl := auth.GetNeosyncUrl() + httpclient, err := auth.GetNeosyncHttpClient(ctx, apiKey, logger) + if err != nil { + return err + } + connectInterceptorOption := connect.WithInterceptors(connectInterceptors...) + connectionclient := mgmtv1alpha1connect.NewConnectionServiceClient(httpclient, neosyncurl, connectInterceptorOption) + + connections, err := getConnections(ctx, connectionclient, *accountId) if err != nil { return err } fmt.Println() //nolint:forbidigo - printConnectionsTable(res.Msg.Connections) + printConnectionsTable(connections) fmt.Println() //nolint:forbidigo return nil } +func getConnections( + ctx context.Context, + connectionclient mgmtv1alpha1connect.ConnectionServiceClient, + accountId string, +) ([]*mgmtv1alpha1.Connection, error) { + res, err := connectionclient.GetConnections(ctx, connect.NewRequest[mgmtv1alpha1.GetConnectionsRequest](&mgmtv1alpha1.GetConnectionsRequest{ + AccountId: accountId, + })) + if err != nil { + return nil, err + } + return res.Msg.GetConnections(), nil +} + func printConnectionsTable( connections []*mgmtv1alpha1.Connection, ) { diff --git a/cli/internal/cmds/neosync/neosync.go b/cli/internal/cmds/neosync/neosync.go index 1d7ae8a872..ef7b874f12 100644 --- a/cli/internal/cmds/neosync/neosync.go +++ b/cli/internal/cmds/neosync/neosync.go @@ -66,6 +66,8 @@ func Execute() { ) rootCmd.PersistentFlags().String(apiKeyFlag, "", fmt.Sprintf("Neosync API Key. Takes precedence over $%s", apiKeyEnvVarName)) + rootCmd.PersistentFlags().Bool("debug", false, "Run in debug mode") + rootCmd.AddCommand(jobs_cmd.NewCmd()) rootCmd.AddCommand(version_cmd.NewCmd()) rootCmd.AddCommand(whoami_cmd.NewCmd()) diff --git a/cli/internal/cmds/neosync/sync/config.go b/cli/internal/cmds/neosync/sync/config.go new file mode 100644 index 0000000000..8791ef6670 --- /dev/null +++ b/cli/internal/cmds/neosync/sync/config.go @@ -0,0 +1,199 @@ +package sync_cmd + +import ( + "errors" + "fmt" + "log/slog" + "os" + + mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1" + "github.com/nucleuscloud/neosync/cli/internal/output" + "github.com/nucleuscloud/neosync/cli/internal/userconfig" + "github.com/spf13/cobra" + "gopkg.in/yaml.v2" +) + +func buildCmdConfig(cmd *cobra.Command) (*cmdConfig, error) { + config := &cmdConfig{ + Source: &sourceConfig{ + ConnectionOpts: &connectionOpts{}, + }, + Destination: &sqlDestinationConfig{}, + AwsDynamoDbDestination: &dynamoDbDestinationConfig{}, + } + configPath, err := cmd.Flags().GetString("config") + if err != nil { + return nil, err + } + + if configPath != "" { + fileBytes, err := os.ReadFile(configPath) + if err != nil { + return nil, fmt.Errorf("error reading config file: %w", err) + } + err = yaml.Unmarshal(fileBytes, &config) + if err != nil { + return nil, fmt.Errorf("error parsing config file: %w", err) + } + } + + connectionId, err := cmd.Flags().GetString("connection-id") + if err != nil { + return nil, err + } + if connectionId != "" { + config.Source.ConnectionId = connectionId + } + + destConnUrl, err := cmd.Flags().GetString("destination-connection-url") + if err != nil { + return nil, err + } + if destConnUrl != "" { + config.Destination.ConnectionUrl = destConnUrl + } + + driver, err := cmd.Flags().GetString("destination-driver") + if err != nil { + return nil, err + } + pDriver, ok := parseDriverString(driver) + if ok { + config.Destination.Driver = pDriver + } + + initSchema, err := cmd.Flags().GetBool("init-schema") + if err != nil { + return nil, err + } + if initSchema { + config.Destination.InitSchema = initSchema + } + + truncateBeforeInsert, err := cmd.Flags().GetBool("truncate-before-insert") + if err != nil { + return nil, err + } + if truncateBeforeInsert { + config.Destination.TruncateBeforeInsert = truncateBeforeInsert + } + + truncateCascade, err := cmd.Flags().GetBool("truncate-cascade") + if err != nil { + return nil, err + } + if truncateCascade { + config.Destination.TruncateCascade = truncateCascade + } + + onConflictDoNothing, err := cmd.Flags().GetBool("on-conflict-do-nothing") + if err != nil { + return nil, err + } + if onConflictDoNothing { + config.Destination.OnConflict.DoNothing = onConflictDoNothing + } + + jobId, err := cmd.Flags().GetString("job-id") + if err != nil { + return nil, err + } + if jobId != "" { + config.Source.ConnectionOpts.JobId = &jobId + } + + jobRunId, err := cmd.Flags().GetString("job-run-id") + if err != nil { + return nil, err + } + if jobRunId != "" { + config.Source.ConnectionOpts.JobRunId = &jobRunId + } + + config, err = buildAwsCredConfig(cmd, config) + if err != nil { + return nil, err + } + + if config.Source.ConnectionId == "" { + return nil, fmt.Errorf("must provide connection-id") + } + + accountIdFlag, err := cmd.Flags().GetString("account-id") + if err != nil { + return nil, err + } + accountId := accountIdFlag + if accountId == "" { + aId, err := userconfig.GetAccountId() + if err != nil { + return nil, errors.New("Unable to retrieve account id. Please use account switch command to set account.") + } + accountId = aId + } + config.AccountId = &accountId + + if accountId == "" { + return nil, errors.New("Account Id not found. Please use account switch command to set account.") + } + + outputType, err := output.ValidateAndRetrieveOutputFlag(cmd) + if err != nil { + return nil, err + } + config.OutputType = &outputType + + debug, err := cmd.Flags().GetBool("debug") + if err != nil { + return nil, err + } + config.Debug = debug + return config, nil +} + +func isConfigValid(cmd *cmdConfig, logger *slog.Logger, sourceConnection *mgmtv1alpha1.Connection, sourceConnectionType ConnectionType) error { + if sourceConnectionType == awsS3Connection && (cmd.Source.ConnectionOpts.JobId == nil || *cmd.Source.ConnectionOpts.JobId == "") && (cmd.Source.ConnectionOpts.JobRunId == nil || *cmd.Source.ConnectionOpts.JobRunId == "") { + return errors.New("S3 source connection type requires job-id or job-run-id.") + } + if sourceConnectionType == gcpCloudStorageConnection && (cmd.Source.ConnectionOpts.JobId == nil || *cmd.Source.ConnectionOpts.JobId == "") && (cmd.Source.ConnectionOpts.JobRunId == nil || *cmd.Source.ConnectionOpts.JobRunId == "") { + return errors.New("GCP Cloud Storage source connection type requires job-id or job-run-id") + } + + if cmd.Destination.TruncateCascade && cmd.Destination.Driver == mysqlDriver { + return fmt.Errorf("truncate cascade is only supported in postgres") + } + + if sourceConnectionType == mysqlConnection || sourceConnectionType == postgresConnection { + if cmd.Destination.Driver == "" { + return fmt.Errorf("must provide destination-driver") + } + if cmd.Destination.ConnectionUrl == "" { + return fmt.Errorf("must provide destination-connection-url") + } + + if cmd.Destination.Driver != mysqlDriver && cmd.Destination.Driver != postgresDriver { + return errors.New("unsupported destination driver. only pgx (postgres) and mysql are currently supported") + } + } + + if sourceConnectionType == awsDynamoDBConnection { + if cmd.AwsDynamoDbDestination == nil { + return fmt.Errorf("must provide destination aws credentials") + } + + if cmd.AwsDynamoDbDestination.AwsCredConfig.Region == "" { + return fmt.Errorf("must provide destination aws region") + } + } + + if sourceConnection.AccountId != *cmd.AccountId { + return fmt.Errorf("Connection not found. AccountId: %s", *cmd.AccountId) + } + + logger.Debug("Checking if source and destination are compatible") + err := areSourceAndDestCompatible(sourceConnection, cmd.Destination.Driver) + if err != nil { + return err + } + return nil +} diff --git a/cli/internal/cmds/neosync/sync/dynamodb.go b/cli/internal/cmds/neosync/sync/dynamodb.go index c8e6f9e4df..81a7923c96 100644 --- a/cli/internal/cmds/neosync/sync/dynamodb.go +++ b/cli/internal/cmds/neosync/sync/dynamodb.go @@ -2,7 +2,7 @@ package sync_cmd import ( tabledependency "github.com/nucleuscloud/neosync/backend/pkg/table-dependency" - cli_neosync_benthos "github.com/nucleuscloud/neosync/cli/internal/benthos" + neosync_benthos "github.com/nucleuscloud/neosync/worker/pkg/benthos" "github.com/spf13/cobra" ) @@ -62,21 +62,19 @@ func buildAwsCredConfig(cmd *cobra.Command, config *cmdConfig) (*cmdConfig, erro func generateDynamoDbBenthosConfig( cmd *cmdConfig, - apiUrl string, - authToken *string, table string, ) *benthosConfigResponse { - bc := &cli_neosync_benthos.BenthosConfig{ - StreamConfig: cli_neosync_benthos.StreamConfig{ - Logger: &cli_neosync_benthos.LoggerConfig{ + bc := &neosync_benthos.BenthosConfig{ + StreamConfig: neosync_benthos.StreamConfig{ + Logger: &neosync_benthos.LoggerConfig{ Level: "ERROR", AddTimestamp: true, }, - Input: &cli_neosync_benthos.InputConfig{ - Inputs: cli_neosync_benthos.Inputs{ - NeosyncConnectionData: &cli_neosync_benthos.NeosyncConnectionData{ - ApiKey: authToken, - ApiUrl: apiUrl, + Input: &neosync_benthos.InputConfig{ + Inputs: neosync_benthos.Inputs{ + NeosyncConnectionData: &neosync_benthos.NeosyncConnectionData{ + // ApiKey: authToken, + // ApiUrl: apiUrl, ConnectionId: cmd.Source.ConnectionId, ConnectionType: string(awsDynamoDBConnection), Schema: "dynamodb", @@ -84,16 +82,16 @@ func generateDynamoDbBenthosConfig( }, }, }, - Pipeline: &cli_neosync_benthos.PipelineConfig{}, - Output: &cli_neosync_benthos.OutputConfig{ - Outputs: cli_neosync_benthos.Outputs{ - AwsDynamoDB: &cli_neosync_benthos.OutputAwsDynamoDB{ + Pipeline: &neosync_benthos.PipelineConfig{}, + Output: &neosync_benthos.OutputConfig{ + Outputs: neosync_benthos.Outputs{ + AwsDynamoDB: &neosync_benthos.OutputAwsDynamoDB{ Table: table, JsonMapColumns: map[string]string{ "": ".", }, - Batching: &cli_neosync_benthos.Batching{ + Batching: &neosync_benthos.Batching{ // https://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_BatchWriteItem.html // A single call to BatchWriteItem can transmit up to 16MB of data over the network, consisting of up to 25 item put or delete operations // Specifying the count here may not be enough if the overall data is above 16MB. @@ -119,12 +117,12 @@ func generateDynamoDbBenthosConfig( } } -func buildBenthosAwsCredentials(cmd *cmdConfig) *cli_neosync_benthos.AwsCredentials { +func buildBenthosAwsCredentials(cmd *cmdConfig) *neosync_benthos.AwsCredentials { if cmd.AwsDynamoDbDestination == nil || cmd.AwsDynamoDbDestination.AwsCredConfig == nil { return nil } cc := cmd.AwsDynamoDbDestination.AwsCredConfig - creds := &cli_neosync_benthos.AwsCredentials{} + creds := &neosync_benthos.AwsCredentials{} if cc.Profile != nil { creds.Profile = *cc.Profile } diff --git a/cli/internal/cmds/neosync/sync/sync.go b/cli/internal/cmds/neosync/sync/sync.go index ba3a474518..9c8c561ef2 100644 --- a/cli/internal/cmds/neosync/sync/sync.go +++ b/cli/internal/cmds/neosync/sync/sync.go @@ -6,7 +6,6 @@ import ( "fmt" "log/slog" "os" - "slices" "strings" syncmap "sync" "time" @@ -27,22 +26,17 @@ import ( sqlmanager_shared "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/shared" tabledependency "github.com/nucleuscloud/neosync/backend/pkg/table-dependency" "github.com/nucleuscloud/neosync/cli/internal/auth" - cli_neosync_benthos "github.com/nucleuscloud/neosync/cli/internal/benthos" - auth_interceptor "github.com/nucleuscloud/neosync/cli/internal/connect/interceptors/auth" "github.com/nucleuscloud/neosync/cli/internal/output" - "github.com/nucleuscloud/neosync/cli/internal/serverconfig" - "github.com/nucleuscloud/neosync/cli/internal/userconfig" - "github.com/nucleuscloud/neosync/cli/internal/version" connectiontunnelmanager "github.com/nucleuscloud/neosync/internal/connection-tunnel-manager" pool_sql_provider "github.com/nucleuscloud/neosync/internal/connection-tunnel-manager/pool/providers/sql" "github.com/nucleuscloud/neosync/internal/connection-tunnel-manager/providers" "github.com/nucleuscloud/neosync/internal/connection-tunnel-manager/providers/mongoprovider" "github.com/nucleuscloud/neosync/internal/connection-tunnel-manager/providers/sqlprovider" + neosync_benthos "github.com/nucleuscloud/neosync/worker/pkg/benthos" "github.com/spf13/cobra" "golang.org/x/sync/errgroup" "gopkg.in/yaml.v2" - _ "github.com/nucleuscloud/neosync/cli/internal/benthos/inputs" benthos_environment "github.com/nucleuscloud/neosync/worker/pkg/benthos/environment" _ "github.com/nucleuscloud/neosync/worker/pkg/benthos/sql" "github.com/nucleuscloud/neosync/worker/pkg/workflows/datasync/activities/shared" @@ -52,8 +46,6 @@ import ( _ "github.com/warpstreamlabs/bento/public/components/pure" _ "github.com/warpstreamlabs/bento/public/components/pure/extended" - http_client "github.com/nucleuscloud/neosync/worker/pkg/http/client" - "github.com/warpstreamlabs/bento/public/service" ) @@ -86,6 +78,8 @@ type cmdConfig struct { Destination *sqlDestinationConfig `yaml:"destination"` AwsDynamoDbDestination *dynamoDbDestinationConfig `yaml:"aws-dynamodb-destination,omitempty"` Debug bool + OutputType *output.OutputType `yaml:"output-type,omitempty"` + AccountId *string `yaml:"account-id,omitempty"` } type sourceConfig struct { @@ -141,133 +135,15 @@ func NewCmd() *cobra.Command { apiKey = &apiKeyStr } - config := &cmdConfig{ - Source: &sourceConfig{ - ConnectionOpts: &connectionOpts{}, - }, - Destination: &sqlDestinationConfig{}, - AwsDynamoDbDestination: &dynamoDbDestinationConfig{}, - } - configPath, err := cmd.Flags().GetString("config") - if err != nil { - return err - } - - if configPath != "" { - fileBytes, err := os.ReadFile(configPath) - if err != nil { - return fmt.Errorf("error reading config file: %w", err) - } - err = yaml.Unmarshal(fileBytes, &config) - if err != nil { - return fmt.Errorf("error parsing config file: %w", err) - } - } - - connectionId, err := cmd.Flags().GetString("connection-id") - if err != nil { - return err - } - if connectionId != "" { - config.Source.ConnectionId = connectionId - } - - destConnUrl, err := cmd.Flags().GetString("destination-connection-url") - if err != nil { - return err - } - if destConnUrl != "" { - config.Destination.ConnectionUrl = destConnUrl - } - - driver, err := cmd.Flags().GetString("destination-driver") - if err != nil { - return err - } - pDriver, ok := parseDriverString(driver) - if ok { - config.Destination.Driver = pDriver - } - - initSchema, err := cmd.Flags().GetBool("init-schema") - if err != nil { - return err - } - if initSchema { - config.Destination.InitSchema = initSchema - } - - truncateBeforeInsert, err := cmd.Flags().GetBool("truncate-before-insert") - if err != nil { - return err - } - if truncateBeforeInsert { - config.Destination.TruncateBeforeInsert = truncateBeforeInsert - } - - truncateCascade, err := cmd.Flags().GetBool("truncate-cascade") - if err != nil { - return err - } - if truncateCascade { - config.Destination.TruncateCascade = truncateCascade - } - - onConflictDoNothing, err := cmd.Flags().GetBool("on-conflict-do-nothing") - if err != nil { - return err - } - if onConflictDoNothing { - config.Destination.OnConflict.DoNothing = onConflictDoNothing - } - - jobId, err := cmd.Flags().GetString("job-id") - if err != nil { - return err - } - if jobId != "" { - config.Source.ConnectionOpts.JobId = &jobId - } - - jobRunId, err := cmd.Flags().GetString("job-run-id") - if err != nil { - return err - } - if jobRunId != "" { - config.Source.ConnectionOpts.JobRunId = &jobRunId - } - - config, err = buildAwsCredConfig(cmd, config) + config, err := buildCmdConfig(cmd) if err != nil { return err } - if config.Source.ConnectionId == "" { - return fmt.Errorf("must provide connection-id") - } - - accountId, err := cmd.Flags().GetString("account-id") - if err != nil { - return err - } - - outputType, err := output.ValidateAndRetrieveOutputFlag(cmd) - if err != nil { - return err - } - - debug, err := cmd.Flags().GetBool("debug") - if err != nil { - return err - } - config.Debug = debug - - return sync(cmd.Context(), outputType, apiKey, &accountId, config) + return sync(cmd.Context(), apiKey, config) }, } - cmd.Flags().Bool("debug", false, "Run in debug mode") - cmd.Flags().String("connection-id", "", "Connection id for sync source") cmd.Flags().String("job-id", "", "Id of Job to sync data from. Only used with [AWS S3, GCP Cloud Storage] connections. Can use job-run-id instead.") cmd.Flags().String("job-run-id", "", "Id of Job run to sync data from. Only used with [AWS S3, GCP Cloud Storage] connections. Can use job-id instead.") @@ -294,10 +170,20 @@ func NewCmd() *cobra.Command { return cmd } +type clisync struct { + connectiondataclient mgmtv1alpha1connect.ConnectionDataServiceClient + connectionclient mgmtv1alpha1connect.ConnectionServiceClient + sqlmanagerclient *sqlmanager.SqlManager + sqlconnector *sqlconnect.SqlOpenConnector + benv *service.Environment + cmd *cmdConfig + logger *slog.Logger + ctx context.Context +} + func sync( ctx context.Context, - outputType output.OutputType, - apiKey, accountIdFlag *string, + apiKey *string, cmd *cmdConfig, ) error { logLevel := charmlog.InfoLevel @@ -311,27 +197,16 @@ func sync( logger := slog.New(charmlogger) logger.Info("Starting sync") - isAuthEnabled, err := auth.IsAuthEnabled(ctx) + + connectInterceptors := []connect.Interceptor{} + neosyncurl := auth.GetNeosyncUrl() + httpclient, err := auth.GetNeosyncHttpClient(ctx, apiKey, logger) if err != nil { return err } - - httpclient := http_client.NewWithHeaders(version.Get().Headers()) - connectionclient := mgmtv1alpha1connect.NewConnectionServiceClient( - httpclient, - serverconfig.GetApiBaseUrl(), - connect.WithInterceptors( - auth_interceptor.NewInterceptor(isAuthEnabled, auth.AuthHeader, auth.GetAuthHeaderTokenFn(apiKey)), - ), - ) - - connectiondataclient := mgmtv1alpha1connect.NewConnectionDataServiceClient( - httpclient, - serverconfig.GetApiBaseUrl(), - connect.WithInterceptors( - auth_interceptor.NewInterceptor(isAuthEnabled, auth.AuthHeader, auth.GetAuthHeaderTokenFn(apiKey)), - ), - ) + connectInterceptorOption := connect.WithInterceptors(connectInterceptors...) + connectionclient := mgmtv1alpha1connect.NewConnectionServiceClient(httpclient, neosyncurl, connectInterceptorOption) + connectiondataclient := mgmtv1alpha1connect.NewConnectionDataServiceClient(httpclient, neosyncurl, connectInterceptorOption) pgpoolmap := &syncmap.Map{} mysqlpoolmap := &syncmap.Map{} @@ -342,99 +217,23 @@ func sync( sqlConnector := &sqlconnect.SqlOpenConnector{} sqlmanagerclient := sqlmanager.NewSqlManager(pgpoolmap, pgquerier, mysqlpoolmap, mysqlquerier, mssqlpoolmap, mssqlquerier, sqlConnector) - logger.Debug("Retrieving neosync source connection") - connResp, err := connectionclient.GetConnection(ctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{ - Id: cmd.Source.ConnectionId, - })) - if err != nil { - return err - } - sourceConnection := connResp.Msg.GetConnection() - sourceConnectionType, err := getConnectionType(sourceConnection) - if err != nil { - return err - } - logger.Debug(fmt.Sprintf("Source connection type: %s", sourceConnectionType)) - - if sourceConnectionType == awsS3Connection && (cmd.Source.ConnectionOpts.JobId == nil || *cmd.Source.ConnectionOpts.JobId == "") && (cmd.Source.ConnectionOpts.JobRunId == nil || *cmd.Source.ConnectionOpts.JobRunId == "") { - return errors.New("S3 source connection type requires job-id or job-run-id.") - } - if sourceConnectionType == gcpCloudStorageConnection && (cmd.Source.ConnectionOpts.JobId == nil || *cmd.Source.ConnectionOpts.JobId == "") && (cmd.Source.ConnectionOpts.JobRunId == nil || *cmd.Source.ConnectionOpts.JobRunId == "") { - return errors.New("GCP Cloud Storage source connection type requires job-id or job-run-id") - } - - if cmd.Destination.TruncateCascade && cmd.Destination.Driver == mysqlDriver { - return fmt.Errorf("truncate cascade is only supported in postgres") - } - - if sourceConnectionType == mysqlConnection || sourceConnectionType == postgresConnection { - if cmd.Destination.Driver == "" { - return fmt.Errorf("must provide destination-driver") - } - if cmd.Destination.ConnectionUrl == "" { - return fmt.Errorf("must provide destination-connection-url") - } - - if cmd.Destination.Driver != mysqlDriver && cmd.Destination.Driver != postgresDriver { - return errors.New("unsupported destination driver. only postgres and mysql are currently supported") - } - } - - if sourceConnectionType == awsDynamoDBConnection { - if cmd.AwsDynamoDbDestination == nil { - return fmt.Errorf("must provide destination aws credentials") - } - - if cmd.AwsDynamoDbDestination.AwsCredConfig.Region == "" { - return fmt.Errorf("must provide destination aws region") - } - } - logger.Debug("Validated config") - - var token *string - if isAuthEnabled { - logger.Debug("Auth Enabled") - if apiKey != nil && *apiKey != "" { - logger.Debug("found API Key") - token = apiKey - } else { - logger.Debug("Retrieving Access Token") - accessToken, err := userconfig.GetAccessToken() - if err != nil { - logger.Error("Unable to retrieve access token. Please use neosync login command and try again.") - return err - } - token = &accessToken - logger.Debug("Setting account id") - var accountId = accountIdFlag - if accountId == nil || *accountId == "" { - aId, err := userconfig.GetAccountId() - if err != nil { - logger.Error("Unable to retrieve account id. Please use account switch command to set account.") - return err - } - accountId = &aId - } - - if accountId == nil || *accountId == "" { - return errors.New("Account Id not found. Please use account switch command to set account.") - } - - if sourceConnection.AccountId != *accountId { - return fmt.Errorf("Connection not found. AccountId: %s", *accountId) - } - } + sync := &clisync{ + connectiondataclient: connectiondataclient, + connectionclient: connectionclient, + sqlmanagerclient: sqlmanagerclient, + sqlconnector: sqlConnector, + cmd: cmd, + logger: logger, + ctx: ctx, } - logger.Debug("Checking if source and destination are compatible") - err = areSourceAndDestCompatible(sourceConnection, cmd.Destination.Driver) - if err != nil { - return err - } + return sync.configureAndRunSync() +} +func (c *clisync) configureAndRunSync() error { connectionprovider := providers.NewProvider( mongoprovider.NewProvider(), - sqlprovider.NewProvider(sqlConnector), + sqlprovider.NewProvider(c.sqlconnector), ) tunnelmanager := connectiontunnelmanager.NewConnectionTunnelManager(connectionprovider) session := uuid.NewString() @@ -443,15 +242,15 @@ func sync( tunnelmanager.ReleaseSession(session) }() - destConnection := cmdConfigToDestinationConnection(cmd) + destConnection := cmdConfigToDestinationConnection(c.cmd) dsnToConnIdMap := &syncmap.Map{} var sqlDsn string - if cmd.Destination != nil { - sqlDsn = cmd.Destination.ConnectionUrl + if c.cmd.Destination != nil { + sqlDsn = c.cmd.Destination.ConnectionUrl } dsnToConnIdMap.Store(sqlDsn, destConnection.Id) stopChan := make(chan error, 3) - ctx, cancel := context.WithCancel(ctx) + ctx, cancel := context.WithCancel(c.ctx) defer cancel() go func() { for { @@ -464,8 +263,8 @@ func sync( } } }() - benv, err := benthos_environment.NewEnvironment( - logger, + benthosEnv, err := benthos_environment.NewEnvironment( + c.logger, benthos_environment.WithSqlConfig(&benthos_environment.SqlConfig{ Provider: pool_sql_provider.NewProvider(pool_sql_provider.GetSqlPoolProviderGetter( tunnelmanager, @@ -474,26 +273,63 @@ func sync( destConnection.Id: destConnection, }, session, - logger, + c.logger, )), IsRetry: false, }), + benthos_environment.WithConnectionDataConfig(&benthos_environment.ConnectionDataConfig{ + NeosyncConnectionDataApi: c.connectiondataclient, + }), benthos_environment.WithStopChannel(stopChan), benthos_environment.WithBlobEnv(bloblang.NewEnvironment()), ) if err != nil { return err } + c.benv = benthosEnv + + groupedConfigs, err := c.configureSync() + if err != nil { + return err + } + if groupedConfigs == nil { + return nil + } + + return runSync(c.ctx, *c.cmd.OutputType, c.benv, groupedConfigs, c.logger) +} + +func (c *clisync) configureSync() ([][]*benthosConfigResponse, error) { + c.logger.Debug("Retrieving neosync connection") + connResp, err := c.connectionclient.GetConnection(c.ctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{ + Id: c.cmd.Source.ConnectionId, + })) + if err != nil { + return nil, err + } + sourceConnection := connResp.Msg.GetConnection() + sourceConnectionType, err := getConnectionType(sourceConnection) + if err != nil { + return nil, err + } + c.logger.Debug(fmt.Sprintf("Source connection type: %s", sourceConnectionType)) + + err = isConfigValid(c.cmd, c.logger, sourceConnection, sourceConnectionType) + if err != nil { + return nil, err + } + c.logger.Debug("Validated config") - logger.Info("Retrieving connection schema...") + c.logger.Info("Retrieving connection schema...") var schemaConfig *schemaConfig switch sourceConnectionType { case awsS3Connection: + c.logger.Info("Building schema and table constraints...") var cfg *mgmtv1alpha1.AwsS3SchemaConfig - if cmd.Source.ConnectionOpts.JobRunId != nil && *cmd.Source.ConnectionOpts.JobRunId != "" { - cfg = &mgmtv1alpha1.AwsS3SchemaConfig{Id: &mgmtv1alpha1.AwsS3SchemaConfig_JobRunId{JobRunId: *cmd.Source.ConnectionOpts.JobRunId}} - } else if cmd.Source.ConnectionOpts.JobId != nil && *cmd.Source.ConnectionOpts.JobId != "" { - cfg = &mgmtv1alpha1.AwsS3SchemaConfig{Id: &mgmtv1alpha1.AwsS3SchemaConfig_JobId{JobId: *cmd.Source.ConnectionOpts.JobId}} + if c.cmd.Source.ConnectionOpts.JobRunId != nil && *c.cmd.Source.ConnectionOpts.JobRunId != "" { + cfg = &mgmtv1alpha1.AwsS3SchemaConfig{Id: &mgmtv1alpha1.AwsS3SchemaConfig_JobRunId{JobRunId: *c.cmd.Source.ConnectionOpts.JobRunId}} + } else if c.cmd.Source.ConnectionOpts.JobId != nil && *c.cmd.Source.ConnectionOpts.JobId != "" { + cfg = &mgmtv1alpha1.AwsS3SchemaConfig{Id: &mgmtv1alpha1.AwsS3SchemaConfig_JobId{JobId: *c.cmd.Source.ConnectionOpts.JobId}} } s3Config := &mgmtv1alpha1.ConnectionSchemaConfig{ Config: &mgmtv1alpha1.ConnectionSchemaConfig_AwsS3Config{ @@ -501,21 +337,21 @@ func sync( }, } - schemaCfg, err := getDestinationSchemaConfig(ctx, connectiondataclient, sqlmanagerclient, sourceConnection, cmd, s3Config, logger) + schemaCfg, err := c.getDestinationSchemaConfig(sourceConnection, s3Config) if err != nil { - return err + return nil, err } if len(schemaCfg.Schemas) == 0 { - logger.Warn("No tables found.") - return nil + c.logger.Warn("No tables found.") + return nil, nil } schemaConfig = schemaCfg case gcpCloudStorageConnection: var cfg *mgmtv1alpha1.GcpCloudStorageSchemaConfig - if cmd.Source.ConnectionOpts.JobRunId != nil && *cmd.Source.ConnectionOpts.JobRunId != "" { - cfg = &mgmtv1alpha1.GcpCloudStorageSchemaConfig{Id: &mgmtv1alpha1.GcpCloudStorageSchemaConfig_JobRunId{JobRunId: *cmd.Source.ConnectionOpts.JobRunId}} - } else if cmd.Source.ConnectionOpts.JobId != nil && *cmd.Source.ConnectionOpts.JobId != "" { - cfg = &mgmtv1alpha1.GcpCloudStorageSchemaConfig{Id: &mgmtv1alpha1.GcpCloudStorageSchemaConfig_JobId{JobId: *cmd.Source.ConnectionOpts.JobId}} + if c.cmd.Source.ConnectionOpts.JobRunId != nil && *c.cmd.Source.ConnectionOpts.JobRunId != "" { + cfg = &mgmtv1alpha1.GcpCloudStorageSchemaConfig{Id: &mgmtv1alpha1.GcpCloudStorageSchemaConfig_JobRunId{JobRunId: *c.cmd.Source.ConnectionOpts.JobRunId}} + } else if c.cmd.Source.ConnectionOpts.JobId != nil && *c.cmd.Source.ConnectionOpts.JobId != "" { + cfg = &mgmtv1alpha1.GcpCloudStorageSchemaConfig{Id: &mgmtv1alpha1.GcpCloudStorageSchemaConfig_JobId{JobId: *c.cmd.Source.ConnectionOpts.JobId}} } gcpConfig := &mgmtv1alpha1.ConnectionSchemaConfig{ @@ -524,45 +360,45 @@ func sync( }, } - schemaCfg, err := getDestinationSchemaConfig(ctx, connectiondataclient, sqlmanagerclient, sourceConnection, cmd, gcpConfig, logger) + schemaCfg, err := c.getDestinationSchemaConfig(sourceConnection, gcpConfig) if err != nil { - return err + return nil, err } if len(schemaCfg.Schemas) == 0 { - logger.Warn("No tables found.") - return nil + c.logger.Warn("No tables found.") + return nil, nil } schemaConfig = schemaCfg case mysqlConnection: - logger.Info("Building schema and table constraints...") + c.logger.Info("Building schema and table constraints...") mysqlCfg := &mgmtv1alpha1.ConnectionSchemaConfig{ Config: &mgmtv1alpha1.ConnectionSchemaConfig_MysqlConfig{ MysqlConfig: &mgmtv1alpha1.MysqlSchemaConfig{}, }, } - schemaCfg, err := getConnectionSchemaConfig(ctx, logger, connectiondataclient, sourceConnection, cmd, mysqlCfg) + schemaCfg, err := c.getConnectionSchemaConfig(sourceConnection, mysqlCfg) if err != nil { - return err + return nil, err } if len(schemaCfg.Schemas) == 0 { - logger.Warn("No tables found.") - return nil + c.logger.Warn("No tables found.") + return nil, nil } schemaConfig = schemaCfg case postgresConnection: - logger.Info("Building schema and table constraints...") + c.logger.Info("Building schema and table constraints...") postgresConfig := &mgmtv1alpha1.ConnectionSchemaConfig{ Config: &mgmtv1alpha1.ConnectionSchemaConfig_PgConfig{ PgConfig: &mgmtv1alpha1.PostgresSchemaConfig{}, }, } - schemaCfg, err := getConnectionSchemaConfig(ctx, logger, connectiondataclient, sourceConnection, cmd, postgresConfig) + schemaCfg, err := c.getConnectionSchemaConfig(sourceConnection, postgresConfig) if err != nil { - return err + return nil, err } if len(schemaCfg.Schemas) == 0 { - logger.Warn("No tables found.") - return nil + c.logger.Warn("No tables found.") + return nil, nil } schemaConfig = schemaCfg case awsDynamoDBConnection: @@ -571,13 +407,13 @@ func sync( DynamodbConfig: &mgmtv1alpha1.DynamoDBSchemaConfig{}, }, } - schemaCfg, err := getConnectionSchemaConfig(ctx, logger, connectiondataclient, sourceConnection, cmd, dynamoConfig) + schemaCfg, err := c.getConnectionSchemaConfig(sourceConnection, dynamoConfig) if err != nil { - return err + return nil, err } if len(schemaCfg.Schemas) == 0 { - logger.Warn("No tables found.") - return nil + c.logger.Warn("No tables found.") + return nil, nil } tableMap := map[string]struct{}{} for _, s := range schemaCfg.Schemas { @@ -585,58 +421,38 @@ func sync( } configs := []*benthosConfigResponse{} for t := range tableMap { - benthosConfig := generateDynamoDbBenthosConfig(cmd, serverconfig.GetApiBaseUrl(), token, t) + benthosConfig := generateDynamoDbBenthosConfig(c.cmd, t) configs = append(configs, benthosConfig) } - - return runSync(ctx, outputType, benv, [][]*benthosConfigResponse{configs}, logger) + return [][]*benthosConfigResponse{configs}, nil default: - return fmt.Errorf("this connection type is not currently supported") + return nil, fmt.Errorf("this connection type is not currently supported") } - logger.Debug("Building sync configs") - syncConfigs := buildSyncConfigs(schemaConfig, logger) + c.logger.Debug("Building sync configs") + syncConfigs := buildSyncConfigs(schemaConfig, c.logger) if syncConfigs == nil { - return nil + return nil, nil } - logger.Info("Running table init statements...") - err = runDestinationInitStatements(ctx, logger, sqlmanagerclient, cmd, syncConfigs, schemaConfig) + c.logger.Info("Running table init statements...") + err = c.runDestinationInitStatements(syncConfigs, schemaConfig) if err != nil { - return err + return nil, err } syncConfigCount := len(syncConfigs) - logger.Info(fmt.Sprintf("Generating %d sync configs...", syncConfigCount)) + c.logger.Info(fmt.Sprintf("Generating %d sync configs...", syncConfigCount)) configs := []*benthosConfigResponse{} for _, cfg := range syncConfigs { - benthosConfig := generateBenthosConfig(cmd, sourceConnectionType, serverconfig.GetApiBaseUrl(), cfg, token) + benthosConfig := generateBenthosConfig(c.cmd, sourceConnectionType, cfg) configs = append(configs, benthosConfig) } // order configs in run order by dependency - groupedConfigs := groupConfigsByDependency(configs, logger) - if groupedConfigs == nil { - return nil - } + c.logger.Debug("Ordering configs by dependency") + groupedConfigs := groupConfigsByDependency(configs, c.logger) - return runSync(ctx, outputType, benv, groupedConfigs, logger) -} - -func areSourceAndDestCompatible(connection *mgmtv1alpha1.Connection, destinationDriver DriverType) error { - switch connection.ConnectionConfig.Config.(type) { - case *mgmtv1alpha1.ConnectionConfig_PgConfig: - if destinationDriver != postgresDriver { - return fmt.Errorf("Connection and destination types are incompatible [postgres, %s]", destinationDriver) - } - case *mgmtv1alpha1.ConnectionConfig_MysqlConfig: - if destinationDriver != mysqlDriver { - return fmt.Errorf("Connection and destination types are incompatible [mysql, %s]", destinationDriver) - } - case *mgmtv1alpha1.ConnectionConfig_AwsS3Config, *mgmtv1alpha1.ConnectionConfig_GcpCloudstorageConfig, *mgmtv1alpha1.ConnectionConfig_DynamodbConfig: - default: - return errors.New("unsupported destination driver. only postgres and mysql are currently supported") - } - return nil + return groupedConfigs, nil } var ( @@ -678,9 +494,15 @@ func syncData(ctx context.Context, benv *service.Environment, cfg *benthosConfig } streamBuilderMu.Lock() streambldr := benv.NewStreamBuilder() + if streambldr == nil { + return fmt.Errorf("failed to create StreamBuilder") + } if outputType == output.PlainOutput { streambldr.SetLogger(logger.With("benthos", "true", "table", cfg.Table, "runType", runType)) } + if benv == nil { + return fmt.Errorf("benthos env is nil") + } err = streambldr.SetYAML(string(configbits)) if err != nil { @@ -788,30 +610,26 @@ func cmdConfigToDestinationConnection(cmd *cmdConfig) *mgmtv1alpha1.Connection { return &mgmtv1alpha1.Connection{} } -func runDestinationInitStatements( - ctx context.Context, - logger *slog.Logger, - sqlmanagerclient sqlmanager.SqlManagerClient, - cmd *cmdConfig, +func (c *clisync) runDestinationInitStatements( syncConfigs []*tabledependency.RunConfig, schemaConfig *schemaConfig, ) error { dependencyMap := buildDependencyMap(syncConfigs) - db, err := sqlmanagerclient.NewSqlDbFromUrl(ctx, string(cmd.Destination.Driver), cmd.Destination.ConnectionUrl) + db, err := c.sqlmanagerclient.NewSqlDbFromUrl(c.ctx, string(c.cmd.Destination.Driver), c.cmd.Destination.ConnectionUrl) if err != nil { return err } defer db.Db.Close() - if cmd.Destination.InitSchema { + if c.cmd.Destination.InitSchema { if len(schemaConfig.InitSchemaStatements) != 0 { for _, block := range schemaConfig.InitSchemaStatements { - logger.Info(fmt.Sprintf("[%s] found %d statements to execute during schema initialization", block.Label, len(block.Statements))) + c.logger.Info(fmt.Sprintf("[%s] found %d statements to execute during schema initialization", block.Label, len(block.Statements))) if len(block.Statements) == 0 { continue } - err = db.Db.BatchExec(ctx, batchSize, block.Statements, &sql_manager.BatchExecOpts{}) + err = db.Db.BatchExec(c.ctx, batchSize, block.Statements, &sql_manager.BatchExecOpts{}) if err != nil { - logger.Error(fmt.Sprintf("Error creating tables: %v", err)) + c.logger.Error(fmt.Sprintf("Error creating tables: %v", err)) return fmt.Errorf("unable to exec pg %s statements: %w", block.Label, err) } } @@ -829,15 +647,15 @@ func runDestinationInitStatements( orderedInitStatements = append(orderedInitStatements, schemaConfig.InitTableStatementsMap[t.String()]) } - err = db.Db.BatchExec(ctx, batchSize, orderedInitStatements, &sql_manager.BatchExecOpts{}) + err = db.Db.BatchExec(c.ctx, batchSize, orderedInitStatements, &sql_manager.BatchExecOpts{}) if err != nil { - logger.Error(fmt.Sprintf("Error creating tables: %v", err)) + c.logger.Error(fmt.Sprintf("Error creating tables: %v", err)) return err } } } - if cmd.Destination.Driver == postgresDriver { - if cmd.Destination.TruncateCascade { + if c.cmd.Destination.Driver == postgresDriver { + if c.cmd.Destination.TruncateCascade { truncateCascadeStmts := []string{} for _, syncCfg := range syncConfigs { stmt, ok := schemaConfig.TruncateTableStatementsMap[syncCfg.Table()] @@ -845,12 +663,12 @@ func runDestinationInitStatements( truncateCascadeStmts = append(truncateCascadeStmts, stmt) } } - err = db.Db.BatchExec(ctx, batchSize, truncateCascadeStmts, &sql_manager.BatchExecOpts{}) + err = db.Db.BatchExec(c.ctx, batchSize, truncateCascadeStmts, &sql_manager.BatchExecOpts{}) if err != nil { - logger.Error(fmt.Sprintf("Error truncate cascade tables: %v", err)) + c.logger.Error(fmt.Sprintf("Error truncate cascade tables: %v", err)) return err } - } else if cmd.Destination.TruncateBeforeInsert { + } else if c.cmd.Destination.TruncateBeforeInsert { orderedTablesResp, err := tabledependency.GetTablesOrderedByDependency(dependencyMap) if err != nil { return err @@ -859,13 +677,13 @@ func runDestinationInitStatements( if err != nil { return err } - err = db.Db.Exec(ctx, orderedTruncateStatement) + err = db.Db.Exec(c.ctx, orderedTruncateStatement) if err != nil { - logger.Error(fmt.Sprintf("Error truncating tables: %v", err)) + c.logger.Error(fmt.Sprintf("Error truncating tables: %v", err)) return err } } - } else if cmd.Destination.Driver == mysqlDriver { + } else if c.cmd.Destination.Driver == mysqlDriver { orderedTablesResp, err := tabledependency.GetTablesOrderedByDependency(dependencyMap) if err != nil { return err @@ -875,9 +693,9 @@ func runDestinationInitStatements( orderedTableTruncateStatements = append(orderedTableTruncateStatements, schemaConfig.TruncateTableStatementsMap[t.String()]) } disableFkChecks := sql_manager.DisableForeignKeyChecks - err = db.Db.BatchExec(ctx, batchSize, orderedTableTruncateStatements, &sql_manager.BatchExecOpts{Prefix: &disableFkChecks}) + err = db.Db.BatchExec(c.ctx, batchSize, orderedTableTruncateStatements, &sql_manager.BatchExecOpts{Prefix: &disableFkChecks}) if err != nil { - logger.Error(fmt.Sprintf("Error truncating tables: %v", err)) + c.logger.Error(fmt.Sprintf("Error truncating tables: %v", err)) return err } } @@ -906,21 +724,6 @@ func buildSyncConfigs( return runConfigs } -func buildDependencyMap(syncConfigs []*tabledependency.RunConfig) map[string][]string { - dependencyMap := map[string][]string{} - for _, cfg := range syncConfigs { - _, dpOk := dependencyMap[cfg.Table()] - if !dpOk { - dependencyMap[cfg.Table()] = []string{} - } - - for _, dep := range cfg.DependsOn() { - dependencyMap[cfg.Table()] = append(dependencyMap[cfg.Table()], dep.Table) - } - } - return dependencyMap -} - func getTableInitStatementMap( ctx context.Context, logger *slog.Logger, @@ -954,31 +757,10 @@ func getTableInitStatementMap( return nil, nil } -type SqlTable struct { - Schema string - Table string - Columns []string -} - -func getTableColMap(schemas []*mgmtv1alpha1.DatabaseColumn) map[string][]string { - tableColMap := map[string][]string{} - for _, record := range schemas { - table := sql_manager.BuildTable(record.Schema, record.Table) - _, ok := tableColMap[table] - if ok { - tableColMap[table] = append(tableColMap[table], record.Column) - } else { - tableColMap[table] = []string{record.Column} - } - } - - return tableColMap -} - type benthosConfigResponse struct { Name string DependsOn []*tabledependency.DependsOn - Config *cli_neosync_benthos.BenthosConfig + Config *neosync_benthos.BenthosConfig Table string Columns []string } @@ -986,9 +768,7 @@ type benthosConfigResponse struct { func generateBenthosConfig( cmd *cmdConfig, connectionType ConnectionType, - apiUrl string, syncConfig *tabledependency.RunConfig, - authToken *string, ) *benthosConfigResponse { schema, table := sqlmanager_shared.SplitTableKey(syncConfig.Table()) @@ -998,17 +778,15 @@ func generateBenthosConfig( jobId = cmd.Source.ConnectionOpts.JobId } - bc := &cli_neosync_benthos.BenthosConfig{ - StreamConfig: cli_neosync_benthos.StreamConfig{ - Logger: &cli_neosync_benthos.LoggerConfig{ + bc := &neosync_benthos.BenthosConfig{ + StreamConfig: neosync_benthos.StreamConfig{ + Logger: &neosync_benthos.LoggerConfig{ Level: "ERROR", AddTimestamp: true, }, - Input: &cli_neosync_benthos.InputConfig{ - Inputs: cli_neosync_benthos.Inputs{ - NeosyncConnectionData: &cli_neosync_benthos.NeosyncConnectionData{ - ApiKey: authToken, - ApiUrl: apiUrl, + Input: &neosync_benthos.InputConfig{ + Inputs: neosync_benthos.Inputs{ + NeosyncConnectionData: &neosync_benthos.NeosyncConnectionData{ ConnectionId: cmd.Source.ConnectionId, ConnectionType: string(connectionType), JobId: jobId, @@ -1018,17 +796,17 @@ func generateBenthosConfig( }, }, }, - Pipeline: &cli_neosync_benthos.PipelineConfig{}, - Output: &cli_neosync_benthos.OutputConfig{}, + Pipeline: &neosync_benthos.PipelineConfig{}, + Output: &neosync_benthos.OutputConfig{}, }, } if syncConfig.RunType() == tabledependency.RunTypeUpdate { args := syncConfig.InsertColumns() args = append(args, syncConfig.PrimaryKeys()...) - bc.Output = &cli_neosync_benthos.OutputConfig{ - Outputs: cli_neosync_benthos.Outputs{ - PooledSqlUpdate: &cli_neosync_benthos.PooledSqlUpdate{ + bc.Output = &neosync_benthos.OutputConfig{ + Outputs: neosync_benthos.Outputs{ + PooledSqlUpdate: &neosync_benthos.PooledSqlUpdate{ Driver: string(cmd.Destination.Driver), Dsn: cmd.Destination.ConnectionUrl, @@ -1038,7 +816,7 @@ func generateBenthosConfig( WhereColumns: syncConfig.PrimaryKeys(), ArgsMapping: buildPlainInsertArgs(args), - Batching: &cli_neosync_benthos.Batching{ + Batching: &neosync_benthos.Batching{ Period: "5s", Count: 100, }, @@ -1046,9 +824,9 @@ func generateBenthosConfig( }, } } else { - bc.Output = &cli_neosync_benthos.OutputConfig{ - Outputs: cli_neosync_benthos.Outputs{ - PooledSqlInsert: &cli_neosync_benthos.PooledSqlInsert{ + bc.Output = &neosync_benthos.OutputConfig{ + Outputs: neosync_benthos.Outputs{ + PooledSqlInsert: &neosync_benthos.PooledSqlInsert{ Driver: string(cmd.Destination.Driver), Dsn: cmd.Destination.ConnectionUrl, @@ -1058,7 +836,7 @@ func generateBenthosConfig( OnConflictDoNothing: cmd.Destination.OnConflict.DoNothing, ArgsMapping: buildPlainInsertArgs(syncConfig.SelectColumns()), - Batching: &cli_neosync_benthos.Batching{ + Batching: &neosync_benthos.Batching{ Period: "5s", Count: 100, }, @@ -1075,66 +853,6 @@ func generateBenthosConfig( Columns: syncConfig.InsertColumns(), } } -func groupConfigsByDependency(configs []*benthosConfigResponse, logger *slog.Logger) [][]*benthosConfigResponse { - groupedConfigs := [][]*benthosConfigResponse{} - configMap := map[string]*benthosConfigResponse{} - queuedMap := map[string][]string{} // map -> table to cols - - // get root configs - rootConfigs := []*benthosConfigResponse{} - for _, c := range configs { - if len(c.DependsOn) == 0 { - rootConfigs = append(rootConfigs, c) - queuedMap[c.Table] = c.Columns - } else { - configMap[c.Name] = c - } - } - if len(rootConfigs) == 0 { - logger.Info("No root configs found. There must be one config with no dependencies.") - return nil - } - groupedConfigs = append(groupedConfigs, rootConfigs) - - prevTableLen := 0 - for len(configMap) > 0 { - // prevents looping forever - if prevTableLen == len(configMap) { - logger.Info("Unable to order configs by dependency. No path found.") - return nil - } - prevTableLen = len(configMap) - dependentConfigs := []*benthosConfigResponse{} - for _, c := range configMap { - if isConfigReady(c, queuedMap) { - dependentConfigs = append(dependentConfigs, c) - delete(configMap, c.Name) - } - } - if len(dependentConfigs) > 0 { - groupedConfigs = append(groupedConfigs, dependentConfigs) - for _, c := range dependentConfigs { - queuedMap[c.Table] = append(queuedMap[c.Table], c.Columns...) - } - } - } - - return groupedConfigs -} -func isConfigReady(config *benthosConfigResponse, queuedMap map[string][]string) bool { - for _, dep := range config.DependsOn { - if cols, ok := queuedMap[dep.Table]; ok { - for _, dc := range dep.Columns { - if !slices.Contains(cols, dc) { - return false - } - } - } else { - return false - } - } - return true -} type schemaConfig struct { Schemas []*mgmtv1alpha1.DatabaseColumn @@ -1145,12 +863,8 @@ type schemaConfig struct { InitSchemaStatements []*mgmtv1alpha1.SchemaInitStatements } -func getConnectionSchemaConfig( - ctx context.Context, - logger *slog.Logger, - connectiondataclient mgmtv1alpha1connect.ConnectionDataServiceClient, +func (c *clisync) getConnectionSchemaConfig( connection *mgmtv1alpha1.Connection, - cmd *cmdConfig, sc *mgmtv1alpha1.ConnectionSchemaConfig, ) (*schemaConfig, error) { var schemas []*mgmtv1alpha1.DatabaseColumn @@ -1159,9 +873,9 @@ func getConnectionSchemaConfig( var initTableStatementsMap map[string]string var truncateTableStatementsMap map[string]string var initSchemaStatements []*mgmtv1alpha1.SchemaInitStatements - errgrp, errctx := errgroup.WithContext(ctx) + errgrp, errctx := errgroup.WithContext(c.ctx) errgrp.Go(func() error { - schemaResp, err := connectiondataclient.GetConnectionSchema(errctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionSchemaRequest{ + schemaResp, err := c.connectiondataclient.GetConnectionSchema(errctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionSchemaRequest{ ConnectionId: connection.Id, SchemaConfig: sc, })) @@ -1173,7 +887,7 @@ func getConnectionSchemaConfig( }) errgrp.Go(func() error { - constraintConnectionResp, err := connectiondataclient.GetConnectionTableConstraints(errctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionTableConstraintsRequest{ConnectionId: cmd.Source.ConnectionId})) + constraintConnectionResp, err := c.connectiondataclient.GetConnectionTableConstraints(errctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionTableConstraintsRequest{ConnectionId: c.cmd.Source.ConnectionId})) if err != nil { return err } @@ -1183,7 +897,7 @@ func getConnectionSchemaConfig( }) errgrp.Go(func() error { - initStatementsResp, err := getTableInitStatementMap(errctx, logger, connectiondataclient, cmd.Source.ConnectionId, cmd.Destination) + initStatementsResp, err := getTableInitStatementMap(errctx, c.logger, c.connectiondataclient, c.cmd.Source.ConnectionId, c.cmd.Destination) if err != nil { return err } @@ -1225,16 +939,11 @@ func getConnectionSchemaConfig( }, nil } -func getDestinationSchemaConfig( - ctx context.Context, - connectiondataclient mgmtv1alpha1connect.ConnectionDataServiceClient, - sqlmanagerclient sqlmanager.SqlManagerClient, +func (c *clisync) getDestinationSchemaConfig( connection *mgmtv1alpha1.Connection, - cmd *cmdConfig, sc *mgmtv1alpha1.ConnectionSchemaConfig, - logger *slog.Logger, ) (*schemaConfig, error) { - schemaResp, err := connectiondataclient.GetConnectionSchema(ctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionSchemaRequest{ + schemaResp, err := c.connectiondataclient.GetConnectionSchema(c.ctx, connect.NewRequest(&mgmtv1alpha1.GetConnectionSchemaRequest{ ConnectionId: connection.Id, SchemaConfig: sc, })) @@ -1244,7 +953,7 @@ func getDestinationSchemaConfig( tableColMap := getTableColMap(schemaResp.Msg.GetSchemas()) if len(tableColMap) == 0 { - logger.Info("No tables found.") + c.logger.Info("No tables found.") return nil, nil } @@ -1257,8 +966,8 @@ func getDestinationSchemaConfig( schemas = append(schemas, s) } - logger.Info("Building table constraints...") - tableConstraints, err := getDestinationTableConstraints(ctx, sqlmanagerclient, cmd.Destination.Driver, cmd.Destination.ConnectionUrl, schemas) + c.logger.Info("Building table constraints...") + tableConstraints, err := c.getDestinationTableConstraints(schemas) if err != nil { return nil, err } @@ -1271,8 +980,8 @@ func getDestinationSchemaConfig( } truncateTableStatementsMap := map[string]string{} - if cmd.Destination.Driver == postgresDriver { - if cmd.Destination.TruncateCascade { + if c.cmd.Destination.Driver == postgresDriver { + if c.cmd.Destination.TruncateCascade { for t := range tableColMap { schema, table := sqlmanager_shared.SplitTableKey(t) stmt, err := sqlmanager_postgres.BuildPgTruncateCascadeStatement(schema, table) @@ -1284,7 +993,7 @@ func getDestinationSchemaConfig( } // truncate before insert handled in runDestinationInitStatements } else { - if cmd.Destination.TruncateBeforeInsert { + if c.cmd.Destination.TruncateBeforeInsert { for t := range tableColMap { schema, table := sqlmanager_shared.SplitTableKey(t) stmt, err := sqlmanager_mysql.BuildMysqlTruncateStatement(schema, table) @@ -1304,10 +1013,10 @@ func getDestinationSchemaConfig( }, nil } -func getDestinationTableConstraints(ctx context.Context, sqlmanagerclient sqlmanager.SqlManagerClient, connectionDriver DriverType, connectionUrl string, schemas []string) (*sql_manager.TableConstraints, error) { - cctx, cancel := context.WithDeadline(ctx, time.Now().Add(5*time.Second)) +func (c *clisync) getDestinationTableConstraints(schemas []string) (*sql_manager.TableConstraints, error) { + cctx, cancel := context.WithDeadline(c.ctx, time.Now().Add(5*time.Second)) defer cancel() - db, err := sqlmanagerclient.NewSqlDbFromUrl(cctx, string(connectionDriver), connectionUrl) + db, err := c.sqlmanagerclient.NewSqlDbFromUrl(cctx, string(c.cmd.Destination.Driver), c.cmd.Destination.ConnectionUrl) if err != nil { return nil, err } @@ -1320,45 +1029,3 @@ func getDestinationTableConstraints(ctx context.Context, sqlmanagerclient sqlman return constraints, nil } - -func buildPlainInsertArgs(cols []string) string { - if len(cols) == 0 { - return "" - } - pieces := make([]string, len(cols)) - for idx := range cols { - pieces[idx] = fmt.Sprintf("this.%q", cols[idx]) - } - return fmt.Sprintf("root = [%s]", strings.Join(pieces, ", ")) -} - -func maxInt(a, b int) int { - if a > b { - return a - } - return b -} - -func parseDriverString(str string) (DriverType, bool) { - p, ok := driverMap[strings.ToLower(str)] - return p, ok -} - -func getConnectionType(connection *mgmtv1alpha1.Connection) (ConnectionType, error) { - if connection.ConnectionConfig.GetAwsS3Config() != nil { - return awsS3Connection, nil - } - if connection.GetConnectionConfig().GetGcpCloudstorageConfig() != nil { - return gcpCloudStorageConnection, nil - } - if connection.ConnectionConfig.GetMysqlConfig() != nil { - return mysqlConnection, nil - } - if connection.ConnectionConfig.GetPgConfig() != nil { - return postgresConnection, nil - } - if connection.ConnectionConfig.GetDynamodbConfig() != nil { - return awsDynamoDBConnection, nil - } - return "", errors.New("unsupported connection type") -} diff --git a/cli/internal/cmds/neosync/sync/sync_integration_test.go b/cli/internal/cmds/neosync/sync/sync_integration_test.go new file mode 100644 index 0000000000..1fda80b520 --- /dev/null +++ b/cli/internal/cmds/neosync/sync/sync_integration_test.go @@ -0,0 +1,107 @@ +package sync_cmd + +import ( + "context" + "testing" + + tcneosyncapi "github.com/nucleuscloud/neosync/backend/pkg/integration-test" + "github.com/nucleuscloud/neosync/cli/internal/output" + "github.com/nucleuscloud/neosync/internal/testutil" + tcpostgres "github.com/nucleuscloud/neosync/internal/testutil/testcontainers/postgres" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" +) + +const neosyncDbMigrationsPath = "../../../../../backend/sql/postgresql/schema" + +func Test_Sync_Postgres(t *testing.T) { + t.Parallel() + ok := testutil.ShouldRunIntegrationTest() + if !ok { + return + } + ctx := context.Background() + + var neosyncApi *tcneosyncapi.NeosyncApiTestClient + var postgres *tcpostgres.PostgresTestSyncContainer + + errgrp := errgroup.Group{} + errgrp.Go(func() error { + p, err := tcpostgres.NewPostgresTestSyncContainer(ctx, []tcpostgres.Option{}, []tcpostgres.Option{}) + if err != nil { + return err + } + postgres = p + return nil + }) + + errgrp.Go(func() error { + api, err := tcneosyncapi.NewNeosyncApiTestClient(ctx, t, tcneosyncapi.WithMigrationsDirectory(neosyncDbMigrationsPath)) + if err != nil { + return err + } + neosyncApi = api + return nil + }) + + err := errgrp.Wait() + if err != nil { + panic(err) + } + + testdataFolder := "../../../../../internal/testutil/testdata/postgres/humanresources" + err = postgres.Source.RunSqlFiles(ctx, &testdataFolder, []string{"create-tables.sql"}) + if err != nil { + panic(err) + } + err = postgres.Target.RunSqlFiles(ctx, &testdataFolder, []string{"create-schema.sql"}) + if err != nil { + panic(err) + } + + connclient := neosyncApi.UnauthdClients.Connections + conndataclient := neosyncApi.UnauthdClients.ConnectionData + + sqlmanagerclient := tcneosyncapi.NewTestSqlManagerClient() + + discardLogger := testutil.GetTestCharmSlogger() + + accountId := tcneosyncapi.CreatePersonalAccount(ctx, t, neosyncApi.UnauthdClients.Users) + sourceConn := tcneosyncapi.CreatePostgresConnection(ctx, t, neosyncApi.UnauthdClients.Connections, accountId, "source", postgres.Source.URL) + t.Run("sync_postgres", func(t *testing.T) { + outputType := output.PlainOutput + cmdConfig := &cmdConfig{ + Source: &sourceConfig{ + ConnectionId: sourceConn.Id, + }, + Destination: &sqlDestinationConfig{ + ConnectionUrl: postgres.Target.URL, + Driver: postgresDriver, + InitSchema: true, + TruncateBeforeInsert: true, + TruncateCascade: true, + }, + OutputType: &outputType, + AccountId: &accountId, + } + sync := &clisync{ + connectiondataclient: conndataclient, + connectionclient: connclient, + sqlmanagerclient: sqlmanagerclient, + ctx: ctx, + logger: discardLogger, + cmd: cmdConfig, + } + err := sync.configureAndRunSync() + require.NoError(t, err) + }) + + err = postgres.TearDown(ctx) + if err != nil { + panic(err) + } + err = neosyncApi.TearDown(ctx) + if err != nil { + panic(err) + } +} diff --git a/cli/internal/cmds/neosync/sync/ui.go b/cli/internal/cmds/neosync/sync/ui.go index 7535f9ed47..2bbe28772b 100644 --- a/cli/internal/cmds/neosync/sync/ui.go +++ b/cli/internal/cmds/neosync/sync/ui.go @@ -11,7 +11,6 @@ import ( "golang.org/x/sync/errgroup" - _ "github.com/nucleuscloud/neosync/cli/internal/benthos/inputs" "github.com/nucleuscloud/neosync/cli/internal/output" _ "github.com/nucleuscloud/neosync/worker/pkg/benthos/sql" _ "github.com/warpstreamlabs/bento/public/components/aws" @@ -67,7 +66,7 @@ func newModel(ctx context.Context, benv *service.Environment, groupedConfigs [][ } func (m *model) Init() tea.Cmd { - return tea.Batch(m.syncConfigs(m.ctx, m.benv, m.groupedConfigs[m.index]), m.spinner.Tick) + return tea.Batch(m.syncConfigs(m.ctx, m.groupedConfigs[m.index]), m.spinner.Tick) } func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { @@ -97,7 +96,7 @@ func (m *model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.index++ return m, tea.Batch( tea.Println(strings.Join(successStrs, " \n")), - m.syncConfigs(m.ctx, m.benv, m.groupedConfigs[m.index]), + m.syncConfigs(m.ctx, m.groupedConfigs[m.index]), ) case spinner.TickMsg: var cmd tea.Cmd @@ -137,7 +136,7 @@ func (m *model) View() string { type syncedDataMsg map[string]string -func (m *model) syncConfigs(ctx context.Context, benv *service.Environment, configs []*benthosConfigResponse) tea.Cmd { +func (m *model) syncConfigs(ctx context.Context, configs []*benthosConfigResponse) tea.Cmd { return func() tea.Msg { messageMap := syncmap.Map{} errgrp, errctx := errgroup.WithContext(ctx) @@ -147,7 +146,7 @@ func (m *model) syncConfigs(ctx context.Context, benv *service.Environment, conf errgrp.Go(func() error { start := time.Now() m.logger.Info(fmt.Sprintf("Syncing table %s", cfg.Name)) - err := syncData(errctx, benv, cfg, m.logger, m.outputType) + err := syncData(errctx, m.benv, cfg, m.logger, m.outputType) if err != nil { fmt.Printf("Error syncing table: %s", err.Error()) //nolint:forbidigo return err diff --git a/cli/internal/cmds/neosync/sync/util.go b/cli/internal/cmds/neosync/sync/util.go new file mode 100644 index 0000000000..ba8da54a29 --- /dev/null +++ b/cli/internal/cmds/neosync/sync/util.go @@ -0,0 +1,164 @@ +package sync_cmd + +import ( + "errors" + "fmt" + "log/slog" + "slices" + "strings" + + mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1" + sql_manager "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/shared" + tabledependency "github.com/nucleuscloud/neosync/backend/pkg/table-dependency" +) + +func buildPlainInsertArgs(cols []string) string { + if len(cols) == 0 { + return "" + } + pieces := make([]string, len(cols)) + for idx := range cols { + pieces[idx] = fmt.Sprintf("this.%q", cols[idx]) + } + return fmt.Sprintf("root = [%s]", strings.Join(pieces, ", ")) +} + +func maxInt(a, b int) int { + if a > b { + return a + } + return b +} + +func parseDriverString(str string) (DriverType, bool) { + p, ok := driverMap[strings.ToLower(str)] + return p, ok +} + +func getConnectionType(connection *mgmtv1alpha1.Connection) (ConnectionType, error) { + if connection.ConnectionConfig.GetAwsS3Config() != nil { + return awsS3Connection, nil + } + if connection.GetConnectionConfig().GetGcpCloudstorageConfig() != nil { + return gcpCloudStorageConnection, nil + } + if connection.ConnectionConfig.GetMysqlConfig() != nil { + return mysqlConnection, nil + } + if connection.ConnectionConfig.GetPgConfig() != nil { + return postgresConnection, nil + } + if connection.ConnectionConfig.GetDynamodbConfig() != nil { + return awsDynamoDBConnection, nil + } + return "", errors.New("unsupported connection type") +} + +func isConfigReady(config *benthosConfigResponse, queuedMap map[string][]string) bool { + for _, dep := range config.DependsOn { + if cols, ok := queuedMap[dep.Table]; ok { + for _, dc := range dep.Columns { + if !slices.Contains(cols, dc) { + return false + } + } + } else { + return false + } + } + return true +} + +func groupConfigsByDependency(configs []*benthosConfigResponse, logger *slog.Logger) [][]*benthosConfigResponse { + groupedConfigs := [][]*benthosConfigResponse{} + configMap := map[string]*benthosConfigResponse{} + queuedMap := map[string][]string{} // map -> table to cols + + // get root configs + rootConfigs := []*benthosConfigResponse{} + for _, c := range configs { + if len(c.DependsOn) == 0 { + rootConfigs = append(rootConfigs, c) + queuedMap[c.Table] = c.Columns + } else { + configMap[c.Name] = c + } + } + if len(rootConfigs) == 0 { + logger.Info("No root configs found. There must be one config with no dependencies.") + return nil + } + groupedConfigs = append(groupedConfigs, rootConfigs) + + prevTableLen := 0 + for len(configMap) > 0 { + // prevents looping forever + if prevTableLen == len(configMap) { + logger.Info("Unable to order configs by dependency. No path found.") + return nil + } + prevTableLen = len(configMap) + dependentConfigs := []*benthosConfigResponse{} + for _, c := range configMap { + if isConfigReady(c, queuedMap) { + dependentConfigs = append(dependentConfigs, c) + delete(configMap, c.Name) + } + } + if len(dependentConfigs) > 0 { + groupedConfigs = append(groupedConfigs, dependentConfigs) + for _, c := range dependentConfigs { + queuedMap[c.Table] = append(queuedMap[c.Table], c.Columns...) + } + } + } + + return groupedConfigs +} + +func getTableColMap(schemas []*mgmtv1alpha1.DatabaseColumn) map[string][]string { + tableColMap := map[string][]string{} + for _, record := range schemas { + table := sql_manager.BuildTable(record.Schema, record.Table) + _, ok := tableColMap[table] + if ok { + tableColMap[table] = append(tableColMap[table], record.Column) + } else { + tableColMap[table] = []string{record.Column} + } + } + + return tableColMap +} + +func buildDependencyMap(syncConfigs []*tabledependency.RunConfig) map[string][]string { + dependencyMap := map[string][]string{} + for _, cfg := range syncConfigs { + _, dpOk := dependencyMap[cfg.Table()] + if !dpOk { + dependencyMap[cfg.Table()] = []string{} + } + + for _, dep := range cfg.DependsOn() { + dependencyMap[cfg.Table()] = append(dependencyMap[cfg.Table()], dep.Table) + } + } + return dependencyMap +} + +func areSourceAndDestCompatible(connection *mgmtv1alpha1.Connection, destinationDriver DriverType) error { + switch connection.ConnectionConfig.Config.(type) { + case *mgmtv1alpha1.ConnectionConfig_PgConfig: + if destinationDriver != postgresDriver { + return fmt.Errorf("Connection and destination types are incompatible [postgres, %s]", destinationDriver) + } + case *mgmtv1alpha1.ConnectionConfig_MysqlConfig: + if destinationDriver != mysqlDriver { + return fmt.Errorf("Connection and destination types are incompatible [mysql, %s]", destinationDriver) + } + case *mgmtv1alpha1.ConnectionConfig_AwsS3Config, *mgmtv1alpha1.ConnectionConfig_GcpCloudstorageConfig, *mgmtv1alpha1.ConnectionConfig_DynamodbConfig: + default: + return errors.New("unsupported destination driver. only postgres and mysql are currently supported") + } + return nil +} diff --git a/internal/testutil/testcontainers/mysql/mysql.go b/internal/testutil/testcontainers/mysql/mysql.go new file mode 100644 index 0000000000..a02081f3bf --- /dev/null +++ b/internal/testutil/testcontainers/mysql/mysql.go @@ -0,0 +1,181 @@ +package testcontainers_mysql + +import ( + "context" + "database/sql" + "fmt" + "os" + "time" + + sqlmanager_shared "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/shared" + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/modules/mysql" + testmysql "github.com/testcontainers/testcontainers-go/modules/mysql" + "github.com/testcontainers/testcontainers-go/wait" + "golang.org/x/sync/errgroup" +) + +type MysqlTestSyncContainer struct { + Source *MysqlTestContainer + Target *MysqlTestContainer +} + +func NewMysqlTestSyncContainer(ctx context.Context, sourceOpts, destOpts []Option) (*MysqlTestSyncContainer, error) { + tc := &MysqlTestSyncContainer{} + errgrp := errgroup.Group{} + errgrp.Go(func() error { + m, err := NewMysqlTestContainer(ctx, sourceOpts...) + if err != nil { + return err + } + tc.Source = m + return nil + }) + + errgrp.Go(func() error { + m, err := NewMysqlTestContainer(ctx, destOpts...) + if err != nil { + return err + } + tc.Target = m + return nil + }) + + err := errgrp.Wait() + if err != nil { + return nil, err + } + + return tc, nil +} + +func (m *MysqlTestSyncContainer) TearDown(ctx context.Context) error { + if m.Source != nil { + err := m.Source.TearDown(ctx) + if err != nil { + return err + } + } + if m.Target != nil { + err := m.Target.TearDown(ctx) + if err != nil { + return err + } + } + return nil +} + +// Holds the MySQL test container and connection pool. +type MysqlTestContainer struct { + DB *sql.DB + URL string + TestContainer *testmysql.MySQLContainer + database string + password string + username string +} + +// Option is a functional option for configuring the Mysql Test Container +type Option func(*MysqlTestContainer) + +// NewMysqlTestContainer initializes a new MySQL Test Container with functional options +func NewMysqlTestContainer(ctx context.Context, opts ...Option) (*MysqlTestContainer, error) { + m := &MysqlTestContainer{ + database: "testdb", + username: "root", + password: "pass", + } + for _, opt := range opts { + opt(m) + } + return m.setup(ctx) +} + +// Sets test container database +func WithDatabase(database string) Option { + return func(a *MysqlTestContainer) { + a.database = database + } +} + +// Sets test container database +func WithUsername(username string) Option { + return func(a *MysqlTestContainer) { + a.username = username + } +} + +// Sets test container database +func WithPassword(password string) Option { + return func(a *MysqlTestContainer) { + a.password = password + } +} + +// Creates and starts a MySQL test container and sets up the connection. +func (m *MysqlTestContainer) setup(ctx context.Context) (*MysqlTestContainer, error) { + mysqlContainer, err := mysql.Run( + ctx, + "mysql:8.0.36", + mysql.WithDatabase(m.database), + mysql.WithUsername(m.username), + mysql.WithPassword(m.password), + testcontainers.WithWaitStrategy( + wait.ForLog("port: 3306 MySQL Community Server").WithOccurrence(1).WithStartupTimeout(20*time.Second), + ), + ) + if err != nil { + return nil, err + } + + connStr, err := mysqlContainer.ConnectionString(ctx, "multiStatements=true&parseTime=true") + if err != nil { + return nil, err + } + + db, err := sql.Open(sqlmanager_shared.MysqlDriver, connStr) + if err != nil { + return nil, err + } + + return &MysqlTestContainer{ + DB: db, + URL: connStr, + TestContainer: mysqlContainer, + }, nil +} + +// Closes the connection pool and terminates the container. +func (m *MysqlTestContainer) TearDown(ctx context.Context) error { + if m.DB != nil { + m.DB.Close() + } + + if m.TestContainer != nil { + err := m.TestContainer.Terminate(ctx) + if err != nil { + return fmt.Errorf("failed to terminate MySQL container: %w", err) + } + } + + return nil +} + +// Executes SQL files within the test container +func (m *MysqlTestContainer) RunSqlFiles(ctx context.Context, folder *string, files []string) error { + for _, file := range files { + filePath := file + if folder != nil && *folder != "" { + filePath = fmt.Sprintf("./%s/%s", *folder, file) + } + sqlStr, err := os.ReadFile(filePath) + if err != nil { + return err + } + _, err = m.DB.ExecContext(ctx, string(sqlStr)) + if err != nil { + return fmt.Errorf("unable to exec SQL when running MySQL SQL files: %w", err) + } + } + return nil +} diff --git a/internal/testutil/testcontainers/postgres/postgres.go b/internal/testutil/testcontainers/postgres/postgres.go new file mode 100644 index 0000000000..6073c53e94 --- /dev/null +++ b/internal/testutil/testcontainers/postgres/postgres.go @@ -0,0 +1,181 @@ +package testcontainers_postgres + +import ( + "context" + "fmt" + "os" + "time" + + "github.com/jackc/pgx/v5/pgxpool" + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/modules/postgres" + testpg "github.com/testcontainers/testcontainers-go/modules/postgres" + "github.com/testcontainers/testcontainers-go/wait" + "golang.org/x/sync/errgroup" +) + +type PostgresTestSyncContainer struct { + Source *PostgresTestContainer + Target *PostgresTestContainer +} + +func NewPostgresTestSyncContainer(ctx context.Context, sourceOpts, destOpts []Option) (*PostgresTestSyncContainer, error) { + tc := &PostgresTestSyncContainer{} + errgrp := errgroup.Group{} + errgrp.Go(func() error { + p, err := NewPostgresTestContainer(ctx, sourceOpts...) + if err != nil { + return err + } + tc.Source = p + return nil + }) + + errgrp.Go(func() error { + p, err := NewPostgresTestContainer(ctx, destOpts...) + if err != nil { + return err + } + tc.Target = p + return nil + }) + + err := errgrp.Wait() + if err != nil { + return nil, err + } + + return tc, nil +} + +func (p *PostgresTestSyncContainer) TearDown(ctx context.Context) error { + if p.Source != nil { + err := p.Source.TearDown(ctx) + if err != nil { + return err + } + } + if p.Target != nil { + err := p.Target.TearDown(ctx) + if err != nil { + return err + } + } + return nil +} + +// Holds the PostgreSQL test container and connection pool. +type PostgresTestContainer struct { + DB *pgxpool.Pool + URL string + TestContainer *testpg.PostgresContainer + database string + username string + password string +} + +// Option is a functional option for configuring the Postgres Test Container +type Option func(*PostgresTestContainer) + +// NewPostgresTestContainer initializes a new Postgres Test Container with functional options +func NewPostgresTestContainer(ctx context.Context, opts ...Option) (*PostgresTestContainer, error) { + p := &PostgresTestContainer{ + database: "testdb", + username: "postgres", + password: "pass", + } + for _, opt := range opts { + opt(p) + } + return p.Setup(ctx) +} + +// Sets test container database +func WithDatabase(database string) Option { + return func(a *PostgresTestContainer) { + a.database = database + } +} + +// Sets test container database +func WithUsername(username string) Option { + return func(a *PostgresTestContainer) { + a.username = username + } +} + +// Sets test container database +func WithPassword(password string) Option { + return func(a *PostgresTestContainer) { + a.password = password + } +} + +// Creates and starts a PostgreSQL test container and sets up the connection. +func (p *PostgresTestContainer) Setup(ctx context.Context) (*PostgresTestContainer, error) { + pgContainer, err := postgres.Run( + ctx, + "postgres:15", + postgres.WithDatabase(p.database), + postgres.WithUsername(p.username), + postgres.WithPassword(p.password), + testcontainers.WithWaitStrategy( + wait.ForLog("database system is ready to accept connections"). + WithOccurrence(2).WithStartupTimeout(20*time.Second), + ), + ) + if err != nil { + return nil, err + } + + connStr, err := pgContainer.ConnectionString(ctx, "sslmode=disable") + if err != nil { + return nil, err + } + + pool, err := pgxpool.New(ctx, connStr) + if err != nil { + return nil, err + } + + return &PostgresTestContainer{ + DB: pool, + URL: connStr, + TestContainer: pgContainer, + }, nil +} + +// Closes the connection pool and terminates the container. +func (p *PostgresTestContainer) TearDown(ctx context.Context) error { + if p.DB != nil { + p.DB.Close() + } + + if p.TestContainer != nil { + err := p.TestContainer.Terminate(ctx) + if err != nil { + return fmt.Errorf("failed to terminate postgres container: %w", err) + } + } + + return nil +} + +// Executes SQL files within the test container +func (p *PostgresTestContainer) RunSqlFiles(ctx context.Context, folder *string, files []string) error { + for _, file := range files { + filePath := file + if folder != nil && *folder != "" { + filePath = fmt.Sprintf("./%s/%s", *folder, file) + } + sqlStr, err := os.ReadFile(filePath) + if err != nil { + return err + } + _, err = p.DB.Exec(ctx, string(sqlStr)) + if err != nil { + return fmt.Errorf("unable to exec sql when running postgres sql files: %w", err) + } + } + return nil +} diff --git a/internal/testutil/testdata/postgres/alltypes/create-schema.sql b/internal/testutil/testdata/postgres/alltypes/create-schema.sql new file mode 100644 index 0000000000..5ac3f363a9 --- /dev/null +++ b/internal/testutil/testdata/postgres/alltypes/create-schema.sql @@ -0,0 +1 @@ +CREATE SCHEMA IF NOT EXISTS alltypes; diff --git a/internal/testutil/testdata/postgres/alltypes/create-tables.sql b/internal/testutil/testdata/postgres/alltypes/create-tables.sql new file mode 100644 index 0000000000..9403da5d35 --- /dev/null +++ b/internal/testutil/testdata/postgres/alltypes/create-tables.sql @@ -0,0 +1,348 @@ +CREATE SCHEMA IF NOT EXISTS alltypes; +CREATE TABLE IF NOT EXISTS alltypes.all_postgres_types ( + id BIGINT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, + -- Numeric Types + smallint_col SMALLINT, + integer_col INTEGER, + bigint_col BIGINT, + decimal_col DECIMAL(10, 2), + numeric_col NUMERIC(10, 2), + real_col REAL, + double_precision_col DOUBLE PRECISION, + serial_col SERIAL, + bigserial_col BIGSERIAL, + + -- Monetary Types + money_col MONEY, + + -- Character Types + char_col CHAR(10), + varchar_col VARCHAR(50), + text_col TEXT, + + -- Binary Types + bytea_col BYTEA, + + -- Date/Time Types + timestamp_col TIMESTAMP, + timestamptz_col TIMESTAMPTZ, + date_col DATE, + time_col TIME, + timetz_col TIMETZ, + interval_col INTERVAL, + + -- Boolean Type + boolean_col BOOLEAN, + + -- UUID Type + uuid_col UUID, + + -- Network Address Types + inet_col INET, + cidr_col CIDR, + macaddr_col MACADDR, + + -- Bit String Types + bit_col BIT(8), + varbit_col VARBIT(8), + + -- Geometric Types + point_col POINT, + line_col LINE, + lseg_col LSEG, + box_col BOX, + path_col PATH, + polygon_col POLYGON, + circle_col CIRCLE, + + -- JSON Types + json_col JSON, + jsonb_col JSONB, + + -- Range Types + int4range_col INT4RANGE, + int8range_col INT8RANGE, + numrange_col NUMRANGE, + tsrange_col TSRANGE, + tstzrange_col TSTZRANGE, + daterange_col DATERANGE, + + -- Array Types + integer_array_col INTEGER[], + text_array_col TEXT[], + + -- XML Type + xml_col XML, + + -- TSVECTOR Type (Full-Text Search) + tsvector_col TSVECTOR, + + -- OID Type + oid_col OID +); + + +INSERT INTO alltypes.all_postgres_types ( + Id, + smallint_col, + integer_col, + bigint_col, + decimal_col, + numeric_col, + real_col, + double_precision_col, + serial_col, + bigserial_col, + money_col, + char_col, + varchar_col, + text_col, + bytea_col, + timestamp_col, + timestamptz_col, + date_col, + time_col, + timetz_col, + interval_col, + boolean_col, + uuid_col, + inet_col, + cidr_col, + macaddr_col, + bit_col, + varbit_col, + point_col, + line_col, + lseg_col, + box_col, + path_col, + polygon_col, + circle_col, + json_col, + jsonb_col, + int4range_col, + int8range_col, + numrange_col, + tsrange_col, + tstzrange_col, + daterange_col, + integer_array_col, + text_array_col, + xml_col, + tsvector_col, + oid_col +) VALUES ( + DEFAULT, + 32767, -- smallint_col + 2147483647, -- integer_col + 9223372036854775807, -- bigint_col + 1234.56, -- decimal_col + 99999999.99, -- numeric_col + 12345.67, -- real_col + 123456789.123456789, -- double_precision_col + 1, -- serial_col (auto-incremented, will be generated) + 1, -- bigserial_col (auto-incremented, will be generated) + '$100.00', -- money_col + 'A', -- char_col + 'DEFAULT', -- varchar_col + 'default', -- text_col + decode('DEADBEEF', 'hex'), -- bytea_col + '2024-01-01 12:34:56', -- timestamp_col + '2024-01-01 12:34:56+00', -- timestamptz_col + '2024-01-01', -- date_col + '12:34:56', -- time_col + '12:34:56+00', -- timetz_col + '1 day', -- interval_col + TRUE, -- boolean_col + '123e4567-e89b-12d3-a456-426614174000', -- uuid_col + '192.168.1.1', -- inet_col + '192.168.1.0/24', -- cidr_col + '08:00:2b:01:02:03', -- macaddr_col + B'10101010', -- bit_col + B'1010', -- varbit_col + '(1, 2)', -- point_col + '{1, 1, 0}', -- line_col + '[(0,0), (1,1)]', -- lseg_col + '(0,0),(1,1)', -- box_col + '((0,0), (1,1), (2,2))', -- path_col + '((0,0), (1,1), (1,0))', -- polygon_col + '<(1,1),1>', -- circle_col + '{"name": "John", "age": 30}', -- json_col + '{"name": "John", "age": 30}', -- jsonb_col + '[1,10]', -- int4range_col + '[1,1000]', -- int8range_col + '[1.0,10.0]', -- numrange_col + '[2024-01-01 12:00:00, 2024-01-01 13:00:00]', -- tsrange_col + '[2024-01-01 12:00:00+00, 2024-01-01 13:00:00+00]', -- tstzrange_col + '[2024-01-01, 2024-01-02]', -- daterange_col + '{1, 2, 3}', -- integer_array_col + '{"one", "two", "three"}', -- text_array_col + 'bar', -- xml_col + 'example tsvector', -- tsvector_col + 123456 -- oid_col +); + +INSERT INTO alltypes.all_postgres_types ( + Id +) VALUES ( + DEFAULT +); + + +CREATE TABLE IF NOT EXISTS alltypes.time_time ( + id SERIAL PRIMARY KEY, + timestamp_col TIMESTAMP, + timestamptz_col TIMESTAMPTZ, + date_col DATE +); + +INSERT INTO alltypes.time_time ( + timestamp_col, + timestamptz_col, + date_col +) +VALUES ( + '2024-03-18 10:30:00', + '2024-03-18 10:30:00+00', + '2024-03-18' +); + +INSERT INTO alltypes.time_time ( + timestamp_col, + timestamptz_col, + date_col +) +VALUES ( + '0001-01-01 00:00:00 BC', + '0001-01-01 00:00:00+00 BC', + '0001-01-01 BC' +); + + +-- CREATE TABLE IF NOT EXISTS alltypes.array_types ( +-- "id" BIGINT NOT NULL PRIMARY KEY, +-- "int_array" _int4, +-- "smallint_array" _int2, +-- "bigint_array" _int8, +-- "real_array" _float4, +-- "double_array" _float8, +-- "text_array" _text, +-- "varchar_array" _varchar, +-- "char_array" _bpchar, +-- "boolean_array" _bool, +-- "date_array" _date, +-- "time_array" _time, +-- "timestamp_array" _timestamp, +-- "timestamptz_array" _timestamptz, +-- "interval_array" _interval, +-- -- "inet_array" _inet, // broken +-- -- "cidr_array" _cidr, +-- "point_array" _point, +-- "line_array" _line, +-- "lseg_array" _lseg, +-- -- "box_array" _box, // broken +-- "path_array" _path, +-- "polygon_array" _polygon, +-- "circle_array" _circle, +-- "uuid_array" _uuid, +-- "json_array" _json, +-- "jsonb_array" _jsonb, +-- "bit_array" _bit, +-- "varbit_array" _varbit, +-- "numeric_array" _numeric, +-- "money_array" _money, +-- "xml_array" _xml, +-- "int_double_array" _int4 +-- ); + + +-- INSERT INTO alltypes.array_types ( +-- id, int_array, smallint_array, bigint_array, real_array, double_array, +-- text_array, varchar_array, char_array, boolean_array, date_array, +-- time_array, timestamp_array, timestamptz_array, interval_array, +-- -- inet_array, cidr_array, +-- point_array, line_array, lseg_array, +-- -- box_array, +-- path_array, polygon_array, circle_array, +-- uuid_array, +-- json_array, jsonb_array, +-- bit_array, varbit_array, numeric_array, +-- money_array, xml_array, int_double_array +-- ) VALUES ( +-- 1, +-- ARRAY[1, 2, 3], +-- ARRAY[10::smallint, 20::smallint], +-- ARRAY[100::bigint, 200::bigint], +-- ARRAY[1.1::real, 2.2::real], +-- ARRAY[1.11::double precision, 2.22::double precision], +-- ARRAY['text1', 'text2'], +-- ARRAY['varchar1'::varchar, 'varchar2'::varchar], +-- ARRAY['a'::char, 'b'::char], +-- ARRAY[true, false], +-- ARRAY['2023-01-01'::date, '2023-01-02'::date], +-- ARRAY['12:00:00'::time, '13:00:00'::time], +-- ARRAY['2023-01-01 12:00:00'::timestamp, '2023-01-02 13:00:00'::timestamp], +-- ARRAY['2023-01-01 12:00:00+00'::timestamptz, '2023-01-02 13:00:00+00'::timestamptz], +-- ARRAY['1 day'::interval, '2 hours'::interval], +-- -- ARRAY['192.168.0.1'::inet, '10.0.0.1'::inet], +-- -- ARRAY['192.168.0.0/24'::cidr, '10.0.0.0/8'::cidr], +-- ARRAY['(1,1)'::point, '(2,2)'::point], +-- ARRAY['{1,2,2}'::line, '{3,4,4}'::line], +-- ARRAY['(1,1,2,2)'::lseg, '(3,3,4,4)'::lseg], +-- -- ARRAY['(1,1,2,2)'::box, '(3,3,4,4)'::box], +-- ARRAY['((1,1),(2,2),(3,3))'::path, '((4,4),(5,5),(6,6))'::path], +-- ARRAY['((1,1),(2,2),(3,3))'::polygon, '((4,4),(5,5),(6,6))'::polygon], +-- ARRAY['<(1,1),1>'::circle, '<(2,2),2>'::circle], +-- ARRAY['a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11'::uuid, 'b0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11'::uuid], +-- ARRAY['{"key": "value1"}'::json, '{"key": "value2"}'::json], +-- ARRAY['{"key": "value1"}'::jsonb, '{"key": "value2"}'::jsonb], +-- ARRAY['101'::bit(3), '110'::bit(3)], +-- ARRAY['10101'::bit varying(5), '01010'::bit varying(5)], +-- ARRAY[1.23::numeric, 4.56::numeric], +-- ARRAY[10.00::money, 20.00::money], +-- ARRAY['value1'::xml, 'value2'::xml], +-- ARRAY[[1, 2], [3, 4]] +-- ); + + +CREATE TABLE alltypes.json_data ( + id SERIAL PRIMARY KEY, + data JSONB +); + + +INSERT INTO alltypes.json_data (data) VALUES ('"Hello, world!"'); +INSERT INTO alltypes.json_data (data) VALUES ('42'); +INSERT INTO alltypes.json_data (data) VALUES ('3.14'); +INSERT INTO alltypes.json_data (data) VALUES ('true'); +INSERT INTO alltypes.json_data (data) VALUES ('false'); +INSERT INTO alltypes.json_data (data) VALUES ('null'); + +INSERT INTO alltypes.json_data (data) VALUES ('{"name": "John", "age": 30}'); +INSERT INTO alltypes.json_data (data) VALUES ('{"coords": {"x": 10, "y": 20}}'); + +INSERT INTO alltypes.json_data (data) VALUES ('[1, 2, 3, 4]'); +INSERT INTO alltypes.json_data (data) VALUES ('["apple", "banana", "cherry"]'); + +INSERT INTO alltypes.json_data (data) VALUES ('{"items": ["book", "pen"], "count": 2, "in_stock": true}'); + +INSERT INTO alltypes.json_data (data) VALUES ( + '{ + "user": { + "name": "Alice", + "age": 28, + "contacts": [ + {"type": "email", "value": "alice@example.com"}, + {"type": "phone", "value": "123-456-7890"} + ] + }, + "orders": [ + {"id": 1001, "total": 59.99}, + {"id": 1002, "total": 24.50} + ], + "preferences": { + "notifications": true, + "theme": "dark" + } + }' +); diff --git a/internal/testutil/testdata/postgres/alltypes/teardown.sql b/internal/testutil/testdata/postgres/alltypes/teardown.sql new file mode 100644 index 0000000000..7313dd27c3 --- /dev/null +++ b/internal/testutil/testdata/postgres/alltypes/teardown.sql @@ -0,0 +1 @@ +DROP SCHEMA IF EXISTS alltypes CASCADE; diff --git a/internal/testutil/testdata/postgres/humanresources/create-schema.sql b/internal/testutil/testdata/postgres/humanresources/create-schema.sql new file mode 100644 index 0000000000..9112cc24aa --- /dev/null +++ b/internal/testutil/testdata/postgres/humanresources/create-schema.sql @@ -0,0 +1 @@ +CREATE SCHEMA IF NOT EXISTS humanresources; diff --git a/internal/testutil/testdata/postgres/humanresources/create-tables.sql b/internal/testutil/testdata/postgres/humanresources/create-tables.sql new file mode 100644 index 0000000000..5e9d8b3d58 --- /dev/null +++ b/internal/testutil/testdata/postgres/humanresources/create-tables.sql @@ -0,0 +1,225 @@ +CREATE SCHEMA IF NOT EXISTS humanresources; +SET search_path TO humanresources; + +CREATE TABLE regions ( + region_id SERIAL PRIMARY KEY, + region_name CHARACTER VARYING (25) +); + +CREATE TABLE countries ( + country_id CHARACTER (2) PRIMARY KEY, + country_name CHARACTER VARYING (40), + region_id INTEGER NOT NULL, + FOREIGN KEY (region_id) REFERENCES regions (region_id) ON UPDATE CASCADE ON DELETE CASCADE +); + +CREATE TABLE locations ( + location_id SERIAL PRIMARY KEY, + street_address CHARACTER VARYING (40), + postal_code CHARACTER VARYING (12), + city CHARACTER VARYING (30) NOT NULL, + state_province CHARACTER VARYING (25), + country_id CHARACTER (2) NOT NULL, + FOREIGN KEY (country_id) REFERENCES countries (country_id) ON UPDATE CASCADE ON DELETE CASCADE +); + +CREATE TABLE departments ( + department_id SERIAL PRIMARY KEY, + department_name CHARACTER VARYING (30) NOT NULL, + location_id INTEGER, + FOREIGN KEY (location_id) REFERENCES locations (location_id) ON UPDATE CASCADE ON DELETE CASCADE +); + +CREATE TABLE jobs ( + job_id SERIAL PRIMARY KEY, + job_title CHARACTER VARYING (35) NOT NULL, + min_salary NUMERIC (8, 2), + max_salary NUMERIC (8, 2) +); + +CREATE TABLE employees ( + employee_id SERIAL PRIMARY KEY, + first_name CHARACTER VARYING (20), + last_name CHARACTER VARYING (25) NOT NULL, + email CHARACTER VARYING (100) NOT NULL, + phone_number CHARACTER VARYING (20), + hire_date DATE NOT NULL, + job_id INTEGER NOT NULL, + salary NUMERIC (8, 2) NOT NULL, + manager_id INTEGER, + department_id INTEGER, + FOREIGN KEY (job_id) REFERENCES jobs (job_id) ON UPDATE CASCADE ON DELETE CASCADE, + FOREIGN KEY (department_id) REFERENCES departments (department_id) ON UPDATE CASCADE ON DELETE CASCADE, + FOREIGN KEY (manager_id) REFERENCES employees (employee_id) ON UPDATE CASCADE ON DELETE CASCADE +); + +CREATE TABLE dependents ( + dependent_id SERIAL PRIMARY KEY, + first_name CHARACTER VARYING (50) NOT NULL, + last_name CHARACTER VARYING (50) NOT NULL, + relationship CHARACTER VARYING (25) NOT NULL, + employee_id INTEGER NOT NULL, + FOREIGN KEY (employee_id) REFERENCES employees (employee_id) ON DELETE CASCADE ON UPDATE CASCADE +); + + +/*Data for the table regions */ + +INSERT INTO regions(region_id,region_name) VALUES (1,'Europe'); +INSERT INTO regions(region_id,region_name) VALUES (2,'Americas'); +INSERT INTO regions(region_id,region_name) VALUES (3,'Asia'); +INSERT INTO regions(region_id,region_name) VALUES (4,'Middle East and Africa'); + + +/*Data for the table countries */ +INSERT INTO countries(country_id,country_name,region_id) VALUES ('AR','Argentina',2); +INSERT INTO countries(country_id,country_name,region_id) VALUES ('AU','Australia',3); +INSERT INTO countries(country_id,country_name,region_id) VALUES ('BE','Belgium',1); +INSERT INTO countries(country_id,country_name,region_id) VALUES ('BR','Brazil',2); +INSERT INTO countries(country_id,country_name,region_id) VALUES ('CA','Canada',2); +INSERT INTO countries(country_id,country_name,region_id) VALUES ('CH','Switzerland',1); +INSERT INTO countries(country_id,country_name,region_id) VALUES ('CN','China',3); +INSERT INTO countries(country_id,country_name,region_id) VALUES ('DE','Germany',1); +INSERT INTO countries(country_id,country_name,region_id) VALUES ('DK','Denmark',1); +INSERT INTO countries(country_id,country_name,region_id) VALUES ('EG','Egypt',4); +INSERT INTO countries(country_id,country_name,region_id) VALUES ('FR','France',1); +INSERT INTO countries(country_id,country_name,region_id) VALUES ('HK','HongKong',3); +INSERT INTO countries(country_id,country_name,region_id) VALUES ('IL','Israel',4); +INSERT INTO countries(country_id,country_name,region_id) VALUES ('IN','India',3); +INSERT INTO countries(country_id,country_name,region_id) VALUES ('IT','Italy',1); +INSERT INTO countries(country_id,country_name,region_id) VALUES ('JP','Japan',3); +INSERT INTO countries(country_id,country_name,region_id) VALUES ('KW','Kuwait',4); +INSERT INTO countries(country_id,country_name,region_id) VALUES ('MX','Mexico',2); +INSERT INTO countries(country_id,country_name,region_id) VALUES ('NG','Nigeria',4); +INSERT INTO countries(country_id,country_name,region_id) VALUES ('NL','Netherlands',1); +INSERT INTO countries(country_id,country_name,region_id) VALUES ('SG','Singapore',3); +INSERT INTO countries(country_id,country_name,region_id) VALUES ('UK','United Kingdom',1); +INSERT INTO countries(country_id,country_name,region_id) VALUES ('US','United States of America',2); +INSERT INTO countries(country_id,country_name,region_id) VALUES ('ZM','Zambia',4); +INSERT INTO countries(country_id,country_name,region_id) VALUES ('ZW','Zimbabwe',4); + +/*Data for the table locations */ +INSERT INTO locations(location_id,street_address,postal_code,city,state_province,country_id) VALUES (1400,'2014 Jabberwocky Rd','26192','Southlake','Texas','US'); +INSERT INTO locations(location_id,street_address,postal_code,city,state_province,country_id) VALUES (1500,'2011 Interiors Blvd','99236','South San Francisco','California','US'); +INSERT INTO locations(location_id,street_address,postal_code,city,state_province,country_id) VALUES (1700,'2004 Charade Rd','98199','Seattle','Washington','US'); +INSERT INTO locations(location_id,street_address,postal_code,city,state_province,country_id) VALUES (1800,'147 Spadina Ave','M5V 2L7','Toronto','Ontario','CA'); +INSERT INTO locations(location_id,street_address,postal_code,city,state_province,country_id) VALUES (2400,'8204 Arthur St',NULL,'London',NULL,'UK'); +INSERT INTO locations(location_id,street_address,postal_code,city,state_province,country_id) VALUES (2500,'Magdalen Centre, The Oxford Science Park','OX9 9ZB','Oxford','Oxford','UK'); +INSERT INTO locations(location_id,street_address,postal_code,city,state_province,country_id) VALUES (2700,'Schwanthalerstr. 7031','80925','Munich','Bavaria','DE'); + + +/*Data for the table jobs */ + +INSERT INTO jobs(job_id,job_title,min_salary,max_salary) VALUES (1,'Public Accountant',4200.00,9000.00); +INSERT INTO jobs(job_id,job_title,min_salary,max_salary) VALUES (2,'Accounting Manager',8200.00,16000.00); +INSERT INTO jobs(job_id,job_title,min_salary,max_salary) VALUES (3,'Administration Assistant',3000.00,6000.00); +INSERT INTO jobs(job_id,job_title,min_salary,max_salary) VALUES (4,'President',20000.00,40000.00); +INSERT INTO jobs(job_id,job_title,min_salary,max_salary) VALUES (5,'Administration Vice President',15000.00,30000.00); +INSERT INTO jobs(job_id,job_title,min_salary,max_salary) VALUES (6,'Accountant',4200.00,9000.00); +INSERT INTO jobs(job_id,job_title,min_salary,max_salary) VALUES (7,'Finance Manager',8200.00,16000.00); +INSERT INTO jobs(job_id,job_title,min_salary,max_salary) VALUES (8,'Human Resources Representative',4000.00,9000.00); +INSERT INTO jobs(job_id,job_title,min_salary,max_salary) VALUES (9,'Programmer',4000.00,10000.00); +INSERT INTO jobs(job_id,job_title,min_salary,max_salary) VALUES (10,'Marketing Manager',9000.00,15000.00); +INSERT INTO jobs(job_id,job_title,min_salary,max_salary) VALUES (11,'Marketing Representative',4000.00,9000.00); +INSERT INTO jobs(job_id,job_title,min_salary,max_salary) VALUES (12,'Public Relations Representative',4500.00,10500.00); +INSERT INTO jobs(job_id,job_title,min_salary,max_salary) VALUES (13,'Purchasing Clerk',2500.00,5500.00); +INSERT INTO jobs(job_id,job_title,min_salary,max_salary) VALUES (14,'Purchasing Manager',8000.00,15000.00); +INSERT INTO jobs(job_id,job_title,min_salary,max_salary) VALUES (15,'Sales Manager',10000.00,20000.00); +INSERT INTO jobs(job_id,job_title,min_salary,max_salary) VALUES (16,'Sales Representative',6000.00,12000.00); +INSERT INTO jobs(job_id,job_title,min_salary,max_salary) VALUES (17,'Shipping Clerk',2500.00,5500.00); +INSERT INTO jobs(job_id,job_title,min_salary,max_salary) VALUES (18,'Stock Clerk',2000.00,5000.00); +INSERT INTO jobs(job_id,job_title,min_salary,max_salary) VALUES (19,'Stock Manager',5500.00,8500.00); + + +/*Data for the table departments */ + +INSERT INTO departments(department_id,department_name,location_id) VALUES (1,'Administration',1700); +INSERT INTO departments(department_id,department_name,location_id) VALUES (2,'Marketing',1800); +INSERT INTO departments(department_id,department_name,location_id) VALUES (3,'Purchasing',1700); +INSERT INTO departments(department_id,department_name,location_id) VALUES (4,'Human Resources',2400); +INSERT INTO departments(department_id,department_name,location_id) VALUES (5,'Shipping',1500); +INSERT INTO departments(department_id,department_name,location_id) VALUES (6,'IT',1400); +INSERT INTO departments(department_id,department_name,location_id) VALUES (7,'Public Relations',2700); +INSERT INTO departments(department_id,department_name,location_id) VALUES (8,'Sales',2500); +INSERT INTO departments(department_id,department_name,location_id) VALUES (9,'Executive',1700); +INSERT INTO departments(department_id,department_name,location_id) VALUES (10,'Finance',1700); +INSERT INTO departments(department_id,department_name,location_id) VALUES (11,'Accounting',1700); + + + +/*Data for the table employees */ + +INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (100,'Steven','King','steven.king@sqltutorial.org','515.123.4567','1987-06-17',4,24000.00,NULL,9); +INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (101,'Neena','Kochhar','neena.kochhar@sqltutorial.org','515.123.4568','1989-09-21',5,17000.00,100,9); +INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (102,'Lex','De Haan','lex.de haan@sqltutorial.org','515.123.4569','1993-01-13',5,17000.00,100,9); +INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (103,'Alexander','Hunold','alexander.hunold@sqltutorial.org','590.423.4567','1990-01-03',9,9000.00,102,6); +INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (104,'Bruce','Ernst','bruce.ernst@sqltutorial.org','590.423.4568','1991-05-21',9,6000.00,103,6); +INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (105,'David','Austin','david.austin@sqltutorial.org','590.423.4569','1997-06-25',9,4800.00,103,6); +INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (106,'Valli','Pataballa','valli.pataballa@sqltutorial.org','590.423.4560','1998-02-05',9,4800.00,103,6); +INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (107,'Diana','Lorentz','diana.lorentz@sqltutorial.org','590.423.5567','1999-02-07',9,4200.00,103,6); +INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (108,'Nancy','Greenberg','nancy.greenberg@sqltutorial.org','515.124.4569','1994-08-17',7,12000.00,101,10); +INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (109,'Daniel','Faviet','daniel.faviet@sqltutorial.org','515.124.4169','1994-08-16',6,9000.00,108,10); +INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (110,'John','Chen','john.chen@sqltutorial.org','515.124.4269','1997-09-28',6,8200.00,108,10); +INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (111,'Ismael','Sciarra','ismael.sciarra@sqltutorial.org','515.124.4369','1997-09-30',6,7700.00,108,10); +INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (112,'Jose Manuel','Urman','jose manuel.urman@sqltutorial.org','515.124.4469','1998-03-07',6,7800.00,108,10); +INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (113,'Luis','Popp','luis.popp@sqltutorial.org','515.124.4567','1999-12-07',6,6900.00,108,10); +INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (114,'Den','Raphaely','den.raphaely@sqltutorial.org','515.127.4561','1994-12-07',14,11000.00,100,3); +INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (115,'Alexander','Khoo','alexander.khoo@sqltutorial.org','515.127.4562','1995-05-18',13,3100.00,114,3); +INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (116,'Shelli','Baida','shelli.baida@sqltutorial.org','515.127.4563','1997-12-24',13,2900.00,114,3); +INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (117,'Sigal','Tobias','sigal.tobias@sqltutorial.org','515.127.4564','1997-07-24',13,2800.00,114,3); +INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (118,'Guy','Himuro','guy.himuro@sqltutorial.org','515.127.4565','1998-11-15',13,2600.00,114,3); +INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (119,'Karen','Colmenares','karen.colmenares@sqltutorial.org','515.127.4566','1999-08-10',13,2500.00,114,3); +INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (120,'Matthew','Weiss','matthew.weiss@sqltutorial.org','650.123.1234','1996-07-18',19,8000.00,100,5); +INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (121,'Adam','Fripp','adam.fripp@sqltutorial.org','650.123.2234','1997-04-10',19,8200.00,100,5); +INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (122,'Payam','Kaufling','payam.kaufling@sqltutorial.org','650.123.3234','1995-05-01',19,7900.00,100,5); +INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (123,'Shanta','Vollman','shanta.vollman@sqltutorial.org','650.123.4234','1997-10-10',19,6500.00,100,5); +INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (126,'Irene','Mikkilineni','irene.mikkilineni@sqltutorial.org','650.124.1224','1998-09-28',18,2700.00,120,5); +INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (145,'John','Russell','john.russell@sqltutorial.org',NULL,'1996-10-01',15,14000.00,100,8); +INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (146,'Karen','Partners','karen.partners@sqltutorial.org',NULL,'1997-01-05',15,13500.00,100,8); +INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (176,'Jonathon','Taylor','jonathon.taylor@sqltutorial.org',NULL,'1998-03-24',16,8600.00,100,8); +INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (177,'Jack','Livingston','jack.livingston@sqltutorial.org',NULL,'1998-04-23',16,8400.00,100,8); +INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (178,'Kimberely','Grant','kimberely.grant@sqltutorial.org',NULL,'1999-05-24',16,7000.00,100,8); +INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (179,'Charles','Johnson','charles.johnson@sqltutorial.org',NULL,'2000-01-04',16,6200.00,100,8); +INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (192,'Sarah','Bell','sarah.bell@sqltutorial.org','650.501.1876','1996-02-04',17,4000.00,123,5); +INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (193,'Britney','Everett','britney.everett@sqltutorial.org','650.501.2876','1997-03-03',17,3900.00,123,5); +INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (200,'Jennifer','Whalen','jennifer.whalen@sqltutorial.org','515.123.4444','1987-09-17',3,4400.00,101,1); +INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (201,'Michael','Hartstein','michael.hartstein@sqltutorial.org','515.123.5555','1996-02-17',10,13000.00,100,2); +INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (202,'Pat','Fay','pat.fay@sqltutorial.org','603.123.6666','1997-08-17',11,6000.00,201,2); +INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (203,'Susan','Mavris','susan.mavris@sqltutorial.org','515.123.7777','1994-06-07',8,6500.00,101,4); +INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (204,'Hermann','Baer','hermann.baer@sqltutorial.org','515.123.8888','1994-06-07',12,10000.00,101,7); +INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (205,'Shelley','Higgins','shelley.higgins@sqltutorial.org','515.123.8080','1994-06-07',2,12000.00,101,11); +INSERT INTO employees(employee_id,first_name,last_name,email,phone_number,hire_date,job_id,salary,manager_id,department_id) VALUES (206,'William','Gietz','william.gietz@sqltutorial.org','515.123.8181','1994-06-07',1,8300.00,205,11); + + +/*Data for the table dependents */ + +INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (1,'Penelope','Gietz','Child',206); +INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (2,'Nick','Higgins','Child',205); +INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (3,'Ed','Whalen','Child',200); +INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (4,'Jennifer','King','Child',100); +INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (5,'Johnny','Kochhar','Child',101); +INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (6,'Bette','De Haan','Child',102); +INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (7,'Grace','Faviet','Child',109); +INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (8,'Matthew','Chen','Child',110); +INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (9,'Joe','Sciarra','Child',111); +INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (10,'Christian','Urman','Child',112); +INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (11,'Zero','Popp','Child',113); +INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (12,'Karl','Greenberg','Child',108); +INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (13,'Uma','Mavris','Child',203); +INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (14,'Vivien','Hunold','Child',103); +INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (15,'Cuba','Ernst','Child',104); +INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (16,'Fred','Austin','Child',105); +INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (17,'Helen','Pataballa','Child',106); +INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (18,'Dan','Lorentz','Child',107); +INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (19,'Bob','Hartstein','Child',201); +INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (20,'Lucille','Fay','Child',202); +INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (21,'Kirsten','Baer','Child',204); +INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (22,'Elvis','Khoo','Child',115); +INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (23,'Sandra','Baida','Child',116); +INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (24,'Cameron','Tobias','Child',117); +INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (25,'Kevin','Himuro','Child',118); +INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (26,'Rip','Colmenares','Child',119); +INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (27,'Julia','Raphaely','Child',114); +INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (28,'Woody','Russell','Child',145); +INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (29,'Alec','Partners','Child',146); +INSERT INTO dependents(dependent_id,first_name,last_name,relationship,employee_id) VALUES (30,'Sandra','Taylor','Child',176); diff --git a/internal/testutil/testdata/postgres/humanresources/teardown.sql b/internal/testutil/testdata/postgres/humanresources/teardown.sql new file mode 100644 index 0000000000..dc2f4c304e --- /dev/null +++ b/internal/testutil/testdata/postgres/humanresources/teardown.sql @@ -0,0 +1 @@ +DROP SCHEMA IF EXISTS humanresources CASCADE; diff --git a/internal/testutil/utils.go b/internal/testutil/utils.go new file mode 100644 index 0000000000..10001a4cf0 --- /dev/null +++ b/internal/testutil/utils.go @@ -0,0 +1,31 @@ +package testutil + +import ( + "fmt" + "io" + "log/slog" + "os" + + charmlog "github.com/charmbracelet/log" +) + +func ShouldRunIntegrationTest() bool { + evkey := "INTEGRATION_TESTS_ENABLED" + shouldRun := os.Getenv(evkey) + if shouldRun != "1" { + slog.Warn(fmt.Sprintf("skipping integration tests, set %s=1 to enable", evkey)) + return false + } + return true +} + +func GetTestSlogger() *slog.Logger { + return slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})) +} + +func GetTestCharmSlogger() *slog.Logger { + charmlogger := charmlog.NewWithOptions(io.Discard, charmlog.Options{ + Level: charmlog.DebugLevel, + }) + return slog.New(charmlogger) +} diff --git a/worker/pkg/benthos/config.go b/worker/pkg/benthos/config.go index 4e6a042c98..99388eeaf4 100644 --- a/worker/pkg/benthos/config.go +++ b/worker/pkg/benthos/config.go @@ -17,6 +17,7 @@ type HTTPConfig struct { } type StreamConfig struct { + Logger *LoggerConfig `json:"logger" yaml:"logger,omitempty"` Input *InputConfig `json:"input" yaml:"input"` Buffer *BufferConfig `json:"buffer,omitempty" yaml:"buffer,omitempty"` Pipeline *PipelineConfig `json:"pipeline" yaml:"pipeline"` @@ -25,6 +26,11 @@ type StreamConfig struct { Metrics *Metrics `json:"metrics,omitempty" yaml:"metrics,omitempty"` } +type LoggerConfig struct { + Level string `json:"level" yaml:"level"` + AddTimestamp bool `json:"add_timestamp" yaml:"add_timestamp"` +} + type Metrics struct { OtelCollector *MetricsOtelCollector `json:"otel_collector,omitempty" yaml:"otel_collector,omitempty"` Mapping string `json:"mapping,omitempty" yaml:"mapping,omitempty"` @@ -54,13 +60,25 @@ type InputConfig struct { } type Inputs struct { - SqlSelect *SqlSelect `json:"sql_select,omitempty" yaml:"sql_select,omitempty"` - PooledSqlRaw *InputPooledSqlRaw `json:"pooled_sql_raw,omitempty" yaml:"pooled_sql_raw,omitempty"` - Generate *Generate `json:"generate,omitempty" yaml:"generate,omitempty"` - OpenAiGenerate *OpenAiGenerate `json:"openai_generate,omitempty" yaml:"openai_generate,omitempty"` - MongoDB *InputMongoDb `json:"mongodb,omitempty" yaml:"mongodb,omitempty"` - PooledMongoDB *InputMongoDb `json:"pooled_mongodb,omitempty" yaml:"pooled_mongodb,omitempty"` - AwsDynamoDB *InputAwsDynamoDB `json:"aws_dynamodb,omitempty" yaml:"aws_dynamodb,omitempty"` + SqlSelect *SqlSelect `json:"sql_select,omitempty" yaml:"sql_select,omitempty"` + PooledSqlRaw *InputPooledSqlRaw `json:"pooled_sql_raw,omitempty" yaml:"pooled_sql_raw,omitempty"` + Generate *Generate `json:"generate,omitempty" yaml:"generate,omitempty"` + OpenAiGenerate *OpenAiGenerate `json:"openai_generate,omitempty" yaml:"openai_generate,omitempty"` + MongoDB *InputMongoDb `json:"mongodb,omitempty" yaml:"mongodb,omitempty"` + PooledMongoDB *InputMongoDb `json:"pooled_mongodb,omitempty" yaml:"pooled_mongodb,omitempty"` + AwsDynamoDB *InputAwsDynamoDB `json:"aws_dynamodb,omitempty" yaml:"aws_dynamodb,omitempty"` + NeosyncConnectionData *NeosyncConnectionData `json:"neosync_connection_data,omitempty" yaml:"neosync_connection_data,omitempty"` +} + +type NeosyncConnectionData struct { + ApiKey *string `json:"api_key,omitempty" yaml:"api_key,omitempty"` + ApiUrl string `json:"api_url" yaml:"api_url"` + ConnectionId string `json:"connection_id" yaml:"connection_id"` + ConnectionType string `json:"connection_type" yaml:"connection_type"` + JobId *string `json:"job_id,omitempty" yaml:"job_id,omitempty"` + JobRunId *string `json:"job_run_id,omitempty" yaml:"job_run_id,omitempty"` + Schema string `json:"schema" yaml:"schema"` + Table string `json:"table" yaml:"table"` } type InputAwsDynamoDB struct { diff --git a/worker/pkg/benthos/environment/environment.go b/worker/pkg/benthos/environment/environment.go index fff3b1c769..a9547af794 100644 --- a/worker/pkg/benthos/environment/environment.go +++ b/worker/pkg/benthos/environment/environment.go @@ -5,11 +5,13 @@ import ( "fmt" "log/slog" + "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1/mgmtv1alpha1connect" neosync_benthos_defaulttransform "github.com/nucleuscloud/neosync/worker/pkg/benthos/default_transform" neosync_benthos_dynamodb "github.com/nucleuscloud/neosync/worker/pkg/benthos/dynamodb" neosync_benthos_error "github.com/nucleuscloud/neosync/worker/pkg/benthos/error" benthos_metrics "github.com/nucleuscloud/neosync/worker/pkg/benthos/metrics" neosync_benthos_mongodb "github.com/nucleuscloud/neosync/worker/pkg/benthos/mongodb" + neosync_benthos_connectiondata "github.com/nucleuscloud/neosync/worker/pkg/benthos/neosync_connection_data" openaigenerate "github.com/nucleuscloud/neosync/worker/pkg/benthos/openai_generate" neosync_benthos_sql "github.com/nucleuscloud/neosync/worker/pkg/benthos/sql" "github.com/warpstreamlabs/bento/public/bloblang" @@ -24,6 +26,8 @@ type RegisterConfig struct { mongoConfig *MongoConfig // nil to disable + connectionDataConfig *ConnectionDataConfig // nil to diable + stopChannel chan<- error blobEnv *bloblang.Environment @@ -52,6 +56,11 @@ func WithMongoConfig(mongocfg *MongoConfig) Option { cfg.mongoConfig = mongocfg } } +func WithConnectionDataConfig(connectionDataCfg *ConnectionDataConfig) Option { + return func(cfg *RegisterConfig) { + cfg.connectionDataConfig = connectionDataCfg + } +} func WithBlobEnv(b *bloblang.Environment) Option { return func(cfg *RegisterConfig) { cfg.blobEnv = b @@ -67,6 +76,10 @@ type MongoConfig struct { Provider neosync_benthos_mongodb.MongoPoolProvider } +type ConnectionDataConfig struct { + NeosyncConnectionDataApi mgmtv1alpha1connect.ConnectionDataServiceClient +} + func NewEnvironment(logger *slog.Logger, opts ...Option) (*service.Environment, error) { return NewWithEnvironment(service.NewEnvironment(), logger, opts...) } @@ -118,6 +131,13 @@ func NewWithEnvironment(env *service.Environment, logger *slog.Logger, opts ...O } } + if config.connectionDataConfig != nil { + err := neosync_benthos_connectiondata.RegisterNeosyncConnectionDataInput(env, config.connectionDataConfig.NeosyncConnectionDataApi) + if err != nil { + return nil, fmt.Errorf("unable to register neosync_connection_data input: %w", err) + } + } + err := openaigenerate.RegisterOpenaiGenerate(env) if err != nil { return nil, fmt.Errorf("unable to register openai_generate input to benthos instance: %w", err) diff --git a/cli/internal/benthos/inputs/neosync-connection-data.go b/worker/pkg/benthos/neosync_connection_data/neosync_connection_data_input.go similarity index 84% rename from cli/internal/benthos/inputs/neosync-connection-data.go rename to worker/pkg/benthos/neosync_connection_data/neosync_connection_data_input.go index fe77f1c029..207337c84a 100644 --- a/cli/internal/benthos/inputs/neosync-connection-data.go +++ b/worker/pkg/benthos/neosync_connection_data/neosync_connection_data_input.go @@ -1,4 +1,4 @@ -package input +package neosync_benthos_connectiondata import ( "context" @@ -8,12 +8,8 @@ import ( "connectrpc.com/connect" mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1" "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1/mgmtv1alpha1connect" - "github.com/nucleuscloud/neosync/cli/internal/auth" - auth_interceptor "github.com/nucleuscloud/neosync/cli/internal/connect/interceptors/auth" - "github.com/nucleuscloud/neosync/cli/internal/version" neosync_dynamodb "github.com/nucleuscloud/neosync/internal/dynamodb" neosync_metadata "github.com/nucleuscloud/neosync/worker/pkg/benthos/metadata" - http_client "github.com/nucleuscloud/neosync/worker/pkg/http/client" "github.com/warpstreamlabs/bento/public/service" ) @@ -28,21 +24,7 @@ var neosyncConnectionDataConfigSpec = service.NewConfigSpec(). Field(service.NewStringField("job_id").Optional()). Field(service.NewStringField("job_run_id").Optional()) -func newNeosyncConnectionDataInput(conf *service.ParsedConfig) (service.Input, error) { - var apiKey *string - if conf.Contains("api_key") { - apiKeyStr, err := conf.FieldString("api_key") - if err != nil { - return nil, err - } - apiKey = &apiKeyStr - } - - apiUrl, err := conf.FieldString("api_url") - if err != nil { - return nil, err - } - +func newNeosyncConnectionDataInput(conf *service.ParsedConfig, neosyncConnectApi mgmtv1alpha1connect.ConnectionDataServiceClient) (service.Input, error) { connectionId, err := conf.FieldString("connection_id") if err != nil { return nil, err @@ -80,8 +62,6 @@ func newNeosyncConnectionDataInput(conf *service.ParsedConfig) (service.Input, e } return service.AutoRetryNacks(&neosyncInput{ - apiKey: apiKey, - apiUrl: apiUrl, connectionId: connectionId, connectionType: connectionType, schema: schema, @@ -90,18 +70,17 @@ func newNeosyncConnectionDataInput(conf *service.ParsedConfig) (service.Input, e jobId: jobId, jobRunId: jobRunId, }, + neosyncConnectApi: neosyncConnectApi, }), nil } -func init() { - err := service.RegisterInput( +func RegisterNeosyncConnectionDataInput(env *service.Environment, neosyncConnectApi mgmtv1alpha1connect.ConnectionDataServiceClient) error { + return env.RegisterInput( "neosync_connection_data", neosyncConnectionDataConfigSpec, func(conf *service.ParsedConfig, mgr *service.Resources) (service.Input, error) { - return newNeosyncConnectionDataInput(conf) - }) - if err != nil { - panic(err) - } + return newNeosyncConnectionDataInput(conf, neosyncConnectApi) + }, + ) } //------------------------------------------------------------------------------ @@ -112,9 +91,6 @@ type connOpts struct { } type neosyncInput struct { - apiKey *string - apiUrl string - connectionId string connectionType string connectionOpts *connOpts @@ -129,12 +105,6 @@ type neosyncInput struct { } func (g *neosyncInput) Connect(ctx context.Context) error { - g.neosyncConnectApi = mgmtv1alpha1connect.NewConnectionDataServiceClient( - http_client.NewWithHeaders(version.Get().Headers()), - g.apiUrl, - connect.WithInterceptors(auth_interceptor.NewInterceptor(g.apiKey != nil, auth.AuthHeader, auth.GetAuthHeaderTokenFn(g.apiKey))), - ) - var streamCfg *mgmtv1alpha1.ConnectionStreamConfig if g.connectionType == "awsS3" { diff --git a/worker/pkg/workflows/datasync/workflow/integration_test.go b/worker/pkg/workflows/datasync/workflow/integration_test.go index c24f62ca74..14f98d5105 100644 --- a/worker/pkg/workflows/datasync/workflow/integration_test.go +++ b/worker/pkg/workflows/datasync/workflow/integration_test.go @@ -3,29 +3,24 @@ package datasync_workflow import ( "context" "database/sql" - "errors" "fmt" "log/slog" - "net/url" "os" "testing" - "time" "github.com/aws/aws-sdk-go-v2/service/dynamodb" dyntypes "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" "github.com/docker/go-connections/nat" _ "github.com/go-sql-driver/mysql" - "github.com/jackc/pgx/v5/pgxpool" mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1" sqlmanager_shared "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/shared" awsmanager "github.com/nucleuscloud/neosync/internal/aws" + tcmysql "github.com/nucleuscloud/neosync/internal/testutil/testcontainers/mysql" + tcpostgres "github.com/nucleuscloud/neosync/internal/testutil/testcontainers/postgres" "github.com/stretchr/testify/suite" "github.com/testcontainers/testcontainers-go" testmongodb "github.com/testcontainers/testcontainers-go/modules/mongodb" testmssql "github.com/testcontainers/testcontainers-go/modules/mssql" - testmysql "github.com/testcontainers/testcontainers-go/modules/mysql" - "github.com/testcontainers/testcontainers-go/modules/postgres" - testpg "github.com/testcontainers/testcontainers-go/modules/postgres" "github.com/testcontainers/testcontainers-go/modules/redis" "github.com/testcontainers/testcontainers-go/wait" "go.mongodb.org/mongo-driver/mongo" @@ -33,20 +28,6 @@ import ( "golang.org/x/sync/errgroup" ) -type postgresTestContainer struct { - pool *pgxpool.Pool - url string -} -type postgresTest struct { - pool *pgxpool.Pool - testcontainer *testpg.PostgresContainer - - source *postgresTestContainer - target *postgresTestContainer - - databases []string -} - type mssqlTest struct { pool *sql.DB testcontainer *testmssql.MSSQLServerContainer @@ -59,18 +40,6 @@ type mssqlTestContainer struct { url string } -type mysqlTestContainer struct { - pool *sql.DB - container *testmysql.MySQLContainer - url string - close func() -} - -type mysqlTest struct { - source *mysqlTestContainer - target *mysqlTestContainer -} - type redisTest struct { url string testcontainer *redis.RedisContainer @@ -91,8 +60,8 @@ type IntegrationTestSuite struct { ctx context.Context - mysql *mysqlTest - postgres *postgresTest + mysql *tcmysql.MysqlTestSyncContainer + postgres *tcpostgres.PostgresTestSyncContainer mssql *mssqlTest redis *redisTest dynamo *dynamodbTest @@ -213,149 +182,20 @@ func createMssqlTest(ctx context.Context, mssqlcontainer *testmssql.MSSQLServerC }, nil } -func (s *IntegrationTestSuite) SetupPostgres() (*postgresTest, error) { - pgcontainer, err := testpg.Run( - s.ctx, - "postgres:15", - postgres.WithDatabase("postgres"), - testcontainers.WithWaitStrategy( - wait.ForLog("database system is ready to accept connections"). - WithOccurrence(2).WithStartupTimeout(20*time.Second), - ), - ) - if err != nil { - return nil, err - } - postgresTest := &postgresTest{ - testcontainer: pgcontainer, - } - connstr, err := pgcontainer.ConnectionString(s.ctx, "sslmode=disable") - if err != nil { - return nil, err - } - - postgresTest.databases = []string{"datasync_source", "datasync_target"} - pool, err := pgxpool.New(s.ctx, connstr) - if err != nil { - return nil, err - } - postgresTest.pool = pool - - s.T().Logf("creating databases. %+v \n", postgresTest.databases) - for _, db := range postgresTest.databases { - _, err = postgresTest.pool.Exec(s.ctx, fmt.Sprintf("CREATE DATABASE %s;", db)) - if err != nil { - return nil, err - } - } - - srcUrl, err := getDbPgUrl(connstr, "datasync_source", "disable") - if err != nil { - return nil, err - } - postgresTest.source = &postgresTestContainer{ - url: srcUrl, - } - sourceConn, err := pgxpool.New(s.ctx, postgresTest.source.url) - if err != nil { - return nil, err - } - postgresTest.source.pool = sourceConn - - targetUrl, err := getDbPgUrl(connstr, "datasync_target", "disable") - if err != nil { - return nil, err - } - postgresTest.target = &postgresTestContainer{ - url: targetUrl, - } - targetConn, err := pgxpool.New(s.ctx, postgresTest.target.url) +func (s *IntegrationTestSuite) SetupPostgres() (*tcpostgres.PostgresTestSyncContainer, error) { + container, err := tcpostgres.NewPostgresTestSyncContainer(s.ctx, []tcpostgres.Option{}, []tcpostgres.Option{}) if err != nil { return nil, err } - postgresTest.target.pool = targetConn - return postgresTest, nil + return container, nil } -func (s *IntegrationTestSuite) SetupMysql() (*mysqlTest, error) { - var source *mysqlTestContainer - var target *mysqlTestContainer - - errgrp := errgroup.Group{} - errgrp.Go(func() error { - sourcecontainer, err := createMysqlTestContainer(s.ctx, "datasync", "root", "pass-source") - if err != nil { - return err - } - source = sourcecontainer - return nil - }) - - errgrp.Go(func() error { - targetcontainer, err := createMysqlTestContainer(s.ctx, "datasync", "root", "pass-target") - if err != nil { - return err - } - target = targetcontainer - return nil - }) - - err := errgrp.Wait() +func (s *IntegrationTestSuite) SetupMysql() (*tcmysql.MysqlTestSyncContainer, error) { + container, err := tcmysql.NewMysqlTestSyncContainer(s.ctx, []tcmysql.Option{}, []tcmysql.Option{}) if err != nil { return nil, err } - - return &mysqlTest{ - source: source, - target: target, - }, nil -} - -func createMysqlTestContainer( - ctx context.Context, - database, username, password string, -) (*mysqlTestContainer, error) { - container, err := testmysql.Run(ctx, - "mysql:8.0.36", - testmysql.WithDatabase(database), - testmysql.WithUsername(username), - testmysql.WithPassword(password), - testcontainers.WithWaitStrategy( - wait.ForLog("port: 3306 MySQL Community Server"). - WithOccurrence(1).WithStartupTimeout(20*time.Second), - ), - ) - if err != nil { - return nil, err - } - connstr, err := container.ConnectionString(ctx, "multiStatements=true&parseTime=true") - if err != nil { - panic(err) - } - pool, err := sql.Open(sqlmanager_shared.MysqlDriver, connstr) - if err != nil { - panic(err) - } - containerPort, err := container.MappedPort(ctx, "3306/tcp") - if err != nil { - return nil, err - } - containerHost, err := container.Host(ctx) - if err != nil { - return nil, err - } - - connUrl := fmt.Sprintf("mysql://%s:%s@%s:%s/%s?multiStatements=true&parseTime=true", username, password, containerHost, containerPort.Port(), database) - return &mysqlTestContainer{ - pool: pool, - url: connUrl, - container: container, - close: func() { - if pool != nil { - pool.Close() - } - }, - }, nil + return container, nil } func (s *IntegrationTestSuite) SetupRedis() (*redisTest, error) { @@ -445,8 +285,8 @@ func (s *IntegrationTestSuite) SetupDynamoDB() (*dynamodbTest, error) { func (s *IntegrationTestSuite) SetupSuite() { s.ctx = context.Background() - var postgresTest *postgresTest - var mysqlTest *mysqlTest + var postgresTest *tcpostgres.PostgresTestSyncContainer + var mysqlTest *tcmysql.MysqlTestSyncContainer var mssqlTest *mssqlTest var redisTest *redisTest var dynamoTest *dynamodbTest @@ -520,20 +360,6 @@ func (s *IntegrationTestSuite) SetupSuite() { s.mongodb = mongodbTest } -func (s *IntegrationTestSuite) RunPostgresSqlFiles(pool *pgxpool.Pool, testFolder string, files []string) { - s.T().Logf("running postgres sql file. folder: %s \n", testFolder) - for _, file := range files { - sqlStr, err := os.ReadFile(fmt.Sprintf("./testdata/%s/%s", testFolder, file)) - if err != nil { - panic(err) - } - _, err = pool.Exec(s.ctx, string(sqlStr)) - if err != nil { - panic(fmt.Errorf("unable to exec sql when running postgres sql files: %w", err)) - } - } -} - func (s *IntegrationTestSuite) RunMysqlSqlFiles(pool *sql.DB, testFolder string, files []string) { s.T().Logf("running mysql sql file. folder: %s \n", testFolder) for _, file := range files { @@ -662,26 +488,9 @@ func (s *IntegrationTestSuite) TearDownSuite() { s.T().Log("tearing down test suite") // postgres if s.postgres != nil { - for _, db := range s.postgres.databases { - _, err := s.postgres.pool.Exec(s.ctx, fmt.Sprintf("DROP DATABASE IF EXISTS %s WITH (FORCE);", db)) - if err != nil { - panic(err) - } - } - if s.postgres.source.pool != nil { - s.postgres.source.pool.Close() - } - if s.postgres.target.pool != nil { - s.postgres.target.pool.Close() - } - if s.postgres.pool != nil { - s.postgres.pool.Close() - } - if s.postgres.testcontainer != nil { - err := s.postgres.testcontainer.Terminate(s.ctx) - if err != nil { - panic(err) - } + err := s.postgres.TearDown(s.ctx) + if err != nil { + panic(err) } } @@ -706,19 +515,9 @@ func (s *IntegrationTestSuite) TearDownSuite() { // mysql if s.mysql != nil { - s.mysql.source.close() - s.mysql.target.close() - if s.mysql.source.container != nil { - err := s.mysql.source.container.Terminate(s.ctx) - if err != nil { - panic(err) - } - } - if s.mysql.target.container != nil { - err := s.mysql.target.container.Terminate(s.ctx) - if err != nil { - panic(err) - } + err := s.mysql.TearDown(s.ctx) + if err != nil { + panic(err) } } @@ -774,19 +573,3 @@ func TestIntegrationTestSuite(t *testing.T) { } suite.Run(t, new(IntegrationTestSuite)) } - -func getDbPgUrl(dburl, database, sslmode string) (string, error) { - u, err := url.Parse(dburl) - if err != nil { - var urlErr *url.Error - if errors.As(err, &urlErr) { - return "", fmt.Errorf("unable to parse postgres url [%s]: %w", urlErr.Op, urlErr.Err) - } - return "", fmt.Errorf("unable to parse postgres url: %w", err) - } - - u.Path = database - query := u.Query() - query.Add("sslmode", sslmode) - return u.String(), nil -} diff --git a/worker/pkg/workflows/datasync/workflow/testdata/javascript-transformers/tests.go b/worker/pkg/workflows/datasync/workflow/testdata/javascript-transformers/tests.go index 572b63a343..d6b6c9c32d 100644 --- a/worker/pkg/workflows/datasync/workflow/testdata/javascript-transformers/tests.go +++ b/worker/pkg/workflows/datasync/workflow/testdata/javascript-transformers/tests.go @@ -9,7 +9,7 @@ func GetSyncTests() []*workflow_testdata.IntegrationTest { return []*workflow_testdata.IntegrationTest{ { Name: "Javascript transformer sync", - Folder: "javascript-transformers", + Folder: "testdata/javascript-transformers", SourceFilePaths: []string{"create.sql", "insert.sql"}, TargetFilePaths: []string{"create.sql"}, JobMappings: getJsTransformerJobmappings(), @@ -19,7 +19,7 @@ func GetSyncTests() []*workflow_testdata.IntegrationTest { }, { Name: "Javascript generator sync", - Folder: "javascript-transformers", + Folder: "testdata/javascript-transformers", SourceFilePaths: []string{"create.sql", "insert.sql"}, TargetFilePaths: []string{"create.sql"}, JobMappings: getJsGeneratorJobmappings(), diff --git a/worker/pkg/workflows/datasync/workflow/testdata/mysql/all-types/tests.go b/worker/pkg/workflows/datasync/workflow/testdata/mysql/all-types/tests.go index fbe7bc3ffd..b656fdff9f 100644 --- a/worker/pkg/workflows/datasync/workflow/testdata/mysql/all-types/tests.go +++ b/worker/pkg/workflows/datasync/workflow/testdata/mysql/all-types/tests.go @@ -8,7 +8,7 @@ func GetSyncTests() []*workflow_testdata.IntegrationTest { return []*workflow_testdata.IntegrationTest{ { Name: "All datatypes passthrough", - Folder: "mysql/all-types", + Folder: "testdata/mysql/all-types", SourceFilePaths: []string{"create.sql", "insert.sql"}, TargetFilePaths: []string{"create-dbs.sql"}, JobMappings: GetDefaultSyncJobMappings(), diff --git a/worker/pkg/workflows/datasync/workflow/testdata/mysql/composite-keys/tests.go b/worker/pkg/workflows/datasync/workflow/testdata/mysql/composite-keys/tests.go index ba260078da..d525b15c09 100644 --- a/worker/pkg/workflows/datasync/workflow/testdata/mysql/composite-keys/tests.go +++ b/worker/pkg/workflows/datasync/workflow/testdata/mysql/composite-keys/tests.go @@ -10,7 +10,7 @@ func GetSyncTests() []*workflow_testdata.IntegrationTest { return []*workflow_testdata.IntegrationTest{ { Name: "Composite key transformation + truncate", - Folder: "mysql/composite-keys", + Folder: "testdata/mysql/composite-keys", SourceFilePaths: []string{"create.sql", "insert.sql"}, TargetFilePaths: []string{"create.sql", "insert.sql"}, JobMappings: getPkTransformerJobmappings(), diff --git a/worker/pkg/workflows/datasync/workflow/testdata/mysql/init-schema/tests.go b/worker/pkg/workflows/datasync/workflow/testdata/mysql/init-schema/tests.go index e571b84f74..c5709d0ace 100644 --- a/worker/pkg/workflows/datasync/workflow/testdata/mysql/init-schema/tests.go +++ b/worker/pkg/workflows/datasync/workflow/testdata/mysql/init-schema/tests.go @@ -9,7 +9,7 @@ func GetSyncTests() []*workflow_testdata.IntegrationTest { return []*workflow_testdata.IntegrationTest{ { Name: "Init Schema", - Folder: "mysql/init-schema", + Folder: "testdata/mysql/init-schema", SourceFilePaths: []string{"create.sql", "insert.sql"}, TargetFilePaths: []string{"create-dbs.sql"}, JobMappings: getJobmappings(), diff --git a/worker/pkg/workflows/datasync/workflow/testdata/mysql/multiple-dbs/tests.go b/worker/pkg/workflows/datasync/workflow/testdata/mysql/multiple-dbs/tests.go index eb0698e241..ee2913649e 100644 --- a/worker/pkg/workflows/datasync/workflow/testdata/mysql/multiple-dbs/tests.go +++ b/worker/pkg/workflows/datasync/workflow/testdata/mysql/multiple-dbs/tests.go @@ -6,7 +6,7 @@ func GetSyncTests() []*workflow_testdata.IntegrationTest { return []*workflow_testdata.IntegrationTest{ { Name: "multiple databases sync + init schema", - Folder: "mysql/multiple-dbs", + Folder: "testdata/mysql/multiple-dbs", SourceFilePaths: []string{"create-dbs.sql", "create.sql", "insert.sql"}, TargetFilePaths: []string{"create-dbs.sql"}, JobMappings: GetDefaultSyncJobMappings(), diff --git a/worker/pkg/workflows/datasync/workflow/testdata/postgres/all-types/tests.go b/worker/pkg/workflows/datasync/workflow/testdata/postgres/all-types/tests.go index e037a36bde..88f590328a 100644 --- a/worker/pkg/workflows/datasync/workflow/testdata/postgres/all-types/tests.go +++ b/worker/pkg/workflows/datasync/workflow/testdata/postgres/all-types/tests.go @@ -6,7 +6,7 @@ func GetSyncTests() []*workflow_testdata.IntegrationTest { return []*workflow_testdata.IntegrationTest{ { Name: "All Postgres types", - Folder: "postgres/all-types", + Folder: "testdata/postgres/all-types", SourceFilePaths: []string{"setup.sql"}, TargetFilePaths: []string{"schema-create.sql", "setup.sql"}, JobMappings: GetDefaultSyncJobMappings(), @@ -23,7 +23,7 @@ func GetSyncTests() []*workflow_testdata.IntegrationTest { }, { Name: "All Postgres types + init schema", - Folder: "postgres/all-types", + Folder: "testdata/postgres/all-types", SourceFilePaths: []string{"setup.sql"}, TargetFilePaths: []string{"schema-create.sql"}, JobMappings: GetDefaultSyncJobMappings(), diff --git a/worker/pkg/workflows/datasync/workflow/testdata/postgres/circular-dependencies/tests.go b/worker/pkg/workflows/datasync/workflow/testdata/postgres/circular-dependencies/tests.go index eb8d6cb847..dcead5c2ab 100644 --- a/worker/pkg/workflows/datasync/workflow/testdata/postgres/circular-dependencies/tests.go +++ b/worker/pkg/workflows/datasync/workflow/testdata/postgres/circular-dependencies/tests.go @@ -6,7 +6,7 @@ func GetSyncTests() []*workflow_testdata.IntegrationTest { return []*workflow_testdata.IntegrationTest{ { Name: "Circular Dependency sync + init schema", - Folder: "postgres/circular-dependencies", + Folder: "testdata/postgres/circular-dependencies", SourceFilePaths: []string{"setup.sql"}, TargetFilePaths: []string{"schema-create.sql"}, JobMappings: GetDefaultSyncJobMappings(), diff --git a/worker/pkg/workflows/datasync/workflow/testdata/postgres/double-reference/tests.go b/worker/pkg/workflows/datasync/workflow/testdata/postgres/double-reference/tests.go index a3385bb792..b12d75798b 100644 --- a/worker/pkg/workflows/datasync/workflow/testdata/postgres/double-reference/tests.go +++ b/worker/pkg/workflows/datasync/workflow/testdata/postgres/double-reference/tests.go @@ -6,7 +6,7 @@ func GetSyncTests() []*workflow_testdata.IntegrationTest { return []*workflow_testdata.IntegrationTest{ { Name: "Double reference sync", - Folder: "postgres/double-reference", + Folder: "testdata/postgres/double-reference", SourceFilePaths: []string{"source-create.sql", "insert.sql"}, TargetFilePaths: []string{"source-create.sql"}, JobMappings: GetDefaultSyncJobMappings(), @@ -19,7 +19,7 @@ func GetSyncTests() []*workflow_testdata.IntegrationTest { }, { Name: "Double reference subset", - Folder: "postgres/double-reference", + Folder: "testdata/postgres/double-reference", SourceFilePaths: []string{"source-create.sql", "insert.sql"}, TargetFilePaths: []string{"source-create.sql"}, SubsetMap: map[string]string{ diff --git a/worker/pkg/workflows/datasync/workflow/testdata/postgres/subsetting/tests.go b/worker/pkg/workflows/datasync/workflow/testdata/postgres/subsetting/tests.go index 7e4b02989e..2039f46e0d 100644 --- a/worker/pkg/workflows/datasync/workflow/testdata/postgres/subsetting/tests.go +++ b/worker/pkg/workflows/datasync/workflow/testdata/postgres/subsetting/tests.go @@ -6,7 +6,7 @@ func GetSyncTests() []*workflow_testdata.IntegrationTest { return []*workflow_testdata.IntegrationTest{ { Name: "Complex subsetting", - Folder: "postgres/subsetting", + Folder: "testdata/postgres/subsetting", SourceFilePaths: []string{"setup.sql"}, TargetFilePaths: []string{"schema-create.sql"}, JobMappings: GetDefaultSyncJobMappings(), diff --git a/worker/pkg/workflows/datasync/workflow/testdata/postgres/virtual-foreign-keys/tests.go b/worker/pkg/workflows/datasync/workflow/testdata/postgres/virtual-foreign-keys/tests.go index bb71642c20..4d80b78a33 100644 --- a/worker/pkg/workflows/datasync/workflow/testdata/postgres/virtual-foreign-keys/tests.go +++ b/worker/pkg/workflows/datasync/workflow/testdata/postgres/virtual-foreign-keys/tests.go @@ -9,7 +9,7 @@ func GetSyncTests() []*workflow_testdata.IntegrationTest { return []*workflow_testdata.IntegrationTest{ { Name: "Virtual Foreign Keys sync", - Folder: "postgres/virtual-foreign-keys", + Folder: "testdata/postgres/virtual-foreign-keys", SourceFilePaths: []string{"source-setup.sql"}, TargetFilePaths: []string{"target-setup.sql"}, JobMappings: GetDefaultSyncJobMappings(), @@ -26,7 +26,7 @@ func GetSyncTests() []*workflow_testdata.IntegrationTest { }, { Name: "Virtual Foreign Keys subset", - Folder: "postgres/virtual-foreign-keys", + Folder: "testdata/postgres/virtual-foreign-keys", SourceFilePaths: []string{"source-setup.sql"}, TargetFilePaths: []string{"target-setup.sql"}, SubsetMap: map[string]string{ diff --git a/worker/pkg/workflows/datasync/workflow/testdata/primary-key-transformer/tests.go b/worker/pkg/workflows/datasync/workflow/testdata/primary-key-transformer/tests.go index 098067bfd2..628dc0e17a 100644 --- a/worker/pkg/workflows/datasync/workflow/testdata/primary-key-transformer/tests.go +++ b/worker/pkg/workflows/datasync/workflow/testdata/primary-key-transformer/tests.go @@ -10,7 +10,7 @@ func GetSyncTests() []*workflow_testdata.IntegrationTest { return []*workflow_testdata.IntegrationTest{ { Name: "Circular Dependency primary key transformation", - Folder: "primary-key-transformer", + Folder: "testdata/primary-key-transformer", SourceFilePaths: []string{"create.sql", "insert.sql"}, TargetFilePaths: []string{"create.sql"}, JobMappings: getPkTransformerJobmappings(), diff --git a/worker/pkg/workflows/datasync/workflow/testdata/skip-fk-violations/tests.go b/worker/pkg/workflows/datasync/workflow/testdata/skip-fk-violations/tests.go index 73cc68a622..c2259efe38 100644 --- a/worker/pkg/workflows/datasync/workflow/testdata/skip-fk-violations/tests.go +++ b/worker/pkg/workflows/datasync/workflow/testdata/skip-fk-violations/tests.go @@ -8,7 +8,7 @@ func GetSyncTests() []*workflow_testdata.IntegrationTest { return []*workflow_testdata.IntegrationTest{ { Name: "Skip Foreign Key Violations", - Folder: "skip-fk-violations", + Folder: "testdata/skip-fk-violations", SourceFilePaths: []string{"create.sql", "insert.sql"}, TargetFilePaths: []string{"create.sql"}, JobMappings: GetDefaultSyncJobMappings(), @@ -27,7 +27,7 @@ func GetSyncTests() []*workflow_testdata.IntegrationTest { }, { Name: "Foreign Key Violations Error", - Folder: "skip-fk-violations", + Folder: "testdata/skip-fk-violations", SourceFilePaths: []string{"create.sql", "insert.sql"}, TargetFilePaths: []string{"create.sql"}, JobMappings: GetDefaultSyncJobMappings(), diff --git a/worker/pkg/workflows/datasync/workflow/workflow_integration_test.go b/worker/pkg/workflows/datasync/workflow/workflow_integration_test.go index 77153757e5..5cf2cf2c73 100644 --- a/worker/pkg/workflows/datasync/workflow/workflow_integration_test.go +++ b/worker/pkg/workflows/datasync/workflow/workflow_integration_test.go @@ -87,8 +87,10 @@ func (s *IntegrationTestSuite) Test_Workflow_Sync_Postgres() { t.Run(tt.Name, func(t *testing.T) { t.Logf("running integration test: %s \n", tt.Name) // setup - s.RunPostgresSqlFiles(s.postgres.source.pool, tt.Folder, tt.SourceFilePaths) - s.RunPostgresSqlFiles(s.postgres.target.pool, tt.Folder, tt.TargetFilePaths) + err := s.postgres.Source.RunSqlFiles(s.ctx, &tt.Folder, tt.SourceFilePaths) + require.NoError(t, err) + err = s.postgres.Target.RunSqlFiles(s.ctx, &tt.Folder, tt.TargetFilePaths) + require.NoError(t, err) schemas := []*mgmtv1alpha1.PostgresSourceSchemaOption{} subsetMap := map[string]*mgmtv1alpha1.PostgresSourceSchemaOption{} @@ -179,7 +181,7 @@ func (s *IntegrationTestSuite) Test_Workflow_Sync_Postgres() { Config: &mgmtv1alpha1.ConnectionConfig_PgConfig{ PgConfig: &mgmtv1alpha1.PostgresConnectionConfig{ ConnectionConfig: &mgmtv1alpha1.PostgresConnectionConfig_Url{ - Url: s.postgres.source.url, + Url: s.postgres.Source.URL, }, }, }, @@ -196,7 +198,7 @@ func (s *IntegrationTestSuite) Test_Workflow_Sync_Postgres() { Config: &mgmtv1alpha1.ConnectionConfig_PgConfig{ PgConfig: &mgmtv1alpha1.PostgresConnectionConfig{ ConnectionConfig: &mgmtv1alpha1.PostgresConnectionConfig_Url{ - Url: s.postgres.target.url, + Url: s.postgres.Target.URL, }, }, }, @@ -212,7 +214,7 @@ func (s *IntegrationTestSuite) Test_Workflow_Sync_Postgres() { srv := startHTTPServer(t, mux) env := executeWorkflow(t, srv, s.redis.url, "115aaf2c-776e-4847-8268-d914e3c15968") require.Truef(t, env.IsWorkflowCompleted(), fmt.Sprintf("Workflow did not complete. Test: %s", tt.Name)) - err := env.GetWorkflowError() + err = env.GetWorkflowError() if tt.ExpectError { require.Error(t, err, "Did not received Temporal Workflow Error", "testName", tt.Name) return @@ -220,7 +222,7 @@ func (s *IntegrationTestSuite) Test_Workflow_Sync_Postgres() { require.NoError(t, err, "Received Temporal Workflow Error", "testName", tt.Name) for table, expected := range tt.Expected { - rows, err := s.postgres.target.pool.Query(s.ctx, fmt.Sprintf("select * from %s;", table)) + rows, err := s.postgres.Target.DB.Query(s.ctx, fmt.Sprintf("select * from %s;", table)) require.NoError(t, err) count := 0 for rows.Next() { @@ -230,8 +232,10 @@ func (s *IntegrationTestSuite) Test_Workflow_Sync_Postgres() { } // tear down - s.RunPostgresSqlFiles(s.postgres.source.pool, tt.Folder, []string{"teardown.sql"}) - s.RunPostgresSqlFiles(s.postgres.target.pool, tt.Folder, []string{"teardown.sql"}) + err = s.postgres.Source.RunSqlFiles(s.ctx, &tt.Folder, []string{"teardown.sql"}) + require.NoError(t, err) + err = s.postgres.Target.RunSqlFiles(s.ctx, &tt.Folder, []string{"teardown.sql"}) + require.NoError(t, err) }) } }) @@ -447,10 +451,12 @@ func toRunContextKeyString(id *mgmtv1alpha1.RunContextKey) string { } func (s *IntegrationTestSuite) Test_Workflow_VirtualForeignKeys_Transform() { - testFolder := "postgres/virtual-foreign-keys" + testFolder := "testdata/postgres/virtual-foreign-keys" // setup - s.RunPostgresSqlFiles(s.postgres.source.pool, testFolder, []string{"source-setup.sql"}) - s.RunPostgresSqlFiles(s.postgres.target.pool, testFolder, []string{"target-setup.sql"}) + err := s.postgres.Source.RunSqlFiles(s.ctx, &testFolder, []string{"source-setup.sql"}) + require.NoError(s.T(), err) + err = s.postgres.Target.RunSqlFiles(s.ctx, &testFolder, []string{"target-setup.sql"}) + require.NoError(s.T(), err) virtualForeignKeys := testdata_virtualforeignkeys.GetVirtualForeignKeys() jobmappings := testdata_virtualforeignkeys.GetDefaultSyncJobMappings() @@ -515,7 +521,7 @@ func (s *IntegrationTestSuite) Test_Workflow_VirtualForeignKeys_Transform() { Config: &mgmtv1alpha1.ConnectionConfig_PgConfig{ PgConfig: &mgmtv1alpha1.PostgresConnectionConfig{ ConnectionConfig: &mgmtv1alpha1.PostgresConnectionConfig_Url{ - Url: s.postgres.source.url, + Url: s.postgres.Source.URL, }, }, }, @@ -532,7 +538,7 @@ func (s *IntegrationTestSuite) Test_Workflow_VirtualForeignKeys_Transform() { Config: &mgmtv1alpha1.ConnectionConfig_PgConfig{ PgConfig: &mgmtv1alpha1.PostgresConnectionConfig{ ConnectionConfig: &mgmtv1alpha1.PostgresConnectionConfig_Url{ - Url: s.postgres.target.url, + Url: s.postgres.Target.URL, }, }, }, @@ -549,12 +555,12 @@ func (s *IntegrationTestSuite) Test_Workflow_VirtualForeignKeys_Transform() { testName := "Virtual Foreign Key primary key transform" env := executeWorkflow(s.T(), srv, s.redis.url, "fd4d8660-31a0-48b2-9adf-10f11b94898f") require.Truef(s.T(), env.IsWorkflowCompleted(), fmt.Sprintf("Workflow did not complete. Test: %s", testName)) - err := env.GetWorkflowError() + err = env.GetWorkflowError() require.NoError(s.T(), err, "Received Temporal Workflow Error", "testName", testName) tables := []string{"regions", "countries", "locations", "departments", "dependents", "jobs", "employees"} for _, t := range tables { - rows, err := s.postgres.target.pool.Query(s.ctx, fmt.Sprintf("select * from vfk_hr.%s;", t)) + rows, err := s.postgres.Target.DB.Query(s.ctx, fmt.Sprintf("select * from vfk_hr.%s;", t)) require.NoError(s.T(), err) count := 0 for rows.Next() { @@ -564,35 +570,37 @@ func (s *IntegrationTestSuite) Test_Workflow_VirtualForeignKeys_Transform() { require.NoError(s.T(), err) } - rows := s.postgres.source.pool.QueryRow(s.ctx, "select count(*) from vfk_hr.countries where country_id = 'US';") + rows := s.postgres.Source.DB.QueryRow(s.ctx, "select count(*) from vfk_hr.countries where country_id = 'US';") var rowCount int err = rows.Scan(&rowCount) require.NoError(s.T(), err) require.Equal(s.T(), 1, rowCount) - rows = s.postgres.source.pool.QueryRow(s.ctx, "select count(*) from vfk_hr.locations where country_id = 'US';") + rows = s.postgres.Source.DB.QueryRow(s.ctx, "select count(*) from vfk_hr.locations where country_id = 'US';") err = rows.Scan(&rowCount) require.NoError(s.T(), err) require.Equal(s.T(), 3, rowCount) - rows = s.postgres.target.pool.QueryRow(s.ctx, "select count(*) from vfk_hr.countries where country_id = 'US';") + rows = s.postgres.Target.DB.QueryRow(s.ctx, "select count(*) from vfk_hr.countries where country_id = 'US';") err = rows.Scan(&rowCount) require.NoError(s.T(), err) require.Equal(s.T(), 0, rowCount) - rows = s.postgres.target.pool.QueryRow(s.ctx, "select count(*) from vfk_hr.countries where country_id = 'SU';") + rows = s.postgres.Target.DB.QueryRow(s.ctx, "select count(*) from vfk_hr.countries where country_id = 'SU';") err = rows.Scan(&rowCount) require.NoError(s.T(), err) require.Equal(s.T(), 1, rowCount) - rows = s.postgres.target.pool.QueryRow(s.ctx, "select count(*) from vfk_hr.locations where country_id = 'SU';") + rows = s.postgres.Target.DB.QueryRow(s.ctx, "select count(*) from vfk_hr.locations where country_id = 'SU';") err = rows.Scan(&rowCount) require.NoError(s.T(), err) require.Equal(s.T(), 3, rowCount) // tear down - s.RunPostgresSqlFiles(s.postgres.source.pool, testFolder, []string{"teardown.sql"}) - s.RunPostgresSqlFiles(s.postgres.target.pool, testFolder, []string{"teardown.sql"}) + err = s.postgres.Source.RunSqlFiles(s.ctx, &testFolder, []string{"teardown.sql"}) + require.NoError(s.T(), err) + err = s.postgres.Target.RunSqlFiles(s.ctx, &testFolder, []string{"teardown.sql"}) + require.NoError(s.T(), err) } func getAllMysqlSyncTests() map[string][]*workflow_testdata.IntegrationTest { @@ -618,8 +626,10 @@ func (s *IntegrationTestSuite) Test_Workflow_Sync_Mysql() { t.Run(tt.Name, func(t *testing.T) { t.Logf("running integration test: %s \n", tt.Name) // setup - s.RunMysqlSqlFiles(s.mysql.source.pool, tt.Folder, tt.SourceFilePaths) - s.RunMysqlSqlFiles(s.mysql.target.pool, tt.Folder, tt.TargetFilePaths) + err := s.mysql.Source.RunSqlFiles(s.ctx, &tt.Folder, tt.SourceFilePaths) + require.NoError(t, err) + err = s.mysql.Target.RunSqlFiles(s.ctx, &tt.Folder, tt.TargetFilePaths) + require.NoError(t, err) schemas := []*mgmtv1alpha1.MysqlSourceSchemaOption{} subsetMap := map[string]*mgmtv1alpha1.MysqlSourceSchemaOption{} @@ -710,7 +720,7 @@ func (s *IntegrationTestSuite) Test_Workflow_Sync_Mysql() { Config: &mgmtv1alpha1.ConnectionConfig_MysqlConfig{ MysqlConfig: &mgmtv1alpha1.MysqlConnectionConfig{ ConnectionConfig: &mgmtv1alpha1.MysqlConnectionConfig_Url{ - Url: s.mysql.source.url, + Url: s.mysql.Source.URL, }, }, }, @@ -727,7 +737,7 @@ func (s *IntegrationTestSuite) Test_Workflow_Sync_Mysql() { Config: &mgmtv1alpha1.ConnectionConfig_MysqlConfig{ MysqlConfig: &mgmtv1alpha1.MysqlConnectionConfig{ ConnectionConfig: &mgmtv1alpha1.MysqlConnectionConfig_Url{ - Url: s.mysql.target.url, + Url: s.mysql.Target.URL, }, }, }, @@ -742,7 +752,7 @@ func (s *IntegrationTestSuite) Test_Workflow_Sync_Mysql() { srv := startHTTPServer(t, mux) env := executeWorkflow(t, srv, s.redis.url, "115aaf2c-776e-4847-8268-d914e3c15968") require.Truef(t, env.IsWorkflowCompleted(), fmt.Sprintf("Workflow did not complete. Test: %s", tt.Name)) - err := env.GetWorkflowError() + err = env.GetWorkflowError() if tt.ExpectError { require.Error(t, err, "Did not received Temporal Workflow Error", "testName", tt.Name) return @@ -750,7 +760,7 @@ func (s *IntegrationTestSuite) Test_Workflow_Sync_Mysql() { require.NoError(t, err, "Received Temporal Workflow Error", "testName", tt.Name) for table, expected := range tt.Expected { - rows, err := s.mysql.target.pool.QueryContext(s.ctx, fmt.Sprintf("select * from %s;", table)) + rows, err := s.mysql.Target.DB.QueryContext(s.ctx, fmt.Sprintf("select * from %s;", table)) require.NoError(t, err) count := 0 for rows.Next() { @@ -760,8 +770,10 @@ func (s *IntegrationTestSuite) Test_Workflow_Sync_Mysql() { } // tear down - s.RunMysqlSqlFiles(s.mysql.source.pool, tt.Folder, []string{"teardown.sql"}) - s.RunMysqlSqlFiles(s.mysql.target.pool, tt.Folder, []string{"teardown.sql"}) + err = s.mysql.Source.RunSqlFiles(s.ctx, &tt.Folder, []string{"teardown.sql"}) + require.NoError(t, err) + err = s.mysql.Target.RunSqlFiles(s.ctx, &tt.Folder, []string{"teardown.sql"}) + require.NoError(t, err) }) } }) @@ -1387,7 +1399,9 @@ func getAllMongoDBSyncTests() map[string][]*workflow_testdata.IntegrationTest { func (s *IntegrationTestSuite) Test_Workflow_Generate() { // setup testName := "Generate Job" - s.RunPostgresSqlFiles(s.postgres.target.pool, "generate-job", []string{"setup.sql"}) + folder := "testdata/generate-job" + err := s.postgres.Target.RunSqlFiles(s.ctx, &folder, []string{"setup.sql"}) + require.NoError(s.T(), err) connectionId := "226add85-5751-4232-b085-a0ae93afc7ce" schema := "generate_job" @@ -1467,7 +1481,7 @@ func (s *IntegrationTestSuite) Test_Workflow_Generate() { Config: &mgmtv1alpha1.ConnectionConfig_PgConfig{ PgConfig: &mgmtv1alpha1.PostgresConnectionConfig{ ConnectionConfig: &mgmtv1alpha1.PostgresConnectionConfig_Url{ - Url: s.postgres.target.url, + Url: s.postgres.Target.URL, }, }, }, @@ -1483,10 +1497,10 @@ func (s *IntegrationTestSuite) Test_Workflow_Generate() { srv := startHTTPServer(s.T(), mux) env := executeWorkflow(s.T(), srv, s.redis.url, "115aaf2c-776e-4847-8268-d914e3c15968") require.Truef(s.T(), env.IsWorkflowCompleted(), fmt.Sprintf("Workflow did not complete. Test: %s", testName)) - err := env.GetWorkflowError() + err = env.GetWorkflowError() require.NoError(s.T(), err, "Received Temporal Workflow Error", "testName", testName) - rows, err := s.postgres.target.pool.Query(s.ctx, fmt.Sprintf("select * from %s.%s;", schema, table)) + rows, err := s.postgres.Target.DB.Query(s.ctx, fmt.Sprintf("select * from %s.%s;", schema, table)) require.NoError(s.T(), err) count := 0 for rows.Next() { @@ -1495,7 +1509,8 @@ func (s *IntegrationTestSuite) Test_Workflow_Generate() { require.Equalf(s.T(), 10, count, fmt.Sprintf("Test: %s Table: %s", testName, table)) // tear down - s.RunPostgresSqlFiles(s.postgres.target.pool, "generate-job", []string{"teardown.sql"}) + err = s.postgres.Target.RunSqlFiles(s.ctx, &folder, []string{"teardown.sql"}) + require.NoError(s.T(), err) } func executeWorkflow(