[feature](nereids) support distinct count (#12159)
support distinct count with group by clause. for example: SELECT count(distinct c_custkey + 1) FROM customer group by c_nation; TODO: support distinct count without group by clause.
This commit is contained in:
@ -256,6 +256,8 @@ public class ExpressionTranslator extends DefaultExpressionVisitor<Expr, PlanTra
|
||||
Count count = (Count) function;
|
||||
if (count.isStar()) {
|
||||
return new FunctionCallExpr(function.getName(), FunctionParams.createStarParam());
|
||||
} else if (count.isDistinct()) {
|
||||
return new FunctionCallExpr(function.getName(), new FunctionParams(true, paramList));
|
||||
}
|
||||
}
|
||||
return new FunctionCallExpr(function.getName(), paramList);
|
||||
|
||||
@ -191,12 +191,17 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
|
||||
// 3. generate output tuple
|
||||
List<Slot> slotList = Lists.newArrayList();
|
||||
TupleDescriptor outputTupleDesc;
|
||||
if (aggregate.getAggPhase() == AggPhase.GLOBAL) {
|
||||
if (aggregate.getAggPhase() == AggPhase.LOCAL) {
|
||||
outputTupleDesc = generateTupleDesc(aggregate.getOutput(), null, context);
|
||||
} else if ((aggregate.getAggPhase() == AggPhase.GLOBAL && aggregate.isFinalPhase())
|
||||
|| aggregate.getAggPhase() == AggPhase.DISTINCT_LOCAL) {
|
||||
slotList.addAll(groupSlotList);
|
||||
slotList.addAll(aggFunctionOutput);
|
||||
outputTupleDesc = generateTupleDesc(slotList, null, context);
|
||||
} else {
|
||||
outputTupleDesc = generateTupleDesc(aggregate.getOutput(), null, context);
|
||||
// In the distinct agg scenario, global shares local's desc
|
||||
AggregationNode localAggNode = (AggregationNode) inputPlanFragment.getPlanRoot().getChild(0);
|
||||
outputTupleDesc = localAggNode.getAggInfo().getOutputTupleDesc();
|
||||
}
|
||||
|
||||
if (aggregate.getAggPhase() == AggPhase.GLOBAL) {
|
||||
@ -204,6 +209,13 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
|
||||
execAggregateFunction.setMergeForNereids(true);
|
||||
}
|
||||
}
|
||||
if (aggregate.getAggPhase() == AggPhase.DISTINCT_LOCAL) {
|
||||
for (FunctionCallExpr execAggregateFunction : execAggregateFunctions) {
|
||||
if (!execAggregateFunction.isDistinct()) {
|
||||
execAggregateFunction.setMergeForNereids(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
AggregateInfo aggInfo = AggregateInfo.create(execGroupingExpressions, execAggregateFunctions, outputTupleDesc,
|
||||
outputTupleDesc, aggregate.getAggPhase().toExec());
|
||||
AggregationNode aggregationNode = new AggregationNode(context.nextPlanNodeId(),
|
||||
@ -216,6 +228,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
|
||||
aggregationNode.setIntermediateTuple();
|
||||
break;
|
||||
case GLOBAL:
|
||||
case DISTINCT_LOCAL:
|
||||
if (currentFragment.getPlanRoot() instanceof ExchangeNode) {
|
||||
ExchangeNode exchangeNode = (ExchangeNode) currentFragment.getPlanRoot();
|
||||
currentFragment = new PlanFragment(context.nextFragmentId(), exchangeNode, mergePartition);
|
||||
|
||||
@ -80,12 +80,12 @@ public class ChildOutputPropertyDeriver extends PlanVisitor<PhysicalProperties,
|
||||
case LOCAL:
|
||||
return new PhysicalProperties(childOutputProperty.getDistributionSpec());
|
||||
case GLOBAL:
|
||||
case DISTINCT_LOCAL:
|
||||
List<ExprId> columns = agg.getPartitionExpressions().stream()
|
||||
.map(SlotReference.class::cast)
|
||||
.map(SlotReference::getExprId)
|
||||
.collect(Collectors.toList());
|
||||
return PhysicalProperties.createHash(new DistributionSpecHash(columns, ShuffleType.AGGREGATE));
|
||||
case DISTINCT_LOCAL:
|
||||
case DISTINCT_GLOBAL:
|
||||
default:
|
||||
throw new RuntimeException("Could not derive output properties for agg phase: " + agg.getAggPhase());
|
||||
|
||||
@ -25,6 +25,7 @@ import org.apache.doris.nereids.properties.DistributionSpecHash.ShuffleType;
|
||||
import org.apache.doris.nereids.trees.expressions.ExprId;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.plans.AggPhase;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.physical.PhysicalAggregate;
|
||||
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
|
||||
@ -82,14 +83,16 @@ public class RequestPropertyDeriver extends PlanVisitor<Void, PlanContext> {
|
||||
addToRequestPropertyToChildren(PhysicalProperties.ANY);
|
||||
return null;
|
||||
}
|
||||
|
||||
if (agg.getAggPhase() == AggPhase.GLOBAL && !agg.isFinalPhase()) {
|
||||
addToRequestPropertyToChildren(requestPropertyFromParent);
|
||||
return null;
|
||||
}
|
||||
// 2. second phase agg, need to return shuffle with partition key
|
||||
List<Expression> partitionExpressions = agg.getPartitionExpressions();
|
||||
if (partitionExpressions.isEmpty()) {
|
||||
addToRequestPropertyToChildren(PhysicalProperties.GATHER);
|
||||
return null;
|
||||
}
|
||||
|
||||
// TODO: when parent is a join node,
|
||||
// use requestPropertyFromParent to keep column order as join to avoid shuffle again.
|
||||
if (partitionExpressions.stream().allMatch(SlotReference.class::isInstance)) {
|
||||
|
||||
@ -27,6 +27,7 @@ import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.TimestampArithmetic;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.Count;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
|
||||
import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
@ -115,6 +116,17 @@ public class BindFunction implements AnalysisRuleFactory {
|
||||
|
||||
@Override
|
||||
public BoundFunction visitUnboundFunction(UnboundFunction unboundFunction, Env env) {
|
||||
// FunctionRegistry can't support boolean arg now, tricky here.
|
||||
if (unboundFunction.getName().equalsIgnoreCase("count")) {
|
||||
List<Expression> arguments = unboundFunction.getArguments();
|
||||
if ((arguments.size() == 0 && unboundFunction.isStar()) || arguments.stream()
|
||||
.allMatch(Expression::isConstant)) {
|
||||
return new Count();
|
||||
}
|
||||
if (arguments.size() == 1) {
|
||||
return new Count(unboundFunction.getArguments().get(0), unboundFunction.isDistinct());
|
||||
}
|
||||
}
|
||||
FunctionRegistry functionRegistry = env.getFunctionRegistry();
|
||||
String functionName = unboundFunction.getName();
|
||||
FunctionBuilder builder = functionRegistry.findFunctionBuilder(
|
||||
|
||||
@ -126,7 +126,7 @@ public class ExpressionRewrite implements RewriteRuleFactory {
|
||||
return agg;
|
||||
}
|
||||
return new LogicalAggregate<>(newGroupByExprs, newOutputExpressions,
|
||||
agg.isDisassembled(), agg.isNormalized(), agg.getAggPhase(), agg.child());
|
||||
agg.isDisassembled(), agg.isNormalized(), agg.isFinalPhase(), agg.getAggPhase(), agg.child());
|
||||
}).toRule(RuleType.REWRITE_AGG_EXPRESSION);
|
||||
}
|
||||
}
|
||||
|
||||
@ -36,6 +36,7 @@ public class LogicalAggToPhysicalHashAgg extends OneImplementationRuleFactory {
|
||||
ImmutableList.of(),
|
||||
agg.getAggPhase(),
|
||||
false,
|
||||
agg.isFinalPhase(),
|
||||
agg.getLogicalProperties(),
|
||||
agg.child())
|
||||
).toRule(RuleType.LOGICAL_AGG_TO_PHYSICAL_HASH_AGG_RULE);
|
||||
|
||||
@ -49,92 +49,187 @@ import java.util.stream.Collectors;
|
||||
* +-- Aggregate(phase: [LOCAL], outputExpr: [SUM(v1 * v2) as a, (k + 1) as b], groupByExpr: [k + 1])
|
||||
* +-- childPlan
|
||||
*
|
||||
* Distinct Agg With Group By Processing:
|
||||
* If we have a query: SELECT count(distinct v1 * v2) + 1 FROM t GROUP BY k + 1
|
||||
* the initial plan is:
|
||||
* Aggregate(phase: [GLOBAL], outputExpr: [Alias(k + 1) #1, Alias(COUNT(distinct v1 * v2) + 1) #2]
|
||||
* , groupByExpr: [k + 1])
|
||||
* +-- childPlan
|
||||
* we should rewrite to:
|
||||
* Aggregate(phase: [DISTINCT_LOCAL], outputExpr: [Alias(b) #1, Alias(COUNT(distinct a) + 1) #2], groupByExpr: [b])
|
||||
* +-- Aggregate(phase: [GLOBAL], outputExpr: [b, a], groupByExpr: [b, a])
|
||||
* +-- Aggregate(phase: [LOCAL], outputExpr: [(k + 1) as b, (v1 * v2) as a], groupByExpr: [k + 1, a])
|
||||
* +-- childPlan
|
||||
*
|
||||
* TODO:
|
||||
* 1. use different class represent different phase aggregate
|
||||
* 2. if instance count is 1, shouldn't disassemble the agg plan
|
||||
*/
|
||||
public class AggregateDisassemble extends OneRewriteRuleFactory {
|
||||
// used in secondDisassemble to transform local expressions into global
|
||||
private final Map<Expression, Expression> globalOutputSubstitutionMap = Maps.newHashMap();
|
||||
// used in secondDisassemble to transform local expressions into global
|
||||
private final Map<Expression, Expression> globalGroupBySubstitutionMap = Maps.newHashMap();
|
||||
// used to indicate the existence of a distinct function for the entire phase
|
||||
private boolean hasDistinctAgg = false;
|
||||
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalAggregate().when(agg -> !agg.isDisassembled()).thenApply(ctx -> {
|
||||
LogicalAggregate<GroupPlan> aggregate = ctx.root;
|
||||
List<NamedExpression> originOutputExprs = aggregate.getOutputExpressions();
|
||||
List<Expression> originGroupByExprs = aggregate.getGroupByExpressions();
|
||||
|
||||
// 1. generate a map from local aggregate output to global aggregate expr substitution.
|
||||
// inputSubstitutionMap use for replacing expression in global aggregate
|
||||
// replace rule is:
|
||||
// a: Expression is a group by key and is a slot reference. e.g. group by k1
|
||||
// b. Expression is a group by key and is an expression. e.g. group by k1 + 1
|
||||
// c. Expression is an aggregate function. e.g. sum(v1) in select list
|
||||
// +-----------+---------------------+-------------------------+--------------------------------+
|
||||
// | situation | origin expression | local output expression | expression in global aggregate |
|
||||
// +-----------+---------------------+-------------------------+--------------------------------+
|
||||
// | a | Ref(k1)#1 | Ref(k1)#1 | Ref(k1)#1 |
|
||||
// +-----------+---------------------+-------------------------+--------------------------------+
|
||||
// | b | Ref(k1)#1 + 1 | A(Ref(k1)#1 + 1, key)#2 | Ref(key)#2 |
|
||||
// +-----------+---------------------+-------------------------+--------------------------------+
|
||||
// | c | A(AF(v1#1), 'af')#2 | A(AF(v1#1), 'af')#3 | AF(af#3) |
|
||||
// +-----------+---------------------+-------------------------+--------------------------------+
|
||||
// NOTICE: Ref: SlotReference, A: Alias, AF: AggregateFunction, #x: ExprId x
|
||||
// 2. collect local aggregate output expressions and local aggregate group by expression list
|
||||
Map<Expression, Expression> inputSubstitutionMap = Maps.newHashMap();
|
||||
List<Expression> localGroupByExprs = aggregate.getGroupByExpressions();
|
||||
List<NamedExpression> localOutputExprs = Lists.newArrayList();
|
||||
for (Expression originGroupByExpr : originGroupByExprs) {
|
||||
if (inputSubstitutionMap.containsKey(originGroupByExpr)) {
|
||||
continue;
|
||||
}
|
||||
if (originGroupByExpr instanceof SlotReference) {
|
||||
inputSubstitutionMap.put(originGroupByExpr, originGroupByExpr);
|
||||
localOutputExprs.add((SlotReference) originGroupByExpr);
|
||||
} else {
|
||||
NamedExpression localOutputExpr = new Alias(originGroupByExpr, originGroupByExpr.toSql());
|
||||
inputSubstitutionMap.put(originGroupByExpr, localOutputExpr.toSlot());
|
||||
localOutputExprs.add(localOutputExpr);
|
||||
}
|
||||
LogicalAggregate firstAggregate = firstDisassemble(aggregate);
|
||||
if (!hasDistinctAgg) {
|
||||
return firstAggregate;
|
||||
}
|
||||
for (NamedExpression originOutputExpr : originOutputExprs) {
|
||||
Set<AggregateFunction> aggregateFunctions
|
||||
= originOutputExpr.collect(AggregateFunction.class::isInstance);
|
||||
for (AggregateFunction aggregateFunction : aggregateFunctions) {
|
||||
if (inputSubstitutionMap.containsKey(aggregateFunction)) {
|
||||
continue;
|
||||
}
|
||||
NamedExpression localOutputExpr = new Alias(aggregateFunction, aggregateFunction.toSql());
|
||||
Expression substitutionValue = aggregateFunction.withChildren(
|
||||
Lists.newArrayList(localOutputExpr.toSlot()));
|
||||
inputSubstitutionMap.put(aggregateFunction, substitutionValue);
|
||||
localOutputExprs.add(localOutputExpr);
|
||||
}
|
||||
}
|
||||
|
||||
// 3. replace expression in globalOutputExprs and globalGroupByExprs
|
||||
List<NamedExpression> globalOutputExprs = aggregate.getOutputExpressions().stream()
|
||||
.map(e -> ExpressionUtils.replace(e, inputSubstitutionMap))
|
||||
.map(NamedExpression.class::cast)
|
||||
.collect(Collectors.toList());
|
||||
List<Expression> globalGroupByExprs = localGroupByExprs.stream()
|
||||
.map(e -> ExpressionUtils.replace(e, inputSubstitutionMap)).collect(Collectors.toList());
|
||||
|
||||
// 4. generate new plan
|
||||
LogicalAggregate localAggregate = new LogicalAggregate<>(
|
||||
localGroupByExprs,
|
||||
localOutputExprs,
|
||||
true,
|
||||
aggregate.isNormalized(),
|
||||
AggPhase.LOCAL,
|
||||
aggregate.child()
|
||||
);
|
||||
return new LogicalAggregate<>(
|
||||
globalGroupByExprs,
|
||||
globalOutputExprs,
|
||||
true,
|
||||
aggregate.isNormalized(),
|
||||
AggPhase.GLOBAL,
|
||||
localAggregate
|
||||
);
|
||||
return secondDisassemble(firstAggregate);
|
||||
}).toRule(RuleType.AGGREGATE_DISASSEMBLE);
|
||||
}
|
||||
|
||||
// only support distinct function with group by
|
||||
// TODO: support distinct function without group by. (add second global phase)
|
||||
private LogicalAggregate secondDisassemble(LogicalAggregate<LogicalAggregate> aggregate) {
|
||||
LogicalAggregate<GroupPlan> local = aggregate.child();
|
||||
// replace expression in globalOutputExprs and globalGroupByExprs
|
||||
List<NamedExpression> globalOutputExprs = local.getOutputExpressions().stream()
|
||||
.map(e -> ExpressionUtils.replace(e, globalOutputSubstitutionMap))
|
||||
.map(NamedExpression.class::cast)
|
||||
.collect(Collectors.toList());
|
||||
List<Expression> globalGroupByExprs = local.getGroupByExpressions().stream()
|
||||
.map(e -> ExpressionUtils.replace(e, globalGroupBySubstitutionMap))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
// generate new plan
|
||||
LogicalAggregate globalAggregate = new LogicalAggregate<>(
|
||||
globalGroupByExprs,
|
||||
globalOutputExprs,
|
||||
true,
|
||||
aggregate.isNormalized(),
|
||||
false,
|
||||
AggPhase.GLOBAL,
|
||||
local
|
||||
);
|
||||
return new LogicalAggregate<>(
|
||||
aggregate.getGroupByExpressions(),
|
||||
aggregate.getOutputExpressions(),
|
||||
true,
|
||||
aggregate.isNormalized(),
|
||||
true,
|
||||
AggPhase.DISTINCT_LOCAL,
|
||||
globalAggregate
|
||||
);
|
||||
}
|
||||
|
||||
private LogicalAggregate firstDisassemble(LogicalAggregate<GroupPlan> aggregate) {
|
||||
List<NamedExpression> originOutputExprs = aggregate.getOutputExpressions();
|
||||
List<Expression> originGroupByExprs = aggregate.getGroupByExpressions();
|
||||
Map<Expression, Expression> inputSubstitutionMap = Maps.newHashMap();
|
||||
|
||||
// 1. generate a map from local aggregate output to global aggregate expr substitution.
|
||||
// inputSubstitutionMap use for replacing expression in global aggregate
|
||||
// replace rule is:
|
||||
// a: Expression is a group by key and is a slot reference. e.g. group by k1
|
||||
// b. Expression is a group by key and is an expression. e.g. group by k1 + 1
|
||||
// c. Expression is an aggregate function. e.g. sum(v1) in select list
|
||||
// +-----------+---------------------+-------------------------+--------------------------------+
|
||||
// | situation | origin expression | local output expression | expression in global aggregate |
|
||||
// +-----------+---------------------+-------------------------+--------------------------------+
|
||||
// | a | Ref(k1)#1 | Ref(k1)#1 | Ref(k1)#1 |
|
||||
// +-----------+---------------------+-------------------------+--------------------------------+
|
||||
// | b | Ref(k1)#1 + 1 | A(Ref(k1)#1 + 1, key)#2 | Ref(key)#2 |
|
||||
// +-----------+---------------------+-------------------------+--------------------------------+
|
||||
// | c | A(AF(v1#1), 'af')#2 | A(AF(v1#1), 'af')#3 | AF(af#3) |
|
||||
// +-----------+---------------------+-------------------------+--------------------------------+
|
||||
// NOTICE: Ref: SlotReference, A: Alias, AF: AggregateFunction, #x: ExprId x
|
||||
// 2. collect local aggregate output expressions and local aggregate group by expression list
|
||||
List<Expression> localGroupByExprs = aggregate.getGroupByExpressions();
|
||||
List<NamedExpression> localOutputExprs = Lists.newArrayList();
|
||||
for (Expression originGroupByExpr : originGroupByExprs) {
|
||||
if (inputSubstitutionMap.containsKey(originGroupByExpr)) {
|
||||
continue;
|
||||
}
|
||||
if (originGroupByExpr instanceof SlotReference) {
|
||||
inputSubstitutionMap.put(originGroupByExpr, originGroupByExpr);
|
||||
globalOutputSubstitutionMap.put(originGroupByExpr, originGroupByExpr);
|
||||
globalGroupBySubstitutionMap.put(originGroupByExpr, originGroupByExpr);
|
||||
localOutputExprs.add((SlotReference) originGroupByExpr);
|
||||
} else {
|
||||
NamedExpression localOutputExpr = new Alias(originGroupByExpr, originGroupByExpr.toSql());
|
||||
inputSubstitutionMap.put(originGroupByExpr, localOutputExpr.toSlot());
|
||||
globalOutputSubstitutionMap.put(localOutputExpr, localOutputExpr.toSlot());
|
||||
globalGroupBySubstitutionMap.put(originGroupByExpr, localOutputExpr.toSlot());
|
||||
localOutputExprs.add(localOutputExpr);
|
||||
}
|
||||
}
|
||||
List<Expression> distinctExprsForLocalGroupBy = Lists.newArrayList();
|
||||
List<NamedExpression> distinctExprsForLocalOutput = Lists.newArrayList();
|
||||
for (NamedExpression originOutputExpr : originOutputExprs) {
|
||||
Set<AggregateFunction> aggregateFunctions
|
||||
= originOutputExpr.collect(AggregateFunction.class::isInstance);
|
||||
for (AggregateFunction aggregateFunction : aggregateFunctions) {
|
||||
if (inputSubstitutionMap.containsKey(aggregateFunction)) {
|
||||
continue;
|
||||
}
|
||||
if (aggregateFunction.isDistinct()) {
|
||||
hasDistinctAgg = true;
|
||||
for (Expression expr : aggregateFunction.children()) {
|
||||
if (expr instanceof SlotReference) {
|
||||
distinctExprsForLocalOutput.add((SlotReference) expr);
|
||||
if (!inputSubstitutionMap.containsKey(expr)) {
|
||||
inputSubstitutionMap.put(expr, expr);
|
||||
globalOutputSubstitutionMap.put(expr, expr);
|
||||
globalGroupBySubstitutionMap.put(expr, expr);
|
||||
}
|
||||
} else {
|
||||
NamedExpression globalOutputExpr = new Alias(expr, expr.toSql());
|
||||
distinctExprsForLocalOutput.add(globalOutputExpr);
|
||||
if (!inputSubstitutionMap.containsKey(expr)) {
|
||||
inputSubstitutionMap.put(expr, globalOutputExpr.toSlot());
|
||||
globalOutputSubstitutionMap.put(globalOutputExpr, globalOutputExpr.toSlot());
|
||||
globalGroupBySubstitutionMap.put(expr, globalOutputExpr.toSlot());
|
||||
}
|
||||
}
|
||||
distinctExprsForLocalGroupBy.add(expr);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
NamedExpression localOutputExpr = new Alias(aggregateFunction, aggregateFunction.toSql());
|
||||
Expression substitutionValue = aggregateFunction.withChildren(
|
||||
Lists.newArrayList(localOutputExpr.toSlot()));
|
||||
inputSubstitutionMap.put(aggregateFunction, substitutionValue);
|
||||
globalOutputSubstitutionMap.put(aggregateFunction, substitutionValue);
|
||||
localOutputExprs.add(localOutputExpr);
|
||||
}
|
||||
}
|
||||
|
||||
// 3. replace expression in globalOutputExprs and globalGroupByExprs
|
||||
List<NamedExpression> globalOutputExprs = aggregate.getOutputExpressions().stream()
|
||||
.map(e -> ExpressionUtils.replace(e, inputSubstitutionMap))
|
||||
.map(NamedExpression.class::cast)
|
||||
.collect(Collectors.toList());
|
||||
List<Expression> globalGroupByExprs = localGroupByExprs.stream()
|
||||
.map(e -> ExpressionUtils.replace(e, inputSubstitutionMap)).collect(Collectors.toList());
|
||||
// To avoid repeated substitution of distinct expressions,
|
||||
// here the expressions are put into the local after the substitution is completed
|
||||
localOutputExprs.addAll(distinctExprsForLocalOutput);
|
||||
localGroupByExprs.addAll(distinctExprsForLocalGroupBy);
|
||||
// 4. generate new plan
|
||||
LogicalAggregate localAggregate = new LogicalAggregate<>(
|
||||
localGroupByExprs,
|
||||
localOutputExprs,
|
||||
true,
|
||||
aggregate.isNormalized(),
|
||||
false,
|
||||
AggPhase.LOCAL,
|
||||
aggregate.child()
|
||||
);
|
||||
return new LogicalAggregate<>(
|
||||
globalGroupByExprs,
|
||||
globalOutputExprs,
|
||||
true,
|
||||
aggregate.isNormalized(),
|
||||
true,
|
||||
AggPhase.GLOBAL,
|
||||
localAggregate
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@ -124,7 +124,7 @@ public class NormalizeAggregate extends OneRewriteRuleFactory {
|
||||
root = new LogicalProject<>(bottomProjections, root);
|
||||
}
|
||||
root = new LogicalAggregate<>(newKeys, newOutputs, aggregate.isDisassembled(),
|
||||
true, aggregate.getAggPhase(), root);
|
||||
true, aggregate.isFinalPhase(), aggregate.getAggPhase(), root);
|
||||
List<NamedExpression> projections = outputs.stream()
|
||||
.map(e -> ExpressionUtils.replace(e, substitutionMap))
|
||||
.map(NamedExpression.class::cast)
|
||||
|
||||
@ -21,19 +21,50 @@ import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* The function which consume arguments in lots of rows and product one value.
|
||||
*/
|
||||
public abstract class AggregateFunction extends BoundFunction {
|
||||
|
||||
private DataType intermediate;
|
||||
private final boolean isDistinct;
|
||||
|
||||
public AggregateFunction(String name, Expression... arguments) {
|
||||
super(name, arguments);
|
||||
isDistinct = false;
|
||||
}
|
||||
|
||||
public AggregateFunction(String name, boolean isDistinct, Expression... arguments) {
|
||||
super(name, arguments);
|
||||
this.isDistinct = isDistinct;
|
||||
}
|
||||
|
||||
public abstract DataType getIntermediateType();
|
||||
|
||||
public boolean isDistinct() {
|
||||
return isDistinct;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
return true;
|
||||
}
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
return false;
|
||||
}
|
||||
AggregateFunction that = (AggregateFunction) o;
|
||||
return Objects.equals(isDistinct, that.isDistinct) && Objects.equals(intermediate, that.intermediate)
|
||||
&& Objects.equals(getName(), that.getName()) && Objects.equals(children, that.children);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(isDistinct, intermediate, getName(), children);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
|
||||
return visitor.visitAggregateFunction(this, context);
|
||||
|
||||
@ -37,8 +37,8 @@ public class Count extends AggregateFunction {
|
||||
this.isStar = true;
|
||||
}
|
||||
|
||||
public Count(Expression child) {
|
||||
super("count", child);
|
||||
public Count(Expression child, boolean isDistinct) {
|
||||
super("count", isDistinct, child);
|
||||
this.isStar = false;
|
||||
}
|
||||
|
||||
@ -62,7 +62,7 @@ public class Count extends AggregateFunction {
|
||||
if (children.size() == 0) {
|
||||
return new Count();
|
||||
}
|
||||
return new Count(children.get(0));
|
||||
return new Count(children.get(0), isDistinct());
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -79,6 +79,9 @@ public class Count extends AggregateFunction {
|
||||
.stream()
|
||||
.map(Expression::toSql)
|
||||
.collect(Collectors.joining(", "));
|
||||
if (isDistinct()) {
|
||||
return "count(distinct " + args + ")";
|
||||
}
|
||||
return "count(" + args + ")";
|
||||
}
|
||||
|
||||
@ -91,6 +94,9 @@ public class Count extends AggregateFunction {
|
||||
.stream()
|
||||
.map(Expression::toString)
|
||||
.collect(Collectors.joining(", "));
|
||||
if (isDistinct()) {
|
||||
return "count(distinct " + args + ")";
|
||||
}
|
||||
return "count(" + args + ")";
|
||||
}
|
||||
}
|
||||
|
||||
@ -59,6 +59,13 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> extends LogicalUnary<CHIL
|
||||
private final List<NamedExpression> outputExpressions;
|
||||
private final AggPhase aggPhase;
|
||||
|
||||
// use for scenes containing distinct agg
|
||||
// 1. If there are LOCAL and GLOBAL phases, global is the final phase
|
||||
// 2. If there are LOCAL, GLOBAL and DISTINCT_LOCAL phases, DISTINCT_LOCAL is the final phase
|
||||
// 3. If there are LOCAL, GLOBAL, DISTINCT_LOCAL, DISTINCT_GLOBAL phases,
|
||||
// DISTINCT_GLOBAL is the final phase
|
||||
private final boolean isFinalPhase;
|
||||
|
||||
/**
|
||||
* Desc: Constructor for LogicalAggregate.
|
||||
*/
|
||||
@ -66,7 +73,7 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> extends LogicalUnary<CHIL
|
||||
List<Expression> groupByExpressions,
|
||||
List<NamedExpression> outputExpressions,
|
||||
CHILD_TYPE child) {
|
||||
this(groupByExpressions, outputExpressions, false, false, AggPhase.GLOBAL, child);
|
||||
this(groupByExpressions, outputExpressions, false, false, true, AggPhase.GLOBAL, child);
|
||||
}
|
||||
|
||||
public LogicalAggregate(
|
||||
@ -74,9 +81,10 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> extends LogicalUnary<CHIL
|
||||
List<NamedExpression> outputExpressions,
|
||||
boolean disassembled,
|
||||
boolean normalized,
|
||||
boolean isFinalPhase,
|
||||
AggPhase aggPhase,
|
||||
CHILD_TYPE child) {
|
||||
this(groupByExpressions, outputExpressions, disassembled, normalized,
|
||||
this(groupByExpressions, outputExpressions, disassembled, normalized, isFinalPhase,
|
||||
aggPhase, Optional.empty(), Optional.empty(), child);
|
||||
}
|
||||
|
||||
@ -88,6 +96,7 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> extends LogicalUnary<CHIL
|
||||
List<NamedExpression> outputExpressions,
|
||||
boolean disassembled,
|
||||
boolean normalized,
|
||||
boolean isFinalPhase,
|
||||
AggPhase aggPhase,
|
||||
Optional<GroupExpression> groupExpression,
|
||||
Optional<LogicalProperties> logicalProperties,
|
||||
@ -97,6 +106,7 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> extends LogicalUnary<CHIL
|
||||
this.outputExpressions = outputExpressions;
|
||||
this.disassembled = disassembled;
|
||||
this.normalized = normalized;
|
||||
this.isFinalPhase = isFinalPhase;
|
||||
this.aggPhase = aggPhase;
|
||||
}
|
||||
|
||||
@ -149,6 +159,10 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> extends LogicalUnary<CHIL
|
||||
return normalized;
|
||||
}
|
||||
|
||||
public boolean isFinalPhase() {
|
||||
return isFinalPhase;
|
||||
}
|
||||
|
||||
/**
|
||||
* Determine the equality with another plan
|
||||
*/
|
||||
@ -164,37 +178,37 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> extends LogicalUnary<CHIL
|
||||
&& Objects.equals(outputExpressions, that.outputExpressions)
|
||||
&& aggPhase == that.aggPhase
|
||||
&& disassembled == that.disassembled
|
||||
&& normalized == that.normalized;
|
||||
&& normalized == that.normalized
|
||||
&& isFinalPhase == that.isFinalPhase;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(groupByExpressions, outputExpressions, aggPhase, normalized, disassembled);
|
||||
return Objects.hash(groupByExpressions, outputExpressions, aggPhase, normalized, disassembled, isFinalPhase);
|
||||
}
|
||||
|
||||
@Override
|
||||
public LogicalAggregate<Plan> withChildren(List<Plan> children) {
|
||||
Preconditions.checkArgument(children.size() == 1);
|
||||
return new LogicalAggregate<>(groupByExpressions, outputExpressions,
|
||||
disassembled, normalized, aggPhase, children.get(0));
|
||||
disassembled, normalized, isFinalPhase, aggPhase, children.get(0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public LogicalAggregate<Plan> withGroupExpression(Optional<GroupExpression> groupExpression) {
|
||||
return new LogicalAggregate<>(groupByExpressions, outputExpressions,
|
||||
disassembled, normalized, aggPhase, groupExpression, Optional.of(getLogicalProperties()),
|
||||
children.get(0));
|
||||
return new LogicalAggregate<>(groupByExpressions, outputExpressions, disassembled, normalized, isFinalPhase,
|
||||
aggPhase, groupExpression, Optional.of(getLogicalProperties()), children.get(0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public LogicalAggregate<Plan> withLogicalProperties(Optional<LogicalProperties> logicalProperties) {
|
||||
return new LogicalAggregate<>(groupByExpressions, outputExpressions,
|
||||
disassembled, normalized, aggPhase, Optional.empty(), logicalProperties, children.get(0));
|
||||
return new LogicalAggregate<>(groupByExpressions, outputExpressions, disassembled, normalized, isFinalPhase,
|
||||
aggPhase, Optional.empty(), logicalProperties, children.get(0));
|
||||
}
|
||||
|
||||
public LogicalAggregate<Plan> withGroupByAndOutput(List<Expression> groupByExprList,
|
||||
List<NamedExpression> outputExpressionList) {
|
||||
return new LogicalAggregate<>(groupByExprList, outputExpressionList,
|
||||
disassembled, normalized, aggPhase, child());
|
||||
disassembled, normalized, isFinalPhase, aggPhase, child());
|
||||
}
|
||||
}
|
||||
|
||||
@ -53,11 +53,18 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan> extends PhysicalUnary<CH
|
||||
|
||||
private final boolean usingStream;
|
||||
|
||||
// use for scenes containing distinct agg
|
||||
// 1. If there are LOCAL and GLOBAL phases, global is the final phase
|
||||
// 2. If there are LOCAL, GLOBAL and DISTINCT_LOCAL phases, DISTINCT_LOCAL is the final phase
|
||||
// 3. If there are LOCAL, GLOBAL, DISTINCT_LOCAL, DISTINCT_GLOBAL phases,
|
||||
// DISTINCT_GLOBAL is the final phase
|
||||
private final boolean isFinalPhase;
|
||||
|
||||
public PhysicalAggregate(List<Expression> groupByExpressions, List<NamedExpression> outputExpressions,
|
||||
List<Expression> partitionExpressions, AggPhase aggPhase, boolean usingStream,
|
||||
LogicalProperties logicalProperties, CHILD_TYPE child) {
|
||||
boolean isFinalPhase, LogicalProperties logicalProperties, CHILD_TYPE child) {
|
||||
this(groupByExpressions, outputExpressions, partitionExpressions, aggPhase, usingStream,
|
||||
Optional.empty(), logicalProperties, child);
|
||||
isFinalPhase, Optional.empty(), logicalProperties, child);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -69,7 +76,7 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan> extends PhysicalUnary<CH
|
||||
* @param usingStream whether it's stream agg.
|
||||
*/
|
||||
public PhysicalAggregate(List<Expression> groupByExpressions, List<NamedExpression> outputExpressions,
|
||||
List<Expression> partitionExpressions, AggPhase aggPhase, boolean usingStream,
|
||||
List<Expression> partitionExpressions, AggPhase aggPhase, boolean usingStream, boolean isFinalPhase,
|
||||
Optional<GroupExpression> groupExpression, LogicalProperties logicalProperties,
|
||||
CHILD_TYPE child) {
|
||||
super(PlanType.PHYSICAL_AGGREGATE, groupExpression, logicalProperties, child);
|
||||
@ -78,6 +85,7 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan> extends PhysicalUnary<CH
|
||||
this.aggPhase = aggPhase;
|
||||
this.partitionExpressions = partitionExpressions;
|
||||
this.usingStream = usingStream;
|
||||
this.isFinalPhase = isFinalPhase;
|
||||
}
|
||||
|
||||
/**
|
||||
@ -89,7 +97,7 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan> extends PhysicalUnary<CH
|
||||
* @param usingStream whether it's stream agg.
|
||||
*/
|
||||
public PhysicalAggregate(List<Expression> groupByExpressions, List<NamedExpression> outputExpressions,
|
||||
List<Expression> partitionExpressions, AggPhase aggPhase, boolean usingStream,
|
||||
List<Expression> partitionExpressions, AggPhase aggPhase, boolean usingStream, boolean isFinalPhase,
|
||||
Optional<GroupExpression> groupExpression, LogicalProperties logicalProperties,
|
||||
PhysicalProperties physicalProperties, CHILD_TYPE child) {
|
||||
super(PlanType.PHYSICAL_AGGREGATE, groupExpression, logicalProperties, physicalProperties, child);
|
||||
@ -98,6 +106,7 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan> extends PhysicalUnary<CH
|
||||
this.aggPhase = aggPhase;
|
||||
this.partitionExpressions = partitionExpressions;
|
||||
this.usingStream = usingStream;
|
||||
this.isFinalPhase = isFinalPhase;
|
||||
}
|
||||
|
||||
public AggPhase getAggPhase() {
|
||||
@ -112,6 +121,10 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan> extends PhysicalUnary<CH
|
||||
return outputExpressions;
|
||||
}
|
||||
|
||||
public boolean isFinalPhase() {
|
||||
return isFinalPhase;
|
||||
}
|
||||
|
||||
public boolean isUsingStream() {
|
||||
return usingStream;
|
||||
}
|
||||
@ -156,36 +169,38 @@ public class PhysicalAggregate<CHILD_TYPE extends Plan> extends PhysicalUnary<CH
|
||||
&& Objects.equals(outputExpressions, that.outputExpressions)
|
||||
&& Objects.equals(partitionExpressions, that.partitionExpressions)
|
||||
&& usingStream == that.usingStream
|
||||
&& aggPhase == that.aggPhase;
|
||||
&& aggPhase == that.aggPhase
|
||||
&& isFinalPhase == that.isFinalPhase;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(groupByExpressions, outputExpressions, partitionExpressions, aggPhase, usingStream);
|
||||
return Objects.hash(groupByExpressions, outputExpressions, partitionExpressions, aggPhase, usingStream,
|
||||
isFinalPhase);
|
||||
}
|
||||
|
||||
@Override
|
||||
public PhysicalAggregate<Plan> withChildren(List<Plan> children) {
|
||||
Preconditions.checkArgument(children.size() == 1);
|
||||
return new PhysicalAggregate<>(groupByExpressions, outputExpressions, partitionExpressions, aggPhase,
|
||||
usingStream, getLogicalProperties(), children.get(0));
|
||||
usingStream, isFinalPhase, getLogicalProperties(), children.get(0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public PhysicalAggregate<CHILD_TYPE> withGroupExpression(Optional<GroupExpression> groupExpression) {
|
||||
return new PhysicalAggregate<>(groupByExpressions, outputExpressions, partitionExpressions, aggPhase,
|
||||
usingStream, groupExpression, getLogicalProperties(), child());
|
||||
usingStream, isFinalPhase, groupExpression, getLogicalProperties(), child());
|
||||
}
|
||||
|
||||
@Override
|
||||
public PhysicalAggregate<CHILD_TYPE> withLogicalProperties(Optional<LogicalProperties> logicalProperties) {
|
||||
return new PhysicalAggregate<>(groupByExpressions, outputExpressions, partitionExpressions, aggPhase,
|
||||
usingStream, Optional.empty(), logicalProperties.get(), child());
|
||||
usingStream, isFinalPhase, Optional.empty(), logicalProperties.get(), child());
|
||||
}
|
||||
|
||||
@Override
|
||||
public PhysicalAggregate<CHILD_TYPE> withPhysicalProperties(PhysicalProperties physicalProperties) {
|
||||
return new PhysicalAggregate<>(groupByExpressions, outputExpressions, partitionExpressions, aggPhase,
|
||||
usingStream, Optional.empty(), getLogicalProperties(), physicalProperties, child());
|
||||
usingStream, isFinalPhase, Optional.empty(), getLogicalProperties(), physicalProperties, child());
|
||||
}
|
||||
}
|
||||
|
||||
@ -360,9 +360,9 @@ public class HavingClauseTest extends AnalyzeCheckTestBase implements PatternMat
|
||||
Alias pk11 = new Alias(new ExprId(8), new Add(new Add(pk, Literal.of((byte) 1)), Literal.of((byte) 1)), "((pk + 1) + 1)");
|
||||
Alias pk2 = new Alias(new ExprId(9), new Add(pk, Literal.of((byte) 2)), "(pk + 2)");
|
||||
Alias sumA1 = new Alias(new ExprId(10), new Sum(a1), "SUM(a1)");
|
||||
Alias countA11 = new Alias(new ExprId(11), new Add(new Count(a1), Literal.of((byte) 1)), "(COUNT(a1) + 1)");
|
||||
Alias countA11 = new Alias(new ExprId(11), new Add(new Count(a1, false), Literal.of((byte) 1)), "(COUNT(a1) + 1)");
|
||||
Alias sumA1A2 = new Alias(new ExprId(12), new Sum(new Add(a1, a2)), "SUM((a1 + a2))");
|
||||
Alias v1 = new Alias(new ExprId(0), new Count(a2), "v1");
|
||||
Alias v1 = new Alias(new ExprId(0), new Count(a2, false), "v1");
|
||||
PlanChecker.from(connectContext).analyze(sql)
|
||||
.matchesFromRoot(
|
||||
logicalProject(
|
||||
|
||||
@ -263,6 +263,7 @@ public class ChildOutputPropertyDeriverTest {
|
||||
Lists.newArrayList(key),
|
||||
AggPhase.LOCAL,
|
||||
true,
|
||||
true,
|
||||
logicalProperties,
|
||||
groupPlan
|
||||
);
|
||||
@ -286,6 +287,7 @@ public class ChildOutputPropertyDeriverTest {
|
||||
Lists.newArrayList(partition),
|
||||
AggPhase.GLOBAL,
|
||||
true,
|
||||
true,
|
||||
logicalProperties,
|
||||
groupPlan
|
||||
);
|
||||
|
||||
@ -146,6 +146,7 @@ public class RequestPropertyDeriverTest {
|
||||
Lists.newArrayList(key),
|
||||
AggPhase.LOCAL,
|
||||
true,
|
||||
true,
|
||||
logicalProperties,
|
||||
groupPlan
|
||||
);
|
||||
@ -168,6 +169,7 @@ public class RequestPropertyDeriverTest {
|
||||
Lists.newArrayList(partition),
|
||||
AggPhase.GLOBAL,
|
||||
true,
|
||||
true,
|
||||
logicalProperties,
|
||||
groupPlan
|
||||
);
|
||||
@ -192,6 +194,7 @@ public class RequestPropertyDeriverTest {
|
||||
Lists.newArrayList(),
|
||||
AggPhase.GLOBAL,
|
||||
true,
|
||||
true,
|
||||
logicalProperties,
|
||||
groupPlan
|
||||
);
|
||||
|
||||
@ -23,6 +23,7 @@ import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.Count;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.Sum;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
|
||||
import org.apache.doris.nereids.trees.plans.AggPhase;
|
||||
@ -269,6 +270,86 @@ public class AggregateDisassembleTest {
|
||||
global.getOutputExpressions().get(0).getExprId());
|
||||
}
|
||||
|
||||
/**
|
||||
* the initial plan is:
|
||||
* Aggregate(phase: [GLOBAL], outputExpr: [(COUNT(distinct age + 1) + 2) as c], groupByExpr: [id + 3])
|
||||
* +-- childPlan(id, name, age)
|
||||
* we should rewrite to:
|
||||
* Aggregate(phase: [DISTINCT_LOCAL], outputExpr: [(COUNT(distinct b) + 2) as c], groupByExpr: [a])
|
||||
* +-- Aggregate(phase: [GLOBAL], outputExpr: [a, b], groupByExpr: [a, b])
|
||||
* +-- Aggregate(phase: [LOCAL], outputExpr: [(id + 3) as a, (age + 1) as b], groupByExpr: [id + 3, age + 1])
|
||||
* +-- childPlan(id, name, age)
|
||||
*/
|
||||
@Test
|
||||
public void distinctAggregateWithGroupBy() {
|
||||
List<Expression> groupExpressionList = Lists.newArrayList(
|
||||
new Add(rStudent.getOutput().get(0).toSlot(), new IntegerLiteral(3)));
|
||||
List<NamedExpression> outputExpressionList = Lists.newArrayList(new Alias(
|
||||
new Add(new Count(new Add(rStudent.getOutput().get(2).toSlot(), new IntegerLiteral(1)), true),
|
||||
new IntegerLiteral(2)), "c"));
|
||||
Plan root = new LogicalAggregate<>(groupExpressionList, outputExpressionList, rStudent);
|
||||
|
||||
Plan after = rewrite(root);
|
||||
|
||||
Assertions.assertTrue(after instanceof LogicalUnary);
|
||||
Assertions.assertTrue(after instanceof LogicalAggregate);
|
||||
Assertions.assertTrue(after.child(0) instanceof LogicalUnary);
|
||||
LogicalAggregate<Plan> distinctLocal = (LogicalAggregate) after;
|
||||
LogicalAggregate<Plan> global = (LogicalAggregate) after.child(0);
|
||||
LogicalAggregate<Plan> local = (LogicalAggregate) after.child(0).child(0);
|
||||
Assertions.assertEquals(AggPhase.DISTINCT_LOCAL, distinctLocal.getAggPhase());
|
||||
Assertions.assertEquals(AggPhase.GLOBAL, global.getAggPhase());
|
||||
Assertions.assertEquals(AggPhase.LOCAL, local.getAggPhase());
|
||||
// check local:
|
||||
// id + 3
|
||||
Expression localOutput0 = new Add(rStudent.getOutput().get(0).toSlot(), new IntegerLiteral(3));
|
||||
// age + 1
|
||||
Expression localOutput1 = new Add(rStudent.getOutput().get(2).toSlot(), new IntegerLiteral(1));
|
||||
// id + 3
|
||||
Expression localGroupBy0 = new Add(rStudent.getOutput().get(0).toSlot(), new IntegerLiteral(3));
|
||||
// age + 1
|
||||
Expression localGroupBy1 = new Add(rStudent.getOutput().get(2).toSlot(), new IntegerLiteral(1));
|
||||
|
||||
Assertions.assertEquals(2, local.getOutputExpressions().size());
|
||||
Assertions.assertTrue(local.getOutputExpressions().get(0) instanceof Alias);
|
||||
Assertions.assertEquals(localOutput0, local.getOutputExpressions().get(0).child(0));
|
||||
Assertions.assertTrue(local.getOutputExpressions().get(1) instanceof Alias);
|
||||
Assertions.assertEquals(localOutput1, local.getOutputExpressions().get(1).child(0));
|
||||
Assertions.assertEquals(2, local.getGroupByExpressions().size());
|
||||
Assertions.assertEquals(localGroupBy0, local.getGroupByExpressions().get(0));
|
||||
Assertions.assertEquals(localGroupBy1, local.getGroupByExpressions().get(1));
|
||||
|
||||
// check global:
|
||||
Expression globalOutput0 = local.getOutputExpressions().get(0).toSlot();
|
||||
Expression globalOutput1 = local.getOutputExpressions().get(1).toSlot();
|
||||
Expression globalGroupBy0 = local.getOutputExpressions().get(0).toSlot();
|
||||
Expression globalGroupBy1 = local.getOutputExpressions().get(1).toSlot();
|
||||
|
||||
Assertions.assertEquals(2, global.getOutputExpressions().size());
|
||||
Assertions.assertTrue(global.getOutputExpressions().get(0) instanceof SlotReference);
|
||||
Assertions.assertEquals(globalOutput0, global.getOutputExpressions().get(0));
|
||||
Assertions.assertTrue(global.getOutputExpressions().get(1) instanceof SlotReference);
|
||||
Assertions.assertEquals(globalOutput1, global.getOutputExpressions().get(1));
|
||||
Assertions.assertEquals(2, global.getGroupByExpressions().size());
|
||||
Assertions.assertEquals(globalGroupBy0, global.getGroupByExpressions().get(0));
|
||||
Assertions.assertEquals(globalGroupBy1, global.getGroupByExpressions().get(1));
|
||||
|
||||
// check distinct local:
|
||||
Expression distinctLocalOutput = new Add(new Count(local.getOutputExpressions().get(1).toSlot(), true),
|
||||
new IntegerLiteral(2));
|
||||
Expression distinctLocalGroupBy = local.getOutputExpressions().get(0).toSlot();
|
||||
|
||||
Assertions.assertEquals(1, distinctLocal.getOutputExpressions().size());
|
||||
Assertions.assertTrue(distinctLocal.getOutputExpressions().get(0) instanceof Alias);
|
||||
Assertions.assertEquals(distinctLocalOutput, distinctLocal.getOutputExpressions().get(0).child(0));
|
||||
Assertions.assertEquals(1, distinctLocal.getGroupByExpressions().size());
|
||||
Assertions.assertEquals(distinctLocalGroupBy, distinctLocal.getGroupByExpressions().get(0));
|
||||
|
||||
// check id:
|
||||
Assertions.assertEquals(outputExpressionList.get(0).getExprId(),
|
||||
distinctLocal.getOutputExpressions().get(0).getExprId());
|
||||
}
|
||||
|
||||
private Plan rewrite(Plan input) {
|
||||
return PlanRewriter.topDownRewrite(input, new ConnectContext(), new AggregateDisassemble());
|
||||
}
|
||||
|
||||
@ -20,6 +20,7 @@ package org.apache.doris.nereids.trees.expressions;
|
||||
import org.apache.doris.nereids.analyzer.UnboundAlias;
|
||||
import org.apache.doris.nereids.analyzer.UnboundFunction;
|
||||
import org.apache.doris.nereids.analyzer.UnboundStar;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.Count;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.Sum;
|
||||
import org.apache.doris.nereids.types.IntegerType;
|
||||
|
||||
@ -168,6 +169,25 @@ public class ExpressionEqualsTest {
|
||||
Assertions.assertEquals(sum1.hashCode(), sum2.hashCode());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAggregateFunction() {
|
||||
Count count1 = new Count();
|
||||
Count count2 = new Count();
|
||||
Assertions.assertEquals(count1, count2);
|
||||
Assertions.assertEquals(count1.hashCode(), count2.hashCode());
|
||||
|
||||
Count count3 = new Count(child1, true);
|
||||
Count count4 = new Count(child2, true);
|
||||
Assertions.assertEquals(count3, count4);
|
||||
Assertions.assertEquals(count3.hashCode(), count4.hashCode());
|
||||
|
||||
// bad case
|
||||
Count count5 = new Count(child1, true);
|
||||
Count count6 = new Count(child2, false);
|
||||
Assertions.assertNotEquals(count5, count6);
|
||||
Assertions.assertNotEquals(count5.hashCode(), count6.hashCode());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNamedExpression() {
|
||||
ExprId aliasId = new ExprId(2);
|
||||
|
||||
@ -71,17 +71,17 @@ public class PlanEqualsTest {
|
||||
|
||||
unexpected = new LogicalAggregate<>(Lists.newArrayList(), ImmutableList.of(
|
||||
new SlotReference(new ExprId(1), "b", BigIntType.INSTANCE, true, Lists.newArrayList())),
|
||||
true, false, AggPhase.GLOBAL, child);
|
||||
true, false, true, AggPhase.GLOBAL, child);
|
||||
Assertions.assertNotEquals(unexpected, actual);
|
||||
|
||||
unexpected = new LogicalAggregate<>(Lists.newArrayList(), ImmutableList.of(
|
||||
new SlotReference(new ExprId(1), "b", BigIntType.INSTANCE, true, Lists.newArrayList())),
|
||||
false, true, AggPhase.GLOBAL, child);
|
||||
false, true, true, AggPhase.GLOBAL, child);
|
||||
Assertions.assertNotEquals(unexpected, actual);
|
||||
|
||||
unexpected = new LogicalAggregate<>(Lists.newArrayList(), ImmutableList.of(
|
||||
new SlotReference(new ExprId(1), "b", BigIntType.INSTANCE, true, Lists.newArrayList())),
|
||||
false, false, AggPhase.LOCAL, child);
|
||||
false, false, true, AggPhase.LOCAL, child);
|
||||
Assertions.assertNotEquals(unexpected, actual);
|
||||
}
|
||||
|
||||
@ -178,20 +178,20 @@ public class PlanEqualsTest {
|
||||
List<NamedExpression> outputExpressionList = ImmutableList.of(
|
||||
new SlotReference(new ExprId(0), "a", BigIntType.INSTANCE, true, Lists.newArrayList()));
|
||||
PhysicalAggregate<Plan> actual = new PhysicalAggregate<>(Lists.newArrayList(), outputExpressionList,
|
||||
Lists.newArrayList(), AggPhase.LOCAL, true, logicalProperties, child);
|
||||
Lists.newArrayList(), AggPhase.LOCAL, true, true, logicalProperties, child);
|
||||
|
||||
List<NamedExpression> outputExpressionList1 = ImmutableList.of(
|
||||
new SlotReference(new ExprId(0), "a", BigIntType.INSTANCE, true, Lists.newArrayList()));
|
||||
PhysicalAggregate<Plan> expected = new PhysicalAggregate<>(Lists.newArrayList(),
|
||||
outputExpressionList1,
|
||||
Lists.newArrayList(), AggPhase.LOCAL, true, logicalProperties, child);
|
||||
Lists.newArrayList(), AggPhase.LOCAL, true, true, logicalProperties, child);
|
||||
Assertions.assertEquals(expected, actual);
|
||||
|
||||
List<NamedExpression> outputExpressionList2 = ImmutableList.of(
|
||||
new SlotReference(new ExprId(0), "a", BigIntType.INSTANCE, true, Lists.newArrayList()));
|
||||
PhysicalAggregate<Plan> unexpected = new PhysicalAggregate<>(Lists.newArrayList(),
|
||||
outputExpressionList2,
|
||||
Lists.newArrayList(), AggPhase.LOCAL, false, logicalProperties, child);
|
||||
Lists.newArrayList(), AggPhase.LOCAL, false, true, logicalProperties, child);
|
||||
Assertions.assertNotEquals(unexpected, actual);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user