diff --git a/evaluator/builtin_time.go b/evaluator/builtin_time.go index b28778191f..baacb77d05 100644 --- a/evaluator/builtin_time.go +++ b/evaluator/builtin_time.go @@ -184,7 +184,7 @@ func builtinNow(args []types.Datum, _ context.Context) (d types.Datum, err error d.SetNull() return d, errors.Trace(err) } - d.SetMysqlTime(tr) + d.SetMysqlTime(&tr) return d, nil } @@ -373,7 +373,7 @@ func builtinCurrentDate(args []types.Datum, _ context.Context) (d types.Datum, e t := mysql.Time{ Time: time.Date(year, month, day, 0, 0, 0, 0, time.Local), Type: mysql.TypeDate, Fsp: 0} - d.SetMysqlTime(t) + d.SetMysqlTime(&t) return d, nil } diff --git a/mysql/time.go b/mysql/time.go index 7f98199415..678a5bc34f 100644 --- a/mysql/time.go +++ b/mysql/time.go @@ -280,7 +280,10 @@ func (t Time) ConvertToDuration() (Duration, error) { // Compare returns an integer comparing the time instant t to o. // If t is after o, return 1, equal o, return 0, before o, return -1. -func (t Time) Compare(o Time) int { +func (t *Time) Compare(o *Time) int { + if o == nil { + return 1 + } if t.Time.After(o.Time) { return 1 } else if t.Time.Equal(o.Time) { @@ -299,7 +302,7 @@ func (t Time) CompareString(str string) (int, error) { return 0, errors.Trace(err) } - return t.Compare(o), nil + return t.Compare(&o), nil } // RoundFrac rounds fractional seconds precision with new fsp and returns a new one. @@ -1027,7 +1030,10 @@ func checkTimestamp(t Time) bool { } // ExtractTimeNum extracts time value number from time unit and format. -func ExtractTimeNum(unit string, t Time) (int64, error) { +func ExtractTimeNum(unit string, t *Time) (int64, error) { + if t == nil { + return 0, errors.Errorf("invalid time t") + } switch strings.ToUpper(unit) { case "MICROSECOND": return int64(t.Nanosecond() / 1000), nil diff --git a/tidb-server/server/util.go b/tidb-server/server/util.go index 4d6c76d3d7..c143f48bf3 100644 --- a/tidb-server/server/util.go +++ b/tidb-server/server/util.go @@ -293,7 +293,7 @@ func dumpRowValuesBinary(alloc arena.Allocator, columns []*ColumnInfo, row []typ case types.KindMysqlDecimal: data = append(data, dumpLengthEncodedString(hack.Slice(val.GetMysqlDecimal().String()), alloc)...) case types.KindMysqlTime: - data = append(data, dumpBinaryDateTime(val.GetMysqlTime(), nil)...) + data = append(data, dumpBinaryDateTime(*val.GetMysqlTime(), nil)...) case types.KindMysqlDuration: data = append(data, dumpBinaryTime(val.GetMysqlDuration().Duration)...) case types.KindMysqlSet: diff --git a/util/types/datum.go b/util/types/datum.go index fc5744e8be..6d038342ed 100644 --- a/util/types/datum.go +++ b/util/types/datum.go @@ -233,12 +233,12 @@ func (d *Datum) SetMysqlSet(b mysql.Set) { } // GetMysqlTime gets mysql.Time value -func (d *Datum) GetMysqlTime() mysql.Time { - return d.x.(mysql.Time) +func (d *Datum) GetMysqlTime() *mysql.Time { + return d.x.(*mysql.Time) } // SetMysqlTime sets mysql.Time value -func (d *Datum) SetMysqlTime(b mysql.Time) { +func (d *Datum) SetMysqlTime(b *mysql.Time) { d.k = KindMysqlTime d.x = b } @@ -258,6 +258,8 @@ func (d *Datum) GetValue() interface{} { return d.GetString() case KindBytes: return d.GetBytes() + case KindMysqlTime: + return *d.GetMysqlTime() default: return d.x } @@ -306,9 +308,10 @@ func (d *Datum) SetValue(val interface{}) { case mysql.Set: d.x = x d.k = KindMysqlSet + case *mysql.Time: + d.SetMysqlTime(x) case mysql.Time: - d.x = x - d.k = KindMysqlTime + d.SetMysqlTime(&x) case []Datum: d.x = x d.k = KindRow @@ -458,7 +461,7 @@ func (d *Datum) compareString(s string) (int, error) { return d.GetMysqlDecimal().Cmp(dec), err case KindMysqlTime: dt, err := mysql.ParseDatetime(s) - return d.GetMysqlTime().Compare(dt), err + return d.GetMysqlTime().Compare(&dt), err case KindMysqlDuration: dur, err := mysql.ParseDuration(s, mysql.MaxFsp) return d.GetMysqlDuration().Compare(dur), err @@ -544,7 +547,10 @@ func (d *Datum) compareMysqlSet(set mysql.Set) (int, error) { } } -func (d *Datum) compareMysqlTime(time mysql.Time) (int, error) { +func (d *Datum) compareMysqlTime(time *mysql.Time) (int, error) { + if time == nil { + return 0, errors.Errorf("invalid time t") + } switch d.k { case KindString, KindBytes: dt, err := mysql.ParseDatetime(d.GetString())