From bbb3af6ce656c800b2b35a8a9761d23a08b772e0 Mon Sep 17 00:00:00 2001 From: Pxl Date: Mon, 29 May 2023 13:07:29 +0800 Subject: [PATCH] [Feature](agg_state) support agg_state combinators (#19969) support agg_state combinators state/merge/union --- be/src/olap/memtable.cpp | 6 +- be/src/olap/tablet_schema.cpp | 20 ++- be/src/olap/tablet_schema.h | 3 +- be/src/runtime/types.h | 8 +- .../aggregate_functions/aggregate_function.h | 52 ++++++-- .../aggregate_function_avg.h | 4 +- .../aggregate_function_count.h | 22 +-- .../aggregate_function_distinct.h | 1 - .../aggregate_function_min_max.h | 6 +- .../aggregate_function_null.h | 10 -- .../aggregate_function_state_merge.h | 50 +++++++ .../aggregate_function_state_union.h | 99 ++++++++++++++ .../aggregate_function_sum.h | 4 +- be/src/vec/data_types/data_type_agg_state.h | 6 +- be/src/vec/exec/vaggregation_node.cpp | 2 +- .../table_function/table_function_factory.cpp | 12 +- be/src/vec/exprs/vectorized_agg_fn.cpp | 28 ++++ be/src/vec/exprs/vectorized_fn_call.cpp | 30 +++++ be/src/vec/functions/function_agg_state.h | 85 ++++++++++++ be/src/vec/olap/block_reader.cpp | 2 +- be/src/vec/utils/util.hpp | 13 ++ .../analysis/BuiltinAggregateFunction.java | 4 +- .../apache/doris/analysis/DescribeStmt.java | 7 +- .../java/org/apache/doris/analysis/Expr.java | 125 +++++++++++++++++- .../doris/catalog/AggregateFunction.java | 4 +- .../java/org/apache/doris/catalog/Column.java | 8 ++ .../org/apache/doris/catalog/Function.java | 39 +++++- .../apache/doris/catalog/ScalarFunction.java | 6 +- .../common/proc/IndexSchemaProcNode.java | 7 +- 29 files changed, 564 insertions(+), 99 deletions(-) create mode 100644 be/src/vec/aggregate_functions/aggregate_function_state_merge.h create mode 100644 be/src/vec/aggregate_functions/aggregate_function_state_union.h create mode 100644 be/src/vec/functions/function_agg_state.h diff --git a/be/src/olap/memtable.cpp b/be/src/olap/memtable.cpp index 9890b0a374..16aa81ee61 100644 --- a/be/src/olap/memtable.cpp +++ b/be/src/olap/memtable.cpp @@ -233,7 +233,7 @@ void MemTable::_aggregate_two_row_in_block(vectorized::MutableBlock& mutable_blo auto col_ptr = mutable_block.mutable_columns()[cid].get(); _agg_functions[cid]->add(row_in_skiplist->agg_places(cid), const_cast(&col_ptr), - new_row->_row_pos, nullptr); + new_row->_row_pos, _arena.get()); } } void MemTable::_put_into_output(vectorized::Block& in_block) { @@ -298,7 +298,7 @@ void MemTable::_finalize_one_row(RowInBlock* row, } else { function->reset(agg_place); function->add(agg_place, const_cast(&col_ptr), - row_pos, nullptr); + row_pos, _arena.get()); } } if constexpr (is_final) { @@ -343,7 +343,7 @@ void MemTable::_aggregate() { _agg_functions[cid]->create(data); _agg_functions[cid]->add( data, const_cast(&col_ptr), - prev_row->_row_pos, nullptr); + prev_row->_row_pos, _arena.get()); } } _stat.merged_rows++; diff --git a/be/src/olap/tablet_schema.cpp b/be/src/olap/tablet_schema.cpp index 699d938938..cf0a2bd266 100644 --- a/be/src/olap/tablet_schema.cpp +++ b/be/src/olap/tablet_schema.cpp @@ -39,6 +39,7 @@ #include "runtime/thread_context.h" #include "tablet_meta.h" #include "vec/aggregate_functions/aggregate_function_simple_factory.h" +#include "vec/aggregate_functions/aggregate_function_state_union.h" #include "vec/core/block.h" #include "vec/data_types/data_type.h" #include "vec/data_types/data_type_factory.hpp" @@ -491,14 +492,21 @@ bool TabletColumn::is_row_store_column() const { return _col_name == BeConsts::ROW_STORE_COL; } -vectorized::AggregateFunctionPtr TabletColumn::get_aggregate_function_merge() const { +vectorized::AggregateFunctionPtr TabletColumn::get_aggregate_function_union( + vectorized::DataTypePtr type) const { + auto state_type = dynamic_cast(type.get()); + if (!state_type) { + return nullptr; + } vectorized::DataTypes argument_types; for (auto col : _sub_columns) { - argument_types.push_back(vectorized::DataTypeFactory::instance().create_data_type(col)); + auto sub_type = vectorized::DataTypeFactory::instance().create_data_type(col); + state_type->add_sub_type(sub_type); } - auto function = vectorized::AggregateFunctionSimpleFactory::instance().get( - _aggregation_name, argument_types, false); - return function; + auto agg_function = vectorized::AggregateFunctionSimpleFactory::instance().get( + _aggregation_name, state_type->get_sub_types(), false); + + return vectorized::AggregateStateUnion::create(agg_function, {type}, type); } vectorized::AggregateFunctionPtr TabletColumn::get_aggregate_function(std::string suffix) const { @@ -514,7 +522,7 @@ vectorized::AggregateFunctionPtr TabletColumn::get_aggregate_function(std::strin if (function) { return function; } - return get_aggregate_function_merge(); + return get_aggregate_function_union(type); } void TabletIndex::init_from_thrift(const TOlapTableIndex& index, diff --git a/be/src/olap/tablet_schema.h b/be/src/olap/tablet_schema.h index 7a1b34745c..3dad8ba875 100644 --- a/be/src/olap/tablet_schema.h +++ b/be/src/olap/tablet_schema.h @@ -89,7 +89,8 @@ public: void set_is_nullable(bool is_nullable) { _is_nullable = is_nullable; } void set_has_default_value(bool has) { _has_default_value = has; } FieldAggregationMethod aggregation() const { return _aggregation; } - vectorized::AggregateFunctionPtr get_aggregate_function_merge() const; + vectorized::AggregateFunctionPtr get_aggregate_function_union( + vectorized::DataTypePtr type) const; vectorized::AggregateFunctionPtr get_aggregate_function(std::string suffix) const; int precision() const { return _precision; } int frac() const { return _frac; } diff --git a/be/src/runtime/types.h b/be/src/runtime/types.h index 5950fe55d7..1d261caeb2 100644 --- a/be/src/runtime/types.h +++ b/be/src/runtime/types.h @@ -59,7 +59,6 @@ struct TypeDescriptor { /// The maximum precision representable by a 8-byte decimal (Decimal8Value) static constexpr int MAX_DECIMAL8_PRECISION = 18; - // Empty for scalar types std::vector children; // Only set if type == TYPE_STRUCT. The field name of each child. @@ -148,6 +147,13 @@ struct TypeDescriptor { int idx = 0; TypeDescriptor result(t.types, &idx); DCHECK_EQ(idx, t.types.size() - 1); + if (result.type == TYPE_AGG_STATE) { + DCHECK(t.__isset.sub_types); + for (auto sub : t.sub_types) { + result.children.push_back(from_thrift(sub)); + result.contains_nulls.push_back(sub.is_nullable); + } + } return result; } diff --git a/be/src/vec/aggregate_functions/aggregate_function.h b/be/src/vec/aggregate_functions/aggregate_function.h index 73f3473d43..98faf3a389 100644 --- a/be/src/vec/aggregate_functions/aggregate_function.h +++ b/be/src/vec/aggregate_functions/aggregate_function.h @@ -134,7 +134,7 @@ public: MutableColumnPtr& dst, const size_t num_rows) const = 0; virtual void serialize_without_key_to_column(ConstAggregateDataPtr __restrict place, - MutableColumnPtr& dst) const = 0; + IColumn& to) const = 0; /// Deserializes state. This function is called only for empty (just created) states. virtual void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, @@ -150,6 +150,10 @@ public: virtual void deserialize_and_merge(AggregateDataPtr __restrict place, BufferReadable& buf, Arena* arena) const = 0; + virtual void deserialize_and_merge_from_column_range(AggregateDataPtr __restrict place, + const IColumn& column, size_t begin, + size_t end, Arena* arena) const = 0; + virtual void deserialize_and_merge_from_column(AggregateDataPtr __restrict place, const IColumn& column, Arena* arena) const = 0; @@ -332,8 +336,8 @@ public: } void serialize_without_key_to_column(ConstAggregateDataPtr __restrict place, - MutableColumnPtr& dst) const override { - VectorBufferWriter writter(assert_cast(*dst)); + IColumn& to) const override { + VectorBufferWriter writter(assert_cast(to)); assert_cast(this)->serialize(place, writter); writter.commit(); } @@ -382,6 +386,38 @@ public: } } } + + void deserialize_and_merge_from_column_range(AggregateDataPtr __restrict place, + const IColumn& column, size_t begin, size_t end, + Arena* arena) const override { + DCHECK(end <= column.size() && begin <= end) + << ", begin:" << begin << ", end:" << end << ", column.size():" << column.size(); + for (size_t i = begin; i <= end; ++i) { + VectorBufferReader buffer_reader( + (assert_cast(column)).get_data_at(i)); + deserialize_and_merge(place, buffer_reader, arena); + } + } + + void deserialize_and_merge_from_column(AggregateDataPtr __restrict place, const IColumn& column, + Arena* arena) const override { + if (column.empty()) { + return; + } + deserialize_and_merge_from_column_range(place, column, 0, column.size() - 1, arena); + } + + void deserialize_and_merge(AggregateDataPtr __restrict place, BufferReadable& buf, + Arena* arena) const override { + char deserialized_data[size_of_data()]; + AggregateDataPtr deserialized_place = (AggregateDataPtr)deserialized_data; + + auto derived = static_cast(this); + derived->create(deserialized_place); + derived->deserialize(deserialized_place, buf, arena); + derived->merge(place, deserialized_place, arena); + derived->destroy(deserialized_place); + } }; /// Implements several methods for manipulation with data. T - type of structure with data for aggregation. @@ -426,16 +462,6 @@ public: derived->deserialize(deserialized_place, buf, arena); derived->merge(place, deserialized_place, arena); } - - void deserialize_and_merge_from_column(AggregateDataPtr __restrict place, const IColumn& column, - Arena* arena) const override { - size_t num_rows = column.size(); - for (size_t i = 0; i != num_rows; ++i) { - VectorBufferReader buffer_reader( - (assert_cast(column)).get_data_at(i)); - deserialize_and_merge(place, buffer_reader, arena); - } - } }; using AggregateFunctionPtr = std::shared_ptr; diff --git a/be/src/vec/aggregate_functions/aggregate_function_avg.h b/be/src/vec/aggregate_functions/aggregate_function_avg.h index bf9c71b90d..64ca64133a 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_avg.h +++ b/be/src/vec/aggregate_functions/aggregate_function_avg.h @@ -221,8 +221,8 @@ public: } void serialize_without_key_to_column(ConstAggregateDataPtr __restrict place, - MutableColumnPtr& dst) const override { - auto& col = assert_cast(*dst); + IColumn& to) const override { + auto& col = assert_cast(to); col.set_item_size(sizeof(Data)); col.resize(1); *reinterpret_cast(col.get_data().data()) = this->data(place); diff --git a/be/src/vec/aggregate_functions/aggregate_function_count.h b/be/src/vec/aggregate_functions/aggregate_function_count.h index 0b6ec8f2ba..a890c58ed5 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_count.h +++ b/be/src/vec/aggregate_functions/aggregate_function_count.h @@ -104,7 +104,7 @@ public: col.resize(num_rows); auto* data = col.get_data().data(); for (size_t i = 0; i != num_rows; ++i) { - data[i] = this->data(places[i] + offset).count; + data[i] = AggregateFunctionCount::data(places[i] + offset).count; } } @@ -120,15 +120,16 @@ public: auto data = assert_cast(column).get_data().data(); const size_t num_rows = column.size(); for (size_t i = 0; i != num_rows; ++i) { - this->data(place).count += data[i]; + AggregateFunctionCount::data(place).count += data[i]; } } void serialize_without_key_to_column(ConstAggregateDataPtr __restrict place, - MutableColumnPtr& dst) const override { - auto& col = assert_cast(*dst); + IColumn& to) const override { + auto& col = assert_cast(to); col.resize(1); - reinterpret_cast(col.get_data().data())->count = this->data(place).count; + reinterpret_cast(col.get_data().data())->count = + AggregateFunctionCount::data(place).count; } MutableColumnPtr create_serialize_column() const override { @@ -199,7 +200,7 @@ public: col.resize(num_rows); auto* data = col.get_data().data(); for (size_t i = 0; i != num_rows; ++i) { - data[i] = this->data(places[i] + offset).count; + data[i] = AggregateFunctionCountNotNullUnary::data(places[i] + offset).count; } } @@ -219,15 +220,16 @@ public: auto data = assert_cast(column).get_data().data(); const size_t num_rows = column.size(); for (size_t i = 0; i != num_rows; ++i) { - this->data(place).count += data[i]; + AggregateFunctionCountNotNullUnary::data(place).count += data[i]; } } void serialize_without_key_to_column(ConstAggregateDataPtr __restrict place, - MutableColumnPtr& dst) const override { - auto& col = assert_cast(*dst); + IColumn& to) const override { + auto& col = assert_cast(to); col.resize(1); - reinterpret_cast(col.get_data().data())->count = this->data(place).count; + reinterpret_cast(col.get_data().data())->count = + AggregateFunctionCountNotNullUnary::data(place).count; } MutableColumnPtr create_serialize_column() const override { diff --git a/be/src/vec/aggregate_functions/aggregate_function_distinct.h b/be/src/vec/aggregate_functions/aggregate_function_distinct.h index 1f17e472d5..769b4ff805 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_distinct.h +++ b/be/src/vec/aggregate_functions/aggregate_function_distinct.h @@ -219,7 +219,6 @@ public: this->data(place).deserialize(buf, arena); } - // void insert_result_into(AggregateDataPtr place, IColumn & to, Arena * arena) const override void insert_result_into(ConstAggregateDataPtr targetplace, IColumn& to) const override { auto place = const_cast(targetplace); auto arguments = this->data(place).get_arguments(this->argument_types); diff --git a/be/src/vec/aggregate_functions/aggregate_function_min_max.h b/be/src/vec/aggregate_functions/aggregate_function_min_max.h index c17aee4348..b22b6d2251 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_min_max.h +++ b/be/src/vec/aggregate_functions/aggregate_function_min_max.h @@ -609,13 +609,13 @@ public: } void serialize_without_key_to_column(ConstAggregateDataPtr __restrict place, - MutableColumnPtr& dst) const override { + IColumn& to) const override { if constexpr (Data::IsFixedLength) { - auto& col = assert_cast(*dst); + auto& col = assert_cast(to); col.resize(1); *reinterpret_cast(col.get_data().data()) = this->data(place); } else { - Base::serialize_without_key_to_column(place, dst); + Base::serialize_without_key_to_column(place, to); } } diff --git a/be/src/vec/aggregate_functions/aggregate_function_null.h b/be/src/vec/aggregate_functions/aggregate_function_null.h index 180a334d19..50f5ab380e 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_null.h +++ b/be/src/vec/aggregate_functions/aggregate_function_null.h @@ -155,16 +155,6 @@ public: } } - void deserialize_and_merge_from_column(AggregateDataPtr __restrict place, const IColumn& column, - Arena* arena) const override { - size_t num_rows = column.size(); - for (size_t i = 0; i != num_rows; ++i) { - VectorBufferReader buffer_reader( - (assert_cast(column)).get_data_at(i)); - deserialize_and_merge(place, buffer_reader, arena); - } - } - void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { if constexpr (result_is_nullable) { ColumnNullable& to_concrete = assert_cast(to); diff --git a/be/src/vec/aggregate_functions/aggregate_function_state_merge.h b/be/src/vec/aggregate_functions/aggregate_function_state_merge.h new file mode 100644 index 0000000000..7afd79a1c8 --- /dev/null +++ b/be/src/vec/aggregate_functions/aggregate_function_state_merge.h @@ -0,0 +1,50 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "vec/aggregate_functions/aggregate_function_state_union.h" + +namespace doris::vectorized { +const static std::string AGG_MERGE_SUFFIX = "_merge"; + +class AggregateStateMerge : public AggregateStateUnion { +public: + AggregateStateMerge(AggregateFunctionPtr function, const DataTypes& argument_types, + const DataTypePtr& return_type) + : AggregateStateUnion(function, argument_types, return_type) {} + + static AggregateFunctionPtr create(AggregateFunctionPtr function, + const DataTypes& argument_types, + const DataTypePtr& return_type) { + CHECK(argument_types.size() == 1); + if (function == nullptr) { + return nullptr; + } + return std::make_shared(function, argument_types, return_type); + } + + String get_name() const override { return _function->get_name() + AGG_MERGE_SUFFIX; } + + DataTypePtr get_return_type() const override { return _function->get_return_type(); } + + void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { + _function->insert_result_into(place, to); + } +}; + +} // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_state_union.h b/be/src/vec/aggregate_functions/aggregate_function_state_union.h new file mode 100644 index 0000000000..d12e8e66d3 --- /dev/null +++ b/be/src/vec/aggregate_functions/aggregate_function_state_union.h @@ -0,0 +1,99 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/data_types/data_type_agg_state.h" + +namespace doris::vectorized { +const static std::string AGG_UNION_SUFFIX = "_union"; + +class AggregateStateUnion : public IAggregateFunctionHelper { +public: + AggregateStateUnion(AggregateFunctionPtr function, const DataTypes& argument_types, + const DataTypePtr& return_type) + : IAggregateFunctionHelper(argument_types), + _function(function), + _return_type(return_type) {} + ~AggregateStateUnion() override = default; + + static AggregateFunctionPtr create(AggregateFunctionPtr function, + const DataTypes& argument_types, + const DataTypePtr& return_type) { + CHECK(argument_types.size() == 1); + if (function == nullptr) { + return nullptr; + } + return std::make_shared(function, argument_types, return_type); + } + + void create(AggregateDataPtr __restrict place) const override { _function->create(place); } + + String get_name() const override { return _function->get_name() + AGG_UNION_SUFFIX; } + + DataTypePtr get_return_type() const override { return _return_type; } + + void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + Arena* arena) const override { + VectorBufferReader buffer_reader( + (assert_cast(*columns[0])).get_data_at(row_num)); + deserialize_and_merge(place, buffer_reader, arena); + } + + void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns, + Arena* arena) const override { + _function->deserialize_and_merge_from_column_range(place, *columns[0], 0, batch_size - 1, + arena); + } + + void reset(AggregateDataPtr place) const override { _function->reset(place); } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, + Arena* arena) const override { + _function->merge(place, rhs, arena); + } + + void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { + _function->serialize(place, buf); + } + + void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, + Arena* arena) const override { + _function->deserialize(place, buf, arena); + } + + void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { + _function->serialize_without_key_to_column(place, to); + } + + void destroy(AggregateDataPtr __restrict place) const noexcept override { + _function->destroy(place); + } + + bool has_trivial_destructor() const override { return _function->has_trivial_destructor(); } + + size_t size_of_data() const override { return _function->size_of_data(); } + + size_t align_of_data() const override { return _function->align_of_data(); } + +protected: + AggregateFunctionPtr _function; + DataTypePtr _return_type; +}; + +} // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_sum.h b/be/src/vec/aggregate_functions/aggregate_function_sum.h index b56dca4c97..75e9609f88 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_sum.h +++ b/be/src/vec/aggregate_functions/aggregate_function_sum.h @@ -160,8 +160,8 @@ public: } void serialize_without_key_to_column(ConstAggregateDataPtr __restrict place, - MutableColumnPtr& dst) const override { - auto& col = assert_cast(*dst); + IColumn& to) const override { + auto& col = assert_cast(to); col.resize(1); reinterpret_cast(col.get_data().data())->sum = this->data(place).sum; } diff --git a/be/src/vec/data_types/data_type_agg_state.h b/be/src/vec/data_types/data_type_agg_state.h index 27d0d2c0e1..1bb5dc4e00 100644 --- a/be/src/vec/data_types/data_type_agg_state.h +++ b/be/src/vec/data_types/data_type_agg_state.h @@ -52,9 +52,9 @@ public: return TPrimitiveType::AGG_STATE; } - const DataTypes& get_sub_types() { return sub_types; } + const DataTypes& get_sub_types() const { return sub_types; } - void add_sub_type(DataTypePtr type) { sub_types.push_back(type); } + void add_sub_type(DataTypePtr type) const { sub_types.push_back(type); } void to_pb_column_meta(PColumnMeta* col_meta) const override { IDataType::to_pb_column_meta(col_meta); @@ -64,7 +64,7 @@ public: } private: - DataTypes sub_types; + mutable DataTypes sub_types; }; } // namespace doris::vectorized diff --git a/be/src/vec/exec/vaggregation_node.cpp b/be/src/vec/exec/vaggregation_node.cpp index e4f2074dce..f4bde7569a 100644 --- a/be/src/vec/exec/vaggregation_node.cpp +++ b/be/src/vec/exec/vaggregation_node.cpp @@ -716,7 +716,7 @@ Status AggregationNode::_serialize_without_key(RuntimeState* state, Block* block for (int i = 0; i < _aggregate_evaluators.size(); ++i) { _aggregate_evaluators[i]->function()->serialize_without_key_to_column( - _agg_data->without_key + _offsets_of_aggregate_states[i], value_columns[i]); + _agg_data->without_key + _offsets_of_aggregate_states[i], *value_columns[i]); } } else { std::vector value_buffer_writers; diff --git a/be/src/vec/exprs/table_function/table_function_factory.cpp b/be/src/vec/exprs/table_function/table_function_factory.cpp index f5846f867a..06b69a27c7 100644 --- a/be/src/vec/exprs/table_function/table_function_factory.cpp +++ b/be/src/vec/exprs/table_function/table_function_factory.cpp @@ -26,6 +26,7 @@ #include "vec/exprs/table_function/vexplode_json_array.h" #include "vec/exprs/table_function/vexplode_numbers.h" #include "vec/exprs/table_function/vexplode_split.h" +#include "vec/utils/util.hpp" namespace doris::vectorized { @@ -61,17 +62,6 @@ const std::unordered_map bool { - if (name.length() < suffix.length()) { - return false; - } - return name.substr(name.length() - suffix.length()) == suffix; - }; - - auto remove_suffix = [](const std::string& name, const std::string& suffix) -> std::string { - return name.substr(0, name.length() - suffix.length()); - }; - bool is_outer = match_suffix(fn_name_raw, COMBINATOR_SUFFIX_OUTER); std::string fn_name_real = is_outer ? remove_suffix(fn_name_raw, COMBINATOR_SUFFIX_OUTER) : fn_name_raw; diff --git a/be/src/vec/exprs/vectorized_agg_fn.cpp b/be/src/vec/exprs/vectorized_agg_fn.cpp index d13138fc56..513eb8fe98 100644 --- a/be/src/vec/exprs/vectorized_agg_fn.cpp +++ b/be/src/vec/exprs/vectorized_agg_fn.cpp @@ -33,12 +33,16 @@ #include "vec/aggregate_functions/aggregate_function_rpc.h" #include "vec/aggregate_functions/aggregate_function_simple_factory.h" #include "vec/aggregate_functions/aggregate_function_sort.h" +#include "vec/aggregate_functions/aggregate_function_state_merge.h" +#include "vec/aggregate_functions/aggregate_function_state_union.h" #include "vec/core/block.h" #include "vec/core/column_with_type_and_name.h" #include "vec/core/materialize_block.h" +#include "vec/data_types/data_type_agg_state.h" #include "vec/data_types/data_type_factory.hpp" #include "vec/exprs/vexpr.h" #include "vec/exprs/vexpr_context.h" +#include "vec/utils/util.hpp" namespace doris { class RowDescriptor; @@ -51,6 +55,17 @@ class IColumn; namespace doris::vectorized { +template +AggregateFunctionPtr get_agg_state_function(const std::string& name, + const DataTypes& argument_types, + DataTypePtr return_type) { + return FunctionType::create( + AggregateFunctionSimpleFactory::instance().get( + name, ((DataTypeAggState*)argument_types[0].get())->get_sub_types(), + return_type->is_nullable()), + argument_types, return_type); +} + AggFnEvaluator::AggFnEvaluator(const TExprNode& desc) : _fn(desc.fn), _is_merge(desc.agg_expr.is_merge_agg), @@ -143,6 +158,19 @@ Status AggFnEvaluator::prepare(RuntimeState* state, const RowDescriptor& desc, } } else if (_fn.binary_type == TFunctionBinaryType::RPC) { _function = AggregateRpcUdaf::create(_fn, argument_types, _data_type); + } else if (_fn.binary_type == TFunctionBinaryType::AGG_STATE) { + if (match_suffix(_fn.name.function_name, AGG_UNION_SUFFIX)) { + _function = get_agg_state_function( + remove_suffix(_fn.name.function_name, AGG_UNION_SUFFIX), argument_types, + _data_type); + } else if (match_suffix(_fn.name.function_name, AGG_MERGE_SUFFIX)) { + _function = get_agg_state_function( + remove_suffix(_fn.name.function_name, AGG_MERGE_SUFFIX), argument_types, + _data_type); + } else { + return Status::InternalError( + "Aggregate Function {} is not endwith '_merge' or '_union'", _fn.signature); + } } else { _function = AggregateFunctionSimpleFactory::instance().get( _fn.name.function_name, argument_types, _data_type->is_nullable()); diff --git a/be/src/vec/exprs/vectorized_fn_call.cpp b/be/src/vec/exprs/vectorized_fn_call.cpp index 636be579e7..53cd19ed0d 100644 --- a/be/src/vec/exprs/vectorized_fn_call.cpp +++ b/be/src/vec/exprs/vectorized_fn_call.cpp @@ -32,15 +32,19 @@ #include "common/status.h" #include "runtime/runtime_state.h" #include "udf/udf.h" +#include "vec/aggregate_functions/aggregate_function_simple_factory.h" #include "vec/columns/column.h" #include "vec/core/block.h" #include "vec/core/column_with_type_and_name.h" #include "vec/core/columns_with_type_and_name.h" #include "vec/data_types/data_type.h" +#include "vec/data_types/data_type_agg_state.h" #include "vec/exprs/vexpr_context.h" +#include "vec/functions/function_agg_state.h" #include "vec/functions/function_java_udf.h" #include "vec/functions/function_rpc.h" #include "vec/functions/simple_function_factory.h" +#include "vec/utils/util.hpp" namespace doris { class RowDescriptor; @@ -50,6 +54,8 @@ class TExprNode; namespace doris::vectorized { +const std::string AGG_STATE_SUFFIX = "_state"; + VectorizedFnCall::VectorizedFnCall(const TExprNode& node) : VExpr(node) {} Status VectorizedFnCall::prepare(RuntimeState* state, const RowDescriptor& desc, @@ -62,6 +68,7 @@ Status VectorizedFnCall::prepare(RuntimeState* state, const RowDescriptor& desc, argument_template.emplace_back(nullptr, child->data_type(), child->expr_name()); child_expr_name.emplace_back(child->expr_name()); } + if (_fn.binary_type == TFunctionBinaryType::RPC) { _function = FunctionRPC::create(_fn, argument_template, _data_type); } else if (_fn.binary_type == TFunctionBinaryType::JAVA_UDF) { @@ -72,6 +79,29 @@ Status VectorizedFnCall::prepare(RuntimeState* state, const RowDescriptor& desc, "Java UDF is not enabled, you can change be config enable_java_support to true " "and restart be."); } + } else if (_fn.binary_type == TFunctionBinaryType::AGG_STATE) { + DataTypes argument_types; + for (auto column : argument_template) { + argument_types.emplace_back(column.type); + } + + if (match_suffix(_fn.name.function_name, AGG_STATE_SUFFIX)) { + if (_data_type->is_nullable()) { + return Status::InternalError("State function's return type must be not nullable"); + } + if (_data_type->get_type_as_primitive_type() != PrimitiveType::TYPE_AGG_STATE) { + return Status::InternalError( + "State function's return type must be agg_state but get {}", + _data_type->get_family_name()); + } + _function = FunctionAggState::create( + argument_types, _data_type, + AggregateFunctionSimpleFactory::instance().get( + remove_suffix(_fn.name.function_name, AGG_STATE_SUFFIX), argument_types, + _data_type->is_nullable())); + } else { + return Status::InternalError("Function {} is not endwith '_state'", _fn.signature); + } } else { // get the function. won't prepare function. _function = SimpleFunctionFactory::instance().get_function( diff --git a/be/src/vec/functions/function_agg_state.h b/be/src/vec/functions/function_agg_state.h new file mode 100644 index 0000000000..e17fc06c60 --- /dev/null +++ b/be/src/vec/functions/function_agg_state.h @@ -0,0 +1,85 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "common/status.h" +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/common/arena.h" +#include "vec/core/block.h" +#include "vec/core/types.h" +#include "vec/data_types/data_type.h" +#include "vec/functions/function.h" + +namespace doris::vectorized { + +class FunctionAggState : public IFunction { +public: + FunctionAggState(const DataTypes& argument_types, const DataTypePtr& return_type, + AggregateFunctionPtr agg_function) + : _argument_types(argument_types), + _return_type(return_type), + _agg_function(agg_function) {} + + static FunctionBasePtr create(const DataTypes& argument_types, const DataTypePtr& return_type, + AggregateFunctionPtr agg_function) { + if (agg_function == nullptr) { + return nullptr; + } + return std::make_shared( + std::make_shared(argument_types, return_type, agg_function), + argument_types, return_type); + } + + size_t get_number_of_arguments() const override { return _argument_types.size(); } + + bool use_default_implementation_for_constants() const override { return true; } + bool use_default_implementation_for_nulls() const override { return false; } + + String get_name() const override { return fmt::format("{}_state", _agg_function->get_name()); } + + DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { + return _return_type; + } + + Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, + size_t result, size_t input_rows_count) override { + auto col = _return_type->create_column(); + std::vector agg_columns; + + for (size_t index : arguments) { + agg_columns.push_back(block.get_by_position(index).column); + } + + VectorBufferWriter writter(assert_cast(*col)); + _agg_function->streaming_agg_serialize(agg_columns.data(), writter, input_rows_count, + &arena); + + block.replace_by_position(result, std::move(col)); + return Status::OK(); + } + +private: + DataTypes _argument_types; + DataTypePtr _return_type; + AggregateFunctionPtr _agg_function; + Arena arena; +}; + +} // namespace doris::vectorized diff --git a/be/src/vec/olap/block_reader.cpp b/be/src/vec/olap/block_reader.cpp index 92ce6fe263..2e1c8d716d 100644 --- a/be/src/vec/olap/block_reader.cpp +++ b/be/src/vec/olap/block_reader.cpp @@ -465,7 +465,7 @@ void BlockReader::_update_agg_value(MutableColumns& columns, int begin, int end, if (begin <= end) { function->add_batch_range(begin, end, place, const_cast(&column_ptr), - nullptr, _stored_has_null_tag[idx]); + &_arena, _stored_has_null_tag[idx]); } if (is_close) { diff --git a/be/src/vec/utils/util.hpp b/be/src/vec/utils/util.hpp index 416987c31d..ba593c60cb 100644 --- a/be/src/vec/utils/util.hpp +++ b/be/src/vec/utils/util.hpp @@ -102,6 +102,19 @@ public: } }; +inline bool match_suffix(const std::string& name, const std::string& suffix) { + if (name.length() < suffix.length()) { + return false; + } + return name.substr(name.length() - suffix.length()) == suffix; +} + +inline std::string remove_suffix(const std::string& name, const std::string& suffix) { + CHECK(match_suffix(name, suffix)) + << ", suffix not match, name=" << name << ", suffix=" << suffix; + return name.substr(0, name.length() - suffix.length()); +}; + } // namespace doris::vectorized namespace apache::thrift { diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/BuiltinAggregateFunction.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/BuiltinAggregateFunction.java index 8bd6c50d42..f9113d9348 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/BuiltinAggregateFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/BuiltinAggregateFunction.java @@ -66,8 +66,8 @@ public class BuiltinAggregateFunction extends Function { } @Override - public TFunction toThrift(Type realReturnType, Type[] realArgTypes) { - TFunction fn = super.toThrift(realReturnType, realArgTypes); + public TFunction toThrift(Type realReturnType, Type[] realArgTypes, Boolean[] realArgTypeNullables) { + TFunction fn = super.toThrift(realReturnType, realArgTypes, realArgTypeNullables); // TODO: for now, just put the op_ enum as the id. if (op == BuiltinAggregateFunction.Operator.FIRST_VALUE_REWRITE) { fn.setId(0); diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/DescribeStmt.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/DescribeStmt.java index eb922ced12..5ef86bf4ea 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/DescribeStmt.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/DescribeStmt.java @@ -17,7 +17,6 @@ package org.apache.doris.analysis; -import org.apache.doris.catalog.AggregateType; import org.apache.doris.catalog.Column; import org.apache.doris.catalog.DatabaseIf; import org.apache.doris.catalog.Env; @@ -200,11 +199,7 @@ public class DescribeStmt extends ShowStmt { // Extra string (aggregation and bloom filter) List extras = Lists.newArrayList(); if (column.getAggregationType() != null) { - if (column.getAggregationType() == AggregateType.GENERIC_AGGREGATION) { - extras.add(column.getGenericAggregationString()); - } else { - extras.add(column.getAggregationType().name()); - } + extras.add(column.getAggregationString()); } if (bfColumns != null && bfColumns.contains(column.getName())) { extras.add("BLOOM_FILTER"); diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java index 8e8c213e9b..93d5437564 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java @@ -21,12 +21,15 @@ package org.apache.doris.analysis; import org.apache.doris.analysis.ArithmeticExpr.Operator; +import org.apache.doris.catalog.AggregateFunction; import org.apache.doris.catalog.ArrayType; import org.apache.doris.catalog.Env; import org.apache.doris.catalog.Function; +import org.apache.doris.catalog.Function.NullableMode; import org.apache.doris.catalog.FunctionSet; import org.apache.doris.catalog.MaterializedIndexMeta; import org.apache.doris.catalog.PrimitiveType; +import org.apache.doris.catalog.ScalarFunction; import org.apache.doris.catalog.ScalarType; import org.apache.doris.catalog.Type; import org.apache.doris.common.AnalysisException; @@ -40,6 +43,7 @@ import org.apache.doris.statistics.ExprStats; import org.apache.doris.thrift.TExpr; import org.apache.doris.thrift.TExprNode; import org.apache.doris.thrift.TExprOpcode; +import org.apache.doris.thrift.TFunctionBinaryType; import com.google.common.base.Joiner; import com.google.common.base.MoreObjects; @@ -75,6 +79,10 @@ public abstract class Expr extends TreeNode implements ParseNode, Cloneabl // supports negation. private static final String NEGATE_FN = "negate"; + public static final String AGG_STATE_SUFFIX = "_state"; + public static final String AGG_UNION_SUFFIX = "_union"; + public static final String AGG_MERGE_SUFFIX = "_merge"; + protected boolean disableTableName = false; // to be used where we can't come up with a better estimate @@ -436,6 +444,10 @@ public abstract class Expr extends TreeNode implements ParseNode, Cloneabl setSelectivity(); } analysisDone(); + if (type.isAggStateType() && !(this instanceof SlotRef) && ((ScalarType) type).getSubTypes() == null) { + type = new ScalarType(Arrays.asList(collectChildReturnTypes()), + Arrays.asList(collectChildReturnNullables())); + } } /** @@ -482,6 +494,14 @@ public abstract class Expr extends TreeNode implements ParseNode, Cloneabl return childTypes; } + protected Boolean[] collectChildReturnNullables() { + Boolean[] childNullables = new Boolean[children.size()]; + for (int i = 0; i < children.size(); ++i) { + childNullables[i] = children.get(i).isNullable(); + } + return childNullables; + } + public List getChildrenWithoutCast() { List result = new ArrayList<>(); for (int i = 0; i < children.size(); ++i) { @@ -984,10 +1004,14 @@ public abstract class Expr extends TreeNode implements ParseNode, Cloneabl // Append a flattened version of this expr, including all children, to 'container'. protected void treeToThriftHelper(TExpr container) { TExprNode msg = new TExprNode(); + if (type.isAggStateType() && ((ScalarType) type).getSubTypes() == null) { + type = new ScalarType(Arrays.asList(collectChildReturnTypes()), + Arrays.asList(collectChildReturnNullables())); + } msg.type = type.toThrift(); msg.num_children = children.size(); if (fn != null) { - msg.setFn(fn.toThrift(type, collectChildReturnTypes())); + msg.setFn(fn.toThrift(type, collectChildReturnTypes(), collectChildReturnNullables())); if (fn.hasVarArgs()) { msg.setVarargStartIdx(fn.getNumArgs() - 1); } @@ -1382,7 +1406,8 @@ public abstract class Expr extends TreeNode implements ParseNode, Cloneabl } public Expr checkTypeCompatibility(Type targetType) throws AnalysisException { - if (!targetType.isComplexType() && targetType.getPrimitiveType() == type.getPrimitiveType()) { + if (!targetType.isComplexType() && !targetType.isAggStateType() + && targetType.getPrimitiveType() == type.getPrimitiveType()) { if (targetType.isDecimalV2() && type.isDecimalV2()) { return this; } else if (!PrimitiveType.typeWithPrecision.contains(type.getPrimitiveType())) { @@ -1392,6 +1417,10 @@ public abstract class Expr extends TreeNode implements ParseNode, Cloneabl return this; } } + if (type.isAggStateType() != targetType.isAggStateType()) { + throw new AnalysisException("AggState can't cast from other type."); + } + // bitmap must match exactly if (targetType.getPrimitiveType() == PrimitiveType.BITMAP) { throw new AnalysisException("bitmap column require the function return type is BITMAP"); @@ -1452,8 +1481,32 @@ public abstract class Expr extends TreeNode implements ParseNode, Cloneabl return this; } - if (this.type.equals(targetType)) { - return this; + if (this.type.isAggStateType()) { + List subTypes = ((ScalarType) targetType).getSubTypes(); + + if (this instanceof FunctionCallExpr) { + if (subTypes.size() != getChildren().size()) { + throw new AnalysisException("AggState's subTypes size not euqal to children number"); + } + for (int i = 0; i < subTypes.size(); i++) { + setChild(i, getChild(i).castTo(subTypes.get(i))); + } + type = targetType; + } else { + List selfSubTypes = ((ScalarType) type).getSubTypes(); + if (subTypes.size() != selfSubTypes.size()) { + throw new AnalysisException("AggState's subTypes size did not match"); + } + for (int i = 0; i < subTypes.size(); i++) { + if (subTypes.get(i) != selfSubTypes.get(i)) { + throw new AnalysisException("AggState's subType did not match"); + } + } + } + } else { + if (this.type.equals(targetType)) { + return this; + } } if (targetType.getPrimitiveType() == PrimitiveType.DECIMALV2 @@ -1793,14 +1846,72 @@ public abstract class Expr extends TreeNode implements ParseNode, Cloneabl throws AnalysisException { FunctionName fnName = new FunctionName(name); Function searchDesc = new Function(fnName, Arrays.asList(getActualArgTypes(argTypes)), Type.INVALID, false, - VectorizedUtil.isVectorized()); + true); Function f = Env.getCurrentEnv().getFunction(searchDesc, mode); if (f != null && fnName.getFunction().equalsIgnoreCase("rand")) { - if (this.children.size() == 1 - && !(this.children.get(0) instanceof LiteralExpr)) { + if (this.children.size() == 1 && !(this.children.get(0) instanceof LiteralExpr)) { throw new AnalysisException("The param of rand function must be literal"); } } + if (f != null) { + return f; + } + + boolean isUnion = name.toLowerCase().endsWith(AGG_UNION_SUFFIX); + boolean isMerge = name.toLowerCase().endsWith(AGG_MERGE_SUFFIX); + boolean isState = name.toLowerCase().endsWith(AGG_STATE_SUFFIX); + if (isUnion || isMerge || isState) { + if (isUnion) { + name = name.substring(0, name.length() - AGG_UNION_SUFFIX.length()); + } + if (isMerge) { + name = name.substring(0, name.length() - AGG_MERGE_SUFFIX.length()); + } + if (isState) { + name = name.substring(0, name.length() - AGG_STATE_SUFFIX.length()); + } + + List argList = Arrays.asList(getActualArgTypes(argTypes)); + List nestedArgList; + if (isState) { + nestedArgList = argList; + } else { + if (argList.size() != 1 || !argList.get(0).isAggStateType()) { + throw new AnalysisException("merge/union function must input one agg_state"); + } + ScalarType aggState = (ScalarType) argList.get(0); + if (aggState.getSubTypes() == null) { + throw new AnalysisException("agg_state's subTypes is null"); + } + nestedArgList = aggState.getSubTypes(); + } + + searchDesc = new Function(new FunctionName(name), nestedArgList, Type.INVALID, false, true); + + f = Env.getCurrentEnv().getFunction(searchDesc, mode); + if (f == null || !(f instanceof AggregateFunction)) { + return null; + } + + if (isState) { + f = new ScalarFunction(new FunctionName(name + AGG_STATE_SUFFIX), Arrays.asList(f.getArgs()), + Type.AGG_STATE, f.hasVarArgs(), f.isUserVisible()); + f.setNullableMode(NullableMode.ALWAYS_NOT_NULLABLE); + } else { + f = ((AggregateFunction) f).clone(); + f.setArgs(argList); + if (isUnion) { + f.setName(new FunctionName(name + AGG_UNION_SUFFIX)); + f.setReturnType((ScalarType) argList.get(0)); + f.setNullableMode(NullableMode.ALWAYS_NOT_NULLABLE); + } + if (isMerge) { + f.setName(new FunctionName(name + AGG_MERGE_SUFFIX)); + } + } + f.setBinaryType(TFunctionBinaryType.AGG_STATE); + } + return f; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java index 7e3179da6e..185f328a6e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java @@ -589,8 +589,8 @@ public class AggregateFunction extends Function { } @Override - public TFunction toThrift(Type realReturnType, Type[] realArgTypes) { - TFunction fn = super.toThrift(realReturnType, realArgTypes); + public TFunction toThrift(Type realReturnType, Type[] realArgTypes, Boolean[] realArgTypeNullables) { + TFunction fn = super.toThrift(realReturnType, realArgTypes, realArgTypeNullables); TAggregateFunction aggFn = new TAggregateFunction(); aggFn.setIsAnalyticOnlyFn(isAnalyticFn && !isAggregateFn); aggFn.setUpdateFnSymbol(updateFnSymbol); diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/Column.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/Column.java index db5c9584c3..d9173f17bc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/Column.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/Column.java @@ -397,6 +397,14 @@ public class Column implements Writable, GsonPostProcessable { return this.aggregationType; } + public String getAggregationString() { + if (getAggregationType() == AggregateType.GENERIC_AGGREGATION) { + return getGenericAggregationString(); + } else { + return getAggregationType().name(); + } + } + public boolean isAggregated() { return aggregationType != null && aggregationType != AggregateType.NONE; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/Function.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/Function.java index f4d65bc593..22b36f604c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/Function.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/Function.java @@ -17,6 +17,8 @@ package org.apache.doris.catalog; +import org.apache.doris.analysis.Expr; +import org.apache.doris.analysis.FunctionCallExpr; import org.apache.doris.analysis.FunctionName; import org.apache.doris.common.AnalysisException; import org.apache.doris.common.UserException; @@ -42,6 +44,7 @@ import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.stream.Collectors; /** * Base class for all functions. @@ -217,6 +220,10 @@ public class Function implements Writable { return argTypes; } + public void setArgs(List argTypes) { + this.argTypes = argTypes.toArray(new Type[argTypes.size()]); + } + // Returns the number of arguments to this function. public int getNumArgs() { return argTypes.length; @@ -506,7 +513,7 @@ public class Function implements Writable { return retType instanceof AnyType; } - public TFunction toThrift(Type realReturnType, Type[] realArgTypes) { + public TFunction toThrift(Type realReturnType, Type[] realArgTypes, Boolean[] realArgTypeNullables) { TFunction fn = new TFunction(); fn.setSignature(signatureString()); fn.setName(name.toThrift()); @@ -514,16 +521,24 @@ public class Function implements Writable { if (location != null) { fn.setHdfsLocation(location.getLocation()); } - // `realArgTypes.length != argTypes.length` is true iff this is an aggregation function. - // For aggregation functions, `argTypes` here is already its real type with true precision and scale. + // `realArgTypes.length != argTypes.length` is true iff this is an aggregation + // function. + // For aggregation functions, `argTypes` here is already its real type with true + // precision and scale. if (realArgTypes.length != argTypes.length) { fn.setArgTypes(Type.toThrift(Lists.newArrayList(argTypes))); } else { fn.setArgTypes(Type.toThrift(Lists.newArrayList(argTypes), Lists.newArrayList(realArgTypes))); } - // For types with different precisions and scales, return type only indicates a type with default + + if (realReturnType.isAggStateType()) { + realReturnType = new ScalarType(Arrays.asList(realArgTypes), Arrays.asList(realArgTypeNullables)); + } + + // For types with different precisions and scales, return type only indicates a + // type with default // precision and scale so we need to transform it to the correct type. - if (realReturnType.typeContainsPrecision()) { + if (realReturnType.typeContainsPrecision() || realReturnType.isAggStateType()) { fn.setRetType(realReturnType.toThrift()); } else { fn.setRetType(getReturnType().toThrift()); @@ -816,4 +831,18 @@ public class Function implements Writable { result = 31 * result + Arrays.hashCode(argTypes); return result; } + + public static FunctionCallExpr convertToStateCombinator(FunctionCallExpr fnCall) { + Function aggFunction = fnCall.getFn(); + List arguments = Arrays.asList(aggFunction.getArgs()); + ScalarFunction fn = new org.apache.doris.catalog.ScalarFunction( + new FunctionName(aggFunction.getFunctionName().getFunction() + Expr.AGG_STATE_SUFFIX), + arguments, + new ScalarType(arguments, fnCall.getChildren().stream().map(expr -> { + return expr.isNullable(); + }).collect(Collectors.toList())), aggFunction.hasVarArgs(), aggFunction.isUserVisible()); + fn.setNullableMode(NullableMode.ALWAYS_NOT_NULLABLE); + fn.setBinaryType(TFunctionBinaryType.AGG_STATE); + return new FunctionCallExpr(fn, fnCall.getParams()); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/ScalarFunction.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/ScalarFunction.java index a8caf28532..31d97e9b53 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/ScalarFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/ScalarFunction.java @@ -229,10 +229,10 @@ public class ScalarFunction extends Function { } @Override - public TFunction toThrift(Type realReturnType, Type[] realArgTypes) { - TFunction fn = super.toThrift(realReturnType, realArgTypes); + public TFunction toThrift(Type realReturnType, Type[] realArgTypes, Boolean[] realArgTypeNullables) { + TFunction fn = super.toThrift(realReturnType, realArgTypes, realArgTypeNullables); fn.setScalarFn(new TScalarFunction()); - if (getBinaryType() != TFunctionBinaryType.BUILTIN) { + if (getBinaryType() == TFunctionBinaryType.JAVA_UDF || getBinaryType() == TFunctionBinaryType.RPC) { fn.getScalarFn().setSymbol(symbolName); } else { fn.getScalarFn().setSymbol(""); diff --git a/fe/fe-core/src/main/java/org/apache/doris/common/proc/IndexSchemaProcNode.java b/fe/fe-core/src/main/java/org/apache/doris/common/proc/IndexSchemaProcNode.java index 6e9ae64d94..9d3ea1882f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/common/proc/IndexSchemaProcNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/common/proc/IndexSchemaProcNode.java @@ -17,7 +17,6 @@ package org.apache.doris.common.proc; -import org.apache.doris.catalog.AggregateType; import org.apache.doris.catalog.Column; import org.apache.doris.common.AnalysisException; import org.apache.doris.common.FeConstants; @@ -60,11 +59,7 @@ public class IndexSchemaProcNode implements ProcNodeInterface { // Extra string (aggregation and bloom filter) List extras = Lists.newArrayList(); if (column.getAggregationType() != null) { - if (column.getAggregationType() == AggregateType.GENERIC_AGGREGATION) { - extras.add(column.getGenericAggregationString()); - } else { - extras.add(column.getAggregationType().name()); - } + extras.add(column.getAggregationString()); } if (bfColumns != null && bfColumns.contains(column.getName())) { extras.add("BLOOM_FILTER");