[Feature](agg_state) support agg_state combinators (#19969)

support agg_state combinators state/merge/union
This commit is contained in:
Pxl
2023-05-29 13:07:29 +08:00
committed by GitHub
parent f217e052d3
commit bbb3af6ce6
29 changed files with 564 additions and 99 deletions

View File

@ -233,7 +233,7 @@ void MemTable::_aggregate_two_row_in_block(vectorized::MutableBlock& mutable_blo
auto col_ptr = mutable_block.mutable_columns()[cid].get();
_agg_functions[cid]->add(row_in_skiplist->agg_places(cid),
const_cast<const doris::vectorized::IColumn**>(&col_ptr),
new_row->_row_pos, nullptr);
new_row->_row_pos, _arena.get());
}
}
void MemTable::_put_into_output(vectorized::Block& in_block) {
@ -298,7 +298,7 @@ void MemTable::_finalize_one_row(RowInBlock* row,
} else {
function->reset(agg_place);
function->add(agg_place, const_cast<const doris::vectorized::IColumn**>(&col_ptr),
row_pos, nullptr);
row_pos, _arena.get());
}
}
if constexpr (is_final) {
@ -343,7 +343,7 @@ void MemTable::_aggregate() {
_agg_functions[cid]->create(data);
_agg_functions[cid]->add(
data, const_cast<const doris::vectorized::IColumn**>(&col_ptr),
prev_row->_row_pos, nullptr);
prev_row->_row_pos, _arena.get());
}
}
_stat.merged_rows++;

View File

@ -39,6 +39,7 @@
#include "runtime/thread_context.h"
#include "tablet_meta.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
#include "vec/aggregate_functions/aggregate_function_state_union.h"
#include "vec/core/block.h"
#include "vec/data_types/data_type.h"
#include "vec/data_types/data_type_factory.hpp"
@ -491,14 +492,21 @@ bool TabletColumn::is_row_store_column() const {
return _col_name == BeConsts::ROW_STORE_COL;
}
vectorized::AggregateFunctionPtr TabletColumn::get_aggregate_function_merge() const {
vectorized::AggregateFunctionPtr TabletColumn::get_aggregate_function_union(
vectorized::DataTypePtr type) const {
auto state_type = dynamic_cast<const vectorized::DataTypeAggState*>(type.get());
if (!state_type) {
return nullptr;
}
vectorized::DataTypes argument_types;
for (auto col : _sub_columns) {
argument_types.push_back(vectorized::DataTypeFactory::instance().create_data_type(col));
auto sub_type = vectorized::DataTypeFactory::instance().create_data_type(col);
state_type->add_sub_type(sub_type);
}
auto function = vectorized::AggregateFunctionSimpleFactory::instance().get(
_aggregation_name, argument_types, false);
return function;
auto agg_function = vectorized::AggregateFunctionSimpleFactory::instance().get(
_aggregation_name, state_type->get_sub_types(), false);
return vectorized::AggregateStateUnion::create(agg_function, {type}, type);
}
vectorized::AggregateFunctionPtr TabletColumn::get_aggregate_function(std::string suffix) const {
@ -514,7 +522,7 @@ vectorized::AggregateFunctionPtr TabletColumn::get_aggregate_function(std::strin
if (function) {
return function;
}
return get_aggregate_function_merge();
return get_aggregate_function_union(type);
}
void TabletIndex::init_from_thrift(const TOlapTableIndex& index,

View File

@ -89,7 +89,8 @@ public:
void set_is_nullable(bool is_nullable) { _is_nullable = is_nullable; }
void set_has_default_value(bool has) { _has_default_value = has; }
FieldAggregationMethod aggregation() const { return _aggregation; }
vectorized::AggregateFunctionPtr get_aggregate_function_merge() const;
vectorized::AggregateFunctionPtr get_aggregate_function_union(
vectorized::DataTypePtr type) const;
vectorized::AggregateFunctionPtr get_aggregate_function(std::string suffix) const;
int precision() const { return _precision; }
int frac() const { return _frac; }

View File

@ -59,7 +59,6 @@ struct TypeDescriptor {
/// The maximum precision representable by a 8-byte decimal (Decimal8Value)
static constexpr int MAX_DECIMAL8_PRECISION = 18;
// Empty for scalar types
std::vector<TypeDescriptor> children;
// Only set if type == TYPE_STRUCT. The field name of each child.
@ -148,6 +147,13 @@ struct TypeDescriptor {
int idx = 0;
TypeDescriptor result(t.types, &idx);
DCHECK_EQ(idx, t.types.size() - 1);
if (result.type == TYPE_AGG_STATE) {
DCHECK(t.__isset.sub_types);
for (auto sub : t.sub_types) {
result.children.push_back(from_thrift(sub));
result.contains_nulls.push_back(sub.is_nullable);
}
}
return result;
}

View File

@ -134,7 +134,7 @@ public:
MutableColumnPtr& dst, const size_t num_rows) const = 0;
virtual void serialize_without_key_to_column(ConstAggregateDataPtr __restrict place,
MutableColumnPtr& dst) const = 0;
IColumn& to) const = 0;
/// Deserializes state. This function is called only for empty (just created) states.
virtual void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
@ -150,6 +150,10 @@ public:
virtual void deserialize_and_merge(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena* arena) const = 0;
virtual void deserialize_and_merge_from_column_range(AggregateDataPtr __restrict place,
const IColumn& column, size_t begin,
size_t end, Arena* arena) const = 0;
virtual void deserialize_and_merge_from_column(AggregateDataPtr __restrict place,
const IColumn& column, Arena* arena) const = 0;
@ -332,8 +336,8 @@ public:
}
void serialize_without_key_to_column(ConstAggregateDataPtr __restrict place,
MutableColumnPtr& dst) const override {
VectorBufferWriter writter(assert_cast<ColumnString&>(*dst));
IColumn& to) const override {
VectorBufferWriter writter(assert_cast<ColumnString&>(to));
assert_cast<const Derived*>(this)->serialize(place, writter);
writter.commit();
}
@ -382,6 +386,38 @@ public:
}
}
}
void deserialize_and_merge_from_column_range(AggregateDataPtr __restrict place,
const IColumn& column, size_t begin, size_t end,
Arena* arena) const override {
DCHECK(end <= column.size() && begin <= end)
<< ", begin:" << begin << ", end:" << end << ", column.size():" << column.size();
for (size_t i = begin; i <= end; ++i) {
VectorBufferReader buffer_reader(
(assert_cast<const ColumnString&>(column)).get_data_at(i));
deserialize_and_merge(place, buffer_reader, arena);
}
}
void deserialize_and_merge_from_column(AggregateDataPtr __restrict place, const IColumn& column,
Arena* arena) const override {
if (column.empty()) {
return;
}
deserialize_and_merge_from_column_range(place, column, 0, column.size() - 1, arena);
}
void deserialize_and_merge(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena* arena) const override {
char deserialized_data[size_of_data()];
AggregateDataPtr deserialized_place = (AggregateDataPtr)deserialized_data;
auto derived = static_cast<const Derived*>(this);
derived->create(deserialized_place);
derived->deserialize(deserialized_place, buf, arena);
derived->merge(place, deserialized_place, arena);
derived->destroy(deserialized_place);
}
};
/// Implements several methods for manipulation with data. T - type of structure with data for aggregation.
@ -426,16 +462,6 @@ public:
derived->deserialize(deserialized_place, buf, arena);
derived->merge(place, deserialized_place, arena);
}
void deserialize_and_merge_from_column(AggregateDataPtr __restrict place, const IColumn& column,
Arena* arena) const override {
size_t num_rows = column.size();
for (size_t i = 0; i != num_rows; ++i) {
VectorBufferReader buffer_reader(
(assert_cast<const ColumnString&>(column)).get_data_at(i));
deserialize_and_merge(place, buffer_reader, arena);
}
}
};
using AggregateFunctionPtr = std::shared_ptr<IAggregateFunction>;

View File

@ -221,8 +221,8 @@ public:
}
void serialize_without_key_to_column(ConstAggregateDataPtr __restrict place,
MutableColumnPtr& dst) const override {
auto& col = assert_cast<ColumnFixedLengthObject&>(*dst);
IColumn& to) const override {
auto& col = assert_cast<ColumnFixedLengthObject&>(to);
col.set_item_size(sizeof(Data));
col.resize(1);
*reinterpret_cast<Data*>(col.get_data().data()) = this->data(place);

View File

@ -104,7 +104,7 @@ public:
col.resize(num_rows);
auto* data = col.get_data().data();
for (size_t i = 0; i != num_rows; ++i) {
data[i] = this->data(places[i] + offset).count;
data[i] = AggregateFunctionCount::data(places[i] + offset).count;
}
}
@ -120,15 +120,16 @@ public:
auto data = assert_cast<const ColumnUInt64&>(column).get_data().data();
const size_t num_rows = column.size();
for (size_t i = 0; i != num_rows; ++i) {
this->data(place).count += data[i];
AggregateFunctionCount::data(place).count += data[i];
}
}
void serialize_without_key_to_column(ConstAggregateDataPtr __restrict place,
MutableColumnPtr& dst) const override {
auto& col = assert_cast<ColumnUInt64&>(*dst);
IColumn& to) const override {
auto& col = assert_cast<ColumnUInt64&>(to);
col.resize(1);
reinterpret_cast<Data*>(col.get_data().data())->count = this->data(place).count;
reinterpret_cast<Data*>(col.get_data().data())->count =
AggregateFunctionCount::data(place).count;
}
MutableColumnPtr create_serialize_column() const override {
@ -199,7 +200,7 @@ public:
col.resize(num_rows);
auto* data = col.get_data().data();
for (size_t i = 0; i != num_rows; ++i) {
data[i] = this->data(places[i] + offset).count;
data[i] = AggregateFunctionCountNotNullUnary::data(places[i] + offset).count;
}
}
@ -219,15 +220,16 @@ public:
auto data = assert_cast<const ColumnUInt64&>(column).get_data().data();
const size_t num_rows = column.size();
for (size_t i = 0; i != num_rows; ++i) {
this->data(place).count += data[i];
AggregateFunctionCountNotNullUnary::data(place).count += data[i];
}
}
void serialize_without_key_to_column(ConstAggregateDataPtr __restrict place,
MutableColumnPtr& dst) const override {
auto& col = assert_cast<ColumnUInt64&>(*dst);
IColumn& to) const override {
auto& col = assert_cast<ColumnUInt64&>(to);
col.resize(1);
reinterpret_cast<Data*>(col.get_data().data())->count = this->data(place).count;
reinterpret_cast<Data*>(col.get_data().data())->count =
AggregateFunctionCountNotNullUnary::data(place).count;
}
MutableColumnPtr create_serialize_column() const override {

View File

@ -219,7 +219,6 @@ public:
this->data(place).deserialize(buf, arena);
}
// void insert_result_into(AggregateDataPtr place, IColumn & to, Arena * arena) const override
void insert_result_into(ConstAggregateDataPtr targetplace, IColumn& to) const override {
auto place = const_cast<AggregateDataPtr>(targetplace);
auto arguments = this->data(place).get_arguments(this->argument_types);

View File

@ -609,13 +609,13 @@ public:
}
void serialize_without_key_to_column(ConstAggregateDataPtr __restrict place,
MutableColumnPtr& dst) const override {
IColumn& to) const override {
if constexpr (Data::IsFixedLength) {
auto& col = assert_cast<ColumnFixedLengthObject&>(*dst);
auto& col = assert_cast<ColumnFixedLengthObject&>(to);
col.resize(1);
*reinterpret_cast<Data*>(col.get_data().data()) = this->data(place);
} else {
Base::serialize_without_key_to_column(place, dst);
Base::serialize_without_key_to_column(place, to);
}
}

View File

@ -155,16 +155,6 @@ public:
}
}
void deserialize_and_merge_from_column(AggregateDataPtr __restrict place, const IColumn& column,
Arena* arena) const override {
size_t num_rows = column.size();
for (size_t i = 0; i != num_rows; ++i) {
VectorBufferReader buffer_reader(
(assert_cast<const ColumnString&>(column)).get_data_at(i));
deserialize_and_merge(place, buffer_reader, arena);
}
}
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
if constexpr (result_is_nullable) {
ColumnNullable& to_concrete = assert_cast<ColumnNullable&>(to);

View File

@ -0,0 +1,50 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include "vec/aggregate_functions/aggregate_function_state_union.h"
namespace doris::vectorized {
const static std::string AGG_MERGE_SUFFIX = "_merge";
class AggregateStateMerge : public AggregateStateUnion {
public:
AggregateStateMerge(AggregateFunctionPtr function, const DataTypes& argument_types,
const DataTypePtr& return_type)
: AggregateStateUnion(function, argument_types, return_type) {}
static AggregateFunctionPtr create(AggregateFunctionPtr function,
const DataTypes& argument_types,
const DataTypePtr& return_type) {
CHECK(argument_types.size() == 1);
if (function == nullptr) {
return nullptr;
}
return std::make_shared<AggregateStateMerge>(function, argument_types, return_type);
}
String get_name() const override { return _function->get_name() + AGG_MERGE_SUFFIX; }
DataTypePtr get_return_type() const override { return _function->get_return_type(); }
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
_function->insert_result_into(place, to);
}
};
} // namespace doris::vectorized

View File

@ -0,0 +1,99 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/data_types/data_type_agg_state.h"
namespace doris::vectorized {
const static std::string AGG_UNION_SUFFIX = "_union";
class AggregateStateUnion : public IAggregateFunctionHelper<AggregateStateUnion> {
public:
AggregateStateUnion(AggregateFunctionPtr function, const DataTypes& argument_types,
const DataTypePtr& return_type)
: IAggregateFunctionHelper(argument_types),
_function(function),
_return_type(return_type) {}
~AggregateStateUnion() override = default;
static AggregateFunctionPtr create(AggregateFunctionPtr function,
const DataTypes& argument_types,
const DataTypePtr& return_type) {
CHECK(argument_types.size() == 1);
if (function == nullptr) {
return nullptr;
}
return std::make_shared<AggregateStateUnion>(function, argument_types, return_type);
}
void create(AggregateDataPtr __restrict place) const override { _function->create(place); }
String get_name() const override { return _function->get_name() + AGG_UNION_SUFFIX; }
DataTypePtr get_return_type() const override { return _return_type; }
void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num,
Arena* arena) const override {
VectorBufferReader buffer_reader(
(assert_cast<const ColumnString&>(*columns[0])).get_data_at(row_num));
deserialize_and_merge(place, buffer_reader, arena);
}
void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns,
Arena* arena) const override {
_function->deserialize_and_merge_from_column_range(place, *columns[0], 0, batch_size - 1,
arena);
}
void reset(AggregateDataPtr place) const override { _function->reset(place); }
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena* arena) const override {
_function->merge(place, rhs, arena);
}
void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
_function->serialize(place, buf);
}
void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena* arena) const override {
_function->deserialize(place, buf, arena);
}
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
_function->serialize_without_key_to_column(place, to);
}
void destroy(AggregateDataPtr __restrict place) const noexcept override {
_function->destroy(place);
}
bool has_trivial_destructor() const override { return _function->has_trivial_destructor(); }
size_t size_of_data() const override { return _function->size_of_data(); }
size_t align_of_data() const override { return _function->align_of_data(); }
protected:
AggregateFunctionPtr _function;
DataTypePtr _return_type;
};
} // namespace doris::vectorized

View File

@ -160,8 +160,8 @@ public:
}
void serialize_without_key_to_column(ConstAggregateDataPtr __restrict place,
MutableColumnPtr& dst) const override {
auto& col = assert_cast<ColVecResult&>(*dst);
IColumn& to) const override {
auto& col = assert_cast<ColVecResult&>(to);
col.resize(1);
reinterpret_cast<Data*>(col.get_data().data())->sum = this->data(place).sum;
}

View File

@ -52,9 +52,9 @@ public:
return TPrimitiveType::AGG_STATE;
}
const DataTypes& get_sub_types() { return sub_types; }
const DataTypes& get_sub_types() const { return sub_types; }
void add_sub_type(DataTypePtr type) { sub_types.push_back(type); }
void add_sub_type(DataTypePtr type) const { sub_types.push_back(type); }
void to_pb_column_meta(PColumnMeta* col_meta) const override {
IDataType::to_pb_column_meta(col_meta);
@ -64,7 +64,7 @@ public:
}
private:
DataTypes sub_types;
mutable DataTypes sub_types;
};
} // namespace doris::vectorized

View File

@ -716,7 +716,7 @@ Status AggregationNode::_serialize_without_key(RuntimeState* state, Block* block
for (int i = 0; i < _aggregate_evaluators.size(); ++i) {
_aggregate_evaluators[i]->function()->serialize_without_key_to_column(
_agg_data->without_key + _offsets_of_aggregate_states[i], value_columns[i]);
_agg_data->without_key + _offsets_of_aggregate_states[i], *value_columns[i]);
}
} else {
std::vector<VectorBufferWriter> value_buffer_writers;

View File

@ -26,6 +26,7 @@
#include "vec/exprs/table_function/vexplode_json_array.h"
#include "vec/exprs/table_function/vexplode_numbers.h"
#include "vec/exprs/table_function/vexplode_split.h"
#include "vec/utils/util.hpp"
namespace doris::vectorized {
@ -61,17 +62,6 @@ const std::unordered_map<std::string, std::function<std::unique_ptr<TableFunctio
Status TableFunctionFactory::get_fn(const std::string& fn_name_raw, ObjectPool* pool,
TableFunction** fn) {
auto match_suffix = [](const std::string& name, const std::string& suffix) -> bool {
if (name.length() < suffix.length()) {
return false;
}
return name.substr(name.length() - suffix.length()) == suffix;
};
auto remove_suffix = [](const std::string& name, const std::string& suffix) -> std::string {
return name.substr(0, name.length() - suffix.length());
};
bool is_outer = match_suffix(fn_name_raw, COMBINATOR_SUFFIX_OUTER);
std::string fn_name_real =
is_outer ? remove_suffix(fn_name_raw, COMBINATOR_SUFFIX_OUTER) : fn_name_raw;

View File

@ -33,12 +33,16 @@
#include "vec/aggregate_functions/aggregate_function_rpc.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
#include "vec/aggregate_functions/aggregate_function_sort.h"
#include "vec/aggregate_functions/aggregate_function_state_merge.h"
#include "vec/aggregate_functions/aggregate_function_state_union.h"
#include "vec/core/block.h"
#include "vec/core/column_with_type_and_name.h"
#include "vec/core/materialize_block.h"
#include "vec/data_types/data_type_agg_state.h"
#include "vec/data_types/data_type_factory.hpp"
#include "vec/exprs/vexpr.h"
#include "vec/exprs/vexpr_context.h"
#include "vec/utils/util.hpp"
namespace doris {
class RowDescriptor;
@ -51,6 +55,17 @@ class IColumn;
namespace doris::vectorized {
template <class FunctionType>
AggregateFunctionPtr get_agg_state_function(const std::string& name,
const DataTypes& argument_types,
DataTypePtr return_type) {
return FunctionType::create(
AggregateFunctionSimpleFactory::instance().get(
name, ((DataTypeAggState*)argument_types[0].get())->get_sub_types(),
return_type->is_nullable()),
argument_types, return_type);
}
AggFnEvaluator::AggFnEvaluator(const TExprNode& desc)
: _fn(desc.fn),
_is_merge(desc.agg_expr.is_merge_agg),
@ -143,6 +158,19 @@ Status AggFnEvaluator::prepare(RuntimeState* state, const RowDescriptor& desc,
}
} else if (_fn.binary_type == TFunctionBinaryType::RPC) {
_function = AggregateRpcUdaf::create(_fn, argument_types, _data_type);
} else if (_fn.binary_type == TFunctionBinaryType::AGG_STATE) {
if (match_suffix(_fn.name.function_name, AGG_UNION_SUFFIX)) {
_function = get_agg_state_function<AggregateStateUnion>(
remove_suffix(_fn.name.function_name, AGG_UNION_SUFFIX), argument_types,
_data_type);
} else if (match_suffix(_fn.name.function_name, AGG_MERGE_SUFFIX)) {
_function = get_agg_state_function<AggregateStateMerge>(
remove_suffix(_fn.name.function_name, AGG_MERGE_SUFFIX), argument_types,
_data_type);
} else {
return Status::InternalError(
"Aggregate Function {} is not endwith '_merge' or '_union'", _fn.signature);
}
} else {
_function = AggregateFunctionSimpleFactory::instance().get(
_fn.name.function_name, argument_types, _data_type->is_nullable());

View File

@ -32,15 +32,19 @@
#include "common/status.h"
#include "runtime/runtime_state.h"
#include "udf/udf.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
#include "vec/columns/column.h"
#include "vec/core/block.h"
#include "vec/core/column_with_type_and_name.h"
#include "vec/core/columns_with_type_and_name.h"
#include "vec/data_types/data_type.h"
#include "vec/data_types/data_type_agg_state.h"
#include "vec/exprs/vexpr_context.h"
#include "vec/functions/function_agg_state.h"
#include "vec/functions/function_java_udf.h"
#include "vec/functions/function_rpc.h"
#include "vec/functions/simple_function_factory.h"
#include "vec/utils/util.hpp"
namespace doris {
class RowDescriptor;
@ -50,6 +54,8 @@ class TExprNode;
namespace doris::vectorized {
const std::string AGG_STATE_SUFFIX = "_state";
VectorizedFnCall::VectorizedFnCall(const TExprNode& node) : VExpr(node) {}
Status VectorizedFnCall::prepare(RuntimeState* state, const RowDescriptor& desc,
@ -62,6 +68,7 @@ Status VectorizedFnCall::prepare(RuntimeState* state, const RowDescriptor& desc,
argument_template.emplace_back(nullptr, child->data_type(), child->expr_name());
child_expr_name.emplace_back(child->expr_name());
}
if (_fn.binary_type == TFunctionBinaryType::RPC) {
_function = FunctionRPC::create(_fn, argument_template, _data_type);
} else if (_fn.binary_type == TFunctionBinaryType::JAVA_UDF) {
@ -72,6 +79,29 @@ Status VectorizedFnCall::prepare(RuntimeState* state, const RowDescriptor& desc,
"Java UDF is not enabled, you can change be config enable_java_support to true "
"and restart be.");
}
} else if (_fn.binary_type == TFunctionBinaryType::AGG_STATE) {
DataTypes argument_types;
for (auto column : argument_template) {
argument_types.emplace_back(column.type);
}
if (match_suffix(_fn.name.function_name, AGG_STATE_SUFFIX)) {
if (_data_type->is_nullable()) {
return Status::InternalError("State function's return type must be not nullable");
}
if (_data_type->get_type_as_primitive_type() != PrimitiveType::TYPE_AGG_STATE) {
return Status::InternalError(
"State function's return type must be agg_state but get {}",
_data_type->get_family_name());
}
_function = FunctionAggState::create(
argument_types, _data_type,
AggregateFunctionSimpleFactory::instance().get(
remove_suffix(_fn.name.function_name, AGG_STATE_SUFFIX), argument_types,
_data_type->is_nullable()));
} else {
return Status::InternalError("Function {} is not endwith '_state'", _fn.signature);
}
} else {
// get the function. won't prepare function.
_function = SimpleFunctionFactory::instance().get_function(

View File

@ -0,0 +1,85 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <fmt/format.h>
#include "common/status.h"
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/common/arena.h"
#include "vec/core/block.h"
#include "vec/core/types.h"
#include "vec/data_types/data_type.h"
#include "vec/functions/function.h"
namespace doris::vectorized {
class FunctionAggState : public IFunction {
public:
FunctionAggState(const DataTypes& argument_types, const DataTypePtr& return_type,
AggregateFunctionPtr agg_function)
: _argument_types(argument_types),
_return_type(return_type),
_agg_function(agg_function) {}
static FunctionBasePtr create(const DataTypes& argument_types, const DataTypePtr& return_type,
AggregateFunctionPtr agg_function) {
if (agg_function == nullptr) {
return nullptr;
}
return std::make_shared<DefaultFunction>(
std::make_shared<FunctionAggState>(argument_types, return_type, agg_function),
argument_types, return_type);
}
size_t get_number_of_arguments() const override { return _argument_types.size(); }
bool use_default_implementation_for_constants() const override { return true; }
bool use_default_implementation_for_nulls() const override { return false; }
String get_name() const override { return fmt::format("{}_state", _agg_function->get_name()); }
DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
return _return_type;
}
Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
size_t result, size_t input_rows_count) override {
auto col = _return_type->create_column();
std::vector<const IColumn*> agg_columns;
for (size_t index : arguments) {
agg_columns.push_back(block.get_by_position(index).column);
}
VectorBufferWriter writter(assert_cast<ColumnString&>(*col));
_agg_function->streaming_agg_serialize(agg_columns.data(), writter, input_rows_count,
&arena);
block.replace_by_position(result, std::move(col));
return Status::OK();
}
private:
DataTypes _argument_types;
DataTypePtr _return_type;
AggregateFunctionPtr _agg_function;
Arena arena;
};
} // namespace doris::vectorized

View File

@ -465,7 +465,7 @@ void BlockReader::_update_agg_value(MutableColumns& columns, int begin, int end,
if (begin <= end) {
function->add_batch_range(begin, end, place, const_cast<const IColumn**>(&column_ptr),
nullptr, _stored_has_null_tag[idx]);
&_arena, _stored_has_null_tag[idx]);
}
if (is_close) {

View File

@ -102,6 +102,19 @@ public:
}
};
inline bool match_suffix(const std::string& name, const std::string& suffix) {
if (name.length() < suffix.length()) {
return false;
}
return name.substr(name.length() - suffix.length()) == suffix;
}
inline std::string remove_suffix(const std::string& name, const std::string& suffix) {
CHECK(match_suffix(name, suffix))
<< ", suffix not match, name=" << name << ", suffix=" << suffix;
return name.substr(0, name.length() - suffix.length());
};
} // namespace doris::vectorized
namespace apache::thrift {

View File

@ -66,8 +66,8 @@ public class BuiltinAggregateFunction extends Function {
}
@Override
public TFunction toThrift(Type realReturnType, Type[] realArgTypes) {
TFunction fn = super.toThrift(realReturnType, realArgTypes);
public TFunction toThrift(Type realReturnType, Type[] realArgTypes, Boolean[] realArgTypeNullables) {
TFunction fn = super.toThrift(realReturnType, realArgTypes, realArgTypeNullables);
// TODO: for now, just put the op_ enum as the id.
if (op == BuiltinAggregateFunction.Operator.FIRST_VALUE_REWRITE) {
fn.setId(0);

View File

@ -17,7 +17,6 @@
package org.apache.doris.analysis;
import org.apache.doris.catalog.AggregateType;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.DatabaseIf;
import org.apache.doris.catalog.Env;
@ -200,11 +199,7 @@ public class DescribeStmt extends ShowStmt {
// Extra string (aggregation and bloom filter)
List<String> extras = Lists.newArrayList();
if (column.getAggregationType() != null) {
if (column.getAggregationType() == AggregateType.GENERIC_AGGREGATION) {
extras.add(column.getGenericAggregationString());
} else {
extras.add(column.getAggregationType().name());
}
extras.add(column.getAggregationString());
}
if (bfColumns != null && bfColumns.contains(column.getName())) {
extras.add("BLOOM_FILTER");

View File

@ -21,12 +21,15 @@
package org.apache.doris.analysis;
import org.apache.doris.analysis.ArithmeticExpr.Operator;
import org.apache.doris.catalog.AggregateFunction;
import org.apache.doris.catalog.ArrayType;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.Function;
import org.apache.doris.catalog.Function.NullableMode;
import org.apache.doris.catalog.FunctionSet;
import org.apache.doris.catalog.MaterializedIndexMeta;
import org.apache.doris.catalog.PrimitiveType;
import org.apache.doris.catalog.ScalarFunction;
import org.apache.doris.catalog.ScalarType;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.AnalysisException;
@ -40,6 +43,7 @@ import org.apache.doris.statistics.ExprStats;
import org.apache.doris.thrift.TExpr;
import org.apache.doris.thrift.TExprNode;
import org.apache.doris.thrift.TExprOpcode;
import org.apache.doris.thrift.TFunctionBinaryType;
import com.google.common.base.Joiner;
import com.google.common.base.MoreObjects;
@ -75,6 +79,10 @@ public abstract class Expr extends TreeNode<Expr> implements ParseNode, Cloneabl
// supports negation.
private static final String NEGATE_FN = "negate";
public static final String AGG_STATE_SUFFIX = "_state";
public static final String AGG_UNION_SUFFIX = "_union";
public static final String AGG_MERGE_SUFFIX = "_merge";
protected boolean disableTableName = false;
// to be used where we can't come up with a better estimate
@ -436,6 +444,10 @@ public abstract class Expr extends TreeNode<Expr> implements ParseNode, Cloneabl
setSelectivity();
}
analysisDone();
if (type.isAggStateType() && !(this instanceof SlotRef) && ((ScalarType) type).getSubTypes() == null) {
type = new ScalarType(Arrays.asList(collectChildReturnTypes()),
Arrays.asList(collectChildReturnNullables()));
}
}
/**
@ -482,6 +494,14 @@ public abstract class Expr extends TreeNode<Expr> implements ParseNode, Cloneabl
return childTypes;
}
protected Boolean[] collectChildReturnNullables() {
Boolean[] childNullables = new Boolean[children.size()];
for (int i = 0; i < children.size(); ++i) {
childNullables[i] = children.get(i).isNullable();
}
return childNullables;
}
public List<Expr> getChildrenWithoutCast() {
List<Expr> result = new ArrayList<>();
for (int i = 0; i < children.size(); ++i) {
@ -984,10 +1004,14 @@ public abstract class Expr extends TreeNode<Expr> implements ParseNode, Cloneabl
// Append a flattened version of this expr, including all children, to 'container'.
protected void treeToThriftHelper(TExpr container) {
TExprNode msg = new TExprNode();
if (type.isAggStateType() && ((ScalarType) type).getSubTypes() == null) {
type = new ScalarType(Arrays.asList(collectChildReturnTypes()),
Arrays.asList(collectChildReturnNullables()));
}
msg.type = type.toThrift();
msg.num_children = children.size();
if (fn != null) {
msg.setFn(fn.toThrift(type, collectChildReturnTypes()));
msg.setFn(fn.toThrift(type, collectChildReturnTypes(), collectChildReturnNullables()));
if (fn.hasVarArgs()) {
msg.setVarargStartIdx(fn.getNumArgs() - 1);
}
@ -1382,7 +1406,8 @@ public abstract class Expr extends TreeNode<Expr> implements ParseNode, Cloneabl
}
public Expr checkTypeCompatibility(Type targetType) throws AnalysisException {
if (!targetType.isComplexType() && targetType.getPrimitiveType() == type.getPrimitiveType()) {
if (!targetType.isComplexType() && !targetType.isAggStateType()
&& targetType.getPrimitiveType() == type.getPrimitiveType()) {
if (targetType.isDecimalV2() && type.isDecimalV2()) {
return this;
} else if (!PrimitiveType.typeWithPrecision.contains(type.getPrimitiveType())) {
@ -1392,6 +1417,10 @@ public abstract class Expr extends TreeNode<Expr> implements ParseNode, Cloneabl
return this;
}
}
if (type.isAggStateType() != targetType.isAggStateType()) {
throw new AnalysisException("AggState can't cast from other type.");
}
// bitmap must match exactly
if (targetType.getPrimitiveType() == PrimitiveType.BITMAP) {
throw new AnalysisException("bitmap column require the function return type is BITMAP");
@ -1452,8 +1481,32 @@ public abstract class Expr extends TreeNode<Expr> implements ParseNode, Cloneabl
return this;
}
if (this.type.equals(targetType)) {
return this;
if (this.type.isAggStateType()) {
List<Type> subTypes = ((ScalarType) targetType).getSubTypes();
if (this instanceof FunctionCallExpr) {
if (subTypes.size() != getChildren().size()) {
throw new AnalysisException("AggState's subTypes size not euqal to children number");
}
for (int i = 0; i < subTypes.size(); i++) {
setChild(i, getChild(i).castTo(subTypes.get(i)));
}
type = targetType;
} else {
List<Type> selfSubTypes = ((ScalarType) type).getSubTypes();
if (subTypes.size() != selfSubTypes.size()) {
throw new AnalysisException("AggState's subTypes size did not match");
}
for (int i = 0; i < subTypes.size(); i++) {
if (subTypes.get(i) != selfSubTypes.get(i)) {
throw new AnalysisException("AggState's subType did not match");
}
}
}
} else {
if (this.type.equals(targetType)) {
return this;
}
}
if (targetType.getPrimitiveType() == PrimitiveType.DECIMALV2
@ -1793,14 +1846,72 @@ public abstract class Expr extends TreeNode<Expr> implements ParseNode, Cloneabl
throws AnalysisException {
FunctionName fnName = new FunctionName(name);
Function searchDesc = new Function(fnName, Arrays.asList(getActualArgTypes(argTypes)), Type.INVALID, false,
VectorizedUtil.isVectorized());
true);
Function f = Env.getCurrentEnv().getFunction(searchDesc, mode);
if (f != null && fnName.getFunction().equalsIgnoreCase("rand")) {
if (this.children.size() == 1
&& !(this.children.get(0) instanceof LiteralExpr)) {
if (this.children.size() == 1 && !(this.children.get(0) instanceof LiteralExpr)) {
throw new AnalysisException("The param of rand function must be literal");
}
}
if (f != null) {
return f;
}
boolean isUnion = name.toLowerCase().endsWith(AGG_UNION_SUFFIX);
boolean isMerge = name.toLowerCase().endsWith(AGG_MERGE_SUFFIX);
boolean isState = name.toLowerCase().endsWith(AGG_STATE_SUFFIX);
if (isUnion || isMerge || isState) {
if (isUnion) {
name = name.substring(0, name.length() - AGG_UNION_SUFFIX.length());
}
if (isMerge) {
name = name.substring(0, name.length() - AGG_MERGE_SUFFIX.length());
}
if (isState) {
name = name.substring(0, name.length() - AGG_STATE_SUFFIX.length());
}
List<Type> argList = Arrays.asList(getActualArgTypes(argTypes));
List<Type> nestedArgList;
if (isState) {
nestedArgList = argList;
} else {
if (argList.size() != 1 || !argList.get(0).isAggStateType()) {
throw new AnalysisException("merge/union function must input one agg_state");
}
ScalarType aggState = (ScalarType) argList.get(0);
if (aggState.getSubTypes() == null) {
throw new AnalysisException("agg_state's subTypes is null");
}
nestedArgList = aggState.getSubTypes();
}
searchDesc = new Function(new FunctionName(name), nestedArgList, Type.INVALID, false, true);
f = Env.getCurrentEnv().getFunction(searchDesc, mode);
if (f == null || !(f instanceof AggregateFunction)) {
return null;
}
if (isState) {
f = new ScalarFunction(new FunctionName(name + AGG_STATE_SUFFIX), Arrays.asList(f.getArgs()),
Type.AGG_STATE, f.hasVarArgs(), f.isUserVisible());
f.setNullableMode(NullableMode.ALWAYS_NOT_NULLABLE);
} else {
f = ((AggregateFunction) f).clone();
f.setArgs(argList);
if (isUnion) {
f.setName(new FunctionName(name + AGG_UNION_SUFFIX));
f.setReturnType((ScalarType) argList.get(0));
f.setNullableMode(NullableMode.ALWAYS_NOT_NULLABLE);
}
if (isMerge) {
f.setName(new FunctionName(name + AGG_MERGE_SUFFIX));
}
}
f.setBinaryType(TFunctionBinaryType.AGG_STATE);
}
return f;
}

View File

@ -589,8 +589,8 @@ public class AggregateFunction extends Function {
}
@Override
public TFunction toThrift(Type realReturnType, Type[] realArgTypes) {
TFunction fn = super.toThrift(realReturnType, realArgTypes);
public TFunction toThrift(Type realReturnType, Type[] realArgTypes, Boolean[] realArgTypeNullables) {
TFunction fn = super.toThrift(realReturnType, realArgTypes, realArgTypeNullables);
TAggregateFunction aggFn = new TAggregateFunction();
aggFn.setIsAnalyticOnlyFn(isAnalyticFn && !isAggregateFn);
aggFn.setUpdateFnSymbol(updateFnSymbol);

View File

@ -397,6 +397,14 @@ public class Column implements Writable, GsonPostProcessable {
return this.aggregationType;
}
public String getAggregationString() {
if (getAggregationType() == AggregateType.GENERIC_AGGREGATION) {
return getGenericAggregationString();
} else {
return getAggregationType().name();
}
}
public boolean isAggregated() {
return aggregationType != null && aggregationType != AggregateType.NONE;
}

View File

@ -17,6 +17,8 @@
package org.apache.doris.catalog;
import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.FunctionCallExpr;
import org.apache.doris.analysis.FunctionName;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.common.UserException;
@ -42,6 +44,7 @@ import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
/**
* Base class for all functions.
@ -217,6 +220,10 @@ public class Function implements Writable {
return argTypes;
}
public void setArgs(List<Type> argTypes) {
this.argTypes = argTypes.toArray(new Type[argTypes.size()]);
}
// Returns the number of arguments to this function.
public int getNumArgs() {
return argTypes.length;
@ -506,7 +513,7 @@ public class Function implements Writable {
return retType instanceof AnyType;
}
public TFunction toThrift(Type realReturnType, Type[] realArgTypes) {
public TFunction toThrift(Type realReturnType, Type[] realArgTypes, Boolean[] realArgTypeNullables) {
TFunction fn = new TFunction();
fn.setSignature(signatureString());
fn.setName(name.toThrift());
@ -514,16 +521,24 @@ public class Function implements Writable {
if (location != null) {
fn.setHdfsLocation(location.getLocation());
}
// `realArgTypes.length != argTypes.length` is true iff this is an aggregation function.
// For aggregation functions, `argTypes` here is already its real type with true precision and scale.
// `realArgTypes.length != argTypes.length` is true iff this is an aggregation
// function.
// For aggregation functions, `argTypes` here is already its real type with true
// precision and scale.
if (realArgTypes.length != argTypes.length) {
fn.setArgTypes(Type.toThrift(Lists.newArrayList(argTypes)));
} else {
fn.setArgTypes(Type.toThrift(Lists.newArrayList(argTypes), Lists.newArrayList(realArgTypes)));
}
// For types with different precisions and scales, return type only indicates a type with default
if (realReturnType.isAggStateType()) {
realReturnType = new ScalarType(Arrays.asList(realArgTypes), Arrays.asList(realArgTypeNullables));
}
// For types with different precisions and scales, return type only indicates a
// type with default
// precision and scale so we need to transform it to the correct type.
if (realReturnType.typeContainsPrecision()) {
if (realReturnType.typeContainsPrecision() || realReturnType.isAggStateType()) {
fn.setRetType(realReturnType.toThrift());
} else {
fn.setRetType(getReturnType().toThrift());
@ -816,4 +831,18 @@ public class Function implements Writable {
result = 31 * result + Arrays.hashCode(argTypes);
return result;
}
public static FunctionCallExpr convertToStateCombinator(FunctionCallExpr fnCall) {
Function aggFunction = fnCall.getFn();
List<Type> arguments = Arrays.asList(aggFunction.getArgs());
ScalarFunction fn = new org.apache.doris.catalog.ScalarFunction(
new FunctionName(aggFunction.getFunctionName().getFunction() + Expr.AGG_STATE_SUFFIX),
arguments,
new ScalarType(arguments, fnCall.getChildren().stream().map(expr -> {
return expr.isNullable();
}).collect(Collectors.toList())), aggFunction.hasVarArgs(), aggFunction.isUserVisible());
fn.setNullableMode(NullableMode.ALWAYS_NOT_NULLABLE);
fn.setBinaryType(TFunctionBinaryType.AGG_STATE);
return new FunctionCallExpr(fn, fnCall.getParams());
}
}

View File

@ -229,10 +229,10 @@ public class ScalarFunction extends Function {
}
@Override
public TFunction toThrift(Type realReturnType, Type[] realArgTypes) {
TFunction fn = super.toThrift(realReturnType, realArgTypes);
public TFunction toThrift(Type realReturnType, Type[] realArgTypes, Boolean[] realArgTypeNullables) {
TFunction fn = super.toThrift(realReturnType, realArgTypes, realArgTypeNullables);
fn.setScalarFn(new TScalarFunction());
if (getBinaryType() != TFunctionBinaryType.BUILTIN) {
if (getBinaryType() == TFunctionBinaryType.JAVA_UDF || getBinaryType() == TFunctionBinaryType.RPC) {
fn.getScalarFn().setSymbol(symbolName);
} else {
fn.getScalarFn().setSymbol("");

View File

@ -17,7 +17,6 @@
package org.apache.doris.common.proc;
import org.apache.doris.catalog.AggregateType;
import org.apache.doris.catalog.Column;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.common.FeConstants;
@ -60,11 +59,7 @@ public class IndexSchemaProcNode implements ProcNodeInterface {
// Extra string (aggregation and bloom filter)
List<String> extras = Lists.newArrayList();
if (column.getAggregationType() != null) {
if (column.getAggregationType() == AggregateType.GENERIC_AGGREGATION) {
extras.add(column.getGenericAggregationString());
} else {
extras.add(column.getAggregationType().name());
}
extras.add(column.getAggregationString());
}
if (bfColumns != null && bfColumns.contains(column.getName())) {
extras.add("BLOOM_FILTER");