diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java index 161194456e..9650526e39 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java @@ -19,21 +19,19 @@ package org.apache.doris.nereids.stats; import org.apache.doris.common.CheckedMath; import org.apache.doris.nereids.trees.expressions.Cast; -import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.algebra.Join; -import org.apache.doris.nereids.util.JoinUtils; import org.apache.doris.statistics.ColumnStats; import org.apache.doris.statistics.StatsDeriveResult; import com.google.common.base.Preconditions; +import com.google.common.collect.Maps; import java.util.List; import java.util.Map; -import java.util.stream.Collectors; /** * Estimate hash join stats. @@ -41,29 +39,46 @@ import java.util.stream.Collectors; */ public class JoinEstimation { + private static final double DEFAULT_JOIN_RATIO = 10.0; + /** * Do estimate. + * // TODO: since we have no column stats here. just use a fix ratio to compute the row count. */ public static StatsDeriveResult estimate(StatsDeriveResult leftStats, StatsDeriveResult rightStats, Join join) { JoinType joinType = join.getJoinType(); - StatsDeriveResult statsDeriveResult = new StatsDeriveResult(leftStats); - statsDeriveResult.merge(rightStats); // TODO: normalize join hashConjuncts. - List hashJoinConjuncts = join.getHashJoinConjuncts(); - List normalizedConjuncts = hashJoinConjuncts.stream().map(EqualTo.class::cast) - .map(e -> JoinUtils.swapEqualToForChildrenOrder(e, leftStats.getSlotToColumnStats().keySet())) - .collect(Collectors.toList()); - long rowCount = -1; - if (joinType.isSemiOrAntiJoin()) { - rowCount = getSemiJoinRowCount(leftStats, rightStats, normalizedConjuncts, joinType); - } else if (joinType.isInnerJoin() || joinType.isOuterJoin()) { - rowCount = getJoinRowCount(leftStats, rightStats, normalizedConjuncts, joinType); - } else if (joinType.isCrossJoin()) { + // List hashJoinConjuncts = join.getHashJoinConjuncts(); + // List normalizedConjuncts = hashJoinConjuncts.stream().map(EqualTo.class::cast) + // .map(e -> JoinUtils.swapEqualToForChildrenOrder(e, leftStats.getSlotToColumnStats().keySet())) + // .collect(Collectors.toList()); + + long rowCount; + if (joinType == JoinType.LEFT_SEMI_JOIN || joinType == JoinType.LEFT_ANTI_JOIN) { + rowCount = Math.round(leftStats.getRowCount() / DEFAULT_JOIN_RATIO) + 1; + } else if (joinType == JoinType.RIGHT_SEMI_JOIN || joinType == JoinType.RIGHT_ANTI_JOIN) { + rowCount = Math.round(rightStats.getRowCount() / DEFAULT_JOIN_RATIO) + 1; + } else if (joinType == JoinType.INNER_JOIN) { + long childRowCount = Math.max(leftStats.getRowCount(), rightStats.getRowCount()); + rowCount = Math.round(childRowCount / DEFAULT_JOIN_RATIO) + 1; + } else if (joinType == JoinType.LEFT_OUTER_JOIN) { + rowCount = leftStats.getRowCount(); + } else if (joinType == JoinType.RIGHT_OUTER_JOIN) { + rowCount = rightStats.getRowCount(); + } else if (joinType == JoinType.CROSS_JOIN) { rowCount = CheckedMath.checkedMultiply(leftStats.getRowCount(), rightStats.getRowCount()); } else { throw new RuntimeException("joinType is not supported"); } + + StatsDeriveResult statsDeriveResult = new StatsDeriveResult(rowCount, Maps.newHashMap()); + if (joinType.isRemainLeftJoin()) { + statsDeriveResult.merge(leftStats); + } + if (joinType.isRemainRightJoin()) { + statsDeriveResult.merge(rightStats); + } statsDeriveResult.setRowCount(rowCount); return statsDeriveResult; } @@ -78,7 +93,7 @@ public class JoinEstimation { // TODO: If the condition of Join Plan could any expression in addition to EqualTo type, // we should handle that properly. private static long getSemiJoinRowCount(StatsDeriveResult leftStats, StatsDeriveResult rightStats, - List eqConjunctList, JoinType joinType) { + List hashConjuncts, JoinType joinType) { long rowCount; if (JoinType.RIGHT_SEMI_JOIN.equals(joinType) || JoinType.RIGHT_ANTI_JOIN.equals(joinType)) { if (rightStats.getRowCount() == -1) { @@ -94,10 +109,11 @@ public class JoinEstimation { Map leftSlotToColStats = leftStats.getSlotToColumnStats(); Map rightSlotToColStats = rightStats.getSlotToColumnStats(); double minSelectivity = 1.0; - for (Expression eqJoinPredicate : eqConjunctList) { - long lhsNdv = leftSlotToColStats.get(removeCast(eqJoinPredicate.child(0))).getNdv(); + for (Expression hashConjunct : hashConjuncts) { + // TODO: since we have no column stats here. just use a fix ratio to compute the row count. + long lhsNdv = leftSlotToColStats.get(removeCast(hashConjunct.child(0))).getNdv(); lhsNdv = Math.min(lhsNdv, leftStats.getRowCount()); - long rhsNdv = rightSlotToColStats.get(removeCast(eqJoinPredicate.child(1))).getNdv(); + long rhsNdv = rightSlotToColStats.get(removeCast(hashConjunct.child(1))).getNdv(); rhsNdv = Math.min(rhsNdv, rightStats.getRowCount()); // Skip conjuncts with unknown NDV on either side. if (lhsNdv == -1 || rhsNdv == -1) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java index 6f16d06f04..ba25d62ce9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java @@ -23,7 +23,6 @@ import org.apache.doris.catalog.Partition; import org.apache.doris.common.AnalysisException; import org.apache.doris.common.Pair; import org.apache.doris.nereids.memo.GroupExpression; -import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; @@ -83,6 +82,8 @@ import java.util.stream.Collectors; */ public class StatsCalculator extends DefaultPlanVisitor { + private static final int DEFAULT_AGGREGATE_RATIO = 1000; + private final GroupExpression groupExpression; private StatsCalculator(GroupExpression groupExpression) { @@ -163,7 +164,7 @@ public class StatsCalculator extends DefaultPlanVisitor @Override public StatsDeriveResult visitLogicalAssertNumRows( LogicalAssertNumRows assertNumRows, Void context) { - return groupExpression.getCopyOfChildStats(0); + return computeAssertNumRows(assertNumRows.getAssertNumRowsElement().getDesiredNumOfRows()); } @Override @@ -235,7 +236,13 @@ public class StatsCalculator extends DefaultPlanVisitor @Override public StatsDeriveResult visitPhysicalAssertNumRows(PhysicalAssertNumRows assertNumRows, Void context) { - return groupExpression.getCopyOfChildStats(0); + return computeAssertNumRows(assertNumRows.getAssertNumRowsElement().getDesiredNumOfRows()); + } + + private StatsDeriveResult computeAssertNumRows(long desiredNumOfRows) { + StatsDeriveResult statsDeriveResult = groupExpression.getCopyOfChildStats(0); + statsDeriveResult.updateRowCountByLimit(1); + return statsDeriveResult; } private StatsDeriveResult computeFilter(Filter filter) { @@ -301,22 +308,21 @@ public class StatsCalculator extends DefaultPlanVisitor } private StatsDeriveResult computeAggregate(Aggregate aggregate) { - List groupByExpressions = aggregate.getGroupByExpressions(); + // TODO: since we have no column stats here. just use a fix ratio to compute the row count. + // List groupByExpressions = aggregate.getGroupByExpressions(); StatsDeriveResult childStats = groupExpression.getCopyOfChildStats(0); - Map childSlotToColumnStats = childStats.getSlotToColumnStats(); - long resultSetCount = 1; - for (Expression groupByExpression : groupByExpressions) { - Set slots = groupByExpression.getInputSlots(); - // TODO: Support more complex group expr. - // For example: - // select max(col1+col3) from t1 group by col1+col3; - if (slots.size() != 1) { - continue; - } - Slot slotReference = slots.iterator().next(); - ColumnStats columnStats = childSlotToColumnStats.get(slotReference); - resultSetCount *= columnStats.getNdv(); + // Map childSlotToColumnStats = childStats.getSlotToColumnStats(); + // long resultSetCount = groupByExpressions.stream() + // .flatMap(expr -> expr.getInputSlots().stream()) + // .filter(childSlotToColumnStats::containsKey) + // .map(childSlotToColumnStats::get) + // .map(ColumnStats::getNdv) + // .reduce(1L, (a, b) -> a * b); + long resultSetCount = childStats.getRowCount() / DEFAULT_AGGREGATE_RATIO; + if (resultSetCount <= 0) { + resultSetCount = 1L; } + Map slotToColumnStats = Maps.newHashMap(); List outputExpressions = aggregate.getOutputExpressions(); // TODO: 1. Estimate the output unit size by the type of corresponding AggregateFunction diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/JoinType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/JoinType.java index 9badf47586..9b071b8986 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/JoinType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/JoinType.java @@ -112,6 +112,14 @@ public enum JoinType { return this == LEFT_OUTER_JOIN || this == RIGHT_OUTER_JOIN || this == FULL_OUTER_JOIN; } + public final boolean isRemainLeftJoin() { + return this != RIGHT_SEMI_JOIN && this != RIGHT_ANTI_JOIN; + } + + public final boolean isRemainRightJoin() { + return this != LEFT_SEMI_JOIN && this != LEFT_ANTI_JOIN; + } + public final boolean isSwapJoinType() { return joinSwapMap.containsKey(this); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/cascades/DeriveStatsJobTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/cascades/DeriveStatsJobTest.java index 3b5172ed2c..3a6f9ab4af 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/cascades/DeriveStatsJobTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/cascades/DeriveStatsJobTest.java @@ -74,7 +74,7 @@ public class DeriveStatsJobTest { } StatsDeriveResult statistics = cascadesContext.getMemo().getRoot().getStatistics(); Assertions.assertNotNull(statistics); - Assertions.assertEquals(10, statistics.getRowCount()); + Assertions.assertEquals(1, statistics.getRowCount()); } private LogicalOlapScan constructOlapSCan() { diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/StatsCalculatorTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/StatsCalculatorTest.java index 0f47477f4c..15501dced2 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/StatsCalculatorTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/StatsCalculatorTest.java @@ -22,21 +22,14 @@ import org.apache.doris.catalog.OlapTable; import org.apache.doris.nereids.memo.Group; import org.apache.doris.nereids.memo.GroupExpression; import org.apache.doris.nereids.properties.LogicalProperties; -import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.And; import org.apache.doris.nereids.trees.expressions.EqualTo; -import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Or; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; -import org.apache.doris.nereids.trees.expressions.functions.AggregateFunction; -import org.apache.doris.nereids.trees.expressions.functions.Sum; import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; import org.apache.doris.nereids.trees.plans.GroupPlan; -import org.apache.doris.nereids.trees.plans.JoinType; -import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; -import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalLimit; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalTopN; @@ -51,14 +44,12 @@ import org.apache.doris.statistics.TableStats; import com.google.common.base.Supplier; import com.google.common.collect.ImmutableList; -import com.google.common.collect.Lists; import mockit.Expectations; import mockit.Mocked; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -74,43 +65,44 @@ public class StatsCalculatorTest { @Mocked StatisticsManager statisticsManager; - @Test - public void testAgg() { - List qualifier = new ArrayList<>(); - qualifier.add("test"); - qualifier.add("t"); - SlotReference slot1 = new SlotReference("c1", IntegerType.INSTANCE, true, qualifier); - SlotReference slot2 = new SlotReference("c2", IntegerType.INSTANCE, true, qualifier); - ColumnStats columnStats1 = new ColumnStats(); - columnStats1.setNdv(10); - columnStats1.setNumNulls(5); - ColumnStats columnStats2 = new ColumnStats(); - columnStats2.setNdv(20); - columnStats1.setNumNulls(10); - Map slotColumnStatsMap = new HashMap<>(); - slotColumnStatsMap.put(slot1, columnStats1); - slotColumnStatsMap.put(slot2, columnStats2); - List groupByExprList = new ArrayList<>(); - groupByExprList.add(slot1); - AggregateFunction sum = new Sum(slot2); - StatsDeriveResult childStats = new StatsDeriveResult(20, slotColumnStatsMap); - Alias alias = new Alias(sum, "a"); - Group childGroup = new Group(); - childGroup.setLogicalProperties(new LogicalProperties(new Supplier>() { - @Override - public List get() { - return Collections.emptyList(); - } - })); - GroupPlan groupPlan = new GroupPlan(childGroup); - childGroup.setStatistics(childStats); - LogicalAggregate logicalAggregate = new LogicalAggregate(groupByExprList, Arrays.asList(alias), groupPlan); - GroupExpression groupExpression = new GroupExpression(logicalAggregate, Arrays.asList(childGroup)); - Group ownerGroup = new Group(); - groupExpression.setOwnerGroup(ownerGroup); - StatsCalculator.estimate(groupExpression); - Assertions.assertEquals(groupExpression.getOwnerGroup().getStatistics().getRowCount(), 10); - } + // TODO: temporary disable this test, until we could get column stats + // @Test + // public void testAgg() { + // List qualifier = new ArrayList<>(); + // qualifier.add("test"); + // qualifier.add("t"); + // SlotReference slot1 = new SlotReference("c1", IntegerType.INSTANCE, true, qualifier); + // SlotReference slot2 = new SlotReference("c2", IntegerType.INSTANCE, true, qualifier); + // ColumnStats columnStats1 = new ColumnStats(); + // columnStats1.setNdv(10); + // columnStats1.setNumNulls(5); + // ColumnStats columnStats2 = new ColumnStats(); + // columnStats2.setNdv(20); + // columnStats1.setNumNulls(10); + // Map slotColumnStatsMap = new HashMap<>(); + // slotColumnStatsMap.put(slot1, columnStats1); + // slotColumnStatsMap.put(slot2, columnStats2); + // List groupByExprList = new ArrayList<>(); + // groupByExprList.add(slot1); + // AggregateFunction sum = new Sum(slot2); + // StatsDeriveResult childStats = new StatsDeriveResult(20, slotColumnStatsMap); + // Alias alias = new Alias(sum, "a"); + // Group childGroup = new Group(); + // childGroup.setLogicalProperties(new LogicalProperties(new Supplier>() { + // @Override + // public List get() { + // return Collections.emptyList(); + // } + // })); + // GroupPlan groupPlan = new GroupPlan(childGroup); + // childGroup.setStatistics(childStats); + // LogicalAggregate logicalAggregate = new LogicalAggregate(groupByExprList, Arrays.asList(alias), groupPlan); + // GroupExpression groupExpression = new GroupExpression(logicalAggregate, Arrays.asList(childGroup)); + // Group ownerGroup = new Group(); + // groupExpression.setOwnerGroup(ownerGroup); + // StatsCalculator.estimate(groupExpression); + // Assertions.assertEquals(groupExpression.getOwnerGroup().getStatistics().getRowCount(), 10); + // } @Test public void testFilter() { @@ -164,42 +156,43 @@ public class StatsCalculatorTest { ownerGroupOr.getStatistics().getRowCount(), 0.001); } - @Test - public void testHashJoin() { - List qualifier = ImmutableList.of("test", "t"); - SlotReference slot1 = new SlotReference("c1", IntegerType.INSTANCE, true, qualifier); - SlotReference slot2 = new SlotReference("c2", IntegerType.INSTANCE, true, qualifier); - ColumnStats columnStats1 = new ColumnStats(); - columnStats1.setNdv(10); - columnStats1.setNumNulls(5); - ColumnStats columnStats2 = new ColumnStats(); - columnStats2.setNdv(20); - columnStats1.setNumNulls(10); - Map slotColumnStatsMap1 = new HashMap<>(); - slotColumnStatsMap1.put(slot1, columnStats1); - - Map slotColumnStatsMap2 = new HashMap<>(); - slotColumnStatsMap2.put(slot2, columnStats2); - - final long leftRowCount = 5000; - StatsDeriveResult leftStats = new StatsDeriveResult(leftRowCount, slotColumnStatsMap1); - - final long rightRowCount = 10000; - StatsDeriveResult rightStats = new StatsDeriveResult(rightRowCount, slotColumnStatsMap2); - - EqualTo equalTo = new EqualTo(slot1, slot2); - - LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t", 0); - LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(0, "t", 0); - LogicalJoin fakeSemiJoin = new LogicalJoin<>( - JoinType.LEFT_SEMI_JOIN, Lists.newArrayList(equalTo), Optional.empty(), scan1, scan2); - LogicalJoin fakeInnerJoin = new LogicalJoin<>( - JoinType.INNER_JOIN, Lists.newArrayList(equalTo), Optional.empty(), scan1, scan2); - StatsDeriveResult semiJoinStats = JoinEstimation.estimate(leftStats, rightStats, fakeSemiJoin); - Assertions.assertEquals(leftRowCount, semiJoinStats.getRowCount()); - StatsDeriveResult innerJoinStats = JoinEstimation.estimate(leftStats, rightStats, fakeInnerJoin); - Assertions.assertEquals(2500000, innerJoinStats.getRowCount()); - } + // TODO: temporary disable this test, until we could get column stats + // @Test + // public void testHashJoin() { + // List qualifier = ImmutableList.of("test", "t"); + // SlotReference slot1 = new SlotReference("c1", IntegerType.INSTANCE, true, qualifier); + // SlotReference slot2 = new SlotReference("c2", IntegerType.INSTANCE, true, qualifier); + // ColumnStats columnStats1 = new ColumnStats(); + // columnStats1.setNdv(10); + // columnStats1.setNumNulls(5); + // ColumnStats columnStats2 = new ColumnStats(); + // columnStats2.setNdv(20); + // columnStats1.setNumNulls(10); + // Map slotColumnStatsMap1 = new HashMap<>(); + // slotColumnStatsMap1.put(slot1, columnStats1); + // + // Map slotColumnStatsMap2 = new HashMap<>(); + // slotColumnStatsMap2.put(slot2, columnStats2); + // + // final long leftRowCount = 5000; + // StatsDeriveResult leftStats = new StatsDeriveResult(leftRowCount, slotColumnStatsMap1); + // + // final long rightRowCount = 10000; + // StatsDeriveResult rightStats = new StatsDeriveResult(rightRowCount, slotColumnStatsMap2); + // + // EqualTo equalTo = new EqualTo(slot1, slot2); + // + // LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t", 0); + // LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(0, "t", 0); + // LogicalJoin fakeSemiJoin = new LogicalJoin<>( + // JoinType.LEFT_SEMI_JOIN, Lists.newArrayList(equalTo), Optional.empty(), scan1, scan2); + // LogicalJoin fakeInnerJoin = new LogicalJoin<>( + // JoinType.INNER_JOIN, Lists.newArrayList(equalTo), Optional.empty(), scan1, scan2); + // StatsDeriveResult semiJoinStats = JoinEstimation.estimate(leftStats, rightStats, fakeSemiJoin); + // Assertions.assertEquals(leftRowCount, semiJoinStats.getRowCount()); + // StatsDeriveResult innerJoinStats = JoinEstimation.estimate(leftStats, rightStats, fakeInnerJoin); + // Assertions.assertEquals(2500000, innerJoinStats.getRowCount()); + // } @Test public void testOlapScan() {