diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java index f5fd89a6d5..8531321e3f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java @@ -26,6 +26,7 @@ import org.apache.doris.nereids.trees.plans.algebra.Join; import org.apache.doris.statistics.ColumnStatistic; import org.apache.doris.statistics.StatsDeriveResult; +import com.google.common.base.Preconditions; import com.google.common.collect.Maps; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -37,7 +38,21 @@ import org.apache.logging.log4j.Logger; public class JoinEstimation { private static final Logger LOG = LogManager.getLogger(JoinEstimation.class); - private static double estimateInnerJoin(Join join, EqualTo equalto, + private static double estimateInnerJoin(StatsDeriveResult leftStats, StatsDeriveResult rightStats, Join join) { + Preconditions.checkArgument(join.getJoinType() == JoinType.INNER_JOIN); + double rowCount = Double.MAX_VALUE; + if (join.getHashJoinConjuncts().isEmpty()) { + rowCount = leftStats.getRowCount() * rightStats.getRowCount(); + } else { + for (Expression equalTo : join.getHashJoinConjuncts()) { + double tmpRowCount = estimateEqualJoinCondition((EqualTo) equalTo, leftStats, rightStats); + rowCount = Math.min(rowCount, tmpRowCount); + } + } + return rowCount; + } + + private static double estimateEqualJoinCondition(EqualTo equalto, StatsDeriveResult leftStats, StatsDeriveResult rightStats) { SlotReference eqRight = (SlotReference) equalto.child(1).getInputSlots().toArray()[0]; @@ -72,6 +87,11 @@ public class JoinEstimation { return leftCount - leftCount / Math.max(2, rightCount); } + private static double estimateFullOuterJoin(StatsDeriveResult leftStats, StatsDeriveResult rightStats, Join join) { + //TODO: after we have histogram, re-design this logical + return leftStats.getRowCount() + rightStats.getRowCount(); + } + /** * estimate join */ @@ -95,15 +115,7 @@ public class JoinEstimation { rowCount = estimateLeftSemiJoin(rightCount, leftCount); } } else if (joinType == JoinType.INNER_JOIN) { - if (join.getHashJoinConjuncts().isEmpty()) { - rowCount = leftStats.getRowCount() * rightStats.getRowCount(); - } else { - for (Expression joinConjunct : join.getHashJoinConjuncts()) { - double tmpRowCount = estimateInnerJoin(join, - (EqualTo) joinConjunct, leftStats, rightStats); - rowCount = Math.min(rowCount, tmpRowCount); - } - } + rowCount = estimateInnerJoin(leftStats, rightStats, join); } else if (joinType == JoinType.LEFT_OUTER_JOIN) { rowCount = leftStats.getRowCount(); } else if (joinType == JoinType.RIGHT_OUTER_JOIN) { @@ -111,6 +123,8 @@ public class JoinEstimation { } else if (joinType == JoinType.CROSS_JOIN) { rowCount = CheckedMath.checkedMultiply(leftStats.getRowCount(), rightStats.getRowCount()); + } else if (joinType == JoinType.FULL_OUTER_JOIN) { + rowCount = estimateFullOuterJoin(leftStats, rightStats, join); } else { LOG.warn("join type is not supported: " + joinType); throw new RuntimeException("joinType is not supported");