diff --git a/parser/ast/expressions.go b/parser/ast/expressions.go index b1ea841051..67b7383e57 100755 --- a/parser/ast/expressions.go +++ b/parser/ast/expressions.go @@ -158,9 +158,15 @@ func (n *BinaryOperationExpr) Restore(ctx *RestoreCtx) error { if err := n.L.Restore(ctx); err != nil { return errors.Annotate(err, "An error occurred when restore BinaryOperationExpr.L") } + if ctx.Flags.HasSpacesAroundBinaryOperationFlag() { + ctx.WritePlain(" ") + } if err := n.Op.Restore(ctx); err != nil { return errors.Annotate(err, "An error occurred when restore BinaryOperationExpr.Op") } + if ctx.Flags.HasSpacesAroundBinaryOperationFlag() { + ctx.WritePlain(" ") + } if err := n.R.Restore(ctx); err != nil { return errors.Annotate(err, "An error occurred when restore BinaryOperationExpr.R") } diff --git a/parser/ast/expressions_test.go b/parser/ast/expressions_test.go index 911663c8e3..02f688dea7 100644 --- a/parser/ast/expressions_test.go +++ b/parser/ast/expressions_test.go @@ -16,6 +16,7 @@ package ast_test import ( . "github.com/pingcap/check" . "github.com/pingcap/parser/ast" + "github.com/pingcap/parser/format" _ "github.com/pingcap/tidb/types/parser_driver" ) @@ -200,6 +201,24 @@ func (tc *testExpressionsSuite) TestBinaryOperationExpr(c *C) { RunNodeRestoreTest(c, testCases, "select %s", extractNodeFunc) } +func (tc *testExpressionsSuite) TestBinaryOperationExprWithFlags(c *C) { + testCases := []NodeRestoreTestCase{ + {"'a'!=1", "'a' != 1"}, + {"a!=1", "`a` != 1"}, + {"3<5", "3 < 5"}, + {"10>5", "10 > 5"}, + {"3+5", "3 + 5"}, + {"3-5", "3 - 5"}, + {"a<>5", "`a` != 5"}, + {"a=1", "`a` = 1"}, + } + extractNodeFunc := func(node Node) Node { + return node.(*SelectStmt).Fields.Fields[0].Expr + } + flags := format.DefaultRestoreFlags | format.RestoreSpacesAroundBinaryOperation + RunNodeRestoreTestWithFlags(c, testCases, "select %s", extractNodeFunc, flags) +} + func (tc *testExpressionsSuite) TestParenthesesExpr(c *C) { testCases := []NodeRestoreTestCase{ {"(1+2)*3", "(1+2)*3"}, diff --git a/parser/ast/functions.go b/parser/ast/functions.go index 5d931f4dff..6b0011bbcd 100755 --- a/parser/ast/functions.go +++ b/parser/ast/functions.go @@ -512,7 +512,7 @@ func (n *FuncCastExpr) Restore(ctx *RestoreCtx) error { return errors.Annotatef(err, "An error occurred while restore FuncCastExpr.Expr") } ctx.WriteKeyWord(" AS ") - n.Tp.FormatAsCastType(ctx.In) + n.Tp.RestoreAsCastType(ctx) ctx.WritePlain(")") case CastConvertFunction: ctx.WriteKeyWord("CONVERT") @@ -521,7 +521,7 @@ func (n *FuncCastExpr) Restore(ctx *RestoreCtx) error { return errors.Annotatef(err, "An error occurred while restore FuncCastExpr.Expr") } ctx.WritePlain(", ") - n.Tp.FormatAsCastType(ctx.In) + n.Tp.RestoreAsCastType(ctx) ctx.WritePlain(")") case CastBinaryOperator: ctx.WriteKeyWord("BINARY ") diff --git a/parser/ast/functions_test.go b/parser/ast/functions_test.go index 64c5383068..c10bd58f90 100644 --- a/parser/ast/functions_test.go +++ b/parser/ast/functions_test.go @@ -84,8 +84,8 @@ func (ts *testFunctionsSuite) TestFuncCallExprRestore(c *C) { func (ts *testFunctionsSuite) TestFuncCastExprRestore(c *C) { testCases := []NodeRestoreTestCase{ {"CONVERT('Müller' USING UtF8Mb4)", "CONVERT('Müller' USING UTF8MB4)"}, - {"CONVERT('Müller', CHAR(32) CHARACTER SET UtF8)", "CONVERT('Müller', CHAR(32) CHARACTER SET UtF8)"}, - {"CAST('test' AS CHAR CHARACTER SET UtF8)", "CAST('test' AS CHAR CHARACTER SET UtF8)"}, + {"CONVERT('Müller', CHAR(32) CHARACTER SET UtF8)", "CONVERT('Müller', CHAR(32) CHARSET UTF8)"}, + {"CAST('test' AS CHAR CHARACTER SET UtF8)", "CAST('test' AS CHAR CHARSET UTF8)"}, {"BINARY 'New York'", "BINARY 'New York'"}, } extractNodeFunc := func(node Node) Node { @@ -132,7 +132,7 @@ func (ts *testFunctionsSuite) TestConvert(c *C) { ErrorMessage string }{ {`SELECT CONVERT("abc" USING "latin1")`, "latin1", ""}, - {`SELECT CONVERT("abc" USING laTiN1)`, "laTiN1", ""}, + {`SELECT CONVERT("abc" USING laTiN1)`, "latin1", ""}, {`SELECT CONVERT("abc" USING "binary")`, "binary", ""}, {`SELECT CONVERT("abc" USING biNaRy)`, "binary", ""}, {`SELECT CONVERT(a USING a)`, "", `[parser:1115]Unknown character set: 'a'`}, // TiDB issue #4436. @@ -161,7 +161,7 @@ func (ts *testFunctionsSuite) TestChar(c *C) { ErrorMessage string }{ {`SELECT CHAR("abc" USING "latin1")`, "latin1", ""}, - {`SELECT CHAR("abc" USING laTiN1)`, "laTiN1", ""}, + {`SELECT CHAR("abc" USING laTiN1)`, "latin1", ""}, {`SELECT CHAR("abc" USING "binary")`, "binary", ""}, {`SELECT CHAR("abc" USING binary)`, "binary", ""}, {`SELECT CHAR(a USING a)`, "", `[parser:1115]Unknown character set: 'a'`}, diff --git a/parser/ast/util_test.go b/parser/ast/util_test.go index a390413715..756cd9418e 100755 --- a/parser/ast/util_test.go +++ b/parser/ast/util_test.go @@ -103,6 +103,10 @@ type NodeRestoreTestCase struct { } func RunNodeRestoreTest(c *C, nodeTestCases []NodeRestoreTestCase, template string, extractNodeFunc func(node Node) Node) { + RunNodeRestoreTestWithFlags(c, nodeTestCases, template, extractNodeFunc, DefaultRestoreFlags) +} + +func RunNodeRestoreTestWithFlags(c *C, nodeTestCases []NodeRestoreTestCase, template string, extractNodeFunc func(node Node) Node, flags RestoreFlags) { parser := parser.New() parser.EnableWindowFunc(true) for _, testCase := range nodeTestCases { @@ -112,7 +116,7 @@ func RunNodeRestoreTest(c *C, nodeTestCases []NodeRestoreTestCase, template stri comment := Commentf("source %#v", testCase) c.Assert(err, IsNil, comment) var sb strings.Builder - err = extractNodeFunc(stmt).Restore(NewRestoreCtx(DefaultRestoreFlags, &sb)) + err = extractNodeFunc(stmt).Restore(NewRestoreCtx(flags, &sb)) c.Assert(err, IsNil, comment) restoreSql := fmt.Sprintf(template, sb.String()) comment = Commentf("source %#v; restore %v", testCase, restoreSql) diff --git a/parser/format/format.go b/parser/format/format.go index ad6c800252..214d8309b7 100755 --- a/parser/format/format.go +++ b/parser/format/format.go @@ -216,6 +216,8 @@ const ( RestoreNameLowercase RestoreNameDoubleQuotes RestoreNameBackQuotes + + RestoreSpacesAroundBinaryOperation ) const ( @@ -271,6 +273,11 @@ func (rf RestoreFlags) HasNameBackQuotesFlag() bool { return rf.has(RestoreNameBackQuotes) } +// HasSpacesAroundBinaryOperationFlag returns a boolean indicating whether `rf` has `RestoreSpacesAroundBinaryOperation` flag. +func (rf RestoreFlags) HasSpacesAroundBinaryOperationFlag() bool { + return rf.has(RestoreSpacesAroundBinaryOperation) +} + // RestoreCtx is `Restore` context to hold flags and writer. type RestoreCtx struct { Flags RestoreFlags diff --git a/parser/parser.go b/parser/parser.go index 2bcdcecae3..309e23316e 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -11568,14 +11568,14 @@ yynewstate: case 1088: { // Validate input charset name to keep the same behavior as parser of MySQL. - _, _, err := charset.GetCharsetInfo(yyS[yypt-0].item.(string)) + name, _, err := charset.GetCharsetInfo(yyS[yypt-0].item.(string)) if err != nil { yylex.AppendError(ErrUnknownCharacterSet.GenWithStackByArgs(yyS[yypt-0].item)) return 1 } - // Use $1 instead of charset name returned from charset.GetCharsetInfo(), - // to keep upper-lower case of input for restore. - parser.yyVAL.item = yyS[yypt-0].item + // Use charset name returned from charset.GetCharsetInfo(), + // to keep lower case of input for generated column restore. + parser.yyVAL.item = name } case 1089: { diff --git a/parser/parser.y b/parser/parser.y index a7ce646e2e..c5a389e8ba 100644 --- a/parser/parser.y +++ b/parser/parser.y @@ -5868,14 +5868,14 @@ CharsetName: StringName { // Validate input charset name to keep the same behavior as parser of MySQL. - _, _, err := charset.GetCharsetInfo($1.(string)) + name, _, err := charset.GetCharsetInfo($1.(string)) if err != nil { yylex.AppendError(ErrUnknownCharacterSet.GenWithStackByArgs($1)) return 1 } - // Use $1 instead of charset name returned from charset.GetCharsetInfo(), - // to keep upper-lower case of input for restore. - $$ = $1 + // Use charset name returned from charset.GetCharsetInfo(), + // to keep lower case of input for generated column restore. + $$ = name } | binaryType { diff --git a/parser/parser_test.go b/parser/parser_test.go index da81de5f91..b7772e881d 100755 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -1003,7 +1003,7 @@ func (s *testParserSuite) TestBuiltin(c *C) { {"create table t (row int)", false, ""}, // for cast with charset - {"SELECT *, CAST(data AS CHAR CHARACTER SET utf8) FROM t;", true, "SELECT *,CAST(`data` AS CHAR CHARACTER SET utf8) FROM `t`"}, + {"SELECT *, CAST(data AS CHAR CHARACTER SET utf8) FROM t;", true, "SELECT *,CAST(`data` AS CHAR CHARSET UTF8) FROM `t`"}, // for cast as JSON {"SELECT *, CAST(data AS JSON) FROM t;", true, "SELECT *,CAST(`data` AS JSON) FROM `t`"}, @@ -1198,7 +1198,7 @@ func (s *testParserSuite) TestBuiltin(c *C) { {`select from_unixtime(1447430881.1234567, "%Y %D %M %h:%i:%s %x")`, true, "SELECT FROM_UNIXTIME(1447430881.1234567, '%Y %D %M %h:%i:%s %x')"}, // for issue 224 - {`SELECT CAST('test collated returns' AS CHAR CHARACTER SET utf8) COLLATE utf8_bin;`, true, "SELECT CAST('test collated returns' AS CHAR CHARACTER SET utf8)"}, + {`SELECT CAST('test collated returns' AS CHAR CHARACTER SET utf8) COLLATE utf8_bin;`, true, "SELECT CAST('test collated returns' AS CHAR CHARSET UTF8)"}, // for string functions // trim diff --git a/parser/types/field_type.go b/parser/types/field_type.go index 83eaec3eef..0b7e5ddada 100644 --- a/parser/types/field_type.go +++ b/parser/types/field_type.go @@ -252,54 +252,63 @@ func (ft *FieldType) Restore(ctx *format.RestoreCtx) error { return nil } -// FormatAsCastType is used for write AST back to string. -func (ft *FieldType) FormatAsCastType(w io.Writer) { +// RestoreAsCastType is used for write AST back to string. +func (ft *FieldType) RestoreAsCastType(ctx *format.RestoreCtx) { switch ft.Tp { case mysql.TypeVarString: if ft.Charset == charset.CharsetBin && ft.Collate == charset.CollationBin { - fmt.Fprint(w, "BINARY") + ctx.WriteKeyWord("BINARY") } else { - fmt.Fprint(w, "CHAR") + ctx.WriteKeyWord("CHAR") } if ft.Flen != UnspecifiedLength { - fmt.Fprintf(w, "(%d)", ft.Flen) + ctx.WritePlainf("(%d)", ft.Flen) } if ft.Flag&mysql.BinaryFlag != 0 { - fmt.Fprint(w, " BINARY") + ctx.WriteKeyWord(" BINARY") } if ft.Charset != charset.CharsetBin && ft.Charset != mysql.DefaultCharset { - fmt.Fprintf(w, " CHARACTER SET %s", ft.Charset) + ctx.WriteKeyWord(" CHARSET ") + ctx.WriteKeyWord(ft.Charset) } case mysql.TypeDate: - fmt.Fprint(w, "DATE") + ctx.WriteKeyWord("DATE") case mysql.TypeDatetime: - fmt.Fprint(w, "DATETIME") + ctx.WriteKeyWord("DATETIME") if ft.Decimal > 0 { - fmt.Fprintf(w, "(%d)", ft.Decimal) + ctx.WritePlainf("(%d)", ft.Decimal) } case mysql.TypeNewDecimal: - fmt.Fprint(w, "DECIMAL") + ctx.WriteKeyWord("DECIMAL") if ft.Flen > 0 && ft.Decimal > 0 { - fmt.Fprintf(w, "(%d, %d)", ft.Flen, ft.Decimal) + ctx.WritePlainf("(%d, %d)", ft.Flen, ft.Decimal) } else if ft.Flen > 0 { - fmt.Fprintf(w, "(%d)", ft.Flen) + ctx.WritePlainf("(%d)", ft.Flen) } case mysql.TypeDuration: - fmt.Fprint(w, "TIME") + ctx.WriteKeyWord("TIME") if ft.Decimal > 0 { - fmt.Fprintf(w, "(%d)", ft.Decimal) + ctx.WritePlainf("(%d)", ft.Decimal) } case mysql.TypeLonglong: if ft.Flag&mysql.UnsignedFlag != 0 { - fmt.Fprint(w, "UNSIGNED") + ctx.WriteKeyWord("UNSIGNED") } else { - fmt.Fprint(w, "SIGNED") + ctx.WriteKeyWord("SIGNED") } case mysql.TypeJSON: - fmt.Fprint(w, "JSON") + ctx.WriteKeyWord("JSON") } } +// FormatAsCastType is used for write AST back to string. +func (ft *FieldType) FormatAsCastType(w io.Writer) { + var sb strings.Builder + restoreCtx := format.NewRestoreCtx(format.DefaultRestoreFlags, &sb) + ft.RestoreAsCastType(restoreCtx) + fmt.Fprint(w, sb.String()) +} + // VarStorageLen indicates this column is a variable length column. const VarStorageLen = -1