From aa5bbb573e33302643369d7d477a9c8284d59d4b Mon Sep 17 00:00:00 2001 From: Ben Dean-Kawamura Date: Wed, 23 Oct 2024 14:50:44 -0400 Subject: [PATCH] [DISCO-3043] multi armed bandit API adding init, select and update functions to be used in thompson sampling --- Cargo.lock | 13 +++++++ components/relevancy/Cargo.toml | 2 + components/relevancy/src/db.rs | 61 ++++++++++++++++++++++++++++++ components/relevancy/src/error.rs | 6 +++ components/relevancy/src/lib.rs | 56 +++++++++++++++++++++++++++ components/relevancy/src/schema.rs | 23 ++++++++++- 6 files changed, 160 insertions(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index 99669888e8..92e3050731 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2856,6 +2856,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" dependencies = [ "autocfg", + "libm", ] [[package]] @@ -3477,6 +3478,16 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rand_distr" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +dependencies = [ + "num-traits", + "rand", +] + [[package]] name = "rand_rccrypto" version = "0.1.0" @@ -3617,6 +3628,8 @@ dependencies = [ "log", "md-5", "parking_lot", + "rand", + "rand_distr", "remote_settings", "rusqlite", "serde", diff --git a/components/relevancy/Cargo.toml b/components/relevancy/Cargo.toml index c768fb100f..819e274187 100644 --- a/components/relevancy/Cargo.toml +++ b/components/relevancy/Cargo.toml @@ -15,6 +15,8 @@ sql-support = { path = "../support/sql" } log = "0.4" md-5 = "0.10" parking_lot = ">=0.11,<=0.12" +rand = "0.8" +rand_distr = "0.4" rusqlite = { workspace = true, features = ["bundled"] } serde = { version = "1", features = ["derive"] } serde_json = "1" diff --git a/components/relevancy/src/db.rs b/components/relevancy/src/db.rs index 6b2ce27d17..28a961f91b 100644 --- a/components/relevancy/src/db.rs +++ b/components/relevancy/src/db.rs @@ -13,6 +13,7 @@ use interrupt_support::SqlInterruptScope; use rusqlite::{Connection, OpenFlags}; use sql_support::{ConnExt, LazyDb}; use std::path::Path; +use crate::Error::BanditNotFound; /// A thread-safe wrapper around an SQLite connection to the Relevancy database pub struct RelevancyDb { @@ -75,6 +76,15 @@ impl RelevancyDb { } } +pub struct BanditData { + pub alpha: usize, + pub beta: usize, + pub impressions: usize, + pub clicks: usize, + pub arm: String, + pub bandit: String, +} + /// A data access object (DAO) that wraps a connection to the Relevancy database /// /// Methods that only read from the database take an immutable reference to @@ -172,8 +182,59 @@ impl<'a> RelevancyDao<'a> { } Ok(interest_vec) } + + pub fn initialize_multi_armed_bandit(&mut self, bandit: String, arm: String) -> Result<()> { + let mut stmt = self + .conn + .prepare("SELECT bandit, arm, alpha, beta, impressions, clicks FROM multi_armed_bandit WHERE bandit=? AND arm=?")?; + + let data = stmt.query_and_then((&bandit, &arm), |row| {Ok(BanditData{ + bandit:row.get(0)?, arm:row.get(1)?, alpha:row.get(2)?, beta:row.get(3)?, impressions:row.get(4)?, clicks:row.get(5)? + })})?.collect::>>()?; + + if data.is_empty() { + let bandit_data = (bandit, arm, 1, 1, 0, 0 ); + let mut new_statement = self.conn.prepare("INSERT INTO multi_armed_bandit VALUES (?, ?, ?, ?, ?, ?)")?; + new_statement.execute((&bandit_data.0, &bandit_data.1, bandit_data.2, bandit_data.3, bandit_data.4, bandit_data.5))?; + } + + Ok(()) + } + + pub fn retrieve_bandit_arm_beta_distribution(&self, bandit: &str, arm: &str) -> Result<(usize, usize)> { + // use bandit and arm to retrieve beta distribution + let mut stmt = self + .conn + .prepare("SELECT alpha, beta FROM multi_armed_bandit WHERE bandit=? AND arm=?")?; + + let mut result = stmt.query((&bandit, &arm))?; + + match result.next()? { + Some(row) => { + Ok((row.get(0)?, row.get(1)?)) + } + None => Err(BanditNotFound {bandit: bandit.to_string(), arm: arm.to_string()}) + } + } + + pub fn update_bandit_arm_beta_distribution(&self, bandit: &str, arm: &str, alpha: usize, beta: usize) -> Result<()> { + // update the beta distribution for given bandit and arm with alpha & beta + let mut stmt = self + .conn + .prepare("UPDATE multi_armed_bandit set alpha=? AND beta=? WHERE bandit=? AND arm=?")?; + + let result = stmt.execute((&alpha, &beta, &bandit, &arm))?; + + if result == 0 { + return Err(BanditNotFound {bandit: bandit.to_string(), arm: arm.to_string()}) + } + + Ok(()) + } } + + #[cfg(test)] mod test { use super::*; diff --git a/components/relevancy/src/error.rs b/components/relevancy/src/error.rs index 08a1c8cb5d..862eae8b3c 100644 --- a/components/relevancy/src/error.rs +++ b/components/relevancy/src/error.rs @@ -42,6 +42,12 @@ pub enum Error { #[error("Base64 Decode Error: {0}")] Base64DecodeError(String), + + #[error("Error retrieving beta distribution for bandit {bandit} and arm {arm}")] + BanditNotFound{ + bandit: String, + arm: String, + } } /// Result enum for the public API diff --git a/components/relevancy/src/lib.rs b/components/relevancy/src/lib.rs index da4fe19489..c2a1f5b87f 100644 --- a/components/relevancy/src/lib.rs +++ b/components/relevancy/src/lib.rs @@ -18,6 +18,8 @@ mod rs; mod schema; pub mod url_hash; +use rand_distr::{Beta, Distribution}; + pub use db::RelevancyDb; pub use error::{ApiResult, Error, RelevancyApiError, Result}; pub use interest::{Interest, InterestVector}; @@ -95,6 +97,60 @@ impl RelevancyStore { pub fn user_interest_vector(&self) -> ApiResult { self.db.read(|dao| dao.get_frecency_user_interest_vector()) } + + /// Initialize the probabilities for any unknown items. + #[handle_error(Error)] + pub fn bandit_init(&self, bandit: String, arms:Vec) -> ApiResult<()> { + // we can calculate click through rate to initialize a more accurate beta distribution + for arm in arms { + self.db.read_write(|dao| { + dao.initialize_multi_armed_bandit(bandit.clone(), arm.clone()) + })?; + } + + Ok(()) + } + + /// Pick an item to show the user + #[handle_error(Error)] + pub fn bandit_select(&self, bandit: String, arms: Vec) -> ApiResult { + // maybe cache the distribution so we don't retrieve each time + + let mut best_sample = f64::MIN; + let mut selected_arm = None; + + + for arm in arms { + let (alpha, beta) = self.db.read(|dao| dao.retrieve_bandit_arm_beta_distribution(&bandit, &arm))?; + // this creates a Beta distribution for an alpha & beta pair + let beta_dist = Beta::new(alpha as f64, beta as f64).expect("computing betas dist unexpectedly failed"); + + // Sample from the Beta distribution + let sampled_prob = beta_dist.sample(&mut rand::thread_rng()); + + if sampled_prob > best_sample { + best_sample = sampled_prob; + selected_arm = Some(arm); + } + } + + return Ok(selected_arm.unwrap_or_default()) + + } + + /// Update the model based on a user selection/non-selection + #[handle_error(Error)] + pub fn bandit_update(&self, bandit: String, arm: String, selected: bool) -> ApiResult<()> { + let (alpha, beta) = self.db.read(|dao| dao.retrieve_bandit_arm_beta_distribution(&bandit, &arm))?; + if selected { + self.db.read_write(|dao| dao.update_bandit_arm_beta_distribution(&bandit, &arm, alpha + 1, beta))?; + } else { + // update beta (failure) distribution (+1) for item + self.db.read_write(|dao| dao.update_bandit_arm_beta_distribution(&bandit, &arm, alpha, beta + 1))?; + } + + Ok(()) + } } impl RelevancyStore { diff --git a/components/relevancy/src/schema.rs b/components/relevancy/src/schema.rs index ab4cfd62df..e63d17f6f1 100644 --- a/components/relevancy/src/schema.rs +++ b/components/relevancy/src/schema.rs @@ -13,7 +13,7 @@ use sql_support::open_database::{self, ConnectionInitializer}; /// 1. Bump this version. /// 2. Add a migration from the old version to the new version in /// [`RelevancyConnectionInitializer::upgrade_from`]. -pub const VERSION: u32 = 14; +pub const VERSION: u32 = 15; /// The current database schema. pub const SQL: &str = " @@ -30,6 +30,15 @@ pub const SQL: &str = " count INTEGER NOT NULL, PRIMARY KEY (kind, interest_code) ) WITHOUT ROWID; + CREATE TABLE multi_armed_bandit( + bandit TEXT NOT NULL, + arm TEXT NOT NULL, + alpha INTEGER NOT NULL, + beta INTEGER NOT NULL, + clicks INTEGER NOT NULL, + impressions INTEGER NOT NULL, + PRIMARY KEY (bandit, arm) + ) WITHOUT ROWID; "; /// Initializes an SQLite connection to the Relevancy database, performing @@ -73,6 +82,18 @@ impl ConnectionInitializer for RelevancyConnectionInitializer { )?; Ok(()) } + 14 => { + tx.execute("CREATE TABLE multi_armed_bandit( + bandit TEXT NOT NULL, + arm TEXT NOT NULL, + alpha INTEGER NOT NULL, + beta INTEGER NOT NULL, + clicks INTEGER NOT NULL, + impressions INTEGER NOT NULL, + PRIMARY KEY (bandit, arm) + ) WITHOUT ROWID;", ()) ?; + Ok(()) + } _ => Err(open_database::Error::IncompatibleVersion(version)), } }