[refactor](Nereids): refactor UT of Nereids (#11330)

refactor the UT of Nereids.

Extract the plan constructor (This PR extract all olapscan and table into PlanConstructor).
This commit is contained in:
jakevin
2022-08-01 22:53:00 +08:00
committed by GitHub
parent 4ccdd65bf6
commit 80ce027ea2
23 changed files with 187 additions and 174 deletions

View File

@ -348,7 +348,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
throw new RuntimeException("Physical hash join could not execute without equal join condition.");
} else {
Expression eqJoinExpression = hashJoin.getCondition().get();
List<Expr> execEqConjunctList = ExpressionUtils.extractConjunctive(eqJoinExpression).stream()
List<Expr> execEqConjunctList = ExpressionUtils.extractConjunction(eqJoinExpression).stream()
.map(EqualTo.class::cast)
.map(e -> swapEqualToForChildrenOrder(e, hashJoin.left().getOutput()))
.map(e -> ExpressionTranslator.translate(e, context))
@ -400,7 +400,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
PlanFragment inputFragment = filter.child(0).accept(this, context);
PlanNode planNode = inputFragment.getPlanRoot();
Expression expression = filter.getPredicates();
List<Expression> expressionList = ExpressionUtils.extractConjunctive(expression);
List<Expression> expressionList = ExpressionUtils.extractConjunction(expression);
expressionList.stream().map(e -> ExpressionTranslator.translate(e, context)).forEach(planNode::addConjunct);
return inputFragment;
}

View File

@ -80,7 +80,7 @@ public class JoinLAsscom extends OneExplorationRuleFactory {
// Ignore join with some OnClause like:
// Join C = B + A for above example.
List<Expression> topJoinOnClauseConjuncts = ExpressionUtils.extractConjunctive(topJoinOnClause);
List<Expression> topJoinOnClauseConjuncts = ExpressionUtils.extractConjunction(topJoinOnClause);
for (Expression topJoinOnClauseConjunct : topJoinOnClauseConjuncts) {
if (ExpressionUtils.isIntersecting(topJoinOnClauseConjunct.collect(SlotReference.class::isInstance),
aOutputSlots)
@ -94,7 +94,7 @@ public class JoinLAsscom extends OneExplorationRuleFactory {
return null;
}
}
List<Expression> bottomJoinOnClauseConjuncts = ExpressionUtils.extractConjunctive(bottomJoinOnClause);
List<Expression> bottomJoinOnClauseConjuncts = ExpressionUtils.extractConjunction(bottomJoinOnClause);
List<Expression> allOnCondition = Lists.newArrayList();
allOnCondition.addAll(topJoinOnClauseConjuncts);

View File

@ -87,7 +87,7 @@ public class JoinProjectLAsscom extends OneExplorationRuleFactory {
// Ignore join with some OnClause like:
// Join C = B + A for above example.
List<Expression> topJoinOnClauseConjuncts = ExpressionUtils.extractConjunctive(topJoinOnClause);
List<Expression> topJoinOnClauseConjuncts = ExpressionUtils.extractConjunction(topJoinOnClause);
for (Expression topJoinOnClauseConjunct : topJoinOnClauseConjuncts) {
if (ExpressionUtils.isIntersecting(
topJoinOnClauseConjunct.collect(SlotReference.class::isInstance), aOutputSlots)
@ -101,7 +101,7 @@ public class JoinProjectLAsscom extends OneExplorationRuleFactory {
return null;
}
}
List<Expression> bottomJoinOnClauseConjuncts = ExpressionUtils.extractConjunctive(
List<Expression> bottomJoinOnClauseConjuncts = ExpressionUtils.extractConjunction(
bottomJoinOnClause);
List<Expression> allOnCondition = Lists.newArrayList();

View File

@ -168,7 +168,7 @@ public class MultiJoin extends PlanVisitor<Void, Void> {
public Void visitLogicalFilter(LogicalFilter<Plan> filter, Void context) {
Plan child = filter.child();
if (child instanceof LogicalJoin) {
conjuncts.addAll(ExpressionUtils.extractConjunctive(filter.getPredicates()));
conjuncts.addAll(ExpressionUtils.extractConjunction(filter.getPredicates()));
}
child.accept(this, context);
@ -184,7 +184,7 @@ public class MultiJoin extends PlanVisitor<Void, Void> {
join.left().accept(this, context);
join.right().accept(this, context);
join.getCondition().ifPresent(cond -> conjuncts.addAll(ExpressionUtils.extractConjunctive(cond)));
join.getCondition().ifPresent(cond -> conjuncts.addAll(ExpressionUtils.extractConjunction(cond)));
if (!(join.left() instanceof LogicalJoin)) {
joinInputs.add(join.left());
}

View File

@ -77,7 +77,7 @@ public class PushPredicateThroughAggregation extends OneRewriteRuleFactory {
}
List<Expression> pushDownPredicates = Lists.newArrayList();
List<Expression> filterPredicates = Lists.newArrayList();
ExpressionUtils.extractConjunctive(filter.getPredicates()).forEach(conjunct -> {
ExpressionUtils.extractConjunction(filter.getPredicates()).forEach(conjunct -> {
Set<Slot> conjunctSlots = SlotExtractor.extractSlot(conjunct);
if (groupBySlots.containsAll(conjunctSlots)) {
pushDownPredicates.add(conjunct);

View File

@ -79,7 +79,7 @@ public class PushPredicateThroughJoin extends OneRewriteRuleFactory {
List<Slot> leftInput = join.left().getOutput();
List<Slot> rightInput = join.right().getOutput();
ExpressionUtils.extractConjunctive(ExpressionUtils.and(onPredicates, wherePredicates))
ExpressionUtils.extractConjunction(ExpressionUtils.and(onPredicates, wherePredicates))
.forEach(predicate -> {
if (Objects.nonNull(getJoinCondition(predicate, leftInput, rightInput))) {
eqConditions.add(predicate);

View File

@ -44,7 +44,7 @@ public class JoinEstimation {
Expression eqCondition, JoinType joinType) {
StatsDeriveResult statsDeriveResult = new StatsDeriveResult(leftStats);
statsDeriveResult.merge(rightStats);
List<Expression> eqConjunctList = ExpressionUtils.extractConjunctive(eqCondition);
List<Expression> eqConjunctList = ExpressionUtils.extractConjunction(eqCondition);
long rowCount = -1;
if (joinType.isSemiOrAntiJoin()) {
rowCount = getSemiJoinRowCount(leftStats, rightStats, eqConjunctList, joinType);

View File

@ -37,11 +37,11 @@ import java.util.Set;
*/
public class ExpressionUtils {
public static List<Expression> extractConjunctive(Expression expr) {
public static List<Expression> extractConjunction(Expression expr) {
return extract(And.class, expr);
}
public static List<Expression> extractDisjunctive(Expression expr) {
public static List<Expression> extractDisjunction(Expression expr) {
return extract(Or.class, expr);
}

View File

@ -59,7 +59,7 @@ public class JoinUtils {
.collect(Collectors.toList());
Expression onCondition = join.getCondition().get();
List<Expression> conjunctList = ExpressionUtils.extractConjunctive(onCondition);
List<Expression> conjunctList = ExpressionUtils.extractConjunction(onCondition);
for (Expression predicate : conjunctList) {
if (isEqualTo(leftSlots, rightSlots, predicate)) {
eqConjuncts.add((EqualTo) predicate);

View File

@ -17,10 +17,7 @@
package org.apache.doris.nereids.jobs;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.Table;
import org.apache.doris.catalog.TableIf.TableType;
import org.apache.doris.catalog.Type;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.analyzer.UnboundRelation;
import org.apache.doris.nereids.memo.Group;
@ -38,6 +35,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRelation;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.ImmutableList;
@ -52,14 +50,9 @@ public class RewriteTopDownJobTest {
public static class FakeRule extends OneRewriteRuleFactory {
@Override
public Rule build() {
return unboundRelation().then(unboundRelation -> {
Table olapTable = new Table(0, "test", TableType.OLAP, ImmutableList.of(
new Column("id", Type.INT),
new Column("name", Type.STRING)
));
return new LogicalBoundRelation(olapTable, Lists.newArrayList("test"));
}
).toRule(RuleType.BINDING_RELATION);
return unboundRelation().then(unboundRelation ->
new LogicalBoundRelation(PlanConstructor.newTable(0L, "test"), Lists.newArrayList("test"))
).toRule(RuleType.BINDING_RELATION);
}
}

View File

@ -19,7 +19,6 @@ package org.apache.doris.nereids.jobs.cascades;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.Table;
import org.apache.doris.catalog.TableIf.TableType;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.memo.Memo;
@ -34,6 +33,7 @@ import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.statistics.ColumnStats;
import org.apache.doris.statistics.Statistics;
@ -42,6 +42,7 @@ import org.apache.doris.statistics.StatsDeriveResult;
import org.apache.doris.statistics.TableStats;
import com.google.common.base.Supplier;
import com.google.common.collect.ImmutableList;
import mockit.Expectations;
import mockit.Mocked;
import org.junit.jupiter.api.Assertions;
@ -84,14 +85,11 @@ public class DeriveStatsJobTest {
columnStats1.setNdv(10);
columnStats1.setNumNulls(5);
long tableId1 = 0;
String tableName1 = "t1";
TableStats tableStats1 = new TableStats();
tableStats1.putColumnStats("c1", columnStats1);
Statistics statistics = new Statistics();
statistics.putTableStats(tableId1, tableStats1);
List<String> qualifier = new ArrayList<>();
qualifier.add("test");
qualifier.add("t");
List<String> qualifier = ImmutableList.of("test", "t");
slot1 = new SlotReference("c1", IntegerType.INSTANCE, true, qualifier);
new Expectations() {{
ConnectContext.get();
@ -104,7 +102,7 @@ public class DeriveStatsJobTest {
result = statistics;
}};
Table table1 = new Table(tableId1, tableName1, TableType.OLAP, Collections.emptyList());
Table table1 = PlanConstructor.newTable(tableId1, "t1");
return new LogicalOlapScan(table1, Collections.emptyList()).withLogicalProperties(
Optional.of(new LogicalProperties(new Supplier<List<Slot>>() {
@Override

View File

@ -17,10 +17,6 @@
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.catalog.AggregateType;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.Table;
import org.apache.doris.catalog.Type;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.trees.expressions.EqualTo;
@ -31,6 +27,7 @@ import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.util.PlanConstructor;
import com.google.common.collect.ImmutableList;
import mockit.Mocked;
@ -43,13 +40,8 @@ import java.util.Optional;
public class JoinCommuteTest {
@Test
public void testInnerJoinCommute(@Mocked PlannerContext plannerContext) {
Table table1 = new Table(0L, "table1", Table.TableType.OLAP,
ImmutableList.of(new Column("id", Type.INT, true, AggregateType.NONE, "0", "")));
LogicalOlapScan scan1 = new LogicalOlapScan(table1, ImmutableList.of());
Table table2 = new Table(0L, "table2", Table.TableType.OLAP,
ImmutableList.of(new Column("id", Type.INT, true, AggregateType.NONE, "0", "")));
LogicalOlapScan scan2 = new LogicalOlapScan(table2, ImmutableList.of());
LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan("t2");
LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan("t2");
Expression onCondition = new EqualTo(
new SlotReference("id", new BigIntType(), true, ImmutableList.of("table1")),
@ -66,5 +58,4 @@ public class JoinCommuteTest {
Assertions.assertEquals(join.child(0), newJoin.child(1));
Assertions.assertEquals(join.child(1), newJoin.child(0));
}
}

View File

@ -17,10 +17,6 @@
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.catalog.AggregateType;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.Table;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.rules.Rule;
@ -31,8 +27,8 @@ import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.util.PlanConstructor;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import mockit.Mocked;
import org.junit.jupiter.api.Assertions;
@ -50,20 +46,10 @@ public class JoinLAsscomTest {
@BeforeAll
public static void init() {
Table t1 = new Table(0L, "t1", Table.TableType.OLAP,
ImmutableList.of(new Column("id", Type.INT, true, AggregateType.NONE, "0", ""),
new Column("name", Type.STRING, true, AggregateType.NONE, "0", "")));
LogicalOlapScan scan1 = new LogicalOlapScan(t1, ImmutableList.of());
LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScanWithTable("t1");
LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScanWithTable("t2");
LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScanWithTable("t3");
Table t2 = new Table(0L, "t2", Table.TableType.OLAP,
ImmutableList.of(new Column("id", Type.INT, true, AggregateType.NONE, "0", ""),
new Column("name", Type.STRING, true, AggregateType.NONE, "0", "")));
LogicalOlapScan scan2 = new LogicalOlapScan(t2, ImmutableList.of());
Table t3 = new Table(0L, "t3", Table.TableType.OLAP,
ImmutableList.of(new Column("id", Type.INT, true, AggregateType.NONE, "0", ""),
new Column("name", Type.STRING, true, AggregateType.NONE, "0", "")));
LogicalOlapScan scan3 = new LogicalOlapScan(t3, ImmutableList.of());
scans.add(scan1);
scans.add(scan2);
scans.add(scan3);

View File

@ -17,10 +17,6 @@
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.catalog.AggregateType;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.Table;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.rules.Rule;
@ -34,6 +30,7 @@ import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.PlanConstructor;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
@ -48,25 +45,15 @@ import java.util.stream.Collectors;
public class JoinProjectLAsscomTest {
private static List<LogicalOlapScan> scans = Lists.newArrayList();
private static List<List<SlotReference>> outputs = Lists.newArrayList();
private static final List<LogicalOlapScan> scans = Lists.newArrayList();
private static final List<List<SlotReference>> outputs = Lists.newArrayList();
@BeforeAll
public static void init() {
Table t1 = new Table(0L, "t1", Table.TableType.OLAP,
ImmutableList.of(new Column("id", Type.INT, true, AggregateType.NONE, "0", ""),
new Column("name", Type.STRING, true, AggregateType.NONE, "0", "")));
LogicalOlapScan scan1 = new LogicalOlapScan(t1, ImmutableList.of());
LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScanWithTable("t1");
LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScanWithTable("t2");
LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScanWithTable("t3");
Table t2 = new Table(0L, "t2", Table.TableType.OLAP,
ImmutableList.of(new Column("id", Type.INT, true, AggregateType.NONE, "0", ""),
new Column("name", Type.STRING, true, AggregateType.NONE, "0", "")));
LogicalOlapScan scan2 = new LogicalOlapScan(t2, ImmutableList.of());
Table t3 = new Table(0L, "t3", Table.TableType.OLAP,
ImmutableList.of(new Column("id", Type.INT, true, AggregateType.NONE, "0", ""),
new Column("name", Type.STRING, true, AggregateType.NONE, "0", "")));
LogicalOlapScan scan3 = new LogicalOlapScan(t3, ImmutableList.of());
scans.add(scan1);
scans.add(scan2);
scans.add(scan3);
@ -77,6 +64,7 @@ public class JoinProjectLAsscomTest {
.collect(Collectors.toList());
List<SlotReference> t3Output = scan3.getOutput().stream().map(slot -> (SlotReference) slot)
.collect(Collectors.toList());
outputs.add(t1Output);
outputs.add(t2Output);
outputs.add(t3Output);

View File

@ -19,27 +19,67 @@ package org.apache.doris.nereids.rules.implementation;
import org.apache.doris.nereids.PlannerContext;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import mockit.Mocked;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.List;
import java.util.Map;
public class LogicalProjectToPhysicalProjectTest {
@Test
public void projectionImplTest(@Mocked Group group, @Mocked PlannerContext plannerContext) {
Plan plan = new LogicalProject(Lists.newArrayList(), new GroupPlan(group));
Rule rule = new LogicalProjectToPhysicalProject().build();
List<Plan> transform = rule.transform(plan, plannerContext);
private final Map<String, Rule> rulesMap
= ImmutableMap.<String, Rule>builder()
.put(LogicalProject.class.getName(), (new LogicalProjectToPhysicalProject()).build())
.put(LogicalAggregate.class.getName(), (new LogicalAggToPhysicalHashAgg()).build())
.put(LogicalJoin.class.getName(), (new LogicalJoinToHashJoin()).build())
.put(LogicalOlapScan.class.getName(), (new LogicalOlapScanToPhysicalOlapScan()).build())
.put(LogicalFilter.class.getName(), (new LogicalFilterToPhysicalFilter()).build())
.put(LogicalSort.class.getName(), (new LogicalSortToPhysicalHeapSort()).build())
.build();
private PhysicalPlan rewriteLogicalToPhysical(Group group, PlannerContext plannerContext) {
List<Plan> children = Lists.newArrayList();
for (Group child : group.getLogicalExpression().children()) {
children.add(rewriteLogicalToPhysical(child, plannerContext));
}
Rule rule = rulesMap.get(group.getLogicalExpression().getPlan().getClass().getName());
List<Plan> transform = rule.transform(group.getLogicalExpression().getPlan(), plannerContext);
Assertions.assertEquals(1, transform.size());
Plan implPlan = transform.get(0);
Assertions.assertEquals(PlanType.PHYSICAL_PROJECT, implPlan.getType());
Assertions.assertTrue(transform.get(0) instanceof PhysicalPlan);
PhysicalPlan implPlanNode = (PhysicalPlan) transform.get(0);
return (PhysicalPlan) implPlanNode.withChildren(children);
}
@Test
public void projectionImplTest() {
LogicalOlapScan scan = PlanConstructor.newLogicalOlapScan("a");
LogicalPlan project = new LogicalProject<>(Lists.newArrayList(), scan);
PlannerContext plannerContext = new Memo(project)
.newPlannerContext(new ConnectContext())
.setDefaultJobContext();
PhysicalPlan physicalProject = rewriteLogicalToPhysical(plannerContext.getMemo().getRoot(), plannerContext);
Assertions.assertEquals(PlanType.PHYSICAL_PROJECT, physicalProject.getType());
PhysicalPlan physicalScan = (PhysicalPlan) physicalProject.child(0);
Assertions.assertEquals(PlanType.PHYSICAL_OLAP_SCAN, physicalScan.getType());
}
}

View File

@ -17,10 +17,6 @@
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.catalog.AggregateType;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.Table;
import org.apache.doris.catalog.Type;
import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.Alias;
@ -34,6 +30,7 @@ import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnary;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.nereids.util.PlanRewriter;
import org.apache.doris.qe.ConnectContext;
@ -52,11 +49,7 @@ public class AggregateDisassembleTest {
@BeforeAll
public final void beforeAll() {
Table student = new Table(0L, "student", Table.TableType.OLAP,
ImmutableList.of(new Column("id", Type.INT, true, AggregateType.NONE, true, "0", ""),
new Column("name", Type.STRING, true, AggregateType.NONE, true, "", ""),
new Column("age", Type.INT, true, AggregateType.NONE, true, "", "")));
rStudent = new LogicalOlapScan(student, ImmutableList.of("student"));
rStudent = new LogicalOlapScan(PlanConstructor.student, ImmutableList.of("student"));
}
/**

View File

@ -17,10 +17,6 @@
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.catalog.AggregateType;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.Table;
import org.apache.doris.catalog.Type;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionNormalization;
@ -41,6 +37,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.nereids.util.PlanRewriter;
import org.apache.doris.qe.ConnectContext;
@ -59,9 +56,7 @@ import java.util.Optional;
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
public class PushDownPredicateTest {
private Table student;
private Table score;
private Table course;
private Plan rStudent;
private Plan rScore;
@ -72,26 +67,11 @@ public class PushDownPredicateTest {
*/
@BeforeAll
public final void beforeAll() {
student = new Table(0L, "student", Table.TableType.OLAP,
ImmutableList.<Column>of(new Column("id", Type.INT, true, AggregateType.NONE, "0", ""),
new Column("name", Type.STRING, true, AggregateType.NONE, "", ""),
new Column("age", Type.INT, true, AggregateType.NONE, "", "")));
rStudent = new LogicalOlapScan(PlanConstructor.student, ImmutableList.of("student"));
score = new Table(0L, "score", Table.TableType.OLAP,
ImmutableList.<Column>of(new Column("sid", Type.INT, true, AggregateType.NONE, "0", ""),
new Column("cid", Type.INT, true, AggregateType.NONE, "", ""),
new Column("grade", Type.DOUBLE, true, AggregateType.NONE, "", "")));
rScore = new LogicalOlapScan(PlanConstructor.score, ImmutableList.of("score"));
course = new Table(0L, "course", Table.TableType.OLAP,
ImmutableList.<Column>of(new Column("cid", Type.INT, true, AggregateType.NONE, "0", ""),
new Column("name", Type.STRING, true, AggregateType.NONE, "", ""),
new Column("teacher", Type.STRING, true, AggregateType.NONE, "", "")));
rStudent = new LogicalOlapScan(student, ImmutableList.of("student"));
rScore = new LogicalOlapScan(score, ImmutableList.of("score"));
rCourse = new LogicalOlapScan(course, ImmutableList.of("course"));
rCourse = new LogicalOlapScan(PlanConstructor.course, ImmutableList.of("course"));
}
@Test

View File

@ -18,10 +18,6 @@
package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.catalog.AggregateType;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.Table;
import org.apache.doris.catalog.Type;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.memo.Memo;
@ -40,6 +36,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.nereids.util.PlanRewriter;
import org.apache.doris.qe.ConnectContext;
@ -73,12 +70,7 @@ public class PushDownPredicateThroughAggregationTest {
*/
@Test
public void pushDownPredicateOneFilterTest() {
Table student = new Table(0L, "student", Table.TableType.OLAP,
ImmutableList.<Column>of(new Column("id", Type.INT, true, AggregateType.NONE, "0", ""),
new Column("gender", Type.INT, false, AggregateType.NONE, "0", ""),
new Column("name", Type.STRING, true, AggregateType.NONE, "", ""),
new Column("age", Type.INT, true, AggregateType.NONE, "", "")));
Plan scan = new LogicalOlapScan(student, ImmutableList.of("student"));
Plan scan = new LogicalOlapScan(PlanConstructor.student, ImmutableList.of("student"));
Slot gender = scan.getOutput().get(1);
Slot age = scan.getOutput().get(3);
@ -138,12 +130,7 @@ public class PushDownPredicateThroughAggregationTest {
*/
@Test
public void pushDownPredicateTwoFilterTest() {
Table student = new Table(0L, "student", Table.TableType.OLAP,
ImmutableList.<Column>of(new Column("id", Type.INT, true, AggregateType.NONE, "0", ""),
new Column("gender", Type.INT, false, AggregateType.NONE, "0", ""),
new Column("name", Type.STRING, true, AggregateType.NONE, "", ""),
new Column("age", Type.INT, true, AggregateType.NONE, "", "")));
Plan scan = new LogicalOlapScan(student, ImmutableList.of("student"));
Plan scan = new LogicalOlapScan(PlanConstructor.student, ImmutableList.of("student"));
Slot gender = scan.getOutput().get(1);
Slot name = scan.getOutput().get(2);
Slot age = scan.getOutput().get(3);

View File

@ -19,7 +19,6 @@ package org.apache.doris.nereids.stats;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.Table;
import org.apache.doris.catalog.TableIf.TableType;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.LogicalProperties;
@ -39,6 +38,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.statistics.ColumnStats;
import org.apache.doris.statistics.Statistics;
@ -47,6 +47,7 @@ import org.apache.doris.statistics.StatsDeriveResult;
import org.apache.doris.statistics.TableStats;
import com.google.common.base.Supplier;
import com.google.common.collect.ImmutableList;
import mockit.Expectations;
import mockit.Mocked;
import org.junit.jupiter.api.Assertions;
@ -202,14 +203,11 @@ public class StatsCalculatorTest {
columnStats1.setNdv(10);
columnStats1.setNumNulls(5);
long tableId1 = 0;
String tableName1 = "t1";
TableStats tableStats1 = new TableStats();
tableStats1.putColumnStats("c1", columnStats1);
Statistics statistics = new Statistics();
statistics.putTableStats(tableId1, tableStats1);
List<String> qualifier = new ArrayList<>();
qualifier.add("test");
qualifier.add("t");
List<String> qualifier = ImmutableList.of("test", "t");
SlotReference slot1 = new SlotReference("c1", IntegerType.INSTANCE, true, qualifier);
new Expectations() {{
ConnectContext.get();
@ -222,7 +220,7 @@ public class StatsCalculatorTest {
result = statistics;
}};
Table table1 = new Table(tableId1, tableName1, TableType.OLAP, Collections.emptyList());
Table table1 = PlanConstructor.newTable(tableId1, "t1");
LogicalOlapScan logicalOlapScan1 = new LogicalOlapScan(table1, Collections.emptyList()).withLogicalProperties(
Optional.of(new LogicalProperties(new Supplier<List<Slot>>() {
@Override

View File

@ -18,8 +18,6 @@
package org.apache.doris.nereids.trees.plans;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.catalog.Table;
import org.apache.doris.catalog.TableIf.TableType;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.trees.expressions.EqualTo;
@ -40,6 +38,7 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalHeapSort;
import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.util.PlanConstructor;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
@ -111,11 +110,10 @@ public class PlanEqualsTest {
@Test
public void testLogicalOlapScan() {
LogicalOlapScan olapScan = new LogicalOlapScan(new Table(TableType.OLAP), Lists.newArrayList());
Assertions.assertEquals(olapScan, olapScan);
LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan("table");
LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan("table");
LogicalOlapScan olapScan1 = new LogicalOlapScan(new Table(TableType.OLAP), Lists.newArrayList());
Assertions.assertEquals(olapScan, olapScan1);
Assertions.assertEquals(scan1, scan2);
}
@Test

View File

@ -17,20 +17,17 @@
package org.apache.doris.nereids.trees.plans;
import org.apache.doris.catalog.AggregateType;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.Table;
import org.apache.doris.catalog.Type;
import org.apache.doris.nereids.analyzer.UnboundRelation;
import org.apache.doris.nereids.exceptions.UnboundException;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalRelation;
import org.apache.doris.nereids.trees.plans.physical.PhysicalRelation;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.util.PlanConstructor;
import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Assertions;
@ -39,14 +36,10 @@ import org.junit.jupiter.api.Test;
import java.util.List;
import java.util.Optional;
public class TestPlanOutput {
public class PlanOutputTest {
@Test
public void testComputeOutput() {
Table table = new Table(0L, "a", Table.TableType.OLAP, ImmutableList.<Column>of(
new Column("id", Type.INT, true, AggregateType.NONE, "0", ""),
new Column("name", Type.STRING, true, AggregateType.NONE, "", "")
));
LogicalRelation relationPlan = new LogicalOlapScan(table, ImmutableList.of("db"));
LogicalOlapScan relationPlan = PlanConstructor.newLogicalOlapScanWithTable("a");
List<Slot> output = relationPlan.getOutput();
Assertions.assertEquals(2, output.size());
Assertions.assertEquals(output.get(0).getName(), "id");
@ -74,11 +67,7 @@ public class TestPlanOutput {
@Test
public void testWithOutput() {
Table table = new Table(0L, "a", Table.TableType.OLAP, ImmutableList.of(
new Column("id", Type.INT, true, AggregateType.NONE, "0", ""),
new Column("name", Type.STRING, true, AggregateType.NONE, "", "")
));
LogicalRelation relationPlan = new LogicalOlapScan(table, ImmutableList.of("db"));
LogicalOlapScan relationPlan = PlanConstructor.newLogicalOlapScanWithTable("a");
List<Slot> output = relationPlan.getOutput();
// column prune

View File

@ -33,12 +33,12 @@ public class ExpressionUtilsTest {
private static final NereidsParser PARSER = new NereidsParser();
@Test
public void extractConjunctsTest() {
public void extractConjunctionTest() {
List<Expression> expressions;
Expression expr;
expr = PARSER.parseExpression("a");
expressions = ExpressionUtils.extractConjunctive(expr);
expressions = ExpressionUtils.extractConjunction(expr);
Assertions.assertEquals(1, expressions.size());
Assertions.assertEquals(expr, expressions.get(0));
@ -47,7 +47,7 @@ public class ExpressionUtilsTest {
Expression b = PARSER.parseExpression("b");
Expression c = PARSER.parseExpression("c");
expressions = ExpressionUtils.extractConjunctive(expr);
expressions = ExpressionUtils.extractConjunction(expr);
Assertions.assertEquals(3, expressions.size());
Assertions.assertEquals(a, expressions.get(0));
Assertions.assertEquals(b, expressions.get(1));
@ -55,7 +55,7 @@ public class ExpressionUtilsTest {
expr = PARSER.parseExpression("(a or b) and c and (e or f)");
expressions = ExpressionUtils.extractConjunctive(expr);
expressions = ExpressionUtils.extractConjunction(expr);
Expression aOrb = PARSER.parseExpression("a or b");
Expression eOrf = PARSER.parseExpression("e or f");
Assertions.assertEquals(3, expressions.size());
@ -65,12 +65,12 @@ public class ExpressionUtilsTest {
}
@Test
public void extractDisjunctsTest() {
public void extractDisjunctionTest() {
List<Expression> expressions;
Expression expr;
expr = PARSER.parseExpression("a");
expressions = ExpressionUtils.extractDisjunctive(expr);
expressions = ExpressionUtils.extractDisjunction(expr);
Assertions.assertEquals(1, expressions.size());
Assertions.assertEquals(expr, expressions.get(0));
@ -79,14 +79,14 @@ public class ExpressionUtilsTest {
Expression b = PARSER.parseExpression("b");
Expression c = PARSER.parseExpression("c");
expressions = ExpressionUtils.extractDisjunctive(expr);
expressions = ExpressionUtils.extractDisjunction(expr);
Assertions.assertEquals(3, expressions.size());
Assertions.assertEquals(a, expressions.get(0));
Assertions.assertEquals(b, expressions.get(1));
Assertions.assertEquals(c, expressions.get(2));
expr = PARSER.parseExpression("(a and b) or c or (e and f)");
expressions = ExpressionUtils.extractDisjunctive(expr);
expressions = ExpressionUtils.extractDisjunction(expr);
Expression aAndb = PARSER.parseExpression("a and b");
Expression eAndf = PARSER.parseExpression("e and f");
Assertions.assertEquals(3, expressions.size());

View File

@ -0,0 +1,72 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.util;
import org.apache.doris.catalog.AggregateType;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.KeysType;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.catalog.Table;
import org.apache.doris.catalog.Type;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import com.google.common.collect.ImmutableList;
public class PlanConstructor {
public static Table student = new Table(0L, "student", Table.TableType.OLAP,
ImmutableList.<Column>of(new Column("id", Type.INT, true, AggregateType.NONE, "0", ""),
new Column("gender", Type.INT, false, AggregateType.NONE, "0", ""),
new Column("name", Type.STRING, true, AggregateType.NONE, "", ""),
new Column("age", Type.INT, true, AggregateType.NONE, "", "")));
public static Table score = new Table(0L, "score", Table.TableType.OLAP,
ImmutableList.<Column>of(new Column("sid", Type.INT, true, AggregateType.NONE, "0", ""),
new Column("cid", Type.INT, true, AggregateType.NONE, "", ""),
new Column("grade", Type.DOUBLE, true, AggregateType.NONE, "", "")));
public static Table course = new Table(0L, "course", Table.TableType.OLAP,
ImmutableList.<Column>of(new Column("cid", Type.INT, true, AggregateType.NONE, "0", ""),
new Column("name", Type.STRING, true, AggregateType.NONE, "", ""),
new Column("teacher", Type.STRING, true, AggregateType.NONE, "", "")));
public static OlapTable newOlapTable(long tableId, String tableName) {
return new OlapTable(0L, tableName,
ImmutableList.of(
new Column("id", Type.INT, true, AggregateType.NONE, "0", ""),
new Column("name", Type.STRING, true, AggregateType.NONE, "", "")),
KeysType.PRIMARY_KEYS, null, null);
}
public static Table newTable(long tableId, String tableName) {
return new Table(tableId, tableName, Table.TableType.OLAP,
ImmutableList.<Column>of(
new Column("id", Type.INT, true, AggregateType.NONE, "0", ""),
new Column("name", Type.STRING, true, AggregateType.NONE, "", "")
));
}
// With OlapTable
public static LogicalOlapScan newLogicalOlapScan(String tableName) {
return new LogicalOlapScan(newOlapTable(0L, tableName), ImmutableList.of("db"));
}
// With Table
public static LogicalOlapScan newLogicalOlapScanWithTable(String tableName) {
return new LogicalOlapScan(newTable(0L, tableName), ImmutableList.of("db"));
}
}