// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. // This file is copied from // https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/AggregateFunctionAvg.h // and modified by Doris #pragma once #include "common/status.h" #include "vec/aggregate_functions/aggregate_function.h" #include "vec/columns/columns_number.h" #include "vec/data_types/data_type_decimal.h" #include "vec/data_types/data_type_number.h" #include "vec/io/io_helper.h" namespace doris::vectorized { template struct AggregateFunctionAvgData { T sum = 0; UInt64 count = 0; template ResultT result() const { if constexpr (std::is_floating_point_v) if constexpr (std::numeric_limits::is_iec559) return static_cast(sum) / count; /// allow division by zero if (!count) throw Exception("AggregateFunctionAvg with zero values", TStatusCode::VEC_LOGIC_ERROR); return static_cast(sum) / count; } void write(BufferWritable& buf) const { write_binary(sum, buf); write_binary(count, buf); } void read(BufferReadable& buf) { read_binary(sum, buf); read_binary(count, buf); } }; /// Calculates arithmetic mean of numbers. template class AggregateFunctionAvg final : public IAggregateFunctionDataHelper> { public: using ResultType = std::conditional_t, Decimal128, Float64>; using ResultDataType = std::conditional_t, DataTypeDecimal, DataTypeNumber>; using ColVecType = std::conditional_t, ColumnDecimal, ColumnVector>; using ColVecResult = std::conditional_t, ColumnDecimal, ColumnVector>; /// ctor for native types AggregateFunctionAvg(const DataTypes& argument_types_) : IAggregateFunctionDataHelper>(argument_types_, {}), scale(0) {} /// ctor for Decimals AggregateFunctionAvg(const IDataType& data_type, const DataTypes& argument_types_) : IAggregateFunctionDataHelper>(argument_types_, {}), scale(get_decimal_scale(data_type)) {} String get_name() const override { return "avg"; } DataTypePtr get_return_type() const override { if constexpr (IsDecimalNumber) return std::make_shared(ResultDataType::max_precision(), scale); else return std::make_shared(); } void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, Arena*) const override { const auto& column = static_cast(*columns[0]); this->data(place).sum += column.get_data()[row_num]; ++this->data(place).count; } void reset(AggregateDataPtr place) const override { this->data(place).sum = 0; this->data(place).count = 0; } void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena*) const override { this->data(place).sum += this->data(rhs).sum; this->data(place).count += this->data(rhs).count; } void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { this->data(place).write(buf); } void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, Arena*) const override { this->data(place).read(buf); } void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { auto& column = static_cast(to); column.get_data().push_back(this->data(place).template result()); } const char* get_header_file_path() const override { return __FILE__; } private: UInt32 scale; }; } // namespace doris::vectorized