[fix](Nereids) column prune generate empty project list on join's child (#12486)

* [fix](Nereids) column prune generate empty project list on join's child
This commit is contained in:
morrySnow
2022-09-09 10:43:57 +08:00
committed by GitHub
parent f98ec06783
commit a04f9814fe
5 changed files with 56 additions and 25 deletions

View File

@ -26,6 +26,7 @@ import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.ImmutableList;
@ -56,7 +57,7 @@ public class PruneAggChildColumns extends OneRewriteRuleFactory {
return RuleType.COLUMN_PRUNE_AGGREGATION_CHILD.build(logicalAggregate().then(agg -> {
List<Slot> childOutput = agg.child().getOutput();
if (isAggregateWithConstant(agg)) {
Slot slot = selectMinimumColumn(childOutput);
Slot slot = ExpressionUtils.selectMinimumColumn(childOutput);
if (childOutput.size() == 1 && childOutput.get(0).equals(slot)) {
return agg;
}
@ -86,17 +87,4 @@ public class PruneAggChildColumns extends OneRewriteRuleFactory {
}
return true;
}
private Slot selectMinimumColumn(List<Slot> outputList) {
Slot minSlot = null;
for (Slot slot : outputList) {
if (minSlot == null) {
minSlot = slot;
} else {
int slotDataTypeWidth = slot.getDataType().width();
minSlot = minSlot.getDataType().width() > slotDataTypeWidth ? slot : minSlot;
}
}
return minSlot;
}
}

View File

@ -25,6 +25,7 @@ import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.ImmutableList;
@ -76,6 +77,13 @@ public class PruneJoinChildrenColumns
List<NamedExpression> rightInputs = joinPlan.right().getOutput().stream()
.filter(r -> exprIds.contains(r.getExprId())).collect(Collectors.toList());
if (leftInputs.isEmpty()) {
leftInputs.add(ExpressionUtils.selectMinimumColumn(joinPlan.left().getOutput()));
}
if (rightInputs.isEmpty()) {
rightInputs.add(ExpressionUtils.selectMinimumColumn(joinPlan.right().getOutput()));
}
Plan leftPlan = joinPlan.left();
Plan rightPlan = joinPlan.right();

View File

@ -31,7 +31,6 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.List;
import java.util.Objects;
@ -75,8 +74,8 @@ public class PushPredicateThroughJoin extends OneRewriteRuleFactory {
List<Expression> otherConditions = Lists.newArrayList();
List<Expression> eqConditions = Lists.newArrayList();
List<Slot> leftInput = join.left().getOutput();
List<Slot> rightInput = join.right().getOutput();
Set<Slot> leftInput = join.left().getOutputSet();
Set<Slot> rightInput = join.right().getOutputSet();
ExpressionUtils.extractConjunction(ExpressionUtils.and(onPredicates, wherePredicates))
.forEach(predicate -> {
@ -122,18 +121,18 @@ public class PushPredicateThroughJoin extends OneRewriteRuleFactory {
Plan leftPlan = joinPlan.left();
Plan rightPlan = joinPlan.right();
if (!left.equals(BooleanLiteral.TRUE)) {
leftPlan = new LogicalFilter(left, leftPlan);
leftPlan = new LogicalFilter<>(left, leftPlan);
}
if (!right.equals(BooleanLiteral.TRUE)) {
rightPlan = new LogicalFilter(right, rightPlan);
rightPlan = new LogicalFilter<>(right, rightPlan);
}
return new LogicalJoin<>(joinPlan.getJoinType(), joinPlan.getHashJoinConjuncts(),
Optional.of(ExpressionUtils.and(joinConditions)), leftPlan, rightPlan);
}
private Expression getJoinCondition(Expression predicate, List<Slot> leftOutputs, List<Slot> rightOutputs) {
private Expression getJoinCondition(Expression predicate, Set<Slot> leftOutputs, Set<Slot> rightOutputs) {
if (!(predicate instanceof ComparisonPredicate)) {
return null;
}
@ -147,11 +146,8 @@ public class PushPredicateThroughJoin extends OneRewriteRuleFactory {
return null;
}
Set<Slot> left = Sets.newLinkedHashSet(leftOutputs);
Set<Slot> right = Sets.newLinkedHashSet(rightOutputs);
if ((left.containsAll(leftSlots) && right.containsAll(rightSlots)) || (left.containsAll(rightSlots)
&& right.containsAll(leftSlots))) {
if ((leftOutputs.containsAll(leftSlots) && rightOutputs.containsAll(rightSlots))
|| (leftOutputs.containsAll(rightSlots) && rightOutputs.containsAll(leftSlots))) {
return predicate;
}

View File

@ -21,6 +21,7 @@ import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
@ -142,4 +143,20 @@ public class ExpressionUtils {
}
return false;
}
/**
* Choose the minimum slot from input parameter.
*/
public static Slot selectMinimumColumn(List<Slot> slots) {
Slot minSlot = null;
for (Slot slot : slots) {
if (minSlot == null) {
minSlot = slot;
} else {
int slotDataTypeWidth = slot.getDataType().width();
minSlot = minSlot.getDataType().width() > slotDataTypeWidth ? slot : minSlot;
}
}
return minSlot;
}
}

View File

@ -269,6 +269,28 @@ public class ColumnPruningTest extends TestWithFeService implements PatternMatch
);
}
@Test
public void pruneColumnForOneSideOnCrossJoin() {
PlanChecker.from(connectContext)
.analyze("select id,name from student cross join score")
.applyTopDown(new ColumnPruning())
.matchesFromRoot(
logicalProject(
logicalJoin(
logicalProject(logicalRelation())
.when(p -> getOutputQualifiedNames(p)
.containsAll(ImmutableList.of(
"default_cluster:test.student.id",
"default_cluster:test.student.name"))),
logicalProject(logicalRelation())
.when(p -> getOutputQualifiedNames(p)
.containsAll(ImmutableList.of(
"default_cluster:test.score.sid")))
)
)
);
}
private List<String> getOutputQualifiedNames(LogicalProject<? extends Plan> p) {
return p.getProjects().stream().map(NamedExpression::getQualifiedName).collect(Collectors.toList());
}