225 lines
6.3 KiB
Go
225 lines
6.3 KiB
Go
// Copyright 2021 PingCAP, Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package charset
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"reflect"
|
|
"strings"
|
|
"unicode"
|
|
"unsafe"
|
|
|
|
"github.com/cznic/mathutil"
|
|
"github.com/pingcap/tidb/parser/mysql"
|
|
"github.com/pingcap/tidb/parser/terror"
|
|
"golang.org/x/text/encoding"
|
|
"golang.org/x/text/transform"
|
|
)
|
|
|
|
var errInvalidCharacterString = terror.ClassParser.NewStd(mysql.ErrInvalidCharacterString)
|
|
|
|
type EncodingLabel string
|
|
|
|
// Format trim and change the label to lowercase.
|
|
func Format(label string) EncodingLabel {
|
|
return EncodingLabel(strings.ToLower(strings.Trim(label, "\t\n\r\f ")))
|
|
}
|
|
|
|
// Formatted is used when the label is already trimmed and it is lowercase.
|
|
func Formatted(label string) EncodingLabel {
|
|
return EncodingLabel(label)
|
|
}
|
|
|
|
// Encoding provide a interface to encode/decode a string with specific encoding.
|
|
type Encoding struct {
|
|
enc encoding.Encoding
|
|
name string
|
|
charLength func([]byte) int
|
|
specialCase unicode.SpecialCase
|
|
}
|
|
|
|
// enabled indicates whether the non-utf8 encoding is used.
|
|
func (e *Encoding) enabled() bool {
|
|
return e != UTF8Encoding
|
|
}
|
|
|
|
// Name returns the name of the current encoding.
|
|
func (e *Encoding) Name() string {
|
|
return e.name
|
|
}
|
|
|
|
// CharLength returns the next character length in bytes.
|
|
func (e *Encoding) CharLength(bs []byte) int {
|
|
return e.charLength(bs)
|
|
}
|
|
|
|
// NewEncoding creates a new Encoding.
|
|
func NewEncoding(label string) *Encoding {
|
|
if len(label) == 0 {
|
|
return UTF8Encoding
|
|
}
|
|
|
|
if e, exist := encodingMap[Format(label)]; exist {
|
|
return e
|
|
}
|
|
return UTF8Encoding
|
|
}
|
|
|
|
// Encode convert bytes from utf-8 charset to a specific charset.
|
|
func (e *Encoding) Encode(dest, src []byte) ([]byte, error) {
|
|
if !e.enabled() {
|
|
return src, nil
|
|
}
|
|
return e.transform(e.enc.NewEncoder(), dest, src, false)
|
|
}
|
|
|
|
// EncodeString convert a string from utf-8 charset to a specific charset.
|
|
func (e *Encoding) EncodeString(src string) (string, error) {
|
|
if !e.enabled() {
|
|
return src, nil
|
|
}
|
|
bs, err := e.transform(e.enc.NewEncoder(), nil, Slice(src), false)
|
|
return string(bs), err
|
|
}
|
|
|
|
// EncodeFirstChar convert first code point of bytes from utf-8 charset to a specific charset.
|
|
func (e *Encoding) EncodeFirstChar(dest, src []byte) ([]byte, error) {
|
|
srcNextLen := e.nextCharLenInSrc(src, false)
|
|
srcEnd := mathutil.Min(srcNextLen, len(src))
|
|
if !e.enabled() {
|
|
return src[:srcEnd], nil
|
|
}
|
|
return e.transform(e.enc.NewEncoder(), dest, src[:srcEnd], false)
|
|
}
|
|
|
|
// EncodeInternal convert bytes from utf-8 charset to a specific charset, we actually do not do the real convert, just find the inconvertible character and use ? replace.
|
|
// The code below is equivalent to
|
|
// expr, _ := e.Encode(dest, src)
|
|
// ret, _ := e.Decode(nil, expr)
|
|
// return ret
|
|
func (e *Encoding) EncodeInternal(dest, src []byte) []byte {
|
|
if !e.enabled() {
|
|
return src
|
|
}
|
|
if dest == nil {
|
|
dest = make([]byte, 0, len(src))
|
|
}
|
|
var srcOffset int
|
|
|
|
var buf [4]byte
|
|
transformer := e.enc.NewEncoder()
|
|
for srcOffset < len(src) {
|
|
length := UTF8Encoding.CharLength(src[srcOffset:])
|
|
_, _, err := transformer.Transform(buf[:], src[srcOffset:srcOffset+length], true)
|
|
if err != nil {
|
|
dest = append(dest, byte('?'))
|
|
} else {
|
|
dest = append(dest, src[srcOffset:srcOffset+length]...)
|
|
}
|
|
srcOffset += length
|
|
}
|
|
|
|
return dest
|
|
}
|
|
|
|
// Decode convert bytes from a specific charset to utf-8 charset.
|
|
func (e *Encoding) Decode(dest, src []byte) ([]byte, error) {
|
|
if !e.enabled() {
|
|
return src, nil
|
|
}
|
|
return e.transform(e.enc.NewDecoder(), dest, src, true)
|
|
}
|
|
|
|
// DecodeString convert a string from a specific charset to utf-8 charset.
|
|
func (e *Encoding) DecodeString(src string) (string, error) {
|
|
if !e.enabled() {
|
|
return src, nil
|
|
}
|
|
bs, err := e.transform(e.enc.NewDecoder(), nil, Slice(src), true)
|
|
return string(bs), err
|
|
}
|
|
|
|
func (e *Encoding) transform(transformer transform.Transformer, dest, src []byte, isDecoding bool) ([]byte, error) {
|
|
if len(dest) < len(src) {
|
|
dest = make([]byte, len(src)*2)
|
|
}
|
|
if len(src) == 0 {
|
|
return src, nil
|
|
}
|
|
var destOffset, srcOffset int
|
|
var encodingErr error
|
|
for {
|
|
srcNextLen := e.nextCharLenInSrc(src[srcOffset:], isDecoding)
|
|
srcEnd := mathutil.Min(srcOffset+srcNextLen, len(src))
|
|
nDest, nSrc, err := transformer.Transform(dest[destOffset:], src[srcOffset:srcEnd], false)
|
|
if err == transform.ErrShortDst {
|
|
dest = enlargeCapacity(dest)
|
|
} else if err != nil || isDecoding && beginWithReplacementChar(dest[destOffset:destOffset+nDest]) {
|
|
if encodingErr == nil {
|
|
encodingErr = e.generateErr(src[srcOffset:], srcNextLen)
|
|
}
|
|
dest[destOffset] = byte('?')
|
|
nDest, nSrc = 1, srcNextLen // skip the source bytes that cannot be decoded normally.
|
|
}
|
|
destOffset += nDest
|
|
srcOffset += nSrc
|
|
// The source bytes are exhausted.
|
|
if srcOffset >= len(src) {
|
|
return dest[:destOffset], encodingErr
|
|
}
|
|
}
|
|
}
|
|
|
|
func (e *Encoding) nextCharLenInSrc(srcRest []byte, isDecoding bool) int {
|
|
if isDecoding {
|
|
if e.charLength != nil {
|
|
return e.charLength(srcRest)
|
|
}
|
|
return len(srcRest)
|
|
}
|
|
return UTF8Encoding.CharLength(srcRest)
|
|
}
|
|
|
|
func enlargeCapacity(dest []byte) []byte {
|
|
newDest := make([]byte, len(dest)*2)
|
|
copy(newDest, dest)
|
|
return newDest
|
|
}
|
|
|
|
func (e *Encoding) generateErr(srcRest []byte, srcNextLen int) error {
|
|
cutEnd := mathutil.Min(srcNextLen, len(srcRest))
|
|
invalidBytes := fmt.Sprintf("%X", string(srcRest[:cutEnd]))
|
|
return errInvalidCharacterString.GenWithStackByArgs(e.name, invalidBytes)
|
|
}
|
|
|
|
// replacementBytes are bytes for the replacement rune 0xfffd.
|
|
var replacementBytes = []byte{0xEF, 0xBF, 0xBD}
|
|
|
|
// beginWithReplacementChar check if dst has the prefix '0xEFBFBD'.
|
|
func beginWithReplacementChar(dst []byte) bool {
|
|
return bytes.HasPrefix(dst, replacementBytes)
|
|
}
|
|
|
|
// Slice converts string to slice without copy.
|
|
// Use at your own risk.
|
|
func Slice(s string) (b []byte) {
|
|
pBytes := (*reflect.SliceHeader)(unsafe.Pointer(&b))
|
|
pString := (*reflect.StringHeader)(unsafe.Pointer(&s))
|
|
pBytes.Data = pString.Data
|
|
pBytes.Len = pString.Len
|
|
pBytes.Cap = pString.Len
|
|
return
|
|
}
|