From c4462fb3a9f0a602bee32c83d06c65a3566da721 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Sun, 1 Dec 2024 14:32:57 -0500 Subject: [PATCH] first draft fully merge group --- optd-mvp/DESIGN.md | 155 ++++++++++++--- optd-mvp/entities.md | 2 +- optd-mvp/src/expression/logical_expression.rs | 71 +++++-- .../src/expression/physical_expression.rs | 26 +-- optd-mvp/src/lib.rs | 2 - .../src/memo/persistent/implementation.rs | 183 ++++++++++++++---- optd-mvp/src/memo/persistent/tests.rs | 101 ++++++++-- .../migrator/memo/m20241127_000001_group.rs | 4 +- 8 files changed, 424 insertions(+), 120 deletions(-) diff --git a/optd-mvp/DESIGN.md b/optd-mvp/DESIGN.md index 190eee3..986bfac 100644 --- a/optd-mvp/DESIGN.md +++ b/optd-mvp/DESIGN.md @@ -1,9 +1,12 @@ # Duplicate Elimination Memo Table +_Connor Tsui, December 2024_ + Note that most of the details are in `src/memo/persistent/implementation.rs`. -For this document, we are assuming that the memo table is backed by a database / ORM. A lot of these -problems would likely not be an issue if everything was in memory. +For this document, we are assuming that the memo table is backed by a database / ORM. Both the +problems and the features detailed in this document are unique to this design, and likely do not +apply to an in-memory memo table. ## Group Merging @@ -12,20 +15,21 @@ for this is to immediately merge two groups together when the engine determines expression would result in a duplicate expression from another group. However, if we want to support parallel exploration, this could be prone to high contention. By -definition, merging group G1 into group G2 would mean that _every expression_ that has a child of -group G1 with would need to be rewritten to point to group G2 instead. +definition, merging group 1 into group 2 would mean that _every expression_ that has a child of +group 1 with would need to be rewritten to point to group 2 instead. -This is unacceptable in a parallel setting, as that would mean every single task that gets affected -would need to either wait for the rewrites to happen before resuming work, or need to abort their -work because data has changed underneath them. +This is prohibitive in a parallel setting, as that would mean every single task that gets affected +would need to either wait for the rewrites to happen before resuming work, or potentially need to +abort their work because data has changed underneath them. -So immediate / eager group merging is not a great idea for parallel exploration. However, if we do -not ever merge two groups that are identical, we are subject to doing duplicate work for every +So immediate / eager group merging is not a great idea for parallel exploration. However, if we +don't merge two groups that are equivalent, we are subject to doing duplicate work for every duplicate expression in the memo table during physical optimization. Instead of merging groups together immediately, we can instead maintain an auxiliary data structure that records the groups that _eventually_ need to get merged, and "lazily" merge those groups -together once every group has finished exploration. +together once every group has finished exploration. We will refer to merging groups as the act of +recording that the groups should eventually be merged together after exploration is finished. ## Union-Find Group Sets @@ -33,20 +37,22 @@ We use the well-known Union-Find algorithm and corresponding data structure as t structure that tracks the to-be-merged groups. Union-Find supports `Union` and `Find` operations, where `Union` merges sets and `Find` searches for -a "canonical" or "root" element that is shared between all elements in a given set. +a "canonical" or "root" element that is shared between all elements in a given set. Note that we +will also support an iteration operation that iterates over all elements in a given set. We will +need this for [duplicate detection](#fingerprinting--group-merge), which is explained below. For more information about Union-Find, see these -[15-451 lecture notes](https://www.cs.cmu.edu/~15451-f24/lectures/lecture08-union-find.pdf). +[15-451 lecture notes](https://www.cs.cmu.edu/~15451-f24/lectures/lecture08-union-find.pdf). We will +use the exact same data structure, but add an additional `next` pointer for each node that embeds +a circular linked list for each set. -Here, we make the elements the groups themselves (really the Group IDs), which allows us to merge +Here, we make the elements the groups themselves (really the group IDs), which allows us to merge group sets together and also determine a "root group" that all groups in a set can agree on. When every group in a group set has finished exploration, we can safely begin to merge them together by moving all expressions from every group in the group set into a single large group. Other than making sure that any reference to an old group in the group set points to this new large -group, exploration of all groups are done and physical optimization can start. - -RFC: Do we need to support incremental search? +group, exploration of all groups is done and physical optimization can start. Note that since we are now waiting for exploration of all groups to finish, this algorithm is much closer to the Volcano framework than the Cascades' incremental search. However, since we eventually @@ -56,14 +62,115 @@ of a problem. ## Duplicate Detection -TODO explain the fingerprinting algorithm and how it relates to group merging - -Union find data structure with a circular linked list for linear iteration +Deciding that we will merge groups lazily does not solve all of our problems. We have to know _when_ +we want to merge these groups. -When taking the fingerprint of an expression, the child groups of an expression need to be root groups. If they are not, we need to try again. -Assuming that all children are root groups, the fingerprint we make for any expression that fulfills that is valid and can be looked up for duplicates. -In order to maintain that correctness, on a merge of two sets, the smaller one requires that a new fingerprint be generated for every expression that has a group in that smaller set. -For example, let's say we need to merge { 1, 2 } (root group 1) with { 3, 4, 5, 6, 7, 8 } (root group 3). We need to find every single expression that has a child group of 1 or 2 and we need to generate a new fingerprint for each where the child groups have been "rewritten" to 3 +A naive approach is to simply loop over every expression in the memo table and check if we are about +to insert a duplicate. This, of course, is bad for performance. -TODO this is incredibly expensive, but is potentially easily parallelizable? +We will use a fingerprinting / hashing method to detect when a duplicate expression might be +inserted into the memo table (returning an error instead of inserting), and we will use that to +trigger group merges. +For every logical expression we insert into the memo table, we will create a fingerprint that +contains both the kind of expression / relation (Scan, Filter, Join) and a hash of all +information that makes that expression unique. For example: + +- The fingerprint of a Scan should probably contain a hash of the table name and the pushdown + predicate. +- The fingerprint of a Filter should probably contain a hash of its child group ID and predicate. +- The fingerprint of a Join should probably contain a hash of the left group ID and the right group + ID, as well as the join predicate. + +Note that the above descriptions are slightly inaccurate, and we'll explain why in a later +[section](#fingerprinting--group-merge). + +Also, if we have duplicate detection for logical expression, and we do not start physical +optimization until after full plan enumeration, then we do not actually need to do duplicate +detection of physical expressions, since they are derivative of the deduplicated logical +expressions. + +### Fingerprint Matching Algorithm + +When an expression is added to the memo table, it will first calculate the fingerprint of the +expression. The memo table will compare this fingerprint with every fingerprint in the memo table to +check if we have seen this expression before (in any group). While this is effectively a scan +through every expression, supporting the fingerprint table with an B+tree index will speed up this +operation dramatically (since these fingerprints can be sorted by expression / relation kind). + +If there are no identical fingerprints, then there is no duplicate expression, and we can safely +add the expression into the memo table. However, if there are matching fingerprints, we need to +further check for false positives due to hash collisions. + +We do full exact match equality checks with every expression that had a fingerprint match. If there +are no exact matches, then we can safely add the expression into the memo table. However, if we find +an exact match (note that there can be at most one exact match since we have an invariant that there +cannot be duplicate expressions), then we know that the expression we are trying to add already +exists in the memo table. + +### Fingerprinting + Group Merge + +There is a slight problem with the algorithm described above. It does not account for when a child +group has merged into another group. + +For example, let's say we have groups 1, 2, and 3. We insert an expression Join(1, 2) into the +memo table with its fingerprint calculated with groups 1 and 2. It is possible that we find out that +groups 2 and 3 need to merged. This means that Join(1, 2) and Join (1, 3) are actually identical +expressions, and the fingerprinting strategies for expressions described above do not handle this. + +We will solve this problem by adding allowing multiple fingerprints to reference the same logical +expression, and we will generate a new fingerprint for every expression that is affected by a group +merge / every expression who's parent group now has a new root group. + +In the above scenario, we will find every expression in the memo table that has group 2 as a child. +For each expression, we will generate another fingerprint with group 2 "rewritten" as group 3 in the +hash. Note that we _do not_ modify the original expression, we are simply adding another fingerprint +into the memo table. + +Finally, we need to handle when multiple groups in a group set are merged into another group set. +For example, if a left group set { 1, 2, 3, 4, 5 } with root 1 needs to be merged into a right group +set { 6, 7, 8, 9, 10 } with root 6, then we need to generate a new fingerprint for every expression +in groups 1, 2, 3, 4, and 5 with group 1 "rewritten" as group 6. + +More formally, we are maintaining this invariant: +**For every expression, there exists a fingerprint that maps back to the expression that uses the** +**root groups of their children to calculate the hash.** + +For example, if we have a group set { 1, 3, 5 } with root group 1 and group set { 2, 4, 6 } with +root group 2, the fingerprint of Join(5, 4) should really be a fingerprint of Join(1, 2). + +This invariant means that when we are checking if some expression already exists, we should use the +root groups of the child groups in our expression to calculate the fingerprint, and we can guarantee +that no fingerprint matches implies no duplicates. + +A further implication of this invariant means that new fingerprints need to be generated every time +we merge groups. If we have a left group set { 1, 3, 5 } with root group 1 and right group set +{ 2, 4, 6 } with root group 2, and we merge the first group set into the second, then every +expression that has a child group of 1, 3, or 5 now has a stale fingerprint that uses root group 1 +instead of root group 2. + +Thus, when we merge the left group into the right group, we need to do the following: + +1. Gather the group set, i.e. every single group that has root group 1 (iterate) +2. Retrieve every single expression that has a child group in the group set (via junction table) +3. Generate a new fingerprint for each expression and add it into the memo table + +The speed of steps 2 and 3 above are largely dependent on the backing DBMS. However, we can support +step 1 directly in the union find data structure by maintain a circular linked list for every set. +Each group now tracks both a `parent` pointer and a `next` pointer. When merging / unioning a set +into another set, we swap the `next` pointers of the two roots to maintain the circular linked list. +This allows us to do step 1 in linear time relative to the size of the group set. + +## Efficiency and Parallelism + +Fingerprinting by itself is very efficient, as creating a fingerprint and looking up a fingerprint +can be made quite efficient with indexes. The real concern here is that merging two groups is very, +very expensive. Depending on the workload, it is both possible that the amortized cost is low or +that group merging takes a majority of the work. + +However, we must remember that we want to parallelize access to the memo table. The above algorithms +are notably **read and append only**. There is never a point where we need to update an expression +to maintain invariants. This is important, as it means that we can add and lookup expression and +groups _without having to take any locks_. If we enforce a serializable isolation level, then every +method on the memo table can be done in parallel with relatively low contention due to there being +zero write-write conflicts. diff --git a/optd-mvp/entities.md b/optd-mvp/entities.md index cfd082a..fc13a39 100644 --- a/optd-mvp/entities.md +++ b/optd-mvp/entities.md @@ -8,7 +8,7 @@ This assumes that you already have the `sqlite3` binary installed. First, make s $ cargo install sea-orm-cli ``` -Make sure your working directory is in the crate root: +Make sure your working directory is in the crate root (not workspace): ```sh $ cd optd-mvp diff --git a/optd-mvp/src/expression/logical_expression.rs b/optd-mvp/src/expression/logical_expression.rs index c7918de..ef0b9f2 100644 --- a/optd-mvp/src/expression/logical_expression.rs +++ b/optd-mvp/src/expression/logical_expression.rs @@ -2,24 +2,20 @@ //! //! FIXME: All fields are placeholders. //! -//! TODO Remove dead code. //! TODO Figure out if each relation should be in a different submodule. //! TODO This entire file is a WIP. -#![allow(dead_code)] - use crate::{entities::*, memo::GroupId}; use fxhash::hash; use serde::{Deserialize, Serialize}; -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug)] pub enum LogicalExpression { Scan(Scan), Filter(Filter), Join(Join), } -/// FIXME: Figure out how to make everything unsigned instead of signed. impl LogicalExpression { pub fn kind(&self) -> i16 { match self { @@ -29,11 +25,6 @@ impl LogicalExpression { } } - /// Definitions of custom fingerprinting strategies for each kind of logical expression. - pub fn fingerprint(&self) -> i64 { - self.fingerprint_with_rewrite(&[]) - } - /// Calculates the fingerprint of a given expression, but replaces all of the children group IDs /// with a new group ID if it is listed in the input `rewrites` list. /// @@ -55,7 +46,7 @@ impl LogicalExpression { let kind = self.kind() as u16 as usize; let hash = match self { - LogicalExpression::Scan(scan) => hash(scan.table_schema.as_str()), + LogicalExpression::Scan(scan) => hash(scan.table.as_str()), LogicalExpression::Filter(filter) => { hash(&rewrite(filter.child).0) ^ hash(filter.expression.as_str()) } @@ -69,27 +60,69 @@ impl LogicalExpression { // Mask out the bottom 16 bits of `hash` and replace them with `kind`. ((hash & !0xFFFF) | kind) as i64 } + + /// Checks equality between two expressions, with both expression rewriting their child group + /// IDs according to the input `rewrites` list. + pub fn eq_with_rewrite(&self, other: &Self, rewrites: &[(GroupId, GroupId)]) -> bool { + // Closure that rewrites a group ID if needed. + let rewrite = |x: GroupId| { + if rewrites.is_empty() { + return x; + } + + if let Some(i) = rewrites.iter().position(|(curr, _new)| &x == curr) { + assert_eq!(rewrites[i].0, x); + rewrites[i].1 + } else { + x + } + }; + + match (self, other) { + (LogicalExpression::Scan(scan_left), LogicalExpression::Scan(scan_right)) => { + scan_left.table == scan_right.table + } + (LogicalExpression::Filter(filter_left), LogicalExpression::Filter(filter_right)) => { + rewrite(filter_left.child) == rewrite(filter_right.child) + && filter_left.expression == filter_right.expression + } + (LogicalExpression::Join(join_left), LogicalExpression::Join(join_right)) => { + rewrite(join_left.left) == rewrite(join_right.left) + && rewrite(join_left.right) == rewrite(join_right.right) + && join_left.expression == join_right.expression + } + _ => false, + } + } + + pub fn children(&self) -> Vec { + match self { + LogicalExpression::Scan(_) => vec![], + LogicalExpression::Filter(filter) => vec![filter.child], + LogicalExpression::Join(join) => vec![join.left, join.right], + } + } } -#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)] +#[derive(Serialize, Deserialize, Clone, Debug)] pub struct Scan { - table_schema: String, + table: String, } -#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)] +#[derive(Serialize, Deserialize, Clone, Debug)] pub struct Filter { child: GroupId, expression: String, } -#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)] +#[derive(Serialize, Deserialize, Clone, Debug)] pub struct Join { left: GroupId, right: GroupId, expression: String, } -/// TODO Use a macro instead. +/// TODO Use a macro. impl From for LogicalExpression { fn from(value: logical_expression::Model) -> Self { match value.kind { @@ -110,7 +143,7 @@ impl From for LogicalExpression { } } -/// TODO Use a macro instead. +/// TODO Use a macro. impl From for logical_expression::Model { fn from(value: LogicalExpression) -> logical_expression::Model { fn create_logical_expression( @@ -152,7 +185,9 @@ mod build { use crate::expression::LogicalExpression; pub fn scan(table_schema: String) -> LogicalExpression { - LogicalExpression::Scan(Scan { table_schema }) + LogicalExpression::Scan(Scan { + table: table_schema, + }) } pub fn filter(child_group: GroupId, expression: String) -> LogicalExpression { diff --git a/optd-mvp/src/expression/physical_expression.rs b/optd-mvp/src/expression/physical_expression.rs index aaaa9e7..9c451b7 100644 --- a/optd-mvp/src/expression/physical_expression.rs +++ b/optd-mvp/src/expression/physical_expression.rs @@ -2,12 +2,9 @@ //! //! FIXME: All fields are placeholders. //! -//! TODO Remove dead code. //! TODO Figure out if each operator should be in a different submodule. //! TODO This entire file is a WIP. -#![allow(dead_code)] - use crate::{entities::*, memo::GroupId}; use serde::{Deserialize, Serialize}; @@ -36,7 +33,7 @@ pub struct HashJoin { expression: String, } -/// TODO Use a macro instead. +/// TODO Use a macro. impl From for PhysicalExpression { fn from(value: physical_expression::Model) -> Self { match value.kind { @@ -57,7 +54,7 @@ impl From for PhysicalExpression { } } -/// TODO Use a macro instead. +/// TODO Use a macro. impl From for physical_expression::Model { fn from(value: PhysicalExpression) -> physical_expression::Model { fn create_physical_expression( @@ -100,23 +97,4 @@ mod build { pub fn table_scan(table_schema: String) -> PhysicalExpression { PhysicalExpression::TableScan(TableScan { table_schema }) } - - pub fn filter(child_group: GroupId, expression: String) -> PhysicalExpression { - PhysicalExpression::Filter(PhysicalFilter { - child: child_group, - expression, - }) - } - - pub fn hash_join( - left_group: GroupId, - right_group: GroupId, - expression: String, - ) -> PhysicalExpression { - PhysicalExpression::HashJoin(HashJoin { - left: left_group, - right: right_group, - expression, - }) - } } diff --git a/optd-mvp/src/lib.rs b/optd-mvp/src/lib.rs index 506eee4..48a4c78 100644 --- a/optd-mvp/src/lib.rs +++ b/optd-mvp/src/lib.rs @@ -18,8 +18,6 @@ pub const DATABASE_FILENAME: &str = "sqlite.db"; pub const DATABASE_URL: &str = "sqlite:./sqlite.db?mode=rwc"; /// An error type wrapping all the different kinds of error the optimizer might raise. -/// -/// TODO more docs. #[derive(Error, Debug)] pub enum OptimizerError { #[error("SeaORM error")] diff --git a/optd-mvp/src/memo/persistent/implementation.rs b/optd-mvp/src/memo/persistent/implementation.rs index d7e7c25..70b10e1 100644 --- a/optd-mvp/src/memo/persistent/implementation.rs +++ b/optd-mvp/src/memo/persistent/implementation.rs @@ -18,6 +18,7 @@ use sea_orm::{ entity::{IntoActiveModel, NotSet, Set}, Database, }; +use std::collections::HashSet; impl PersistentMemo { /// Creates a new `PersistentMemo` struct by connecting to a database defined at @@ -90,12 +91,55 @@ impl PersistentMemo { // For every group along the path that we walked, set their parent id pointer to the root. // This allows for an amortized O(1) cost for `get_root_group`. for group in path { - self.update_group_parent(GroupId(group.id), root_id).await?; + let mut active_group = group.into_active_model(); + + // Update the group to point to the new parent. + active_group.parent_id = Set(Some(root_id.0)); + active_group.update(&self.db).await?; } Ok(root_id) } + /// Retrieves every group ID of groups that share the same root group with the input group. + /// + /// If a group does not exist in the cycle, returns a [`MemoError::UnknownGroup`] error. + /// + /// The group records form a union-find data structure that also maintains a circular linked + /// list in every set that allows us to iterate over all elements in a set in linear time. + pub async fn get_group_set(&self, group_id: GroupId) -> OptimizerResult> { + // Iterate over the circular linked list until we reach ourselves again. + let base_group = self.get_group(group_id).await?; + + // The only case when `next_id` is set to `None` is if the current group is a root, which + // means that this group is the only group in the set. + if base_group.next_id.is_none() { + assert!(base_group.parent_id.is_none()); + return Ok(vec![group_id]); + } + + // Iterate over the circular linked list until we see ourselves again, collecting nodes + // along the way. + let mut set = vec![group_id]; + let mut next_id = base_group + .next_id + .expect("next pointer cannot be null if it is in a cycle"); + loop { + let curr_group = self.get_group(GroupId(next_id)).await?; + + if curr_group.id == group_id.0 { + break; + } + + set.push(GroupId(curr_group.id)); + next_id = curr_group + .next_id + .expect("next pointer cannot be null if it is in a cycle"); + } + + Ok(set) + } + /// Retrieves a [`physical_expression::Model`] given a [`PhysicalExpressionId`]. /// /// If the physical expression does not exist, returns a @@ -227,30 +271,6 @@ impl PersistentMemo { Ok(old_id) } - /// Updates / replaces a group's parent group. Optionally returns the previous parent. - /// - /// If either of the groups do not exist, returns a [`MemoError::UnknownGroup`] error. - pub async fn update_group_parent( - &self, - group_id: GroupId, - parent_id: GroupId, - ) -> OptimizerResult> { - // First retrieve the group record. - let mut group = self.get_group(group_id).await?.into_active_model(); - - // Check that the parent group exists. - let _ = self.get_group(parent_id).await?; - - // Update the group to point to the new parent. - let old_parent = group.parent_id; - group.parent_id = Set(Some(parent_id.0)); - group.update(&self.db).await?; - - // Note that the `unwrap` here is unwrapping the `ActiveValue`, not the `Option`. - let old_parent = old_parent.unwrap().map(GroupId); - Ok(old_parent) - } - /// Adds a logical expression to an existing group via its ID. /// /// The caller is required to pass in a slice of [`GroupId`] that represent the child groups of @@ -265,8 +285,6 @@ impl PersistentMemo { /// /// If the memo table detects that the input is unique, it will insert the expression into the /// input group and return an `Ok(Ok(expression_id))`. - /// - /// FIXME Check that all of the children are reduced groups? pub async fn add_logical_expression_to_group( &self, group_id: GroupId, @@ -323,7 +341,7 @@ impl PersistentMemo { kind: Set(kind), hash: Set(hash), }; - let _ = fingerprint::Entity::insert(fingerprint) + fingerprint::Entity::insert(fingerprint) .exec(&self.db) .await?; @@ -379,8 +397,6 @@ impl PersistentMemo { /// This function assumes that the child groups of the expression are currently roots of their /// group sets. For example, if G1 and G2 should be merged, and G1 is the root, then the input /// expression should _not_ have G2 as a child, and should be replaced with G1. - /// - /// TODO Check that all of the children are root groups? How to do this? pub async fn is_duplicate_logical_expression( &self, logical_expression: &LogicalExpression, @@ -422,8 +438,16 @@ impl PersistentMemo { let expr_id = LogicalExpressionId(potential_match.logical_expression_id); let (group_id, expr) = self.get_logical_expression(expr_id).await?; - // Check for an exact match. - if &expr == logical_expression { + // We need to add the root groups of the new expression to the rewrites vector. + // TODO make this much more efficient by making rewrites a hash map, potentially im::HashMap. + let mut rewrites = rewrites.clone(); + for child_id in expr.children() { + let root_id = self.get_root_group(child_id).await?; + rewrites.push((child_id, root_id)); + } + + // Check for an exact match after rewrites. + if logical_expression.eq_with_rewrite(&expr, &rewrites) { match_id = Some((group_id, expr_id)); // There should be at most one duplicate expression, so we can break here. @@ -447,8 +471,6 @@ impl PersistentMemo { /// /// If the expression does not exist, this function will create a new group and a new /// expression, returning brand new IDs for both. - /// - /// FIXME Check that all of the children are reduced groups? pub async fn add_group( &self, logical_expression: LogicalExpression, @@ -513,10 +535,101 @@ impl PersistentMemo { kind: Set(kind), hash: Set(hash), }; - let _ = fingerprint::Entity::insert(fingerprint) + fingerprint::Entity::insert(fingerprint) .exec(&self.db) .await?; Ok(Ok((GroupId(group_id), LogicalExpressionId(expr_id)))) } + + /// Merges two groups sets together. + /// + /// If either of the input groups do not exist, returns a [`MemoError::UnknownGroup`] error. + /// + /// TODO write docs. + /// TODO highly inefficient, need to understand metrics and performance testing. + /// TODO Optimization: add rank / size into data structure + pub async fn merge_groups( + &self, + left_group_id: GroupId, + right_group_id: GroupId, + ) -> OptimizerResult { + // Without a rank / size field, we have no way of determining which set is better to merge + // into the other. So we will arbitrarily choose to merge the left group into the right + // group here. If rank is added in the future, then merge the smaller set into the larger. + + let left_root_id = self.get_root_group(left_group_id).await?; + let left_root = self.get_group(left_root_id).await?; + // A `None` next pointer means it should technically be pointing to itself. + let left_next = left_root.next_id.unwrap_or(left_root_id.0); + let mut active_left_root = left_root.into_active_model(); + + let right_root_id = self.get_root_group(right_group_id).await?; + let right_root = self.get_group(right_root_id).await?; + // A `None` next pointer means it should technically be pointing to itself. + let right_next = right_root.next_id.unwrap_or(right_root_id.0); + let mut active_right_root = right_root.into_active_model(); + + // Before we actually update the group records, We first need to generate new fingerprints + // for every single expression that has a child group in the left set. + // TODO make this more efficient, this code is doing double work from `get_group_set`. + let group_set_ids = self.get_group_set(left_group_id).await?; + let mut left_group_models = Vec::with_capacity(group_set_ids.len()); + for &group_id in &group_set_ids { + left_group_models.push(self.get_group(group_id).await?); + } + + // Retrieve every single expression that has a child group in the left set. + let left_group_expressions: Vec> = left_group_models + .load_many_to_many( + logical_expression::Entity, + logical_children::Entity, + &self.db, + ) + .await?; + + // Need to replace every single occurrence of groups in the set with the new root. + let rewrites: Vec<(GroupId, GroupId)> = group_set_ids + .iter() + .map(|&group_id| (group_id, right_root_id)) + .collect(); + + // For each expression, generate a new fingerprint. + let mut seen = HashSet::new(); + for model in left_group_expressions.into_iter().flatten() { + let expr_id = model.id; + + // There may be duplicates in the expressions list. + if seen.contains(&expr_id) { + continue; + } else { + seen.insert(expr_id); + } + + let logical_expression: LogicalExpression = model.into(); + let hash = logical_expression.fingerprint_with_rewrite(&rewrites); + + let fingerprint = fingerprint::ActiveModel { + id: NotSet, + logical_expression_id: Set(expr_id), + kind: Set(logical_expression.kind()), + hash: Set(hash), + }; + fingerprint::Entity::insert(fingerprint) + .exec(&self.db) + .await?; + } + + // Update the left group root to point to the right group root. + active_left_root.parent_id = Set(Some(right_root_id.0)); + + // Swap the next pointers of each root to maintain the circular linked list. + active_left_root.next_id = Set(Some(right_next)); + active_right_root.next_id = Set(Some(left_next)); + + active_left_root.update(&self.db).await?; + active_right_root.update(&self.db).await?; + + Ok(right_root_id) + } } diff --git a/optd-mvp/src/memo/persistent/tests.rs b/optd-mvp/src/memo/persistent/tests.rs index 3dcddd6..0e07c81 100644 --- a/optd-mvp/src/memo/persistent/tests.rs +++ b/optd-mvp/src/memo/persistent/tests.rs @@ -114,8 +114,6 @@ async fn test_simple_tree() { ); // Create two join expression that should be in the same group. - // TODO: Eventually, the predicates will be in their own table, and the predicate representation - // will be a foreign key. For now, we represent them as strings. let join1 = join(scan_id_1, scan_id_2, "t1.a = t2.b".to_string()); let join2 = join(scan_id_2, scan_id_1, "t1.a = t2.b".to_string()); @@ -143,7 +141,7 @@ async fn test_simple_tree() { memo.cleanup().await; } -/// Tests basic group merging. See comments in the test itself for more information. +/// Tests a single group merge. See comments in the test itself for more information. #[ignore] #[tokio::test] async fn test_simple_group_link() { @@ -191,22 +189,97 @@ async fn test_simple_group_link() { // The above tells the application that the expression already exists in the memo, specifically // under `existing_group`. Thus, we should link these two groups together. - // Here, we arbitrarily choose to link group 1 into group 2. - memo.update_group_parent(join_group_1, join_group_2) - .await - .unwrap(); + memo.merge_groups(join_group_1, join_group_2).await.unwrap(); let test_root_1 = memo.get_root_group(join_group_1).await.unwrap(); let test_root_2 = memo.get_root_group(join_group_2).await.unwrap(); assert_eq!(test_root_1, test_root_2); - // TODO(Connor) - // - // We now need to find all logical expressions that had group 1 (or whatever the root group of - // the set that group 1 belongs to is, in this case it is just group 1) as a child, and add a - // new fingerprint for each one that uses group 2 as a child instead. - // - // In order to do this, we need to iterate through every group in group 1's set. + memo.cleanup().await; +} + +#[ignore] +#[tokio::test] +async fn test_group_merge() { + let memo = PersistentMemo::new().await; + memo.cleanup().await; + + // Create a base group. + let scan1 = scan("t1".to_string()); + let (scan_id_1, _) = memo.add_group(scan1, &[]).await.unwrap().ok().unwrap(); + + // Create a bunch of equivalent groups. + let filter0 = filter(scan_id_1, "true".to_string()); + let filter1 = filter(scan_id_1, "1 < 2".to_string()); + let filter2 = filter(scan_id_1, "2 > 1".to_string()); + let filter3 = filter(scan_id_1, "42 != 100".to_string()); + let filter4 = filter(scan_id_1, "10000 > 0".to_string()); + let filter5 = filter(scan_id_1, "1 + 2 = 3".to_string()); + let filter6 = filter(scan_id_1, "true OR false".to_string()); + let filter7 = filter(scan_id_1, "(1 + 1 > -1 AND true) OR false".to_string()); + let (filter_id_0, _) = memo.add_group(filter0, &[]).await.unwrap().ok().unwrap(); + let (filter_id_1, _) = memo.add_group(filter1, &[]).await.unwrap().ok().unwrap(); + let (filter_id_2, _) = memo.add_group(filter2, &[]).await.unwrap().ok().unwrap(); + let (filter_id_3, _) = memo.add_group(filter3, &[]).await.unwrap().ok().unwrap(); + let (filter_id_4, _) = memo.add_group(filter4, &[]).await.unwrap().ok().unwrap(); + let (filter_id_5, _) = memo.add_group(filter5, &[]).await.unwrap().ok().unwrap(); + let (filter_id_6, _) = memo.add_group(filter6, &[]).await.unwrap().ok().unwrap(); + let (filter_id_7, _) = memo.add_group(filter7, &[]).await.unwrap().ok().unwrap(); + let filters = vec![ + filter_id_0, + filter_id_1, + filter_id_2, + filter_id_3, + filter_id_4, + filter_id_5, + filter_id_6, + filter_id_7, + ]; + + // Merge them all together. + let quarter_0 = memo.merge_groups(filters[0], filters[1]).await.unwrap(); + let quarter_1 = memo.merge_groups(filters[2], filters[3]).await.unwrap(); + let quarter_2 = memo.merge_groups(filters[4], filters[5]).await.unwrap(); + let quarter_3 = memo.merge_groups(filters[6], filters[7]).await.unwrap(); + let semi_0 = memo.merge_groups(quarter_0, quarter_1).await.unwrap(); + let semi_1 = memo.merge_groups(quarter_2, quarter_3).await.unwrap(); + let final_id = memo.merge_groups(semi_0, semi_1).await.unwrap(); + + // Check that the group set is properly representative. + { + let set = memo.get_group_set(final_id).await.unwrap(); + assert_eq!(set.len(), 8); + for id in set { + assert!(filters.contains(&id)); + } + } + + // Create another base group. + let scan2 = scan("t2".to_string()); + let (scan_id_2, _) = memo.add_group(scan2, &[]).await.unwrap().ok().unwrap(); + + // Add a join group. + let join0 = join(filter_id_0, scan_id_2, "t1.a = t2.a".to_string()); + let (join_group_id, join_expr_id) = memo + .add_group(join0, &[filter_id_0, scan_id_2]) + .await + .unwrap() + .ok() + .unwrap(); + + // Adding the duplicate join expressions should return a duplication error containing the IDs of + // the already existing group and expression. + for filter_id in filters { + let join_test = join(filter_id, scan_id_2, "t1.a = t2.a".to_string()); + let (join_group_id_test, join_expr_id_test) = memo + .add_group(join_test, &[filter_id, scan_id_2]) + .await + .unwrap() + .err() + .unwrap(); + assert_eq!(join_group_id, join_group_id_test); + assert_eq!(join_expr_id, join_expr_id_test); + } memo.cleanup().await; } diff --git a/optd-mvp/src/migrator/memo/m20241127_000001_group.rs b/optd-mvp/src/migrator/memo/m20241127_000001_group.rs index d5bbe0e..59b5a09 100644 --- a/optd-mvp/src/migrator/memo/m20241127_000001_group.rs +++ b/optd-mvp/src/migrator/memo/m20241127_000001_group.rs @@ -28,8 +28,8 @@ //! `cost` foreign key reference to a cost record (FIXME). See the //! [section](#best-physical-plan-winner) below for more details. //! -//! Finally, we maintain a union-find graph structure embedded in the group records. -//! TODO write more information about this once this is implemented. +//! Finally, we maintain a union-find graph structure embedded in the group records. See the +//! `DESIGN.md` document for more information. //! //! # Entity Relationships //!