From 6871c964af911e0ae9ffb385136091670fa47a2c Mon Sep 17 00:00:00 2001 From: starocean999 <40539150+starocean999@users.noreply.github.com> Date: Wed, 20 Mar 2024 14:43:49 +0800 Subject: [PATCH] [fix](nereids)NullSafeEqualToEqual rule only change to equal if both children are not nullable (#32374) NullSafeEqualToEqual rule only change to equal if both children are not nullable --- be/src/pipeline/exec/hashjoin_build_sink.cpp | 12 ++++++ be/src/pipeline/exec/hashjoin_build_sink.h | 3 ++ .../pipeline/exec/hashjoin_probe_operator.cpp | 21 ++++++++++ .../pipeline/exec/hashjoin_probe_operator.h | 4 ++ be/src/vec/exec/join/vhash_join_node.cpp | 42 ++++++++++++------- be/src/vec/exec/join/vhash_join_node.h | 4 +- .../rules/NullSafeEqualToEqual.java | 2 +- .../rules/NullSafeEqualToEqualTest.java | 32 ++++++++------ .../test_half_join_nullable_build_side.groovy | 2 - 9 files changed, 90 insertions(+), 32 deletions(-) diff --git a/be/src/pipeline/exec/hashjoin_build_sink.cpp b/be/src/pipeline/exec/hashjoin_build_sink.cpp index 94ae46690f..b3ee878a94 100644 --- a/be/src/pipeline/exec/hashjoin_build_sink.cpp +++ b/be/src/pipeline/exec/hashjoin_build_sink.cpp @@ -175,7 +175,15 @@ Status HashJoinBuildSinkLocalState::_extract_join_column( vectorized::Block& block, vectorized::ColumnUInt8::MutablePtr& null_map, vectorized::ColumnRawPtrs& raw_ptrs, const std::vector& res_col_ids) { auto& shared_state = *_shared_state; + auto& p = _parent->cast(); for (size_t i = 0; i < shared_state.build_exprs_size; ++i) { + if (p._should_convert_to_nullable[i]) { + _key_columns_holder.emplace_back( + vectorized::make_nullable(block.get_by_position(res_col_ids[i]).column)); + raw_ptrs[i] = _key_columns_holder.back().get(); + continue; + } + if (shared_state.is_null_safe_eq_join[i]) { raw_ptrs[i] = block.get_by_position(res_col_ids[i]).column.get(); } else { @@ -411,7 +419,11 @@ Status HashJoinBuildSinkOperatorX::init(const TPlanNode& tnode, RuntimeState* st const bool is_null_safe_equal = eq_join_conjunct.__isset.opcode && eq_join_conjunct.opcode == TExprOpcode::EQ_FOR_NULL; + const bool should_convert_to_nullable = is_null_safe_equal && + !eq_join_conjunct.right.nodes[0].is_nullable && + eq_join_conjunct.left.nodes[0].is_nullable; _is_null_safe_eq_join.push_back(is_null_safe_equal); + _should_convert_to_nullable.emplace_back(should_convert_to_nullable); // if is null aware, build join column and probe join column both need dispose null value _store_null_in_hash_table.emplace_back( diff --git a/be/src/pipeline/exec/hashjoin_build_sink.h b/be/src/pipeline/exec/hashjoin_build_sink.h index 5812c0861a..0849027e32 100644 --- a/be/src/pipeline/exec/hashjoin_build_sink.h +++ b/be/src/pipeline/exec/hashjoin_build_sink.h @@ -90,6 +90,7 @@ protected: // build expr vectorized::VExprContextSPtrs _build_expr_ctxs; + std::vector _key_columns_holder; bool _should_build_hash_table = true; int64_t _build_side_mem_used = 0; @@ -173,6 +174,8 @@ private: // mark the join column whether support null eq std::vector _is_null_safe_eq_join; + std::vector _should_convert_to_nullable; + bool _is_broadcast_join = false; std::shared_ptr _shared_hashtable_controller; diff --git a/be/src/pipeline/exec/hashjoin_probe_operator.cpp b/be/src/pipeline/exec/hashjoin_probe_operator.cpp index 014d728584..e6a605405a 100644 --- a/be/src/pipeline/exec/hashjoin_probe_operator.cpp +++ b/be/src/pipeline/exec/hashjoin_probe_operator.cpp @@ -214,6 +214,7 @@ void HashJoinProbeLocalState::_prepare_probe_block() { column_type.column = remove_nullable(column_type.column); column_type.type = remove_nullable(column_type.type); } + _key_columns_holder.clear(); _probe_block.clear_column_data(_parent->get_child()->row_desc().num_materialized_slots()); } @@ -374,7 +375,15 @@ Status HashJoinProbeLocalState::_extract_join_column(vectorized::Block& block, vectorized::ColumnRawPtrs& raw_ptrs, const std::vector& res_col_ids) { auto& shared_state = *_shared_state; + auto& p = _parent->cast(); for (size_t i = 0; i < shared_state.build_exprs_size; ++i) { + if (p._should_convert_to_nullable[i]) { + _key_columns_holder.emplace_back( + vectorized::make_nullable(block.get_by_position(res_col_ids[i]).column)); + raw_ptrs[i] = _key_columns_holder.back().get(); + continue; + } + if (shared_state.is_null_safe_eq_join[i]) { raw_ptrs[i] = block.get_by_position(res_col_ids[i]).column.get(); } else { @@ -524,6 +533,18 @@ Status HashJoinProbeOperatorX::init(const TPlanNode& tnode, RuntimeState* state) null_aware || (_probe_expr_ctxs.back()->root()->is_nullable() && probe_dispose_null); conjuncts_index++; + const bool is_null_safe_equal = eq_join_conjunct.__isset.opcode && + eq_join_conjunct.opcode == TExprOpcode::EQ_FOR_NULL; + + /// If it's right anti join, + /// we should convert the probe to nullable if the build side is nullable. + /// And if it is 'null safe equal', + /// we must make sure the build side and the probe side are both nullable or non-nullable. + const bool should_convert_to_nullable = + (is_null_safe_equal || _join_op == TJoinOp::RIGHT_ANTI_JOIN) && + !eq_join_conjunct.left.nodes[0].is_nullable && + eq_join_conjunct.right.nodes[0].is_nullable; + _should_convert_to_nullable.emplace_back(should_convert_to_nullable); } for (size_t i = 0; i < _probe_expr_ctxs.size(); ++i) { _probe_ignore_null |= !probe_not_ignore_null[i]; diff --git a/be/src/pipeline/exec/hashjoin_probe_operator.h b/be/src/pipeline/exec/hashjoin_probe_operator.h index af2a255d61..b4930307bc 100644 --- a/be/src/pipeline/exec/hashjoin_probe_operator.h +++ b/be/src/pipeline/exec/hashjoin_probe_operator.h @@ -124,6 +124,8 @@ private: vectorized::VExprContextSPtrs _mark_join_conjuncts; + std::vector _key_columns_holder; + // probe expr vectorized::VExprContextSPtrs _probe_expr_ctxs; std::vector _probe_column_disguise_null; @@ -194,6 +196,8 @@ private: vectorized::VExprContextSPtrs _probe_expr_ctxs; bool _probe_ignore_null = false; + std::vector _should_convert_to_nullable; + vectorized::DataTypes _right_table_data_types; vectorized::DataTypes _left_table_data_types; std::vector _hash_output_slot_ids; diff --git a/be/src/vec/exec/join/vhash_join_node.cpp b/be/src/vec/exec/join/vhash_join_node.cpp index 4d3a2aa947..44074f3bcf 100644 --- a/be/src/vec/exec/join/vhash_join_node.cpp +++ b/be/src/vec/exec/join/vhash_join_node.cpp @@ -123,13 +123,23 @@ Status HashJoinNode::init(const TPlanNode& tnode, RuntimeState* state) { eq_join_conjunct.opcode == TExprOpcode::EQ_FOR_NULL; _is_null_safe_eq_join.push_back(null_aware); + const bool build_side_nullable = _build_expr_ctxs.back()->root()->is_nullable(); + const bool probe_side_nullable = _probe_expr_ctxs.back()->root()->is_nullable(); // if is null aware, build join column and probe join column both need dispose null value - _store_null_in_hash_table.emplace_back( - null_aware || - (_build_expr_ctxs.back()->root()->is_nullable() && build_stores_null)); + _store_null_in_hash_table.emplace_back(null_aware || + (build_side_nullable && build_stores_null)); probe_not_ignore_null[conjuncts_index] = - null_aware || - (_probe_expr_ctxs.back()->root()->is_nullable() && probe_dispose_null); + null_aware || (probe_side_nullable && probe_dispose_null); + + const bool should_convert_build_side_to_nullable = + null_aware && !build_side_nullable && probe_side_nullable; + const bool should_convert_probe_side_to_nullable = + (null_aware || _join_op == TJoinOp::RIGHT_ANTI_JOIN) && build_side_nullable && + !probe_side_nullable; + + _should_convert_build_side_to_nullable.emplace_back(should_convert_build_side_to_nullable); + _should_convert_probe_side_to_nullable.emplace_back(should_convert_probe_side_to_nullable); + conjuncts_index++; } for (size_t i = 0; i < _probe_expr_ctxs.size(); ++i) { @@ -594,7 +604,7 @@ void HashJoinNode::_prepare_probe_block() { column_type.column = remove_nullable(column_type.column); column_type.type = remove_nullable(column_type.type); } - _temp_probe_nullable_columns.clear(); + _key_columns_holder.clear(); release_block_memory(_probe_block); } @@ -841,8 +851,17 @@ Status HashJoinNode::_extract_join_column(Block& block, ColumnUInt8::MutablePtr& ColumnRawPtrs& raw_ptrs, const std::vector& res_col_ids) { DCHECK_EQ(_build_expr_ctxs.size(), _probe_expr_ctxs.size()); - _temp_probe_nullable_columns.clear(); + _key_columns_holder.clear(); + auto& should_convert_to_nullable = BuildSide ? _should_convert_build_side_to_nullable + : _should_convert_probe_side_to_nullable; for (size_t i = 0; i < _build_expr_ctxs.size(); ++i) { + if (should_convert_to_nullable[i]) { + _key_columns_holder.emplace_back( + make_nullable(block.get_by_position(res_col_ids[i]).column)); + raw_ptrs[i] = _key_columns_holder.back().get(); + continue; + } + if (_is_null_safe_eq_join[i]) { raw_ptrs[i] = block.get_by_position(res_col_ids[i]).column.get(); } else { @@ -865,15 +884,6 @@ Status HashJoinNode::_extract_join_column(Block& block, ColumnUInt8::MutablePtr& raw_ptrs[i] = &col_nested; } } else { - if constexpr (!BuildSide) { - if (_join_op == TJoinOp::RIGHT_ANTI_JOIN && - _build_expr_ctxs[i]->root()->is_nullable()) { - _temp_probe_nullable_columns.emplace_back(make_nullable( - block.get_by_position(res_col_ids[i]).column->assume_mutable())); - raw_ptrs[i] = _temp_probe_nullable_columns.back().get(); - continue; - } - } raw_ptrs[i] = column; } } diff --git a/be/src/vec/exec/join/vhash_join_node.h b/be/src/vec/exec/join/vhash_join_node.h index 58451c360e..1559a5dfa5 100644 --- a/be/src/vec/exec/join/vhash_join_node.h +++ b/be/src/vec/exec/join/vhash_join_node.h @@ -281,9 +281,11 @@ private: // mark the build hash table whether it needs to store null value std::vector _store_null_in_hash_table; + std::vector _should_convert_build_side_to_nullable; + std::vector _should_convert_probe_side_to_nullable; // In right anti join, if the probe side is not nullable and the build side is nullable, // we need to convert the probe column to nullable. - std::vector _temp_probe_nullable_columns; + std::vector _key_columns_holder; std::vector _probe_column_disguise_null; std::vector _probe_column_convert_to_null; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqual.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqual.java index c215e65f72..6507f49825 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqual.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/NullSafeEqualToEqual.java @@ -54,7 +54,7 @@ public class NullSafeEqualToEqual extends DefaultExpressionRewriter "abc" to "A = "abc" + // "A(nullable)<=>B" not changed @Test - void testNullSafeEqualToEqual() { - executor = new ExpressionRuleExecutor(ImmutableList.of(NullSafeEqualToEqual.INSTANCE)); - SlotReference slot = new SlotReference("a", StringType.INSTANCE, true); - StringLiteral str = new StringLiteral("abc"); - assertRewrite(new NullSafeEqual(slot, str), new EqualTo(slot, str)); - } - - // "A<=>B" not changed - @Test - void testNullSafeEqualNotChanged() { + void testNullSafeEqualNotChangedLeft() { executor = new ExpressionRuleExecutor(ImmutableList.of(NullSafeEqualToEqual.INSTANCE)); SlotReference a = new SlotReference("a", StringType.INSTANCE, true); + SlotReference b = new SlotReference("b", StringType.INSTANCE, false); + assertRewrite(new NullSafeEqual(a, b), new NullSafeEqual(a, b)); + } + + // "A<=>B(nullable)" not changed + @Test + void testNullSafeEqualNotChangedRight() { + executor = new ExpressionRuleExecutor(ImmutableList.of(NullSafeEqualToEqual.INSTANCE)); + SlotReference a = new SlotReference("a", StringType.INSTANCE, false); SlotReference b = new SlotReference("b", StringType.INSTANCE, true); assertRewrite(new NullSafeEqual(a, b), new NullSafeEqual(a, b)); } + + // "A<=>B" changed + @Test + void testNullSafeEqualToEqual() { + executor = new ExpressionRuleExecutor(ImmutableList.of(NullSafeEqualToEqual.INSTANCE)); + SlotReference a = new SlotReference("a", StringType.INSTANCE, false); + SlotReference b = new SlotReference("b", StringType.INSTANCE, false); + assertRewrite(new NullSafeEqual(a, b), new EqualTo(a, b)); + } } diff --git a/regression-test/suites/query_p0/join/test_half_join_nullable_build_side.groovy b/regression-test/suites/query_p0/join/test_half_join_nullable_build_side.groovy index 428edf315f..bddccb26ab 100644 --- a/regression-test/suites/query_p0/join/test_half_join_nullable_build_side.groovy +++ b/regression-test/suites/query_p0/join/test_half_join_nullable_build_side.groovy @@ -16,8 +16,6 @@ // under the License. suite("test_half_join_nullable_build_side", "query,p0") { - /// TODO: fix on pipelinex - sql " set ENABLE_PIPELINE_X_ENGINE = 0; " sql " set disable_join_reorder = 1; " sql " drop table if exists test_half_join_nullable_build_side_l; "; sql " drop table if exists test_half_join_nullable_build_side_l2; ";