[enhancement](Nereids) Optimize expression (#32067)

1. optimize expression comparison by
      a) flatter method call stack
      b) short circuit: if some simple field not equals, then return, and not compare to the big field, like Alias.name
    2. lazy compute Alias.name which is computed by toSql. Now Alias.toSlot() will not generate long name immediately
    3. cache Expresssion.inputSlots, it can save time when invoke this method multiple times
    4. always compute Expression.unbound, it can avoid traverse the big expression tree

this pr can save about 200ms when submit some long sqls
This commit is contained in:
924060929
2024-03-13 14:07:40 +08:00
committed by yiguolei
parent 3c4234111b
commit 2f89ec961f
21 changed files with 233 additions and 168 deletions

View File

@ -66,7 +66,7 @@ public class FillUpMissingSlots implements AnalysisRuleFactory {
.flatMap(Set::stream)
.filter(s -> !projectOutputSet.contains(s))
.collect(Collectors.toSet());
if (notExistedInProject.size() == 0) {
if (notExistedInProject.isEmpty()) {
return null;
}
List<NamedExpression> projects = ImmutableList.<NamedExpression>builder()
@ -128,7 +128,7 @@ public class FillUpMissingSlots implements AnalysisRuleFactory {
.flatMap(Set::stream)
.filter(s -> !childOutput.contains(s))
.collect(Collectors.toSet());
if (notExistedInProject.size() == 0) {
if (notExistedInProject.isEmpty()) {
return null;
}
LogicalProject<?> project = sort.child().child();
@ -165,7 +165,7 @@ public class FillUpMissingSlots implements AnalysisRuleFactory {
.flatMap(Set::stream)
.filter(s -> !projectOutputSet.contains(s))
.collect(Collectors.toSet());
if (notExistedInProject.size() == 0) {
if (notExistedInProject.isEmpty()) {
return null;
}
List<NamedExpression> projects = ImmutableList.<NamedExpression>builder()
@ -189,7 +189,7 @@ public class FillUpMissingSlots implements AnalysisRuleFactory {
outputExpressions = aggregate.getOutputExpressions();
groupByExpressions = aggregate.getGroupByExpressions();
outputSubstitutionMap = outputExpressions.stream().filter(Alias.class::isInstance)
.collect(Collectors.toMap(alias -> alias.toSlot(), alias -> alias.child(0),
.collect(Collectors.toMap(NamedExpression::toSlot, alias -> alias.child(0),
(k1, k2) -> k1));
}

View File

@ -37,12 +37,11 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalHaving;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanUtils;
import org.apache.doris.nereids.util.PlanUtils.CollectNonWindowedAggFuncs;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.HashSet;
@ -144,8 +143,7 @@ public class NormalizeAggregate implements RewriteRuleFactory, NormalizeToSlot {
// collect all trival-agg
List<NamedExpression> aggregateOutput = aggregate.getOutputExpressions();
List<AggregateFunction> aggFuncs = Lists.newArrayList();
aggregateOutput.forEach(o -> o.accept(PlanUtils.CollectNonWindowedAggFuncs.INSTANCE, aggFuncs));
List<AggregateFunction> aggFuncs = CollectNonWindowedAggFuncs.collect(aggregateOutput);
// split non-distinct agg child as two part
// TRUE part 1: need push down itself, if it contains subqury or window expression

View File

@ -39,11 +39,10 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanUtils;
import org.apache.doris.nereids.util.PlanUtils.CollectNonWindowedAggFuncs;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.google.common.collect.Sets.SetView;
@ -174,9 +173,7 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory {
.flatMap(function -> function.getArguments().stream())
.collect(ImmutableSet.toImmutableSet());
List<AggregateFunction> aggregateFunctions = Lists.newArrayList();
repeat.getOutputExpressions().forEach(
o -> o.accept(PlanUtils.CollectNonWindowedAggFuncs.INSTANCE, aggregateFunctions));
List<AggregateFunction> aggregateFunctions = CollectNonWindowedAggFuncs.collect(repeat.getOutputExpressions());
ImmutableSet<Expression> argumentsOfAggregateFunction = aggregateFunctions.stream()
.flatMap(function -> function.getArguments().stream().map(arg -> {
@ -271,9 +268,8 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory {
@NotNull LogicalAggregate<Plan> aggregate) {
LogicalRepeat<Plan> repeat = (LogicalRepeat<Plan>) aggregate.child();
List<AggregateFunction> aggregateFunctions = Lists.newArrayList();
aggregate.getOutputExpressions().forEach(
o -> o.accept(PlanUtils.CollectNonWindowedAggFuncs.INSTANCE, aggregateFunctions));
List<AggregateFunction> aggregateFunctions =
CollectNonWindowedAggFuncs.collect(aggregate.getOutputExpressions());
Set<Slot> aggUsedSlots = aggregateFunctions.stream()
.flatMap(e -> e.<Set<SlotReference>>collect(SlotReference.class::isInstance).stream())
.collect(ImmutableSet.toImmutableSet());

View File

@ -29,13 +29,11 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Maps;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
/**
@ -119,9 +117,11 @@ public interface NormalizeToSlot {
public <E extends Expression> List<E> normalizeToUseSlotRefWithoutWindowFunction(
Collection<E> expressions) {
return expressions.stream()
.map(e -> (E) e.accept(NormalizeWithoutWindowFunction.INSTANCE, normalizeToSlotMap))
.collect(Collectors.toList());
ImmutableList.Builder<E> normalized = ImmutableList.builderWithExpectedSize(expressions.size());
for (E expression : expressions) {
normalized.add((E) expression.accept(NormalizeWithoutWindowFunction.INSTANCE, normalizeToSlotMap));
}
return normalized.build();
}
/**
@ -155,8 +155,9 @@ public interface NormalizeToSlot {
@Override
public Expression visit(Expression expr, Map<Expression, NormalizeToSlotTriplet> replaceMap) {
if (replaceMap.containsKey(expr)) {
return replaceMap.get(expr).remainExpr;
NormalizeToSlotTriplet triplet = replaceMap.get(expr);
if (triplet != null) {
return triplet.remainExpr;
}
return super.visit(expr, replaceMap);
}
@ -164,10 +165,12 @@ public interface NormalizeToSlot {
@Override
public Expression visitWindow(WindowExpression windowExpression,
Map<Expression, NormalizeToSlotTriplet> replaceMap) {
if (replaceMap.containsKey(windowExpression)) {
return replaceMap.get(windowExpression).remainExpr;
NormalizeToSlotTriplet triplet = replaceMap.get(windowExpression);
if (triplet != null) {
return triplet.remainExpr;
}
List<Expression> newChildren = new ArrayList<>();
ImmutableList.Builder<Expression> newChildren =
ImmutableList.builderWithExpectedSize(windowExpression.arity());
Expression function = super.visit(windowExpression.getFunction(), replaceMap);
newChildren.add(function);
boolean hasNewChildren = function != windowExpression.getFunction();
@ -185,10 +188,13 @@ public interface NormalizeToSlot {
}
newChildren.add(newChild);
}
if (!hasNewChildren) {
return windowExpression;
}
if (windowExpression.getWindowFrame().isPresent()) {
newChildren.add(windowExpression.getWindowFrame().get());
}
return hasNewChildren ? windowExpression.withChildren(newChildren) : windowExpression;
return windowExpression.withChildren(newChildren.build());
}
}

View File

@ -144,9 +144,7 @@ public interface TreeNode<NODE_TYPE extends TreeNode<NODE_TYPE>> {
boolean changed = false;
for (NODE_TYPE child : children()) {
NODE_TYPE newChild = child.rewriteUp(rewriteFunction);
if (child != newChild) {
changed = true;
}
changed |= child != newChild;
newChildren.add(newChild);
}

View File

@ -23,11 +23,13 @@ import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
import com.google.common.base.Preconditions;
import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Supplier;
/**
* Expression for alias, such as col1 as c1.
@ -35,7 +37,7 @@ import java.util.Optional;
public class Alias extends NamedExpression implements UnaryExpression {
private final ExprId exprId;
private final String name;
private final Supplier<String> name;
private final List<String> qualifier;
private final boolean nameFromChild;
@ -50,7 +52,8 @@ public class Alias extends NamedExpression implements UnaryExpression {
}
public Alias(Expression child) {
this(StatementScopeIdGenerator.newExprId(), child, child.toSql(), true);
this(StatementScopeIdGenerator.newExprId(), ImmutableList.of(child),
Suppliers.memoize(child::toSql), ImmutableList.of(), true);
}
public Alias(ExprId exprId, Expression child, String name) {
@ -62,6 +65,11 @@ public class Alias extends NamedExpression implements UnaryExpression {
}
public Alias(ExprId exprId, List<Expression> child, String name, List<String> qualifier, boolean nameFromChild) {
this(exprId, child, Suppliers.memoize(() -> name), qualifier, nameFromChild);
}
private Alias(ExprId exprId, List<Expression> child, Supplier<String> name,
List<String> qualifier, boolean nameFromChild) {
super(child);
this.exprId = exprId;
this.name = name;
@ -73,6 +81,10 @@ public class Alias extends NamedExpression implements UnaryExpression {
public Slot toSlot() throws UnboundException {
SlotReference slotReference = child() instanceof SlotReference
? (SlotReference) child() : null;
Supplier<Optional<String>> internalName = nameFromChild
? Suppliers.memoize(() -> Optional.of(child().toString()))
: () -> Optional.of(name.get());
return new SlotReference(exprId, name, child().getDataType(), child().nullable(), qualifier,
slotReference != null
? ((SlotReference) child()).getTable().orElse(null)
@ -80,14 +92,16 @@ public class Alias extends NamedExpression implements UnaryExpression {
slotReference != null
? slotReference.getColumn().orElse(null)
: null,
nameFromChild ? Optional.of(child().toString()) : Optional.of(name), slotReference != null
? slotReference.getSubColPath()
: null);
internalName,
slotReference != null
? slotReference.getSubColPath()
: null
);
}
@Override
public String getName() throws UnboundException {
return name;
return name.get();
}
@Override
@ -107,7 +121,7 @@ public class Alias extends NamedExpression implements UnaryExpression {
@Override
public String toSql() {
return child().toSql() + " AS `" + name + "`";
return child().toSql() + " AS `" + name.get() + "`";
}
@Override
@ -116,35 +130,31 @@ public class Alias extends NamedExpression implements UnaryExpression {
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
protected boolean extraEquals(Expression other) {
Alias that = (Alias) other;
if (!exprId.equals(that.exprId) || !qualifier.equals(that.qualifier)) {
return false;
}
Alias that = (Alias) o;
return exprId.equals(that.exprId)
&& name.equals(that.name)
&& qualifier.equals(that.qualifier)
&& child().equals(that.child());
return nameFromChild || name.get().equals(that.name.get());
}
@Override
public int hashCode() {
return Objects.hash(exprId, name, qualifier, children());
return Objects.hash(exprId, qualifier);
}
@Override
public String toString() {
return child().toString() + " AS `" + name + "`#" + exprId;
return child().toString() + " AS `" + name.get() + "`#" + exprId;
}
@Override
public Alias withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 1);
if (nameFromChild) {
return new Alias(exprId, children, children.get(0).toSql(), qualifier, nameFromChild);
return new Alias(exprId, children,
Suppliers.memoize(() -> children.get(0).toSql()), qualifier, nameFromChild);
} else {
return new Alias(exprId, children, name, qualifier, nameFromChild);
}

View File

@ -74,4 +74,8 @@ public class Any extends Expression implements LeafExpression {
public boolean deepEquals(TreeNode<?> that) {
return true;
}
protected boolean supportCompareWidthAndDepth() {
return false;
}
}

View File

@ -147,7 +147,7 @@ public class ArrayItemReference extends NamedExpression implements ExpectsInputT
@Override
public ArrayItemSlot withExprId(ExprId exprId) {
return new ArrayItemSlot(exprId, name, dataType, nullable);
return new ArrayItemSlot(exprId, name.get(), dataType, nullable);
}
@Override
@ -157,7 +157,7 @@ public class ArrayItemReference extends NamedExpression implements ExpectsInputT
@Override
public SlotReference withNullable(boolean newNullable) {
return new ArrayItemSlot(exprId, name, dataType, nullable);
return new ArrayItemSlot(exprId, name.get(), dataType, nullable);
}
@Override

View File

@ -68,16 +68,4 @@ public abstract class BinaryOperator extends Expression implements BinaryExpress
public int hashCode() {
return Objects.hash(symbol, left(), right());
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
BinaryOperator other = (BinaryOperator) o;
return Objects.equals(left(), other.left()) && Objects.equals(right(), other.right());
}
}

View File

@ -42,14 +42,16 @@ import org.apache.doris.nereids.types.StructType;
import org.apache.doris.nereids.util.Utils;
import com.google.common.base.Preconditions;
import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSet.Builder;
import com.google.common.collect.Lists;
import org.apache.commons.lang3.StringUtils;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Supplier;
/**
* Abstract class for all Expression in Nereids.
@ -62,21 +64,30 @@ public abstract class Expression extends AbstractTreeNode<Expression> implements
private final int width;
// Mark this expression is from predicate infer or something else infer
private final boolean inferred;
private final boolean hasUnbound;
private final boolean compareWidthAndDepth;
private final Supplier<Set<Slot>> inputSlots = Suppliers.memoize(() -> collect(Slot.class::isInstance));
protected Expression(Expression... children) {
super(children);
int maxChildDepth = 0;
int sumChildWidth = 0;
boolean hasUnbound = false;
boolean compareWidthAndDepth = true;
for (int i = 0; i < children.length; ++i) {
Expression child = children[i];
maxChildDepth = Math.max(child.depth, maxChildDepth);
sumChildWidth += child.width;
hasUnbound |= child.hasUnbound;
compareWidthAndDepth &= (child.compareWidthAndDepth & child.supportCompareWidthAndDepth());
}
this.depth = maxChildDepth + 1;
this.width = sumChildWidth + ((children.length == 0) ? 1 : 0);
this.compareWidthAndDepth = compareWidthAndDepth;
checkLimit();
this.inferred = false;
this.hasUnbound = hasUnbound || this instanceof Unbound;
}
protected Expression(List<Expression> children) {
@ -87,16 +98,22 @@ public abstract class Expression extends AbstractTreeNode<Expression> implements
super(children);
int maxChildDepth = 0;
int sumChildWidth = 0;
boolean hasUnbound = false;
boolean compareWidthAndDepth = true;
for (int i = 0; i < children.size(); ++i) {
Expression child = children.get(i);
maxChildDepth = Math.max(child.depth, maxChildDepth);
sumChildWidth += child.width;
hasUnbound |= child.hasUnbound;
compareWidthAndDepth &= (child.compareWidthAndDepth & child.supportCompareWidthAndDepth());
}
this.depth = maxChildDepth + 1;
this.width = sumChildWidth + ((children.isEmpty()) ? 1 : 0);
this.compareWidthAndDepth = compareWidthAndDepth;
checkLimit();
this.inferred = inferred;
this.hasUnbound = hasUnbound || this instanceof Unbound;
}
private void checkLimit() {
@ -293,7 +310,7 @@ public abstract class Expression extends AbstractTreeNode<Expression> implements
* Note that the input slots of subquery's inner plan is not included.
*/
public final Set<Slot> getInputSlots() {
return collect(Slot.class::isInstance);
return inputSlots.get();
}
/**
@ -302,13 +319,12 @@ public abstract class Expression extends AbstractTreeNode<Expression> implements
* Note that the input slots of subquery's inner plan is not included.
*/
public final Set<ExprId> getInputSlotExprIds() {
ImmutableSet.Builder<ExprId> result = ImmutableSet.builder();
foreach(node -> {
if (node instanceof Slot) {
result.add(((Slot) node).getExprId());
}
});
return result.build();
Set<Slot> inputSlots = getInputSlots();
Builder<ExprId> exprIds = ImmutableSet.builderWithExpectedSize(inputSlots.size());
for (Slot inputSlot : inputSlots) {
exprIds.add(inputSlot.getExprId());
}
return exprIds.build();
}
public boolean isLiteral() {
@ -336,7 +352,26 @@ public abstract class Expression extends AbstractTreeNode<Expression> implements
return false;
}
Expression that = (Expression) o;
return Objects.equals(children(), that.children());
if ((compareWidthAndDepth && (this.width != that.width || this.depth != that.depth))
|| arity() != that.arity() || !extraEquals(that)) {
return false;
}
return equalsChildren(that);
}
protected boolean equalsChildren(Expression that) {
List<Expression> children = children();
List<Expression> thatChildren = that.children();
for (int i = 0; i < children.size(); i++) {
if (!children.get(i).equals(thatChildren.get(i))) {
return false;
}
}
return true;
}
protected boolean extraEquals(Expression that) {
return true;
}
@Override
@ -348,18 +383,14 @@ public abstract class Expression extends AbstractTreeNode<Expression> implements
* This expression has unbound symbols or not.
*/
public boolean hasUnbound() {
if (this instanceof Unbound) {
return true;
}
for (Expression child : children) {
if (child.hasUnbound()) {
return true;
}
}
return false;
return this.hasUnbound;
}
public String shapeInfo() {
return toSql();
}
protected boolean supportCompareWidthAndDepth() {
return true;
}
}

View File

@ -71,7 +71,7 @@ public class MarkJoinSlotReference extends SlotReference {
@Override
public MarkJoinSlotReference withExprId(ExprId exprId) {
return new MarkJoinSlotReference(exprId, name, existsHasAgg);
return new MarkJoinSlotReference(exprId, name.get(), existsHasAgg);
}
@Override

View File

@ -26,11 +26,13 @@ import org.apache.doris.nereids.util.Utils;
import org.apache.doris.qe.ConnectContext;
import com.google.common.base.Preconditions;
import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Supplier;
import javax.annotation.Nullable;
/**
@ -38,7 +40,7 @@ import javax.annotation.Nullable;
*/
public class SlotReference extends Slot {
protected final ExprId exprId;
protected final String name;
protected final Supplier<String> name;
protected final DataType dataType;
protected final boolean nullable;
protected final List<String> qualifier;
@ -50,7 +52,7 @@ public class SlotReference extends Slot {
// the unique string representation of a SlotReference
// different SlotReference will have different internalName
// TODO: remove this member variable after mv selection is refactored
protected final Optional<String> internalName;
protected final Supplier<Optional<String>> internalName;
private final TableIf table;
private final Column column;
@ -84,6 +86,13 @@ public class SlotReference extends Slot {
this(exprId, name, dataType, nullable, qualifier, table, column, internalName, null);
}
public SlotReference(ExprId exprId, String name, DataType dataType, boolean nullable,
List<String> qualifier, @Nullable TableIf table, @Nullable Column column,
Optional<String> internalName, List<String> subColLabels) {
this(exprId, () -> name, dataType, nullable, qualifier, table, column,
buildInternalName(() -> name, subColLabels, internalName), subColLabels);
}
/**
* Constructor for SlotReference.
*
@ -96,9 +105,9 @@ public class SlotReference extends Slot {
* @param internalName the internalName of this slot
* @param subColLabels subColumn access labels
*/
public SlotReference(ExprId exprId, String name, DataType dataType, boolean nullable,
public SlotReference(ExprId exprId, Supplier<String> name, DataType dataType, boolean nullable,
List<String> qualifier, @Nullable TableIf table, @Nullable Column column,
Optional<String> internalName, List<String> subColLabels) {
Supplier<Optional<String>> internalName, List<String> subColLabels) {
this.exprId = exprId;
this.name = name;
this.dataType = dataType;
@ -108,14 +117,7 @@ public class SlotReference extends Slot {
this.table = table;
this.column = column;
this.subColPath = subColLabels;
if (subColLabels != null && !this.subColPath.isEmpty()) {
// Modify internal name to distinguish from different sub-columns of same top level column,
// using the `.` to connect each part of paths
String fullName = internalName.orElse(name) + String.join(".", this.subColPath);
this.internalName = Optional.of(fullName);
} else {
this.internalName = internalName.isPresent() ? internalName : Optional.of(name);
}
this.internalName = internalName;
}
public static SlotReference of(String name, DataType type) {
@ -147,7 +149,7 @@ public class SlotReference extends Slot {
@Override
public String getName() {
return name;
return name.get();
}
@Override
@ -172,7 +174,7 @@ public class SlotReference extends Slot {
@Override
public String getInternalName() {
return internalName.get();
return internalName.get().get();
}
public Optional<Column> getColumn() {
@ -185,21 +187,21 @@ public class SlotReference extends Slot {
@Override
public String toSql() {
return name;
return name.get();
}
@Override
public String toString() {
// Just return name and exprId, add another method to show fully qualified name when it's necessary.
return name + "#" + exprId;
return name.get() + "#" + exprId;
}
@Override
public String shapeInfo() {
if (qualifier.isEmpty()) {
return name;
return name.get();
} else {
return qualifier.get(qualifier.size() - 1) + "." + name;
return qualifier.get(qualifier.size() - 1) + "." + name.get();
}
}
@ -258,7 +260,8 @@ public class SlotReference extends Slot {
@Override
public SlotReference withName(String name) {
return new SlotReference(exprId, name, dataType, nullable, qualifier, table, column, internalName, subColPath);
return new SlotReference(
exprId, () -> name, dataType, nullable, qualifier, table, column, internalName, subColPath);
}
@Override
@ -277,4 +280,17 @@ public class SlotReference extends Slot {
public boolean hasSubColPath() {
return subColPath != null && !subColPath.isEmpty();
}
private static Supplier<Optional<String>> buildInternalName(
Supplier<String> name, List<String> subColLabels, Optional<String> internalName) {
if (subColLabels != null && !subColLabels.isEmpty()) {
// Modify internal name to distinguish from different sub-columns of same top level column,
// using the `.` to connect each part of paths
return Suppliers.memoize(() ->
Optional.of(internalName.orElse(name.get()) + String.join(".", subColLabels)));
} else {
return Suppliers.memoize(() ->
internalName.isPresent() ? internalName : Optional.of(name.get()));
}
}
}

View File

@ -124,13 +124,13 @@ public class VirtualSlotReference extends SlotReference implements SlotNotFromCh
if (this.nullable == newNullable) {
return this;
}
return new VirtualSlotReference(exprId, name, dataType, newNullable, qualifier,
return new VirtualSlotReference(exprId, name.get(), dataType, newNullable, qualifier,
originExpression, computeLongValueMethod);
}
@Override
public VirtualSlotReference withQualifier(List<String> qualifier) {
return new VirtualSlotReference(exprId, name, dataType, nullable, qualifier,
return new VirtualSlotReference(exprId, name.get(), dataType, nullable, qualifier,
originExpression, computeLongValueMethod);
}
@ -142,7 +142,7 @@ public class VirtualSlotReference extends SlotReference implements SlotNotFromCh
@Override
public VirtualSlotReference withExprId(ExprId exprId) {
return new VirtualSlotReference(exprId, name, dataType, nullable, qualifier,
return new VirtualSlotReference(exprId, name.get(), dataType, nullable, qualifier,
originExpression, computeLongValueMethod);
}
}

View File

@ -89,21 +89,6 @@ public class WhenClause extends Expression implements BinaryExpression, ExpectsI
return EXPECTS_INPUT_TYPES;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
if (!super.equals(o)) {
return false;
}
WhenClause other = (WhenClause) o;
return Objects.equals(left(), other.left()) && Objects.equals(right(), other.right());
}
@Override
public int hashCode() {
return Objects.hash(left(), right());

View File

@ -131,7 +131,7 @@ public class WindowExpression extends Expression {
@Override
public WindowExpression withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() >= 1);
Preconditions.checkArgument(!children.isEmpty());
int index = 0;
Expression func = children.get(index);
index += 1;

View File

@ -74,15 +74,8 @@ public abstract class BoundFunction extends Function implements ComputeSignature
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
BoundFunction that = (BoundFunction) o;
return Objects.equals(name, that.name) && Objects.equals(children, that.children);
protected boolean extraEquals(Expression that) {
return Objects.equals(name, ((BoundFunction) that).name);
}
@Override
@ -92,11 +85,16 @@ public abstract class BoundFunction extends Function implements ComputeSignature
@Override
public String toSql() throws UnboundException {
String args = children()
.stream()
.map(Expression::toSql)
.collect(Collectors.joining(", "));
return name + "(" + args + ")";
StringBuilder sql = new StringBuilder(name).append("(");
int arity = arity();
for (int i = 0; i < arity; i++) {
Expression arg = child(i);
sql.append(arg.toSql());
if (i + 1 < arity) {
sql.append(", ");
}
}
return sql.append(")").toString();
}
@Override

View File

@ -75,8 +75,8 @@ public abstract class AbstractPlan extends AbstractTreeNode<Plan> implements Pla
this.type = Objects.requireNonNull(type, "type can not be null");
this.groupExpression = Objects.requireNonNull(groupExpression, "groupExpression can not be null");
Objects.requireNonNull(optLogicalProperties, "logicalProperties can not be null");
this.logicalPropertiesSupplier = Suppliers.memoize(() -> optLogicalProperties.orElseGet(
this::computeLogicalProperties));
this.logicalPropertiesSupplier = Suppliers.memoize(() ->
optLogicalProperties.orElseGet(this::computeLogicalProperties));
this.statistics = statistics;
this.id = StatementScopeIdGenerator.newObjectId();
}
@ -166,8 +166,14 @@ public abstract class AbstractPlan extends AbstractTreeNode<Plan> implements Pla
@Override
public LogicalProperties computeLogicalProperties() {
boolean hasUnboundChild = children.stream()
.anyMatch(child -> !child.bound());
boolean hasUnboundChild = false;
for (Plan child : children) {
if (!child.bound()) {
hasUnboundChild = true;
break;
}
}
if (hasUnboundChild || hasUnboundExpression()) {
return UnboundLogicalProperties.INSTANCE;
} else {

View File

@ -61,8 +61,14 @@ public interface Plan extends TreeNode<Plan> {
return !(getLogicalProperties() instanceof UnboundLogicalProperties);
}
/** hasUnboundExpression */
default boolean hasUnboundExpression() {
return getExpressions().stream().anyMatch(Expression::hasUnbound);
for (Expression expression : getExpressions()) {
if (expression.hasUnbound()) {
return true;
}
}
return false;
}
default boolean containsSlots(ImmutableSet<Slot> slots) {

View File

@ -645,9 +645,11 @@ public class ExpressionUtils {
public static <E> Set<E> mutableCollect(List<? extends Expression> expressions,
Predicate<TreeNode<Expression>> predicate) {
return expressions.stream()
.flatMap(expr -> expr.<Set<E>>collect(predicate).stream())
.collect(Collectors.toSet());
Set<E> set = new HashSet<>();
for (Expression expr : expressions) {
set.addAll(expr.collect(predicate));
}
return set;
}
public static <E> List<E> collectAll(Collection<? extends Expression> expressions,
@ -717,9 +719,11 @@ public class ExpressionUtils {
* Get input slot set from list of expressions.
*/
public static Set<Slot> getInputSlotSet(Collection<? extends Expression> exprs) {
return exprs.stream()
.flatMap(expr -> expr.getInputSlots().stream())
.collect(ImmutableSet.toImmutableSet());
Set<Slot> set = new HashSet<>();
for (Expression expr : exprs) {
set.addAll(expr.getInputSlots());
}
return set;
}
public static boolean checkTypeSkipCast(Expression expression, Class<? extends Expression> cls) {

View File

@ -25,7 +25,6 @@ import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalCatalogRelation;
@ -37,6 +36,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.Collection;
@ -125,16 +125,21 @@ public class PlanUtils {
* get table set from plan root.
*/
public static ImmutableSet<TableIf> getTableSet(LogicalPlan plan) {
Set<LogicalCatalogRelation> tableSet = new HashSet<>();
tableSet.addAll((Collection<? extends LogicalCatalogRelation>) plan
.collect(LogicalCatalogRelation.class::isInstance));
ImmutableSet<TableIf> resultSet = tableSet.stream().map(e -> e.getTable())
.collect(ImmutableSet.toImmutableSet());
return resultSet;
Set<LogicalCatalogRelation> tableSet = plan.collect(LogicalCatalogRelation.class::isInstance);
return tableSet.stream()
.map(LogicalCatalogRelation::getTable)
.collect(ImmutableSet.<TableIf>toImmutableSet());
}
/** fastGetChildrenOutput */
public static List<Slot> fastGetChildrenOutputs(List<Plan> children) {
switch (children.size()) {
case 1: return children.get(0).getOutput();
case 0: return ImmutableList.of();
default: {
}
}
int outputNum = 0;
// child.output is cached by AbstractPlan.logicalProperties,
// we can compute output num without the overhead of re-compute output
@ -175,22 +180,36 @@ public class PlanUtils {
/**
* collect non_window_agg_func
*/
public static class CollectNonWindowedAggFuncs extends DefaultExpressionVisitor<Void, List<AggregateFunction>> {
public static final CollectNonWindowedAggFuncs INSTANCE = new CollectNonWindowedAggFuncs();
@Override
public Void visitWindow(WindowExpression windowExpression, List<AggregateFunction> context) {
for (Expression child : windowExpression.getExpressionsInWindowSpec()) {
child.accept(this, context);
public static class CollectNonWindowedAggFuncs {
public static List<AggregateFunction> collect(Collection<? extends Expression> expressions) {
List<AggregateFunction> aggFunctions = Lists.newArrayList();
for (Expression expression : expressions) {
doCollect(expression, aggFunctions);
}
return null;
return aggFunctions;
}
@Override
public Void visitAggregateFunction(AggregateFunction aggregateFunction, List<AggregateFunction> context) {
context.add(aggregateFunction);
return null;
public static List<AggregateFunction> collect(Expression expression) {
List<AggregateFunction> aggFuns = Lists.newArrayList();
doCollect(expression, aggFuns);
return aggFuns;
}
private static void doCollect(Expression expression, List<AggregateFunction> aggFunctions) {
expression.foreach(expr -> {
if (expr instanceof AggregateFunction) {
aggFunctions.add((AggregateFunction) expr);
return true;
} else if (expr instanceof WindowExpression) {
WindowExpression windowExpression = (WindowExpression) expr;
for (Expression exprInWindowsSpec : windowExpression.getExpressionsInWindowSpec()) {
doCollect(exprInWindowsSpec, aggFunctions);
}
return true;
} else {
return false;
}
});
}
}
}