expression, plan: rewrite builtin function: IS TRUE && IS FALSE (#4086)
This commit is contained in:
@ -879,8 +879,8 @@ var funcs = map[string]functionClass{
|
||||
ast.UnaryPlus: &unaryOpFunctionClass{baseFunctionClass{ast.UnaryPlus, 1, 1}, opcode.Plus},
|
||||
ast.UnaryMinus: &unaryMinusFunctionClass{baseFunctionClass{ast.UnaryMinus, 1, 1}},
|
||||
ast.In: &inFunctionClass{baseFunctionClass{ast.In, 1, -1}},
|
||||
ast.IsTruth: &isTrueOpFunctionClass{baseFunctionClass{ast.IsTruth, 1, 1}, opcode.IsTruth},
|
||||
ast.IsFalsity: &isTrueOpFunctionClass{baseFunctionClass{ast.IsFalsity, 1, 1}, opcode.IsFalsity},
|
||||
ast.IsTruth: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsTruth, 1, 1}, opcode.IsTruth},
|
||||
ast.IsFalsity: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsFalsity, 1, 1}, opcode.IsFalsity},
|
||||
ast.Like: &likeFunctionClass{baseFunctionClass{ast.Like, 2, 3}},
|
||||
ast.Regexp: ®expFunctionClass{baseFunctionClass{ast.Regexp, 2, 2}},
|
||||
ast.Case: &caseWhenFunctionClass{baseFunctionClass{ast.Case, 1, -1}},
|
||||
|
||||
@ -28,7 +28,7 @@ var (
|
||||
_ functionClass = &logicAndFunctionClass{}
|
||||
_ functionClass = &logicOrFunctionClass{}
|
||||
_ functionClass = &logicXorFunctionClass{}
|
||||
_ functionClass = &isTrueOpFunctionClass{}
|
||||
_ functionClass = &isTrueOrFalseFunctionClass{}
|
||||
_ functionClass = &unaryOpFunctionClass{}
|
||||
_ functionClass = &unaryMinusFunctionClass{}
|
||||
_ functionClass = &isNullFunctionClass{}
|
||||
@ -39,7 +39,12 @@ var (
|
||||
_ builtinFunc = &builtinLogicAndSig{}
|
||||
_ builtinFunc = &builtinLogicOrSig{}
|
||||
_ builtinFunc = &builtinLogicXorSig{}
|
||||
_ builtinFunc = &builtinIsTrueOpSig{}
|
||||
_ builtinFunc = &builtinRealIsTrueSig{}
|
||||
_ builtinFunc = &builtinDecimalIsTrueSig{}
|
||||
_ builtinFunc = &builtinIntIsTrueSig{}
|
||||
_ builtinFunc = &builtinRealIsFalseSig{}
|
||||
_ builtinFunc = &builtinDecimalIsFalseSig{}
|
||||
_ builtinFunc = &builtinIntIsFalseSig{}
|
||||
_ builtinFunc = &builtinUnaryOpSig{}
|
||||
_ builtinFunc = &builtinUnaryMinusIntSig{}
|
||||
_ builtinFunc = &builtinDecimalIsNullSig{}
|
||||
@ -346,40 +351,141 @@ func (b *builtinRightShiftSig) evalInt(row []types.Datum) (int64, bool, error) {
|
||||
return int64(uint64(arg0) >> uint64(arg1)), false, nil
|
||||
}
|
||||
|
||||
type isTrueOpFunctionClass struct {
|
||||
type isTrueOrFalseFunctionClass struct {
|
||||
baseFunctionClass
|
||||
|
||||
op opcode.Op
|
||||
}
|
||||
|
||||
func (c *isTrueOpFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
|
||||
sig := &builtinIsTrueOpSig{newBaseBuiltinFunc(args, ctx), c.op}
|
||||
return sig.setSelf(sig), errors.Trace(c.verifyArgs(args))
|
||||
}
|
||||
func (c *isTrueOrFalseFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
|
||||
if err := c.verifyArgs(args); err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
type builtinIsTrueOpSig struct {
|
||||
baseBuiltinFunc
|
||||
|
||||
op opcode.Op
|
||||
}
|
||||
|
||||
func (b *builtinIsTrueOpSig) eval(row []types.Datum) (d types.Datum, err error) {
|
||||
args, err := b.evalArgs(row)
|
||||
argTp := tpInt
|
||||
switch args[0].GetTypeClass() {
|
||||
case types.ClassReal:
|
||||
argTp = tpReal
|
||||
case types.ClassDecimal:
|
||||
argTp = tpDecimal
|
||||
}
|
||||
bf, err := newBaseBuiltinFuncWithTp(args, ctx, tpInt, argTp)
|
||||
if err != nil {
|
||||
return types.Datum{}, errors.Trace(err)
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
var boolVal bool
|
||||
if !args[0].IsNull() {
|
||||
iVal, err := args[0].ToBool(b.ctx.GetSessionVars().StmtCtx)
|
||||
if err != nil {
|
||||
return d, errors.Trace(err)
|
||||
bf.tp.Flen = 1
|
||||
|
||||
var sig builtinFunc
|
||||
switch c.op {
|
||||
case opcode.IsTruth:
|
||||
switch argTp {
|
||||
case tpReal:
|
||||
sig = &builtinRealIsTrueSig{baseIntBuiltinFunc{bf}}
|
||||
case tpDecimal:
|
||||
sig = &builtinDecimalIsTrueSig{baseIntBuiltinFunc{bf}}
|
||||
case tpInt:
|
||||
sig = &builtinIntIsTrueSig{baseIntBuiltinFunc{bf}}
|
||||
}
|
||||
if (b.op == opcode.IsTruth && iVal == 1) || (b.op == opcode.IsFalsity && iVal == 0) {
|
||||
boolVal = true
|
||||
case opcode.IsFalsity:
|
||||
switch argTp {
|
||||
case tpReal:
|
||||
sig = &builtinRealIsFalseSig{baseIntBuiltinFunc{bf}}
|
||||
case tpDecimal:
|
||||
sig = &builtinDecimalIsFalseSig{baseIntBuiltinFunc{bf}}
|
||||
case tpInt:
|
||||
sig = &builtinIntIsFalseSig{baseIntBuiltinFunc{bf}}
|
||||
}
|
||||
}
|
||||
d.SetInt64(boolToInt64(boolVal))
|
||||
return
|
||||
return sig.setSelf(sig), nil
|
||||
}
|
||||
|
||||
type builtinRealIsTrueSig struct {
|
||||
baseIntBuiltinFunc
|
||||
}
|
||||
|
||||
func (b *builtinRealIsTrueSig) evalInt(row []types.Datum) (int64, bool, error) {
|
||||
input, isNull, err := b.args[0].EvalReal(row, b.ctx.GetSessionVars().StmtCtx)
|
||||
if err != nil {
|
||||
return 0, true, errors.Trace(err)
|
||||
}
|
||||
if isNull || input == 0 {
|
||||
return 0, false, nil
|
||||
}
|
||||
return 1, false, nil
|
||||
}
|
||||
|
||||
type builtinDecimalIsTrueSig struct {
|
||||
baseIntBuiltinFunc
|
||||
}
|
||||
|
||||
func (b *builtinDecimalIsTrueSig) evalInt(row []types.Datum) (int64, bool, error) {
|
||||
input, isNull, err := b.args[0].EvalDecimal(row, b.ctx.GetSessionVars().StmtCtx)
|
||||
if err != nil {
|
||||
return 0, true, errors.Trace(err)
|
||||
}
|
||||
if isNull || input.IsZero() {
|
||||
return 0, false, nil
|
||||
}
|
||||
return 1, false, nil
|
||||
}
|
||||
|
||||
type builtinIntIsTrueSig struct {
|
||||
baseIntBuiltinFunc
|
||||
}
|
||||
|
||||
func (b *builtinIntIsTrueSig) evalInt(row []types.Datum) (int64, bool, error) {
|
||||
input, isNull, err := b.args[0].EvalInt(row, b.ctx.GetSessionVars().StmtCtx)
|
||||
if err != nil {
|
||||
return 0, true, errors.Trace(err)
|
||||
}
|
||||
if isNull || input == 0 {
|
||||
return 0, false, nil
|
||||
}
|
||||
return 1, false, nil
|
||||
}
|
||||
|
||||
type builtinRealIsFalseSig struct {
|
||||
baseIntBuiltinFunc
|
||||
}
|
||||
|
||||
func (b *builtinRealIsFalseSig) evalInt(row []types.Datum) (int64, bool, error) {
|
||||
input, isNull, err := b.args[0].EvalReal(row, b.ctx.GetSessionVars().StmtCtx)
|
||||
if err != nil {
|
||||
return 0, true, errors.Trace(err)
|
||||
}
|
||||
if isNull || input != 0 {
|
||||
return 0, false, nil
|
||||
}
|
||||
return 1, false, nil
|
||||
}
|
||||
|
||||
type builtinDecimalIsFalseSig struct {
|
||||
baseIntBuiltinFunc
|
||||
}
|
||||
|
||||
func (b *builtinDecimalIsFalseSig) evalInt(row []types.Datum) (int64, bool, error) {
|
||||
input, isNull, err := b.args[0].EvalDecimal(row, b.ctx.GetSessionVars().StmtCtx)
|
||||
if err != nil {
|
||||
return 0, true, errors.Trace(err)
|
||||
}
|
||||
if isNull || !input.IsZero() {
|
||||
return 0, false, nil
|
||||
}
|
||||
return 1, false, nil
|
||||
}
|
||||
|
||||
type builtinIntIsFalseSig struct {
|
||||
baseIntBuiltinFunc
|
||||
}
|
||||
|
||||
func (b *builtinIntIsFalseSig) evalInt(row []types.Datum) (int64, bool, error) {
|
||||
input, isNull, err := b.args[0].EvalInt(row, b.ctx.GetSessionVars().StmtCtx)
|
||||
if err != nil {
|
||||
return 0, true, errors.Trace(err)
|
||||
}
|
||||
if isNull || input != 0 {
|
||||
return 0, false, nil
|
||||
}
|
||||
return 1, false, nil
|
||||
}
|
||||
|
||||
type bitNegFunctionClass struct {
|
||||
|
||||
@ -20,6 +20,7 @@ import (
|
||||
. "github.com/pingcap/check"
|
||||
"github.com/pingcap/tidb/ast"
|
||||
"github.com/pingcap/tidb/util/testleak"
|
||||
"github.com/pingcap/tidb/util/testutil"
|
||||
"github.com/pingcap/tidb/util/types"
|
||||
)
|
||||
|
||||
@ -476,3 +477,77 @@ func (s *testEvaluatorSuite) TestUnaryNot(c *C) {
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(f.isDeterministic(), IsTrue)
|
||||
}
|
||||
|
||||
func (s *testEvaluatorSuite) TestIsTrueOrFalse(c *C) {
|
||||
defer testleak.AfterTest(c)()
|
||||
sc := s.ctx.GetSessionVars().StmtCtx
|
||||
origin := sc.IgnoreTruncate
|
||||
defer func() {
|
||||
sc.IgnoreTruncate = origin
|
||||
}()
|
||||
sc.IgnoreTruncate = true
|
||||
|
||||
testCases := []struct {
|
||||
args []interface{}
|
||||
isTrue interface{}
|
||||
isFalse interface{}
|
||||
}{
|
||||
{
|
||||
args: []interface{}{-12},
|
||||
isTrue: 1,
|
||||
isFalse: 0,
|
||||
},
|
||||
{
|
||||
args: []interface{}{12},
|
||||
isTrue: 1,
|
||||
isFalse: 0,
|
||||
},
|
||||
{
|
||||
args: []interface{}{0},
|
||||
isTrue: 0,
|
||||
isFalse: 1,
|
||||
},
|
||||
{
|
||||
args: []interface{}{float64(0)},
|
||||
isTrue: 0,
|
||||
isFalse: 1,
|
||||
},
|
||||
{
|
||||
args: []interface{}{"aaa"},
|
||||
isTrue: 0,
|
||||
isFalse: 1,
|
||||
},
|
||||
{
|
||||
args: []interface{}{""},
|
||||
isTrue: 0,
|
||||
isFalse: 1,
|
||||
},
|
||||
{
|
||||
args: []interface{}{nil},
|
||||
isTrue: 0,
|
||||
isFalse: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
isTrueSig, err := funcs[ast.IsTruth].getFunction(datumsToConstants(types.MakeDatums(tc.args...)), s.ctx)
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(isTrueSig, NotNil)
|
||||
c.Assert(isTrueSig.isDeterministic(), IsTrue)
|
||||
|
||||
isTrue, err := isTrueSig.eval(nil)
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(isTrue, testutil.DatumEquals, types.NewDatum(tc.isTrue))
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
isFalseSig, err := funcs[ast.IsFalsity].getFunction(datumsToConstants(types.MakeDatums(tc.args...)), s.ctx)
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(isFalseSig, NotNil)
|
||||
c.Assert(isFalseSig.isDeterministic(), IsTrue)
|
||||
|
||||
isFalse, err := isFalseSig.eval(nil)
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(isFalse, testutil.DatumEquals, types.NewDatum(tc.isFalse))
|
||||
}
|
||||
}
|
||||
|
||||
@ -830,7 +830,7 @@ func (s *testIntegrationSuite) TestBuiltin(c *C) {
|
||||
tk := testkit.NewTestKit(c, s.store)
|
||||
tk.MustExec("use test")
|
||||
|
||||
// for is true
|
||||
// for is true && is false
|
||||
tk.MustExec("drop table if exists t")
|
||||
tk.MustExec("create table t (a int, b int, index idx_b (b))")
|
||||
tk.MustExec("insert t values (1, 1)")
|
||||
@ -844,6 +844,11 @@ func (s *testIntegrationSuite) TestBuiltin(c *C) {
|
||||
result.Check(nil)
|
||||
result = tk.MustQuery("select * from t where a is not true")
|
||||
result.Check(nil)
|
||||
result = tk.MustQuery(`select 1 is true, 0 is true, null is true, "aaa" is true, "" is true, -12.00 is true, 0.0 is true, 0.0000001 is true;`)
|
||||
result.Check(testkit.Rows("1 0 0 0 0 1 0 1"))
|
||||
result = tk.MustQuery(`select 1 is false, 0 is false, null is false, "aaa" is false, "" is false, -12.00 is false, 0.0 is false, 0.0000001 is false;`)
|
||||
result.Check(testkit.Rows("0 1 0 1 1 0 1 0"))
|
||||
|
||||
// for in
|
||||
result = tk.MustQuery("select * from t where b in (a)")
|
||||
result.Check(testkit.Rows("1 1", "2 2"))
|
||||
|
||||
@ -81,6 +81,7 @@ func (s *testPlanSuite) TestInferType(c *C) {
|
||||
tests = append(tests, s.createTestCase4EncryptionFuncs()...)
|
||||
tests = append(tests, s.createTestCase4CompareFuncs()...)
|
||||
tests = append(tests, s.createTestCase4Miscellaneous()...)
|
||||
tests = append(tests, s.createTestCase4OpFuncs()...)
|
||||
|
||||
for _, tt := range tests {
|
||||
ctx := testKit.Se.(context.Context)
|
||||
@ -584,3 +585,29 @@ func (s *testPlanSuite) createTestCase4Miscellaneous() []typeInferTestCase {
|
||||
{"sleep(c_binary)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 20, 0},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *testPlanSuite) createTestCase4OpFuncs() []typeInferTestCase {
|
||||
return []typeInferTestCase{
|
||||
{"c_int is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
|
||||
{"c_decimal is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
|
||||
{"c_double is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
|
||||
{"c_float is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
|
||||
{"c_datetime is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
|
||||
{"c_time is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
|
||||
{"c_enum is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
|
||||
{"c_text is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
|
||||
{"18446 is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
|
||||
{"1844674.1 is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
|
||||
|
||||
{"c_int is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
|
||||
{"c_decimal is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
|
||||
{"c_double is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
|
||||
{"c_float is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
|
||||
{"c_datetime is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
|
||||
{"c_time is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
|
||||
{"c_enum is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
|
||||
{"c_text is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
|
||||
{"18446 is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
|
||||
{"1844674.1 is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user