diff --git a/tests/integrations/client/http_client_test.go b/tests/integrations/client/http_client_test.go index 33652da9be0..fa109946e4b 100644 --- a/tests/integrations/client/http_client_test.go +++ b/tests/integrations/client/http_client_test.go @@ -41,190 +41,169 @@ import ( "github.com/tikv/pd/tests" ) -type mode int - -// We have two ways to create HTTP client. -// 1. using `NewClient` which created `DefaultPDServiceDiscovery` -// 2. using `NewClientWithServiceDiscovery` which pass a `PDServiceDiscovery` as parameter -// test cases should be run in both modes. -const ( - defaultServiceDiscovery mode = iota - specificServiceDiscovery -) - type httpClientTestSuite struct { suite.Suite - env map[mode]*httpClientTestEnv + // 1. Using `NewClient` will create a `DefaultPDServiceDiscovery` internal. + // 2. Using `NewClientWithServiceDiscovery` will need a `PDServiceDiscovery` to be passed in. + withServiceDiscovery bool + ctx context.Context + cancelFunc context.CancelFunc + cluster *tests.TestCluster + endpoints []string + client pd.Client } -type httpClientTestEnv struct { - ctx context.Context - cancelFunc context.CancelFunc - cluster *tests.TestCluster - endpoints []string +func TestHTTPClientTestSuite(t *testing.T) { + suite.Run(t, &httpClientTestSuite{ + withServiceDiscovery: false, + }) } -func TestHTTPClientTestSuite(t *testing.T) { - suite.Run(t, new(httpClientTestSuite)) +func TestHTTPClientTestSuiteWithServiceDiscovery(t *testing.T) { + suite.Run(t, &httpClientTestSuite{ + withServiceDiscovery: true, + }) } func (suite *httpClientTestSuite) SetupSuite() { - suite.env = make(map[mode]*httpClientTestEnv) re := suite.Require() + suite.ctx, suite.cancelFunc = context.WithCancel(context.Background()) - for _, mode := range []mode{defaultServiceDiscovery, specificServiceDiscovery} { - env := &httpClientTestEnv{} - env.ctx, env.cancelFunc = context.WithCancel(context.Background()) + cluster, err := tests.NewTestCluster(suite.ctx, 2) + re.NoError(err) - cluster, err := tests.NewTestCluster(env.ctx, 2) - re.NoError(err) + err = cluster.RunInitialServers() + re.NoError(err) + leader := cluster.WaitLeader() + re.NotEmpty(leader) + leaderServer := cluster.GetLeaderServer() - err = cluster.RunInitialServers() + err = leaderServer.BootstrapCluster() + re.NoError(err) + for _, region := range []*core.RegionInfo{ + core.NewTestRegionInfo(10, 1, []byte("a1"), []byte("a2")), + core.NewTestRegionInfo(11, 1, []byte("a2"), []byte("a3")), + } { + err := leaderServer.GetRaftCluster().HandleRegionHeartbeat(region) re.NoError(err) - leader := cluster.WaitLeader() - re.NotEmpty(leader) - leaderServer := cluster.GetLeaderServer() - - err = leaderServer.BootstrapCluster() + } + var ( + testServers = cluster.GetServers() + endpoints = make([]string, 0, len(testServers)) + ) + for _, s := range testServers { + addr := s.GetConfig().AdvertiseClientUrls + url, err := url.Parse(addr) re.NoError(err) - for _, region := range []*core.RegionInfo{ - core.NewTestRegionInfo(10, 1, []byte("a1"), []byte("a2")), - core.NewTestRegionInfo(11, 1, []byte("a2"), []byte("a3")), - } { - err := leaderServer.GetRaftCluster().HandleRegionHeartbeat(region) - re.NoError(err) - } - var ( - testServers = cluster.GetServers() - endpoints = make([]string, 0, len(testServers)) - ) - for _, s := range testServers { - addr := s.GetConfig().AdvertiseClientUrls - url, err := url.Parse(addr) - re.NoError(err) - endpoints = append(endpoints, url.Host) - } - env.endpoints = endpoints - env.cluster = cluster - - suite.env[mode] = env + endpoints = append(endpoints, url.Host) } -} - -func (suite *httpClientTestSuite) TearDownSuite() { - for _, env := range suite.env { - env.cancelFunc() - env.cluster.Destroy() + suite.endpoints = endpoints + suite.cluster = cluster + + if suite.withServiceDiscovery { + // Run test with specific service discovery. + cli := setupCli(suite.ctx, re, suite.endpoints) + sd := cli.GetServiceDiscovery() + suite.client = pd.NewClientWithServiceDiscovery("pd-http-client-it-grpc", sd) + } else { + // Run test with default service discovery. + suite.client = pd.NewClient("pd-http-client-it-http", suite.endpoints) } } -// RunTestInTwoModes is to run test in two modes. -func (suite *httpClientTestSuite) RunTestInTwoModes(test func(mode mode, client pd.Client)) { - // Run test with specific service discovery. - cli := setupCli(suite.env[specificServiceDiscovery].ctx, suite.Require(), suite.env[specificServiceDiscovery].endpoints) - sd := cli.GetServiceDiscovery() - client := pd.NewClientWithServiceDiscovery("pd-http-client-it-grpc", sd) - test(specificServiceDiscovery, client) - client.Close() - - // Run test with default service discovery. - client = pd.NewClient("pd-http-client-it-http", suite.env[defaultServiceDiscovery].endpoints) - test(defaultServiceDiscovery, client) - client.Close() +func (suite *httpClientTestSuite) TearDownSuite() { + suite.cancelFunc() + suite.client.Close() + suite.cluster.Destroy() } func (suite *httpClientTestSuite) TestMeta() { - suite.RunTestInTwoModes(suite.checkMeta) -} - -func (suite *httpClientTestSuite) checkMeta(mode mode, client pd.Client) { re := suite.Require() - env := suite.env[mode] - replicateConfig, err := client.GetReplicateConfig(env.ctx) + client := suite.client + ctx, cancel := context.WithCancel(suite.ctx) + defer cancel() + replicateConfig, err := client.GetReplicateConfig(ctx) re.NoError(err) re.Equal(3.0, replicateConfig["max-replicas"]) - region, err := client.GetRegionByID(env.ctx, 10) + region, err := client.GetRegionByID(ctx, 10) re.NoError(err) re.Equal(int64(10), region.ID) re.Equal(core.HexRegionKeyStr([]byte("a1")), region.StartKey) re.Equal(core.HexRegionKeyStr([]byte("a2")), region.EndKey) - region, err = client.GetRegionByKey(env.ctx, []byte("a2")) + region, err = client.GetRegionByKey(ctx, []byte("a2")) re.NoError(err) re.Equal(int64(11), region.ID) re.Equal(core.HexRegionKeyStr([]byte("a2")), region.StartKey) re.Equal(core.HexRegionKeyStr([]byte("a3")), region.EndKey) - regions, err := client.GetRegions(env.ctx) + regions, err := client.GetRegions(ctx) re.NoError(err) re.Equal(int64(2), regions.Count) re.Len(regions.Regions, 2) - regions, err = client.GetRegionsByKeyRange(env.ctx, pd.NewKeyRange([]byte("a1"), []byte("a3")), -1) + regions, err = client.GetRegionsByKeyRange(ctx, pd.NewKeyRange([]byte("a1"), []byte("a3")), -1) re.NoError(err) re.Equal(int64(2), regions.Count) re.Len(regions.Regions, 2) - regions, err = client.GetRegionsByStoreID(env.ctx, 1) + regions, err = client.GetRegionsByStoreID(ctx, 1) re.NoError(err) re.Equal(int64(2), regions.Count) re.Len(regions.Regions, 2) - regions, err = client.GetEmptyRegions(env.ctx) + regions, err = client.GetEmptyRegions(ctx) re.NoError(err) re.Equal(int64(2), regions.Count) re.Len(regions.Regions, 2) - state, err := client.GetRegionsReplicatedStateByKeyRange(env.ctx, pd.NewKeyRange([]byte("a1"), []byte("a3"))) + state, err := client.GetRegionsReplicatedStateByKeyRange(ctx, pd.NewKeyRange([]byte("a1"), []byte("a3"))) re.NoError(err) re.Equal("INPROGRESS", state) - regionStats, err := client.GetRegionStatusByKeyRange(env.ctx, pd.NewKeyRange([]byte("a1"), []byte("a3")), false) + regionStats, err := client.GetRegionStatusByKeyRange(ctx, pd.NewKeyRange([]byte("a1"), []byte("a3")), false) re.NoError(err) re.Positive(regionStats.Count) re.NotEmpty(regionStats.StoreLeaderCount) - regionStats, err = client.GetRegionStatusByKeyRange(env.ctx, pd.NewKeyRange([]byte("a1"), []byte("a3")), true) + regionStats, err = client.GetRegionStatusByKeyRange(ctx, pd.NewKeyRange([]byte("a1"), []byte("a3")), true) re.NoError(err) re.Positive(regionStats.Count) re.Empty(regionStats.StoreLeaderCount) - hotReadRegions, err := client.GetHotReadRegions(env.ctx) + hotReadRegions, err := client.GetHotReadRegions(ctx) re.NoError(err) re.Len(hotReadRegions.AsPeer, 1) re.Len(hotReadRegions.AsLeader, 1) - hotWriteRegions, err := client.GetHotWriteRegions(env.ctx) + hotWriteRegions, err := client.GetHotWriteRegions(ctx) re.NoError(err) re.Len(hotWriteRegions.AsPeer, 1) re.Len(hotWriteRegions.AsLeader, 1) - historyHorRegions, err := client.GetHistoryHotRegions(env.ctx, &pd.HistoryHotRegionsRequest{ + historyHorRegions, err := client.GetHistoryHotRegions(ctx, &pd.HistoryHotRegionsRequest{ StartTime: 0, EndTime: time.Now().AddDate(0, 0, 1).UnixNano() / int64(time.Millisecond), }) re.NoError(err) re.Empty(historyHorRegions.HistoryHotRegion) - store, err := client.GetStores(env.ctx) + store, err := client.GetStores(ctx) re.NoError(err) re.Equal(1, store.Count) re.Len(store.Stores, 1) storeID := uint64(store.Stores[0].Store.ID) // TODO: why type is different? - store2, err := client.GetStore(env.ctx, storeID) + store2, err := client.GetStore(ctx, storeID) re.NoError(err) re.EqualValues(storeID, store2.Store.ID) - version, err := client.GetClusterVersion(env.ctx) + version, err := client.GetClusterVersion(ctx) re.NoError(err) re.Equal("0.0.0", version) - rgs, _ := client.GetRegionsByKeyRange(env.ctx, pd.NewKeyRange([]byte("a"), []byte("a1")), 100) + rgs, _ := client.GetRegionsByKeyRange(ctx, pd.NewKeyRange([]byte("a"), []byte("a1")), 100) re.Equal(int64(0), rgs.Count) - rgs, _ = client.GetRegionsByKeyRange(env.ctx, pd.NewKeyRange([]byte("a1"), []byte("a3")), 100) + rgs, _ = client.GetRegionsByKeyRange(ctx, pd.NewKeyRange([]byte("a1"), []byte("a3")), 100) re.Equal(int64(2), rgs.Count) - rgs, _ = client.GetRegionsByKeyRange(env.ctx, pd.NewKeyRange([]byte("a2"), []byte("b")), 100) + rgs, _ = client.GetRegionsByKeyRange(ctx, pd.NewKeyRange([]byte("a2"), []byte("b")), 100) re.Equal(int64(1), rgs.Count) - rgs, _ = client.GetRegionsByKeyRange(env.ctx, pd.NewKeyRange([]byte(""), []byte("")), 100) + rgs, _ = client.GetRegionsByKeyRange(ctx, pd.NewKeyRange([]byte(""), []byte("")), 100) re.Equal(int64(2), rgs.Count) } func (suite *httpClientTestSuite) TestGetMinResolvedTSByStoresIDs() { - suite.RunTestInTwoModes(suite.checkGetMinResolvedTSByStoresIDs) -} - -func (suite *httpClientTestSuite) checkGetMinResolvedTSByStoresIDs(mode mode, client pd.Client) { re := suite.Require() - env := suite.env[mode] - + client := suite.client + ctx, cancel := context.WithCancel(suite.ctx) + defer cancel() testMinResolvedTS := tsoutil.TimeToTS(time.Now()) - raftCluster := env.cluster.GetLeaderServer().GetRaftCluster() + raftCluster := suite.cluster.GetLeaderServer().GetRaftCluster() err := raftCluster.SetMinResolvedTS(1, testMinResolvedTS) re.NoError(err) // Make sure the min resolved TS is updated. @@ -233,18 +212,18 @@ func (suite *httpClientTestSuite) checkGetMinResolvedTSByStoresIDs(mode mode, cl return minResolvedTS == testMinResolvedTS }) // Wait for the cluster-level min resolved TS to be initialized. - minResolvedTS, storeMinResolvedTSMap, err := client.GetMinResolvedTSByStoresIDs(env.ctx, nil) + minResolvedTS, storeMinResolvedTSMap, err := client.GetMinResolvedTSByStoresIDs(ctx, nil) re.NoError(err) re.Equal(testMinResolvedTS, minResolvedTS) re.Empty(storeMinResolvedTSMap) // Get the store-level min resolved TS. - minResolvedTS, storeMinResolvedTSMap, err = client.GetMinResolvedTSByStoresIDs(env.ctx, []uint64{1}) + minResolvedTS, storeMinResolvedTSMap, err = client.GetMinResolvedTSByStoresIDs(ctx, []uint64{1}) re.NoError(err) re.Equal(testMinResolvedTS, minResolvedTS) re.Len(storeMinResolvedTSMap, 1) re.Equal(minResolvedTS, storeMinResolvedTSMap[1]) // Get the store-level min resolved TS with an invalid store ID. - minResolvedTS, storeMinResolvedTSMap, err = client.GetMinResolvedTSByStoresIDs(env.ctx, []uint64{1, 2}) + minResolvedTS, storeMinResolvedTSMap, err = client.GetMinResolvedTSByStoresIDs(ctx, []uint64{1, 2}) re.NoError(err) re.Equal(testMinResolvedTS, minResolvedTS) re.Len(storeMinResolvedTSMap, 2) @@ -253,22 +232,19 @@ func (suite *httpClientTestSuite) checkGetMinResolvedTSByStoresIDs(mode mode, cl } func (suite *httpClientTestSuite) TestRule() { - suite.RunTestInTwoModes(suite.checkRule) -} - -func (suite *httpClientTestSuite) checkRule(mode mode, client pd.Client) { re := suite.Require() - env := suite.env[mode] - - bundles, err := client.GetAllPlacementRuleBundles(env.ctx) + client := suite.client + ctx, cancel := context.WithCancel(suite.ctx) + defer cancel() + bundles, err := client.GetAllPlacementRuleBundles(ctx) re.NoError(err) re.Len(bundles, 1) re.Equal(placement.DefaultGroupID, bundles[0].ID) - bundle, err := client.GetPlacementRuleBundleByGroup(env.ctx, placement.DefaultGroupID) + bundle, err := client.GetPlacementRuleBundleByGroup(ctx, placement.DefaultGroupID) re.NoError(err) re.Equal(bundles[0], bundle) // Check if we have the default rule. - checkRuleResult(re, env, client, &pd.Rule{ + suite.checkRuleResult(ctx, re, &pd.Rule{ GroupID: placement.DefaultGroupID, ID: placement.DefaultRuleID, Role: pd.Voter, @@ -277,7 +253,7 @@ func (suite *httpClientTestSuite) checkRule(mode mode, client pd.Client) { EndKey: []byte{}, }, 1, true) // Should be the same as the rules in the bundle. - checkRuleResult(re, env, client, bundle.Rules[0], 1, true) + suite.checkRuleResult(ctx, re, bundle.Rules[0], 1, true) testRule := &pd.Rule{ GroupID: placement.DefaultGroupID, ID: "test", @@ -286,39 +262,39 @@ func (suite *httpClientTestSuite) checkRule(mode mode, client pd.Client) { StartKey: []byte{}, EndKey: []byte{}, } - err = client.SetPlacementRule(env.ctx, testRule) + err = client.SetPlacementRule(ctx, testRule) re.NoError(err) - checkRuleResult(re, env, client, testRule, 2, true) - err = client.DeletePlacementRule(env.ctx, placement.DefaultGroupID, "test") + suite.checkRuleResult(ctx, re, testRule, 2, true) + err = client.DeletePlacementRule(ctx, placement.DefaultGroupID, "test") re.NoError(err) - checkRuleResult(re, env, client, testRule, 1, false) + suite.checkRuleResult(ctx, re, testRule, 1, false) testRuleOp := &pd.RuleOp{ Rule: testRule, Action: pd.RuleOpAdd, } - err = client.SetPlacementRuleInBatch(env.ctx, []*pd.RuleOp{testRuleOp}) + err = client.SetPlacementRuleInBatch(ctx, []*pd.RuleOp{testRuleOp}) re.NoError(err) - checkRuleResult(re, env, client, testRule, 2, true) + suite.checkRuleResult(ctx, re, testRule, 2, true) testRuleOp = &pd.RuleOp{ Rule: testRule, Action: pd.RuleOpDel, } - err = client.SetPlacementRuleInBatch(env.ctx, []*pd.RuleOp{testRuleOp}) + err = client.SetPlacementRuleInBatch(ctx, []*pd.RuleOp{testRuleOp}) re.NoError(err) - checkRuleResult(re, env, client, testRule, 1, false) - err = client.SetPlacementRuleBundles(env.ctx, []*pd.GroupBundle{ + suite.checkRuleResult(ctx, re, testRule, 1, false) + err = client.SetPlacementRuleBundles(ctx, []*pd.GroupBundle{ { ID: placement.DefaultGroupID, Rules: []*pd.Rule{testRule}, }, }, true) re.NoError(err) - checkRuleResult(re, env, client, testRule, 1, true) - ruleGroups, err := client.GetAllPlacementRuleGroups(env.ctx) + suite.checkRuleResult(ctx, re, testRule, 1, true) + ruleGroups, err := client.GetAllPlacementRuleGroups(ctx) re.NoError(err) re.Len(ruleGroups, 1) re.Equal(placement.DefaultGroupID, ruleGroups[0].ID) - ruleGroup, err := client.GetPlacementRuleGroupByID(env.ctx, placement.DefaultGroupID) + ruleGroup, err := client.GetPlacementRuleGroupByID(ctx, placement.DefaultGroupID) re.NoError(err) re.Equal(ruleGroups[0], ruleGroup) testRuleGroup := &pd.RuleGroup{ @@ -326,14 +302,14 @@ func (suite *httpClientTestSuite) checkRule(mode mode, client pd.Client) { Index: 1, Override: true, } - err = client.SetPlacementRuleGroup(env.ctx, testRuleGroup) + err = client.SetPlacementRuleGroup(ctx, testRuleGroup) re.NoError(err) - ruleGroup, err = client.GetPlacementRuleGroupByID(env.ctx, testRuleGroup.ID) + ruleGroup, err = client.GetPlacementRuleGroupByID(ctx, testRuleGroup.ID) re.NoError(err) re.Equal(testRuleGroup, ruleGroup) - err = client.DeletePlacementRuleGroupByID(env.ctx, testRuleGroup.ID) + err = client.DeletePlacementRuleGroupByID(ctx, testRuleGroup.ID) re.NoError(err) - ruleGroup, err = client.GetPlacementRuleGroupByID(env.ctx, testRuleGroup.ID) + ruleGroup, err = client.GetPlacementRuleGroupByID(ctx, testRuleGroup.ID) re.ErrorContains(err, http.StatusText(http.StatusNotFound)) re.Empty(ruleGroup) // Test the start key and end key. @@ -345,34 +321,33 @@ func (suite *httpClientTestSuite) checkRule(mode mode, client pd.Client) { StartKey: []byte("a1"), EndKey: []byte(""), } - err = client.SetPlacementRule(env.ctx, testRule) + err = client.SetPlacementRule(ctx, testRule) re.NoError(err) - checkRuleResult(re, env, client, testRule, 1, true) + suite.checkRuleResult(ctx, re, testRule, 1, true) } -func checkRuleResult( - re *require.Assertions, - env *httpClientTestEnv, - client pd.Client, +func (suite *httpClientTestSuite) checkRuleResult( + ctx context.Context, re *require.Assertions, rule *pd.Rule, totalRuleCount int, exist bool, ) { + client := suite.client if exist { - got, err := client.GetPlacementRule(env.ctx, rule.GroupID, rule.ID) + got, err := client.GetPlacementRule(ctx, rule.GroupID, rule.ID) re.NoError(err) // skip comparison of the generated field got.StartKeyHex = rule.StartKeyHex got.EndKeyHex = rule.EndKeyHex re.Equal(rule, got) } else { - _, err := client.GetPlacementRule(env.ctx, rule.GroupID, rule.ID) + _, err := client.GetPlacementRule(ctx, rule.GroupID, rule.ID) re.ErrorContains(err, http.StatusText(http.StatusNotFound)) } // Check through the `GetPlacementRulesByGroup` API. - rules, err := client.GetPlacementRulesByGroup(env.ctx, rule.GroupID) + rules, err := client.GetPlacementRulesByGroup(ctx, rule.GroupID) re.NoError(err) checkRuleFunc(re, rules, rule, totalRuleCount, exist) // Check through the `GetPlacementRuleBundleByGroup` API. - bundle, err := client.GetPlacementRuleBundleByGroup(env.ctx, rule.GroupID) + bundle, err := client.GetPlacementRuleBundleByGroup(ctx, rule.GroupID) re.NoError(err) checkRuleFunc(re, bundle.Rules, rule, totalRuleCount, exist) } @@ -400,14 +375,11 @@ func checkRuleFunc( } func (suite *httpClientTestSuite) TestRegionLabel() { - suite.RunTestInTwoModes(suite.checkRegionLabel) -} - -func (suite *httpClientTestSuite) checkRegionLabel(mode mode, client pd.Client) { re := suite.Require() - env := suite.env[mode] - - labelRules, err := client.GetAllRegionLabelRules(env.ctx) + client := suite.client + ctx, cancel := context.WithCancel(suite.ctx) + defer cancel() + labelRules, err := client.GetAllRegionLabelRules(ctx) re.NoError(err) re.Len(labelRules, 1) re.Equal("keyspaces/0", labelRules[0].ID) @@ -418,9 +390,9 @@ func (suite *httpClientTestSuite) checkRegionLabel(mode mode, client pd.Client) RuleType: "key-range", Data: labeler.MakeKeyRanges("1234", "5678"), } - err = client.SetRegionLabelRule(env.ctx, labelRule) + err = client.SetRegionLabelRule(ctx, labelRule) re.NoError(err) - labelRules, err = client.GetAllRegionLabelRules(env.ctx) + labelRules, err = client.GetAllRegionLabelRules(ctx) re.NoError(err) re.Len(labelRules, 2) sort.Slice(labelRules, func(i, j int) bool { @@ -440,9 +412,9 @@ func (suite *httpClientTestSuite) checkRegionLabel(mode mode, client pd.Client) SetRules: []*pd.LabelRule{labelRule}, DeleteRules: []string{"rule1"}, } - err = client.PatchRegionLabelRules(env.ctx, patch) + err = client.PatchRegionLabelRules(ctx, patch) re.NoError(err) - allLabelRules, err := client.GetAllRegionLabelRules(env.ctx) + allLabelRules, err := client.GetAllRegionLabelRules(ctx) re.NoError(err) re.Len(labelRules, 2) sort.Slice(allLabelRules, func(i, j int) bool { @@ -451,7 +423,7 @@ func (suite *httpClientTestSuite) checkRegionLabel(mode mode, client pd.Client) re.Equal(labelRule.ID, allLabelRules[1].ID) re.Equal(labelRule.Labels, allLabelRules[1].Labels) re.Equal(labelRule.RuleType, allLabelRules[1].RuleType) - labelRules, err = client.GetRegionLabelRulesByIDs(env.ctx, []string{"keyspaces/0", "rule2"}) + labelRules, err = client.GetRegionLabelRulesByIDs(ctx, []string{"keyspaces/0", "rule2"}) re.NoError(err) sort.Slice(labelRules, func(i, j int) bool { return labelRules[i].ID < labelRules[j].ID @@ -460,24 +432,21 @@ func (suite *httpClientTestSuite) checkRegionLabel(mode mode, client pd.Client) } func (suite *httpClientTestSuite) TestAccelerateSchedule() { - suite.RunTestInTwoModes(suite.checkAccelerateSchedule) -} - -func (suite *httpClientTestSuite) checkAccelerateSchedule(mode mode, client pd.Client) { re := suite.Require() - env := suite.env[mode] - - raftCluster := env.cluster.GetLeaderServer().GetRaftCluster() + client := suite.client + ctx, cancel := context.WithCancel(suite.ctx) + defer cancel() + raftCluster := suite.cluster.GetLeaderServer().GetRaftCluster() suspectRegions := raftCluster.GetSuspectRegions() re.Empty(suspectRegions) - err := client.AccelerateSchedule(env.ctx, pd.NewKeyRange([]byte("a1"), []byte("a2"))) + err := client.AccelerateSchedule(ctx, pd.NewKeyRange([]byte("a1"), []byte("a2"))) re.NoError(err) suspectRegions = raftCluster.GetSuspectRegions() re.Len(suspectRegions, 1) raftCluster.ClearSuspectRegions() suspectRegions = raftCluster.GetSuspectRegions() re.Empty(suspectRegions) - err = client.AccelerateScheduleInBatch(env.ctx, []*pd.KeyRange{ + err = client.AccelerateScheduleInBatch(ctx, []*pd.KeyRange{ pd.NewKeyRange([]byte("a1"), []byte("a2")), pd.NewKeyRange([]byte("a2"), []byte("a3")), }) @@ -487,24 +456,21 @@ func (suite *httpClientTestSuite) checkAccelerateSchedule(mode mode, client pd.C } func (suite *httpClientTestSuite) TestConfig() { - suite.RunTestInTwoModes(suite.checkConfig) -} - -func (suite *httpClientTestSuite) checkConfig(mode mode, client pd.Client) { re := suite.Require() - env := suite.env[mode] - - config, err := client.GetConfig(env.ctx) + client := suite.client + ctx, cancel := context.WithCancel(suite.ctx) + defer cancel() + config, err := client.GetConfig(ctx) re.NoError(err) re.Equal(float64(4), config["schedule"].(map[string]any)["leader-schedule-limit"]) newConfig := map[string]any{ "schedule.leader-schedule-limit": float64(8), } - err = client.SetConfig(env.ctx, newConfig) + err = client.SetConfig(ctx, newConfig) re.NoError(err) - config, err = client.GetConfig(env.ctx) + config, err = client.GetConfig(ctx) re.NoError(err) re.Equal(float64(8), config["schedule"].(map[string]any)["leader-schedule-limit"]) @@ -512,15 +478,15 @@ func (suite *httpClientTestSuite) checkConfig(mode mode, client pd.Client) { newConfig = map[string]any{ "schedule.leader-schedule-limit": float64(16), } - err = client.SetConfig(env.ctx, newConfig, 5) + err = client.SetConfig(ctx, newConfig, 5) re.NoError(err) - resp, err := env.cluster.GetEtcdClient().Get(env.ctx, sc.TTLConfigPrefix+"/schedule.leader-schedule-limit") + resp, err := suite.cluster.GetEtcdClient().Get(ctx, sc.TTLConfigPrefix+"/schedule.leader-schedule-limit") re.NoError(err) re.Equal([]byte("16"), resp.Kvs[0].Value) // delete the config with TTL. - err = client.SetConfig(env.ctx, newConfig, 0) + err = client.SetConfig(ctx, newConfig, 0) re.NoError(err) - resp, err = env.cluster.GetEtcdClient().Get(env.ctx, sc.TTLConfigPrefix+"/schedule.leader-schedule-limit") + resp, err = suite.cluster.GetEtcdClient().Get(ctx, sc.TTLConfigPrefix+"/schedule.leader-schedule-limit") re.NoError(err) re.Empty(resp.Kvs) @@ -528,81 +494,72 @@ func (suite *httpClientTestSuite) checkConfig(mode mode, client pd.Client) { newConfig = map[string]any{ "schedule.max-pending-peer-count": uint64(math.MaxInt32), } - err = client.SetConfig(env.ctx, newConfig, 4) + err = client.SetConfig(ctx, newConfig, 4) re.NoError(err) - c := env.cluster.GetLeaderServer().GetRaftCluster().GetOpts().GetMaxPendingPeerCount() + c := suite.cluster.GetLeaderServer().GetRaftCluster().GetOpts().GetMaxPendingPeerCount() re.Equal(uint64(math.MaxInt32), c) - err = client.SetConfig(env.ctx, newConfig, 0) + err = client.SetConfig(ctx, newConfig, 0) re.NoError(err) - resp, err = env.cluster.GetEtcdClient().Get(env.ctx, sc.TTLConfigPrefix+"/schedule.max-pending-peer-count") + resp, err = suite.cluster.GetEtcdClient().Get(ctx, sc.TTLConfigPrefix+"/schedule.max-pending-peer-count") re.NoError(err) re.Empty(resp.Kvs) } func (suite *httpClientTestSuite) TestScheduleConfig() { - suite.RunTestInTwoModes(suite.checkScheduleConfig) -} - -func (suite *httpClientTestSuite) checkScheduleConfig(mode mode, client pd.Client) { re := suite.Require() - env := suite.env[mode] - - config, err := client.GetScheduleConfig(env.ctx) + client := suite.client + ctx, cancel := context.WithCancel(suite.ctx) + defer cancel() + config, err := client.GetScheduleConfig(ctx) re.NoError(err) re.Equal(float64(4), config["hot-region-schedule-limit"]) re.Equal(float64(2048), config["region-schedule-limit"]) config["hot-region-schedule-limit"] = float64(8) - err = client.SetScheduleConfig(env.ctx, config) + err = client.SetScheduleConfig(ctx, config) re.NoError(err) - config, err = client.GetScheduleConfig(env.ctx) + config, err = client.GetScheduleConfig(ctx) re.NoError(err) re.Equal(float64(8), config["hot-region-schedule-limit"]) re.Equal(float64(2048), config["region-schedule-limit"]) } func (suite *httpClientTestSuite) TestSchedulers() { - suite.RunTestInTwoModes(suite.checkSchedulers) -} - -func (suite *httpClientTestSuite) checkSchedulers(mode mode, client pd.Client) { re := suite.Require() - env := suite.env[mode] - - schedulers, err := client.GetSchedulers(env.ctx) + client := suite.client + ctx, cancel := context.WithCancel(suite.ctx) + defer cancel() + schedulers, err := client.GetSchedulers(ctx) re.NoError(err) re.Empty(schedulers) - err = client.CreateScheduler(env.ctx, "evict-leader-scheduler", 1) + err = client.CreateScheduler(ctx, "evict-leader-scheduler", 1) re.NoError(err) - schedulers, err = client.GetSchedulers(env.ctx) + schedulers, err = client.GetSchedulers(ctx) re.NoError(err) re.Len(schedulers, 1) - err = client.SetSchedulerDelay(env.ctx, "evict-leader-scheduler", 100) + err = client.SetSchedulerDelay(ctx, "evict-leader-scheduler", 100) re.NoError(err) - err = client.SetSchedulerDelay(env.ctx, "not-exist", 100) + err = client.SetSchedulerDelay(ctx, "not-exist", 100) re.ErrorContains(err, "500 Internal Server Error") // TODO: should return friendly error message } func (suite *httpClientTestSuite) TestSetStoreLabels() { - suite.RunTestInTwoModes(suite.checkSetStoreLabels) -} - -func (suite *httpClientTestSuite) checkSetStoreLabels(mode mode, client pd.Client) { re := suite.Require() - env := suite.env[mode] - - resp, err := client.GetStores(env.ctx) + client := suite.client + ctx, cancel := context.WithCancel(suite.ctx) + defer cancel() + resp, err := client.GetStores(ctx) re.NoError(err) setStore := resp.Stores[0] re.Empty(setStore.Store.Labels, nil) storeLabels := map[string]string{ "zone": "zone1", } - err = client.SetStoreLabels(env.ctx, 1, storeLabels) + err = client.SetStoreLabels(ctx, 1, storeLabels) re.NoError(err) - resp, err = client.GetStores(env.ctx) + resp, err = client.GetStores(ctx) re.NoError(err) for _, store := range resp.Stores { if store.Store.ID == setStore.Store.ID { @@ -614,67 +571,52 @@ func (suite *httpClientTestSuite) checkSetStoreLabels(mode mode, client pd.Clien } func (suite *httpClientTestSuite) TestTransferLeader() { - suite.RunTestInTwoModes(suite.checkTransferLeader) -} - -func (suite *httpClientTestSuite) checkTransferLeader(mode mode, client pd.Client) { re := suite.Require() - env := suite.env[mode] - - members, err := client.GetMembers(env.ctx) + client := suite.client + ctx, cancel := context.WithCancel(suite.ctx) + defer cancel() + members, err := client.GetMembers(ctx) re.NoError(err) re.Len(members.Members, 2) - leader, err := client.GetLeader(env.ctx) + leader, err := client.GetLeader(ctx) re.NoError(err) // Transfer leader to another pd for _, member := range members.Members { if member.GetName() != leader.GetName() { - err = client.TransferLeader(env.ctx, member.GetName()) + err = client.TransferLeader(ctx, member.GetName()) re.NoError(err) break } } - newLeader := env.cluster.WaitLeader() + newLeader := suite.cluster.WaitLeader() re.NotEmpty(newLeader) re.NoError(err) re.NotEqual(leader.GetName(), newLeader) // Force to update the members info. testutil.Eventually(re, func() bool { - leader, err = client.GetLeader(env.ctx) + leader, err = client.GetLeader(ctx) re.NoError(err) return newLeader == leader.GetName() }) - members, err = client.GetMembers(env.ctx) + members, err = client.GetMembers(ctx) re.NoError(err) re.Len(members.Members, 2) re.Equal(leader.GetName(), members.Leader.GetName()) } func (suite *httpClientTestSuite) TestVersion() { - suite.RunTestInTwoModes(suite.checkVersion) -} - -func (suite *httpClientTestSuite) checkVersion(mode mode, client pd.Client) { re := suite.Require() - env := suite.env[mode] - - ver, err := client.GetPDVersion(env.ctx) + ver, err := suite.client.GetPDVersion(suite.ctx) re.NoError(err) re.Equal(versioninfo.PDReleaseVersion, ver) } func (suite *httpClientTestSuite) TestStatus() { - suite.RunTestInTwoModes(suite.checkStatus) -} - -func (suite *httpClientTestSuite) checkStatus(mode mode, client pd.Client) { re := suite.Require() - env := suite.env[mode] - - status, err := client.GetStatus(env.ctx) + status, err := suite.client.GetStatus(suite.ctx) re.NoError(err) re.Equal(versioninfo.PDReleaseVersion, status.Version) re.Equal(versioninfo.PDGitHash, status.GitHash) @@ -683,48 +625,41 @@ func (suite *httpClientTestSuite) checkStatus(mode mode, client pd.Client) { } func (suite *httpClientTestSuite) TestAdmin() { - suite.RunTestInTwoModes(suite.checkAdmin) -} - -func (suite *httpClientTestSuite) checkAdmin(mode mode, client pd.Client) { re := suite.Require() - env := suite.env[mode] - - err := client.SetSnapshotRecoveringMark(env.ctx) + client := suite.client + ctx, cancel := context.WithCancel(suite.ctx) + defer cancel() + err := client.SetSnapshotRecoveringMark(ctx) re.NoError(err) - err = client.ResetTS(env.ctx, 123, true) + err = client.ResetTS(ctx, 123, true) re.NoError(err) - err = client.ResetBaseAllocID(env.ctx, 456) + err = client.ResetBaseAllocID(ctx, 456) re.NoError(err) - err = client.DeleteSnapshotRecoveringMark(env.ctx) + err = client.DeleteSnapshotRecoveringMark(ctx) re.NoError(err) } func (suite *httpClientTestSuite) TestWithBackoffer() { - suite.RunTestInTwoModes(suite.checkWithBackoffer) -} - -func (suite *httpClientTestSuite) checkWithBackoffer(mode mode, client pd.Client) { re := suite.Require() - env := suite.env[mode] - + client := suite.client + ctx, cancel := context.WithCancel(suite.ctx) + defer cancel() // Should return with 404 error without backoffer. - rule, err := client.GetPlacementRule(env.ctx, "non-exist-group", "non-exist-rule") + rule, err := client.GetPlacementRule(ctx, "non-exist-group", "non-exist-rule") re.ErrorContains(err, http.StatusText(http.StatusNotFound)) re.Nil(rule) // Should return with 404 error even with an infinite backoffer. rule, err = client. WithBackoffer(retry.InitialBackoffer(100*time.Millisecond, time.Second, 0)). - GetPlacementRule(env.ctx, "non-exist-group", "non-exist-rule") + GetPlacementRule(ctx, "non-exist-group", "non-exist-rule") re.ErrorContains(err, http.StatusText(http.StatusNotFound)) re.Nil(rule) } func (suite *httpClientTestSuite) TestRedirectWithMetrics() { re := suite.Require() - env := suite.env[defaultServiceDiscovery] - cli := setupCli(env.ctx, suite.Require(), env.endpoints) + cli := setupCli(suite.ctx, re, suite.endpoints) defer cli.Close() sd := cli.GetServiceDiscovery() @@ -785,12 +720,10 @@ func (suite *httpClientTestSuite) TestRedirectWithMetrics() { } func (suite *httpClientTestSuite) TestUpdateKeyspaceGCManagementType() { - suite.RunTestInTwoModes(suite.checkUpdateKeyspaceGCManagementType) -} - -func (suite *httpClientTestSuite) checkUpdateKeyspaceGCManagementType(mode mode, client pd.Client) { re := suite.Require() - env := suite.env[mode] + client := suite.client + ctx, cancel := context.WithCancel(suite.ctx) + defer cancel() keyspaceName := "DEFAULT" expectGCManagementType := "keyspace_level_gc" @@ -800,10 +733,10 @@ func (suite *httpClientTestSuite) checkUpdateKeyspaceGCManagementType(mode mode, GCManagementType: expectGCManagementType, }, } - err := client.UpdateKeyspaceGCManagementType(env.ctx, keyspaceName, &keyspaceSafePointVersionConfig) + err := client.UpdateKeyspaceGCManagementType(ctx, keyspaceName, &keyspaceSafePointVersionConfig) re.NoError(err) - keyspaceMetaRes, err := client.GetKeyspaceMetaByName(env.ctx, keyspaceName) + keyspaceMetaRes, err := client.GetKeyspaceMetaByName(ctx, keyspaceName) re.NoError(err) val, ok := keyspaceMetaRes.Config["gc_management_type"] @@ -813,14 +746,8 @@ func (suite *httpClientTestSuite) checkUpdateKeyspaceGCManagementType(mode mode, } func (suite *httpClientTestSuite) TestGetHealthStatus() { - suite.RunTestInTwoModes(suite.checkGetHealthStatus) -} - -func (suite *httpClientTestSuite) checkGetHealthStatus(mode mode, client pd.Client) { re := suite.Require() - env := suite.env[mode] - - healths, err := client.GetHealthStatus(env.ctx) + healths, err := suite.client.GetHealthStatus(suite.ctx) re.NoError(err) re.Len(healths, 2) sort.Slice(healths, func(i, j int) bool {