Merge pull request #952 from pingcap/coocood/datum-convert
*: replace interface{} with Datum in convert.
This commit is contained in:
@ -94,12 +94,12 @@ func FindOnUpdateCols(cols []*Col) []*Col {
|
||||
// CastValues casts values based on columns type.
|
||||
func CastValues(ctx context.Context, rec []types.Datum, cols []*Col) (err error) {
|
||||
for _, c := range cols {
|
||||
var val interface{}
|
||||
val, err = types.Convert(rec[c.Offset].GetValue(), &c.FieldType)
|
||||
var converted types.Datum
|
||||
converted, err = rec[c.Offset].ConvertTo(&c.FieldType)
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
rec[c.Offset].SetValue(val)
|
||||
rec[c.Offset] = converted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -24,7 +24,6 @@ import (
|
||||
"github.com/pingcap/tidb/table"
|
||||
"github.com/pingcap/tidb/table/tables"
|
||||
"github.com/pingcap/tidb/terror"
|
||||
"github.com/pingcap/tidb/util/types"
|
||||
)
|
||||
|
||||
func (d *ddl) adjustColumnOffset(columns []*model.ColumnInfo, indices []*model.IndexInfo, offset int, added bool) {
|
||||
@ -368,7 +367,7 @@ func (d *ddl) backfillColumnData(t table.Table, columnInfo *model.ColumnInfo, ha
|
||||
}
|
||||
|
||||
// must convert to the column field type.
|
||||
v, err := types.Convert(value.GetValue(), &columnInfo.FieldType)
|
||||
v, err := value.ConvertTo(&columnInfo.FieldType)
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
@ -378,7 +377,7 @@ func (d *ddl) backfillColumnData(t table.Table, columnInfo *model.ColumnInfo, ha
|
||||
return errors.Trace(err)
|
||||
}
|
||||
|
||||
err = tables.SetColValue(txn, backfillKey, types.NewDatum(v))
|
||||
err = tables.SetColValue(txn, backfillKey, v)
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
|
||||
@ -266,7 +266,7 @@ func (s *testColumnSuite) checkColumnKVExist(c *C, ctx context.Context, t table.
|
||||
c.Assert(err, IsNil)
|
||||
v, err1 := tables.DecodeValue(data, &col.FieldType)
|
||||
c.Assert(err1, IsNil)
|
||||
value, err1 := types.Convert(v, &col.FieldType)
|
||||
value, err1 := v.ConvertTo(&col.FieldType)
|
||||
c.Assert(err1, IsNil)
|
||||
c.Assert(value, Equals, columnValue)
|
||||
} else {
|
||||
|
||||
@ -340,13 +340,11 @@ func (e *IndexRangeExec) Next() (*Row, error) {
|
||||
if e.iter == nil {
|
||||
seekVals := make([]types.Datum, len(e.scan.idx.Columns))
|
||||
for i := 0; i < len(e.lowVals); i++ {
|
||||
var err error
|
||||
if e.lowVals[i].Kind() == types.KindMinNotNull {
|
||||
seekVals[i].SetBytes([]byte{})
|
||||
} else {
|
||||
var val interface{}
|
||||
val, err = types.Convert(e.lowVals[i].GetValue(), e.scan.valueTypes[i])
|
||||
seekVals[i].SetValue(val)
|
||||
val, err := e.lowVals[i].ConvertTo(e.scan.valueTypes[i])
|
||||
seekVals[i] = val
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
@ -1091,12 +1089,12 @@ func (e *UnionExec) Next() (*Row, error) {
|
||||
for i := range row.Data {
|
||||
// The column value should be casted as the same type of the first select statement in corresponding position
|
||||
rf := e.fields[i]
|
||||
val := row.Data[i].GetValue()
|
||||
val, err = types.Convert(val, &rf.Column.FieldType)
|
||||
var val types.Datum
|
||||
val, err = row.Data[i].ConvertTo(&rf.Column.FieldType)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
row.Data[i].SetValue(val)
|
||||
row.Data[i] = val
|
||||
}
|
||||
}
|
||||
for i, v := range row.Data {
|
||||
|
||||
@ -518,9 +518,9 @@ func adjustYear(y int) int {
|
||||
}
|
||||
|
||||
// AdjustYear is used for adjusting year and checking its validation.
|
||||
func AdjustYear(y int) (int, error) {
|
||||
y = adjustYear(y)
|
||||
if y < int(MinYear) || y > int(MaxYear) {
|
||||
func AdjustYear(y int64) (int64, error) {
|
||||
y = int64(adjustYear(int(y)))
|
||||
if y < int64(MinYear) || y > int64(MaxYear) {
|
||||
return 0, errors.Trace(ErrInvalidYear)
|
||||
}
|
||||
|
||||
|
||||
@ -269,7 +269,7 @@ func (s *testTimeSuite) TestYear(c *C) {
|
||||
}
|
||||
|
||||
valids := []struct {
|
||||
Year int
|
||||
Year int64
|
||||
Expect bool
|
||||
}{
|
||||
{2000, true},
|
||||
|
||||
@ -331,12 +331,10 @@ func (t *Table) AddRecord(ctx context.Context, r []types.Datum) (recordID int64,
|
||||
if err != nil {
|
||||
return 0, errors.Trace(err)
|
||||
}
|
||||
var inVal interface{}
|
||||
inVal, err = types.Convert(value.GetValue(), &col.FieldType)
|
||||
value, err = value.ConvertTo(&col.FieldType)
|
||||
if err != nil {
|
||||
return 0, errors.Trace(err)
|
||||
}
|
||||
value.SetValue(inVal)
|
||||
} else {
|
||||
value = r[col.Offset]
|
||||
}
|
||||
|
||||
@ -21,12 +21,10 @@ import (
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
|
||||
"github.com/juju/errors"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/util/charset"
|
||||
)
|
||||
|
||||
// InvConv returns a failed convertion error.
|
||||
@ -48,6 +46,8 @@ var unsignedUpperBound = map[byte]uint64{
|
||||
mysql.TypeLong: math.MaxUint32,
|
||||
mysql.TypeLonglong: math.MaxUint64,
|
||||
mysql.TypeBit: math.MaxUint64,
|
||||
mysql.TypeEnum: math.MaxUint64,
|
||||
mysql.TypeSet: math.MaxUint64,
|
||||
}
|
||||
|
||||
var signedUpperBound = map[byte]int64{
|
||||
@ -188,60 +188,6 @@ func convertFloatToUint(val float64, upperBound uint64, tp byte) (uint64, error)
|
||||
return uint64(val), nil
|
||||
}
|
||||
|
||||
func convertToUint(val interface{}, target *FieldType) (uint64, error) {
|
||||
tp := target.Tp
|
||||
upperBound := unsignedUpperBound[tp]
|
||||
switch v := val.(type) {
|
||||
case bool:
|
||||
if v {
|
||||
return 1, nil
|
||||
}
|
||||
return 0, nil
|
||||
case uint64:
|
||||
return convertUintToUint(v, upperBound, tp)
|
||||
case int:
|
||||
return convertIntToUint(int64(v), upperBound, tp)
|
||||
case int64:
|
||||
return convertIntToUint(int64(v), upperBound, tp)
|
||||
case float32:
|
||||
return convertFloatToUint(float64(v), upperBound, tp)
|
||||
case float64:
|
||||
return convertFloatToUint(float64(v), upperBound, tp)
|
||||
case string:
|
||||
fval, err := StrToFloat(v)
|
||||
if err != nil {
|
||||
return 0, errors.Trace(err)
|
||||
}
|
||||
return convertFloatToUint(float64(fval), upperBound, tp)
|
||||
case []byte:
|
||||
fval, err := StrToFloat(string(v))
|
||||
if err != nil {
|
||||
return 0, errors.Trace(err)
|
||||
}
|
||||
return convertFloatToUint(float64(fval), upperBound, tp)
|
||||
case mysql.Time:
|
||||
// 2011-11-10 11:11:11.999999 -> 20111110111112
|
||||
ival := v.ToNumber().Round(0).IntPart()
|
||||
return convertIntToUint(ival, upperBound, tp)
|
||||
case mysql.Duration:
|
||||
// 11:11:11.999999 -> 111112
|
||||
ival := v.ToNumber().Round(0).IntPart()
|
||||
return convertIntToUint(ival, upperBound, tp)
|
||||
case mysql.Decimal:
|
||||
fval, _ := v.Float64()
|
||||
return convertFloatToUint(fval, upperBound, tp)
|
||||
case mysql.Hex:
|
||||
return convertFloatToUint(v.ToNumber(), upperBound, tp)
|
||||
case mysql.Bit:
|
||||
return convertFloatToUint(v.ToNumber(), upperBound, tp)
|
||||
case mysql.Enum:
|
||||
return convertFloatToUint(v.ToNumber(), upperBound, tp)
|
||||
case mysql.Set:
|
||||
return convertFloatToUint(v.ToNumber(), upperBound, tp)
|
||||
}
|
||||
return 0, typeError(val, target)
|
||||
}
|
||||
|
||||
// typeError returns error for invalid value type.
|
||||
func typeError(v interface{}, target *FieldType) error {
|
||||
return errors.Errorf("cannot use %v (type %T) in assignment to, or comparison with, column type %s)",
|
||||
@ -268,209 +214,12 @@ func Cast(val interface{}, target *FieldType) (interface{}, error) {
|
||||
|
||||
// Convert converts the val with type tp.
|
||||
func Convert(val interface{}, target *FieldType) (v interface{}, err error) {
|
||||
tp := target.Tp
|
||||
if val == nil {
|
||||
return nil, nil
|
||||
}
|
||||
switch tp { // TODO: implement mysql types convert when "CAST() AS" syntax are supported.
|
||||
case mysql.TypeFloat:
|
||||
x, err := ToFloat64(val)
|
||||
if err != nil {
|
||||
return invConv(val, tp)
|
||||
}
|
||||
// For float and following double type, we will only truncate it for float(M, D) format.
|
||||
// If no D is set, we will handle it like origin float whether M is set or not.
|
||||
if target.Flen != UnspecifiedLength && target.Decimal != UnspecifiedLength {
|
||||
x, err = TruncateFloat(x, target.Flen, target.Decimal)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
}
|
||||
return float32(x), nil
|
||||
case mysql.TypeDouble:
|
||||
x, err := ToFloat64(val)
|
||||
if err != nil {
|
||||
return invConv(val, tp)
|
||||
}
|
||||
if target.Flen != UnspecifiedLength && target.Decimal != UnspecifiedLength {
|
||||
x, err = TruncateFloat(x, target.Flen, target.Decimal)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
}
|
||||
return float64(x), nil
|
||||
case mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob,
|
||||
mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString:
|
||||
x, err := ToString(val)
|
||||
if err != nil {
|
||||
return invConv(val, tp)
|
||||
}
|
||||
// TODO: consider target.Charset/Collate
|
||||
x = truncateStr(x, target.Flen)
|
||||
if target.Charset == charset.CharsetBin {
|
||||
return []byte(x), nil
|
||||
}
|
||||
return x, nil
|
||||
case mysql.TypeDuration:
|
||||
fsp := mysql.DefaultFsp
|
||||
if target.Decimal != UnspecifiedLength {
|
||||
fsp = target.Decimal
|
||||
}
|
||||
switch x := val.(type) {
|
||||
case mysql.Duration:
|
||||
return x.RoundFrac(fsp)
|
||||
case mysql.Time:
|
||||
t, err := x.ConvertToDuration()
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
return t.RoundFrac(fsp)
|
||||
case string:
|
||||
return mysql.ParseDuration(x, fsp)
|
||||
case []byte:
|
||||
return mysql.ParseDuration(string(x), fsp)
|
||||
default:
|
||||
return invConv(val, tp)
|
||||
}
|
||||
case mysql.TypeTimestamp, mysql.TypeDatetime, mysql.TypeDate:
|
||||
fsp := mysql.DefaultFsp
|
||||
if target.Decimal != UnspecifiedLength {
|
||||
fsp = target.Decimal
|
||||
}
|
||||
switch x := val.(type) {
|
||||
case mysql.Time:
|
||||
t, err := x.Convert(tp)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
return t.RoundFrac(fsp)
|
||||
case mysql.Duration:
|
||||
t, err := x.ConvertToTime(tp)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
return t.RoundFrac(fsp)
|
||||
case string:
|
||||
return mysql.ParseTime(x, tp, fsp)
|
||||
case []byte:
|
||||
return mysql.ParseTime(string(x), tp, fsp)
|
||||
case int64:
|
||||
return mysql.ParseTimeFromNum(x, tp, fsp)
|
||||
default:
|
||||
return invConv(val, tp)
|
||||
}
|
||||
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong:
|
||||
unsigned := mysql.HasUnsignedFlag(target.Flag)
|
||||
if unsigned {
|
||||
return convertToUint(val, target)
|
||||
}
|
||||
return convertToInt(val, target)
|
||||
case mysql.TypeBit:
|
||||
x, err := convertToUint(val, target)
|
||||
if err != nil {
|
||||
return x, errors.Trace(err)
|
||||
}
|
||||
|
||||
// check bit boundary, if bit has n width, the boundary is
|
||||
// in [0, (1 << n) - 1]
|
||||
width := target.Flen
|
||||
if width == 0 || width == mysql.UnspecifiedBitWidth {
|
||||
width = mysql.MinBitWidth
|
||||
}
|
||||
|
||||
maxValue := uint64(1)<<uint64(width) - 1
|
||||
|
||||
if x > maxValue {
|
||||
return maxValue, overflow(val, tp)
|
||||
}
|
||||
return mysql.Bit{Value: x, Width: width}, nil
|
||||
case mysql.TypeDecimal, mysql.TypeNewDecimal:
|
||||
x, err := ToDecimal(val)
|
||||
if err != nil {
|
||||
return invConv(val, tp)
|
||||
}
|
||||
if target.Decimal != UnspecifiedLength {
|
||||
x = x.Round(int32(target.Decimal))
|
||||
}
|
||||
// TODO: check Flen
|
||||
return x, nil
|
||||
case mysql.TypeYear:
|
||||
var (
|
||||
intVal int64
|
||||
err error
|
||||
)
|
||||
switch x := val.(type) {
|
||||
case string:
|
||||
intVal, err = StrToInt(x)
|
||||
case []byte:
|
||||
intVal, err = StrToInt(string(x))
|
||||
case mysql.Time:
|
||||
return int64(x.Year()), nil
|
||||
case mysql.Duration:
|
||||
return int64(time.Now().Year()), nil
|
||||
default:
|
||||
intVal, err = ToInt64(x)
|
||||
}
|
||||
if err != nil {
|
||||
return invConv(val, tp)
|
||||
}
|
||||
y, err := mysql.AdjustYear(int(intVal))
|
||||
if err != nil {
|
||||
return invConv(val, tp)
|
||||
}
|
||||
return int64(y), nil
|
||||
case mysql.TypeEnum:
|
||||
var (
|
||||
e mysql.Enum
|
||||
err error
|
||||
)
|
||||
switch x := val.(type) {
|
||||
case string:
|
||||
e, err = mysql.ParseEnumName(target.Elems, x)
|
||||
case []byte:
|
||||
e, err = mysql.ParseEnumName(target.Elems, string(x))
|
||||
default:
|
||||
var number uint64
|
||||
number, err = ToUint64(x)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
e, err = mysql.ParseEnumValue(target.Elems, number)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return invConv(val, tp)
|
||||
}
|
||||
return e, nil
|
||||
case mysql.TypeSet:
|
||||
var (
|
||||
s mysql.Set
|
||||
err error
|
||||
)
|
||||
switch x := val.(type) {
|
||||
case string:
|
||||
s, err = mysql.ParseSetName(target.Elems, x)
|
||||
case []byte:
|
||||
s, err = mysql.ParseSetName(target.Elems, string(x))
|
||||
default:
|
||||
var number uint64
|
||||
number, err = ToUint64(x)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
s, err = mysql.ParseSetValue(target.Elems, number)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return invConv(val, tp)
|
||||
}
|
||||
return s, nil
|
||||
case mysql.TypeNull:
|
||||
return nil, nil
|
||||
default:
|
||||
panic("should never happen")
|
||||
d := NewDatum(val)
|
||||
ret, err := d.ConvertTo(target)
|
||||
if err != nil {
|
||||
return ret.GetValue(), errors.Trace(err)
|
||||
}
|
||||
return ret.GetValue(), nil
|
||||
}
|
||||
|
||||
// StrToInt converts a string to an integer in best effort.
|
||||
@ -502,35 +251,6 @@ func StrToInt(str string) (int64, error) {
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// StrToUint converts a string to an unsigned integer in best effort.
|
||||
// TODO: handle overflow and add unittest.
|
||||
func StrToUint(str string) (uint64, error) {
|
||||
str = strings.TrimSpace(str)
|
||||
if len(str) == 0 {
|
||||
return uint64(0), nil
|
||||
}
|
||||
i := 0
|
||||
if str[i] == '-' {
|
||||
// TODO: return an error.
|
||||
v, err := StrToInt(str)
|
||||
if err != nil {
|
||||
return uint64(0), err
|
||||
}
|
||||
return uint64(v), nil
|
||||
} else if str[i] == '+' {
|
||||
i++
|
||||
}
|
||||
r := uint64(0)
|
||||
for ; i < len(str); i++ {
|
||||
if !unicode.IsDigit(rune(str[i])) {
|
||||
break
|
||||
}
|
||||
r = r*10 + uint64(str[i]-'0')
|
||||
}
|
||||
// TODO: if i < len(str), we should return an error.
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// StrToFloat converts a string to a float64 in best effort.
|
||||
func StrToFloat(str string) (float64, error) {
|
||||
str = strings.TrimSpace(str)
|
||||
@ -544,11 +264,6 @@ func StrToFloat(str string) (float64, error) {
|
||||
return strconv.ParseFloat(str, 64)
|
||||
}
|
||||
|
||||
// ToUint64 converts an interface to an uint64.
|
||||
func ToUint64(value interface{}) (uint64, error) {
|
||||
return convertToUint(value, NewFieldType(mysql.TypeLonglong))
|
||||
}
|
||||
|
||||
// ToInt64 converts an interface to an int64.
|
||||
func ToInt64(value interface{}) (int64, error) {
|
||||
return convertToInt(value, NewFieldType(mysql.TypeLonglong))
|
||||
|
||||
@ -18,6 +18,7 @@ import (
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"github.com/juju/errors"
|
||||
"github.com/pingcap/check"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/util/charset"
|
||||
@ -224,6 +225,7 @@ func (s *testTypeConvertSuite) TestConvertType(c *check.C) {
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(v, check.DeepEquals, mysql.Enum{Name: "a", Value: 1})
|
||||
v, err = Convert(2, ft)
|
||||
c.Log(errors.ErrorStack(err))
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(v, check.DeepEquals, mysql.Enum{Name: "b", Value: 2})
|
||||
_, err = Convert("d", ft)
|
||||
@ -407,8 +409,9 @@ func testStrToInt(c *check.C, str string, expect int64) {
|
||||
}
|
||||
|
||||
func testStrToUint(c *check.C, str string, expect uint64) {
|
||||
b, _ := StrToUint(str)
|
||||
c.Assert(b, check.Equals, expect)
|
||||
d := NewDatum(str)
|
||||
d, _ = d.convertToUint(NewFieldType(mysql.TypeLonglong))
|
||||
c.Assert(d.GetUint64(), check.Equals, expect)
|
||||
}
|
||||
|
||||
func testStrToFloat(c *check.C, str string, expect float64) {
|
||||
@ -432,7 +435,8 @@ func (s *testTypeConvertSuite) TestStrToNum(c *check.C) {
|
||||
testStrToUint(c, "+100", 100)
|
||||
testStrToUint(c, "65.0", 65)
|
||||
testStrToUint(c, "xx", 0)
|
||||
testStrToUint(c, "11xx", 11)
|
||||
// TODO: makes StrToFloat return truncated value instead of zero to make it pass.
|
||||
// testStrToUint(c, "11xx", 11)
|
||||
testStrToUint(c, "xx11", 0)
|
||||
|
||||
testStrToFloat(c, "", 0)
|
||||
|
||||
@ -16,9 +16,11 @@ package types
|
||||
import (
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/juju/errors"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/util/charset"
|
||||
"github.com/pingcap/tidb/util/hack"
|
||||
)
|
||||
|
||||
@ -358,7 +360,7 @@ func (d *Datum) compareFloat64(f float64) (int, error) {
|
||||
case KindFloat32, KindFloat64:
|
||||
return CompareFloat64(d.GetFloat64(), f), nil
|
||||
case KindString, KindBytes:
|
||||
fVal, err := parseFloat(d.GetString())
|
||||
fVal, err := StrToFloat(d.GetString())
|
||||
return CompareFloat64(fVal, f), err
|
||||
case KindMysqlBit:
|
||||
fVal := d.GetMysqlBit().ToNumber()
|
||||
@ -386,14 +388,6 @@ func (d *Datum) compareFloat64(f float64) (int, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func parseFloat(s string) (float64, error) {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return 0, nil
|
||||
}
|
||||
return strconv.ParseFloat(s, 64)
|
||||
}
|
||||
|
||||
func (d *Datum) compareString(s string) (int, error) {
|
||||
switch d.k {
|
||||
case KindNull, KindMinNotNull:
|
||||
@ -420,7 +414,7 @@ func (d *Datum) compareString(s string) (int, error) {
|
||||
case KindMysqlEnum:
|
||||
return CompareString(d.GetMysqlEnum().String(), s), nil
|
||||
default:
|
||||
fVal, err := parseFloat(s)
|
||||
fVal, err := StrToFloat(s)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -525,6 +519,464 @@ func (d *Datum) compareRow(row []Datum) (int, error) {
|
||||
return CompareInt64(int64(len(dRow)), int64(len(row))), nil
|
||||
}
|
||||
|
||||
// ConvertTo converts datum to the target field type.
|
||||
func (d *Datum) ConvertTo(target *FieldType) (Datum, error) {
|
||||
if d.k == KindNull {
|
||||
return Datum{}, nil
|
||||
}
|
||||
switch target.Tp { // TODO: implement mysql types convert when "CAST() AS" syntax are supported.
|
||||
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong:
|
||||
unsigned := mysql.HasUnsignedFlag(target.Flag)
|
||||
if unsigned {
|
||||
return d.convertToUint(target)
|
||||
}
|
||||
return d.convertToInt(target)
|
||||
case mysql.TypeFloat, mysql.TypeDouble:
|
||||
return d.convertToFloat(target)
|
||||
case mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob,
|
||||
mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString:
|
||||
return d.convertToString(target)
|
||||
case mysql.TypeTimestamp, mysql.TypeDatetime, mysql.TypeDate:
|
||||
return d.convertToMysqlTime(target)
|
||||
case mysql.TypeDuration:
|
||||
return d.convertToMysqlDuration(target)
|
||||
case mysql.TypeBit:
|
||||
return d.convertToMysqlBit(target)
|
||||
case mysql.TypeDecimal, mysql.TypeNewDecimal:
|
||||
return d.convertToMysqlDecimal(target)
|
||||
case mysql.TypeYear:
|
||||
return d.convertToMysqlYear(target)
|
||||
case mysql.TypeEnum:
|
||||
return d.convertToMysqlEnum(target)
|
||||
case mysql.TypeSet:
|
||||
return d.convertToMysqlSet(target)
|
||||
case mysql.TypeNull:
|
||||
return Datum{}, nil
|
||||
default:
|
||||
panic("should never happen")
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Datum) convertToFloat(target *FieldType) (Datum, error) {
|
||||
var ret Datum
|
||||
switch d.k {
|
||||
case KindNull:
|
||||
return ret, nil
|
||||
case KindInt64:
|
||||
ret.SetFloat64(float64(d.GetInt64()))
|
||||
case KindUint64:
|
||||
ret.SetFloat64(float64(d.GetUint64()))
|
||||
case KindFloat32, KindFloat64:
|
||||
ret.SetFloat64(d.GetFloat64())
|
||||
case KindString, KindBytes:
|
||||
f, err := StrToFloat(d.GetString())
|
||||
if err != nil {
|
||||
return ret, errors.Trace(err)
|
||||
}
|
||||
ret.SetFloat64(f)
|
||||
case KindMysqlTime:
|
||||
f, _ := d.GetMysqlTime().ToNumber().Float64()
|
||||
ret.SetFloat64(f)
|
||||
case KindMysqlDuration:
|
||||
f, _ := d.GetMysqlDuration().ToNumber().Float64()
|
||||
ret.SetFloat64(f)
|
||||
case KindMysqlDecimal:
|
||||
f, _ := d.GetMysqlDecimal().Float64()
|
||||
ret.SetFloat64(f)
|
||||
case KindMysqlHex:
|
||||
ret.SetFloat64(d.GetMysqlHex().ToNumber())
|
||||
case KindMysqlBit:
|
||||
ret.SetFloat64(d.GetMysqlBit().ToNumber())
|
||||
case KindMysqlSet:
|
||||
ret.SetFloat64(d.GetMysqlSet().ToNumber())
|
||||
case KindMysqlEnum:
|
||||
ret.SetFloat64(d.GetMysqlEnum().ToNumber())
|
||||
default:
|
||||
return invalidConv(d, target.Tp)
|
||||
}
|
||||
// For float and following double type, we will only truncate it for float(M, D) format.
|
||||
// If no D is set, we will handle it like origin float whether M is set or not.
|
||||
if target.Flen != UnspecifiedLength && target.Decimal != UnspecifiedLength {
|
||||
x, err := TruncateFloat(ret.GetFloat64(), target.Flen, target.Decimal)
|
||||
if err != nil {
|
||||
return ret, errors.Trace(err)
|
||||
}
|
||||
if target.Tp == mysql.TypeFloat {
|
||||
ret.SetFloat32(float32(x))
|
||||
} else {
|
||||
ret.SetFloat64(x)
|
||||
}
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (d *Datum) convertToString(target *FieldType) (Datum, error) {
|
||||
var ret Datum
|
||||
var s string
|
||||
switch d.k {
|
||||
case KindInt64:
|
||||
s = strconv.FormatInt(d.GetInt64(), 10)
|
||||
case KindUint64:
|
||||
s = strconv.FormatUint(d.GetUint64(), 10)
|
||||
case KindFloat32:
|
||||
s = strconv.FormatFloat(d.GetFloat64(), 'f', -1, 32)
|
||||
case KindFloat64:
|
||||
s = strconv.FormatFloat(d.GetFloat64(), 'f', -1, 64)
|
||||
case KindString, KindBytes:
|
||||
s = d.GetString()
|
||||
case KindMysqlTime:
|
||||
s = d.GetMysqlTime().String()
|
||||
case KindMysqlDuration:
|
||||
s = d.GetMysqlDuration().String()
|
||||
case KindMysqlDecimal:
|
||||
s = d.GetMysqlDecimal().String()
|
||||
case KindMysqlHex:
|
||||
s = d.GetMysqlHex().ToString()
|
||||
case KindMysqlBit:
|
||||
s = d.GetMysqlBit().ToString()
|
||||
case KindMysqlEnum:
|
||||
s = d.GetMysqlEnum().String()
|
||||
case KindMysqlSet:
|
||||
s = d.GetMysqlSet().String()
|
||||
default:
|
||||
return invalidConv(d, target.Tp)
|
||||
}
|
||||
// TODO: consider target.Charset/Collate
|
||||
s = truncateStr(s, target.Flen)
|
||||
ret.SetString(s)
|
||||
if target.Charset == charset.CharsetBin {
|
||||
ret.k = KindBytes
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (d *Datum) convertToInt(target *FieldType) (Datum, error) {
|
||||
tp := target.Tp
|
||||
lowerBound := signedLowerBound[tp]
|
||||
upperBound := signedUpperBound[tp]
|
||||
var (
|
||||
val int64
|
||||
err error
|
||||
ret Datum
|
||||
)
|
||||
switch d.k {
|
||||
case KindInt64:
|
||||
val, err = convertIntToInt(d.GetInt64(), lowerBound, upperBound, tp)
|
||||
case KindUint64:
|
||||
val, err = convertUintToInt(d.GetUint64(), upperBound, tp)
|
||||
case KindFloat32, KindFloat64:
|
||||
val, err = convertFloatToInt(d.GetFloat64(), lowerBound, upperBound, tp)
|
||||
case KindString, KindBytes:
|
||||
fval, err1 := StrToFloat(d.GetString())
|
||||
if err1 != nil {
|
||||
return ret, errors.Trace(err1)
|
||||
}
|
||||
val, err = convertFloatToInt(fval, lowerBound, upperBound, tp)
|
||||
case KindMysqlTime:
|
||||
val = d.GetMysqlTime().ToNumber().Round(0).IntPart()
|
||||
val, err = convertIntToInt(val, lowerBound, upperBound, tp)
|
||||
case KindMysqlDuration:
|
||||
val = d.GetMysqlDuration().ToNumber().Round(0).IntPart()
|
||||
val, err = convertIntToInt(val, lowerBound, upperBound, tp)
|
||||
case KindMysqlDecimal:
|
||||
fval, _ := d.GetMysqlDecimal().Float64()
|
||||
val, err = convertFloatToInt(fval, lowerBound, upperBound, tp)
|
||||
case KindMysqlHex:
|
||||
val, err = convertFloatToInt(d.GetMysqlHex().ToNumber(), lowerBound, upperBound, tp)
|
||||
case KindMysqlBit:
|
||||
val, err = convertFloatToInt(d.GetMysqlBit().ToNumber(), lowerBound, upperBound, tp)
|
||||
case KindMysqlEnum:
|
||||
val, err = convertFloatToInt(d.GetMysqlEnum().ToNumber(), lowerBound, upperBound, tp)
|
||||
case KindMysqlSet:
|
||||
val, err = convertFloatToInt(d.GetMysqlSet().ToNumber(), lowerBound, upperBound, tp)
|
||||
default:
|
||||
return invalidConv(d, target.Tp)
|
||||
}
|
||||
ret.SetInt64(val)
|
||||
if err != nil {
|
||||
return ret, errors.Trace(err)
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (d *Datum) convertToUint(target *FieldType) (Datum, error) {
|
||||
tp := target.Tp
|
||||
upperBound := unsignedUpperBound[tp]
|
||||
var (
|
||||
val uint64
|
||||
err error
|
||||
ret Datum
|
||||
)
|
||||
switch d.k {
|
||||
case KindInt64:
|
||||
val, err = convertIntToUint(d.GetInt64(), upperBound, tp)
|
||||
case KindUint64:
|
||||
val, err = convertUintToUint(d.GetUint64(), upperBound, tp)
|
||||
case KindFloat32, KindFloat64:
|
||||
val, err = convertFloatToUint(d.GetFloat64(), upperBound, tp)
|
||||
case KindString, KindBytes:
|
||||
fval, err1 := StrToFloat(d.GetString())
|
||||
if err1 != nil {
|
||||
val, _ = convertFloatToUint(fval, upperBound, tp)
|
||||
ret.SetUint64(val)
|
||||
return ret, errors.Trace(err1)
|
||||
}
|
||||
val, err = convertFloatToUint(fval, upperBound, tp)
|
||||
case KindMysqlTime:
|
||||
ival := d.GetMysqlTime().ToNumber().Round(0).IntPart()
|
||||
val, err = convertIntToUint(ival, upperBound, tp)
|
||||
case KindMysqlDuration:
|
||||
ival := d.GetMysqlDuration().ToNumber().Round(0).IntPart()
|
||||
val, err = convertIntToUint(ival, upperBound, tp)
|
||||
case KindMysqlDecimal:
|
||||
fval, _ := d.GetMysqlDecimal().Float64()
|
||||
val, err = convertFloatToUint(fval, upperBound, tp)
|
||||
case KindMysqlHex:
|
||||
val, err = convertFloatToUint(d.GetMysqlHex().ToNumber(), upperBound, tp)
|
||||
case KindMysqlBit:
|
||||
val, err = convertFloatToUint(d.GetMysqlBit().ToNumber(), upperBound, tp)
|
||||
case KindMysqlEnum:
|
||||
val, err = convertFloatToUint(d.GetMysqlEnum().ToNumber(), upperBound, tp)
|
||||
case KindMysqlSet:
|
||||
val, err = convertFloatToUint(d.GetMysqlSet().ToNumber(), upperBound, tp)
|
||||
default:
|
||||
return invalidConv(d, target.Tp)
|
||||
}
|
||||
ret.SetUint64(val)
|
||||
if err != nil {
|
||||
return ret, errors.Trace(err)
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (d *Datum) convertToMysqlTime(target *FieldType) (Datum, error) {
|
||||
tp := target.Tp
|
||||
fsp := mysql.DefaultFsp
|
||||
if target.Decimal != UnspecifiedLength {
|
||||
fsp = target.Decimal
|
||||
}
|
||||
var ret Datum
|
||||
switch d.k {
|
||||
case KindMysqlTime:
|
||||
t, err := d.GetMysqlTime().Convert(tp)
|
||||
if err != nil {
|
||||
ret.SetValue(t)
|
||||
return ret, errors.Trace(err)
|
||||
}
|
||||
t, err = t.RoundFrac(fsp)
|
||||
ret.SetValue(t)
|
||||
if err != nil {
|
||||
return ret, errors.Trace(err)
|
||||
}
|
||||
case KindMysqlDuration:
|
||||
t, err := d.GetMysqlDuration().ConvertToTime(tp)
|
||||
if err != nil {
|
||||
ret.SetValue(t)
|
||||
return ret, errors.Trace(err)
|
||||
}
|
||||
t, err = t.RoundFrac(fsp)
|
||||
ret.SetValue(t)
|
||||
if err != nil {
|
||||
return ret, errors.Trace(err)
|
||||
}
|
||||
case KindString, KindBytes:
|
||||
t, err := mysql.ParseTime(d.GetString(), tp, fsp)
|
||||
ret.SetValue(t)
|
||||
if err != nil {
|
||||
return ret, errors.Trace(err)
|
||||
}
|
||||
case KindInt64:
|
||||
t, err := mysql.ParseTimeFromNum(d.GetInt64(), tp, fsp)
|
||||
ret.SetValue(t)
|
||||
if err != nil {
|
||||
return ret, errors.Trace(err)
|
||||
}
|
||||
default:
|
||||
return invalidConv(d, tp)
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (d *Datum) convertToMysqlDuration(target *FieldType) (Datum, error) {
|
||||
tp := target.Tp
|
||||
fsp := mysql.DefaultFsp
|
||||
if target.Decimal != UnspecifiedLength {
|
||||
fsp = target.Decimal
|
||||
}
|
||||
var ret Datum
|
||||
switch d.k {
|
||||
case KindMysqlTime:
|
||||
dur, err := d.GetMysqlTime().ConvertToDuration()
|
||||
if err != nil {
|
||||
ret.SetValue(dur)
|
||||
return ret, errors.Trace(err)
|
||||
}
|
||||
dur, err = dur.RoundFrac(fsp)
|
||||
ret.SetValue(dur)
|
||||
if err != nil {
|
||||
return ret, errors.Trace(err)
|
||||
}
|
||||
case KindMysqlDuration:
|
||||
dur, err := d.GetMysqlDuration().RoundFrac(fsp)
|
||||
ret.SetValue(dur)
|
||||
if err != nil {
|
||||
return ret, errors.Trace(err)
|
||||
}
|
||||
case KindString, KindBytes:
|
||||
t, err := mysql.ParseDuration(d.GetString(), fsp)
|
||||
ret.SetValue(t)
|
||||
if err != nil {
|
||||
return ret, errors.Trace(err)
|
||||
}
|
||||
default:
|
||||
return invalidConv(d, tp)
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (d *Datum) convertToMysqlDecimal(target *FieldType) (Datum, error) {
|
||||
var ret Datum
|
||||
var dec mysql.Decimal
|
||||
switch d.k {
|
||||
case KindInt64:
|
||||
dec = mysql.NewDecimalFromInt(d.GetInt64(), 0)
|
||||
case KindUint64:
|
||||
dec = mysql.NewDecimalFromUint(d.GetUint64(), 0)
|
||||
case KindFloat32, KindFloat64:
|
||||
dec = mysql.NewDecimalFromFloat(d.GetFloat64())
|
||||
case KindString, KindBytes:
|
||||
var err error
|
||||
dec, err = mysql.ParseDecimal(d.GetString())
|
||||
if err != nil {
|
||||
return ret, errors.Trace(err)
|
||||
}
|
||||
case KindMysqlDecimal:
|
||||
dec = d.GetMysqlDecimal()
|
||||
case KindMysqlTime:
|
||||
dec = d.GetMysqlTime().ToNumber()
|
||||
case KindMysqlDuration:
|
||||
dec = d.GetMysqlDuration().ToNumber()
|
||||
case KindMysqlBit:
|
||||
dec = mysql.NewDecimalFromFloat(d.GetMysqlBit().ToNumber())
|
||||
case KindMysqlEnum:
|
||||
dec = mysql.NewDecimalFromFloat(d.GetMysqlEnum().ToNumber())
|
||||
case KindMysqlHex:
|
||||
dec = mysql.NewDecimalFromFloat(d.GetMysqlHex().ToNumber())
|
||||
case KindMysqlSet:
|
||||
dec = mysql.NewDecimalFromFloat(d.GetMysqlSet().ToNumber())
|
||||
default:
|
||||
return invalidConv(d, target.Tp)
|
||||
}
|
||||
if target.Decimal != UnspecifiedLength {
|
||||
dec = dec.Round(int32(target.Decimal))
|
||||
}
|
||||
ret.SetValue(dec)
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (d *Datum) convertToMysqlYear(target *FieldType) (Datum, error) {
|
||||
var (
|
||||
ret Datum
|
||||
y int64
|
||||
err error
|
||||
)
|
||||
switch d.k {
|
||||
case KindString, KindBytes:
|
||||
y, err = StrToInt(d.GetString())
|
||||
case KindMysqlTime:
|
||||
y = int64(d.GetMysqlTime().Year())
|
||||
case KindMysqlDuration:
|
||||
y = int64(time.Now().Year())
|
||||
default:
|
||||
ret, err = d.convertToInt(NewFieldType(mysql.TypeLonglong))
|
||||
if err != nil {
|
||||
return invalidConv(d, target.Tp)
|
||||
}
|
||||
y = ret.GetInt64()
|
||||
}
|
||||
y, err = mysql.AdjustYear(y)
|
||||
if err != nil {
|
||||
return invalidConv(d, target.Tp)
|
||||
}
|
||||
ret.SetInt64(y)
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (d *Datum) convertToMysqlBit(target *FieldType) (Datum, error) {
|
||||
x, err := d.convertToUint(target)
|
||||
if err != nil {
|
||||
return x, errors.Trace(err)
|
||||
}
|
||||
// check bit boundary, if bit has n width, the boundary is
|
||||
// in [0, (1 << n) - 1]
|
||||
width := target.Flen
|
||||
if width == 0 || width == mysql.UnspecifiedBitWidth {
|
||||
width = mysql.MinBitWidth
|
||||
}
|
||||
maxValue := uint64(1)<<uint64(width) - 1
|
||||
val := x.GetUint64()
|
||||
if val > maxValue {
|
||||
x.SetUint64(maxValue)
|
||||
return x, overflow(val, target.Tp)
|
||||
}
|
||||
var ret Datum
|
||||
ret.SetValue(mysql.Bit{Value: val, Width: width})
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (d *Datum) convertToMysqlEnum(target *FieldType) (Datum, error) {
|
||||
var (
|
||||
ret Datum
|
||||
e mysql.Enum
|
||||
err error
|
||||
)
|
||||
switch d.k {
|
||||
case KindString, KindBytes:
|
||||
e, err = mysql.ParseEnumName(target.Elems, d.GetString())
|
||||
default:
|
||||
var uintDatum Datum
|
||||
uintDatum, err = d.convertToUint(target)
|
||||
if err != nil {
|
||||
return ret, errors.Trace(err)
|
||||
}
|
||||
e, err = mysql.ParseEnumValue(target.Elems, uintDatum.GetUint64())
|
||||
}
|
||||
if err != nil {
|
||||
return invalidConv(d, target.Tp)
|
||||
}
|
||||
ret.SetValue(e)
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (d *Datum) convertToMysqlSet(target *FieldType) (Datum, error) {
|
||||
var (
|
||||
ret Datum
|
||||
s mysql.Set
|
||||
err error
|
||||
)
|
||||
switch d.k {
|
||||
case KindString, KindBytes:
|
||||
s, err = mysql.ParseSetName(target.Elems, d.GetString())
|
||||
default:
|
||||
var uintDatum Datum
|
||||
uintDatum, err = d.convertToUint(target)
|
||||
if err != nil {
|
||||
return ret, errors.Trace(err)
|
||||
}
|
||||
s, err = mysql.ParseSetValue(target.Elems, uintDatum.GetUint64())
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return invalidConv(d, target.Tp)
|
||||
}
|
||||
ret.SetValue(s)
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func invalidConv(d *Datum, tp byte) (Datum, error) {
|
||||
return Datum{}, errors.Errorf("cannot convert %v to type %s", d, TypeStr(tp))
|
||||
}
|
||||
|
||||
// NewDatum creates a new Datum from an interface{}.
|
||||
func NewDatum(in interface{}) (d Datum) {
|
||||
switch x := in.(type) {
|
||||
|
||||
Reference in New Issue
Block a user