Files
tidb/expression/builtin_op.go
2017-01-28 18:44:02 +08:00

413 lines
9.9 KiB
Go

// Copyright 2016 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package expression
import (
"github.com/juju/errors"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/util/types"
)
var (
_ functionClass = &andandFunctionClass{}
_ functionClass = &ororFunctionClass{}
_ functionClass = &logicXorFunctionClass{}
_ functionClass = &bitOpFunctionClass{}
_ functionClass = &isTrueOpFunctionClass{}
_ functionClass = &unaryOpFunctionClass{}
_ functionClass = &isNullFunctionClass{}
)
var (
_ builtinFunc = &builtinAndAndSig{}
_ builtinFunc = &builtinOrOrSig{}
_ builtinFunc = &builtinLogicXorSig{}
_ builtinFunc = &builtinBitOpSig{}
_ builtinFunc = &builtinIsTrueOpSig{}
_ builtinFunc = &builtinUnaryOpSig{}
_ builtinFunc = &builtinIsNullSig{}
)
type andandFunctionClass struct {
baseFunctionClass
}
func (c *andandFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
return &builtinAndAndSig{newBaseBuiltinFunc(args, ctx)}, errors.Trace(c.verifyArgs(args))
}
type builtinAndAndSig struct {
baseBuiltinFunc
}
func (b *builtinAndAndSig) eval(row []types.Datum) (d types.Datum, err error) {
leftDatum, err := b.args[0].Eval(row)
if err != nil {
return d, errors.Trace(err)
}
sc := b.ctx.GetSessionVars().StmtCtx
if !leftDatum.IsNull() {
var x int64
x, err = leftDatum.ToBool(sc)
if err != nil {
return d, errors.Trace(err)
} else if x == 0 {
// false && any other types is false
d.SetInt64(x)
return
}
}
rightDatum, err := b.args[1].Eval(row)
if err != nil {
return d, errors.Trace(err)
}
if !rightDatum.IsNull() {
var y int64
y, err = rightDatum.ToBool(sc)
if err != nil {
return d, errors.Trace(err)
} else if y == 0 {
d.SetInt64(y)
return
}
}
if leftDatum.IsNull() || rightDatum.IsNull() {
return
}
d.SetInt64(int64(1))
return
}
type ororFunctionClass struct {
baseFunctionClass
}
func (c *ororFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
return &builtinOrOrSig{newBaseBuiltinFunc(args, ctx)}, errors.Trace(c.verifyArgs(args))
}
type builtinOrOrSig struct {
baseBuiltinFunc
}
func (b *builtinOrOrSig) eval(row []types.Datum) (d types.Datum, err error) {
leftDatum, err := b.args[0].Eval(row)
if err != nil {
return d, errors.Trace(err)
}
sc := b.ctx.GetSessionVars().StmtCtx
if !leftDatum.IsNull() {
var x int64
x, err = leftDatum.ToBool(sc)
if err != nil {
return d, errors.Trace(err)
} else if x == 1 {
// false && any other types is false
d.SetInt64(x)
return
}
}
rightDatum, err := b.args[1].Eval(row)
if err != nil {
return d, errors.Trace(err)
}
if !rightDatum.IsNull() {
var y int64
y, err = rightDatum.ToBool(sc)
if err != nil {
return d, errors.Trace(err)
} else if y == 1 {
d.SetInt64(y)
return
}
}
if leftDatum.IsNull() || rightDatum.IsNull() {
return
}
d.SetInt64(int64(0))
return
}
type logicXorFunctionClass struct {
baseFunctionClass
}
func (c *logicXorFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
return &builtinLogicXorSig{newBaseBuiltinFunc(args, ctx)}, errors.Trace(c.verifyArgs(args))
}
type builtinLogicXorSig struct {
baseBuiltinFunc
}
func (b *builtinLogicXorSig) eval(row []types.Datum) (d types.Datum, err error) {
args, err := b.evalArgs(row)
if err != nil {
return types.Datum{}, errors.Trace(err)
}
leftDatum := args[0]
righDatum := args[1]
if leftDatum.IsNull() || righDatum.IsNull() {
return
}
sc := b.ctx.GetSessionVars().StmtCtx
x, err := leftDatum.ToBool(sc)
if err != nil {
return d, errors.Trace(err)
}
y, err := righDatum.ToBool(sc)
if err != nil {
return d, errors.Trace(err)
}
if x == y {
d.SetInt64(zeroI64)
} else {
d.SetInt64(oneI64)
}
return
}
type bitOpFunctionClass struct {
baseFunctionClass
op opcode.Op
}
func (c *bitOpFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
return &builtinBitOpSig{newBaseBuiltinFunc(args, ctx), c.op}, errors.Trace(c.verifyArgs(args))
}
type builtinBitOpSig struct {
baseBuiltinFunc
op opcode.Op
}
func (s *builtinBitOpSig) eval(row []types.Datum) (d types.Datum, err error) {
args, err := s.evalArgs(row)
if err != nil {
return types.Datum{}, errors.Trace(err)
}
sc := s.ctx.GetSessionVars().StmtCtx
a, b, err := types.CoerceDatum(sc, args[0], args[1])
if err != nil {
return d, errors.Trace(err)
}
if a.IsNull() || b.IsNull() {
return
}
x, err := a.ToInt64(sc)
if err != nil {
return d, errors.Trace(err)
}
y, err := b.ToInt64(sc)
if err != nil {
return d, errors.Trace(err)
}
// use a int64 for bit operator, return uint64
switch s.op {
case opcode.And:
d.SetUint64(uint64(x & y))
case opcode.Or:
d.SetUint64(uint64(x | y))
case opcode.Xor:
d.SetUint64(uint64(x ^ y))
case opcode.RightShift:
d.SetUint64(uint64(x) >> uint64(y))
case opcode.LeftShift:
d.SetUint64(uint64(x) << uint64(y))
default:
return d, errInvalidOperation.Gen("invalid op %v in bit operation", s.op)
}
return
}
type isTrueOpFunctionClass struct {
baseFunctionClass
op opcode.Op
}
func (c *isTrueOpFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
return &builtinIsTrueOpSig{newBaseBuiltinFunc(args, ctx), c.op}, errors.Trace(c.verifyArgs(args))
}
type builtinIsTrueOpSig struct {
baseBuiltinFunc
op opcode.Op
}
func (b *builtinIsTrueOpSig) eval(row []types.Datum) (d types.Datum, err error) {
args, err := b.evalArgs(row)
if err != nil {
return types.Datum{}, errors.Trace(err)
}
var boolVal bool
if !args[0].IsNull() {
iVal, err := args[0].ToBool(b.ctx.GetSessionVars().StmtCtx)
if err != nil {
return d, errors.Trace(err)
}
if (b.op == opcode.IsTruth && iVal == 1) || (b.op == opcode.IsFalsity && iVal == 0) {
boolVal = true
}
}
d.SetInt64(boolToInt64(boolVal))
return
}
type unaryOpFunctionClass struct {
baseFunctionClass
op opcode.Op
}
func (c *unaryOpFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
return &builtinUnaryOpSig{newBaseBuiltinFunc(args, ctx), c.op}, errors.Trace(c.verifyArgs(args))
}
type builtinUnaryOpSig struct {
baseBuiltinFunc
op opcode.Op
}
func (b *builtinUnaryOpSig) eval(row []types.Datum) (d types.Datum, err error) {
args, err := b.evalArgs(row)
if err != nil {
return types.Datum{}, errors.Trace(err)
}
defer func() {
if er := recover(); er != nil {
err = errors.Errorf("%v", er)
}
}()
aDatum := args[0]
if aDatum.IsNull() {
return
}
sc := b.ctx.GetSessionVars().StmtCtx
switch b.op {
case opcode.Not:
var n int64
n, err = aDatum.ToBool(sc)
if err != nil {
err = errors.Trace(err)
} else if n == 0 {
d.SetInt64(1)
} else {
d.SetInt64(0)
}
case opcode.BitNeg:
var n int64
// for bit operation, we will use int64 first, then return uint64
n, err = aDatum.ToInt64(sc)
if err != nil {
return d, errors.Trace(err)
}
d.SetUint64(uint64(^n))
case opcode.Plus:
switch aDatum.Kind() {
case types.KindInt64,
types.KindUint64,
types.KindFloat64,
types.KindFloat32,
types.KindMysqlDuration,
types.KindMysqlTime,
types.KindString,
types.KindMysqlDecimal,
types.KindBytes,
types.KindMysqlHex,
types.KindMysqlBit,
types.KindMysqlEnum,
types.KindMysqlSet:
d = aDatum
default:
return d, errInvalidOperation.Gen("Unsupported type %v for op.Plus", aDatum.Kind())
}
case opcode.Minus:
switch aDatum.Kind() {
case types.KindInt64:
d.SetInt64(-aDatum.GetInt64())
case types.KindUint64:
d.SetInt64(-int64(aDatum.GetUint64()))
case types.KindFloat64:
d.SetFloat64(-aDatum.GetFloat64())
case types.KindFloat32:
d.SetFloat32(-aDatum.GetFloat32())
case types.KindMysqlDuration:
dec := new(types.MyDecimal)
err = types.DecimalSub(new(types.MyDecimal), aDatum.GetMysqlDuration().ToNumber(), dec)
d.SetMysqlDecimal(dec)
case types.KindMysqlTime:
dec := new(types.MyDecimal)
err = types.DecimalSub(new(types.MyDecimal), aDatum.GetMysqlTime().ToNumber(), dec)
d.SetMysqlDecimal(dec)
case types.KindString, types.KindBytes:
f, err1 := types.StrToFloat(sc, aDatum.GetString())
err = errors.Trace(err1)
d.SetFloat64(-f)
case types.KindMysqlDecimal:
dec := new(types.MyDecimal)
err = types.DecimalSub(new(types.MyDecimal), aDatum.GetMysqlDecimal(), dec)
d.SetMysqlDecimal(dec)
case types.KindMysqlHex:
d.SetFloat64(-aDatum.GetMysqlHex().ToNumber())
case types.KindMysqlBit:
d.SetFloat64(-aDatum.GetMysqlBit().ToNumber())
case types.KindMysqlEnum:
d.SetFloat64(-aDatum.GetMysqlEnum().ToNumber())
case types.KindMysqlSet:
d.SetFloat64(-aDatum.GetMysqlSet().ToNumber())
default:
return d, errInvalidOperation.Gen("Unsupported type %v for op.Minus", aDatum.Kind())
}
default:
return d, errInvalidOperation.Gen("Unsupported op %v for unary op", b.op)
}
return
}
type isNullFunctionClass struct {
baseFunctionClass
}
func (c *isNullFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
return &builtinIsNullSig{newBaseBuiltinFunc(args, ctx)}, errors.Trace(c.verifyArgs(args))
}
type builtinIsNullSig struct {
baseBuiltinFunc
}
// See https://dev.mysql.com/doc/refman/5.7/en/comparison-operators.html#function_isnull
func (b *builtinIsNullSig) eval(row []types.Datum) (d types.Datum, err error) {
args, err := b.evalArgs(row)
if err != nil {
return types.Datum{}, errors.Trace(err)
}
if args[0].IsNull() {
d.SetInt64(1)
} else {
d.SetInt64(0)
}
return d, nil
}