[Enhancement](plan) Optimize preagg for aggregate function (#28886)

This commit is contained in:
Xujian Duan
2024-01-23 11:17:53 +08:00
committed by yiguolei
parent d61974db14
commit 2499ca6d89
8 changed files with 294 additions and 11 deletions

View File

@ -2627,5 +2627,13 @@ public abstract class Expr extends TreeNode<Expr> implements ParseNode, Cloneabl
expr.replaceSlot(tuple);
}
}
public boolean isNullLiteral() {
return this instanceof NullLiteral;
}
public boolean isZeroLiteral() {
return this instanceof LiteralExpr && ((LiteralExpr) this).isZero();
}
}

View File

@ -34,8 +34,10 @@ import org.apache.logging.log4j.Logger;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.math.BigDecimal;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
public abstract class LiteralExpr extends Expr implements Comparable<LiteralExpr> {
@ -449,4 +451,31 @@ public abstract class LiteralExpr extends Expr implements Comparable<LiteralExpr
public boolean matchExprs(List<Expr> exprs, SelectStmt stmt, boolean ignoreAlias, TupleDescriptor tuple) {
return true;
}
/** whether is ZERO value **/
public boolean isZero() {
boolean isZero = false;
switch (type.getPrimitiveType()) {
case TINYINT:
case SMALLINT:
case INT:
case BIGINT:
case LARGEINT:
isZero = this.getLongValue() == 0;
break;
case FLOAT:
case DOUBLE:
isZero = this.getDoubleValue() == 0.0f;
break;
case DECIMALV2:
case DECIMAL32:
case DECIMAL64:
case DECIMAL128:
case DECIMAL256:
isZero = Objects.equals(((DecimalLiteral) this).getValue(), BigDecimal.ZERO);
break;
default:
}
return isZero;
}
}

View File

@ -32,6 +32,7 @@ import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory;
import org.apache.doris.nereids.rules.rewrite.mv.AbstractSelectMaterializedIndexRule.SlotContext;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.CaseWhen;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
@ -40,6 +41,7 @@ import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotNotFromChildren;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.VirtualSlotReference;
import org.apache.doris.nereids.trees.expressions.WhenClause;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnion;
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionCount;
@ -53,8 +55,10 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.functions.combinator.MergeCombinator;
import org.apache.doris.nereids.trees.expressions.functions.combinator.StateCombinator;
import org.apache.doris.nereids.trees.expressions.functions.scalar.HllHash;
import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ToBitmap;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ToBitmapWithCheck;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
@ -79,6 +83,7 @@ import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.google.common.collect.Streams;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
@ -879,6 +884,9 @@ public class SelectMaterializedIndexWithAggregate extends AbstractSelectMaterial
if (slotOpt.isPresent() && context.keyNameToColumn.containsKey(normalizeName(slotOpt.get().toSql()))) {
return PreAggStatus.on();
}
if (count.child(0).arity() != 0) {
return checkSubExpressions(count, null, context);
}
}
return PreAggStatus.off(String.format(
"Count distinct is only valid for key columns, but meet %s.", count.toSql()));
@ -963,11 +971,106 @@ public class SelectMaterializedIndexWithAggregate extends AbstractSelectMaterial
return PreAggStatus.off(String.format("Aggregate operator don't match, aggregate function: %s"
+ ", column aggregate type: %s", aggFunc.toSql(), aggType));
}
} else if (!aggFunc.child(0).children().isEmpty()) {
return checkSubExpressions(aggFunc, matchingAggType, ctx);
} else {
return PreAggStatus.off(String.format("Slot(%s) in %s is neither key column nor value column.",
childNameWithFuncName, aggFunc.toSql()));
}
}
// check sub expressions in AggregateFunction.
private PreAggStatus checkSubExpressions(AggregateFunction aggFunc, AggregateType matchingAggType,
CheckContext ctx) {
Expression child = aggFunc.child(0);
List<Expression> conditionExps = new ArrayList<>();
List<Expression> returnExps = new ArrayList<>();
// ignore cast
while (child instanceof Cast) {
if (!((Cast) child).getDataType().isNumericType()) {
return PreAggStatus.off(String.format("[%s] is not numeric CAST.", child.toSql()));
}
child = child.child(0);
}
// step 1: extract all condition exprs and return exprs
if (child instanceof If) {
conditionExps.add(child.child(0));
returnExps.add(child.child(1));
returnExps.add(child.child(2));
} else if (child instanceof CaseWhen) {
CaseWhen caseWhen = (CaseWhen) child;
// WHEN THEN
for (WhenClause whenClause : caseWhen.getWhenClauses()) {
conditionExps.add(whenClause.getOperand());
returnExps.add(whenClause.getResult());
}
// ELSE
returnExps.add(caseWhen.getDefaultValue().orElse(new NullLiteral()));
} else {
// currently, only IF and CASE WHEN are supported
returnExps.add(child);
}
// step 2: check condition expressions
for (Expression conditionExp : conditionExps) {
if (!containsAllColumn(conditionExp, ctx.keyNameToColumn.keySet())) {
return PreAggStatus.off(String.format("some columns in condition [%s] is not key.",
conditionExp.toSql()));
}
}
// step 3: check return expressions
// NOTE: now we just support SUM, MIN, MAX and COUNT DISTINCT
int returnExprValidateNum = 0;
for (Expression returnExp : returnExps) {
// ignore cast in return expr
while (returnExp instanceof Cast) {
returnExp = returnExp.child(0);
}
// now we only check simple return expressions
String exprName = returnExp.getExpressionName();
if (!returnExp.children().isEmpty()) {
return PreAggStatus.off(String.format("do not support compound expression [%s] in %s.",
returnExp.toSql(), matchingAggType));
}
if (ctx.keyNameToColumn.containsKey(exprName)) {
if (matchingAggType != AggregateType.MAX && matchingAggType != AggregateType.MIN
&& (aggFunc instanceof Count && !aggFunc.isDistinct())) {
return PreAggStatus.off("agg on key column should be MAX, MIN or COUNT DISTINCT.");
}
}
if (matchingAggType == AggregateType.SUM) {
if ((ctx.valueNameToColumn.containsKey(exprName)
&& ctx.valueNameToColumn.get(exprName).getAggregationType() == matchingAggType)
|| returnExp.isZeroLiteral() || returnExp.isNullLiteral()) {
returnExprValidateNum++;
} else {
return PreAggStatus.off(String.format("SUM cant preagg for [%s].", aggFunc.toSql()));
}
} else if (matchingAggType == AggregateType.MAX || matchingAggType == AggregateType.MIN) {
if (ctx.keyNameToColumn.containsKey(exprName) || returnExp.isNullLiteral()
|| (ctx.valueNameToColumn.containsKey(exprName)
&& ctx.valueNameToColumn.get(exprName).getAggregationType() == matchingAggType)) {
returnExprValidateNum++;
} else {
return PreAggStatus.off(String.format("MAX/MIN cant preagg for [%s].", aggFunc.toSql()));
}
} else if (aggFunc.getName().equalsIgnoreCase("COUNT") && aggFunc.isDistinct()) {
if (ctx.keyNameToColumn.containsKey(exprName)
|| returnExp.isZeroLiteral() || returnExp.isNullLiteral()) {
returnExprValidateNum++;
} else {
return PreAggStatus.off(String.format("COUNT DISTINCT cant preagg for [%s].", aggFunc.toSql()));
}
}
}
if (returnExprValidateNum == returnExps.size()) {
return PreAggStatus.on();
}
return PreAggStatus.off(String.format("cant preagg for [%s].", aggFunc.toSql()));
}
}
private static class CheckContext {

View File

@ -260,6 +260,10 @@ public abstract class Expression extends AbstractTreeNode<Expression> implements
}
}
public boolean isZeroLiteral() {
return this instanceof Literal && ((Literal) this).isZero();
}
public final Expression castTo(DataType targetType) throws AnalysisException {
return uncheckedCastTo(targetType);
}

View File

@ -366,4 +366,27 @@ public abstract class Literal extends Expression implements LeafExpression, Comp
public boolean isStringLikeLiteral() {
return dataType.isStringLikeType();
}
/** whether is ZERO value **/
public boolean isZero() {
if (isNullLiteral()) {
return false;
}
if (dataType.isSmallIntType() || dataType.isTinyIntType() || dataType.isIntegerType()) {
return getValue().equals(0);
} else if (dataType.isBigIntType()) {
return getValue().equals(0L);
} else if (dataType.isLargeIntType()) {
return getValue().equals(BigInteger.ZERO);
} else if (dataType.isFloatType()) {
return getValue().equals(0.0f);
} else if (dataType.isDoubleType()) {
return getValue().equals(0.0);
} else if (dataType.isDecimalV2Type()) {
return getValue().equals(BigDecimal.ZERO);
} else if (dataType.isDecimalV3Type()) {
return getValue().equals(BigDecimal.ZERO);
}
return false;
}
}

View File

@ -28,6 +28,7 @@ import org.apache.doris.analysis.AssertNumRowsElement;
import org.apache.doris.analysis.BaseTableRef;
import org.apache.doris.analysis.BinaryPredicate;
import org.apache.doris.analysis.CaseExpr;
import org.apache.doris.analysis.CaseWhenClause;
import org.apache.doris.analysis.CastExpr;
import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.ExprSubstitutionMap;
@ -77,6 +78,7 @@ import org.apache.doris.thrift.TPushAggOp;
import com.google.common.base.Preconditions;
import com.google.common.base.Predicate;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
@ -653,19 +655,38 @@ public class SingleNodePlanner {
List<Column> conditionColumns = Lists.newArrayList();
if (!(aggExpr.getChild(0) instanceof SlotRef)) {
Expr child = aggExpr.getChild(0);
if ((child instanceof CastExpr) && (child.getChild(0) instanceof SlotRef)) {
if (child.getType().isNumericType()
&& child.getChild(0).getType().isNumericType()) {
returnColumns.add(((SlotRef) child.getChild(0)).getDesc().getColumn());
} else {
turnOffReason = "aggExpr.getChild(0)["
// ignore cast
boolean castReturnExprValidate = true;
while (child instanceof CastExpr) {
if (child.getChild(0) instanceof SlotRef) {
if (child.getType().isNumericType() && child.getChild(0).getType().isNumericType()) {
returnColumns.add(((SlotRef) child.getChild(0)).getDesc().getColumn());
} else {
turnOffReason = "aggExpr.getChild(0)["
+ aggExpr.getChild(0).toSql()
+ "] is not Numeric CastExpr";
aggExprValidate = false;
break;
castReturnExprValidate = false;
break;
}
}
} else if (aggExpr.getChild(0) instanceof CaseExpr) {
CaseExpr caseExpr = (CaseExpr) aggExpr.getChild(0);
child = child.getChild(0);
}
if (!castReturnExprValidate) {
aggExprValidate = false;
break;
}
// convert IF to CASE WHEN.
// For example:
// IF(a > 1, 1, 0) -> CASE WHEN a > 1 THEN 1 ELSE 0 END
if (child instanceof FunctionCallExpr && ((FunctionCallExpr) child)
.getFnName().getFunction().equalsIgnoreCase("IF")) {
Preconditions.checkArgument(child.getChildren().size() == 3);
CaseWhenClause caseWhenClause = new CaseWhenClause(child.getChild(0), child.getChild(1));
child = new CaseExpr(ImmutableList.of(caseWhenClause), child.getChild(2));
}
if (child instanceof CaseExpr) {
CaseExpr caseExpr = (CaseExpr) child;
List<Expr> conditionExprs = caseExpr.getConditionExprs();
for (Expr conditionExpr : conditionExprs) {
List<TupleId> conditionTupleIds = Lists.newArrayList();
@ -680,8 +701,14 @@ public class SingleNodePlanner {
boolean caseReturnExprValidate = true;
List<Expr> returnExprs = caseExpr.getReturnExprs();
for (Expr returnExpr : returnExprs) {
// ignore cast in return expr
while (returnExpr instanceof CastExpr) {
returnExpr = returnExpr.getChild(0);
}
if (returnExpr instanceof SlotRef) {
returnColumns.add(((SlotRef) returnExpr).getDesc().getColumn());
} else if (returnExpr.isNullLiteral() || returnExpr.isZeroLiteral()) {
// If then expr is NULL or Zero, open the preaggregation
} else {
turnOffReason = "aggExpr.getChild(0)[" + aggExpr.getChild(0).toSql()
+ "] is not SlotExpr";

View File

@ -1251,4 +1251,93 @@ class SelectMvIndexTest extends BaseMaterializedIndexSelectTest implements MemoP
Assertions.assertEquals(secondTableIndexName, scan1.getSelectedIndexName());
});
}
@Test
public void testSubExpressionsInAggregation() throws Exception {
createTable("CREATE TABLE db1.`test_pre_agg_tbl` (\n"
+ " `k1` int,\n"
+ " `k2` int,\n"
+ " `k3` char,\n"
+ " `k4` int,\n"
+ " `k5` bigint,\n"
+ " `k6` bigint,\n"
+ " `v1` int SUM,\n"
+ " `v2` bigint SUM,\n"
+ " `v3` bigint MAX,\n"
+ " `v4` bigint MIN,\n"
+ " `v5` float SUM,\n"
+ " `v6` double SUM,\n"
+ " `v7` decimal SUM\n"
+ ") ENGINE=OLAP\n"
+ "AGGREGATE KEY(`k1`, `k2`, `k3`, `k4`, `k5`, `k6`)\n"
+ "COMMENT \"OLAP\"\n"
+ "DISTRIBUTED BY HASH(`k1`) BUCKETS 5\n"
+ "PROPERTIES (\n"
+ "\"replication_num\" = \"1\"\n"
+ ");");
addRollup("alter table db1.test_pre_agg_tbl add rollup test_rollup(k1, k2, k3, v1, v2, v3, v4, v5, v6, v7)");
String sql1 = "select sum(case when k1 > 0 then v1 when k1 = 0 then 0 when k1 < 0 then v2 else 0 end),"
+ "sum(case when k2 = 1 then 0 else v1 end),"
+ "sum(case when k2 = 1 then null else v2 end),"
+ "sum(case when k2 = 1 then null else v5 end),"
+ "sum(case when k2 = 1 then null else v6 end),"
+ "sum(case when k2 = 1 then null else v7 end)"
+ "from db1.test_pre_agg_tbl";
// legacy planner
Assertions.assertTrue(getSQLPlanOrErrorMsg(sql1).contains(
"TABLE: db1.test_pre_agg_tbl(test_rollup), PREAGGREGATION: ON"));
// nereids planner
PlanChecker.from(connectContext)
.analyze(sql1)
.rewrite()
.matches(logicalOlapScan().when(scan -> {
Assertions.assertEquals("test_rollup", scan.getSelectedMaterializedIndexName().get());
Assertions.assertTrue(scan.getPreAggStatus().isOn());
return true;
}));
String sql2 = "select sum(case when k1 > 0 then v1 else 1 end) from db1.test_pre_agg_tbl";
// legacy planner
Assertions.assertTrue(getSQLPlanOrErrorMsg(sql2).contains("PREAGGREGATION: OFF"));
// nereids planner
PlanChecker.from(connectContext)
.analyze(sql2)
.rewrite()
.matches(logicalOlapScan().when(scan -> {
Assertions.assertEquals("test_pre_agg_tbl", scan.getSelectedMaterializedIndexName().get());
Assertions.assertTrue(scan.getPreAggStatus().isOff());
return true;
}));
String sql3 = "select max(case when k1 > 0 then v3 else null end),min(case when k1 > 0 then null else v4 end)"
+ " from db1.test_pre_agg_tbl";
// legacy planner
Assertions.assertTrue(getSQLPlanOrErrorMsg(sql3).contains(
"TABLE: db1.test_pre_agg_tbl(test_rollup), PREAGGREGATION: ON"));
// nereids planner
PlanChecker.from(connectContext)
.analyze(sql3)
.rewrite()
.matches(logicalOlapScan().when(scan -> {
Assertions.assertEquals("test_rollup", scan.getSelectedMaterializedIndexName().get());
Assertions.assertTrue(scan.getPreAggStatus().isOn());
return true;
}));
String sql4 = "select count(distinct case when k1 > 0 then k1 else null end), "
+ "count(distinct if(k2 < 0, null, k2)) from db1.test_pre_agg_tbl";
// legacy planner
Assertions.assertTrue(getSQLPlanOrErrorMsg(sql4).contains(
"TABLE: db1.test_pre_agg_tbl(test_rollup), PREAGGREGATION: ON"));
// nereids planner
PlanChecker.from(connectContext)
.analyze(sql4)
.rewrite()
.matches(logicalOlapScan().when(scan -> {
Assertions.assertEquals("test_rollup", scan.getSelectedMaterializedIndexName().get());
Assertions.assertTrue(scan.getPreAggStatus().isOn());
return true;
}));
}
}

View File

@ -242,7 +242,7 @@ class SelectRollupIndexTest extends BaseMaterializedIndexSelectTest implements M
.matches(logicalOlapScan().when(scan -> {
PreAggStatus preAgg = scan.getPreAggStatus();
Assertions.assertTrue(preAgg.isOff());
Assertions.assertEquals("Slot((v1 + 1)) in sum((v1 + 1)) is neither key column nor value column.",
Assertions.assertEquals("do not support compound expression [(v1 + 1)] in SUM.",
preAgg.getOffReason());
return true;
}));