diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/ApplyRuleJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/ApplyRuleJob.java index 3e73850b01..5560c369dd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/ApplyRuleJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/ApplyRuleJob.java @@ -61,7 +61,7 @@ public class ApplyRuleJob extends Job { } @Override - public void execute() throws AnalysisException { + public final void execute() throws AnalysisException { if (groupExpression.hasApplied(rule) || groupExpression.isUnused()) { return; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/GroupExpressionMatching.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/GroupExpressionMatching.java index f73ddcc886..e281e74a33 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/GroupExpressionMatching.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/pattern/GroupExpressionMatching.java @@ -55,6 +55,7 @@ public class GroupExpressionMatching implements Iterable { public static class GroupExpressionIterator implements Iterator { private final List results = Lists.newArrayList(); private int resultIndex = 0; + private int resultsSize; /** * Constructor. @@ -103,7 +104,7 @@ public class GroupExpressionMatching implements Iterable { // matching children group, one List per child // first dimension is every child group's plan // second dimension is all matched plan in one group - List> childrenPlans = Lists.newArrayListWithCapacity(childrenGroupArity); + List[] childrenPlans = new List[childrenGroupArity]; for (int i = 0; i < childrenGroupArity; ++i) { Group childGroup = groupExpression.child(i); List childrenPlan = matchingChildGroup(pattern, childGroup, i); @@ -116,7 +117,7 @@ public class GroupExpressionMatching implements Iterable { return; } } - childrenPlans.add(childrenPlan); + childrenPlans[i] = childrenPlan; } assembleAllCombinationPlanTree(root, pattern, groupExpression, childrenPlans); } else if (patternArity == 1 && (pattern.hasMultiChild() || pattern.hasMultiGroupChild())) { @@ -127,6 +128,7 @@ public class GroupExpressionMatching implements Iterable { results.add(root); } } + this.resultsSize = results.size(); } private List matchingChildGroup(Pattern parentPattern, @@ -154,38 +156,35 @@ public class GroupExpressionMatching implements Iterable { } private void assembleAllCombinationPlanTree(Plan root, Pattern rootPattern, - GroupExpression groupExpression, - List> childrenPlans) { - int[] childrenPlanIndex = new int[childrenPlans.size()]; + GroupExpression groupExpression, List[] childrenPlans) { + int childrenPlansSize = childrenPlans.length; + int[] childrenPlanIndex = new int[childrenPlansSize]; int offset = 0; LogicalProperties logicalProperties = groupExpression.getOwnerGroup().getLogicalProperties(); // assemble all combination of plan tree by current root plan and children plan - while (offset < childrenPlans.size()) { - ImmutableList.Builder childrenBuilder = - ImmutableList.builderWithExpectedSize(childrenPlans.size()); - for (int i = 0; i < childrenPlans.size(); i++) { - childrenBuilder.add(childrenPlans.get(i).get(childrenPlanIndex[i])); + Optional groupExprOption = Optional.of(groupExpression); + Optional logicalPropOption = Optional.of(logicalProperties); + while (offset < childrenPlansSize) { + ImmutableList.Builder childrenBuilder = ImmutableList.builderWithExpectedSize(childrenPlansSize); + for (int i = 0; i < childrenPlansSize; i++) { + childrenBuilder.add(childrenPlans[i].get(childrenPlanIndex[i])); } List children = childrenBuilder.build(); // assemble children: replace GroupPlan to real plan, // withChildren will erase groupExpression, so we must // withGroupExpression too. - Plan rootWithChildren = root.withGroupExprLogicalPropChildren(Optional.of(groupExpression), - Optional.of(logicalProperties), children); + Plan rootWithChildren = root.withGroupExprLogicalPropChildren(groupExprOption, + logicalPropOption, children); if (rootPattern.matchPredicates(rootWithChildren)) { results.add(rootWithChildren); } - offset = 0; - while (true) { + for (offset = 0; offset < childrenPlansSize; offset++) { childrenPlanIndex[offset]++; - if (childrenPlanIndex[offset] == childrenPlans.get(offset).size()) { + if (childrenPlanIndex[offset] == childrenPlans[offset].size()) { + // Reset the index when it reaches the size of the current child plan list childrenPlanIndex[offset] = 0; - offset++; - if (offset == childrenPlans.size()) { - break; - } } else { break; } @@ -195,7 +194,7 @@ public class GroupExpressionMatching implements Iterable { @Override public boolean hasNext() { - return resultIndex < results.size(); + return resultIndex < resultsSize; } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/AbstractTreeNode.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/AbstractTreeNode.java index 0305ae2afa..7a545ec17b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/AbstractTreeNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/AbstractTreeNode.java @@ -17,10 +17,6 @@ package org.apache.doris.nereids.trees; -import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator; -import org.apache.doris.nereids.trees.plans.ObjectId; -import org.apache.doris.planner.PlanNodeId; - import com.google.common.collect.ImmutableList; import java.util.List; @@ -33,7 +29,6 @@ import java.util.List; */ public abstract class AbstractTreeNode> implements TreeNode { - protected final ObjectId id = StatementScopeIdGenerator.newObjectId(); protected final List children; // TODO: Maybe we should use a GroupPlan to avoid TreeNode hold the GroupExpression. // https://github.com/apache/doris/pull/9807#discussion_r884829067 @@ -59,12 +54,4 @@ public abstract class AbstractTreeNode> public int arity() { return children.size(); } - - /** - * used for PhysicalPlanTranslator only - * @return PlanNodeId - */ - public PlanNodeId translatePlanNodeId() { - return id.toPlanNodeId(); - } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java index 3f0370d7c3..12a3a9768c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java @@ -37,6 +37,7 @@ import org.apache.doris.nereids.types.coercion.AnyDataType; import org.apache.doris.nereids.util.Utils; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; import org.apache.commons.lang3.StringUtils; @@ -44,7 +45,6 @@ import java.util.Arrays; import java.util.List; import java.util.Objects; import java.util.Set; -import java.util.stream.Collectors; /** * Abstract class for all Expression in Nereids. @@ -247,8 +247,19 @@ public abstract class Expression extends AbstractTreeNode implements return collect(Slot.class::isInstance); } + /** + * Get all the input slot ids of the expression. + *

+ * Note that the input slots of subquery's inner plan is not included. + */ public final Set getInputSlotExprIds() { - return getInputSlots().stream().map(NamedExpression::getExprId).collect(Collectors.toSet()); + ImmutableSet.Builder result = ImmutableSet.builder(); + foreach(node -> { + if (node instanceof Slot) { + result.add(((Slot) node).getExprId()); + } + }); + return result.build(); } public boolean isLiteral() { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/AbstractPlan.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/AbstractPlan.java index 38a209ff55..c223dd43b6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/AbstractPlan.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/AbstractPlan.java @@ -24,9 +24,11 @@ import org.apache.doris.nereids.properties.UnboundLogicalProperties; import org.apache.doris.nereids.trees.AbstractTreeNode; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator; import org.apache.doris.nereids.util.MutableState; import org.apache.doris.nereids.util.MutableState.EmptyMutableState; import org.apache.doris.nereids.util.TreeStringUtils; +import org.apache.doris.planner.PlanNodeId; import org.apache.doris.statistics.Statistics; import com.google.common.base.Supplier; @@ -45,6 +47,7 @@ import javax.annotation.Nullable; */ public abstract class AbstractPlan extends AbstractTreeNode implements Plan { public static final String FRAGMENT_ID = "fragment"; + protected final ObjectId id = StatementScopeIdGenerator.newObjectId(); protected final Statistics statistics; protected final PlanType type; @@ -168,4 +171,12 @@ public abstract class AbstractPlan extends AbstractTreeNode implements Pla public void setMutableState(String key, Object state) { this.mutableState = this.mutableState.set(key, state); } + + /** + * used for PhysicalPlanTranslator only + * @return PlanNodeId + */ + public PlanNodeId translatePlanNodeId() { + return id.toPlanNodeId(); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashJoin.java index b60afd6730..994b4d4f97 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashJoin.java @@ -43,9 +43,9 @@ import org.apache.doris.statistics.Statistics; import org.apache.doris.thrift.TRuntimeFilterType; import com.google.common.base.Preconditions; -import com.google.common.collect.Lists; import com.google.common.collect.Sets; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Optional; @@ -113,22 +113,25 @@ public class PhysicalHashJoin< * Return pair of left used slots and right used slots. */ public Pair, List> getHashConjunctsExprIds() { - List exprIds1 = Lists.newArrayListWithCapacity(hashJoinConjuncts.size()); - List exprIds2 = Lists.newArrayListWithCapacity(hashJoinConjuncts.size()); + int size = hashJoinConjuncts.size(); + + List exprIds1 = new ArrayList<>(size); + List exprIds2 = new ArrayList<>(size); Set leftExprIds = left().getOutputExprIdSet(); Set rightExprIds = right().getOutputExprIdSet(); for (Expression expr : hashJoinConjuncts) { - expr.getInputSlotExprIds().forEach(exprId -> { + for (ExprId exprId : expr.getInputSlotExprIds()) { if (leftExprIds.contains(exprId)) { exprIds1.add(exprId); } else if (rightExprIds.contains(exprId)) { exprIds2.add(exprId); } else { - throw new RuntimeException("Could not generate valid equal on clause slot pairs for join"); + throw new RuntimeException("Invalid ExprId found: " + exprId + + ". Cannot generate valid equal on clause slot pairs for join."); } - }); + } } return Pair.of(exprIds1, exprIds2); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java index d1fb973dd6..bcf53ce29f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java @@ -38,6 +38,7 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute; import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin; import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan; import org.apache.doris.qe.ConnectContext; +import org.apache.doris.qe.SessionVariable; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; @@ -66,9 +67,10 @@ public class JoinUtils { * check if the row count of the left child in the broadcast join is less than a threshold value. */ public static boolean checkBroadcastJoinStats(PhysicalHashJoin join) { - double memLimit = ConnectContext.get().getSessionVariable().getMaxExecMemByte(); - double rowsLimit = ConnectContext.get().getSessionVariable().getBroadcastRowCountLimit(); - double brMemlimit = ConnectContext.get().getSessionVariable().getBroadcastHashtableMemLimitPercentage(); + SessionVariable sessionVariable = ConnectContext.get().getSessionVariable(); + double memLimit = sessionVariable.getMaxExecMemByte(); + double rowsLimit = sessionVariable.getBroadcastRowCountLimit(); + double brMemlimit = sessionVariable.getBroadcastHashtableMemLimitPercentage(); double datasize = join.getGroupExpression().get().child(1).getStatistics().computeSize(); double rowCount = join.getGroupExpression().get().child(1).getStatistics().getRowCount(); return rowCount <= rowsLimit && datasize <= memLimit * brMemlimit; @@ -114,12 +116,12 @@ public class JoinUtils { * @return true if the equal can be used as hash join condition */ public boolean isHashJoinCondition(EqualTo equalTo) { - Set equalLeft = equalTo.left().collect(Slot.class::isInstance); + Set equalLeft = equalTo.left().getInputSlots(); if (equalLeft.isEmpty()) { return false; } - Set equalRight = equalTo.right().collect(Slot.class::isInstance); + Set equalRight = equalTo.right().getInputSlots(); if (equalRight.isEmpty()) { return false; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java index 17034c15e6..48eb452a74 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java @@ -55,7 +55,7 @@ public class PlanUtils { * normalize comparison predicate on a binary plan to its two sides are corresponding to the child's output. */ public static ComparisonPredicate maybeCommuteComparisonPredicate(ComparisonPredicate expression, Plan left) { - Set slots = expression.left().collect(Slot.class::isInstance); + Set slots = expression.left().getInputSlots(); Set leftSlots = left.getOutputSet(); Set buffer = Sets.newHashSet(slots); buffer.removeAll(leftSlots); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/pattern/GroupExpressionMatchingTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/pattern/GroupExpressionMatchingTest.java index 53a459859b..6a4d38b5ad 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/pattern/GroupExpressionMatchingTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/pattern/GroupExpressionMatchingTest.java @@ -42,10 +42,10 @@ import org.junit.jupiter.api.Test; import java.util.Iterator; -public class GroupExpressionMatchingTest { +class GroupExpressionMatchingTest { @Test - public void testLeafNode() { + void testLeafNode() { Pattern pattern = new Pattern<>(PlanType.LOGICAL_UNBOUND_RELATION); Memo memo = new Memo(null, new UnboundRelation(StatementScopeIdGenerator.newRelationId(), Lists.newArrayList("test"))); @@ -61,7 +61,7 @@ public class GroupExpressionMatchingTest { } @Test - public void testDepth2() { + void testDepth2() { Pattern pattern = new Pattern<>(PlanType.LOGICAL_PROJECT, new Pattern<>(PlanType.LOGICAL_UNBOUND_RELATION)); @@ -93,7 +93,7 @@ public class GroupExpressionMatchingTest { } @Test - public void testDepth2WithGroup() { + void testDepth2WithGroup() { Pattern pattern = new Pattern<>(PlanType.LOGICAL_PROJECT, Pattern.GROUP); Plan leaf = new UnboundRelation(StatementScopeIdGenerator.newRelationId(), Lists.newArrayList("test")); @@ -119,7 +119,7 @@ public class GroupExpressionMatchingTest { } @Test - public void testLeafAny() { + void testLeafAny() { Pattern pattern = Pattern.ANY; Memo memo = new Memo(null, new UnboundRelation(StatementScopeIdGenerator.newRelationId(), Lists.newArrayList("test"))); @@ -135,7 +135,7 @@ public class GroupExpressionMatchingTest { } @Test - public void testAnyWithChild() { + void testAnyWithChild() { Plan root = new LogicalProject( ImmutableList.of(new SlotReference("name", StringType.INSTANCE, true, ImmutableList.of("test"))), @@ -159,7 +159,7 @@ public class GroupExpressionMatchingTest { } @Test - public void testInnerLogicalJoinMatch() { + void testInnerLogicalJoinMatch() { Plan root = new LogicalJoin(JoinType.INNER_JOIN, new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("a")), new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("b")) @@ -181,7 +181,7 @@ public class GroupExpressionMatchingTest { } @Test - public void testInnerLogicalJoinMismatch() { + void testInnerLogicalJoinMismatch() { Plan root = new LogicalJoin(JoinType.LEFT_OUTER_JOIN, new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("a")), new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("b")) @@ -198,7 +198,7 @@ public class GroupExpressionMatchingTest { } @Test - public void testTopMatchButChildrenNotMatch() { + void testTopMatchButChildrenNotMatch() { Plan root = new LogicalJoin(JoinType.LEFT_OUTER_JOIN, new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("a")), new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("b")) @@ -216,12 +216,12 @@ public class GroupExpressionMatchingTest { } @Test - public void testSubTreeMatch() { + void testSubTreeMatch() { Plan root = - new LogicalFilter(ImmutableSet.of(new EqualTo(new UnboundSlot(Lists.newArrayList("a", "id")), + new LogicalFilter<>(ImmutableSet.of(new EqualTo(new UnboundSlot(Lists.newArrayList("a", "id")), new UnboundSlot(Lists.newArrayList("b", "id")))), - new LogicalJoin(JoinType.INNER_JOIN, - new LogicalJoin(JoinType.LEFT_OUTER_JOIN, + new LogicalJoin<>(JoinType.INNER_JOIN, + new LogicalJoin<>(JoinType.LEFT_OUTER_JOIN, new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("a")), new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("b"))), new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("c"))) diff --git a/regression-test/suites/nereids_p0/expression/topn_to_max.groovy b/regression-test/suites/nereids_p0/expression/topn_to_max.groovy index ae848b5a24..4c05b42ccc 100644 --- a/regression-test/suites/nereids_p0/expression/topn_to_max.groovy +++ b/regression-test/suites/nereids_p0/expression/topn_to_max.groovy @@ -31,7 +31,7 @@ suite("test_topn_to_max") { group by k1; ''' res = sql ''' - explain rewritten plan select k1, max(k2) + explain rewritten plan select k1, topn(k2, 1) from test_topn_to_max group by k1; ''' @@ -42,7 +42,7 @@ suite("test_topn_to_max") { from test_topn_to_max; ''' res = sql ''' - explain rewritten plan select max(k2) + explain rewritten plan select topn(k2, 1) from test_topn_to_max; ''' assertTrue(res.toString().contains("max"))