[feature](Nereids): optimize logical group expression in dphyp (#30000)

This commit is contained in:
谢健
2024-01-16 17:16:41 +08:00
committed by yiguolei
parent d11e797d4c
commit 4bf4239d7a
13 changed files with 1769 additions and 312 deletions

View File

@ -72,7 +72,6 @@ public class StatementContext {
private int maxNAryInnerJoin = 0;
private boolean isDpHyp = false;
private boolean isOtherJoinReorder = false;
// hasUnknownColStats true if any column stats in the tables used by this sql is unknown
// the algorithm to derive plan when column stats are unknown is implemented in cascading framework, not in dphyper.
@ -158,14 +157,6 @@ public class StatementContext {
isDpHyp = dpHyp;
}
public boolean isOtherJoinReorder() {
return isOtherJoinReorder;
}
public void setOtherJoinReorder(boolean otherJoinReorder) {
isOtherJoinReorder = otherJoinReorder;
}
public ExprId getNextExprId() {
return exprIdGenerator.getNextId();
}

View File

@ -75,7 +75,6 @@ public class OptimizeGroupExpressionJob extends Job {
|| context.getCascadesContext().getMemo().getGroupExpressionsSize() > context.getCascadesContext()
.getConnectContext().getSessionVariable().memoMaxGroupExpressionSize;
boolean isDpHyp = context.getCascadesContext().getStatementContext().isDpHyp();
boolean isOtherJoinReorder = context.getCascadesContext().getStatementContext().isOtherJoinReorder();
boolean isEnableBushyTree = context.getCascadesContext().getConnectContext().getSessionVariable()
.isEnableBushyTree();
boolean isLeftZigZagTree = context.getCascadesContext().getConnectContext()
@ -86,11 +85,7 @@ public class OptimizeGroupExpressionJob extends Job {
if (isDisableJoinReorder) {
return Collections.emptyList();
} else if (isDpHyp) {
if (isOtherJoinReorder) {
return getRuleSet().getDPHypReorderRules();
} else {
return Collections.emptyList();
}
return getRuleSet().getDPHypReorderRules();
} else if (isLeftZigZagTree) {
return getRuleSet().getLeftZigZagTreeJoinReorder();
} else if (isEnableBushyTree) {

View File

@ -67,7 +67,6 @@ public class Optimizer {
boolean isDpHyp = getSessionVariable().enableDPHypOptimizer
|| maxJoinCount > maxTableCount;
cascadesContext.getStatementContext().setDpHyp(isDpHyp);
cascadesContext.getStatementContext().setOtherJoinReorder(false);
if (!getSessionVariable().isDisableJoinReorder() && isDpHyp
&& maxJoinCount <= getSessionVariable().getMaxJoinNumberOfReorder()) {
//RightNow, dphyper can only order 64 join operators
@ -85,7 +84,6 @@ public class Optimizer {
// Due to EnsureProjectOnTopJoin, root group can't be Join Group, so DPHyp doesn't change the root group
cascadesContext.pushJob(new JoinOrderJob(root, cascadesContext.getCurrentJobContext()));
cascadesContext.getJobScheduler().executeJobPool(cascadesContext);
cascadesContext.getStatementContext().setOtherJoinReorder(true);
}
private SessionVariable getSessionVariable() {

View File

@ -17,10 +17,9 @@
package org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver;
import org.apache.doris.nereids.hint.DistributeHint;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.jobs.cascades.CostAndEnforcerJob;
import org.apache.doris.nereids.jobs.cascades.DeriveStatsJob;
import org.apache.doris.nereids.jobs.cascades.OptimizeGroupExpressionJob;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.HyperGraph;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.Edge;
@ -29,40 +28,28 @@ import org.apache.doris.nereids.memo.CopyInResult;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.properties.FunctionalDependencies;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.DistributeType;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalNestedLoopJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.JoinUtils;
import org.apache.doris.nereids.util.PlanUtils;
import org.apache.doris.qe.ConnectContext;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@ -91,7 +78,6 @@ public class PlanReceiver implements AbstractReceiver {
this.finalOutputs = outputs;
}
/**
* Emit a new plan from bottom to top
* <p>
@ -130,21 +116,18 @@ public class PlanReceiver implements AbstractReceiver {
}
long fullKey = LongBitmap.newBitmapUnion(left, right);
List<Plan> physicalJoins = proposeAllPhysicalJoins(joinType, leftPlan, rightPlan, hashConjuncts,
LogicalPlan logicalPlan = proposeJoin(joinType, leftPlan, rightPlan, hashConjuncts,
otherConjuncts);
List<Plan> physicalPlans = proposeProject(physicalJoins, edges, left, right);
logicalPlan = proposeProject(logicalPlan, edges, left, right);
// Second, we copy all physical plan to Group and generate properties and calculate cost
if (!planTable.containsKey(fullKey)) {
planTable.put(fullKey, memo.newGroup(physicalPlans.get(0).getLogicalProperties()));
planTable.put(fullKey, memo.newGroup(logicalPlan.getLogicalProperties()));
}
Group group = planTable.get(fullKey);
for (Plan plan : physicalPlans) {
CopyInResult copyInResult = memo.copyIn(plan, group, false, planTable);
GroupExpression physicalExpression = copyInResult.correspondingExpression;
proposeAllDistributedPlans(physicalExpression);
}
CopyInResult copyInResult = memo.copyIn(logicalPlan, group, false, planTable);
proposeAllDistributedPlans(copyInResult.correspondingExpression);
return true;
}
@ -204,7 +187,7 @@ public class PlanReceiver implements AbstractReceiver {
}
private void proposeAllDistributedPlans(GroupExpression groupExpression) {
jobContext.getCascadesContext().pushJob(new CostAndEnforcerJob(groupExpression,
jobContext.getCascadesContext().pushJob(new OptimizeGroupExpressionJob(groupExpression,
new JobContext(jobContext.getCascadesContext(), PhysicalProperties.ANY, Double.MAX_VALUE)));
if (!groupExpression.isStatDerived()) {
jobContext.getCascadesContext().pushJob(new DeriveStatsJob(groupExpression,
@ -213,42 +196,16 @@ public class PlanReceiver implements AbstractReceiver {
jobContext.getCascadesContext().getJobScheduler().executeJobPool(jobContext.getCascadesContext());
}
private List<Plan> proposeAllPhysicalJoins(JoinType joinType, Plan left, Plan right, List<Expression> hashConjuncts,
private LogicalPlan proposeJoin(JoinType joinType, Plan left, Plan right, List<Expression> hashConjuncts,
List<Expression> otherConjuncts) {
// Check whether only NSL can be performed
LogicalProperties joinProperties = new LogicalProperties(
() -> JoinUtils.getJoinOutput(joinType, left, right), () -> FunctionalDependencies.EMPTY_FUNC_DEPS);
List<Plan> plans = Lists.newArrayList();
if (JoinUtils.shouldNestedLoopJoin(joinType, hashConjuncts)) {
plans.add(new PhysicalNestedLoopJoin<>(joinType, hashConjuncts, otherConjuncts,
Optional.empty(), joinProperties,
left, right));
if (joinType.isSwapJoinType()) {
plans.add(new PhysicalNestedLoopJoin<>(joinType.swap(), hashConjuncts, otherConjuncts, Optional.empty(),
joinProperties,
right, left));
}
} else {
plans.add(new PhysicalHashJoin<>(joinType, hashConjuncts, otherConjuncts,
new DistributeHint(DistributeType.NONE), Optional.empty(),
joinProperties,
left, right));
if (joinType.isSwapJoinType()) {
plans.add(new PhysicalHashJoin<>(joinType.swap(), hashConjuncts, otherConjuncts,
new DistributeHint(DistributeType.NONE),
Optional.empty(),
joinProperties,
right, left));
}
}
return plans;
return new LogicalJoin<>(joinType, hashConjuncts, otherConjuncts, left, right);
}
@Override
public void addGroup(long bitmap, Group group) {
Preconditions.checkArgument(LongBitmap.getCardinality(bitmap) == 1);
usdEdges.put(bitmap, new BitSet());
Plan plan = proposeProject(Lists.newArrayList(new GroupPlan(group)), new ArrayList<>(), bitmap, bitmap).get(0);
Plan plan = proposeProject(new GroupPlan(group), new ArrayList<>(), bitmap, bitmap);
if (!(plan instanceof GroupPlan)) {
CopyInResult copyInResult = jobContext.getCascadesContext().getMemo().copyIn(plan, null, false, planTable);
group = copyInResult.correspondingExpression.getOwnerGroup();
@ -274,59 +231,13 @@ public class PlanReceiver implements AbstractReceiver {
@Override
public Group getBestPlan(long bitmap) {
// If there are some rules relied on the logical join, we need to make logical Expression
// However, it cost 15% of total optimized time.
makeLogicalExpression(() -> planTable.get(bitmap));
return planTable.get(bitmap);
}
private void makeLogicalExpression(Supplier<Group> root) {
if (!root.get().getLogicalExpressions().isEmpty()) {
return;
}
// only makeLogicalExpression for those winners
Set<GroupExpression> hasGenerated = new HashSet<>();
for (PhysicalProperties physicalProperties : root.get().getAllProperties()) {
GroupExpression groupExpression = root.get().getBestPlan(physicalProperties);
if (hasGenerated.contains(groupExpression) || groupExpression.getPlan() instanceof PhysicalDistribute) {
continue;
}
hasGenerated.add(groupExpression);
// process child first, plan's child may be changed due to mergeGroup
// due to mergeGroup, the children Group of groupExpression may be replaced, so we need to use lambda to
// get the child to make we can get child at the time we use child.
// If we use for child: groupExpression.children(), it means that we take it in advance. It may cause NPE,
// work flow: get children() to get left, right -> copyIn left() -> mergeGroup -> right is merged -> NPE
Plan physicalPlan = groupExpression.getPlan();
for (int i = 0; i < groupExpression.children().size(); i++) {
int childIdx = i;
makeLogicalExpression(() -> groupExpression.child(childIdx));
}
Plan logicalPlan;
if (physicalPlan instanceof PhysicalProject) {
PhysicalProject physicalProject = (PhysicalProject) physicalPlan;
logicalPlan = new LogicalProject<>(physicalProject.getProjects(),
new GroupPlan(groupExpression.child(0)));
} else if (physicalPlan instanceof AbstractPhysicalJoin) {
AbstractPhysicalJoin physicalJoin = (AbstractPhysicalJoin) physicalPlan;
logicalPlan = new LogicalJoin<>(physicalJoin.getJoinType(), physicalJoin.getHashJoinConjuncts(),
physicalJoin.getOtherJoinConjuncts(),
new DistributeHint(DistributeType.NONE), physicalJoin.getMarkJoinSlotReference(),
groupExpression.children().stream().map(g -> new GroupPlan(g)).collect(Collectors.toList()));
} else {
throw new RuntimeException("DPhyp can only handle join and project operator");
}
jobContext.getCascadesContext().getMemo().copyIn(logicalPlan, root.get(), false, planTable);
}
}
private List<Plan> proposeProject(List<Plan> allChild, List<JoinEdge> edges, long left, long right) {
private LogicalPlan proposeProject(LogicalPlan join, List<JoinEdge> edges, long left, long right) {
long fullKey = LongBitmap.newBitmapUnion(left, right);
List<Slot> outputs = allChild.get(0).getOutput();
Set<Slot> outputSet = allChild.get(0).getOutputSet();
List<Slot> outputs = join.getOutput();
Set<Slot> outputSet = join.getOutputSet();
List<NamedExpression> complexProjects = new ArrayList<>();
// Calculate complex expression should be done by current(fullKey) node
@ -354,40 +265,29 @@ public class PlanReceiver implements AbstractReceiver {
// calculate required columns by all parents
Set<Slot> requireSlots = calculateRequiredSlots(left, right, edges);
List<NamedExpression> allProjects = Stream.concat(
outputs.stream().filter(e -> requireSlots.contains(e)),
outputs.stream().filter(requireSlots::contains),
complexProjects.stream().filter(e -> requireSlots.contains(e.toSlot()))
).collect(Collectors.toList());
// propose physical project
// propose logical project
if (allProjects.isEmpty()) {
allProjects.add(ExpressionUtils.selectMinimumColumn(outputs));
}
if (outputSet.equals(new HashSet<>(allProjects))) {
return allChild;
return join;
}
Set<Slot> childOutputSet = allChild.get(0).getOutputSet();
Set<Slot> childOutputSet = join.getOutputSet();
List<NamedExpression> projects = allProjects.stream()
.filter(expr ->
childOutputSet.containsAll(expr.getInputSlots()))
.collect(Collectors.toList());
LogicalPlan project = join;
if (!outputSet.equals(new HashSet<>(projects))) {
LogicalProperties projectProperties = new LogicalProperties(
() -> projects.stream()
.map(NamedExpression::toSlot)
.collect(ImmutableList.toImmutableList()), () -> FunctionalDependencies.EMPTY_FUNC_DEPS);
allChild = allChild.stream()
.map(c -> new PhysicalProject<>(projects, projectProperties, c))
.collect(Collectors.toList());
}
if (!(!projects.isEmpty() && projects.size() == allProjects.size())) {
Set<NamedExpression> s1 = projects.stream().collect(Collectors.toSet());
List<NamedExpression> s2 = allProjects.stream().filter(e -> !s1.contains(e)).collect(Collectors.toList());
System.out.println(s2);
project = new LogicalProject<>(projects, join);
}
Preconditions.checkState(!projects.isEmpty() && projects.size() == allProjects.size(),
" there are some projects left " + projects + allProjects);
return allChild;
" there are some projects left %s %s", projects, allProjects);
return project;
}
}

View File

@ -226,10 +226,6 @@ public class RuleSet {
.addAll(OTHER_REORDER_RULES)
.build();
public static final List<Rule> DPHYP_REORDER_RULES = ImmutableList.<Rule>builder()
.add(JoinCommute.BUSHY.build())
.build();
public static final List<Rule> MATERIALIZED_VIEW_RULES = planRuleFactories()
.add(MaterializedViewOnlyJoinRule.INSTANCE)
.add(MaterializedViewProjectJoinRule.INSTANCE)
@ -243,6 +239,11 @@ public class RuleSet {
.add(MaterializedViewFilterProjectAggregateRule.INSTANCE)
.build();
public static final List<Rule> DPHYP_REORDER_RULES = ImmutableList.<Rule>builder()
.addAll(MATERIALIZED_VIEW_RULES)
.add(JoinCommute.BUSHY.build())
.build();
public List<Rule> getDPHypReorderRules() {
return DPHYP_REORDER_RULES;
}

View File

@ -91,6 +91,13 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
Optional.empty(), Optional.empty(), leftChild, rightChild);
}
public LogicalJoin(JoinType joinType, List<Expression> hashJoinConjuncts, List<Expression> otherJoinConjuncts,
LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild) {
this(joinType, hashJoinConjuncts, otherJoinConjuncts,
new DistributeHint(DistributeType.NONE), Optional.empty(),
Optional.empty(), Optional.empty(), leftChild, rightChild);
}
public LogicalJoin(JoinType joinType, List<Expression> hashJoinConjuncts, List<Expression> otherJoinConjuncts,
DistributeHint hint, LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild) {
this(joinType, hashJoinConjuncts, otherJoinConjuncts, hint, Optional.empty(), Optional.empty(),