Skip to content

Commit

Permalink
[DISCO-3043] multi armed bandit API
Browse files Browse the repository at this point in the history
adding init, select and update functions to be
used in thompson sampling
  • Loading branch information
bendk authored and misaniwere committed Oct 29, 2024
1 parent b8f83af commit aa5bbb5
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 1 deletion.
13 changes: 13 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions components/relevancy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ sql-support = { path = "../support/sql" }
log = "0.4"
md-5 = "0.10"
parking_lot = ">=0.11,<=0.12"
rand = "0.8"
rand_distr = "0.4"
rusqlite = { workspace = true, features = ["bundled"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
Expand Down
61 changes: 61 additions & 0 deletions components/relevancy/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use interrupt_support::SqlInterruptScope;
use rusqlite::{Connection, OpenFlags};
use sql_support::{ConnExt, LazyDb};
use std::path::Path;
use crate::Error::BanditNotFound;

/// A thread-safe wrapper around an SQLite connection to the Relevancy database
pub struct RelevancyDb {
Expand Down Expand Up @@ -75,6 +76,15 @@ impl RelevancyDb {
}
}

pub struct BanditData {
pub alpha: usize,
pub beta: usize,
pub impressions: usize,
pub clicks: usize,
pub arm: String,
pub bandit: String,
}

/// A data access object (DAO) that wraps a connection to the Relevancy database
///
/// Methods that only read from the database take an immutable reference to
Expand Down Expand Up @@ -172,8 +182,59 @@ impl<'a> RelevancyDao<'a> {
}
Ok(interest_vec)
}

pub fn initialize_multi_armed_bandit(&mut self, bandit: String, arm: String) -> Result<()> {
let mut stmt = self
.conn
.prepare("SELECT bandit, arm, alpha, beta, impressions, clicks FROM multi_armed_bandit WHERE bandit=? AND arm=?")?;

let data = stmt.query_and_then((&bandit, &arm), |row| {Ok(BanditData{
bandit:row.get(0)?, arm:row.get(1)?, alpha:row.get(2)?, beta:row.get(3)?, impressions:row.get(4)?, clicks:row.get(5)?
})})?.collect::<Result<Vec<BanditData>>>()?;

if data.is_empty() {
let bandit_data = (bandit, arm, 1, 1, 0, 0 );
let mut new_statement = self.conn.prepare("INSERT INTO multi_armed_bandit VALUES (?, ?, ?, ?, ?, ?)")?;
new_statement.execute((&bandit_data.0, &bandit_data.1, bandit_data.2, bandit_data.3, bandit_data.4, bandit_data.5))?;
}

Ok(())
}

pub fn retrieve_bandit_arm_beta_distribution(&self, bandit: &str, arm: &str) -> Result<(usize, usize)> {
// use bandit and arm to retrieve beta distribution
let mut stmt = self
.conn
.prepare("SELECT alpha, beta FROM multi_armed_bandit WHERE bandit=? AND arm=?")?;

let mut result = stmt.query((&bandit, &arm))?;

match result.next()? {
Some(row) => {
Ok((row.get(0)?, row.get(1)?))
}
None => Err(BanditNotFound {bandit: bandit.to_string(), arm: arm.to_string()})
}
}

pub fn update_bandit_arm_beta_distribution(&self, bandit: &str, arm: &str, alpha: usize, beta: usize) -> Result<()> {
// update the beta distribution for given bandit and arm with alpha & beta
let mut stmt = self
.conn
.prepare("UPDATE multi_armed_bandit set alpha=? AND beta=? WHERE bandit=? AND arm=?")?;

let result = stmt.execute((&alpha, &beta, &bandit, &arm))?;

if result == 0 {
return Err(BanditNotFound {bandit: bandit.to_string(), arm: arm.to_string()})
}

Ok(())
}
}



#[cfg(test)]
mod test {
use super::*;
Expand Down
6 changes: 6 additions & 0 deletions components/relevancy/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ pub enum Error {

#[error("Base64 Decode Error: {0}")]
Base64DecodeError(String),

#[error("Error retrieving beta distribution for bandit {bandit} and arm {arm}")]
BanditNotFound{
bandit: String,
arm: String,
}
}

/// Result enum for the public API
Expand Down
56 changes: 56 additions & 0 deletions components/relevancy/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ mod rs;
mod schema;
pub mod url_hash;

use rand_distr::{Beta, Distribution};

pub use db::RelevancyDb;
pub use error::{ApiResult, Error, RelevancyApiError, Result};
pub use interest::{Interest, InterestVector};
Expand Down Expand Up @@ -95,6 +97,60 @@ impl RelevancyStore {
pub fn user_interest_vector(&self) -> ApiResult<InterestVector> {
self.db.read(|dao| dao.get_frecency_user_interest_vector())
}

/// Initialize the probabilities for any unknown items.
#[handle_error(Error)]
pub fn bandit_init(&self, bandit: String, arms:Vec<String>) -> ApiResult<()> {
// we can calculate click through rate to initialize a more accurate beta distribution
for arm in arms {
self.db.read_write(|dao| {
dao.initialize_multi_armed_bandit(bandit.clone(), arm.clone())
})?;
}

Ok(())
}

/// Pick an item to show the user
#[handle_error(Error)]
pub fn bandit_select(&self, bandit: String, arms: Vec<String>) -> ApiResult<String> {
// maybe cache the distribution so we don't retrieve each time

let mut best_sample = f64::MIN;
let mut selected_arm = None;


for arm in arms {
let (alpha, beta) = self.db.read(|dao| dao.retrieve_bandit_arm_beta_distribution(&bandit, &arm))?;
// this creates a Beta distribution for an alpha & beta pair
let beta_dist = Beta::new(alpha as f64, beta as f64).expect("computing betas dist unexpectedly failed");

// Sample from the Beta distribution
let sampled_prob = beta_dist.sample(&mut rand::thread_rng());

if sampled_prob > best_sample {
best_sample = sampled_prob;
selected_arm = Some(arm);
}
}

return Ok(selected_arm.unwrap_or_default())

}

/// Update the model based on a user selection/non-selection
#[handle_error(Error)]
pub fn bandit_update(&self, bandit: String, arm: String, selected: bool) -> ApiResult<()> {
let (alpha, beta) = self.db.read(|dao| dao.retrieve_bandit_arm_beta_distribution(&bandit, &arm))?;
if selected {
self.db.read_write(|dao| dao.update_bandit_arm_beta_distribution(&bandit, &arm, alpha + 1, beta))?;
} else {
// update beta (failure) distribution (+1) for item
self.db.read_write(|dao| dao.update_bandit_arm_beta_distribution(&bandit, &arm, alpha, beta + 1))?;
}

Ok(())
}
}

impl RelevancyStore {
Expand Down
23 changes: 22 additions & 1 deletion components/relevancy/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use sql_support::open_database::{self, ConnectionInitializer};
/// 1. Bump this version.
/// 2. Add a migration from the old version to the new version in
/// [`RelevancyConnectionInitializer::upgrade_from`].
pub const VERSION: u32 = 14;
pub const VERSION: u32 = 15;

/// The current database schema.
pub const SQL: &str = "
Expand All @@ -30,6 +30,15 @@ pub const SQL: &str = "
count INTEGER NOT NULL,
PRIMARY KEY (kind, interest_code)
) WITHOUT ROWID;
CREATE TABLE multi_armed_bandit(
bandit TEXT NOT NULL,
arm TEXT NOT NULL,
alpha INTEGER NOT NULL,
beta INTEGER NOT NULL,
clicks INTEGER NOT NULL,
impressions INTEGER NOT NULL,
PRIMARY KEY (bandit, arm)
) WITHOUT ROWID;
";

/// Initializes an SQLite connection to the Relevancy database, performing
Expand Down Expand Up @@ -73,6 +82,18 @@ impl ConnectionInitializer for RelevancyConnectionInitializer {
)?;
Ok(())
}
14 => {
tx.execute("CREATE TABLE multi_armed_bandit(
bandit TEXT NOT NULL,
arm TEXT NOT NULL,
alpha INTEGER NOT NULL,
beta INTEGER NOT NULL,
clicks INTEGER NOT NULL,
impressions INTEGER NOT NULL,
PRIMARY KEY (bandit, arm)
) WITHOUT ROWID;", ()) ?;
Ok(())
}
_ => Err(open_database::Error::IncompatibleVersion(version)),
}
}
Expand Down

0 comments on commit aa5bbb5

Please sign in to comment.