From c50bfbce0bb9422afe9566eef0a76d8d52eeec03 Mon Sep 17 00:00:00 2001 From: HuaiyuXu <391585975@qq.com> Date: Wed, 30 Aug 2017 08:30:57 +0800 Subject: [PATCH] *: rewrite builtin function: STR_TO_DATE (#4357) --- expression/builtin_time.go | 130 +++++++++++++++++++++++++++----- expression/builtin_time_test.go | 1 + expression/integration_test.go | 10 +++ plan/typeinfer_test.go | 6 ++ util/types/time.go | 43 +++++++++++ 5 files changed, 172 insertions(+), 18 deletions(-) diff --git a/expression/builtin_time.go b/expression/builtin_time.go index 6fa5941ad9..1c8fde7667 100644 --- a/expression/builtin_time.go +++ b/expression/builtin_time.go @@ -131,7 +131,6 @@ var ( _ builtinFunc = &builtinYearWeekWithoutModeSig{} _ builtinFunc = &builtinFromUnixTimeSig{} _ builtinFunc = &builtinGetFormatSig{} - _ builtinFunc = &builtinStrToDateSig{} _ builtinFunc = &builtinSysDateWithFspSig{} _ builtinFunc = &builtinSysDateWithoutFspSig{} _ builtinFunc = &builtinCurrentDateSig{} @@ -161,6 +160,9 @@ var ( _ builtinFunc = &builtinTimestamp1ArgSig{} _ builtinFunc = &builtinTimestamp2ArgsSig{} _ builtinFunc = &builtinLastDaySig{} + _ builtinFunc = &builtinStrToDateDateSig{} + _ builtinFunc = &builtinStrToDateDatetimeSig{} + _ builtinFunc = &builtinStrToDateDurationSig{} ) // handleInvalidTimeError reports error or warning depend on the context. @@ -1257,37 +1259,129 @@ type strToDateFunctionClass struct { baseFunctionClass } -func (c *strToDateFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) { +func (c *strToDateFunctionClass) getRetTp(arg Expression, ctx context.Context) (tp byte, fsp int) { + tp = mysql.TypeDatetime + if _, ok := arg.(*Constant); !ok { + return tp, types.MaxFsp + } + strArg := WrapWithCastAsString(arg, ctx) + format, isNull, err := strArg.EvalString(nil, ctx.GetSessionVars().StmtCtx) + if err != nil || isNull { + return + } + isDuration, isDate := types.GetFormatType(format) + if isDuration && !isDate { + tp = mysql.TypeDuration + } else if !isDuration && isDate { + tp = mysql.TypeDate + } + if strings.Index(format, "%f") >= 0 { + fsp = types.MaxFsp + } + return +} + +// See https://dev.mysql.com/doc/refman/5.5/en/date-and-time-functions.html#function_str-to-date +func (c *strToDateFunctionClass) getFunction(args []Expression, ctx context.Context) (sig builtinFunc, err error) { if err := c.verifyArgs(args); err != nil { return nil, errors.Trace(err) } - sig := &builtinStrToDateSig{newBaseBuiltinFunc(args, ctx)} + retTp, fsp := c.getRetTp(args[1], ctx) + switch retTp { + case mysql.TypeDate: + bf := newBaseBuiltinFuncWithTp(args, ctx, tpDatetime, tpString, tpString) + bf.tp.Tp, bf.tp.Flen, bf.tp.Decimal = mysql.TypeDate, mysql.MaxDateWidth, types.MinFsp + sig = &builtinStrToDateDateSig{baseTimeBuiltinFunc{bf}} + case mysql.TypeDatetime: + bf := newBaseBuiltinFuncWithTp(args, ctx, tpDatetime, tpString, tpString) + if fsp == types.MinFsp { + bf.tp.Flen, bf.tp.Decimal = mysql.MaxDatetimeWidthNoFsp, types.MinFsp + } else { + bf.tp.Flen, bf.tp.Decimal = mysql.MaxDatetimeWidthWithFsp, types.MaxFsp + } + sig = &builtinStrToDateDatetimeSig{baseTimeBuiltinFunc{bf}} + case mysql.TypeDuration: + bf := newBaseBuiltinFuncWithTp(args, ctx, tpDuration, tpString, tpString) + if fsp == types.MinFsp { + bf.tp.Flen, bf.tp.Decimal = mysql.MaxDurationWidthNoFsp, types.MinFsp + } else { + bf.tp.Flen, bf.tp.Decimal = mysql.MaxDurationWidthWithFsp, types.MaxFsp + } + sig = &builtinStrToDateDurationSig{baseDurationBuiltinFunc{bf}} + } return sig.setSelf(sig), nil } -type builtinStrToDateSig struct { - baseBuiltinFunc +type builtinStrToDateDateSig struct { + baseTimeBuiltinFunc } -// eval evals a builtinStrToDateSig. -// See https://dev.mysql.com/doc/refman/5.5/en/date-and-time-functions.html#function_str-to-date -func (b *builtinStrToDateSig) eval(row []types.Datum) (d types.Datum, err error) { - args, err := b.evalArgs(row) - if err != nil { - return d, errors.Trace(err) +func (b *builtinStrToDateDateSig) evalTime(row []types.Datum) (types.Time, bool, error) { + sc := b.ctx.GetSessionVars().StmtCtx + date, isNull, err := b.args[0].EvalString(row, sc) + if isNull || err != nil { + return types.Time{}, isNull, errors.Trace(err) + } + format, isNull, err := b.args[1].EvalString(row, sc) + if isNull || err != nil { + return types.Time{}, isNull, errors.Trace(err) } - date := args[0].GetString() - format := args[1].GetString() var t types.Time - succ := t.StrToDate(date, format) if !succ { - d.SetNull() - return + return types.Time{}, true, handleInvalidTimeError(b.ctx, types.ErrInvalidTimeFormat) } + t.Type, t.Fsp = mysql.TypeDate, types.MinFsp + return t, false, nil +} - d.SetMysqlTime(t) - return +type builtinStrToDateDatetimeSig struct { + baseTimeBuiltinFunc +} + +func (b *builtinStrToDateDatetimeSig) evalTime(row []types.Datum) (types.Time, bool, error) { + sc := b.ctx.GetSessionVars().StmtCtx + date, isNull, err := b.args[0].EvalString(row, sc) + if isNull || err != nil { + return types.Time{}, isNull, errors.Trace(err) + } + format, isNull, err := b.args[1].EvalString(row, sc) + if isNull || err != nil { + return types.Time{}, isNull, errors.Trace(err) + } + var t types.Time + succ := t.StrToDate(date, format) + if !succ { + return types.Time{}, true, handleInvalidTimeError(b.ctx, types.ErrInvalidTimeFormat) + } + t.Type, t.Fsp = mysql.TypeDatetime, b.tp.Decimal + return t, false, nil +} + +type builtinStrToDateDurationSig struct { + baseDurationBuiltinFunc +} + +// TODO: If the NO_ZERO_DATE or NO_ZERO_IN_DATE SQL mode is enabled, zero dates or part of dates are disallowed. +// In that case, STR_TO_DATE() returns NULL and generates a warning. +func (b *builtinStrToDateDurationSig) evalDuration(row []types.Datum) (types.Duration, bool, error) { + sc := b.ctx.GetSessionVars().StmtCtx + date, isNull, err := b.args[0].EvalString(row, sc) + if isNull || err != nil { + return types.Duration{}, isNull, errors.Trace(err) + } + format, isNull, err := b.args[1].EvalString(row, sc) + if isNull || err != nil { + return types.Duration{}, isNull, errors.Trace(err) + } + var t types.Time + succ := t.StrToDate(date, format) + if !succ { + return types.Duration{}, true, handleInvalidTimeError(b.ctx, types.ErrInvalidTimeFormat) + } + t.Fsp = b.tp.Decimal + dur, err := t.ConvertToDuration() + return dur, false, errors.Trace(err) } type sysDateFunctionClass struct { diff --git a/expression/builtin_time_test.go b/expression/builtin_time_test.go index 8038e54e7e..3f8d7c2d5b 100644 --- a/expression/builtin_time_test.go +++ b/expression/builtin_time_test.go @@ -1109,6 +1109,7 @@ func (s *testEvaluatorSuite) TestStrToDate(c *C) { format := types.NewStringDatum(test.Format) f, err := fc.getFunction(datumsToConstants([]types.Datum{date, format}), s.ctx) c.Assert(err, IsNil) + c.Assert(f.isDeterministic(), IsTrue) result, err := f.eval(nil) c.Assert(err, IsNil) if !test.Success { diff --git a/expression/integration_test.go b/expression/integration_test.go index e5510f8ea8..690709a932 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -1296,6 +1296,16 @@ func (s *testIntegrationSuite) TestTimeBuiltin(c *C) { result = tk.MustQuery(`select dayname("2017-12-01"), dayname("0000-00-00"), dayname("0000-01-00"), dayname("0000-01-00 00:00:00")`) result.Check(testkit.Rows("Friday ")) tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1105|invalid time format", "Warning|1105|invalid time format", "Warning|1105|invalid time format")) + + // for str_to_date + result = tk.MustQuery("select str_to_date('01-01-2017', '%d-%m-%Y'), str_to_date('59:20:12 01-01-2017', '%s:%i:%H %d-%m-%Y'), str_to_date('59:20:12', '%s:%i:%H')") + result.Check(testkit.Rows("2017-01-01 2017-01-01 12:20:59 12:20:59")) + result = tk.MustQuery("select str_to_date('aaa01-01-2017', 'aaa%d-%m-%Y'), str_to_date('59:20:12 aaa01-01-2017', '%s:%i:%H aaa%d-%m-%Y'), str_to_date('59:20:12aaa', '%s:%i:%Haaa')") + result.Check(testkit.Rows("2017-01-01 2017-01-01 12:20:59 12:20:59")) + result = tk.MustQuery("select str_to_date('01-01-2017', '%d'), str_to_date('59', '%d-%Y')") + // TODO: MySQL returns " ". + result.Check(testkit.Rows("0000-00-01 ")) + tk.MustQuery("show warnings").Check(testutil.RowsWithSep("|", "Warning|1105|invalid time format")) } func (s *testIntegrationSuite) TestOpBuiltin(c *C) { diff --git a/plan/typeinfer_test.go b/plan/typeinfer_test.go index cf5cc1022b..a44a9e4510 100644 --- a/plan/typeinfer_test.go +++ b/plan/typeinfer_test.go @@ -1564,6 +1564,12 @@ func (s *testPlanSuite) createTestCase4TimeFuncs() []typeInferTestCase { {"quarter(c_set )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0}, {"quarter(c_enum )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0}, + {"str_to_date(c_varchar, '%Y:%m:%d')", mysql.TypeDate, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDateWidth, types.MinFsp}, + {"str_to_date(c_varchar, '%Y:%m:%d %H:%i:%s')", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDatetimeWidthNoFsp, types.MinFsp}, + {"str_to_date(c_varchar, '%Y:%m:%d %H:%i:%s.%f')", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDatetimeWidthWithFsp, types.MaxFsp}, + {"str_to_date(c_varchar, '%H:%i:%s')", mysql.TypeDuration, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDurationWidthNoFsp, types.MinFsp}, + {"str_to_date(c_varchar, '%H:%i:%s.%f')", mysql.TypeDuration, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDurationWidthWithFsp, types.MaxFsp}, + {"period_add(c_int_d , c_int_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 6, 0}, {"period_add(c_bigint_d , c_int_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 6, 0}, {"period_add(c_float_d , c_int_d)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 6, 0}, diff --git a/util/types/time.go b/util/types/time.go index 3c78d33602..25f7da2d5b 100644 --- a/util/types/time.go +++ b/util/types/time.go @@ -2043,6 +2043,49 @@ var dateFormatParserTable = map[string]dateFormatParser{ // "%y": yearTwoDigits, // Year, numeric (two digits) } +// GetFormatType checks the type(Duration, Date or Datetime) of a format string. +func GetFormatType(format string) (isDuration, isDate bool) { + durationTokens := map[string]struct{}{ + "%h": {}, + "%H": {}, + "%i": {}, + "%I": {}, + "%s": {}, + "%S": {}, + "%k": {}, + "%l": {}, + } + dateTokens := map[string]struct{}{ + "%y": {}, + "%Y": {}, + "%m": {}, + "%M": {}, + "%c": {}, + "%b": {}, + "%D": {}, + "%d": {}, + "%e": {}, + } + + format = skipWhiteSpace(format) + for token, formatRemain, succ := getFormatToken(format); len(token) != 0; format = formatRemain { + if !succ { + isDuration, isDate = false, false + break + } + if _, ok := durationTokens[token]; ok { + isDuration = true + } else if _, ok := dateTokens[token]; ok { + isDate = true + } + if isDuration && isDate { + break + } + token, formatRemain, succ = getFormatToken(format) + } + return +} + func matchDateWithToken(t *mysqlTime, date string, token string, ctx map[string]int) (remain string, succ bool) { if parse, ok := dateFormatParserTable[token]; ok { return parse(t, date, ctx)