diff --git a/internal/api/bindings.h b/internal/api/bindings.h index 00bb0ec3a..2d2835f5f 100644 --- a/internal/api/bindings.h +++ b/internal/api/bindings.h @@ -255,27 +255,34 @@ typedef struct U8SliceView { uintptr_t len; } U8SliceView; -typedef struct iterator_t { +/** + * A reference to some tables on the Go side which allow accessing + * the actual iterator instance. + */ +typedef struct IteratorReference { /** * An ID assigned to this contract call */ uint64_t call_id; - uint64_t iterator_index; -} iterator_t; + /** + * An ID assigned to this iterator + */ + uint64_t iterator_id; +} IteratorReference; typedef struct IteratorVtable { - int32_t (*next)(struct iterator_t iterator, + int32_t (*next)(struct IteratorReference iterator, struct gas_meter_t *gas_meter, uint64_t *gas_used, struct UnmanagedVector *key_out, struct UnmanagedVector *value_out, struct UnmanagedVector *err_msg_out); - int32_t (*next_key)(struct iterator_t iterator, + int32_t (*next_key)(struct IteratorReference iterator, struct gas_meter_t *gas_meter, uint64_t *gas_used, struct UnmanagedVector *key_out, struct UnmanagedVector *err_msg_out); - int32_t (*next_value)(struct iterator_t iterator, + int32_t (*next_value)(struct IteratorReference iterator, struct gas_meter_t *gas_meter, uint64_t *gas_used, struct UnmanagedVector *value_out, @@ -284,7 +291,11 @@ typedef struct IteratorVtable { typedef struct GoIter { struct gas_meter_t *gas_meter; - struct iterator_t state; + /** + * A reference which identifies the iterator and allows finding and accessing the + * actual iterator instance in Go. Once fully initalized, this is immutable. + */ + struct IteratorReference reference; struct IteratorVtable vtable; } GoIter; diff --git a/internal/api/callbacks.go b/internal/api/callbacks.go index 5db7ee817..1d60f5683 100644 --- a/internal/api/callbacks.go +++ b/internal/api/callbacks.go @@ -19,9 +19,9 @@ GoError cSet_cgo(db_t *ptr, gas_meter_t *gas_meter, uint64_t *used_gas, U8SliceV GoError cDelete_cgo(db_t *ptr, gas_meter_t *gas_meter, uint64_t *used_gas, U8SliceView key, UnmanagedVector *errOut); GoError cScan_cgo(db_t *ptr, gas_meter_t *gas_meter, uint64_t *used_gas, U8SliceView start, U8SliceView end, int32_t order, GoIter *out, UnmanagedVector *errOut); // iterator -GoError cNext_cgo(iterator_t *ptr, gas_meter_t *gas_meter, uint64_t *used_gas, UnmanagedVector *key, UnmanagedVector *val, UnmanagedVector *errOut); -GoError cNextKey_cgo(iterator_t *ptr, gas_meter_t *gas_meter, uint64_t *used_gas, UnmanagedVector *key, UnmanagedVector *errOut); -GoError cNextValue_cgo(iterator_t *ptr, gas_meter_t *gas_meter, uint64_t *used_gas, UnmanagedVector *val, UnmanagedVector *errOut); +GoError cNext_cgo(IteratorReference *ref, gas_meter_t *gas_meter, uint64_t *used_gas, UnmanagedVector *key, UnmanagedVector *val, UnmanagedVector *errOut); +GoError cNextKey_cgo(IteratorReference *ref, gas_meter_t *gas_meter, uint64_t *used_gas, UnmanagedVector *key, UnmanagedVector *errOut); +GoError cNextValue_cgo(IteratorReference *ref, gas_meter_t *gas_meter, uint64_t *used_gas, UnmanagedVector *val, UnmanagedVector *errOut); // api GoError cHumanizeAddress_cgo(api_t *ptr, U8SliceView src, UnmanagedVector *dest, UnmanagedVector *errOut, uint64_t *used_gas); GoError cCanonicalizeAddress_cgo(api_t *ptr, U8SliceView src, UnmanagedVector *dest, UnmanagedVector *errOut, uint64_t *used_gas); @@ -138,14 +138,14 @@ const frameLenLimit = 32768 // contract: original pointer/struct referenced must live longer than C.Db struct // since this is only used internally, we can verify the code that this is the case -func buildIterator(callID uint64, it types.Iterator) (C.iterator_t, error) { - idx, err := storeIterator(callID, it, frameLenLimit) +func buildIterator(callID uint64, it types.Iterator) (C.IteratorReference, error) { + iteratorID, err := storeIterator(callID, it, frameLenLimit) if err != nil { - return C.iterator_t{}, err + return C.IteratorReference{}, err } - return C.iterator_t{ - call_id: cu64(callID), - iterator_index: cu64(idx), + return C.IteratorReference{ + call_id: cu64(callID), + iterator_id: cu64(iteratorID), }, nil } @@ -257,7 +257,7 @@ func cScan(ptr *C.db_t, gasMeter *C.gas_meter_t, usedGas *cu64, start C.U8SliceV gasAfter := gm.GasConsumed() *usedGas = (cu64)(gasAfter - gasBefore) - cIterator, err := buildIterator(state.CallID, iter) + iteratorRef, err := buildIterator(state.CallID, iter) if err != nil { // store the actual error message in the return buffer *errOut = newUnmanagedVector([]byte(err.Error())) @@ -266,7 +266,7 @@ func cScan(ptr *C.db_t, gasMeter *C.gas_meter_t, usedGas *cu64, start C.U8SliceV *out = C.GoIter{ gas_meter: gasMeter, - state: cIterator, + reference: iteratorRef, vtable: iterator_vtable, } @@ -274,7 +274,7 @@ func cScan(ptr *C.db_t, gasMeter *C.gas_meter_t, usedGas *cu64, start C.U8SliceV } //export cNext -func cNext(ref C.iterator_t, gasMeter *C.gas_meter_t, usedGas *cu64, key *C.UnmanagedVector, val *C.UnmanagedVector, errOut *C.UnmanagedVector) (ret C.GoError) { +func cNext(ref C.IteratorReference, gasMeter *C.gas_meter_t, usedGas *cu64, key *C.UnmanagedVector, val *C.UnmanagedVector, errOut *C.UnmanagedVector) (ret C.GoError) { // typical usage of iterator // for ; itr.Valid(); itr.Next() { // k, v := itr.Key(); itr.Value() @@ -291,7 +291,7 @@ func cNext(ref C.iterator_t, gasMeter *C.gas_meter_t, usedGas *cu64, key *C.Unma } gm := *(*types.GasMeter)(unsafe.Pointer(gasMeter)) - iter := retrieveIterator(uint64(ref.call_id), uint64(ref.iterator_index)) + iter := retrieveIterator(uint64(ref.call_id), uint64(ref.iterator_id)) if iter == nil { panic("Unable to retrieve iterator.") } @@ -315,17 +315,17 @@ func cNext(ref C.iterator_t, gasMeter *C.gas_meter_t, usedGas *cu64, key *C.Unma } //export cNextKey -func cNextKey(ref C.iterator_t, gasMeter *C.gas_meter_t, usedGas *cu64, key *C.UnmanagedVector, errOut *C.UnmanagedVector) (ret C.GoError) { +func cNextKey(ref C.IteratorReference, gasMeter *C.gas_meter_t, usedGas *cu64, key *C.UnmanagedVector, errOut *C.UnmanagedVector) (ret C.GoError) { return nextPart(ref, gasMeter, usedGas, key, errOut, func(iter types.Iterator) []byte { return iter.Key() }) } //export cNextValue -func cNextValue(ref C.iterator_t, gasMeter *C.gas_meter_t, usedGas *cu64, value *C.UnmanagedVector, errOut *C.UnmanagedVector) (ret C.GoError) { +func cNextValue(ref C.IteratorReference, gasMeter *C.gas_meter_t, usedGas *cu64, value *C.UnmanagedVector, errOut *C.UnmanagedVector) (ret C.GoError) { return nextPart(ref, gasMeter, usedGas, value, errOut, func(iter types.Iterator) []byte { return iter.Value() }) } // nextPart is a helper function that contains the shared code for key- and value-only iteration. -func nextPart(ref C.iterator_t, gasMeter *C.gas_meter_t, usedGas *cu64, output *C.UnmanagedVector, errOut *C.UnmanagedVector, valFn func(types.Iterator) []byte) (ret C.GoError) { +func nextPart(ref C.IteratorReference, gasMeter *C.gas_meter_t, usedGas *cu64, output *C.UnmanagedVector, errOut *C.UnmanagedVector, valFn func(types.Iterator) []byte) (ret C.GoError) { // typical usage of iterator // for ; itr.Valid(); itr.Next() { // k, v := itr.Key(); itr.Value() @@ -342,7 +342,7 @@ func nextPart(ref C.iterator_t, gasMeter *C.gas_meter_t, usedGas *cu64, output * } gm := *(*types.GasMeter)(unsafe.Pointer(gasMeter)) - iter := retrieveIterator(uint64(ref.call_id), uint64(ref.iterator_index)) + iter := retrieveIterator(uint64(ref.call_id), uint64(ref.iterator_id)) if iter == nil { panic("Unable to retrieve iterator.") } diff --git a/internal/api/callbacks_cgo.go b/internal/api/callbacks_cgo.go index c8f237f61..53d84c076 100644 --- a/internal/api/callbacks_cgo.go +++ b/internal/api/callbacks_cgo.go @@ -10,9 +10,9 @@ GoError cGet(db_t *ptr, gas_meter_t *gas_meter, uint64_t *used_gas, U8SliceView GoError cDelete(db_t *ptr, gas_meter_t *gas_meter, uint64_t *used_gas, U8SliceView key, UnmanagedVector *errOut); GoError cScan(db_t *ptr, gas_meter_t *gas_meter, uint64_t *used_gas, U8SliceView start, U8SliceView end, int32_t order, GoIter *out, UnmanagedVector *errOut); // imports (iterator) -GoError cNext(iterator_t *ptr, gas_meter_t *gas_meter, uint64_t *used_gas, UnmanagedVector *key, UnmanagedVector *val, UnmanagedVector *errOut); -GoError cNextKey(iterator_t *ptr, gas_meter_t *gas_meter, uint64_t *used_gas, UnmanagedVector *key, UnmanagedVector *errOut); -GoError cNextValue(iterator_t *ptr, gas_meter_t *gas_meter, uint64_t *used_gas, UnmanagedVector *value, UnmanagedVector *errOut); +GoError cNext(IteratorReference *ref, gas_meter_t *gas_meter, uint64_t *used_gas, UnmanagedVector *key, UnmanagedVector *val, UnmanagedVector *errOut); +GoError cNextKey(IteratorReference *ref, gas_meter_t *gas_meter, uint64_t *used_gas, UnmanagedVector *key, UnmanagedVector *errOut); +GoError cNextValue(IteratorReference *ref, gas_meter_t *gas_meter, uint64_t *used_gas, UnmanagedVector *value, UnmanagedVector *errOut); // imports (api) GoError cHumanizeAddress(api_t *ptr, U8SliceView src, UnmanagedVector *dest, UnmanagedVector *errOut, uint64_t *used_gas); GoError cCanonicalizeAddress(api_t *ptr, U8SliceView src, UnmanagedVector *dest, UnmanagedVector *errOut, uint64_t *used_gas); @@ -35,14 +35,14 @@ GoError cScan_cgo(db_t *ptr, gas_meter_t *gas_meter, uint64_t *used_gas, U8Slice } // Gateway functions (iterator) -GoError cNext_cgo(iterator_t *ptr, gas_meter_t *gas_meter, uint64_t *used_gas, UnmanagedVector *key, UnmanagedVector *val, UnmanagedVector *errOut) { - return cNext(ptr, gas_meter, used_gas, key, val, errOut); +GoError cNext_cgo(IteratorReference *ref, gas_meter_t *gas_meter, uint64_t *used_gas, UnmanagedVector *key, UnmanagedVector *val, UnmanagedVector *errOut) { + return cNext(ref, gas_meter, used_gas, key, val, errOut); } -GoError cNextKey_cgo(iterator_t *ptr, gas_meter_t *gas_meter, uint64_t *used_gas, UnmanagedVector *key, UnmanagedVector *errOut) { - return cNextKey(ptr, gas_meter, used_gas, key, errOut); +GoError cNextKey_cgo(IteratorReference *ref, gas_meter_t *gas_meter, uint64_t *used_gas, UnmanagedVector *key, UnmanagedVector *errOut) { + return cNextKey(ref, gas_meter, used_gas, key, errOut); } -GoError cNextValue_cgo(iterator_t *ptr, gas_meter_t *gas_meter, uint64_t *used_gas, UnmanagedVector *val, UnmanagedVector *errOut) { - return cNextValue(ptr, gas_meter, used_gas, val, errOut); +GoError cNextValue_cgo(IteratorReference *ref, gas_meter_t *gas_meter, uint64_t *used_gas, UnmanagedVector *val, UnmanagedVector *errOut) { + return cNextValue(ref, gas_meter, used_gas, val, errOut); } // Gateway functions (api) diff --git a/internal/api/iterator.go b/internal/api/iterator.go index 4dff612c3..c9a768b40 100644 --- a/internal/api/iterator.go +++ b/internal/api/iterator.go @@ -2,6 +2,7 @@ package api import ( "fmt" + "math" "sync" "github.com/CosmWasm/wasmvm/v2/types" @@ -53,39 +54,71 @@ func endCall(callID uint64) { } } -// storeIterator will add this to the end of the frame for the given ID and return a reference to it. -// We start counting with 1, so the 0 value is flagged as an error. This means we must -// remember to do idx-1 when retrieving +// storeIterator will add this to the end of the frame for the given call ID and return +// an iterator ID to reference it. +// +// We assign iterator IDs starting with 1 for historic reasons. This could be changed to 0 +// I guess. func storeIterator(callID uint64, it types.Iterator, frameLenLimit int) (uint64, error) { iteratorFramesMutex.Lock() defer iteratorFramesMutex.Unlock() - old_frame_len := len(iteratorFrames[callID]) - if old_frame_len >= frameLenLimit { + new_index := len(iteratorFrames[callID]) + if new_index >= frameLenLimit { return 0, fmt.Errorf("Reached iterator limit (%d)", frameLenLimit) } - // store at array position `old_frame_len` + // store at array position `new_index` iteratorFrames[callID] = append(iteratorFrames[callID], it) - new_index := old_frame_len + 1 - return uint64(new_index), nil + iterator_id, ok := indexToIteratorID(new_index) + if !ok { + // This error case is not expected to happen since the above code ensures the + // index is in the range [0, frameLenLimit-1] + return 0, fmt.Errorf("could not convert index to iterator ID") + } + return iterator_id, nil } -// retrieveIterator will recover an iterator based on index. This ensures it will not be garbage collected. -// We start counting with 1, in storeIterator so the 0 value is flagged as an error. This means we must -// remember to do idx-1 when retrieving -func retrieveIterator(callID uint64, index uint64) types.Iterator { +// retrieveIterator will recover an iterator based on its ID. +func retrieveIterator(callID uint64, iteratorID uint64) types.Iterator { + indexInFrame, ok := iteratorIdToIndex(iteratorID) + if !ok { + return nil + } + iteratorFramesMutex.Lock() defer iteratorFramesMutex.Unlock() myFrame := iteratorFrames[callID] if myFrame == nil { return nil } - posInFrame := int(index) - 1 - if posInFrame < 0 || posInFrame >= len(myFrame) { + if indexInFrame >= len(myFrame) { // index out of range return nil } - return myFrame[posInFrame] + return myFrame[indexInFrame] +} + +// iteratorIdToIndex converts an iterator ID to an index in the frame. +// The second value marks if the conversion succeeded. +func iteratorIdToIndex(id uint64) (int, bool) { + if id < 1 || id > math.MaxInt32 { + // If success is false, the int value is undefined. We use an arbitrary constant for potential debugging purposes. + return 777777777, false + } + + // Int conversion safe because value is in signed 32bit integer range + return int(id) - 1, true +} + +// indexToIteratorID converts an index in the frame to an iterator ID. +// The second value marks if the conversion succeeded. +func indexToIteratorID(index int) (uint64, bool) { + if index < 0 || index > math.MaxInt32 { + // If success is false, the return value is undefined. We use an arbitrary constant for potential debugging purposes. + return 888888888, false + } + + return uint64(index) + 1, true } diff --git a/internal/api/iterator_test.go b/internal/api/iterator_test.go index 54366a4cc..0c81db775 100644 --- a/internal/api/iterator_test.go +++ b/internal/api/iterator_test.go @@ -131,7 +131,7 @@ func TestRetrieveIterator(t *testing.T) { var err error iter, _ = store.Iterator(nil, nil) - index11, err := storeIterator(callID1, iter, limit) + iteratorID11, err := storeIterator(callID1, iter, limit) require.NoError(t, err) iter, _ = store.Iterator(nil, nil) _, err = storeIterator(callID1, iter, limit) @@ -140,27 +140,33 @@ func TestRetrieveIterator(t *testing.T) { _, err = storeIterator(callID2, iter, limit) require.NoError(t, err) iter, _ = store.Iterator(nil, nil) - index22, err := storeIterator(callID2, iter, limit) + iteratorID22, err := storeIterator(callID2, iter, limit) require.NoError(t, err) iter, err = store.Iterator(nil, nil) require.NoError(t, err) - index23, err := storeIterator(callID2, iter, limit) + iteratorID23, err := storeIterator(callID2, iter, limit) require.NoError(t, err) // Retrieve existing - iter = retrieveIterator(callID1, index11) + iter = retrieveIterator(callID1, iteratorID11) require.NotNil(t, iter) - iter = retrieveIterator(callID2, index22) + iter = retrieveIterator(callID2, iteratorID22) require.NotNil(t, iter) - // Retrieve non-existent index - iter = retrieveIterator(callID1, index23) + // Retrieve with non-existent iterator ID + iter = retrieveIterator(callID1, iteratorID23) require.Nil(t, iter) iter = retrieveIterator(callID1, uint64(0)) require.Nil(t, iter) + iter = retrieveIterator(callID1, uint64(2147483647)) + require.Nil(t, iter) + iter = retrieveIterator(callID1, uint64(2147483648)) + require.Nil(t, iter) + iter = retrieveIterator(callID1, uint64(18446744073709551615)) + require.Nil(t, iter) - // Retrieve non-existent call ID - iter = retrieveIterator(callID1+1_234_567, index23) + // Retrieve with non-existent call ID + iter = retrieveIterator(callID1+1_234_567, iteratorID23) require.Nil(t, iter) endCall(callID1) diff --git a/libwasmvm/bindings.h b/libwasmvm/bindings.h index 00bb0ec3a..2d2835f5f 100644 --- a/libwasmvm/bindings.h +++ b/libwasmvm/bindings.h @@ -255,27 +255,34 @@ typedef struct U8SliceView { uintptr_t len; } U8SliceView; -typedef struct iterator_t { +/** + * A reference to some tables on the Go side which allow accessing + * the actual iterator instance. + */ +typedef struct IteratorReference { /** * An ID assigned to this contract call */ uint64_t call_id; - uint64_t iterator_index; -} iterator_t; + /** + * An ID assigned to this iterator + */ + uint64_t iterator_id; +} IteratorReference; typedef struct IteratorVtable { - int32_t (*next)(struct iterator_t iterator, + int32_t (*next)(struct IteratorReference iterator, struct gas_meter_t *gas_meter, uint64_t *gas_used, struct UnmanagedVector *key_out, struct UnmanagedVector *value_out, struct UnmanagedVector *err_msg_out); - int32_t (*next_key)(struct iterator_t iterator, + int32_t (*next_key)(struct IteratorReference iterator, struct gas_meter_t *gas_meter, uint64_t *gas_used, struct UnmanagedVector *key_out, struct UnmanagedVector *err_msg_out); - int32_t (*next_value)(struct iterator_t iterator, + int32_t (*next_value)(struct IteratorReference iterator, struct gas_meter_t *gas_meter, uint64_t *gas_used, struct UnmanagedVector *value_out, @@ -284,7 +291,11 @@ typedef struct IteratorVtable { typedef struct GoIter { struct gas_meter_t *gas_meter; - struct iterator_t state; + /** + * A reference which identifies the iterator and allows finding and accessing the + * actual iterator instance in Go. Once fully initalized, this is immutable. + */ + struct IteratorReference reference; struct IteratorVtable vtable; } GoIter; diff --git a/libwasmvm/src/iterator.rs b/libwasmvm/src/iterator.rs index 2b84dc8d7..408d7c4e7 100644 --- a/libwasmvm/src/iterator.rs +++ b/libwasmvm/src/iterator.rs @@ -6,13 +6,15 @@ use crate::gas_meter::gas_meter_t; use crate::memory::UnmanagedVector; use crate::vtables::Vtable; -// Iterator maintains integer references to some tables on the Go side +/// A reference to some tables on the Go side which allow accessing +/// the actual iterator instance. #[repr(C)] #[derive(Default, Copy, Clone)] -pub struct iterator_t { +pub struct IteratorReference { /// An ID assigned to this contract call pub call_id: u64, - pub iterator_index: u64, + /// An ID assigned to this iterator + pub iterator_id: u64, } // These functions should return GoError but because we don't trust them here, we treat the return value as i32 @@ -22,7 +24,7 @@ pub struct iterator_t { pub struct IteratorVtable { pub next: Option< extern "C" fn( - iterator: iterator_t, + iterator: IteratorReference, gas_meter: *mut gas_meter_t, gas_used: *mut u64, key_out: *mut UnmanagedVector, @@ -32,7 +34,7 @@ pub struct IteratorVtable { >, pub next_key: Option< extern "C" fn( - iterator: iterator_t, + iterator: IteratorReference, gas_meter: *mut gas_meter_t, gas_used: *mut u64, key_out: *mut UnmanagedVector, @@ -41,7 +43,7 @@ pub struct IteratorVtable { >, pub next_value: Option< extern "C" fn( - iterator: iterator_t, + iterator: IteratorReference, gas_meter: *mut gas_meter_t, gas_used: *mut u64, value_out: *mut UnmanagedVector, @@ -55,7 +57,9 @@ impl Vtable for IteratorVtable {} #[repr(C)] pub struct GoIter { pub gas_meter: *mut gas_meter_t, - pub state: iterator_t, + /// A reference which identifies the iterator and allows finding and accessing the + /// actual iterator instance in Go. Once fully initalized, this is immutable. + pub reference: IteratorReference, pub vtable: IteratorVtable, } @@ -67,8 +71,8 @@ impl GoIter { /// which is then filled in Go (see `fn scan`). pub fn stub() -> Self { GoIter { + reference: IteratorReference::default(), gas_meter: std::ptr::null_mut(), - state: iterator_t::default(), vtable: IteratorVtable::default(), } } @@ -84,7 +88,7 @@ impl GoIter { let mut error_msg = UnmanagedVector::default(); let mut used_gas = 0_u64; let go_result: GoError = (next)( - self.state, + self.reference, self.gas_meter, &mut used_gas as *mut u64, &mut output_key as *mut UnmanagedVector, @@ -141,7 +145,7 @@ impl GoIter { fn next_key_or_val( &mut self, next: extern "C" fn( - iterator: iterator_t, + iterator: IteratorReference, gas_meter: *mut gas_meter_t, gas_limit: *mut u64, key_or_value_out: *mut UnmanagedVector, // key if called from next_key; value if called from next_value @@ -152,7 +156,7 @@ impl GoIter { let mut error_msg = UnmanagedVector::default(); let mut used_gas = 0_u64; let go_result: GoError = (next)( - self.state, + self.reference, self.gas_meter, &mut used_gas as *mut u64, &mut output as *mut UnmanagedVector, @@ -190,8 +194,8 @@ mod test { // creates an all null-instance let iter = GoIter::stub(); assert!(iter.gas_meter.is_null()); - assert_eq!(iter.state.call_id, 0); - assert_eq!(iter.state.iterator_index, 0); + assert_eq!(iter.reference.call_id, 0); + assert_eq!(iter.reference.iterator_id, 0); assert!(iter.vtable.next.is_none()); assert!(iter.vtable.next_key.is_none()); assert!(iter.vtable.next_value.is_none()); diff --git a/libwasmvm/src/vtables.rs b/libwasmvm/src/vtables.rs index bd16517d5..4763a6877 100644 --- a/libwasmvm/src/vtables.rs +++ b/libwasmvm/src/vtables.rs @@ -16,12 +16,12 @@ /// /// ``` /// # use wasmvm::UnmanagedVector; -/// # struct iterator_t; +/// # struct IteratorReference; /// # struct gas_meter_t; /// pub struct IteratorVtable { /// pub next: Option< /// extern "C" fn( -/// iterator: iterator_t, +/// iterator: IteratorReference, /// gas_meter: *mut gas_meter_t, /// gas_used: *mut u64, /// key_out: *mut UnmanagedVector,