From 2f89ec961f614789853821b993c0f889d8dd4478 Mon Sep 17 00:00:00 2001 From: 924060929 <924060929@qq.com> Date: Wed, 13 Mar 2024 14:07:40 +0800 Subject: [PATCH] [enhancement](Nereids) Optimize expression (#32067) 1. optimize expression comparison by a) flatter method call stack b) short circuit: if some simple field not equals, then return, and not compare to the big field, like Alias.name 2. lazy compute Alias.name which is computed by toSql. Now Alias.toSlot() will not generate long name immediately 3. cache Expresssion.inputSlots, it can save time when invoke this method multiple times 4. always compute Expression.unbound, it can avoid traverse the big expression tree this pr can save about 200ms when submit some long sqls --- .../rules/analysis/FillUpMissingSlots.java | 8 +-- .../rules/analysis/NormalizeAggregate.java | 6 +- .../rules/analysis/NormalizeRepeat.java | 12 ++-- .../rules/rewrite/NormalizeToSlot.java | 28 +++++--- .../apache/doris/nereids/trees/TreeNode.java | 4 +- .../nereids/trees/expressions/Alias.java | 50 ++++++++------ .../doris/nereids/trees/expressions/Any.java | 4 ++ .../trees/expressions/ArrayItemReference.java | 4 +- .../trees/expressions/BinaryOperator.java | 12 ---- .../nereids/trees/expressions/Expression.java | 69 ++++++++++++++----- .../expressions/MarkJoinSlotReference.java | 2 +- .../trees/expressions/SlotReference.java | 54 ++++++++++----- .../expressions/VirtualSlotReference.java | 6 +- .../nereids/trees/expressions/WhenClause.java | 15 ---- .../trees/expressions/WindowExpression.java | 2 +- .../expressions/functions/BoundFunction.java | 26 ++++--- .../nereids/trees/plans/AbstractPlan.java | 14 ++-- .../doris/nereids/trees/plans/Plan.java | 8 ++- .../doris/nereids/util/ExpressionUtils.java | 16 +++-- .../apache/doris/nereids/util/PlanUtils.java | 59 ++++++++++------ .../analysis/FillUpMissingSlotsTest.java | 2 +- 21 files changed, 233 insertions(+), 168 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java index 5688b9d48b..d6c783bbe9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java @@ -66,7 +66,7 @@ public class FillUpMissingSlots implements AnalysisRuleFactory { .flatMap(Set::stream) .filter(s -> !projectOutputSet.contains(s)) .collect(Collectors.toSet()); - if (notExistedInProject.size() == 0) { + if (notExistedInProject.isEmpty()) { return null; } List projects = ImmutableList.builder() @@ -128,7 +128,7 @@ public class FillUpMissingSlots implements AnalysisRuleFactory { .flatMap(Set::stream) .filter(s -> !childOutput.contains(s)) .collect(Collectors.toSet()); - if (notExistedInProject.size() == 0) { + if (notExistedInProject.isEmpty()) { return null; } LogicalProject project = sort.child().child(); @@ -165,7 +165,7 @@ public class FillUpMissingSlots implements AnalysisRuleFactory { .flatMap(Set::stream) .filter(s -> !projectOutputSet.contains(s)) .collect(Collectors.toSet()); - if (notExistedInProject.size() == 0) { + if (notExistedInProject.isEmpty()) { return null; } List projects = ImmutableList.builder() @@ -189,7 +189,7 @@ public class FillUpMissingSlots implements AnalysisRuleFactory { outputExpressions = aggregate.getOutputExpressions(); groupByExpressions = aggregate.getGroupByExpressions(); outputSubstitutionMap = outputExpressions.stream().filter(Alias.class::isInstance) - .collect(Collectors.toMap(alias -> alias.toSlot(), alias -> alias.child(0), + .collect(Collectors.toMap(NamedExpression::toSlot, alias -> alias.child(0), (k1, k2) -> k1)); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java index 5874c26e17..a8c3445261 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java @@ -37,12 +37,11 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalHaving; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.util.ExpressionUtils; -import org.apache.doris.nereids.util.PlanUtils; +import org.apache.doris.nereids.util.PlanUtils.CollectNonWindowedAggFuncs; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList.Builder; import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Lists; import com.google.common.collect.Sets; import java.util.HashSet; @@ -144,8 +143,7 @@ public class NormalizeAggregate implements RewriteRuleFactory, NormalizeToSlot { // collect all trival-agg List aggregateOutput = aggregate.getOutputExpressions(); - List aggFuncs = Lists.newArrayList(); - aggregateOutput.forEach(o -> o.accept(PlanUtils.CollectNonWindowedAggFuncs.INSTANCE, aggFuncs)); + List aggFuncs = CollectNonWindowedAggFuncs.collect(aggregateOutput); // split non-distinct agg child as two part // TRUE part 1: need push down itself, if it contains subqury or window expression diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java index 8437dc40b0..9451cc40bc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java @@ -39,11 +39,10 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat; import org.apache.doris.nereids.util.ExpressionUtils; -import org.apache.doris.nereids.util.PlanUtils; +import org.apache.doris.nereids.util.PlanUtils.CollectNonWindowedAggFuncs; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Sets; import com.google.common.collect.Sets.SetView; @@ -174,9 +173,7 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory { .flatMap(function -> function.getArguments().stream()) .collect(ImmutableSet.toImmutableSet()); - List aggregateFunctions = Lists.newArrayList(); - repeat.getOutputExpressions().forEach( - o -> o.accept(PlanUtils.CollectNonWindowedAggFuncs.INSTANCE, aggregateFunctions)); + List aggregateFunctions = CollectNonWindowedAggFuncs.collect(repeat.getOutputExpressions()); ImmutableSet argumentsOfAggregateFunction = aggregateFunctions.stream() .flatMap(function -> function.getArguments().stream().map(arg -> { @@ -271,9 +268,8 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory { @NotNull LogicalAggregate aggregate) { LogicalRepeat repeat = (LogicalRepeat) aggregate.child(); - List aggregateFunctions = Lists.newArrayList(); - aggregate.getOutputExpressions().forEach( - o -> o.accept(PlanUtils.CollectNonWindowedAggFuncs.INSTANCE, aggregateFunctions)); + List aggregateFunctions = + CollectNonWindowedAggFuncs.collect(aggregate.getOutputExpressions()); Set aggUsedSlots = aggregateFunctions.stream() .flatMap(e -> e.>collect(SlotReference.class::isInstance).stream()) .collect(ImmutableSet.toImmutableSet()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeToSlot.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeToSlot.java index efad94c665..683841a5f8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeToSlot.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeToSlot.java @@ -29,13 +29,11 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Maps; -import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Set; import java.util.function.BiFunction; -import java.util.stream.Collectors; import javax.annotation.Nullable; /** @@ -119,9 +117,11 @@ public interface NormalizeToSlot { public List normalizeToUseSlotRefWithoutWindowFunction( Collection expressions) { - return expressions.stream() - .map(e -> (E) e.accept(NormalizeWithoutWindowFunction.INSTANCE, normalizeToSlotMap)) - .collect(Collectors.toList()); + ImmutableList.Builder normalized = ImmutableList.builderWithExpectedSize(expressions.size()); + for (E expression : expressions) { + normalized.add((E) expression.accept(NormalizeWithoutWindowFunction.INSTANCE, normalizeToSlotMap)); + } + return normalized.build(); } /** @@ -155,8 +155,9 @@ public interface NormalizeToSlot { @Override public Expression visit(Expression expr, Map replaceMap) { - if (replaceMap.containsKey(expr)) { - return replaceMap.get(expr).remainExpr; + NormalizeToSlotTriplet triplet = replaceMap.get(expr); + if (triplet != null) { + return triplet.remainExpr; } return super.visit(expr, replaceMap); } @@ -164,10 +165,12 @@ public interface NormalizeToSlot { @Override public Expression visitWindow(WindowExpression windowExpression, Map replaceMap) { - if (replaceMap.containsKey(windowExpression)) { - return replaceMap.get(windowExpression).remainExpr; + NormalizeToSlotTriplet triplet = replaceMap.get(windowExpression); + if (triplet != null) { + return triplet.remainExpr; } - List newChildren = new ArrayList<>(); + ImmutableList.Builder newChildren = + ImmutableList.builderWithExpectedSize(windowExpression.arity()); Expression function = super.visit(windowExpression.getFunction(), replaceMap); newChildren.add(function); boolean hasNewChildren = function != windowExpression.getFunction(); @@ -185,10 +188,13 @@ public interface NormalizeToSlot { } newChildren.add(newChild); } + if (!hasNewChildren) { + return windowExpression; + } if (windowExpression.getWindowFrame().isPresent()) { newChildren.add(windowExpression.getWindowFrame().get()); } - return hasNewChildren ? windowExpression.withChildren(newChildren) : windowExpression; + return windowExpression.withChildren(newChildren.build()); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java index fb41384b5d..d37070865e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java @@ -144,9 +144,7 @@ public interface TreeNode> { boolean changed = false; for (NODE_TYPE child : children()) { NODE_TYPE newChild = child.rewriteUp(rewriteFunction); - if (child != newChild) { - changed = true; - } + changed |= child != newChild; newChildren.add(newChild); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Alias.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Alias.java index 1aacb02949..a15cec4722 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Alias.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Alias.java @@ -23,11 +23,13 @@ import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.DataType; import com.google.common.base.Preconditions; +import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableList; import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.function.Supplier; /** * Expression for alias, such as col1 as c1. @@ -35,7 +37,7 @@ import java.util.Optional; public class Alias extends NamedExpression implements UnaryExpression { private final ExprId exprId; - private final String name; + private final Supplier name; private final List qualifier; private final boolean nameFromChild; @@ -50,7 +52,8 @@ public class Alias extends NamedExpression implements UnaryExpression { } public Alias(Expression child) { - this(StatementScopeIdGenerator.newExprId(), child, child.toSql(), true); + this(StatementScopeIdGenerator.newExprId(), ImmutableList.of(child), + Suppliers.memoize(child::toSql), ImmutableList.of(), true); } public Alias(ExprId exprId, Expression child, String name) { @@ -62,6 +65,11 @@ public class Alias extends NamedExpression implements UnaryExpression { } public Alias(ExprId exprId, List child, String name, List qualifier, boolean nameFromChild) { + this(exprId, child, Suppliers.memoize(() -> name), qualifier, nameFromChild); + } + + private Alias(ExprId exprId, List child, Supplier name, + List qualifier, boolean nameFromChild) { super(child); this.exprId = exprId; this.name = name; @@ -73,6 +81,10 @@ public class Alias extends NamedExpression implements UnaryExpression { public Slot toSlot() throws UnboundException { SlotReference slotReference = child() instanceof SlotReference ? (SlotReference) child() : null; + + Supplier> internalName = nameFromChild + ? Suppliers.memoize(() -> Optional.of(child().toString())) + : () -> Optional.of(name.get()); return new SlotReference(exprId, name, child().getDataType(), child().nullable(), qualifier, slotReference != null ? ((SlotReference) child()).getTable().orElse(null) @@ -80,14 +92,16 @@ public class Alias extends NamedExpression implements UnaryExpression { slotReference != null ? slotReference.getColumn().orElse(null) : null, - nameFromChild ? Optional.of(child().toString()) : Optional.of(name), slotReference != null - ? slotReference.getSubColPath() - : null); + internalName, + slotReference != null + ? slotReference.getSubColPath() + : null + ); } @Override public String getName() throws UnboundException { - return name; + return name.get(); } @Override @@ -107,7 +121,7 @@ public class Alias extends NamedExpression implements UnaryExpression { @Override public String toSql() { - return child().toSql() + " AS `" + name + "`"; + return child().toSql() + " AS `" + name.get() + "`"; } @Override @@ -116,35 +130,31 @@ public class Alias extends NamedExpression implements UnaryExpression { } @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { + protected boolean extraEquals(Expression other) { + Alias that = (Alias) other; + if (!exprId.equals(that.exprId) || !qualifier.equals(that.qualifier)) { return false; } - Alias that = (Alias) o; - return exprId.equals(that.exprId) - && name.equals(that.name) - && qualifier.equals(that.qualifier) - && child().equals(that.child()); + + return nameFromChild || name.get().equals(that.name.get()); } @Override public int hashCode() { - return Objects.hash(exprId, name, qualifier, children()); + return Objects.hash(exprId, qualifier); } @Override public String toString() { - return child().toString() + " AS `" + name + "`#" + exprId; + return child().toString() + " AS `" + name.get() + "`#" + exprId; } @Override public Alias withChildren(List children) { Preconditions.checkArgument(children.size() == 1); if (nameFromChild) { - return new Alias(exprId, children, children.get(0).toSql(), qualifier, nameFromChild); + return new Alias(exprId, children, + Suppliers.memoize(() -> children.get(0).toSql()), qualifier, nameFromChild); } else { return new Alias(exprId, children, name, qualifier, nameFromChild); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Any.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Any.java index 2e4bc745b2..287362d817 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Any.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Any.java @@ -74,4 +74,8 @@ public class Any extends Expression implements LeafExpression { public boolean deepEquals(TreeNode that) { return true; } + + protected boolean supportCompareWidthAndDepth() { + return false; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ArrayItemReference.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ArrayItemReference.java index 3dd0ef6485..b43e481119 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ArrayItemReference.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ArrayItemReference.java @@ -147,7 +147,7 @@ public class ArrayItemReference extends NamedExpression implements ExpectsInputT @Override public ArrayItemSlot withExprId(ExprId exprId) { - return new ArrayItemSlot(exprId, name, dataType, nullable); + return new ArrayItemSlot(exprId, name.get(), dataType, nullable); } @Override @@ -157,7 +157,7 @@ public class ArrayItemReference extends NamedExpression implements ExpectsInputT @Override public SlotReference withNullable(boolean newNullable) { - return new ArrayItemSlot(exprId, name, dataType, nullable); + return new ArrayItemSlot(exprId, name.get(), dataType, nullable); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/BinaryOperator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/BinaryOperator.java index 2d06456d0a..01a61d576d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/BinaryOperator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/BinaryOperator.java @@ -68,16 +68,4 @@ public abstract class BinaryOperator extends Expression implements BinaryExpress public int hashCode() { return Objects.hash(symbol, left(), right()); } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - BinaryOperator other = (BinaryOperator) o; - return Objects.equals(left(), other.left()) && Objects.equals(right(), other.right()); - } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java index 1660efa3a3..a7947c82a5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java @@ -42,14 +42,16 @@ import org.apache.doris.nereids.types.StructType; import org.apache.doris.nereids.util.Utils; import com.google.common.base.Preconditions; +import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.ImmutableSet.Builder; import com.google.common.collect.Lists; import org.apache.commons.lang3.StringUtils; import java.util.List; -import java.util.Objects; import java.util.Optional; import java.util.Set; +import java.util.function.Supplier; /** * Abstract class for all Expression in Nereids. @@ -62,21 +64,30 @@ public abstract class Expression extends AbstractTreeNode implements private final int width; // Mark this expression is from predicate infer or something else infer private final boolean inferred; + private final boolean hasUnbound; + private final boolean compareWidthAndDepth; + private final Supplier> inputSlots = Suppliers.memoize(() -> collect(Slot.class::isInstance)); protected Expression(Expression... children) { super(children); int maxChildDepth = 0; int sumChildWidth = 0; + boolean hasUnbound = false; + boolean compareWidthAndDepth = true; for (int i = 0; i < children.length; ++i) { Expression child = children[i]; maxChildDepth = Math.max(child.depth, maxChildDepth); sumChildWidth += child.width; + hasUnbound |= child.hasUnbound; + compareWidthAndDepth &= (child.compareWidthAndDepth & child.supportCompareWidthAndDepth()); } this.depth = maxChildDepth + 1; this.width = sumChildWidth + ((children.length == 0) ? 1 : 0); + this.compareWidthAndDepth = compareWidthAndDepth; checkLimit(); this.inferred = false; + this.hasUnbound = hasUnbound || this instanceof Unbound; } protected Expression(List children) { @@ -87,16 +98,22 @@ public abstract class Expression extends AbstractTreeNode implements super(children); int maxChildDepth = 0; int sumChildWidth = 0; + boolean hasUnbound = false; + boolean compareWidthAndDepth = true; for (int i = 0; i < children.size(); ++i) { Expression child = children.get(i); maxChildDepth = Math.max(child.depth, maxChildDepth); sumChildWidth += child.width; + hasUnbound |= child.hasUnbound; + compareWidthAndDepth &= (child.compareWidthAndDepth & child.supportCompareWidthAndDepth()); } this.depth = maxChildDepth + 1; this.width = sumChildWidth + ((children.isEmpty()) ? 1 : 0); + this.compareWidthAndDepth = compareWidthAndDepth; checkLimit(); this.inferred = inferred; + this.hasUnbound = hasUnbound || this instanceof Unbound; } private void checkLimit() { @@ -293,7 +310,7 @@ public abstract class Expression extends AbstractTreeNode implements * Note that the input slots of subquery's inner plan is not included. */ public final Set getInputSlots() { - return collect(Slot.class::isInstance); + return inputSlots.get(); } /** @@ -302,13 +319,12 @@ public abstract class Expression extends AbstractTreeNode implements * Note that the input slots of subquery's inner plan is not included. */ public final Set getInputSlotExprIds() { - ImmutableSet.Builder result = ImmutableSet.builder(); - foreach(node -> { - if (node instanceof Slot) { - result.add(((Slot) node).getExprId()); - } - }); - return result.build(); + Set inputSlots = getInputSlots(); + Builder exprIds = ImmutableSet.builderWithExpectedSize(inputSlots.size()); + for (Slot inputSlot : inputSlots) { + exprIds.add(inputSlot.getExprId()); + } + return exprIds.build(); } public boolean isLiteral() { @@ -336,7 +352,26 @@ public abstract class Expression extends AbstractTreeNode implements return false; } Expression that = (Expression) o; - return Objects.equals(children(), that.children()); + if ((compareWidthAndDepth && (this.width != that.width || this.depth != that.depth)) + || arity() != that.arity() || !extraEquals(that)) { + return false; + } + return equalsChildren(that); + } + + protected boolean equalsChildren(Expression that) { + List children = children(); + List thatChildren = that.children(); + for (int i = 0; i < children.size(); i++) { + if (!children.get(i).equals(thatChildren.get(i))) { + return false; + } + } + return true; + } + + protected boolean extraEquals(Expression that) { + return true; } @Override @@ -348,18 +383,14 @@ public abstract class Expression extends AbstractTreeNode implements * This expression has unbound symbols or not. */ public boolean hasUnbound() { - if (this instanceof Unbound) { - return true; - } - for (Expression child : children) { - if (child.hasUnbound()) { - return true; - } - } - return false; + return this.hasUnbound; } public String shapeInfo() { return toSql(); } + + protected boolean supportCompareWidthAndDepth() { + return true; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/MarkJoinSlotReference.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/MarkJoinSlotReference.java index b9711ed52a..bfbd24647e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/MarkJoinSlotReference.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/MarkJoinSlotReference.java @@ -71,7 +71,7 @@ public class MarkJoinSlotReference extends SlotReference { @Override public MarkJoinSlotReference withExprId(ExprId exprId) { - return new MarkJoinSlotReference(exprId, name, existsHasAgg); + return new MarkJoinSlotReference(exprId, name.get(), existsHasAgg); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java index 3be5a56447..7cfaad72a2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java @@ -26,11 +26,13 @@ import org.apache.doris.nereids.util.Utils; import org.apache.doris.qe.ConnectContext; import com.google.common.base.Preconditions; +import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableList; import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.function.Supplier; import javax.annotation.Nullable; /** @@ -38,7 +40,7 @@ import javax.annotation.Nullable; */ public class SlotReference extends Slot { protected final ExprId exprId; - protected final String name; + protected final Supplier name; protected final DataType dataType; protected final boolean nullable; protected final List qualifier; @@ -50,7 +52,7 @@ public class SlotReference extends Slot { // the unique string representation of a SlotReference // different SlotReference will have different internalName // TODO: remove this member variable after mv selection is refactored - protected final Optional internalName; + protected final Supplier> internalName; private final TableIf table; private final Column column; @@ -84,6 +86,13 @@ public class SlotReference extends Slot { this(exprId, name, dataType, nullable, qualifier, table, column, internalName, null); } + public SlotReference(ExprId exprId, String name, DataType dataType, boolean nullable, + List qualifier, @Nullable TableIf table, @Nullable Column column, + Optional internalName, List subColLabels) { + this(exprId, () -> name, dataType, nullable, qualifier, table, column, + buildInternalName(() -> name, subColLabels, internalName), subColLabels); + } + /** * Constructor for SlotReference. * @@ -96,9 +105,9 @@ public class SlotReference extends Slot { * @param internalName the internalName of this slot * @param subColLabels subColumn access labels */ - public SlotReference(ExprId exprId, String name, DataType dataType, boolean nullable, + public SlotReference(ExprId exprId, Supplier name, DataType dataType, boolean nullable, List qualifier, @Nullable TableIf table, @Nullable Column column, - Optional internalName, List subColLabels) { + Supplier> internalName, List subColLabels) { this.exprId = exprId; this.name = name; this.dataType = dataType; @@ -108,14 +117,7 @@ public class SlotReference extends Slot { this.table = table; this.column = column; this.subColPath = subColLabels; - if (subColLabels != null && !this.subColPath.isEmpty()) { - // Modify internal name to distinguish from different sub-columns of same top level column, - // using the `.` to connect each part of paths - String fullName = internalName.orElse(name) + String.join(".", this.subColPath); - this.internalName = Optional.of(fullName); - } else { - this.internalName = internalName.isPresent() ? internalName : Optional.of(name); - } + this.internalName = internalName; } public static SlotReference of(String name, DataType type) { @@ -147,7 +149,7 @@ public class SlotReference extends Slot { @Override public String getName() { - return name; + return name.get(); } @Override @@ -172,7 +174,7 @@ public class SlotReference extends Slot { @Override public String getInternalName() { - return internalName.get(); + return internalName.get().get(); } public Optional getColumn() { @@ -185,21 +187,21 @@ public class SlotReference extends Slot { @Override public String toSql() { - return name; + return name.get(); } @Override public String toString() { // Just return name and exprId, add another method to show fully qualified name when it's necessary. - return name + "#" + exprId; + return name.get() + "#" + exprId; } @Override public String shapeInfo() { if (qualifier.isEmpty()) { - return name; + return name.get(); } else { - return qualifier.get(qualifier.size() - 1) + "." + name; + return qualifier.get(qualifier.size() - 1) + "." + name.get(); } } @@ -258,7 +260,8 @@ public class SlotReference extends Slot { @Override public SlotReference withName(String name) { - return new SlotReference(exprId, name, dataType, nullable, qualifier, table, column, internalName, subColPath); + return new SlotReference( + exprId, () -> name, dataType, nullable, qualifier, table, column, internalName, subColPath); } @Override @@ -277,4 +280,17 @@ public class SlotReference extends Slot { public boolean hasSubColPath() { return subColPath != null && !subColPath.isEmpty(); } + + private static Supplier> buildInternalName( + Supplier name, List subColLabels, Optional internalName) { + if (subColLabels != null && !subColLabels.isEmpty()) { + // Modify internal name to distinguish from different sub-columns of same top level column, + // using the `.` to connect each part of paths + return Suppliers.memoize(() -> + Optional.of(internalName.orElse(name.get()) + String.join(".", subColLabels))); + } else { + return Suppliers.memoize(() -> + internalName.isPresent() ? internalName : Optional.of(name.get())); + } + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/VirtualSlotReference.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/VirtualSlotReference.java index eb0829bf0e..b9b9eeb9b2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/VirtualSlotReference.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/VirtualSlotReference.java @@ -124,13 +124,13 @@ public class VirtualSlotReference extends SlotReference implements SlotNotFromCh if (this.nullable == newNullable) { return this; } - return new VirtualSlotReference(exprId, name, dataType, newNullable, qualifier, + return new VirtualSlotReference(exprId, name.get(), dataType, newNullable, qualifier, originExpression, computeLongValueMethod); } @Override public VirtualSlotReference withQualifier(List qualifier) { - return new VirtualSlotReference(exprId, name, dataType, nullable, qualifier, + return new VirtualSlotReference(exprId, name.get(), dataType, nullable, qualifier, originExpression, computeLongValueMethod); } @@ -142,7 +142,7 @@ public class VirtualSlotReference extends SlotReference implements SlotNotFromCh @Override public VirtualSlotReference withExprId(ExprId exprId) { - return new VirtualSlotReference(exprId, name, dataType, nullable, qualifier, + return new VirtualSlotReference(exprId, name.get(), dataType, nullable, qualifier, originExpression, computeLongValueMethod); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WhenClause.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WhenClause.java index 3cc3586990..4ce77f22df 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WhenClause.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WhenClause.java @@ -89,21 +89,6 @@ public class WhenClause extends Expression implements BinaryExpression, ExpectsI return EXPECTS_INPUT_TYPES; } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - if (!super.equals(o)) { - return false; - } - WhenClause other = (WhenClause) o; - return Objects.equals(left(), other.left()) && Objects.equals(right(), other.right()); - } - @Override public int hashCode() { return Objects.hash(left(), right()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WindowExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WindowExpression.java index c9531d61b6..d9fa8d36dd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WindowExpression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WindowExpression.java @@ -131,7 +131,7 @@ public class WindowExpression extends Expression { @Override public WindowExpression withChildren(List children) { - Preconditions.checkArgument(children.size() >= 1); + Preconditions.checkArgument(!children.isEmpty()); int index = 0; Expression func = children.get(index); index += 1; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/BoundFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/BoundFunction.java index d2089a8d32..33b587ce74 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/BoundFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/BoundFunction.java @@ -74,15 +74,8 @@ public abstract class BoundFunction extends Function implements ComputeSignature } @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - BoundFunction that = (BoundFunction) o; - return Objects.equals(name, that.name) && Objects.equals(children, that.children); + protected boolean extraEquals(Expression that) { + return Objects.equals(name, ((BoundFunction) that).name); } @Override @@ -92,11 +85,16 @@ public abstract class BoundFunction extends Function implements ComputeSignature @Override public String toSql() throws UnboundException { - String args = children() - .stream() - .map(Expression::toSql) - .collect(Collectors.joining(", ")); - return name + "(" + args + ")"; + StringBuilder sql = new StringBuilder(name).append("("); + int arity = arity(); + for (int i = 0; i < arity; i++) { + Expression arg = child(i); + sql.append(arg.toSql()); + if (i + 1 < arity) { + sql.append(", "); + } + } + return sql.append(")").toString(); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/AbstractPlan.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/AbstractPlan.java index cd0cfb5429..4be6d35dc9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/AbstractPlan.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/AbstractPlan.java @@ -75,8 +75,8 @@ public abstract class AbstractPlan extends AbstractTreeNode implements Pla this.type = Objects.requireNonNull(type, "type can not be null"); this.groupExpression = Objects.requireNonNull(groupExpression, "groupExpression can not be null"); Objects.requireNonNull(optLogicalProperties, "logicalProperties can not be null"); - this.logicalPropertiesSupplier = Suppliers.memoize(() -> optLogicalProperties.orElseGet( - this::computeLogicalProperties)); + this.logicalPropertiesSupplier = Suppliers.memoize(() -> + optLogicalProperties.orElseGet(this::computeLogicalProperties)); this.statistics = statistics; this.id = StatementScopeIdGenerator.newObjectId(); } @@ -166,8 +166,14 @@ public abstract class AbstractPlan extends AbstractTreeNode implements Pla @Override public LogicalProperties computeLogicalProperties() { - boolean hasUnboundChild = children.stream() - .anyMatch(child -> !child.bound()); + boolean hasUnboundChild = false; + for (Plan child : children) { + if (!child.bound()) { + hasUnboundChild = true; + break; + } + } + if (hasUnboundChild || hasUnboundExpression()) { return UnboundLogicalProperties.INSTANCE; } else { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/Plan.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/Plan.java index f31c6e97a0..1b237c72fd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/Plan.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/Plan.java @@ -61,8 +61,14 @@ public interface Plan extends TreeNode { return !(getLogicalProperties() instanceof UnboundLogicalProperties); } + /** hasUnboundExpression */ default boolean hasUnboundExpression() { - return getExpressions().stream().anyMatch(Expression::hasUnbound); + for (Expression expression : getExpressions()) { + if (expression.hasUnbound()) { + return true; + } + } + return false; } default boolean containsSlots(ImmutableSet slots) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index f35ae1a75d..1f44d128b2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -645,9 +645,11 @@ public class ExpressionUtils { public static Set mutableCollect(List expressions, Predicate> predicate) { - return expressions.stream() - .flatMap(expr -> expr.>collect(predicate).stream()) - .collect(Collectors.toSet()); + Set set = new HashSet<>(); + for (Expression expr : expressions) { + set.addAll(expr.collect(predicate)); + } + return set; } public static List collectAll(Collection expressions, @@ -717,9 +719,11 @@ public class ExpressionUtils { * Get input slot set from list of expressions. */ public static Set getInputSlotSet(Collection exprs) { - return exprs.stream() - .flatMap(expr -> expr.getInputSlots().stream()) - .collect(ImmutableSet.toImmutableSet()); + Set set = new HashSet<>(); + for (Expression expr : exprs) { + set.addAll(expr.getInputSlots()); + } + return set; } public static boolean checkTypeSkipCast(Expression expression, Class cls) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java index a4e25e2141..759b96c5b7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java @@ -25,7 +25,6 @@ import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.WindowExpression; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; -import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalCatalogRelation; @@ -37,6 +36,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList.Builder; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Lists; import com.google.common.collect.Sets; import java.util.Collection; @@ -125,16 +125,21 @@ public class PlanUtils { * get table set from plan root. */ public static ImmutableSet getTableSet(LogicalPlan plan) { - Set tableSet = new HashSet<>(); - tableSet.addAll((Collection) plan - .collect(LogicalCatalogRelation.class::isInstance)); - ImmutableSet resultSet = tableSet.stream().map(e -> e.getTable()) - .collect(ImmutableSet.toImmutableSet()); - return resultSet; + Set tableSet = plan.collect(LogicalCatalogRelation.class::isInstance); + return tableSet.stream() + .map(LogicalCatalogRelation::getTable) + .collect(ImmutableSet.toImmutableSet()); } /** fastGetChildrenOutput */ public static List fastGetChildrenOutputs(List children) { + switch (children.size()) { + case 1: return children.get(0).getOutput(); + case 0: return ImmutableList.of(); + default: { + } + } + int outputNum = 0; // child.output is cached by AbstractPlan.logicalProperties, // we can compute output num without the overhead of re-compute output @@ -175,22 +180,36 @@ public class PlanUtils { /** * collect non_window_agg_func */ - public static class CollectNonWindowedAggFuncs extends DefaultExpressionVisitor> { - - public static final CollectNonWindowedAggFuncs INSTANCE = new CollectNonWindowedAggFuncs(); - - @Override - public Void visitWindow(WindowExpression windowExpression, List context) { - for (Expression child : windowExpression.getExpressionsInWindowSpec()) { - child.accept(this, context); + public static class CollectNonWindowedAggFuncs { + public static List collect(Collection expressions) { + List aggFunctions = Lists.newArrayList(); + for (Expression expression : expressions) { + doCollect(expression, aggFunctions); } - return null; + return aggFunctions; } - @Override - public Void visitAggregateFunction(AggregateFunction aggregateFunction, List context) { - context.add(aggregateFunction); - return null; + public static List collect(Expression expression) { + List aggFuns = Lists.newArrayList(); + doCollect(expression, aggFuns); + return aggFuns; + } + + private static void doCollect(Expression expression, List aggFunctions) { + expression.foreach(expr -> { + if (expr instanceof AggregateFunction) { + aggFunctions.add((AggregateFunction) expr); + return true; + } else if (expr instanceof WindowExpression) { + WindowExpression windowExpression = (WindowExpression) expr; + for (Expression exprInWindowsSpec : windowExpression.getExpressionsInWindowSpec()) { + doCollect(exprInWindowsSpec, aggFunctions); + } + return true; + } else { + return false; + } + }); } } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java index f270ff4db0..e99a00e144 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java @@ -529,7 +529,7 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo Alias minXX = new Alias(new ExprId(5), new Min(xx.toSlot()), "min(xx)"); PlanChecker.from(connectContext).analyze(sql).printlnTree().matches(logicalProject( logicalSort(logicalProject(logicalAggregate(logicalProject(logicalOlapScan()) - .when(FieldChecker.check("projects", Lists.newArrayList(xx, a2, a1)))))) + .when(FieldChecker.check("projects", Lists.newArrayList(xx, a1, a2)))))) .when(FieldChecker.check("orderKeys", ImmutableList .of(new OrderKey(minXX.toSlot(), true, true)))))