plan: fix cost estimation (#3896)
This commit is contained in:
@ -108,6 +108,54 @@ func (s *testAnalyzeSuite) TestIndexRead(c *C) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *testAnalyzeSuite) TestEmptyTable(c *C) {
|
||||
defer func() {
|
||||
testleak.AfterTest(c)()
|
||||
}()
|
||||
store, err := newStoreWithBootstrap()
|
||||
c.Assert(err, IsNil)
|
||||
testKit := testkit.NewTestKit(c, store)
|
||||
defer func() {
|
||||
store.Close()
|
||||
}()
|
||||
testKit.MustExec("use test")
|
||||
testKit.MustExec("drop table if exists t, t1")
|
||||
testKit.MustExec("create table t (c1 int)")
|
||||
testKit.MustExec("create table t1 (c1 int)")
|
||||
testKit.MustExec("analyze table t, t1")
|
||||
tests := []struct {
|
||||
sql string
|
||||
best string
|
||||
}{
|
||||
{
|
||||
sql: "select * from t where t.c1 <= 50",
|
||||
best: "TableReader(Table(t)->Sel([le(test.t.c1, 50)]))",
|
||||
},
|
||||
{
|
||||
sql: "select * from t where c1 in (select c1 from t1)",
|
||||
best: "SemiJoin{TableReader(Table(t))->TableReader(Table(t1))}(test.t.c1,test.t1.c1)",
|
||||
},
|
||||
{
|
||||
sql: "select * from t, t1 where t.c1 = t1.c1",
|
||||
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t1))}(test.t.c1,test.t1.c1)",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
ctx := testKit.Se.(context.Context)
|
||||
stmts, err := tidb.Parse(ctx, tt.sql)
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(stmts, HasLen, 1)
|
||||
stmt := stmts[0]
|
||||
is := sessionctx.GetDomain(ctx).InfoSchema()
|
||||
err = plan.ResolveName(stmt, is, ctx)
|
||||
c.Assert(err, IsNil)
|
||||
err = expression.InferType(ctx.GetSessionVars().StmtCtx, stmt)
|
||||
c.Assert(err, IsNil)
|
||||
p, err := plan.Optimize(ctx, stmt, is)
|
||||
c.Assert(plan.ToString(p), Equals, tt.best, Commentf("for %s", tt.sql))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *testAnalyzeSuite) TestAnalyze(c *C) {
|
||||
defer func() {
|
||||
testleak.AfterTest(c)()
|
||||
|
||||
@ -150,6 +150,9 @@ func (p *PhysicalHashJoin) getCost(lCnt, rCnt float64) float64 {
|
||||
if p.SmallTable == 1 {
|
||||
smallTableCnt = rCnt
|
||||
}
|
||||
if smallTableCnt <= 1 {
|
||||
smallTableCnt = 1
|
||||
}
|
||||
return (lCnt + rCnt) * (1 + math.Log2(smallTableCnt))
|
||||
}
|
||||
|
||||
@ -180,6 +183,9 @@ func (p *PhysicalMergeJoin) attach2Task(tasks ...task) task {
|
||||
}
|
||||
|
||||
func (p *PhysicalHashSemiJoin) getCost(lCnt, rCnt float64) float64 {
|
||||
if rCnt <= 1 {
|
||||
rCnt = 1
|
||||
}
|
||||
return (lCnt + rCnt) * (1 + math.Log2(rCnt))
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user