diff --git a/modules/light-clients/08-wasm/keeper/genesis.go b/modules/light-clients/08-wasm/keeper/genesis.go index 772e3c01417..6291e83a380 100644 --- a/modules/light-clients/08-wasm/keeper/genesis.go +++ b/modules/light-clients/08-wasm/keeper/genesis.go @@ -1,6 +1,8 @@ package keeper import ( + wasmvmtypes "github.com/CosmWasm/wasmvm/types" + sdk "github.com/cosmos/cosmos-sdk/types" "github.com/cosmos/ibc-go/modules/light-clients/08-wasm/types" @@ -29,7 +31,7 @@ func (k Keeper) ExportGenesis(ctx sdk.Context) types.GenesisState { // Grab code from wasmVM and add to genesis state. var genesisState types.GenesisState for _, codeHash := range codeHashes { - code, err := k.wasmVM.GetCode(codeHash) + code, err := k.wasmVM.GetCode(wasmvmtypes.Checksum(codeHash)) if err != nil { panic(err) } diff --git a/modules/light-clients/08-wasm/keeper/genesis_test.go b/modules/light-clients/08-wasm/keeper/genesis_test.go index 04d76252c83..0ae5c1e6812 100644 --- a/modules/light-clients/08-wasm/keeper/genesis_test.go +++ b/modules/light-clients/08-wasm/keeper/genesis_test.go @@ -59,7 +59,6 @@ func (suite *KeeperTestSuite) TestInitGenesis() { var storedHashes []string codeHashes, err := types.GetAllCodeHashes(suite.chainA.GetContext()) suite.Require().NoError(err) - suite.Require().NotNil(codeHashes) for _, hash := range codeHashes { storedHashes = append(storedHashes, hex.EncodeToString(hash)) diff --git a/modules/light-clients/08-wasm/keeper/snapshotter.go b/modules/light-clients/08-wasm/keeper/snapshotter.go index c3322e52500..348862ef359 100644 --- a/modules/light-clients/08-wasm/keeper/snapshotter.go +++ b/modules/light-clients/08-wasm/keeper/snapshotter.go @@ -4,6 +4,8 @@ import ( "encoding/hex" "io" + wasmvmtypes "github.com/CosmWasm/wasmvm/types" + errorsmod "cosmossdk.io/errors" snapshot "cosmossdk.io/store/snapshots/types" storetypes "cosmossdk.io/store/types" @@ -72,7 +74,7 @@ func (ws *WasmSnapshotter) SnapshotExtension(height uint64, payloadWriter snapsh } for _, codeHash := range codeHashes { - wasmCode, err := ws.keeper.wasmVM.GetCode(codeHash) + wasmCode, err := ws.keeper.wasmVM.GetCode(wasmvmtypes.Checksum(codeHash)) if err != nil { return err } diff --git a/modules/light-clients/08-wasm/types/wasm.go b/modules/light-clients/08-wasm/types/wasm.go index 3e048f93345..6f364b758fa 100644 --- a/modules/light-clients/08-wasm/types/wasm.go +++ b/modules/light-clients/08-wasm/types/wasm.go @@ -6,21 +6,25 @@ import ( "github.com/cosmos/ibc-go/modules/light-clients/08-wasm/internal/ibcwasm" ) +// CodeHash is a type alias used for wasm byte code checksums. +type CodeHash []byte + // GetAllCodeHashes is a helper to get all code hashes from the store. // It returns an empty slice if no code hashes are found -func GetAllCodeHashes(ctx sdk.Context) ([][]byte, error) { +func GetAllCodeHashes(ctx sdk.Context) ([]CodeHash, error) { iterator, err := ibcwasm.CodeHashes.Iterate(ctx, nil) if err != nil { return nil, err } - codeHashes, err := iterator.Keys() + keys, err := iterator.Keys() if err != nil { return nil, err } - if codeHashes == nil { - codeHashes = [][]byte{} + codeHashes := []CodeHash{} + for _, key := range keys { + codeHashes = append(codeHashes, key) } return codeHashes, nil @@ -28,7 +32,7 @@ func GetAllCodeHashes(ctx sdk.Context) ([][]byte, error) { // HasCodeHash returns true if the given code hash exists in the store and // false otherwise. -func HasCodeHash(ctx sdk.Context, codeHash []byte) bool { +func HasCodeHash(ctx sdk.Context, codeHash CodeHash) bool { found, err := ibcwasm.CodeHashes.Has(ctx, codeHash) if err != nil { return false diff --git a/modules/light-clients/08-wasm/types/wasm_test.go b/modules/light-clients/08-wasm/types/wasm_test.go index cefe4668ae2..c46e3e5b5c1 100644 --- a/modules/light-clients/08-wasm/types/wasm_test.go +++ b/modules/light-clients/08-wasm/types/wasm_test.go @@ -12,12 +12,12 @@ func (suite *TypesTestSuite) TestGetCodeHashes() { testCases := []struct { name string malleate func() - expResult func(codeHashes [][]byte) + expResult func(codeHashes []types.CodeHash) }{ { "success: no contract stored.", func() {}, - func(codeHashes [][]byte) { + func(codeHashes []types.CodeHash) { suite.Require().Len(codeHashes, 0) }, }, @@ -26,10 +26,10 @@ func (suite *TypesTestSuite) TestGetCodeHashes() { func() { suite.SetupWasmWithMockVM() }, - func(codeHashes [][]byte) { + func(codeHashes []types.CodeHash) { suite.Require().Len(codeHashes, 1) expectedCodeHash := sha256.Sum256(wasmtesting.Code) - suite.Require().Equal(expectedCodeHash[:], codeHashes[0]) + suite.Require().Equal(types.CodeHash(expectedCodeHash[:]), codeHashes[0]) }, }, { @@ -37,12 +37,12 @@ func (suite *TypesTestSuite) TestGetCodeHashes() { func() { suite.SetupWasmWithMockVM() - err := ibcwasm.CodeHashes.Set(suite.chainA.GetContext(), []byte("codehash")) + err := ibcwasm.CodeHashes.Set(suite.chainA.GetContext(), types.CodeHash("codehash")) suite.Require().NoError(err) }, - func(codeHashes [][]byte) { + func(codeHashes []types.CodeHash) { suite.Require().Len(codeHashes, 2) - suite.Require().Contains(codeHashes, []byte("codehash")) + suite.Require().Contains(codeHashes, types.CodeHash("codehash")) }, }, } @@ -67,8 +67,8 @@ func (suite *TypesTestSuite) TestAddCodeHash() { // default mock vm contract is stored suite.Require().Len(codeHashes, 1) - codeHash1 := []byte("codehash1") - codeHash2 := []byte("codehash2") + codeHash1 := types.CodeHash("codehash1") + codeHash2 := types.CodeHash("codehash2") err = ibcwasm.CodeHashes.Set(suite.chainA.GetContext(), codeHash1) suite.Require().NoError(err) err = ibcwasm.CodeHashes.Set(suite.chainA.GetContext(), codeHash2) @@ -86,7 +86,7 @@ func (suite *TypesTestSuite) TestAddCodeHash() { } func (suite *TypesTestSuite) TestHasCodeHash() { - var codeHash []byte + var codeHash types.CodeHash testCases := []struct { name string @@ -96,7 +96,7 @@ func (suite *TypesTestSuite) TestHasCodeHash() { { "success: code hash exists", func() { - codeHash = []byte("codehash") + codeHash = types.CodeHash("codehash") err := ibcwasm.CodeHashes.Set(suite.chainA.GetContext(), codeHash) suite.Require().NoError(err) }, @@ -105,7 +105,7 @@ func (suite *TypesTestSuite) TestHasCodeHash() { { "success: code hash does not exist", func() { - codeHash = []byte("non-existent-codehash") + codeHash = types.CodeHash("non-existent-codehash") }, false, },