[feature](nereids) support distinct count (#12159)

support distinct count with group by clause.
for example:
SELECT count(distinct c_custkey + 1) FROM customer group by c_nation;

TODO: support distinct count without group by clause.
This commit is contained in:
yinzhijian
2022-09-15 13:01:47 +08:00
committed by GitHub
parent b11791b9a8
commit 5b6d48ed5b
21 changed files with 421 additions and 114 deletions

View File

@ -256,6 +256,8 @@ public class ExpressionTranslator extends DefaultExpressionVisitor<Expr, PlanTra
Count count = (Count) function;
if (count.isStar()) {
return new FunctionCallExpr(function.getName(), FunctionParams.createStarParam());
} else if (count.isDistinct()) {
return new FunctionCallExpr(function.getName(), new FunctionParams(true, paramList));
}
}
return new FunctionCallExpr(function.getName(), paramList);

View File

@ -191,12 +191,17 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
// 3. generate output tuple
List<Slot> slotList = Lists.newArrayList();
TupleDescriptor outputTupleDesc;
if (aggregate.getAggPhase() == AggPhase.GLOBAL) {
if (aggregate.getAggPhase() == AggPhase.LOCAL) {
outputTupleDesc = generateTupleDesc(aggregate.getOutput(), null, context);
} else if ((aggregate.getAggPhase() == AggPhase.GLOBAL && aggregate.isFinalPhase())
|| aggregate.getAggPhase() == AggPhase.DISTINCT_LOCAL) {
slotList.addAll(groupSlotList);
slotList.addAll(aggFunctionOutput);
outputTupleDesc = generateTupleDesc(slotList, null, context);
} else {
outputTupleDesc = generateTupleDesc(aggregate.getOutput(), null, context);
// In the distinct agg scenario, global shares local's desc
AggregationNode localAggNode = (AggregationNode) inputPlanFragment.getPlanRoot().getChild(0);
outputTupleDesc = localAggNode.getAggInfo().getOutputTupleDesc();
}
if (aggregate.getAggPhase() == AggPhase.GLOBAL) {
@ -204,6 +209,13 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
execAggregateFunction.setMergeForNereids(true);
}
}
if (aggregate.getAggPhase() == AggPhase.DISTINCT_LOCAL) {
for (FunctionCallExpr execAggregateFunction : execAggregateFunctions) {
if (!execAggregateFunction.isDistinct()) {
execAggregateFunction.setMergeForNereids(true);
}
}
}
AggregateInfo aggInfo = AggregateInfo.create(execGroupingExpressions, execAggregateFunctions, outputTupleDesc,
outputTupleDesc, aggregate.getAggPhase().toExec());
AggregationNode aggregationNode = new AggregationNode(context.nextPlanNodeId(),
@ -216,6 +228,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
aggregationNode.setIntermediateTuple();
break;
case GLOBAL:
case DISTINCT_LOCAL:
if (currentFragment.getPlanRoot() instanceof ExchangeNode) {
ExchangeNode exchangeNode = (ExchangeNode) currentFragment.getPlanRoot();
currentFragment = new PlanFragment(context.nextFragmentId(), exchangeNode, mergePartition);

View File

@ -80,12 +80,12 @@ public class ChildOutputPropertyDeriver extends PlanVisitor<PhysicalProperties,
case LOCAL:
return new PhysicalProperties(childOutputProperty.getDistributionSpec());
case GLOBAL:
case DISTINCT_LOCAL:
List<ExprId> columns = agg.getPartitionExpressions().stream()
.map(SlotReference.class::cast)
.map(SlotReference::getExprId)
.collect(Collectors.toList());
return PhysicalProperties.createHash(new DistributionSpecHash(columns, ShuffleType.AGGREGATE));
case DISTINCT_LOCAL:
case DISTINCT_GLOBAL:
default:
throw new RuntimeException("Could not derive output properties for agg phase: " + agg.getAggPhase());

View File

@ -25,6 +25,7 @@ import org.apache.doris.nereids.properties.DistributionSpecHash.ShuffleType;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.AggPhase;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalAggregate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
@ -82,14 +83,16 @@ public class RequestPropertyDeriver extends PlanVisitor<Void, PlanContext> {
addToRequestPropertyToChildren(PhysicalProperties.ANY);
return null;
}
if (agg.getAggPhase() == AggPhase.GLOBAL && !agg.isFinalPhase()) {
addToRequestPropertyToChildren(requestPropertyFromParent);
return null;
}
// 2. second phase agg, need to return shuffle with partition key
List<Expression> partitionExpressions = agg.getPartitionExpressions();
if (partitionExpressions.isEmpty()) {
addToRequestPropertyToChildren(PhysicalProperties.GATHER);
return null;
}
// TODO: when parent is a join node,
// use requestPropertyFromParent to keep column order as join to avoid shuffle again.
if (partitionExpressions.stream().allMatch(SlotReference.class::isInstance)) {

View File

@ -27,6 +27,7 @@ import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.TimestampArithmetic;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.Count;
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.GroupPlan;
@ -115,6 +116,17 @@ public class BindFunction implements AnalysisRuleFactory {
@Override
public BoundFunction visitUnboundFunction(UnboundFunction unboundFunction, Env env) {
// FunctionRegistry can't support boolean arg now, tricky here.
if (unboundFunction.getName().equalsIgnoreCase("count")) {
List<Expression> arguments = unboundFunction.getArguments();
if ((arguments.size() == 0 && unboundFunction.isStar()) || arguments.stream()
.allMatch(Expression::isConstant)) {
return new Count();
}
if (arguments.size() == 1) {
return new Count(unboundFunction.getArguments().get(0), unboundFunction.isDistinct());
}
}
FunctionRegistry functionRegistry = env.getFunctionRegistry();
String functionName = unboundFunction.getName();
FunctionBuilder builder = functionRegistry.findFunctionBuilder(

View File

@ -126,7 +126,7 @@ public class ExpressionRewrite implements RewriteRuleFactory {
return agg;
}
return new LogicalAggregate<>(newGroupByExprs, newOutputExpressions,
agg.isDisassembled(), agg.isNormalized(), agg.getAggPhase(), agg.child());
agg.isDisassembled(), agg.isNormalized(), agg.isFinalPhase(), agg.getAggPhase(), agg.child());
}).toRule(RuleType.REWRITE_AGG_EXPRESSION);
}
}

View File

@ -36,6 +36,7 @@ public class LogicalAggToPhysicalHashAgg extends OneImplementationRuleFactory {
ImmutableList.of(),
agg.getAggPhase(),
false,
agg.isFinalPhase(),
agg.getLogicalProperties(),
agg.child())
).toRule(RuleType.LOGICAL_AGG_TO_PHYSICAL_HASH_AGG_RULE);

View File

@ -49,92 +49,187 @@ import java.util.stream.Collectors;
* +-- Aggregate(phase: [LOCAL], outputExpr: [SUM(v1 * v2) as a, (k + 1) as b], groupByExpr: [k + 1])
* +-- childPlan
*
* Distinct Agg With Group By Processing:
* If we have a query: SELECT count(distinct v1 * v2) + 1 FROM t GROUP BY k + 1
* the initial plan is:
* Aggregate(phase: [GLOBAL], outputExpr: [Alias(k + 1) #1, Alias(COUNT(distinct v1 * v2) + 1) #2]
* , groupByExpr: [k + 1])
* +-- childPlan
* we should rewrite to:
* Aggregate(phase: [DISTINCT_LOCAL], outputExpr: [Alias(b) #1, Alias(COUNT(distinct a) + 1) #2], groupByExpr: [b])
* +-- Aggregate(phase: [GLOBAL], outputExpr: [b, a], groupByExpr: [b, a])
* +-- Aggregate(phase: [LOCAL], outputExpr: [(k + 1) as b, (v1 * v2) as a], groupByExpr: [k + 1, a])
* +-- childPlan
*
* TODO:
* 1. use different class represent different phase aggregate
* 2. if instance count is 1, shouldn't disassemble the agg plan
*/
public class AggregateDisassemble extends OneRewriteRuleFactory {
// used in secondDisassemble to transform local expressions into global
private final Map<Expression, Expression> globalOutputSubstitutionMap = Maps.newHashMap();
// used in secondDisassemble to transform local expressions into global
private final Map<Expression, Expression> globalGroupBySubstitutionMap = Maps.newHashMap();
// used to indicate the existence of a distinct function for the entire phase
private boolean hasDistinctAgg = false;
@Override
public Rule build() {
return logicalAggregate().when(agg -> !agg.isDisassembled()).thenApply(ctx -> {
LogicalAggregate<GroupPlan> aggregate = ctx.root;
List<NamedExpression> originOutputExprs = aggregate.getOutputExpressions();
List<Expression> originGroupByExprs = aggregate.getGroupByExpressions();
// 1. generate a map from local aggregate output to global aggregate expr substitution.
// inputSubstitutionMap use for replacing expression in global aggregate
// replace rule is:
// a: Expression is a group by key and is a slot reference. e.g. group by k1
// b. Expression is a group by key and is an expression. e.g. group by k1 + 1
// c. Expression is an aggregate function. e.g. sum(v1) in select list
// +-----------+---------------------+-------------------------+--------------------------------+
// | situation | origin expression | local output expression | expression in global aggregate |
// +-----------+---------------------+-------------------------+--------------------------------+
// | a | Ref(k1)#1 | Ref(k1)#1 | Ref(k1)#1 |
// +-----------+---------------------+-------------------------+--------------------------------+
// | b | Ref(k1)#1 + 1 | A(Ref(k1)#1 + 1, key)#2 | Ref(key)#2 |
// +-----------+---------------------+-------------------------+--------------------------------+
// | c | A(AF(v1#1), 'af')#2 | A(AF(v1#1), 'af')#3 | AF(af#3) |
// +-----------+---------------------+-------------------------+--------------------------------+
// NOTICE: Ref: SlotReference, A: Alias, AF: AggregateFunction, #x: ExprId x
// 2. collect local aggregate output expressions and local aggregate group by expression list
Map<Expression, Expression> inputSubstitutionMap = Maps.newHashMap();
List<Expression> localGroupByExprs = aggregate.getGroupByExpressions();
List<NamedExpression> localOutputExprs = Lists.newArrayList();
for (Expression originGroupByExpr : originGroupByExprs) {
if (inputSubstitutionMap.containsKey(originGroupByExpr)) {
continue;
}
if (originGroupByExpr instanceof SlotReference) {
inputSubstitutionMap.put(originGroupByExpr, originGroupByExpr);
localOutputExprs.add((SlotReference) originGroupByExpr);
} else {
NamedExpression localOutputExpr = new Alias(originGroupByExpr, originGroupByExpr.toSql());
inputSubstitutionMap.put(originGroupByExpr, localOutputExpr.toSlot());
localOutputExprs.add(localOutputExpr);
}
LogicalAggregate firstAggregate = firstDisassemble(aggregate);
if (!hasDistinctAgg) {
return firstAggregate;
}
for (NamedExpression originOutputExpr : originOutputExprs) {
Set<AggregateFunction> aggregateFunctions
= originOutputExpr.collect(AggregateFunction.class::isInstance);
for (AggregateFunction aggregateFunction : aggregateFunctions) {
if (inputSubstitutionMap.containsKey(aggregateFunction)) {
continue;
}
NamedExpression localOutputExpr = new Alias(aggregateFunction, aggregateFunction.toSql());
Expression substitutionValue = aggregateFunction.withChildren(
Lists.newArrayList(localOutputExpr.toSlot()));
inputSubstitutionMap.put(aggregateFunction, substitutionValue);
localOutputExprs.add(localOutputExpr);
}
}
// 3. replace expression in globalOutputExprs and globalGroupByExprs
List<NamedExpression> globalOutputExprs = aggregate.getOutputExpressions().stream()
.map(e -> ExpressionUtils.replace(e, inputSubstitutionMap))
.map(NamedExpression.class::cast)
.collect(Collectors.toList());
List<Expression> globalGroupByExprs = localGroupByExprs.stream()
.map(e -> ExpressionUtils.replace(e, inputSubstitutionMap)).collect(Collectors.toList());
// 4. generate new plan
LogicalAggregate localAggregate = new LogicalAggregate<>(
localGroupByExprs,
localOutputExprs,
true,
aggregate.isNormalized(),
AggPhase.LOCAL,
aggregate.child()
);
return new LogicalAggregate<>(
globalGroupByExprs,
globalOutputExprs,
true,
aggregate.isNormalized(),
AggPhase.GLOBAL,
localAggregate
);
return secondDisassemble(firstAggregate);
}).toRule(RuleType.AGGREGATE_DISASSEMBLE);
}
// only support distinct function with group by
// TODO: support distinct function without group by. (add second global phase)
private LogicalAggregate secondDisassemble(LogicalAggregate<LogicalAggregate> aggregate) {
LogicalAggregate<GroupPlan> local = aggregate.child();
// replace expression in globalOutputExprs and globalGroupByExprs
List<NamedExpression> globalOutputExprs = local.getOutputExpressions().stream()
.map(e -> ExpressionUtils.replace(e, globalOutputSubstitutionMap))
.map(NamedExpression.class::cast)
.collect(Collectors.toList());
List<Expression> globalGroupByExprs = local.getGroupByExpressions().stream()
.map(e -> ExpressionUtils.replace(e, globalGroupBySubstitutionMap))
.collect(Collectors.toList());
// generate new plan
LogicalAggregate globalAggregate = new LogicalAggregate<>(
globalGroupByExprs,
globalOutputExprs,
true,
aggregate.isNormalized(),
false,
AggPhase.GLOBAL,
local
);
return new LogicalAggregate<>(
aggregate.getGroupByExpressions(),
aggregate.getOutputExpressions(),
true,
aggregate.isNormalized(),
true,
AggPhase.DISTINCT_LOCAL,
globalAggregate
);
}
private LogicalAggregate firstDisassemble(LogicalAggregate<GroupPlan> aggregate) {
List<NamedExpression> originOutputExprs = aggregate.getOutputExpressions();
List<Expression> originGroupByExprs = aggregate.getGroupByExpressions();
Map<Expression, Expression> inputSubstitutionMap = Maps.newHashMap();
// 1. generate a map from local aggregate output to global aggregate expr substitution.
// inputSubstitutionMap use for replacing expression in global aggregate
// replace rule is:
// a: Expression is a group by key and is a slot reference. e.g. group by k1
// b. Expression is a group by key and is an expression. e.g. group by k1 + 1
// c. Expression is an aggregate function. e.g. sum(v1) in select list
// +-----------+---------------------+-------------------------+--------------------------------+
// | situation | origin expression | local output expression | expression in global aggregate |
// +-----------+---------------------+-------------------------+--------------------------------+
// | a | Ref(k1)#1 | Ref(k1)#1 | Ref(k1)#1 |
// +-----------+---------------------+-------------------------+--------------------------------+
// | b | Ref(k1)#1 + 1 | A(Ref(k1)#1 + 1, key)#2 | Ref(key)#2 |
// +-----------+---------------------+-------------------------+--------------------------------+
// | c | A(AF(v1#1), 'af')#2 | A(AF(v1#1), 'af')#3 | AF(af#3) |
// +-----------+---------------------+-------------------------+--------------------------------+
// NOTICE: Ref: SlotReference, A: Alias, AF: AggregateFunction, #x: ExprId x
// 2. collect local aggregate output expressions and local aggregate group by expression list
List<Expression> localGroupByExprs = aggregate.getGroupByExpressions();
List<NamedExpression> localOutputExprs = Lists.newArrayList();
for (Expression originGroupByExpr : originGroupByExprs) {
if (inputSubstitutionMap.containsKey(originGroupByExpr)) {
continue;
}
if (originGroupByExpr instanceof SlotReference) {
inputSubstitutionMap.put(originGroupByExpr, originGroupByExpr);
globalOutputSubstitutionMap.put(originGroupByExpr, originGroupByExpr);
globalGroupBySubstitutionMap.put(originGroupByExpr, originGroupByExpr);
localOutputExprs.add((SlotReference) originGroupByExpr);
} else {
NamedExpression localOutputExpr = new Alias(originGroupByExpr, originGroupByExpr.toSql());
inputSubstitutionMap.put(originGroupByExpr, localOutputExpr.toSlot());
globalOutputSubstitutionMap.put(localOutputExpr, localOutputExpr.toSlot());
globalGroupBySubstitutionMap.put(originGroupByExpr, localOutputExpr.toSlot());
localOutputExprs.add(localOutputExpr);
}
}
List<Expression> distinctExprsForLocalGroupBy = Lists.newArrayList();
List<NamedExpression> distinctExprsForLocalOutput = Lists.newArrayList();
for (NamedExpression originOutputExpr : originOutputExprs) {
Set<AggregateFunction> aggregateFunctions
= originOutputExpr.collect(AggregateFunction.class::isInstance);
for (AggregateFunction aggregateFunction : aggregateFunctions) {
if (inputSubstitutionMap.containsKey(aggregateFunction)) {
continue;
}
if (aggregateFunction.isDistinct()) {
hasDistinctAgg = true;
for (Expression expr : aggregateFunction.children()) {
if (expr instanceof SlotReference) {
distinctExprsForLocalOutput.add((SlotReference) expr);
if (!inputSubstitutionMap.containsKey(expr)) {
inputSubstitutionMap.put(expr, expr);
globalOutputSubstitutionMap.put(expr, expr);
globalGroupBySubstitutionMap.put(expr, expr);
}
} else {
NamedExpression globalOutputExpr = new Alias(expr, expr.toSql());
distinctExprsForLocalOutput.add(globalOutputExpr);
if (!inputSubstitutionMap.containsKey(expr)) {
inputSubstitutionMap.put(expr, globalOutputExpr.toSlot());
globalOutputSubstitutionMap.put(globalOutputExpr, globalOutputExpr.toSlot());
globalGroupBySubstitutionMap.put(expr, globalOutputExpr.toSlot());
}
}
distinctExprsForLocalGroupBy.add(expr);
}
continue;
}
NamedExpression localOutputExpr = new Alias(aggregateFunction, aggregateFunction.toSql());
Expression substitutionValue = aggregateFunction.withChildren(
Lists.newArrayList(localOutputExpr.toSlot()));
inputSubstitutionMap.put(aggregateFunction, substitutionValue);
globalOutputSubstitutionMap.put(aggregateFunction, substitutionValue);
localOutputExprs.add(localOutputExpr);
}
}
// 3. replace expression in globalOutputExprs and globalGroupByExprs
List<NamedExpression> globalOutputExprs = aggregate.getOutputExpressions().stream()
.map(e -> ExpressionUtils.replace(e, inputSubstitutionMap))
.map(NamedExpression.class::cast)
.collect(Collectors.toList());
List<Expression> globalGroupByExprs = localGroupByExprs.stream()
.map(e -> ExpressionUtils.replace(e, inputSubstitutionMap)).collect(Collectors.toList());
// To avoid repeated substitution of distinct expressions,
// here the expressions are put into the local after the substitution is completed
localOutputExprs.addAll(distinctExprsForLocalOutput);
localGroupByExprs.addAll(distinctExprsForLocalGroupBy);
// 4. generate new plan
LogicalAggregate localAggregate = new LogicalAggregate<>(
localGroupByExprs,
localOutputExprs,
true,
aggregate.isNormalized(),
false,
AggPhase.LOCAL,
aggregate.child()
);
return new LogicalAggregate<>(
globalGroupByExprs,
globalOutputExprs,
true,
aggregate.isNormalized(),
true,
AggPhase.GLOBAL,
localAggregate
);
}
}

View File

@ -124,7 +124,7 @@ public class NormalizeAggregate extends OneRewriteRuleFactory {
root = new LogicalProject<>(bottomProjections, root);
}
root = new LogicalAggregate<>(newKeys, newOutputs, aggregate.isDisassembled(),
true, aggregate.getAggPhase(), root);
true, aggregate.isFinalPhase(), aggregate.getAggPhase(), root);
List<NamedExpression> projections = outputs.stream()
.map(e -> ExpressionUtils.replace(e, substitutionMap))
.map(NamedExpression.class::cast)

View File

@ -21,19 +21,50 @@ import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
import java.util.Objects;
/**
* The function which consume arguments in lots of rows and product one value.
*/
public abstract class AggregateFunction extends BoundFunction {
private DataType intermediate;
private final boolean isDistinct;
public AggregateFunction(String name, Expression... arguments) {
super(name, arguments);
isDistinct = false;
}
public AggregateFunction(String name, boolean isDistinct, Expression... arguments) {
super(name, arguments);
this.isDistinct = isDistinct;
}
public abstract DataType getIntermediateType();
public boolean isDistinct() {
return isDistinct;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
AggregateFunction that = (AggregateFunction) o;
return Objects.equals(isDistinct, that.isDistinct) && Objects.equals(intermediate, that.intermediate)
&& Objects.equals(getName(), that.getName()) && Objects.equals(children, that.children);
}
@Override
public int hashCode() {
return Objects.hash(isDistinct, intermediate, getName(), children);
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitAggregateFunction(this, context);

View File

@ -37,8 +37,8 @@ public class Count extends AggregateFunction {
this.isStar = true;
}
public Count(Expression child) {
super("count", child);
public Count(Expression child, boolean isDistinct) {
super("count", isDistinct, child);
this.isStar = false;
}
@ -62,7 +62,7 @@ public class Count extends AggregateFunction {
if (children.size() == 0) {
return new Count();
}
return new Count(children.get(0));
return new Count(children.get(0), isDistinct());
}
@Override
@ -79,6 +79,9 @@ public class Count extends AggregateFunction {
.stream()
.map(Expression::toSql)
.collect(Collectors.joining(", "));
if (isDistinct()) {
return "count(distinct " + args + ")";
}
return "count(" + args + ")";
}
@ -91,6 +94,9 @@ public class Count extends AggregateFunction {
.stream()
.map(Expression::toString)
.collect(Collectors.joining(", "));
if (isDistinct()) {
return "count(distinct " + args + ")";
}
return "count(" + args + ")";
}
}

View File

@ -59,6 +59,13 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> extends LogicalUnary<CHIL
private final List<NamedExpression> outputExpressions;
private final AggPhase aggPhase;
// use for scenes containing distinct agg
// 1. If there are LOCAL and GLOBAL phases, global is the final phase
// 2. If there are LOCAL, GLOBAL and DISTINCT_LOCAL phases, DISTINCT_LOCAL is the final phase
// 3. If there are LOCAL, GLOBAL, DISTINCT_LOCAL, DISTINCT_GLOBAL phases,
// DISTINCT_GLOBAL is the final phase
private final boolean isFinalPhase;
/**
* Desc: Constructor for LogicalAggregate.
*/
@ -66,7 +73,7 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> extends LogicalUnary<CHIL
List<Expression> groupByExpressions,
List<NamedExpression> outputExpressions,
CHILD_TYPE child) {
this(groupByExpressions, outputExpressions, false, false, AggPhase.GLOBAL, child);
this(groupByExpressions, outputExpressions, false, false, true, AggPhase.GLOBAL, child);
}
public LogicalAggregate(
@ -74,9 +81,10 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> extends LogicalUnary<CHIL
List<NamedExpression> outputExpressions,
boolean disassembled,
boolean normalized,
boolean isFinalPhase,
AggPhase aggPhase,
CHILD_TYPE child) {
this(groupByExpressions, outputExpressions, disassembled, normalized,
this(groupByExpressions, outputExpressions, disassembled, normalized, isFinalPhase,
aggPhase, Optional.empty(), Optional.empty(), child);
}
@ -88,6 +96,7 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> extends LogicalUnary<CHIL
List<NamedExpression> outputExpressions,
boolean disassembled,
boolean normalized,
boolean isFinalPhase,
AggPhase aggPhase,
Optional<GroupExpression> groupExpression,
Optional<LogicalProperties> logicalProperties,
@ -97,6 +106,7 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> extends LogicalUnary<CHIL
this.outputExpressions = outputExpressions;
this.disassembled = disassembled;
this.normalized = normalized;
this.isFinalPhase = isFinalPhase;
this.aggPhase = aggPhase;
}
@ -149,6 +159,10 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> extends LogicalUnary<CHIL
return normalized;
}
public boolean isFinalPhase() {
return isFinalPhase;
}
/**
* Determine the equality with another plan
*/
@ -164,37 +178,37 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> extends LogicalUnary<CHIL
&& Objects.equals(outputExpressions, that.outputExpressions)
&& aggPhase == that.aggPhase
&& disassembled == that.disassembled
&& normalized == that.normalized;
&& normalized == that.normalized
&& isFinalPhase == that.isFinalPhase;
}
@Override
public int hashCode() {
return Objects.hash(groupByExpressions, outputExpressions, aggPhase, normalized, disassembled);
return Objects.hash(groupByExpressions, outputExpressions, aggPhase, normalized, disassembled, isFinalPhase);
}
@Override
public LogicalAggregate<Plan> withChildren(List<Plan> children) {
Preconditions.checkArgument(children.size() == 1);
return new LogicalAggregate<>(groupByExpressions, outputExpressions,
disassembled, normalized, aggPhase, children.get(0));
disassembled, normalized, isFinalPhase, aggPhase, children.get(0));
}
@Override
public LogicalAggregate<Plan> withGroupExpression(Optional<GroupExpression> groupExpression) {
return new LogicalAggregate<>(groupByExpressions, outputExpressions,
disassembled, normalized, aggPhase, groupExpression, Optional.of(getLogicalProperties()),
children.get(0));
return new LogicalAggregate<>(groupByExpressions, outputExpressions, disassembled, normalized, isFinalPhase,
aggPhase, groupExpression, Optional.of(getLogicalProperties()), children.get(0));
}
@Override
public LogicalAggregate<Plan> withLogicalProperties(Optional<LogicalProperties> logicalProperties) {
return new LogicalAggregate<>(groupByExpressions, outputExpressions,
disassembled, normalized, aggPhase, Optional.empty(), logicalProperties, children.get(0));
return new LogicalAggregate<>(groupByExpressions, outputExpressions, disassembled, normalized, isFinalPhase,
aggPhase, Optional.empty(), logicalProperties, children.get(0));
}
public LogicalAggregate<Plan> withGroupByAndOutput(List<Expression> groupByExprList,
List<NamedExpression> outputExpressionList) {
return new LogicalAggregate<>(groupByExprList, outputExpressionList,
disassembled, normalized, aggPhase, child());
disassembled, normalized, isFinalPhase, aggPhase, child());
}
}

View File

@ -53,11 +53,18 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan> extends PhysicalUnary<CH
private final boolean usingStream;
// use for scenes containing distinct agg
// 1. If there are LOCAL and GLOBAL phases, global is the final phase
// 2. If there are LOCAL, GLOBAL and DISTINCT_LOCAL phases, DISTINCT_LOCAL is the final phase
// 3. If there are LOCAL, GLOBAL, DISTINCT_LOCAL, DISTINCT_GLOBAL phases,
// DISTINCT_GLOBAL is the final phase
private final boolean isFinalPhase;
public PhysicalAggregate(List<Expression> groupByExpressions, List<NamedExpression> outputExpressions,
List<Expression> partitionExpressions, AggPhase aggPhase, boolean usingStream,
LogicalProperties logicalProperties, CHILD_TYPE child) {
boolean isFinalPhase, LogicalProperties logicalProperties, CHILD_TYPE child) {
this(groupByExpressions, outputExpressions, partitionExpressions, aggPhase, usingStream,
Optional.empty(), logicalProperties, child);
isFinalPhase, Optional.empty(), logicalProperties, child);
}
/**
@ -69,7 +76,7 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan> extends PhysicalUnary<CH
* @param usingStream whether it's stream agg.
*/
public PhysicalAggregate(List<Expression> groupByExpressions, List<NamedExpression> outputExpressions,
List<Expression> partitionExpressions, AggPhase aggPhase, boolean usingStream,
List<Expression> partitionExpressions, AggPhase aggPhase, boolean usingStream, boolean isFinalPhase,
Optional<GroupExpression> groupExpression, LogicalProperties logicalProperties,
CHILD_TYPE child) {
super(PlanType.PHYSICAL_AGGREGATE, groupExpression, logicalProperties, child);
@ -78,6 +85,7 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan> extends PhysicalUnary<CH
this.aggPhase = aggPhase;
this.partitionExpressions = partitionExpressions;
this.usingStream = usingStream;
this.isFinalPhase = isFinalPhase;
}
/**
@ -89,7 +97,7 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan> extends PhysicalUnary<CH
* @param usingStream whether it's stream agg.
*/
public PhysicalAggregate(List<Expression> groupByExpressions, List<NamedExpression> outputExpressions,
List<Expression> partitionExpressions, AggPhase aggPhase, boolean usingStream,
List<Expression> partitionExpressions, AggPhase aggPhase, boolean usingStream, boolean isFinalPhase,
Optional<GroupExpression> groupExpression, LogicalProperties logicalProperties,
PhysicalProperties physicalProperties, CHILD_TYPE child) {
super(PlanType.PHYSICAL_AGGREGATE, groupExpression, logicalProperties, physicalProperties, child);
@ -98,6 +106,7 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan> extends PhysicalUnary<CH
this.aggPhase = aggPhase;
this.partitionExpressions = partitionExpressions;
this.usingStream = usingStream;
this.isFinalPhase = isFinalPhase;
}
public AggPhase getAggPhase() {
@ -112,6 +121,10 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan> extends PhysicalUnary<CH
return outputExpressions;
}
public boolean isFinalPhase() {
return isFinalPhase;
}
public boolean isUsingStream() {
return usingStream;
}
@ -156,36 +169,38 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan> extends PhysicalUnary<CH
&& Objects.equals(outputExpressions, that.outputExpressions)
&& Objects.equals(partitionExpressions, that.partitionExpressions)
&& usingStream == that.usingStream
&& aggPhase == that.aggPhase;
&& aggPhase == that.aggPhase
&& isFinalPhase == that.isFinalPhase;
}
@Override
public int hashCode() {
return Objects.hash(groupByExpressions, outputExpressions, partitionExpressions, aggPhase, usingStream);
return Objects.hash(groupByExpressions, outputExpressions, partitionExpressions, aggPhase, usingStream,
isFinalPhase);
}
@Override
public PhysicalAggregate<Plan> withChildren(List<Plan> children) {
Preconditions.checkArgument(children.size() == 1);
return new PhysicalAggregate<>(groupByExpressions, outputExpressions, partitionExpressions, aggPhase,
usingStream, getLogicalProperties(), children.get(0));
usingStream, isFinalPhase, getLogicalProperties(), children.get(0));
}
@Override
public PhysicalAggregate<CHILD_TYPE> withGroupExpression(Optional<GroupExpression> groupExpression) {
return new PhysicalAggregate<>(groupByExpressions, outputExpressions, partitionExpressions, aggPhase,
usingStream, groupExpression, getLogicalProperties(), child());
usingStream, isFinalPhase, groupExpression, getLogicalProperties(), child());
}
@Override
public PhysicalAggregate<CHILD_TYPE> withLogicalProperties(Optional<LogicalProperties> logicalProperties) {
return new PhysicalAggregate<>(groupByExpressions, outputExpressions, partitionExpressions, aggPhase,
usingStream, Optional.empty(), logicalProperties.get(), child());
usingStream, isFinalPhase, Optional.empty(), logicalProperties.get(), child());
}
@Override
public PhysicalAggregate<CHILD_TYPE> withPhysicalProperties(PhysicalProperties physicalProperties) {
return new PhysicalAggregate<>(groupByExpressions, outputExpressions, partitionExpressions, aggPhase,
usingStream, Optional.empty(), getLogicalProperties(), physicalProperties, child());
usingStream, isFinalPhase, Optional.empty(), getLogicalProperties(), physicalProperties, child());
}
}

View File

@ -360,9 +360,9 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat
Alias pk11 = new Alias(new ExprId(8), new Add(new Add(pk, Literal.of((byte) 1)), Literal.of((byte) 1)), "((pk + 1) + 1)");
Alias pk2 = new Alias(new ExprId(9), new Add(pk, Literal.of((byte) 2)), "(pk + 2)");
Alias sumA1 = new Alias(new ExprId(10), new Sum(a1), "SUM(a1)");
Alias countA11 = new Alias(new ExprId(11), new Add(new Count(a1), Literal.of((byte) 1)), "(COUNT(a1) + 1)");
Alias countA11 = new Alias(new ExprId(11), new Add(new Count(a1, false), Literal.of((byte) 1)), "(COUNT(a1) + 1)");
Alias sumA1A2 = new Alias(new ExprId(12), new Sum(new Add(a1, a2)), "SUM((a1 + a2))");
Alias v1 = new Alias(new ExprId(0), new Count(a2), "v1");
Alias v1 = new Alias(new ExprId(0), new Count(a2, false), "v1");
PlanChecker.from(connectContext).analyze(sql)
.matchesFromRoot(
logicalProject(

View File

@ -263,6 +263,7 @@ public class ChildOutputPropertyDeriverTest {
Lists.newArrayList(key),
AggPhase.LOCAL,
true,
true,
logicalProperties,
groupPlan
);
@ -286,6 +287,7 @@ public class ChildOutputPropertyDeriverTest {
Lists.newArrayList(partition),
AggPhase.GLOBAL,
true,
true,
logicalProperties,
groupPlan
);

View File

@ -146,6 +146,7 @@ public class RequestPropertyDeriverTest {
Lists.newArrayList(key),
AggPhase.LOCAL,
true,
true,
logicalProperties,
groupPlan
);
@ -168,6 +169,7 @@ public class RequestPropertyDeriverTest {
Lists.newArrayList(partition),
AggPhase.GLOBAL,
true,
true,
logicalProperties,
groupPlan
);
@ -192,6 +194,7 @@ public class RequestPropertyDeriverTest {
Lists.newArrayList(),
AggPhase.GLOBAL,
true,
true,
logicalProperties,
groupPlan
);

View File

@ -23,6 +23,7 @@ import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.Count;
import org.apache.doris.nereids.trees.expressions.functions.Sum;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.plans.AggPhase;
@ -269,6 +270,86 @@ public class AggregateDisassembleTest {
global.getOutputExpressions().get(0).getExprId());
}
/**
* the initial plan is:
* Aggregate(phase: [GLOBAL], outputExpr: [(COUNT(distinct age + 1) + 2) as c], groupByExpr: [id + 3])
* +-- childPlan(id, name, age)
* we should rewrite to:
* Aggregate(phase: [DISTINCT_LOCAL], outputExpr: [(COUNT(distinct b) + 2) as c], groupByExpr: [a])
* +-- Aggregate(phase: [GLOBAL], outputExpr: [a, b], groupByExpr: [a, b])
* +-- Aggregate(phase: [LOCAL], outputExpr: [(id + 3) as a, (age + 1) as b], groupByExpr: [id + 3, age + 1])
* +-- childPlan(id, name, age)
*/
@Test
public void distinctAggregateWithGroupBy() {
List<Expression> groupExpressionList = Lists.newArrayList(
new Add(rStudent.getOutput().get(0).toSlot(), new IntegerLiteral(3)));
List<NamedExpression> outputExpressionList = Lists.newArrayList(new Alias(
new Add(new Count(new Add(rStudent.getOutput().get(2).toSlot(), new IntegerLiteral(1)), true),
new IntegerLiteral(2)), "c"));
Plan root = new LogicalAggregate<>(groupExpressionList, outputExpressionList, rStudent);
Plan after = rewrite(root);
Assertions.assertTrue(after instanceof LogicalUnary);
Assertions.assertTrue(after instanceof LogicalAggregate);
Assertions.assertTrue(after.child(0) instanceof LogicalUnary);
LogicalAggregate<Plan> distinctLocal = (LogicalAggregate) after;
LogicalAggregate<Plan> global = (LogicalAggregate) after.child(0);
LogicalAggregate<Plan> local = (LogicalAggregate) after.child(0).child(0);
Assertions.assertEquals(AggPhase.DISTINCT_LOCAL, distinctLocal.getAggPhase());
Assertions.assertEquals(AggPhase.GLOBAL, global.getAggPhase());
Assertions.assertEquals(AggPhase.LOCAL, local.getAggPhase());
// check local:
// id + 3
Expression localOutput0 = new Add(rStudent.getOutput().get(0).toSlot(), new IntegerLiteral(3));
// age + 1
Expression localOutput1 = new Add(rStudent.getOutput().get(2).toSlot(), new IntegerLiteral(1));
// id + 3
Expression localGroupBy0 = new Add(rStudent.getOutput().get(0).toSlot(), new IntegerLiteral(3));
// age + 1
Expression localGroupBy1 = new Add(rStudent.getOutput().get(2).toSlot(), new IntegerLiteral(1));
Assertions.assertEquals(2, local.getOutputExpressions().size());
Assertions.assertTrue(local.getOutputExpressions().get(0) instanceof Alias);
Assertions.assertEquals(localOutput0, local.getOutputExpressions().get(0).child(0));
Assertions.assertTrue(local.getOutputExpressions().get(1) instanceof Alias);
Assertions.assertEquals(localOutput1, local.getOutputExpressions().get(1).child(0));
Assertions.assertEquals(2, local.getGroupByExpressions().size());
Assertions.assertEquals(localGroupBy0, local.getGroupByExpressions().get(0));
Assertions.assertEquals(localGroupBy1, local.getGroupByExpressions().get(1));
// check global:
Expression globalOutput0 = local.getOutputExpressions().get(0).toSlot();
Expression globalOutput1 = local.getOutputExpressions().get(1).toSlot();
Expression globalGroupBy0 = local.getOutputExpressions().get(0).toSlot();
Expression globalGroupBy1 = local.getOutputExpressions().get(1).toSlot();
Assertions.assertEquals(2, global.getOutputExpressions().size());
Assertions.assertTrue(global.getOutputExpressions().get(0) instanceof SlotReference);
Assertions.assertEquals(globalOutput0, global.getOutputExpressions().get(0));
Assertions.assertTrue(global.getOutputExpressions().get(1) instanceof SlotReference);
Assertions.assertEquals(globalOutput1, global.getOutputExpressions().get(1));
Assertions.assertEquals(2, global.getGroupByExpressions().size());
Assertions.assertEquals(globalGroupBy0, global.getGroupByExpressions().get(0));
Assertions.assertEquals(globalGroupBy1, global.getGroupByExpressions().get(1));
// check distinct local:
Expression distinctLocalOutput = new Add(new Count(local.getOutputExpressions().get(1).toSlot(), true),
new IntegerLiteral(2));
Expression distinctLocalGroupBy = local.getOutputExpressions().get(0).toSlot();
Assertions.assertEquals(1, distinctLocal.getOutputExpressions().size());
Assertions.assertTrue(distinctLocal.getOutputExpressions().get(0) instanceof Alias);
Assertions.assertEquals(distinctLocalOutput, distinctLocal.getOutputExpressions().get(0).child(0));
Assertions.assertEquals(1, distinctLocal.getGroupByExpressions().size());
Assertions.assertEquals(distinctLocalGroupBy, distinctLocal.getGroupByExpressions().get(0));
// check id:
Assertions.assertEquals(outputExpressionList.get(0).getExprId(),
distinctLocal.getOutputExpressions().get(0).getExprId());
}
private Plan rewrite(Plan input) {
return PlanRewriter.topDownRewrite(input, new ConnectContext(), new AggregateDisassemble());
}

View File

@ -20,6 +20,7 @@ package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.nereids.analyzer.UnboundAlias;
import org.apache.doris.nereids.analyzer.UnboundFunction;
import org.apache.doris.nereids.analyzer.UnboundStar;
import org.apache.doris.nereids.trees.expressions.functions.Count;
import org.apache.doris.nereids.trees.expressions.functions.Sum;
import org.apache.doris.nereids.types.IntegerType;
@ -168,6 +169,25 @@ public class ExpressionEqualsTest {
Assertions.assertEquals(sum1.hashCode(), sum2.hashCode());
}
@Test
public void testAggregateFunction() {
Count count1 = new Count();
Count count2 = new Count();
Assertions.assertEquals(count1, count2);
Assertions.assertEquals(count1.hashCode(), count2.hashCode());
Count count3 = new Count(child1, true);
Count count4 = new Count(child2, true);
Assertions.assertEquals(count3, count4);
Assertions.assertEquals(count3.hashCode(), count4.hashCode());
// bad case
Count count5 = new Count(child1, true);
Count count6 = new Count(child2, false);
Assertions.assertNotEquals(count5, count6);
Assertions.assertNotEquals(count5.hashCode(), count6.hashCode());
}
@Test
public void testNamedExpression() {
ExprId aliasId = new ExprId(2);

View File

@ -71,17 +71,17 @@ public class PlanEqualsTest {
unexpected = new LogicalAggregate<>(Lists.newArrayList(), ImmutableList.of(
new SlotReference(new ExprId(1), "b", BigIntType.INSTANCE, true, Lists.newArrayList())),
true, false, AggPhase.GLOBAL, child);
true, false, true, AggPhase.GLOBAL, child);
Assertions.assertNotEquals(unexpected, actual);
unexpected = new LogicalAggregate<>(Lists.newArrayList(), ImmutableList.of(
new SlotReference(new ExprId(1), "b", BigIntType.INSTANCE, true, Lists.newArrayList())),
false, true, AggPhase.GLOBAL, child);
false, true, true, AggPhase.GLOBAL, child);
Assertions.assertNotEquals(unexpected, actual);
unexpected = new LogicalAggregate<>(Lists.newArrayList(), ImmutableList.of(
new SlotReference(new ExprId(1), "b", BigIntType.INSTANCE, true, Lists.newArrayList())),
false, false, AggPhase.LOCAL, child);
false, false, true, AggPhase.LOCAL, child);
Assertions.assertNotEquals(unexpected, actual);
}
@ -178,20 +178,20 @@ public class PlanEqualsTest {
List<NamedExpression> outputExpressionList = ImmutableList.of(
new SlotReference(new ExprId(0), "a", BigIntType.INSTANCE, true, Lists.newArrayList()));
PhysicalAggregate<Plan> actual = new PhysicalAggregate<>(Lists.newArrayList(), outputExpressionList,
Lists.newArrayList(), AggPhase.LOCAL, true, logicalProperties, child);
Lists.newArrayList(), AggPhase.LOCAL, true, true, logicalProperties, child);
List<NamedExpression> outputExpressionList1 = ImmutableList.of(
new SlotReference(new ExprId(0), "a", BigIntType.INSTANCE, true, Lists.newArrayList()));
PhysicalAggregate<Plan> expected = new PhysicalAggregate<>(Lists.newArrayList(),
outputExpressionList1,
Lists.newArrayList(), AggPhase.LOCAL, true, logicalProperties, child);
Lists.newArrayList(), AggPhase.LOCAL, true, true, logicalProperties, child);
Assertions.assertEquals(expected, actual);
List<NamedExpression> outputExpressionList2 = ImmutableList.of(
new SlotReference(new ExprId(0), "a", BigIntType.INSTANCE, true, Lists.newArrayList()));
PhysicalAggregate<Plan> unexpected = new PhysicalAggregate<>(Lists.newArrayList(),
outputExpressionList2,
Lists.newArrayList(), AggPhase.LOCAL, false, logicalProperties, child);
Lists.newArrayList(), AggPhase.LOCAL, false, true, logicalProperties, child);
Assertions.assertNotEquals(unexpected, actual);
}