From abeea90449527e57c7469188bdeb08530339ed17 Mon Sep 17 00:00:00 2001 From: zhaoxingyu Date: Tue, 22 Mar 2016 14:04:23 +0800 Subject: [PATCH 1/2] evaluator: change the remaining string functions. Change the remaining string functions. --- evaluator/builtin.go | 30 ++- evaluator/builtin_string.go | 219 ++++++++++++---------- evaluator/builtin_string_test.go | 308 ++++++++++++++++++++++++++++++- evaluator/evaluator_test.go | 290 ----------------------------- util/types/datum.go | 55 +++++- 5 files changed, 488 insertions(+), 414 deletions(-) diff --git a/evaluator/builtin.go b/evaluator/builtin.go index 5d44b616c4..d288d03015 100644 --- a/evaluator/builtin.go +++ b/evaluator/builtin.go @@ -53,15 +53,6 @@ var OldFuncs = map[string]OldFunc{ "ifnull": {builtinIfNull, 2, 2, true, false}, "nullif": {builtinNullIf, 2, 2, true, false}, - // string functions - "replace": {builtinReplace, 3, 3, true, false}, - "strcmp": {builtinStrcmp, 2, 2, true, false}, - "convert": {builtinConvert, 2, 2, true, false}, - "substring": {builtinSubstring, 2, 3, true, false}, - "substring_index": {builtinSubstringIndex, 3, 3, true, false}, - "locate": {builtinLocate, 2, 3, true, false}, - "trim": {builtinTrim, 1, 3, true, false}, - // information functions "current_user": {builtinCurrentUser, 0, 0, false, false}, "database": {builtinDatabase, 0, 0, false, false}, @@ -110,13 +101,20 @@ var Funcs = map[string]Func{ "date_arith": {builtinDateArith, 3, 3}, // string functions - "concat": {builtinConcat, 1, -1}, - "concat_ws": {builtinConcatWS, 2, -1}, - "left": {builtinLeft, 2, 2}, - "length": {builtinLength, 1, 1}, - "lower": {builtinLower, 1, 1}, - "repeat": {builtinRepeat, 2, 2}, - "upper": {builtinUpper, 1, 1}, + "concat": {builtinConcat, 1, -1}, + "concat_ws": {builtinConcatWS, 2, -1}, + "left": {builtinLeft, 2, 2}, + "length": {builtinLength, 1, 1}, + "lower": {builtinLower, 1, 1}, + "repeat": {builtinRepeat, 2, 2}, + "upper": {builtinUpper, 1, 1}, + "replace": {builtinReplace, 3, 3}, + "strcmp": {builtinStrcmp, 2, 2}, + "convert": {builtinConvert, 2, 2}, + "substring": {builtinSubstring, 2, 3}, + "substring_index": {builtinSubstringIndex, 3, 3}, + "locate": {builtinLocate, 2, 3}, + "trim": {builtinTrim, 1, 3}, } // See: http://dev.mysql.com/doc/refman/5.7/en/comparison-operators.html#function_coalesce diff --git a/evaluator/builtin_string.go b/evaluator/builtin_string.go index 1751c83f94..382f5d8d1c 100644 --- a/evaluator/builtin_string.go +++ b/evaluator/builtin_string.go @@ -180,103 +180,114 @@ func builtinUpper(args []types.Datum, _ context.Context) (d types.Datum, err err } // See: https://dev.mysql.com/doc/refman/5.7/en/string-comparison-functions.html -func builtinStrcmp(args []interface{}, _ context.Context) (interface{}, error) { - if args[0] == nil || args[1] == nil { - return nil, nil +func builtinStrcmp(args []types.Datum, _ context.Context) (d types.Datum, err error) { + if args[0].Kind() == types.KindNull || args[1].Kind() == types.KindNull { + d.SetNull() + return d, nil } - left, err := types.ToString(args[0]) + left, err := args[0].ToString() if err != nil { - return nil, errors.Trace(err) + d.SetNull() + return d, errors.Trace(err) } - right, err := types.ToString(args[1]) + right, err := args[1].ToString() if err != nil { - return nil, errors.Trace(err) + d.SetNull() + return d, errors.Trace(err) } res := types.CompareString(left, right) - return res, nil + d.SetInt64(int64(res)) + return d, nil } // See: https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_replace -func builtinReplace(args []interface{}, _ context.Context) (interface{}, error) { +func builtinReplace(args []types.Datum, _ context.Context) (d types.Datum, err error) { for _, arg := range args { - if arg == nil { - return nil, nil + if arg.Kind() == types.KindNull { + d.SetNull() + return d, nil } } - str, err := types.ToString(args[0]) + str, err := args[0].ToString() if err != nil { - return nil, errors.Trace(err) + d.SetNull() + return d, errors.Trace(err) } - oldStr, err := types.ToString(args[1]) + oldStr, err := args[1].ToString() if err != nil { - return nil, errors.Trace(err) + d.SetNull() + return d, errors.Trace(err) } - newStr, err := types.ToString(args[2]) + newStr, err := args[2].ToString() if err != nil { - return nil, errors.Trace(err) + d.SetNull() + return d, errors.Trace(err) } + d.SetString(strings.Replace(str, oldStr, newStr, -1)) - return strings.Replace(str, oldStr, newStr, -1), nil + return d, nil } // See: https://dev.mysql.com/doc/refman/5.7/en/cast-functions.html#function_convert -func builtinConvert(args []interface{}, _ context.Context) (interface{}, error) { - value := args[0] - Charset := args[1].(string) - +func builtinConvert(args []types.Datum, _ context.Context) (d types.Datum, err error) { // Casting nil to any type returns nil - if value == nil { - return nil, nil - } - str, ok := value.(string) - if !ok { - return nil, nil + if args[0].Kind() != types.KindString { + d.SetNull() + return d, nil } + + str := args[0].GetString() + Charset := args[1].GetString() + if strings.ToLower(Charset) == "ascii" { - return value, nil + d.SetString(str) + return d, nil } else if strings.ToLower(Charset) == "utf8mb4" { - return value, nil + d.SetString(str) + return d, nil } encoding, _ := charset.Lookup(Charset) if encoding == nil { - return nil, errors.Errorf("unknown encoding: %s", Charset) + d.SetNull() + return d, errors.Errorf("unknown encoding: %s", Charset) } target, _, err := transform.String(encoding.NewDecoder(), str) if err != nil { + d.SetNull() log.Errorf("Convert %s to %s with error: %v", str, Charset, err) - return nil, errors.Trace(err) + return d, errors.Trace(err) } - return target, nil + d.SetString(target) + return d, nil } -func builtinSubstring(args []interface{}, _ context.Context) (interface{}, error) { +func builtinSubstring(args []types.Datum, _ context.Context) (d types.Datum, err error) { // The meaning of the elements of args. // arg[0] -> StrExpr // arg[1] -> Pos // arg[2] -> Len (Optional) - str, err := types.ToString(args[0]) + str, err := args[0].ToString() if err != nil { - return nil, errors.Errorf("Substring invalid args, need string but get %T", args[0]) + d.SetNull() + return d, errors.Errorf("Substring invalid args, need string but get %v", args[0].Kind()) } - t := args[1] - p, ok := t.(int64) - if !ok { - return nil, errors.Errorf("Substring invalid pos args, need int but get %T", t) + if args[1].Kind() != types.KindInt64 { + d.SetNull() + return d, errors.Errorf("Substring invalid pos args, need int but get %v", args[1].Kind()) } - pos := int(p) + pos := args[1].GetInt64() - length := -1 + length := int64(-1) if len(args) == 3 { - t = args[2] - p, ok = t.(int64) - if !ok { - return nil, errors.Errorf("Substring invalid pos args, need int but get %T", t) + if args[2].Kind() != types.KindInt64 { + d.SetNull() + return d, errors.Errorf("Substring invalid pos args, need int but get %v", args[2].Kind()) } - length = int(p) + length = args[2].GetInt64() } // The forms without a len argument return a substring from string str starting at position pos. // The forms with a len argument return a substring len characters long from string str, starting at position pos. @@ -284,48 +295,50 @@ func builtinSubstring(args []interface{}, _ context.Context) (interface{}, error // In this case, the beginning of the substring is pos characters from the end of the string, rather than the beginning. // A negative value may be used for pos in any of the forms of this function. if pos < 0 { - pos = len(str) + pos + pos = int64(len(str)) + pos } else { pos-- } - if pos > len(str) || pos <= 0 { - pos = len(str) + if pos > int64(len(str)) || pos <= int64(0) { + pos = int64(len(str)) } - end := len(str) - if length != -1 { + end := int64(len(str)) + if length != int64(-1) { end = pos + length } - if end > len(str) { - end = len(str) + if end > int64(len(str)) { + end = int64(len(str)) } - return str[pos:end], nil + d.SetString(str[pos:end]) + return d, nil } // See: https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_substring-index -func builtinSubstringIndex(args []interface{}, _ context.Context) (interface{}, error) { +func builtinSubstringIndex(args []types.Datum, _ context.Context) (d types.Datum, err error) { // The meaning of the elements of args. // args[0] -> StrExpr // args[1] -> Delim // args[2] -> Count - fs := args[0] - str, err := types.ToString(fs) + str, err := args[0].ToString() if err != nil { - return nil, errors.Errorf("Substring_Index invalid args, need string but get %T", fs) + d.SetNull() + return d, errors.Errorf("Substring_Index invalid args, need string but get %v", args[0].Kind()) } - t := args[1] - delim, err := types.ToString(t) + delim, err := args[1].ToString() if err != nil { - return nil, errors.Errorf("Substring_Index invalid delim, need string but get %T", t) + d.SetNull() + return d, errors.Errorf("Substring_Index invalid delim, need string but get %v", args[1].Kind()) } if len(delim) == 0 { - return "", nil + d.SetString("") + return d, nil } - t = args[2] - c, err := types.ToInt64(t) + c, err := args[2].ToInt64() if err != nil { - return nil, errors.Trace(err) + d.SetNull() + return d, errors.Trace(err) } count := int(c) strs := strings.Split(str, delim) @@ -346,83 +359,92 @@ func builtinSubstringIndex(args []interface{}, _ context.Context) (interface{}, } } substrs := strs[start:end] - return strings.Join(substrs, delim), nil + d.SetString(strings.Join(substrs, delim)) + return d, nil } // See: https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_locate -func builtinLocate(args []interface{}, _ context.Context) (interface{}, error) { +func builtinLocate(args []types.Datum, _ context.Context) (d types.Datum, err error) { // The meaning of the elements of args. // args[0] -> SubStr // args[1] -> Str // args[2] -> Pos // eval str - fs := args[1] - if fs == nil { - return nil, nil + if args[1].Kind() == types.KindNull { + d.SetNull() + return d, nil } - str, err := types.ToString(fs) + str, err := args[1].ToString() if err != nil { - return nil, errors.Trace(err) + d.SetNull() + return d, errors.Trace(err) } // eval substr - fs = args[0] - if fs == nil { - return nil, nil + if args[0].Kind() == types.KindNull { + d.SetNull() + return d, nil } - subStr, err := types.ToString(fs) + subStr, err := args[0].ToString() if err != nil { - return nil, errors.Trace(err) + d.SetNull() + return d, errors.Trace(err) } // eval pos pos := int64(0) if len(args) == 3 { - t := args[2] - p, err := types.ToInt64(t) + p, err := args[2].ToInt64() if err != nil { - return nil, errors.Trace(err) + d.SetNull() + return d, errors.Trace(err) } pos = p - 1 if pos < 0 || pos > int64(len(str)) { - return 0, nil + d.SetInt64(0) + return d, nil } if pos > int64(len(str)-len(subStr)) { - return 0, nil + d.SetInt64(0) + return d, nil } } if len(subStr) == 0 { - return pos + 1, nil + d.SetInt64(pos + 1) + return d, nil } i := strings.Index(str[pos:], subStr) if i == -1 { - return 0, nil + d.SetInt64(0) + return d, nil } - return int64(i) + pos + 1, nil + d.SetInt64(int64(i) + pos + 1) + return d, nil } const spaceChars = "\n\t\r " // See: https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_trim -func builtinTrim(args []interface{}, _ context.Context) (interface{}, error) { +func builtinTrim(args []types.Datum, _ context.Context) (d types.Datum, err error) { // args[0] -> Str // args[1] -> RemStr // args[2] -> Direction // eval str - fs := args[0] - if fs == nil { - return nil, nil + if args[0].Kind() == types.KindNull { + d.SetNull() + return d, nil } - str, err := types.ToString(fs) + str, err := args[0].ToString() if err != nil { - return nil, errors.Trace(err) + d.SetNull() + return d, errors.Trace(err) } remstr := "" // eval remstr if len(args) > 1 { - fs = args[1] - if fs != nil { - remstr, err = types.ToString(fs) + if args[1].Kind() != types.KindNull { + remstr, err = args[1].ToString() if err != nil { - return nil, errors.Trace(err) + d.SetNull() + return d, errors.Trace(err) } } } @@ -430,7 +452,7 @@ func builtinTrim(args []interface{}, _ context.Context) (interface{}, error) { var result string var direction ast.TrimDirectionType if len(args) > 2 { - direction = args[2].(ast.TrimDirectionType) + direction = args[2].GetValue().(ast.TrimDirectionType) } else { direction = ast.TrimBothDefault } @@ -452,7 +474,8 @@ func builtinTrim(args []interface{}, _ context.Context) (interface{}, error) { } else { result = strings.Trim(str, spaceChars) } - return result, nil + d.SetString(result) + return d, nil } func trimLeft(str, remstr string) string { diff --git a/evaluator/builtin_string_test.go b/evaluator/builtin_string_test.go index a15309456e..bf6c2e6c85 100644 --- a/evaluator/builtin_string_test.go +++ b/evaluator/builtin_string_test.go @@ -19,7 +19,10 @@ import ( "time" . "github.com/pingcap/check" + "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/model" "github.com/pingcap/tidb/mysql" + "github.com/pingcap/tidb/util/mock" "github.com/pingcap/tidb/util/types" ) @@ -192,10 +195,11 @@ func (s *testEvaluatorSuite) TestStrcmp(c *C) { {[]interface{}{nil, ""}, nil}, } - for _, t := range tbl { - v, err := builtinStrcmp(t.Input, nil) + dtbl := tblToDtbl(tbl) + for _, t := range dtbl { + d, err := builtinStrcmp(t["Input"], nil) c.Assert(err, IsNil) - c.Assert(v, Equals, t.Expect) + c.Assert(d, DatumEquals, t["Expect"][0]) } } @@ -212,9 +216,301 @@ func (s *testEvaluatorSuite) TestReplace(c *C) { {[]interface{}{12345, 2, "aa"}, "1aa345"}, } - for _, t := range tbl { - v, err := builtinReplace(t.Input, nil) + dtbl := tblToDtbl(tbl) + + for _, t := range dtbl { + d, err := builtinReplace(t["Input"], nil) c.Assert(err, IsNil) - c.Assert(v, Equals, t.Expect) + c.Assert(d, DatumEquals, t["Expect"][0]) + } +} + +func (s *testEvaluatorSuite) TestSubstring(c *C) { + tbl := []struct { + str string + pos int64 + slen int64 + result string + }{ + {"Quadratically", 5, -1, "ratically"}, + {"foobarbar", 4, -1, "barbar"}, + {"Quadratically", 5, 6, "ratica"}, + {"Sakila", -3, -1, "ila"}, + {"Sakila", -5, 3, "aki"}, + {"Sakila", -4, 2, "ki"}, + {"Sakila", 1000, 2, ""}, + {"", 2, 3, ""}, + } + ctx := mock.NewContext() + for _, v := range tbl { + f := &ast.FuncCallExpr{ + FnName: model.NewCIStr("SUBSTRING"), + Args: []ast.ExprNode{ast.NewValueExpr(v.str), ast.NewValueExpr(v.pos)}, + } + if v.slen != -1 { + f.Args = append(f.Args, ast.NewValueExpr(v.slen)) + } + r, err := Eval(ctx, f) + c.Assert(err, IsNil) + s, ok := r.(string) + c.Assert(ok, IsTrue) + c.Assert(s, Equals, v.result) + + r1, err := Eval(ctx, f) + c.Assert(err, IsNil) + s1, ok := r1.(string) + c.Assert(ok, IsTrue) + c.Assert(s, Equals, s1) + } + errTbl := []struct { + str interface{} + pos interface{} + len interface{} + result string + }{ + {"foobarbar", "4", -1, "barbar"}, + {"Quadratically", 5, "6", "ratica"}, + } + for _, v := range errTbl { + f := &ast.FuncCallExpr{ + FnName: model.NewCIStr("SUBSTRING"), + Args: []ast.ExprNode{ast.NewValueExpr(v.str), ast.NewValueExpr(v.pos)}, + } + if v.len != -1 { + f.Args = append(f.Args, ast.NewValueExpr(v.len)) + } + _, err := Eval(ctx, f) + c.Assert(err, NotNil) + } +} + +func (s *testEvaluatorSuite) TestConvert(c *C) { + ctx := mock.NewContext() + tbl := []struct { + str string + cs string + result string + }{ + {"haha", "utf8", "haha"}, + {"haha", "ascii", "haha"}, + } + for _, v := range tbl { + f := &ast.FuncCallExpr{ + FnName: model.NewCIStr("CONVERT"), + Args: []ast.ExprNode{ + ast.NewValueExpr(v.str), + ast.NewValueExpr(v.cs), + }, + } + + r, err := Eval(ctx, f) + c.Assert(err, IsNil) + s, ok := r.(string) + c.Assert(ok, IsTrue) + c.Assert(s, Equals, v.result) + } + + // Test case for error + errTbl := []struct { + str interface{} + cs string + result string + }{ + {"haha", "wrongcharset", "haha"}, + } + for _, v := range errTbl { + f := &ast.FuncCallExpr{ + FnName: model.NewCIStr("CONVERT"), + Args: []ast.ExprNode{ + ast.NewValueExpr(v.str), + ast.NewValueExpr(v.cs), + }, + } + + _, err := Eval(ctx, f) + c.Assert(err, NotNil) + } +} + +func (s *testEvaluatorSuite) TestSubstringIndex(c *C) { + tbl := []struct { + str string + delim string + count int64 + result string + }{ + {"www.mysql.com", ".", 2, "www.mysql"}, + {"www.mysql.com", ".", -2, "mysql.com"}, + {"www.mysql.com", ".", 0, ""}, + {"www.mysql.com", ".", 3, "www.mysql.com"}, + {"www.mysql.com", ".", 4, "www.mysql.com"}, + {"www.mysql.com", ".", -3, "www.mysql.com"}, + {"www.mysql.com", ".", -4, "www.mysql.com"}, + + {"www.mysql.com", "d", 1, "www.mysql.com"}, + {"www.mysql.com", "d", 0, ""}, + {"www.mysql.com", "d", -1, "www.mysql.com"}, + + {"", ".", 2, ""}, + {"", ".", -2, ""}, + {"", ".", 0, ""}, + + {"www.mysql.com", "", 1, ""}, + {"www.mysql.com", "", -1, ""}, + {"www.mysql.com", "", 0, ""}, + } + ctx := mock.NewContext() + for _, v := range tbl { + f := &ast.FuncCallExpr{ + FnName: model.NewCIStr("SUBSTRING_INDEX"), + Args: []ast.ExprNode{ast.NewValueExpr(v.str), ast.NewValueExpr(v.delim), ast.NewValueExpr(v.count)}, + } + r, err := Eval(ctx, f) + c.Assert(err, IsNil) + s, ok := r.(string) + c.Assert(ok, IsTrue) + c.Assert(s, Equals, v.result) + } + errTbl := []struct { + str interface{} + delim interface{} + count interface{} + }{ + {nil, ".", 2}, + {nil, ".", -2}, + {nil, ".", 0}, + {"asdf", nil, 2}, + {"asdf", nil, -2}, + {"asdf", nil, 0}, + {"www.mysql.com", ".", nil}, + } + for _, v := range errTbl { + f := &ast.FuncCallExpr{ + FnName: model.NewCIStr("SUBSTRING_INDEX"), + Args: []ast.ExprNode{ast.NewValueExpr(v.str), ast.NewValueExpr(v.delim), ast.NewValueExpr(v.count)}, + } + r, err := Eval(ctx, f) + c.Assert(err, NotNil) + c.Assert(r, IsNil) + } +} + +func (s *testEvaluatorSuite) TestLocate(c *C) { + tbl := []struct { + subStr string + Str string + result int64 + }{ + {"bar", "foobarbar", 4}, + {"xbar", "foobar", 0}, + {"", "foobar", 1}, + {"foobar", "", 0}, + {"", "", 1}, + } + ctx := mock.NewContext() + for _, v := range tbl { + f := &ast.FuncCallExpr{ + FnName: model.NewCIStr("LOCATE"), + Args: []ast.ExprNode{ast.NewValueExpr(v.subStr), ast.NewValueExpr(v.Str)}, + } + r, err := Eval(ctx, f) + c.Assert(err, IsNil) + s, ok := r.(int64) + c.Assert(ok, IsTrue) + c.Assert(s, Equals, v.result) + } + + tbl2 := []struct { + subStr string + Str string + pos int64 + result int64 + }{ + {"bar", "foobarbar", 5, 7}, + {"xbar", "foobar", 1, 0}, + {"", "foobar", 2, 2}, + {"foobar", "", 1, 0}, + {"", "", 2, 0}, + } + for _, v := range tbl2 { + f := &ast.FuncCallExpr{ + FnName: model.NewCIStr("LOCATE"), + Args: []ast.ExprNode{ast.NewValueExpr(v.subStr), ast.NewValueExpr(v.Str), ast.NewValueExpr(v.pos)}, + } + r, err := Eval(ctx, f) + c.Assert(err, IsNil) + s, ok := r.(int64) + c.Assert(ok, IsTrue) + c.Assert(s, Equals, v.result) + } + + errTbl := []struct { + subStr interface{} + Str interface{} + }{ + {nil, nil}, + {"", nil}, + {nil, ""}, + {"foo", nil}, + {nil, "bar"}, + } + for _, v := range errTbl { + f := &ast.FuncCallExpr{ + FnName: model.NewCIStr("LOCATE"), + Args: []ast.ExprNode{ast.NewValueExpr(v.subStr), ast.NewValueExpr(v.Str)}, + } + r, _ := Eval(ctx, f) + c.Assert(r, IsNil) + } + + errTbl2 := []struct { + subStr interface{} + Str interface{} + pos interface{} + }{ + {nil, nil, 1}, + {"", nil, 1}, + {nil, "", 1}, + {"foo", nil, -1}, + {nil, "bar", 0}, + } + for _, v := range errTbl2 { + f := &ast.FuncCallExpr{ + FnName: model.NewCIStr("LOCATE"), + Args: []ast.ExprNode{ast.NewValueExpr(v.subStr), ast.NewValueExpr(v.Str), ast.NewValueExpr(v.pos)}, + } + r, _ := Eval(ctx, f) + c.Assert(r, IsNil) + } +} + +func (s *testEvaluatorSuite) TestTrim(c *C) { + tbl := []struct { + str interface{} + remstr interface{} + dir ast.TrimDirectionType + result interface{} + }{ + {" bar ", nil, ast.TrimBothDefault, "bar"}, + {"xxxbarxxx", "x", ast.TrimLeading, "barxxx"}, + {"xxxbarxxx", "x", ast.TrimBoth, "bar"}, + {"barxxyz", "xyz", ast.TrimTrailing, "barx"}, + {nil, "xyz", ast.TrimBoth, nil}, + {1, 2, ast.TrimBoth, "1"}, + {" \t\rbar\n ", nil, ast.TrimBothDefault, "bar"}, + } + ctx := mock.NewContext() + for _, v := range tbl { + f := &ast.FuncCallExpr{ + FnName: model.NewCIStr("TRIM"), + Args: []ast.ExprNode{ + ast.NewValueExpr(v.str), + ast.NewValueExpr(v.remstr), + ast.NewValueExpr(v.dir), + }, + } + r, err := Eval(ctx, f) + c.Assert(err, IsNil) + c.Assert(r, Equals, v.result) } } diff --git a/evaluator/evaluator_test.go b/evaluator/evaluator_test.go index 3b8397aa2a..77691b6153 100644 --- a/evaluator/evaluator_test.go +++ b/evaluator/evaluator_test.go @@ -352,54 +352,6 @@ func (s *testEvaluatorSuite) TestCaseWhen(c *C) { c.Assert(v, IsNil) } -func (s *testEvaluatorSuite) TestConvert(c *C) { - ctx := mock.NewContext() - tbl := []struct { - str string - cs string - result string - }{ - {"haha", "utf8", "haha"}, - {"haha", "ascii", "haha"}, - } - for _, v := range tbl { - f := &ast.FuncCallExpr{ - FnName: model.NewCIStr("CONVERT"), - Args: []ast.ExprNode{ - ast.NewValueExpr(v.str), - ast.NewValueExpr(v.cs), - }, - } - - r, err := Eval(ctx, f) - c.Assert(err, IsNil) - s, ok := r.(string) - c.Assert(ok, IsTrue) - c.Assert(s, Equals, v.result) - } - - // Test case for error - errTbl := []struct { - str interface{} - cs string - result string - }{ - {"haha", "wrongcharset", "haha"}, - } - for _, v := range errTbl { - f := &ast.FuncCallExpr{ - FnName: model.NewCIStr("CONVERT"), - Args: []ast.ExprNode{ - ast.NewValueExpr(v.str), - ast.NewValueExpr(v.cs), - }, - } - - _, err := Eval(ctx, f) - c.Assert(err, NotNil) - } -} - func (s *testEvaluatorSuite) TestCall(c *C) { ctx := mock.NewContext() @@ -766,248 +718,6 @@ func (s *testEvaluatorSuite) TestRegexp(c *C) { } } -func (s *testEvaluatorSuite) TestSubstring(c *C) { - tbl := []struct { - str string - pos int64 - slen int64 - result string - }{ - {"Quadratically", 5, -1, "ratically"}, - {"foobarbar", 4, -1, "barbar"}, - {"Quadratically", 5, 6, "ratica"}, - {"Sakila", -3, -1, "ila"}, - {"Sakila", -5, 3, "aki"}, - {"Sakila", -4, 2, "ki"}, - {"Sakila", 1000, 2, ""}, - {"", 2, 3, ""}, - } - ctx := mock.NewContext() - for _, v := range tbl { - f := &ast.FuncCallExpr{ - FnName: model.NewCIStr("SUBSTRING"), - Args: []ast.ExprNode{ast.NewValueExpr(v.str), ast.NewValueExpr(v.pos)}, - } - if v.slen != -1 { - f.Args = append(f.Args, ast.NewValueExpr(v.slen)) - } - r, err := Eval(ctx, f) - c.Assert(err, IsNil) - s, ok := r.(string) - c.Assert(ok, IsTrue) - c.Assert(s, Equals, v.result) - - r1, err := Eval(ctx, f) - c.Assert(err, IsNil) - s1, ok := r1.(string) - c.Assert(ok, IsTrue) - c.Assert(s, Equals, s1) - } - errTbl := []struct { - str interface{} - pos interface{} - len interface{} - result string - }{ - {"foobarbar", "4", -1, "barbar"}, - {"Quadratically", 5, "6", "ratica"}, - } - for _, v := range errTbl { - f := &ast.FuncCallExpr{ - FnName: model.NewCIStr("SUBSTRING"), - Args: []ast.ExprNode{ast.NewValueExpr(v.str), ast.NewValueExpr(v.pos)}, - } - if v.len != -1 { - f.Args = append(f.Args, ast.NewValueExpr(v.len)) - } - _, err := Eval(ctx, f) - c.Assert(err, NotNil) - } -} - -func (s *testEvaluatorSuite) TestSubstringIndex(c *C) { - tbl := []struct { - str string - delim string - count int64 - result string - }{ - {"www.mysql.com", ".", 2, "www.mysql"}, - {"www.mysql.com", ".", -2, "mysql.com"}, - {"www.mysql.com", ".", 0, ""}, - {"www.mysql.com", ".", 3, "www.mysql.com"}, - {"www.mysql.com", ".", 4, "www.mysql.com"}, - {"www.mysql.com", ".", -3, "www.mysql.com"}, - {"www.mysql.com", ".", -4, "www.mysql.com"}, - - {"www.mysql.com", "d", 1, "www.mysql.com"}, - {"www.mysql.com", "d", 0, ""}, - {"www.mysql.com", "d", -1, "www.mysql.com"}, - - {"", ".", 2, ""}, - {"", ".", -2, ""}, - {"", ".", 0, ""}, - - {"www.mysql.com", "", 1, ""}, - {"www.mysql.com", "", -1, ""}, - {"www.mysql.com", "", 0, ""}, - } - ctx := mock.NewContext() - for _, v := range tbl { - f := &ast.FuncCallExpr{ - FnName: model.NewCIStr("SUBSTRING_INDEX"), - Args: []ast.ExprNode{ast.NewValueExpr(v.str), ast.NewValueExpr(v.delim), ast.NewValueExpr(v.count)}, - } - r, err := Eval(ctx, f) - c.Assert(err, IsNil) - s, ok := r.(string) - c.Assert(ok, IsTrue) - c.Assert(s, Equals, v.result) - } - errTbl := []struct { - str interface{} - delim interface{} - count interface{} - }{ - {nil, ".", 2}, - {nil, ".", -2}, - {nil, ".", 0}, - {"asdf", nil, 2}, - {"asdf", nil, -2}, - {"asdf", nil, 0}, - {"www.mysql.com", ".", nil}, - } - for _, v := range errTbl { - f := &ast.FuncCallExpr{ - FnName: model.NewCIStr("SUBSTRING_INDEX"), - Args: []ast.ExprNode{ast.NewValueExpr(v.str), ast.NewValueExpr(v.delim), ast.NewValueExpr(v.count)}, - } - r, err := Eval(ctx, f) - c.Assert(err, NotNil) - c.Assert(r, IsNil) - } -} - -func (s *testEvaluatorSuite) TestLocate(c *C) { - tbl := []struct { - subStr string - Str string - result int64 - }{ - {"bar", "foobarbar", 4}, - {"xbar", "foobar", 0}, - {"", "foobar", 1}, - {"foobar", "", 0}, - {"", "", 1}, - } - ctx := mock.NewContext() - for _, v := range tbl { - f := &ast.FuncCallExpr{ - FnName: model.NewCIStr("LOCATE"), - Args: []ast.ExprNode{ast.NewValueExpr(v.subStr), ast.NewValueExpr(v.Str)}, - } - r, err := Eval(ctx, f) - c.Assert(err, IsNil) - s, ok := r.(int64) - c.Assert(ok, IsTrue) - c.Assert(s, Equals, v.result) - } - - tbl2 := []struct { - subStr string - Str string - pos int64 - result int64 - }{ - {"bar", "foobarbar", 5, 7}, - {"xbar", "foobar", 1, 0}, - {"", "foobar", 2, 2}, - {"foobar", "", 1, 0}, - {"", "", 2, 0}, - } - for _, v := range tbl2 { - f := &ast.FuncCallExpr{ - FnName: model.NewCIStr("LOCATE"), - Args: []ast.ExprNode{ast.NewValueExpr(v.subStr), ast.NewValueExpr(v.Str), ast.NewValueExpr(v.pos)}, - } - r, err := Eval(ctx, f) - c.Assert(err, IsNil) - s, ok := r.(int64) - c.Assert(ok, IsTrue) - c.Assert(s, Equals, v.result) - } - - errTbl := []struct { - subStr interface{} - Str interface{} - }{ - {nil, nil}, - {"", nil}, - {nil, ""}, - {"foo", nil}, - {nil, "bar"}, - } - for _, v := range errTbl { - f := &ast.FuncCallExpr{ - FnName: model.NewCIStr("LOCATE"), - Args: []ast.ExprNode{ast.NewValueExpr(v.subStr), ast.NewValueExpr(v.Str)}, - } - r, _ := Eval(ctx, f) - c.Assert(r, IsNil) - } - - errTbl2 := []struct { - subStr interface{} - Str interface{} - pos interface{} - }{ - {nil, nil, 1}, - {"", nil, 1}, - {nil, "", 1}, - {"foo", nil, -1}, - {nil, "bar", 0}, - } - for _, v := range errTbl2 { - f := &ast.FuncCallExpr{ - FnName: model.NewCIStr("LOCATE"), - Args: []ast.ExprNode{ast.NewValueExpr(v.subStr), ast.NewValueExpr(v.Str), ast.NewValueExpr(v.pos)}, - } - r, _ := Eval(ctx, f) - c.Assert(r, IsNil) - } -} - -func (s *testEvaluatorSuite) TestTrim(c *C) { - tbl := []struct { - str interface{} - remstr interface{} - dir ast.TrimDirectionType - result interface{} - }{ - {" bar ", nil, ast.TrimBothDefault, "bar"}, - {"xxxbarxxx", "x", ast.TrimLeading, "barxxx"}, - {"xxxbarxxx", "x", ast.TrimBoth, "bar"}, - {"barxxyz", "xyz", ast.TrimTrailing, "barx"}, - {nil, "xyz", ast.TrimBoth, nil}, - {1, 2, ast.TrimBoth, "1"}, - {" \t\rbar\n ", nil, ast.TrimBothDefault, "bar"}, - } - ctx := mock.NewContext() - for _, v := range tbl { - f := &ast.FuncCallExpr{ - FnName: model.NewCIStr("TRIM"), - Args: []ast.ExprNode{ - ast.NewValueExpr(v.str), - ast.NewValueExpr(v.remstr), - ast.NewValueExpr(v.dir), - }, - } - r, err := Eval(ctx, f) - c.Assert(err, IsNil) - c.Assert(r, Equals, v.result) - } -} - func (s *testEvaluatorSuite) TestUnaryOp(c *C) { tbl := []struct { arg interface{} diff --git a/util/types/datum.go b/util/types/datum.go index ddb8a43253..7e3393b56c 100644 --- a/util/types/datum.go +++ b/util/types/datum.go @@ -24,10 +24,13 @@ import ( "github.com/pingcap/tidb/util/hack" ) +// KindType is a dummy type to avoid naming collision in context. +type KindType int + // Kind constants. const ( - KindNull int = 0 - KindInt64 int = iota + 1 + KindNull KindType = 0 + KindInt64 KindType = iota + 1 KindUint64 KindFloat32 KindFloat64 @@ -46,17 +49,61 @@ const ( KindMaxValue ) +// String defines a Stringer function for debugging and pretty printing. +func (k KindType) String() string { + switch k { + case KindNull: + return "KindNull" + case KindInt64: + return "KindInt64" + case KindUint64: + return "KindUint64" + case KindFloat32: + return "KindFloat32" + case KindFloat64: + return "KindFloat64" + case KindString: + return "KindString" + case KindBytes: + return "KindBytes" + case KindMysqlBit: + return "KindMysqlBit" + case KindMysqlDecimal: + return "KindMysqlDecimal" + case KindMysqlDuration: + return "KindMysqlDuration" + case KindMysqlEnum: + return "KindMysqlEnum" + case KindMysqlHex: + return "KindMysqlHex" + case KindMysqlSet: + return "KindMysqlSet" + case KindMysqlTime: + return "KindMysqlTime" + case KindRow: + return "KindRow" + case KindInterface: + return "KindInterface" + case KindMinNotNull: + return "KindMinNotNull" + case KindMaxValue: + return "KindMaxValue" + default: + return "" + } +} + // Datum is a data box holds different kind of data. // It has better performance and is easier to use than `interface{}`. type Datum struct { - k int // datum kind. + k KindType // datum kind. i int64 // i can hold int64 uint64 float64 values. b []byte // b can hold string or []byte values. x interface{} // f hold all other types. } // Kind gets the kind of the datum. -func (d *Datum) Kind() int { +func (d *Datum) Kind() KindType { return d.k } From 2ef0c56cccbc98cf5cfa0f1b423e4c43788660a3 Mon Sep 17 00:00:00 2001 From: zhaoxingyu Date: Tue, 22 Mar 2016 16:12:07 +0800 Subject: [PATCH 2/2] Make some change. --- evaluator/builtin_string.go | 47 ++++--------------------------- util/types/datum.go | 55 +++---------------------------------- 2 files changed, 9 insertions(+), 93 deletions(-) diff --git a/evaluator/builtin_string.go b/evaluator/builtin_string.go index 382f5d8d1c..172bdc2b4b 100644 --- a/evaluator/builtin_string.go +++ b/evaluator/builtin_string.go @@ -35,12 +35,10 @@ import ( func builtinLength(args []types.Datum, _ context.Context) (d types.Datum, err error) { switch args[0].Kind() { case types.KindNull: - d.SetNull() return d, nil default: s, err := args[0].ToString() if err != nil { - d.SetNull() return d, errors.Trace(err) } d.SetInt64(int64(len(s))) @@ -53,13 +51,11 @@ func builtinConcat(args []types.Datum, _ context.Context) (d types.Datum, err er var s []byte for _, a := range args { if a.Kind() == types.KindNull { - d.SetNull() return d, nil } var ss string ss, err = a.ToString() if err != nil { - d.SetNull() return d, errors.Trace(err) } s = append(s, []byte(ss)...) @@ -75,14 +71,12 @@ func builtinConcatWS(args []types.Datum, _ context.Context) (d types.Datum, err for i, a := range args { if a.Kind() == types.KindNull { if i == 0 { - d.SetNull() return d, nil } continue } ss, err := a.ToString() if err != nil { - d.SetNull() return d, errors.Trace(err) } @@ -101,12 +95,10 @@ func builtinConcatWS(args []types.Datum, _ context.Context) (d types.Datum, err func builtinLeft(args []types.Datum, _ context.Context) (d types.Datum, err error) { str, err := args[0].ToString() if err != nil { - d.SetNull() return d, errors.Trace(err) } length, err := args[1].ToInt64() if err != nil { - d.SetNull() return d, errors.Trace(err) } l := int(length) @@ -123,7 +115,6 @@ func builtinLeft(args []types.Datum, _ context.Context) (d types.Datum, err erro func builtinRepeat(args []types.Datum, _ context.Context) (d types.Datum, err error) { str, err := args[0].ToString() if err != nil { - d.SetNull() return d, err } ch := fmt.Sprintf("%v", str) @@ -148,12 +139,10 @@ func builtinLower(args []types.Datum, _ context.Context) (d types.Datum, err err x := args[0] switch x.Kind() { case types.KindNull: - d.SetNull() return d, nil default: s, err := x.ToString() if err != nil { - d.SetNull() return d, errors.Trace(err) } d.SetString(strings.ToLower(s)) @@ -166,12 +155,10 @@ func builtinUpper(args []types.Datum, _ context.Context) (d types.Datum, err err x := args[0] switch x.Kind() { case types.KindNull: - d.SetNull() return d, nil default: s, err := x.ToString() if err != nil { - d.SetNull() return d, errors.Trace(err) } d.SetString(strings.ToUpper(s)) @@ -182,17 +169,14 @@ func builtinUpper(args []types.Datum, _ context.Context) (d types.Datum, err err // See: https://dev.mysql.com/doc/refman/5.7/en/string-comparison-functions.html func builtinStrcmp(args []types.Datum, _ context.Context) (d types.Datum, err error) { if args[0].Kind() == types.KindNull || args[1].Kind() == types.KindNull { - d.SetNull() return d, nil } left, err := args[0].ToString() if err != nil { - d.SetNull() return d, errors.Trace(err) } right, err := args[1].ToString() if err != nil { - d.SetNull() return d, errors.Trace(err) } res := types.CompareString(left, right) @@ -204,24 +188,20 @@ func builtinStrcmp(args []types.Datum, _ context.Context) (d types.Datum, err er func builtinReplace(args []types.Datum, _ context.Context) (d types.Datum, err error) { for _, arg := range args { if arg.Kind() == types.KindNull { - d.SetNull() return d, nil } } str, err := args[0].ToString() if err != nil { - d.SetNull() return d, errors.Trace(err) } oldStr, err := args[1].ToString() if err != nil { - d.SetNull() return d, errors.Trace(err) } newStr, err := args[2].ToString() if err != nil { - d.SetNull() return d, errors.Trace(err) } d.SetString(strings.Replace(str, oldStr, newStr, -1)) @@ -233,7 +213,6 @@ func builtinReplace(args []types.Datum, _ context.Context) (d types.Datum, err e func builtinConvert(args []types.Datum, _ context.Context) (d types.Datum, err error) { // Casting nil to any type returns nil if args[0].Kind() != types.KindString { - d.SetNull() return d, nil } @@ -250,13 +229,11 @@ func builtinConvert(args []types.Datum, _ context.Context) (d types.Datum, err e encoding, _ := charset.Lookup(Charset) if encoding == nil { - d.SetNull() return d, errors.Errorf("unknown encoding: %s", Charset) } target, _, err := transform.String(encoding.NewDecoder(), str) if err != nil { - d.SetNull() log.Errorf("Convert %s to %s with error: %v", str, Charset, err) return d, errors.Trace(err) } @@ -271,21 +248,18 @@ func builtinSubstring(args []types.Datum, _ context.Context) (d types.Datum, err // arg[2] -> Len (Optional) str, err := args[0].ToString() if err != nil { - d.SetNull() - return d, errors.Errorf("Substring invalid args, need string but get %v", args[0].Kind()) + return d, errors.Errorf("Substring invalid args, need string but get %T", args[0].GetValue()) } if args[1].Kind() != types.KindInt64 { - d.SetNull() - return d, errors.Errorf("Substring invalid pos args, need int but get %v", args[1].Kind()) + return d, errors.Errorf("Substring invalid pos args, need int but get %T", args[1].GetValue()) } pos := args[1].GetInt64() length := int64(-1) if len(args) == 3 { if args[2].Kind() != types.KindInt64 { - d.SetNull() - return d, errors.Errorf("Substring invalid pos args, need int but get %v", args[2].Kind()) + return d, errors.Errorf("Substring invalid pos args, need int but get %T", args[2].GetValue()) } length = args[2].GetInt64() } @@ -321,14 +295,12 @@ func builtinSubstringIndex(args []types.Datum, _ context.Context) (d types.Datum // args[2] -> Count str, err := args[0].ToString() if err != nil { - d.SetNull() - return d, errors.Errorf("Substring_Index invalid args, need string but get %v", args[0].Kind()) + return d, errors.Errorf("Substring_Index invalid args, need string but get %T", args[0].GetValue()) } delim, err := args[1].ToString() if err != nil { - d.SetNull() - return d, errors.Errorf("Substring_Index invalid delim, need string but get %v", args[1].Kind()) + return d, errors.Errorf("Substring_Index invalid delim, need string but get %T", args[1].GetValue()) } if len(delim) == 0 { d.SetString("") @@ -337,7 +309,6 @@ func builtinSubstringIndex(args []types.Datum, _ context.Context) (d types.Datum c, err := args[2].ToInt64() if err != nil { - d.SetNull() return d, errors.Trace(err) } count := int(c) @@ -371,22 +342,18 @@ func builtinLocate(args []types.Datum, _ context.Context) (d types.Datum, err er // args[2] -> Pos // eval str if args[1].Kind() == types.KindNull { - d.SetNull() return d, nil } str, err := args[1].ToString() if err != nil { - d.SetNull() return d, errors.Trace(err) } // eval substr if args[0].Kind() == types.KindNull { - d.SetNull() return d, nil } subStr, err := args[0].ToString() if err != nil { - d.SetNull() return d, errors.Trace(err) } // eval pos @@ -394,7 +361,6 @@ func builtinLocate(args []types.Datum, _ context.Context) (d types.Datum, err er if len(args) == 3 { p, err := args[2].ToInt64() if err != nil { - d.SetNull() return d, errors.Trace(err) } pos = p - 1 @@ -429,12 +395,10 @@ func builtinTrim(args []types.Datum, _ context.Context) (d types.Datum, err erro // args[2] -> Direction // eval str if args[0].Kind() == types.KindNull { - d.SetNull() return d, nil } str, err := args[0].ToString() if err != nil { - d.SetNull() return d, errors.Trace(err) } remstr := "" @@ -443,7 +407,6 @@ func builtinTrim(args []types.Datum, _ context.Context) (d types.Datum, err erro if args[1].Kind() != types.KindNull { remstr, err = args[1].ToString() if err != nil { - d.SetNull() return d, errors.Trace(err) } } diff --git a/util/types/datum.go b/util/types/datum.go index 7e3393b56c..ddb8a43253 100644 --- a/util/types/datum.go +++ b/util/types/datum.go @@ -24,13 +24,10 @@ import ( "github.com/pingcap/tidb/util/hack" ) -// KindType is a dummy type to avoid naming collision in context. -type KindType int - // Kind constants. const ( - KindNull KindType = 0 - KindInt64 KindType = iota + 1 + KindNull int = 0 + KindInt64 int = iota + 1 KindUint64 KindFloat32 KindFloat64 @@ -49,61 +46,17 @@ const ( KindMaxValue ) -// String defines a Stringer function for debugging and pretty printing. -func (k KindType) String() string { - switch k { - case KindNull: - return "KindNull" - case KindInt64: - return "KindInt64" - case KindUint64: - return "KindUint64" - case KindFloat32: - return "KindFloat32" - case KindFloat64: - return "KindFloat64" - case KindString: - return "KindString" - case KindBytes: - return "KindBytes" - case KindMysqlBit: - return "KindMysqlBit" - case KindMysqlDecimal: - return "KindMysqlDecimal" - case KindMysqlDuration: - return "KindMysqlDuration" - case KindMysqlEnum: - return "KindMysqlEnum" - case KindMysqlHex: - return "KindMysqlHex" - case KindMysqlSet: - return "KindMysqlSet" - case KindMysqlTime: - return "KindMysqlTime" - case KindRow: - return "KindRow" - case KindInterface: - return "KindInterface" - case KindMinNotNull: - return "KindMinNotNull" - case KindMaxValue: - return "KindMaxValue" - default: - return "" - } -} - // Datum is a data box holds different kind of data. // It has better performance and is easier to use than `interface{}`. type Datum struct { - k KindType // datum kind. + k int // datum kind. i int64 // i can hold int64 uint64 float64 values. b []byte // b can hold string or []byte values. x interface{} // f hold all other types. } // Kind gets the kind of the datum. -func (d *Datum) Kind() KindType { +func (d *Datum) Kind() int { return d.k }