plan: fix cost estimation (#3896)

This commit is contained in:
Haibin Xie
2017-07-26 15:24:13 +08:00
committed by GitHub
parent e008c6c36a
commit fad1e7eca1
2 changed files with 54 additions and 0 deletions

View File

@ -108,6 +108,54 @@ func (s *testAnalyzeSuite) TestIndexRead(c *C) {
}
}
func (s *testAnalyzeSuite) TestEmptyTable(c *C) {
defer func() {
testleak.AfterTest(c)()
}()
store, err := newStoreWithBootstrap()
c.Assert(err, IsNil)
testKit := testkit.NewTestKit(c, store)
defer func() {
store.Close()
}()
testKit.MustExec("use test")
testKit.MustExec("drop table if exists t, t1")
testKit.MustExec("create table t (c1 int)")
testKit.MustExec("create table t1 (c1 int)")
testKit.MustExec("analyze table t, t1")
tests := []struct {
sql string
best string
}{
{
sql: "select * from t where t.c1 <= 50",
best: "TableReader(Table(t)->Sel([le(test.t.c1, 50)]))",
},
{
sql: "select * from t where c1 in (select c1 from t1)",
best: "SemiJoin{TableReader(Table(t))->TableReader(Table(t1))}(test.t.c1,test.t1.c1)",
},
{
sql: "select * from t, t1 where t.c1 = t1.c1",
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t1))}(test.t.c1,test.t1.c1)",
},
}
for _, tt := range tests {
ctx := testKit.Se.(context.Context)
stmts, err := tidb.Parse(ctx, tt.sql)
c.Assert(err, IsNil)
c.Assert(stmts, HasLen, 1)
stmt := stmts[0]
is := sessionctx.GetDomain(ctx).InfoSchema()
err = plan.ResolveName(stmt, is, ctx)
c.Assert(err, IsNil)
err = expression.InferType(ctx.GetSessionVars().StmtCtx, stmt)
c.Assert(err, IsNil)
p, err := plan.Optimize(ctx, stmt, is)
c.Assert(plan.ToString(p), Equals, tt.best, Commentf("for %s", tt.sql))
}
}
func (s *testAnalyzeSuite) TestAnalyze(c *C) {
defer func() {
testleak.AfterTest(c)()

View File

@ -150,6 +150,9 @@ func (p *PhysicalHashJoin) getCost(lCnt, rCnt float64) float64 {
if p.SmallTable == 1 {
smallTableCnt = rCnt
}
if smallTableCnt <= 1 {
smallTableCnt = 1
}
return (lCnt + rCnt) * (1 + math.Log2(smallTableCnt))
}
@ -180,6 +183,9 @@ func (p *PhysicalMergeJoin) attach2Task(tasks ...task) task {
}
func (p *PhysicalHashSemiJoin) getCost(lCnt, rCnt float64) float64 {
if rCnt <= 1 {
rCnt = 1
}
return (lCnt + rCnt) * (1 + math.Log2(rCnt))
}