573 lines
16 KiB
Go
573 lines
16 KiB
Go
// Copyright 2014 The ql Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSES/QL-LICENSE file.
|
|
|
|
// Copyright 2015 PingCAP, Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package types
|
|
|
|
import (
|
|
"math"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/pingcap/errors"
|
|
"github.com/pingcap/parser/mysql"
|
|
"github.com/pingcap/parser/terror"
|
|
"github.com/pingcap/tidb/sessionctx/stmtctx"
|
|
"github.com/pingcap/tidb/types/json"
|
|
"github.com/pingcap/tidb/util/hack"
|
|
)
|
|
|
|
func truncateStr(str string, flen int) string {
|
|
if flen != UnspecifiedLength && len(str) > flen {
|
|
str = str[:flen]
|
|
}
|
|
return str
|
|
}
|
|
|
|
// IntergerUnsignedUpperBound indicates the max uint64 values of different mysql types.
|
|
func IntergerUnsignedUpperBound(intType byte) uint64 {
|
|
switch intType {
|
|
case mysql.TypeTiny:
|
|
return math.MaxUint8
|
|
case mysql.TypeShort:
|
|
return math.MaxUint16
|
|
case mysql.TypeInt24:
|
|
return mysql.MaxUint24
|
|
case mysql.TypeLong:
|
|
return math.MaxUint32
|
|
case mysql.TypeLonglong:
|
|
return math.MaxUint64
|
|
case mysql.TypeBit:
|
|
return math.MaxUint64
|
|
case mysql.TypeEnum:
|
|
return math.MaxUint64
|
|
case mysql.TypeSet:
|
|
return math.MaxUint64
|
|
default:
|
|
panic("Input byte is not a mysql type")
|
|
}
|
|
}
|
|
|
|
// IntergerSignedUpperBound indicates the max int64 values of different mysql types.
|
|
func IntergerSignedUpperBound(intType byte) int64 {
|
|
switch intType {
|
|
case mysql.TypeTiny:
|
|
return math.MaxInt8
|
|
case mysql.TypeShort:
|
|
return math.MaxInt16
|
|
case mysql.TypeInt24:
|
|
return mysql.MaxInt24
|
|
case mysql.TypeLong:
|
|
return math.MaxInt32
|
|
case mysql.TypeLonglong:
|
|
return math.MaxInt64
|
|
default:
|
|
panic("Input byte is not a mysql type")
|
|
}
|
|
}
|
|
|
|
// IntergerSignedLowerBound indicates the min int64 values of different mysql types.
|
|
func IntergerSignedLowerBound(intType byte) int64 {
|
|
switch intType {
|
|
case mysql.TypeTiny:
|
|
return math.MinInt8
|
|
case mysql.TypeShort:
|
|
return math.MinInt16
|
|
case mysql.TypeInt24:
|
|
return mysql.MinInt24
|
|
case mysql.TypeLong:
|
|
return math.MinInt32
|
|
case mysql.TypeLonglong:
|
|
return math.MinInt64
|
|
default:
|
|
panic("Input byte is not a mysql type")
|
|
}
|
|
}
|
|
|
|
// ConvertFloatToInt converts a float64 value to a int value.
|
|
func ConvertFloatToInt(fval float64, lowerBound, upperBound int64, tp byte) (int64, error) {
|
|
val := RoundFloat(fval)
|
|
if val < float64(lowerBound) {
|
|
return lowerBound, overflow(val, tp)
|
|
}
|
|
|
|
if val >= float64(upperBound) {
|
|
if val == float64(upperBound) {
|
|
return upperBound, nil
|
|
}
|
|
return upperBound, overflow(val, tp)
|
|
}
|
|
return int64(val), nil
|
|
}
|
|
|
|
// ConvertIntToInt converts an int value to another int value of different precision.
|
|
func ConvertIntToInt(val int64, lowerBound int64, upperBound int64, tp byte) (int64, error) {
|
|
if val < lowerBound {
|
|
return lowerBound, overflow(val, tp)
|
|
}
|
|
|
|
if val > upperBound {
|
|
return upperBound, overflow(val, tp)
|
|
}
|
|
|
|
return val, nil
|
|
}
|
|
|
|
// ConvertUintToInt converts an uint value to an int value.
|
|
func ConvertUintToInt(val uint64, upperBound int64, tp byte) (int64, error) {
|
|
if val > uint64(upperBound) {
|
|
return upperBound, overflow(val, tp)
|
|
}
|
|
|
|
return int64(val), nil
|
|
}
|
|
|
|
// ConvertIntToUint converts an int value to an uint value.
|
|
func ConvertIntToUint(sc *stmtctx.StatementContext, val int64, upperBound uint64, tp byte) (uint64, error) {
|
|
if sc.ShouldClipToZero() && val < 0 {
|
|
return 0, overflow(val, tp)
|
|
}
|
|
|
|
if uint64(val) > upperBound {
|
|
return upperBound, overflow(val, tp)
|
|
}
|
|
|
|
return uint64(val), nil
|
|
}
|
|
|
|
// ConvertUintToUint converts an uint value to another uint value of different precision.
|
|
func ConvertUintToUint(val uint64, upperBound uint64, tp byte) (uint64, error) {
|
|
if val > upperBound {
|
|
return upperBound, overflow(val, tp)
|
|
}
|
|
|
|
return val, nil
|
|
}
|
|
|
|
// ConvertFloatToUint converts a float value to an uint value.
|
|
func ConvertFloatToUint(sc *stmtctx.StatementContext, fval float64, upperBound uint64, tp byte) (uint64, error) {
|
|
val := RoundFloat(fval)
|
|
if val < 0 {
|
|
if sc.ShouldClipToZero() {
|
|
return 0, overflow(val, tp)
|
|
}
|
|
return uint64(int64(val)), overflow(val, tp)
|
|
}
|
|
|
|
if val > float64(upperBound) {
|
|
return upperBound, overflow(val, tp)
|
|
}
|
|
return uint64(val), nil
|
|
}
|
|
|
|
// StrToInt converts a string to an integer at the best-effort.
|
|
func StrToInt(sc *stmtctx.StatementContext, str string) (int64, error) {
|
|
str = strings.TrimSpace(str)
|
|
validPrefix, err := getValidIntPrefix(sc, str)
|
|
iVal, err1 := strconv.ParseInt(validPrefix, 10, 64)
|
|
if err1 != nil {
|
|
return iVal, ErrOverflow.GenWithStackByArgs("BIGINT", validPrefix)
|
|
}
|
|
return iVal, errors.Trace(err)
|
|
}
|
|
|
|
// StrToUint converts a string to an unsigned integer at the best-effortt.
|
|
func StrToUint(sc *stmtctx.StatementContext, str string) (uint64, error) {
|
|
str = strings.TrimSpace(str)
|
|
validPrefix, err := getValidIntPrefix(sc, str)
|
|
if validPrefix[0] == '+' {
|
|
validPrefix = validPrefix[1:]
|
|
}
|
|
uVal, err1 := strconv.ParseUint(validPrefix, 10, 64)
|
|
if err1 != nil {
|
|
return uVal, ErrOverflow.GenWithStackByArgs("BIGINT UNSIGNED", validPrefix)
|
|
}
|
|
return uVal, errors.Trace(err)
|
|
}
|
|
|
|
// StrToDateTime converts str to MySQL DateTime.
|
|
func StrToDateTime(sc *stmtctx.StatementContext, str string, fsp int) (Time, error) {
|
|
return ParseTime(sc, str, mysql.TypeDatetime, fsp)
|
|
}
|
|
|
|
// StrToDuration converts str to Duration. It returns Duration in normal case,
|
|
// and returns Time when str is in datetime format.
|
|
// when isDuration is true, the d is returned, when it is false, the t is returned.
|
|
// See https://dev.mysql.com/doc/refman/5.5/en/date-and-time-literals.html.
|
|
func StrToDuration(sc *stmtctx.StatementContext, str string, fsp int) (d Duration, t Time, isDuration bool, err error) {
|
|
str = strings.TrimSpace(str)
|
|
length := len(str)
|
|
if length > 0 && str[0] == '-' {
|
|
length--
|
|
}
|
|
// Timestamp format is 'YYYYMMDDHHMMSS' or 'YYMMDDHHMMSS', which length is 12.
|
|
// See #3923, it explains what we do here.
|
|
if length >= 12 {
|
|
t, err = StrToDateTime(sc, str, fsp)
|
|
if err == nil {
|
|
return d, t, false, nil
|
|
}
|
|
}
|
|
|
|
d, err = ParseDuration(sc, str, fsp)
|
|
if ErrTruncatedWrongVal.Equal(err) {
|
|
err = sc.HandleTruncate(err)
|
|
}
|
|
return d, t, true, errors.Trace(err)
|
|
}
|
|
|
|
// NumberToDuration converts number to Duration.
|
|
func NumberToDuration(number int64, fsp int) (Duration, error) {
|
|
if number > TimeMaxValue {
|
|
// Try to parse DATETIME.
|
|
if number >= 10000000000 { // '2001-00-00 00-00-00'
|
|
if t, err := ParseDatetimeFromNum(nil, number); err == nil {
|
|
dur, err1 := t.ConvertToDuration()
|
|
return dur, errors.Trace(err1)
|
|
}
|
|
}
|
|
dur, err1 := MaxMySQLTime(fsp).ConvertToDuration()
|
|
terror.Log(err1)
|
|
return dur, ErrOverflow.GenWithStackByArgs("Duration", strconv.Itoa(int(number)))
|
|
} else if number < -TimeMaxValue {
|
|
dur, err1 := MaxMySQLTime(fsp).ConvertToDuration()
|
|
terror.Log(err1)
|
|
dur.Duration = -dur.Duration
|
|
return dur, ErrOverflow.GenWithStackByArgs("Duration", strconv.Itoa(int(number)))
|
|
}
|
|
var neg bool
|
|
if neg = number < 0; neg {
|
|
number = -number
|
|
}
|
|
|
|
if number/10000 > TimeMaxHour || number%100 >= 60 || (number/100)%100 >= 60 {
|
|
return ZeroDuration, errors.Trace(ErrInvalidTimeFormat.GenWithStackByArgs(number))
|
|
}
|
|
t := Time{Time: FromDate(0, 0, 0, int(number/10000), int((number/100)%100), int(number%100), 0), Type: mysql.TypeDuration, Fsp: fsp}
|
|
dur, err := t.ConvertToDuration()
|
|
if err != nil {
|
|
return ZeroDuration, errors.Trace(err)
|
|
}
|
|
if neg {
|
|
dur.Duration = -dur.Duration
|
|
}
|
|
return dur, nil
|
|
}
|
|
|
|
// getValidIntPrefix gets prefix of the string which can be successfully parsed as int.
|
|
func getValidIntPrefix(sc *stmtctx.StatementContext, str string) (string, error) {
|
|
floatPrefix, err := getValidFloatPrefix(sc, str)
|
|
if err != nil {
|
|
return floatPrefix, errors.Trace(err)
|
|
}
|
|
return floatStrToIntStr(sc, floatPrefix, str)
|
|
}
|
|
|
|
// roundIntStr is to round int string base on the number following dot.
|
|
func roundIntStr(numNextDot byte, intStr string) string {
|
|
if numNextDot < '5' {
|
|
return intStr
|
|
}
|
|
retStr := []byte(intStr)
|
|
for i := len(intStr) - 1; i >= 0; i-- {
|
|
if retStr[i] != '9' {
|
|
retStr[i]++
|
|
break
|
|
}
|
|
if i == 0 {
|
|
retStr[i] = '1'
|
|
retStr = append(retStr, '0')
|
|
break
|
|
}
|
|
retStr[i] = '0'
|
|
}
|
|
return string(retStr)
|
|
}
|
|
|
|
// floatStrToIntStr converts a valid float string into valid integer string which can be parsed by
|
|
// strconv.ParseInt, we can't parse float first then convert it to string because precision will
|
|
// be lost. For example, the string value "18446744073709551615" which is the max number of unsigned
|
|
// int will cause some precision to lose. intStr[0] may be a positive and negative sign like '+' or '-'.
|
|
func floatStrToIntStr(sc *stmtctx.StatementContext, validFloat string, oriStr string) (intStr string, _ error) {
|
|
var dotIdx = -1
|
|
var eIdx = -1
|
|
for i := 0; i < len(validFloat); i++ {
|
|
switch validFloat[i] {
|
|
case '.':
|
|
dotIdx = i
|
|
case 'e', 'E':
|
|
eIdx = i
|
|
}
|
|
}
|
|
if eIdx == -1 {
|
|
if dotIdx == -1 {
|
|
return validFloat, nil
|
|
}
|
|
var digits []byte
|
|
if validFloat[0] == '-' || validFloat[0] == '+' {
|
|
dotIdx--
|
|
digits = []byte(validFloat[1:])
|
|
} else {
|
|
digits = []byte(validFloat)
|
|
}
|
|
if dotIdx == 0 {
|
|
intStr = "0"
|
|
} else {
|
|
intStr = string(digits)[:dotIdx]
|
|
}
|
|
if len(digits) > dotIdx+1 {
|
|
intStr = roundIntStr(digits[dotIdx+1], intStr)
|
|
}
|
|
if (len(intStr) > 1 || intStr[0] != '0') && validFloat[0] == '-' {
|
|
intStr = "-" + intStr
|
|
}
|
|
return intStr, nil
|
|
}
|
|
var intCnt int
|
|
digits := make([]byte, 0, len(validFloat))
|
|
if dotIdx == -1 {
|
|
digits = append(digits, validFloat[:eIdx]...)
|
|
intCnt = len(digits)
|
|
} else {
|
|
digits = append(digits, validFloat[:dotIdx]...)
|
|
intCnt = len(digits)
|
|
digits = append(digits, validFloat[dotIdx+1:eIdx]...)
|
|
}
|
|
exp, err := strconv.Atoi(validFloat[eIdx+1:])
|
|
if err != nil {
|
|
return validFloat, errors.Trace(err)
|
|
}
|
|
if exp > 0 && int64(intCnt) > (math.MaxInt64-int64(exp)) {
|
|
// (exp + incCnt) overflows MaxInt64.
|
|
sc.AppendWarning(ErrOverflow.GenWithStackByArgs("BIGINT", oriStr))
|
|
return validFloat[:eIdx], nil
|
|
}
|
|
intCnt += exp
|
|
if intCnt <= 0 {
|
|
intStr = "0"
|
|
if intCnt == 0 && len(digits) > 0 {
|
|
intStr = roundIntStr(digits[0], intStr)
|
|
}
|
|
return intStr, nil
|
|
}
|
|
if intCnt == 1 && (digits[0] == '-' || digits[0] == '+') {
|
|
intStr = "0"
|
|
if len(digits) > 1 {
|
|
intStr = roundIntStr(digits[1], intStr)
|
|
}
|
|
if intStr[0] == '1' {
|
|
intStr = string(digits[:1]) + intStr
|
|
}
|
|
return intStr, nil
|
|
}
|
|
if intCnt <= len(digits) {
|
|
intStr = string(digits[:intCnt])
|
|
if intCnt < len(digits) {
|
|
intStr = roundIntStr(digits[intCnt], intStr)
|
|
}
|
|
} else {
|
|
// convert scientific notation decimal number
|
|
extraZeroCount := intCnt - len(digits)
|
|
if extraZeroCount > 20 {
|
|
// Append overflow warning and return to avoid allocating too much memory.
|
|
sc.AppendWarning(ErrOverflow.GenWithStackByArgs("BIGINT", oriStr))
|
|
return validFloat[:eIdx], nil
|
|
}
|
|
intStr = string(digits) + strings.Repeat("0", extraZeroCount)
|
|
}
|
|
return intStr, nil
|
|
}
|
|
|
|
// StrToFloat converts a string to a float64 at the best-effort.
|
|
func StrToFloat(sc *stmtctx.StatementContext, str string) (float64, error) {
|
|
str = strings.TrimSpace(str)
|
|
validStr, err := getValidFloatPrefix(sc, str)
|
|
f, err1 := strconv.ParseFloat(validStr, 64)
|
|
if err1 != nil {
|
|
if err2, ok := err1.(*strconv.NumError); ok {
|
|
// value will truncate to MAX/MIN if out of range.
|
|
if err2.Err == strconv.ErrRange {
|
|
err1 = sc.HandleTruncate(ErrTruncatedWrongVal.GenWithStackByArgs("DOUBLE", str))
|
|
if math.IsInf(f, 1) {
|
|
f = math.MaxFloat64
|
|
} else if math.IsInf(f, -1) {
|
|
f = -math.MaxFloat64
|
|
}
|
|
}
|
|
}
|
|
return f, errors.Trace(err1)
|
|
}
|
|
return f, errors.Trace(err)
|
|
}
|
|
|
|
// ConvertJSONToInt casts JSON into int64.
|
|
func ConvertJSONToInt(sc *stmtctx.StatementContext, j json.BinaryJSON, unsigned bool) (int64, error) {
|
|
switch j.TypeCode {
|
|
case json.TypeCodeObject, json.TypeCodeArray:
|
|
return 0, nil
|
|
case json.TypeCodeLiteral:
|
|
switch j.Value[0] {
|
|
case json.LiteralNil, json.LiteralFalse:
|
|
return 0, nil
|
|
default:
|
|
return 1, nil
|
|
}
|
|
case json.TypeCodeInt64, json.TypeCodeUint64:
|
|
return j.GetInt64(), nil
|
|
case json.TypeCodeFloat64:
|
|
f := j.GetFloat64()
|
|
if !unsigned {
|
|
lBound := IntergerSignedLowerBound(mysql.TypeLonglong)
|
|
uBound := IntergerSignedUpperBound(mysql.TypeLonglong)
|
|
return ConvertFloatToInt(f, lBound, uBound, mysql.TypeDouble)
|
|
}
|
|
bound := IntergerUnsignedUpperBound(mysql.TypeLonglong)
|
|
u, err := ConvertFloatToUint(sc, f, bound, mysql.TypeDouble)
|
|
return int64(u), errors.Trace(err)
|
|
case json.TypeCodeString:
|
|
str := string(hack.String(j.GetString()))
|
|
return StrToInt(sc, str)
|
|
}
|
|
return 0, errors.New("Unknown type code in JSON")
|
|
}
|
|
|
|
// ConvertJSONToFloat casts JSON into float64.
|
|
func ConvertJSONToFloat(sc *stmtctx.StatementContext, j json.BinaryJSON) (float64, error) {
|
|
switch j.TypeCode {
|
|
case json.TypeCodeObject, json.TypeCodeArray:
|
|
return 0, nil
|
|
case json.TypeCodeLiteral:
|
|
switch j.Value[0] {
|
|
case json.LiteralNil, json.LiteralFalse:
|
|
return 0, nil
|
|
default:
|
|
return 1, nil
|
|
}
|
|
case json.TypeCodeInt64:
|
|
return float64(j.GetInt64()), nil
|
|
case json.TypeCodeUint64:
|
|
u, err := ConvertIntToUint(sc, j.GetInt64(), IntergerUnsignedUpperBound(mysql.TypeLonglong), mysql.TypeLonglong)
|
|
return float64(u), errors.Trace(err)
|
|
case json.TypeCodeFloat64:
|
|
return j.GetFloat64(), nil
|
|
case json.TypeCodeString:
|
|
str := string(hack.String(j.GetString()))
|
|
return StrToFloat(sc, str)
|
|
}
|
|
return 0, errors.New("Unknown type code in JSON")
|
|
}
|
|
|
|
// ConvertJSONToDecimal casts JSON into decimal.
|
|
func ConvertJSONToDecimal(sc *stmtctx.StatementContext, j json.BinaryJSON) (*MyDecimal, error) {
|
|
res := new(MyDecimal)
|
|
if j.TypeCode != json.TypeCodeString {
|
|
f64, err := ConvertJSONToFloat(sc, j)
|
|
if err != nil {
|
|
return res, errors.Trace(err)
|
|
}
|
|
err = res.FromFloat64(f64)
|
|
return res, errors.Trace(err)
|
|
}
|
|
err := sc.HandleTruncate(res.FromString([]byte(j.GetString())))
|
|
return res, errors.Trace(err)
|
|
}
|
|
|
|
// getValidFloatPrefix gets prefix of string which can be successfully parsed as float.
|
|
func getValidFloatPrefix(sc *stmtctx.StatementContext, s string) (valid string, err error) {
|
|
var (
|
|
sawDot bool
|
|
sawDigit bool
|
|
validLen int
|
|
eIdx int
|
|
)
|
|
for i := 0; i < len(s); i++ {
|
|
c := s[i]
|
|
if c == '+' || c == '-' {
|
|
if i != 0 && i != eIdx+1 { // "1e+1" is valid.
|
|
break
|
|
}
|
|
} else if c == '.' {
|
|
if sawDot || eIdx > 0 { // "1.1." or "1e1.1"
|
|
break
|
|
}
|
|
sawDot = true
|
|
if sawDigit { // "123." is valid.
|
|
validLen = i + 1
|
|
}
|
|
} else if c == 'e' || c == 'E' {
|
|
if !sawDigit { // "+.e"
|
|
break
|
|
}
|
|
if eIdx != 0 { // "1e5e"
|
|
break
|
|
}
|
|
eIdx = i
|
|
} else if c < '0' || c > '9' {
|
|
break
|
|
} else {
|
|
sawDigit = true
|
|
validLen = i + 1
|
|
}
|
|
}
|
|
valid = s[:validLen]
|
|
if valid == "" {
|
|
valid = "0"
|
|
}
|
|
if validLen == 0 || validLen != len(s) {
|
|
err = errors.Trace(handleTruncateError(sc))
|
|
}
|
|
return valid, err
|
|
}
|
|
|
|
// ToString converts an interface to a string.
|
|
func ToString(value interface{}) (string, error) {
|
|
switch v := value.(type) {
|
|
case bool:
|
|
if v {
|
|
return "1", nil
|
|
}
|
|
return "0", nil
|
|
case int:
|
|
return strconv.FormatInt(int64(v), 10), nil
|
|
case int64:
|
|
return strconv.FormatInt(v, 10), nil
|
|
case uint64:
|
|
return strconv.FormatUint(v, 10), nil
|
|
case float32:
|
|
return strconv.FormatFloat(float64(v), 'f', -1, 32), nil
|
|
case float64:
|
|
return strconv.FormatFloat(v, 'f', -1, 64), nil
|
|
case string:
|
|
return v, nil
|
|
case []byte:
|
|
return string(v), nil
|
|
case Time:
|
|
return v.String(), nil
|
|
case Duration:
|
|
return v.String(), nil
|
|
case *MyDecimal:
|
|
return v.String(), nil
|
|
case BinaryLiteral:
|
|
return v.ToString(), nil
|
|
case Enum:
|
|
return v.String(), nil
|
|
case Set:
|
|
return v.String(), nil
|
|
default:
|
|
return "", errors.Errorf("cannot convert %v(type %T) to string", value, value)
|
|
}
|
|
}
|