diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java index 44def4bcaa..261d4385ce 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java @@ -294,6 +294,11 @@ public class JoinEstimation { */ public static Statistics estimate(Statistics leftStats, Statistics rightStats, Join join) { JoinType joinType = join.getJoinType(); + Statistics crossJoinStats = new StatisticsBuilder() + .setRowCount(Math.max(1, leftStats.getRowCount()) * Math.max(1, rightStats.getRowCount())) + .putColumnStatistics(leftStats.columnStatistics()) + .putColumnStatistics(rightStats.columnStatistics()) + .build(); if (joinType.isSemiOrAntiJoin()) { return estimateSemiOrAnti(leftStats, rightStats, join); } else if (joinType == JoinType.INNER_JOIN) { @@ -304,15 +309,15 @@ public class JoinEstimation { Statistics innerJoinStats = estimateInnerJoin(leftStats, rightStats, join); double rowCount = Math.max(leftStats.getRowCount(), innerJoinStats.getRowCount()); rowCount = Math.max(leftStats.getRowCount(), rowCount); - return innerJoinStats.withRowCountAndEnforceValid(rowCount); + return crossJoinStats.withRowCountAndEnforceValid(rowCount); } else if (joinType == JoinType.RIGHT_OUTER_JOIN) { Statistics innerJoinStats = estimateInnerJoin(leftStats, rightStats, join); double rowCount = Math.max(rightStats.getRowCount(), innerJoinStats.getRowCount()); rowCount = Math.max(rowCount, rightStats.getRowCount()); - return innerJoinStats.withRowCountAndEnforceValid(rowCount); + return crossJoinStats.withRowCountAndEnforceValid(rowCount); } else if (joinType == JoinType.FULL_OUTER_JOIN) { Statistics innerJoinStats = estimateInnerJoin(leftStats, rightStats, join); - return innerJoinStats.withRowCountAndEnforceValid(leftStats.getRowCount() + return crossJoinStats.withRowCountAndEnforceValid(leftStats.getRowCount() + rightStats.getRowCount() + innerJoinStats.getRowCount()); } else if (joinType == JoinType.CROSS_JOIN) { return new StatisticsBuilder() diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/JoinEstimateTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/JoinEstimateTest.java index f49bf94e2d..2735e26da4 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/JoinEstimateTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/JoinEstimateTest.java @@ -91,4 +91,58 @@ public class JoinEstimateTest { Assertions.assertNotNull(outAStats); Assertions.assertEquals(5, outBStats.ndv); } + + @Test + public void testOuterJoinStats() { + SlotReference a = new SlotReference("a", IntegerType.INSTANCE); + SlotReference b = new SlotReference("b", IntegerType.INSTANCE); + SlotReference c = new SlotReference("c", IntegerType.INSTANCE); + EqualTo eq = new EqualTo(a, b); + Statistics leftStats = new StatisticsBuilder().setRowCount(100).build(); + leftStats.addColumnStats(a, + new ColumnStatisticBuilder() + .setCount(100) + .setNdv(10) + .build() + ); + Statistics rightStats = new StatisticsBuilder().setRowCount(80).build(); + rightStats.addColumnStats(b, + new ColumnStatisticBuilder() + .setCount(80) + .setNdv(0) + .build() + ).addColumnStats(c, + new ColumnStatisticBuilder() + .setCount(80) + .setNdv(20) + .build() + ); + IdGenerator idGenerator = GroupId.createGenerator(); + GroupPlan left = new GroupPlan(new Group(idGenerator.getNextId(), new LogicalProperties( + new Supplier>() { + @Override + public List get() { + return Lists.newArrayList(a); + } + }))); + GroupPlan right = new GroupPlan(new Group(idGenerator.getNextId(), new LogicalProperties( + new Supplier>() { + @Override + public List get() { + return Lists.newArrayList(b, c); + } + }))); + LogicalJoin join = new LogicalJoin(JoinType.LEFT_OUTER_JOIN, Lists.newArrayList(eq), + left, right); + Statistics outputStats = JoinEstimation.estimate(leftStats, rightStats, join); + ColumnStatistic outAStats = outputStats.findColumnStatistics(a); + Assertions.assertNotNull(outAStats); + Assertions.assertEquals(10, outAStats.ndv); + ColumnStatistic outBStats = outputStats.findColumnStatistics(b); + Assertions.assertNotNull(outAStats); + Assertions.assertEquals(0, outBStats.ndv); + ColumnStatistic outCStats = outputStats.findColumnStatistics(c); + Assertions.assertNotNull(outAStats); + Assertions.assertEquals(20.0, outCStats.ndv); + } }