diff --git a/be/src/exec/aggregation_node.cpp b/be/src/exec/aggregation_node.cpp index 5657d8fa1a..5f9beeb2ab 100644 --- a/be/src/exec/aggregation_node.cpp +++ b/be/src/exec/aggregation_node.cpp @@ -148,7 +148,8 @@ Status AggregationNode::prepare(RuntimeState* state) { // TODO: how many buckets? _hash_tbl.reset(new HashTable( - _build_expr_ctxs, _probe_expr_ctxs, 1, true, id(), mem_tracker(), 1024)); + _build_expr_ctxs, _probe_expr_ctxs, 1, true, + vector(_build_expr_ctxs.size(), false), id(), mem_tracker(), 1024)); if (_probe_expr_ctxs.empty()) { // create single output tuple now; we need to output something diff --git a/be/src/exec/hash_join_node.cpp b/be/src/exec/hash_join_node.cpp index bfb1b3f097..9d2a6c2a62 100644 --- a/be/src/exec/hash_join_node.cpp +++ b/be/src/exec/hash_join_node.cpp @@ -71,6 +71,12 @@ Status HashJoinNode::init(const TPlanNode& tnode, RuntimeState* state) { _probe_expr_ctxs.push_back(ctx); RETURN_IF_ERROR(Expr::create_expr_tree(_pool, eq_join_conjuncts[i].right, &ctx)); _build_expr_ctxs.push_back(ctx); + if (eq_join_conjuncts[i].__isset.opcode + && eq_join_conjuncts[i].opcode == TExprOpcode::EQ_FOR_NULL) { + _is_null_safe_eq_join.push_back(true); + } else { + _is_null_safe_eq_join.push_back(false); + } } RETURN_IF_ERROR( @@ -131,13 +137,13 @@ Status HashJoinNode::prepare(RuntimeState* state) { _build_tuple_row_size = num_build_tuples * sizeof(Tuple*); // TODO: default buckets - const bool stores_nulls = _join_op == TJoinOp::RIGHT_OUTER_JOIN + const bool null_preserved = _join_op == TJoinOp::RIGHT_OUTER_JOIN || _join_op == TJoinOp::FULL_OUTER_JOIN || _join_op == TJoinOp::RIGHT_ANTI_JOIN || _join_op == TJoinOp::RIGHT_SEMI_JOIN; _hash_tbl.reset(new HashTable( _build_expr_ctxs, _probe_expr_ctxs, _build_tuple_size, - stores_nulls, id(), mem_tracker(), 1024)); + null_preserved, _is_null_safe_eq_join, id(), mem_tracker(), 1024)); _probe_batch.reset(new RowBatch(child(0)->row_desc(), state->batch_size(), mem_tracker())); @@ -286,6 +292,15 @@ Status HashJoinNode::open(RuntimeState* state) { _is_push_down = false; } + // The predicate could not be pushed down when there is Null-safe equal operator. + // The in predicate will filter the null value in child[0] while it is needed in the Null-safe equal join. + // For example: select * from a join b where a.id<=>b.id + // the null value in table a should be return by scan node instead of filtering it by In-predicate. + if (std::find(_is_null_safe_eq_join.begin(), _is_null_safe_eq_join.end(), + true) != _is_null_safe_eq_join.end()) { + _is_push_down = false; + } + if (_is_push_down) { // Blocks until ConstructHashTable has returned, after which // the hash table is fully constructed and we can start the probe diff --git a/be/src/exec/hash_join_node.h b/be/src/exec/hash_join_node.h index 1f054d2409..b00c60881d 100644 --- a/be/src/exec/hash_join_node.h +++ b/be/src/exec/hash_join_node.h @@ -78,6 +78,9 @@ private: // _build_exprs (over child(1)) and _probe_exprs (over child(0)) std::vector _probe_expr_ctxs; std::vector _build_expr_ctxs; + // true: the operator of eq join predicate is null safe equal => '<=>' + // false: the operator of eq join predicate is equal => '=' + std::vector _is_null_safe_eq_join; std::list _push_down_expr_ctxs; // non-equi-join conjuncts from the JOIN clause diff --git a/be/src/exec/hash_table.cpp b/be/src/exec/hash_table.cpp index a312e25184..ac94470be4 100644 --- a/be/src/exec/hash_table.cpp +++ b/be/src/exec/hash_table.cpp @@ -42,12 +42,18 @@ const char* HashTable::_s_llvm_class_name = "class.doris::HashTable"; HashTable::HashTable(const vector& build_expr_ctxs, const vector& probe_expr_ctxs, - int num_build_tuples, bool stores_nulls, int32_t initial_seed, + int num_build_tuples, bool null_preserved, + const std::vector& finds_nulls, + int32_t initial_seed, MemTracker* mem_tracker, int64_t num_buckets) : _build_expr_ctxs(build_expr_ctxs), _probe_expr_ctxs(probe_expr_ctxs), _num_build_tuples(num_build_tuples), - _stores_nulls(stores_nulls), + _null_preserved(null_preserved), + _finds_nulls(finds_nulls), + _stores_nulls(null_preserved + || (std::find(finds_nulls.begin(), finds_nulls.end(), + true) != finds_nulls.end())), _initial_seed(initial_seed), _node_byte_size(sizeof(Node) + sizeof(Tuple*) * _num_build_tuples), _num_filled_buckets(0), @@ -180,7 +186,7 @@ bool HashTable::equals(TupleRow* build_row) { void* val = _build_expr_ctxs[i]->get_value(build_row); if (val == NULL) { - if (!_stores_nulls) { + if (!(_null_preserved && _finds_nulls[i])) { return false; } diff --git a/be/src/exec/hash_table.h b/be/src/exec/hash_table.h index fe22fc23bf..1b2c649a25 100644 --- a/be/src/exec/hash_table.h +++ b/be/src/exec/hash_table.h @@ -98,7 +98,9 @@ public: HashTable( const std::vector& build_exprs, const std::vector& probe_exprs, - int num_build_tuples, bool stores_nulls, int32_t initial_seed, + int num_build_tuples, bool stores_nulls, + const std::vector& finds_nulls, + int32_t initial_seed, MemTracker* mem_tracker, int64_t num_buckets); @@ -384,6 +386,12 @@ private: // Number of Tuple* in the build tuple row const int _num_build_tuples; + // the row in hash table is preserved such as RIGHT_OUTER_JOIN + const bool _null_preserved; + // true: the null-safe equal '<=>' is true. The row with null shoud be judged. + // false: the equal '=' is false. The row with null should be filtered. + const std::vector _finds_nulls; + // outer join || has null equal join should be true const bool _stores_nulls; const int32_t _initial_seed; diff --git a/be/src/exec/olap_scan_node.cpp b/be/src/exec/olap_scan_node.cpp index c039c3ef26..aa911ee08b 100644 --- a/be/src/exec/olap_scan_node.cpp +++ b/be/src/exec/olap_scan_node.cpp @@ -193,9 +193,10 @@ Status OlapScanNode::open(RuntimeState* state) { SCOPED_TIMER(_runtime_profile->total_time_counter()); RETURN_IF_CANCELLED(state); RETURN_IF_ERROR(ExecNode::open(state)); - + for (int conj_idx = 0; conj_idx < _conjunct_ctxs.size(); ++conj_idx) { // if conjunct is constant, compute direct and set eos = true + if (_conjunct_ctxs[conj_idx]->root()->is_constant()) { void* value = _conjunct_ctxs[conj_idx]->get_value(NULL); if (value == NULL || *reinterpret_cast(value) == false) { diff --git a/fe/src/main/java/org/apache/doris/analysis/Analyzer.java b/fe/src/main/java/org/apache/doris/analysis/Analyzer.java index 50c8ae6d4c..2f2ef7c406 100644 --- a/fe/src/main/java/org/apache/doris/analysis/Analyzer.java +++ b/fe/src/main/java/org/apache/doris/analysis/Analyzer.java @@ -821,7 +821,7 @@ public class Analyzer { return; } BinaryPredicate binaryPred = (BinaryPredicate) e; - if (binaryPred.getOp() != BinaryPredicate.Operator.EQ) { + if (!binaryPred.getOp().isEquivalence()) { return; } if (tupleIds.size() < 2) { diff --git a/fe/src/main/java/org/apache/doris/planner/DistributedPlanner.java b/fe/src/main/java/org/apache/doris/planner/DistributedPlanner.java index f225dcc1aa..280eb28b43 100644 --- a/fe/src/main/java/org/apache/doris/planner/DistributedPlanner.java +++ b/fe/src/main/java/org/apache/doris/planner/DistributedPlanner.java @@ -18,6 +18,7 @@ package org.apache.doris.planner; import org.apache.doris.analysis.AggregateInfo; +import org.apache.doris.analysis.BinaryPredicate; import org.apache.doris.analysis.Expr; import org.apache.doris.analysis.InsertStmt; import org.apache.doris.analysis.JoinOperator; @@ -33,7 +34,6 @@ import org.apache.doris.catalog.OlapTable; import org.apache.doris.catalog.Table; import org.apache.doris.common.AnalysisException; import org.apache.doris.common.Config; -import org.apache.doris.common.Pair; import org.apache.doris.common.UserException; import org.apache.doris.qe.ConnectContext; import org.apache.doris.thrift.TPartitionType; @@ -384,13 +384,13 @@ public class DistributedPlanner { // TODO: create equivalence classes based on equality predicates // first, extract join exprs - List> eqJoinConjuncts = node.getEqJoinConjuncts(); + List eqJoinConjuncts = node.getEqJoinConjuncts(); List lhsJoinExprs = Lists.newArrayList(); List rhsJoinExprs = Lists.newArrayList(); - for (Pair pair : eqJoinConjuncts) { + for (BinaryPredicate eqJoinPredicate : eqJoinConjuncts) { // no remapping necessary - lhsJoinExprs.add(pair.first.clone(null)); - rhsJoinExprs.add(pair.second.clone(null)); + lhsJoinExprs.add(eqJoinPredicate.getChild(0).clone(null)); + rhsJoinExprs.add(eqJoinPredicate.getChild(1).clone(null)); } // create the parent fragment containing the HashJoin node @@ -489,14 +489,16 @@ public class DistributedPlanner { List leftColumns = ((HashDistributionInfo) leftDistribution).getDistributionColumns(); List rightColumns = ((HashDistributionInfo) rightDistribution).getDistributionColumns(); - List> eqJoinConjuncts = node.getEqJoinConjuncts(); - for (Pair eqJoinPredicate : eqJoinConjuncts) { - if (eqJoinPredicate.first.unwrapSlotRef() == null || eqJoinPredicate.second.unwrapSlotRef() == null) { + List eqJoinConjuncts = node.getEqJoinConjuncts(); + for (BinaryPredicate eqJoinPredicate : eqJoinConjuncts) { + Expr lhsJoinExpr = eqJoinPredicate.getChild(0); + Expr rhsJoinExpr = eqJoinPredicate.getChild(1); + if (lhsJoinExpr.unwrapSlotRef() == null || rhsJoinExpr.unwrapSlotRef() == null) { continue; } - SlotDescriptor leftSlot = eqJoinPredicate.first.unwrapSlotRef().getDesc(); - SlotDescriptor rightSlot = eqJoinPredicate.second.unwrapSlotRef().getDesc(); + SlotDescriptor leftSlot = lhsJoinExpr.unwrapSlotRef().getDesc(); + SlotDescriptor rightSlot = rhsJoinExpr.unwrapSlotRef().getDesc(); //3 the eqJoinConjuncts must contain the distributionColumns if (leftColumns.contains(leftSlot.getColumn()) && rightColumns.contains(rightSlot.getColumn())) { diff --git a/fe/src/main/java/org/apache/doris/planner/HashJoinNode.java b/fe/src/main/java/org/apache/doris/planner/HashJoinNode.java index cd35ab0604..d425791d88 100644 --- a/fe/src/main/java/org/apache/doris/planner/HashJoinNode.java +++ b/fe/src/main/java/org/apache/doris/planner/HashJoinNode.java @@ -18,6 +18,7 @@ package org.apache.doris.planner; import org.apache.doris.analysis.Analyzer; +import org.apache.doris.analysis.BinaryPredicate; import org.apache.doris.analysis.Expr; import org.apache.doris.analysis.ExprSubstitutionMap; import org.apache.doris.analysis.JoinOperator; @@ -26,7 +27,6 @@ import org.apache.doris.analysis.SlotId; import org.apache.doris.analysis.SlotRef; import org.apache.doris.analysis.TableRef; import org.apache.doris.catalog.ColumnStats; -import org.apache.doris.common.Pair; import org.apache.doris.common.UserException; import org.apache.doris.thrift.TEqJoinCondition; import org.apache.doris.thrift.TExplainLevel; @@ -42,6 +42,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import java.util.List; +import java.util.stream.Collectors; /** * Hash join between left child and right child. @@ -53,8 +54,8 @@ public class HashJoinNode extends PlanNode { private final TableRef innerRef; private final JoinOperator joinOp; - // conjuncts of the form " = ", recorded as Pair(, ) - private List> eqJoinConjuncts; + // predicates of the form 'a=b' or 'a<=>b' + private List eqJoinConjuncts = Lists.newArrayList(); // join conjuncts from the JOIN clause that aren't equi-join predicates private List otherJoinConjuncts; private boolean isPushDown; @@ -63,9 +64,9 @@ public class HashJoinNode extends PlanNode { private String colocateReason = ""; // if can not do colocate join, set reason here public HashJoinNode(PlanNodeId id, PlanNode outer, PlanNode inner, TableRef innerRef, - List> eqJoinConjuncts, List otherJoinConjuncts) { + List eqJoinConjuncts, List otherJoinConjuncts) { super(id, "HASH JOIN"); - Preconditions.checkArgument(eqJoinConjuncts != null); + Preconditions.checkArgument(eqJoinConjuncts != null && !eqJoinConjuncts.isEmpty()); Preconditions.checkArgument(otherJoinConjuncts != null); tupleIds.addAll(outer.getTupleIds()); tupleIds.addAll(inner.getTupleIds()); @@ -73,8 +74,11 @@ public class HashJoinNode extends PlanNode { tblRefIds.addAll(inner.getTblRefIds()); this.innerRef = innerRef; this.joinOp = innerRef.getJoinOp(); + for (Expr eqJoinPredicate : eqJoinConjuncts) { + Preconditions.checkArgument(eqJoinPredicate instanceof BinaryPredicate); + this.eqJoinConjuncts.add((BinaryPredicate) eqJoinPredicate); + } this.distrMode = DistributionMode.NONE; - this.eqJoinConjuncts = eqJoinConjuncts; this.otherJoinConjuncts = otherJoinConjuncts; children.add(outer); children.add(inner); @@ -94,7 +98,7 @@ public class HashJoinNode extends PlanNode { } } - public List> getEqJoinConjuncts() { + public List getEqJoinConjuncts() { return eqJoinConjuncts; } @@ -134,17 +138,12 @@ public class HashJoinNode extends PlanNode { //assignedConjuncts = analyzr.getAssignedConjuncts(); ExprSubstitutionMap combinedChildSmap = getCombinedChildWithoutTupleIsNullSmap(); + List newEqJoinConjuncts = + Expr.substituteList(eqJoinConjuncts, combinedChildSmap, analyzer, false); + eqJoinConjuncts = newEqJoinConjuncts.stream() + .map(entity -> (BinaryPredicate) entity).collect(Collectors.toList()); otherJoinConjuncts = Expr.substituteList(otherJoinConjuncts, combinedChildSmap, analyzer, false); - - List> newEqJoinConjuncts = Lists.newArrayList(); - for (Pair c: eqJoinConjuncts) { - Pair p = - new Pair(c.first.clone(combinedChildSmap), c.second.clone(combinedChildSmap)); - newEqJoinConjuncts.add( - new Pair(c.first.clone(combinedChildSmap), c.second.clone(combinedChildSmap))); - } - eqJoinConjuncts = newEqJoinConjuncts; } @Override @@ -174,11 +173,13 @@ public class HashJoinNode extends PlanNode { // - the output cardinality of the join would be F.cardinality * 0.2 long maxNumDistinct = 0; - for (Pair eqJoinPredicate : eqJoinConjuncts) { - if (eqJoinPredicate.first.unwrapSlotRef() == null) { + for (BinaryPredicate eqJoinPredicate : eqJoinConjuncts) { + Expr lhsJoinExpr = eqJoinPredicate.getChild(0); + Expr rhsJoinExpr = eqJoinPredicate.getChild(1); + if (lhsJoinExpr.unwrapSlotRef() == null) { continue; } - SlotRef rhsSlotRef = eqJoinPredicate.second.unwrapSlotRef(); + SlotRef rhsSlotRef = rhsJoinExpr.unwrapSlotRef(); if (rhsSlotRef == null) { continue; } @@ -229,8 +230,8 @@ public class HashJoinNode extends PlanNode { private String eqJoinConjunctsDebugString() { Objects.ToStringHelper helper = Objects.toStringHelper(this); - for (Pair entry : eqJoinConjuncts) { - helper.add("lhs", entry.first).add("rhs", entry.second); + for (BinaryPredicate expr : eqJoinConjuncts) { + helper.add("lhs", expr.getChild(0)).add("rhs", expr.getChild(1)); } return helper.toString(); } @@ -240,9 +241,8 @@ public class HashJoinNode extends PlanNode { super.getMaterializedIds(analyzer, ids); // we also need to materialize everything referenced by eqJoinConjuncts // and otherJoinConjuncts - for (Pair p : eqJoinConjuncts) { - p.first.getIds(null, ids); - p.second.getIds(null, ids); + for (Expr eqJoinPredicate : eqJoinConjuncts) { + eqJoinPredicate.getIds(null, ids); } for (Expr e : otherJoinConjuncts) { e.getIds(null, ids); @@ -258,9 +258,11 @@ public class HashJoinNode extends PlanNode { msg.node_type = TPlanNodeType.HASH_JOIN_NODE; msg.hash_join_node = new THashJoinNode(); msg.hash_join_node.join_op = joinOp.toThrift(); - for (Pair entry : eqJoinConjuncts) { + for (BinaryPredicate eqJoinPredicate : eqJoinConjuncts) { TEqJoinCondition eqJoinCondition = - new TEqJoinCondition(entry.first.treeToThrift(), entry.second.treeToThrift()); + new TEqJoinCondition(eqJoinPredicate.getChild(0).treeToThrift(), + eqJoinPredicate.getChild(1).treeToThrift()); + eqJoinCondition.setOpcode(eqJoinPredicate.getOp().getOpcode()); msg.hash_join_node.addToEq_join_conjuncts(eqJoinCondition); } for (Expr e : otherJoinConjuncts) { @@ -280,9 +282,8 @@ public class HashJoinNode extends PlanNode { output.append(detailPrefix + "colocate: " + isColocate + (isColocate? "" : ", reason: " + colocateReason) + "\n"); - for (Pair entry : eqJoinConjuncts) { - output.append(detailPrefix + " " + - entry.first.toSql() + " = " + entry.second.toSql() + "\n"); + for (BinaryPredicate eqJoinPredicate : eqJoinConjuncts) { + output.append(eqJoinPredicate.toSql() + "\n"); } if (!otherJoinConjuncts.isEmpty()) { output.append(detailPrefix + "other join predicates: ").append( diff --git a/fe/src/main/java/org/apache/doris/planner/SingleNodePlanner.java b/fe/src/main/java/org/apache/doris/planner/SingleNodePlanner.java index 390ac122e6..591c7ec9e7 100644 --- a/fe/src/main/java/org/apache/doris/planner/SingleNodePlanner.java +++ b/fe/src/main/java/org/apache/doris/planner/SingleNodePlanner.java @@ -56,7 +56,6 @@ import org.apache.doris.catalog.Table; import org.apache.doris.catalog.FunctionSet; import org.apache.doris.common.AnalysisException; import org.apache.doris.catalog.AggregateFunction; -import org.apache.doris.common.Pair; import org.apache.doris.common.Reference; import org.apache.doris.common.UserException; import org.apache.logging.log4j.LogManager; @@ -1286,14 +1285,12 @@ public class SingleNodePlanner { /** * Return join conjuncts that can be used for hash table lookups. - for inner joins, those are equi-join predicates * in which one side is fully bound by lhsIds and the other by rhs' id; - for outer joins: same type of conjuncts as - * inner joins, but only from the JOIN clause Returns the conjuncts in 'joinConjuncts' (in which " = " is - * returned as Pair(, )) and also in their original form in 'joinPredicates'. + * inner joins, but only from the JOIN clause Returns the original form in 'joinPredicates'. */ private void getHashLookupJoinConjuncts(Analyzer analyzer, PlanNode left, PlanNode right, - List> joinConjuncts, List joinPredicates, + List joinConjuncts, Reference errMsg, JoinOperator op) { joinConjuncts.clear(); - joinPredicates.clear(); final List lhsIds = left.getTblRefIds(); final List rhsIds = right.getTblRefIds(); List candidates; @@ -1333,9 +1330,7 @@ public class SingleNodePlanner { } Preconditions.checkState(lhsExpr != rhsExpr); - joinPredicates.add(e); - Pair entry = Pair.create(lhsExpr, rhsExpr); - joinConjuncts.add(entry); + joinConjuncts.add(e); } } @@ -1351,14 +1346,13 @@ public class SingleNodePlanner { // materialized by that node PlanNode inner = createTableRefNode(analyzer, innerRef); - List> eqJoinConjuncts = Lists.newArrayList(); - List eqJoinPredicates = Lists.newArrayList(); + List eqJoinConjuncts = Lists.newArrayList(); Reference errMsg = new Reference(); // get eq join predicates for the TableRefs' ids (not the PlanNodes' ids, which // are materialized) - getHashLookupJoinConjuncts(analyzer, outer, inner, eqJoinConjuncts, - eqJoinPredicates, errMsg, innerRef.getJoinOp()); - if (eqJoinPredicates.isEmpty()) { + getHashLookupJoinConjuncts(analyzer, outer, inner, + eqJoinConjuncts, errMsg, innerRef.getJoinOp()); + if (eqJoinConjuncts.isEmpty()) { // only inner join can change to cross join if (innerRef.getJoinOp().isOuterJoin() || innerRef.getJoinOp().isSemiAntiJoin()) { @@ -1375,7 +1369,7 @@ public class SingleNodePlanner { result.init(analyzer); return result; } - analyzer.markConjunctsAssigned(eqJoinPredicates); + analyzer.markConjunctsAssigned(eqJoinConjuncts); List ojConjuncts = Lists.newArrayList(); if (innerRef.getJoinOp().isOuterJoin()) { diff --git a/gensrc/thrift/PlanNodes.thrift b/gensrc/thrift/PlanNodes.thrift index 73a639d21f..0809408b2d 100644 --- a/gensrc/thrift/PlanNodes.thrift +++ b/gensrc/thrift/PlanNodes.thrift @@ -20,6 +20,7 @@ namespace java org.apache.doris.thrift include "Exprs.thrift" include "Types.thrift" +include "Opcodes.thrift" include "Partitions.thrift" enum TPlanNodeType { @@ -260,6 +261,8 @@ struct TEqJoinCondition { 1: required Exprs.TExpr left; // right-hand side of " = " 2: required Exprs.TExpr right; + // operator of equal join + 3: optional Opcodes.TExprOpcode opcode; } enum TJoinOp {