Skip to content

Commit

Permalink
Merge pull request #501 from CosmWasm/iteratorRef
Browse files Browse the repository at this point in the history
Refactor IteratorReference
  • Loading branch information
webmaster128 authored Feb 23, 2024
2 parents 9530fce + 27e3ac8 commit 27656e4
Show file tree
Hide file tree
Showing 8 changed files with 144 additions and 79 deletions.
25 changes: 18 additions & 7 deletions internal/api/bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,27 +255,34 @@ typedef struct U8SliceView {
uintptr_t len;
} U8SliceView;

typedef struct iterator_t {
/**
* A reference to some tables on the Go side which allow accessing
* the actual iterator instance.
*/
typedef struct IteratorReference {
/**
* An ID assigned to this contract call
*/
uint64_t call_id;
uint64_t iterator_index;
} iterator_t;
/**
* An ID assigned to this iterator
*/
uint64_t iterator_id;
} IteratorReference;

typedef struct IteratorVtable {
int32_t (*next)(struct iterator_t iterator,
int32_t (*next)(struct IteratorReference iterator,
struct gas_meter_t *gas_meter,
uint64_t *gas_used,
struct UnmanagedVector *key_out,
struct UnmanagedVector *value_out,
struct UnmanagedVector *err_msg_out);
int32_t (*next_key)(struct iterator_t iterator,
int32_t (*next_key)(struct IteratorReference iterator,
struct gas_meter_t *gas_meter,
uint64_t *gas_used,
struct UnmanagedVector *key_out,
struct UnmanagedVector *err_msg_out);
int32_t (*next_value)(struct iterator_t iterator,
int32_t (*next_value)(struct IteratorReference iterator,
struct gas_meter_t *gas_meter,
uint64_t *gas_used,
struct UnmanagedVector *value_out,
Expand All @@ -284,7 +291,11 @@ typedef struct IteratorVtable {

typedef struct GoIter {
struct gas_meter_t *gas_meter;
struct iterator_t state;
/**
* A reference which identifies the iterator and allows finding and accessing the
* actual iterator instance in Go. Once fully initalized, this is immutable.
*/
struct IteratorReference reference;
struct IteratorVtable vtable;
} GoIter;

Expand Down
34 changes: 17 additions & 17 deletions internal/api/callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ GoError cSet_cgo(db_t *ptr, gas_meter_t *gas_meter, uint64_t *used_gas, U8SliceV
GoError cDelete_cgo(db_t *ptr, gas_meter_t *gas_meter, uint64_t *used_gas, U8SliceView key, UnmanagedVector *errOut);
GoError cScan_cgo(db_t *ptr, gas_meter_t *gas_meter, uint64_t *used_gas, U8SliceView start, U8SliceView end, int32_t order, GoIter *out, UnmanagedVector *errOut);
// iterator
GoError cNext_cgo(iterator_t *ptr, gas_meter_t *gas_meter, uint64_t *used_gas, UnmanagedVector *key, UnmanagedVector *val, UnmanagedVector *errOut);
GoError cNextKey_cgo(iterator_t *ptr, gas_meter_t *gas_meter, uint64_t *used_gas, UnmanagedVector *key, UnmanagedVector *errOut);
GoError cNextValue_cgo(iterator_t *ptr, gas_meter_t *gas_meter, uint64_t *used_gas, UnmanagedVector *val, UnmanagedVector *errOut);
GoError cNext_cgo(IteratorReference *ref, gas_meter_t *gas_meter, uint64_t *used_gas, UnmanagedVector *key, UnmanagedVector *val, UnmanagedVector *errOut);
GoError cNextKey_cgo(IteratorReference *ref, gas_meter_t *gas_meter, uint64_t *used_gas, UnmanagedVector *key, UnmanagedVector *errOut);
GoError cNextValue_cgo(IteratorReference *ref, gas_meter_t *gas_meter, uint64_t *used_gas, UnmanagedVector *val, UnmanagedVector *errOut);
// api
GoError cHumanizeAddress_cgo(api_t *ptr, U8SliceView src, UnmanagedVector *dest, UnmanagedVector *errOut, uint64_t *used_gas);
GoError cCanonicalizeAddress_cgo(api_t *ptr, U8SliceView src, UnmanagedVector *dest, UnmanagedVector *errOut, uint64_t *used_gas);
Expand Down Expand Up @@ -138,14 +138,14 @@ const frameLenLimit = 32768

// contract: original pointer/struct referenced must live longer than C.Db struct
// since this is only used internally, we can verify the code that this is the case
func buildIterator(callID uint64, it types.Iterator) (C.iterator_t, error) {
idx, err := storeIterator(callID, it, frameLenLimit)
func buildIterator(callID uint64, it types.Iterator) (C.IteratorReference, error) {
iteratorID, err := storeIterator(callID, it, frameLenLimit)
if err != nil {
return C.iterator_t{}, err
return C.IteratorReference{}, err
}
return C.iterator_t{
call_id: cu64(callID),
iterator_index: cu64(idx),
return C.IteratorReference{
call_id: cu64(callID),
iterator_id: cu64(iteratorID),
}, nil
}

Expand Down Expand Up @@ -257,7 +257,7 @@ func cScan(ptr *C.db_t, gasMeter *C.gas_meter_t, usedGas *cu64, start C.U8SliceV
gasAfter := gm.GasConsumed()
*usedGas = (cu64)(gasAfter - gasBefore)

cIterator, err := buildIterator(state.CallID, iter)
iteratorRef, err := buildIterator(state.CallID, iter)
if err != nil {
// store the actual error message in the return buffer
*errOut = newUnmanagedVector([]byte(err.Error()))
Expand All @@ -266,15 +266,15 @@ func cScan(ptr *C.db_t, gasMeter *C.gas_meter_t, usedGas *cu64, start C.U8SliceV

*out = C.GoIter{
gas_meter: gasMeter,
state: cIterator,
reference: iteratorRef,
vtable: iterator_vtable,
}

return C.GoError_None
}

//export cNext
func cNext(ref C.iterator_t, gasMeter *C.gas_meter_t, usedGas *cu64, key *C.UnmanagedVector, val *C.UnmanagedVector, errOut *C.UnmanagedVector) (ret C.GoError) {
func cNext(ref C.IteratorReference, gasMeter *C.gas_meter_t, usedGas *cu64, key *C.UnmanagedVector, val *C.UnmanagedVector, errOut *C.UnmanagedVector) (ret C.GoError) {
// typical usage of iterator
// for ; itr.Valid(); itr.Next() {
// k, v := itr.Key(); itr.Value()
Expand All @@ -291,7 +291,7 @@ func cNext(ref C.iterator_t, gasMeter *C.gas_meter_t, usedGas *cu64, key *C.Unma
}

gm := *(*types.GasMeter)(unsafe.Pointer(gasMeter))
iter := retrieveIterator(uint64(ref.call_id), uint64(ref.iterator_index))
iter := retrieveIterator(uint64(ref.call_id), uint64(ref.iterator_id))
if iter == nil {
panic("Unable to retrieve iterator.")
}
Expand All @@ -315,17 +315,17 @@ func cNext(ref C.iterator_t, gasMeter *C.gas_meter_t, usedGas *cu64, key *C.Unma
}

//export cNextKey
func cNextKey(ref C.iterator_t, gasMeter *C.gas_meter_t, usedGas *cu64, key *C.UnmanagedVector, errOut *C.UnmanagedVector) (ret C.GoError) {
func cNextKey(ref C.IteratorReference, gasMeter *C.gas_meter_t, usedGas *cu64, key *C.UnmanagedVector, errOut *C.UnmanagedVector) (ret C.GoError) {
return nextPart(ref, gasMeter, usedGas, key, errOut, func(iter types.Iterator) []byte { return iter.Key() })
}

//export cNextValue
func cNextValue(ref C.iterator_t, gasMeter *C.gas_meter_t, usedGas *cu64, value *C.UnmanagedVector, errOut *C.UnmanagedVector) (ret C.GoError) {
func cNextValue(ref C.IteratorReference, gasMeter *C.gas_meter_t, usedGas *cu64, value *C.UnmanagedVector, errOut *C.UnmanagedVector) (ret C.GoError) {
return nextPart(ref, gasMeter, usedGas, value, errOut, func(iter types.Iterator) []byte { return iter.Value() })
}

// nextPart is a helper function that contains the shared code for key- and value-only iteration.
func nextPart(ref C.iterator_t, gasMeter *C.gas_meter_t, usedGas *cu64, output *C.UnmanagedVector, errOut *C.UnmanagedVector, valFn func(types.Iterator) []byte) (ret C.GoError) {
func nextPart(ref C.IteratorReference, gasMeter *C.gas_meter_t, usedGas *cu64, output *C.UnmanagedVector, errOut *C.UnmanagedVector, valFn func(types.Iterator) []byte) (ret C.GoError) {
// typical usage of iterator
// for ; itr.Valid(); itr.Next() {
// k, v := itr.Key(); itr.Value()
Expand All @@ -342,7 +342,7 @@ func nextPart(ref C.iterator_t, gasMeter *C.gas_meter_t, usedGas *cu64, output *
}

gm := *(*types.GasMeter)(unsafe.Pointer(gasMeter))
iter := retrieveIterator(uint64(ref.call_id), uint64(ref.iterator_index))
iter := retrieveIterator(uint64(ref.call_id), uint64(ref.iterator_id))
if iter == nil {
panic("Unable to retrieve iterator.")
}
Expand Down
18 changes: 9 additions & 9 deletions internal/api/callbacks_cgo.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ GoError cGet(db_t *ptr, gas_meter_t *gas_meter, uint64_t *used_gas, U8SliceView
GoError cDelete(db_t *ptr, gas_meter_t *gas_meter, uint64_t *used_gas, U8SliceView key, UnmanagedVector *errOut);
GoError cScan(db_t *ptr, gas_meter_t *gas_meter, uint64_t *used_gas, U8SliceView start, U8SliceView end, int32_t order, GoIter *out, UnmanagedVector *errOut);
// imports (iterator)
GoError cNext(iterator_t *ptr, gas_meter_t *gas_meter, uint64_t *used_gas, UnmanagedVector *key, UnmanagedVector *val, UnmanagedVector *errOut);
GoError cNextKey(iterator_t *ptr, gas_meter_t *gas_meter, uint64_t *used_gas, UnmanagedVector *key, UnmanagedVector *errOut);
GoError cNextValue(iterator_t *ptr, gas_meter_t *gas_meter, uint64_t *used_gas, UnmanagedVector *value, UnmanagedVector *errOut);
GoError cNext(IteratorReference *ref, gas_meter_t *gas_meter, uint64_t *used_gas, UnmanagedVector *key, UnmanagedVector *val, UnmanagedVector *errOut);
GoError cNextKey(IteratorReference *ref, gas_meter_t *gas_meter, uint64_t *used_gas, UnmanagedVector *key, UnmanagedVector *errOut);
GoError cNextValue(IteratorReference *ref, gas_meter_t *gas_meter, uint64_t *used_gas, UnmanagedVector *value, UnmanagedVector *errOut);
// imports (api)
GoError cHumanizeAddress(api_t *ptr, U8SliceView src, UnmanagedVector *dest, UnmanagedVector *errOut, uint64_t *used_gas);
GoError cCanonicalizeAddress(api_t *ptr, U8SliceView src, UnmanagedVector *dest, UnmanagedVector *errOut, uint64_t *used_gas);
Expand All @@ -35,14 +35,14 @@ GoError cScan_cgo(db_t *ptr, gas_meter_t *gas_meter, uint64_t *used_gas, U8Slice
}
// Gateway functions (iterator)
GoError cNext_cgo(iterator_t *ptr, gas_meter_t *gas_meter, uint64_t *used_gas, UnmanagedVector *key, UnmanagedVector *val, UnmanagedVector *errOut) {
return cNext(ptr, gas_meter, used_gas, key, val, errOut);
GoError cNext_cgo(IteratorReference *ref, gas_meter_t *gas_meter, uint64_t *used_gas, UnmanagedVector *key, UnmanagedVector *val, UnmanagedVector *errOut) {
return cNext(ref, gas_meter, used_gas, key, val, errOut);
}
GoError cNextKey_cgo(iterator_t *ptr, gas_meter_t *gas_meter, uint64_t *used_gas, UnmanagedVector *key, UnmanagedVector *errOut) {
return cNextKey(ptr, gas_meter, used_gas, key, errOut);
GoError cNextKey_cgo(IteratorReference *ref, gas_meter_t *gas_meter, uint64_t *used_gas, UnmanagedVector *key, UnmanagedVector *errOut) {
return cNextKey(ref, gas_meter, used_gas, key, errOut);
}
GoError cNextValue_cgo(iterator_t *ptr, gas_meter_t *gas_meter, uint64_t *used_gas, UnmanagedVector *val, UnmanagedVector *errOut) {
return cNextValue(ptr, gas_meter, used_gas, val, errOut);
GoError cNextValue_cgo(IteratorReference *ref, gas_meter_t *gas_meter, uint64_t *used_gas, UnmanagedVector *val, UnmanagedVector *errOut) {
return cNextValue(ref, gas_meter, used_gas, val, errOut);
}
// Gateway functions (api)
Expand Down
63 changes: 48 additions & 15 deletions internal/api/iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package api

import (
"fmt"
"math"
"sync"

"github.com/CosmWasm/wasmvm/v2/types"
Expand Down Expand Up @@ -53,39 +54,71 @@ func endCall(callID uint64) {
}
}

// storeIterator will add this to the end of the frame for the given ID and return a reference to it.
// We start counting with 1, so the 0 value is flagged as an error. This means we must
// remember to do idx-1 when retrieving
// storeIterator will add this to the end of the frame for the given call ID and return
// an iterator ID to reference it.
//
// We assign iterator IDs starting with 1 for historic reasons. This could be changed to 0
// I guess.
func storeIterator(callID uint64, it types.Iterator, frameLenLimit int) (uint64, error) {
iteratorFramesMutex.Lock()
defer iteratorFramesMutex.Unlock()

old_frame_len := len(iteratorFrames[callID])
if old_frame_len >= frameLenLimit {
new_index := len(iteratorFrames[callID])
if new_index >= frameLenLimit {
return 0, fmt.Errorf("Reached iterator limit (%d)", frameLenLimit)
}

// store at array position `old_frame_len`
// store at array position `new_index`
iteratorFrames[callID] = append(iteratorFrames[callID], it)
new_index := old_frame_len + 1

return uint64(new_index), nil
iterator_id, ok := indexToIteratorID(new_index)
if !ok {
// This error case is not expected to happen since the above code ensures the
// index is in the range [0, frameLenLimit-1]
return 0, fmt.Errorf("could not convert index to iterator ID")
}
return iterator_id, nil
}

// retrieveIterator will recover an iterator based on index. This ensures it will not be garbage collected.
// We start counting with 1, in storeIterator so the 0 value is flagged as an error. This means we must
// remember to do idx-1 when retrieving
func retrieveIterator(callID uint64, index uint64) types.Iterator {
// retrieveIterator will recover an iterator based on its ID.
func retrieveIterator(callID uint64, iteratorID uint64) types.Iterator {
indexInFrame, ok := iteratorIdToIndex(iteratorID)
if !ok {
return nil
}

iteratorFramesMutex.Lock()
defer iteratorFramesMutex.Unlock()
myFrame := iteratorFrames[callID]
if myFrame == nil {
return nil
}
posInFrame := int(index) - 1
if posInFrame < 0 || posInFrame >= len(myFrame) {
if indexInFrame >= len(myFrame) {
// index out of range
return nil
}
return myFrame[posInFrame]
return myFrame[indexInFrame]
}

// iteratorIdToIndex converts an iterator ID to an index in the frame.
// The second value marks if the conversion succeeded.
func iteratorIdToIndex(id uint64) (int, bool) {
if id < 1 || id > math.MaxInt32 {
// If success is false, the int value is undefined. We use an arbitrary constant for potential debugging purposes.
return 777777777, false
}

// Int conversion safe because value is in signed 32bit integer range
return int(id) - 1, true
}

// indexToIteratorID converts an index in the frame to an iterator ID.
// The second value marks if the conversion succeeded.
func indexToIteratorID(index int) (uint64, bool) {
if index < 0 || index > math.MaxInt32 {
// If success is false, the return value is undefined. We use an arbitrary constant for potential debugging purposes.
return 888888888, false
}

return uint64(index) + 1, true
}
24 changes: 15 additions & 9 deletions internal/api/iterator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func TestRetrieveIterator(t *testing.T) {
var err error

iter, _ = store.Iterator(nil, nil)
index11, err := storeIterator(callID1, iter, limit)
iteratorID11, err := storeIterator(callID1, iter, limit)
require.NoError(t, err)
iter, _ = store.Iterator(nil, nil)
_, err = storeIterator(callID1, iter, limit)
Expand All @@ -140,27 +140,33 @@ func TestRetrieveIterator(t *testing.T) {
_, err = storeIterator(callID2, iter, limit)
require.NoError(t, err)
iter, _ = store.Iterator(nil, nil)
index22, err := storeIterator(callID2, iter, limit)
iteratorID22, err := storeIterator(callID2, iter, limit)
require.NoError(t, err)
iter, err = store.Iterator(nil, nil)
require.NoError(t, err)
index23, err := storeIterator(callID2, iter, limit)
iteratorID23, err := storeIterator(callID2, iter, limit)
require.NoError(t, err)

// Retrieve existing
iter = retrieveIterator(callID1, index11)
iter = retrieveIterator(callID1, iteratorID11)
require.NotNil(t, iter)
iter = retrieveIterator(callID2, index22)
iter = retrieveIterator(callID2, iteratorID22)
require.NotNil(t, iter)

// Retrieve non-existent index
iter = retrieveIterator(callID1, index23)
// Retrieve with non-existent iterator ID
iter = retrieveIterator(callID1, iteratorID23)
require.Nil(t, iter)
iter = retrieveIterator(callID1, uint64(0))
require.Nil(t, iter)
iter = retrieveIterator(callID1, uint64(2147483647))
require.Nil(t, iter)
iter = retrieveIterator(callID1, uint64(2147483648))
require.Nil(t, iter)
iter = retrieveIterator(callID1, uint64(18446744073709551615))
require.Nil(t, iter)

// Retrieve non-existent call ID
iter = retrieveIterator(callID1+1_234_567, index23)
// Retrieve with non-existent call ID
iter = retrieveIterator(callID1+1_234_567, iteratorID23)
require.Nil(t, iter)

endCall(callID1)
Expand Down
25 changes: 18 additions & 7 deletions libwasmvm/bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,27 +255,34 @@ typedef struct U8SliceView {
uintptr_t len;
} U8SliceView;

typedef struct iterator_t {
/**
* A reference to some tables on the Go side which allow accessing
* the actual iterator instance.
*/
typedef struct IteratorReference {
/**
* An ID assigned to this contract call
*/
uint64_t call_id;
uint64_t iterator_index;
} iterator_t;
/**
* An ID assigned to this iterator
*/
uint64_t iterator_id;
} IteratorReference;

typedef struct IteratorVtable {
int32_t (*next)(struct iterator_t iterator,
int32_t (*next)(struct IteratorReference iterator,
struct gas_meter_t *gas_meter,
uint64_t *gas_used,
struct UnmanagedVector *key_out,
struct UnmanagedVector *value_out,
struct UnmanagedVector *err_msg_out);
int32_t (*next_key)(struct iterator_t iterator,
int32_t (*next_key)(struct IteratorReference iterator,
struct gas_meter_t *gas_meter,
uint64_t *gas_used,
struct UnmanagedVector *key_out,
struct UnmanagedVector *err_msg_out);
int32_t (*next_value)(struct iterator_t iterator,
int32_t (*next_value)(struct IteratorReference iterator,
struct gas_meter_t *gas_meter,
uint64_t *gas_used,
struct UnmanagedVector *value_out,
Expand All @@ -284,7 +291,11 @@ typedef struct IteratorVtable {

typedef struct GoIter {
struct gas_meter_t *gas_meter;
struct iterator_t state;
/**
* A reference which identifies the iterator and allows finding and accessing the
* actual iterator instance in Go. Once fully initalized, this is immutable.
*/
struct IteratorReference reference;
struct IteratorVtable vtable;
} GoIter;

Expand Down
Loading

0 comments on commit 27656e4

Please sign in to comment.