diff --git a/be/src/vec/aggregate_functions/aggregate_function_covar.cpp b/be/src/vec/aggregate_functions/aggregate_function_covar.cpp index aa5bd511d9..4eb03c05fd 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_covar.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_covar.cpp @@ -76,13 +76,15 @@ AggregateFunctionPtr create_aggregate_function_covariance_pop(const std::string& name, argument_types, result_is_nullable, NOTNULLABLE); } +// register covar_pop for nullable/non_nullable both. void register_aggregate_function_covar_pop(AggregateFunctionSimpleFactory& factory) { factory.register_function_both("covar", create_aggregate_function_covariance_pop); factory.register_alias("covar", "covar_pop"); } void register_aggregate_function_covar_samp(AggregateFunctionSimpleFactory& factory) { - factory.register_function("covar_samp", create_aggregate_function_covariance_samp); + factory.register_function("covar_samp", create_aggregate_function_covariance_samp, + NOTNULLABLE); factory.register_function("covar_samp", create_aggregate_function_covariance_samp, NULLABLE); } diff --git a/be/src/vec/aggregate_functions/aggregate_function_covar.h b/be/src/vec/aggregate_functions/aggregate_function_covar.h index 31f0d7d283..51a07f2114 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_covar.h +++ b/be/src/vec/aggregate_functions/aggregate_function_covar.h @@ -17,17 +17,16 @@ #pragma once +#include "common/exception.h" +#include "common/status.h" #define POP true #define NOTPOP false #define NULLABLE true #define NOTNULLABLE false -#include -#include - -#include #include -#include +#include +#include #include #include @@ -43,8 +42,8 @@ #include "vec/data_types/data_type_number.h" #include "vec/io/io_helper.h" -namespace doris { -namespace vectorized { +namespace doris::vectorized { + class Arena; class BufferReadable; class BufferWritable; @@ -52,10 +51,6 @@ template class ColumnDecimal; template class ColumnVector; -} // namespace vectorized -} // namespace doris - -namespace doris::vectorized { template struct BaseData { @@ -228,17 +223,30 @@ struct SampData : Data { using ColVecResult = std::conditional_t, ColumnDecimal, ColumnVector>; void insert_result_into(IColumn& to) const { - ColumnNullable& nullable_column = assert_cast(to); - if (this->count == 1 || this->count == 0) { - nullable_column.insert_default(); - } else { - auto& col = assert_cast(nullable_column.get_nested_column()); - if constexpr (IsDecimalNumber) { - col.get_data().push_back(this->get_samp_result().value()); + if (to.is_nullable()) { + auto& nullable_column = assert_cast(to); + if (this->count == 1 || this->count == 0) { + nullable_column.insert_default(); } else { - col.get_data().push_back(this->get_samp_result()); + auto& col = assert_cast(nullable_column.get_nested_column()); + if constexpr (IsDecimalNumber) { + col.get_data().push_back(this->get_samp_result().value()); + } else { + col.get_data().push_back(this->get_samp_result()); + } + nullable_column.get_null_map_data().push_back(0); + } + } else { + if (this->count == 1 || this->count == 0) { + to.insert_default(); + } else { + auto& col = assert_cast(to); + if constexpr (IsDecimalNumber) { + col.get_data().push_back(this->get_samp_result().value()); + } else { + col.get_data().push_back(this->get_samp_result()); + } } - nullable_column.get_null_map_data().push_back(0); } } }; @@ -266,26 +274,45 @@ public: String get_name() const override { return Data::name(); } DataTypePtr get_return_type() const override { - if constexpr (is_pop) { + if constexpr (is_pop || !is_nullable) { // covar and covar_samp(non_nullable) return Data::get_return_type(); - } else { + } else { // covar_samp return make_nullable(Data::get_return_type()); } } void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num, Arena*) const override { - if constexpr (is_pop) { + if constexpr (is_pop) { // covar_samp this->data(place).add(columns[0], columns[1], row_num); - } else { + } else { // covar if constexpr (is_nullable) { + // nullable means at least one child is null. + // so here, maybe JUST ONE OF ups is null. so nullptr perhaps in ..._x or ..._y! const auto* nullable_column_x = check_and_get_column(columns[0]); const auto* nullable_column_y = check_and_get_column(columns[1]); - if (!nullable_column_x->is_null_at(row_num) && - !nullable_column_y->is_null_at(row_num)) { - this->data(place).add(&nullable_column_x->get_nested_column(), - &nullable_column_y->get_nested_column(), row_num); + + if (nullable_column_x && nullable_column_y) { // both nullable + if (!nullable_column_x->is_null_at(row_num) && + !nullable_column_y->is_null_at(row_num)) { + this->data(place).add(&nullable_column_x->get_nested_column(), + &nullable_column_y->get_nested_column(), row_num); + } + } else if (nullable_column_x) { // x nullable + if (!nullable_column_x->is_null_at(row_num)) { + this->data(place).add(&nullable_column_x->get_nested_column(), columns[1], + row_num); + } + } else if (nullable_column_y) { // y nullable + if (!nullable_column_y->is_null_at(row_num)) { + this->data(place).add(columns[0], &nullable_column_y->get_nested_column(), + row_num); + } + } else { + throw Exception(ErrorCode::INTERNAL_ERROR, + "Nullable function {} get non-nullable columns!", get_name()); } + } else { this->data(place).add(columns[0], columns[1], row_num); } @@ -317,14 +344,14 @@ template class AggregateFunctionSamp final : public AggregateFunctionSampCovariance { public: - AggregateFunctionSamp(const DataTypes& argument_types_) + AggregateFunctionSamp(const DataTypes& argument_types_) // covar_samp : AggregateFunctionSampCovariance(argument_types_) {} }; template class AggregateFunctionPop final : public AggregateFunctionSampCovariance { public: - AggregateFunctionPop(const DataTypes& argument_types_) + AggregateFunctionPop(const DataTypes& argument_types_) // covar : AggregateFunctionSampCovariance(argument_types_) {} }; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/CovarSamp.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/CovarSamp.java index 2693d7636f..0ffe6e88af 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/CovarSamp.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/CovarSamp.java @@ -19,8 +19,8 @@ package org.apache.doris.nereids.trees.expressions.functions.agg; import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable; import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; +import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.BigIntType; @@ -39,7 +39,7 @@ import java.util.List; * AggregateFunction 'covar_samp'. This class is generated by GenerateFunction. */ public class CovarSamp extends AggregateFunction - implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNullable { + implements UnaryExpression, ExplicitlyCastableSignature, PropagateNullable { public static final List SIGNATURES = ImmutableList.of( FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, DoubleType.INSTANCE), diff --git a/regression-test/data/nereids_function_p0/agg_function/test_covar_samp.out b/regression-test/data/nereids_function_p0/agg_function/test_covar_samp.out index 728beed6cc..806badbb46 100644 --- a/regression-test/data/nereids_function_p0/agg_function/test_covar_samp.out +++ b/regression-test/data/nereids_function_p0/agg_function/test_covar_samp.out @@ -1,6 +1,6 @@ -- This file is automatically generated. You should know what you did if you want to edit this -- !sql -- -1 +1.0 -- !sql -- -1.5 @@ -12,4 +12,14 @@ 4.5 -- !sql -- -1.666667 \ No newline at end of file +1.666666666666666 + +-- !notnull1 -- +1.666666666666666 + +-- !notnull2 -- +1.666666666666666 + +-- !notnull3 -- +1.666666666666666 + diff --git a/regression-test/suites/nereids_function_p0/agg_function/test_covar_samp.groovy b/regression-test/suites/nereids_function_p0/agg_function/test_covar_samp.groovy index a75a933e74..c9e82a86a9 100644 --- a/regression-test/suites/nereids_function_p0/agg_function/test_covar_samp.groovy +++ b/regression-test/suites/nereids_function_p0/agg_function/test_covar_samp.groovy @@ -86,5 +86,7 @@ suite("test_covar_samp") { """ qt_sql "select covar_samp(x,y) from test_covar_samp" - sql """ DROP TABLE IF EXISTS test_covar_samp """ + qt_notnull1 "select covar_samp(non_nullable(x), non_nullable(y)) from test_covar_samp" + qt_notnull2 "select covar_samp(x, non_nullable(y)) from test_covar_samp" + qt_notnull3 "select covar_samp(non_nullable(x), y) from test_covar_samp" }