diff --git a/plan/plans/index.go b/plan/plans/index.go index a91af7f8f0..fcba257e7f 100644 --- a/plan/plans/index.go +++ b/plan/plans/index.go @@ -194,6 +194,7 @@ func (r *indexPlan) GetFields() []*field.ResultField { // Filter implements plan.Plan Filter interface. // Filter merges BinaryOperations and determines the lower and upper bound. func (r *indexPlan) Filter(ctx context.Context, expr expression.Expression) (plan.Plan, bool, error) { + var spans []*indexSpan switch x := expr.(type) { case *expression.BinaryOperation: ok, name, val, err := x.IsIdentCompareVal() @@ -214,14 +215,12 @@ func (r *indexPlan) Filter(ctx context.Context, expr expression.Expression) (pla if err != nil { return nil, false, errors.Trace(err) } - r.spans = filterSpans(r.spans, toSpans(x.Op, val, seekVal)) - return r, true, nil + spans = filterSpans(r.spans, toSpans(x.Op, val, seekVal)) case *expression.Ident: if r.col.Name.L != x.L { break } - r.spans = filterSpans(r.spans, toSpans(opcode.GE, minNotNullVal, nil)) - return r, true, nil + spans = filterSpans(r.spans, toSpans(opcode.GE, minNotNullVal, nil)) case *expression.UnaryOperation: if x.Op != '!' { break @@ -234,16 +233,25 @@ func (r *indexPlan) Filter(ctx context.Context, expr expression.Expression) (pla if r.col.Name.L != cname.L { break } - r.spans = filterSpans(r.spans, toSpans(opcode.EQ, nil, nil)) - return r, true, nil + spans = filterSpans(r.spans, toSpans(opcode.EQ, nil, nil)) } - return r, false, nil + if spans == nil { + return r, false, nil + } + + return &indexPlan{ + src: r.src, + col: r.col, + idxName: r.idxName, + idx: r.idx, + spans: spans, + }, true, nil } // return the intersection range between origin and filter. func filterSpans(origin []*indexSpan, filter []*indexSpan) []*indexSpan { - var newSpans []*indexSpan + newSpans := make([]*indexSpan, 0, len(filter)) for _, fSpan := range filter { for _, oSpan := range origin { newSpan := oSpan.cutOffLow(fSpan.lowVal, fSpan.lowExclude) diff --git a/rset/rsets/where.go b/rset/rsets/where.go index d9904cc856..30729792a9 100644 --- a/rset/rsets/where.go +++ b/rset/rsets/where.go @@ -18,6 +18,7 @@ package rsets import ( + "github.com/juju/errors" "github.com/ngaut/log" "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/expression" @@ -188,13 +189,18 @@ func (r *WhereRset) Plan(ctx context.Context) (plan.Plan, error) { return r.planStatic(ctx, expr) } + var ( + src = r.Src + err error + ) + switch x := expr.(type) { case *expression.BinaryOperation: - return r.planBinOp(ctx, x) + src, err = r.planBinOp(ctx, x) case *expression.Ident: - return r.planIdent(ctx, x) + src, err = r.planIdent(ctx, x) case *expression.IsNull: - return r.planIsNull(ctx, x) + src, err = r.planIsNull(ctx, x) case *expression.PatternIn: // TODO: optimize // TODO: show plan @@ -203,10 +209,22 @@ func (r *WhereRset) Plan(ctx context.Context) (plan.Plan, error) { case *expression.PatternRegexp: // TODO: optimize case *expression.UnaryOperation: - return r.planUnaryOp(ctx, x) + src, err = r.planUnaryOp(ctx, x) default: log.Warnf("%v not supported in where rset now", r.Expr) } - return &plans.FilterDefaultPlan{Plan: r.Src, Expr: expr}, nil + if err != nil { + return nil, errors.Trace(err) + } + + if _, ok := src.(*plans.FilterDefaultPlan); ok { + return src, nil + } + + // We must use a FilterDefaultPlan here to wrap filtered plan. + // Alghough we can check where condition using index plan, we still need + // to check again after the FROM phase if the FROM phase contains outer join. + // TODO: if FROM phase doesn't contain outer join, we can return filtered plan directly. + return &plans.FilterDefaultPlan{Plan: src, Expr: expr}, nil } diff --git a/tidb_test.go b/tidb_test.go index 84bdc17462..20ae252e28 100644 --- a/tidb_test.go +++ b/tidb_test.go @@ -735,6 +735,21 @@ func (s *testSessionSuite) TestIndex(c *C) { c.Assert(err, IsNil) c.Assert(rows, HasLen, 1) match(c, rows[0], 1) + + mustExecSQL(c, se, "drop table if exists t1, t2") + mustExecSQL(c, se, ` + create table t1 (c1 int, primary key(c1)); + create table t2 (c2 int, primary key(c2)); + insert into t1 values (1), (2); + insert into t2 values (2);`) + + r = mustExecSQL(c, se, "select * from t1 left join t2 on t1.c1 = t2.c2 order by t1.c1") + rows, err = r.Rows(-1, 0) + matches(c, rows, [][]interface{}{{1, nil}, {2, 2}}) + + r = mustExecSQL(c, se, "select * from t1 left join t2 on t1.c1 = t2.c2 where t2.c2 < 10") + rows, err = r.Rows(-1, 0) + matches(c, rows, [][]interface{}{{2, 2}}) } func (s *testSessionSuite) TestMySQLTypes(c *C) { @@ -848,6 +863,20 @@ func (s *testSessionSuite) TestSelect(c *C) { row, err = r.FirstRow() c.Assert(err, IsNil) c.Assert(row, IsNil) + + mustExecSQL(c, se, "drop table if exists t1, t2, t3") + mustExecSQL(c, se, ` + create table t1 (c1 int); + create table t2 (c2 int); + create table t3 (c3 int); + insert into t1 values (1), (2); + insert into t2 values (2); + insert into t3 values (3);`) + r = mustExecSQL(c, se, "select * from t1 left join t2 on t1.c1 = t2.c2 left join t3 on t1.c1 = t3.c3 order by t1.c1") + rows, err = r.Rows(-1, 0) + c.Assert(err, IsNil) + matches(c, rows, [][]interface{}{{1, nil, nil}, {2, 2, nil}}) + } func (s *testSessionSuite) TestSubQuery(c *C) { @@ -1151,3 +1180,10 @@ func match(c *C, row []interface{}, expected ...interface{}) { c.Assert(got, Equals, need) } } + +func matches(c *C, rows [][]interface{}, expected [][]interface{}) { + c.Assert(len(rows), Equals, len(expected)) + for i := 0; i < len(rows); i++ { + match(c, rows[i], expected[i]...) + } +}