[feature](Nereids) binding slot in order by that not show in project (#14042)

1. binding slot in order by that not show in project, such as:
SELECT c1 FROM t WHERE c2 > 0 ORDER BY c3

2. not check unbound when bind slot reference. Instead, do it in analysis check.
This commit is contained in:
morrySnow
2022-11-09 13:25:41 +08:00
committed by GitHub
parent 7362460525
commit aff62655c4
11 changed files with 95 additions and 53 deletions

View File

@ -21,9 +21,9 @@ import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.analysis.BindFunction;
import org.apache.doris.nereids.rules.analysis.BindRelation;
import org.apache.doris.nereids.rules.analysis.BindSlotReference;
import org.apache.doris.nereids.rules.analysis.FillUpMissingSlots;
import org.apache.doris.nereids.rules.analysis.ProjectToGlobalAggregate;
import org.apache.doris.nereids.rules.analysis.RegisterCTE;
import org.apache.doris.nereids.rules.analysis.ResolveAggregateFunctions;
import org.apache.doris.nereids.rules.analysis.Scope;
import org.apache.doris.nereids.rules.analysis.UserAuthentication;
@ -55,7 +55,7 @@ public class AnalyzeRulesJob extends BatchRulesJob {
new ProjectToGlobalAggregate()
)),
topDownBatch(ImmutableList.of(
new ResolveAggregateFunctions()
new FillUpMissingSlots()
))
));
}

View File

@ -52,8 +52,8 @@ public class PlanPostProcessors {
Builder<PlanPostProcessor> builder = ImmutableList.builder();
if (cascadesContext.getConnectContext().getSessionVariable().isEnableNereidsRuntimeFilter()) {
builder.add(new RuntimeFilterGenerator());
builder.add(new Validator());
}
builder.add(new Validator());
return builder.build();
}
}

View File

@ -46,9 +46,10 @@ public enum RuleType {
BINDING_FILTER_FUNCTION(RuleTypeClass.REWRITE),
BINDING_HAVING_FUNCTION(RuleTypeClass.REWRITE),
BINDING_SORT_FUNCTION(RuleTypeClass.REWRITE),
RESOLVE_HAVING_AGGREGATE_FUNCTIONS(RuleTypeClass.REWRITE),
RESOLVE_SORT_AGGREGATE_FUNCTIONS(RuleTypeClass.REWRITE),
RESOLVE_SORT_HAVING_AGGREGATE_FUNCTIONS(RuleTypeClass.REWRITE),
FILL_UP_HAVING_AGGREGATE_FUNCTIONS(RuleTypeClass.REWRITE),
FILL_UP_SORT_AGGREGATE_FUNCTIONS(RuleTypeClass.REWRITE),
FILL_UP_SORT_HAVING_AGGREGATE_FUNCTIONS(RuleTypeClass.REWRITE),
FILL_UP_SORT_PROJECT(RuleTypeClass.REWRITE),
RESOLVE_PROJECT_ALIAS(RuleTypeClass.REWRITE),
RESOLVE_AGGREGATE_ALIAS(RuleTypeClass.REWRITE),

View File

@ -160,12 +160,15 @@ public class BindSlotReference implements AnalysisRuleFactory {
})
),
RuleType.BINDING_SORT_SLOT.build(
logicalSort().when(Plan::canBind).thenApply(ctx -> {
LogicalSort<GroupPlan> sort = ctx.root;
logicalSort(logicalProject()).when(Plan::canBind).thenApply(ctx -> {
LogicalSort<LogicalProject<GroupPlan>> sort = ctx.root;
List<OrderKey> sortItemList = sort.getOrderKeys()
.stream()
.map(orderKey -> {
Expression item = bind(orderKey.getExpr(), sort.children(), sort, ctx.cascadesContext);
if (item.containsType(UnboundSlot.class)) {
item = bind(item, sort.child().children(), sort, ctx.cascadesContext);
}
return new OrderKey(item, orderKey.isAsc(), orderKey.isNullFirst());
}).collect(Collectors.toList());
@ -291,7 +294,9 @@ public class BindSlotReference implements AnalysisRuleFactory {
List<Slot> bounded = boundedOpt.get();
switch (bounded.size()) {
case 0:
throw new AnalysisException(String.format("Cannot find column %s.", unboundSlot.toSql()));
// just return, give a chance to bind on another slot.
// if unbound finally, check will throw exception
return unboundSlot;
case 1:
if (!foundInThisScope) {
getScope().getOuterScope().get().getCorrelatedSlots().add(bounded.get(0));

View File

@ -17,6 +17,7 @@
package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.nereids.analyzer.UnboundSlot;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
@ -24,7 +25,11 @@ import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.typecoercion.TypeCheckResult;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.commons.lang.StringUtils;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
/**
* Check analysis rule to check semantic correct after analysis by Nereids.
@ -32,7 +37,7 @@ import java.util.Optional;
public class CheckAnalysis extends OneAnalysisRuleFactory {
@Override
public Rule build() {
return any().then(this::checkExpressionInputTypes).toRule(RuleType.CHECK_ANALYSIS);
return any().then(plan -> checkExpressionInputTypes(checkBound(plan))).toRule(RuleType.CHECK_ANALYSIS);
}
private Plan checkExpressionInputTypes(Plan plan) {
@ -46,4 +51,18 @@ public class CheckAnalysis extends OneAnalysisRuleFactory {
}
return plan;
}
private Plan checkBound(Plan plan) {
Set<UnboundSlot> unboundSlots = plan.getExpressions().stream()
.<Set<UnboundSlot>>map(e -> e.collect(UnboundSlot.class::isInstance))
.flatMap(Set::stream)
.collect(Collectors.toSet());
if (!unboundSlots.isEmpty()) {
throw new AnalysisException(String.format("Cannot find column %s.",
StringUtils.join(unboundSlots.stream()
.map(UnboundSlot::toSql)
.collect(Collectors.toSet()), ", ")));
}
return plan;
}
}

View File

@ -37,6 +37,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Streams;
@ -44,16 +45,17 @@ import com.google.common.collect.Streams;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
/**
* Resolve having clause to the aggregation.
*/
public class ResolveAggregateFunctions implements AnalysisRuleFactory {
public class FillUpMissingSlots implements AnalysisRuleFactory {
@Override
public List<Rule> buildRules() {
return ImmutableList.of(
RuleType.RESOLVE_SORT_AGGREGATE_FUNCTIONS.build(
RuleType.FILL_UP_SORT_AGGREGATE_FUNCTIONS.build(
logicalSort(logicalAggregate())
.when(sort -> sort.getExpressions().stream()
.anyMatch(e -> e.containsType(AggregateFunction.class)))
@ -72,7 +74,7 @@ public class ResolveAggregateFunctions implements AnalysisRuleFactory {
});
})
),
RuleType.RESOLVE_SORT_HAVING_AGGREGATE_FUNCTIONS.build(
RuleType.FILL_UP_SORT_HAVING_AGGREGATE_FUNCTIONS.build(
logicalSort(logicalHaving(logicalAggregate()))
.when(sort -> sort.getExpressions().stream()
.anyMatch(e -> e.containsType(AggregateFunction.class)))
@ -91,7 +93,7 @@ public class ResolveAggregateFunctions implements AnalysisRuleFactory {
});
})
),
RuleType.RESOLVE_HAVING_AGGREGATE_FUNCTIONS.build(
RuleType.FILL_UP_HAVING_AGGREGATE_FUNCTIONS.build(
logicalHaving(logicalAggregate()).then(having -> {
LogicalAggregate<GroupPlan> aggregate = having.child();
Resolver resolver = new Resolver(aggregate);
@ -102,6 +104,27 @@ public class ResolveAggregateFunctions implements AnalysisRuleFactory {
return new LogicalFilter<>(newPredicates, a);
});
})
),
RuleType.FILL_UP_SORT_PROJECT.build(
logicalSort(logicalProject())
.when(sort -> sort.getExpressions().stream()
.map(Expression::getInputSlots)
.flatMap(Set::stream)
.anyMatch(s -> !sort.child().getOutputSet().contains(s)))
.then(sort -> {
final Builder<NamedExpression> projectionsBuilder = ImmutableList.builder();
projectionsBuilder.addAll(sort.child().getProjects());
Set<Slot> notExistedInProject = sort.getExpressions().stream()
.map(Expression::getInputSlots)
.flatMap(Set::stream)
.filter(s -> !sort.child().getOutputSet().contains(s))
.collect(Collectors.toSet());
projectionsBuilder.addAll(notExistedInProject);
return new LogicalProject(sort.child().getOutput(),
new LogicalSort<>(sort.getOrderKeys(),
new LogicalProject<>(projectionsBuilder.build(),
sort.child().child())));
})
)
);
}

View File

@ -140,20 +140,20 @@ public class AnalyzeClickBenchTest extends ClickBenchTestBase {
checkAnalyze(ClickBenchUtils.Q23);
}
// @Test
// public void q24() {
// checkAnalyze(ClickBenchUtils.Q24);
// }
@Test
public void q24() {
checkAnalyze(ClickBenchUtils.Q24);
}
@Test
public void q25() {
checkAnalyze(ClickBenchUtils.Q25);
}
// @Test
// public void q26() {
// checkAnalyze(ClickBenchUtils.Q26);
// }
@Test
public void q26() {
checkAnalyze(ClickBenchUtils.Q26);
}
@Test
public void q27() {

View File

@ -182,11 +182,10 @@ public class AnalyzeTPCHTest extends TPCHTestBase {
checkAnalyze(TPCHUtils.Q21_rewrite);
}
// NOTE: not support '1 for 2' syntax
/*@Test
@Test
public void q22() {
checkAnalyze(TPCHUtils.Q22);
}*/
}
@Test
public void q22_rewrite() {

View File

@ -21,7 +21,7 @@ import org.apache.doris.utframe.TestWithFeService;
public class TPCHUtils {
public static final String Q1 = "select /*+SET_VAR(exec_mem_limit=8589934592, parallel_fragment_exec_instance_num=8, enable_vectorized_engine=true, batch_size=4096, disable_join_reorder=false, enable_cost_based_join_reorder=false, enable_projection=false) */\n"
public static final String Q1 = "select\n"
+ " l_returnflag,\n"
+ " l_linestatus,\n"
+ " sum(l_quantity) as sum_qty,\n"
@ -88,7 +88,7 @@ public class TPCHUtils {
+ " p_partkey\n"
+ "limit 100;";
public static final String Q2_rewrite = "select /*+SET_VAR(exec_mem_limit=8589934592, parallel_fragment_exec_instance_num=1, enable_vectorized_engine=true, batch_size=4096, disable_join_reorder=true, enable_cost_based_join_reorder=false, enable_projection=true) */\n"
public static final String Q2_rewrite = "select\n"
+ " s_acctbal,\n"
+ " s_name,\n"
+ " n_name,\n"
@ -152,7 +152,7 @@ public class TPCHUtils {
+ " o_orderdate\n"
+ "limit 10;";
public static String Q3_rewrite = "select /*+SET_VAR(exec_mem_limit=8589934592, parallel_fragment_exec_instance_num=8, enable_vectorized_engine=true, batch_size=4096, disable_join_reorder=true, enable_cost_based_join_reorder=false, enable_projection=true) */\n"
public static String Q3_rewrite = "select \n"
+ " l_orderkey,\n"
+ " sum(l_extendedprice * (1 - l_discount)) as revenue,\n"
+ " o_orderdate,\n"
@ -198,7 +198,7 @@ public class TPCHUtils {
+ "order by\n"
+ " o_orderpriority;";
public static String Q4_rewrite = "select /*+SET_VAR(exec_mem_limit=8589934592, parallel_fragment_exec_instance_num=1, enable_vectorized_engine=true, batch_size=4096, disable_join_reorder=true, enable_cost_based_join_reorder=false, enable_projection=true) */\n"
public static String Q4_rewrite = "select \n"
+ " o_orderpriority,\n"
+ " count(*) as order_count\n"
+ "from\n"
@ -398,7 +398,7 @@ public class TPCHUtils {
+ " revenue desc\n"
+ "limit 20;";
public static final String Q11 = "select /*+SET_VAR(exec_mem_limit=8589934592, parallel_fragment_exec_instance_num=2, enable_vectorized_engine=true, batch_size=4096, disable_join_reorder=false, enable_cost_based_join_reorder=true, enable_projection=true) */\n"
public static final String Q11 = "select \n"
+ " ps_partkey,\n"
+ " sum(ps_supplycost * ps_availqty) as value\n"
+ "from\n"
@ -454,7 +454,7 @@ public class TPCHUtils {
+ "order by\n"
+ " l_shipmode;";
public static final String Q12_rewrite = "select /*+SET_VAR(exec_mem_limit=8589934592, parallel_fragment_exec_instance_num=2, enable_vectorized_engine=true, batch_size=4096, disable_join_reorder=false, enable_cost_based_join_reorder=true, enable_projection=true) */\n"
public static final String Q12_rewrite = "select \n"
+ " l_shipmode,\n"
+ " sum(case\n"
+ " when o_orderpriority = '1-URGENT'\n"
@ -518,7 +518,7 @@ public class TPCHUtils {
+ " and l_shipdate >= date '1995-09-01'\n"
+ " and l_shipdate < date '1995-10-01';";
public static final String Q14_rewrite = "select /*+SET_VAR(exec_mem_limit=8589934592, parallel_fragment_exec_instance_num=8, enable_vectorized_engine=true, batch_size=4096, disable_join_reorder=true, enable_cost_based_join_reorder=true, enable_projection=true) */\n"
public static final String Q14_rewrite = "select \n"
+ " 100.00 * sum(case\n"
+ " when p_type like 'PROMO%'\n"
+ " then l_extendedprice * (1 - l_discount)\n"
@ -563,7 +563,7 @@ public class TPCHUtils {
+ "order by\n"
+ "\ts_suppkey;";
public static final String Q15_rewrite = "select /*+SET_VAR(exec_mem_limit=8589934592, parallel_fragment_exec_instance_num=4, enable_vectorized_engine=true, batch_size=4096, disable_join_reorder=false, enable_cost_based_join_reorder=true, enable_projection=true) */\n"
public static final String Q15_rewrite = "select \n"
+ " s_suppkey,\n"
+ " s_name,\n"
+ " s_address,\n"
@ -632,7 +632,7 @@ public class TPCHUtils {
+ " l_partkey = p_partkey\n"
+ " );";
public static final String Q17_rewrite = "select /*+SET_VAR(exec_mem_limit=8589934592, parallel_fragment_exec_instance_num=1, enable_vectorized_engine=true, batch_size=4096, disable_join_reorder=false, enable_cost_based_join_reorder=true, enable_projection=true) */\n"
public static final String Q17_rewrite = "select \n"
+ " sum(l_extendedprice) / 7.0 as avg_yearly\n"
+ "from\n"
+ " lineitem join [broadcast]\n"
@ -686,7 +686,7 @@ public class TPCHUtils {
+ " o_orderdate\n"
+ "limit 100;";
public static final String Q18_rewrite = "select /*+SET_VAR(exec_mem_limit=8589934592, parallel_fragment_exec_instance_num=8, enable_vectorized_engine=true, batch_size=4096, disable_join_reorder=true, enable_cost_based_join_reorder=true, enable_projection=true) */\n"
public static final String Q18_rewrite = "select \n"
+ " c_name,\n"
+ " c_custkey,\n"
+ " t3.o_orderkey,\n"
@ -799,7 +799,7 @@ public class TPCHUtils {
+ "order by\n"
+ " s_name;";
public static final String Q20_rewrite = "select /*+SET_VAR(exec_mem_limit=8589934592, parallel_fragment_exec_instance_num=8, enable_vectorized_engine=true, batch_size=4096, disable_join_reorder=true, enable_cost_based_join_reorder=true, enable_projection=true) */\n"
public static final String Q20_rewrite = "select \n"
+ "s_name, s_address from\n"
+ "supplier left semi join\n"
+ "(\n"
@ -866,7 +866,7 @@ public class TPCHUtils {
+ " s_name\n"
+ "limit 100;";
public static final String Q21_rewrite = "select /*+SET_VAR(exec_mem_limit=8589934592, parallel_fragment_exec_instance_num=8, enable_vectorized_engine=true, batch_size=4096, disable_join_reorder=true, enable_cost_based_join_reorder=true, enable_projection=true) */\n"
public static final String Q21_rewrite = "select \n"
+ "s_name, count(*) as numwait\n"
+ "from orders join\n"
+ "(\n"
@ -905,12 +905,12 @@ public class TPCHUtils {
+ "from\n"
+ " (\n"
+ " select\n"
+ " substring(c_phone from 1 for 2) as cntrycode,\n"
+ " substring(c_phone, 1, 2) as cntrycode,\n"
+ " c_acctbal\n"
+ " from\n"
+ " customer\n"
+ " where\n"
+ " substring(c_phone from 1 for 2) in\n"
+ " substring(c_phone, 1, 2) in\n"
+ " ('13', '31', '23', '29', '30', '18', '17')\n"
+ " and c_acctbal > (\n"
+ " select\n"
@ -919,7 +919,7 @@ public class TPCHUtils {
+ " customer\n"
+ " where\n"
+ " c_acctbal > 0.00\n"
+ " and substring(c_phone from 1 for 2) in\n"
+ " and substring(c_phone, 1, 2) in\n"
+ " ('13', '31', '23', '29', '30', '18', '17')\n"
+ " )\n"
+ " and not exists (\n"
@ -936,7 +936,7 @@ public class TPCHUtils {
+ "order by\n"
+ " cntrycode;";
public static final String Q22_rewrite = "select /*+SET_VAR(exec_mem_limit=8589934592, parallel_fragment_exec_instance_num=8, enable_vectorized_engine=true, batch_size=4096, disable_join_reorder=false, enable_cost_based_join_reorder=true, enable_projection=true) */\n"
public static final String Q22_rewrite = "select \n"
+ " cntrycode,\n"
+ " count(*) as numcust,\n"
+ " sum(c_acctbal) as totacctbal\n"

View File

@ -48,7 +48,7 @@ import org.junit.jupiter.api.Test;
import java.util.stream.Collectors;
public class ResolveAggregateFunctionsTest extends AnalyzeCheckTestBase implements PatternMatchSupported {
public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements PatternMatchSupported {
@Override
public void runBeforeAll() throws Exception {

View File

@ -61,11 +61,6 @@ public class LogicalPlanBuilder {
return from(scan);
}
public LogicalPlanBuilder projectWithExprs(List<NamedExpression> projectExprs) {
LogicalProject<LogicalPlan> project = new LogicalProject<>(projectExprs, this.plan);
return from(project);
}
public LogicalPlanBuilder project(List<Integer> slotsIndex) {
List<NamedExpression> projectExprs = Lists.newArrayList();
for (Integer index : slotsIndex) {
@ -143,26 +138,26 @@ public class LogicalPlanBuilder {
for (Integer index : outputExprsIndex) {
outputBuilder.add(this.plan.getOutput().get(index));
}
ImmutableList<NamedExpression> outputExpresList = outputBuilder.build();
ImmutableList<NamedExpression> outputExprsList = outputBuilder.build();
LogicalAggregate<Plan> agg = new LogicalAggregate<>(groupByKeys, outputExpresList, this.plan);
LogicalAggregate<Plan> agg = new LogicalAggregate<>(groupByKeys, outputExprsList, this.plan);
return from(agg);
}
public LogicalPlanBuilder aggGroupUsingIndex(List<Integer> groupByKeysIndex,
List<NamedExpression> outputExpresList) {
List<NamedExpression> outputExprsList) {
Builder<Expression> groupByBuilder = ImmutableList.builder();
for (Integer index : groupByKeysIndex) {
groupByBuilder.add(this.plan.getOutput().get(index));
}
ImmutableList<Expression> groupByKeys = groupByBuilder.build();
LogicalAggregate<Plan> agg = new LogicalAggregate<>(groupByKeys, outputExpresList, this.plan);
LogicalAggregate<Plan> agg = new LogicalAggregate<>(groupByKeys, outputExprsList, this.plan);
return from(agg);
}
public LogicalPlanBuilder agg(List<Expression> groupByKeys, List<NamedExpression> outputExpresList) {
LogicalAggregate<Plan> agg = new LogicalAggregate<>(groupByKeys, outputExpresList, this.plan);
public LogicalPlanBuilder agg(List<Expression> groupByKeys, List<NamedExpression> outputExprsList) {
LogicalAggregate<Plan> agg = new LogicalAggregate<>(groupByKeys, outputExprsList, this.plan);
return from(agg);
}
}