diff --git a/expression/builtin_string.go b/expression/builtin_string.go index 0e5af6c6bb..304939829a 100644 --- a/expression/builtin_string.go +++ b/expression/builtin_string.go @@ -1479,15 +1479,8 @@ func (b *builtinLocate2ArgsUTF8Sig) evalInt(row chunk.Row) (int64, bool, error) if int64(len([]rune(subStr))) == 0 { return 1, false, nil } - if collate.IsCICollation(b.collation) { - str = strings.ToLower(str) - subStr = strings.ToLower(subStr) - } - ret, idx := 0, strings.Index(str, subStr) - if idx != -1 { - ret = utf8.RuneCountInString(str[:idx]) + 1 - } - return int64(ret), false, nil + + return locateStringWithCollation(str, subStr, b.collation), false, nil } type builtinLocate3ArgsSig struct { @@ -1569,9 +1562,10 @@ func (b *builtinLocate3ArgsUTF8Sig) evalInt(row chunk.Row) (int64, bool, error) return pos + 1, false, nil } slice := string([]rune(str)[pos:]) - idx := strings.Index(slice, subStr) - if idx != -1 { - return pos + int64(utf8.RuneCountInString(slice[:idx])) + 1, false, nil + + idx := locateStringWithCollation(slice, subStr, b.collation) + if idx != 0 { + return pos + idx, false, nil } return 0, false, nil } diff --git a/expression/integration_test.go b/expression/integration_test.go index d03e57e2d7..6ebf5e35d1 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -7112,6 +7112,14 @@ func (s *testIntegrationSerialSuite) TestCollateStringFunction(c *C) { tk.MustQuery("select LOCATE(p4,n3) from t1;").Check(testkit.Rows("0", "0", "0")) tk.MustQuery("select LOCATE(p4,n4) from t1;").Check(testkit.Rows("0", "1", "1")) + tk.MustQuery("select locate('S', 's' collate utf8mb4_general_ci);").Check(testkit.Rows("1")) + tk.MustQuery("select locate('S', 'a' collate utf8mb4_general_ci);").Check(testkit.Rows("0")) + // MySQL return 0 here, I believe it is a bug in MySQL since 'ß' == 's' under utf8mb4_general_ci collation. + tk.MustQuery("select locate('ß', 's' collate utf8mb4_general_ci);").Check(testkit.Rows("1")) + tk.MustQuery("select locate('S', 's' collate utf8mb4_unicode_ci);").Check(testkit.Rows("1")) + tk.MustQuery("select locate('S', 'a' collate utf8mb4_unicode_ci);").Check(testkit.Rows("0")) + tk.MustQuery("select locate('ß', 'ss' collate utf8mb4_unicode_ci);").Check(testkit.Rows("1")) + tk.MustExec("truncate table t1;") tk.MustExec("insert into t1 (a) values (1);") tk.MustExec("insert into t1 (a,p1,p2,p3,p4,n1,n2,n3,n4) values (2,'0aA1!测试テストמבחן ','0aA1!测试テストמבחן ','0aA1!测试テストמבחן ','0aA1!测试テストמבחן ','0Aa1!测试テストמבחן','0Aa1!测试テストמבחן','0Aa1!测试テストמבחן','0Aa1!测试テストמבחן');") @@ -7667,6 +7675,14 @@ func (s *testIntegrationSuite) TestIssue17287(c *C) { tk.MustQuery("execute stmt7 using @val2;").Check(testkit.Rows("1589873946")) } +func (s *testIntegrationSuite) TestIssue26989(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("set names utf8mb4 collate utf8mb4_general_ci;") + tk.MustQuery("select position('a' in 'AA');").Check(testkit.Rows("0")) + tk.MustQuery("select locate('a', 'AA');").Check(testkit.Rows("0")) + tk.MustQuery("select locate('a', 'a');").Check(testkit.Rows("1")) +} + func (s *testIntegrationSuite) TestIssue17898(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") diff --git a/expression/util.go b/expression/util.go index 56d5264956..9cb9c0cbe1 100644 --- a/expression/util.go +++ b/expression/util.go @@ -14,12 +14,14 @@ package expression import ( + "bytes" "context" "math" "strconv" "strings" "time" "unicode" + "unicode/utf8" "github.com/pingcap/errors" "github.com/pingcap/failpoint" @@ -359,6 +361,29 @@ func SubstituteCorCol2Constant(expr Expression) (Expression, error) { return expr, nil } +func locateStringWithCollation(str, substr, coll string) int64 { + collator := collate.GetCollator(coll) + strKey := collator.Key(str) + subStrKey := collator.Key(substr) + + index := bytes.Index(strKey, subStrKey) + if index == -1 || index == 0 { + return int64(index + 1) + } + + // todo: we can use binary search to make it faster. + count := int64(0) + for { + r, size := utf8.DecodeRuneInString(str) + count += 1 + index -= len(collator.Key(string(r))) + if index == 0 { + return count + 1 + } + str = str[size:] + } +} + // timeZone2Duration converts timezone whose format should satisfy the regular condition // `(^(+|-)(0?[0-9]|1[0-2]):[0-5]?\d$)|(^+13:00$)` to time.Duration. func timeZone2Duration(tz string) time.Duration {