From 1eed36b25eab772abffd84ebe0e0f4e2378cd545 Mon Sep 17 00:00:00 2001 From: HuaiyuXu <391585975@qq.com> Date: Thu, 24 Aug 2017 13:33:52 +0800 Subject: [PATCH] expression: refine built-in func ifnull and if (#4301) --- expression/builtin_control.go | 102 +++++++---------------------- expression/builtin_control_test.go | 1 + 2 files changed, 25 insertions(+), 78 deletions(-) diff --git a/expression/builtin_control.go b/expression/builtin_control.go index 84895b62bc..23102ad644 100644 --- a/expression/builtin_control.go +++ b/expression/builtin_control.go @@ -410,20 +410,11 @@ func (b *builtinIfIntSig) evalInt(row []types.Datum) (ret int64, isNull bool, er return 0, false, errors.Trace(err) } arg1, isNull1, err := b.args[1].EvalInt(row, sc) - if err != nil { - return 0, false, errors.Trace(err) + if (!isNull0 && arg0 != 0) || err != nil { + return arg1, isNull1, errors.Trace(err) } arg2, isNull2, err := b.args[2].EvalInt(row, sc) - if err != nil { - return 0, false, errors.Trace(err) - } - switch { - case isNull0 || arg0 == 0: - ret, isNull = arg2, isNull2 - case arg0 != 0: - ret, isNull = arg1, isNull1 - } - return + return arg2, isNull2, errors.Trace(err) } type builtinIfRealSig struct { @@ -437,20 +428,11 @@ func (b *builtinIfRealSig) evalReal(row []types.Datum) (ret float64, isNull bool return 0, false, errors.Trace(err) } arg1, isNull1, err := b.args[1].EvalReal(row, sc) - if err != nil { - return 0, false, errors.Trace(err) + if (!isNull0 && arg0 != 0) || err != nil { + return arg1, isNull1, errors.Trace(err) } arg2, isNull2, err := b.args[2].EvalReal(row, sc) - if err != nil { - return 0, false, errors.Trace(err) - } - switch { - case isNull0 || arg0 == 0: - ret, isNull = arg2, isNull2 - case arg0 != 0: - ret, isNull = arg1, isNull1 - } - return + return arg2, isNull2, errors.Trace(err) } type builtinIfDecimalSig struct { @@ -464,20 +446,11 @@ func (b *builtinIfDecimalSig) evalDecimal(row []types.Datum) (ret *types.MyDecim return nil, false, errors.Trace(err) } arg1, isNull1, err := b.args[1].EvalDecimal(row, sc) - if err != nil { - return nil, false, errors.Trace(err) + if (!isNull0 && arg0 != 0) || err != nil { + return arg1, isNull1, errors.Trace(err) } arg2, isNull2, err := b.args[2].EvalDecimal(row, sc) - if err != nil { - return nil, false, errors.Trace(err) - } - switch { - case isNull0 || arg0 == 0: - ret, isNull = arg2, isNull2 - case arg0 != 0: - ret, isNull = arg1, isNull1 - } - return + return arg2, isNull2, errors.Trace(err) } type builtinIfStringSig struct { @@ -491,20 +464,11 @@ func (b *builtinIfStringSig) evalString(row []types.Datum) (ret string, isNull b return "", false, errors.Trace(err) } arg1, isNull1, err := b.args[1].EvalString(row, sc) - if err != nil { - return "", false, errors.Trace(err) + if (!isNull0 && arg0 != 0) || err != nil { + return arg1, isNull1, errors.Trace(err) } arg2, isNull2, err := b.args[2].EvalString(row, sc) - if err != nil { - return "", false, errors.Trace(err) - } - switch { - case isNull0 || arg0 == 0: - ret, isNull = arg2, isNull2 - case arg0 != 0: - ret, isNull = arg1, isNull1 - } - return + return arg2, isNull2, errors.Trace(err) } type builtinIfTimeSig struct { @@ -518,20 +482,11 @@ func (b *builtinIfTimeSig) evalTime(row []types.Datum) (ret types.Time, isNull b return ret, false, errors.Trace(err) } arg1, isNull1, err := b.args[1].EvalTime(row, sc) - if err != nil { - return ret, false, errors.Trace(err) + if (!isNull0 && arg0 != 0) || err != nil { + return arg1, isNull1, errors.Trace(err) } arg2, isNull2, err := b.args[2].EvalTime(row, sc) - if err != nil { - return ret, false, errors.Trace(err) - } - switch { - case isNull0 || arg0 == 0: - ret, isNull = arg2, isNull2 - case arg0 != 0: - ret, isNull = arg1, isNull1 - } - return + return arg2, isNull2, errors.Trace(err) } type builtinIfDurationSig struct { @@ -545,20 +500,11 @@ func (b *builtinIfDurationSig) evalDuration(row []types.Datum) (ret types.Durati return ret, false, errors.Trace(err) } arg1, isNull1, err := b.args[1].EvalDuration(row, sc) - if err != nil { - return ret, false, errors.Trace(err) + if (!isNull0 && arg0 != 0) || err != nil { + return arg1, isNull1, errors.Trace(err) } arg2, isNull2, err := b.args[2].EvalDuration(row, sc) - if err != nil { - return ret, false, errors.Trace(err) - } - switch { - case isNull0 || arg0 == 0: - ret, isNull = arg2, isNull2 - case arg0 != 0: - ret, isNull = arg1, isNull1 - } - return + return arg2, isNull2, errors.Trace(err) } type ifNullFunctionClass struct { @@ -607,7 +553,7 @@ type builtinIfNullIntSig struct { func (b *builtinIfNullIntSig) evalInt(row []types.Datum) (int64, bool, error) { sc := b.ctx.GetSessionVars().StmtCtx arg0, isNull, err := b.args[0].EvalInt(row, sc) - if !isNull { + if !isNull || err != nil { return arg0, false, errors.Trace(err) } arg1, isNull, err := b.args[1].EvalInt(row, sc) @@ -621,7 +567,7 @@ type builtinIfNullRealSig struct { func (b *builtinIfNullRealSig) evalReal(row []types.Datum) (float64, bool, error) { sc := b.ctx.GetSessionVars().StmtCtx arg0, isNull, err := b.args[0].EvalReal(row, sc) - if !isNull { + if !isNull || err != nil { return arg0, false, errors.Trace(err) } arg1, isNull, err := b.args[1].EvalReal(row, sc) @@ -635,7 +581,7 @@ type builtinIfNullDecimalSig struct { func (b *builtinIfNullDecimalSig) evalDecimal(row []types.Datum) (*types.MyDecimal, bool, error) { sc := b.ctx.GetSessionVars().StmtCtx arg0, isNull, err := b.args[0].EvalDecimal(row, sc) - if !isNull { + if !isNull || err != nil { return arg0, false, errors.Trace(err) } arg1, isNull, err := b.args[1].EvalDecimal(row, sc) @@ -649,7 +595,7 @@ type builtinIfNullStringSig struct { func (b *builtinIfNullStringSig) evalString(row []types.Datum) (string, bool, error) { sc := b.ctx.GetSessionVars().StmtCtx arg0, isNull, err := b.args[0].EvalString(row, sc) - if !isNull { + if !isNull || err != nil { return arg0, false, errors.Trace(err) } arg1, isNull, err := b.args[1].EvalString(row, sc) @@ -663,7 +609,7 @@ type builtinIfNullTimeSig struct { func (b *builtinIfNullTimeSig) evalTime(row []types.Datum) (types.Time, bool, error) { sc := b.ctx.GetSessionVars().StmtCtx arg0, isNull, err := b.args[0].EvalTime(row, sc) - if !isNull { + if !isNull || err != nil { return arg0, false, errors.Trace(err) } arg1, isNull, err := b.args[1].EvalTime(row, sc) @@ -677,7 +623,7 @@ type builtinIfNullDurationSig struct { func (b *builtinIfNullDurationSig) evalDuration(row []types.Datum) (types.Duration, bool, error) { sc := b.ctx.GetSessionVars().StmtCtx arg0, isNull, err := b.args[0].EvalDuration(row, sc) - if !isNull { + if !isNull || err != nil { return arg0, false, errors.Trace(err) } arg1, isNull, err := b.args[1].EvalDuration(row, sc) diff --git a/expression/builtin_control_test.go b/expression/builtin_control_test.go index 4853bf46b9..8de8fc3a25 100644 --- a/expression/builtin_control_test.go +++ b/expression/builtin_control_test.go @@ -114,6 +114,7 @@ func (s *testEvaluatorSuite) TestIfNull(c *C) { {nil, types.Hex{Value: 1}, "\x01", false, false}, {nil, types.Set{Value: 1, Name: "abc"}, "abc", false, false}, {"abc", nil, "abc", false, false}, + {errors.New(""), nil, "", true, true}, } for _, t := range tbl {