[feature](agg-func) support corr function #30822

This commit is contained in:
nanfeng
2024-02-07 08:32:06 +08:00
committed by yiguolei
parent 4052746f1c
commit 2bb477bae7
11 changed files with 539 additions and 0 deletions

View File

@ -0,0 +1,130 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <glog/logging.h>
#include <cmath>
#include "common/status.h"
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/aggregate_functions/factory_helpers.h"
#include "vec/aggregate_functions/helpers.h"
#include "vec/columns/column_decimal.h"
#include "vec/columns/column_vector.h"
#include "vec/common/arithmetic_overflow.h"
#include "vec/common/string_buffer.hpp"
#include "vec/core/types.h"
#include "vec/data_types/data_type_decimal.h"
#include "vec/data_types/data_type_nullable.h"
#include "vec/data_types/data_type_number.h"
#include "vec/io/io_helper.h"
namespace doris::vectorized {
template <typename T1, typename T2, template <typename> typename Moments>
struct StatFunc {
using Type1 = T1;
using Type2 = T2;
using ResultType = std::conditional_t<std::is_same_v<T1, T2> && std::is_same_v<T1, Float32>,
Float32, Float64>;
using Data = Moments<ResultType>;
};
template <typename StatFunc>
struct AggregateFunctionBinary
: public IAggregateFunctionDataHelper<typename StatFunc::Data,
AggregateFunctionBinary<StatFunc>> {
using ResultType = typename StatFunc::ResultType;
using ColVecT1 = ColumnVectorOrDecimal<typename StatFunc::Type1>;
using ColVecT2 = ColumnVectorOrDecimal<typename StatFunc::Type2>;
using ColVecResult = ColumnVector<ResultType>;
static constexpr UInt32 num_args = 2;
AggregateFunctionBinary(const DataTypes& argument_types_)
: IAggregateFunctionDataHelper<typename StatFunc::Data,
AggregateFunctionBinary<StatFunc>>(argument_types_) {}
String get_name() const override { return StatFunc::Data::name(); }
DataTypePtr get_return_type() const override {
return std::make_shared<DataTypeNumber<ResultType>>();
}
bool allocates_memory_in_arena() const override { return false; }
void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num,
Arena*) const override {
this->data(place).add(
static_cast<ResultType>(
static_cast<const ColVecT1&>(*columns[0]).get_data()[row_num]),
static_cast<ResultType>(
static_cast<const ColVecT2&>(*columns[1]).get_data()[row_num]));
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena*) const override {
this->data(place).merge(this->data(rhs));
}
void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
this->data(place).write(buf);
}
void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena*) const override {
this->data(place).read(buf);
}
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
const auto& data = this->data(place);
auto& dst = static_cast<ColVecResult&>(to).get_data();
dst.push_back(data.get());
}
};
template <template <typename> typename Moments, typename FirstType, typename... TArgs>
AggregateFunctionPtr create_with_two_basic_numeric_types_second(const DataTypePtr& second_type,
TArgs&&... args) {
WhichDataType which(remove_nullable(second_type));
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return creator_without_type::create< \
AggregateFunctionBinary<StatFunc<FirstType, TYPE, Moments>>>( \
std::forward<TArgs>(args)...);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
return nullptr;
}
template <template <typename> typename Moments, typename... TArgs>
AggregateFunctionPtr create_with_two_basic_numeric_types(const DataTypePtr& first_type,
const DataTypePtr& second_type,
TArgs&&... args) {
WhichDataType which(remove_nullable(first_type));
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return create_with_two_basic_numeric_types_second<Moments, TYPE>( \
second_type, std::forward<TArgs>(args)...);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
return nullptr;
}
} // namespace doris::vectorized

View File

@ -0,0 +1,92 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/aggregate_functions/aggregate_function_binary.h"
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
#include "vec/core/types.h"
namespace doris::vectorized {
template <typename T>
struct CorrMoment {
T m0 {};
T x1 {};
T y1 {};
T xy {};
T x2 {};
T y2 {};
void add(T x, T y) {
++m0;
x1 += x;
y1 += y;
xy += x * y;
x2 += x * x;
y2 += y * y;
}
void merge(const CorrMoment& rhs) {
m0 += rhs.m0;
x1 += rhs.x1;
y1 += rhs.y1;
xy += rhs.xy;
x2 += rhs.x2;
y2 += rhs.y2;
}
void write(BufferWritable& buf) const {
write_binary(m0, buf);
write_binary(x1, buf);
write_binary(y1, buf);
write_binary(xy, buf);
write_binary(x2, buf);
write_binary(y2, buf);
}
void read(BufferReadable& buf) {
read_binary(m0, buf);
read_binary(x1, buf);
read_binary(y1, buf);
read_binary(xy, buf);
read_binary(x2, buf);
read_binary(y2, buf);
}
T get() const {
if ((m0 * x2 - x1 * x1) * (m0 * y2 - y1 * y1) == 0) [[unlikely]] {
return 0;
}
return (m0 * xy - x1 * y1) / sqrt((m0 * x2 - x1 * x1) * (m0 * y2 - y1 * y1));
}
static String name() { return "corr"; }
};
AggregateFunctionPtr create_aggregate_corr_function(const std::string& name,
const DataTypes& argument_types,
const bool result_is_nullable) {
assert_binary(name, argument_types);
return create_with_two_basic_numeric_types<CorrMoment>(argument_types[0], argument_types[1],
argument_types, result_is_nullable);
}
void register_aggregate_functions_corr(AggregateFunctionSimpleFactory& factory) {
factory.register_function_both("corr", create_aggregate_corr_function);
}
} // namespace doris::vectorized

View File

@ -60,6 +60,7 @@ void register_aggregate_function_avg_weighted(AggregateFunctionSimpleFactory& fa
void register_aggregate_function_histogram(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_map_agg(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_bitmap_agg(AggregateFunctionSimpleFactory& factory);
void register_aggregate_functions_corr(AggregateFunctionSimpleFactory& factory);
AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() {
static std::once_flag oc;
@ -100,6 +101,8 @@ AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() {
register_aggregate_function_replace_reader_load(instance);
register_aggregate_function_window_lead_lag_first_last(instance);
register_aggregate_function_HLL_union_agg(instance);
register_aggregate_functions_corr(instance);
});
return instance;
}

View File

@ -0,0 +1,49 @@
---
{
"title": "CORR",
"language": "en"
}
---
<!--
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->
## CORR
### Description
#### Syntax
` double corr(x, y)`
Calculate the Pearson correlation coefficient, which is returned as the covariance of x and y divided by the product of the standard deviations of x and y.
If the standard deviation of x or y is 0, the result will be 0.
### example
```
mysql> select corr(x,y) from baseall;
+---------------------+
| corr(x, y) |
+---------------------+
| 0.89442719099991586 |
+---------------------+
1 row in set (0.21 sec)
```
### keywords
CORR

View File

@ -0,0 +1,50 @@
---
{
"title": "CORR",
"language": "zh-CN"
}
---
<!--
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->
## CORR
### Description
#### Syntax
` double corr(x, y)`
计算皮尔逊系数, 即返回结果为: x和y的协方差,除x和y的标准差乘积。
如果x或y的标准差为0, 将返回0。
### example
```
mysql> select corr(x,y) from baseall;
+---------------------+
| corr(x, y) |
+---------------------+
| 0.89442719099991586 |
+---------------------+
1 row in set (0.21 sec)
```
### keywords
CORR

View File

@ -28,6 +28,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionCount
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionInt;
import org.apache.doris.nereids.trees.expressions.functions.agg.CollectList;
import org.apache.doris.nereids.trees.expressions.functions.agg.CollectSet;
import org.apache.doris.nereids.trees.expressions.functions.agg.Corr;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.CountByEnum;
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupBitAnd;
@ -93,6 +94,7 @@ public class BuiltinAggregateFunctions implements FunctionHelper {
agg(BitmapUnionInt.class, "bitmap_union_int"),
agg(CollectList.class, "collect_list", "group_array"),
agg(CollectSet.class, "collect_set", "group_uniq_array"),
agg(Corr.class, "corr"),
agg(Count.class, "count"),
agg(CountByEnum.class, "count_by_enum"),
agg(GroupBitAnd.class, "group_bit_and"),

View File

@ -1710,6 +1710,31 @@ public class FunctionSet<T> {
"",
false, true, false, true));
// corr
addBuiltin(AggregateFunction.createBuiltin("corr",
Lists.<Type>newArrayList(Type.TINYINT, Type.TINYINT), Type.DOUBLE, Type.DOUBLE,
"", "", "", "", "", "", "",
false, false, false, true));
addBuiltin(AggregateFunction.createBuiltin("corr",
Lists.<Type>newArrayList(Type.SMALLINT, Type.SMALLINT), Type.DOUBLE, Type.DOUBLE,
"", "", "", "", "", "", "",
false, false, false, true));
addBuiltin(AggregateFunction.createBuiltin("corr",
Lists.<Type>newArrayList(Type.INT, Type.INT), Type.DOUBLE, Type.DOUBLE,
"", "", "", "", "", "", "",
false, false, false, true));
addBuiltin(AggregateFunction.createBuiltin("corr",
Lists.<Type>newArrayList(Type.BIGINT, Type.BIGINT), Type.DOUBLE, Type.DOUBLE,
"", "", "", "", "", "", "",
false, false, false, true));
addBuiltin(AggregateFunction.createBuiltin("corr",
Lists.<Type>newArrayList(Type.FLOAT, Type.FLOAT), Type.DOUBLE, Type.DOUBLE,
"", "", "", "", "", "", "",
false, false, false, true));
addBuiltin(AggregateFunction.createBuiltin("corr",
Lists.<Type>newArrayList(Type.DOUBLE, Type.DOUBLE), Type.DOUBLE, Type.DOUBLE,
"", "", "", "", "", "", "",
false, false, false, true));
}
public Map<String, List<Function>> getVectorizedFunctions() {

View File

@ -0,0 +1,85 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions.functions.agg;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.FloatType;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.types.SmallIntType;
import org.apache.doris.nereids.types.TinyIntType;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.List;
/**
* AggregateFunction 'corr'. This class is generated by GenerateFunction.
*/
public class Corr extends AggregateFunction
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE, TinyIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE, SmallIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE, BigIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE, FloatType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, DoubleType.INSTANCE)
);
/**
* constructor with 2 argument.
*/
public Corr(Expression arg1, Expression arg2) {
super("corr", arg1, arg2);
}
/**
* constructor with 3 arguments.
*/
public Corr(boolean distinct, Expression arg1, Expression arg2) {
super("corr", distinct, arg1, arg2);
}
/**
* withDistinctAndChildren.
*/
@Override
public Corr withDistinctAndChildren(boolean distinct, List<Expression> children) {
Preconditions.checkArgument(children.size() == 2);
return new Corr(distinct, children.get(0), children.get(1));
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitCorr(this, context);
}
@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
}
}

View File

@ -29,6 +29,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionCount
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionInt;
import org.apache.doris.nereids.trees.expressions.functions.agg.CollectList;
import org.apache.doris.nereids.trees.expressions.functions.agg.CollectSet;
import org.apache.doris.nereids.trees.expressions.functions.agg.Corr;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.CountByEnum;
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupBitAnd;
@ -126,6 +127,10 @@ public interface AggregateFunctionVisitor<R, C> {
return visitAggregateFunction(collectSet, context);
}
default R visitCorr(Corr corr, C context) {
return visitAggregateFunction(corr, context);
}
default R visitCount(Count count, C context) {
return visitAggregateFunction(count, context);
}

View File

@ -0,0 +1,13 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !sql --
1.0
-- !sql --
-1.0
-- !sql --
0.0
-- !sql --
0.8944271909999159

View File

@ -0,0 +1,85 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
suite("test_corr") {
sql """ DROP TABLE IF EXISTS test_corr """
sql """ SET enable_nereids_planner=true """
sql """ SET enable_fallback_to_original_planner=false """
sql """
CREATE TABLE test_corr (
`id` int,
`x` int,
`y` int,
) ENGINE=OLAP
Duplicate KEY (`id`)
DISTRIBUTED BY HASH(`id`) BUCKETS 4
PROPERTIES (
"replication_allocation" = "tag.location.default: 1"
);
"""
// Perfect positive correlation
sql """
insert into test_corr values
(1, 1, 1),
(2, 2, 2),
(3, 3, 3),
(4, 4, 4),
(5, 5, 5)
"""
qt_sql "select corr(x,y) from test_corr"
sql """ truncate table test_corr """
// Perfect negative correlation
sql """
insert into test_corr values
(1, 1, 5),
(2, 2, 4),
(3, 3, 3),
(4, 4, 2),
(5, 5, 1)
"""
qt_sql "select corr(x,y) from test_corr"
sql """ truncate table test_corr """
// Zero correlation
sql """
insert into test_corr values
(1, 1, 1),
(2, 1, 2),
(3, 1, 3),
(4, 1, 4),
(5, 1, 5)
"""
qt_sql "select corr(x,y) from test_corr"
sql """ truncate table test_corr """
// Partial linear correlation
sql """
insert into test_corr values
(1, 1, 1),
(2, 2, 2),
(3, 3, 3),
(4, 4, 4),
(5, 5, 10)
"""
qt_sql "select corr(x,y) from test_corr"
sql """ DROP TABLE IF EXISTS test_corr """
}