From 81ffef96d78b0acd31f7bb7436b90d0eddbc9d39 Mon Sep 17 00:00:00 2001 From: qupeng Date: Wed, 13 Dec 2017 18:57:01 +0800 Subject: [PATCH] ddl: fix a bug when format generation expressions. (#5262) --- ast/expressions.go | 37 ++++++++++++++++++++++++++++- ast/format_test.go | 17 ++++++++++++- ast/functions.go | 31 ++++++++++++++++++++---- ddl/ddl_api.go | 10 +++++--- ddl/ddl_db_test.go | 10 +++++++- parser/parser.y | 15 ------------ util/stringutil/string_util.go | 15 ------------ util/stringutil/string_util_test.go | 14 ----------- 8 files changed, 94 insertions(+), 55 deletions(-) diff --git a/ast/expressions.go b/ast/expressions.go index 0e9d1a85c9..4ae57cb141 100644 --- a/ast/expressions.go +++ b/ast/expressions.go @@ -17,10 +17,12 @@ import ( "fmt" "io" "regexp" + "strconv" "strings" "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/model" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" "github.com/pingcap/tidb/types" ) @@ -60,7 +62,40 @@ type ValueExpr struct { // Format the ExprNode into a Writer. func (n *ValueExpr) Format(w io.Writer) { - fmt.Fprint(w, n.Text()) + var s string + switch n.Kind() { + case types.KindNull: + s = "NULL" + case types.KindInt64: + if n.Type.Flag&mysql.IsBooleanFlag != 0 { + if n.GetInt64() > 0 { + s = "TRUE" + } else { + s = "FALSE" + } + } else { + s = strconv.FormatInt(n.GetInt64(), 10) + } + case types.KindUint64: + s = strconv.FormatUint(n.GetUint64(), 10) + case types.KindFloat32: + s = strconv.FormatFloat(n.GetFloat64(), 'e', -1, 32) + case types.KindFloat64: + s = strconv.FormatFloat(n.GetFloat64(), 'e', -1, 64) + case types.KindString, types.KindBytes: + s = strconv.Quote(n.GetString()) + case types.KindMysqlDecimal: + s = n.GetMysqlDecimal().String() + case types.KindBinaryLiteral: + if n.Type.Flag&mysql.UnsignedFlag != 0 { + s = fmt.Sprintf("x'%x'", n.GetBytes()) + } else { + s = n.GetBinaryLiteral().ToBitLiteralString(true) + } + default: + panic("Can't format to string") + } + fmt.Fprint(w, s) } // NewValueExpr creates a ValueExpr with value, and sets default field type. diff --git a/ast/format_test.go b/ast/format_test.go index 3c14ba7ff8..0f4b3f96c2 100644 --- a/ast/format_test.go +++ b/ast/format_test.go @@ -24,13 +24,24 @@ func (ts *testAstFormatSuite) TestAstFormat(c *C) { output string }{ // Literals. + {`null`, `NULL`}, + {`true`, `TRUE`}, {`350`, `350`}, + {`001e-12`, `1e-12`}, // Float. {`345.678`, `345.678`}, - {`1e-12`, `1e-12`}, + {`00.0001000`, `0.0001000`}, // Decimal. {`null`, `NULL`}, {`"Hello, world"`, `"Hello, world"`}, {`'Hello, world'`, `"Hello, world"`}, {`'Hello, "world"'`, `"Hello, \"world\""`}, + {`_utf8'你好'`, `"你好"`}, + {`x'bcde'`, "x'bcde'"}, + {`x''`, "x''"}, + {`x'0035'`, "x'0035'"}, // Shouldn't trim leading zero. + {`b'00111111'`, `b'111111'`}, + {`time'10:10:10.123'`, `timeliteral("10:10:10.123")`}, + {`timestamp'1999-01-01 10:0:0.123'`, `timestampliteral("1999-01-01 10:0:0.123")`}, + {`date '1700-01-01'`, `dateliteral("1700-01-01")`}, // Expressions. {`f between 30 and 50`, "`f` BETWEEN 30 AND 50"}, @@ -54,6 +65,10 @@ func (ts *testAstFormatSuite) TestAstFormat(c *C) { // Functions. {` json_extract ( a,'$.b',"$.\"c d\"" ) `, "json_extract(`a`, \"$.b\", \"$.\\\"c d\\\"\")"}, {` length ( a )`, "length(`a`)"}, + {`a -> '$.a'`, "json_extract(`a`, \"$.a\")"}, + {`a.b ->> '$.a'`, "json_unquote(json_extract(`a`.`b`, \"$.a\"))"}, + {`DATE_ADD('1970-01-01', interval 3 second)`, `date_add("1970-01-01", INTERVAL 3 SECOND)`}, + {`TIMESTAMPDIFF(month, '2001-01-01', '2001-02-02 12:03:05.123')`, `timestampdiff(MONTH, "2001-01-01", "2001-02-02 12:03:05.123")`}, // Cast, Convert and Binary. // There should not be spaces between 'cast' and '(' unless 'IGNORE_SPACE' mode is set. // see: https://dev.mysql.com/doc/refman/5.7/en/function-resolution.html diff --git a/ast/functions.go b/ast/functions.go index b6b22c958d..e51a67f503 100644 --- a/ast/functions.go +++ b/ast/functions.go @@ -314,16 +314,37 @@ type FuncCallExpr struct { // Format the ExprNode into a Writer. func (n *FuncCallExpr) Format(w io.Writer) { - fmt.Fprintf(w, "%s(", n.FnName.String()) - for i, arg := range n.Args { - arg.Format(w) - if i != len(n.Args)-1 { - fmt.Fprintf(w, ", ") + fmt.Fprintf(w, "%s(", n.FnName.L) + if !n.specialFormatArgs(w) { + for i, arg := range n.Args { + arg.Format(w) + if i != len(n.Args)-1 { + fmt.Fprintf(w, ", ") + } } } fmt.Fprintf(w, ")") } +// specialFormatArgs formats argument list for some special functions. +func (n *FuncCallExpr) specialFormatArgs(w io.Writer) bool { + switch n.FnName.L { + case DateAdd, DateSub, AddDate, SubDate: + n.Args[0].Format(w) + fmt.Fprintf(w, ", INTERVAL ") + n.Args[1].Format(w) + fmt.Fprintf(w, " %s", n.Args[2].GetDatum().GetString()) + return true + case TimestampAdd, TimestampDiff: + fmt.Fprintf(w, "%s, ", n.Args[0].GetDatum().GetString()) + n.Args[1].Format(w) + fmt.Fprintf(w, ", ") + n.Args[2].Format(w) + return true + } + return false +} + // Accept implements Node interface. func (n *FuncCallExpr) Accept(v Visitor) (Node, bool) { newNode, skipChildren := v.Enter(n) diff --git a/ddl/ddl_api.go b/ddl/ddl_api.go index 3920ee00e5..7749d64f22 100644 --- a/ddl/ddl_api.go +++ b/ddl/ddl_api.go @@ -18,6 +18,7 @@ package ddl import ( + "bytes" "fmt" "strings" "time" @@ -34,7 +35,6 @@ import ( "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/charset" - "github.com/pingcap/tidb/util/stringutil" ) func (d *ddl) CreateSchema(ctx context.Context, schema model.CIStr, charsetInfo *ast.CharsetOpt) (err error) { @@ -318,7 +318,9 @@ func columnDefToCol(ctx context.Context, offset int, colDef *ast.ColumnDef) (*ta return nil, nil, errors.Trace(err) } case ast.ColumnOptionGenerated: - col.GeneratedExprString = stringutil.RemoveBlanks(v.Expr.Text()) + var buf = bytes.NewBuffer([]byte{}) + v.Expr.Format(buf) + col.GeneratedExprString = buf.String() col.GeneratedStored = v.Stored _, dependColNames := findDependedColumnNames(colDef) col.Dependences = dependColNames @@ -1135,7 +1137,9 @@ func setDefaultAndComment(ctx context.Context, col *table.Column, options []*ast col.Flag |= mysql.OnUpdateNowFlag setOnUpdateNow = true case ast.ColumnOptionGenerated: - col.GeneratedExprString = stringutil.RemoveBlanks(opt.Expr.Text()) + var buf = bytes.NewBuffer([]byte{}) + opt.Expr.Format(buf) + col.GeneratedExprString = buf.String() col.GeneratedStored = opt.Stored col.Dependences = make(map[string]struct{}) for _, colName := range findColumnNamesInExpr(opt.Expr) { diff --git a/ddl/ddl_db_test.go b/ddl/ddl_db_test.go index b86c0dee5f..5277eba873 100644 --- a/ddl/ddl_db_test.go +++ b/ddl/ddl_db_test.go @@ -1586,7 +1586,7 @@ func (s *testDBSuite) TestGeneratedColumnDDL(c *C) { // Check show create table with virtual generated column. result = s.tk.MustQuery(`show create table test_gv_ddl`) result.Check(testkit.Rows( - "test_gv_ddl CREATE TABLE `test_gv_ddl` (\n `a` int(11) DEFAULT NULL,\n `b` int(11) GENERATED ALWAYS AS (a+8) VIRTUAL DEFAULT NULL\n) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_bin", + "test_gv_ddl CREATE TABLE `test_gv_ddl` (\n `a` int(11) DEFAULT NULL,\n `b` int(11) GENERATED ALWAYS AS (`a` + 8) VIRTUAL DEFAULT NULL\n) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_bin", )) // Check alter table add a stored generated column. @@ -1594,6 +1594,14 @@ func (s *testDBSuite) TestGeneratedColumnDDL(c *C) { result = s.tk.MustQuery(`DESC test_gv_ddl`) result.Check(testkit.Rows(`a int(11) YES `, `b int(11) YES VIRTUAL GENERATED`, `c int(11) YES STORED GENERATED`)) + // Check generated expression with blanks. + s.tk.MustExec("create table table_with_gen_col_blanks (a int, b char(20) as (cast( \r\n\t a \r\n\tas char)))") + result = s.tk.MustQuery(`show create table table_with_gen_col_blanks`) + result.Check(testkit.Rows("table_with_gen_col_blanks CREATE TABLE `table_with_gen_col_blanks` (\n" + + " `a` int(11) DEFAULT NULL,\n" + + " `b` char(20) GENERATED ALWAYS AS (CAST(`a` AS CHAR)) VIRTUAL DEFAULT NULL\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_bin")) + genExprTests := []struct { stmt string err int diff --git a/parser/parser.y b/parser/parser.y index 2fb10eeec8..50c4a32183 100644 --- a/parser/parser.y +++ b/parser/parser.y @@ -27,7 +27,6 @@ package parser import ( "strings" - "strconv" "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/ast" @@ -2684,32 +2683,26 @@ Literal: "FALSE" { $$ = ast.NewValueExpr(false) - $$.SetText("FALSE") } | "NULL" { $$ = ast.NewValueExpr(nil) - $$.SetText("NULL") } | "TRUE" { $$ = ast.NewValueExpr(true) - $$.SetText("TRUE") } | floatLit { $$ = ast.NewValueExpr($1) - $$.SetText(yyS[yypt].ident) } | decLit { $$ = ast.NewValueExpr($1) - $$.SetText(yyS[yypt].ident) } | intLit { $$ = ast.NewValueExpr($1) - $$.SetText(yyS[yypt].ident) } | StringLiteral %prec lowerThanStringLitToken { @@ -2731,18 +2724,14 @@ Literal: tp.Flag |= mysql.BinaryFlag } $$ = expr - // Because `Lexer` removes quotation marks, we add them back. - $$.SetText(strconv.Quote($2)) } | hexLit { $$ = ast.NewValueExpr($1) - $$.SetText(yyS[yypt].ident) } | bitLit { $$ = ast.NewValueExpr($1) - $$.SetText(yyS[yypt].ident) } StringLiteral: @@ -2750,8 +2739,6 @@ StringLiteral: { expr := ast.NewValueExpr($1) $$ = expr - // Because `Lexer` removes quotation marks, we add them back. - $$.SetText(strconv.Quote($1)) } | StringLiteral stringLit { @@ -2765,8 +2752,6 @@ StringLiteral: expr.SetProjectionOffset(len(strLit)) } $$ = expr - // Because `Lexer` removes quotation marks, we add them back. - $$.SetText(strconv.Quote(strLit + $2)) } diff --git a/util/stringutil/string_util.go b/util/stringutil/string_util.go index 9920a4e5bf..1cd4972666 100644 --- a/util/stringutil/string_util.go +++ b/util/stringutil/string_util.go @@ -14,7 +14,6 @@ package stringutil import ( - "bytes" "strings" "unicode/utf8" @@ -235,17 +234,3 @@ func DoMatch(str string, patChars, patTypes []byte) bool { } return sIdx == len(str) } - -// RemoveBlanks removes all blanks, returns a new string. -func RemoveBlanks(s string) string { - var buf = new(bytes.Buffer) - var cbuf [6]byte - for _, c := range s { - if c == rune(' ') || c == rune('\t') || c == rune('\r') || c == rune('\n') { - continue - } - len := utf8.EncodeRune(cbuf[0:], c) - buf.Write(cbuf[0:len]) - } - return buf.String() -} diff --git a/util/stringutil/string_util_test.go b/util/stringutil/string_util_test.go index 6668cb40df..526a19e24e 100644 --- a/util/stringutil/string_util_test.go +++ b/util/stringutil/string_util_test.go @@ -124,17 +124,3 @@ func (s *testStringUtilSuite) TestPatternMatch(c *C) { c.Assert(match, Equals, v.match, Commentf("%v", v)) } } - -func (s *testStringUtilSuite) TestRemoveBlanks(c *C) { - defer testleak.AfterTest(c)() - tests := []struct { - input string - output string - }{ - {"a\nb\rc d\te", "abcde"}, - {"hello, 世界\npeace", "hello,世界peace"}, - } - for _, tt := range tests { - c.Assert(RemoveBlanks(tt.input), Equals, tt.output) - } -}