[enhancement](Nereids) planner performance speed up (#12858)

optimize planner by:
1. reduce duplicated calculation on equals, getOutput, computeOutput eq.
2. getOnClauseUsedSlots: the two side of equalTo is centainly slot, so not need to use List.
This commit is contained in:
mch_ucchi
2022-09-29 16:01:10 +08:00
committed by GitHub
parent 34b14a71c8
commit 1ae9454771
9 changed files with 122 additions and 73 deletions

View File

@ -38,7 +38,7 @@ import java.util.List;
*/
public class PlanContext {
// array of children's derived stats
private final List<StatsDeriveResult> childrenStats = Lists.newArrayList();
private final List<StatsDeriveResult> childrenStats;
// attached group expression
private final GroupExpression groupExpression;
@ -47,6 +47,7 @@ public class PlanContext {
*/
public PlanContext(GroupExpression groupExpression) {
this.groupExpression = groupExpression;
childrenStats = Lists.newArrayListWithCapacity(groupExpression.children().size());
for (Group group : groupExpression.children()) {
childrenStats.add(group.getStatistics());
@ -76,11 +77,7 @@ public class PlanContext {
}
public List<Id> getChildOutputIds(int index) {
List<Id> ids = Lists.newArrayList();
childLogicalPropertyAt(index).getOutput().forEach(slot -> {
ids.add(slot.getExprId());
});
return ids;
return childLogicalPropertyAt(index).getOutputExprIds();
}
/**

View File

@ -19,8 +19,6 @@ package org.apache.doris.nereids.cost;
import com.google.common.base.Preconditions;
import java.util.stream.Stream;
/**
* Use for estimating the cost of plan.
*/
@ -88,13 +86,17 @@ public final class CostEstimate {
}
/**
* Sums partial cost estimates of some (single) plan node.
* sum of cost estimate
*/
public static CostEstimate sum(CostEstimate one, CostEstimate two, CostEstimate... more) {
return Stream.concat(Stream.of(one, two), Stream.of(more))
.reduce(zero(), (a, b) -> new CostEstimate(
a.cpuCost + b.cpuCost,
a.memoryCost + b.memoryCost,
a.networkCost + b.networkCost));
double cpuCostSum = one.cpuCost + two.cpuCost;
double memoryCostSum = one.memoryCost + two.memoryCost;
double networkCostSum = one.networkCost + one.networkCost;
for (CostEstimate costEstimate : more) {
cpuCostSum += costEstimate.cpuCost;
memoryCostSum += costEstimate.memoryCost;
networkCostSum += costEstimate.networkCost;
}
return new CostEstimate(cpuCostSum, memoryCostSum, networkCostSum);
}
}

View File

@ -28,6 +28,7 @@ import com.google.common.collect.Sets;
import java.util.BitSet;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@ -68,9 +69,11 @@ public class DistributionSpecHash extends DistributionSpec {
this(leftColumns, shuffleType, -1L, Collections.emptySet());
Objects.requireNonNull(rightColumns);
Preconditions.checkArgument(leftColumns.size() == rightColumns.size());
for (int i = 0; i < rightColumns.size(); i++) {
exprIdToEquivalenceSet.put(rightColumns.get(i), i);
equivalenceExprIds.get(i).add(rightColumns.get(i));
int i = 0;
Iterator<Set<ExprId>> iter = equivalenceExprIds.iterator();
for (ExprId id : rightColumns) {
exprIdToEquivalenceSet.put(id, i++);
iter.next().add(id);
}
}
@ -81,21 +84,23 @@ public class DistributionSpecHash extends DistributionSpec {
long tableId, Set<Long> partitionIds) {
this.orderedShuffledColumns = Objects.requireNonNull(orderedShuffledColumns);
this.shuffleType = Objects.requireNonNull(shuffleType);
this.tableId = tableId;
this.partitionIds = Objects.requireNonNull(partitionIds);
this.equivalenceExprIds = Lists.newArrayList();
this.exprIdToEquivalenceSet = Maps.newHashMap();
orderedShuffledColumns.forEach(id -> {
exprIdToEquivalenceSet.put(id, equivalenceExprIds.size());
this.tableId = tableId;
equivalenceExprIds = Lists.newArrayListWithCapacity(orderedShuffledColumns.size());
exprIdToEquivalenceSet = Maps.newHashMapWithExpectedSize(orderedShuffledColumns.size());
int i = 0;
for (ExprId id : orderedShuffledColumns) {
exprIdToEquivalenceSet.put(id, i++);
equivalenceExprIds.add(Sets.newHashSet(id));
});
}
}
/**
* Used in merge outside and put result into it.
*/
public DistributionSpecHash(List<ExprId> orderedShuffledColumns, ShuffleType shuffleType, long tableId,
Set<Long> partitionIds, List<Set<ExprId>> equivalenceExprIds, Map<ExprId, Integer> exprIdToEquivalenceSet) {
Set<Long> partitionIds, List<Set<ExprId>> equivalenceExprIds,
Map<ExprId, Integer> exprIdToEquivalenceSet) {
this.orderedShuffledColumns = Objects.requireNonNull(orderedShuffledColumns);
this.shuffleType = Objects.requireNonNull(shuffleType);
this.tableId = tableId;
@ -113,7 +118,8 @@ public class DistributionSpecHash extends DistributionSpec {
equivalenceExprId.addAll(right.getEquivalenceExprIds().get(i));
equivalenceExprIds.add(equivalenceExprId);
}
Map<ExprId, Integer> exprIdToEquivalenceSet = Maps.newHashMap();
Map<ExprId, Integer> exprIdToEquivalenceSet = Maps.newHashMapWithExpectedSize(
left.getExprIdToEquivalenceSet().size() + right.getExprIdToEquivalenceSet().size());
exprIdToEquivalenceSet.putAll(left.getExprIdToEquivalenceSet());
exprIdToEquivalenceSet.putAll(right.getExprIdToEquivalenceSet());
return new DistributionSpecHash(orderedShuffledColumns, shuffleType,
@ -208,16 +214,12 @@ public class DistributionSpecHash extends DistributionSpec {
return false;
}
DistributionSpecHash that = (DistributionSpecHash) o;
return tableId == that.tableId && orderedShuffledColumns.equals(that.orderedShuffledColumns)
&& shuffleType == that.shuffleType && partitionIds.equals(that.partitionIds)
&& equivalenceExprIds.equals(that.equivalenceExprIds)
&& exprIdToEquivalenceSet.equals(that.exprIdToEquivalenceSet);
return shuffleType == that.shuffleType && orderedShuffledColumns.equals(that.orderedShuffledColumns);
}
@Override
public int hashCode() {
return Objects.hash(orderedShuffledColumns, shuffleType, tableId, partitionIds,
equivalenceExprIds, exprIdToEquivalenceSet);
return Objects.hash(shuffleType, orderedShuffledColumns);
}
@Override
@ -245,6 +247,5 @@ public class DistributionSpecHash extends DistributionSpec {
BUCKETED,
// output, all distribute enforce
ENFORCED,
;
}
}

View File

@ -17,6 +17,7 @@
package org.apache.doris.nereids.properties;
import org.apache.doris.common.Id;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
@ -27,6 +28,7 @@ import com.google.common.base.Suppliers;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
/**
@ -35,6 +37,9 @@ import java.util.stream.Collectors;
public class LogicalProperties {
protected final Supplier<List<Slot>> outputSupplier;
protected final Supplier<HashSet<ExprId>> outputSetSupplier;
private Integer hashCode = null;
private Set<ExprId> outputExprIdSet;
private List<Id> outputExprIds;
/**
* constructor of LogicalProperties.
@ -56,6 +61,21 @@ public class LogicalProperties {
return outputSupplier.get();
}
public Set<ExprId> getOutputExprIdSet() {
if (outputExprIdSet == null) {
outputExprIdSet = this.outputSupplier.get().stream()
.map(NamedExpression::getExprId).collect(Collectors.toSet());
}
return outputExprIdSet;
}
public List<Id> getOutputExprIds() {
if (outputExprIds == null) {
outputExprIds = outputExprIdSet.stream().map(Id.class::cast).collect(Collectors.toList());
}
return outputExprIds;
}
public LogicalProperties withOutput(List<Slot> output) {
return new LogicalProperties(Suppliers.ofInstance(output));
}
@ -74,6 +94,9 @@ public class LogicalProperties {
@Override
public int hashCode() {
return Objects.hash(outputSetSupplier.get());
if (hashCode == null) {
hashCode = Objects.hash(outputSetSupplier.get());
}
return hashCode;
}
}

View File

@ -34,6 +34,8 @@ public class PhysicalProperties {
private final DistributionSpec distributionSpec;
private Integer hashCode = null;
private PhysicalProperties() {
this.orderSpec = new OrderSpec();
this.distributionSpec = DistributionSpecAny.INSTANCE;
@ -80,12 +82,18 @@ public class PhysicalProperties {
return false;
}
PhysicalProperties that = (PhysicalProperties) o;
if (this.hashCode() != that.hashCode()) {
return false;
}
return orderSpec.equals(that.orderSpec)
&& distributionSpec.equals(that.distributionSpec);
}
@Override
public int hashCode() {
return Objects.hash(orderSpec, distributionSpec);
if (hashCode == null) {
hashCode = Objects.hash(orderSpec, distributionSpec);
}
return hashCode;
}
}

View File

@ -69,8 +69,9 @@ public class RequestPropertyDeriver extends PlanVisitor<Void, PlanContext> {
@Override
public Void visit(Plan plan, PlanContext context) {
List<PhysicalProperties> requiredPropertyList = Lists.newArrayList();
for (int i = 0; i < context.getGroupExpression().arity(); i++) {
List<PhysicalProperties> requiredPropertyList =
Lists.newArrayListWithCapacity(context.getGroupExpression().arity());
for (int i = context.getGroupExpression().arity(); i > 0; --i) {
requiredPropertyList.add(PhysicalProperties.ANY);
}
requestPropertyToChildren.add(requiredPropertyList);

View File

@ -21,6 +21,7 @@ import org.apache.doris.nereids.analyzer.Unbound;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.trees.AbstractTreeNode;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.util.TreeStringUtils;
import org.apache.doris.statistics.StatsDeriveResult;
@ -31,6 +32,7 @@ import com.google.common.base.Suppliers;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import javax.annotation.Nullable;
/**
@ -122,6 +124,11 @@ public abstract class AbstractPlan extends AbstractTreeNode<Plan> implements Pla
return getLogicalProperties().getOutput();
}
@Override
public Set<ExprId> getOutputExprIdSet() {
return getLogicalProperties().getOutputExprIdSet();
}
@Override
public Plan child(int index) {
return super.child(index);

View File

@ -21,7 +21,9 @@ import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.properties.UnboundLogicalProperties;
import org.apache.doris.nereids.trees.TreeNode;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
@ -30,6 +32,7 @@ import com.google.common.collect.ImmutableSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
/**
* Abstract class for all plan node.
@ -79,6 +82,10 @@ public interface Plan extends TreeNode<Plan> {
return ImmutableSet.copyOf(getOutput());
}
default Set<ExprId> getOutputExprIdSet() {
return getOutput().stream().map(NamedExpression::getExprId).collect(Collectors.toSet());
}
/**
* Get the input slot set of the plan.
* The result is collected from all the expressions' input slots appearing in the plan node.

View File

@ -28,6 +28,7 @@ import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.Join;
@ -37,7 +38,6 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.qe.ConnectContext;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
@ -59,7 +59,7 @@ public class JoinUtils {
return !(join.getJoinType().isReturnUnmatchedRightJoin());
}
private static class JoinSlotCoverageChecker {
private static final class JoinSlotCoverageChecker {
Set<ExprId> leftExprIds;
Set<ExprId> rightExprIds;
@ -68,16 +68,20 @@ public class JoinUtils {
rightExprIds = right.stream().map(Slot::getExprId).collect(Collectors.toSet());
}
boolean isCoveredByLeftSlots(Set<Slot> slots) {
return slots.stream()
.map(Slot::getExprId)
.allMatch(leftExprIds::contains);
JoinSlotCoverageChecker(Set<ExprId> left, Set<ExprId> right) {
leftExprIds = left;
rightExprIds = right;
}
boolean isCoveredByRightSlots(Set<Slot> slots) {
return slots.stream()
.map(Slot::getExprId)
.allMatch(rightExprIds::contains);
/**
* PushDownExpressionInHashConjuncts ensure the "slots" is only one slot.
*/
boolean isCoveredByLeftSlots(ExprId slot) {
return leftExprIds.contains(slot);
}
boolean isCoveredByRightSlots(ExprId slot) {
return rightExprIds.contains(slot);
}
/**
@ -116,10 +120,11 @@ public class JoinUtils {
* @param join join node
* @return pair of expressions, for hash table or not.
*/
public static Pair<List<Expression>, List<Expression>> extractExpressionForHashTable(LogicalJoin join) {
public static Pair<List<Expression>, List<Expression>> extractExpressionForHashTable(
LogicalJoin<GroupPlan, GroupPlan> join) {
if (join.getOtherJoinCondition().isPresent()) {
List<Expression> onExprs = ExpressionUtils.extractConjunction(
(Expression) join.getOtherJoinCondition().get());
join.getOtherJoinCondition().get());
List<Slot> leftSlots = join.left().getOutput();
List<Slot> rightSlots = join.right().getOutput();
return extractExpressionForHashTable(leftSlots, rightSlots, onExprs);
@ -152,37 +157,35 @@ public class JoinUtils {
*/
public static Pair<List<ExprId>, List<ExprId>> getOnClauseUsedSlots(
AbstractPhysicalJoin<? extends Plan, ? extends Plan> join) {
Pair<List<ExprId>, List<ExprId>> childSlotsExprId =
Pair.of(Lists.newArrayList(), Lists.newArrayList());
List<Slot> leftSlots = join.left().getOutput();
List<Slot> rightSlots = join.right().getOutput();
List<EqualTo> equalToList = join.getHashJoinConjuncts().stream()
.map(e -> (EqualTo) e).collect(Collectors.toList());
JoinSlotCoverageChecker checker = new JoinSlotCoverageChecker(leftSlots, rightSlots);
List<ExprId> exprIds1 = Lists.newArrayListWithCapacity(join.getHashJoinConjuncts().size());
List<ExprId> exprIds2 = Lists.newArrayListWithCapacity(join.getHashJoinConjuncts().size());
for (EqualTo equalTo : equalToList) {
Set<Slot> leftOnSlots = equalTo.left().collect(Slot.class::isInstance);
Set<Slot> rightOnSlots = equalTo.right().collect(Slot.class::isInstance);
List<ExprId> leftOnSlotsExprId = leftOnSlots.stream()
.map(Slot::getExprId).collect(Collectors.toList());
List<ExprId> rightOnSlotsExprId = rightOnSlots.stream()
.map(Slot::getExprId).collect(Collectors.toList());
if (checker.isCoveredByLeftSlots(leftOnSlots)
&& checker.isCoveredByRightSlots(rightOnSlots)) {
childSlotsExprId.first.addAll(leftOnSlotsExprId);
childSlotsExprId.second.addAll(rightOnSlotsExprId);
} else if (checker.isCoveredByLeftSlots(rightOnSlots)
&& checker.isCoveredByRightSlots(leftOnSlots)) {
childSlotsExprId.first.addAll(rightOnSlotsExprId);
childSlotsExprId.second.addAll(leftOnSlotsExprId);
JoinSlotCoverageChecker checker = new JoinSlotCoverageChecker(
join.left().getOutputExprIdSet(),
join.right().getOutputExprIdSet());
for (Expression expr : join.getHashJoinConjuncts()) {
EqualTo equalTo = (EqualTo) expr;
if (!(equalTo.left() instanceof Slot) || !(equalTo.right() instanceof Slot)) {
continue;
}
ExprId leftExprId = ((Slot) equalTo.left()).getExprId();
ExprId rightExprId = ((Slot) equalTo.right()).getExprId();
if (checker.isCoveredByLeftSlots(leftExprId)
&& checker.isCoveredByRightSlots(rightExprId)) {
exprIds1.add(leftExprId);
exprIds2.add(rightExprId);
} else if (checker.isCoveredByLeftSlots(rightExprId)
&& checker.isCoveredByRightSlots(leftExprId)) {
exprIds1.add(rightExprId);
exprIds2.add(leftExprId);
} else {
throw new RuntimeException("Could not generate valid equal on clause slot pairs for join: " + join);
}
}
Preconditions.checkState(childSlotsExprId.first.size() == childSlotsExprId.second.size());
return childSlotsExprId;
return Pair.of(exprIds1, exprIds2);
}
public static boolean shouldNestedLoopJoin(Join join) {