diff --git a/kvstore/typedvalue.go b/kvstore/typedvalue.go index 95bb4d5f5..957c2f3ff 100644 --- a/kvstore/typedvalue.go +++ b/kvstore/typedvalue.go @@ -1,6 +1,9 @@ package kvstore -import "github.com/iotaledger/hive.go/ierrors" +import ( + "github.com/iotaledger/hive.go/ierrors" + "github.com/iotaledger/hive.go/runtime/syncutils" +) // TypedValue is a generically typed wrapper around a KVStore that provides access to a single value. type TypedValue[V any] struct { @@ -9,6 +12,10 @@ type TypedValue[V any] struct { vToBytes ObjectToBytes[V] bytesToV BytesToObject[V] + + valueCached *V + hasCached *bool + mutex syncutils.Mutex } // NewTypedValue is the constructor for TypedValue. @@ -26,51 +33,127 @@ func NewTypedValue[V any]( } } +// KVStore returns the underlying KVStore. func (t *TypedValue[V]) KVStore() KVStore { return t.kv } // Get gets the given key or an error if an error occurred. func (t *TypedValue[V]) Get() (value V, err error) { - valueBytes, err := t.kv.Get(t.keyBytes) - if err != nil { - return value, ierrors.Wrap(err, "failed to retrieve from KV store") + t.mutex.Lock() + defer t.mutex.Unlock() + + if t.hasCached != nil && !*t.hasCached { + return value, ErrKeyNotFound + } + + if t.valueCached != nil { + return *t.valueCached, nil } - v, _, err := t.bytesToV(valueBytes) - if err != nil { + if valueBytes, valueBytesErr := t.kv.Get(t.keyBytes); valueBytesErr != nil { + if ierrors.Is(valueBytesErr, ErrKeyNotFound) { + t.hasCached = &falsePtr + } + + return value, ierrors.Wrap(valueBytesErr, "failed to retrieve value from KV store") + } else if value, _, err = t.bytesToV(valueBytes); err != nil { return value, ierrors.Wrap(err, "failed to decode value") } - return v, nil + t.valueCached = &value + t.hasCached = &truePtr + + return value, nil } // Has checks whether the given key exists. func (t *TypedValue[V]) Has() (has bool, err error) { - return t.kv.Has(t.keyBytes) + t.mutex.Lock() + defer t.mutex.Unlock() + + if t.hasCached != nil { + return *t.hasCached, nil + } else if has, err = t.kv.Has(t.keyBytes); err != nil { + return false, ierrors.Wrap(err, "failed to check whether key exists") + } + + t.hasCached = &has + + return has, nil } -// Set sets the given key and value. -func (t *TypedValue[V]) Set(value V) (err error) { - valueBytes, err := t.vToBytes(value) - if err != nil { - return ierrors.Wrap(err, "failed to encode value") +// Compute atomically computes and sets a new value based on the current value and some provided computation function. +func (t *TypedValue[V]) Compute(computeFunc func(currentValue V, exists bool) (newValue V, err error)) (newValue V, err error) { + t.mutex.Lock() + defer t.mutex.Unlock() + + currentValue, exists := t.cachedValue() + if !exists && t.hasCached == nil || *t.hasCached { + if valueBytes, valueBytesErr := t.kv.Get(t.keyBytes); valueBytesErr != nil { + if !ierrors.Is(valueBytesErr, ErrKeyNotFound) { + return newValue, ierrors.Wrap(valueBytesErr, "failed to retrieve value from KV store") + } + } else if currentValue, _, err = t.bytesToV(valueBytes); err != nil { + return newValue, ierrors.Wrap(err, "failed to decode value") + } else { + exists = true + } + } + + if newValue, err = computeFunc(currentValue, exists); err != nil { + return newValue, ierrors.Wrap(err, "failed to compute new value") } - err = t.kv.Set(t.keyBytes, valueBytes) - if err != nil { + t.valueCached = &newValue + t.hasCached = &truePtr + + return newValue, nil +} + +// Set sets the given key and value. +func (t *TypedValue[V]) Set(value V) error { + t.mutex.Lock() + defer t.mutex.Unlock() + + if valueBytes, err := t.vToBytes(value); err != nil { + return ierrors.Wrap(err, "failed to encode value") + } else if err = t.kv.Set(t.keyBytes, valueBytes); err != nil { return ierrors.Wrap(err, "failed to store in KV store") } + t.valueCached = &value + t.hasCached = &truePtr + return nil } // Delete deletes the given key from the store. func (t *TypedValue[V]) Delete() (err error) { - err = t.kv.Delete(t.keyBytes) - if err != nil { + t.mutex.Lock() + defer t.mutex.Unlock() + + if err = t.kv.Delete(t.keyBytes); err != nil { return ierrors.Wrap(err, "failed to delete entry from KV store") } + t.valueCached = nil + t.hasCached = &falsePtr + return nil } + +// cachedValue returns the cached value and a boolean indicating whether the value is cached. +func (t *TypedValue[V]) cachedValue() (value V, isCached bool) { + if t.valueCached == nil { + return value, false + } + + return *t.valueCached, true +} + +// truePtr is a pointer to a true value. +var truePtr = true + +// falsePtr is a pointer to a false value. +var falsePtr = false diff --git a/kvstore/typedvalue_test.go b/kvstore/typedvalue_test.go new file mode 100644 index 000000000..4bcefb2f2 --- /dev/null +++ b/kvstore/typedvalue_test.go @@ -0,0 +1,96 @@ +package kvstore_test + +import ( + "encoding/binary" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/iotaledger/hive.go/kvstore" + "github.com/iotaledger/hive.go/kvstore/mapdb" +) + +func TestTypedValue(t *testing.T) { + kvStore := mapdb.NewMapDB() + defer kvStore.Close() + + increase := func(currentValue int, exists bool) (newValue int, err error) { + if !exists { + return 1337, nil + } + + return currentValue + 1, nil + } + + typedValue := kvstore.NewTypedValue[int](kvStore, []byte("key"), intToBytes, bytesToInt) + + value, err := typedValue.Get() + require.Equal(t, 0, value) + require.ErrorIs(t, err, kvstore.ErrKeyNotFound) + + has, err := typedValue.Has() + require.False(t, has) + require.NoError(t, err) + + value, err = typedValue.Get() + require.Equal(t, 0, value) + require.ErrorIs(t, err, kvstore.ErrKeyNotFound) + + value, err = typedValue.Compute(increase) + require.Equal(t, 1337, value) + require.NoError(t, err) + + value, err = typedValue.Compute(increase) + require.Equal(t, 1338, value) + require.NoError(t, err) + + value, err = typedValue.Compute(increase) + require.Equal(t, 1339, value) + require.NoError(t, err) + + value, err = typedValue.Get() + require.Equal(t, 1339, value) + require.NoError(t, err) + + has, err = typedValue.Has() + require.True(t, has) + require.NoError(t, err) + + require.NoError(t, typedValue.Delete()) + + value, err = typedValue.Get() + require.Equal(t, 0, value) + require.ErrorIs(t, err, kvstore.ErrKeyNotFound) + + has, err = typedValue.Has() + require.False(t, has) + require.NoError(t, err) + + typedValue.Set(42) + value, err = typedValue.Get() + require.Equal(t, 42, value) + require.NoError(t, err) + + typedValueRestored := kvstore.NewTypedValue[int](kvStore, []byte("key"), intToBytes, bytesToInt) + has, err = typedValueRestored.Has() + require.True(t, has) + require.NoError(t, err) + + value, err = typedValueRestored.Get() + require.Equal(t, 42, value) + require.NoError(t, err) +} + +func intToBytes(value int) (encoded []byte, err error) { + encoded = make([]byte, 4) + + binary.LittleEndian.PutUint32(encoded, uint32(value)) + + return encoded, nil +} + +func bytesToInt(encoded []byte) (value int, consumed int, err error) { + value = int(binary.LittleEndian.Uint32(encoded)) + + return value, 4, nil +}