From df96f76f7879b4cea85639c2b3695149bfddeb69 Mon Sep 17 00:00:00 2001 From: Mryange <59914473+Mryange@users.noreply.github.com> Date: Tue, 23 Apr 2024 09:18:34 +0800 Subject: [PATCH] [featrue](pipelineX) check output type in some node (#33716) --- .../exec/aggregation_sink_operator.cpp | 9 ++++++-- .../pipeline/exec/aggregation_sink_operator.h | 2 ++ be/src/pipeline/exec/hashjoin_build_sink.cpp | 22 ++++++++++++++++--- .../exec/streaming_aggregation_operator.cpp | 9 ++++++-- .../exec/streaming_aggregation_operator.h | 1 + be/src/pipeline/exec/union_sink_operator.cpp | 1 + be/src/vec/exprs/vectorized_agg_fn.cpp | 19 ++++++++++++++++ be/src/vec/exprs/vectorized_agg_fn.h | 4 ++++ 8 files changed, 60 insertions(+), 7 deletions(-) diff --git a/be/src/pipeline/exec/aggregation_sink_operator.cpp b/be/src/pipeline/exec/aggregation_sink_operator.cpp index e29d6de286..869685f02e 100644 --- a/be/src/pipeline/exec/aggregation_sink_operator.cpp +++ b/be/src/pipeline/exec/aggregation_sink_operator.cpp @@ -631,7 +631,8 @@ AggSinkOperatorX::AggSinkOperatorX(ObjectPool* pool, int operator_id, const TPla (tnode.__isset.conjuncts && !tnode.conjuncts.empty())), _partition_exprs(tnode.__isset.distribute_expr_lists ? tnode.distribute_expr_lists[0] : std::vector {}), - _is_colocate(tnode.agg_node.__isset.is_colocate && tnode.agg_node.is_colocate) {} + _is_colocate(tnode.agg_node.__isset.is_colocate && tnode.agg_node.is_colocate), + _agg_fn_output_row_descriptor(descs, tnode.row_tuples, tnode.nullable_tuples) {} Status AggSinkOperatorX::init(const TPlanNode& tnode, RuntimeState* state) { RETURN_IF_ERROR(DataSinkOperatorX::init(tnode, state)); @@ -714,7 +715,11 @@ Status AggSinkOperatorX::prepare(RuntimeState* state) { alignment_of_next_state * alignment_of_next_state; } } - + // check output type + if (_needs_finalize) { + RETURN_IF_ERROR(vectorized::AggFnEvaluator::check_agg_fn_output( + _probe_expr_ctxs.size(), _aggregate_evaluators, _agg_fn_output_row_descriptor)); + } return Status::OK(); } diff --git a/be/src/pipeline/exec/aggregation_sink_operator.h b/be/src/pipeline/exec/aggregation_sink_operator.h index e3d8baad39..b3ffa19d6d 100644 --- a/be/src/pipeline/exec/aggregation_sink_operator.h +++ b/be/src/pipeline/exec/aggregation_sink_operator.h @@ -213,6 +213,8 @@ protected: const std::vector _partition_exprs; const bool _is_colocate; + + RowDescriptor _agg_fn_output_row_descriptor; }; } // namespace pipeline diff --git a/be/src/pipeline/exec/hashjoin_build_sink.cpp b/be/src/pipeline/exec/hashjoin_build_sink.cpp index 176eaf33b1..a0d111c63a 100644 --- a/be/src/pipeline/exec/hashjoin_build_sink.cpp +++ b/be/src/pipeline/exec/hashjoin_build_sink.cpp @@ -22,6 +22,7 @@ #include "exprs/bloom_filter_func.h" #include "pipeline/exec/hashjoin_probe_operator.h" #include "pipeline/exec/operator.h" +#include "vec/data_types/data_type_nullable.h" #include "vec/exec/join/vhash_join_node.h" #include "vec/utils/template_helpers.hpp" @@ -461,9 +462,24 @@ Status HashJoinBuildSinkOperatorX::init(const TPlanNode& tnode, RuntimeState* st const std::vector& eq_join_conjuncts = tnode.hash_join_node.eq_join_conjuncts; for (const auto& eq_join_conjunct : eq_join_conjuncts) { - vectorized::VExprContextSPtr ctx; - RETURN_IF_ERROR(vectorized::VExpr::create_expr_tree(eq_join_conjunct.right, ctx)); - _build_expr_ctxs.push_back(ctx); + vectorized::VExprContextSPtr build_ctx; + RETURN_IF_ERROR(vectorized::VExpr::create_expr_tree(eq_join_conjunct.right, build_ctx)); + { + // for type check + vectorized::VExprContextSPtr probe_ctx; + RETURN_IF_ERROR(vectorized::VExpr::create_expr_tree(eq_join_conjunct.left, probe_ctx)); + auto build_side_expr_type = build_ctx->root()->data_type(); + auto probe_side_expr_type = probe_ctx->root()->data_type(); + if (!vectorized::make_nullable(build_side_expr_type) + ->equals(*vectorized::make_nullable(probe_side_expr_type))) { + return Status::InternalError( + "build side type {}, not match probe side type {} , node info " + "{}", + build_side_expr_type->get_name(), probe_side_expr_type->get_name(), + this->debug_string(0)); + } + } + _build_expr_ctxs.push_back(build_ctx); const auto vexpr = _build_expr_ctxs.back()->root(); diff --git a/be/src/pipeline/exec/streaming_aggregation_operator.cpp b/be/src/pipeline/exec/streaming_aggregation_operator.cpp index dfcfb0ebc4..f33d799db4 100644 --- a/be/src/pipeline/exec/streaming_aggregation_operator.cpp +++ b/be/src/pipeline/exec/streaming_aggregation_operator.cpp @@ -1145,7 +1145,8 @@ StreamingAggOperatorX::StreamingAggOperatorX(ObjectPool* pool, int operator_id, _needs_finalize(tnode.agg_node.need_finalize), _is_merge(false), _is_first_phase(tnode.agg_node.__isset.is_first_phase && tnode.agg_node.is_first_phase), - _have_conjuncts(tnode.__isset.vconjunct && !tnode.vconjunct.nodes.empty()) {} + _have_conjuncts(tnode.__isset.vconjunct && !tnode.vconjunct.nodes.empty()), + _agg_fn_output_row_descriptor(descs, tnode.row_tuples, tnode.nullable_tuples) {} Status StreamingAggOperatorX::init(const TPlanNode& tnode, RuntimeState* state) { RETURN_IF_ERROR(StatefulOperatorX::init(tnode, state)); @@ -1235,7 +1236,11 @@ Status StreamingAggOperatorX::prepare(RuntimeState* state) { alignment_of_next_state * alignment_of_next_state; } } - + // check output type + if (_needs_finalize) { + RETURN_IF_ERROR(vectorized::AggFnEvaluator::check_agg_fn_output( + _probe_expr_ctxs.size(), _aggregate_evaluators, _agg_fn_output_row_descriptor)); + } return Status::OK(); } diff --git a/be/src/pipeline/exec/streaming_aggregation_operator.h b/be/src/pipeline/exec/streaming_aggregation_operator.h index 2895fc63f3..caaee88b3c 100644 --- a/be/src/pipeline/exec/streaming_aggregation_operator.h +++ b/be/src/pipeline/exec/streaming_aggregation_operator.h @@ -243,6 +243,7 @@ private: bool _can_short_circuit = false; std::vector _make_nullable_keys; bool _have_conjuncts; + RowDescriptor _agg_fn_output_row_descriptor; }; } // namespace pipeline diff --git a/be/src/pipeline/exec/union_sink_operator.cpp b/be/src/pipeline/exec/union_sink_operator.cpp index 5acf6c8e1a..e466237a37 100644 --- a/be/src/pipeline/exec/union_sink_operator.cpp +++ b/be/src/pipeline/exec/union_sink_operator.cpp @@ -138,6 +138,7 @@ Status UnionSinkOperatorX::init(const TPlanNode& tnode, RuntimeState* state) { Status UnionSinkOperatorX::prepare(RuntimeState* state) { RETURN_IF_ERROR(vectorized::VExpr::prepare(_child_expr, state, _child_x->row_desc())); + RETURN_IF_ERROR(vectorized::VExpr::check_expr_output_type(_child_expr, _row_descriptor)); return Status::OK(); } diff --git a/be/src/vec/exprs/vectorized_agg_fn.cpp b/be/src/vec/exprs/vectorized_agg_fn.cpp index 4dfdff7820..d0fbf36372 100644 --- a/be/src/vec/exprs/vectorized_agg_fn.cpp +++ b/be/src/vec/exprs/vectorized_agg_fn.cpp @@ -350,4 +350,23 @@ AggFnEvaluator::AggFnEvaluator(AggFnEvaluator& evaluator, RuntimeState* state) } } +Status AggFnEvaluator::check_agg_fn_output(int key_size, + const std::vector& agg_fn, + const RowDescriptor& output_row_desc) { + auto name_and_types = VectorizedUtils::create_name_and_data_types(output_row_desc); + for (int i = key_size, j = 0; i < name_and_types.size(); i++, j++) { + auto&& [name, column_type] = name_and_types[i]; + auto agg_return_type = agg_fn[j]->function()->get_return_type(); + if (!column_type->equals(*agg_return_type)) { + if (!column_type->is_nullable() || agg_return_type->is_nullable() || + !remove_nullable(column_type)->equals(*agg_return_type)) { + return Status::InternalError( + "column_type not match data_types in agg node, column_type={}, " + "data_types={},column name={}", + column_type->get_name(), agg_return_type->get_name(), name); + } + } + } + return Status::OK(); +} } // namespace doris::vectorized diff --git a/be/src/vec/exprs/vectorized_agg_fn.h b/be/src/vec/exprs/vectorized_agg_fn.h index 546b939ddf..7dcd1b3e02 100644 --- a/be/src/vec/exprs/vectorized_agg_fn.h +++ b/be/src/vec/exprs/vectorized_agg_fn.h @@ -97,6 +97,10 @@ public: bool is_merge() const { return _is_merge; } const VExprContextSPtrs& input_exprs_ctxs() const { return _input_exprs_ctxs; } + static Status check_agg_fn_output(int key_size, + const std::vector& agg_fn, + const RowDescriptor& output_row_desc); + void set_version(const int version) { _function->set_version(version); } AggFnEvaluator* clone(RuntimeState* state, ObjectPool* pool);