Files
tidb/util/codec/codec.go
2015-10-19 16:26:43 +08:00

212 lines
6.1 KiB
Go

// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package codec
import (
"bytes"
"time"
"github.com/juju/errors"
"github.com/pingcap/tidb/mysql"
)
var (
// InfiniteValue is the greatest than any other encoded value.
InfiniteValue = []byte{0xFF, 0xFF}
// NilValue is the smallest than any other encoded value.
NilValue = []byte{0x00, 0x00}
// SmallestNoneNilValue is smaller than any other encoded value except nil value.
SmallestNoneNilValue = []byte{0x00, 0x01}
)
// TODO: use iota + 1 instead?
const (
formatNilFlag = 'n'
formatIntFlag = 'd'
formatUintFlag = 'u'
formatFloatFlag = 'f'
formatStringFlag = 's'
formatBytesFlag = 'b'
formatDurationFlag = 't'
formatDecimalFlag = 'c'
)
var sepKey = []byte{0x00, 0x00}
// EncodeKey encodes args to a slice which can be sorted lexicographically later.
// EncodeKey guarantees the encoded slice is in ascending order for comparison.
// TODO: we may add more test to check its valiadation, especially for null type and multi indices.
func EncodeKey(args ...interface{}) ([]byte, error) {
var b []byte
format := make([]byte, 0, len(args))
for _, arg := range args {
switch v := arg.(type) {
case bool:
if v {
b = EncodeInt(b, int64(1))
} else {
b = EncodeInt(b, int64(0))
}
format = append(format, formatIntFlag)
case int:
b = EncodeInt(b, int64(v))
format = append(format, formatIntFlag)
case int8:
b = EncodeInt(b, int64(v))
format = append(format, formatIntFlag)
case int16:
b = EncodeInt(b, int64(v))
format = append(format, formatIntFlag)
case int32:
b = EncodeInt(b, int64(v))
format = append(format, formatIntFlag)
case int64:
b = EncodeInt(b, int64(v))
format = append(format, formatIntFlag)
case uint:
b = EncodeUint(b, uint64(v))
format = append(format, formatUintFlag)
case uint8:
b = EncodeUint(b, uint64(v))
format = append(format, formatUintFlag)
case uint16:
b = EncodeUint(b, uint64(v))
format = append(format, formatUintFlag)
case uint32:
b = EncodeUint(b, uint64(v))
format = append(format, formatUintFlag)
case uint64:
b = EncodeUint(b, uint64(v))
format = append(format, formatUintFlag)
case float32:
b = EncodeFloat(b, float64(v))
format = append(format, formatFloatFlag)
case float64:
b = EncodeFloat(b, float64(v))
format = append(format, formatFloatFlag)
case string:
b = EncodeBytes(b, []byte(v))
format = append(format, formatStringFlag)
case []byte:
b = EncodeBytes(b, v)
format = append(format, formatBytesFlag)
case mysql.Time:
b = EncodeBytes(b, []byte(v.String()))
format = append(format, formatStringFlag)
case mysql.Duration:
// duration may have negative value, so we cannot use String to encode directly.
b = EncodeInt(b, int64(v.Duration))
format = append(format, formatDurationFlag)
case mysql.Decimal:
b = EncodeDecimal(b, v)
format = append(format, formatDecimalFlag)
case mysql.Hex:
b = EncodeInt(b, int64(v.ToNumber()))
format = append(format, formatIntFlag)
case mysql.Bit:
b = EncodeUint(b, uint64(v.ToNumber()))
format = append(format, formatUintFlag)
case mysql.Enum:
b = EncodeUint(b, uint64(v.ToNumber()))
format = append(format, formatUintFlag)
case mysql.Set:
b = EncodeUint(b, uint64(v.ToNumber()))
format = append(format, formatUintFlag)
case nil:
// We will 0x00, 0x00 for nil.
// The []byte{} will be encoded as 0x00, 0x01.
// The []byte{0x00} will be encode as 0x00, 0xFF, 0x00, 0x01.
// And any integer and float encoded values are greater than 0x00, 0x01.
// So maybe the smallest none null value is []byte{} and we can use it to skip null values.
b = append(b, sepKey...)
format = append(format, formatNilFlag)
default:
return nil, errors.Errorf("unsupport encode type %T", arg)
}
}
// The comma is the seperator,
// e.g: 0x00, 0x00
// We need more tests to check its validation.
b = append(b, sepKey...)
b = append(b, format...)
return b, nil
}
// StripEnd splits a slice b into two substrings separated by sepKey
// and returns a slice byte of the previous substrings.
func StripEnd(b []byte) ([]byte, error) {
n := bytes.LastIndex(b, sepKey)
if n == -1 || n+2 >= len(b) {
// No seperator or no proper format.
return nil, errors.Errorf("invalid encoded key")
}
return b[:n], nil
}
// DecodeKey decodes values from a byte slice generated with EncodeKey before.
func DecodeKey(b []byte) ([]interface{}, error) {
// At first read the format.
n := bytes.LastIndex(b, sepKey)
if n == -1 || n+2 >= len(b) {
// No seperator or no proper format.
return nil, errors.Errorf("invalid encoded key")
}
format := b[n+2:]
b = b[0:n]
v := make([]interface{}, len(format))
var err error
for i, flag := range format {
switch flag {
case formatIntFlag:
b, v[i], err = DecodeInt(b)
case formatUintFlag:
b, v[i], err = DecodeUint(b)
case formatFloatFlag:
b, v[i], err = DecodeFloat(b)
case formatStringFlag:
var r []byte
b, r, err = DecodeBytes(b)
if err == nil {
v[i] = string(r)
}
case formatBytesFlag:
b, v[i], err = DecodeBytes(b)
case formatDurationFlag:
var r int64
b, r, err = DecodeInt(b)
if err == nil {
// use max fsp, let outer to do round manually.
v[i] = mysql.Duration{Duration: time.Duration(r), Fsp: mysql.MaxFsp}
}
case formatDecimalFlag:
b, v[i], err = DecodeDecimal(b)
case formatNilFlag:
if len(b) < 2 || (b[0] != 0x00 && b[1] != 0x00) {
return nil, errors.Errorf("malformed encoded nil")
}
b, v[i] = b[2:], nil
default:
return nil, errors.Errorf("invalid encoded key format %v in %s", flag, format)
}
if err != nil {
return nil, errors.Trace(err)
}
}
return v, nil
}