176 lines
4.3 KiB
Go
176 lines
4.3 KiB
Go
// Copyright 2015 PingCAP, Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package expression
|
|
|
|
import (
|
|
"fmt"
|
|
|
|
"github.com/juju/errors"
|
|
"github.com/pingcap/tidb/context"
|
|
|
|
"github.com/pingcap/tidb/parser/opcode"
|
|
"github.com/pingcap/tidb/util/types"
|
|
)
|
|
|
|
// CompareSubQuery is the expression for "expr cmp (select ...)".
|
|
// See: https://dev.mysql.com/doc/refman/5.7/en/comparisons-using-subqueries.html
|
|
// See: https://dev.mysql.com/doc/refman/5.7/en/any-in-some-subqueries.html
|
|
// See: https://dev.mysql.com/doc/refman/5.7/en/all-subqueries.html
|
|
type CompareSubQuery struct {
|
|
// L is the left expression
|
|
L Expression
|
|
// Op is the comparison opcode.
|
|
Op opcode.Op
|
|
// R is the sub query for right expression.
|
|
R SubQuery
|
|
// All is true, we should compare all records in subquery.
|
|
All bool
|
|
}
|
|
|
|
// Clone implements the Expression Clone interface.
|
|
func (cs *CompareSubQuery) Clone() Expression {
|
|
l := cs.L.Clone()
|
|
r := cs.R.Clone()
|
|
return &CompareSubQuery{L: l, Op: cs.Op, R: r.(SubQuery), All: cs.All}
|
|
}
|
|
|
|
// IsStatic implements the Expression IsStatic interface.
|
|
func (cs *CompareSubQuery) IsStatic() bool {
|
|
return cs.L.IsStatic() && cs.R.IsStatic()
|
|
}
|
|
|
|
// String implements the Expression String interface.
|
|
func (cs *CompareSubQuery) String() string {
|
|
anyOrAll := "ANY"
|
|
if cs.All {
|
|
anyOrAll = "ALL"
|
|
}
|
|
|
|
return fmt.Sprintf("%s %s %s %s", cs.L, cs.Op, anyOrAll, cs.R)
|
|
}
|
|
|
|
// Eval implements the Expression Eval interface.
|
|
func (cs *CompareSubQuery) Eval(ctx context.Context, args map[interface{}]interface{}) (interface{}, error) {
|
|
if err := hasSameColumnCount(ctx, cs.L, cs.R); err != nil {
|
|
return nil, errors.Trace(err)
|
|
}
|
|
|
|
lv, err := cs.L.Eval(ctx, args)
|
|
if err != nil {
|
|
return nil, errors.Trace(err)
|
|
}
|
|
if lv == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
if !cs.R.UseOuterQuery() && cs.R.Value() != nil {
|
|
return cs.checkResult(lv, cs.R.Value().([]interface{}))
|
|
}
|
|
|
|
res, err := cs.R.EvalRows(ctx, args, -1)
|
|
if err != nil {
|
|
return nil, errors.Trace(err)
|
|
}
|
|
|
|
cs.R.SetValue(res)
|
|
return cs.checkResult(lv, cs.R.Value().([]interface{}))
|
|
}
|
|
|
|
// Accept implements Expression Accept interface.
|
|
func (cs *CompareSubQuery) Accept(v Visitor) (Expression, error) {
|
|
return v.VisitCompareSubQuery(cs)
|
|
}
|
|
|
|
func (cs *CompareSubQuery) checkAllResult(lv interface{}, result []interface{}) (interface{}, error) {
|
|
hasNull := false
|
|
for _, v := range result {
|
|
if v == nil {
|
|
hasNull = true
|
|
continue
|
|
}
|
|
|
|
comRes, err := types.Compare(lv, v)
|
|
if err != nil {
|
|
return nil, errors.Trace(err)
|
|
}
|
|
|
|
res, err := getCompResult(cs.Op, comRes)
|
|
if err != nil {
|
|
return nil, errors.Trace(err)
|
|
}
|
|
if !res {
|
|
return false, nil
|
|
}
|
|
}
|
|
|
|
if hasNull {
|
|
// If no matched but we get null, return null.
|
|
// Like `insert t (c) values (1),(2),(null)`, then
|
|
// `select 3 > all (select c from t)`, returns null.
|
|
return nil, nil
|
|
}
|
|
|
|
return true, nil
|
|
}
|
|
|
|
func (cs *CompareSubQuery) checkAnyResult(lv interface{}, result []interface{}) (interface{}, error) {
|
|
hasNull := false
|
|
for _, v := range result {
|
|
if v == nil {
|
|
hasNull = true
|
|
continue
|
|
}
|
|
|
|
comRes, err := types.Compare(lv, v)
|
|
if err != nil {
|
|
return nil, errors.Trace(err)
|
|
}
|
|
|
|
res, err := getCompResult(cs.Op, comRes)
|
|
if err != nil {
|
|
return nil, errors.Trace(err)
|
|
}
|
|
if res {
|
|
return true, nil
|
|
}
|
|
}
|
|
|
|
if hasNull {
|
|
// If no matched but we get null, return null.
|
|
// Like `insert t (c) values (1),(2),(null)`, then
|
|
// `select 0 > any (select c from t)`, returns null.
|
|
return nil, nil
|
|
}
|
|
|
|
return false, nil
|
|
}
|
|
|
|
func (cs *CompareSubQuery) checkResult(lv interface{}, result []interface{}) (interface{}, error) {
|
|
if cs.All {
|
|
return cs.checkAllResult(lv, result)
|
|
}
|
|
|
|
return cs.checkAnyResult(lv, result)
|
|
}
|
|
|
|
// NewCompareSubQuery creates a CompareSubQuery object.
|
|
func NewCompareSubQuery(op opcode.Op, lhs Expression, rhs SubQuery, all bool) *CompareSubQuery {
|
|
return &CompareSubQuery{
|
|
Op: op,
|
|
L: lhs,
|
|
R: rhs,
|
|
All: all,
|
|
}
|
|
}
|