Skip to content

Commit

Permalink
POS-2700: fix: allow one msg per tx
Browse files Browse the repository at this point in the history
  • Loading branch information
marcello33 committed Oct 8, 2024
1 parent 09e70f0 commit 08f48b4
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 30 deletions.
3 changes: 3 additions & 0 deletions types/errors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,4 +151,7 @@ var (

// ErrRemoveBlockProposer defines an error when removing the block proposer fails.
ErrRemoveBlockProposer = errorsmod.Register(RootCodespace, 44, "error removing block proposer")

// ErrTooManyMsgsInTx defines an error when number of messages is not correct.
ErrTooManyMsgsInTx = errorsmod.Register(RootCodespace, 45, "too many messages in tx")
)
45 changes: 28 additions & 17 deletions x/auth/ante/ante_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,12 +241,11 @@ func TestAnteHandlerAccountNumbers(t *testing.T) {
false,
},
{
"new tx with another signer and incorrect account numbers",
"new tx with more than 1 message",
func(suite *AnteTestSuite) TestCaseArgs {
accs := suite.CreateTestAccounts(2)
msg1 := testdata.NewTestMsg(accs[0].acc.GetAddress(), accs[1].acc.GetAddress())
msg2 := testdata.NewTestMsg(accs[1].acc.GetAddress(), accs[0].acc.GetAddress())
suite.bankKeeper.EXPECT().SendCoinsFromAccountToModule(gomock.Any(), accs[0].acc.GetAddress(), gomock.Any(), gomock.Any()).Return(nil)

return TestCaseArgs{
accNums: []uint64{2, 0},
Expand All @@ -257,6 +256,25 @@ func TestAnteHandlerAccountNumbers(t *testing.T) {
},
false,
false,
sdkerrors.ErrTooManyMsgsInTx,
false,
},
{
"new tx with another signer and incorrect account numbers",
func(suite *AnteTestSuite) TestCaseArgs {
accs := suite.CreateTestAccounts(2)
msg1 := testdata.NewTestMsg(accs[0].acc.GetAddress(), accs[1].acc.GetAddress())
suite.bankKeeper.EXPECT().SendCoinsFromAccountToModule(gomock.Any(), accs[0].acc.GetAddress(), gomock.Any(), gomock.Any()).Return(nil)

return TestCaseArgs{
accNums: []uint64{2, 0},
accSeqs: []uint64{0, 0},
msgs: []sdk.Msg{msg1},
privs: []cryptotypes.PrivKey{accs[0].priv, accs[1].priv},
}
},
false,
false,
sdkerrors.ErrTooManySignatures,
false,
},
Expand All @@ -265,11 +283,10 @@ func TestAnteHandlerAccountNumbers(t *testing.T) {
func(suite *AnteTestSuite) TestCaseArgs {
accs := suite.CreateTestAccounts(1)
msg1 := testdata.NewTestMsg(accs[0].acc.GetAddress())
msg2 := testdata.NewTestMsg(accs[0].acc.GetAddress())
suite.bankKeeper.EXPECT().SendCoinsFromAccountToModule(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil)

return TestCaseArgs{
msgs: []sdk.Msg{msg1, msg2},
msgs: []sdk.Msg{msg1},
}.WithAccountsInfo(accs)
},
false,
Expand Down Expand Up @@ -338,14 +355,13 @@ func TestAnteHandlerAccountNumbersAtBlockHeightZero(t *testing.T) {
func(suite *AnteTestSuite) TestCaseArgs {
accs := suite.CreateTestAccounts(1)
msg1 := testdata.NewTestMsg(accs[0].acc.GetAddress())
msg2 := testdata.NewTestMsg(accs[0].acc.GetAddress())

suite.bankKeeper.EXPECT().SendCoinsFromAccountToModule(gomock.Any(), accs[0].acc.GetAddress(), gomock.Any(), gomock.Any()).Return(nil)

return TestCaseArgs{
accNums: []uint64{1, 0}, // wrong account numbers
accSeqs: []uint64{0, 0},
msgs: []sdk.Msg{msg1, msg2},
msgs: []sdk.Msg{msg1},
privs: []cryptotypes.PrivKey{accs[0].priv},
}
},
Expand All @@ -359,14 +375,13 @@ func TestAnteHandlerAccountNumbersAtBlockHeightZero(t *testing.T) {
func(suite *AnteTestSuite) TestCaseArgs {
accs := suite.CreateTestAccounts(1)
msg1 := testdata.NewTestMsg(accs[0].acc.GetAddress())
msg2 := testdata.NewTestMsg(accs[0].acc.GetAddress())

suite.bankKeeper.EXPECT().SendCoinsFromAccountToModule(gomock.Any(), accs[0].acc.GetAddress(), gomock.Any(), gomock.Any()).Return(nil)

return TestCaseArgs{
accNums: []uint64{0}, // correct account numbers
accSeqs: []uint64{0},
msgs: []sdk.Msg{msg1, msg2},
msgs: []sdk.Msg{msg1},
privs: []cryptotypes.PrivKey{accs[0].priv},
}
},
Expand Down Expand Up @@ -472,12 +487,11 @@ func TestAnteHandlerSequences(t *testing.T) {
func(suite *AnteTestSuite) TestCaseArgs {
accs := suite.CreateTestAccounts(1)
msg1 := testdata.NewTestMsg(accs[0].acc.GetAddress())
msg2 := testdata.NewTestMsg(accs[0].acc.GetAddress())

suite.bankKeeper.EXPECT().SendCoinsFromAccountToModule(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil)

return TestCaseArgs{
msgs: []sdk.Msg{msg1, msg2},
msgs: []sdk.Msg{msg1},
}.WithAccountsInfo(accs)
},
false,
Expand All @@ -490,8 +504,7 @@ func TestAnteHandlerSequences(t *testing.T) {
func(suite *AnteTestSuite) TestCaseArgs {
accs := suite.CreateTestAccounts(1)
msg1 := testdata.NewTestMsg(accs[0].acc.GetAddress())
msg2 := testdata.NewTestMsg(accs[0].acc.GetAddress())
msgs := []sdk.Msg{msg1, msg2}
msgs := []sdk.Msg{msg1}

privs := []cryptotypes.PrivKey{accs[0].priv}
accNums := []uint64{accs[0].acc.GetAccountNumber()}
Expand All @@ -504,7 +517,7 @@ func TestAnteHandlerSequences(t *testing.T) {
require.NoError(t, err)

return TestCaseArgs{
msgs: []sdk.Msg{msg1, msg2},
msgs: []sdk.Msg{msg1},
}.WithAccountsInfo(accs)
},
false,
Expand All @@ -517,8 +530,7 @@ func TestAnteHandlerSequences(t *testing.T) {
func(suite *AnteTestSuite) TestCaseArgs {
accs := suite.CreateTestAccounts(1)
msg1 := testdata.NewTestMsg(accs[0].acc.GetAddress())
msg2 := testdata.NewTestMsg(accs[0].acc.GetAddress())
msgs := []sdk.Msg{msg1, msg2}
msgs := []sdk.Msg{msg1}

privs := []cryptotypes.PrivKey{accs[0].priv}
accNums := []uint64{accs[0].acc.GetAccountNumber()}
Expand Down Expand Up @@ -552,8 +564,7 @@ func TestAnteHandlerSequences(t *testing.T) {
func(suite *AnteTestSuite) TestCaseArgs {
accs := suite.CreateTestAccounts(1)
msg1 := testdata.NewTestMsg(accs[0].acc.GetAddress())
msg2 := testdata.NewTestMsg(accs[0].acc.GetAddress())
msgs := []sdk.Msg{msg1, msg2}
msgs := []sdk.Msg{msg1}

privs := []cryptotypes.PrivKey{accs[0].priv}
accNums := []uint64{accs[0].acc.GetAccountNumber()}
Expand Down
27 changes: 27 additions & 0 deletions x/auth/ante/basic.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ func NewValidateBasicDecorator() ValidateBasicDecorator {
}

func (vbd ValidateBasicDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate bool, next sdk.AnteHandler) (sdk.Context, error) {
msgV2, err := tx.GetMsgsV2()
if err != nil {
return ctx, err
}

if len(msgV2) > 1 || len(tx.GetMsgs()) > 1 {
return ctx, errorsmod.Wrap(sdkerrors.ErrTooManyMsgsInTx, "Tx must contain only one message")
}

// no need to validate basic on recheck tx, call next antehandler
if ctx.IsReCheckTx() {
return next(ctx, tx, simulate)
Expand Down Expand Up @@ -58,6 +67,15 @@ func (vmd ValidateMemoDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate
return ctx, errorsmod.Wrap(sdkerrors.ErrTxDecode, "invalid transaction type")
}

msgV2, err := memoTx.GetMsgsV2()
if err != nil {
return ctx, err
}

if len(msgV2) > 1 || len(memoTx.GetMsgs()) > 1 {
return ctx, errorsmod.Wrap(sdkerrors.ErrTooManyMsgsInTx, "Tx must contain only one message")
}

memoLength := len(memoTx.GetMemo())
if memoLength > 0 {
params := vmd.ak.GetParams(ctx)
Expand Down Expand Up @@ -98,6 +116,15 @@ func (cgts ConsumeTxSizeGasDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, sim
}
params := cgts.ak.GetParams(ctx)

msgV2, err := sigTx.GetMsgsV2()
if err != nil {
return ctx, err
}

if len(msgV2) > 1 || len(sigTx.GetMsgs()) > 1 {
return ctx, errorsmod.Wrap(sdkerrors.ErrTooManyMsgsInTx, "Tx must contain only one message")
}

// HV2: removed `ConsumeGas` in it as done in heimdall's `auth/ante.go` (original ancestor's method `newCtx.GasMeter().ConsumeGas`)
// ctx.GasMeter().ConsumeGas(params.TxSizeCostPerByte*storetypes.Gas(len(ctx.TxBytes())), "txSize")

Expand Down
9 changes: 9 additions & 0 deletions x/auth/ante/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,15 @@ func (sud SetUpContextDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate
return newCtx, errorsmod.Wrap(sdkerrors.ErrTxDecode, "Tx must be GasTx")
}

msgV2, err := gasTx.GetMsgsV2()
if err != nil {
return newCtx, err
}

if len(msgV2) > 1 || len(gasTx.GetMsgs()) > 1 {
return newCtx, errorsmod.Wrap(sdkerrors.ErrTooManyMsgsInTx, "Tx must contain only one message")
}

newCtx = SetGasMeter(simulate, ctx, gasTx.GetGas())

if cp := ctx.ConsensusParams(); cp.Block != nil {
Expand Down
36 changes: 36 additions & 0 deletions x/auth/ante/sigverify.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ func (spkd SetPubKeyDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate b
return ctx, errorsmod.Wrap(sdkerrors.ErrTxDecode, "invalid tx type")
}

msgV2, err := sigTx.GetMsgsV2()
if err != nil {
return ctx, err
}

if len(msgV2) > 1 || len(sigTx.GetMsgs()) > 1 {
return ctx, errorsmod.Wrap(sdkerrors.ErrTooManyMsgsInTx, "Tx must contain only one message")
}

pubkeys, err := sigTx.GetPubKeys()
if err != nil {
return ctx, err
Expand Down Expand Up @@ -167,6 +176,15 @@ func (sgcd SigGasConsumeDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simula
return ctx, errorsmod.Wrap(sdkerrors.ErrTxDecode, "invalid transaction type")
}

msgV2, err := sigTx.GetMsgsV2()
if err != nil {
return ctx, err
}

if len(msgV2) > 1 || len(sigTx.GetMsgs()) > 1 {
return ctx, errorsmod.Wrap(sdkerrors.ErrTooManyMsgsInTx, "Tx must contain only one message")
}

params := sgcd.ak.GetParams(ctx)
sigs, err := sigTx.GetSignaturesV2()
if err != nil {
Expand Down Expand Up @@ -271,6 +289,15 @@ func (svd SigVerificationDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simul
return ctx, errorsmod.Wrap(sdkerrors.ErrTxDecode, "invalid transaction type")
}

msgV2, err := sigTx.GetMsgsV2()
if err != nil {
return ctx, err
}

if len(msgV2) > 1 || len(sigTx.GetMsgs()) > 1 {
return ctx, errorsmod.Wrap(sdkerrors.ErrTooManyMsgsInTx, "Tx must contain only one message")
}

// stdSigs contains the sequence number, account number, and signatures.
// When simulating, this would just be a 0-length slice.
sigs, err := sigTx.GetSignaturesV2()
Expand Down Expand Up @@ -394,6 +421,15 @@ func (isd IncrementSequenceDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, sim
return ctx, errorsmod.Wrap(sdkerrors.ErrTxDecode, "invalid transaction type")
}

msgV2, err := sigTx.GetMsgsV2()
if err != nil {
return ctx, err
}

if len(msgV2) > 1 || len(sigTx.GetMsgs()) > 1 {
return ctx, errorsmod.Wrap(sdkerrors.ErrTooManyMsgsInTx, "Tx must contain only one message")
}

// increment sequence of all signers
signers, err := sigTx.GetSigners()
if err != nil {
Expand Down
24 changes: 11 additions & 13 deletions x/auth/ante/sigverify_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,9 @@ func TestSetPubKey(t *testing.T) {

// keys and addresses
priv1, pub1, addr1 := testdata.KeyTestPubAddr()
priv2, pub2, addr2 := testdata.KeyTestPubAddr()

addrs := []sdk.AccAddress{addr1, addr2}
pubs := []cryptotypes.PubKey{pub1, pub2}
addrs := []sdk.AccAddress{addr1}
pubs := []cryptotypes.PubKey{pub1}

msgs := make([]sdk.Msg, len(addrs))
// set accounts and create msg for each address
Expand All @@ -51,7 +50,7 @@ func TestSetPubKey(t *testing.T) {
suite.txBuilder.SetFeeAmount(testdata.NewTestFeeAmount())
suite.txBuilder.SetGasLimit(testdata.NewTestGasLimit())

privs, accNums, accSeqs := []cryptotypes.PrivKey{priv1, priv2}, []uint64{0, 1}, []uint64{0, 0}
privs, accNums, accSeqs := []cryptotypes.PrivKey{priv1}, []uint64{0}, []uint64{0}
tx, err := suite.CreateTestTx(suite.ctx, privs, accNums, accSeqs, suite.ctx.ChainID(), signing.SignMode_SIGN_MODE_DIRECT)
require.NoError(t, err)

Expand Down Expand Up @@ -148,10 +147,9 @@ func TestSigVerification(t *testing.T) {

// keys and addresses
priv1, _, addr1 := testdata.KeyTestPubAddr()
priv2, _, addr2 := testdata.KeyTestPubAddr()
priv3, _, addr3 := testdata.KeyTestPubAddr()
priv2, _, _ := testdata.KeyTestPubAddr()

addrs := []sdk.AccAddress{addr1, addr2, addr3}
addrs := []sdk.AccAddress{addr1}

msgs := make([]sdk.Msg, len(addrs))
accs := make([]sdk.AccountI, len(addrs))
Expand Down Expand Up @@ -194,12 +192,12 @@ func TestSigVerification(t *testing.T) {
validSigs := false
testCases := []testCase{
{"no signers", []cryptotypes.PrivKey{}, []uint64{}, []uint64{}, validSigs, false, true},
{"not enough signers", []cryptotypes.PrivKey{priv1, priv2}, []uint64{accs[0].GetAccountNumber(), accs[1].GetAccountNumber()}, []uint64{0, 0}, validSigs, false, true},
{"wrong order signers", []cryptotypes.PrivKey{priv3, priv2, priv1}, []uint64{accs[2].GetAccountNumber(), accs[1].GetAccountNumber(), accs[0].GetAccountNumber()}, []uint64{0, 0, 0}, validSigs, false, true},
{"wrong accnums", []cryptotypes.PrivKey{priv1, priv2, priv3}, []uint64{7, 8, 9}, []uint64{0, 0, 0}, validSigs, false, true},
{"wrong sequences", []cryptotypes.PrivKey{priv1, priv2, priv3}, []uint64{accs[0].GetAccountNumber(), accs[1].GetAccountNumber(), accs[2].GetAccountNumber()}, []uint64{3, 4, 5}, validSigs, false, true},
{"valid tx", []cryptotypes.PrivKey{priv1, priv2, priv3}, []uint64{accs[0].GetAccountNumber(), accs[1].GetAccountNumber(), accs[2].GetAccountNumber()}, []uint64{0, 0, 0}, validSigs, false, false},
{"no err on recheck", []cryptotypes.PrivKey{priv1, priv2, priv3}, []uint64{0, 0, 0}, []uint64{0, 0, 0}, !validSigs, true, false},
{"not enough signers", []cryptotypes.PrivKey{}, []uint64{accs[0].GetAccountNumber()}, []uint64{0}, validSigs, false, true},
{"wrong order signers", []cryptotypes.PrivKey{priv2}, []uint64{accs[0].GetAccountNumber()}, []uint64{0}, validSigs, false, true},
{"wrong accnums", []cryptotypes.PrivKey{priv1}, []uint64{7}, []uint64{0}, validSigs, false, true},
{"wrong sequences", []cryptotypes.PrivKey{priv1}, []uint64{accs[0].GetAccountNumber()}, []uint64{3}, validSigs, false, true},
{"valid tx", []cryptotypes.PrivKey{priv1}, []uint64{accs[0].GetAccountNumber()}, []uint64{0}, validSigs, false, false},
{"no err on recheck", []cryptotypes.PrivKey{priv1}, []uint64{0}, []uint64{0}, !validSigs, true, false},
}

for i, tc := range testCases {
Expand Down
9 changes: 9 additions & 0 deletions x/auth/ante/validator_tx_fee.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@ func checkTxFeeWithValidatorMinGasPrices(ctx sdk.Context, tx sdk.Tx, params type
return nil, 0, errorsmod.Wrap(sdkerrors.ErrInvalidTxFees, "must provide correct txFees")
}

msgV2, err := tx.GetMsgsV2()
if err != nil {
return nil, 0, err
}

if len(msgV2) > 1 || len(tx.GetMsgs()) > 1 {
return nil, 0, errorsmod.Wrap(sdkerrors.ErrTooManyMsgsInTx, "Tx must contain only one message")
}

// HV2: gas is retrieved from Params as currently done in heimdall
gas := params.GetMaxTxGas()
feeCoins := sdk.Coins{sdk.Coin{Denom: types.FeeToken, Amount: amount}}
Expand Down

0 comments on commit 08f48b4

Please sign in to comment.