// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. // This file is copied from // https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/AggregateFunctionSum.h // and modified by Doris #pragma once #include "vec/aggregate_functions/aggregate_function.h" #include "vec/columns/column_vector.h" #include "vec/data_types/data_type_decimal.h" #include "vec/data_types/data_type_number.h" #include "vec/io/io_helper.h" namespace doris::vectorized { template struct AggregateFunctionSumData { T sum {}; void add(T value) { sum += value; } void merge(const AggregateFunctionSumData& rhs) { sum += rhs.sum; } void write(BufferWritable& buf) const { write_binary(sum, buf); } void read(BufferReadable& buf) { read_binary(sum, buf); } T get() const { return sum; } }; /// Counts the sum of the numbers. template class AggregateFunctionSum final : public IAggregateFunctionDataHelper> { public: using ResultDataType = std::conditional_t, DataTypeDecimal, DataTypeNumber>; using ColVecType = std::conditional_t, ColumnDecimal, ColumnVector>; using ColVecResult = std::conditional_t, ColumnDecimal, ColumnVector>; String get_name() const override { return "sum"; } AggregateFunctionSum(const DataTypes& argument_types_) : IAggregateFunctionDataHelper>( argument_types_), scale(0) {} AggregateFunctionSum(const IDataType& data_type, const DataTypes& argument_types_) : IAggregateFunctionDataHelper>( argument_types_), scale(get_decimal_scale(data_type)) {} DataTypePtr get_return_type() const override { if constexpr (IsDecimalNumber) { return std::make_shared(ResultDataType::max_precision(), scale); } else { return std::make_shared(); } } void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, Arena*) const override { const auto& column = static_cast(*columns[0]); this->data(place).add(column.get_data()[row_num]); } void reset(AggregateDataPtr place) const override { this->data(place).sum = {}; } void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena*) const override { this->data(place).merge(this->data(rhs)); } void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { this->data(place).write(buf); } void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, Arena*) const override { this->data(place).read(buf); } void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { auto& column = static_cast(to); column.get_data().push_back(this->data(place).get()); } void deserialize_from_column(AggregateDataPtr places, const IColumn& column, Arena* arena, size_t num_rows) const override { auto data = assert_cast(column).get_data().data(); auto dst_data = reinterpret_cast(places); for (size_t i = 0; i != num_rows; ++i) { dst_data[i].sum = data[i]; } } void serialize_to_column(const std::vector& places, size_t offset, MutableColumnPtr& dst, const size_t num_rows) const override { auto& col = assert_cast(*dst); col.resize(num_rows); auto* data = col.get_data().data(); for (size_t i = 0; i != num_rows; ++i) { data[i] = this->data(places[i] + offset).sum; } } void streaming_agg_serialize_to_column(const IColumn** columns, MutableColumnPtr& dst, const size_t num_rows, Arena* arena) const override { auto& col = assert_cast(*dst); auto& src = assert_cast(*columns[0]); col.resize(num_rows); auto* src_data = src.get_data().data(); auto* dst_data = col.get_data().data(); for (size_t i = 0; i != num_rows; ++i) { dst_data[i] = src_data[i]; } } void deserialize_and_merge_from_column(AggregateDataPtr __restrict place, const IColumn& column, Arena* arena) const override { auto data = assert_cast(column).get_data().data(); const size_t num_rows = column.size(); for (size_t i = 0; i != num_rows; ++i) { this->data(place).sum += data[i]; } } void serialize_without_key_to_column(ConstAggregateDataPtr __restrict place, MutableColumnPtr& dst) const override { auto& col = assert_cast(*dst); col.resize(1); reinterpret_cast(col.get_data().data())->sum = this->data(place).sum; } MutableColumnPtr create_serialize_column() const override { return get_return_type()->create_column(); } DataTypePtr get_serialized_type() const override { return get_return_type(); } private: UInt32 scale; }; AggregateFunctionPtr create_aggregate_function_sum_reader(const std::string& name, const DataTypes& argument_types, const bool result_is_nullable); } // namespace doris::vectorized