diff --git a/go/test/endtoend/transaction/twopc/main_test.go b/go/test/endtoend/transaction/twopc/main_test.go index eaf835e678e..e7818e3088d 100644 --- a/go/test/endtoend/transaction/twopc/main_test.go +++ b/go/test/endtoend/transaction/twopc/main_test.go @@ -227,24 +227,21 @@ func getStatement(stmt string) string { } func runVStream(t *testing.T, ctx context.Context, ch chan *binlogdatapb.VEvent, vtgateConn *vtgateconn.VTGateConn) { - vgtid := &binlogdatapb.VGtid{ - ShardGtids: []*binlogdatapb.ShardGtid{ - {Keyspace: keyspaceName, Shard: "-40", Gtid: "current"}, - {Keyspace: keyspaceName, Shard: "40-80", Gtid: "current"}, - {Keyspace: keyspaceName, Shard: "80-", Gtid: "current"}, - }} - filter := &binlogdatapb.Filter{ - Rules: []*binlogdatapb.Rule{{ - Match: "/.*/", - }}, + shards := []string{"-40", "40-80", "80-"} + shardGtids := make([]*binlogdatapb.ShardGtid, 0, len(shards)) + var seen = make(map[string]bool, len(shards)) + var wg sync.WaitGroup + for _, shard := range shards { + shardGtids = append(shardGtids, &binlogdatapb.ShardGtid{Keyspace: keyspaceName, Shard: shard, Gtid: "current"}) + seen[shard] = false + wg.Add(1) } + vgtid := &binlogdatapb.VGtid{ShardGtids: shardGtids} + filter := &binlogdatapb.Filter{Rules: []*binlogdatapb.Rule{{Match: "/.*/"}}} + vReader, err := vtgateConn.VStream(ctx, topodatapb.TabletType_PRIMARY, vgtid, filter, nil) require.NoError(t, err) - // Use a channel to signal that the first VGTID event has been processed - firstEventProcessed := make(chan struct{}) - var once sync.Once - go func() { for { evs, err := vReader.Recv() @@ -254,9 +251,12 @@ func runVStream(t *testing.T, ctx context.Context, ch chan *binlogdatapb.VEvent, require.NoError(t, err) for _, ev := range evs { - // Signal the first event has been processed using sync.Once + // Mark VGTID event from each shard seen. if ev.Type == binlogdatapb.VEventType_VGTID { - once.Do(func() { close(firstEventProcessed) }) + if !seen[ev.Shard] { + seen[ev.Shard] = true + wg.Done() + } } if ev.Type == binlogdatapb.VEventType_ROW || ev.Type == binlogdatapb.VEventType_FIELD { ch <- ev @@ -265,8 +265,8 @@ func runVStream(t *testing.T, ctx context.Context, ch chan *binlogdatapb.VEvent, } }() - // Wait for the first event to be processed - <-firstEventProcessed + // Wait for VGTID event from all shards + wg.Wait() } func retrieveTransitions(t *testing.T, ch chan *binlogdatapb.VEvent, tableMap map[string][]*querypb.Field, dtMap map[string]string) map[string][]string {