[enhancement](Nereids) change aggregate and join stats calc algorithm (#12447)

The original statistic derive calculate algorithm rely on NDV and other column statistics. But we cannot get these stats in product environment. 
This PR change these operator's stats calc algorithm to use a DEFAULT RATIO variable instead of column statistics.
We should change these algorithm when we could get column stats in product environment
This commit is contained in:
morrySnow
2022-09-09 01:00:07 +08:00
committed by GitHub
parent b4f0f39e77
commit d2a23a4cf9
5 changed files with 142 additions and 119 deletions

View File

@ -19,21 +19,19 @@ package org.apache.doris.nereids.stats;
import org.apache.doris.common.CheckedMath;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.algebra.Join;
import org.apache.doris.nereids.util.JoinUtils;
import org.apache.doris.statistics.ColumnStats;
import org.apache.doris.statistics.StatsDeriveResult;
import com.google.common.base.Preconditions;
import com.google.common.collect.Maps;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
/**
* Estimate hash join stats.
@ -41,29 +39,46 @@ import java.util.stream.Collectors;
*/
public class JoinEstimation {
private static final double DEFAULT_JOIN_RATIO = 10.0;
/**
* Do estimate.
* // TODO: since we have no column stats here. just use a fix ratio to compute the row count.
*/
public static StatsDeriveResult estimate(StatsDeriveResult leftStats, StatsDeriveResult rightStats, Join join) {
JoinType joinType = join.getJoinType();
StatsDeriveResult statsDeriveResult = new StatsDeriveResult(leftStats);
statsDeriveResult.merge(rightStats);
// TODO: normalize join hashConjuncts.
List<Expression> hashJoinConjuncts = join.getHashJoinConjuncts();
List<Expression> normalizedConjuncts = hashJoinConjuncts.stream().map(EqualTo.class::cast)
.map(e -> JoinUtils.swapEqualToForChildrenOrder(e, leftStats.getSlotToColumnStats().keySet()))
.collect(Collectors.toList());
long rowCount = -1;
if (joinType.isSemiOrAntiJoin()) {
rowCount = getSemiJoinRowCount(leftStats, rightStats, normalizedConjuncts, joinType);
} else if (joinType.isInnerJoin() || joinType.isOuterJoin()) {
rowCount = getJoinRowCount(leftStats, rightStats, normalizedConjuncts, joinType);
} else if (joinType.isCrossJoin()) {
// List<Expression> hashJoinConjuncts = join.getHashJoinConjuncts();
// List<Expression> normalizedConjuncts = hashJoinConjuncts.stream().map(EqualTo.class::cast)
// .map(e -> JoinUtils.swapEqualToForChildrenOrder(e, leftStats.getSlotToColumnStats().keySet()))
// .collect(Collectors.toList());
long rowCount;
if (joinType == JoinType.LEFT_SEMI_JOIN || joinType == JoinType.LEFT_ANTI_JOIN) {
rowCount = Math.round(leftStats.getRowCount() / DEFAULT_JOIN_RATIO) + 1;
} else if (joinType == JoinType.RIGHT_SEMI_JOIN || joinType == JoinType.RIGHT_ANTI_JOIN) {
rowCount = Math.round(rightStats.getRowCount() / DEFAULT_JOIN_RATIO) + 1;
} else if (joinType == JoinType.INNER_JOIN) {
long childRowCount = Math.max(leftStats.getRowCount(), rightStats.getRowCount());
rowCount = Math.round(childRowCount / DEFAULT_JOIN_RATIO) + 1;
} else if (joinType == JoinType.LEFT_OUTER_JOIN) {
rowCount = leftStats.getRowCount();
} else if (joinType == JoinType.RIGHT_OUTER_JOIN) {
rowCount = rightStats.getRowCount();
} else if (joinType == JoinType.CROSS_JOIN) {
rowCount = CheckedMath.checkedMultiply(leftStats.getRowCount(),
rightStats.getRowCount());
} else {
throw new RuntimeException("joinType is not supported");
}
StatsDeriveResult statsDeriveResult = new StatsDeriveResult(rowCount, Maps.newHashMap());
if (joinType.isRemainLeftJoin()) {
statsDeriveResult.merge(leftStats);
}
if (joinType.isRemainRightJoin()) {
statsDeriveResult.merge(rightStats);
}
statsDeriveResult.setRowCount(rowCount);
return statsDeriveResult;
}
@ -78,7 +93,7 @@ public class JoinEstimation {
// TODO: If the condition of Join Plan could any expression in addition to EqualTo type,
// we should handle that properly.
private static long getSemiJoinRowCount(StatsDeriveResult leftStats, StatsDeriveResult rightStats,
List<Expression> eqConjunctList, JoinType joinType) {
List<Expression> hashConjuncts, JoinType joinType) {
long rowCount;
if (JoinType.RIGHT_SEMI_JOIN.equals(joinType) || JoinType.RIGHT_ANTI_JOIN.equals(joinType)) {
if (rightStats.getRowCount() == -1) {
@ -94,10 +109,11 @@ public class JoinEstimation {
Map<Slot, ColumnStats> leftSlotToColStats = leftStats.getSlotToColumnStats();
Map<Slot, ColumnStats> rightSlotToColStats = rightStats.getSlotToColumnStats();
double minSelectivity = 1.0;
for (Expression eqJoinPredicate : eqConjunctList) {
long lhsNdv = leftSlotToColStats.get(removeCast(eqJoinPredicate.child(0))).getNdv();
for (Expression hashConjunct : hashConjuncts) {
// TODO: since we have no column stats here. just use a fix ratio to compute the row count.
long lhsNdv = leftSlotToColStats.get(removeCast(hashConjunct.child(0))).getNdv();
lhsNdv = Math.min(lhsNdv, leftStats.getRowCount());
long rhsNdv = rightSlotToColStats.get(removeCast(eqJoinPredicate.child(1))).getNdv();
long rhsNdv = rightSlotToColStats.get(removeCast(hashConjunct.child(1))).getNdv();
rhsNdv = Math.min(rhsNdv, rightStats.getRowCount());
// Skip conjuncts with unknown NDV on either side.
if (lhsNdv == -1 || rhsNdv == -1) {

View File

@ -23,7 +23,6 @@ import org.apache.doris.catalog.Partition;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
@ -83,6 +82,8 @@ import java.util.stream.Collectors;
*/
public class StatsCalculator extends DefaultPlanVisitor<StatsDeriveResult, Void> {
private static final int DEFAULT_AGGREGATE_RATIO = 1000;
private final GroupExpression groupExpression;
private StatsCalculator(GroupExpression groupExpression) {
@ -163,7 +164,7 @@ public class StatsCalculator extends DefaultPlanVisitor<StatsDeriveResult, Void>
@Override
public StatsDeriveResult visitLogicalAssertNumRows(
LogicalAssertNumRows<? extends Plan> assertNumRows, Void context) {
return groupExpression.getCopyOfChildStats(0);
return computeAssertNumRows(assertNumRows.getAssertNumRowsElement().getDesiredNumOfRows());
}
@Override
@ -235,7 +236,13 @@ public class StatsCalculator extends DefaultPlanVisitor<StatsDeriveResult, Void>
@Override
public StatsDeriveResult visitPhysicalAssertNumRows(PhysicalAssertNumRows<? extends Plan> assertNumRows,
Void context) {
return groupExpression.getCopyOfChildStats(0);
return computeAssertNumRows(assertNumRows.getAssertNumRowsElement().getDesiredNumOfRows());
}
private StatsDeriveResult computeAssertNumRows(long desiredNumOfRows) {
StatsDeriveResult statsDeriveResult = groupExpression.getCopyOfChildStats(0);
statsDeriveResult.updateRowCountByLimit(1);
return statsDeriveResult;
}
private StatsDeriveResult computeFilter(Filter filter) {
@ -301,22 +308,21 @@ public class StatsCalculator extends DefaultPlanVisitor<StatsDeriveResult, Void>
}
private StatsDeriveResult computeAggregate(Aggregate aggregate) {
List<Expression> groupByExpressions = aggregate.getGroupByExpressions();
// TODO: since we have no column stats here. just use a fix ratio to compute the row count.
// List<Expression> groupByExpressions = aggregate.getGroupByExpressions();
StatsDeriveResult childStats = groupExpression.getCopyOfChildStats(0);
Map<Slot, ColumnStats> childSlotToColumnStats = childStats.getSlotToColumnStats();
long resultSetCount = 1;
for (Expression groupByExpression : groupByExpressions) {
Set<Slot> slots = groupByExpression.getInputSlots();
// TODO: Support more complex group expr.
// For example:
// select max(col1+col3) from t1 group by col1+col3;
if (slots.size() != 1) {
continue;
}
Slot slotReference = slots.iterator().next();
ColumnStats columnStats = childSlotToColumnStats.get(slotReference);
resultSetCount *= columnStats.getNdv();
// Map<Slot, ColumnStats> childSlotToColumnStats = childStats.getSlotToColumnStats();
// long resultSetCount = groupByExpressions.stream()
// .flatMap(expr -> expr.getInputSlots().stream())
// .filter(childSlotToColumnStats::containsKey)
// .map(childSlotToColumnStats::get)
// .map(ColumnStats::getNdv)
// .reduce(1L, (a, b) -> a * b);
long resultSetCount = childStats.getRowCount() / DEFAULT_AGGREGATE_RATIO;
if (resultSetCount <= 0) {
resultSetCount = 1L;
}
Map<Slot, ColumnStats> slotToColumnStats = Maps.newHashMap();
List<NamedExpression> outputExpressions = aggregate.getOutputExpressions();
// TODO: 1. Estimate the output unit size by the type of corresponding AggregateFunction

View File

@ -112,6 +112,14 @@ public enum JoinType {
return this == LEFT_OUTER_JOIN || this == RIGHT_OUTER_JOIN || this == FULL_OUTER_JOIN;
}
public final boolean isRemainLeftJoin() {
return this != RIGHT_SEMI_JOIN && this != RIGHT_ANTI_JOIN;
}
public final boolean isRemainRightJoin() {
return this != LEFT_SEMI_JOIN && this != LEFT_ANTI_JOIN;
}
public final boolean isSwapJoinType() {
return joinSwapMap.containsKey(this);
}

View File

@ -74,7 +74,7 @@ public class DeriveStatsJobTest {
}
StatsDeriveResult statistics = cascadesContext.getMemo().getRoot().getStatistics();
Assertions.assertNotNull(statistics);
Assertions.assertEquals(10, statistics.getRowCount());
Assertions.assertEquals(1, statistics.getRowCount());
}
private LogicalOlapScan constructOlapSCan() {

View File

@ -22,21 +22,14 @@ import org.apache.doris.catalog.OlapTable;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.Sum;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
@ -51,14 +44,12 @@ import org.apache.doris.statistics.TableStats;
import com.google.common.base.Supplier;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import mockit.Expectations;
import mockit.Mocked;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
@ -74,43 +65,44 @@ public class StatsCalculatorTest {
@Mocked
StatisticsManager statisticsManager;
@Test
public void testAgg() {
List<String> qualifier = new ArrayList<>();
qualifier.add("test");
qualifier.add("t");
SlotReference slot1 = new SlotReference("c1", IntegerType.INSTANCE, true, qualifier);
SlotReference slot2 = new SlotReference("c2", IntegerType.INSTANCE, true, qualifier);
ColumnStats columnStats1 = new ColumnStats();
columnStats1.setNdv(10);
columnStats1.setNumNulls(5);
ColumnStats columnStats2 = new ColumnStats();
columnStats2.setNdv(20);
columnStats1.setNumNulls(10);
Map<Slot, ColumnStats> slotColumnStatsMap = new HashMap<>();
slotColumnStatsMap.put(slot1, columnStats1);
slotColumnStatsMap.put(slot2, columnStats2);
List<Expression> groupByExprList = new ArrayList<>();
groupByExprList.add(slot1);
AggregateFunction sum = new Sum(slot2);
StatsDeriveResult childStats = new StatsDeriveResult(20, slotColumnStatsMap);
Alias alias = new Alias(sum, "a");
Group childGroup = new Group();
childGroup.setLogicalProperties(new LogicalProperties(new Supplier<List<Slot>>() {
@Override
public List<Slot> get() {
return Collections.emptyList();
}
}));
GroupPlan groupPlan = new GroupPlan(childGroup);
childGroup.setStatistics(childStats);
LogicalAggregate logicalAggregate = new LogicalAggregate(groupByExprList, Arrays.asList(alias), groupPlan);
GroupExpression groupExpression = new GroupExpression(logicalAggregate, Arrays.asList(childGroup));
Group ownerGroup = new Group();
groupExpression.setOwnerGroup(ownerGroup);
StatsCalculator.estimate(groupExpression);
Assertions.assertEquals(groupExpression.getOwnerGroup().getStatistics().getRowCount(), 10);
}
// TODO: temporary disable this test, until we could get column stats
// @Test
// public void testAgg() {
// List<String> qualifier = new ArrayList<>();
// qualifier.add("test");
// qualifier.add("t");
// SlotReference slot1 = new SlotReference("c1", IntegerType.INSTANCE, true, qualifier);
// SlotReference slot2 = new SlotReference("c2", IntegerType.INSTANCE, true, qualifier);
// ColumnStats columnStats1 = new ColumnStats();
// columnStats1.setNdv(10);
// columnStats1.setNumNulls(5);
// ColumnStats columnStats2 = new ColumnStats();
// columnStats2.setNdv(20);
// columnStats1.setNumNulls(10);
// Map<Slot, ColumnStats> slotColumnStatsMap = new HashMap<>();
// slotColumnStatsMap.put(slot1, columnStats1);
// slotColumnStatsMap.put(slot2, columnStats2);
// List<Expression> groupByExprList = new ArrayList<>();
// groupByExprList.add(slot1);
// AggregateFunction sum = new Sum(slot2);
// StatsDeriveResult childStats = new StatsDeriveResult(20, slotColumnStatsMap);
// Alias alias = new Alias(sum, "a");
// Group childGroup = new Group();
// childGroup.setLogicalProperties(new LogicalProperties(new Supplier<List<Slot>>() {
// @Override
// public List<Slot> get() {
// return Collections.emptyList();
// }
// }));
// GroupPlan groupPlan = new GroupPlan(childGroup);
// childGroup.setStatistics(childStats);
// LogicalAggregate logicalAggregate = new LogicalAggregate(groupByExprList, Arrays.asList(alias), groupPlan);
// GroupExpression groupExpression = new GroupExpression(logicalAggregate, Arrays.asList(childGroup));
// Group ownerGroup = new Group();
// groupExpression.setOwnerGroup(ownerGroup);
// StatsCalculator.estimate(groupExpression);
// Assertions.assertEquals(groupExpression.getOwnerGroup().getStatistics().getRowCount(), 10);
// }
@Test
public void testFilter() {
@ -164,42 +156,43 @@ public class StatsCalculatorTest {
ownerGroupOr.getStatistics().getRowCount(), 0.001);
}
@Test
public void testHashJoin() {
List<String> qualifier = ImmutableList.of("test", "t");
SlotReference slot1 = new SlotReference("c1", IntegerType.INSTANCE, true, qualifier);
SlotReference slot2 = new SlotReference("c2", IntegerType.INSTANCE, true, qualifier);
ColumnStats columnStats1 = new ColumnStats();
columnStats1.setNdv(10);
columnStats1.setNumNulls(5);
ColumnStats columnStats2 = new ColumnStats();
columnStats2.setNdv(20);
columnStats1.setNumNulls(10);
Map<Slot, ColumnStats> slotColumnStatsMap1 = new HashMap<>();
slotColumnStatsMap1.put(slot1, columnStats1);
Map<Slot, ColumnStats> slotColumnStatsMap2 = new HashMap<>();
slotColumnStatsMap2.put(slot2, columnStats2);
final long leftRowCount = 5000;
StatsDeriveResult leftStats = new StatsDeriveResult(leftRowCount, slotColumnStatsMap1);
final long rightRowCount = 10000;
StatsDeriveResult rightStats = new StatsDeriveResult(rightRowCount, slotColumnStatsMap2);
EqualTo equalTo = new EqualTo(slot1, slot2);
LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t", 0);
LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(0, "t", 0);
LogicalJoin<LogicalOlapScan, LogicalOlapScan> fakeSemiJoin = new LogicalJoin<>(
JoinType.LEFT_SEMI_JOIN, Lists.newArrayList(equalTo), Optional.empty(), scan1, scan2);
LogicalJoin<LogicalOlapScan, LogicalOlapScan> fakeInnerJoin = new LogicalJoin<>(
JoinType.INNER_JOIN, Lists.newArrayList(equalTo), Optional.empty(), scan1, scan2);
StatsDeriveResult semiJoinStats = JoinEstimation.estimate(leftStats, rightStats, fakeSemiJoin);
Assertions.assertEquals(leftRowCount, semiJoinStats.getRowCount());
StatsDeriveResult innerJoinStats = JoinEstimation.estimate(leftStats, rightStats, fakeInnerJoin);
Assertions.assertEquals(2500000, innerJoinStats.getRowCount());
}
// TODO: temporary disable this test, until we could get column stats
// @Test
// public void testHashJoin() {
// List<String> qualifier = ImmutableList.of("test", "t");
// SlotReference slot1 = new SlotReference("c1", IntegerType.INSTANCE, true, qualifier);
// SlotReference slot2 = new SlotReference("c2", IntegerType.INSTANCE, true, qualifier);
// ColumnStats columnStats1 = new ColumnStats();
// columnStats1.setNdv(10);
// columnStats1.setNumNulls(5);
// ColumnStats columnStats2 = new ColumnStats();
// columnStats2.setNdv(20);
// columnStats1.setNumNulls(10);
// Map<Slot, ColumnStats> slotColumnStatsMap1 = new HashMap<>();
// slotColumnStatsMap1.put(slot1, columnStats1);
//
// Map<Slot, ColumnStats> slotColumnStatsMap2 = new HashMap<>();
// slotColumnStatsMap2.put(slot2, columnStats2);
//
// final long leftRowCount = 5000;
// StatsDeriveResult leftStats = new StatsDeriveResult(leftRowCount, slotColumnStatsMap1);
//
// final long rightRowCount = 10000;
// StatsDeriveResult rightStats = new StatsDeriveResult(rightRowCount, slotColumnStatsMap2);
//
// EqualTo equalTo = new EqualTo(slot1, slot2);
//
// LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t", 0);
// LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(0, "t", 0);
// LogicalJoin<LogicalOlapScan, LogicalOlapScan> fakeSemiJoin = new LogicalJoin<>(
// JoinType.LEFT_SEMI_JOIN, Lists.newArrayList(equalTo), Optional.empty(), scan1, scan2);
// LogicalJoin<LogicalOlapScan, LogicalOlapScan> fakeInnerJoin = new LogicalJoin<>(
// JoinType.INNER_JOIN, Lists.newArrayList(equalTo), Optional.empty(), scan1, scan2);
// StatsDeriveResult semiJoinStats = JoinEstimation.estimate(leftStats, rightStats, fakeSemiJoin);
// Assertions.assertEquals(leftRowCount, semiJoinStats.getRowCount());
// StatsDeriveResult innerJoinStats = JoinEstimation.estimate(leftStats, rightStats, fakeInnerJoin);
// Assertions.assertEquals(2500000, innerJoinStats.getRowCount());
// }
@Test
public void testOlapScan() {