diff --git a/server/server_test.go b/server/server_test.go index 4af0a2704a..70b236f01f 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -237,17 +237,19 @@ func runTestPrepareResultFieldType(t *C) { func runTestSpecialType(t *C) { runTestsOnNewDB(t, nil, "SpecialType", func(dbt *DBTest) { - dbt.mustExec("create table test (a decimal(10, 5), b datetime, c time)") - dbt.mustExec("insert test values (1.4, '2012-12-21 12:12:12', '4:23:34')") + dbt.mustExec("create table test (a decimal(10, 5), b datetime, c time, d bit(8))") + dbt.mustExec("insert test values (1.4, '2012-12-21 12:12:12', '4:23:34', b'1000')") rows := dbt.mustQuery("select * from test where a > ?", 0) t.Assert(rows.Next(), IsTrue) var outA float64 var outB, outC string - err := rows.Scan(&outA, &outB, &outC) + var outD []byte + err := rows.Scan(&outA, &outB, &outC, &outD) t.Assert(err, IsNil) t.Assert(outA, Equals, 1.4) t.Assert(outB, Equals, "2012-12-21 12:12:12") t.Assert(outC, Equals, "04:23:34") + t.Assert(outD, BytesEquals, []byte{8}) }) } diff --git a/server/util.go b/server/util.go index be47d20716..32e0d3d15b 100644 --- a/server/util.go +++ b/server/util.go @@ -253,7 +253,7 @@ func dumpBinaryRow(buffer []byte, columns []*ColumnInfo, row types.Row) ([]byte, case mysql.TypeNewDecimal: v, _ := row.GetMyDecimal(i) buffer = dumpLengthEncodedString(buffer, hack.Slice(v.String())) - case mysql.TypeString, mysql.TypeVarString, mysql.TypeVarchar, + case mysql.TypeString, mysql.TypeVarString, mysql.TypeVarchar, mysql.TypeBit, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob, mysql.TypeBlob: v, _ := row.GetBytes(i) buffer = dumpLengthEncodedString(buffer, v) @@ -323,7 +323,7 @@ func dumpTextRow(buffer []byte, columns []*ColumnInfo, row types.Row) ([]byte, e case mysql.TypeNewDecimal: v, _ := row.GetMyDecimal(i) buffer = dumpLengthEncodedString(buffer, hack.Slice(v.String())) - case mysql.TypeString, mysql.TypeVarString, mysql.TypeVarchar, + case mysql.TypeString, mysql.TypeVarString, mysql.TypeVarchar, mysql.TypeBit, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob, mysql.TypeBlob: v, _ := row.GetBytes(i) buffer = dumpLengthEncodedString(buffer, v) diff --git a/util/chunk/chunk.go b/util/chunk/chunk.go index b8de292135..653a42bcfd 100644 --- a/util/chunk/chunk.go +++ b/util/chunk/chunk.go @@ -55,6 +55,25 @@ func (c *Chunk) AddInterfaceColumn() { }) } +// AddColumnByFieldType adds a column by field type. +func (c *Chunk) AddColumnByFieldType(fieldTp byte, initCap int) { + switch fieldTp { + case mysql.TypeFloat: + c.AddFixedLenColumn(4, initCap) + case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, + mysql.TypeDouble: + c.AddFixedLenColumn(8, initCap) + case mysql.TypeDuration: + c.AddFixedLenColumn(16, initCap) + case mysql.TypeNewDecimal: + c.AddFixedLenColumn(types.MyDecimalStructSize, initCap) + case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp, mysql.TypeJSON: + c.AddInterfaceColumn() + default: + c.AddVarLenColumn(initCap) + } +} + // Reset resets the chunk, so the memory it allocated can be reused. // Make sure all the data in the chunk is not used anymore before you reuse this chunk. func (c *Chunk) Reset() { @@ -193,7 +212,10 @@ func (c *column) reset() { c.length = 0 c.nullCount = 0 c.nullBitmap = c.nullBitmap[:0] - c.offsets = c.offsets[:0] + if len(c.offsets) > 0 { + // The first offset is always 0, it makes slicing the data easier, we need to keep it. + c.offsets = c.offsets[:1] + } c.data = c.data[:0] c.ifaces = c.ifaces[:0] } @@ -320,7 +342,7 @@ func (r Row) GetUint64(colIdx int) (uint64, bool) { // GetFloat32 returns the float64 value and isNull with the colIdx. func (r Row) GetFloat32(colIdx int) (float32, bool) { col := r.c.columns[colIdx] - return *(*float32)(unsafe.Pointer(&col.data[r.idx*8])), col.isNull(r.idx) + return *(*float32)(unsafe.Pointer(&col.data[r.idx*4])), col.isNull(r.idx) } // GetFloat64 returns the float64 value and isNull with the colIdx. @@ -449,6 +471,11 @@ func (r Row) GetDatum(colIdx int, tp *types.FieldType) types.Datum { if !isNull { d.SetMysqlSet(val) } + case mysql.TypeBit: + val, isNull := r.GetBytes(colIdx) + if !isNull { + d.SetMysqlBit(val) + } case mysql.TypeJSON: val, isNull := r.GetJSON(colIdx) if !isNull { diff --git a/util/chunk/chunk_test.go b/util/chunk/chunk_test.go index 4e537f9847..d3ecd1decf 100644 --- a/util/chunk/chunk_test.go +++ b/util/chunk/chunk_test.go @@ -121,6 +121,22 @@ func (s *testChunkSuite) TestChunk(c *C) { iVal, _ = chk.GetRow(0).GetInt64(1) c.Assert(iVal, Equals, int64(1)) c.Assert(chk.NumRows(), Equals, 1) + + // Test Reset. + chk = newChunk(0) + chk.AppendString(0, "abcd") + chk.Reset() + chk.AppendString(0, "def") + strVal, _ := chk.GetRow(0).GetString(0) + c.Assert(strVal, Equals, "def") + + // Test float32 + chk = newChunk(4) + chk.AppendFloat32(0, 1) + chk.AppendFloat32(0, 1) + chk.AppendFloat32(0, 1) + f32, _ = chk.GetRow(2).GetFloat32(0) + c.Assert(f32, Equals, float32(1)) } // newChunk creates a new chunk and initialize columns with element length. diff --git a/util/codec/codec.go b/util/codec/codec.go index 5a102060f5..f6e674880b 100644 --- a/util/codec/codec.go +++ b/util/codec/codec.go @@ -22,6 +22,7 @@ import ( "github.com/pingcap/tidb/terror" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/types/json" + "github.com/pingcap/tidb/util/chunk" ) // First byte in the encoded value which specifies the encoding type. @@ -383,3 +384,157 @@ func peekUvarint(b []byte) (int, error) { } return n, nil } + +// DecodeOneToChunk decodes one value to chunk and returns the remained bytes. +func DecodeOneToChunk(b []byte, chk *chunk.Chunk, colIdx int, ft *types.FieldType, loc *time.Location) (remain []byte, err error) { + if len(b) < 1 { + return nil, errors.New("invalid encoded key") + } + flag := b[0] + b = b[1:] + switch flag { + case intFlag: + var v int64 + b, v, err = DecodeInt(b) + if err != nil { + return nil, errors.Trace(err) + } + appendIntToChunk(v, chk, colIdx, ft) + case uintFlag: + var v uint64 + b, v, err = DecodeUint(b) + if err != nil { + return nil, errors.Trace(err) + } + err = appendUintToChunk(v, chk, colIdx, ft, loc) + case varintFlag: + var v int64 + b, v, err = DecodeVarint(b) + if err != nil { + return nil, errors.Trace(err) + } + appendIntToChunk(v, chk, colIdx, ft) + case uvarintFlag: + var v uint64 + b, v, err = DecodeUvarint(b) + if err != nil { + return nil, errors.Trace(err) + } + err = appendUintToChunk(v, chk, colIdx, ft, loc) + case floatFlag: + var v float64 + b, v, err = DecodeFloat(b) + if err != nil { + return nil, errors.Trace(err) + } + appendFloatToChunk(v, chk, colIdx, ft) + case bytesFlag: + var v []byte + b, v, err = DecodeBytes(b) + if err != nil { + return nil, errors.Trace(err) + } + chk.AppendBytes(colIdx, v) + case compactBytesFlag: + var v []byte + b, v, err = DecodeCompactBytes(b) + if err != nil { + return nil, errors.Trace(err) + } + chk.AppendBytes(colIdx, v) + case decimalFlag: + var d types.Datum + b, d, err = DecodeDecimal(b) + if err != nil { + return nil, errors.Trace(err) + } + chk.AppendMyDecimal(colIdx, d.GetMysqlDecimal()) + case durationFlag: + var r int64 + b, r, err = DecodeInt(b) + if err != nil { + return nil, errors.Trace(err) + } + v := types.Duration{Duration: time.Duration(r), Fsp: ft.Decimal} + chk.AppendDuration(colIdx, v) + case jsonFlag: + var size int + size, err = json.PeekBytesAsJSON(b) + if err != nil { + return nil, errors.Trace(err) + } + var j json.JSON + j, err = json.Deserialize(b) + if err != nil { + return nil, errors.Trace(err) + } + b = b[size:] + chk.AppendJSON(colIdx, j) + case NilFlag: + chk.AppendNull(colIdx) + default: + return nil, errors.Errorf("invalid encoded key flag %v", flag) + } + if err != nil { + return nil, errors.Trace(err) + } + return b, nil +} + +func appendIntToChunk(val int64, chk *chunk.Chunk, colIdx int, ft *types.FieldType) { + switch ft.Tp { + case mysql.TypeDuration: + v := types.Duration{Duration: time.Duration(val), Fsp: ft.Decimal} + chk.AppendDuration(colIdx, v) + default: + chk.AppendInt64(colIdx, val) + } +} + +func appendUintToChunk(val uint64, chk *chunk.Chunk, colIdx int, ft *types.FieldType, loc *time.Location) error { + switch ft.Tp { + case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp: + var t types.Time + t.Type = ft.Tp + t.Fsp = ft.Decimal + var err error + err = t.FromPackedUint(val) + if err != nil { + return errors.Trace(err) + } + if ft.Tp == mysql.TypeTimestamp && !t.IsZero() { + err = t.ConvertTimeZone(time.UTC, loc) + if err != nil { + return errors.Trace(err) + } + } + chk.AppendTime(colIdx, t) + case mysql.TypeEnum: + // ignore error deliberately, to read empty enum value. + enum, err := types.ParseEnumValue(ft.Elems, val) + if err != nil { + enum = types.Enum{} + } + chk.AppendEnum(colIdx, enum) + case mysql.TypeSet: + set, err := types.ParseSetValue(ft.Elems, val) + if err != nil { + return errors.Trace(err) + } + chk.AppendSet(colIdx, set) + case mysql.TypeBit: + byteSize := (ft.Flen + 7) >> 3 + chk.AppendBytes(colIdx, types.NewBinaryLiteralFromUint(val, byteSize)) + default: + chk.AppendUint64(colIdx, val) + } + return nil +} + +func appendFloatToChunk(val float64, chk *chunk.Chunk, colIdx int, ft *types.FieldType) { + if ft.Tp == mysql.TypeFloat { + chk.AppendFloat32(colIdx, float32(val)) + } else { + chk.AppendFloat64(colIdx, val) + } +} diff --git a/util/codec/codec_test.go b/util/codec/codec_test.go index 110d78cf3c..0e2c3aac87 100644 --- a/util/codec/codec_test.go +++ b/util/codec/codec_test.go @@ -17,12 +17,14 @@ import ( "bytes" "math" "testing" + "time" . "github.com/pingcap/check" "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/types/json" + "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/testleak" ) @@ -866,3 +868,71 @@ func (s *testCodecSuite) TestSetRawValues(c *C) { c.Assert(encoded, BytesEquals, rawVal.GetBytes()) } } + +func (s *testCodecSuite) TestDecodeOneToChunk(c *C) { + defer testleak.AfterTest(c)() + table := []struct { + value interface{} + tp *types.FieldType + }{ + {nil, types.NewFieldType(mysql.TypeLonglong)}, + {int64(1), types.NewFieldType(mysql.TypeTiny)}, + {int64(1), types.NewFieldType(mysql.TypeShort)}, + {int64(1), types.NewFieldType(mysql.TypeInt24)}, + {int64(1), types.NewFieldType(mysql.TypeLong)}, + {int64(1), types.NewFieldType(mysql.TypeLonglong)}, + {float32(1), types.NewFieldType(mysql.TypeFloat)}, + {float64(1), types.NewFieldType(mysql.TypeDouble)}, + {types.NewDecFromInt(1), types.NewFieldType(mysql.TypeNewDecimal)}, + {"abc", types.NewFieldType(mysql.TypeString)}, + {"def", types.NewFieldType(mysql.TypeVarchar)}, + {"ghi", types.NewFieldType(mysql.TypeVarString)}, + {[]byte("abc"), types.NewFieldType(mysql.TypeBlob)}, + {[]byte("abc"), types.NewFieldType(mysql.TypeTinyBlob)}, + {[]byte("abc"), types.NewFieldType(mysql.TypeMediumBlob)}, + {[]byte("abc"), types.NewFieldType(mysql.TypeLongBlob)}, + {types.CurrentTime(mysql.TypeDatetime), types.NewFieldType(mysql.TypeDatetime)}, + {types.CurrentTime(mysql.TypeDate), types.NewFieldType(mysql.TypeDate)}, + {types.Time{ + Time: types.FromGoTime(time.Now()), + Type: mysql.TypeTimestamp, + TimeZone: time.Local, + }, types.NewFieldType(mysql.TypeTimestamp)}, + {types.Duration{Duration: time.Second, Fsp: 1}, types.NewFieldType(mysql.TypeDuration)}, + {types.Enum{"a", 0}, &types.FieldType{Tp: mysql.TypeEnum, Elems: []string{"a"}}}, + {types.Set{"a", 0}, &types.FieldType{Tp: mysql.TypeSet, Elems: []string{"a"}}}, + {types.BinaryLiteral{100}, &types.FieldType{Tp: mysql.TypeBit, Flen: 8}}, + {json.CreateJSON("abc"), types.NewFieldType(mysql.TypeJSON)}, + } + chk := new(chunk.Chunk) + var datums []types.Datum + for _, t := range table { + chk.AddColumnByFieldType(t.tp.Tp, 0) + datums = append(datums, types.NewDatum(t.value)) + } + + rowCount := 3 + for rowIdx := 0; rowIdx < rowCount; rowIdx++ { + encoded, err := EncodeValue(nil, datums...) + c.Assert(err, IsNil) + for colIdx, t := range table { + encoded, err = DecodeOneToChunk(encoded, chk, colIdx, t.tp, time.Local) + c.Assert(err, IsNil) + } + } + + sc := new(variable.StatementContext) + for colIdx, t := range table { + for rowIdx := 0; rowIdx < rowCount; rowIdx++ { + got := chk.GetRow(rowIdx).GetDatum(colIdx, t.tp) + expect := datums[colIdx] + if got.IsNull() { + c.Assert(expect.IsNull(), IsTrue) + } else { + cmp, err := got.CompareDatum(sc, &expect) + c.Assert(err, IsNil) + c.Assert(cmp, Equals, 0) + } + } + } +}