[enhancement](Nereids) refactor expression rewriter to pattern match (#32617)
this pr can improve the performance of the nereids planner, in plan stage. 1. refactor expression rewriter to pattern match, so the lots of expression rewrite rules can criss-crossed apply in a big bottom-up iteration, and rewrite until the expression became stable. now we can process more cases because original there has no loop, and sometimes only process the top expression, like `SimplifyArithmeticRule`. 2. replace `Collection.stream()` to `ImmutableXxx.Builder` to avoid useless method call 3. loop unrolling some codes, like `Expression.<init>`, `PlanTreeRewriteBottomUpJob.pushChildrenJobs` 4. use type/arity specified-code, like `OneRangePartitionEvaluator.toNereidsLiterals()`, `PartitionRangeExpander.tryExpandRange()`, `PartitionRangeExpander.enumerableCount()` 5. refactor `ExtractCommonFactorRule`, now we can extract more cases, and I fix the deed loop when use `ExtractCommonFactorRule` and `SimplifyRange` in one iterative, because `SimplifyRange` generate right deep tree, but `ExtractCommonFactorRule` generate left deep tree 6. refactor `FoldConstantRuleOnFE`, support visitor/pattern match mode, in ExpressionNormalization, pattern match can criss-crossed apply with other rules; in PartitionPruner, visitor can evaluate expression faster 7. lazy compute and cache some operation 8. use int field to compare date 9. use BitSet to find disableNereidsRules 10. two level loop usually faster then build Multimap when bind slot in Scope, so I revert the code 11. `PlanTreeRewriteBottomUpJob` don't need to clearStatePhase any more ### test case 100 threads parallel continuous send this sql which query an empty table, test in my mac machine(m2 chip, 8 core), enable sql cache ```sql select count(1),date_format(time_col,'%Y%m%d'),varchar_col1 from tbl where partition_date>'2024-02-15' and (varchar_col2 ='73130' or varchar_col3='73130') and time_col>'2024-03-04' and time_col<'2024-03-05' group by date_format(time_col,'%Y%m%d'),varchar_col1 order by date_format(time_col,'%Y%m%d') desc, varchar_col1 desc,count(1) asc limit 1000 ``` before this pr: 3100 peak QPS, about 2700 avg QPS after this pr: 4800 peak QPS, about 4400 avg QPS (cherry picked from commit 7338683fdbdf77711f2ce61e580c19f4ea100723)
This commit is contained in:
@ -1016,7 +1016,7 @@ under the License.
|
||||
<configuration>
|
||||
<proc>only</proc>
|
||||
<compilerArgs>
|
||||
<arg>-AplanPath=${basedir}/src/main/java/org/apache/doris/nereids</arg>
|
||||
<arg>-Apath=${basedir}/src/main/java/org/apache/doris/nereids</arg>
|
||||
</compilerArgs>
|
||||
<includes>
|
||||
<include>org/apache/doris/nereids/pattern/generator/PatternDescribableProcessPoint.java</include>
|
||||
|
||||
@ -570,11 +570,14 @@ public class DateLiteral extends LiteralExpr {
|
||||
switch (type.getPrimitiveType()) {
|
||||
case DATE:
|
||||
case DATEV2:
|
||||
return this.getStringValue().compareTo(MIN_DATE.getStringValue()) == 0;
|
||||
return year == 0 && month == 1 && day == 1
|
||||
&& this.getStringValue().compareTo(MIN_DATE.getStringValue()) == 0;
|
||||
case DATETIME:
|
||||
return this.getStringValue().compareTo(MIN_DATETIME.getStringValue()) == 0;
|
||||
return year == 0 && month == 1 && day == 1
|
||||
&& this.getStringValue().compareTo(MIN_DATETIME.getStringValue()) == 0;
|
||||
case DATETIMEV2:
|
||||
return this.getStringValue().compareTo(MIN_DATETIMEV2.getStringValue()) == 0;
|
||||
return year == 0 && month == 1 && day == 1
|
||||
&& this.getStringValue().compareTo(MIN_DATETIMEV2.getStringValue()) == 0;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -2454,9 +2454,8 @@ public class OlapTable extends Table implements MTMVRelatedTableIf {
|
||||
}
|
||||
|
||||
public boolean isDupKeysOrMergeOnWrite() {
|
||||
return getKeysType() == KeysType.DUP_KEYS
|
||||
|| (getKeysType() == KeysType.UNIQUE_KEYS
|
||||
&& getEnableUniqueKeyMergeOnWrite());
|
||||
return keysType == KeysType.DUP_KEYS
|
||||
|| (keysType == KeysType.UNIQUE_KEYS && getEnableUniqueKeyMergeOnWrite());
|
||||
}
|
||||
|
||||
public void initAutoIncrementGenerator(long dbId) {
|
||||
|
||||
@ -67,7 +67,7 @@ public class MTMVRelationManager implements MTMVHookService {
|
||||
* @return
|
||||
*/
|
||||
public Set<MTMV> getAvailableMTMVs(List<BaseTableInfo> tableInfos, ConnectContext ctx) {
|
||||
Set<MTMV> res = Sets.newHashSet();
|
||||
Set<MTMV> res = Sets.newLinkedHashSet();
|
||||
Set<BaseTableInfo> mvInfos = getMTMVInfos(tableInfos);
|
||||
for (BaseTableInfo tableInfo : mvInfos) {
|
||||
try {
|
||||
@ -90,7 +90,7 @@ public class MTMVRelationManager implements MTMVHookService {
|
||||
}
|
||||
|
||||
private Set<BaseTableInfo> getMTMVInfos(List<BaseTableInfo> tableInfos) {
|
||||
Set<BaseTableInfo> mvInfos = Sets.newHashSet();
|
||||
Set<BaseTableInfo> mvInfos = Sets.newLinkedHashSet();
|
||||
for (BaseTableInfo tableInfo : tableInfos) {
|
||||
mvInfos.addAll(getMtmvsByBaseTable(tableInfo));
|
||||
}
|
||||
|
||||
@ -380,7 +380,9 @@ public class Role implements Writable, GsonPostProcessable {
|
||||
|
||||
public boolean checkColPriv(String ctl, String db, String tbl, String col, PrivPredicate wanted) {
|
||||
Optional<Privilege> colPrivilege = wanted.getColPrivilege();
|
||||
Preconditions.checkState(colPrivilege.isPresent(), "this privPredicate should not use checkColPriv:" + wanted);
|
||||
if (!colPrivilege.isPresent()) {
|
||||
throw new IllegalStateException("this privPredicate should not use checkColPriv:" + wanted);
|
||||
}
|
||||
return checkTblPriv(ctl, db, tbl, wanted) || onlyCheckColPriv(ctl, db, tbl, col, colPrivilege.get());
|
||||
}
|
||||
|
||||
|
||||
@ -76,6 +76,7 @@ import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.BitSet;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
@ -134,6 +135,11 @@ public class CascadesContext implements ScheduleContext {
|
||||
// trigger by rule and show by `explain plan process` statement
|
||||
private final List<PlanProcess> planProcesses = new ArrayList<>();
|
||||
|
||||
// this field is modified by FoldConstantRuleOnFE, it matters current traverse
|
||||
// into AggregateFunction with distinct, we can not fold constant in this case
|
||||
private int distinctAggLevel;
|
||||
private final boolean isEnableExprTrace;
|
||||
|
||||
/**
|
||||
* Constructor of OptimizerContext.
|
||||
*
|
||||
@ -156,6 +162,13 @@ public class CascadesContext implements ScheduleContext {
|
||||
this.subqueryExprIsAnalyzed = new HashMap<>();
|
||||
this.runtimeFilterContext = new RuntimeFilterContext(getConnectContext().getSessionVariable());
|
||||
this.materializationContexts = new ArrayList<>();
|
||||
if (statementContext.getConnectContext() != null) {
|
||||
ConnectContext connectContext = statementContext.getConnectContext();
|
||||
SessionVariable sessionVariable = connectContext.getSessionVariable();
|
||||
this.isEnableExprTrace = sessionVariable != null && sessionVariable.isEnableExprTrace();
|
||||
} else {
|
||||
this.isEnableExprTrace = false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@ -256,7 +269,7 @@ public class CascadesContext implements ScheduleContext {
|
||||
this.tables = tables.stream().collect(Collectors.toMap(TableIf::getId, t -> t, (t1, t2) -> t1));
|
||||
}
|
||||
|
||||
public ConnectContext getConnectContext() {
|
||||
public final ConnectContext getConnectContext() {
|
||||
return statementContext.getConnectContext();
|
||||
}
|
||||
|
||||
@ -366,14 +379,20 @@ public class CascadesContext implements ScheduleContext {
|
||||
return defaultValue;
|
||||
}
|
||||
|
||||
StatementContext statementContext = getStatementContext();
|
||||
if (statementContext == null) {
|
||||
return defaultValue;
|
||||
}
|
||||
return statementContext.getOrRegisterCache(cacheName,
|
||||
return getStatementContext().getOrRegisterCache(cacheName,
|
||||
() -> variableSupplier.apply(connectContext.getSessionVariable()));
|
||||
}
|
||||
|
||||
/** getAndCacheDisableRules */
|
||||
public final BitSet getAndCacheDisableRules() {
|
||||
ConnectContext connectContext = getConnectContext();
|
||||
StatementContext statementContext = getStatementContext();
|
||||
if (connectContext == null || statementContext == null) {
|
||||
return new BitSet();
|
||||
}
|
||||
return statementContext.getOrCacheDisableRules(connectContext.getSessionVariable());
|
||||
}
|
||||
|
||||
private CascadesContext execute(Job job) {
|
||||
pushJob(job);
|
||||
jobScheduler.executeJobPool(this);
|
||||
@ -718,8 +737,28 @@ public class CascadesContext implements ScheduleContext {
|
||||
}
|
||||
|
||||
public void printPlanProcess() {
|
||||
printPlanProcess(this.planProcesses);
|
||||
}
|
||||
|
||||
public static void printPlanProcess(List<PlanProcess> planProcesses) {
|
||||
for (PlanProcess row : planProcesses) {
|
||||
LOG.info("RULE: " + row.ruleName + "\nBEFORE:\n" + row.beforeShape + "\nafter:\n" + row.afterShape);
|
||||
}
|
||||
}
|
||||
|
||||
public void incrementDistinctAggLevel() {
|
||||
this.distinctAggLevel++;
|
||||
}
|
||||
|
||||
public void decrementDistinctAggLevel() {
|
||||
this.distinctAggLevel--;
|
||||
}
|
||||
|
||||
public int getDistinctAggLevel() {
|
||||
return distinctAggLevel;
|
||||
}
|
||||
|
||||
public boolean isEnableExprTrace() {
|
||||
return isEnableExprTrace;
|
||||
}
|
||||
}
|
||||
|
||||
@ -387,7 +387,7 @@ public class NereidsPlanner extends Planner {
|
||||
if (hint instanceof DistributeHint) {
|
||||
distributeHintIndex++;
|
||||
if (!hint.getExplainString().equals("")) {
|
||||
distributeIndex = "_" + String.valueOf(distributeHintIndex);
|
||||
distributeIndex = "_" + distributeHintIndex;
|
||||
}
|
||||
}
|
||||
switch (hint.getStatus()) {
|
||||
|
||||
@ -36,6 +36,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
import org.apache.doris.qe.OriginStatement;
|
||||
import org.apache.doris.qe.SessionVariable;
|
||||
|
||||
import com.google.common.base.Stopwatch;
|
||||
import com.google.common.base.Supplier;
|
||||
@ -45,6 +46,7 @@ import com.google.common.collect.Maps;
|
||||
import com.google.common.collect.Sets;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.BitSet;
|
||||
import java.util.Collection;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
@ -117,6 +119,8 @@ public class StatementContext {
|
||||
// Relation for example LogicalOlapScan
|
||||
private final Map<Slot, Relation> slotToRelation = Maps.newHashMap();
|
||||
|
||||
private BitSet disableRules;
|
||||
|
||||
public StatementContext() {
|
||||
this.connectContext = ConnectContext.get();
|
||||
}
|
||||
@ -259,11 +263,22 @@ public class StatementContext {
|
||||
return supplier.get();
|
||||
}
|
||||
|
||||
public synchronized BitSet getOrCacheDisableRules(SessionVariable sessionVariable) {
|
||||
if (this.disableRules != null) {
|
||||
return this.disableRules;
|
||||
}
|
||||
this.disableRules = sessionVariable.getDisableNereidsRules();
|
||||
return this.disableRules;
|
||||
}
|
||||
|
||||
/**
|
||||
* Some value of the cacheKey may change, invalid cache when value change
|
||||
*/
|
||||
public synchronized void invalidCache(String cacheKey) {
|
||||
contextCacheMap.remove(cacheKey);
|
||||
if (cacheKey.equalsIgnoreCase(SessionVariable.DISABLE_NEREIDS_RULES)) {
|
||||
this.disableRules = null;
|
||||
}
|
||||
}
|
||||
|
||||
public ColumnAliasGenerator getColumnAliasGenerator() {
|
||||
|
||||
@ -26,6 +26,7 @@ import com.google.common.collect.LinkedListMultimap;
|
||||
import com.google.common.collect.ListMultimap;
|
||||
import com.google.common.collect.Sets;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Objects;
|
||||
@ -63,6 +64,7 @@ public class Scope {
|
||||
private final List<Slot> slots;
|
||||
private final Optional<SubqueryExpr> ownerSubquery;
|
||||
private final Set<Slot> correlatedSlots;
|
||||
private final boolean buildNameToSlot;
|
||||
private final Supplier<ListMultimap<String, Slot>> nameToSlot;
|
||||
|
||||
public Scope(List<? extends Slot> slots) {
|
||||
@ -75,7 +77,8 @@ public class Scope {
|
||||
this.slots = Utils.fastToImmutableList(Objects.requireNonNull(slots, "slots can not be null"));
|
||||
this.ownerSubquery = Objects.requireNonNull(subqueryExpr, "subqueryExpr can not be null");
|
||||
this.correlatedSlots = Sets.newLinkedHashSet();
|
||||
this.nameToSlot = Suppliers.memoize(this::buildNameToSlot);
|
||||
this.buildNameToSlot = slots.size() > 500;
|
||||
this.nameToSlot = buildNameToSlot ? Suppliers.memoize(this::buildNameToSlot) : null;
|
||||
}
|
||||
|
||||
public List<Slot> getSlots() {
|
||||
@ -96,7 +99,19 @@ public class Scope {
|
||||
|
||||
/** findSlotIgnoreCase */
|
||||
public List<Slot> findSlotIgnoreCase(String slotName) {
|
||||
return nameToSlot.get().get(slotName.toUpperCase(Locale.ROOT));
|
||||
if (!buildNameToSlot) {
|
||||
Object[] array = new Object[slots.size()];
|
||||
int filterIndex = 0;
|
||||
for (int i = 0; i < slots.size(); i++) {
|
||||
Slot slot = slots.get(i);
|
||||
if (slot.getName().equalsIgnoreCase(slotName)) {
|
||||
array[filterIndex++] = slot;
|
||||
}
|
||||
}
|
||||
return (List) Arrays.asList(array).subList(0, filterIndex);
|
||||
} else {
|
||||
return nameToSlot.get().get(slotName.toUpperCase(Locale.ROOT));
|
||||
}
|
||||
}
|
||||
|
||||
private ListMultimap<String, Slot> buildNameToSlot() {
|
||||
|
||||
@ -34,16 +34,14 @@ import org.apache.doris.nereids.rules.RuleSet;
|
||||
import org.apache.doris.nereids.trees.expressions.CTEId;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
import org.apache.doris.qe.SessionVariable;
|
||||
import org.apache.doris.statistics.Statistics;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
|
||||
import java.util.BitSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* Abstract class for all job using for analyze and optimize query plan in Nereids.
|
||||
@ -57,7 +55,7 @@ public abstract class Job implements TracerSupplier {
|
||||
protected JobType type;
|
||||
protected JobContext context;
|
||||
protected boolean once;
|
||||
protected final Set<Integer> disableRules;
|
||||
protected final BitSet disableRules;
|
||||
|
||||
protected Map<CTEId, Statistics> cteIdToStats;
|
||||
|
||||
@ -129,8 +127,7 @@ public abstract class Job implements TracerSupplier {
|
||||
groupExpression.getOwnerGroup(), groupExpression, groupExpression.getPlan()));
|
||||
}
|
||||
|
||||
public static Set<Integer> getDisableRules(JobContext context) {
|
||||
return context.getCascadesContext().getAndCacheSessionVariable(
|
||||
SessionVariable.DISABLE_NEREIDS_RULES, ImmutableSet.of(), SessionVariable::getDisableNereidsRules);
|
||||
public static BitSet getDisableRules(JobContext context) {
|
||||
return context.getCascadesContext().getAndCacheDisableRules();
|
||||
}
|
||||
}
|
||||
|
||||
@ -30,7 +30,7 @@ import org.apache.doris.nereids.rules.analysis.LogicalSubQueryAliasToLogicalProj
|
||||
import org.apache.doris.nereids.rules.analysis.NormalizeAggregate;
|
||||
import org.apache.doris.nereids.rules.expression.CheckLegalityAfterRewrite;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionNormalization;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionOptimization;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionNormalizationAndOptimization;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewrite;
|
||||
import org.apache.doris.nereids.rules.rewrite.AddDefaultLimit;
|
||||
import org.apache.doris.nereids.rules.rewrite.AdjustConjunctsReturnType;
|
||||
@ -152,8 +152,7 @@ public class Rewriter extends AbstractBatchJobExecutor {
|
||||
// such as group by key matching and replaced
|
||||
// but we need to do some normalization before subquery unnesting,
|
||||
// such as extract common expression.
|
||||
new ExpressionNormalization(),
|
||||
new ExpressionOptimization(),
|
||||
new ExpressionNormalizationAndOptimization(),
|
||||
new AvgDistinctToSumDivCount(),
|
||||
new CountDistinctRewrite(),
|
||||
new ExtractFilterFromCrossJoin()
|
||||
@ -240,7 +239,7 @@ public class Rewriter extends AbstractBatchJobExecutor {
|
||||
// efficient because it can find the new plans and apply transform wherever it is
|
||||
bottomUp(RuleSet.PUSH_DOWN_FILTERS),
|
||||
// after push down, some new filters are generated, which needs to be optimized. (example: tpch q19)
|
||||
topDown(new ExpressionOptimization()),
|
||||
// topDown(new ExpressionOptimization()),
|
||||
topDown(
|
||||
new MergeFilters(),
|
||||
new ReorderJoin(),
|
||||
|
||||
@ -51,6 +51,7 @@ import com.google.common.collect.Sets;
|
||||
import java.util.ArrayList;
|
||||
import java.util.BitSet;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
@ -64,7 +65,7 @@ import javax.annotation.Nullable;
|
||||
*/
|
||||
public class HyperGraph {
|
||||
// record all edges that can be placed on the subgraph
|
||||
private final Map<Long, BitSet> treeEdgesCache = new HashMap<>();
|
||||
private final Map<Long, BitSet> treeEdgesCache = new LinkedHashMap<>();
|
||||
private final List<JoinEdge> joinEdges;
|
||||
private final List<FilterEdge> filterEdges;
|
||||
private final List<AbstractNode> nodes;
|
||||
@ -330,9 +331,9 @@ public class HyperGraph {
|
||||
private final List<AbstractNode> nodes = new ArrayList<>();
|
||||
|
||||
// These hyperGraphs should be replaced nodes when building all
|
||||
private final Map<Long, List<HyperGraph>> replacedHyperGraphs = new HashMap<>();
|
||||
private final HashMap<Slot, Long> slotToNodeMap = new HashMap<>();
|
||||
private final Map<Long, List<NamedExpression>> complexProject = new HashMap<>();
|
||||
private final Map<Long, List<HyperGraph>> replacedHyperGraphs = new LinkedHashMap<>();
|
||||
private final HashMap<Slot, Long> slotToNodeMap = new LinkedHashMap<>();
|
||||
private final Map<Long, List<NamedExpression>> complexProject = new LinkedHashMap<>();
|
||||
private Set<Slot> finalOutputs;
|
||||
|
||||
public List<AbstractNode> getNodes() {
|
||||
@ -522,7 +523,7 @@ public class HyperGraph {
|
||||
*/
|
||||
private BitSet addJoin(LogicalJoin<?, ?> join,
|
||||
Pair<BitSet, Long> leftEdgeNodes, Pair<BitSet, Long> rightEdgeNodes) {
|
||||
HashMap<Pair<Long, Long>, Pair<List<Expression>, List<Expression>>> conjuncts = new HashMap<>();
|
||||
Map<Pair<Long, Long>, Pair<List<Expression>, List<Expression>>> conjuncts = new LinkedHashMap<>();
|
||||
for (Expression expression : join.getHashJoinConjuncts()) {
|
||||
// TODO: avoid calling calculateEnds if calNodeMap's results are same
|
||||
Pair<Long, Long> ends = calculateEnds(calNodeMap(expression.getInputSlots()), leftEdgeNodes,
|
||||
|
||||
@ -25,8 +25,8 @@ import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
|
||||
|
||||
import java.util.BitSet;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.function.Supplier;
|
||||
|
||||
/**
|
||||
@ -50,8 +50,8 @@ public class CustomRewriteJob implements RewriteJob {
|
||||
|
||||
@Override
|
||||
public void execute(JobContext context) {
|
||||
Set<Integer> disableRules = Job.getDisableRules(context);
|
||||
if (disableRules.contains(ruleType.type())) {
|
||||
BitSet disableRules = Job.getDisableRules(context);
|
||||
if (disableRules.get(ruleType.type())) {
|
||||
return;
|
||||
}
|
||||
CascadesContext cascadesContext = context.getCascadesContext();
|
||||
|
||||
@ -39,9 +39,9 @@ public class PlanTreeRewriteBottomUpJob extends PlanTreeRewriteJob {
|
||||
// Different 'RewriteState' has different actions,
|
||||
// so we will do specified action for each node based on their 'RewriteState'.
|
||||
private static final String REWRITE_STATE_KEY = "rewrite_state";
|
||||
|
||||
private final RewriteJobContext rewriteJobContext;
|
||||
private final List<Rule> rules;
|
||||
private final int batchId;
|
||||
|
||||
enum RewriteState {
|
||||
// 'REWRITE_THIS' means the current plan node can be handled immediately. If the plan state is 'REWRITE_THIS',
|
||||
@ -59,22 +59,15 @@ public class PlanTreeRewriteBottomUpJob extends PlanTreeRewriteJob {
|
||||
super(JobType.BOTTOM_UP_REWRITE, context);
|
||||
this.rewriteJobContext = Objects.requireNonNull(rewriteJobContext, "rewriteContext cannot be null");
|
||||
this.rules = Objects.requireNonNull(rules, "rules cannot be null");
|
||||
this.batchId = rewriteJobContext.batchId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void execute() {
|
||||
// For the bottom-up rewrite job, we need to reset the state of its children
|
||||
// if the plan has changed after the rewrite. So we use the 'childrenVisited' to check this situation.
|
||||
boolean clearStatePhase = !rewriteJobContext.childrenVisited;
|
||||
if (clearStatePhase) {
|
||||
traverseClearState();
|
||||
return;
|
||||
}
|
||||
|
||||
// We'll do different actions based on their different states.
|
||||
// You can check the comment in 'RewriteState' structure for more details.
|
||||
Plan plan = rewriteJobContext.plan;
|
||||
RewriteState state = getState(plan);
|
||||
RewriteState state = getState(plan, batchId);
|
||||
switch (state) {
|
||||
case REWRITE_THIS:
|
||||
rewriteThis();
|
||||
@ -90,33 +83,13 @@ public class PlanTreeRewriteBottomUpJob extends PlanTreeRewriteJob {
|
||||
}
|
||||
}
|
||||
|
||||
private void traverseClearState() {
|
||||
// Reset the state for current node.
|
||||
RewriteJobContext clearedStateContext = rewriteJobContext.withChildrenVisited(true);
|
||||
setState(clearedStateContext.plan, RewriteState.REWRITE_THIS);
|
||||
pushJob(new PlanTreeRewriteBottomUpJob(clearedStateContext, context, rules));
|
||||
|
||||
// Generate the new rewrite job for its children. Because the character of stack is 'first in, last out',
|
||||
// so we can traverse reset the state for the plan node until the leaf node.
|
||||
List<Plan> children = clearedStateContext.plan.children();
|
||||
for (int i = children.size() - 1; i >= 0; i--) {
|
||||
Plan child = children.get(i);
|
||||
RewriteJobContext childRewriteJobContext = new RewriteJobContext(
|
||||
child, clearedStateContext, i, false);
|
||||
// NOTICE: this relay on pull up cte anchor
|
||||
if (!(rewriteJobContext.plan instanceof LogicalCTEAnchor)) {
|
||||
pushJob(new PlanTreeRewriteBottomUpJob(childRewriteJobContext, context, rules));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void rewriteThis() {
|
||||
// Link the current node with the sub-plan to get the current plan which is used in the rewrite phase later.
|
||||
Plan plan = linkChildren(rewriteJobContext.plan, rewriteJobContext.childrenContext);
|
||||
RewriteResult rewriteResult = rewrite(plan, rules, rewriteJobContext);
|
||||
if (rewriteResult.hasNewPlan) {
|
||||
RewriteJobContext newJobContext = rewriteJobContext.withPlan(rewriteResult.plan);
|
||||
RewriteState state = getState(rewriteResult.plan);
|
||||
RewriteState state = getState(rewriteResult.plan, batchId);
|
||||
// Some eliminate rule will return a rewritten plan, for example the current node is eliminated
|
||||
// and return the child plan. So we don't need to handle it again.
|
||||
if (state == RewriteState.REWRITTEN) {
|
||||
@ -125,40 +98,82 @@ public class PlanTreeRewriteBottomUpJob extends PlanTreeRewriteJob {
|
||||
}
|
||||
// After the rewrite take effect, we should handle the children part again.
|
||||
pushJob(new PlanTreeRewriteBottomUpJob(newJobContext, context, rules));
|
||||
setState(rewriteResult.plan, RewriteState.ENSURE_CHILDREN_REWRITTEN);
|
||||
setState(rewriteResult.plan, RewriteState.ENSURE_CHILDREN_REWRITTEN, batchId);
|
||||
} else {
|
||||
// No new plan is generated, so just set the state of the current plan to 'REWRITTEN'.
|
||||
setState(rewriteResult.plan, RewriteState.REWRITTEN);
|
||||
setState(rewriteResult.plan, RewriteState.REWRITTEN, batchId);
|
||||
rewriteJobContext.setResult(rewriteResult.plan);
|
||||
}
|
||||
}
|
||||
|
||||
private void ensureChildrenRewritten() {
|
||||
// Similar to the function 'traverseClearState'.
|
||||
Plan plan = rewriteJobContext.plan;
|
||||
setState(plan, RewriteState.REWRITE_THIS);
|
||||
int batchId = rewriteJobContext.batchId;
|
||||
setState(plan, RewriteState.REWRITE_THIS, batchId);
|
||||
pushJob(new PlanTreeRewriteBottomUpJob(rewriteJobContext, context, rules));
|
||||
|
||||
List<Plan> children = plan.children();
|
||||
for (int i = children.size() - 1; i >= 0; i--) {
|
||||
Plan child = children.get(i);
|
||||
// some rule return new plan tree, which the number of new plan node > 1,
|
||||
// we should transform this new plan nodes too.
|
||||
RewriteJobContext childRewriteJobContext = new RewriteJobContext(
|
||||
child, rewriteJobContext, i, false);
|
||||
// NOTICE: this relay on pull up cte anchor
|
||||
if (!(rewriteJobContext.plan instanceof LogicalCTEAnchor)) {
|
||||
pushJob(new PlanTreeRewriteBottomUpJob(childRewriteJobContext, context, rules));
|
||||
}
|
||||
// some rule return new plan tree, which the number of new plan node > 1,
|
||||
// we should transform this new plan nodes too.
|
||||
// NOTICE: this relay on pull up cte anchor
|
||||
if (!(rewriteJobContext.plan instanceof LogicalCTEAnchor)) {
|
||||
pushChildrenJobs(plan);
|
||||
}
|
||||
}
|
||||
|
||||
private static final RewriteState getState(Plan plan) {
|
||||
Optional<RewriteState> state = plan.getMutableState(REWRITE_STATE_KEY);
|
||||
return state.orElse(RewriteState.ENSURE_CHILDREN_REWRITTEN);
|
||||
private void pushChildrenJobs(Plan plan) {
|
||||
List<Plan> children = plan.children();
|
||||
switch (children.size()) {
|
||||
case 0: return;
|
||||
case 1:
|
||||
Plan child = children.get(0);
|
||||
RewriteJobContext childRewriteJobContext = new RewriteJobContext(
|
||||
child, rewriteJobContext, 0, false, batchId);
|
||||
pushJob(new PlanTreeRewriteBottomUpJob(childRewriteJobContext, context, rules));
|
||||
return;
|
||||
case 2:
|
||||
Plan right = children.get(1);
|
||||
RewriteJobContext rightRewriteJobContext = new RewriteJobContext(
|
||||
right, rewriteJobContext, 1, false, batchId);
|
||||
pushJob(new PlanTreeRewriteBottomUpJob(rightRewriteJobContext, context, rules));
|
||||
|
||||
Plan left = children.get(0);
|
||||
RewriteJobContext leftRewriteJobContext = new RewriteJobContext(
|
||||
left, rewriteJobContext, 0, false, batchId);
|
||||
pushJob(new PlanTreeRewriteBottomUpJob(leftRewriteJobContext, context, rules));
|
||||
return;
|
||||
default:
|
||||
for (int i = children.size() - 1; i >= 0; i--) {
|
||||
child = children.get(i);
|
||||
childRewriteJobContext = new RewriteJobContext(
|
||||
child, rewriteJobContext, i, false, batchId);
|
||||
pushJob(new PlanTreeRewriteBottomUpJob(childRewriteJobContext, context, rules));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static final void setState(Plan plan, RewriteState state) {
|
||||
plan.setMutableState(REWRITE_STATE_KEY, state);
|
||||
private static RewriteState getState(Plan plan, int currentBatchId) {
|
||||
Optional<RewriteStateContext> state = plan.getMutableState(REWRITE_STATE_KEY);
|
||||
if (!state.isPresent()) {
|
||||
return RewriteState.ENSURE_CHILDREN_REWRITTEN;
|
||||
}
|
||||
RewriteStateContext context = state.get();
|
||||
if (context.batchId != currentBatchId) {
|
||||
return RewriteState.ENSURE_CHILDREN_REWRITTEN;
|
||||
}
|
||||
return context.rewriteState;
|
||||
}
|
||||
|
||||
private static void setState(Plan plan, RewriteState state, int batchId) {
|
||||
plan.setMutableState(REWRITE_STATE_KEY, new RewriteStateContext(state, batchId));
|
||||
}
|
||||
|
||||
private static class RewriteStateContext {
|
||||
private final RewriteState rewriteState;
|
||||
private final int batchId;
|
||||
|
||||
public RewriteStateContext(RewriteState rewriteState, int batchId) {
|
||||
this.rewriteState = rewriteState;
|
||||
this.batchId = batchId;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -28,6 +28,8 @@ import org.apache.doris.nereids.pattern.Pattern;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/** PlanTreeRewriteJob */
|
||||
@ -43,7 +45,7 @@ public abstract class PlanTreeRewriteJob extends Job {
|
||||
|
||||
boolean showPlanProcess = cascadesContext.showPlanProcess();
|
||||
for (Rule rule : rules) {
|
||||
if (disableRules.contains(rule.getRuleType().type())) {
|
||||
if (disableRules.get(rule.getRuleType().type())) {
|
||||
continue;
|
||||
}
|
||||
Pattern<Plan> pattern = (Pattern<Plan>) rule.getPattern();
|
||||
@ -76,26 +78,50 @@ public abstract class PlanTreeRewriteJob extends Job {
|
||||
return new RewriteResult(false, plan);
|
||||
}
|
||||
|
||||
protected final Plan linkChildrenAndParent(Plan plan, RewriteJobContext rewriteJobContext) {
|
||||
Plan newPlan = linkChildren(plan, rewriteJobContext.childrenContext);
|
||||
rewriteJobContext.setResult(newPlan);
|
||||
return newPlan;
|
||||
}
|
||||
|
||||
protected final Plan linkChildren(Plan plan, RewriteJobContext[] childrenContext) {
|
||||
boolean changed = false;
|
||||
Plan[] newChildren = new Plan[childrenContext.length];
|
||||
for (int i = 0; i < childrenContext.length; ++i) {
|
||||
Plan result = childrenContext[i].result;
|
||||
Plan oldChild = plan.child(i);
|
||||
if (result != null && result != oldChild) {
|
||||
newChildren[i] = result;
|
||||
changed = true;
|
||||
} else {
|
||||
newChildren[i] = oldChild;
|
||||
protected static Plan linkChildren(Plan plan, RewriteJobContext[] childrenContext) {
|
||||
List<Plan> children = plan.children();
|
||||
// loop unrolling
|
||||
switch (children.size()) {
|
||||
case 0: {
|
||||
return plan;
|
||||
}
|
||||
case 1: {
|
||||
RewriteJobContext child = childrenContext[0];
|
||||
Plan firstResult = child == null ? plan.child(0) : child.result;
|
||||
return firstResult == null || firstResult == children.get(0)
|
||||
? plan : plan.withChildren(ImmutableList.of(firstResult));
|
||||
}
|
||||
case 2: {
|
||||
RewriteJobContext left = childrenContext[0];
|
||||
Plan firstResult = left == null ? plan.child(0) : left.result;
|
||||
RewriteJobContext right = childrenContext[1];
|
||||
Plan secondResult = right == null ? plan.child(1) : right.result;
|
||||
Plan firstOrigin = children.get(0);
|
||||
Plan secondOrigin = children.get(1);
|
||||
boolean firstChanged = firstResult != null && firstResult != firstOrigin;
|
||||
boolean secondChanged = secondResult != null && secondResult != secondOrigin;
|
||||
if (firstChanged || secondChanged) {
|
||||
ImmutableList.Builder<Plan> newChildren = ImmutableList.builderWithExpectedSize(2);
|
||||
newChildren.add(firstChanged ? firstResult : firstOrigin);
|
||||
newChildren.add(secondChanged ? secondResult : secondOrigin);
|
||||
return plan.withChildren(newChildren.build());
|
||||
} else {
|
||||
return plan;
|
||||
}
|
||||
}
|
||||
default: {
|
||||
boolean changed = false;
|
||||
int i = 0;
|
||||
Plan[] newChildren = new Plan[childrenContext.length];
|
||||
for (Plan oldChild : children) {
|
||||
Plan result = childrenContext[i].result;
|
||||
changed = result != null && result != oldChild;
|
||||
newChildren[i] = changed ? result : oldChild;
|
||||
i++;
|
||||
}
|
||||
return changed ? plan.withChildren(newChildren) : plan;
|
||||
}
|
||||
}
|
||||
return changed ? plan.withChildren(newChildren) : plan;
|
||||
}
|
||||
|
||||
private String getCurrentPlanTreeString() {
|
||||
|
||||
@ -56,21 +56,44 @@ public class PlanTreeRewriteTopDownJob extends PlanTreeRewriteJob {
|
||||
RewriteJobContext newRewriteJobContext = rewriteJobContext.withChildrenVisited(true);
|
||||
pushJob(new PlanTreeRewriteTopDownJob(newRewriteJobContext, context, rules));
|
||||
|
||||
List<Plan> children = newRewriteJobContext.plan.children();
|
||||
for (int i = children.size() - 1; i >= 0; i--) {
|
||||
RewriteJobContext childRewriteJobContext = new RewriteJobContext(
|
||||
children.get(i), newRewriteJobContext, i, false);
|
||||
// NOTICE: this relay on pull up cte anchor
|
||||
if (!(rewriteJobContext.plan instanceof LogicalCTEAnchor)) {
|
||||
pushJob(new PlanTreeRewriteTopDownJob(childRewriteJobContext, context, rules));
|
||||
}
|
||||
// NOTICE: this relay on pull up cte anchor
|
||||
if (!(this.rewriteJobContext.plan instanceof LogicalCTEAnchor)) {
|
||||
pushChildrenJobs(newRewriteJobContext);
|
||||
}
|
||||
} else {
|
||||
// All the children part are already visited. Just link the children plan to the current node.
|
||||
Plan result = linkChildrenAndParent(rewriteJobContext.plan, rewriteJobContext);
|
||||
Plan result = linkChildren(rewriteJobContext.plan, rewriteJobContext.childrenContext);
|
||||
rewriteJobContext.setResult(result);
|
||||
if (rewriteJobContext.parentContext == null) {
|
||||
context.getCascadesContext().setRewritePlan(result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void pushChildrenJobs(RewriteJobContext rewriteJobContext) {
|
||||
List<Plan> children = rewriteJobContext.plan.children();
|
||||
switch (children.size()) {
|
||||
case 0: return;
|
||||
case 1:
|
||||
RewriteJobContext childRewriteJobContext = new RewriteJobContext(
|
||||
children.get(0), rewriteJobContext, 0, false, this.rewriteJobContext.batchId);
|
||||
pushJob(new PlanTreeRewriteTopDownJob(childRewriteJobContext, context, rules));
|
||||
return;
|
||||
case 2:
|
||||
RewriteJobContext rightRewriteJobContext = new RewriteJobContext(
|
||||
children.get(1), rewriteJobContext, 1, false, this.rewriteJobContext.batchId);
|
||||
pushJob(new PlanTreeRewriteTopDownJob(rightRewriteJobContext, context, rules));
|
||||
|
||||
RewriteJobContext leftRewriteJobContext = new RewriteJobContext(
|
||||
children.get(0), rewriteJobContext, 0, false, this.rewriteJobContext.batchId);
|
||||
pushJob(new PlanTreeRewriteTopDownJob(leftRewriteJobContext, context, rules));
|
||||
return;
|
||||
default:
|
||||
for (int i = children.size() - 1; i >= 0; i--) {
|
||||
childRewriteJobContext = new RewriteJobContext(
|
||||
children.get(i), rewriteJobContext, i, false, this.rewriteJobContext.batchId);
|
||||
pushJob(new PlanTreeRewriteTopDownJob(childRewriteJobContext, context, rules));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -25,6 +25,7 @@ import javax.annotation.Nullable;
|
||||
public class RewriteJobContext {
|
||||
|
||||
final boolean childrenVisited;
|
||||
final int batchId;
|
||||
final RewriteJobContext parentContext;
|
||||
final int childIndexInParentContext;
|
||||
final Plan plan;
|
||||
@ -33,7 +34,7 @@ public class RewriteJobContext {
|
||||
|
||||
/** RewriteJobContext */
|
||||
public RewriteJobContext(Plan plan, @Nullable RewriteJobContext parentContext, int childIndexInParentContext,
|
||||
boolean childrenVisited) {
|
||||
boolean childrenVisited, int batchId) {
|
||||
this.plan = plan;
|
||||
this.parentContext = parentContext;
|
||||
this.childIndexInParentContext = childIndexInParentContext;
|
||||
@ -42,6 +43,7 @@ public class RewriteJobContext {
|
||||
if (parentContext != null) {
|
||||
parentContext.childrenContext[childIndexInParentContext] = this;
|
||||
}
|
||||
this.batchId = batchId;
|
||||
}
|
||||
|
||||
public void setResult(Plan result) {
|
||||
@ -49,15 +51,15 @@ public class RewriteJobContext {
|
||||
}
|
||||
|
||||
public RewriteJobContext withChildrenVisited(boolean childrenVisited) {
|
||||
return new RewriteJobContext(plan, parentContext, childIndexInParentContext, childrenVisited);
|
||||
return new RewriteJobContext(plan, parentContext, childIndexInParentContext, childrenVisited, batchId);
|
||||
}
|
||||
|
||||
public RewriteJobContext withPlan(Plan plan) {
|
||||
return new RewriteJobContext(plan, parentContext, childIndexInParentContext, childrenVisited);
|
||||
return new RewriteJobContext(plan, parentContext, childIndexInParentContext, childrenVisited, batchId);
|
||||
}
|
||||
|
||||
public RewriteJobContext withPlanAndChildrenVisited(Plan plan, boolean childrenVisited) {
|
||||
return new RewriteJobContext(plan, parentContext, childIndexInParentContext, childrenVisited);
|
||||
return new RewriteJobContext(plan, parentContext, childIndexInParentContext, childrenVisited, batchId);
|
||||
}
|
||||
|
||||
public boolean isRewriteRoot() {
|
||||
|
||||
@ -27,9 +27,11 @@ import org.apache.doris.nereids.trees.plans.Plan;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
/** RootPlanTreeRewriteJob */
|
||||
public class RootPlanTreeRewriteJob implements RewriteJob {
|
||||
private static final AtomicInteger BATCH_ID = new AtomicInteger();
|
||||
|
||||
private final List<Rule> rules;
|
||||
private final RewriteJobBuilder rewriteJobBuilder;
|
||||
@ -47,7 +49,9 @@ public class RootPlanTreeRewriteJob implements RewriteJob {
|
||||
// get plan from the cascades context
|
||||
Plan root = cascadesContext.getRewritePlan();
|
||||
// write rewritten root plan to cascades context by the RootRewriteJobContext
|
||||
RootRewriteJobContext rewriteJobContext = new RootRewriteJobContext(root, false, context);
|
||||
int batchId = BATCH_ID.incrementAndGet();
|
||||
RootRewriteJobContext rewriteJobContext = new RootRewriteJobContext(
|
||||
root, false, context, batchId);
|
||||
Job rewriteJob = rewriteJobBuilder.build(rewriteJobContext, context, rules);
|
||||
|
||||
context.getScheduleContext().pushJob(rewriteJob);
|
||||
@ -71,8 +75,8 @@ public class RootPlanTreeRewriteJob implements RewriteJob {
|
||||
|
||||
private final JobContext jobContext;
|
||||
|
||||
RootRewriteJobContext(Plan plan, boolean childrenVisited, JobContext jobContext) {
|
||||
super(plan, null, -1, childrenVisited);
|
||||
RootRewriteJobContext(Plan plan, boolean childrenVisited, JobContext jobContext, int batchId) {
|
||||
super(plan, null, -1, childrenVisited, batchId);
|
||||
this.jobContext = Objects.requireNonNull(jobContext, "jobContext cannot be null");
|
||||
jobContext.getCascadesContext().setCurrentRootRewriteJobContext(this);
|
||||
}
|
||||
@ -89,17 +93,17 @@ public class RootPlanTreeRewriteJob implements RewriteJob {
|
||||
|
||||
@Override
|
||||
public RewriteJobContext withChildrenVisited(boolean childrenVisited) {
|
||||
return new RootRewriteJobContext(plan, childrenVisited, jobContext);
|
||||
return new RootRewriteJobContext(plan, childrenVisited, jobContext, batchId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RewriteJobContext withPlan(Plan plan) {
|
||||
return new RootRewriteJobContext(plan, childrenVisited, jobContext);
|
||||
return new RootRewriteJobContext(plan, childrenVisited, jobContext, batchId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RewriteJobContext withPlanAndChildrenVisited(Plan plan, boolean childrenVisited) {
|
||||
return new RootRewriteJobContext(plan, childrenVisited, jobContext);
|
||||
return new RootRewriteJobContext(plan, childrenVisited, jobContext, batchId);
|
||||
}
|
||||
|
||||
/** linkChildren */
|
||||
|
||||
@ -0,0 +1,112 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.pattern;
|
||||
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionMatchingContext;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatchRule;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
|
||||
import java.lang.reflect.Field;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
|
||||
/** ExpressionPatternMapping */
|
||||
public class ExpressionPatternRules extends TypeMappings<Expression, ExpressionPatternMatchRule> {
|
||||
private static final Logger LOG = LogManager.getLogger(ExpressionPatternRules.class);
|
||||
|
||||
public ExpressionPatternRules(List<ExpressionPatternMatchRule> typeMappings) {
|
||||
super(typeMappings);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Set<Class<? extends Expression>> getChildrenClasses(Class<? extends Expression> clazz) {
|
||||
return org.apache.doris.nereids.pattern.GeneratedExpressionRelations.CHILDREN_CLASS_MAP.get(clazz);
|
||||
}
|
||||
|
||||
/** matchesAndApply */
|
||||
public Optional<Expression> matchesAndApply(Expression expr, ExpressionRewriteContext context, Expression parent) {
|
||||
List<ExpressionPatternMatchRule> rules = singleMappings.get(expr.getClass());
|
||||
ExpressionMatchingContext<Expression> matchingContext
|
||||
= new ExpressionMatchingContext<>(expr, parent, context);
|
||||
switch (rules.size()) {
|
||||
case 0: {
|
||||
for (ExpressionPatternMatchRule multiMatchRule : multiMappings) {
|
||||
if (multiMatchRule.matchesTypeAndPredicates(matchingContext)) {
|
||||
Expression newExpr = multiMatchRule.apply(matchingContext);
|
||||
if (!newExpr.equals(expr)) {
|
||||
if (context.cascadesContext.isEnableExprTrace()) {
|
||||
traceExprChanged(multiMatchRule, expr, newExpr);
|
||||
}
|
||||
return Optional.of(newExpr);
|
||||
}
|
||||
}
|
||||
}
|
||||
return Optional.empty();
|
||||
}
|
||||
case 1: {
|
||||
ExpressionPatternMatchRule rule = rules.get(0);
|
||||
if (rule.matchesPredicates(matchingContext)) {
|
||||
Expression newExpr = rule.apply(matchingContext);
|
||||
if (!newExpr.equals(expr)) {
|
||||
if (context.cascadesContext.isEnableExprTrace()) {
|
||||
traceExprChanged(rule, expr, newExpr);
|
||||
}
|
||||
return Optional.of(newExpr);
|
||||
}
|
||||
}
|
||||
return Optional.empty();
|
||||
}
|
||||
default: {
|
||||
for (ExpressionPatternMatchRule rule : rules) {
|
||||
if (rule.matchesPredicates(matchingContext)) {
|
||||
Expression newExpr = rule.apply(matchingContext);
|
||||
if (!expr.equals(newExpr)) {
|
||||
if (context.cascadesContext.isEnableExprTrace()) {
|
||||
traceExprChanged(rule, expr, newExpr);
|
||||
}
|
||||
return Optional.of(newExpr);
|
||||
}
|
||||
}
|
||||
}
|
||||
return Optional.empty();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static void traceExprChanged(ExpressionPatternMatchRule rule, Expression expr, Expression newExpr) {
|
||||
try {
|
||||
Field[] declaredFields = (rule.matchingAction).getClass().getDeclaredFields();
|
||||
Class<?> ruleClass;
|
||||
if (declaredFields.length == 0) {
|
||||
ruleClass = rule.matchingAction.getClass();
|
||||
} else {
|
||||
Field field = declaredFields[0];
|
||||
field.setAccessible(true);
|
||||
ruleClass = field.get(rule.matchingAction).getClass();
|
||||
}
|
||||
LOG.info("RULE: " + ruleClass + "\nbefore: " + expr + "\nafter: " + newExpr);
|
||||
} catch (Throwable t) {
|
||||
LOG.error(t.getMessage(), t);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,112 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.pattern;
|
||||
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionMatchingContext;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionTraverseListener;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionTraverseListenerMapping;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import javax.annotation.Nullable;
|
||||
|
||||
/** ExpressionPatternTraverseListeners */
|
||||
public class ExpressionPatternTraverseListeners
|
||||
extends TypeMappings<Expression, ExpressionTraverseListenerMapping> {
|
||||
public ExpressionPatternTraverseListeners(
|
||||
List<ExpressionTraverseListenerMapping> typeMappings) {
|
||||
super(typeMappings);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Set<Class<? extends Expression>> getChildrenClasses(Class<? extends Expression> clazz) {
|
||||
return org.apache.doris.nereids.pattern.GeneratedExpressionRelations.CHILDREN_CLASS_MAP.get(clazz);
|
||||
}
|
||||
|
||||
/** matchesAndCombineListener */
|
||||
public @Nullable CombinedListener matchesAndCombineListeners(
|
||||
Expression expr, ExpressionRewriteContext context, Expression parent) {
|
||||
List<ExpressionTraverseListenerMapping> listenerSingleMappings = singleMappings.get(expr.getClass());
|
||||
ExpressionMatchingContext<Expression> matchingContext
|
||||
= new ExpressionMatchingContext<>(expr, parent, context);
|
||||
switch (listenerSingleMappings.size()) {
|
||||
case 0: {
|
||||
ImmutableList.Builder<ExpressionTraverseListener<Expression>> matchedListeners
|
||||
= ImmutableList.builder();
|
||||
for (ExpressionTraverseListenerMapping multiMapping : multiMappings) {
|
||||
if (multiMapping.matchesTypeAndPredicates(matchingContext)) {
|
||||
matchedListeners.add(multiMapping.listener);
|
||||
}
|
||||
}
|
||||
return CombinedListener.tryCombine(matchedListeners.build(), matchingContext);
|
||||
}
|
||||
case 1: {
|
||||
ExpressionTraverseListenerMapping listenerMapping = listenerSingleMappings.get(0);
|
||||
if (listenerMapping.matchesPredicates(matchingContext)) {
|
||||
return CombinedListener.tryCombine(ImmutableList.of(listenerMapping.listener), matchingContext);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
default: {
|
||||
ImmutableList.Builder<ExpressionTraverseListener<Expression>> matchedListeners
|
||||
= ImmutableList.builder();
|
||||
for (ExpressionTraverseListenerMapping singleMapping : listenerSingleMappings) {
|
||||
if (singleMapping.matchesPredicates(matchingContext)) {
|
||||
matchedListeners.add(singleMapping.listener);
|
||||
}
|
||||
}
|
||||
return CombinedListener.tryCombine(matchedListeners.build(), matchingContext);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** CombinedListener */
|
||||
public static class CombinedListener {
|
||||
private final ExpressionMatchingContext<Expression> context;
|
||||
private final List<ExpressionTraverseListener<Expression>> listeners;
|
||||
|
||||
/** CombinedListener */
|
||||
public CombinedListener(ExpressionMatchingContext<Expression> context,
|
||||
List<ExpressionTraverseListener<Expression>> listeners) {
|
||||
this.context = context;
|
||||
this.listeners = listeners;
|
||||
}
|
||||
|
||||
public static @Nullable CombinedListener tryCombine(
|
||||
List<ExpressionTraverseListener<Expression>> listenerMappings,
|
||||
ExpressionMatchingContext<Expression> context) {
|
||||
return listenerMappings.isEmpty() ? null : new CombinedListener(context, listenerMappings);
|
||||
}
|
||||
|
||||
public void onEnter() {
|
||||
for (ExpressionTraverseListener<Expression> listener : listeners) {
|
||||
listener.onEnter(context);
|
||||
}
|
||||
}
|
||||
|
||||
public void onExit(Expression rewritten) {
|
||||
for (ExpressionTraverseListener<Expression> listener : listeners) {
|
||||
listener.onExit(context, rewritten);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,59 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.pattern;
|
||||
|
||||
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
/** ParentTypeIdMapping */
|
||||
public class ParentTypeIdMapping {
|
||||
|
||||
private final AtomicInteger idGenerator = new AtomicInteger();
|
||||
private final Map<Class<?>, Integer> classId = new ConcurrentHashMap<>(8192);
|
||||
|
||||
/** getId */
|
||||
public int getId(Class<?> clazz) {
|
||||
Integer id = classId.get(clazz);
|
||||
if (id != null) {
|
||||
return id;
|
||||
}
|
||||
return ensureClassHasId(clazz);
|
||||
}
|
||||
|
||||
private int ensureClassHasId(Class<?> clazz) {
|
||||
Class<?> superClass = clazz.getSuperclass();
|
||||
if (superClass != null) {
|
||||
ensureClassHasId(superClass);
|
||||
}
|
||||
|
||||
for (Class<?> interfaceClass : clazz.getInterfaces()) {
|
||||
ensureClassHasId(interfaceClass);
|
||||
}
|
||||
|
||||
return classId.computeIfAbsent(clazz, c -> idGenerator.incrementAndGet());
|
||||
}
|
||||
|
||||
public static void main(String[] args) {
|
||||
ParentTypeIdMapping mapping = new ParentTypeIdMapping();
|
||||
int id = mapping.getId(LessThanEqual.class);
|
||||
System.out.println(id);
|
||||
}
|
||||
}
|
||||
@ -152,6 +152,10 @@ public class Pattern<TYPE extends Plan>
|
||||
if (this instanceof SubTreePattern) {
|
||||
return matchPredicates((TYPE) plan);
|
||||
}
|
||||
return matchChildrenAndSelfPredicates(plan, childPatternNum);
|
||||
}
|
||||
|
||||
private boolean matchChildrenAndSelfPredicates(Plan plan, int childPatternNum) {
|
||||
List<Plan> childrenPlan = plan.children();
|
||||
for (int i = 0; i < childrenPlan.size(); i++) {
|
||||
Plan child = childrenPlan.get(i);
|
||||
|
||||
@ -0,0 +1,133 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.pattern;
|
||||
|
||||
import org.apache.doris.nereids.pattern.TypeMappings.TypeMapping;
|
||||
import org.apache.doris.nereids.util.Utils;
|
||||
|
||||
import com.google.common.collect.ArrayListMultimap;
|
||||
import com.google.common.collect.ListMultimap;
|
||||
import com.google.common.collect.Lists;
|
||||
|
||||
import java.lang.reflect.Modifier;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import javax.annotation.Nullable;
|
||||
|
||||
/** ExpressionPatternMappings */
|
||||
public abstract class TypeMappings<K, T extends TypeMapping<K>> {
|
||||
protected final ListMultimap<Class<? extends K>, T> singleMappings;
|
||||
protected final List<T> multiMappings;
|
||||
|
||||
/** ExpressionPatternMappings */
|
||||
public TypeMappings(List<T> typeMappings) {
|
||||
this.singleMappings = ArrayListMultimap.create();
|
||||
this.multiMappings = Lists.newArrayList();
|
||||
|
||||
for (T mapping : typeMappings) {
|
||||
Set<Class<? extends K>> childrenClasses = getChildrenClasses(mapping.getType());
|
||||
if (childrenClasses == null || childrenClasses.isEmpty()) {
|
||||
// add some expressions which no child class
|
||||
// e.g. LessThanEqual
|
||||
addSimpleMapping(mapping);
|
||||
} else if (childrenClasses.size() <= 100) {
|
||||
// add some expressions which have children classes
|
||||
// e.g. ComparisonPredicate will be expanded to
|
||||
// ruleMappings.put(LessThanEqual.class, rule);
|
||||
// ruleMappings.put(LessThan.class, rule);
|
||||
// ruleMappings.put(GreaterThan.class, rule);
|
||||
// ruleMappings.put(GreaterThanEquals.class, rule);
|
||||
// ...
|
||||
addThisAndChildrenMapping(mapping, childrenClasses);
|
||||
} else {
|
||||
// some expressions have lots of children classes, e.g. Expression, ExpressionTrait, BinaryExpression,
|
||||
// we will not expand this types to child class, but also add this rules to other type matching.
|
||||
// for example, if we have three rules to matches this types: LessThanEqual, Abs and Expression,
|
||||
// then the ruleMappings would be:
|
||||
// {
|
||||
// LessThanEqual.class: [rule_of_LessThanEqual, rule_of_Expression],
|
||||
// Abs.class: [rule_of_Abs, rule_of_Expression]
|
||||
// }
|
||||
//
|
||||
// and the multiMatchRules would be: [rule_of_Expression]
|
||||
//
|
||||
// if we matches `a <= 1`, there have two rules would be applied because
|
||||
// ruleMappings.get(LessThanEqual.class) return two rules;
|
||||
// if we matches `a = 1`, ruleMappings.get(EqualTo.class) will return empty rules, so we use
|
||||
// all the rules in multiMatchRules to matches and apply, the rule_of_Expression will be applied.
|
||||
addMultiMapping(mapping);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public @Nullable List<T> get(Class<? extends K> clazz) {
|
||||
return singleMappings.get(clazz);
|
||||
}
|
||||
|
||||
private void addSimpleMapping(T typeMapping) {
|
||||
Class<? extends K> clazz = typeMapping.getType();
|
||||
int modifiers = clazz.getModifiers();
|
||||
if (!Modifier.isAbstract(modifiers)) {
|
||||
addSingleMapping(clazz, typeMapping);
|
||||
}
|
||||
}
|
||||
|
||||
private void addThisAndChildrenMapping(
|
||||
T typeMapping, Set<Class<? extends K>> childrenClasses) {
|
||||
Class<? extends K> clazz = typeMapping.getType();
|
||||
if (!Modifier.isAbstract(clazz.getModifiers())) {
|
||||
addSingleMapping(clazz, typeMapping);
|
||||
}
|
||||
|
||||
for (Class<? extends K> childrenClass : childrenClasses) {
|
||||
if (!Modifier.isAbstract(childrenClass.getModifiers())) {
|
||||
addSingleMapping(childrenClass, typeMapping);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void addMultiMapping(T multiMapping) {
|
||||
multiMappings.add(multiMapping);
|
||||
|
||||
Set<Class<? extends K>> existSingleMappingTypes = Utils.fastToImmutableSet(singleMappings.keySet());
|
||||
for (Class<? extends K> existSingleType : existSingleMappingTypes) {
|
||||
Class<? extends K> type = multiMapping.getType();
|
||||
if (type.isAssignableFrom(existSingleType)) {
|
||||
singleMappings.put(existSingleType, multiMapping);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void addSingleMapping(Class<? extends K> clazz, T singleMapping) {
|
||||
if (!singleMappings.containsKey(clazz) && !multiMappings.isEmpty()) {
|
||||
for (T multiMapping : multiMappings) {
|
||||
if (multiMapping.getType().isAssignableFrom(clazz)) {
|
||||
singleMappings.put(clazz, multiMapping);
|
||||
}
|
||||
}
|
||||
}
|
||||
singleMappings.put(clazz, singleMapping);
|
||||
}
|
||||
|
||||
protected abstract Set<Class<? extends K>> getChildrenClasses(Class<? extends K> clazz);
|
||||
|
||||
/** TypeMapping */
|
||||
public interface TypeMapping<K> {
|
||||
Class<? extends K> getType();
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,159 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.pattern.generator;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
|
||||
import java.io.BufferedWriter;
|
||||
import java.io.File;
|
||||
import java.io.FileWriter;
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import javax.annotation.processing.ProcessingEnvironment;
|
||||
import javax.tools.StandardLocation;
|
||||
|
||||
/** ExpressionTypeMappingGenerator */
|
||||
public class ExpressionTypeMappingGenerator {
|
||||
private final JavaAstAnalyzer analyzer;
|
||||
|
||||
public ExpressionTypeMappingGenerator(JavaAstAnalyzer javaAstAnalyzer) {
|
||||
this.analyzer = javaAstAnalyzer;
|
||||
}
|
||||
|
||||
public JavaAstAnalyzer getAnalyzer() {
|
||||
return analyzer;
|
||||
}
|
||||
|
||||
/** generate */
|
||||
public void generate(ProcessingEnvironment processingEnv) throws IOException {
|
||||
Set<String> superExpressions = findSuperExpression();
|
||||
Map<String, Set<String>> childrenNameMap = analyzer.getChildrenNameMap();
|
||||
Map<String, Set<String>> parentNameMap = analyzer.getParentNameMap();
|
||||
String code = generateCode(childrenNameMap, parentNameMap, superExpressions);
|
||||
generateFile(processingEnv, code);
|
||||
}
|
||||
|
||||
private void generateFile(ProcessingEnvironment processingEnv, String code) throws IOException {
|
||||
File generatePatternFile = new File(processingEnv.getFiler()
|
||||
.getResource(StandardLocation.SOURCE_OUTPUT, "org.apache.doris.nereids.pattern",
|
||||
"GeneratedExpressionRelations.java").toUri());
|
||||
if (generatePatternFile.exists()) {
|
||||
generatePatternFile.delete();
|
||||
}
|
||||
if (!generatePatternFile.getParentFile().exists()) {
|
||||
generatePatternFile.getParentFile().mkdirs();
|
||||
}
|
||||
|
||||
// bypass create file for processingEnv.getFiler(), compile GeneratePatterns in next compile term
|
||||
try (BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(generatePatternFile))) {
|
||||
bufferedWriter.write(code);
|
||||
}
|
||||
}
|
||||
|
||||
private Set<String> findSuperExpression() {
|
||||
Map<String, Set<String>> parentNameMap = analyzer.getParentNameMap();
|
||||
Map<String, Set<String>> childrenNameMap = analyzer.getChildrenNameMap();
|
||||
Set<String> superExpressions = Sets.newLinkedHashSet();
|
||||
for (Entry<String, Set<String>> entry : childrenNameMap.entrySet()) {
|
||||
String parentName = entry.getKey();
|
||||
Set<String> childrenNames = entry.getValue();
|
||||
|
||||
if (parentName.startsWith("org.apache.doris.nereids.trees.expressions.")) {
|
||||
for (String childrenName : childrenNames) {
|
||||
Set<String> parentNames = parentNameMap.get(childrenName);
|
||||
if (parentNames != null
|
||||
&& parentNames.contains("org.apache.doris.nereids.trees.expressions.Expression")) {
|
||||
superExpressions.add(parentName);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return superExpressions;
|
||||
}
|
||||
|
||||
private String generateCode(Map<String, Set<String>> childrenNameMap,
|
||||
Map<String, Set<String>> parentNameMap, Set<String> superExpressions) {
|
||||
String generateCode
|
||||
= "// Licensed to the Apache Software Foundation (ASF) under one\n"
|
||||
+ "// or more contributor license agreements. See the NOTICE file\n"
|
||||
+ "// distributed with this work for additional information\n"
|
||||
+ "// regarding copyright ownership. The ASF licenses this file\n"
|
||||
+ "// to you under the Apache License, Version 2.0 (the\n"
|
||||
+ "// \"License\"); you may not use this file except in compliance\n"
|
||||
+ "// with the License. You may obtain a copy of the License at\n"
|
||||
+ "//\n"
|
||||
+ "// http://www.apache.org/licenses/LICENSE-2.0\n"
|
||||
+ "//\n"
|
||||
+ "// Unless required by applicable law or agreed to in writing,\n"
|
||||
+ "// software distributed under the License is distributed on an\n"
|
||||
+ "// \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n"
|
||||
+ "// KIND, either express or implied. See the License for the\n"
|
||||
+ "// specific language governing permissions and limitations\n"
|
||||
+ "// under the License.\n"
|
||||
+ "\n"
|
||||
+ "package org.apache.doris.nereids.pattern;\n"
|
||||
+ "\n"
|
||||
+ "import org.apache.doris.nereids.trees.expressions.Expression;\n"
|
||||
+ "\n"
|
||||
+ "import com.google.common.collect.ImmutableMap;\n"
|
||||
+ "import com.google.common.collect.ImmutableSet;\n"
|
||||
+ "\n"
|
||||
+ "import java.util.Map;\n"
|
||||
+ "import java.util.Set;\n"
|
||||
+ "\n";
|
||||
generateCode += "/** GeneratedExpressionRelations */\npublic class GeneratedExpressionRelations {\n";
|
||||
String childrenClassesGenericType = "<Class<?>, Set<Class<? extends Expression>>>";
|
||||
generateCode +=
|
||||
" public static final Map" + childrenClassesGenericType + " CHILDREN_CLASS_MAP;\n\n";
|
||||
generateCode +=
|
||||
" static {\n"
|
||||
+ " ImmutableMap.Builder" + childrenClassesGenericType + " childrenClassesBuilder\n"
|
||||
+ " = ImmutableMap.builderWithExpectedSize(" + childrenNameMap.size() + ");\n";
|
||||
|
||||
for (String superExpression : superExpressions) {
|
||||
Set<String> childrenClasseSet = childrenNameMap.get(superExpression)
|
||||
.stream()
|
||||
.filter(childClass -> parentNameMap.get(childClass)
|
||||
.contains("org.apache.doris.nereids.trees.expressions.Expression")
|
||||
)
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
List<String> childrenClasses = Lists.newArrayList(childrenClasseSet);
|
||||
Collections.sort(childrenClasses, Comparator.naturalOrder());
|
||||
|
||||
String childClassesString = childrenClasses.stream()
|
||||
.map(childClass -> " " + childClass + ".class")
|
||||
.collect(Collectors.joining(",\n"));
|
||||
generateCode += " childrenClassesBuilder.put(\n " + superExpression
|
||||
+ ".class,\n ImmutableSet.<Class<? extends Expression>>of(\n" + childClassesString
|
||||
+ "\n )\n );\n\n";
|
||||
}
|
||||
|
||||
generateCode += " CHILDREN_CLASS_MAP = childrenClassesBuilder.build();\n";
|
||||
|
||||
return generateCode + " }\n}\n";
|
||||
}
|
||||
}
|
||||
@ -29,25 +29,24 @@ import org.apache.doris.nereids.pattern.generator.javaast.TypeType;
|
||||
|
||||
import com.google.common.base.Joiner;
|
||||
|
||||
import java.lang.reflect.Modifier;
|
||||
import java.util.ArrayList;
|
||||
import java.util.IdentityHashMap;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* used to analyze plan class extends hierarchy and then generated pattern builder methods.
|
||||
*/
|
||||
public class PatternGeneratorAnalyzer {
|
||||
private final Map<String, TypeDeclaration> name2Ast = new LinkedHashMap<>();
|
||||
private final IdentityHashMap<TypeDeclaration, String> ast2Name = new IdentityHashMap<>();
|
||||
private final IdentityHashMap<TypeDeclaration, Map<String, String>> ast2Import = new IdentityHashMap<>();
|
||||
private final IdentityHashMap<TypeDeclaration, Set<String>> parentClassMap = new IdentityHashMap<>();
|
||||
/** JavaAstAnalyzer */
|
||||
public class JavaAstAnalyzer {
|
||||
protected final Map<String, TypeDeclaration> name2Ast = new LinkedHashMap<>();
|
||||
protected final IdentityHashMap<TypeDeclaration, String> ast2Name = new IdentityHashMap<>();
|
||||
protected final IdentityHashMap<TypeDeclaration, Map<String, String>> ast2Import = new IdentityHashMap<>();
|
||||
protected final IdentityHashMap<TypeDeclaration, Set<String>> parentClassMap = new IdentityHashMap<>();
|
||||
protected final Map<String, Set<String>> parentNameMap = new LinkedHashMap<>();
|
||||
protected final Map<String, Set<String>> childrenNameMap = new LinkedHashMap<>();
|
||||
|
||||
/** add java AST. */
|
||||
public void addAsts(List<TypeDeclaration> typeDeclarations) {
|
||||
@ -56,14 +55,20 @@ public class PatternGeneratorAnalyzer {
|
||||
}
|
||||
}
|
||||
|
||||
/** generate pattern methods. */
|
||||
public String generatePatterns(String className, String parentClassName, boolean isMemoPattern) {
|
||||
analyzeImport();
|
||||
analyzeParentClass();
|
||||
return doGenerate(className, parentClassName, isMemoPattern);
|
||||
public IdentityHashMap<TypeDeclaration, Set<String>> getParentClassMap() {
|
||||
return parentClassMap;
|
||||
}
|
||||
|
||||
Optional<TypeDeclaration> getType(TypeDeclaration typeDeclaration, TypeType type) {
|
||||
public Map<String, Set<String>> getParentNameMap() {
|
||||
return parentNameMap;
|
||||
}
|
||||
|
||||
public Map<String, Set<String>> getChildrenNameMap() {
|
||||
return childrenNameMap;
|
||||
}
|
||||
|
||||
/** getType */
|
||||
public Optional<TypeDeclaration> getType(TypeDeclaration typeDeclaration, TypeType type) {
|
||||
String typeName = analyzeClass(new LinkedHashSet<>(), typeDeclaration, type);
|
||||
if (typeName != null) {
|
||||
TypeDeclaration ast = name2Ast.get(typeName);
|
||||
@ -73,34 +78,11 @@ public class PatternGeneratorAnalyzer {
|
||||
return Optional.empty();
|
||||
}
|
||||
|
||||
private String doGenerate(String className, String parentClassName, boolean isMemoPattern) {
|
||||
Map<ClassDeclaration, Set<String>> planClassMap = parentClassMap.entrySet().stream()
|
||||
.filter(kv -> kv.getValue().contains("org.apache.doris.nereids.trees.plans.Plan"))
|
||||
.filter(kv -> !kv.getKey().name.equals("GroupPlan"))
|
||||
.filter(kv -> !Modifier.isAbstract(kv.getKey().modifiers.mod)
|
||||
&& kv.getKey() instanceof ClassDeclaration)
|
||||
.collect(Collectors.toMap(kv -> (ClassDeclaration) kv.getKey(), kv -> kv.getValue()));
|
||||
|
||||
List<PatternGenerator> generators = planClassMap.entrySet()
|
||||
.stream()
|
||||
.map(kv -> PatternGenerator.create(this, kv.getKey(), kv.getValue(), isMemoPattern))
|
||||
.filter(Optional::isPresent)
|
||||
.map(Optional::get)
|
||||
.sorted((g1, g2) -> {
|
||||
// logical first
|
||||
if (g1.isLogical() != g2.isLogical()) {
|
||||
return g1.isLogical() ? -1 : 1;
|
||||
}
|
||||
// leaf first
|
||||
if (g1.childrenNum() != g2.childrenNum()) {
|
||||
return g1.childrenNum() - g2.childrenNum();
|
||||
}
|
||||
// string dict sort
|
||||
return g1.opType.name.compareTo(g2.opType.name);
|
||||
})
|
||||
.collect(Collectors.toList());
|
||||
|
||||
return PatternGenerator.generateCode(className, parentClassName, generators, this, isMemoPattern);
|
||||
protected void analyze() {
|
||||
analyzeImport();
|
||||
analyzeParentClass();
|
||||
analyzeParentName();
|
||||
analyzeChildrenName();
|
||||
}
|
||||
|
||||
private void analyzeImport() {
|
||||
@ -148,7 +130,28 @@ public class PatternGeneratorAnalyzer {
|
||||
parentClasses.addAll(currentParentClasses);
|
||||
}
|
||||
|
||||
String analyzeClass(Set<String> parentClasses, TypeDeclaration typeDeclaration, TypeType type) {
|
||||
private void analyzeParentName() {
|
||||
for (Entry<TypeDeclaration, Set<String>> entry : parentClassMap.entrySet()) {
|
||||
String parentName = entry.getKey().getFullQualifiedName();
|
||||
parentNameMap.put(parentName, entry.getValue());
|
||||
}
|
||||
}
|
||||
|
||||
private void analyzeChildrenName() {
|
||||
for (Entry<String, TypeDeclaration> entry : name2Ast.entrySet()) {
|
||||
Set<String> parentNames = parentClassMap.get(entry.getValue());
|
||||
for (String parentName : parentNames) {
|
||||
Set<String> childrenNames = childrenNameMap.get(parentName);
|
||||
if (childrenNames == null) {
|
||||
childrenNames = new LinkedHashSet<>();
|
||||
childrenNameMap.put(parentName, childrenNames);
|
||||
}
|
||||
childrenNames.add(entry.getKey());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private String analyzeClass(Set<String> parentClasses, TypeDeclaration typeDeclaration, TypeType type) {
|
||||
if (type.classOrInterfaceType.isPresent()) {
|
||||
List<String> identifiers = new ArrayList<>();
|
||||
ClassOrInterfaceType classOrInterfaceType = type.classOrInterfaceType.get();
|
||||
@ -23,9 +23,9 @@ import java.util.Set;
|
||||
import java.util.TreeSet;
|
||||
|
||||
/** used to generate pattern for LogicalBinary. */
|
||||
public class LogicalBinaryPatternGenerator extends PatternGenerator {
|
||||
public class LogicalBinaryPatternGenerator extends PlanPatternGenerator {
|
||||
|
||||
public LogicalBinaryPatternGenerator(PatternGeneratorAnalyzer analyzer,
|
||||
public LogicalBinaryPatternGenerator(PlanPatternGeneratorAnalyzer analyzer,
|
||||
ClassDeclaration opType, Set<String> parentClass, boolean isMemoPattern) {
|
||||
super(analyzer, opType, parentClass, isMemoPattern);
|
||||
}
|
||||
|
||||
@ -23,9 +23,9 @@ import java.util.Set;
|
||||
import java.util.TreeSet;
|
||||
|
||||
/** used to generate pattern for LogicalLeaf. */
|
||||
public class LogicalLeafPatternGenerator extends PatternGenerator {
|
||||
public class LogicalLeafPatternGenerator extends PlanPatternGenerator {
|
||||
|
||||
public LogicalLeafPatternGenerator(PatternGeneratorAnalyzer analyzer,
|
||||
public LogicalLeafPatternGenerator(PlanPatternGeneratorAnalyzer analyzer,
|
||||
ClassDeclaration opType, Set<String> parentClass, boolean isMemoPattern) {
|
||||
super(analyzer, opType, parentClass, isMemoPattern);
|
||||
}
|
||||
|
||||
@ -23,9 +23,9 @@ import java.util.Set;
|
||||
import java.util.TreeSet;
|
||||
|
||||
/** used to generate pattern for LogicalUnary. */
|
||||
public class LogicalUnaryPatternGenerator extends PatternGenerator {
|
||||
public class LogicalUnaryPatternGenerator extends PlanPatternGenerator {
|
||||
|
||||
public LogicalUnaryPatternGenerator(PatternGeneratorAnalyzer analyzer,
|
||||
public LogicalUnaryPatternGenerator(PlanPatternGeneratorAnalyzer analyzer,
|
||||
ClassDeclaration opType, Set<String> parentClass, boolean isMemoPattern) {
|
||||
super(analyzer, opType, parentClass, isMemoPattern);
|
||||
}
|
||||
|
||||
@ -60,12 +60,12 @@ import javax.tools.StandardLocation;
|
||||
@SupportedSourceVersion(SourceVersion.RELEASE_8)
|
||||
@SupportedAnnotationTypes("org.apache.doris.nereids.pattern.generator.PatternDescribable")
|
||||
public class PatternDescribableProcessor extends AbstractProcessor {
|
||||
private List<File> planPaths;
|
||||
private List<File> paths;
|
||||
|
||||
@Override
|
||||
public synchronized void init(ProcessingEnvironment processingEnv) {
|
||||
super.init(processingEnv);
|
||||
this.planPaths = Arrays.stream(processingEnv.getOptions().get("planPath").split(","))
|
||||
this.paths = Arrays.stream(processingEnv.getOptions().get("path").split(","))
|
||||
.map(path -> path.trim())
|
||||
.filter(path -> !path.isEmpty())
|
||||
.collect(Collectors.toSet())
|
||||
@ -80,15 +80,25 @@ public class PatternDescribableProcessor extends AbstractProcessor {
|
||||
return false;
|
||||
}
|
||||
try {
|
||||
List<File> planFiles = findJavaFiles(planPaths);
|
||||
PatternGeneratorAnalyzer patternGeneratorAnalyzer = new PatternGeneratorAnalyzer();
|
||||
for (File file : planFiles) {
|
||||
List<File> javaFiles = findJavaFiles(paths);
|
||||
JavaAstAnalyzer javaAstAnalyzer = new JavaAstAnalyzer();
|
||||
for (File file : javaFiles) {
|
||||
List<TypeDeclaration> asts = parseJavaFile(file);
|
||||
patternGeneratorAnalyzer.addAsts(asts);
|
||||
javaAstAnalyzer.addAsts(asts);
|
||||
}
|
||||
|
||||
doGenerate("GeneratedMemoPatterns", "MemoPatterns", true, patternGeneratorAnalyzer);
|
||||
doGenerate("GeneratedPlanPatterns", "PlanPatterns", false, patternGeneratorAnalyzer);
|
||||
javaAstAnalyzer.analyze();
|
||||
|
||||
ExpressionTypeMappingGenerator expressionTypeMappingGenerator
|
||||
= new ExpressionTypeMappingGenerator(javaAstAnalyzer);
|
||||
expressionTypeMappingGenerator.generate(processingEnv);
|
||||
|
||||
PlanTypeMappingGenerator planTypeMappingGenerator = new PlanTypeMappingGenerator(javaAstAnalyzer);
|
||||
planTypeMappingGenerator.generate(processingEnv);
|
||||
|
||||
PlanPatternGeneratorAnalyzer patternGeneratorAnalyzer = new PlanPatternGeneratorAnalyzer(javaAstAnalyzer);
|
||||
generatePlanPatterns("GeneratedMemoPatterns", "MemoPatterns", true, patternGeneratorAnalyzer);
|
||||
generatePlanPatterns("GeneratedPlanPatterns", "PlanPatterns", false, patternGeneratorAnalyzer);
|
||||
} catch (Throwable t) {
|
||||
String exceptionMsg = Throwables.getStackTraceAsString(t);
|
||||
processingEnv.getMessager().printMessage(Kind.ERROR,
|
||||
@ -97,8 +107,12 @@ public class PatternDescribableProcessor extends AbstractProcessor {
|
||||
return false;
|
||||
}
|
||||
|
||||
private void doGenerate(String className, String parentClassName, boolean isMemoPattern,
|
||||
PatternGeneratorAnalyzer patternGeneratorAnalyzer) throws IOException {
|
||||
private void generateExpressionTypeMapping() {
|
||||
|
||||
}
|
||||
|
||||
private void generatePlanPatterns(String className, String parentClassName, boolean isMemoPattern,
|
||||
PlanPatternGeneratorAnalyzer patternGeneratorAnalyzer) throws IOException {
|
||||
String generatePatternCode = patternGeneratorAnalyzer.generatePatterns(
|
||||
className, parentClassName, isMemoPattern);
|
||||
File generatePatternFile = new File(processingEnv.getFiler()
|
||||
|
||||
@ -23,9 +23,9 @@ import java.util.Set;
|
||||
import java.util.TreeSet;
|
||||
|
||||
/** used to generate pattern for PhysicalBinary. */
|
||||
public class PhysicalBinaryPatternGenerator extends PatternGenerator {
|
||||
public class PhysicalBinaryPatternGenerator extends PlanPatternGenerator {
|
||||
|
||||
public PhysicalBinaryPatternGenerator(PatternGeneratorAnalyzer analyzer,
|
||||
public PhysicalBinaryPatternGenerator(PlanPatternGeneratorAnalyzer analyzer,
|
||||
ClassDeclaration opType, Set<String> parentClass, boolean isMemoPattern) {
|
||||
super(analyzer, opType, parentClass, isMemoPattern);
|
||||
}
|
||||
|
||||
@ -23,9 +23,9 @@ import java.util.Set;
|
||||
import java.util.TreeSet;
|
||||
|
||||
/** used to generate pattern for PhysicalLeaf. */
|
||||
public class PhysicalLeafPatternGenerator extends PatternGenerator {
|
||||
public class PhysicalLeafPatternGenerator extends PlanPatternGenerator {
|
||||
|
||||
public PhysicalLeafPatternGenerator(PatternGeneratorAnalyzer analyzer,
|
||||
public PhysicalLeafPatternGenerator(PlanPatternGeneratorAnalyzer analyzer,
|
||||
ClassDeclaration opType, Set<String> parentClass, boolean isMemoPattern) {
|
||||
super(analyzer, opType, parentClass, isMemoPattern);
|
||||
}
|
||||
|
||||
@ -23,9 +23,9 @@ import java.util.Set;
|
||||
import java.util.TreeSet;
|
||||
|
||||
/** used to generate pattern for PhysicalUnary. */
|
||||
public class PhysicalUnaryPatternGenerator extends PatternGenerator {
|
||||
public class PhysicalUnaryPatternGenerator extends PlanPatternGenerator {
|
||||
|
||||
public PhysicalUnaryPatternGenerator(PatternGeneratorAnalyzer analyzer,
|
||||
public PhysicalUnaryPatternGenerator(PlanPatternGeneratorAnalyzer analyzer,
|
||||
ClassDeclaration opType, Set<String> parentClass, boolean isMemoPattern) {
|
||||
super(analyzer, opType, parentClass, isMemoPattern);
|
||||
}
|
||||
|
||||
@ -43,8 +43,8 @@ import java.util.regex.Pattern;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/** used to generate pattern by plan. */
|
||||
public abstract class PatternGenerator {
|
||||
protected final PatternGeneratorAnalyzer analyzer;
|
||||
public abstract class PlanPatternGenerator {
|
||||
protected final JavaAstAnalyzer analyzer;
|
||||
protected final ClassDeclaration opType;
|
||||
protected final Set<String> parentClass;
|
||||
protected final List<EnumFieldPatternInfo> enumFieldPatternInfos;
|
||||
@ -52,9 +52,9 @@ public abstract class PatternGenerator {
|
||||
protected final boolean isMemoPattern;
|
||||
|
||||
/** constructor. */
|
||||
public PatternGenerator(PatternGeneratorAnalyzer analyzer, ClassDeclaration opType,
|
||||
public PlanPatternGenerator(PlanPatternGeneratorAnalyzer analyzer, ClassDeclaration opType,
|
||||
Set<String> parentClass, boolean isMemoPattern) {
|
||||
this.analyzer = analyzer;
|
||||
this.analyzer = analyzer.getAnalyzer();
|
||||
this.opType = opType;
|
||||
this.parentClass = parentClass;
|
||||
this.enumFieldPatternInfos = getEnumFieldPatternInfos();
|
||||
@ -76,8 +76,8 @@ public abstract class PatternGenerator {
|
||||
}
|
||||
|
||||
/** generate code by generators and analyzer. */
|
||||
public static String generateCode(String className, String parentClassName, List<PatternGenerator> generators,
|
||||
PatternGeneratorAnalyzer analyzer, boolean isMemoPattern) {
|
||||
public static String generateCode(String className, String parentClassName, List<PlanPatternGenerator> generators,
|
||||
PlanPatternGeneratorAnalyzer analyzer, boolean isMemoPattern) {
|
||||
String generateCode
|
||||
= "// Licensed to the Apache Software Foundation (ASF) under one\n"
|
||||
+ "// or more contributor license agreements. See the NOTICE file\n"
|
||||
@ -206,7 +206,7 @@ public abstract class PatternGenerator {
|
||||
}
|
||||
|
||||
/** create generator by plan's type. */
|
||||
public static Optional<PatternGenerator> create(PatternGeneratorAnalyzer analyzer,
|
||||
public static Optional<PlanPatternGenerator> create(PlanPatternGeneratorAnalyzer analyzer,
|
||||
ClassDeclaration opType, Set<String> parentClass, boolean isMemoPattern) {
|
||||
if (parentClass.contains("org.apache.doris.nereids.trees.plans.logical.LogicalLeaf")) {
|
||||
return Optional.of(new LogicalLeafPatternGenerator(analyzer, opType, parentClass, isMemoPattern));
|
||||
@ -225,9 +225,9 @@ public abstract class PatternGenerator {
|
||||
}
|
||||
}
|
||||
|
||||
private static String generateImports(List<PatternGenerator> generators) {
|
||||
private static String generateImports(List<PlanPatternGenerator> generators) {
|
||||
Set<String> imports = new HashSet<>();
|
||||
for (PatternGenerator generator : generators) {
|
||||
for (PlanPatternGenerator generator : generators) {
|
||||
imports.addAll(generator.getImports());
|
||||
}
|
||||
List<String> sortedImports = new ArrayList<>(imports);
|
||||
@ -0,0 +1,73 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.pattern.generator;
|
||||
|
||||
import org.apache.doris.nereids.pattern.generator.javaast.ClassDeclaration;
|
||||
|
||||
import java.lang.reflect.Modifier;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* used to analyze plan class extends hierarchy and then generated pattern builder methods.
|
||||
*/
|
||||
public class PlanPatternGeneratorAnalyzer {
|
||||
private final JavaAstAnalyzer analyzer;
|
||||
|
||||
public PlanPatternGeneratorAnalyzer(JavaAstAnalyzer analyzer) {
|
||||
this.analyzer = analyzer;
|
||||
}
|
||||
|
||||
public JavaAstAnalyzer getAnalyzer() {
|
||||
return analyzer;
|
||||
}
|
||||
|
||||
/** generate pattern methods. */
|
||||
public String generatePatterns(String className, String parentClassName, boolean isMemoPattern) {
|
||||
Map<ClassDeclaration, Set<String>> planClassMap = analyzer.getParentClassMap().entrySet().stream()
|
||||
.filter(kv -> kv.getValue().contains("org.apache.doris.nereids.trees.plans.Plan"))
|
||||
.filter(kv -> !kv.getKey().name.equals("GroupPlan"))
|
||||
.filter(kv -> !Modifier.isAbstract(kv.getKey().modifiers.mod)
|
||||
&& kv.getKey() instanceof ClassDeclaration)
|
||||
.collect(Collectors.toMap(kv -> (ClassDeclaration) kv.getKey(), kv -> kv.getValue()));
|
||||
|
||||
List<PlanPatternGenerator> generators = planClassMap.entrySet()
|
||||
.stream()
|
||||
.map(kv -> PlanPatternGenerator.create(this, kv.getKey(), kv.getValue(), isMemoPattern))
|
||||
.filter(Optional::isPresent)
|
||||
.map(Optional::get)
|
||||
.sorted((g1, g2) -> {
|
||||
// logical first
|
||||
if (g1.isLogical() != g2.isLogical()) {
|
||||
return g1.isLogical() ? -1 : 1;
|
||||
}
|
||||
// leaf first
|
||||
if (g1.childrenNum() != g2.childrenNum()) {
|
||||
return g1.childrenNum() - g2.childrenNum();
|
||||
}
|
||||
// string dict sort
|
||||
return g1.opType.name.compareTo(g2.opType.name);
|
||||
})
|
||||
.collect(Collectors.toList());
|
||||
|
||||
return PlanPatternGenerator.generateCode(className, parentClassName, generators, this, isMemoPattern);
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,159 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.pattern.generator;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
|
||||
import java.io.BufferedWriter;
|
||||
import java.io.File;
|
||||
import java.io.FileWriter;
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import javax.annotation.processing.ProcessingEnvironment;
|
||||
import javax.tools.StandardLocation;
|
||||
|
||||
/** PlanTypeMappingGenerator */
|
||||
public class PlanTypeMappingGenerator {
|
||||
private final JavaAstAnalyzer analyzer;
|
||||
|
||||
public PlanTypeMappingGenerator(JavaAstAnalyzer javaAstAnalyzer) {
|
||||
this.analyzer = javaAstAnalyzer;
|
||||
}
|
||||
|
||||
public JavaAstAnalyzer getAnalyzer() {
|
||||
return analyzer;
|
||||
}
|
||||
|
||||
/** generate */
|
||||
public void generate(ProcessingEnvironment processingEnv) throws IOException {
|
||||
Set<String> superPlans = findSuperPlan();
|
||||
Map<String, Set<String>> childrenNameMap = analyzer.getChildrenNameMap();
|
||||
Map<String, Set<String>> parentNameMap = analyzer.getParentNameMap();
|
||||
String code = generateCode(childrenNameMap, parentNameMap, superPlans);
|
||||
generateFile(processingEnv, code);
|
||||
}
|
||||
|
||||
private void generateFile(ProcessingEnvironment processingEnv, String code) throws IOException {
|
||||
File generatePatternFile = new File(processingEnv.getFiler()
|
||||
.getResource(StandardLocation.SOURCE_OUTPUT, "org.apache.doris.nereids.pattern",
|
||||
"GeneratedPlanRelations.java").toUri());
|
||||
if (generatePatternFile.exists()) {
|
||||
generatePatternFile.delete();
|
||||
}
|
||||
if (!generatePatternFile.getParentFile().exists()) {
|
||||
generatePatternFile.getParentFile().mkdirs();
|
||||
}
|
||||
|
||||
// bypass create file for processingEnv.getFiler(), compile GeneratePatterns in next compile term
|
||||
try (BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(generatePatternFile))) {
|
||||
bufferedWriter.write(code);
|
||||
}
|
||||
}
|
||||
|
||||
private Set<String> findSuperPlan() {
|
||||
Map<String, Set<String>> parentNameMap = analyzer.getParentNameMap();
|
||||
Map<String, Set<String>> childrenNameMap = analyzer.getChildrenNameMap();
|
||||
Set<String> superPlans = Sets.newLinkedHashSet();
|
||||
for (Entry<String, Set<String>> entry : childrenNameMap.entrySet()) {
|
||||
String parentName = entry.getKey();
|
||||
Set<String> childrenNames = entry.getValue();
|
||||
|
||||
if (parentName.startsWith("org.apache.doris.nereids.trees.plans.")) {
|
||||
for (String childrenName : childrenNames) {
|
||||
Set<String> parentNames = parentNameMap.get(childrenName);
|
||||
if (parentNames != null
|
||||
&& parentNames.contains("org.apache.doris.nereids.trees.plans.Plan")) {
|
||||
superPlans.add(parentName);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return superPlans;
|
||||
}
|
||||
|
||||
private String generateCode(Map<String, Set<String>> childrenNameMap,
|
||||
Map<String, Set<String>> parentNameMap, Set<String> superPlans) {
|
||||
String generateCode
|
||||
= "// Licensed to the Apache Software Foundation (ASF) under one\n"
|
||||
+ "// or more contributor license agreements. See the NOTICE file\n"
|
||||
+ "// distributed with this work for additional information\n"
|
||||
+ "// regarding copyright ownership. The ASF licenses this file\n"
|
||||
+ "// to you under the Apache License, Version 2.0 (the\n"
|
||||
+ "// \"License\"); you may not use this file except in compliance\n"
|
||||
+ "// with the License. You may obtain a copy of the License at\n"
|
||||
+ "//\n"
|
||||
+ "// http://www.apache.org/licenses/LICENSE-2.0\n"
|
||||
+ "//\n"
|
||||
+ "// Unless required by applicable law or agreed to in writing,\n"
|
||||
+ "// software distributed under the License is distributed on an\n"
|
||||
+ "// \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n"
|
||||
+ "// KIND, either express or implied. See the License for the\n"
|
||||
+ "// specific language governing permissions and limitations\n"
|
||||
+ "// under the License.\n"
|
||||
+ "\n"
|
||||
+ "package org.apache.doris.nereids.pattern;\n"
|
||||
+ "\n"
|
||||
+ "import org.apache.doris.nereids.trees.plans.Plan;\n"
|
||||
+ "\n"
|
||||
+ "import com.google.common.collect.ImmutableMap;\n"
|
||||
+ "import com.google.common.collect.ImmutableSet;\n"
|
||||
+ "\n"
|
||||
+ "import java.util.Map;\n"
|
||||
+ "import java.util.Set;\n"
|
||||
+ "\n";
|
||||
generateCode += "/** GeneratedPlanRelations */\npublic class GeneratedPlanRelations {\n";
|
||||
String childrenClassesGenericType = "<Class<?>, Set<Class<? extends Plan>>>";
|
||||
generateCode +=
|
||||
" public static final Map" + childrenClassesGenericType + " CHILDREN_CLASS_MAP;\n\n";
|
||||
generateCode +=
|
||||
" static {\n"
|
||||
+ " ImmutableMap.Builder" + childrenClassesGenericType + " childrenClassesBuilder\n"
|
||||
+ " = ImmutableMap.builderWithExpectedSize(" + childrenNameMap.size() + ");\n";
|
||||
|
||||
for (String superPlan : superPlans) {
|
||||
Set<String> childrenClasseSet = childrenNameMap.get(superPlan)
|
||||
.stream()
|
||||
.filter(childClass -> parentNameMap.get(childClass)
|
||||
.contains("org.apache.doris.nereids.trees.plans.Plan")
|
||||
)
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
List<String> childrenClasses = Lists.newArrayList(childrenClasseSet);
|
||||
Collections.sort(childrenClasses, Comparator.naturalOrder());
|
||||
|
||||
String childClassesString = childrenClasses.stream()
|
||||
.map(childClass -> " " + childClass + ".class")
|
||||
.collect(Collectors.joining(",\n"));
|
||||
generateCode += " childrenClassesBuilder.put(\n " + superPlan
|
||||
+ ".class,\n ImmutableSet.<Class<? extends Plan>>of(\n" + childClassesString
|
||||
+ "\n )\n );\n\n";
|
||||
}
|
||||
|
||||
generateCode += " CHILDREN_CLASS_MAP = childrenClassesBuilder.build();\n";
|
||||
|
||||
return generateCode + " }\n}\n";
|
||||
}
|
||||
}
|
||||
@ -195,9 +195,20 @@ public class RuntimeFilterPruner extends PlanPostProcessor {
|
||||
@Override
|
||||
public PhysicalFilter visitPhysicalFilter(PhysicalFilter<? extends Plan> filter, CascadesContext context) {
|
||||
filter.child().accept(this, context);
|
||||
boolean visibleFilter = filter.getExpressions().stream()
|
||||
.flatMap(expression -> expression.getInputSlots().stream())
|
||||
.anyMatch(slot -> isVisibleColumn(slot));
|
||||
|
||||
boolean visibleFilter = false;
|
||||
|
||||
for (Expression expr : filter.getExpressions()) {
|
||||
for (Slot inputSlot : expr.getInputSlots()) {
|
||||
if (isVisibleColumn(inputSlot)) {
|
||||
visibleFilter = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (visibleFilter) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (visibleFilter) {
|
||||
// skip filters like: __DORIS_DELETE_SIGN__ = 0
|
||||
context.getRuntimeFilterContext().addEffectiveSrcNode(filter, RuntimeFilterContext.EffectiveSrcType.NATIVE);
|
||||
|
||||
@ -26,6 +26,8 @@ import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
|
||||
import org.apache.doris.nereids.trees.plans.physical.PhysicalFilter;
|
||||
import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
|
||||
import org.apache.doris.nereids.util.PlanUtils;
|
||||
import org.apache.doris.nereids.util.Utils;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
|
||||
@ -69,7 +71,10 @@ public class Validator extends PlanPostProcessor {
|
||||
|
||||
@Override
|
||||
public Plan visit(Plan plan, CascadesContext context) {
|
||||
plan.children().forEach(child -> child.accept(this, context));
|
||||
for (Plan child : plan.children()) {
|
||||
child.accept(this, context);
|
||||
}
|
||||
|
||||
Optional<Slot> opt = checkAllSlotFromChildren(plan);
|
||||
if (opt.isPresent()) {
|
||||
List<Slot> childrenOutput = plan.children().stream().flatMap(p -> p.getOutput().stream()).collect(
|
||||
@ -93,8 +98,7 @@ public class Validator extends PlanPostProcessor {
|
||||
if (plan instanceof Aggregate) {
|
||||
return Optional.empty();
|
||||
}
|
||||
Set<Slot> childOutputSet = plan.children().stream().flatMap(child -> child.getOutputSet().stream())
|
||||
.collect(Collectors.toSet());
|
||||
Set<Slot> childOutputSet = Utils.fastToImmutableSet(PlanUtils.fastGetChildrenOutputs(plan.children()));
|
||||
Set<Slot> inputSlots = plan.getInputSlots();
|
||||
for (Slot slot : inputSlots) {
|
||||
if (slot.getName().startsWith("mv") || slot instanceof SlotNotFromChildren) {
|
||||
|
||||
@ -20,6 +20,7 @@ package org.apache.doris.nereids.properties;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
import com.google.common.collect.Sets;
|
||||
|
||||
import java.util.HashSet;
|
||||
import java.util.Map;
|
||||
@ -196,12 +197,23 @@ public class FunctionalDependencies {
|
||||
}
|
||||
|
||||
public void removeNotContain(Set<Slot> slotSet) {
|
||||
slots = slots.stream()
|
||||
.filter(slotSet::contains)
|
||||
.collect(Collectors.toSet());
|
||||
slotSets = slotSets.stream()
|
||||
.filter(slotSet::containsAll)
|
||||
.collect(Collectors.toSet());
|
||||
if (!slotSet.isEmpty()) {
|
||||
Set<Slot> newSlots = Sets.newLinkedHashSetWithExpectedSize(slots.size());
|
||||
for (Slot slot : slots) {
|
||||
if (slotSet.contains(slot)) {
|
||||
newSlots.add(slot);
|
||||
}
|
||||
}
|
||||
this.slots = newSlots;
|
||||
|
||||
Set<ImmutableSet<Slot>> newSlotSets = Sets.newLinkedHashSetWithExpectedSize(slots.size());
|
||||
for (ImmutableSet<Slot> set : slotSets) {
|
||||
if (slotSet.containsAll(set)) {
|
||||
newSlotSets.add(set);
|
||||
}
|
||||
}
|
||||
this.slotSets = newSlotSets;
|
||||
}
|
||||
}
|
||||
|
||||
public void add(Slot slot) {
|
||||
|
||||
@ -19,7 +19,6 @@ package org.apache.doris.nereids.properties;
|
||||
|
||||
import org.apache.doris.common.Id;
|
||||
import org.apache.doris.nereids.trees.expressions.ExprId;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
|
||||
import com.google.common.base.Supplier;
|
||||
@ -62,21 +61,40 @@ public class LogicalProperties {
|
||||
this.outputSupplier = Suppliers.memoize(
|
||||
Objects.requireNonNull(outputSupplier, "outputSupplier can not be null")
|
||||
);
|
||||
this.outputExprIdsSupplier = Suppliers.memoize(
|
||||
() -> this.outputSupplier.get().stream().map(NamedExpression::getExprId).map(Id.class::cast)
|
||||
.collect(ImmutableList.toImmutableList())
|
||||
);
|
||||
this.outputSetSupplier = Suppliers.memoize(
|
||||
() -> ImmutableSet.copyOf(this.outputSupplier.get())
|
||||
);
|
||||
this.outputMapSupplier = Suppliers.memoize(
|
||||
() -> this.outputSetSupplier.get().stream().collect(ImmutableMap.toImmutableMap(s -> s, s -> s))
|
||||
);
|
||||
this.outputExprIdSetSupplier = Suppliers.memoize(
|
||||
() -> this.outputSupplier.get().stream()
|
||||
.map(NamedExpression::getExprId)
|
||||
.collect(ImmutableSet.toImmutableSet())
|
||||
);
|
||||
this.outputExprIdsSupplier = Suppliers.memoize(() -> {
|
||||
List<Slot> output = this.outputSupplier.get();
|
||||
ImmutableList.Builder<Id> exprIdSet
|
||||
= ImmutableList.builderWithExpectedSize(output.size());
|
||||
for (Slot slot : output) {
|
||||
exprIdSet.add(slot.getExprId());
|
||||
}
|
||||
return exprIdSet.build();
|
||||
});
|
||||
this.outputSetSupplier = Suppliers.memoize(() -> {
|
||||
List<Slot> output = outputSupplier.get();
|
||||
ImmutableSet.Builder<Slot> slots = ImmutableSet.builderWithExpectedSize(output.size());
|
||||
for (Slot slot : output) {
|
||||
slots.add(slot);
|
||||
}
|
||||
return slots.build();
|
||||
});
|
||||
this.outputMapSupplier = Suppliers.memoize(() -> {
|
||||
Set<Slot> slots = outputSetSupplier.get();
|
||||
ImmutableMap.Builder<Slot, Slot> map = ImmutableMap.builderWithExpectedSize(slots.size());
|
||||
for (Slot slot : slots) {
|
||||
map.put(slot, slot);
|
||||
}
|
||||
return map.build();
|
||||
});
|
||||
this.outputExprIdSetSupplier = Suppliers.memoize(() -> {
|
||||
List<Slot> output = this.outputSupplier.get();
|
||||
ImmutableSet.Builder<ExprId> exprIdSet
|
||||
= ImmutableSet.builderWithExpectedSize(output.size());
|
||||
for (Slot slot : output) {
|
||||
exprIdSet.add(slot.getExprId());
|
||||
}
|
||||
return exprIdSet.build();
|
||||
});
|
||||
this.fdSupplier = Suppliers.memoize(
|
||||
Objects.requireNonNull(fdSupplier, "FunctionalDependencies can not be null")
|
||||
);
|
||||
|
||||
@ -24,8 +24,8 @@ import org.apache.doris.nereids.pattern.Pattern;
|
||||
import org.apache.doris.nereids.rules.RuleType.RuleTypeClass;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
|
||||
import java.util.BitSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* Abstract class for all rules.
|
||||
@ -79,8 +79,8 @@ public abstract class Rule {
|
||||
/**
|
||||
* Filter out already applied rules and rules that are not matched on root node.
|
||||
*/
|
||||
public boolean isInvalid(Set<Integer> disableRules, GroupExpression groupExpression) {
|
||||
return disableRules.contains(this.getRuleType().type())
|
||||
public boolean isInvalid(BitSet disableRules, GroupExpression groupExpression) {
|
||||
return disableRules.get(this.getRuleType().type())
|
||||
|| !groupExpression.notApplied(this)
|
||||
|| !this.getPattern().matchRoot(groupExpression.getPlan());
|
||||
}
|
||||
|
||||
@ -49,6 +49,7 @@ import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewProjectAggr
|
||||
import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewProjectFilterAggregateRule;
|
||||
import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewProjectFilterJoinRule;
|
||||
import org.apache.doris.nereids.rules.exploration.mv.MaterializedViewProjectJoinRule;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionOptimization;
|
||||
import org.apache.doris.nereids.rules.implementation.AggregateStrategies;
|
||||
import org.apache.doris.nereids.rules.implementation.LogicalAssertNumRowsToPhysicalAssertNumRows;
|
||||
import org.apache.doris.nereids.rules.implementation.LogicalCTEAnchorToPhysicalCTEAnchor;
|
||||
@ -153,7 +154,8 @@ public class RuleSet {
|
||||
new MergeLimits(),
|
||||
new PushDownAliasThroughJoin(),
|
||||
new PushDownFilterThroughWindow(),
|
||||
new PushDownFilterThroughPartitionTopN()
|
||||
new PushDownFilterThroughPartitionTopN(),
|
||||
new ExpressionOptimization()
|
||||
);
|
||||
|
||||
public static final List<Rule> IMPLEMENTATION_RULES = planRuleFactories()
|
||||
|
||||
@ -46,21 +46,30 @@ public class AdjustAggregateNullableForEmptySet implements RewriteRuleFactory {
|
||||
RuleType.ADJUST_NULLABLE_FOR_AGGREGATE_SLOT.build(
|
||||
logicalAggregate()
|
||||
.then(agg -> {
|
||||
List<NamedExpression> output = agg.getOutputExpressions().stream()
|
||||
.map(ne -> ((NamedExpression) FunctionReplacer.INSTANCE.replace(ne,
|
||||
agg.getGroupByExpressions().isEmpty())))
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
return agg.withAggOutput(output);
|
||||
List<NamedExpression> outputExprs = agg.getOutputExpressions();
|
||||
boolean noGroupBy = agg.getGroupByExpressions().isEmpty();
|
||||
ImmutableList.Builder<NamedExpression> newOutput
|
||||
= ImmutableList.builderWithExpectedSize(outputExprs.size());
|
||||
for (NamedExpression ne : outputExprs) {
|
||||
NamedExpression newExpr =
|
||||
((NamedExpression) FunctionReplacer.INSTANCE.replace(ne, noGroupBy));
|
||||
newOutput.add(newExpr);
|
||||
}
|
||||
return agg.withAggOutput(newOutput.build());
|
||||
})
|
||||
),
|
||||
RuleType.ADJUST_NULLABLE_FOR_HAVING_SLOT.build(
|
||||
logicalHaving(logicalAggregate())
|
||||
.then(having -> {
|
||||
Set<Expression> newConjuncts = having.getConjuncts().stream()
|
||||
.map(ne -> FunctionReplacer.INSTANCE.replace(ne,
|
||||
having.child().getGroupByExpressions().isEmpty()))
|
||||
.collect(ImmutableSet.toImmutableSet());
|
||||
return new LogicalHaving<>(newConjuncts, having.child());
|
||||
Set<Expression> conjuncts = having.getConjuncts();
|
||||
boolean noGroupBy = having.child().getGroupByExpressions().isEmpty();
|
||||
ImmutableSet.Builder<Expression> newConjuncts
|
||||
= ImmutableSet.builderWithExpectedSize(conjuncts.size());
|
||||
for (Expression expr : conjuncts) {
|
||||
Expression newExpr = FunctionReplacer.INSTANCE.replace(expr, noGroupBy);
|
||||
newConjuncts.add(newExpr);
|
||||
}
|
||||
return new LogicalHaving<>(newConjuncts.build(), having.child());
|
||||
})
|
||||
)
|
||||
);
|
||||
|
||||
@ -333,10 +333,11 @@ public class BindExpression implements AnalysisRuleFactory {
|
||||
List<LogicalPlan> relations
|
||||
= Lists.newArrayListWithCapacity(logicalInlineTable.getConstantExprsList().size());
|
||||
for (int i = 0; i < logicalInlineTable.getConstantExprsList().size(); i++) {
|
||||
if (logicalInlineTable.getConstantExprsList().get(i).stream()
|
||||
.anyMatch(DefaultValueSlot.class::isInstance)) {
|
||||
throw new AnalysisException("Default expression"
|
||||
+ " can't exist in SELECT statement at row " + (i + 1));
|
||||
for (NamedExpression constantExpr : logicalInlineTable.getConstantExprsList().get(i)) {
|
||||
if (constantExpr instanceof DefaultValueSlot) {
|
||||
throw new AnalysisException("Default expression"
|
||||
+ " can't exist in SELECT statement at row " + (i + 1));
|
||||
}
|
||||
}
|
||||
relations.add(new UnboundOneRowRelation(StatementScopeIdGenerator.newRelationId(),
|
||||
logicalInlineTable.getConstantExprsList().get(i)));
|
||||
@ -590,7 +591,7 @@ public class BindExpression implements AnalysisRuleFactory {
|
||||
SimpleExprAnalyzer analyzer = buildSimpleExprAnalyzer(
|
||||
filter, cascadesContext, filter.children(), true, true);
|
||||
ImmutableSet.Builder<Expression> boundConjuncts = ImmutableSet.builderWithExpectedSize(
|
||||
filter.getConjuncts().size() * 2);
|
||||
filter.getConjuncts().size());
|
||||
for (Expression conjunct : filter.getConjuncts()) {
|
||||
Expression boundConjunct = analyzer.analyze(conjunct);
|
||||
boundConjunct = TypeCoercionUtils.castIfNotSameType(boundConjunct, BooleanType.INSTANCE);
|
||||
@ -828,15 +829,22 @@ public class BindExpression implements AnalysisRuleFactory {
|
||||
if (output.stream().noneMatch(Alias.class::isInstance)) {
|
||||
return;
|
||||
}
|
||||
List<Alias> aliasList = output.stream().filter(Alias.class::isInstance)
|
||||
.map(Alias.class::cast).collect(Collectors.toList());
|
||||
List<Alias> aliasList = ExpressionUtils.filter(output, Alias.class);
|
||||
|
||||
List<NamedExpression> exprAliasList =
|
||||
ExpressionUtils.collectAll(expressions, NamedExpression.class::isInstance);
|
||||
|
||||
boolean isGroupByContainAlias = exprAliasList.stream().anyMatch(ne ->
|
||||
aliasList.stream().anyMatch(alias -> !alias.getExprId().equals(ne.getExprId())
|
||||
&& alias.getName().equals(ne.getName())));
|
||||
boolean isGroupByContainAlias = false;
|
||||
for (NamedExpression ne : exprAliasList) {
|
||||
for (Alias alias : aliasList) {
|
||||
if (!alias.getExprId().equals(ne.getExprId()) && alias.getName().equalsIgnoreCase(ne.getName())) {
|
||||
isGroupByContainAlias = true;
|
||||
}
|
||||
}
|
||||
if (isGroupByContainAlias) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (isGroupByContainAlias
|
||||
&& ConnectContext.get() != null
|
||||
|
||||
@ -22,7 +22,6 @@ import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
@ -34,7 +33,6 @@ import com.google.common.collect.ImmutableList;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Rule to bind slot with path in query plan.
|
||||
@ -60,21 +58,18 @@ public class BindSlotWithPaths implements AnalysisRuleFactory {
|
||||
Set<SlotReference> pathsSlots = ctx.statementContext.getAllPathsSlots();
|
||||
// With new logical properties that contains new slots with paths
|
||||
StatementContext stmtCtx = ConnectContext.get().getStatementContext();
|
||||
List<Slot> olapScanPathSlots = pathsSlots.stream().filter(
|
||||
slot -> {
|
||||
Preconditions.checkNotNull(stmtCtx.getRelationBySlot(slot),
|
||||
"[Not implemented] Slot not found in relation map, slot ", slot);
|
||||
return stmtCtx.getRelationBySlot(slot).getRelationId()
|
||||
== logicalOlapScan.getRelationId();
|
||||
}).collect(
|
||||
Collectors.toList());
|
||||
List<NamedExpression> newExprs = olapScanPathSlots.stream()
|
||||
.map(SlotReference.class::cast)
|
||||
.map(slotReference ->
|
||||
new Alias(slotReference.getExprId(),
|
||||
stmtCtx.getOriginalExpr(slotReference), slotReference.getName()))
|
||||
.collect(
|
||||
Collectors.toList());
|
||||
ImmutableList.Builder<NamedExpression> newExprsBuilder
|
||||
= ImmutableList.builderWithExpectedSize(pathsSlots.size());
|
||||
for (SlotReference slot : pathsSlots) {
|
||||
Preconditions.checkNotNull(stmtCtx.getRelationBySlot(slot),
|
||||
"[Not implemented] Slot not found in relation map, slot ", slot);
|
||||
if (stmtCtx.getRelationBySlot(slot).getRelationId()
|
||||
== logicalOlapScan.getRelationId()) {
|
||||
newExprsBuilder.add(new Alias(slot.getExprId(),
|
||||
stmtCtx.getOriginalExpr(slot), slot.getName()));
|
||||
}
|
||||
}
|
||||
ImmutableList<NamedExpression> newExprs = newExprsBuilder.build();
|
||||
if (newExprs.isEmpty()) {
|
||||
return ctx.root;
|
||||
}
|
||||
|
||||
@ -46,6 +46,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
|
||||
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.List;
|
||||
@ -69,42 +70,43 @@ public class CheckAfterRewrite extends OneAnalysisRuleFactory {
|
||||
}
|
||||
|
||||
private void checkUnexpectedExpression(Plan plan) {
|
||||
if (plan.getExpressions().stream().anyMatch(e -> e.anyMatch(SubqueryExpr.class::isInstance))) {
|
||||
throw new AnalysisException("Subquery is not allowed in " + plan.getType());
|
||||
}
|
||||
if (!(plan instanceof Generate)) {
|
||||
if (plan.getExpressions().stream().anyMatch(e -> e.anyMatch(TableGeneratingFunction.class::isInstance))) {
|
||||
throw new AnalysisException("table generating function is not allowed in " + plan.getType());
|
||||
}
|
||||
}
|
||||
if (!(plan instanceof LogicalAggregate || plan instanceof LogicalWindow)) {
|
||||
if (plan.getExpressions().stream().anyMatch(e -> e.anyMatch(AggregateFunction.class::isInstance))) {
|
||||
throw new AnalysisException("aggregate function is not allowed in " + plan.getType());
|
||||
}
|
||||
}
|
||||
if (!(plan instanceof LogicalAggregate)) {
|
||||
if (plan.getExpressions().stream().anyMatch(e -> e.anyMatch(GroupingScalarFunction.class::isInstance))) {
|
||||
throw new AnalysisException("grouping scalar function is not allowed in " + plan.getType());
|
||||
}
|
||||
}
|
||||
if (!(plan instanceof LogicalWindow)) {
|
||||
if (plan.getExpressions().stream().anyMatch(e -> e.anyMatch(WindowExpression.class::isInstance))) {
|
||||
throw new AnalysisException("analytic function is not allowed in " + plan.getType());
|
||||
}
|
||||
boolean isGenerate = plan instanceof Generate;
|
||||
boolean isAgg = plan instanceof LogicalAggregate;
|
||||
boolean isWindow = plan instanceof LogicalWindow;
|
||||
boolean notAggAndWindow = !isAgg && !isWindow;
|
||||
|
||||
for (Expression expression : plan.getExpressions()) {
|
||||
expression.foreach(expr -> {
|
||||
if (expr instanceof SubqueryExpr) {
|
||||
throw new AnalysisException("Subquery is not allowed in " + plan.getType());
|
||||
} else if (!isGenerate && expr instanceof TableGeneratingFunction) {
|
||||
throw new AnalysisException("table generating function is not allowed in " + plan.getType());
|
||||
} else if (notAggAndWindow && expr instanceof AggregateFunction) {
|
||||
throw new AnalysisException("aggregate function is not allowed in " + plan.getType());
|
||||
} else if (!isAgg && expr instanceof GroupingScalarFunction) {
|
||||
throw new AnalysisException("grouping scalar function is not allowed in " + plan.getType());
|
||||
} else if (!isWindow && expr instanceof WindowExpression) {
|
||||
throw new AnalysisException("analytic function is not allowed in " + plan.getType());
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
private void checkAllSlotReferenceFromChildren(Plan plan) {
|
||||
Set<Slot> notFromChildren = plan.getExpressions().stream()
|
||||
.flatMap(expr -> expr.getInputSlots().stream())
|
||||
.collect(Collectors.toSet());
|
||||
Set<ExprId> childrenOutput = plan.children().stream()
|
||||
.flatMap(child -> child.getOutput().stream())
|
||||
.map(NamedExpression::getExprId)
|
||||
.collect(Collectors.toSet());
|
||||
notFromChildren = notFromChildren.stream()
|
||||
.filter(s -> !childrenOutput.contains(s.getExprId()))
|
||||
.collect(Collectors.toSet());
|
||||
Set<Slot> inputSlots = plan.getInputSlots();
|
||||
Set<ExprId> childrenOutput = plan.getChildrenOutputExprIdSet();
|
||||
|
||||
ImmutableSet.Builder<Slot> notFromChildrenBuilder = ImmutableSet.builderWithExpectedSize(inputSlots.size());
|
||||
for (Slot inputSlot : inputSlots) {
|
||||
if (!childrenOutput.contains(inputSlot.getExprId())) {
|
||||
notFromChildrenBuilder.add(inputSlot);
|
||||
}
|
||||
}
|
||||
Set<Slot> notFromChildren = notFromChildrenBuilder.build();
|
||||
if (notFromChildren.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
notFromChildren = removeValidSlotsNotFromChildren(notFromChildren, childrenOutput);
|
||||
if (!notFromChildren.isEmpty()) {
|
||||
if (plan.arity() != 0 && plan.child(0) instanceof LogicalAggregate) {
|
||||
@ -181,17 +183,18 @@ public class CheckAfterRewrite extends OneAnalysisRuleFactory {
|
||||
}
|
||||
|
||||
private void checkMatchIsUsedCorrectly(Plan plan) {
|
||||
if (plan.getExpressions().stream().anyMatch(
|
||||
expression -> expression instanceof Match)) {
|
||||
if (plan instanceof LogicalFilter && (plan.child(0) instanceof LogicalOlapScan
|
||||
|| plan.child(0) instanceof LogicalDeferMaterializeOlapScan
|
||||
|| plan.child(0) instanceof LogicalProject
|
||||
for (Expression expression : plan.getExpressions()) {
|
||||
if (expression instanceof Match) {
|
||||
if (plan instanceof LogicalFilter && (plan.child(0) instanceof LogicalOlapScan
|
||||
|| plan.child(0) instanceof LogicalDeferMaterializeOlapScan
|
||||
|| plan.child(0) instanceof LogicalProject
|
||||
&& ((LogicalProject<?>) plan.child(0)).hasPushedDownToProjectionFunctions())) {
|
||||
return;
|
||||
} else {
|
||||
throw new AnalysisException(String.format(
|
||||
"Not support match in %s in plan: %s, only support in olapScan filter",
|
||||
plan.child(0), plan));
|
||||
return;
|
||||
} else {
|
||||
throw new AnalysisException(String.format(
|
||||
"Not support match in %s in plan: %s, only support in olapScan filter",
|
||||
plan.child(0), plan));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -45,7 +45,6 @@ import com.google.common.collect.ImmutableSet;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
@ -117,14 +116,16 @@ public class CheckAnalysis implements AnalysisRuleFactory {
|
||||
if (unexpectedExpressionTypes.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
plan.getExpressions().forEach(c -> c.foreachUp(e -> {
|
||||
for (Class<? extends Expression> type : unexpectedExpressionTypes) {
|
||||
if (type.isInstance(e)) {
|
||||
throw new AnalysisException(plan.getType() + " can not contains "
|
||||
+ type.getSimpleName() + " expression: " + ((Expression) e).toSql());
|
||||
for (Expression expr : plan.getExpressions()) {
|
||||
expr.foreachUp(e -> {
|
||||
for (Class<? extends Expression> type : unexpectedExpressionTypes) {
|
||||
if (type.isInstance(e)) {
|
||||
throw new AnalysisException(plan.getType() + " can not contains "
|
||||
+ type.getSimpleName() + " expression: " + ((Expression) e).toSql());
|
||||
}
|
||||
}
|
||||
}
|
||||
}));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
private void checkExpressionInputTypes(Plan plan) {
|
||||
@ -157,20 +158,21 @@ public class CheckAnalysis implements AnalysisRuleFactory {
|
||||
break;
|
||||
}
|
||||
}
|
||||
long distinctFunctionNum = aggregateFunctions.stream()
|
||||
.filter(AggregateFunction::isDistinct)
|
||||
.count();
|
||||
|
||||
long distinctFunctionNum = 0;
|
||||
for (AggregateFunction aggregateFunction : aggregateFunctions) {
|
||||
distinctFunctionNum += aggregateFunction.isDistinct() ? 1 : 0;
|
||||
}
|
||||
|
||||
if (distinctMultiColumns && distinctFunctionNum > 1) {
|
||||
throw new AnalysisException(
|
||||
"The query contains multi count distinct or sum distinct, each can't have multi columns");
|
||||
}
|
||||
Optional<Expression> expr = aggregate.getGroupByExpressions().stream()
|
||||
.filter(expression -> expression.containsType(AggregateFunction.class)).findFirst();
|
||||
if (expr.isPresent()) {
|
||||
throw new AnalysisException(
|
||||
"GROUP BY expression must not contain aggregate functions: "
|
||||
+ expr.get().toSql());
|
||||
for (Expression expr : aggregate.getGroupByExpressions()) {
|
||||
if (expr.anyMatch(AggregateFunction.class::isInstance)) {
|
||||
throw new AnalysisException(
|
||||
"GROUP BY expression must not contain aggregate functions: " + expr.toSql());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -64,7 +64,7 @@ public class EliminateGroupByConstant extends OneRewriteRuleFactory {
|
||||
// because we rely on expression matching to replace subtree that same as group by expr in output
|
||||
// if we do constant folding before normalize aggregate, the subtree will change and matching fail
|
||||
// such as: select a + 1 + 2 + 3, sum(b) from t group by a + 1 + 2
|
||||
Expression foldExpression = FoldConstantRule.INSTANCE.rewrite(expression, context);
|
||||
Expression foldExpression = FoldConstantRule.evaluate(expression, context);
|
||||
if (!foldExpression.isConstant()) {
|
||||
slotGroupByExprs.add(expression);
|
||||
} else {
|
||||
|
||||
@ -297,7 +297,7 @@ public class ExpressionAnalyzer extends SubExprAnalyzer<ExpressionRewriteContext
|
||||
if (unboundFunction.isHighOrder()) {
|
||||
unboundFunction = bindHighOrderFunction(unboundFunction, context);
|
||||
} else {
|
||||
unboundFunction = (UnboundFunction) rewriteChildren(this, unboundFunction, context);
|
||||
unboundFunction = (UnboundFunction) super.visit(unboundFunction, context);
|
||||
}
|
||||
|
||||
// bind function
|
||||
|
||||
@ -316,13 +316,18 @@ public class FillUpMissingSlots implements AnalysisRuleFactory {
|
||||
}
|
||||
|
||||
private boolean checkSort(LogicalSort<? extends Plan> logicalSort) {
|
||||
return logicalSort.getOrderKeys().stream()
|
||||
.map(OrderKey::getExpr)
|
||||
.map(Expression::getInputSlots)
|
||||
.flatMap(Set::stream)
|
||||
.anyMatch(s -> !logicalSort.child().getOutputSet().contains(s))
|
||||
|| logicalSort.getOrderKeys().stream()
|
||||
.map(OrderKey::getExpr)
|
||||
.anyMatch(e -> e.containsType(AggregateFunction.class));
|
||||
Plan child = logicalSort.child();
|
||||
for (OrderKey orderKey : logicalSort.getOrderKeys()) {
|
||||
Expression expr = orderKey.getExpr();
|
||||
if (expr.anyMatch(AggregateFunction.class::isInstance)) {
|
||||
return true;
|
||||
}
|
||||
for (Slot inputSlot : expr.getInputSlots()) {
|
||||
if (!child.getOutputSet().contains(inputSlot)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@ -38,6 +38,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
import org.apache.doris.nereids.util.PlanUtils.CollectNonWindowedAggFuncs;
|
||||
import org.apache.doris.nereids.util.Utils;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableList.Builder;
|
||||
@ -139,8 +140,7 @@ public class NormalizeAggregate implements RewriteRuleFactory, NormalizeToSlot {
|
||||
|
||||
// Push down exprs:
|
||||
// collect group by exprs
|
||||
Set<Expression> groupingByExprs =
|
||||
ImmutableSet.copyOf(aggregate.getGroupByExpressions());
|
||||
Set<Expression> groupingByExprs = Utils.fastToImmutableSet(aggregate.getGroupByExpressions());
|
||||
|
||||
// collect all trivial-agg
|
||||
List<NamedExpression> aggregateOutput = aggregate.getOutputExpressions();
|
||||
@ -149,27 +149,31 @@ public class NormalizeAggregate implements RewriteRuleFactory, NormalizeToSlot {
|
||||
// split non-distinct agg child as two part
|
||||
// TRUE part 1: need push down itself, if it contains subquery or window expression
|
||||
// FALSE part 2: need push down its input slots, if it DOES NOT contain subquery or window expression
|
||||
Map<Boolean, Set<Expression>> categorizedNoDistinctAggsChildren = aggFuncs.stream()
|
||||
Map<Boolean, ImmutableSet<Expression>> categorizedNoDistinctAggsChildren = aggFuncs.stream()
|
||||
.filter(aggFunc -> !aggFunc.isDistinct())
|
||||
.flatMap(agg -> agg.children().stream())
|
||||
.collect(Collectors.groupingBy(
|
||||
child -> child.containsType(SubqueryExpr.class, WindowExpression.class),
|
||||
Collectors.toSet()));
|
||||
ImmutableSet.toImmutableSet()));
|
||||
|
||||
// split distinct agg child as two parts
|
||||
// TRUE part 1: need push down itself, if it is NOT SlotReference or Literal
|
||||
// FALSE part 2: need push down its input slots, if it is SlotReference or Literal
|
||||
Map<Boolean, Set<Expression>> categorizedDistinctAggsChildren = aggFuncs.stream()
|
||||
Map<Object, ImmutableSet<Expression>> categorizedDistinctAggsChildren = aggFuncs.stream()
|
||||
.filter(AggregateFunction::isDistinct)
|
||||
.flatMap(agg -> agg.children().stream())
|
||||
.collect(Collectors.groupingBy(child -> !(child instanceof SlotReference), Collectors.toSet()));
|
||||
.collect(
|
||||
Collectors.groupingBy(
|
||||
child -> !(child instanceof SlotReference),
|
||||
ImmutableSet.toImmutableSet())
|
||||
);
|
||||
|
||||
Set<Expression> needPushSelf = Sets.union(
|
||||
categorizedNoDistinctAggsChildren.getOrDefault(true, new HashSet<>()),
|
||||
categorizedDistinctAggsChildren.getOrDefault(true, new HashSet<>()));
|
||||
categorizedNoDistinctAggsChildren.getOrDefault(true, ImmutableSet.of()),
|
||||
categorizedDistinctAggsChildren.getOrDefault(true, ImmutableSet.of()));
|
||||
Set<Slot> needPushInputSlots = ExpressionUtils.getInputSlotSet(Sets.union(
|
||||
categorizedNoDistinctAggsChildren.getOrDefault(false, new HashSet<>()),
|
||||
categorizedDistinctAggsChildren.getOrDefault(false, new HashSet<>())));
|
||||
categorizedNoDistinctAggsChildren.getOrDefault(false, ImmutableSet.of()),
|
||||
categorizedDistinctAggsChildren.getOrDefault(false, ImmutableSet.of())));
|
||||
|
||||
Set<Alias> existsAlias =
|
||||
ExpressionUtils.mutableCollect(aggregateOutput, Alias.class::isInstance);
|
||||
@ -194,8 +198,7 @@ public class NormalizeAggregate implements RewriteRuleFactory, NormalizeToSlot {
|
||||
// create bottom project
|
||||
Plan bottomPlan;
|
||||
if (!bottomProjects.isEmpty()) {
|
||||
bottomPlan = new LogicalProject<>(ImmutableList.copyOf(bottomProjects),
|
||||
aggregate.child());
|
||||
bottomPlan = new LogicalProject<>(ImmutableList.copyOf(bottomProjects), aggregate.child());
|
||||
} else {
|
||||
bottomPlan = aggregate.child();
|
||||
}
|
||||
@ -230,13 +233,17 @@ public class NormalizeAggregate implements RewriteRuleFactory, NormalizeToSlot {
|
||||
|
||||
// agg output include 2 parts
|
||||
// pushedGroupByExprs and normalized agg functions
|
||||
List<NamedExpression> normalizedAggOutput = ImmutableList.<NamedExpression>builder()
|
||||
.addAll(pushedGroupByExprs.stream().map(NamedExpression::toSlot).iterator())
|
||||
.addAll(normalizedAggFuncsToSlotContext
|
||||
.pushDownToNamedExpression(normalizedAggFuncs))
|
||||
.build();
|
||||
|
||||
ImmutableList.Builder<NamedExpression> normalizedAggOutputBuilder
|
||||
= ImmutableList.builderWithExpectedSize(groupingByExprs.size() + normalizedAggFuncs.size());
|
||||
for (NamedExpression pushedGroupByExpr : pushedGroupByExprs) {
|
||||
normalizedAggOutputBuilder.add(pushedGroupByExpr.toSlot());
|
||||
}
|
||||
normalizedAggOutputBuilder.addAll(
|
||||
normalizedAggFuncsToSlotContext.pushDownToNamedExpression(normalizedAggFuncs)
|
||||
);
|
||||
// create new agg node
|
||||
ImmutableList<NamedExpression> normalizedAggOutput = normalizedAggOutputBuilder.build();
|
||||
LogicalAggregate<?> newAggregate =
|
||||
aggregate.withNormalized(normalizedGroupExprs, normalizedAggOutput, bottomPlan);
|
||||
|
||||
|
||||
@ -22,6 +22,7 @@ import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
|
||||
@ -35,7 +36,6 @@ import com.google.common.collect.Maps;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
|
||||
/**
|
||||
* replace.
|
||||
@ -47,52 +47,50 @@ public class ReplaceExpressionByChildOutput implements AnalysisRuleFactory {
|
||||
.add(RuleType.REPLACE_SORT_EXPRESSION_BY_CHILD_OUTPUT.build(
|
||||
logicalSort(logicalProject()).then(sort -> {
|
||||
LogicalProject<Plan> project = sort.child();
|
||||
Map<Expression, Slot> sMap = Maps.newHashMap();
|
||||
project.getProjects().stream()
|
||||
.filter(Alias.class::isInstance)
|
||||
.map(Alias.class::cast)
|
||||
.forEach(p -> sMap.put(p.child(), p.toSlot()));
|
||||
Map<Expression, Slot> sMap = buildOutputAliasMap(project.getProjects());
|
||||
return replaceSortExpression(sort, sMap);
|
||||
})
|
||||
))
|
||||
.add(RuleType.REPLACE_SORT_EXPRESSION_BY_CHILD_OUTPUT.build(
|
||||
logicalSort(logicalAggregate()).then(sort -> {
|
||||
LogicalAggregate<Plan> aggregate = sort.child();
|
||||
Map<Expression, Slot> sMap = Maps.newHashMap();
|
||||
aggregate.getOutputExpressions().stream()
|
||||
.filter(Alias.class::isInstance)
|
||||
.map(Alias.class::cast)
|
||||
.forEach(p -> sMap.put(p.child(), p.toSlot()));
|
||||
Map<Expression, Slot> sMap = buildOutputAliasMap(aggregate.getOutputExpressions());
|
||||
return replaceSortExpression(sort, sMap);
|
||||
})
|
||||
)).add(RuleType.REPLACE_SORT_EXPRESSION_BY_CHILD_OUTPUT.build(
|
||||
logicalSort(logicalHaving(logicalAggregate())).then(sort -> {
|
||||
LogicalAggregate<Plan> aggregate = sort.child().child();
|
||||
Map<Expression, Slot> sMap = Maps.newHashMap();
|
||||
aggregate.getOutputExpressions().stream()
|
||||
.filter(Alias.class::isInstance)
|
||||
.map(Alias.class::cast)
|
||||
.forEach(p -> sMap.put(p.child(), p.toSlot()));
|
||||
Map<Expression, Slot> sMap = buildOutputAliasMap(aggregate.getOutputExpressions());
|
||||
return replaceSortExpression(sort, sMap);
|
||||
})
|
||||
))
|
||||
.build();
|
||||
}
|
||||
|
||||
private Map<Expression, Slot> buildOutputAliasMap(List<NamedExpression> output) {
|
||||
Map<Expression, Slot> sMap = Maps.newHashMapWithExpectedSize(output.size());
|
||||
for (NamedExpression expr : output) {
|
||||
if (expr instanceof Alias) {
|
||||
Alias alias = (Alias) expr;
|
||||
sMap.put(alias.child(), alias.toSlot());
|
||||
}
|
||||
}
|
||||
return sMap;
|
||||
}
|
||||
|
||||
private LogicalPlan replaceSortExpression(LogicalSort<? extends LogicalPlan> sort, Map<Expression, Slot> sMap) {
|
||||
List<OrderKey> orderKeys = sort.getOrderKeys();
|
||||
AtomicBoolean changed = new AtomicBoolean(false);
|
||||
List<OrderKey> newKeys = orderKeys.stream().map(k -> {
|
||||
|
||||
boolean changed = false;
|
||||
ImmutableList.Builder<OrderKey> newKeys = ImmutableList.builderWithExpectedSize(orderKeys.size());
|
||||
for (OrderKey k : orderKeys) {
|
||||
Expression newExpr = ExpressionUtils.replace(k.getExpr(), sMap);
|
||||
if (newExpr != k.getExpr()) {
|
||||
changed.set(true);
|
||||
changed = true;
|
||||
}
|
||||
return new OrderKey(newExpr, k.isAsc(), k.isNullFirst());
|
||||
}).collect(ImmutableList.toImmutableList());
|
||||
if (changed.get()) {
|
||||
return new LogicalSort<>(newKeys, sort.child());
|
||||
} else {
|
||||
return sort;
|
||||
newKeys.add(new OrderKey(newExpr, k.isAsc(), k.isNullFirst()));
|
||||
}
|
||||
|
||||
return changed ? new LogicalSort<>(newKeys.build(), sort.child()) : sort;
|
||||
}
|
||||
}
|
||||
|
||||
@ -21,6 +21,7 @@ import org.apache.doris.nereids.CascadesContext;
|
||||
import org.apache.doris.nereids.StatementContext;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
|
||||
import org.apache.doris.nereids.rules.expression.rules.TrySimplifyPredicateWithMarkJoinSlot;
|
||||
import org.apache.doris.nereids.trees.TreeNode;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
@ -51,6 +52,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
import org.apache.doris.nereids.util.Utils;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
@ -77,24 +79,21 @@ public class SubqueryToApply implements AnalysisRuleFactory {
|
||||
RuleType.FILTER_SUBQUERY_TO_APPLY.build(
|
||||
logicalFilter().thenApply(ctx -> {
|
||||
LogicalFilter<Plan> filter = ctx.root;
|
||||
ImmutableList<Set<SubqueryExpr>> subqueryExprsList = filter.getConjuncts().stream()
|
||||
.<Set<SubqueryExpr>>map(e -> e.collect(SubqueryToApply::canConvertToSupply))
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
if (subqueryExprsList.stream()
|
||||
.flatMap(Collection::stream).noneMatch(SubqueryExpr.class::isInstance)) {
|
||||
|
||||
Set<Expression> conjuncts = filter.getConjuncts();
|
||||
CollectSubquerys collectSubquerys = collectSubquerys(conjuncts);
|
||||
if (!collectSubquerys.hasSubquery) {
|
||||
return filter;
|
||||
}
|
||||
ImmutableList<Boolean> shouldOutputMarkJoinSlot =
|
||||
filter.getConjuncts().stream()
|
||||
.map(expr -> !(expr instanceof SubqueryExpr)
|
||||
&& expr.containsType(SubqueryExpr.class))
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
|
||||
List<Expression> oldConjuncts = ImmutableList.copyOf(filter.getConjuncts());
|
||||
ImmutableList.Builder<Expression> newConjuncts = new ImmutableList.Builder<>();
|
||||
List<Boolean> shouldOutputMarkJoinSlot = shouldOutputMarkJoinSlot(conjuncts);
|
||||
|
||||
List<Expression> oldConjuncts = Utils.fastToImmutableList(conjuncts);
|
||||
ImmutableSet.Builder<Expression> newConjuncts = new ImmutableSet.Builder<>();
|
||||
LogicalPlan applyPlan = null;
|
||||
LogicalPlan tmpPlan = (LogicalPlan) filter.child();
|
||||
|
||||
List<Set<SubqueryExpr>> subqueryExprsList = collectSubquerys.subqueies;
|
||||
// Subquery traversal with the conjunct of and as the granularity.
|
||||
for (int i = 0; i < subqueryExprsList.size(); ++i) {
|
||||
Set<SubqueryExpr> subqueryExprs = subqueryExprsList.get(i);
|
||||
@ -119,9 +118,11 @@ public class SubqueryToApply implements AnalysisRuleFactory {
|
||||
* if it's semi join with non-null mark slot
|
||||
* we can safely change the mark conjunct to hash conjunct
|
||||
*/
|
||||
ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(ctx.cascadesContext);
|
||||
boolean isMarkSlotNotNull = conjunct.containsType(MarkJoinSlotReference.class)
|
||||
? ExpressionUtils.canInferNotNullForMarkSlot(
|
||||
TrySimplifyPredicateWithMarkJoinSlot.INSTANCE.rewrite(conjunct, null))
|
||||
TrySimplifyPredicateWithMarkJoinSlot.INSTANCE.rewrite(conjunct,
|
||||
rewriteContext), rewriteContext)
|
||||
: false;
|
||||
|
||||
applyPlan = subqueryToApply(subqueryExprs.stream()
|
||||
@ -132,21 +133,22 @@ public class SubqueryToApply implements AnalysisRuleFactory {
|
||||
tmpPlan = applyPlan;
|
||||
newConjuncts.add(conjunct);
|
||||
}
|
||||
Set<Expression> conjuncts = ImmutableSet.copyOf(newConjuncts.build());
|
||||
Plan newFilter = new LogicalFilter<>(conjuncts, applyPlan);
|
||||
Plan newFilter = new LogicalFilter<>(newConjuncts.build(), applyPlan);
|
||||
return new LogicalProject<>(filter.getOutput().stream().collect(ImmutableList.toImmutableList()),
|
||||
newFilter);
|
||||
})
|
||||
),
|
||||
RuleType.PROJECT_SUBQUERY_TO_APPLY.build(logicalProject().thenApply(ctx -> {
|
||||
LogicalProject<Plan> project = ctx.root;
|
||||
ImmutableList<Set<SubqueryExpr>> subqueryExprsList = project.getProjects().stream()
|
||||
.<Set<SubqueryExpr>>map(e -> e.collect(SubqueryToApply::canConvertToSupply))
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
if (subqueryExprsList.stream().flatMap(Collection::stream).count() == 0) {
|
||||
|
||||
List<NamedExpression> projects = project.getProjects();
|
||||
CollectSubquerys collectSubquerys = collectSubquerys(projects);
|
||||
if (!collectSubquerys.hasSubquery) {
|
||||
return project;
|
||||
}
|
||||
List<NamedExpression> oldProjects = ImmutableList.copyOf(project.getProjects());
|
||||
|
||||
List<Set<SubqueryExpr>> subqueryExprsList = collectSubquerys.subqueies;
|
||||
List<NamedExpression> oldProjects = ImmutableList.copyOf(projects);
|
||||
ImmutableList.Builder<NamedExpression> newProjects = new ImmutableList.Builder<>();
|
||||
LogicalPlan childPlan = (LogicalPlan) project.child();
|
||||
LogicalPlan applyPlan;
|
||||
@ -166,7 +168,7 @@ public class SubqueryToApply implements AnalysisRuleFactory {
|
||||
replaceSubquery.replace(oldProjects.get(i), context);
|
||||
|
||||
applyPlan = subqueryToApply(
|
||||
subqueryExprs.stream().collect(ImmutableList.toImmutableList()),
|
||||
Utils.fastToImmutableList(subqueryExprs),
|
||||
childPlan, context.getSubqueryToMarkJoinSlot(),
|
||||
ctx.cascadesContext,
|
||||
Optional.of(newProject), true, false);
|
||||
@ -240,9 +242,11 @@ public class SubqueryToApply implements AnalysisRuleFactory {
|
||||
* if it's semi join with non-null mark slot
|
||||
* we can safely change the mark conjunct to hash conjunct
|
||||
*/
|
||||
ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(ctx.cascadesContext);
|
||||
boolean isMarkSlotNotNull = conjunct.containsType(MarkJoinSlotReference.class)
|
||||
? ExpressionUtils.canInferNotNullForMarkSlot(
|
||||
TrySimplifyPredicateWithMarkJoinSlot.INSTANCE.rewrite(conjunct, null))
|
||||
TrySimplifyPredicateWithMarkJoinSlot.INSTANCE.rewrite(conjunct, rewriteContext),
|
||||
rewriteContext)
|
||||
: false;
|
||||
applyPlan = subqueryToApply(
|
||||
subqueryExprs.stream().collect(ImmutableList.toImmutableList()),
|
||||
@ -566,4 +570,33 @@ public class SubqueryToApply implements AnalysisRuleFactory {
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private List<Boolean> shouldOutputMarkJoinSlot(Collection<Expression> conjuncts) {
|
||||
ImmutableList.Builder<Boolean> result = ImmutableList.builderWithExpectedSize(conjuncts.size());
|
||||
for (Expression expr : conjuncts) {
|
||||
result.add(!(expr instanceof SubqueryExpr) && expr.containsType(SubqueryExpr.class));
|
||||
}
|
||||
return result.build();
|
||||
}
|
||||
|
||||
private CollectSubquerys collectSubquerys(Collection<? extends Expression> exprs) {
|
||||
boolean hasSubqueryExpr = false;
|
||||
ImmutableList.Builder<Set<SubqueryExpr>> subqueryExprsListBuilder = ImmutableList.builder();
|
||||
for (Expression expression : exprs) {
|
||||
Set<SubqueryExpr> subqueries = expression.collect(SubqueryToApply::canConvertToSupply);
|
||||
hasSubqueryExpr |= !subqueries.isEmpty();
|
||||
subqueryExprsListBuilder.add(subqueries);
|
||||
}
|
||||
return new CollectSubquerys(subqueryExprsListBuilder.build(), hasSubqueryExpr);
|
||||
}
|
||||
|
||||
private static class CollectSubquerys {
|
||||
final List<Set<SubqueryExpr>> subqueies;
|
||||
final boolean hasSubquery;
|
||||
|
||||
public CollectSubquerys(List<Set<SubqueryExpr>> subqueies, boolean hasSubquery) {
|
||||
this.subqueies = subqueies;
|
||||
this.hasSubquery = hasSubquery;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -122,10 +122,11 @@ public abstract class AbstractMaterializedViewRule implements ExplorationRuleFac
|
||||
// TODO Just Check query queryPlan firstly, support multi later.
|
||||
StructInfo queryStructInfo = queryStructInfos.get(0);
|
||||
if (!checkPattern(queryStructInfo)) {
|
||||
cascadesContext.getMaterializationContexts().forEach(ctx ->
|
||||
ctx.recordFailReason(queryStructInfo, "Query struct info is invalid",
|
||||
() -> String.format("queryPlan is %s", queryPlan.treeString())
|
||||
));
|
||||
for (MaterializationContext ctx : cascadesContext.getMaterializationContexts()) {
|
||||
ctx.recordFailReason(queryStructInfo, "Query struct info is invalid",
|
||||
() -> String.format("queryPlan is %s", queryPlan.treeString())
|
||||
);
|
||||
}
|
||||
return validQueryStructInfos;
|
||||
}
|
||||
validQueryStructInfos.add(queryStructInfo);
|
||||
@ -228,7 +229,7 @@ public abstract class AbstractMaterializedViewRule implements ExplorationRuleFac
|
||||
viewToQuerySlotMapping));
|
||||
continue;
|
||||
}
|
||||
rewrittenPlan = new LogicalFilter<>(Sets.newHashSet(rewriteCompensatePredicates), mvScan);
|
||||
rewrittenPlan = new LogicalFilter<>(Sets.newLinkedHashSet(rewriteCompensatePredicates), mvScan);
|
||||
}
|
||||
// Rewrite query by view
|
||||
rewrittenPlan = rewriteQueryByView(matchMode, queryStructInfo, viewStructInfo, viewToQuerySlotMapping,
|
||||
@ -293,7 +294,7 @@ public abstract class AbstractMaterializedViewRule implements ExplorationRuleFac
|
||||
if (originOutputs.size() != rewrittenPlan.getOutput().size()) {
|
||||
return null;
|
||||
}
|
||||
Map<Slot, ExprId> originSlotToRewrittenExprId = Maps.newHashMap();
|
||||
Map<Slot, ExprId> originSlotToRewrittenExprId = Maps.newLinkedHashMap();
|
||||
for (int i = 0; i < originOutputs.size(); i++) {
|
||||
originSlotToRewrittenExprId.put(originOutputs.get(i), rewrittenPlan.getOutput().get(i).getExprId());
|
||||
}
|
||||
@ -305,7 +306,7 @@ public abstract class AbstractMaterializedViewRule implements ExplorationRuleFac
|
||||
rewrittenPlan = rewrittenPlanContext.getRewritePlan();
|
||||
|
||||
// for get right nullable after rewritten, we need this map
|
||||
Map<ExprId, Slot> exprIdToNewRewrittenSlot = Maps.newHashMap();
|
||||
Map<ExprId, Slot> exprIdToNewRewrittenSlot = Maps.newLinkedHashMap();
|
||||
for (Slot slot : rewrittenPlan.getOutput()) {
|
||||
exprIdToNewRewrittenSlot.put(slot.getExprId(), slot);
|
||||
}
|
||||
|
||||
@ -79,7 +79,7 @@ public class InitMaterializationContextHook implements PlannerHook {
|
||||
if (availableMTMVs.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
availableMTMVs.forEach(materializedView -> {
|
||||
for (MTMV materializedView : availableMTMVs) {
|
||||
// generate outside, maybe add partition filter in the future
|
||||
LogicalOlapScan mvScan = new LogicalOlapScan(
|
||||
cascadesContext.getStatementContext().getNextRelationId(),
|
||||
@ -96,6 +96,6 @@ public class InitMaterializationContextHook implements PlannerHook {
|
||||
Plan projectScan = new LogicalProject<Plan>(mvProjects, mvScan);
|
||||
cascadesContext.addMaterializationContext(
|
||||
MaterializationContext.fromMaterializedView(materializedView, projectScan, cascadesContext));
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -33,11 +33,12 @@ import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
import org.apache.doris.nereids.util.Utils;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
@ -67,7 +68,7 @@ public class MaterializationContext {
|
||||
private boolean success = false;
|
||||
// if rewrite by mv fail, record the reason, if success the failReason should be empty.
|
||||
// The key is the query belonged group expression objectId, the value is the fail reason
|
||||
private final Map<ObjectId, Pair<String, String>> failReason = new HashMap<>();
|
||||
private final Map<ObjectId, Pair<String, String>> failReason = new LinkedHashMap<>();
|
||||
private boolean enableRecordFailureDetail = false;
|
||||
|
||||
/**
|
||||
@ -163,7 +164,6 @@ public class MaterializationContext {
|
||||
if (this.success) {
|
||||
return;
|
||||
}
|
||||
this.success = false;
|
||||
this.failReason.put(structInfo.getOriginalPlanId(),
|
||||
Pair.of(summary, this.isEnableRecordFailureDetail() ? failureReasonSupplier.get() : ""));
|
||||
}
|
||||
@ -233,7 +233,7 @@ public class MaterializationContext {
|
||||
for (MaterializationContext ctx : materializationContexts) {
|
||||
if (!ctx.isSuccess()) {
|
||||
Set<String> failReasonSet =
|
||||
ctx.getFailReason().values().stream().map(Pair::key).collect(Collectors.toSet());
|
||||
ctx.getFailReason().values().stream().map(Pair::key).collect(ImmutableSet.toImmutableSet());
|
||||
builder.append("\n")
|
||||
.append(" Name: ").append(ctx.getMTMV().getName())
|
||||
.append("\n")
|
||||
|
||||
@ -53,8 +53,8 @@ import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
@ -116,7 +116,7 @@ public class StructInfo {
|
||||
this.predicates = predicates;
|
||||
if (predicates == null) {
|
||||
// collect predicate from top plan which not in hyper graph
|
||||
Set<Expression> topPlanPredicates = new HashSet<>();
|
||||
Set<Expression> topPlanPredicates = new LinkedHashSet<>();
|
||||
topPlan.accept(PREDICATE_COLLECTOR, topPlanPredicates);
|
||||
this.predicates = Predicates.of(topPlanPredicates);
|
||||
}
|
||||
@ -241,7 +241,9 @@ public class StructInfo {
|
||||
public static List<StructInfo> of(Plan originalPlan) {
|
||||
// TODO only consider the inner join currently, Should support outer join
|
||||
// Split plan by the boundary which contains multi child
|
||||
PlanSplitContext planSplitContext = new PlanSplitContext(Sets.newHashSet(LogicalJoin.class));
|
||||
LinkedHashSet<Class<? extends Plan>> set = Sets.newLinkedHashSet();
|
||||
set.add(LogicalJoin.class);
|
||||
PlanSplitContext planSplitContext = new PlanSplitContext(set);
|
||||
// if single table without join, the bottom is
|
||||
originalPlan.accept(PLAN_SPLITTER, planSplitContext);
|
||||
|
||||
@ -261,16 +263,18 @@ public class StructInfo {
|
||||
.map(GroupExpression::getId).orElseGet(() -> new ObjectId(-1));
|
||||
// if any of topPlan or bottomPlan is null, split the top plan to two parts by join node
|
||||
if (topPlan == null || bottomPlan == null) {
|
||||
PlanSplitContext planSplitContext = new PlanSplitContext(Sets.newHashSet(LogicalJoin.class));
|
||||
Set<Class<? extends Plan>> set = Sets.newLinkedHashSet();
|
||||
set.add(LogicalJoin.class);
|
||||
PlanSplitContext planSplitContext = new PlanSplitContext(set);
|
||||
originalPlan.accept(PLAN_SPLITTER, planSplitContext);
|
||||
bottomPlan = planSplitContext.getBottomPlan();
|
||||
topPlan = planSplitContext.getTopPlan();
|
||||
}
|
||||
// collect struct info fromGraph
|
||||
ImmutableList.Builder<CatalogRelation> relationBuilder = ImmutableList.builder();
|
||||
Map<RelationId, StructInfoNode> relationIdStructInfoNodeMap = new HashMap<>();
|
||||
Map<Expression, Expression> shuttledHashConjunctsToConjunctsMap = new HashMap<>();
|
||||
Map<ExprId, Expression> namedExprIdAndExprMapping = new HashMap<>();
|
||||
Map<RelationId, StructInfoNode> relationIdStructInfoNodeMap = new LinkedHashMap<>();
|
||||
Map<Expression, Expression> shuttledHashConjunctsToConjunctsMap = new LinkedHashMap<>();
|
||||
Map<ExprId, Expression> namedExprIdAndExprMapping = new LinkedHashMap<>();
|
||||
boolean valid = collectStructInfoFromGraph(hyperGraph, topPlan, shuttledHashConjunctsToConjunctsMap,
|
||||
namedExprIdAndExprMapping,
|
||||
relationBuilder,
|
||||
|
||||
@ -0,0 +1,124 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.expression;
|
||||
|
||||
import org.apache.doris.nereids.pattern.ExpressionPatternRules;
|
||||
import org.apache.doris.nereids.pattern.ExpressionPatternTraverseListeners;
|
||||
import org.apache.doris.nereids.pattern.ExpressionPatternTraverseListeners.CombinedListener;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import javax.annotation.Nullable;
|
||||
|
||||
/** ExpressionBottomUpRewriter */
|
||||
public class ExpressionBottomUpRewriter implements ExpressionRewriteRule<ExpressionRewriteContext> {
|
||||
public static final String BATCH_ID_KEY = "batch_id";
|
||||
private static final Logger LOG = LogManager.getLogger(ExpressionBottomUpRewriter.class);
|
||||
private static final AtomicInteger rewriteBatchId = new AtomicInteger();
|
||||
private final ExpressionPatternRules rules;
|
||||
private final ExpressionPatternTraverseListeners listeners;
|
||||
|
||||
public ExpressionBottomUpRewriter(ExpressionPatternRules rules, ExpressionPatternTraverseListeners listeners) {
|
||||
this.rules = rules;
|
||||
this.listeners = listeners;
|
||||
}
|
||||
|
||||
// entrance
|
||||
@Override
|
||||
public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) {
|
||||
int currentBatch = rewriteBatchId.incrementAndGet();
|
||||
return rewriteBottomUp(expr, ctx, currentBatch, null, rules, listeners);
|
||||
}
|
||||
|
||||
private static Expression rewriteBottomUp(
|
||||
Expression expression, ExpressionRewriteContext context, int currentBatch, @Nullable Expression parent,
|
||||
ExpressionPatternRules rules, ExpressionPatternTraverseListeners listeners) {
|
||||
|
||||
Optional<Integer> rewriteState = expression.getMutableState(BATCH_ID_KEY);
|
||||
if (!rewriteState.isPresent() || rewriteState.get() != currentBatch) {
|
||||
CombinedListener listener = null;
|
||||
boolean hasChildren = expression.arity() > 0;
|
||||
if (hasChildren) {
|
||||
listener = listeners.matchesAndCombineListeners(expression, context, parent);
|
||||
if (listener != null) {
|
||||
listener.onEnter();
|
||||
}
|
||||
}
|
||||
|
||||
Expression afterRewrite = expression;
|
||||
try {
|
||||
Expression beforeRewrite;
|
||||
afterRewrite = rewriteChildren(expression, context, currentBatch, rules, listeners);
|
||||
// use rewriteTimes to avoid dead loop
|
||||
int rewriteTimes = 0;
|
||||
boolean changed;
|
||||
do {
|
||||
beforeRewrite = afterRewrite;
|
||||
|
||||
// rewrite this
|
||||
Optional<Expression> applied = rules.matchesAndApply(beforeRewrite, context, parent);
|
||||
|
||||
changed = applied.isPresent();
|
||||
if (changed) {
|
||||
afterRewrite = applied.get();
|
||||
// ensure children are rewritten
|
||||
afterRewrite = rewriteChildren(afterRewrite, context, currentBatch, rules, listeners);
|
||||
}
|
||||
rewriteTimes++;
|
||||
} while (changed && rewriteTimes < 100);
|
||||
|
||||
// set rewritten
|
||||
afterRewrite.setMutableState(BATCH_ID_KEY, currentBatch);
|
||||
} finally {
|
||||
if (hasChildren && listener != null) {
|
||||
listener.onExit(afterRewrite);
|
||||
}
|
||||
}
|
||||
|
||||
return afterRewrite;
|
||||
}
|
||||
|
||||
// already rewritten
|
||||
return expression;
|
||||
}
|
||||
|
||||
private static Expression rewriteChildren(Expression parent, ExpressionRewriteContext context, int currentBatch,
|
||||
ExpressionPatternRules rules, ExpressionPatternTraverseListeners listeners) {
|
||||
boolean changed = false;
|
||||
ImmutableList.Builder<Expression> newChildren = ImmutableList.builderWithExpectedSize(parent.arity());
|
||||
for (Expression child : parent.children()) {
|
||||
Expression newChild = rewriteBottomUp(child, context, currentBatch, parent, rules, listeners);
|
||||
changed |= !child.equals(newChild);
|
||||
newChildren.add(newChild);
|
||||
}
|
||||
|
||||
Expression result = parent;
|
||||
if (changed) {
|
||||
result = parent.withChildren(newChildren.build());
|
||||
}
|
||||
if (changed && context.cascadesContext.isEnableExprTrace()) {
|
||||
LOG.info("WithChildren: \nbefore: " + parent + "\nafter: " + result);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,41 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.expression;
|
||||
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
/** ExpressionListenerMatcher */
|
||||
public class ExpressionListenerMatcher<E extends Expression> {
|
||||
public final Class<E> typePattern;
|
||||
public final List<Predicate<ExpressionMatchingContext<E>>> predicates;
|
||||
public final ExpressionTraverseListener<E> listener;
|
||||
|
||||
public ExpressionListenerMatcher(Class<E> typePattern,
|
||||
List<Predicate<ExpressionMatchingContext<E>>> predicates,
|
||||
ExpressionTraverseListener<E> listener) {
|
||||
this.typePattern = Objects.requireNonNull(typePattern, "typePattern can not be null");
|
||||
this.predicates = predicates == null ? ImmutableList.of() : predicates;
|
||||
this.listener = Objects.requireNonNull(listener, "listener can not be null");
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,25 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.expression;
|
||||
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
|
||||
/** ExpressionMatchAction */
|
||||
public interface ExpressionMatchingAction<E extends Expression> {
|
||||
Expression apply(ExpressionMatchingContext<E> context);
|
||||
}
|
||||
@ -0,0 +1,46 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.expression;
|
||||
|
||||
import org.apache.doris.nereids.CascadesContext;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
|
||||
import java.util.Optional;
|
||||
|
||||
/** ExpressionMatchingContext */
|
||||
public class ExpressionMatchingContext<E extends Expression> {
|
||||
public final E expr;
|
||||
public final Optional<Expression> parent;
|
||||
public final ExpressionRewriteContext rewriteContext;
|
||||
public final CascadesContext cascadesContext;
|
||||
|
||||
public ExpressionMatchingContext(E expr, Expression parent, ExpressionRewriteContext context) {
|
||||
this.expr = expr;
|
||||
this.parent = Optional.ofNullable(parent);
|
||||
this.rewriteContext = context;
|
||||
this.cascadesContext = context.cascadesContext;
|
||||
}
|
||||
|
||||
public boolean isRoot() {
|
||||
return !parent.isPresent();
|
||||
}
|
||||
|
||||
public Expression parentOr(Expression defaultParent) {
|
||||
return parent.orElse(defaultParent);
|
||||
}
|
||||
}
|
||||
@ -42,20 +42,21 @@ public class ExpressionNormalization extends ExpressionRewrite {
|
||||
// we should run supportJavaDateFormatter before foldConstantRule or be will fold
|
||||
// from_unixtime(timestamp, 'yyyyMMdd') to 'yyyyMMdd'
|
||||
public static final List<ExpressionRewriteRule> NORMALIZE_REWRITE_RULES = ImmutableList.of(
|
||||
SupportJavaDateFormatter.INSTANCE,
|
||||
ReplaceVariableByLiteral.INSTANCE,
|
||||
NormalizeBinaryPredicatesRule.INSTANCE,
|
||||
InPredicateDedup.INSTANCE,
|
||||
InPredicateToEqualToRule.INSTANCE,
|
||||
SimplifyNotExprRule.INSTANCE,
|
||||
SimplifyArithmeticRule.INSTANCE,
|
||||
FoldConstantRule.INSTANCE,
|
||||
SimplifyCastRule.INSTANCE,
|
||||
DigitalMaskingConvert.INSTANCE,
|
||||
SimplifyArithmeticComparisonRule.INSTANCE,
|
||||
SupportJavaDateFormatter.INSTANCE,
|
||||
ConvertAggStateCast.INSTANCE,
|
||||
CheckCast.INSTANCE
|
||||
bottomUp(
|
||||
ReplaceVariableByLiteral.INSTANCE,
|
||||
SupportJavaDateFormatter.INSTANCE,
|
||||
NormalizeBinaryPredicatesRule.INSTANCE,
|
||||
InPredicateDedup.INSTANCE,
|
||||
InPredicateToEqualToRule.INSTANCE,
|
||||
SimplifyNotExprRule.INSTANCE,
|
||||
SimplifyArithmeticRule.INSTANCE,
|
||||
FoldConstantRule.INSTANCE,
|
||||
SimplifyCastRule.INSTANCE,
|
||||
DigitalMaskingConvert.INSTANCE,
|
||||
SimplifyArithmeticComparisonRule.INSTANCE,
|
||||
ConvertAggStateCast.INSTANCE,
|
||||
CheckCast.INSTANCE
|
||||
)
|
||||
);
|
||||
|
||||
public ExpressionNormalization() {
|
||||
|
||||
@ -0,0 +1,33 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.expression;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
/** ExpressionNormalizationAndOptimization */
|
||||
public class ExpressionNormalizationAndOptimization extends ExpressionRewrite {
|
||||
/** ExpressionNormalizationAndOptimization */
|
||||
public ExpressionNormalizationAndOptimization() {
|
||||
super(new ExpressionRuleExecutor(
|
||||
ImmutableList.<ExpressionRewriteRule>builder()
|
||||
.addAll(ExpressionNormalization.NORMALIZE_REWRITE_RULES)
|
||||
.addAll(ExpressionOptimization.OPTIMIZE_REWRITE_RULES)
|
||||
.build()
|
||||
));
|
||||
}
|
||||
}
|
||||
@ -39,18 +39,20 @@ import java.util.List;
|
||||
*/
|
||||
public class ExpressionOptimization extends ExpressionRewrite {
|
||||
public static final List<ExpressionRewriteRule> OPTIMIZE_REWRITE_RULES = ImmutableList.of(
|
||||
ExtractCommonFactorRule.INSTANCE,
|
||||
DistinctPredicatesRule.INSTANCE,
|
||||
SimplifyComparisonPredicate.INSTANCE,
|
||||
SimplifyInPredicate.INSTANCE,
|
||||
SimplifyDecimalV3Comparison.INSTANCE,
|
||||
SimplifyRange.INSTANCE,
|
||||
DateFunctionRewrite.INSTANCE,
|
||||
OrToIn.INSTANCE,
|
||||
ArrayContainToArrayOverlap.INSTANCE,
|
||||
CaseWhenToIf.INSTANCE,
|
||||
TopnToMax.INSTANCE,
|
||||
NullSafeEqualToEqual.INSTANCE
|
||||
bottomUp(
|
||||
ExtractCommonFactorRule.INSTANCE,
|
||||
DistinctPredicatesRule.INSTANCE,
|
||||
SimplifyComparisonPredicate.INSTANCE,
|
||||
SimplifyInPredicate.INSTANCE,
|
||||
SimplifyDecimalV3Comparison.INSTANCE,
|
||||
OrToIn.INSTANCE,
|
||||
SimplifyRange.INSTANCE,
|
||||
DateFunctionRewrite.INSTANCE,
|
||||
ArrayContainToArrayOverlap.INSTANCE,
|
||||
CaseWhenToIf.INSTANCE,
|
||||
TopnToMax.INSTANCE,
|
||||
NullSafeEqualToEqual.INSTANCE
|
||||
)
|
||||
);
|
||||
private static final ExpressionRuleExecutor EXECUTOR = new ExpressionRuleExecutor(OPTIMIZE_REWRITE_RULES);
|
||||
|
||||
|
||||
@ -0,0 +1,64 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.expression;
|
||||
|
||||
import org.apache.doris.nereids.pattern.TypeMappings.TypeMapping;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
/** ExpressionPatternMatcherRule */
|
||||
public class ExpressionPatternMatchRule implements TypeMapping<Expression> {
|
||||
public final Class<? extends Expression> typePattern;
|
||||
public final List<Predicate<ExpressionMatchingContext<Expression>>> predicates;
|
||||
public final ExpressionMatchingAction<Expression> matchingAction;
|
||||
|
||||
public ExpressionPatternMatchRule(ExpressionPatternMatcher patternMatcher) {
|
||||
this.typePattern = patternMatcher.typePattern;
|
||||
this.predicates = patternMatcher.predicates;
|
||||
this.matchingAction = patternMatcher.matchingAction;
|
||||
}
|
||||
|
||||
/** matches */
|
||||
public boolean matchesTypeAndPredicates(ExpressionMatchingContext<Expression> context) {
|
||||
return typePattern.isInstance(context.expr) && matchesPredicates(context);
|
||||
}
|
||||
|
||||
/** matchesPredicates */
|
||||
public boolean matchesPredicates(ExpressionMatchingContext<Expression> context) {
|
||||
if (!predicates.isEmpty()) {
|
||||
for (Predicate<ExpressionMatchingContext<Expression>> predicate : predicates) {
|
||||
if (!predicate.test(context)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
public Expression apply(ExpressionMatchingContext<Expression> context) {
|
||||
Expression newResult = matchingAction.apply(context);
|
||||
return newResult == null ? context.expr : newResult;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Class<? extends Expression> getType() {
|
||||
return typePattern;
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,41 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.expression;
|
||||
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
/** ExpressionPattern */
|
||||
public class ExpressionPatternMatcher<E extends Expression> {
|
||||
public final Class<E> typePattern;
|
||||
public final List<Predicate<ExpressionMatchingContext<E>>> predicates;
|
||||
public final ExpressionMatchingAction<E> matchingAction;
|
||||
|
||||
public ExpressionPatternMatcher(Class<E> typePattern,
|
||||
List<Predicate<ExpressionMatchingContext<E>>> predicates,
|
||||
ExpressionMatchingAction<E> matchingAction) {
|
||||
this.typePattern = Objects.requireNonNull(typePattern, "typePattern can not be null");
|
||||
this.predicates = predicates == null ? ImmutableList.of() : predicates;
|
||||
this.matchingAction = Objects.requireNonNull(matchingAction, "matchingAction can not be null");
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,84 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.expression;
|
||||
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.function.Function;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
/** ExpressionPatternRuleFactory */
|
||||
public interface ExpressionPatternRuleFactory {
|
||||
List<ExpressionPatternMatcher<? extends Expression>> buildRules();
|
||||
|
||||
default <E extends Expression> ExpressionPatternDescriptor<E> matchesType(Class<E> clazz) {
|
||||
return new ExpressionPatternDescriptor<>(clazz);
|
||||
}
|
||||
|
||||
default <E extends Expression> ExpressionPatternDescriptor<E> root(Class<E> clazz) {
|
||||
return new ExpressionPatternDescriptor<>(clazz)
|
||||
.whenCtx(ctx -> ctx.isRoot());
|
||||
}
|
||||
|
||||
default <E extends Expression> ExpressionPatternDescriptor<E> matchesTopType(Class<E> clazz) {
|
||||
return new ExpressionPatternDescriptor<>(clazz)
|
||||
.whenCtx(ctx -> ctx.isRoot() || !clazz.isInstance(ctx.parent.get()));
|
||||
}
|
||||
|
||||
/** ExpressionPatternDescriptor */
|
||||
class ExpressionPatternDescriptor<E extends Expression> {
|
||||
private final Class<E> typePattern;
|
||||
private final ImmutableList<Predicate<ExpressionMatchingContext<E>>> predicates;
|
||||
|
||||
public ExpressionPatternDescriptor(Class<E> typePattern) {
|
||||
this(typePattern, ImmutableList.of());
|
||||
}
|
||||
|
||||
public ExpressionPatternDescriptor(
|
||||
Class<E> typePattern, ImmutableList<Predicate<ExpressionMatchingContext<E>>> predicates) {
|
||||
this.typePattern = Objects.requireNonNull(typePattern, "typePattern can not be null");
|
||||
this.predicates = Objects.requireNonNull(predicates, "predicates can not be null");
|
||||
}
|
||||
|
||||
public ExpressionPatternDescriptor<E> when(Predicate<E> predicate) {
|
||||
return whenCtx(ctx -> predicate.test(ctx.expr));
|
||||
}
|
||||
|
||||
public ExpressionPatternDescriptor<E> whenCtx(Predicate<ExpressionMatchingContext<E>> predicate) {
|
||||
ImmutableList.Builder<Predicate<ExpressionMatchingContext<E>>> newPredicates
|
||||
= ImmutableList.builderWithExpectedSize(predicates.size() + 1);
|
||||
newPredicates.addAll(predicates);
|
||||
newPredicates.add(predicate);
|
||||
return new ExpressionPatternDescriptor<>(typePattern, newPredicates.build());
|
||||
}
|
||||
|
||||
/** then */
|
||||
public ExpressionPatternMatcher<E> then(Function<E, Expression> rewriter) {
|
||||
return new ExpressionPatternMatcher<>(
|
||||
typePattern, predicates, (context) -> rewriter.apply(context.expr));
|
||||
}
|
||||
|
||||
public ExpressionPatternMatcher<E> thenApply(ExpressionMatchingAction<E> action) {
|
||||
return new ExpressionPatternMatcher<>(typePattern, predicates, action);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -18,6 +18,8 @@
|
||||
package org.apache.doris.nereids.rules.expression;
|
||||
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.pattern.ExpressionPatternRules;
|
||||
import org.apache.doris.nereids.pattern.ExpressionPatternTraverseListeners;
|
||||
import org.apache.doris.nereids.properties.OrderKey;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
@ -41,7 +43,7 @@ import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
@ -123,9 +125,7 @@ public class ExpressionRewrite implements RewriteRuleFactory {
|
||||
LogicalProject<Plan> project = ctx.root;
|
||||
ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
|
||||
List<NamedExpression> projects = project.getProjects();
|
||||
List<NamedExpression> newProjects = projects.stream()
|
||||
.map(expr -> (NamedExpression) rewriter.rewrite(expr, context))
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
List<NamedExpression> newProjects = rewriteAll(projects, rewriter, context);
|
||||
if (projects.equals(newProjects)) {
|
||||
return project;
|
||||
}
|
||||
@ -160,9 +160,7 @@ public class ExpressionRewrite implements RewriteRuleFactory {
|
||||
List<Expression> newGroupByExprs = rewriter.rewrite(groupByExprs, context);
|
||||
|
||||
List<NamedExpression> outputExpressions = agg.getOutputExpressions();
|
||||
List<NamedExpression> newOutputExpressions = outputExpressions.stream()
|
||||
.map(expr -> (NamedExpression) rewriter.rewrite(expr, context))
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
List<NamedExpression> newOutputExpressions = rewriteAll(outputExpressions, rewriter, context);
|
||||
if (outputExpressions.equals(newOutputExpressions)) {
|
||||
return agg;
|
||||
}
|
||||
@ -222,13 +220,16 @@ public class ExpressionRewrite implements RewriteRuleFactory {
|
||||
return logicalSort().thenApply(ctx -> {
|
||||
LogicalSort<Plan> sort = ctx.root;
|
||||
List<OrderKey> orderKeys = sort.getOrderKeys();
|
||||
List<OrderKey> rewrittenOrderKeys = new ArrayList<>();
|
||||
ImmutableList.Builder<OrderKey> rewrittenOrderKeys
|
||||
= ImmutableList.builderWithExpectedSize(orderKeys.size());
|
||||
ExpressionRewriteContext context = new ExpressionRewriteContext(ctx.cascadesContext);
|
||||
boolean changed = false;
|
||||
for (OrderKey k : orderKeys) {
|
||||
Expression expression = rewriter.rewrite(k.getExpr(), context);
|
||||
changed |= expression != k.getExpr();
|
||||
rewrittenOrderKeys.add(new OrderKey(expression, k.isAsc(), k.isNullFirst()));
|
||||
}
|
||||
return sort.withOrderKeys(rewrittenOrderKeys);
|
||||
return changed ? sort.withOrderKeys(rewrittenOrderKeys.build()) : sort;
|
||||
}).toRule(RuleType.REWRITE_SORT_EXPRESSION);
|
||||
}
|
||||
}
|
||||
@ -270,4 +271,36 @@ public class ExpressionRewrite implements RewriteRuleFactory {
|
||||
}).toRule(RuleType.REWRITE_REPEAT_EXPRESSION);
|
||||
}
|
||||
}
|
||||
|
||||
/** bottomUp */
|
||||
public static ExpressionBottomUpRewriter bottomUp(ExpressionPatternRuleFactory... ruleFactories) {
|
||||
ImmutableList.Builder<ExpressionPatternMatchRule> rules = ImmutableList.builder();
|
||||
ImmutableList.Builder<ExpressionTraverseListenerMapping> listeners = ImmutableList.builder();
|
||||
for (ExpressionPatternRuleFactory ruleFactory : ruleFactories) {
|
||||
if (ruleFactory instanceof ExpressionTraverseListenerFactory) {
|
||||
List<ExpressionListenerMatcher<? extends Expression>> listenersMatcher
|
||||
= ((ExpressionTraverseListenerFactory) ruleFactory).buildListeners();
|
||||
for (ExpressionListenerMatcher<? extends Expression> listenerMatcher : listenersMatcher) {
|
||||
listeners.add(new ExpressionTraverseListenerMapping(listenerMatcher));
|
||||
}
|
||||
}
|
||||
for (ExpressionPatternMatcher<? extends Expression> patternMatcher : ruleFactory.buildRules()) {
|
||||
rules.add(new ExpressionPatternMatchRule(patternMatcher));
|
||||
}
|
||||
}
|
||||
|
||||
return new ExpressionBottomUpRewriter(
|
||||
new ExpressionPatternRules(rules.build()),
|
||||
new ExpressionPatternTraverseListeners(listeners.build())
|
||||
);
|
||||
}
|
||||
|
||||
public static <E extends Expression> List<E> rewriteAll(
|
||||
Collection<E> exprs, ExpressionRuleExecutor rewriter, ExpressionRewriteContext context) {
|
||||
ImmutableList.Builder<E> result = ImmutableList.builderWithExpectedSize(exprs.size());
|
||||
for (E expr : exprs) {
|
||||
result.add((E) rewriter.rewrite(expr, context));
|
||||
}
|
||||
return result.build();
|
||||
}
|
||||
}
|
||||
|
||||
@ -19,6 +19,8 @@ package org.apache.doris.nereids.rules.expression;
|
||||
|
||||
import org.apache.doris.nereids.CascadesContext;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* expression rewrite context.
|
||||
*/
|
||||
@ -27,7 +29,7 @@ public class ExpressionRewriteContext {
|
||||
public final CascadesContext cascadesContext;
|
||||
|
||||
public ExpressionRewriteContext(CascadesContext cascadesContext) {
|
||||
this.cascadesContext = cascadesContext;
|
||||
this.cascadesContext = Objects.requireNonNull(cascadesContext, "cascadesContext can not be null");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package org.apache.doris.nereids.rules.expression;
|
||||
|
||||
import org.apache.doris.nereids.rules.expression.rules.NormalizeBinaryPredicatesRule;
|
||||
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
@ -36,7 +37,11 @@ public class ExpressionRuleExecutor {
|
||||
}
|
||||
|
||||
public List<Expression> rewrite(List<Expression> exprs, ExpressionRewriteContext ctx) {
|
||||
return exprs.stream().map(expr -> rewrite(expr, ctx)).collect(ImmutableList.toImmutableList());
|
||||
ImmutableList.Builder<Expression> result = ImmutableList.builderWithExpectedSize(exprs.size());
|
||||
for (Expression expr : exprs) {
|
||||
result.add(rewrite(expr, ctx));
|
||||
}
|
||||
return result.build();
|
||||
}
|
||||
|
||||
/**
|
||||
@ -61,8 +66,15 @@ public class ExpressionRuleExecutor {
|
||||
return rule.rewrite(expr, ctx);
|
||||
}
|
||||
|
||||
/** normalize */
|
||||
public static Expression normalize(Expression expression) {
|
||||
return NormalizeBinaryPredicatesRule.INSTANCE.rewrite(expression, null);
|
||||
return expression.rewriteUp(expr -> {
|
||||
if (expr instanceof ComparisonPredicate) {
|
||||
return NormalizeBinaryPredicatesRule.normalize((ComparisonPredicate) expression);
|
||||
} else {
|
||||
return expr;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -0,0 +1,31 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.expression;
|
||||
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
|
||||
/** ExpressionTraverseListener */
|
||||
public interface ExpressionTraverseListener<E extends Expression> {
|
||||
default void onEnter(ExpressionMatchingContext<E> context) {}
|
||||
|
||||
default void onExit(ExpressionMatchingContext<E> context, Expression rewritten) {}
|
||||
|
||||
default <CAST extends Expression> ExpressionTraverseListener<CAST> as() {
|
||||
return (ExpressionTraverseListener) this;
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,79 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.expression;
|
||||
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
/** ExpressionTraverseListenerFactory */
|
||||
public interface ExpressionTraverseListenerFactory {
|
||||
List<ExpressionListenerMatcher<? extends Expression>> buildListeners();
|
||||
|
||||
default <E extends Expression> ListenerDescriptor<E> listenerType(Class<E> clazz) {
|
||||
return new ListenerDescriptor<>(clazz);
|
||||
}
|
||||
|
||||
/** listenerTypes */
|
||||
default List<ListenerDescriptor<Expression>> listenerTypes(Class<? extends Expression>... classes) {
|
||||
ImmutableList.Builder<ListenerDescriptor<Expression>> listeners
|
||||
= ImmutableList.builderWithExpectedSize(classes.length);
|
||||
for (Class<? extends Expression> clazz : classes) {
|
||||
listeners.add((ListenerDescriptor<Expression>) listenerType(clazz));
|
||||
}
|
||||
return listeners.build();
|
||||
}
|
||||
|
||||
/** ListenerDescriptor */
|
||||
class ListenerDescriptor<E extends Expression> {
|
||||
|
||||
private final Class<E> typePattern;
|
||||
private final ImmutableList<Predicate<ExpressionMatchingContext<E>>> predicates;
|
||||
|
||||
public ListenerDescriptor(Class<E> typePattern) {
|
||||
this(typePattern, ImmutableList.of());
|
||||
}
|
||||
|
||||
public ListenerDescriptor(
|
||||
Class<E> typePattern, ImmutableList<Predicate<ExpressionMatchingContext<E>>> predicates) {
|
||||
this.typePattern = Objects.requireNonNull(typePattern, "typePattern can not be null");
|
||||
this.predicates = Objects.requireNonNull(predicates, "predicates can not be null");
|
||||
}
|
||||
|
||||
public ListenerDescriptor<E> when(Predicate<E> predicate) {
|
||||
return whenCtx(ctx -> predicate.test(ctx.expr));
|
||||
}
|
||||
|
||||
public ListenerDescriptor<E> whenCtx(Predicate<ExpressionMatchingContext<E>> predicate) {
|
||||
ImmutableList.Builder<Predicate<ExpressionMatchingContext<E>>> newPredicates
|
||||
= ImmutableList.builderWithExpectedSize(predicates.size() + 1);
|
||||
newPredicates.addAll(predicates);
|
||||
newPredicates.add(predicate);
|
||||
return new ListenerDescriptor<>(typePattern, newPredicates.build());
|
||||
}
|
||||
|
||||
/** then */
|
||||
public ExpressionListenerMatcher<E> then(ExpressionTraverseListener<E> listener) {
|
||||
return new ExpressionListenerMatcher<>(typePattern, predicates, listener);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,59 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.expression;
|
||||
|
||||
import org.apache.doris.nereids.pattern.TypeMappings.TypeMapping;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
/** ExpressionTraverseListener */
|
||||
public class ExpressionTraverseListenerMapping implements TypeMapping<Expression> {
|
||||
public final Class<? extends Expression> typePattern;
|
||||
public final List<Predicate<ExpressionMatchingContext<Expression>>> predicates;
|
||||
public final ExpressionTraverseListener<Expression> listener;
|
||||
|
||||
public ExpressionTraverseListenerMapping(ExpressionListenerMatcher listenerMatcher) {
|
||||
this.typePattern = listenerMatcher.typePattern;
|
||||
this.predicates = listenerMatcher.predicates;
|
||||
this.listener = listenerMatcher.listener;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Class<? extends Expression> getType() {
|
||||
return typePattern;
|
||||
}
|
||||
|
||||
/** matches */
|
||||
public boolean matchesTypeAndPredicates(ExpressionMatchingContext<Expression> context) {
|
||||
return typePattern.isInstance(context.expr) && matchesPredicates(context);
|
||||
}
|
||||
|
||||
/** matchesPredicates */
|
||||
public boolean matchesPredicates(ExpressionMatchingContext<Expression> context) {
|
||||
if (!predicates.isEmpty()) {
|
||||
for (Predicate<ExpressionMatchingContext<Expression>> predicate : predicates) {
|
||||
if (!predicate.test(context)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
@ -18,8 +18,8 @@
|
||||
package org.apache.doris.nereids.rules.expression.check;
|
||||
|
||||
import org.apache.doris.nereids.exceptions.AnalysisException;
|
||||
import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
|
||||
import org.apache.doris.nereids.trees.expressions.Cast;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.types.ArrayType;
|
||||
@ -31,18 +31,24 @@ import org.apache.doris.nereids.types.StructType;
|
||||
import org.apache.doris.nereids.types.coercion.CharacterType;
|
||||
import org.apache.doris.nereids.types.coercion.PrimitiveType;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* check cast valid
|
||||
*/
|
||||
public class CheckCast extends AbstractExpressionRewriteRule {
|
||||
|
||||
public static final CheckCast INSTANCE = new CheckCast();
|
||||
public class CheckCast implements ExpressionPatternRuleFactory {
|
||||
public static CheckCast INSTANCE = new CheckCast();
|
||||
|
||||
@Override
|
||||
public Expression visitCast(Cast cast, ExpressionRewriteContext context) {
|
||||
rewrite(cast.child(), context);
|
||||
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
|
||||
return ImmutableList.of(
|
||||
matchesType(Cast.class).then(CheckCast::check)
|
||||
);
|
||||
}
|
||||
|
||||
private static Expression check(Cast cast) {
|
||||
DataType originalType = cast.child().getDataType();
|
||||
DataType targetType = cast.getDataType();
|
||||
if (!check(originalType, targetType)) {
|
||||
@ -51,7 +57,7 @@ public class CheckCast extends AbstractExpressionRewriteRule {
|
||||
return cast;
|
||||
}
|
||||
|
||||
private boolean check(DataType originalType, DataType targetType) {
|
||||
private static boolean check(DataType originalType, DataType targetType) {
|
||||
if (originalType.isVariantType() && (targetType instanceof PrimitiveType || targetType.isArrayType())) {
|
||||
// variant could cast to primitive types and array
|
||||
return true;
|
||||
@ -99,7 +105,7 @@ public class CheckCast extends AbstractExpressionRewriteRule {
|
||||
* 3. original type is same with target type
|
||||
* 4. target type is null type
|
||||
*/
|
||||
private boolean checkPrimitiveType(DataType originalType, DataType targetType) {
|
||||
private static boolean checkPrimitiveType(DataType originalType, DataType targetType) {
|
||||
if (!originalType.isPrimitive() || !targetType.isPrimitive()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -17,26 +17,29 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.expression.rules;
|
||||
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteRule;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.Or;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayContains;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraysOverlap;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.ArrayLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
import org.apache.doris.nereids.util.Utils;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableList.Builder;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Multimaps;
|
||||
import com.google.common.collect.SetMultimap;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.Collection;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* array_contains ( c_array, '1' )
|
||||
@ -44,56 +47,73 @@ import java.util.stream.Collectors;
|
||||
* =========================================>
|
||||
* array_overlap(c_array, ['1', '2'])
|
||||
*/
|
||||
public class ArrayContainToArrayOverlap extends DefaultExpressionRewriter<ExpressionRewriteContext> implements
|
||||
ExpressionRewriteRule<ExpressionRewriteContext> {
|
||||
public class ArrayContainToArrayOverlap implements ExpressionPatternRuleFactory {
|
||||
|
||||
public static final ArrayContainToArrayOverlap INSTANCE = new ArrayContainToArrayOverlap();
|
||||
|
||||
private static final int REWRITE_PREDICATE_THRESHOLD = 2;
|
||||
|
||||
@Override
|
||||
public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) {
|
||||
return expr.accept(this, ctx);
|
||||
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
|
||||
return ImmutableList.of(
|
||||
matchesTopType(Or.class).then(ArrayContainToArrayOverlap::rewrite)
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitOr(Or or, ExpressionRewriteContext ctx) {
|
||||
private static Expression rewrite(Or or) {
|
||||
List<Expression> disjuncts = ExpressionUtils.extractDisjunction(or);
|
||||
Map<Boolean, List<Expression>> containFuncAndOtherFunc = disjuncts.stream()
|
||||
.collect(Collectors.partitioningBy(this::isValidArrayContains));
|
||||
Map<Expression, Set<Literal>> containLiteralSet = new HashMap<>();
|
||||
List<Expression> contains = containFuncAndOtherFunc.get(true);
|
||||
List<Expression> others = containFuncAndOtherFunc.get(false);
|
||||
|
||||
contains.forEach(containFunc ->
|
||||
containLiteralSet.computeIfAbsent(containFunc.child(0), k -> new HashSet<>())
|
||||
.add((Literal) containFunc.child(1)));
|
||||
List<Expression> contains = Lists.newArrayList();
|
||||
List<Expression> others = Lists.newArrayList();
|
||||
for (Expression expr : disjuncts) {
|
||||
if (ArrayContainToArrayOverlap.isValidArrayContains(expr)) {
|
||||
contains.add(expr);
|
||||
} else {
|
||||
others.add(expr);
|
||||
}
|
||||
}
|
||||
|
||||
if (contains.size() <= 1) {
|
||||
return or;
|
||||
}
|
||||
|
||||
SetMultimap<Expression, Literal> containLiteralSet = Multimaps.newSetMultimap(
|
||||
new LinkedHashMap<>(), LinkedHashSet::new
|
||||
);
|
||||
for (Expression contain : contains) {
|
||||
containLiteralSet.put(contain.child(0), (Literal) contain.child(1));
|
||||
}
|
||||
|
||||
Builder<Expression> newDisjunctsBuilder = new ImmutableList.Builder<>();
|
||||
containLiteralSet.forEach((left, literalSet) -> {
|
||||
for (Entry<Expression, Collection<Literal>> kv : containLiteralSet.asMap().entrySet()) {
|
||||
Expression left = kv.getKey();
|
||||
Collection<Literal> literalSet = kv.getValue();
|
||||
if (literalSet.size() > REWRITE_PREDICATE_THRESHOLD) {
|
||||
newDisjunctsBuilder.add(
|
||||
new ArraysOverlap(left,
|
||||
new ArrayLiteral(ImmutableList.copyOf(literalSet))));
|
||||
new ArraysOverlap(left, new ArrayLiteral(Utils.fastToImmutableList(literalSet)))
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
contains.stream()
|
||||
.filter(e -> !canCovertToArrayOverlap(e, containLiteralSet))
|
||||
.forEach(newDisjunctsBuilder::add);
|
||||
others.stream()
|
||||
.map(e -> e.accept(this, null))
|
||||
.forEach(newDisjunctsBuilder::add);
|
||||
for (Expression contain : contains) {
|
||||
if (!canCovertToArrayOverlap(contain, containLiteralSet)) {
|
||||
newDisjunctsBuilder.add(contain);
|
||||
}
|
||||
}
|
||||
newDisjunctsBuilder.addAll(others);
|
||||
return ExpressionUtils.or(newDisjunctsBuilder.build());
|
||||
}
|
||||
|
||||
private boolean isValidArrayContains(Expression expression) {
|
||||
private static boolean isValidArrayContains(Expression expression) {
|
||||
return expression instanceof ArrayContains && expression.child(1) instanceof Literal;
|
||||
}
|
||||
|
||||
private boolean canCovertToArrayOverlap(Expression expression, Map<Expression, Set<Literal>> containLiteralSet) {
|
||||
return expression instanceof ArrayContains
|
||||
&& containLiteralSet.getOrDefault(expression.child(0),
|
||||
new HashSet<>()).size() > REWRITE_PREDICATE_THRESHOLD;
|
||||
private static boolean canCovertToArrayOverlap(
|
||||
Expression expression, SetMultimap<Expression, Literal> containLiteralSet) {
|
||||
if (!(expression instanceof ArrayContains)) {
|
||||
return false;
|
||||
}
|
||||
Set<Literal> containLiteral = containLiteralSet.get(expression.child(0));
|
||||
return containLiteral.size() > REWRITE_PREDICATE_THRESHOLD;
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,25 +17,35 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.expression.rules;
|
||||
|
||||
import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
|
||||
import org.apache.doris.nereids.trees.expressions.CaseWhen;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.WhenClause;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Rewrite rule to convert CASE WHEN to IF.
|
||||
* For example:
|
||||
* CASE WHEN a > 1 THEN 1 ELSE 0 END -> IF(a > 1, 1, 0)
|
||||
*/
|
||||
public class CaseWhenToIf extends AbstractExpressionRewriteRule {
|
||||
public class CaseWhenToIf implements ExpressionPatternRuleFactory {
|
||||
|
||||
public static CaseWhenToIf INSTANCE = new CaseWhenToIf();
|
||||
|
||||
@Override
|
||||
public Expression visitCaseWhen(CaseWhen caseWhen, ExpressionRewriteContext context) {
|
||||
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
|
||||
return ImmutableList.of(
|
||||
matchesTopType(CaseWhen.class).then(CaseWhenToIf::rewrite)
|
||||
);
|
||||
}
|
||||
|
||||
private static Expression rewrite(CaseWhen caseWhen) {
|
||||
Expression expr = caseWhen;
|
||||
if (caseWhen.getWhenClauses().size() == 1) {
|
||||
WhenClause whenClause = caseWhen.getWhenClauses().get(0);
|
||||
|
||||
@ -17,8 +17,8 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.expression.rules;
|
||||
|
||||
import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
|
||||
import org.apache.doris.nereids.trees.expressions.Cast;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.combinator.StateCombinator;
|
||||
@ -30,29 +30,30 @@ import org.apache.doris.nereids.util.TypeCoercionUtils;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Follow legacy planner cast agg_state combinator's children if we need cast it to another agg_state type when insert
|
||||
*/
|
||||
public class ConvertAggStateCast extends AbstractExpressionRewriteRule {
|
||||
public class ConvertAggStateCast implements ExpressionPatternRuleFactory {
|
||||
|
||||
public static ConvertAggStateCast INSTANCE = new ConvertAggStateCast();
|
||||
|
||||
@Override
|
||||
public Expression visitCast(Cast cast, ExpressionRewriteContext context) {
|
||||
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
|
||||
return ImmutableList.of(
|
||||
matchesTopType(Cast.class).then(ConvertAggStateCast::convert)
|
||||
);
|
||||
}
|
||||
|
||||
private static Expression convert(Cast cast) {
|
||||
Expression child = cast.child();
|
||||
DataType originalType = child.getDataType();
|
||||
DataType targetType = cast.getDataType();
|
||||
if (originalType instanceof AggStateType
|
||||
&& targetType instanceof AggStateType
|
||||
&& child instanceof StateCombinator) {
|
||||
AggStateType original = (AggStateType) originalType;
|
||||
AggStateType target = (AggStateType) targetType;
|
||||
if (original.getSubTypes().size() != target.getSubTypes().size()) {
|
||||
return processCastChild(cast, context);
|
||||
}
|
||||
if (!original.getFunctionName().equalsIgnoreCase(target.getFunctionName())) {
|
||||
return processCastChild(cast, context);
|
||||
}
|
||||
ImmutableList.Builder<Expression> newChildren = ImmutableList.builderWithExpectedSize(child.arity());
|
||||
for (int i = 0; i < child.arity(); i++) {
|
||||
Expression newChild = TypeCoercionUtils.castIfNotSameType(child.child(i), target.getSubTypes().get(i));
|
||||
@ -66,15 +67,7 @@ public class ConvertAggStateCast extends AbstractExpressionRewriteRule {
|
||||
newChildren.add(newChild);
|
||||
}
|
||||
child = child.withChildren(newChildren.build());
|
||||
return processCastChild(cast.withChildren(ImmutableList.of(child)), context);
|
||||
}
|
||||
return processCastChild(cast, context);
|
||||
}
|
||||
|
||||
private Expression processCastChild(Cast cast, ExpressionRewriteContext context) {
|
||||
Expression child = visit(cast.child(), context);
|
||||
if (child != cast.child()) {
|
||||
cast = cast.withChildren(ImmutableList.of(child));
|
||||
return cast.withChildren(ImmutableList.of(child));
|
||||
}
|
||||
return cast;
|
||||
}
|
||||
|
||||
@ -17,8 +17,8 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.expression.rules;
|
||||
|
||||
import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
|
||||
import org.apache.doris.nereids.trees.expressions.And;
|
||||
import org.apache.doris.nereids.trees.expressions.EqualTo;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
@ -34,17 +34,31 @@ import org.apache.doris.nereids.trees.expressions.literal.DateV2Literal;
|
||||
import org.apache.doris.nereids.types.DateTimeType;
|
||||
import org.apache.doris.nereids.types.DateTimeV2Type;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* F: a DateTime or DateTimeV2 column
|
||||
* Date(F) > 2020-01-01 => F > 2020-01-02 00:00:00
|
||||
* Date(F) >= 2020-01-01 => F > 2020-01-01 00:00:00
|
||||
*
|
||||
*/
|
||||
public class DateFunctionRewrite extends AbstractExpressionRewriteRule {
|
||||
public class DateFunctionRewrite implements ExpressionPatternRuleFactory {
|
||||
public static DateFunctionRewrite INSTANCE = new DateFunctionRewrite();
|
||||
|
||||
@Override
|
||||
public Expression visitEqualTo(EqualTo equalTo, ExpressionRewriteContext context) {
|
||||
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
|
||||
return ImmutableList.of(
|
||||
matchesType(EqualTo.class).then(DateFunctionRewrite::rewriteEqualTo),
|
||||
matchesType(GreaterThan.class).then(DateFunctionRewrite::rewriteGreaterThan),
|
||||
matchesType(GreaterThanEqual.class).then(DateFunctionRewrite::rewriteGreaterThanEqual),
|
||||
matchesType(LessThan.class).then(DateFunctionRewrite::rewriteLessThan),
|
||||
matchesType(LessThanEqual.class).then(DateFunctionRewrite::rewriteLessThanEqual)
|
||||
);
|
||||
}
|
||||
|
||||
private static Expression rewriteEqualTo(EqualTo equalTo) {
|
||||
if (equalTo.left() instanceof Date) {
|
||||
// V1
|
||||
if (equalTo.left().child(0).getDataType() instanceof DateTimeType
|
||||
@ -70,8 +84,7 @@ public class DateFunctionRewrite extends AbstractExpressionRewriteRule {
|
||||
return equalTo;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitGreaterThan(GreaterThan greaterThan, ExpressionRewriteContext context) {
|
||||
private static Expression rewriteGreaterThan(GreaterThan greaterThan) {
|
||||
if (greaterThan.left() instanceof Date) {
|
||||
// V1
|
||||
if (greaterThan.left().child(0).getDataType() instanceof DateTimeType
|
||||
@ -91,8 +104,7 @@ public class DateFunctionRewrite extends AbstractExpressionRewriteRule {
|
||||
return greaterThan;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitGreaterThanEqual(GreaterThanEqual greaterThanEqual, ExpressionRewriteContext context) {
|
||||
private static Expression rewriteGreaterThanEqual(GreaterThanEqual greaterThanEqual) {
|
||||
if (greaterThanEqual.left() instanceof Date) {
|
||||
// V1
|
||||
if (greaterThanEqual.left().child(0).getDataType() instanceof DateTimeType
|
||||
@ -111,8 +123,7 @@ public class DateFunctionRewrite extends AbstractExpressionRewriteRule {
|
||||
return greaterThanEqual;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitLessThan(LessThan lessThan, ExpressionRewriteContext context) {
|
||||
private static Expression rewriteLessThan(LessThan lessThan) {
|
||||
if (lessThan.left() instanceof Date) {
|
||||
// V1
|
||||
if (lessThan.left().child(0).getDataType() instanceof DateTimeType
|
||||
@ -131,8 +142,7 @@ public class DateFunctionRewrite extends AbstractExpressionRewriteRule {
|
||||
return lessThan;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitLessThanEqual(LessThanEqual lessThanEqual, ExpressionRewriteContext context) {
|
||||
private static Expression rewriteLessThanEqual(LessThanEqual lessThanEqual) {
|
||||
if (lessThanEqual.left() instanceof Date) {
|
||||
// V1
|
||||
if (lessThanEqual.left().child(0).getDataType() instanceof DateTimeType
|
||||
|
||||
@ -17,8 +17,8 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.expression.rules;
|
||||
|
||||
import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.Concat;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.DigitalMasking;
|
||||
@ -26,16 +26,25 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.Left;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.Right;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Convert DigitalMasking to Concat
|
||||
*/
|
||||
public class DigitalMaskingConvert extends AbstractExpressionRewriteRule {
|
||||
|
||||
public class DigitalMaskingConvert implements ExpressionPatternRuleFactory {
|
||||
public static DigitalMaskingConvert INSTANCE = new DigitalMaskingConvert();
|
||||
|
||||
@Override
|
||||
public Expression visitDigitalMasking(DigitalMasking digitalMasking, ExpressionRewriteContext context) {
|
||||
return new Concat(new Left(digitalMasking.child(), Literal.of(3)), Literal.of("****"),
|
||||
new Right(digitalMasking.child(), Literal.of(4)));
|
||||
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
|
||||
return ImmutableList.of(
|
||||
matchesType(DigitalMasking.class).then(digitalMasking ->
|
||||
new Concat(
|
||||
new Left(digitalMasking.child(), Literal.of(3)),
|
||||
Literal.of("****"),
|
||||
new Right(digitalMasking.child(), Literal.of(4)))
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,12 +17,13 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.expression.rules;
|
||||
|
||||
import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
|
||||
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.Lists;
|
||||
|
||||
import java.util.LinkedHashSet;
|
||||
@ -35,16 +36,21 @@ import java.util.Set;
|
||||
* transform (a = 1) and (b > 2) and (a = 1) to (a = 1) and (b > 2)
|
||||
* transform (a = 1) or (a = 1) to (a = 1)
|
||||
*/
|
||||
public class DistinctPredicatesRule extends AbstractExpressionRewriteRule {
|
||||
|
||||
public class DistinctPredicatesRule implements ExpressionPatternRuleFactory {
|
||||
public static final DistinctPredicatesRule INSTANCE = new DistinctPredicatesRule();
|
||||
|
||||
@Override
|
||||
public Expression visitCompoundPredicate(CompoundPredicate expr, ExpressionRewriteContext context) {
|
||||
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
|
||||
return ImmutableList.of(
|
||||
matchesTopType(CompoundPredicate.class).then(DistinctPredicatesRule::distinct)
|
||||
);
|
||||
}
|
||||
|
||||
private static Expression distinct(CompoundPredicate expr) {
|
||||
List<Expression> extractExpressions = ExpressionUtils.extract(expr);
|
||||
Set<Expression> distinctExpressions = new LinkedHashSet<>(extractExpressions);
|
||||
if (distinctExpressions.size() != extractExpressions.size()) {
|
||||
return ExpressionUtils.combine(expr.getClass(), Lists.newArrayList(distinctExpressions));
|
||||
return ExpressionUtils.combineAsLeftDeepTree(expr.getClass(), Lists.newArrayList(distinctExpressions));
|
||||
}
|
||||
return expr;
|
||||
}
|
||||
|
||||
@ -18,21 +18,28 @@
|
||||
package org.apache.doris.nereids.rules.expression.rules;
|
||||
|
||||
import org.apache.doris.nereids.annotation.Developing;
|
||||
import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
|
||||
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
import org.apache.doris.nereids.util.Utils;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Maps;
|
||||
import com.google.common.collect.Multimaps;
|
||||
import com.google.common.collect.SetMultimap;
|
||||
import com.google.common.collect.Sets;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.HashSet;
|
||||
import java.util.Collection;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Extract common expr for `CompoundPredicate`.
|
||||
@ -41,42 +48,197 @@ import java.util.stream.Collectors;
|
||||
* transform (a and b) or (a and c) to a and (b or c)
|
||||
*/
|
||||
@Developing
|
||||
public class ExtractCommonFactorRule extends AbstractExpressionRewriteRule {
|
||||
|
||||
public class ExtractCommonFactorRule implements ExpressionPatternRuleFactory {
|
||||
public static final ExtractCommonFactorRule INSTANCE = new ExtractCommonFactorRule();
|
||||
|
||||
@Override
|
||||
public Expression visitCompoundPredicate(CompoundPredicate expr, ExpressionRewriteContext context) {
|
||||
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
|
||||
return ImmutableList.of(
|
||||
matchesTopType(CompoundPredicate.class).then(ExtractCommonFactorRule::extractCommonFactor)
|
||||
);
|
||||
}
|
||||
|
||||
Expression rewrittenChildren = ExpressionUtils.combine(expr.getClass(), ExpressionUtils.extract(expr).stream()
|
||||
.map(predicate -> rewrite(predicate, context)).collect(ImmutableList.toImmutableList()));
|
||||
if (!(rewrittenChildren instanceof CompoundPredicate)) {
|
||||
return rewrittenChildren;
|
||||
private static Expression extractCommonFactor(CompoundPredicate originExpr) {
|
||||
// fast return
|
||||
if (!(originExpr.left() instanceof CompoundPredicate || originExpr.left() instanceof BooleanLiteral)
|
||||
&& !(originExpr.right() instanceof CompoundPredicate || originExpr.right() instanceof BooleanLiteral)) {
|
||||
return originExpr;
|
||||
}
|
||||
|
||||
CompoundPredicate compoundPredicate = (CompoundPredicate) rewrittenChildren;
|
||||
// flatten same type to a list
|
||||
// e.g. ((a and (b or c)) and c) -> [a, (b or c), c]
|
||||
List<Expression> flatten = ExpressionUtils.extract(originExpr);
|
||||
|
||||
List<List<Expression>> partitions = ExpressionUtils.extract(compoundPredicate).stream()
|
||||
.map(predicate -> predicate instanceof CompoundPredicate ? ExpressionUtils.extract(
|
||||
(CompoundPredicate) predicate) : Lists.newArrayList(predicate)).collect(Collectors.toList());
|
||||
// combine and delete some boolean literal predicate
|
||||
// e.g. (a and true) -> true
|
||||
Expression simplified = ExpressionUtils.combineAsLeftDeepTree(originExpr.getClass(), flatten);
|
||||
if (!(simplified instanceof CompoundPredicate)) {
|
||||
return simplified;
|
||||
}
|
||||
|
||||
Set<Expression> commons = partitions.stream()
|
||||
.<Set<Expression>>map(HashSet::new)
|
||||
.reduce(Sets::intersection)
|
||||
.orElse(Collections.emptySet());
|
||||
// separate two levels CompoundPredicate to partitions
|
||||
// e.g. ((a and (b or c)) and c) -> [[a], [b, c], c]
|
||||
CompoundPredicate leftDeapTree = (CompoundPredicate) simplified;
|
||||
ImmutableSet.Builder<List<Expression>> partitionsBuilder
|
||||
= ImmutableSet.builderWithExpectedSize(flatten.size());
|
||||
for (Expression onPartition : ExpressionUtils.extract(leftDeapTree)) {
|
||||
if (onPartition instanceof CompoundPredicate) {
|
||||
partitionsBuilder.add(ExpressionUtils.extract((CompoundPredicate) onPartition));
|
||||
} else {
|
||||
partitionsBuilder.add(ImmutableList.of(onPartition));
|
||||
}
|
||||
}
|
||||
Set<List<Expression>> partitions = partitionsBuilder.build();
|
||||
|
||||
List<List<Expression>> uncorrelated = partitions.stream()
|
||||
.map(predicates -> predicates.stream().filter(p -> !commons.contains(p)).collect(Collectors.toList()))
|
||||
.collect(Collectors.toList());
|
||||
Expression result = extractCommonFactors(originExpr, leftDeapTree, Utils.fastToImmutableList(partitions));
|
||||
return result;
|
||||
}
|
||||
|
||||
Expression combineUncorrelated = ExpressionUtils.combine(compoundPredicate.getClass(),
|
||||
uncorrelated.stream()
|
||||
.map(predicates -> ExpressionUtils.combine(compoundPredicate.flipType(), predicates))
|
||||
.collect(Collectors.toList()));
|
||||
private static Expression extractCommonFactors(CompoundPredicate originPredicate,
|
||||
CompoundPredicate leftDeapTreePredicate, List<List<Expression>> initPartitions) {
|
||||
// extract factor and fill into commonFactorToPartIds
|
||||
// e.g.
|
||||
// originPredicate: (a and (b and c)) and (b or c)
|
||||
// leftDeapTreePredicate: ((a and b) and c) and (b or c)
|
||||
// initPartitions: [[a], [b], [c], [b, c]]
|
||||
//
|
||||
// -> commonFactorToPartIds = {a: [0], b: [1, 3], c: [2, 3]}.
|
||||
// so we can know `b` and `c` is a common factors
|
||||
SetMultimap<Expression, Integer> commonFactorToPartIds = Multimaps.newSetMultimap(
|
||||
Maps.newLinkedHashMap(), LinkedHashSet::new
|
||||
);
|
||||
int originExpressionNum = 0;
|
||||
int partId = 0;
|
||||
for (List<Expression> partition : initPartitions) {
|
||||
for (Expression expression : partition) {
|
||||
commonFactorToPartIds.put(expression, partId);
|
||||
originExpressionNum++;
|
||||
}
|
||||
partId++;
|
||||
}
|
||||
|
||||
List<Expression> finalCompound = Lists.newArrayList(commons);
|
||||
finalCompound.add(combineUncorrelated);
|
||||
// commonFactorToPartIds = {a: [0], b: [1, 3], c: [2, 3]}
|
||||
//
|
||||
// -> reverse key value of commonFactorToPartIds and remove intersecting partition
|
||||
//
|
||||
// -> 1. reverse: {[0]: [a], [1, 3]: [b], [2, 3]: [c]}
|
||||
// -> 2. sort by key size desc: {[1, 3]: [b], [2, 3]: [c], [0]: [a]}
|
||||
// -> 3. remove intersection partition: {[1, 3]: [b], [2]: [c], [0]: [a]},
|
||||
// because first part and second part intersect by partition 3
|
||||
LinkedHashMap<Set<Integer>, Set<Expression>> commonFactorPartitions
|
||||
= partitionByMostCommonFactors(commonFactorToPartIds);
|
||||
|
||||
return ExpressionUtils.combine(compoundPredicate.flipType(), finalCompound);
|
||||
int extractedExpressionNum = 0;
|
||||
for (Set<Expression> exprs : commonFactorPartitions.values()) {
|
||||
extractedExpressionNum += exprs.size();
|
||||
}
|
||||
|
||||
// no any common factor
|
||||
if (commonFactorPartitions.entrySet().iterator().next().getKey().size() <= 1
|
||||
&& !(originPredicate.getWidth() > leftDeapTreePredicate.getWidth())
|
||||
&& originExpressionNum <= extractedExpressionNum) {
|
||||
// this condition is important because it can avoid deap loop:
|
||||
// origin originExpr: A = 1 and (B > 0 and B < 10)
|
||||
// after ExtractCommonFactorRule: (A = 1 and B > 0) and (B < 10) (left deap tree)
|
||||
// after SimplifyRange: A = 1 and (B > 0 and B < 10) (right deap tree)
|
||||
return originPredicate;
|
||||
}
|
||||
|
||||
// now we can do extract common factors for each part:
|
||||
// originPredicate: (a and (b and c)) and (b or c)
|
||||
// leftDeapTreePredicate: ((a and b) and c) and (b or c)
|
||||
// initPartitions: [[a], [b], [c], [b, c]]
|
||||
// commonFactorPartitions: {[1, 3]: [b], [0]: [a]}
|
||||
//
|
||||
// -> extractedExprs: [
|
||||
// b or (false and c) = b,
|
||||
// a,
|
||||
// c
|
||||
// ]
|
||||
//
|
||||
// -> result: (b or c) and a and c
|
||||
ImmutableList.Builder<Expression> extractedExprs
|
||||
= ImmutableList.builderWithExpectedSize(commonFactorPartitions.size());
|
||||
for (Entry<Set<Integer>, Set<Expression>> kv : commonFactorPartitions.entrySet()) {
|
||||
Expression extracted = doExtractCommonFactors(
|
||||
leftDeapTreePredicate, initPartitions, kv.getKey(), kv.getValue()
|
||||
);
|
||||
extractedExprs.add(extracted);
|
||||
}
|
||||
|
||||
// combine and eliminate some boolean literal predicate
|
||||
return ExpressionUtils.combineAsLeftDeepTree(leftDeapTreePredicate.getClass(), extractedExprs.build());
|
||||
}
|
||||
|
||||
private static Expression doExtractCommonFactors(
|
||||
CompoundPredicate originPredicate,
|
||||
List<List<Expression>> partitions, Set<Integer> partitionIds, Set<Expression> commonFactors) {
|
||||
ImmutableList.Builder<Expression> uncorrelatedExprPartitionsBuilder
|
||||
= ImmutableList.builderWithExpectedSize(partitionIds.size());
|
||||
for (Integer partitionId : partitionIds) {
|
||||
List<Expression> partition = partitions.get(partitionId);
|
||||
ImmutableSet.Builder<Expression> uncorrelatedBuilder
|
||||
= ImmutableSet.builderWithExpectedSize(partition.size());
|
||||
for (Expression exprOfPart : partition) {
|
||||
if (!commonFactors.contains(exprOfPart)) {
|
||||
uncorrelatedBuilder.add(exprOfPart);
|
||||
}
|
||||
}
|
||||
|
||||
Set<Expression> uncorrelated = uncorrelatedBuilder.build();
|
||||
Expression partitionWithoutCommonFactor
|
||||
= ExpressionUtils.combineAsLeftDeepTree(originPredicate.flipType(), uncorrelated);
|
||||
if (partitionWithoutCommonFactor instanceof CompoundPredicate) {
|
||||
partitionWithoutCommonFactor = extractCommonFactor((CompoundPredicate) partitionWithoutCommonFactor);
|
||||
}
|
||||
uncorrelatedExprPartitionsBuilder.add(partitionWithoutCommonFactor);
|
||||
}
|
||||
|
||||
ImmutableList<Expression> uncorrelatedExprPartitions = uncorrelatedExprPartitionsBuilder.build();
|
||||
ImmutableList.Builder<Expression> allExprs = ImmutableList.builderWithExpectedSize(commonFactors.size() + 1);
|
||||
allExprs.addAll(commonFactors);
|
||||
|
||||
Expression combineUncorrelatedExpr = ExpressionUtils.combineAsLeftDeepTree(
|
||||
originPredicate.getClass(), uncorrelatedExprPartitions);
|
||||
allExprs.add(combineUncorrelatedExpr);
|
||||
|
||||
Expression result = ExpressionUtils.combineAsLeftDeepTree(originPredicate.flipType(), allExprs.build());
|
||||
return result;
|
||||
}
|
||||
|
||||
private static LinkedHashMap<Set<Integer>, Set<Expression>> partitionByMostCommonFactors(
|
||||
SetMultimap<Expression, Integer> commonFactorToPartIds) {
|
||||
SetMultimap<Set<Integer>, Expression> partWithCommonFactors = Multimaps.newSetMultimap(
|
||||
Maps.newLinkedHashMap(), LinkedHashSet::new
|
||||
);
|
||||
|
||||
for (Entry<Expression, Collection<Integer>> factorToId : commonFactorToPartIds.asMap().entrySet()) {
|
||||
partWithCommonFactors.put((Set<Integer>) factorToId.getValue(), factorToId.getKey());
|
||||
}
|
||||
|
||||
List<Set<Integer>> sortedPartitionIdHasCommonFactor = Lists.newArrayList(partWithCommonFactors.keySet());
|
||||
// place the most common factor at the head of this list
|
||||
sortedPartitionIdHasCommonFactor.sort((p1, p2) -> p2.size() - p1.size());
|
||||
|
||||
LinkedHashMap<Set<Integer>, Set<Expression>> shouldExtractFactors = Maps.newLinkedHashMap();
|
||||
|
||||
Set<Integer> allocatedPartitions = Sets.newLinkedHashSet();
|
||||
for (Set<Integer> originMostCommonFactorPartitions : sortedPartitionIdHasCommonFactor) {
|
||||
ImmutableSet.Builder<Integer> notAllocatePartitions = ImmutableSet.builderWithExpectedSize(
|
||||
originMostCommonFactorPartitions.size());
|
||||
for (Integer partId : originMostCommonFactorPartitions) {
|
||||
if (allocatedPartitions.add(partId)) {
|
||||
notAllocatePartitions.add(partId);
|
||||
}
|
||||
}
|
||||
|
||||
Set<Integer> mostCommonFactorPartitions = notAllocatePartitions.build();
|
||||
if (!mostCommonFactorPartitions.isEmpty()) {
|
||||
Set<Expression> commonFactors = partWithCommonFactors.get(originMostCommonFactorPartitions);
|
||||
shouldExtractFactors.put(mostCommonFactorPartitions, commonFactors);
|
||||
}
|
||||
}
|
||||
|
||||
return shouldExtractFactors;
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,24 +17,46 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.expression.rules;
|
||||
|
||||
import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionBottomUpRewriter;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewrite;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Constant evaluation of an expression.
|
||||
*/
|
||||
public class FoldConstantRule extends AbstractExpressionRewriteRule {
|
||||
public class FoldConstantRule implements ExpressionPatternRuleFactory {
|
||||
|
||||
public static final FoldConstantRule INSTANCE = new FoldConstantRule();
|
||||
|
||||
private static final ExpressionBottomUpRewriter FULL_FOLD_REWRITER = ExpressionRewrite.bottomUp(
|
||||
FoldConstantRuleOnFE.VISITOR_INSTANCE,
|
||||
FoldConstantRuleOnBE.INSTANCE
|
||||
);
|
||||
|
||||
/** evaluate by pattern match */
|
||||
@Override
|
||||
public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) {
|
||||
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
|
||||
return ImmutableList.<ExpressionPatternMatcher<? extends Expression>>builder()
|
||||
.addAll(FoldConstantRuleOnFE.PATTERN_MATCH_INSTANCE.buildRules())
|
||||
.addAll(FoldConstantRuleOnBE.INSTANCE.buildRules())
|
||||
.build();
|
||||
}
|
||||
|
||||
/** evaluate by visitor */
|
||||
public static Expression evaluate(Expression expr, ExpressionRewriteContext ctx) {
|
||||
if (ctx.cascadesContext != null
|
||||
&& ctx.cascadesContext.getConnectContext() != null
|
||||
&& ctx.cascadesContext.getConnectContext().getSessionVariable().isEnableFoldConstantByBe()) {
|
||||
return new FoldConstantRuleOnBE().rewrite(expr, ctx);
|
||||
return FULL_FOLD_REWRITER.rewrite(expr, ctx);
|
||||
} else {
|
||||
return FoldConstantRuleOnFE.VISITOR_INSTANCE.rewrite(expr, ctx);
|
||||
}
|
||||
return FoldConstantRuleOnFE.INSTANCE.rewrite(expr, ctx);
|
||||
}
|
||||
}
|
||||
|
||||
@ -27,8 +27,9 @@ import org.apache.doris.common.UserException;
|
||||
import org.apache.doris.common.util.DebugUtil;
|
||||
import org.apache.doris.common.util.TimeUtils;
|
||||
import org.apache.doris.nereids.glue.translator.ExpressionTranslator;
|
||||
import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionMatchingContext;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.Cast;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
@ -55,6 +56,7 @@ import org.apache.doris.thrift.TQueryGlobals;
|
||||
import org.apache.doris.thrift.TQueryOptions;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.Maps;
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
@ -73,24 +75,38 @@ import java.util.concurrent.TimeUnit;
|
||||
/**
|
||||
* Constant evaluation of an expression.
|
||||
*/
|
||||
public class FoldConstantRuleOnBE extends AbstractExpressionRewriteRule {
|
||||
public class FoldConstantRuleOnBE implements ExpressionPatternRuleFactory {
|
||||
|
||||
public static final FoldConstantRuleOnBE INSTANCE = new FoldConstantRuleOnBE();
|
||||
private static final Logger LOG = LogManager.getLogger(FoldConstantRuleOnBE.class);
|
||||
private final IdGenerator<ExprId> idGenerator = ExprId.createGenerator();
|
||||
|
||||
@Override
|
||||
public Expression rewrite(Expression expression, ExpressionRewriteContext ctx) {
|
||||
expression = FoldConstantRuleOnFE.INSTANCE.rewrite(expression, ctx);
|
||||
return foldByBE(expression, ctx);
|
||||
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
|
||||
return ImmutableList.of(
|
||||
root(Expression.class)
|
||||
.whenCtx(FoldConstantRuleOnBE::isEnableFoldByBe)
|
||||
.thenApply(FoldConstantRuleOnBE::foldByBE)
|
||||
);
|
||||
}
|
||||
|
||||
private Expression foldByBE(Expression root, ExpressionRewriteContext context) {
|
||||
public static boolean isEnableFoldByBe(ExpressionMatchingContext<Expression> ctx) {
|
||||
return ctx.cascadesContext != null
|
||||
&& ctx.cascadesContext.getConnectContext() != null
|
||||
&& ctx.cascadesContext.getConnectContext().getSessionVariable().isEnableFoldConstantByBe();
|
||||
}
|
||||
|
||||
/** foldByBE */
|
||||
public static Expression foldByBE(ExpressionMatchingContext<Expression> context) {
|
||||
IdGenerator<ExprId> idGenerator = ExprId.createGenerator();
|
||||
|
||||
Expression root = context.expr;
|
||||
Map<String, Expression> constMap = Maps.newHashMap();
|
||||
Map<String, TExpr> staleConstTExprMap = Maps.newHashMap();
|
||||
Expression rootWithoutAlias = root;
|
||||
if (root instanceof Alias) {
|
||||
rootWithoutAlias = ((Alias) root).child();
|
||||
}
|
||||
collectConst(rootWithoutAlias, constMap, staleConstTExprMap);
|
||||
collectConst(rootWithoutAlias, constMap, staleConstTExprMap, idGenerator);
|
||||
if (constMap.isEmpty()) {
|
||||
return root;
|
||||
}
|
||||
@ -103,7 +119,8 @@ public class FoldConstantRuleOnBE extends AbstractExpressionRewriteRule {
|
||||
return root;
|
||||
}
|
||||
|
||||
private Expression replace(Expression root, Map<String, Expression> constMap, Map<String, Expression> resultMap) {
|
||||
private static Expression replace(
|
||||
Expression root, Map<String, Expression> constMap, Map<String, Expression> resultMap) {
|
||||
for (Entry<String, Expression> entry : constMap.entrySet()) {
|
||||
if (entry.getValue().equals(root)) {
|
||||
return resultMap.get(entry.getKey());
|
||||
@ -121,7 +138,8 @@ public class FoldConstantRuleOnBE extends AbstractExpressionRewriteRule {
|
||||
return hasNewChildren ? root.withChildren(newChildren) : root;
|
||||
}
|
||||
|
||||
private void collectConst(Expression expr, Map<String, Expression> constMap, Map<String, TExpr> tExprMap) {
|
||||
private static void collectConst(Expression expr, Map<String, Expression> constMap,
|
||||
Map<String, TExpr> tExprMap, IdGenerator<ExprId> idGenerator) {
|
||||
if (expr.isConstant()) {
|
||||
// Do not constant fold cast(null as dataType) because we cannot preserve the
|
||||
// cast-to-types and that can lead to query failures, e.g., CTAS
|
||||
@ -157,13 +175,13 @@ public class FoldConstantRuleOnBE extends AbstractExpressionRewriteRule {
|
||||
} else {
|
||||
for (int i = 0; i < expr.children().size(); i++) {
|
||||
final Expression child = expr.children().get(i);
|
||||
collectConst(child, constMap, tExprMap);
|
||||
collectConst(child, constMap, tExprMap, idGenerator);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// if sleep(5) will cause rpc timeout
|
||||
private boolean skipSleepFunction(Expression expr) {
|
||||
private static boolean skipSleepFunction(Expression expr) {
|
||||
if (expr instanceof Sleep) {
|
||||
Expression param = expr.child(0);
|
||||
if (param instanceof Cast) {
|
||||
@ -176,7 +194,7 @@ public class FoldConstantRuleOnBE extends AbstractExpressionRewriteRule {
|
||||
return false;
|
||||
}
|
||||
|
||||
private Map<String, Expression> evalOnBE(Map<String, Map<String, TExpr>> paramMap,
|
||||
private static Map<String, Expression> evalOnBE(Map<String, Map<String, TExpr>> paramMap,
|
||||
Map<String, Expression> constMap, ConnectContext context) {
|
||||
|
||||
Map<String, Expression> resultMap = new HashMap<>();
|
||||
|
||||
@ -22,7 +22,13 @@ import org.apache.doris.catalog.Env;
|
||||
import org.apache.doris.cluster.ClusterNamespace;
|
||||
import org.apache.doris.nereids.exceptions.AnalysisException;
|
||||
import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionListenerMatcher;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionMatchingContext;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionTraverseListener;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionTraverseListenerFactory;
|
||||
import org.apache.doris.nereids.trees.expressions.AggregateExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.And;
|
||||
import org.apache.doris.nereids.trees.expressions.BinaryArithmetic;
|
||||
@ -80,6 +86,8 @@ import org.apache.doris.qe.GlobalVariable;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.base.Strings;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableList.Builder;
|
||||
import com.google.common.collect.Lists;
|
||||
import org.apache.commons.codec.digest.DigestUtils;
|
||||
|
||||
@ -87,13 +95,78 @@ import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.function.BiFunction;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
/**
|
||||
* evaluate an expression on fe.
|
||||
*/
|
||||
public class FoldConstantRuleOnFE extends AbstractExpressionRewriteRule {
|
||||
public class FoldConstantRuleOnFE extends AbstractExpressionRewriteRule
|
||||
implements ExpressionPatternRuleFactory, ExpressionTraverseListenerFactory {
|
||||
|
||||
public static final FoldConstantRuleOnFE INSTANCE = new FoldConstantRuleOnFE();
|
||||
public static final FoldConstantRuleOnFE VISITOR_INSTANCE = new FoldConstantRuleOnFE(true);
|
||||
public static final FoldConstantRuleOnFE PATTERN_MATCH_INSTANCE = new FoldConstantRuleOnFE(false);
|
||||
|
||||
// record whether current expression is in an aggregate function with distinct,
|
||||
// if is, we will skip to fold constant
|
||||
private static final ListenAggDistinct LISTEN_AGG_DISTINCT = new ListenAggDistinct();
|
||||
private static final CheckWhetherUnderAggDistinct NOT_UNDER_AGG_DISTINCT = new CheckWhetherUnderAggDistinct();
|
||||
|
||||
private final boolean deepRewrite;
|
||||
|
||||
public FoldConstantRuleOnFE(boolean deepRewrite) {
|
||||
this.deepRewrite = deepRewrite;
|
||||
}
|
||||
|
||||
public static Expression evaluate(Expression expression, ExpressionRewriteContext expressionRewriteContext) {
|
||||
return VISITOR_INSTANCE.rewrite(expression, expressionRewriteContext);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ExpressionListenerMatcher<? extends Expression>> buildListeners() {
|
||||
return ImmutableList.of(
|
||||
listenerType(AggregateFunction.class)
|
||||
.when(AggregateFunction::isDistinct)
|
||||
.then(LISTEN_AGG_DISTINCT.as()),
|
||||
|
||||
listenerType(AggregateExpression.class)
|
||||
.when(AggregateExpression::isDistinct)
|
||||
.then(LISTEN_AGG_DISTINCT.as())
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
|
||||
return ImmutableList.of(
|
||||
matches(EncryptKeyRef.class, this::visitEncryptKeyRef),
|
||||
matches(EqualTo.class, this::visitEqualTo),
|
||||
matches(GreaterThan.class, this::visitGreaterThan),
|
||||
matches(GreaterThanEqual.class, this::visitGreaterThanEqual),
|
||||
matches(LessThan.class, this::visitLessThan),
|
||||
matches(LessThanEqual.class, this::visitLessThanEqual),
|
||||
matches(NullSafeEqual.class, this::visitNullSafeEqual),
|
||||
matches(Not.class, this::visitNot),
|
||||
matches(Database.class, this::visitDatabase),
|
||||
matches(CurrentUser.class, this::visitCurrentUser),
|
||||
matches(CurrentCatalog.class, this::visitCurrentCatalog),
|
||||
matches(User.class, this::visitUser),
|
||||
matches(ConnectionId.class, this::visitConnectionId),
|
||||
matches(And.class, this::visitAnd),
|
||||
matches(Or.class, this::visitOr),
|
||||
matches(Cast.class, this::visitCast),
|
||||
matches(BoundFunction.class, this::visitBoundFunction),
|
||||
matches(BinaryArithmetic.class, this::visitBinaryArithmetic),
|
||||
matches(CaseWhen.class, this::visitCaseWhen),
|
||||
matches(If.class, this::visitIf),
|
||||
matches(InPredicate.class, this::visitInPredicate),
|
||||
matches(IsNull.class, this::visitIsNull),
|
||||
matches(TimestampArithmetic.class, this::visitTimestampArithmetic),
|
||||
matches(Password.class, this::visitPassword),
|
||||
matches(Array.class, this::visitArray),
|
||||
matches(Date.class, this::visitDate),
|
||||
matches(Version.class, this::visitVersion)
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) {
|
||||
@ -253,7 +326,7 @@ public class FoldConstantRuleOnFE extends AbstractExpressionRewriteRule {
|
||||
List<Expression> nonTrueLiteral = Lists.newArrayList();
|
||||
int nullCount = 0;
|
||||
for (Expression e : and.children()) {
|
||||
e = e.accept(this, context);
|
||||
e = deepRewrite ? e.accept(this, context) : e;
|
||||
if (BooleanLiteral.FALSE.equals(e)) {
|
||||
return BooleanLiteral.FALSE;
|
||||
} else if (e instanceof NullLiteral) {
|
||||
@ -294,7 +367,7 @@ public class FoldConstantRuleOnFE extends AbstractExpressionRewriteRule {
|
||||
List<Expression> nonFalseLiteral = Lists.newArrayList();
|
||||
int nullCount = 0;
|
||||
for (Expression e : or.children()) {
|
||||
e = e.accept(this, context);
|
||||
e = deepRewrite ? e.accept(this, context) : e;
|
||||
if (BooleanLiteral.TRUE.equals(e)) {
|
||||
return BooleanLiteral.TRUE;
|
||||
} else if (e instanceof NullLiteral) {
|
||||
@ -412,9 +485,13 @@ public class FoldConstantRuleOnFE extends AbstractExpressionRewriteRule {
|
||||
}
|
||||
}
|
||||
|
||||
Expression defaultResult = caseWhen.getDefaultValue().isPresent()
|
||||
? rewrite(caseWhen.getDefaultValue().get(), context)
|
||||
: null;
|
||||
Expression defaultResult = null;
|
||||
if (caseWhen.getDefaultValue().isPresent()) {
|
||||
defaultResult = caseWhen.getDefaultValue().get();
|
||||
if (deepRewrite) {
|
||||
defaultResult = rewrite(defaultResult, context);
|
||||
}
|
||||
}
|
||||
if (foundNewDefault) {
|
||||
defaultResult = newDefault;
|
||||
}
|
||||
@ -537,28 +614,83 @@ public class FoldConstantRuleOnFE extends AbstractExpressionRewriteRule {
|
||||
return new StringLiteral(GlobalVariable.version);
|
||||
}
|
||||
|
||||
private <E> E rewriteChildren(Expression expr, ExpressionRewriteContext ctx) {
|
||||
return (E) super.visit(expr, ctx);
|
||||
}
|
||||
|
||||
private boolean allArgsIsAllLiteral(Expression expression) {
|
||||
return ExpressionUtils.isAllLiteral(expression.getArguments());
|
||||
}
|
||||
|
||||
private boolean argsHasNullLiteral(Expression expression) {
|
||||
return ExpressionUtils.hasNullLiteral(expression.getArguments());
|
||||
private <E extends Expression> E rewriteChildren(E expr, ExpressionRewriteContext context) {
|
||||
if (!deepRewrite) {
|
||||
return expr;
|
||||
}
|
||||
switch (expr.arity()) {
|
||||
case 1: {
|
||||
Expression originChild = expr.child(0);
|
||||
Expression newChild = originChild.accept(this, context);
|
||||
return (originChild != newChild) ? (E) expr.withChildren(ImmutableList.of(newChild)) : expr;
|
||||
}
|
||||
case 2: {
|
||||
Expression originLeft = expr.child(0);
|
||||
Expression newLeft = originLeft.accept(this, context);
|
||||
Expression originRight = expr.child(1);
|
||||
Expression newRight = originRight.accept(this, context);
|
||||
return (originLeft != newLeft || originRight != newRight)
|
||||
? (E) expr.withChildren(ImmutableList.of(newLeft, newRight))
|
||||
: expr;
|
||||
}
|
||||
case 0: {
|
||||
return expr;
|
||||
}
|
||||
default: {
|
||||
boolean hasNewChildren = false;
|
||||
Builder<Expression> newChildren = ImmutableList.builderWithExpectedSize(expr.arity());
|
||||
for (Expression child : expr.children()) {
|
||||
Expression newChild = child.accept(this, context);
|
||||
if (newChild != child) {
|
||||
hasNewChildren = true;
|
||||
}
|
||||
newChildren.add(newChild);
|
||||
}
|
||||
return hasNewChildren ? (E) expr.withChildren(newChildren.build()) : expr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private Optional<Expression> preProcess(Expression expression) {
|
||||
if (expression instanceof AggregateFunction || expression instanceof TableGeneratingFunction) {
|
||||
return Optional.of(expression);
|
||||
}
|
||||
if (expression instanceof PropagateNullable && argsHasNullLiteral(expression)) {
|
||||
if (expression instanceof PropagateNullable && ExpressionUtils.hasNullLiteral(expression.getArguments())) {
|
||||
return Optional.of(new NullLiteral(expression.getDataType()));
|
||||
}
|
||||
if (!allArgsIsAllLiteral(expression)) {
|
||||
if (!ExpressionUtils.isAllLiteral(expression.getArguments())) {
|
||||
return Optional.of(expression);
|
||||
}
|
||||
return Optional.empty();
|
||||
}
|
||||
|
||||
private static class ListenAggDistinct implements ExpressionTraverseListener<Expression> {
|
||||
@Override
|
||||
public void onEnter(ExpressionMatchingContext<Expression> context) {
|
||||
context.cascadesContext.incrementDistinctAggLevel();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onExit(ExpressionMatchingContext<Expression> context, Expression rewritten) {
|
||||
context.cascadesContext.decrementDistinctAggLevel();
|
||||
}
|
||||
}
|
||||
|
||||
private static class CheckWhetherUnderAggDistinct implements Predicate<ExpressionMatchingContext<Expression>> {
|
||||
@Override
|
||||
public boolean test(ExpressionMatchingContext<Expression> context) {
|
||||
return context.cascadesContext.getDistinctAggLevel() == 0;
|
||||
}
|
||||
|
||||
public <E extends Expression> Predicate<ExpressionMatchingContext<E>> as() {
|
||||
return (Predicate) this;
|
||||
}
|
||||
}
|
||||
|
||||
private <E extends Expression> ExpressionPatternMatcher<? extends Expression> matches(
|
||||
Class<E> clazz, BiFunction<E, ExpressionRewriteContext, Expression> visitMethod) {
|
||||
return matchesType(clazz)
|
||||
.whenCtx(NOT_UNDER_AGG_DISTINCT.as())
|
||||
.thenApply(ctx -> visitMethod.apply(ctx.expr, ctx.rewriteContext));
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,13 +17,14 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.expression.rules;
|
||||
|
||||
import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.InPredicate;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
@ -31,25 +32,32 @@ import java.util.Set;
|
||||
* Deduplicate InPredicate, For example:
|
||||
* where A in (x, x) ==> where A in (x)
|
||||
*/
|
||||
public class InPredicateDedup extends AbstractExpressionRewriteRule {
|
||||
|
||||
public static InPredicateDedup INSTANCE = new InPredicateDedup();
|
||||
public class InPredicateDedup implements ExpressionPatternRuleFactory {
|
||||
public static final InPredicateDedup INSTANCE = new InPredicateDedup();
|
||||
|
||||
@Override
|
||||
public Expression visitInPredicate(InPredicate inPredicate, ExpressionRewriteContext context) {
|
||||
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
|
||||
return ImmutableList.of(
|
||||
matchesType(InPredicate.class).then(InPredicateDedup::dedup)
|
||||
);
|
||||
}
|
||||
|
||||
/** dedup */
|
||||
public static Expression dedup(InPredicate inPredicate) {
|
||||
// In many BI scenarios, the sql is auto-generated, and hence there may be thousands of options.
|
||||
// It takes a long time to apply this rule. So set a threshold for the max number.
|
||||
if (inPredicate.getOptions().size() > 200) {
|
||||
int optionSize = inPredicate.getOptions().size();
|
||||
if (optionSize > 200) {
|
||||
return inPredicate;
|
||||
}
|
||||
Set<Expression> dedupExpr = new HashSet<>();
|
||||
List<Expression> newOptions = new ArrayList<>();
|
||||
ImmutableSet.Builder<Expression> newOptionsBuilder = ImmutableSet.builderWithExpectedSize(inPredicate.arity());
|
||||
for (Expression option : inPredicate.getOptions()) {
|
||||
if (dedupExpr.contains(option)) {
|
||||
continue;
|
||||
}
|
||||
dedupExpr.add(option);
|
||||
newOptions.add(option);
|
||||
newOptionsBuilder.add(option);
|
||||
}
|
||||
|
||||
Set<Expression> newOptions = newOptionsBuilder.build();
|
||||
if (newOptions.size() == optionSize) {
|
||||
return inPredicate;
|
||||
}
|
||||
return new InPredicate(inPredicate.getCompareExpr(), newOptions);
|
||||
}
|
||||
|
||||
@ -17,12 +17,14 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.expression.rules;
|
||||
|
||||
import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
|
||||
import org.apache.doris.nereids.trees.expressions.EqualTo;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.InPredicate;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
@ -36,17 +38,16 @@ import java.util.List;
|
||||
* NOTICE: it's related with `SimplifyRange`.
|
||||
* They are same processes, so must change synchronously.
|
||||
*/
|
||||
public class InPredicateToEqualToRule extends AbstractExpressionRewriteRule {
|
||||
|
||||
public static InPredicateToEqualToRule INSTANCE = new InPredicateToEqualToRule();
|
||||
public class InPredicateToEqualToRule implements ExpressionPatternRuleFactory {
|
||||
public static final InPredicateToEqualToRule INSTANCE = new InPredicateToEqualToRule();
|
||||
|
||||
@Override
|
||||
public Expression visitInPredicate(InPredicate inPredicate, ExpressionRewriteContext context) {
|
||||
Expression left = inPredicate.getCompareExpr();
|
||||
List<Expression> right = inPredicate.getOptions();
|
||||
if (right.size() != 1) {
|
||||
return new InPredicate(left.accept(this, context), right);
|
||||
}
|
||||
return new EqualTo(left.accept(this, context), right.get(0).accept(this, context));
|
||||
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
|
||||
return ImmutableList.of(
|
||||
matchesType(InPredicate.class)
|
||||
.when(in -> in.getOptions().size() == 1)
|
||||
.then(in -> new EqualTo(in.getCompareExpr(), in.getOptions().get(0))
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,22 +17,31 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.expression.rules;
|
||||
|
||||
import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
|
||||
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Normalizes binary predicates of the form 'expr' op 'slot' so that the slot is on the left-hand side.
|
||||
* For example:
|
||||
* 5 > id -> id < 5
|
||||
*/
|
||||
public class NormalizeBinaryPredicatesRule extends AbstractExpressionRewriteRule {
|
||||
|
||||
public static NormalizeBinaryPredicatesRule INSTANCE = new NormalizeBinaryPredicatesRule();
|
||||
public class NormalizeBinaryPredicatesRule implements ExpressionPatternRuleFactory {
|
||||
public static final NormalizeBinaryPredicatesRule INSTANCE = new NormalizeBinaryPredicatesRule();
|
||||
|
||||
@Override
|
||||
public Expression visitComparisonPredicate(ComparisonPredicate expr, ExpressionRewriteContext context) {
|
||||
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
|
||||
return ImmutableList.of(
|
||||
matchesType(ComparisonPredicate.class).then(NormalizeBinaryPredicatesRule::normalize)
|
||||
);
|
||||
}
|
||||
|
||||
public static Expression normalize(ComparisonPredicate expr) {
|
||||
return expr.left().isConstant() && !expr.right().isConstant() ? expr.commute() : expr;
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,31 +17,34 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.expression.rules;
|
||||
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteRule;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
|
||||
import org.apache.doris.nereids.trees.expressions.EqualTo;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.IsNull;
|
||||
import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* convert "<=>" to "=", if any side is not nullable
|
||||
* convert "A <=> null" to "A is null"
|
||||
*/
|
||||
public class NullSafeEqualToEqual extends DefaultExpressionRewriter<ExpressionRewriteContext> implements
|
||||
ExpressionRewriteRule<ExpressionRewriteContext> {
|
||||
public class NullSafeEqualToEqual implements ExpressionPatternRuleFactory {
|
||||
public static final NullSafeEqualToEqual INSTANCE = new NullSafeEqualToEqual();
|
||||
|
||||
@Override
|
||||
public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) {
|
||||
return expr.accept(this, null);
|
||||
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
|
||||
return ImmutableList.of(
|
||||
matchesType(NullSafeEqual.class).then(NullSafeEqualToEqual::rewrite)
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitNullSafeEqual(NullSafeEqual nullSafeEqual, ExpressionRewriteContext ctx) {
|
||||
private static Expression rewrite(NullSafeEqual nullSafeEqual) {
|
||||
if (nullSafeEqual.left() instanceof NullLiteral) {
|
||||
if (nullSafeEqual.right().nullable()) {
|
||||
return new IsNull(nullSafeEqual.right());
|
||||
|
||||
@ -82,7 +82,7 @@ public class OneListPartitionEvaluator
|
||||
expr = super.visit(expr, context);
|
||||
if (!(expr instanceof Literal)) {
|
||||
// just forward to fold constant rule
|
||||
return expr.accept(FoldConstantRuleOnFE.INSTANCE, expressionRewriteContext);
|
||||
return FoldConstantRuleOnFE.evaluate(expr, expressionRewriteContext);
|
||||
}
|
||||
return expr;
|
||||
}
|
||||
|
||||
@ -92,7 +92,7 @@ public class OneRangePartitionEvaluator
|
||||
|
||||
/** OneRangePartitionEvaluator */
|
||||
public OneRangePartitionEvaluator(long partitionId, List<Slot> partitionSlots,
|
||||
RangePartitionItem partitionItem, CascadesContext cascadesContext) {
|
||||
RangePartitionItem partitionItem, CascadesContext cascadesContext, int expandThreshold) {
|
||||
this.partitionId = partitionId;
|
||||
this.partitionSlots = Objects.requireNonNull(partitionSlots, "partitionSlots cannot be null");
|
||||
this.partitionItem = Objects.requireNonNull(partitionItem, "partitionItem cannot be null");
|
||||
@ -103,41 +103,46 @@ public class OneRangePartitionEvaluator
|
||||
this.lowers = toNereidsLiterals(range.lowerEndpoint());
|
||||
this.uppers = toNereidsLiterals(range.upperEndpoint());
|
||||
|
||||
PartitionRangeExpander expander = new PartitionRangeExpander();
|
||||
this.partitionSlotTypes = expander.computePartitionSlotTypes(lowers, uppers);
|
||||
this.slotToType = Maps.newHashMapWithExpectedSize(16);
|
||||
for (int i = 0; i < partitionSlots.size(); i++) {
|
||||
slotToType.put(partitionSlots.get(i), partitionSlotTypes.get(i));
|
||||
this.partitionSlotTypes = PartitionRangeExpander.computePartitionSlotTypes(lowers, uppers);
|
||||
|
||||
if (partitionSlots.size() == 1) {
|
||||
// fast path
|
||||
Slot partSlot = partitionSlots.get(0);
|
||||
this.slotToType = ImmutableMap.of(partSlot, partitionSlotTypes.get(0));
|
||||
this.partitionSlotContainsNull
|
||||
= ImmutableMap.of(partSlot, range.lowerEndpoint().getKeys().get(0).isMinValue());
|
||||
} else {
|
||||
// slow path
|
||||
this.slotToType = Maps.newHashMap();
|
||||
for (int i = 0; i < partitionSlots.size(); i++) {
|
||||
slotToType.put(partitionSlots.get(i), partitionSlotTypes.get(i));
|
||||
}
|
||||
|
||||
this.partitionSlotContainsNull = Maps.newHashMap();
|
||||
for (int i = 0; i < partitionSlots.size(); i++) {
|
||||
Slot slot = partitionSlots.get(i);
|
||||
if (!slot.nullable()) {
|
||||
partitionSlotContainsNull.put(slot, false);
|
||||
continue;
|
||||
}
|
||||
PartitionSlotType partitionSlotType = partitionSlotTypes.get(i);
|
||||
boolean maybeNull;
|
||||
switch (partitionSlotType) {
|
||||
case CONST:
|
||||
case RANGE:
|
||||
maybeNull = range.lowerEndpoint().getKeys().get(i).isMinValue();
|
||||
break;
|
||||
case OTHER:
|
||||
maybeNull = true;
|
||||
break;
|
||||
default:
|
||||
throw new AnalysisException("Unknown partition slot type: " + partitionSlotType);
|
||||
}
|
||||
partitionSlotContainsNull.put(slot, maybeNull);
|
||||
}
|
||||
}
|
||||
|
||||
this.partitionSlotContainsNull = Maps.newHashMapWithExpectedSize(16);
|
||||
for (int i = 0; i < partitionSlots.size(); i++) {
|
||||
Slot slot = partitionSlots.get(i);
|
||||
if (!slot.nullable()) {
|
||||
partitionSlotContainsNull.put(slot, false);
|
||||
continue;
|
||||
}
|
||||
PartitionSlotType partitionSlotType = partitionSlotTypes.get(i);
|
||||
boolean maybeNull = false;
|
||||
switch (partitionSlotType) {
|
||||
case CONST:
|
||||
case RANGE:
|
||||
maybeNull = range.lowerEndpoint().getKeys().get(i).isMinValue();
|
||||
break;
|
||||
case OTHER:
|
||||
maybeNull = true;
|
||||
break;
|
||||
default:
|
||||
throw new AnalysisException("Unknown partition slot type: " + partitionSlotType);
|
||||
}
|
||||
partitionSlotContainsNull.put(slot, maybeNull);
|
||||
}
|
||||
|
||||
int expandThreshold = cascadesContext.getAndCacheSessionVariable(
|
||||
"partitionPruningExpandThreshold",
|
||||
10, sessionVariable -> sessionVariable.partitionPruningExpandThreshold);
|
||||
|
||||
List<List<Expression>> expandInputs = expander.tryExpandRange(
|
||||
List<List<Expression>> expandInputs = PartitionRangeExpander.tryExpandRange(
|
||||
partitionSlots, lowers, uppers, partitionSlotTypes, expandThreshold);
|
||||
// after expand range, we will get 2 dimension list like list:
|
||||
// part_col1: [1], part_col2:[4, 5, 6], we should combine it to
|
||||
@ -451,10 +456,13 @@ public class OneRangePartitionEvaluator
|
||||
|
||||
private EvaluateRangeResult evaluateChildrenThenThis(Expression expr, EvaluateRangeInput context) {
|
||||
// evaluate children
|
||||
List<Expression> newChildren = new ArrayList<>();
|
||||
List<EvaluateRangeResult> childrenResults = new ArrayList<>();
|
||||
List<Expression> children = expr.children();
|
||||
ImmutableList.Builder<Expression> newChildren = ImmutableList.builderWithExpectedSize(children.size());
|
||||
List<EvaluateRangeResult> childrenResults = new ArrayList<>(children.size());
|
||||
boolean hasNewChildren = false;
|
||||
for (Expression child : expr.children()) {
|
||||
|
||||
for (int i = 0; i < children.size(); i++) {
|
||||
Expression child = children.get(i);
|
||||
EvaluateRangeResult childResult = child.accept(this, context);
|
||||
if (childResult.result != child) {
|
||||
hasNewChildren = true;
|
||||
@ -463,11 +471,11 @@ public class OneRangePartitionEvaluator
|
||||
newChildren.add(childResult.result);
|
||||
}
|
||||
if (hasNewChildren) {
|
||||
expr = expr.withChildren(newChildren);
|
||||
expr = expr.withChildren(newChildren.build());
|
||||
}
|
||||
|
||||
// evaluate this
|
||||
expr = expr.accept(FoldConstantRuleOnFE.INSTANCE, expressionRewriteContext);
|
||||
expr = FoldConstantRuleOnFE.evaluate(expr, expressionRewriteContext);
|
||||
return new EvaluateRangeResult(expr, context.defaultColumnRanges, childrenResults);
|
||||
}
|
||||
|
||||
@ -575,9 +583,28 @@ public class OneRangePartitionEvaluator
|
||||
}
|
||||
|
||||
private List<Literal> toNereidsLiterals(PartitionKey partitionKey) {
|
||||
List<Literal> literals = Lists.newArrayListWithCapacity(partitionKey.getKeys().size());
|
||||
for (int i = 0; i < partitionKey.getKeys().size(); i++) {
|
||||
LiteralExpr literalExpr = partitionKey.getKeys().get(i);
|
||||
if (partitionKey.getKeys().size() == 1) {
|
||||
// fast path
|
||||
return toSingleNereidsLiteral(partitionKey);
|
||||
}
|
||||
|
||||
// slow path
|
||||
return toMultiNereidsLiterals(partitionKey);
|
||||
}
|
||||
|
||||
private List<Literal> toSingleNereidsLiteral(PartitionKey partitionKey) {
|
||||
List<LiteralExpr> keys = partitionKey.getKeys();
|
||||
LiteralExpr literalExpr = keys.get(0);
|
||||
PrimitiveType primitiveType = partitionKey.getTypes().get(0);
|
||||
Type type = Type.fromPrimitiveType(primitiveType);
|
||||
return ImmutableList.of(Literal.fromLegacyLiteral(literalExpr, type));
|
||||
}
|
||||
|
||||
private List<Literal> toMultiNereidsLiterals(PartitionKey partitionKey) {
|
||||
List<LiteralExpr> keys = partitionKey.getKeys();
|
||||
List<Literal> literals = Lists.newArrayListWithCapacity(keys.size());
|
||||
for (int i = 0; i < keys.size(); i++) {
|
||||
LiteralExpr literalExpr = keys.get(i);
|
||||
PrimitiveType primitiveType = partitionKey.getTypes().get(i);
|
||||
Type type = Type.fromPrimitiveType(primitiveType);
|
||||
literals.add(Literal.fromLegacyLiteral(literalExpr, type));
|
||||
@ -613,8 +640,8 @@ public class OneRangePartitionEvaluator
|
||||
Literal lower = span.lowerEndpoint().getValue();
|
||||
Literal upper = span.upperEndpoint().getValue();
|
||||
|
||||
Expression lowerDate = new Date(lower).accept(FoldConstantRuleOnFE.INSTANCE, expressionRewriteContext);
|
||||
Expression upperDate = new Date(upper).accept(FoldConstantRuleOnFE.INSTANCE, expressionRewriteContext);
|
||||
Expression lowerDate = FoldConstantRuleOnFE.evaluate(new Date(lower), expressionRewriteContext);
|
||||
Expression upperDate = FoldConstantRuleOnFE.evaluate(new Date(upper), expressionRewriteContext);
|
||||
|
||||
if (lowerDate instanceof Literal && upperDate instanceof Literal && lowerDate.equals(upperDate)) {
|
||||
return new EvaluateRangeResult(lowerDate, result.columnRanges, result.childrenResult);
|
||||
@ -696,7 +723,7 @@ public class OneRangePartitionEvaluator
|
||||
|
||||
public EvaluateRangeResult(Expression result, Map<Slot, ColumnRange> columnRanges,
|
||||
List<EvaluateRangeResult> childrenResult) {
|
||||
this(result, columnRanges, childrenResult, childrenResult.stream().allMatch(r -> r.isRejectNot()));
|
||||
this(result, columnRanges, childrenResult, allIsRejectNot(childrenResult));
|
||||
}
|
||||
|
||||
public EvaluateRangeResult withRejectNot(boolean rejectNot) {
|
||||
@ -706,6 +733,15 @@ public class OneRangePartitionEvaluator
|
||||
public boolean isRejectNot() {
|
||||
return rejectNot;
|
||||
}
|
||||
|
||||
private static boolean allIsRejectNot(List<EvaluateRangeResult> childrenResult) {
|
||||
for (EvaluateRangeResult evaluateRangeResult : childrenResult) {
|
||||
if (!evaluateRangeResult.isRejectNot()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@ -17,15 +17,17 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.expression.rules;
|
||||
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionBottomUpRewriter;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewrite;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteRule;
|
||||
import org.apache.doris.nereids.trees.expressions.EqualTo;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.InPredicate;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Or;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
@ -54,20 +56,25 @@ import java.util.Set;
|
||||
* adding any additional rule-specific fields to the default ExpressionRewriteContext. However, the entire expression
|
||||
* rewrite framework always passes an ExpressionRewriteContext of type context to all rules.
|
||||
*/
|
||||
public class OrToIn extends DefaultExpressionRewriter<ExpressionRewriteContext> implements
|
||||
ExpressionRewriteRule<ExpressionRewriteContext> {
|
||||
public class OrToIn implements ExpressionPatternRuleFactory {
|
||||
|
||||
public static final OrToIn INSTANCE = new OrToIn();
|
||||
|
||||
public static final int REWRITE_OR_TO_IN_PREDICATE_THRESHOLD = 2;
|
||||
|
||||
@Override
|
||||
public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) {
|
||||
return expr.accept(this, null);
|
||||
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
|
||||
return ImmutableList.of(
|
||||
matchesTopType(Or.class).then(OrToIn::rewrite)
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitOr(Or or, ExpressionRewriteContext ctx) {
|
||||
public Expression rewriteTree(Expression expr, ExpressionRewriteContext context) {
|
||||
ExpressionBottomUpRewriter bottomUpRewriter = ExpressionRewrite.bottomUp(this);
|
||||
return bottomUpRewriter.rewrite(expr, context);
|
||||
}
|
||||
|
||||
private static Expression rewrite(Or or) {
|
||||
// NOTICE: use linked hash map to avoid unstable order or entry.
|
||||
// unstable order entry lead to dead loop since return expression always un-equals to original one.
|
||||
Map<NamedExpression, Set<Literal>> slotNameToLiteral = Maps.newLinkedHashMap();
|
||||
@ -80,6 +87,10 @@ public class OrToIn extends DefaultExpressionRewriter<ExpressionRewriteContext>
|
||||
handleInPredicate((InPredicate) expression, slotNameToLiteral, disConjunctToSlot);
|
||||
}
|
||||
}
|
||||
if (disConjunctToSlot.isEmpty()) {
|
||||
return or;
|
||||
}
|
||||
|
||||
List<Expression> rewrittenOr = new ArrayList<>();
|
||||
for (Map.Entry<NamedExpression, Set<Literal>> entry : slotNameToLiteral.entrySet()) {
|
||||
Set<Literal> literals = entry.getValue();
|
||||
@ -90,7 +101,7 @@ public class OrToIn extends DefaultExpressionRewriter<ExpressionRewriteContext>
|
||||
}
|
||||
for (Expression expression : expressions) {
|
||||
if (disConjunctToSlot.get(expression) == null) {
|
||||
rewrittenOr.add(expression.accept(this, null));
|
||||
rewrittenOr.add(expression);
|
||||
} else {
|
||||
Set<Literal> literals = slotNameToLiteral.get(disConjunctToSlot.get(expression));
|
||||
if (literals.size() < REWRITE_OR_TO_IN_PREDICATE_THRESHOLD) {
|
||||
@ -102,7 +113,7 @@ public class OrToIn extends DefaultExpressionRewriter<ExpressionRewriteContext>
|
||||
return ExpressionUtils.or(rewrittenOr);
|
||||
}
|
||||
|
||||
private void handleEqualTo(EqualTo equal, Map<NamedExpression, Set<Literal>> slotNameToLiteral,
|
||||
private static void handleEqualTo(EqualTo equal, Map<NamedExpression, Set<Literal>> slotNameToLiteral,
|
||||
Map<Expression, NamedExpression> disConjunctToSlot) {
|
||||
Expression left = equal.left();
|
||||
Expression right = equal.right();
|
||||
@ -115,7 +126,7 @@ public class OrToIn extends DefaultExpressionRewriter<ExpressionRewriteContext>
|
||||
}
|
||||
}
|
||||
|
||||
private void handleInPredicate(InPredicate inPredicate, Map<NamedExpression, Set<Literal>> slotNameToLiteral,
|
||||
private static void handleInPredicate(InPredicate inPredicate, Map<NamedExpression, Set<Literal>> slotNameToLiteral,
|
||||
Map<Expression, NamedExpression> disConjunctToSlot) {
|
||||
// TODO a+b in (1,2,3...) is not supported now
|
||||
if (inPredicate.getCompareExpr() instanceof NamedExpression
|
||||
@ -127,10 +138,9 @@ public class OrToIn extends DefaultExpressionRewriter<ExpressionRewriteContext>
|
||||
}
|
||||
}
|
||||
|
||||
public void addSlotToLiteral(NamedExpression namedExpression, Literal literal,
|
||||
private static void addSlotToLiteral(NamedExpression namedExpression, Literal literal,
|
||||
Map<NamedExpression, Set<Literal>> slotNameToLiteral) {
|
||||
Set<Literal> literals = slotNameToLiteral.computeIfAbsent(namedExpression, k -> new LinkedHashSet<>());
|
||||
literals.add(literal);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -21,6 +21,7 @@ import org.apache.doris.catalog.ListPartitionItem;
|
||||
import org.apache.doris.catalog.PartitionItem;
|
||||
import org.apache.doris.catalog.RangePartitionItem;
|
||||
import org.apache.doris.nereids.CascadesContext;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
|
||||
import org.apache.doris.nereids.trees.expressions.Cast;
|
||||
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
@ -81,14 +82,19 @@ public class PartitionPruner extends DefaultExpressionRewriter<Void> {
|
||||
&& ((Cast) right).child().getDataType().isDateType()) {
|
||||
DateTimeLiteral dt = (DateTimeLiteral) left;
|
||||
Cast cast = (Cast) right;
|
||||
return cp.withChildren(new DateLiteral(dt.getYear(), dt.getMonth(), dt.getDay()), cast.child());
|
||||
return cp.withChildren(
|
||||
ImmutableList.of(new DateLiteral(dt.getYear(), dt.getMonth(), dt.getDay()), cast.child())
|
||||
);
|
||||
} else if (right instanceof DateTimeLiteral && ((DateTimeLiteral) right).isMidnight()
|
||||
&& left instanceof Cast
|
||||
&& ((Cast) left).child() instanceof SlotReference
|
||||
&& ((Cast) left).child().getDataType().isDateType()) {
|
||||
DateTimeLiteral dt = (DateTimeLiteral) right;
|
||||
Cast cast = (Cast) left;
|
||||
return cp.withChildren(cast.child(), new DateLiteral(dt.getYear(), dt.getMonth(), dt.getDay()));
|
||||
return cp.withChildren(ImmutableList.of(
|
||||
cast.child(),
|
||||
new DateLiteral(dt.getYear(), dt.getMonth(), dt.getDay()))
|
||||
);
|
||||
} else {
|
||||
return cp;
|
||||
}
|
||||
@ -115,13 +121,18 @@ public class PartitionPruner extends DefaultExpressionRewriter<Void> {
|
||||
partitionPredicate, ImmutableSet.copyOf(partitionSlots), cascadesContext);
|
||||
partitionPredicate = PredicateRewriteForPartitionPrune.rewrite(partitionPredicate, cascadesContext);
|
||||
|
||||
int expandThreshold = cascadesContext.getAndCacheSessionVariable(
|
||||
"partitionPruningExpandThreshold",
|
||||
10, sessionVariable -> sessionVariable.partitionPruningExpandThreshold);
|
||||
|
||||
List<OnePartitionEvaluator> evaluators = Lists.newArrayListWithCapacity(idToPartitions.size());
|
||||
for (Entry<Long, PartitionItem> kv : idToPartitions.entrySet()) {
|
||||
evaluators.add(toPartitionEvaluator(
|
||||
kv.getKey(), kv.getValue(), partitionSlots, cascadesContext, partitionTableType));
|
||||
kv.getKey(), kv.getValue(), partitionSlots, cascadesContext, expandThreshold));
|
||||
}
|
||||
|
||||
partitionPredicate = OrToIn.INSTANCE.rewrite(partitionPredicate, null);
|
||||
partitionPredicate = OrToIn.INSTANCE.rewriteTree(
|
||||
partitionPredicate, new ExpressionRewriteContext(cascadesContext));
|
||||
PartitionPruner partitionPruner = new PartitionPruner(evaluators, partitionPredicate);
|
||||
//TODO: we keep default partition because it's too hard to prune it, we return false in canPrune().
|
||||
return partitionPruner.prune();
|
||||
@ -131,13 +142,13 @@ public class PartitionPruner extends DefaultExpressionRewriter<Void> {
|
||||
* convert partition item to partition evaluator
|
||||
*/
|
||||
public static final OnePartitionEvaluator toPartitionEvaluator(long id, PartitionItem partitionItem,
|
||||
List<Slot> partitionSlots, CascadesContext cascadesContext, PartitionTableType partitionTableType) {
|
||||
List<Slot> partitionSlots, CascadesContext cascadesContext, int expandThreshold) {
|
||||
if (partitionItem instanceof ListPartitionItem) {
|
||||
return new OneListPartitionEvaluator(
|
||||
id, partitionSlots, (ListPartitionItem) partitionItem, cascadesContext);
|
||||
} else if (partitionItem instanceof RangePartitionItem) {
|
||||
return new OneRangePartitionEvaluator(
|
||||
id, partitionSlots, (RangePartitionItem) partitionItem, cascadesContext);
|
||||
id, partitionSlots, (RangePartitionItem) partitionItem, cascadesContext, expandThreshold);
|
||||
} else {
|
||||
return new UnknownPartitionEvaluator(id, partitionItem);
|
||||
}
|
||||
|
||||
@ -41,7 +41,6 @@ import java.time.LocalDate;
|
||||
import java.time.ZoneOffset;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.NoSuchElementException;
|
||||
import java.util.function.Function;
|
||||
|
||||
/**
|
||||
@ -74,10 +73,44 @@ public class PartitionRangeExpander {
|
||||
}
|
||||
|
||||
/** expandRangeLiterals */
|
||||
public final List<List<Expression>> tryExpandRange(
|
||||
public static final List<List<Expression>> tryExpandRange(
|
||||
List<Slot> partitionSlots, List<Literal> lowers, List<Literal> uppers,
|
||||
List<PartitionSlotType> partitionSlotTypes, int expandThreshold) {
|
||||
if (partitionSlots.size() == 1) {
|
||||
return tryExpandSingleColumnRange(partitionSlots.get(0), lowers.get(0),
|
||||
uppers.get(0), expandThreshold);
|
||||
} else {
|
||||
// slow path
|
||||
return commonTryExpandRange(partitionSlots, lowers, uppers, partitionSlotTypes, expandThreshold);
|
||||
}
|
||||
}
|
||||
|
||||
private static List<List<Expression>> tryExpandSingleColumnRange(Slot partitionSlot, Literal lower,
|
||||
Literal upper, int expandThreshold) {
|
||||
// must be range slot
|
||||
try {
|
||||
if (canExpandRange(partitionSlot, lower, upper, 1, expandThreshold)) {
|
||||
Iterator<? extends Expression> iterator = enumerableIterator(
|
||||
partitionSlot, lower, upper, true);
|
||||
if (iterator instanceof SingletonIterator) {
|
||||
return ImmutableList.of(ImmutableList.of(iterator.next()));
|
||||
} else {
|
||||
return ImmutableList.of(
|
||||
ImmutableList.copyOf(iterator)
|
||||
);
|
||||
}
|
||||
} else {
|
||||
return ImmutableList.of(ImmutableList.of(partitionSlot));
|
||||
}
|
||||
} catch (Throwable t) {
|
||||
// catch for safety, should not invoke here
|
||||
return ImmutableList.of(ImmutableList.of(partitionSlot));
|
||||
}
|
||||
}
|
||||
|
||||
private static List<List<Expression>> commonTryExpandRange(
|
||||
List<Slot> partitionSlots, List<Literal> lowers, List<Literal> uppers,
|
||||
List<PartitionSlotType> partitionSlotTypes, int expandThreshold) {
|
||||
long expandedCount = 1;
|
||||
List<List<Expression>> expandedLists = Lists.newArrayListWithCapacity(lowers.size());
|
||||
for (int i = 0; i < partitionSlotTypes.size(); i++) {
|
||||
@ -126,7 +159,7 @@ public class PartitionRangeExpander {
|
||||
return expandedLists;
|
||||
}
|
||||
|
||||
private boolean canExpandRange(Slot slot, Literal lower, Literal upper,
|
||||
private static boolean canExpandRange(Slot slot, Literal lower, Literal upper,
|
||||
long expandedCount, int expandThreshold) {
|
||||
DataType type = slot.getDataType();
|
||||
if (!type.isIntegerLikeType() && !type.isDateType() && !type.isDateV2Type()) {
|
||||
@ -139,7 +172,7 @@ public class PartitionRangeExpander {
|
||||
}
|
||||
// too much expanded will consuming resources of frontend,
|
||||
// e.g. [1, 100000000), we should skip expand it
|
||||
return (expandedCount * count) <= expandThreshold;
|
||||
return count == 1 || (expandedCount * count) <= expandThreshold;
|
||||
} catch (Throwable t) {
|
||||
// e.g. max_value can not expand
|
||||
return false;
|
||||
@ -147,7 +180,7 @@ public class PartitionRangeExpander {
|
||||
}
|
||||
|
||||
/** the types will like this: [CONST, CONST, ..., RANGE, OTHER, OTHER, ...] */
|
||||
public List<PartitionSlotType> computePartitionSlotTypes(List<Literal> lowers, List<Literal> uppers) {
|
||||
public static List<PartitionSlotType> computePartitionSlotTypes(List<Literal> lowers, List<Literal> uppers) {
|
||||
PartitionSlotType previousType = PartitionSlotType.CONST;
|
||||
List<PartitionSlotType> types = Lists.newArrayListWithCapacity(lowers.size());
|
||||
for (int i = 0; i < lowers.size(); ++i) {
|
||||
@ -167,7 +200,7 @@ public class PartitionRangeExpander {
|
||||
return types;
|
||||
}
|
||||
|
||||
private long enumerableCount(DataType dataType, Literal startInclusive, Literal endExclusive) {
|
||||
private static long enumerableCount(DataType dataType, Literal startInclusive, Literal endExclusive) {
|
||||
if (dataType.isIntegerLikeType()) {
|
||||
BigInteger start = new BigInteger(startInclusive.getStringValue());
|
||||
BigInteger end = new BigInteger(endExclusive.getStringValue());
|
||||
@ -175,6 +208,12 @@ public class PartitionRangeExpander {
|
||||
} else if (dataType.isDateType()) {
|
||||
DateLiteral startInclusiveDate = (DateLiteral) startInclusive;
|
||||
DateLiteral endExclusiveDate = (DateLiteral) endExclusive;
|
||||
|
||||
if (startInclusiveDate.getYear() == endExclusiveDate.getYear()
|
||||
&& startInclusiveDate.getMonth() == endExclusiveDate.getMonth()) {
|
||||
return endExclusiveDate.getDay() - startInclusiveDate.getDay();
|
||||
}
|
||||
|
||||
LocalDate startDate = LocalDate.of(
|
||||
(int) startInclusiveDate.getYear(),
|
||||
(int) startInclusiveDate.getMonth(),
|
||||
@ -192,6 +231,12 @@ public class PartitionRangeExpander {
|
||||
} else if (dataType.isDateV2Type()) {
|
||||
DateV2Literal startInclusiveDate = (DateV2Literal) startInclusive;
|
||||
DateV2Literal endExclusiveDate = (DateV2Literal) endExclusive;
|
||||
|
||||
if (startInclusiveDate.getYear() == endExclusiveDate.getYear()
|
||||
&& startInclusiveDate.getMonth() == endExclusiveDate.getMonth()) {
|
||||
return endExclusiveDate.getDay() - startInclusiveDate.getDay();
|
||||
}
|
||||
|
||||
LocalDate startDate = LocalDate.of(
|
||||
(int) startInclusiveDate.getYear(),
|
||||
(int) startInclusiveDate.getMonth(),
|
||||
@ -212,7 +257,7 @@ public class PartitionRangeExpander {
|
||||
return -1;
|
||||
}
|
||||
|
||||
private Iterator<? extends Expression> enumerableIterator(
|
||||
private static Iterator<? extends Expression> enumerableIterator(
|
||||
Slot slot, Literal startInclusive, Literal endLiteral, boolean endExclusive) {
|
||||
DataType dataType = slot.getDataType();
|
||||
if (dataType.isIntegerLikeType()) {
|
||||
@ -237,6 +282,12 @@ public class PartitionRangeExpander {
|
||||
} else if (dataType.isDateType()) {
|
||||
DateLiteral startInclusiveDate = (DateLiteral) startInclusive;
|
||||
DateLiteral endLiteralDate = (DateLiteral) endLiteral;
|
||||
if (endExclusive && startInclusiveDate.getYear() == endLiteralDate.getYear()
|
||||
&& startInclusiveDate.getMonth() == endLiteralDate.getMonth()
|
||||
&& startInclusiveDate.getDay() + 1 == endLiteralDate.getDay()) {
|
||||
return new SingletonIterator(startInclusive);
|
||||
}
|
||||
|
||||
LocalDate startDate = LocalDate.of(
|
||||
(int) startInclusiveDate.getYear(),
|
||||
(int) startInclusiveDate.getMonth(),
|
||||
@ -258,6 +309,13 @@ public class PartitionRangeExpander {
|
||||
} else if (dataType.isDateV2Type()) {
|
||||
DateV2Literal startInclusiveDate = (DateV2Literal) startInclusive;
|
||||
DateV2Literal endLiteralDate = (DateV2Literal) endLiteral;
|
||||
|
||||
if (endExclusive && startInclusiveDate.getYear() == endLiteralDate.getYear()
|
||||
&& startInclusiveDate.getMonth() == endLiteralDate.getMonth()
|
||||
&& startInclusiveDate.getDay() + 1 == endLiteralDate.getDay()) {
|
||||
return new SingletonIterator(startInclusive);
|
||||
}
|
||||
|
||||
LocalDate startDate = LocalDate.of(
|
||||
(int) startInclusiveDate.getYear(),
|
||||
(int) startInclusiveDate.getMonth(),
|
||||
@ -282,7 +340,7 @@ public class PartitionRangeExpander {
|
||||
return Iterators.singletonIterator(slot);
|
||||
}
|
||||
|
||||
private class IntegerLikeRangePartitionValueIterator<L extends IntegerLikeLiteral>
|
||||
private static class IntegerLikeRangePartitionValueIterator<L extends IntegerLikeLiteral>
|
||||
extends RangePartitionValueIterator<BigInteger, L> {
|
||||
|
||||
public IntegerLikeRangePartitionValueIterator(BigInteger startInclusive, BigInteger end,
|
||||
@ -296,7 +354,7 @@ public class PartitionRangeExpander {
|
||||
}
|
||||
}
|
||||
|
||||
private class DateLikeRangePartitionValueIterator<L extends Literal>
|
||||
private static class DateLikeRangePartitionValueIterator<L extends Literal>
|
||||
extends RangePartitionValueIterator<LocalDate, L> {
|
||||
|
||||
public DateLikeRangePartitionValueIterator(
|
||||
@ -309,43 +367,4 @@ public class PartitionRangeExpander {
|
||||
return current.plusDays(1);
|
||||
}
|
||||
}
|
||||
|
||||
private abstract class RangePartitionValueIterator<C extends Comparable, L extends Literal>
|
||||
implements Iterator<L> {
|
||||
private final C startInclusive;
|
||||
private final C end;
|
||||
private final boolean endExclusive;
|
||||
private C current;
|
||||
|
||||
private final Function<C, L> toLiteral;
|
||||
|
||||
public RangePartitionValueIterator(C startInclusive, C end, boolean endExclusive, Function<C, L> toLiteral) {
|
||||
this.startInclusive = startInclusive;
|
||||
this.end = end;
|
||||
this.endExclusive = endExclusive;
|
||||
this.current = this.startInclusive;
|
||||
this.toLiteral = toLiteral;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasNext() {
|
||||
if (endExclusive) {
|
||||
return current.compareTo(end) < 0;
|
||||
} else {
|
||||
return current.compareTo(end) <= 0;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public L next() {
|
||||
if (hasNext()) {
|
||||
C value = current;
|
||||
current = doGetNext(current);
|
||||
return toLiteral.apply(value);
|
||||
}
|
||||
throw new NoSuchElementException();
|
||||
}
|
||||
|
||||
protected abstract C doGetNext(C current);
|
||||
}
|
||||
}
|
||||
|
||||
@ -70,7 +70,7 @@ public class PredicateRewriteForPartitionPrune
|
||||
}
|
||||
}
|
||||
if (convertable) {
|
||||
Expression or = ExpressionUtils.combine(Or.class, splitIn);
|
||||
Expression or = ExpressionUtils.combineAsLeftDeepTree(Or.class, splitIn);
|
||||
return or;
|
||||
}
|
||||
} else if (dateChild.getDataType() instanceof DateTimeV2Type) {
|
||||
@ -87,7 +87,7 @@ public class PredicateRewriteForPartitionPrune
|
||||
}
|
||||
}
|
||||
if (convertable) {
|
||||
Expression or = ExpressionUtils.combine(Or.class, splitIn);
|
||||
Expression or = ExpressionUtils.combineAsLeftDeepTree(Or.class, splitIn);
|
||||
return or;
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,64 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.expression.rules;
|
||||
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
|
||||
import java.util.Iterator;
|
||||
import java.util.NoSuchElementException;
|
||||
import java.util.function.Function;
|
||||
|
||||
/** RangePartitionValueIterator */
|
||||
public abstract class RangePartitionValueIterator<C extends Comparable, L extends Literal>
|
||||
implements Iterator<L> {
|
||||
private final C startInclusive;
|
||||
private final C end;
|
||||
private final boolean endExclusive;
|
||||
private C current;
|
||||
|
||||
private final Function<C, L> toLiteral;
|
||||
|
||||
public RangePartitionValueIterator(C startInclusive, C end, boolean endExclusive, Function<C, L> toLiteral) {
|
||||
this.startInclusive = startInclusive;
|
||||
this.end = end;
|
||||
this.endExclusive = endExclusive;
|
||||
this.current = this.startInclusive;
|
||||
this.toLiteral = toLiteral;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasNext() {
|
||||
if (endExclusive) {
|
||||
return current.compareTo(end) < 0;
|
||||
} else {
|
||||
return current.compareTo(end) <= 0;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public L next() {
|
||||
if (hasNext()) {
|
||||
C value = current;
|
||||
current = doGetNext(current);
|
||||
return toLiteral.apply(value);
|
||||
}
|
||||
throw new NoSuchElementException();
|
||||
}
|
||||
|
||||
protected abstract C doGetNext(C current);
|
||||
}
|
||||
@ -17,20 +17,25 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.expression.rules;
|
||||
|
||||
import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.Variable;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* replace varaible to real expression
|
||||
*/
|
||||
public class ReplaceVariableByLiteral extends AbstractExpressionRewriteRule {
|
||||
|
||||
public class ReplaceVariableByLiteral implements ExpressionPatternRuleFactory {
|
||||
public static ReplaceVariableByLiteral INSTANCE = new ReplaceVariableByLiteral();
|
||||
|
||||
@Override
|
||||
public Expression visitVariable(Variable variable, ExpressionRewriteContext context) {
|
||||
return variable.getRealExpression();
|
||||
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
|
||||
return ImmutableList.of(
|
||||
matchesType(Variable.class).then(Variable::getRealExpression)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,7 +17,8 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.expression.rules;
|
||||
|
||||
import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
|
||||
import org.apache.doris.nereids.trees.expressions.Add;
|
||||
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
|
||||
@ -43,6 +44,7 @@ import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.util.TypeCoercionUtils;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
|
||||
import java.util.Arrays;
|
||||
@ -55,11 +57,11 @@ import javax.annotation.Nullable;
|
||||
* a + 1 > 1 => a > 0
|
||||
* a / -2 > 1 => a < -2
|
||||
*/
|
||||
public class SimplifyArithmeticComparisonRule extends AbstractExpressionRewriteRule {
|
||||
public static final SimplifyArithmeticComparisonRule INSTANCE = new SimplifyArithmeticComparisonRule();
|
||||
public class SimplifyArithmeticComparisonRule implements ExpressionPatternRuleFactory {
|
||||
public static SimplifyArithmeticComparisonRule INSTANCE = new SimplifyArithmeticComparisonRule();
|
||||
|
||||
// don't rearrange multiplication because divide may loss precision
|
||||
final Map<Class<? extends Expression>, Class<? extends Expression>> rearrangementMap = ImmutableMap
|
||||
private static final Map<Class<? extends Expression>, Class<? extends Expression>> REARRANGEMENT_MAP = ImmutableMap
|
||||
.<Class<? extends Expression>, Class<? extends Expression>>builder()
|
||||
.put(Add.class, Subtract.class)
|
||||
.put(Subtract.class, Add.class)
|
||||
@ -81,41 +83,54 @@ public class SimplifyArithmeticComparisonRule extends AbstractExpressionRewriteR
|
||||
.build();
|
||||
|
||||
@Override
|
||||
public Expression visitComparisonPredicate(ComparisonPredicate comparison, ExpressionRewriteContext context) {
|
||||
if (couldRearrange(comparison)) {
|
||||
ComparisonPredicate newComparison = normalize(comparison);
|
||||
if (newComparison == null) {
|
||||
return comparison;
|
||||
}
|
||||
try {
|
||||
List<Expression> children =
|
||||
tryRearrangeChildren(newComparison.left(), newComparison.right(), context);
|
||||
newComparison = (ComparisonPredicate) visitComparisonPredicate(
|
||||
(ComparisonPredicate) newComparison.withChildren(children), context);
|
||||
} catch (Exception e) {
|
||||
return comparison;
|
||||
}
|
||||
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
|
||||
return ImmutableList.of(
|
||||
matchesType(ComparisonPredicate.class)
|
||||
.thenApply(ctx -> simplify(ctx.expr, new ExpressionRewriteContext(ctx.cascadesContext)))
|
||||
);
|
||||
}
|
||||
|
||||
/** simplify */
|
||||
public static Expression simplify(ComparisonPredicate comparison, ExpressionRewriteContext context) {
|
||||
if (!couldRearrange(comparison)) {
|
||||
return comparison;
|
||||
}
|
||||
ComparisonPredicate newComparison = normalize(comparison);
|
||||
if (newComparison == null) {
|
||||
return comparison;
|
||||
}
|
||||
try {
|
||||
List<Expression> children = tryRearrangeChildren(newComparison.left(), newComparison.right(), context);
|
||||
newComparison = (ComparisonPredicate) simplify(
|
||||
(ComparisonPredicate) newComparison.withChildren(children), context);
|
||||
return TypeCoercionUtils.processComparisonPredicate(newComparison);
|
||||
} else {
|
||||
} catch (Exception e) {
|
||||
return comparison;
|
||||
}
|
||||
}
|
||||
|
||||
private boolean couldRearrange(ComparisonPredicate cmp) {
|
||||
return rearrangementMap.containsKey(cmp.left().getClass())
|
||||
&& !cmp.left().isConstant()
|
||||
&& cmp.left().children().stream().anyMatch(Expression::isConstant);
|
||||
private static boolean couldRearrange(ComparisonPredicate cmp) {
|
||||
if (!REARRANGEMENT_MAP.containsKey(cmp.left().getClass()) || cmp.left().isConstant()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (Expression child : cmp.left().children()) {
|
||||
if (child.isConstant()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private List<Expression> tryRearrangeChildren(Expression left, Expression right,
|
||||
private static List<Expression> tryRearrangeChildren(Expression left, Expression right,
|
||||
ExpressionRewriteContext context) throws Exception {
|
||||
if (!left.child(1).isConstant()) {
|
||||
throw new RuntimeException(String.format("Expected literal when arranging children for Expr %s", left));
|
||||
}
|
||||
Literal leftLiteral = (Literal) FoldConstantRule.INSTANCE.rewrite(left.child(1), context);
|
||||
Literal leftLiteral = (Literal) FoldConstantRule.evaluate(left.child(1), context);
|
||||
Expression leftExpr = left.child(0);
|
||||
|
||||
Class<? extends Expression> oppositeOperator = rearrangementMap.get(left.getClass());
|
||||
Class<? extends Expression> oppositeOperator = REARRANGEMENT_MAP.get(left.getClass());
|
||||
Expression newChild = oppositeOperator.getConstructor(Expression.class, Expression.class)
|
||||
.newInstance(right, leftLiteral);
|
||||
|
||||
@ -127,25 +142,25 @@ public class SimplifyArithmeticComparisonRule extends AbstractExpressionRewriteR
|
||||
}
|
||||
|
||||
// Ensure that the second child must be Literal, such as
|
||||
private @Nullable ComparisonPredicate normalize(ComparisonPredicate comparison) {
|
||||
if (!(comparison.left().child(1) instanceof Literal)) {
|
||||
Expression left = comparison.left();
|
||||
if (comparison.left() instanceof Add) {
|
||||
// 1 + a > 1 => a + 1 > 1
|
||||
Expression newLeft = left.withChildren(left.child(1), left.child(0));
|
||||
comparison = (ComparisonPredicate) comparison.withChildren(newLeft, comparison.right());
|
||||
} else if (comparison.left() instanceof Subtract) {
|
||||
// 1 - a > 1 => a + 1 < 1
|
||||
Expression newLeft = left.child(0);
|
||||
Expression newRight = new Add(left.child(1), comparison.right());
|
||||
comparison = (ComparisonPredicate) comparison.withChildren(newLeft, newRight);
|
||||
comparison = comparison.commute();
|
||||
} else {
|
||||
// Don't normalize division/multiplication because the slot sign is undecided.
|
||||
return null;
|
||||
}
|
||||
private static @Nullable ComparisonPredicate normalize(ComparisonPredicate comparison) {
|
||||
Expression left = comparison.left();
|
||||
Expression leftRight = left.child(1);
|
||||
if (leftRight instanceof Literal) {
|
||||
return comparison;
|
||||
}
|
||||
if (left instanceof Add) {
|
||||
// 1 + a > 1 => a + 1 > 1
|
||||
Expression newLeft = left.withChildren(leftRight, left.child(0));
|
||||
return (ComparisonPredicate) comparison.withChildren(newLeft, comparison.right());
|
||||
} else if (left instanceof Subtract) {
|
||||
// 1 - a > 1 => a + 1 < 1
|
||||
Expression newLeft = left.child(0);
|
||||
Expression newRight = new Add(leftRight, comparison.right());
|
||||
comparison = (ComparisonPredicate) comparison.withChildren(newLeft, newRight);
|
||||
return comparison.commute();
|
||||
} else {
|
||||
// Don't normalize division/multiplication because the slot sign is undecided.
|
||||
return null;
|
||||
}
|
||||
return comparison;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -17,8 +17,8 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.expression.rules;
|
||||
|
||||
import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
|
||||
import org.apache.doris.nereids.trees.expressions.Add;
|
||||
import org.apache.doris.nereids.trees.expressions.BinaryArithmetic;
|
||||
import org.apache.doris.nereids.trees.expressions.Divide;
|
||||
@ -27,7 +27,9 @@ import org.apache.doris.nereids.trees.expressions.Multiply;
|
||||
import org.apache.doris.nereids.trees.expressions.Subtract;
|
||||
import org.apache.doris.nereids.util.TypeCoercionUtils;
|
||||
import org.apache.doris.nereids.util.TypeUtils;
|
||||
import org.apache.doris.nereids.util.Utils;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.Lists;
|
||||
|
||||
import java.util.List;
|
||||
@ -43,27 +45,24 @@ import java.util.Optional;
|
||||
*
|
||||
* TODO: handle cases like: '1 - IA < 1' to 'IA > 0'
|
||||
*/
|
||||
public class SimplifyArithmeticRule extends AbstractExpressionRewriteRule {
|
||||
public class SimplifyArithmeticRule implements ExpressionPatternRuleFactory {
|
||||
public static final SimplifyArithmeticRule INSTANCE = new SimplifyArithmeticRule();
|
||||
|
||||
@Override
|
||||
public Expression visitAdd(Add add, ExpressionRewriteContext context) {
|
||||
return process(add, true);
|
||||
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
|
||||
return ImmutableList.of(
|
||||
matchesTopType(BinaryArithmetic.class).then(SimplifyArithmeticRule::simplify)
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitSubtract(Subtract subtract, ExpressionRewriteContext context) {
|
||||
return process(subtract, true);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitDivide(Divide divide, ExpressionRewriteContext context) {
|
||||
return process(divide, false);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitMultiply(Multiply multiply, ExpressionRewriteContext context) {
|
||||
return process(multiply, false);
|
||||
/** simplify */
|
||||
public static Expression simplify(BinaryArithmetic binaryArithmetic) {
|
||||
if (binaryArithmetic instanceof Add || binaryArithmetic instanceof Subtract) {
|
||||
return process(binaryArithmetic, true);
|
||||
} else if (binaryArithmetic instanceof Multiply || binaryArithmetic instanceof Divide) {
|
||||
return process(binaryArithmetic, false);
|
||||
}
|
||||
return binaryArithmetic;
|
||||
}
|
||||
|
||||
/**
|
||||
@ -75,7 +74,7 @@ public class SimplifyArithmeticRule extends AbstractExpressionRewriteRule {
|
||||
* 3.build new arithmetic expression.
|
||||
* (a + b - c + d) + (1 - 2 - 1)
|
||||
*/
|
||||
private Expression process(BinaryArithmetic arithmetic, boolean isAddOrSub) {
|
||||
private static Expression process(BinaryArithmetic arithmetic, boolean isAddOrSub) {
|
||||
// 1. flatten the arithmetic expression.
|
||||
List<Operand> flattedExpressions = flatten(arithmetic, isAddOrSub);
|
||||
|
||||
@ -83,22 +82,24 @@ public class SimplifyArithmeticRule extends AbstractExpressionRewriteRule {
|
||||
List<Operand> constants = Lists.newArrayList();
|
||||
|
||||
// TODO currently we don't process decimal for simplicity.
|
||||
if (flattedExpressions.stream().anyMatch(operand -> operand.expression.getDataType().isDecimalLikeType())) {
|
||||
return arithmetic;
|
||||
for (Operand operand : flattedExpressions) {
|
||||
if (operand.expression.getDataType().isDecimalLikeType()) {
|
||||
return arithmetic;
|
||||
}
|
||||
}
|
||||
// 2. move variables to left side and move constants to right sid.
|
||||
flattedExpressions.forEach(operand -> {
|
||||
for (Operand operand : flattedExpressions) {
|
||||
if (operand.expression.isConstant()) {
|
||||
constants.add(operand);
|
||||
} else {
|
||||
variables.add(operand);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// 3. build new arithmetic expression.
|
||||
if (!constants.isEmpty()) {
|
||||
boolean isOpposite = !constants.get(0).flag;
|
||||
Optional<Operand> c = constants.stream().reduce((x, y) -> {
|
||||
Optional<Operand> c = Utils.fastReduce(constants, (x, y) -> {
|
||||
Expression expr;
|
||||
if (isOpposite && y.flag || !isOpposite && !y.flag) {
|
||||
expr = getSubOrDivide(isAddOrSub, x, y);
|
||||
@ -115,10 +116,10 @@ public class SimplifyArithmeticRule extends AbstractExpressionRewriteRule {
|
||||
}
|
||||
}
|
||||
|
||||
Optional<Operand> result = variables.stream().reduce((x, y) -> !y.flag
|
||||
Optional<Operand> result = Utils.fastReduce(variables, (x, y) -> !y.flag
|
||||
? Operand.of(true, getSubOrDivide(isAddOrSub, x, y))
|
||||
: Operand.of(true, getAddOrMultiply(isAddOrSub, x, y)));
|
||||
|
||||
: Operand.of(true, getAddOrMultiply(isAddOrSub, x, y))
|
||||
);
|
||||
if (result.isPresent()) {
|
||||
return TypeCoercionUtils.castIfNotSameType(result.get().expression, arithmetic.getDataType());
|
||||
} else {
|
||||
@ -126,7 +127,7 @@ public class SimplifyArithmeticRule extends AbstractExpressionRewriteRule {
|
||||
}
|
||||
}
|
||||
|
||||
private List<Operand> flatten(Expression expr, boolean isAddOrSub) {
|
||||
private static List<Operand> flatten(Expression expr, boolean isAddOrSub) {
|
||||
List<Operand> result = Lists.newArrayList();
|
||||
if (isAddOrSub) {
|
||||
flattenAddSubtract(true, expr, result);
|
||||
@ -136,7 +137,7 @@ public class SimplifyArithmeticRule extends AbstractExpressionRewriteRule {
|
||||
return result;
|
||||
}
|
||||
|
||||
private void flattenAddSubtract(boolean flag, Expression expr, List<Operand> result) {
|
||||
private static void flattenAddSubtract(boolean flag, Expression expr, List<Operand> result) {
|
||||
if (TypeUtils.isAddOrSubtract(expr)) {
|
||||
BinaryArithmetic arithmetic = (BinaryArithmetic) expr;
|
||||
flattenAddSubtract(flag, arithmetic.left(), result);
|
||||
@ -152,7 +153,7 @@ public class SimplifyArithmeticRule extends AbstractExpressionRewriteRule {
|
||||
}
|
||||
}
|
||||
|
||||
private void flattenMultiplyDivide(boolean flag, Expression expr, List<Operand> result) {
|
||||
private static void flattenMultiplyDivide(boolean flag, Expression expr, List<Operand> result) {
|
||||
if (TypeUtils.isMultiplyOrDivide(expr)) {
|
||||
BinaryArithmetic arithmetic = (BinaryArithmetic) expr;
|
||||
flattenMultiplyDivide(flag, arithmetic.left(), result);
|
||||
@ -168,13 +169,13 @@ public class SimplifyArithmeticRule extends AbstractExpressionRewriteRule {
|
||||
}
|
||||
}
|
||||
|
||||
private Expression getSubOrDivide(boolean isAddOrSub, Operand x, Operand y) {
|
||||
return isAddOrSub ? new Subtract(x.expression, y.expression)
|
||||
private static Expression getSubOrDivide(boolean isSubOrDivide, Operand x, Operand y) {
|
||||
return isSubOrDivide ? new Subtract(x.expression, y.expression)
|
||||
: new Divide(x.expression, y.expression);
|
||||
}
|
||||
|
||||
private Expression getAddOrMultiply(boolean isAddOrSub, Operand x, Operand y) {
|
||||
return isAddOrSub ? new Add(x.expression, y.expression)
|
||||
private static Expression getAddOrMultiply(boolean isAddOrMultiply, Operand x, Operand y) {
|
||||
return isAddOrMultiply ? new Add(x.expression, y.expression)
|
||||
: new Multiply(x.expression, y.expression);
|
||||
}
|
||||
|
||||
@ -204,3 +205,4 @@ public class SimplifyArithmeticRule extends AbstractExpressionRewriteRule {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -17,8 +17,8 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.expression.rules;
|
||||
|
||||
import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
|
||||
import org.apache.doris.nereids.trees.expressions.Cast;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
|
||||
@ -37,7 +37,10 @@ import org.apache.doris.nereids.types.DecimalV3Type;
|
||||
import org.apache.doris.nereids.types.StringType;
|
||||
import org.apache.doris.nereids.types.VarcharType;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Rewrite rule of simplify CAST expression.
|
||||
@ -46,17 +49,19 @@ import java.math.BigDecimal;
|
||||
* Merge cast like
|
||||
* - cast(cast(1 as bigint) as string) -> cast(1 as string).
|
||||
*/
|
||||
public class SimplifyCastRule extends AbstractExpressionRewriteRule {
|
||||
|
||||
public class SimplifyCastRule implements ExpressionPatternRuleFactory {
|
||||
public static SimplifyCastRule INSTANCE = new SimplifyCastRule();
|
||||
|
||||
@Override
|
||||
public Expression visitCast(Cast origin, ExpressionRewriteContext context) {
|
||||
return simplify(origin, context);
|
||||
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
|
||||
return ImmutableList.of(
|
||||
matchesType(Cast.class).then(SimplifyCastRule::simplifyCast)
|
||||
);
|
||||
}
|
||||
|
||||
private Expression simplify(Cast cast, ExpressionRewriteContext context) {
|
||||
Expression child = rewrite(cast.child(), context);
|
||||
/** simplifyCast */
|
||||
public static Expression simplifyCast(Cast cast) {
|
||||
Expression child = cast.child();
|
||||
|
||||
// remove redundant cast
|
||||
// CAST(value as type), value is type
|
||||
|
||||
@ -18,6 +18,8 @@
|
||||
package org.apache.doris.nereids.rules.expression.rules;
|
||||
|
||||
import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
|
||||
import org.apache.doris.nereids.trees.expressions.And;
|
||||
import org.apache.doris.nereids.trees.expressions.Cast;
|
||||
@ -55,17 +57,18 @@ import org.apache.doris.nereids.types.coercion.DateLikeType;
|
||||
import org.apache.doris.nereids.util.TypeCoercionUtils;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
import java.math.RoundingMode;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* simplify comparison
|
||||
* such as: cast(c1 as DateV2) >= DateV2Literal --> c1 >= DateLiteral
|
||||
* cast(c1 AS double) > 2.0 --> c1 >= 2 (c1 is integer like type)
|
||||
*/
|
||||
public class SimplifyComparisonPredicate extends AbstractExpressionRewriteRule {
|
||||
|
||||
public class SimplifyComparisonPredicate extends AbstractExpressionRewriteRule implements ExpressionPatternRuleFactory {
|
||||
public static SimplifyComparisonPredicate INSTANCE = new SimplifyComparisonPredicate();
|
||||
|
||||
enum AdjustType {
|
||||
@ -75,9 +78,19 @@ public class SimplifyComparisonPredicate extends AbstractExpressionRewriteRule {
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitComparisonPredicate(ComparisonPredicate cp, ExpressionRewriteContext context) {
|
||||
cp = (ComparisonPredicate) visit(cp, context);
|
||||
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
|
||||
return ImmutableList.of(
|
||||
matchesType(ComparisonPredicate.class).then(SimplifyComparisonPredicate::simplify)
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitComparisonPredicate(ComparisonPredicate cp, ExpressionRewriteContext context) {
|
||||
return simplify(cp);
|
||||
}
|
||||
|
||||
/** simplify */
|
||||
public static Expression simplify(ComparisonPredicate cp) {
|
||||
if (cp.left() instanceof Literal && !(cp.right() instanceof Literal)) {
|
||||
cp = cp.commute();
|
||||
}
|
||||
@ -146,7 +159,7 @@ public class SimplifyComparisonPredicate extends AbstractExpressionRewriteRule {
|
||||
return comparisonPredicate;
|
||||
}
|
||||
|
||||
private Expression processDateLikeTypeCoercion(ComparisonPredicate cp, Expression left, Expression right) {
|
||||
private static Expression processDateLikeTypeCoercion(ComparisonPredicate cp, Expression left, Expression right) {
|
||||
if (left instanceof Cast && right instanceof DateLiteral) {
|
||||
Cast cast = (Cast) left;
|
||||
if (cast.child().getDataType() instanceof DateTimeType) {
|
||||
@ -196,7 +209,7 @@ public class SimplifyComparisonPredicate extends AbstractExpressionRewriteRule {
|
||||
}
|
||||
}
|
||||
|
||||
private Expression processFloatLikeTypeCoercion(ComparisonPredicate comparisonPredicate,
|
||||
private static Expression processFloatLikeTypeCoercion(ComparisonPredicate comparisonPredicate,
|
||||
Expression left, Expression right) {
|
||||
if (left instanceof Cast && left.child(0).getDataType().isIntegerLikeType()
|
||||
&& (right instanceof DoubleLiteral || right instanceof FloatLiteral)) {
|
||||
@ -209,7 +222,7 @@ public class SimplifyComparisonPredicate extends AbstractExpressionRewriteRule {
|
||||
}
|
||||
}
|
||||
|
||||
private Expression processDecimalV3TypeCoercion(ComparisonPredicate comparisonPredicate,
|
||||
private static Expression processDecimalV3TypeCoercion(ComparisonPredicate comparisonPredicate,
|
||||
Expression left, Expression right) {
|
||||
if (left instanceof Cast && right instanceof DecimalV3Literal) {
|
||||
Cast cast = (Cast) left;
|
||||
@ -264,7 +277,7 @@ public class SimplifyComparisonPredicate extends AbstractExpressionRewriteRule {
|
||||
return comparisonPredicate;
|
||||
}
|
||||
|
||||
private Expression processIntegerDecimalLiteralComparison(
|
||||
private static Expression processIntegerDecimalLiteralComparison(
|
||||
ComparisonPredicate comparisonPredicate, Expression left, BigDecimal literal) {
|
||||
// we only process isIntegerLikeType, which are tinyint, smallint, int, bigint
|
||||
if (literal.compareTo(new BigDecimal(Long.MAX_VALUE)) <= 0) {
|
||||
@ -306,7 +319,7 @@ public class SimplifyComparisonPredicate extends AbstractExpressionRewriteRule {
|
||||
return comparisonPredicate;
|
||||
}
|
||||
|
||||
private IntegerLikeLiteral convertDecimalToIntegerLikeLiteral(BigDecimal decimal) {
|
||||
private static IntegerLikeLiteral convertDecimalToIntegerLikeLiteral(BigDecimal decimal) {
|
||||
Preconditions.checkArgument(
|
||||
decimal.scale() <= 0 && decimal.compareTo(new BigDecimal(Long.MAX_VALUE)) <= 0,
|
||||
"decimal literal must have 0 scale and smaller than Long.MAX_VALUE");
|
||||
@ -322,15 +335,15 @@ public class SimplifyComparisonPredicate extends AbstractExpressionRewriteRule {
|
||||
}
|
||||
}
|
||||
|
||||
private Expression migrateToDateTime(DateTimeV2Literal l) {
|
||||
private static Expression migrateToDateTime(DateTimeV2Literal l) {
|
||||
return new DateTimeLiteral(l.getYear(), l.getMonth(), l.getDay(), l.getHour(), l.getMinute(), l.getSecond());
|
||||
}
|
||||
|
||||
private boolean cannotAdjust(DateTimeLiteral l, ComparisonPredicate cp) {
|
||||
private static boolean cannotAdjust(DateTimeLiteral l, ComparisonPredicate cp) {
|
||||
return cp instanceof EqualTo && (l.getHour() != 0 || l.getMinute() != 0 || l.getSecond() != 0);
|
||||
}
|
||||
|
||||
private Expression migrateToDateV2(DateTimeLiteral l, AdjustType type) {
|
||||
private static Expression migrateToDateV2(DateTimeLiteral l, AdjustType type) {
|
||||
DateV2Literal d = new DateV2Literal(l.getYear(), l.getMonth(), l.getDay());
|
||||
if (type == AdjustType.UPPER && (l.getHour() != 0 || l.getMinute() != 0 || l.getSecond() != 0)) {
|
||||
d = ((DateV2Literal) d.plusDays(1));
|
||||
@ -338,7 +351,7 @@ public class SimplifyComparisonPredicate extends AbstractExpressionRewriteRule {
|
||||
return d;
|
||||
}
|
||||
|
||||
private Expression migrateToDate(DateV2Literal l) {
|
||||
private static Expression migrateToDate(DateV2Literal l) {
|
||||
return new DateLiteral(l.getYear(), l.getMonth(), l.getDay());
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,8 +17,8 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.expression.rules;
|
||||
|
||||
import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
|
||||
import org.apache.doris.nereids.trees.expressions.Cast;
|
||||
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
@ -26,8 +26,10 @@ import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
|
||||
import org.apache.doris.nereids.types.DecimalV3Type;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* if we have a column with decimalv3 type and set enable_decimal_conversion = false.
|
||||
@ -37,14 +39,20 @@ import java.math.BigDecimal;
|
||||
* and the col1 need to convert to decimalv3(27, 9) to match the precision of right hand
|
||||
* this rule simplify it from cast(col1 as decimalv3(27, 9)) > 0.6 to col1 > 0.6
|
||||
*/
|
||||
public class SimplifyDecimalV3Comparison extends AbstractExpressionRewriteRule {
|
||||
|
||||
public class SimplifyDecimalV3Comparison implements ExpressionPatternRuleFactory {
|
||||
public static SimplifyDecimalV3Comparison INSTANCE = new SimplifyDecimalV3Comparison();
|
||||
|
||||
@Override
|
||||
public Expression visitComparisonPredicate(ComparisonPredicate cp, ExpressionRewriteContext context) {
|
||||
Expression left = rewrite(cp.left(), context);
|
||||
Expression right = rewrite(cp.right(), context);
|
||||
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
|
||||
return ImmutableList.of(
|
||||
matchesType(ComparisonPredicate.class).then(SimplifyDecimalV3Comparison::simplify)
|
||||
);
|
||||
}
|
||||
|
||||
/** simplify */
|
||||
public static Expression simplify(ComparisonPredicate cp) {
|
||||
Expression left = cp.left();
|
||||
Expression right = cp.right();
|
||||
|
||||
if (left.getDataType() instanceof DecimalV3Type
|
||||
&& left instanceof Cast
|
||||
@ -60,7 +68,7 @@ public class SimplifyDecimalV3Comparison extends AbstractExpressionRewriteRule {
|
||||
}
|
||||
}
|
||||
|
||||
private Expression doProcess(ComparisonPredicate cp, Cast left, DecimalV3Literal right) {
|
||||
private static Expression doProcess(ComparisonPredicate cp, Cast left, DecimalV3Literal right) {
|
||||
BigDecimal trailingZerosValue = right.getValue().stripTrailingZeros();
|
||||
int scale = org.apache.doris.analysis.DecimalLiteral.getBigDecimalScale(trailingZerosValue);
|
||||
int precision = org.apache.doris.analysis.DecimalLiteral.getBigDecimalPrecision(trailingZerosValue);
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user