Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
forsaken628 committed Dec 31, 2024
1 parent e6ab004 commit 33b3724
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 51 deletions.
6 changes: 3 additions & 3 deletions src/query/expression/src/aggregate/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ use databend_common_exception::Result;
use super::AggrStateLoc;
use super::StateAddr;
use crate::types::DataType;
use crate::Column;
use crate::ColumnBuilder;
use crate::InputColumns;
use crate::Scalar;
Expand Down Expand Up @@ -99,9 +98,10 @@ pub trait AggregateFunction: fmt::Display + Sync + Send {
&self,
places: &[StateAddr],
loc: Box<[AggrStateLoc]>,
column: &Column,
columns: InputColumns,
) -> Result<()> {
let c = column.as_binary().unwrap();
let idx = *loc[0].as_custom().unwrap().0;
let c = columns[idx].as_binary().unwrap();
for (place, mut data) in places.iter().zip(c.iter()) {
self.merge(&AggrState::with_loc(*place, loc.clone()), &mut data)?;
}
Expand Down
76 changes: 31 additions & 45 deletions src/query/expression/src/aggregate/aggregate_hashtable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,32 +199,25 @@ impl AggregateHashTable {
}

let state_places = &state.state_places.as_slice()[0..row_count];

let states_layout = self.payload.state_layout.as_ref().unwrap();
if agg_states.is_empty() {
for ((aggr, params), loc) in self.payload.aggrs.iter().zip(params.iter()).zip(
self.payload
.state_layout
.as_ref()
.unwrap()
.loc
.iter()
.cloned(),
) {
for ((aggr, params), loc) in self
.payload
.aggrs
.iter()
.zip(params.iter())
.zip(states_layout.loc.iter().cloned())
{
aggr.accumulate_keys(state_places, loc, *params, row_count)?;
}
} else {
for ((aggr, agg_state), loc) in
self.payload.aggrs.iter().zip(agg_states.iter()).zip(
self.payload
.state_layout
.as_ref()
.unwrap()
.loc
.iter()
.cloned(),
)
for (aggr, loc) in self
.payload
.aggrs
.iter()
.zip(states_layout.loc.iter().cloned())
{
aggr.batch_merge(state_places, loc, agg_state)?;
aggr.batch_merge(state_places, loc, agg_states)?;
}
}
}
Expand Down Expand Up @@ -418,36 +411,30 @@ impl AggregateHashTable {
let state = &mut flush_state.probe_state;
let places = &state.state_places.as_slice()[0..row_count];
let rhses = &flush_state.state_places.as_slice()[0..row_count];
for (aggr, loc) in self.payload.aggrs.iter().zip(
self.payload
.state_layout
.as_ref()
.unwrap()
.loc
.iter()
.cloned(),
) {
aggr.batch_merge_states(places, rhses, loc)?;
if let Some(layout) = self.payload.state_layout.as_ref() {
for (aggr, loc) in self.payload.aggrs.iter().zip(layout.loc.iter().cloned()) {
aggr.batch_merge_states(places, rhses, loc)?;
}
}
}

Ok(())
}

pub fn merge_result(&mut self, flush_state: &mut PayloadFlushState) -> Result<bool> {
if self.payload.flush(flush_state) {
let row_count = flush_state.row_count;
if !self.payload.flush(flush_state) {
return Ok(false);
}

flush_state.aggregate_results.clear();
for (aggr, loc) in self.payload.aggrs.iter().zip(
self.payload
.state_layout
.as_ref()
.unwrap()
.loc
.iter()
.cloned(),
) {
let row_count = flush_state.row_count;
flush_state.aggregate_results.clear();
if let Some(states_layout) = self.payload.state_layout.as_ref() {
for (aggr, loc) in self
.payload
.aggrs
.iter()
.zip(states_layout.loc.iter().cloned())
{
let return_type = aggr.return_type()?;
let mut builder = ColumnBuilder::with_capacity(&return_type, row_count * 4);

Expand All @@ -458,9 +445,8 @@ impl AggregateHashTable {
)?;
flush_state.aggregate_results.push(builder.build());
}
return Ok(true);
}
Ok(false)
Ok(true)
}

fn maybe_repartition(&mut self) -> bool {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ impl SerializedPayload {
let capacity = AggregateHashTable::get_capacity_for_count(rows_num);
let config = HashTableConfig::default().with_initial_radix_bits(radix_bits);
let mut state = ProbeState::default();
let agg_len = aggrs.len();
let agg_len = aggrs.len(); // todo check
let group_len = group_types.len();
let mut hashtable = AggregateHashTable::new_directly(
group_types,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,15 @@ impl TransformPartialAggregate {
HashTable::AggregateHashTable(hashtable) => {
let (params_columns, states_index) = if is_agg_index_block {
let num_columns = block.num_columns();
let functions_count = self.params.aggregate_functions.len();
let states_count = self
.params
.states_layout
.as_ref()
.map(|layout| layout.states_count())
.unwrap_or(0);
(
vec![],
(num_columns - functions_count..num_columns).collect::<Vec<_>>(),
(num_columns - states_count..num_columns).collect::<Vec<_>>(),
)
} else {
(
Expand Down

0 comments on commit 33b3724

Please sign in to comment.