Merge pull request #125 from pingcap/qiuyesuifeng/any-all-support

Support any/some/all subquery.
This commit is contained in:
qiuyesuifeng
2015-09-14 10:33:06 +08:00
7 changed files with 446 additions and 15 deletions

View File

@ -364,6 +364,25 @@ func (o *BinaryOperation) evalLogicOp(ctx context.Context, args map[interface{}]
}
}
func getCompResult(op opcode.Op, value int) (bool, error) {
switch op {
case opcode.LT:
return value < 0, nil
case opcode.LE:
return value <= 0, nil
case opcode.GE:
return value >= 0, nil
case opcode.GT:
return value > 0, nil
case opcode.EQ:
return value == 0, nil
case opcode.NE:
return value != 0, nil
default:
return false, errors.Errorf("invalid op %v in comparision operation", op)
}
}
// operator: >=, >, <=, <, !=, <>, = <=>, etc.
// see https://dev.mysql.com/doc/refman/5.7/en/comparison-operators.html
func (o *BinaryOperation) evalComparisonOp(ctx context.Context, args map[interface{}]interface{}) (interface{}, error) {
@ -383,22 +402,12 @@ func (o *BinaryOperation) evalComparisonOp(ctx context.Context, args map[interfa
return nil, o.traceErr(err)
}
switch o.Op {
case opcode.LT:
return n < 0, nil
case opcode.LE:
return n <= 0, nil
case opcode.GE:
return n >= 0, nil
case opcode.GT:
return n > 0, nil
case opcode.EQ:
return n == 0, nil
case opcode.NE:
return n != 0, nil
default:
return nil, o.errorf("invalid op %v in comparision operation", o.Op)
r, err := getCompResult(o.Op, n)
if err != nil {
return nil, o.errorf(err.Error())
}
return r, nil
}
func (o *BinaryOperation) evalPlus(a interface{}, b interface{}) (interface{}, error) {

View File

@ -0,0 +1,191 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package expressions
import (
"fmt"
"github.com/juju/errors"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/util/types"
)
// CompareSubQuery is the expression for "expr cmp (select ...)".
// See: https://dev.mysql.com/doc/refman/5.7/en/comparisons-using-subqueries.html
// See: https://dev.mysql.com/doc/refman/5.7/en/any-in-some-subqueries.html
// See: https://dev.mysql.com/doc/refman/5.7/en/all-subqueries.html
type CompareSubQuery struct {
// L is the left expression
L expression.Expression
// Op is the comparison opcode.
Op opcode.Op
// R is the sub query for right expression.
R *SubQuery
// All is true, we should compare all records in subquery.
All bool
}
// Clone implements the Expression Clone interface.
func (cs *CompareSubQuery) Clone() (expression.Expression, error) {
l, err := cs.L.Clone()
if err != nil {
return nil, errors.Trace(err)
}
r, err := cs.R.Clone()
if err != nil {
return nil, errors.Trace(err)
}
return &CompareSubQuery{L: l, Op: cs.Op, R: r.(*SubQuery), All: cs.All}, nil
}
// IsStatic implements the Expression IsStatic interface.
func (cs *CompareSubQuery) IsStatic() bool {
return cs.L.IsStatic() && cs.R.IsStatic()
}
// String implements the Expression String interface.
func (cs *CompareSubQuery) String() string {
anyOrAll := "ANY"
if cs.All {
anyOrAll = "ALL"
}
return fmt.Sprintf("%s %s %s %s", cs.L, cs.Op, anyOrAll, cs.R)
}
// Eval implements the Expression Eval interface.
func (cs *CompareSubQuery) Eval(ctx context.Context, args map[interface{}]interface{}) (interface{}, error) {
if err := hasSameColumnCount(ctx, cs.L, cs.R); err != nil {
return nil, errors.Trace(err)
}
lv, err := cs.L.Eval(ctx, args)
if err != nil {
return nil, errors.Trace(err)
}
if lv == nil {
return nil, nil
}
if cs.R.Value != nil {
return cs.checkResult(lv, cs.R.Value.([]interface{}))
}
p, err := cs.R.Plan(ctx)
if err != nil {
return nil, errors.Trace(err)
}
res := []interface{}{}
err = p.Do(ctx, func(id interface{}, data []interface{}) (bool, error) {
if len(data) == 1 {
res = append(res, data[0])
} else {
res = append(res, data)
}
return true, nil
})
if err != nil {
return nil, errors.Trace(err)
}
cs.R.Value = res
return cs.checkResult(lv, cs.R.Value.([]interface{}))
}
func (cs *CompareSubQuery) checkAllResult(lv interface{}, result []interface{}) (interface{}, error) {
hasNull := false
for _, v := range result {
if v == nil {
hasNull = true
continue
}
comRes, err := types.Compare(lv, v)
if err != nil {
return nil, errors.Trace(err)
}
res, err := getCompResult(cs.Op, comRes)
if err != nil {
return nil, errors.Trace(err)
}
if !res {
return false, nil
}
}
if hasNull {
// If no matched but we get null, return null.
// Like `insert t (c) values (1),(2),(null)`, then
// `select 3 > all (select c from t)`, returns null.
return nil, nil
}
return true, nil
}
func (cs *CompareSubQuery) checkAnyResult(lv interface{}, result []interface{}) (interface{}, error) {
hasNull := false
for _, v := range result {
if v == nil {
hasNull = true
continue
}
comRes, err := types.Compare(lv, v)
if err != nil {
return nil, errors.Trace(err)
}
res, err := getCompResult(cs.Op, comRes)
if err != nil {
return nil, errors.Trace(err)
}
if res {
return true, nil
}
}
if hasNull {
// If no matched but we get null, return null.
// Like `insert t (c) values (1),(2),(null)`, then
// `select 0 > any (select c from t)`, returns null.
return nil, nil
}
return false, nil
}
func (cs *CompareSubQuery) checkResult(lv interface{}, result []interface{}) (interface{}, error) {
if cs.All {
return cs.checkAllResult(lv, result)
}
return cs.checkAnyResult(lv, result)
}
// NewCompareSubQuery creates a CompareSubQuery object.
func NewCompareSubQuery(op opcode.Op, lhs expression.Expression, rhs *SubQuery, all bool) *CompareSubQuery {
return &CompareSubQuery{
Op: op,
L: lhs,
R: rhs,
All: all,
}
}

View File

@ -0,0 +1,170 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package expressions
import (
. "github.com/pingcap/check"
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/util/types"
)
var _ = Suite(&testCompSubQuerySuite{})
type testCompSubQuerySuite struct {
}
func (s *testCompSubQuerySuite) convert(v interface{}) interface{} {
switch x := v.(type) {
case nil:
return nil
case int:
return int64(x)
}
return v
}
func (s *testCompSubQuerySuite) TestCompSubQuery(c *C) {
tbl := []struct {
lhs interface{}
op opcode.Op
rhs []interface{}
all bool
result interface{} // 0 for false, 1 for true, nil for nil.
}{
// Test any subquery.
{nil, opcode.EQ, []interface{}{1, 2}, false, nil},
{0, opcode.EQ, []interface{}{1, 2}, false, 0},
{0, opcode.EQ, []interface{}{1, 2, nil}, false, nil},
{1, opcode.EQ, []interface{}{1, 1}, false, 1},
{1, opcode.EQ, []interface{}{1, 1, nil}, false, 1},
{nil, opcode.NE, []interface{}{1, 2}, false, nil},
{1, opcode.NE, []interface{}{1, 2}, false, 1},
{1, opcode.NE, []interface{}{1, 2, nil}, false, 1},
{1, opcode.NE, []interface{}{1, 1}, false, 0},
{1, opcode.NE, []interface{}{1, 1, nil}, false, nil},
{nil, opcode.GT, []interface{}{1, 2}, false, nil},
{1, opcode.GT, []interface{}{1, 2}, false, 0},
{1, opcode.GT, []interface{}{1, 2, nil}, false, nil},
{2, opcode.GT, []interface{}{1, 2}, false, 1},
{2, opcode.GT, []interface{}{1, 2, nil}, false, 1},
{3, opcode.GT, []interface{}{1, 2}, false, 1},
{3, opcode.GT, []interface{}{1, 2, nil}, false, 1},
{nil, opcode.GE, []interface{}{1, 2}, false, nil},
{0, opcode.GE, []interface{}{1, 2}, false, 0},
{0, opcode.GE, []interface{}{1, 2, nil}, false, nil},
{1, opcode.GE, []interface{}{1, 2}, false, 1},
{1, opcode.GE, []interface{}{1, 2, nil}, false, 1},
{2, opcode.GE, []interface{}{1, 2}, false, 1},
{3, opcode.GE, []interface{}{1, 2}, false, 1},
{nil, opcode.LT, []interface{}{1, 2}, false, nil},
{0, opcode.LT, []interface{}{1, 2}, false, 1},
{0, opcode.LT, []interface{}{1, 2, nil}, false, 1},
{1, opcode.LT, []interface{}{1, 2}, false, 1},
{2, opcode.LT, []interface{}{1, 2}, false, 0},
{2, opcode.LT, []interface{}{1, 2, nil}, false, nil},
{3, opcode.LT, []interface{}{1, 2}, false, 0},
{nil, opcode.LE, []interface{}{1, 2}, false, nil},
{0, opcode.LE, []interface{}{1, 2}, false, 1},
{0, opcode.LE, []interface{}{1, 2, nil}, false, 1},
{1, opcode.LE, []interface{}{1, 2}, false, 1},
{2, opcode.LE, []interface{}{1, 2}, false, 1},
{3, opcode.LE, []interface{}{1, 2}, false, 0},
{3, opcode.LE, []interface{}{1, 2, nil}, false, nil},
// Test all subquery.
{nil, opcode.EQ, []interface{}{1, 2}, true, nil},
{0, opcode.EQ, []interface{}{1, 2}, true, 0},
{0, opcode.EQ, []interface{}{1, 2, nil}, true, 0},
{1, opcode.EQ, []interface{}{1, 2}, true, 0},
{1, opcode.EQ, []interface{}{1, 2, nil}, true, 0},
{1, opcode.EQ, []interface{}{1, 1}, true, 1},
{1, opcode.EQ, []interface{}{1, 1, nil}, true, nil},
{nil, opcode.NE, []interface{}{1, 2}, true, nil},
{0, opcode.NE, []interface{}{1, 2}, true, 1},
{1, opcode.NE, []interface{}{1, 2, nil}, true, 0},
{1, opcode.NE, []interface{}{1, 1}, true, 0},
{1, opcode.NE, []interface{}{1, 1, nil}, true, 0},
{nil, opcode.GT, []interface{}{1, 2}, true, nil},
{1, opcode.GT, []interface{}{1, 2}, true, 0},
{1, opcode.GT, []interface{}{1, 2, nil}, true, 0},
{2, opcode.GT, []interface{}{1, 2}, true, 0},
{2, opcode.GT, []interface{}{1, 2, nil}, true, 0},
{3, opcode.GT, []interface{}{1, 2}, true, 1},
{3, opcode.GT, []interface{}{1, 2, nil}, true, nil},
{nil, opcode.GE, []interface{}{1, 2}, true, nil},
{0, opcode.GE, []interface{}{1, 2}, true, 0},
{0, opcode.GE, []interface{}{1, 2, nil}, true, 0},
{1, opcode.GE, []interface{}{1, 2}, true, 0},
{1, opcode.GE, []interface{}{1, 2, nil}, true, 0},
{2, opcode.GE, []interface{}{1, 2}, true, 1},
{3, opcode.GE, []interface{}{1, 2}, true, 1},
{3, opcode.GE, []interface{}{1, 2, nil}, true, nil},
{nil, opcode.LT, []interface{}{1, 2}, true, nil},
{0, opcode.LT, []interface{}{1, 2}, true, 1},
{0, opcode.LT, []interface{}{1, 2, nil}, true, nil},
{1, opcode.LT, []interface{}{1, 2}, true, 0},
{2, opcode.LT, []interface{}{1, 2}, true, 0},
{2, opcode.LT, []interface{}{1, 2, nil}, true, 0},
{3, opcode.LT, []interface{}{1, 2}, true, 0},
{nil, opcode.LE, []interface{}{1, 2}, true, nil},
{0, opcode.LE, []interface{}{1, 2}, true, 1},
{0, opcode.LE, []interface{}{1, 2, nil}, true, nil},
{1, opcode.LE, []interface{}{1, 2}, true, 1},
{2, opcode.LE, []interface{}{1, 2}, true, 0},
{3, opcode.LE, []interface{}{1, 2}, true, 0},
{3, opcode.LE, []interface{}{1, 2, nil}, true, 0},
}
for _, t := range tbl {
lhs := s.convert(t.lhs)
rhs := make([][]interface{}, 0, len(t.rhs))
for _, v := range t.rhs {
rhs = append(rhs, []interface{}{s.convert(v)})
}
sq := newMockSubQuery(rhs, []string{"c"})
expr := NewCompareSubQuery(t.op, Value{lhs}, sq, t.all)
c.Assert(expr.IsStatic(), IsFalse)
str := expr.String()
c.Assert(len(str), Greater, 0)
v, err := expr.Eval(nil, nil)
c.Assert(err, IsNil)
switch x := t.result.(type) {
case nil:
c.Assert(v, IsNil)
case int:
val, err := types.ToBool(v)
c.Assert(err, IsNil)
c.Assert(val, Equals, int64(x))
}
}
// Test error.
sq := newMockSubQuery([][]interface{}{{1, 2}}, []string{"c1", "c2"})
expr := NewCompareSubQuery(opcode.EQ, Value{1}, sq, true)
_, err := expr.Eval(nil, nil)
c.Assert(err, NotNil)
expr = NewCompareSubQuery(opcode.EQ, Value{1}, sq, false)
_, err = expr.Eval(nil, nil)
c.Assert(err, NotNil)
}

View File

@ -226,6 +226,8 @@ func mentionedAggregateFuncs(e expression.Expression, m *[]expression.Expression
for _, expr := range x.Values {
mentionedAggregateFuncs(expr, m)
}
case *CompareSubQuery:
mentionedAggregateFuncs(x.L, m)
default:
log.Errorf("Unknown Expression: %T", e)
}
@ -314,6 +316,8 @@ func mentionedColumns(e expression.Expression, m map[string]bool, names *[]strin
for _, expr := range x.Values {
mentionedColumns(expr, m, names)
}
case *CompareSubQuery:
mentionedColumns(x.L, m, names)
default:
log.Errorf("Unknown Expression: %T", e)
}

View File

@ -70,6 +70,7 @@ import (
and "AND"
andand "&&"
andnot "&^"
any "ANY"
as "AS"
asc "ASC"
autoIncrement "AUTO_INCREMENT"
@ -175,6 +176,7 @@ import (
share "SHARE"
show "SHOW"
signed "SIGNED"
some "SOME"
start "START"
stringType "string"
substring "SUBSTRING"
@ -266,6 +268,7 @@ import (
AlterTableStmt "Alter table statement"
AlterSpecification "Alter table specification"
AlterSpecificationList "Alter table specification list"
AnyOrAll "Any or All for subquery"
AsOpt "as optional"
Assignment "assignment"
AssignmentList "assignment list"
@ -1387,8 +1390,50 @@ Factor:
{
$$ = expressions.NewBinaryOperation(opcode.EQ, $1.(expression.Expression), $3.(expression.Expression))
}
| Factor ">=" AnyOrAll SubSelect %prec eq
{
$$ = expressions.NewCompareSubQuery(opcode.GE, $1.(expression.Expression), $4.(*expressions.SubQuery), $3.(bool))
}
| Factor '>' AnyOrAll SubSelect %prec eq
{
$$ = expressions.NewCompareSubQuery(opcode.GT, $1.(expression.Expression), $4.(*expressions.SubQuery), $3.(bool))
}
| Factor "<=" AnyOrAll SubSelect %prec eq
{
$$ = expressions.NewCompareSubQuery(opcode.LE, $1.(expression.Expression), $4.(*expressions.SubQuery), $3.(bool))
}
| Factor '<' AnyOrAll SubSelect %prec eq
{
$$ = expressions.NewCompareSubQuery(opcode.LT, $1.(expression.Expression), $4.(*expressions.SubQuery), $3.(bool))
}
| Factor "!=" AnyOrAll SubSelect %prec eq
{
$$ = expressions.NewCompareSubQuery(opcode.NE, $1.(expression.Expression), $4.(*expressions.SubQuery), $3.(bool))
}
| Factor "<>" AnyOrAll SubSelect %prec eq
{
$$ = expressions.NewCompareSubQuery(opcode.NE, $1.(expression.Expression), $4.(*expressions.SubQuery), $3.(bool))
}
| Factor "=" AnyOrAll SubSelect %prec eq
{
$$ = expressions.NewCompareSubQuery(opcode.EQ, $1.(expression.Expression), $4.(*expressions.SubQuery), $3.(bool))
}
| Factor1
AnyOrAll:
"ANY"
{
$$ = false
}
| "SOME"
{
$$ = false
}
| "ALL"
{
$$ = true
}
Factor1:
PrimaryFactor NotOpt "IN" '(' ExpressionList ')'
{

View File

@ -285,6 +285,12 @@ func (s *testParserSuite) TestParser0(c *C) {
{"SHOW SESSION VARIABLES LIKE 'character_set_results'", true},
{"SHOW VARIABLES", true},
{"SHOW GLOBAL VARIABLES", true},
// For compare subquery
{"SELECT 1 > (select 1)", true},
{"SELECT 1 > ANY (select 1)", true},
{"SELECT 1 > ALL (select 1)", true},
{"SELECT 1 > SOME (select 1)", true},
}
for _, t := range table {

View File

@ -229,6 +229,7 @@ after {a}{f}{t}{e}{r}
all {a}{l}{l}
alter {a}{l}{t}{e}{r}
and {a}{n}{d}
any {a}{n}{y}
as {a}{s}
asc {a}{s}{c}
auto_increment {a}{u}{t}{o}_{i}{n}{c}{r}{e}{m}{e}{n}{t}
@ -319,6 +320,7 @@ session {s}{e}{s}{s}{i}{o}{n}
set {s}{e}{t}
share {s}{h}{a}{r}{e}
show {s}{h}{o}{w}
some {s}{o}{m}{e}
start {s}{t}{a}{r}{t}
substring {s}{u}{b}{s}{t}{r}{i}{n}{g}
table {t}{a}{b}{l}{e}
@ -459,6 +461,8 @@ sys_var "@@"(({global}".")|({session}".")|{local}".")?{ident}
{all} return all
{alter} return alter
{and} return and
{any} lval.item = string(l.val)
return any
{asc} return asc
{as} return as
{auto_increment} lval.item = string(l.val)
@ -565,6 +569,8 @@ sys_var "@@"(({global}".")|({session}".")|{local}".")?{ident}
{schemas} return schemas
{session} lval.item = string(l.val)
return session
{some} lval.item = string(l.val)
return some
{start} lval.item = string(l.val)
return start
{global} lval.item = string(l.val)