diff --git a/expression/expressions/binop.go b/expression/expressions/binop.go index 80eeb0f175..f86852421c 100644 --- a/expression/expressions/binop.go +++ b/expression/expressions/binop.go @@ -364,6 +364,25 @@ func (o *BinaryOperation) evalLogicOp(ctx context.Context, args map[interface{}] } } +func getCompResult(op opcode.Op, value int) (bool, error) { + switch op { + case opcode.LT: + return value < 0, nil + case opcode.LE: + return value <= 0, nil + case opcode.GE: + return value >= 0, nil + case opcode.GT: + return value > 0, nil + case opcode.EQ: + return value == 0, nil + case opcode.NE: + return value != 0, nil + default: + return false, errors.Errorf("invalid op %v in comparision operation", op) + } +} + // operator: >=, >, <=, <, !=, <>, = <=>, etc. // see https://dev.mysql.com/doc/refman/5.7/en/comparison-operators.html func (o *BinaryOperation) evalComparisonOp(ctx context.Context, args map[interface{}]interface{}) (interface{}, error) { @@ -383,22 +402,12 @@ func (o *BinaryOperation) evalComparisonOp(ctx context.Context, args map[interfa return nil, o.traceErr(err) } - switch o.Op { - case opcode.LT: - return n < 0, nil - case opcode.LE: - return n <= 0, nil - case opcode.GE: - return n >= 0, nil - case opcode.GT: - return n > 0, nil - case opcode.EQ: - return n == 0, nil - case opcode.NE: - return n != 0, nil - default: - return nil, o.errorf("invalid op %v in comparision operation", o.Op) + r, err := getCompResult(o.Op, n) + if err != nil { + return nil, o.errorf(err.Error()) } + + return r, nil } func (o *BinaryOperation) evalPlus(a interface{}, b interface{}) (interface{}, error) { diff --git a/expression/expressions/cmp_subquery.go b/expression/expressions/cmp_subquery.go new file mode 100644 index 0000000000..d3a8031973 --- /dev/null +++ b/expression/expressions/cmp_subquery.go @@ -0,0 +1,191 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package expressions + +import ( + "fmt" + + "github.com/juju/errors" + "github.com/pingcap/tidb/context" + "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/parser/opcode" + "github.com/pingcap/tidb/util/types" +) + +// CompareSubQuery is the expression for "expr cmp (select ...)". +// See: https://dev.mysql.com/doc/refman/5.7/en/comparisons-using-subqueries.html +// See: https://dev.mysql.com/doc/refman/5.7/en/any-in-some-subqueries.html +// See: https://dev.mysql.com/doc/refman/5.7/en/all-subqueries.html +type CompareSubQuery struct { + // L is the left expression + L expression.Expression + // Op is the comparison opcode. + Op opcode.Op + // R is the sub query for right expression. + R *SubQuery + // All is true, we should compare all records in subquery. + All bool +} + +// Clone implements the Expression Clone interface. +func (cs *CompareSubQuery) Clone() (expression.Expression, error) { + l, err := cs.L.Clone() + if err != nil { + return nil, errors.Trace(err) + } + + r, err := cs.R.Clone() + if err != nil { + return nil, errors.Trace(err) + } + + return &CompareSubQuery{L: l, Op: cs.Op, R: r.(*SubQuery), All: cs.All}, nil +} + +// IsStatic implements the Expression IsStatic interface. +func (cs *CompareSubQuery) IsStatic() bool { + return cs.L.IsStatic() && cs.R.IsStatic() +} + +// String implements the Expression String interface. +func (cs *CompareSubQuery) String() string { + anyOrAll := "ANY" + if cs.All { + anyOrAll = "ALL" + } + + return fmt.Sprintf("%s %s %s %s", cs.L, cs.Op, anyOrAll, cs.R) +} + +// Eval implements the Expression Eval interface. +func (cs *CompareSubQuery) Eval(ctx context.Context, args map[interface{}]interface{}) (interface{}, error) { + if err := hasSameColumnCount(ctx, cs.L, cs.R); err != nil { + return nil, errors.Trace(err) + } + + lv, err := cs.L.Eval(ctx, args) + if err != nil { + return nil, errors.Trace(err) + } + if lv == nil { + return nil, nil + } + + if cs.R.Value != nil { + return cs.checkResult(lv, cs.R.Value.([]interface{})) + } + + p, err := cs.R.Plan(ctx) + if err != nil { + return nil, errors.Trace(err) + } + + res := []interface{}{} + err = p.Do(ctx, func(id interface{}, data []interface{}) (bool, error) { + if len(data) == 1 { + res = append(res, data[0]) + } else { + res = append(res, data) + } + return true, nil + }) + if err != nil { + return nil, errors.Trace(err) + } + + cs.R.Value = res + return cs.checkResult(lv, cs.R.Value.([]interface{})) +} + +func (cs *CompareSubQuery) checkAllResult(lv interface{}, result []interface{}) (interface{}, error) { + hasNull := false + for _, v := range result { + if v == nil { + hasNull = true + continue + } + + comRes, err := types.Compare(lv, v) + if err != nil { + return nil, errors.Trace(err) + } + + res, err := getCompResult(cs.Op, comRes) + if err != nil { + return nil, errors.Trace(err) + } + if !res { + return false, nil + } + } + + if hasNull { + // If no matched but we get null, return null. + // Like `insert t (c) values (1),(2),(null)`, then + // `select 3 > all (select c from t)`, returns null. + return nil, nil + } + + return true, nil +} + +func (cs *CompareSubQuery) checkAnyResult(lv interface{}, result []interface{}) (interface{}, error) { + hasNull := false + for _, v := range result { + if v == nil { + hasNull = true + continue + } + + comRes, err := types.Compare(lv, v) + if err != nil { + return nil, errors.Trace(err) + } + + res, err := getCompResult(cs.Op, comRes) + if err != nil { + return nil, errors.Trace(err) + } + if res { + return true, nil + } + } + + if hasNull { + // If no matched but we get null, return null. + // Like `insert t (c) values (1),(2),(null)`, then + // `select 0 > any (select c from t)`, returns null. + return nil, nil + } + + return false, nil +} + +func (cs *CompareSubQuery) checkResult(lv interface{}, result []interface{}) (interface{}, error) { + if cs.All { + return cs.checkAllResult(lv, result) + } + + return cs.checkAnyResult(lv, result) +} + +// NewCompareSubQuery creates a CompareSubQuery object. +func NewCompareSubQuery(op opcode.Op, lhs expression.Expression, rhs *SubQuery, all bool) *CompareSubQuery { + return &CompareSubQuery{ + Op: op, + L: lhs, + R: rhs, + All: all, + } +} diff --git a/expression/expressions/cmp_subquery_test.go b/expression/expressions/cmp_subquery_test.go new file mode 100644 index 0000000000..8357d88b69 --- /dev/null +++ b/expression/expressions/cmp_subquery_test.go @@ -0,0 +1,170 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package expressions + +import ( + . "github.com/pingcap/check" + "github.com/pingcap/tidb/parser/opcode" + "github.com/pingcap/tidb/util/types" +) + +var _ = Suite(&testCompSubQuerySuite{}) + +type testCompSubQuerySuite struct { +} + +func (s *testCompSubQuerySuite) convert(v interface{}) interface{} { + switch x := v.(type) { + case nil: + return nil + case int: + return int64(x) + } + + return v +} + +func (s *testCompSubQuerySuite) TestCompSubQuery(c *C) { + tbl := []struct { + lhs interface{} + op opcode.Op + rhs []interface{} + all bool + result interface{} // 0 for false, 1 for true, nil for nil. + }{ + // Test any subquery. + {nil, opcode.EQ, []interface{}{1, 2}, false, nil}, + {0, opcode.EQ, []interface{}{1, 2}, false, 0}, + {0, opcode.EQ, []interface{}{1, 2, nil}, false, nil}, + {1, opcode.EQ, []interface{}{1, 1}, false, 1}, + {1, opcode.EQ, []interface{}{1, 1, nil}, false, 1}, + {nil, opcode.NE, []interface{}{1, 2}, false, nil}, + {1, opcode.NE, []interface{}{1, 2}, false, 1}, + {1, opcode.NE, []interface{}{1, 2, nil}, false, 1}, + {1, opcode.NE, []interface{}{1, 1}, false, 0}, + {1, opcode.NE, []interface{}{1, 1, nil}, false, nil}, + {nil, opcode.GT, []interface{}{1, 2}, false, nil}, + {1, opcode.GT, []interface{}{1, 2}, false, 0}, + {1, opcode.GT, []interface{}{1, 2, nil}, false, nil}, + {2, opcode.GT, []interface{}{1, 2}, false, 1}, + {2, opcode.GT, []interface{}{1, 2, nil}, false, 1}, + {3, opcode.GT, []interface{}{1, 2}, false, 1}, + {3, opcode.GT, []interface{}{1, 2, nil}, false, 1}, + {nil, opcode.GE, []interface{}{1, 2}, false, nil}, + {0, opcode.GE, []interface{}{1, 2}, false, 0}, + {0, opcode.GE, []interface{}{1, 2, nil}, false, nil}, + {1, opcode.GE, []interface{}{1, 2}, false, 1}, + {1, opcode.GE, []interface{}{1, 2, nil}, false, 1}, + {2, opcode.GE, []interface{}{1, 2}, false, 1}, + {3, opcode.GE, []interface{}{1, 2}, false, 1}, + {nil, opcode.LT, []interface{}{1, 2}, false, nil}, + {0, opcode.LT, []interface{}{1, 2}, false, 1}, + {0, opcode.LT, []interface{}{1, 2, nil}, false, 1}, + {1, opcode.LT, []interface{}{1, 2}, false, 1}, + {2, opcode.LT, []interface{}{1, 2}, false, 0}, + {2, opcode.LT, []interface{}{1, 2, nil}, false, nil}, + {3, opcode.LT, []interface{}{1, 2}, false, 0}, + {nil, opcode.LE, []interface{}{1, 2}, false, nil}, + {0, opcode.LE, []interface{}{1, 2}, false, 1}, + {0, opcode.LE, []interface{}{1, 2, nil}, false, 1}, + {1, opcode.LE, []interface{}{1, 2}, false, 1}, + {2, opcode.LE, []interface{}{1, 2}, false, 1}, + {3, opcode.LE, []interface{}{1, 2}, false, 0}, + {3, opcode.LE, []interface{}{1, 2, nil}, false, nil}, + + // Test all subquery. + {nil, opcode.EQ, []interface{}{1, 2}, true, nil}, + {0, opcode.EQ, []interface{}{1, 2}, true, 0}, + {0, opcode.EQ, []interface{}{1, 2, nil}, true, 0}, + {1, opcode.EQ, []interface{}{1, 2}, true, 0}, + {1, opcode.EQ, []interface{}{1, 2, nil}, true, 0}, + {1, opcode.EQ, []interface{}{1, 1}, true, 1}, + {1, opcode.EQ, []interface{}{1, 1, nil}, true, nil}, + {nil, opcode.NE, []interface{}{1, 2}, true, nil}, + {0, opcode.NE, []interface{}{1, 2}, true, 1}, + {1, opcode.NE, []interface{}{1, 2, nil}, true, 0}, + {1, opcode.NE, []interface{}{1, 1}, true, 0}, + {1, opcode.NE, []interface{}{1, 1, nil}, true, 0}, + {nil, opcode.GT, []interface{}{1, 2}, true, nil}, + {1, opcode.GT, []interface{}{1, 2}, true, 0}, + {1, opcode.GT, []interface{}{1, 2, nil}, true, 0}, + {2, opcode.GT, []interface{}{1, 2}, true, 0}, + {2, opcode.GT, []interface{}{1, 2, nil}, true, 0}, + {3, opcode.GT, []interface{}{1, 2}, true, 1}, + {3, opcode.GT, []interface{}{1, 2, nil}, true, nil}, + {nil, opcode.GE, []interface{}{1, 2}, true, nil}, + {0, opcode.GE, []interface{}{1, 2}, true, 0}, + {0, opcode.GE, []interface{}{1, 2, nil}, true, 0}, + {1, opcode.GE, []interface{}{1, 2}, true, 0}, + {1, opcode.GE, []interface{}{1, 2, nil}, true, 0}, + {2, opcode.GE, []interface{}{1, 2}, true, 1}, + {3, opcode.GE, []interface{}{1, 2}, true, 1}, + {3, opcode.GE, []interface{}{1, 2, nil}, true, nil}, + {nil, opcode.LT, []interface{}{1, 2}, true, nil}, + {0, opcode.LT, []interface{}{1, 2}, true, 1}, + {0, opcode.LT, []interface{}{1, 2, nil}, true, nil}, + {1, opcode.LT, []interface{}{1, 2}, true, 0}, + {2, opcode.LT, []interface{}{1, 2}, true, 0}, + {2, opcode.LT, []interface{}{1, 2, nil}, true, 0}, + {3, opcode.LT, []interface{}{1, 2}, true, 0}, + {nil, opcode.LE, []interface{}{1, 2}, true, nil}, + {0, opcode.LE, []interface{}{1, 2}, true, 1}, + {0, opcode.LE, []interface{}{1, 2, nil}, true, nil}, + {1, opcode.LE, []interface{}{1, 2}, true, 1}, + {2, opcode.LE, []interface{}{1, 2}, true, 0}, + {3, opcode.LE, []interface{}{1, 2}, true, 0}, + {3, opcode.LE, []interface{}{1, 2, nil}, true, 0}, + } + + for _, t := range tbl { + lhs := s.convert(t.lhs) + + rhs := make([][]interface{}, 0, len(t.rhs)) + for _, v := range t.rhs { + rhs = append(rhs, []interface{}{s.convert(v)}) + } + + sq := newMockSubQuery(rhs, []string{"c"}) + expr := NewCompareSubQuery(t.op, Value{lhs}, sq, t.all) + + c.Assert(expr.IsStatic(), IsFalse) + + str := expr.String() + c.Assert(len(str), Greater, 0) + + v, err := expr.Eval(nil, nil) + c.Assert(err, IsNil) + + switch x := t.result.(type) { + case nil: + c.Assert(v, IsNil) + case int: + val, err := types.ToBool(v) + c.Assert(err, IsNil) + c.Assert(val, Equals, int64(x)) + } + } + + // Test error. + sq := newMockSubQuery([][]interface{}{{1, 2}}, []string{"c1", "c2"}) + expr := NewCompareSubQuery(opcode.EQ, Value{1}, sq, true) + + _, err := expr.Eval(nil, nil) + c.Assert(err, NotNil) + + expr = NewCompareSubQuery(opcode.EQ, Value{1}, sq, false) + + _, err = expr.Eval(nil, nil) + c.Assert(err, NotNil) +} diff --git a/expression/expressions/helper.go b/expression/expressions/helper.go index 10dc8ad862..8e1cd3fd06 100644 --- a/expression/expressions/helper.go +++ b/expression/expressions/helper.go @@ -226,6 +226,8 @@ func mentionedAggregateFuncs(e expression.Expression, m *[]expression.Expression for _, expr := range x.Values { mentionedAggregateFuncs(expr, m) } + case *CompareSubQuery: + mentionedAggregateFuncs(x.L, m) default: log.Errorf("Unknown Expression: %T", e) } @@ -314,6 +316,8 @@ func mentionedColumns(e expression.Expression, m map[string]bool, names *[]strin for _, expr := range x.Values { mentionedColumns(expr, m, names) } + case *CompareSubQuery: + mentionedColumns(x.L, m, names) default: log.Errorf("Unknown Expression: %T", e) } diff --git a/parser/parser.y b/parser/parser.y index 0d9d93263e..1fdda1a5e7 100644 --- a/parser/parser.y +++ b/parser/parser.y @@ -70,6 +70,7 @@ import ( and "AND" andand "&&" andnot "&^" + any "ANY" as "AS" asc "ASC" autoIncrement "AUTO_INCREMENT" @@ -175,6 +176,7 @@ import ( share "SHARE" show "SHOW" signed "SIGNED" + some "SOME" start "START" stringType "string" substring "SUBSTRING" @@ -266,6 +268,7 @@ import ( AlterTableStmt "Alter table statement" AlterSpecification "Alter table specification" AlterSpecificationList "Alter table specification list" + AnyOrAll "Any or All for subquery" AsOpt "as optional" Assignment "assignment" AssignmentList "assignment list" @@ -1387,8 +1390,50 @@ Factor: { $$ = expressions.NewBinaryOperation(opcode.EQ, $1.(expression.Expression), $3.(expression.Expression)) } +| Factor ">=" AnyOrAll SubSelect %prec eq + { + $$ = expressions.NewCompareSubQuery(opcode.GE, $1.(expression.Expression), $4.(*expressions.SubQuery), $3.(bool)) + } +| Factor '>' AnyOrAll SubSelect %prec eq + { + $$ = expressions.NewCompareSubQuery(opcode.GT, $1.(expression.Expression), $4.(*expressions.SubQuery), $3.(bool)) + } +| Factor "<=" AnyOrAll SubSelect %prec eq + { + $$ = expressions.NewCompareSubQuery(opcode.LE, $1.(expression.Expression), $4.(*expressions.SubQuery), $3.(bool)) + } +| Factor '<' AnyOrAll SubSelect %prec eq + { + $$ = expressions.NewCompareSubQuery(opcode.LT, $1.(expression.Expression), $4.(*expressions.SubQuery), $3.(bool)) + } +| Factor "!=" AnyOrAll SubSelect %prec eq + { + $$ = expressions.NewCompareSubQuery(opcode.NE, $1.(expression.Expression), $4.(*expressions.SubQuery), $3.(bool)) + } +| Factor "<>" AnyOrAll SubSelect %prec eq + { + $$ = expressions.NewCompareSubQuery(opcode.NE, $1.(expression.Expression), $4.(*expressions.SubQuery), $3.(bool)) + } +| Factor "=" AnyOrAll SubSelect %prec eq + { + $$ = expressions.NewCompareSubQuery(opcode.EQ, $1.(expression.Expression), $4.(*expressions.SubQuery), $3.(bool)) + } | Factor1 +AnyOrAll: + "ANY" + { + $$ = false + } +| "SOME" + { + $$ = false + } +| "ALL" + { + $$ = true + } + Factor1: PrimaryFactor NotOpt "IN" '(' ExpressionList ')' { diff --git a/parser/parser_test.go b/parser/parser_test.go index c6a4438448..a0d6691f3a 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -285,6 +285,12 @@ func (s *testParserSuite) TestParser0(c *C) { {"SHOW SESSION VARIABLES LIKE 'character_set_results'", true}, {"SHOW VARIABLES", true}, {"SHOW GLOBAL VARIABLES", true}, + + // For compare subquery + {"SELECT 1 > (select 1)", true}, + {"SELECT 1 > ANY (select 1)", true}, + {"SELECT 1 > ALL (select 1)", true}, + {"SELECT 1 > SOME (select 1)", true}, } for _, t := range table { diff --git a/parser/scanner.l b/parser/scanner.l index 68a3e5c1e4..557700e400 100644 --- a/parser/scanner.l +++ b/parser/scanner.l @@ -229,6 +229,7 @@ after {a}{f}{t}{e}{r} all {a}{l}{l} alter {a}{l}{t}{e}{r} and {a}{n}{d} +any {a}{n}{y} as {a}{s} asc {a}{s}{c} auto_increment {a}{u}{t}{o}_{i}{n}{c}{r}{e}{m}{e}{n}{t} @@ -319,6 +320,7 @@ session {s}{e}{s}{s}{i}{o}{n} set {s}{e}{t} share {s}{h}{a}{r}{e} show {s}{h}{o}{w} +some {s}{o}{m}{e} start {s}{t}{a}{r}{t} substring {s}{u}{b}{s}{t}{r}{i}{n}{g} table {t}{a}{b}{l}{e} @@ -459,6 +461,8 @@ sys_var "@@"(({global}".")|({session}".")|{local}".")?{ident} {all} return all {alter} return alter {and} return and +{any} lval.item = string(l.val) + return any {asc} return asc {as} return as {auto_increment} lval.item = string(l.val) @@ -565,6 +569,8 @@ sys_var "@@"(({global}".")|({session}".")|{local}".")?{ident} {schemas} return schemas {session} lval.item = string(l.val) return session +{some} lval.item = string(l.val) + return some {start} lval.item = string(l.val) return start {global} lval.item = string(l.val)