From 08f48b4f2878220dd93ae7f52543be3f86a28b1e Mon Sep 17 00:00:00 2001 From: marcello33 Date: Tue, 8 Oct 2024 17:41:49 +0200 Subject: [PATCH] POS-2700: fix: allow one msg per tx --- types/errors/errors.go | 3 +++ x/auth/ante/ante_test.go | 45 ++++++++++++++++++++------------- x/auth/ante/basic.go | 27 ++++++++++++++++++++ x/auth/ante/setup.go | 9 +++++++ x/auth/ante/sigverify.go | 36 ++++++++++++++++++++++++++ x/auth/ante/sigverify_test.go | 24 ++++++++---------- x/auth/ante/validator_tx_fee.go | 9 +++++++ 7 files changed, 123 insertions(+), 30 deletions(-) diff --git a/types/errors/errors.go b/types/errors/errors.go index 189946326b91..9452a8af49dd 100644 --- a/types/errors/errors.go +++ b/types/errors/errors.go @@ -151,4 +151,7 @@ var ( // ErrRemoveBlockProposer defines an error when removing the block proposer fails. ErrRemoveBlockProposer = errorsmod.Register(RootCodespace, 44, "error removing block proposer") + + // ErrTooManyMsgsInTx defines an error when number of messages is not correct. + ErrTooManyMsgsInTx = errorsmod.Register(RootCodespace, 45, "too many messages in tx") ) diff --git a/x/auth/ante/ante_test.go b/x/auth/ante/ante_test.go index 782767d9708f..2bf2deed1c05 100644 --- a/x/auth/ante/ante_test.go +++ b/x/auth/ante/ante_test.go @@ -241,12 +241,11 @@ func TestAnteHandlerAccountNumbers(t *testing.T) { false, }, { - "new tx with another signer and incorrect account numbers", + "new tx with more than 1 message", func(suite *AnteTestSuite) TestCaseArgs { accs := suite.CreateTestAccounts(2) msg1 := testdata.NewTestMsg(accs[0].acc.GetAddress(), accs[1].acc.GetAddress()) msg2 := testdata.NewTestMsg(accs[1].acc.GetAddress(), accs[0].acc.GetAddress()) - suite.bankKeeper.EXPECT().SendCoinsFromAccountToModule(gomock.Any(), accs[0].acc.GetAddress(), gomock.Any(), gomock.Any()).Return(nil) return TestCaseArgs{ accNums: []uint64{2, 0}, @@ -257,6 +256,25 @@ func TestAnteHandlerAccountNumbers(t *testing.T) { }, false, false, + sdkerrors.ErrTooManyMsgsInTx, + false, + }, + { + "new tx with another signer and incorrect account numbers", + func(suite *AnteTestSuite) TestCaseArgs { + accs := suite.CreateTestAccounts(2) + msg1 := testdata.NewTestMsg(accs[0].acc.GetAddress(), accs[1].acc.GetAddress()) + suite.bankKeeper.EXPECT().SendCoinsFromAccountToModule(gomock.Any(), accs[0].acc.GetAddress(), gomock.Any(), gomock.Any()).Return(nil) + + return TestCaseArgs{ + accNums: []uint64{2, 0}, + accSeqs: []uint64{0, 0}, + msgs: []sdk.Msg{msg1}, + privs: []cryptotypes.PrivKey{accs[0].priv, accs[1].priv}, + } + }, + false, + false, sdkerrors.ErrTooManySignatures, false, }, @@ -265,11 +283,10 @@ func TestAnteHandlerAccountNumbers(t *testing.T) { func(suite *AnteTestSuite) TestCaseArgs { accs := suite.CreateTestAccounts(1) msg1 := testdata.NewTestMsg(accs[0].acc.GetAddress()) - msg2 := testdata.NewTestMsg(accs[0].acc.GetAddress()) suite.bankKeeper.EXPECT().SendCoinsFromAccountToModule(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) return TestCaseArgs{ - msgs: []sdk.Msg{msg1, msg2}, + msgs: []sdk.Msg{msg1}, }.WithAccountsInfo(accs) }, false, @@ -338,14 +355,13 @@ func TestAnteHandlerAccountNumbersAtBlockHeightZero(t *testing.T) { func(suite *AnteTestSuite) TestCaseArgs { accs := suite.CreateTestAccounts(1) msg1 := testdata.NewTestMsg(accs[0].acc.GetAddress()) - msg2 := testdata.NewTestMsg(accs[0].acc.GetAddress()) suite.bankKeeper.EXPECT().SendCoinsFromAccountToModule(gomock.Any(), accs[0].acc.GetAddress(), gomock.Any(), gomock.Any()).Return(nil) return TestCaseArgs{ accNums: []uint64{1, 0}, // wrong account numbers accSeqs: []uint64{0, 0}, - msgs: []sdk.Msg{msg1, msg2}, + msgs: []sdk.Msg{msg1}, privs: []cryptotypes.PrivKey{accs[0].priv}, } }, @@ -359,14 +375,13 @@ func TestAnteHandlerAccountNumbersAtBlockHeightZero(t *testing.T) { func(suite *AnteTestSuite) TestCaseArgs { accs := suite.CreateTestAccounts(1) msg1 := testdata.NewTestMsg(accs[0].acc.GetAddress()) - msg2 := testdata.NewTestMsg(accs[0].acc.GetAddress()) suite.bankKeeper.EXPECT().SendCoinsFromAccountToModule(gomock.Any(), accs[0].acc.GetAddress(), gomock.Any(), gomock.Any()).Return(nil) return TestCaseArgs{ accNums: []uint64{0}, // correct account numbers accSeqs: []uint64{0}, - msgs: []sdk.Msg{msg1, msg2}, + msgs: []sdk.Msg{msg1}, privs: []cryptotypes.PrivKey{accs[0].priv}, } }, @@ -472,12 +487,11 @@ func TestAnteHandlerSequences(t *testing.T) { func(suite *AnteTestSuite) TestCaseArgs { accs := suite.CreateTestAccounts(1) msg1 := testdata.NewTestMsg(accs[0].acc.GetAddress()) - msg2 := testdata.NewTestMsg(accs[0].acc.GetAddress()) suite.bankKeeper.EXPECT().SendCoinsFromAccountToModule(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) return TestCaseArgs{ - msgs: []sdk.Msg{msg1, msg2}, + msgs: []sdk.Msg{msg1}, }.WithAccountsInfo(accs) }, false, @@ -490,8 +504,7 @@ func TestAnteHandlerSequences(t *testing.T) { func(suite *AnteTestSuite) TestCaseArgs { accs := suite.CreateTestAccounts(1) msg1 := testdata.NewTestMsg(accs[0].acc.GetAddress()) - msg2 := testdata.NewTestMsg(accs[0].acc.GetAddress()) - msgs := []sdk.Msg{msg1, msg2} + msgs := []sdk.Msg{msg1} privs := []cryptotypes.PrivKey{accs[0].priv} accNums := []uint64{accs[0].acc.GetAccountNumber()} @@ -504,7 +517,7 @@ func TestAnteHandlerSequences(t *testing.T) { require.NoError(t, err) return TestCaseArgs{ - msgs: []sdk.Msg{msg1, msg2}, + msgs: []sdk.Msg{msg1}, }.WithAccountsInfo(accs) }, false, @@ -517,8 +530,7 @@ func TestAnteHandlerSequences(t *testing.T) { func(suite *AnteTestSuite) TestCaseArgs { accs := suite.CreateTestAccounts(1) msg1 := testdata.NewTestMsg(accs[0].acc.GetAddress()) - msg2 := testdata.NewTestMsg(accs[0].acc.GetAddress()) - msgs := []sdk.Msg{msg1, msg2} + msgs := []sdk.Msg{msg1} privs := []cryptotypes.PrivKey{accs[0].priv} accNums := []uint64{accs[0].acc.GetAccountNumber()} @@ -552,8 +564,7 @@ func TestAnteHandlerSequences(t *testing.T) { func(suite *AnteTestSuite) TestCaseArgs { accs := suite.CreateTestAccounts(1) msg1 := testdata.NewTestMsg(accs[0].acc.GetAddress()) - msg2 := testdata.NewTestMsg(accs[0].acc.GetAddress()) - msgs := []sdk.Msg{msg1, msg2} + msgs := []sdk.Msg{msg1} privs := []cryptotypes.PrivKey{accs[0].priv} accNums := []uint64{accs[0].acc.GetAccountNumber()} diff --git a/x/auth/ante/basic.go b/x/auth/ante/basic.go index 94c88bd41d27..6dad8ea1b3d8 100644 --- a/x/auth/ante/basic.go +++ b/x/auth/ante/basic.go @@ -25,6 +25,15 @@ func NewValidateBasicDecorator() ValidateBasicDecorator { } func (vbd ValidateBasicDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate bool, next sdk.AnteHandler) (sdk.Context, error) { + msgV2, err := tx.GetMsgsV2() + if err != nil { + return ctx, err + } + + if len(msgV2) > 1 || len(tx.GetMsgs()) > 1 { + return ctx, errorsmod.Wrap(sdkerrors.ErrTooManyMsgsInTx, "Tx must contain only one message") + } + // no need to validate basic on recheck tx, call next antehandler if ctx.IsReCheckTx() { return next(ctx, tx, simulate) @@ -58,6 +67,15 @@ func (vmd ValidateMemoDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate return ctx, errorsmod.Wrap(sdkerrors.ErrTxDecode, "invalid transaction type") } + msgV2, err := memoTx.GetMsgsV2() + if err != nil { + return ctx, err + } + + if len(msgV2) > 1 || len(memoTx.GetMsgs()) > 1 { + return ctx, errorsmod.Wrap(sdkerrors.ErrTooManyMsgsInTx, "Tx must contain only one message") + } + memoLength := len(memoTx.GetMemo()) if memoLength > 0 { params := vmd.ak.GetParams(ctx) @@ -98,6 +116,15 @@ func (cgts ConsumeTxSizeGasDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, sim } params := cgts.ak.GetParams(ctx) + msgV2, err := sigTx.GetMsgsV2() + if err != nil { + return ctx, err + } + + if len(msgV2) > 1 || len(sigTx.GetMsgs()) > 1 { + return ctx, errorsmod.Wrap(sdkerrors.ErrTooManyMsgsInTx, "Tx must contain only one message") + } + // HV2: removed `ConsumeGas` in it as done in heimdall's `auth/ante.go` (original ancestor's method `newCtx.GasMeter().ConsumeGas`) // ctx.GasMeter().ConsumeGas(params.TxSizeCostPerByte*storetypes.Gas(len(ctx.TxBytes())), "txSize") diff --git a/x/auth/ante/setup.go b/x/auth/ante/setup.go index 35ba2dc36308..86ea1bcb343a 100644 --- a/x/auth/ante/setup.go +++ b/x/auth/ante/setup.go @@ -37,6 +37,15 @@ func (sud SetUpContextDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate return newCtx, errorsmod.Wrap(sdkerrors.ErrTxDecode, "Tx must be GasTx") } + msgV2, err := gasTx.GetMsgsV2() + if err != nil { + return newCtx, err + } + + if len(msgV2) > 1 || len(gasTx.GetMsgs()) > 1 { + return newCtx, errorsmod.Wrap(sdkerrors.ErrTooManyMsgsInTx, "Tx must contain only one message") + } + newCtx = SetGasMeter(simulate, ctx, gasTx.GetGas()) if cp := ctx.ConsensusParams(); cp.Block != nil { diff --git a/x/auth/ante/sigverify.go b/x/auth/ante/sigverify.go index c1b70deeb9a9..e1bf2437aef6 100644 --- a/x/auth/ante/sigverify.go +++ b/x/auth/ante/sigverify.go @@ -64,6 +64,15 @@ func (spkd SetPubKeyDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate b return ctx, errorsmod.Wrap(sdkerrors.ErrTxDecode, "invalid tx type") } + msgV2, err := sigTx.GetMsgsV2() + if err != nil { + return ctx, err + } + + if len(msgV2) > 1 || len(sigTx.GetMsgs()) > 1 { + return ctx, errorsmod.Wrap(sdkerrors.ErrTooManyMsgsInTx, "Tx must contain only one message") + } + pubkeys, err := sigTx.GetPubKeys() if err != nil { return ctx, err @@ -167,6 +176,15 @@ func (sgcd SigGasConsumeDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simula return ctx, errorsmod.Wrap(sdkerrors.ErrTxDecode, "invalid transaction type") } + msgV2, err := sigTx.GetMsgsV2() + if err != nil { + return ctx, err + } + + if len(msgV2) > 1 || len(sigTx.GetMsgs()) > 1 { + return ctx, errorsmod.Wrap(sdkerrors.ErrTooManyMsgsInTx, "Tx must contain only one message") + } + params := sgcd.ak.GetParams(ctx) sigs, err := sigTx.GetSignaturesV2() if err != nil { @@ -271,6 +289,15 @@ func (svd SigVerificationDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simul return ctx, errorsmod.Wrap(sdkerrors.ErrTxDecode, "invalid transaction type") } + msgV2, err := sigTx.GetMsgsV2() + if err != nil { + return ctx, err + } + + if len(msgV2) > 1 || len(sigTx.GetMsgs()) > 1 { + return ctx, errorsmod.Wrap(sdkerrors.ErrTooManyMsgsInTx, "Tx must contain only one message") + } + // stdSigs contains the sequence number, account number, and signatures. // When simulating, this would just be a 0-length slice. sigs, err := sigTx.GetSignaturesV2() @@ -394,6 +421,15 @@ func (isd IncrementSequenceDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, sim return ctx, errorsmod.Wrap(sdkerrors.ErrTxDecode, "invalid transaction type") } + msgV2, err := sigTx.GetMsgsV2() + if err != nil { + return ctx, err + } + + if len(msgV2) > 1 || len(sigTx.GetMsgs()) > 1 { + return ctx, errorsmod.Wrap(sdkerrors.ErrTooManyMsgsInTx, "Tx must contain only one message") + } + // increment sequence of all signers signers, err := sigTx.GetSigners() if err != nil { diff --git a/x/auth/ante/sigverify_test.go b/x/auth/ante/sigverify_test.go index 0dc67c0072cb..f05bf6276a32 100644 --- a/x/auth/ante/sigverify_test.go +++ b/x/auth/ante/sigverify_test.go @@ -34,10 +34,9 @@ func TestSetPubKey(t *testing.T) { // keys and addresses priv1, pub1, addr1 := testdata.KeyTestPubAddr() - priv2, pub2, addr2 := testdata.KeyTestPubAddr() - addrs := []sdk.AccAddress{addr1, addr2} - pubs := []cryptotypes.PubKey{pub1, pub2} + addrs := []sdk.AccAddress{addr1} + pubs := []cryptotypes.PubKey{pub1} msgs := make([]sdk.Msg, len(addrs)) // set accounts and create msg for each address @@ -51,7 +50,7 @@ func TestSetPubKey(t *testing.T) { suite.txBuilder.SetFeeAmount(testdata.NewTestFeeAmount()) suite.txBuilder.SetGasLimit(testdata.NewTestGasLimit()) - privs, accNums, accSeqs := []cryptotypes.PrivKey{priv1, priv2}, []uint64{0, 1}, []uint64{0, 0} + privs, accNums, accSeqs := []cryptotypes.PrivKey{priv1}, []uint64{0}, []uint64{0} tx, err := suite.CreateTestTx(suite.ctx, privs, accNums, accSeqs, suite.ctx.ChainID(), signing.SignMode_SIGN_MODE_DIRECT) require.NoError(t, err) @@ -148,10 +147,9 @@ func TestSigVerification(t *testing.T) { // keys and addresses priv1, _, addr1 := testdata.KeyTestPubAddr() - priv2, _, addr2 := testdata.KeyTestPubAddr() - priv3, _, addr3 := testdata.KeyTestPubAddr() + priv2, _, _ := testdata.KeyTestPubAddr() - addrs := []sdk.AccAddress{addr1, addr2, addr3} + addrs := []sdk.AccAddress{addr1} msgs := make([]sdk.Msg, len(addrs)) accs := make([]sdk.AccountI, len(addrs)) @@ -194,12 +192,12 @@ func TestSigVerification(t *testing.T) { validSigs := false testCases := []testCase{ {"no signers", []cryptotypes.PrivKey{}, []uint64{}, []uint64{}, validSigs, false, true}, - {"not enough signers", []cryptotypes.PrivKey{priv1, priv2}, []uint64{accs[0].GetAccountNumber(), accs[1].GetAccountNumber()}, []uint64{0, 0}, validSigs, false, true}, - {"wrong order signers", []cryptotypes.PrivKey{priv3, priv2, priv1}, []uint64{accs[2].GetAccountNumber(), accs[1].GetAccountNumber(), accs[0].GetAccountNumber()}, []uint64{0, 0, 0}, validSigs, false, true}, - {"wrong accnums", []cryptotypes.PrivKey{priv1, priv2, priv3}, []uint64{7, 8, 9}, []uint64{0, 0, 0}, validSigs, false, true}, - {"wrong sequences", []cryptotypes.PrivKey{priv1, priv2, priv3}, []uint64{accs[0].GetAccountNumber(), accs[1].GetAccountNumber(), accs[2].GetAccountNumber()}, []uint64{3, 4, 5}, validSigs, false, true}, - {"valid tx", []cryptotypes.PrivKey{priv1, priv2, priv3}, []uint64{accs[0].GetAccountNumber(), accs[1].GetAccountNumber(), accs[2].GetAccountNumber()}, []uint64{0, 0, 0}, validSigs, false, false}, - {"no err on recheck", []cryptotypes.PrivKey{priv1, priv2, priv3}, []uint64{0, 0, 0}, []uint64{0, 0, 0}, !validSigs, true, false}, + {"not enough signers", []cryptotypes.PrivKey{}, []uint64{accs[0].GetAccountNumber()}, []uint64{0}, validSigs, false, true}, + {"wrong order signers", []cryptotypes.PrivKey{priv2}, []uint64{accs[0].GetAccountNumber()}, []uint64{0}, validSigs, false, true}, + {"wrong accnums", []cryptotypes.PrivKey{priv1}, []uint64{7}, []uint64{0}, validSigs, false, true}, + {"wrong sequences", []cryptotypes.PrivKey{priv1}, []uint64{accs[0].GetAccountNumber()}, []uint64{3}, validSigs, false, true}, + {"valid tx", []cryptotypes.PrivKey{priv1}, []uint64{accs[0].GetAccountNumber()}, []uint64{0}, validSigs, false, false}, + {"no err on recheck", []cryptotypes.PrivKey{priv1}, []uint64{0}, []uint64{0}, !validSigs, true, false}, } for i, tc := range testCases { diff --git a/x/auth/ante/validator_tx_fee.go b/x/auth/ante/validator_tx_fee.go index 628c1e2fdae5..553a9b9eedf1 100644 --- a/x/auth/ante/validator_tx_fee.go +++ b/x/auth/ante/validator_tx_fee.go @@ -20,6 +20,15 @@ func checkTxFeeWithValidatorMinGasPrices(ctx sdk.Context, tx sdk.Tx, params type return nil, 0, errorsmod.Wrap(sdkerrors.ErrInvalidTxFees, "must provide correct txFees") } + msgV2, err := tx.GetMsgsV2() + if err != nil { + return nil, 0, err + } + + if len(msgV2) > 1 || len(tx.GetMsgs()) > 1 { + return nil, 0, errorsmod.Wrap(sdkerrors.ErrTooManyMsgsInTx, "Tx must contain only one message") + } + // HV2: gas is retrieved from Params as currently done in heimdall gas := params.GetMaxTxGas() feeCoins := sdk.Coins{sdk.Coin{Denom: types.FeeToken, Amount: amount}}