[Fix-2.1](function) fix function covar core for not null input (#39943)
## Proposed changes
Issue Number: close #xxx
add testcases like:
```groovy
qt_notnull1 "select covar_samp(non_nullable(x), non_nullable(y)) from test_covar_samp"
qt_notnull2 "select covar_samp(x, non_nullable(y)) from test_covar_samp"
qt_notnull3 "select covar_samp(non_nullable(x), y) from test_covar_samp"
```
before they will all coredump in 2.1
This commit is contained in:
@ -76,13 +76,15 @@ AggregateFunctionPtr create_aggregate_function_covariance_pop(const std::string&
|
||||
name, argument_types, result_is_nullable, NOTNULLABLE);
|
||||
}
|
||||
|
||||
// register covar_pop for nullable/non_nullable both.
|
||||
void register_aggregate_function_covar_pop(AggregateFunctionSimpleFactory& factory) {
|
||||
factory.register_function_both("covar", create_aggregate_function_covariance_pop);
|
||||
factory.register_alias("covar", "covar_pop");
|
||||
}
|
||||
|
||||
void register_aggregate_function_covar_samp(AggregateFunctionSimpleFactory& factory) {
|
||||
factory.register_function("covar_samp", create_aggregate_function_covariance_samp<NOTNULLABLE>);
|
||||
factory.register_function("covar_samp", create_aggregate_function_covariance_samp<NOTNULLABLE>,
|
||||
NOTNULLABLE);
|
||||
factory.register_function("covar_samp", create_aggregate_function_covariance_samp<NULLABLE>,
|
||||
NULLABLE);
|
||||
}
|
||||
|
||||
@ -17,17 +17,16 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "common/exception.h"
|
||||
#include "common/status.h"
|
||||
#define POP true
|
||||
#define NOTPOP false
|
||||
#define NULLABLE true
|
||||
#define NOTNULLABLE false
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <boost/iterator/iterator_facade.hpp>
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
|
||||
@ -43,8 +42,8 @@
|
||||
#include "vec/data_types/data_type_number.h"
|
||||
#include "vec/io/io_helper.h"
|
||||
|
||||
namespace doris {
|
||||
namespace vectorized {
|
||||
namespace doris::vectorized {
|
||||
|
||||
class Arena;
|
||||
class BufferReadable;
|
||||
class BufferWritable;
|
||||
@ -52,10 +51,6 @@ template <typename T>
|
||||
class ColumnDecimal;
|
||||
template <typename>
|
||||
class ColumnVector;
|
||||
} // namespace vectorized
|
||||
} // namespace doris
|
||||
|
||||
namespace doris::vectorized {
|
||||
|
||||
template <typename T>
|
||||
struct BaseData {
|
||||
@ -228,17 +223,30 @@ struct SampData : Data {
|
||||
using ColVecResult = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<Decimal128V2>,
|
||||
ColumnVector<Float64>>;
|
||||
void insert_result_into(IColumn& to) const {
|
||||
ColumnNullable& nullable_column = assert_cast<ColumnNullable&>(to);
|
||||
if (this->count == 1 || this->count == 0) {
|
||||
nullable_column.insert_default();
|
||||
} else {
|
||||
auto& col = assert_cast<ColVecResult&>(nullable_column.get_nested_column());
|
||||
if constexpr (IsDecimalNumber<T>) {
|
||||
col.get_data().push_back(this->get_samp_result().value());
|
||||
if (to.is_nullable()) {
|
||||
auto& nullable_column = assert_cast<ColumnNullable&>(to);
|
||||
if (this->count == 1 || this->count == 0) {
|
||||
nullable_column.insert_default();
|
||||
} else {
|
||||
col.get_data().push_back(this->get_samp_result());
|
||||
auto& col = assert_cast<ColVecResult&>(nullable_column.get_nested_column());
|
||||
if constexpr (IsDecimalNumber<T>) {
|
||||
col.get_data().push_back(this->get_samp_result().value());
|
||||
} else {
|
||||
col.get_data().push_back(this->get_samp_result());
|
||||
}
|
||||
nullable_column.get_null_map_data().push_back(0);
|
||||
}
|
||||
} else {
|
||||
if (this->count == 1 || this->count == 0) {
|
||||
to.insert_default();
|
||||
} else {
|
||||
auto& col = assert_cast<ColVecResult&>(to);
|
||||
if constexpr (IsDecimalNumber<T>) {
|
||||
col.get_data().push_back(this->get_samp_result().value());
|
||||
} else {
|
||||
col.get_data().push_back(this->get_samp_result());
|
||||
}
|
||||
}
|
||||
nullable_column.get_null_map_data().push_back(0);
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -266,26 +274,45 @@ public:
|
||||
String get_name() const override { return Data::name(); }
|
||||
|
||||
DataTypePtr get_return_type() const override {
|
||||
if constexpr (is_pop) {
|
||||
if constexpr (is_pop || !is_nullable) { // covar and covar_samp(non_nullable)
|
||||
return Data::get_return_type();
|
||||
} else {
|
||||
} else { // covar_samp
|
||||
return make_nullable(Data::get_return_type());
|
||||
}
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
|
||||
Arena*) const override {
|
||||
if constexpr (is_pop) {
|
||||
if constexpr (is_pop) { // covar_samp
|
||||
this->data(place).add(columns[0], columns[1], row_num);
|
||||
} else {
|
||||
} else { // covar
|
||||
if constexpr (is_nullable) {
|
||||
// nullable means at least one child is null.
|
||||
// so here, maybe JUST ONE OF ups is null. so nullptr perhaps in ..._x or ..._y!
|
||||
const auto* nullable_column_x = check_and_get_column<ColumnNullable>(columns[0]);
|
||||
const auto* nullable_column_y = check_and_get_column<ColumnNullable>(columns[1]);
|
||||
if (!nullable_column_x->is_null_at(row_num) &&
|
||||
!nullable_column_y->is_null_at(row_num)) {
|
||||
this->data(place).add(&nullable_column_x->get_nested_column(),
|
||||
&nullable_column_y->get_nested_column(), row_num);
|
||||
|
||||
if (nullable_column_x && nullable_column_y) { // both nullable
|
||||
if (!nullable_column_x->is_null_at(row_num) &&
|
||||
!nullable_column_y->is_null_at(row_num)) {
|
||||
this->data(place).add(&nullable_column_x->get_nested_column(),
|
||||
&nullable_column_y->get_nested_column(), row_num);
|
||||
}
|
||||
} else if (nullable_column_x) { // x nullable
|
||||
if (!nullable_column_x->is_null_at(row_num)) {
|
||||
this->data(place).add(&nullable_column_x->get_nested_column(), columns[1],
|
||||
row_num);
|
||||
}
|
||||
} else if (nullable_column_y) { // y nullable
|
||||
if (!nullable_column_y->is_null_at(row_num)) {
|
||||
this->data(place).add(columns[0], &nullable_column_y->get_nested_column(),
|
||||
row_num);
|
||||
}
|
||||
} else {
|
||||
throw Exception(ErrorCode::INTERNAL_ERROR,
|
||||
"Nullable function {} get non-nullable columns!", get_name());
|
||||
}
|
||||
|
||||
} else {
|
||||
this->data(place).add(columns[0], columns[1], row_num);
|
||||
}
|
||||
@ -317,14 +344,14 @@ template <typename Data, bool is_nullable>
|
||||
class AggregateFunctionSamp final
|
||||
: public AggregateFunctionSampCovariance<NOTPOP, Data, is_nullable> {
|
||||
public:
|
||||
AggregateFunctionSamp(const DataTypes& argument_types_)
|
||||
AggregateFunctionSamp(const DataTypes& argument_types_) // covar_samp
|
||||
: AggregateFunctionSampCovariance<NOTPOP, Data, is_nullable>(argument_types_) {}
|
||||
};
|
||||
|
||||
template <typename Data, bool is_nullable>
|
||||
class AggregateFunctionPop final : public AggregateFunctionSampCovariance<POP, Data, is_nullable> {
|
||||
public:
|
||||
AggregateFunctionPop(const DataTypes& argument_types_)
|
||||
AggregateFunctionPop(const DataTypes& argument_types_) // covar
|
||||
: AggregateFunctionSampCovariance<POP, Data, is_nullable>(argument_types_) {}
|
||||
};
|
||||
|
||||
|
||||
@ -19,8 +19,8 @@ package org.apache.doris.nereids.trees.expressions.functions.agg;
|
||||
|
||||
import org.apache.doris.catalog.FunctionSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
import org.apache.doris.nereids.types.BigIntType;
|
||||
@ -39,7 +39,7 @@ import java.util.List;
|
||||
* AggregateFunction 'covar_samp'. This class is generated by GenerateFunction.
|
||||
*/
|
||||
public class CovarSamp extends AggregateFunction
|
||||
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNullable {
|
||||
implements UnaryExpression, ExplicitlyCastableSignature, PropagateNullable {
|
||||
|
||||
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
|
||||
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, DoubleType.INSTANCE),
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
-- This file is automatically generated. You should know what you did if you want to edit this
|
||||
-- !sql --
|
||||
1
|
||||
1.0
|
||||
|
||||
-- !sql --
|
||||
-1.5
|
||||
@ -12,4 +12,14 @@
|
||||
4.5
|
||||
|
||||
-- !sql --
|
||||
1.666667
|
||||
1.666666666666666
|
||||
|
||||
-- !notnull1 --
|
||||
1.666666666666666
|
||||
|
||||
-- !notnull2 --
|
||||
1.666666666666666
|
||||
|
||||
-- !notnull3 --
|
||||
1.666666666666666
|
||||
|
||||
|
||||
@ -86,5 +86,7 @@ suite("test_covar_samp") {
|
||||
"""
|
||||
qt_sql "select covar_samp(x,y) from test_covar_samp"
|
||||
|
||||
sql """ DROP TABLE IF EXISTS test_covar_samp """
|
||||
qt_notnull1 "select covar_samp(non_nullable(x), non_nullable(y)) from test_covar_samp"
|
||||
qt_notnull2 "select covar_samp(x, non_nullable(y)) from test_covar_samp"
|
||||
qt_notnull3 "select covar_samp(non_nullable(x), y) from test_covar_samp"
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user