Skip to content

Commit

Permalink
Move ads in its own package again and fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
karimodm committed Aug 1, 2023
1 parent 0203483 commit fea872b
Show file tree
Hide file tree
Showing 12 changed files with 631 additions and 181 deletions.
61 changes: 61 additions & 0 deletions ads/ads.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package ds

import (
"github.com/iotaledger/hive.go/ads/amap"
"github.com/iotaledger/hive.go/ads/aset"
"github.com/iotaledger/hive.go/ds/types"
"github.com/iotaledger/hive.go/kvstore"
)

// Map is a map that can produce proofs for its values which can be verified against a known merkle root
// that is formed using a sparse merkle tree.
type Map[K, V any] interface {
// Set sets the given key to the given value.
Set(key K, value V) error

// Get returns the value for the given key.
Get(key K) (value V, exists bool, err error)

// Has returns true if the given key exists.
Has(key K) (exists bool, err error)

// Delete deletes the given key.
Delete(key K) (deleted bool, err error)

// Stream streams all key-value pairs to the given consumer function.
Stream(consumerFunc func(key K, value V) error) error

// Commit commits the changes to the underlying store.
Commit() error

// Root returns the root of the sparse merkle tree.
Root() types.Identifier

// Size returns the number of elements in the map.
Size() int

// WasRestoredFromStorage returns true if the map was restored from an existing storage.
WasRestoredFromStorage() bool
}

// NewMap creates a new AuthenticatedMap.
func NewMap[K, V any](store kvstore.KVStore, kToBytes kvstore.ObjectToBytes[K], bytesToK kvstore.BytesToObject[K], vToBytes kvstore.ObjectToBytes[V], bytesToV kvstore.BytesToObject[V]) Map[K, V] {
return amap.NewAuthenticatedMap(store, kToBytes, bytesToK, vToBytes, bytesToV)
}

// Set is a sparse merkle tree based set.
type Set[K any] interface {
Root() types.Identifier
Add(key K) error
Has(key K) (exists bool, err error)
Delete(key K) (deleted bool, err error)
Stream(consumerFunc func(key K) error) error
Commit() error
Size() int
WasRestoredFromStorage() bool
}

// NewSet creates a new sparse merkle tree based map.
func NewSet[K any](store kvstore.KVStore, kToBytes kvstore.ObjectToBytes[K], bytesToK kvstore.BytesToObject[K]) Set[K] {
return aset.NewAuthenticatedSet(store, kToBytes, bytesToK)
}
51 changes: 25 additions & 26 deletions ds/authenticated_map_impl.go → ads/amap/authenticated_map.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package ds
package amap

import (
"crypto/sha256"
Expand All @@ -21,7 +21,8 @@ const (
nonEmptyLeaf = 1
)

type authenticatedMap[K, V any] struct {
// AuthenticatedMap is a sparse merkle tree based map.
type AuthenticatedMap[K, V any] struct {
rawKeysStore *kvstore.TypedStore[K, types.Empty]
tree *smt.SMT
size *typedkey.Number[uint64]
Expand All @@ -34,14 +35,15 @@ type authenticatedMap[K, V any] struct {
bytesToV kvstore.BytesToObject[V]
}

func newAuthenticatedMap[K, V any](
// NewAuthenticatedMap creates a new authenticated map.
func NewAuthenticatedMap[K, V any](
store kvstore.KVStore,
kToBytes kvstore.ObjectToBytes[K],
bytesToK kvstore.BytesToObject[K],
vToBytes kvstore.ObjectToBytes[V],
bytesToV kvstore.BytesToObject[V],
) *authenticatedMap[K, V] {
newMap := &authenticatedMap[K, V]{
) *AuthenticatedMap[K, V] {
newMap := &AuthenticatedMap[K, V]{
rawKeysStore: kvstore.NewTypedStore(lo.PanicOnErr(store.WithExtendedRealm([]byte{prefixRawKeysStorage})), kToBytes, bytesToK, types.Empty.Bytes, types.EmptyFromBytes),
size: typedkey.NewNumber[uint64](store, prefixSizeKey),
root: typedkey.NewBytes(store, prefixRootKey),
Expand All @@ -62,12 +64,12 @@ func newAuthenticatedMap[K, V any](
}

// WasRestoredFromStorage returns true if the map has been restored from storage.
func (m *authenticatedMap[K, V]) WasRestoredFromStorage() bool {
func (m *AuthenticatedMap[K, V]) WasRestoredFromStorage() bool {
return len(m.root.Get()) != 0
}

// Root returns the root of the state sparse merkle tree at the latest committed slot.
func (m *authenticatedMap[K, V]) Root() (root types.Identifier) {
func (m *AuthenticatedMap[K, V]) Root() (root types.Identifier) {
m.mutex.RLock()
defer m.mutex.RUnlock()

Expand All @@ -77,7 +79,7 @@ func (m *authenticatedMap[K, V]) Root() (root types.Identifier) {
}

// Set sets the output to unspent outputs set.
func (m *authenticatedMap[K, V]) Set(key K, value V) error {
func (m *AuthenticatedMap[K, V]) Set(key K, value V) error {
m.mutex.Lock()
defer m.mutex.Unlock()

Expand Down Expand Up @@ -112,15 +114,15 @@ func (m *authenticatedMap[K, V]) Set(key K, value V) error {
}

// Size returns the number of elements in the map.
func (m *authenticatedMap[K, V]) Size() int {
func (m *AuthenticatedMap[K, V]) Size() int {
m.mutex.RLock()
defer m.mutex.RUnlock()

return int(m.size.Get())
}

// Commit persists the current state of the map to the storage.
func (m *authenticatedMap[K, V]) Commit() error {
func (m *AuthenticatedMap[K, V]) Commit() error {
m.mutex.Lock()
defer m.mutex.Unlock()

Expand All @@ -130,7 +132,7 @@ func (m *authenticatedMap[K, V]) Commit() error {
}

// Delete removes the key from the map.
func (m *authenticatedMap[K, V]) Delete(key K) (deleted bool, err error) {
func (m *AuthenticatedMap[K, V]) Delete(key K) (deleted bool, err error) {
m.mutex.Lock()
defer m.mutex.Unlock()

Expand Down Expand Up @@ -164,7 +166,7 @@ func (m *authenticatedMap[K, V]) Delete(key K) (deleted bool, err error) {
}

// Has returns true if the key is in the set.
func (m *authenticatedMap[K, V]) Has(key K) (has bool, err error) {
func (m *AuthenticatedMap[K, V]) Has(key K) (has bool, err error) {
m.mutex.Lock()
defer m.mutex.Unlock()

Expand All @@ -177,7 +179,7 @@ func (m *authenticatedMap[K, V]) Has(key K) (has bool, err error) {
}

// Get returns the value for the given key.
func (m *authenticatedMap[K, V]) Get(key K) (value V, exists bool, err error) {
func (m *AuthenticatedMap[K, V]) Get(key K) (value V, exists bool, err error) {
m.mutex.Lock()
defer m.mutex.Unlock()

Expand All @@ -188,14 +190,10 @@ func (m *authenticatedMap[K, V]) Get(key K) (value V, exists bool, err error) {

valueBytes, err := m.tree.Get(keyBytes)
if err != nil {
if ierrors.Is(err, kvstore.ErrKeyNotFound) {
return value, false, err
}

return value, false, ierrors.Wrap(err, "failed to get from tree")
}

if len(valueBytes) == 0 {
if valueBytes == nil {
return value, false, err
}

Expand All @@ -212,48 +210,49 @@ func (m *authenticatedMap[K, V]) Get(key K) (value V, exists bool, err error) {
}

// Stream streams all the keys and values.
func (m *authenticatedMap[K, V]) Stream(callback func(key K, value V) error) (err error) {
func (m *AuthenticatedMap[K, V]) Stream(callback func(key K, value V) error) error {
m.mutex.Lock()
defer m.mutex.Unlock()

var innerErr error
if iterationErr := m.rawKeysStore.IterateKeys([]byte{}, func(key K) bool {
keyBytes, err := m.kToBytes(key)
if err != nil {
err = ierrors.Wrapf(err, "failed to serialize key %s", key)
innerErr = ierrors.Wrapf(err, "failed to serialize key %s", keyBytes)

return false
}

valueBytes, valueErr := m.tree.Get(keyBytes)
if valueErr != nil {
err = ierrors.Wrapf(valueErr, "failed to get value for key %s", keyBytes)
innerErr = ierrors.Wrapf(valueErr, "failed to get value for key %s", keyBytes)

return false
}

value, _, valueErr := m.bytesToV(valueBytes)
if valueErr != nil {
err = ierrors.Wrapf(valueErr, "failed to deserialize value %s", valueBytes)
innerErr = ierrors.Wrapf(valueErr, "failed to deserialize value %s", valueBytes)

return false
}

if callbackErr := callback(key, value); callbackErr != nil {
err = ierrors.Wrapf(callbackErr, "failed to execute callback for key %s", keyBytes)
innerErr = ierrors.Wrapf(callbackErr, "failed to execute callback for key %s", keyBytes)

return false
}

return true
}); iterationErr != nil {
err = ierrors.Wrap(iterationErr, "failed to iterate over raw keys")
return ierrors.Wrap(iterationErr, "failed to iterate over raw keys")
}

return
return innerErr
}

// has returns true if the key is in the map.
func (m *authenticatedMap[K, V]) has(keyBytes []byte) (has bool, err error) {
func (m *AuthenticatedMap[K, V]) has(keyBytes []byte) (has bool, err error) {
value, err := m.tree.Get(keyBytes)
if err != nil {
return false, ierrors.Wrap(err, "failed to get from tree")
Expand Down
61 changes: 41 additions & 20 deletions ds/authenticated_map_test.go → ads/amap/authenticated_map_test.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
package ds_test
package amap_test

import (
"testing"

"github.com/stretchr/testify/require"

"github.com/iotaledger/hive.go/ads"
"github.com/iotaledger/hive.go/ads/amap"
"github.com/iotaledger/hive.go/ierrors"
"github.com/iotaledger/hive.go/kvstore/mapdb"
"github.com/iotaledger/hive.go/lo"
)

var ErrStopIteration = ierrors.New("stop")

func TestMap(t *testing.T) {
store := mapdb.NewMapDB()
newMap := ads.NewMap(store,
newMap := amap.NewAuthenticatedMap(store,
testKey.Bytes,
testKeyFromBytes,
testValue.Bytes,
Expand All @@ -20,62 +24,77 @@ func TestMap(t *testing.T) {
keys := []testKey{testKey([]byte{'a'}), testKey([]byte{'b'})}
values := []testValue{testValueFromString("test value"), testValueFromString("test value 1")}
// Test setting and getting a value
require.Equal(t, 0, newMap.Size())
require.False(t, newMap.WasRestoredFromStorage())

for i, k := range keys {
newMap.Set(k, values[i])
}

for i, k := range keys {
exist := newMap.Has(k)
exist, err := newMap.Has(k)
require.NoError(t, err)
require.True(t, exist)
gotValue, exists := newMap.Get(k)
gotValue, exists, err := newMap.Get(k)
require.NoError(t, err)
require.True(t, exists)
require.ElementsMatch(t, values[i], gotValue)
}

// Test setting a value to empty, which should panic
require.Panics(t, func() { newMap.Set(keys[0], testValue{}) })
require.Equal(t, len(keys), newMap.Size())

// Test setting a value to empty, which should be just fine
require.NoError(t, newMap.Set(keys[0], testValue{}))

// Test getting a non-existing key
gotValue, exists := newMap.Get(testKey([]byte{'c'}))
gotValue, exists, err := newMap.Get(testKey([]byte{'c'}))
require.NoError(t, err)
require.False(t, exists)
require.Nil(t, gotValue)

// overwrite the value of keys[0]
newValue := testValueFromString("test")
newMap.Set(keys[0], newValue)
gotValue, exists = newMap.Get(keys[0])
gotValue, exists, err = newMap.Get(keys[0])
require.NoError(t, err)
require.True(t, exists)
require.ElementsMatch(t, newValue, gotValue)

// get the root of having 2 keys
oldRoot := newMap.Root()

// Test deleting a key
require.True(t, newMap.Delete(keys[0]))
exists = newMap.Has(keys[0])
require.True(t, lo.PanicOnErr(newMap.Delete(keys[0])))
exists, err = newMap.Has(keys[0])
require.NoError(t, err)
require.False(t, exists)
_, exists = newMap.Get(keys[0])
_, exists, err = newMap.Get(keys[0])
require.NoError(t, err)
require.False(t, exists)

// The root now should be different
require.NotEqualValues(t, oldRoot, newMap.Root())

// Test deleting a non-existent key
require.False(t, newMap.Delete(keys[0]))
require.False(t, lo.PanicOnErr(newMap.Delete(keys[0])))

require.NoError(t, newMap.Commit())

// The root should be same if loading the same store to map
newMap1 := ads.NewMap(store,
newMap1 := amap.NewAuthenticatedMap(store,
testKey.Bytes,
testKeyFromBytes,
testValue.Bytes,
testValueFromBytes,
)

require.True(t, newMap.WasRestoredFromStorage())
require.EqualValues(t, newMap.Root(), newMap1.Root())
}

func TestStreamMap(t *testing.T) {
store := mapdb.NewMapDB()
newMap := ads.NewMap[testKey, testValue](store,
newMap := amap.NewAuthenticatedMap[testKey, testValue](store,
testKey.Bytes,
testKeyFromBytes,
testValue.Bytes,
Expand All @@ -91,9 +110,9 @@ func TestStreamMap(t *testing.T) {
}

seen := make(map[testKey]testValue)
err := newMap.Stream(func(key testKey, value testValue) bool {
err := newMap.Stream(func(key testKey, value testValue) error {
seen[key] = value
return true
return nil
})
require.NoError(t, err)

Expand All @@ -106,12 +125,14 @@ func TestStreamMap(t *testing.T) {

// consume function returns false, only 1 element is visited.
seenKV := make(map[testKey]testValue)
err = newMap.Stream(func(key testKey, value testValue) bool {
err = newMap.Stream(func(key testKey, value testValue) error {
seenKV[key] = value

return false
return ErrStopIteration
})
require.NoError(t, err)
// the error is expected because we stopped the iteration early
require.Error(t, err)
require.ErrorIs(t, err, ErrStopIteration)
require.Equal(t, 1, len(seenKV))
for k, v := range seenKV {
expectedV, has := kvMap[k]
Expand Down
Loading

0 comments on commit fea872b

Please sign in to comment.