diff --git a/be/src/vec/aggregate_functions/aggregate_function_binary.h b/be/src/vec/aggregate_functions/aggregate_function_binary.h new file mode 100644 index 0000000000..422919c52a --- /dev/null +++ b/be/src/vec/aggregate_functions/aggregate_function_binary.h @@ -0,0 +1,130 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include + +#include "common/status.h" +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/aggregate_functions/factory_helpers.h" +#include "vec/aggregate_functions/helpers.h" +#include "vec/columns/column_decimal.h" +#include "vec/columns/column_vector.h" +#include "vec/common/arithmetic_overflow.h" +#include "vec/common/string_buffer.hpp" +#include "vec/core/types.h" +#include "vec/data_types/data_type_decimal.h" +#include "vec/data_types/data_type_nullable.h" +#include "vec/data_types/data_type_number.h" +#include "vec/io/io_helper.h" + +namespace doris::vectorized { + +template typename Moments> +struct StatFunc { + using Type1 = T1; + using Type2 = T2; + using ResultType = std::conditional_t && std::is_same_v, + Float32, Float64>; + using Data = Moments; +}; + +template +struct AggregateFunctionBinary + : public IAggregateFunctionDataHelper> { + using ResultType = typename StatFunc::ResultType; + + using ColVecT1 = ColumnVectorOrDecimal; + using ColVecT2 = ColumnVectorOrDecimal; + using ColVecResult = ColumnVector; + static constexpr UInt32 num_args = 2; + + AggregateFunctionBinary(const DataTypes& argument_types_) + : IAggregateFunctionDataHelper>(argument_types_) {} + + String get_name() const override { return StatFunc::Data::name(); } + + DataTypePtr get_return_type() const override { + return std::make_shared>(); + } + + bool allocates_memory_in_arena() const override { return false; } + + void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + Arena*) const override { + this->data(place).add( + static_cast( + static_cast(*columns[0]).get_data()[row_num]), + static_cast( + static_cast(*columns[1]).get_data()[row_num])); + } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, + Arena*) const override { + this->data(place).merge(this->data(rhs)); + } + + void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { + this->data(place).write(buf); + } + + void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, + Arena*) const override { + this->data(place).read(buf); + } + + void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { + const auto& data = this->data(place); + auto& dst = static_cast(to).get_data(); + dst.push_back(data.get()); + } +}; + +template