[improvement](nereids) Support aggregate functions without from clause (#25500)
Support aggregate functions in select without from clause, here are some examples as following: SELECT 1, 'a', COUNT(), SUM(1) + 1, AVG(2) / COUNT(), MAX(3), MIN(4), RANK() OVER() AS w_rank, DENSE_RANK() OVER() AS w_dense_rank, ROW_NUMBER() OVER() AS w_row_number, SUM(5) OVER() AS w_sum, AVG(6) OVER() AS w_avg, COUNT() OVER() AS w_count, MAX(7) OVER() AS w_max, MIN(8) OVER() AS w_min;
This commit is contained in:
@ -34,6 +34,7 @@ import org.apache.doris.nereids.rules.analysis.EliminateGroupByConstant;
|
||||
import org.apache.doris.nereids.rules.analysis.FillUpMissingSlots;
|
||||
import org.apache.doris.nereids.rules.analysis.NormalizeAggregate;
|
||||
import org.apache.doris.nereids.rules.analysis.NormalizeRepeat;
|
||||
import org.apache.doris.nereids.rules.analysis.OneRowRelationExtractAggregate;
|
||||
import org.apache.doris.nereids.rules.analysis.ProjectToGlobalAggregate;
|
||||
import org.apache.doris.nereids.rules.analysis.ProjectWithDistinctToAggregate;
|
||||
import org.apache.doris.nereids.rules.analysis.ReplaceExpressionByChildOutput;
|
||||
@ -103,7 +104,8 @@ public class Analyzer extends AbstractBatchJobExecutor {
|
||||
// please see rule BindSlotReference or BindFunction for example
|
||||
new ProjectWithDistinctToAggregate(),
|
||||
new ResolveOrdinalInOrderByAndGroupBy(),
|
||||
new ReplaceExpressionByChildOutput()
|
||||
new ReplaceExpressionByChildOutput(),
|
||||
new OneRowRelationExtractAggregate()
|
||||
),
|
||||
topDown(
|
||||
new FillUpMissingSlots(),
|
||||
|
||||
@ -71,6 +71,7 @@ public enum RuleType {
|
||||
RESOLVE_PROJECT_ALIAS(RuleTypeClass.REWRITE),
|
||||
RESOLVE_AGGREGATE_ALIAS(RuleTypeClass.REWRITE),
|
||||
PROJECT_TO_GLOBAL_AGGREGATE(RuleTypeClass.REWRITE),
|
||||
ONE_ROW_RELATION_EXTRACT_AGGREGATE(RuleTypeClass.REWRITE),
|
||||
PROJECT_WITH_DISTINCT_TO_AGGREGATE(RuleTypeClass.REWRITE),
|
||||
AVG_DISTINCT_TO_SUM_DIV_COUNT(RuleTypeClass.REWRITE),
|
||||
ANALYZE_CTE(RuleTypeClass.REWRITE),
|
||||
|
||||
@ -76,7 +76,6 @@ public class CheckAnalysis implements AnalysisRuleFactory {
|
||||
TableGeneratingFunction.class,
|
||||
WindowExpression.class))
|
||||
.put(LogicalOneRowRelation.class, ImmutableSet.of(
|
||||
AggregateFunction.class,
|
||||
GroupingScalarFunction.class,
|
||||
TableGeneratingFunction.class,
|
||||
WindowExpression.class))
|
||||
|
||||
@ -0,0 +1,68 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.analysis;
|
||||
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitors;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalRelation;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* OneRowRelationExtractAggregate.
|
||||
* <p>
|
||||
* example sql:
|
||||
* <pre>
|
||||
* SELECT 1, 'a', COUNT();
|
||||
* </pre>
|
||||
* <p>
|
||||
* origin plan:
|
||||
* <p>
|
||||
* LogicalOneRowRelation ( projects=[1 AS `1`#0, 'a' AS `'a'`#1, count(*) AS `count(*)`#2] )
|
||||
* transformed plan:
|
||||
* <p>
|
||||
* LogicalAggregate[23] ( groupByExpr=[], outputExpr=[1 AS `1`#0, 'a' AS `'a'`#1, count(*) AS `count(*)`#2],
|
||||
* hasRepeat=false )
|
||||
* LogicalOneRowRelation ( projects=[] )
|
||||
*/
|
||||
public class OneRowRelationExtractAggregate extends OneAnalysisRuleFactory {
|
||||
@Override
|
||||
public Rule build() {
|
||||
return RuleType.ONE_ROW_RELATION_EXTRACT_AGGREGATE.build(
|
||||
logicalOneRowRelation().then(relation -> {
|
||||
List<NamedExpression> outputs = relation.getOutputs();
|
||||
boolean needGlobalAggregate = outputs
|
||||
.stream()
|
||||
.anyMatch(p -> p.accept(ExpressionVisitors.CONTAINS_AGGREGATE_CHECKER, null));
|
||||
if (needGlobalAggregate) {
|
||||
LogicalRelation newRelation = new LogicalOneRowRelation(relation.getRelationId(),
|
||||
ImmutableList.of());
|
||||
return new LogicalAggregate<>(ImmutableList.of(), relation.getOutputs(), newRelation);
|
||||
} else {
|
||||
return relation;
|
||||
}
|
||||
})
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -19,10 +19,7 @@ package org.apache.doris.nereids.rules.analysis;
|
||||
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.WindowExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitors;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
@ -49,7 +46,7 @@ public class ProjectToGlobalAggregate extends OneAnalysisRuleFactory {
|
||||
logicalProject().then(project -> {
|
||||
boolean needGlobalAggregate = project.getProjects()
|
||||
.stream()
|
||||
.anyMatch(p -> p.accept(ContainsAggregateChecker.INSTANCE, null));
|
||||
.anyMatch(p -> p.accept(ExpressionVisitors.CONTAINS_AGGREGATE_CHECKER, null));
|
||||
|
||||
if (needGlobalAggregate) {
|
||||
return new LogicalAggregate<>(ImmutableList.of(), project.getProjects(), project.child());
|
||||
@ -59,32 +56,4 @@ public class ProjectToGlobalAggregate extends OneAnalysisRuleFactory {
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
private static class ContainsAggregateChecker extends DefaultExpressionVisitor<Boolean, Void> {
|
||||
|
||||
private static final ContainsAggregateChecker INSTANCE = new ContainsAggregateChecker();
|
||||
|
||||
@Override
|
||||
public Boolean visit(Expression expr, Void context) {
|
||||
boolean needAggregate = false;
|
||||
for (Expression child : expr.children()) {
|
||||
needAggregate = needAggregate || child.accept(this, context);
|
||||
}
|
||||
return needAggregate;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Boolean visitWindow(WindowExpression windowExpression, Void context) {
|
||||
boolean needAggregate = false;
|
||||
for (Expression child : windowExpression.getExpressionsInWindowSpec()) {
|
||||
needAggregate = needAggregate || child.accept(this, context);
|
||||
}
|
||||
return needAggregate;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Boolean visitAggregateFunction(AggregateFunction aggregateFunction, Void context) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -215,6 +215,9 @@ public class ColumnPruning extends DefaultPlanRewriter<PruneContext> implements
|
||||
/** prune output */
|
||||
public static <P extends Plan> P pruneOutput(P plan, List<NamedExpression> originOutput,
|
||||
Function<List<NamedExpression>, P> withPrunedOutput, PruneContext context) {
|
||||
if (originOutput.isEmpty()) {
|
||||
return plan;
|
||||
}
|
||||
List<NamedExpression> prunedOutputs = originOutput.stream()
|
||||
.filter(output -> context.requiredSlots.contains(output.toSlot()))
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
|
||||
@ -0,0 +1,57 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.trees.expressions.visitor;
|
||||
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.WindowExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
|
||||
|
||||
/**
|
||||
* This is the factory for all ExpressionVisitor instance.
|
||||
* All children instance of DefaultExpressionVisitor or ExpressionVisitor for common usage
|
||||
* should be here and expose self by class static final field.
|
||||
*/
|
||||
public class ExpressionVisitors {
|
||||
|
||||
public static final ContainsAggregateChecker CONTAINS_AGGREGATE_CHECKER = new ContainsAggregateChecker();
|
||||
|
||||
private static class ContainsAggregateChecker extends DefaultExpressionVisitor<Boolean, Void> {
|
||||
@Override
|
||||
public Boolean visit(Expression expr, Void context) {
|
||||
boolean needAggregate = false;
|
||||
for (Expression child : expr.children()) {
|
||||
needAggregate = needAggregate || child.accept(this, context);
|
||||
}
|
||||
return needAggregate;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Boolean visitWindow(WindowExpression windowExpression, Void context) {
|
||||
boolean needAggregate = false;
|
||||
for (Expression child : windowExpression.getExpressionsInWindowSpec()) {
|
||||
needAggregate = needAggregate || child.accept(this, context);
|
||||
}
|
||||
return needAggregate;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Boolean visitAggregateFunction(AggregateFunction aggregateFunction, Void context) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -22,7 +22,6 @@ import org.apache.doris.nereids.properties.LogicalProperties;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.PlanType;
|
||||
import org.apache.doris.nereids.trees.plans.RelationId;
|
||||
@ -30,7 +29,6 @@ import org.apache.doris.nereids.trees.plans.algebra.OneRowRelation;
|
||||
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
|
||||
import org.apache.doris.nereids.util.Utils;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.List;
|
||||
@ -52,8 +50,6 @@ public class LogicalOneRowRelation extends LogicalRelation implements OneRowRela
|
||||
private LogicalOneRowRelation(RelationId relationId, List<NamedExpression> projects,
|
||||
Optional<GroupExpression> groupExpression, Optional<LogicalProperties> logicalProperties) {
|
||||
super(relationId, PlanType.LOGICAL_ONE_ROW_RELATION, groupExpression, logicalProperties);
|
||||
Preconditions.checkArgument(projects.stream().noneMatch(p -> p.containsType(AggregateFunction.class)),
|
||||
"OneRowRelation can not contains any aggregate function");
|
||||
this.projects = ImmutableList.copyOf(Objects.requireNonNull(projects, "projects can not be null"));
|
||||
}
|
||||
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
-- This file is automatically generated. You should know what you did if you want to edit this
|
||||
-- !projectAggFuncs --
|
||||
1 a 1 2 2.0 3 4 1 1 1 5 6.0 1 7 8
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
-- This file is automatically generated. You should know what you did if you want to edit this
|
||||
-- !projectAggFuncs --
|
||||
1 a 1 2 2.0 3 4 1 1 1 5 6.0 1 7 8
|
||||
|
||||
@ -1,4 +1,17 @@
|
||||
/*
|
||||
-- database: presto; groups: no_from
|
||||
SELECT COUNT(10), MAX(50), MIN(90.0)
|
||||
*/
|
||||
SELECT 1,
|
||||
'a',
|
||||
COUNT(),
|
||||
SUM(1) + 1,
|
||||
AVG(2) / COUNT(),
|
||||
MAX(3),
|
||||
MIN(4),
|
||||
RANK() OVER() AS w_rank,
|
||||
DENSE_RANK() OVER() AS w_dense_rank,
|
||||
ROW_NUMBER() OVER() AS w_row_number,
|
||||
SUM(5) OVER() AS w_sum,
|
||||
AVG(6) OVER() AS w_avg,
|
||||
COUNT() OVER() AS w_count,
|
||||
MAX(7) OVER() AS w_max,
|
||||
MIN(8) OVER() AS w_min;
|
||||
|
||||
|
||||
@ -31,11 +31,4 @@ suite("one_row_relation") {
|
||||
)a"""
|
||||
result([[100, "abc", "ab", "de", null]])
|
||||
}
|
||||
|
||||
test {
|
||||
sql """
|
||||
select sum(1);
|
||||
"""
|
||||
exception "OneRowRelation can not contains any aggregate function"
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,4 +1,16 @@
|
||||
/*
|
||||
-- database: presto; groups: no_from
|
||||
SELECT COUNT(10), MAX(50), MIN(90.0)
|
||||
*/
|
||||
SELECT 1,
|
||||
'a',
|
||||
COUNT(),
|
||||
SUM(1) + 1,
|
||||
AVG(2) / COUNT(),
|
||||
MAX(3),
|
||||
MIN(4),
|
||||
RANK() OVER() AS w_rank,
|
||||
DENSE_RANK() OVER() AS w_dense_rank,
|
||||
ROW_NUMBER() OVER() AS w_row_number,
|
||||
SUM(5) OVER() AS w_sum,
|
||||
AVG(6) OVER() AS w_avg,
|
||||
COUNT() OVER() AS w_count,
|
||||
MAX(7) OVER() AS w_max,
|
||||
MIN(8) OVER() AS w_min;
|
||||
|
||||
Reference in New Issue
Block a user