[Fix-2.1](function) fix function covar core for not null input (#39943)

## Proposed changes

Issue Number: close #xxx

add testcases like:
```groovy
    qt_notnull1 "select covar_samp(non_nullable(x), non_nullable(y)) from test_covar_samp"
    qt_notnull2 "select covar_samp(x, non_nullable(y)) from test_covar_samp"
    qt_notnull3 "select covar_samp(non_nullable(x), y) from test_covar_samp"
```

before they will all coredump in 2.1
This commit is contained in:
zclllhhjj
2024-08-27 08:39:47 +08:00
committed by GitHub
parent 21bd4a4ac8
commit db0724dfe0
5 changed files with 77 additions and 36 deletions

View File

@ -76,13 +76,15 @@ AggregateFunctionPtr create_aggregate_function_covariance_pop(const std::string&
name, argument_types, result_is_nullable, NOTNULLABLE);
}
// register covar_pop for nullable/non_nullable both.
void register_aggregate_function_covar_pop(AggregateFunctionSimpleFactory& factory) {
factory.register_function_both("covar", create_aggregate_function_covariance_pop);
factory.register_alias("covar", "covar_pop");
}
void register_aggregate_function_covar_samp(AggregateFunctionSimpleFactory& factory) {
factory.register_function("covar_samp", create_aggregate_function_covariance_samp<NOTNULLABLE>);
factory.register_function("covar_samp", create_aggregate_function_covariance_samp<NOTNULLABLE>,
NOTNULLABLE);
factory.register_function("covar_samp", create_aggregate_function_covariance_samp<NULLABLE>,
NULLABLE);
}

View File

@ -17,17 +17,16 @@
#pragma once
#include "common/exception.h"
#include "common/status.h"
#define POP true
#define NOTPOP false
#define NULLABLE true
#define NOTNULLABLE false
#include <stddef.h>
#include <stdint.h>
#include <algorithm>
#include <boost/iterator/iterator_facade.hpp>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <type_traits>
@ -43,8 +42,8 @@
#include "vec/data_types/data_type_number.h"
#include "vec/io/io_helper.h"
namespace doris {
namespace vectorized {
namespace doris::vectorized {
class Arena;
class BufferReadable;
class BufferWritable;
@ -52,10 +51,6 @@ template <typename T>
class ColumnDecimal;
template <typename>
class ColumnVector;
} // namespace vectorized
} // namespace doris
namespace doris::vectorized {
template <typename T>
struct BaseData {
@ -228,17 +223,30 @@ struct SampData : Data {
using ColVecResult = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<Decimal128V2>,
ColumnVector<Float64>>;
void insert_result_into(IColumn& to) const {
ColumnNullable& nullable_column = assert_cast<ColumnNullable&>(to);
if (this->count == 1 || this->count == 0) {
nullable_column.insert_default();
} else {
auto& col = assert_cast<ColVecResult&>(nullable_column.get_nested_column());
if constexpr (IsDecimalNumber<T>) {
col.get_data().push_back(this->get_samp_result().value());
if (to.is_nullable()) {
auto& nullable_column = assert_cast<ColumnNullable&>(to);
if (this->count == 1 || this->count == 0) {
nullable_column.insert_default();
} else {
col.get_data().push_back(this->get_samp_result());
auto& col = assert_cast<ColVecResult&>(nullable_column.get_nested_column());
if constexpr (IsDecimalNumber<T>) {
col.get_data().push_back(this->get_samp_result().value());
} else {
col.get_data().push_back(this->get_samp_result());
}
nullable_column.get_null_map_data().push_back(0);
}
} else {
if (this->count == 1 || this->count == 0) {
to.insert_default();
} else {
auto& col = assert_cast<ColVecResult&>(to);
if constexpr (IsDecimalNumber<T>) {
col.get_data().push_back(this->get_samp_result().value());
} else {
col.get_data().push_back(this->get_samp_result());
}
}
nullable_column.get_null_map_data().push_back(0);
}
}
};
@ -266,26 +274,45 @@ public:
String get_name() const override { return Data::name(); }
DataTypePtr get_return_type() const override {
if constexpr (is_pop) {
if constexpr (is_pop || !is_nullable) { // covar and covar_samp(non_nullable)
return Data::get_return_type();
} else {
} else { // covar_samp
return make_nullable(Data::get_return_type());
}
}
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena*) const override {
if constexpr (is_pop) {
if constexpr (is_pop) { // covar_samp
this->data(place).add(columns[0], columns[1], row_num);
} else {
} else { // covar
if constexpr (is_nullable) {
// nullable means at least one child is null.
// so here, maybe JUST ONE OF ups is null. so nullptr perhaps in ..._x or ..._y!
const auto* nullable_column_x = check_and_get_column<ColumnNullable>(columns[0]);
const auto* nullable_column_y = check_and_get_column<ColumnNullable>(columns[1]);
if (!nullable_column_x->is_null_at(row_num) &&
!nullable_column_y->is_null_at(row_num)) {
this->data(place).add(&nullable_column_x->get_nested_column(),
&nullable_column_y->get_nested_column(), row_num);
if (nullable_column_x && nullable_column_y) { // both nullable
if (!nullable_column_x->is_null_at(row_num) &&
!nullable_column_y->is_null_at(row_num)) {
this->data(place).add(&nullable_column_x->get_nested_column(),
&nullable_column_y->get_nested_column(), row_num);
}
} else if (nullable_column_x) { // x nullable
if (!nullable_column_x->is_null_at(row_num)) {
this->data(place).add(&nullable_column_x->get_nested_column(), columns[1],
row_num);
}
} else if (nullable_column_y) { // y nullable
if (!nullable_column_y->is_null_at(row_num)) {
this->data(place).add(columns[0], &nullable_column_y->get_nested_column(),
row_num);
}
} else {
throw Exception(ErrorCode::INTERNAL_ERROR,
"Nullable function {} get non-nullable columns!", get_name());
}
} else {
this->data(place).add(columns[0], columns[1], row_num);
}
@ -317,14 +344,14 @@ template <typename Data, bool is_nullable>
class AggregateFunctionSamp final
: public AggregateFunctionSampCovariance<NOTPOP, Data, is_nullable> {
public:
AggregateFunctionSamp(const DataTypes& argument_types_)
AggregateFunctionSamp(const DataTypes& argument_types_) // covar_samp
: AggregateFunctionSampCovariance<NOTPOP, Data, is_nullable>(argument_types_) {}
};
template <typename Data, bool is_nullable>
class AggregateFunctionPop final : public AggregateFunctionSampCovariance<POP, Data, is_nullable> {
public:
AggregateFunctionPop(const DataTypes& argument_types_)
AggregateFunctionPop(const DataTypes& argument_types_) // covar
: AggregateFunctionSampCovariance<POP, Data, is_nullable>(argument_types_) {}
};

View File

@ -19,8 +19,8 @@ package org.apache.doris.nereids.trees.expressions.functions.agg;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BigIntType;
@ -39,7 +39,7 @@ import java.util.List;
* AggregateFunction 'covar_samp'. This class is generated by GenerateFunction.
*/
public class CovarSamp extends AggregateFunction
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNullable {
implements UnaryExpression, ExplicitlyCastableSignature, PropagateNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, DoubleType.INSTANCE),

View File

@ -1,6 +1,6 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !sql --
1
1.0
-- !sql --
-1.5
@ -12,4 +12,14 @@
4.5
-- !sql --
1.666667
1.666666666666666
-- !notnull1 --
1.666666666666666
-- !notnull2 --
1.666666666666666
-- !notnull3 --
1.666666666666666

View File

@ -86,5 +86,7 @@ suite("test_covar_samp") {
"""
qt_sql "select covar_samp(x,y) from test_covar_samp"
sql """ DROP TABLE IF EXISTS test_covar_samp """
qt_notnull1 "select covar_samp(non_nullable(x), non_nullable(y)) from test_covar_samp"
qt_notnull2 "select covar_samp(x, non_nullable(y)) from test_covar_samp"
qt_notnull3 "select covar_samp(non_nullable(x), y) from test_covar_samp"
}