diff --git a/bor/keeper.go b/bor/keeper.go index 8b7613e87..7e8269ccc 100644 --- a/bor/keeper.go +++ b/bor/keeper.go @@ -229,8 +229,14 @@ func (k *Keeper) SelectNextProducers(ctx sdk.Context, seed common.Hash) (vals [] return spanEligibleVals, nil } + // TODO remove old selection algorigthm // select next producers using seed as blockheader hash - newProducersIds, err := SelectNextProducers(seed, spanEligibleVals, producerCount) + fn := SelectNextProducers + if ctx.BlockHeight() < 375300 { + fn = XXXSelectNextProducers + } + + newProducersIds, err := fn(seed, spanEligibleVals, producerCount) if err != nil { return vals, err } diff --git a/bor/selection.go b/bor/selection.go index 4758a03cd..f0d0217ee 100644 --- a/bor/selection.go +++ b/bor/selection.go @@ -1,14 +1,18 @@ package bor import ( + "encoding/binary" + "math" + "math/rand" + "github.com/maticnetwork/bor/common" "github.com/maticnetwork/heimdall/bor/types" "github.com/maticnetwork/heimdall/helper" hmTypes "github.com/maticnetwork/heimdall/types" ) -// SelectNextProducers selects producers for next span by converting power to tickets -func SelectNextProducers(blkHash common.Hash, spanEligibleVals []hmTypes.Validator, producerCount uint64) (selectedIDs []uint64, err error) { +// XXXSelectNextProducers selects producers for next span by converting power to tickets +func XXXSelectNextProducers(blkHash common.Hash, spanEligibleVals []hmTypes.Validator, producerCount uint64) (selectedIDs []uint64, err error) { if len(spanEligibleVals) <= int(producerCount) { for _, val := range spanEligibleVals { selectedIDs = append(selectedIDs, uint64(val.ID)) @@ -37,3 +41,93 @@ func convertToSlots(vals []hmTypes.Validator) (validatorIndices []uint64) { } return validatorIndices } + +// +// New selection algorithm +// + +// SelectNextProducers selects producers for next span by converting power to tickets +func SelectNextProducers(blkHash common.Hash, spanEligibleValidators []hmTypes.Validator, producerCount uint64) ([]uint64, error) { + selectedProducers := make([]uint64, 0) + + if len(spanEligibleValidators) <= int(producerCount) { + for _, validator := range spanEligibleValidators { + selectedProducers = append(selectedProducers, uint64(validator.ID)) + } + + return selectedProducers, nil + } + + // extract seed from hash + seedBytes := helper.ToBytes32(blkHash.Bytes()[:32]) + seed := int64(binary.BigEndian.Uint64(seedBytes[:])) + rand.Seed(seed) + + // weighted range from validators' voting power + votingPower := make([]uint64, len(spanEligibleValidators)) + for idx, validator := range spanEligibleValidators { + votingPower[idx] = uint64(validator.VotingPower) + } + + weightedRanges, totalVotingPower := createWeightedRanges(votingPower) + // select producers, with replacement + for i := uint64(0); i < producerCount; i++ { + /* + random must be in [1, totalVotingPower] to avoid situation such as + 2 validators with 1 staking power each. + Weighted range will look like (1, 2) + Rolling inclusive will have a range of 0 - 2, making validator with staking power 1 chance of selection = 66% + */ + targetWeight := randomRangeInclusive(1, totalVotingPower) + index := binarySearch(weightedRanges, targetWeight) + selectedProducers = append(selectedProducers, spanEligibleValidators[index].ID.Uint64()) + } + + return selectedProducers[:producerCount], nil +} + +func binarySearch(array []uint64, search uint64) int { + if len(array) == 0 { + return -1 + } + l := 0 + r := len(array) - 1 + for l < r { + mid := (l + r) / 2 + if array[mid] >= search { + r = mid + } else { + l = mid + 1 + } + } + return l +} + +// randomRangeInclusive produces unbiased pseudo random in the range [min, max]. Uses rand.Uint64() and can be seeded beforehand. +func randomRangeInclusive(min uint64, max uint64) uint64 { + if max <= min { + return max + } + + rangeLength := max - min + 1 + maxAllowedValue := math.MaxUint64 - math.MaxUint64%rangeLength - 1 + randomValue := rand.Uint64() + + // reject anything that is beyond the reminder to avoid bias + for randomValue >= maxAllowedValue { + randomValue = rand.Uint64() + } + + return min + randomValue%rangeLength +} + +// createWeightedRanges converts array [1, 2, 3] into cumulative form [1, 3, 6] +func createWeightedRanges(weights []uint64) ([]uint64, uint64) { + weightedRanges := make([]uint64, len(weights)) + totalWeight := uint64(0) + for i := 0; i < len(weightedRanges); i++ { + totalWeight += weights[i] + weightedRanges[i] = totalWeight + } + return weightedRanges, totalWeight +} diff --git a/bor/selection_test.go b/bor/selection_test.go index a263673f0..c0deba8e2 100644 --- a/bor/selection_test.go +++ b/bor/selection_test.go @@ -1,77 +1,21 @@ package bor import ( + "encoding/binary" "encoding/json" "fmt" + "math/big" + "reflect" + "strconv" "testing" "github.com/maticnetwork/bor/common" + "github.com/maticnetwork/bor/crypto" + "github.com/maticnetwork/heimdall/types" hmTypes "github.com/maticnetwork/heimdall/types" "github.com/stretchr/testify/require" ) -type producerSelectionTestCase struct { - seed string - producerCount uint64 - resultSlots int64 - resultProducers int64 -} - -func TestSelectNextProducers(t *testing.T) { - testcases := []producerSelectionTestCase{ - producerSelectionTestCase{"0x8f5bab218b6bb34476f51ca588e9f4553a3a7ce5e13a66c660a5283e97e9a85a", 10, 5, 5}, - producerSelectionTestCase{"0x8f5bab218b6bb34476f51ca588e9f4553a3a7ce5e13a66c660a5283e97e9a85a", 5, 5, 5}, - producerSelectionTestCase{"0xe09cc356df20c7a2dd38cb85b680a16ec29bd8b3e1ecc1b20f2e5603d5e7ee85", 10, 5, 5}, - producerSelectionTestCase{"0xe09cc356df20c7a2dd38cb85b680a16ec29bd8b3e1ecc1b20f2e5603d5e7ee85", 5, 5, 5}, - - producerSelectionTestCase{"0x8f5bab218b6bb34476f51ca588e9f4553a3a7ce5e13a66c660a5283e97e9a85a", 4, 4, 3}, - producerSelectionTestCase{"0xe09cc356df20c7a2dd38cb85b680a16ec29bd8b3e1ecc1b20f2e5603d5e7ee85", 4, 4, 1}, - } - - var validators []hmTypes.Validator - json.Unmarshal([]byte(testValidators), &validators) - require.Equal(t, 5, len(validators), "Total validators should be 5") - - for i, testcase := range testcases { - seed := common.HexToHash(testcase.seed) - producerIds, err := SelectNextProducers(seed, validators, testcase.producerCount) - fmt.Println("producerIds", producerIds) - require.NoError(t, err, "Error should be nil") - producers, slots := getSelectedValidtorsFromIDs(validators, producerIds) - require.Equal(t, testcase.resultSlots, slots, "Total slots should be %v (Testcase %v)", testcase.resultSlots, i+1) - require.Equal(t, int(testcase.resultProducers), len(producers), "Total producers should be %v (Testcase %v)", testcase.resultProducers, i+1) - } -} - -func getSelectedValidtorsFromIDs(validators []hmTypes.Validator, producerIds []uint64) ([]hmTypes.Validator, int64) { - var vals []hmTypes.Validator - IDToPower := make(map[uint64]uint64) - for _, ID := range producerIds { - IDToPower[ID] = IDToPower[ID] + 1 - } - - var slots int64 - for key, value := range IDToPower { - if val, ok := findValidatorByID(validators, key); ok { - val.VotingPower = int64(value) - vals = append(vals, val) - slots = slots + int64(value) - } - } - - return vals, slots -} - -func findValidatorByID(validators []hmTypes.Validator, id uint64) (val hmTypes.Validator, ok bool) { - for _, v := range validators { - if v.ID.Uint64() == id { - return v, true - } - } - - return -} - const testValidators = `[ { "ID": 3, @@ -124,3 +68,226 @@ const testValidators = `[ "accum": 10000 } ]` + +func TestSelectNextProducers(t *testing.T) { + type producerSelectionTestCase struct { + seed string + producerCount uint64 + resultSlots int64 + resultProducers int64 + } + + testcases := []producerSelectionTestCase{ + producerSelectionTestCase{"0x8f5bab218b6bb34476f51ca588e9f4553a3a7ce5e13a66c660a5283e97e9a85a", 10, 5, 5}, + producerSelectionTestCase{"0x8f5bab218b6bb34476f51ca588e9f4553a3a7ce5e13a66c660a5283e97e9a85a", 5, 5, 5}, + producerSelectionTestCase{"0xe09cc356df20c7a2dd38cb85b680a16ec29bd8b3e1ecc1b20f2e5603d5e7ee85", 10, 5, 5}, + producerSelectionTestCase{"0xe09cc356df20c7a2dd38cb85b680a16ec29bd8b3e1ecc1b20f2e5603d5e7ee85", 5, 5, 5}, + producerSelectionTestCase{"0x8f5bab218b6bb34476f51ca588e9f4553a3a7ce5e13a66c660a5283e97e9a85a", 4, 4, 3}, + producerSelectionTestCase{"0xe09cc356df20c7a2dd38cb85b680a16ec29bd8b3e1ecc1b20f2e5603d5e7ee85", 4, 4, 4}, + } + + var validators []hmTypes.Validator + json.Unmarshal([]byte(testValidators), &validators) + require.Equal(t, 5, len(validators), "Total validators should be 5") + + for i, testcase := range testcases { + seed := common.HexToHash(testcase.seed) + producerIds, err := SelectNextProducers(seed, validators, testcase.producerCount) + fmt.Println("producerIds", producerIds) + require.NoError(t, err, "Error should be nil") + producers, slots := getSelectedValidatorsFromIDs(validators, producerIds) + require.Equal(t, testcase.resultSlots, slots, "Total slots should be %v (Testcase %v)", testcase.resultSlots, i+1) + require.Equal(t, int(testcase.resultProducers), len(producers), "Total producers should be %v (Testcase %v)", testcase.resultProducers, i+1) + } +} + +func getSelectedValidatorsFromIDs(validators []hmTypes.Validator, producerIds []uint64) ([]hmTypes.Validator, int64) { + var vals []hmTypes.Validator + IDToPower := make(map[uint64]uint64) + for _, ID := range producerIds { + IDToPower[ID] = IDToPower[ID] + 1 + } + + var slots int64 + for key, value := range IDToPower { + if val, ok := findValidatorByID(validators, key); ok { + val.VotingPower = int64(value) + vals = append(vals, val) + slots = slots + int64(value) + } + } + + return vals, slots +} + +func findValidatorByID(validators []hmTypes.Validator, id uint64) (val hmTypes.Validator, ok bool) { + for _, v := range validators { + if v.ID.Uint64() == id { + return v, true + } + } + + return +} + +func Test_createWeightedRanges(t *testing.T) { + type args struct { + vals []uint64 + } + tests := []struct { + name string + args args + ranges []uint64 + totalWeight uint64 + }{ + { + args: args{ + vals: []uint64{30, 20, 50, 50, 1}, + }, + ranges: []uint64{30, 50, 100, 150, 151}, + totalWeight: 151, + }, + { + args: args{ + vals: []uint64{1, 2, 1, 2, 1}, + }, + ranges: []uint64{1, 3, 4, 6, 7}, + totalWeight: 7, + }, + { + args: args{ + vals: []uint64{10, 1, 20, 1, 2}, + }, + ranges: []uint64{10, 11, 31, 32, 34}, + totalWeight: 34, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ranges, totalWeight := createWeightedRanges(tt.args.vals) + if !reflect.DeepEqual(ranges, tt.ranges) { + t.Errorf("createWeightedRange() got ranges = %v, want %v", ranges, tt.ranges) + } + if totalWeight != tt.totalWeight { + t.Errorf("createWeightedRange() got totalWeight = %v, want %v", totalWeight, tt.totalWeight) + } + }) + } +} + +func SimulateSelectionDistributionCorrectness() { + var validators []hmTypes.Validator + + validators = append(validators, hmTypes.Validator{ID: 1, VotingPower: 10}) + validators = append(validators, hmTypes.Validator{ID: 2, VotingPower: 10}) + validators = append(validators, hmTypes.Validator{ID: 3, VotingPower: 100}) + validators = append(validators, hmTypes.Validator{ID: 4, VotingPower: 100}) + validators = append(validators, hmTypes.Validator{ID: 5, VotingPower: 1000}) + validators = append(validators, hmTypes.Validator{ID: 6, VotingPower: 1000}) + validators = append(validators, hmTypes.Validator{ID: 7, VotingPower: 10000}) + validators = append(validators, hmTypes.Validator{ID: 8, VotingPower: 10000}) + validators = append(validators, hmTypes.Validator{ID: 9, VotingPower: 100000}) + validators = append(validators, hmTypes.Validator{ID: 10, VotingPower: 100000}) + validators = append(validators, hmTypes.Validator{ID: 11, VotingPower: 1000000}) + validators = append(validators, hmTypes.Validator{ID: 12, VotingPower: 1000000}) + + perfectProbabilities := make(map[types.ValidatorID]*big.Float) + totalPower := int64(0) + for _, validator := range validators { + totalPower += validator.VotingPower + } + + fmt.Printf("totalPower = %d\n", totalPower) + + totalPowerStr := strconv.FormatUint(uint64(totalPower), 10) + totalPowerF, _ := new(big.Float).SetString(totalPowerStr) + votingPowerF := new(big.Float) + for _, validator := range validators { + votingPowerF, _ := votingPowerF.SetString(strconv.FormatUint(uint64(validator.VotingPower), 10)) + perfectProbabilities[validator.ID] = new(big.Float).Quo(votingPowerF, totalPowerF) + } + + producerSlots := uint64(7) + iterations := uint64(10000000) + i := uint64(0) + buffer := make([]byte, 8) + selectedTimes := make(map[types.ValidatorID]uint64) + + for i < iterations { + i++ + binary.BigEndian.PutUint64(buffer, i) + keccak := crypto.Keccak256(buffer) + var hash common.Hash + copy(hash[:], keccak) + producerIds, _ := SelectNextProducers(hash, validators, producerSlots) + + for _, id := range producerIds { + selectedTimes[types.ValidatorID(id)]++ + } + } + + totalProducers, _ := new(big.Float).SetString(strconv.FormatUint(iterations*producerSlots, 10)) + fmt.Printf("Total producers selected = %d\n", iterations*producerSlots) + for _, validator := range validators { + wasSelected, _ := new(big.Float).SetString(strconv.FormatUint(selectedTimes[validator.ID], 10)) + prob := new(big.Float).Quo(wasSelected, totalProducers) + fmt.Printf("validator { ID = %d, Power = %d, Perfect Probability = %v%% } was selected %d times with %v%% probability\n", + validator.ID, validator.VotingPower, perfectProbabilities[validator.ID], selectedTimes[validator.ID], prob) + } +} + +func Test_binarySearch(t *testing.T) { + type args struct { + array []uint64 + search uint64 + } + + tests := []struct { + name string + args args + want int + }{ + { + args: args{ + array: []uint64{}, + search: 0, + }, + want: -1, + }, + { + args: args{ + array: []uint64{1}, + search: 100, + }, + want: 0, + }, + { + args: args{ + array: []uint64{1, 1000}, + search: 100, + }, + want: 1, + }, + { + args: args{ + array: []uint64{1, 100, 1000}, + search: 2, + }, + want: 1, + }, + { + args: args{ + array: []uint64{1, 100, 1000, 1000}, + search: 1001, + }, + want: 3, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := binarySearch(tt.args.array, tt.args.search); got != tt.want { + t.Errorf("binarySearch() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/bor/shuffle_test.go b/bor/shuffle_test.go index 36b2458d0..fdd07ffcc 100644 --- a/bor/shuffle_test.go +++ b/bor/shuffle_test.go @@ -47,7 +47,7 @@ func TestShuffleList(t *testing.T) { func TestValShuffle(t *testing.T) { seedHash1 := common.HexToHash("0xc46afc66ad9f4b237414c23a0cf0c469aeb60f52176565990644a9ee36a17667") initialVals := GenRandomVal(50, 0, 100, uint64(10), true, 1) - selectedProducerIndices, err := SelectNextProducers(seedHash1, initialVals, 40) + selectedProducerIndices, err := XXXSelectNextProducers(seedHash1, initialVals, 40) IDToPower := make(map[uint64]int64) for _, ID := range selectedProducerIndices { IDToPower[ID] = IDToPower[ID] + 1