Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(cost-model): introduce MemoExt trait for property info #38

Merged
merged 10 commits into from
Nov 18, 2024
1 change: 1 addition & 0 deletions optd-cost-model/src/common/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod nodes;
pub mod predicates;
pub mod properties;
pub mod types;
pub mod values;
245 changes: 245 additions & 0 deletions optd-cost-model/src/common/properties/attr_ref.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
use std::collections::HashSet;

use crate::{common::types::TableId, utils::DisjointSets};

pub type AttrRefs = Vec<AttrRef>;

/// [`BaseTableAttrRef`] represents a reference to an attribute in a base table,
/// i.e. a table existing in the catalog.
#[derive(Clone, Debug, Default, Eq, Hash, PartialEq)]
pub struct BaseTableAttrRef {
pub table_id: TableId,
pub attr_idx: u64,
}

/// [`AttrRef`] represents a reference to an attribute in a query.
#[derive(Clone, Debug)]
pub enum AttrRef {
/// Reference to a base table attribute.
BaseTableAttrRef(BaseTableAttrRef),
/// Reference to a derived attribute (e.g. t.v1 + t.v2).
/// TODO: Better representation of derived attributes.
Derived,
}

impl AttrRef {
pub fn base_table_attr_ref(table_id: TableId, attr_idx: u64) -> Self {
AttrRef::BaseTableAttrRef(BaseTableAttrRef { table_id, attr_idx })
}
}

impl From<BaseTableAttrRef> for AttrRef {
fn from(attr: BaseTableAttrRef) -> Self {
AttrRef::BaseTableAttrRef(attr)
}
}

/// [`EqPredicate`] represents an equality predicate between two attributes.
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct EqPredicate {
pub left: BaseTableAttrRef,
pub right: BaseTableAttrRef,
}

impl EqPredicate {
pub fn new(left: BaseTableAttrRef, right: BaseTableAttrRef) -> Self {
Self { left, right }
}
}

/// [`SemanticCorrelation`] represents the semantic correlation between attributes in a
/// query. "Semantic" means that the attributes are correlated based on the
/// semantics of the query, not the statistics.
///
/// [`SemanticCorrelation`] contains equal attributes denoted by disjoint sets of base
/// table attributes, e.g. {{ t1.c1 = t2.c1 = t3.c1 }, { t1.c2 = t2.c2 }}.
#[derive(Clone, Debug, Default)]
pub struct SemanticCorrelation {
/// A disjoint set of base table attributes with equal values in the same row.
disjoint_eq_attr_sets: DisjointSets<BaseTableAttrRef>,
/// The predicates that define the equalities.
eq_predicates: HashSet<EqPredicate>,
}

impl SemanticCorrelation {
pub fn new() -> Self {
Self {
disjoint_eq_attr_sets: DisjointSets::new(),
eq_predicates: HashSet::new(),
}
}

pub fn add_predicate(&mut self, predicate: EqPredicate) {
let left = &predicate.left;
let right = &predicate.right;

// Add the indices to the set if they do not exist.
if !self.disjoint_eq_attr_sets.contains(left) {
self.disjoint_eq_attr_sets
.make_set(left.clone())
.expect("just checked left attribute index does not exist");
}
if !self.disjoint_eq_attr_sets.contains(right) {
self.disjoint_eq_attr_sets
.make_set(right.clone())
.expect("just checked right attribute index does not exist");
}
// Union the attributes.
self.disjoint_eq_attr_sets
.union(left, right)
.expect("both attribute indices should exist");

// Keep track of the predicate.
self.eq_predicates.insert(predicate);
}

/// Determine if two attributes are in the same set.
pub fn is_eq(&mut self, left: &BaseTableAttrRef, right: &BaseTableAttrRef) -> bool {
self.disjoint_eq_attr_sets
.same_set(left, right)
.unwrap_or(false)
}

pub fn contains(&self, base_attr_ref: &BaseTableAttrRef) -> bool {
self.disjoint_eq_attr_sets.contains(base_attr_ref)
}

/// Get the number of attributes that are equal to `attr`, including `attr` itself.
pub fn num_eq_attributes(&mut self, attr: &BaseTableAttrRef) -> usize {
self.disjoint_eq_attr_sets.set_size(attr).unwrap()
}

/// Find the set of predicates that define the equality of the set of attributes `attr` belongs to.
pub fn find_predicates_for_eq_attr_set(&mut self, attr: &BaseTableAttrRef) -> Vec<EqPredicate> {
let mut predicates = Vec::new();
for predicate in &self.eq_predicates {
let left = &predicate.left;
let right = &predicate.right;
if (left != attr && self.disjoint_eq_attr_sets.same_set(attr, left).unwrap())
|| (right != attr && self.disjoint_eq_attr_sets.same_set(attr, right).unwrap())
{
predicates.push(predicate.clone());
}
}
predicates
}

/// Find the set of attributes that define the equality of the set of attributes `attr` belongs to.
pub fn find_attrs_for_eq_attribute_set(
&mut self,
attr: &BaseTableAttrRef,
) -> HashSet<BaseTableAttrRef> {
let predicates = self.find_predicates_for_eq_attr_set(attr);
predicates
.into_iter()
.flat_map(|predicate| vec![predicate.left, predicate.right])
.collect()
}

/// Union two `EqBaseTableattributesets` to produce a new disjoint sets.
pub fn union(x: Self, y: Self) -> Self {
let mut eq_attr_sets = Self::new();
for predicate in x
.eq_predicates
.into_iter()
.chain(y.eq_predicates.into_iter())
{
eq_attr_sets.add_predicate(predicate);
}
eq_attr_sets
}

pub fn merge(x: Option<Self>, y: Option<Self>) -> Option<Self> {
let eq_attr_sets = match (x, y) {
(Some(x), Some(y)) => Self::union(x, y),
(Some(x), None) => x.clone(),
(None, Some(y)) => y.clone(),
_ => return None,
};
Some(eq_attr_sets)
}
}

/// [`GroupAttrRefs`] represents the attributes of a group in a query.
#[derive(Clone, Debug)]
pub struct GroupAttrRefs {
attribute_refs: AttrRefs,
/// Correlation of the output attributes of the group.
output_correlation: Option<SemanticCorrelation>,
}

impl GroupAttrRefs {
pub fn new(attribute_refs: AttrRefs, output_correlation: Option<SemanticCorrelation>) -> Self {
Self {
attribute_refs,
output_correlation,
}
}

pub fn base_table_attribute_refs(&self) -> &AttrRefs {
&self.attribute_refs
}

pub fn output_correlation(&self) -> Option<&SemanticCorrelation> {
self.output_correlation.as_ref()
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_eq_base_table_attribute_sets() {
let attr1 = BaseTableAttrRef {
table_id: TableId(1),
attr_idx: 1,
};
let attr2 = BaseTableAttrRef {
table_id: TableId(2),
attr_idx: 2,
};
let attr3 = BaseTableAttrRef {
table_id: TableId(3),
attr_idx: 3,
};
let attr4 = BaseTableAttrRef {
table_id: TableId(4),
attr_idx: 4,
};
let pred1 = EqPredicate::new(attr1.clone(), attr2.clone());
let pred2 = EqPredicate::new(attr3.clone(), attr4.clone());
let pred3 = EqPredicate::new(attr1.clone(), attr3.clone());

let mut eq_attr_sets = SemanticCorrelation::new();

// (1, 2)
eq_attr_sets.add_predicate(pred1.clone());
assert!(eq_attr_sets.is_eq(&attr1, &attr2));

// (1, 2), (3, 4)
eq_attr_sets.add_predicate(pred2.clone());
assert!(eq_attr_sets.is_eq(&attr3, &attr4));
assert!(!eq_attr_sets.is_eq(&attr2, &attr3));

let predicates = eq_attr_sets.find_predicates_for_eq_attr_set(&attr1);
assert_eq!(predicates.len(), 1);
assert!(predicates.contains(&pred1));

let predicates = eq_attr_sets.find_predicates_for_eq_attr_set(&attr3);
assert_eq!(predicates.len(), 1);
assert!(predicates.contains(&pred2));

// (1, 2, 3, 4)
eq_attr_sets.add_predicate(pred3.clone());
assert!(eq_attr_sets.is_eq(&attr1, &attr3));
assert!(eq_attr_sets.is_eq(&attr2, &attr4));
assert!(eq_attr_sets.is_eq(&attr1, &attr4));

let predicates = eq_attr_sets.find_predicates_for_eq_attr_set(&attr1);
assert_eq!(predicates.len(), 3);
assert!(predicates.contains(&pred1));
assert!(predicates.contains(&pred2));
assert!(predicates.contains(&pred3));
}
}
23 changes: 23 additions & 0 deletions optd-cost-model/src/common/properties/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
use serde::{Deserialize, Serialize};

use super::predicates::constant_pred::ConstantType;

pub mod attr_ref;
pub mod schema;

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Attribute {
pub name: String,
pub typ: ConstantType,
pub nullable: bool,
}

impl std::fmt::Display for Attribute {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.nullable {
write!(f, "{}:{:?}", self.name, self.typ)
} else {
write!(f, "{}:{:?}(non-null)", self.name, self.typ)
}
}
}
35 changes: 35 additions & 0 deletions optd-cost-model/src/common/properties/schema.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
use itertools::Itertools;

use serde::{Deserialize, Serialize};

use super::Attribute;

/// [`Schema`] represents the schema of a group in the memo. It contains a list of attributes.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Schema {
pub attributes: Vec<Attribute>,
}

impl std::fmt::Display for Schema {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"[{}]",
self.attributes.iter().map(|x| x.to_string()).join(", ")
)
}
}

impl Schema {
pub fn new(attributes: Vec<Attribute>) -> Self {
Self { attributes }
}

pub fn len(&self) -> usize {
self.attributes.len()
}

pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
10 changes: 7 additions & 3 deletions optd-cost-model/src/cost_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use crate::{
nodes::{ArcPredicateNode, PhysicalNodeType},
types::{AttrId, EpochId, ExprId, TableId},
},
memo_ext::MemoExt,
storage::CostModelStorageManager,
ComputeCostContext, Cost, CostModel, CostModelResult, EstimatedStatistic, StatValue,
};
Expand All @@ -20,28 +21,31 @@ use crate::{
pub struct CostModelImpl<S: CostModelStorageLayer> {
storage_manager: CostModelStorageManager<S>,
default_catalog_source: CatalogSource,
_memo: Arc<dyn MemoExt>,
}
Comment on lines 21 to 25
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure whether it's good to use Arc here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks fine, otherwise you'll need to add the memo type as part of the cost model impl's generic type parameters.


impl<S: CostModelStorageLayer> CostModelImpl<S> {
/// TODO: documentation
pub fn new(
storage_manager: CostModelStorageManager<S>,
default_catalog_source: CatalogSource,
memo: Arc<dyn MemoExt>,
) -> Self {
Self {
storage_manager,
default_catalog_source,
_memo: memo,
}
}
}

impl<S: CostModelStorageLayer + std::marker::Sync + 'static> CostModel for CostModelImpl<S> {
impl<S: CostModelStorageLayer + Sync + 'static> CostModel for CostModelImpl<S> {
fn compute_operation_cost(
&self,
node: &PhysicalNodeType,
predicates: &[ArcPredicateNode],
children_stats: &[Option<&EstimatedStatistic>],
context: Option<ComputeCostContext>,
context: ComputeCostContext,
) -> CostModelResult<Cost> {
todo!()
}
Expand All @@ -51,7 +55,7 @@ impl<S: CostModelStorageLayer + std::marker::Sync + 'static> CostModel for CostM
node: PhysicalNodeType,
predicates: &[ArcPredicateNode],
children_statistics: &[Option<&EstimatedStatistic>],
context: Option<ComputeCostContext>,
context: ComputeCostContext,
) -> CostModelResult<EstimatedStatistic> {
todo!()
}
Expand Down
6 changes: 4 additions & 2 deletions optd-cost-model/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ use optd_persistent::{
pub mod common;
pub mod cost;
pub mod cost_model;
pub mod memo_ext;
pub mod stats;
pub mod storage;
pub mod utils;

pub enum StatValue {
Int(i64),
Expand Down Expand Up @@ -63,7 +65,7 @@ pub trait CostModel: 'static + Send + Sync {
node: &PhysicalNodeType,
predicates: &[ArcPredicateNode],
children_stats: &[Option<&EstimatedStatistic>],
context: Option<ComputeCostContext>,
context: ComputeCostContext,
) -> CostModelResult<Cost>;

/// TODO: documentation
Expand All @@ -76,7 +78,7 @@ pub trait CostModel: 'static + Send + Sync {
node: PhysicalNodeType,
predicates: &[ArcPredicateNode],
children_statistics: &[Option<&EstimatedStatistic>],
context: Option<ComputeCostContext>,
context: ComputeCostContext,
) -> CostModelResult<EstimatedStatistic>;

/// TODO: documentation
Expand Down
Loading