[improvement](nereids) Support aggregate functions without from clause (#25500)

Support aggregate functions in select without from clause, here are some examples as following:

SELECT 1,  
  'a',
   COUNT(),  
   SUM(1) + 1,
   AVG(2) / COUNT(),
   MAX(3),
   MIN(4),
   RANK() OVER() AS w_rank,
   DENSE_RANK() OVER() AS w_dense_rank,
   ROW_NUMBER() OVER() AS w_row_number,
   SUM(5) OVER() AS w_sum,
   AVG(6) OVER() AS w_avg,
   COUNT() OVER() AS w_count,
   MAX(7) OVER() AS w_max,
   MIN(8) OVER() AS w_min;
This commit is contained in:
JingDas
2023-10-19 12:07:37 +08:00
committed by GitHub
parent fcf7bdc9e0
commit b45f501e51
13 changed files with 173 additions and 52 deletions

View File

@ -34,6 +34,7 @@ import org.apache.doris.nereids.rules.analysis.EliminateGroupByConstant;
import org.apache.doris.nereids.rules.analysis.FillUpMissingSlots;
import org.apache.doris.nereids.rules.analysis.NormalizeAggregate;
import org.apache.doris.nereids.rules.analysis.NormalizeRepeat;
import org.apache.doris.nereids.rules.analysis.OneRowRelationExtractAggregate;
import org.apache.doris.nereids.rules.analysis.ProjectToGlobalAggregate;
import org.apache.doris.nereids.rules.analysis.ProjectWithDistinctToAggregate;
import org.apache.doris.nereids.rules.analysis.ReplaceExpressionByChildOutput;
@ -103,7 +104,8 @@ public class Analyzer extends AbstractBatchJobExecutor {
// please see rule BindSlotReference or BindFunction for example
new ProjectWithDistinctToAggregate(),
new ResolveOrdinalInOrderByAndGroupBy(),
new ReplaceExpressionByChildOutput()
new ReplaceExpressionByChildOutput(),
new OneRowRelationExtractAggregate()
),
topDown(
new FillUpMissingSlots(),

View File

@ -71,6 +71,7 @@ public enum RuleType {
RESOLVE_PROJECT_ALIAS(RuleTypeClass.REWRITE),
RESOLVE_AGGREGATE_ALIAS(RuleTypeClass.REWRITE),
PROJECT_TO_GLOBAL_AGGREGATE(RuleTypeClass.REWRITE),
ONE_ROW_RELATION_EXTRACT_AGGREGATE(RuleTypeClass.REWRITE),
PROJECT_WITH_DISTINCT_TO_AGGREGATE(RuleTypeClass.REWRITE),
AVG_DISTINCT_TO_SUM_DIV_COUNT(RuleTypeClass.REWRITE),
ANALYZE_CTE(RuleTypeClass.REWRITE),

View File

@ -76,7 +76,6 @@ public class CheckAnalysis implements AnalysisRuleFactory {
TableGeneratingFunction.class,
WindowExpression.class))
.put(LogicalOneRowRelation.class, ImmutableSet.of(
AggregateFunction.class,
GroupingScalarFunction.class,
TableGeneratingFunction.class,
WindowExpression.class))

View File

@ -0,0 +1,68 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitors;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalRelation;
import com.google.common.collect.ImmutableList;
import java.util.List;
/**
* OneRowRelationExtractAggregate.
* <p>
* example sql:
* <pre>
* SELECT 1, 'a', COUNT();
* </pre>
* <p>
* origin plan:
* <p>
* LogicalOneRowRelation ( projects=[1 AS `1`#0, 'a' AS `'a'`#1, count(*) AS `count(*)`#2] )
* transformed plan:
* <p>
* LogicalAggregate[23] ( groupByExpr=[], outputExpr=[1 AS `1`#0, 'a' AS `'a'`#1, count(*) AS `count(*)`#2],
* hasRepeat=false )
* LogicalOneRowRelation ( projects=[] )
*/
public class OneRowRelationExtractAggregate extends OneAnalysisRuleFactory {
@Override
public Rule build() {
return RuleType.ONE_ROW_RELATION_EXTRACT_AGGREGATE.build(
logicalOneRowRelation().then(relation -> {
List<NamedExpression> outputs = relation.getOutputs();
boolean needGlobalAggregate = outputs
.stream()
.anyMatch(p -> p.accept(ExpressionVisitors.CONTAINS_AGGREGATE_CHECKER, null));
if (needGlobalAggregate) {
LogicalRelation newRelation = new LogicalOneRowRelation(relation.getRelationId(),
ImmutableList.of());
return new LogicalAggregate<>(ImmutableList.of(), relation.getOutputs(), newRelation);
} else {
return relation;
}
})
);
}
}

View File

@ -19,10 +19,7 @@ package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitors;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import com.google.common.collect.ImmutableList;
@ -49,7 +46,7 @@ public class ProjectToGlobalAggregate extends OneAnalysisRuleFactory {
logicalProject().then(project -> {
boolean needGlobalAggregate = project.getProjects()
.stream()
.anyMatch(p -> p.accept(ContainsAggregateChecker.INSTANCE, null));
.anyMatch(p -> p.accept(ExpressionVisitors.CONTAINS_AGGREGATE_CHECKER, null));
if (needGlobalAggregate) {
return new LogicalAggregate<>(ImmutableList.of(), project.getProjects(), project.child());
@ -59,32 +56,4 @@ public class ProjectToGlobalAggregate extends OneAnalysisRuleFactory {
})
);
}
private static class ContainsAggregateChecker extends DefaultExpressionVisitor<Boolean, Void> {
private static final ContainsAggregateChecker INSTANCE = new ContainsAggregateChecker();
@Override
public Boolean visit(Expression expr, Void context) {
boolean needAggregate = false;
for (Expression child : expr.children()) {
needAggregate = needAggregate || child.accept(this, context);
}
return needAggregate;
}
@Override
public Boolean visitWindow(WindowExpression windowExpression, Void context) {
boolean needAggregate = false;
for (Expression child : windowExpression.getExpressionsInWindowSpec()) {
needAggregate = needAggregate || child.accept(this, context);
}
return needAggregate;
}
@Override
public Boolean visitAggregateFunction(AggregateFunction aggregateFunction, Void context) {
return true;
}
}
}

View File

@ -215,6 +215,9 @@ public class ColumnPruning extends DefaultPlanRewriter<PruneContext> implements
/** prune output */
public static <P extends Plan> P pruneOutput(P plan, List<NamedExpression> originOutput,
Function<List<NamedExpression>, P> withPrunedOutput, PruneContext context) {
if (originOutput.isEmpty()) {
return plan;
}
List<NamedExpression> prunedOutputs = originOutput.stream()
.filter(output -> context.requiredSlots.contains(output.toSlot()))
.collect(ImmutableList.toImmutableList());

View File

@ -0,0 +1,57 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions.visitor;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
/**
* This is the factory for all ExpressionVisitor instance.
* All children instance of DefaultExpressionVisitor or ExpressionVisitor for common usage
* should be here and expose self by class static final field.
*/
public class ExpressionVisitors {
public static final ContainsAggregateChecker CONTAINS_AGGREGATE_CHECKER = new ContainsAggregateChecker();
private static class ContainsAggregateChecker extends DefaultExpressionVisitor<Boolean, Void> {
@Override
public Boolean visit(Expression expr, Void context) {
boolean needAggregate = false;
for (Expression child : expr.children()) {
needAggregate = needAggregate || child.accept(this, context);
}
return needAggregate;
}
@Override
public Boolean visitWindow(WindowExpression windowExpression, Void context) {
boolean needAggregate = false;
for (Expression child : windowExpression.getExpressionsInWindowSpec()) {
needAggregate = needAggregate || child.accept(this, context);
}
return needAggregate;
}
@Override
public Boolean visitAggregateFunction(AggregateFunction aggregateFunction, Void context) {
return true;
}
}
}

View File

@ -22,7 +22,6 @@ import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.RelationId;
@ -30,7 +29,6 @@ import org.apache.doris.nereids.trees.plans.algebra.OneRowRelation;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.nereids.util.Utils;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.List;
@ -52,8 +50,6 @@ public class LogicalOneRowRelation extends LogicalRelation implements OneRowRela
private LogicalOneRowRelation(RelationId relationId, List<NamedExpression> projects,
Optional<GroupExpression> groupExpression, Optional<LogicalProperties> logicalProperties) {
super(relationId, PlanType.LOGICAL_ONE_ROW_RELATION, groupExpression, logicalProperties);
Preconditions.checkArgument(projects.stream().noneMatch(p -> p.containsType(AggregateFunction.class)),
"OneRowRelation can not contains any aggregate function");
this.projects = ImmutableList.copyOf(Objects.requireNonNull(projects, "projects can not be null"));
}

View File

@ -0,0 +1,4 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !projectAggFuncs --
1 a 1 2 2.0 3 4 1 1 1 5 6.0 1 7 8

View File

@ -0,0 +1,4 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !projectAggFuncs --
1 a 1 2 2.0 3 4 1 1 1 5 6.0 1 7 8

View File

@ -1,4 +1,17 @@
/*
-- database: presto; groups: no_from
SELECT COUNT(10), MAX(50), MIN(90.0)
*/
SELECT 1,
'a',
COUNT(),
SUM(1) + 1,
AVG(2) / COUNT(),
MAX(3),
MIN(4),
RANK() OVER() AS w_rank,
DENSE_RANK() OVER() AS w_dense_rank,
ROW_NUMBER() OVER() AS w_row_number,
SUM(5) OVER() AS w_sum,
AVG(6) OVER() AS w_avg,
COUNT() OVER() AS w_count,
MAX(7) OVER() AS w_max,
MIN(8) OVER() AS w_min;

View File

@ -31,11 +31,4 @@ suite("one_row_relation") {
)a"""
result([[100, "abc", "ab", "de", null]])
}
test {
sql """
select sum(1);
"""
exception "OneRowRelation can not contains any aggregate function"
}
}

View File

@ -1,4 +1,16 @@
/*
-- database: presto; groups: no_from
SELECT COUNT(10), MAX(50), MIN(90.0)
*/
SELECT 1,
'a',
COUNT(),
SUM(1) + 1,
AVG(2) / COUNT(),
MAX(3),
MIN(4),
RANK() OVER() AS w_rank,
DENSE_RANK() OVER() AS w_dense_rank,
ROW_NUMBER() OVER() AS w_row_number,
SUM(5) OVER() AS w_sum,
AVG(6) OVER() AS w_avg,
COUNT() OVER() AS w_count,
MAX(7) OVER() AS w_max,
MIN(8) OVER() AS w_min;