diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..29f1c22 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +/target +/Cargo.lock +*.dot +*.png +.idea +*.sdb diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..362368d --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,23 @@ +workspace = { members = ["example"] } + +[package] +name = "spacedb" +version = "0.1.0" +edition = "2021" +description = "A cryptographically verifiable data store and universal accumulator for the Spaces protocol." +repository = "https://github.com/spacesprotocol/spacedb" +license = "Apache-2.0" + +[dependencies] +libc = { version = "0.2.150", optional = true } +bincode = { version = "2.0.0-rc.3", default-features = false, features = ["alloc"] } +hex = { version = "0.4.3", optional = true } + +[dependencies.sha2] +git = "https://github.com/risc0/RustCrypto-hashes" +tag = "sha2-v0.10.6-risczero.0" +default-features = false + +[features] +default = ["std"] +std = ["libc", "hex", "bincode/derive", "bincode/std"] diff --git a/README.md b/README.md new file mode 100644 index 0000000..bd16047 --- /dev/null +++ b/README.md @@ -0,0 +1,138 @@ +# SpaceDB + +Note: this project is still under active development and should be considered experimental. + +SpaceDB is a cryptographically verifiable data store and universal accumulator for the [Spaces protocol](https://spacesprotocol.com). It's a Merkle-ized binary trie described in the [Merklix](https://blog.vermorel.com/pdf/merklix-tree-for-bitcoin-2018-07.pdf) paper and explained in detail [here](https://spacesprotocol.org/#binary-trie). + + +## Features + +- Fast, portable, single-file database. +- MVCC-based concurrency control with multi-reader/single-writer lock-free access. +- Provides compact proofs of membership/non-membership for batches of elements through subtrees. +- Subtrees act as cryptographic accumulators and can be updated independently. +- `no_std` support, particularly for use within RISC0 zkVM and leverages SHA256 acceleration. +- Accumulator keeps a constant size state of a single 32-byte tree root. + + + +## Usage + +```rust +use spacedb::db::Database; + + +let db = Database::open("example.sdb")?; + +// Insert some data +let mut tx = db.begin_write()?; +for i in 0..100 { + let key = format!("key{}", i); + let value = format!("value{}", i); + tx.insert(db.hash(key.as_bytes()), value.into_bytes())?; +} +tx.commit()?; + +let mut snapshot = db.begin_read()?; +println!("Tree root: {}", hex::encode(snapshot.root()?)); + +// Prove a subset of the keys +let keys_to_prove: Vec<_> = (0..10) + .map(|i| format!("key{}", i)) + // prove exclusion of some other keys + .chain((0..5).map(|i| format!("other{}", i))) + .map(|key| db.hash(key.as_bytes())) + .collect(); + +// Reveal relevant nodes needed to prove the specified set of keys +let mut subtree = snapshot.prove_all(&keys_to_prove)?; + +// Will have the exact same root as the snapshot +println!("Subtree root: {}", hex::encode(subtree.root().unwrap())); + +// Inclusion and exclusion proofs +assert!(subtree.contains(&db.hash("key0".as_bytes())).unwrap()); +assert!(!subtree.contains(&db.hash("other0".as_bytes())).unwrap()); + +// Proving exclusion of "other100" fails since we didn't reveal +// relevant branches needed to traverse its path in this subtree +assert!(subtree.contains(&db.hash("other100".as_bytes())).is_err()); + +``` + + + +## Subtrees + +Subtrees can function as cryptographic accumulators, allowing clients to verify and update their state without keeping a database. + +```rust + +// Client maintains a 32-byte tree root +let mut accumulator_root = snapshot.root()?; +assert_eq!(accumulator_root, subtree.root().unwrap(), "Roots must match"); + +// Update leaves +for (key, value) in subtree.iter_mut() { + *value = "new value".to_string().into_bytes(); +} + +// Inserting a non-existent key (must be provably absent) +let key = subtree.hash("other0".as_bytes()); +subtree.insert(key, "new value".into_bytes()).unwrap(); + +// Updating the accumulator root +accumulator_root = subtree.root().unwrap(); + +``` + +## Using in RISC0 zkVM + +Subtrees work in `no_std` environments utilizing the SHA256 accelerator when running inside the RISC0 zkVM. + +```toml +[dependencies] +spacedb = { version = "0.1", default-features = false } +``` + + + + +## Key Iteration + +Iterate over all keys in a given snapshot: + +```rust +let db = Database::open("my.sdb")?; +let snapshot = db.begin_read()?; + +for (key, value) in snapshot.iter().filter_map(Result::ok) { + // do something ... +} + +``` + + + +## Snapshot iteration + +Iterate over all snapshots: + +```rust +let db = Database::open("my.sdb")?; + +for snapshot in db.iter().filter_map(Result::ok) { + let root = snapshot.root()?; + println!("Snapshot Root: {}", hex::encode(root)); +} +``` + +## Prior Art + +Merkle-ized tries, including variations like Patricia tries and Merkle prefix trees, are foundational structures that have been used in numerous projects and cryptocurrencies. Some other libraries that implement some form of Merkle-ized binary tries include +[liburkel](https://github.com/chjj/liburkel) which this library initially drew some inspiration from — although SpaceDB is generally around ~20% faster, and [multiproof,](https://github.com/gballet/multiproof-rs/tree/master) but they either lack memory safety, core features such as subtrees/accumulators needed for Spaces protocol or are unmaintained. Other popular cryptographically verifiable data stores include [Trillian](https://github.com/google/trillian) used for [Certificate Transparency](https://www.certificate-transparency.org/) + + +## License + +This project is licensed under the [Apache 2.0](LICENSE). \ No newline at end of file diff --git a/example/Cargo.toml b/example/Cargo.toml new file mode 100644 index 0000000..7a3d47e --- /dev/null +++ b/example/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "example" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +spacedb = {version = "*", path = ".." } +hex = "0.4.3" + diff --git a/example/src/main.rs b/example/src/main.rs new file mode 100644 index 0000000..d696f8a --- /dev/null +++ b/example/src/main.rs @@ -0,0 +1,43 @@ +use spacedb::db::Database; + +fn main() -> Result<(), std::io::Error> { + let db = Database::memory()?; + + // Insert some data + let mut tx = db.begin_write()?; + for i in 0..100 { + let key = format!("key{}", i); + let value = format!("value{}", i); + tx.insert(db.hash(key.as_bytes()), value.into_bytes())?; + } + tx.commit()?; + + // Get the committed snapshot + let mut snapshot = db.begin_read()?; + println!("Tree root: {}", hex::encode(snapshot.root()?)); + + // Prove a subset of the keys + let keys_to_prove: Vec<_> = (0..10) + .map(|i| format!("key{}", i)) + // prove exclusion of some other keys + .chain((0..5).map(|i| format!("other{}", i))) + .map(|key| db.hash(key.as_bytes())) + .collect(); + + // reveal the relevant nodes needed to prove the specified set of keys + let subtree = snapshot.prove_all(&keys_to_prove)?; + + // Will have the exact same root as the snapshot + println!("Subtree root: {}", hex::encode(subtree.root().unwrap())); + + // Prove inclusion + assert!(subtree.contains(&db.hash("key0".as_bytes())).unwrap()); + + // Prove exclusion + assert!(!subtree.contains(&db.hash("other0".as_bytes())).unwrap()); + + // We don't have enough data to prove key "other100" is not in the subtree + // as the relevant branches needed to prove it are not included + assert!(subtree.contains(&db.hash("other100".as_bytes())).is_err()); + Ok(()) +} diff --git a/src/db.rs b/src/db.rs new file mode 100644 index 0000000..dbc5e43 --- /dev/null +++ b/src/db.rs @@ -0,0 +1,304 @@ +use crate::{ + fs::{FileBackend, StorageBackend}, + node::NodeInner, + tx::{ReadTransaction, WriteTransaction}, + Configuration, Hash, NodeHasher, Sha256Hasher, +}; +use bincode::{config, error::DecodeError, Decode, Encode}; +use sha2::{Digest as _, Sha256}; +use std::{ + fs::OpenOptions, + io, + sync::{Arc, Mutex}, +}; + +const HEADER_MAGIC: [u8; 9] = [b's', b'p', b'a', b'c', b'e', b':', b'/', b'/', b'.']; +pub(crate) const PAGE_SIZE: usize = 4096; + +#[derive(Debug, Encode, Decode, PartialEq, Eq)] +pub struct DatabaseHeader { + pub magic: [u8; 9], + pub version: u8, + pub savepoint: SavePoint, +} + +pub struct Database { + pub(crate) header: Arc>, + pub(crate) file: Box, + pub config: Configuration, +} + +#[derive(Copy, Clone, Encode, Decode, Debug, Eq, PartialEq, Hash)] +pub struct SavePoint { + pub root: Record, + pub previous_save_point: Record, +} + +#[derive(Copy, Clone, Encode, Decode, Debug, Eq, PartialEq, Hash)] +pub struct Record { + pub offset: u64, + pub size: u32, +} + +pub const EMPTY_RECORD: Record = Record { offset: 0, size: 0 }; + +impl DatabaseHeader { + pub fn new() -> Self { + Self { + magic: HEADER_MAGIC, + version: 0, + savepoint: SavePoint { + root: EMPTY_RECORD, + previous_save_point: EMPTY_RECORD, + }, + } + } + + pub(crate) fn to_bytes(&self) -> Vec { + let config = config::standard() + .with_fixed_int_encoding() + .with_little_endian(); + let mut raw = bincode::encode_to_vec(self, config).unwrap(); + // add 24 bytes padding + 4 bytes checksum + raw.extend_from_slice(&[0; 26]); + let mut hasher = Sha256::new(); + hasher.update(&raw); + let checksum = hasher.finalize(); + raw.extend_from_slice(&checksum[..4]); + raw + } + + fn from_bytes(bytes: &[u8]) -> Result { + // calc checksum + let mut hasher = Sha256::new(); + hasher.update(&bytes[..60]); + let checksum = hasher.finalize(); + + if bytes[60..64] != checksum[..4] { + return Err(DecodeError::Other("Checksum mismatch")); + } + + let config = config::standard() + .with_fixed_int_encoding() + .with_little_endian(); + let (h, _) = bincode::decode_from_slice(bytes, config)?; + + Ok(h) + } + + pub(crate) fn len(&self) -> u64 { + if self.savepoint.is_empty() { + return (PAGE_SIZE * 2) as u64; + } + + let save_point_len = self.savepoint.len(); + return (save_point_len + PAGE_SIZE as u64 - 1) / PAGE_SIZE as u64 * PAGE_SIZE as u64; + } +} + +impl Database { + pub fn open(path: &str) -> Result { + let file = OpenOptions::new() + .read(true) + .write(true) + .create(true) + .open(path) + .unwrap(); + let config = Configuration::standard(); + Self::new(Box::new(FileBackend::new(file)?), config) + } + + pub fn memory() -> Result { + let file = Box::new(crate::fs::MemoryBackend::new()); + let config = Configuration::standard(); + Self::new(file, config) + } +} + +impl Database { + pub fn new(file: Box, config: Configuration) -> Result { + let header; + let mut has_header = false; + + if file.len()? > 0 { + let result = Self::recover_header(&file)?; + header = result.0; + has_header = true; + } else { + header = DatabaseHeader::new(); + let bytes = header.to_bytes(); + file.set_len(bytes.len() as u64)?; + file.write(0, &bytes)?; + file.sync_data()?; + } + + let db = Self { + header: Arc::new(Mutex::new(header)), + file, + config, + }; + + if !has_header { + db.write_header(&db.header.lock().unwrap())?; + } + + Ok(db) + } + + #[inline(always)] + pub fn hash(&self, data: &[u8]) -> Hash { + H::hash(data) + } + + pub(crate) fn recover_header( + file: &Box, + ) -> Result<(DatabaseHeader, bool), io::Error> { + // Attempt to read from slot 0 + let bytes = file.read(0, 64)?; + if let Ok(header) = DatabaseHeader::from_bytes(&bytes) { + return Ok((header, false)); + } + + // Didn't work, try slot 1 + let bytes = file.read(PAGE_SIZE as u64, 64)?; + let header = DatabaseHeader::from_bytes(&bytes) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + + Ok((header, true)) + } + + pub(crate) fn write_header(&self, hdr: &DatabaseHeader) -> Result<(), io::Error> { + // Database reserves first two pages for the metadata + // The first page slot 0 contains the header + // Second page slot 1 contains a backup of the header + if self.file.len()? < PAGE_SIZE as u64 * 2 { + self.file.set_len(PAGE_SIZE as u64 * 2)?; + } + + let mut bytes = hdr.to_bytes(); + assert_eq!(bytes.len(), 64); + + bytes.extend_from_slice(&[0; PAGE_SIZE - 64]); + + self.file.write(0, &bytes)?; + self.file.sync_data()?; + + // write backup header + self.file.write(PAGE_SIZE as u64, &bytes)?; + self.file.sync_data()?; + Ok(()) + } + + fn read_save_point(&self, record: Record) -> Result { + let raw = self.file.read(record.offset, record.size as usize)?; + let (save_point, _) = bincode::decode_from_slice(&raw, config::standard()) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + Ok(save_point) + } + + pub fn begin_write(&self) -> Result, io::Error> { + Ok(WriteTransaction::new(self)) + } + + pub fn begin_read(&self) -> Result, io::Error> { + let result = Self::recover_header(&self.file)?; + // Use the stored configuration + Ok(ReadTransaction::new(self, result.0.savepoint)) + } + + pub(crate) fn load_node(&self, id: Record) -> Result { + let raw = self.file.read(id.offset, id.size as usize)?; + let config = config::standard(); + let (inner, _): (NodeInner, usize) = bincode::decode_from_slice(&raw, config) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + Ok(inner) + } + + pub fn iter(&self) -> SnapshotIterator { + SnapshotIterator::new(self) + } +} + +pub struct SnapshotIterator<'db, H: NodeHasher> { + current: Option, + started: bool, + db: &'db Database, +} + +impl<'db, H: NodeHasher> SnapshotIterator<'db, H> { + pub fn new(db: &'db Database) -> Self { + SnapshotIterator { + current: None, + started: false, + db, + } + } + + fn prev(&mut self) -> Result, io::Error> { + if !self.started { + let savepoint = Database::::recover_header(&self.db.file)?.0.savepoint; + self.current = if !savepoint.is_empty() { + Some(savepoint) + } else { + None + }; + self.started = true; + } + if self.current.is_none() { + return Ok(None); + } + + let savepoint = self.current.take().unwrap(); + if savepoint.is_empty() { + return Ok(None); + } + if savepoint.is_initial() { + return Ok(Some(savepoint)); + } + self.current = Some(self.db.read_save_point(savepoint.previous_save_point)?); + Ok(Some(savepoint)) + } +} + +impl<'db, H: NodeHasher> Iterator for SnapshotIterator<'db, H> { + type Item = Result, io::Error>; + fn next(&mut self) -> Option { + match self.prev() { + Ok(Some(prev_savepoint)) => Some(Ok(ReadTransaction::new(self.db, prev_savepoint))), + Ok(None) => None, + Err(e) => Some(Err(e)), + } + } +} + +impl SavePoint { + #[inline] + pub fn is_initial(&self) -> bool { + self.previous_save_point == EMPTY_RECORD + } + + #[inline] + pub fn is_empty(&self) -> bool { + self.root == EMPTY_RECORD && self.previous_save_point == EMPTY_RECORD + } + + #[inline] + pub fn len(&self) -> u64 { + return self.root.size as u64 + self.root.offset; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_header() { + let header = DatabaseHeader::new(); + let bytes = header.to_bytes(); + let header2 = DatabaseHeader::from_bytes(&bytes).unwrap(); + assert_eq!(header, header2); + + assert_eq!(bytes.len(), 64); + } +} diff --git a/src/fs.rs b/src/fs.rs new file mode 100644 index 0000000..f886cf5 --- /dev/null +++ b/src/fs.rs @@ -0,0 +1,454 @@ +// Uses flock on Unix and LockFile on Windows to ensure exclusive access to the database file. +// based on https://github.com/cberner/redb/tree/master/src/tree_store/page_store/file_backend +use crate::{ + db::{Record, SavePoint, EMPTY_RECORD, PAGE_SIZE}, + node::Node, +}; +use bincode::config; +use std::{ + fs::File, + io, + ops::{Index, IndexMut, RangeFrom}, + sync::*, +}; + +pub trait StorageBackend { + fn len(&self) -> Result; + fn set_len(&self, len: u64) -> Result<(), io::Error>; + fn read(&self, offset: u64, len: usize) -> Result, io::Error>; + fn sync_data(&self) -> Result<(), io::Error>; + fn write(&self, offset: u64, data: &[u8]) -> Result<(), io::Error>; +} + +#[derive(Debug, Default)] +pub struct MemoryBackend(RwLock>); + +#[cfg(any(unix))] +use std::os::fd::AsRawFd; +use std::os::unix::fs::FileExt; + +#[cfg(windows)] +use std::os::windows::{ + fs::FileExt, + io::{AsRawHandle, RawHandle}, +}; + +#[cfg(windows)] +const ERROR_LOCK_VIOLATION: i32 = 0x21; + +#[cfg(windows)] +const ERROR_IO_PENDING: i32 = 997; + +#[cfg(windows)] +extern "system" { + /// + fn LockFile( + file: RawHandle, + offset_low: u32, + offset_high: u32, + length_low: u32, + length_high: u32, + ) -> i32; + + /// + fn UnlockFile( + file: RawHandle, + offset_low: u32, + offset_high: u32, + length_low: u32, + length_high: u32, + ) -> i32; +} + +#[cfg(not(any(windows, unix)))] +use std::sync::Mutex; + +#[cfg(any(windows, unix))] +pub struct FileBackend { + file: File, +} + +#[cfg(any(unix))] +impl FileBackend { + pub fn new(file: File) -> Result { + let fd = file.as_raw_fd(); + let result = unsafe { libc::flock(fd, libc::LOCK_EX | libc::LOCK_NB) }; + if result != 0 { + let err = io::Error::last_os_error(); + if err.kind() == io::ErrorKind::WouldBlock { + Err(io::Error::new( + io::ErrorKind::WouldBlock, + "Database already open", + )) + } else { + Err(err.into()) + } + } else { + Ok(Self { file }) + } + } +} + +#[cfg(any(unix))] +impl Drop for FileBackend { + fn drop(&mut self) { + unsafe { libc::flock(self.file.as_raw_fd(), libc::LOCK_UN) }; + } +} + +#[cfg(any(unix))] +impl StorageBackend for FileBackend { + fn len(&self) -> Result { + Ok(self.file.metadata()?.len()) + } + + fn set_len(&self, len: u64) -> Result<(), io::Error> { + self.file.set_len(len) + } + + fn read(&self, offset: u64, len: usize) -> Result, io::Error> { + let mut buffer = vec![0; len]; + self.file.read_exact_at(&mut buffer, offset)?; + Ok(buffer) + } + + fn sync_data(&self) -> Result<(), io::Error> { + self.file.sync_data() + } + + fn write(&self, offset: u64, data: &[u8]) -> Result<(), io::Error> { + self.file.write_all_at(data, offset) + } +} + +#[cfg(windows)] +impl FileBackend { + pub fn new(file: File) -> Result { + let handle = file.as_raw_handle(); + unsafe { + let result = LockFile(handle, 0, 0, u32::MAX, u32::MAX); + + if result == 0 { + let err = io::Error::last_os_error(); + return if err.raw_os_error() == Some(ERROR_IO_PENDING) + || err.raw_os_error() == Some(ERROR_LOCK_VIOLATION) + { + Err(io::Error::new( + io::ErrorKind::WouldBlock, + "Database already open", + )) + } else { + Err(err.into()) + }; + } + }; + + Ok(Self { file }) + } +} + +#[cfg(windows)] +impl Drop for FileBackend { + fn drop(&mut self) { + unsafe { UnlockFile(self.file.as_raw_handle(), 0, 0, u32::MAX, u32::MAX) }; + } +} + +#[cfg(windows)] +impl StorageBackend for FileBackend { + fn set_len(&self, len: u64) -> Result<(), io::Error> { + self.file.set_len(len) + } + + fn len(&self) -> Result { + Ok(self.file.metadata()?.len()) + } + + fn read(&self, mut offset: u64, len: usize) -> Result, io::Error> { + let mut buffer = vec![0; len]; + let mut data_offset = 0; + while data_offset < buffer.len() { + let read = self.file.seek_read(&mut buffer[data_offset..], offset)?; + offset += read as u64; + data_offset += read; + } + Ok(buffer) + } + + fn sync_data(&self) -> Result<(), io::Error> { + self.file.sync_data() + } + + fn write(&self, mut offset: u64, data: &[u8]) -> Result<(), io::Error> { + let mut data_offset = 0; + while data_offset < data.len() { + let written = self.file.seek_write(&data[data_offset..], offset)?; + offset += written as u64; + data_offset += written; + } + Ok(()) + } +} + +// We use a mutex based lock on platforms that don't support flock +#[cfg(not(any(windows, unix)))] +struct FileBackend { + file: Mutex, +} + +#[cfg(not(any(windows, unix)))] +impl FileBackend { + fn new(file: File) -> Result { + Ok(Self { + file: Mutex::new(file), + }) + } +} + +#[cfg(not(any(windows, unix)))] +impl StorageBackend for FileBackend { + fn set_len(&self, len: u64) -> Result<(), io::Error> { + self.file.lock().unwrap().set_len(len) + } + + fn len(&self) -> Result { + Ok(self.file.lock().unwrap().metadata()?.len()) + } + + fn sync_data(&self, eventual: bool) -> Result<(), io::Error> { + self.file.lock().unwrap().sync_data() + } + + fn write(&self, offset: u64, data: &[u8]) -> Result<(), io::Error> { + let file = self.file.lock().unwrap(); + file.seek(SeekFrom::Start(offset))?; + file.write_all(data) + } + + fn read(&self, offset: u64, len: usize) -> Result, io::Error> { + let mut result = vec![0; len]; + let file = self.file.lock().unwrap(); + file.seek(SeekFrom::Start(offset))?; + file.read_exact(&mut result)?; + Ok(result) + } +} + +impl MemoryBackend { + fn out_of_range() -> io::Error { + io::Error::new(io::ErrorKind::InvalidInput, "Index out-of-range.") + } +} + +impl MemoryBackend { + /// Creates a new, empty memory backend. + pub fn new() -> Self { + Self::default() + } + + /// Gets a read guard for this backend. + fn read(&self) -> RwLockReadGuard<'_, Vec> { + self.0.read().expect("Could not acquire read lock.") + } + + /// Gets a write guard for this backend. + fn write(&self) -> RwLockWriteGuard<'_, Vec> { + self.0.write().expect("Could not acquire write lock.") + } +} + +impl StorageBackend for MemoryBackend { + fn len(&self) -> Result { + Ok(self.read().len() as u64) + } + + fn set_len(&self, len: u64) -> Result<(), io::Error> { + let mut guard = self.write(); + let len = usize::try_from(len).map_err(|_| Self::out_of_range())?; + if guard.len() < len { + let additional = len - guard.len(); + guard.reserve(additional); + for _ in 0..additional { + guard.push(0); + } + } else { + guard.truncate(len); + } + + Ok(()) + } + + fn read(&self, offset: u64, len: usize) -> Result, io::Error> { + let guard = self.read(); + let offset = usize::try_from(offset).map_err(|_| Self::out_of_range())?; + if offset + len <= guard.len() { + Ok(guard[offset..offset + len].to_owned()) + } else { + Err(Self::out_of_range()) + } + } + + fn sync_data(&self) -> Result<(), io::Error> { + Ok(()) + } + + fn write(&self, offset: u64, data: &[u8]) -> Result<(), io::Error> { + let mut guard = self.write(); + let offset = usize::try_from(offset).map_err(|_| Self::out_of_range())?; + if offset + data.len() <= guard.len() { + guard[offset..offset + data.len()].copy_from_slice(data); + Ok(()) + } else { + Err(Self::out_of_range()) + } + } +} + +pub struct WriteBuffer<'file, const SIZE: usize> { + file: &'file Box, + buffer: [u8; SIZE], + len: usize, + file_len: u64, +} + +impl<'file, const SIZE: usize> WriteBuffer<'file, SIZE> { + pub(crate) fn new(file: &'file Box, file_len: u64) -> Self { + Self { + file, + buffer: [0u8; SIZE], + len: 0, + file_len, + } + } + + fn remaining(&self) -> usize { + SIZE - self.len + } + + fn tail(&mut self) -> &mut [u8] { + &mut self.buffer[self.len..] + } + + pub(crate) fn flush(&mut self) -> Result<(), io::Error> { + if self.len == 0 { + return Ok(()); + } + + let aligned_len = self.len - (self.len % PAGE_SIZE); + + // Write all full pages in one go, if any + if aligned_len > 0 { + self.file.set_len(self.file_len + aligned_len as u64)?; + self.file + .write(self.file_len, &self.buffer[0..aligned_len])?; + self.file_len += aligned_len as u64; + } + + // Handle the remaining data and pad to a full page + if aligned_len < self.len { + let remaining_len = self.len - aligned_len; + self.buffer.copy_within(aligned_len..self.len, 0); + self.buffer[remaining_len..PAGE_SIZE].fill(0); + + self.file.set_len(self.file_len + PAGE_SIZE as u64)?; + self.file.write(self.file_len, &self.buffer[0..PAGE_SIZE])?; + self.file_len += PAGE_SIZE as u64; + } + + self.len = 0; + Ok(()) + } + + pub fn write_save_point(&mut self, save_point: &SavePoint) -> Result { + let config = config::standard(); + let size = + bincode::encode_into_slice(save_point, &mut self.tail(), config).map_err(|e| { + io::Error::new( + io::ErrorKind::Other, + format!("Failed to encode save point: {}", e), + ) + })?; + let record = Record { + offset: self.file_len + self.len as u64, + size: size as u32, + }; + + self.len += size; + Ok(record) + } + + pub fn write_node(&mut self, node: &mut Node) -> Result { + if self.remaining() < node.mem_size() { + self.flush()?; + } + + let config = config::standard(); + + if node.inner.is_none() { + if node.id != EMPTY_RECORD { + return Ok(node.id); + } + return Err(io::Error::new(io::ErrorKind::NotFound, "Node not found")); + } + + let size = { + let inner = node.inner.as_mut().unwrap(); + bincode::encode_into_slice(inner, &mut self.tail(), config).map_err(|e| { + io::Error::new( + io::ErrorKind::Other, + format!("Failed to encode node: {}", e), + ) + })? + }; + + let node_id = Record { + offset: self.file_len + self.len as u64, + size: size as u32, + }; + + self.len += size; + Ok(node_id) + } +} + +impl<'file, const SIZE: usize> Index for WriteBuffer<'file, SIZE> { + type Output = u8; + + fn index(&self, index: usize) -> &Self::Output { + &self.buffer[index] + } +} + +impl<'file, const SIZE: usize> Index> for WriteBuffer<'file, SIZE> { + type Output = [u8]; + + fn index(&self, range: std::ops::Range) -> &Self::Output { + &self.buffer[range] + } +} + +impl<'file, const SIZE: usize> IndexMut for WriteBuffer<'file, SIZE> { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + &mut self.buffer[index] + } +} + +impl<'file, const SIZE: usize> IndexMut> for WriteBuffer<'file, SIZE> { + fn index_mut(&mut self, range: std::ops::Range) -> &mut Self::Output { + &mut self.buffer[range] + } +} + +impl<'file, const SIZE: usize> Index> for WriteBuffer<'file, SIZE> { + type Output = [u8]; + + fn index(&self, range: RangeFrom) -> &Self::Output { + &self.buffer[range] + } +} + +impl<'file, const SIZE: usize> IndexMut> for WriteBuffer<'file, SIZE> { + fn index_mut(&mut self, range: RangeFrom) -> &mut Self::Output { + &mut self.buffer[range] + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..4470e69 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,97 @@ +#![cfg_attr(not(feature = "std"), no_std)] +extern crate alloc; + +pub mod path; +pub mod subtree; + +#[cfg(feature = "std")] +pub mod node; + +#[cfg(feature = "std")] +pub mod db; + +#[cfg(feature = "std")] +pub mod tx; + +#[cfg(feature = "std")] +pub mod fs; + +#[cfg(feature = "debug")] +pub mod debug; + +#[cfg(feature = "std")] +pub(crate) const ZERO_HASH: Hash = [0; 32]; + +use core::marker::PhantomData; +use sha2::{Digest as _, Sha256}; + +pub type Hash = [u8; 32]; + +const LEAF_TAG: u8 = 0x00; +const INTERNAL_TAG: u8 = 0x01; + +#[derive(Clone)] +pub struct Sha256Hasher; + +const DEFAULT_CACHE_SIZE: usize = 1024 * 1024 * 1024; /* 1GB */ + +#[derive(Clone, Debug)] +pub struct Configuration { + pub cache_size: usize, + _marker: PhantomData, +} + +impl Configuration { + pub fn new() -> Self { + Self { + cache_size: DEFAULT_CACHE_SIZE, + _marker: PhantomData, + } + } + + pub fn with_cache_size(mut self, size: usize) -> Self { + self.cache_size = size; + self + } +} + +pub trait NodeHasher: Clone { + fn hash(data: &[u8]) -> Hash; + fn hash_leaf(key: &[u8], value_hash: &[u8]) -> Hash; + fn hash_internal(prefix: &[u8], left: &[u8], right: &[u8]) -> Hash; +} + +impl Configuration { + pub fn standard() -> Self { + Self::new().with_cache_size(DEFAULT_CACHE_SIZE) + } +} + +impl NodeHasher for Sha256Hasher { + fn hash(data: &[u8]) -> Hash { + let mut hasher = Sha256::new(); + hasher.update(data); + hasher.finalize().as_slice().try_into().unwrap() + } + + fn hash_leaf(key: &[u8], value_hash: &[u8]) -> Hash { + let mut hasher = Sha256::new(); + hasher.update([LEAF_TAG]); + hasher.update(&key); + hasher.update(&value_hash); + hasher.finalize().as_slice().try_into().unwrap() + } + + fn hash_internal(prefix: &[u8], left: &[u8], right: &[u8]) -> Hash { + let mut hasher = Sha256::new(); + + hasher.update([INTERNAL_TAG]); + let bit_len = prefix[0]; + hasher.update([bit_len]); + hasher.update(&prefix[1..]); + hasher.update(left); + hasher.update(right); + + hasher.finalize().as_slice().try_into().unwrap() + } +} diff --git a/src/node.rs b/src/node.rs new file mode 100644 index 0000000..9b9b58b --- /dev/null +++ b/src/node.rs @@ -0,0 +1,175 @@ +use crate::{ + db::{Record, EMPTY_RECORD}, + path::{Path, PathSegment, PathSegmentInner}, + Hash, +}; +use bincode::{ + de::Decoder, + enc::Encoder, + error::{DecodeError, EncodeError}, + impl_borrow_decode, Decode, Encode, +}; + +#[derive(Clone, Debug)] +pub struct Node { + pub id: Record, + pub inner: Option, + pub(crate) hash_cache: Option, +} + +#[derive(Clone, Debug)] +pub enum NodeInner { + Leaf { + key: Path, + value: Vec, + }, + Internal { + prefix: PathSegment, + left: Box, + right: Box, + }, +} + +impl Node { + #[inline] + pub fn from_internal( + prefix: PathSegment, + left: Box, + right: Box, + ) -> Self { + Self { + id: EMPTY_RECORD, + inner: Some(NodeInner::Internal { + prefix, + left, + right, + }), + hash_cache: None, + } + } + + #[inline] + pub fn from_leaf(key: Path, value: Vec) -> Self { + Self { + id: EMPTY_RECORD, + inner: Some(NodeInner::Leaf { key, value }), + hash_cache: None, + } + } + + #[inline] + pub(crate) fn from_id(id: Record) -> Self { + Self { + id, + inner: None, + hash_cache: None, + } + } + + #[inline] + pub fn mem_size(&self) -> usize { + let base_size = std::mem::size_of_val(&self); + let inner_size = std::mem::size_of_val(&self.inner) + + match &self.inner { + Some(NodeInner::Leaf { value, .. }) => value.capacity(), + Some(NodeInner::Internal { left, right, .. }) => left.mem_size() + right.mem_size(), + None => 0, + }; + + base_size + + inner_size + + std::mem::size_of_val(&self.hash_cache) + + std::mem::size_of_val(&self.id) + - 1 + } +} + +impl Encode for NodeInner { + fn encode(&self, encoder: &mut E) -> Result<(), EncodeError> { + match self { + NodeInner::Leaf { key, value } => { + Encode::encode(&0u8, encoder)?; + Encode::encode(&key.0, encoder)?; + Encode::encode(value, encoder)?; + } + NodeInner::Internal { + prefix, + left, + right, + } => { + Encode::encode(&1u8, encoder)?; + Encode::encode(&prefix.0, encoder)?; + Encode::encode(left, encoder)?; + Encode::encode(right, encoder)?; + } + } + Ok(()) + } +} + +impl Decode for NodeInner { + fn decode(decoder: &mut D) -> Result { + let tag: u8 = Decode::decode(decoder)?; + match tag { + 0 => { + let key = Path(Decode::decode(decoder)?); + let value = Decode::decode(decoder)?; + Ok(NodeInner::Leaf { key, value }) + } + 1 => { + let seg: [u8; 33] = Decode::decode(decoder)?; + let prefix = PathSegment(seg); + let left: Node = Decode::decode(decoder)?; + let right: Node = Decode::decode(decoder)?; + Ok(NodeInner::Internal { + prefix, + left: Box::new(left), + right: Box::new(right), + }) + } + _ => Err(DecodeError::Other("Invalid tag")), + } + } +} + +impl<'a> Encode for &'a mut NodeInner { + fn encode(&self, encoder: &mut E) -> Result<(), EncodeError> { + match self { + NodeInner::Leaf { key, value } => { + Encode::encode(&0u8, encoder)?; + Encode::encode(&key.0, encoder)?; + Encode::encode(value, encoder)?; + } + NodeInner::Internal { + prefix, + left, + right, + } => { + Encode::encode(&1u8, encoder)?; + Encode::encode(&prefix.0, encoder)?; + Encode::encode(left, encoder)?; + Encode::encode(right, encoder)?; + } + } + Ok(()) + } +} + +impl Encode for Node { + fn encode(&self, encoder: &mut E) -> Result<(), EncodeError> { + if self.id == EMPTY_RECORD { + return Err(EncodeError::Other("Node id is zero")); + } + Encode::encode(&self.id, encoder) + } +} + +impl Decode for Node { + fn decode(decoder: &mut D) -> Result { + let id = Decode::decode(decoder)?; + Ok(Node::from_id(id)) + } +} + +impl_borrow_decode!(Node); +impl_borrow_decode!(NodeInner); diff --git a/src/path.rs b/src/path.rs new file mode 100644 index 0000000..e7e6cf8 --- /dev/null +++ b/src/path.rs @@ -0,0 +1,256 @@ +pub(crate) type PathSegmentInner = [u8; 33]; + +const BIT_MASK: [u8; 8] = [128, 64, 32, 16, 8, 4, 2, 1]; + +#[derive(PartialEq)] +pub enum Direction { + Left, + Right, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct Path(pub T); + +#[derive(Clone, Copy, Debug)] +pub struct PathSegment(pub T); + +pub trait PathUtils { + /// Returns the direction at a specified index within a path in the binary trie. + /// + /// # Parameters + /// - `index`: The index of the bit within the path. + /// + /// # Returns + /// - `Direction::Right` if the bit at the specified index is set. + /// - `Direction::Left` if the bit at the specified index is unset. + /// + /// # Note + /// - The function uses MSB ordering, where the most significant bit is at index 0. + fn direction(&self, index: usize) -> Direction; + + /// Returns the first point of divergence relative to `start` . + /// The comparison begins at `start` index within `self` and proceeds until the end of the shortest segment. + /// + /// # Parameters + /// - `start`: The index within `self` from which the comparison should start. Must be within the bounds of `self`. + /// - `segment`: The segment to compare against `self`. Must implement `BitLength` and `PathUtils`. + /// + /// # Panics + /// - If `start` + `segment.bit_len()` is greater than the length of `self`. + /// + /// # Returns + /// - `None` if the paths are identical up to the length of the shortest segment. + /// - `Some(index)` where `index` is the first point of divergence relative to `start` in `self`. + /// + /// # Note + /// - The comparison stops at the end of the shortest segment. + /// - The function uses MSB ordering, where the most significant bit is at index 0. + fn split_point( + &self, + segment_start: usize, + segment: S, + ) -> Option; +} + +impl> PathSegment { + /// Copies all bits from `src` into `self` starting and ending at the specified bit indices. + /// `self` must be able to accommodate the copied bits within its length. + /// + /// # Parameters + /// - `src`: The source path from which bits will be copied. Must implement `BitLength` and `PathUtils`. + /// - `start`: The starting bit index in the `src` path (inclusive) + /// - `end`: The ending bit index in `src` path (exclusive) + /// + /// # Panics + /// - The function panics if `start` > `end`. + pub fn copy(&mut self, src: A, start: usize, end: usize) { + if start == end { + return; + } + assert!(start < end, "start {} must be less than end {}", start, end); + let bit_len = end - start; + self.set_len(bit_len); + + let (src, src_start, start_bit) = (src.inner(), start / 8, start % 8); + let (dst_slice, dst_end_idx, dst_end_bit) = + (self.as_mut_inner(), (bit_len - 1) / 8, bit_len % 8); + + // If aligned on byte boundary, use direct slice copy. + if start_bit == 0 { + dst_slice[..dst_end_idx + 1] + .copy_from_slice(&src[src_start..src_start + dst_end_idx + 1]); + } else { + // For non-aligned bits, copy bits with shifting. + for (i, j) in (src_start..src_start + dst_end_idx + 1).zip(0..) { + dst_slice[j] = src[i] << start_bit; + if i + 1 < src.len() { + dst_slice[j] |= src[i + 1] >> (8 - start_bit); + } + } + } + + // Handle the case where the last byte in dst_slice is copied + // but not all bits are needed. + if dst_end_bit != 0 { + // zero out the unused bits. + dst_slice[dst_end_idx] &= 0xFF << (8 - dst_end_bit); + } + } + + #[inline(always)] + pub fn set_len(&mut self, len: usize) { + assert!(len <= 255, "PathSegment length must be <= 255"); + self.0.as_mut()[0] = len as u8; + } + + #[inline(always)] + pub fn as_mut_inner(&mut self) -> &mut [u8] { + &mut self.0.as_mut()[1..] + } +} + +impl> PathUtils for T { + #[inline(always)] + fn direction(&self, index: usize) -> Direction { + if self.inner()[index / 8] & BIT_MASK[index % 8] != 0 { + return Direction::Right; + } + Direction::Left + } + + fn split_point(&self, start: usize, b: S) -> Option { + let max_bit_len = core::cmp::min(self.bit_len(), b.bit_len()); + let (src_start_byte, src_start_bit, seg_end_byte) = + (start / 8, start % 8, (max_bit_len + 7) / 8); + let mut count = 0; + + // Aligned on byte boundary + if src_start_bit == 0 { + let (a, b) = (&self.inner()[src_start_byte..], &b.inner()[..seg_end_byte]); + for (a_byte, b_byte) in a.iter().zip(b.iter()) { + if *a_byte != *b_byte { + count += (a_byte ^ b_byte).leading_zeros(); + break; + } + count += 8; + } + } else { + // Non-aligned: we need to align self and then compare (b is already aligned) + let (a, b) = (&self.inner()[src_start_byte..], &b.inner()[..seg_end_byte]); + for (i, b_byte) in b.iter().enumerate() { + // Remove bits we don't care about at the start by shifiting + let mut a_byte = a[i] << src_start_bit; + // We made room for some bits from the next byte + if i < a.len() { + a_byte |= a[i + 1] >> (8 - src_start_bit); + } + + // We now have an aligned a_byte + if a_byte != *b_byte { + count += (a_byte ^ b_byte).leading_zeros(); + break; + } + count += 8; + } + } + + let count = core::cmp::min(count as usize, max_bit_len); + if count == max_bit_len { + return None; + } else { + return Some(count); + } + } +} + +impl PathSegment<[u8; 33]> { + #[inline(always)] + pub fn from_path(src: A, from: usize, to: usize) -> Self { + let mut a = PathSegment([0; 33]); + a.copy(src, from, to); + a + } +} + +impl> AsRef<[u8]> for Path { + #[inline(always)] + fn as_ref(&self) -> &[u8] { + self.0.as_ref() + } +} + +impl> AsMut<[u8]> for Path { + #[inline(always)] + fn as_mut(&mut self) -> &mut [u8] { + self.0.as_mut() + } +} + +impl> AsRef<[u8]> for PathSegment { + #[inline(always)] + fn as_ref(&self) -> &[u8] { + self.0.as_ref() + } +} + +impl> AsMut<[u8]> for PathSegment { + #[inline(always)] + fn as_mut(&mut self) -> &mut [u8] { + self.0.as_mut() + } +} + +impl> BitLength for Path { + #[inline(always)] + fn bit_len(&self) -> usize { + 256 + } + + #[inline(always)] + fn inner(&self) -> &[u8] { + &self.0.as_ref() + } + + #[inline(always)] + fn as_bytes(&self) -> &[u8] { + &self.0.as_ref() + } +} + +impl> BitLength for PathSegment { + #[inline(always)] + fn bit_len(&self) -> usize { + self.0.as_ref()[0] as usize + } + + #[inline(always)] + fn inner(&self) -> &[u8] { + &self.0.as_ref()[1..] + } + + #[inline(always)] + fn as_bytes(&self) -> &[u8] { + let byte_len = (self.bit_len() + 7) / 8; + &self.0.as_ref()[..(byte_len + 1)] + } +} + +#[cfg(feature = "std")] +impl PartialOrd for Path { + fn partial_cmp(&self, other: &Self) -> Option { + self.0.partial_cmp(&other.0) + } +} + +#[cfg(feature = "std")] +impl Ord for Path { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.0.cmp(&other.0) + } +} + +pub trait BitLength { + fn bit_len(&self) -> usize; + fn inner(&self) -> &[u8]; + fn as_bytes(&self) -> &[u8]; +} diff --git a/src/subtree.rs b/src/subtree.rs new file mode 100644 index 0000000..d93a2d6 --- /dev/null +++ b/src/subtree.rs @@ -0,0 +1,464 @@ +use crate::{ + path::{BitLength, Direction, Path, PathSegment, PathSegmentInner, PathUtils}, + Hash, NodeHasher, +}; + +use alloc::{boxed::Box, vec, vec::Vec}; +use bincode::{ + de::Decoder, + enc::Encoder, + error::{DecodeError, EncodeError}, + impl_borrow_decode, Decode, Encode, +}; + +#[derive(Clone, Debug)] +pub struct SubTree { + pub root: SubTreeNode, + pub _marker: core::marker::PhantomData, +} + +#[derive(Clone, Debug)] +pub enum SubTreeNode { + Leaf { + key: Path, + value_or_hash: ValueOrHash, + }, + Internal { + prefix: PathSegment, + left: Box, + right: Box, + }, + Hash(Hash), + None, +} + +#[derive(Clone, Debug)] +pub enum ValueOrHash { + Value(Vec), + Hash(Hash), +} + +#[derive(Debug)] +pub enum VerifyError { + KeyExists, + IncompleteProof, +} + +impl SubTree { + pub fn empty() -> Self { + Self { + root: SubTreeNode::None, + _marker: core::marker::PhantomData, + } + } + + pub fn root(&self) -> Result { + if self.is_empty() { + return Ok(H::hash(&[])); + } + Self::hash_node(&self.root, 0) + } + + #[inline(always)] + pub fn hash(&self, value: &[u8]) -> Hash { + H::hash(value) + } + + #[inline(always)] + pub fn is_empty(&self) -> bool { + match self.root { + SubTreeNode::None => true, + _ => false, + } + } + + pub fn insert(&mut self, key: Hash, value_or_hash: ValueOrHash) -> Result<(), VerifyError> { + if self.is_empty() { + self.root = SubTreeNode::Leaf { + key: Path(key), + value_or_hash, + }; + return Ok(()); + } + + let mut node = &mut self.root; + let key = Path(key); + let mut depth = 0; + loop { + match node { + SubTreeNode::Leaf { key: node_key, .. } => { + // Same key + if key.0 == node_key.0 { + return Err(VerifyError::KeyExists); + } + + // A split point must exist: compress common path into an internal node + let point = node_key.split_point(0, key).unwrap(); + let prefix = PathSegment::from_path(*node_key, depth, point); + let depth = depth + prefix.bit_len() as usize; + let direction = key.direction(depth); + let current_node = core::mem::take(node); + let new_node = SubTreeNode::Leaf { key, value_or_hash }; + let (left, right) = match direction { + Direction::Right => (Box::new(current_node), Box::new(new_node)), + Direction::Left => (Box::new(new_node), Box::new(current_node)), + }; + *node = SubTreeNode::Internal { + prefix, + left, + right, + }; + return Ok(()); + } + SubTreeNode::Internal { + prefix, + left, + right, + } => { + let point = key.split_point(depth, *prefix); + if point.is_none() { + depth = depth + prefix.bit_len() as usize; + match key.direction(depth) { + Direction::Right => node = right, + Direction::Left => node = left, + } + depth += 1; + continue; + } + + // A split point exists: compress common path into an internal node + let point = point.unwrap(); + let parent_prefix = PathSegment::from_path(*prefix, 0, point); + let current_node_prefix = + PathSegment::from_path(*prefix, point + 1, prefix.bit_len()); + + let current_node = SubTreeNode::Internal { + prefix: current_node_prefix, + left: core::mem::take(left), + right: core::mem::take(right), + }; + + depth = depth + parent_prefix.bit_len(); + + let new_node = SubTreeNode::Leaf { key, value_or_hash }; + let (lefty, righty) = match key.direction(depth) { + Direction::Right => (Box::new(current_node), Box::new(new_node)), + Direction::Left => (Box::new(new_node), Box::new(current_node)), + }; + + *prefix = parent_prefix; + *left = lefty; + *right = righty; + + return Ok(()); + } + SubTreeNode::Hash(_hash) => { + return Err(VerifyError::IncompleteProof); + } + SubTreeNode::None => { + unreachable!("Unexpected None node") + } + } + } + } + + pub fn contains(&self, key: &Hash) -> Result { + if self.is_empty() { + return Ok(false); + } + + let mut node = &self.root; + let key = Path(key); + let mut depth = 0; + loop { + match node { + SubTreeNode::Leaf { key: node_key, .. } => { + return Ok(*key.0 == node_key.0); + } + SubTreeNode::Internal { + prefix, + left, + right, + } => { + depth = depth + prefix.bit_len() as usize; + match key.direction(depth) { + Direction::Left => node = left, + Direction::Right => node = right, + } + depth += 1; + } + SubTreeNode::Hash(_hash) => { + return Err(VerifyError::IncompleteProof); + } + SubTreeNode::None => { + unreachable!("None should not be inserted") + } + } + } + } + + fn hash_node(node: &SubTreeNode, depth: usize) -> Result { + match node { + SubTreeNode::Leaf { key, value_or_hash } => match value_or_hash { + ValueOrHash::Value(value) => { + let hash = H::hash(value); + Ok(H::hash_leaf(&key.0, &hash)) + } + ValueOrHash::Hash(hash) => Ok(H::hash_leaf(&key.0, hash)), + }, + SubTreeNode::Internal { + prefix, + left, + right, + } => { + let depth = depth + prefix.bit_len() + 1; + let left_hash = Self::hash_node(left, depth)?; + let right_hash = Self::hash_node(right, depth)?; + Ok(H::hash_internal(prefix.as_bytes(), &left_hash, &right_hash)) + } + SubTreeNode::Hash(hash) => Ok(hash.clone()), + SubTreeNode::None => { + unreachable!("None should not be inserted") + } + } + } + + pub fn iter(&self) -> SubtreeIter { + if self.is_empty() || !value_node(&self.root) { + return SubtreeIter { stack: vec![] }; + } + SubtreeIter { + stack: vec![(&self.root, 0)], + } + } + + pub fn iter_mut(&mut self) -> SubtreeIterMut { + if self.is_empty() || !value_node(&self.root) { + return SubtreeIterMut { stack: vec![] }; + } + SubtreeIterMut { + stack: vec![(&mut self.root, 0)], + } + } +} + +impl Encode for SubTree { + fn encode(&self, encoder: &mut E) -> Result<(), EncodeError> { + Encode::encode(&self.root, encoder) + } +} + +impl Decode for SubTree { + fn decode(decoder: &mut D) -> Result { + let root = Decode::decode(decoder)?; + Ok(Self { + root, + _marker: core::marker::PhantomData, + }) + } +} + +impl Encode for SubTreeNode { + fn encode(&self, encoder: &mut E) -> Result<(), EncodeError> { + match self { + SubTreeNode::Leaf { key, value_or_hash } => { + Encode::encode(&0u8, encoder)?; + Encode::encode(&key.0, encoder)?; + Encode::encode(value_or_hash, encoder)?; + } + SubTreeNode::Internal { + prefix, + left, + right, + } => { + Encode::encode(&1u8, encoder)?; + Encode::encode(&prefix.0, encoder)?; + Encode::encode(left, encoder)?; + Encode::encode(right, encoder)?; + } + SubTreeNode::Hash(hash) => { + Encode::encode(&2u8, encoder)?; + Encode::encode(hash, encoder)?; + } + SubTreeNode::None => { + unreachable!("None should not be encoded") + } + } + Ok(()) + } +} + +impl Decode for SubTreeNode { + fn decode(decoder: &mut D) -> Result { + let tag: u8 = Decode::decode(decoder)?; + match tag { + 0 => { + let key_raw: Hash = Decode::decode(decoder)?; + let key = Path(key_raw); + let value_or_hash = Decode::decode(decoder)?; + Ok(SubTreeNode::Leaf { key, value_or_hash }) + } + 1 => { + let seg: [u8; 33] = Decode::decode(decoder)?; + let prefix = PathSegment(seg); + let left: Box = Decode::decode(decoder)?; + let right: Box = Decode::decode(decoder)?; + Ok(SubTreeNode::Internal { + prefix, + left, + right, + }) + } + 2 => { + let hash: Hash = Decode::decode(decoder)?; + Ok(SubTreeNode::Hash(hash)) + } + _ => Err(DecodeError::Other("Invalid tag subtree node")), + } + } +} + +impl_borrow_decode!(SubTreeNode); + +impl Encode for ValueOrHash { + fn encode(&self, encoder: &mut E) -> Result<(), EncodeError> { + match self { + ValueOrHash::Value(value) => { + Encode::encode(&0u8, encoder)?; + Encode::encode(value, encoder)?; + } + ValueOrHash::Hash(hash) => { + Encode::encode(&1u8, encoder)?; + Encode::encode(hash, encoder)?; + } + } + Ok(()) + } +} + +impl Decode for ValueOrHash { + fn decode(decoder: &mut D) -> Result { + let tag: u8 = Decode::decode(decoder)?; + match tag { + 0 => { + let value: Vec = Decode::decode(decoder)?; + Ok(ValueOrHash::Value(value)) + } + 1 => { + let hash: Hash = Decode::decode(decoder)?; + Ok(ValueOrHash::Hash(hash)) + } + _ => Err(DecodeError::Other("Invalid tag")), + } + } +} + +impl_borrow_decode!(ValueOrHash); + +impl Default for SubTreeNode { + fn default() -> Self { + SubTreeNode::None + } +} + +pub struct SubtreeIter<'a> { + stack: Vec<(&'a SubTreeNode, usize)>, +} + +impl<'a> Iterator for SubtreeIter<'a> { + type Item = (&'a Hash, &'a Vec); + + fn next(&mut self) -> Option { + loop { + let (node, depth) = match self.stack.pop() { + Some(x) => x, + None => return None, + }; + + match node { + SubTreeNode::Leaf { key, value_or_hash } => { + if let ValueOrHash::Value(value) = value_or_hash { + return Some((&key.0, value)); + } + unreachable!("Hashes of leaf nodes must not be in the stack") + } + SubTreeNode::Internal { + prefix, + left, + right, + } => { + let depth = depth + prefix.bit_len() + 1; + if value_node(right.as_ref()) { + self.stack.push((right, depth)); + } + if value_node(left.as_ref()) { + self.stack.push((left, depth)); + } + } + SubTreeNode::Hash(_hash) => { + unreachable!("Hashes must not be in the stack") + } + SubTreeNode::None => { + unreachable!("None should not be inserted") + } + } + } + } +} + +pub struct SubtreeIterMut<'a> { + stack: Vec<(&'a mut SubTreeNode, usize)>, +} + +impl<'a> Iterator for SubtreeIterMut<'a> { + // The Item type now is a tuple of an immutable reference to Hash and a mutable reference to Vec + type Item = (&'a Hash, &'a mut Vec); + + fn next(&mut self) -> Option { + loop { + let (node, depth) = self.stack.pop()?; + + match node { + SubTreeNode::Leaf { key, value_or_hash } => match value_or_hash { + ValueOrHash::Value(value) => { + return Some((&key.0, value)); + } + ValueOrHash::Hash(_) => { + unreachable!("Hash of leaf node must not be in the stack"); + } + }, + SubTreeNode::Internal { + prefix, + left, + right, + } => { + let depth = depth + prefix.bit_len() + 1; + if value_node(right.as_ref()) { + self.stack.push((right, depth)); + } + if value_node(left.as_ref()) { + self.stack.push((left, depth)); + } + } + SubTreeNode::Hash(_) => { + unreachable!("Hashes must not be in the stack") + } + SubTreeNode::None => { + unreachable!("None should not be inserted") + } + } + } + } +} + +#[inline(always)] +fn value_node(node: &SubTreeNode) -> bool { + match node { + SubTreeNode::Leaf { value_or_hash, .. } => matches!(value_or_hash, ValueOrHash::Value(_)), + SubTreeNode::Internal { .. } => true, + SubTreeNode::Hash(_) => false, + SubTreeNode::None => { + unreachable!("None should not be inserted") + } + } +} diff --git a/src/tx.rs b/src/tx.rs new file mode 100644 index 0000000..6bcee9b --- /dev/null +++ b/src/tx.rs @@ -0,0 +1,518 @@ +use crate::{ + path::{Direction, Path}, + subtree::ValueOrHash, +}; + +use bincode::config; +use core::marker::PhantomData; +use std::{io, sync::MutexGuard}; + +use crate::{ + db::{Database, Record, SavePoint, EMPTY_RECORD, PAGE_SIZE}, + node::{Node, NodeInner}, + path::{BitLength, PathUtils}, + subtree::{SubTree, SubTreeNode}, + Configuration, Hash, NodeHasher, +}; + +use crate::{db::DatabaseHeader, fs::WriteBuffer}; + +use crate::{ + path::{PathSegment, PathSegmentInner}, + ZERO_HASH, +}; + +const BUFFER_SIZE: usize = 16 * 64 * 1024; + +pub struct WriteTransaction<'db, H: NodeHasher> { + pub db: &'db Database, + pub(crate) state: Option, + header: MutexGuard<'db, DatabaseHeader>, +} + +pub struct ReadTransaction<'db, H: NodeHasher> { + pub db: &'db Database, + pub root: Record, + pub cache: Cache<'db, H>, + pub config: Configuration, +} + +pub struct Cache<'db, H: NodeHasher> { + db: &'db Database, + pub node: Option, + pub len: usize, + pub max_len: usize, +} + +struct CacheEntry<'n> { + node: &'n mut Node, + clean: bool, +} + +impl<'db, H: NodeHasher + 'db> ReadTransaction<'db, H> { + pub(crate) fn new(db: &'db Database, savepoint: SavePoint) -> Self { + Self { + db, + root: savepoint.root, + cache: Cache::new(db, savepoint.root, db.config.cache_size), + config: db.config.clone(), + } + } + + pub fn iter(&self) -> KeyIterator { + KeyIterator::new(self.db, self.root) + } + + pub fn get(&mut self, key: &Hash) -> Result, io::Error> { + let mut node = self.cache.node.take().unwrap(); + let result = Self::get_node(&mut self.cache, &mut node, Path(key), 0)?; + self.cache.node = Some(node); + Ok(result) + } + + pub fn root(&mut self) -> Result { + let mut n = self.cache.node.take().unwrap(); + if n.id == EMPTY_RECORD { + self.cache.node = Some(n); + return Ok(H::hash(&[])); + } + + let h = { + let entry = Self::load_hash(&mut self.cache, &mut n)?; + entry.node.hash_cache.clone().unwrap() + }; + self.cache.node = Some(n); + Ok(h) + } + + pub fn prove(&mut self, key: &Hash) -> Result, io::Error> { + self.prove_all(&[*key]) + } + + pub fn prove_all(&mut self, keys: &[Hash]) -> Result, io::Error> { + let mut node = self.cache.node.take().unwrap(); + if node.id == EMPTY_RECORD { + return Ok(SubTree::::empty()); + } + + let mut key_paths = keys.iter().map(|k| Path(k)).collect::>(); + key_paths.sort(); + + let subtree = Self::prove_nodes(&mut self.cache, &mut node, key_paths.as_slice(), 0)?; + self.cache.node = Some(node); + Ok(SubTree:: { + root: subtree, + _marker: PhantomData::, + }) + } + + fn prove_nodes( + cache: &mut Cache, + node: &mut Node, + keys: &[Path<&Hash>], + depth: usize, + ) -> Result { + let entry = cache.load_node(node)?; + match entry.node.inner.as_mut().unwrap() { + NodeInner::Leaf { + key: node_key, + value, + } => { + let include_value = keys.iter().any(|k| *k.0 == node_key.0); + let value_or_hash = if include_value { + ValueOrHash::Value(value.clone()) + } else { + ValueOrHash::Hash(H::hash(value)) + }; + Ok(SubTreeNode::Leaf { + key: node_key.clone(), + value_or_hash, + }) + } + NodeInner::Internal { + prefix, + left, + right, + } => { + // Exclude keys that are not in this subtree. + let end = keys.partition_point(|key| key.split_point(depth, *prefix).is_none()); + let keys = &keys[..end]; + + // Keys are split based on their direction at the current depth. + let depth = depth + prefix.bit_len(); + + let split = keys.partition_point(|key| key.direction(depth) == Direction::Left); + let (left_keys, right_keys) = keys.split_at(split); + + let left_subtree = if left_keys.is_empty() { + let left_entry = Self::load_hash(cache, left)?; + let left_hash = left_entry.node.hash_cache.clone().unwrap(); + SubTreeNode::Hash(left_hash) + } else { + Self::prove_nodes(cache, left, left_keys, depth + 1)? + }; + + let right_subtree = if right_keys.is_empty() { + let right_entry = Self::load_hash(cache, right)?; + let right_hash = right_entry.node.hash_cache.clone().unwrap(); + SubTreeNode::Hash(right_hash) + } else { + Self::prove_nodes(cache, right, right_keys, depth + 1)? + }; + + Ok(SubTreeNode::Internal { + prefix: prefix.clone(), + left: Box::new(left_subtree), + right: Box::new(right_subtree), + }) + } + } + } + + fn load_hash<'c>( + cache: &mut Cache, + node: &'c mut Node, + ) -> Result, io::Error> { + if node.hash_cache.is_some() { + return Ok(CacheEntry::new(node, false)); + } + + let entry = cache.load_node(node)?; + match entry.node.inner.as_mut().unwrap() { + NodeInner::Leaf { key, value } => { + let hash = H::hash(value); + entry.node.hash_cache = Some(H::hash_leaf(&key.0, &hash)); + } + NodeInner::Internal { + prefix, + left, + right, + } => { + let left_entry = Self::load_hash(cache, left)?; + let left_hash = left_entry.node.hash_cache.as_ref().unwrap(); + let right_entry = Self::load_hash(cache, right)?; + let right_hash = right_entry.node.hash_cache.as_ref().unwrap(); + entry.node.hash_cache = + Some(H::hash_internal(prefix.as_bytes(), left_hash, right_hash)); + } + } + Ok(entry) + } + + fn get_node<'c>( + cache: &mut Cache, + node: &'c mut Node, + key: Path<&Hash>, + depth: usize, + ) -> Result, io::Error> { + let entry = cache.load_node(node)?; + match entry.node.inner.as_mut().unwrap() { + NodeInner::Leaf { value, .. } => Ok(value.clone()), + NodeInner::Internal { + prefix, + left, + right, + } => { + let depth = depth + prefix.bit_len(); + match key.direction(depth) { + Direction::Right => Self::get_node(cache, right, key, depth + 1), + Direction::Left => Self::get_node(cache, left, key, depth + 1), + } + } + } + } +} + +impl<'db, H: NodeHasher> WriteTransaction<'db, H> { + pub(crate) fn new(db: &'db Database) -> Self { + let head = db.header.lock().unwrap(); + let state = if head.savepoint.root == EMPTY_RECORD { + None + } else { + Some(Node::from_id(head.savepoint.root)) + }; + + Self { + db, + state, + header: head, + } + } + + pub fn insert(&mut self, key: Hash, value: Vec) -> Result<(), io::Error> { + if self.state.is_none() { + self.state = Some(Node::from_leaf(Path(key), value)); + return Ok(()); + } + + let mut state = self.state.take().unwrap(); + state = self.insert_into_node(state, Path(key), value, 0)?; + self.state = Some(state); + Ok(()) + } + + fn insert_into_node( + &mut self, + node: Node, + key: Path, + value: Vec, + depth: usize, + ) -> Result { + let inner = match node.inner { + Some(node) => node, + None => { + if node.id == EMPTY_RECORD { + return Err(io::Error::new(io::ErrorKind::NotFound, "Node not found")); + } + let raw = self.db.file.read(node.id.offset, node.id.size as usize)?; + let config = config::standard(); + let (inner, _): (NodeInner, usize) = + bincode::decode_from_slice(&raw, config).unwrap(); + inner + } + }; + + match inner { + NodeInner::Leaf { + key: node_key, + value: node_value, + } => self.insert_leaf(node_key, node_value, key, value, depth), + NodeInner::Internal { + prefix, + left, + right, + } => self.insert_internal(prefix, left, right, key, value, depth), + } + } + + #[inline] + fn insert_internal( + &mut self, + prefix: PathSegment, + left: Box, + right: Box, + key: Path, + value: Vec, + depth: usize, + ) -> io::Result { + let point = key.split_point(depth, prefix); + if point.is_none() { + let depth = depth + prefix.bit_len(); + // Traverse further based on the direction + return match key.direction(depth) { + Direction::Right => { + let new_node = + Box::new(self.insert_into_node(*right, key, value, depth + 1)?); + Ok(Node::from_internal(prefix, left, new_node)) + } + Direction::Left => { + let new_node = Box::new(self.insert_into_node(*left, key, value, depth + 1)?); + Ok(Node::from_internal(prefix, new_node, right)) + } + }; + } + + // A split point exists: compress common path into an internal node + let point = point.unwrap(); + + // Prefix paths are relative to the depth + // Parent will be from 0th bit of the prefix (inclusive) to split point (exclusive) + let parent_prefix = PathSegment::from_path(prefix, 0, point); + + // Since current node is going down one level, we need to copy from split point+1 i.e. skipping 1 bit. + let prefix = PathSegment::from_path(prefix, point + 1, prefix.bit_len()); + + let current_node = Node::from_internal(prefix, left, right); + let new_node = Node::from_leaf(key, value); + + let depth = depth + parent_prefix.bit_len(); + let (left, right) = match key.direction(depth) { + Direction::Right => (current_node, new_node), + Direction::Left => (new_node, current_node), + }; + + Ok(Node::from_internal( + parent_prefix, + Box::new(left), + Box::new(right), + )) + } + + #[inline] + fn insert_leaf( + &mut self, + current_key: Path, + current_value: Vec, + key: Path, + value: Vec, + depth: usize, + ) -> io::Result { + // Empty root: leaf becomes root + if current_key == Path(ZERO_HASH) { + return Ok(Node::from_leaf(key, value)); + } + + // Same key: update value + if current_key == key { + return Ok(Node::from_leaf(key, value)); + } + + // A split point must exist: compress common path into an internal node + let point = current_key.split_point(0, key).unwrap(); + let prefix = PathSegment::from_path(current_key, depth, point); + + let depth = depth + prefix.bit_len(); + let node_direction = key.direction(depth); + let current_node = Node::from_leaf(current_key, current_value); + let node = Node::from_leaf(key, value); + + let (left, right) = match node_direction { + Direction::Right => (current_node, node), + Direction::Left => (node, current_node), + }; + + Ok(Node::from_internal(prefix, Box::new(left), Box::new(right))) + } + + fn write_all( + &mut self, + buf: &mut WriteBuffer, + node: &mut Node, + ) -> Result { + match &mut node.inner { + Some(NodeInner::Leaf { .. }) => { + node.id = buf.write_node(node)?; + } + Some(NodeInner::Internal { left, right, .. }) => { + self.write_all(buf, left)?; + self.write_all(buf, right)?; + node.id = buf.write_node(node)?; + } + None => { + if node.id != EMPTY_RECORD { + return Ok(node.id); + } + return Err(io::Error::new(io::ErrorKind::NotFound, "Node not found")); + } + } + + Ok(node.id) + } + pub fn commit(mut self) -> Result<(), io::Error> { + if self.state.is_none() { + return Ok(()); + } + + let expected_file_length = self.header.len(); + assert_eq!( + expected_file_length % PAGE_SIZE as u64, + 0, + "Database length is not a multiple of page size {}", + expected_file_length + ); + + let file_length = self.db.file.len()?; + if file_length != expected_file_length { + // truncate/extend file to expected length + self.db.file.set_len(expected_file_length)?; + } + + let mut buf: WriteBuffer = WriteBuffer::new(&self.db.file, file_length); + let mut state = self.state.take().unwrap(); + let root = self.write_all(&mut buf, &mut state)?; + + let previous_save_point = buf.write_save_point(&self.header.savepoint)?; + buf.flush()?; + self.db.file.sync_data()?; + + self.header.savepoint = SavePoint { + root, + previous_save_point, + }; + + self.db.write_header(&self.header)?; + Ok(()) + } +} + +pub struct KeyIterator<'db, H: NodeHasher> { + db: &'db Database, + stack: Vec, +} + +impl<'db, H: NodeHasher> KeyIterator<'db, H> { + fn new(db: &'db Database, root: Record) -> Self { + let stack = vec![root]; + Self { db, stack } + } +} + +impl<'db, H: NodeHasher> Iterator for KeyIterator<'db, H> { + type Item = Result<(Hash, Vec), io::Error>; + + fn next(&mut self) -> Option { + let record = self.stack.pop()?; + match self.db.load_node(record) { + Ok(inner) => match inner { + NodeInner::Leaf { key, value } => Some(Ok((key.0, value))), + NodeInner::Internal { left, right, .. } => { + self.stack.push(left.id); + self.stack.push(right.id); + self.next() + } + }, + Err(e) => Some(Err(e)), + } + } +} + +impl<'n> CacheEntry<'n> { + fn new(node: &'n mut Node, clean: bool) -> Self { + Self { node, clean } + } +} + +impl Drop for CacheEntry<'_> { + fn drop(&mut self) { + if self.clean { + self.node.inner = None; + } + } +} + +impl<'db, H: NodeHasher> Cache<'db, H> { + fn new(db: &'db Database, record: Record, capacity: usize) -> Self { + Self { + node: Some(Node::from_id(record)), + len: 0, + max_len: capacity, + db, + } + } + + fn is_full(&self) -> bool { + self.len > self.max_len + } + + fn load_node<'c>(&mut self, node: &'c mut Node) -> Result, io::Error> { + if node.inner.is_some() { + return Ok(CacheEntry { node, clean: false }); + } + assert_ne!(node.id, EMPTY_RECORD, "Attempted to read empty record"); + let is_full = self.is_full(); + + let inner = self.db.load_node(node.id)?; + + let empty_len = node.mem_size(); + node.inner = Some(inner); + let new_len = node.mem_size(); + if !is_full { + self.len += new_len - empty_len; + } + + Ok(CacheEntry { + node, + clean: is_full, + }) + } +} diff --git a/tests/integration_test.rs b/tests/integration_test.rs new file mode 100644 index 0000000..c2e6ecb --- /dev/null +++ b/tests/integration_test.rs @@ -0,0 +1,120 @@ +use spacedb::{db::Database, subtree::{SubTree, ValueOrHash}, NodeHasher, Sha256Hasher}; + +#[test] +fn it_works_with_empty_trees() { + let db = Database::memory().unwrap(); + + let mut snapshot = db.begin_read().unwrap(); + let root = snapshot.root().unwrap(); + assert_eq!( + root, + db.hash(&[]), + "empty tree must return zero hash" + ); + + let foo = db.hash("foo".as_bytes()); + let subtree = snapshot.prove(&foo).unwrap(); + + assert_eq!( + subtree.root().unwrap(), + root, + "empty subtree must return zero hash" + ); + + assert_eq!(subtree.contains(&foo).unwrap(), false) +} + +#[test] +fn it_inserts_into_tree() { + let db = Database::memory().unwrap(); + let mut tx = db.begin_write().unwrap(); + let key = db.hash(&[]); + let value = "some data".as_bytes().to_vec(); + + tx.insert(key.clone(), value.clone()).unwrap(); + tx.commit().unwrap(); + + let mut tree = db.begin_read().unwrap(); + + let mut subtree = SubTree::::empty(); + subtree.insert(key, ValueOrHash::Value(value)).unwrap(); + + assert_eq!( + subtree.root().unwrap(), + tree.root().unwrap(), + "subtree root != tree root" + ) +} + +#[test] +fn it_inserts_many_items_into_tree() { + let db = Database::memory().unwrap(); + let mut tx = db.begin_write().unwrap(); + + // Initialize the subtree + let mut subtree = SubTree::::empty(); + + // Insert 100 key-value pairs into the transaction and the subtree + let mut keys = Vec::new(); + for i in 0..100 { + let key = Sha256Hasher::hash(format!("key{}", i).as_bytes()); + keys.push(key.clone()); + let value = format!("data{}", i).as_bytes().to_vec(); + + tx.insert(key.clone(), value.clone()).unwrap(); + subtree.insert(key, ValueOrHash::Value(value)).unwrap(); + } + + // Commit the transaction + tx.commit().unwrap(); + + let mut tree = db.begin_read().unwrap(); + let subtree2 = tree.prove_all(&keys).unwrap(); + + assert_eq!( + subtree2.root().unwrap(), + tree.root().unwrap(), + "subtree2 != tree" + ); + + // Compare the root hash of the subtree and the main tree + assert_eq!( + subtree.root().unwrap(), + tree.root().unwrap(), + "subtree root != tree root after inserting many items" + ); +} + +#[test] +fn it_should_iterate_over_tree() { + use std::collections::HashSet; + let db = Database::memory().unwrap(); + let mut tx = db.begin_write().unwrap(); + let mut inserted_values = HashSet::new(); + + let n = 1000; + for i in 0..n { + let key = Sha256Hasher::hash(format!("key{}", i).as_bytes()); + let value = format!("data{}", i).as_bytes().to_vec(); + tx.insert(key.clone(), value.clone()).unwrap(); + inserted_values.insert(String::from_utf8(value).unwrap()); + } + + tx.commit().unwrap(); + + let snapshot = db.begin_read().unwrap(); + for (_, value) in snapshot.iter().filter_map(Result::ok) { + let value_str = String::from_utf8(value).unwrap(); + assert!( + inserted_values.contains(&value_str), + "Value not found in set: {}", + value_str + ); + } + + assert_eq!( + inserted_values.len(), + n, + "The number of iterated items does not match the number of inserted items." + ); +}