[parser] support parsing SQL with encodings other than utf8 (#1312)

This commit is contained in:
tangenta
2021-09-13 17:36:05 +08:00
committed by Ti Chi Robot
parent 8c9c20d3ff
commit ba105bbd10
10 changed files with 335 additions and 44 deletions

View File

@ -107,6 +107,16 @@ func ValidCharsetAndCollation(cs string, co string) bool {
return ok
}
// GetDefaultCollationLegacy is compatible with the charset support in old version parser.
func GetDefaultCollationLegacy(charset string) (string, error) {
switch strings.ToLower(charset) {
case CharsetUTF8, CharsetUTF8MB4, CharsetASCII, CharsetLatin1, CharsetBin:
return GetDefaultCollation(charset)
default:
return "", errors.Errorf("Unknown charset %s", charset)
}
}
// GetDefaultCollation returns the default collation for charset.
func GetDefaultCollation(charset string) (string, error) {
cs, err := GetCharsetInfo(charset)

137
parser/charset/encoding.go Normal file
View File

@ -0,0 +1,137 @@
// Copyright 2021 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package charset
import (
"strings"
"golang.org/x/text/encoding"
"golang.org/x/text/transform"
)
const (
encodingBufferSizeDefault = 1024
encodingBufferSizeRecycleThreshold = 4 * 1024
encodingDefault = "utf-8"
)
type EncodingLabel string
// Format trim and change the label to lowercase.
func Format(label string) EncodingLabel {
return EncodingLabel(strings.ToLower(strings.Trim(label, "\t\n\r\f ")))
}
// Formatted is used when the label is already trimmed and it is lowercase.
func Formatted(label string) EncodingLabel {
return EncodingLabel(label)
}
// Encoding provide a interface to encode/decode a string with specific encoding.
type Encoding struct {
enc encoding.Encoding
name string
charLength func([]byte) int
buffer []byte
}
// Enabled indicates whether the non-utf8 encoding is used.
func (e *Encoding) Enabled() bool {
return e.enc != nil && e.charLength != nil
}
// Name returns the name of the current encoding.
func (e *Encoding) Name() string {
return e.name
}
// NewEncoding creates a new Encoding.
func NewEncoding(label EncodingLabel) *Encoding {
if len(label) == 0 {
return &Encoding{}
}
e, name := lookup(label)
if e != nil && name != encodingDefault {
return &Encoding{
enc: e,
name: name,
charLength: FindNextCharacterLength(name),
buffer: make([]byte, encodingBufferSizeDefault),
}
}
return &Encoding{name: name}
}
// UpdateEncoding updates to a new Encoding without changing the buffer.
func (e *Encoding) UpdateEncoding(label EncodingLabel) {
enc, name := lookup(label)
e.name = name
if enc != nil && name != encodingDefault {
e.enc = enc
}
if len(e.buffer) == 0 {
e.buffer = make([]byte, encodingBufferSizeDefault)
}
}
// Encode encodes the bytes to a string.
func (e *Encoding) Encode(src []byte) (string, bool) {
return e.transform(e.enc.NewEncoder(), src)
}
// Decode decodes the bytes to a string.
func (e *Encoding) Decode(src []byte) (string, bool) {
return e.transform(e.enc.NewDecoder(), src)
}
func (e *Encoding) transform(transformer transform.Transformer, src []byte) (string, bool) {
if len(e.buffer) < len(src) {
e.buffer = make([]byte, len(src)*2)
}
var destOffset, srcOffset int
ok := true
for {
nextLen := 4
if e.charLength != nil {
nextLen = e.charLength(src[srcOffset:])
}
srcEnd := srcOffset + nextLen
if srcEnd > len(src) {
srcEnd = len(src)
}
nDest, nSrc, err := transformer.Transform(e.buffer[destOffset:], src[srcOffset:srcEnd], false)
destOffset += nDest
srcOffset += nSrc
if err == nil {
if srcOffset >= len(src) {
result := string(e.buffer[:destOffset])
if len(e.buffer) > encodingBufferSizeRecycleThreshold {
// This prevents Encoding from holding too much memory.
e.buffer = make([]byte, encodingBufferSizeDefault)
}
return result, ok
}
} else if err == transform.ErrShortDst {
newDest := make([]byte, len(e.buffer)*2)
copy(newDest, e.buffer)
e.buffer = newDest
} else {
e.buffer[destOffset] = byte('?')
destOffset += 1
srcOffset += 1
ok = false
}
}
}

View File

@ -31,7 +31,11 @@ import (
// leading and trailing whitespace.
func Lookup(label string) (e encoding.Encoding, name string) {
label = strings.ToLower(strings.Trim(label, "\t\n\r\f "))
enc := encodings[label]
return lookup(Formatted(label))
}
func lookup(label EncodingLabel) (e encoding.Encoding, name string) {
enc := encodings[string(label)]
return enc.e, enc.name
}
@ -258,3 +262,32 @@ var encodings = map[string]struct {
"utf-16le": {unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM), "utf-16le"},
"x-user-defined": {charmap.XUserDefined, "x-user-defined"},
}
// FindNextCharacterLength is used in lexer.peek() to determine the next character length.
func FindNextCharacterLength(label string) func([]byte) int {
if f, ok := encodingNextCharacterLength[label]; ok {
return f
}
return nil
}
var encodingNextCharacterLength = map[string]func([]byte) int{
// https://en.wikipedia.org/wiki/GBK_(character_encoding)#Layout_diagram
"gbk": func(bs []byte) int {
if len(bs) == 0 || bs[0] < 0x80 {
// A byte in the range 00–7F is a single byte that means the same thing as it does in ASCII.
return 1
}
return 2
},
"utf-8": func(bs []byte) int {
if len(bs) == 0 || bs[0] < 0x80 {
return 1
} else if bs[0] < 0xe0 {
return 2
} else if bs[0] < 0xf0 {
return 3
}
return 4
},
}

View File

@ -129,11 +129,11 @@ func (hp *hintParser) parse(input string, sqlMode mysql.SQLMode, initPos Pos) ([
hp.result = nil
hp.lexer.reset(input[3:])
hp.lexer.SetSQLMode(sqlMode)
hp.lexer.r.p = Pos{
hp.lexer.r.updatePos(Pos{
Line: initPos.Line,
Col: initPos.Col + 3, // skipped the initial '/*+'
Offset: 0,
}
})
hp.lexer.inBangComment = true // skip the final '*/' (we need the '*/' for reporting warnings)
yyhintParse(&hp.lexer, hp)

View File

@ -21,6 +21,8 @@ import (
"unicode"
"unicode/utf8"
"github.com/pingcap/errors"
"github.com/pingcap/parser/charset"
"github.com/pingcap/parser/mysql"
tidbfeature "github.com/pingcap/parser/tidb"
)
@ -39,6 +41,8 @@ type Scanner struct {
r reader
buf bytes.Buffer
encoding charset.Encoding
errs []error
warns []error
stmtStartPos int
@ -134,11 +138,28 @@ func (s *Scanner) AppendError(err error) {
s.errs = append(s.errs, err)
}
func (s *Scanner) tryDecodeToUTF8String(sql string) string {
if !s.encoding.Enabled() {
name := s.encoding.Name()
if len(name) > 0 {
s.AppendError(errors.Errorf("Encoding %s is not supported", name))
s.lastErrorAsWarn()
}
return sql
}
utf8Lit, ok := s.encoding.Decode(Slice(sql))
if !ok {
s.AppendError(errors.Errorf("Cannot convert string '%x' from %s to utf8mb4", sql, s.encoding.Name()))
s.lastErrorAsWarn()
}
return utf8Lit
}
func (s *Scanner) getNextToken() int {
r := s.r
tok, pos, lit := s.scan()
if tok == identifier {
tok = handleIdent(&yySymType{})
tok = s.handleIdent(&yySymType{})
}
if tok == identifier {
if tok1 := s.isTokenIdentifier(lit, pos.Offset); tok1 != 0 {
@ -163,7 +184,7 @@ func (s *Scanner) Lex(v *yySymType) int {
v.offset = pos.Offset
v.ident = lit
if tok == identifier {
tok = handleIdent(v)
tok = s.handleIdent(v)
}
if tok == identifier {
if tok1 := s.isTokenIdentifier(lit, pos.Offset); tok1 != 0 {
@ -240,6 +261,7 @@ func (s *Scanner) EnableWindowFunc(val bool) {
func (s *Scanner) InheritScanner(sql string) *Scanner {
return &Scanner{
r: reader{s: sql},
encoding: s.encoding,
sqlMode: s.sqlMode,
supportWindowFunc: s.supportWindowFunc,
}
@ -250,6 +272,22 @@ func NewScanner(s string) *Scanner {
return &Scanner{r: reader{s: s}}
}
func (s *Scanner) handleIdent(lval *yySymType) int {
str := lval.ident
// A character string literal may have an optional character set introducer and COLLATE clause:
// [_charset_name]'string' [COLLATE collation_name]
// See https://dev.mysql.com/doc/refman/5.7/en/charset-literal.html
if !strings.HasPrefix(str, "_") {
return identifier
}
cs, err := charset.GetCharsetInfo(str[1:])
if err != nil {
return identifier
}
lval.ident = cs.Name
return underscoreCS
}
func (s *Scanner) skipWhitespace() rune {
return s.r.incAsLongAs(unicode.IsSpace)
}
@ -266,7 +304,7 @@ func (s *Scanner) scan() (tok int, pos Pos, lit string) {
return 0, pos, ""
}
if !s.r.eof() && isIdentExtend(ch0) {
if isIdentExtend(ch0) {
return scanIdentifier(s)
}
@ -302,7 +340,7 @@ func startWithXx(s *Scanner) (tok int, pos Pos, lit string) {
}
return
}
s.r.p = pos
s.r.updatePos(pos)
return scanIdentifier(s)
}
@ -334,7 +372,7 @@ func startWithBb(s *Scanner) (tok int, pos Pos, lit string) {
}
return
}
s.r.p = pos
s.r.updatePos(pos)
return scanIdentifier(s)
}
@ -762,7 +800,7 @@ func (s *Scanner) scanBit() {
}
func (s *Scanner) scanFloat(beg *Pos) (tok int, pos Pos, lit string) {
s.r.p = *beg
s.r.updatePos(*beg)
// float = D1 . D2 e D3
s.scanDigits()
ch0 := s.r.peek()
@ -784,7 +822,7 @@ func (s *Scanner) scanFloat(beg *Pos) (tok int, pos Pos, lit string) {
// D1 . D2 e XX when XX is not D3, parse the result to an identifier.
// 9e9e = 9e9(float) + e(identifier)
// 9est = 9est(identifier)
s.r.p = *beg
s.r.updatePos(*beg)
s.r.incAsLongAs(isIdentChar)
tok = identifier
}
@ -810,7 +848,7 @@ func (s *Scanner) scanVersionDigits(min, max int) {
if isDigit(ch) {
s.r.inc()
} else if i < min {
s.r.p = pos
s.r.updatePos(pos)
return
} else {
break
@ -832,7 +870,7 @@ func (s *Scanner) scanFeatureIDs() (featureIDs []string) {
state = expectChar
break
}
s.r.p = pos
s.r.updatePos(pos)
return nil
case expectChar:
if isIdentChar(ch) {
@ -840,7 +878,7 @@ func (s *Scanner) scanFeatureIDs() (featureIDs []string) {
state = obtainChar
break
}
s.r.p = pos
s.r.updatePos(pos)
return nil
case obtainChar:
if isIdentChar(ch) {
@ -856,11 +894,11 @@ func (s *Scanner) scanFeatureIDs() (featureIDs []string) {
featureIDs = append(featureIDs, b.String())
return featureIDs
}
s.r.p = pos
s.r.updatePos(pos)
return nil
}
}
s.r.p = pos
s.r.updatePos(pos)
return nil
}
@ -876,6 +914,9 @@ type reader struct {
s string
p Pos
w int
peekRune rune
peekRuneUpdated bool
}
var eof = Pos{-1, -1, -1}
@ -888,21 +929,22 @@ func (r *reader) eof() bool {
// if reader meets EOF, it will return unicode.ReplacementChar. to distinguish from
// the real unicode.ReplacementChar, the caller should call r.eof() again to check.
func (r *reader) peek() rune {
if r.peekRuneUpdated {
return r.peekRune
}
if r.eof() {
return unicode.ReplacementChar
}
v, w := rune(r.s[r.p.Offset]), 1
switch {
case v == 0:
r.w = w
return v // illegal UTF-8 encoding
case v >= 0x80:
if v >= 0x80 {
v, w = utf8.DecodeRuneInString(r.s[r.p.Offset:])
if v == utf8.RuneError && w == 1 {
v = rune(r.s[r.p.Offset]) // illegal UTF-8 encoding
v = rune(r.s[r.p.Offset]) // illegal encoding
}
}
r.w = w
r.peekRune = v
r.peekRuneUpdated = true
return v
}
@ -915,6 +957,7 @@ func (r *reader) inc() {
}
r.p.Offset += r.w
r.p.Col++
r.peekRuneUpdated = false
}
func (r *reader) incN(n int) {
@ -936,6 +979,13 @@ func (r *reader) pos() Pos {
return r.p
}
func (r *reader) updatePos(pos Pos) {
if r.p.Offset != pos.Offset {
r.peekRuneUpdated = false
}
r.p = pos
}
func (r *reader) data(from *Pos) string {
return r.s[from.Offset:r.p.Offset]
}

View File

@ -14,9 +14,8 @@
package parser
import (
"strings"
"github.com/pingcap/parser/charset"
"reflect"
"unsafe"
)
func isLetter(ch rune) bool {
@ -991,18 +990,13 @@ func (s *Scanner) isTokenIdentifier(lit string, offset int) int {
return tok
}
func handleIdent(lval *yySymType) int {
s := lval.ident
// A character string literal may have an optional character set introducer and COLLATE clause:
// [_charset_name]'string' [COLLATE collation_name]
// See https://dev.mysql.com/doc/refman/5.7/en/charset-literal.html
if !strings.HasPrefix(s, "_") {
return identifier
}
cs, err := charset.GetCharsetInfo(s[1:])
if err != nil {
return identifier
}
lval.ident = cs.Name
return underscoreCS
// Slice converts string to slice without copy.
// Use at your own risk.
func Slice(s string) (b []byte) {
pBytes := (*reflect.SliceHeader)(unsafe.Pointer(&b))
pString := (*reflect.StringHeader)(unsafe.Pointer(&s))
pBytes.Data = pString.Data
pBytes.Len = pString.Len
pBytes.Cap = pString.Len
return
}

View File

@ -14908,7 +14908,7 @@ yynewstate:
case 1166:
{
// See https://dev.mysql.com/doc/refman/5.7/en/charset-literal.html
co, err := charset.GetDefaultCollation(yyS[yypt-1].ident)
co, err := charset.GetDefaultCollationLegacy(yyS[yypt-1].ident)
if err != nil {
yylex.AppendError(yylex.Errorf("Get collation error for charset: %s", yyS[yypt-1].ident))
return 1
@ -14932,7 +14932,7 @@ yynewstate:
}
case 1169:
{
co, err := charset.GetDefaultCollation(yyS[yypt-1].ident)
co, err := charset.GetDefaultCollationLegacy(yyS[yypt-1].ident)
if err != nil {
yylex.AppendError(yylex.Errorf("Get collation error for charset: %s", yyS[yypt-1].ident))
return 1
@ -14948,7 +14948,7 @@ yynewstate:
}
case 1170:
{
co, err := charset.GetDefaultCollation(yyS[yypt-1].ident)
co, err := charset.GetDefaultCollationLegacy(yyS[yypt-1].ident)
if err != nil {
yylex.AppendError(yylex.Errorf("Get collation error for charset: %s", yyS[yypt-1].ident))
return 1

View File

@ -6452,7 +6452,7 @@ Literal:
| "UNDERSCORE_CHARSET" stringLit
{
// See https://dev.mysql.com/doc/refman/5.7/en/charset-literal.html
co, err := charset.GetDefaultCollation($1)
co, err := charset.GetDefaultCollationLegacy($1)
if err != nil {
yylex.AppendError(yylex.Errorf("Get collation error for charset: %s", $1))
return 1
@ -6476,7 +6476,7 @@ Literal:
}
| "UNDERSCORE_CHARSET" hexLit
{
co, err := charset.GetDefaultCollation($1)
co, err := charset.GetDefaultCollationLegacy($1)
if err != nil {
yylex.AppendError(yylex.Errorf("Get collation error for charset: %s", $1))
return 1
@ -6492,7 +6492,7 @@ Literal:
}
| "UNDERSCORE_CHARSET" bitLit
{
co, err := charset.GetDefaultCollation($1)
co, err := charset.GetDefaultCollationLegacy($1)
if err != nil {
yylex.AppendError(yylex.Errorf("Get collation error for charset: %s", $1))
return 1

View File

@ -6320,3 +6320,66 @@ func (s *testParserSuite) TestPlanRecreator(c *C) {
c.Assert(v.Stmt.Text(), Equals, "SELECT a FROM t")
c.Assert(v.Analyze, IsTrue)
}
func (s *testParserSuite) TestGBKEncoding(c *C) {
p := parser.New()
gbkEncoding, _ := charset.Lookup("gbk")
encoder := gbkEncoding.NewEncoder()
sql, err := encoder.String("create table 测试表 (测试列 varchar(255) default 'GBK测试用例');")
c.Assert(err, IsNil)
stmt, err := p.ParseOneStmt(sql, "", "")
c.Assert(err, IsNil)
checker := &gbkEncodingChecker{}
_, _ = stmt.Accept(checker)
c.Assert(checker.tblName, Not(Equals), "测试表")
c.Assert(checker.colName, Not(Equals), "测试列")
p.SetParserConfig(parser.ParserConfig{CharsetClient: "gbk"})
stmt, err = p.ParseOneStmt(sql, "", "")
c.Assert(err, IsNil)
_, _ = stmt.Accept(checker)
c.Assert(checker.tblName, Equals, "测试表")
c.Assert(checker.colName, Equals, "测试列")
c.Assert(checker.expr, Equals, "GBK测试用例")
utf8SQL := "select '芢' from `玚`;"
sql, err = encoder.String(utf8SQL)
c.Assert(err, IsNil)
stmt, err = p.ParseOneStmt(sql, "", "")
c.Assert(err, IsNil)
stmt, err = p.ParseOneStmt("select '\xc6\x5c' from `\xab\x60`;", "", "")
c.Assert(err, IsNil)
p.SetParserConfig(parser.ParserConfig{CharsetClient: ""})
stmt, err = p.ParseOneStmt("select _gbk '\xc6\x5c' from dual;", "", "")
c.Assert(err, NotNil)
}
type gbkEncodingChecker struct {
tblName string
colName string
expr string
}
func (g *gbkEncodingChecker) Enter(n ast.Node) (node ast.Node, skipChildren bool) {
if tn, ok := n.(*ast.TableName); ok {
g.tblName = tn.Name.O
return n, false
}
if cn, ok := n.(*ast.ColumnName); ok {
g.colName = cn.Name.O
return n, false
}
if c, ok := n.(*ast.ColumnOption); ok {
if ve, ok := c.Expr.(ast.ValueExpr); ok {
g.expr = ve.GetString()
return n, false
}
}
return n, false
}
func (g *gbkEncodingChecker) Leave(n ast.Node) (node ast.Node, ok bool) {
return n, true
}

View File

@ -23,6 +23,7 @@ import (
"github.com/pingcap/errors"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/auth"
"github.com/pingcap/parser/charset"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/parser/terror"
)
@ -72,6 +73,7 @@ type ParserConfig struct {
EnableWindowFunction bool
EnableStrictDoubleTypeCheck bool
SkipPositionRecording bool
CharsetClient string // CharsetClient indicates how to decode the original SQL.
}
// Parser represents a parser instance. Some temporary objects are stored in it to reduce object allocation during Parse function.
@ -132,11 +134,13 @@ func (parser *Parser) SetParserConfig(config ParserConfig) {
parser.EnableWindowFunc(config.EnableWindowFunction)
parser.SetStrictDoubleTypeCheck(config.EnableStrictDoubleTypeCheck)
parser.lexer.skipPositionRecording = config.SkipPositionRecording
parser.lexer.encoding = *charset.NewEncoding(charset.Format(config.CharsetClient))
}
// Parse parses a query string to raw ast.StmtNode.
// If charset or collation is "", default charset and collation will be used.
func (parser *Parser) Parse(sql, charset, collation string) (stmt []ast.StmtNode, warns []error, err error) {
sql = parser.lexer.tryDecodeToUTF8String(sql)
if charset == "" {
charset = mysql.DefaultCharset
}