diff --git a/expression/expressions/like.go b/expression/expressions/like.go index 4f7ce4dcc9..d7abb38bd7 100644 --- a/expression/expressions/like.go +++ b/expression/expressions/like.go @@ -101,11 +101,15 @@ func (p *PatternLike) Eval(ctx context.Context, args map[interface{}]interface{} if pattern == nil { return nil, nil } - spattern, ok := pattern.(string) - if !ok { - return nil, errors.Errorf("non-string pattern in LIKE: %v (Value of type %T)", pattern, pattern) + var spattern string + switch v := pattern.(type) { + case string: + spattern = v + case []byte: + spattern = string(v) + default: + return nil, errors.Errorf("Pattern should be string or []byte in LIKE: %v (Value of type %T)", pattern, pattern) } - p.patChars, p.patTypes = compilePattern(spattern) } diff --git a/expression/expressions/like_test.go b/expression/expressions/like_test.go index a30e6b1c30..dfdb5c7692 100644 --- a/expression/expressions/like_test.go +++ b/expression/expressions/like_test.go @@ -105,4 +105,20 @@ func (*testLikeSuite) TestEval(c *C) { pattern.Expr = mockExpr{isStatic: false, val: nil} _, err = pattern.Eval(nil, nil) c.Assert(err, IsNil) + + // Testcase for "LIKE BINARY xxx" + pattern = &PatternLike{ + Expr: mockExpr{isStatic: true, val: "slien"}, + Pattern: mockExpr{isStatic: true, val: []byte("%E%")}, + } + v, err := pattern.Eval(nil, nil) + c.Assert(err, IsNil) + c.Assert(v, IsTrue) + pattern = &PatternLike{ + Expr: mockExpr{isStatic: true, val: "slin"}, + Pattern: mockExpr{isStatic: true, val: []byte("%E%")}, + } + v, err = pattern.Eval(nil, nil) + c.Assert(err, IsNil) + c.Assert(v, IsFalse) } diff --git a/stmt/stmts/insert.go b/stmt/stmts/insert.go index 70d41343d3..80888ee096 100644 --- a/stmt/stmts/insert.go +++ b/stmt/stmts/insert.go @@ -295,7 +295,7 @@ func (s *InsertIntoStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) // On duplicate key Update the duplicate row. // Evaluate the updated value. // TODO: report rows affected and last insert id. - toUpdateColumns, err := getUpdateColumns(t, s.OnDuplicate, false) + toUpdateColumns, err := getUpdateColumns(t, s.OnDuplicate, false, nil) if err != nil { return nil, errors.Trace(err) } diff --git a/stmt/stmts/update.go b/stmt/stmts/update.go index a4af62b61d..1115877279 100644 --- a/stmt/stmts/update.go +++ b/stmt/stmts/update.go @@ -85,14 +85,23 @@ func (s *UpdateStmt) SetText(text string) { s.Text = text } -func getUpdateColumns(t table.Table, assignList []expressions.Assignment, isMultipleTable bool) ([]*column.Col, error) { +func getUpdateColumns(t table.Table, assignList []expressions.Assignment, isMultipleTable bool, tblAliasMap map[string]string) ([]*column.Col, error) { // TODO: We should check the validate if assignList in somewhere else. Maybe in building plan. + // TODO: We should use field.GetFieldIndex to replace this function. tcols := make([]*column.Col, 0, len(assignList)) tname := t.TableName() for _, asgn := range assignList { if isMultipleTable { if !strings.EqualFold(tname.O, asgn.TableName) { - continue + // Try to compare alias name with t.TableName() + if tblAliasMap == nil { + continue + } + if alias, ok := tblAliasMap[asgn.TableName]; !ok { + continue + } else if !strings.EqualFold(tname.O, alias) { + continue + } } } col := column.FindCol(t.Cols(), asgn.ColName) @@ -243,6 +252,15 @@ func (s *UpdateStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) { updatedRowKeys := make(map[string]bool) // For single-table syntax, TableRef may contain multiple tables isMultipleTable := s.MultipleTable || s.TableRefs.MultipleTable() + + // Get table alias map. + fs := p.GetFields() + tblAliasMap := make(map[string]string) + for _, f := range fs { + if f.TableName != f.OrgTableName { + tblAliasMap[f.TableName] = f.OrgTableName + } + } for { row, err1 := p.Next(ctx) if err1 != nil { @@ -279,7 +297,7 @@ func (s *UpdateStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) { end := start + len(tbl.Cols()) data := rowData[start:end] start = end - tcols, err2 := getUpdateColumns(tbl, s.List, isMultipleTable) + tcols, err2 := getUpdateColumns(tbl, s.List, isMultipleTable, tblAliasMap) if err2 != nil { return nil, errors.Trace(err2) } diff --git a/stmt/stmts/update_test.go b/stmt/stmts/update_test.go index 069db96215..b97c80f319 100644 --- a/stmt/stmts/update_test.go +++ b/stmt/stmts/update_test.go @@ -151,4 +151,26 @@ func (s *testStmtSuite) TestMultipleTableUpdate(c *C) { } rows.Close() mustCommit(c, tx) + + // JoinTable with alias table name. + r = mustExec(c, testDB, `UPDATE items T0 join month T1 on T0.id=T1.mid SET T0.price=T1.mprice;`) + c.Assert(r, NotNil) + tx = mustBegin(c, testDB) + rows, err = tx.Query("SELECT * FROM items") + c.Assert(err, IsNil) + expectedResult = map[int]string{ + 11: "month_price_11", + 12: "items_price_12", + 13: "month_price_13", + } + for rows.Next() { + var ( + id int + price string + ) + rows.Scan(&id, &price) + c.Assert(price, Equals, expectedResult[id]) + } + rows.Close() + mustCommit(c, tx) }