expression: make locate function consider collation (#27079)

This commit is contained in:
xiongjiwei
2021-08-11 20:29:16 +08:00
committed by GitHub
parent 281b431292
commit 3ac7914b7e
3 changed files with 47 additions and 12 deletions

View File

@ -1479,15 +1479,8 @@ func (b *builtinLocate2ArgsUTF8Sig) evalInt(row chunk.Row) (int64, bool, error)
if int64(len([]rune(subStr))) == 0 {
return 1, false, nil
}
if collate.IsCICollation(b.collation) {
str = strings.ToLower(str)
subStr = strings.ToLower(subStr)
}
ret, idx := 0, strings.Index(str, subStr)
if idx != -1 {
ret = utf8.RuneCountInString(str[:idx]) + 1
}
return int64(ret), false, nil
return locateStringWithCollation(str, subStr, b.collation), false, nil
}
type builtinLocate3ArgsSig struct {
@ -1569,9 +1562,10 @@ func (b *builtinLocate3ArgsUTF8Sig) evalInt(row chunk.Row) (int64, bool, error)
return pos + 1, false, nil
}
slice := string([]rune(str)[pos:])
idx := strings.Index(slice, subStr)
if idx != -1 {
return pos + int64(utf8.RuneCountInString(slice[:idx])) + 1, false, nil
idx := locateStringWithCollation(slice, subStr, b.collation)
if idx != 0 {
return pos + idx, false, nil
}
return 0, false, nil
}

View File

@ -7112,6 +7112,14 @@ func (s *testIntegrationSerialSuite) TestCollateStringFunction(c *C) {
tk.MustQuery("select LOCATE(p4,n3) from t1;").Check(testkit.Rows("0", "0", "0"))
tk.MustQuery("select LOCATE(p4,n4) from t1;").Check(testkit.Rows("0", "1", "1"))
tk.MustQuery("select locate('S', 's' collate utf8mb4_general_ci);").Check(testkit.Rows("1"))
tk.MustQuery("select locate('S', 'a' collate utf8mb4_general_ci);").Check(testkit.Rows("0"))
// MySQL return 0 here, I believe it is a bug in MySQL since 'ß' == 's' under utf8mb4_general_ci collation.
tk.MustQuery("select locate('ß', 's' collate utf8mb4_general_ci);").Check(testkit.Rows("1"))
tk.MustQuery("select locate('S', 's' collate utf8mb4_unicode_ci);").Check(testkit.Rows("1"))
tk.MustQuery("select locate('S', 'a' collate utf8mb4_unicode_ci);").Check(testkit.Rows("0"))
tk.MustQuery("select locate('ß', 'ss' collate utf8mb4_unicode_ci);").Check(testkit.Rows("1"))
tk.MustExec("truncate table t1;")
tk.MustExec("insert into t1 (a) values (1);")
tk.MustExec("insert into t1 (a,p1,p2,p3,p4,n1,n2,n3,n4) values (2,'0aA1!测试テストמבחן ','0aA1!测试テストמבחן ','0aA1!测试テストמבחן ','0aA1!测试テストמבחן ','0Aa1!测试テストמבחן','0Aa1!测试テストמבחן','0Aa1!测试テストמבחן','0Aa1!测试テストמבחן');")
@ -7667,6 +7675,14 @@ func (s *testIntegrationSuite) TestIssue17287(c *C) {
tk.MustQuery("execute stmt7 using @val2;").Check(testkit.Rows("1589873946"))
}
func (s *testIntegrationSuite) TestIssue26989(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("set names utf8mb4 collate utf8mb4_general_ci;")
tk.MustQuery("select position('a' in 'AA');").Check(testkit.Rows("0"))
tk.MustQuery("select locate('a', 'AA');").Check(testkit.Rows("0"))
tk.MustQuery("select locate('a', 'a');").Check(testkit.Rows("1"))
}
func (s *testIntegrationSuite) TestIssue17898(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")

View File

@ -14,12 +14,14 @@
package expression
import (
"bytes"
"context"
"math"
"strconv"
"strings"
"time"
"unicode"
"unicode/utf8"
"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
@ -359,6 +361,29 @@ func SubstituteCorCol2Constant(expr Expression) (Expression, error) {
return expr, nil
}
func locateStringWithCollation(str, substr, coll string) int64 {
collator := collate.GetCollator(coll)
strKey := collator.Key(str)
subStrKey := collator.Key(substr)
index := bytes.Index(strKey, subStrKey)
if index == -1 || index == 0 {
return int64(index + 1)
}
// todo: we can use binary search to make it faster.
count := int64(0)
for {
r, size := utf8.DecodeRuneInString(str)
count += 1
index -= len(collator.Key(string(r)))
if index == 0 {
return count + 1
}
str = str[size:]
}
}
// timeZone2Duration converts timezone whose format should satisfy the regular condition
// `(^(+|-)(0?[0-9]|1[0-2]):[0-5]?\d$)|(^+13:00$)` to time.Duration.
func timeZone2Duration(tz string) time.Duration {