expressions: support row in IN expression

This commit is contained in:
siddontang
2015-09-10 18:34:35 +08:00
parent af42efef32
commit 2055d891ed
3 changed files with 62 additions and 23 deletions

View File

@ -181,7 +181,7 @@ func newMockPlan(rset *mockRecordset) *mockPlan {
func (p *mockPlan) Do(ctx context.Context, f plan.RowIterFunc) error {
for _, data := range p.rset.rows {
if more, err := f(nil, data); !more || err != nil {
if more, err := f(nil, data[:p.rset.offset]); !more || err != nil {
return err
}
}
@ -192,7 +192,7 @@ func (p *mockPlan) Explain(w format.Formatter) {}
func (p *mockPlan) GetFields() []*field.ResultField {
fields, _ := p.rset.Fields()
return fields
return fields[:p.rset.offset]
}
func (p *mockPlan) Filter(ctx context.Context, expr expression.Expression) (plan.Plan, bool, error) {

View File

@ -24,7 +24,6 @@ import (
"github.com/juju/errors"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/plan"
"github.com/pingcap/tidb/util/types"
)
@ -100,28 +99,20 @@ func (n *PatternIn) String() string {
return fmt.Sprintf("%s IN (%s)", n.Expr, n.Sel)
}
func (n *PatternIn) evalInList(ctx context.Context, args map[interface{}]interface{},
in interface{}, list []expression.Expression) (interface{}, error) {
func (n *PatternIn) checkInList(in interface{}, list []interface{}) (interface{}, error) {
hasNull := false
for _, v := range list {
b := NewBinaryOperation(opcode.EQ, Value{in}, v)
eVal, err := b.Eval(ctx, args)
if err != nil {
return nil, err
}
if eVal == nil {
if v == nil {
hasNull = true
continue
}
r, err := types.ToBool(eVal)
r, err := types.Compare(in, v)
if err != nil {
return nil, err
}
if r == 1 {
if r == 0 {
return !n.Not, nil
}
}
@ -135,6 +126,19 @@ func (n *PatternIn) evalInList(ctx context.Context, args map[interface{}]interfa
return n.Not, nil
}
func evalExprList(ctx context.Context, args map[interface{}]interface{}, list []expression.Expression) ([]interface{}, error) {
var err error
values := make([]interface{}, len(list))
for i := range values {
values[i], err = list[i].Eval(ctx, args)
if err != nil {
return nil, errors.Trace(err)
}
}
return values, nil
}
// Eval implements the Expression Eval interface.
func (n *PatternIn) Eval(ctx context.Context, args map[interface{}]interface{}) (v interface{}, err error) {
lhs, err := n.Expr.Eval(ctx, args)
@ -147,10 +151,20 @@ func (n *PatternIn) Eval(ctx context.Context, args map[interface{}]interface{})
}
if n.Sel == nil {
return n.evalInList(ctx, args, lhs, n.List)
if err := hasSameColumn(n.Expr, n.List...); err != nil {
return nil, errors.Trace(err)
}
var values []interface{}
values, err = evalExprList(ctx, args, n.List)
if err != nil {
return nil, errors.Trace(err)
}
return n.checkInList(lhs, values)
}
var res []expression.Expression
var res []interface{}
if ev, ok := args[n]; !ok {
// select not yet evaluated
r, err := n.Sel.Plan(ctx)
@ -158,14 +172,18 @@ func (n *PatternIn) Eval(ctx context.Context, args map[interface{}]interface{})
return nil, err
}
if g, e := len(r.GetFields()), 1; g != e {
if g, e := len(r.GetFields()), columnNumber(n.Expr); g != e {
return false, errors.Errorf("IN (%s): mismatched field count, have %d, need %d", n.Sel, g, e)
}
res = make([]expression.Expression, 0)
res = make([]interface{}, 0)
// evaluate select and save its result for later in expression check
err = r.Do(ctx, func(id interface{}, data []interface{}) (more bool, err error) {
res = append(res, Value{data[0]})
if len(data) == 1 {
res = append(res, data[0])
} else {
res = append(res, data)
}
args[n] = res
return true, nil
})
@ -175,8 +193,8 @@ func (n *PatternIn) Eval(ctx context.Context, args map[interface{}]interface{})
args[n] = res
} else {
res = ev.([]expression.Expression)
res = ev.([]interface{})
}
return n.evalInList(ctx, args, lhs, res)
return n.checkInList(lhs, res)
}

View File

@ -108,9 +108,30 @@ func (t *testPatternInSuite) TestPatternIn(c *C) {
c.Assert(err, IsNil)
c.Assert(vv, IsTrue)
args[e2] = []expression.Expression{Value{1}, Value{2}}
args[e2] = []interface{}{1, 2}
vv, err = e2.Eval(nil, args)
c.Assert(err, IsNil)
c.Assert(vv, IsTrue)
delete(args, e2)
e2.Expr = newTestRow(1, 2)
sel.SetFieldOffset(2)
_, err = e2.Eval(nil, args)
c.Assert(err, IsNil)
e2.Expr = newTestRow(1, 2, 3)
sel.SetFieldOffset(2)
_, err = e2.Eval(nil, args)
c.Assert(err, NotNil)
delete(args, e2)
e2.Sel = nil
e2.List = []expression.Expression{newTestRow(1, 2, 3)}
_, err = e2.Eval(nil, args)
c.Assert(err, IsNil)
e2.List = []expression.Expression{Value{1}}
_, err = e2.Eval(nil, args)
c.Assert(err, NotNil)
}