[Refactor](Nereids) Fix expression constant and improve SlotExtractor (#11513)

1. Fix expression constant and add unit test.
2. Improve logic in SlotExtractor and remove useless class IterationVisitor.
This commit is contained in:
Shuo Wang
2022-08-08 17:36:21 +08:00
committed by GitHub
parent 9349746987
commit c1c635e944
16 changed files with 82 additions and 254 deletions

View File

@ -51,11 +51,6 @@ public class UnboundSlot extends Slot implements Unbound {
}).reduce((left, right) -> left + "." + right).orElse("");
}
@Override
public boolean isConstant() {
return false;
}
@Override
public String toSql() {
return nameParts.stream().map(Utils::quoteIfNeeded).reduce((left, right) -> left + "." + right).orElse("");

View File

@ -37,7 +37,6 @@ import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.visitor.SlotExtractor;
import org.apache.doris.nereids.trees.plans.AggPhase;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
@ -53,6 +52,7 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanVisitor;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.SlotExtractor;
import org.apache.doris.nereids.util.Utils;
import org.apache.doris.planner.AggregationNode;
import org.apache.doris.planner.DataPartition;

View File

@ -23,8 +23,8 @@ import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.visitor.SlotExtractor;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.util.SlotExtractor;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;

View File

@ -20,13 +20,13 @@ package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.visitor.SlotExtractor;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.SlotExtractor;
import com.google.common.base.Preconditions;

View File

@ -23,8 +23,8 @@ import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.visitor.SlotExtractor;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.SlotExtractor;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;

View File

@ -19,11 +19,11 @@ package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.visitor.SlotExtractor;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.SlotExtractor;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;

View File

@ -21,11 +21,11 @@ import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.visitor.SlotExtractor;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.SlotExtractor;
import com.google.common.collect.ImmutableList;

View File

@ -19,11 +19,11 @@ package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.visitor.SlotExtractor;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.util.SlotExtractor;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;

View File

@ -22,12 +22,12 @@ import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.visitor.SlotExtractor;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.SlotExtractor;
import com.google.common.collect.Lists;

View File

@ -24,12 +24,12 @@ import org.apache.doris.nereids.trees.expressions.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.visitor.SlotExtractor;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.SlotExtractor;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;

View File

@ -79,8 +79,12 @@ public abstract class Expression extends AbstractTreeNode<Expression> {
/**
* Whether the expression is a constant.
*/
public boolean isConstant() {
return children().stream().allMatch(Expression::isConstant);
public final boolean isConstant() {
if (this instanceof LeafExpression) {
return this instanceof Literal;
} else {
return children().stream().allMatch(Expression::isConstant);
}
}
public final Expression castTo(DataType targetType) throws AnalysisException {

View File

@ -79,11 +79,6 @@ public abstract class Literal extends Expression implements LeafExpression {
return visitor.visitLiteral(this, context);
}
@Override
public boolean isConstant() {
return true;
}
@Override
public boolean equals(Object o) {
if (this == o) {

View File

@ -23,6 +23,7 @@ import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.util.Utils;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Objects;
@ -37,6 +38,14 @@ public class SlotReference extends Slot {
private final DataType dataType;
private final boolean nullable;
public SlotReference(String name, DataType dataType) {
this(NamedExpressionUtil.newExprId(), name, dataType, true, ImmutableList.of());
}
public SlotReference(String name, DataType dataType, boolean nullable) {
this(NamedExpressionUtil.newExprId(), name, dataType, nullable, ImmutableList.of());
}
public SlotReference(String name, DataType dataType, boolean nullable, List<String> qualifier) {
this(NamedExpressionUtil.newExprId(), name, dataType, nullable, qualifier);
}

View File

@ -1,212 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions.visitor;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Arithmetic;
import org.apache.doris.nereids.trees.expressions.Between;
import org.apache.doris.nereids.trees.expressions.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.Divide;
import org.apache.doris.nereids.trees.expressions.DoubleLiteral;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
import org.apache.doris.nereids.trees.expressions.IntegerLiteral;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.Mod;
import org.apache.doris.nereids.trees.expressions.Multiply;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.NullLiteral;
import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.StringLiteral;
import org.apache.doris.nereids.trees.expressions.Subtract;
import org.apache.doris.nereids.trees.expressions.functions.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
/**
* Iterative traversal of an expression.
*/
public abstract class IterationVisitor<C> extends DefaultExpressionVisitor<Void, C> {
@Override
public Void visit(Expression expr, C context) {
return expr.accept(this, context);
}
@Override
public Void visitNot(Not expr, C context) {
visit(expr.child(), context);
return null;
}
@Override
public Void visitCompoundPredicate(CompoundPredicate expr, C context) {
visit(expr.left(), context);
visit(expr.right(), context);
return null;
}
@Override
public Void visitArithmetic(Arithmetic arithmetic, C context) {
visit(arithmetic.child(0), context);
if (arithmetic.getArithmeticOperator().isBinary()) {
visit(arithmetic.child(1), context);
}
return null;
}
@Override
public Void visitBetween(Between betweenPredicate, C context) {
visit(betweenPredicate.getCompareExpr(), context);
visit(betweenPredicate.getLowerBound(), context);
visit(betweenPredicate.getUpperBound(), context);
return null;
}
@Override
public Void visitAlias(Alias alias, C context) {
return visitNamedExpression(alias, context);
}
@Override
public Void visitComparisonPredicate(ComparisonPredicate cp, C context) {
visit(cp.left(), context);
visit(cp.right(), context);
return null;
}
@Override
public Void visitEqualTo(EqualTo equalTo, C context) {
return visitComparisonPredicate(equalTo, context);
}
@Override
public Void visitGreaterThan(GreaterThan greaterThan, C context) {
return visitComparisonPredicate(greaterThan, context);
}
@Override
public Void visitGreaterThanEqual(GreaterThanEqual greaterThanEqual, C context) {
return visitComparisonPredicate(greaterThanEqual, context);
}
@Override
public Void visitLessThan(LessThan lessThan, C context) {
return visitComparisonPredicate(lessThan, context);
}
@Override
public Void visitLessThanEqual(LessThanEqual lessThanEqual, C context) {
return visitComparisonPredicate(lessThanEqual, context);
}
@Override
public Void visitNullSafeEqual(NullSafeEqual nullSafeEqual, C context) {
return visitComparisonPredicate(nullSafeEqual, context);
}
@Override
public Void visitSlot(Slot slot, C context) {
return null;
}
@Override
public Void visitNamedExpression(NamedExpression namedExpression, C context) {
for (Expression child : namedExpression.children()) {
visit(child, context);
}
return null;
}
@Override
public Void visitBoundFunction(BoundFunction boundFunction, C context) {
for (Expression argument : boundFunction.getArguments()) {
visit(argument, context);
}
return null;
}
@Override
public Void visitAggregateFunction(AggregateFunction aggregateFunction, C context) {
return visitBoundFunction(aggregateFunction, context);
}
@Override
public Void visitAdd(Add add, C context) {
return visitArithmetic(add, context);
}
@Override
public Void visitSubtract(Subtract subtract, C context) {
return visitArithmetic(subtract, context);
}
@Override
public Void visitMultiply(Multiply multiply, C context) {
return visitArithmetic(multiply, context);
}
@Override
public Void visitDivide(Divide divide, C context) {
return visitArithmetic(divide, context);
}
@Override
public Void visitMod(Mod mod, C context) {
return visitArithmetic(mod, context);
}
@Override
public Void visitSlotReference(SlotReference slotReference, C context) {
return super.visitSlotReference(slotReference, context);
}
@Override
public Void visitBooleanLiteral(BooleanLiteral booleanLiteral, C context) {
return null;
}
@Override
public Void visitStringLiteral(StringLiteral stringLiteral, C context) {
return null;
}
@Override
public Void visitIntegerLiteral(IntegerLiteral integerLiteral, C context) {
return null;
}
@Override
public Void visitNullLiteral(NullLiteral nullLiteral, C context) {
return null;
}
@Override
public Void visitDoubleLiteral(DoubleLiteral doubleLiteral, C context) {
return null;
}
}

View File

@ -15,32 +15,38 @@
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions.visitor;
package org.apache.doris.nereids.util;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.Collection;
import java.util.List;
import java.util.Set;
/**
* Extracts the SlotReference contained in the expression.
*/
public class SlotExtractor extends IterationVisitor<List<Slot>> {
public class SlotExtractor {
private static final DefaultExpressionVisitor<Void, Set<Slot>> SLOT_COLLECTOR
= new DefaultExpressionVisitor<Void, Set<Slot>>() {
@Override
public Void visitSlotReference(SlotReference slotReference, Set<Slot> context) {
context.add(slotReference);
return null;
}
};
/**
* extract slot reference.
*/
public static Set<Slot> extractSlot(Collection<Expression> expressions) {
Set<Slot> slots = Sets.newLinkedHashSet();
Set<Slot> slots = Sets.newHashSet();
for (Expression expression : expressions) {
slots.addAll(extractSlot(expression));
extractSlot(expression, slots);
}
return slots;
}
@ -49,24 +55,14 @@ public class SlotExtractor extends IterationVisitor<List<Slot>> {
* extract slot reference.
*/
public static Set<Slot> extractSlot(Expression... expressions) {
Set<Slot> slots = Sets.newLinkedHashSet();
Set<Slot> slots = Sets.newHashSet();
for (Expression expression : expressions) {
slots.addAll(extractSlot(expression));
extractSlot(expression, slots);
}
return slots;
}
private static List<Slot> extractSlot(Expression expression) {
List<Slot> slots = Lists.newArrayList();
new SlotExtractor().visit(expression, slots);
return slots;
}
@Override
public Void visitSlotReference(SlotReference slotReference, List<Slot> context) {
context.add(slotReference);
return null;
private static void extractSlot(Expression expression, Set<Slot> slots) {
expression.accept(SLOT_COLLECTOR, slots);
}
}

View File

@ -0,0 +1,41 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.nereids.types.IntegerType;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
public class ExpressionTest {
@Test
public void testConstantExpression() {
// literal is constant
Assertions.assertTrue(new StringLiteral("abc").isConstant());
// slot reference is not constant
Assertions.assertFalse(new SlotReference("a", IntegerType.INSTANCE).isConstant());
// `1 + 2` is constant
Assertions.assertTrue(new Add(new IntegerLiteral(1), new IntegerLiteral(2)).isConstant());
// `a + 1` is not constant
Assertions.assertFalse(
new Add(new SlotReference("a", IntegerType.INSTANCE), new IntegerLiteral(1)).isConstant());
}
}