Files
tidb/expression/cmp_subquery.go
2015-09-23 10:18:56 +08:00

176 lines
4.3 KiB
Go

// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package expression
import (
"fmt"
"github.com/juju/errors"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/util/types"
)
// CompareSubQuery is the expression for "expr cmp (select ...)".
// See: https://dev.mysql.com/doc/refman/5.7/en/comparisons-using-subqueries.html
// See: https://dev.mysql.com/doc/refman/5.7/en/any-in-some-subqueries.html
// See: https://dev.mysql.com/doc/refman/5.7/en/all-subqueries.html
type CompareSubQuery struct {
// L is the left expression
L Expression
// Op is the comparison opcode.
Op opcode.Op
// R is the sub query for right expression.
R SubQuery
// All is true, we should compare all records in subquery.
All bool
}
// Clone implements the Expression Clone interface.
func (cs *CompareSubQuery) Clone() Expression {
l := cs.L.Clone()
r := cs.R.Clone()
return &CompareSubQuery{L: l, Op: cs.Op, R: r.(SubQuery), All: cs.All}
}
// IsStatic implements the Expression IsStatic interface.
func (cs *CompareSubQuery) IsStatic() bool {
return cs.L.IsStatic() && cs.R.IsStatic()
}
// String implements the Expression String interface.
func (cs *CompareSubQuery) String() string {
anyOrAll := "ANY"
if cs.All {
anyOrAll = "ALL"
}
return fmt.Sprintf("%s %s %s %s", cs.L, cs.Op, anyOrAll, cs.R)
}
// Eval implements the Expression Eval interface.
func (cs *CompareSubQuery) Eval(ctx context.Context, args map[interface{}]interface{}) (interface{}, error) {
if err := hasSameColumnCount(ctx, cs.L, cs.R); err != nil {
return nil, errors.Trace(err)
}
lv, err := cs.L.Eval(ctx, args)
if err != nil {
return nil, errors.Trace(err)
}
if lv == nil {
return nil, nil
}
if !cs.R.UseOuterQuery() && cs.R.Value() != nil {
return cs.checkResult(lv, cs.R.Value().([]interface{}))
}
res, err := cs.R.EvalRows(ctx, args, -1)
if err != nil {
return nil, errors.Trace(err)
}
cs.R.SetValue(res)
return cs.checkResult(lv, cs.R.Value().([]interface{}))
}
// Accept implements Expression Accept interface.
func (cs *CompareSubQuery) Accept(v Visitor) (Expression, error) {
return v.VisitCompareSubQuery(cs)
}
func (cs *CompareSubQuery) checkAllResult(lv interface{}, result []interface{}) (interface{}, error) {
hasNull := false
for _, v := range result {
if v == nil {
hasNull = true
continue
}
comRes, err := types.Compare(lv, v)
if err != nil {
return nil, errors.Trace(err)
}
res, err := getCompResult(cs.Op, comRes)
if err != nil {
return nil, errors.Trace(err)
}
if !res {
return false, nil
}
}
if hasNull {
// If no matched but we get null, return null.
// Like `insert t (c) values (1),(2),(null)`, then
// `select 3 > all (select c from t)`, returns null.
return nil, nil
}
return true, nil
}
func (cs *CompareSubQuery) checkAnyResult(lv interface{}, result []interface{}) (interface{}, error) {
hasNull := false
for _, v := range result {
if v == nil {
hasNull = true
continue
}
comRes, err := types.Compare(lv, v)
if err != nil {
return nil, errors.Trace(err)
}
res, err := getCompResult(cs.Op, comRes)
if err != nil {
return nil, errors.Trace(err)
}
if res {
return true, nil
}
}
if hasNull {
// If no matched but we get null, return null.
// Like `insert t (c) values (1),(2),(null)`, then
// `select 0 > any (select c from t)`, returns null.
return nil, nil
}
return false, nil
}
func (cs *CompareSubQuery) checkResult(lv interface{}, result []interface{}) (interface{}, error) {
if cs.All {
return cs.checkAllResult(lv, result)
}
return cs.checkAnyResult(lv, result)
}
// NewCompareSubQuery creates a CompareSubQuery object.
func NewCompareSubQuery(op opcode.Op, lhs Expression, rhs SubQuery, all bool) *CompareSubQuery {
return &CompareSubQuery{
Op: op,
L: lhs,
R: rhs,
All: all,
}
}