diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java index 41e2cd912b..41f4ae7cea 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java @@ -27,10 +27,13 @@ import org.apache.doris.nereids.metrics.consumer.LogConsumer; import org.apache.doris.nereids.metrics.event.GroupMergeEvent; import org.apache.doris.nereids.properties.LogicalProperties; import org.apache.doris.nereids.properties.PhysicalProperties; +import org.apache.doris.nereids.properties.RequestPropertyDeriver; +import org.apache.doris.nereids.properties.RequirePropertiesSupplier; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.LeafPlan; import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.algebra.SetOperation; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; @@ -52,7 +55,9 @@ import java.util.List; import java.util.Map; import java.util.Optional; import java.util.PriorityQueue; +import java.util.Set; import java.util.stream.Collectors; +import java.util.stream.Stream; import javax.annotation.Nullable; /** @@ -796,71 +801,101 @@ public class Memo { public Pair rank(long n) { double threshold = 0.000000001; Preconditions.checkArgument(n > 0, "the n %d must be greater than 0 in nthPlan", n); - List> plans = rankGroup(root, PhysicalProperties.GATHER); - plans = plans.stream().filter( - p -> !p.second.equals(Double.NaN) - && !p.second.equals(Double.POSITIVE_INFINITY) - && !p.second.equals(Double.NEGATIVE_INFINITY)) + List> plans = rankGroup(root, PhysicalProperties.GATHER); + plans = plans.stream() + .filter(p -> Double.isFinite(p.second.getValue())) .collect(Collectors.toList()); // This is big heap, it always pops the element with larger cost or larger id. - PriorityQueue> pq = new PriorityQueue<>((l, r) -> Math.abs(l.second - r.second) < threshold - ? -Long.compare(l.first, r.first) : -Double.compare(l.second, r.second)); - for (Pair p : plans) { + PriorityQueue> pq = new PriorityQueue<>((l, r) -> + Math.abs(l.second.getValue() - r.second.getValue()) < threshold + ? -Long.compare(l.first, r.first) + : -Double.compare(l.second.getValue(), r.second.getValue())); + for (Pair p : plans) { pq.add(p); if (pq.size() > n) { pq.poll(); } } - return pq.peek(); + Preconditions.checkArgument(pq.peek() != null, "rank error because there is no valid plan"); + return Pair.of(pq.peek().first, pq.peek().second.getValue()); } - private List> rankGroup(Group group, PhysicalProperties prop) { - List> res = new ArrayList<>(); - int prefix = res.size(); - for (GroupExpression groupExpression : extractGroupExpressionContainsProp(group, prop)) { - for (Pair idCostPair : rankGroupExpression(groupExpression, prop)) { + /** + * return number of plan that can be ranked + */ + public int getRankSize() { + List> plans = rankGroup(root, PhysicalProperties.GATHER); + plans = plans.stream().filter( + p -> !p.second.equals(Double.NaN) + && !p.second.equals(Double.POSITIVE_INFINITY) + && !p.second.equals(Double.NEGATIVE_INFINITY)) + .collect(Collectors.toList()); + return plans.size(); + } + + private List> rankGroup(Group group, PhysicalProperties prop) { + List> res = new ArrayList<>(); + int prefix = 0; + List validGroupExprList = extractGroupExpressionSatisfyProp(group, prop); + for (GroupExpression groupExpression : validGroupExprList) { + for (Pair idCostPair : rankGroupExpression(groupExpression, prop)) { res.add(Pair.of(idCostPair.first + prefix, idCostPair.second)); } prefix = res.size(); + // avoid ranking all plans + if (res.size() > 1e2) { + break; + } } return res; } - private List> rankGroupExpression(GroupExpression groupExpression, + private List> rankGroupExpression(GroupExpression groupExpression, PhysicalProperties prop) { if (!groupExpression.getLowestCostTable().containsKey(prop)) { return new ArrayList<>(); } - List> res = new ArrayList<>(); - - List inputProperties = groupExpression.getInputPropertiesList(prop); + List> res = new ArrayList<>(); if (groupExpression.getPlan() instanceof LeafPlan) { - res.add(Pair.of(0L, groupExpression.getCostByProperties(prop))); + res.add(Pair.of(0L, groupExpression.getCostValueByProperties(prop))); return res; } - List>> children = new ArrayList<>(); - for (int i = 0; i < inputProperties.size(); i++) { - // To avoid reach a circle, we don't allow ranking the same group with the same physical properties. - Preconditions.checkArgument(!groupExpression.child(i).equals(groupExpression.getOwnerGroup()) - || !prop.equals(inputProperties.get(i))); - List> idCostPair - = rankGroup(groupExpression.child(i), inputProperties.get(i)); - children.add(idCostPair); - } - List>> childrenId = new ArrayList<>(); - permute(children, 0, childrenId, new ArrayList<>()); - Cost cost = CostCalculator.calculateCost(groupExpression, inputProperties); - for (Pair> c : childrenId) { - Cost totalCost = cost; - for (int i = 0; i < children.size(); i++) { - totalCost = CostCalculator.addChildCost(groupExpression.getPlan(), - totalCost, - groupExpression.child(i).getLowestCostPlan(inputProperties.get(i)).get().first, - i); + List> inputPropertiesList = extractInputProperties(groupExpression, prop); + for (List inputProperties : inputPropertiesList) { + int prefix = res.size(); + List>> children = new ArrayList<>(); + for (int i = 0; i < inputProperties.size(); i++) { + // To avoid reach a circle, we don't allow ranking the same group with the same physical properties. + Preconditions.checkArgument(!groupExpression.child(i).equals(groupExpression.getOwnerGroup()) + || !prop.equals(inputProperties.get(i))); + List> idCostPair + = rankGroup(groupExpression.child(i), inputProperties.get(i)); + children.add(idCostPair); + } + + List>> childrenId = new ArrayList<>(); + permute(children, 0, childrenId, new ArrayList<>()); + Cost cost = CostCalculator.calculateCost(groupExpression, inputProperties); + for (Pair> c : childrenId) { + Cost totalCost = cost; + for (int i = 0; i < children.size(); i++) { + totalCost = CostCalculator.addChildCost(groupExpression.getPlan(), + totalCost, + children.get(i).get(c.second.get(i)).second, + i); + } + if (res.isEmpty()) { + Preconditions.checkArgument( + Math.abs(totalCost.getValue() - groupExpression.getCostByProperties(prop)) < 0.0001, + "Please check operator %s, expected cost %s but found %s", + groupExpression.getPlan().shapeInfo(), totalCost.getValue(), + groupExpression.getCostByProperties(prop)); + } + res.add(Pair.of(prefix + c.first, totalCost)); } - res.add(Pair.of(c.first, totalCost.getValue())); } + return res; } @@ -869,7 +904,7 @@ public class Memo { * for children [1, 2] [1, 2, 3] * we can get: 0: [1,1] 1:[1, 2] 2:[1, 3] 3:[2, 1] 4:[2, 2] 5:[2, 3] */ - private void permute(List>> children, int index, + private void permute(List>> children, int index, List>> result, List current) { if (index == children.size()) { result.add(Pair.of(getUniqueId(children, current), current)); @@ -889,7 +924,7 @@ public class Memo { * [0, 0]: 0*1 + 0*1*2 * [0, 1]: 0*1 + 1*1*2 */ - private static long getUniqueId(List>> lists, List current) { + private static long getUniqueId(List>> lists, List current) { long id = 0; long factor = 1; for (int i = 0; i < lists.size(); i++) { @@ -899,54 +934,153 @@ public class Memo { return id; } - private List extractGroupExpressionContainsProp(Group group, PhysicalProperties prop) { - List validExpressions = new ArrayList<>(); + private List extractGroupExpressionSatisfyProp(Group group, PhysicalProperties prop) { GroupExpression bestExpr = group.getLowestCostPlan(prop).get().second; - validExpressions.add(bestExpr); - for (GroupExpression groupExpression : group.getPhysicalExpressions()) { - if (!groupExpression.equals(bestExpr) && groupExpression.getLowestCostTable().containsKey(prop)) { - validExpressions.add(groupExpression); - } - } - return validExpressions; + List exprs = Lists.newArrayList(bestExpr); + Set hasVisited = new HashSet<>(); + hasVisited.add(bestExpr); + Stream.concat(group.getPhysicalExpressions().stream(), group.getEnforcers().stream()) + .forEach(groupExpression -> { + if (!groupExpression.getInputPropertiesListOrEmpty(prop).isEmpty() + && !groupExpression.equals(bestExpr) && !hasVisited.contains(groupExpression)) { + hasVisited.add(groupExpression); + exprs.add(groupExpression); + } + }); + return exprs; } - private PhysicalPlan unrankGroup(Group group, PhysicalProperties prop, long rank) { + // ---------------------------------------------------------------- + // extract input properties for a given group expression and required output properties + // There are three cases: + // 1. If group expression is enforcer, return the input properties of the best expression + // 2. If group expression require any, return any input properties + // 3. Otherwise, return all input properties that satisfies the required output properties + private List> extractInputProperties(GroupExpression groupExpression, + PhysicalProperties prop) { + List> res = new ArrayList<>(); + res.add(groupExpression.getInputPropertiesList(prop)); + + // return optimized input for enforcer + if (groupExpression.getOwnerGroup().getEnforcers().contains(groupExpression)) { + return res; + } + + // return any if exits except RequirePropertiesSupplier and SetOperators + // Because PropRegulator could change their input properties + RequestPropertyDeriver requestPropertyDeriver = new RequestPropertyDeriver(prop); + List> requestList = requestPropertyDeriver + .getRequestChildrenPropertyList(groupExpression); + Optional> any = requestList.stream() + .filter(e -> e.stream().allMatch(PhysicalProperties.ANY::equals)) + .findAny(); + if (any.isPresent() + && !(groupExpression.getPlan() instanceof RequirePropertiesSupplier) + && !(groupExpression.getPlan() instanceof SetOperation)) { + res.clear(); + res.add(any.get()); + return res; + } + + // return all optimized inputs + Set> inputProps = groupExpression.getLowestCostTable().keySet().stream() + .filter(physicalProperties -> physicalProperties.satisfy(prop)) + .map(groupExpression::getInputPropertiesList) + .collect(Collectors.toSet()); + res.addAll(inputProps); + return res; + } + + private int getGroupSize(Group group, PhysicalProperties prop, + Map>> exprSizeCache) { + List validGroupExprs = extractGroupExpressionSatisfyProp(group, prop); + int groupCount = 0; + for (GroupExpression groupExpression : validGroupExprs) { + int exprCount = getExprSize(groupExpression, prop, exprSizeCache); + groupCount += exprCount; + if (groupCount > 1e2) { + break; + } + } + return groupCount; + } + + // return size for each input properties + private int getExprSize(GroupExpression groupExpression, PhysicalProperties properties, + Map>> exprChildSizeCache) { + List> exprCount = new ArrayList<>(); + if (!groupExpression.getLowestCostTable().containsKey(properties)) { + exprCount.add(Lists.newArrayList(0)); + } else if (groupExpression.getPlan() instanceof LeafPlan) { + exprCount.add(Lists.newArrayList(1)); + } else { + List> inputPropertiesList = extractInputProperties(groupExpression, properties); + for (List inputProperties : inputPropertiesList) { + List groupExprSize = new ArrayList<>(); + for (int i = 0; i < inputProperties.size(); i++) { + groupExprSize.add( + getGroupSize(groupExpression.child(i), inputProperties.get(i), exprChildSizeCache)); + } + exprCount.add(groupExprSize); + } + } + exprChildSizeCache.put(groupExpression, exprCount); + return exprCount.stream() + .mapToInt(s -> s.stream().reduce(1, (a, b) -> a * b)) + .sum(); + } + + private PhysicalPlan unrankGroup(Group group, PhysicalProperties prop, long rank, + Map>> exprSizeCache) { int prefix = 0; - for (GroupExpression groupExpression : extractGroupExpressionContainsProp(group, prop)) { - List> possiblePlans = rankGroupExpression(groupExpression, prop); - if (!possiblePlans.isEmpty() && rank - prefix <= possiblePlans.get(possiblePlans.size() - 1).first) { - return unrankGroupExpression(groupExpression, prop, rank - prefix); + for (GroupExpression groupExpression : extractGroupExpressionSatisfyProp(group, prop)) { + int exprCount = exprSizeCache.get(groupExpression).stream() + .mapToInt(s -> s.stream().reduce(1, (a, b) -> a * b)) + .sum(); + // rank is start from 0 + if (exprCount != 0 && rank + 1 - prefix <= exprCount) { + return unrankGroupExpression(groupExpression, prop, rank - prefix, + exprSizeCache); } - prefix += possiblePlans.size(); + prefix += exprCount; } - Preconditions.checkArgument(false, "unrank Group error"); - return null; + throw new RuntimeException("the group has no plan for prop %s in rank job"); } - private PhysicalPlan unrankGroupExpression(GroupExpression groupExpression, - PhysicalProperties prop, long rank) { + private PhysicalPlan unrankGroupExpression(GroupExpression groupExpression, PhysicalProperties prop, long rank, + Map>> exprSizeCache) { if (groupExpression.getPlan() instanceof LeafPlan) { - Preconditions.checkArgument(rank == 0); + Preconditions.checkArgument(rank == 0, + "leaf plan's %s rank must be 0 but is %d", groupExpression, rank); return ((PhysicalPlan) groupExpression.getPlan()).withPhysicalPropertiesAndStats( groupExpression.getOutputProperties(prop), groupExpression.getOwnerGroup().getStatistics()); } - List>> children = new ArrayList<>(); - List properties = groupExpression.getInputPropertiesList(prop); - for (int i = 0; i < properties.size(); i++) { - children.add(rankGroup(groupExpression.child(i), properties.get(i))); - } - List childrenRanks = extractChildRanks(rank, children); - List childrenPlan = new ArrayList<>(); - for (int i = 0; i < properties.size(); i++) { - childrenPlan.add(unrankGroup(groupExpression.child(i), properties.get(i), childrenRanks.get(i))); + List> inputPropertiesList = extractInputProperties(groupExpression, prop); + for (int i = 0; i < inputPropertiesList.size(); i++) { + List properties = inputPropertiesList.get(i); + List childrenSize = exprSizeCache.get(groupExpression).get(i); + int count = childrenSize.stream().reduce(1, (a, b) -> a * b); + if (rank >= count) { + rank -= count; + continue; + } + List childrenRanks = extractChildRanks(rank, childrenSize); + List childrenPlan = new ArrayList<>(); + for (int j = 0; j < properties.size(); j++) { + Plan plan = unrankGroup(groupExpression.child(j), properties.get(j), + childrenRanks.get(j), exprSizeCache); + Preconditions.checkArgument(plan != null, "rank group get null"); + childrenPlan.add(plan); + } + + Plan plan = groupExpression.getPlan().withChildren(childrenPlan); + return ((PhysicalPlan) plan).withPhysicalPropertiesAndStats( + groupExpression.getOutputProperties(prop), + groupExpression.getOwnerGroup().getStatistics()); } - Plan plan = groupExpression.getPlan().withChildren(childrenPlan); - return ((PhysicalPlan) plan).withPhysicalPropertiesAndStats( - groupExpression.getOutputProperties(prop), - groupExpression.getOwnerGroup().getStatistics()); + throw new RuntimeException("the groupExpr has no plan for prop in rank job"); } /** @@ -955,20 +1089,20 @@ public class Memo { * 1: [1%1, 1%(1*2)] * 2: [2%1, 2%(1*2)] */ - private List extractChildRanks(long rank, List>> children) { - Preconditions.checkArgument(!children.isEmpty(), "children should not empty in extractChildRanks"); - int factor = children.get(0).size(); + private List extractChildRanks(long rank, List childrenSize) { + Preconditions.checkArgument(!childrenSize.isEmpty(), "children should not empty in extractChildRanks"); List indices = new ArrayList<>(); - for (int i = 0; i < children.size() - 1; i++) { + for (int i = 0; i < childrenSize.size(); i++) { + int factor = childrenSize.get(i); indices.add(rank % factor); rank = rank / factor; - factor *= children.get(i + 1).size(); } - indices.add(rank % factor); return indices; } public PhysicalPlan unrank(long id) { - return unrankGroup(getRoot(), PhysicalProperties.GATHER, id); + Map>> exprSizeCache = new HashMap<>(); + getGroupSize(getRoot(), PhysicalProperties.GATHER, exprSizeCache); + return unrankGroup(getRoot(), PhysicalProperties.GATHER, id, exprSizeCache); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java index ef8fd61415..ec1d93ca70 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java @@ -72,6 +72,10 @@ public class RequestPropertyDeriver extends PlanVisitor { this.requestPropertyFromParent = context.getRequiredProperties(); } + public RequestPropertyDeriver(PhysicalProperties requestPropertyFromParent) { + this.requestPropertyFromParent = requestPropertyFromParent; + } + /** * get request children property list */ diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/RankTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/RankTest.java index 95ea22b705..4c81c0f6cc 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/RankTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/RankTest.java @@ -17,49 +17,44 @@ package org.apache.doris.nereids.memo; -import org.apache.doris.nereids.datasets.tpch.TPCHTestBase; -import org.apache.doris.nereids.datasets.tpch.TPCHUtils; +import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan; +import org.apache.doris.nereids.util.HyperGraphBuilder; +import org.apache.doris.nereids.util.MemoTestUtils; import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.utframe.TestWithFeService; +import com.google.common.collect.Sets; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; -import java.lang.reflect.Field; +import java.util.HashSet; +import java.util.Set; -public class RankTest extends TPCHTestBase { +public class RankTest extends TestWithFeService { @Test - void testRank() throws NoSuchFieldException, IllegalAccessException { - for (int i = 1; i < 22; i++) { - Field field = TPCHUtils.class.getField("Q" + i); - System.out.println("Q" + i); - Memo memo = PlanChecker.from(connectContext) - .analyze(field.get(null).toString()) - .rewrite() - .optimize() - .getCascadesContext() - .getMemo(); - memo.rank(1); + void test() { + HyperGraphBuilder hyperGraphBuilder = new HyperGraphBuilder(Sets.newHashSet(JoinType.INNER_JOIN)); + hyperGraphBuilder.init(0, 1, 2); + Plan plan = hyperGraphBuilder.addEdge(JoinType.INNER_JOIN, 1, 2) + .addEdge(JoinType.INNER_JOIN, 0, 1) + .buildPlan(); + plan = new LogicalProject(plan.getOutput(), plan); + CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(connectContext, plan); + hyperGraphBuilder.initStats(cascadesContext); + PhysicalPlan bestPlan = PlanChecker.from(cascadesContext) + .optimize() + .getBestPlanTree(); + Memo memo = cascadesContext.getMemo(); + Set shape = new HashSet<>(); + for (int i = 0; i < memo.getRankSize(); i++) { + shape.add(memo.unrank(memo.rank(i + 1).first).shape("")); } - } - - //TODO re-open this case latter. the plan for q3 is different. But we do not have time to fix this bug now. - @Test - void testUnrank() throws NoSuchFieldException, IllegalAccessException { - //for (int i = 1; i < 22; i++) { - // Field field = TPCHUtils.class.getField("Q" + i); - // System.out.println("Q" + i); - // Memo memo = PlanChecker.from(connectContext) - // .analyze(field.get(null).toString()) - // .rewrite() - // .optimize() - // .getCascadesContext() - // .getMemo(); - // PhysicalPlan plan1 = memo.unrank(memo.rank(1).first); - // PhysicalPlan plan2 = PlanChecker.from(connectContext) - // .analyze(field.get(null).toString()) - // .rewrite() - // .optimize() - // .getBestPlanTree(PhysicalProperties.GATHER); - // Assertions.assertTrue(PlanChecker.isPlanEqualWithoutID(plan1, plan2)); - //} + System.out.println(shape); + Assertions.assertEquals(3, shape.size()); + Assertions.assertEquals(bestPlan.shape(""), memo.unrank(memo.rank(1).first).shape("")); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/HyperGraphBuilder.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/HyperGraphBuilder.java index ac867003a3..e33c28ae93 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/HyperGraphBuilder.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/HyperGraphBuilder.java @@ -61,14 +61,14 @@ public class HyperGraphBuilder { private final HashMap plans = new HashMap<>(); private final HashMap> schemas = new HashMap<>(); - private final ImmutableList fullJoinTypes = ImmutableList.of( + private ImmutableList fullJoinTypes = ImmutableList.of( JoinType.INNER_JOIN, JoinType.LEFT_OUTER_JOIN, JoinType.RIGHT_OUTER_JOIN, JoinType.FULL_OUTER_JOIN ); - private final ImmutableList leftFullJoinTypes = ImmutableList.of( + private ImmutableList leftFullJoinTypes = ImmutableList.of( JoinType.INNER_JOIN, JoinType.LEFT_OUTER_JOIN, JoinType.RIGHT_OUTER_JOIN, @@ -78,7 +78,7 @@ public class HyperGraphBuilder { JoinType.NULL_AWARE_LEFT_ANTI_JOIN ); - private final ImmutableList rightFullJoinTypes = ImmutableList.of( + private ImmutableList rightFullJoinTypes = ImmutableList.of( JoinType.INNER_JOIN, JoinType.LEFT_OUTER_JOIN, JoinType.RIGHT_OUTER_JOIN, @@ -87,6 +87,20 @@ public class HyperGraphBuilder { JoinType.RIGHT_ANTI_JOIN ); + public HyperGraphBuilder() {} + + public HyperGraphBuilder(Set validJoinType) { + fullJoinTypes = fullJoinTypes.stream() + .filter(validJoinType::contains) + .collect(ImmutableList.toImmutableList()); + leftFullJoinTypes = leftFullJoinTypes.stream() + .filter(validJoinType::contains) + .collect(ImmutableList.toImmutableList()); + rightFullJoinTypes = rightFullJoinTypes.stream() + .filter(validJoinType::contains) + .collect(ImmutableList.toImmutableList()); + } + public HyperGraph build() { assert plans.size() == 1 : "there are cross join"; Plan plan = plans.values().iterator().next(); diff --git a/tools/cost_model_evaluate/evaluator.py b/tools/cost_model_evaluate/evaluator.py index e963a183ed..d653c18cf1 100644 --- a/tools/cost_model_evaluate/evaluator.py +++ b/tools/cost_model_evaluate/evaluator.py @@ -48,6 +48,7 @@ class Evaluator: plans = self.extract_all_plans() res: list[tuple[float, float]] = [] for n, (plan, cost) in plans.items(): + print(f"run {n}-th plan") time = self.sql_executor.get_execute_time(plan) res.append((cost, time)) if self.config.plot: @@ -60,7 +61,8 @@ class Evaluator: x_values = [t[0] for t in data] y_values = [t[1] for t in data] fig, ax = plt.subplots() - ax.scatter(x_values, y_values) + ax.scatter(x_values[:1], y_values[:1], c='r') + ax.scatter(x_values[1:], y_values[1:]) ax.set_xlabel('Cost') ax.set_ylabel('Time') plt.show() @@ -72,11 +74,13 @@ class Evaluator: def extract_all_plans(self): plan_set = set() plan_map: dict[int, tuple[str, float]] = {} - for n in range(1, self.config.plan_number): + n = 0 + while len(plan_set) < self.config.plan_number: + n += 1 query = self.inject_nth_optimized_hint(n) plan, cost = self.sql_executor.get_plan_with_cost(query) if plan in plan_set: - break + continue plan_set.add(plan) plan_map[n] = (query, cost) return plan_map diff --git a/tools/cost_model_evaluate/main.py b/tools/cost_model_evaluate/main.py index 3103fb2316..6a93ebeaea 100644 --- a/tools/cost_model_evaluate/main.py +++ b/tools/cost_model_evaluate/main.py @@ -24,7 +24,7 @@ config = Config( "", "127.0.0.1", 9030, - "regression_test_nereids_tpch_p0", + "tpch", 2, 50, True, diff --git a/tools/cost_model_evaluate/sql_executor.py b/tools/cost_model_evaluate/sql_executor.py index c38cc322db..9189eb7c2a 100644 --- a/tools/cost_model_evaluate/sql_executor.py +++ b/tools/cost_model_evaluate/sql_executor.py @@ -31,7 +31,7 @@ class SQLExecutor: database=database ) self.cursor = self.connection.cursor() - self.wait_fetch_time_index = 16 + self.wait_fetch_time_index = 4 def execute_query(self, query: str, parameters: Tuple | None) -> List[Tuple]: if parameters: