[enhancement](session var) varariable to control whether to rewrite OR to IN or not (#15437)

This commit is contained in:
Henry2SS
2022-12-29 14:50:32 +08:00
committed by GitHub
parent e2603ca883
commit 25b257e37c
5 changed files with 74 additions and 11 deletions

View File

@ -204,6 +204,9 @@ public class SessionVariable implements Serializable, Writable {
// percentage of EXEC_MEM_LIMIT
public static final String BROADCAST_HASHTABLE_MEM_LIMIT_PERCENTAGE = "broadcast_hashtable_mem_limit_percentage";
public static final String REWRITE_OR_TO_IN_PREDICATE_THRESHOLD = "rewrite_or_to_in_predicate_threshold";
public static final String NEREIDS_STAR_SCHEMA_SUPPORT = "nereids_star_schema_support";
public static final String NEREIDS_CBO_PENALTY_FACTOR = "nereids_cbo_penalty_factor";
@ -554,6 +557,9 @@ public class SessionVariable implements Serializable, Writable {
@VariableMgr.VarAttr(name = NEREIDS_STAR_SCHEMA_SUPPORT)
private boolean nereidsStarSchemaSupport = true;
@VariableMgr.VarAttr(name = REWRITE_OR_TO_IN_PREDICATE_THRESHOLD)
private int rewriteOrToInPredicateThreshold = 2;
@VariableMgr.VarAttr(name = NEREIDS_CBO_PENALTY_FACTOR)
private double nereidsCboPenaltyFactor = 0.7;
@VariableMgr.VarAttr(name = ENABLE_NEREIDS_TRACE)
@ -702,6 +708,14 @@ public class SessionVariable implements Serializable, Writable {
this.blockEncryptionMode = blockEncryptionMode;
}
public void setRewriteOrToInPredicateThreshold(int threshold) {
this.rewriteOrToInPredicateThreshold = threshold;
}
public int getRewriteOrToInPredicateThreshold() {
return rewriteOrToInPredicateThreshold;
}
public long getMaxExecMemByte() {
return maxExecMemByte;
}

View File

@ -28,6 +28,7 @@ import org.apache.doris.analysis.SlotRef;
import org.apache.doris.analysis.TableName;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.planner.PlanNode;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.rewrite.ExprRewriter.ClauseType;
import com.google.common.base.Preconditions;
@ -43,6 +44,7 @@ import org.apache.logging.log4j.Logger;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
@ -462,6 +464,13 @@ public class ExtractCommonFactorsRule implements ExprRewriteRule {
boolean isOrToInAllowed = true;
Set<String> slotSet = new LinkedHashSet<>();
int rewriteThreshold;
if (ConnectContext.get() == null) {
rewriteThreshold = 2;
} else {
rewriteThreshold = ConnectContext.get().getSessionVariable().getRewriteOrToInPredicateThreshold();
}
for (int i = 0; i < exprs.size(); i++) {
Expr predicate = exprs.get(i);
if (!(predicate instanceof BinaryPredicate) && !(predicate instanceof InPredicate)) {
@ -492,22 +501,44 @@ public class ExtractCommonFactorsRule implements ExprRewriteRule {
// isOrToInAllowed : true, means can rewrite
// slotSet.size : nums of columnName in exprs, should be 1
if (isOrToInAllowed && slotSet.size() == 1) {
// slotRef to get ColumnName
// SlotRef firstSlot = (SlotRef) exprs.get(0).getChild(0);
List<Expr> childrenList = exprs.get(0).getChildren();
inPredicate = new InPredicate(exprs.get(0).getChild(0),
childrenList.subList(1, childrenList.size()), false);
for (int i = 1; i < exprs.size(); i++) {
childrenList = exprs.get(i).getChildren();
inPredicate.addChildren(childrenList.subList(1, childrenList.size()));
if (exprs.size() < rewriteThreshold) {
return null;
}
// get deduplication list
List<Expr> deduplicationExprs = getDeduplicationList(exprs);
inPredicate = new InPredicate(deduplicationExprs.get(0),
deduplicationExprs.subList(1, deduplicationExprs.size()), false);
}
return inPredicate;
}
public List<Expr> getDeduplicationList(List<Expr> exprs) {
Set<Expr> set = new HashSet<>();
List<Expr> deduplicationExprList = new ArrayList<>();
deduplicationExprList.add(exprs.get(0).getChild(0));
for (Expr expr : exprs) {
if (expr instanceof BinaryPredicate) {
if (!set.contains(expr.getChild(1))) {
set.add(expr.getChild(1));
deduplicationExprList.add(expr.getChild(1));
}
} else {
List<Expr> childrenExprs = expr.getChildren();
for (Expr childrenExpr : childrenExprs.subList(1, childrenExprs.size())) {
if (!set.contains(childrenExpr)) {
set.add(childrenExpr);
deduplicationExprList.add(childrenExpr);
}
}
}
}
return deduplicationExprList;
}
/**
* Convert RangeSet to Compound Predicate
* @param slotRef: <k1>

View File

@ -2238,5 +2238,22 @@ public class QueryPlanTest extends TestWithFeService {
sql = "SELECT * from test1 where (query_time = 1 or query_time = 2) and (scan_bytes = 2 or scan_bytes = 3)";
explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "EXPLAIN " + sql);
Assert.assertTrue(explainString.contains("PREDICATES: `query_time` IN (1, 2), `scan_bytes` IN (2, 3)"));
sql = "SELECT * from test1 where query_time = 1 or query_time = 2 or query_time = 3 or query_time = 1";
explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "EXPLAIN " + sql);
Assert.assertTrue(explainString.contains("PREDICATES: `query_time` IN (1, 2, 3)"));
sql = "SELECT * from test1 where query_time = 1 or query_time = 2 or query_time in (3, 2)";
explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "EXPLAIN " + sql);
Assert.assertTrue(explainString.contains("PREDICATES: `query_time` IN (1, 2, 3)"));
connectContext.getSessionVariable().setRewriteOrToInPredicateThreshold(100);
sql = "SELECT * from test1 where query_time = 1 or query_time = 2 or query_time in (3, 4)";
explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "EXPLAIN " + sql);
Assert.assertTrue(explainString.contains("PREDICATES: (`query_time` = 1 OR `query_time` = 2 OR `query_time` IN (3, 4))"));
sql = "SELECT * from test1 where (query_time = 1 or query_time = 2) and query_time in (3, 4)";
explainString = UtFrameUtils.getSQLPlanOrErrorMsg(connectContext, "EXPLAIN " + sql);
Assert.assertTrue(explainString.contains("PREDICATES: (`query_time` = 1 OR `query_time` = 2), `query_time` IN (3, 4)"));
}
}