From f6e34c31b5074da838b2c159b547795121528068 Mon Sep 17 00:00:00 2001 From: shenli Date: Thu, 17 Sep 2015 13:07:18 +0800 Subject: [PATCH 1/3] stmts: Fix bug for update when JoinTable with alias table name. Fix bug found in beego/orm --- stmt/stmts/insert.go | 2 +- stmt/stmts/update.go | 23 ++++++++++++++++++++--- stmt/stmts/update_test.go | 22 ++++++++++++++++++++++ 3 files changed, 43 insertions(+), 4 deletions(-) diff --git a/stmt/stmts/insert.go b/stmt/stmts/insert.go index 70d41343d3..80888ee096 100644 --- a/stmt/stmts/insert.go +++ b/stmt/stmts/insert.go @@ -295,7 +295,7 @@ func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) // On duplicate key Update the duplicate row. // Evaluate the updated value. // TODO: report rows affected and last insert id. - toUpdateColumns, err := getUpdateColumns(t, s.OnDuplicate, false) + toUpdateColumns, err := getUpdateColumns(t, s.OnDuplicate, false, nil) if err != nil { return nil, errors.Trace(err) } diff --git a/stmt/stmts/update.go b/stmt/stmts/update.go index a4af62b61d..2f55d67649 100644 --- a/stmt/stmts/update.go +++ b/stmt/stmts/update.go @@ -85,14 +85,22 @@ func (s *UpdateStmt) SetText(text string) { s.Text = text } -func getUpdateColumns(t table.Table, assignList []expressions.Assignment, isMultipleTable bool) ([]*column.Col, error) { +func getUpdateColumns(t table.Table, assignList []expressions.Assignment, isMultipleTable bool, tblAliasMap map[string]string) ([]*column.Col, error) { // TODO: We should check the validate if assignList in somewhere else. Maybe in building plan. tcols := make([]*column.Col, 0, len(assignList)) tname := t.TableName() for _, asgn := range assignList { if isMultipleTable { if !strings.EqualFold(tname.O, asgn.TableName) { - continue + // Try to compare alias name with t.TableName() + if tblAliasMap == nil { + continue + } + if alias, ok := tblAliasMap[asgn.TableName]; !ok { + continue + } else if !strings.EqualFold(tname.O, alias) { + continue + } } } col := column.FindCol(t.Cols(), asgn.ColName) @@ -243,6 +251,15 @@ func (s *UpdateStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) { updatedRowKeys := make(map[string]bool) // For single-table syntax, TableRef may contain multiple tables isMultipleTable := s.MultipleTable || s.TableRefs.MultipleTable() + + // Get table alias map. + fs := p.GetFields() + tblAliasMap := make(map[string]string) + for _, f := range fs { + if f.TableName != f.OrgTableName { + tblAliasMap[f.TableName] = f.OrgTableName + } + } for { row, err1 := p.Next(ctx) if err1 != nil { @@ -279,7 +296,7 @@ func (s *UpdateStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) { end := start + len(tbl.Cols()) data := rowData[start:end] start = end - tcols, err2 := getUpdateColumns(tbl, s.List, isMultipleTable) + tcols, err2 := getUpdateColumns(tbl, s.List, isMultipleTable, tblAliasMap) if err2 != nil { return nil, errors.Trace(err2) } diff --git a/stmt/stmts/update_test.go b/stmt/stmts/update_test.go index 069db96215..b97c80f319 100644 --- a/stmt/stmts/update_test.go +++ b/stmt/stmts/update_test.go @@ -151,4 +151,26 @@ func (s *testStmtSuite) TestMultipleTableUpdate(c *C) { } rows.Close() mustCommit(c, tx) + + // JoinTable with alias table name. + r = mustExec(c, testDB, `UPDATE items T0 join month T1 on T0.id=T1.mid SET T0.price=T1.mprice;`) + c.Assert(r, NotNil) + tx = mustBegin(c, testDB) + rows, err = tx.Query("SELECT * FROM items") + c.Assert(err, IsNil) + expectedResult = map[int]string{ + 11: "month_price_11", + 12: "items_price_12", + 13: "month_price_13", + } + for rows.Next() { + var ( + id int + price string + ) + rows.Scan(&id, &price) + c.Assert(price, Equals, expectedResult[id]) + } + rows.Close() + mustCommit(c, tx) } From f7662b71db13aed0b9200eb8af364f485e08ca0d Mon Sep 17 00:00:00 2001 From: shenli Date: Thu, 17 Sep 2015 13:56:31 +0800 Subject: [PATCH 2/3] expressions: Process []byte in like.go Support "LIKE BINARY expr" --- expression/expressions/like.go | 7 ++++++- expression/expressions/like_test.go | 16 ++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/expression/expressions/like.go b/expression/expressions/like.go index 4f7ce4dcc9..3e480d5ef9 100644 --- a/expression/expressions/like.go +++ b/expression/expressions/like.go @@ -103,7 +103,12 @@ func (p *PatternLike) Eval(ctx context.Context, args map[interface{}]interface{} } spattern, ok := pattern.(string) if !ok { - return nil, errors.Errorf("non-string pattern in LIKE: %v (Value of type %T)", pattern, pattern) + bpattern, ok := pattern.([]byte) + if !ok { + return nil, errors.Errorf("non-string pattern in LIKE: %v (Value of type %T)", pattern, pattern) + } + // TODO: BINARY pattern match should be case-insensitive + spattern = string(bpattern) } p.patChars, p.patTypes = compilePattern(spattern) diff --git a/expression/expressions/like_test.go b/expression/expressions/like_test.go index a30e6b1c30..dfdb5c7692 100644 --- a/expression/expressions/like_test.go +++ b/expression/expressions/like_test.go @@ -105,4 +105,20 @@ func (*testLikeSuite) TestEval(c *C) { pattern.Expr = mockExpr{isStatic: false, val: nil} _, err = pattern.Eval(nil, nil) c.Assert(err, IsNil) + + // Testcase for "LIKE BINARY xxx" + pattern = &PatternLike{ + Expr: mockExpr{isStatic: true, val: "slien"}, + Pattern: mockExpr{isStatic: true, val: []byte("%E%")}, + } + v, err := pattern.Eval(nil, nil) + c.Assert(err, IsNil) + c.Assert(v, IsTrue) + pattern = &PatternLike{ + Expr: mockExpr{isStatic: true, val: "slin"}, + Pattern: mockExpr{isStatic: true, val: []byte("%E%")}, + } + v, err = pattern.Eval(nil, nil) + c.Assert(err, IsNil) + c.Assert(v, IsFalse) } From c6df21ad68eb16fc334063c17ed79bd814a04f33 Mon Sep 17 00:00:00 2001 From: shenli Date: Thu, 17 Sep 2015 15:01:53 +0800 Subject: [PATCH 3/3] *: Address comments --- expression/expressions/like.go | 17 ++++++++--------- stmt/stmts/update.go | 1 + 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/expression/expressions/like.go b/expression/expressions/like.go index 3e480d5ef9..d7abb38bd7 100644 --- a/expression/expressions/like.go +++ b/expression/expressions/like.go @@ -101,16 +101,15 @@ func (p *PatternLike) Eval(ctx context.Context, args map[interface{}]interface{} if pattern == nil { return nil, nil } - spattern, ok := pattern.(string) - if !ok { - bpattern, ok := pattern.([]byte) - if !ok { - return nil, errors.Errorf("non-string pattern in LIKE: %v (Value of type %T)", pattern, pattern) - } - // TODO: BINARY pattern match should be case-insensitive - spattern = string(bpattern) + var spattern string + switch v := pattern.(type) { + case string: + spattern = v + case []byte: + spattern = string(v) + default: + return nil, errors.Errorf("Pattern should be string or []byte in LIKE: %v (Value of type %T)", pattern, pattern) } - p.patChars, p.patTypes = compilePattern(spattern) } diff --git a/stmt/stmts/update.go b/stmt/stmts/update.go index 2f55d67649..1115877279 100644 --- a/stmt/stmts/update.go +++ b/stmt/stmts/update.go @@ -87,6 +87,7 @@ func (s *UpdateStmt) SetText(text string) { func getUpdateColumns(t table.Table, assignList []expressions.Assignment, isMultipleTable bool, tblAliasMap map[string]string) ([]*column.Col, error) { // TODO: We should check the validate if assignList in somewhere else. Maybe in building plan. + // TODO: We should use field.GetFieldIndex to replace this function. tcols := make([]*column.Col, 0, len(assignList)) tname := t.TableName() for _, asgn := range assignList {