[fix](Nereids) column prune generate empty project list on join's child (#12486)
* [fix](Nereids) column prune generate empty project list on join's child
This commit is contained in:
@ -26,6 +26,7 @@ import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
@ -56,7 +57,7 @@ public class PruneAggChildColumns extends OneRewriteRuleFactory {
|
||||
return RuleType.COLUMN_PRUNE_AGGREGATION_CHILD.build(logicalAggregate().then(agg -> {
|
||||
List<Slot> childOutput = agg.child().getOutput();
|
||||
if (isAggregateWithConstant(agg)) {
|
||||
Slot slot = selectMinimumColumn(childOutput);
|
||||
Slot slot = ExpressionUtils.selectMinimumColumn(childOutput);
|
||||
if (childOutput.size() == 1 && childOutput.get(0).equals(slot)) {
|
||||
return agg;
|
||||
}
|
||||
@ -86,17 +87,4 @@ public class PruneAggChildColumns extends OneRewriteRuleFactory {
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private Slot selectMinimumColumn(List<Slot> outputList) {
|
||||
Slot minSlot = null;
|
||||
for (Slot slot : outputList) {
|
||||
if (minSlot == null) {
|
||||
minSlot = slot;
|
||||
} else {
|
||||
int slotDataTypeWidth = slot.getDataType().width();
|
||||
minSlot = minSlot.getDataType().width() > slotDataTypeWidth ? slot : minSlot;
|
||||
}
|
||||
}
|
||||
return minSlot;
|
||||
}
|
||||
}
|
||||
|
||||
@ -25,6 +25,7 @@ import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
@ -76,6 +77,13 @@ public class PruneJoinChildrenColumns
|
||||
List<NamedExpression> rightInputs = joinPlan.right().getOutput().stream()
|
||||
.filter(r -> exprIds.contains(r.getExprId())).collect(Collectors.toList());
|
||||
|
||||
if (leftInputs.isEmpty()) {
|
||||
leftInputs.add(ExpressionUtils.selectMinimumColumn(joinPlan.left().getOutput()));
|
||||
}
|
||||
if (rightInputs.isEmpty()) {
|
||||
rightInputs.add(ExpressionUtils.selectMinimumColumn(joinPlan.right().getOutput()));
|
||||
}
|
||||
|
||||
Plan leftPlan = joinPlan.left();
|
||||
Plan rightPlan = joinPlan.right();
|
||||
|
||||
|
||||
@ -31,7 +31,6 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
@ -75,8 +74,8 @@ public class PushPredicateThroughJoin extends OneRewriteRuleFactory {
|
||||
List<Expression> otherConditions = Lists.newArrayList();
|
||||
List<Expression> eqConditions = Lists.newArrayList();
|
||||
|
||||
List<Slot> leftInput = join.left().getOutput();
|
||||
List<Slot> rightInput = join.right().getOutput();
|
||||
Set<Slot> leftInput = join.left().getOutputSet();
|
||||
Set<Slot> rightInput = join.right().getOutputSet();
|
||||
|
||||
ExpressionUtils.extractConjunction(ExpressionUtils.and(onPredicates, wherePredicates))
|
||||
.forEach(predicate -> {
|
||||
@ -122,18 +121,18 @@ public class PushPredicateThroughJoin extends OneRewriteRuleFactory {
|
||||
Plan leftPlan = joinPlan.left();
|
||||
Plan rightPlan = joinPlan.right();
|
||||
if (!left.equals(BooleanLiteral.TRUE)) {
|
||||
leftPlan = new LogicalFilter(left, leftPlan);
|
||||
leftPlan = new LogicalFilter<>(left, leftPlan);
|
||||
}
|
||||
|
||||
if (!right.equals(BooleanLiteral.TRUE)) {
|
||||
rightPlan = new LogicalFilter(right, rightPlan);
|
||||
rightPlan = new LogicalFilter<>(right, rightPlan);
|
||||
}
|
||||
|
||||
return new LogicalJoin<>(joinPlan.getJoinType(), joinPlan.getHashJoinConjuncts(),
|
||||
Optional.of(ExpressionUtils.and(joinConditions)), leftPlan, rightPlan);
|
||||
}
|
||||
|
||||
private Expression getJoinCondition(Expression predicate, List<Slot> leftOutputs, List<Slot> rightOutputs) {
|
||||
private Expression getJoinCondition(Expression predicate, Set<Slot> leftOutputs, Set<Slot> rightOutputs) {
|
||||
if (!(predicate instanceof ComparisonPredicate)) {
|
||||
return null;
|
||||
}
|
||||
@ -147,11 +146,8 @@ public class PushPredicateThroughJoin extends OneRewriteRuleFactory {
|
||||
return null;
|
||||
}
|
||||
|
||||
Set<Slot> left = Sets.newLinkedHashSet(leftOutputs);
|
||||
Set<Slot> right = Sets.newLinkedHashSet(rightOutputs);
|
||||
|
||||
if ((left.containsAll(leftSlots) && right.containsAll(rightSlots)) || (left.containsAll(rightSlots)
|
||||
&& right.containsAll(leftSlots))) {
|
||||
if ((leftOutputs.containsAll(leftSlots) && rightOutputs.containsAll(rightSlots))
|
||||
|| (leftOutputs.containsAll(rightSlots) && rightOutputs.containsAll(leftSlots))) {
|
||||
return predicate;
|
||||
}
|
||||
|
||||
|
||||
@ -21,6 +21,7 @@ import org.apache.doris.nereids.trees.expressions.And;
|
||||
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.Or;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
|
||||
|
||||
@ -142,4 +143,20 @@ public class ExpressionUtils {
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Choose the minimum slot from input parameter.
|
||||
*/
|
||||
public static Slot selectMinimumColumn(List<Slot> slots) {
|
||||
Slot minSlot = null;
|
||||
for (Slot slot : slots) {
|
||||
if (minSlot == null) {
|
||||
minSlot = slot;
|
||||
} else {
|
||||
int slotDataTypeWidth = slot.getDataType().width();
|
||||
minSlot = minSlot.getDataType().width() > slotDataTypeWidth ? slot : minSlot;
|
||||
}
|
||||
}
|
||||
return minSlot;
|
||||
}
|
||||
}
|
||||
|
||||
@ -269,6 +269,28 @@ public class ColumnPruningTest extends TestWithFeService implements PatternMatch
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void pruneColumnForOneSideOnCrossJoin() {
|
||||
PlanChecker.from(connectContext)
|
||||
.analyze("select id,name from student cross join score")
|
||||
.applyTopDown(new ColumnPruning())
|
||||
.matchesFromRoot(
|
||||
logicalProject(
|
||||
logicalJoin(
|
||||
logicalProject(logicalRelation())
|
||||
.when(p -> getOutputQualifiedNames(p)
|
||||
.containsAll(ImmutableList.of(
|
||||
"default_cluster:test.student.id",
|
||||
"default_cluster:test.student.name"))),
|
||||
logicalProject(logicalRelation())
|
||||
.when(p -> getOutputQualifiedNames(p)
|
||||
.containsAll(ImmutableList.of(
|
||||
"default_cluster:test.score.sid")))
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
private List<String> getOutputQualifiedNames(LogicalProject<? extends Plan> p) {
|
||||
return p.getProjects().stream().map(NamedExpression::getQualifiedName).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user