[opt](Nerieds) add infer props to expression (#28953)

This commit is contained in:
seawinde
2023-12-27 12:21:25 +08:00
committed by GitHub
parent 86a0eb9344
commit 17f92155d7
13 changed files with 237 additions and 59 deletions

View File

@ -88,6 +88,9 @@ public class PredicatePropagation {
public Set<Expression> infer(Set<Expression> predicates) {
Set<Expression> inferred = Sets.newHashSet();
for (Expression predicate : predicates) {
// if we support more infer predicate expression type, we should impl withInferred() method.
// And should add inferred props in withChildren() method just like ComparisonPredicate,
// and it's subclass, to mark the predicate is from infer.
if (!(predicate instanceof ComparisonPredicate)) {
continue;
}
@ -130,7 +133,7 @@ public class PredicatePropagation {
.comparisonPredicate.withChildren(newLeft, newRight);
Expression expr = SimplifyComparisonPredicate.INSTANCE
.rewrite(TypeCoercionUtils.processComparisonPredicate(newPredicate), null);
return DateFunctionRewrite.INSTANCE.rewrite(expr, null);
return DateFunctionRewrite.INSTANCE.rewrite(expr, null).withInferred(true);
}
private Expression inferOneSide(Expression predicateOneSide, Expression equalLeft, Expression equalRight) {

View File

@ -34,7 +34,11 @@ public abstract class BinaryOperator extends Expression implements BinaryExpress
protected final String symbol;
public BinaryOperator(List<Expression> children, String symbol) {
super(children);
this(children, symbol, false);
}
public BinaryOperator(List<Expression> children, String symbol, boolean inferred) {
super(children, inferred);
this.symbol = symbol;
}

View File

@ -33,7 +33,11 @@ import java.util.List;
public abstract class ComparisonPredicate extends BinaryOperator {
public ComparisonPredicate(List<Expression> children, String symbol) {
super(children, symbol);
this(children, symbol, false);
}
public ComparisonPredicate(List<Expression> children, String symbol, boolean inferred) {
super(children, symbol, inferred);
}
@Override

View File

@ -25,7 +25,11 @@ import java.util.List;
public abstract class EqualPredicate extends ComparisonPredicate {
protected EqualPredicate(List<Expression> children, String symbol) {
super(children, symbol);
this(children, symbol, false);
}
protected EqualPredicate(List<Expression> children, String symbol, boolean inferred) {
super(children, symbol, inferred);
}
@Override

View File

@ -32,11 +32,19 @@ import java.util.List;
public class EqualTo extends EqualPredicate implements PropagateNullable {
public EqualTo(Expression left, Expression right) {
super(ImmutableList.of(left, right), "=");
this(left, right, false);
}
public EqualTo(Expression left, Expression right, boolean inferred) {
super(ImmutableList.of(left, right), "=", inferred);
}
private EqualTo(List<Expression> children) {
super(children, "=");
this(children, false);
}
private EqualTo(List<Expression> children, boolean inferred) {
super(children, "=", inferred);
}
@Override
@ -47,7 +55,12 @@ public class EqualTo extends EqualPredicate implements PropagateNullable {
@Override
public EqualTo withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 2);
return new EqualTo(children);
return new EqualTo(children, this.isInferred());
}
@Override
public Expression withInferred(boolean inferred) {
return new EqualTo(this.children, inferred);
}
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {

View File

@ -59,6 +59,8 @@ public abstract class Expression extends AbstractTreeNode<Expression> implements
protected Optional<String> exprName = Optional.empty();
private final int depth;
private final int width;
// Mark this expression is from predicate infer or something else infer
private final boolean inferred;
protected Expression(Expression... children) {
super(children);
@ -69,6 +71,7 @@ public abstract class Expression extends AbstractTreeNode<Expression> implements
.mapToInt(e -> e.width)
.sum() + (children.length == 0 ? 1 : 0);
checkLimit();
this.inferred = false;
}
protected Expression(List<Expression> children) {
@ -80,6 +83,19 @@ public abstract class Expression extends AbstractTreeNode<Expression> implements
.mapToInt(e -> e.width)
.sum() + (children.isEmpty() ? 1 : 0);
checkLimit();
this.inferred = false;
}
protected Expression(List<Expression> children, boolean inferred) {
super(children);
depth = children.stream()
.mapToInt(e -> e.depth)
.max().orElse(0) + 1;
width = children.stream()
.mapToInt(e -> e.width)
.sum() + (children.isEmpty() ? 1 : 0);
checkLimit();
this.inferred = inferred;
}
private void checkLimit() {
@ -216,11 +232,19 @@ public abstract class Expression extends AbstractTreeNode<Expression> implements
return depth;
}
public boolean isInferred() {
return inferred;
}
@Override
public Expression withChildren(List<Expression> children) {
throw new RuntimeException();
}
public Expression withInferred(boolean inferred) {
throw new RuntimeException("current expression has not impl the withInferred method");
}
/**
* Whether the expression is a constant.
*/

View File

@ -37,11 +37,19 @@ public class GreaterThan extends ComparisonPredicate implements PropagateNullabl
* @param right right child of greater than
*/
public GreaterThan(Expression left, Expression right) {
super(ImmutableList.of(left, right), ">");
this(left, right, false);
}
public GreaterThan(Expression left, Expression right, boolean inferred) {
super(ImmutableList.of(left, right), ">", inferred);
}
private GreaterThan(List<Expression> children) {
super(children, ">");
this(children, false);
}
private GreaterThan(List<Expression> children, boolean inferred) {
super(children, ">", inferred);
}
@Override
@ -57,7 +65,12 @@ public class GreaterThan extends ComparisonPredicate implements PropagateNullabl
@Override
public GreaterThan withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 2);
return new GreaterThan(children);
return new GreaterThan(children, this.isInferred());
}
@Override
public Expression withInferred(boolean inferred) {
return new GreaterThan(this.children, inferred);
}
@Override

View File

@ -32,11 +32,19 @@ import java.util.List;
public class GreaterThanEqual extends ComparisonPredicate implements PropagateNullable {
public GreaterThanEqual(Expression left, Expression right) {
super(ImmutableList.of(left, right), ">=");
this(left, right, false);
}
public GreaterThanEqual(Expression left, Expression right, boolean inferred) {
super(ImmutableList.of(left, right), ">=", inferred);
}
private GreaterThanEqual(List<Expression> children) {
super(children, ">=");
this(children, false);
}
private GreaterThanEqual(List<Expression> children, boolean inferred) {
super(children, ">=", inferred);
}
@Override
@ -52,7 +60,12 @@ public class GreaterThanEqual extends ComparisonPredicate implements PropagateNu
@Override
public GreaterThanEqual withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 2);
return new GreaterThanEqual(children);
return new GreaterThanEqual(children, this.isInferred());
}
@Override
public Expression withInferred(boolean inferred) {
return new GreaterThanEqual(this.children, inferred);
}
@Override

View File

@ -31,11 +31,19 @@ import java.util.List;
*/
public class LessThan extends ComparisonPredicate implements PropagateNullable {
public LessThan(Expression left, Expression right) {
super(ImmutableList.of(left, right), "<");
this(left, right, false);
}
public LessThan(Expression left, Expression right, boolean inferred) {
super(ImmutableList.of(left, right), "<", inferred);
}
private LessThan(List<Expression> children) {
super(children, "<");
this(children, false);
}
private LessThan(List<Expression> children, boolean inferred) {
super(children, "<", inferred);
}
@Override
@ -51,7 +59,12 @@ public class LessThan extends ComparisonPredicate implements PropagateNullable {
@Override
public LessThan withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 2);
return new LessThan(children);
return new LessThan(children, this.isInferred());
}
@Override
public Expression withInferred(boolean inferred) {
return new LessThan(this.children, inferred);
}
@Override

View File

@ -37,11 +37,19 @@ public class LessThanEqual extends ComparisonPredicate implements PropagateNulla
* @param right right child of Less Than And Equal
*/
public LessThanEqual(Expression left, Expression right) {
super(ImmutableList.of(left, right), "<=");
this(left, right, false);
}
public LessThanEqual(Expression left, Expression right, boolean inferred) {
super(ImmutableList.of(left, right), "<=", inferred);
}
private LessThanEqual(List<Expression> children) {
super(children, "<=");
this(children, false);
}
private LessThanEqual(List<Expression> children, boolean inferred) {
super(children, "<=", inferred);
}
@Override
@ -57,7 +65,12 @@ public class LessThanEqual extends ComparisonPredicate implements PropagateNulla
@Override
public LessThanEqual withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 2);
return new LessThanEqual(children);
return new LessThanEqual(children, this.isInferred());
}
@Override
public Expression withInferred(boolean inferred) {
return new LessThanEqual(this.children, inferred);
}
@Override

View File

@ -31,11 +31,19 @@ import java.util.List;
*/
public class NullSafeEqual extends EqualPredicate implements AlwaysNotNullable {
public NullSafeEqual(Expression left, Expression right) {
super(ImmutableList.of(left, right), "<=>");
this(left, right, false);
}
public NullSafeEqual(Expression left, Expression right, boolean inferred) {
super(ImmutableList.of(left, right), "<=>", inferred);
}
private NullSafeEqual(List<Expression> children) {
super(children, "<=>");
this(children, false);
}
private NullSafeEqual(List<Expression> children, boolean inferred) {
super(children, "<=>", inferred);
}
@Override
@ -51,7 +59,12 @@ public class NullSafeEqual extends EqualPredicate implements AlwaysNotNullable {
@Override
public NullSafeEqual withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 2);
return new NullSafeEqual(children);
return new NullSafeEqual(children, this.isInferred());
}
@Override
public Expression withInferred(boolean inferred) {
return new NullSafeEqual(this.children, inferred);
}
@Override

View File

@ -48,6 +48,7 @@ import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.visitor.ExpressionLineageReplacer;
@ -639,4 +640,25 @@ public class ExpressionUtils {
}
);
}
/**
* Check the expression is inferred or not, if inferred return true, nor return false
*/
public static boolean isInferred(Expression expression) {
return expression.accept(new DefaultExpressionVisitor<Boolean, Void>() {
@Override
public Boolean visit(Expression expr, Void context) {
boolean inferred = expr.isInferred();
if (expr.isInferred() || expr.children().isEmpty()) {
return inferred;
}
inferred = true;
for (Expression child : expr.children()) {
inferred = inferred && child.accept(this, context);
}
return inferred;
}
}, null);
}
}