From 2ed02b287df7b2bbc3757a920c84ea350aa3dacc Mon Sep 17 00:00:00 2001 From: Du Chuan Date: Wed, 9 Aug 2017 19:29:32 +0800 Subject: [PATCH] expression: rewrite builtin function: CHAR_LENGTH (#4105) * expression: rewrite builtin funtion 'CHAR_LENGTH' --- expression/builtin_string.go | 36 ++++++++++++++----------------- expression/builtin_string_test.go | 1 + expression/integration_test.go | 12 +++++++++++ plan/typeinfer_test.go | 16 ++++++++++++++ 4 files changed, 45 insertions(+), 20 deletions(-) diff --git a/expression/builtin_string.go b/expression/builtin_string.go index 15ae951a90..61e69e26b7 100644 --- a/expression/builtin_string.go +++ b/expression/builtin_string.go @@ -1645,33 +1645,29 @@ type charLengthFunctionClass struct { } func (c *charLengthFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) { - sig := &builtinCharLengthSig{newBaseBuiltinFunc(args, ctx)} - return sig.setSelf(sig), errors.Trace(c.verifyArgs(args)) + if argsErr := c.verifyArgs(args); argsErr != nil { + return nil, errors.Trace(argsErr) + } + bf, err := newBaseBuiltinFuncWithTp(args, ctx, tpInt, tpString) + if err != nil { + return nil, errors.Trace(err) + } + sig := &builtinCharLengthSig{baseIntBuiltinFunc{bf}} + return sig.setSelf(sig), nil } type builtinCharLengthSig struct { - baseBuiltinFunc + baseIntBuiltinFunc } -// eval evals a builtinCharLengthSig. +// evalInt evals a builtinCharLengthSig. // See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_char-length -func (b *builtinCharLengthSig) eval(row []types.Datum) (d types.Datum, err error) { - args, err := b.evalArgs(row) - if err != nil { - return types.Datum{}, errors.Trace(err) - } - switch args[0].Kind() { - case types.KindNull: - return d, nil - default: - s, err := args[0].ToString() - if err != nil { - return d, errors.Trace(err) - } - r := []rune(s) - d.SetInt64(int64(len(r))) - return d, nil +func (b *builtinCharLengthSig) evalInt(row []types.Datum) (int64, bool, error) { + val, isNull, err := b.args[0].EvalString(row, b.ctx.GetSessionVars().StmtCtx) + if isNull || err != nil { + return 0, isNull, errors.Trace(err) } + return int64(len([]rune(val))), false, nil } type findInSetFunctionClass struct { diff --git a/expression/builtin_string_test.go b/expression/builtin_string_test.go index 72fd31c2cc..85bbaf1715 100644 --- a/expression/builtin_string_test.go +++ b/expression/builtin_string_test.go @@ -1211,6 +1211,7 @@ func (s *testEvaluatorSuite) TestCharLength(c *C) { fc := funcs[ast.CharLength] f, err := fc.getFunction(datumsToConstants(types.MakeDatums(v.input)), s.ctx) c.Assert(err, IsNil) + c.Assert(f.isDeterministic(), Equals, true) r, err := f.eval(nil) c.Assert(err, IsNil) c.Assert(r, testutil.DatumEquals, types.NewDatum(v.result)) diff --git a/expression/integration_test.go b/expression/integration_test.go index 94a1cca777..a583c0863b 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -541,6 +541,18 @@ func (s *testIntegrationSuite) TestStringBuiltin(c *C) { result = tk.MustQuery(`select bin("中文");`) result.Check(testkit.Rows("0")) + // for char_length + result = tk.MustQuery(`select char_length(null);`) + result.Check(testkit.Rows("")) + result = tk.MustQuery(`select char_length("Hello");`) + result.Check(testkit.Rows("5")) + result = tk.MustQuery(`select char_length("a中b文c");`) + result.Check(testkit.Rows("5")) + result = tk.MustQuery(`select char_length(123);`) + result.Check(testkit.Rows("3")) + result = tk.MustQuery(`select char_length(12.3456);`) + result.Check(testkit.Rows("7")) + } func (s *testIntegrationSuite) TestEncryptionBuiltin(c *C) { diff --git a/plan/typeinfer_test.go b/plan/typeinfer_test.go index 129239f95b..07df152d9a 100644 --- a/plan/typeinfer_test.go +++ b/plan/typeinfer_test.go @@ -195,6 +195,22 @@ func (s *testPlanSuite) createTestCase4StrFuncs() []typeInferTestCase { {"bin(c_set )", mysql.TypeVarString, charset.CharsetUTF8, 0, 64, types.UnspecifiedLength}, {"bin(c_enum )", mysql.TypeVarString, charset.CharsetUTF8, 0, 64, types.UnspecifiedLength}, + {"char_length(c_int)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, + {"char_length(c_float)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, + {"char_length(c_double)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, + {"char_length(c_decimal)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, + {"char_length(c_datetime)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, + {"char_length(c_time)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, + {"char_length(c_timestamp)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, + {"char_length(c_char)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, + {"char_length(c_varchar)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, + {"char_length(c_text)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, + {"char_length(c_binary)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, + {"char_length(c_varbinary)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, + {"char_length(c_blob)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, + {"char_length(c_set)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, + {"char_length(c_enum)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, + {"char(c_int )", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 4, types.UnspecifiedLength}, {"char(c_bigint )", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 4, types.UnspecifiedLength}, {"char(c_float )", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 4, types.UnspecifiedLength},