// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. #include "vec/exprs/vectorized_agg_fn.h" #include "fmt/format.h" #include "fmt/ranges.h" #include "runtime/descriptors.h" #include "vec/aggregate_functions/aggregate_function_java_udaf.h" #include "vec/aggregate_functions/aggregate_function_rpc.h" #include "vec/aggregate_functions/aggregate_function_simple_factory.h" #include "vec/aggregate_functions/aggregate_function_sort.h" #include "vec/core/materialize_block.h" #include "vec/data_types/data_type_factory.hpp" #include "vec/exprs/vexpr.h" namespace doris::vectorized { AggFnEvaluator::AggFnEvaluator(const TExprNode& desc) : _fn(desc.fn), _is_merge(desc.agg_expr.is_merge_agg), _return_type(TypeDescriptor::from_thrift(desc.fn.ret_type)), _intermediate_slot_desc(nullptr), _output_slot_desc(nullptr), _exec_timer(nullptr), _merge_timer(nullptr), _expr_timer(nullptr) { bool nullable = true; if (desc.__isset.is_nullable) { nullable = desc.is_nullable; } _data_type = DataTypeFactory::instance().create_data_type(_return_type, nullable); if (desc.agg_expr.__isset.param_types) { auto& param_types = desc.agg_expr.param_types; for (int i = 0; i < param_types.size(); i++) { _argument_types_with_sort.push_back( DataTypeFactory::instance().create_data_type(param_types[i])); } } } Status AggFnEvaluator::create(ObjectPool* pool, const TExpr& desc, const TSortInfo& sort_info, AggFnEvaluator** result) { *result = pool->add(new AggFnEvaluator(desc.nodes[0])); auto& agg_fn_evaluator = *result; int node_idx = 0; for (int i = 0; i < desc.nodes[0].num_children; ++i) { ++node_idx; VExpr* expr = nullptr; VExprContext* ctx = nullptr; RETURN_IF_ERROR( VExpr::create_tree_from_thrift(pool, desc.nodes, nullptr, &node_idx, &expr, &ctx)); agg_fn_evaluator->_input_exprs_ctxs.push_back(ctx); } auto sort_size = sort_info.ordering_exprs.size(); auto real_arguments_size = agg_fn_evaluator->_argument_types_with_sort.size() - sort_size; // Child arguments contains [real arguments, order by arguments], we pass the arguments // to the order by functions for (int i = 0; i < sort_size; ++i) { agg_fn_evaluator->_sort_description.emplace_back(real_arguments_size + i, sort_info.is_asc_order[i] == true, sort_info.nulls_first[i] == true); } // Pass the real arguments to get functions for (int i = 0; i < real_arguments_size; ++i) { agg_fn_evaluator->_real_argument_types.emplace_back( agg_fn_evaluator->_argument_types_with_sort[i]); } return Status::OK(); } Status AggFnEvaluator::prepare(RuntimeState* state, const RowDescriptor& desc, const SlotDescriptor* intermediate_slot_desc, const SlotDescriptor* output_slot_desc) { DCHECK(intermediate_slot_desc != nullptr); DCHECK(_intermediate_slot_desc == nullptr); _output_slot_desc = output_slot_desc; _intermediate_slot_desc = intermediate_slot_desc; Status status = VExpr::prepare(_input_exprs_ctxs, state, desc); RETURN_IF_ERROR(status); DataTypes tmp_argument_types; tmp_argument_types.reserve(_input_exprs_ctxs.size()); std::vector child_expr_name; // prepare for argument for (int i = 0; i < _input_exprs_ctxs.size(); ++i) { auto data_type = _input_exprs_ctxs[i]->root()->data_type(); tmp_argument_types.emplace_back(data_type); child_expr_name.emplace_back(_input_exprs_ctxs[i]->root()->expr_name()); } const DataTypes& argument_types = _real_argument_types.empty() ? tmp_argument_types : _real_argument_types; if (_fn.binary_type == TFunctionBinaryType::JAVA_UDF) { if (config::enable_java_support) { _function = AggregateJavaUdaf::create(_fn, argument_types, {}, _data_type); RETURN_IF_ERROR(static_cast(_function.get())->check_udaf(_fn)); } else { return Status::InternalError( "Java UDAF is not enabled, you can change be config enable_java_support to " "true and restart be."); } } else if (_fn.binary_type == TFunctionBinaryType::RPC) { _function = AggregateRpcUdaf::create(_fn, argument_types, {}, _data_type); } else { _function = AggregateFunctionSimpleFactory::instance().get( _fn.name.function_name, argument_types, {}, _data_type->is_nullable()); } if (_function == nullptr) { return Status::InternalError("Agg Function {} is not implemented", _fn.signature); } if (!_sort_description.empty()) { _function = transform_to_sort_agg_function(_function, _argument_types_with_sort, _sort_description, state); } _expr_name = fmt::format("{}({})", _fn.name.function_name, child_expr_name); return Status::OK(); } Status AggFnEvaluator::open(RuntimeState* state) { return VExpr::open(_input_exprs_ctxs, state); } void AggFnEvaluator::close(RuntimeState* state) { VExpr::close(_input_exprs_ctxs, state); } void AggFnEvaluator::create(AggregateDataPtr place) { _function->create(place); } void AggFnEvaluator::destroy(AggregateDataPtr place) { _function->destroy(place); } Status AggFnEvaluator::execute_single_add(Block* block, AggregateDataPtr place, Arena* arena) { RETURN_IF_ERROR(_calc_argment_columns(block)); SCOPED_TIMER(_exec_timer); _function->add_batch_single_place(block->rows(), place, _agg_columns.data(), arena); return Status::OK(); } Status AggFnEvaluator::execute_batch_add(Block* block, size_t offset, AggregateDataPtr* places, Arena* arena, bool agg_many) { RETURN_IF_ERROR(_calc_argment_columns(block)); SCOPED_TIMER(_exec_timer); _function->add_batch(block->rows(), places, offset, _agg_columns.data(), arena, agg_many); return Status::OK(); } Status AggFnEvaluator::execute_batch_add_selected(Block* block, size_t offset, AggregateDataPtr* places, Arena* arena) { RETURN_IF_ERROR(_calc_argment_columns(block)); SCOPED_TIMER(_exec_timer); _function->add_batch_selected(block->rows(), places, offset, _agg_columns.data(), arena); return Status::OK(); } Status AggFnEvaluator::streaming_agg_serialize(Block* block, BufferWritable& buf, const size_t num_rows, Arena* arena) { RETURN_IF_ERROR(_calc_argment_columns(block)); SCOPED_TIMER(_exec_timer); _function->streaming_agg_serialize(_agg_columns.data(), buf, num_rows, arena); return Status::OK(); } Status AggFnEvaluator::streaming_agg_serialize_to_column(Block* block, MutableColumnPtr& dst, const size_t num_rows, Arena* arena) { RETURN_IF_ERROR(_calc_argment_columns(block)); SCOPED_TIMER(_exec_timer); _function->streaming_agg_serialize_to_column(_agg_columns.data(), dst, num_rows, arena); return Status::OK(); } void AggFnEvaluator::insert_result_info(AggregateDataPtr place, IColumn* column) { _function->insert_result_into(place, *column); } void AggFnEvaluator::insert_result_info_vec(const std::vector& places, size_t offset, IColumn* column, const size_t num_rows) { _function->insert_result_into_vec(places, offset, *column, num_rows); } void AggFnEvaluator::reset(AggregateDataPtr place) { _function->reset(place); } std::string AggFnEvaluator::debug_string(const std::vector& exprs) { std::stringstream out; out << "["; for (int i = 0; i < exprs.size(); ++i) { out << (i == 0 ? "" : " ") << exprs[i]->debug_string(); } out << "]"; return out.str(); } std::string AggFnEvaluator::debug_string() const { std::stringstream out; out << "AggFnEvaluator("; out << ")"; return out.str(); } Status AggFnEvaluator::_calc_argment_columns(Block* block) { SCOPED_TIMER(_expr_timer); _agg_columns.resize(_input_exprs_ctxs.size()); int column_ids[_input_exprs_ctxs.size()]; for (int i = 0; i < _input_exprs_ctxs.size(); ++i) { int column_id = -1; RETURN_IF_ERROR(_input_exprs_ctxs[i]->execute(block, &column_id)); column_ids[i] = column_id; } materialize_block_inplace(*block, column_ids, column_ids + _input_exprs_ctxs.size()); for (int i = 0; i < _input_exprs_ctxs.size(); ++i) { _agg_columns[i] = block->get_by_position(column_ids[i]).column.get(); } return Status::OK(); } } // namespace doris::vectorized