From 2bb477bae7abefda454c53039af27df877a5978c Mon Sep 17 00:00:00 2001 From: nanfeng Date: Wed, 7 Feb 2024 08:32:06 +0800 Subject: [PATCH] [feature](agg-func) support corr function #30822 --- .../aggregate_function_binary.h | 130 ++++++++++++++++++ .../aggregate_function_corr.cpp | 92 +++++++++++++ .../aggregate_function_simple_factory.cpp | 3 + .../sql-functions/aggregate-functions/corr.md | 49 +++++++ .../sql-functions/aggregate-functions/corr.md | 50 +++++++ .../catalog/BuiltinAggregateFunctions.java | 2 + .../org/apache/doris/catalog/FunctionSet.java | 25 ++++ .../trees/expressions/functions/agg/Corr.java | 85 ++++++++++++ .../visitor/AggregateFunctionVisitor.java | 5 + .../agg_function/test_corr.out | 13 ++ .../agg_function/test_corr.groovy | 85 ++++++++++++ 11 files changed, 539 insertions(+) create mode 100644 be/src/vec/aggregate_functions/aggregate_function_binary.h create mode 100644 be/src/vec/aggregate_functions/aggregate_function_corr.cpp create mode 100644 docs/en/docs/sql-manual/sql-functions/aggregate-functions/corr.md create mode 100644 docs/zh-CN/docs/sql-manual/sql-functions/aggregate-functions/corr.md create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Corr.java create mode 100644 regression-test/data/nereids_function_p0/agg_function/test_corr.out create mode 100644 regression-test/suites/nereids_function_p0/agg_function/test_corr.groovy diff --git a/be/src/vec/aggregate_functions/aggregate_function_binary.h b/be/src/vec/aggregate_functions/aggregate_function_binary.h new file mode 100644 index 0000000000..422919c52a --- /dev/null +++ b/be/src/vec/aggregate_functions/aggregate_function_binary.h @@ -0,0 +1,130 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include + +#include "common/status.h" +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/aggregate_functions/factory_helpers.h" +#include "vec/aggregate_functions/helpers.h" +#include "vec/columns/column_decimal.h" +#include "vec/columns/column_vector.h" +#include "vec/common/arithmetic_overflow.h" +#include "vec/common/string_buffer.hpp" +#include "vec/core/types.h" +#include "vec/data_types/data_type_decimal.h" +#include "vec/data_types/data_type_nullable.h" +#include "vec/data_types/data_type_number.h" +#include "vec/io/io_helper.h" + +namespace doris::vectorized { + +template typename Moments> +struct StatFunc { + using Type1 = T1; + using Type2 = T2; + using ResultType = std::conditional_t && std::is_same_v, + Float32, Float64>; + using Data = Moments; +}; + +template +struct AggregateFunctionBinary + : public IAggregateFunctionDataHelper> { + using ResultType = typename StatFunc::ResultType; + + using ColVecT1 = ColumnVectorOrDecimal; + using ColVecT2 = ColumnVectorOrDecimal; + using ColVecResult = ColumnVector; + static constexpr UInt32 num_args = 2; + + AggregateFunctionBinary(const DataTypes& argument_types_) + : IAggregateFunctionDataHelper>(argument_types_) {} + + String get_name() const override { return StatFunc::Data::name(); } + + DataTypePtr get_return_type() const override { + return std::make_shared>(); + } + + bool allocates_memory_in_arena() const override { return false; } + + void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + Arena*) const override { + this->data(place).add( + static_cast( + static_cast(*columns[0]).get_data()[row_num]), + static_cast( + static_cast(*columns[1]).get_data()[row_num])); + } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, + Arena*) const override { + this->data(place).merge(this->data(rhs)); + } + + void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { + this->data(place).write(buf); + } + + void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, + Arena*) const override { + this->data(place).read(buf); + } + + void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { + const auto& data = this->data(place); + auto& dst = static_cast(to).get_data(); + dst.push_back(data.get()); + } +}; + +template