diff --git a/bootstrap.go b/bootstrap.go index ead2091356..8de29f1d9b 100644 --- a/bootstrap.go +++ b/bootstrap.go @@ -22,7 +22,7 @@ import ( "runtime/debug" "github.com/ngaut/log" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util/errors" "github.com/pingcap/tidb/util/errors2" ) diff --git a/column/column.go b/column/column.go index 6acdfe2c37..d6831c8f97 100644 --- a/column/column.go +++ b/column/column.go @@ -24,7 +24,7 @@ import ( "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util/types" ) diff --git a/column/column_test.go b/column/column_test.go index d4ae1a2a35..45710b588b 100644 --- a/column/column_test.go +++ b/column/column_test.go @@ -18,7 +18,7 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util/types" ) diff --git a/ddl/ddl.go b/ddl/ddl.go index d8130e74cf..daab43b48c 100644 --- a/ddl/ddl.go +++ b/ddl/ddl.go @@ -30,7 +30,7 @@ import ( "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/meta" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/coldef" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/table/tables" @@ -512,7 +512,7 @@ func updateOldRows(ctx context.Context, t *tables.Table, col *column.Col) error if err != nil { return errors.Trace(err) } - it, err := txn.Seek([]byte(t.FirstKey()), nil) + it, err := txn.Seek([]byte(t.FirstKey())) if err != nil { return errors.Trace(err) } @@ -685,7 +685,7 @@ func (d *ddl) buildIndex(ctx context.Context, t table.Table, idxInfo *model.Inde if err != nil { return errors.Trace(err) } - it, err := txn.Seek([]byte(firstKey), nil) + it, err := txn.Seek([]byte(firstKey)) if err != nil { return errors.Trace(err) } diff --git a/ddl/ddl_test.go b/ddl/ddl_test.go index 2945b69388..864567a29f 100644 --- a/ddl/ddl_test.go +++ b/ddl/ddl_test.go @@ -23,7 +23,7 @@ import ( "github.com/pingcap/tidb/ddl" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/stmt" diff --git a/driver.go b/driver.go index 31d58c5acd..ad2f08647c 100644 --- a/driver.go +++ b/driver.go @@ -31,7 +31,7 @@ import ( "github.com/juju/errors" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/rset" "github.com/pingcap/tidb/sessionctx" qerror "github.com/pingcap/tidb/util/errors" diff --git a/expression/binop.go b/expression/binop.go index d466a3ddd9..8ce8cf2b54 100644 --- a/expression/binop.go +++ b/expression/binop.go @@ -23,7 +23,7 @@ import ( "github.com/juju/errors" "github.com/pingcap/tidb/context" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" "github.com/pingcap/tidb/util/types" ) diff --git a/expression/binop_test.go b/expression/binop_test.go index 81f14ce804..d3d7c28776 100644 --- a/expression/binop_test.go +++ b/expression/binop_test.go @@ -20,7 +20,7 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/tidb/expression/builtin" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" "github.com/pingcap/tidb/util/types" ) @@ -186,7 +186,7 @@ func (s *testBinOpSuite) TestIdentRelOp(c *C) { f := func(name string) *Ident { return &Ident{ - model.NewCIStr(name), + CIStr: model.NewCIStr(name), } } diff --git a/expression/builtin/groupby.go b/expression/builtin/groupby.go index 69cd07253f..5566c3897f 100644 --- a/expression/builtin/groupby.go +++ b/expression/builtin/groupby.go @@ -24,7 +24,7 @@ import ( "github.com/juju/errors" "github.com/pingcap/tidb/kv/memkv" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util/types" ) diff --git a/expression/builtin/groupby_test.go b/expression/builtin/groupby_test.go index a6116bd968..0730725bda 100644 --- a/expression/builtin/groupby_test.go +++ b/expression/builtin/groupby_test.go @@ -15,7 +15,7 @@ package builtin import ( . "github.com/pingcap/check" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util/types" ) diff --git a/expression/builtin/string_test.go b/expression/builtin/string_test.go index 878191115c..5c128a47d9 100644 --- a/expression/builtin/string_test.go +++ b/expression/builtin/string_test.go @@ -19,7 +19,7 @@ import ( "time" . "github.com/pingcap/check" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" ) func (s *testBuiltinSuite) TestLength(c *C) { diff --git a/expression/builtin/time.go b/expression/builtin/time.go index 22e9e07468..ecfb62002e 100644 --- a/expression/builtin/time.go +++ b/expression/builtin/time.go @@ -21,7 +21,7 @@ import ( "time" "github.com/juju/errors" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util/types" ) diff --git a/expression/builtin/time_test.go b/expression/builtin/time_test.go index 151b0ed9ba..c3e662933f 100644 --- a/expression/builtin/time_test.go +++ b/expression/builtin/time_test.go @@ -18,7 +18,7 @@ import ( "time" . "github.com/pingcap/check" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" ) func (s *testBuiltinSuite) TestDate(c *C) { diff --git a/expression/cast.go b/expression/cast.go index 42f54420b1..4b0d549e79 100644 --- a/expression/cast.go +++ b/expression/cast.go @@ -18,7 +18,7 @@ import ( "github.com/pingcap/tidb/context" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util/types" ) diff --git a/expression/cast_test.go b/expression/cast_test.go index 7d969e5f06..72556f40f9 100644 --- a/expression/cast_test.go +++ b/expression/cast_test.go @@ -17,7 +17,7 @@ import ( "errors" . "github.com/pingcap/check" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util/charset" "github.com/pingcap/tidb/util/types" ) diff --git a/expression/date_add.go b/expression/date_add.go new file mode 100644 index 0000000000..95cd99a0de --- /dev/null +++ b/expression/date_add.go @@ -0,0 +1,104 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package expression + +import ( + "fmt" + "strings" + + "github.com/juju/errors" + "github.com/pingcap/tidb/context" + "github.com/pingcap/tidb/mysql" + "github.com/pingcap/tidb/util/types" +) + +// DateAdd is for time date_add function. +// See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-add +type DateAdd struct { + Unit string + Date Expression + Interval Expression +} + +// Clone implements the Expression Clone interface. +func (da *DateAdd) Clone() Expression { + n := *da + return &n +} + +// Eval implements the Expression Eval interface. +func (da *DateAdd) Eval(ctx context.Context, args map[interface{}]interface{}) (interface{}, error) { + dv, err := da.Date.Eval(ctx, args) + if dv == nil || err != nil { + return nil, errors.Trace(err) + } + + sv, err := types.ToString(dv) + if err != nil { + return nil, errors.Trace(err) + } + + f := types.NewFieldType(mysql.TypeDatetime) + f.Decimal = mysql.MaxFsp + + dv, err = types.Convert(sv, f) + if dv == nil || err != nil { + return nil, errors.Trace(err) + } + + t, ok := dv.(mysql.Time) + if !ok { + return nil, errors.Errorf("need time type, but got %T", dv) + } + + iv, err := da.Interval.Eval(ctx, args) + if iv == nil || err != nil { + return nil, errors.Trace(err) + } + + format, err := types.ToString(iv) + if err != nil { + return nil, errors.Trace(err) + } + + years, months, days, durations, err := mysql.ExtractTimeValue(da.Unit, strings.TrimSpace(format)) + if err != nil { + return nil, errors.Trace(err) + } + + t.Time = t.Time.Add(durations) + t.Time = t.Time.AddDate(int(years), int(months), int(days)) + + // "2011-11-11 10:10:20.000000" outputs "2011-11-11 10:10:20". + if t.Time.Nanosecond() == 0 { + t.Fsp = 0 + } + + return t, nil +} + +// IsStatic implements the Expression IsStatic interface. +func (da *DateAdd) IsStatic() bool { + return da.Date.IsStatic() && da.Interval.IsStatic() +} + +// String implements the Expression String interface. +func (da *DateAdd) String() string { + return fmt.Sprintf("DATE_ADD(%s, INTERVAL %s %s)", da.Date, da.Interval, strings.ToUpper(da.Unit)) +} + +// Accept implements the Visitor Accept interface. +func (da *DateAdd) Accept(v Visitor) (Expression, error) { + return v.VisitDateAdd(da) +} diff --git a/expression/date_add_test.go b/expression/date_add_test.go new file mode 100644 index 0000000000..6b612dce40 --- /dev/null +++ b/expression/date_add_test.go @@ -0,0 +1,144 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package expression + +import ( + . "github.com/pingcap/check" + "github.com/pingcap/tidb/mysql" +) + +var _ = Suite(&testDateAddSuite{}) + +type testDateAddSuite struct { +} + +func (t *testDateAddSuite) TestDateAdd(c *C) { + input := "2011-11-11 10:10:10" + e := &DateAdd{ + Unit: "DAY", + Date: Value{Val: input}, + Interval: Value{Val: "1"}, + } + c.Assert(e.String(), Equals, `DATE_ADD("2011-11-11 10:10:10", INTERVAL "1" DAY)`) + c.Assert(e.Clone(), NotNil) + c.Assert(e.IsStatic(), IsTrue) + + _, err := e.Eval(nil, nil) + c.Assert(err, IsNil) + + // Test null. + e = &DateAdd{ + Unit: "DAY", + Date: Value{Val: nil}, + Interval: Value{Val: "1"}, + } + + v, err := e.Eval(nil, nil) + c.Assert(err, IsNil) + c.Assert(v, IsNil) + + e = &DateAdd{ + Unit: "DAY", + Date: Value{Val: input}, + Interval: Value{Val: nil}, + } + + v, err = e.Eval(nil, nil) + c.Assert(err, IsNil) + c.Assert(v, IsNil) + + // Test eval. + tbl := []struct { + Unit string + Interval interface{} + Expect string + }{ + {"MICROSECOND", "1000", "2011-11-11 10:10:10.001000"}, + {"MICROSECOND", 1000, "2011-11-11 10:10:10.001000"}, + {"SECOND", "10", "2011-11-11 10:10:20"}, + {"MINUTE", "10", "2011-11-11 10:20:10"}, + {"HOUR", "10", "2011-11-11 20:10:10"}, + {"DAY", "11", "2011-11-22 10:10:10"}, + {"WEEK", "2", "2011-11-25 10:10:10"}, + {"MONTH", "2", "2012-01-11 10:10:10"}, + {"QUARTER", "4", "2012-11-11 10:10:10"}, + {"YEAR", "2", "2013-11-11 10:10:10"}, + {"SECOND_MICROSECOND", "10.00100000", "2011-11-11 10:10:20.100000"}, + {"SECOND_MICROSECOND", "10.0010000000", "2011-11-11 10:10:30"}, + {"SECOND_MICROSECOND", "10.0010000010", "2011-11-11 10:10:30.000010"}, + {"MINUTE_MICROSECOND", "10:10.100", "2011-11-11 10:20:20.100000"}, + {"MINUTE_SECOND", "10:10", "2011-11-11 10:20:20"}, + {"HOUR_MICROSECOND", "10:10:10.100", "2011-11-11 20:20:20.100000"}, + {"HOUR_SECOND", "10:10:10", "2011-11-11 20:20:20"}, + {"HOUR_MINUTE", "10:10", "2011-11-11 20:20:10"}, + {"DAY_MICROSECOND", "11 10:10:10.100", "2011-11-22 20:20:20.100000"}, + {"DAY_SECOND", "11 10:10:10", "2011-11-22 20:20:20"}, + {"DAY_MINUTE", "11 10:10", "2011-11-22 20:20:10"}, + {"DAY_HOUR", "11 10", "2011-11-22 20:10:10"}, + {"YEAR_MONTH", "11-1", "2022-12-11 10:10:10"}, + {"YEAR_MONTH", "11-11", "2023-10-11 10:10:10"}, + } + + for _, t := range tbl { + e := &DateAdd{ + Unit: t.Unit, + Date: Value{Val: input}, + Interval: Value{Val: t.Interval}, + } + + v, err := e.Eval(nil, nil) + c.Assert(err, IsNil) + + value, ok := v.(mysql.Time) + c.Assert(ok, IsTrue) + c.Assert(value.String(), Equals, t.Expect) + } + + // Test error. + errInput := "20111111 10:10:10" + errTbl := []struct { + Unit string + Interval interface{} + }{ + {"MICROSECOND", "abc1000"}, + {"MICROSECOND", ""}, + {"SECOND_MICROSECOND", "10"}, + {"MINUTE_MICROSECOND", "10.0000"}, + {"MINUTE_MICROSECOND", "10:10:10.0000"}, + + // MySQL support, but tidb not. + {"HOUR_MICROSECOND", "10:10.0000"}, + {"YEAR_MONTH", "10 1"}, + } + + for _, t := range errTbl { + e := &DateAdd{ + Unit: t.Unit, + Date: Value{Val: input}, + Interval: Value{Val: t.Interval}, + } + + _, err := e.Eval(nil, nil) + c.Assert(err, NotNil) + + e = &DateAdd{ + Unit: t.Unit, + Date: Value{Val: errInput}, + Interval: Value{Val: t.Interval}, + } + + v, err := e.Eval(nil, nil) + c.Assert(err, NotNil, Commentf("%s", v)) + } +} diff --git a/expression/extract.go b/expression/extract.go index 04d6c1f9e1..32608cdddb 100644 --- a/expression/extract.go +++ b/expression/extract.go @@ -19,12 +19,12 @@ import ( "github.com/juju/errors" "github.com/pingcap/tidb/context" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util/types" ) // Extract is for time extract function. -// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_extract +// See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_extract type Extract struct { Unit string Date Expression @@ -56,7 +56,7 @@ func (e *Extract) Eval(ctx context.Context, args map[interface{}]interface{}) (i return nil, errors.Errorf("need time type, but got %T", v) } - n, err1 := extractTime(e.Unit, t) + n, err1 := mysql.ExtractTimeNum(e.Unit, t) if err1 != nil { return nil, errors.Trace(err1) } @@ -78,70 +78,3 @@ func (e *Extract) String() string { func (e *Extract) Accept(v Visitor) (Expression, error) { return v.VisitExtract(e) } - -func extractTime(unit string, t mysql.Time) (int64, error) { - switch strings.ToUpper(unit) { - case "MICROSECOND": - return int64(t.Nanosecond() / 1000), nil - case "SECOND": - return int64(t.Second()), nil - case "MINUTE": - return int64(t.Minute()), nil - case "HOUR": - return int64(t.Hour()), nil - case "DAY": - return int64(t.Day()), nil - case "WEEK": - _, week := t.ISOWeek() - return int64(week), nil - case "MONTH": - return int64(t.Month()), nil - case "QUARTER": - m := int64(t.Month()) - // 1 - 3 -> 1 - // 4 - 6 -> 2 - // 7 - 9 -> 3 - // 10 - 12 -> 4 - return (m + 2) / 3, nil - case "YEAR": - return int64(t.Year()), nil - case "SECOND_MICROSECOND": - return int64(t.Second())*1000000 + int64(t.Nanosecond())/1000, nil - case "MINUTE_MICROSECOND": - _, m, s := t.Clock() - return int64(m)*100000000 + int64(s)*1000000 + int64(t.Nanosecond())/1000, nil - case "MINUTE_SECOND": - _, m, s := t.Clock() - return int64(m*100 + s), nil - case "HOUR_MICROSECOND": - h, m, s := t.Clock() - return int64(h)*10000000000 + int64(m)*100000000 + int64(s)*1000000 + int64(t.Nanosecond())/1000, nil - case "HOUR_SECOND": - h, m, s := t.Clock() - return int64(h)*10000 + int64(m)*100 + int64(s), nil - case "HOUR_MINUTE": - h, m, _ := t.Clock() - return int64(h)*100 + int64(m), nil - case "DAY_MICROSECOND": - h, m, s := t.Clock() - d := t.Day() - return int64(d*1000000+h*10000+m*100+s)*1000000 + int64(t.Nanosecond())/1000, nil - case "DAY_SECOND": - h, m, s := t.Clock() - d := t.Day() - return int64(d)*1000000 + int64(h)*10000 + int64(m)*100 + int64(s), nil - case "DAY_MINUTE": - h, m, _ := t.Clock() - d := t.Day() - return int64(d)*10000 + int64(h)*100 + int64(m), nil - case "DAY_HOUR": - h, _, _ := t.Clock() - d := t.Day() - return int64(d)*100 + int64(h), nil - case "YEAR_MONTH": - y, m, _ := t.Date() - return int64(y)*100 + int64(m), nil - default: - return 0, errors.Errorf("invalid unit %s", unit) - } -} diff --git a/expression/helper.go b/expression/helper.go index 69554c4354..563d907dc0 100644 --- a/expression/helper.go +++ b/expression/helper.go @@ -30,7 +30,7 @@ import ( "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/expression/builtin" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/types" @@ -45,13 +45,15 @@ const ( ExprEvalPositionFunc = "$positionFunc" // ExprEvalValuesFunc is the key saving a function to retrieve value for column name. ExprEvalValuesFunc = "$valuesFunc" + // ExprEvalIdentReferFunc is the key saving a function to retrieve value with identifier reference index. + ExprEvalIdentReferFunc = "$identReferFunc" ) var ( // CurrentTimestamp is the keyword getting default value for datetime and timestamp type. CurrentTimestamp = "CURRENT_TIMESTAMP" // CurrentTimeExpr is the expression retireving default value for datetime and timestamp type. - CurrentTimeExpr = &Ident{model.NewCIStr(CurrentTimestamp)} + CurrentTimeExpr = &Ident{CIStr: model.NewCIStr(CurrentTimestamp)} // ZeroTimestamp shows the zero datetime and timestamp. ZeroTimestamp = "0000-00-00 00:00:00" ) @@ -142,22 +144,40 @@ func newMentionedAggregateFuncsVisitor() *mentionedAggregateFuncsVisitor { } func (v *mentionedAggregateFuncsVisitor) VisitCall(c *Call) (Expression, error) { + isAggregate := IsAggregateFunc(c.F) + + if isAggregate { + v.exprs = append(v.exprs, c) + } + + n := len(v.exprs) for _, e := range c.Args { _, err := e.Accept(v) if err != nil { return nil, errors.Trace(err) } } - f, ok := builtin.Funcs[strings.ToLower(c.F)] - if !ok { - return nil, errors.Errorf("unknown function %s", c.F) - } - if f.IsAggregate { - v.exprs = append(v.exprs, c) + + if isAggregate && len(v.exprs) != n { + // aggregate function can't use aggregate function as the arg. + // here means we have aggregate function in arg. + return nil, errors.Errorf("Invalid use of group function") } + return c, nil } +// IsAggregateFunc checks whether name is an aggregate function or not. +func IsAggregateFunc(name string) bool { + // TODO: use switch defined aggregate name "sum", "count", etc... directly. + // Maybe we can remove builtin IsAggregate field later. + f, ok := builtin.Funcs[strings.ToLower(name)] + if !ok { + return false + } + return f.IsAggregate +} + // MentionedColumns returns a list of names for Ident expression. func MentionedColumns(e Expression) []string { var names []string diff --git a/expression/helper_test.go b/expression/helper_test.go index 6bda7bc2ac..21e01e5127 100644 --- a/expression/helper_test.go +++ b/expression/helper_test.go @@ -5,9 +5,8 @@ import ( "time" . "github.com/pingcap/check" - "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" "github.com/pingcap/tidb/sessionctx/variable" ) @@ -47,6 +46,15 @@ func (s *testHelperSuite) TestContainAggFunc(c *C) { b := ContainAggregateFunc(t.Expr) c.Assert(b, Equals, t.Expect) } + + expr := &Call{ + F: "count", + Args: []Expression{ + &Call{F: "count", Args: []Expression{v}}, + }, + } + _, err := MentionedAggregateFuncs(expr) + c.Assert(err, NotNil) } func (s *testHelperSuite) TestMentionedColumns(c *C) { @@ -57,7 +65,7 @@ func (s *testHelperSuite) TestMentionedColumns(c *C) { }{ {Value{1}, 0}, {&BinaryOperation{L: v, R: v}, 0}, - {&Ident{model.NewCIStr("id")}, 1}, + {&Ident{CIStr: model.NewCIStr("id")}, 1}, {&Call{F: "count", Args: []Expression{v}}, 0}, {&IsNull{Expr: v}, 0}, {&PExpr{Expr: v}, 0}, @@ -106,7 +114,7 @@ func (s *testHelperSuite) TestBase(c *C) { {int64(1), int64(1)}, {&UnaryOperation{Op: opcode.Plus, V: Value{1}}, 1}, {&UnaryOperation{Op: opcode.Not, V: Value{1}}, nil}, - {&UnaryOperation{Op: opcode.Plus, V: &Ident{model.NewCIStr("id")}}, nil}, + {&UnaryOperation{Op: opcode.Plus, V: &Ident{CIStr: model.NewCIStr("id")}}, nil}, {nil, nil}, } @@ -228,7 +236,7 @@ func (s *testHelperSuite) TestGetTimeValue(c *C) { {Value{"2012-13-12 00:00:00"}}, {Value{0}}, {Value{int64(1)}}, - {&Ident{model.NewCIStr("xxx")}}, + {&Ident{CIStr: model.NewCIStr("xxx")}}, {NewUnaryOperation(opcode.Minus, Value{int64(1)})}, } diff --git a/expression/ident.go b/expression/ident.go index a2ddff3d23..427355fe79 100644 --- a/expression/ident.go +++ b/expression/ident.go @@ -29,10 +29,23 @@ var ( _ Expression = (*Ident)(nil) ) +const ( + // IdentReferSelectList means the identifier reference is in select list. + IdentReferSelectList = 1 + // IdentReferFromTable means the identifier reference is in FROM table. + IdentReferFromTable = 2 +) + // Ident is the identifier expression. type Ident struct { // model.CIStr contains origin identifier name and its lowercase name. model.CIStr + + // ReferScope means where the identifer reference is, select list or from. + ReferScope int + + // ReferIndex is the index to get the identifer data. + ReferIndex int } // Clone implements the Expression Clone interface. @@ -63,6 +76,14 @@ func (i *Ident) Eval(ctx context.Context, args map[interface{}]interface{}) (v i return nil, nil } + // TODO: we will unify ExprEvalIdentReferFunc and ExprEvalIdentFunc later, + // now just put them here for refactor step by step. + if f, ok := args[ExprEvalIdentReferFunc]; ok { + if got, ok := f.(func(string, int, int) (interface{}, error)); ok { + return got(i.L, i.ReferScope, i.ReferIndex) + } + } + if f, ok := args[ExprEvalIdentFunc]; ok { if got, ok := f.(func(string) (interface{}, error)); ok { return got(i.L) diff --git a/expression/ident_test.go b/expression/ident_test.go index f71289be6c..6d3dfe3719 100644 --- a/expression/ident_test.go +++ b/expression/ident_test.go @@ -26,7 +26,7 @@ type testIdentSuite struct { func (s *testIdentSuite) TestIdent(c *C) { e := Ident{ - model.NewCIStr("id"), + CIStr: model.NewCIStr("id"), } c.Assert(e.IsStatic(), IsFalse) @@ -55,4 +55,14 @@ func (s *testIdentSuite) TestIdent(c *C) { v, err = e.Eval(nil, m) c.Assert(err, IsNil) c.Assert(v, Equals, 1) + + delete(m, ExprEvalIdentFunc) + e.ReferScope = IdentReferSelectList + e.ReferIndex = 1 + m[ExprEvalIdentReferFunc] = func(string, int, int) (interface{}, error) { + return 2, nil + } + v, err = e.Eval(nil, m) + c.Assert(err, IsNil) + c.Assert(v, Equals, 2) } diff --git a/expression/unary.go b/expression/unary.go index 7c58c128b1..f7c8e9ea6b 100644 --- a/expression/unary.go +++ b/expression/unary.go @@ -23,7 +23,7 @@ import ( "github.com/juju/errors" "github.com/pingcap/tidb/context" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" "github.com/pingcap/tidb/util/types" ) diff --git a/expression/unary_test.go b/expression/unary_test.go index bef54d7982..aff61503ba 100644 --- a/expression/unary_test.go +++ b/expression/unary_test.go @@ -19,7 +19,7 @@ import ( . "github.com/pingcap/check" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" "github.com/pingcap/tidb/util/types" ) diff --git a/expression/visitor.go b/expression/visitor.go index b6b7c04bec..6ec16f765f 100644 --- a/expression/visitor.go +++ b/expression/visitor.go @@ -106,6 +106,9 @@ type Visitor interface { // VisitFunctionTrim visits FunctionTrim expression. VisitFunctionTrim(v *FunctionTrim) (Expression, error) + + // VisitDateAdd visits DateAdd expression. + VisitDateAdd(da *DateAdd) (Expression, error) } // BaseVisitor is the base implementation of Visitor. @@ -456,7 +459,7 @@ func (bv *BaseVisitor) VisitWhenClause(w *WhenClause) (Expression, error) { return w, nil } -// VisitExtract implements Visitor +// VisitExtract implements Visitor interface. func (bv *BaseVisitor) VisitExtract(v *Extract) (Expression, error) { var err error v.Date, err = v.Date.Accept(bv.V) @@ -478,3 +481,19 @@ func (bv *BaseVisitor) VisitFunctionTrim(ss *FunctionTrim) (Expression, error) { } return ss, nil } + +// VisitDateAdd implements Visitor interface. +func (bv *BaseVisitor) VisitDateAdd(da *DateAdd) (Expression, error) { + var err error + da.Date, err = da.Date.Accept(bv.V) + if err != nil { + return da, errors.Trace(err) + } + + da.Interval, err = da.Interval.Accept(bv.V) + if err != nil { + return da, errors.Trace(err) + } + + return da, nil +} diff --git a/expression/visitor_test.go b/expression/visitor_test.go index 281c86c00d..b88315bfa6 100644 --- a/expression/visitor_test.go +++ b/expression/visitor_test.go @@ -17,7 +17,7 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" "github.com/pingcap/tidb/util/types" ) diff --git a/field/field_test.go b/field/field_test.go index c72d29b6e6..49cb09663c 100644 --- a/field/field_test.go +++ b/field/field_test.go @@ -19,7 +19,7 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/field" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util/types" ) diff --git a/field/result_field.go b/field/result_field.go index 2eb301b9d2..fd79ab9e77 100644 --- a/field/result_field.go +++ b/field/result_field.go @@ -23,7 +23,7 @@ import ( "github.com/juju/errors" "github.com/pingcap/tidb/column" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" ) const ( diff --git a/field/result_field_test.go b/field/result_field_test.go index 39308aefd5..9d58d86a20 100644 --- a/field/result_field_test.go +++ b/field/result_field_test.go @@ -19,7 +19,7 @@ import ( "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/field" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util/types" ) diff --git a/infoschema/infoschema_test.go b/infoschema/infoschema_test.go index ae197c93f4..f10910ead9 100644 --- a/infoschema/infoschema_test.go +++ b/infoschema/infoschema_test.go @@ -19,7 +19,7 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/model" - "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/store/localstore" "github.com/pingcap/tidb/store/localstore/goleveldb" "github.com/pingcap/tidb/util/types" @@ -51,7 +51,7 @@ func (*testSuite) TestT(c *C) { ID: 3, Name: colName, Offset: 0, - FieldType: *types.NewFieldType(mysqldef.TypeLonglong), + FieldType: *types.NewFieldType(mysql.TypeLonglong), } idxInfo := &model.IndexInfo{ diff --git a/kv/compactor.go b/kv/compactor.go new file mode 100644 index 0000000000..47ebd7a619 --- /dev/null +++ b/kv/compactor.go @@ -0,0 +1,40 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package kv + +import "time" + +// CompactPolicy defines gc policy of MVCC storage. +type CompactPolicy struct { + // SafePoint specifies + SafePoint int + // TriggerInterval specifies how often should the compactor + // scans outdated data. + TriggerInterval time.Duration + // BatchDeleteCnt specifies the batch size for + // deleting outdated data transaction. + BatchDeleteCnt int +} + +// Compactor compacts MVCC storage. +type Compactor interface { + // OnGet is the hook point on Txn.Get. + OnGet(k Key) + // OnSet is the hook point on Txn.Set. + OnSet(k Key) + // OnDelete is the hook point on Txn.Delete. + OnDelete(k Key) + // Compact is the function removes the given key. + Compact(ctx interface{}, k Key) error +} diff --git a/kv/index_iter.go b/kv/index_iter.go index 98c60557ec..e19004ddeb 100644 --- a/kv/index_iter.go +++ b/kv/index_iter.go @@ -87,7 +87,7 @@ func (c *indexIter) Next() (k []interface{}, h int64, err error) { k = vv } // update new iter to next - newIt, err := c.it.Next(hasPrefix([]byte(c.prefix))) + newIt, err := c.it.Next() if err != nil { return nil, 0, errors.Trace(err) } @@ -189,16 +189,10 @@ func (c *kvIndex) Delete(txn Transaction, indexedValues []interface{}, h int64) return errors.Trace(err) } -func hasPrefix(prefix []byte) FnKeyCmp { - return func(k Key) bool { - return bytes.HasPrefix([]byte(k), prefix) - } -} - // Drop removes the KV index from store. func (c *kvIndex) Drop(txn Transaction) error { prefix := []byte(c.prefix) - it, err := txn.Seek(Key(prefix), hasPrefix(prefix)) + it, err := txn.Seek(Key(prefix)) if err != nil { return errors.Trace(err) } @@ -213,7 +207,7 @@ func (c *kvIndex) Drop(txn Transaction) error { if err != nil { return errors.Trace(err) } - it, err = it.Next(hasPrefix(prefix)) + it, err = it.Next() if err != nil { return errors.Trace(err) } @@ -227,7 +221,7 @@ func (c *kvIndex) Seek(txn Transaction, indexedValues []interface{}) (iter Index if err != nil { return nil, false, errors.Trace(err) } - it, err := txn.Seek(keyBuf, hasPrefix([]byte(c.prefix))) + it, err := txn.Seek(keyBuf) if err != nil { return nil, false, errors.Trace(err) } @@ -242,7 +236,7 @@ func (c *kvIndex) Seek(txn Transaction, indexedValues []interface{}) (iter Index // SeekFirst returns an iterator which points to the first entry of the KV index. func (c *kvIndex) SeekFirst(txn Transaction) (iter IndexIterator, err error) { prefix := []byte(c.prefix) - it, err := txn.Seek(prefix, hasPrefix(prefix)) + it, err := txn.Seek(prefix) if err != nil { return nil, errors.Trace(err) } diff --git a/kv/iter.go b/kv/iter.go index 7dc7beaecf..03a99c5a82 100644 --- a/kv/iter.go +++ b/kv/iter.go @@ -72,7 +72,7 @@ func DecodeValue(data []byte) ([]interface{}, error) { func NextUntil(it Iterator, fn FnKeyCmp) (Iterator, error) { var err error for it.Valid() && !fn([]byte(it.Key())) { - it, err = it.Next(nil) + it, err = it.Next() if err != nil { return nil, errors.Trace(err) } diff --git a/kv/kv.go b/kv/kv.go index b93e9dc3a2..6681cc6f0f 100644 --- a/kv/kv.go +++ b/kv/kv.go @@ -104,7 +104,7 @@ type Transaction interface { // Set sets the value for key k as v into KV store. Set(k Key, v []byte) error // Seek searches for the entry with key k in KV store. - Seek(k Key, fnKeyCmp func(key Key) bool) (Iterator, error) + Seek(k Key) (Iterator, error) // Inc increases the value for key k in KV store by step. Inc(k Key, step int64) (int64, error) // GetInt64 get int64 which created by Inc method. @@ -171,7 +171,7 @@ type FnKeyCmp func(key Key) bool // Iterator is the interface for a interator on KV store. type Iterator interface { - Next(FnKeyCmp) (Iterator, error) + Next() (Iterator, error) Value() []byte Key() string Valid() bool diff --git a/kv/union_iter.go b/kv/union_iter.go index 6b677c433c..a25f81e776 100644 --- a/kv/union_iter.go +++ b/kv/union_iter.go @@ -51,7 +51,7 @@ func (iter *UnionIter) dirtyNext() { // Go next and update valid status. func (iter *UnionIter) snapshotNext() { - iter.snapshotIt, _ = iter.snapshotIt.Next(nil) + iter.snapshotIt, _ = iter.snapshotIt.Next() iter.snapshotValid = iter.snapshotIt.Valid() } @@ -116,7 +116,7 @@ func (iter *UnionIter) updateCur() { } // Next implements the Iterator Next interface. -func (iter *UnionIter) Next(f FnKeyCmp) (Iterator, error) { +func (iter *UnionIter) Next() (Iterator, error) { if !iter.curIsDirty { iter.snapshotNext() } else { diff --git a/mysqldef/bit.go b/mysql/bit.go similarity index 99% rename from mysqldef/bit.go rename to mysql/bit.go index 985c5b1d58..4f76dceb18 100644 --- a/mysqldef/bit.go +++ b/mysql/bit.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql import ( "fmt" diff --git a/mysqldef/bit_test.go b/mysql/bit_test.go similarity index 98% rename from mysqldef/bit_test.go rename to mysql/bit_test.go index b2ca9d1080..49652206df 100644 --- a/mysqldef/bit_test.go +++ b/mysql/bit_test.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql import . "github.com/pingcap/check" diff --git a/mysqldef/charset.go b/mysql/charset.go similarity index 99% rename from mysqldef/charset.go rename to mysql/charset.go index fb4809c362..962d792e50 100644 --- a/mysqldef/charset.go +++ b/mysql/charset.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql // CharsetIDs maps charset name to its default collation ID. var CharsetIDs = map[string]uint8{ diff --git a/mysqldef/const.go b/mysql/const.go similarity index 99% rename from mysqldef/const.go rename to mysql/const.go index 3f483c1eb0..e0f49e4294 100644 --- a/mysqldef/const.go +++ b/mysql/const.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql // Version informations. const ( diff --git a/mysqldef/decimal.go b/mysql/decimal.go similarity index 99% rename from mysqldef/decimal.go rename to mysql/decimal.go index 2ff129af19..87bd66e6b3 100644 --- a/mysqldef/decimal.go +++ b/mysql/decimal.go @@ -57,7 +57,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql // Decimal implements an arbitrary precision fixed-point decimal. // diff --git a/mysqldef/decimal_test.go b/mysql/decimal_test.go similarity index 99% rename from mysqldef/decimal_test.go rename to mysql/decimal_test.go index 60bb409c88..ab9748e232 100644 --- a/mysqldef/decimal_test.go +++ b/mysql/decimal_test.go @@ -57,7 +57,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql import ( "encoding/json" diff --git a/mysqldef/enum.go b/mysql/enum.go similarity index 98% rename from mysqldef/enum.go rename to mysql/enum.go index 96c4712800..425a3e1b52 100644 --- a/mysqldef/enum.go +++ b/mysql/enum.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql import ( "strconv" diff --git a/mysqldef/enum_test.go b/mysql/enum_test.go similarity index 98% rename from mysqldef/enum_test.go rename to mysql/enum_test.go index 0c06660493..154ac73802 100644 --- a/mysqldef/enum_test.go +++ b/mysql/enum_test.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql import ( . "github.com/pingcap/check" diff --git a/mysqldef/errcode.go b/mysql/errcode.go similarity index 99% rename from mysqldef/errcode.go rename to mysql/errcode.go index d486c36c90..8de53d7e71 100644 --- a/mysqldef/errcode.go +++ b/mysql/errcode.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql // MySQL error code. // This value is numeric. It is not portable to other database systems. diff --git a/mysqldef/errname.go b/mysql/errname.go similarity index 99% rename from mysqldef/errname.go rename to mysql/errname.go index bc98fed16b..ad11ec746e 100644 --- a/mysqldef/errname.go +++ b/mysql/errname.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql // MySQLErrName maps error code to MySQL error messages. var MySQLErrName = map[uint16]string{ diff --git a/mysqldef/error.go b/mysql/error.go similarity index 99% rename from mysqldef/error.go rename to mysql/error.go index 6b75670fc8..43246a4a2a 100644 --- a/mysqldef/error.go +++ b/mysql/error.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql import ( "errors" diff --git a/mysqldef/error_test.go b/mysql/error_test.go similarity index 98% rename from mysqldef/error_test.go rename to mysql/error_test.go index 10f9dd2d8e..1e499ff304 100644 --- a/mysqldef/error_test.go +++ b/mysql/error_test.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql import ( . "github.com/pingcap/check" diff --git a/mysqldef/fsp.go b/mysql/fsp.go similarity index 90% rename from mysqldef/fsp.go rename to mysql/fsp.go index b8802e070f..820fabb3d1 100644 --- a/mysqldef/fsp.go +++ b/mysql/fsp.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql import ( "math" @@ -80,3 +80,13 @@ func parseFrac(s string, fsp int) (int, error) { // 0.0312 round 2 -> 3 -> 30000 return int(round * math.Pow10(MaxFsp-fsp)), nil } + +// alignFrac is used to generate alignment frac, like `100` -> `100000` +func alignFrac(s string, fsp int) string { + sl := len(s) + if sl < fsp { + return s + strings.Repeat("0", fsp-sl) + } + + return s +} diff --git a/mysqldef/hex.go b/mysql/hex.go similarity index 99% rename from mysqldef/hex.go rename to mysql/hex.go index 0a3b160d5d..5858b5d969 100644 --- a/mysqldef/hex.go +++ b/mysql/hex.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql import ( "encoding/hex" diff --git a/mysqldef/hex_test.go b/mysql/hex_test.go similarity index 98% rename from mysqldef/hex_test.go rename to mysql/hex_test.go index ba1e3e55b4..551e245b64 100644 --- a/mysqldef/hex_test.go +++ b/mysql/hex_test.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql import ( "strconv" diff --git a/mysqldef/set.go b/mysql/set.go similarity index 99% rename from mysqldef/set.go rename to mysql/set.go index 613a938f55..8c0788a614 100644 --- a/mysqldef/set.go +++ b/mysql/set.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql import ( "strconv" diff --git a/mysqldef/set_test.go b/mysql/set_test.go similarity index 99% rename from mysqldef/set_test.go rename to mysql/set_test.go index 0469da6f23..69887db4b6 100644 --- a/mysqldef/set_test.go +++ b/mysql/set_test.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql import ( . "github.com/pingcap/check" diff --git a/mysqldef/state.go b/mysql/state.go similarity index 99% rename from mysqldef/state.go rename to mysql/state.go index 9dfb4e9c14..67fcd0e5af 100644 --- a/mysqldef/state.go +++ b/mysql/state.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql const ( // DefaultMySQLState is default state of the mySQL diff --git a/mysqldef/time.go b/mysql/time.go similarity index 69% rename from mysqldef/time.go rename to mysql/time.go index bb01d19926..7e094a5625 100644 --- a/mysqldef/time.go +++ b/mysql/time.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql import ( "bytes" @@ -979,3 +979,364 @@ func checkTimestamp(t Time) bool { return true } + +// ExtractTimeNum extracts time value number from time unit and format. +func ExtractTimeNum(unit string, t Time) (int64, error) { + switch strings.ToUpper(unit) { + case "MICROSECOND": + return int64(t.Nanosecond() / 1000), nil + case "SECOND": + return int64(t.Second()), nil + case "MINUTE": + return int64(t.Minute()), nil + case "HOUR": + return int64(t.Hour()), nil + case "DAY": + return int64(t.Day()), nil + case "WEEK": + _, week := t.ISOWeek() + return int64(week), nil + case "MONTH": + return int64(t.Month()), nil + case "QUARTER": + m := int64(t.Month()) + // 1 - 3 -> 1 + // 4 - 6 -> 2 + // 7 - 9 -> 3 + // 10 - 12 -> 4 + return (m + 2) / 3, nil + case "YEAR": + return int64(t.Year()), nil + case "SECOND_MICROSECOND": + return int64(t.Second())*1000000 + int64(t.Nanosecond())/1000, nil + case "MINUTE_MICROSECOND": + _, m, s := t.Clock() + return int64(m)*100000000 + int64(s)*1000000 + int64(t.Nanosecond())/1000, nil + case "MINUTE_SECOND": + _, m, s := t.Clock() + return int64(m*100 + s), nil + case "HOUR_MICROSECOND": + h, m, s := t.Clock() + return int64(h)*10000000000 + int64(m)*100000000 + int64(s)*1000000 + int64(t.Nanosecond())/1000, nil + case "HOUR_SECOND": + h, m, s := t.Clock() + return int64(h)*10000 + int64(m)*100 + int64(s), nil + case "HOUR_MINUTE": + h, m, _ := t.Clock() + return int64(h)*100 + int64(m), nil + case "DAY_MICROSECOND": + h, m, s := t.Clock() + d := t.Day() + return int64(d*1000000+h*10000+m*100+s)*1000000 + int64(t.Nanosecond())/1000, nil + case "DAY_SECOND": + h, m, s := t.Clock() + d := t.Day() + return int64(d)*1000000 + int64(h)*10000 + int64(m)*100 + int64(s), nil + case "DAY_MINUTE": + h, m, _ := t.Clock() + d := t.Day() + return int64(d)*10000 + int64(h)*100 + int64(m), nil + case "DAY_HOUR": + h, _, _ := t.Clock() + d := t.Day() + return int64(d)*100 + int64(h), nil + case "YEAR_MONTH": + y, m, _ := t.Date() + return int64(y)*100 + int64(m), nil + default: + return 0, errors.Errorf("invalid unit %s", unit) + } +} + +func extractSingleTimeValue(unit string, format string) (int64, int64, int64, time.Duration, error) { + iv, err := strconv.ParseInt(format, 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + v := time.Duration(iv) + switch strings.ToUpper(unit) { + case "MICROSECOND": + return 0, 0, 0, v * time.Microsecond, nil + case "SECOND": + return 0, 0, 0, v * time.Second, nil + case "MINUTE": + return 0, 0, 0, v * time.Minute, nil + case "HOUR": + return 0, 0, 0, v * time.Hour, nil + case "DAY": + return 0, 0, iv, 0, nil + case "WEEK": + return 0, 0, 7 * iv, 0, nil + case "MONTH": + return 0, iv, 0, 0, nil + case "QUARTER": + return 0, 3 * iv, 0, 0, nil + case "YEAR": + return iv, 0, 0, 0, nil + } + + return 0, 0, 0, 0, errors.Errorf("invalid singel timeunit - %s", unit) +} + +// Format is `SS.FFFFFF`. +func extractSecondMicrosecond(format string) (int64, int64, int64, time.Duration, error) { + fields := strings.Split(format, ".") + if len(fields) != 2 { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + seconds, err := strconv.ParseInt(fields[0], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + microseconds, err := strconv.ParseInt(alignFrac(fields[1], MaxFsp), 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + return 0, 0, 0, time.Duration(seconds)*time.Second + time.Duration(microseconds)*time.Microsecond, nil +} + +// Format is `MM:SS.FFFFFF`. +func extractMinuteMicrosecond(format string) (int64, int64, int64, time.Duration, error) { + fields := strings.Split(format, ":") + if len(fields) != 2 { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + minutes, err := strconv.ParseInt(fields[0], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + _, _, _, value, err := extractSecondMicrosecond(fields[1]) + if err != nil { + return 0, 0, 0, 0, errors.Trace(err) + } + + return 0, 0, 0, time.Duration(minutes)*time.Minute + value, nil +} + +// Format is `MM:SS`. +func extractMinuteSecond(format string) (int64, int64, int64, time.Duration, error) { + fields := strings.Split(format, ":") + if len(fields) != 2 { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + minutes, err := strconv.ParseInt(fields[0], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + seconds, err := strconv.ParseInt(fields[1], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + return 0, 0, 0, time.Duration(minutes)*time.Minute + time.Duration(seconds)*time.Second, nil +} + +// Format is `HH:MM:SS.FFFFFF`. +func extractHourMicrosecond(format string) (int64, int64, int64, time.Duration, error) { + fields := strings.Split(format, ":") + if len(fields) != 3 { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + hours, err := strconv.ParseInt(fields[0], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + minutes, err := strconv.ParseInt(fields[1], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + _, _, _, value, err := extractSecondMicrosecond(fields[2]) + if err != nil { + return 0, 0, 0, 0, errors.Trace(err) + } + + return 0, 0, 0, time.Duration(hours)*time.Hour + time.Duration(minutes)*time.Minute + value, nil +} + +// Format is `HH:MM:SS`. +func extractHourSecond(format string) (int64, int64, int64, time.Duration, error) { + fields := strings.Split(format, ":") + if len(fields) != 3 { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + hours, err := strconv.ParseInt(fields[0], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + minutes, err := strconv.ParseInt(fields[1], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + seconds, err := strconv.ParseInt(fields[2], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + return 0, 0, 0, time.Duration(hours)*time.Hour + time.Duration(minutes)*time.Minute + time.Duration(seconds)*time.Second, nil +} + +// Format is `HH:MM`. +func extractHourMinute(format string) (int64, int64, int64, time.Duration, error) { + fields := strings.Split(format, ":") + if len(fields) != 2 { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + hours, err := strconv.ParseInt(fields[0], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + minutes, err := strconv.ParseInt(fields[1], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + return 0, 0, 0, time.Duration(hours)*time.Hour + time.Duration(minutes)*time.Minute, nil +} + +// Format is `DD HH:MM:SS.FFFFFF`. +func extractDayMicrosecond(format string) (int64, int64, int64, time.Duration, error) { + fields := strings.Split(format, " ") + if len(fields) != 2 { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + days, err := strconv.ParseInt(fields[0], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + _, _, _, value, err := extractHourMicrosecond(fields[1]) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + return 0, 0, days, value, nil +} + +// Format is `DD HH:MM:SS`. +func extractDaySecond(format string) (int64, int64, int64, time.Duration, error) { + fields := strings.Split(format, " ") + if len(fields) != 2 { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + days, err := strconv.ParseInt(fields[0], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + _, _, _, value, err := extractHourSecond(fields[1]) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + return 0, 0, days, value, nil +} + +// Format is `DD HH:MM`. +func extractDayMinute(format string) (int64, int64, int64, time.Duration, error) { + fields := strings.Split(format, " ") + if len(fields) != 2 { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + days, err := strconv.ParseInt(fields[0], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + _, _, _, value, err := extractHourMinute(fields[1]) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + return 0, 0, days, value, nil +} + +// Format is `DD HH`. +func extractDayHour(format string) (int64, int64, int64, time.Duration, error) { + fields := strings.Split(format, " ") + if len(fields) != 2 { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + days, err := strconv.ParseInt(fields[0], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + hours, err := strconv.ParseInt(fields[1], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + return 0, 0, days, time.Duration(hours) * time.Hour, nil +} + +// Format is `YYYY-MM`. +func extractYearMonth(format string) (int64, int64, int64, time.Duration, error) { + fields := strings.Split(format, "-") + if len(fields) != 2 { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + years, err := strconv.ParseInt(fields[0], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + months, err := strconv.ParseInt(fields[1], 10, 64) + if err != nil { + return 0, 0, 0, 0, errors.Errorf("invalid time format - %s", format) + } + + return years, months, 0, 0, nil +} + +// ExtractTimeValue extracts time value from time unit and format. +func ExtractTimeValue(unit string, format string) (int64, int64, int64, time.Duration, error) { + switch strings.ToUpper(unit) { + case "MICROSECOND", "SECOND", "MINUTE", "HOUR", "DAY", "WEEK", "MONTH", "QUARTER", "YEAR": + return extractSingleTimeValue(unit, format) + case "SECOND_MICROSECOND": + return extractSecondMicrosecond(format) + case "MINUTE_MICROSECOND": + return extractMinuteMicrosecond(format) + case "MINUTE_SECOND": + return extractMinuteSecond(format) + case "HOUR_MICROSECOND": + return extractHourMicrosecond(format) + case "HOUR_SECOND": + return extractHourSecond(format) + case "HOUR_MINUTE": + return extractHourMinute(format) + case "DAY_MICROSECOND": + return extractDayMicrosecond(format) + case "DAY_SECOND": + return extractDaySecond(format) + case "DAY_MINUTE": + return extractDayMinute(format) + case "DAY_HOUR": + return extractDayHour(format) + case "YEAR_MONTH": + return extractYearMonth(format) + default: + return 0, 0, 0, 0, errors.Errorf("invalid singel timeunit - %s", unit) + } +} diff --git a/mysqldef/time_test.go b/mysql/time_test.go similarity index 99% rename from mysqldef/time_test.go rename to mysql/time_test.go index f2db328247..f58851ace2 100644 --- a/mysqldef/time_test.go +++ b/mysql/time_test.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql import ( "testing" diff --git a/mysqldef/type.go b/mysql/type.go similarity index 99% rename from mysqldef/type.go rename to mysql/type.go index c9230d62b7..36f3441a94 100644 --- a/mysqldef/type.go +++ b/mysql/type.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql // MySQL type informations. const ( diff --git a/mysqldef/type_test.go b/mysql/type_test.go similarity index 98% rename from mysqldef/type_test.go rename to mysql/type_test.go index 47175d7d03..53d6fe646b 100644 --- a/mysqldef/type_test.go +++ b/mysql/type_test.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql import . "github.com/pingcap/check" diff --git a/mysqldef/util.go b/mysql/util.go similarity index 98% rename from mysqldef/util.go rename to mysql/util.go index 11f08c03a3..7a5cf72e1f 100644 --- a/mysqldef/util.go +++ b/mysql/util.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql // GetDefaultFieldLength is used for Interger Types, Flen is the display length. // Call this when no Flen assigned in ddl. diff --git a/mysqldef/util_test.go b/mysql/util_test.go similarity index 98% rename from mysqldef/util_test.go rename to mysql/util_test.go index e2e98a9f9d..acda0a42b6 100644 --- a/mysqldef/util_test.go +++ b/mysql/util_test.go @@ -11,7 +11,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package mysqldef +package mysql import "testing" diff --git a/parser/coldef/col_def.go b/parser/coldef/col_def.go index 8010c08dec..d11609c373 100644 --- a/parser/coldef/col_def.go +++ b/parser/coldef/col_def.go @@ -21,7 +21,7 @@ import ( "github.com/pingcap/tidb/column" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/util/charset" "github.com/pingcap/tidb/util/types" diff --git a/parser/coldef/opt.go b/parser/coldef/opt.go index 8287e674ed..9a31cb0add 100644 --- a/parser/coldef/opt.go +++ b/parser/coldef/opt.go @@ -18,7 +18,7 @@ import ( "strings" "github.com/pingcap/tidb/expression" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" ) // FloatOpt is used for parsing floating-point type option from SQL. diff --git a/parser/parser.y b/parser/parser.y index 8f8e4d4007..ef5ed414e8 100644 --- a/parser/parser.y +++ b/parser/parser.y @@ -29,7 +29,7 @@ import ( "fmt" "strings" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/parser/coldef" "github.com/pingcap/tidb/ddl" @@ -113,6 +113,7 @@ import ( currentUser "CURRENT_USER" database "DATABASE" databases "DATABASES" + dateAdd "DATE_ADD" day "DAY" dayofmonth "DAYOFMONTH" dayofweek "DAYOFWEEK" @@ -164,6 +165,7 @@ import ( index "INDEX" inner "INNER" insert "INSERT" + interval "INTERVAL" into "INTO" is "IS" join "JOIN" @@ -707,7 +709,6 @@ ColumnPosition: } } - AlterSpecificationList: AlterSpecification { @@ -1702,7 +1703,7 @@ UnReservedKeyword: | "NATIONAL" | "ROW" | "QUARTER" | "ESCAPE" NotKeywordToken: - "ABS" | "COALESCE" | "CONCAT" | "CONCAT_WS" | "COUNT" | "DAY" | "DAYOFMONTH" | "DAYOFWEEK" | "DAYOFYEAR" | "FOUND_ROWS" | "GROUP_CONCAT" + "ABS" | "COALESCE" | "CONCAT" | "CONCAT_WS" | "COUNT" | "DAY" | "DATE_ADD" | "DAYOFMONTH" | "DAYOFWEEK" | "DAYOFYEAR" | "FOUND_ROWS" | "GROUP_CONCAT" | "HOUR" | "IFNULL" | "LENGTH" | "LOCATE" | "MAX" | "MICROSECOND" | "MIN" | "MINUTE" | "NULLIF" | "MONTH" | "NOW" | "RAND" | "SECOND" | "SQL_CALC_FOUND_ROWS" | "SUBSTRING" %prec lowerThanLeftParen | "SUBSTRING_INDEX" | "SUM" | "TRIM" | "WEEKDAY" | "WEEKOFYEAR" | "YEARWEEK" @@ -2239,6 +2240,14 @@ FunctionCallNonKeyword: return 1 } } +| "DATE_ADD" '(' Expression ',' "INTERVAL" Expression TimeUnit ')' + { + $$ = &expression.DateAdd{ + Unit: $7.(string), + Date: $3.(expression.Expression), + Interval: $6.(expression.Expression), + } + } | "EXTRACT" '(' TimeUnit "FROM" Expression ')' { $$ = &expression.Extract{ diff --git a/parser/parser_test.go b/parser/parser_test.go index 350b5274ef..65b890047c 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -403,6 +403,28 @@ func (s *testParserSuite) TestBuiltin(c *C) { {`SELECT TRIM(LEADING 'x' FROM 'xxxbarxxx');`, true}, {`SELECT TRIM(BOTH 'x' FROM 'xxxbarxxx');`, true}, {`SELECT TRIM(TRAILING 'xyz' FROM 'barxxyz');`, true}, + + // For date_add + {`select date_add("2011-11-11 10:10:10.123456", interval 10 microsecond)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval 10 second)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval 10 minute)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval 10 hour)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval 10 day)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval 1 week)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval 1 month)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval 1 quarter)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval 1 year)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "10.10" second_microsecond)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "10:10.10" minute_microsecond)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "10:10" minute_second)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "10:10:10.10" hour_microsecond)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "10:10:10" hour_second)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "10:10" hour_minute)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "11 10:10:10.10" day_microsecond)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "11 10:10:10" day_second)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "11 10:10" day_minute)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "11 10" day_hour)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "11-11" year_month)`, true}, } s.RunTest(c, table) } diff --git a/parser/scanner.l b/parser/scanner.l index 66f31eff87..8ce33d9011 100644 --- a/parser/scanner.l +++ b/parser/scanner.l @@ -27,7 +27,7 @@ import ( "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/util/stringutil" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" ) type lexer struct { @@ -284,6 +284,7 @@ current_date {c}{u}{r}{r}{e}{n}{t}_{d}{a}{t}{e} current_user {c}{u}{r}{r}{e}{n}{t}_{u}{s}{e}{r} database {d}{a}{t}{a}{b}{a}{s}{e} databases {d}{a}{t}{a}{b}{a}{s}{e}{s} +date_add {d}{a}{t}{e}_{a}{d}{d} day {d}{a}{y} dayofweek {d}{a}{y}{o}{f}{w}{e}{e}{k} dayofmonth {d}{a}{y}{o}{f}{m}{o}{n}{t}{h} @@ -331,6 +332,7 @@ in {i}{n} index {i}{n}{d}{e}{x} inner {i}{n}{n}{e}{r} insert {i}{n}{s}{e}{r}{t} +interval {i}{n}{t}{e}{r}{v}{a}{l} into {i}{n}{t}{o} is {i}{s} join {j}{o}{i}{n} @@ -630,6 +632,8 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h} {database} lval.item = string(l.val) return database {databases} return databases +{date_add} lval.item = string(l.val) + return dateAdd {day} lval.item = string(l.val) return day {dayofweek} lval.item = string(l.val) @@ -711,6 +715,7 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h} {index} return index {inner} return inner {insert} return insert +{interval} return interval {into} return into {in} return in {is} return is diff --git a/plan/plans/fields.go b/plan/plans/fields.go index c6198ab81e..6624f8c6a1 100644 --- a/plan/plans/fields.go +++ b/plan/plans/fields.go @@ -64,10 +64,9 @@ func (r *SelectFieldsDefaultPlan) Next(ctx context.Context) (row *plan.Row, err return nil, errors.Trace(err) } - r.evalArgs[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) { - v, err0 := GetIdentValue(name, r.Src.GetFields(), srcRow.Data, field.DefaultFieldFlag) - if err0 == nil { - return v, nil + r.evalArgs[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) { + if scope == expression.IdentReferFromTable { + return srcRow.FromData[index], nil } return getIdentValueFromOuterQuery(ctx, name) diff --git a/plan/plans/final_test.go b/plan/plans/final_test.go index 8ba45d898d..d9400ca03d 100644 --- a/plan/plans/final_test.go +++ b/plan/plans/final_test.go @@ -18,7 +18,7 @@ import ( "github.com/pingcap/tidb/column" "github.com/pingcap/tidb/field" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/plan/plans" "github.com/pingcap/tidb/rset/rsets" "github.com/pingcap/tidb/util/charset" diff --git a/plan/plans/from.go b/plan/plans/from.go index a1f844fd86..a0c1450404 100644 --- a/plan/plans/from.go +++ b/plan/plans/from.go @@ -68,7 +68,7 @@ func (r *TableNilPlan) Next(ctx context.Context) (row *plan.Row, err error) { if err != nil { return nil, errors.Trace(err) } - r.iter, err = txn.Seek([]byte(r.T.FirstKey()), nil) + r.iter, err = txn.Seek([]byte(r.T.FirstKey())) if err != nil { return nil, errors.Trace(err) } @@ -274,7 +274,7 @@ func (r *TableDefaultPlan) Next(ctx context.Context) (row *plan.Row, err error) if err != nil { return nil, errors.Trace(err) } - r.iter, err = txn.Seek([]byte(r.T.FirstKey()), nil) + r.iter, err = txn.Seek([]byte(r.T.FirstKey())) if err != nil { return nil, errors.Trace(err) } diff --git a/plan/plans/from_test.go b/plan/plans/from_test.go index 1c31d6cb04..9249ef8dee 100644 --- a/plan/plans/from_test.go +++ b/plan/plans/from_test.go @@ -24,7 +24,7 @@ import ( "github.com/pingcap/tidb/field" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" "github.com/pingcap/tidb/plan/plans" "github.com/pingcap/tidb/rset/rsets" diff --git a/plan/plans/groupby.go b/plan/plans/groupby.go index ce81c002e6..14b6c422ed 100644 --- a/plan/plans/groupby.go +++ b/plan/plans/groupby.go @@ -89,15 +89,11 @@ type groupRow struct { func (r *GroupByDefaultPlan) evalGroupKey(ctx context.Context, k []interface{}, outRow []interface{}, in []interface{}) error { // group by items can not contain aggregate field, so we can eval them safely. m := map[interface{}]interface{}{} - m[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) { - v, err := GetIdentValue(name, r.Src.GetFields(), in, field.DefaultFieldFlag) - if err == nil { - return v, nil - } - - v, err = r.getFieldValueByName(name, outRow) - if err == nil { - return v, nil + m[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) { + if scope == expression.IdentReferFromTable { + return in[index], nil + } else if scope == expression.IdentReferSelectList { + return outRow[index], nil } // try to find in outer query @@ -132,10 +128,11 @@ func (r *GroupByDefaultPlan) getFieldValueByName(name string, out []interface{}) func (r *GroupByDefaultPlan) evalNoneAggFields(ctx context.Context, out []interface{}, m map[interface{}]interface{}, in []interface{}) error { - m[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) { - v, err := GetIdentValue(name, r.Src.GetFields(), in, field.DefaultFieldFlag) - if err == nil { - return v, nil + m[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) { + if scope == expression.IdentReferFromTable { + return in[index], nil + } else if scope == expression.IdentReferSelectList { + return out[index], nil } // try to find in outer query @@ -157,42 +154,20 @@ func (r *GroupByDefaultPlan) evalNoneAggFields(ctx context.Context, out []interf func (r *GroupByDefaultPlan) evalAggFields(ctx context.Context, out []interface{}, m map[interface{}]interface{}, in []interface{}) error { - // Eval aggregate field results in ctx - for i := range r.AggFields { - if i < r.HiddenFieldOffset { - m[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) { - v, err := GetIdentValue(name, r.Src.GetFields(), in, field.DefaultFieldFlag) - if err == nil { - return v, nil - } - // try to find in outer query - return getIdentValueFromOuterQuery(ctx, name) - } - } else { - // having may contain aggregate function and we will add it to hidden field, - // and this field can retrieve the data in select list, e.g. - // select c1 as a from t having count(a) = 1 - // because all the select list data is generated before, so we can get it - // when handling hidden field. - m[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) { - v, err := GetIdentValue(name, r.Src.GetFields(), in, field.DefaultFieldFlag) - if err == nil { - return v, nil - } - - // if we can not find in table, we will try to find in un-hidden select list - // only hidden field can use this - v, err = r.getFieldValueByName(name, out) - if err == nil { - return v, nil - } - - // try to find in outer query - return getIdentValueFromOuterQuery(ctx, name) - } + m[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) { + if scope == expression.IdentReferFromTable { + return in[index], nil + } else if scope == expression.IdentReferSelectList { + return out[index], nil } + // try to find in outer query + return getIdentValueFromOuterQuery(ctx, name) + } + + // Eval aggregate field results in ctx + for i := range r.AggFields { // we must evaluate aggregate function only, e.g, select col1 + count(*) in (count(*)), // we cannot evaluate it directly here, because col1 + count(*) returns nil before AggDone phase, // so we don't evaluate count(*) in In expression, and will get an invalid data in AggDone phase for it. diff --git a/plan/plans/groupby_test.go b/plan/plans/groupby_test.go index 9fde4aea41..9c3caab894 100644 --- a/plan/plans/groupby_test.go +++ b/plan/plans/groupby_test.go @@ -42,12 +42,16 @@ func (t *testGroupBySuite) TestGroupBy(c *C) { Fields: []*field.Field{ { Expr: &expression.Ident{ - CIStr: model.NewCIStr("id"), + CIStr: model.NewCIStr("id"), + ReferScope: expression.IdentReferFromTable, + ReferIndex: 0, }, }, { Expr: &expression.Ident{ - CIStr: model.NewCIStr("name"), + CIStr: model.NewCIStr("name"), + ReferScope: expression.IdentReferFromTable, + ReferIndex: 1, }, }, { @@ -55,7 +59,9 @@ func (t *testGroupBySuite) TestGroupBy(c *C) { F: "sum", Args: []expression.Expression{ &expression.Ident{ - CIStr: model.NewCIStr("id"), + CIStr: model.NewCIStr("id"), + ReferScope: expression.IdentReferFromTable, + ReferIndex: 0, }, }, }, @@ -69,7 +75,9 @@ func (t *testGroupBySuite) TestGroupBy(c *C) { Src: tblPlan, By: []expression.Expression{ &expression.Ident{ - CIStr: model.NewCIStr("id"), + CIStr: model.NewCIStr("id"), + ReferScope: expression.IdentReferFromTable, + ReferIndex: 0, }, }, } diff --git a/plan/plans/having.go b/plan/plans/having.go index 231f059eaa..62c64d3647 100644 --- a/plan/plans/having.go +++ b/plan/plans/having.go @@ -62,10 +62,11 @@ func (r *HavingPlan) Next(ctx context.Context) (row *plan.Row, err error) { if srcRow == nil || err != nil { return nil, errors.Trace(err) } - r.evalArgs[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) { - v, err0 := GetIdentValue(name, r.Src.GetFields(), srcRow.Data, field.CheckFieldFlag) - if err0 == nil { - return v, nil + r.evalArgs[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) { + if scope == expression.IdentReferFromTable { + return srcRow.FromData[index], nil + } else if scope == expression.IdentReferSelectList { + return srcRow.Data[index], nil } // try to find in outer query diff --git a/plan/plans/having_test.go b/plan/plans/having_test.go index a3777fe9f2..21a4fc19cf 100644 --- a/plan/plans/having_test.go +++ b/plan/plans/having_test.go @@ -38,7 +38,9 @@ func (t *testHavingPlan) TestHaving(c *C) { Expr: &expression.BinaryOperation{ Op: opcode.GE, L: &expression.Ident{ - CIStr: model.NewCIStr("id"), + CIStr: model.NewCIStr("id"), + ReferScope: expression.IdentReferSelectList, + ReferIndex: 0, }, R: &expression.Value{ Val: 20, diff --git a/plan/plans/index_test.go b/plan/plans/index_test.go index 0cc49fbeac..1a713e0e27 100644 --- a/plan/plans/index_test.go +++ b/plan/plans/index_test.go @@ -24,7 +24,7 @@ import ( "github.com/pingcap/tidb/field" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" "github.com/pingcap/tidb/plan/plans" "github.com/pingcap/tidb/rset/rsets" diff --git a/plan/plans/info.go b/plan/plans/info.go index 9b222c7a6a..6c45fb002c 100644 --- a/plan/plans/info.go +++ b/plan/plans/info.go @@ -25,7 +25,7 @@ import ( "github.com/pingcap/tidb/field" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/plan" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/util/charset" diff --git a/plan/plans/join.go b/plan/plans/join.go index cd04caa6cd..bcea22bad0 100644 --- a/plan/plans/join.go +++ b/plan/plans/join.go @@ -155,6 +155,7 @@ func (r *JoinPlan) Next(ctx context.Context) (row *plan.Row, err error) { } func (r *JoinPlan) nextLeftJoin(ctx context.Context) (row *plan.Row, err error) { + visitor := newJoinIdentVisitor(0) for { if r.cursor < len(r.matchedRows) { row = r.matchedRows[r.cursor] @@ -167,7 +168,7 @@ func (r *JoinPlan) nextLeftJoin(ctx context.Context) (row *plan.Row, err error) return nil, errors.Trace(err) } tempExpr := r.On.Clone() - visitor := NewIdentEvalVisitor(r.Left.GetFields(), leftRow.Data) + visitor.row = leftRow.Data _, err = tempExpr.Accept(visitor) if err != nil { return nil, errors.Trace(err) @@ -189,6 +190,7 @@ func (r *JoinPlan) nextLeftJoin(ctx context.Context) (row *plan.Row, err error) } func (r *JoinPlan) nextRightJoin(ctx context.Context) (row *plan.Row, err error) { + visitor := newJoinIdentVisitor(len(r.Left.GetFields())) for { if r.cursor < len(r.matchedRows) { row = r.matchedRows[r.cursor] @@ -202,7 +204,7 @@ func (r *JoinPlan) nextRightJoin(ctx context.Context) (row *plan.Row, err error) } tempExpr := r.On.Clone() - visitor := NewIdentEvalVisitor(r.Right.GetFields(), rightRow.Data) + visitor.row = rightRow.Data _, err = tempExpr.Accept(visitor) if err != nil { return nil, errors.Trace(err) @@ -251,8 +253,8 @@ func (r *JoinPlan) findMatchedRows(ctx context.Context, row *plan.Row, p plan.Pl } else { joined = append(append(joined, row.Data...), cmpRow.Data...) } - r.evalArgs[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) { - return GetIdentValue(name, r.Fields, joined, field.DefaultFieldFlag) + r.evalArgs[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) { + return joined[index], nil } var b bool b, err = expression.EvalBoolExpr(ctx, r.On, r.evalArgs) @@ -280,6 +282,7 @@ func (r *JoinPlan) findMatchedRows(ctx context.Context, row *plan.Row, p plan.Pl } func (r *JoinPlan) nextCrossJoin(ctx context.Context) (row *plan.Row, err error) { + visitor := newJoinIdentVisitor(0) for { if r.curRow == nil { r.curRow, err = r.Left.Next(ctx) @@ -292,7 +295,7 @@ func (r *JoinPlan) nextCrossJoin(ctx context.Context) (row *plan.Row, err error) if r.On != nil { tempExpr := r.On.Clone() - visitor := NewIdentEvalVisitor(r.Left.GetFields(), r.curRow.Data) + visitor.row = r.curRow.Data _, err = tempExpr.Accept(visitor) if err != nil { return nil, errors.Trace(err) @@ -323,8 +326,8 @@ func (r *JoinPlan) nextCrossJoin(ctx context.Context) (row *plan.Row, err error) joinedRow = append(append(joinedRow, r.curRow.Data...), rightRow.Data...) if r.On != nil { - r.evalArgs[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) { - return GetIdentValue(name, r.Fields, joinedRow, field.DefaultFieldFlag) + r.evalArgs[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) { + return joinedRow[index], nil } b, err := expression.EvalBoolExpr(ctx, r.On, r.evalArgs) @@ -356,32 +359,34 @@ func (r *JoinPlan) Close() error { return r.Left.Close() } -// IdentEvalVisitor converts Ident expression to value expression. -type IdentEvalVisitor struct { +// joinIdentVisitor converts Ident expression to value expression in ON expression. +type joinIdentVisitor struct { expression.BaseVisitor - fields []*field.ResultField row []interface{} + offset int } -// NewIdentEvalVisitor creates a new IdentEvalVisitor. -func NewIdentEvalVisitor(fields []*field.ResultField, row []interface{}) *IdentEvalVisitor { - iev := &IdentEvalVisitor{fields: fields, row: row} +// newJoinIdentVisitor creates a new joinIdentVisitor. +func newJoinIdentVisitor(offset int) *joinIdentVisitor { + iev := &joinIdentVisitor{offset: offset} iev.BaseVisitor.V = iev return iev } // VisitIdent implements Visitor interface. -func (iev *IdentEvalVisitor) VisitIdent(i *expression.Ident) (expression.Expression, error) { - v, err := GetIdentValue(i.L, iev.fields, iev.row, field.CheckFieldFlag) - if err != nil { +func (iev *joinIdentVisitor) VisitIdent(i *expression.Ident) (expression.Expression, error) { + // the row here may be just left part or right part, but identifier may reference another part, + // so here we must check its reference index validation. + if i.ReferIndex < iev.offset || i.ReferIndex >= iev.offset+len(iev.row) { return i, nil } + v := iev.row[i.ReferIndex-iev.offset] return expression.Value{Val: v}, nil } // VisitBinaryOperation swaps the right side identifier to left side if left side expression is static. // So it can be used in index plan. -func (iev *IdentEvalVisitor) VisitBinaryOperation(binop *expression.BinaryOperation) (expression.Expression, error) { +func (iev *joinIdentVisitor) VisitBinaryOperation(binop *expression.BinaryOperation) (expression.Expression, error) { var err error binop.L, err = binop.L.Accept(iev) if err != nil { diff --git a/plan/plans/join_test.go b/plan/plans/join_test.go index a14fa38525..c38824a22c 100644 --- a/plan/plans/join_test.go +++ b/plan/plans/join_test.go @@ -20,7 +20,7 @@ import ( "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/field" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" "github.com/pingcap/tidb/plan/plans" "github.com/pingcap/tidb/rset/rsets" diff --git a/plan/plans/orderby.go b/plan/plans/orderby.go index 347f536e48..4708339203 100644 --- a/plan/plans/orderby.go +++ b/plan/plans/orderby.go @@ -158,10 +158,11 @@ func (r *OrderByDefaultPlan) fetchAll(ctx context.Context) error { if row == nil { break } - evalArgs[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) { - v, err := GetIdentValue(name, r.ResultFields, row.Data, field.CheckFieldFlag) - if err == nil { - return v, nil + evalArgs[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) { + if scope == expression.IdentReferFromTable { + return row.FromData[index], nil + } else if scope == expression.IdentReferSelectList { + return row.Data[index], nil } // try to find in outer query diff --git a/plan/plans/orderby_test.go b/plan/plans/orderby_test.go index 2b434300ad..828081016d 100644 --- a/plan/plans/orderby_test.go +++ b/plan/plans/orderby_test.go @@ -48,7 +48,9 @@ func (t *testOrderBySuit) TestOrderBy(c *C) { Src: tblPlan, By: []expression.Expression{ &expression.Ident{ - CIStr: model.NewCIStr("id"), + CIStr: model.NewCIStr("id"), + ReferScope: expression.IdentReferSelectList, + ReferIndex: 0, }, }, Ascs: []bool{false}, diff --git a/plan/plans/plans.go b/plan/plans/plans.go index 0b3c689ad7..393aa98f28 100644 --- a/plan/plans/plans.go +++ b/plan/plans/plans.go @@ -22,7 +22,7 @@ import ( "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/field" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/plan" "github.com/pingcap/tidb/util/charset" "github.com/pingcap/tidb/util/format" diff --git a/plan/plans/row_stack.go b/plan/plans/row_stack.go index f4482bcfca..7b10919d29 100644 --- a/plan/plans/row_stack.go +++ b/plan/plans/row_stack.go @@ -199,5 +199,5 @@ func getIdentValueFromOuterQuery(ctx context.Context, name string) (interface{}, } } - return nil, errors.Trace(err) + return nil, errors.Errorf("unknown field %s", name) } diff --git a/plan/plans/select_list.go b/plan/plans/select_list.go index 7d918fceed..52b76fa26a 100644 --- a/plan/plans/select_list.go +++ b/plan/plans/select_list.go @@ -95,63 +95,43 @@ func (s *SelectList) GetFields() []*field.ResultField { } // UpdateAggFields adds aggregate function resultfield to select result field list. -func (s *SelectList) UpdateAggFields(expr expression.Expression, tableFields []*field.ResultField) (expression.Expression, error) { - // For aggregate function, the name can be in table or select list. - names := expression.MentionedColumns(expr) - - for _, name := range names { - if field.ContainFieldName(name, tableFields, field.DefaultFieldFlag) { - continue - } - - if field.ContainFieldName(name, s.ResultFields, field.DefaultFieldFlag) { - continue - } - - return nil, errors.Errorf("Unknown column '%s'", name) - } - +func (s *SelectList) UpdateAggFields(expr expression.Expression) (expression.Expression, error) { // We must add aggregate function to hidden select list // and use a position expression to fetch its value later. - exprName := expr.String() - idx := field.GetResultFieldIndex(exprName, s.ResultFields, field.CheckFieldFlag) - if len(idx) == 0 { - f := &field.Field{Expr: expr} - resultField := &field.ResultField{Name: exprName} - s.AddField(f, resultField) + name := strings.ToLower(expr.String()) + index := -1 + for i := 0; i < s.HiddenFieldOffset; i++ { + // only check origin name, e,g. "select sum(c1) as a from t order by sum(c1)" + // or "select sum(c1) from t order by sum(c1)" + if s.ResultFields[i].ColumnInfo.Name.L == name { + index = i + break + } + } - return &expression.Position{N: len(s.Fields), Name: exprName}, nil + if index == -1 { + f := &field.Field{Expr: expr} + s.AddField(f, nil) + + pos := len(s.Fields) + + s.AggFields[pos-1] = struct{}{} + + return &expression.Position{N: pos, Name: name}, nil } // select list has this field, use it directly. - return &expression.Position{N: idx[0] + 1, Name: exprName}, nil + return &expression.Position{N: index + 1, Name: name}, nil } -// CloneHiddenField checks and clones field and result field from table fields, -// and adds them to hidden field of select list. -func (s *SelectList) CloneHiddenField(name string, tableFields []*field.ResultField) bool { - // Check and add hidden field. - if field.ContainFieldName(name, tableFields, field.CheckFieldFlag) { - resultField, _ := field.CloneFieldByName(name, tableFields, field.CheckFieldFlag) - f := &field.Field{ - Expr: &expression.Ident{ - CIStr: resultField.ColumnInfo.Name, - }, - } - s.AddField(f, resultField) - return true - } - - return false -} - -// CheckReferAmbiguous checks whether an identifier reference is ambiguous or not in select list. +// CheckAmbiguous checks whether an identifier reference is ambiguous or not in select list. // e,g, "select c1 as a, c2 as a from t group by a" is ambiguous, // but "select c1 as a, c1 as a from t group by a" is not. -// For MySQL "select c1 as a, c2 + 1 as a from t group by a" is not ambiguous too, -// so we will only check identifier too. -// If no ambiguous, -1 means expr refers none in select list, else an index in select list returns. -func (s *SelectList) CheckReferAmbiguous(expr expression.Expression) (int, error) { +// "select c1 as a, c2 + 1 as a from t group by a" is not ambiguous too, +// If no ambiguous, -1 means expr refers none in select list, else an index for first match. +// CheckAmbiguous will break the check when finding first matching which is not an indentifier, +// or an index for an identifier field in the end, -1 means none found. +func (s *SelectList) CheckAmbiguous(expr expression.Expression) (int, error) { if _, ok := expr.(*expression.Ident); !ok { return -1, nil } @@ -162,14 +142,22 @@ func (s *SelectList) CheckReferAmbiguous(expr expression.Expression) (int, error return -1, nil } + // select c1 as a, 1 as a, c2 as a from t order by a is not ambiguous. + // select c1 as a, c2 as a from t order by a is ambiguous. + // select 1 as a, c1 as a from t order by a is not ambiguous. + // select c1 as a, sum(c1) as a from t group by a is error. + // select c1 as a, 1 as a, sum(c1) as a from t group by a is not error. + // so we will break the check if matching a none identifier field. lastIndex := -1 // only check origin select list, no hidden field. for i := 0; i < s.HiddenFieldOffset; i++ { if !strings.EqualFold(s.ResultFields[i].Name, name) { continue - } else if _, ok := s.Fields[i].Expr.(*expression.Ident); !ok { - // not identfier, no check - continue + } + + if _, ok := s.Fields[i].Expr.(*expression.Ident); !ok { + // not identfier, return directly. + return i, nil } if lastIndex == -1 { @@ -179,6 +167,7 @@ func (s *SelectList) CheckReferAmbiguous(expr expression.Expression) (int, error } // check origin name, e,g. "select c1 as c2, c2 from t group by c2" is ambiguous. + if s.ResultFields[i].ColumnInfo.Name.L != s.ResultFields[lastIndex].ColumnInfo.Name.L { return -1, errors.Errorf("refer %s is ambiguous", expr) } @@ -233,6 +222,7 @@ func ResolveSelectList(selectFields []*field.Field, srcFields []*field.ResultFie continue } + // TODO: use fromIdentVisitor to cleanup. var result *field.ResultField for _, name := range names { idx := field.GetResultFieldIndex(name, srcFields, field.DefaultFieldFlag) diff --git a/plan/plans/select_list_test.go b/plan/plans/select_list_test.go index 383a14e69b..7db298422d 100644 --- a/plan/plans/select_list_test.go +++ b/plan/plans/select_list_test.go @@ -96,7 +96,7 @@ func (s *testSelectListSuite) TestAmbiguous(c *C) { sl, err := plans.ResolveSelectList(fs, rs) c.Assert(err, IsNil) - index, err := sl.CheckReferAmbiguous(&expression.Ident{ + idx, err := sl.CheckAmbiguous(&expression.Ident{ CIStr: model.NewCIStr(t.Name), }) @@ -106,6 +106,6 @@ func (s *testSelectListSuite) TestAmbiguous(c *C) { } c.Assert(err, IsNil) - c.Assert(t.Index, Equals, index) + c.Assert(t.Index, DeepEquals, idx) } } diff --git a/plan/plans/show.go b/plan/plans/show.go index 87eff2e91a..62a045412d 100644 --- a/plan/plans/show.go +++ b/plan/plans/show.go @@ -25,7 +25,7 @@ import ( "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/field" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/plan" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" diff --git a/plan/plans/show_test.go b/plan/plans/show_test.go index 7b36ebf810..acde709fe6 100644 --- a/plan/plans/show_test.go +++ b/plan/plans/show_test.go @@ -22,7 +22,7 @@ import ( "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" "github.com/pingcap/tidb/plan/plans" "github.com/pingcap/tidb/rset/rsets" diff --git a/plan/plans/union_test.go b/plan/plans/union_test.go index 478f7b8f22..70653ae024 100644 --- a/plan/plans/union_test.go +++ b/plan/plans/union_test.go @@ -18,7 +18,7 @@ import ( "github.com/pingcap/tidb/column" "github.com/pingcap/tidb/field" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/plan" "github.com/pingcap/tidb/plan/plans" "github.com/pingcap/tidb/rset/rsets" diff --git a/plan/plans/where.go b/plan/plans/where.go index 5f9bf59eeb..2a1a866ff1 100644 --- a/plan/plans/where.go +++ b/plan/plans/where.go @@ -56,15 +56,15 @@ func (r *FilterDefaultPlan) Next(ctx context.Context) (row *plan.Row, err error) if row == nil || err != nil { return nil, errors.Trace(err) } - r.evalArgs[expression.ExprEvalIdentFunc] = func(name string) (interface{}, error) { - v, err0 := GetIdentValue(name, r.GetFields(), row.Data, field.DefaultFieldFlag) - if err0 == nil { - return v, nil + r.evalArgs[expression.ExprEvalIdentReferFunc] = func(name string, scope int, index int) (interface{}, error) { + if scope == expression.IdentReferFromTable { + return row.Data[index], nil } // try to find in outer query return getIdentValueFromOuterQuery(ctx, name) } + var meet bool meet, err = r.meetCondition(ctx) if err != nil { diff --git a/plan/plans/where_test.go b/plan/plans/where_test.go index 6c890608aa..f1effcfbc6 100644 --- a/plan/plans/where_test.go +++ b/plan/plans/where_test.go @@ -44,7 +44,9 @@ func (t *testWhereSuit) TestWhere(c *C) { Expr: &expression.BinaryOperation{ Op: opcode.GE, L: &expression.Ident{ - CIStr: model.NewCIStr("id"), + CIStr: model.NewCIStr("id"), + ReferScope: expression.IdentReferFromTable, + ReferIndex: 0, }, R: expression.Value{ Val: 30, diff --git a/rset/rsets/fields.go b/rset/rsets/fields.go index 091519fa59..ffe6176409 100644 --- a/rset/rsets/fields.go +++ b/rset/rsets/fields.go @@ -39,10 +39,29 @@ type SelectFieldsRset struct { SelectList *plans.SelectList } +func updateSelectFieldsRefer(selectList *plans.SelectList) error { + visitor := newFromIdentVisitor(selectList.FromFields, fieldListClause) + + // we only fix un-hidden fields here, for hidden fields, it should be + // handled in their own clause, in order by or having. + for i, f := range selectList.Fields[0:selectList.HiddenFieldOffset] { + e, err := f.Expr.Accept(visitor) + if err != nil { + return errors.Trace(err) + } + selectList.Fields[i].Expr = e + } + return nil +} + // Plan gets SrcPlan/SelectFieldsDefaultPlan. // If all fields are equal to src plan fields, then gets SrcPlan. // Default gets SelectFieldsDefaultPlan. func (r *SelectFieldsRset) Plan(ctx context.Context) (plan.Plan, error) { + if err := updateSelectFieldsRefer(r.SelectList); err != nil { + return nil, errors.Trace(err) + } + fields := r.SelectList.Fields srcFields := r.Src.GetFields() if len(fields) == len(srcFields) { @@ -97,6 +116,7 @@ func (r *SelectFromDualRset) Plan(ctx context.Context) (plan.Plan, error) { // field cannot contain identifier for _, f := range r.Fields { if cs := expression.MentionedColumns(f.Expr); len(cs) > 0 { + // TODO: check in outer query, like select * from t where t.c = (select c limit 1); return nil, errors.Errorf("Unknown column '%s' in 'field list'", cs[0]) } } diff --git a/rset/rsets/fields_test.go b/rset/rsets/fields_test.go index 36ad26f297..aa63daa5b4 100644 --- a/rset/rsets/fields_test.go +++ b/rset/rsets/fields_test.go @@ -69,6 +69,7 @@ func (s *testSelectFieldsPlannerSuite) TestDistinctPlanner(c *C) { oldFld := s.sr.SelectList.Fields[1] s.sr.SelectList.Fields = s.sr.SelectList.Fields[:1] + s.sr.SelectList.HiddenFieldOffset = len(s.sr.SelectList.Fields) p, err = s.sr.Plan(nil) c.Assert(err, IsNil) @@ -81,6 +82,7 @@ func (s *testSelectFieldsPlannerSuite) TestDistinctPlanner(c *C) { s.sr.Src = tdp s.sr.SelectList.Fields = []*field.Field{fld} + s.sr.SelectList.HiddenFieldOffset = len(s.sr.SelectList.Fields) p, err = s.sr.Plan(nil) c.Assert(err, IsNil) @@ -94,6 +96,7 @@ func (s *testSelectFieldsPlannerSuite) TestDistinctPlanner(c *C) { // cover isConst check, like `select 1, c1 from t` s.sr.SelectList.Fields = []*field.Field{fld, oldFld} + s.sr.SelectList.HiddenFieldOffset = len(s.sr.SelectList.Fields) p, err = s.sr.Plan(nil) c.Assert(err, IsNil) diff --git a/rset/rsets/groupby.go b/rset/rsets/groupby.go index aa3560bf97..d95ce8f073 100644 --- a/rset/rsets/groupby.go +++ b/rset/rsets/groupby.go @@ -39,70 +39,110 @@ type GroupByRset struct { SelectList *plans.SelectList } +type groupByVisitor struct { + expression.BaseVisitor + selectList *plans.SelectList + rootIdent *expression.Ident +} + +func (v *groupByVisitor) VisitIdent(i *expression.Ident) (expression.Expression, error) { + // Group by ambiguous rule: + // select c1 as a, c2 as a from t group by a is ambiguous + // select c1 as a, c2 as a from t group by a + 1 is ambiguous + // select c1 as c2, c2 from t group by c2 is ambiguous + // select c1 as c2, c2 from t group by c2 + 1 is not ambiguous + + var ( + index int + err error + ) + + if v.rootIdent == i { + // The group by is an identifier, we must check it first. + index, err = checkIdentAmbiguous(i, v.selectList, groupByClause) + if err != nil { + return nil, errors.Trace(err) + } + } + + // first find this identifier in FROM. + idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields, field.DefaultFieldFlag) + if len(idx) > 0 { + i.ReferScope = expression.IdentReferFromTable + i.ReferIndex = idx[0] + return i, nil + } + + if v.rootIdent != i { + // This identifier is the part of the group by, check ambiguous here. + index, err = checkIdentAmbiguous(i, v.selectList, groupByClause) + if err != nil { + return nil, errors.Trace(err) + } + } + + // try to find in select list, we have got index using checkIdent before. + if index >= 0 { + // group by can not reference aggregate fields + if _, ok := v.selectList.AggFields[index]; ok { + return nil, errors.Errorf("Reference '%s' not supported (reference to group function)", i) + } + + // find in select list + i.ReferScope = expression.IdentReferSelectList + i.ReferIndex = index + return i, nil + } + + // TODO: check in out query + return i, errors.Errorf("Unknown column '%s' in 'group statement'", i) +} + +func (v *groupByVisitor) VisitCall(c *expression.Call) (expression.Expression, error) { + ok := expression.IsAggregateFunc(c.F) + if ok { + return nil, errors.Errorf("group by cannot contain aggregate function %s", c) + } + + var err error + for i, e := range c.Args { + c.Args[i], err = e.Accept(v) + if err != nil { + return nil, errors.Trace(err) + } + } + + return c, nil +} + // Plan gets GroupByDefaultPlan. func (r *GroupByRset) Plan(ctx context.Context) (plan.Plan, error) { - fields := r.SelectList.Fields + if err := updateSelectFieldsRefer(r.SelectList); err != nil { + return nil, errors.Trace(err) + } - r.SelectList.AggFields = GetAggFields(fields) - aggFields := r.SelectList.AggFields + visitor := &groupByVisitor{} + visitor.BaseVisitor.V = visitor + visitor.selectList = r.SelectList for i, e := range r.By { - if v, ok := e.(expression.Value); ok { - var position int - switch u := v.Val.(type) { - case int64: - position = int(u) - case uint64: - position = int(u) - default: - continue - } - - if position < 1 || position > len(fields) { - return nil, errors.Errorf("Unknown column '%d' in 'group statement'", position) - } - - index := position - 1 - if _, ok := aggFields[index]; ok { - return nil, errors.Errorf("Can't group on '%s'", fields[index]) - } - - // use Position expression for the associated field. - r.By[i] = &expression.Position{N: position} - } else { - index, err := r.SelectList.CheckReferAmbiguous(e) - if err != nil { - return nil, errors.Errorf("Column '%s' in group statement is ambiguous", e) - } else if _, ok := aggFields[index]; ok { - return nil, errors.Errorf("Can't group on '%s'", e) - } - - // TODO: check more ambiguous case - // Group by ambiguous rule: - // select c1 as a, c2 as a from t group by a is ambiguous - // select c1 as a, c2 as a from t group by a + 1 is ambiguous - // select c1 as c2, c2 from t group by c2 is ambiguous - // select c1 as c2, c2 from t group by c2 + 1 is ambiguous - - // TODO: use visitor to check aggregate function - names := expression.MentionedColumns(e) - for _, name := range names { - indices := field.GetFieldIndex(name, fields[0:r.SelectList.HiddenFieldOffset], field.DefaultFieldFlag) - if len(indices) == 1 { - // check reference to aggregate function, like `select c1, count(c1) as b from t group by b + 1`. - index := indices[0] - if _, ok := aggFields[index]; ok { - return nil, errors.Errorf("Reference '%s' not supported (reference to group function)", name) - } - } - } - - // group by should be an expression, a qualified field name or a select field position, - // but can not contain any aggregate function. - if e := r.By[i]; expression.ContainAggregateFunc(e) { - return nil, errors.Errorf("group by cannot contain aggregate function %s", e.String()) - } + pos, err := castPosition(e, r.SelectList, groupByClause) + if err != nil { + return nil, errors.Trace(err) } + + if pos != nil { + // use Position expression for the associated field. + r.By[i] = pos + continue + } + + visitor.rootIdent = castIdent(e) + by, err := e.Accept(visitor) + if err != nil { + return nil, errors.Trace(err) + } + r.By[i] = by } return &plans.GroupByDefaultPlan{By: r.By, Src: r.Src, diff --git a/rset/rsets/groupby_test.go b/rset/rsets/groupby_test.go index 9a87172316..eebe20d9cb 100644 --- a/rset/rsets/groupby_test.go +++ b/rset/rsets/groupby_test.go @@ -50,6 +50,11 @@ func (s *testGroupByRsetSuite) SetUpSuite(c *C) { s.r = &GroupByRset{Src: tblPlan, SelectList: selectList, By: by} } +func resetAggFields(selectList *plans.SelectList) { + fields := selectList.Fields + selectList.AggFields = GetAggFields(fields) +} + func (s *testGroupByRsetSuite) TestGroupByRsetPlan(c *C) { // `select id, name from t group by name` p, err := s.r.Plan(nil) @@ -88,6 +93,8 @@ func (s *testGroupByRsetSuite) TestGroupByRsetPlan(c *C) { fld := &field.Field{Expr: fldExpr, AsName: "a"} s.r.SelectList.Fields[0] = fld + resetAggFields(s.r.SelectList) + s.r.By[0] = expression.Value{Val: int64(1)} _, err = s.r.Plan(nil) @@ -115,6 +122,8 @@ func (s *testGroupByRsetSuite) TestGroupByRsetPlan(c *C) { s.r.SelectList.ResultFields[0].Col.Name = model.NewCIStr("count(id)") s.r.SelectList.ResultFields[0].Name = "a" + resetAggFields(s.r.SelectList) + _, err = s.r.Plan(nil) c.Assert(err, NotNil) diff --git a/rset/rsets/having.go b/rset/rsets/having.go index 39ba3cc561..be43f3cc02 100644 --- a/rset/rsets/having.go +++ b/rset/rsets/having.go @@ -31,66 +31,274 @@ type HavingRset struct { Src plan.Plan Expr expression.Expression SelectList *plans.SelectList + // Group by clause + GroupBy []expression.Expression } -// CheckAndUpdateSelectList checks having fields validity and set hidden fields to selectList. -func (r *HavingRset) CheckAndUpdateSelectList(selectList *plans.SelectList, groupBy []expression.Expression, tableFields []*field.ResultField) error { +// CheckAggregate will check whether order by has aggregate function or not, +// if has, we will add it to select list hidden field. +func (r *HavingRset) CheckAggregate(selectList *plans.SelectList) error { if expression.ContainAggregateFunc(r.Expr) { - expr, err := selectList.UpdateAggFields(r.Expr, tableFields) + expr, err := selectList.UpdateAggFields(r.Expr) if err != nil { return errors.Errorf("%s in 'having clause'", err.Error()) } r.Expr = expr - } else { - // having can only contain group by column and select list, e.g, - // `select c1 from t group by c2 having c3 > 0` is invalid, - // because c3 is not in group by and select list. - names := expression.MentionedColumns(r.Expr) - for _, name := range names { - found := false - - // check name whether in select list. - // notice that `select c1 as c2 from t group by c1, c2, c3 having c2 > c3`, - // will use t.c2 not t.c1 here. - if field.ContainFieldName(name, selectList.ResultFields, field.OrgFieldNameFlag) { - continue - } - if field.ContainFieldName(name, selectList.ResultFields, field.FieldNameFlag) { - if field.ContainFieldName(name, tableFields, field.OrgFieldNameFlag) { - selectList.CloneHiddenField(name, tableFields) - } - continue - } - - // check name whether in group by. - // group by must only have column name, e.g, - // `select c1 from t group by c2 having c2 > 0` is valid, - // but `select c1 from t group by c2 + 1 having c2 > 0` is invalid. - for _, by := range groupBy { - if !field.CheckFieldsEqual(name, by.String()) { - continue - } - - // if name is not in table fields, it will get an unknown field error in GroupByRset, - // so no need to check return value. - selectList.CloneHiddenField(name, tableFields) - - found = true - break - } - - if !found { - return errors.Errorf("Unknown column '%s' in 'having clause'", name) - } - } } return nil } +type havingVisitor struct { + expression.BaseVisitor + selectList *plans.SelectList + + // for group by + groupBy []expression.Expression + + // true means we are visiting aggregate function arguments now. + inAggregate bool +} + +func (v *havingVisitor) visitIdentInAggregate(i *expression.Ident) (expression.Expression, error) { + // if we are visiting aggregate function arguments, the identifier first checks in from table, + // then in select list, and outer query finally. + + // find this identifier in FROM. + idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields, field.DefaultFieldFlag) + if len(idx) > 0 { + i.ReferScope = expression.IdentReferFromTable + i.ReferIndex = idx[0] + return i, nil + } + + // check in select list. + index, err := checkIdentAmbiguous(i, v.selectList, havingClause) + if err != nil { + return i, errors.Trace(err) + } + + if index >= 0 { + // find in select list + i.ReferScope = expression.IdentReferSelectList + i.ReferIndex = index + return i, nil + } + + // TODO: check in out query + // TODO: return unknown field error, but now just return directly. + // Because this may reference outer query. + return i, nil +} + +func (v *havingVisitor) checkIdentInGroupBy(i *expression.Ident) (*expression.Ident, bool, error) { + for _, by := range v.groupBy { + e := castIdent(by) + if e == nil { + // group by must be a identifier too + continue + } + + if !field.CheckFieldsEqual(e.L, i.L) { + // not same, continue + continue + } + + // if group by references select list or having is qualified identifier, + // no other check. + if e.ReferScope == expression.IdentReferSelectList || field.IsQualifiedName(i.L) { + i.ReferScope = e.ReferScope + i.ReferIndex = e.ReferIndex + return i, true, nil + } + + // having is unqualified name, e.g, select * from t1, t2 group by t1.c having c. + // both t1 and t2 have column c, we must check ambiguous here. + + idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields, field.DefaultFieldFlag) + if len(idx) > 1 { + return i, false, errors.Errorf("Column '%s' in having clause is ambiguous", i) + } + + i.ReferScope = e.ReferScope + i.ReferIndex = e.ReferIndex + return i, true, nil + } + + return i, false, nil +} + +func (v *havingVisitor) checkIdentInSelectList(i *expression.Ident) (*expression.Ident, bool, error) { + index, err := checkIdentAmbiguous(i, v.selectList, havingClause) + if err != nil { + return i, false, errors.Trace(err) + } + + if index >= 0 { + // identifier references a select field. use it directly. + // e,g. select c1 as c2 from t having c2, here c2 references c1. + i.ReferScope = expression.IdentReferSelectList + i.ReferIndex = index + return i, true, nil + } + + lastIndex := -1 + var lastFieldName string + // we may meet this select c1 as c2 from t having c1, so we must check origin field name too. + for index := 0; index < v.selectList.HiddenFieldOffset; index++ { + e := castIdent(v.selectList.Fields[index].Expr) + if e == nil { + // not identifier + continue + } + + if !field.CheckFieldsEqual(e.L, i.L) { + // not same, continue + continue + } + + if field.IsQualifiedName(i.L) { + // qualified name, no need check + i.ReferScope = expression.IdentReferSelectList + i.ReferIndex = index + return i, true, nil + } + + if lastIndex == -1 { + lastIndex = index + lastFieldName = e.L + continue + } + + // we may meet select t1.c as a, t2.c as b from t1, t2 having c, must check ambiguous here + if !field.CheckFieldsEqual(lastFieldName, e.L) { + return i, false, errors.Errorf("Column '%s' in having clause is ambiguous", i) + } + + i.ReferScope = expression.IdentReferSelectList + i.ReferIndex = index + return i, true, nil + } + + if lastIndex != -1 { + i.ReferScope = expression.IdentReferSelectList + i.ReferIndex = lastIndex + return i, true, nil + } + + return i, false, nil +} + +func (v *havingVisitor) VisitIdent(i *expression.Ident) (expression.Expression, error) { + // Having has the most complex rule for ambiguous and identifer reference check. + + // Having ambiguous rule: + // select c1 as a, c2 as a from t having a is ambiguous + // select c1 as a, c2 as a from t having a + 1 is ambiguous + // select c1 as c2, c2 from t having c2 is ambiguous + // select c1 as c2, c2 from t having c2 + 1 is ambiguous + + // The identifier in having must exist in group by, select list or outer query, so select c1 from t having c2 is wrong, + // the identifier will first try to find the reference in group by, then in select list and finally in outer query. + + // Having identifier reference check + // select c1 from t group by c2 having c2, c2 in having references group by c2. + // select c1 as c2 from t having c2, c2 in having references c1 in select field. + // select c1 as c2 from t group by c2 having c2, c2 in having references group by c2. + // select c1 as c2 from t having sum(c2), c2 in having sum(c2) references t.c2 in from table. + // select c1 as c2 from t having sum(c2) + c2, c2 in having sum(c2) references t.c2 in from table, but another c2 references c1 in select list. + // select c1 as c2 from t group by c2 having sum(c2) + c2, c2 in having sum(c2) references t.c2 in from table, another c2 references c2 in group by. + // select c1 as a from t having a, a references c1 in select field. + // select c1 as a from t having sum(a), a in order by sum(a) references c1 in select field. + // select c1 as c2 from t having c1, c1 in having references c1 in select list. + + // TODO: unify VisitIdent for group by and order by visitor, they are little different. + + if v.inAggregate { + return v.visitIdentInAggregate(i) + } + + var ( + err error + ok bool + ) + + // first, check in group by + i, ok, err = v.checkIdentInGroupBy(i) + if err != nil || ok { + return i, errors.Trace(err) + } + + // then in select list + i, ok, err = v.checkIdentInSelectList(i) + if err != nil || ok { + return i, errors.Trace(err) + } + + // TODO: check in out query + + return i, errors.Errorf("Unknown column '%s' in 'having clause'", i) +} + +func (v *havingVisitor) VisitPosition(p *expression.Position) (expression.Expression, error) { + n := p.N + + // we have added order by to hidden field before, now visit it here. + expr := v.selectList.Fields[n-1].Expr + e, err := expr.Accept(v) + if err != nil { + return nil, errors.Trace(err) + } + v.selectList.Fields[n-1].Expr = e + + return p, nil +} + +func (v *havingVisitor) VisitCall(c *expression.Call) (expression.Expression, error) { + isAggregate := expression.IsAggregateFunc(c.F) + + if v.inAggregate && isAggregate { + // aggregate function can't contain aggregate function + return nil, errors.Errorf("Invalid use of group function") + } + + if isAggregate { + // set true to let outer know we are in aggregate function. + v.inAggregate = true + } + + var err error + for i, e := range c.Args { + c.Args[i], err = e.Accept(v) + if err != nil { + return nil, errors.Trace(err) + } + } + + if isAggregate { + v.inAggregate = false + } + + return c, nil +} + // Plan gets HavingPlan. func (r *HavingRset) Plan(ctx context.Context) (plan.Plan, error) { + visitor := &havingVisitor{ + selectList: r.SelectList, + groupBy: r.GroupBy, + inAggregate: false, + } + visitor.BaseVisitor.V = visitor + + e, err := r.Expr.Accept(visitor) + if err != nil { + return nil, errors.Trace(err) + } + + r.Expr = e + return &plans.HavingPlan{Src: r.Src, Expr: r.Expr, SelectList: r.SelectList}, nil } diff --git a/rset/rsets/having_test.go b/rset/rsets/having_test.go index 91a13e3052..180f57d779 100644 --- a/rset/rsets/having_test.go +++ b/rset/rsets/having_test.go @@ -50,59 +50,6 @@ func (s *testHavingRsetSuite) SetUpSuite(c *C) { s.r = &HavingRset{Src: tblPlan, Expr: expr, SelectList: selectList} } -func (s *testHavingRsetSuite) TestHavingRsetCheckAndUpdateSelectList(c *C) { - resultFields := s.r.Src.GetFields() - - selectList := s.r.SelectList - - groupBy := []expression.Expression{} - - // `select id, name from t having id > 1` - err := s.r.CheckAndUpdateSelectList(selectList, groupBy, resultFields) - c.Assert(err, IsNil) - - // `select name from t group by id having id > 1` - selectList.ResultFields = selectList.ResultFields[1:] - selectList.Fields = selectList.Fields[1:] - - groupBy = []expression.Expression{&expression.Ident{CIStr: model.NewCIStr("id")}} - err = s.r.CheckAndUpdateSelectList(selectList, groupBy, resultFields) - c.Assert(err, IsNil) - - // `select name from t group by id + 1 having id > 1` - expr := expression.NewBinaryOperation(opcode.Plus, &expression.Ident{CIStr: model.NewCIStr("id")}, expression.Value{Val: 1}) - - groupBy = []expression.Expression{expr} - err = s.r.CheckAndUpdateSelectList(selectList, groupBy, resultFields) - c.Assert(err, IsNil) - - // `select name from t group by id + 1 having count(1) > 1` - aggExpr, err := expression.NewCall("count", []expression.Expression{expression.Value{Val: 1}}, false) - c.Assert(err, IsNil) - - s.r.Expr = aggExpr - - err = s.r.CheckAndUpdateSelectList(selectList, groupBy, resultFields) - c.Assert(err, IsNil) - - // `select name from t group by id + 1 having count(xxx) > 1` - aggExpr, err = expression.NewCall("count", []expression.Expression{&expression.Ident{CIStr: model.NewCIStr("xxx")}}, false) - c.Assert(err, IsNil) - - s.r.Expr = aggExpr - - err = s.r.CheckAndUpdateSelectList(selectList, groupBy, resultFields) - c.Assert(err, NotNil) - - // `select name from t group by id having xxx > 1` - expr = expression.NewBinaryOperation(opcode.GT, &expression.Ident{CIStr: model.NewCIStr("xxx")}, expression.Value{Val: 1}) - - s.r.Expr = expr - - err = s.r.CheckAndUpdateSelectList(selectList, groupBy, resultFields) - c.Assert(err, NotNil) -} - func (s *testHavingRsetSuite) TestHavingRsetPlan(c *C) { p, err := s.r.Plan(nil) c.Assert(err, IsNil) diff --git a/rset/rsets/helper.go b/rset/rsets/helper.go index 53ce63b743..91aff99379 100644 --- a/rset/rsets/helper.go +++ b/rset/rsets/helper.go @@ -14,13 +14,15 @@ package rsets import ( + "github.com/juju/errors" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/field" + "github.com/pingcap/tidb/plan/plans" ) // GetAggFields gets aggregate fields position map. func GetAggFields(fields []*field.Field) map[int]struct{} { - aggFields := map[int]struct{}{} + aggFields := make(map[int]struct{}, len(fields)) for i, v := range fields { if expression.ContainAggregateFunc(v.Expr) { aggFields[i] = struct{}{} @@ -34,3 +36,122 @@ func HasAggFields(fields []*field.Field) bool { aggFields := GetAggFields(fields) return len(aggFields) > 0 } + +// castIdent returns an Ident expression if e is or nil. +func castIdent(e expression.Expression) *expression.Ident { + i, ok := e.(*expression.Ident) + if !ok { + return nil + } + return i +} + +// TODO: export clause type and move to plan? +type clauseType int + +const ( + noneClause clauseType = iota + onClause + whereClause + groupByClause + fieldListClause + havingClause + orderByClause +) + +func (clause clauseType) String() string { + switch clause { + case onClause: + return "on clause" + case fieldListClause: + return "field list" + case whereClause: + return "where clause" + case groupByClause: + return "group statement" + case orderByClause: + return "order clause" + case havingClause: + return "having clause" + } + return "none" +} + +// castPosition returns an group/order by Position expression if e is a number. +func castPosition(e expression.Expression, selectList *plans.SelectList, clause clauseType) (*expression.Position, error) { + v, ok := e.(expression.Value) + if !ok { + return nil, nil + } + + var position int + switch u := v.Val.(type) { + case int64: + position = int(u) + case uint64: + position = int(u) + default: + return nil, nil + } + + if position < 1 || position > selectList.HiddenFieldOffset { + return nil, errors.Errorf("Unknown column '%d' in '%s'", position, clause) + } + + if clause == groupByClause { + index := position - 1 + if _, ok := selectList.AggFields[index]; ok { + return nil, errors.Errorf("Can't group on '%s'", selectList.Fields[index]) + } + } + + // use Position expression for the associated field. + return &expression.Position{N: position}, nil +} + +func checkIdentAmbiguous(i *expression.Ident, selectList *plans.SelectList, clause clauseType) (int, error) { + index, err := selectList.CheckAmbiguous(i) + if err != nil { + return -1, errors.Errorf("Column '%s' in %s is ambiguous", i, clause) + } else if index == -1 { + return -1, nil + } + + return index, nil +} + +// fromIdentVisitor can only handle identifier which reference FROM table or outer query. +// like in common select list, where or join on condition. +type fromIdentVisitor struct { + expression.BaseVisitor + fromFields []*field.ResultField + clause clauseType +} + +func (v *fromIdentVisitor) VisitIdent(i *expression.Ident) (expression.Expression, error) { + idx := field.GetResultFieldIndex(i.L, v.fromFields, field.DefaultFieldFlag) + if len(idx) == 1 { + i.ReferScope = expression.IdentReferFromTable + i.ReferIndex = idx[0] + return i, nil + } else if len(idx) > 1 { + return nil, errors.Errorf("Column '%s' in %s is ambiguous", i, v.clause) + } + + if v.clause == onClause { + // on clause can't check outer query. + return nil, errors.Errorf("Unknown column '%s' in '%s'", i, v.clause) + } + + // TODO: check in outer query + return i, nil +} + +func newFromIdentVisitor(fromFields []*field.ResultField, clause clauseType) *fromIdentVisitor { + visitor := &fromIdentVisitor{} + visitor.BaseVisitor.V = visitor + visitor.fromFields = fromFields + visitor.clause = clause + + return visitor +} diff --git a/rset/rsets/join.go b/rset/rsets/join.go index e358a27dc7..97e7c8e583 100644 --- a/rset/rsets/join.go +++ b/rset/rsets/join.go @@ -141,6 +141,17 @@ func (r *JoinRset) buildJoinPlan(ctx context.Context, p *plans.JoinPlan, s *Join p.Fields = append(p.Fields, rightFields...) } + if p.On != nil { + visitor := newFromIdentVisitor(p.Fields, onClause) + + e, err := p.On.Accept(visitor) + if err != nil { + return errors.Trace(err) + } + + p.On = e + } + return nil } diff --git a/rset/rsets/orderby.go b/rset/rsets/orderby.go index e0b9e74834..9412cc9faa 100644 --- a/rset/rsets/orderby.go +++ b/rset/rsets/orderby.go @@ -65,46 +65,110 @@ func (r *OrderByRset) String() string { return strings.Join(a, ", ") } -// CheckAndUpdateSelectList checks order by fields validity and set hidden fields to selectList. -func (r *OrderByRset) CheckAndUpdateSelectList(selectList *plans.SelectList, tableFields []*field.ResultField) error { +// CheckAggregate will check whether order by has aggregate function or not, +// if has, we will add it to select list hidden field. +func (r *OrderByRset) CheckAggregate(selectList *plans.SelectList) error { for i, v := range r.By { if expression.ContainAggregateFunc(v.Expr) { - expr, err := selectList.UpdateAggFields(v.Expr, tableFields) + expr, err := selectList.UpdateAggFields(v.Expr) if err != nil { return errors.Errorf("%s in 'order clause'", err.Error()) } r.By[i].Expr = expr - } else { - if _, err := selectList.CheckReferAmbiguous(v.Expr); err != nil { - return errors.Errorf("Column '%s' in order statement is ambiguous", v.Expr) - } + } + } + return nil +} - // TODO: check more ambiguous case - // Order by ambiguous rule: - // select c1 as a, c2 as a from t order by a is ambiguous - // select c1 as a, c2 as a from t order by a + 1 is ambiguous - // select c1 as c2, c2 from t order by c2 is ambiguous - // select c1 as c2, c2 from t order by c2 + 1 is ambiguous +type orderByVisitor struct { + expression.BaseVisitor + selectList *plans.SelectList + rootIdent *expression.Ident +} - // TODO: use vistor to refactor all and combine following plan check. - names := expression.MentionedColumns(v.Expr) - for _, name := range names { - // try to find in select list - // TODO: mysql has confused result for this, see #555. - // now we use select list then order by, later we should make it easier. - if field.ContainFieldName(name, selectList.ResultFields, field.CheckFieldFlag) { - continue - } +func (v *orderByVisitor) VisitIdent(i *expression.Ident) (expression.Expression, error) { + // Order by ambiguous rule: + // select c1 as a, c2 as a from t order by a is ambiguous + // select c1 as a, c2 as a from t order by a + 1 is ambiguous + // select c1 as c2, c2 from t order by c2 is ambiguous + // select c1 as c2, c2 from t order by c2 + 1 is not ambiguous - if !selectList.CloneHiddenField(name, tableFields) { - return errors.Errorf("Unknown column '%s' in 'order clause'", name) - } - } + // Order by identifier reference check + // select c1 as c2 from t order by c2, c2 in order by references c1 in select field. + // select c1 as c2 from t order by c2 + 1, c2 in order by c2 + 1 references t.c2 in from table. + // select c1 as c2 from t order by sum(c2), c2 in order by sum(c2) references t.c2 in from table. + // select c1 as c2 from t order by sum(1) + c2, c2 in order by sum(1) + c2 references t.c2 in from table. + // select c1 as a from t order by a, a references c1 in select field. + // select c1 as a from t order by sum(a), a in order by sum(a) references c1 in select field. + + // TODO: unify VisitIdent for group by and order by visitor, they are little different. + + var ( + index int + err error + ) + + if v.rootIdent == i { + // The order by is an identifier, we must check it first. + index, err = checkIdentAmbiguous(i, v.selectList, orderByClause) + if err != nil { + return nil, errors.Trace(err) + } + + if index >= 0 { + // identifier references a select field. use it directly. + // e,g. select c1 as c2 from t order by c2, here c2 references c1. + i.ReferScope = expression.IdentReferSelectList + i.ReferIndex = index + return i, nil } } - return nil + // find this identifier in FROM. + idx := field.GetResultFieldIndex(i.L, v.selectList.FromFields, field.DefaultFieldFlag) + if len(idx) > 0 { + i.ReferScope = expression.IdentReferFromTable + i.ReferIndex = idx[0] + return i, nil + } + + if v.rootIdent != i { + // This identifier is the part of the order by, check ambiguous here. + index, err = checkIdentAmbiguous(i, v.selectList, orderByClause) + if err != nil { + return nil, errors.Trace(err) + } + } + + // try to find in select list, we have got index using checkIdentAmbiguous before. + if index >= 0 { + // find in select list + i.ReferScope = expression.IdentReferSelectList + i.ReferIndex = index + return i, nil + } + + // TODO: check in out query + return i, errors.Errorf("Unknown column '%s' in 'order clause'", i) +} + +func (v *orderByVisitor) VisitPosition(p *expression.Position) (expression.Expression, error) { + n := p.N + if p.N <= v.selectList.HiddenFieldOffset { + // this position expression is not in hidden field, no need to check + return p, nil + } + + // we have added order by to hidden field before, now visit it here. + expr := v.selectList.Fields[n-1].Expr + e, err := expr.Accept(v) + if err != nil { + return nil, errors.Trace(err) + } + v.selectList.Fields[n-1].Expr = e + + return p, nil } // Plan get SrcPlan/OrderByDefaultPlan. @@ -120,45 +184,28 @@ func (r *OrderByRset) Plan(ctx context.Context) (plan.Plan, error) { ascs []bool ) - fields := r.Src.GetFields() + visitor := &orderByVisitor{} + visitor.BaseVisitor.V = visitor + visitor.selectList = r.SelectList + for i := range r.By { e := r.By[i].Expr - if v, ok := e.(expression.Value); ok { - var ( - position int - isPosition = true - ) - switch u := v.Val.(type) { - case int64: - position = int(u) - case uint64: - position = int(u) - default: - isPosition = false - // only const value - } + pos, err := castPosition(e, r.SelectList, orderByClause) + if err != nil { + return nil, errors.Trace(err) + } - if isPosition { - if position < 1 || position > len(fields) { - return nil, errors.Errorf("Unknown column '%d' in 'order clause'", position) - } - - // use Position expression for the associated field. - r.By[i].Expr = &expression.Position{N: position} - } + if pos != nil { + // use Position expression for the associated field. + r.By[i].Expr = pos } else { - // Don't check ambiguous here, only check field exists or not. - // TODO: use visitor to refactor. - colNames := expression.MentionedColumns(e) - for _, name := range colNames { - if idx := field.GetResultFieldIndex(name, r.SelectList.ResultFields, field.DefaultFieldFlag); len(idx) == 0 { - // find in from - if idx = field.GetResultFieldIndex(name, r.SelectList.FromFields, field.DefaultFieldFlag); len(idx) == 0 { - return nil, errors.Errorf("unknown field %s", name) - } - } + visitor.rootIdent = castIdent(e) + e, err = e.Accept(visitor) + if err != nil { + return nil, errors.Trace(err) } + r.By[i].Expr = e } by = append(by, r.By[i].Expr) diff --git a/rset/rsets/orderby_test.go b/rset/rsets/orderby_test.go index afd4b5313e..9a51c381b7 100644 --- a/rset/rsets/orderby_test.go +++ b/rset/rsets/orderby_test.go @@ -47,64 +47,6 @@ func (s *testOrderByRsetSuite) SetUpSuite(c *C) { s.r = &OrderByRset{Src: tblPlan, SelectList: selectList} } -func (s *testOrderByRsetSuite) TestOrderByRsetCheckAndUpdateSelectList(c *C) { - resultFields := s.r.Src.GetFields() - - fields := make([]*field.Field, len(resultFields)) - for i, resultField := range resultFields { - name := resultField.Name - fields[i] = &field.Field{Expr: &expression.Ident{CIStr: model.NewCIStr(name)}} - } - - selectList := &plans.SelectList{ - HiddenFieldOffset: len(resultFields), - ResultFields: resultFields, - Fields: fields, - } - - expr := &expression.Ident{CIStr: model.NewCIStr("id")} - orderByItem := OrderByItem{Expr: expr, Asc: true} - by := []OrderByItem{orderByItem} - r := &OrderByRset{By: by, SelectList: selectList} - - // `select id, name from t order by id` - err := r.CheckAndUpdateSelectList(selectList, resultFields) - c.Assert(err, IsNil) - - // `select id, name as id from t order by id` - selectList.Fields[1].AsName = "id" - selectList.ResultFields[1].Name = "id" - - err = r.CheckAndUpdateSelectList(selectList, resultFields) - c.Assert(err, NotNil) - - // `select id, name from t order by count(1) > 1` - aggExpr, err := expression.NewCall("count", []expression.Expression{expression.Value{Val: 1}}, false) - c.Assert(err, IsNil) - - r.By[0].Expr = aggExpr - - err = r.CheckAndUpdateSelectList(selectList, resultFields) - c.Assert(err, IsNil) - - // `select id, name from t order by count(xxx) > 1` - aggExpr, err = expression.NewCall("count", []expression.Expression{&expression.Ident{CIStr: model.NewCIStr("xxx")}}, false) - c.Assert(err, IsNil) - - r.By[0].Expr = aggExpr - - err = r.CheckAndUpdateSelectList(selectList, resultFields) - c.Assert(err, NotNil) - - // `select id, name from t order by xxx` - r.By[0].Expr = &expression.Ident{CIStr: model.NewCIStr("xxx")} - selectList.Fields[1].AsName = "name" - selectList.ResultFields[1].Name = "name" - - err = r.CheckAndUpdateSelectList(selectList, resultFields) - c.Assert(err, NotNil) -} - func (s *testOrderByRsetSuite) TestOrderByRsetPlan(c *C) { // `select id, name from t` _, err := s.r.Plan(nil) @@ -145,12 +87,6 @@ func (s *testOrderByRsetSuite) TestOrderByRsetPlan(c *C) { _, err = s.r.Plan(nil) c.Assert(err, IsNil) - // `select id, name from t order by xxx` - s.r.By[0].Expr = &expression.Ident{CIStr: model.NewCIStr("xxx")} - - _, err = s.r.Plan(nil) - c.Assert(err, NotNil) - // check src plan is NullPlan s.r.Src = &plans.NullPlan{Fields: s.r.SelectList.ResultFields} diff --git a/rset/rsets/where.go b/rset/rsets/where.go index 30729792a9..3a583fe010 100644 --- a/rset/rsets/where.go +++ b/rset/rsets/where.go @@ -162,7 +162,6 @@ func (r *WhereRset) planStatic(ctx context.Context, e expression.Expression) (pl if err != nil { return nil, err } - if val == nil { // like `select * from t where null`. return &plans.NullPlan{Fields: r.Src.GetFields()}, nil @@ -172,7 +171,6 @@ func (r *WhereRset) planStatic(ctx context.Context, e expression.Expression) (pl if err != nil { return nil, err } - if n == 0 { // like `select * from t where 0`. return &plans.NullPlan{Fields: r.Src.GetFields()}, nil @@ -181,8 +179,25 @@ func (r *WhereRset) planStatic(ctx context.Context, e expression.Expression) (pl return &plans.FilterDefaultPlan{Plan: r.Src, Expr: e}, nil } +func (r *WhereRset) updateWhereFieldsRefer() error { + visitor := newFromIdentVisitor(r.Src.GetFields(), whereClause) + + e, err := r.Expr.Accept(visitor) + if err != nil { + return errors.Trace(err) + } + + r.Expr = e + return nil +} + // Plan gets NullPlan/FilterDefaultPlan. func (r *WhereRset) Plan(ctx context.Context) (plan.Plan, error) { + // Update where fields refer expr by using fromIdentVisitor. + if err := r.updateWhereFieldsRefer(); err != nil { + return nil, errors.Trace(err) + } + expr := r.Expr.Clone() if expr.IsStatic() { // IsStaic means we have a const value for where condition, and we don't need any index. diff --git a/session.go b/session.go index 31c6a703a7..b7c9aa9add 100644 --- a/session.go +++ b/session.go @@ -31,7 +31,7 @@ import ( "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/field" "github.com/pingcap/tidb/kv" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/rset" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/db" diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index a71b3558b8..5e03948130 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -15,7 +15,7 @@ package variable import ( "github.com/pingcap/tidb/context" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" ) // SessionVars is to handle user-defined or global variables in current session. diff --git a/stmt/stmts/account_manage.go b/stmt/stmts/account_manage.go index 8c169ef354..cec31dad1e 100644 --- a/stmt/stmts/account_manage.go +++ b/stmt/stmts/account_manage.go @@ -23,7 +23,7 @@ import ( "github.com/juju/errors" "github.com/pingcap/tidb/context" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/coldef" "github.com/pingcap/tidb/rset" "github.com/pingcap/tidb/stmt" diff --git a/stmt/stmts/create_test.go b/stmt/stmts/create_test.go index d53e8270d2..15e0cbf1bc 100644 --- a/stmt/stmts/create_test.go +++ b/stmt/stmts/create_test.go @@ -21,7 +21,7 @@ import ( "github.com/ngaut/log" . "github.com/pingcap/check" "github.com/pingcap/tidb" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" ) func TestT(t *testing.T) { diff --git a/stmt/stmts/grant.go b/stmt/stmts/grant.go index 841ecb8769..7ba4152656 100644 --- a/stmt/stmts/grant.go +++ b/stmt/stmts/grant.go @@ -25,7 +25,7 @@ import ( "github.com/pingcap/tidb/column" "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/coldef" "github.com/pingcap/tidb/rset" "github.com/pingcap/tidb/sessionctx" @@ -221,7 +221,7 @@ func (s *GrantStmt) grantPriv(ctx context.Context, priv *coldef.PrivElem, user * } return s.grantColumnPriv(ctx, priv, user) default: - return errors.Errorf("Unknown grant level: %s", s.Level) + return errors.Errorf("Unknown grant level: %#v", s.Level) } } @@ -305,7 +305,7 @@ func composeGlobalPrivUpdate(priv mysql.PrivilegeType) (string, error) { } col, ok := mysql.Priv2UserCol[priv] if !ok { - return "", errors.Errorf("Unknown priv: %s", priv) + return "", errors.Errorf("Unknown priv: %v", priv) } return fmt.Sprintf(`%s="Y"`, col), nil } @@ -317,7 +317,7 @@ func composeDBPrivUpdate(priv mysql.PrivilegeType) (string, error) { for _, p := range mysql.AllDBPrivs { v, ok := mysql.Priv2UserCol[p] if !ok { - return "", errors.Errorf("Unknown db privilege %s", priv) + return "", errors.Errorf("Unknown db privilege %v", priv) } strs = append(strs, fmt.Sprintf(`%s="Y"`, v)) } @@ -325,7 +325,7 @@ func composeDBPrivUpdate(priv mysql.PrivilegeType) (string, error) { } col, ok := mysql.Priv2UserCol[priv] if !ok { - return "", errors.Errorf("Unknown priv: %s", priv) + return "", errors.Errorf("Unknown priv: %v", priv) } return fmt.Sprintf(`%s="Y"`, col), nil } @@ -363,7 +363,7 @@ func composeTablePrivUpdate(ctx context.Context, priv mysql.PrivilegeType, name } p, ok := mysql.Priv2SetStr[priv] if !ok { - return "", errors.Errorf("Unknown priv: %s", priv) + return "", errors.Errorf("Unknown priv: %v", priv) } if len(currTablePriv) == 0 { newTablePriv = p @@ -406,7 +406,7 @@ func composeColumnPrivUpdate(ctx context.Context, priv mysql.PrivilegeType, name } p, ok := mysql.Priv2SetStr[priv] if !ok { - return "", errors.Errorf("Unknown priv: %s", priv) + return "", errors.Errorf("Unknown priv: %v", priv) } if len(currColumnPriv) == 0 { newColumnPriv = p diff --git a/stmt/stmts/grant_test.go b/stmt/stmts/grant_test.go index 2b4ace611b..eae406ccf9 100644 --- a/stmt/stmts/grant_test.go +++ b/stmt/stmts/grant_test.go @@ -18,7 +18,7 @@ import ( "strings" . "github.com/pingcap/check" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" ) func (s *testStmtSuite) TestGrantGlobal(c *C) { diff --git a/stmt/stmts/insert.go b/stmt/stmts/insert.go index 374349b975..61251f56f8 100644 --- a/stmt/stmts/insert.go +++ b/stmt/stmts/insert.go @@ -24,7 +24,7 @@ import ( "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/field" "github.com/pingcap/tidb/kv" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/plan" "github.com/pingcap/tidb/rset" "github.com/pingcap/tidb/sessionctx/variable" diff --git a/stmt/stmts/select.go b/stmt/stmts/select.go index 8b889dc5a2..72b72594bd 100644 --- a/stmt/stmts/select.go +++ b/stmt/stmts/select.go @@ -175,20 +175,20 @@ func (s *SelectStmt) Plan(ctx context.Context) (plan.Plan, error) { if s.Having != nil { // `having` may contain aggregate functions, and we will add this to hidden fields. - if err = s.Having.CheckAndUpdateSelectList(selectList, groupBy, r.GetFields()); err != nil { + if err = s.Having.CheckAggregate(selectList); err != nil { return nil, errors.Trace(err) } } if s.OrderBy != nil { // `order by` may contain aggregate functions, and we will add this to hidden fields. - if err = s.OrderBy.CheckAndUpdateSelectList(selectList, r.GetFields()); err != nil { + if err = s.OrderBy.CheckAggregate(selectList); err != nil { return nil, errors.Trace(err) } } switch { - case !rsets.HasAggFields(selectList.Fields) && s.GroupBy == nil: + case len(selectList.AggFields) == 0 && s.GroupBy == nil: // If no group by and no aggregate functions, we will use SelectFieldsPlan. if r, err = (&rsets.SelectFieldsRset{Src: r, SelectList: selectList, @@ -209,7 +209,8 @@ func (s *SelectStmt) Plan(ctx context.Context) (plan.Plan, error) { if r, err = (&rsets.HavingRset{ Src: r, Expr: s.Expr, - SelectList: selectList}).Plan(ctx); err != nil { + SelectList: selectList, + GroupBy: groupBy}).Plan(ctx); err != nil { return nil, err } } diff --git a/stmt/stmts/stmt_helper.go b/stmt/stmts/stmt_helper.go index d8ebe8682a..cca52a8f5b 100644 --- a/stmt/stmts/stmt_helper.go +++ b/stmt/stmts/stmt_helper.go @@ -18,7 +18,7 @@ import ( "github.com/pingcap/tidb/column" "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/expression" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/table" ) diff --git a/stmt/stmts/transaction.go b/stmt/stmts/transaction.go index b82c398f6b..292eb61936 100644 --- a/stmt/stmts/transaction.go +++ b/stmt/stmts/transaction.go @@ -19,7 +19,7 @@ package stmts import ( "github.com/pingcap/tidb/context" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/rset" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/stmt" diff --git a/stmt/stmts/union.go b/stmt/stmts/union.go index 90390a7187..bd569f1c67 100644 --- a/stmt/stmts/union.go +++ b/stmt/stmts/union.go @@ -92,13 +92,14 @@ func (s *UnionStmt) Plan(ctx context.Context) (plan.Plan, error) { selectList := &plans.SelectList{} selectList.ResultFields = make([]*field.ResultField, len(fields)) selectList.HiddenFieldOffset = len(fields) + selectList.Fields = s.Selects[0].Fields // Union uses first select return column names and ignores table name. // We only care result name and type here. for i, f := range fields { - nf := &field.ResultField{} - nf.Name = f.Name - nf.FieldType = f.FieldType + nf := f.Clone() + nf.OrgTableName = "" + nf.TableName = "" selectList.ResultFields[i] = nf } @@ -109,6 +110,7 @@ func (s *UnionStmt) Plan(ctx context.Context) (plan.Plan, error) { r = &plans.UnionPlan{Srcs: srcs, Distincts: s.Distincts, RFields: selectList.ResultFields} + // TODO: check aggregate function later. if s := s.OrderBy; s != nil { if r, err = (&rsets.OrderByRset{By: s.By, Src: r, diff --git a/stmt/stmts/update.go b/stmt/stmts/update.go index 4ea5da59dc..1b6f46f4dc 100644 --- a/stmt/stmts/update.go +++ b/stmt/stmts/update.go @@ -26,7 +26,7 @@ import ( "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/field" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/plan" "github.com/pingcap/tidb/plan/plans" "github.com/pingcap/tidb/rset" diff --git a/store/localstore/compactor.go b/store/localstore/compactor.go new file mode 100644 index 0000000000..d0eabb221b --- /dev/null +++ b/store/localstore/compactor.go @@ -0,0 +1,199 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package localstore + +import ( + "sync" + "time" + + "github.com/juju/errors" + "github.com/ngaut/log" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/store/localstore/engine" + "github.com/pingcap/tidb/util/bytes" +) + +var _ kv.Compactor = (*localstoreCompactor)(nil) + +const ( + deleteWorkerCnt = 3 +) + +var localCompactDefaultPolicy = kv.CompactPolicy{ + SafePoint: 20 * 1000, // in ms + TriggerInterval: 1 * time.Second, + BatchDeleteCnt: 100, +} + +type localstoreCompactor struct { + mu sync.Mutex + recentKeys map[string]struct{} + stopCh chan struct{} + delCh chan kv.EncodedKey + workerWaitGroup sync.WaitGroup + ticker *time.Ticker + db engine.DB + policy kv.CompactPolicy +} + +func (gc *localstoreCompactor) OnSet(k kv.Key) { + gc.mu.Lock() + defer gc.mu.Unlock() + gc.recentKeys[string(k)] = struct{}{} +} + +func (gc *localstoreCompactor) OnGet(k kv.Key) { + // Do nothing now. +} + +func (gc *localstoreCompactor) OnDelete(k kv.Key) { + gc.mu.Lock() + defer gc.mu.Unlock() + gc.recentKeys[string(k)] = struct{}{} +} + +func (gc *localstoreCompactor) getAllVersions(k kv.Key) ([]kv.EncodedKey, error) { + startKey := MvccEncodeVersionKey(k, kv.MaxVersion) + endKey := MvccEncodeVersionKey(k, kv.MinVersion) + + it, err := gc.db.Seek(startKey) + if err != nil { + return nil, errors.Trace(err) + } + defer it.Release() + + var ret []kv.EncodedKey + for it.Next() { + if kv.EncodedKey(it.Key()).Cmp(endKey) < 0 { + ret = append(ret, bytes.CloneBytes(kv.EncodedKey(it.Key()))) + } + } + return ret, nil +} + +func (gc *localstoreCompactor) deleteWorker() { + gc.workerWaitGroup.Add(1) + defer gc.workerWaitGroup.Done() + cnt := 0 + batch := gc.db.NewBatch() + for { + select { + case <-gc.stopCh: + return + case key := <-gc.delCh: + { + cnt++ + batch.Delete(key) + // Batch delete. + if cnt == gc.policy.BatchDeleteCnt { + err := gc.db.Commit(batch) + if err != nil { + log.Error(err) + } + batch = gc.db.NewBatch() + cnt = 0 + } + } + } + } +} + +func (gc *localstoreCompactor) checkExpiredKeysWorker() { + gc.workerWaitGroup.Add(1) + defer gc.workerWaitGroup.Done() + for { + select { + case <-gc.stopCh: + log.Info("GC stopped") + return + case <-gc.ticker.C: + log.Info("GC trigger") + gc.mu.Lock() + m := gc.recentKeys + gc.recentKeys = make(map[string]struct{}) + gc.mu.Unlock() + // Do Compactor + for k := range m { + err := gc.Compact(nil, []byte(k)) + if err != nil { + log.Error(err) + } + } + } + } +} + +func (gc *localstoreCompactor) filterExpiredKeys(keys []kv.EncodedKey) []kv.EncodedKey { + var ret []kv.EncodedKey + first := true + // keys are always in descending order. + for _, k := range keys { + _, ver, err := MvccDecode(k) + if err != nil { + // Should not happen. + panic(err) + } + ts := localVersionToTimestamp(ver) + currentTS := time.Now().UnixNano() / int64(time.Millisecond) + // Check timeout keys. + if currentTS-int64(ts) >= int64(gc.policy.SafePoint) { + // Skip first version. + if first { + first = false + continue + } + ret = append(ret, k) + } + } + return ret +} + +func (gc *localstoreCompactor) Compact(ctx interface{}, k kv.Key) error { + keys, err := gc.getAllVersions(k) + if err != nil { + return errors.Trace(err) + } + for _, key := range gc.filterExpiredKeys(keys) { + // Send timeout key to deleteWorker. + log.Info("GC send key to deleteWorker", key) + gc.delCh <- key + } + return nil +} + +func (gc *localstoreCompactor) Start() { + // Start workers. + go gc.checkExpiredKeysWorker() + for i := 0; i < deleteWorkerCnt; i++ { + go gc.deleteWorker() + } +} + +func (gc *localstoreCompactor) Stop() { + gc.ticker.Stop() + close(gc.stopCh) + // Wait for all workers to finish. + gc.workerWaitGroup.Wait() +} + +func newLocalCompactor(policy kv.CompactPolicy, db engine.DB) *localstoreCompactor { + return &localstoreCompactor{ + recentKeys: make(map[string]struct{}), + stopCh: make(chan struct{}), + delCh: make(chan kv.EncodedKey, 100), + ticker: time.NewTicker(policy.TriggerInterval), + policy: policy, + db: db, + } +} diff --git a/store/localstore/compactor_test.go b/store/localstore/compactor_test.go new file mode 100644 index 0000000000..027437be9c --- /dev/null +++ b/store/localstore/compactor_test.go @@ -0,0 +1,87 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package localstore + +import ( + "time" + + . "github.com/pingcap/check" + "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/store/localstore/engine" +) + +var _ = Suite(&localstoreCompactorTestSuite{}) + +type localstoreCompactorTestSuite struct { +} + +func count(db engine.DB) int { + it, _ := db.Seek([]byte{0}) + defer it.Release() + totalCnt := 0 + for it.Next() { + totalCnt++ + } + return totalCnt +} + +func (s *localstoreCompactorTestSuite) TestCompactor(c *C) { + store := createMemStore() + db := store.(*dbStore).db + store.(*dbStore).compactor.Stop() + + policy := kv.CompactPolicy{ + SafePoint: 500, + BatchDeleteCnt: 1, + TriggerInterval: 100 * time.Millisecond, + } + compactor := newLocalCompactor(policy, db) + store.(*dbStore).compactor = compactor + + compactor.Start() + + txn, _ := store.Begin() + txn.Set([]byte("a"), []byte("1")) + txn.Commit() + txn, _ = store.Begin() + txn.Set([]byte("a"), []byte("2")) + txn.Commit() + txn, _ = store.Begin() + txn.Set([]byte("a"), []byte("3")) + txn.Commit() + txn, _ = store.Begin() + txn.Set([]byte("a"), []byte("3")) + txn.Commit() + txn, _ = store.Begin() + txn.Set([]byte("a"), []byte("4")) + txn.Commit() + txn, _ = store.Begin() + txn.Set([]byte("a"), []byte("5")) + txn.Commit() + t := count(db) + c.Assert(t, Equals, 7) + + // Simulating timeout + time.Sleep(1 * time.Second) + // Touch a, tigger GC + txn, _ = store.Begin() + txn.Set([]byte("a"), []byte("b")) + txn.Commit() + time.Sleep(1 * time.Second) + // Do background GC + t = count(db) + c.Assert(t, Equals, 3) + + compactor.Stop() +} diff --git a/store/localstore/kv.go b/store/localstore/kv.go index 0bab034074..ae96bb5cf0 100644 --- a/store/localstore/kv.go +++ b/store/localstore/kv.go @@ -15,7 +15,6 @@ package localstore import ( "sync" - "time" "github.com/juju/errors" "github.com/ngaut/log" @@ -38,6 +37,7 @@ type dbStore struct { keysLocked map[string]uint64 uuid string path string + compactor *localstoreCompactor } type storeCache struct { @@ -85,10 +85,10 @@ func (d Driver) Open(schema string) (kv.Storage, error) { uuid: uuid.NewV4().String(), path: schema, db: db, + compactor: newLocalCompactor(localCompactDefaultPolicy, db), } - mc.cache[schema] = s - + s.compactor.Start() return s, nil } @@ -117,7 +117,6 @@ func (s *dbStore) Begin() (kv.Transaction, error) { return nil, err } txn := &dbTxn{ - startTs: time.Now(), tid: beginVer.Ver, valid: true, store: s, @@ -138,7 +137,7 @@ func (s *dbStore) Begin() (kv.Transaction, error) { func (s *dbStore) Close() error { mc.mu.Lock() defer mc.mu.Unlock() - + s.compactor.Stop() delete(mc.cache, s.path) return s.db.Close() } diff --git a/store/localstore/kv_test.go b/store/localstore/kv_test.go index d09c35043d..5b28af06f1 100644 --- a/store/localstore/kv_test.go +++ b/store/localstore/kv_test.go @@ -93,7 +93,7 @@ func valToStr(c *C, iter kv.Iterator) string { func checkSeek(c *C, txn kv.Transaction) { for i := startIndex; i < testCount; i++ { val := encodeInt(i) - iter, err := txn.Seek(val, nil) + iter, err := txn.Seek(val) c.Assert(err, IsNil) c.Assert(iter.Key(), Equals, string(val)) c.Assert(decodeInt([]byte(valToStr(c, iter))), Equals, i) @@ -103,12 +103,12 @@ func checkSeek(c *C, txn kv.Transaction) { // Test iterator Next() for i := startIndex; i < testCount-1; i++ { val := encodeInt(i) - iter, err := txn.Seek(val, nil) + iter, err := txn.Seek(val) c.Assert(err, IsNil) c.Assert(iter.Key(), Equals, string(val)) c.Assert(valToStr(c, iter), Equals, string(val)) - next, err := iter.Next(nil) + next, err := iter.Next() c.Assert(err, IsNil) c.Assert(next.Valid(), IsTrue) @@ -119,7 +119,7 @@ func checkSeek(c *C, txn kv.Transaction) { } // Non exist seek test - iter, err := txn.Seek(encodeInt(testCount), nil) + iter, err := txn.Seek(encodeInt(testCount)) c.Assert(err, IsNil) c.Assert(iter.Valid(), IsFalse) iter.Close() @@ -264,19 +264,19 @@ func (s *testKVSuite) TestDelete2(c *C) { txn, err = s.s.Begin() c.Assert(err, IsNil) - it, err := txn.Seek([]byte("DATA_test_tbl_department_record__0000000001_0003"), nil) + it, err := txn.Seek([]byte("DATA_test_tbl_department_record__0000000001_0003")) c.Assert(err, IsNil) for it.Valid() { err = txn.Delete([]byte(it.Key())) c.Assert(err, IsNil) - it, err = it.Next(nil) + it, err = it.Next() c.Assert(err, IsNil) } txn.Commit() txn, err = s.s.Begin() c.Assert(err, IsNil) - it, _ = txn.Seek([]byte("DATA_test_tbl_department_record__000000000"), nil) + it, _ = txn.Seek([]byte("DATA_test_tbl_department_record__000000000")) c.Assert(it.Valid(), IsFalse) txn.Commit() @@ -299,7 +299,7 @@ func (s *testKVSuite) TestBasicSeek(c *C) { c.Assert(err, IsNil) defer txn.Commit() - it, err := txn.Seek([]byte("2"), nil) + it, err := txn.Seek([]byte("2")) c.Assert(err, IsNil) c.Assert(it.Valid(), Equals, false) txn.Delete([]byte("1")) @@ -320,30 +320,30 @@ func (s *testKVSuite) TestBasicTable(c *C) { err = txn.Set([]byte("1"), []byte("1")) c.Assert(err, IsNil) - it, err := txn.Seek([]byte("0"), nil) + it, err := txn.Seek([]byte("0")) c.Assert(err, IsNil) c.Assert(it.Key(), Equals, "1") err = txn.Set([]byte("0"), []byte("0")) c.Assert(err, IsNil) - it, err = txn.Seek([]byte("0"), nil) + it, err = txn.Seek([]byte("0")) c.Assert(err, IsNil) c.Assert(it.Key(), Equals, "0") err = txn.Delete([]byte("0")) c.Assert(err, IsNil) txn.Delete([]byte("1")) - it, err = txn.Seek([]byte("0"), nil) + it, err = txn.Seek([]byte("0")) c.Assert(err, IsNil) c.Assert(it.Key(), Equals, "2") err = txn.Delete([]byte("3")) c.Assert(err, IsNil) - it, err = txn.Seek([]byte("2"), nil) + it, err = txn.Seek([]byte("2")) c.Assert(err, IsNil) c.Assert(it.Key(), Equals, "2") - it, err = txn.Seek([]byte("3"), nil) + it, err = txn.Seek([]byte("3")) c.Assert(err, IsNil) c.Assert(it.Key(), Equals, "4") err = txn.Delete([]byte("2")) @@ -401,13 +401,13 @@ func (s *testKVSuite) TestSeekMin(c *C) { txn.Set([]byte(kv.key), []byte(kv.value)) } - it, err := txn.Seek(nil, nil) + it, err := txn.Seek(nil) for it.Valid() { fmt.Printf("%s, %s\n", it.Key(), it.Value()) - it, _ = it.Next(nil) + it, _ = it.Next() } - it, err = txn.Seek([]byte("DATA_test_main_db_tbl_tbl_test_record__00000000000000000000"), nil) + it, err = txn.Seek([]byte("DATA_test_main_db_tbl_tbl_test_record__00000000000000000000")) c.Assert(err, IsNil) c.Assert(string(it.Key()), Equals, "DATA_test_main_db_tbl_tbl_test_record__00000000000000000001") diff --git a/store/localstore/local_version_provider.go b/store/localstore/local_version_provider.go index 34edfab2f5..997bf786b6 100644 --- a/store/localstore/local_version_provider.go +++ b/store/localstore/local_version_provider.go @@ -15,9 +15,11 @@ var ErrOverflow = errors.New("overflow when allocating new version") // LocalVersionProvider uses local timestamp for version. type LocalVersionProvider struct { - mu sync.Mutex - lastTimeStampTs uint64 - n uint64 + mu sync.Mutex + lastTimestamp uint64 + // logical guaranteed version's monotonic increasing for calls when lastTimestamp + // are equal. + logical uint64 } const ( @@ -31,14 +33,18 @@ func (l *LocalVersionProvider) CurrentVersion() (kv.Version, error) { var ts uint64 ts = uint64((time.Now().UnixNano() / int64(time.Millisecond)) << timePrecisionOffset) - if l.lastTimeStampTs == uint64(ts) { - l.n++ - if l.n >= 1<= 1<> timePrecisionOffset +} diff --git a/store/localstore/mvcc.go b/store/localstore/mvcc.go index 9d475c4fad..f61296eee2 100644 --- a/store/localstore/mvcc.go +++ b/store/localstore/mvcc.go @@ -1,3 +1,16 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + package localstore import ( diff --git a/store/localstore/mvcc_test.go b/store/localstore/mvcc_test.go index e0add9217d..9de9682e19 100644 --- a/store/localstore/mvcc_test.go +++ b/store/localstore/mvcc_test.go @@ -1,3 +1,16 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + package localstore import ( @@ -125,11 +138,11 @@ func (t *testMvccSuite) TestMvccPutAndDel(c *C) { func (t *testMvccSuite) TestMvccNext(c *C) { txn, _ := t.s.Begin() - it, err := txn.Seek(encodeInt(2), nil) + it, err := txn.Seek(encodeInt(2)) c.Assert(err, IsNil) c.Assert(it.Valid(), IsTrue) for it.Valid() { - it, err = it.Next(nil) + it, err = it.Next() c.Assert(err, IsNil) } txn.Commit() @@ -186,7 +199,7 @@ func (t *testMvccSuite) TestMvccSuiteGetLatest(c *C) { c.Assert(err, IsNil) c.Assert(string(b), Equals, string(encodeInt(100+9))) // we can always scan newest data - it, err := tx.Seek(encodeInt(5), nil) + it, err := tx.Seek(encodeInt(5)) c.Assert(err, IsNil) c.Assert(it.Valid(), IsTrue) c.Assert(string(it.Value()), Equals, string(encodeInt(100+9))) @@ -231,7 +244,7 @@ func (t *testMvccSuite) TestMvccSnapshotScan(c *C) { if string(it.Value()) == "new" { found = true } - it, err = it.Next(nil) + it, err = it.Next() c.Assert(err, IsNil) } return found @@ -258,11 +271,11 @@ func (t *testMvccSuite) TestBufferedIterator(c *C) { tx.Commit() tx, _ = s.Begin() - iter, err := tx.Seek([]byte{0}, nil) + iter, err := tx.Seek([]byte{0}) c.Assert(err, IsNil) cnt := 0 for iter.Valid() { - iter, err = iter.Next(nil) + iter, err = iter.Next() c.Assert(err, IsNil) cnt++ } @@ -270,7 +283,7 @@ func (t *testMvccSuite) TestBufferedIterator(c *C) { c.Assert(cnt, Equals, 6) tx, _ = s.Begin() - it, err := tx.Seek([]byte{0xff, 0xee}, nil) + it, err := tx.Seek([]byte{0xff, 0xee}) c.Assert(err, IsNil) c.Assert(it.Valid(), IsTrue) c.Assert(it.Key(), Equals, "\xff\xff\xee\xff") @@ -278,11 +291,11 @@ func (t *testMvccSuite) TestBufferedIterator(c *C) { // no such key tx, _ = s.Begin() - it, err = tx.Seek([]byte{0xff, 0xff, 0xff, 0xff}, nil) + it, err = tx.Seek([]byte{0xff, 0xff, 0xff, 0xff}) c.Assert(err, IsNil) c.Assert(it.Valid(), IsFalse) - it, err = tx.Seek([]byte{0x0, 0xff}, nil) + it, err = tx.Seek([]byte{0x0, 0xff}) c.Assert(err, IsNil) c.Assert(it.Valid(), IsTrue) c.Assert(it.Value(), DeepEquals, []byte("2")) diff --git a/store/localstore/snapshot.go b/store/localstore/snapshot.go index ff53e90511..2340001520 100644 --- a/store/localstore/snapshot.go +++ b/store/localstore/snapshot.go @@ -136,11 +136,11 @@ func newDBIter(s *dbSnapshot, startKey kv.Key, exceptedVer kv.Version) *dbIter { valid: true, exceptedVersion: exceptedVer, } - it.Next(nil) + it.Next() return it } -func (it *dbIter) Next(fn kv.FnKeyCmp) (kv.Iterator, error) { +func (it *dbIter) Next() (kv.Iterator, error) { encKey := codec.EncodeBytes(nil, it.startKey) var retErr error var engineIter engine.Iterator diff --git a/store/localstore/txn.go b/store/localstore/txn.go index 741a2c7154..9725501685 100644 --- a/store/localstore/txn.go +++ b/store/localstore/txn.go @@ -17,7 +17,6 @@ import ( "fmt" "runtime/debug" "strconv" - "time" "github.com/juju/errors" "github.com/ngaut/log" @@ -38,7 +37,6 @@ var ( type dbTxn struct { kv.UnionStore store *dbStore // for commit - startTs time.Time tid uint64 valid bool version kv.Version // commit version @@ -95,7 +93,7 @@ func (txn *dbTxn) Inc(k kv.Key, step int64) (int64, error) { if err != nil { return 0, errors.Trace(err) } - + txn.store.compactor.OnSet(k) return intVal, nil } @@ -108,7 +106,6 @@ func (txn *dbTxn) GetInt64(k kv.Key) (int64, error) { if err != nil { return 0, errors.Trace(err) } - intVal, err := strconv.ParseInt(string(val), 10, 0) return intVal, errors.Trace(err) } @@ -130,7 +127,7 @@ func (txn *dbTxn) Get(k kv.Key) ([]byte, error) { if len(val) == 0 { return nil, errors.Trace(kv.ErrNotExist) } - + txn.store.compactor.OnGet(k) return val, nil } @@ -143,10 +140,15 @@ func (txn *dbTxn) Set(k kv.Key, data []byte) error { log.Debugf("set key:%q, txn:%d", k, txn.tid) k = kv.EncodeKey(k) - return txn.UnionStore.Set(k, data) + err := txn.UnionStore.Set(k, data) + if err != nil { + return errors.Trace(err) + } + txn.store.compactor.OnSet(k) + return nil } -func (txn *dbTxn) Seek(k kv.Key, fnKeyCmp func(kv.Key) bool) (kv.Iterator, error) { +func (txn *dbTxn) Seek(k kv.Key) (kv.Iterator, error) { log.Debugf("seek key:%q, txn:%d", k, txn.tid) k = kv.EncodeKey(k) @@ -158,19 +160,18 @@ func (txn *dbTxn) Seek(k kv.Key, fnKeyCmp func(kv.Key) bool) (kv.Iterator, error return &kv.UnionIter{}, nil } - if fnKeyCmp != nil { - if fnKeyCmp([]byte(iter.Key())[:1]) { - return &kv.UnionIter{}, nil - } - } - return iter, nil } func (txn *dbTxn) Delete(k kv.Key) error { log.Debugf("delete key:%q, txn:%d", k, txn.tid) k = kv.EncodeKey(k) - return txn.UnionStore.Delete(k) + err := txn.UnionStore.Delete(k) + if err != nil { + return errors.Trace(err) + } + txn.store.compactor.OnDelete(k) + return nil } func (txn *dbTxn) each(f func(iterator.Iterator) error) error { diff --git a/table/tables/tables.go b/table/tables/tables.go index fdb3f1e2f0..e8d90d46b5 100644 --- a/table/tables/tables.go +++ b/table/tables/tables.go @@ -31,7 +31,7 @@ import ( "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/meta/autoid" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/util" @@ -530,7 +530,7 @@ func (t *Table) IterRecords(ctx context.Context, startKey string, cols []*column return err } - it, err := txn.Seek([]byte(startKey), nil) + it, err := txn.Seek([]byte(startKey)) if err != nil { return err } diff --git a/table/tables/tables_test.go b/table/tables/tables_test.go index 0221fa6147..edaeeb8c2d 100644 --- a/table/tables/tables_test.go +++ b/table/tables/tables_test.go @@ -22,7 +22,7 @@ import ( "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/store/localstore" "github.com/pingcap/tidb/store/localstore/goleveldb" diff --git a/tidb-server/server/conn.go b/tidb-server/server/conn.go index 898f91bd84..3b2f3d8a55 100644 --- a/tidb-server/server/conn.go +++ b/tidb-server/server/conn.go @@ -45,7 +45,7 @@ import ( "github.com/juju/errors" "github.com/ngaut/log" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util/arena" "github.com/pingcap/tidb/util/errors2" "github.com/pingcap/tidb/util/hack" diff --git a/tidb-server/server/conn_stmt.go b/tidb-server/server/conn_stmt.go index aacaf4933a..a4d32aa3f3 100644 --- a/tidb-server/server/conn_stmt.go +++ b/tidb-server/server/conn_stmt.go @@ -40,7 +40,7 @@ import ( "strconv" "github.com/juju/errors" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util/hack" ) diff --git a/tidb-server/server/driver_tidb.go b/tidb-server/server/driver_tidb.go index 79c23ae93c..1767f61798 100644 --- a/tidb-server/server/driver_tidb.go +++ b/tidb-server/server/driver_tidb.go @@ -18,7 +18,7 @@ import ( "github.com/pingcap/tidb" "github.com/pingcap/tidb/field" "github.com/pingcap/tidb/kv" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/rset" "github.com/pingcap/tidb/util/errors2" "github.com/pingcap/tidb/util/types" diff --git a/tidb-server/server/packetio.go b/tidb-server/server/packetio.go index 225922edc5..150cc7c4d2 100644 --- a/tidb-server/server/packetio.go +++ b/tidb-server/server/packetio.go @@ -40,7 +40,7 @@ import ( "net" "github.com/juju/errors" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" ) type packetIO struct { diff --git a/tidb-server/server/server.go b/tidb-server/server/server.go index 49681e4da1..e3dd087663 100644 --- a/tidb-server/server/server.go +++ b/tidb-server/server/server.go @@ -37,7 +37,7 @@ import ( "github.com/juju/errors" "github.com/ngaut/log" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util/arena" ) diff --git a/tidb-server/server/util.go b/tidb-server/server/util.go index 3dbf420afc..08362bc440 100644 --- a/tidb-server/server/util.go +++ b/tidb-server/server/util.go @@ -42,7 +42,7 @@ import ( "time" "github.com/juju/errors" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util/arena" "github.com/pingcap/tidb/util/hack" ) diff --git a/tidb-server/server/util_test.go b/tidb-server/server/util_test.go index 81a14e35de..b5d1cb454b 100644 --- a/tidb-server/server/util_test.go +++ b/tidb-server/server/util_test.go @@ -15,7 +15,7 @@ package server import ( . "github.com/pingcap/check" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" ) var _ = Suite(&testUtilSuite{}) diff --git a/tidb_test.go b/tidb_test.go index 6fa9a1b2cd..ef043b1678 100644 --- a/tidb_test.go +++ b/tidb_test.go @@ -26,7 +26,7 @@ import ( "github.com/ngaut/log" . "github.com/pingcap/check" "github.com/pingcap/tidb/kv" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/rset" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/errors2" @@ -882,6 +882,8 @@ func (s *testSessionSuite) TestSelect(c *C) { c.Assert(err, IsNil) matches(c, rows, [][]interface{}{{1, nil, nil}, {2, 2, nil}}) + mustExecFailed(c, se, "select * from t1 left join t2 on t1.c1 = t3.c3 left join on t3 on t1.c1 = t2.c2") + // For issue 393 mustExecSQL(c, se, "drop table if exists t") mustExecSQL(c, se, "create table t (b blob)") @@ -891,7 +893,6 @@ func (s *testSessionSuite) TestSelect(c *C) { row, err = r.FirstRow() c.Assert(err, IsNil) match(c, row, 3) - } func (s *testSessionSuite) TestSubQuery(c *C) { @@ -1150,6 +1151,11 @@ func (s *testSessionSuite) TestGroupBy(c *C) { mustExecSQL(c, se, "drop table if exists t") mustExecSQL(c, se, "create table t (c1 int, c2 int)") mustExecSQL(c, se, "insert into t values (1,1), (2,2), (1,2), (1,3)") + mustExecMatch(c, se, "select nullif (count(*), 2);", [][]interface{}{{1}}) + mustExecMatch(c, se, "select 1 as a, sum(c1) as a from t group by a", [][]interface{}{{1, 5}}) + mustExecMatch(c, se, "select c1 as a, 1 as a, sum(c1) as a from t group by a", [][]interface{}{{1, 1, 5}}) + mustExecMatch(c, se, "select c1 as a, 1 as a, c2 as a from t group by a;", [][]interface{}{{1, 1, 1}}) + mustExecMatch(c, se, "select c1 as c2, sum(c1) as c2 from t group by c2;", [][]interface{}{{1, 1}, {2, 3}, {1, 1}}) mustExecMatch(c, se, "select c1 as c2, c2 from t group by c2 + 1", [][]interface{}{{1, 1}, {2, 2}, {1, 3}}) mustExecMatch(c, se, "select c1 as c2, count(c1) from t group by c2", [][]interface{}{{1, 1}, {2, 2}, {1, 1}}) @@ -1175,22 +1181,49 @@ func (s *testSessionSuite) TestOrderBy(c *C) { mustExecMatch(c, se, "select c1 as a, t.c1 as a from t order by a desc", [][]interface{}{{2, 2}, {1, 1}}) mustExecMatch(c, se, "select c1 as c2 from t order by c2", [][]interface{}{{1}, {2}}) mustExecMatch(c, se, "select sum(c1) from t order by sum(c1)", [][]interface{}{{3}}) - - // TODO: now this test result is not same as MySQL, we will update it later. - mustExecMatch(c, se, "select c1 as c2 from t order by c2 + 1", [][]interface{}{{1}, {2}}) + mustExecMatch(c, se, "select c1 as c2 from t order by c2 + 1", [][]interface{}{{2}, {1}}) mustExecFailed(c, se, "select c1 as a, c2 as a from t order by a") + + mustExecFailed(c, se, "(select c1 as c2, c2 from t) union (select c1, c2 from t) order by c2") + mustExecFailed(c, se, "(select c1 as c2, c2 from t) union (select c1, c2 from t) order by c1") } func (s *testSessionSuite) TestHaving(c *C) { store := newStore(c, s.dbName) se := newSession(c, store, s.dbName) mustExecSQL(c, se, "drop table if exists t") - mustExecSQL(c, se, "create table t (c1 int, c2 int)") - mustExecSQL(c, se, "insert into t values (1,2), (2, 1)") + mustExecSQL(c, se, "create table t (c1 int, c2 int, c3 int)") + mustExecSQL(c, se, "insert into t values (1,2,3), (2, 3, 1), (3, 1, 2)") - mustExecMatch(c, se, "select sum(c1) from t group by c1 having sum(c1)", [][]interface{}{{1}, {2}}) - mustExecMatch(c, se, "select sum(c1) - 1 from t group by c1 having sum(c1) - 1", [][]interface{}{{1}}) + mustExecMatch(c, se, "select c1 as c2, c3 from t having c2 = 2", [][]interface{}{{2, 1}}) + mustExecMatch(c, se, "select c1 as c2, c3 from t group by c2 having c2 = 2;", [][]interface{}{{1, 3}}) + mustExecMatch(c, se, "select c1 as c2, c3 from t group by c2 having sum(c2) = 2;", [][]interface{}{{1, 3}}) + mustExecMatch(c, se, "select c1 as c2, c3 from t group by c3 having sum(c2) = 2;", [][]interface{}{{1, 3}}) + mustExecMatch(c, se, "select c1 as c2, c3 from t group by c3 having sum(0) + c2 = 2;", [][]interface{}{{2, 1}}) + mustExecMatch(c, se, "select c1 as a from t having c1 = 1;", [][]interface{}{{1}}) + mustExecMatch(c, se, "select t.c1 from t having c1 = 1;", [][]interface{}{{1}}) + mustExecMatch(c, se, "select a.c1 from t as a having c1 = 1;", [][]interface{}{{1}}) + mustExecMatch(c, se, "select c1 as a from t group by c3 having sum(a) = 1;", [][]interface{}{{1}}) + mustExecMatch(c, se, "select c1 as a from t group by c3 having sum(a) + a = 2;", [][]interface{}{{1}}) + mustExecMatch(c, se, "select a.c1 as c, a.c1 as d from t as a, t as b having c1 = 1 limit 1;", [][]interface{}{{1, 1}}) + + mustExecMatch(c, se, "select sum(c1) from t group by c1 having sum(c1)", [][]interface{}{{1}, {2}, {3}}) + mustExecMatch(c, se, "select sum(c1) - 1 from t group by c1 having sum(c1) - 1", [][]interface{}{{1}, {2}}) + mustExecMatch(c, se, "select 1 from t group by c1 having sum(abs(c2 + c3)) = c1", [][]interface{}{{1}}) + + mustExecFailed(c, se, "select c1 from t having c2") + mustExecFailed(c, se, "select c1 from t having c2 + 1") + mustExecFailed(c, se, "select c1 from t group by c2 + 1 having c2") + mustExecFailed(c, se, "select c1 from t group by c2 + 1 having c2 + 1") + mustExecFailed(c, se, "select c1 as c2, c2 from t having c2") + mustExecFailed(c, se, "select c1 as c2, c2 from t having c2 + 1") + mustExecFailed(c, se, "select c1 as a, c2 as a from t having a") + mustExecFailed(c, se, "select c1 as a, c2 as a from t having a + 1") + mustExecFailed(c, se, "select c1 + 1 from t having c1") + mustExecFailed(c, se, "select c1 + 1 from t having c1 + 1") + mustExecFailed(c, se, "select a.c1 as c, b.c1 as d from t as a, t as b having c1") + mustExecFailed(c, se, "select 1 from t having sum(avg(c1))") } func newSession(c *C, store kv.Storage, dbName string) Session { @@ -1273,6 +1306,10 @@ func mustExecMatch(c *C, se Session, sql string, expected [][]interface{}) { } func mustExecFailed(c *C, se Session, sql string, args ...interface{}) { - _, err := exec(c, se, sql, args...) + r, err := exec(c, se, sql, args...) + if err == nil { + // sometimes we may meet error after executing first row. + _, err = r.FirstRow() + } c.Assert(err, NotNil) } diff --git a/util/codec/codec.go b/util/codec/codec.go index d679834575..edd951d533 100644 --- a/util/codec/codec.go +++ b/util/codec/codec.go @@ -18,7 +18,7 @@ import ( "time" "github.com/juju/errors" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" ) var ( diff --git a/util/codec/codec_test.go b/util/codec/codec_test.go index a3a0ccdeff..2334ea46f8 100644 --- a/util/codec/codec_test.go +++ b/util/codec/codec_test.go @@ -19,7 +19,7 @@ import ( "testing" . "github.com/pingcap/check" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" ) func TestT(t *testing.T) { diff --git a/util/codec/decimal.go b/util/codec/decimal.go index f33b5ab39f..13f39d0850 100644 --- a/util/codec/decimal.go +++ b/util/codec/decimal.go @@ -18,7 +18,7 @@ import ( "math/big" "github.com/juju/errors" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" ) const ( diff --git a/util/codec/decimal_test.go b/util/codec/decimal_test.go index 28532e766e..f957ee5351 100644 --- a/util/codec/decimal_test.go +++ b/util/codec/decimal_test.go @@ -15,7 +15,7 @@ package codec import ( . "github.com/pingcap/check" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" ) var _ = Suite(&testDecimalSuite{}) diff --git a/util/prefix_helper.go b/util/prefix_helper.go index 12aa114962..78294a1a40 100644 --- a/util/prefix_helper.go +++ b/util/prefix_helper.go @@ -28,15 +28,9 @@ import ( "github.com/pingcap/tidb/util/codec" ) -func hasPrefix(prefix []byte) kv.FnKeyCmp { - return func(k kv.Key) bool { - return bytes.HasPrefix(k, prefix) - } -} - // ScanMetaWithPrefix scans metadata with the prefix. func ScanMetaWithPrefix(txn kv.Transaction, prefix string, filter func([]byte, []byte) bool) error { - iter, err := txn.Seek([]byte(prefix), hasPrefix([]byte(prefix))) + iter, err := txn.Seek([]byte(prefix)) if err != nil { return errors.Trace(err) } @@ -51,7 +45,7 @@ func ScanMetaWithPrefix(txn kv.Transaction, prefix string, filter func([]byte, [ if !filter([]byte(iter.Key()), iter.Value()) { break } - iter, err = iter.Next(hasPrefix([]byte(prefix))) + iter, err = iter.Next() } else { break } @@ -68,7 +62,7 @@ func DelKeyWithPrefix(ctx context.Context, prefix string) error { } var keys []string - iter, err := txn.Seek([]byte(prefix), hasPrefix([]byte(prefix))) + iter, err := txn.Seek([]byte(prefix)) if err != nil { return errors.Trace(err) } @@ -81,7 +75,7 @@ func DelKeyWithPrefix(ctx context.Context, prefix string) error { if iter.Valid() && strings.HasPrefix(iter.Key(), prefix) { keys = append(keys, iter.Key()) - iter, err = iter.Next(hasPrefix([]byte(prefix))) + iter, err = iter.Next() } else { break } diff --git a/util/prefix_helper_test.go b/util/prefix_helper_test.go index bdb9340614..18e00e1366 100644 --- a/util/prefix_helper_test.go +++ b/util/prefix_helper_test.go @@ -161,11 +161,8 @@ func (s *testPrefixSuite) TestCode(c *C) { b1 := EncodeRecordKey("aa", 1, 0) b2 := EncodeRecordKey("aa", 1, 1) c.Logf("%#v, %#v", b2, b1) - raw, err := codec.StripEnd(b1) + _, err := codec.StripEnd(b1) c.Assert(err, IsNil) - f := hasPrefix(raw) - has := f(b2) - c.Assert(has, IsTrue) } func (s *testPrefixSuite) TestPrefixFilter(c *C) { diff --git a/util/types/compare.go b/util/types/compare.go index 3c894bd228..ff035adedb 100644 --- a/util/types/compare.go +++ b/util/types/compare.go @@ -19,7 +19,7 @@ package types import ( "github.com/juju/errors" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" ) // CompareInt64 returns an integer comparing the int64 x to y. diff --git a/util/types/compare_test.go b/util/types/compare_test.go index 7e492497df..0a17c92037 100644 --- a/util/types/compare_test.go +++ b/util/types/compare_test.go @@ -17,7 +17,7 @@ import ( "time" . "github.com/pingcap/check" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" ) var _ = Suite(&testCompareSuite{}) diff --git a/util/types/convert.go b/util/types/convert.go index f3c7c4079e..f28721d479 100644 --- a/util/types/convert.go +++ b/util/types/convert.go @@ -25,7 +25,7 @@ import ( "unicode" "github.com/juju/errors" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util/charset" ) diff --git a/util/types/convert_test.go b/util/types/convert_test.go index 300d54aea6..1bc10f9d4a 100644 --- a/util/types/convert_test.go +++ b/util/types/convert_test.go @@ -20,7 +20,7 @@ import ( "fmt" . "github.com/pingcap/check" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util/charset" ) diff --git a/util/types/etc.go b/util/types/etc.go index 1df8a83eff..fdbb1e7b3f 100644 --- a/util/types/etc.go +++ b/util/types/etc.go @@ -23,7 +23,7 @@ import ( "strings" "github.com/juju/errors" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" "github.com/pingcap/tidb/util/charset" "github.com/pingcap/tidb/util/errors2" diff --git a/util/types/etc_test.go b/util/types/etc_test.go index 20aa8d6626..5fbf15d113 100644 --- a/util/types/etc_test.go +++ b/util/types/etc_test.go @@ -20,7 +20,7 @@ import ( "time" . "github.com/pingcap/check" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" ) func TestT(t *testing.T) { diff --git a/util/types/field_type.go b/util/types/field_type.go index ec96912bcf..daf8b354f5 100644 --- a/util/types/field_type.go +++ b/util/types/field_type.go @@ -21,7 +21,7 @@ import ( "fmt" "strings" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util/charset" ) diff --git a/util/types/field_type_test.go b/util/types/field_type_test.go index 558c799d8a..a0a069867a 100644 --- a/util/types/field_type_test.go +++ b/util/types/field_type_test.go @@ -15,7 +15,7 @@ package types import ( . "github.com/pingcap/check" - mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/mysql" ) var _ = Suite(&testFieldTypeSuite{})