[refactor](multicast) change the way multicast do filter, project and shuffle (#21412)

Co-authored-by: Jerry Hu <mrhhsg@gmail.com>

1. Filtering is done at the sending end rather than the receiving end
2. Projection is done at the sending end rather than the receiving end
3. Each sender can use different shuffle policies to send data
This commit is contained in:
morrySnow
2023-07-04 16:51:07 +08:00
committed by GitHub
parent 09f414e0f4
commit 90dd8716ed
29 changed files with 436 additions and 201 deletions

View File

@ -247,27 +247,21 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
@Override
public PlanFragment visitPhysicalDistribute(PhysicalDistribute<? extends Plan> distribute,
PlanTranslatorContext context) {
PlanFragment childFragment = distribute.child().accept(this, context);
PlanFragment inputFragment = distribute.child().accept(this, context);
// TODO: why need set streaming here? should remove this.
if (childFragment.getPlanRoot() instanceof AggregationNode
if (inputFragment.getPlanRoot() instanceof AggregationNode
&& distribute.child() instanceof PhysicalHashAggregate
&& context.getFirstAggregateInFragment(childFragment) == distribute.child()) {
&& context.getFirstAggregateInFragment(inputFragment) == distribute.child()) {
PhysicalHashAggregate<?> hashAggregate = (PhysicalHashAggregate<?>) distribute.child();
if (hashAggregate.getAggPhase() == AggPhase.LOCAL
&& hashAggregate.getAggMode() == AggMode.INPUT_TO_BUFFER) {
AggregationNode aggregationNode = (AggregationNode) childFragment.getPlanRoot();
AggregationNode aggregationNode = (AggregationNode) inputFragment.getPlanRoot();
aggregationNode.setUseStreamingPreagg(hashAggregate.isMaybeUsingStream());
}
}
ExchangeNode exchangeNode = new ExchangeNode(context.nextPlanNodeId(), childFragment.getPlanRoot());
ExchangeNode exchangeNode = new ExchangeNode(context.nextPlanNodeId(), inputFragment.getPlanRoot());
updateLegacyPlanIdToPhysicalPlan(exchangeNode, distribute);
exchangeNode.setNumInstances(childFragment.getPlanRoot().getNumInstances());
if (distribute.getDistributionSpec() instanceof DistributionSpecGather) {
// gather to one instance
exchangeNode.setNumInstances(1);
}
List<ExprId> validOutputIds = distribute.getOutputExprIds();
if (distribute.child() instanceof PhysicalHashAggregate) {
// we must add group by keys to output list,
@ -282,8 +276,28 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
}
DataPartition dataPartition = toDataPartition(distribute.getDistributionSpec(), validOutputIds, context);
PlanFragment parentFragment = new PlanFragment(context.nextFragmentId(), exchangeNode, dataPartition);
childFragment.setDestination(exchangeNode);
childFragment.setOutputPartition(dataPartition);
exchangeNode.setNumInstances(inputFragment.getPlanRoot().getNumInstances());
if (distribute.getDistributionSpec() instanceof DistributionSpecGather) {
// gather to one instance
exchangeNode.setNumInstances(1);
}
// process multicast sink
if (inputFragment instanceof MultiCastPlanFragment) {
MultiCastDataSink multiCastDataSink = (MultiCastDataSink) inputFragment.getSink();
DataStreamSink dataStreamSink = multiCastDataSink.getDataStreamSinks().get(
multiCastDataSink.getDataStreamSinks().size() - 1);
TupleDescriptor tupleDescriptor = generateTupleDesc(distribute.getOutput(), null, context);
exchangeNode.updateTupleIds(tupleDescriptor);
dataStreamSink.setExchNodeId(exchangeNode.getId());
dataStreamSink.setOutputPartition(dataPartition);
parentFragment.addChild(inputFragment);
((MultiCastPlanFragment) inputFragment).addToDest(exchangeNode);
} else {
inputFragment.setDestination(exchangeNode);
inputFragment.setOutputPartition(dataPartition);
}
context.addPlanFragment(parentFragment);
return parentFragment;
}
@ -760,71 +774,23 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
MultiCastDataSink multiCastDataSink = (MultiCastDataSink) multiCastFragment.getSink();
Preconditions.checkState(multiCastDataSink != null, "invalid multiCastDataSink");
PhysicalCTEProducer cteProducer = context.getCteProduceMap().get(cteId);
PhysicalCTEProducer<?> cteProducer = context.getCteProduceMap().get(cteId);
Preconditions.checkState(cteProducer != null, "invalid cteProducer");
ExchangeNode exchangeNode = new ExchangeNode(context.nextPlanNodeId(), multiCastFragment.getPlanRoot());
DataStreamSink streamSink = new DataStreamSink(exchangeNode.getId());
streamSink.setPartition(DataPartition.RANDOM);
// set datasink to multicast data sink but do not set target now
// target will be set when translate distribute
DataStreamSink streamSink = new DataStreamSink();
streamSink.setFragment(multiCastFragment);
multiCastDataSink.getDataStreamSinks().add(streamSink);
multiCastDataSink.getDestinations().add(Lists.newArrayList());
exchangeNode.setNumInstances(multiCastFragment.getPlanRoot().getNumInstances());
PlanFragment consumeFragment = new PlanFragment(context.nextFragmentId(), exchangeNode,
multiCastFragment.getDataPartition());
Map<Slot, Slot> projectMap = Maps.newHashMap();
projectMap.putAll(cteConsumer.getProducerToConsumerSlotMap());
List<NamedExpression> execList = new ArrayList<>();
PlanNode inputPlanNode = consumeFragment.getPlanRoot();
List<Slot> cteProjects = cteProducer.getProjects();
for (Slot slot : cteProjects) {
if (projectMap.containsKey(slot)) {
execList.add(projectMap.get(slot));
} else {
throw new RuntimeException("could not find slot in cte producer consumer projectMap");
}
// update expr to slot mapping
for (Slot producerSlot : cteProducer.getProjects()) {
Slot consumerSlot = cteConsumer.getProducerToConsumerSlotMap().get(producerSlot);
SlotRef slotRef = context.findSlotRef(producerSlot.getExprId());
context.addExprIdSlotRefPair(consumerSlot.getExprId(), slotRef);
}
List<Slot> slotList = execList
.stream()
.map(NamedExpression::toSlot)
.collect(Collectors.toList());
TupleDescriptor tupleDescriptor = generateTupleDesc(slotList, null, context);
// update tuple list and tblTupleList
inputPlanNode.getTupleIds().clear();
inputPlanNode.getTupleIds().add(tupleDescriptor.getId());
inputPlanNode.getTblRefIds().clear();
inputPlanNode.getTblRefIds().add(tupleDescriptor.getId());
inputPlanNode.getNullableTupleIds().clear();
inputPlanNode.getNullableTupleIds().add(tupleDescriptor.getId());
List<Expr> execExprList = execList
.stream()
.map(e -> ExpressionTranslator.translate(e, context))
.collect(Collectors.toList());
inputPlanNode.setProjectList(execExprList);
inputPlanNode.setOutputTupleDesc(tupleDescriptor);
// update data partition
consumeFragment.setDataPartition(DataPartition.RANDOM);
SelectNode projectNode = new SelectNode(context.nextPlanNodeId(), inputPlanNode);
consumeFragment.setPlanRoot(projectNode);
multiCastFragment.getDestNodeList().add(exchangeNode);
consumeFragment.addChild(multiCastFragment);
context.getPlanFragments().add(consumeFragment);
return consumeFragment;
return multiCastFragment;
}
@Override
@ -859,6 +825,17 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
}
PlanFragment inputFragment = filter.child(0).accept(this, context);
// process multicast sink
if (inputFragment instanceof MultiCastPlanFragment) {
MultiCastDataSink multiCastDataSink = (MultiCastDataSink) inputFragment.getSink();
DataStreamSink dataStreamSink = multiCastDataSink.getDataStreamSinks().get(
multiCastDataSink.getDataStreamSinks().size() - 1);
filter.getConjuncts().stream()
.map(e -> ExpressionTranslator.translate(e, context))
.forEach(dataStreamSink::addConjunct);
return inputFragment;
}
PlanNode planNode = inputFragment.getPlanRoot();
if (planNode instanceof ExchangeNode || planNode instanceof SortNode || planNode instanceof UnionNode) {
// the three nodes don't support conjuncts, need create a SelectNode to filter data
@ -1397,19 +1374,31 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla
((AbstractPhysicalJoin<?, ?>) project.child(0).child(0)).setShouldTranslateOutput(false);
}
}
PlanFragment inputFragment = project.child(0).accept(this, context);
List<Expr> execExprList = project.getProjects()
.stream()
.map(e -> ExpressionTranslator.translate(e, context))
.collect(Collectors.toList());
// TODO: fix the project alias of an aliased relation.
PlanNode inputPlanNode = inputFragment.getPlanRoot();
List<Slot> slotList = project.getProjects()
.stream()
.map(NamedExpression::toSlot)
.collect(Collectors.toList());
// process multicast sink
if (inputFragment instanceof MultiCastPlanFragment) {
MultiCastDataSink multiCastDataSink = (MultiCastDataSink) inputFragment.getSink();
DataStreamSink dataStreamSink = multiCastDataSink.getDataStreamSinks().get(
multiCastDataSink.getDataStreamSinks().size() - 1);
TupleDescriptor tupleDescriptor = generateTupleDesc(slotList, null, context);
dataStreamSink.setProjections(execExprList);
dataStreamSink.setOutputTupleDesc(tupleDescriptor);
return inputFragment;
}
PlanNode inputPlanNode = inputFragment.getPlanRoot();
List<Expr> predicateList = inputPlanNode.getConjuncts();
Set<SlotId> requiredSlotIdSet = Sets.newHashSet();
for (Expr expr : execExprList) {

View File

@ -114,7 +114,7 @@ public class ChildOutputPropertyDeriver extends PlanVisitor<PhysicalProperties,
public PhysicalProperties visitPhysicalCTEConsumer(
PhysicalCTEConsumer cteConsumer, PlanContext context) {
Preconditions.checkState(childrenOutputProperties.size() == 0);
return PhysicalProperties.ANY;
return PhysicalProperties.MUST_SHUFFLE;
}
@Override

View File

@ -27,9 +27,11 @@ import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.plans.AggMode;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
import org.apache.doris.nereids.trees.plans.physical.PhysicalFilter;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalNestedLoopJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
import org.apache.doris.nereids.trees.plans.physical.PhysicalSetOperation;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.nereids.util.JoinUtils;
@ -74,16 +76,33 @@ public class ChildrenPropertiesRegulator extends PlanVisitor<Boolean, Void> {
@Override
public Boolean visit(Plan plan, Void context) {
// process must shuffle
for (int i = 0; i < children.size(); i++) {
DistributionSpec distributionSpec = childrenProperties.get(i).getDistributionSpec();
if (distributionSpec instanceof DistributionSpecMustShuffle) {
updateChildEnforceAndCost(i, PhysicalProperties.EXECUTION_ANY);
}
}
return true;
}
@Override
public Boolean visitPhysicalHashAggregate(PhysicalHashAggregate<? extends Plan> agg, Void context) {
// forbid one phase agg on distribute
if (agg.getAggMode() == AggMode.INPUT_TO_RESULT
&& children.get(0).getPlan() instanceof PhysicalDistribute) {
// this means one stage gather agg, usually bad pattern
return false;
}
// process must shuffle
visit(agg, context);
// process agg
return true;
}
@Override
public Boolean visitPhysicalFilter(PhysicalFilter<? extends Plan> filter, Void context) {
// do not process must shuffle
return true;
}
@ -93,6 +112,9 @@ public class ChildrenPropertiesRegulator extends PlanVisitor<Boolean, Void> {
Preconditions.checkArgument(children.size() == 2, "children.size() != 2");
Preconditions.checkArgument(childrenProperties.size() == 2);
Preconditions.checkArgument(requiredProperties.size() == 2);
// process must shuffle
visit(hashJoin, context);
// process hash join
DistributionSpec leftDistributionSpec = childrenProperties.get(0).getDistributionSpec();
DistributionSpec rightDistributionSpec = childrenProperties.get(1).getDistributionSpec();
@ -229,6 +251,9 @@ public class ChildrenPropertiesRegulator extends PlanVisitor<Boolean, Void> {
Preconditions.checkArgument(children.size() == 2, String.format("children.size() is %d", children.size()));
Preconditions.checkArgument(childrenProperties.size() == 2);
Preconditions.checkArgument(requiredProperties.size() == 2);
// process must shuffle
visit(nestedLoopJoin, context);
// process nlj
DistributionSpec rightDistributionSpec = childrenProperties.get(1).getDistributionSpec();
if (rightDistributionSpec instanceof DistributionSpecStorageGather) {
updateChildEnforceAndCost(1, PhysicalProperties.GATHER);
@ -236,8 +261,17 @@ public class ChildrenPropertiesRegulator extends PlanVisitor<Boolean, Void> {
return true;
}
@Override
public Boolean visitPhysicalProject(PhysicalProject<? extends Plan> project, Void context) {
// do not process must shuffle
return true;
}
@Override
public Boolean visitPhysicalSetOperation(PhysicalSetOperation setOperation, Void context) {
// process must shuffle
visit(setOperation, context);
// process set operation
if (children.isEmpty()) {
return true;
}

View File

@ -0,0 +1,35 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.properties;
/**
* present data must use after shuffle
*/
public class DistributionSpecMustShuffle extends DistributionSpec {
public static final DistributionSpecMustShuffle INSTANCE = new DistributionSpecMustShuffle();
public DistributionSpecMustShuffle() {
super();
}
@Override
public boolean satisfy(DistributionSpec other) {
return other instanceof DistributionSpecAny;
}
}

View File

@ -44,6 +44,8 @@ public class PhysicalProperties {
public static PhysicalProperties STORAGE_GATHER = new PhysicalProperties(DistributionSpecStorageGather.INSTANCE);
public static PhysicalProperties MUST_SHUFFLE = new PhysicalProperties(DistributionSpecMustShuffle.INSTANCE);
private final OrderSpec orderSpec;
private final DistributionSpec distributionSpec;

View File

@ -20,19 +20,38 @@
package org.apache.doris.planner;
import org.apache.doris.analysis.BitmapFilterPredicate;
import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.TupleDescriptor;
import org.apache.doris.thrift.TDataSink;
import org.apache.doris.thrift.TDataSinkType;
import org.apache.doris.thrift.TDataStreamSink;
import org.apache.doris.thrift.TExplainLevel;
import com.google.common.collect.Lists;
import org.springframework.util.CollectionUtils;
import java.util.List;
/**
* Data sink that forwards data to an exchange node.
*/
public class DataStreamSink extends DataSink {
private final PlanNodeId exchNodeId;
private PlanNodeId exchNodeId;
private DataPartition outputPartition;
protected TupleDescriptor outputTupleDesc;
protected List<Expr> projections;
protected List<Expr> conjuncts = Lists.newArrayList();
public DataStreamSink() {
}
public DataStreamSink(PlanNodeId exchNodeId) {
this.exchNodeId = exchNodeId;
}
@ -42,23 +61,66 @@ public class DataStreamSink extends DataSink {
return exchNodeId;
}
public void setExchNodeId(PlanNodeId exchNodeId) {
this.exchNodeId = exchNodeId;
}
@Override
public DataPartition getOutputPartition() {
return outputPartition;
}
public void setPartition(DataPartition partition) {
outputPartition = partition;
public void setOutputPartition(DataPartition outputPartition) {
this.outputPartition = outputPartition;
}
public TupleDescriptor getOutputTupleDesc() {
return outputTupleDesc;
}
public void setOutputTupleDesc(TupleDescriptor outputTupleDesc) {
this.outputTupleDesc = outputTupleDesc;
}
public List<Expr> getProjections() {
return projections;
}
public void setProjections(List<Expr> projections) {
this.projections = projections;
}
public List<Expr> getConjuncts() {
return conjuncts;
}
public void setConjuncts(List<Expr> conjuncts) {
this.conjuncts = conjuncts;
}
public void addConjunct(Expr conjunct) {
this.conjuncts.add(conjunct);
}
@Override
public String getExplainString(String prefix, TExplainLevel explainLevel) {
StringBuilder strBuilder = new StringBuilder();
strBuilder.append(prefix + "STREAM DATA SINK\n");
strBuilder.append(prefix + " EXCHANGE ID: " + exchNodeId + "\n");
strBuilder.append(prefix).append("STREAM DATA SINK\n");
strBuilder.append(prefix).append(" EXCHANGE ID: ").append(exchNodeId);
if (outputPartition != null) {
strBuilder.append(prefix + " " + outputPartition.getExplainString(explainLevel));
strBuilder.append("\n").append(prefix).append(" ").append(outputPartition.getExplainString(explainLevel));
}
if (!conjuncts.isEmpty()) {
Expr expr = PlanNode.convertConjunctsToAndCompoundPredicate(conjuncts);
strBuilder.append(prefix).append(" CONJUNCTS: ").append(expr.toSql()).append("\n");
}
if (!CollectionUtils.isEmpty(projections)) {
strBuilder.append(prefix).append(" PROJECTIONS: ")
.append(PlanNode.getExplainString(projections)).append("\n");
strBuilder.append(prefix).append(" PROJECTION TUPLE: ").append(outputTupleDesc.getId());
strBuilder.append("\n");
}
return strBuilder.toString();
}
@ -67,6 +129,19 @@ public class DataStreamSink extends DataSink {
TDataSink result = new TDataSink(TDataSinkType.DATA_STREAM_SINK);
TDataStreamSink tStreamSink =
new TDataStreamSink(exchNodeId.asInt(), outputPartition.toThrift());
for (Expr e : conjuncts) {
if (!(e instanceof BitmapFilterPredicate)) {
tStreamSink.addToConjuncts(e.treeToThrift());
}
}
if (projections != null) {
for (Expr expr : projections) {
tStreamSink.addToOutputExprs(expr.treeToThrift());
}
}
if (outputTupleDesc != null) {
tStreamSink.setOutputTupleId(outputTupleDesc.getId().asInt());
}
result.setStreamSink(tStreamSink);
return result;
}

View File

@ -108,15 +108,21 @@ public class ExchangeNode extends PlanNode {
public final void computeTupleIds() {
PlanNode inputNode = getChild(0);
TupleDescriptor outputTupleDesc = inputNode.getOutputTupleDesc();
updateTupleIds(outputTupleDesc);
}
public void updateTupleIds(TupleDescriptor outputTupleDesc) {
if (outputTupleDesc != null) {
tupleIds.clear();
tupleIds.add(outputTupleDesc.getId());
tblRefIds.add(outputTupleDesc.getId());
nullableTupleIds.add(outputTupleDesc.getId());
} else {
clearTupleIds();
tupleIds.addAll(getChild(0).getTupleIds());
tblRefIds.addAll(getChild(0).getTblRefIds());
nullableTupleIds.addAll(getChild(0).getNullableTupleIds());
}
tblRefIds.addAll(getChild(0).getTblRefIds());
nullableTupleIds.addAll(getChild(0).getNullableTupleIds());
}
@Override

View File

@ -35,10 +35,11 @@ public class MultiCastPlanFragment extends PlanFragment {
this.children.addAll(planFragment.getChildren());
}
public List<ExchangeNode> getDestNodeList() {
return destNodeList;
public void addToDest(ExchangeNode exchangeNode) {
destNodeList.add(exchangeNode);
}
public List<PlanFragment> getDestFragmentList() {
return destNodeList.stream().map(PlanNode::getFragment).collect(Collectors.toList());
}

View File

@ -256,7 +256,7 @@ public class PlanFragment extends TreeNode<PlanFragment> {
Preconditions.checkState(sink == null);
// we're streaming to an exchange node
DataStreamSink streamSink = new DataStreamSink(destNode.getId());
streamSink.setPartition(outputPartition);
streamSink.setOutputPartition(outputPartition);
streamSink.setFragment(this);
sink = streamSink;
} else {

View File

@ -415,7 +415,7 @@ public abstract class PlanNode extends TreeNode<PlanNode> implements PlanStats {
}
}
protected Expr convertConjunctsToAndCompoundPredicate(List<Expr> conjuncts) {
public static Expr convertConjunctsToAndCompoundPredicate(List<Expr> conjuncts) {
List<Expr> targetConjuncts = Lists.newArrayList(conjuncts);
while (targetConjuncts.size() > 1) {
List<Expr> newTargetConjuncts = Lists.newArrayList();
@ -824,7 +824,7 @@ public abstract class PlanNode extends TreeNode<PlanNode> implements PlanStats {
return output.toString();
}
protected String getExplainString(List<? extends Expr> exprs) {
public static String getExplainString(List<? extends Expr> exprs) {
if (exprs == null) {
return "";
}