Skip to content

Commit

Permalink
convert from using the sqlite package to using sqlx for better compat…
Browse files Browse the repository at this point in the history
…ibility with consumer apps
  • Loading branch information
tomsanbear committed Jun 17, 2024
1 parent 4bf2d44 commit ed6924a
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 33 deletions.
2 changes: 1 addition & 1 deletion mai-sdk-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,4 @@ libp2p = { version = "0.53", features = [
] }
either = "^1.12"
base64 = "^0.22"
sqlite = "^0.36"
sqlx = { version = "0.7", features = ["runtime-tokio", "tls-rustls", "sqlite"] }
68 changes: 44 additions & 24 deletions mai-sdk-core/src/distributed_kv_store.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use crate::{network::PeerId, task_queue::TaskId};
use async_channel::Sender;
use libp2p::futures::TryStreamExt;
use sqlx::Row;
use std::{collections::HashMap, fmt::Debug, sync::Arc};
use tokio::sync::{Mutex, RwLock};
use tokio::sync::RwLock;

pub type TaskAssignments = Arc<RwLock<HashMap<TaskId, PeerId>>>;

Expand All @@ -18,7 +20,7 @@ use crate::event_bridge::EventBridge;
pub struct DistributedKVStore {
logger: Logger,
bridge: EventBridge,
connection: Arc<Mutex<sqlite::Connection>>,
connection_pool: sqlx::SqlitePool,
}

impl Debug for DistributedKVStore {
Expand All @@ -43,32 +45,53 @@ pub struct GetEvent {
}

impl DistributedKVStore {
pub fn new(logger: &Logger, bridge: &EventBridge, persist: bool) -> Self {
let connection = if persist {
pub async fn new(logger: &Logger, bridge: &EventBridge, persist: bool) -> Self {
let connection_pool = if persist {
info!(logger, "Using persistent database");
sqlite::open("mai_sdk.sqlite").unwrap()
sqlx::sqlite::SqlitePoolOptions::new()
.max_connections(1)
.connect_lazy("mai.db")
.unwrap()
} else {
warn!(logger, "Using in-memory database");
sqlite::open(":memory:").unwrap()
sqlx::sqlite::SqlitePoolOptions::new()
.max_connections(1)
.connect(":memory:")
.await
.unwrap()
};
let query = "CREATE TABLE IF NOT EXISTS kv (key TEXT PRIMARY KEY, value BLOB)";
connection.execute(query).unwrap();

// initialize the kv table
// TODO: convert to use sqlx::migrate! macro
sqlx::query(
r#"
CREATE TABLE IF NOT EXISTS kv (
key TEXT PRIMARY KEY,
value BLOB
)
"#,
)
.execute(&connection_pool)
.await
.unwrap();

DistributedKVStore {
logger: logger.clone(),
bridge: bridge.clone(),
connection: Arc::new(Mutex::new(connection)),
connection_pool,
}
}

pub async fn get(&self, key: String) -> Result<Option<Value>> {
// First check the local store, then remote store
if let Some(value) = {
let query = "SELECT value FROM kv WHERE key = ?";
let connection = self.connection.lock().await;
let mut statement = connection.prepare(query)?;
statement.bind((1, key.as_str()))?;
if let sqlite::State::Row = statement.next()? {
Some(statement.read(0)?)
let mut rows = sqlx::query("SELECT value FROM kv WHERE key = ?")
.bind(key.clone())
.fetch(&self.connection_pool);
let row = rows.try_next().await?;
if let Some(row) = row {
let value: Vec<u8> = row.get(0);
Some(value)
} else {
None
}
Expand All @@ -93,14 +116,11 @@ impl DistributedKVStore {
}

pub async fn set(&self, key: String, value: Value) -> Result<()> {
// Store locally then send a set event to the bridge
let connection = self.connection.lock().await;
let query = "INSERT OR REPLACE INTO kv (key, value) VALUES (?, ?)";
let mut statement = connection.prepare(query)?;
statement.bind((1, key.as_str()))?;
statement.bind((2, value.as_slice()))?;
statement.next()?;

sqlx::query("INSERT OR REPLACE INTO kv (key, value) VALUES (?, ?)")
.bind(key.clone())
.bind(value)
.execute(&self.connection_pool)
.await?;
Ok(())
}
}
Expand All @@ -115,7 +135,7 @@ mod tests {
let logger = slog::Logger::root(slog::Discard, slog::o!());

let bridge = EventBridge::new(&logger);
let store = DistributedKVStore::new(&logger, &bridge, false);
let store = DistributedKVStore::new(&logger, &bridge, false).await;

let key = "key".to_string();
let value = vec![1, 2, 3];
Expand Down
2 changes: 1 addition & 1 deletion mai-sdk-core/src/task_queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ mod tests {
}

// Setup distributed kv store
let distributed_kv_store = DistributedKVStore::new(&logger, &event_bridge, false);
let distributed_kv_store = DistributedKVStore::new(&logger, &event_bridge, false).await;

// Setup the task queue
let runnable_state = ();
Expand Down
3 changes: 2 additions & 1 deletion mai-sdk-plugins/src/transcription/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ mod tests {
&logger,
&event_bridge,
false,
);
)
.await;
let task = TranscriptionPluginTaskTranscribe::new();
let state = TranscriptionPluginState {
logger,
Expand Down
12 changes: 8 additions & 4 deletions mai-sdk-python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ pub struct PythonRuntimeArgs {
}

#[pyclass]
#[derive(Debug, Clone)]
pub struct PythonRuntime {
rt: tokio::runtime::Runtime,
state: RuntimeState,
}

Expand Down Expand Up @@ -50,15 +50,19 @@ impl PythonRuntime {
#[pyfunction]
fn start_worker(args: PythonRuntimeArgs) -> PyResult<PythonRuntime> {
let logger = slog::Logger::root(slog::Discard, o!());
let state = RuntimeState::new_worker(RuntimeStateArgs {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let state = rt.block_on(RuntimeState::new_worker(RuntimeStateArgs {
logger,
listen_addrs: args.gossip_listen_addrs,
bootstrap_addrs: args.bootstrap_addrs,
gossipsub_heartbeat_interval: Duration::from_secs(args.gossipsub_heartbeat_interval),
ping_interval: Duration::from_secs(args.ping_interval),
psk: None,
});
Ok(PythonRuntime { state })
}));
Ok(PythonRuntime { state, rt })
}

/// A Python module implemented in Rust.
Expand Down
4 changes: 2 additions & 2 deletions mai-sdk-runtime/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ pub struct RuntimeStateArgs {

impl RuntimeState {
/// Create a new instance of the runtime state
pub fn new_worker(args: RuntimeStateArgs) -> Self {
pub async fn new_worker(args: RuntimeStateArgs) -> Self {
let event_bridge = EventBridge::new(&args.logger);
let p2p_network = P2PNetwork::new(P2PNetworkConfig {
logger: args.logger.clone(),
Expand All @@ -155,7 +155,7 @@ impl RuntimeState {
gossipsub_heartbeat_interval: args.gossipsub_heartbeat_interval,
psk: args.psk,
});
let distributed_kv_store = DistributedKVStore::new(&args.logger, &event_bridge, true);
let distributed_kv_store = DistributedKVStore::new(&args.logger, &event_bridge, true).await;
let distributed_task_queue = DistributedTaskQueue::new(
&args.logger,
&p2p_network.peer_id(),
Expand Down

0 comments on commit ed6924a

Please sign in to comment.