From d4bfcf02725bd2b59d4306cee602f5d491a7a5bb Mon Sep 17 00:00:00 2001 From: Jian Zhang Date: Mon, 11 Sep 2017 15:48:50 +0800 Subject: [PATCH] expression: rewrite builtin function: CONVERT_TZ (#4463) --- expression/builtin_time.go | 121 ++++++++++++++++++-------------- expression/builtin_time_test.go | 2 +- expression/integration_test.go | 7 ++ plan/typeinfer_test.go | 4 ++ 4 files changed, 82 insertions(+), 52 deletions(-) diff --git a/expression/builtin_time.go b/expression/builtin_time.go index d95d332476..ea9e1be74b 100644 --- a/expression/builtin_time.go +++ b/expression/builtin_time.go @@ -2503,94 +2503,113 @@ type convertTzFunctionClass struct { baseFunctionClass } +func (c *convertTzFunctionClass) getDecimal(ctx context.Context, arg Expression) int { + decimal := types.MaxFsp + if dt, isConstant := arg.(*Constant); isConstant { + switch fieldTp2EvalTp(arg.GetType()) { + case tpInt: + decimal = 0 + case tpReal, tpDecimal: + decimal = arg.GetType().Decimal + case tpString: + str, isNull, err := dt.EvalString(nil, ctx.GetSessionVars().StmtCtx) + if err == nil && !isNull { + decimal = types.DateFSP(str) + } + } + } + if decimal > types.MaxFsp { + return types.MaxFsp + } + if decimal < types.MinFsp { + return types.MinFsp + } + return decimal +} + func (c *convertTzFunctionClass) getFunction(ctx context.Context, args []Expression) (builtinFunc, error) { if err := c.verifyArgs(args); err != nil { return nil, errors.Trace(err) } - sig := &builtinConvertTzSig{newBaseBuiltinFunc(args, ctx)} + // tzRegex holds the regex to check whether a string is a time zone. + tzRegex, err := regexp.Compile(`(^(\+|-)(0?[0-9]|1[0-2]):[0-5]?\d$)|(^\+13:00$)`) + if err != nil { + return nil, errors.Trace(err) + } + + decimal := c.getDecimal(ctx, args[0]) + bf := newBaseBuiltinFuncWithTp(args, ctx, tpDatetime, tpDatetime, tpString, tpString) + bf.tp.Decimal = decimal + sig := &builtinConvertTzSig{ + baseTimeBuiltinFunc: baseTimeBuiltinFunc{bf}, + timezoneRegex: tzRegex, + } return sig.setSelf(sig), nil } type builtinConvertTzSig struct { - baseBuiltinFunc + baseTimeBuiltinFunc + timezoneRegex *regexp.Regexp } -// eval evals a builtinConvertTzSig. +// evalTime evals CONVERT_TZ(dt,from_tz,to_tz). // See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_convert-tz -func (b *builtinConvertTzSig) eval(row []types.Datum) (d types.Datum, err error) { - args, err := b.evalArgs(row) - if err != nil { - return d, errors.Trace(err) - } - - if args[0].IsNull() || args[1].IsNull() || args[2].IsNull() { - return - } - +func (b *builtinConvertTzSig) evalTime(row []types.Datum) (types.Time, bool, error) { sc := b.ctx.GetSessionVars().StmtCtx - fsp := 0 - if args[0].Kind() == types.KindString { - fsp = types.DateFSP(args[0].GetString()) + dt, isNull, err := b.args[0].EvalTime(row, sc) + if isNull || err != nil { + return types.Time{}, true, errors.Trace(err) } - arg0, err := convertToTimeWithFsp(sc, args[0], mysql.TypeDatetime, fsp) - if err != nil { - return d, errors.Trace(err) + fromTzStr, isNull, err := b.args[1].EvalString(row, sc) + if isNull || err != nil { + return types.Time{}, true, errors.Trace(err) } - if arg0.IsNull() { - return + toTzStr, isNull, err := b.args[2].EvalString(row, sc) + if isNull || err != nil { + return types.Time{}, true, errors.Trace(err) } - dt := arg0.GetMysqlTime() + fromTzMatched := b.timezoneRegex.MatchString(fromTzStr) + toTzMatched := b.timezoneRegex.MatchString(toTzStr) - fromTZ := args[1].GetString() - toTZ := args[2].GetString() - - const tzArgReg = `(^(\+|-)(0?[0-9]|1[0-2]):[0-5]?\d$)|(^\+13:00$)` - r, _ := regexp.Compile(tzArgReg) - fmatch := r.MatchString(fromTZ) - tmatch := r.MatchString(toTZ) - - if !fmatch && !tmatch { - ftz, err := time.LoadLocation(fromTZ) + if !fromTzMatched && !toTzMatched { + fromTz, err := time.LoadLocation(fromTzStr) if err != nil { - return d, errors.Trace(err) + return types.Time{}, true, errors.Trace(err) } - ttz, err := time.LoadLocation(toTZ) + toTz, err := time.LoadLocation(toTzStr) if err != nil { - return d, errors.Trace(err) + return types.Time{}, true, errors.Trace(err) } - t, err := dt.Time.GoTime(ftz) + t, err := dt.Time.GoTime(fromTz) if err != nil { - return d, errors.Trace(err) + return types.Time{}, true, errors.Trace(err) } - d.SetMysqlTime(types.Time{ - Time: types.FromGoTime(t.In(ttz)), + return types.Time{ + Time: types.FromGoTime(t.In(toTz)), Type: mysql.TypeDatetime, - Fsp: dt.Fsp, - }) - return d, nil + Fsp: b.tp.Decimal, + }, false, nil } - - if fmatch && tmatch { + if fromTzMatched && toTzMatched { t, err := dt.Time.GoTime(time.Local) if err != nil { - return d, errors.Trace(err) + return types.Time{}, true, errors.Trace(err) } - d.SetMysqlTime(types.Time{ - Time: types.FromGoTime(t.Add(timeZone2Duration(toTZ) - timeZone2Duration(fromTZ))), + return types.Time{ + Time: types.FromGoTime(t.Add(timeZone2Duration(toTzStr) - timeZone2Duration(fromTzStr))), Type: mysql.TypeDatetime, - Fsp: dt.Fsp, - }) + Fsp: b.tp.Decimal, + }, false, nil } - - return + return types.Time{}, true, nil } type makeDateFunctionClass struct { diff --git a/expression/builtin_time_test.go b/expression/builtin_time_test.go index 2ff2a4ddc0..ba3cf8fb43 100644 --- a/expression/builtin_time_test.go +++ b/expression/builtin_time_test.go @@ -2068,7 +2068,7 @@ func (s *testEvaluatorSuite) TestConvertTz(c *C) { c.Assert(err, NotNil) } result, _ := d.ToString() - c.Assert(result, Equals, test.expect) + c.Assert(result, Equals, test.expect, Commentf("convert_tz(\"%v\", \"%s\", \"%s\")", test.t, test.fromTz, test.toTz)) } } diff --git a/expression/integration_test.go b/expression/integration_test.go index a6b6c8fb51..2cd2655cb8 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -1485,6 +1485,13 @@ func (s *testIntegrationSuite) TestTimeBuiltin(c *C) { GET_FORMAT(TIME,'ISO'), GET_FORMAT(TIME,'EUR'), GET_FORMAT(TIME,'INTERNAL')`) result.Check(testkit.Rows("%m.%d.%Y %Y-%m-%d %Y-%m-%d %d.%m.%Y %Y%m%d %Y-%m-%d %H.%i.%s %Y-%m-%d %H:%i:%s %Y-%m-%d %H:%i:%s %Y-%m-%d %H.%i.%s %Y%m%d%H%i%s %h:%i:%s %p %H:%i:%s %H:%i:%s %H.%i.%s %H%i%s")) + // for convert_tz + result = tk.MustQuery(`select convert_tz("2004-01-01 12:00:00", "+00:00", "+10:32"), convert_tz("2004-01-01 12:00:00.01", "+00:00", "+10:32"), convert_tz("2004-01-01 12:00:00.01234567", "+00:00", "+10:32");`) + result.Check(testkit.Rows("2004-01-01 22:32:00 2004-01-01 22:32:00.01 2004-01-01 22:32:00.012346")) + // TODO: release the following test after fix #4462 + //result = tk.MustQuery(`select convert_tz(20040101, "+00:00", "+10:32"), convert_tz(20040101.01, "+00:00", "+10:32"), convert_tz(20040101.01234567, "+00:00", "+10:32");`) + //result.Check(testkit.Rows("2004-01-01 10:32:00 2004-01-01 10:32:00.00 2004-01-01 10:32:00.000000")) + // for from_unixtime tk.MustExec(`set @@session.time_zone = "+08:00"`) result = tk.MustQuery(`select from_unixtime(20170101), from_unixtime(20170101.9999999), from_unixtime(20170101.999), from_unixtime(20170101.999, "%Y %D %M %h:%i:%s %x"), from_unixtime(20170101.999, "%Y %D %M %h:%i:%s %x")`) diff --git a/plan/typeinfer_test.go b/plan/typeinfer_test.go index 15bb51ea67..32bfb3b547 100644 --- a/plan/typeinfer_test.go +++ b/plan/typeinfer_test.go @@ -1729,10 +1729,14 @@ func (s *testPlanSuite) createTestCase4TimeFuncs() []typeInferTestCase { {"period_diff(c_enum , c_int_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 6, 0}, {"get_format(DATE, 'USA')", mysql.TypeVarString, charset.CharsetUTF8, 0, 17, types.UnspecifiedLength}, + + {"convert_tz(c_time_d, c_text_d, c_text_d)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDatetimeWidthWithFsp, types.MaxFsp}, + {"from_unixtime(20170101.999)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDatetimeWidthWithFsp, 3}, {"from_unixtime(20170101.1234567)", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDatetimeWidthWithFsp, types.MaxFsp}, {"from_unixtime('20170101.999')", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDatetimeWidthWithFsp, types.MaxFsp}, {"from_unixtime(20170101.123, '%H')", mysql.TypeVarString, charset.CharsetUTF8, 0, 2, types.UnspecifiedLength}, + {"extract(day from c_char)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, {"extract(hour from c_char)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxIntWidth, 0}, }