diff --git a/ast/functions.go b/ast/functions.go index ce33003f91..2e0a461d78 100644 --- a/ast/functions.go +++ b/ast/functions.go @@ -265,12 +265,6 @@ const ( DateSub ) -// DateArithInterval is the struct of DateArith interval part. -type DateArithInterval struct { - Unit string - Interval ExprNode -} - const ( // AggFuncCount is the name of Count function. AggFuncCount = "count" diff --git a/expression/builtin.go b/expression/builtin.go index 193ad24f63..a3ea43cf83 100644 --- a/expression/builtin.go +++ b/expression/builtin.go @@ -162,7 +162,7 @@ var Funcs = map[string]Func{ ast.CurrentDate: {builtinCurrentDate, 0, 0}, ast.CurrentTime: {builtinCurrentTime, 0, 1}, ast.Date: {builtinDate, 1, 1}, - ast.DateArith: {builtinDateArith, 3, 3}, + ast.DateArith: {builtinDateArith, 4, 4}, ast.DateFormat: {builtinDateFormat, 2, 2}, ast.CurrentTimestamp: {builtinNow, 0, 1}, ast.Curtime: {builtinCurrentTime, 0, 1}, diff --git a/expression/builtin_time.go b/expression/builtin_time.go index 270246734c..5a6b5c43bf 100644 --- a/expression/builtin_time.go +++ b/expression/builtin_time.go @@ -638,15 +638,16 @@ func builtinDateArith(args []types.Datum, ctx context.Context) (d types.Datum, e // Op is used for distinguishing date_add and date_sub. // args[0] -> Op // args[1] -> Date - // args[2] -> DateArithInterval + // args[2] -> Interval Value + // args[3] -> Interval Unit // health check for date and interval - if args[1].IsNull() { + if args[1].IsNull() || args[2].IsNull() { return d, nil } nodeDate := args[1] - nodeInterval := args[2].GetInterface().(ast.DateArithInterval) - nodeIntervalIntervalDatum := nodeInterval.Interval.GetDatum() - if nodeIntervalIntervalDatum.IsNull() { + nodeIntervalValue := args[2] + nodeIntervalUnit := args[3].GetString() + if nodeIntervalValue.IsNull() { return d, nil } // parse date @@ -672,7 +673,7 @@ func builtinDateArith(args []types.Datum, ctx context.Context) (d types.Datum, e } } sc := ctx.GetSessionVars().StmtCtx - if types.IsClockUnit(nodeInterval.Unit) { + if types.IsClockUnit(nodeIntervalUnit) { fieldType = mysql.TypeDatetime } resultField = types.NewFieldType(fieldType) @@ -690,24 +691,24 @@ func builtinDateArith(args []types.Datum, ctx context.Context) (d types.Datum, e result := value.GetMysqlTime() // parse interval var interval string - if strings.ToLower(nodeInterval.Unit) == "day" { - day, err1 := parseDayInterval(sc, *nodeIntervalIntervalDatum) + if strings.ToLower(nodeIntervalUnit) == "day" { + day, err1 := parseDayInterval(sc, nodeIntervalValue) if err1 != nil { - return d, errInvalidOperation.Gen("DateArith invalid day interval, need int but got %T", nodeIntervalIntervalDatum.GetString()) + return d, errInvalidOperation.Gen("DateArith invalid day interval, need int but got %T", nodeIntervalValue.GetString()) } interval = fmt.Sprintf("%d", day) } else { - if nodeIntervalIntervalDatum.Kind() == types.KindString { - interval = fmt.Sprintf("%v", nodeIntervalIntervalDatum.GetString()) + if nodeIntervalValue.Kind() == types.KindString { + interval = fmt.Sprintf("%v", nodeIntervalValue.GetString()) } else { - ii, err1 := nodeIntervalIntervalDatum.ToInt64(sc) + ii, err1 := nodeIntervalValue.ToInt64(sc) if err1 != nil { return d, errors.Trace(err1) } interval = fmt.Sprintf("%v", ii) } } - year, month, day, duration, err := types.ExtractTimeValue(nodeInterval.Unit, interval) + year, month, day, duration, err := types.ExtractTimeValue(nodeIntervalUnit, interval) if err != nil { return d, errors.Trace(err) } diff --git a/expression/builtin_time_test.go b/expression/builtin_time_test.go index 4e8fcba955..b2f150bd38 100644 --- a/expression/builtin_time_test.go +++ b/expression/builtin_time_test.go @@ -19,6 +19,7 @@ import ( "time" . "github.com/pingcap/check" + "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/util/mock" "github.com/pingcap/tidb/util/testleak" "github.com/pingcap/tidb/util/testutil" @@ -533,3 +534,29 @@ func (s *testEvaluatorSuite) TestYearWeek(c *C) { c.Assert(err, IsNil) c.Assert(result.IsNull(), IsTrue) } + +func (s *testEvaluatorSuite) TestDateArith(c *C) { + defer testleak.AfterTest(c)() + + date := []string{"2016-12-31", "2017-01-01"} + + args := types.MakeDatums(ast.DateAdd, date[0], 1, "DAY") + v, err := builtinDateArith(args, s.ctx) + c.Assert(err, IsNil) + c.Assert(v.GetMysqlTime().String(), Equals, date[1]) + + args = types.MakeDatums(ast.DateSub, date[1], 1, "DAY") + v, err = builtinDateArith(args, s.ctx) + c.Assert(err, IsNil) + c.Assert(v.GetMysqlTime().String(), Equals, date[0]) + + args = types.MakeDatums(ast.DateAdd, date[0], nil, "DAY") + v, err = builtinDateArith(args, s.ctx) + c.Assert(err, IsNil) + c.Assert(v.IsNull(), IsTrue) + + args = types.MakeDatums(ast.DateSub, date[1], nil, "DAY") + v, err = builtinDateArith(args, s.ctx) + c.Assert(err, IsNil) + c.Assert(v.IsNull(), IsTrue) +} diff --git a/parser/lexer_test.go b/parser/lexer_test.go index 26684a8260..c5f2751615 100644 --- a/parser/lexer_test.go +++ b/parser/lexer_test.go @@ -213,13 +213,11 @@ func (s *testLexerSuite) TestIdentifier(c *C) { func (s *testLexerSuite) TestSpecialComment(c *C) { l := NewScanner("/*!40101 select\n5*/") tok, pos, lit := l.scan() - fmt.Println(tok, pos, lit) c.Assert(tok, Equals, identifier) c.Assert(lit, Equals, "select") c.Assert(pos, Equals, Pos{0, 0, 9}) tok, pos, lit = l.scan() - fmt.Println(tok, pos, lit) c.Assert(tok, Equals, intLit) c.Assert(lit, Equals, "5") c.Assert(pos, Equals, Pos{1, 1, 16}) diff --git a/parser/parser.y b/parser/parser.y index 69f991df45..2e4e4b085b 100644 --- a/parser/parser.y +++ b/parser/parser.y @@ -474,7 +474,6 @@ import ( CreateUserStmt "CREATE User statement" DateArithOpt "Date arith dateadd or datesub option" DateArithMultiFormsOpt "Date arith adddate or subdate option" - DateArithInterval "Date arith interval part" DBName "Database Name" DeallocateStmt "Deallocate prepared statement" Default "DEFAULT clause" @@ -2648,35 +2647,39 @@ FunctionCallNonKeyword: { $$ = &ast.FuncCallExpr{FnName: model.NewCIStr($1), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } -| DateArithOpt '(' Expression ',' "INTERVAL" Expression TimeUnit ')' +| DateArithMultiFormsOpt '(' Expression ',' Expression ')' { - op := ast.NewValueExpr($1) - dateArithInterval := ast.NewValueExpr( - ast.DateArithInterval{ - Unit: $7, - Interval: $6.(ast.ExprNode), - }, - ) - $$ = &ast.FuncCallExpr{ FnName: model.NewCIStr("DATE_ARITH"), Args: []ast.ExprNode{ - op, + ast.NewValueExpr($1), $3.(ast.ExprNode), - dateArithInterval, + $5.(ast.ExprNode), + ast.NewValueExpr("DAY"), }, } } -| DateArithMultiFormsOpt '(' Expression ',' DateArithInterval')' +| DateArithMultiFormsOpt '(' Expression ',' "INTERVAL" Expression TimeUnit ')' { - op := ast.NewValueExpr($1) - dateArithInterval := ast.NewValueExpr($5) $$ = &ast.FuncCallExpr{ FnName: model.NewCIStr("DATE_ARITH"), Args: []ast.ExprNode{ - op, + ast.NewValueExpr($1), $3.(ast.ExprNode), - dateArithInterval, + $6.(ast.ExprNode), + ast.NewValueExpr($7), + }, + } + } +| DateArithOpt '(' Expression ',' "INTERVAL" Expression TimeUnit ')' + { + $$ = &ast.FuncCallExpr{ + FnName: model.NewCIStr("DATE_ARITH"), + Args: []ast.ExprNode{ + ast.NewValueExpr($1), + $3.(ast.ExprNode), + $6.(ast.ExprNode), + ast.NewValueExpr($7), }, } } @@ -3086,19 +3089,6 @@ DateArithMultiFormsOpt: $$ = ast.DateSub } -DateArithInterval: - Expression - { - $$ = ast.DateArithInterval{ - Unit: "day", - Interval: $1.(ast.ExprNode), - } - } -| "INTERVAL" Expression TimeUnit - { - $$ = ast.DateArithInterval{Unit: $3, Interval: $2.(ast.ExprNode)} - } - TrimDirection: "BOTH" { diff --git a/parser/parser_test.go b/parser/parser_test.go index 8c91c81716..d3e2b73115 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -694,6 +694,9 @@ func (s *testParserSuite) TestBuiltin(c *C) { {`select date_add("2011-11-11 10:10:10.123456", interval "11 10:10" day_minute)`, true}, {`select date_add("2011-11-11 10:10:10.123456", interval "11 10" day_hour)`, true}, {`select date_add("2011-11-11 10:10:10.123456", interval "11-11" year_month)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", 10)`, false}, + {`select date_add("2011-11-11 10:10:10.123456", 0.10)`, false}, + {`select date_add("2011-11-11 10:10:10.123456", "11,11")`, false}, // For strcmp {`select strcmp('abc', 'def')`, true}, @@ -747,6 +750,9 @@ func (s *testParserSuite) TestBuiltin(c *C) { {`select date_sub("2011-11-11 10:10:10.123456", interval "11 10:10" day_minute)`, true}, {`select date_sub("2011-11-11 10:10:10.123456", interval "11 10" day_hour)`, true}, {`select date_sub("2011-11-11 10:10:10.123456", interval "11-11" year_month)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", 10)`, false}, + {`select date_sub("2011-11-11 10:10:10.123456", 0.10)`, false}, + {`select date_sub("2011-11-11 10:10:10.123456", "11,11")`, false}, // For subdate {`select subdate("2011-11-11 10:10:10.123456", interval 10 microsecond)`, true}, @@ -769,9 +775,9 @@ func (s *testParserSuite) TestBuiltin(c *C) { {`select subdate("2011-11-11 10:10:10.123456", interval "11 10:10" day_minute)`, true}, {`select subdate("2011-11-11 10:10:10.123456", interval "11 10" day_hour)`, true}, {`select subdate("2011-11-11 10:10:10.123456", interval "11-11" year_month)`, true}, - {`select adddate("2011-11-11 10:10:10.123456", 10)`, true}, - {`select adddate("2011-11-11 10:10:10.123456", 0.10)`, true}, - {`select adddate("2011-11-11 10:10:10.123456", "11,11")`, true}, + {`select subdate("2011-11-11 10:10:10.123456", 10)`, true}, + {`select subdate("2011-11-11 10:10:10.123456", 0.10)`, true}, + {`select subdate("2011-11-11 10:10:10.123456", "11,11")`, true}, // For misc functions {`SELECT GET_LOCK('lock1',10);`, true}, diff --git a/plan/expression_test.go b/plan/expression_test.go index 5d2c1645a1..bd743a1c9b 100644 --- a/plan/expression_test.go +++ b/plan/expression_test.go @@ -304,6 +304,9 @@ func (s *testExpressionSuite) TestDateArith(c *C) { // nil test {nil, 1, "DAY", nil, nil, false}, {"2011-11-11", nil, "DAY", nil, nil, false}, + // tests for inner function call + {"2011-11-11", s.parseExpr(c, "LEAST(1, 2)"), "DAY", "2011-11-12", "2011-11-10", false}, + {"2011-11-11", s.parseExpr(c, "LEAST(NULL, 2)"), "DAY", nil, nil, false}, // tests for different units {"2011-11-11 10:10:10", 1000, "MICROSECOND", "2011-11-11 10:10:10.001000", "2011-11-11 10:10:09.999000", false}, {"2011-11-11 10:10:10", "10", "SECOND", "2011-11-11 10:10:20", "2011-11-11 10:10:00", false}, @@ -364,19 +367,19 @@ func (s *testExpressionSuite) TestDateArith(c *C) { // run the test cases for _, t := range tests { op := ast.NewValueExpr(ast.DateAdd) - dateArithInterval := ast.NewValueExpr( - ast.DateArithInterval{ - Unit: t.Unit, - Interval: ast.NewValueExpr(t.Interval), - }, - ) - date := ast.NewValueExpr(t.Date) + var interval ast.ExprNode + if n, ok := t.Interval.(ast.ExprNode); ok { + interval = n + } else { + interval = ast.NewValueExpr(t.Interval) + } expr := &ast.FuncCallExpr{ FnName: model.NewCIStr("DATE_ARITH"), Args: []ast.ExprNode{ op, - date, - dateArithInterval, + ast.NewValueExpr(t.Date), + interval, + ast.NewValueExpr(t.Unit), }, } ast.SetFlag(expr)