Files
tidb/expression/builtin_control.go
2017-01-25 20:27:54 +08:00

184 lines
4.5 KiB
Go

// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package expression
import (
"github.com/juju/errors"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/util/types"
)
var (
_ functionClass = &caseWhenFunctionClass{}
_ functionClass = &ifFunctionClass{}
_ functionClass = &ifNullFunctionClass{}
_ functionClass = &nullIfFunctionClass{}
)
var (
_ builtinFunc = &builtinCaseWhenSig{}
_ builtinFunc = &builtinIfSig{}
_ builtinFunc = &builtinIfNullSig{}
_ builtinFunc = &builtinNullIfSig{}
)
type caseWhenFunctionClass struct {
baseFunctionClass
}
func (c *caseWhenFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
return &builtinCaseWhenSig{newBaseBuiltinFunc(args, ctx)}, errors.Trace(c.verifyArgs(args))
}
type builtinCaseWhenSig struct {
baseBuiltinFunc
}
// See https://dev.mysql.com/doc/refman/5.7/en/case.html
func (b *builtinCaseWhenSig) eval(row []types.Datum) (d types.Datum, err error) {
args, err := b.evalArgs(row)
if err != nil {
return types.Datum{}, errors.Trace(err)
}
sc := b.ctx.GetSessionVars().StmtCtx
l := len(args)
for i := 0; i < l-1; i += 2 {
if args[i].IsNull() {
continue
}
b, err1 := args[i].ToBool(sc)
if err1 != nil {
return d, errors.Trace(err1)
}
if b == 1 {
d = args[i+1]
return
}
}
// when clause(condition, result) -> args[i], args[i+1]; (i >= 0 && i+1 < l-1)
// else clause -> args[l-1]
// If case clause has else clause, l%2 == 1.
if l%2 == 1 {
d = args[l-1]
}
return
}
type ifFunctionClass struct {
baseFunctionClass
}
func (c *ifFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
return &builtinIfSig{newBaseBuiltinFunc(args, ctx)}, errors.Trace(c.verifyArgs(args))
}
type builtinIfSig struct {
baseBuiltinFunc
}
// See https://dev.mysql.com/doc/refman/5.7/en/control-flow-functions.html#function_if
func (s *builtinIfSig) eval(row []types.Datum) (types.Datum, error) {
args, err := s.evalArgs(row)
if err != nil {
return types.Datum{}, errors.Trace(err)
}
// if(expr1, expr2, expr3)
// if expr1 is true, return expr2, otherwise, return expr3
v1 := args[0]
v2 := args[1]
v3 := args[2]
if v1.IsNull() {
return v3, nil
}
b, err := v1.ToBool(s.ctx.GetSessionVars().StmtCtx)
if err != nil {
d := types.Datum{}
return d, errors.Trace(err)
}
// TODO: check return type, must be numeric or string
if b == 1 {
return v2, nil
}
return v3, nil
}
type ifNullFunctionClass struct {
baseFunctionClass
}
func (c *ifNullFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
return &builtinIfNullSig{newBaseBuiltinFunc(args, ctx)}, errors.Trace(c.verifyArgs(args))
}
type builtinIfNullSig struct {
baseBuiltinFunc
}
// See https://dev.mysql.com/doc/refman/5.7/en/control-flow-functions.html#function_ifnull
func (b *builtinIfNullSig) eval(row []types.Datum) (types.Datum, error) {
args, err := b.evalArgs(row)
if err != nil {
return types.Datum{}, errors.Trace(err)
}
// ifnull(expr1, expr2)
// if expr1 is not null, return expr1, otherwise, return expr2
v1 := args[0]
v2 := args[1]
if !v1.IsNull() {
return v1, nil
}
return v2, nil
}
type nullIfFunctionClass struct {
baseFunctionClass
}
func (c *nullIfFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
return &builtinNullIfSig{newBaseBuiltinFunc(args, ctx)}, errors.Trace(c.verifyArgs(args))
}
type builtinNullIfSig struct {
baseBuiltinFunc
}
// See https://dev.mysql.com/doc/refman/5.7/en/control-flow-functions.html#function_nullif
func (b *builtinNullIfSig) eval(row []types.Datum) (types.Datum, error) {
args, err := b.evalArgs(row)
if err != nil {
return types.Datum{}, errors.Trace(err)
}
// nullif(expr1, expr2)
// returns null if expr1 = expr2 is true, otherwise returns expr1
v1 := args[0]
v2 := args[1]
if v1.IsNull() || v2.IsNull() {
return v1, nil
}
if n, err1 := v1.CompareDatum(b.ctx.GetSessionVars().StmtCtx, v2); err1 != nil || n == 0 {
d := types.Datum{}
return d, errors.Trace(err1)
}
return v1, nil
}