Skip to content

Commit

Permalink
Merge pull request #3481 from CheshireFox/fix/chain-not-found-err
Browse files Browse the repository at this point in the history
Fix chain not found error
  • Loading branch information
jorgemmsilva authored Jul 25, 2024
2 parents e415fd6 + 54863e0 commit 54cb0e6
Show file tree
Hide file tree
Showing 12 changed files with 88 additions and 73 deletions.
8 changes: 7 additions & 1 deletion packages/isc/chainid.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,19 @@ func ChainIDFromBytes(data []byte) (ret ChainID, err error) {
}

func ChainIDFromString(bech32 string) (ChainID, error) {
_, addr, err := iotago.ParseBech32(bech32)
netPrefix, addr, err := iotago.ParseBech32(bech32)
if err != nil {
return ChainID{}, err
}
if addr.Type() != iotago.AddressAlias {
return ChainID{}, fmt.Errorf("chainID must be an alias address (%s)", bech32)
}

expectedNetPrefix := parameters.L1().Protocol.Bech32HRP
if netPrefix != expectedNetPrefix {
return ChainID{}, fmt.Errorf("invalid network prefix: %s", netPrefix)
}

return ChainIDFromAddress(addr.(*iotago.AliasAddress)), nil
}

Expand Down
9 changes: 9 additions & 0 deletions packages/isc/chainid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package isc
import (
"testing"

"github.com/stretchr/testify/require"

"github.com/iotaledger/wasp/packages/util/rwutil"
)

Expand All @@ -12,3 +14,10 @@ func TestChainIDSerialization(t *testing.T) {
rwutil.BytesTest(t, chainID, ChainIDFromBytes)
rwutil.StringTest(t, chainID, ChainIDFromString)
}

func TestIncorrectPrefix(t *testing.T) {
chainID := "rms1prxunz807j39nmhzy3gre4hwdlzvdjyrkfn59d27x6xh426y8ajt205mh9g"
_, err := ChainIDFromString(chainID)

require.ErrorContains(t, err, "invalid network prefix: rms")
}
4 changes: 2 additions & 2 deletions packages/webapi/apierrors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ import (
"strings"
)

func ChainNotFoundError(chainID string) *HTTPError {
return NewHTTPError(http.StatusNotFound, fmt.Sprintf("Chain ID: %v not found", chainID), nil)
func ChainNotFoundError() *HTTPError {
return NewHTTPError(http.StatusNotFound, "Chain ID not found", nil)
}

func UserNotFoundError(username string) *HTTPError {
Expand Down
2 changes: 1 addition & 1 deletion packages/webapi/controllers/chain/chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (c *Controller) getCommitteeInfo(e echo.Context) error {

chain, err := c.chainService.GetChainInfoByChainID(chainID, "")
if err != nil {
return apierrors.ChainNotFoundError(chainID.String())
return apierrors.ChainNotFoundError()
}

chainNodeInfo, err := c.committeeService.GetCommitteeInfo(chainID)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func ChainIDFromParams(c echo.Context, cs interfaces.ChainService) (isc.ChainID,
}

if !cs.HasChain(chainID) {
return isc.ChainID{}, apierrors.ChainNotFoundError(chainID.String())
return isc.ChainID{}, apierrors.ChainNotFoundError()
}
// set chainID to be used by the prometheus metrics
c.Set(EchoContextKeyChainID, chainID)
Expand Down
48 changes: 24 additions & 24 deletions packages/webapi/controllers/corecontracts/accounts.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ import (
)

func (c *Controller) getTotalAssets(e echo.Context) error {
ch, chainID, err := controllerutils.ChainFromParams(e, c.chainService)
ch, _, err := controllerutils.ChainFromParams(e, c.chainService)
if err != nil {
return c.handleViewCallError(err, chainID)
return c.handleViewCallError(err)
}

assets, err := corecontracts.GetTotalAssets(ch, e.QueryParam(params.ParamBlockIndexOrTrieRoot))
if err != nil {
return c.handleViewCallError(err, chainID)
return c.handleViewCallError(err)
}

assetsResponse := &models.AssetsResponse{
Expand All @@ -34,9 +34,9 @@ func (c *Controller) getTotalAssets(e echo.Context) error {
}

func (c *Controller) getAccountBalance(e echo.Context) error {
ch, chainID, err := controllerutils.ChainFromParams(e, c.chainService)
ch, _, err := controllerutils.ChainFromParams(e, c.chainService)
if err != nil {
return c.handleViewCallError(err, chainID)
return c.handleViewCallError(err)
}

agentID, err := params.DecodeAgentID(e)
Expand All @@ -46,7 +46,7 @@ func (c *Controller) getAccountBalance(e echo.Context) error {

assets, err := corecontracts.GetAccountBalance(ch, agentID, e.QueryParam(params.ParamBlockIndexOrTrieRoot))
if err != nil {
return c.handleViewCallError(err, chainID)
return c.handleViewCallError(err)
}

assetsResponse := &models.AssetsResponse{
Expand All @@ -58,9 +58,9 @@ func (c *Controller) getAccountBalance(e echo.Context) error {
}

func (c *Controller) getAccountNFTs(e echo.Context) error {
ch, chainID, err := controllerutils.ChainFromParams(e, c.chainService)
ch, _, err := controllerutils.ChainFromParams(e, c.chainService)
if err != nil {
return c.handleViewCallError(err, chainID)
return c.handleViewCallError(err)
}

agentID, err := params.DecodeAgentID(e)
Expand All @@ -70,7 +70,7 @@ func (c *Controller) getAccountNFTs(e echo.Context) error {

nfts, err := corecontracts.GetAccountNFTs(ch, agentID, e.QueryParam(params.ParamBlockIndexOrTrieRoot))
if err != nil {
return c.handleViewCallError(err, chainID)
return c.handleViewCallError(err)
}

nftsResponse := &models.AccountNFTsResponse{
Expand All @@ -85,9 +85,9 @@ func (c *Controller) getAccountNFTs(e echo.Context) error {
}

func (c *Controller) getAccountFoundries(e echo.Context) error {
ch, chainID, err := controllerutils.ChainFromParams(e, c.chainService)
ch, _, err := controllerutils.ChainFromParams(e, c.chainService)
if err != nil {
return c.handleViewCallError(err, chainID)
return c.handleViewCallError(err)
}
agentID, err := params.DecodeAgentID(e)
if err != nil {
Expand All @@ -96,7 +96,7 @@ func (c *Controller) getAccountFoundries(e echo.Context) error {

foundries, err := corecontracts.GetAccountFoundries(ch, agentID, e.QueryParam(params.ParamBlockIndexOrTrieRoot))
if err != nil {
return c.handleViewCallError(err, chainID)
return c.handleViewCallError(err)
}

return e.JSON(http.StatusOK, &models.AccountFoundriesResponse{
Expand All @@ -105,9 +105,9 @@ func (c *Controller) getAccountFoundries(e echo.Context) error {
}

func (c *Controller) getAccountNonce(e echo.Context) error {
ch, chainID, err := controllerutils.ChainFromParams(e, c.chainService)
ch, _, err := controllerutils.ChainFromParams(e, c.chainService)
if err != nil {
return c.handleViewCallError(err, chainID)
return c.handleViewCallError(err)
}

agentID, err := params.DecodeAgentID(e)
Expand All @@ -117,7 +117,7 @@ func (c *Controller) getAccountNonce(e echo.Context) error {

nonce, err := corecontracts.GetAccountNonce(ch, agentID, e.QueryParam(params.ParamBlockIndexOrTrieRoot))
if err != nil {
return c.handleViewCallError(err, chainID)
return c.handleViewCallError(err)
}

nonceResponse := &models.AccountNonceResponse{
Expand All @@ -128,9 +128,9 @@ func (c *Controller) getAccountNonce(e echo.Context) error {
}

func (c *Controller) getNFTData(e echo.Context) error {
ch, chainID, err := controllerutils.ChainFromParams(e, c.chainService)
ch, _, err := controllerutils.ChainFromParams(e, c.chainService)
if err != nil {
return c.handleViewCallError(err, chainID)
return c.handleViewCallError(err)
}

nftID, err := params.DecodeNFTID(e)
Expand All @@ -140,7 +140,7 @@ func (c *Controller) getNFTData(e echo.Context) error {

nftData, err := corecontracts.GetNFTData(ch, *nftID, e.QueryParam(params.ParamBlockIndexOrTrieRoot))
if err != nil {
return c.handleViewCallError(err, chainID)
return c.handleViewCallError(err)
}

nftDataResponse := isc.NFTToJSONObject(nftData)
Expand All @@ -149,14 +149,14 @@ func (c *Controller) getNFTData(e echo.Context) error {
}

func (c *Controller) getNativeTokenIDRegistry(e echo.Context) error {
ch, chainID, err := controllerutils.ChainFromParams(e, c.chainService)
ch, _, err := controllerutils.ChainFromParams(e, c.chainService)
if err != nil {
return c.handleViewCallError(err, chainID)
return c.handleViewCallError(err)
}

registries, err := corecontracts.GetNativeTokenIDRegistry(ch, e.QueryParam(params.ParamBlockIndexOrTrieRoot))
if err != nil {
return c.handleViewCallError(err, chainID)
return c.handleViewCallError(err)
}

nativeTokenIDRegistryResponse := &models.NativeTokenIDRegistryResponse{
Expand All @@ -171,9 +171,9 @@ func (c *Controller) getNativeTokenIDRegistry(e echo.Context) error {
}

func (c *Controller) getFoundryOutput(e echo.Context) error {
ch, chainID, err := controllerutils.ChainFromParams(e, c.chainService)
ch, _, err := controllerutils.ChainFromParams(e, c.chainService)
if err != nil {
return c.handleViewCallError(err, chainID)
return c.handleViewCallError(err)
}

serialNumber, err := params.DecodeUInt(e, "serialNumber")
Expand All @@ -183,7 +183,7 @@ func (c *Controller) getFoundryOutput(e echo.Context) error {

foundryOutput, err := corecontracts.GetFoundryOutput(ch, uint32(serialNumber), e.QueryParam(params.ParamBlockIndexOrTrieRoot))
if err != nil {
return c.handleViewCallError(err, chainID)
return c.handleViewCallError(err)
}

foundryOutputID, err := foundryOutput.ID()
Expand Down
12 changes: 6 additions & 6 deletions packages/webapi/controllers/corecontracts/blob.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ type BlobValueResponse struct {
}

func (c *Controller) getBlobValue(e echo.Context) error {
ch, chainID, err := controllerutils.ChainFromParams(e, c.chainService)
ch, _, err := controllerutils.ChainFromParams(e, c.chainService)
if err != nil {
return c.handleViewCallError(err, chainID)
return c.handleViewCallError(err)
}

blobHash, err := params.DecodeBlobHash(e)
Expand All @@ -40,7 +40,7 @@ func (c *Controller) getBlobValue(e echo.Context) error {

blobValueBytes, err := corecontracts.GetBlobValue(ch, *blobHash, fieldKey, e.QueryParam(params.ParamBlockIndexOrTrieRoot))
if err != nil {
return c.handleViewCallError(err, chainID)
return c.handleViewCallError(err)
}

blobValueResponse := &BlobValueResponse{
Expand All @@ -57,9 +57,9 @@ type BlobInfoResponse struct {
func (c *Controller) getBlobInfo(e echo.Context) error {
fmt.Println("GET BLOB INFO")

ch, chainID, err := controllerutils.ChainFromParams(e, c.chainService)
ch, _, err := controllerutils.ChainFromParams(e, c.chainService)
if err != nil {
return c.handleViewCallError(err, chainID)
return c.handleViewCallError(err)
}

blobHash, err := params.DecodeBlobHash(e)
Expand All @@ -69,7 +69,7 @@ func (c *Controller) getBlobInfo(e echo.Context) error {

blobInfo, ok, err := corecontracts.GetBlobInfo(ch, *blobHash, e.QueryParam(params.ParamBlockIndexOrTrieRoot))
if err != nil {
return c.handleViewCallError(err, chainID)
return c.handleViewCallError(err)
}

fmt.Printf("GET BLOB INFO: ok:%v, err:%v", ok, err)
Expand Down
Loading

0 comments on commit 54cb0e6

Please sign in to comment.