// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. // This file is copied from // https://github.com/ClickHouse/ClickHouse/blob/master/src/Functions/array/arrayAggregation.cpp // and modified by Doris #include #include "vec/aggregate_functions/aggregate_function_avg.h" #include "vec/aggregate_functions/aggregate_function_min_max.h" #include "vec/aggregate_functions/aggregate_function_null.h" #include "vec/aggregate_functions/aggregate_function_product.h" #include "vec/aggregate_functions/aggregate_function_sum.h" #include "vec/aggregate_functions/helpers.h" #include "vec/columns/column_nullable.h" #include "vec/common/arena.h" #include "vec/core/types.h" #include "vec/data_types/data_type.h" #include "vec/data_types/data_type_nullable.h" #include "vec/functions/array/function_array_join.h" #include "vec/functions/array/function_array_mapped.h" #include "vec/functions/simple_function_factory.h" namespace doris { namespace vectorized { enum class AggregateOperation { MIN, MAX, SUM, AVERAGE, PRODUCT }; template struct ArrayAggregateResultImpl; template struct ArrayAggregateResultImpl { using Result = Element; }; template struct ArrayAggregateResultImpl { using Result = Element; }; template struct ArrayAggregateResultImpl { using Result = std::conditional_t, Decimal128, Float64>; }; template struct ArrayAggregateResultImpl { using Result = std::conditional_t, Decimal128, Float64>; }; template struct ArrayAggregateResultImpl { using Result = std::conditional_t< IsDecimalNumber, Decimal128, std::conditional_t, Float64, std::conditional_t, Int128, Int64>>>; }; template using ArrayAggregateResult = typename ArrayAggregateResultImpl::Result; // For MIN/MAX, the type of result is the same as the type of elements, we can omit the // template specialization. template struct AggregateFunctionImpl; template <> struct AggregateFunctionImpl { template struct TypeTraits { using ResultType = ArrayAggregateResult; using AggregateDataType = AggregateFunctionSumData; using Function = AggregateFunctionSum; }; }; template <> struct AggregateFunctionImpl { template struct TypeTraits { using ResultType = ArrayAggregateResult; using AggregateDataType = AggregateFunctionAvgData; using Function = AggregateFunctionAvg; static_assert(std::is_same_v, "ResultType doesn't match."); }; }; template <> struct AggregateFunctionImpl { template struct TypeTraits { using ResultType = ArrayAggregateResult; using AggregateDataType = AggregateFunctionProductData; using Function = AggregateFunctionProduct; }; }; template struct AggregateFunction { template using Function = typename Derived::template TypeTraits::Function; static auto create(const DataTypePtr& data_type_ptr) -> AggregateFunctionPtr { DataTypes data_types = {remove_nullable(data_type_ptr)}; auto& data_type = *data_types.front(); AggregateFunctionPtr nested_function; if (is_decimal(data_types.front())) { nested_function = AggregateFunctionPtr( create_with_decimal_type(data_type, data_type, data_types)); } else { nested_function = AggregateFunctionPtr(create_with_numeric_type(data_type, data_types)); } AggregateFunctionPtr function; function.reset(new AggregateFunctionNullUnary(nested_function, {make_nullable(data_type_ptr)}, {})); return function; } }; template struct ArrayAggregateImpl { using column_type = ColumnArray; using data_type = DataTypeArray; static bool _is_variadic() { return false; } static size_t _get_number_of_arguments() { return 1; } static DataTypePtr get_return_type(const DataTypes& arguments) { using Function = AggregateFunction>; const DataTypeArray* data_type_array = static_cast(remove_nullable(arguments[0]).get()); auto function = Function::create(data_type_array->get_nested_type()); return function->get_return_type(); } static Status execute(Block& block, const ColumnNumbers& arguments, size_t result, const DataTypeArray* data_type_array, const ColumnArray& array) { ColumnPtr res; DataTypePtr type = data_type_array->get_nested_type(); const IColumn* data = array.get_data_ptr().get(); const auto& offsets = array.get_offsets(); if (execute_type(res, type, data, offsets) || execute_type(res, type, data, offsets) || execute_type(res, type, data, offsets) || execute_type(res, type, data, offsets) || execute_type(res, type, data, offsets) || execute_type(res, type, data, offsets) || execute_type(res, type, data, offsets) || execute_type(res, type, data, offsets)) { block.replace_by_position(result, std::move(res)); return Status::OK(); } else { return Status::RuntimeError("Unexpected column for aggregation: {}", data->get_name()); } } template static bool execute_type(ColumnPtr& res_ptr, const DataTypePtr& type, const IColumn* data, const ColumnArray::Offsets& offsets) { using ColVecType = ColumnVectorOrDecimal; using ResultType = ArrayAggregateResult; using ColVecResultType = ColumnVectorOrDecimal; using Function = AggregateFunction>; const ColVecType* column = data->is_nullable() ? check_and_get_column( static_cast(data)->get_nested_column()) : check_and_get_column(&*data); if (!column) { return false; } ColumnPtr res_column; if constexpr (IsDecimalNumber) { res_column = ColVecResultType::create(0, column->get_scale()); } else { res_column = ColVecResultType::create(); } res_column = make_nullable(res_column); static_cast(res_column->assume_mutable_ref()).reserve(offsets.size()); auto function = Function::create(type); auto guard = AggregateFunctionGuard(function.get()); Arena arena; auto nullable_column = make_nullable(data->get_ptr()); const IColumn* columns[] = {nullable_column.get()}; for (int64_t i = 0; i < offsets.size(); ++i) { auto start = offsets[i - 1]; // -1 is ok. auto end = offsets[i]; bool is_empty = (start == end); if (is_empty) { res_column->assume_mutable()->insert_default(); continue; } function->reset(guard.data()); function->add_batch_range(start, end - 1, guard.data(), columns, &arena, data->is_nullable()); function->insert_result_into(guard.data(), res_column->assume_mutable_ref()); } res_ptr = std::move(res_column); return true; }; }; struct NameArrayMin { static constexpr auto name = "array_min"; }; template <> struct AggregateFunction> { static auto create(const DataTypePtr& data_type_ptr) -> AggregateFunctionPtr { DataTypes data_types = {remove_nullable(data_type_ptr)}; auto nested_function = AggregateFunctionPtr( create_aggregate_function_min(NameArrayMin::name, data_types, {}, false)); AggregateFunctionPtr function; function.reset(new AggregateFunctionNullUnary(nested_function, {make_nullable(data_type_ptr)}, {})); return function; } }; struct NameArrayMax { static constexpr auto name = "array_max"; }; template <> struct AggregateFunction> { static auto create(const DataTypePtr& data_type_ptr) -> AggregateFunctionPtr { DataTypes data_types = {remove_nullable(data_type_ptr)}; auto nested_function = AggregateFunctionPtr( create_aggregate_function_max(NameArrayMax::name, data_types, {}, false)); AggregateFunctionPtr function; function.reset(new AggregateFunctionNullUnary(nested_function, {make_nullable(data_type_ptr)}, {})); return function; } }; struct NameArraySum { static constexpr auto name = "array_sum"; }; struct NameArrayAverage { static constexpr auto name = "array_avg"; }; struct NameArrayProduct { static constexpr auto name = "array_product"; }; using FunctionArrayMin = FunctionArrayMapped, NameArrayMin>; using FunctionArrayMax = FunctionArrayMapped, NameArrayMax>; using FunctionArraySum = FunctionArrayMapped, NameArraySum>; using FunctionArrayAverage = FunctionArrayMapped, NameArrayAverage>; using FunctionArrayProduct = FunctionArrayMapped, NameArrayProduct>; using FunctionArrayJoin = FunctionArrayMapped; void register_function_array_aggregation(SimpleFunctionFactory& factory) { factory.register_function(); factory.register_function(); factory.register_function(); factory.register_function(); factory.register_function(); factory.register_function(); } } // namespace vectorized } // namespace doris