diff --git a/optd-cost-model/src/common/mod.rs b/optd-cost-model/src/common/mod.rs index a610705..ca5eb7e 100644 --- a/optd-cost-model/src/common/mod.rs +++ b/optd-cost-model/src/common/mod.rs @@ -1,4 +1,5 @@ pub mod nodes; pub mod predicates; +pub mod properties; pub mod types; pub mod values; diff --git a/optd-cost-model/src/common/properties/attr_ref.rs b/optd-cost-model/src/common/properties/attr_ref.rs new file mode 100644 index 0000000..eb10fbb --- /dev/null +++ b/optd-cost-model/src/common/properties/attr_ref.rs @@ -0,0 +1,245 @@ +use std::collections::HashSet; + +use crate::{common::types::TableId, utils::DisjointSets}; + +pub type AttrRefs = Vec; + +/// [`BaseTableAttrRef`] represents a reference to an attribute in a base table, +/// i.e. a table existing in the catalog. +#[derive(Clone, Debug, Default, Eq, Hash, PartialEq)] +pub struct BaseTableAttrRef { + pub table_id: TableId, + pub attr_idx: u64, +} + +/// [`AttrRef`] represents a reference to an attribute in a query. +#[derive(Clone, Debug)] +pub enum AttrRef { + /// Reference to a base table attribute. + BaseTableAttrRef(BaseTableAttrRef), + /// Reference to a derived attribute (e.g. t.v1 + t.v2). + /// TODO: Better representation of derived attributes. + Derived, +} + +impl AttrRef { + pub fn base_table_attr_ref(table_id: TableId, attr_idx: u64) -> Self { + AttrRef::BaseTableAttrRef(BaseTableAttrRef { table_id, attr_idx }) + } +} + +impl From for AttrRef { + fn from(attr: BaseTableAttrRef) -> Self { + AttrRef::BaseTableAttrRef(attr) + } +} + +/// [`EqPredicate`] represents an equality predicate between two attributes. +#[derive(Clone, Debug, Eq, Hash, PartialEq)] +pub struct EqPredicate { + pub left: BaseTableAttrRef, + pub right: BaseTableAttrRef, +} + +impl EqPredicate { + pub fn new(left: BaseTableAttrRef, right: BaseTableAttrRef) -> Self { + Self { left, right } + } +} + +/// [`SemanticCorrelation`] represents the semantic correlation between attributes in a +/// query. "Semantic" means that the attributes are correlated based on the +/// semantics of the query, not the statistics. +/// +/// [`SemanticCorrelation`] contains equal attributes denoted by disjoint sets of base +/// table attributes, e.g. {{ t1.c1 = t2.c1 = t3.c1 }, { t1.c2 = t2.c2 }}. +#[derive(Clone, Debug, Default)] +pub struct SemanticCorrelation { + /// A disjoint set of base table attributes with equal values in the same row. + disjoint_eq_attr_sets: DisjointSets, + /// The predicates that define the equalities. + eq_predicates: HashSet, +} + +impl SemanticCorrelation { + pub fn new() -> Self { + Self { + disjoint_eq_attr_sets: DisjointSets::new(), + eq_predicates: HashSet::new(), + } + } + + pub fn add_predicate(&mut self, predicate: EqPredicate) { + let left = &predicate.left; + let right = &predicate.right; + + // Add the indices to the set if they do not exist. + if !self.disjoint_eq_attr_sets.contains(left) { + self.disjoint_eq_attr_sets + .make_set(left.clone()) + .expect("just checked left attribute index does not exist"); + } + if !self.disjoint_eq_attr_sets.contains(right) { + self.disjoint_eq_attr_sets + .make_set(right.clone()) + .expect("just checked right attribute index does not exist"); + } + // Union the attributes. + self.disjoint_eq_attr_sets + .union(left, right) + .expect("both attribute indices should exist"); + + // Keep track of the predicate. + self.eq_predicates.insert(predicate); + } + + /// Determine if two attributes are in the same set. + pub fn is_eq(&mut self, left: &BaseTableAttrRef, right: &BaseTableAttrRef) -> bool { + self.disjoint_eq_attr_sets + .same_set(left, right) + .unwrap_or(false) + } + + pub fn contains(&self, base_attr_ref: &BaseTableAttrRef) -> bool { + self.disjoint_eq_attr_sets.contains(base_attr_ref) + } + + /// Get the number of attributes that are equal to `attr`, including `attr` itself. + pub fn num_eq_attributes(&mut self, attr: &BaseTableAttrRef) -> usize { + self.disjoint_eq_attr_sets.set_size(attr).unwrap() + } + + /// Find the set of predicates that define the equality of the set of attributes `attr` belongs to. + pub fn find_predicates_for_eq_attr_set(&mut self, attr: &BaseTableAttrRef) -> Vec { + let mut predicates = Vec::new(); + for predicate in &self.eq_predicates { + let left = &predicate.left; + let right = &predicate.right; + if (left != attr && self.disjoint_eq_attr_sets.same_set(attr, left).unwrap()) + || (right != attr && self.disjoint_eq_attr_sets.same_set(attr, right).unwrap()) + { + predicates.push(predicate.clone()); + } + } + predicates + } + + /// Find the set of attributes that define the equality of the set of attributes `attr` belongs to. + pub fn find_attrs_for_eq_attribute_set( + &mut self, + attr: &BaseTableAttrRef, + ) -> HashSet { + let predicates = self.find_predicates_for_eq_attr_set(attr); + predicates + .into_iter() + .flat_map(|predicate| vec![predicate.left, predicate.right]) + .collect() + } + + /// Union two `EqBaseTableattributesets` to produce a new disjoint sets. + pub fn union(x: Self, y: Self) -> Self { + let mut eq_attr_sets = Self::new(); + for predicate in x + .eq_predicates + .into_iter() + .chain(y.eq_predicates.into_iter()) + { + eq_attr_sets.add_predicate(predicate); + } + eq_attr_sets + } + + pub fn merge(x: Option, y: Option) -> Option { + let eq_attr_sets = match (x, y) { + (Some(x), Some(y)) => Self::union(x, y), + (Some(x), None) => x.clone(), + (None, Some(y)) => y.clone(), + _ => return None, + }; + Some(eq_attr_sets) + } +} + +/// [`GroupAttrRefs`] represents the attributes of a group in a query. +#[derive(Clone, Debug)] +pub struct GroupAttrRefs { + attribute_refs: AttrRefs, + /// Correlation of the output attributes of the group. + output_correlation: Option, +} + +impl GroupAttrRefs { + pub fn new(attribute_refs: AttrRefs, output_correlation: Option) -> Self { + Self { + attribute_refs, + output_correlation, + } + } + + pub fn base_table_attribute_refs(&self) -> &AttrRefs { + &self.attribute_refs + } + + pub fn output_correlation(&self) -> Option<&SemanticCorrelation> { + self.output_correlation.as_ref() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_eq_base_table_attribute_sets() { + let attr1 = BaseTableAttrRef { + table_id: TableId(1), + attr_idx: 1, + }; + let attr2 = BaseTableAttrRef { + table_id: TableId(2), + attr_idx: 2, + }; + let attr3 = BaseTableAttrRef { + table_id: TableId(3), + attr_idx: 3, + }; + let attr4 = BaseTableAttrRef { + table_id: TableId(4), + attr_idx: 4, + }; + let pred1 = EqPredicate::new(attr1.clone(), attr2.clone()); + let pred2 = EqPredicate::new(attr3.clone(), attr4.clone()); + let pred3 = EqPredicate::new(attr1.clone(), attr3.clone()); + + let mut eq_attr_sets = SemanticCorrelation::new(); + + // (1, 2) + eq_attr_sets.add_predicate(pred1.clone()); + assert!(eq_attr_sets.is_eq(&attr1, &attr2)); + + // (1, 2), (3, 4) + eq_attr_sets.add_predicate(pred2.clone()); + assert!(eq_attr_sets.is_eq(&attr3, &attr4)); + assert!(!eq_attr_sets.is_eq(&attr2, &attr3)); + + let predicates = eq_attr_sets.find_predicates_for_eq_attr_set(&attr1); + assert_eq!(predicates.len(), 1); + assert!(predicates.contains(&pred1)); + + let predicates = eq_attr_sets.find_predicates_for_eq_attr_set(&attr3); + assert_eq!(predicates.len(), 1); + assert!(predicates.contains(&pred2)); + + // (1, 2, 3, 4) + eq_attr_sets.add_predicate(pred3.clone()); + assert!(eq_attr_sets.is_eq(&attr1, &attr3)); + assert!(eq_attr_sets.is_eq(&attr2, &attr4)); + assert!(eq_attr_sets.is_eq(&attr1, &attr4)); + + let predicates = eq_attr_sets.find_predicates_for_eq_attr_set(&attr1); + assert_eq!(predicates.len(), 3); + assert!(predicates.contains(&pred1)); + assert!(predicates.contains(&pred2)); + assert!(predicates.contains(&pred3)); + } +} diff --git a/optd-cost-model/src/common/properties/mod.rs b/optd-cost-model/src/common/properties/mod.rs new file mode 100644 index 0000000..c9acbd1 --- /dev/null +++ b/optd-cost-model/src/common/properties/mod.rs @@ -0,0 +1,23 @@ +use serde::{Deserialize, Serialize}; + +use super::predicates::constant_pred::ConstantType; + +pub mod attr_ref; +pub mod schema; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Attribute { + pub name: String, + pub typ: ConstantType, + pub nullable: bool, +} + +impl std::fmt::Display for Attribute { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if self.nullable { + write!(f, "{}:{:?}", self.name, self.typ) + } else { + write!(f, "{}:{:?}(non-null)", self.name, self.typ) + } + } +} diff --git a/optd-cost-model/src/common/properties/schema.rs b/optd-cost-model/src/common/properties/schema.rs new file mode 100644 index 0000000..4ee4fce --- /dev/null +++ b/optd-cost-model/src/common/properties/schema.rs @@ -0,0 +1,35 @@ +use itertools::Itertools; + +use serde::{Deserialize, Serialize}; + +use super::Attribute; + +/// [`Schema`] represents the schema of a group in the memo. It contains a list of attributes. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Schema { + pub attributes: Vec, +} + +impl std::fmt::Display for Schema { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "[{}]", + self.attributes.iter().map(|x| x.to_string()).join(", ") + ) + } +} + +impl Schema { + pub fn new(attributes: Vec) -> Self { + Self { attributes } + } + + pub fn len(&self) -> usize { + self.attributes.len() + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} diff --git a/optd-cost-model/src/cost_model.rs b/optd-cost-model/src/cost_model.rs index c0b0677..e933add 100644 --- a/optd-cost-model/src/cost_model.rs +++ b/optd-cost-model/src/cost_model.rs @@ -12,6 +12,7 @@ use crate::{ nodes::{ArcPredicateNode, PhysicalNodeType}, types::{AttrId, EpochId, ExprId, TableId}, }, + memo_ext::MemoExt, storage::CostModelStorageManager, ComputeCostContext, Cost, CostModel, CostModelResult, EstimatedStatistic, StatValue, }; @@ -20,6 +21,7 @@ use crate::{ pub struct CostModelImpl { storage_manager: CostModelStorageManager, default_catalog_source: CatalogSource, + _memo: Arc, } impl CostModelImpl { @@ -27,21 +29,23 @@ impl CostModelImpl { pub fn new( storage_manager: CostModelStorageManager, default_catalog_source: CatalogSource, + memo: Arc, ) -> Self { Self { storage_manager, default_catalog_source, + _memo: memo, } } } -impl CostModel for CostModelImpl { +impl CostModel for CostModelImpl { fn compute_operation_cost( &self, node: &PhysicalNodeType, predicates: &[ArcPredicateNode], children_stats: &[Option<&EstimatedStatistic>], - context: Option, + context: ComputeCostContext, ) -> CostModelResult { todo!() } @@ -51,7 +55,7 @@ impl CostModel for CostM node: PhysicalNodeType, predicates: &[ArcPredicateNode], children_statistics: &[Option<&EstimatedStatistic>], - context: Option, + context: ComputeCostContext, ) -> CostModelResult { todo!() } diff --git a/optd-cost-model/src/lib.rs b/optd-cost-model/src/lib.rs index a635b66..5417f1c 100644 --- a/optd-cost-model/src/lib.rs +++ b/optd-cost-model/src/lib.rs @@ -10,8 +10,10 @@ use optd_persistent::{ pub mod common; pub mod cost; pub mod cost_model; +pub mod memo_ext; pub mod stats; pub mod storage; +pub mod utils; pub enum StatValue { Int(i64), @@ -63,7 +65,7 @@ pub trait CostModel: 'static + Send + Sync { node: &PhysicalNodeType, predicates: &[ArcPredicateNode], children_stats: &[Option<&EstimatedStatistic>], - context: Option, + context: ComputeCostContext, ) -> CostModelResult; /// TODO: documentation @@ -76,7 +78,7 @@ pub trait CostModel: 'static + Send + Sync { node: PhysicalNodeType, predicates: &[ArcPredicateNode], children_statistics: &[Option<&EstimatedStatistic>], - context: Option, + context: ComputeCostContext, ) -> CostModelResult; /// TODO: documentation diff --git a/optd-cost-model/src/memo_ext.rs b/optd-cost-model/src/memo_ext.rs new file mode 100644 index 0000000..16cddca --- /dev/null +++ b/optd-cost-model/src/memo_ext.rs @@ -0,0 +1,22 @@ +use crate::common::{ + properties::{attr_ref::GroupAttrRefs, schema::Schema, Attribute}, + types::GroupId, +}; + +/// [`MemoExt`] is a trait that provides methods to access the schema, column reference, and attribute +/// information of a group in the memo. The information are used by the cost model to compute the cost of +/// an expression. +/// +/// [`MemoExt`] should be implemented by the optimizer core to provide the necessary information to the cost +/// model. All information required here is already present in the memo, so the optimizer core should be able +/// to implement this trait without additional work. +pub trait MemoExt: Send + Sync + 'static { + /// Get the schema of a group in the memo. + fn get_schema(&self, group_id: GroupId) -> Schema; + /// Get the attribute reference of a group in the memo. + fn get_attribute_ref(&self, group_id: GroupId) -> GroupAttrRefs; + /// Get the attribute information of a given attribute in a group in the memo. + fn get_attribute_info(&self, group_id: GroupId, attr_ref_idx: u64) -> Attribute; + + // TODO: Figure out what other information is needed to compute the cost... +} diff --git a/optd-cost-model/src/utils.rs b/optd-cost-model/src/utils.rs new file mode 100644 index 0000000..125c499 --- /dev/null +++ b/optd-cost-model/src/utils.rs @@ -0,0 +1,118 @@ +//! optd's implementation of disjoint sets (union finds). It's send + sync + serializable. + +use std::{collections::HashMap, hash::Hash}; +#[derive(Clone, Default)] +pub struct DisjointSets { + data_idx: HashMap, + parents: Vec, +} + +impl std::fmt::Debug for DisjointSets { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "DisjointSets") + } +} + +impl DisjointSets { + pub fn new() -> Self { + Self { + data_idx: HashMap::new(), + parents: Vec::new(), + } + } + + pub fn contains(&self, data: &T) -> bool { + self.data_idx.contains_key(data) + } + + #[must_use] + pub fn make_set(&mut self, data: T) -> Option<()> { + if self.data_idx.contains_key(&data) { + return None; + } + let idx = self.parents.len(); + self.data_idx.insert(data.clone(), idx); + self.parents.push(idx); + Some(()) + } + + fn find(&mut self, mut idx: usize) -> usize { + while self.parents[idx] != idx { + self.parents[idx] = self.parents[self.parents[idx]]; + idx = self.parents[idx]; + } + idx + } + + fn find_const(&self, mut idx: usize) -> usize { + while self.parents[idx] != idx { + idx = self.parents[idx]; + } + idx + } + + #[must_use] + pub fn union(&mut self, data1: &T, data2: &T) -> Option<()> { + let idx1 = *self.data_idx.get(data1)?; + let idx2 = *self.data_idx.get(data2)?; + let parent1 = self.find(idx1); + let parent2 = self.find(idx2); + if parent1 != parent2 { + self.parents[parent1] = parent2; + } + Some(()) + } + + pub fn same_set(&self, data1: &T, data2: &T) -> Option { + let idx1 = *self.data_idx.get(data1)?; + let idx2 = *self.data_idx.get(data2)?; + Some(self.find_const(idx1) == self.find_const(idx2)) + } + + pub fn set_size(&self, data: &T) -> Option { + let idx = *self.data_idx.get(data)?; + let parent = self.find_const(idx); + Some( + self.parents + .iter() + .filter(|&&x| self.find_const(x) == parent) + .count(), + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_union_find() { + let mut set = DisjointSets::new(); + set.make_set("a").unwrap(); + set.make_set("b").unwrap(); + set.make_set("c").unwrap(); + set.make_set("d").unwrap(); + set.make_set("e").unwrap(); + assert!(set.same_set(&"a", &"a").unwrap()); + assert!(!set.same_set(&"a", &"b").unwrap()); + assert_eq!(set.set_size(&"a").unwrap(), 1); + assert_eq!(set.set_size(&"c").unwrap(), 1); + set.union(&"a", &"b").unwrap(); + assert_eq!(set.set_size(&"a").unwrap(), 2); + assert_eq!(set.set_size(&"c").unwrap(), 1); + assert!(set.same_set(&"a", &"b").unwrap()); + assert!(!set.same_set(&"a", &"c").unwrap()); + set.union(&"b", &"c").unwrap(); + assert!(set.same_set(&"a", &"c").unwrap()); + assert!(!set.same_set(&"a", &"d").unwrap()); + assert_eq!(set.set_size(&"a").unwrap(), 3); + assert_eq!(set.set_size(&"d").unwrap(), 1); + set.union(&"d", &"e").unwrap(); + assert!(set.same_set(&"d", &"e").unwrap()); + assert!(!set.same_set(&"a", &"d").unwrap()); + assert_eq!(set.set_size(&"a").unwrap(), 3); + assert_eq!(set.set_size(&"d").unwrap(), 2); + set.union(&"c", &"e").unwrap(); + assert!(set.same_set(&"a", &"e").unwrap()); + assert_eq!(set.set_size(&"d").unwrap(), 5); + } +}