expression: make locate function consider collation (#27079)
This commit is contained in:
@ -1479,15 +1479,8 @@ func (b *builtinLocate2ArgsUTF8Sig) evalInt(row chunk.Row) (int64, bool, error)
|
||||
if int64(len([]rune(subStr))) == 0 {
|
||||
return 1, false, nil
|
||||
}
|
||||
if collate.IsCICollation(b.collation) {
|
||||
str = strings.ToLower(str)
|
||||
subStr = strings.ToLower(subStr)
|
||||
}
|
||||
ret, idx := 0, strings.Index(str, subStr)
|
||||
if idx != -1 {
|
||||
ret = utf8.RuneCountInString(str[:idx]) + 1
|
||||
}
|
||||
return int64(ret), false, nil
|
||||
|
||||
return locateStringWithCollation(str, subStr, b.collation), false, nil
|
||||
}
|
||||
|
||||
type builtinLocate3ArgsSig struct {
|
||||
@ -1569,9 +1562,10 @@ func (b *builtinLocate3ArgsUTF8Sig) evalInt(row chunk.Row) (int64, bool, error)
|
||||
return pos + 1, false, nil
|
||||
}
|
||||
slice := string([]rune(str)[pos:])
|
||||
idx := strings.Index(slice, subStr)
|
||||
if idx != -1 {
|
||||
return pos + int64(utf8.RuneCountInString(slice[:idx])) + 1, false, nil
|
||||
|
||||
idx := locateStringWithCollation(slice, subStr, b.collation)
|
||||
if idx != 0 {
|
||||
return pos + idx, false, nil
|
||||
}
|
||||
return 0, false, nil
|
||||
}
|
||||
|
||||
@ -7112,6 +7112,14 @@ func (s *testIntegrationSerialSuite) TestCollateStringFunction(c *C) {
|
||||
tk.MustQuery("select LOCATE(p4,n3) from t1;").Check(testkit.Rows("0", "0", "0"))
|
||||
tk.MustQuery("select LOCATE(p4,n4) from t1;").Check(testkit.Rows("0", "1", "1"))
|
||||
|
||||
tk.MustQuery("select locate('S', 's' collate utf8mb4_general_ci);").Check(testkit.Rows("1"))
|
||||
tk.MustQuery("select locate('S', 'a' collate utf8mb4_general_ci);").Check(testkit.Rows("0"))
|
||||
// MySQL return 0 here, I believe it is a bug in MySQL since 'ß' == 's' under utf8mb4_general_ci collation.
|
||||
tk.MustQuery("select locate('ß', 's' collate utf8mb4_general_ci);").Check(testkit.Rows("1"))
|
||||
tk.MustQuery("select locate('S', 's' collate utf8mb4_unicode_ci);").Check(testkit.Rows("1"))
|
||||
tk.MustQuery("select locate('S', 'a' collate utf8mb4_unicode_ci);").Check(testkit.Rows("0"))
|
||||
tk.MustQuery("select locate('ß', 'ss' collate utf8mb4_unicode_ci);").Check(testkit.Rows("1"))
|
||||
|
||||
tk.MustExec("truncate table t1;")
|
||||
tk.MustExec("insert into t1 (a) values (1);")
|
||||
tk.MustExec("insert into t1 (a,p1,p2,p3,p4,n1,n2,n3,n4) values (2,'0aA1!测试テストמבחן ','0aA1!测试テストמבחן ','0aA1!测试テストמבחן ','0aA1!测试テストמבחן ','0Aa1!测试テストמבחן','0Aa1!测试テストמבחן','0Aa1!测试テストמבחן','0Aa1!测试テストמבחן');")
|
||||
@ -7667,6 +7675,14 @@ func (s *testIntegrationSuite) TestIssue17287(c *C) {
|
||||
tk.MustQuery("execute stmt7 using @val2;").Check(testkit.Rows("1589873946"))
|
||||
}
|
||||
|
||||
func (s *testIntegrationSuite) TestIssue26989(c *C) {
|
||||
tk := testkit.NewTestKit(c, s.store)
|
||||
tk.MustExec("set names utf8mb4 collate utf8mb4_general_ci;")
|
||||
tk.MustQuery("select position('a' in 'AA');").Check(testkit.Rows("0"))
|
||||
tk.MustQuery("select locate('a', 'AA');").Check(testkit.Rows("0"))
|
||||
tk.MustQuery("select locate('a', 'a');").Check(testkit.Rows("1"))
|
||||
}
|
||||
|
||||
func (s *testIntegrationSuite) TestIssue17898(c *C) {
|
||||
tk := testkit.NewTestKit(c, s.store)
|
||||
tk.MustExec("use test")
|
||||
|
||||
@ -14,12 +14,14 @@
|
||||
package expression
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/pingcap/errors"
|
||||
"github.com/pingcap/failpoint"
|
||||
@ -359,6 +361,29 @@ func SubstituteCorCol2Constant(expr Expression) (Expression, error) {
|
||||
return expr, nil
|
||||
}
|
||||
|
||||
func locateStringWithCollation(str, substr, coll string) int64 {
|
||||
collator := collate.GetCollator(coll)
|
||||
strKey := collator.Key(str)
|
||||
subStrKey := collator.Key(substr)
|
||||
|
||||
index := bytes.Index(strKey, subStrKey)
|
||||
if index == -1 || index == 0 {
|
||||
return int64(index + 1)
|
||||
}
|
||||
|
||||
// todo: we can use binary search to make it faster.
|
||||
count := int64(0)
|
||||
for {
|
||||
r, size := utf8.DecodeRuneInString(str)
|
||||
count += 1
|
||||
index -= len(collator.Key(string(r)))
|
||||
if index == 0 {
|
||||
return count + 1
|
||||
}
|
||||
str = str[size:]
|
||||
}
|
||||
}
|
||||
|
||||
// timeZone2Duration converts timezone whose format should satisfy the regular condition
|
||||
// `(^(+|-)(0?[0-9]|1[0-2]):[0-5]?\d$)|(^+13:00$)` to time.Duration.
|
||||
func timeZone2Duration(tz string) time.Duration {
|
||||
|
||||
Reference in New Issue
Block a user