update: fix the updatable table name resolution in build update list (#30061)
This commit is contained in:
@ -235,6 +235,28 @@ func (s *testIntegrationSuite) TestIssue24571(c *C) {
|
||||
tk.MustExec("update (select 1 as a) as t, test.t set test.t.a=1;")
|
||||
}
|
||||
|
||||
func (s *testIntegrationSuite) TestBuildUpdateListResolver(c *C) {
|
||||
tk := testkit.NewTestKit(c, s.store)
|
||||
tk.MustExec("use test")
|
||||
|
||||
// For issue https://github.com/pingcap/tidb/issues/24567
|
||||
tk.MustExec("drop table if exists t")
|
||||
tk.MustExec("drop table if exists t1")
|
||||
tk.MustExec("create table t(a int)")
|
||||
tk.MustExec("create table t1(b int)")
|
||||
tk.MustGetErrCode("update (select 1 as a) as t set a=1", mysql.ErrNonUpdatableTable)
|
||||
tk.MustGetErrCode("update (select 1 as a) as t, t1 set a=1", mysql.ErrNonUpdatableTable)
|
||||
tk.MustExec("drop table if exists t")
|
||||
tk.MustExec("drop table if exists t1")
|
||||
|
||||
// For issue https://github.com/pingcap/tidb/issues/30031
|
||||
tk.MustExec("create table t(a int default -1, c int as (a+10) stored)")
|
||||
tk.MustExec("insert into t(a) values(1)")
|
||||
tk.MustExec("update test.t, (select 1 as b) as t set test.t.a=default")
|
||||
tk.MustQuery("select * from t").Check(testkit.Rows("-1 9"))
|
||||
tk.MustExec("drop table if exists t")
|
||||
}
|
||||
|
||||
func (s *testIntegrationSuite) TestIssue22828(c *C) {
|
||||
tk := testkit.NewTestKit(c, s.store)
|
||||
tk.MustExec("use test")
|
||||
|
||||
@ -4731,13 +4731,9 @@ func (b *PlanBuilder) buildUpdate(ctx context.Context, update *ast.UpdateStmt) (
|
||||
proj.SetChildren(p)
|
||||
p = proj
|
||||
|
||||
// update subquery table should be forbidden
|
||||
var notUpdatableTbl []string
|
||||
notUpdatableTbl = extractTableSourceAsNames(update.TableRefs.TableRefs, notUpdatableTbl, true)
|
||||
|
||||
var updateTableList []*ast.TableName
|
||||
updateTableList = extractTableList(update.TableRefs.TableRefs, updateTableList, true)
|
||||
orderedList, np, allAssignmentsAreConstant, err := b.buildUpdateLists(ctx, updateTableList, update.List, p, notUpdatableTbl)
|
||||
utlr := &updatableTableListResolver{}
|
||||
update.Accept(utlr)
|
||||
orderedList, np, allAssignmentsAreConstant, err := b.buildUpdateLists(ctx, utlr.updatableTableList, update.List, p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -4820,8 +4816,7 @@ func isCTE(tl *ast.TableName) bool {
|
||||
return tl.TableInfo == nil
|
||||
}
|
||||
|
||||
func (b *PlanBuilder) buildUpdateLists(ctx context.Context, tableList []*ast.TableName, list []*ast.Assignment, p LogicalPlan,
|
||||
notUpdatableTbl []string) (newList []*expression.Assignment, po LogicalPlan, allAssignmentsAreConstant bool, e error) {
|
||||
func (b *PlanBuilder) buildUpdateLists(ctx context.Context, tableList []*ast.TableName, list []*ast.Assignment, p LogicalPlan) (newList []*expression.Assignment, po LogicalPlan, allAssignmentsAreConstant bool, e error) {
|
||||
b.curClause = fieldList
|
||||
// modifyColumns indicates which columns are in set list,
|
||||
// and if it is set to `DEFAULT`
|
||||
@ -4844,21 +4839,22 @@ func (b *PlanBuilder) buildUpdateLists(ctx context.Context, tableList []*ast.Tab
|
||||
columnsIdx[assign.Column] = idx
|
||||
}
|
||||
name := p.OutputNames()[idx]
|
||||
foundListItem := false
|
||||
for _, tl := range tableList {
|
||||
if (tl.Schema.L == "" || tl.Schema.L == name.DBName.L) && (tl.Name.L == name.TblName.L) {
|
||||
if isCTE(tl) || tl.TableInfo.IsView() || tl.TableInfo.IsSequence() {
|
||||
return nil, nil, false, ErrNonUpdatableTable.GenWithStackByArgs(name.TblName.O, "UPDATE")
|
||||
}
|
||||
// may be a subquery
|
||||
if tl.Schema.L == "" {
|
||||
for _, nTbl := range notUpdatableTbl {
|
||||
if nTbl == name.TblName.L {
|
||||
return nil, nil, false, ErrNonUpdatableTable.GenWithStackByArgs(name.TblName.O, "UPDATE")
|
||||
}
|
||||
}
|
||||
}
|
||||
foundListItem = true
|
||||
}
|
||||
}
|
||||
if !foundListItem {
|
||||
// For case like:
|
||||
// 1: update (select * from t1) t1 set b = 1111111 ----- (no updatable table here)
|
||||
// 2: update (select 1 as a) as t, t1 set a=1 ----- (updatable t1 don't have column a)
|
||||
// --- subQuery is not counted as updatable table.
|
||||
return nil, nil, false, ErrNonUpdatableTable.GenWithStackByArgs(name.TblName.O, "UPDATE")
|
||||
}
|
||||
columnFullName := fmt.Sprintf("%s.%s.%s", name.DBName.L, name.TblName.L, name.ColName.L)
|
||||
// We save a flag for the column in map `modifyColumns`
|
||||
// This flag indicated if assign keyword `DEFAULT` to the column
|
||||
@ -4873,15 +4869,7 @@ func (b *PlanBuilder) buildUpdateLists(ctx context.Context, tableList []*ast.Tab
|
||||
// And, fill virtualAssignments here; that's for generated columns.
|
||||
virtualAssignments := make([]*ast.Assignment, 0)
|
||||
for _, tn := range tableList {
|
||||
// Only generate virtual to updatable table, skip not updatable table(i.e. table in update's subQuery)
|
||||
updatable := true
|
||||
for _, nTbl := range notUpdatableTbl {
|
||||
if tn.Name.L == nTbl {
|
||||
updatable = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if !updatable || isCTE(tn) || tn.TableInfo.IsView() || tn.TableInfo.IsSequence() {
|
||||
if isCTE(tn) || tn.TableInfo.IsView() || tn.TableInfo.IsSequence() {
|
||||
continue
|
||||
}
|
||||
|
||||
@ -5984,6 +5972,35 @@ func unfoldSelectList(list *ast.SetOprSelectList, unfoldList *ast.SetOprSelectLi
|
||||
}
|
||||
}
|
||||
|
||||
type updatableTableListResolver struct {
|
||||
updatableTableList []*ast.TableName
|
||||
}
|
||||
|
||||
func (u *updatableTableListResolver) Enter(inNode ast.Node) (ast.Node, bool) {
|
||||
switch v := inNode.(type) {
|
||||
case *ast.UpdateStmt, *ast.TableRefsClause, *ast.Join, *ast.TableSource, *ast.TableName:
|
||||
return v, false
|
||||
}
|
||||
return inNode, true
|
||||
}
|
||||
|
||||
func (u *updatableTableListResolver) Leave(inNode ast.Node) (ast.Node, bool) {
|
||||
switch v := inNode.(type) {
|
||||
case *ast.TableSource:
|
||||
if s, ok := v.Source.(*ast.TableName); ok {
|
||||
if v.AsName.L != "" {
|
||||
newTableName := *s
|
||||
newTableName.Name = v.AsName
|
||||
newTableName.Schema = model.NewCIStr("")
|
||||
u.updatableTableList = append(u.updatableTableList, &newTableName)
|
||||
} else {
|
||||
u.updatableTableList = append(u.updatableTableList, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
return inNode, true
|
||||
}
|
||||
|
||||
// extractTableList extracts all the TableNames from node.
|
||||
// If asName is true, extract AsName prior to OrigName.
|
||||
// Privilege check should use OrigName, while expression may use AsName.
|
||||
@ -6114,28 +6131,6 @@ func collectTableName(node ast.ResultSetNode, updatableName *map[string]bool, in
|
||||
}
|
||||
}
|
||||
|
||||
// extractTableSourceAsNames extracts TableSource.AsNames from node.
|
||||
// if onlySelectStmt is set to be true, only extracts AsNames when TableSource.Source.(type) == *ast.SelectStmt
|
||||
func extractTableSourceAsNames(node ast.ResultSetNode, input []string, onlySelectStmt bool) []string {
|
||||
switch x := node.(type) {
|
||||
case *ast.Join:
|
||||
input = extractTableSourceAsNames(x.Left, input, onlySelectStmt)
|
||||
input = extractTableSourceAsNames(x.Right, input, onlySelectStmt)
|
||||
case *ast.TableSource:
|
||||
if _, ok := x.Source.(*ast.SelectStmt); !ok && onlySelectStmt {
|
||||
break
|
||||
}
|
||||
if s, ok := x.Source.(*ast.TableName); ok {
|
||||
if x.AsName.L == "" {
|
||||
input = append(input, s.Name.L)
|
||||
break
|
||||
}
|
||||
}
|
||||
input = append(input, x.AsName.L)
|
||||
}
|
||||
return input
|
||||
}
|
||||
|
||||
func appendDynamicVisitInfo(vi []visitInfo, priv string, withGrant bool, err error) []visitInfo {
|
||||
return append(vi, visitInfo{
|
||||
privilege: mysql.ExtendedPriv,
|
||||
|
||||
Reference in New Issue
Block a user