diff --git a/store/errors.go b/store/errors.go index 14f03f7e4dee..306d9c460131 100644 --- a/store/errors.go +++ b/store/errors.go @@ -10,4 +10,5 @@ var ( ErrRecordNotFound = errors.New("record not found") ErrUnknownStoreKey = errors.New("unknown store key") ErrInvalidVersion = errors.New("invalid version") + ErrKeyEmpty = errors.New("key empty") ) diff --git a/store/iterator.go b/store/iterator.go index ea9a88e92131..f37bd4051362 100644 --- a/store/iterator.go +++ b/store/iterator.go @@ -2,10 +2,13 @@ package store // Iterator ... type Iterator interface { - // Next moves the iterator to the next key/value pair. It returns whether - // the iterator successfully moved to a new key/value pair. The iterator may - // return false if the underlying database has been closed before the iteration - // has completed, in which case future calls to Error() must return ErrClosed. + // Domain returns the start (inclusive) and end (exclusive) limits of the iterator. + Domain() ([]byte, []byte) + + // Valid returns if the iterator is currently valid. + Valid() bool + + // Next moves the iterator to the next key/value pair. Next() bool // Error returns any accumulated error. Error() should be called after all @@ -25,26 +28,11 @@ type Iterator interface { // IteratorCreator ... type IteratorCreator interface { - // NewIterator creates an iterator over the entire key space contained within - // the backing key-value database. - NewIterator(storeKey string) Iterator - - // NewStartIterator creates an iterator over a subset of a database key space - // starting at a particular key. - NewStartIterator(storeKey string, start []byte) Iterator - - // NewEndIterator creates an iterator over a subset of a database key space - // ending at a particular key. - NewEndIterator(storeKey string, start []byte) Iterator - - // NewPrefixIterator creates an iterator over a subset of a database key space - // with a particular key prefix. - NewPrefixIterator(storeKey string, prefix []byte) Iterator + NewIterator(storeKey string, start, end []byte) (Iterator, error) + NewReverseIterator(storeKey string, start, end []byte) (Iterator, error) } type VersionedIteratorCreator interface { - NewIterator(storeKey string, version uint64) Iterator - NewStartIterator(storeKey string, version uint64, start []byte) Iterator - NewEndIterator(storeKey string, version uint64, start []byte) Iterator - NewPrefixIterator(storeKey string, version uint64, prefix []byte) Iterator + NewIterator(storeKey string, version uint64, start, end []byte) (Iterator, error) + NewReverseIterator(storeKey string, version uint64, start, end []byte) (Iterator, error) } diff --git a/store/storage/db.go b/store/storage/db.go index 1c32dae5a08f..ac04b72c2360 100644 --- a/store/storage/db.go +++ b/store/storage/db.go @@ -203,18 +203,10 @@ func (db *Database) GetLatestVersion() (uint64, error) { return db.vdb.GetLatestVersion() } -func (db *Database) NewIterator(storekey string, version uint64) store.Iterator { - panic("not implemented") +func (db *Database) NewIterator(storeKey string, version uint64, start, end []byte) (store.Iterator, error) { + panic("not implemented!") } -func (db *Database) NewStartIterator(storekey string, version uint64, start []byte) store.Iterator { - panic("not implemented") -} - -func (db *Database) NewEndIterator(storekey string, version uint64, start []byte) store.Iterator { - panic("not implemented") -} - -func (db *Database) NewPrefixIterator(storekey string, version uint64, prefix []byte) store.Iterator { - panic("not implemented") +func (db *Database) NewReverseIterator(storeKey string, version uint64, start, end []byte) (store.Iterator, error) { + panic("not implemented!") } diff --git a/store/storage/rocksdb/db.go b/store/storage/rocksdb/db.go index c1e6cde1bb7b..f51833baaa3c 100644 --- a/store/storage/rocksdb/db.go +++ b/store/storage/rocksdb/db.go @@ -45,7 +45,7 @@ func (db *Database) Close() error { return nil } -func (db *Database) GetSlice(storeKey string, version uint64, key []byte) (*grocksdb.Slice, error) { +func (db *Database) getSlice(storeKey string, version uint64, key []byte) (*grocksdb.Slice, error) { return db.storage.GetCF( newTSReadOptions(version), db.cfHandle, @@ -74,7 +74,7 @@ func (db *Database) GetLatestVersion() (uint64, error) { } func (db *Database) Has(storeKey string, version uint64, key []byte) (bool, error) { - slice, err := db.GetSlice(storeKey, version, key) + slice, err := db.getSlice(storeKey, version, key) if err != nil { return false, err } @@ -83,7 +83,7 @@ func (db *Database) Has(storeKey string, version uint64, key []byte) (bool, erro } func (db *Database) Get(storeKey string, version uint64, key []byte) ([]byte, error) { - slice, err := db.GetSlice(storeKey, version, key) + slice, err := db.getSlice(storeKey, version, key) if err != nil { return nil, fmt.Errorf("failed to get RocksDB slice: %w", err) } @@ -131,19 +131,19 @@ func (db *Database) NewBatch(version uint64) store.Batch { return NewBatch(db, version) } -func (db *Database) NewIterator(storeKey string, version uint64) store.Iterator { - panic("not implemented!") -} +func (db *Database) NewIterator(storeKey string, version uint64, start, end []byte) (store.Iterator, error) { + if (start != nil && len(start) == 0) || (end != nil && len(end) == 0) { + return nil, store.ErrKeyEmpty + } -func (db *Database) NewStartIterator(storeKey string, version uint64, start []byte) store.Iterator { - panic("not implemented!") -} + prefix := storePrefix(storeKey) + start, end = iterateWithPrefix(prefix, start, end) -func (db *Database) NewEndIterator(storeKey string, version uint64, start []byte) store.Iterator { - panic("not implemented!") + itr := db.storage.NewIteratorCF(newTSReadOptions(version), db.cfHandle) + return newRocksDBIterator(itr, prefix, start, end, false), nil } -func (db *Database) NewPrefixIterator(storeKey string, version uint64, prefix []byte) store.Iterator { +func (db *Database) NewReverseIterator(storeKey string, version uint64, start, end []byte) (store.Iterator, error) { panic("not implemented!") } @@ -187,3 +187,53 @@ func copyAndFreeSlice(s *grocksdb.Slice) []byte { return v } + +func cloneAppend(bz []byte, tail []byte) (res []byte) { + res = make([]byte, len(bz)+len(tail)) + + copy(res, bz) + copy(res[len(bz):], tail) + + return res +} + +func copyIncr(bz []byte) []byte { + if len(bz) == 0 { + panic("copyIncr expects non-zero bz length") + } + + ret := make([]byte, len(bz)) + copy(ret, bz) + + for i := len(bz) - 1; i >= 0; i-- { + if ret[i] < byte(0xFF) { + ret[i]++ + return ret + } + + ret[i] = byte(0x00) + + if i == 0 { + // overflow + return nil + } + } + + return nil +} + +func iterateWithPrefix(prefix, begin, end []byte) ([]byte, []byte) { + if len(prefix) == 0 { + return begin, end + } + + begin = cloneAppend(prefix, begin) + + if end == nil { + end = copyIncr(prefix) + } else { + end = cloneAppend(prefix, end) + } + + return begin, end +} diff --git a/store/storage/rocksdb/db_test.go b/store/storage/rocksdb/db_test.go index d2a457261158..329a5d40efad 100644 --- a/store/storage/rocksdb/db_test.go +++ b/store/storage/rocksdb/db_test.go @@ -1,11 +1,166 @@ package rocksdb import ( + "fmt" + "sort" "testing" "github.com/stretchr/testify/require" ) -func TestFoo(t *testing.T) { - require.Equal(t, 1, 1) +const ( + storeKey1 = "store1" +) + +func TestDatabase_Close(t *testing.T) { + db, err := New(t.TempDir()) + require.NoError(t, err) + require.NoError(t, db.Close()) +} + +func TestDatabase_LatestVersion(t *testing.T) { + db, err := New(t.TempDir()) + require.NoError(t, err) + + lv, err := db.GetLatestVersion() + require.NoError(t, err) + require.Zero(t, lv) + + expected := uint64(1) + + err = db.SetLatestVersion(expected) + require.NoError(t, err) + + lv, err = db.GetLatestVersion() + require.NoError(t, err) + require.Equal(t, expected, lv) +} + +func TestDatabase_CRUD(t *testing.T) { + db, err := New(t.TempDir()) + require.NoError(t, err) + + ok, err := db.Has(storeKey1, 1, []byte("key")) + require.NoError(t, err) + require.False(t, ok) + + err = db.Set(storeKey1, 1, []byte("key"), []byte("value")) + require.NoError(t, err) + + ok, err = db.Has(storeKey1, 1, []byte("key")) + require.NoError(t, err) + require.True(t, ok) + + val, err := db.Get(storeKey1, 1, []byte("key")) + require.NoError(t, err) + require.Equal(t, []byte("value"), val) + + err = db.Delete(storeKey1, 1, []byte("key")) + require.NoError(t, err) + + ok, err = db.Has(storeKey1, 1, []byte("key")) + require.NoError(t, err) + require.False(t, ok) + + val, err = db.Get(storeKey1, 1, []byte("key")) + require.NoError(t, err) + require.Nil(t, val) +} + +func TestDatabase_Batch(t *testing.T) { + db, err := New(t.TempDir()) + require.NoError(t, err) + + batch := db.NewBatch(1) + + for i := 0; i < 100; i++ { + err = batch.Set(storeKey1, []byte(fmt.Sprintf("key%d", i)), []byte("value")) + require.NoError(t, err) + } + + for i := 0; i < 100; i++ { + if i%10 == 0 { + err = batch.Delete(storeKey1, []byte(fmt.Sprintf("key%d", i))) + require.NoError(t, err) + } + } + + require.NotZero(t, batch.Size()) + + err = batch.Write() + require.NoError(t, err) + + lv, err := db.GetLatestVersion() + require.NoError(t, err) + require.Equal(t, uint64(1), lv) + + for i := 0; i < 100; i++ { + ok, err := db.Has(storeKey1, 1, []byte(fmt.Sprintf("key%d", i))) + require.NoError(t, err) + + if i%10 == 0 { + require.False(t, ok) + } else { + require.True(t, ok) + } + } +} + +func TestDatabase_ResetBatch(t *testing.T) { + db, err := New(t.TempDir()) + require.NoError(t, err) + + batch := db.NewBatch(1) + + for i := 0; i < 100; i++ { + err = batch.Set(storeKey1, []byte(fmt.Sprintf("key%d", i)), []byte("value")) + require.NoError(t, err) + } + + for i := 0; i < 100; i++ { + if i%10 == 0 { + err = batch.Delete(storeKey1, []byte(fmt.Sprintf("key%d", i))) + require.NoError(t, err) + } + } + + require.NotZero(t, batch.Size()) + batch.Reset() + require.NotPanics(t, func() { batch.Reset() }) + + // There is an initial cost of 12 bytes for the batch header + require.LessOrEqual(t, batch.Size(), 12) +} + +func TestDatabase_StartIterator(t *testing.T) { + db, err := New(t.TempDir()) + require.NoError(t, err) + + batch := db.NewBatch(1) + + keys := make([]string, 100) + for i := 0; i < 100; i++ { + key := fmt.Sprintf("key%d", i) + err = batch.Set(storeKey1, []byte(key), []byte("value")) + require.NoError(t, err) + + keys[i] = key + } + + sort.Strings(keys) + + err = batch.Write() + require.NoError(t, err) + + iter, err := db.NewIterator(storeKey1, 1, []byte("key0"), nil) + require.NoError(t, err) + + defer iter.Close() + + var i int + for ; iter.Valid(); iter.Next() { + require.Equal(t, []byte(keys[i]), iter.Key()) + require.Equal(t, []byte("value"), iter.Value()) + i++ + } } diff --git a/store/storage/rocksdb/iterator.go b/store/storage/rocksdb/iterator.go new file mode 100644 index 000000000000..69a2c581a2c2 --- /dev/null +++ b/store/storage/rocksdb/iterator.go @@ -0,0 +1,129 @@ +package rocksdb + +import ( + "bytes" + + "cosmossdk.io/store/v2" + "github.com/linxGnu/grocksdb" +) + +var _ store.Iterator = (*iterator)(nil) + +type iterator struct { + source *grocksdb.Iterator + prefix, start, end []byte + reverse bool + invalid bool +} + +func newRocksDBIterator(source *grocksdb.Iterator, prefix, start, end []byte, reverse bool) *iterator { + if reverse { + if end == nil { + source.SeekToLast() + } else { + source.Seek(end) + + if source.Valid() { + eoaKey := copyAndFreeSlice(source.Key()) // end or after key + if bytes.Compare(end, eoaKey) <= 0 { + source.Prev() + } + } else { + source.SeekToLast() + } + } + } else { + if start == nil { + source.SeekToFirst() + } else { + source.Seek(start) + } + } + + return &iterator{ + source: source, + prefix: prefix, + start: start, + end: end, + reverse: reverse, + invalid: false, + } +} + +func (itr *iterator) Domain() ([]byte, []byte) { + return itr.start, itr.end +} + +func (itr *iterator) Valid() bool { + // once invalid, forever invalid + if itr.invalid { + return false + } + + // if source has error, consider it invalid + if err := itr.source.Err(); err != nil { + itr.invalid = true + return false + } + + // if source is invalid, consider it invalid + if !itr.source.Valid() { + itr.invalid = true + return false + } + + // if key is at the end or past it, consider it invalid + start := itr.start + end := itr.end + key := copyAndFreeSlice(itr.source.Key()) + + if itr.reverse { + if start != nil && bytes.Compare(key, start) < 0 { + itr.invalid = true + return false + } + } else { + if end != nil && bytes.Compare(end, key) <= 0 { + itr.invalid = true + return false + } + } + + return true +} + +func (itr *iterator) Key() []byte { + itr.assertIsValid() + return copyAndFreeSlice(itr.source.Key())[len(itr.prefix):] +} + +func (itr *iterator) Value() []byte { + itr.assertIsValid() + return copyAndFreeSlice(itr.source.Value()) +} + +func (itr iterator) Next() bool { + itr.assertIsValid() + + if itr.reverse { + itr.source.Prev() + } else { + itr.source.Next() + } + + return itr.Valid() +} + +func (itr *iterator) Error() error { + return itr.source.Err() +} + +func (itr *iterator) Close() { + itr.source.Close() +} + +func (itr *iterator) assertIsValid() { + if !itr.Valid() { + panic("iterator is invalid") + } +}