[Enhancement](plan) Optimize preagg for aggregate function (#28886)
This commit is contained in:
@ -2627,5 +2627,13 @@ public abstract class Expr extends TreeNode<Expr> implements ParseNode, Cloneabl
|
||||
expr.replaceSlot(tuple);
|
||||
}
|
||||
}
|
||||
|
||||
public boolean isNullLiteral() {
|
||||
return this instanceof NullLiteral;
|
||||
}
|
||||
|
||||
public boolean isZeroLiteral() {
|
||||
return this instanceof LiteralExpr && ((LiteralExpr) this).isZero();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -34,8 +34,10 @@ import org.apache.logging.log4j.Logger;
|
||||
import java.io.DataInput;
|
||||
import java.io.DataOutput;
|
||||
import java.io.IOException;
|
||||
import java.math.BigDecimal;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
|
||||
public abstract class LiteralExpr extends Expr implements Comparable<LiteralExpr> {
|
||||
@ -449,4 +451,31 @@ public abstract class LiteralExpr extends Expr implements Comparable<LiteralExpr
|
||||
public boolean matchExprs(List<Expr> exprs, SelectStmt stmt, boolean ignoreAlias, TupleDescriptor tuple) {
|
||||
return true;
|
||||
}
|
||||
|
||||
/** whether is ZERO value **/
|
||||
public boolean isZero() {
|
||||
boolean isZero = false;
|
||||
switch (type.getPrimitiveType()) {
|
||||
case TINYINT:
|
||||
case SMALLINT:
|
||||
case INT:
|
||||
case BIGINT:
|
||||
case LARGEINT:
|
||||
isZero = this.getLongValue() == 0;
|
||||
break;
|
||||
case FLOAT:
|
||||
case DOUBLE:
|
||||
isZero = this.getDoubleValue() == 0.0f;
|
||||
break;
|
||||
case DECIMALV2:
|
||||
case DECIMAL32:
|
||||
case DECIMAL64:
|
||||
case DECIMAL128:
|
||||
case DECIMAL256:
|
||||
isZero = Objects.equals(((DecimalLiteral) this).getValue(), BigDecimal.ZERO);
|
||||
break;
|
||||
default:
|
||||
}
|
||||
return isZero;
|
||||
}
|
||||
}
|
||||
|
||||
@ -32,6 +32,7 @@ import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory;
|
||||
import org.apache.doris.nereids.rules.rewrite.mv.AbstractSelectMaterializedIndexRule.SlotContext;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.CaseWhen;
|
||||
import org.apache.doris.nereids.trees.expressions.Cast;
|
||||
import org.apache.doris.nereids.trees.expressions.ExprId;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
@ -40,6 +41,7 @@ import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotNotFromChildren;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.VirtualSlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.WhenClause;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnion;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionCount;
|
||||
@ -53,8 +55,10 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.combinator.MergeCombinator;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.combinator.StateCombinator;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.HllHash;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.ToBitmap;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.ToBitmapWithCheck;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
@ -79,6 +83,7 @@ import com.google.common.collect.Maps;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.google.common.collect.Streams;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
@ -879,6 +884,9 @@ public class SelectMaterializedIndexWithAggregate extends AbstractSelectMaterial
|
||||
if (slotOpt.isPresent() && context.keyNameToColumn.containsKey(normalizeName(slotOpt.get().toSql()))) {
|
||||
return PreAggStatus.on();
|
||||
}
|
||||
if (count.child(0).arity() != 0) {
|
||||
return checkSubExpressions(count, null, context);
|
||||
}
|
||||
}
|
||||
return PreAggStatus.off(String.format(
|
||||
"Count distinct is only valid for key columns, but meet %s.", count.toSql()));
|
||||
@ -963,11 +971,106 @@ public class SelectMaterializedIndexWithAggregate extends AbstractSelectMaterial
|
||||
return PreAggStatus.off(String.format("Aggregate operator don't match, aggregate function: %s"
|
||||
+ ", column aggregate type: %s", aggFunc.toSql(), aggType));
|
||||
}
|
||||
} else if (!aggFunc.child(0).children().isEmpty()) {
|
||||
return checkSubExpressions(aggFunc, matchingAggType, ctx);
|
||||
} else {
|
||||
return PreAggStatus.off(String.format("Slot(%s) in %s is neither key column nor value column.",
|
||||
childNameWithFuncName, aggFunc.toSql()));
|
||||
}
|
||||
}
|
||||
|
||||
// check sub expressions in AggregateFunction.
|
||||
private PreAggStatus checkSubExpressions(AggregateFunction aggFunc, AggregateType matchingAggType,
|
||||
CheckContext ctx) {
|
||||
Expression child = aggFunc.child(0);
|
||||
List<Expression> conditionExps = new ArrayList<>();
|
||||
List<Expression> returnExps = new ArrayList<>();
|
||||
|
||||
// ignore cast
|
||||
while (child instanceof Cast) {
|
||||
if (!((Cast) child).getDataType().isNumericType()) {
|
||||
return PreAggStatus.off(String.format("[%s] is not numeric CAST.", child.toSql()));
|
||||
}
|
||||
child = child.child(0);
|
||||
}
|
||||
// step 1: extract all condition exprs and return exprs
|
||||
if (child instanceof If) {
|
||||
conditionExps.add(child.child(0));
|
||||
returnExps.add(child.child(1));
|
||||
returnExps.add(child.child(2));
|
||||
} else if (child instanceof CaseWhen) {
|
||||
CaseWhen caseWhen = (CaseWhen) child;
|
||||
// WHEN THEN
|
||||
for (WhenClause whenClause : caseWhen.getWhenClauses()) {
|
||||
conditionExps.add(whenClause.getOperand());
|
||||
returnExps.add(whenClause.getResult());
|
||||
}
|
||||
// ELSE
|
||||
returnExps.add(caseWhen.getDefaultValue().orElse(new NullLiteral()));
|
||||
} else {
|
||||
// currently, only IF and CASE WHEN are supported
|
||||
returnExps.add(child);
|
||||
}
|
||||
|
||||
// step 2: check condition expressions
|
||||
for (Expression conditionExp : conditionExps) {
|
||||
if (!containsAllColumn(conditionExp, ctx.keyNameToColumn.keySet())) {
|
||||
return PreAggStatus.off(String.format("some columns in condition [%s] is not key.",
|
||||
conditionExp.toSql()));
|
||||
}
|
||||
}
|
||||
|
||||
// step 3: check return expressions
|
||||
// NOTE: now we just support SUM, MIN, MAX and COUNT DISTINCT
|
||||
int returnExprValidateNum = 0;
|
||||
for (Expression returnExp : returnExps) {
|
||||
// ignore cast in return expr
|
||||
while (returnExp instanceof Cast) {
|
||||
returnExp = returnExp.child(0);
|
||||
}
|
||||
// now we only check simple return expressions
|
||||
String exprName = returnExp.getExpressionName();
|
||||
if (!returnExp.children().isEmpty()) {
|
||||
return PreAggStatus.off(String.format("do not support compound expression [%s] in %s.",
|
||||
returnExp.toSql(), matchingAggType));
|
||||
}
|
||||
if (ctx.keyNameToColumn.containsKey(exprName)) {
|
||||
if (matchingAggType != AggregateType.MAX && matchingAggType != AggregateType.MIN
|
||||
&& (aggFunc instanceof Count && !aggFunc.isDistinct())) {
|
||||
return PreAggStatus.off("agg on key column should be MAX, MIN or COUNT DISTINCT.");
|
||||
}
|
||||
}
|
||||
|
||||
if (matchingAggType == AggregateType.SUM) {
|
||||
if ((ctx.valueNameToColumn.containsKey(exprName)
|
||||
&& ctx.valueNameToColumn.get(exprName).getAggregationType() == matchingAggType)
|
||||
|| returnExp.isZeroLiteral() || returnExp.isNullLiteral()) {
|
||||
returnExprValidateNum++;
|
||||
} else {
|
||||
return PreAggStatus.off(String.format("SUM cant preagg for [%s].", aggFunc.toSql()));
|
||||
}
|
||||
} else if (matchingAggType == AggregateType.MAX || matchingAggType == AggregateType.MIN) {
|
||||
if (ctx.keyNameToColumn.containsKey(exprName) || returnExp.isNullLiteral()
|
||||
|| (ctx.valueNameToColumn.containsKey(exprName)
|
||||
&& ctx.valueNameToColumn.get(exprName).getAggregationType() == matchingAggType)) {
|
||||
returnExprValidateNum++;
|
||||
} else {
|
||||
return PreAggStatus.off(String.format("MAX/MIN cant preagg for [%s].", aggFunc.toSql()));
|
||||
}
|
||||
} else if (aggFunc.getName().equalsIgnoreCase("COUNT") && aggFunc.isDistinct()) {
|
||||
if (ctx.keyNameToColumn.containsKey(exprName)
|
||||
|| returnExp.isZeroLiteral() || returnExp.isNullLiteral()) {
|
||||
returnExprValidateNum++;
|
||||
} else {
|
||||
return PreAggStatus.off(String.format("COUNT DISTINCT cant preagg for [%s].", aggFunc.toSql()));
|
||||
}
|
||||
}
|
||||
}
|
||||
if (returnExprValidateNum == returnExps.size()) {
|
||||
return PreAggStatus.on();
|
||||
}
|
||||
return PreAggStatus.off(String.format("cant preagg for [%s].", aggFunc.toSql()));
|
||||
}
|
||||
}
|
||||
|
||||
private static class CheckContext {
|
||||
|
||||
@ -260,6 +260,10 @@ public abstract class Expression extends AbstractTreeNode<Expression> implements
|
||||
}
|
||||
}
|
||||
|
||||
public boolean isZeroLiteral() {
|
||||
return this instanceof Literal && ((Literal) this).isZero();
|
||||
}
|
||||
|
||||
public final Expression castTo(DataType targetType) throws AnalysisException {
|
||||
return uncheckedCastTo(targetType);
|
||||
}
|
||||
|
||||
@ -366,4 +366,27 @@ public abstract class Literal extends Expression implements LeafExpression, Comp
|
||||
public boolean isStringLikeLiteral() {
|
||||
return dataType.isStringLikeType();
|
||||
}
|
||||
|
||||
/** whether is ZERO value **/
|
||||
public boolean isZero() {
|
||||
if (isNullLiteral()) {
|
||||
return false;
|
||||
}
|
||||
if (dataType.isSmallIntType() || dataType.isTinyIntType() || dataType.isIntegerType()) {
|
||||
return getValue().equals(0);
|
||||
} else if (dataType.isBigIntType()) {
|
||||
return getValue().equals(0L);
|
||||
} else if (dataType.isLargeIntType()) {
|
||||
return getValue().equals(BigInteger.ZERO);
|
||||
} else if (dataType.isFloatType()) {
|
||||
return getValue().equals(0.0f);
|
||||
} else if (dataType.isDoubleType()) {
|
||||
return getValue().equals(0.0);
|
||||
} else if (dataType.isDecimalV2Type()) {
|
||||
return getValue().equals(BigDecimal.ZERO);
|
||||
} else if (dataType.isDecimalV3Type()) {
|
||||
return getValue().equals(BigDecimal.ZERO);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@ -28,6 +28,7 @@ import org.apache.doris.analysis.AssertNumRowsElement;
|
||||
import org.apache.doris.analysis.BaseTableRef;
|
||||
import org.apache.doris.analysis.BinaryPredicate;
|
||||
import org.apache.doris.analysis.CaseExpr;
|
||||
import org.apache.doris.analysis.CaseWhenClause;
|
||||
import org.apache.doris.analysis.CastExpr;
|
||||
import org.apache.doris.analysis.Expr;
|
||||
import org.apache.doris.analysis.ExprSubstitutionMap;
|
||||
@ -77,6 +78,7 @@ import org.apache.doris.thrift.TPushAggOp;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.base.Predicate;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.Iterables;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Maps;
|
||||
@ -653,19 +655,38 @@ public class SingleNodePlanner {
|
||||
List<Column> conditionColumns = Lists.newArrayList();
|
||||
if (!(aggExpr.getChild(0) instanceof SlotRef)) {
|
||||
Expr child = aggExpr.getChild(0);
|
||||
if ((child instanceof CastExpr) && (child.getChild(0) instanceof SlotRef)) {
|
||||
if (child.getType().isNumericType()
|
||||
&& child.getChild(0).getType().isNumericType()) {
|
||||
returnColumns.add(((SlotRef) child.getChild(0)).getDesc().getColumn());
|
||||
} else {
|
||||
turnOffReason = "aggExpr.getChild(0)["
|
||||
|
||||
// ignore cast
|
||||
boolean castReturnExprValidate = true;
|
||||
while (child instanceof CastExpr) {
|
||||
if (child.getChild(0) instanceof SlotRef) {
|
||||
if (child.getType().isNumericType() && child.getChild(0).getType().isNumericType()) {
|
||||
returnColumns.add(((SlotRef) child.getChild(0)).getDesc().getColumn());
|
||||
} else {
|
||||
turnOffReason = "aggExpr.getChild(0)["
|
||||
+ aggExpr.getChild(0).toSql()
|
||||
+ "] is not Numeric CastExpr";
|
||||
aggExprValidate = false;
|
||||
break;
|
||||
castReturnExprValidate = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else if (aggExpr.getChild(0) instanceof CaseExpr) {
|
||||
CaseExpr caseExpr = (CaseExpr) aggExpr.getChild(0);
|
||||
child = child.getChild(0);
|
||||
}
|
||||
if (!castReturnExprValidate) {
|
||||
aggExprValidate = false;
|
||||
break;
|
||||
}
|
||||
// convert IF to CASE WHEN.
|
||||
// For example:
|
||||
// IF(a > 1, 1, 0) -> CASE WHEN a > 1 THEN 1 ELSE 0 END
|
||||
if (child instanceof FunctionCallExpr && ((FunctionCallExpr) child)
|
||||
.getFnName().getFunction().equalsIgnoreCase("IF")) {
|
||||
Preconditions.checkArgument(child.getChildren().size() == 3);
|
||||
CaseWhenClause caseWhenClause = new CaseWhenClause(child.getChild(0), child.getChild(1));
|
||||
child = new CaseExpr(ImmutableList.of(caseWhenClause), child.getChild(2));
|
||||
}
|
||||
if (child instanceof CaseExpr) {
|
||||
CaseExpr caseExpr = (CaseExpr) child;
|
||||
List<Expr> conditionExprs = caseExpr.getConditionExprs();
|
||||
for (Expr conditionExpr : conditionExprs) {
|
||||
List<TupleId> conditionTupleIds = Lists.newArrayList();
|
||||
@ -680,8 +701,14 @@ public class SingleNodePlanner {
|
||||
boolean caseReturnExprValidate = true;
|
||||
List<Expr> returnExprs = caseExpr.getReturnExprs();
|
||||
for (Expr returnExpr : returnExprs) {
|
||||
// ignore cast in return expr
|
||||
while (returnExpr instanceof CastExpr) {
|
||||
returnExpr = returnExpr.getChild(0);
|
||||
}
|
||||
if (returnExpr instanceof SlotRef) {
|
||||
returnColumns.add(((SlotRef) returnExpr).getDesc().getColumn());
|
||||
} else if (returnExpr.isNullLiteral() || returnExpr.isZeroLiteral()) {
|
||||
// If then expr is NULL or Zero, open the preaggregation
|
||||
} else {
|
||||
turnOffReason = "aggExpr.getChild(0)[" + aggExpr.getChild(0).toSql()
|
||||
+ "] is not SlotExpr";
|
||||
|
||||
@ -1251,4 +1251,93 @@ class SelectMvIndexTest extends BaseMaterializedIndexSelectTest implements MemoP
|
||||
Assertions.assertEquals(secondTableIndexName, scan1.getSelectedIndexName());
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSubExpressionsInAggregation() throws Exception {
|
||||
createTable("CREATE TABLE db1.`test_pre_agg_tbl` (\n"
|
||||
+ " `k1` int,\n"
|
||||
+ " `k2` int,\n"
|
||||
+ " `k3` char,\n"
|
||||
+ " `k4` int,\n"
|
||||
+ " `k5` bigint,\n"
|
||||
+ " `k6` bigint,\n"
|
||||
+ " `v1` int SUM,\n"
|
||||
+ " `v2` bigint SUM,\n"
|
||||
+ " `v3` bigint MAX,\n"
|
||||
+ " `v4` bigint MIN,\n"
|
||||
+ " `v5` float SUM,\n"
|
||||
+ " `v6` double SUM,\n"
|
||||
+ " `v7` decimal SUM\n"
|
||||
+ ") ENGINE=OLAP\n"
|
||||
+ "AGGREGATE KEY(`k1`, `k2`, `k3`, `k4`, `k5`, `k6`)\n"
|
||||
+ "COMMENT \"OLAP\"\n"
|
||||
+ "DISTRIBUTED BY HASH(`k1`) BUCKETS 5\n"
|
||||
+ "PROPERTIES (\n"
|
||||
+ "\"replication_num\" = \"1\"\n"
|
||||
+ ");");
|
||||
addRollup("alter table db1.test_pre_agg_tbl add rollup test_rollup(k1, k2, k3, v1, v2, v3, v4, v5, v6, v7)");
|
||||
|
||||
String sql1 = "select sum(case when k1 > 0 then v1 when k1 = 0 then 0 when k1 < 0 then v2 else 0 end),"
|
||||
+ "sum(case when k2 = 1 then 0 else v1 end),"
|
||||
+ "sum(case when k2 = 1 then null else v2 end),"
|
||||
+ "sum(case when k2 = 1 then null else v5 end),"
|
||||
+ "sum(case when k2 = 1 then null else v6 end),"
|
||||
+ "sum(case when k2 = 1 then null else v7 end)"
|
||||
+ "from db1.test_pre_agg_tbl";
|
||||
// legacy planner
|
||||
Assertions.assertTrue(getSQLPlanOrErrorMsg(sql1).contains(
|
||||
"TABLE: db1.test_pre_agg_tbl(test_rollup), PREAGGREGATION: ON"));
|
||||
// nereids planner
|
||||
PlanChecker.from(connectContext)
|
||||
.analyze(sql1)
|
||||
.rewrite()
|
||||
.matches(logicalOlapScan().when(scan -> {
|
||||
Assertions.assertEquals("test_rollup", scan.getSelectedMaterializedIndexName().get());
|
||||
Assertions.assertTrue(scan.getPreAggStatus().isOn());
|
||||
return true;
|
||||
}));
|
||||
|
||||
String sql2 = "select sum(case when k1 > 0 then v1 else 1 end) from db1.test_pre_agg_tbl";
|
||||
// legacy planner
|
||||
Assertions.assertTrue(getSQLPlanOrErrorMsg(sql2).contains("PREAGGREGATION: OFF"));
|
||||
// nereids planner
|
||||
PlanChecker.from(connectContext)
|
||||
.analyze(sql2)
|
||||
.rewrite()
|
||||
.matches(logicalOlapScan().when(scan -> {
|
||||
Assertions.assertEquals("test_pre_agg_tbl", scan.getSelectedMaterializedIndexName().get());
|
||||
Assertions.assertTrue(scan.getPreAggStatus().isOff());
|
||||
return true;
|
||||
}));
|
||||
|
||||
String sql3 = "select max(case when k1 > 0 then v3 else null end),min(case when k1 > 0 then null else v4 end)"
|
||||
+ " from db1.test_pre_agg_tbl";
|
||||
// legacy planner
|
||||
Assertions.assertTrue(getSQLPlanOrErrorMsg(sql3).contains(
|
||||
"TABLE: db1.test_pre_agg_tbl(test_rollup), PREAGGREGATION: ON"));
|
||||
// nereids planner
|
||||
PlanChecker.from(connectContext)
|
||||
.analyze(sql3)
|
||||
.rewrite()
|
||||
.matches(logicalOlapScan().when(scan -> {
|
||||
Assertions.assertEquals("test_rollup", scan.getSelectedMaterializedIndexName().get());
|
||||
Assertions.assertTrue(scan.getPreAggStatus().isOn());
|
||||
return true;
|
||||
}));
|
||||
|
||||
String sql4 = "select count(distinct case when k1 > 0 then k1 else null end), "
|
||||
+ "count(distinct if(k2 < 0, null, k2)) from db1.test_pre_agg_tbl";
|
||||
// legacy planner
|
||||
Assertions.assertTrue(getSQLPlanOrErrorMsg(sql4).contains(
|
||||
"TABLE: db1.test_pre_agg_tbl(test_rollup), PREAGGREGATION: ON"));
|
||||
// nereids planner
|
||||
PlanChecker.from(connectContext)
|
||||
.analyze(sql4)
|
||||
.rewrite()
|
||||
.matches(logicalOlapScan().when(scan -> {
|
||||
Assertions.assertEquals("test_rollup", scan.getSelectedMaterializedIndexName().get());
|
||||
Assertions.assertTrue(scan.getPreAggStatus().isOn());
|
||||
return true;
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
@ -242,7 +242,7 @@ class SelectRollupIndexTest extends BaseMaterializedIndexSelectTest implements M
|
||||
.matches(logicalOlapScan().when(scan -> {
|
||||
PreAggStatus preAgg = scan.getPreAggStatus();
|
||||
Assertions.assertTrue(preAgg.isOff());
|
||||
Assertions.assertEquals("Slot((v1 + 1)) in sum((v1 + 1)) is neither key column nor value column.",
|
||||
Assertions.assertEquals("do not support compound expression [(v1 + 1)] in SUM.",
|
||||
preAgg.getOffReason());
|
||||
return true;
|
||||
}));
|
||||
|
||||
Reference in New Issue
Block a user