From d8c2c96c8e7ca4bd4cf5259c46214da3e6d3ba02 Mon Sep 17 00:00:00 2001 From: siddontang Date: Thu, 24 Sep 2015 10:56:07 +0800 Subject: [PATCH] types: support enum type --- util/types/compare.go | 24 ++++++++++++++++++++++++ util/types/compare_test.go | 9 +++++++++ util/types/convert.go | 37 +++++++++++++++++++++++++++++++++++++ util/types/convert_test.go | 18 ++++++++++++++++++ util/types/etc.go | 8 ++++++-- util/types/etc_test.go | 1 + 6 files changed, 95 insertions(+), 2 deletions(-) diff --git a/util/types/compare.go b/util/types/compare.go index e152f4a2ca..f9cda12142 100644 --- a/util/types/compare.go +++ b/util/types/compare.go @@ -187,6 +187,8 @@ func Compare(a, b interface{}) (int, error) { return CompareFloat64(float64(x), y.ToNumber()), nil case mysql.Bit: return CompareFloat64(float64(x), y.ToNumber()), nil + case mysql.Enum: + return CompareFloat64(float64(x), y.ToNumber()), nil } case uint64: switch y := b.(type) { @@ -200,6 +202,9 @@ func Compare(a, b interface{}) (int, error) { return CompareFloat64(float64(x), y.ToNumber()), nil case mysql.Bit: return CompareFloat64(float64(x), y.ToNumber()), nil + case mysql.Enum: + return CompareFloat64(float64(x), y.ToNumber()), nil + } case mysql.Decimal: switch y := b.(type) { @@ -238,6 +243,8 @@ func Compare(a, b interface{}) (int, error) { return CompareString(x, y.ToString()), nil case mysql.Bit: return CompareString(x, y.ToString()), nil + case mysql.Enum: + return CompareString(x, y.String()), nil } case mysql.Time: switch y := b.(type) { @@ -263,6 +270,8 @@ func Compare(a, b interface{}) (int, error) { return CompareFloat64(x.ToNumber(), y.ToNumber()), nil case string: return CompareString(x.ToString(), y), nil + case mysql.Enum: + return CompareFloat64(x.ToNumber(), y.ToNumber()), nil } case mysql.Bit: switch y := b.(type) { @@ -274,6 +283,21 @@ func Compare(a, b interface{}) (int, error) { return CompareFloat64(x.ToNumber(), y.ToNumber()), nil case string: return CompareString(x.ToString(), y), nil + case mysql.Enum: + return CompareFloat64(x.ToNumber(), y.ToNumber()), nil + } + case mysql.Enum: + switch y := b.(type) { + case int64: + return CompareFloat64(x.ToNumber(), float64(y)), nil + case uint64: + return CompareFloat64(x.ToNumber(), float64(y)), nil + case mysql.Hex: + return CompareFloat64(x.ToNumber(), y.ToNumber()), nil + case string: + return CompareString(x.String(), y), nil + case mysql.Bit: + return CompareFloat64(x.ToNumber(), y.ToNumber()), nil } } diff --git a/util/types/compare_test.go b/util/types/compare_test.go index 391fb4e774..d207cd20e0 100644 --- a/util/types/compare_test.go +++ b/util/types/compare_test.go @@ -140,6 +140,15 @@ func (s *testCompareSuite) TestCompare(c *C) { {mysql.Bit{Value: 1, Width: 1}, float64(0), 1}, {mysql.Bit{Value: 1, Width: 1}, mysql.NewDecimalFromInt(1, 0), 0}, {mysql.Bit{Value: 1, Width: 1}, mysql.Hex{Value: 2}, -1}, + + {mysql.Enum{Name: "a", Value: 1}, 1, 0}, + {mysql.Enum{Name: "a", Value: 1}, "a", 0}, + {mysql.Enum{Name: "a", Value: 1}, uint64(10), -1}, + {mysql.Enum{Name: "a", Value: 1}, float64(0), 1}, + {mysql.Enum{Name: "a", Value: 1}, mysql.NewDecimalFromInt(1, 0), 0}, + {mysql.Enum{Name: "a", Value: 1}, mysql.Hex{Value: 2}, -1}, + {mysql.Enum{Name: "a", Value: 1}, mysql.Bit{Value: 1, Width: 1}, 0}, + {mysql.Enum{Name: "a", Value: 1}, mysql.Hex{Value: 1}, 0}, } for _, t := range cmpTbl { diff --git a/util/types/convert.go b/util/types/convert.go index edf9ddc4ed..487c7122dc 100644 --- a/util/types/convert.go +++ b/util/types/convert.go @@ -130,6 +130,8 @@ func convertToInt(val interface{}, target *FieldType) (converted int64, err erro return convertFloatToInt(v.ToNumber(), lowerBound, upperBound, tp) case mysql.Bit: return convertFloatToInt(v.ToNumber(), lowerBound, upperBound, tp) + case mysql.Enum: + return convertFloatToInt(v.ToNumber(), lowerBound, upperBound, tp) } return 0, typeError(val, target) } @@ -198,6 +200,8 @@ func convertToUint(val interface{}, target *FieldType) (converted uint64, err er return convertFloatToUint(v.ToNumber(), upperBound, tp) case mysql.Bit: return convertFloatToUint(v.ToNumber(), upperBound, tp) + case mysql.Enum: + return convertFloatToUint(v.ToNumber(), upperBound, tp) } return 0, typeError(val, target) } @@ -423,6 +427,29 @@ func Convert(val interface{}, target *FieldType) (v interface{}, err error) { // return invConv(val, tp) } return int64(y), nil + case mysql.TypeEnum: + var ( + e mysql.Enum + err error + ) + switch x := val.(type) { + case string: + e, err = mysql.ParseEnumName(target.Elems, x) + case []byte: + e, err = mysql.ParseEnumName(target.Elems, string(x)) + default: + var number uint64 + number, err = ToUint64(x) + if err != nil { + return nil, errors.Trace(err) + } + e, err = mysql.ParseEnumValue(target.Elems, number) + } + + if err != nil { + return invConv(val, tp) + } + return e, nil default: panic("should never happen") } @@ -531,6 +558,8 @@ func ToUint64(value interface{}) (uint64, error) { return uint64(v.ToNumber()), nil case mysql.Bit: return uint64(v.ToNumber()), nil + case mysql.Enum: + return uint64(v.ToNumber()), nil default: return 0, errors.Errorf("cannot convert %v(type %T) to int64", value, value) } @@ -573,6 +602,8 @@ func ToInt64(value interface{}) (int64, error) { return int64(v.ToNumber()), nil case mysql.Bit: return int64(v.ToNumber()), nil + case mysql.Enum: + return int64(v.ToNumber()), nil default: return 0, errors.Errorf("cannot convert %v(type %T) to int64", value, value) } @@ -613,6 +644,8 @@ func ToFloat64(value interface{}) (float64, error) { return v.ToNumber(), nil case mysql.Bit: return v.ToNumber(), nil + case mysql.Enum: + return v.ToNumber(), nil default: return 0, errors.Errorf("cannot convert %v(type %T) to float64", value, value) } @@ -670,6 +703,8 @@ func ToString(value interface{}) (string, error) { return v.ToString(), nil case mysql.Bit: return v.ToString(), nil + case mysql.Enum: + return v.String(), nil default: return "", errors.Errorf("cannot convert %v(type %T) to string", value, value) } @@ -723,6 +758,8 @@ func ToBool(value interface{}) (int64, error) { isZero = (v.ToNumber() == 0) case mysql.Bit: isZero = (v.ToNumber() == 0) + case mysql.Enum: + isZero = (v.ToNumber() == 0) default: return 0, errors.Errorf("cannot convert %v(type %T) to bool", value, value) } diff --git a/util/types/convert_test.go b/util/types/convert_test.go index ba8cabbe48..a2eeff18d8 100644 --- a/util/types/convert_test.go +++ b/util/types/convert_test.go @@ -217,6 +217,20 @@ func (s *testTypeConvertSuite) TestConvertType(c *C) { c.Assert(v, Equals, int64(2015)) v, err = Convert(mysql.ZeroDuration, ft) c.Assert(v, Equals, int64(time.Now().Year())) + + // For enum + ft = NewFieldType(mysql.TypeEnum) + ft.Elems = []string{"a", "b", "c"} + v, err = Convert("a", ft) + c.Assert(err, IsNil) + c.Assert(v, DeepEquals, mysql.Enum{Name: "a", Value: 1}) + v, err = Convert(2, ft) + c.Assert(err, IsNil) + c.Assert(v, DeepEquals, mysql.Enum{Name: "b", Value: 2}) + _, err = Convert("d", ft) + c.Assert(err, NotNil) + _, err = Convert(4, ft) + c.Assert(err, NotNil) } func testToInt64(c *C, val interface{}, expect int64) { @@ -234,6 +248,7 @@ func (s *testTypeConvertSuite) TestConvertToInt64(c *C) { testToInt64(c, float64(3.1), int64(3)) testToInt64(c, mysql.Hex{Value: 100}, int64(100)) testToInt64(c, mysql.Bit{Value: 100, Width: 8}, int64(100)) + testToInt64(c, mysql.Enum{Name: "a", Value: 1}, int64(1)) t, err := mysql.ParseTime("2011-11-10 11:11:11.999999", mysql.TypeTimestamp, 0) c.Assert(err, IsNil) @@ -271,6 +286,7 @@ func (s *testTypeConvertSuite) TestConvertToFloat64(c *C) { testToFloat64(c, float64(3.1), float64(3.1)) testToFloat64(c, mysql.Hex{Value: 100}, float64(100)) testToFloat64(c, mysql.Bit{Value: 100, Width: 8}, float64(100)) + testToFloat64(c, mysql.Enum{Name: "a", Value: 1}, float64(1)) t, err := mysql.ParseTime("2011-11-10 11:11:11.999999", mysql.TypeTimestamp, 6) c.Assert(err, IsNil) @@ -308,6 +324,7 @@ func (s *testTypeConvertSuite) TestConvertToString(c *C) { testToString(c, []byte{1}, "\x01") testToString(c, mysql.Hex{Value: 0x4D7953514C}, "MySQL") testToString(c, mysql.Bit{Value: 0x41, Width: 8}, "A") + testToString(c, mysql.Enum{Name: "a", Value: 1}, "a") t, err := mysql.ParseTime("2011-11-10 11:11:11.999999", mysql.TypeTimestamp, 6) c.Assert(err, IsNil) @@ -345,6 +362,7 @@ func (s *testTypeConvertSuite) TestConvertToBool(c *C) { testToBool(c, []byte("0"), 0) testToBool(c, mysql.Hex{Value: 0}, 0) testToBool(c, mysql.Bit{Value: 0, Width: 8}, 0) + testToBool(c, mysql.Enum{Name: "a", Value: 1}, 1) t, err := mysql.ParseTime("2011-11-10 11:11:11.999999", mysql.TypeTimestamp, 6) c.Assert(err, IsNil) diff --git a/util/types/etc.go b/util/types/etc.go index a2f13a78d4..b9fc94ca5d 100644 --- a/util/types/etc.go +++ b/util/types/etc.go @@ -248,7 +248,7 @@ func IsOrderedType(v interface{}) (r bool) { uint, uint8, uint16, uint32, uint64, float32, float64, string, []byte, mysql.Decimal, mysql.Time, mysql.Duration, - mysql.Hex, mysql.Bit: + mysql.Hex, mysql.Bit, mysql.Enum: return true } return false @@ -264,7 +264,7 @@ func Clone(from interface{}) (interface{}, error) { case uint8, uint16, uint32, uint64, float32, float64, int16, int8, bool, string, int, int64, int32, mysql.Time, mysql.Duration, mysql.Decimal, - mysql.Hex, mysql.Bit: + mysql.Hex, mysql.Bit, mysql.Enum: return x, nil case []byte: target := make([]byte, len(from.([]byte))) @@ -356,6 +356,8 @@ func Coerce(a, b interface{}) (x, y interface{}) { x = v.ToNumber() case mysql.Bit: x = v.ToNumber() + case mysql.Enum: + x = v.ToNumber() } switch v := y.(type) { case int64: @@ -366,6 +368,8 @@ func Coerce(a, b interface{}) (x, y interface{}) { y = v.ToNumber() case mysql.Bit: y = v.ToNumber() + case mysql.Enum: + y = v.ToNumber() } } return diff --git a/util/types/etc_test.go b/util/types/etc_test.go index a2101682e3..a9229afde9 100644 --- a/util/types/etc_test.go +++ b/util/types/etc_test.go @@ -151,6 +151,7 @@ func (s *testTypeEtcSuite) TestClone(c *C) { checkClone(c, make(map[int]string), false) checkClone(c, mysql.Hex{Value: 1}, true) checkClone(c, mysql.Bit{Value: 1, Width: 1}, true) + checkClone(c, mysql.Enum{Name: "a", Value: 1}, true) } func checkCoerce(c *C, a, b interface{}) {