Files
tidb/expression/expressions/cmp_subquery.go
2015-09-14 10:27:57 +08:00

192 lines
4.5 KiB
Go

// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package expressions
import (
"fmt"
"github.com/juju/errors"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/util/types"
)
// CompareSubQuery is the expression for "expr cmp (select ...)".
// See: https://dev.mysql.com/doc/refman/5.7/en/comparisons-using-subqueries.html
// See: https://dev.mysql.com/doc/refman/5.7/en/any-in-some-subqueries.html
// See: https://dev.mysql.com/doc/refman/5.7/en/all-subqueries.html
type CompareSubQuery struct {
// L is the left expression
L expression.Expression
// Op is the comparison opcode.
Op opcode.Op
// R is the sub query for right expression.
R *SubQuery
// All is true, we should compare all records in subquery.
All bool
}
// Clone implements the Expression Clone interface.
func (cs *CompareSubQuery) Clone() (expression.Expression, error) {
l, err := cs.L.Clone()
if err != nil {
return nil, errors.Trace(err)
}
r, err := cs.R.Clone()
if err != nil {
return nil, errors.Trace(err)
}
return &CompareSubQuery{L: l, Op: cs.Op, R: r.(*SubQuery), All: cs.All}, nil
}
// IsStatic implements the Expression IsStatic interface.
func (cs *CompareSubQuery) IsStatic() bool {
return cs.L.IsStatic() && cs.R.IsStatic()
}
// String implements the Expression String interface.
func (cs *CompareSubQuery) String() string {
anyOrAll := "ANY"
if cs.All {
anyOrAll = "ALL"
}
return fmt.Sprintf("%s %s %s %s", cs.L, cs.Op, anyOrAll, cs.R)
}
// Eval implements the Expression Eval interface.
func (cs *CompareSubQuery) Eval(ctx context.Context, args map[interface{}]interface{}) (interface{}, error) {
if err := hasSameColumnCount(ctx, cs.L, cs.R); err != nil {
return nil, errors.Trace(err)
}
lv, err := cs.L.Eval(ctx, args)
if err != nil {
return nil, errors.Trace(err)
}
if lv == nil {
return nil, nil
}
if cs.R.Value != nil {
return cs.checkResult(lv, cs.R.Value.([]interface{}))
}
p, err := cs.R.Plan(ctx)
if err != nil {
return nil, errors.Trace(err)
}
res := []interface{}{}
err = p.Do(ctx, func(id interface{}, data []interface{}) (bool, error) {
if len(data) == 1 {
res = append(res, data[0])
} else {
res = append(res, data)
}
return true, nil
})
if err != nil {
return nil, errors.Trace(err)
}
cs.R.Value = res
return cs.checkResult(lv, cs.R.Value.([]interface{}))
}
func (cs *CompareSubQuery) checkAllResult(lv interface{}, result []interface{}) (interface{}, error) {
hasNull := false
for _, v := range result {
if v == nil {
hasNull = true
continue
}
comRes, err := types.Compare(lv, v)
if err != nil {
return nil, errors.Trace(err)
}
res, err := getCompResult(cs.Op, comRes)
if err != nil {
return nil, errors.Trace(err)
}
if !res {
return false, nil
}
}
if hasNull {
// If no matched but we get null, return null.
// Like `insert t (c) values (1),(2),(null)`, then
// `select 3 > all (select c from t)`, returns null.
return nil, nil
}
return true, nil
}
func (cs *CompareSubQuery) checkAnyResult(lv interface{}, result []interface{}) (interface{}, error) {
hasNull := false
for _, v := range result {
if v == nil {
hasNull = true
continue
}
comRes, err := types.Compare(lv, v)
if err != nil {
return nil, errors.Trace(err)
}
res, err := getCompResult(cs.Op, comRes)
if err != nil {
return nil, errors.Trace(err)
}
if res {
return true, nil
}
}
if hasNull {
// If no matched but we get null, return null.
// Like `insert t (c) values (1),(2),(null)`, then
// `select 0 > any (select c from t)`, returns null.
return nil, nil
}
return false, nil
}
func (cs *CompareSubQuery) checkResult(lv interface{}, result []interface{}) (interface{}, error) {
if cs.All {
return cs.checkAllResult(lv, result)
}
return cs.checkAnyResult(lv, result)
}
// NewCompareSubQuery creates a CompareSubQuery object.
func NewCompareSubQuery(op opcode.Op, lhs expression.Expression, rhs *SubQuery, all bool) *CompareSubQuery {
return &CompareSubQuery{
Op: op,
L: lhs,
R: rhs,
All: all,
}
}