diff --git a/plan/cbo_test.go b/plan/cbo_test.go index d44d38f859..b26e833204 100644 --- a/plan/cbo_test.go +++ b/plan/cbo_test.go @@ -108,6 +108,54 @@ func (s *testAnalyzeSuite) TestIndexRead(c *C) { } } +func (s *testAnalyzeSuite) TestEmptyTable(c *C) { + defer func() { + testleak.AfterTest(c)() + }() + store, err := newStoreWithBootstrap() + c.Assert(err, IsNil) + testKit := testkit.NewTestKit(c, store) + defer func() { + store.Close() + }() + testKit.MustExec("use test") + testKit.MustExec("drop table if exists t, t1") + testKit.MustExec("create table t (c1 int)") + testKit.MustExec("create table t1 (c1 int)") + testKit.MustExec("analyze table t, t1") + tests := []struct { + sql string + best string + }{ + { + sql: "select * from t where t.c1 <= 50", + best: "TableReader(Table(t)->Sel([le(test.t.c1, 50)]))", + }, + { + sql: "select * from t where c1 in (select c1 from t1)", + best: "SemiJoin{TableReader(Table(t))->TableReader(Table(t1))}(test.t.c1,test.t1.c1)", + }, + { + sql: "select * from t, t1 where t.c1 = t1.c1", + best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t1))}(test.t.c1,test.t1.c1)", + }, + } + for _, tt := range tests { + ctx := testKit.Se.(context.Context) + stmts, err := tidb.Parse(ctx, tt.sql) + c.Assert(err, IsNil) + c.Assert(stmts, HasLen, 1) + stmt := stmts[0] + is := sessionctx.GetDomain(ctx).InfoSchema() + err = plan.ResolveName(stmt, is, ctx) + c.Assert(err, IsNil) + err = expression.InferType(ctx.GetSessionVars().StmtCtx, stmt) + c.Assert(err, IsNil) + p, err := plan.Optimize(ctx, stmt, is) + c.Assert(plan.ToString(p), Equals, tt.best, Commentf("for %s", tt.sql)) + } +} + func (s *testAnalyzeSuite) TestAnalyze(c *C) { defer func() { testleak.AfterTest(c)() diff --git a/plan/task.go b/plan/task.go index 603eb65d41..4592bf1141 100644 --- a/plan/task.go +++ b/plan/task.go @@ -150,6 +150,9 @@ func (p *PhysicalHashJoin) getCost(lCnt, rCnt float64) float64 { if p.SmallTable == 1 { smallTableCnt = rCnt } + if smallTableCnt <= 1 { + smallTableCnt = 1 + } return (lCnt + rCnt) * (1 + math.Log2(smallTableCnt)) } @@ -180,6 +183,9 @@ func (p *PhysicalMergeJoin) attach2Task(tasks ...task) task { } func (p *PhysicalHashSemiJoin) getCost(lCnt, rCnt float64) float64 { + if rCnt <= 1 { + rCnt = 1 + } return (lCnt + rCnt) * (1 + math.Log2(rCnt)) }