From b5ef73dc57cdd34a4bf0c99a9708c1938361e728 Mon Sep 17 00:00:00 2001 From: arnkore Date: Fri, 25 Jan 2019 13:57:47 +0800 Subject: [PATCH] [parser] parser: implement restore for UpdateStmt (#190) * implement restore for UpdateStmt * add test case for UpdateStmt's restore functionality * format code * 1. put comma check at the start of loop; 2. check the errors for the restore of UpdateStmt.List[%d].Column and UpdateStmt.List[%d].Column.Expr; * clean origin HintName of UpdateStmt.TableHInts --- parser/ast/dml.go | 64 ++++++++++++++++++++++++++++++++++++++++++- parser/parser_test.go | 26 ++++++++++++++---- 2 files changed, 84 insertions(+), 6 deletions(-) diff --git a/parser/ast/dml.go b/parser/ast/dml.go index dbe794a200..da7c6774c6 100755 --- a/parser/ast/dml.go +++ b/parser/ast/dml.go @@ -1473,7 +1473,69 @@ type UpdateStmt struct { // Restore implements Node interface. func (n *UpdateStmt) Restore(ctx *RestoreCtx) error { - return errors.New("Not implemented") + ctx.WriteKeyWord("UPDATE ") + + if n.TableHints != nil && len(n.TableHints) != 0 { + ctx.WritePlain("/*+ ") + for i, tableHint := range n.TableHints { + if err := tableHint.Restore(ctx); err != nil { + return errors.Annotatef(err, "An error occurred while restore UpdateStmt.TableHints[%d]", i) + } + } + ctx.WritePlain("*/ ") + } + + switch n.Priority { + case mysql.LowPriority: + ctx.WriteKeyWord("LOW_PRIORITY ") + } + if n.IgnoreErr { + ctx.WriteKeyWord("IGNORE ") + } + + if err := n.TableRefs.Restore(ctx); err != nil { + return errors.Annotate(err, "An error occur while restore UpdateStmt.TableRefs") + } + + ctx.WriteKeyWord(" SET ") + for i, assignment := range n.List { + if i != 0 { + ctx.WritePlain(", ") + } + + if err := assignment.Column.Restore(ctx); err != nil { + return errors.Annotatef(err, "An error occur while restore UpdateStmt.List[%d].Column", i) + } + + ctx.WritePlain("=") + + if err := assignment.Expr.Restore(ctx); err != nil { + return errors.Annotatef(err, "An error occur while restore UpdateStmt.List[%d].Expr", i) + } + } + + if n.Where != nil { + ctx.WriteKeyWord(" WHERE ") + if err := n.Where.Restore(ctx); err != nil { + return errors.Annotate(err, "An error occur while restore UpdateStmt.Where") + } + } + + if n.Order != nil { + ctx.WritePlain(" ") + if err := n.Order.Restore(ctx); err != nil { + return errors.Annotate(err, "An error occur while restore UpdateStmt.Order") + } + } + + if n.Limit != nil { + ctx.WritePlain(" ") + if err := n.Limit.Restore(ctx); err != nil { + return errors.Annotate(err, "An error occur while restore UpdateStmt.Limit") + } + } + + return nil } // Accept implements Node Accept interface. diff --git a/parser/parser_test.go b/parser/parser_test.go index 08911e3e2b..30414d4555 100755 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -509,12 +509,24 @@ func (s *testParserSuite) TestDMLStmt(c *C) { {"DELETE /*+ TiDB_HJ(t1, t2) */ t1, t2 from t1, t2 where t1.id=t2.id", true, ""}, // for update statement - {"UPDATE t SET id = id + 1 ORDER BY id DESC;", true, ""}, - {"UPDATE items,month SET items.price=month.price WHERE items.id=month.id;", true, ""}, + {"UPDATE LOW_PRIORITY IGNORE t SET id = id + 1 ORDER BY id DESC;", true, "UPDATE LOW_PRIORITY IGNORE `t` SET `id`=`id`+1 ORDER BY `id` DESC"}, + {"UPDATE t SET id = id + 1 ORDER BY id DESC;", true, "UPDATE `t` SET `id`=`id`+1 ORDER BY `id` DESC"}, + {"UPDATE t SET id = id + 1 ORDER BY id DESC limit 3 ;", true, "UPDATE `t` SET `id`=`id`+1 ORDER BY `id` DESC LIMIT 3"}, + {"UPDATE t SET id = id + 1, name = 'jojo';", true, "UPDATE `t` SET `id`=`id`+1, `name`='jojo'"}, + {"UPDATE items,month SET items.price=month.price WHERE items.id=month.id;", true, "UPDATE (`items`) JOIN `month` SET `items`.`price`=`month`.`price` WHERE `items`.`id`=`month`.`id`"}, + {"UPDATE user T0 LEFT OUTER JOIN user_profile T1 ON T1.id = T0.profile_id SET T0.profile_id = 1 WHERE T0.profile_id IN (1);", true, "UPDATE `user` AS `T0` LEFT JOIN `user_profile` AS `T1` ON `T1`.`id`=`T0`.`profile_id` SET `T0`.`profile_id`=1 WHERE `T0`.`profile_id` IN (1)"}, + {"UPDATE t1, t2 set t1.profile_id = 1, t2.profile_id = 1 where ta.a=t.ba", true, "UPDATE (`t1`) JOIN `t2` SET `t1`.`profile_id`=1, `t2`.`profile_id`=1 WHERE `ta`.`a`=`t`.`ba`"}, + // for optimizer hint in update statement + {"UPDATE /*+ TiDB_INLJ(t1, t2) */ t1, t2 set t1.profile_id = 1, t2.profile_id = 1 where ta.a=t.ba", true, "UPDATE /*+ TIDB_INLJ(`t1`, `t2`)*/ (`t1`) JOIN `t2` SET `t1`.`profile_id`=1, `t2`.`profile_id`=1 WHERE `ta`.`a`=`t`.`ba`"}, + {"UPDATE /*+ TiDB_SMJ(t1, t2) */ t1, t2 set t1.profile_id = 1, t2.profile_id = 1 where ta.a=t.ba", true, "UPDATE /*+ TIDB_SMJ(`t1`, `t2`)*/ (`t1`) JOIN `t2` SET `t1`.`profile_id`=1, `t2`.`profile_id`=1 WHERE `ta`.`a`=`t`.`ba`"}, + {"UPDATE /*+ TiDB_HJ(t1, t2) */ t1, t2 set t1.profile_id = 1, t2.profile_id = 1 where ta.a=t.ba", true, "UPDATE /*+ TIDB_HJ(`t1`, `t2`)*/ (`t1`) JOIN `t2` SET `t1`.`profile_id`=1, `t2`.`profile_id`=1 WHERE `ta`.`a`=`t`.`ba`"}, + // fail case for update statement {"UPDATE items,month SET items.price=month.price WHERE items.id=month.id LIMIT 10;", false, ""}, - {"UPDATE user T0 LEFT OUTER JOIN user_profile T1 ON T1.id = T0.profile_id SET T0.profile_id = 1 WHERE T0.profile_id IN (1);", true, ""}, - {"UPDATE /*+ TiDB_INLJ(t1, t2) */ t1, t2 set t1.profile_id = 1, t2.profile_id = 1 where ta.a=t.ba", true, ""}, - {"UPDATE /*+ TiDB_SMJ(t1, t2) */ t1, t2 set t1.profile_id = 1, t2.profile_id = 1 where ta.a=t.ba", true, ""}, + {"UPDATE items,month SET items.price=month.price WHERE items.id=month.id order by month.id;", false, ""}, + // for "USE INDEX" in delete statement + {"UPDATE t1 USE INDEX(idx_a) SET t1.price=3.25 WHERE t1.id=1;", true, "UPDATE `t1` USE INDEX (`idx_a`) SET `t1`.`price`=3.25 WHERE `t1`.`id`=1"}, + {"UPDATE t1 USE INDEX(idx_a) JOIN t2 SET t1.price=t2.price WHERE t1.id=t2.id;", true, "UPDATE `t1` USE INDEX (`idx_a`) JOIN `t2` SET `t1`.`price`=`t2`.`price` WHERE `t1`.`id`=`t2`.`id`"}, + {"UPDATE t1 USE INDEX(idx_a) JOIN t2 USE INDEX(idx_a) SET t1.price=t2.price WHERE t1.id=t2.id;", true, "UPDATE `t1` USE INDEX (`idx_a`) JOIN `t2` USE INDEX (`idx_a`) SET `t1`.`price`=`t2`.`price` WHERE `t1`.`id`=`t2`.`id`"}, // for select with where clause {"SELECT * FROM t WHERE 1 = 1", true, "SELECT * FROM `t` WHERE 1=1"}, @@ -2981,6 +2993,10 @@ func (checker *nodeTextCleaner) Enter(in ast.Node) (out ast.Node, skipChildren b } } } + case *ast.UpdateStmt: + for _, tableHint := range node.TableHints { + tableHint.HintName.O = "" + } case *ast.Constraint: if node.Option != nil { if node.Option.KeyBlockSize == 0x0 && node.Option.Tp == 0 && node.Option.Comment == "" {