[enhancement](session var) varariable to control whether to rewrite OR to IN or not (#15437)
This commit is contained in:
@ -204,6 +204,9 @@ public class SessionVariable implements Serializable, Writable {
|
||||
|
||||
// percentage of EXEC_MEM_LIMIT
|
||||
public static final String BROADCAST_HASHTABLE_MEM_LIMIT_PERCENTAGE = "broadcast_hashtable_mem_limit_percentage";
|
||||
|
||||
public static final String REWRITE_OR_TO_IN_PREDICATE_THRESHOLD = "rewrite_or_to_in_predicate_threshold";
|
||||
|
||||
public static final String NEREIDS_STAR_SCHEMA_SUPPORT = "nereids_star_schema_support";
|
||||
|
||||
public static final String NEREIDS_CBO_PENALTY_FACTOR = "nereids_cbo_penalty_factor";
|
||||
@ -554,6 +557,9 @@ public class SessionVariable implements Serializable, Writable {
|
||||
@VariableMgr.VarAttr(name = NEREIDS_STAR_SCHEMA_SUPPORT)
|
||||
private boolean nereidsStarSchemaSupport = true;
|
||||
|
||||
@VariableMgr.VarAttr(name = REWRITE_OR_TO_IN_PREDICATE_THRESHOLD)
|
||||
private int rewriteOrToInPredicateThreshold = 2;
|
||||
|
||||
@VariableMgr.VarAttr(name = NEREIDS_CBO_PENALTY_FACTOR)
|
||||
private double nereidsCboPenaltyFactor = 0.7;
|
||||
@VariableMgr.VarAttr(name = ENABLE_NEREIDS_TRACE)
|
||||
@ -702,6 +708,14 @@ public class SessionVariable implements Serializable, Writable {
|
||||
this.blockEncryptionMode = blockEncryptionMode;
|
||||
}
|
||||
|
||||
public void setRewriteOrToInPredicateThreshold(int threshold) {
|
||||
this.rewriteOrToInPredicateThreshold = threshold;
|
||||
}
|
||||
|
||||
public int getRewriteOrToInPredicateThreshold() {
|
||||
return rewriteOrToInPredicateThreshold;
|
||||
}
|
||||
|
||||
public long getMaxExecMemByte() {
|
||||
return maxExecMemByte;
|
||||
}
|
||||
|
||||
@ -28,6 +28,7 @@ import org.apache.doris.analysis.SlotRef;
|
||||
import org.apache.doris.analysis.TableName;
|
||||
import org.apache.doris.common.AnalysisException;
|
||||
import org.apache.doris.planner.PlanNode;
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
import org.apache.doris.rewrite.ExprRewriter.ClauseType;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
@ -43,6 +44,7 @@ import org.apache.logging.log4j.Logger;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashSet;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@ -462,6 +464,13 @@ public class ExtractCommonFactorsRule implements ExprRewriteRule {
|
||||
boolean isOrToInAllowed = true;
|
||||
Set<String> slotSet = new LinkedHashSet<>();
|
||||
|
||||
int rewriteThreshold;
|
||||
if (ConnectContext.get() == null) {
|
||||
rewriteThreshold = 2;
|
||||
} else {
|
||||
rewriteThreshold = ConnectContext.get().getSessionVariable().getRewriteOrToInPredicateThreshold();
|
||||
}
|
||||
|
||||
for (int i = 0; i < exprs.size(); i++) {
|
||||
Expr predicate = exprs.get(i);
|
||||
if (!(predicate instanceof BinaryPredicate) && !(predicate instanceof InPredicate)) {
|
||||
@ -492,22 +501,44 @@ public class ExtractCommonFactorsRule implements ExprRewriteRule {
|
||||
// isOrToInAllowed : true, means can rewrite
|
||||
// slotSet.size : nums of columnName in exprs, should be 1
|
||||
if (isOrToInAllowed && slotSet.size() == 1) {
|
||||
// slotRef to get ColumnName
|
||||
|
||||
// SlotRef firstSlot = (SlotRef) exprs.get(0).getChild(0);
|
||||
List<Expr> childrenList = exprs.get(0).getChildren();
|
||||
inPredicate = new InPredicate(exprs.get(0).getChild(0),
|
||||
childrenList.subList(1, childrenList.size()), false);
|
||||
|
||||
for (int i = 1; i < exprs.size(); i++) {
|
||||
childrenList = exprs.get(i).getChildren();
|
||||
inPredicate.addChildren(childrenList.subList(1, childrenList.size()));
|
||||
if (exprs.size() < rewriteThreshold) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// get deduplication list
|
||||
List<Expr> deduplicationExprs = getDeduplicationList(exprs);
|
||||
inPredicate = new InPredicate(deduplicationExprs.get(0),
|
||||
deduplicationExprs.subList(1, deduplicationExprs.size()), false);
|
||||
}
|
||||
|
||||
return inPredicate;
|
||||
}
|
||||
|
||||
public List<Expr> getDeduplicationList(List<Expr> exprs) {
|
||||
Set<Expr> set = new HashSet<>();
|
||||
List<Expr> deduplicationExprList = new ArrayList<>();
|
||||
|
||||
deduplicationExprList.add(exprs.get(0).getChild(0));
|
||||
|
||||
for (Expr expr : exprs) {
|
||||
if (expr instanceof BinaryPredicate) {
|
||||
if (!set.contains(expr.getChild(1))) {
|
||||
set.add(expr.getChild(1));
|
||||
deduplicationExprList.add(expr.getChild(1));
|
||||
}
|
||||
} else {
|
||||
List<Expr> childrenExprs = expr.getChildren();
|
||||
for (Expr childrenExpr : childrenExprs.subList(1, childrenExprs.size())) {
|
||||
if (!set.contains(childrenExpr)) {
|
||||
set.add(childrenExpr);
|
||||
deduplicationExprList.add(childrenExpr);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return deduplicationExprList;
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert RangeSet to Compound Predicate
|
||||
* @param slotRef: <k1>
|
||||
|
||||
@ -2238,5 +2238,22 @@ public class QueryPlanTest extends TestWithFeService {
|
||||
sql = "SELECT * from test1 where (query_time = 1 or query_time = 2) and (scan_bytes = 2 or scan_bytes = 3)";
|
||||
explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "EXPLAIN " + sql);
|
||||
Assert.assertTrue(explainString.contains("PREDICATES: `query_time` IN (1, 2), `scan_bytes` IN (2, 3)"));
|
||||
|
||||
sql = "SELECT * from test1 where query_time = 1 or query_time = 2 or query_time = 3 or query_time = 1";
|
||||
explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "EXPLAIN " + sql);
|
||||
Assert.assertTrue(explainString.contains("PREDICATES: `query_time` IN (1, 2, 3)"));
|
||||
|
||||
sql = "SELECT * from test1 where query_time = 1 or query_time = 2 or query_time in (3, 2)";
|
||||
explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "EXPLAIN " + sql);
|
||||
Assert.assertTrue(explainString.contains("PREDICATES: `query_time` IN (1, 2, 3)"));
|
||||
|
||||
connectContext.getSessionVariable().setRewriteOrToInPredicateThreshold(100);
|
||||
sql = "SELECT * from test1 where query_time = 1 or query_time = 2 or query_time in (3, 4)";
|
||||
explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "EXPLAIN " + sql);
|
||||
Assert.assertTrue(explainString.contains("PREDICATES: (`query_time` = 1 OR `query_time` = 2 OR `query_time` IN (3, 4))"));
|
||||
|
||||
sql = "SELECT * from test1 where (query_time = 1 or query_time = 2) and query_time in (3, 4)";
|
||||
explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "EXPLAIN " + sql);
|
||||
Assert.assertTrue(explainString.contains("PREDICATES: (`query_time` = 1 OR `query_time` = 2), `query_time` IN (3, 4)"));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user