From a04f9814fe5a09c0d9e9399fe71cc4d765f8bff1 Mon Sep 17 00:00:00 2001 From: morrySnow <101034200+morrySnow@users.noreply.github.com> Date: Fri, 9 Sep 2022 10:43:57 +0800 Subject: [PATCH] [fix](Nereids) column prune generate empty project list on join's child (#12486) * [fix](Nereids) column prune generate empty project list on join's child --- .../rewrite/logical/PruneAggChildColumns.java | 16 ++------------ .../logical/PruneJoinChildrenColumns.java | 8 +++++++ .../logical/PushPredicateThroughJoin.java | 18 ++++++--------- .../doris/nereids/util/ExpressionUtils.java | 17 ++++++++++++++ .../rewrite/logical/ColumnPruningTest.java | 22 +++++++++++++++++++ 5 files changed, 56 insertions(+), 25 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PruneAggChildColumns.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PruneAggChildColumns.java index b84d885e3a..05047c7ad8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PruneAggChildColumns.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PruneAggChildColumns.java @@ -26,6 +26,7 @@ import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.util.ExpressionUtils; import com.google.common.collect.ImmutableList; @@ -56,7 +57,7 @@ public class PruneAggChildColumns extends OneRewriteRuleFactory { return RuleType.COLUMN_PRUNE_AGGREGATION_CHILD.build(logicalAggregate().then(agg -> { List childOutput = agg.child().getOutput(); if (isAggregateWithConstant(agg)) { - Slot slot = selectMinimumColumn(childOutput); + Slot slot = ExpressionUtils.selectMinimumColumn(childOutput); if (childOutput.size() == 1 && childOutput.get(0).equals(slot)) { return agg; } @@ -86,17 +87,4 @@ public class PruneAggChildColumns extends OneRewriteRuleFactory { } return true; } - - private Slot selectMinimumColumn(List outputList) { - Slot minSlot = null; - for (Slot slot : outputList) { - if (minSlot == null) { - minSlot = slot; - } else { - int slotDataTypeWidth = slot.getDataType().width(); - minSlot = minSlot.getDataType().width() > slotDataTypeWidth ? slot : minSlot; - } - } - return minSlot; - } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PruneJoinChildrenColumns.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PruneJoinChildrenColumns.java index e76a832c2d..33dbc63057 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PruneJoinChildrenColumns.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PruneJoinChildrenColumns.java @@ -25,6 +25,7 @@ import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.util.ExpressionUtils; import com.google.common.collect.ImmutableList; @@ -76,6 +77,13 @@ public class PruneJoinChildrenColumns List rightInputs = joinPlan.right().getOutput().stream() .filter(r -> exprIds.contains(r.getExprId())).collect(Collectors.toList()); + if (leftInputs.isEmpty()) { + leftInputs.add(ExpressionUtils.selectMinimumColumn(joinPlan.left().getOutput())); + } + if (rightInputs.isEmpty()) { + rightInputs.add(ExpressionUtils.selectMinimumColumn(joinPlan.right().getOutput())); + } + Plan leftPlan = joinPlan.left(); Plan rightPlan = joinPlan.right(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughJoin.java index 70e75adc98..6923e0509e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughJoin.java @@ -31,7 +31,6 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.util.ExpressionUtils; import com.google.common.collect.Lists; -import com.google.common.collect.Sets; import java.util.List; import java.util.Objects; @@ -75,8 +74,8 @@ public class PushPredicateThroughJoin extends OneRewriteRuleFactory { List otherConditions = Lists.newArrayList(); List eqConditions = Lists.newArrayList(); - List leftInput = join.left().getOutput(); - List rightInput = join.right().getOutput(); + Set leftInput = join.left().getOutputSet(); + Set rightInput = join.right().getOutputSet(); ExpressionUtils.extractConjunction(ExpressionUtils.and(onPredicates, wherePredicates)) .forEach(predicate -> { @@ -122,18 +121,18 @@ public class PushPredicateThroughJoin extends OneRewriteRuleFactory { Plan leftPlan = joinPlan.left(); Plan rightPlan = joinPlan.right(); if (!left.equals(BooleanLiteral.TRUE)) { - leftPlan = new LogicalFilter(left, leftPlan); + leftPlan = new LogicalFilter<>(left, leftPlan); } if (!right.equals(BooleanLiteral.TRUE)) { - rightPlan = new LogicalFilter(right, rightPlan); + rightPlan = new LogicalFilter<>(right, rightPlan); } return new LogicalJoin<>(joinPlan.getJoinType(), joinPlan.getHashJoinConjuncts(), Optional.of(ExpressionUtils.and(joinConditions)), leftPlan, rightPlan); } - private Expression getJoinCondition(Expression predicate, List leftOutputs, List rightOutputs) { + private Expression getJoinCondition(Expression predicate, Set leftOutputs, Set rightOutputs) { if (!(predicate instanceof ComparisonPredicate)) { return null; } @@ -147,11 +146,8 @@ public class PushPredicateThroughJoin extends OneRewriteRuleFactory { return null; } - Set left = Sets.newLinkedHashSet(leftOutputs); - Set right = Sets.newLinkedHashSet(rightOutputs); - - if ((left.containsAll(leftSlots) && right.containsAll(rightSlots)) || (left.containsAll(rightSlots) - && right.containsAll(leftSlots))) { + if ((leftOutputs.containsAll(leftSlots) && rightOutputs.containsAll(rightSlots)) + || (leftOutputs.containsAll(rightSlots) && rightOutputs.containsAll(leftSlots))) { return predicate; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index a716e3e43f..07ba450ea2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -21,6 +21,7 @@ import org.apache.doris.nereids.trees.expressions.And; import org.apache.doris.nereids.trees.expressions.CompoundPredicate; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Or; +import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; @@ -142,4 +143,20 @@ public class ExpressionUtils { } return false; } + + /** + * Choose the minimum slot from input parameter. + */ + public static Slot selectMinimumColumn(List slots) { + Slot minSlot = null; + for (Slot slot : slots) { + if (minSlot == null) { + minSlot = slot; + } else { + int slotDataTypeWidth = slot.getDataType().width(); + minSlot = minSlot.getDataType().width() > slotDataTypeWidth ? slot : minSlot; + } + } + return minSlot; + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ColumnPruningTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ColumnPruningTest.java index c26fa86462..bf2daa712e 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ColumnPruningTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ColumnPruningTest.java @@ -269,6 +269,28 @@ public class ColumnPruningTest extends TestWithFeService implements PatternMatch ); } + @Test + public void pruneColumnForOneSideOnCrossJoin() { + PlanChecker.from(connectContext) + .analyze("select id,name from student cross join score") + .applyTopDown(new ColumnPruning()) + .matchesFromRoot( + logicalProject( + logicalJoin( + logicalProject(logicalRelation()) + .when(p -> getOutputQualifiedNames(p) + .containsAll(ImmutableList.of( + "default_cluster:test.student.id", + "default_cluster:test.student.name"))), + logicalProject(logicalRelation()) + .when(p -> getOutputQualifiedNames(p) + .containsAll(ImmutableList.of( + "default_cluster:test.score.sid"))) + ) + ) + ); + } + private List getOutputQualifiedNames(LogicalProject p) { return p.getProjects().stream().map(NamedExpression::getQualifiedName).collect(Collectors.toList()); }