diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java index 5688b9d48b..d6c783bbe9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java @@ -66,7 +66,7 @@ public class FillUpMissingSlots implements AnalysisRuleFactory { .flatMap(Set::stream) .filter(s -> !projectOutputSet.contains(s)) .collect(Collectors.toSet()); - if (notExistedInProject.size() == 0) { + if (notExistedInProject.isEmpty()) { return null; } List projects = ImmutableList.builder() @@ -128,7 +128,7 @@ public class FillUpMissingSlots implements AnalysisRuleFactory { .flatMap(Set::stream) .filter(s -> !childOutput.contains(s)) .collect(Collectors.toSet()); - if (notExistedInProject.size() == 0) { + if (notExistedInProject.isEmpty()) { return null; } LogicalProject project = sort.child().child(); @@ -165,7 +165,7 @@ public class FillUpMissingSlots implements AnalysisRuleFactory { .flatMap(Set::stream) .filter(s -> !projectOutputSet.contains(s)) .collect(Collectors.toSet()); - if (notExistedInProject.size() == 0) { + if (notExistedInProject.isEmpty()) { return null; } List projects = ImmutableList.builder() @@ -189,7 +189,7 @@ public class FillUpMissingSlots implements AnalysisRuleFactory { outputExpressions = aggregate.getOutputExpressions(); groupByExpressions = aggregate.getGroupByExpressions(); outputSubstitutionMap = outputExpressions.stream().filter(Alias.class::isInstance) - .collect(Collectors.toMap(alias -> alias.toSlot(), alias -> alias.child(0), + .collect(Collectors.toMap(NamedExpression::toSlot, alias -> alias.child(0), (k1, k2) -> k1)); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java index 5874c26e17..a8c3445261 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java @@ -37,12 +37,11 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalHaving; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.util.ExpressionUtils; -import org.apache.doris.nereids.util.PlanUtils; +import org.apache.doris.nereids.util.PlanUtils.CollectNonWindowedAggFuncs; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList.Builder; import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Lists; import com.google.common.collect.Sets; import java.util.HashSet; @@ -144,8 +143,7 @@ public class NormalizeAggregate implements RewriteRuleFactory, NormalizeToSlot { // collect all trival-agg List aggregateOutput = aggregate.getOutputExpressions(); - List aggFuncs = Lists.newArrayList(); - aggregateOutput.forEach(o -> o.accept(PlanUtils.CollectNonWindowedAggFuncs.INSTANCE, aggFuncs)); + List aggFuncs = CollectNonWindowedAggFuncs.collect(aggregateOutput); // split non-distinct agg child as two part // TRUE part 1: need push down itself, if it contains subqury or window expression diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java index 8437dc40b0..9451cc40bc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java @@ -39,11 +39,10 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat; import org.apache.doris.nereids.util.ExpressionUtils; -import org.apache.doris.nereids.util.PlanUtils; +import org.apache.doris.nereids.util.PlanUtils.CollectNonWindowedAggFuncs; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Sets; import com.google.common.collect.Sets.SetView; @@ -174,9 +173,7 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory { .flatMap(function -> function.getArguments().stream()) .collect(ImmutableSet.toImmutableSet()); - List aggregateFunctions = Lists.newArrayList(); - repeat.getOutputExpressions().forEach( - o -> o.accept(PlanUtils.CollectNonWindowedAggFuncs.INSTANCE, aggregateFunctions)); + List aggregateFunctions = CollectNonWindowedAggFuncs.collect(repeat.getOutputExpressions()); ImmutableSet argumentsOfAggregateFunction = aggregateFunctions.stream() .flatMap(function -> function.getArguments().stream().map(arg -> { @@ -271,9 +268,8 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory { @NotNull LogicalAggregate aggregate) { LogicalRepeat repeat = (LogicalRepeat) aggregate.child(); - List aggregateFunctions = Lists.newArrayList(); - aggregate.getOutputExpressions().forEach( - o -> o.accept(PlanUtils.CollectNonWindowedAggFuncs.INSTANCE, aggregateFunctions)); + List aggregateFunctions = + CollectNonWindowedAggFuncs.collect(aggregate.getOutputExpressions()); Set aggUsedSlots = aggregateFunctions.stream() .flatMap(e -> e.>collect(SlotReference.class::isInstance).stream()) .collect(ImmutableSet.toImmutableSet()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeToSlot.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeToSlot.java index efad94c665..683841a5f8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeToSlot.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeToSlot.java @@ -29,13 +29,11 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Maps; -import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Set; import java.util.function.BiFunction; -import java.util.stream.Collectors; import javax.annotation.Nullable; /** @@ -119,9 +117,11 @@ public interface NormalizeToSlot { public List normalizeToUseSlotRefWithoutWindowFunction( Collection expressions) { - return expressions.stream() - .map(e -> (E) e.accept(NormalizeWithoutWindowFunction.INSTANCE, normalizeToSlotMap)) - .collect(Collectors.toList()); + ImmutableList.Builder normalized = ImmutableList.builderWithExpectedSize(expressions.size()); + for (E expression : expressions) { + normalized.add((E) expression.accept(NormalizeWithoutWindowFunction.INSTANCE, normalizeToSlotMap)); + } + return normalized.build(); } /** @@ -155,8 +155,9 @@ public interface NormalizeToSlot { @Override public Expression visit(Expression expr, Map replaceMap) { - if (replaceMap.containsKey(expr)) { - return replaceMap.get(expr).remainExpr; + NormalizeToSlotTriplet triplet = replaceMap.get(expr); + if (triplet != null) { + return triplet.remainExpr; } return super.visit(expr, replaceMap); } @@ -164,10 +165,12 @@ public interface NormalizeToSlot { @Override public Expression visitWindow(WindowExpression windowExpression, Map replaceMap) { - if (replaceMap.containsKey(windowExpression)) { - return replaceMap.get(windowExpression).remainExpr; + NormalizeToSlotTriplet triplet = replaceMap.get(windowExpression); + if (triplet != null) { + return triplet.remainExpr; } - List newChildren = new ArrayList<>(); + ImmutableList.Builder newChildren = + ImmutableList.builderWithExpectedSize(windowExpression.arity()); Expression function = super.visit(windowExpression.getFunction(), replaceMap); newChildren.add(function); boolean hasNewChildren = function != windowExpression.getFunction(); @@ -185,10 +188,13 @@ public interface NormalizeToSlot { } newChildren.add(newChild); } + if (!hasNewChildren) { + return windowExpression; + } if (windowExpression.getWindowFrame().isPresent()) { newChildren.add(windowExpression.getWindowFrame().get()); } - return hasNewChildren ? windowExpression.withChildren(newChildren) : windowExpression; + return windowExpression.withChildren(newChildren.build()); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java index fb41384b5d..d37070865e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java @@ -144,9 +144,7 @@ public interface TreeNode> { boolean changed = false; for (NODE_TYPE child : children()) { NODE_TYPE newChild = child.rewriteUp(rewriteFunction); - if (child != newChild) { - changed = true; - } + changed |= child != newChild; newChildren.add(newChild); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Alias.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Alias.java index 1aacb02949..a15cec4722 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Alias.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Alias.java @@ -23,11 +23,13 @@ import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.DataType; import com.google.common.base.Preconditions; +import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableList; import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.function.Supplier; /** * Expression for alias, such as col1 as c1. @@ -35,7 +37,7 @@ import java.util.Optional; public class Alias extends NamedExpression implements UnaryExpression { private final ExprId exprId; - private final String name; + private final Supplier name; private final List qualifier; private final boolean nameFromChild; @@ -50,7 +52,8 @@ public class Alias extends NamedExpression implements UnaryExpression { } public Alias(Expression child) { - this(StatementScopeIdGenerator.newExprId(), child, child.toSql(), true); + this(StatementScopeIdGenerator.newExprId(), ImmutableList.of(child), + Suppliers.memoize(child::toSql), ImmutableList.of(), true); } public Alias(ExprId exprId, Expression child, String name) { @@ -62,6 +65,11 @@ public class Alias extends NamedExpression implements UnaryExpression { } public Alias(ExprId exprId, List child, String name, List qualifier, boolean nameFromChild) { + this(exprId, child, Suppliers.memoize(() -> name), qualifier, nameFromChild); + } + + private Alias(ExprId exprId, List child, Supplier name, + List qualifier, boolean nameFromChild) { super(child); this.exprId = exprId; this.name = name; @@ -73,6 +81,10 @@ public class Alias extends NamedExpression implements UnaryExpression { public Slot toSlot() throws UnboundException { SlotReference slotReference = child() instanceof SlotReference ? (SlotReference) child() : null; + + Supplier> internalName = nameFromChild + ? Suppliers.memoize(() -> Optional.of(child().toString())) + : () -> Optional.of(name.get()); return new SlotReference(exprId, name, child().getDataType(), child().nullable(), qualifier, slotReference != null ? ((SlotReference) child()).getTable().orElse(null) @@ -80,14 +92,16 @@ public class Alias extends NamedExpression implements UnaryExpression { slotReference != null ? slotReference.getColumn().orElse(null) : null, - nameFromChild ? Optional.of(child().toString()) : Optional.of(name), slotReference != null - ? slotReference.getSubColPath() - : null); + internalName, + slotReference != null + ? slotReference.getSubColPath() + : null + ); } @Override public String getName() throws UnboundException { - return name; + return name.get(); } @Override @@ -107,7 +121,7 @@ public class Alias extends NamedExpression implements UnaryExpression { @Override public String toSql() { - return child().toSql() + " AS `" + name + "`"; + return child().toSql() + " AS `" + name.get() + "`"; } @Override @@ -116,35 +130,31 @@ public class Alias extends NamedExpression implements UnaryExpression { } @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { + protected boolean extraEquals(Expression other) { + Alias that = (Alias) other; + if (!exprId.equals(that.exprId) || !qualifier.equals(that.qualifier)) { return false; } - Alias that = (Alias) o; - return exprId.equals(that.exprId) - && name.equals(that.name) - && qualifier.equals(that.qualifier) - && child().equals(that.child()); + + return nameFromChild || name.get().equals(that.name.get()); } @Override public int hashCode() { - return Objects.hash(exprId, name, qualifier, children()); + return Objects.hash(exprId, qualifier); } @Override public String toString() { - return child().toString() + " AS `" + name + "`#" + exprId; + return child().toString() + " AS `" + name.get() + "`#" + exprId; } @Override public Alias withChildren(List children) { Preconditions.checkArgument(children.size() == 1); if (nameFromChild) { - return new Alias(exprId, children, children.get(0).toSql(), qualifier, nameFromChild); + return new Alias(exprId, children, + Suppliers.memoize(() -> children.get(0).toSql()), qualifier, nameFromChild); } else { return new Alias(exprId, children, name, qualifier, nameFromChild); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Any.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Any.java index 2e4bc745b2..287362d817 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Any.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Any.java @@ -74,4 +74,8 @@ public class Any extends Expression implements LeafExpression { public boolean deepEquals(TreeNode that) { return true; } + + protected boolean supportCompareWidthAndDepth() { + return false; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ArrayItemReference.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ArrayItemReference.java index 3dd0ef6485..b43e481119 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ArrayItemReference.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ArrayItemReference.java @@ -147,7 +147,7 @@ public class ArrayItemReference extends NamedExpression implements ExpectsInputT @Override public ArrayItemSlot withExprId(ExprId exprId) { - return new ArrayItemSlot(exprId, name, dataType, nullable); + return new ArrayItemSlot(exprId, name.get(), dataType, nullable); } @Override @@ -157,7 +157,7 @@ public class ArrayItemReference extends NamedExpression implements ExpectsInputT @Override public SlotReference withNullable(boolean newNullable) { - return new ArrayItemSlot(exprId, name, dataType, nullable); + return new ArrayItemSlot(exprId, name.get(), dataType, nullable); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/BinaryOperator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/BinaryOperator.java index 2d06456d0a..01a61d576d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/BinaryOperator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/BinaryOperator.java @@ -68,16 +68,4 @@ public abstract class BinaryOperator extends Expression implements BinaryExpress public int hashCode() { return Objects.hash(symbol, left(), right()); } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - BinaryOperator other = (BinaryOperator) o; - return Objects.equals(left(), other.left()) && Objects.equals(right(), other.right()); - } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java index 1660efa3a3..a7947c82a5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java @@ -42,14 +42,16 @@ import org.apache.doris.nereids.types.StructType; import org.apache.doris.nereids.util.Utils; import com.google.common.base.Preconditions; +import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.ImmutableSet.Builder; import com.google.common.collect.Lists; import org.apache.commons.lang3.StringUtils; import java.util.List; -import java.util.Objects; import java.util.Optional; import java.util.Set; +import java.util.function.Supplier; /** * Abstract class for all Expression in Nereids. @@ -62,21 +64,30 @@ public abstract class Expression extends AbstractTreeNode implements private final int width; // Mark this expression is from predicate infer or something else infer private final boolean inferred; + private final boolean hasUnbound; + private final boolean compareWidthAndDepth; + private final Supplier> inputSlots = Suppliers.memoize(() -> collect(Slot.class::isInstance)); protected Expression(Expression... children) { super(children); int maxChildDepth = 0; int sumChildWidth = 0; + boolean hasUnbound = false; + boolean compareWidthAndDepth = true; for (int i = 0; i < children.length; ++i) { Expression child = children[i]; maxChildDepth = Math.max(child.depth, maxChildDepth); sumChildWidth += child.width; + hasUnbound |= child.hasUnbound; + compareWidthAndDepth &= (child.compareWidthAndDepth & child.supportCompareWidthAndDepth()); } this.depth = maxChildDepth + 1; this.width = sumChildWidth + ((children.length == 0) ? 1 : 0); + this.compareWidthAndDepth = compareWidthAndDepth; checkLimit(); this.inferred = false; + this.hasUnbound = hasUnbound || this instanceof Unbound; } protected Expression(List children) { @@ -87,16 +98,22 @@ public abstract class Expression extends AbstractTreeNode implements super(children); int maxChildDepth = 0; int sumChildWidth = 0; + boolean hasUnbound = false; + boolean compareWidthAndDepth = true; for (int i = 0; i < children.size(); ++i) { Expression child = children.get(i); maxChildDepth = Math.max(child.depth, maxChildDepth); sumChildWidth += child.width; + hasUnbound |= child.hasUnbound; + compareWidthAndDepth &= (child.compareWidthAndDepth & child.supportCompareWidthAndDepth()); } this.depth = maxChildDepth + 1; this.width = sumChildWidth + ((children.isEmpty()) ? 1 : 0); + this.compareWidthAndDepth = compareWidthAndDepth; checkLimit(); this.inferred = inferred; + this.hasUnbound = hasUnbound || this instanceof Unbound; } private void checkLimit() { @@ -293,7 +310,7 @@ public abstract class Expression extends AbstractTreeNode implements * Note that the input slots of subquery's inner plan is not included. */ public final Set getInputSlots() { - return collect(Slot.class::isInstance); + return inputSlots.get(); } /** @@ -302,13 +319,12 @@ public abstract class Expression extends AbstractTreeNode implements * Note that the input slots of subquery's inner plan is not included. */ public final Set getInputSlotExprIds() { - ImmutableSet.Builder result = ImmutableSet.builder(); - foreach(node -> { - if (node instanceof Slot) { - result.add(((Slot) node).getExprId()); - } - }); - return result.build(); + Set inputSlots = getInputSlots(); + Builder exprIds = ImmutableSet.builderWithExpectedSize(inputSlots.size()); + for (Slot inputSlot : inputSlots) { + exprIds.add(inputSlot.getExprId()); + } + return exprIds.build(); } public boolean isLiteral() { @@ -336,7 +352,26 @@ public abstract class Expression extends AbstractTreeNode implements return false; } Expression that = (Expression) o; - return Objects.equals(children(), that.children()); + if ((compareWidthAndDepth && (this.width != that.width || this.depth != that.depth)) + || arity() != that.arity() || !extraEquals(that)) { + return false; + } + return equalsChildren(that); + } + + protected boolean equalsChildren(Expression that) { + List children = children(); + List thatChildren = that.children(); + for (int i = 0; i < children.size(); i++) { + if (!children.get(i).equals(thatChildren.get(i))) { + return false; + } + } + return true; + } + + protected boolean extraEquals(Expression that) { + return true; } @Override @@ -348,18 +383,14 @@ public abstract class Expression extends AbstractTreeNode implements * This expression has unbound symbols or not. */ public boolean hasUnbound() { - if (this instanceof Unbound) { - return true; - } - for (Expression child : children) { - if (child.hasUnbound()) { - return true; - } - } - return false; + return this.hasUnbound; } public String shapeInfo() { return toSql(); } + + protected boolean supportCompareWidthAndDepth() { + return true; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/MarkJoinSlotReference.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/MarkJoinSlotReference.java index b9711ed52a..bfbd24647e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/MarkJoinSlotReference.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/MarkJoinSlotReference.java @@ -71,7 +71,7 @@ public class MarkJoinSlotReference extends SlotReference { @Override public MarkJoinSlotReference withExprId(ExprId exprId) { - return new MarkJoinSlotReference(exprId, name, existsHasAgg); + return new MarkJoinSlotReference(exprId, name.get(), existsHasAgg); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java index 3be5a56447..7cfaad72a2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java @@ -26,11 +26,13 @@ import org.apache.doris.nereids.util.Utils; import org.apache.doris.qe.ConnectContext; import com.google.common.base.Preconditions; +import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableList; import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.function.Supplier; import javax.annotation.Nullable; /** @@ -38,7 +40,7 @@ import javax.annotation.Nullable; */ public class SlotReference extends Slot { protected final ExprId exprId; - protected final String name; + protected final Supplier name; protected final DataType dataType; protected final boolean nullable; protected final List qualifier; @@ -50,7 +52,7 @@ public class SlotReference extends Slot { // the unique string representation of a SlotReference // different SlotReference will have different internalName // TODO: remove this member variable after mv selection is refactored - protected final Optional internalName; + protected final Supplier> internalName; private final TableIf table; private final Column column; @@ -84,6 +86,13 @@ public class SlotReference extends Slot { this(exprId, name, dataType, nullable, qualifier, table, column, internalName, null); } + public SlotReference(ExprId exprId, String name, DataType dataType, boolean nullable, + List qualifier, @Nullable TableIf table, @Nullable Column column, + Optional internalName, List subColLabels) { + this(exprId, () -> name, dataType, nullable, qualifier, table, column, + buildInternalName(() -> name, subColLabels, internalName), subColLabels); + } + /** * Constructor for SlotReference. * @@ -96,9 +105,9 @@ public class SlotReference extends Slot { * @param internalName the internalName of this slot * @param subColLabels subColumn access labels */ - public SlotReference(ExprId exprId, String name, DataType dataType, boolean nullable, + public SlotReference(ExprId exprId, Supplier name, DataType dataType, boolean nullable, List qualifier, @Nullable TableIf table, @Nullable Column column, - Optional internalName, List subColLabels) { + Supplier> internalName, List subColLabels) { this.exprId = exprId; this.name = name; this.dataType = dataType; @@ -108,14 +117,7 @@ public class SlotReference extends Slot { this.table = table; this.column = column; this.subColPath = subColLabels; - if (subColLabels != null && !this.subColPath.isEmpty()) { - // Modify internal name to distinguish from different sub-columns of same top level column, - // using the `.` to connect each part of paths - String fullName = internalName.orElse(name) + String.join(".", this.subColPath); - this.internalName = Optional.of(fullName); - } else { - this.internalName = internalName.isPresent() ? internalName : Optional.of(name); - } + this.internalName = internalName; } public static SlotReference of(String name, DataType type) { @@ -147,7 +149,7 @@ public class SlotReference extends Slot { @Override public String getName() { - return name; + return name.get(); } @Override @@ -172,7 +174,7 @@ public class SlotReference extends Slot { @Override public String getInternalName() { - return internalName.get(); + return internalName.get().get(); } public Optional getColumn() { @@ -185,21 +187,21 @@ public class SlotReference extends Slot { @Override public String toSql() { - return name; + return name.get(); } @Override public String toString() { // Just return name and exprId, add another method to show fully qualified name when it's necessary. - return name + "#" + exprId; + return name.get() + "#" + exprId; } @Override public String shapeInfo() { if (qualifier.isEmpty()) { - return name; + return name.get(); } else { - return qualifier.get(qualifier.size() - 1) + "." + name; + return qualifier.get(qualifier.size() - 1) + "." + name.get(); } } @@ -258,7 +260,8 @@ public class SlotReference extends Slot { @Override public SlotReference withName(String name) { - return new SlotReference(exprId, name, dataType, nullable, qualifier, table, column, internalName, subColPath); + return new SlotReference( + exprId, () -> name, dataType, nullable, qualifier, table, column, internalName, subColPath); } @Override @@ -277,4 +280,17 @@ public class SlotReference extends Slot { public boolean hasSubColPath() { return subColPath != null && !subColPath.isEmpty(); } + + private static Supplier> buildInternalName( + Supplier name, List subColLabels, Optional internalName) { + if (subColLabels != null && !subColLabels.isEmpty()) { + // Modify internal name to distinguish from different sub-columns of same top level column, + // using the `.` to connect each part of paths + return Suppliers.memoize(() -> + Optional.of(internalName.orElse(name.get()) + String.join(".", subColLabels))); + } else { + return Suppliers.memoize(() -> + internalName.isPresent() ? internalName : Optional.of(name.get())); + } + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/VirtualSlotReference.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/VirtualSlotReference.java index eb0829bf0e..b9b9eeb9b2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/VirtualSlotReference.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/VirtualSlotReference.java @@ -124,13 +124,13 @@ public class VirtualSlotReference extends SlotReference implements SlotNotFromCh if (this.nullable == newNullable) { return this; } - return new VirtualSlotReference(exprId, name, dataType, newNullable, qualifier, + return new VirtualSlotReference(exprId, name.get(), dataType, newNullable, qualifier, originExpression, computeLongValueMethod); } @Override public VirtualSlotReference withQualifier(List qualifier) { - return new VirtualSlotReference(exprId, name, dataType, nullable, qualifier, + return new VirtualSlotReference(exprId, name.get(), dataType, nullable, qualifier, originExpression, computeLongValueMethod); } @@ -142,7 +142,7 @@ public class VirtualSlotReference extends SlotReference implements SlotNotFromCh @Override public VirtualSlotReference withExprId(ExprId exprId) { - return new VirtualSlotReference(exprId, name, dataType, nullable, qualifier, + return new VirtualSlotReference(exprId, name.get(), dataType, nullable, qualifier, originExpression, computeLongValueMethod); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WhenClause.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WhenClause.java index 3cc3586990..4ce77f22df 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WhenClause.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WhenClause.java @@ -89,21 +89,6 @@ public class WhenClause extends Expression implements BinaryExpression, ExpectsI return EXPECTS_INPUT_TYPES; } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - if (!super.equals(o)) { - return false; - } - WhenClause other = (WhenClause) o; - return Objects.equals(left(), other.left()) && Objects.equals(right(), other.right()); - } - @Override public int hashCode() { return Objects.hash(left(), right()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WindowExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WindowExpression.java index c9531d61b6..d9fa8d36dd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WindowExpression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WindowExpression.java @@ -131,7 +131,7 @@ public class WindowExpression extends Expression { @Override public WindowExpression withChildren(List children) { - Preconditions.checkArgument(children.size() >= 1); + Preconditions.checkArgument(!children.isEmpty()); int index = 0; Expression func = children.get(index); index += 1; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/BoundFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/BoundFunction.java index d2089a8d32..33b587ce74 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/BoundFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/BoundFunction.java @@ -74,15 +74,8 @@ public abstract class BoundFunction extends Function implements ComputeSignature } @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - BoundFunction that = (BoundFunction) o; - return Objects.equals(name, that.name) && Objects.equals(children, that.children); + protected boolean extraEquals(Expression that) { + return Objects.equals(name, ((BoundFunction) that).name); } @Override @@ -92,11 +85,16 @@ public abstract class BoundFunction extends Function implements ComputeSignature @Override public String toSql() throws UnboundException { - String args = children() - .stream() - .map(Expression::toSql) - .collect(Collectors.joining(", ")); - return name + "(" + args + ")"; + StringBuilder sql = new StringBuilder(name).append("("); + int arity = arity(); + for (int i = 0; i < arity; i++) { + Expression arg = child(i); + sql.append(arg.toSql()); + if (i + 1 < arity) { + sql.append(", "); + } + } + return sql.append(")").toString(); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/AbstractPlan.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/AbstractPlan.java index cd0cfb5429..4be6d35dc9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/AbstractPlan.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/AbstractPlan.java @@ -75,8 +75,8 @@ public abstract class AbstractPlan extends AbstractTreeNode implements Pla this.type = Objects.requireNonNull(type, "type can not be null"); this.groupExpression = Objects.requireNonNull(groupExpression, "groupExpression can not be null"); Objects.requireNonNull(optLogicalProperties, "logicalProperties can not be null"); - this.logicalPropertiesSupplier = Suppliers.memoize(() -> optLogicalProperties.orElseGet( - this::computeLogicalProperties)); + this.logicalPropertiesSupplier = Suppliers.memoize(() -> + optLogicalProperties.orElseGet(this::computeLogicalProperties)); this.statistics = statistics; this.id = StatementScopeIdGenerator.newObjectId(); } @@ -166,8 +166,14 @@ public abstract class AbstractPlan extends AbstractTreeNode implements Pla @Override public LogicalProperties computeLogicalProperties() { - boolean hasUnboundChild = children.stream() - .anyMatch(child -> !child.bound()); + boolean hasUnboundChild = false; + for (Plan child : children) { + if (!child.bound()) { + hasUnboundChild = true; + break; + } + } + if (hasUnboundChild || hasUnboundExpression()) { return UnboundLogicalProperties.INSTANCE; } else { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/Plan.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/Plan.java index f31c6e97a0..1b237c72fd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/Plan.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/Plan.java @@ -61,8 +61,14 @@ public interface Plan extends TreeNode { return !(getLogicalProperties() instanceof UnboundLogicalProperties); } + /** hasUnboundExpression */ default boolean hasUnboundExpression() { - return getExpressions().stream().anyMatch(Expression::hasUnbound); + for (Expression expression : getExpressions()) { + if (expression.hasUnbound()) { + return true; + } + } + return false; } default boolean containsSlots(ImmutableSet slots) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index f35ae1a75d..1f44d128b2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -645,9 +645,11 @@ public class ExpressionUtils { public static Set mutableCollect(List expressions, Predicate> predicate) { - return expressions.stream() - .flatMap(expr -> expr.>collect(predicate).stream()) - .collect(Collectors.toSet()); + Set set = new HashSet<>(); + for (Expression expr : expressions) { + set.addAll(expr.collect(predicate)); + } + return set; } public static List collectAll(Collection expressions, @@ -717,9 +719,11 @@ public class ExpressionUtils { * Get input slot set from list of expressions. */ public static Set getInputSlotSet(Collection exprs) { - return exprs.stream() - .flatMap(expr -> expr.getInputSlots().stream()) - .collect(ImmutableSet.toImmutableSet()); + Set set = new HashSet<>(); + for (Expression expr : exprs) { + set.addAll(expr.getInputSlots()); + } + return set; } public static boolean checkTypeSkipCast(Expression expression, Class cls) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java index a4e25e2141..759b96c5b7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java @@ -25,7 +25,6 @@ import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.WindowExpression; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; -import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalCatalogRelation; @@ -37,6 +36,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList.Builder; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Lists; import com.google.common.collect.Sets; import java.util.Collection; @@ -125,16 +125,21 @@ public class PlanUtils { * get table set from plan root. */ public static ImmutableSet getTableSet(LogicalPlan plan) { - Set tableSet = new HashSet<>(); - tableSet.addAll((Collection) plan - .collect(LogicalCatalogRelation.class::isInstance)); - ImmutableSet resultSet = tableSet.stream().map(e -> e.getTable()) - .collect(ImmutableSet.toImmutableSet()); - return resultSet; + Set tableSet = plan.collect(LogicalCatalogRelation.class::isInstance); + return tableSet.stream() + .map(LogicalCatalogRelation::getTable) + .collect(ImmutableSet.toImmutableSet()); } /** fastGetChildrenOutput */ public static List fastGetChildrenOutputs(List children) { + switch (children.size()) { + case 1: return children.get(0).getOutput(); + case 0: return ImmutableList.of(); + default: { + } + } + int outputNum = 0; // child.output is cached by AbstractPlan.logicalProperties, // we can compute output num without the overhead of re-compute output @@ -175,22 +180,36 @@ public class PlanUtils { /** * collect non_window_agg_func */ - public static class CollectNonWindowedAggFuncs extends DefaultExpressionVisitor> { - - public static final CollectNonWindowedAggFuncs INSTANCE = new CollectNonWindowedAggFuncs(); - - @Override - public Void visitWindow(WindowExpression windowExpression, List context) { - for (Expression child : windowExpression.getExpressionsInWindowSpec()) { - child.accept(this, context); + public static class CollectNonWindowedAggFuncs { + public static List collect(Collection expressions) { + List aggFunctions = Lists.newArrayList(); + for (Expression expression : expressions) { + doCollect(expression, aggFunctions); } - return null; + return aggFunctions; } - @Override - public Void visitAggregateFunction(AggregateFunction aggregateFunction, List context) { - context.add(aggregateFunction); - return null; + public static List collect(Expression expression) { + List aggFuns = Lists.newArrayList(); + doCollect(expression, aggFuns); + return aggFuns; + } + + private static void doCollect(Expression expression, List aggFunctions) { + expression.foreach(expr -> { + if (expr instanceof AggregateFunction) { + aggFunctions.add((AggregateFunction) expr); + return true; + } else if (expr instanceof WindowExpression) { + WindowExpression windowExpression = (WindowExpression) expr; + for (Expression exprInWindowsSpec : windowExpression.getExpressionsInWindowSpec()) { + doCollect(exprInWindowsSpec, aggFunctions); + } + return true; + } else { + return false; + } + }); } } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java index f270ff4db0..e99a00e144 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlotsTest.java @@ -529,7 +529,7 @@ public class FillUpMissingSlotsTest extends AnalyzeCheckTestBase implements Memo Alias minXX = new Alias(new ExprId(5), new Min(xx.toSlot()), "min(xx)"); PlanChecker.from(connectContext).analyze(sql).printlnTree().matches(logicalProject( logicalSort(logicalProject(logicalAggregate(logicalProject(logicalOlapScan()) - .when(FieldChecker.check("projects", Lists.newArrayList(xx, a2, a1)))))) + .when(FieldChecker.check("projects", Lists.newArrayList(xx, a1, a2)))))) .when(FieldChecker.check("orderKeys", ImmutableList .of(new OrderKey(minXX.toSlot(), true, true)))))