diff --git a/x/lockup/types/msgs.go b/x/lockup/types/msgs.go index cfb8dedeb..5753dc956 100644 --- a/x/lockup/types/msgs.go +++ b/x/lockup/types/msgs.go @@ -2,7 +2,6 @@ package types import ( "fmt" - "time" sdk "github.com/cosmos/cosmos-sdk/types" ) @@ -19,20 +18,21 @@ var ( _ sdk.Msg = (*MsgInitiateUnlock)(nil) ) -// NewMsgLockTokens creates a message to lock tokens. -func NewMsgLockTokens(owner sdk.AccAddress, duration time.Duration, coins sdk.Coins) *MsgLockTokens { - return &MsgLockTokens{ - Owner: owner.String(), - Duration: duration, - Coins: coins, - } -} - func (m MsgLockTokens) Route() string { return RouterKey } func (m MsgLockTokens) Type() string { return TypeMsgLockTokens } func (m MsgLockTokens) ValidateBasic() error { + if err := m.Coins.Validate(); err != nil { + return fmt.Errorf("invalid coins") + } + if m.Coins.IsZero() { + return fmt.Errorf("zero coins") + } if m.Duration <= 0 { - return fmt.Errorf("duration should be positive: %d < 0", m.Duration) + return fmt.Errorf("duration should be positive: %d <= 0", m.Duration) + } + + if _, err := sdk.AccAddressFromBech32(m.Owner); err != nil { + return fmt.Errorf("invalid address") } return nil } @@ -49,7 +49,7 @@ func (m MsgLockTokens) GetSigners() []sdk.AccAddress { func (m *MsgInitiateUnlock) ValidateBasic() error { _, err := sdk.AccAddressFromBech32(m.Owner) if err != nil { - return err + return fmt.Errorf("invalid address") } return nil } diff --git a/x/lockup/types/msgs_test.go b/x/lockup/types/msgs_test.go new file mode 100644 index 000000000..5248d9ed8 --- /dev/null +++ b/x/lockup/types/msgs_test.go @@ -0,0 +1,116 @@ +package types + +import ( + "testing" + "time" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/stretchr/testify/require" + + "github.com/NibiruChain/nibiru/x/testutil/sample" +) + +func TestMsgLockTokens_ValidateBasic(t *testing.T) { + type test struct { + msg *MsgLockTokens + wantErr string + } + + validAddr := sample.AccAddress().String() + validDuration := 1 * time.Hour + validCoins := sdk.NewCoins(sdk.NewInt64Coin("test", 100)) + + cases := map[string]test{ + "success": { + msg: &MsgLockTokens{ + Owner: validAddr, + Duration: validDuration, + Coins: validCoins, + }, + }, + "invalid address": { + msg: &MsgLockTokens{ + Owner: "", + Duration: validDuration, + Coins: validCoins, + }, + wantErr: "invalid address", + }, + "invalid coins": { + msg: &MsgLockTokens{ + Owner: validAddr, + Duration: validDuration, + Coins: sdk.Coins{sdk.Coin{}}, + }, + wantErr: "invalid coins", + }, + "zero coins": { + msg: &MsgLockTokens{ + Owner: validAddr, + Duration: validDuration, + Coins: sdk.NewCoins(), + }, + wantErr: "zero coins", + }, + "invalid duration": { + msg: &MsgLockTokens{ + Owner: validAddr, + Duration: 0, + Coins: validCoins, + }, + wantErr: "duration should be positive", + }, + } + + for name, tc := range cases { + tc := tc + t.Run(name, func(t *testing.T) { + err := tc.msg.ValidateBasic() + if tc.wantErr == "" && err != nil { + t.Fatalf("unexpected error: %s", err) + } + if tc.wantErr != "" && err == nil { + t.Fatalf("expected error: %s", err) + } + if tc.wantErr != "" { + require.Contains(t, err.Error(), tc.wantErr) + } + }) + } +} + +func TestMsgInitiateUnlock_ValidateBasic(t *testing.T) { + type test struct { + msg *MsgInitiateUnlock + wantErr string + } + + cases := map[string]test{ + "success": { + msg: &MsgInitiateUnlock{ + Owner: sample.AccAddress().String(), + LockId: 0, + }, + }, + "invalid address": { + msg: &MsgInitiateUnlock{Owner: "invalid address", LockId: 0}, + wantErr: "invalid address", + }, + } + + for name, tc := range cases { + tc := tc + t.Run(name, func(t *testing.T) { + err := tc.msg.ValidateBasic() + if tc.wantErr == "" && err != nil { + t.Fatalf("unexpected error: %s", err) + } + if tc.wantErr != "" && err == nil { + t.Fatalf("expected error: %s", err) + } + if tc.wantErr != "" { + require.Contains(t, err.Error(), tc.wantErr) + } + }) + } +}