diff --git a/x/consensus/keeper/estimate.go b/x/consensus/keeper/estimate.go index 15f5219e..a5630223 100644 --- a/x/consensus/keeper/estimate.go +++ b/x/consensus/keeper/estimate.go @@ -100,14 +100,14 @@ func (k Keeper) checkAndProcessEstimatedMessage(ctx context.Context, return fmt.Errorf("failed to set elected gas estimate: %w", err) } - if err := k.checkAndProcessEstimatedSubmitLogicCall(ctx, msg, q, estimate); err != nil { + if err := k.checkAndProcessEstimatedFeePayer(ctx, msg, q, estimate); err != nil { return fmt.Errorf("failed to process estimated submit logic call: %w", err) } return nil } -func (k Keeper) checkAndProcessEstimatedSubmitLogicCall( +func (k Keeper) checkAndProcessEstimatedFeePayer( ctx context.Context, msg types.QueuedSignedMessageI, q consensus.Queuer, @@ -117,9 +117,9 @@ func (k Keeper) checkAndProcessEstimatedSubmitLogicCall( if err != nil { return fmt.Errorf("failed to convert message to evm message: %w", err) } - action, ok := m.Action.(*evmtypes.Message_SubmitLogicCall) + action, ok := m.Action.(evmtypes.FeePayer) if !ok { - // Skip messages that are not SubmitLogicCall + // Skip messages that do not contain fees return nil } @@ -127,11 +127,14 @@ func (k Keeper) checkAndProcessEstimatedSubmitLogicCall( if err != nil { return fmt.Errorf("failed to parse validator address: %w", err) } + fees, err := k.calculateFeesForEstimate(ctx, valAddr, m.GetChainReferenceID(), estimate) if err != nil { return fmt.Errorf("failed to calculate fees for estimate: %w", err) } - action.SubmitLogicCall.Fees = fees + + action.SetFees(fees) + _, err = q.Put(ctx, m, &consensus.PutOptions{ MsgIDToReplace: msg.GetId(), }) diff --git a/x/consensus/keeper/estimate_test.go b/x/consensus/keeper/estimate_test.go index 7dee94d9..502c56ad 100644 --- a/x/consensus/keeper/estimate_test.go +++ b/x/consensus/keeper/estimate_test.go @@ -122,6 +122,32 @@ func Test_CheckAndProcessEstimatedMessages(t *testing.T) { return true }, }, + { + name: "Upload user smart contract message", + msg: &evmtypes.Message{ + TurnstoneID: "abc", + ChainReferenceID: chainReferenceID, + Assignee: validators[0].Address.String(), + Action: &evmtypes.Message_UploadUserSmartContract{ + UploadUserSmartContract: &evmtypes.UploadUserSmartContract{}, + }, + }, + slcCheck: func(m *evmtypes.Message, r *require.Assertions, expected bool) bool { + usc := m.GetUploadUserSmartContract() + if usc == nil { + return expected + } + if !expected { + return usc.Fees == nil + } + + r.NotNil(usc.Fees) + r.Equal(uint64(31500), usc.Fees.RelayerFee, "relayer fee: got %d", usc.Fees.RelayerFee) + r.Equal(uint64(9450), usc.Fees.CommunityFee, "community fee: got %d", usc.Fees.CommunityFee) + r.Equal(uint64(315), usc.Fees.SecurityFee, "security fee: got %d", usc.Fees.SecurityFee) + return true + }, + }, } for _, tc := range tt { diff --git a/x/evm/types/turnstone_message.go b/x/evm/types/turnstone_message.go index 03a57334..3fd69637 100644 --- a/x/evm/types/turnstone_message.go +++ b/x/evm/types/turnstone_message.go @@ -20,3 +20,20 @@ type SmartContractUploader interface { GetAbi() string GetBytecode() []byte } + +type FeePayer interface { + SetFees(fees *Fees) +} + +var ( + _ FeePayer = &Message_SubmitLogicCall{} + _ FeePayer = &Message_UploadUserSmartContract{} +) + +func (m *Message_SubmitLogicCall) SetFees(fees *Fees) { + m.SubmitLogicCall.Fees = fees +} + +func (m *Message_UploadUserSmartContract) SetFees(fees *Fees) { + m.UploadUserSmartContract.Fees = fees +}