[mv](nereids) mv cost related PRs (#35652 #35701 #35864 #36368 #36789 #34970) (#37097)

## Proposed changes
pick from #35652 #35701 #35864 #36368 #36789 #34970

Issue Number: close #xxx

<!--Describe your changes.-->
This commit is contained in:
minghong
2024-07-04 09:42:11 +08:00
committed by GitHub
parent 077fda4259
commit 26be313d40
25 changed files with 494 additions and 299 deletions

View File

@ -164,6 +164,18 @@ public class MaterializedIndexMeta implements Writable, GsonPostProcessable {
initColumnNameMap();
}
public List<Column> getPrefixKeyColumns() {
List<Column> keys = Lists.newArrayList();
for (Column col : schema) {
if (col.isKey()) {
keys.add(col);
} else {
break;
}
}
return keys;
}
public void setSchemaHash(int newSchemaHash) {
this.schemaHash = newSchemaHash;
}

View File

@ -17,6 +17,7 @@
package org.apache.doris.nereids.cost;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.KeysType;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.nereids.PlanContext;
@ -24,14 +25,19 @@ import org.apache.doris.nereids.properties.DistributionSpec;
import org.apache.doris.nereids.properties.DistributionSpecGather;
import org.apache.doris.nereids.properties.DistributionSpecHash;
import org.apache.doris.nereids.properties.DistributionSpecReplicated;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.OlapScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalAssertNumRows;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDeferMaterializeOlapScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDeferMaterializeTopN;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
import org.apache.doris.nereids.trees.plans.physical.PhysicalEsScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalFileScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalFilter;
import org.apache.doris.nereids.trees.plans.physical.PhysicalGenerate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
@ -52,8 +58,11 @@ import org.apache.doris.qe.SessionVariable;
import org.apache.doris.statistics.Statistics;
import com.google.common.base.Preconditions;
import com.google.common.collect.Sets;
import java.util.Collections;
import java.util.List;
import java.util.Set;
class CostModelV1 extends PlanVisitor<Cost, PlanContext> {
@ -113,6 +122,57 @@ class CostModelV1 extends PlanVisitor<Cost, PlanContext> {
return CostV1.ofCpu(context.getSessionVariable(), rows - aggMvBonus);
}
private Set<Column> getColumnForRangePredicate(Set<Expression> expressions) {
Set<Column> columns = Sets.newHashSet();
for (Expression expr : expressions) {
if (expr instanceof ComparisonPredicate) {
ComparisonPredicate compare = (ComparisonPredicate) expr;
boolean hasLiteral = compare.left() instanceof Literal || compare.right() instanceof Literal;
boolean hasSlot = compare.left() instanceof SlotReference || compare.right() instanceof SlotReference;
if (hasSlot && hasLiteral) {
if (compare.left() instanceof SlotReference) {
if (((SlotReference) compare.left()).getColumn().isPresent()) {
columns.add(((SlotReference) compare.left()).getColumn().get());
}
} else {
if (((SlotReference) compare.right()).getColumn().isPresent()) {
columns.add(((SlotReference) compare.right()).getColumn().get());
}
}
}
}
}
return columns;
}
@Override
public Cost visitPhysicalFilter(PhysicalFilter<? extends Plan> filter, PlanContext context) {
double exprCost = expressionTreeCost(filter.getExpressions());
double filterCostFactor = 0.0001;
if (ConnectContext.get() != null) {
filterCostFactor = ConnectContext.get().getSessionVariable().filterCostFactor;
}
int prefixIndexMatched = 0;
if (filter.getGroupExpression().isPresent()) {
OlapScan olapScan = (OlapScan) filter.getGroupExpression().get().getFirstChildPlan(OlapScan.class);
if (olapScan != null) {
// check prefix index
long idxId = olapScan.getSelectedIndexId();
List<Column> keyColumns = olapScan.getTable().getIndexMetaByIndexId(idxId).getPrefixKeyColumns();
Set<Column> predicateColumns = getColumnForRangePredicate(filter.getConjuncts());
for (Column col : keyColumns) {
if (predicateColumns.contains(col)) {
prefixIndexMatched++;
} else {
break;
}
}
}
}
return CostV1.ofCpu(context.getSessionVariable(),
(filter.getConjuncts().size() - prefixIndexMatched + exprCost) * filterCostFactor);
}
@Override
public Cost visitPhysicalDeferMaterializeOlapScan(PhysicalDeferMaterializeOlapScan deferMaterializeOlapScan,
PlanContext context) {
@ -141,7 +201,8 @@ class CostModelV1 extends PlanVisitor<Cost, PlanContext> {
@Override
public Cost visitPhysicalProject(PhysicalProject<? extends Plan> physicalProject, PlanContext context) {
return CostV1.ofCpu(context.getSessionVariable(), 1);
double exprCost = expressionTreeCost(physicalProject.getProjects());
return CostV1.ofCpu(context.getSessionVariable(), exprCost + 1);
}
@Override
@ -252,16 +313,29 @@ class CostModelV1 extends PlanVisitor<Cost, PlanContext> {
intputRowCount * childStatistics.dataSizeFactor() * RANDOM_SHUFFLE_TO_HASH_SHUFFLE_FACTOR / beNumber);
}
private double expressionTreeCost(List<? extends Expression> expressions) {
double exprCost = 0.0;
ExpressionCostEvaluator expressionCostEvaluator = new ExpressionCostEvaluator();
for (Expression expr : expressions) {
if (!(expr instanceof SlotReference)) {
exprCost += expr.accept(expressionCostEvaluator, null);
}
}
return exprCost;
}
@Override
public Cost visitPhysicalHashAggregate(
PhysicalHashAggregate<? extends Plan> aggregate, PlanContext context) {
Statistics inputStatistics = context.getChildStatistics(0);
double exprCost = expressionTreeCost(aggregate.getExpressions());
if (aggregate.getAggPhase().isLocal()) {
return CostV1.of(context.getSessionVariable(), inputStatistics.getRowCount() / beNumber,
return CostV1.of(context.getSessionVariable(),
exprCost / 100 + inputStatistics.getRowCount() / beNumber,
inputStatistics.getRowCount() / beNumber, 0);
} else {
// global
return CostV1.of(context.getSessionVariable(), inputStatistics.getRowCount(),
return CostV1.of(context.getSessionVariable(), exprCost / 100 + inputStatistics.getRowCount(),
inputStatistics.getRowCount(), 0);
}
}
@ -289,7 +363,7 @@ class CostModelV1 extends PlanVisitor<Cost, PlanContext> {
double leftRowCount = probeStats.getRowCount();
double rightRowCount = buildStats.getRowCount();
if (leftRowCount == rightRowCount
if ((long) leftRowCount == (long) rightRowCount
&& physicalHashJoin.getGroupExpression().isPresent()
&& physicalHashJoin.getGroupExpression().get().getOwnerGroup() != null
&& !physicalHashJoin.getGroupExpression().get().getOwnerGroup().isStatsReliable()) {

View File

@ -0,0 +1,86 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.cost;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.CharType;
import org.apache.doris.nereids.types.DecimalV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.MapType;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.types.StructType;
import org.apache.doris.nereids.types.VarcharType;
import com.google.common.collect.Maps;
import java.util.Map;
/**
* expression cost is calculated by
* 1. non-leaf tree node count: N
* 2. expression which contains input of stringType or complexType(array/json/struct...), add cost
*/
public class ExpressionCostEvaluator extends ExpressionVisitor<Double, Void> {
private static Map<Class, Double> dataTypeCost = Maps.newHashMap();
static {
dataTypeCost.put(DecimalV2Type.class, 1.5);
dataTypeCost.put(DecimalV3Type.class, 1.5);
dataTypeCost.put(StringType.class, 2.0);
dataTypeCost.put(CharType.class, 2.0);
dataTypeCost.put(VarcharType.class, 2.0);
dataTypeCost.put(ArrayType.class, 3.0);
dataTypeCost.put(MapType.class, 3.0);
dataTypeCost.put(StructType.class, 3.0);
}
@Override
public Double visit(Expression expr, Void context) {
double cost = 0.0;
for (Expression child : expr.children()) {
cost += child.accept(this, context);
// the more children, the more computing cost
cost += dataTypeCost.getOrDefault(child.getDataType().getClass(), 0.1);
}
return cost;
}
@Override
public Double visitSlotReference(SlotReference slot, Void context) {
return 0.0;
}
@Override
public Double visitLiteral(Literal literal, Void context) {
return 0.0;
}
@Override
public Double visitAlias(Alias alias, Void context) {
Expression child = alias.child();
if (child instanceof SlotReference) {
return 0.0;
}
return alias.child().accept(this, context);
}
}

View File

@ -349,4 +349,28 @@ public class GroupExpression {
public ObjectId getId() {
return id;
}
/**
* the first child plan of clazz
* @param clazz the operator type, like join/aggregate
* @return child operator of type clazz, if not found, return null
*/
public Plan getFirstChildPlan(Class clazz) {
for (Group childGroup : children) {
for (GroupExpression logical : childGroup.getLogicalExpressions()) {
if (clazz.isInstance(logical.getPlan())) {
return logical.getPlan();
}
}
}
// for dphyp
for (Group childGroup : children) {
for (GroupExpression physical : childGroup.getPhysicalExpressions()) {
if (clazz.isInstance(physical.getPlan())) {
return physical.getPlan();
}
}
}
return null;
}
}

View File

@ -603,24 +603,24 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
private Statistics computeFilter(Filter filter) {
Statistics stats = groupExpression.childStatistics(0);
Plan plan = tryToFindChild(groupExpression);
boolean isOnBaseTable = false;
if (plan != null) {
if (plan instanceof OlapScan) {
isOnBaseTable = true;
} else if (plan instanceof Aggregate) {
Aggregate agg = ((Aggregate<?>) plan);
List<NamedExpression> expressions = agg.getOutputExpressions();
Set<Slot> slots = expressions
.stream()
.filter(Alias.class::isInstance)
.filter(s -> ((Alias) s).child().anyMatch(AggregateFunction.class::isInstance))
.map(NamedExpression::toSlot).collect(Collectors.toSet());
Expression predicate = filter.getPredicate();
if (predicate.anyMatch(s -> slots.contains(s))) {
return new FilterEstimation(slots).estimate(filter.getPredicate(), stats);
}
} else if (plan instanceof LogicalJoin && filter instanceof LogicalFilter
if (groupExpression.getFirstChildPlan(OlapScan.class) != null) {
return new FilterEstimation(true).estimate(filter.getPredicate(), stats);
}
if (groupExpression.getFirstChildPlan(Aggregate.class) != null) {
Aggregate agg = (Aggregate<?>) groupExpression.getFirstChildPlan(Aggregate.class);
List<NamedExpression> expressions = agg.getOutputExpressions();
Set<Slot> slots = expressions
.stream()
.filter(Alias.class::isInstance)
.filter(s -> ((Alias) s).child().anyMatch(AggregateFunction.class::isInstance))
.map(NamedExpression::toSlot).collect(Collectors.toSet());
Expression predicate = filter.getPredicate();
if (predicate.anyMatch(s -> slots.contains(s))) {
return new FilterEstimation(slots).estimate(filter.getPredicate(), stats);
}
} else if (groupExpression.getFirstChildPlan(LogicalJoin.class) != null) {
LogicalJoin plan = (LogicalJoin) groupExpression.getFirstChildPlan(LogicalJoin.class);
if (filter instanceof LogicalFilter
&& filter.getConjuncts().stream().anyMatch(e -> e instanceof IsNull)) {
Statistics isNullStats = computeGeneratedIsNullStats((LogicalJoin) plan, filter);
if (isNullStats != null) {
@ -640,8 +640,7 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
}
}
}
return new FilterEstimation(isOnBaseTable).estimate(filter.getPredicate(), stats);
return new FilterEstimation(false).estimate(filter.getPredicate(), stats);
}
private Statistics computeGeneratedIsNullStats(LogicalJoin join, Filter filter) {

View File

@ -1205,6 +1205,8 @@ public class SessionVariable implements Serializable, Writable {
@VariableMgr.VarAttr(name = ENABLE_NEW_COST_MODEL, needForward = true)
private boolean enableNewCostModel = false;
@VariableMgr.VarAttr(name = "filter_cost_factor", needForward = true)
public double filterCostFactor = 0.0001;
@VariableMgr.VarAttr(name = NEREIDS_STAR_SCHEMA_SUPPORT)
private boolean nereidsStarSchemaSupport = true;