diff --git a/expression/expressions/expression_test.go b/expression/expressions/expression_test.go index a5280a9e9f..7350f75d8a 100644 --- a/expression/expressions/expression_test.go +++ b/expression/expressions/expression_test.go @@ -181,7 +181,7 @@ func newMockPlan(rset *mockRecordset) *mockPlan { func (p *mockPlan) Do(ctx context.Context, f plan.RowIterFunc) error { for _, data := range p.rset.rows { - if more, err := f(nil, data); !more || err != nil { + if more, err := f(nil, data[:p.rset.offset]); !more || err != nil { return err } } @@ -192,7 +192,7 @@ func (p *mockPlan) Explain(w format.Formatter) {} func (p *mockPlan) GetFields() []*field.ResultField { fields, _ := p.rset.Fields() - return fields + return fields[:p.rset.offset] } func (p *mockPlan) Filter(ctx context.Context, expr expression.Expression) (plan.Plan, bool, error) { diff --git a/expression/expressions/in.go b/expression/expressions/in.go index ecd0c9f56e..a86d81e42b 100644 --- a/expression/expressions/in.go +++ b/expression/expressions/in.go @@ -24,7 +24,6 @@ import ( "github.com/juju/errors" "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/expression" - "github.com/pingcap/tidb/parser/opcode" "github.com/pingcap/tidb/plan" "github.com/pingcap/tidb/util/types" ) @@ -100,28 +99,20 @@ func (n *PatternIn) String() string { return fmt.Sprintf("%s IN (%s)", n.Expr, n.Sel) } -func (n *PatternIn) evalInList(ctx context.Context, args map[interface{}]interface{}, - in interface{}, list []expression.Expression) (interface{}, error) { +func (n *PatternIn) checkInList(in interface{}, list []interface{}) (interface{}, error) { hasNull := false for _, v := range list { - b := NewBinaryOperation(opcode.EQ, Value{in}, v) - - eVal, err := b.Eval(ctx, args) - if err != nil { - return nil, err - } - - if eVal == nil { + if v == nil { hasNull = true continue } - r, err := types.ToBool(eVal) + r, err := types.Compare(in, v) if err != nil { return nil, err } - if r == 1 { + if r == 0 { return !n.Not, nil } } @@ -135,6 +126,19 @@ func (n *PatternIn) evalInList(ctx context.Context, args map[interface{}]interfa return n.Not, nil } +func evalExprList(ctx context.Context, args map[interface{}]interface{}, list []expression.Expression) ([]interface{}, error) { + var err error + values := make([]interface{}, len(list)) + for i := range values { + values[i], err = list[i].Eval(ctx, args) + if err != nil { + return nil, errors.Trace(err) + } + } + + return values, nil +} + // Eval implements the Expression Eval interface. func (n *PatternIn) Eval(ctx context.Context, args map[interface{}]interface{}) (v interface{}, err error) { lhs, err := n.Expr.Eval(ctx, args) @@ -147,10 +151,20 @@ func (n *PatternIn) Eval(ctx context.Context, args map[interface{}]interface{}) } if n.Sel == nil { - return n.evalInList(ctx, args, lhs, n.List) + if err := hasSameColumn(n.Expr, n.List...); err != nil { + return nil, errors.Trace(err) + } + + var values []interface{} + values, err = evalExprList(ctx, args, n.List) + if err != nil { + return nil, errors.Trace(err) + } + + return n.checkInList(lhs, values) } - var res []expression.Expression + var res []interface{} if ev, ok := args[n]; !ok { // select not yet evaluated r, err := n.Sel.Plan(ctx) @@ -158,14 +172,18 @@ func (n *PatternIn) Eval(ctx context.Context, args map[interface{}]interface{}) return nil, err } - if g, e := len(r.GetFields()), 1; g != e { + if g, e := len(r.GetFields()), columnNumber(n.Expr); g != e { return false, errors.Errorf("IN (%s): mismatched field count, have %d, need %d", n.Sel, g, e) } - res = make([]expression.Expression, 0) + res = make([]interface{}, 0) // evaluate select and save its result for later in expression check err = r.Do(ctx, func(id interface{}, data []interface{}) (more bool, err error) { - res = append(res, Value{data[0]}) + if len(data) == 1 { + res = append(res, data[0]) + } else { + res = append(res, data) + } args[n] = res return true, nil }) @@ -175,8 +193,8 @@ func (n *PatternIn) Eval(ctx context.Context, args map[interface{}]interface{}) args[n] = res } else { - res = ev.([]expression.Expression) + res = ev.([]interface{}) } - return n.evalInList(ctx, args, lhs, res) + return n.checkInList(lhs, res) } diff --git a/expression/expressions/in_test.go b/expression/expressions/in_test.go index f64ec5db97..0613755382 100644 --- a/expression/expressions/in_test.go +++ b/expression/expressions/in_test.go @@ -108,9 +108,30 @@ func (t *testPatternInSuite) TestPatternIn(c *C) { c.Assert(err, IsNil) c.Assert(vv, IsTrue) - args[e2] = []expression.Expression{Value{1}, Value{2}} + args[e2] = []interface{}{1, 2} vv, err = e2.Eval(nil, args) c.Assert(err, IsNil) c.Assert(vv, IsTrue) + + delete(args, e2) + e2.Expr = newTestRow(1, 2) + sel.SetFieldOffset(2) + _, err = e2.Eval(nil, args) + c.Assert(err, IsNil) + + e2.Expr = newTestRow(1, 2, 3) + sel.SetFieldOffset(2) + _, err = e2.Eval(nil, args) + c.Assert(err, NotNil) + + delete(args, e2) + e2.Sel = nil + e2.List = []expression.Expression{newTestRow(1, 2, 3)} + _, err = e2.Eval(nil, args) + c.Assert(err, IsNil) + + e2.List = []expression.Expression{Value{1}} + _, err = e2.Eval(nil, args) + c.Assert(err, NotNil) }