[feature](agg-func) support corr function #30822
This commit is contained in:
130
be/src/vec/aggregate_functions/aggregate_function_binary.h
Normal file
130
be/src/vec/aggregate_functions/aggregate_function_binary.h
Normal file
@ -0,0 +1,130 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <glog/logging.h>
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "common/status.h"
|
||||
#include "vec/aggregate_functions/aggregate_function.h"
|
||||
#include "vec/aggregate_functions/factory_helpers.h"
|
||||
#include "vec/aggregate_functions/helpers.h"
|
||||
#include "vec/columns/column_decimal.h"
|
||||
#include "vec/columns/column_vector.h"
|
||||
#include "vec/common/arithmetic_overflow.h"
|
||||
#include "vec/common/string_buffer.hpp"
|
||||
#include "vec/core/types.h"
|
||||
#include "vec/data_types/data_type_decimal.h"
|
||||
#include "vec/data_types/data_type_nullable.h"
|
||||
#include "vec/data_types/data_type_number.h"
|
||||
#include "vec/io/io_helper.h"
|
||||
|
||||
namespace doris::vectorized {
|
||||
|
||||
template <typename T1, typename T2, template <typename> typename Moments>
|
||||
struct StatFunc {
|
||||
using Type1 = T1;
|
||||
using Type2 = T2;
|
||||
using ResultType = std::conditional_t<std::is_same_v<T1, T2> && std::is_same_v<T1, Float32>,
|
||||
Float32, Float64>;
|
||||
using Data = Moments<ResultType>;
|
||||
};
|
||||
|
||||
template <typename StatFunc>
|
||||
struct AggregateFunctionBinary
|
||||
: public IAggregateFunctionDataHelper<typename StatFunc::Data,
|
||||
AggregateFunctionBinary<StatFunc>> {
|
||||
using ResultType = typename StatFunc::ResultType;
|
||||
|
||||
using ColVecT1 = ColumnVectorOrDecimal<typename StatFunc::Type1>;
|
||||
using ColVecT2 = ColumnVectorOrDecimal<typename StatFunc::Type2>;
|
||||
using ColVecResult = ColumnVector<ResultType>;
|
||||
static constexpr UInt32 num_args = 2;
|
||||
|
||||
AggregateFunctionBinary(const DataTypes& argument_types_)
|
||||
: IAggregateFunctionDataHelper<typename StatFunc::Data,
|
||||
AggregateFunctionBinary<StatFunc>>(argument_types_) {}
|
||||
|
||||
String get_name() const override { return StatFunc::Data::name(); }
|
||||
|
||||
DataTypePtr get_return_type() const override {
|
||||
return std::make_shared<DataTypeNumber<ResultType>>();
|
||||
}
|
||||
|
||||
bool allocates_memory_in_arena() const override { return false; }
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num,
|
||||
Arena*) const override {
|
||||
this->data(place).add(
|
||||
static_cast<ResultType>(
|
||||
static_cast<const ColVecT1&>(*columns[0]).get_data()[row_num]),
|
||||
static_cast<ResultType>(
|
||||
static_cast<const ColVecT2&>(*columns[1]).get_data()[row_num]));
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
|
||||
Arena*) const override {
|
||||
this->data(place).merge(this->data(rhs));
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
|
||||
this->data(place).write(buf);
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
|
||||
Arena*) const override {
|
||||
this->data(place).read(buf);
|
||||
}
|
||||
|
||||
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
|
||||
const auto& data = this->data(place);
|
||||
auto& dst = static_cast<ColVecResult&>(to).get_data();
|
||||
dst.push_back(data.get());
|
||||
}
|
||||
};
|
||||
|
||||
template <template <typename> typename Moments, typename FirstType, typename... TArgs>
|
||||
AggregateFunctionPtr create_with_two_basic_numeric_types_second(const DataTypePtr& second_type,
|
||||
TArgs&&... args) {
|
||||
WhichDataType which(remove_nullable(second_type));
|
||||
#define DISPATCH(TYPE) \
|
||||
if (which.idx == TypeIndex::TYPE) \
|
||||
return creator_without_type::create< \
|
||||
AggregateFunctionBinary<StatFunc<FirstType, TYPE, Moments>>>( \
|
||||
std::forward<TArgs>(args)...);
|
||||
FOR_NUMERIC_TYPES(DISPATCH)
|
||||
#undef DISPATCH
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <template <typename> typename Moments, typename... TArgs>
|
||||
AggregateFunctionPtr create_with_two_basic_numeric_types(const DataTypePtr& first_type,
|
||||
const DataTypePtr& second_type,
|
||||
TArgs&&... args) {
|
||||
WhichDataType which(remove_nullable(first_type));
|
||||
#define DISPATCH(TYPE) \
|
||||
if (which.idx == TypeIndex::TYPE) \
|
||||
return create_with_two_basic_numeric_types_second<Moments, TYPE>( \
|
||||
second_type, std::forward<TArgs>(args)...);
|
||||
FOR_NUMERIC_TYPES(DISPATCH)
|
||||
#undef DISPATCH
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace doris::vectorized
|
||||
92
be/src/vec/aggregate_functions/aggregate_function_corr.cpp
Normal file
92
be/src/vec/aggregate_functions/aggregate_function_corr.cpp
Normal file
@ -0,0 +1,92 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#include "vec/aggregate_functions/aggregate_function.h"
|
||||
#include "vec/aggregate_functions/aggregate_function_binary.h"
|
||||
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
|
||||
#include "vec/core/types.h"
|
||||
|
||||
namespace doris::vectorized {
|
||||
|
||||
template <typename T>
|
||||
struct CorrMoment {
|
||||
T m0 {};
|
||||
T x1 {};
|
||||
T y1 {};
|
||||
T xy {};
|
||||
T x2 {};
|
||||
T y2 {};
|
||||
|
||||
void add(T x, T y) {
|
||||
++m0;
|
||||
x1 += x;
|
||||
y1 += y;
|
||||
xy += x * y;
|
||||
x2 += x * x;
|
||||
y2 += y * y;
|
||||
}
|
||||
|
||||
void merge(const CorrMoment& rhs) {
|
||||
m0 += rhs.m0;
|
||||
x1 += rhs.x1;
|
||||
y1 += rhs.y1;
|
||||
xy += rhs.xy;
|
||||
x2 += rhs.x2;
|
||||
y2 += rhs.y2;
|
||||
}
|
||||
|
||||
void write(BufferWritable& buf) const {
|
||||
write_binary(m0, buf);
|
||||
write_binary(x1, buf);
|
||||
write_binary(y1, buf);
|
||||
write_binary(xy, buf);
|
||||
write_binary(x2, buf);
|
||||
write_binary(y2, buf);
|
||||
}
|
||||
|
||||
void read(BufferReadable& buf) {
|
||||
read_binary(m0, buf);
|
||||
read_binary(x1, buf);
|
||||
read_binary(y1, buf);
|
||||
read_binary(xy, buf);
|
||||
read_binary(x2, buf);
|
||||
read_binary(y2, buf);
|
||||
}
|
||||
|
||||
T get() const {
|
||||
if ((m0 * x2 - x1 * x1) * (m0 * y2 - y1 * y1) == 0) [[unlikely]] {
|
||||
return 0;
|
||||
}
|
||||
return (m0 * xy - x1 * y1) / sqrt((m0 * x2 - x1 * x1) * (m0 * y2 - y1 * y1));
|
||||
}
|
||||
|
||||
static String name() { return "corr"; }
|
||||
};
|
||||
|
||||
AggregateFunctionPtr create_aggregate_corr_function(const std::string& name,
|
||||
const DataTypes& argument_types,
|
||||
const bool result_is_nullable) {
|
||||
assert_binary(name, argument_types);
|
||||
return create_with_two_basic_numeric_types<CorrMoment>(argument_types[0], argument_types[1],
|
||||
argument_types, result_is_nullable);
|
||||
}
|
||||
|
||||
void register_aggregate_functions_corr(AggregateFunctionSimpleFactory& factory) {
|
||||
factory.register_function_both("corr", create_aggregate_corr_function);
|
||||
}
|
||||
|
||||
} // namespace doris::vectorized
|
||||
@ -60,6 +60,7 @@ void register_aggregate_function_avg_weighted(AggregateFunctionSimpleFactory& fa
|
||||
void register_aggregate_function_histogram(AggregateFunctionSimpleFactory& factory);
|
||||
void register_aggregate_function_map_agg(AggregateFunctionSimpleFactory& factory);
|
||||
void register_aggregate_function_bitmap_agg(AggregateFunctionSimpleFactory& factory);
|
||||
void register_aggregate_functions_corr(AggregateFunctionSimpleFactory& factory);
|
||||
|
||||
AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() {
|
||||
static std::once_flag oc;
|
||||
@ -100,6 +101,8 @@ AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() {
|
||||
register_aggregate_function_replace_reader_load(instance);
|
||||
register_aggregate_function_window_lead_lag_first_last(instance);
|
||||
register_aggregate_function_HLL_union_agg(instance);
|
||||
|
||||
register_aggregate_functions_corr(instance);
|
||||
});
|
||||
return instance;
|
||||
}
|
||||
|
||||
@ -0,0 +1,49 @@
|
||||
---
|
||||
{
|
||||
"title": "CORR",
|
||||
"language": "en"
|
||||
}
|
||||
---
|
||||
|
||||
<!--
|
||||
Licensed to the Apache Software Foundation (ASF) under one
|
||||
or more contributor license agreements. See the NOTICE file
|
||||
distributed with this work for additional information
|
||||
regarding copyright ownership. The ASF licenses this file
|
||||
to you under the Apache License, Version 2.0 (the
|
||||
"License"); you may not use this file except in compliance
|
||||
with the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing,
|
||||
software distributed under the License is distributed on an
|
||||
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations
|
||||
under the License.
|
||||
-->
|
||||
|
||||
## CORR
|
||||
### Description
|
||||
#### Syntax
|
||||
|
||||
` double corr(x, y)`
|
||||
|
||||
Calculate the Pearson correlation coefficient, which is returned as the covariance of x and y divided by the product of the standard deviations of x and y.
|
||||
If the standard deviation of x or y is 0, the result will be 0.
|
||||
|
||||
### example
|
||||
|
||||
```
|
||||
mysql> select corr(x,y) from baseall;
|
||||
+---------------------+
|
||||
| corr(x, y) |
|
||||
+---------------------+
|
||||
| 0.89442719099991586 |
|
||||
+---------------------+
|
||||
1 row in set (0.21 sec)
|
||||
|
||||
```
|
||||
### keywords
|
||||
CORR
|
||||
@ -0,0 +1,50 @@
|
||||
---
|
||||
{
|
||||
"title": "CORR",
|
||||
"language": "zh-CN"
|
||||
}
|
||||
---
|
||||
|
||||
<!--
|
||||
Licensed to the Apache Software Foundation (ASF) under one
|
||||
or more contributor license agreements. See the NOTICE file
|
||||
distributed with this work for additional information
|
||||
regarding copyright ownership. The ASF licenses this file
|
||||
to you under the Apache License, Version 2.0 (the
|
||||
"License"); you may not use this file except in compliance
|
||||
with the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing,
|
||||
software distributed under the License is distributed on an
|
||||
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations
|
||||
under the License.
|
||||
-->
|
||||
|
||||
## CORR
|
||||
### Description
|
||||
#### Syntax
|
||||
|
||||
` double corr(x, y)`
|
||||
|
||||
计算皮尔逊系数, 即返回结果为: x和y的协方差,除x和y的标准差乘积。
|
||||
如果x或y的标准差为0, 将返回0。
|
||||
|
||||
|
||||
### example
|
||||
|
||||
```
|
||||
mysql> select corr(x,y) from baseall;
|
||||
+---------------------+
|
||||
| corr(x, y) |
|
||||
+---------------------+
|
||||
| 0.89442719099991586 |
|
||||
+---------------------+
|
||||
1 row in set (0.21 sec)
|
||||
|
||||
```
|
||||
### keywords
|
||||
CORR
|
||||
@ -28,6 +28,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionCount
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionInt;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.CollectList;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.CollectSet;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.Corr;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.CountByEnum;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupBitAnd;
|
||||
@ -93,6 +94,7 @@ public class BuiltinAggregateFunctions implements FunctionHelper {
|
||||
agg(BitmapUnionInt.class, "bitmap_union_int"),
|
||||
agg(CollectList.class, "collect_list", "group_array"),
|
||||
agg(CollectSet.class, "collect_set", "group_uniq_array"),
|
||||
agg(Corr.class, "corr"),
|
||||
agg(Count.class, "count"),
|
||||
agg(CountByEnum.class, "count_by_enum"),
|
||||
agg(GroupBitAnd.class, "group_bit_and"),
|
||||
|
||||
@ -1710,6 +1710,31 @@ public class FunctionSet<T> {
|
||||
"",
|
||||
false, true, false, true));
|
||||
|
||||
// corr
|
||||
addBuiltin(AggregateFunction.createBuiltin("corr",
|
||||
Lists.<Type>newArrayList(Type.TINYINT, Type.TINYINT), Type.DOUBLE, Type.DOUBLE,
|
||||
"", "", "", "", "", "", "",
|
||||
false, false, false, true));
|
||||
addBuiltin(AggregateFunction.createBuiltin("corr",
|
||||
Lists.<Type>newArrayList(Type.SMALLINT, Type.SMALLINT), Type.DOUBLE, Type.DOUBLE,
|
||||
"", "", "", "", "", "", "",
|
||||
false, false, false, true));
|
||||
addBuiltin(AggregateFunction.createBuiltin("corr",
|
||||
Lists.<Type>newArrayList(Type.INT, Type.INT), Type.DOUBLE, Type.DOUBLE,
|
||||
"", "", "", "", "", "", "",
|
||||
false, false, false, true));
|
||||
addBuiltin(AggregateFunction.createBuiltin("corr",
|
||||
Lists.<Type>newArrayList(Type.BIGINT, Type.BIGINT), Type.DOUBLE, Type.DOUBLE,
|
||||
"", "", "", "", "", "", "",
|
||||
false, false, false, true));
|
||||
addBuiltin(AggregateFunction.createBuiltin("corr",
|
||||
Lists.<Type>newArrayList(Type.FLOAT, Type.FLOAT), Type.DOUBLE, Type.DOUBLE,
|
||||
"", "", "", "", "", "", "",
|
||||
false, false, false, true));
|
||||
addBuiltin(AggregateFunction.createBuiltin("corr",
|
||||
Lists.<Type>newArrayList(Type.DOUBLE, Type.DOUBLE), Type.DOUBLE, Type.DOUBLE,
|
||||
"", "", "", "", "", "", "",
|
||||
false, false, false, true));
|
||||
}
|
||||
|
||||
public Map<String, List<Function>> getVectorizedFunctions() {
|
||||
|
||||
@ -0,0 +1,85 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.trees.expressions.functions.agg;
|
||||
|
||||
import org.apache.doris.catalog.FunctionSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
import org.apache.doris.nereids.types.BigIntType;
|
||||
import org.apache.doris.nereids.types.DoubleType;
|
||||
import org.apache.doris.nereids.types.FloatType;
|
||||
import org.apache.doris.nereids.types.IntegerType;
|
||||
import org.apache.doris.nereids.types.SmallIntType;
|
||||
import org.apache.doris.nereids.types.TinyIntType;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* AggregateFunction 'corr'. This class is generated by GenerateFunction.
|
||||
*/
|
||||
public class Corr extends AggregateFunction
|
||||
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNullable {
|
||||
|
||||
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
|
||||
FunctionSignature.ret(DoubleType.INSTANCE).args(TinyIntType.INSTANCE, TinyIntType.INSTANCE),
|
||||
FunctionSignature.ret(DoubleType.INSTANCE).args(SmallIntType.INSTANCE, SmallIntType.INSTANCE),
|
||||
FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE, IntegerType.INSTANCE),
|
||||
FunctionSignature.ret(DoubleType.INSTANCE).args(BigIntType.INSTANCE, BigIntType.INSTANCE),
|
||||
FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE, FloatType.INSTANCE),
|
||||
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, DoubleType.INSTANCE)
|
||||
);
|
||||
|
||||
/**
|
||||
* constructor with 2 argument.
|
||||
*/
|
||||
public Corr(Expression arg1, Expression arg2) {
|
||||
super("corr", arg1, arg2);
|
||||
}
|
||||
|
||||
/**
|
||||
* constructor with 3 arguments.
|
||||
*/
|
||||
public Corr(boolean distinct, Expression arg1, Expression arg2) {
|
||||
super("corr", distinct, arg1, arg2);
|
||||
}
|
||||
|
||||
/**
|
||||
* withDistinctAndChildren.
|
||||
*/
|
||||
@Override
|
||||
public Corr withDistinctAndChildren(boolean distinct, List<Expression> children) {
|
||||
Preconditions.checkArgument(children.size() == 2);
|
||||
return new Corr(distinct, children.get(0), children.get(1));
|
||||
}
|
||||
|
||||
@Override
|
||||
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
|
||||
return visitor.visitCorr(this, context);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<FunctionSignature> getSignatures() {
|
||||
return SIGNATURES;
|
||||
}
|
||||
}
|
||||
@ -29,6 +29,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionCount
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionInt;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.CollectList;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.CollectSet;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.Corr;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.CountByEnum;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupBitAnd;
|
||||
@ -126,6 +127,10 @@ public interface AggregateFunctionVisitor<R, C> {
|
||||
return visitAggregateFunction(collectSet, context);
|
||||
}
|
||||
|
||||
default R visitCorr(Corr corr, C context) {
|
||||
return visitAggregateFunction(corr, context);
|
||||
}
|
||||
|
||||
default R visitCount(Count count, C context) {
|
||||
return visitAggregateFunction(count, context);
|
||||
}
|
||||
|
||||
@ -0,0 +1,13 @@
|
||||
-- This file is automatically generated. You should know what you did if you want to edit this
|
||||
-- !sql --
|
||||
1.0
|
||||
|
||||
-- !sql --
|
||||
-1.0
|
||||
|
||||
-- !sql --
|
||||
0.0
|
||||
|
||||
-- !sql --
|
||||
0.8944271909999159
|
||||
|
||||
@ -0,0 +1,85 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
suite("test_corr") {
|
||||
sql """ DROP TABLE IF EXISTS test_corr """
|
||||
|
||||
sql """ SET enable_nereids_planner=true """
|
||||
sql """ SET enable_fallback_to_original_planner=false """
|
||||
|
||||
sql """
|
||||
CREATE TABLE test_corr (
|
||||
`id` int,
|
||||
`x` int,
|
||||
`y` int,
|
||||
) ENGINE=OLAP
|
||||
Duplicate KEY (`id`)
|
||||
DISTRIBUTED BY HASH(`id`) BUCKETS 4
|
||||
PROPERTIES (
|
||||
"replication_allocation" = "tag.location.default: 1"
|
||||
);
|
||||
"""
|
||||
|
||||
// Perfect positive correlation
|
||||
sql """
|
||||
insert into test_corr values
|
||||
(1, 1, 1),
|
||||
(2, 2, 2),
|
||||
(3, 3, 3),
|
||||
(4, 4, 4),
|
||||
(5, 5, 5)
|
||||
"""
|
||||
qt_sql "select corr(x,y) from test_corr"
|
||||
sql """ truncate table test_corr """
|
||||
|
||||
// Perfect negative correlation
|
||||
sql """
|
||||
insert into test_corr values
|
||||
(1, 1, 5),
|
||||
(2, 2, 4),
|
||||
(3, 3, 3),
|
||||
(4, 4, 2),
|
||||
(5, 5, 1)
|
||||
"""
|
||||
qt_sql "select corr(x,y) from test_corr"
|
||||
sql """ truncate table test_corr """
|
||||
|
||||
// Zero correlation
|
||||
sql """
|
||||
insert into test_corr values
|
||||
(1, 1, 1),
|
||||
(2, 1, 2),
|
||||
(3, 1, 3),
|
||||
(4, 1, 4),
|
||||
(5, 1, 5)
|
||||
"""
|
||||
qt_sql "select corr(x,y) from test_corr"
|
||||
sql """ truncate table test_corr """
|
||||
|
||||
// Partial linear correlation
|
||||
sql """
|
||||
insert into test_corr values
|
||||
(1, 1, 1),
|
||||
(2, 2, 2),
|
||||
(3, 3, 3),
|
||||
(4, 4, 4),
|
||||
(5, 5, 10)
|
||||
"""
|
||||
qt_sql "select corr(x,y) from test_corr"
|
||||
|
||||
sql """ DROP TABLE IF EXISTS test_corr """
|
||||
}
|
||||
Reference in New Issue
Block a user