[fix](nereids)subquery unnesting get wrong result if correlated conjuncts is not slot_a = slot_b (#37683)

pick from master https://github.com/apache/doris/pull/37644

## Proposed changes

Issue Number: close #xxx

<!--Describe your changes.-->
This commit is contained in:
starocean999
2024-07-16 15:06:40 +08:00
committed by GitHub
parent 02716598d4
commit 80ea98b371
8 changed files with 206 additions and 15 deletions

View File

@ -200,7 +200,11 @@ public class Rewriter extends AbstractBatchJobExecutor {
* TODO: group these rules to make sure the result plan is what we expected.
*/
new CorrelateApplyToUnCorrelateApply(),
new ApplyToJoin()
new ApplyToJoin(),
// UnCorrelatedApplyAggregateFilter rule will create new aggregate outputs,
// The later rule CheckPrivileges which inherent from ColumnPruning only works
// if the aggregation node is normalized, so we need call NormalizeAggregate here
new NormalizeAggregate()
)
),
// before `Subquery unnesting` topic, some correlate slots should have appeared at LogicalApply.left,

View File

@ -75,6 +75,11 @@ class SubExprAnalyzer<T> extends DefaultExpressionRewriter<T> {
@Override
public Expression visitExistsSubquery(Exists exists, T context) {
LogicalPlan queryPlan = exists.getQueryPlan();
// distinct is useless, remove it
if (queryPlan instanceof LogicalProject && ((LogicalProject) queryPlan).isDistinct()) {
exists = exists.withSubquery(((LogicalProject) queryPlan).withDistinct(false));
}
AnalyzedResult analyzedResult = analyzeSubquery(exists);
if (analyzedResult.rootIsLimitZero()) {
return BooleanLiteral.of(exists.isNot());
@ -89,6 +94,11 @@ class SubExprAnalyzer<T> extends DefaultExpressionRewriter<T> {
@Override
public Expression visitInSubquery(InSubquery expr, T context) {
LogicalPlan queryPlan = expr.getQueryPlan();
// distinct is useless, remove it
if (queryPlan instanceof LogicalProject && ((LogicalProject) queryPlan).isDistinct()) {
expr = expr.withSubquery(((LogicalProject) queryPlan).withDistinct(false));
}
AnalyzedResult analyzedResult = analyzeSubquery(expr);
checkOutputColumn(analyzedResult.getLogicalPlan());

View File

@ -19,8 +19,10 @@ package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
@ -31,6 +33,7 @@ import org.apache.doris.nereids.util.Utils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Maps;
import java.util.ArrayList;
import java.util.List;
@ -97,10 +100,19 @@ public class UnCorrelatedApplyAggregateFilter implements RewriteRuleFactory {
// pull up correlated filter into apply node
List<NamedExpression> newAggOutput = new ArrayList<>(agg.getOutputExpressions());
List<Expression> newGroupby =
Utils.getCorrelatedSlots(correlatedPredicate, apply.getCorrelationSlot());
Utils.getUnCorrelatedExprs(correlatedPredicate, apply.getCorrelationSlot());
newGroupby.addAll(agg.getGroupByExpressions());
newAggOutput.addAll(newGroupby.stream().map(NamedExpression.class::cast)
.collect(ImmutableList.toImmutableList()));
Map<Expression, Slot> unCorrelatedExprToSlot = Maps.newHashMap();
for (Expression expression : newGroupby) {
if (expression instanceof Slot) {
newAggOutput.add((NamedExpression) expression);
} else {
Alias alias = new Alias(expression);
unCorrelatedExprToSlot.put(expression, alias.toSlot());
newAggOutput.add(alias);
}
}
correlatedPredicate = ExpressionUtils.replace(correlatedPredicate, unCorrelatedExprToSlot);
LogicalAggregate newAgg = new LogicalAggregate<>(newGroupby, newAggOutput,
PlanUtils.filterOrSelf(ImmutableSet.copyOf(unCorrelatedPredicate), filter.child()));
return new LogicalApply<>(apply.getCorrelationSlot(), apply.getSubqueryExpr(),

View File

@ -199,6 +199,10 @@ public class LogicalProject<CHILD_TYPE extends Plan> extends LogicalUnary<CHILD_
return new LogicalProject<>(projects, excepts, isDistinct, canEliminate, ImmutableList.of(child));
}
public LogicalProject<Plan> withDistinct(boolean isDistinct) {
return new LogicalProject<>(projects, excepts, isDistinct, canEliminate, children);
}
public boolean isDistinct() {
return isDistinct;
}

View File

@ -17,8 +17,11 @@
package org.apache.doris.nereids.util;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
@ -178,18 +181,51 @@ public class Utils {
}
/**
* Get the correlated columns that belong to the subquery,
* that is, the correlated columns that can be resolved within the subquery.
* Get the unCorrelated exprs that belong to the subquery,
* that is, the unCorrelated exprs that can be resolved within the subquery.
* eg:
* select * from t1 where t1.a = (select sum(t2.b) from t2 where t1.c = t2.d));
* correlatedPredicates : t1.c = t2.d
* correlatedSlots : t1.c
* return t2.d
* select * from t1 where t1.a = (select sum(t2.b) from t2 where t1.c = abs(t2.d));
* correlatedPredicates : t1.c = abs(t2.d)
* unCorrelatedExprs : abs(t2.d)
* return abs(t2.d)
*/
public static List<Expression> getCorrelatedSlots(List<Expression> correlatedPredicates,
List<Expression> correlatedSlots) {
return ExpressionUtils.getInputSlotSet(correlatedPredicates).stream()
.filter(slot -> !correlatedSlots.contains(slot)).collect(Collectors.toList());
public static List<Expression> getUnCorrelatedExprs(List<Expression> correlatedPredicates,
List<Expression> correlatedSlots) {
List<Expression> unCorrelatedExprs = new ArrayList<>();
correlatedPredicates.forEach(predicate -> {
if (!(predicate instanceof BinaryExpression) && (!(predicate instanceof Not)
|| !(predicate.child(0) instanceof BinaryExpression))) {
throw new AnalysisException(
"Unsupported correlated subquery with correlated predicate "
+ predicate.toString());
}
BinaryExpression binaryExpression;
if (predicate instanceof Not) {
binaryExpression = (BinaryExpression) ((Not) predicate).child();
} else {
binaryExpression = (BinaryExpression) predicate;
}
Expression left = binaryExpression.left();
Expression right = binaryExpression.right();
Set<Slot> leftInputSlots = left.getInputSlots();
Set<Slot> rightInputSlots = right.getInputSlots();
boolean correlatedToLeft = !leftInputSlots.isEmpty()
&& leftInputSlots.stream().allMatch(correlatedSlots::contains)
&& rightInputSlots.stream().noneMatch(correlatedSlots::contains);
boolean correlatedToRight = !rightInputSlots.isEmpty()
&& rightInputSlots.stream().allMatch(correlatedSlots::contains)
&& leftInputSlots.stream().noneMatch(correlatedSlots::contains);
if (!correlatedToLeft && !correlatedToRight) {
throw new AnalysisException(
"Unsupported correlated subquery with correlated predicate " + predicate);
} else if (correlatedToLeft && !rightInputSlots.isEmpty()) {
unCorrelatedExprs.add(right);
} else if (correlatedToRight && !leftInputSlots.isEmpty()) {
unCorrelatedExprs.add(left);
}
});
return unCorrelatedExprs;
}
private static List<Expression> collectCorrelatedSlotsFromChildren(

View File

@ -0,0 +1,54 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !select_simple_scalar --
-2 -2
2 2
3 2
-- !select_complex_scalar --
2 2
3 2
-- !select_simple_in --
1 1
2 1
-- !select_complex_in --
1 1
2 1
-- !select_simple_not_in --
-2 -2
-1 -1
1 1
2 1
2 2
3 2
-- !select_complex_not_in --
-2 -2
-1 -1
1 1
2 1
2 2
3 2
-- !select_simple_exists --
-2 -2
2 2
3 2
-- !select_complex_exists --
2 2
3 2
-- !select_simple_not_exists --
-1 -1
1 1
2 1
-- !select_complex_not_exists --
-2 -2
-1 -1
1 1
2 1

View File

@ -193,7 +193,7 @@ suite ("sub_query_diff_old_optimize") {
sql """
SELECT DISTINCT k1 FROM sub_query_diff_old_optimize_subquery1 i1 WHERE ((SELECT count(*) FROM sub_query_diff_old_optimize_subquery1 WHERE ((k1 = i1.k1) AND (k2 = 2)) or ((k2 = i1.k1) AND (k2 = 1)) ) > 0);
"""
exception "scalar subquery's correlatedPredicates's operator must be EQ"
exception "Unsupported correlated subquery with correlated predicate"
}
}

View File

@ -0,0 +1,71 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
suite("test_subquery_conjunct") {
sql "set enable_nereids_planner=true"
sql "set enable_fallback_to_original_planner=false"
sql """drop table if exists subquery_conjunct_table;"""
sql """CREATE TABLE `subquery_conjunct_table` (
`id` INT NOT NULL,
`c1` INT NOT NULL
) ENGINE=OLAP
DUPLICATE KEY(`id`, `c1`)
DISTRIBUTED BY RANDOM BUCKETS AUTO
PROPERTIES (
"replication_allocation" = "tag.location.default: 1"
);"""
sql """insert into subquery_conjunct_table values(1, 1),(2,2),(-1,-1),(-2,-2),(2,1),(3,2);"""
qt_select_simple_scalar """select * from subquery_conjunct_table t1 where abs(t1.c1) != (select sum(c1) from subquery_conjunct_table t2 where t2.c1 + t2.id = t1.c1) order by t1.id, t1.c1;"""
qt_select_complex_scalar """select * from subquery_conjunct_table t1 where abs(t1.c1) != (select sum(c1) from subquery_conjunct_table t2 where abs(t2.c1 + t2.id) = t1.c1) order by t1.id, t1.c1;"""
qt_select_simple_in """select * from subquery_conjunct_table t1 where abs(t1.c1) in (select c1 from subquery_conjunct_table t2 where t2.c1 + t2.id -1 = t1.c1) order by t1.id, t1.c1;"""
qt_select_complex_in """select * from subquery_conjunct_table t1 where abs(t1.c1) in (select c1 from subquery_conjunct_table t2 where abs(t2.c1+ t2.id -1) = t1.c1) order by t1.id, t1.c1;"""
qt_select_simple_not_in """select * from subquery_conjunct_table t1 where abs(t1.c1) not in (select c1 from subquery_conjunct_table t2 where t2.c1 + t2.id = t1.c1) order by t1.id, t1.c1;"""
qt_select_complex_not_in """select * from subquery_conjunct_table t1 where abs(t1.c1) not in (select c1 from subquery_conjunct_table t2 where abs(t2.c1 + t2.id) = t1.c1) order by t1.id, t1.c1;"""
qt_select_simple_exists """select * from subquery_conjunct_table t1 where exists (select c1 from subquery_conjunct_table t2 where t2.c1 + t2.id = t1.c1) order by t1.id, t1.c1;"""
qt_select_complex_exists """select * from subquery_conjunct_table t1 where exists (select c1 from subquery_conjunct_table t2 where abs(t2.c1 + t2.id) = t1.c1) order by t1.id, t1.c1;"""
qt_select_simple_not_exists """select * from subquery_conjunct_table t1 where not exists (select c1 from subquery_conjunct_table t2 where t2.c1 + t2.id = t1.c1) order by t1.id, t1.c1;"""
qt_select_complex_not_exists """select * from subquery_conjunct_table t1 where not exists (select c1 from subquery_conjunct_table t2 where abs(t2.c1 + t2.id) = t1.c1) order by t1.id, t1.c1;"""
test {
sql """ select * from subquery_conjunct_table t1 where abs(t1.c1) != (select sum(c1) from subquery_conjunct_table t2 where abs(t2.c1) - t1.c1 = 0) order by t1.id; """
exception "Unsupported correlated subquery with correlated predicate"
}
test {
sql """ select * from subquery_conjunct_table t1 where abs(t1.c1) != ( select sum(c1) from subquery_conjunct_table t2 where abs(t2.c1 -1) + t1.id = t1.c1) order by t1.id, t1.c1; """
exception "Unsupported correlated subquery with correlated predicate"
}
test {
sql """ select * from subquery_conjunct_table t1 where abs(t1.c1) != (select sum(c1) from subquery_conjunct_table t2 where abs(t2.c1) > t1.c1) order by t1.id; """
exception "scalar subquery's correlatedPredicates's operator must be EQ"
}
test {
sql """ select * from subquery_conjunct_table t1 where abs(t1.c1) in (select sum(c1) from subquery_conjunct_table t2 where t2.c1 + 1 = t1.c1) order by t1.id, t1.c1; """
exception "Unsupported correlated subquery with grouping and/or aggregation"
}
test {
sql """ select * from subquery_conjunct_table t1 where abs(t1.c1) in (select sum(c1) from subquery_conjunct_table t2 where abs(t2.c1) = t1.c1) order by t1.id, t1.c1; """
exception "Unsupported correlated subquery with grouping and/or aggregation"
}
test {
sql """ select * from subquery_conjunct_table t1 where abs(t1.c1) not in (select sum(c1) from subquery_conjunct_table t2 where t2.c1 + 1= t1.c1) order by t1.id, t1.c1; """
exception "Unsupported correlated subquery with grouping and/or aggregation"
}
test {
sql """ select * from subquery_conjunct_table t1 where abs(t1.c1) not in (select sum(c1) from subquery_conjunct_table t2 where abs(t2.c1 -1) = t1.c1) order by t1.id, t1.c1; """
exception "Unsupported correlated subquery with grouping and/or aggregation"
}
}