[Nereids][Improve] infer predicate after push down predicate (#12996)

This PR implements the function of predicate inference

For example:

``` sql
select * from student left join score on student.id = score.sid where score.sid > 1
```
transformed logical plan tree:

                    left join
             /                    \
       filter(sid >1)     filter(id > 1) <---- inferred predicate
         |                           |
      scan                      scan  

See `InferPredicatesTest`  for more cases

 The logic is as follows:
  1. poll up bottom predicate then infer additional predicates
    for example:
    select * from (select * from t1 where t1.id = 1) t join t2 on t.id = t2.id
    1. poll up bottom predicate
       select * from (select * from t1 where t1.id = 1) t join t2 on t.id = t2.id and t.id = 1
    2. infer
       select * from (select * from t1 where t1.id = 1) t join t2 on t.id = t2.id and t.id = 1 and t2.id = 1
    finally transformed sql:
       select * from (select * from t1 where t1.id = 1) t join t2 on t.id = t2.id and t2.id = 1
  2. put these predicates into `otherJoinConjuncts` , these predicates are processed in the next
    round of predicate push-down


Now only support infer `ComparisonPredicate`.

TODO: We should determine whether `expression` satisfies the condition for replacement
             eg: Satisfy `expression` is non-deterministic
This commit is contained in:
shee
2022-11-08 21:36:17 +08:00
committed by GitHub
parent b6f91b6eff
commit 3f3f2eb098
10 changed files with 995 additions and 0 deletions

View File

@ -29,6 +29,7 @@ public enum JobType {
APPLY_RULE,
DERIVE_STATS,
TOP_DOWN_REWRITE,
VISITOR_REWRITE,
BOTTOM_UP_REWRITE
;
}

View File

@ -19,11 +19,14 @@ package org.apache.doris.nereids.jobs.batch;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.jobs.Job;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.jobs.cascades.OptimizeGroupJob;
import org.apache.doris.nereids.jobs.rewrite.RewriteBottomUpJob;
import org.apache.doris.nereids.jobs.rewrite.RewriteTopDownJob;
import org.apache.doris.nereids.jobs.rewrite.VisitorRewriteJob;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleFactory;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import java.util.ArrayList;
import java.util.List;
@ -71,6 +74,10 @@ public abstract class BatchRulesJob {
cascadesContext.getCurrentJobContext(), once);
}
protected Job visitorJob(DefaultPlanRewriter<JobContext> planRewriter) {
return new VisitorRewriteJob(cascadesContext, planRewriter, true);
}
protected Job optimize() {
return new OptimizeGroupJob(
cascadesContext.getMemo().getRoot(),

View File

@ -30,6 +30,7 @@ import org.apache.doris.nereids.rules.rewrite.logical.EliminateLimit;
import org.apache.doris.nereids.rules.rewrite.logical.EliminateUnnecessaryProject;
import org.apache.doris.nereids.rules.rewrite.logical.ExtractSingleTableExpressionFromDisjunction;
import org.apache.doris.nereids.rules.rewrite.logical.FindHashConditionForJoin;
import org.apache.doris.nereids.rules.rewrite.logical.InferPredicates;
import org.apache.doris.nereids.rules.rewrite.logical.LimitPushDown;
import org.apache.doris.nereids.rules.rewrite.logical.NormalizeAggregate;
import org.apache.doris.nereids.rules.rewrite.logical.PruneOlapScanPartition;
@ -65,9 +66,11 @@ public class NereidsRewriteJobExecutor extends BatchRulesJob {
.add(topDownBatch(ImmutableList.of(new ExtractSingleTableExpressionFromDisjunction())))
.add(topDownBatch(ImmutableList.of(new NormalizeAggregate())))
.add(topDownBatch(RuleSet.PUSH_DOWN_FILTERS, false))
.add(visitorJob(new InferPredicates()))
.add(topDownBatch(ImmutableList.of(new ReorderJoin())))
.add(topDownBatch(ImmutableList.of(new ColumnPruning())))
.add(topDownBatch(RuleSet.PUSH_DOWN_FILTERS, false))
.add(visitorJob(new InferPredicates()))
.add(topDownBatch(ImmutableList.of(PushFilterInsideJoin.INSTANCE)))
.add(topDownBatch(ImmutableList.of(new FindHashConditionForJoin())))
.add(topDownBatch(ImmutableList.of(new LimitPushDown())))

View File

@ -0,0 +1,56 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.jobs.rewrite;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.jobs.Job;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.jobs.JobType;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import java.util.Objects;
/**
* Use visitor to rewrite the plan.
*/
public class VisitorRewriteJob extends Job {
private final Group group;
private final DefaultPlanRewriter<JobContext> planRewriter;
/**
* Constructor.
*/
public VisitorRewriteJob(CascadesContext cascadesContext, DefaultPlanRewriter<JobContext> rewriter, boolean once) {
super(JobType.VISITOR_REWRITE, cascadesContext.getCurrentJobContext(), once);
this.group = Objects.requireNonNull(cascadesContext.getMemo().getRoot(), "group cannot be null");
this.planRewriter = Objects.requireNonNull(rewriter, "planRewriter cannot be null");
}
@Override
public void execute() {
GroupExpression logicalExpression = group.getLogicalExpression();
Plan root = context.getCascadesContext().getMemo().copyOut(logicalExpression, true);
Plan rewrittenRoot = root.accept(planRewriter, context);
context.getCascadesContext().getMemo().copyIn(rewrittenRoot, group, true);
}
}

View File

@ -0,0 +1,117 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
/**
* infer additional predicates for `LogicalFilter` and `LogicalJoin`.
* The logic is as follows:
* 1. poll up bottom predicate then infer additional predicates
* for example:
* select * from (select * from t1 where t1.id = 1) t join t2 on t.id = t2.id
* 1. poll up bottom predicate
* select * from (select * from t1 where t1.id = 1) t join t2 on t.id = t2.id and t.id = 1
* 2. infer
* select * from (select * from t1 where t1.id = 1) t join t2 on t.id = t2.id and t.id = 1 and t2.id = 1
* finally transformed sql:
* select * from (select * from t1 where t1.id = 1) t join t2 on t.id = t2.id and t2.id = 1
* 2. put these predicates into `otherJoinConjuncts` , these predicates are processed in the next
* round of predicate push-down
*/
public class InferPredicates extends DefaultPlanRewriter<JobContext> {
private final PredicatePropagation propagation = new PredicatePropagation();
private final PullUpPredicates pollUpPredicates = new PullUpPredicates();
@Override
public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, JobContext context) {
join = (LogicalJoin<? extends Plan, ? extends Plan>) super.visit(join, context);
Plan left = join.left();
Plan right = join.right();
Set<Expression> expressions = getAllExpressions(left, right, join.getOnClauseCondition());
List<Expression> otherJoinConjuncts = Lists.newArrayList(join.getOtherJoinConjuncts());
switch (join.getJoinType()) {
case INNER_JOIN:
case CROSS_JOIN:
case LEFT_SEMI_JOIN:
case RIGHT_SEMI_JOIN:
otherJoinConjuncts.addAll(inferNewPredicate(left, expressions));
otherJoinConjuncts.addAll(inferNewPredicate(right, expressions));
break;
case LEFT_OUTER_JOIN:
case LEFT_ANTI_JOIN:
otherJoinConjuncts.addAll(inferNewPredicate(right, expressions));
break;
case RIGHT_OUTER_JOIN:
case RIGHT_ANTI_JOIN:
otherJoinConjuncts.addAll(inferNewPredicate(left, expressions));
break;
default:
return join;
}
return join.withOtherJoinConjuncts(otherJoinConjuncts);
}
@Override
public Plan visitLogicalFilter(LogicalFilter<? extends Plan> filter, JobContext context) {
filter = (LogicalFilter<? extends Plan>) super.visit(filter, context);
Set<Expression> filterPredicates = pullUpPredicates(filter);
filterPredicates.removeAll(pullUpPredicates(filter.child()));
filter.getConjuncts().forEach(filterPredicates::remove);
if (!filterPredicates.isEmpty()) {
filterPredicates.addAll(filter.getConjuncts());
return new LogicalFilter<>(ExpressionUtils.and(Lists.newArrayList(filterPredicates)), filter.child());
}
return filter;
}
private Set<Expression> getAllExpressions(Plan left, Plan right, Optional<Expression> condition) {
Set<Expression> baseExpressions = pullUpPredicates(left);
baseExpressions.addAll(pullUpPredicates(right));
condition.ifPresent(on -> baseExpressions.addAll(ExpressionUtils.extractConjunction(on)));
baseExpressions.addAll(propagation.infer(baseExpressions));
return baseExpressions;
}
private Set<Expression> pullUpPredicates(Plan plan) {
return Sets.newHashSet(plan.accept(pollUpPredicates, null));
}
private List<Expression> inferNewPredicate(Plan plan, Set<Expression> expressions) {
List<Expression> predicates = expressions.stream()
.filter(c -> !c.getInputSlots().isEmpty() && plan.getOutputSet().containsAll(
c.getInputSlots())).collect(Collectors.toList());
predicates.removeAll(plan.accept(pollUpPredicates, null));
return predicates;
}
}

View File

@ -0,0 +1,105 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import com.google.common.collect.Sets;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
/**
* derive additional predicates.
* for example:
* a = b and a = 1 => b = 1
*/
public class PredicatePropagation {
/**
* infer additional predicates.
*/
public Set<Expression> infer(Set<Expression> predicates) {
Set<Expression> inferred = Sets.newHashSet();
for (Expression predicate : predicates) {
if (canEquivalentInfer(predicate)) {
List<Expression> newInferred = predicates.stream()
.filter(p -> !p.equals(predicate))
.map(p -> doInfer(predicate, p))
.collect(Collectors.toList());
inferred.addAll(newInferred);
}
}
inferred.removeAll(predicates);
return inferred;
}
/**
* Use the left or right child of `leftSlotEqualToRightSlot` to replace the left or right child of `expression`
* Now only support infer `ComparisonPredicate`.
* TODO: We should determine whether `expression` satisfies the condition for replacement
* eg: Satisfy `expression` is non-deterministic
*/
private Expression doInfer(Expression leftSlotEqualToRightSlot, Expression expression) {
return expression.accept(new DefaultExpressionRewriter<Void>() {
@Override
public Expression visit(Expression expr, Void context) {
return expr;
}
@Override
public Expression visitComparisonPredicate(ComparisonPredicate cp, Void context) {
if (cp.left().isSlot() && cp.right().isConstant()) {
return replaceSlot(cp);
} else if (cp.left().isConstant() && cp.right().isSlot()) {
return replaceSlot(cp);
}
return super.visit(cp, context);
}
private Expression replaceSlot(Expression expr) {
return expr.rewriteUp(e -> {
if (e.equals(leftSlotEqualToRightSlot.child(0))) {
return leftSlotEqualToRightSlot.child(1);
} else if (e.equals(leftSlotEqualToRightSlot.child(1))) {
return leftSlotEqualToRightSlot.child(0);
} else {
return e;
}
});
}
}, null);
}
/**
* Currently only equivalence derivation is supported
* and requires that the left and right sides of an expression must be slot
*/
private boolean canEquivalentInfer(Expression predicate) {
return predicate instanceof EqualTo
&& predicate.children().stream().allMatch(e -> e instanceof SlotReference)
&& predicate.child(0).getDataType().equals(predicate.child(1).getDataType());
}
}

View File

@ -0,0 +1,165 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.Collection;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
/**
* poll up effective predicates from operator's children.
*/
public class PullUpPredicates extends PlanVisitor<ImmutableSet<Expression>, Void> {
PredicatePropagation propagation = new PredicatePropagation();
Map<Plan, ImmutableSet<Expression>> cache = new IdentityHashMap<>();
@Override
public ImmutableSet<Expression> visit(Plan plan, Void context) {
if (plan.arity() == 1) {
return plan.child(0).accept(this, context);
}
return ImmutableSet.of();
}
@Override
public ImmutableSet<Expression> visitLogicalFilter(LogicalFilter<? extends Plan> filter, Void context) {
return cacheOrElse(filter, () -> {
List<Expression> predicates = Lists.newArrayList(filter.getConjuncts());
predicates.addAll(filter.child().accept(this, context));
return getAvailableExpressions(predicates, filter);
});
}
@Override
public ImmutableSet<Expression> visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, Void context) {
return cacheOrElse(join, () -> {
Set<Expression> predicates = Sets.newHashSet();
ImmutableSet<Expression> leftPredicates = join.left().accept(this, context);
ImmutableSet<Expression> rightPredicates = join.right().accept(this, context);
switch (join.getJoinType()) {
case INNER_JOIN:
case CROSS_JOIN:
predicates.addAll(leftPredicates);
predicates.addAll(rightPredicates);
join.getOnClauseCondition().map(on -> predicates.addAll(ExpressionUtils.extractConjunction(on)));
break;
case LEFT_SEMI_JOIN:
predicates.addAll(leftPredicates);
join.getOnClauseCondition().map(on -> predicates.addAll(ExpressionUtils.extractConjunction(on)));
break;
case RIGHT_SEMI_JOIN:
predicates.addAll(rightPredicates);
join.getOnClauseCondition().map(on -> predicates.addAll(ExpressionUtils.extractConjunction(on)));
break;
case LEFT_OUTER_JOIN:
case LEFT_ANTI_JOIN:
predicates.addAll(leftPredicates);
break;
case RIGHT_OUTER_JOIN:
case RIGHT_ANTI_JOIN:
predicates.addAll(rightPredicates);
break;
default:
}
return getAvailableExpressions(predicates, join);
});
}
@Override
public ImmutableSet<Expression> visitLogicalProject(LogicalProject<? extends Plan> project, Void context) {
return cacheOrElse(project, () -> {
ImmutableSet<Expression> childPredicates = project.child().accept(this, context);
Map<Expression, Slot> expressionSlotMap = project.getAliasToProducer()
.entrySet()
.stream()
.collect(Collectors.toMap(Entry::getValue, Entry::getKey));
Expression expression = ExpressionUtils.replace(ExpressionUtils.and(Lists.newArrayList(childPredicates)),
expressionSlotMap);
List<Expression> predicates = ExpressionUtils.extractConjunction(expression);
return getAvailableExpressions(predicates, project);
});
}
@Override
public ImmutableSet<Expression> visitLogicalAggregate(LogicalAggregate<? extends Plan> aggregate, Void context) {
return cacheOrElse(aggregate, () -> {
ImmutableSet<Expression> childPredicates = aggregate.child().accept(this, context);
Map<Expression, Slot> expressionSlotMap = aggregate.getOutputExpressions()
.stream()
.filter(this::hasAgg)
.collect(Collectors.toMap(
namedExpr -> {
if (namedExpr instanceof Alias) {
return ((Alias) namedExpr).child();
} else {
return namedExpr;
}
}, NamedExpression::toSlot)
);
Expression expression = ExpressionUtils.replace(ExpressionUtils.and(Lists.newArrayList(childPredicates)),
expressionSlotMap);
List<Expression> predicates = ExpressionUtils.extractConjunction(expression);
return getAvailableExpressions(predicates, aggregate);
});
}
private ImmutableSet<Expression> cacheOrElse(Plan plan, Supplier<ImmutableSet<Expression>> predicatesSupplier) {
ImmutableSet<Expression> predicates = cache.get(plan);
if (predicates != null) {
return predicates;
}
predicates = predicatesSupplier.get();
cache.put(plan, predicates);
return predicates;
}
private ImmutableSet<Expression> getAvailableExpressions(Collection<Expression> predicates, Plan plan) {
Set<Expression> expressions = Sets.newHashSet(predicates);
expressions.addAll(propagation.infer(expressions));
return expressions.stream()
.filter(p -> plan.getOutputSet().containsAll(p.getInputSlots()))
.collect(ImmutableSet.toImmutableSet());
}
private boolean hasAgg(Expression expression) {
return expression.anyMatch(AggregateFunction.class::isInstance);
}
}

View File

@ -139,6 +139,10 @@ public abstract class Expression extends AbstractTreeNode<Expression> implements
return this instanceof NullLiteral;
}
public boolean isSlot() {
return this instanceof Slot;
}
@Override
public boolean equals(Object o) {
if (this == o) {

View File

@ -258,4 +258,10 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
public LogicalJoin withJoinType(JoinType joinType) {
return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, left(), right(), joinReorderContext);
}
public LogicalJoin withOtherJoinConjuncts(List<Expression> otherJoinConjuncts) {
return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, left(), right(),
joinReorderContext);
}
}

View File

@ -0,0 +1,531 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.util.PatternMatchSupported;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.utframe.TestWithFeService;
import org.junit.jupiter.api.Test;
public class InferPredicatesTest extends TestWithFeService implements PatternMatchSupported {
@Override
protected void runBeforeAll() throws Exception {
createDatabase("test");
createTable("create table test.student (\n"
+ "id int not null,\n"
+ "name varchar(128),\n"
+ "age int,sex int)\n"
+ "distributed by hash(id) buckets 10\n"
+ "properties('replication_num' = '1');");
createTable("create table test.score (\n"
+ "sid int not null, \n"
+ "cid int not null, \n"
+ "grade double)\n"
+ "distributed by hash(sid,cid) buckets 10\n"
+ "properties('replication_num' = '1');");
createTable("create table test.course (\n"
+ "id int not null, \n"
+ "name varchar(128), \n"
+ "teacher varchar(128))\n"
+ "distributed by hash(id) buckets 10\n"
+ "properties('replication_num' = '1');");
createTables("create table test.subquery1\n"
+ "(k1 bigint, k2 bigint)\n"
+ "duplicate key(k1)\n"
+ "distributed by hash(k2) buckets 1\n"
+ "properties('replication_num' = '1');\n",
"create table test.subquery2\n"
+ "(k1 varchar(10), k2 bigint)\n"
+ "partition by range(k2)\n"
+ "(partition p1 values less than(\"10\"))\n"
+ "distributed by hash(k2) buckets 1\n"
+ "properties('replication_num' = '1');",
"create table test.subquery3\n"
+ "(k1 int not null, k2 varchar(128), k3 bigint, v1 bigint, v2 bigint)\n"
+ "distributed by hash(k2) buckets 1\n"
+ "properties('replication_num' = '1');",
"create table test.subquery4\n"
+ "(k1 bigint, k2 bigint)\n"
+ "duplicate key(k1)\n"
+ "distributed by hash(k2) buckets 1\n"
+ "properties('replication_num' = '1');");
connectContext.setDatabase("default_cluster:test");
}
@Test
public void inferPredicatesTest01() {
String sql = "select * from student join score on student.id = score.sid where student.id > 1";
Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan();
System.out.println(plan.treeString());
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicates().toSql().contains("id > 1")),
logicalFilter(
logicalOlapScan()
).when(filer -> filer.getPredicates().toSql().contains("sid > 1"))
)
);
}
@Test
public void inferPredicatesTest02() {
String sql = "select * from student join score on student.id = score.sid";
Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan();
System.out.println(plan.treeString());
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalJoin(
logicalOlapScan(),
logicalOlapScan()
)
);
}
@Test
public void inferPredicatesTest03() {
String sql = "select * from student join score on student.id = score.sid where student.id in (1,2,3)";
Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan();
System.out.println(plan.treeString());
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicates().toSql().contains("id IN (1, 2, 3)")),
logicalOlapScan()
)
);
}
@Test
public void inferPredicatesTest04() {
String sql = "select * from student join score on student.id = score.sid and student.id in (1,2,3)";
Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan();
System.out.println(plan.treeString());
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicates().toSql().contains("id IN (1, 2, 3)")),
logicalOlapScan()
)
);
}
@Test
public void inferPredicatesTest05() {
String sql = "select * from student join score on student.id = score.sid join course on score.sid = course.id where student.id > 1";
Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan();
System.out.println(plan.treeString());
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalJoin(
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicates().toSql().contains("id > 1")),
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicates().toSql().contains("sid > 1"))
),
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicates().toSql().contains("id > 1"))
)
);
}
@Test
public void inferPredicatesTest06() {
String sql = "select * from student join score on student.id = score.sid join course on score.sid = course.id and score.sid > 1";
Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan();
System.out.println(plan.treeString());
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalJoin(
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicates().toSql().contains("id > 1")),
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicates().toSql().contains("sid > 1"))
),
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicates().toSql().contains("id > 1"))
)
);
}
@Test
public void inferPredicatesTest07() {
String sql = "select * from student left join score on student.id = score.sid where student.id > 1";
Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan();
System.out.println(plan.treeString());
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicates().toSql().contains("id > 1")),
logicalFilter(
logicalOlapScan()
).when(filer -> filer.getPredicates().toSql().contains("sid > 1"))
)
);
}
@Test
public void inferPredicatesTest08() {
String sql = "select * from student left join score on student.id = score.sid and student.id > 1";
Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan();
System.out.println(plan.treeString());
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalJoin(
logicalOlapScan(),
logicalFilter(
logicalOlapScan()
).when(filer -> filer.getPredicates().toSql().contains("sid > 1"))
)
);
}
@Test
public void inferPredicatesTest09() {
// convert left join to inner join
String sql = "select * from student left join score on student.id = score.sid where score.sid > 1";
Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan();
System.out.println(plan.treeString());
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filer -> filer.getPredicates().toSql().contains("id > 1")),
logicalFilter(
logicalOlapScan()
).when(filer -> filer.getPredicates().toSql().contains("sid > 1"))
)
);
}
@Test
public void inferPredicatesTest10() {
String sql = "select * from (select id as nid, name from student) t left join score on t.nid = score.sid where t.nid > 1";
Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan();
System.out.println(plan.treeString());
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalJoin(
logicalProject(
logicalFilter(
logicalOlapScan()
).when(filer -> filer.getPredicates().toSql().contains("id > 1"))
),
logicalFilter(
logicalOlapScan()
).when(filer -> filer.getPredicates().toSql().contains("sid > 1"))
)
);
}
@Test
public void inferPredicatesTest11() {
String sql = "select * from (select id as nid, name from student) t left join score on t.nid = score.sid and t.nid > 1";
Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan();
System.out.println(plan.treeString());
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalJoin(
logicalProject(
logicalOlapScan()
),
logicalFilter(
logicalOlapScan()
).when(filer -> filer.getPredicates().toSql().contains("sid > 1"))
)
);
}
@Test
public void inferPredicatesTest12() {
String sql = "select * from student left join (select sid as nid, sum(grade) from score group by sid) s on s.nid = student.id where student.id > 1";
Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan();
System.out.println(plan.treeString());
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filer -> filer.getPredicates().toSql().contains("id > 1")),
logicalProject(
logicalAggregate(
logicalProject(
logicalFilter(
logicalOlapScan()
).when(filer -> filer.getPredicates().toSql().contains("sid > 1"))
))
)
)
);
}
@Test
public void inferPredicatesTest13() {
String sql = "select * from (select id, name from student where id = 1) t left join score on t.id = score.sid";
Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan();
System.out.println(plan.treeString());
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalJoin(
logicalProject(
logicalFilter(
logicalOlapScan()
).when(filer -> filer.getPredicates().toSql().contains("id = 1"))
),
logicalFilter(
logicalOlapScan()
).when(filer -> filer.getPredicates().toSql().contains("sid = 1"))
)
);
}
@Test
public void inferPredicatesTest14() {
String sql = "select * from student left semi join score on student.id = score.sid where student.id > 1";
Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan();
System.out.println(plan.treeString());
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicates().toSql().contains("id > 1")),
logicalProject(
logicalFilter(
logicalOlapScan()
).when(filer -> filer.getPredicates().toSql().contains("sid > 1"))
)
)
);
}
@Test
public void inferPredicatesTest15() {
String sql = "select * from student left semi join score on student.id = score.sid and student.id > 1";
Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan();
System.out.println(plan.treeString());
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicates().toSql().contains("id > 1")),
logicalProject(
logicalFilter(
logicalOlapScan()
).when(filer -> filer.getPredicates().toSql().contains("sid > 1"))
)
)
);
}
@Test
public void inferPredicatesTest16() {
String sql = "select * from student left anti join score on student.id = score.sid and student.id > 1";
Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan();
System.out.println(plan.treeString());
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalJoin(
logicalOlapScan(),
logicalProject(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicates().toSql().contains("sid > 1"))
)
)
);
}
@Test
public void inferPredicatesTest17() {
String sql = "select * from student left anti join score on student.id = score.sid and score.sid > 1";
Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan();
System.out.println(plan.treeString());
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalJoin(
logicalOlapScan(),
logicalProject(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicates().toSql().contains("sid > 1"))
)
)
);
}
@Test
public void inferPredicatesTest18() {
String sql = "select * from student left anti join score on student.id = score.sid where student.id > 1";
Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan();
System.out.println(plan.treeString());
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicates().toSql().contains("id > 1")),
logicalProject(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicates().toSql().contains("sid > 1"))
)
)
);
}
@Test
public void inferPredicatesTest19() {
String sql = "select * from subquery1 left semi join (select t1.k3 from (select * from subquery3 left semi join (select k1 from subquery4 where k1 = 3) t on subquery3.k3 = t.k1) t1 inner join (select k2,sum(k2) as sk2 from subquery2 group by k2) t2 on t2.k2 = t1.v1 and t1.v2 > t2.sk2) t3 on t3.k3 = subquery1.k1";
Plan plan = PlanChecker.from(connectContext).analyze(sql).rewrite().getPlan();
System.out.println(plan.treeString());
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicates().toSql().contains("k1 = 3")),
logicalProject(
logicalJoin(
logicalJoin(
logicalProject(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicates().toSql().contains("k3 = 3"))
),
logicalProject(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicates().toSql().contains("k1 = 3"))
)
),
logicalProject()
)
)
)
);
}
@Test
public void inferPredicatesTest20() {
String sql = "select * from student left join score on student.id = score.sid and score.sid > 1 inner join course on course.id = score.sid";
PlanChecker.from(connectContext).analyze(sql).rewrite().printlnTree();
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalJoin(
logicalJoin(
logicalOlapScan(),
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicates().toSql().contains("sid > 1"))
),
logicalOlapScan()
)
);
}
@Test
public void inferPredicatesTest21() {
String sql = "select * from student,score,course where student.id = score.sid and score.sid = course.id and score.sid > 1";
PlanChecker.from(connectContext).analyze(sql).rewrite().printlnTree();
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matchesFromRoot(
logicalJoin(
logicalJoin(
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicates().toSql().contains("id > 1")),
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicates().toSql().contains("sid > 1"))
),
logicalFilter(
logicalOlapScan()
).when(filter -> filter.getPredicates().toSql().contains("id > 1"))
)
);
}
}