branch-2.1: [feature](group by)Support group by with order. (#53037) (#53840)

backport: https://github.com/apache/doris/pull/53037
This commit is contained in:
James
2025-07-25 14:37:10 +08:00
committed by GitHub
parent 77ff75b954
commit c1fa17af38
6 changed files with 378 additions and 8 deletions

View File

@ -1145,6 +1145,10 @@ relationHint
| HINT_START identifier (COMMA identifier)* HINT_END #commentRelationHint
;
expressionWithOrder
: expression ordering = (ASC | DESC)?
;
aggClause
: GROUP BY groupingElement
;
@ -1153,7 +1157,7 @@ groupingElement
: ROLLUP LEFT_PAREN (expression (COMMA expression)*)? RIGHT_PAREN
| CUBE LEFT_PAREN (expression (COMMA expression)*)? RIGHT_PAREN
| GROUPING SETS LEFT_PAREN groupingSet (COMMA groupingSet)* RIGHT_PAREN
| expression (COMMA expression)* (WITH ROLLUP)?
| expressionWithOrder (COMMA expressionWithOrder)* (WITH ROLLUP)?
;
groupingSet

View File

@ -0,0 +1,86 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.parser;
import org.apache.doris.nereids.trees.expressions.Expression;
import java.util.Objects;
/**
* Represents the group by expression with order of a statement.
*/
public class GroupKeyWithOrder {
private final Expression expr;
// Order is ascending.
private final boolean hasOrder;
private final boolean isAsc;
/**
* Constructor of GroupKeyWithOrder.
*/
public GroupKeyWithOrder(Expression expr, boolean hasOrder, boolean isAsc) {
this.expr = expr;
this.hasOrder = hasOrder;
this.isAsc = isAsc;
}
public Expression getExpr() {
return expr;
}
public boolean isAsc() {
return isAsc;
}
public boolean hasOrder() {
return hasOrder;
}
public GroupKeyWithOrder withExpression(Expression expr) {
return new GroupKeyWithOrder(expr, isAsc, hasOrder);
}
public String toSql() {
return expr.toSql() + (hasOrder ? (isAsc ? " asc" : " desc") : "");
}
@Override
public String toString() {
return expr.toString() + (hasOrder ? (isAsc ? " asc" : " desc") : "");
}
@Override
public int hashCode() {
return Objects.hash(expr, isAsc, hasOrder);
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
GroupKeyWithOrder that = (GroupKeyWithOrder) o;
return isAsc == that.isAsc() && hasOrder == that.hasOrder() && expr.equals(that.getExpr());
}
}

View File

@ -97,6 +97,7 @@ import org.apache.doris.nereids.DorisParser.ElementAtContext;
import org.apache.doris.nereids.DorisParser.ExistContext;
import org.apache.doris.nereids.DorisParser.ExplainContext;
import org.apache.doris.nereids.DorisParser.ExportContext;
import org.apache.doris.nereids.DorisParser.ExpressionWithOrderContext;
import org.apache.doris.nereids.DorisParser.FixedPartitionDefContext;
import org.apache.doris.nereids.DorisParser.FromClauseContext;
import org.apache.doris.nereids.DorisParser.GroupingElementContext;
@ -2605,6 +2606,16 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
});
}
@Override
public GroupKeyWithOrder visitExpressionWithOrder(ExpressionWithOrderContext ctx) {
return ParserUtils.withOrigin(ctx, () -> {
boolean hasOrder = ctx.ASC() != null || ctx.DESC() != null;
boolean isAsc = ctx.DESC() == null;
Expression expression = typedVisit(ctx.expression());
return new GroupKeyWithOrder(expression, hasOrder, isAsc);
});
}
private <T> List<T> visit(List<? extends ParserRuleContext> contexts, Class<T> clazz) {
return contexts.stream()
.map(this::visit)
@ -3085,8 +3096,10 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
// from -> where -> group by -> having -> select
LogicalPlan filter = withFilter(inputRelation, whereClause);
SelectColumnClauseContext selectColumnCtx = selectClause.selectColumnClause();
LogicalPlan aggregate = withAggregate(filter, selectColumnCtx, aggClause);
List<OrderKey> orderKeys = Lists.newArrayList();
LogicalPlan aggregate = withAggregate(filter, selectColumnCtx, aggClause, orderKeys);
boolean isDistinct = (selectClause.DISTINCT() != null);
LogicalPlan selectPlan;
if (!(aggregate instanceof Aggregate) && havingClause.isPresent()) {
// create a project node for pattern match of ProjectToGlobalAggregate rule
// then ProjectToGlobalAggregate rule can insert agg node as LogicalHaving node's child
@ -3102,12 +3115,16 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
List<NamedExpression> projects = getNamedExpressions(selectColumnCtx.namedExpressionSeq());
project = new LogicalProject<>(projects, ImmutableList.of(), isDistinct, aggregate);
}
return new LogicalHaving<>(ExpressionUtils.extractConjunctionToSet(
selectPlan = new LogicalHaving<>(ExpressionUtils.extractConjunctionToSet(
getExpression((havingClause.get().booleanExpression()))), project);
} else {
LogicalPlan having = withHaving(aggregate, havingClause);
return withProjection(having, selectColumnCtx, aggClause, isDistinct);
selectPlan = withProjection(having, selectColumnCtx, aggClause, isDistinct);
}
if (!orderKeys.isEmpty()) {
selectPlan = new LogicalSort<>(orderKeys, selectPlan);
}
return selectPlan;
});
}
@ -3385,7 +3402,7 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
}
private LogicalPlan withAggregate(LogicalPlan input, SelectColumnClauseContext selectCtx,
Optional<AggClauseContext> aggCtx) {
Optional<AggClauseContext> aggCtx, List<OrderKey> orderKeys) {
return input.optionalMap(aggCtx, () -> {
GroupingElementContext groupingElementContext = aggCtx.get().groupingElement();
List<NamedExpression> namedExpressions = getNamedExpressions(selectCtx.namedExpressionSeq());
@ -3399,13 +3416,27 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
List<Expression> cubeExpressions = visit(groupingElementContext.expression(), Expression.class);
List<List<Expression>> groupingSets = ExpressionUtils.cubeToGroupingSets(cubeExpressions);
return new LogicalRepeat<>(groupingSets, namedExpressions, input);
} else if (groupingElementContext.ROLLUP() != null) {
} else if (groupingElementContext.ROLLUP() != null && groupingElementContext.WITH() == null) {
List<Expression> rollupExpressions = visit(groupingElementContext.expression(), Expression.class);
List<List<Expression>> groupingSets = ExpressionUtils.rollupToGroupingSets(rollupExpressions);
return new LogicalRepeat<>(groupingSets, namedExpressions, input);
} else {
List<Expression> groupByExpressions = visit(groupingElementContext.expression(), Expression.class);
return new LogicalAggregate<>(groupByExpressions, namedExpressions, input);
List<GroupKeyWithOrder> groupKeyWithOrders = visit(groupingElementContext.expressionWithOrder(),
GroupKeyWithOrder.class);
ImmutableList<Expression> groupByExpressions = groupKeyWithOrders.stream()
.map(GroupKeyWithOrder::getExpr)
.collect(ImmutableList.toImmutableList());
if (groupKeyWithOrders.stream().anyMatch(GroupKeyWithOrder::hasOrder)) {
groupKeyWithOrders.stream()
.map(e -> new OrderKey(e.getExpr(), e.isAsc(), e.isAsc()))
.forEach(orderKeys::add);
}
if (groupingElementContext.ROLLUP() != null) {
List<List<Expression>> groupingSets = ExpressionUtils.rollupToGroupingSets(groupByExpressions);
return new LogicalRepeat<>(groupingSets, namedExpressions, input);
} else {
return new LogicalAggregate<>(groupByExpressions, namedExpressions, input);
}
}
});
}

View File

@ -37,6 +37,8 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalCTE;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.types.DateTimeType;
import org.apache.doris.nereids.types.DateType;
import org.apache.doris.nereids.types.DecimalV2Type;
@ -419,4 +421,56 @@ public class NereidsParserTest extends ParserTestBase {
NereidsParser nereidsParser = new NereidsParser();
nereidsParser.parseSingle(sql);
}
private void checkQueryTopPlanClass(String sql, NereidsParser parser, Class<?> clazz) {
if (clazz == null) {
Assertions.assertThrows(ParseException.class, () -> parser.parseSingle(sql));
} else {
LogicalPlan logicalPlan = parser.parseSingle(sql);
Assertions.assertInstanceOf(clazz, logicalPlan.child(0));
}
}
@Test
public void testExpressionWithOrder() {
NereidsParser nereidsParser = new NereidsParser();
checkQueryTopPlanClass("SELECT a, b, sum(c) from test group by a, b DESC",
nereidsParser, LogicalSort.class);
checkQueryTopPlanClass("SELECT a, b, sum(c) from test group by a DESC, b",
nereidsParser, LogicalSort.class);
checkQueryTopPlanClass("SELECT a, b, sum(c) from test group by a ASC, b",
nereidsParser, LogicalSort.class);
checkQueryTopPlanClass("SELECT a, b, sum(c) from test group by a, b ASC",
nereidsParser, LogicalSort.class);
checkQueryTopPlanClass("SELECT a, b, sum(c) from test group by a ASC, b ASC",
nereidsParser, LogicalSort.class);
checkQueryTopPlanClass("SELECT a, b, sum(c) from test group by a DESC, b DESC",
nereidsParser, LogicalSort.class);
checkQueryTopPlanClass("SELECT a, b, sum(c) from test group by a ASC, b DESC",
nereidsParser, LogicalSort.class);
checkQueryTopPlanClass("SELECT a, b, sum(c) from test group by a DESC, b ASC",
nereidsParser, LogicalSort.class);
checkQueryTopPlanClass("SELECT a, b, sum(c) from test group by a, b DESC WITH ROLLUP",
nereidsParser, LogicalSort.class);
checkQueryTopPlanClass("SELECT a, b, sum(c) from test group by a DESC, b WITH ROLLUP",
nereidsParser, LogicalSort.class);
checkQueryTopPlanClass("SELECT a, b, sum(c) from test group by a ASC, b WITH ROLLUP",
nereidsParser, LogicalSort.class);
checkQueryTopPlanClass("SELECT a, b, sum(c) from test group by a, b ASC WITH ROLLUP",
nereidsParser, LogicalSort.class);
checkQueryTopPlanClass("SELECT a, b, sum(c) from test group by a ASC, b ASC WITH ROLLUP",
nereidsParser, LogicalSort.class);
checkQueryTopPlanClass("SELECT a, b, sum(c) from test group by a DESC, b DESC WITH ROLLUP",
nereidsParser, LogicalSort.class);
checkQueryTopPlanClass("SELECT a, b, sum(c) from test group by a ASC, b DESC WITH ROLLUP",
nereidsParser, LogicalSort.class);
checkQueryTopPlanClass("SELECT a, b, sum(c) from test group by a DESC, b ASC WITH ROLLUP",
nereidsParser, LogicalSort.class);
checkQueryTopPlanClass("SELECT a, b, sum(c) from test group by a, b",
nereidsParser, LogicalAggregate.class);
checkQueryTopPlanClass("SELECT a, b, sum(c) from test group by a, b WITH ROLLUP",
nereidsParser, LogicalRepeat.class);
}
}